aboutsummaryrefslogtreecommitdiffstats
path: root/resources/libraries/python/PLRsearch/Integrator.py
diff options
context:
space:
mode:
Diffstat (limited to 'resources/libraries/python/PLRsearch/Integrator.py')
-rw-r--r--resources/libraries/python/PLRsearch/Integrator.py59
1 files changed, 35 insertions, 24 deletions
diff --git a/resources/libraries/python/PLRsearch/Integrator.py b/resources/libraries/python/PLRsearch/Integrator.py
index a7a59391ed..cc8f838fe6 100644
--- a/resources/libraries/python/PLRsearch/Integrator.py
+++ b/resources/libraries/python/PLRsearch/Integrator.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021 Cisco and/or its affiliates.
+# Copyright (c) 2024 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
@@ -189,12 +189,15 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
:raises numpy.linalg.LinAlgError: If the focus shape gets singular
(due to rounding errors). Try changing scale_coeff.
"""
- debug_list = list()
- trace_list = list()
+ debug_list = []
+ trace_list = []
# Block until input object appears.
- dimension, dilled_function, param_focus_tracker, max_samples = (
- communication_pipe.recv()
- )
+ (
+ dimension,
+ dilled_function,
+ param_focus_tracker,
+ max_samples,
+ ) = communication_pipe.recv()
debug_list.append(
f"Called with param_focus_tracker {param_focus_tracker!r}"
)
@@ -237,39 +240,47 @@ def estimate_nd(communication_pipe, scale_coeff=8.0, trace_enabled=False):
if max_samples and samples >= max_samples:
break
sample_point = generate_sample(
- param_focus_tracker.averages, param_focus_tracker.covariance_matrix,
- dimension, scale_coeff
+ param_focus_tracker.averages,
+ param_focus_tracker.covariance_matrix,
+ dimension,
+ scale_coeff,
)
- trace(u"sample_point", sample_point)
+ trace("sample_point", sample_point)
samples += 1
- trace(u"samples", samples)
+ trace("samples", samples)
value, log_weight = value_logweight_function(trace, *sample_point)
- trace(u"value", value)
- trace(u"log_weight", log_weight)
- trace(u"focus tracker before adding", param_focus_tracker)
+ trace("value", value)
+ trace("log_weight", log_weight)
+ trace("focus tracker before adding", param_focus_tracker)
# Update focus related statistics.
param_distance = param_focus_tracker.add_without_dominance_get_distance(
sample_point, log_weight
)
# The code above looked at weight (not importance).
# The code below looks at importance (not weight).
- log_rarity = param_distance / 2.0
- trace(u"log_rarity", log_rarity)
+ log_rarity = param_distance / 2.0 / scale_coeff
+ trace("log_rarity", log_rarity)
log_importance = log_weight + log_rarity
- trace(u"log_importance", log_importance)
+ trace("log_importance", log_importance)
value_tracker.add(value, log_importance)
# Update sampled statistics.
param_sampled_tracker.add_get_shift(sample_point, log_importance)
debug_list.append(f"integrator used {samples!s} samples")
debug_list.append(
- u" ".join([
- u"value_avg", str(value_tracker.average),
- u"param_sampled_avg", repr(param_sampled_tracker.averages),
- u"param_sampled_cov", repr(param_sampled_tracker.covariance_matrix),
- u"value_log_variance", str(value_tracker.log_variance),
- u"value_log_secondary_variance",
- str(value_tracker.secondary.log_variance)
- ])
+ " ".join(
+ [
+ "value_avg",
+ str(value_tracker.average),
+ "param_sampled_avg",
+ repr(param_sampled_tracker.averages),
+ "param_sampled_cov",
+ repr(param_sampled_tracker.covariance_matrix),
+ "value_log_variance",
+ str(value_tracker.log_variance),
+ "value_log_secondary_variance",
+ str(value_tracker.secondary.log_variance),
+ ]
+ )
)
communication_pipe.send(
(value_tracker, param_focus_tracker, debug_list, trace_list, samples)