diff options
author | Miroslav Los <miroslav.los@pantheon.tech> | 2019-08-16 15:09:39 +0200 |
---|---|---|
committer | Vratko Polak <vrpolak@cisco.com> | 2019-08-19 08:41:05 +0000 |
commit | 281b230ba982f9f6ad589fb6e44f121a6a46531f (patch) | |
tree | b4527dab2a6d859ba12f11bdaa4540bac8ae56a0 /resources/libraries/python/PLRsearch/PLRsearch.py | |
parent | 93358d8b82778cb00436f78d0e8396f01acbd279 (diff) |
Framework: Refactor complex functions in PLRSearch
Signed-off-by: Miroslav Los <miroslav.los@pantheon.tech>
Change-Id: Ie2f19a2e3b37e8d85656ab31ece59b89c76bea25
Diffstat (limited to 'resources/libraries/python/PLRsearch/PLRsearch.py')
-rw-r--r-- | resources/libraries/python/PLRsearch/PLRsearch.py | 72 |
1 files changed, 44 insertions, 28 deletions
diff --git a/resources/libraries/python/PLRsearch/PLRsearch.py b/resources/libraries/python/PLRsearch/PLRsearch.py index db870c55dc..4205818d91 100644 --- a/resources/libraries/python/PLRsearch/PLRsearch.py +++ b/resources/libraries/python/PLRsearch/PLRsearch.py @@ -17,17 +17,17 @@ import logging import math import multiprocessing import time +from collections import namedtuple import dill -# TODO: Inform pylint about scipy (of correct version) being available. from scipy.special import erfcx, erfc # TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH, # then switch to absolute imports within PLRsearch package. # Current usage of relative imports is just a short term workaround. -import Integrator # pylint: disable=relative-import -from log_plus import log_plus, log_minus # pylint: disable=relative-import -import stat_trackers # pylint: disable=relative-import +from . import Integrator +from .log_plus import log_plus, log_minus +from . import stat_trackers class PLRsearch(object): @@ -461,8 +461,6 @@ class PLRsearch(object): trace("log_trial_likelihood", log_trial_likelihood) return log_likelihood - # TODO: Refactor (somehow) so pylint stops complaining about - # too many local variables. def measure_and_compute( self, trial_duration, transmit_rate, trial_result_list, min_rate, max_rate, focus_trackers=(None, None), max_samples=None): @@ -531,6 +529,7 @@ class PLRsearch(object): erf_focus_tracker = stat_trackers.VectorStatTracker(dimension) erf_focus_tracker.unit_reset() old_trackers = stretch_focus_tracker.copy(), erf_focus_tracker.copy() + def start_computing(fitting_function, focus_tracker): """Just a block of code to be used for each fitting function. @@ -546,6 +545,7 @@ class PLRsearch(object): :returns: Boss end of communication pipe. :rtype: multiprocessing.Connection """ + def value_logweight_func(trace, x_mrr, x_spread): """Return log of critical rate and log of likelihood. @@ -585,6 +585,7 @@ class PLRsearch(object): trace, fitting_function, min_rate, max_rate, self.packet_loss_ratio_target, mrr, spread)) return value, logweight + dilled_function = dill.dumps(value_logweight_func) boss_pipe_end, worker_pipe_end = multiprocessing.Pipe() boss_pipe_end.send( @@ -595,12 +596,15 @@ class PLRsearch(object): worker.daemon = True worker.start() return boss_pipe_end + erf_pipe = start_computing( self.lfit_erf, erf_focus_tracker) stretch_pipe = start_computing( self.lfit_stretch, stretch_focus_tracker) + # Measurement phase. measurement = self.measurer.measure(trial_duration, transmit_rate) + # Processing phase. def stop_computing(name, pipe): """Just a block of code to be used for each worker. @@ -637,30 +641,42 @@ class PLRsearch(object): logging.debug(message) logging.debug("trackers: value %(val)r focus %(foc)r", { "val": value_tracker, "foc": focus_tracker}) - return value_tracker, focus_tracker, sampls - stretch_value_tracker, stretch_focus_tracker, stretch_samples = ( - stop_computing("stretch", stretch_pipe)) - erf_value_tracker, erf_focus_tracker, erf_samples = ( - stop_computing("erf", erf_pipe)) - stretch_avg = stretch_value_tracker.average - erf_avg = erf_value_tracker.average - # TODO: Take into account secondary stats. - stretch_stdev = math.exp(stretch_value_tracker.log_variance / 2) - erf_stdev = math.exp(erf_value_tracker.log_variance / 2) - avg = math.exp((stretch_avg + erf_avg) / 2.0) - var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0 - var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0 - stdev = avg * math.sqrt(var) - focus_trackers = (stretch_focus_tracker, erf_focus_tracker) + return _PartialResult(value_tracker, focus_tracker, sampls) + + stretch_result = stop_computing("stretch", stretch_pipe) + erf_result = stop_computing("erf", erf_pipe) + result = PLRsearch._get_result(measurement, stretch_result, erf_result) logging.info( "measure_and_compute finished with trial result %(res)r " "avg %(avg)r stdev %(stdev)r stretch %(a1)r erf %(a2)r " "new trackers %(nt)r old trackers %(ot)r stretch samples %(ss)r " "erf samples %(es)r", - {"res": measurement, "avg": avg, "stdev": stdev, - "a1": math.exp(stretch_avg), "a2": math.exp(erf_avg), - "nt": focus_trackers, "ot": old_trackers, "ss": stretch_samples, - "es": erf_samples}) - return ( - measurement, avg, stdev, math.exp(stretch_avg), - math.exp(erf_avg), focus_trackers) + {"res": result.measurement, + "avg": result.avg, "stdev": result.stdev, + "a1": result.stretch_exp_avg, "a2": result.erf_exp_avg, + "nt": result.trackers, "ot": old_trackers, + "ss": stretch_result.samples, "es": erf_result.samples}) + return result + + @staticmethod + def _get_result(measurement, stretch_result, erf_result): + """Collate results from measure_and_compute""" + stretch_avg = stretch_result.value_tracker.average + erf_avg = erf_result.value_tracker.average + # TODO: Take into account secondary stats. + stretch_stdev = math.exp(stretch_result.value_tracker.log_variance / 2) + erf_stdev = math.exp(erf_result.value_tracker.log_variance / 2) + avg = math.exp((stretch_avg + erf_avg) / 2.0) + var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0 + var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0 + stdev = avg * math.sqrt(var) + trackers = (stretch_result.focus_tracker, erf_result.focus_tracker) + sea = math.exp(stretch_avg) + eea = math.exp(erf_avg) + return _ComputeResult(measurement, avg, stdev, sea, eea, trackers) + + +_PartialResult = namedtuple('_PartialResult', + 'value_tracker focus_tracker samples') +_ComputeResult = namedtuple('_ComputeResult', 'measurement avg stdev ' + + 'stretch_exp_avg erf_exp_avg trackers') |