aboutsummaryrefslogtreecommitdiffstats
path: root/netmodel
diff options
context:
space:
mode:
Diffstat (limited to 'netmodel')
-rw-r--r--netmodel/interfaces/process/__init__.py49
-rw-r--r--netmodel/interfaces/vicn.py39
-rw-r--r--netmodel/interfaces/vpp/__init__.py225
-rw-r--r--netmodel/interfaces/websocket/__init__.py38
-rw-r--r--netmodel/model/attribute.py120
-rw-r--r--netmodel/model/collection.py5
-rw-r--r--netmodel/model/key.py19
-rw-r--r--netmodel/model/mapper.py4
-rw-r--r--netmodel/model/object.py126
-rw-r--r--netmodel/model/query.py13
-rw-r--r--netmodel/model/sa_collections.py265
-rw-r--r--netmodel/model/sa_compat.py270
-rw-r--r--netmodel/model/type.py185
-rw-r--r--netmodel/model/uuid.py51
-rw-r--r--netmodel/network/fib.py10
-rw-r--r--netmodel/network/flow_table.py145
-rw-r--r--netmodel/network/interface.py20
-rw-r--r--netmodel/network/packet.py34
-rw-r--r--netmodel/network/prefix.py13
-rw-r--r--netmodel/network/router.py35
-rw-r--r--netmodel/util/daemon.py2
21 files changed, 1429 insertions, 239 deletions
diff --git a/netmodel/interfaces/process/__init__.py b/netmodel/interfaces/process/__init__.py
index b985c32f..59bf6f9f 100644
--- a/netmodel/interfaces/process/__init__.py
+++ b/netmodel/interfaces/process/__init__.py
@@ -16,6 +16,7 @@
# limitations under the License.
#
+import logging
import shlex
import socket
import subprocess
@@ -27,10 +28,13 @@ from netmodel.network.prefix import Prefix
from netmodel.model.attribute import Attribute
from netmodel.model.filter import Filter
from netmodel.model.object import Object
-from netmodel.model.query import Query, ACTION_UPDATE
-from netmodel.model.query import ACTION_SUBSCRIBE, FUNCTION_SUM
+from netmodel.model.query import Query, ACTION_UPDATE, ACTION2STR
+from netmodel.model.query import ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE
+from netmodel.model.query import FUNCTION_SUM
from netmodel.model.type import String, Integer, Double
+log = logging.getLogger(__name__)
+
DEFAULT_INTERVAL = 1 # s
KEY_FIELD = 'device_name'
@@ -48,19 +52,19 @@ class Process(threading.Thread):
class BWMThread(Process):
SEP=';'
- CMD="stdbuf -oL bwm-ng -t 1000 -N -o csv -c 0 -C '%s'"
+ CMD="stdbuf -oL bwm-ng -t 1000 -N -o csv -c 0 -C '%s'"
# Parsing information (from README, specs section)
# https://github.com/jgjl/bwm-ng/blob/master/README
#
# Type rate:
- FIELDS_RATE = ['timestamp', 'iface_name', 'bytes_out_s', 'bytes_in_s',
+ FIELDS_RATE = ['timestamp', 'iface_name', 'bytes_out_s', 'bytes_in_s',
'bytes_total_s', 'bytes_in', 'bytes_out', 'packets_out_s',
'packets_in_s', 'packets_total_s', 'packets_in', 'packets_out',
'errors_out_s', 'errors_in_s', 'errors_in', 'errors_out']
# Type svg, sum, max
- FIELDS_SUM = ['timestamp', 'iface_name', 'bytes_out', 'bytes_in',
- 'bytes_total', 'packets_out', 'packets_in', 'packets_total',
+ FIELDS_SUM = ['timestamp', 'iface_name', 'bytes_out', 'bytes_in',
+ 'bytes_total', 'packets_out', 'packets_in', 'packets_total',
'errors_out', 'errors_in']
def __init__(self, interfaces, callback):
@@ -74,7 +78,7 @@ class BWMThread(Process):
def run(self):
cmd = self.CMD % (self.SEP)
- p = subprocess.Popen(shlex.split(cmd), stdout = subprocess.PIPE,
+ p = subprocess.Popen(shlex.split(cmd), stdout = subprocess.PIPE,
stderr = subprocess.STDOUT)
stdout = []
self._is_running = True
@@ -85,18 +89,18 @@ class BWMThread(Process):
break
if line:
record = self._parse_line(line.strip())
- # We use 'total' to push the statistics back to VICN
+ # We use 'total' to push the statistics back to VICN
if record['iface_name'] == 'total':
for interfaces in self.groups_of_interfaces:
if not len(interfaces) > 1:
# If the tuple contains only one interface, grab
# the information from bwm_stats and sends it back
- # to VICN
+ # to VICN
if interfaces[0] not in self.bwm_stats:
continue
interface = self.bwm_stats[interfaces[0]]
f_list = [[KEY_FIELD, '==', interface.device_name]]
- query = Query(ACTION_UPDATE, Interface.__type__,
+ query = Query(ACTION_UPDATE, Interface.__type__,
filter = Filter.from_list(f_list),
params = interface.get_attribute_dict())
self._callback(query)
@@ -133,7 +137,7 @@ class BWMThread(Process):
# sum function specified because it is used to
# match the subscribe query
attrs = aggregated_interface.get_attribute_dict()
- query = Query(ACTION_UPDATE, Interface.__type__,
+ query = Query(ACTION_UPDATE, Interface.__type__,
filter = Filter.from_list([predicate]),
params = attrs,
aggregate = FUNCTION_SUM)
@@ -148,7 +152,7 @@ class BWMThread(Process):
bw_upstream = float(record['bytes_out_s']),
bw_downstream = float(record['bytes_in_s']),
)
-
+
self.bwm_stats[record['iface_name']] = interface
rc = p.poll()
@@ -176,17 +180,17 @@ class BWMInterface(BaseInterface):
packet = Packet.from_query(reply, reply = True)
self.receive(packet)
- #--------------------------------------------------------------------------
+ #--------------------------------------------------------------------------
# Router interface
- #--------------------------------------------------------------------------
+ #--------------------------------------------------------------------------
def send_impl(self, packet):
query = packet.to_query()
- assert query.action == ACTION_SUBSCRIBE
- interval = query.params.get('interval', DEFAULT_INTERVAL) \
- if query.params else DEFAULT_INTERVAL
- assert interval
+ if query.action not in [ACTION_SUBSCRIBE]: # , ACTION_UNSUBSCRIBE]:
+ log.warning("Ignore unknown action {}".format(
+ ACTION2STR[query.action]))
+ return
# TODO: Add the sum operator. If sum the list of interfaces is
# added to the BWMThread as a tuple, otherwise every single
@@ -194,15 +198,18 @@ class BWMInterface(BaseInterface):
# We currently simply extract it from the filter
interfaces_list = [p.value for p in query.filter if p.key == KEY_FIELD]
-
+
# interfaces is a list of tuple. If someone sbscribe to mulitple
# interfaces interfaces will be a list of 1 tuple. The tuple will
# contain the set of interfaces
- assert len(interfaces_list) == 1
+ if len(interfaces_list) != 1:
+ log.warning("interfaces_list should have len = 1: {}".format(interfaces_list))
+ return
+
interfaces = interfaces_list[0] \
if isinstance(interfaces_list[0], tuple) \
else tuple([interfaces_list[0]])
-
+
# Check if interfaces is more than one. In this case, we only support
# The SUM function on the list of field.
if len(interfaces) > 1:
diff --git a/netmodel/interfaces/vicn.py b/netmodel/interfaces/vicn.py
index 9ec9672e..80195f9a 100644
--- a/netmodel/interfaces/vicn.py
+++ b/netmodel/interfaces/vicn.py
@@ -16,16 +16,20 @@
# limitations under the License.
#
+import logging
+
from vicn.core.task import BashTask
from netmodel.model.object import Object
from netmodel.model.attribute import Attribute
-from netmodel.model.query import Query, ACTION_INSERT
+from netmodel.model.query import Query, ACTION_INSERT, ACTION_SELECT
from netmodel.model.type import String
from netmodel.network.interface import Interface, InterfaceState
from netmodel.network.packet import Packet
from netmodel.network.prefix import Prefix
from netmodel.util.misc import lookahead
+log = logging.getLogger(__name__)
+
class VICNBaseResource(Object):
__type__ = 'vicn/'
@@ -50,18 +54,27 @@ class VICNBaseResource(Object):
elif query.object_name == 'resource':
resources = interface._manager.get_resources()
else:
- resources = interface._manager.by_type_str(query.object_name)
-
- for resource, last in lookahead(resources):
- params = resource.get_attribute_dict(aggregates = True)
- params['id'] = resource._state.uuid._uuid
- params['type'] = resource.get_types()
- params['state'] = resource._state.state
- params['log'] = resource._state.log
- reply = Query(ACTION_INSERT, query.object_name, params = params)
- reply.last = last
- packet = Packet.from_query(reply, reply = True)
- cb(packet, ingress_interface = interface)
+ _resources = interface._manager.by_type_str(query.object_name)
+ resources = list()
+ for resource in _resources:
+ group_names = [r.name for r in resource.groups]
+ resources.append(resource)
+
+ if query.action == ACTION_SELECT:
+ for resource, last in lookahead(resources):
+ params = resource.get_attribute_dict(aggregates = True)
+ params['id'] = resource._state.uuid._uuid
+ params['type'] = resource.get_types()
+ params['state'] = resource._state.state
+ params['log'] = resource._state.log
+ reply = Query(ACTION_INSERT, query.object_name, params = params)
+ reply.last = last
+ packet = Packet.from_query(reply, reply = True)
+ cb(packet, ingress_interface = interface)
+ else:
+ log.warning("Unknown action in query {}".format(query))
+
+ interface._manager._broadcast(query)
class L2Graph(Object):
__type__ = 'vicn/l2graph'
diff --git a/netmodel/interfaces/vpp/__init__.py b/netmodel/interfaces/vpp/__init__.py
new file mode 100644
index 00000000..34d106fd
--- /dev/null
+++ b/netmodel/interfaces/vpp/__init__.py
@@ -0,0 +1,225 @@
+#!/usr/bin/env python3
+
+# Dependency: vpp-api-python
+
+import asyncio
+import asyncio.subprocess
+import collections
+import copy
+import logging
+import pyparsing as pp
+import socket
+import time
+
+from netmodel.model.attribute import Attribute
+from netmodel.model.filter import Filter
+from netmodel.model.query import Query, ACTION2STR, ACTION_UPDATE
+from netmodel.model.query import ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE
+from netmodel.model.object import Object
+from netmodel.model.type import Double, String
+from netmodel.network.interface import Interface as BaseInterface
+from netmodel.network.packet import Packet
+
+log = logging.getLogger(__name__)
+
+DEFAULT_INTERVAL = 1 # s
+KEY_FIELD = 'device_name'
+
+def parse(s):
+ kw_name = pp.Keyword('Name')
+ kw_idx = pp.Keyword('Idx')
+ kw_state = pp.Keyword('State')
+ kw_counter = pp.Keyword('Counter')
+ kw_count = pp.Keyword('Count')
+
+ kw_up = pp.CaselessKeyword('up')
+ kw_down = pp.CaselessKeyword('down')
+ kw_rx_packets = pp.CaselessKeyword('rx packets')
+ kw_rx_bytes = pp.CaselessKeyword('rx bytes')
+ kw_tx_packets = pp.CaselessKeyword('tx packets')
+ kw_tx_bytes = pp.CaselessKeyword('tx bytes')
+ kw_drops = pp.CaselessKeyword('drops')
+ kw_ip4 = pp.CaselessKeyword('ip4')
+ kw_ip6 = pp.CaselessKeyword('ip6')
+ kw_tx_error = pp.CaselessKeyword('tx-error')
+ kw_rx_miss = pp.CaselessKeyword('rx-miss')
+
+ header = kw_name + kw_idx + kw_state + kw_counter + kw_count
+
+ interface = (pp.Word(pp.alphanums + '/' + '-').setResultsName('device_name') + \
+ pp.Word(pp.nums).setResultsName('index') + \
+ pp.oneOf(['up', 'down']).setResultsName('state') + \
+ pp.Optional(kw_rx_packets + pp.Word(pp.nums).setResultsName('rx_packets')) + \
+ pp.Optional(kw_rx_bytes + pp.Word(pp.nums).setResultsName('rx_bytes')) + \
+ pp.Optional(kw_tx_packets + pp.Word(pp.nums).setResultsName('tx_packets')) + \
+ pp.Optional(kw_tx_bytes + pp.Word(pp.nums).setResultsName('tx_bytes')) + \
+ pp.Optional(kw_drops + pp.Word(pp.nums).setResultsName('drops')) + \
+ pp.Optional(kw_ip4 + pp.Word(pp.nums).setResultsName('ip4')) + \
+ pp.Optional(kw_ip6 + pp.Word(pp.nums).setResultsName('ip6')) + \
+ pp.Optional(kw_rx_miss + pp.Word(pp.nums).setResultsName('rx_miss')) + \
+ pp.Optional(kw_tx_error + pp.Word(pp.nums).setResultsName('tx_error'))
+ ).setParseAction(lambda t: t.asDict())
+
+ bnf = (
+ header.suppress() +
+ pp.OneOrMore(interface)
+ ).setParseAction(lambda t: t.asList())
+
+ return bnf.parseString(s, parseAll = True).asList()
+
+class VPPInterface(Object):
+ __type__ = 'vpp_interface'
+
+ node = Attribute(String)
+ device_name = Attribute(String)
+ bw_upstream = Attribute(Double) # bytes
+ bw_downstream = Attribute(Double) # bytes
+
+class VPPCtlInterface(BaseInterface):
+ __interface__ = 'vppctl'
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.register_object(VPPInterface)
+
+ # Set of monitored interfaces
+ self._interfaces = collections.defaultdict(int)
+ self._running = False
+
+ # interface -> (time, rx, tx)
+ self._last = dict()
+
+ async def _tick(self):
+ while self._running:
+ try:
+ create = asyncio.create_subprocess_exec(
+ 'vppctl_wrapper', 'show', 'int',
+ stdout=asyncio.subprocess.PIPE,
+ )
+ proc = await create
+ await proc.wait()
+ stdout = await proc.stdout.read()
+
+ if proc.returncode:
+ log.error("error")
+ return
+
+ interfaces = parse(stdout.decode())
+ last = copy.copy(self._last)
+ self._last = dict()
+ now = time.time()
+ for interface in interfaces:
+ if not interface['device_name'] in self._interfaces:
+ continue
+ tx = float(interface['tx_bytes'])
+ rx = float(interface['rx_bytes'])
+ self._last[interface['device_name']] = (now, rx, tx)
+
+ if not interface['device_name'] in last:
+ continue
+ prev_now, prev_rx, prev_tx = last[interface['device_name']]
+
+ # Per interface throughput computation
+ ret = {
+ 'node' : socket.gethostname(),
+ 'device_name' : interface['device_name'],
+ 'bw_upstream' : (tx - prev_tx) / (now - prev_now),
+ 'bw_downstream' : (rx - prev_rx) / (now - prev_now),
+ }
+
+ f_list = [[KEY_FIELD, '==', interface['device_name']]]
+ del interface['device_name']
+ query = Query(ACTION_UPDATE, VPPInterface.__type__,
+ filter = Filter.from_list(f_list),
+ params = ret)
+ self.receive(Packet.from_query(query, reply = True))
+ except Exception as e:
+ import traceback; traceback.print_exc()
+ log.error("Could not perform measurement {}".format(e))
+
+ await asyncio.sleep(DEFAULT_INTERVAL)
+
+ #--------------------------------------------------------------------------
+ # Router interface
+ #--------------------------------------------------------------------------
+
+ def send_impl(self, packet):
+ query = packet.to_query()
+
+ if query.action not in (ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE):
+ log.warning("Ignore unknown action {}".format(
+ ACTION2STR[query.action]))
+ return
+
+ # We currently simply extract it from the filter
+ interfaces = set([p.value for p in query.filter if p.key == KEY_FIELD])
+
+ for interface in interfaces:
+ if query.action == ACTION_SUBSCRIBE:
+ self._interfaces[interface] += 1
+ else:
+ self._interfaces[interface] -= 1
+
+ all_interfaces = set([k for k, v in self._interfaces.items() if v > 0])
+
+ if all_interfaces and not self._running:
+ self._running = True
+ asyncio.ensure_future(self._tick())
+ elif not all_interfaces and self._running:
+ self._running = False
+
+
+#-------------------------------------------------------------------------------
+
+if __name__ == '__main__':
+ x=""" Name Idx State Counter Count
+ TenGigabitEthernetc/0/1 1 up rx packets 3511586
+ rx bytes 4785592030
+ tx packets 3511678
+ tx bytes 313021701
+ drops 7
+ ip4 161538
+ ip6 3350047
+ tx-error 2
+ host-bh1 4 up rx packets 5
+ rx bytes 394
+ tx packets 10
+ tx bytes 860
+ drops 4
+ ip6 4
+ host-bh2 6 up rx packets 3164301
+ rx bytes 287315869
+ tx packets 3164238
+ tx bytes 4290944332
+ drops 4
+ ip4 161539
+ ip6 3002759
+ host-bh3 7 up rx packets 33066
+ rx bytes 2446928
+ tx packets 33060
+ tx bytes 47058708
+ drops 5
+ ip6 33065
+ host-bh4 5 up rx packets 114407
+ rx bytes 8466166
+ tx packets 114412
+ tx bytes 162905294
+ drops 7
+ ip6 114406
+ host-bh5 3 up rx packets 150574
+ rx bytes 11142524
+ tx packets 150578
+ tx bytes 214407016
+ drops 7
+ ip6 150573
+ host-bh6 2 up rx packets 49380
+ rx bytes 3654160
+ tx packets 49368
+ tx bytes 70283976
+ drops 9
+ ip6 49377
+ local0 0 down drops 3
+ """
+
+ r = parse(x)
+ print(r)
diff --git a/netmodel/interfaces/websocket/__init__.py b/netmodel/interfaces/websocket/__init__.py
index cb79fc39..b6402aca 100644
--- a/netmodel/interfaces/websocket/__init__.py
+++ b/netmodel/interfaces/websocket/__init__.py
@@ -97,7 +97,7 @@ class ClientProtocol(WebSocketClientProtocol):
Websocket opening handshake has completed.
"""
self.factory.interface._on_client_open(self)
-
+
def onMessage(self, payload, isBinary):
self.factory.interface._on_client_message(self, payload, isBinary)
@@ -109,7 +109,7 @@ class ClientProtocol(WebSocketClientProtocol):
class WebSocketClientInterface(Interface):
"""
All messages are exchanged using text (non-binary) mode.
- """
+ """
__interface__ = 'websocketclient'
def __init__(self, *args, **kwargs):
@@ -149,6 +149,11 @@ class WebSocketClientInterface(Interface):
# Holds the instance of the connect client protocol
self._client = None
+ def __repr__(self):
+ return '<WebSocketClientInterface {}>'.format(self._client)
+
+ __str__ = __repr__
+
#--------------------------------------------------------------------------
# Interface API
#--------------------------------------------------------------------------
@@ -168,10 +173,10 @@ class WebSocketClientInterface(Interface):
async def _connect(self):
loop = asyncio.get_event_loop()
try:
- self._instance = await loop.create_connection(self._factory,
+ self._instance = await loop.create_connection(self._factory,
self._address, self._port)
except Exception as e:
- log.warning('Connect failed : {}'.format(e))
+ log.warning('Connect failed on {} : {}'.format(self, e))
self._instance = None
# don't await for retry, since it cause an infinite recursion...
asyncio.ensure_future(self._retry())
@@ -206,7 +211,7 @@ class WebSocketClientInterface(Interface):
query, record = args
else:
query = args
-
+
if isinstance(query, dict):
query = Query.from_dict(query)
else:
@@ -226,7 +231,7 @@ class WebSocketClientInterface(Interface):
asyncio.ensure_future(self._retry())
#------------------------------------------------------------------------------
-
+
class ServerProtocol(WebSocketServerProtocol, Interface):
"""
Default WebSocket server protocol.
@@ -244,7 +249,7 @@ class ServerProtocol(WebSocketServerProtocol, Interface):
Constructor.
Args:
- callback (Function[ -> ]) :
+ callback (Function[ -> ]) :
hook (Function[->]) : Hook method to be called for every packet to
be sent on the interface. Processing continues with the packet
returned by this function, or is interrupted in case of a None
@@ -252,6 +257,12 @@ class ServerProtocol(WebSocketServerProtocol, Interface):
"""
WebSocketServerProtocol.__init__(self)
Interface.__init__(self, callback=callback, hook=hook)
+ self._last_peer = None
+
+ def __repr__(self):
+ return '<WebSocketInterface {}>'.format(self._last_peer if self._last_peer else 'N/A')
+
+ __str__ = __repr__
#--------------------------------------------------------------------------
# Interface API
@@ -272,9 +283,10 @@ class ServerProtocol(WebSocketServerProtocol, Interface):
# Websocket events
def onConnect(self, request):
+ self._last_peer = request.peer
self.factory._instances.append(self)
self.set_state(InterfaceState.Up)
-
+
def onOpen(self):
#print("WebSocket connection open.")
pass
@@ -305,7 +317,7 @@ class WebSocketServerInterface(Interface):
It is also used to broadcast packets to all connected clients.
All messages are exchanged using text (non-binary) mode.
- """
+ """
__interface__ = 'websocketserver'
@@ -331,6 +343,12 @@ class WebSocketServerInterface(Interface):
# packets.
self._factory._instances = list()
+ def __repr__(self):
+ return '<WebSocketServerInterface ws://{}:{}'.format(
+ self._address, self._port)
+
+ __str__ = __repr__
+
#--------------------------------------------------------------------------
# Interface API
#--------------------------------------------------------------------------
@@ -343,7 +361,7 @@ class WebSocketServerInterface(Interface):
loop = asyncio.get_event_loop()
# Websocket server
log.info('WebSocket server started')
- self._server = await loop.create_server(self._factory, self._address,
+ self._server = await loop.create_server(self._factory, self._address,
self._port)
await self._set_state(InterfaceState.Up)
diff --git a/netmodel/model/attribute.py b/netmodel/model/attribute.py
index b69ee1bf..b2fa2331 100644
--- a/netmodel/model/attribute.py
+++ b/netmodel/model/attribute.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python3
+#!/usr/bin/env python2
# -*- coding: utf-8 -*-
#
# Copyright (c) 2017 Cisco and/or its affiliates.
@@ -22,11 +22,10 @@ import logging
import operator
import types
-from netmodel.model.mapper import ObjectSpecification
-from netmodel.model.type import is_type
-from netmodel.util.meta import inheritors
-from netmodel.util.misc import is_iterable
-from vicn.core.sa_collections import InstrumentedList, _list_decorators
+from netmodel.model.mapper import ObjectSpecification
+from netmodel.model.type import Type, Self
+from netmodel.util.misc import is_iterable
+from netmodel.model.collection import Collection
log = logging.getLogger(__name__)
instance_dict = operator.attrgetter('__dict__')
@@ -38,26 +37,25 @@ class NEVER_SET: None
#------------------------------------------------------------------------------
class Multiplicity:
- _1_1 = '1_1'
- _1_N = '1_N'
- _N_1 = 'N_1'
- _N_N = 'N_N'
-
+ OneToOne = '1_1'
+ OneToMany = '1_N'
+ ManyToOne = 'N_1'
+ ManyToMany = 'N_N'
@staticmethod
def reverse(value):
reverse_map = {
- Multiplicity._1_1: Multiplicity._1_1,
- Multiplicity._1_N: Multiplicity._N_1,
- Multiplicity._N_1: Multiplicity._1_N,
- Multiplicity._N_N: Multiplicity._N_N,
+ Multiplicity.OneToOne: Multiplicity.OneToOne,
+ Multiplicity.OneToMany: Multiplicity.ManyToOne,
+ Multiplicity.ManyToOne: Multiplicity.OneToMany,
+ Multiplicity.ManyToMany: Multiplicity.ManyToMany,
}
return reverse_map[value]
# Default attribute properties values (default to None)
DEFAULT = {
- 'multiplicity' : Multiplicity._1_1,
+ 'multiplicity' : Multiplicity.OneToOne,
'mandatory' : False,
}
@@ -71,33 +69,36 @@ class Attribute(abc.ABC, ObjectSpecification):
'type',
'description',
'default',
- 'choices',
'mandatory',
'multiplicity',
'ro',
'auto',
'func',
- 'requirements',
'reverse_name',
'reverse_description',
'reverse_auto'
]
def __init__(self, *args, **kwargs):
- for key in Attribute.properties:
+ for key in kwargs.keys():
+ if not key in self.properties:
+ raise ValueError("Invalid attribute property {}".format(key))
+ for key in self.properties:
value = kwargs.pop(key, NEVER_SET)
setattr(self, key, value)
if len(args) == 1:
self.type, = args
- elif len(args) == 2:
- self.name, self.type = args
- assert is_type(self.type)
+ # self.type is optional since the type can be inherited. Although we
+ # will have to verify the attribute is complete at some point
+ if isinstance(self.type, str):
+ self.type = Type.from_string(self.type)
+ assert self.type is Self or Type.exists(self.type)
self.is_aggregate = False
self._reverse_attributes = list()
-
+
#--------------------------------------------------------------------------
# Display
#--------------------------------------------------------------------------
@@ -107,6 +108,15 @@ class Attribute(abc.ABC, ObjectSpecification):
__str__ = __repr__
+ # The following functions are required to allow comparing attributes, and
+ # using them as dict keys
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+ def __hash__(self):
+ return hash(self.name)
+
#--------------------------------------------------------------------------
# Descriptor protocol
#
@@ -118,7 +128,7 @@ class Attribute(abc.ABC, ObjectSpecification):
return self
value = instance_dict(instance).get(self.name, NEVER_SET)
-
+
# Case : collection attribute
if self.is_collection:
if value is NEVER_SET:
@@ -126,12 +136,12 @@ class Attribute(abc.ABC, ObjectSpecification):
default = self.default(instance)
else:
default = self.default
- value = InstrumentedList(default)
+ value = Collection(default)
value._attribute = self
value._instance = instance
self.__set__(instance, value)
return value
- return value
+ return value
# Case : scalar attribute
@@ -159,8 +169,8 @@ class Attribute(abc.ABC, ObjectSpecification):
return
if self.is_collection:
- if not isinstance(value, InstrumentedList):
- value = InstrumentedList(value)
+ if not isinstance(value, Collection):
+ value = Collection(value)
value._attribute = self
value._instance = instance
@@ -172,6 +182,10 @@ class Attribute(abc.ABC, ObjectSpecification):
def __delete__(self, instance):
raise NotImplementedError
+ def __set_name__(self, owner, name):
+ self.name = name
+ self.owner = owner
+
#--------------------------------------------------------------------------
# Accessors
#--------------------------------------------------------------------------
@@ -184,17 +198,16 @@ class Attribute(abc.ABC, ObjectSpecification):
return DEFAULT.get(name, None)
return value
- # Shortcuts
-
def has_reverse_attribute(self):
return self.reverse_name and self.multiplicity
@property
def is_collection(self):
- return self.multiplicity in (Multiplicity._1_N, Multiplicity._N_N)
+ return self.multiplicity in (Multiplicity.OneToMany,
+ Multiplicity.ManyToMany)
def is_set(self, instance):
- return self.name in instance_dict(instance)
+ return instance.is_set(self.name)
#--------------------------------------------------------------------------
# Operations
@@ -217,46 +230,3 @@ class Attribute(abc.ABC, ObjectSpecification):
value.extend(parent_value)
else:
setattr(self, prop, parent_value)
-
- #--------------------------------------------------------------------------
- # Attribute values
- #--------------------------------------------------------------------------
-
- def _handle_getitem(self, instance, item):
- return item
-
- def _handle_add(self, instance, item):
- instance._state.dirty = True
- instance._state.attr_dirty.add(self.name)
- print('marking', self.name, 'as dirty')
- return item
-
- def _handle_remove(self, instance, item):
- instance._state.dirty = True
- instance._state.attr_dirty.add(self.name)
- print('marking', self.name, 'as dirty')
-
- def _handle_before_remove(self, instance):
- pass
-
- #--------------------------------------------------------------------------
- # Attribute values
- #--------------------------------------------------------------------------
-
-class Relation(Attribute):
- properties = Attribute.properties[:]
- properties.extend([
- 'reverse_name',
- 'reverse_description',
- 'multiplicity',
- ])
-
-class SelfRelation(Relation):
- def __init__(self, *args, **kwargs):
- if args:
- if not len(args) == 1:
- raise ValueError('Bad initialized for SelfRelation')
- name, = args
- super().__init__(name, None, *args, **kwargs)
- else:
- super().__init__(None, *args, **kwargs)
diff --git a/netmodel/model/collection.py b/netmodel/model/collection.py
index 21be84d8..01a63299 100644
--- a/netmodel/model/collection.py
+++ b/netmodel/model/collection.py
@@ -16,9 +16,10 @@
# limitations under the License.
#
-from netmodel.model.filter import Filter
+from netmodel.model.sa_collections import InstrumentedList
+from netmodel.model.filter import Filter
-class Collection(list):
+class Collection(InstrumentedList):
"""
A collection corresponds to a list of objects, and includes processing functionalities to
manipulate them.
diff --git a/netmodel/model/key.py b/netmodel/model/key.py
new file mode 100644
index 00000000..bc49af03
--- /dev/null
+++ b/netmodel/model/key.py
@@ -0,0 +1,19 @@
+from netmodel.model.mapper import ObjectSpecification
+
+class Key(ObjectSpecification):
+ def __init__(self, *attributes):
+ self._attributes = attributes
+
+ #--------------------------------------------------------------------------
+ # Descriptor protocol
+ #
+ # see. https://docs.python.org/3/howto/descriptor.html
+ #--------------------------------------------------------------------------
+
+ def __set_name__(self, owner, name):
+ self._name = name
+ self._owner = owner
+
+ def __iter__(self):
+ for attribute in self._attributes:
+ yield attribute
diff --git a/netmodel/model/mapper.py b/netmodel/model/mapper.py
index 9be46a14..856238c7 100644
--- a/netmodel/model/mapper.py
+++ b/netmodel/model/mapper.py
@@ -16,5 +16,7 @@
# limitations under the License.
#
-class ObjectSpecification:
+import sys
+
+class ObjectSpecification:
pass
diff --git a/netmodel/model/object.py b/netmodel/model/object.py
index 32d3a833..99dbe0c2 100644
--- a/netmodel/model/object.py
+++ b/netmodel/model/object.py
@@ -18,7 +18,10 @@
from abc import ABCMeta
-from netmodel.model.attribute import Attribute
+import sys
+
+from netmodel.model.attribute import Attribute, Multiplicity
+from netmodel.model.key import Key
from netmodel.model.type import BaseType
from netmodel.model.mapper import ObjectSpecification
@@ -26,11 +29,21 @@ from netmodel.model.mapper import ObjectSpecification
E_UNK_RES_NAME = 'Unknown resource name for attribute {} in {} ({}) : {}'
-class ObjectMetaclass(ABCMeta):
+class ObjectMetaclass(type):
"""
Object metaclass allowing non-uniform attribute declaration.
"""
+ def __new__(mcls, name, bases, attrs):
+ cls = super(ObjectMetaclass, mcls).__new__(mcls, name, bases, attrs)
+ if (sys.version_info < (3, 6)):
+ # Before Python 3.6, descriptor protocol does not include __set_name__.
+ # We use a metaclass to emulate the functionality.
+ for attr, obj in attrs.items():
+ if isinstance(obj, ObjectSpecification):
+ obj.__set_name__(cls, attr)
+ return cls
+
def __init__(cls, class_name, parents, attrs):
"""
Args:
@@ -67,19 +80,19 @@ class Object(BaseType, metaclass = ObjectMetaclass):
else:
resource = x
if not resource:
- raise LurchException(E_UNK_RES_NAME.format(key,
+ raise LurchException(E_UNK_RES_NAME.format(key,
self.name, self.__class__.__name__, x))
new_value.append(resource._state.uuid)
value = new_value
else:
if isinstance(value, str):
- resource = self._state.manager.by_name(value)
+ resource = self._state.manager.by_name(value)
elif isinstance(value, UUID):
- resource = self._state.manager.by_uuid(value)
+ resource = self._state.manager.by_uuid(value)
else:
resource = value
if not resource:
- raise LurchException(E_UNK_RES_NAME.format(key,
+ raise LurchException(E_UNK_RES_NAME.format(key,
self.name, self.__class__.__name__, value))
value = resource._state.uuid
setattr(self, key, value)
@@ -111,19 +124,20 @@ class Object(BaseType, metaclass = ObjectMetaclass):
@classmethod
def _sanitize(cls):
- """Sanitize the object model to accomodate for multiple declaration
- styles
+ """
+ This methods performs sanitization of the object declaration.
+
+ More specifically:
+ - it goes over all attributes and sets their name based on the python
+ object attribute name.
+ - it establishes mutual object relationships through reverse attributes.
- In particular, this method:
- - set names to all attributes
"""
cls._reverse_attributes = dict()
cur_reverse_attributes = dict()
for name, obj in vars(cls).items():
- if not isinstance(obj, ObjectSpecification):
+ if not isinstance(obj, Attribute):
continue
- if isinstance(obj, Attribute):
- obj.name = name
# Remember whether a reverse_name is defined before loading
# inherited properties from parent
@@ -135,30 +149,72 @@ class Object(BaseType, metaclass = ObjectMetaclass):
if hasattr(base, name):
parent_attribute = getattr(base, name)
obj.merge(parent_attribute)
- assert obj.type
+ assert obj.type, "No type for obj={} cls={}, base={}".format(obj, cls, base)
# Handle reverse attribute
#
# NOTE: we need to do this after merging to be sure we get all
# properties inherited from parent (eg. multiplicity)
+ #
+ # See "Reverse attributes" section in BaseResource docstring.
+ #
+ # Continueing with the same example, let's detail how it is handled:
+ #
+ # Original declaration:
+ # >>>
+ # class Group(Resource):
+ # resources = Attribute(Resource, description = 'Resources belonging to the group',
+ # multiplicity = Multiplicity.ManyToMany,
+ # default = [],
+ # reverse_name = 'groups',
+ # reverse_description = 'Groups to which the resource belongs')
+ # <<<
+ #
+ # Local variables:
+ # cls = <class 'vicn.resource.group.Group'>
+ # obj = <Attribute resources>
+ # obj.type = <class 'vicn.core.Resource'>
+ # reverse_attribute = <Attribute groups>
+ #
+ # Result:
+ # 1) Group._reverse_attributes =
+ # { <Attribute resources> : [<Attribute groups>, ...], ...}
+ # 2) Add attribute <Attribute groups> to class Resource
+ # 3) Resource._reverse_attributes =
+ # { <Attribute groups> : [<Attribute resources], ...], ...}
+ #
if has_reverse:
a = {
- 'name' : obj.reverse_name,
- 'description' : obj.reverse_description,
- 'multiplicity' : Multiplicity.reverse(obj.multiplicity),
- 'auto' : obj.reverse_auto,
+ 'name' : obj.reverse_name,
+ 'description' : obj.reverse_description,
+ 'multiplicity' : Multiplicity.reverse(obj.multiplicity),
+ 'reverse_name' : obj.name,
+ 'reverse_description' : obj.description,
+ 'auto' : obj.reverse_auto,
}
- reverse_attribute = Attribute(cls, **a)
+
+ # We need to use the same class as the Attribute !
+ reverse_attribute = obj.__class__(cls, **a)
reverse_attribute.is_aggregate = True
+ # 1) Store the reverse attributes to be later inserted in the
+ # remote class, at the end of the function
+ # TODO : clarify the reasons to perform this in two steps
cur_reverse_attributes[obj.type] = reverse_attribute
+ # 2)
if not obj in cls._reverse_attributes:
cls._reverse_attributes[obj] = list()
cls._reverse_attributes[obj].append(reverse_attribute)
- for cls, a in cur_reverse_attributes.items():
- setattr(cls, a.name, a)
+ # 3)
+ if not reverse_attribute in obj.type._reverse_attributes:
+ obj.type._reverse_attributes[reverse_attribute] = list()
+ obj.type._reverse_attributes[reverse_attribute].append(obj)
+
+ # Insert newly created reverse attributes in the remote class
+ for kls, a in cur_reverse_attributes.items():
+ setattr(kls, a.name, a)
@classmethod
def iter_attributes(cls, aggregates = False):
@@ -168,7 +224,7 @@ class Object(BaseType, metaclass = ObjectMetaclass):
continue
if attribute.is_aggregate and not aggregates:
continue
-
+
yield attribute
def get_attributes(self, aggregates = False):
@@ -178,7 +234,7 @@ class Object(BaseType, metaclass = ObjectMetaclass):
return set(a.name for a in self.iter_attributes(aggregates = \
aggregates))
- def get_attribute_dict(self, field_names = None, aggregates = False,
+ def get_attribute_dict(self, field_names = None, aggregates = False,
uuid = True):
assert not field_names or field_names.is_star()
attributes = self.get_attributes(aggregates = aggregates)
@@ -192,14 +248,29 @@ class Object(BaseType, metaclass = ObjectMetaclass):
ret[a.name] = list()
for x in value:
if uuid and isinstance(x, Object):
- x = x._state.uuid._uuid
+ x = x._state.uuid._uuid
ret[a.name].append(x)
else:
if uuid and isinstance(value, Object):
- value = value._state.uuid._uuid
+ value = value._state.uuid._uuid
ret[a.name] = value
return ret
+ @classmethod
+ def iter_keys(cls):
+ for name in dir(cls):
+ key = getattr(cls, name)
+ if not isinstance(key, Key):
+ continue
+ yield key
+
+ @classmethod
+ def get_keys(cls):
+ return list(cls.iter_keys())
+
+ def get_key_dicts(self):
+ return [{attribute: self.get(attribute.name) for attribute in key} for key in self.iter_keys()]
+
def get_tuple(self):
return (self.__class__, self._get_attribute_dict())
@@ -229,3 +300,8 @@ class Object(BaseType, metaclass = ObjectMetaclass):
def has_attribute(cls, name):
return name in [a.name for a in cls.attributes()]
+ def get(self, attribute_name):
+ raise NotImplementedError
+
+ def set(self, attribute_name, value):
+ raise NotImplementedError
diff --git a/netmodel/model/query.py b/netmodel/model/query.py
index c182cb45..a1d331fb 100644
--- a/netmodel/model/query.py
+++ b/netmodel/model/query.py
@@ -46,7 +46,7 @@ FUNCTION2STR = {
STR2FUNCTION = dict((v, k) for k, v in FUNCTION2STR.items())
class Query:
- def __init__(self, action, object_name, filter = None, params = None,
+ def __init__(self, action, object_name, filter = None, params = None,
field_names = None, aggregate = None, last = False, reply = False):
self.action = action
self.object_name = object_name
@@ -64,13 +64,13 @@ class Query:
if field_names:
if isinstance(field_names, FieldNames):
self.field_names = field_names
- else:
+ else:
self.field_names = FieldNames(field_names)
else:
self.field_names = FieldNames()
self.aggregate = aggregate
-
+
self.last = last
self.reply = reply
@@ -100,7 +100,7 @@ class Query:
field_names = FieldNames(star = True)
last = dic.get('last', False)
reply = dic.get('reply', False)
- return Query(action, object_name, filter, params, field_names,
+ return Query(action, object_name, filter, params, field_names,
aggregate, last)
def to_sql(self, multiline = False):
@@ -140,8 +140,7 @@ class Query:
return strmap[self.action] % locals()
- def __str__(self):
- return self.to_sql()
-
def __repr__(self):
return self.to_sql()
+
+ __str__ = __repr__
diff --git a/netmodel/model/sa_collections.py b/netmodel/model/sa_collections.py
new file mode 100644
index 00000000..5e651061
--- /dev/null
+++ b/netmodel/model/sa_collections.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# This module is derived from code from SQLAlchemy
+#
+# orm/collections.py
+# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+#
+
+import logging
+
+from netmodel.model.sa_compat import py2k
+from netmodel.model.uuid import UUID
+
+class InstrumentedListException(Exception): pass
+
+log = logging.getLogger(__name__)
+
+def _list_decorators():
+ """Tailored instrumentation wrappers for any list-like class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(list, fn.__name__).__doc__
+
+ def append(fn):
+ def append(self, item):
+ try:
+ item = self._attribute.do_list_add(self._instance, item)
+ fn(self, item)
+ except InstrumentedListException as e:
+ pass
+ _tidy(append)
+ return append
+
+ def remove(fn):
+ def remove(self, value):
+ # testlib.pragma exempt:__eq__
+ try:
+ self._attribute.do_list_remove(self._instance, value)
+ fn(self, value)
+ except : pass
+ _tidy(remove)
+ return remove
+
+ def insert(fn):
+ def insert(self, index, value):
+ try:
+ value = self._attribute.do_list_add(self._instance, item)
+ fn(self, index, value)
+ except : pass
+ _tidy(insert)
+ return insert
+
+ def __getitem__(fn):
+ def __getitem__(self, index):
+ item = fn(self, index)
+ return self._attribute.handle_getitem(self._instance, item)
+ _tidy(__getitem__)
+ return __getitem__
+
+ def __setitem__(fn):
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ existing = self[index]
+ if existing is not None:
+ try:
+ self._attribute.do_list_remove(self._instance, existing)
+ except: pass
+ try:
+ value = self._attribute.do_list_add(self._instance, value)
+ fn(self, index, value)
+ except: pass
+ else:
+ # slice assignment requires __delitem__, insert, __len__
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ if index.stop is not None:
+ stop = index.stop
+ else:
+ stop = len(self)
+ if stop < 0:
+ stop += len(self)
+
+ if step == 1:
+ for i in range(start, stop, step):
+ if len(self) > start:
+ del self[start]
+
+ for i, item in enumerate(value):
+ self.insert(i + start, item)
+ else:
+ rng = list(range(start, stop, step))
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value),
+ len(rng)))
+ for i, item in zip(rng, value):
+ self.__setitem__(i, item)
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, index):
+ if not isinstance(index, slice):
+ item = self[index]
+ try:
+ self._attribute.do_list_remove(self._instance, item)
+ fn(self, index)
+ except : pass
+ else:
+ # slice deletion requires __getslice__ and a slice-groking
+ # __getitem__ for stepped deletion
+ # note: not breaking this into atomic dels
+ has_except = False
+ for item in self[index]:
+ try:
+ self._attribute.do_list_remove(self._instance, item)
+ except : has_except = True
+ if not has_except:
+ fn(self, index)
+ _tidy(__delitem__)
+ return __delitem__
+
+ if py2k:
+ def __setslice__(fn):
+ def __setslice__(self, start, end, values):
+ has_except = False
+ for value in self[start:end]:
+ try:
+ self._attribute.do_list_remove(self._instance, value)
+ except : has_except = True
+ #values = [self._attribute.do_list_add(self._instance, value) for value in values]
+ _values = list()
+ for value in values:
+ try:
+ _values.append(self._attribute.do_list_add(self._instance, value))
+ except: has_except = True
+ if not has_except:
+ fn(self, start, end, _values)
+ _tidy(__setslice__)
+ return __setslice__
+
+ def __delslice__(fn):
+ def __delslice__(self, start, end):
+ has_except = False
+ for value in self[start:end]:
+ try:
+ self._attribute.do_list_remove(self._instance, value)
+ except : has_except = True
+ if not has_except:
+ fn(self, start, end)
+ _tidy(__delslice__)
+ return __delslice__
+
+ def extend(fn):
+ def extend(self, iterable):
+ for value in iterable:
+ self.append(value)
+ _tidy(extend)
+ return extend
+
+ def __iadd__(fn):
+ def __iadd__(self, iterable):
+ # list.__iadd__ takes any iterable and seems to let TypeError
+ # raise as-is instead of returning NotImplemented
+ for value in iterable:
+ self.append(value)
+ return self
+ _tidy(__iadd__)
+ return __iadd__
+
+ def pop(fn):
+ def pop(self, index=-1):
+ try:
+ self._attribute.do_list_remove(self._instance, item)
+ item = fn(self, index)
+ return item
+ except : return None
+ _tidy(pop)
+ return pop
+
+ def __iter__(fn):
+ def __iter__(self):
+ for item in fn(self):
+ yield self._attribute.handle_getitem(self._instance, item)
+ _tidy(__iter__)
+ return __iter__
+
+ def __repr__(fn):
+ def __repr__(self):
+ return '<Collection {} {}>'.format(id(self), list.__repr__(self))
+ _tidy(__repr__)
+ return __repr__
+
+ __str__ = __repr__
+ #def __str__(fn):
+ # def __str__(self):
+ # return str(list(self))
+ # _tidy(__str__)
+ # return __str__
+
+ if not py2k:
+ def clear(fn):
+ def clear(self, index=-1):
+ has_except = False
+ for item in self:
+ try:
+ self._attribute.do_list_remove(self._instance, item)
+ except : has_except = True
+ if not has_except:
+ fn(self)
+ _tidy(clear)
+ return clear
+
+ # __imul__ : not wrapping this. all members of the collection are already
+ # present, so no need to fire appends... wrapping it with an explicit
+ # decorator is still possible, so events on *= can be had if they're
+ # desired. hard to imagine a use case for __imul__, though.
+
+ l = locals().copy()
+ l.pop('_tidy')
+ return l
+
+def _instrument_list(cls):
+ # inspired by sqlalchemy
+ for method, decorator in _list_decorators().items():
+ fn = getattr(cls, method, None)
+ if fn:
+ #if (fn and method not in methods and
+ # not hasattr(fn, '_sa_instrumented')):
+ setattr(cls, method, decorator(fn))
+
+class InstrumentedList(list):
+
+ @classmethod
+ def from_list(cls, value, instance, attribute):
+ lst = list()
+ if value:
+ for x in value:
+ if isinstance(x, UUID):
+ x = instance.from_uuid(x)
+ lst.append(x)
+ # Having a class method is important for inheritance
+ value = cls(lst)
+ value._attribute = attribute
+ value._instance = instance
+ return value
+
+ def __contains__(self, key):
+ from vicn.core.resource import Resource
+ if isinstance(key, Resource):
+ key = key.get_uuid()
+ return list.__contains__(self, key)
+
+ def __lshift__(self, item):
+ self.append(item)
+
+_instrument_list(InstrumentedList)
diff --git a/netmodel/model/sa_compat.py b/netmodel/model/sa_compat.py
new file mode 100644
index 00000000..34211455
--- /dev/null
+++ b/netmodel/model/sa_compat.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# This module originates from SQLAlchemy
+#
+# util/compat.py
+# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+#
+
+"""Handle Python version/platform incompatibilities."""
+
+import sys
+
+try:
+ import threading
+except ImportError:
+ import dummy_threading as threading
+
+py36 = sys.version_info >= (3, 6)
+py33 = sys.version_info >= (3, 3)
+py32 = sys.version_info >= (3, 2)
+py3k = sys.version_info >= (3, 0)
+py2k = sys.version_info < (3, 0)
+py265 = sys.version_info >= (2, 6, 5)
+jython = sys.platform.startswith('java')
+pypy = hasattr(sys, 'pypy_version_info')
+win32 = sys.platform.startswith('win')
+cpython = not pypy and not jython # TODO: something better for this ?
+
+import collections
+next = next
+
+if py3k:
+ import pickle
+else:
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle
+
+# work around http://bugs.python.org/issue2646
+if py265:
+ safe_kwarg = lambda arg: arg
+else:
+ safe_kwarg = str
+
+ArgSpec = collections.namedtuple("ArgSpec",
+ ["args", "varargs", "keywords", "defaults"])
+
+if py3k:
+ import builtins
+
+ from inspect import getfullargspec as inspect_getfullargspec
+ from urllib.parse import (quote_plus, unquote_plus,
+ parse_qsl, quote, unquote)
+ import configparser
+ from io import StringIO
+
+ from io import BytesIO as byte_buffer
+
+ def inspect_getargspec(func):
+ return ArgSpec(
+ *inspect_getfullargspec(func)[0:4]
+ )
+
+ string_types = str,
+ binary_types = bytes,
+ binary_type = bytes
+ text_type = str
+ int_types = int,
+ iterbytes = iter
+
+ def u(s):
+ return s
+
+ def ue(s):
+ return s
+
+ def b(s):
+ return s.encode("latin-1")
+
+ if py32:
+ callable = callable
+ else:
+ def callable(fn):
+ return hasattr(fn, '__call__')
+
+ def cmp(a, b):
+ return (a > b) - (a < b)
+
+ from functools import reduce
+
+ print_ = getattr(builtins, "print")
+
+ import_ = getattr(builtins, '__import__')
+
+ import itertools
+ itertools_filterfalse = itertools.filterfalse
+ itertools_filter = filter
+ itertools_imap = map
+ from itertools import zip_longest
+
+ import base64
+
+ def b64encode(x):
+ return base64.b64encode(x).decode('ascii')
+
+ def b64decode(x):
+ return base64.b64decode(x.encode('ascii'))
+
+else:
+ from inspect import getargspec as inspect_getfullargspec
+ inspect_getargspec = inspect_getfullargspec
+ from urllib import quote_plus, unquote_plus, quote, unquote
+ from urlparse import parse_qsl
+ import ConfigParser as configparser
+ from StringIO import StringIO
+ from cStringIO import StringIO as byte_buffer
+
+ string_types = basestring,
+ binary_types = bytes,
+ binary_type = str
+ text_type = unicode
+ int_types = int, long
+
+ def iterbytes(buf):
+ return (ord(byte) for byte in buf)
+
+ def u(s):
+ # this differs from what six does, which doesn't support non-ASCII
+ # strings - we only use u() with
+ # literal source strings, and all our source files with non-ascii
+ # in them (all are tests) are utf-8 encoded.
+ return unicode(s, "utf-8")
+
+ def ue(s):
+ return unicode(s, "unicode_escape")
+
+ def b(s):
+ return s
+
+ def import_(*args):
+ if len(args) == 4:
+ args = args[0:3] + ([str(arg) for arg in args[3]],)
+ return __import__(*args)
+
+ callable = callable
+ cmp = cmp
+ reduce = reduce
+
+ import base64
+ b64encode = base64.b64encode
+ b64decode = base64.b64decode
+
+ def print_(*args, **kwargs):
+ fp = kwargs.pop("file", sys.stdout)
+ if fp is None:
+ return
+ for arg in enumerate(args):
+ if not isinstance(arg, basestring):
+ arg = str(arg)
+ fp.write(arg)
+
+ import itertools
+ itertools_filterfalse = itertools.ifilterfalse
+ itertools_filter = itertools.ifilter
+ itertools_imap = itertools.imap
+ from itertools import izip_longest as zip_longest
+
+
+import time
+if win32 or jython:
+ time_func = time.clock
+else:
+ time_func = time.time
+
+from collections import namedtuple
+from operator import attrgetter as dottedgetter
+
+
+if py3k:
+ def reraise(tp, value, tb=None, cause=None):
+ if cause is not None:
+ assert cause is not value, "Same cause emitted"
+ value.__cause__ = cause
+ if value.__traceback__ is not tb:
+ raise value.with_traceback(tb)
+ raise value
+
+else:
+ # not as nice as that of Py3K, but at least preserves
+ # the code line where the issue occurred
+ exec("def reraise(tp, value, tb=None, cause=None):\n"
+ " if cause is not None:\n"
+ " assert cause is not value, 'Same cause emitted'\n"
+ " raise tp, value, tb\n")
+
+
+def raise_from_cause(exception, exc_info=None):
+ if exc_info is None:
+ exc_info = sys.exc_info()
+ exc_type, exc_value, exc_tb = exc_info
+ cause = exc_value if exc_value is not exception else None
+ reraise(type(exception), exception, tb=exc_tb, cause=cause)
+
+if py3k:
+ exec_ = getattr(builtins, 'exec')
+else:
+ def exec_(func_text, globals_, lcl=None):
+ if lcl is None:
+ exec('exec func_text in globals_')
+ else:
+ exec('exec func_text in globals_, lcl')
+
+
+def with_metaclass(meta, *bases):
+ """Create a base class with a metaclass.
+
+ Drops the middle class upon creation.
+
+ Source: http://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/
+
+ """
+
+ class metaclass(meta):
+ __call__ = type.__call__
+ __init__ = type.__init__
+
+ def __new__(cls, name, this_bases, d):
+ if this_bases is None:
+ return type.__new__(cls, name, (), d)
+ return meta(name, bases, d)
+ return metaclass('temporary_class', None, {})
+
+
+from contextlib import contextmanager
+
+try:
+ from contextlib import nested
+except ImportError:
+ # removed in py3k, credit to mitsuhiko for
+ # workaround
+
+ @contextmanager
+ def nested(*managers):
+ exits = []
+ vars = []
+ exc = (None, None, None)
+ try:
+ for mgr in managers:
+ exit = mgr.__exit__
+ enter = mgr.__enter__
+ vars.append(enter())
+ exits.append(exit)
+ yield vars
+ except:
+ exc = sys.exc_info()
+ finally:
+ while exits:
+ exit = exits.pop()
+ try:
+ if exit(*exc):
+ exc = (None, None, None)
+ except:
+ exc = sys.exc_info()
+ if exc != (None, None, None):
+ reraise(exc[0], exc[1], exc[2])
diff --git a/netmodel/model/type.py b/netmodel/model/type.py
index 20dc2580..9a7b8740 100644
--- a/netmodel/model/type.py
+++ b/netmodel/model/type.py
@@ -16,27 +16,47 @@
# limitations under the License.
#
+from socket import inet_pton, inet_ntop, AF_INET6
+from struct import unpack, pack
+from abc import ABCMeta
+
from netmodel.util.meta import inheritors
class BaseType:
+ __choices__ = None
+
@staticmethod
def name():
return self.__class__.__name__.lower()
+ @classmethod
+ def restrict(cls, **kwargs):
+ class BaseType(cls):
+ __choices__ = kwargs.pop('choices', None)
+ return BaseType
+
class String(BaseType):
- def __init__(self, *args, **kwargs):
- self._min_size = kwargs.pop('min_size', None)
- self._max_size = kwargs.pop('max_size', None)
- self._ascii = kwargs.pop('ascii', False)
- self._forbidden = kwargs.pop('forbidden', None)
- super().__init__()
+ __min_size__ = None
+ __max_size__ = None
+ __ascii__ = None
+ __forbidden__ = None
+
+ @classmethod
+ def restrict(cls, **kwargs):
+ base = super().restrict(**kwargs)
+ class String(base):
+ __max_size__ = kwargs.pop('max_size', None)
+ __min_size__ = kwargs.pop('min_size', None)
+ __ascii__ = kwargs.pop('ascii', None)
+ __forbidden__ = kwargs.pop('forbidden', None)
+ return String
class Integer(BaseType):
def __init__(self, *args, **kwargs):
self._min_value = kwargs.pop('min_value', None)
self._max_value = kwargs.pop('max_value', None)
super().__init__()
-
+
class Double(BaseType):
def __init__(self, *args, **kwargs):
self._min_value = kwargs.pop('min_value', None)
@@ -49,12 +69,161 @@ class Bool(BaseType):
class Dict(BaseType):
pass
+class PrefixTreeException(Exception): pass
+class NotEnoughAddresses(PrefixTreeException): pass
+class UnassignablePrefix(PrefixTreeException): pass
+
+class Prefix(BaseType, metaclass=ABCMeta):
+
+ def __init__(self, ip_address, prefix_len=None):
+ if not prefix_len:
+ if not isinstance(ip_address, str):
+ import pdb; pdb.set_trace()
+ if '/' in ip_address:
+ ip_address, prefix_len = ip_address.split('/')
+ prefix_len = int(prefix_len)
+ else:
+ prefix_len = self.MAX_PREFIX_SIZE
+ if isinstance(ip_address, str):
+ ip_address = self.aton(ip_address)
+ self.ip_address = ip_address
+ self.prefix_len = prefix_len
+
+ def __contains__(self, obj):
+ #it can be an IP as a integer
+ if isinstance(obj, int):
+ obj = type(self)(obj, self.MAX_PREFIX_SIZE)
+ #Or it's an IP string
+ if isinstance(obj, str):
+ #It's a prefix as 'IP/prefix'
+ if '/' in obj:
+ split_obj = obj.split('/')
+ obj = type(self)(split_obj[0], int(split_obj[1]))
+ else:
+ obj = type(self)(obj, self.MAX_PREFIX_SIZE)
+
+ return self._contains_prefix(obj)
+
+ @classmethod
+ def mask(cls):
+ mask_len = cls.MAX_PREFIX_SIZE//8 #Converts from bits to bytes
+ mask = 0
+ for step in range(0,mask_len):
+ mask = (mask << 8) | 0xff
+ return mask
+
+ def _contains_prefix(self, prefix):
+ assert isinstance(prefix, type(self))
+ return (prefix.prefix_len >= self.prefix_len and
+ prefix.ip_address >= self.first_prefix_address() and
+ prefix.ip_address <= self.last_prefix_address())
+
+ #Returns the first address of a prefix
+ def first_prefix_address(self):
+ return self.ip_address & (self.mask() << (self.MAX_PREFIX_SIZE-self.prefix_len))
+
+ def canonical_prefix(self):
+ return type(self)(self.first_prefix_address(), self.prefix_len)
+
+ def last_prefix_address(self):
+ return self.ip_address | (self.mask() >> self.prefix_len)
+
+ def limits(self):
+ return self.first_prefix_address(), self.last_prefix_address()
+
+ def __str__(self):
+ return "{}/{}".format(self.ntoa(self.first_prefix_address()), self.prefix_len)
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ return self.get_tuple() == other.get_tuple()
+
+ def get_tuple(self):
+ return (self.first_prefix_address(), self.prefix_len)
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
+ def __iter__(self):
+ return self.get_iterator()
+
+ #Iterates by steps of prefix_len, e.g., on all available /31 in a /24
+ def get_iterator(self, prefix_len=None):
+ if prefix_len is None:
+ prefix_len=self.MAX_PREFIX_SIZE
+ assert (prefix_len >= self.prefix_len and prefix_len<=self.MAX_PREFIX_SIZE)
+ step = 2**(self.MAX_PREFIX_SIZE - prefix_len)
+ for ip in range(self.first_prefix_address(), self.last_prefix_address()+1, step):
+ yield type(self)(ip, prefix_len)
+
+class Inet4Prefix(Prefix):
+
+ MAX_PREFIX_SIZE = 32
+
+ @classmethod
+ def aton(cls, address):
+ ret = 0
+ components = address.split('.')
+ for comp in components:
+ ret = (ret << 8) + int(comp)
+ return ret
+
+ @classmethod
+ def ntoa(cls, address):
+ components = []
+ for _ in range(0,4):
+ components.insert(0,'{}'.format(address % 256))
+ address = address >> 8
+ return '.'.join(components)
+
+class Inet6Prefix(Prefix):
+
+ MAX_PREFIX_SIZE = 128
+
+ @classmethod
+ def aton (cls, address):
+ prefix, suffix = unpack(">QQ", inet_pton(AF_INET6, address))
+ return (prefix << 64) | suffix
+
+ @classmethod
+ def ntoa (cls, address):
+ return inet_ntop(AF_INET6, pack(">QQ", address >> 64, address & ((1 << 64) -1)))
+
+ #skip_internet_address: skip a:b::0, as v6 often use default /64 prefixes
+ def get_iterator(self, prefix_len=None, skip_internet_address=None):
+ if skip_internet_address is None:
+ #We skip the internet address if we iterate over Addresses
+ if prefix_len is None:
+ skip_internet_address = True
+ #But not if we iterate over prefixes
+ else:
+ skip_internet_address = False
+ it = super().get_iterator(prefix_len)
+ if skip_internet_address:
+ next(it)
+ return it
+
+class InetAddress(Prefix):
+
+ def get_tuple(self):
+ return (self.ip_address, self.prefix_len)
+
+ def __str__(self):
+ return self.ntoa(self.ip_address)
+
+class Inet4Address(InetAddress, Inet4Prefix):
+ pass
+
+class Inet6Address(InetAddress, Inet6Prefix):
+ pass
+
class Self(BaseType):
"""Self-reference
"""
class Type:
- BASE_TYPES = (String, Integer, Double, Bool)
+ BASE_TYPES = (String, Integer, Double, Bool)
_registry = dict()
@staticmethod
diff --git a/netmodel/model/uuid.py b/netmodel/model/uuid.py
new file mode 100644
index 00000000..dae51d75
--- /dev/null
+++ b/netmodel/model/uuid.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2017 Cisco and/or its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import string
+
+# Separator for components of the UUID
+UUID_SEP = '-'
+
+# Length of the random component of the UUID
+UUID_LEN = 5
+
+class UUID:
+ def __init__(self, name, cls):
+ self._uuid = self._make_uuid(name, cls)
+
+ def _make_uuid(self, name, cls):
+ """Generate a unique resource identifier
+
+ The UUID consists in the type of the resource, to which is added a
+ random identifier of length UUID_LEN. Components of the UUID are
+ separated by UUID_SEP.
+ """
+ uuid = ''.join(random.choice(string.ascii_uppercase + string.digits)
+ for _ in range(UUID_LEN))
+ if name:
+ uuid = name # + UUID_SEP + uuid
+ return UUID_SEP.join([cls.__name__, uuid])
+
+ def __repr__(self):
+ return '<UUID {}>'.format(self._uuid)
+
+ def __lt__(self, other):
+ return self._uuid < other._uuid
+
+ __str__ = __repr__
diff --git a/netmodel/network/fib.py b/netmodel/network/fib.py
index e6b81607..11b90b22 100644
--- a/netmodel/network/fib.py
+++ b/netmodel/network/fib.py
@@ -39,7 +39,10 @@ class FIB:
self._entries = dict()
def add(self, prefix, next_hops = None):
- self._entries[prefix] = FIBEntry(prefix, next_hops)
+ if prefix not in self._entries:
+ self._entries[prefix] = FIBEntry(prefix, next_hops)
+ else:
+ self._entries[prefix].update(next_hops)
def update(self, prefix, next_hops = None):
entry = self._entries.get(prefix)
@@ -54,10 +57,11 @@ class FIB:
raise Exception('prefix not found')
entry.remove(next_hops)
return
-
+
del self._entries[prefix]
def get(self, object_name):
for entry in self._entries.values():
if entry._prefix.object_name == object_name:
- return next(iter(entry._next_hops))
+ return entry._next_hops
+ return None
diff --git a/netmodel/network/flow_table.py b/netmodel/network/flow_table.py
index 99e42e99..2da40d1c 100644
--- a/netmodel/network/flow_table.py
+++ b/netmodel/network/flow_table.py
@@ -16,7 +16,10 @@
# limitations under the License.
#
-from netmodel.model.query import ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE
+import copy
+from collections import defaultdict, Counter
+
+from netmodel.model.query import Query, ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE
# Per-interface flow table
class SubFlowTable:
@@ -25,21 +28,33 @@ class SubFlowTable:
self._interface = interface
# Flow -> ingress interface
- self._flows = dict()
-
- # Ingress interface -> list of subscription
- self._subscriptions = dict()
+ self._flows = defaultdict(set)
def add(self, packet, ingress_interface):
flow = packet.get_flow()
- self._flows[flow] = ingress_interface
+ self._flows[flow].add(ingress_interface)
def match(self, packet):
flow = packet.get_flow()
return self._flows.get(flow.get_reverse())
+ def _on_interface_delete(self, interface):
+ """
+ Returns:
+ False is the flow table is empty.
+ """
+ to_remove = set()
+ for flow, ingress_interfaces in self._flows.items():
+ ingress_interfaces.discard(interface)
+ if not ingress_interfaces:
+ to_remove.add(flow)
+ for flow in to_remove:
+ del self._flows[flow]
+
+ return len(self._flows) > 0
+
class Subscription:
- def __init__(self, packet, ingress_list, egress_list):
+ def __init__(self, packet, ingress_set, egress_set):
"""
Args:
packet : subscription packet
@@ -47,8 +62,22 @@ class Subscription:
egress_list (List[Interface]) : list of egress interface
"""
self._packet = packet
- self._ingress_list = ingress_list
- self._egress_list = egress_list
+ #self._ingress_set = ingress_set
+ #self._egress_set = egress_set
+
+ def get_tuple(self):
+ return (self._packet,)
+
+ def __eq__(self, other):
+ return self.get_tuple() == other.get_tuple()
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
+ def __repr__(self):
+ return '<Subscription {}>'.format(self._packet.to_query())
+
+ __str__ = __repr__
class FlowTable:
"""
@@ -76,12 +105,30 @@ class FlowTable:
# The Flow Table also maintains a list of subscriptions doubly indexed
# by both subscriptors, and egress interfaces
- # ingress_interface -> list of subscriptions
- self._ingress_subscriptions = dict()
-
- # egress_interface -> list of subscriptions
- self._egress_subscriptions = dict()
-
+ # ingress_interface -> bag of subscriptions
+ self._ingress_subscriptions = defaultdict(Counter)
+
+ # egress_interface -> bag of subscriptions
+ self._egress_subscriptions = defaultdict(Counter)
+
+ def dump(self, msg=''):
+ print("="*80)
+ print("FLOW TABLE {}".format(msg))
+ print("-" * 80)
+ print("SubFlowTables")
+ for interface, flow_table in self._sub_flow_tables.items():
+ for k, v in flow_table._flows.items():
+ print(interface, "\t", k, "\t", v)
+ print("-" * 80)
+ print("Ingress subscriptions")
+ for interface, subscriptions in self._ingress_subscriptions.items():
+ print(interface, "\t", subscriptions)
+ print("-" * 80)
+ print("Egress subscriptions")
+ for interface, subscriptions in self._egress_subscriptions.items():
+ print(interface, "\t", subscriptions)
+ print("=" * 80)
+ print("")
def match(self, packet, interface):
"""
@@ -95,34 +142,30 @@ class FlowTable:
if not sub_flow_table:
return None
return sub_flow_table.match(packet)
-
- def add(self, packet, ingress_interface, interface):
- sub_flow_table = self._sub_flow_tables.get(interface)
- if not sub_flow_table:
- sub_flow_table = SubFlowTable(interface)
- self._sub_flow_tables[interface] = sub_flow_table
- sub_flow_table.add(packet, ingress_interface)
+
+ def add(self, packet, ingress_interface, interfaces):
+ for interface in interfaces:
+ sub_flow_table = self._sub_flow_tables.get(interface)
+ if not sub_flow_table:
+ sub_flow_table = SubFlowTable(interface)
+ self._sub_flow_tables[interface] = sub_flow_table
+ sub_flow_table.add(packet, ingress_interface)
# If the flow is a subscription, we need to associate it to the list
query = packet.to_query()
if query.action == ACTION_SUBSCRIBE:
# XXX we currently don't merge subscriptions, and assume a single
# next hop interface
- s = Subscription(packet, [ingress_interface], [interface])
+ s = Subscription(packet, set([ingress_interface]), interfaces)
if ingress_interface:
- if not ingress_interface in self._ingress_subscriptions:
- self._ingress_subscriptions[ingress_interface] = list()
- self._ingress_subscriptions[ingress_interface].append(s)
-
- if not interface in self._egress_subscriptions:
- self._egress_subscriptions[interface] = list()
- self._egress_subscriptions[interface].append(s)
+ self._ingress_subscriptions[ingress_interface] += Counter([s])
+ for interface in interfaces:
+ self._egress_subscriptions[interface] += Counter([s])
elif query.action == ACTION_UNSUBSCRIBE:
raise NotImplementedError
-
# Events
def _on_interface_up(self, interface):
@@ -141,9 +184,47 @@ class FlowTable:
"""
Callback: an interface has been deleted
- Cancel all subscriptions that have been issues from
+ Cancel all subscriptions that have been issues from
netmodel.network.interface.
Remove all pointers to subscriptions pending on this interface
"""
+
if interface in self._ingress_subscriptions:
+ # If the interface we delete at the origin of the subscription,
+ # let's also remove corresponding egress subscriptions
+ subs = self._ingress_subscriptions[interface]
+ if not subs:
+ return
+
+ to_remove = set()
+ for _interface, subscriptions in self._egress_subscriptions.items():
+
+ removed = subs & subscriptions
+ if removed:
+ for s in removed:
+ # We found a subscription of this interface on an other
+ # interface; send unsubscribe...
+ action, params = s._packet.payload
+ p = copy.deepcopy(s._packet)
+ p.payload = (ACTION_UNSUBSCRIBE, params)
+ _interface.send(p)
+ # ... and remove them
+ subscriptions -= removed
+
+ # if the interface has no more subscription remove it.
+ if not subscriptions:
+ to_remove.add(_interface)
+
+ for i in to_remove:
+ del self._egress_subscriptions[i]
+
del self._ingress_subscriptions[interface]
+
+ # Remove interface from flow table destination
+ to_remove = set()
+ for _interface, sub_flow_table in self._sub_flow_tables.items():
+ remove = sub_flow_table._on_interface_delete(interface)
+ if not remove:
+ to_remove.add(_interface)
+ for _interface in to_remove:
+ del self._sub_flow_tables[_interface]
diff --git a/netmodel/network/interface.py b/netmodel/network/interface.py
index c9e31422..3bad4c41 100644
--- a/netmodel/network/interface.py
+++ b/netmodel/network/interface.py
@@ -44,7 +44,7 @@ class InterfaceState(enum.Enum):
def register_interfaces():
Interface._factory = dict()
- for loader, module_name, is_pkg in pkgutil.walk_packages(interfaces.__path__,
+ for loader, module_name, is_pkg in pkgutil.walk_packages(interfaces.__path__,
interfaces.__name__ + '.'):
# Note: we cannot skip files using is_pkg otherwise it will ignore
# interfaces defined outside of __init__.py
@@ -95,12 +95,12 @@ class Interface:
self._tx_buffer = list()
self._state = InterfaceState.Down
- self._error = None
+ self._error = None
self._reconnecting = True
self._reconnection_delay = RECONNECTION_DELAY
self._registered_objects = dict()
-
+
# Callbacks
self._up_callbacks = list()
self._down_callbacks = list()
@@ -119,7 +119,7 @@ class Interface:
def __hash__(self):
return hash(self._name)
- #---------------------------------------------------------------------------
+ #---------------------------------------------------------------------------
def register_object(self, obj):
self._registered_objects[obj.__type__] = obj
@@ -127,9 +127,9 @@ class Interface:
def get_prefixes(self):
return [ Prefix(v.__type__) for v in self._registered_objects.values() ]
- #---------------------------------------------------------------------------
+ #---------------------------------------------------------------------------
# State management, callbacks
- #---------------------------------------------------------------------------
+ #---------------------------------------------------------------------------
def set_state(self, state):
asyncio.ensure_future(self._set_state(state))
@@ -170,7 +170,7 @@ class Interface:
for cb, args, kwargs in self._delete_callbacks:
cb(interface, *args, **kwargs)
- #--------------------------------------------------------------------------
+ #--------------------------------------------------------------------------
def set_reconnecting(self, reconnecting):
self._reconnecting = reconnecting
@@ -253,13 +253,13 @@ class Interface:
def receive_impl(self, packet):
ingress_interface = self
cb = self._callback
- if cb is None:
- return
if self._hook:
new_packet = self._hook(packet)
- if new_packet is not None:
+ if cb is not None and new_packet is not None:
cb(new_packet, ingress_interface=ingress_interface)
return
+ if cb is None:
+ return
cb(packet, ingress_interface=ingress_interface)
#--------------------------------------------------------------------------
diff --git a/netmodel/network/packet.py b/netmodel/network/packet.py
index 9552b0e7..7edcccc4 100644
--- a/netmodel/network/packet.py
+++ b/netmodel/network/packet.py
@@ -109,7 +109,7 @@ class VICNTLV:
@staticmethod
def get_type(buf):
- (typelen, ) = struct.unpack(NETMODEL_TLV_TYPELEN_STR,
+ (typelen, ) = struct.unpack(NETMODEL_TLV_TYPELEN_STR,
buf[:NETMODEL_TLV_SIZE])
return (typelen & NETMODEL_TLV_TYPE_MASK) >> NETMODEL_TLV_TYPE_SHIFT
@@ -179,7 +179,7 @@ class VICNTLV:
class ObjectName(VICNTLV): pass
@VICNTLV.set_tlv_type(NETMODEL_TLV_FIELD)
-class Field(VICNTLV):
+class Field(VICNTLV):
"""Field == STR
"""
@@ -217,7 +217,7 @@ class Prefix(Object, Prefix_, VICNTLV):
def __hash__(self):
return hash(self.get_tuple())
-
+
@VICNTLV.set_tlv_type(NETMODEL_TLV_SRC)
class Source(Prefix):
"""Source address
@@ -251,12 +251,12 @@ class Packet(Object, VICNTLV):
source = Source()
destination = Destination(Prefix)
protocol = Protocol(String, default = 'query')
- flags = Flags()
- payload = Payload()
+ flags = Flags(String)
+ payload = Payload(String)
# This should be dispatched across L3 L4 L7
- def __init__(self, source = None, destination = None, protocol = None,
+ def __init__(self, source = None, destination = None, protocol = None,
flags = 0, payload = None):
self.source = source
self.destination = destination
@@ -272,8 +272,8 @@ class Packet(Object, VICNTLV):
packet = Packet()
if src_query:
address = Prefix(
- object_name = src_query.object_name,
- filter = src_query.filter,
+ object_name = src_query.object_name,
+ filter = src_query.filter,
field_names = src_query.field_names,
aggregate = src_query.aggregate)
if reply:
@@ -283,8 +283,8 @@ class Packet(Object, VICNTLV):
if query:
address = Prefix(
- object_name = query.object_name,
- filter = query.filter,
+ object_name = query.object_name,
+ filter = query.filter,
field_names = query.field_names,
aggregate = query.aggregate)
@@ -310,7 +310,7 @@ class Packet(Object, VICNTLV):
field_names = address.field_names
aggregate = address.aggregate
- return Query(action, object_name, filter, params, field_names,
+ return Query(action, object_name, filter, params, field_names,
aggregate = aggregate, last = self.last, reply = self.reply)
@property
@@ -335,6 +335,16 @@ class Packet(Object, VICNTLV):
else:
self.flags &= ~FLAG_REPLY
+ def get_tuple(self):
+ return (self.source, self.destination, self.protocol, self.flags,
+ self.payload)
+
+ def __eq__(self, other):
+ return self.get_tuple() == other.get_tuple()
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
class ErrorPacket(Packet):
"""
Analog with ICMP errors packets in IP networks
@@ -344,7 +354,7 @@ class ErrorPacket(Packet):
# Constructor
#--------------------------------------------------------------------------
- def __init__(self, type = ERROR, code = ERROR, message = None,
+ def __init__(self, type = ERROR, code = ERROR, message = None,
traceback = None, **kwargs):
assert not traceback or isinstance(traceback, str)
diff --git a/netmodel/network/prefix.py b/netmodel/network/prefix.py
index 00b5db71..d444a56d 100644
--- a/netmodel/network/prefix.py
+++ b/netmodel/network/prefix.py
@@ -17,20 +17,23 @@
#
class Prefix:
- def __init__(self, object_name = None, filter = None, field_names = None,
+ def __init__(self, object_name = None, filter = None, field_names = None,
aggregate = None):
self.object_name = object_name
self.filter = filter
self.field_names = field_names
self.aggregate = aggregate
- def __hash__(self):
- return hash(self.get_tuple())
-
def get_tuple(self):
- return (self.object_name, self.filter, self.field_names,
+ return (self.object_name, self.filter, self.field_names,
self.aggregate)
+ def __eq__(self, other):
+ return self.get_tuple() == other.get_tuple()
+
+ def __hash__(self):
+ return hash(self.get_tuple())
+
def __repr__(self):
return '<Prefix {}>'.format(self.get_tuple())
diff --git a/netmodel/network/router.py b/netmodel/network/router.py
index 84d69dca..871cefe0 100644
--- a/netmodel/network/router.py
+++ b/netmodel/network/router.py
@@ -68,7 +68,7 @@ class Router:
#--------------------------------------------------------------------------
def register_local_collection(self, cls):
- self.get_interface(LOCAL_NAMESPACE).register_collection(cls,
+ self.get_interface(LOCAL_NAMESPACE).register_collection(cls,
LOCAL_NAMESPACE)
def register_collection(self, cls, namespace=None):
@@ -97,7 +97,7 @@ class Router:
interface.set_state(InterfaceState.PendingUp)
for prefix in interface.get_prefixes():
- self._fib.add(prefix, [interface])
+ self._fib.add(prefix, set([interface]))
return interface
@@ -112,8 +112,8 @@ class Router:
def on_interface_up(self, interface):
"""
- This callback is triggered when an interface becomes up.
-
+ This callback is triggered when an interface becomes up.
+
The router will request metadata.
The flow table is notified.
"""
@@ -139,7 +139,7 @@ class Router:
# Public API
#---------------------------------------------------------------------------
- def add_interface(self, interface_type, name=None, namespace=None,
+ def add_interface(self, interface_type, name=None, namespace=None,
**platform_config):
"""
namespace is used to force appending of a namespace to the tables.
@@ -161,7 +161,7 @@ class Router:
interface = interface_cls(self, **platform_config)
except Exception as e:
traceback.print_exc()
- raise Exception("Cannot create interface %s of type %s with parameters %r: %s"
+ raise Exception("Cannot create interface %s of type %s with parameters %r: %s"
% (name, interface_type,
platform_config, e))
self._register_interface(interface, name)
@@ -231,9 +231,10 @@ class Router:
interface flow table
- using the FIB if no match is found
"""
- orig_interface = self._flow_table.match(packet, ingress_interface)
- if orig_interface:
- orig_interface.send(packet)
+ orig_interfaces = self._flow_table.match(packet, ingress_interface)
+ if orig_interfaces:
+ for orig_interface in orig_interfaces:
+ orig_interface.send(packet)
return
if isinstance(packet, str):
@@ -247,11 +248,17 @@ class Router:
return
# Get route from FIB
- interface = self._fib.get(packet.destination.object_name)
- if not interface:
+ if packet.destination is None:
+ log.warning("Ignored reply packet with no match in flow table {}".format(
+ packet.to_query()))
+ return
+ interfaces = self._fib.get(packet.destination.object_name)
+ if not interfaces:
+ log.error('No match in FIB for {}'.format(
+ packet.destination.object_name))
return
# Update flow table before sending
- self._flow_table.add(packet, ingress_interface, interface)
-
- interface.send(packet)
+ self._flow_table.add(packet, ingress_interface, interfaces)
+ for interface in interfaces:
+ interface.send(packet)
diff --git a/netmodel/util/daemon.py b/netmodel/util/daemon.py
index 29683a54..eb8cd1a2 100644
--- a/netmodel/util/daemon.py
+++ b/netmodel/util/daemon.py
@@ -231,7 +231,7 @@ class Daemon:
"""
Overload this method if you use twisted (see xmlrpc.py)
"""
- sys.exit(0)
+ os._exit(0)
# Overload these...