aboutsummaryrefslogtreecommitdiffstats
path: root/resources
diff options
context:
space:
mode:
authorMiroslav Los <miroslav.los@pantheon.tech>2019-08-16 15:09:39 +0200
committerVratko Polak <vrpolak@cisco.com>2019-08-23 11:02:11 +0000
commit51511689eb9f93134878c314ba0349f28ef2ec4f (patch)
tree4a2efbf2dc4c4431874a4b56e421aa14c345b326 /resources
parentd95a4661e5ea956f71e9f3c5aed46491cf20a9a3 (diff)
Framework: Refactor complex functions in PLRSearch
Signed-off-by: Miroslav Los <miroslav.los@pantheon.tech> Change-Id: Ie2f19a2e3b37e8d85656ab31ece59b89c76bea25 (cherry picked from commit 281b230ba982f9f6ad589fb6e44f121a6a46531f)
Diffstat (limited to 'resources')
-rw-r--r--resources/libraries/python/PLRsearch/Integrator.py45
-rw-r--r--resources/libraries/python/PLRsearch/PLRsearch.py72
2 files changed, 68 insertions, 49 deletions
diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py
index 82abe5f8a3..035afd848c 100644
--- a/resources/libraries/python/PLRsearch/Integrator.py
+++ b/resources/libraries/python/PLRsearch/Integrator.py
@@ -45,9 +45,24 @@ def try_estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
raise
-# TODO: Pylint reports multiple complexity violations.
-# Refactor the code, using less (but structured) variables
-# and function calls for (relatively) loosly coupled code blocks.
+def generate_sample(averages, covariance_matrix, dimension, scale_coeff):
+ """Generate next sample for estimate_nd"""
+ covariance_matrix = copy.deepcopy(covariance_matrix)
+ for first in range(dimension):
+ for second in range(dimension):
+ covariance_matrix[first][second] *= scale_coeff
+ while 1:
+ sample_point = random.multivariate_normal(
+ averages, covariance_matrix, 1)[0].tolist()
+ # Multivariate Gauss can fall outside (-1, 1) interval
+ for first in range(dimension):
+ sample_coordinate = sample_point[first]
+ if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
+ break
+ else:
+ return sample_point
+
+
def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
"""Use Bayesian inference from control queue, put result to result queue.
@@ -148,6 +163,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
communication_pipe.recv())
debug_list.append("Called with param_focus_tracker {tracker!r}"
.format(tracker=param_focus_tracker))
+
def trace(name, value):
"""
Add a variable (name and value) to trace list (if enabled).
@@ -163,6 +179,7 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
"""
if trace_enabled:
trace_list.append(name + " " + repr(value))
+
value_logweight_function = dill.loads(dilled_function)
samples = 0
# Importance sampling produces samples of higher weight (important)
@@ -180,28 +197,14 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
else:
# Focus tracker has probably too high weight.
param_focus_tracker.log_sum_weight = None
- # TODO: Teach pylint the used version of numpy.random does have this member.
random.seed(0)
while not communication_pipe.poll():
if max_samples and samples >= max_samples:
break
- # Generate next sample.
- averages = param_focus_tracker.averages
- covariance_matrix = copy.deepcopy(param_focus_tracker.covariance_matrix)
- for first in range(dimension):
- for second in range(dimension):
- covariance_matrix[first][second] *= scale_coeff
- while 1:
- # TODO: Teach pylint that numpy.random does also have this member.
- sample_point = random.multivariate_normal(
- averages, covariance_matrix, 1)[0].tolist()
- # Multivariate Gauss can fall outside (-1, 1) interval
- for first in range(dimension):
- sample_coordinate = sample_point[first]
- if sample_coordinate <= -1.0 or sample_coordinate >= 1.0:
- break
- else: # These two breaks implement "level two continue".
- break
+ sample_point = generate_sample(param_focus_tracker.averages,
+ param_focus_tracker.covariance_matrix,
+ dimension,
+ scale_coeff)
trace("sample_point", sample_point)
samples += 1
trace("samples", samples)
diff --git a/resources/libraries/python/PLRsearch/PLRsearch.py b/resources/libraries/python/PLRsearch/PLRsearch.py
index db870c55dc..4205818d91 100644
--- a/resources/libraries/python/PLRsearch/PLRsearch.py
+++ b/resources/libraries/python/PLRsearch/PLRsearch.py
@@ -17,17 +17,17 @@ import logging
import math
import multiprocessing
import time
+from collections import namedtuple
import dill
-# TODO: Inform pylint about scipy (of correct version) being available.
from scipy.special import erfcx, erfc
# TODO: Teach FD.io CSIT to use multiple dirs in PYTHONPATH,
# then switch to absolute imports within PLRsearch package.
# Current usage of relative imports is just a short term workaround.
-import Integrator # pylint: disable=relative-import
-from log_plus import log_plus, log_minus # pylint: disable=relative-import
-import stat_trackers # pylint: disable=relative-import
+from . import Integrator
+from .log_plus import log_plus, log_minus
+from . import stat_trackers
class PLRsearch(object):
@@ -461,8 +461,6 @@ class PLRsearch(object):
trace("log_trial_likelihood", log_trial_likelihood)
return log_likelihood
- # TODO: Refactor (somehow) so pylint stops complaining about
- # too many local variables.
def measure_and_compute(
self, trial_duration, transmit_rate, trial_result_list,
min_rate, max_rate, focus_trackers=(None, None), max_samples=None):
@@ -531,6 +529,7 @@ class PLRsearch(object):
erf_focus_tracker = stat_trackers.VectorStatTracker(dimension)
erf_focus_tracker.unit_reset()
old_trackers = stretch_focus_tracker.copy(), erf_focus_tracker.copy()
+
def start_computing(fitting_function, focus_tracker):
"""Just a block of code to be used for each fitting function.
@@ -546,6 +545,7 @@ class PLRsearch(object):
:returns: Boss end of communication pipe.
:rtype: multiprocessing.Connection
"""
+
def value_logweight_func(trace, x_mrr, x_spread):
"""Return log of critical rate and log of likelihood.
@@ -585,6 +585,7 @@ class PLRsearch(object):
trace, fitting_function, min_rate, max_rate,
self.packet_loss_ratio_target, mrr, spread))
return value, logweight
+
dilled_function = dill.dumps(value_logweight_func)
boss_pipe_end, worker_pipe_end = multiprocessing.Pipe()
boss_pipe_end.send(
@@ -595,12 +596,15 @@ class PLRsearch(object):
worker.daemon = True
worker.start()
return boss_pipe_end
+
erf_pipe = start_computing(
self.lfit_erf, erf_focus_tracker)
stretch_pipe = start_computing(
self.lfit_stretch, stretch_focus_tracker)
+
# Measurement phase.
measurement = self.measurer.measure(trial_duration, transmit_rate)
+
# Processing phase.
def stop_computing(name, pipe):
"""Just a block of code to be used for each worker.
@@ -637,30 +641,42 @@ class PLRsearch(object):
logging.debug(message)
logging.debug("trackers: value %(val)r focus %(foc)r", {
"val": value_tracker, "foc": focus_tracker})
- return value_tracker, focus_tracker, sampls
- stretch_value_tracker, stretch_focus_tracker, stretch_samples = (
- stop_computing("stretch", stretch_pipe))
- erf_value_tracker, erf_focus_tracker, erf_samples = (
- stop_computing("erf", erf_pipe))
- stretch_avg = stretch_value_tracker.average
- erf_avg = erf_value_tracker.average
- # TODO: Take into account secondary stats.
- stretch_stdev = math.exp(stretch_value_tracker.log_variance / 2)
- erf_stdev = math.exp(erf_value_tracker.log_variance / 2)
- avg = math.exp((stretch_avg + erf_avg) / 2.0)
- var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
- var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
- stdev = avg * math.sqrt(var)
- focus_trackers = (stretch_focus_tracker, erf_focus_tracker)
+ return _PartialResult(value_tracker, focus_tracker, sampls)
+
+ stretch_result = stop_computing("stretch", stretch_pipe)
+ erf_result = stop_computing("erf", erf_pipe)
+ result = PLRsearch._get_result(measurement, stretch_result, erf_result)
logging.info(
"measure_and_compute finished with trial result %(res)r "
"avg %(avg)r stdev %(stdev)r stretch %(a1)r erf %(a2)r "
"new trackers %(nt)r old trackers %(ot)r stretch samples %(ss)r "
"erf samples %(es)r",
- {"res": measurement, "avg": avg, "stdev": stdev,
- "a1": math.exp(stretch_avg), "a2": math.exp(erf_avg),
- "nt": focus_trackers, "ot": old_trackers, "ss": stretch_samples,
- "es": erf_samples})
- return (
- measurement, avg, stdev, math.exp(stretch_avg),
- math.exp(erf_avg), focus_trackers)
+ {"res": result.measurement,
+ "avg": result.avg, "stdev": result.stdev,
+ "a1": result.stretch_exp_avg, "a2": result.erf_exp_avg,
+ "nt": result.trackers, "ot": old_trackers,
+ "ss": stretch_result.samples, "es": erf_result.samples})
+ return result
+
+ @staticmethod
+ def _get_result(measurement, stretch_result, erf_result):
+ """Collate results from measure_and_compute"""
+ stretch_avg = stretch_result.value_tracker.average
+ erf_avg = erf_result.value_tracker.average
+ # TODO: Take into account secondary stats.
+ stretch_stdev = math.exp(stretch_result.value_tracker.log_variance / 2)
+ erf_stdev = math.exp(erf_result.value_tracker.log_variance / 2)
+ avg = math.exp((stretch_avg + erf_avg) / 2.0)
+ var = (stretch_stdev * stretch_stdev + erf_stdev * erf_stdev) / 2.0
+ var += (stretch_avg - erf_avg) * (stretch_avg - erf_avg) / 4.0
+ stdev = avg * math.sqrt(var)
+ trackers = (stretch_result.focus_tracker, erf_result.focus_tracker)
+ sea = math.exp(stretch_avg)
+ eea = math.exp(erf_avg)
+ return _ComputeResult(measurement, avg, stdev, sea, eea, trackers)
+
+
+_PartialResult = namedtuple('_PartialResult',
+ 'value_tracker focus_tracker samples')
+_ComputeResult = namedtuple('_ComputeResult', 'measurement avg stdev ' +
+ 'stretch_exp_avg erf_exp_avg trackers')