diff options
Diffstat (limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py')
-rw-r--r-- | scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py | 386 |
1 files changed, 386 insertions, 0 deletions
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py new file mode 100644 index 00000000..abfa32cd --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py @@ -0,0 +1,386 @@ +#!/router/bin/python + +from trex_stl_exceptions import * +from trex_stl_packet_builder_interface import CTrexPktBuilderInterface +from trex_stl_packet_builder_scapy import CScapyTRexPktBuilder, Ether, IP +from collections import OrderedDict, namedtuple + +from dpkt import pcap +import random +import yaml +import base64 +import string +import traceback + +def random_name (l): + return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(l)) + + +# base class for TX mode +class STLTXMode(object): + def __init__ (self): + self.fields = {} + + def to_json (self): + return self.fields + + +# continuous mode +class STLTXCont(STLTXMode): + + def __init__ (self, pps = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + super(STLTXCont, self).__init__() + + self.fields['type'] = 'continuous' + self.fields['pps'] = pps + + +# single burst mode +class STLTXSingleBurst(STLTXMode): + + def __init__ (self, pps = 1, total_pkts = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + if not isinstance(total_pkts, int): + raise STLArgumentError('total_pkts', total_pkts) + + super(STLTXSingleBurst, self).__init__() + + self.fields['type'] = 'single_burst' + self.fields['pps'] = pps + self.fields['total_pkts'] = total_pkts + + +# multi burst mode +class STLTXMultiBurst(STLTXMode): + + def __init__ (self, + pps = 1, + pkts_per_burst = 1, + ibg = 0.0, + count = 1): + + if not isinstance(pps, (int, float)): + raise STLArgumentError('pps', pps) + + if not isinstance(pkts_per_burst, int): + raise STLArgumentError('pkts_per_burst', pkts_per_burst) + + if not isinstance(ibg, (int, float)): + raise STLArgumentError('ibg', ibg) + + if not isinstance(count, int): + raise STLArgumentError('count', count) + + super(STLTXMultiBurst, self).__init__() + + self.fields['type'] = 'multi_burst' + self.fields['pps'] = pps + self.fields['pkts_per_burst'] = pkts_per_burst + self.fields['ibg'] = ibg + self.fields['count'] = count + + +class STLStream(object): + + def __init__ (self, + name = random_name(8), + packet = None, + mode = STLTXCont(1), + enabled = True, + self_start = True, + isg = 0.0, + rx_stats = None, + next = None, + stream_id = None): + + # type checking + if not isinstance(mode, STLTXMode): + raise STLArgumentError('mode', mode) + + if packet and not isinstance(packet, CTrexPktBuilderInterface): + raise STLArgumentError('packet', packet) + + if not isinstance(enabled, bool): + raise STLArgumentError('enabled', enabled) + + if not isinstance(self_start, bool): + raise STLArgumentError('self_start', self_start) + + if not isinstance(isg, (int, float)): + raise STLArgumentError('isg', isg) + + if (type(mode) == STLTXCont) and (next != None): + raise STLError("continuous stream cannot have a next stream ID") + + # tag for the stream and next - can be anything + self.name = name + self.next = next + self.set_id(stream_id) + + self.fields = {} + + # basic fields + self.fields['enabled'] = enabled + self.fields['self_start'] = self_start + self.fields['isg'] = isg + + # mode + self.fields['mode'] = mode.to_json() + + self.fields['packet'] = {} + self.fields['vm'] = {} + + if not packet: + packet = CScapyTRexPktBuilder(pkt = Ether()/IP()) + + # packet builder + packet.compile() + + # packet and VM + self.fields['packet'] = packet.dump_pkt() + self.fields['vm'] = packet.get_vm_data() + + if not rx_stats: + self.fields['rx_stats'] = {} + self.fields['rx_stats']['enabled'] = False + else: + self.fields['rx_stats'] = rx_stats + + + def __str__ (self): + return json.dumps(self.fields, indent = 4, separators=(',', ': '), sort_keys = True) + + def to_json (self): + return self.fields + + def get_id (self): + return self.id + + def set_id (self, id): + self.id = id + + def get_name (self): + return self.name + + def get_next (self): + return self.next + + + def to_yaml (self): + return {'name': self.name, 'stream': self.fields} + + + +class YAMLLoader(object): + + def __init__ (self, yaml_file): + self.yaml_path = os.path.dirname(yaml_file) + self.yaml_file = yaml_file + + + def __parse_packet (self, packet_dict): + builder = CScapyTRexPktBuilder() + + packet_type = set(packet_dict).intersection(['binary', 'pcap']) + if len(packet_type) != 1: + raise STLError("packet section must contain either 'binary' or 'pcap'") + + if 'binary' in packet_type: + try: + pkt_str = base64.b64decode(packet_dict['binary']) + except TypeError: + raise STLError("'binary' field is not a valid packet format") + + builder.set_pkt_as_str(pkt_str) + + elif 'pcap' in packet_type: + pcap = os.path.join(self.yaml_path, packet_dict['pcap']) + + if not os.path.exists(pcap): + raise STLError("'pcap' - cannot find '{0}'".format(pcap)) + + builder.set_packet(pcap) + + return builder + + + def __parse_mode (self, mode_obj): + + mode_type = mode_obj.get('type') + + if mode_type == 'continuous': + defaults = STLTXCont() + mode = STLTXCont(pps = mode_obj.get('pps', defaults.fields['pps'])) + + elif mode_type == 'single_burst': + defaults = STLTXSingleBurst() + mode = STLTXSingleBurst(pps = mode_obj.get('pps', defaults.fields['pps']), + total_pkts = mode_obj.get('total_pkts', defaults.fields['total_pkts'])) + + elif mode_type == 'multi_burst': + defaults = STLTXMultiBurst() + mode = STLTXMultiBurst(pps = mode_obj.get('pps', defaults.fields['pps']), + pkts_per_burst = mode_obj.get('pkts_per_burst', defaults.fields['pkts_per_burst']), + ibg = mode_obj.get('ibg', defaults.fields['ibg']), + count = mode_obj.get('count', defaults.fields['count'])) + + else: + raise STLError("mode type can be 'continuous', 'single_burst' or 'multi_burst") + + + return mode + + + def __parse_stream (self, yaml_object): + s_obj = yaml_object['stream'] + + # parse packet + packet = s_obj.get('packet') + if not packet: + raise STLError("YAML file must contain 'packet' field") + + builder = self.__parse_packet(packet) + + + # mode + mode_obj = s_obj.get('mode') + if not mode_obj: + raise STLError("YAML file must contain 'mode' field") + + mode = self.__parse_mode(mode_obj) + + + defaults = STLStream() + + # create the stream + stream = STLStream(name = yaml_object.get('name'), + packet = builder, + mode = mode, + enabled = s_obj.get('enabled', defaults.fields['enabled']), + self_start = s_obj.get('self_start', defaults.fields['self_start']), + isg = s_obj.get('isg', defaults.fields['isg']), + rx_stats = s_obj.get('rx_stats', defaults.fields['rx_stats']), + next = yaml_object.get('next')) + + # hack the VM fields for now + if 'vm' in s_obj: + stream.fields['vm'].update(s_obj['vm']) + + return stream + + + def parse (self): + with open(self.yaml_file, 'r') as f: + # read YAML and pass it down to stream object + yaml_str = f.read() + + try: + objects = yaml.load(yaml_str) + except yaml.parser.ParserError as e: + raise STLError(str(e)) + + streams = [self.__parse_stream(object) for object in objects] + + return streams + + +# profile class +class STLProfile(object): + def __init__ (self, streams = None): + if streams == None: + streams = [] + + if not type(streams) == list: + streams = [streams] + + if not all([isinstance(stream, STLStream) for stream in streams]): + raise STLArgumentError('streams', streams) + + self.streams = streams + + + def get_streams (self): + return self.streams + + def __str__ (self): + return '\n'.join([str(stream) for stream in self.streams]) + + + @staticmethod + def load_yaml (yaml_file): + # check filename + if not os.path.isfile(yaml_file): + raise STLError("file '{0}' does not exists".format(yaml_file)) + + yaml_loader = YAMLLoader(yaml_file) + streams = yaml_loader.parse() + + return STLProfile(streams) + + + @staticmethod + def load_py (python_file): + # check filename + if not os.path.isfile(python_file): + raise STLError("file '{0}' does not exists".format(python_file)) + + basedir = os.path.dirname(python_file) + sys.path.append(basedir) + + try: + file = os.path.basename(python_file).split('.')[0] + module = __import__(file, globals(), locals(), [], -1) + reload(module) # reload the update + + streams = module.register().get_streams() + + return STLProfile(streams) + + except Exception as e: + a, b, tb = sys.exc_info() + x =''.join(traceback.format_list(traceback.extract_tb(tb)[1:])) + a.__name__ + ": " + str(b) + "\n" + + summary = "\nPython Traceback follows:\n\n" + x + raise STLError(summary) + + + finally: + sys.path.remove(basedir) + + + @staticmethod + def load (filename): + x = os.path.basename(filename).split('.') + suffix = x[1] if (len(x) == 2) else None + + if suffix == 'py': + profile = STLProfile.load_py(filename) + + elif suffix == 'yaml': + profile = STLProfile.load_yaml(filename) + + else: + raise STLError("unknown profile file type: '{0}'".format(suffix)) + + return profile + + + def dump_to_yaml (self, yaml_file = None): + yaml_list = [stream.to_yaml() for stream in self.streams] + yaml_str = yaml.dump(yaml_list, default_flow_style = False) + + # write to file if provided + if yaml_file: + with open(yaml_file, 'w') as f: + f.write(yaml_str) + + return yaml_str + + |