aboutsummaryrefslogtreecommitdiffstats
path: root/resources/libraries/python/PLRsearch
diff options
context:
space:
mode:
Diffstat (limited to 'resources/libraries/python/PLRsearch')
-rw-r--r--resources/libraries/python/PLRsearch/Integrator.py59
-rw-r--r--resources/libraries/python/PLRsearch/PLRsearch.py186
-rw-r--r--resources/libraries/python/PLRsearch/log_plus.py8
-rw-r--r--resources/libraries/python/PLRsearch/stat_trackers.py58
4 files changed, 192 insertions, 119 deletions
diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py
index a7a59391ed..cc8f838fe6 100644
--- a/resources/libraries/python/PLRsearch/Integrator.py
+++ b/resources/libraries/python/PLRsearch/Integrator.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -189,12 +189,15 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
:raises numpy.linalg.LinAlgError: If the focus shape gets singular
(due to rounding errors). Try changing scale_coeff.
"""
- debug_list = list()
- trace_list = list()
+ debug_list = []
+ trace_list = []
# Block until input object appears.
- dimension, dilled_function, param_focus_tracker, max_samples = (
- communication_pipe.recv()
- )
+ (
+ dimension,
+ dilled_function,
+ param_focus_tracker,
+ max_samples,
+ ) = communication_pipe.recv()
debug_list.append(
f"Called with param_focus_tracker {param_focus_tracker!r}"
)
@@ -237,39 +240,47 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
if max_samples and samples >= max_samples:
break
sample_point = generate_sample(
- param_focus_tracker.averages, param_focus_tracker.covariance_matrix,
- dimension, scale_coeff
+ param_focus_tracker.averages,
+ param_focus_tracker.covariance_matrix,
+ dimension,
+ scale_coeff,
)
- trace(u"sample_point", sample_point)
+ trace("sample_point", sample_point)
samples += 1
- trace(u"samples", samples)
+ trace("samples", samples)
value, log_weight = value_logweight_function(trace, *sample_point)
- trace(u"value", value)
- trace(u"log_weight", log_weight)
- trace(u"focus tracker before adding", param_focus_tracker)
+ trace("value", value)
+ trace("log_weight", log_weight)
+ trace("focus tracker before adding", param_focus_tracker)
# Update focus related statistics.
param_distance = param_focus_tracker.add_without_dominance_get_distance(
sample_point, log_weight
)
# The code above looked at weight (not importance).
# The code below looks at importance (not weight).
- log_rarity = param_distance / 2.0
- trace(u"log_rarity", log_rarity)
+ log_rarity = param_distance / 2.0 / scale_coeff
+ trace("log_rarity", log_rarity)
log_importance = log_weight + log_rarity
- trace(u"log_importance", log_importance)
+ trace("log_importance", log_importance)
value_tracker.add(value, log_importance)
# Update sampled statistics.
param_sampled_tracker.add_get_shift(sample_point, log_importance)
debug_list.append(f"integrator used {samples!s} samples")
debug_list.append(
- u" ".join([
- u"value_avg", str(value_tracker.average),
- u"param_sampled_avg", repr(param_sampled_tracker.averages),
- u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
- u"value_log_variance", str(value_tracker.log_variance),
- u"value_log_secondary_variance",
- str(value_tracker.secondary.log_variance)
- ])
+ " ".join(
+ [
+ "value_avg",
+ str(value_tracker.average),
+ "param_sampled_avg",
+ repr(param_sampled_tracker.averages),
+ "param_sampled_cov",
+ repr(param_sampled_tracker.covariance_matrix),
+ "value_log_variance",
+ str(value_tracker.log_variance),
+ "value_log_secondary_variance",
+ str(value_tracker.secondary.log_variance),
+ ]
+ )
)
communication_pipe.send(
(value_tracker, param_focus_tracker, debug_list, trace_list, samples)
diff --git a/resources/libraries/python/PLRsearch/PLRsearch.py b/resources/libraries/python/PLRsearch/PLRsearch.py
index 0e78cc936d..326aa2e2d2 100644
--- a/resources/libraries/python/PLRsearch/PLRsearch.py
+++ b/resources/libraries/python/PLRsearch/PLRsearch.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2022 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -53,8 +53,14 @@ class PLRsearch:
log_xerfcx_10 = math.log(xerfcx_limit - math.exp(10) * erfcx(math.exp(10)))
def __init__(
- self, measurer, trial_duration_per_trial, packet_loss_ratio_target,
- trial_number_offset=0, timeout=7200.0, trace_enabled=False):
+ self,
+ measurer,
+ trial_duration_per_trial,
+ packet_loss_ratio_target,
+ trial_number_offset=0,
+ timeout=7200.0,
+ trace_enabled=False,
+ ):
"""Store rate measurer and additional parameters.
The measurer must never report negative loss count.
@@ -176,7 +182,7 @@ class PLRsearch:
f"Started search with min_rate {min_rate!r}, "
f"max_rate {max_rate!r}"
)
- trial_result_list = list()
+ trial_result_list = []
trial_number = self.trial_number_offset
focus_trackers = (None, None)
transmit_rate = (min_rate + max_rate) / 2.0
@@ -186,34 +192,54 @@ class PLRsearch:
trial_number += 1
logging.info(f"Trial {trial_number!r}")
results = self.measure_and_compute(
- self.trial_duration_per_trial * trial_number, transmit_rate,
- trial_result_list, min_rate, max_rate, focus_trackers
+ self.trial_duration_per_trial * trial_number,
+ transmit_rate,
+ trial_result_list,
+ min_rate,
+ max_rate,
+ focus_trackers,
)
measurement, average, stdev, avg1, avg2, focus_trackers = results
+ # Workaround for unsent packets and other anomalies.
+ measurement.plr_loss_count = min(
+ measurement.intended_count,
+ int(measurement.intended_count * measurement.loss_ratio + 0.9),
+ )
+ logging.debug(
+ f"loss ratio {measurement.plr_loss_count}"
+ f" / {measurement.intended_count}"
+ )
zeros += 1
# TODO: Ratio of fill rate to drain rate seems to have
# exponential impact. Make it configurable, or is 4:3 good enough?
- if measurement.loss_ratio >= self.packet_loss_ratio_target:
+ if measurement.plr_loss_count >= (
+ measurement.intended_count * self.packet_loss_ratio_target
+ ):
for _ in range(4 * zeros):
- lossy_loads.append(measurement.target_tr)
- if measurement.loss_count > 0:
+ lossy_loads.append(measurement.intended_load)
+ lossy_loads.sort()
zeros = 0
- lossy_loads.sort()
+ logging.debug("High enough loss, lossy loads added.")
+ else:
+ logging.debug(
+ f"Not a high loss, zero counter bumped to {zeros}."
+ )
if stop_time <= time.time():
return average, stdev
trial_result_list.append(measurement)
if (trial_number - self.trial_number_offset) <= 1:
next_load = max_rate
elif (trial_number - self.trial_number_offset) <= 3:
- next_load = (measurement.relative_receive_rate / (
- 1.0 - self.packet_loss_ratio_target))
+ next_load = measurement.relative_forwarding_rate / (
+ 1.0 - self.packet_loss_ratio_target
+ )
else:
next_load = (avg1 + avg2) / 2.0
if zeros > 0:
if lossy_loads[0] > next_load:
diminisher = math.pow(2.0, 1 - zeros)
next_load = lossy_loads[0] + diminisher * next_load
- next_load /= (1.0 + diminisher)
+ next_load /= 1.0 + diminisher
# On zero measurement, we need to drain obsoleted low losses
# even if we did not use them to increase next_load,
# in order to get to usable loses at higher loads.
@@ -263,22 +289,22 @@ class PLRsearch:
# TODO: chi is from https://en.wikipedia.org/wiki/Nondimensionalization
chi = (load - mrr) / spread
chi0 = -mrr / spread
- trace(u"stretch: load", load)
- trace(u"mrr", mrr)
- trace(u"spread", spread)
- trace(u"chi", chi)
- trace(u"chi0", chi0)
+ trace("stretch: load", load)
+ trace("mrr", mrr)
+ trace("spread", spread)
+ trace("chi", chi)
+ trace("chi0", chi0)
if chi > 0:
log_lps = math.log(
load - mrr + (log_plus(0, -chi) - log_plus(0, chi0)) * spread
)
- trace(u"big loss direct log_lps", log_lps)
+ trace("big loss direct log_lps", log_lps)
else:
two_positive = log_plus(chi, 2 * chi0 - log_2)
two_negative = log_plus(chi0, 2 * chi - log_2)
if two_positive <= two_negative:
log_lps = log_minus(chi, chi0) + log_spread
- trace(u"small loss crude log_lps", log_lps)
+ trace("small loss crude log_lps", log_lps)
return log_lps
two = log_minus(two_positive, two_negative)
three_positive = log_plus(two_positive, 3 * chi - log_3)
@@ -286,11 +312,11 @@ class PLRsearch:
three = log_minus(three_positive, three_negative)
if two == three:
log_lps = two + log_spread
- trace(u"small loss approx log_lps", log_lps)
+ trace("small loss approx log_lps", log_lps)
else:
log_lps = math.log(log_plus(0, chi) - log_plus(0, chi0))
log_lps += log_spread
- trace(u"small loss direct log_lps", log_lps)
+ trace("small loss direct log_lps", log_lps)
return log_lps
@staticmethod
@@ -329,26 +355,26 @@ class PLRsearch:
# TODO: The stretch sign is just to have less minuses. Worth changing?
chi = (mrr - load) / spread
chi0 = mrr / spread
- trace(u"Erf: load", load)
- trace(u"mrr", mrr)
- trace(u"spread", spread)
- trace(u"chi", chi)
- trace(u"chi0", chi0)
+ trace("Erf: load", load)
+ trace("mrr", mrr)
+ trace("spread", spread)
+ trace("chi", chi)
+ trace("chi0", chi0)
if chi >= -1.0:
- trace(u"positive, b roughly bigger than m", None)
+ trace("positive, b roughly bigger than m", None)
if chi > math.exp(10):
first = PLRsearch.log_xerfcx_10 + 2 * (math.log(chi) - 10)
- trace(u"approximated first", first)
+ trace("approximated first", first)
else:
first = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi))
- trace(u"exact first", first)
+ trace("exact first", first)
first -= chi * chi
second = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi0))
second -= chi0 * chi0
intermediate = log_minus(first, second)
- trace(u"first", first)
+ trace("first", first)
else:
- trace(u"negative, b roughly smaller than m", None)
+ trace("negative, b roughly smaller than m", None)
exp_first = PLRsearch.xerfcx_limit + chi * erfcx(-chi)
exp_first *= math.exp(-chi * chi)
exp_first -= 2 * chi
@@ -359,17 +385,17 @@ class PLRsearch:
second = math.log(PLRsearch.xerfcx_limit - chi * erfcx(chi0))
second -= chi0 * chi0
intermediate = math.log(exp_first - math.exp(second))
- trace(u"exp_first", exp_first)
- trace(u"second", second)
- trace(u"intermediate", intermediate)
+ trace("exp_first", exp_first)
+ trace("second", second)
+ trace("intermediate", intermediate)
result = intermediate + math.log(spread) - math.log(erfc(-chi0))
- trace(u"result", result)
+ trace("result", result)
return result
@staticmethod
def find_critical_rate(
- trace, lfit_func, min_rate, max_rate, loss_ratio_target,
- mrr, spread):
+ trace, lfit_func, min_rate, max_rate, loss_ratio_target, mrr, spread
+ ):
"""Given ratio target and parameters, return the achieving offered load.
This is basically an inverse function to lfit_func
@@ -411,12 +437,12 @@ class PLRsearch:
loss_rate = math.exp(lfit_func(trace, rate, mrr, spread))
loss_ratio = loss_rate / rate
if loss_ratio > loss_ratio_target:
- trace(u"halving down", rate)
+ trace("halving down", rate)
rate_hi = rate
elif loss_ratio < loss_ratio_target:
- trace(u"halving up", rate)
+ trace("halving up", rate)
rate_lo = rate
- trace(u"found", rate)
+ trace("found", rate)
return rate
@staticmethod
@@ -441,7 +467,7 @@ class PLRsearch:
Instead, the expected average loss is scaled according to the number
of packets actually sent.
- TODO: Copy ReceiveRateMeasurement from MLRsearch.
+ TODO: Copy MeasurementResult from MLRsearch.
:param trace: A multiprocessing-friendly logging function (closure).
:param lfit_func: Fitting function, typically lfit_spread or lfit_erf.
@@ -450,40 +476,47 @@ class PLRsearch:
:param spread: The spread parameter for the fitting function.
:type trace: function (str, object) -> None
:type lfit_func: Function from 3 floats to float.
- :type trial_result_list: list of MLRsearch.ReceiveRateMeasurement
+ :type trial_result_list: list of MLRsearch.MeasurementResult
:type mrr: float
:type spread: float
:returns: Logarithm of result weight for given function and parameters.
:rtype: float
"""
log_likelihood = 0.0
- trace(u"log_weight for mrr", mrr)
- trace(u"spread", spread)
+ trace("log_weight for mrr", mrr)
+ trace("spread", spread)
for result in trial_result_list:
- trace(u"for tr", result.target_tr)
- trace(u"lc", result.loss_count)
- trace(u"d", result.duration)
- # _rel_ values use units of target_tr (transactions per second).
+ trace("for tr", result.intended_load)
+ trace("plc", result.plr_loss_count)
+ trace("d", result.intended_duration)
+ # _rel_ values use units of intended_load (transactions per second).
log_avg_rel_loss_per_second = lfit_func(
- trace, result.target_tr, mrr, spread
+ trace, result.intended_load, mrr, spread
)
# _abs_ values use units of loss count (maybe packets).
# There can be multiple packets per transaction.
log_avg_abs_loss_per_trial = log_avg_rel_loss_per_second + math.log(
- result.transmit_count / result.target_tr
+ result.offered_count / result.intended_load
)
# Geometric probability computation for logarithms.
log_trial_likelihood = log_plus(0.0, -log_avg_abs_loss_per_trial)
- log_trial_likelihood *= -result.loss_count
+ log_trial_likelihood *= -result.plr_loss_count
log_trial_likelihood -= log_plus(0.0, +log_avg_abs_loss_per_trial)
log_likelihood += log_trial_likelihood
- trace(u"avg_loss_per_trial", math.exp(log_avg_abs_loss_per_trial))
- trace(u"log_trial_likelihood", log_trial_likelihood)
+ trace("avg_loss_per_trial", math.exp(log_avg_abs_loss_per_trial))
+ trace("log_trial_likelihood", log_trial_likelihood)
return log_likelihood
def measure_and_compute(
- self, trial_duration, transmit_rate, trial_result_list,
- min_rate, max_rate, focus_trackers=(None, None), max_samples=None):
+ self,
+ trial_duration,
+ transmit_rate,
+ trial_result_list,
+ min_rate,
+ max_rate,
+ focus_trackers=(None, None),
+ max_samples=None,
+ ):
"""Perform both measurement and computation at once.
High level steps: Prepare and launch computation worker processes,
@@ -524,7 +557,7 @@ class PLRsearch:
:param max_samples: Limit for integrator samples, for debugging.
:type trial_duration: float
:type transmit_rate: float
- :type trial_result_list: list of MLRsearch.ReceiveRateMeasurement
+ :type trial_result_list: list of MLRsearch.MeasurementResult
:type min_rate: float
:type max_rate: float
:type focus_trackers: 2-tuple of None or stat_trackers.VectorStatTracker
@@ -572,7 +605,7 @@ class PLRsearch:
# See https://stackoverflow.com/questions/15137292/large-objects-and-multiprocessing-pipes-and-send
worker = multiprocessing.Process(
target=Integrator.try_estimate_nd,
- args=(worker_pipe_end, 10.0, self.trace_enabled)
+ args=(worker_pipe_end, 5.0, self.trace_enabled),
)
worker.daemon = True
worker.start()
@@ -616,8 +649,13 @@ class PLRsearch:
)
value = math.log(
self.find_critical_rate(
- trace, fitting_function, min_rate, max_rate,
- self.packet_loss_ratio_target, mrr, spread
+ trace,
+ fitting_function,
+ min_rate,
+ max_rate,
+ self.packet_loss_ratio_target,
+ mrr,
+ spread,
)
)
return value, logweight
@@ -664,14 +702,18 @@ class PLRsearch:
raise RuntimeError(f"Worker {name} did not finish!")
result_or_traceback = pipe.recv()
try:
- value_tracker, focus_tracker, debug_list, trace_list, sampls = (
- result_or_traceback
- )
- except ValueError:
+ (
+ value_tracker,
+ focus_tracker,
+ debug_list,
+ trace_list,
+ sampls,
+ ) = result_or_traceback
+ except ValueError as exc:
raise RuntimeError(
f"Worker {name} failed with the following traceback:\n"
f"{result_or_traceback}"
- )
+ ) from exc
logging.info(f"Logs from worker {name!r}:")
for message in debug_list:
logging.info(message)
@@ -682,8 +724,8 @@ class PLRsearch:
)
return _PartialResult(value_tracker, focus_tracker, sampls)
- stretch_result = stop_computing(u"stretch", stretch_pipe)
- erf_result = stop_computing(u"erf", erf_pipe)
+ stretch_result = stop_computing("stretch", stretch_pipe)
+ erf_result = stop_computing("erf", erf_pipe)
result = PLRsearch._get_result(measurement, stretch_result, erf_result)
logging.info(
f"measure_and_compute finished with trial result "
@@ -705,7 +747,7 @@ class PLRsearch:
:param measurement: The trial measurement obtained during computation.
:param stretch_result: Computation output for stretch fitting function.
:param erf_result: Computation output for erf fitting function.
- :type measurement: ReceiveRateMeasurement
+ :type measurement: MeasurementResult
:type stretch_result: _PartialResult
:type erf_result: _PartialResult
:returns: Combined results.
@@ -730,7 +772,7 @@ class PLRsearch:
# Named tuples, for multiple local variables to be passed as return value.
_PartialResult = namedtuple(
- u"_PartialResult", u"value_tracker focus_tracker samples"
+ "_PartialResult", "value_tracker focus_tracker samples"
)
"""Two stat trackers and sample counter.
@@ -743,8 +785,8 @@ _PartialResult = namedtuple(
"""
_ComputeResult = namedtuple(
- u"_ComputeResult",
- u"measurement avg stdev stretch_exp_avg erf_exp_avg trackers"
+ "_ComputeResult",
+ "measurement avg stdev stretch_exp_avg erf_exp_avg trackers",
)
"""Measurement, 4 computation result values, pair of trackers.
@@ -754,7 +796,7 @@ _ComputeResult = namedtuple(
:param stretch_exp_avg: Stretch fitting function estimate average exponentiated.
:param erf_exp_avg: Erf fitting function estimate average, exponentiated.
:param trackers: Pair of focus trackers to start next iteration with.
-:type measurement: ReceiveRateMeasurement
+:type measurement: MeasurementResult
:type avg: float
:type stdev: float
:type stretch_exp_avg: float
diff --git a/resources/libraries/python/PLRsearch/log_plus.py b/resources/libraries/python/PLRsearch/log_plus.py
index 8ede2909c6..aabefdb5be 100644
--- a/resources/libraries/python/PLRsearch/log_plus.py
+++ b/resources/libraries/python/PLRsearch/log_plus.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -76,14 +76,14 @@ def log_minus(first, second):
:raises RuntimeError: If the difference would be non-positive.
"""
if first is None:
- raise RuntimeError(u"log_minus: does not support None first")
+ raise RuntimeError("log_minus: does not support None first")
if second is None:
return first
if second >= first:
- raise RuntimeError(u"log_minus: first has to be bigger than second")
+ raise RuntimeError("log_minus: first has to be bigger than second")
factor = -math.expm1(second - first)
if factor <= 0.0:
- msg = u"log_minus: non-positive number to log"
+ msg = "log_minus: non-positive number to log"
else:
return first + math.log(factor)
raise RuntimeError(msg)
diff --git a/resources/libraries/python/PLRsearch/stat_trackers.py b/resources/libraries/python/PLRsearch/stat_trackers.py
index e0b21dc3a9..e598fd840e 100644
--- a/resources/libraries/python/PLRsearch/stat_trackers.py
+++ b/resources/libraries/python/PLRsearch/stat_trackers.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -64,8 +64,10 @@ class ScalarStatTracker:
:returns: Expression constructing an equivalent instance.
:rtype: str
"""
- return f"ScalarStatTracker(log_sum_weight={self.log_sum_weight!r}," \
+ return (
+ f"ScalarStatTracker(log_sum_weight={self.log_sum_weight!r},"
f"average={self.average!r},log_variance={self.log_variance!r})"
+ )
def copy(self):
"""Return new ScalarStatTracker instance with the same state as self.
@@ -110,7 +112,8 @@ class ScalarStatTracker:
if absolute_shift > 0.0:
log_square_shift = 2 * math.log(absolute_shift)
log_variance = log_plus(
- log_variance, log_square_shift + log_sample_ratio)
+ log_variance, log_square_shift + log_sample_ratio
+ )
if log_variance is not None:
log_variance += old_log_sum_weight - new_log_sum_weight
self.log_sum_weight = new_log_sum_weight
@@ -133,10 +136,17 @@ class ScalarDualStatTracker(ScalarStatTracker):
One typical use is for Monte Carlo integrator to decide whether
the partial sums so far are reliable enough.
"""
+
def __init__(
- self, log_sum_weight=None, average=0.0, log_variance=None,
- log_sum_secondary_weight=None, secondary_average=0.0,
- log_secondary_variance=None, max_log_weight=None):
+ self,
+ log_sum_weight=None,
+ average=0.0,
+ log_variance=None,
+ log_sum_secondary_weight=None,
+ secondary_average=0.0,
+ log_secondary_variance=None,
+ max_log_weight=None,
+ ):
"""Initialize new tracker instance, empty by default.
:param log_sum_weight: Natural logarithm of sum of weights
@@ -177,12 +187,14 @@ class ScalarDualStatTracker(ScalarStatTracker):
:rtype: str
"""
sec = self.secondary
- return f"ScalarDualStatTracker(log_sum_weight={self.log_sum_weight!r},"\
- f"average={self.average!r},log_variance={self.log_variance!r}," \
- f"log_sum_secondary_weight={sec.log_sum_weight!r}," \
- f"secondary_average={sec.average!r}," \
- f"log_secondary_variance={sec.log_variance!r}," \
+ return (
+ f"ScalarDualStatTracker(log_sum_weight={self.log_sum_weight!r},"
+ f"average={self.average!r},log_variance={self.log_variance!r},"
+ f"log_sum_secondary_weight={sec.log_sum_weight!r},"
+ f"secondary_average={sec.average!r},"
+ f"log_secondary_variance={sec.log_variance!r},"
f"max_log_weight={self.max_log_weight!r})"
+ )
def add(self, scalar_value, log_weight=0.0):
"""Return updated both stats after addition of another sample.
@@ -197,7 +209,7 @@ class ScalarDualStatTracker(ScalarStatTracker):
"""
# Using super() as copy() and add() are not expected to change
# signature, so this way diamond inheritance will be supported.
- primary = super(ScalarDualStatTracker, self)
+ primary = super()
if self.max_log_weight is None or log_weight >= self.max_log_weight:
self.max_log_weight = log_weight
self.secondary = primary.copy()
@@ -242,8 +254,12 @@ class VectorStatTracker:
"""
def __init__(
- self, dimension=2, log_sum_weight=None, averages=None,
- covariance_matrix=None):
+ self,
+ dimension=2,
+ log_sum_weight=None,
+ averages=None,
+ covariance_matrix=None,
+ ):
"""Initialize new tracker instance, two-dimensional empty by default.
If any of latter two arguments is None, it means
@@ -272,10 +288,12 @@ class VectorStatTracker:
:returns: Expression constructing an equivalent instance.
:rtype: str
"""
- return f"VectorStatTracker(dimension={self.dimension!r}," \
- f"log_sum_weight={self.log_sum_weight!r}," \
- f"averages={self.averages!r}," \
+ return (
+ f"VectorStatTracker(dimension={self.dimension!r},"
+ f"log_sum_weight={self.log_sum_weight!r},"
+ f"averages={self.averages!r},"
f"covariance_matrix={self.covariance_matrix!r})"
+ )
def copy(self):
"""Return new instance with the same state as self.
@@ -287,8 +305,10 @@ class VectorStatTracker:
:rtype: VectorStatTracker
"""
return VectorStatTracker(
- self.dimension, self.log_sum_weight, self.averages[:],
- copy.deepcopy(self.covariance_matrix)
+ self.dimension,
+ self.log_sum_weight,
+ self.averages[:],
+ copy.deepcopy(self.covariance_matrix),
)
def reset(self):