diff options
author | Matus Mrekaj <matus.mrekaj@pantheon.tech> | 2019-10-22 15:05:39 +0200 |
---|---|---|
committer | Matus Mrekaj <matus.mrekaj@pantheon.tech> | 2019-10-30 14:42:35 +0100 |
commit | 58601b470bbd4e5ef534fed83511aa5a7f1c2d1e (patch) | |
tree | 1c0c1176567d66e1b7be45c51f445dd5baa28dee | |
parent | cc80dbcaaaca8bf1b6042fead850d456cf589a4e (diff) |
fix data races in proxy server
Signed-off-by: Matus Mrekaj <matus.mrekaj@pantheon.tech>
Change-Id: I932d560548ee816e28683243a7318a2a7fbbb24a
-rw-r--r-- | cmd/vpp-proxy/main.go | 9 | ||||
-rw-r--r-- | core/connection.go | 31 | ||||
-rw-r--r-- | core/control_ping.go | 4 | ||||
-rw-r--r-- | core/request_handler.go | 2 | ||||
-rw-r--r-- | go.mod | 2 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | proxy/client.go | 73 | ||||
-rw-r--r-- | proxy/log.go | 30 | ||||
-rw-r--r-- | proxy/proxy.go | 84 | ||||
-rw-r--r-- | proxy/server.go | 331 |
10 files changed, 442 insertions, 126 deletions
diff --git a/cmd/vpp-proxy/main.go b/cmd/vpp-proxy/main.go index de1b7b4..5221218 100644 --- a/cmd/vpp-proxy/main.go +++ b/cmd/vpp-proxy/main.go @@ -120,7 +120,10 @@ func runClient() { } func runServer() { - p := proxy.NewServer() + p, err := proxy.NewServer() + if err != nil { + log.Fatalln(err) + } statsAdapter := statsclient.NewStatsClient(*statsSocket) binapiAdapter := socketclient.NewVppClient(*binapiSocket) @@ -135,5 +138,7 @@ func runServer() { } defer p.DisconnectBinapi() - p.ListenAndServe(*proxyAddr) + if err := p.ListenAndServe(*proxyAddr); err != nil { + log.Fatalln(err) + } } diff --git a/core/connection.go b/core/connection.go index 6f82616..264ec43 100644 --- a/core/connection.go +++ b/core/connection.go @@ -111,6 +111,9 @@ type Connection struct { lastReplyLock sync.Mutex // lock for the last reply lastReply time.Time // time of the last received reply from VPP + + msgControlPing api.Message + msgControlPingReply api.Message } func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration) *Connection { @@ -122,14 +125,16 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration) } c := &Connection{ - vppClient: binapi, - maxAttempts: attempts, - recInterval: interval, - codec: &codec.MsgCodec{}, - msgIDs: make(map[string]uint16), - msgMap: make(map[uint16]api.Message), - channels: make(map[uint16]*Channel), - subscriptions: make(map[uint16][]*subscriptionCtx), + vppClient: binapi, + maxAttempts: attempts, + recInterval: interval, + codec: &codec.MsgCodec{}, + msgIDs: make(map[string]uint16), + msgMap: make(map[uint16]api.Message), + channels: make(map[uint16]*Channel), + subscriptions: make(map[uint16][]*subscriptionCtx), + msgControlPing: msgControlPing, + msgControlPingReply: msgControlPingReply, } binapi.SetMsgCallback(c.msgCallback) return c @@ -314,7 +319,7 @@ func (c *Connection) healthCheckLoop(connChan chan ConnectionEvent) { } // send the control ping request - ch.reqChan <- &vppRequest{msg: msgControlPing} + ch.reqChan <- &vppRequest{msg: c.msgControlPing} for { // expect response within timeout period @@ -427,12 +432,12 @@ func (c *Connection) retrieveMessageIDs() (err error) { } n++ - if c.pingReqID == 0 && msg.GetMessageName() == msgControlPing.GetMessageName() { + if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() { c.pingReqID = msgID - msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - } else if c.pingReplyID == 0 && msg.GetMessageName() == msgControlPingReply.GetMessageName() { + c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() { c.pingReplyID = msgID - msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) } if debugMsgIDs { diff --git a/core/control_ping.go b/core/control_ping.go index cd447b7..b39fd3f 100644 --- a/core/control_ping.go +++ b/core/control_ping.go @@ -1,6 +1,8 @@ package core -import "git.fd.io/govpp.git/api" +import ( + "git.fd.io/govpp.git/api" +) var ( msgControlPing api.Message = new(ControlPing) diff --git a/core/request_handler.go b/core/request_handler.go index d3f7bdc..ddd5307 100644 --- a/core/request_handler.go +++ b/core/request_handler.go @@ -110,7 +110,7 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error { if req.multi { // send a control ping to determine end of the multipart response - pingData, _ := c.codec.EncodeMsg(msgControlPing, c.pingReqID) + pingData, _ := c.codec.EncodeMsg(c.msgControlPing, c.pingReqID) log.WithFields(logger.Fields{ "channel": ch.id, @@ -9,7 +9,7 @@ require ( github.com/golang/protobuf v1.3.2 // indirect github.com/hpcloud/tail v1.0.0 // indirect github.com/kr/pretty v0.1.0 // indirect - github.com/lunixbochs/struc v0.0.0-20180408203800-02e4c2afbb2a + github.com/lunixbochs/struc v0.0.0-20190916212049-a5c72983bc42 github.com/onsi/ginkgo v1.8.0 // indirect github.com/onsi/gomega v1.1.0 github.com/pkg/profile v1.2.1 @@ -17,6 +17,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lunixbochs/struc v0.0.0-20180408203800-02e4c2afbb2a h1:axFx97V2Lyke5LbeygrJlzc07mwVhHt2ZHeI/Nv8Aq4= github.com/lunixbochs/struc v0.0.0-20180408203800-02e4c2afbb2a/go.mod h1:iOJu9pApjjmEmNq7PqlA5R9mDu/HMF5EM3llWKX/TyA= +github.com/lunixbochs/struc v0.0.0-20190916212049-a5c72983bc42 h1:PzBD7QuxXSgSu61TKXxRwVGzWO5d9QZ0HxFFpndZMCg= +github.com/lunixbochs/struc v0.0.0-20190916212049-a5c72983bc42/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/onsi/ginkgo v1.8.0 h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.1.0 h1:e3YP4dN/HYPpGh29X1ZkcxcEICsOls9huyVCRBaxjq8= diff --git a/proxy/client.go b/proxy/client.go index 4f2df0f..7f92946 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -1,8 +1,22 @@ +// Copyright (c) 2019 Cisco and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package proxy import ( "fmt" - "log" + "git.fd.io/govpp.git/core" "net/rpc" "reflect" "time" @@ -40,7 +54,8 @@ func (c *Client) NewStatsClient() (*StatsClient, error) { // NewBinapiClient returns new BinapiClient which implements api.Channel. func (c *Client) NewBinapiClient() (*BinapiClient, error) { binapi := &BinapiClient{ - rpc: c.rpc, + rpc: c.rpc, + timeout: core.DefaultReplyTimeout, } return binapi, nil } @@ -103,27 +118,31 @@ func (s *StatsClient) GetBufferStats(bufStats *api.BufferStats) error { } type BinapiClient struct { - rpc *rpc.Client + rpc *rpc.Client + timeout time.Duration } func (b *BinapiClient) SendRequest(msg api.Message) api.RequestCtx { req := &requestCtx{ - rpc: b.rpc, - req: msg, + rpc: b.rpc, + timeout: b.timeout, + req: msg, } - log.Printf("SendRequest: %T %+v", msg, msg) + log.Debugf("SendRequest: %T %+v", msg, msg) return req } type requestCtx struct { - rpc *rpc.Client - req api.Message + rpc *rpc.Client + req api.Message + timeout time.Duration } func (r *requestCtx) ReceiveReply(msg api.Message) error { req := BinapiRequest{ Msg: r.req, ReplyMsg: msg, + Timeout: r.timeout, } resp := BinapiResponse{} @@ -140,16 +159,18 @@ func (r *requestCtx) ReceiveReply(msg api.Message) error { func (b *BinapiClient) SendMultiRequest(msg api.Message) api.MultiRequestCtx { req := &multiRequestCtx{ - rpc: b.rpc, - req: msg, + rpc: b.rpc, + timeout: b.timeout, + req: msg, } - log.Printf("SendMultiRequest: %T %+v", msg, msg) + log.Debugf("SendMultiRequest: %T %+v", msg, msg) return req } type multiRequestCtx struct { - rpc *rpc.Client - req api.Message + rpc *rpc.Client + req api.Message + timeout time.Duration index int replies []api.Message @@ -162,6 +183,7 @@ func (r *multiRequestCtx) ReceiveReply(msg api.Message) (stop bool, err error) { Msg: r.req, ReplyMsg: msg, IsMulti: true, + Timeout: r.timeout, } resp := BinapiResponse{} @@ -189,24 +211,23 @@ func (b *BinapiClient) SubscribeNotification(notifChan chan api.Message, event a } func (b *BinapiClient) SetReplyTimeout(timeout time.Duration) { - req := BinapiTimeoutRequest{Timeout: timeout} - resp := BinapiTimeoutResponse{} - if err := b.rpc.Call("BinapiRPC.SetTimeout", req, &resp); err != nil { - log.Println(err) - } + b.timeout = timeout } func (b *BinapiClient) CheckCompatiblity(msgs ...api.Message) error { + msgNamesCrscs := make([]string, 0, len(msgs)) + for _, msg := range msgs { - req := BinapiCompatibilityRequest{ - MsgName: msg.GetMessageName(), - Crc: msg.GetCrcString(), - } - resp := BinapiCompatibilityResponse{} - if err := b.rpc.Call("BinapiRPC.Compatibility", req, &resp); err != nil { - return err - } + msgNamesCrscs = append(msgNamesCrscs, msg.GetMessageName()+"_"+msg.GetCrcString()) } + + req := BinapiCompatibilityRequest{MsgNameCrcs: msgNamesCrscs} + resp := BinapiCompatibilityResponse{} + + if err := b.rpc.Call("BinapiRPC.Compatibility", req, &resp); err != nil { + return err + } + return nil } diff --git a/proxy/log.go b/proxy/log.go new file mode 100644 index 0000000..2810528 --- /dev/null +++ b/proxy/log.go @@ -0,0 +1,30 @@ +package proxy + +import ( + "github.com/sirupsen/logrus" + "os" +) + +var ( + debug = os.Getenv("DEBUG_GOVPP_PROXY") != "" + + log = logrus.New() +) + +func init() { + log.Out = os.Stdout + if debug { + log.Level = logrus.DebugLevel + log.Debugf("govpp/proxy: debug mode enabled") + } +} + +// SetLogger sets global logger to l. +func SetLogger(l *logrus.Logger) { + log = l +} + +// SetLogLevel sets global logger level to lvl. +func SetLogLevel(lvl logrus.Level) { + log.Level = lvl +}
\ No newline at end of file diff --git a/proxy/proxy.go b/proxy/proxy.go index 1f8f824..33cf05f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,88 +15,78 @@ package proxy import ( - "log" + "fmt" + "io" "net" "net/http" "net/rpc" "git.fd.io/govpp.git/adapter" - "git.fd.io/govpp.git/core" ) // Server defines a proxy server that serves client requests to stats and binapi. type Server struct { rpc *rpc.Server - statsConn *core.StatsConnection - binapiConn *core.Connection + statsRPC *StatsRPC + binapiRPC *BinapiRPC } -func NewServer() *Server { - return &Server{ - rpc: rpc.NewServer(), +func NewServer() (*Server, error) { + srv := &Server{ + rpc: rpc.NewServer(), + statsRPC: &StatsRPC{}, + binapiRPC: &BinapiRPC{}, } + + if err := srv.rpc.Register(srv.statsRPC); err != nil { + return nil, err + } + + if err := srv.rpc.Register(srv.binapiRPC); err != nil { + return nil, err + } + + return srv, nil } func (p *Server) ConnectStats(stats adapter.StatsAPI) error { - var err error - p.statsConn, err = core.ConnectStats(stats) - if err != nil { - return err - } - return nil + return p.statsRPC.Connect(stats) } func (p *Server) DisconnectStats() { - if p.statsConn != nil { - p.statsConn.Disconnect() - } + p.statsRPC.Disconnect() } func (p *Server) ConnectBinapi(binapi adapter.VppAPI) error { - var err error - p.binapiConn, err = core.Connect(binapi) - if err != nil { - return err - } - return nil + return p.binapiRPC.Connect(binapi) } func (p *Server) DisconnectBinapi() { - if p.binapiConn != nil { - p.binapiConn.Disconnect() - } + p.binapiRPC.Disconnect() } -func (p *Server) ListenAndServe(addr string) { - if p.statsConn != nil { - statsRPC := NewStatsRPC(p.statsConn) - if err := p.rpc.Register(statsRPC); err != nil { - panic(err) - } - } - if p.binapiConn != nil { - ch, err := p.binapiConn.NewAPIChannel() - if err != nil { - panic(err) - } - binapiRPC := NewBinapiRPC(ch) - if err := p.rpc.Register(binapiRPC); err != nil { - panic(err) - } - } +func (p *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { + p.rpc.ServeHTTP(w, req) +} + +func (p *Server) ServeCodec(codec rpc.ServerCodec) { + p.rpc.ServeCodec(codec) +} +func (p *Server) ServeConn(conn io.ReadWriteCloser) { + p.rpc.ServeConn(conn) +} + +func (p *Server) ListenAndServe(addr string) error { p.rpc.HandleHTTP(rpc.DefaultRPCPath, rpc.DefaultDebugPath) l, e := net.Listen("tcp", addr) if e != nil { - log.Fatal("listen error:", e) + return fmt.Errorf("listen error:", e) } defer l.Close() log.Printf("proxy serving on: %v", addr) - - if err := http.Serve(l, nil); err != nil { - log.Fatalln(err) - } + return http.Serve(l, nil) } diff --git a/proxy/server.go b/proxy/server.go index 20f01f0..ecb0e8d 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -1,12 +1,48 @@ +// Copyright (c) 2019 Cisco and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package proxy import ( + "errors" "fmt" - "log" "reflect" + "sync" + "sync/atomic" "time" + "git.fd.io/govpp.git/adapter" "git.fd.io/govpp.git/api" + "git.fd.io/govpp.git/core" +) + +const ( + binapiErrorMsg = ` +------------------------------------------------------------ + received binapi request while VPP connection is down! + - is VPP running ? + - have you called Connect on the binapi RPC ? +------------------------------------------------------------ +` + statsErrorMsg = ` +------------------------------------------------------------ + received stats request while stats connection is down! + - is VPP running ? + - is the correct socket name configured ? + - have you called Connect on the stats RPC ? +------------------------------------------------------------ +` ) type StatsRequest struct { @@ -23,34 +59,147 @@ type StatsResponse struct { // StatsRPC is a RPC server for proxying client request to api.StatsProvider. type StatsRPC struct { - stats api.StatsProvider + statsConn *core.StatsConnection + stats adapter.StatsAPI + + done chan struct{} + // non-zero if the RPC service is available + available uint32 + // non-zero if connected to stats file. + isConnected uint32 + // synchronizes access to statsConn. + mu sync.Mutex } // NewStatsRPC returns new StatsRPC to be used as RPC server // proxying request to given api.StatsProvider. -func NewStatsRPC(stats api.StatsProvider) *StatsRPC { - return &StatsRPC{stats: stats} +func NewStatsRPC(stats adapter.StatsAPI) (*StatsRPC, error) { + rpc := new(StatsRPC) + if err := rpc.Connect(stats); err != nil { + return nil, err + } + return rpc, nil +} + +func (s *StatsRPC) watchConnection() { + heartbeatTicker := time.NewTicker(10 * time.Second).C + atomic.StoreUint32(&s.available, 1) + log.Println("enabling statsRPC service") + + count := 0 + prev := new(api.SystemStats) + + s.mu.Lock() + if err := s.statsConn.GetSystemStats(prev); err != nil { + atomic.StoreUint32(&s.available, 0) + log.Warnf("disabling statsRPC service, reason:", err) + } + s.mu.Unlock() + + for { + select { + case <-heartbeatTicker: + // If disconnect was called exit. + if atomic.LoadUint32(&s.isConnected) == 0 { + atomic.StoreUint32(&s.available, 0) + return + } + + curr := new(api.SystemStats) + + s.mu.Lock() + if err := s.statsConn.GetSystemStats(curr); err != nil { + atomic.StoreUint32(&s.available, 0) + log.Warnf("disabling statsRPC service, reason:", err) + } + s.mu.Unlock() + + if curr.Heartbeat <= prev.Heartbeat { + count++ + // vpp might have crashed/reset... try reconnecting + if count == 5 { + count = 0 + atomic.StoreUint32(&s.available, 0) + log.Warnln("disabling statsRPC service, reason: vpp might have crashed/reset...") + s.statsConn.Disconnect() + for { + var err error + s.statsConn, err = core.ConnectStats(s.stats) + if err == nil { + atomic.StoreUint32(&s.available, 1) + log.Println("enabling statsRPC service") + break + } + time.Sleep(5 * time.Second) + } + } + } else { + count = 0 + } + + prev = curr + case <-s.done: + return + } + } +} + +func (s *StatsRPC) Connect(stats adapter.StatsAPI) error { + if atomic.LoadUint32(&s.isConnected) == 1 { + return errors.New("connection already exists") + } + s.stats = stats + var err error + s.statsConn, err = core.ConnectStats(s.stats) + if err != nil { + return err + } + s.done = make(chan struct{}) + atomic.StoreUint32(&s.isConnected, 1) + + go s.watchConnection() + return nil +} + +func (s *StatsRPC) Disconnect() { + if atomic.LoadUint32(&s.isConnected) == 1 { + atomic.StoreUint32(&s.isConnected, 0) + close(s.done) + s.statsConn.Disconnect() + s.statsConn = nil + } +} + +func (s *StatsRPC) serviceAvailable() bool { + return atomic.LoadUint32(&s.available) == 1 } func (s *StatsRPC) GetStats(req StatsRequest, resp *StatsResponse) error { - log.Printf("StatsRPC.GetStats - REQ: %+v", req) + if !s.serviceAvailable() { + log.Println(statsErrorMsg) + return errors.New("server does not support 'get stats' at this time, try again later") + } + log.Debugf("StatsRPC.GetStats - REQ: %+v", req) + + s.mu.Lock() + defer s.mu.Unlock() switch req.StatsType { case "system": resp.SysStats = new(api.SystemStats) - return s.stats.GetSystemStats(resp.SysStats) + return s.statsConn.GetSystemStats(resp.SysStats) case "node": resp.NodeStats = new(api.NodeStats) - return s.stats.GetNodeStats(resp.NodeStats) + return s.statsConn.GetNodeStats(resp.NodeStats) case "interface": resp.IfaceStats = new(api.InterfaceStats) - return s.stats.GetInterfaceStats(resp.IfaceStats) + return s.statsConn.GetInterfaceStats(resp.IfaceStats) case "error": resp.ErrStats = new(api.ErrorStats) - return s.stats.GetErrorStats(resp.ErrStats) + return s.statsConn.GetErrorStats(resp.ErrStats) case "buffer": resp.BufStats = new(api.BufferStats) - return s.stats.GetBufferStats(resp.BufStats) + return s.statsConn.GetBufferStats(resp.BufStats) default: return fmt.Errorf("unknown stats type: %s", req.StatsType) } @@ -60,6 +209,7 @@ type BinapiRequest struct { Msg api.Message IsMulti bool ReplyMsg api.Message + Timeout time.Duration } type BinapiResponse struct { @@ -68,36 +218,124 @@ type BinapiResponse struct { } type BinapiCompatibilityRequest struct { - MsgName string - Crc string + MsgNameCrcs []string } type BinapiCompatibilityResponse struct { -} - -type BinapiTimeoutRequest struct { - Timeout time.Duration -} - -type BinapiTimeoutResponse struct { + CompatibleMsgs []string + IncompatibleMsgs []string } // BinapiRPC is a RPC server for proxying client request to api.Channel. type BinapiRPC struct { - binapi api.Channel + binapiConn *core.Connection + binapi adapter.VppAPI + + events chan core.ConnectionEvent + done chan struct{} + // non-zero if the RPC service is available + available uint32 + // non-zero if connected to vpp. + isConnected uint32 } // NewBinapiRPC returns new BinapiRPC to be used as RPC server // proxying request to given api.Channel. -func NewBinapiRPC(binapi api.Channel) *BinapiRPC { - return &BinapiRPC{binapi: binapi} +func NewBinapiRPC(binapi adapter.VppAPI) (*BinapiRPC, error) { + rpc := new(BinapiRPC) + if err := rpc.Connect(binapi); err != nil { + return nil, err + } + return rpc, nil +} + +func (s *BinapiRPC) watchConnection() { + for { + select { + case e := <-s.events: + // If disconnect was called exit. + if atomic.LoadUint32(&s.isConnected) == 0 { + atomic.StoreUint32(&s.available, 0) + return + } + + switch e.State { + case core.Connected: + if !s.serviceAvailable() { + atomic.StoreUint32(&s.available, 1) + log.Println("enabling binapiRPC service") + } + case core.Disconnected: + if s.serviceAvailable() { + atomic.StoreUint32(&s.available, 0) + log.Warnf("disabling binapiRPC, reason: %v\n", e.Error) + } + case core.Failed: + if s.serviceAvailable() { + atomic.StoreUint32(&s.available, 0) + log.Warnf("disabling binapiRPC, reason: %v\n", e.Error) + } + // vpp might have crashed/reset... reconnect + s.binapiConn.Disconnect() + + var err error + s.binapiConn, s.events, err = core.AsyncConnect(s.binapi, 3, 5*time.Second) + if err != nil { + log.Println(err) + } + } + case <-s.done: + return + } + } +} + +func (s *BinapiRPC) Connect(binapi adapter.VppAPI) error { + if atomic.LoadUint32(&s.isConnected) == 1 { + return errors.New("connection already exists") + } + s.binapi = binapi + var err error + s.binapiConn, s.events, err = core.AsyncConnect(binapi, 3, time.Second) + if err != nil { + return err + } + s.done = make(chan struct{}) + atomic.StoreUint32(&s.isConnected, 1) + + go s.watchConnection() + return nil +} + +func (s *BinapiRPC) Disconnect() { + if atomic.LoadUint32(&s.isConnected) == 1 { + atomic.StoreUint32(&s.isConnected, 0) + close(s.done) + s.binapiConn.Disconnect() + s.binapiConn = nil + } +} + +func (s *BinapiRPC) serviceAvailable() bool { + return atomic.LoadUint32(&s.available) == 1 } func (s *BinapiRPC) Invoke(req BinapiRequest, resp *BinapiResponse) error { - log.Printf("BinapiRPC.Invoke - REQ: %#v", req) + if !s.serviceAvailable() { + log.Println(binapiErrorMsg) + return errors.New("server does not support 'invoke' at this time, try again later") + } + log.Debugf("BinapiRPC.Invoke - REQ: %#v", req) + + ch, err := s.binapiConn.NewAPIChannel() + if err != nil { + return err + } + defer ch.Close() + ch.SetReplyTimeout(req.Timeout) if req.IsMulti { - multi := s.binapi.SendMultiRequest(req.Msg) + multi := ch.SendMultiRequest(req.Msg) for { // create new message in response of type ReplyMsg msg := reflect.New(reflect.TypeOf(req.ReplyMsg).Elem()).Interface().(api.Message) @@ -115,7 +353,7 @@ func (s *BinapiRPC) Invoke(req BinapiRequest, resp *BinapiResponse) error { // create new message in response of type ReplyMsg resp.Msg = reflect.New(reflect.TypeOf(req.ReplyMsg).Elem()).Interface().(api.Message) - err := s.binapi.SendRequest(req.Msg).ReceiveReply(resp.Msg) + err := ch.SendRequest(req.Msg).ReceiveReply(resp.Msg) if err != nil { return err } @@ -124,16 +362,39 @@ func (s *BinapiRPC) Invoke(req BinapiRequest, resp *BinapiResponse) error { return nil } -func (s *BinapiRPC) SetTimeout(req BinapiTimeoutRequest, _ *BinapiTimeoutResponse) error { - log.Printf("BinapiRPC.SetTimeout - REQ: %#v", req) - s.binapi.SetReplyTimeout(req.Timeout) - return nil -} +func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCompatibilityResponse) error { + if !s.serviceAvailable() { + log.Println(binapiErrorMsg) + return errors.New("server does not support 'compatibility check' at this time, try again later") + } + log.Debugf("BinapiRPC.Compatiblity - REQ: %#v", req) + + ch, err := s.binapiConn.NewAPIChannel() + if err != nil { + return err + } + defer ch.Close() + + resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs)) + resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs)) -func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, _ *BinapiCompatibilityResponse) error { - log.Printf("BinapiRPC.Compatiblity - REQ: %#v", req) - if val, ok := api.GetRegisteredMessages()[req.MsgName+"_"+req.Crc]; ok { - return s.binapi.CheckCompatiblity(val) + for _, msg := range req.MsgNameCrcs { + val, ok := api.GetRegisteredMessages()[msg] + if !ok { + resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg) + continue + } + + if err = ch.CheckCompatiblity(val); err != nil { + resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg) + } else { + resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg) + } } - return fmt.Errorf("compatibility check failed for the message: %s", req.MsgName+"_"+req.Crc) + + if len(resp.IncompatibleMsgs) > 0 { + return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs) + } + + return nil } |