diff options
Diffstat (limited to 'src/tools')
-rwxr-xr-x | src/tools/vppapigen/test_vppapigen.py | 33 | ||||
-rwxr-xr-x | src/tools/vppapigen/vppapigen.py | 49 |
2 files changed, 56 insertions, 26 deletions
diff --git a/src/tools/vppapigen/test_vppapigen.py b/src/tools/vppapigen/test_vppapigen.py index a8a0a49a8db..5b64310e51f 100755 --- a/src/tools/vppapigen/test_vppapigen.py +++ b/src/tools/vppapigen/test_vppapigen.py @@ -26,8 +26,8 @@ class TestTypedef(unittest.TestCase): def test_duplicatetype(self): test_string = ''' - typeonly define foo1 { u8 dummy; }; - typeonly define foo1 { u8 dummy; }; + typedef foo1 { u8 dummy; }; + typedef foo1 { u8 dummy; }; ''' self.assertRaises(KeyError, self.parser.parse_string, test_string) @@ -39,23 +39,29 @@ class TestDefine(unittest.TestCase): def test_unknowntype(self): test_string = 'define foo { foobar foo;};' - self.assertRaises(ParseError, self.parser.parse_string, test_string) + with self.assertRaises(ParseError) as ctx: + self.parser.parse_string(test_string) + self.assertIn('Undefined type: foobar', str(ctx.exception)) + test_string = 'define { u8 foo;};' - self.assertRaises(ParseError, self.parser.parse_string, test_string) + with self.assertRaises(ParseError) as ctx: + self.parser.parse_string(test_string) def test_flags(self): test_string = ''' manual_print dont_trace manual_endian define foo { u8 foo; }; + define foo_reply {u32 context; i32 retval; }; ''' r = self.parser.parse_string(test_string) self.assertIsNotNone(r) s = self.parser.process(r) self.assertIsNotNone(s) - for d in s['defines']: - self.assertTrue(d.dont_trace) - self.assertTrue(d.manual_endian) - self.assertTrue(d.manual_print) - self.assertFalse(d.autoreply) + for d in s['Define']: + if d.name == 'foo': + self.assertTrue(d.dont_trace) + self.assertTrue(d.manual_endian) + self.assertTrue(d.manual_print) + self.assertFalse(d.autoreply) test_string = ''' nonexisting_flag define foo { u8 foo; }; @@ -71,11 +77,14 @@ class TestService(unittest.TestCase): def test_service(self): test_string = ''' - service foo { rpc foo (show_version) returns (show_version) }; + autoreply define show_version { u8 foo;}; + service { rpc show_version returns show_version_reply; }; ''' r = self.parser.parse_string(test_string) - print('R', r) + s = self.parser.process(r) + self.assertEqual(s['Service'][0].caller, 'show_version') + self.assertEqual(s['Service'][0].reply, 'show_version_reply') if __name__ == '__main__': - unittest.main() + unittest.main(verbosity=2) diff --git a/src/tools/vppapigen/vppapigen.py b/src/tools/vppapigen/vppapigen.py index fa7e47afb73..bb4e2c4f5cf 100755 --- a/src/tools/vppapigen/vppapigen.py +++ b/src/tools/vppapigen/vppapigen.py @@ -22,10 +22,15 @@ sys.dont_write_bytecode = True # Global dictionary of new types (including enums) global_types = {} +seen_imports = {} + def global_type_add(name, obj): '''Add new type to the dictionary of types ''' type_name = 'vl_api_' + name + '_t' + if type_name in global_types: + raise KeyError("Attempted redefinition of {!r} with {!r}.".format( + name, obj)) global_types[type_name] = obj @@ -320,20 +325,35 @@ class Enum(): class Import(): - def __init__(self, filename): - self.filename = filename - # Deal with imports - parser = VPPAPI(filename=filename) - dirlist = dirlist_get() - f = filename - for dir in dirlist: - f = os.path.join(dir, filename) - if os.path.exists(f): - break + def __new__(cls, *args, **kwargs): + if args[0] not in seen_imports: + instance = super().__new__(cls) + instance._initialized = False + seen_imports[args[0]] = instance + + return seen_imports[args[0]] - with open(f, encoding='utf-8') as fd: - self.result = parser.parse_file(fd, None) + def __init__(self, filename): + if self._initialized: + return + else: + self.filename = filename + # Deal with imports + parser = VPPAPI(filename=filename) + dirlist = dirlist_get() + f = filename + for dir in dirlist: + f = os.path.join(dir, filename) + if os.path.exists(f): + break + if sys.version[0] == '2': + with open(f) as fd: + self.result = parser.parse_file(fd, None) + else: + with open(f, encoding='utf-8') as fd: + self.result = parser.parse_file(fd, None) + self._initialized = True def __repr__(self): return self.filename @@ -859,9 +879,10 @@ class VPPAPI(object): continue if isinstance(o, Import): result.append(o) - self.process_imports(o.result, True, result) + result = self.process_imports(o.result, True, result) else: result.append(o) + return result # Add message ids to each message. @@ -955,7 +976,7 @@ def main(): if args.output_module == 'C': s = parser.process(parsed_objects) else: - parser.process_imports(parsed_objects, False, result) + result = parser.process_imports(parsed_objects, False, result) s = parser.process(result) # Add msg_id field |