aboutsummaryrefslogtreecommitdiffstats
path: root/netmodel/model/type.py
diff options
context:
space:
mode:
Diffstat (limited to 'netmodel/model/type.py')
-rw-r--r--netmodel/model/type.py185
1 files changed, 177 insertions, 8 deletions
diff --git a/netmodel/model/type.py b/netmodel/model/type.py
index 20dc2580..9a7b8740 100644
--- a/netmodel/model/type.py
+++ b/netmodel/model/type.py
@@ -16,27 +16,47 @@
# limitations under the License.
#
+from socket import inet_pton, inet_ntop, AF_INET6
+from struct import unpack, pack
+from abc import ABCMeta
+
from netmodel.util.meta import inheritors
class BaseType:
+ __choices__ = None
+
@staticmethod
def name():
return self.__class__.__name__.lower()
+ @classmethod
+ def restrict(cls, **kwargs):
+ class BaseType(cls):
+ __choices__ = kwargs.pop('choices', None)
+ return BaseType
+
class String(BaseType):
- def __init__(self, *args, **kwargs):
- self._min_size = kwargs.pop('min_size', None)
- self._max_size = kwargs.pop('max_size', None)
- self._ascii = kwargs.pop('ascii', False)
- self._forbidden = kwargs.pop('forbidden', None)
- super().__init__()
+ __min_size__ = None
+ __max_size__ = None
+ __ascii__ = None
+ __forbidden__ = None
+
+ @classmethod
+ def restrict(cls, **kwargs):
+ base = super().restrict(**kwargs)
+ class String(base):
+ __max_size__ = kwargs.pop('max_size', None)
+ __min_size__ = kwargs.pop('min_size', None)
+ __ascii__ = kwargs.pop('ascii', None)
+ __forbidden__ = kwargs.pop('forbidden', None)
+ return String
class Integer(BaseType):
def __init__(self, *args, **kwargs):
self._min_value = kwargs.pop('min_value', None)
self._max_value = kwargs.pop('max_value', None)
super().__init__()
-
+
class Double(BaseType):
def __init__(self, *args, **kwargs):
self._min_value = kwargs.pop('min_value', None)
@@ -49,12 +69,161 @@ class Bool(BaseType):
class Dict(BaseType):
pass
+class PrefixTreeException(Exception): pass
+class NotEnoughAddresses(PrefixTreeException): pass
+class UnassignablePrefix(PrefixTreeException): pass
+
+class Prefix(BaseType, metaclass=ABCMeta):
+
+ def __init__(self, ip_address, prefix_len=None):
+ if not prefix_len:
+ if not isinstance(ip_address, str):
+ import pdb; pdb.set_trace()
+ if '/' in ip_address:
+ ip_address, prefix_len = ip_address.split('/')
+ prefix_len = int(prefix_len)
+ else:
+ prefix_len = self.MAX_PREFIX_SIZE
+ if isinstance(ip_address, str):
+ ip_address = self.aton(ip_address)
+ self.ip_address = ip_address
+ self.prefix_len = prefix_len
+
+ def __contains__(self, obj):
+ #it can be an IP as a integer
+ if isinstance(obj, int):
+ obj = type(self)(obj, self.MAX_PREFIX_SIZE)
+ #Or it's an IP string
+ if isinstance(obj, str):
+ #It's a prefix as 'IP/prefix'
+ if '/' in obj:
+ split_obj = obj.split('/')
+ obj = type(self)(split_obj[0], int(split_obj[1]))
+ else:
+ obj = type(self)(obj, self.MAX_PREFIX_SIZE)
+
+ return self._contains_prefix(obj)
+
+ @classmethod
+ def mask(cls):
+ mask_len = cls.MAX_PREFIX_SIZE//8 #Converts from bits to bytes
+ mask = 0
+ for step in range(0,mask_len):
+ mask = (mask << 8) | 0xff
+ return mask
+
+ def _contains_prefix(self, prefix):
+ assert isinstance(prefix, type(self))
+ return (prefix.prefix_len >= self.prefix_len and
+ prefix.ip_address >= self.first_prefix_address() and
+ prefix.ip_address <= self.last_prefix_address())
+
+ #Returns the first address of a prefix
+ def first_prefix_address(self):
+ return self.ip_address & (self.mask() << (self.MAX_PREFIX_SIZE-self.prefix_len))
+
+ def canonical_prefix(self):
+ return type(self)(self.first_prefix_address(), self.prefix_len)
+
+ def last_prefix_address(self):
+ return self.ip_address | (self.mask() >> self.prefix_len)
+
+ def limits(self):
+ return self.first_prefix_address(), self.last_prefix_address()
+
+ def __str__(self):
+ return "{}/{}".format(self.ntoa(self.first_prefix_address()), self.prefix_len)
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ return self.get_tuple() == other.get_tuple()
+
+ def get_tuple(self):
+ return (self.first_prefix_address(), self.prefix_len)
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
+ def __iter__(self):
+ return self.get_iterator()
+
+ #Iterates by steps of prefix_len, e.g., on all available /31 in a /24
+ def get_iterator(self, prefix_len=None):
+ if prefix_len is None:
+ prefix_len=self.MAX_PREFIX_SIZE
+ assert (prefix_len >= self.prefix_len and prefix_len<=self.MAX_PREFIX_SIZE)
+ step = 2**(self.MAX_PREFIX_SIZE - prefix_len)
+ for ip in range(self.first_prefix_address(), self.last_prefix_address()+1, step):
+ yield type(self)(ip, prefix_len)
+
+class Inet4Prefix(Prefix):
+
+ MAX_PREFIX_SIZE = 32
+
+ @classmethod
+ def aton(cls, address):
+ ret = 0
+ components = address.split('.')
+ for comp in components:
+ ret = (ret << 8) + int(comp)
+ return ret
+
+ @classmethod
+ def ntoa(cls, address):
+ components = []
+ for _ in range(0,4):
+ components.insert(0,'{}'.format(address % 256))
+ address = address >> 8
+ return '.'.join(components)
+
+class Inet6Prefix(Prefix):
+
+ MAX_PREFIX_SIZE = 128
+
+ @classmethod
+ def aton (cls, address):
+ prefix, suffix = unpack(">QQ", inet_pton(AF_INET6, address))
+ return (prefix << 64) | suffix
+
+ @classmethod
+ def ntoa (cls, address):
+ return inet_ntop(AF_INET6, pack(">QQ", address >> 64, address & ((1 << 64) -1)))
+
+ #skip_internet_address: skip a:b::0, as v6 often use default /64 prefixes
+ def get_iterator(self, prefix_len=None, skip_internet_address=None):
+ if skip_internet_address is None:
+ #We skip the internet address if we iterate over Addresses
+ if prefix_len is None:
+ skip_internet_address = True
+ #But not if we iterate over prefixes
+ else:
+ skip_internet_address = False
+ it = super().get_iterator(prefix_len)
+ if skip_internet_address:
+ next(it)
+ return it
+
+class InetAddress(Prefix):
+
+ def get_tuple(self):
+ return (self.ip_address, self.prefix_len)
+
+ def __str__(self):
+ return self.ntoa(self.ip_address)
+
+class Inet4Address(InetAddress, Inet4Prefix):
+ pass
+
+class Inet6Address(InetAddress, Inet6Prefix):
+ pass
+
class Self(BaseType):
"""Self-reference
"""
class Type:
- BASE_TYPES = (String, Integer, Double, Bool)
+ BASE_TYPES = (String, Integer, Double, Bool)
_registry = dict()
@staticmethod