aboutsummaryrefslogtreecommitdiffstats
path: root/resources/libraries/python/PLRsearch/stat_trackers.py
diff options
context:
space:
mode:
Diffstat (limited to 'resources/libraries/python/PLRsearch/stat_trackers.py')
-rw-r--r--resources/libraries/python/PLRsearch/stat_trackers.py56
1 files changed, 38 insertions, 18 deletions
diff --git a/resources/libraries/python/PLRsearch/stat_trackers.py b/resources/libraries/python/PLRsearch/stat_trackers.py
index e0b21dc3a9..d19eebedb7 100644
--- a/resources/libraries/python/PLRsearch/stat_trackers.py
+++ b/resources/libraries/python/PLRsearch/stat_trackers.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -64,8 +64,10 @@ class ScalarStatTracker:
:returns: Expression constructing an equivalent instance.
:rtype: str
"""
- return f"ScalarStatTracker(log_sum_weight={self.log_sum_weight!r}," \
+ return (
+ f"ScalarStatTracker(log_sum_weight={self.log_sum_weight!r},"
f"average={self.average!r},log_variance={self.log_variance!r})"
+ )
def copy(self):
"""Return new ScalarStatTracker instance with the same state as self.
@@ -110,7 +112,8 @@ class ScalarStatTracker:
if absolute_shift > 0.0:
log_square_shift = 2 * math.log(absolute_shift)
log_variance = log_plus(
- log_variance, log_square_shift + log_sample_ratio)
+ log_variance, log_square_shift + log_sample_ratio
+ )
if log_variance is not None:
log_variance += old_log_sum_weight - new_log_sum_weight
self.log_sum_weight = new_log_sum_weight
@@ -133,10 +136,17 @@ class ScalarDualStatTracker(ScalarStatTracker):
One typical use is for Monte Carlo integrator to decide whether
the partial sums so far are reliable enough.
"""
+
def __init__(
- self, log_sum_weight=None, average=0.0, log_variance=None,
- log_sum_secondary_weight=None, secondary_average=0.0,
- log_secondary_variance=None, max_log_weight=None):
+ self,
+ log_sum_weight=None,
+ average=0.0,
+ log_variance=None,
+ log_sum_secondary_weight=None,
+ secondary_average=0.0,
+ log_secondary_variance=None,
+ max_log_weight=None,
+ ):
"""Initialize new tracker instance, empty by default.
:param log_sum_weight: Natural logarithm of sum of weights
@@ -177,12 +187,14 @@ class ScalarDualStatTracker(ScalarStatTracker):
:rtype: str
"""
sec = self.secondary
- return f"ScalarDualStatTracker(log_sum_weight={self.log_sum_weight!r},"\
- f"average={self.average!r},log_variance={self.log_variance!r}," \
- f"log_sum_secondary_weight={sec.log_sum_weight!r}," \
- f"secondary_average={sec.average!r}," \
- f"log_secondary_variance={sec.log_variance!r}," \
+ return (
+ f"ScalarDualStatTracker(log_sum_weight={self.log_sum_weight!r},"
+ f"average={self.average!r},log_variance={self.log_variance!r},"
+ f"log_sum_secondary_weight={sec.log_sum_weight!r},"
+ f"secondary_average={sec.average!r},"
+ f"log_secondary_variance={sec.log_variance!r},"
f"max_log_weight={self.max_log_weight!r})"
+ )
def add(self, scalar_value, log_weight=0.0):
"""Return updated both stats after addition of another sample.
@@ -242,8 +254,12 @@ class VectorStatTracker:
"""
def __init__(
- self, dimension=2, log_sum_weight=None, averages=None,
- covariance_matrix=None):
+ self,
+ dimension=2,
+ log_sum_weight=None,
+ averages=None,
+ covariance_matrix=None,
+ ):
"""Initialize new tracker instance, two-dimensional empty by default.
If any of latter two arguments is None, it means
@@ -272,10 +288,12 @@ class VectorStatTracker:
:returns: Expression constructing an equivalent instance.
:rtype: str
"""
- return f"VectorStatTracker(dimension={self.dimension!r}," \
- f"log_sum_weight={self.log_sum_weight!r}," \
- f"averages={self.averages!r}," \
+ return (
+ f"VectorStatTracker(dimension={self.dimension!r},"
+ f"log_sum_weight={self.log_sum_weight!r},"
+ f"averages={self.averages!r},"
f"covariance_matrix={self.covariance_matrix!r})"
+ )
def copy(self):
"""Return new instance with the same state as self.
@@ -287,8 +305,10 @@ class VectorStatTracker:
:rtype: VectorStatTracker
"""
return VectorStatTracker(
- self.dimension, self.log_sum_weight, self.averages[:],
- copy.deepcopy(self.covariance_matrix)
+ self.dimension,
+ self.log_sum_weight,
+ self.averages[:],
+ copy.deepcopy(self.covariance_matrix),
)
def reset(self):