From ceed73403bdb61387d04be8b47183e9c4a970749 Mon Sep 17 00:00:00 2001 From: Ondrej Fabry Date: Tue, 23 Jun 2020 14:10:53 +0200 Subject: Fix codec fallback and generate type imports Change-Id: Idd76c7f19d952939caf153928ac60175845078ff Signed-off-by: Ondrej Fabry --- binapigen/binapigen.go | 9 -- binapigen/generate.go | 306 +++++++++++++++++------------------ binapigen/generate_rpc.go | 49 ++++-- binapigen/generate_test.go | 25 ++- binapigen/generator.go | 127 +++++++++++++-- binapigen/generator_test.go | 9 ++ binapigen/run.go | 65 ++++---- binapigen/validate.go | 3 +- binapigen/vppapi/integration_test.go | 2 - 9 files changed, 343 insertions(+), 252 deletions(-) (limited to 'binapigen') diff --git a/binapigen/binapigen.go b/binapigen/binapigen.go index 0178476..c5a976b 100644 --- a/binapigen/binapigen.go +++ b/binapigen/binapigen.go @@ -98,9 +98,7 @@ func (file *File) importedFiles(gen *Generator) []*File { logf("file %s import %s not found API files", file.Name, imp) continue } - //if gen.ImportTypes || impFile.Generate { files = append(files, impFile) - //} } return files } @@ -302,13 +300,6 @@ func newField(gen *Generator, file *File, apitype vppapi.Field) *Field { return typ } -func (f *Field) resolveType(gen *Generator) error { - switch { - - } - return nil -} - type Service = vppapi.Service type RPC = vppapi.RPC diff --git a/binapigen/generate.go b/binapigen/generate.go index 1f9b89a..8a34445 100644 --- a/binapigen/generate.go +++ b/binapigen/generate.go @@ -15,16 +15,13 @@ package binapigen import ( - "bytes" "fmt" "io" - "os/exec" - "path" - "path/filepath" "sort" "strings" "git.fd.io/govpp.git/version" + "github.com/sirupsen/logrus" ) // generatedCodeVersion indicates a version of the generated code. @@ -33,7 +30,7 @@ import ( // a constant, api.GoVppAPIPackageIsVersionN (where N is generatedCodeVersion). const generatedCodeVersion = 2 -// message field names +// common message fields const ( msgIdField = "_vl_msg_id" clientIndexField = "client_index" @@ -41,23 +38,16 @@ const ( retvalField = "retval" ) +// global API info const ( - outputFileExt = ".ba.go" // file extension of the Go generated files - rpcFileSuffix = "_rpc" // file name suffix for the RPC services - constModuleName = "ModuleName" // module name constant constAPIVersion = "APIVersion" // API version constant constVersionCrc = "VersionCrc" // version CRC constant +) +// generated fiels +const ( unionDataField = "XXX_UnionData" // name for the union data field - - serviceApiName = "RPCService" // name for the RPC service interface - serviceImplName = "serviceClient" // name for the RPC service implementation - serviceClientName = "ServiceClient" // name for the RPC service client - - // TODO: register service descriptor - //serviceDescType = "ServiceDesc" // name for service descriptor type - //serviceDescName = "_ServiceRPC_serviceDesc" // name for service descriptor var ) // MessageType represents the type of a VPP message @@ -70,22 +60,90 @@ const ( otherMessage // other VPP message ) -type GenFile struct { - *Generator - filename string - file *File - packageDir string - buf bytes.Buffer -} - -func generatePackage(ctx *GenFile, w io.Writer) { +func generateFileBinapi(ctx *GenFile, w io.Writer) { logf("----------------------------") - logf("generating binapi package: %q", ctx.file.PackageName) + logf("generating BINAPI file package: %q", ctx.file.PackageName) logf("----------------------------") - generateHeader(ctx, w) + // generate file header + fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.") + fmt.Fprintln(w, "// versions:") + fmt.Fprintf(w, "// binapi-generator: %s\n", version.Version()) + if ctx.IncludeVppVersion { + fmt.Fprintf(w, "// VPP: %s\n", ctx.VPPVersion) + } + fmt.Fprintf(w, "// source: %s\n", ctx.file.Path) + fmt.Fprintln(w) + + generatePackageHeader(ctx, w) generateImports(ctx, w) + generateApiInfo(ctx, w) + generateTypes(ctx, w) + generateMessages(ctx, w) + + generateImportRefs(ctx, w) +} + +func generatePackageHeader(ctx *GenFile, w io.Writer) { + fmt.Fprintln(w, "/*") + fmt.Fprintf(w, "Package %s contains generated code for VPP API file %s.api (%s).\n", + ctx.file.PackageName, ctx.file.Name, ctx.file.Version()) + fmt.Fprintln(w) + fmt.Fprintln(w, "It consists of:") + printObjNum := func(obj string, num int) { + if num > 0 { + if num > 1 { + if strings.HasSuffix(obj, "s") { + obj += "es" + } else { + obj += "s" + } + } + fmt.Fprintf(w, "\t%3d %s\n", num, obj) + } + } + printObjNum("alias", len(ctx.file.Aliases)) + printObjNum("enum", len(ctx.file.Enums)) + printObjNum("message", len(ctx.file.Messages)) + printObjNum("type", len(ctx.file.Structs)) + printObjNum("union", len(ctx.file.Unions)) + fmt.Fprintln(w, "*/") + fmt.Fprintf(w, "package %s\n", ctx.file.PackageName) + fmt.Fprintln(w) +} + +func generateImports(ctx *GenFile, w io.Writer) { + fmt.Fprintln(w, "import (") + fmt.Fprintln(w, ` "bytes"`) + fmt.Fprintln(w, ` "context"`) + fmt.Fprintln(w, ` "encoding/binary"`) + fmt.Fprintln(w, ` "io"`) + fmt.Fprintln(w, ` "math"`) + fmt.Fprintln(w, ` "strconv"`) + fmt.Fprintln(w) + fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api") + fmt.Fprintf(w, "\tcodec \"%s\"\n", "git.fd.io/govpp.git/codec") + fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc") + imports := listImports(ctx) + if len(imports) > 0 { + fmt.Fprintln(w) + for imp, importPath := range imports { + fmt.Fprintf(w, "\t%s \"%s\"\n", imp, importPath) + } + } + fmt.Fprintln(w, ")") + fmt.Fprintln(w) + + fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file") + fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.") + fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the") + fmt.Fprintln(w, "// GoVPP api package needs to be updated.") + fmt.Fprintf(w, "const _ = api.GoVppAPIPackageIsVersion%d // please upgrade the GoVPP api package\n", generatedCodeVersion) + fmt.Fprintln(w) +} + +func generateApiInfo(ctx *GenFile, w io.Writer) { // generate module desc fmt.Fprintln(w, "const (") fmt.Fprintf(w, "\t// %s is the name of this module.\n", constModuleName) @@ -99,7 +157,9 @@ func generatePackage(ctx *GenFile, w io.Writer) { } fmt.Fprintln(w, ")") fmt.Fprintln(w) +} +func generateTypes(ctx *GenFile, w io.Writer) { // generate enums if len(ctx.file.Enums) > 0 { for _, enum := range ctx.file.Enums { @@ -143,129 +203,41 @@ func generatePackage(ctx *GenFile, w io.Writer) { generateUnion(ctx, w, union) } } - - // generate messages - if len(ctx.file.Messages) > 0 { - for _, msg := range ctx.file.Messages { - generateMessage(ctx, w, msg) - } - - initFnName := fmt.Sprintf("file_%s_binapi_init", ctx.file.PackageName) - - // generate message registrations - fmt.Fprintf(w, "func init() { %s() }\n", initFnName) - fmt.Fprintf(w, "func %s() {\n", initFnName) - for _, msg := range ctx.file.Messages { - fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", - msg.GoName, ctx.file.Name+"."+msg.GoName) - } - fmt.Fprintln(w, "}") - fmt.Fprintln(w) - - // generate list of messages - fmt.Fprintf(w, "// Messages returns list of all messages in this module.\n") - fmt.Fprintln(w, "func AllMessages() []api.Message {") - fmt.Fprintln(w, "\treturn []api.Message{") - for _, msg := range ctx.file.Messages { - fmt.Fprintf(w, "\t(*%s)(nil),\n", msg.GoName) - } - fmt.Fprintln(w, "}") - fmt.Fprintln(w, "}") - } - - generateFooter(ctx, w) - } -func generateHeader(ctx *GenFile, w io.Writer) { - fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.") - fmt.Fprintln(w, "// versions:") - fmt.Fprintf(w, "// binapi-generator: %s\n", version.Version()) - if ctx.IncludeVppVersion { - fmt.Fprintf(w, "// VPP: %s\n", ctx.VPPVersion) +func generateMessages(ctx *GenFile, w io.Writer) { + if len(ctx.file.Messages) == 0 { + return } - fmt.Fprintf(w, "// source: %s\n", ctx.file.Path) - fmt.Fprintln(w) - - fmt.Fprintln(w, "/*") - fmt.Fprintf(w, "Package %s contains generated code for VPP binary API defined by %s.api (version %s).\n", - ctx.file.PackageName, ctx.file.Name, ctx.file.Version()) - fmt.Fprintln(w) - fmt.Fprintln(w, "It consists of:") - printObjNum := func(obj string, num int) { - if num > 0 { - if num > 1 { - if strings.HasSuffix(obj, "s") { - obj += "es" - } else { - obj += "s" - } - } - fmt.Fprintf(w, "\t%3d %s\n", num, obj) - } + for _, msg := range ctx.file.Messages { + generateMessage(ctx, w, msg) } - //printObjNum("RPC", len(ctx.file.Service.RPCs)) - printObjNum("alias", len(ctx.file.Aliases)) - printObjNum("enum", len(ctx.file.Enums)) - printObjNum("message", len(ctx.file.Messages)) - printObjNum("type", len(ctx.file.Structs)) - printObjNum("union", len(ctx.file.Unions)) - fmt.Fprintln(w, "*/") - fmt.Fprintf(w, "package %s\n", ctx.file.PackageName) - fmt.Fprintln(w) -} -func generateImports(ctx *GenFile, w io.Writer) { - fmt.Fprintln(w, "import (") - fmt.Fprintln(w, ` "bytes"`) - fmt.Fprintln(w, ` "context"`) - fmt.Fprintln(w, ` "encoding/binary"`) - fmt.Fprintln(w, ` "io"`) - fmt.Fprintln(w, ` "math"`) - fmt.Fprintln(w, ` "strconv"`) - fmt.Fprintln(w) - fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api") - fmt.Fprintf(w, "\tcodec \"%s\"\n", "git.fd.io/govpp.git/codec") - fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc") - if len(ctx.file.Imports) > 0 { - fmt.Fprintln(w) - for _, imp := range ctx.file.Imports { - importPath := path.Join(ctx.ImportPrefix, imp) - if ctx.ImportPrefix == "" { - importPath = getImportPath(ctx.packageDir, imp) - } - fmt.Fprintf(w, "\t%s \"%s\"\n", imp, strings.TrimSpace(importPath)) - } - } - fmt.Fprintln(w, ")") - fmt.Fprintln(w) + // generate message registrations + initFnName := fmt.Sprintf("file_%s_binapi_init", ctx.file.PackageName) - fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file") - fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.") - fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the") - fmt.Fprintln(w, "// GoVPP api package needs to be updated.") - fmt.Fprintf(w, "const _ = api.GoVppAPIPackageIsVersion%d // please upgrade the GoVPP api package\n", generatedCodeVersion) + fmt.Fprintf(w, "func init() { %s() }\n", initFnName) + fmt.Fprintf(w, "func %s() {\n", initFnName) + for _, msg := range ctx.file.Messages { + fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", + msg.GoName, ctx.file.Name+"."+msg.GoName) + } + fmt.Fprintln(w, "}") fmt.Fprintln(w) -} -func getImportPath(outputDir string, pkg string) string { - absPath, err := filepath.Abs(filepath.Join(outputDir, "..", pkg)) - if err != nil { - panic(err) - } - cmd := exec.Command("go", "list", absPath) - var errbuf, outbuf bytes.Buffer - cmd.Stdout = &outbuf - cmd.Stderr = &errbuf - if err := cmd.Run(); err != nil { - fmt.Printf("ERR: %v\n", errbuf.String()) - panic(err) + // generate list of messages + fmt.Fprintf(w, "// Messages returns list of all messages in this module.\n") + fmt.Fprintln(w, "func AllMessages() []api.Message {") + fmt.Fprintln(w, "\treturn []api.Message{") + for _, msg := range ctx.file.Messages { + fmt.Fprintf(w, "\t(*%s)(nil),\n", msg.GoName) } - return outbuf.String() + fmt.Fprintln(w, "}") + fmt.Fprintln(w, "}") } -func generateFooter(ctx *GenFile, w io.Writer) { +func generateImportRefs(ctx *GenFile, w io.Writer) { fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n") fmt.Fprintf(w, "var _ = api.RegisterMessage\n") fmt.Fprintf(w, "var _ = codec.DecodeString\n") @@ -522,7 +494,7 @@ func generateMessage(ctx *GenFile, w io.Writer, msg *Message) { // skip internal fields switch strings.ToLower(field.Name) { - case /*crcField,*/ msgIdField: + case msgIdField: continue case clientIndexField, contextField: if n == 0 { @@ -590,20 +562,22 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field } lvl := 0 - var encodeFields func(fields []*Field, parentName string) - encodeFields = func(fields []*Field, parentName string) { + var sizeFields func(fields []*Field, parentName string) + sizeFields = func(fields []*Field, parentName string) { lvl++ defer func() { lvl-- }() n := 0 for _, field := range fields { - // skip internal fields - switch strings.ToLower(field.Name) { - case /*crcField,*/ msgIdField: - continue - case clientIndexField, contextField: - if n == 0 { + if field.ParentMessage != nil { + // skip internal fields + switch strings.ToLower(field.Name) { + case msgIdField: continue + case clientIndexField, contextField: + if n == 0 { + continue + } } } n++ @@ -646,12 +620,12 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field } else if alias := getAliasByRef(ctx.file, field.Type); alias != nil { if encodeBaseType(alias.Type, name, alias.Length, "") { } else if typ := getTypeByRef(ctx.file, alias.Type); typ != nil { - encodeFields(typ.Fields, name) + sizeFields(typ.Fields, name) } else { fmt.Fprintf(w, "\t// ??? ALIAS %s %s\n", name, alias.Type) } } else if typ := getTypeByRef(ctx.file, field.Type); typ != nil { - encodeFields(typ.Fields, name) + sizeFields(typ.Fields, name) } else if union := getUnionByRef(ctx.file, field.Type); union != nil { maxSize := getUnionSize(ctx.file, union) fmt.Fprintf(w, "\tsize += %d\n", maxSize) @@ -665,7 +639,7 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field } } - encodeFields(fields, "m") + sizeFields(fields, "m") fmt.Fprintf(w, "return size\n") @@ -786,13 +760,15 @@ func generateMessageMarshal(ctx *GenFile, w io.Writer, name string, fields []*Fi n := 0 for _, field := range fields { - // skip internal fields - switch strings.ToLower(field.Name) { - case /*crcField,*/ msgIdField: - continue - case clientIndexField, contextField: - if n == 0 { + if field.ParentMessage != nil { + // skip internal fields + switch strings.ToLower(field.Name) { + case msgIdField: continue + case clientIndexField, contextField: + if n == 0 { + continue + } } } n++ @@ -1004,13 +980,15 @@ func generateMessageUnmarshal(ctx *GenFile, w io.Writer, name string, fields []* n := 0 for _, field := range fields { - // skip internal fields - switch strings.ToLower(field.Name) { - case /*crcField,*/ msgIdField: - continue - case clientIndexField, contextField: - if n == 0 { + if field.ParentMessage != nil { + // skip internal fields + switch strings.ToLower(field.Name) { + case msgIdField: continue + case clientIndexField, contextField: + if n == 0 { + continue + } } } n++ @@ -1239,3 +1217,7 @@ func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageTy fmt.Fprintln(w, "}") fmt.Fprintln(w) } + +func logf(f string, v ...interface{}) { + logrus.Debugf(f, v...) +} diff --git a/binapigen/generate_rpc.go b/binapigen/generate_rpc.go index b480f4a..4beec04 100644 --- a/binapigen/generate_rpc.go +++ b/binapigen/generate_rpc.go @@ -20,17 +20,31 @@ import ( "strings" ) -func generatePackageRPC(ctx *GenFile, w io.Writer) { +// generated service names +const ( + serviceApiName = "RPCService" // name for the RPC service interface + serviceImplName = "serviceClient" // name for the RPC service implementation + serviceClientName = "ServiceClient" // name for the RPC service client + + // TODO: register service descriptor + //serviceDescType = "ServiceDesc" // name for service descriptor type + //serviceDescName = "_ServiceRPC_serviceDesc" // name for service descriptor var +) + +func generateFileRPC(ctx *GenFile, w io.Writer) { logf("----------------------------") - logf("generating RPC package: %q", ctx.file.PackageName) + logf("generating RPC file package: %q", ctx.file.PackageName) logf("----------------------------") + // generate file header fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.") fmt.Fprintln(w) + // generate package header fmt.Fprintf(w, "package %s\n", ctx.file.PackageName) fmt.Fprintln(w) + // generate imports fmt.Fprintln(w, "import (") fmt.Fprintln(w, ` "context"`) fmt.Fprintln(w, ` "io"`) @@ -39,9 +53,9 @@ func generatePackageRPC(ctx *GenFile, w io.Writer) { fmt.Fprintln(w, ")") fmt.Fprintln(w) - // generate services + // generate RPC service if ctx.file.Service != nil && len(ctx.file.Service.RPCs) > 0 { - generateServiceMethods(ctx, w, ctx.file.Service.RPCs) + generateService(ctx, w, ctx.file.Service) } // generate message registrations @@ -50,6 +64,7 @@ func generatePackageRPC(ctx *GenFile, w io.Writer) { fmt.Fprintln(w, "}") fmt.Fprintln(w)*/ + // generate import refs fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n") fmt.Fprintf(w, "var _ = api.RegisterMessage\n") fmt.Fprintf(w, "var _ = context.Background\n") @@ -57,14 +72,14 @@ func generatePackageRPC(ctx *GenFile, w io.Writer) { } -func generateServiceMethods(ctx *GenFile, w io.Writer, methods []RPC) { +func generateService(ctx *GenFile, w io.Writer, svc *Service) { // generate services comment generateComment(ctx, w, serviceApiName, "services", "service") // generate service api fmt.Fprintf(w, "type %s interface {\n", serviceApiName) - for _, svc := range methods { - generateServiceMethod(ctx, w, &svc) + for _, rpc := range svc.RPCs { + generateRPCMethod(ctx, w, &rpc) fmt.Fprintln(w) } fmt.Fprintln(w, "}") @@ -82,21 +97,21 @@ func generateServiceMethods(ctx *GenFile, w io.Writer, methods []RPC) { fmt.Fprintln(w, "}") fmt.Fprintln(w) - for _, met := range methods { - method := camelCaseName(met.RequestMsg) + for _, rpc := range svc.RPCs { + method := camelCaseName(rpc.RequestMsg) if m := strings.TrimSuffix(method, "Dump"); method != m { method = "Dump" + m } fmt.Fprintf(w, "func (c *%s) ", serviceImplName) - generateServiceMethod(ctx, w, &met) + generateRPCMethod(ctx, w, &rpc) fmt.Fprintln(w, " {") - if met.Stream { + if rpc.Stream { streamImpl := fmt.Sprintf("%s_%sClient", serviceImplName, method) fmt.Fprintf(w, "\tstream := c.ch.SendMultiRequest(in)\n") fmt.Fprintf(w, "\tx := &%s{stream}\n", streamImpl) fmt.Fprintf(w, "\treturn x, nil\n") - } else if replyTyp := camelCaseName(met.ReplyMsg); replyTyp != "" { + } else if replyTyp := camelCaseName(rpc.ReplyMsg); replyTyp != "" { fmt.Fprintf(w, "\tout := new(%s)\n", replyTyp) fmt.Fprintf(w, "\terr:= c.ch.SendRequest(in).ReceiveReply(out)\n") fmt.Fprintf(w, "\tif err != nil { return nil, err }\n") @@ -108,9 +123,9 @@ func generateServiceMethods(ctx *GenFile, w io.Writer, methods []RPC) { fmt.Fprintln(w, "}") fmt.Fprintln(w) - if met.Stream { - replyTyp := camelCaseName(met.ReplyMsg) - method := camelCaseName(met.RequestMsg) + if rpc.Stream { + replyTyp := camelCaseName(rpc.ReplyMsg) + method := camelCaseName(rpc.RequestMsg) if m := strings.TrimSuffix(method, "Dump"); method != m { method = "Dump" + m } @@ -143,7 +158,7 @@ func generateServiceMethods(ctx *GenFile, w io.Writer, methods []RPC) { fmt.Fprintf(w, "\tServiceName: \"%s\",\n", ctx.moduleName) fmt.Fprintf(w, "\tHandlerType: (*%s)(nil),\n", serviceApiName) fmt.Fprintf(w, "\tMethods: []api.MethodDesc{\n") - for _, method := range methods { + for _, method := range rpcs { fmt.Fprintf(w, "\t {\n") fmt.Fprintf(w, "\t MethodName: \"%s\",\n", method.Name) fmt.Fprintf(w, "\t },\n") @@ -157,7 +172,7 @@ func generateServiceMethods(ctx *GenFile, w io.Writer, methods []RPC) { fmt.Fprintln(w) } -func generateServiceMethod(ctx *GenFile, w io.Writer, rpc *RPC) { +func generateRPCMethod(ctx *GenFile, w io.Writer, rpc *RPC) { reqTyp := camelCaseName(rpc.RequestMsg) logf(" writing RPC: %+v", reqTyp) diff --git a/binapigen/generate_test.go b/binapigen/generate_test.go index 5a2a07a..aab62cd 100644 --- a/binapigen/generate_test.go +++ b/binapigen/generate_test.go @@ -25,14 +25,13 @@ import ( const testOutputDir = "test_output_directory" -func GenerateFromFile(apiDir, outputDir string, opts Options) error { - // parse API files - apifiles, err := vppapi.ParseDir(apiDir) +func GenerateFromFile(file, outputDir string, opts Options) error { + apifile, err := vppapi.ParseFile(file) if err != nil { return err } - g, err := New(opts, apifiles) + g, err := New(opts, []*vppapi.File{apifile}) if err != nil { return err } @@ -40,7 +39,7 @@ func GenerateFromFile(apiDir, outputDir string, opts Options) error { if !file.Generate { continue } - GenerateBinapiFile(g, file, outputDir) + GenerateBinapi(g, file, outputDir) if file.Service != nil { GenerateRPC(g, file, outputDir) } @@ -59,7 +58,7 @@ func TestGenerateFromFile(t *testing.T) { // remove directory created during test defer os.RemoveAll(testOutputDir) - err := GenerateFromFile("testdata/acl.api.json", testOutputDir, Options{}) + err := GenerateFromFile("vppapi/testdata/acl.api.json", testOutputDir, Options{FilesToGenerate: []string{"acl"}}) Expect(err).ShouldNot(HaveOccurred()) fileInfo, err := os.Stat(testOutputDir + "/acl/acl.ba.go") Expect(err).ShouldNot(HaveOccurred()) @@ -70,17 +69,17 @@ func TestGenerateFromFile(t *testing.T) { func TestGenerateFromFileInputError(t *testing.T) { RegisterTestingT(t) - err := GenerateFromFile("testdata/nonexisting.json", testOutputDir, Options{}) + err := GenerateFromFile("vppapi/testdata/nonexisting.json", testOutputDir, Options{}) Expect(err).Should(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("invalid input file name")) + Expect(err.Error()).To(ContainSubstring("unsupported")) } func TestGenerateFromFileReadJsonError(t *testing.T) { RegisterTestingT(t) - err := GenerateFromFile("testdata/input-read-json-error.json", testOutputDir, Options{}) + err := GenerateFromFile("vppapi/testdata/input-read-json-error.json", testOutputDir, Options{}) Expect(err).Should(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("invalid input file name")) + Expect(err.Error()).To(ContainSubstring("unsupported")) } func TestGenerateFromFileGeneratePackageError(t *testing.T) { @@ -94,7 +93,7 @@ func TestGenerateFromFileGeneratePackageError(t *testing.T) { os.RemoveAll(testOutputDir) }() - err := GenerateFromFile("testdata/input-generate-error.json", testOutputDir, Options{}) + err := GenerateFromFile("vppapi/testdata/input-generate-error.json", testOutputDir, Options{}) Expect(err).Should(HaveOccurred()) } @@ -147,7 +146,7 @@ func TestGetContextInterfaceJson(t *testing.T) { // prepare writer writer := bufio.NewWriter(outFile) Expect(writer.Buffered()).To(BeZero()) - err = generatePackage(testCtx, writer) + err = generateFileBinapi(testCtx, writer) Expect(err).ShouldNot(HaveOccurred()) } @@ -313,7 +312,7 @@ func TestGeneratePackageHeader(t *testing.T) { // prepare writer writer := bufio.NewWriter(outFile) Expect(writer.Buffered()).To(BeZero()) - generateHeader(testCtx, writer, inFile) + generatePackageHeader(testCtx, writer, inFile) Expect(writer.Buffered()).ToNot(BeZero()) } diff --git a/binapigen/generator.go b/binapigen/generator.go index 9471462..07c1b13 100644 --- a/binapigen/generator.go +++ b/binapigen/generator.go @@ -20,7 +20,9 @@ import ( "go/format" "io/ioutil" "os" + "path" "path/filepath" + "regexp" "github.com/sirupsen/logrus" @@ -94,29 +96,45 @@ func New(opts Options, apifiles []*vppapi.File) (*Generator, error) { } } - logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate) - for _, genfile := range opts.FilesToGenerate { - file, ok := g.FilesByPath[genfile] - if !ok { - file, ok = g.FilesByName[genfile] + if len(opts.FilesToGenerate) > 0 { + logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate) + for _, genfile := range opts.FilesToGenerate { + file, ok := g.FilesByPath[genfile] if !ok { - return nil, fmt.Errorf("no API file found for: %v", genfile) + file, ok = g.FilesByName[genfile] + if !ok { + return nil, fmt.Errorf("no API file found for: %v", genfile) + } } - } - file.Generate = true - if opts.ImportTypes { - for _, impFile := range file.importedFiles(g) { - impFile.Generate = true + file.Generate = true + if opts.ImportTypes { + // generate all imported files + for _, impFile := range file.importedFiles(g) { + impFile.Generate = true + } } } + } else { + logrus.Debugf("Files to generate not specified, marking all %d files to generate", len(g.Files)) + for _, file := range g.Files { + file.Generate = true + } } logrus.Debugf("Resolving imported types") for _, file := range g.Files { if !file.Generate { + // skip resolving for non-generated files continue } - importedFiles := file.importedFiles(g) + var importedFiles []*File + for _, impFile := range file.importedFiles(g) { + if !impFile.Generate { + // exclude imports of non-generated files + continue + } + importedFiles = append(importedFiles, impFile) + } file.loadTypeImports(g, importedFiles) } @@ -130,13 +148,21 @@ func (g *Generator) Generate() error { logrus.Infof("Generating %d files", len(g.genfiles)) for _, genfile := range g.genfiles { - if err := writeSourceTo(genfile.filename, genfile.buf.Bytes()); err != nil { + if err := writeSourceTo(genfile.filename, genfile.Content()); err != nil { return fmt.Errorf("writing source for RPC package %s failed: %v", genfile.filename, err) } } return nil } +type GenFile struct { + *Generator + filename string + file *File + outputDir string + buf bytes.Buffer +} + func (g *Generator) NewGenFile(filename string) *GenFile { f := &GenFile{ Generator: g, @@ -146,6 +172,10 @@ func (g *Generator) NewGenFile(filename string) *GenFile { return f } +func (f *GenFile) Content() []byte { + return f.buf.Bytes() +} + func writeSourceTo(outputFile string, b []byte) error { // create output directory packageDir := filepath.Dir(outputFile) @@ -170,3 +200,74 @@ func writeSourceTo(outputFile string, b []byte) error { return nil } + +func listImports(genfile *GenFile) map[string]string { + var importPath = genfile.ImportPrefix + if importPath == "" { + importPath = resolveImportPath(genfile.outputDir) + logrus.Debugf("resolved import path: %s", importPath) + } + imports := map[string]string{} + for _, imp := range genfile.file.imports { + if _, ok := imports[imp]; !ok { + imports[imp] = path.Join(importPath, imp) + } + } + return imports +} + +func resolveImportPath(outputDir string) string { + absPath, err := filepath.Abs(outputDir) + if err != nil { + panic(err) + } + modRoot := findModuleRoot(absPath) + if modRoot == "" { + logrus.Fatalf("module root not found at: %s", absPath) + } + modPath := findModulePath(path.Join(modRoot, "go.mod")) + if modPath == "" { + logrus.Fatalf("module path not found") + } + relDir, err := filepath.Rel(modRoot, absPath) + if err != nil { + panic(err) + } + return filepath.Join(modPath, relDir) +} + +func findModuleRoot(dir string) (root string) { + if dir == "" { + panic("dir not set") + } + dir = filepath.Clean(dir) + + // Look for enclosing go.mod. + for { + if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() { + return dir + } + d := filepath.Dir(dir) + if d == dir { + break + } + dir = d + } + return "" +} + +var ( + modulePathRE = regexp.MustCompile(`module[ \t]+([^ \t\r\n]+)`) +) + +func findModulePath(file string) string { + data, err := ioutil.ReadFile(file) + if err != nil { + return "" + } + m := modulePathRE.FindSubmatch(data) + if m == nil { + return "" + } + return string(m[1]) +} diff --git a/binapigen/generator_test.go b/binapigen/generator_test.go index ddbda99..9e5b342 100644 --- a/binapigen/generator_test.go +++ b/binapigen/generator_test.go @@ -18,6 +18,15 @@ import ( "testing" ) +func TestModule(t *testing.T) { + const expected = "git.fd.io/govpp.git/examples/binapi" + + impPath := resolveImportPath("../examples/binapi") + if impPath != expected { + t.Fatalf("expected: %q, got: %q", expected, impPath) + } +} + func TestBinapiTypeSizes(t *testing.T) { tests := []struct { name string diff --git a/binapigen/run.go b/binapigen/run.go index 441c43d..e6086ee 100644 --- a/binapigen/run.go +++ b/binapigen/run.go @@ -19,44 +19,13 @@ import ( "os" "path/filepath" - "github.com/sirupsen/logrus" - "git.fd.io/govpp.git/binapigen/vppapi" ) -var debugMode = true - -func logf(f string, v ...interface{}) { - if debugMode { - logrus.Debugf(f, v...) - } -} - -func GenerateBinapiFile(gen *Generator, file *File, outputDir string) *GenFile { - packageDir := filepath.Join(outputDir, file.PackageName) - filename := filepath.Join(packageDir, file.PackageName+outputFileExt) - - g := gen.NewGenFile(filename) - g.file = file - g.packageDir = filepath.Join(outputDir, file.PackageName) - - generatePackage(g, &g.buf) - - return g -} - -func GenerateRPC(gen *Generator, file *File, outputDir string) *GenFile { - packageDir := filepath.Join(outputDir, file.PackageName) - filename := filepath.Join(packageDir, file.PackageName+rpcFileSuffix+outputFileExt) - - g := gen.NewGenFile(filename) - g.file = file - g.packageDir = filepath.Join(outputDir, file.PackageName) - - generatePackageRPC(g, &g.buf) - - return g -} +const ( + outputFileExt = ".ba.go" // file extension of the Go generated files + rpcFileSuffix = "_rpc" // file name suffix for the RPC services +) func Run(apiDir string, opts Options, f func(*Generator) error) { if err := run(apiDir, opts, f); err != nil { @@ -87,3 +56,29 @@ func run(apiDir string, opts Options, f func(*Generator) error) error { return nil } + +func GenerateBinapi(gen *Generator, file *File, outputDir string) *GenFile { + packageDir := filepath.Join(outputDir, file.PackageName) + filename := filepath.Join(packageDir, file.PackageName+outputFileExt) + + g := gen.NewGenFile(filename) + g.file = file + g.outputDir = outputDir + + generateFileBinapi(g, &g.buf) + + return g +} + +func GenerateRPC(gen *Generator, file *File, outputDir string) *GenFile { + packageDir := filepath.Join(outputDir, file.PackageName) + filename := filepath.Join(packageDir, file.PackageName+rpcFileSuffix+outputFileExt) + + g := gen.NewGenFile(filename) + g.file = file + g.outputDir = outputDir + + generateFileRPC(g, &g.buf) + + return g +} diff --git a/binapigen/validate.go b/binapigen/validate.go index 2dae903..a79e148 100644 --- a/binapigen/validate.go +++ b/binapigen/validate.go @@ -17,8 +17,9 @@ package binapigen import ( "strings" - "git.fd.io/govpp.git/binapigen/vppapi" "github.com/sirupsen/logrus" + + "git.fd.io/govpp.git/binapigen/vppapi" ) const ( diff --git a/binapigen/vppapi/integration_test.go b/binapigen/vppapi/integration_test.go index 142017a..9d619b8 100644 --- a/binapigen/vppapi/integration_test.go +++ b/binapigen/vppapi/integration_test.go @@ -30,7 +30,6 @@ func TestParse(t *testing.T) { } for _, file := range files { - //t.Logf(" - %s: %+v", path, module) b, err := json.MarshalIndent(file, "\t", " ") if err != nil { t.Fatal(err) @@ -39,5 +38,4 @@ func TestParse(t *testing.T) { } t.Logf("parsed %d files", len(files)) - } -- cgit 1.2.3-korg