// Copyright (c) 2020 Cisco and/or its affiliates. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at: // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package binapigen import ( "fmt" "path" "github.com/sirupsen/logrus" ) func init() { RegisterPlugin("rpc", GenerateRPC) } // library dependencies const ( contextPkg = GoImportPath("context") ioPkg = GoImportPath("io") ) // generated names const ( serviceApiName = "RPCService" // name for the RPC service interface serviceImplName = "serviceClient" // name for the RPC service implementation serviceClientName = "ServiceClient" // name for the RPC service client // TODO: register service descriptor //serviceDescType = "ServiceDesc" // name for service descriptor type //serviceDescName = "_ServiceRPC_serviceDesc" // name for service descriptor var ) func GenerateRPC(gen *Generator, file *File) *GenFile { if file.Service == nil { return nil } logf("----------------------------") logf(" Generate RPC - %s", file.Desc.Name) logf("----------------------------") filename := path.Join(file.FilenamePrefix, file.Desc.Name+"_rpc.ba.go") g := gen.NewGenFile(filename, file.GoImportPath) g.file = file // generate file header g.P("// Code generated by GoVPP's binapi-generator. DO NOT EDIT.") g.P() g.P("package ", file.PackageName) g.P() // generate RPC service if len(file.Service.RPCs) > 0 { genService(g, file.Service) } return g } func genService(g *GenFile, svc *Service) { // generate comment g.P("// ", serviceApiName, " defines RPC service ", g.file.Desc.Name, ".") // generate service interface g.P("type ", serviceApiName, " interface {") for _, rpc := range svc.RPCs { g.P(rpcMethodSignature(g, rpc)) } g.P("}") g.P() // generate client implementation g.P("type ", serviceImplName, " struct {") g.P("conn ", govppApiPkg.Ident("Connection")) g.P("}") g.P() // generate client constructor g.P("func New", serviceClientName, "(conn ", govppApiPkg.Ident("Connection"), ") ", serviceApiName, " {") g.P("return &", serviceImplName, "{conn}") g.P("}") g.P() msgControlPingReply, ok := g.gen.messagesByName["control_ping_reply"] if !ok { logrus.Fatalf("no message for %v", "control_ping_reply") } msgControlPing, ok := g.gen.messagesByName["control_ping"] if !ok { logrus.Fatalf("no message for %v", "control_ping") } for _, rpc := range svc.RPCs { logf(" gen RPC: %v (%s)", rpc.GoName, rpc.VPP.Request) g.P("func (c *", serviceImplName, ") ", rpcMethodSignature(g, rpc), " {") if rpc.VPP.Stream { streamImpl := fmt.Sprintf("%s_%sClient", serviceImplName, rpc.GoName) streamApi := fmt.Sprintf("%s_%sClient", serviceApiName, rpc.GoName) msgDetails := rpc.MsgReply var msgReply *Message if rpc.MsgStream != nil { msgDetails = rpc.MsgStream msgReply = rpc.MsgReply } else { msgDetails = rpc.MsgReply msgReply = msgControlPingReply } g.P("stream, err := c.conn.NewStream(ctx)") g.P("if err != nil { return nil, err }") g.P("x := &", streamImpl, "{stream}") g.P("if err := x.Stream.SendMsg(in); err != nil {") g.P(" return nil, err") g.P("}") if rpc.MsgStream == nil { g.P("if err = x.Stream.SendMsg(&", msgControlPing.GoIdent, "{}); err != nil {") g.P(" return nil, err") g.P("}") } g.P("return x, nil") g.P("}") g.P() g.P("type ", streamApi, " interface {") g.P(" Recv() (*", msgDetails.GoIdent, ", error)") g.P(" ", govppApiPkg.Ident("Stream")) g.P("}") g.P() g.P("type ", streamImpl, " struct {") g.P(" ", govppApiPkg.Ident("Stream")) g.P("}") g.P() g.P("func (c *", streamImpl, ") Recv() (*", msgDetails.GoIdent, ", error) {") g.P(" msg, err := c.Stream.RecvMsg()") g.P(" if err != nil { return nil, err }") g.P(" switch m := msg.(type) {") g.P(" case *", msgDetails.GoIdent, ":") g.P(" return m, nil") g.P(" case *", msgReply.GoIdent, ":") g.P(" err = c.Stream.Close()") g.P(" if err != nil { return nil, err }") g.P(" return nil, ", ioPkg.Ident("EOF")) g.P(" default:") g.P(" return nil, ", fmtPkg.Ident("Errorf"), "(\"unexpected message: %T %v\", m, m)") g.P("}") } else if rpc.MsgReply != nil { g.P("out := new(", rpc.MsgReply.GoIdent, ")") g.P("err := c.conn.Invoke(ctx, in, out)") g.P("if err != nil { return nil, err }") if retvalField := getRetvalField(rpc.MsgReply); retvalField != nil { if fieldType := getFieldType(g, retvalField); fieldType == "int32" { g.P("return out, ", govppApiPkg.Ident("RetvalToVPPApiError"), "(out.", retvalField.GoName, ")") } else { g.P("return out, ", govppApiPkg.Ident("RetvalToVPPApiError"), "(int32(out.", retvalField.GoName, "))") } } else { g.P("return out, nil") } } else { g.P("stream, err := c.conn.NewStream(ctx)") g.P("if err != nil { return err }") g.P("err = stream.SendMsg(in)") g.P("if err != nil { return err }") g.P("err = stream.Close()") g.P("if err != nil { return err }") g.P("return nil") } g.P("}") g.P() } // TODO: generate service descriptor /*fmt.Fprintf(w, "var %s = api.%s{\n", serviceDescName, serviceDescType) fmt.Fprintf(w, "\tServiceName: \"%s\",\n", ctx.moduleName) fmt.Fprintf(w, "\tHandlerType: (*%s)(nil),\n", serviceApiName) fmt.Fprintf(w, "\tMethods: []api.MethodDesc{\n") for _, method := range rpcs { fmt.Fprintf(w, "\t {\n") fmt.Fprintf(w, "\t MethodName: \"%s\",\n", method.Name) fmt.Fprintf(w, "\t },\n") } fmt.Fprintf(w, "\t},\n") //fmt.Fprintf(w, "\tCompatibility: %s,\n", messageCrcName) //fmt.Fprintf(w, "\tMetadata: reflect.TypeOf((*%s)(nil)).Elem().PkgPath(),\n", serviceApiName) fmt.Fprintf(w, "\tMetadata: \"%s\",\n", ctx.inputFile) fmt.Fprintln(w, "}")*/ g.P() } func rpcMethodSignature(g *GenFile, rpc *RPC) string { s := rpc.GoName + "(ctx " + g.GoIdent(contextPkg.Ident("Context")) s += ", in *" + g.GoIdent(rpc.MsgRequest.GoIdent) + ") (" if rpc.VPP.Stream { s += serviceApiName + "_" + rpc.GoName + "Client, " } else if rpc.MsgReply != nil { s += "*" + g.GoIdent(rpc.MsgReply.GoIdent) + ", " } s += "error)" return s }