diff options
author | 2016-05-03 16:09:00 +0300 | |
---|---|---|
committer | 2016-05-03 16:09:00 +0300 | |
commit | f27bfbab72f9221b221a7f36f4063e9ff8e9d307 (patch) | |
tree | a8775aa1c5063b5ed02fd13fbb94d541b696e91b /scripts/automation/trex_control_plane/stl/trex_stl_lib | |
parent | dbe44f64ec7cff8d55400327e5f6994ac1fc87bb (diff) | |
parent | 737d92948a1bfd9f94fd31a4e207d0f9e6f028c0 (diff) |
Merge trex-204
Diffstat (limited to 'scripts/automation/trex_control_plane/stl/trex_stl_lib')
7 files changed, 607 insertions, 65 deletions
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py index aa95f037..862a9979 100755 --- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py @@ -264,22 +264,55 @@ class EventsHandler(object): self.__async_event_port_job_done(port_id) show_event = True - # port was stolen... + # port was acquired - maybe stolen... elif (type == 5): session_id = data['session_id'] - # false alarm, its us + port_id = int(data['port_id']) + who = data['who'] + force = data['force'] + + # if we hold the port and it was not taken by this session - show it + if port_id in self.client.get_acquired_ports() and session_id != self.client.session_id: + show_event = True + + # format the thief/us... if session_id == self.client.session_id: - return + user = 'you' + elif who == self.client.username: + user = 'another session of you' + else: + user = "'{0}'".format(who) - port_id = int(data['port_id']) - who = data['who'] + if force: + ev = "Port {0} was forcely taken by {1}".format(port_id, user) + else: + ev = "Port {0} was taken by {1}".format(port_id, user) - ev = "Port {0} was forcely taken by '{1}'".format(port_id, who) + # call the handler in case its not this session + if session_id != self.client.session_id: + self.__async_event_port_acquired(port_id, who) + + + # port was released + elif (type == 6): + port_id = int(data['port_id']) + who = data['who'] + session_id = data['session_id'] + + if session_id == self.client.session_id: + user = 'you' + elif who == self.client.username: + user = 'another session of you' + else: + user = "'{0}'".format(who) + + ev = "Port {0} was released by {1}".format(port_id, user) + + # call the handler in case its not this session + if session_id != self.client.session_id: + self.__async_event_port_released(port_id) - # call the handler - self.__async_event_port_forced_acquired(port_id, who) - show_event = True # server stopped elif (type == 100): @@ -317,9 +350,11 @@ class EventsHandler(object): self.client.ports[port_id].async_event_port_resumed() - def __async_event_port_forced_acquired (self, port_id, who): - self.client.ports[port_id].async_event_forced_acquired(who) + def __async_event_port_acquired (self, port_id, who): + self.client.ports[port_id].async_event_acquired(who) + def __async_event_port_released (self, port_id): + self.client.ports[port_id].async_event_released() def __async_event_server_stopped (self): self.client.connected = False @@ -462,6 +497,11 @@ class STLClient(object): self.session_id = random.getrandbits(32) self.connected = False + # API classes + self.api_vers = [ {'type': 'core', 'major': 1, 'minor':2 } + ] + self.api_h = {'core': None} + # logger self.logger = DefaultLogger() if not logger else logger @@ -505,10 +545,7 @@ class STLClient(object): self.flow_stats) - # API classes - self.api_vers = [ {'type': 'core', 'major': 1, 'minor':1 } - ] - self.api_h = {'core': None} + ############# private functions - used by the class itself ########### @@ -545,13 +582,13 @@ class STLClient(object): return rc # acquire ports, if port_list is none - get all - def __acquire (self, port_id_list = None, force = False): + def __acquire (self, port_id_list = None, force = False, sync_streams = True): port_id_list = self.__ports(port_id_list) rc = RC() for port_id in port_id_list: - rc.add(self.ports[port_id].acquire(force)) + rc.add(self.ports[port_id].acquire(force, sync_streams)) return rc @@ -1335,16 +1372,20 @@ class STLClient(object): @__api_check(True) - def acquire (self, ports = None, force = False): + def acquire (self, ports = None, force = False, sync_streams = True): """ Acquires ports for executing commands :parameters: ports : list Ports on which to execute the command + force : bool Force acquire the ports. + sync_streams: bool + sync with the server about the configured streams + :raises: + :exc:`STLError` @@ -1359,7 +1400,7 @@ class STLClient(object): else: self.logger.pre_cmd("Acquiring ports {0}:".format(ports)) - rc = self.__acquire(ports, force) + rc = self.__acquire(ports, force, sync_streams) self.logger.post_cmd(rc) @@ -1459,7 +1500,8 @@ class STLClient(object): ports = ports if ports is not None else self.get_all_ports() ports = self._validate_port_list(ports) - self.acquire(ports, force = True) + # force take the port and ignore any streams on it + self.acquire(ports, force = True, sync_streams = False) self.stop(ports, rx_delay_ms = 0) self.remove_all_streams(ports) self.clear_stats(ports) @@ -2038,6 +2080,11 @@ class STLClient(object): return wrap + @__console + def ping_line (self, line): + '''pings the server''' + self.ping() + return True @__console def connect_line (self, line): @@ -2117,6 +2164,28 @@ class STLClient(object): @__console + def reacquire_line (self, line): + '''reacquire all the ports under your username which are not acquired by your session''' + + parser = parsing_opts.gen_parser(self, + "reacquire", + self.reacquire_line.__doc__) + + opts = parser.parse_args(line.split()) + if opts is None: + return + + # find all the on-owned ports under your name + my_unowned_ports = list_difference([k for k, v in self.ports.items() if v.get_owner() == self.username], self.get_acquired_ports()) + if not my_unowned_ports: + self.logger.log("reacquire - no unowned ports under '{0}'".format(self.username)) + return + + self.acquire(ports = my_unowned_ports, force = True) + return True + + + @__console def disconnect_line (self, line): self.disconnect() @@ -2491,15 +2560,20 @@ class STLClient(object): '''Sets port attributes ''' parser = parsing_opts.gen_parser(self, - "port", + "port_attr", self.set_port_attr_line.__doc__, parsing_opts.PORT_LIST_WITH_ALL, parsing_opts.PROMISCUOUS_SWITCH) - opts = parser.parse_args(line.split()) + opts = parser.parse_args(line.split(), default_ports = self.get_acquired_ports(), verify_acquired = True) if opts is None: return + # if no attributes - fall back to printing the status + if opts.prom is None: + self.show_stats_line("--ps --port {0}".format(' '.join(str(port) for port in opts.ports))) + return + self.set_port_attr(opts.ports, opts.prom) @@ -2592,5 +2666,4 @@ class STLClient(object): if opts.clear: self.clear_events() - self.logger.log(format_text("\nEvent log was cleared\n")) - +
\ No newline at end of file diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py index bd5ba8e7..fa04b9f6 100644 --- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_jsonrpc_client.py @@ -47,7 +47,7 @@ class JsonRpcClient(object): MSG_COMPRESS_HEADER_MAGIC = 0xABE85CEA def __init__ (self, default_server, default_port, client): - self.client = client + self.client_api = client.api_h self.logger = client.logger self.connected = False @@ -104,7 +104,7 @@ class JsonRpcClient(object): # if this RPC has an API class - add it's handler if api_class: - msg["params"]["api_h"] = self.client.api_h[api_class] + msg["params"]["api_h"] = self.client_api[api_class] if encode: diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py index 6f6f50b1..e8f89b27 100644 --- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_port.py @@ -76,29 +76,55 @@ class Port(object): def get_speed_bps (self): return (self.info['speed'] * 1000 * 1000 * 1000) + def get_formatted_speed (self): + return "{0} Gbps".format(self.info['speed']) + # take the port - def acquire(self, force = False): + def acquire(self, force = False, sync_streams = True): params = {"port_id": self.port_id, "user": self.user, "session_id": self.session_id, "force": force} rc = self.transmit("acquire", params) - if rc.good(): - self.handler = rc.data() - return self.ok() + if not rc: + return self.err(rc.err()) + + self.handler = rc.data() + + if sync_streams: + return self.sync_streams() else: + return self.ok() + + + # sync all the streams with the server + def sync_streams (self): + params = {"port_id": self.port_id} + + rc = self.transmit("get_all_streams", params) + if rc.bad(): return self.err(rc.err()) + for k, v in rc.data()['streams'].items(): + self.streams[k] = {'next_id': v['next_stream_id'], + 'pkt' : base64.b64decode(v['packet']['binary']), + 'mode' : v['mode']['type'], + 'rate' : STLStream.get_rate_from_field(v['mode']['rate'])} + return self.ok() + # release the port def release(self): params = {"port_id": self.port_id, "handler": self.handler} rc = self.transmit("release", params) - self.handler = None - + if rc.good(): + + self.handler = None + self.owner = '' + return self.ok() else: return self.err(rc.err()) @@ -151,19 +177,7 @@ class Port(object): # attributes self.attr = rc.data()['attr'] - # sync the streams - params = {"port_id": self.port_id} - - rc = self.transmit("get_all_streams", params) - if rc.bad(): - return self.err(rc.err()) - - for k, v in rc.data()['streams'].items(): - self.streams[k] = {'next_id': v['next_stream_id'], - 'pkt' : base64.b64decode(v['packet']['binary']), - 'mode' : v['mode']['type'], - 'rate' : STLStream.get_rate_from_field(v['mode']['rate'])} - + return self.ok() @@ -679,7 +693,10 @@ class Port(object): if not self.is_acquired(): self.state = self.STATE_TX - def async_event_forced_acquired (self, who): + def async_event_acquired (self, who): self.handler = None self.owner = who + def async_event_released (self): + self.owner = '' + diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py index 6b1185ef..c7513144 100644 --- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_stats.py @@ -201,6 +201,10 @@ class CTRexInfoGenerator(object): def _get_rational_block_char(value, range_start, interval): # in Konsole, utf-8 is sometimes printed with artifacts, return ascii for now #return 'X' if value >= range_start + float(interval) / 2 else ' ' + + if sys.__stdout__.encoding != 'UTF-8': + return 'X' if value >= range_start + float(interval) / 2 else ' ' + value -= range_start ratio = float(value) / interval if ratio <= 0.0625: @@ -208,7 +212,7 @@ class CTRexInfoGenerator(object): if ratio <= 0.1875: return u'\u2581' # 1/8 if ratio <= 0.3125: - return u'\u2582' # 2/4 + return u'\u2582' # 2/8 if ratio <= 0.4375: return u'\u2583' # 3/8 if ratio <= 0.5625: @@ -247,6 +251,7 @@ class CTRexInfoGenerator(object): return_stats_data = {} per_field_stats = OrderedDict([("owner", []), ("state", []), + ("speed", []), ("--", []), ("Tx bps L2", []), ("Tx bps L1", []), @@ -532,7 +537,12 @@ class CTRexStats(object): v = self.get_trend(field, use_raw) value = abs(v) - arrow = u'\u25b2' if v > 0 else u'\u25bc' + + # use arrows if utf-8 is supported + if sys.__stdout__.encoding == 'UTF-8': + arrow = u'\u25b2' if v > 0 else u'\u25bc' + else: + arrow = '' if sys.version_info < (3,0): arrow = arrow.encode('utf-8') @@ -584,6 +594,7 @@ class CGlobalStats(CTRexStats): # absolute stats['cpu_util'] = self.get("m_cpu_util") stats['rx_cpu_util'] = self.get("m_rx_cpu_util") + stats['bw_per_core'] = self.get("m_bw_per_core") stats['tx_bps'] = self.get("m_tx_bps") stats['tx_pps'] = self.get("m_tx_pps") @@ -688,12 +699,13 @@ class CPortStats(CTRexStats): else: self.__merge_dicts(self.reference_stats, x.reference_stats) - # history - if not self.history: - self.history = copy.deepcopy(x.history) - else: - for h1, h2 in zip(self.history, x.history): - self.__merge_dicts(h1, h2) + # history - should be traverse with a lock + with self.lock, x.lock: + if not self.history: + self.history = copy.deepcopy(x.history) + else: + for h1, h2 in zip(self.history, x.history): + self.__merge_dicts(h1, h2) return self @@ -758,9 +770,17 @@ class CPortStats(CTRexStats): else: state = format_text(state, 'bold') + # mark owned ports by color + if self._port_obj: + owner = self._port_obj.get_owner() + if self._port_obj.is_acquired(): + owner = format_text(owner, 'green') + else: + owner = '' - return {"owner": self._port_obj.get_owner() if self._port_obj else "", + return {"owner": owner, "state": "{0}".format(state), + "speed": self._port_obj.get_formatted_speed() if self._port_obj else '', "--": " ", "---": " ", diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/GAObjClass.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/GAObjClass.py new file mode 100755 index 00000000..164aae7a --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/GAObjClass.py @@ -0,0 +1,295 @@ +#import requests # need external lib for that
+try: # Python2
+ import Queue
+ from urllib2 import *
+except: # Python3
+ import queue as Queue
+ from urllib.request import *
+ from urllib.error import *
+import threading
+import sys
+from time import sleep
+
+"""
+GAObjClass is a class destined to send Google Analytics Information.
+
+cid - unique number per user.
+command - the Event Category rubric appears on site. type: TEXT
+action - the Event Action rubric appears on site - type: TEXT
+label - the Event Label rubric - type: TEXT
+value - the event value metric - type: INTEGER
+
+QUOTAS:
+1 single payload - up to 8192Bytes
+batched:
+A maximum of 20 hits can be specified per request.
+The total size of all hit payloads cannot be greater than 16K bytes.
+No single hit payload can be greater than 8K bytes.
+"""
+
+url_single = 'http://www.google-analytics.com/collect' #sending single event
+url_batched = 'http://www.google-analytics.com/batch' #sending batched events
+url_debug = 'http://www.google-analytics.com/debug/collect' #verifying hit is valid
+url_conn = 'http://172.217.2.196' # testing internet connection to this address (google-analytics server)
+
+
+#..................................................................class GA_EVENT_ObjClass................................................................
+class GA_EVENT_ObjClass:
+ def __init__(self,cid,trackerID,command,action,label,value,appName,appVer):
+ self.cid = cid
+ self.trackerID = trackerID
+ self.command = command
+ self.action = action
+ self.label = label
+ self.value = value
+ self.appName = appName
+ self.appVer = appVer
+ self.generate_payload()
+ self.size = sys.getsizeof(self.payload)
+
+ def generate_payload(self):
+ self.payload ='v=1&t=event&tid='+str(self.trackerID)
+ self.payload+='&cid='+str(self.cid)
+ self.payload+='&ec='+str(self.command)
+ self.payload+='&ea='+str(self.action)
+ self.payload+='&el='+str(self.label)
+ self.payload+='&ev='+str(self.value)
+ self.payload+='&an='+str(self.appName)
+ self.payload+='&av='+str(self.appVer)
+
+#..................................................................class GA_EXCEPTION_ObjClass................................................................
+#ExceptionFatal - BOOLEAN
+class GA_EXCEPTION_ObjClass:
+ def __init__(self,cid,trackerID,ExceptionName,ExceptionFatal,appName,appVer):
+ self.cid = cid
+ self.trackerID = trackerID
+ self.ExceptionName = ExceptionName
+ self.ExceptionFatal = ExceptionFatal
+ self.appName = appName
+ self.appVer = appVer
+ self.generate_payload()
+
+ def generate_payload(self):
+ self.payload ='v=1&t=exception&tid='+str(self.trackerID)
+ self.payload+='&cid='+str(self.cid)
+ self.payload+='&exd='+str(self.ExceptionName)
+ self.payload+='&exf='+str(self.ExceptionFatal)
+ self.payload+='&an='+str(self.appName)
+ self.payload+='&av='+str(self.appVer)
+
+#.....................................................................class ga_Thread.................................................................
+"""
+
+Google analytics thread manager:
+
+will report and empty queue of google analytics items to GA server, every Timeout (parameter given on initialization)
+will perform connectivity check every timeout*10 seconds
+
+
+"""
+
+class ga_Thread (threading.Thread):
+ def __init__(self,threadID,gManager):
+ threading.Thread.__init__(self)
+ self.threadID = threadID
+ self.gManager = gManager
+
+ def run(self):
+ keepAliveCounter=0
+ #sys.stdout.write('thread started \n')
+ #sys.stdout.flush()
+ while True:
+ if (keepAliveCounter==10):
+ keepAliveCounter=0
+ if (self.gManager.internet_on()==True):
+ self.gManager.connectedToInternet=1
+ else:
+ self.gManager.connectedToInternet=0
+ sleep(self.gManager.Timeout)
+ keepAliveCounter+=1
+ if not self.gManager.GA_q.empty():
+ self.gManager.threadLock.acquire(1)
+# sys.stdout.write('lock acquired: reporting to GA \n')
+# sys.stdout.flush()
+ if (self.gManager.connectedToInternet==1):
+ self.gManager.emptyAndReportQ()
+ self.gManager.threadLock.release()
+# sys.stdout.write('finished \n')
+# sys.stdout.flush()
+
+
+
+#.....................................................................class GAmanager.................................................................
+"""
+
+Google ID - specify tracker property, example: UA-75220362-2 (when the suffix '2' specifies the analytics property profile)
+
+UserID - unique userID, this will differ between users on GA
+
+appName - s string to determine app name
+
+appVer - a string to determine app version
+
+QueueSize - the size of the queue that holds reported items. once the Queue is full:
+ on blocking mode:
+ will block program until next submission to GA server, which will make new space
+ on non-blocking mode:
+ will drop new requests
+
+Timout - the timeout the queue uses between data transmissions. Timeout should be shorter than the time it takes to generate 20 events. MIN VALUE = 11 seconds
+
+User Permission - the user must accept data transmission, use this flag as 1/0 flag, when UserPermission=1 allows data collection
+
+BlockingMode - set to 1 if you wish every Google Analytic Object will be submitted and processed, with no drops allowed.
+ this will block the running of the program until every item is processed
+
+*** Restriction - Google's restriction for amount of packages being sent per session per second is: 1 event per second, per session. session length is 30min ***
+"""
+class GAmanager:
+ def __init__(self,GoogleID,UserID,appName,appVer,QueueSize,Timeout,UserPermission,BlockingMode):
+ self.UserID = UserID
+ self.GoogleID = GoogleID
+ self.QueueSize = QueueSize
+ self.Timeout = Timeout
+ self.appName = appName
+ self.appVer = appVer
+ self.UserPermission = UserPermission
+ self.GA_q = Queue.Queue(QueueSize)
+ self.thread = ga_Thread(UserID,self)
+ self.threadLock = threading.Lock()
+ self.BlockingMode = BlockingMode
+ self.connectedToInternet =0
+ if (self.internet_on()==True):
+# sys.stdout.write('internet connection active \n')
+# sys.stdout.flush()
+ self.connectedToInternet=1
+ else:
+ self.connectedToInternet=0
+
+ def gaAddAction(self,Event,action,label,value):
+ self.gaAddObject(GA_EVENT_ObjClass(self.UserID,self.GoogleID,Event,action,label,value,self.appName,self.appVer))
+
+ def gaAddException(self,ExceptionName,ExceptionFatal):
+ self.gaAddObject(GA_EXCEPTION_ObjClass(self.UserID,self.GoogleID,ExceptionName,ExceptionFatal,self.appName,self.appVer))
+
+ def gaAddObject(self,Object):
+ if self.BlockingMode==1:
+ while self.GA_q.full():
+ sleep(self.Timeout)
+# sys.stdout.write('blocking mode=1 \n queue full - sleeping for timeout \n') # within Timout, the thread will empty part of the queue
+# sys.stdout.flush()
+ lockState = self.threadLock.acquire(self.BlockingMode)
+ if lockState==1:
+# sys.stdout.write('got lock, adding item \n')
+# sys.stdout.flush()
+ try:
+ self.GA_q.put_nowait(Object)
+# sys.stdout.write('got lock, item added \n')
+# sys.stdout.flush()
+ except Queue.Full:
+# sys.stdout.write('Queue full \n')
+# sys.stdout.flush()
+ pass
+ self.threadLock.release()
+
+ def emptyQueueToList(self,obj_list):
+ items=0
+ while ((not self.GA_q.empty()) and (items<20)):
+ obj_list.append(self.GA_q.get_nowait().payload)
+ items+=1
+# print items
+
+ def reportBatched(self,batched):
+ req = Request(url_batched, data=batched.encode('ascii'))
+ urlopen(req)
+ #requests.post(url_batched,data=batched)
+
+ def emptyAndReportQ(self):
+ obj_list = []
+ self.emptyQueueToList(obj_list)
+ if not len(obj_list):
+ return
+ batched = '\n'.join(obj_list)
+# print batched # - for debug
+ self.reportBatched(batched)
+
+ def printSelf(self):
+ print('remaining in queue:')
+ while not self.GA_q.empty():
+ obj = self.GA_q.get_nowait()
+ print(obj.payload)
+
+ def internet_on(self):
+ try:
+ urlopen(url_conn,timeout=10)
+ return True
+ except URLError as err: pass
+ return False
+
+ def activate(self):
+ if (self.UserPermission==1):
+ self.thread.start()
+
+
+
+#***************************************------TEST--------------**************************************
+
+if __name__ == '__main__':
+ g = GAmanager(GoogleID='UA-75220362-4',UserID="Foo",QueueSize=100,Timeout=5,UserPermission=1,BlockingMode=1,appName='TRex',appVer='1.11.232') #timeout in seconds
+#for i in range(0,35,1):
+#i = 42
+ g.gaAddAction(Event='stl',action='stl/udp_1pkt_simple.py {packet_count:1000,packet_len:9000}',label='Boo',value=20)
+ #g.gaAddAction(Event='test',action='start',label='Boo1',value=20)
+
+#g.gaAddException('MEMFAULT',1)
+#g.gaAddException('MEMFAULT',1)
+#g.gaAddException('MEMFAULT',1)
+#g.gaAddException('MEMFAULT',1)
+#g.gaAddException('MEMFAULT',1)
+#g.gaAddException('MEMFAULT',1)
+ g.emptyAndReportQ()
+# g.printSelf()
+#print g.payload
+#print g.size
+
+
+
+
+#g.activate()
+#g.gaAddAction(Event='test',action='start',label='1',value='1')
+#sys.stdout.write('element added \n')
+#sys.stdout.flush()
+#g.gaAddAction(Event='test',action='start',label='2',value='1')
+#sys.stdout.write('element added \n')
+#sys.stdout.flush()
+#g.gaAddAction(Event='test',action='start',label='3',value='1')
+#sys.stdout.write('element added \n')
+#sys.stdout.flush()
+
+#testdata = "v=1&t=event&tid=UA-75220362-4&cid=2&ec=test&ea=testing&el=testpacket&ev=2"
+#r = requests.post(url_debug,data=testdata)
+#print r
+
+#thread1 = ga_Thread(1,g)
+#thread1.start()
+#thread1.join()
+#for i in range(1,10,1):
+# sys.stdout.write('yesh %d'% (i))
+# sys.stdout.flush()
+
+
+# add timing mechanism - DONE
+# add exception mechanism - DONE
+# add version mechanism - DONE
+# ask Itay for unique ID generation per user
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/filters.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/filters.py new file mode 100644 index 00000000..714f7807 --- /dev/null +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/filters.py @@ -0,0 +1,144 @@ + +def shallow_copy(x): + return type(x)(x) + + +class ToggleFilter(object): + """ + This class provides a "sticky" filter, that works by "toggling" items of the original database on and off. + """ + def __init__(self, db_ref, show_by_default=True): + """ + Instantiate a ToggleFilter object + + :parameters: + db_ref : iterable + an iterable object (i.e. list, set etc) that would serve as the reference db of the instance. + Changes in that object will affect the output of ToggleFilter instance. + + show_by_default: bool + decide if by default all the items are "on", i.e. these items will be presented if no other + toggling occurred. + + default value : **True** + + """ + self._data = db_ref + self._toggle_db = set() + self._filter_method = filter + self.__set_initial_state(show_by_default) + + def reset (self): + """ + Toggles off all the items + """ + self._toggle_db = set() + + + def toggle_item(self, item_key): + """ + Toggle a single item in/out. + + :parameters: + item_key : + an item the by its value the filter can decide to toggle or not. + Example: int, str and so on. + + :return: + + **True** if item toggled **into** the filtered items + + **False** if item toggled **out from** the filtered items + + :raises: + + KeyError, in case if item key is not part of the toggled list and not part of the referenced db. + + """ + if item_key in self._toggle_db: + self._toggle_db.remove(item_key) + return False + elif item_key in self._data: + self._toggle_db.add(item_key) + return True + else: + raise KeyError("Provided item key isn't a key of the referenced data structure.") + + def toggle_items(self, *args): + """ + Toggle multiple items in/out with a single call. Each item will be ha. + + :parameters: + args : iterable + an iterable object containing all item keys to be toggled in/out + + :return: + + **True** if all toggled items were toggled **into** the filtered items + + **False** if at least one of the items was toggled **out from** the filtered items + + :raises: + + KeyError, in case if ont of the item keys was not part of the toggled list and not part of the referenced db. + + """ + # in python 3, 'map' returns an iterator, so wrapping with 'list' call creates same effect for both python 2 and 3 + return all(list(map(self.toggle_item, args))) + + def filter_items(self): + """ + Filters the pointed database by showing only the items mapped at toggle_db set. + + :returns: + Filtered data of the original object. + + """ + return self._filter_method(self.__toggle_filter, self._data) + + # private methods + + def __set_initial_state(self, show_by_default): + try: + _ = (x for x in self._data) + if isinstance(self._data, dict): + self._filter_method = ToggleFilter.dict_filter + if show_by_default: + self._toggle_db = set(self._data.keys()) + return + elif isinstance(self._data, list): + self._filter_method = ToggleFilter.list_filter + elif isinstance(self._data, set): + self._filter_method = ToggleFilter.set_filter + elif isinstance(self._data, tuple): + self._filter_method = ToggleFilter.tuple_filter + if show_by_default: + self._toggle_db = set(shallow_copy(self._data)) # assuming all relevant data with unique identifier + return + except TypeError: + raise TypeError("provided data object is not iterable") + + def __toggle_filter(self, x): + return (x in self._toggle_db) + + # static utility methods + + @staticmethod + def dict_filter(function, iterable): + assert isinstance(iterable, dict) + return {k: v + for k,v in iterable.items() + if function(k)} + + @staticmethod + def list_filter(function, iterable): + # in python 3, filter returns an iterator, so wrapping with list creates same effect for both python 2 and 3 + return list(filter(function, iterable)) + + @staticmethod + def set_filter(function, iterable): + return {x + for x in iterable + if function(x)} + + @staticmethod + def tuple_filter(function, iterable): + return tuple(filter(function, iterable)) + + +if __name__ == "__main__": + pass diff --git a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py index bc2d44f4..5c0dfb14 100644 --- a/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py +++ b/scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py @@ -124,16 +124,9 @@ def underline(text): def text_attribute(text, attribute): - if isinstance(text, str): - return "{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'], - txt=text, - stop=TEXT_CODES[attribute]['end']) - elif isinstance(text, unicode): - return u"{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'], - txt=text, - stop=TEXT_CODES[attribute]['end']) - else: - raise Exception("not a string") + return "{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'], + txt=text, + stop=TEXT_CODES[attribute]['end']) FUNC_DICT = {'blue': blue, |