From 85a341d645b57b7cd88a26ed2ea0a314704240ea Mon Sep 17 00:00:00 2001 From: Jordan Augé Date: Fri, 24 Feb 2017 14:58:01 +0100 Subject: Initial commit: vICN MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I7ce66c4e84a6a1921c63442f858b49e083adc7a7 Signed-off-by: Jordan Augé --- netmodel/__init__.py | 0 netmodel/interfaces/__init__.py | 0 netmodel/interfaces/local.py | 60 +++++ netmodel/interfaces/process/__init__.py | 215 +++++++++++++++ netmodel/interfaces/socket/__init__.py | 171 ++++++++++++ netmodel/interfaces/socket/tcp.py | 86 ++++++ netmodel/interfaces/socket/udp.py | 88 ++++++ netmodel/interfaces/socket/unix.py | 87 ++++++ netmodel/interfaces/vicn.py | 100 +++++++ netmodel/interfaces/websocket/__init__.py | 358 +++++++++++++++++++++++++ netmodel/model/__init__.py | 0 netmodel/model/attribute.py | 262 ++++++++++++++++++ netmodel/model/field_names.py | 396 +++++++++++++++++++++++++++ netmodel/model/filter.py | 397 +++++++++++++++++++++++++++ netmodel/model/mapper.py | 20 ++ netmodel/model/object.py | 231 ++++++++++++++++ netmodel/model/predicate.py | 306 +++++++++++++++++++++ netmodel/model/query.py | 147 ++++++++++ netmodel/model/result_value.py | 185 +++++++++++++ netmodel/model/sql_parser.py | 221 +++++++++++++++ netmodel/model/type.py | 89 +++++++ netmodel/network/__init__.py | 0 netmodel/network/fib.py | 63 +++++ netmodel/network/flow.py | 57 ++++ netmodel/network/flow_table.py | 150 +++++++++++ netmodel/network/interface.py | 281 +++++++++++++++++++ netmodel/network/packet.py | 429 ++++++++++++++++++++++++++++++ netmodel/network/prefix.py | 37 +++ netmodel/network/router.py | 257 ++++++++++++++++++ netmodel/util/__init__.py | 0 netmodel/util/argparse.py | 22 ++ netmodel/util/color.py | 113 ++++++++ netmodel/util/daemon.py | 245 +++++++++++++++++ netmodel/util/debug.py | 45 ++++ netmodel/util/deprecated.py | 35 +++ netmodel/util/log.py | 120 +++++++++ netmodel/util/meta.py | 30 +++ netmodel/util/misc.py | 62 +++++ netmodel/util/process.py | 35 +++ netmodel/util/sa_compat.py | 265 ++++++++++++++++++ netmodel/util/singleton.py | 41 +++ netmodel/util/socket.py | 31 +++ netmodel/util/toposort.py | 82 ++++++ 43 files changed, 5819 insertions(+) create mode 100644 netmodel/__init__.py create mode 100644 netmodel/interfaces/__init__.py create mode 100644 netmodel/interfaces/local.py create mode 100644 netmodel/interfaces/process/__init__.py create mode 100644 netmodel/interfaces/socket/__init__.py create mode 100644 netmodel/interfaces/socket/tcp.py create mode 100644 netmodel/interfaces/socket/udp.py create mode 100644 netmodel/interfaces/socket/unix.py create mode 100644 netmodel/interfaces/vicn.py create mode 100644 netmodel/interfaces/websocket/__init__.py create mode 100644 netmodel/model/__init__.py create mode 100644 netmodel/model/attribute.py create mode 100644 netmodel/model/field_names.py create mode 100644 netmodel/model/filter.py create mode 100644 netmodel/model/mapper.py create mode 100644 netmodel/model/object.py create mode 100644 netmodel/model/predicate.py create mode 100644 netmodel/model/query.py create mode 100644 netmodel/model/result_value.py create mode 100644 netmodel/model/sql_parser.py create mode 100644 netmodel/model/type.py create mode 100644 netmodel/network/__init__.py create mode 100644 netmodel/network/fib.py create mode 100644 netmodel/network/flow.py create mode 100644 netmodel/network/flow_table.py create mode 100644 netmodel/network/interface.py create mode 100644 netmodel/network/packet.py create mode 100644 netmodel/network/prefix.py create mode 100644 netmodel/network/router.py create mode 100644 netmodel/util/__init__.py create mode 100644 netmodel/util/argparse.py create mode 100644 netmodel/util/color.py create mode 100644 netmodel/util/daemon.py create mode 100644 netmodel/util/debug.py create mode 100644 netmodel/util/deprecated.py create mode 100644 netmodel/util/log.py create mode 100644 netmodel/util/meta.py create mode 100644 netmodel/util/misc.py create mode 100644 netmodel/util/process.py create mode 100644 netmodel/util/sa_compat.py create mode 100644 netmodel/util/singleton.py create mode 100644 netmodel/util/socket.py create mode 100644 netmodel/util/toposort.py (limited to 'netmodel') diff --git a/netmodel/__init__.py b/netmodel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netmodel/interfaces/__init__.py b/netmodel/interfaces/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netmodel/interfaces/local.py b/netmodel/interfaces/local.py new file mode 100644 index 00000000..c68dec7e --- /dev/null +++ b/netmodel/interfaces/local.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from netmodel.model.attribute import Attribute +from netmodel.model.query import Query, ACTION_INSERT +from netmodel.model.object import Object +from netmodel.model.type import String +from netmodel.network.interface import Interface, InterfaceState +from netmodel.network.packet import Packet +from netmodel.network.prefix import Prefix +from netmodel.util.misc import lookahead + +class LocalObjectInterface(Object): + __type__ = 'local/interface' + + name = Attribute(String) + type = Attribute(String) + status = Attribute(String) + description = Attribute(String) + + @classmethod + def get(cls, query, ingress_interface): + cb = ingress_interface._callback + interfaces = ingress_interface._router.get_interfaces() + for interface, last in lookahead(interfaces): + interface_dict = { + 'name': interface.name, + 'type': interface.__interface__, + 'status': interface.get_status(), + 'description': interface.get_description(), + } + reply = Query(ACTION_INSERT, query.object_name, params = + interface_dict) + reply.last = last + packet = Packet.from_query(reply, reply = True) + cb(packet, ingress_interface = ingress_interface) + +class LocalInterface(Interface): + __interface__ = 'local' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._router = kwargs.pop('router') + self.register_object(LocalObjectInterface) + diff --git a/netmodel/interfaces/process/__init__.py b/netmodel/interfaces/process/__init__.py new file mode 100644 index 00000000..b985c32f --- /dev/null +++ b/netmodel/interfaces/process/__init__.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shlex +import socket +import subprocess +import threading + +from netmodel.network.interface import Interface as BaseInterface +from netmodel.network.packet import Packet +from netmodel.network.prefix import Prefix +from netmodel.model.attribute import Attribute +from netmodel.model.filter import Filter +from netmodel.model.object import Object +from netmodel.model.query import Query, ACTION_UPDATE +from netmodel.model.query import ACTION_SUBSCRIBE, FUNCTION_SUM +from netmodel.model.type import String, Integer, Double + +DEFAULT_INTERVAL = 1 # s +KEY_FIELD = 'device_name' + +class Interface(Object): + __type__ = 'interface' + + node = Attribute(String) + device_name = Attribute(String) + bw_upstream = Attribute(Double) # bytes + bw_downstream = Attribute(Double) # bytes + +class Process(threading.Thread): + pass + +class BWMThread(Process): + + SEP=';' + CMD="stdbuf -oL bwm-ng -t 1000 -N -o csv -c 0 -C '%s'" + + # Parsing information (from README, specs section) + # https://github.com/jgjl/bwm-ng/blob/master/README + # + # Type rate: + FIELDS_RATE = ['timestamp', 'iface_name', 'bytes_out_s', 'bytes_in_s', + 'bytes_total_s', 'bytes_in', 'bytes_out', 'packets_out_s', + 'packets_in_s', 'packets_total_s', 'packets_in', 'packets_out', + 'errors_out_s', 'errors_in_s', 'errors_in', 'errors_out'] + # Type svg, sum, max + FIELDS_SUM = ['timestamp', 'iface_name', 'bytes_out', 'bytes_in', + 'bytes_total', 'packets_out', 'packets_in', 'packets_total', + 'errors_out', 'errors_in'] + + def __init__(self, interfaces, callback): + threading.Thread.__init__(self) + + # The list of interfaces is used for filtering + self.groups_of_interfaces = set(interfaces) + + self._callback = callback + self._is_running = False + + def run(self): + cmd = self.CMD % (self.SEP) + p = subprocess.Popen(shlex.split(cmd), stdout = subprocess.PIPE, + stderr = subprocess.STDOUT) + stdout = [] + self._is_running = True + self.bwm_stats = dict() + while self._is_running: + line = p.stdout.readline().decode() + if line == '' and p.poll() is not None: + break + if line: + record = self._parse_line(line.strip()) + # We use 'total' to push the statistics back to VICN + if record['iface_name'] == 'total': + for interfaces in self.groups_of_interfaces: + if not len(interfaces) > 1: + # If the tuple contains only one interface, grab + # the information from bwm_stats and sends it back + # to VICN + if interfaces[0] not in self.bwm_stats: + continue + interface = self.bwm_stats[interfaces[0]] + f_list = [[KEY_FIELD, '==', interface.device_name]] + query = Query(ACTION_UPDATE, Interface.__type__, + filter = Filter.from_list(f_list), + params = interface.get_attribute_dict()) + self._callback(query) + else: + # Iterate over each tuple of interfaces to create + # the aggregated filter and paramters to send back + # Currently, we only support sum among the stats + # when VICN subscribes to a tuple of interfaces + aggregated_filters = list() + aggregated_interface = Interface( + node = socket.gethostname(), + device_name = 'sum', + bw_upstream = 0, + bw_downstream = 0) + predicate = list() + predicate.append(KEY_FIELD) + predicate.append('INCLUDED') + for interface in interfaces: + if interface not in self.bwm_stats: + continue + iface = self.bwm_stats[interface] + aggregated_filters.append(iface.device_name) + aggregated_interface.bw_upstream += \ + iface.bw_upstream + aggregated_interface.bw_downstream += \ + iface.bw_downstream + + if not aggregated_filters: + continue + predicate.append(aggregated_filters) + + # We support mulitple interfaces only if tied up + # with the SUM function. The update must have the + # sum function specified because it is used to + # match the subscribe query + attrs = aggregated_interface.get_attribute_dict() + query = Query(ACTION_UPDATE, Interface.__type__, + filter = Filter.from_list([predicate]), + params = attrs, + aggregate = FUNCTION_SUM) + self._callback(query) + else: + # Statistics from netmodel.network.interface will be stored + # in self.bwm_stats and used later to construct the update + # queries + interface = Interface( + node = socket.gethostname(), + device_name = record['iface_name'], + bw_upstream = float(record['bytes_out_s']), + bw_downstream = float(record['bytes_in_s']), + ) + + self.bwm_stats[record['iface_name']] = interface + + rc = p.poll() + return rc + + def stop(self): + self._is_running = False + + def _parse_line(self, line): + return dict(zip(self.FIELDS_RATE, line.split(self.SEP))) + +class BWMInterface(BaseInterface): + __interface__ = 'bwm' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._thread = None + + self.register_object(Interface) + + def terminate(self): + self._thread.stop() + + def _on_reply(self, reply): + packet = Packet.from_query(reply, reply = True) + self.receive(packet) + + #-------------------------------------------------------------------------- + # Router interface + #-------------------------------------------------------------------------- + + def send_impl(self, packet): + query = packet.to_query() + + assert query.action == ACTION_SUBSCRIBE + interval = query.params.get('interval', DEFAULT_INTERVAL) \ + if query.params else DEFAULT_INTERVAL + assert interval + + # TODO: Add the sum operator. If sum the list of interfaces is + # added to the BWMThread as a tuple, otherwise every single + # interface will be added singularly + + # We currently simply extract it from the filter + interfaces_list = [p.value for p in query.filter if p.key == KEY_FIELD] + + # interfaces is a list of tuple. If someone sbscribe to mulitple + # interfaces interfaces will be a list of 1 tuple. The tuple will + # contain the set of interfaces + assert len(interfaces_list) == 1 + interfaces = interfaces_list[0] \ + if isinstance(interfaces_list[0], tuple) \ + else tuple([interfaces_list[0]]) + + # Check if interfaces is more than one. In this case, we only support + # The SUM function on the list of field. + if len(interfaces) > 1: + assert query.aggregate == FUNCTION_SUM + + if self._thread is None: + self._thread = BWMThread(tuple([interfaces]), self._on_reply) + self._thread.start() + else: + self._thread.groups_of_interfaces.add(interfaces) diff --git a/netmodel/interfaces/socket/__init__.py b/netmodel/interfaces/socket/__init__.py new file mode 100644 index 00000000..00581eb4 --- /dev/null +++ b/netmodel/interfaces/socket/__init__.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import json + +from netmodel.network.packet import Packet +from netmodel.network.interface import Interface, InterfaceState + +class Protocol(asyncio.Protocol): + def __init__(self): + self._transport = None + + def terminate(self): + if self._transport: + self._transport.close() + + def send(self, packet): + if isinstance(packet, Packet): + data = json.dumps(packet.to_query().to_dict()) + else: + data = packet + self.send_impl(data) + + def receive(self, data, ingress_interface): + try: + packet = Packet.from_query(Query.from_dict(json.loads(data))) + except: + packet = data + self.receive(packet, ingress_interface) + + +class ServerProtocol(Protocol): + # asyncio.Protocol + + def __init__(self): + super().__init__() + + def connection_made(self, transport): + """ + Called when a connection is made. + The argument is the _transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + self._transport = transport + self.set_state(InterfaceState.Up) + + def connection_lost(self, exc): + """ + Called when the connection is lost or closed. + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + self.set_state(InterfaceState.Down) + +class ClientProtocol(Protocol): + def __init__(self, interface, *args, **kwargs): + super().__init__(*args, **kwargs) + self._interface = interface + + # asyncio.Protocol + + def connection_made(self, transport): + """ + Called when a connection is made. + The argument is the _transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + self._transport = transport + self._interface.set_state(InterfaceState.Up) + + def connection_lost(self, exc): + """ + Called when the connection is lost or closed. + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + self._interface.set_state(InterfaceState.Down) + + + +#------------------------------------------------------------------------------ + +class SocketServer: + def __init__(self, *args, **kwargs): + # For a server, an instance of asyncio.base_events.Server + self._transport = None + self._clients = list() + + def terminate(self): + """Close the server and terminate all clients. + """ + self._socket.close() + for client in self._clients: + self._clients.terminate() + + def send(self, packet): + """Broadcast packet to all connected clients. + """ + for client in self._clients: + self._clients.send(packet) + + def receive(self, packet): + raise RuntimeError('Unexpected packet received by server interface') + + def __repr__(self): + if self._transport: + peername = self._transport.get_extra_info('peername') + else: + peername = 'not connected' + return ''.format(self.__interface__, peername) + + async def pending_up_impl(self): + try: + self._server = await self._create_socket() + except Exception as e: + await self._set_state(InterfaceState.Error) + self._error = str(e) + + # Only the server interface is up once the socket has been created and + # is listening... + await self._set_state(InterfaceState.Up) + +class SocketClient: + def __init__(self, *args, **kwargs): + # For a client connection, this is a tuple + # (_SelectorSocketTransport, protocol) + self._transport = None + self._protocol = None + + def send_impl(self, packet): + self._protocol.send(packet) + + async def pending_up_impl(self): + try: + self._transport, self._protocol = await self._create_socket() + except Exception as e: + await self._set_state(InterfaceState.Error) + self._error = str(e) + + def pending_down_impl(self): + if self._socket: + self._transport.close() + + def __repr__(self): + if self._socket: + peername = self._transport.get_extra_info('peername') + else: + peername = 'not connected' + return ''.format(self.__interface__, peername) + + diff --git a/netmodel/interfaces/socket/tcp.py b/netmodel/interfaces/socket/tcp.py new file mode 100644 index 00000000..5b886d9a --- /dev/null +++ b/netmodel/interfaces/socket/tcp.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio + +from netmodel.interfaces.socket import ServerProtocol, ClientProtocol +from netmodel.interfaces.socket import SocketClient, SocketServer +from netmodel.network.interface import Interface + +DEFAULT_ADDRESS = '127.0.0.1' +DEFAULT_PORT = 7000 + +class TCPProtocol: + def send_impl(self, data): + msg = data.encode() + self._transport.write(msg) + + # asyncio.Protocol + + def data_received(self, data): + """ + Called when some data is received. + The argument is a bytes object. + """ + msg = data.decode() + self.receive(msg, ingress_interface=self) + +class TCPServerProtocol(TCPProtocol, ServerProtocol, Interface): + __interface__ = 'tcp' + + def __init__(self, *args, **kwargs): + # Note: super() does not call all parents' constructors + Interface.__init__(self, *args, **kwargs) + ServerProtocol.__init__(self) + +class TCPClientProtocol(TCPProtocol, ClientProtocol): + pass + +#------------------------------------------------------------------------------ + +class TCPServerInterface(SocketServer, Interface): + __interface__ = 'tcpserver' + + def __init__(self, *args, **kwargs): + SocketServer.__init__(self) + self._address = kwargs.pop('address', DEFAULT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + Interface.__init__(self, *args, **kwargs) + + def new_protocol(self): + p = TcpServerProtocol(callback = self._callback, hook=self._hook) + self.spawn_interface(p) + return p + + def _create_socket(self): + loop = asyncio.get_event_loop() + return loop.create_server(self.new_protocol, self._address, self._port) + +class TCPClientInterface(SocketClient, Interface): + __interface__ = 'tcpclient' + + def __init__(self, *args, **kwargs): + SocketClient.__init__(self) + self._address = kwargs.pop('address', DEFAULT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + Interface.__init__(self, *args, **kwargs) + + def _create_socket(self): + loop = asyncio.get_event_loop() + protocol = lambda : TCPClientProtocol(self) + return loop.create_connection(protocol, self._address, self._port) diff --git a/netmodel/interfaces/socket/udp.py b/netmodel/interfaces/socket/udp.py new file mode 100644 index 00000000..d3fdb696 --- /dev/null +++ b/netmodel/interfaces/socket/udp.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio + +from netmodel.interfaces.socket import ServerProtocol, ClientProtocol +from netmodel.interfaces.socket import SocketClient, SocketServer +from netmodel.network.interface import Interface + +DEFAULT_ADDRESS = '127.0.0.1' +DEFAULT_PORT = 7000 + +class UDPProtocol: + + def send_impl(self, data): + msg = data.encode() + self._transport.sendto(msg) + + # asyncio.Protocol + + def datagram_received(self, data, addr): + msg = data.decode() + self.receive(msg, ingress_interface=self) + + def error_received(self, exc): + print('Error received:', exc) + +class UDPServerProtocol(UDPProtocol, ServerProtocol, Interface): + __interface__ = 'udp' + + def __init__(self, *args, **kwargs): + # Note: super() does not call all parents' constructors + Interface.__init__(self, *args, **kwargs) + ServerProtocol.__init__(self) + +class UDPClientProtocol(UDPProtocol, ClientProtocol): + pass + +#------------------------------------------------------------------------------ + +class UDPServerInterface(SocketServer, Interface): + __interface__ = 'udpserver' + + def __init__(self, *args, **kwargs): + SocketServer.__init__(self) + self._address = kwargs.pop('address', DEFAULT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + Interface.__init__(self, *args, **kwargs) + + def new_protocol(self): + p = UdpServerProtocol(callback = self._callback, hook=self._hook) + self.spawn_interface(p) + return p + + def _create_socket(self): + loop = asyncio.get_event_loop() + return loop.create_datagram_endpoint(self.new_protocol, + local_addr=(self._address, self._port)) + +class UDPClientInterface(SocketClient, Interface): + __interface__ = 'udpclient' + + def __init__(self, *args, **kwargs): + SocketClient.__init__(self) + self._address = kwargs.pop('address', DEFAULT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + Interface.__init__(self, *args, **kwargs) + + def _create_socket(self): + loop = asyncio.get_event_loop() + protocol = lambda : UDPClientProtocol(self) + return loop.create_datagram_endpoint(protocol, + remote_addr=(self._address, self._port)) diff --git a/netmodel/interfaces/socket/unix.py b/netmodel/interfaces/socket/unix.py new file mode 100644 index 00000000..eec3d680 --- /dev/null +++ b/netmodel/interfaces/socket/unix.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio + +from netmodel.interfaces.socket import ServerProtocol, ClientProtocol +from netmodel.interfaces.socket import SocketClient, SocketServer +from netmodel.network.interface import Interface +from netmodel.util.misc import silentremove + +DEFAULT_PATH = '/tmp/unix_interface' + +class UnixProtocol: + + def send_impl(self, data): + msg = data.encode() + self._transport.write(msg) + + def data_received(self, data): + """ + Called when some data is received. + The argument is a bytes object. + """ + msg = data.decode() + self.receive(msg, ingress_interface=self) + +class UnixServerProtocol(UnixProtocol, ServerProtocol, Interface): + __interface__ = 'unix' + + def __init__(self, *args, **kwargs): + # Note: super() does not call all parents' constructors + Interface.__init__(self, *args, **kwargs) + ServerProtocol.__init__(self) + +class UnixClientProtocol(UnixProtocol, ClientProtocol): + pass + +#------------------------------------------------------------------------------ + +class UnixServerInterface(SocketServer, Interface): + __interface__ = 'unixserver' + + def __init__(self, *args, **kwargs): + SocketServer.__init__(self) + self._path = kwargs.pop('path', DEFAULT_PATH) + Interface.__init__(self, *args, **kwargs) + + def terminate(self): + silentremove(self._path) + + def new_protocol(self): + p = UnixServerProtocol(callback=self._callback, hook=self._hook) + self.spawn_interface(p) + return p + + def _create_socket(self): + loop = asyncio.get_event_loop() + silentremove(self._path) + return loop.create_unix_server(self.new_protocol, self._path) + +class UnixClientInterface(SocketClient, Interface): + __interface__ = 'unixclient' + + def __init__(self, *args, **kwargs): + SocketClient.__init__(self) + self._path = kwargs.pop('path', DEFAULT_PATH) + Interface.__init__(self, *args, **kwargs) + + def _create_socket(self): + loop = asyncio.get_event_loop() + protocol = lambda : UnixClientProtocol(self) + return loop.create_unix_connection(protocol, self._path) diff --git a/netmodel/interfaces/vicn.py b/netmodel/interfaces/vicn.py new file mode 100644 index 00000000..9ec9672e --- /dev/null +++ b/netmodel/interfaces/vicn.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vicn.core.task import BashTask +from netmodel.model.object import Object +from netmodel.model.attribute import Attribute +from netmodel.model.query import Query, ACTION_INSERT +from netmodel.model.type import String +from netmodel.network.interface import Interface, InterfaceState +from netmodel.network.packet import Packet +from netmodel.network.prefix import Prefix +from netmodel.util.misc import lookahead + +class VICNBaseResource(Object): + __type__ = 'vicn/' + + @classmethod + def get(cls, query, interface): + cb = interface._callback + + if query.object_name == 'script': + predicates = query.filter.to_list() + assert len(predicates) == 1 + _, _, name = predicates[0] + script = '{}/{}'.format(interface._manager._base, name) + + task = BashTask(None, script) + interface._manager.schedule(task) + return + + elif query.object_name == 'gui': + interface._manager._broadcast(query) + return + + elif query.object_name == 'resource': + resources = interface._manager.get_resources() + else: + resources = interface._manager.by_type_str(query.object_name) + + for resource, last in lookahead(resources): + params = resource.get_attribute_dict(aggregates = True) + params['id'] = resource._state.uuid._uuid + params['type'] = resource.get_types() + params['state'] = resource._state.state + params['log'] = resource._state.log + reply = Query(ACTION_INSERT, query.object_name, params = params) + reply.last = last + packet = Packet.from_query(reply, reply = True) + cb(packet, ingress_interface = interface) + +class L2Graph(Object): + __type__ = 'vicn/l2graph' + + @classmethod + def get(cls, query, interface): + cb = interface._callback + + from vicn.resource.central import _get_l2_graph + G = _get_l2_graph(interface._manager, with_managed=True) + + nodes = G.nodes() + edges = G.edges() + params = {'nodes': nodes, 'edges': edges} + reply = Query(ACTION_INSERT, query.object_name, params = params) + reply.last = True + packet = Packet.from_query(reply, reply = True) + cb(packet, ingress_interface = interface) + +class VICNInterface(Interface): + __interface__ = 'vicn' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._manager = kwargs.pop('manager') + + # Resources + resources = list() + resources.extend(self._manager.get_resource_type_names()) + resources.append('resource') + for resource in resources: + class VICNResource(VICNBaseResource): + __type__ = '{}'.format(resource.lower()) + self.register_object(VICNResource) + + self.register_object(L2Graph) diff --git a/netmodel/interfaces/websocket/__init__.py b/netmodel/interfaces/websocket/__init__.py new file mode 100644 index 00000000..cb79fc39 --- /dev/null +++ b/netmodel/interfaces/websocket/__init__.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import logging +import json + +from netmodel.network.interface import Interface, InterfaceState +from netmodel.network.packet import Packet +from netmodel.model.query import Query +from netmodel.model.query import ACTION_INSERT, ACTION_SELECT +from netmodel.model.query import ACTION_UPDATE, ACTION_DELETE +from netmodel.model.query import ACTION_EXECUTE + +from autobahn.asyncio.websocket import WebSocketClientProtocol, \ + WebSocketClientFactory +from autobahn.asyncio.websocket import WebSocketServerProtocol, \ + WebSocketServerFactory + +log = logging.getLogger(__name__) + +DEFAULT_ADDRESS = '0.0.0.0' +DEFAULT_CLIENT_ADDRESS = '127.0.0.1' +DEFAULT_PORT = 9000 +DEFAULT_TIMEOUT = 2 + +#------------------------------------------------------------------------------ + +from json import JSONEncoder +class DictEncoder(JSONEncoder): + """Default JSON encoder + + Because some classes are not JSON serializable, we define here our own + encoder which is based on the member variables of the object. + + The ideal solution would be to make all objects JSON serializable, but this + encoder is useful for user-defined classes that would otherwise make the + whole program to fail. It might though raise a warning to incitate + developers to make their class conformant. + + Reference: + http://stackoverflow.com/questions/3768895/how-to-make-a-class-json-serializable + """ + def default(self, o): + try: + return vars(o) + except: + return {} + +#------------------------------------------------------------------------------ + +class ClientProtocol(WebSocketClientProtocol): + """ + Default WebSocket client protocol. + + This protocol is mainly used to relay events to the Interface, which is + pointer to by the factory. + """ + + #-------------------------------------------------------------------------- + # Internal methods + #-------------------------------------------------------------------------- + + def send_impl(self, packet): + msg = json.dumps(packet.to_query().to_dict()) + self.sendMessage(msg.encode(), False) + + #-------------------------------------------------------------------------- + # WebSocket events + #-------------------------------------------------------------------------- + + # Websocket events + + def onConnect(self, response): + """ + Websocket opening handshake is started by the client. + """ + self.factory.interface._on_client_connect(self, response) + + def onOpen(self): + """ + Websocket opening handshake has completed. + """ + self.factory.interface._on_client_open(self) + + def onMessage(self, payload, isBinary): + self.factory.interface._on_client_message(self, payload, isBinary) + + def onClose(self, wasClean, code, reason): + self.factory.interface._on_client_close(self, wasClean, code, reason) + +#------------------------------------------------------------------------------ + +class WebSocketClientInterface(Interface): + """ + All messages are exchanged using text (non-binary) mode. + """ + __interface__ = 'websocketclient' + + def __init__(self, *args, **kwargs): + """ + Constructor. + + Args: + address (str) : Address of the remote websocket server. Defaults to + 127.0.0.1 (localhost). + port (int) : Port of the remote websocket server. Defaults to 9999. + + This constructor triggers the initialization of a WebSocket client + factory, which is associated a ClientProtocol, as well as a reference + to the current interface. + + PendingUp --- connect --- Up ...disconnect... Down + A | + +-----+ + retry + + All messages are exchanged using text (non-binary) mode. + """ + + self._address = kwargs.pop('address', DEFAULT_CLIENT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + self._timeout = kwargs.pop('timeout', DEFAULT_TIMEOUT) + + super().__init__(*args, **kwargs) + + self._factory = WebSocketClientFactory("ws://{}:{}".format( + self._address, self._port)) + self._factory.protocol = ClientProtocol + self._factory.interface = self + + self._instance = None + + # Holds the instance of the connect client protocol + self._client = None + + #-------------------------------------------------------------------------- + # Interface API + #-------------------------------------------------------------------------- + + async def pending_up_impl(self): + await self._connect() + + def send_impl(self, packet): + if not self._client: + log.error('interface is up but has no client') + self._client.send_impl(packet) + + #-------------------------------------------------------------------------- + # Internal methods + #-------------------------------------------------------------------------- + + async def _connect(self): + loop = asyncio.get_event_loop() + try: + self._instance = await loop.create_connection(self._factory, + self._address, self._port) + except Exception as e: + log.warning('Connect failed : {}'.format(e)) + self._instance = None + # don't await for retry, since it cause an infinite recursion... + asyncio.ensure_future(self._retry()) + + async def _retry(self): + """ + Timer: retry connection after timeout. + """ + log.info('Reconnecting in {} seconds...'.format(self._timeout)) + await asyncio.sleep(self._timeout) + log.info('Reconnecting...') + await self._connect() + + # WebSocket events (from the underlying protocol) + + def _on_client_connect(self, client, response): + self._client = client + + def _on_client_open(self, client): + self.set_state(InterfaceState.Up) + + def _on_client_message(self, client, payload, isBinary): + """ + Event: a message is received by the WebSocket client connection. + """ + + assert not isBinary + + args = json.loads(payload.decode('utf-8')) + query, record = None, None + if len(args) == 2: + query, record = args + else: + query = args + + if isinstance(query, dict): + query = Query.from_dict(query) + else: + query = Query(ACTION_SELECT, query) + + packet = Packet.from_query(query) + + self.receive(packet) + + def _on_client_close(self, client, wasClean, code, reason): + self._client = None + self._instance = None + + self.set_state(InterfaceState.Down) + + # Schedule reconnection + asyncio.ensure_future(self._retry()) + +#------------------------------------------------------------------------------ + +class ServerProtocol(WebSocketServerProtocol, Interface): + """ + Default WebSocket server protocol. + + This protocol is used for every server-side accepted WebSocket connection. + As such it is an Interface on its own, and should handle the Interface state + machinery. + + We would better triggering code directly + """ + __interface__ = 'websocket' + + def __init__(self, callback, hook): + """ + Constructor. + + Args: + callback (Function[ -> ]) : + hook (Function[->]) : Hook method to be called for every packet to + be sent on the interface. Processing continues with the packet + returned by this function, or is interrupted in case of a None + value. Defaults to None = no hook. + """ + WebSocketServerProtocol.__init__(self) + Interface.__init__(self, callback=callback, hook=hook) + + #-------------------------------------------------------------------------- + # Interface API + #-------------------------------------------------------------------------- + + async def pending_up_impl(self): + await self._set_state(InterfaceState.Up) + + def send_impl(self, packet): + # We assume we only send records... + msg = json.dumps(packet.to_query().to_dict(), cls=DictEncoder) + self.sendMessage(msg.encode(), False) + + #-------------------------------------------------------------------------- + # Internal methods + #-------------------------------------------------------------------------- + + # Websocket events + + def onConnect(self, request): + self.factory._instances.append(self) + self.set_state(InterfaceState.Up) + + def onOpen(self): + #print("WebSocket connection open.") + pass + + def onMessage(self, payload, isBinary): + assert not isBinary, "Binary message received: {0} bytes".format( + len(payload)) + query_dict = json.loads(payload.decode('utf8')) + query = Query.from_dict(query_dict) + packet = Packet.from_query(query) + self.receive(packet) + + def onClose(self, wasClean, code, reason): + self.set_state(InterfaceState.Down) + try: + self.factory._instances.remove(self) + except: pass + + self.delete_interface(self) + +#------------------------------------------------------------------------------ + +class WebSocketServerInterface(Interface): + """ + This virtual interface only listens for incoming connections in order to + dynamically instanciate new interfaces upon client connection. + + It is also used to broadcast packets to all connected clients. + + All messages are exchanged using text (non-binary) mode. + """ + + __interface__ = 'websocketserver' + + def __init__(self, *args, **kwargs): + self._address = kwargs.pop('address', DEFAULT_ADDRESS) + self._port = kwargs.pop('port', DEFAULT_PORT) + + super().__init__(*args, **kwargs) + + def new_server_protocol(): + p = ServerProtocol(self._callback, self._hook) + self.spawn_interface(p) + return p + + ws_url = u"ws://{}:{}".format(self._address, self._port) + self._factory = WebSocketServerFactory(ws_url) + # see comment in MyWebSocketServerFactory + self._factory.protocol = new_server_protocol + self._factory._callback = self._callback + self._factory._interface = self + + # A list of all connected instances (= interfaces), used to broadcast + # packets. + self._factory._instances = list() + + #-------------------------------------------------------------------------- + # Interface API + #-------------------------------------------------------------------------- + + async def pending_up_impl(self): + """ + As we have no feedback for when the server is actually started, we mark + the interface up as soon as the create_server method returns. + """ + loop = asyncio.get_event_loop() + # Websocket server + log.info('WebSocket server started') + self._server = await loop.create_server(self._factory, self._address, + self._port) + await self._set_state(InterfaceState.Up) + + async def pending_down_impl(self): + raise NotImplementedError + + def send_impl(self, packet): + """ + Sends a packet to all connected clients (broadcast). + """ + for instance in self._factory._instances: + instance.send(packet) diff --git a/netmodel/model/__init__.py b/netmodel/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netmodel/model/attribute.py b/netmodel/model/attribute.py new file mode 100644 index 00000000..b69ee1bf --- /dev/null +++ b/netmodel/model/attribute.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import copy +import logging +import operator +import types + +from netmodel.model.mapper import ObjectSpecification +from netmodel.model.type import is_type +from netmodel.util.meta import inheritors +from netmodel.util.misc import is_iterable +from vicn.core.sa_collections import InstrumentedList, _list_decorators + +log = logging.getLogger(__name__) +instance_dict = operator.attrgetter('__dict__') + +class NEVER_SET: None + +#------------------------------------------------------------------------------ +# Attribute Multiplicity +#------------------------------------------------------------------------------ + +class Multiplicity: + _1_1 = '1_1' + _1_N = '1_N' + _N_1 = 'N_1' + _N_N = 'N_N' + + + @staticmethod + def reverse(value): + reverse_map = { + Multiplicity._1_1: Multiplicity._1_1, + Multiplicity._1_N: Multiplicity._N_1, + Multiplicity._N_1: Multiplicity._1_N, + Multiplicity._N_N: Multiplicity._N_N, + } + return reverse_map[value] + + +# Default attribute properties values (default to None) +DEFAULT = { + 'multiplicity' : Multiplicity._1_1, + 'mandatory' : False, +} + +#------------------------------------------------------------------------------ +# Attribute +#------------------------------------------------------------------------------ + +class Attribute(abc.ABC, ObjectSpecification): + properties = [ + 'name', + 'type', + 'description', + 'default', + 'choices', + 'mandatory', + 'multiplicity', + 'ro', + 'auto', + 'func', + 'requirements', + 'reverse_name', + 'reverse_description', + 'reverse_auto' + ] + + def __init__(self, *args, **kwargs): + for key in Attribute.properties: + value = kwargs.pop(key, NEVER_SET) + setattr(self, key, value) + + if len(args) == 1: + self.type, = args + elif len(args) == 2: + self.name, self.type = args + assert is_type(self.type) + + self.is_aggregate = False + + self._reverse_attributes = list() + + #-------------------------------------------------------------------------- + # Display + #-------------------------------------------------------------------------- + + def __repr__(self): + return ''.format(self.name) + + __str__ = __repr__ + + #-------------------------------------------------------------------------- + # Descriptor protocol + # + # see. https://docs.python.org/3/howto/descriptor.html + #-------------------------------------------------------------------------- + + def __get__(self, instance, owner=None): + if instance is None: + return self + + value = instance_dict(instance).get(self.name, NEVER_SET) + + # Case : collection attribute + if self.is_collection: + if value is NEVER_SET: + if isinstance(self.default, types.FunctionType): + default = self.default(instance) + else: + default = self.default + value = InstrumentedList(default) + value._attribute = self + value._instance = instance + self.__set__(instance, value) + return value + return value + + # Case : scalar attribute + + if value in (None, NEVER_SET) and self.auto not in (None, NEVER_SET): + # Automatic instanciation + if not self.requirements in (None, NEVER_SET) and \ + self.requirements: + log.warning('Ignored requirement {}'.format(self.requirements)) + value = instance.auto_instanciate(self) + self.__set__(instance, value) + return value + + if value is NEVER_SET: + if isinstance(self.default, types.FunctionType): + value = self.default(instance) + else: + value = copy.deepcopy(self.default) + self.__set__(instance, value) + return value + + return value + + def __set__(self, instance, value): + if instance is None: + return + + if self.is_collection: + if not isinstance(value, InstrumentedList): + value = InstrumentedList(value) + value._attribute = self + value._instance = instance + + instance_dict(instance)[self.name] = value + if hasattr(instance, '_state'): + instance._state.attr_dirty.add(self.name) + instance._state.dirty = True + + def __delete__(self, instance): + raise NotImplementedError + + #-------------------------------------------------------------------------- + # Accessors + #-------------------------------------------------------------------------- + + def __getattribute__(self, name): + value = super().__getattribute__(name) + if value is NEVER_SET: + if name == 'default': + return list() if self.is_collection else None + return DEFAULT.get(name, None) + return value + + # Shortcuts + + def has_reverse_attribute(self): + return self.reverse_name and self.multiplicity + + @property + def is_collection(self): + return self.multiplicity in (Multiplicity._1_N, Multiplicity._N_N) + + def is_set(self, instance): + return self.name in instance_dict(instance) + + #-------------------------------------------------------------------------- + # Operations + #-------------------------------------------------------------------------- + + def merge(self, parent): + for prop in Attribute.properties: + # NOTE: we cannot use getattr otherwise we get the default value, + # and we never override + value = vars(self).get(prop, NEVER_SET) + if value is not NEVER_SET and not is_iterable(value): + continue + + parent_value = vars(parent).get(prop, NEVER_SET) + if parent_value is NEVER_SET: + continue + + if parent_value: + if is_iterable(value): + value.extend(parent_value) + else: + setattr(self, prop, parent_value) + + #-------------------------------------------------------------------------- + # Attribute values + #-------------------------------------------------------------------------- + + def _handle_getitem(self, instance, item): + return item + + def _handle_add(self, instance, item): + instance._state.dirty = True + instance._state.attr_dirty.add(self.name) + print('marking', self.name, 'as dirty') + return item + + def _handle_remove(self, instance, item): + instance._state.dirty = True + instance._state.attr_dirty.add(self.name) + print('marking', self.name, 'as dirty') + + def _handle_before_remove(self, instance): + pass + + #-------------------------------------------------------------------------- + # Attribute values + #-------------------------------------------------------------------------- + +class Relation(Attribute): + properties = Attribute.properties[:] + properties.extend([ + 'reverse_name', + 'reverse_description', + 'multiplicity', + ]) + +class SelfRelation(Relation): + def __init__(self, *args, **kwargs): + if args: + if not len(args) == 1: + raise ValueError('Bad initialized for SelfRelation') + name, = args + super().__init__(name, None, *args, **kwargs) + else: + super().__init__(None, *args, **kwargs) diff --git a/netmodel/model/field_names.py b/netmodel/model/field_names.py new file mode 100644 index 00000000..82881998 --- /dev/null +++ b/netmodel/model/field_names.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FIELD_SEPARATOR = '.' +DEFAULT_IS_STAR = False + +class FieldNames(list): + """ + A FieldNames instance gather a set of field_names or represents *. + THIS IS NOT a set(Field). + + The distinction between parent and children fields is based on the + FieldNames.FIELD_SEPARATOR character. + """ + + #-------------------------------------------------------------------------- + # Constructor + #-------------------------------------------------------------------------- + + def __init__(self, *args, **kwargs): + """ + Constructor. + """ + star = kwargs.pop('star', DEFAULT_IS_STAR) + list.__init__(self, *args, **kwargs) + size = len(self) + if star and size != 0: + raise ValueError("Inconsistent parameter (star = %s size = %s)" % \ + (star, size)) + # self._star == False and len(self) == 0 occurs when we create + # FieldNames() (to use later |=) and must behaves as FieldNames(star = + # False) + self._star = star + + def __repr__(self): + """ + Returns: + The %r representation of this FieldNames instance. + """ + if self.is_star(): + return "" + else: + return "" % [x for x in self] + + def __hash__(self): + return hash((self._star,) + tuple(self)) + + def is_star(self): + """ + Returns: + True iif this FieldNames instance correspond to "*". + Example : SELECT * FROM foo + """ + try: + return self._star + except: + # This is due to a bug in early versions of Python 2.7 which are + # present on PlanetLab. During copy.deepcopy(), the object is + # reconstructed using append before the state (self.__dict__ is + # reconstructed). Hence the object has no _star when append is + # called and this raises a crazy Exception: + # I could not reproduce in a smaller example + # http://pastie.org/private/5nf15jg0qcvd05pbmnrp8g + return False + + def set_star(self): + """ + Update this FieldNames instance to make it corresponds to "*" + """ + self._star = True + self.clear() + + def unset_star(self, field_names): + """ + Update this FieldNames instance to make it corresponds to a set of + FieldNames + + Args: + field_names: A FieldNames instance or a set of String instances + (field names) + """ + assert len(field_names) > 0 + self._star = False + if field_names: + self |= field_names + + def is_empty(self): + """ + Returns: + True iif FieldNames instance designates contains least one field + name. + """ + return not self.is_star() and len(self) == 0 + + def copy(self): + """ + Returns: + A copy of this FieldNames instance. + """ + return FieldNames(self[:]) + + #-------------------------------------------------------------------------- + # Iterators + #-------------------------------------------------------------------------- + + def iter_field_subfield(self): + for f in self: + field, _, subfield = f.partition(FIELD_SEPARATOR) + yield (field, subfield) + + #-------------------------------------------------------------------------- + # Overloaded set internal functions + #-------------------------------------------------------------------------- + + def __or__(self, fields): + """ + Compute the union of two FieldNames instances. + Args: + fields: a set of String (corresponding to field names) or a + FieldNames instance. + Returns: + The union of the both FieldNames instance. + """ + if self.is_star() or fields.is_star(): + return FieldNames(star = True) + else: + l = self[:] + l.extend([x for x in fields if x not in l]) + return FieldNames(l) + + def __ior__(self, fields): + """ + Compute the union of two FieldNames instances. + Args: + fields: a set of Field instances or a FieldNames instance. + Returns: + The updated FieldNames instance. + """ + if fields.is_star(): + self.set_star() + return self + else: + self.extend([x for x in fields if x not in self]) + return self + + def __and__(self, fields): + """ + Compute the intersection of two FieldNames instances. + Args: + fields: a set of Field instances or a FieldNames instance. + Returns: + The intersection of the both FieldNames instances. + """ + if self.is_star(): + return fields.copy() + elif isinstance(fields, FieldNames) and fields.is_star(): + return self.copy() + else: + return FieldNames([x for x in self if x in fields]) + + def __iand__(self, fields): + """ + Compute the intersection of two FieldNames instances. + Args: + fields: a set of Field instances or a FieldNames instance. + Returns: + The updated FieldNames instance. + """ + if self.is_star(): + self.unset_star(fields) + elif fields.is_star(): + pass + else: + self[:] = [x for x in self if x in fields] + return self + + def __nonzero__(self): + return self.is_star() or bool(list(self)) + + # Python>=3 + __bool__ = __nonzero__ + + __add__ = __or__ + + def __sub__(self, fields): + if fields.is_star(): + return FieldNames(star = False) + else: + if self.is_star(): + # * - x,y,z = ??? + return FieldNames(star = True) + else: + return FieldNames([x for x in self if x not in fields]) + + def __isub__(self, fields): + raise NotImplemented + + def __iadd__(self, fields): + raise NotImplemented + + #-------------------------------------------------------------------------- + # Overloaded set comparison functions + #-------------------------------------------------------------------------- + + def __eq__(self, other): + """ + Test whether this FieldNames instance corresponds to another one. + Args: + other: The FieldNames instance compared to self. + Returns: + True if the both FieldNames instance matches. + """ + return self.is_star() and other.is_star() or set(self) == set(other) + + def __le__(self, other): + """ + Test whether this FieldNames instance in included in + (or equal to) another one. + Args: + other: The FieldNames instance compared to self or + Returns: + True if the both FieldNames instance matches. + """ + assert isinstance(other, FieldNames),\ + "Invalid other = %s (%s)" % (other, type(other)) + + return (self.is_star() and other.is_star())\ + or (not self.is_star() and other.is_star())\ + or (set(self) <= set(other)) # list.__le__(self, other) + + # Defined with respect of previous functions + + def __ne__(self, other): + """ + Test whether this FieldNames instance differs to another one. + Args: + other: The FieldNames instance compared to self. + Returns: + True if the both FieldNames instance differs. + """ + return not self == other + + def __lt__(self, other): + """ + Test whether this FieldNames instance in strictly included in + another one. + Args: + other: The FieldNames instance compared to self. + Returns: + True if self is strictly included in other. + """ + return self <= other and self != other + + def __ge__(self, other): + return other.__le__(self) + + def __gt__(self, other): + return other.__lt__(self) + + #-------------------------------------------------------------------------- + # Overloaded set functions + #-------------------------------------------------------------------------- + + def add(self, field_name): + # DEPRECATED + assert isinstance(field_name, str) + self.append(field_name) + + def set(self, field_names): + assert isinstance(field_names, FieldNames) + if field_names.is_star(): + self.set_star() + return + assert len(field_names) > 0 + self._star = False + self.clear() + self |= field_names + + def append(self, field_name): + if not isinstance(field_name, str): + raise TypeError("Invalid field_name %s (string expected, got %s)" \ + % (field_name, type(field_name))) + + if not self.is_star(): + list.append(self, field_name) + + def clear(self): + self._star = True + del self[:] + + def rename(self, aliases): + """ + Rename all the field names involved in self according to a dict. + Args: + aliases: A {String : String} mapping the old field name and + the new field name. + Returns: + The updated FieldNames instance. + """ + s = self.copy() + for element in s: + if element in aliases: + s.remove(element) + s.add(aliases[element]) + self.clear() + self |= s + return self + + @staticmethod + def join(field, subfield): + return "%s%s%s" % (field, FIELD_SEPARATOR, subfield) + + @staticmethod + def after_path(field, path, allow_shortcuts = True): + """ + Returns the part of the field after path + + Args: + path (list): + allow_shortcuts (bool): Default to True. + """ + if not path: + return (field, None) + last = None + field_parts = field.split(FIELD_SEPARATOR) + for path_element in path[1:]: + if path_element == field_parts[0]: + field_parts.pop(0) + last = None + else: + last = path_element + return (FIELD_SEPARATOR.join(field_parts), last) + + def split_subfields(self, include_parent = True, current_path = None, + allow_shortcuts = True): + """ + Args: + include_parent (bool): is the parent field included in the list of + returned FieldNames (1st part of the tuple). + current_path (list): the path of fields that will be skipped at the + beginning + path_shortcuts (bool): do we allow shortcuts in the path + + Returns: A tuple made of 4 operands: + fields: + map_method_subfields: + map_original_field: + rename: + + Example path = ROOT.A.B + split_subfields(A.B.C.D, A.B.C.D', current_path=[ROOT,A,B]) => + (FieldNames(), { C: [D, D'] }) + split_subfields(A.E.B.C.D, A.E.B.C.D', current_path=[ROOT,A,B]) => + (FieldNames(), { C: [D, D'] }) + """ + field_names = FieldNames() + map_method_subfields = dict() + map_original_field = dict() + rename = dict() + + for original_field in self: + # The current_path can be seen as a set of fields that have to be + # passed through before we can consider a field + field, last = FieldNames.after_path(original_field, current_path, + allow_shortcuts) + + field_name, _, subfield = field.partition(FIELD_SEPARATOR) + + if not subfield: + field_names.add(field_name) + else: + if include_parent: + field_names.add(field_name) + if not field_name in map_method_subfields: + map_method_subfields[field_name] = FieldNames() + map_method_subfields[field_name].add(subfield) + + map_original_field[field_name] = original_field + rename[field_name] = last + + return (field_names, map_method_subfields, map_original_field, rename) diff --git a/netmodel/model/filter.py b/netmodel/model/filter.py new file mode 100644 index 00000000..d0790e3e --- /dev/null +++ b/netmodel/model/filter.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy + +from netmodel.model.field_names import FieldNames +from netmodel.model.predicate import Predicate, eq, included +from netmodel.util.misc import is_iterable + +class Filter(set): + """ + A Filter is a set of Predicate instances + """ + + @staticmethod + def from_list(l): + """ + Create a Filter instance by using an input list. + Args: + l: A list of Predicate instances. + """ + f = Filter() + try: + for element in l: + f.add(Predicate(*element)) + except Exception as e: + #print("Error in setting Filter from list", e) + return None + return f + + @staticmethod + def from_dict(d): + """ + Create a Filter instance by using an input dict. + Args: + d: A dict {key : value} instance where each + key-value pair leads to a Predicate. + 'key' could start with the operator to be + used in the predicate, otherwise we use + '=' by default. + """ + f = Filter() + for key, value in d.items(): + if key[0] in Predicate.operators.keys(): + f.add(Predicate(key[1:], key[0], value)) + else: + f.add(Predicate(key, '=', value)) + return f + + def to_list(self): + """ + Returns: + The list corresponding to this Filter instance. + """ + ret = list() + for predicate in self: + ret.append(predicate.to_list()) + return ret + + @staticmethod + def from_clause(clause): + """ + NOTE: We can only handle simple clauses formed of AND fields. + """ + raise NotImplementedError + + @staticmethod + def from_string(string): + """ + """ + from netmodel.model.sql_parser import SQLParser + p = SQLParser() + ret = p.filter.parseString(string, parseAll=True) + return ret[0] if ret else None + + def filter_by(self, predicate): + """ + Update this Filter by adding a Predicate. + Args: + predicate: A Predicate instance. + Returns: + The resulting Filter instance. + """ + assert isinstance(predicate, Predicate),\ + "Invalid predicate = %s (%s)" % (predicate, type(predicate)) + self.add(predicate) + return self + + def unfilter_by(self, *args): + assert len(args) == 1 or len(args) == 3, \ + "Invalid expression for filter" + + if not self.is_empty(): + if len(args) == 1: + # we got a Filter, or a set, or a list, or a tuple or None. + filters = args[0] + if filters != None: + if not isinstance(filters, (set, list, tuple, Filter)): + filters = [filters] + for predicate in set(filters): + self.remove(predicate) + elif len(args) == 3: + # we got three args: (field_name, op, value) + predicate = Predicate(*args) + self.remove(predicate) + + assert isinstance(self, Filter),\ + "Invalid filters = %s" % (self, type(self)) + return self + + def add(self, predicate_or_filter): + """ + Adds a predicate or a filter (a set of predicate) -- or a list thereof + -- to the current filter. + """ + if is_iterable(predicate_or_filter): + map(self.add, predicate_or_filter) + return + + assert isinstance(predicate_or_filter, Predicate) + set.add(self, predicate_or_filter) + + def is_empty(self): + """ + Tests whether this Filter is empty or not. + Returns: + True iif this Filter is empty. + """ + return len(self) == 0 + + def __str__(self): + """ + Returns: + The '%s' representation of this Filter. + """ + if self.is_empty(): + return "" + else: + return " AND ".join([str(pred) for pred in self]) + + def __repr__(self): + """ + Returns: + The '%r' representation of this Filter. + """ + return '' % self + + def __key(self): + return tuple([hash(pred) for pred in self]) + + def __hash__(self): + return hash(self.__key()) + + def __additem__(self, value): + if not isinstance(value, Predicate): + raise TypeError("Element of class Predicate expected, received %s"\ + % value.__class__.__name__) + set.__additem__(self, value) + + + def copy(self): + return copy.deepcopy(self) + + def keys(self): + """ + Returns: + A set of String corresponding to each field name + involved in this Filter. + """ + return set([x.key for x in self]) + + def has(self, key): + for x in self: + if x.key == key: + return True + return False + + def has_op(self, key, op): + for x in self: + if x.key == key and x.op == op: + return True + return False + + def has_eq(self, key): + return self.has_op(key, eq) + + def get(self, key): + ret = [] + for x in self: + if x.key == key: + ret.append(x) + return ret + + def delete(self, key): + to_del = [] + for x in self: + if x.key == key: + to_del.append(x) + for x in to_del: + self.remove(x) + + def get_op(self, key, op): + if isinstance(op, (list, tuple, set)): + for x in self: + if x.key == key and x.op in op: + return x.value + else: + for x in self: + if x.key == key and x.op == op: + return x.value + return None + + def get_eq(self, key): + return self.get_op(key, eq) + + def set_op(self, key, op, value): + for x in self: + if x.key == key and x.op == op: + x.value = value + return + raise KeyError(key) + + def set_eq(self, key, value): + return self.set_op(key, eq, value) + + def get_predicates(self, key): + ret = [] + for x in self: + if x.key == key: + ret.append(x) + return ret + + def match(self, dic, ignore_missing=True): + for predicate in self: + if not predicate.match(dic, ignore_missing): + return False + return True + + def filter(self, l): + output = [] + for x in l: + if self.match(x): + output.append(x) + return output + + def get_field_names(self): + field_names = FieldNames() + for predicate in self: + field_names |= predicate.get_field_names() + return field_names + + def grep(self, fun): + return Filter([x for x in self if fun(x)]) + + def rgrep(self, fun): + return Filter([x for x in self if not fun(x)]) + + def split(self, fun, true_only = False): + true_filter, false_filter = Filter(), Filter() + for predicate in self: + if fun(predicate): + true_filter.add(predicate) + else: + false_filter.add(predicate) + if true_only: + return true_filter + else: + return (true_filter, false_filter) + + + def split_fields(self, fields, true_only = False): + return self.split(lambda predicate: predicate.get_key() in fields, + true_only) + + def provides_key_field(self, key_fields): + # No support for tuples + for field in key_fields: + if not self.has_op(field, eq) and not self.has_op(field, included): + # Missing key fields in query filters + return False + return True + + def rename(self, aliases): + for predicate in self: + predicate.rename(aliases) + return self + + def get_field_values(self, field): + """ + This function returns the values that are determined by the filters for + a given field, or None is the filter is not *setting* determined values. + + Returns: list : a list of fields + """ + value_list = list() + for predicate in self: + key, op, value = predicate.get_tuple() + + if key == field: + extract_tuple = False + elif key == (field, ): + extract_tuple = True + else: + continue + + if op == eq: + if extract_tuple: + value = value[0] + value_list.append(value) + elif op == included: + if extract_tuple: + value = [x[0] for x in value] + value_list.extend(value) + else: + continue + + return list(set(value_list)) + + def update_field_value_eq(self, field, value): + for predicate in self: + p_field, p_op, p_value = predicate.get_tuple() + if p_field == field: + predicate.set_op(eq) + predicate.set_value(value) + break # assuming there is a single predicate with field/op + + def __and__(self, other): + # Note: we assume the predicates in self and other are already in + # minimal form, eg. not the same fields twice... We could break after + # a predicate with the same key is found btw... + s = self.copy() + for o_predicate in other: + o_key, o_op, o_value = o_predicate.get_tuple() + + key_found = False + for predicate in s: + key, op, value = predicate.get_tuple() + if key != o_key: + continue + + # We already have a predicate with the same key + key_found = True + + if op == eq: + if o_op == eq: + # Similar filters... + if value != o_value: + # ... with different values + return None + else: + # ... with same values + pass + elif o_op == included: + # Inclusion + if value not in o_value: + # no overlap + return None + else: + # We already have the more restrictive predicate... + pass + + elif op == included: + if o_op == eq: + if o_value not in value: + return None + else: + # One value overlaps... update the initial predicate + # with the more restrictive one + predicate.set_op(eq) + predicate.set_value(value) + elif o_op == included: + intersection = set(o_value) & set(value) + if not set(o_value) & set(value): + return None + else: + predicate.set_value(tuple(intersection)) + + # No conflict found, we can add the predicate to s + if not key_found: + s.add(o_predicate) + + return s diff --git a/netmodel/model/mapper.py b/netmodel/model/mapper.py new file mode 100644 index 00000000..9be46a14 --- /dev/null +++ b/netmodel/model/mapper.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class ObjectSpecification: + pass diff --git a/netmodel/model/object.py b/netmodel/model/object.py new file mode 100644 index 00000000..32d3a833 --- /dev/null +++ b/netmodel/model/object.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from netmodel.model.attribute import Attribute +from netmodel.model.type import BaseType +from netmodel.model.mapper import ObjectSpecification + +# Warning and error messages + +E_UNK_RES_NAME = 'Unknown resource name for attribute {} in {} ({}) : {}' + +class ObjectMetaclass(ABCMeta): + """ + Object metaclass allowing non-uniform attribute declaration. + """ + + def __init__(cls, class_name, parents, attrs): + """ + Args: + cls: The class type we're registering. + class_name: A String containing the class_name. + parents: The parent class types of 'cls'. + attrs: The attribute (members) of 'cls'. + """ + super().__init__(class_name, parents, attrs) + cls._sanitize() + +class Object(BaseType, metaclass = ObjectMetaclass): + + def __init__(self, **kwargs): + """ + Object constructor. + + Args: + kwargs: named arguments consisting in object attributes to be + initialized at construction. + """ + mandatory = { a.name for a in self.iter_attributes() if a.mandatory } + + for key, value in kwargs.items(): + attribute = self.get_attribute(key) + if issubclass(attribute.type, Object): + if attribute.is_collection: + new_value = list() + for x in value: + if isinstance(x, str): + resource = self._state.manager.by_name(x) + elif isinstance(x, UUID): + resource = self._state.manager.by_uuid(x) + else: + resource = x + if not resource: + raise LurchException(E_UNK_RES_NAME.format(key, + self.name, self.__class__.__name__, x)) + new_value.append(resource._state.uuid) + value = new_value + else: + if isinstance(value, str): + resource = self._state.manager.by_name(value) + elif isinstance(value, UUID): + resource = self._state.manager.by_uuid(value) + else: + resource = value + if not resource: + raise LurchException(E_UNK_RES_NAME.format(key, + self.name, self.__class__.__name__, value)) + value = resource._state.uuid + setattr(self, key, value) + mandatory -= { key } + + # Check that all mandatory atttributes have been set + # Mandatory resource attributes will be marked as pending since they + # might be discovered + # Eventually, their absence will be discovered at runtime + if mandatory: + raise Exception('Mandatory attributes not set: %r' % (mandatory,)) + + # Assign backreferences (we need attribute to be initialized, so it has + # to be done at the end of __init__ + for other_instance, attribute in self.iter_backrefs(): + if attribute.is_collection: + collection = getattr(other_instance, attribute.name) + collection.append(self) + else: + setattr(other_instance, attribute.name, self) + + #-------------------------------------------------------------------------- + # Object model + #-------------------------------------------------------------------------- + + @classmethod + def get_attribute(cls, key): + return getattr(cls, key) + + @classmethod + def _sanitize(cls): + """Sanitize the object model to accomodate for multiple declaration + styles + + In particular, this method: + - set names to all attributes + """ + cls._reverse_attributes = dict() + cur_reverse_attributes = dict() + for name, obj in vars(cls).items(): + if not isinstance(obj, ObjectSpecification): + continue + if isinstance(obj, Attribute): + obj.name = name + + # Remember whether a reverse_name is defined before loading + # inherited properties from parent + has_reverse = bool(obj.reverse_name) + + # Handle overloaded attributes + # By recursion, it is sufficient to look into the parent + for base in cls.__bases__: + if hasattr(base, name): + parent_attribute = getattr(base, name) + obj.merge(parent_attribute) + assert obj.type + + # Handle reverse attribute + # + # NOTE: we need to do this after merging to be sure we get all + # properties inherited from parent (eg. multiplicity) + if has_reverse: + a = { + 'name' : obj.reverse_name, + 'description' : obj.reverse_description, + 'multiplicity' : Multiplicity.reverse(obj.multiplicity), + 'auto' : obj.reverse_auto, + } + reverse_attribute = Attribute(cls, **a) + reverse_attribute.is_aggregate = True + + cur_reverse_attributes[obj.type] = reverse_attribute + + if not obj in cls._reverse_attributes: + cls._reverse_attributes[obj] = list() + cls._reverse_attributes[obj].append(reverse_attribute) + + for cls, a in cur_reverse_attributes.items(): + setattr(cls, a.name, a) + + @classmethod + def iter_attributes(cls, aggregates = False): + for name in dir(cls): + attribute = getattr(cls, name) + if not isinstance(attribute, Attribute): + continue + if attribute.is_aggregate and not aggregates: + continue + + yield attribute + + def get_attributes(self, aggregates = False): + return list(self.iter_attributes(aggregates = aggregates)) + + def get_attribute_names(self, aggregates = False): + return set(a.name for a in self.iter_attributes(aggregates = \ + aggregates)) + + def get_attribute_dict(self, field_names = None, aggregates = False, + uuid = True): + assert not field_names or field_names.is_star() + attributes = self.get_attributes(aggregates = aggregates) + + ret = dict() + for a in attributes: + if not a.is_set(self): + continue + value = getattr(self, a.name) + if a.is_collection: + ret[a.name] = list() + for x in value: + if uuid and isinstance(x, Object): + x = x._state.uuid._uuid + ret[a.name].append(x) + else: + if uuid and isinstance(value, Object): + value = value._state.uuid._uuid + ret[a.name] = value + return ret + + def get_tuple(self): + return (self.__class__, self._get_attribute_dict()) + + def format(self, fmt): + return fmt.format(**self.get_attribute_dict(uuid = False)) + + def iter_backrefs(self): + for base in self.__class__.mro(): + if not hasattr(base, '_reverse_attributes'): + continue + for attr, rattrs in base._reverse_attributes.items(): + instances = getattr(self, attr.name) + if not attr.is_collection: + instances = [instances] + for instance in instances: + # - instance = node + if instance in (None, NEVER_SET): + continue + for rattr in rattrs: + yield instance, rattr + + #-------------------------------------------------------------------------- + # Accessors + #-------------------------------------------------------------------------- + + @classmethod + def has_attribute(cls, name): + return name in [a.name for a in cls.attributes()] + diff --git a/netmodel/model/predicate.py b/netmodel/model/predicate.py new file mode 100644 index 00000000..08ed956a --- /dev/null +++ b/netmodel/model/predicate.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy + +from netmodel.model.field_names import FieldNames, FIELD_SEPARATOR + +from operator import ( + and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg +) + +# Define the inclusion operators +class contains(type): pass +class included(type): pass + +class Predicate: + + operators = { + '==' : eq, + '!=' : ne, + '<' : lt, + '<=' : le, + '>' : gt, + '>=' : ge, + 'CONTAINS' : contains, + 'INCLUDED' : included + } + + operators_short = { + '=' : eq, + '~' : ne, + '<' : lt, + '[' : le, + '>' : gt, + ']' : ge, + '}' : contains, + '{' : included + } + + def __init__(self, *args, **kwargs): + """ + Build a Predicate instance. + Args: + kwargs: You can pass: + - 3 args (left, operator, right) + left: The left operand (it may be a String instance or a + tuple) + operator: See Predicate.operators, this is the binary + operator involved in this Predicate. + right: The right value (it may be a String instance + or a literal (String, numerical value, tuple...)) + - 1 argument (list or tuple), containing three arguments + (variable, operator, value) + """ + if len(args) == 3: + key, op, value = args + elif len(args) == 1 and isinstance(args[0], (tuple, list)) and \ + len(args[0]) == 3: + key, op, value = args[0] + elif len(args) == 1 and isinstance(args[0], Predicate): + key, op, value = args[0].get_tuple() + else: + raise Exception("Bad initializer for Predicate (args = %r)" % args) + + assert not isinstance(value, (frozenset, dict, set)), \ + "Invalid value type (type = %r)" % type(value) + if isinstance(value, list): + value = tuple(value) + + self.key = key + if isinstance(op, str): + op = op.upper() + if op in self.operators.keys(): + self.op = self.operators[op] + elif op in self.operators_short.keys(): + self.op = self.operators_short[op] + else: + self.op = op + + if isinstance(value, list): + self.value = tuple(value) + else: + self.value = value + + def __str__(self): + """ + Returns: + The '%s' representation of this Predicate. + """ + return repr(self) + + def __repr__(self): + """ + Returns: + The '%r' representation of this Predicate. + """ + key, op, value = self.get_str_tuple() + if isinstance(value, (tuple, list, set, frozenset)): + value = [repr(v) for v in value] + value = "(%s)" % ", ".join(value) + return "%s %s %r" % (key, op, value) + + def __hash__(self): + """ + Returns: + The hash of this Predicate (this allows to define set of + Predicate instances). + """ + return hash(self.get_tuple()) + + def __eq__(self, predicate): + """ + Returns: + True iif self == predicate. + """ + if not predicate: + return False + return self.get_tuple() == predicate.get_tuple() + + def copy(self): + return copy.deepcopy(self) + + def get_key(self): + """ + Returns: + The left operand of this Predicate. It may be a String + or a tuple of Strings. + """ + return self.key + + def set_key(self, key): + """ + Set the left operand of this Predicate. + Params: + key: The new left operand. + """ + self.key = key + + def update_key(self, function): + self.set_key(function(self.get_key())) + + def get_op(self): + return self.op + + def set_op(self, op): + self.op = op + + def get_value(self): + return self.value + + def set_value(self, value): + self.value = value + + def get_tuple(self): + return (self.key, self.op, self.value) + + def get_tuple_ext(self): + key, op, value = self.get_tuple() + key_field, _, key_subfield = key.partition(FIELD_SEPARATOR) + return (key_field, key_subfield, op, value) + + def get_str_op(self): + op_str = [s for s, op in self.operators.items() if op == self.op] + return op_str[0] + + def get_str_tuple(self): + return (self.key, self.get_str_op(), self.value,) + + def to_list(self): + return list(self.get_str_tuple()) + + def match(self, dic, ignore_missing=False): + # Can we match ? + if self.key not in dic: + return ignore_missing + + if self.op == eq: + if isinstance(self.value, list): + return (dic[self.key] in self.value) + else: + return (dic[self.key] == self.value) + elif self.op == ne: + if isinstance(self.value, list): + return (dic[self.key] not in self.value) + else: + return (dic[self.key] != self.value) + elif self.op == lt: + if isinstance(self.value, str): + # prefix match + return dic[self.key].startswith('%s.' % self.value) + else: + return (dic[self.key] < self.value) + elif self.op == le: + if isinstance(self.value, str): + return dic[self.key] == self.value or \ + dic[self.key].startswith('%s.' % self.value) + else: + return (dic[self.key] <= self.value) + elif self.op == gt: + if isinstance(self.value, str): + # prefix match + return self.value.startswith('%s.' % dic[self.key]) + else: + return (dic[self.key] > self.value) + elif self.op == ge: + if isinstance(self.value, str): + # prefix match + return dic[self.key] == self.value or \ + self.value.startswith('%s.' % dic[self.key]) + else: + return (dic[self.key] >= self.value) + elif self.op == and_: + return (dic[self.key] & self.value) + elif self.op == or_: + return (dic[self.key] | self.value) + elif self.op == contains: + try: + method, subfield = self.key.split('.', 1) + return not not [ x for x in dic[method] \ + if x[subfield] == self.value] + except ValueError: # split has failed + return self.value in dic[self.key] + elif self.op == included: + return dic[self.key] in self.value + else: + raise Exception("Unexpected table format: %r" % dic) + + def filter(self, dic): + """ + Filter dic according to the current predicate. + """ + + if '.' in self.key: + # users.hrn + method, subfield = self.key.split('.', 1) + if not method in dic: + return None + + if isinstance(dic[method], dict): + subpred = Predicate(subfield, self.op, self.value) + match = subpred.match(dic[method]) + return dic if match else None + + elif isinstance(dic[method], (list, tuple)): + # 1..N relationships + match = False + if self.op == contains: + return dic if self.match(dic) else None + else: + subpred = Predicate(subfield, self.op, self.value) + dic[method] = subpred.filter(dic[method]) + return dic + else: + raise Exception("Unexpected table format: %r", dic) + + + else: + # Individual field operations + return dic if self.match(dic) else None + + def get_field_names(self): + if isinstance(self.key, (list, tuple, set, frozenset)): + return FieldNames(self.key) + else: + return FieldNames([self.key]) + + def get_value_names(self): + if isinstance(self.value, (list, tuple, set, frozenset)): + return FieldNames(self.value) + else: + return FieldNames([self.value]) + + def has_empty_value(self): + if isinstance(self.value, (list, tuple, set, frozenset)): + return not any(self.value) + else: + return not self.value + + def is_composite(self): + """ + Returns: + True iif this Predicate instance involves + a tuple key (and tuple value). + """ + return isinstance(self.get_key(), tuple) + + def rename(self, aliases): + if self.is_composite(): + raise NotImplemented + if self.key in aliases: + self.key = aliases[self.key] diff --git a/netmodel/model/query.py b/netmodel/model/query.py new file mode 100644 index 00000000..c182cb45 --- /dev/null +++ b/netmodel/model/query.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from netmodel.model.filter import Filter +from netmodel.model.field_names import FieldNames + +ACTION_INSERT = 1 +ACTION_SELECT = 2 +ACTION_UPDATE = 3 +ACTION_DELETE = 4 +ACTION_EXECUTE = 5 +ACTION_SUBSCRIBE = 6 +ACTION_UNSUBSCRIBE = 7 + +ACTION2STR = { + ACTION_INSERT : 'insert', + ACTION_SELECT : 'select', + ACTION_UPDATE : 'update', + ACTION_DELETE : 'delete', + ACTION_EXECUTE : 'execute', + ACTION_SUBSCRIBE : 'subscribe', + ACTION_UNSUBSCRIBE : 'unsubscribe', +} +STR2ACTION = dict((v, k) for k, v in ACTION2STR.items()) + +FUNCTION_SUM = 1 + +FUNCTION2STR = { + FUNCTION_SUM : 'sum' +} +STR2FUNCTION = dict((v, k) for k, v in FUNCTION2STR.items()) + +class Query: + def __init__(self, action, object_name, filter = None, params = None, + field_names = None, aggregate = None, last = False, reply = False): + self.action = action + self.object_name = object_name + + if filter: + if isinstance(filter, Filter): + self.filter = filter + else: + self.filter = Filter.from_list(filter) + else: + self.filter = Filter() + + self.params = params + + if field_names: + if isinstance(field_names, FieldNames): + self.field_names = field_names + else: + self.field_names = FieldNames(field_names) + else: + self.field_names = FieldNames() + + self.aggregate = aggregate + + self.last = last + self.reply = reply + + def to_dict(self): + aggregate = FUNCTION2STR[self.aggregate] if self.aggregate else None + return { + 'action': ACTION2STR[self.action], + 'object_name': self.object_name, + 'filter': self.filter.to_list(), + 'params': self.params, + 'field_names': self.field_names, + 'aggregate': aggregate, + 'reply': self.reply, + 'last': self.last + } + + @staticmethod + def from_dict(dic): + action = STR2ACTION[dic.get('action').lower()] + object_name = dic.get('object_name') + filter = dic.get('filter', None) + params = dic.get('params', None) + field_names = dic.get('field_names', None) + aggregate = STR2FUNCTION[dic.get('aggregate').lower()] \ + if dic.get('aggregate') else None + if field_names == '*': + field_names = FieldNames(star = True) + last = dic.get('last', False) + reply = dic.get('reply', False) + return Query(action, object_name, filter, params, field_names, + aggregate, last) + + def to_sql(self, multiline = False): + """ + Args: + platform: A String corresponding to a namespace (or platform name) + multiline: A boolean indicating whether the String could contain + carriage return. + Returns: + The String representing this Query. + """ + get_params_str = lambda : ", ".join(["%s = %r" % (k, v) \ + for k, v in self.params.items()]) + + object_name = self.object_name + field_names = self.field_names + field_names_str = ('*' if field_names.is_star() \ + else ', '.join([field for field in field_names])) + select = "SELECT %s" % ((FUNCTION2STR[self.aggregate] + "(%s)") \ + if self.aggregate else '%s') % field_names_str + filter = "WHERE %s" % self.filter if self.filter else '' + #at = "AT %s" % self.get_timestamp() if self.get_timestamp() else "" + at = '' + params = "SET %s" % get_params_str() if self.params else '' + + sep = " " if not multiline else "\n " + + strmap = { + ACTION_SELECT : "%(select)s%(sep)s%(at)s%(sep)sFROM %(object_name)s%(sep)s%(filter)s", + ACTION_UPDATE : "UPDATE %(object_name)s%(sep)s%(params)s%(sep)s%(filter)s%(sep)s%(select)s", + ACTION_INSERT : "INSERT INTO %(object_name)s%(sep)s%(params)s", + ACTION_DELETE : "DELETE FROM %(object_name)s%(sep)s%(filter)s", + ACTION_SUBSCRIBE : "SUBSCRIBE : %(select)s%(sep)s%(at)s%(sep)sFROM %(object_name)s%(sep)s%(filter)s", + ACTION_UNSUBSCRIBE : "UNSUBSCRIBE : %(select)s%(sep)s%(at)s%(sep)sFROM %(object_name)s%(sep)s%(filter)s", + ACTION_EXECUTE : "EXECUTE : %(select)s%(sep)s%(at)s%(sep)sFROM %(object_name)s%(sep)s%(filter)s", + } + + return strmap[self.action] % locals() + + def __str__(self): + return self.to_sql() + + def __repr__(self): + return self.to_sql() diff --git a/netmodel/model/result_value.py b/netmodel/model/result_value.py new file mode 100644 index 00000000..1812d5c4 --- /dev/null +++ b/netmodel/model/result_value.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pprint +import time +import traceback + +from netmodel.network.packet import ErrorPacket +from netmodel.model.query import Query as Record + +# type +SUCCESS = 0 +WARNING = 1 +ERROR = 2 + +# origin +CORE = 0 +GATEWAY = 1 + +class ResultValue(dict): + + ALLOWED_FIELDS = set(["origin", "type", "code", "value", "description", + "traceback", "ts"]) + + def __init__(self, *args, **kwargs): + if args: + if kwargs: + raise Exception("Bad initialization for ResultValue") + + if len(args) == 1 and isinstance(args[0], dict): + kwargs = args[0] + + given = set(kwargs.keys()) + cstr_success = set(["code", "origin", "value"]) <= given + cstr_error = set(["code", "type", "origin", "description"]) <= given + assert given <= self.ALLOWED_FIELDS, \ + "Wrong fields in ResultValue constructor: %r" % \ + (given - self.ALLOWED_FIELDS) + assert cstr_success or cstr_error, \ + "Incomplete set of fields in ResultValue constructor: %r" % given + + dict.__init__(self, **kwargs) + + # Set missing fields to None + for field in self.ALLOWED_FIELDS - given: + self[field] = None + if not "ts" in self: + self["ts"] = time.time() + + def get_code(self): + """ + Returns: + The code transported in this ResultValue instance/ + """ + return self["code"] + + @classmethod + def get(self, records, errors): + num_errors = len(errors) + + if num_errors == 0: + return ResultValue.success(records) + elif records: + return ResultValue.warning(records, errors) + else: + return ResultValue.errors(errors) + + @classmethod + def success(self, result): + return ResultValue( + code = SUCCESS, + type = SUCCESS, + origin = [CORE, 0], + value = result + ) + + @staticmethod + def warning(result, errors): + return ResultValue( + code = ERROR, + type = WARNING, + origin = [CORE, 0], + value = result, + description = errors + ) + + @staticmethod + def error(description, code = ERROR): + assert isinstance(description, str),\ + "Invalid description = %s (%s)" % (description, type(description)) + assert isinstance(code, int),\ + "Invalid code = %s (%s)" % (code, type(code)) + + return ResultValue( + type = ERROR, + code = code, + origin = [CORE, 0], + description = [ErrorPacket(type = ERROR, code = code, + message = description, traceback = None)] + ) + + @staticmethod + def errors(errors): + """ + Make a ResultValue corresponding to an error and + gathering a set of ErrorPacket instances. + Args: + errors: A list of ErrorPacket instances. + Returns: + The corresponding ResultValue instance. + """ + assert isinstance(errors, list),\ + "Invalid errors = %s (%s)" % (errors, type(errors)) + + return ResultValue( + type = ERROR, + code = ERROR, + origin = [CORE, 0], + description = errors + ) + + def is_warning(self): + return self["type"] == WARNING + + def is_success(self): + return self["type"] == SUCCESS and self["code"] == SUCCESS + + def get_all(self): + """ + Retrieve the Records embedded in this ResultValue. + Raises: + RuntimeError: in case of failure. + Returns: + A Records instance. + """ + if not self.is_success() and not self.is_warning(): + raise RuntimeError("Error executing query: %s" % \ + (self["description"])) + try: + records = self["value"] + if len(records) > 0 and not isinstance(records[0], Record): + raise TypeError("Please put Record instances in ResultValue") + return records + except AttributeError as e: + raise RuntimeError(e) + + def get_one(self): + """ + Retrieve the only Record embeded in this ResultValue. + Raises: + RuntimeError: if there is 0 or more that 1 Record in + this ResultValue. + Returns: + A list of Records (and not of dict). + """ + records = self.get_all() + num_records = len(records) + if num_records != 1: + raise RuntimeError('Cannot call get_one() with multiple records') + return records.get_one() + + def get_error_message(self): + return "%r" % self["description"] + + @staticmethod + def to_html(raw_dict): + return pprint.pformat(raw_dict).replace("\\n","
") + + def to_dict(self): + return dict(self) diff --git a/netmodel/model/sql_parser.py b/netmodel/model/sql_parser.py new file mode 100644 index 00000000..862c0a54 --- /dev/null +++ b/netmodel/model/sql_parser.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +import sys +import pyparsing as pp + +from netmodel.model.query import Query +from netmodel.model.filter import Filter +from netmodel.model.predicate import Predicate + +DEBUG = False + +def debug(args): + if DEBUG: print(args) + +class SQLParser(object): + + def __init__(self): + """ + Our simple BNF: + SELECT [fields[*] FROM table WHERE clause + """ + + integer = pp.Combine(pp.Optional(pp.oneOf("+ -")) + + pp.Word(pp.nums)).setParseAction(lambda t:int(t[0])) + floatNumber = pp.Regex(r'\d+(\.\d*)?([eE]\d+)?') + point = pp.Literal(".") + e = pp.CaselessLiteral("E") + + kw_store = pp.CaselessKeyword('=') + kw_select = pp.CaselessKeyword('select') + kw_subscribe = pp.CaselessKeyword('subscribe') + kw_update = pp.CaselessKeyword('update') + kw_insert = pp.CaselessKeyword('insert') + kw_delete = pp.CaselessKeyword('delete') + kw_execute = pp.CaselessKeyword('execute') + + kw_from = pp.CaselessKeyword('from') + kw_into = pp.CaselessKeyword('into') + kw_where = pp.CaselessKeyword('where') + kw_at = pp.CaselessKeyword('at') + kw_set = pp.CaselessKeyword('set') + kw_true = pp.CaselessKeyword('true').setParseAction(lambda t: 1) + kw_false = pp.CaselessKeyword('false').setParseAction(lambda t: 0) + kw_with = pp.CaselessKeyword('with') + + sum_function = pp.CaselessLiteral('sum') + + # Regex string representing the set of possible operators + # Example : ">=|<=|!=|>|<|=" + OPERATOR_RX = "(?i)%s" % '|'.join([re.sub('\|', '\|', o) \ + for o in Predicate.operators.keys()]) + + # predicate + field = pp.Word(pp.alphanums + '_' + '.') + operator = pp.Regex(OPERATOR_RX).setName("operator") + variable = pp.Literal('$').suppress() + pp.Word(pp.alphanums \ + + '_' + '.').setParseAction(lambda t: "$%s" % t[0]) + filename = pp.Regex('([a-z]+?://)?(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+') + + obj = pp.Forward() + value = obj | pp.QuotedString('"') | pp.QuotedString("'") | \ + kw_true | kw_false | integer | variable + + def handle_value_list(s, l, t): + t = t.asList() + new_t = tuple(t) + debug("[handle_value_list] s = %(s)s ** l = %(l)s ** t = %(t)s" \ + % locals()) + debug(" new_t = %(new_t)s" % locals()) + return new_t + + value_list = value \ + | (pp.Literal("[").suppress() + pp.Literal("]").suppress()) \ + .setParseAction(lambda s, l, t: [[]]) \ + | pp.Literal("[").suppress() \ + + pp.delimitedList(value).setParseAction(handle_value_list) \ + + pp.Literal("]") \ + .suppress() + + table = pp.Word(pp.alphanums + ':_-/').setResultsName('object_name') + field_list = pp.Literal("*") | pp.delimitedList(field).setParseAction(lambda tokens: set(tokens)) + + assoc = (field + pp.Literal(":").suppress() + value_list).setParseAction(lambda tokens: [tokens.asList()]) + obj << pp.Literal("{").suppress() \ + + pp.delimitedList(assoc).setParseAction(lambda t: dict(t.asList())) \ + + pp.Literal("}").suppress() + + # PARAMETER (SET) + # X = Y --> t=(X, Y) + def handle_param(s, l, t): + t = t.asList() + assert len(t) == 2 + new_t = tuple(t) + debug("[handle_param] s = %(s)s ** l = %(l)s ** t = %(t)s" % locals()) + debug(" new_t = %(new_t)s" % locals()) + return new_t + + param = (field + pp.Literal("=").suppress() + value_list) \ + .setParseAction(handle_param) + + # PARAMETERS (SET) + # PARAMETER[, PARAMETER[, ...]] --> dict() + def handle_parameters(s, l, t): + t = t.asList() + new_t = dict(t) if t else dict() + debug("[handle_parameters] s = %(s)s ** l = %(l)s ** t = %(t)s" % locals()) + debug(" new_t = %(new_t)s" % locals()) + return new_t + + parameters = pp.delimitedList(param) \ + .setParseAction(handle_parameters) + + predicate = (field + operator + value_list).setParseAction(self.handlePredicate) + + # For the time being, we only support simple filters and not full clauses + filter = pp.delimitedList(predicate, delim='&&').setParseAction(lambda tokens: Filter(tokens.asList())) + + datetime = pp.Regex(r'....-..-.. ..:..:..') + + timestamp = pp.CaselessKeyword('now') | datetime + + store_elt = (variable.setResultsName("variable") + kw_store.suppress()) + fields_elt = field_list.setResultsName('field_names') + aggregate_elt = sum_function.setResultsName('aggregate') + pp.Literal("(").suppress() + fields_elt + pp.Literal(")").suppress() + select_elt = (kw_select.suppress() + fields_elt) + subscribe_elt = (kw_subscribe.suppress() + fields_elt) + where_elt = (kw_where.suppress() + filter.setResultsName('filter')) + set_elt = (kw_set.suppress() + parameters.setResultsName('params')) + at_elt = (kw_at.suppress() + timestamp.setResultsName('timestamp')) + into_elt = (kw_into.suppress() + filename.setResultsName('receiver')) + + # SELECT [SUM(]*|field_list[)] [AT timestamp] FROM table [WHERE clause] + select = ( + pp.Optional(store_elt)\ + + kw_select.suppress() \ + + pp.Optional(into_elt) \ + + (aggregate_elt | fields_elt)\ + + pp.Optional(at_elt)\ + + kw_from.suppress()\ + + table\ + + pp.Optional(where_elt) + ).setParseAction(lambda args: self.action(args, 'select')) + + subscribe = ( + pp.Optional(store_elt) \ + + kw_subscribe.suppress() \ + + pp.Optional(into_elt) \ + + (aggregate_elt | fields_elt) \ + + pp.Optional(at_elt)\ + + kw_from.suppress()\ + + table\ + + pp.Optional(where_elt) + + pp.Optional(set_elt) + ).setParseAction(lambda args: self.action(args, 'subscribe')) + + # UPDATE table SET parameters [WHERE clause] [SELECT *|field_list] + update = ( + kw_update \ + + table \ + + set_elt \ + + pp.Optional(where_elt) \ + + pp.Optional(select_elt) + ).setParseAction(lambda args: self.action(args, 'update')) + + # INSERT INTO table SET parameters [SELECT *|field_list] + insert = ( + kw_insert + kw_into + table + + set_elt + + pp.Optional(select_elt) + ).setParseAction(lambda args: self.action(args, 'insert')) + + # DELETE FROM table [WHERE clause] + delete = ( + kw_delete \ + + kw_from \ + + table \ + + pp.Optional(where_elt) + ).setParseAction(lambda args: self.action(args, 'delete')) + + # + execute = ( + kw_execute + kw_from + table + + set_elt + + pp.Optional(where_elt) + ).setParseAction(lambda args: self.action(args, 'execute')) + + annotation = pp.Optional(kw_with + \ + parameters.setResultsName('annotation')) + + self.bnf = (select | update | insert | delete | subscribe | execute) \ + + annotation + + # For reusing parser: + self.filter = filter + + def action(self, args, action): + args['action'] = action + + def handlePredicate(self, args): + return Predicate(*args) + + def parse(self, string): + result = self.bnf.parseString(string, parseAll = True) + return dict(result.items()) diff --git a/netmodel/model/type.py b/netmodel/model/type.py new file mode 100644 index 00000000..20dc2580 --- /dev/null +++ b/netmodel/model/type.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from netmodel.util.meta import inheritors + +class BaseType: + @staticmethod + def name(): + return self.__class__.__name__.lower() + +class String(BaseType): + def __init__(self, *args, **kwargs): + self._min_size = kwargs.pop('min_size', None) + self._max_size = kwargs.pop('max_size', None) + self._ascii = kwargs.pop('ascii', False) + self._forbidden = kwargs.pop('forbidden', None) + super().__init__() + +class Integer(BaseType): + def __init__(self, *args, **kwargs): + self._min_value = kwargs.pop('min_value', None) + self._max_value = kwargs.pop('max_value', None) + super().__init__() + +class Double(BaseType): + def __init__(self, *args, **kwargs): + self._min_value = kwargs.pop('min_value', None) + self._max_value = kwargs.pop('max_value', None) + super().__init__() + +class Bool(BaseType): + pass + +class Dict(BaseType): + pass + +class Self(BaseType): + """Self-reference + """ + +class Type: + BASE_TYPES = (String, Integer, Double, Bool) + _registry = dict() + + @staticmethod + def from_string(type_name, raise_exception=True): + """Returns a type corresponding to the type name. + + Params: + type_name (str) : Name of the type + + Returns + Type : Type class of the requested type name + """ + type_cls = [t for t in Type.BASE_TYPES if t.name == type_name] + if type_cls: + return type_cls[0] + + type_cls = Type._registry.get(type_name, None) + if not type_cls: + raise Exception("No type found: {}".format(type_name)) + return type_cls + + @staticmethod + def is_base_type(type_cls): + return type_cls in Type.BASE_TYPES + + @staticmethod + def exists(typ): + return (isinstance(typ, type) and typ in inheritors(BaseType)) \ + or isinstance(typ, BaseType) + +is_base_type = Type.is_base_type +is_type = Type.exists diff --git a/netmodel/network/__init__.py b/netmodel/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netmodel/network/fib.py b/netmodel/network/fib.py new file mode 100644 index 00000000..e6b81607 --- /dev/null +++ b/netmodel/network/fib.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class FIBEntry: + def __init__(self, prefix, next_hops = None): + if next_hops is None: + next_hops = set() + + self._prefix = prefix + self._next_hops = next_hops + + def update(self, next_hops = None): + if not next_hops: + return + self._next_hops |= next_hops + + def remove(self, next_hops = None): + if not next_hops: + return + self._next_hops &= next_hops + +class FIB: + def __init__(self): + self._entries = dict() + + def add(self, prefix, next_hops = None): + self._entries[prefix] = FIBEntry(prefix, next_hops) + + def update(self, prefix, next_hops = None): + entry = self._entries.get(prefix) + if not entry: + raise Exception('prefix not found') + entry.update(next_hops) + + def remove(self, prefix, next_hops = None): + if next_hop: + entry = self._entries.get(prefix) + if not entry: + raise Exception('prefix not found') + entry.remove(next_hops) + return + + del self._entries[prefix] + + def get(self, object_name): + for entry in self._entries.values(): + if entry._prefix.object_name == object_name: + return next(iter(entry._next_hops)) diff --git a/netmodel/network/flow.py b/netmodel/network/flow.py new file mode 100644 index 00000000..bce4512c --- /dev/null +++ b/netmodel/network/flow.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class Flow: + def __init__(self, source, destination): + """ + Constructor. + Args: + source: The source Prefix of this Flow. + destination: The destination Prefix of this Flow. + """ + self.source = source + self.destination = destination + + def get_reverse(self): + """ + Make the reverse Flow of this Flow. + Returns: + The reverse Flow. + """ + return Flow(self.destination, self.source) + + def __eq__(self, other): + """ + Tests whether two Flows are equal or not. + Args: + other: A Flow instance. + Returns: + True iif self == other. + """ + if self.source and other.source and self.source != other.source: + return False + if self.destination and other.destination and self.destination != other.destination: + return False + return True + + def __hash__(self): + # Order is important + return hash((self.source, self.destination)) + + def __repr__(self): + return " %s>" % (self.source, self.destination) diff --git a/netmodel/network/flow_table.py b/netmodel/network/flow_table.py new file mode 100644 index 00000000..86c6e52e --- /dev/null +++ b/netmodel/network/flow_table.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from netmodel.model.query import ACTION_SUBSCRIBE, ACTION_UNSUBSCRIBE + +# Per-interface flow table +class SubFlowTable: + def __init__(self, interface): + # Interface to which this flow table is associated + self._interface = interface + + # Flow -> ingress interface + self._flows = dict() + + # Ingress interface -> list of subscription + self._subscriptions = dict() + + def add(self, packet, ingress_interface): + flow = packet.get_flow() + self._flows[flow] = ingress_interface + + def match(self, packet): + flow = packet.get_flow() + return self._flows.get(flow.get_reverse()) + +class Subscription: + def __init__(self, packet, ingress_list, egress_list): + """ + Args: + packet : subscription packet + ingress_list (List[Interface]) : list of ingress interface + egress_list (List[Interface]) : list of egress interface + """ + self._packet = packet + self._ingress_list = ingress_list + self._egress_list = egress_list + +class FlowTable: + """ + The flow table managed flow association between packets, as well as the + subscription state for each flow. + + Event management + ================ + + The flow table has to be notified from netmodel.network.interface becomes + up, and when it is deleted. This is handled through the following events: + + _on_interface_up + . resend any pending subscription on the interface + _on_interface_deleted + . delete any subscription + """ + + def __init__(self): + # Per interface (sub) flow tables + # flow are added when forwarded using FIB + # matched upon returning + self._sub_flow_tables = dict() + + # The Flow Table also maintains a list of subscriptions doubly indexed + # by both subscriptors, and egress interfaces + + # ingress_interface -> list of subscriptions + self._ingress_subscriptions = dict() + + # egress_interface -> list of subscriptions + self._egress_subscriptions = dict() + + + def match(self, packet, interface): + """ + Check whether the packet arriving on interface is a reply. + + Returns: + interface that originally requested the packet + None if not found + """ + sub_flow_table = self._sub_flow_tables.get(interface) + if not sub_flow_table: + return None + return sub_flow_table.match(packet) + + def add(self, packet, ingress_interface, interface): + sub_flow_table = self._sub_flow_tables.get(interface) + if not sub_flow_table: + sub_flow_table = SubFlowTable(interface) + self._sub_flow_tables[interface] = sub_flow_table + sub_flow_table.add(packet, ingress_interface) + + # If the flow is a subscription, we need to associate it to the list + query = packet.to_query() + if query.action == ACTION_SUBSCRIBE: + print('adding subscription', query.to_dict()) + # XXX we currently don't merge subscriptions, and assume a single + # next hop interface + s = Subscription(packet, [ingress_interface], [interface]) + + if ingress_interface: + if not ingress_interface in self._ingress_subscriptions: + self._ingress_subscriptions[ingress_interface] = list() + self._ingress_subscriptions[ingress_interface].append(s) + + if not interface in self._egress_subscriptions: + self._egress_subscriptions[interface] = list() + self._egress_subscriptions[interface].append(s) + + elif query.action == ACTION_UNSUBSCRIBE: + raise NotImplementedError + + + # Events + + def _on_interface_up(self, interface): + """ + Callback: an interface gets back up after downtime. + + Resend all pending subscriptions when an interface comes back up. + """ + subscriptions = self._egress_subscriptions.get(interface) + if not subscriptions: + return + for s in subscriptions: + interface.send(s._packet) + + def _on_interface_delete(self, interface): + """ + Callback: an interface has been deleted + + Cancel all subscriptions that have been issues from + netmodel.network.interface. + Remove all pointers to subscriptions pending on this interface + """ + if interface in self._ingress_subscriptions: + del self._ingress_subscriptions[interface] diff --git a/netmodel/network/interface.py b/netmodel/network/interface.py new file mode 100644 index 00000000..c9e31422 --- /dev/null +++ b/netmodel/network/interface.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import enum +import inspect +import logging +import pkgutil +import sys +import traceback + +import netmodel.interfaces as interfaces +from netmodel.network.packet import Packet +from netmodel.network.prefix import Prefix + +# Tag identifying an interface name +INTERFACE_NAME_PROPERTY = '__interface__' + +RECONNECTION_DELAY = 10 + +log = logging.getLogger(__name__) + +class InterfaceState(enum.Enum): + Up = 'up' + Down = 'down' + PendingUp = 'up (pending)' + PendingDown = 'down (pending)' + Error = 'error' + +def register_interfaces(): + Interface._factory = dict() + for loader, module_name, is_pkg in pkgutil.walk_packages(interfaces.__path__, + interfaces.__name__ + '.'): + # Note: we cannot skip files using is_pkg otherwise it will ignore + # interfaces defined outside of __init__.py + #if not is_pkg: + # continue + try: + module = loader.find_module(module_name).load_module(module_name) + for _, obj in inspect.getmembers(module): + if not inspect.isclass(obj): + continue + if not issubclass(obj, Interface): + continue + if obj is Interface: + continue + if not hasattr(obj, INTERFACE_NAME_PROPERTY): + log.warning('Ignored interface' + module_name + \ + 'with no ' + INTERFACE_NAME_PROPERTY + ' property') + continue + name = getattr(obj, INTERFACE_NAME_PROPERTY) + Interface._factory[name] = obj + + except ImportError as e: + log.warning('Interface {} automatically disabled. ' \ + 'Please install dependencies if you wish to use it: {}'\ + .format(module_name, e)) + + except Exception as e: + log.warning('Failed to load interface {}: {}'.format( + module_name, e)) + +#------------------------------------------------------------------------------ + +class Interface: + @classmethod + def get(cls, name): + if not hasattr(cls, '_factory'): + register_interfaces() + return cls._factory.get(name, None) + + STATE_DOWN = 0 + STATE_PENDING_UP = 1 + STATE_UP = 2 + STATE_PENDING_DOWN = 3 + + def __init__(self, *args, **kwargs): + self._callback = kwargs.pop('callback', None) + self._hook = kwargs.pop('hook', None) + + self._tx_buffer = list() + self._state = InterfaceState.Down + self._error = None + self._reconnecting = True + self._reconnection_delay = RECONNECTION_DELAY + + self._registered_objects = dict() + + # Callbacks + self._up_callbacks = list() + self._down_callbacks = list() + self._spawn_callbacks = list() + self._delete_callbacks = list() + + # Set upon registration + self._name = None + + def terminate(self): + self.set_state(InterfaceState.PendingDown) + + def __repr__(self): + return "" % (self.__class__.__name__) + + def __hash__(self): + return hash(self._name) + + #--------------------------------------------------------------------------- + + def register_object(self, obj): + self._registered_objects[obj.__type__] = obj + + def get_prefixes(self): + return [ Prefix(v.__type__) for v in self._registered_objects.values() ] + + #--------------------------------------------------------------------------- + # State management, callbacks + #--------------------------------------------------------------------------- + + def set_state(self, state): + asyncio.ensure_future(self._set_state(state)) + + async def _set_state(self, state): + self._state = state + + if state == InterfaceState.PendingUp: + await self.pending_up_impl() + elif state == InterfaceState.PendingDown: + await self.pending_down_impl() + elif state == InterfaceState.Error: + pass + elif state == InterfaceState.Up: + log.info("Interface {} : new state UP.".format(self.__interface__,)) + if self._tx_buffer: + log.info("Platform %s: sending %d buffered packets." % + (self.__interface__, len(self._tx_buffer))) + while self._tx_buffer: + packet = self._tx_buffer.pop() + self.send_impl(packet) + # Trigger callbacks to inform interface is up + for cb, args, kwargs in self._up_callbacks: + cb(self, *args, **kwargs) + elif state == InterfaceState.Down: + log.info("Interface %s: new state DOWN." % (self.__interface__,)) + self._state = self.STATE_DOWN + # Trigger callbacks to inform interface is down + for cb, args, kwargs in self._down_callbacks: + cb(self, *args, **kwargs) + + def spawn_interface(self, interface): + #print('spawn interface', interface) + for cb, args, kwargs in self._spawn_callbacks: + cb(interface, *args, **kwargs) + + def delete_interface(self, interface): + for cb, args, kwargs in self._delete_callbacks: + cb(interface, *args, **kwargs) + + #-------------------------------------------------------------------------- + + def set_reconnecting(self, reconnecting): + self._reconnecting = reconnecting + + def get_interface_type(self): + return self.__interface__ + + def get_description(self): + return str(self) + + def get_status(self): + return 'UP' if self.is_up() else 'ERROR' if self.is_error() else 'DOWN' + + def is_up(self): + return self._state == InterfaceState.Up + + def is_down(self): + return not self.is_up() + + def is_error(self): + return self.is_down() and self._error is not None + + def reinit_impl(self): + pass + + def reinit(self, **platform_config): + self.set_down() + if platform_config: + self.reconnect_impl(self, **platform_config) + self.set_up() + + #-------------------------------------------------------------------------- + # Callback management + #-------------------------------------------------------------------------- + + def add_up_callback(self, callback, *args, **kwargs): + cb_tuple = (callback, args, kwargs) + self._up_callbacks.append(cb_tuple) + + def del_up_callback(self, callback): + self._up_callbacks = [cb for cb in self._up_callbacks \ + if cb[0] == callback] + + def add_down_callback(self, callback, *args, **kwargs): + cb_tuple = (callback, args, kwargs) + self._down_callbacks.append(cb_tuple) + + def del_down_callback(self, callback): + self._down_callbacks = [cb for cb in self._down_callbacks \ + if cb[0] == callback] + + def add_spawn_callback(self, callback, *args, **kwargs): + cb_tuple = (callback, args, kwargs) + self._spawn_callbacks.append(cb_tuple) + + def del_spawn_callback(self, callback): + self._spawn_callbacks = [cb for cb in self._spawn_callbacks \ + if cb[0] == callback] + + def add_delete_callback(self, callback, *args, **kwargs): + cb_tuple = (callback, args, kwargs) + self._delete_callbacks.append(cb_tuple) + + def del_delete_callback(self, callback): + self._delete_callbacks = [cb for cb in self._delete_callbacks \ + if cb[0] == callback] + + #-------------------------------------------------------------------------- + # Interface API + #-------------------------------------------------------------------------- + + async def pending_up_impl(self): + self.set_state(InterfaceState.Up) + + def send_impl(self, packet): + query = packet.to_query() + obj = self._registered_objects.get(query.object_name) + obj.get(query, self) + + def receive_impl(self, packet): + ingress_interface = self + cb = self._callback + if cb is None: + return + if self._hook: + new_packet = self._hook(packet) + if new_packet is not None: + cb(new_packet, ingress_interface=ingress_interface) + return + cb(packet, ingress_interface=ingress_interface) + + #-------------------------------------------------------------------------- + + def send(self, packet): + if self.is_up(): + self.send_impl(packet) + else: + self._tx_buffer.append(packet) + + def receive(self, packet): + """ + For packets received from outside (eg. a remote server). + """ + self.receive_impl(packet) + + def execute(self, query): + self.send(Packet.from_query(query)) + diff --git a/netmodel/network/packet.py b/netmodel/network/packet.py new file mode 100644 index 00000000..9552b0e7 --- /dev/null +++ b/netmodel/network/packet.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import enum +import logging +import pickle +import traceback +import uuid + +from types import GeneratorType + +from netmodel.network.flow import Flow +from netmodel.network.prefix import Prefix as Prefix_ +from netmodel.model.attribute import Attribute +from netmodel.model.type import String +from netmodel.model.query import Query +from netmodel.model.filter import Filter as Filter_ +from netmodel.model.field_names import FieldNames as FieldNames_ +from netmodel.model.object import Object + +log = logging.getLogger(__name__) + +class NetmodelException(Exception): + pass + +#------------------------------------------------------------------------------ + +NETMODEL_TLV_TYPELEN_STR = '!H' +NETMODEL_TLV_SIZE = 2 +NETMODEL_TLV_TYPE_MASK = 0xfe00 +NETMODEL_TLV_TYPE_SHIFT = 9 +NETMODEL_TLV_LENGTH_MASK = 0x01ff + +NETMODEL_TLV_OBJECT_NAME = 1 +NETMODEL_TLV_FIELD = 2 +NETMODEL_TLV_FIELDS = 3 +NETMODEL_TLV_PREDICATE = 4 +NETMODEL_TLV_FILTER = 5 +NETMODEL_TLV_SRC = 6 +NETMODEL_TLV_DST = 7 +NETMODEL_TLV_PROTOCOL = 8 +NETMODEL_TLV_FLAGS = 9 +NETMODEL_TLV_PAYLOAD = 10 +NETMODEL_TLV_PACKET = 11 + +SUCCESS = 0 +WARNING = 1 +ERROR = 2 + +FLAG_LAST = 1 << 0 +FLAG_REPLY = 1 << 1 + +#------------------------------------------------------------------------------ + +class PacketProtocol(enum.Enum): + Query = 'query' + Error = 'error' + +#------------------------------------------------------------------------------ + +class VICNTLV: + + _LEN_MIN = 0 + _LEN_MAX = 511 + tlv_type = None + + _tlv_parsers = {} # Attributes... + tlvs = [] + + def decode(self, buf): + (self.typelen, ) = struct.unpack( + NETMODEL_TLV_TYPELEN_STR, buf[:NETMODEL_TLV_SIZE]) + tlv_type = \ + (self.typelen & NETMODEL_TLV_TYPE_MASK) >> NETMODEL_TLV_TYPE_SHIFT + assert self.tlv_type == tlv_type + + self.len = self.typelen & NETMODEL_TLV_LENGTH_MASK + assert len(buf) >= self.len + NETMODEL_TLV_SIZE + + self.tlv_info = buf[NETMODEL_TLV_SIZE:] + self.tlv_info = self.tlv_info[:self.len] + + #-------------------------------------------------------------------------- + # Descriptor protocol + #-------------------------------------------------------------------------- + + @classmethod + def _get_tlv_parsers(cls): + if not cls._tlv_parsers: + cls._tlv_parsers = None + return cls._tlv_parsers + + + @staticmethod + def get_type(buf): + (typelen, ) = struct.unpack(NETMODEL_TLV_TYPELEN_STR, + buf[:NETMODEL_TLV_SIZE]) + return (typelen & NETMODEL_TLV_TYPE_MASK) >> NETMODEL_TLV_TYPE_SHIFT + + + def _len_valid(self): + return self._LEN_MIN <= self.len and self.len <= self._LEN_MAX + + #-------------------------------------------------------------------------- + + @classmethod + def _parser(cls, buf): + tlvs = [] + + while buf: + tlv_type = VICNTLV.get_type(buf) + tlv = cls._tlv_parsers[tlv_type](buf) + tlvs.append(tlv) + offset = NETMODEL_TLV_SIZE + tlv.len + buf = buf[offset:] + if tlv.tlv_type == NETMODEL_TLV_END: + break + assert len(buf) > 0 + + pkt = cls(tlvs) + + assert pkt._tlvs_len_valid() + assert pkt._tlvs_valid() + + return pkt, None, buf + + @classmethod + def parser(cls, buf): + try: + return cls._parser(buf) + except: + return None, None, buf + + def serialize(self, payload, prev): + data = bytearray() + for tlv in self.tlvs: + data += tlv.serialize() + + return data + + @classmethod + def set_type(cls, tlv_cls): + cls._tlv_parsers[tlv_cls.tlv_type] = tlv_cls + + @classmethod + def get_type(cls, tlv_type): + return cls._tlv_parsers[tlv_type] + + @classmethod + def set_tlv_type(cls, tlv_type): + def _set_type(tlv_cls): + tlv_cls.tlv_type = tlv_type + #cls.set_type(tlv_cls) + return tlv_cls + return _set_type + + def __len__(self): + return sum(NETMODEL_TLV_SIZE + tlv.len for tlv in self.tlvs) + +#------------------------------------------------------------------------------ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_OBJECT_NAME) +class ObjectName(VICNTLV): pass + +@VICNTLV.set_tlv_type(NETMODEL_TLV_FIELD) +class Field(VICNTLV): + """Field == STR + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_PREDICATE) +class Predicate(VICNTLV): + """Predicate == key, op, value + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_FILTER) +class Filter(Filter_, VICNTLV): + """Filter == Array + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_FIELDS) +class FieldNames(FieldNames_, VICNTLV): + """Fields == Array + """ + + +class Prefix(Object, Prefix_, VICNTLV): + object_name = ObjectName() + filter = Filter() + field_names = FieldNames() + + def __init__(self, *args, **kwargs): + Object.__init__(self) + Prefix_.__init__(self, *args, **kwargs) + VICNTLV.__init__(self) + + def get_tuple(self): + return (self.object_name, self.filter, self.field_names) + + def __eq__(self, other): + return self.get_tuple() == other.get_tuple() + + def __hash__(self): + return hash(self.get_tuple()) + +@VICNTLV.set_tlv_type(NETMODEL_TLV_SRC) +class Source(Prefix): + """Source address + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_DST) +class Destination(Prefix): + """Destination address + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_PROTOCOL) +class Protocol(Attribute, VICNTLV): + """Protocol + """ + + +@VICNTLV.set_tlv_type(NETMODEL_TLV_FLAGS) +class Flags(Attribute, VICNTLV): + """Flags: last, ... + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_PAYLOAD) +class Payload(Attribute, VICNTLV): + """Payload + """ + +@VICNTLV.set_tlv_type(NETMODEL_TLV_PACKET) +class Packet(Object, VICNTLV): + """Base packet class + """ + source = Source() + destination = Destination(Prefix) + protocol = Protocol(String, default = 'query') + flags = Flags() + payload = Payload() + + # This should be dispatched across L3 L4 L7 + + def __init__(self, source = None, destination = None, protocol = None, + flags = 0, payload = None): + self.source = source + self.destination = destination + self.protocol = protocol + self.flags = flags + self.payload = payload + + def get_flow(self): + return Flow(self.source, self.destination) + + @staticmethod + def from_query(query, src_query = None, reply = False): + packet = Packet() + if src_query: + address = Prefix( + object_name = src_query.object_name, + filter = src_query.filter, + field_names = src_query.field_names, + aggregate = src_query.aggregate) + if reply: + packet.destination = address + else: + packet.source = address + + if query: + address = Prefix( + object_name = query.object_name, + filter = query.filter, + field_names = query.field_names, + aggregate = query.aggregate) + + if reply: + packet.source = address + else: + packet.destination = address + + packet.payload = (query.action, query.params) + + packet.protocol = 'sync' + packet.last = not query or query.last + packet.reply = reply + + return packet + + def to_query(self): + action, params = self.payload + + address = self.source if self.reply else self.destination + object_name = address.object_name + filter = address.filter + field_names = address.field_names + aggregate = address.aggregate + + return Query(action, object_name, filter, params, field_names, + aggregate = aggregate, last = self.last, reply = self.reply) + + @property + def last(self): + return self.flags & FLAG_LAST + + @last.setter + def last(self, last): + if last: + self.flags |= FLAG_LAST + else: + self.flags &= ~FLAG_LAST + + @property + def reply(self): + return self.flags & FLAG_REPLY + + @reply.setter + def reply(self, reply): + if reply: + self.flags |= FLAG_REPLY + else: + self.flags &= ~FLAG_REPLY + +class ErrorPacket(Packet): + """ + Analog with ICMP errors packets in IP networks + """ + + #-------------------------------------------------------------------------- + # Constructor + #-------------------------------------------------------------------------- + + def __init__(self, type = ERROR, code = ERROR, message = None, + traceback = None, **kwargs): + assert not traceback or isinstance(traceback, str) + + Packet.__init__(self, **kwargs) + self.protocol = PacketProtocol.Error + self.last = True + self._type = type + self._code = code + self._message = message + self._traceback = traceback + + #-------------------------------------------------------------------------- + # Static methods + #-------------------------------------------------------------------------- + + @staticmethod + def from_exception(packet, e): + if isinstance(e, NetmodelException): + error_packet = ErrorPacket( + type = e.TYPE, # eg. ERROR + code = e.CODE, # eg. BADARGS + message = str(e), #e.message, + traceback = traceback.format_exc(), + last = True + ) + else: + error_packet = ErrorPacket( + type = ERROR, + code = UNHANDLED_EXCEPTION, + message = str(e), + traceback = traceback.format_exc(), + last = True + ) + error_packet.set_source(packet.get_destination()) + error_packet.set_destination(packet.get_source()) + return error_packet + + def get_message(self): + """ + Returns: + The error message related to this ErrorPacket. + """ + return self._message + + def get_traceback(self): + """ + Returns: + The traceback related to this ErrorPacket. + """ + return self._traceback + + def get_origin(self): + """ + Returns: + A value among {code::CORE, code::GATEWAY} + identifying who is the origin of this ErrorPacket. + """ + return self._origin + + def get_code(self): + """ + Returns: + The error code of the Error carried by this ErrorPacket. + """ + return self._code + + def get_type(self): + """ + Returns: + The error type of the Error carried by this ErrorPacket. + """ + return self._type + + def __repr__(self): + """ + Returns: + The '%r' representation of this ERROR Packet. + """ + return "" % ( + Packet.get_protocol_name(self.get_protocol()), + self.get_message() + ) diff --git a/netmodel/network/prefix.py b/netmodel/network/prefix.py new file mode 100644 index 00000000..00b5db71 --- /dev/null +++ b/netmodel/network/prefix.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class Prefix: + def __init__(self, object_name = None, filter = None, field_names = None, + aggregate = None): + self.object_name = object_name + self.filter = filter + self.field_names = field_names + self.aggregate = aggregate + + def __hash__(self): + return hash(self.get_tuple()) + + def get_tuple(self): + return (self.object_name, self.filter, self.field_names, + self.aggregate) + + def __repr__(self): + return ''.format(self.get_tuple()) + + __str__ = __repr__ diff --git a/netmodel/network/router.py b/netmodel/network/router.py new file mode 100644 index 00000000..84d69dca --- /dev/null +++ b/netmodel/network/router.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import logging +import random +import string +import traceback + +from netmodel.network.interface import Interface, InterfaceState +from netmodel.network.fib import FIB +from netmodel.network.flow_table import FlowTable +from netmodel.network.packet import Packet, ErrorPacket + +log = logging.getLogger(__name__) + +class Router: + + #-------------------------------------------------------------------------- + # Constructor, destructor, accessors + #-------------------------------------------------------------------------- + + def __init__(self, vicn_callback = None): + """ + Constructor. + Args: + allowed_capabilities: A Capabilities instance which defines which + operation can be performed by this Router. Pass None if there + is no restriction. + """ + # FIB + self._fib = FIB() + + # Per-interface flow table + self._flow_table = FlowTable() + + # interface_uuid -> interface + self._interfaces = dict() + + self._vicn_callback = vicn_callback + + def terminate(self): + for interface in self._interfaces.values(): + interface.terminate() + + # Accessors + + def get_fib(self): + return self._fib + + #-------------------------------------------------------------------------- + # Collection management + #-------------------------------------------------------------------------- + + def register_local_collection(self, cls): + self.get_interface(LOCAL_NAMESPACE).register_collection(cls, + LOCAL_NAMESPACE) + + def register_collection(self, cls, namespace=None): + self.get_interface(LOCAL_NAMESPACE).register_collection(cls, namespace) + + #-------------------------------------------------------------------------- + # Interface management + #-------------------------------------------------------------------------- + + def _register_interface(self, interface, name=None): + if not name: + name = 'interface-' + ''.join(random.choice(string.ascii_uppercase + + string.digits) for _ in range(3)) + interface.name = name + self._interfaces[name] = interface + + # Populate interface callbacks + interface.add_up_callback(self.on_interface_up) + interface.add_down_callback(self.on_interface_down) + interface.add_spawn_callback(self.on_interface_spawn) + interface.add_delete_callback(self.on_interface_delete) + + log.info('Successfully created interface {} with name {}'.format( + interface.__interface__, name)) + + interface.set_state(InterfaceState.PendingUp) + + for prefix in interface.get_prefixes(): + self._fib.add(prefix, [interface]) + + return interface + + def _unregister_interface(self, interface): + del self._interfaces[interface.name] + + # Interface events + + #-------------------------------------------------------------------------- + # Interface management + #-------------------------------------------------------------------------- + + def on_interface_up(self, interface): + """ + This callback is triggered when an interface becomes up. + + The router will request metadata. + The flow table is notified. + """ + self._flow_table._on_interface_up(interface) + + def on_interface_down(self, interface): + # We need to remove corresponding FIB entries + log.info("Router interface is now down") + + def on_interface_spawn(self, interface): + self._register_interface(interface) + + def on_interface_delete(self, interface): + """Callback : an interface has been deleted. + + - TODO : purge the FIB + - inform the flow table for managing pending subscriptions. + """ + self._unregister_interface(interface) + self._flow_table._on_interface_delete(interface) + + #--------------------------------------------------------------------------- + # Public API + #--------------------------------------------------------------------------- + + def add_interface(self, interface_type, name=None, namespace=None, + **platform_config): + """ + namespace is used to force appending of a namespace to the tables. + existing namespaces are thus ignored. + + # This is the public facing interface, which internally uses + # _register_interface. + """ + interface_cls = Interface.get(interface_type) + if interface_cls is None: + log.warning("Could not create a %(interface_type)s interface" % \ + locals()) + return None + + try: + # passes a callback to the Interface + # no hook + platform_config['callback'] = self._on_receive + interface = interface_cls(self, **platform_config) + except Exception as e: + traceback.print_exc() + raise Exception("Cannot create interface %s of type %s with parameters %r: %s" + % (name, interface_type, + platform_config, e)) + self._register_interface(interface, name) + return interface + + def is_interface_up(self, interface_name): + interface = self._interfaces.get(interface_name) + if not interface: + return False + return self._interfaces[interface_name].is_up() + + def del_platform(self, platform_name, rebuild = True): + """ + Remove a platform from this Router. This platform is no more + registered. The corresponding Announces are also removed. + Args: + platform_name: A String containing a platform name. + rebuild: True if the DbGraph must be rebuild. + Returns: + True if it altered this Router. + """ + ret = False + try: + del self._interfaces[platform_name] + ret = True + except KeyError: + pass + + self.disable_platform(platform_name, rebuild) + return ret + + def get_interface(self, platform_name): + """ + Retrieve the Interface instance corresponding to a platform. + Args: + platform_name: A String containing the name of the platform. + Raises: + ValueError: if platform_name is invalid. + RuntimeError: in case of failure. + Returns: + The corresponding Interface if found, None otherwise. + """ + if platform_name.lower() != platform_name: + raise ValueError("Invalid platform_name = %s, must be lower case" \ + % platform_name) + + if platform_name in self._interfaces: + return self._interfaces[platform_name] + + raise RuntimeError("%s is not yet registered" % platform_name) + + def get_interface_names(self): + return self._interfaces.keys() + + def get_interfaces(self): + return self._interfaces.values() + + #-------------------------------------------------------------------------- + # Packet operations + #-------------------------------------------------------------------------- + + def _on_receive(self, packet, ingress_interface): + """Handles reception of a new packet. + + An incoming packet is forwarder either: + - using the reverse path is there is a match with the ingress + interface flow table + - using the FIB if no match is found + """ + orig_interface = self._flow_table.match(packet, ingress_interface) + if orig_interface: + orig_interface.send(packet) + return + + if isinstance(packet, str): + # Workaround : internal command + if self._vicn_callback: + self._vicn_callback(packet) + return + + if packet.source is None and packet.destination is None: + log.warning('TODO: handle NULL packet, need source on all packets') + return + + # Get route from FIB + interface = self._fib.get(packet.destination.object_name) + if not interface: + return + + # Update flow table before sending + self._flow_table.add(packet, ingress_interface, interface) + + interface.send(packet) diff --git a/netmodel/util/__init__.py b/netmodel/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netmodel/util/argparse.py b/netmodel/util/argparse.py new file mode 100644 index 00000000..c9678922 --- /dev/null +++ b/netmodel/util/argparse.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse + +class ArgumentParser(argparse.ArgumentParser): + pass diff --git a/netmodel/util/color.py b/netmodel/util/color.py new file mode 100644 index 00000000..55719469 --- /dev/null +++ b/netmodel/util/color.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# ANSI escape codes for terminals. +# X11 xterm: always works, all platforms +# cygwin dosbox: run through |cat and then colors work +# linux: works on console & gnome-terminal +# mac: untested + +BLACK = "\033[0;30m" +BLUE = "\033[0;34m" +GREEN = "\033[0;32m" +CYAN = "\033[0;36m" +RED = "\033[0;31m" +PURPLE = "\033[0;35m" +BROWN = "\033[0;33m" +GRAY = "\033[0;37m" +BOLDGRAY = "\033[1;30m" +BOLDBLUE = "\033[1;34m" +BOLDGREEN = "\033[1;32m" +BOLDCYAN = "\033[1;36m" +BOLDRED = "\033[1;31m" +BOLDPURPLE = "\033[1;35m" +BOLDYELLOW = "\033[1;33m" +WHITE = "\033[1;37m" + +MYCYAN = "\033[96m" +MYGREEN = '\033[92m' +MYBLUE = '\033[94m' +MYWARNING = '\033[93m' +MYRED = '\033[91m' +MYHEADER = '\033[95m' +MYEND = '\033[0m' + +NORMAL = "\033[0m" + +colors = { + 'white': "\033[1;37m", + 'yellow': "\033[1;33m", + 'green': "\033[1;32m", + 'blue': "\033[1;34m", + 'cyan': "\033[1;36m", + 'red': "\033[1;31m", + 'magenta': "\033[1;35m", + 'black': "\033[1;30m", + 'darkwhite': "\033[0;37m", + 'darkyellow': "\033[0;33m", + 'darkgreen': "\033[0;32m", + 'darkblue': "\033[0;34m", + 'darkcyan': "\033[0;36m", + 'darkred': "\033[0;31m", + 'darkmagenta': "\033[0;35m", + 'darkblack': "\033[0;30m", + 'off': "\033[0;0m" +} + +def textcolor(color, string): + """ + This function is useful to output information to the stdout by exploiting + different colors, depending on the result of the last command executed. + + It is possible to chose one of the following colors: + - white + - yellow + - green + - blue + - cyan + - red + - magenta + - black + - darkwhite + - darkyellow + - darkgreen + - darkblue + - darkcyan + - darkred + - darkmagenta + - darkblack + - off + + :param color: The color of the output string, chosen from the previous + list. + :param string: The string to color + :return: The colored string if the color is valid, the original string + otherwise. + """ + + try: + return colors[color] + string + colors['off'] + except: + return string + +if __name__ == '__main__': + # Display color names in their color + for name, color in locals().items(): + if name.startswith('__'): continue + print(color, name, MYEND) + diff --git a/netmodel/util/daemon.py b/netmodel/util/daemon.py new file mode 100644 index 00000000..29683a54 --- /dev/null +++ b/netmodel/util/daemon.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# see also: http://www.jejik.com/files/examples/daemon3x.py + +# This is used to import the daemon package instead of the local module which is +# named identically... +from __future__ import absolute_import + +import os +import sys +import time +import logging +import atexit +import signal +import lockfile +import traceback + +log = logging.getLogger(__name__) + +class Daemon: + + #-------------------------------------------------------------------------- + # Checks + #-------------------------------------------------------------------------- + + def check_python_daemon(self): + """ + Check whether python-daemon is properly installed. + Returns: + True iiff everything is fine. + """ + # http://www.python.org/dev/peps/pep-3143/ + ret = False + try: + import daemon + getattr(daemon, "DaemonContext") + ret = True + except AttributeError as e: + # daemon and python-daemon conflict with each other + log.critical("Please install python-daemon instead of daemon." \ + "Remove daemon first.") + except ImportError: + log.critical("Please install python-daemon.") + return ret + + #-------------------------------------------------------------------------- + # Initialization + #-------------------------------------------------------------------------- + + def __init__(self, name, + uid = os.getuid(), + gid = os.getgid(), + working_directory = '/', + debug_mode = False, + no_daemon = False, + pid_filename = None + ): + self._name = name + self._uid = uid + self._gid = gid + self._working_directory = working_directory + self._debug_mode = debug_mode + self._no_daemon = no_daemon + self._pid_filename = pid_filename if pid_filename \ + else '/var/run/{}.pid'.format(name) + + # Reference which file descriptors must remain opened while + # daemonizing (for instance the file descriptor related to + # the logger, a socket file created before daemonization etc.) + self._files_to_keep = list() + self._lock_file = None + + self.initialize() + + #------------------------------------------------------------------------ + + + def remove_pid_file(self): + """ + (Internal usage) + Remove the PID file + """ + if os.path.exists(self._pid_filename) == True: + log.info("Removing %s" % self._pid_filename) + os.remove(self._pid_filename) + + if self._lock_file and self._lock_file.is_locked(): + self._lock_file.release() + + def make_pid_file(self): + """ + Create a PID file if required in which we store the PID of the daemon + if needed + """ + if self._pid_filename and not self._no_daemon: + atexit.register(self.remove_pid_file) + open(self._pid_filename, "w+").write("%s\n" % str(os.getpid())) + + def get_pid_from_pid_file(self): + """ + Retrieve the PID of the daemon thanks to the pid file. + Returns: + An integer containing the PID of this daemon. + None if the pid file is not readable or does not exists + """ + pid = None + if self._pid_filename: + try: + f_pid = file(self._pid_filename, "r") + pid = int(f_pid.read().strip()) + f_pid.close() + except IOError: + pid = None + return pid + + def make_lock_file(self): + """ + Prepare the lock file required to manage the pid file. + Initialize self.lock_file + Returns: + True iif successful. + """ + if self._pid_filename and not self._no_daemon: + log.debug("Daemonizing using pid file '%s'" % self._pid_filename) + self.lock_file = lockfile.FileLock(self._pid_filename) + if self.lock_file.is_locked() == True: + log.error("'%s' is already running ('%s' is locked)." % \ + (self._name, self._pid_filename)) + return False + self.lock_file.acquire() + else: + self.lock_file = None + return True + + def start(self): + """ + Start the daemon. + """ + # Check whether daemon module is properly installed + if self.check_python_daemon() == False: + self._terminate() + import daemon + + # Prepare self.lock_file + if not self.make_lock_file(): + sys.exit(1) + + # We might need to preserve a few files from logging handlers + files_to_keep = list() + #for handler in log.handlers: + # preserve_files + + if self._no_daemon: + self.main() + return + + # Prepare the daemon context + dcontext = daemon.DaemonContext( + detach_process = not self._no_daemon, + working_directory = self._working_directory, + pidfile = self.lock_file, + stdin = sys.stdin, + stdout = sys.stdout, + stderr = sys.stderr, + uid = self._uid, + gid = self._gid, + files_preserve = files_to_keep + ) + + # Prepare signal handling to stop properly if the daemon is killed + # Note that signal.SIGKILL can't be handled: + # http://crunchtools.com/unixlinux-signals-101/ + dcontext.signal_map = { + signal.SIGTERM : self.signal_handler, + signal.SIGQUIT : self.signal_handler, + signal.SIGINT : self.signal_handler + } + + with dcontext: + log.info("Entering daemonization") + self.make_pid_file() + + try: + self.main() + except Exception as e: + log.error("Unhandled exception in start: %s" % e) + log.error(traceback.format_exc()) + finally: + self._terminate() + + def signal_handler(self, signal_id, frame): + """ + (Internal use) + Stop the daemon (signal handler) + Args: + signal_id: The integer identifying the signal + (see also "man 7 signal") + Example: 15 if the received signal is signal.SIGTERM + frame: + """ + self._terminate() + + def _terminate(self): + """ + Stops gracefully the daemon. + Note: + The lockfile should implicitly released by the daemon package. + """ + log.info("Stopping %s" % self.__class__.__name__) + self.terminate() + self.remove_pid_file() + self.leave() + + def leave(self): + """ + Overload this method if you use twisted (see xmlrpc.py) + """ + sys.exit(0) + + # Overload these... + + def initialize(self): + pass + + def main(self): + raise NotImplementedError + + def terminate(self): + pass diff --git a/netmodel/util/debug.py b/netmodel/util/debug.py new file mode 100644 index 00000000..dfa7d127 --- /dev/null +++ b/netmodel/util/debug.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# FROM: https://gist.github.com/techtonik/2151727 +# Public Domain, i.e. feel free to copy/paste +# Considered a hack in Python 2 +# + +import inspect +import traceback + +def print_call_stack(): + for line in traceback.format_stack(): + print(line.strip()) + +def caller_name(skip=2): + """Get a name of a caller in the format module.class.method + + `skip` specifies how many levels of stack to skip while getting caller + name. skip=1 means "who calls me", skip=2 "who calls my caller" etc. + + An empty string is returned if skipped levels exceed stack height + """ + stack = inspect.stack() + start = 0 + skip + if len(stack) < start + 1: + return '' + parentframe = stack[start][0] + + name = [] + module = inspect.getmodule(parentframe) + # `modname` can be None when frame is executed directly in console + if module: + name.append(module.__name__) + # detect classname + if 'self' in parentframe.f_locals: + # there seems to be no way to detect static method call - it will + # be just a function call + name.append(parentframe.f_locals['self'].__class__.__name__) + codename = parentframe.f_code.co_name + if codename != '': # top level usually + name.append( codename ) # function or a method + del parentframe + return ".".join(name) + diff --git a/netmodel/util/deprecated.py b/netmodel/util/deprecated.py new file mode 100644 index 00000000..faa80ac2 --- /dev/null +++ b/netmodel/util/deprecated.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +import functools + +def deprecated(func): + """This is a decorator which can be used to mark functions + as deprecated. It will result in a warning being emmitted + when the function is used.""" + + @functools.wraps(func) + def new_func(*args, **kwargs): + #warnings.simplefilter('always', DeprecationWarning) + warnings.warn("Call to deprecated function {}.".format(func.__name__),\ + category=DeprecationWarning, stacklevel=2) + #warnings.simplefilter('default', DeprecationWarning) + return func(*args, **kwargs) + + return new_func diff --git a/netmodel/util/log.py b/netmodel/util/log.py new file mode 100644 index 00000000..68eb9a7f --- /dev/null +++ b/netmodel/util/log.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import logging.config +import os +import sys + +colors = { + 'white': "\033[1;37m", + 'yellow': "\033[1;33m", + 'green': "\033[1;32m", + 'blue': "\033[1;34m", + 'cyan': "\033[1;36m", + 'red': "\033[1;31m", + 'magenta': "\033[1;35m", + 'black': "\033[1;30m", + 'darkwhite': "\033[0;37m", + 'darkyellow': "\033[0;33m", + 'darkgreen': "\033[0;32m", + 'darkblue': "\033[0;34m", + 'darkcyan': "\033[0;36m", + 'darkred': "\033[0;31m", + 'darkmagenta': "\033[0;35m", + 'darkblack': "\033[0;30m", + 'off': "\033[0;0m" +} + +def textcolor(color, string): + """ + This function is useful to output information to the stdout by exploiting + different colors, depending on the result of the last command executed. + + It is possible to chose one of the following colors: white, yellow, green, + blue, cyan, red, magenta, black, darkwhite, darkyellow, darkgreen, + darkblue, darkcyan, darkred, darkmagenta, darkblack, off + + :param color: The color of the output string, chosen from the previous + list. + :param string: The string to color + :return: The colored string if the color is valid, the original string + otherwise. + """ + try: + return colors[color] + string + colors['off'] + except: + return string + +FMT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + +MAP_LEVELNAME_COLOR = { + 'DEBUG' : 'off', + 'INFO' : 'cyan', + 'WARNING' : 'yellow', + 'ERROR' : 'red', + 'CRITICAL' : 'red' +} + +DEFAULT_COLOR = 'blue' + +DEPRECATED_REPEAT = False +DEPRECATED_DONE = list() + +DEBUG_PATHS = 'vicn' + +class DebugPathsFilter: + def __init__(self, debug_paths): + self._debug_paths = debug_paths + + def filter(self, record): + return record.levelname != 'DEBUG' or record.name in self._debug_paths + +class ColorFormatter(logging.Formatter): + def __init__(self, *args, debug_paths = None, **kwargs): + self._debug_paths = debug_paths + super().__init__(*args, **kwargs) + + def format(self, record): + formatted = super().format(record) + + try: + color = record.category + except AttributeError: + color = MAP_LEVELNAME_COLOR.get(record.levelname, DEFAULT_COLOR) + + return textcolor(color, formatted) + +def initialize_logging(): + # Use logger config + config_path = os.path.join(os.path.dirname(__file__), + os.path.pardir, os.path.pardir, 'config', 'logging.conf') + if os.path.exists(config_path): + logging.config.fileConfig(config_path, disable_existing_loggers=False) + + root = logging.getLogger() + root.setLevel(logging.DEBUG) + + # Stdout handler + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + #formatter = logging.Formatter(FMT) + formatter = ColorFormatter(FMT, debug_paths = DEBUG_PATHS) + ch.setFormatter(formatter) + ch.addFilter(DebugPathsFilter(DEBUG_PATHS)) + root.addHandler(ch) diff --git a/netmodel/util/meta.py b/netmodel/util/meta.py new file mode 100644 index 00000000..355beb7e --- /dev/null +++ b/netmodel/util/meta.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# http://stackoverflow.com/questions/5881873/python-find-all-classes-which-inherit-from-this-one + +def inheritors(klass): + subclasses = set() + work = [klass] + while work: + parent = work.pop() + for child in parent.__subclasses__(): + if child not in subclasses: + subclasses.add(child) + work.append(child) + return subclasses diff --git a/netmodel/util/misc.py b/netmodel/util/misc.py new file mode 100644 index 00000000..315887b3 --- /dev/null +++ b/netmodel/util/misc.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import collections + +def is_iterable(x): + return isinstance(x, collections.Iterable) and not isinstance(x, str) + +#------------------------------------------------------------------------------ +from itertools import tee + +# https://docs.python.org/3/library/itertools.html#itertools-recipes +def pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +#------------------------------------------------------------------------------ +# http://stackoverflow.com/questions/1630320/what-is-the-pythonic-way-to-detect-the-last-element-in-a-python-for-loop +def lookahead(iterable): + it = iter(iterable) + last = next(it) + for val in it: + yield last, False + last = val + yield last, True + +#------------------------------------------------------------------------------ +# http://stackoverflow.com/questions/10840533/most-pythonic-way-to-delete-a-file-which-may-not-exist + +import os, errno + +def silentremove(filename): + try: + os.remove(filename) + except OSError as e: # this would be "except OSError, e:" before Python 2.6 + if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory + raise # re-raise exception if a different error occured + +#------------------------------------------------------------------------------ +import socket + +def is_local_host(hostname): + return hostname in ['localhost', '127.0.0.1'] or \ + hostname == socket.gethostname() diff --git a/netmodel/util/process.py b/netmodel/util/process.py new file mode 100644 index 00000000..954c0098 --- /dev/null +++ b/netmodel/util/process.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import subprocess +import shlex + +from vicn.core.commands import ReturnValue + +def execute_local(cmd, output=True): + cmd = shlex.split(cmd) + if output: + p = subprocess.Popen(cmd, stdout=subprocess.PIPE) + stdout, stderr = p.communicate() + ret = p.returncode + return ReturnValue(ret, stdout, stderr) + else: + p = subprocess.Popen(cmd, stdout=subprocess.DEVNULL) + ret = p.wait() + return ReturnValue(ret) + diff --git a/netmodel/util/sa_compat.py b/netmodel/util/sa_compat.py new file mode 100644 index 00000000..ee4a20f9 --- /dev/null +++ b/netmodel/util/sa_compat.py @@ -0,0 +1,265 @@ +# util/compat.py +# Copyright (C) 2005-2016 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Handle Python version/platform incompatibilities.""" + +import sys + +try: + import threading +except ImportError: + import dummy_threading as threading + +py36 = sys.version_info >= (3, 6) +py33 = sys.version_info >= (3, 3) +py32 = sys.version_info >= (3, 2) +py3k = sys.version_info >= (3, 0) +py2k = sys.version_info < (3, 0) +py265 = sys.version_info >= (2, 6, 5) +jython = sys.platform.startswith('java') +pypy = hasattr(sys, 'pypy_version_info') +win32 = sys.platform.startswith('win') +cpython = not pypy and not jython # TODO: something better for this ? + +import collections +next = next + +if py3k: + import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle + +# work around http://bugs.python.org/issue2646 +if py265: + safe_kwarg = lambda arg: arg +else: + safe_kwarg = str + +ArgSpec = collections.namedtuple("ArgSpec", + ["args", "varargs", "keywords", "defaults"]) + +if py3k: + import builtins + + from inspect import getfullargspec as inspect_getfullargspec + from urllib.parse import (quote_plus, unquote_plus, + parse_qsl, quote, unquote) + import configparser + from io import StringIO + + from io import BytesIO as byte_buffer + + def inspect_getargspec(func): + return ArgSpec( + *inspect_getfullargspec(func)[0:4] + ) + + string_types = str, + binary_types = bytes, + binary_type = bytes + text_type = str + int_types = int, + iterbytes = iter + + def u(s): + return s + + def ue(s): + return s + + def b(s): + return s.encode("latin-1") + + if py32: + callable = callable + else: + def callable(fn): + return hasattr(fn, '__call__') + + def cmp(a, b): + return (a > b) - (a < b) + + from functools import reduce + + print_ = getattr(builtins, "print") + + import_ = getattr(builtins, '__import__') + + import itertools + itertools_filterfalse = itertools.filterfalse + itertools_filter = filter + itertools_imap = map + from itertools import zip_longest + + import base64 + + def b64encode(x): + return base64.b64encode(x).decode('ascii') + + def b64decode(x): + return base64.b64decode(x.encode('ascii')) + +else: + from inspect import getargspec as inspect_getfullargspec + inspect_getargspec = inspect_getfullargspec + from urllib import quote_plus, unquote_plus, quote, unquote + from urlparse import parse_qsl + import ConfigParser as configparser + from StringIO import StringIO + from cStringIO import StringIO as byte_buffer + + string_types = basestring, + binary_types = bytes, + binary_type = str + text_type = unicode + int_types = int, long + + def iterbytes(buf): + return (ord(byte) for byte in buf) + + def u(s): + # this differs from what six does, which doesn't support non-ASCII + # strings - we only use u() with + # literal source strings, and all our source files with non-ascii + # in them (all are tests) are utf-8 encoded. + return unicode(s, "utf-8") + + def ue(s): + return unicode(s, "unicode_escape") + + def b(s): + return s + + def import_(*args): + if len(args) == 4: + args = args[0:3] + ([str(arg) for arg in args[3]],) + return __import__(*args) + + callable = callable + cmp = cmp + reduce = reduce + + import base64 + b64encode = base64.b64encode + b64decode = base64.b64decode + + def print_(*args, **kwargs): + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + for arg in enumerate(args): + if not isinstance(arg, basestring): + arg = str(arg) + fp.write(arg) + + import itertools + itertools_filterfalse = itertools.ifilterfalse + itertools_filter = itertools.ifilter + itertools_imap = itertools.imap + from itertools import izip_longest as zip_longest + + +import time +if win32 or jython: + time_func = time.clock +else: + time_func = time.time + +from collections import namedtuple +from operator import attrgetter as dottedgetter + + +if py3k: + def reraise(tp, value, tb=None, cause=None): + if cause is not None: + assert cause is not value, "Same cause emitted" + value.__cause__ = cause + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + +else: + # not as nice as that of Py3K, but at least preserves + # the code line where the issue occurred + exec("def reraise(tp, value, tb=None, cause=None):\n" + " if cause is not None:\n" + " assert cause is not value, 'Same cause emitted'\n" + " raise tp, value, tb\n") + + +def raise_from_cause(exception, exc_info=None): + if exc_info is None: + exc_info = sys.exc_info() + exc_type, exc_value, exc_tb = exc_info + cause = exc_value if exc_value is not exception else None + reraise(type(exception), exception, tb=exc_tb, cause=cause) + +if py3k: + exec_ = getattr(builtins, 'exec') +else: + def exec_(func_text, globals_, lcl=None): + if lcl is None: + exec('exec func_text in globals_') + else: + exec('exec func_text in globals_, lcl') + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass. + + Drops the middle class upon creation. + + Source: http://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/ + + """ + + class metaclass(meta): + __call__ = type.__call__ + __init__ = type.__init__ + + def __new__(cls, name, this_bases, d): + if this_bases is None: + return type.__new__(cls, name, (), d) + return meta(name, bases, d) + return metaclass('temporary_class', None, {}) + + +from contextlib import contextmanager + +try: + from contextlib import nested +except ImportError: + # removed in py3k, credit to mitsuhiko for + # workaround + + @contextmanager + def nested(*managers): + exits = [] + vars = [] + exc = (None, None, None) + try: + for mgr in managers: + exit = mgr.__exit__ + enter = mgr.__enter__ + vars.append(enter()) + exits.append(exit) + yield vars + except: + exc = sys.exc_info() + finally: + while exits: + exit = exits.pop() + try: + if exit(*exc): + exc = (None, None, None) + except: + exc = sys.exc_info() + if exc != (None, None, None): + reraise(exc[0], exc[1], exc[2]) diff --git a/netmodel/util/singleton.py b/netmodel/util/singleton.py new file mode 100644 index 00000000..4aaff9ee --- /dev/null +++ b/netmodel/util/singleton.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class Singleton(type): + """ + Classes that inherit from Singleton can be instanciated only once + + See also + http://stackoverflow.com/questions/6760685/creating-a-singleton-in-python + """ + + def __init__(cls, name, bases, dic): + super(Singleton,cls).__init__(name,bases,dic) + cls.instance=None + + def __call__(cls, *args, **kw): + if cls.instance is None: + cls.instance=super(Singleton,cls).__call__(*args,**kw) + return cls.instance + + def _drop(cls): + "Drop the instance (for testing purposes)." + if cls.instance is not None: + del cls.instance + cls.instance = None + diff --git a/netmodel/util/socket.py b/netmodel/util/socket.py new file mode 100644 index 00000000..e0cab384 --- /dev/null +++ b/netmodel/util/socket.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (c) 2017 Cisco and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import socket + +log = logging.getLogger(__name__) + +def check_port(address, port): + log.info('Check port status for {address}:{port}'.format(**locals())) + s = socket.socket() + try: + s.connect((address, port)) + return True + except socket.error as e: + return False diff --git a/netmodel/util/toposort.py b/netmodel/util/toposort.py new file mode 100644 index 00000000..64931c32 --- /dev/null +++ b/netmodel/util/toposort.py @@ -0,0 +1,82 @@ +####################################################################### +# Implements a topological sort algorithm. +# +# Copyright 2014 True Blade Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Notes: +# Based on http://code.activestate.com/recipes/578272-topological-sort +# with these major changes: +# Added unittests. +# Deleted doctests (maybe not the best idea in the world, but it cleans +# up the docstring). +# Moved functools import to the top of the file. +# Changed assert to a ValueError. +# Changed iter[items|keys] to [items|keys], for python 3 +# compatibility. I don't think it matters for python 2 these are +# now lists instead of iterables. +# Copy the input so as to leave it unmodified. +# Renamed function from toposort2 to toposort. +# Handle empty input. +# Switch tests to use set literals. +# +######################################################################## + +from functools import reduce as _reduce + +__all__ = ['toposort', 'toposort_flatten'] + +def toposort(data): + """Dependencies are expressed as a dictionary whose keys are items +and whose values are a set of dependent items. Output is a list of +sets in topological order. The first set consists of items with no +dependences, each subsequent set consists of items that depend upon +items in the preceeding sets. +""" + + # Special case empty input. + if len(data) == 0: + return + + # Copy the input so as to leave it unmodified. + data = data.copy() + + # Ignore self dependencies. + for k, v in data.items(): + v.discard(k) + # Find all items that don't depend on anything. + extra_items_in_deps = _reduce(set.union, data.values()) - set(data.keys()) + # Add empty dependences where needed. + data.update({item:set() for item in extra_items_in_deps}) + while True: + ordered = set(item for item, dep in data.items() if len(dep) == 0) + if not ordered: + break + yield ordered + data = {item: (dep - ordered) + for item, dep in data.items() + if item not in ordered} + if len(data) != 0: + raise ValueError('Cyclic dependencies exist among these items: {}'.format(', '.join(repr(x) for x in data.items()))) + + +def toposort_flatten(data, sort=True): + """Returns a single list of dependencies. For any set returned by +toposort(), those items are sorted and appended to the result (just to +make the results deterministic).""" + + result = [] + for d in toposort(data): + result.extend((sorted if sort else list)(d)) + return result -- cgit 1.2.3-korg