diff options
Diffstat (limited to 'resources/libraries/python/PLRsearch/Integrator.py')
-rw-r--r-- | resources/libraries/python/PLRsearch/Integrator.py | 59 |
1 files changed, 34 insertions, 25 deletions
diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 86181eaa56..331bd8475b 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -23,6 +23,7 @@ import copy import traceback import dill + from numpy import random # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH, @@ -58,7 +59,7 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): # so we have to catch them all. traceback_string = traceback.format_exc() communication_pipe.send(traceback_string) - # After sendig, re-raise, so usages other than "one process per call" + # After sending, re-raise, so usages other than "one process per call" # keep behaving correctly. raise @@ -86,7 +87,8 @@ def generate_sample(averages, covariance_matrix, dimension, scale_coeff): covariance_matrix[first][second] *= scale_coeff while 1: sample_point = random.multivariate_normal( - averages, covariance_matrix, 1)[0].tolist() + averages, covariance_matrix, 1 + )[0].tolist() # Multivariate Gauss can fall outside (-1, 1) interval for first in range(dimension): sample_coordinate = sample_point[first] @@ -187,14 +189,15 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): :raises numpy.linalg.LinAlgError: If the focus shape gets singular (due to rounding errors). Try changing scale_coeff. """ - debug_list = list() trace_list = list() # Block until input object appears. dimension, dilled_function, param_focus_tracker, max_samples = ( - communication_pipe.recv()) - debug_list.append("Called with param_focus_tracker {tracker!r}" - .format(tracker=param_focus_tracker)) + communication_pipe.recv() + ) + debug_list.append( + f"Called with param_focus_tracker {param_focus_tracker!r}" + ) def trace(name, value): """ @@ -210,7 +213,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): :type value: object """ if trace_enabled: - trace_list.append(name + " " + repr(value)) + trace_list.append(f"{name} {value!r}") value_logweight_function = dill.loads(dilled_function) samples = 0 @@ -235,33 +238,39 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): break sample_point = generate_sample( param_focus_tracker.averages, param_focus_tracker.covariance_matrix, - dimension, scale_coeff) - trace("sample_point", sample_point) + dimension, scale_coeff + ) + trace(u"sample_point", sample_point) samples += 1 - trace("samples", samples) + trace(u"samples", samples) value, log_weight = value_logweight_function(trace, *sample_point) - trace("value", value) - trace("log_weight", log_weight) - trace("focus tracker before adding", param_focus_tracker) + trace(u"value", value) + trace(u"log_weight", log_weight) + trace(u"focus tracker before adding", param_focus_tracker) # Update focus related statistics. param_distance = param_focus_tracker.add_without_dominance_get_distance( - sample_point, log_weight) + sample_point, log_weight + ) # The code above looked at weight (not importance). # The code below looks at importance (not weight). log_rarity = param_distance / 2.0 - trace("log_rarity", log_rarity) + trace(u"log_rarity", log_rarity) log_importance = log_weight + log_rarity - trace("log_importance", log_importance) + trace(u"log_importance", log_importance) value_tracker.add(value, log_importance) # Update sampled statistics. param_sampled_tracker.add_get_shift(sample_point, log_importance) - debug_list.append("integrator used " + str(samples) + " samples") - debug_list.append(" ".join([ - "value_avg", str(value_tracker.average), - "param_sampled_avg", repr(param_sampled_tracker.averages), - "param_sampled_cov", repr(param_sampled_tracker.covariance_matrix), - "value_log_variance", str(value_tracker.log_variance), - "value_log_secondary_variance", - str(value_tracker.secondary.log_variance)])) + debug_list.append(f"integrator used {samples!s} samples") + debug_list.append( + u" ".join([ + u"value_avg", str(value_tracker.average), + u"param_sampled_avg", repr(param_sampled_tracker.averages), + u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix), + u"value_log_variance", str(value_tracker.log_variance), + u"value_log_secondary_variance", + str(value_tracker.secondary.log_variance) + ]) + ) communication_pipe.send( - (value_tracker, param_focus_tracker, debug_list, trace_list, samples)) + (value_tracker, param_focus_tracker, debug_list, trace_list, samples) + ) |