diff options
Diffstat (limited to 'resources/libraries/python/MLRsearch/load_rounding.py')
-rw-r--r-- | resources/libraries/python/MLRsearch/load_rounding.py | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/resources/libraries/python/MLRsearch/load_rounding.py b/resources/libraries/python/MLRsearch/load_rounding.py new file mode 100644 index 0000000000..0ac4487be9 --- /dev/null +++ b/resources/libraries/python/MLRsearch/load_rounding.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module defining LoadRounding class.""" + +import math + +from dataclasses import dataclass +from typing import List, Tuple + +from .dataclass import secondary_field + + +@dataclass +class LoadRounding: + """Class encapsulating stateful utilities that round intended load values. + + For MLRsearch algorithm logic to be correct, it is important that + interval width expansion and narrowing are exactly reversible, + which is not true in general for floating point number arithmetics. + + This class offers conversion to and from an integer quantity. + Operations in the integer realm are guaranteed to be reversible, + so the only risk is when converting between float and integer realm. + + Which relative width corresponds to the unit integer + is computed in initialization from width goals, + striking a balance between memory requirements and precision. + + There are two quality knobs. One restricts how far + can an integer be from the exact float value. + The other restrict how close it can be. That is to make sure + even with unpredictable rounding errors during the conversion, + the converted integer value is never bigger than the intended float value, + to ensure the intervals returned from MLRsearch will always + meet the relative width goal. + + An instance of this class is mutable only in the sense it contains + a growing cache of previously computed values. + """ + + # TODO: Hide the cache and present as frozen hashable object. + + min_load: float + """Minimal intended load [tps] to support, must be positive.""" + max_load: float + """Maximal intended load [tps] to support, must be bigger than min load.""" + float_goals: Tuple[float] + """Relative width goals to approximate, each must be positive + and smaller than one. Deduplicated and sorted in post init.""" + quality_lower: float = 0.99 + """Minimal multiple of each goal to be achievable.""" + quality_upper: float = 0.999999 + """Maximal multiple of each goal to be achievable.""" + # Primary fields above, computed fields below. + max_int_load: int = secondary_field() + """Integer for max load (min load int is zero).""" + _int2load: List[Tuple[int, float]] = secondary_field() + """Known int values (sorted) and their float equivalents.""" + + def __post_init__(self) -> None: + """Ensure types, perform checks, initialize conversion structures. + + :raises RuntimeError: If a requirement is not met. + """ + self.min_load = float(self.min_load) + self.max_load = float(self.max_load) + if not 0.0 < self.min_load < self.max_load: + raise RuntimeError("Load limits not supported: {self}") + self.quality_lower = float(self.quality_lower) + self.quality_upper = float(self.quality_upper) + if not 0.0 < self.quality_lower < self.quality_upper < 1.0: + raise RuntimeError("Qualities not supported: {self}") + goals = [] + for goal in self.float_goals: + goal = float(goal) + if not 0.0 < goal < 1.0: + raise RuntimeError(f"Goal width {goal} is not supported.") + goals.append(goal) + self.float_goals = tuple(sorted(set(goals))) + self.max_int_load = self._find_ints() + self._int2load = [] + self._int2load.append((0, self.min_load)) + self._int2load.append((self.max_int_load, self.max_load)) + + def _find_ints(self) -> int: + """Find and return value for max_int_load. + + Separated out of post init, as this is less conversion and checking, + and more math and searching. + + A dumb implementation would start with 1 and kept increasing by 1 + until all goals are within quality limits. + An actual implementation is smarter with the increment, + so it is expected to find the resulting values somewhat faster. + + :returns: Value to be stored as max_int_load. + :rtype: int + """ + minmax_log_width = math.log(self.max_load) - math.log(self.min_load) + log_goals = [-math.log1p(-goal) for goal in self.float_goals] + candidate = 1 + while 1: + log_width_unit = minmax_log_width / candidate + # Fallback to increment by one if rounding errors make tries bad. + next_tries = [candidate + 1] + acceptable = True + for log_goal in log_goals: + units = log_goal / log_width_unit + int_units = math.floor(units) + quality = int_units / units + if not self.quality_lower <= quality <= self.quality_upper: + acceptable = False + target = (int_units + 1) / self.quality_upper + next_try = (target / units) * candidate + next_tries.append(next_try) + # Else quality acceptable, not bumping the candidate. + if acceptable: + return candidate + candidate = int(math.ceil(max(next_tries))) + + def int2float(self, int_load: int) -> float: + """Convert from int to float tps load. Expand internal table as needed. + + Too low or too high ints result in min or max load respectively. + + :param int_load: Integer quantity to turn back into float load. + :type int_load: int + :returns: Converted load in tps. + :rtype: float + :raises RuntimeError: If internal inconsistency is detected. + """ + if int_load <= 0: + return self.min_load + if int_load >= self.max_int_load: + return self.max_load + lo_index, hi_index = 0, len(self._int2load) + lo_int, hi_int = 0, self.max_int_load + lo_load, hi_load = self.min_load, self.max_load + while hi_int - lo_int >= 2: + mid_index = (hi_index + lo_index + 1) // 2 + if mid_index >= hi_index: + mid_int = (hi_int + lo_int) // 2 + log_coeff = math.log(hi_load) - math.log(lo_load) + log_coeff *= (mid_int - lo_int) / (hi_int - lo_int) + mid_load = lo_load * math.exp(log_coeff) + self._int2load.insert(mid_index, (mid_int, mid_load)) + hi_index += 1 + mid_int, mid_load = self._int2load[mid_index] + if mid_int < int_load: + lo_index, lo_int, lo_load = mid_index, mid_int, mid_load + continue + if mid_int > int_load: + hi_index, hi_int, hi_load = mid_index, mid_int, mid_load + continue + return mid_load + raise RuntimeError("Bisect in int2float failed.") + + def float2int(self, float_load: float) -> int: + """Convert and round from tps load to int. Maybe expand internal table. + + Too low or too high load result in zero or max int respectively. + + Result value is rounded down to an integer. + + :param float_load: Tps quantity to convert into int. + :type float_load: float + :returns: Converted integer value suitable for halving. + :rtype: int + """ + if float_load <= self.min_load: + return 0 + if float_load >= self.max_load: + return self.max_int_load + lo_index, hi_index = 0, len(self._int2load) + lo_int, hi_int = 0, self.max_int_load + lo_load, hi_load = self.min_load, self.max_load + while hi_int - lo_int >= 2: + mid_index = (hi_index + lo_index + 1) // 2 + if mid_index >= hi_index: + mid_int = (hi_int + lo_int) // 2 + log_coeff = math.log(hi_load) - math.log(lo_load) + log_coeff *= (mid_int - lo_int) / (hi_int - lo_int) + mid_load = lo_load * math.exp(log_coeff) + self._int2load.insert(mid_index, (mid_int, mid_load)) + hi_index += 1 + mid_int, mid_load = self._int2load[mid_index] + if mid_load < float_load: + lo_index, lo_int, lo_load = mid_index, mid_int, mid_load + continue + if mid_load > float_load: + hi_index, hi_int, hi_load = mid_index, mid_int, mid_load + continue + return mid_int + return lo_int |