From 281b230ba982f9f6ad589fb6e44f121a6a46531f Mon Sep 17 00:00:00 2001 From: Miroslav Los Date: Fri, 16 Aug 2019 15:09:39 +0200 Subject: Framework: Refactor complex functions in PLRSearch Signed-off-by: Miroslav Los Change-Id: Ie2f19a2e3b37e8d85656ab31ece59b89c76bea25 --- resources/libraries/python/PLRsearch/Integrator.py | 45 ++++++++++++---------- 1 file changed, 24 insertions(+), 21 deletions(-) (limited to 'resources/libraries/python/PLRsearch/Integrator.py') diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 82abe5f8a3..035afd848c 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -45,9 +45,24 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): raise -# TODO: Pylint reports multiple complexity violations. -# Refactor the code, using less (but structured) variables -# and function calls for (relatively) loosly coupled code blocks. +def generate_sample(averages, covariance_matrix, dimension, scale_coeff): + """Generate next sample for estimate_nd""" + covariance_matrix = copy.deepcopy(covariance_matrix) + for first in range(dimension): + for second in range(dimension): + covariance_matrix[first][second] *= scale_coeff + while 1: + sample_point = random.multivariate_normal( + averages, covariance_matrix, 1)[0].tolist() + # Multivariate Gauss can fall outside (-1, 1) interval + for first in range(dimension): + sample_coordinate = sample_point[first] + if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: + break + else: + return sample_point + + def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """Use Bayesian inference from control queue, put result to result queue. @@ -148,6 +163,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): communication_pipe.recv()) debug_list.append("Called with param_focus_tracker {tracker!r}" .format(tracker=param_focus_tracker)) + def trace(name, value): """ Add a variable (name and value) to trace list (if enabled). @@ -163,6 +179,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """ if trace_enabled: trace_list.append(name + " " + repr(value)) + value_logweight_function = dill.loads(dilled_function) samples = 0 # Importance sampling produces samples of higher weight (important) @@ -180,28 +197,14 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): else: # Focus tracker has probably too high weight. param_focus_tracker.log_sum_weight = None - # TODO: Teach pylint the used version of numpy.random does have this member. random.seed(0) while not communication_pipe.poll(): if max_samples and samples >= max_samples: break - # Generate next sample. - averages = param_focus_tracker.averages - covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix) - for first in range(dimension): - for second in range(dimension): - covariance_matrix[first][second] *= scale_coeff - while 1: - # TODO: Teach pylint that numpy.random does also have this member. - sample_point = random.multivariate_normal( - averages, covariance_matrix, 1)[0].tolist() - # Multivariate Gauss can fall outside (-1, 1) interval - for first in range(dimension): - sample_coordinate = sample_point[first] - if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: - break - else: # These two breaks implement "level two continue". - break + sample_point = generate_sample(param_focus_tracker.averages, + param_focus_tracker.covariance_matrix, + dimension, + scale_coeff) trace("sample_point", sample_point) samples += 1 trace("samples", samples) -- cgit 1.2.3-korg