summaryrefslogtreecommitdiffstats
path: root/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py')
-rw-r--r--scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py386
1 files changed, 386 insertions, 0 deletions
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
new file mode 100644
index 00000000..abfa32cd
--- /dev/null
+++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
@@ -0,0 +1,386 @@
+#!/router/bin/python
+
+from trex_stl_exceptions import *
+from trex_stl_packet_builder_interface import CTrexPktBuilderInterface
+from trex_stl_packet_builder_scapy import CScapyTRexPktBuilder, Ether, IP
+from collections import OrderedDict, namedtuple
+
+from dpkt import pcap
+import random
+import yaml
+import base64
+import string
+import traceback
+
+def random_name (l):
+ return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(l))
+
+
+# base class for TX mode
+class STLTXMode(object):
+ def __init__ (self):
+ self.fields = {}
+
+ def to_json (self):
+ return self.fields
+
+
+# continuous mode
+class STLTXCont(STLTXMode):
+
+ def __init__ (self, pps = 1):
+
+ if not isinstance(pps, (int, float)):
+ raise STLArgumentError('pps', pps)
+
+ super(STLTXCont, self).__init__()
+
+ self.fields['type'] = 'continuous'
+ self.fields['pps'] = pps
+
+
+# single burst mode
+class STLTXSingleBurst(STLTXMode):
+
+ def __init__ (self, pps = 1, total_pkts = 1):
+
+ if not isinstance(pps, (int, float)):
+ raise STLArgumentError('pps', pps)
+
+ if not isinstance(total_pkts, int):
+ raise STLArgumentError('total_pkts', total_pkts)
+
+ super(STLTXSingleBurst, self).__init__()
+
+ self.fields['type'] = 'single_burst'
+ self.fields['pps'] = pps
+ self.fields['total_pkts'] = total_pkts
+
+
+# multi burst mode
+class STLTXMultiBurst(STLTXMode):
+
+ def __init__ (self,
+ pps = 1,
+ pkts_per_burst = 1,
+ ibg = 0.0,
+ count = 1):
+
+ if not isinstance(pps, (int, float)):
+ raise STLArgumentError('pps', pps)
+
+ if not isinstance(pkts_per_burst, int):
+ raise STLArgumentError('pkts_per_burst', pkts_per_burst)
+
+ if not isinstance(ibg, (int, float)):
+ raise STLArgumentError('ibg', ibg)
+
+ if not isinstance(count, int):
+ raise STLArgumentError('count', count)
+
+ super(STLTXMultiBurst, self).__init__()
+
+ self.fields['type'] = 'multi_burst'
+ self.fields['pps'] = pps
+ self.fields['pkts_per_burst'] = pkts_per_burst
+ self.fields['ibg'] = ibg
+ self.fields['count'] = count
+
+
+class STLStream(object):
+
+ def __init__ (self,
+ name = random_name(8),
+ packet = None,
+ mode = STLTXCont(1),
+ enabled = True,
+ self_start = True,
+ isg = 0.0,
+ rx_stats = None,
+ next = None,
+ stream_id = None):
+
+ # type checking
+ if not isinstance(mode, STLTXMode):
+ raise STLArgumentError('mode', mode)
+
+ if packet and not isinstance(packet, CTrexPktBuilderInterface):
+ raise STLArgumentError('packet', packet)
+
+ if not isinstance(enabled, bool):
+ raise STLArgumentError('enabled', enabled)
+
+ if not isinstance(self_start, bool):
+ raise STLArgumentError('self_start', self_start)
+
+ if not isinstance(isg, (int, float)):
+ raise STLArgumentError('isg', isg)
+
+ if (type(mode) == STLTXCont) and (next != None):
+ raise STLError("continuous stream cannot have a next stream ID")
+
+ # tag for the stream and next - can be anything
+ self.name = name
+ self.next = next
+ self.set_id(stream_id)
+
+ self.fields = {}
+
+ # basic fields
+ self.fields['enabled'] = enabled
+ self.fields['self_start'] = self_start
+ self.fields['isg'] = isg
+
+ # mode
+ self.fields['mode'] = mode.to_json()
+
+ self.fields['packet'] = {}
+ self.fields['vm'] = {}
+
+ if not packet:
+ packet = CScapyTRexPktBuilder(pkt = Ether()/IP())
+
+ # packet builder
+ packet.compile()
+
+ # packet and VM
+ self.fields['packet'] = packet.dump_pkt()
+ self.fields['vm'] = packet.get_vm_data()
+
+ if not rx_stats:
+ self.fields['rx_stats'] = {}
+ self.fields['rx_stats']['enabled'] = False
+ else:
+ self.fields['rx_stats'] = rx_stats
+
+
+ def __str__ (self):
+ return json.dumps(self.fields, indent = 4, separators=(',', ': '), sort_keys = True)
+
+ def to_json (self):
+ return self.fields
+
+ def get_id (self):
+ return self.id
+
+ def set_id (self, id):
+ self.id = id
+
+ def get_name (self):
+ return self.name
+
+ def get_next (self):
+ return self.next
+
+
+ def to_yaml (self):
+ return {'name': self.name, 'stream': self.fields}
+
+
+
+class YAMLLoader(object):
+
+ def __init__ (self, yaml_file):
+ self.yaml_path = os.path.dirname(yaml_file)
+ self.yaml_file = yaml_file
+
+
+ def __parse_packet (self, packet_dict):
+ builder = CScapyTRexPktBuilder()
+
+ packet_type = set(packet_dict).intersection(['binary', 'pcap'])
+ if len(packet_type) != 1:
+ raise STLError("packet section must contain either 'binary' or 'pcap'")
+
+ if 'binary' in packet_type:
+ try:
+ pkt_str = base64.b64decode(packet_dict['binary'])
+ except TypeError:
+ raise STLError("'binary' field is not a valid packet format")
+
+ builder.set_pkt_as_str(pkt_str)
+
+ elif 'pcap' in packet_type:
+ pcap = os.path.join(self.yaml_path, packet_dict['pcap'])
+
+ if not os.path.exists(pcap):
+ raise STLError("'pcap' - cannot find '{0}'".format(pcap))
+
+ builder.set_packet(pcap)
+
+ return builder
+
+
+ def __parse_mode (self, mode_obj):
+
+ mode_type = mode_obj.get('type')
+
+ if mode_type == 'continuous':
+ defaults = STLTXCont()
+ mode = STLTXCont(pps = mode_obj.get('pps', defaults.fields['pps']))
+
+ elif mode_type == 'single_burst':
+ defaults = STLTXSingleBurst()
+ mode = STLTXSingleBurst(pps = mode_obj.get('pps', defaults.fields['pps']),
+ total_pkts = mode_obj.get('total_pkts', defaults.fields['total_pkts']))
+
+ elif mode_type == 'multi_burst':
+ defaults = STLTXMultiBurst()
+ mode = STLTXMultiBurst(pps = mode_obj.get('pps', defaults.fields['pps']),
+ pkts_per_burst = mode_obj.get('pkts_per_burst', defaults.fields['pkts_per_burst']),
+ ibg = mode_obj.get('ibg', defaults.fields['ibg']),
+ count = mode_obj.get('count', defaults.fields['count']))
+
+ else:
+ raise STLError("mode type can be 'continuous', 'single_burst' or 'multi_burst")
+
+
+ return mode
+
+
+ def __parse_stream (self, yaml_object):
+ s_obj = yaml_object['stream']
+
+ # parse packet
+ packet = s_obj.get('packet')
+ if not packet:
+ raise STLError("YAML file must contain 'packet' field")
+
+ builder = self.__parse_packet(packet)
+
+
+ # mode
+ mode_obj = s_obj.get('mode')
+ if not mode_obj:
+ raise STLError("YAML file must contain 'mode' field")
+
+ mode = self.__parse_mode(mode_obj)
+
+
+ defaults = STLStream()
+
+ # create the stream
+ stream = STLStream(name = yaml_object.get('name'),
+ packet = builder,
+ mode = mode,
+ enabled = s_obj.get('enabled', defaults.fields['enabled']),
+ self_start = s_obj.get('self_start', defaults.fields['self_start']),
+ isg = s_obj.get('isg', defaults.fields['isg']),
+ rx_stats = s_obj.get('rx_stats', defaults.fields['rx_stats']),
+ next = yaml_object.get('next'))
+
+ # hack the VM fields for now
+ if 'vm' in s_obj:
+ stream.fields['vm'].update(s_obj['vm'])
+
+ return stream
+
+
+ def parse (self):
+ with open(self.yaml_file, 'r') as f:
+ # read YAML and pass it down to stream object
+ yaml_str = f.read()
+
+ try:
+ objects = yaml.load(yaml_str)
+ except yaml.parser.ParserError as e:
+ raise STLError(str(e))
+
+ streams = [self.__parse_stream(object) for object in objects]
+
+ return streams
+
+
+# profile class
+class STLProfile(object):
+ def __init__ (self, streams = None):
+ if streams == None:
+ streams = []
+
+ if not type(streams) == list:
+ streams = [streams]
+
+ if not all([isinstance(stream, STLStream) for stream in streams]):
+ raise STLArgumentError('streams', streams)
+
+ self.streams = streams
+
+
+ def get_streams (self):
+ return self.streams
+
+ def __str__ (self):
+ return '\n'.join([str(stream) for stream in self.streams])
+
+
+ @staticmethod
+ def load_yaml (yaml_file):
+ # check filename
+ if not os.path.isfile(yaml_file):
+ raise STLError("file '{0}' does not exists".format(yaml_file))
+
+ yaml_loader = YAMLLoader(yaml_file)
+ streams = yaml_loader.parse()
+
+ return STLProfile(streams)
+
+
+ @staticmethod
+ def load_py (python_file):
+ # check filename
+ if not os.path.isfile(python_file):
+ raise STLError("file '{0}' does not exists".format(python_file))
+
+ basedir = os.path.dirname(python_file)
+ sys.path.append(basedir)
+
+ try:
+ file = os.path.basename(python_file).split('.')[0]
+ module = __import__(file, globals(), locals(), [], -1)
+ reload(module) # reload the update
+
+ streams = module.register().get_streams()
+
+ return STLProfile(streams)
+
+ except Exception as e:
+ a, b, tb = sys.exc_info()
+ x =''.join(traceback.format_list(traceback.extract_tb(tb)[1:])) + a.__name__ + ": " + str(b) + "\n"
+
+ summary = "\nPython Traceback follows:\n\n" + x
+ raise STLError(summary)
+
+
+ finally:
+ sys.path.remove(basedir)
+
+
+ @staticmethod
+ def load (filename):
+ x = os.path.basename(filename).split('.')
+ suffix = x[1] if (len(x) == 2) else None
+
+ if suffix == 'py':
+ profile = STLProfile.load_py(filename)
+
+ elif suffix == 'yaml':
+ profile = STLProfile.load_yaml(filename)
+
+ else:
+ raise STLError("unknown profile file type: '{0}'".format(suffix))
+
+ return profile
+
+
+ def dump_to_yaml (self, yaml_file = None):
+ yaml_list = [stream.to_yaml() for stream in self.streams]
+ yaml_str = yaml.dump(yaml_list, default_flow_style = False)
+
+ # write to file if provided
+ if yaml_file:
+ with open(yaml_file, 'w') as f:
+ f.write(yaml_str)
+
+ return yaml_str
+
+