diff options
author | 2016-02-08 10:55:20 -0500 | |
---|---|---|
committer | 2016-02-08 10:55:20 -0500 | |
commit | 6107c1ca4aa485c5971ff3326513b8f4934f7ac1 (patch) | |
tree | a109d80501bd087e3219f68c186fb55bc17e090a /scripts/automation/trex_control_plane/stl/trex_stl_lib | |
parent | f5a5e50bfe046148a20f6ce578d6082119dec2c0 (diff) |
huge refactor - again
Diffstat (limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib')
20 files changed, 6199 insertions, 0 deletions
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/__init__.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/__init__.py diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/api.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/api.py new file mode 100644 index 00000000..4c0c10fa --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/api.py @@ -0,0 +1,31 @@ + +# get external libs +import trex_stl_ext + +# client and exceptions +from trex_stl_exceptions import * +from trex_stl_client import STLClient, LoggerApi + +# streams +from trex_stl_streams import * + +# packet builder +from trex_stl_packet_builder_scapy import * +from scapy.all import * + +# packet builder +STLPktBuilder = CScapyTRexPktBuilder + +# VM +STLVmFlowVar = CTRexVmDescFlowVar +STLVmWriteFlowVar = CTRexVmDescWrFlowVar +STLVmFixIpv4 = CTRexVmDescFixIpv4 +STLVmTrimPktSize = CTRexVmDescTrimPktSize +STLVmTupleGen = CTRexVmDescTupleGen + + +# simulator +from trex_stl_sim import STLSim + +# std lib (various lib functions) +from trex_stl_std import * diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_async_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_async_client.py new file mode 100644 index 00000000..410482b9 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_async_client.py @@ -0,0 +1,322 @@ +#!/router/bin/python + +import json +import threading +import time +import datetime +import zmq +import re +import random + +from trex_stl_jsonrpc_client import JsonRpcClient, BatchMessage + +from utils.text_opts import * +from trex_stl_stats import * +from trex_stl_types import * + +# basic async stats class +class CTRexAsyncStats(object): + def __init__ (self): + self.ref_point = None + self.current = {} + self.last_update_ts = datetime.datetime.now() + + def update (self, snapshot): + + #update + self.last_update_ts = datetime.datetime.now() + + self.current = snapshot + + if self.ref_point == None: + self.ref_point = self.current + + def clear(self): + self.ref_point = self.current + + + def get(self, field, format=False, suffix=""): + + if not field in self.current: + return "N/A" + + if not format: + return self.current[field] + else: + return format_num(self.current[field], suffix) + + def get_rel (self, field, format=False, suffix=""): + if not field in self.current: + return "N/A" + + if not format: + return (self.current[field] - self.ref_point[field]) + else: + return format_num(self.current[field] - self.ref_point[field], suffix) + + + # return true if new data has arrived in the past 2 seconds + def is_online (self): + delta_ms = (datetime.datetime.now() - self.last_update_ts).total_seconds() * 1000 + return (delta_ms < 2000) + +# describes the general stats provided by TRex +class CTRexAsyncStatsGeneral(CTRexAsyncStats): + def __init__ (self): + super(CTRexAsyncStatsGeneral, self).__init__() + + +# per port stats +class CTRexAsyncStatsPort(CTRexAsyncStats): + def __init__ (self): + super(CTRexAsyncStatsPort, self).__init__() + + def get_stream_stats (self, stream_id): + return None + +# stats manager +class CTRexAsyncStatsManager(): + def __init__ (self): + + self.general_stats = CTRexAsyncStatsGeneral() + self.port_stats = {} + + + def get_general_stats(self): + return self.general_stats + + def get_port_stats (self, port_id): + + if not str(port_id) in self.port_stats: + return None + + return self.port_stats[str(port_id)] + + + def update(self, data): + self.__handle_snapshot(data) + + def __handle_snapshot(self, snapshot): + + general_stats = {} + port_stats = {} + + # filter the values per port and general + for key, value in snapshot.iteritems(): + + # match a pattern of ports + m = re.search('(.*)\-([0-8])', key) + if m: + + port_id = m.group(2) + field_name = m.group(1) + + if not port_id in port_stats: + port_stats[port_id] = {} + + port_stats[port_id][field_name] = value + + else: + # no port match - general stats + general_stats[key] = value + + # update the general object with the snapshot + self.general_stats.update(general_stats) + + # update all ports + for port_id, data in port_stats.iteritems(): + + if not port_id in self.port_stats: + self.port_stats[port_id] = CTRexAsyncStatsPort() + + self.port_stats[port_id].update(data) + + + + + +class CTRexAsyncClient(): + def __init__ (self, server, port, stateless_client): + + self.port = port + self.server = server + + self.stateless_client = stateless_client + + self.event_handler = stateless_client.event_handler + self.logger = self.stateless_client.logger + + self.raw_snapshot = {} + + self.stats = CTRexAsyncStatsManager() + + self.last_data_recv_ts = 0 + self.async_barrier = None + + self.connected = False + + # connects the async channel + def connect (self): + + if self.connected: + self.disconnect() + + self.tr = "tcp://{0}:{1}".format(self.server, self.port) + + # Socket to talk to server + self.context = zmq.Context() + self.socket = self.context.socket(zmq.SUB) + + + # before running the thread - mark as active + self.active = True + self.t = threading.Thread(target = self._run) + + # kill this thread on exit and don't add it to the join list + self.t.setDaemon(True) + self.t.start() + + self.connected = True + + rc = self.barrier() + if not rc: + self.disconnect() + return rc + + return RC_OK() + + + + + # disconnect + def disconnect (self): + if not self.connected: + return + + # signal that the context was destroyed (exit the thread loop) + self.context.term() + + # mark for join and join + self.active = False + self.t.join() + + # done + self.connected = False + + + # thread function + def _run (self): + + # socket must be created on the same thread + self.socket.setsockopt(zmq.SUBSCRIBE, '') + self.socket.setsockopt(zmq.RCVTIMEO, 5000) + self.socket.connect(self.tr) + + got_data = False + + while self.active: + try: + + line = self.socket.recv_string() + self.last_data_recv_ts = time.time() + + # signal once + if not got_data: + self.event_handler.on_async_alive() + got_data = True + + + # got a timeout - mark as not alive and retry + except zmq.Again: + + # signal once + if got_data: + self.event_handler.on_async_dead() + got_data = False + + continue + + except zmq.ContextTerminated: + # outside thread signaled us to exit + break + + msg = json.loads(line) + + name = msg['name'] + data = msg['data'] + type = msg['type'] + self.raw_snapshot[name] = data + + self.__dispatch(name, type, data) + + + # closing of socket must be from the same thread + self.socket.close(linger = 0) + + + # did we get info for the last 3 seconds ? + def is_alive (self): + if self.last_data_recv_ts == None: + return False + + return ( (time.time() - self.last_data_recv_ts) < 3 ) + + def get_stats (self): + return self.stats + + def get_raw_snapshot (self): + return self.raw_snapshot + + # dispatch the message to the right place + def __dispatch (self, name, type, data): + # stats + if name == "trex-global": + self.event_handler.handle_async_stats_update(data) + + # events + elif name == "trex-event": + self.event_handler.handle_async_event(type, data) + + # barriers + elif name == "trex-barrier": + self.handle_async_barrier(type, data) + else: + pass + + + # async barrier handling routine + def handle_async_barrier (self, type, data): + if self.async_barrier['key'] == type: + self.async_barrier['ack'] = True + + + # block on barrier for async channel + def barrier(self, timeout = 5): + + # set a random key + key = random.getrandbits(32) + self.async_barrier = {'key': key, 'ack': False} + + # expr time + expr = time.time() + timeout + + while not self.async_barrier['ack']: + + # inject + rc = self.stateless_client._transmit("publish_now", params = {'key' : key}) + if not rc: + return rc + + # fast loop + for i in xrange(0, 100): + if self.async_barrier['ack']: + break + time.sleep(0.001) + + if time.time() > expr: + return RC_ERR("*** [subscriber] - timeout - no data flow from server at : " + self.tr) + + return RC_OK() + + + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py new file mode 100644 index 00000000..ed11791b --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py @@ -0,0 +1,2020 @@ +#!/router/bin/python + +# for API usage the path name must be full +from trex_stl_lib.trex_stl_exceptions import * +from trex_stl_lib.trex_stl_streams import * + +from trex_stl_jsonrpc_client import JsonRpcClient, BatchMessage +import trex_stl_stats + +from trex_stl_port import Port +from trex_stl_types import * +from trex_stl_async_client import CTRexAsyncClient + +from utils import parsing_opts, text_tables, common +from utils.text_opts import * + + +from collections import namedtuple +from yaml import YAMLError +import time +import datetime +import re +import random +import json +import traceback + +############################ logger ############################# +############################ ############################# +############################ ############################# + +# logger API for the client +class LoggerApi(object): + # verbose levels + VERBOSE_QUIET = 0 + VERBOSE_REGULAR = 1 + VERBOSE_HIGH = 2 + + def __init__(self): + self.level = LoggerApi.VERBOSE_REGULAR + + # implemented by specific logger + def write(self, msg, newline = True): + raise Exception("implement this") + + # implemented by specific logger + def flush(self): + raise Exception("implement this") + + def set_verbose (self, level): + if not level in xrange(self.VERBOSE_QUIET, self.VERBOSE_HIGH + 1): + raise ValueError("bad value provided for logger") + + self.level = level + + def get_verbose (self): + return self.level + + + def check_verbose (self, level): + return (self.level >= level) + + + # simple log message with verbose + def log (self, msg, level = VERBOSE_REGULAR, newline = True): + if not self.check_verbose(level): + return + + self.write(msg, newline) + + # logging that comes from async event + def async_log (self, msg, level = VERBOSE_REGULAR, newline = True): + self.log(msg, level, newline) + + + def pre_cmd (self, desc): + self.log(format_text('\n{:<60}'.format(desc), 'bold'), newline = False) + self.flush() + + def post_cmd (self, rc): + if rc: + self.log(format_text("[SUCCESS]\n", 'green', 'bold')) + else: + self.log(format_text("[FAILED]\n", 'red', 'bold')) + + + def log_cmd (self, desc): + self.pre_cmd(desc) + self.post_cmd(True) + + + # supress object getter + def supress (self): + class Supress(object): + def __init__ (self, logger): + self.logger = logger + + def __enter__ (self): + self.saved_level = self.logger.get_verbose() + self.logger.set_verbose(LoggerApi.VERBOSE_QUIET) + + def __exit__ (self, type, value, traceback): + self.logger.set_verbose(self.saved_level) + + return Supress(self) + + + +# default logger - to stdout +class DefaultLogger(LoggerApi): + + def __init__ (self): + super(DefaultLogger, self).__init__() + + def write (self, msg, newline = True): + if newline: + print msg + else: + print msg, + + def flush (self): + sys.stdout.flush() + + +############################ async event hander ############################# +############################ ############################# +############################ ############################# + +# handles different async events given to the client +class AsyncEventHandler(object): + + def __init__ (self, client): + self.client = client + self.logger = self.client.logger + + self.events = [] + + # public functions + + def get_events (self): + return self.events + + + def clear_events (self): + self.events = [] + + + def on_async_dead (self): + if self.client.connected: + msg = 'lost connection to server' + self.__add_event_log(msg, 'local', True) + self.client.connected = False + + + def on_async_alive (self): + pass + + + # handles an async stats update from the subscriber + def handle_async_stats_update(self, dump_data): + global_stats = {} + port_stats = {} + + # filter the values per port and general + for key, value in dump_data.iteritems(): + # match a pattern of ports + m = re.search('(.*)\-([0-8])', key) + if m: + port_id = int(m.group(2)) + field_name = m.group(1) + if self.client.ports.has_key(port_id): + if not port_id in port_stats: + port_stats[port_id] = {} + port_stats[port_id][field_name] = value + else: + continue + else: + # no port match - general stats + global_stats[key] = value + + # update the general object with the snapshot + self.client.global_stats.update(global_stats) + + # update all ports + for port_id, data in port_stats.iteritems(): + self.client.ports[port_id].port_stats.update(data) + + + # dispatcher for server async events (port started, port stopped and etc.) + def handle_async_event (self, type, data): + # DP stopped + show_event = False + + # port started + if (type == 0): + port_id = int(data['port_id']) + ev = "Port {0} has started".format(port_id) + self.__async_event_port_started(port_id) + + # port stopped + elif (type == 1): + port_id = int(data['port_id']) + ev = "Port {0} has stopped".format(port_id) + + # call the handler + self.__async_event_port_stopped(port_id) + + + # port paused + elif (type == 2): + port_id = int(data['port_id']) + ev = "Port {0} has paused".format(port_id) + + # call the handler + self.__async_event_port_paused(port_id) + + # port resumed + elif (type == 3): + port_id = int(data['port_id']) + ev = "Port {0} has resumed".format(port_id) + + # call the handler + self.__async_event_port_resumed(port_id) + + # port finished traffic + elif (type == 4): + port_id = int(data['port_id']) + ev = "Port {0} job done".format(port_id) + + # call the handler + self.__async_event_port_stopped(port_id) + show_event = True + + # port was stolen... + elif (type == 5): + session_id = data['session_id'] + + # false alarm, its us + if session_id == self.client.session_id: + return + + port_id = int(data['port_id']) + who = data['who'] + + ev = "Port {0} was forcely taken by '{1}'".format(port_id, who) + + # call the handler + self.__async_event_port_forced_acquired(port_id) + show_event = True + + # server stopped + elif (type == 100): + ev = "Server has stopped" + self.__async_event_server_stopped() + show_event = True + + + else: + # unknown event - ignore + return + + + self.__add_event_log(ev, 'server', show_event) + + + # private functions + + def __async_event_port_stopped (self, port_id): + self.client.ports[port_id].async_event_port_stopped() + + + def __async_event_port_started (self, port_id): + self.client.ports[port_id].async_event_port_started() + + + def __async_event_port_paused (self, port_id): + self.client.ports[port_id].async_event_port_paused() + + + def __async_event_port_resumed (self, port_id): + self.client.ports[port_id].async_event_port_resumed() + + + def __async_event_port_forced_acquired (self, port_id): + self.client.ports[port_id].async_event_forced_acquired() + + + def __async_event_server_stopped (self): + self.client.connected = False + + + # add event to log + def __add_event_log (self, msg, ev_type, show = False): + + if ev_type == "server": + prefix = "[server]" + elif ev_type == "local": + prefix = "[local]" + + ts = time.time() + st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') + self.events.append("{:<10} - {:^8} - {:}".format(st, prefix, format_text(msg, 'bold'))) + + if show: + self.logger.async_log(format_text("\n\n{:^8} - {:}".format(prefix, format_text(msg, 'bold')))) + + + + + +############################ RPC layer ############################# +############################ ############################# +############################ ############################# + +class CCommLink(object): + """describes the connectivity of the stateless client method""" + def __init__(self, server="localhost", port=5050, virtual=False, prn_func = None): + self.virtual = virtual + self.server = server + self.port = port + self.rpc_link = JsonRpcClient(self.server, self.port, prn_func) + + @property + def is_connected(self): + if not self.virtual: + return self.rpc_link.connected + else: + return True + + def get_server (self): + return self.server + + def get_port (self): + return self.port + + def connect(self): + if not self.virtual: + return self.rpc_link.connect() + + def disconnect(self): + if not self.virtual: + return self.rpc_link.disconnect() + + def transmit(self, method_name, params={}): + if self.virtual: + self._prompt_virtual_tx_msg() + _, msg = self.rpc_link.create_jsonrpc_v2(method_name, params) + print msg + return + else: + return self.rpc_link.invoke_rpc_method(method_name, params) + + def transmit_batch(self, batch_list): + if self.virtual: + self._prompt_virtual_tx_msg() + print [msg + for _, msg in [self.rpc_link.create_jsonrpc_v2(command.method, command.params) + for command in batch_list]] + else: + batch = self.rpc_link.create_batch() + for command in batch_list: + batch.add(command.method, command.params) + # invoke the batch + return batch.invoke() + + def _prompt_virtual_tx_msg(self): + print "Transmitting virtually over tcp://{server}:{port}".format(server=self.server, + port=self.port) + + + +############################ client ############################# +############################ ############################# +############################ ############################# + +class STLClient(object): + """docstring for STLClient""" + + def __init__(self, + username = common.get_current_user(), + server = "localhost", + sync_port = 4501, + async_port = 4500, + verbose_level = LoggerApi.VERBOSE_QUIET, + logger = None, + virtual = False): + + + self.username = username + + # init objects + self.ports = {} + self.server_version = {} + self.system_info = {} + self.session_id = random.getrandbits(32) + self.connected = False + + # logger + self.logger = DefaultLogger() if not logger else logger + + # initial verbose + self.logger.set_verbose(verbose_level) + + # low level RPC layer + self.comm_link = CCommLink(server, + sync_port, + virtual, + self.logger) + + # async event handler manager + self.event_handler = AsyncEventHandler(self) + + # async subscriber level + self.async_client = CTRexAsyncClient(server, + async_port, + self) + + + + + # stats + self.connection_info = {"username": username, + "server": server, + "sync_port": sync_port, + "async_port": async_port, + "virtual": virtual} + + + self.global_stats = trex_stl_stats.CGlobalStats(self.connection_info, + self.server_version, + self.ports) + + self.stats_generator = trex_stl_stats.CTRexInfoGenerator(self.global_stats, + self.ports) + + + + ############# private functions - used by the class itself ########### + + # some preprocessing for port argument + def __ports (self, port_id_list): + + # none means all + if port_id_list == None: + return range(0, self.get_port_count()) + + # always list + if isinstance(port_id_list, int): + port_id_list = [port_id_list] + + if not isinstance(port_id_list, list): + raise ValueError("bad port id list: {0}".format(port_id_list)) + + for port_id in port_id_list: + if not isinstance(port_id, int) or (port_id < 0) or (port_id > self.get_port_count()): + raise ValueError("bad port id {0}".format(port_id)) + + return port_id_list + + + # sync ports + def __sync_ports (self, port_id_list = None, force = False): + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].sync()) + + return rc + + # acquire ports, if port_list is none - get all + def __acquire (self, port_id_list = None, force = False): + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].acquire(force)) + + return rc + + # release ports + def __release (self, port_id_list = None): + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].release()) + + return rc + + + def __add_streams(self, stream_list, port_id_list = None): + + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].add_streams(stream_list)) + + return rc + + + + def __remove_streams(self, stream_id_list, port_id_list = None): + + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].remove_streams(stream_id_list)) + + return rc + + + + def __remove_all_streams(self, port_id_list = None): + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].remove_all_streams()) + + return rc + + + def __get_stream(self, stream_id, port_id, get_pkt = False): + + return self.ports[port_id].get_stream(stream_id) + + + def __get_all_streams(self, port_id, get_pkt = False): + + return self.ports[port_id].get_all_streams() + + + def __get_stream_id_list(self, port_id): + + return self.ports[port_id].get_stream_id_list() + + + def __start (self, multiplier, duration, port_id_list = None, force = False): + + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].start(multiplier, duration, force)) + + return rc + + + def __resume (self, port_id_list = None, force = False): + + port_id_list = self.__ports(port_id_list) + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].resume()) + + return rc + + def __pause (self, port_id_list = None, force = False): + + port_id_list = self.__ports(port_id_list) + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].pause()) + + return rc + + + def __stop (self, port_id_list = None, force = False): + + port_id_list = self.__ports(port_id_list) + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].stop(force)) + + return rc + + + def __update (self, mult, port_id_list = None, force = False): + + port_id_list = self.__ports(port_id_list) + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].update(mult, force)) + + return rc + + + def __validate (self, port_id_list = None): + port_id_list = self.__ports(port_id_list) + + rc = RC() + + for port_id in port_id_list: + rc.add(self.ports[port_id].validate()) + + return rc + + + + # connect to server + def __connect(self): + + # first disconnect if already connected + if self.is_connected(): + self.__disconnect() + + # clear this flag + self.connected = False + + # connect sync channel + self.logger.pre_cmd("Connecting to RPC server on {0}:{1}".format(self.connection_info['server'], self.connection_info['sync_port'])) + rc = self.comm_link.connect() + self.logger.post_cmd(rc) + + if not rc: + return rc + + # version + rc = self._transmit("get_version") + if not rc: + return rc + + + self.server_version = rc.data() + self.global_stats.server_version = rc.data() + + # cache system info + rc = self._transmit("get_system_info") + if not rc: + return rc + + self.system_info = rc.data() + + # cache supported commands + rc = self._transmit("get_supported_cmds") + if not rc: + return rc + + self.supported_cmds = rc.data() + + # create ports + for port_id in xrange(self.system_info["port_count"]): + speed = self.system_info['ports'][port_id]['speed'] + driver = self.system_info['ports'][port_id]['driver'] + + self.ports[port_id] = Port(port_id, + speed, + driver, + self.username, + self.comm_link, + self.session_id) + + + # sync the ports + rc = self.__sync_ports() + if not rc: + return rc + + + # connect async channel + self.logger.pre_cmd("connecting to publisher server on {0}:{1}".format(self.connection_info['server'], self.connection_info['async_port'])) + rc = self.async_client.connect() + self.logger.post_cmd(rc) + + if not rc: + return rc + + self.connected = True + + return RC_OK() + + + # disconenct from server + def __disconnect(self, release_ports = True): + # release any previous acquired ports + if self.is_connected() and release_ports: + self.__release(self.get_acquired_ports()) + + self.comm_link.disconnect() + self.async_client.disconnect() + + self.connected = False + + return RC_OK() + + + # clear stats + def __clear_stats(self, port_id_list, clear_global): + + for port_id in port_id_list: + self.ports[port_id].clear_stats() + + if clear_global: + self.global_stats.clear_stats() + + self.logger.log_cmd("clearing stats on port(s) {0}:".format(port_id_list)) + + return RC + + + # get stats + def __get_stats (self, port_id_list): + stats = {} + + stats['global'] = self.global_stats.get_stats() + + total = {} + for port_id in port_id_list: + port_stats = self.ports[port_id].get_stats() + stats[port_id] = port_stats + + for k, v in port_stats.iteritems(): + if not k in total: + total[k] = v + else: + total[k] += v + + stats['total'] = total + + return stats + + + ############ functions used by other classes but not users ############## + + def _verify_port_id_list (self, port_id_list): + # check arguments + if not isinstance(port_id_list, list): + return RC_ERR("ports should be an instance of 'list' not {0}".format(type(port_id_list))) + + # all ports are valid ports + if not port_id_list or not all([port_id in self.get_all_ports() for port_id in port_id_list]): + return RC_ERR("") + + return RC_OK() + + def _validate_port_list(self, port_id_list): + if not isinstance(port_id_list, list): + return False + + # check each item of the sequence + return (port_id_list and all([port_id in self.get_all_ports() for port_id in port_id_list])) + + + + # transmit request on the RPC link + def _transmit(self, method_name, params={}): + return self.comm_link.transmit(method_name, params) + + # transmit batch request on the RPC link + def _transmit_batch(self, batch_list): + return self.comm_link.transmit_batch(batch_list) + + # stats + def _get_formatted_stats(self, port_id_list, stats_mask = trex_stl_stats.COMPACT): + stats_opts = trex_stl_stats.ALL_STATS_OPTS.intersection(stats_mask) + + stats_obj = {} + for stats_type in stats_opts: + stats_obj.update(self.stats_generator.generate_single_statistic(port_id_list, stats_type)) + + return stats_obj + + def _get_streams(self, port_id_list, streams_mask=set()): + + streams_obj = self.stats_generator.generate_streams_info(port_id_list, streams_mask) + + return streams_obj + + + def _invalidate_stats (self, port_id_list): + for port_id in port_id_list: + self.ports[port_id].invalidate_stats() + + self.global_stats.invalidate() + + return RC_OK() + + + + + + ################################# + # ------ private methods ------ # + @staticmethod + def __get_mask_keys(ok_values={True}, **kwargs): + masked_keys = set() + for key, val in kwargs.iteritems(): + if val in ok_values: + masked_keys.add(key) + return masked_keys + + @staticmethod + def __filter_namespace_args(namespace, ok_values): + return {k: v for k, v in namespace.__dict__.items() if k in ok_values} + + + # API decorator - double wrap because of argument + def __api_check(connected = True): + + def wrap (f): + def wrap2(*args, **kwargs): + client = args[0] + + func_name = f.__name__ + + # check connection + if connected and not client.is_connected(): + raise STLStateError(func_name, 'disconnected') + + ret = f(*args, **kwargs) + return ret + return wrap2 + + return wrap + + + + ############################ API ############################# + ############################ ############################# + ############################ ############################# + def __enter__ (self): + self.connect() + self.acquire(force = True) + self.reset() + return self + + def __exit__ (self, type, value, traceback): + if self.get_active_ports(): + self.stop(self.get_active_ports()) + self.disconnect() + + ############################ Getters ############################# + ############################ ############################# + ############################ ############################# + + + # return verbose level of the logger + def get_verbose (self): + return self.logger.get_verbose() + + # is the client on read only mode ? + def is_all_ports_acquired (self): + return not (self.get_all_ports() == self.get_acquired_ports()) + + # is the client connected ? + def is_connected (self): + return self.connected and self.comm_link.is_connected + + + # get connection info + def get_connection_info (self): + return self.connection_info + + + # get supported commands by the server + def get_server_supported_cmds(self): + return self.supported_cmds + + # get server version + def get_server_version(self): + return self.server_version + + # get server system info + def get_server_system_info(self): + return self.system_info + + # get port count + def get_port_count(self): + return len(self.ports) + + + # returns the port object + def get_port (self, port_id): + port = self.ports.get(port_id, None) + if (port != None): + return port + else: + raise STLArgumentError('port id', port_id, valid_values = self.get_all_ports()) + + + # get all ports as IDs + def get_all_ports (self): + return self.ports.keys() + + # get all acquired ports + def get_acquired_ports(self): + return [port_id + for port_id, port_obj in self.ports.iteritems() + if port_obj.is_acquired()] + + # get all active ports (TX or pause) + def get_active_ports(self): + return [port_id + for port_id, port_obj in self.ports.iteritems() + if port_obj.is_active()] + + # get paused ports + def get_paused_ports (self): + return [port_id + for port_id, port_obj in self.ports.iteritems() + if port_obj.is_paused()] + + # get all TX ports + def get_transmitting_ports (self): + return [port_id + for port_id, port_obj in self.ports.iteritems() + if port_obj.is_transmitting()] + + + # get stats + def get_stats (self, ports = None, async_barrier = True): + # by default use all ports + if ports == None: + ports = self.get_acquired_ports() + else: + ports = self.__ports(ports) + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # check async barrier + if not type(async_barrier) is bool: + raise STLArgumentError('async_barrier', async_barrier) + + + # if the user requested a barrier - use it + if async_barrier: + rc = self.async_client.barrier() + if not rc: + raise STLError(rc) + + return self.__get_stats(ports) + + # return all async events + def get_events (self): + return self.event_handler.get_events() + + ############################ Commands ############################# + ############################ ############################# + ############################ ############################# + + + """ + Sets verbose level + + :parameters: + level : str + "high" + "low" + "normal" + + :raises: + None + + """ + def set_verbose (self, level): + modes = {'low' : LoggerApi.VERBOSE_QUIET, 'normal': LoggerApi.VERBOSE_REGULAR, 'high': LoggerApi.VERBOSE_HIGH} + + if not level in modes.keys(): + raise STLArgumentError('level', level) + + self.logger.set_verbose(modes[level]) + + + """ + Connects to the TRex server + + :parameters: + None + + :raises: + + :exc:`STLError` + + """ + @__api_check(False) + def connect (self): + rc = self.__connect() + if not rc: + raise STLError(rc) + + + """ + Disconnects from the server + + :parameters: + stop_traffic : bool + tries to stop traffic before disconnecting + release_ports : bool + tries to release all the acquired ports + + """ + @__api_check(False) + def disconnect (self, stop_traffic = True, release_ports = True): + + # try to stop ports but do nothing if not possible + if stop_traffic: + try: + self.stop() + except STLError: + pass + + + self.logger.pre_cmd("Disconnecting from server at '{0}':'{1}'".format(self.connection_info['server'], + self.connection_info['sync_port'])) + rc = self.__disconnect(release_ports) + self.logger.post_cmd(rc) + + + + """ + Acquires ports for executing commands + + :parameters: + ports : list + ports to execute the command + force : bool + force acquire the ports + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def acquire (self, ports = None, force = False): + # by default use all ports + if ports == None: + ports = self.get_all_ports() + + # verify ports + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # verify valid port id list + if force: + self.logger.pre_cmd("Force acquiring ports {0}:".format(ports)) + else: + self.logger.pre_cmd("Acquiring ports {0}:".format(ports)) + + rc = self.__acquire(ports, force) + + self.logger.post_cmd(rc) + + if not rc: + # cleanup + self.__release(ports) + raise STLError(rc) + + + """ + Release ports + + :parameters: + ports : list + ports to execute the command + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def release (self, ports = None): + # by default use all acquired ports + if ports == None: + ports = self.get_acquired_ports() + + # verify ports + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.logger.pre_cmd("Releasing ports {0}:".format(ports)) + rc = self.__release(ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + """ + Pings the server + + :parameters: + None + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def ping(self): + self.logger.pre_cmd( "Pinging the server on '{0}' port '{1}': ".format(self.connection_info['server'], + self.connection_info['sync_port'])) + rc = self._transmit("ping") + + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + + """ + force acquire ports, stop the traffic, remove all streams and clear stats + + :parameters: + ports : list + ports to execute the command + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def reset(self, ports = None): + + # by default use all ports + if ports == None: + ports = self.get_all_ports() + + # verify ports + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.acquire(ports, force = True) + self.stop(ports) + self.remove_all_streams(ports) + self.clear_stats(ports) + + + """ + remove all streams from port(s) + + :parameters: + ports : list + ports to execute the command + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def remove_all_streams (self, ports = None): + + # by default use all ports + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.logger.pre_cmd("Removing all streams from port(s) {0}:".format(ports)) + rc = self.__remove_all_streams(ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + """ + add a list of streams to port(s) + + :parameters: + ports : list + ports to execute the command + streams: list + streams to attach + + :returns: + list of stream IDs in order of the stream list + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def add_streams (self, streams, ports = None): + # by default use all ports + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # transform single stream + if not isinstance(streams, list): + streams = [streams] + + # check streams + if not all([isinstance(stream, STLStream) for stream in streams]): + raise STLArgumentError('streams', streams) + + self.logger.pre_cmd("Attaching {0} streams to port(s) {1}:".format(len(streams), ports)) + rc = self.__add_streams(streams, ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + return [stream.get_id() for stream in streams] + + + """ + remove a list of streams from ports + + :parameters: + ports : list + ports to execute the command + stream_id_list: list + stream id list to remove + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def remove_streams (self, stream_id_list, ports = None): + # by default use all ports + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # transform single stream + if not isinstance(stream_id_list, list): + stream_id_list = [stream_id_list] + + # check streams + if not all([isinstance(stream_id, long) for stream_id in stream_id_list]): + raise STLArgumentError('stream_id_list', stream_id_list) + + # remove streams + self.logger.pre_cmd("Removing {0} streams from port(s) {1}:".format(len(stream_id_list), ports)) + rc = self.__remove_streams(stream_id_list, ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + """ + load a profile from file + + :parameters: + filename : str + filename to load + + :returns: + list of streams from the profile + + :raises: + + :exc:`STLError` + + """ + @staticmethod + def load_profile (filename): + + # check filename + if not os.path.isfile(filename): + raise STLError("file '{0}' does not exists".format(filename)) + + streams = None + + # try YAML + try: + streams = STLStream.load_from_yaml(filename) + print "***** YAML IS NOT WORKING !!! *********" + + + except YAMLError: + # try python loader + try: + basedir = os.path.dirname(filename) + + sys.path.append(basedir) + file = os.path.basename(filename).split('.')[0] + module = __import__(file, globals(), locals(), [], -1) + reload(module) # reload the update + + streams = module.register().get_streams() + + except Exception as e : + print str(e); + traceback.print_exc(file=sys.stdout) + raise STLError("Unexpected error: '{0}'".format(filename)) + + return streams + + + + + """ + start traffic on port(s) + + :parameters: + ports : list + ports to execute command + + mult : str + multiplier in a form of pps, bps, or line util in % + examples: "5kpps", "10gbps", "85%", "32mbps" + + force : bool + imply stopping the port of active and also + forces a profile that exceeds the L1 BW + + duration : int + limit the run for time in seconds + -1 means unlimited + + total : bool + should the B/W be divided by the ports + or duplicated for each + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def start (self, + ports = None, + mult = "1", + force = False, + duration = -1, + total = False): + + + # by default use all ports + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # verify multiplier + mult_obj = parsing_opts.decode_multiplier(mult, + allow_update = False, + divide_count = len(ports) if total else 1) + if not mult_obj: + raise STLArgumentError('mult', mult) + + # some type checkings + + if not type(force) is bool: + raise STLArgumentError('force', force) + + if not isinstance(duration, (int, float)): + raise STLArgumentError('duration', duration) + + if not type(total) is bool: + raise STLArgumentError('total', total) + + + # verify ports are stopped or force stop them + active_ports = list(set(self.get_active_ports()).intersection(ports)) + if active_ports: + if not force: + raise STLError("Port(s) {0} are active - please stop them or specify 'force'".format(active_ports)) + else: + rc = self.stop(active_ports) + if not rc: + raise STLError(rc) + + + # start traffic + self.logger.pre_cmd("Starting traffic on port(s) {0}:".format(ports)) + rc = self.__start(mult_obj, duration, ports, force) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + + + """ + stop port(s) + + :parameters: + ports : list + ports to execute the command + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def stop (self, ports = None): + + # by default the user means all the active ports + if ports == None: + ports = self.get_active_ports() + if not ports: + return + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.logger.pre_cmd("Stopping traffic on port(s) {0}:".format(ports)) + rc = self.__stop(ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + + """ + update traffic on port(s) + + :parameters: + ports : list + ports to execute command + + mult : str + multiplier in a form of pps, bps, or line util in % + and also with +/- + examples: "5kpps+", "10gbps-", "85%", "32mbps", "20%+" + + force : bool + forces a profile that exceeds the L1 BW + + total : bool + should the B/W be divided by the ports + or duplicated for each + + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def update (self, ports = None, mult = "1", total = False, force = False): + + # by default the user means all the active ports + if ports == None: + ports = self.get_active_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # verify multiplier + mult_obj = parsing_opts.decode_multiplier(mult, + allow_update = True, + divide_count = len(ports) if total else 1) + if not mult_obj: + raise STLArgumentError('mult', mult) + + # verify total + if not type(total) is bool: + raise STLArgumentError('total', total) + + + # call low level functions + self.logger.pre_cmd("Updating traffic on port(s) {0}:".format(ports)) + rc = self.__update(mult, ports, force) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + + """ + pause traffic on port(s) + + :parameters: + ports : list + ports to execute command + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def pause (self, ports = None): + + # by default the user means all the TX ports + if ports == None: + ports = self.get_transmitting_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.logger.pre_cmd("Pausing traffic on port(s) {0}:".format(ports)) + rc = self.__pause(ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + + """ + resume traffic on port(s) + + :parameters: + ports : list + ports to execute command + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def resume (self, ports = None): + + # by default the user means all the paused ports + if ports == None: + ports = self.get_paused_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + self.logger.pre_cmd("Resume traffic on port(s) {0}:".format(ports)) + rc = self.__resume(ports) + self.logger.post_cmd(rc) + + if not rc: + raise STLError(rc) + + + """ + validate port(s) configuration + + :parameters: + ports : list + ports to execute command + + mult : str + multiplier in a form of pps, bps, or line util in % + examples: "5kpps", "10gbps", "85%", "32mbps" + + duration : int + limit the run for time in seconds + -1 means unlimited + + total : bool + should the B/W be divided by the ports + or duplicated for each + + :raises: + + :exc:`STLError` + + """ + @__api_check(True) + def validate (self, ports = None, mult = "1", duration = "-1", total = False): + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # verify multiplier + mult_obj = parsing_opts.decode_multiplier(mult, + allow_update = True, + divide_count = len(ports) if total else 1) + if not mult_obj: + raise STLArgumentError('mult', mult) + + + if not isinstance(duration, (int, float)): + raise STLArgumentError('duration', duration) + + + self.logger.pre_cmd("Validating streams on port(s) {0}:".format(ports)) + rc = self.__validate(ports) + self.logger.post_cmd(rc) + + + for port in ports: + self.ports[port].print_profile(mult_obj, duration) + + + """ + clear stats on port(s) + + :parameters: + ports : list + ports to execute command + + clear_global : bool + clear the global stats + + :raises: + + :exc:`STLError` + + """ + @__api_check(False) + def clear_stats (self, ports = None, clear_global = True): + + # by default use all ports + if ports == None: + ports = self.get_all_ports() + else: + ports = self.__ports(ports) + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + # verify clear global + if not type(clear_global) is bool: + raise STLArgumentError('clear_global', clear_global) + + + rc = self.__clear_stats(ports, clear_global) + if not rc: + raise STLError(rc) + + + + + + """ + block until specify port(s) traffic has ended + + :parameters: + ports : list + ports to execute command + + timeout : int + timeout in seconds + + :raises: + + :exc:`STLTimeoutError` - in case timeout has expired + + :exe:'STLError' + + """ + @__api_check(True) + def wait_on_traffic (self, ports = None, timeout = 60): + + # by default use all acquired ports + if ports == None: + ports = self.get_acquired_ports() + + # verify valid port id list + rc = self._validate_port_list(ports) + if not rc: + raise STLArgumentError('ports', ports, valid_values = self.get_all_ports()) + + expr = time.time() + timeout + + # wait while any of the required ports are active + while set(self.get_active_ports()).intersection(ports): + time.sleep(0.01) + if time.time() > expr: + raise STLTimeoutError(timeout) + + + """ + clear all events + + :parameters: + None + + :raises: + None + + """ + def clear_events (self): + self.event_handler.clear_events() + + + ############################ Line ############################# + ############################ Commands ############################# + ############################ ############################# + + # console decorator + def __console(f): + def wrap(*args): + client = args[0] + + time1 = time.time() + + try: + rc = f(*args) + except STLError as e: + client.logger.log("Log:\n" + format_text(e.brief() + "\n", 'bold')) + return + + # if got true - print time + if rc: + delta = time.time() - time1 + client.logger.log(format_time(delta) + "\n") + + + return wrap + + + @__console + def connect_line (self, line): + '''Connects to the TRex server''' + # define a parser + parser = parsing_opts.gen_parser(self, + "connect", + self.connect_line.__doc__, + parsing_opts.FORCE) + + opts = parser.parse_args(line.split()) + + if opts is None: + return + + # call the API + self.connect() + self.acquire(force = opts.force) + + # true means print time + return True + + @__console + def disconnect_line (self, line): + self.disconnect() + + + + @__console + def reset_line (self, line): + self.reset() + + # true means print time + return True + + + @__console + def start_line (self, line): + '''Start selected traffic in specified ports on TRex\n''' + # define a parser + parser = parsing_opts.gen_parser(self, + "start", + self.start_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL, + parsing_opts.TOTAL, + parsing_opts.FORCE, + parsing_opts.STREAM_FROM_PATH_OR_FILE, + parsing_opts.DURATION, + parsing_opts.MULTIPLIER_STRICT, + parsing_opts.DRY_RUN) + + opts = parser.parse_args(line.split()) + + + if opts is None: + return + + + active_ports = list(set(self.get_active_ports()).intersection(opts.ports)) + + if active_ports: + if not opts.force: + msg = "Port(s) {0} are active - please stop them or add '--force'\n".format(active_ports) + self.logger.log(format_text(msg, 'bold')) + return + else: + self.stop(active_ports) + + + # remove all streams + self.remove_all_streams(opts.ports) + + # pack the profile + streams = self.load_profile(opts.file[0]) + self.add_streams(streams, ports = opts.ports) + + if opts.dry: + self.validate(opts.ports, opts.mult, opts.duration, opts.total) + else: + self.start(opts.ports, + opts.mult, + opts.force, + opts.duration, + opts.total) + + # true means print time + return True + + + + @__console + def stop_line (self, line): + '''Stop active traffic in specified ports on TRex\n''' + parser = parsing_opts.gen_parser(self, + "stop", + self.stop_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + # find the relevant ports + ports = list(set(self.get_active_ports()).intersection(opts.ports)) + + if not ports: + self.logger.log(format_text("No active traffic on provided ports\n", 'bold')) + return + + self.stop(ports) + + # true means print time + return True + + + @__console + def update_line (self, line): + '''Update port(s) speed currently active\n''' + parser = parsing_opts.gen_parser(self, + "update", + self.update_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL, + parsing_opts.MULTIPLIER, + parsing_opts.TOTAL, + parsing_opts.FORCE) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + # find the relevant ports + ports = list(set(self.get_active_ports()).intersection(opts.ports)) + + if not ports: + self.logger.log(format_text("No ports in valid state to update\n", 'bold')) + return + + self.update(ports, opts.mult, opts.total, opts.force) + + # true means print time + return True + + + @__console + def pause_line (self, line): + '''Pause active traffic in specified ports on TRex\n''' + parser = parsing_opts.gen_parser(self, + "pause", + self.pause_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + # find the relevant ports + ports = list(set(self.get_transmitting_ports()).intersection(opts.ports)) + + if not ports: + self.logger.log(format_text("No ports in valid state to pause\n", 'bold')) + return + + self.pause(ports) + + # true means print time + return True + + + @__console + def resume_line (self, line): + '''Resume active traffic in specified ports on TRex\n''' + parser = parsing_opts.gen_parser(self, + "resume", + self.resume_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + # find the relevant ports + ports = list(set(self.get_paused_ports()).intersection(opts.ports)) + + if not ports: + self.logger.log(format_text("No ports in valid state to resume\n", 'bold')) + return + + return self.resume(ports) + + # true means print time + return True + + + @__console + def clear_stats_line (self, line): + '''Clear cached local statistics\n''' + # define a parser + parser = parsing_opts.gen_parser(self, + "clear", + self.clear_stats_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL) + + opts = parser.parse_args(line.split()) + + if opts is None: + return + + self.clear_stats(opts.ports) + + + + + @__console + def show_stats_line (self, line): + '''Fetch statistics from TRex server by port\n''' + # define a parser + parser = parsing_opts.gen_parser(self, + "stats", + self.show_stats_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL, + parsing_opts.STATS_MASK) + + opts = parser.parse_args(line.split()) + + if opts is None: + return + + # determine stats mask + mask = self.__get_mask_keys(**self.__filter_namespace_args(opts, trex_stl_stats.ALL_STATS_OPTS)) + if not mask: + # set to show all stats if no filter was given + mask = trex_stl_stats.ALL_STATS_OPTS + + stats_opts = trex_stl_stats.ALL_STATS_OPTS.intersection(mask) + + stats = self._get_formatted_stats(opts.ports, mask) + + + # print stats to screen + for stat_type, stat_data in stats.iteritems(): + text_tables.print_table_with_header(stat_data.text_table, stat_type) + + + @__console + def show_streams_line(self, line): + '''Fetch streams statistics from TRex server by port\n''' + # define a parser + parser = parsing_opts.gen_parser(self, + "streams", + self.show_streams_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL, + parsing_opts.STREAMS_MASK) + + opts = parser.parse_args(line.split()) + + if opts is None: + return + + streams = self._get_streams(opts.ports, set(opts.streams)) + if not streams: + self.logger.log(format_text("No streams found with desired filter.\n", "bold", "magenta")) + + else: + # print stats to screen + for stream_hdr, port_streams_data in streams.iteritems(): + text_tables.print_table_with_header(port_streams_data.text_table, + header= stream_hdr.split(":")[0] + ":", + untouched_header= stream_hdr.split(":")[1]) + + + + + @__console + def validate_line (self, line): + '''validates port(s) stream configuration\n''' + + parser = parsing_opts.gen_parser(self, + "validate", + self.validate_line.__doc__, + parsing_opts.PORT_LIST_WITH_ALL) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + self.validate(opts.ports) + + + +
\ No newline at end of file diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_exceptions.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_exceptions.py new file mode 100644 index 00000000..45acc72e --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_exceptions.py @@ -0,0 +1,54 @@ +import os +import sys + +from utils.text_opts import * + +# basic error for API +class STLError(Exception): + def __init__ (self, msg): + self.msg = str(msg) + + def __str__ (self): + exc_type, exc_obj, exc_tb = sys.exc_info() + fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] + + + s = "\n******\n" + s += "Error at {0}:{1}\n\n".format(format_text(fname, 'bold'), format_text(exc_tb.tb_lineno), 'bold') + s += "specific error:\n\n{0}\n".format(format_text(self.msg, 'bold')) + + return s + + def brief (self): + return self.msg + + +# raised when the client state is invalid for operation +class STLStateError(STLError): + def __init__ (self, op, state): + self.msg = "Operation '{0}' is not valid while '{1}'".format(op, state) + + +# port state error +class STLPortStateError(STLError): + def __init__ (self, port, op, state): + self.msg = "Operation '{0}' on port(s) '{1}' is not valid while port(s) '{2}'".format(op, port, state) + + +# raised when argument is not valid for operation +class STLArgumentError(STLError): + def __init__ (self, name, got, valid_values = None, extended = None): + self.msg = "Argument: '{0}' invalid value: '{1}'".format(name, got) + if valid_values: + self.msg += " - valid values are '{0}'".format(valid_values) + + if extended: + self.msg += "\n{0}".format(extended) + +# raised when timeout occurs +class STLTimeoutError(STLError): + def __init__ (self, timeout): + self.msg = "Timeout: operation took more than '{0}' seconds".format(timeout) + + + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_ext.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_ext.py new file mode 100644 index 00000000..1092679a --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_ext.py @@ -0,0 +1,74 @@ +import sys +import os +import warnings + +# if not set - set it to default +if not 'TREX_STL_EXT_PATH' in globals(): + CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) + # ../../../../external_libs + TREX_STL_EXT_PATH = os.path.abspath(os.path.join(CURRENT_PATH, os.pardir, os.pardir, os.pardir, os.pardir, 'external_libs')) + +# check path exists +if not os.path.exists(TREX_STL_EXT_PATH): + print "Unable to find external packages path: '{0}'".format(TREX_STL_EXT_PATH) + print "Please provide the correct path using TREX_STL_EXT_PATH variable" + exit(0) + +# the modules required +CLIENT_UTILS_MODULES = ['dpkt-1.8.6', + 'yaml-3.11', + 'texttable-0.8.4', + 'scapy-2.3.1' + ] + +def import_client_utils_modules(): + import_module_list(CLIENT_UTILS_MODULES) + + +def import_module_list(modules_list): + assert(isinstance(modules_list, list)) + for p in modules_list: + full_path = os.path.join(TREX_STL_EXT_PATH, p) + fix_path = os.path.normcase(full_path) + sys.path.insert(1, full_path) + + + import_platform_dirs() + + + +def import_platform_dirs (): + # handle platform dirs + + # try fedora 18 first and then cel5.9 + # we are using the ZMQ module to determine the right platform + + full_path = os.path.join(TREX_STL_EXT_PATH, 'platform/fedora18') + fix_path = os.path.normcase(full_path) + sys.path.insert(0, full_path) + try: + # try to import and delete it from the namespace + import zmq + del zmq + return + except: + sys.path.pop(0) + pass + + full_path = os.path.join(TREX_STL_EXT_PATH, 'platform/cel59') + fix_path = os.path.normcase(full_path) + sys.path.insert(0, full_path) + try: + # try to import and delete it from the namespace + import zmq + del zmq + return + + except: + sys.path.pop(0) + sys.modules['zmq'] = None + warnings.warn("unable to determine platform type for ZMQ import") + + +import_client_utils_modules() + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py new file mode 100644 index 00000000..ab3c7282 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py @@ -0,0 +1,244 @@ +#!/router/bin/python + +import zmq +import json +import re +from time import sleep +from collections import namedtuple +from trex_stl_types import * +from utils.common import random_id_gen + +class bcolors: + BLUE = '\033[94m' + GREEN = '\033[32m' + YELLOW = '\033[93m' + RED = '\033[31m' + MAGENTA = '\033[35m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +# sub class to describe a batch +class BatchMessage(object): + def __init__ (self, rpc_client): + self.rpc_client = rpc_client + self.batch_list = [] + + def add (self, method_name, params={}): + + id, msg = self.rpc_client.create_jsonrpc_v2(method_name, params, encode = False) + self.batch_list.append(msg) + + def invoke(self, block = False): + if not self.rpc_client.connected: + return RC_ERR("Not connected to server") + + msg = json.dumps(self.batch_list) + + return self.rpc_client.send_raw_msg(msg) + + +# JSON RPC v2.0 client +class JsonRpcClient(object): + + def __init__ (self, default_server, default_port, logger): + self.logger = logger + self.connected = False + + # default values + self.port = default_port + self.server = default_server + + self.id_gen = random_id_gen() + + + def get_connection_details (self): + rc = {} + rc['server'] = self.server + rc['port'] = self.port + + return rc + + # pretty print for JSON + def pretty_json (self, json_str, use_colors = True): + pretty_str = json.dumps(json.loads(json_str), indent = 4, separators=(',', ': '), sort_keys = True) + + if not use_colors: + return pretty_str + + try: + # int numbers + pretty_str = re.sub(r'([ ]*:[ ]+)(\-?[1-9][0-9]*[^.])',r'\1{0}\2{1}'.format(bcolors.BLUE, bcolors.ENDC), pretty_str) + # float + pretty_str = re.sub(r'([ ]*:[ ]+)(\-?[1-9][0-9]*\.[0-9]+)',r'\1{0}\2{1}'.format(bcolors.MAGENTA, bcolors.ENDC), pretty_str) + # strings + + pretty_str = re.sub(r'([ ]*:[ ]+)("[^"]*")',r'\1{0}\2{1}'.format(bcolors.RED, bcolors.ENDC), pretty_str) + pretty_str = re.sub(r"('[^']*')", r'{0}\1{1}'.format(bcolors.MAGENTA, bcolors.RED), pretty_str) + except : + pass + + return pretty_str + + def verbose_msg (self, msg): + self.logger.log("\n\n[verbose] " + msg, level = self.logger.VERBOSE_HIGH) + + + # batch messages + def create_batch (self): + return BatchMessage(self) + + def create_jsonrpc_v2 (self, method_name, params = {}, encode = True): + msg = {} + msg["jsonrpc"] = "2.0" + msg["method"] = method_name + + msg["params"] = params + + msg["id"] = self.id_gen.next() + + if encode: + return id, json.dumps(msg) + else: + return id, msg + + + def invoke_rpc_method (self, method_name, params = {}): + if not self.connected: + return RC_ERR("Not connected to server") + + id, msg = self.create_jsonrpc_v2(method_name, params) + + return self.send_raw_msg(msg) + + + # low level send of string message + def send_raw_msg (self, msg): + + self.verbose_msg("Sending Request To Server:\n\n" + self.pretty_json(msg) + "\n") + + tries = 0 + while True: + try: + self.socket.send(msg) + break + except zmq.Again: + tries += 1 + if tries > 5: + self.disconnect() + return RC_ERR("*** [RPC] - Failed to send message to server") + + + tries = 0 + while True: + try: + response = self.socket.recv() + break + except zmq.Again: + tries += 1 + if tries > 5: + self.disconnect() + return RC_ERR("*** [RPC] - Failed to get server response at {0}".format(self.transport)) + + + self.verbose_msg("Server Response:\n\n" + self.pretty_json(response) + "\n") + + # decode + + # batch ? + response_json = json.loads(response) + + if isinstance(response_json, list): + rc_batch = RC() + + for single_response in response_json: + rc = self.process_single_response(single_response) + rc_batch.add(rc) + + return rc_batch + + else: + return self.process_single_response(response_json) + + + def process_single_response (self, response_json): + + if (response_json.get("jsonrpc") != "2.0"): + return RC_ERR("Malformed Response ({0})".format(str(response_json))) + + # error reported by server + if ("error" in response_json): + if "specific_err" in response_json["error"]: + return RC_ERR(response_json["error"]["specific_err"]) + else: + return RC_ERR(response_json["error"]["message"]) + + + # if no error there should be a result + if ("result" not in response_json): + return RC_ERR("Malformed Response ({0})".format(str(response_json))) + + return RC_OK(response_json["result"]) + + + + def disconnect (self): + if self.connected: + self.socket.close(linger = 0) + self.context.destroy(linger = 0) + self.connected = False + return RC_OK() + else: + return RC_ERR("Not connected to server") + + + def connect(self, server = None, port = None): + if self.connected: + self.disconnect() + + self.context = zmq.Context() + + self.server = (server if server else self.server) + self.port = (port if port else self.port) + + # Socket to talk to server + self.transport = "tcp://{0}:{1}".format(self.server, self.port) + + self.socket = self.context.socket(zmq.REQ) + try: + self.socket.connect(self.transport) + except zmq.error.ZMQError as e: + return RC_ERR("ZMQ Error: Bad server or port name: " + str(e)) + + self.socket.setsockopt(zmq.SNDTIMEO, 1000) + self.socket.setsockopt(zmq.RCVTIMEO, 1000) + + self.connected = True + + rc = self.invoke_rpc_method('ping') + if not rc: + self.connected = False + return rc + + return RC_OK() + + + def reconnect(self): + # connect using current values + return self.connect() + + if not self.connected: + return RC_ERR("Not connected to server") + + # reconnect + return self.connect(self.server, self.port) + + + def is_connected(self): + return self.connected + + def __del__(self): + self.logger.log("Shutting down RPC client\n") + if hasattr(self, "context"): + self.context.destroy(linger=0) + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_interface.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_interface.py new file mode 100644 index 00000000..b6e7c026 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_interface.py @@ -0,0 +1,43 @@ + +# base object class for a packet builder +class CTrexPktBuilderInterface(object): + + def compile (self): + """ + Compiles the packet and VM + """ + raise Exception("implement me") + + + def dump_pkt(self): + """ + Dumps the packet as a decimal array of bytes (each item x gets value between 0-255) + + :parameters: + None + + :return: + + packet representation as array of bytes + + :raises: + + :exc:`CTRexPktBuilder.EmptyPacketError`, in case packet is empty. + + """ + + raise Exception("implement me") + + + def get_vm_data(self): + """ + Dumps the instructions + + :parameters: + None + + :return: + + json object of instructions + + """ + + raise Exception("implement me") + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_scapy.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_scapy.py new file mode 100644 index 00000000..0811209a --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_packet_builder_scapy.py @@ -0,0 +1,808 @@ +import random +import string +import struct +import socket +import json +import yaml +import binascii +import base64 + +from trex_stl_packet_builder_interface import CTrexPktBuilderInterface + +from scapy.all import * + + + +class CTRexPacketBuildException(Exception): + """ + This is the general Packet Building error exception class. + """ + def __init__(self, code, message): + self.code = code + self.message = message + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return u"[errcode:%r] %r" % (self.code, self.message) + +################################################################################################ + +def ipv4_str_to_num (ipv4_buffer): + + assert type(ipv4_buffer)==str, 'type of ipv4_buffer is not str' + assert len(ipv4_buffer)==4, 'size of ipv4_buffer is not 4' + res=0 + shift=24 + for i in ipv4_buffer: + res = res + (ord(i)<<shift); + shift =shift -8 + return res + + + +def is_valid_ipv4(ip_addr): + """ + return buffer in network order + """ + if type(ip_addr)==str and len(ip_addr) == 4: + return ip_addr + + if type(ip_addr)==int : + ip_addr = socket.inet_ntoa(struct.pack("!I", ip_addr)) + + try: + return socket.inet_pton(socket.AF_INET, ip_addr) + except AttributeError: # no inet_pton here, sorry + return socket.inet_aton(ip_addr) + except socket.error: # not a valid address + raise CTRexPacketBuildException(-10,"not valid ipv4 format"); + + +class CTRexScriptsBase(object): + """ + VM Script base class + """ + def clone (self): + return copy.deepcopy(self) + + +class CTRexScFieldRangeBase(CTRexScriptsBase): + + FILED_TYPES = ['inc', 'dec', 'rand'] + + def __init__(self, field_name, + field_type + ): + super(CTRexScFieldRangeBase, self).__init__() + self.field_name =field_name + self.field_type =field_type + if not self.field_type in CTRexScFieldRangeBase.FILED_TYPES : + raise CTRexPacketBuildException(-12, 'field type should be in %s' % FILED_TYPES); + + +class CTRexScFieldRangeValue(CTRexScFieldRangeBase): + """ + range of field value + """ + def __init__(self, field_name, + field_type, + min_value, + max_value + ): + super(CTRexScFieldRangeValue, self).__init__(field_name,field_type) + self.min_value =min_value; + self.max_value =max_value; + if min_value > max_value: + raise CTRexPacketBuildException(-12, 'min is greater than max'); + if min_value == max_value: + raise CTRexPacketBuildException(-13, "min value is equal to max value, you can't use this type of range"); + + +class CTRexScIpv4SimpleRange(CTRexScFieldRangeBase): + """ + range of ipv4 ip + """ + def __init__(self, field_name, field_type, min_ip, max_ip): + super(CTRexScIpv4SimpleRange, self).__init__(field_name,field_type) + self.min_ip = min_ip + self.max_ip = max_ip + mmin=ipv4_str_to_num (is_valid_ipv4(min_ip)) + mmax=ipv4_str_to_num (is_valid_ipv4(max_ip)) + if mmin > mmax : + raise CTRexPacketBuildException(-11, 'CTRexScIpv4SimpleRange m_min ip is bigger than max'); + + +class CTRexScIpv4TupleGen(CTRexScriptsBase): + """ + range tuple + """ + FLAGS_ULIMIT_FLOWS =1 + + def __init__(self, min_ipv4, max_ipv4, num_flows=100000, min_port=1025, max_port=65535, flags=0): + super(CTRexScIpv4TupleGen, self).__init__() + self.min_ip = min_ipv4 + self.max_ip = max_ipv4 + mmin=ipv4_str_to_num (is_valid_ipv4(min_ipv4)) + mmax=ipv4_str_to_num (is_valid_ipv4(max_ipv4)) + if mmin > mmax : + raise CTRexPacketBuildException(-11, 'CTRexScIpv4SimpleRange m_min ip is bigger than max'); + + self.num_flows=num_flows; + + self.min_port =min_port + self.max_port =max_port + self.flags = flags + + +class CTRexScTrimPacketSize(CTRexScriptsBase): + """ + trim packet size. field type is CTRexScFieldRangeBase.FILED_TYPES = ["inc","dec","rand"] + """ + def __init__(self,field_type="rand",min_pkt_size=None, max_pkt_size=None): + super(CTRexScTrimPacketSize, self).__init__() + self.field_type = field_type + self.min_pkt_size = min_pkt_size + self.max_pkt_size = max_pkt_size + if max_pkt_size != None and min_pkt_size !=None : + if min_pkt_size == max_pkt_size: + raise CTRexPacketBuildException(-11, 'CTRexScTrimPacketSize min_pkt_size is the same as max_pkt_size '); + + if min_pkt_size > max_pkt_size: + raise CTRexPacketBuildException(-11, 'CTRexScTrimPacketSize min_pkt_size is bigger than max_pkt_size '); + +class CTRexScRaw(CTRexScriptsBase): + """ + raw instructions + """ + def __init__(self,list_of_commands=None): + super(CTRexScRaw, self).__init__() + if list_of_commands==None: + self.commands =[] + else: + self.commands = list_of_commands + + def add_cmd (self,cmd): + self.commands.append(cmd) + + + +################################################################################################ +# VM raw instructions +################################################################################################ + +class CTRexVmInsBase(object): + """ + instruction base + """ + def __init__(self, ins_type): + self.type = ins_type + assert type(ins_type)==str, 'type of ins_type is not str' + +class CTRexVmInsFixIpv4(CTRexVmInsBase): + def __init__(self, offset): + super(CTRexVmInsFixIpv4, self).__init__("fix_checksum_ipv4") + self.pkt_offset = offset + assert type(offset)==int, 'type of offset is not int' + + +class CTRexVmInsFlowVar(CTRexVmInsBase): + #TBD add more validation tests + + OPERATIONS =['inc', 'dec', 'random'] + VALID_SIZES =[1, 2, 4, 8] + + def __init__(self, fv_name, size, op, init_value, min_value, max_value): + super(CTRexVmInsFlowVar, self).__init__("flow_var") + self.name = fv_name; + assert type(fv_name)==str, 'type of fv_name is not str' + self.size = size + self.op = op + self.init_value = init_value + assert type(init_value)==int, 'type of init_value is not int' + self.min_value=min_value + assert type(min_value)==int, 'type of min_value is not int' + self.max_value=max_value + assert type(max_value)==int, 'type of min_value is not int' + +class CTRexVmInsWrFlowVar(CTRexVmInsBase): + def __init__(self, fv_name, pkt_offset, add_value=0, is_big_endian=True): + super(CTRexVmInsWrFlowVar, self).__init__("write_flow_var") + self.name = fv_name + assert type(fv_name)==str, 'type of fv_name is not str' + self.pkt_offset = pkt_offset + assert type(pkt_offset)==int, 'type of pkt_offset is not int' + self.add_value = add_value + assert type(add_value)==int, 'type of add_value is not int' + self.is_big_endian = is_big_endian + assert type(is_big_endian)==bool, 'type of is_big_endian is not bool' + +class CTRexVmInsTrimPktSize(CTRexVmInsBase): + def __init__(self,fv_name): + super(CTRexVmInsTrimPktSize, self).__init__("trim_pkt_size") + self.name = fv_name + assert type(fv_name)==str, 'type of fv_name is not str' + +class CTRexVmInsTupleGen(CTRexVmInsBase): + def __init__(self, fv_name, ip_min, ip_max, port_min, port_max, limit_flows, flags=0): + super(CTRexVmInsTupleGen, self).__init__("tuple_flow_var") + self.name =fv_name + assert type(fv_name)==str, 'type of fv_name is not str' + self.ip_min = ip_min; + self.ip_max = ip_max; + self.port_min = port_min; + self.port_max = port_max; + self.limit_flows = limit_flows; + self.flags =flags; + + +################################################################################################ +# +class CTRexVmEngine(object): + + def __init__(self): + """ + inlcude list of instruction + """ + super(CTRexVmEngine, self).__init__() + self.ins=[] + self.split_by_var = '' + + # return as json + def get_json (self): + inst_array = []; + # dump it as dict + for obj in self.ins: + inst_array.append(obj.__dict__); + + return {'instructions': inst_array, 'split_by_var': self.split_by_var} + + def add_ins (self,ins): + #assert issubclass(ins, CTRexVmInsBase) + self.ins.append(ins); + + def dump (self): + cnt=0; + for obj in self.ins: + print "ins",cnt + cnt = cnt +1 + print obj.__dict__ + + def dump_bjson (self): + print json.dumps(self.get_json(), sort_keys=True, indent=4) + + def dump_as_yaml (self): + print yaml.dump(self.get_json(), default_flow_style=False) + + + +################################################################################################ + +class CTRexScapyPktUtl(object): + + def __init__(self, scapy_pkt): + self.pkt = scapy_pkt + + def pkt_iter (self): + p=self.pkt; + while True: + yield p + p=p.payload + if p ==None or isinstance(p,NoPayload): + break; + + def get_list_iter(self): + l=list(self.pkt_iter()) + return l + + + def get_pkt_layers(self): + """ + return string 'IP:UDP:TCP' + """ + l=self.get_list_iter (); + l1=map(lambda p: p.name,l ); + return ":".join(l1); + + def _layer_offset(self, name, cnt = 0): + """ + return offset of layer e.g 'IP',1 will return offfset of layer ip:1 + """ + save_cnt=cnt + for pkt in self.pkt_iter (): + if pkt.name == name: + if cnt==0: + return (pkt, pkt.offset) + else: + cnt=cnt -1 + + raise CTRexPacketBuildException(-11,("no layer %s-%d" % (name, save_cnt))); + + + def layer_offset(self, name, cnt = 0): + """ + return offset of layer e.g 'IP',1 will return offfset of layer ip:1 + """ + save_cnt=cnt + for pkt in self.pkt_iter (): + if pkt.name == name: + if cnt==0: + return pkt.offset + else: + cnt=cnt -1 + + raise CTRexPacketBuildException(-11,("no layer %s-%d" % (name, save_cnt))); + + def get_field_offet(self, layer, layer_cnt, field_name): + """ + return offset of layer e.g 'IP',1 will return offfset of layer ip:1 + """ + t=self._layer_offset(layer,layer_cnt); + l_offset=t[1]; + layer_pkt=t[0] + + #layer_pkt.dump_fields_offsets () + + for f in layer_pkt.fields_desc: + if f.name == field_name: + return (l_offset+f.offset,f.get_size_bytes ()); + + raise CTRexPacketBuildException(-11, "no layer %s-%d." % (name, save_cnt, field_name)); + + def get_layer_offet_by_str(self, layer_des): + """ + return layer offset by string + + :parameters: + + IP:0 + IP:1 + return offset + + + """ + l1=layer_des.split(":") + layer="" + layer_cnt=0; + + if len(l1)==1: + layer=l1[0]; + else: + layer=l1[0]; + layer_cnt=int(l1[1]); + + return self.layer_offset(layer, layer_cnt) + + + + def get_field_offet_by_str(self, field_des): + """ + return field_des (offset,size) layer:cnt.field + for example + 802|1Q.vlan get 802.1Q->valn replace | with . + IP.src + IP:0.src (first IP.src like IP.src) + for example IP:1.src for internal IP + + return (offset, size) as tuple + + + """ + + s=field_des.split("."); + if len(s)!=2: + raise CTRexPacketBuildException(-11, ("field desription should be layer:cnt.field e.g IP.src or IP:1.src")); + + + layer_ex = s[0].replace("|",".") + field = s[1] + + l1=layer_ex.split(":") + layer="" + layer_cnt=0; + + if len(l1)==1: + layer=l1[0]; + else: + layer=l1[0]; + layer_cnt=int(l1[1]); + + return self.get_field_offet(layer,layer_cnt,field) + + def has_IPv4 (self): + return self.pkt.has_layer("IP"); + + def has_IPv6 (self): + return self.pkt.has_layer("IPv6"); + + def has_UDP (self): + return self.pkt.has_layer("UDP"); + +################################################################################################ + +class CTRexVmDescBase(object): + """ + instruction base + """ + def __init__(self): + pass; + + def get_obj(self): + return self; + + def get_json(self): + return self.get_obj().__dict__ + + def dump_bjson(self): + print json.dumps(self.get_json(), sort_keys=True, indent=4) + + def dump_as_yaml(self): + print yaml.dump(self.get_json(), default_flow_style=False) + + + def get_var_ref (self): + ''' + virtual function return a ref var name + ''' + return None + + def get_var_name(self): + ''' + virtual function return the varible name if exists + ''' + return None + + def compile(self,parent): + ''' + virtual function to take parent than has function name_to_offset + ''' + pass; + + +def valid_fv_size (size): + if not (size in CTRexVmInsFlowVar.VALID_SIZES): + raise CTRexPacketBuildException(-11,("flow var has not valid size %d ") % size ); + +def valid_fv_ops (op): + if not (op in CTRexVmInsFlowVar.OPERATIONS): + raise CTRexPacketBuildException(-11,("flow var does not have a valid op %s ") % op ); + +def convert_val (val): + if type(val) == int: + return val + else: + if type(val) == str: + return ipv4_str_to_num (is_valid_ipv4(val)) + else: + raise CTRexPacketBuildException(-11,("init val not valid %s ") % val ); + +def check_for_int (val): + assert type(val)==int, 'type of vcal is not int' + + +class CTRexVmDescFlowVar(CTRexVmDescBase): + def __init__(self, name, init_value=None, min_value=0, max_value=255, size=4, op="inc"): + super(CTRexVmDescFlowVar, self).__init__() + self.name = name; + assert type(name)==str, 'type of name is not str' + self.size =size + valid_fv_size(size) + self.op =op + valid_fv_ops (op) + + # choose default value for init val + if init_value == None: + init_value = max_value if op == "dec" else min_value + + self.init_value = convert_val (init_value) + self.min_value = convert_val (min_value); + self.max_value = convert_val (max_value) + + if self.min_value > self.max_value : + raise CTRexPacketBuildException(-11,("max %d is lower than min %d ") % (self.max_value,self.min_value) ); + + def get_obj (self): + return CTRexVmInsFlowVar(self.name,self.size,self.op,self.init_value,self.min_value,self.max_value); + + def get_var_name(self): + return [self.name] + + +class CTRexVmDescFixIpv4(CTRexVmDescBase): + def __init__(self, offset): + super(CTRexVmDescFixIpv4, self).__init__() + self.offset = offset; # could be a name of offset + + def get_obj (self): + return CTRexVmInsFixIpv4(self.offset); + + def compile(self,parent): + if type(self.offset)==str: + self.offset = parent._pkt_layer_offset(self.offset); + +class CTRexVmDescWrFlowVar(CTRexVmDescBase): + def __init__(self, fv_name, pkt_offset, offset_fixup=0, add_val=0, is_big=True): + super(CTRexVmDescWrFlowVar, self).__init__() + self.name =fv_name + assert type(fv_name)==str, 'type of fv_name is not str' + self.offset_fixup =offset_fixup + assert type(offset_fixup)==int, 'type of offset_fixup is not int' + self.pkt_offset =pkt_offset + self.add_val =add_val + assert type(add_val)==int,'type of add_val is not int' + self.is_big =is_big; + assert type(is_big)==bool,'type of is_big_endian is not bool' + + def get_var_ref (self): + return self.name + + def get_obj (self): + return CTRexVmInsWrFlowVar(self.name,self.pkt_offset+self.offset_fixup,self.add_val,self.is_big) + + def compile(self,parent): + if type(self.pkt_offset)==str: + t=parent._name_to_offset(self.pkt_offset) + self.pkt_offset = t[0] + + +class CTRexVmDescTrimPktSize(CTRexVmDescBase): + def __init__(self,fv_name): + super(CTRexVmDescTrimPktSize, self).__init__() + self.name = fv_name + assert type(fv_name)==str, 'type of fv_name is not str' + + def get_var_ref (self): + return self.name + + def get_obj (self): + return CTRexVmInsTrimPktSize(self.name) + + + +class CTRexVmDescTupleGen(CTRexVmDescBase): + def __init__(self,name, ip_min="0.0.0.1", ip_max="0.0.0.10", port_min=1025, port_max=65535, limit_flows=100000, flags=0): + super(CTRexVmDescTupleGen, self).__init__() + self.name = name + assert type(name)==str, 'type of fv_name is not str' + self.ip_min = convert_val(ip_min); + self.ip_max = convert_val(ip_max); + self.port_min = port_min; + check_for_int (port_min) + self.port_max = port_max; + check_for_int(port_max) + self.limit_flows = limit_flows; + check_for_int(limit_flows) + self.flags =flags; + check_for_int(flags) + + def get_var_name(self): + return [self.name+".ip",self.name+".port"] + + def get_obj (self): + return CTRexVmInsTupleGen(self.name, self.ip_min, self.ip_max, self.port_min, self.port_max, self.limit_flows, self.flags); + + +################################################################################################ + +class CScapyTRexPktBuilder(CTrexPktBuilderInterface): + + """ + This class defines the TRex API of building a packet using dpkt package. + Using this class the user can also define how TRex will handle the packet by specifying the VM setting. + pkt could be Scapy pkt or pcap file name + """ + def __init__(self, pkt = None, vm = None): + """ + Instantiate a CTRexPktBuilder object + + :parameters: + None + + """ + super(CScapyTRexPktBuilder, self).__init__() + + self.pkt = None # as input + self.pkt_raw = None # from raw pcap file + self.vm_scripts = [] # list of high level instructions + self.vm_low_level = None + self.metadata="" + was_set=False + + + # process packet + if pkt != None: + self.set_packet(pkt) + was_set=True + + # process VM + if vm != None: + if not isinstance(vm, (CTRexScRaw, list)): + raise CTRexPacketBuildException(-14, "bad value for variable vm") + + self.add_command(vm if isinstance(vm, CTRexScRaw) else CTRexScRaw(vm)) + was_set=True + + if was_set: + self.compile () + + + def dump_vm_data_as_yaml(self): + print yaml.dump(self.get_vm_data(), default_flow_style=False) + + def get_vm_data(self): + """ + Dumps the instructions + + :parameters: + None + + :return: + + json object of instructions + + :raises: + + :exc:`AssertionError`, in case VM is not compiled (is None). + """ + + assert self.vm_low_level is not None, 'vm_low_level is None, please use compile()' + + return self.vm_low_level.get_json() + + def dump_pkt(self, encode = True): + """ + Dumps the packet as a decimal array of bytes (each item x gets value between 0-255) + + :parameters: + encode : bool + Encode using base64. (disable for debug) + + Default: **True** + + :return: + + packet representation as array of bytes + + :raises: + + :exc:`AssertionError`, in case packet is empty. + + """ + + assert self.pkt, 'empty packet' + pkt_buf = self._get_pkt_as_str() + + return {'binary': base64.b64encode(pkt_buf) if encode else pkt_buf, + 'meta': self.metadata} + + def dump_pkt_to_pcap(self, file_path): + wrpcap(file_path, self._get_pkt_as_str()) + + def add_command (self, script): + self.vm_scripts.append(script.clone()); + + def dump_scripts (self): + self.vm_low_level.dump_as_yaml() + + def dump_as_hex (self): + pkt_buf = self._get_pkt_as_str() + print hexdump(pkt_buf) + + def pkt_layers_desc (self): + """ + return layer description like this IP:TCP:Pyload + + """ + pkt_buf = self._get_pkt_as_str() + scapy_pkt = Ether(pkt_buf); + pkt_utl = CTRexScapyPktUtl(scapy_pkt); + return pkt_utl.get_pkt_layers() + + def set_pkt_as_str (self, pkt_buffer): + assert type(pkt_buffer)==str, "pkt_buffer should be string" + self.pkt_raw = pkt_buffer + + def set_pcap_file (self, pcap_file): + """ + load raw pcap file into a buffer. load only the first packet + + :parameters: + pcap_file : file_name + + :raises: + + :exc:`AssertionError`, in case packet is empty. + + """ + + p=RawPcapReader(pcap_file) + was_set = False + + for pkt in p: + was_set=True; + self.pkt_raw = str(pkt[0]) + break + if not was_set : + raise CTRexPacketBuildException(-14, "no buffer inside the pcap file") + + def set_packet (self, pkt): + """ + Scapy packet Ether()/IP(src="16.0.0.1",dst="48.0.0.1")/UDP(dport=12,sport=1025)/IP()/"A"*10 + """ + if isinstance(pkt, Packet): + self.pkt = pkt; + else: + if isinstance(pkt, str): + self.set_pcap_file(pkt) + else: + raise CTRexPacketBuildException(-14, "bad packet" ) + + + + def compile (self): + self.vm_low_level=CTRexVmEngine() + if self.pkt == None and self.pkt_raw == None: + raise CTRexPacketBuildException(-14, "Packet is empty") + + if self.pkt: + self.pkt.build(); + + for sc in self.vm_scripts: + if isinstance(sc, CTRexScRaw): + self._compile_raw(sc) + + #for obj in self.vm_scripts: + # # tuple gen script + # if isinstance(obj, CTRexScIpv4TupleGen) + # self._add_tuple_gen(tuple_gen) + + #################################################### + # private + + def _compile_raw (self,obj): + + # make sure we have varibles once + vars={}; + + # add it add var to dit + for desc in obj.commands: + var_names = desc.get_var_name() + + if var_names : + for var_name in var_names: + if vars.has_key(var_name): + raise CTRexPacketBuildException(-11,("variable %s define twice ") % (var_name) ); + else: + vars[var_name]=1 + + # check that all write exits + for desc in obj.commands: + var_name = desc.get_var_ref() + if var_name : + if not vars.has_key(var_name): + raise CTRexPacketBuildException(-11,("variable %s does not exists ") % (var_name) ); + desc.compile(self); + + for desc in obj.commands: + self.vm_low_level.add_ins(desc.get_obj()); + + + def _pkt_layer_offset (self,layer_name): + assert self.pkt != None, 'empty packet' + p_utl=CTRexScapyPktUtl(self.pkt); + return p_utl.get_layer_offet_by_str(layer_name) + + def _name_to_offset(self,field_name): + assert self.pkt != None, 'empty packet' + p_utl=CTRexScapyPktUtl(self.pkt); + return p_utl.get_field_offet_by_str(field_name) + + def _get_pkt_as_str(self): + if self.pkt: + return str(self.pkt) + if self.pkt_raw: + return self.pkt_raw + raise CTRexPacketBuildException(-11, 'empty packet'); + + def _add_tuple_gen(self,tuple_gen): + + pass; + + + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py new file mode 100644 index 00000000..b2cf1c90 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py @@ -0,0 +1,547 @@ + +from collections import namedtuple, OrderedDict + +import trex_stl_stats +from trex_stl_types import * + +StreamOnPort = namedtuple('StreamOnPort', ['compiled_stream', 'metadata']) + +########## utlity ############ +def mult_to_factor (mult, max_bps_l2, max_pps, line_util): + if mult['type'] == 'raw': + return mult['value'] + + if mult['type'] == 'bps': + return mult['value'] / max_bps_l2 + + if mult['type'] == 'pps': + return mult['value'] / max_pps + + if mult['type'] == 'percentage': + return mult['value'] / line_util + + +# describes a single port +class Port(object): + STATE_DOWN = 0 + STATE_IDLE = 1 + STATE_STREAMS = 2 + STATE_TX = 3 + STATE_PAUSE = 4 + PortState = namedtuple('PortState', ['state_id', 'state_name']) + STATES_MAP = {STATE_DOWN: "DOWN", + STATE_IDLE: "IDLE", + STATE_STREAMS: "IDLE", + STATE_TX: "ACTIVE", + STATE_PAUSE: "PAUSE"} + + + def __init__ (self, port_id, speed, driver, user, comm_link, session_id): + self.port_id = port_id + self.state = self.STATE_IDLE + self.handler = None + self.comm_link = comm_link + self.transmit = comm_link.transmit + self.transmit_batch = comm_link.transmit_batch + self.user = user + self.driver = driver + self.speed = speed + self.streams = {} + self.profile = None + self.session_id = session_id + + self.port_stats = trex_stl_stats.CPortStats(self) + + self.next_available_id = 1 + + + def err(self, msg): + return RC_ERR("port {0} : {1}".format(self.port_id, msg)) + + def ok(self, data = ""): + return RC_OK(data) + + def get_speed_bps (self): + return (self.speed * 1000 * 1000 * 1000) + + # take the port + def acquire(self, force = False): + params = {"port_id": self.port_id, + "user": self.user, + "session_id": self.session_id, + "force": force} + + command = RpcCmdData("acquire", params) + rc = self.transmit(command.method, command.params) + if rc.good(): + self.handler = rc.data() + return self.ok() + else: + return self.err(rc.err()) + + # release the port + def release(self): + params = {"port_id": self.port_id, + "handler": self.handler} + + command = RpcCmdData("release", params) + rc = self.transmit(command.method, command.params) + self.handler = None + + if rc.good(): + return self.ok() + else: + return self.err(rc.err()) + + def is_acquired(self): + return (self.handler != None) + + def is_active(self): + return(self.state == self.STATE_TX ) or (self.state == self.STATE_PAUSE) + + def is_transmitting (self): + return (self.state == self.STATE_TX) + + def is_paused (self): + return (self.state == self.STATE_PAUSE) + + + def sync(self): + params = {"port_id": self.port_id} + + command = RpcCmdData("get_port_status", params) + rc = self.transmit(command.method, command.params) + if rc.bad(): + return self.err(rc.err()) + + # sync the port + port_state = rc.data()['state'] + + if port_state == "DOWN": + self.state = self.STATE_DOWN + elif port_state == "IDLE": + self.state = self.STATE_IDLE + elif port_state == "STREAMS": + self.state = self.STATE_STREAMS + elif port_state == "TX": + self.state = self.STATE_TX + elif port_state == "PAUSE": + self.state = self.STATE_PAUSE + else: + raise Exception("port {0}: bad state received from server '{1}'".format(self.port_id, port_state)) + + # TODO: handle syncing the streams into stream_db + + self.next_available_id = rc.data()['max_stream_id'] + + return self.ok() + + + # return TRUE if write commands + def is_port_writable (self): + # operations on port can be done on state idle or state streams + return ((self.state == self.STATE_IDLE) or (self.state == self.STATE_STREAMS)) + + + def __allocate_stream_id (self): + id = self.next_available_id + self.next_available_id += 1 + return id + + + # add streams + def add_streams (self, streams_list): + + if not self.is_acquired(): + return self.err("port is not owned") + + if not self.is_port_writable(): + return self.err("Please stop port before attempting to add streams") + + # listify + streams_list = streams_list if isinstance(streams_list, list) else [streams_list] + + lookup = {} + + # allocate IDs + for stream in streams_list: + if stream.get_id() == None: + stream.set_id(self.__allocate_stream_id()) + + lookup[stream.get_name()] = stream.get_id() + + # resolve names + for stream in streams_list: + next_id = -1 + + next = stream.get_next() + if next: + if not next in lookup: + return self.err("stream dependency error - unable to find '{0}'".format(next)) + next_id = lookup[next] + + stream.fields['next_stream_id'] = next_id + + + batch = [] + for stream in streams_list: + + params = {"handler": self.handler, + "port_id": self.port_id, + "stream_id": stream.get_id(), + "stream": stream.to_json()} + + cmd = RpcCmdData('add_stream', params) + batch.append(cmd) + + # meta data for show streams + self.streams[stream.get_id()] = StreamOnPort(stream.to_json(), + Port._generate_stream_metadata(stream)) + + rc = self.transmit_batch(batch) + if not rc: + return self.err(rc.err()) + + + + # the only valid state now + self.state = self.STATE_STREAMS + + return self.ok() + + + + # remove stream from port + def remove_streams (self, stream_id_list): + + if not self.is_acquired(): + return self.err("port is not owned") + + if not self.is_port_writable(): + return self.err("Please stop port before attempting to remove streams") + + # single element to list + stream_id_list = stream_id_list if isinstance(stream_id_list, list) else [stream_id_list] + + # verify existance + if not all([stream_id in self.streams for stream_id in stream_id_list]): + return self.err("stream {0} does not exists".format(stream_id)) + + batch = [] + + for stream_id in stream_id_list: + params = {"handler": self.handler, + "port_id": self.port_id, + "stream_id": stream_id} + + cmd = RpcCmdData('remove_stream', params) + batch.append(cmd) + + del self.streams[stream_id] + + + rc = self.transmit_batch(batch) + if not rc: + return self.err(rc.err()) + + self.state = self.STATE_STREAMS if (len(self.streams) > 0) else self.STATE_IDLE + + return self.ok() + + + # remove all the streams + def remove_all_streams (self): + + if not self.is_acquired(): + return self.err("port is not owned") + + if not self.is_port_writable(): + return self.err("Please stop port before attempting to remove streams") + + params = {"handler": self.handler, + "port_id": self.port_id} + + rc = self.transmit("remove_all_streams", params) + if not rc: + return self.err(rc.err()) + + self.streams = {} + + self.state = self.STATE_IDLE + + return self.ok() + + # get a specific stream + def get_stream (self, stream_id): + if stream_id in self.streams: + return self.streams[stream_id] + else: + return None + + def get_all_streams (self): + return self.streams + + # start traffic + def start (self, mul, duration, force): + if not self.is_acquired(): + return self.err("port is not owned") + + if self.state == self.STATE_DOWN: + return self.err("Unable to start traffic - port is down") + + if self.state == self.STATE_IDLE: + return self.err("Unable to start traffic - no streams attached to port") + + if self.state == self.STATE_TX: + return self.err("Unable to start traffic - port is already transmitting") + + params = {"handler": self.handler, + "port_id": self.port_id, + "mul": mul, + "duration": duration, + "force": force} + + rc = self.transmit("start_traffic", params) + if rc.bad(): + return self.err(rc.err()) + + self.state = self.STATE_TX + + return self.ok() + + # stop traffic + # with force ignores the cached state and sends the command + def stop (self, force = False): + + if not self.is_acquired(): + return self.err("port is not owned") + + # port is already stopped + if not force: + if (self.state == self.STATE_IDLE) or (self.state == self.state == self.STATE_STREAMS): + return self.ok() + + + + params = {"handler": self.handler, + "port_id": self.port_id} + + rc = self.transmit("stop_traffic", params) + if rc.bad(): + return self.err(rc.err()) + + # only valid state after stop + self.state = self.STATE_STREAMS + + return self.ok() + + def pause (self): + + if not self.is_acquired(): + return self.err("port is not owned") + + if (self.state != self.STATE_TX) : + return self.err("port is not transmitting") + + params = {"handler": self.handler, + "port_id": self.port_id} + + rc = self.transmit("pause_traffic", params) + if rc.bad(): + return self.err(rc.err()) + + # only valid state after stop + self.state = self.STATE_PAUSE + + return self.ok() + + + def resume (self): + + if not self.is_acquired(): + return self.err("port is not owned") + + if (self.state != self.STATE_PAUSE) : + return self.err("port is not in pause mode") + + params = {"handler": self.handler, + "port_id": self.port_id} + + rc = self.transmit("resume_traffic", params) + if rc.bad(): + return self.err(rc.err()) + + # only valid state after stop + self.state = self.STATE_TX + + return self.ok() + + + def update (self, mul, force): + + if not self.is_acquired(): + return self.err("port is not owned") + + if (self.state != self.STATE_TX) : + return self.err("port is not transmitting") + + params = {"handler": self.handler, + "port_id": self.port_id, + "mul": mul, + "force": force} + + rc = self.transmit("update_traffic", params) + if rc.bad(): + return self.err(rc.err()) + + return self.ok() + + + def validate (self): + + if not self.is_acquired(): + return self.err("port is not owned") + + if (self.state == self.STATE_DOWN): + return self.err("port is down") + + if (self.state == self.STATE_IDLE): + return self.err("no streams attached to port") + + params = {"handler": self.handler, + "port_id": self.port_id} + + rc = self.transmit("validate", params) + if rc.bad(): + return self.err(rc.err()) + + self.profile = rc.data() + + return self.ok() + + def get_profile (self): + return self.profile + + + def print_profile (self, mult, duration): + if not self.get_profile(): + return + + rate = self.get_profile()['rate'] + graph = self.get_profile()['graph'] + + print format_text("Profile Map Per Port\n", 'underline', 'bold') + + factor = mult_to_factor(mult, rate['max_bps_l2'], rate['max_pps'], rate['max_line_util']) + + print "Profile max BPS L2 (base / req): {:^12} / {:^12}".format(format_num(rate['max_bps_l2'], suffix = "bps"), + format_num(rate['max_bps_l2'] * factor, suffix = "bps")) + + print "Profile max BPS L1 (base / req): {:^12} / {:^12}".format(format_num(rate['max_bps_l1'], suffix = "bps"), + format_num(rate['max_bps_l1'] * factor, suffix = "bps")) + + print "Profile max PPS (base / req): {:^12} / {:^12}".format(format_num(rate['max_pps'], suffix = "pps"), + format_num(rate['max_pps'] * factor, suffix = "pps"),) + + print "Profile line util. (base / req): {:^12} / {:^12}".format(format_percentage(rate['max_line_util']), + format_percentage(rate['max_line_util'] * factor)) + + + # duration + exp_time_base_sec = graph['expected_duration'] / (1000 * 1000) + exp_time_factor_sec = exp_time_base_sec / factor + + # user configured a duration + if duration > 0: + if exp_time_factor_sec > 0: + exp_time_factor_sec = min(exp_time_factor_sec, duration) + else: + exp_time_factor_sec = duration + + + print "Duration (base / req): {:^12} / {:^12}".format(format_time(exp_time_base_sec), + format_time(exp_time_factor_sec)) + print "\n" + + + def get_port_state_name(self): + return self.STATES_MAP.get(self.state, "Unknown") + + ################# stats handler ###################### + def generate_port_stats(self): + return self.port_stats.generate_stats() + + def generate_port_status(self): + return {"type": self.driver, + "maximum": "{speed} Gb/s".format(speed=self.speed), + "status": self.get_port_state_name() + } + + def clear_stats(self): + return self.port_stats.clear_stats() + + + def get_stats (self): + return self.port_stats.get_stats() + + + def invalidate_stats(self): + return self.port_stats.invalidate() + + ################# stream printout ###################### + def generate_loaded_streams_sum(self, stream_id_list): + if self.state == self.STATE_DOWN or self.state == self.STATE_STREAMS: + return {} + streams_data = {} + + if not stream_id_list: + # if no mask has been provided, apply to all streams on port + stream_id_list = self.streams.keys() + + + streams_data = {stream_id: self.streams[stream_id].metadata.get('stream_sum', ["N/A"] * 6) + for stream_id in stream_id_list + if stream_id in self.streams} + + + return {"referring_file" : "", + "streams" : streams_data} + + @staticmethod + def _generate_stream_metadata(stream): + meta_dict = {} + # create packet stream description + #pkt_bld_obj = packet_builder.CTRexPktBuilder() + #pkt_bld_obj.load_from_stream_obj(compiled_stream_obj) + # generate stream summary based on that + + #next_stream = "None" if stream['next_stream_id']==-1 else stream['next_stream_id'] + + meta_dict['stream_sum'] = OrderedDict([("id", stream.get_id()), + ("packet_type", "FIXME!!!"), + ("length", "FIXME!!!"), + ("mode", "FIXME!!!"), + ("rate_pps", "FIXME!!!"), + ("next_stream", "FIXME!!!") + ]) + return meta_dict + + ################# events handler ###################### + def async_event_port_stopped (self): + self.state = self.STATE_STREAMS + + + def async_event_port_started (self): + self.state = self.STATE_TX + + + def async_event_port_paused (self): + self.state = self.STATE_PAUSE + + + def async_event_port_resumed (self): + self.state = self.STATE_TX + + def async_event_forced_acquired (self): + self.handler = None + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_sim.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_sim.py new file mode 100644 index 00000000..1252b752 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_sim.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Itay Marom +Cisco Systems, Inc. + +Copyright (c) 2015-2015 Cisco Systems, Inc. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# simulator can be run as a standalone +import trex_stl_ext + +from trex_stl_exceptions import * +from yaml import YAMLError +from trex_stl_streams import * +from utils import parsing_opts +from trex_stl_client import STLClient + +import re +import json + + + +import argparse +import tempfile +import subprocess +import os +from dpkt import pcap +from operator import itemgetter + +class BpSimException(Exception): + pass + +def merge_cap_files (pcap_file_list, out_filename, delete_src = False): + + out_pkts = [] + if not all([os.path.exists(f) for f in pcap_file_list]): + print "failed to merge cap file list...\nnot all files exist\n" + return + + # read all packets to a list + for src in pcap_file_list: + f = open(src, 'r') + reader = pcap.Reader(f) + pkts = reader.readpkts() + out_pkts += pkts + f.close() + if delete_src: + os.unlink(src) + + # sort by the timestamp + out_pkts = sorted(out_pkts, key=itemgetter(0)) + + + out = open(out_filename, 'w') + out_writer = pcap.Writer(out) + + for ts, pkt in out_pkts: + out_writer.writepkt(pkt, ts) + + out.close() + + + +# stateless simulation +class STLSim(object): + def __init__ (self, bp_sim_path = None, handler = 0, port_id = 0): + + if not bp_sim_path: + # auto find scripts + m = re.match(".*/trex-core", os.getcwd()) + if not m: + raise STLError('cannot find BP sim path, please provide it') + + self.bp_sim_path = os.path.join(m.group(0), 'scripts') + + else: + self.bp_sim_path = bp_sim_path + + # dummies + self.handler = handler + self.port_id = port_id + + + def generate_start_cmd (self, mult = "1", force = True, duration = -1): + return {"id":1, + "jsonrpc": "2.0", + "method": "start_traffic", + "params": {"handler": self.handler, + "force": force, + "port_id": self.port_id, + "mul": parsing_opts.decode_multiplier(mult), + "duration": duration} + } + + + + # run command + # input_list - a list of streams or YAML files + # outfile - pcap file to save output, if None its a dry run + # dp_core_count - how many DP cores to use + # dp_core_index - simulate only specific dp core without merging + # is_debug - debug or release image + # pkt_limit - how many packets to simulate + # mult - multiplier + # mode - can be 'valgrind, 'gdb', 'json' or 'none' + def run (self, + input_list, + outfile = None, + dp_core_count = 1, + dp_core_index = None, + is_debug = True, + pkt_limit = 5000, + mult = "1", + duration = -1, + mode = 'none'): + + if not mode in ['none', 'gdb', 'valgrind', 'json']: + raise STLArgumentError('mode', mode) + + # listify + input_list = input_list if isinstance(input_list, list) else [input_list] + + # check streams arguments + if not all([isinstance(i, (STLStream, str)) for i in input_list]): + raise STLArgumentError('input_list', input_list) + + # split to two type + input_files = [x for x in input_list if isinstance(x, str)] + stream_list = [x for x in input_list if isinstance(x, STLStream)] + + # handle YAMLs + for input_file in input_files: + stream_list += STLClient.load_profile(input_file) + + + # load streams + cmds_json = [] + + id = 1 + + lookup = {} + # allocate IDs + for stream in stream_list: + if stream.get_id() == None: + stream.set_id(id) + id += 1 + + lookup[stream.get_name()] = stream.get_id() + + # resolve names + for stream in stream_list: + next_id = -1 + next = stream.get_next() + if next: + if not next in lookup: + raise STLError("stream dependency error - unable to find '{0}'".format(next)) + next_id = lookup[next] + + stream.fields['next_stream_id'] = next_id + + for stream in stream_list: + cmd = {"id":1, + "jsonrpc": "2.0", + "method": "add_stream", + "params": {"handler": self.handler, + "port_id": self.port_id, + "stream_id": stream.get_id(), + "stream": stream.to_json()} + } + + cmds_json.append(cmd) + + # generate start command + cmds_json.append(self.generate_start_cmd(mult = mult, + force = True, + duration = duration)) + + if mode == 'json': + print json.dumps(cmds_json, indent = 4, separators=(',', ': '), sort_keys = True) + return + + # start simulation + self.outfile = outfile + self.dp_core_count = dp_core_count + self.dp_core_index = dp_core_index + self.is_debug = is_debug + self.pkt_limit = pkt_limit + self.mult = mult + self.duration = duration, + self.mode = mode + + self.__run(cmds_json) + + + # internal run + def __run (self, cmds_json): + + # write to temp file + f = tempfile.NamedTemporaryFile(delete = False) + f.write(json.dumps(cmds_json)) + f.close() + + # launch bp-sim + try: + self.execute_bp_sim(f.name) + finally: + os.unlink(f.name) + + + + def execute_bp_sim (self, json_filename): + if self.is_debug: + exe = os.path.join(self.bp_sim_path, 'bp-sim-64-debug') + else: + exe = os.path.join(self.bp_sim_path, 'bp-sim-64') + + if not os.path.exists(exe): + raise STLError("'{0}' does not exists, please build it before calling the simulation".format(exe)) + + + cmd = [exe, + '--pcap', + '--sl', + '--cores', + str(self.dp_core_count), + '--limit', + str(self.pkt_limit), + '-f', + json_filename] + + # out or dry + if not self.outfile: + cmd += ['--dry'] + cmd += ['-o', '/dev/null'] + else: + cmd += ['-o', self.outfile] + + if self.dp_core_index != None: + cmd += ['--core_index', str(self.dp_core_index)] + + if self.mode == 'valgrind': + cmd = ['valgrind', '--leak-check=full', '--error-exitcode=1'] + cmd + + elif self.mode == 'gdb': + cmd = ['/bin/gdb', '--args'] + cmd + + print "executing command: '{0}'".format(" ".join(cmd)) + rc = subprocess.call(cmd) + if rc != 0: + raise STLError('simulation has failed with error code {0}'.format(rc)) + + self.merge_results() + + + def merge_results (self): + if not self.outfile: + return + + if self.dp_core_count == 1: + return + + if self.dp_core_index != None: + return + + + print "Mering cores output to a single pcap file...\n" + inputs = ["{0}-{1}".format(self.outfile, index) for index in xrange(0, self.dp_core_count)] + merge_cap_files(inputs, self.outfile, delete_src = True) + + + + +def is_valid_file(filename): + if not os.path.isfile(filename): + raise argparse.ArgumentTypeError("The file '%s' does not exist" % filename) + + return filename + + +def unsigned_int (x): + x = int(x) + if x < 0: + raise argparse.ArgumentTypeError("argument must be >= 0") + + return x + +def setParserOptions(): + parser = argparse.ArgumentParser(prog="stl_sim.py") + + parser.add_argument("-f", + dest ="input_file", + help = "input file in YAML or Python format", + type = is_valid_file, + required=True) + + parser.add_argument("-o", + dest = "output_file", + default = None, + help = "output file in ERF format") + + + parser.add_argument("-c", "--cores", + help = "DP core count [default is 1]", + dest = "dp_core_count", + default = 1, + type = int, + choices = xrange(1, 9)) + + parser.add_argument("-n", "--core_index", + help = "Record only a specific core", + dest = "dp_core_index", + default = None, + type = int) + + parser.add_argument("-r", "--release", + help = "runs on release image instead of debug [default is False]", + action = "store_true", + default = False) + + + parser.add_argument("-l", "--limit", + help = "limit test total packet count [default is 5000]", + default = 5000, + type = unsigned_int) + + parser.add_argument('-m', '--multiplier', + help = parsing_opts.match_multiplier_help, + dest = 'mult', + default = "1", + type = parsing_opts.match_multiplier_strict) + + parser.add_argument('-d', '--duration', + help = "run duration", + dest = 'duration', + default = -1, + type = float) + + + group = parser.add_mutually_exclusive_group() + + group.add_argument("-x", "--valgrind", + help = "run under valgrind [default is False]", + action = "store_true", + default = False) + + group.add_argument("-g", "--gdb", + help = "run under GDB [default is False]", + action = "store_true", + default = False) + + group.add_argument("--json", + help = "generate JSON output only to stdout [default is False]", + action = "store_true", + default = False) + + return parser + + +def validate_args (parser, options): + + if options.dp_core_index: + if not options.dp_core_index in xrange(0, options.dp_core_count): + parser.error("DP core index valid range is 0 to {0}".format(options.dp_core_count - 1)) + + # zero is ok - no limit, but other values must be at least as the number of cores + if (options.limit != 0) and options.limit < options.dp_core_count: + parser.error("limit cannot be lower than number of DP cores") + + +def main (): + parser = setParserOptions() + options = parser.parse_args() + + validate_args(parser, options) + + + + if options.valgrind: + mode = 'valgrind' + elif options.gdb: + mode = 'gdb' + elif options.json: + mode = 'json' + else: + mode = 'none' + + try: + r = STLSim() + r.run(input_list = options.input_file, + outfile = options.output_file, + dp_core_count = options.dp_core_count, + dp_core_index = options.dp_core_index, + is_debug = (not options.release), + pkt_limit = options.limit, + mult = options.mult, + duration = options.duration, + mode = mode) + + except KeyboardInterrupt as e: + print "\n\n*** Caught Ctrl + C... Exiting...\n\n" + exit(1) + + except STLError as e: + print e + exit(1) + + exit(0) + +if __name__ == '__main__': + main() + + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py new file mode 100644 index 00000000..3f09e47c --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py @@ -0,0 +1,581 @@ +#!/router/bin/python + +from utils import text_tables +from utils.text_opts import format_text, format_threshold, format_num + +from trex_stl_async_client import CTRexAsyncStats + +from collections import namedtuple, OrderedDict, deque +import copy +import datetime +import time +import re +import math +import copy + +GLOBAL_STATS = 'g' +PORT_STATS = 'p' +PORT_STATUS = 'ps' +ALL_STATS_OPTS = {GLOBAL_STATS, PORT_STATS, PORT_STATUS} +COMPACT = {GLOBAL_STATS, PORT_STATS} + +ExportableStats = namedtuple('ExportableStats', ['raw_data', 'text_table']) + +# use to calculate diffs relative to the previous values +# for example, BW +def calculate_diff (samples): + total = 0.0 + + weight_step = 1.0 / sum(xrange(0, len(samples))) + weight = weight_step + + for i in xrange(0, len(samples) - 1): + current = samples[i] if samples[i] > 0 else 1 + next = samples[i + 1] if samples[i + 1] > 0 else 1 + + s = 100 * ((float(next) / current) - 1.0) + + # block change by 100% + total += (min(s, 100) * weight) + weight += weight_step + + return total + + +# calculate by absolute values and not relatives (useful for CPU usage in % and etc.) +def calculate_diff_raw (samples): + total = 0.0 + + weight_step = 1.0 / sum(xrange(0, len(samples))) + weight = weight_step + + for i in xrange(0, len(samples) - 1): + current = samples[i] + next = samples[i + 1] + + total += ( (next - current) * weight ) + weight += weight_step + + return total + + +class CTRexInfoGenerator(object): + """ + This object is responsible of generating stats and information from objects maintained at + STLClient and the ports. + """ + + def __init__(self, global_stats_ref, ports_dict_ref): + self._global_stats = global_stats_ref + self._ports_dict = ports_dict_ref + + def generate_single_statistic(self, port_id_list, statistic_type): + if statistic_type == GLOBAL_STATS: + return self._generate_global_stats() + elif statistic_type == PORT_STATS: + return self._generate_port_stats(port_id_list) + pass + elif statistic_type == PORT_STATUS: + return self._generate_port_status(port_id_list) + else: + # ignore by returning empty object + return {} + + def generate_streams_info(self, port_id_list, stream_id_list): + relevant_ports = self.__get_relevant_ports(port_id_list) + + return_data = OrderedDict() + + for port_obj in relevant_ports: + streams_data = self._generate_single_port_streams_info(port_obj, stream_id_list) + if not streams_data: + continue + hdr_key = "Port {port}: {yaml_file}".format(port= port_obj.port_id, + yaml_file= streams_data.raw_data.get('referring_file', '')) + + # TODO: test for other ports with same stream structure, and join them + return_data[hdr_key] = streams_data + + return return_data + + def _generate_global_stats(self): + stats_data = self._global_stats.generate_stats() + + # build table representation + stats_table = text_tables.TRexTextInfo() + stats_table.set_cols_align(["l", "l"]) + + stats_table.add_rows([[k.replace("_", " ").title(), v] + for k, v in stats_data.iteritems()], + header=False) + + return {"global_statistics": ExportableStats(stats_data, stats_table)} + + def _generate_port_stats(self, port_id_list): + relevant_ports = self.__get_relevant_ports(port_id_list) + + return_stats_data = {} + per_field_stats = OrderedDict([("owner", []), + ("state", []), + ("--", []), + ("Tx bps", []), + ("Tx pps", []), + + ("---", []), + ("Rx bps", []), + ("Rx pps", []), + + ("----", []), + ("opackets", []), + ("ipackets", []), + ("obytes", []), + ("ibytes", []), + ("tx-bytes", []), + ("rx-bytes", []), + ("tx-pkts", []), + ("rx-pkts", []), + + ("-----", []), + ("oerrors", []), + ("ierrors", []), + + ] + ) + + total_stats = CPortStats(None) + + for port_obj in relevant_ports: + # fetch port data + port_stats = port_obj.generate_port_stats() + + total_stats += port_obj.port_stats + + # populate to data structures + return_stats_data[port_obj.port_id] = port_stats + self.__update_per_field_dict(port_stats, per_field_stats) + + total_cols = len(relevant_ports) + header = ["port"] + [port.port_id for port in relevant_ports] + + if (total_cols > 1): + self.__update_per_field_dict(total_stats.generate_stats(), per_field_stats) + header += ['total'] + total_cols += 1 + + stats_table = text_tables.TRexTextTable() + stats_table.set_cols_align(["l"] + ["r"] * total_cols) + stats_table.set_cols_width([10] + [17] * total_cols) + stats_table.set_cols_dtype(['t'] + ['t'] * total_cols) + + stats_table.add_rows([[k] + v + for k, v in per_field_stats.iteritems()], + header=False) + + stats_table.header(header) + + return {"port_statistics": ExportableStats(return_stats_data, stats_table)} + + def _generate_port_status(self, port_id_list): + relevant_ports = self.__get_relevant_ports(port_id_list) + + return_stats_data = {} + per_field_status = OrderedDict([("type", []), + ("maximum", []), + ("status", []) + ] + ) + + for port_obj in relevant_ports: + # fetch port data + # port_stats = self._async_stats.get_port_stats(port_obj.port_id) + port_status = port_obj.generate_port_status() + + # populate to data structures + return_stats_data[port_obj.port_id] = port_status + + self.__update_per_field_dict(port_status, per_field_status) + + stats_table = text_tables.TRexTextTable() + stats_table.set_cols_align(["l"] + ["c"]*len(relevant_ports)) + stats_table.set_cols_width([10] + [20] * len(relevant_ports)) + + stats_table.add_rows([[k] + v + for k, v in per_field_status.iteritems()], + header=False) + stats_table.header(["port"] + [port.port_id + for port in relevant_ports]) + + return {"port_status": ExportableStats(return_stats_data, stats_table)} + + def _generate_single_port_streams_info(self, port_obj, stream_id_list): + + return_streams_data = port_obj.generate_loaded_streams_sum(stream_id_list) + + if not return_streams_data.get("streams"): + # we got no streams available + return None + + # FORMAT VALUES ON DEMAND + + # because we mutate this - deep copy before + return_streams_data = copy.deepcopy(return_streams_data) + + for stream_id, stream_id_sum in return_streams_data['streams'].iteritems(): + stream_id_sum['rate_pps'] = format_num(stream_id_sum['rate_pps'], suffix='pps') + stream_id_sum['packet_type'] = self._trim_packet_headers(stream_id_sum['packet_type'], 20) + + info_table = text_tables.TRexTextTable() + info_table.set_cols_align(["c"] + ["l"] + ["r"] + ["c"] + ["r"] + ["c"]) + info_table.set_cols_width([10] + [20] + [8] + [16] + [10] + [12]) + info_table.set_cols_dtype(["t"] + ["t"] + ["t"] + ["t"] + ["t"] + ["t"]) + + info_table.add_rows([v.values() + for k, v in return_streams_data['streams'].iteritems()], + header=False) + info_table.header(["ID", "packet type", "length", "mode", "rate", "next stream"]) + + return ExportableStats(return_streams_data, info_table) + + + def __get_relevant_ports(self, port_id_list): + # fetch owned ports + ports = [port_obj + for _, port_obj in self._ports_dict.iteritems() + if port_obj.port_id in port_id_list] + + # display only the first FOUR options, by design + if len(ports) > 4: + print format_text("[WARNING]: ", 'magenta', 'bold'), format_text("displaying up to 4 ports", 'magenta') + ports = ports[:4] + return ports + + def __update_per_field_dict(self, dict_src_data, dict_dest_ref): + for key, val in dict_src_data.iteritems(): + if key in dict_dest_ref: + dict_dest_ref[key].append(val) + + @staticmethod + def _trim_packet_headers(headers_str, trim_limit): + if len(headers_str) < trim_limit: + # do nothing + return headers_str + else: + return (headers_str[:trim_limit-3] + "...") + + + +class CTRexStats(object): + """ This is an abstract class to represent a stats object """ + + def __init__(self): + self.reference_stats = {} + self.latest_stats = {} + self.last_update_ts = time.time() + self.history = deque(maxlen = 10) + + def __getitem__(self, item): + # override this to allow quick and clean access to fields + if not item in self.latest_stats: + return "N/A" + + # item must exist + m = re.search('_(([a-z])ps)$', item) + if m: + # this is a non-relative item + unit = m.group(2) + if unit == "b": + return self.get(item, format=True, suffix="b/sec") + elif unit == "p": + return self.get(item, format=True, suffix="pkt/sec") + else: + return self.get(item, format=True, suffix=m.group(1)) + + m = re.search('^[i|o](a-z+)$', item) + if m: + # this is a non-relative item + type = m.group(1) + if type == "bytes": + return self.get_rel(item, format=True, suffix="B") + elif type == "packets": + return self.get_rel(item, format=True, suffix="pkts") + else: + # do not format with suffix + return self.get_rel(item, format=True) + + # can't match to any known pattern, return N/A + return "N/A" + + + def generate_stats(self): + # must be implemented by designated classes (such as port/ global stats) + raise NotImplementedError() + + def update(self, snapshot): + # update + self.latest_stats = snapshot + self.history.append(snapshot) + + diff_time = time.time() - self.last_update_ts + + # 3 seconds is too much - this is the new reference + if (not self.reference_stats) or (diff_time > 3): + self.reference_stats = self.latest_stats + + self.last_update_ts = time.time() + + + def clear_stats(self): + self.reference_stats = self.latest_stats + + + def invalidate (self): + self.latest_stats = {} + + def get(self, field, format=False, suffix=""): + if not field in self.latest_stats: + return "N/A" + if not format: + return self.latest_stats[field] + else: + return format_num(self.latest_stats[field], suffix) + + def get_rel(self, field, format=False, suffix=""): + if not field in self.latest_stats: + return "N/A" + + if not format: + if not field in self.reference_stats: + print "REF: " + str(self.reference_stats) + print "BASE: " + str(self.latest_stats) + + return (self.latest_stats[field] - self.reference_stats[field]) + else: + return format_num(self.latest_stats[field] - self.reference_stats[field], suffix) + + # get trend for a field + def get_trend (self, field, use_raw = False, percision = 10.0): + if not field in self.latest_stats: + return 0 + + # not enough history - no trend + if len(self.history) < 5: + return 0 + + # absolute value is too low 0 considered noise + if self.latest_stats[field] < percision: + return 0 + + field_samples = [sample[field] for sample in self.history] + + if use_raw: + return calculate_diff_raw(field_samples) + else: + return calculate_diff(field_samples) + + + def get_trend_gui (self, field, show_value = False, use_raw = False, up_color = 'red', down_color = 'green'): + v = self.get_trend(field, use_raw) + + value = abs(v) + arrow = u'\u25b2' if v > 0 else u'\u25bc' + color = up_color if v > 0 else down_color + + # change in 1% is not meaningful + if value < 1: + return "" + + elif value > 5: + + if show_value: + return format_text(u"{0}{0}{0} {1:.2f}%".format(arrow,v), color) + else: + return format_text(u"{0}{0}{0}".format(arrow), color) + + elif value > 2: + + if show_value: + return format_text(u"{0}{0} {1:.2f}%".format(arrow,v), color) + else: + return format_text(u"{0}{0}".format(arrow), color) + + else: + if show_value: + return format_text(u"{0} {1:.2f}%".format(arrow,v), color) + else: + return format_text(u"{0}".format(arrow), color) + + + +class CGlobalStats(CTRexStats): + + def __init__(self, connection_info, server_version, ports_dict_ref): + super(CGlobalStats, self).__init__() + self.connection_info = connection_info + self.server_version = server_version + self._ports_dict = ports_dict_ref + + def get_stats (self): + stats = {} + + # absolute + stats['cpu_util'] = self.get("m_cpu_util") + stats['tx_bps'] = self.get("m_tx_bps") + stats['tx_pps'] = self.get("m_tx_pps") + + stats['rx_bps'] = self.get("m_rx_bps") + stats['rx_pps'] = self.get("m_rx_pps") + stats['rx_drop_bps'] = self.get("m_rx_drop_bps") + + # relatives + stats['queue_full'] = self.get_rel("m_total_queue_full") + + return stats + + + def generate_stats(self): + return OrderedDict([("connection", "{host}, Port {port}".format(host=self.connection_info.get("server"), + port=self.connection_info.get("sync_port"))), + ("version", "{ver}, UUID: {uuid}".format(ver=self.server_version.get("version", "N/A"), + uuid="N/A")), + + ("cpu_util", u"{0}% {1}".format( format_threshold(self.get("m_cpu_util"), [85, 100], [0, 85]), + self.get_trend_gui("m_cpu_util", use_raw = True))), + + (" ", ""), + + ("total_tx", u"{0} {1}".format( self.get("m_tx_bps", format=True, suffix="b/sec"), + self.get_trend_gui("m_tx_bps"))), + + ("total_rx", u"{0} {1}".format( self.get("m_rx_bps", format=True, suffix="b/sec"), + self.get_trend_gui("m_rx_bps"))), + + ("total_pps", u"{0} {1}".format( self.get("m_tx_pps", format=True, suffix="pkt/sec"), + self.get_trend_gui("m_tx_pps"))), + + (" ", ""), + + ("drop_rate", "{0}".format( format_num(self.get("m_rx_drop_bps"), + suffix = 'b/sec', + opts = 'green' if (self.get("m_rx_drop_bps")== 0) else 'red'))), + + ("queue_full", "{0}".format( format_num(self.get_rel("m_total_queue_full"), + suffix = 'pkts', + compact = False, + opts = 'green' if (self.get_rel("m_total_queue_full")== 0) else 'red'))), + + ] + ) + +class CPortStats(CTRexStats): + + def __init__(self, port_obj): + super(CPortStats, self).__init__() + self._port_obj = port_obj + + @staticmethod + def __merge_dicts (target, src): + for k, v in src.iteritems(): + if k in target: + target[k] += v + else: + target[k] = v + + + def __add__ (self, x): + if not isinstance(x, CPortStats): + raise TypeError("cannot add non stats object to stats") + + # main stats + if not self.latest_stats: + self.latest_stats = {} + + self.__merge_dicts(self.latest_stats, x.latest_stats) + + # reference stats + if x.reference_stats: + if not self.reference_stats: + self.reference_stats = x.reference_stats.copy() + else: + self.__merge_dicts(self.reference_stats, x.reference_stats) + + # history + if not self.history: + self.history = copy.deepcopy(x.history) + else: + for h1, h2 in zip(self.history, x.history): + self.__merge_dicts(h1, h2) + + return self + + # for port we need to do something smarter + def get_stats (self): + stats = {} + + stats['opackets'] = self.get_rel("opackets") + stats['ipackets'] = self.get_rel("ipackets") + stats['obytes'] = self.get_rel("obytes") + stats['ibytes'] = self.get_rel("ibytes") + stats['oerrors'] = self.get_rel("oerrors") + stats['ierrors'] = self.get_rel("ierrors") + stats['tx_bps'] = self.get("m_total_tx_bps") + stats['tx_pps'] = self.get("m_total_tx_pps") + stats['rx_bps'] = self.get("m_total_rx_bps") + stats['rx_pps'] = self.get("m_total_rx_pps") + + return stats + + + def generate_stats(self): + + state = self._port_obj.get_port_state_name() if self._port_obj else "" + if state == "ACTIVE": + state = format_text(state, 'green', 'bold') + elif state == "PAUSE": + state = format_text(state, 'magenta', 'bold') + else: + state = format_text(state, 'bold') + + return {"owner": self._port_obj.user if self._port_obj else "", + "state": "{0}".format(state), + + "--": " ", + "---": " ", + "----": " ", + "-----": " ", + + "Tx bps": u"{0} {1}".format(self.get_trend_gui("m_total_tx_bps", show_value = False), + self.get("m_total_tx_bps", format = True, suffix = "bps")), + + "Rx bps": u"{0} {1}".format(self.get_trend_gui("m_total_rx_bps", show_value = False), + self.get("m_total_rx_bps", format = True, suffix = "bps")), + + "Tx pps": u"{0} {1}".format(self.get_trend_gui("m_total_tx_pps", show_value = False), + self.get("m_total_tx_pps", format = True, suffix = "pps")), + + "Rx pps": u"{0} {1}".format(self.get_trend_gui("m_total_rx_pps", show_value = False), + self.get("m_total_rx_pps", format = True, suffix = "pps")), + + "opackets" : self.get_rel("opackets"), + "ipackets" : self.get_rel("ipackets"), + "obytes" : self.get_rel("obytes"), + "ibytes" : self.get_rel("ibytes"), + + "tx-bytes": self.get_rel("obytes", format = True, suffix = "B"), + "rx-bytes": self.get_rel("ibytes", format = True, suffix = "B"), + "tx-pkts": self.get_rel("opackets", format = True, suffix = "pkts"), + "rx-pkts": self.get_rel("ipackets", format = True, suffix = "pkts"), + + "oerrors" : format_num(self.get_rel("oerrors"), + compact = False, + opts = 'green' if (self.get_rel("oerrors")== 0) else 'red'), + + "ierrors" : format_num(self.get_rel("ierrors"), + compact = False, + opts = 'green' if (self.get_rel("ierrors")== 0) else 'red'), + + } + + + +if __name__ == "__main__": + pass diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_std.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_std.py new file mode 100644 index 00000000..72a5ea52 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_std.py @@ -0,0 +1,67 @@ +from trex_stl_streams import * +from trex_stl_packet_builder_scapy import * + +# map ports +# will destroy all streams/data on the ports +def stl_map_ports (client, ports = None): + + # by default use all ports + if ports == None: + ports = client.get_all_ports() + + # reset the ports + client.reset(ports) + + # generate streams + base_pkt = CScapyTRexPktBuilder(pkt = Ether()/IP()) + + pkts = 1 + for port in ports: + stream = STLStream(packet = base_pkt, + mode = STLTXSingleBurst(pps = 100000, total_pkts = pkts)) + + client.add_streams(stream, [port]) + pkts = pkts * 2 + + # inject + client.clear_stats() + client.start(ports, mult = "1mpps") + client.wait_on_traffic(ports) + + stats = client.get_stats() + + # cleanup + client.reset(ports = ports) + + table = {} + for port in ports: + table[port] = None + + for port in ports: + ipackets = stats[port]["ipackets"] + + exp = 1 + while ipackets >= exp: + if ((ipackets & exp) == (exp)): + source = int(math.log(exp, 2)) + table[source] = port + + exp *= 2 + + if not all(x != None for x in table.values()): + raise STLError('unable to map ports') + + dir_a = set() + dir_b = set() + for src, dst in table.iteritems(): + # src is not in + if src not in (dir_a, dir_b): + if dst in dir_a: + dir_b.add(src) + else: + dir_a.add(src) + + table['dir'] = [list(dir_a), list(dir_b)] + + return table + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py new file mode 100644 index 00000000..d8e86fef --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py @@ -0,0 +1,259 @@ +#!/router/bin/python + +from trex_stl_exceptions import * +from trex_stl_packet_builder_interface import CTrexPktBuilderInterface +from trex_stl_packet_builder_scapy import CScapyTRexPktBuilder, Ether, IP +from collections import OrderedDict, namedtuple + +from dpkt import pcap +import random +import yaml +import base64 +import string + +def random_name (l): + return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(l)) + + +# base class for TX mode +class STLTXMode(object): + def __init__ (self): + self.fields = {} + + def to_json (self): + return self.fields + + +# continuous mode +class STLTXCont(STLTXMode): + + def __init__ (self, pps = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + super(STLTXCont, self).__init__() + + self.fields['type'] = 'continuous' + self.fields['pps'] = pps + + +# single burst mode +class STLTXSingleBurst(STLTXMode): + + def __init__ (self, pps = 1, total_pkts = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + if not isinstance(total_pkts, int): + raise STLArgumentError('total_pkts', total_pkts) + + super(STLTXSingleBurst, self).__init__() + + self.fields['type'] = 'single_burst' + self.fields['pps'] = pps + self.fields['total_pkts'] = total_pkts + + +# multi burst mode +class STLTXMultiBurst(STLTXMode): + + def __init__ (self, + pps = 1, + pkts_per_burst = 1, + ibg = 0.0, + count = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + if not isinstance(pkts_per_burst, int): + raise STLArgumentError('pkts_per_burst', pkts_per_burst) + + if not isinstance(ibg, (int, float)): + raise STLArgumentError('ibg', ibg) + + if not isinstance(count, int): + raise STLArgumentError('count', count) + + super(STLTXMultiBurst, self).__init__() + + self.fields['type'] = 'multi_burst' + self.fields['pps'] = pps + self.fields['pkts_per_burst'] = pkts_per_burst + self.fields['ibg'] = ibg + self.fields['count'] = count + + +class STLStream(object): + + def __init__ (self, + name = random_name(8), + packet = None, + mode = STLTXCont(1), + enabled = True, + self_start = True, + isg = 0.0, + rx_stats = None, + next = None, + stream_id = None): + + # type checking + if not isinstance(mode, STLTXMode): + raise STLArgumentError('mode', mode) + + if packet and not isinstance(packet, CTrexPktBuilderInterface): + raise STLArgumentError('packet', packet) + + if not isinstance(enabled, bool): + raise STLArgumentError('enabled', enabled) + + if not isinstance(self_start, bool): + raise STLArgumentError('self_start', self_start) + + if not isinstance(isg, (int, float)): + raise STLArgumentError('isg', isg) + + if (type(mode) == STLTXCont) and (next != None): + raise STLError("continuous stream cannot have a next stream ID") + + # tag for the stream and next - can be anything + self.name = name + self.next = next + self.set_id(stream_id) + + self.fields = {} + + # basic fields + self.fields['enabled'] = enabled + self.fields['self_start'] = self_start + self.fields['isg'] = isg + + # mode + self.fields['mode'] = mode.to_json() + + self.fields['packet'] = {} + self.fields['vm'] = {} + + if not packet: + packet = CScapyTRexPktBuilder(pkt = Ether()/IP()) + + # packet builder + packet.compile() + + # packet and VM + self.fields['packet'] = packet.dump_pkt() + self.fields['vm'] = packet.get_vm_data() + + self.fields['rx_stats'] = {} + if not rx_stats: + self.fields['rx_stats']['enabled'] = False + + + def __str__ (self): + return json.dumps(self.fields, indent = 4, separators=(',', ': '), sort_keys = True) + + def to_json (self): + return self.fields + + def get_id (self): + return self.id + + def set_id (self, id): + self.id = id + + def get_name (self): + return self.name + + def get_next (self): + return self.next + + + def to_yaml (self): + fields = dict(stream.fields) + + + @staticmethod + def dump_to_yaml (stream_list, yaml_file = None): + + # type check + if isinstance(stream_list, STLStream): + stream_list = [stream_list] + + if not all([isinstance(stream, STLStream) for stream in stream_list]): + raise STLArgumentError('stream_list', stream_list) + + + names = {} + for i, stream in enumerate(stream_list): + names[stream.get_id()] = "stream-{0}".format(i) + + yaml_lst = [] + for stream in stream_list: + + fields = dict(stream.fields) + + # handle the next stream id + if fields['next_stream_id'] == -1: + del fields['next_stream_id'] + + else: + if not stream.get_id() in names: + raise STLError('broken dependencies in stream list') + + fields['next_stream'] = names[stream.get_id()] + + # add to list + yaml_lst.append({'name': names[stream.get_id()], 'stream': fields}) + + # write to file + x = yaml.dump(yaml_lst, default_flow_style=False) + if yaml_file: + with open(yaml_file, 'w') as f: + f.write(x) + + return x + + + @staticmethod + def load_from_yaml (yaml_file): + + with open(yaml_file, 'r') as f: + yaml_str = f.read() + + + # load YAML + lst = yaml.load(yaml_str) + + # decode to streams + streams = [] + for stream in lst: + # for defaults + defaults = STLStream() + s = STLStream(packet = None, + mode = STLTXCont(1), + enabled = True, + self_start = True, + isg = 0.0, + rx_stats = None, + next_stream_id = -1, + stream_id = None + ) + + streams.append(s) + + return streams + + +class STLYAMLLoader(object): + def __init__ (self, yaml_file): + self.yaml_file = yaml_file + + def load (self): + with open(self.yaml_file, 'r') as f: + yaml_str = f.read() + objects = yaml.load(yaml_str) + + for object in objects: + pass diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_types.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_types.py new file mode 100644 index 00000000..1164076b --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_types.py @@ -0,0 +1,95 @@ + +from collections import namedtuple +from utils.text_opts import * + +RpcCmdData = namedtuple('RpcCmdData', ['method', 'params']) + +class RpcResponseStatus(namedtuple('RpcResponseStatus', ['success', 'id', 'msg'])): + __slots__ = () + def __str__(self): + return "{id:^3} - {msg} ({stat})".format(id=self.id, + msg=self.msg, + stat="success" if self.success else "fail") + +# simple class to represent complex return value +class RC(): + + def __init__ (self, rc = None, data = None, is_warn = False): + self.rc_list = [] + + if (rc != None): + tuple_rc = namedtuple('RC', ['rc', 'data', 'is_warn']) + self.rc_list.append(tuple_rc(rc, data, is_warn)) + + def __nonzero__ (self): + return self.good() + + + def add (self, rc): + self.rc_list += rc.rc_list + + def good (self): + return all([x.rc for x in self.rc_list]) + + def bad (self): + return not self.good() + + def warn (self): + return any([x.is_warn for x in self.rc_list]) + + def data (self): + d = [x.data if x.rc else "" for x in self.rc_list] + return (d if len(d) != 1 else d[0]) + + def err (self): + e = [x.data if not x.rc else "" for x in self.rc_list] + return (e if len(e) != 1 else e[0]) + + def __str__ (self): + s = "" + for x in self.rc_list: + if x.data: + s += format_text("\n{0}".format(x.data), 'bold') + return s + + def prn_func (self, msg, newline = True): + if newline: + print msg + else: + print msg, + + def annotate (self, log_func = None, desc = None, show_status = True): + + if not log_func: + log_func = self.prn_func + + if desc: + log_func(format_text('\n{:<60}'.format(desc), 'bold'), newline = False) + else: + log_func("") + + if self.bad(): + # print all the errors + print "" + for x in self.rc_list: + if not x.rc: + log_func(format_text("\n{0}".format(x.data), 'bold')) + + print "" + if show_status: + log_func(format_text("[FAILED]\n", 'red', 'bold')) + + + else: + if show_status: + log_func(format_text("[SUCCESS]\n", 'green', 'bold')) + + +def RC_OK(data = ""): + return RC(True, data) + +def RC_ERR (err): + return RC(False, err) + +def RC_WARN (warn): + return RC(True, warn, is_warn = True) diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/__init__.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/__init__.py diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py new file mode 100644 index 00000000..117017c3 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py @@ -0,0 +1,47 @@ +import os +import sys +import string +import random + +try: + import pwd +except ImportError: + import getpass + pwd = None + +using_python_3 = True if sys.version_info.major == 3 else False + +def get_current_user(): + if pwd: + return pwd.getpwuid(os.geteuid()).pw_name + else: + return getpass.getuser() + + +def user_input(): + if using_python_3: + return input() + else: + # using python version 2 + return raw_input() + + +def random_id_gen(length=8): + """ + A generator for creating a random chars id of specific length + + :parameters: + length : int + the desired length of the generated id + + default: 8 + + :return: + a random id with each next() request. + """ + id_chars = string.ascii_lowercase + string.digits + while True: + return_id = '' + for i in range(length): + return_id += random.choice(id_chars) + yield return_id diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py new file mode 100755 index 00000000..968bbb7e --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py @@ -0,0 +1,362 @@ +import argparse +from collections import namedtuple +import sys +import re +import os + +ArgumentPack = namedtuple('ArgumentPack', ['name_or_flags', 'options']) +ArgumentGroup = namedtuple('ArgumentGroup', ['type', 'args', 'options']) + + +# list of available parsing options +MULTIPLIER = 1 +MULTIPLIER_STRICT = 2 +PORT_LIST = 3 +ALL_PORTS = 4 +PORT_LIST_WITH_ALL = 5 +FILE_PATH = 6 +FILE_FROM_DB = 7 +SERVER_IP = 8 +STREAM_FROM_PATH_OR_FILE = 9 +DURATION = 10 +FORCE = 11 +DRY_RUN = 12 +XTERM = 13 +TOTAL = 14 +FULL_OUTPUT = 15 + +GLOBAL_STATS = 50 +PORT_STATS = 51 +PORT_STATUS = 52 +STATS_MASK = 53 + +STREAMS_MASK = 60 +# ALL_STREAMS = 61 +# STREAM_LIST_WITH_ALL = 62 + + + +# list of ArgumentGroup types +MUTEX = 1 + +def check_negative(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError("non positive value provided: '{0}'".format(value)) + return ivalue + +def match_time_unit(val): + '''match some val against time shortcut inputs ''' + match = re.match("^(\d+(\.\d+)?)([m|h]?)$", val) + if match: + digit = float(match.group(1)) + unit = match.group(3) + if not unit: + return digit + elif unit == 'm': + return digit*60 + else: + return digit*60*60 + else: + raise argparse.ArgumentTypeError("Duration should be passed in the following format: \n" + "-d 100 : in sec \n" + "-d 10m : in min \n" + "-d 1h : in hours") + +match_multiplier_help = """Multiplier should be passed in the following format: + [number][<empty> | bps | kbps | mbps | gbps | pps | kpps | mpps | %% ]. + no suffix will provide an absoulute factor and percentage + will provide a percentage of the line rate. examples + : '-m 10', '-m 10kbps', '-m 10mpps', '-m 23%%' """ + + +# decodes multiplier +# if allow_update - no +/- is allowed +# divide states between how many entities the +# value should be divided +def decode_multiplier(val, allow_update = False, divide_count = 1): + + # must be string + if not isinstance(val, str): + return None + + # do we allow updates ? +/- + if not allow_update: + match = re.match("^(\d+(\.\d+)?)(bps|kbps|mbps|gbps|pps|kpps|mpps|%?)$", val) + op = None + else: + match = re.match("^(\d+(\.\d+)?)(bps|kbps|mbps|gbps|pps|kpps|mpps|%?)([\+\-])?$", val) + if match: + op = match.group(4) + else: + op = None + + result = {} + + if match: + + value = float(match.group(1)) + unit = match.group(3) + + + + # raw type (factor) + if not unit: + result['type'] = 'raw' + result['value'] = value + + elif unit == 'bps': + result['type'] = 'bps' + result['value'] = value + + elif unit == 'kbps': + result['type'] = 'bps' + result['value'] = value * 1000 + + elif unit == 'mbps': + result['type'] = 'bps' + result['value'] = value * 1000 * 1000 + + elif unit == 'gbps': + result['type'] = 'bps' + result['value'] = value * 1000 * 1000 * 1000 + + elif unit == 'pps': + result['type'] = 'pps' + result['value'] = value + + elif unit == "kpps": + result['type'] = 'pps' + result['value'] = value * 1000 + + elif unit == "mpps": + result['type'] = 'pps' + result['value'] = value * 1000 * 1000 + + elif unit == "%": + result['type'] = 'percentage' + result['value'] = value + + + if op == "+": + result['op'] = "add" + elif op == "-": + result['op'] = "sub" + else: + result['op'] = "abs" + + if result['op'] != 'percentage': + result['value'] = result['value'] / divide_count + + return result + + else: + return None + + +def match_multiplier(val): + '''match some val against multiplier shortcut inputs ''' + result = decode_multiplier(val, allow_update = True) + if not result: + raise argparse.ArgumentTypeError(match_multiplier_help) + + return val + + +def match_multiplier_strict(val): + '''match some val against multiplier shortcut inputs ''' + result = decode_multiplier(val, allow_update = False) + if not result: + raise argparse.ArgumentTypeError(match_multiplier_help) + + return val + + +def is_valid_file(filename): + if not os.path.isfile(filename): + raise argparse.ArgumentTypeError("The file '%s' does not exist" % filename) + + return filename + + +OPTIONS_DB = {MULTIPLIER: ArgumentPack(['-m', '--multiplier'], + {'help': match_multiplier_help, + 'dest': "mult", + 'default': "1", + 'type': match_multiplier}), + + MULTIPLIER_STRICT: ArgumentPack(['-m', '--multiplier'], + {'help': match_multiplier_help, + 'dest': "mult", + 'default': "1", + 'type': match_multiplier_strict}), + + TOTAL: ArgumentPack(['-t', '--total'], + {'help': "traffic will be divided between all ports specified", + 'dest': "total", + 'default': False, + 'action': "store_true"}), + + PORT_LIST: ArgumentPack(['--port'], + {"nargs": '+', + 'dest':'ports', + 'metavar': 'PORTS', + 'type': int, + 'help': "A list of ports on which to apply the command", + 'default': []}), + + ALL_PORTS: ArgumentPack(['-a'], + {"action": "store_true", + "dest": "all_ports", + 'help': "Set this flag to apply the command on all available ports", + 'default': False},), + + DURATION: ArgumentPack(['-d'], + {'action': "store", + 'metavar': 'TIME', + 'dest': 'duration', + 'type': match_time_unit, + 'default': -1.0, + 'help': "Set duration time for TRex."}), + + FORCE: ArgumentPack(['--force'], + {"action": "store_true", + 'default': False, + 'help': "Set if you want to stop active ports before applying new TRex run on them."}), + + FILE_PATH: ArgumentPack(['-f'], + {'metavar': 'FILE', + 'dest': 'file', + 'nargs': 1, + 'type': is_valid_file, + 'help': "File path to YAML file that describes a stream pack. "}), + + FILE_FROM_DB: ArgumentPack(['--db'], + {'metavar': 'LOADED_STREAM_PACK', + 'help': "A stream pack which already loaded into console cache."}), + + SERVER_IP: ArgumentPack(['--server'], + {'metavar': 'SERVER', + 'help': "server IP"}), + + DRY_RUN: ArgumentPack(['-n', '--dry'], + {'action': 'store_true', + 'dest': 'dry', + 'default': False, + 'help': "Dry run - no traffic will be injected"}), + + + XTERM: ArgumentPack(['-x', '--xterm'], + {'action': 'store_true', + 'dest': 'xterm', + 'default': False, + 'help': "Starts TUI in xterm window"}), + + + FULL_OUTPUT: ArgumentPack(['--full'], + {'action': 'store_true', + 'help': "Prompt full info in a JSON format"}), + + GLOBAL_STATS: ArgumentPack(['-g'], + {'action': 'store_true', + 'help': "Fetch only global statistics"}), + + PORT_STATS: ArgumentPack(['-p'], + {'action': 'store_true', + 'help': "Fetch only port statistics"}), + + PORT_STATUS: ArgumentPack(['--ps'], + {'action': 'store_true', + 'help': "Fetch only port status data"}), + + STREAMS_MASK: ArgumentPack(['--streams'], + {"nargs": '+', + 'dest':'streams', + 'metavar': 'STREAMS', + 'type': int, + 'help': "A list of stream IDs to query about. Default: analyze all streams", + 'default': []}), + + + # advanced options + PORT_LIST_WITH_ALL: ArgumentGroup(MUTEX, [PORT_LIST, + ALL_PORTS], + {'required': False}), + + STREAM_FROM_PATH_OR_FILE: ArgumentGroup(MUTEX, [FILE_PATH, + FILE_FROM_DB], + {'required': True}), + STATS_MASK: ArgumentGroup(MUTEX, [GLOBAL_STATS, + PORT_STATS, + PORT_STATUS], + {}) + } + + +class CCmdArgParser(argparse.ArgumentParser): + + def __init__(self, stateless_client, *args, **kwargs): + super(CCmdArgParser, self).__init__(*args, **kwargs) + self.stateless_client = stateless_client + + def parse_args(self, args=None, namespace=None): + try: + opts = super(CCmdArgParser, self).parse_args(args, namespace) + if opts is None: + return None + + # if all ports are marked or + if (getattr(opts, "all_ports", None) == True) or (getattr(opts, "ports", None) == []): + opts.ports = self.stateless_client.get_all_ports() + + # so maybe we have ports configured + elif getattr(opts, "ports", None): + for port in opts.ports: + if not self.stateless_client._validate_port_list([port]): + self.error("port id '{0}' is not a valid port id\n".format(port)) + + return opts + + except SystemExit: + # recover from system exit scenarios, such as "help", or bad arguments. + return None + + +def get_flags (opt): + return OPTIONS_DB[opt].name_or_flags + +def gen_parser(stateless_client, op_name, description, *args): + parser = CCmdArgParser(stateless_client, prog=op_name, conflict_handler='resolve', + description=description) + for param in args: + try: + + if isinstance(param, int): + argument = OPTIONS_DB[param] + else: + argument = param + + if isinstance(argument, ArgumentGroup): + if argument.type == MUTEX: + # handle as mutually exclusive group + group = parser.add_mutually_exclusive_group(**argument.options) + for sub_argument in argument.args: + group.add_argument(*OPTIONS_DB[sub_argument].name_or_flags, + **OPTIONS_DB[sub_argument].options) + else: + # ignore invalid objects + continue + elif isinstance(argument, ArgumentPack): + parser.add_argument(*argument.name_or_flags, + **argument.options) + else: + # ignore invalid objects + continue + except KeyError as e: + cause = e.args[0] + raise KeyError("The attribute '{0}' is missing as a field of the {1} option.\n".format(cause, param)) + return parser + + +if __name__ == "__main__": + pass
\ No newline at end of file diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py new file mode 100644 index 00000000..78a0ab1f --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py @@ -0,0 +1,192 @@ +import json +import re + +TEXT_CODES = {'bold': {'start': '\x1b[1m', + 'end': '\x1b[22m'}, + 'cyan': {'start': '\x1b[36m', + 'end': '\x1b[39m'}, + 'blue': {'start': '\x1b[34m', + 'end': '\x1b[39m'}, + 'red': {'start': '\x1b[31m', + 'end': '\x1b[39m'}, + 'magenta': {'start': '\x1b[35m', + 'end': '\x1b[39m'}, + 'green': {'start': '\x1b[32m', + 'end': '\x1b[39m'}, + 'yellow': {'start': '\x1b[33m', + 'end': '\x1b[39m'}, + 'underline': {'start': '\x1b[4m', + 'end': '\x1b[24m'}} + +class TextCodesStripper: + keys = [re.escape(v['start']) for k,v in TEXT_CODES.iteritems()] + keys += [re.escape(v['end']) for k,v in TEXT_CODES.iteritems()] + pattern = re.compile("|".join(keys)) + + @staticmethod + def strip (s): + return re.sub(TextCodesStripper.pattern, '', s) + +def format_num (size, suffix = "", compact = True, opts = ()): + txt = "NaN" + + if type(size) == str: + return "N/A" + + u = '' + + if compact: + for unit in ['','K','M','G','T','P']: + if abs(size) < 1000.0: + u = unit + break + size /= 1000.0 + + if isinstance(size, float): + txt = "%3.2f" % (size) + else: + txt = "{:,}".format(size) + + if u or suffix: + txt += " {:}{:}".format(u, suffix) + + if isinstance(opts, tuple): + return format_text(txt, *opts) + else: + return format_text(txt, (opts)) + + + +def format_time (t_sec): + if t_sec < 0: + return "infinite" + + if t_sec < 1: + # low numbers + for unit in ['ms', 'usec', 'ns']: + t_sec *= 1000.0 + if t_sec >= 1.0: + return '{:,.2f} [{:}]'.format(t_sec, unit) + + return "NaN" + + else: + # seconds + if t_sec < 60.0: + return '{:,.2f} [{:}]'.format(t_sec, 'sec') + + # minutes + t_sec /= 60.0 + if t_sec < 60.0: + return '{:,.2f} [{:}]'.format(t_sec, 'minutes') + + # hours + t_sec /= 60.0 + if t_sec < 24.0: + return '{:,.2f} [{:}]'.format(t_sec, 'hours') + + # days + t_sec /= 24.0 + return '{:,.2f} [{:}]'.format(t_sec, 'days') + + +def format_percentage (size): + return "%0.2f %%" % (size) + +def bold(text): + return text_attribute(text, 'bold') + + +def cyan(text): + return text_attribute(text, 'cyan') + + +def blue(text): + return text_attribute(text, 'blue') + + +def red(text): + return text_attribute(text, 'red') + + +def magenta(text): + return text_attribute(text, 'magenta') + + +def green(text): + return text_attribute(text, 'green') + +def yellow(text): + return text_attribute(text, 'yellow') + +def underline(text): + return text_attribute(text, 'underline') + + +def text_attribute(text, attribute): + if isinstance(text, str): + return "{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'], + txt=text, + stop=TEXT_CODES[attribute]['end']) + elif isinstance(text, unicode): + return u"{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'], + txt=text, + stop=TEXT_CODES[attribute]['end']) + else: + raise Exception("not a string") + + +FUNC_DICT = {'blue': blue, + 'bold': bold, + 'green': green, + 'yellow': yellow, + 'cyan': cyan, + 'magenta': magenta, + 'underline': underline, + 'red': red} + + +def format_text(text, *args): + return_string = text + for i in args: + func = FUNC_DICT.get(i) + if func: + return_string = func(return_string) + + return return_string + + +def format_threshold (value, red_zone, green_zone): + if value >= red_zone[0] and value <= red_zone[1]: + return format_text("{0}".format(value), 'red') + + if value >= green_zone[0] and value <= green_zone[1]: + return format_text("{0}".format(value), 'green') + + return "{0}".format(value) + +# pretty print for JSON +def pretty_json (json_str, use_colors = True): + pretty_str = json.dumps(json.loads(json_str), indent = 4, separators=(',', ': '), sort_keys = True) + + if not use_colors: + return pretty_str + + try: + # int numbers + pretty_str = re.sub(r'([ ]*:[ ]+)(\-?[1-9][0-9]*[^.])',r'\1{0}'.format(blue(r'\2')), pretty_str) + # float + pretty_str = re.sub(r'([ ]*:[ ]+)(\-?[1-9][0-9]*\.[0-9]+)',r'\1{0}'.format(magenta(r'\2')), pretty_str) + # # strings + # + pretty_str = re.sub(r'([ ]*:[ ]+)("[^"]*")',r'\1{0}'.format(red(r'\2')), pretty_str) + pretty_str = re.sub(r"('[^']*')", r'{0}\1{1}'.format(TEXT_CODES['magenta']['start'], + TEXT_CODES['red']['start']), pretty_str) + except : + pass + + return pretty_str + + +if __name__ == "__main__": + pass diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_tables.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_tables.py new file mode 100644 index 00000000..07753fda --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_tables.py @@ -0,0 +1,31 @@ +from texttable import Texttable +from text_opts import format_text + +class TRexTextTable(Texttable): + + def __init__(self): + Texttable.__init__(self) + # set class attributes so that it'll be more like TRex standard output + self.set_chars(['-', '|', '-', '-']) + self.set_deco(Texttable.HEADER | Texttable.VLINES) + +class TRexTextInfo(Texttable): + + def __init__(self): + Texttable.__init__(self) + # set class attributes so that it'll be more like TRex standard output + self.set_chars(['-', ':', '-', '-']) + self.set_deco(Texttable.VLINES) + +def generate_trex_stats_table(): + pass + +def print_table_with_header(texttable_obj, header="", untouched_header=""): + header = header.replace("_", " ").title() + untouched_header + print format_text(header, 'cyan', 'underline') + "\n" + + print (texttable_obj.draw() + "\n").encode('utf-8') + +if __name__ == "__main__": + pass + |