diff options
Diffstat (limited to 'src/vpp-api/python/vpp_papi')
-rw-r--r-- | src/vpp-api/python/vpp_papi/tests/test_vpp_papi.py | 6 | ||||
-rw-r--r-- | src/vpp-api/python/vpp_papi/vpp_papi.py | 68 |
2 files changed, 41 insertions, 33 deletions
diff --git a/src/vpp-api/python/vpp_papi/tests/test_vpp_papi.py b/src/vpp-api/python/vpp_papi/tests/test_vpp_papi.py index 2b21c83966a..51c024aa3ab 100644 --- a/src/vpp-api/python/vpp_papi/tests/test_vpp_papi.py +++ b/src/vpp-api/python/vpp_papi/tests/test_vpp_papi.py @@ -24,8 +24,7 @@ from vpp_papi import vpp_transport_shmem class TestVppPapiVPPApiClient(unittest.TestCase): def test_getcontext(self): - vpp_papi.VPPApiClient.apidir = "." - c = vpp_papi.VPPApiClient(testmode=True, use_socket=True) + c = vpp_papi.VPPApiClient(apidir=".", testmode=True, use_socket=True) # reset initialization at module load time. c.get_context.context = mp.Value(ctypes.c_uint, 0) @@ -39,8 +38,7 @@ class TestVppPapiVPPApiClientMp(unittest.TestCase): # run_tests.py (eg. make test TEST_JOBS=10) def test_get_context_mp(self): - vpp_papi.VPPApiClient.apidir = "." - c = vpp_papi.VPPApiClient(testmode=True, use_socket=True) + c = vpp_papi.VPPApiClient(apidir=".", testmode=True, use_socket=True) # reset initialization at module load time. c.get_context.context = mp.Value(ctypes.c_uint, 0) diff --git a/src/vpp-api/python/vpp_papi/vpp_papi.py b/src/vpp-api/python/vpp_papi/vpp_papi.py index a9edfed81be..5c089647e59 100644 --- a/src/vpp-api/python/vpp_papi/vpp_papi.py +++ b/src/vpp-api/python/vpp_papi/vpp_papi.py @@ -281,16 +281,15 @@ class VPPApiJSONFiles: @classmethod def process_json_file(self, apidef_file): - api = json.load(apidef_file) - return self._process_json(api) + return self._process_json(apidef_file.read()) @classmethod def process_json_str(self, json_str): - api = json.loads(json_str) - return self._process_json(api) + return self._process_json(json_str) @staticmethod - def _process_json(api): # -> Tuple[Dict, Dict] + def _process_json(json_str): # -> Tuple[Dict, Dict] + api = json.loads(json_str) types = {} services = {} messages = {} @@ -380,6 +379,30 @@ class VPPApiJSONFiles: pass return messages, services + @staticmethod + def load_api(apifiles=None, apidir=None): + messages = {} + services = {} + if not apifiles: + # Pick up API definitions from default directory + try: + if isinstance(apidir, list): + apifiles = [] + for d in apidir: + apifiles += VPPApiJSONFiles.find_api_files(d) + else: + apifiles = VPPApiJSONFiles.find_api_files(apidir) + except (RuntimeError, VPPApiError): + raise VPPRuntimeError + + for file in apifiles: + with open(file) as apidef_file: + m, s = VPPApiJSONFiles.process_json_file(apidef_file) + messages.update(m) + services.update(s) + + return apifiles, messages, services + class VPPApiClient: """VPP interface. @@ -394,7 +417,6 @@ class VPPApiClient: these messages in a background thread. """ - apidir = None VPPApiError = VPPApiError VPPRuntimeError = VPPRuntimeError VPPValueError = VPPValueError @@ -405,6 +427,7 @@ class VPPApiClient: self, *, apifiles=None, + apidir=None, testmode=False, async_thread=True, logger=None, @@ -439,6 +462,7 @@ class VPPApiClient: self.id_msgdef = [] self.header = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]]) self.apifiles = [] + self.apidir = apidir self.event_callback = None self.message_queue = queue.Queue() self.read_timeout = read_timeout @@ -449,29 +473,15 @@ class VPPApiClient: self._apifiles = apifiles self.stats = {} - if not apifiles: - # Pick up API definitions from default directory - try: - if isinstance(self.apidir, list): - apifiles = [] - for d in self.apidir: - apifiles += VPPApiJSONFiles.find_api_files(d) - else: - apifiles = VPPApiJSONFiles.find_api_files(self.apidir) - except (RuntimeError, VPPApiError): - # In test mode we don't care that we can't find the API files - if testmode: - apifiles = [] - else: - raise VPPRuntimeError - - for file in apifiles: - with open(file) as apidef_file: - m, s = VPPApiJSONFiles.process_json_file(apidef_file) - self.messages.update(m) - self.services.update(s) - - self.apifiles = apifiles + try: + self.apifiles, self.messages, self.services = VPPApiJSONFiles.load_api( + apifiles, apidir + ) + except VPPRuntimeError as e: + if testmode: + self.apifiles = [] + else: + raise e # Basic sanity check if len(self.messages) == 0 and not testmode: |