aboutsummaryrefslogtreecommitdiffstats
path: root/resources/libraries/python/MLRsearch/MeasurementDatabase.py
blob: 2f601d626000e0b51299c58c66570e0bb9a13cd1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) 2021 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module defining MeasurementDatabase class."""

from .ReceiveRateInterval import ReceiveRateInterval
from .PerDurationDatabase import PerDurationDatabase


class MeasurementDatabase:
    """A structure holding measurement results.

    The implementation uses a dict from duration values
    to PerDurationDatabase instances.

    Several utility methods are added, accomplishing tasks useful for MLRsearch.

    This class contains the "find tightest bounds" parts of logic required
    by MLRsearch. One exception is lack of any special handling for maximal
    or minimal rates.
    """

    def __init__(self, measurements):
        """Store measurement results in per-duration databases.

        TODO: Move processing to a factory method,
        keep constructor only to store (presumably valid) values.

        If the measurements argument contains is a dict,
        the constructor assumes it contains the processed databases.

        :param measurements: The measurement results to store.
        :type measurements: Iterable[ReceiveRateMeasurement]
        """
        if isinstance(measurements, dict):
            self.data_for_duration = measurements
        else:
            self.data_for_duration = dict()
            # TODO: There is overlap with add() code. Worth extracting?
            for measurement in measurements:
                duration = measurement.duration
                if duration in self.data_for_duration:
                    self.data_for_duration[duration].add(measurement)
                else:
                    self.data_for_duration[duration] = PerDurationDatabase(
                        duration, [measurement]
                    )
        durations = sorted(self.data_for_duration.keys())
        self.current_duration = durations[-1] if duration else None
        self.previous_duration = durations[-2] if len(durations) > 1 else None

    def __repr__(self):
        """Return string executable to get equivalent instance.

        :returns: Code to construct equivalent instance.
        :rtype: str
        """
        return f"MeasurementDatabase(measurements={self.data_for_duration!r})"

    def set_current_duration(self, duration):
        """Remember what MLRsearch considers the current duration.

        Setting the same duration is allowed, setting smaller is not allowed.

        :param duration: Target trial duration of current phase, in seconds.
        :type duration: float
        :raises ValueError: If the duration is smaller than previous.
        """
        if duration < self.current_duration:
            raise ValueError(
                f"Duration {duration} shorter than current duration"
                f" {self.current_duration}"
            )
        if duration > self.current_duration:
            self.previous_duration = self.current_duration
            self.current_duration = duration
            self.data_for_duration[duration] = PerDurationDatabase(
                duration, list()
            )
        # Else no-op.

    def add(self, measurement):
        """Add a measurement. Duration has to match the set one.

        :param measurement: Measurement result to add to the database.
        :type measurement: ReceiveRateMeasurement
        """
        duration = measurement.duration
        if duration != self.current_duration:
            raise ValueError(
                f"{measurement!r} duration different than"
                f" {self.current_duration}"
            )
        self.data_for_duration[duration].add(measurement)

    def get_bounds(self, ratio):
        """Return 6 bounds: lower/upper, current/previous, tightest/second.

        Second tightest bounds are only returned for current duration.
        None instead of a measurement if there is no measurement of that type.

        The result cotains bounds in this order:
        1. Tightest lower bound for current duration.
        2. Tightest upper bound for current duration.
        3. Tightest lower bound for previous duration.
        4. Tightest upper bound for previous duration.
        5. Second tightest lower bound for current duration.
        6. Second tightest upper bound for current duration.

        :param ratio: Target ratio, valid has to be lower or equal.
        :type ratio: float
        :returns: Measurements acting as various bounds.
        :rtype: 6-tuple of Optional[PerDurationDatabase]
        """
        cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = [None] * 6
        duration = self.current_duration
        if duration is not None:
            data = self.data_for_duration[duration]
            cur_lo1, cur_hi1, cur_lo2, cur_hi2 = data.get_valid_bounds(ratio)
        duration = self.previous_duration
        if duration is not None:
            data = self.data_for_duration[duration]
            pre_lo, pre_hi, _, _ = data.get_valid_bounds(ratio)
        return cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2

    def get_results(self, ratio_list):
        """Return list of intervals for given ratios, from current duration.

        Attempt to construct valid intervals. If a valid bound is missing,
        use smallest/biggest target_tr for lower/upper bound.
        This can result in degenerate intervals.

        :param ratio_list: Ratios to create intervals for.
        :type ratio_list: Iterable[float]
        :returns: List of intervals.
        :rtype: List[ReceiveRateInterval]
        """
        ret_list = list()
        current_data = self.data_for_duration[self.current_duration]
        for ratio in ratio_list:
            lower_bound, upper_bound, _, _, _, _ = self.get_bounds(ratio)
            if lower_bound is None:
                lower_bound = current_data.measurements[0]
            if upper_bound is None:
                upper_bound = current_data.measurements[-1]
            ret_list.append(ReceiveRateInterval(lower_bound, upper_bound))
        return ret_list