From 94ce0dcd7f93fe82e667f38d805f56d6d828f824 Mon Sep 17 00:00:00 2001
From: imarom <imarom@cisco.com>
Date: Thu, 14 Apr 2016 19:29:03 +0300
Subject: fixed partial ports console

---
 .../stl/trex_stl_lib/trex_stl_client.py            | 143 ++++++++++++++-------
 .../stl/trex_stl_lib/utils/common.py               |   5 +
 .../stl/trex_stl_lib/utils/parsing_opts.py         |  42 +++++-
 3 files changed, 134 insertions(+), 56 deletions(-)

(limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib')

diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py
index 98f3fe3a..aa95f037 100755
--- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py
+++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py
@@ -12,6 +12,7 @@ from .trex_stl_types import *
 from .trex_stl_async_client import CTRexAsyncClient
 
 from .utils import parsing_opts, text_tables, common
+from .utils.common import list_intersect, list_difference, is_sub_list
 from .utils.text_opts import *
 from functools import wraps
 
@@ -1141,23 +1142,39 @@ class STLClient(object):
                 if port_obj.is_acquired()]
 
     # get all active ports (TX or pause)
-    def get_active_ports(self):
-        return [port_id
-                for port_id, port_obj in self.ports.items()
-                if port_obj.is_active()]
+    def get_active_ports(self, owned = True):
+        if owned:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_active() and port_obj.is_acquired()]
+        else:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_active()]
 
 
     # get paused ports
-    def get_paused_ports (self):
-        return [port_id
-                for port_id, port_obj in self.ports.items()
-                if port_obj.is_paused()]
+    def get_paused_ports (self, owned = True):
+        if owned:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_paused() and port_obj.is_acquired()]
+        else:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_paused()]
+
 
     # get all TX ports
-    def get_transmitting_ports (self):
-        return [port_id
-                for port_id, port_obj in self.ports.items()
-                if port_obj.is_transmitting()]
+    def get_transmitting_ports (self, owned = True):
+        if owned:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_transmitting() and port_obj.is_acquired()]
+        else:
+            return [port_id
+                    for port_id, port_obj in self.ports.items()
+                    if port_obj.is_transmitting()]
 
 
     # get stats
@@ -2031,8 +2048,7 @@ class STLClient(object):
                                          parsing_opts.PORT_LIST_WITH_ALL,
                                          parsing_opts.FORCE)
 
-        opts = parser.parse_args(line.split())
-
+        opts = parser.parse_args(line.split(), default_ports = self.get_all_ports())
         if opts is None:
             return
 
@@ -2042,6 +2058,7 @@ class STLClient(object):
         # true means print time
         return True
 
+
     @__console
     def acquire_line (self, line):
         '''Acquire ports\n'''
@@ -2053,15 +2070,14 @@ class STLClient(object):
                                          parsing_opts.PORT_LIST_WITH_ALL,
                                          parsing_opts.FORCE)
 
-        opts = parser.parse_args(line.split())
-
+        opts = parser.parse_args(line.split(), default_ports = self.get_all_ports())
         if opts is None:
             return
 
-        # call the API
-        ports = [x for x in opts.ports if x not in self.get_acquired_ports()]
+        # filter out all the already owned ports
+        ports = list_difference(opts.ports, self.get_acquired_ports())
         if not ports:
-            self.logger.log("Port(s) {0} are already acquired\n".format(opts.ports))
+            self.logger.log("acquire - all port(s) {0} are already acquired".format(opts.ports))
             return
 
         self.acquire(ports = ports, force = opts.force)
@@ -2080,22 +2096,26 @@ class STLClient(object):
                                          self.release_line.__doc__,
                                          parsing_opts.PORT_LIST_WITH_ALL)
 
-        opts = parser.parse_args(line.split())
-
+        opts = parser.parse_args(line.split(), default_ports = self.get_acquired_ports())
         if opts is None:
             return
 
-        # call the API
-        ports = [x for x in opts.ports if x in self.get_acquired_ports()]
+        ports = list_intersect(opts.ports, self.get_acquired_ports())
         if not ports:
-            self.logger.log("Port(s) {0} are not owned by you\n".format(opts.ports))
-            return
+            if not opts.ports:
+                self.logger.log("release - no acquired ports")
+                return
+            else:
+                self.logger.log("release - none of port(s) {0} are acquired".format(opts.ports))
+                return
 
+        
         self.release(ports = ports)
 
         # true means print time
         return True
 
+
     @__console
     def disconnect_line (self, line):
         self.disconnect()
@@ -2104,12 +2124,24 @@ class STLClient(object):
 
     @__console
     def reset_line (self, line):
-        self.reset()
+        '''Reset ports - if no ports are provided all acquired ports will be reset'''
+
+        parser = parsing_opts.gen_parser(self,
+                                         "reset",
+                                         self.reset_line.__doc__,
+                                         parsing_opts.PORT_LIST_WITH_ALL)
+
+        opts = parser.parse_args(line.split(), default_ports = self.get_acquired_ports(), verify_acquired = True)
+        if opts is None:
+            return
+
+        self.reset(ports = opts.ports)
 
         # true means print time
         return True
 
 
+
     @__console
     def start_line (self, line):
         '''Start selected traffic on specified ports on TRex\n'''
@@ -2126,15 +2158,11 @@ class STLClient(object):
                                          parsing_opts.MULTIPLIER_STRICT,
                                          parsing_opts.DRY_RUN)
 
-        opts = parser.parse_args(line.split())
-
-
+        opts = parser.parse_args(line.split(), default_ports = self.get_acquired_ports(), verify_acquired = True)
         if opts is None:
             return
 
-
-        active_ports = list(set(self.get_active_ports()).intersection(opts.ports))
-
+        active_ports = list_intersect(self.get_active_ports(), opts.ports)
         if active_ports:
             if not opts.force:
                 msg = "Port(s) {0} are active - please stop them or add '--force'\n".format(active_ports)
@@ -2205,17 +2233,21 @@ class STLClient(object):
                                          self.stop_line.__doc__,
                                          parsing_opts.PORT_LIST_WITH_ALL)
 
-        opts = parser.parse_args(line.split())
+        opts = parser.parse_args(line.split(), default_ports = self.get_active_ports(), verify_acquired = True)
         if opts is None:
             return
 
-        # find the relevant ports
-        ports = list(set(self.get_active_ports()).intersection(opts.ports))
 
+        # find the relevant ports
+        ports = list_intersect(opts.ports, self.get_active_ports())
         if not ports:
-            self.logger.log(format_text("No active traffic on provided ports\n", 'bold'))
+            if not opts.ports:
+                self.logger.log('stop - no active ports')
+            else:
+                self.logger.log('stop - no active traffic on ports {0}'.format(opts.ports))
             return
 
+        # call API
         self.stop(ports)
 
         # true means print time
@@ -2233,15 +2265,18 @@ class STLClient(object):
                                          parsing_opts.TOTAL,
                                          parsing_opts.FORCE)
 
-        opts = parser.parse_args(line.split())
+        opts = parser.parse_args(line.split(), default_ports = self.get_active_ports(), verify_acquired = True)
         if opts is None:
             return
 
-         # find the relevant ports
-        ports = list(set(self.get_active_ports()).intersection(opts.ports))
 
+        # find the relevant ports
+        ports = list_intersect(opts.ports, self.get_active_ports())
         if not ports:
-            self.logger.log(format_text("No ports in valid state to update\n", 'bold'))
+            if not opts.ports:
+                self.logger.log('update - no active ports')
+            else:
+                self.logger.log('update - no active traffic on ports {0}'.format(opts.ports))
             return
 
         self.update(ports, opts.mult, opts.total, opts.force)
@@ -2258,15 +2293,22 @@ class STLClient(object):
                                          self.pause_line.__doc__,
                                          parsing_opts.PORT_LIST_WITH_ALL)
 
-        opts = parser.parse_args(line.split())
+        opts = parser.parse_args(line.split(), default_ports = self.get_transmitting_ports(), verify_acquired = True)
         if opts is None:
             return
 
-        # find the relevant ports
-        ports = list(set(self.get_transmitting_ports()).intersection(opts.ports))
+        # check for already paused case
+        if opts.ports and is_sub_list(opts.ports, self.get_paused_ports()):
+            self.logger.log('pause - all of port(s) {0} are already paused'.format(opts.ports))
+            return
 
+        # find the relevant ports
+        ports = list_intersect(opts.ports, self.get_transmitting_ports())
         if not ports:
-            self.logger.log(format_text("No ports in valid state to pause\n", 'bold'))
+            if not opts.ports:
+                self.logger.log('pause - no transmitting ports')
+            else:
+                self.logger.log('pause - none of ports {0} are transmitting'.format(opts.ports))
             return
 
         self.pause(ports)
@@ -2283,18 +2325,21 @@ class STLClient(object):
                                          self.resume_line.__doc__,
                                          parsing_opts.PORT_LIST_WITH_ALL)
 
-        opts = parser.parse_args(line.split())
+        opts = parser.parse_args(line.split(), default_ports = self.get_paused_ports(), verify_acquired = True)
         if opts is None:
             return
 
         # find the relevant ports
-        ports = list(set(self.get_paused_ports()).intersection(opts.ports))
-
+        ports = list_intersect(opts.ports, self.get_paused_ports())
         if not ports:
-            self.logger.log(format_text("No ports in valid state to resume\n", 'bold'))
+            if not opts.ports:
+                self.logger.log('resume - no paused ports')
+            else:
+                self.logger.log('resume - none of ports {0} are paused'.format(opts.ports))
             return
 
-        return self.resume(ports)
+
+        self.resume(ports)
 
         # true means print time
         return True
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py
index e176ca99..b4903e81 100644
--- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py
+++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/common.py
@@ -59,3 +59,8 @@ def get_number(input):
 def list_intersect(l1, l2):
     return list(filter(lambda x: x in l2, l1))
 
+def list_difference (l1, l2):
+    return list(filter(lambda x: x not in l2, l1))
+
+def is_sub_list (l1, l2):
+    return set(l1) <= set(l2)
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py
index c4f2b358..ad46625b 100755
--- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py
+++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/parsing_opts.py
@@ -1,5 +1,6 @@
 import argparse
 from collections import namedtuple
+from .common import list_intersect, list_difference
 import sys
 import re
 import os
@@ -262,7 +263,7 @@ OPTIONS_DB = {MULTIPLIER: ArgumentPack(['-m', '--multiplier'],
                                             'action': "store_false"}),
 
 
-              PORT_LIST: ArgumentPack(['--port'],
+              PORT_LIST: ArgumentPack(['--port', '-p'],
                                         {"nargs": '+',
                                          'dest':'ports',
                                          'metavar': 'PORTS',
@@ -374,22 +375,49 @@ class CCmdArgParser(argparse.ArgumentParser):
     def __init__(self, stateless_client, *args, **kwargs):
         super(CCmdArgParser, self).__init__(*args, **kwargs)
         self.stateless_client = stateless_client
+        self.cmd_name = kwargs.get('prog')
 
-    def parse_args(self, args=None, namespace=None):
+
+    def has_ports_cfg (self, opts):
+        return hasattr(opts, "all_ports") or hasattr(opts, "ports")
+
+    def parse_args(self, args=None, namespace=None, default_ports=None, verify_acquired=False):
         try:
             opts = super(CCmdArgParser, self).parse_args(args, namespace)
             if opts is None:
                 return None
 
+            if not self.has_ports_cfg(opts):
+                return opts
+
             # if all ports are marked or 
             if (getattr(opts, "all_ports", None) == True) or (getattr(opts, "ports", None) == []):
-                opts.ports = self.stateless_client.get_all_ports()
+                if default_ports is None:
+                    opts.ports = self.stateless_client.get_acquired_ports()
+                else:
+                    opts.ports = default_ports
 
             # so maybe we have ports configured
-            elif getattr(opts, "ports", None):
-                for port in opts.ports:
-                    if not self.stateless_client._validate_port_list(port):
-                        self.error("port id '{0}' is not a valid port id\n".format(port))
+            invalid_ports = list_difference(opts.ports, self.stateless_client.get_all_ports())
+            if invalid_ports:
+                self.stateless_client.logger.log("{0}: port(s) {1} are not valid port IDs".format(self.cmd_name, invalid_ports))
+                return None
+
+            # verify acquired ports
+            if verify_acquired:
+                acquired_ports = self.stateless_client.get_acquired_ports()
+
+                diff = list_difference(opts.ports, acquired_ports)
+                if diff:
+                    self.stateless_client.logger.log("{0} - port(s) {1} are not acquired".format(self.cmd_name, diff))
+                    return None
+
+                # no acquire ports at all
+                if not acquired_ports:
+                    self.stateless_client.logger.log("{0} - no acquired ports".format(self.cmd_name))
+                    return None
+
+
 
             return opts
 
-- 
cgit