aboutsummaryrefslogtreecommitdiffstats
path: root/resources/libraries/python/MLRsearch/discrete_width.py
diff options
context:
space:
mode:
Diffstat (limited to 'resources/libraries/python/MLRsearch/discrete_width.py')
-rw-r--r--resources/libraries/python/MLRsearch/discrete_width.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/resources/libraries/python/MLRsearch/discrete_width.py b/resources/libraries/python/MLRsearch/discrete_width.py
new file mode 100644
index 0000000000..8a4845a83f
--- /dev/null
+++ b/resources/libraries/python/MLRsearch/discrete_width.py
@@ -0,0 +1,197 @@
+# Copyright (c) 2023 Cisco and/or its affiliates.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module defining DiscreteWidth class."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+
+from .load_rounding import LoadRounding
+
+
+# TODO: Make properly frozen.
+@dataclass(order=True)
+class DiscreteWidth:
+ """Structure to store float width together with its rounded integer form.
+
+ The width does not have to be positive, i.e. the computed integer width
+ does not have to be larger than zero.
+
+ LoadRounding instance is needed to enable conversion between two forms.
+
+ Conversion and arithmetic methods are added for convenience.
+ Division and non-integer multiplication are intentionally not supported,
+ as MLRsearch should not seek unround widths when round ones are available.
+
+ The instance is effectively immutable, but not hashable as it refers
+ to the rounding instance, which is implemented as mutable
+ (although the mutations are not visible).
+ """
+
+ # For most debugs, rounding in repr just takes space.
+ rounding: LoadRounding = field(repr=False, compare=False)
+ """Rounding instance to use for conversion."""
+ float_width: float = None
+ """Relative width of float intended load.
+ This is treated as a constructor argument, and does not need to match
+ the int width. Int width is computed to be no wider than this."""
+ int_width: int = field(compare=False, default=None)
+ """Integer form, difference of integer loads.
+ This is the primary quantity used by most computations."""
+
+ def __post_init__(self) -> None:
+ """Ensure types, compute missing information.
+
+ At this point, it is allowed for float width to be slightly larger
+ than the implied int width.
+
+ If both forms are specified, the float form is taken as primary
+ (thus the integer form is recomputed to match).
+
+ :raises RuntimeError: If both init arguments are None.
+ """
+ if self.float_width is None and self.int_width is None:
+ raise RuntimeError("Float or int value is needed.")
+ if self.float_width is None:
+ self.int_width = int(self.int_width)
+ min_load = self.rounding.int2float(0)
+ increased_load = self.rounding.int2float(self.int_width)
+ self.float_width = (increased_load - min_load) / increased_load
+ return
+ self.float_width = float(self.float_width)
+ min_load = self.rounding.int2float(0)
+ increased_load = min_load / (1.0 - self.float_width)
+ int_load = self.rounding.float2int(increased_load)
+ verify_load = self.rounding.int2float(int_load)
+ if verify_load > increased_load:
+ int_load -= 1
+ self.int_width = int_load
+
+ def __str__(self) -> str:
+ """Convert into a short human-readable string.
+
+ :returns: The short string.
+ :rtype: str
+ """
+ return f"int_width={int(self)}"
+
+ def __int__(self) -> int:
+ """Return the integer form.
+
+ :returns: The int field value.
+ :rtype: int
+ """
+ return self.int_width
+
+ def __float__(self) -> float:
+ """Return the float form.
+
+ :returns: The float field value.
+ :rtype: float
+ """
+ return self.float_width
+
+ def __hash__(self) -> int:
+ """Return a hash based on the float value.
+
+ With this, the instance can be used as if it was immutable and hashable,
+ e.g. it can be a key in a dict.
+
+ :returns: Hash value for this instance.
+ :rtype: int
+ """
+ return hash(float(self))
+
+ def rounded_down(self) -> DiscreteWidth:
+ """Create and return new instance with float form matching int.
+
+ :returns: New instance with same int form and float form rounded down.
+ :rtype: DiscreteWidth
+ """
+ return DiscreteWidth(rounding=self.rounding, int_width=int(self))
+
+ def __add__(self, width: DiscreteWidth) -> DiscreteWidth:
+ """Return newly constructed instance with int widths added.
+
+ Rounding instance (reference) is copied from self.
+
+ Argument type is checked, to avoid caller adding something unsupported.
+
+ :param width: Value to add to int width.
+ :type width: DiscreteWidth
+ :returns: New instance.
+ :rtype: DiscreteWidth
+ :raises RuntimeError: When argument has unexpected type.
+ """
+ if not isinstance(width, DiscreteWidth):
+ raise RuntimeError(f"Not width: {width!r}")
+ return DiscreteWidth(
+ rounding=self.rounding,
+ int_width=self.int_width + int(width),
+ )
+
+ def __sub__(self, width: DiscreteWidth) -> DiscreteWidth:
+ """Return newly constructed instance with int widths subtracted.
+
+ Rounding instance (reference) is copied from self.
+
+ Argument type is checked, to avoid caller adding something unsupported.
+ Non-positive results are disallowed by constructor.
+
+ :param width: Value to subtract to int width.
+ :type width: DiscreteWidth
+ :returns: New instance.
+ :rtype: DiscreteWidth
+ :raises RuntimeError: When argument has unexpected type.
+ """
+ if not isinstance(width, DiscreteWidth):
+ raise RuntimeError(f"Not width: {type(width)}")
+ return DiscreteWidth(
+ rounding=self.rounding,
+ int_width=self.int_width - int(width),
+ )
+
+ def __mul__(self, coefficient: int) -> DiscreteWidth:
+ """Construct new instance with int value multiplied.
+
+ Rounding instance (reference) is copied from self.
+
+ :param coefficient: Constant to multiply int width with.
+ :type coefficient: int
+ :returns: New instance with multiplied int width.
+ :rtype: DiscreteWidth
+ :raises RuntimeError: If argument value does not meet requirements.
+ """
+ if not isinstance(coefficient, int):
+ raise RuntimeError(f"Coefficient not int: {coefficient!r}")
+ if coefficient < 1:
+ raise RuntimeError(f"Coefficient not positive: {coefficient!r}")
+ return DiscreteWidth(
+ rounding=self.rounding,
+ int_width=self.int_width * coefficient,
+ )
+
+ def half_rounded_down(self) -> DiscreteWidth:
+ """Contruct new instance of half the integer width.
+
+ If the current integer width is odd, round the half width down.
+
+ :returns: New instance with half int width.
+ :rtype: DiscreteWidth
+ :raises RuntimeError: If the resulting integerl width is not positive.
+ """
+ return DiscreteWidth(
+ rounding=self.rounding,
+ int_width=self.int_width // 2,
+ )