diff options
author | Vratko Polak <vrpolak@cisco.com> | 2024-01-12 14:50:52 +0100 |
---|---|---|
committer | Vratko Polak <vrpolak@cisco.com> | 2024-01-24 12:28:58 +0100 |
commit | fcd0677317970062b37e196b4d1a15a135f51cca (patch) | |
tree | f9355ec08094de2e863c2fceea465f7f3c6b54cc /resources/libraries/python/PLRsearch/Integrator.py | |
parent | 852f60f525fdc6080387fe6a3b297736c83f0834 (diff) |
style(PLRsearch): format according to black
Change-Id: I26e0ce172740f7f440469578294f8c13fec5850b
Signed-off-by: Vratko Polak <vrpolak@cisco.com>
Diffstat (limited to 'resources/libraries/python/PLRsearch/Integrator.py')
-rw-r--r-- | resources/libraries/python/PLRsearch/Integrator.py | 53 |
1 files changed, 32 insertions, 21 deletions
diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 7f118db00d..f80110ce29 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 Cisco and/or its affiliates. +# Copyright (c) 2024 Cisco and/or its affiliates. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at: @@ -192,9 +192,12 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): debug_list = list() trace_list = list() # Block until input object appears. - dimension, dilled_function, param_focus_tracker, max_samples = ( - communication_pipe.recv() - ) + ( + dimension, + dilled_function, + param_focus_tracker, + max_samples, + ) = communication_pipe.recv() debug_list.append( f"Called with param_focus_tracker {param_focus_tracker!r}" ) @@ -237,16 +240,18 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): if max_samples and samples >= max_samples: break sample_point = generate_sample( - param_focus_tracker.averages, param_focus_tracker.covariance_matrix, - dimension, scale_coeff + param_focus_tracker.averages, + param_focus_tracker.covariance_matrix, + dimension, + scale_coeff, ) - trace(u"sample_point", sample_point) + trace("sample_point", sample_point) samples += 1 - trace(u"samples", samples) + trace("samples", samples) value, log_weight = value_logweight_function(trace, *sample_point) - trace(u"value", value) - trace(u"log_weight", log_weight) - trace(u"focus tracker before adding", param_focus_tracker) + trace("value", value) + trace("log_weight", log_weight) + trace("focus tracker before adding", param_focus_tracker) # Update focus related statistics. param_distance = param_focus_tracker.add_without_dominance_get_distance( sample_point, log_weight @@ -254,22 +259,28 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): # The code above looked at weight (not importance). # The code below looks at importance (not weight). log_rarity = param_distance / 2.0 / scale_coeff - trace(u"log_rarity", log_rarity) + trace("log_rarity", log_rarity) log_importance = log_weight + log_rarity - trace(u"log_importance", log_importance) + trace("log_importance", log_importance) value_tracker.add(value, log_importance) # Update sampled statistics. param_sampled_tracker.add_get_shift(sample_point, log_importance) debug_list.append(f"integrator used {samples!s} samples") debug_list.append( - u" ".join([ - u"value_avg", str(value_tracker.average), - u"param_sampled_avg", repr(param_sampled_tracker.averages), - u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix), - u"value_log_variance", str(value_tracker.log_variance), - u"value_log_secondary_variance", - str(value_tracker.secondary.log_variance) - ]) + " ".join( + [ + "value_avg", + str(value_tracker.average), + "param_sampled_avg", + repr(param_sampled_tracker.averages), + "param_sampled_cov", + repr(param_sampled_tracker.covariance_matrix), + "value_log_variance", + str(value_tracker.log_variance), + "value_log_secondary_variance", + str(value_tracker.secondary.log_variance), + ] + ) ) communication_pipe.send( (value_tracker, param_focus_tracker, debug_list, trace_list, samples) |