From 1975220bed9599e8c6da81571d7c13b74d1df7d3 Mon Sep 17 00:00:00 2001
From: Yaroslav Brustinov <ybrustin@cisco.com>
Date: Wed, 31 Aug 2016 17:02:06 +0300
Subject: generation of native code: fix imports

---
 .../stl/trex_stl_lib/trex_stl_streams.py           | 35 ++++++++++++++++++----
 1 file changed, 29 insertions(+), 6 deletions(-)

(limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib')

diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
index fc0bc78c..752f14b5 100755
--- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
+++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_streams.py
@@ -3,7 +3,7 @@
 from .trex_stl_exceptions import *
 from .trex_stl_types import verify_exclusive_arg, validate_type
 from .trex_stl_packet_builder_interface import CTrexPktBuilderInterface
-from .trex_stl_packet_builder_scapy import STLPktBuilder, Ether, IP, UDP, TCP, RawPcapReader
+from .trex_stl_packet_builder_scapy import *
 from collections import OrderedDict, namedtuple
 
 from scapy.utils import ltoa
@@ -538,7 +538,33 @@ class STLStream(object):
         """ Convert to Python code as profile  """
         packet = Ether(self.pkt)
         layer = packet
-        while layer:                    # remove checksums
+        imports_arr = []
+        # remove checksums, add imports if needed
+        while layer:
+            layer_class = layer.__class__.__name__
+            try: # check if class can be instantiated
+                eval('%s()' % layer_class)
+            except NameError: # no such layer
+                found_import = False
+                for module_path, module in sys.modules.items():
+                    import_string = 'from %s import %s' % (module_path, layer_class)
+                    if import_string in imports_arr:
+                        found_import = True
+                        break
+                    if not module_path.startswith(('scapy.layers', 'scapy.contrib')):
+                        continue
+                    check_layer = getattr(module, layer_class, None)
+                    if not check_layer:
+                        continue
+                    try:
+                        check_layer()
+                        imports_arr.append(import_string)
+                        found_import = True
+                        break
+                    except: # can't by instantiated
+                        continue
+                if not found_import:
+                    raise STLError('Could not determine import of layer %s' % layer.name)
             for chksum_name in ('cksum', 'chksum'):
                 if chksum_name in layer.fields:
                     del layer.fields[chksum_name]
@@ -546,10 +572,7 @@ class STLStream(object):
         packet.hide_defaults()          # remove fields with default values
         payload = packet.getlayer('Raw')
         packet_command = packet.command()
-        imports_arr = []
-        if 'MPLS(' in packet_command:
-            imports_arr.append('from scapy.contrib.mpls import MPLS')
-            
+
         imports = '\n'.join(imports_arr)
         if payload:
             payload.remove_payload() # fcs etc.
-- 
cgit