aboutsummaryrefslogtreecommitdiffstats
path: root/netmodel/network/packet.py
diff options
context:
space:
mode:
Diffstat (limited to 'netmodel/network/packet.py')
-rw-r--r--netmodel/network/packet.py429
1 files changed, 429 insertions, 0 deletions
diff --git a/netmodel/network/packet.py b/netmodel/network/packet.py
new file mode 100644
index 00000000..9552b0e7
--- /dev/null
+++ b/netmodel/network/packet.py
@@ -0,0 +1,429 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2017 Cisco and/or its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+import enum
+import logging
+import pickle
+import traceback
+import uuid
+
+from types import GeneratorType
+
+from netmodel.network.flow import Flow
+from netmodel.network.prefix import Prefix as Prefix_
+from netmodel.model.attribute import Attribute
+from netmodel.model.type import String
+from netmodel.model.query import Query
+from netmodel.model.filter import Filter as Filter_
+from netmodel.model.field_names import FieldNames as FieldNames_
+from netmodel.model.object import Object
+
+log = logging.getLogger(__name__)
+
+class NetmodelException(Exception):
+ pass
+
+#------------------------------------------------------------------------------
+
+NETMODEL_TLV_TYPELEN_STR = '!H'
+NETMODEL_TLV_SIZE = 2
+NETMODEL_TLV_TYPE_MASK = 0xfe00
+NETMODEL_TLV_TYPE_SHIFT = 9
+NETMODEL_TLV_LENGTH_MASK = 0x01ff
+
+NETMODEL_TLV_OBJECT_NAME = 1
+NETMODEL_TLV_FIELD = 2
+NETMODEL_TLV_FIELDS = 3
+NETMODEL_TLV_PREDICATE = 4
+NETMODEL_TLV_FILTER = 5
+NETMODEL_TLV_SRC = 6
+NETMODEL_TLV_DST = 7
+NETMODEL_TLV_PROTOCOL = 8
+NETMODEL_TLV_FLAGS = 9
+NETMODEL_TLV_PAYLOAD = 10
+NETMODEL_TLV_PACKET = 11
+
+SUCCESS = 0
+WARNING = 1
+ERROR = 2
+
+FLAG_LAST = 1 << 0
+FLAG_REPLY = 1 << 1
+
+#------------------------------------------------------------------------------
+
+class PacketProtocol(enum.Enum):
+ Query = 'query'
+ Error = 'error'
+
+#------------------------------------------------------------------------------
+
+class VICNTLV:
+
+ _LEN_MIN = 0
+ _LEN_MAX = 511
+ tlv_type = None
+
+ _tlv_parsers = {} # Attributes...
+ tlvs = []
+
+ def decode(self, buf):
+ (self.typelen, ) = struct.unpack(
+ NETMODEL_TLV_TYPELEN_STR, buf[:NETMODEL_TLV_SIZE])
+ tlv_type = \
+ (self.typelen & NETMODEL_TLV_TYPE_MASK) >> NETMODEL_TLV_TYPE_SHIFT
+ assert self.tlv_type == tlv_type
+
+ self.len = self.typelen & NETMODEL_TLV_LENGTH_MASK
+ assert len(buf) >= self.len + NETMODEL_TLV_SIZE
+
+ self.tlv_info = buf[NETMODEL_TLV_SIZE:]
+ self.tlv_info = self.tlv_info[:self.len]
+
+ #--------------------------------------------------------------------------
+ # Descriptor protocol
+ #--------------------------------------------------------------------------
+
+ @classmethod
+ def _get_tlv_parsers(cls):
+ if not cls._tlv_parsers:
+ cls._tlv_parsers = None
+ return cls._tlv_parsers
+
+
+ @staticmethod
+ def get_type(buf):
+ (typelen, ) = struct.unpack(NETMODEL_TLV_TYPELEN_STR,
+ buf[:NETMODEL_TLV_SIZE])
+ return (typelen & NETMODEL_TLV_TYPE_MASK) >> NETMODEL_TLV_TYPE_SHIFT
+
+
+ def _len_valid(self):
+ return self._LEN_MIN <= self.len and self.len <= self._LEN_MAX
+
+ #--------------------------------------------------------------------------
+
+ @classmethod
+ def _parser(cls, buf):
+ tlvs = []
+
+ while buf:
+ tlv_type = VICNTLV.get_type(buf)
+ tlv = cls._tlv_parsers[tlv_type](buf)
+ tlvs.append(tlv)
+ offset = NETMODEL_TLV_SIZE + tlv.len
+ buf = buf[offset:]
+ if tlv.tlv_type == NETMODEL_TLV_END:
+ break
+ assert len(buf) > 0
+
+ pkt = cls(tlvs)
+
+ assert pkt._tlvs_len_valid()
+ assert pkt._tlvs_valid()
+
+ return pkt, None, buf
+
+ @classmethod
+ def parser(cls, buf):
+ try:
+ return cls._parser(buf)
+ except:
+ return None, None, buf
+
+ def serialize(self, payload, prev):
+ data = bytearray()
+ for tlv in self.tlvs:
+ data += tlv.serialize()
+
+ return data
+
+ @classmethod
+ def set_type(cls, tlv_cls):
+ cls._tlv_parsers[tlv_cls.tlv_type] = tlv_cls
+
+ @classmethod
+ def get_type(cls, tlv_type):
+ return cls._tlv_parsers[tlv_type]
+
+ @classmethod
+ def set_tlv_type(cls, tlv_type):
+ def _set_type(tlv_cls):
+ tlv_cls.tlv_type = tlv_type
+ #cls.set_type(tlv_cls)
+ return tlv_cls
+ return _set_type
+
+ def __len__(self):
+ return sum(NETMODEL_TLV_SIZE + tlv.len for tlv in self.tlvs)
+
+#------------------------------------------------------------------------------
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_OBJECT_NAME)
+class ObjectName(VICNTLV): pass
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_FIELD)
+class Field(VICNTLV):
+ """Field == STR
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_PREDICATE)
+class Predicate(VICNTLV):
+ """Predicate == key, op, value
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_FILTER)
+class Filter(Filter_, VICNTLV):
+ """Filter == Array<Predicates>
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_FIELDS)
+class FieldNames(FieldNames_, VICNTLV):
+ """Fields == Array<Field>
+ """
+
+
+class Prefix(Object, Prefix_, VICNTLV):
+ object_name = ObjectName()
+ filter = Filter()
+ field_names = FieldNames()
+
+ def __init__(self, *args, **kwargs):
+ Object.__init__(self)
+ Prefix_.__init__(self, *args, **kwargs)
+ VICNTLV.__init__(self)
+
+ def get_tuple(self):
+ return (self.object_name, self.filter, self.field_names)
+
+ def __eq__(self, other):
+ return self.get_tuple() == other.get_tuple()
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_SRC)
+class Source(Prefix):
+ """Source address
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_DST)
+class Destination(Prefix):
+ """Destination address
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_PROTOCOL)
+class Protocol(Attribute, VICNTLV):
+ """Protocol
+ """
+
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_FLAGS)
+class Flags(Attribute, VICNTLV):
+ """Flags: last, ...
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_PAYLOAD)
+class Payload(Attribute, VICNTLV):
+ """Payload
+ """
+
+@VICNTLV.set_tlv_type(NETMODEL_TLV_PACKET)
+class Packet(Object, VICNTLV):
+ """Base packet class
+ """
+ source = Source()
+ destination = Destination(Prefix)
+ protocol = Protocol(String, default = 'query')
+ flags = Flags()
+ payload = Payload()
+
+ # This should be dispatched across L3 L4 L7
+
+ def __init__(self, source = None, destination = None, protocol = None,
+ flags = 0, payload = None):
+ self.source = source
+ self.destination = destination
+ self.protocol = protocol
+ self.flags = flags
+ self.payload = payload
+
+ def get_flow(self):
+ return Flow(self.source, self.destination)
+
+ @staticmethod
+ def from_query(query, src_query = None, reply = False):
+ packet = Packet()
+ if src_query:
+ address = Prefix(
+ object_name = src_query.object_name,
+ filter = src_query.filter,
+ field_names = src_query.field_names,
+ aggregate = src_query.aggregate)
+ if reply:
+ packet.destination = address
+ else:
+ packet.source = address
+
+ if query:
+ address = Prefix(
+ object_name = query.object_name,
+ filter = query.filter,
+ field_names = query.field_names,
+ aggregate = query.aggregate)
+
+ if reply:
+ packet.source = address
+ else:
+ packet.destination = address
+
+ packet.payload = (query.action, query.params)
+
+ packet.protocol = 'sync'
+ packet.last = not query or query.last
+ packet.reply = reply
+
+ return packet
+
+ def to_query(self):
+ action, params = self.payload
+
+ address = self.source if self.reply else self.destination
+ object_name = address.object_name
+ filter = address.filter
+ field_names = address.field_names
+ aggregate = address.aggregate
+
+ return Query(action, object_name, filter, params, field_names,
+ aggregate = aggregate, last = self.last, reply = self.reply)
+
+ @property
+ def last(self):
+ return self.flags & FLAG_LAST
+
+ @last.setter
+ def last(self, last):
+ if last:
+ self.flags |= FLAG_LAST
+ else:
+ self.flags &= ~FLAG_LAST
+
+ @property
+ def reply(self):
+ return self.flags & FLAG_REPLY
+
+ @reply.setter
+ def reply(self, reply):
+ if reply:
+ self.flags |= FLAG_REPLY
+ else:
+ self.flags &= ~FLAG_REPLY
+
+class ErrorPacket(Packet):
+ """
+ Analog with ICMP errors packets in IP networks
+ """
+
+ #--------------------------------------------------------------------------
+ # Constructor
+ #--------------------------------------------------------------------------
+
+ def __init__(self, type = ERROR, code = ERROR, message = None,
+ traceback = None, **kwargs):
+ assert not traceback or isinstance(traceback, str)
+
+ Packet.__init__(self, **kwargs)
+ self.protocol = PacketProtocol.Error
+ self.last = True
+ self._type = type
+ self._code = code
+ self._message = message
+ self._traceback = traceback
+
+ #--------------------------------------------------------------------------
+ # Static methods
+ #--------------------------------------------------------------------------
+
+ @staticmethod
+ def from_exception(packet, e):
+ if isinstance(e, NetmodelException):
+ error_packet = ErrorPacket(
+ type = e.TYPE, # eg. ERROR
+ code = e.CODE, # eg. BADARGS
+ message = str(e), #e.message,
+ traceback = traceback.format_exc(),
+ last = True
+ )
+ else:
+ error_packet = ErrorPacket(
+ type = ERROR,
+ code = UNHANDLED_EXCEPTION,
+ message = str(e),
+ traceback = traceback.format_exc(),
+ last = True
+ )
+ error_packet.set_source(packet.get_destination())
+ error_packet.set_destination(packet.get_source())
+ return error_packet
+
+ def get_message(self):
+ """
+ Returns:
+ The error message related to this ErrorPacket.
+ """
+ return self._message
+
+ def get_traceback(self):
+ """
+ Returns:
+ The traceback related to this ErrorPacket.
+ """
+ return self._traceback
+
+ def get_origin(self):
+ """
+ Returns:
+ A value among {code::CORE, code::GATEWAY}
+ identifying who is the origin of this ErrorPacket.
+ """
+ return self._origin
+
+ def get_code(self):
+ """
+ Returns:
+ The error code of the Error carried by this ErrorPacket.
+ """
+ return self._code
+
+ def get_type(self):
+ """
+ Returns:
+ The error type of the Error carried by this ErrorPacket.
+ """
+ return self._type
+
+ def __repr__(self):
+ """
+ Returns:
+ The '%r' representation of this ERROR Packet.
+ """
+ return "<Packet.%s: %s>" % (
+ Packet.get_protocol_name(self.get_protocol()),
+ self.get_message()
+ )