From 281b230ba982f9f6ad589fb6e44f121a6a46531f Mon Sep 17 00:00:00 2001 From: Miroslav Los Date: Fri, 16 Aug 2019 15:09:39 +0200 Subject: Framework: Refactor complex functions in PLRSearch Signed-off-by: Miroslav Los Change-Id: Ie2f19a2e3b37e8d85656ab31ece59b89c76bea25 --- resources/libraries/python/PLRsearch/Integrator.py | 45 +++++++------- resources/libraries/python/PLRsearch/PLRsearch.py | 72 +++++++++++++--------- 2 files changed, 68 insertions(+), 49 deletions(-) (limited to 'resources/libraries') diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py index 82abe5f8a3..035afd848c 100644 --- a/resources/libraries/python/PLRsearch/Integrator.py +++ b/resources/libraries/python/PLRsearch/Integrator.py @@ -45,9 +45,24 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): raise -# TODO: Pylint reports multiple complexity violations. -# Refactor the code, using less (but structured) variables -# and function calls for (relatively) loosly coupled code blocks. +def generate_sample(averages, covariance_matrix, dimension, scale_coeff): + """Generate next sample for estimate_nd""" + covariance_matrix = copy.deepcopy(covariance_matrix) + for first in range(dimension): + for second in range(dimension): + covariance_matrix[first][second] *= scale_coeff + while 1: + sample_point = random.multivariate_normal( + averages, covariance_matrix, 1)[0].tolist() + # Multivariate Gauss can fall outside (-1, 1) interval + for first in range(dimension): + sample_coordinate = sample_point[first] + if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: + break + else: + return sample_point + + def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """Use Bayesian inference from control queue, put result to result queue. @@ -148,6 +163,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): communication_pipe.recv()) debug_list.append("Called with param_focus_tracker {tracker!r}" .format(tracker=param_focus_tracker)) + def trace(name, value): """ Add a variable (name and value) to trace list (if enabled). @@ -163,6 +179,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): """ if trace_enabled: trace_list.append(name + " " + repr(value)) + value_logweight_function = dill.loads(dilled_function) samples = 0 # Importance sampling produces samples of higher weight (important) @@ -180,28 +197,14 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False): else: # Focus tracker has probably too high weight. param_focus_tracker.log_sum_weight = None - # TODO: Teach pylint the used version of numpy.random does have this member. random.seed(0) while not communication_pipe.poll(): if max_samples and samples >= max_samples: break - # Generate next sample. - averages = param_focus_tracker.averages - covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix) - for first in range(dimension): - for second in range(dimension): - covariance_matrix[first][second] *= scale_coeff - while 1: - # TODO: Teach pylint that numpy.random does also have this member. - sample_point = random.multivariate_normal( - averages, covariance_matrix, 1)[0].tolist() - # Multivariate Gauss can fall outside (-1, 1) interval - for first in range(dimension): - sample_coordinate = sample_point[first] - if sample_coordinate <= -1.0 or sample_coordinate >= 1.0: - break - else: # These two breaks implement "level two continue". - break + sample_point = generate_sample(param_focus_tracker.averages, + param_focus_tracker.covariance_matrix, + dimension, + scale_coeff) trace("sample_point", sample_point) samples += 1 trace("samples", samples) diff --git a/resources/libraries/python/PLRsearch/PLRsearch.py b/resources/libraries/python/PLRsearch/PLRsearch.py index db870c55dc..4205818d91 100644 --- a/resources/libraries/python/PLRsearch/PLRsearch.py +++ b/resources/libraries/python/PLRsearch/PLRsearch.py @@ -17,17 +17,17 @@ import logging import math import multiprocessing import time +from collections import namedtuple import dill -# TODO: Inform pylint about scipy (of correct version) being available. from scipy.special import erfcx, erfc # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH, # then switch to absolute imports within PLRsearch package. # Current usage of relative imports is just a short term workaround. -import Integrator # pylint: disable=relative-import -from log_plus import log_plus, log_minus # pylint: disable=relative-import -import stat_trackers # pylint: disable=relative-import +from . import Integrator +from .log_plus import log_plus, log_minus +from . import stat_trackers class PLRsearch(object): @@ -461,8 +461,6 @@ class PLRsearch(object): trace("log_trial_likelihood", log_trial_likelihood) return log_likelihood - # TODO: Refactor (somehow) so pylint stops complaining about - # too many local variables. def measure_and_compute( self, trial_duration, transmit_rate, trial_result_list, min_rate, max_rate, focus_trackers=(None, None), max_samples=None): @@ -531,6 +529,7 @@ class PLRsearch(object): erf_focus_tracker = stat_trackers.VectorStatTracker(dimension) erf_focus_tracker.unit_reset() old_trackers = stretch_focus_tracker.copy(), erf_focus_tracker.copy() + def start_computing(fitting_function, focus_tracker): """Just a block of code to be used for each fitting function. @@ -546,6 +545,7 @@ class PLRsearch(object): :returns: Boss end of communication pipe. :rtype: multiprocessing.Connection """ + def value_logweight_func(trace, x_mrr, x_spread): """Return log of critical rate and log of likelihood. @@ -585,6 +585,7 @@ class PLRsearch(object): trace, fitting_function, min_rate, max_rate, self.packet_loss_ratio_target, mrr, spread)) return value, logweight + dilled_function = dill.dumps(value_logweight_func) boss_pipe_end, worker_pipe_end = multiprocessing.Pipe() boss_pipe_end.send( @@ -595,12 +596,15 @@ class PLRsearch(object): worker.daemon = True worker.start() return boss_pipe_end + erf_pipe = start_computing( self.lfit_erf, erf_focus_tracker) stretch_pipe = start_computing( self.lfit_stretch, stretch_focus_tracker) + # Measurement phase. measurement = self.measurer.measure(trial_duration, transmit_rate) + # Processing phase. def stop_computing(name, pipe): """Just a block of code to be used for each worker. @@ -637,30 +641,42 @@ class PLRsearch(object): logging.debug(message) logging.debug("trackers: value %(val)r focus %(foc)r", { "val": value_tracker, "foc": focus_tracker}) - return value_tracker, focus_tracker, sampls - stretch_value_tracker, stretch_focus_tracker, stretch_samples = ( - stop_computing("stretch", stretch_pipe)) - erf_value_tracker, erf_focus_tracker, erf_samples = ( - stop_computing("erf", erf_pipe)) - stretch_avg = stretch_value_tracker.average - erf_avg = erf_value_tracker.average - # TODO: Take into account secondary stats. - stretch_stdev = math.exp(stretch_value_tracker.log_variance / 2) - erf_stdev = math.exp(erf_value_tracker.log_variance / 2) - avg = math.exp((stretch_avg + erf_avg) / 2.0) - var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0 - var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0 - stdev = avg * math.sqrt(var) - focus_trackers = (stretch_focus_tracker, erf_focus_tracker) + return _PartialResult(value_tracker, focus_tracker, sampls) + + stretch_result = stop_computing("stretch", stretch_pipe) + erf_result = stop_computing("erf", erf_pipe) + result = PLRsearch._get_result(measurement, stretch_result, erf_result) logging.info( "measure_and_compute finished with trial result %(res)r " "avg %(avg)r stdev %(stdev)r stretch %(a1)r erf %(a2)r " "new trackers %(nt)r old trackers %(ot)r stretch samples %(ss)r " "erf samples %(es)r", - {"res": measurement, "avg": avg, "stdev": stdev, - "a1": math.exp(stretch_avg), "a2": math.exp(erf_avg), - "nt": focus_trackers, "ot": old_trackers, "ss": stretch_samples, - "es": erf_samples}) - return ( - measurement, avg, stdev, math.exp(stretch_avg), - math.exp(erf_avg), focus_trackers) + {"res": result.measurement, + "avg": result.avg, "stdev": result.stdev, + "a1": result.stretch_exp_avg, "a2": result.erf_exp_avg, + "nt": result.trackers, "ot": old_trackers, + "ss": stretch_result.samples, "es": erf_result.samples}) + return result + + @staticmethod + def _get_result(measurement, stretch_result, erf_result): + """Collate results from measure_and_compute""" + stretch_avg = stretch_result.value_tracker.average + erf_avg = erf_result.value_tracker.average + # TODO: Take into account secondary stats. + stretch_stdev = math.exp(stretch_result.value_tracker.log_variance / 2) + erf_stdev = math.exp(erf_result.value_tracker.log_variance / 2) + avg = math.exp((stretch_avg + erf_avg) / 2.0) + var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0 + var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0 + stdev = avg * math.sqrt(var) + trackers = (stretch_result.focus_tracker, erf_result.focus_tracker) + sea = math.exp(stretch_avg) + eea = math.exp(erf_avg) + return _ComputeResult(measurement, avg, stdev, sea, eea, trackers) + + +_PartialResult = namedtuple('_PartialResult', + 'value_tracker focus_tracker samples') +_ComputeResult = namedtuple('_ComputeResult', 'measurement avg stdev ' + + 'stretch_exp_avg erf_exp_avg trackers') -- cgit 1.2.3-korg