diff options
-rw-r--r-- | Gopkg.lock | 9 | ||||
-rw-r--r-- | adapter/socketclient/socketclient.go | 14 | ||||
-rw-r--r-- | adapter/stats_api.go | 11 | ||||
-rw-r--r-- | adapter/statsclient/statsclient.go | 460 | ||||
-rw-r--r-- | adapter/vppapiclient/stat_client.go | 21 | ||||
-rw-r--r-- | examples/stats-api/stats_api.go | 11 | ||||
-rw-r--r-- | vendor/github.com/ftrvxmtrx/fd/LICENSE.MIT | 18 | ||||
-rw-r--r-- | vendor/github.com/ftrvxmtrx/fd/README.md | 25 | ||||
-rw-r--r-- | vendor/github.com/ftrvxmtrx/fd/fd.go | 104 |
9 files changed, 655 insertions, 18 deletions
@@ -16,6 +16,14 @@ revision = "4da3e2cfbabc9f751898f250b49f2439785783a1" [[projects]] + branch = "master" + digest = "1:ea797b536b154f62d2e1c49b61d5a1088782111563eb59837ff2b83fd2a65184" + name = "github.com/ftrvxmtrx/fd" + packages = ["."] + pruneopts = "UT" + revision = "c6d800382fff6dc1412f34269f71b7f83bd059ad" + +[[projects]] digest = "1:81259d6c2b9aa336c627a31074078d5473788c1f54a373e4392d4e722716d74d" name = "github.com/google/gopacket" packages = [ @@ -90,6 +98,7 @@ input-imports = [ "github.com/bennyscetbun/jsongo", "github.com/fsnotify/fsnotify", + "github.com/ftrvxmtrx/fd", "github.com/google/gopacket", "github.com/google/gopacket/layers", "github.com/lunixbochs/struc", diff --git a/adapter/socketclient/socketclient.go b/adapter/socketclient/socketclient.go index 19fff7a..e56f89c 100644 --- a/adapter/socketclient/socketclient.go +++ b/adapter/socketclient/socketclient.go @@ -1,3 +1,17 @@ +// Copyright (c) 2019 Cisco and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package socketclient import ( diff --git a/adapter/stats_api.go b/adapter/stats_api.go index 146914d..798fcbd 100644 --- a/adapter/stats_api.go +++ b/adapter/stats_api.go @@ -15,9 +15,20 @@ package adapter import ( + "errors" "fmt" ) +var ( + ErrStatDirBusy = errors.New("stat dir busy") + ErrStatDumpBusy = errors.New("stat dump busy") +) + +var ( + // DefaultStatsSocket is the default path for the VPP stat socket file. + DefaultStatsSocket = "/run/vpp/stats.sock" +) + // StatsAPI provides connection to VPP stats API. type StatsAPI interface { // Connect establishes client connection to the stats API. diff --git a/adapter/statsclient/statsclient.go b/adapter/statsclient/statsclient.go new file mode 100644 index 0000000..07fcc49 --- /dev/null +++ b/adapter/statsclient/statsclient.go @@ -0,0 +1,460 @@ +// Copyright (c) 2019 Cisco and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statsclient + +import ( + "bytes" + "fmt" + "net" + "os" + "regexp" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "github.com/ftrvxmtrx/fd" + logger "github.com/sirupsen/logrus" + + "git.fd.io/govpp.git/adapter" +) + +var ( + // Debug is global variable that determines debug mode + Debug = os.Getenv("DEBUG_GOVPP_STATS") != "" + + // Log is global logger + Log = logger.New() +) + +// init initializes global logger, which logs debug level messages to stdout. +func init() { + Log.Out = os.Stdout + if Debug { + Log.Level = logger.DebugLevel + Log.Debug("enabled debug mode") + } +} + +// StatsClient is the pure Go implementation for VPP stats API. +type StatsClient struct { + sockAddr string + + currentEpoch int64 + sharedHeader []byte + directoryVector uintptr + memorySize int +} + +// NewStatsClient returns new VPP stats API client. +func NewStatsClient(socketName string) *StatsClient { + return &StatsClient{ + sockAddr: socketName, + } +} + +func (c *StatsClient) Connect() error { + var sockName string + if c.sockAddr == "" { + sockName = adapter.DefaultStatsSocket + } else { + sockName = c.sockAddr + } + + if _, err := os.Stat(sockName); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("stats socket file %q does not exists, ensure that VPP is running with `statseg { ... }` section in config", sockName) + } + return fmt.Errorf("stats socket file error: %v", err) + } + + if err := c.statSegmentConnect(sockName); err != nil { + return err + } + + return nil +} + +const statshmFilename = "statshm" + +func (c *StatsClient) statSegmentConnect(sockName string) error { + addr := &net.UnixAddr{ + Net: "unixpacket", + Name: sockName, + } + + Log.Debugf("connecting to: %v", addr) + + conn, err := net.DialUnix(addr.Net, nil, addr) + if err != nil { + Log.Warnf("connecting to socket %s failed: %s", addr, err) + return err + } + defer func() { + if err := conn.Close(); err != nil { + Log.Warnf("closing socket failed: %v", err) + } + }() + + Log.Debugf("connected to socket: %v", addr) + + files, err := fd.Get(conn, 1, []string{statshmFilename}) + if err != nil { + return fmt.Errorf("getting file descriptor over socket failed: %v", err) + } else if len(files) == 0 { + return fmt.Errorf("no files received over socket") + } + defer func() { + for _, f := range files { + if err := f.Close(); err != nil { + Log.Warnf("closing file %s failed: %v", f.Name(), err) + } + } + }() + + Log.Debugf("received %d files over socket", len(files)) + + f := files[0] + + info, err := f.Stat() + if err != nil { + return err + } + + size := int(info.Size()) + + Log.Debugf("fd: name=%v size=%v", info.Name(), size) + + data, err := syscall.Mmap(int(f.Fd()), 0, size, syscall.PROT_READ, syscall.MAP_SHARED) + if err != nil { + Log.Warnf("mapping shared memory failed: %v", err) + return fmt.Errorf("mapping shared memory failed: %v", err) + } + + Log.Debugf("successfuly mapped shared memory") + + c.sharedHeader = data + c.memorySize = size + + return nil +} + +func (c *StatsClient) Disconnect() error { + err := syscall.Munmap(c.sharedHeader) + if err != nil { + Log.Warnf("unmapping shared memory failed: %v", err) + return fmt.Errorf("unmapping shared memory failed: %v", err) + } + + Log.Debugf("successfuly unmapped shared memory") + + return nil +} + +func nameMatches(name string, patterns []string) bool { + if len(patterns) == 0 { + return true + } + for _, pattern := range patterns { + matched, err := regexp.MatchString(pattern, name) + if err == nil && matched { + return true + } + } + return false +} + +func (c *StatsClient) ListStats(patterns ...string) (statNames []string, err error) { + sa := c.accessStart() + if sa == nil { + return nil, fmt.Errorf("access failed") + } + + dirOffset, _, _ := c.readOffsets() + Log.Debugf("dirOffset: %v", dirOffset) + + vecLen := vectorLen(unsafe.Pointer(&c.sharedHeader[dirOffset])) + Log.Debugf("vecLen: %v", vecLen) + Log.Debugf("unsafe.Sizeof(statSegDirectoryEntry{}): %v", unsafe.Sizeof(statSegDirectoryEntry{})) + + for i := uint64(0); i < vecLen; i++ { + offset := uintptr(i) * unsafe.Sizeof(statSegDirectoryEntry{}) + dirEntry := (*statSegDirectoryEntry)(add(unsafe.Pointer(&c.sharedHeader[dirOffset]), offset)) + + nul := bytes.IndexByte(dirEntry.name[:], '\x00') + if nul < 0 { + Log.Warnf("no zero byte found for: %q", dirEntry.name[:]) + continue + } + name := string(dirEntry.name[:nul]) + + Log.Debugf(" %80q (type: %v, data: %d, offset: %d) ", name, dirEntry.directoryType, dirEntry.unionData, dirEntry.offsetVector) + + if nameMatches(name, patterns) { + statNames = append(statNames, name) + } + + // TODO: copy the listed entries elsewhere + } + + if !c.accessEnd(sa) { + return nil, adapter.ErrStatDirBusy + } + + c.currentEpoch = sa.epoch + + return statNames, nil +} + +func (c *StatsClient) DumpStats(patterns ...string) (entries []*adapter.StatEntry, err error) { + epoch, _ := c.readEpoch() + if c.currentEpoch > 0 && c.currentEpoch != epoch { // TODO: do list stats before dump + return nil, fmt.Errorf("old data") + } + + sa := c.accessStart() + if sa == nil { + return nil, fmt.Errorf("access failed") + } + + dirOffset, _, _ := c.readOffsets() + vecLen := vectorLen(unsafe.Pointer(&c.sharedHeader[dirOffset])) + + for i := uint64(0); i < vecLen; i++ { + offset := uintptr(i) * unsafe.Sizeof(statSegDirectoryEntry{}) + dirEntry := (*statSegDirectoryEntry)(add(unsafe.Pointer(&c.sharedHeader[dirOffset]), offset)) + + entry := c.copyData(dirEntry) + if nameMatches(entry.Name, patterns) { + entries = append(entries, &entry) + } + } + + if !c.accessEnd(sa) { + return nil, adapter.ErrStatDumpBusy + } + + return entries, nil +} + +func (c *StatsClient) copyData(dirEntry *statSegDirectoryEntry) (statEntry adapter.StatEntry) { + name := dirEntry.name[:] + if nul := bytes.IndexByte(name, '\x00'); nul < 0 { + Log.Warnf("no zero byte found for: %q", dirEntry.name[:]) + } else { + name = dirEntry.name[:nul] + } + + statEntry.Name = string(name) + statEntry.Type = adapter.StatType(dirEntry.directoryType) + + Log.Debugf(" - %s (type: %v, data: %v, offset: %v) ", statEntry.Name, statEntry.Type, dirEntry.unionData, dirEntry.offsetVector) + + switch statEntry.Type { + case adapter.ScalarIndex: + statEntry.Data = adapter.ScalarStat(dirEntry.unionData) + + case adapter.ErrorIndex: + _, errOffset, _ := c.readOffsets() + offsetVector := unsafe.Pointer(&c.sharedHeader[errOffset]) + vecLen := vectorLen(offsetVector) + + var errData adapter.Counter + for i := uint64(0); i < vecLen; i++ { + cb := *(*uint64)(add(offsetVector, uintptr(i)*unsafe.Sizeof(uint64(0)))) + offset := uintptr(cb) + uintptr(dirEntry.unionData)*unsafe.Sizeof(adapter.Counter(0)) + val := *(*adapter.Counter)(add(unsafe.Pointer(&c.sharedHeader[0]), offset)) + errData += val + } + statEntry.Data = adapter.ErrorStat(errData) + + case adapter.SimpleCounterVector: + if dirEntry.unionData == 0 { + Log.Debugf("\toffset is not valid") + break + } else if dirEntry.unionData >= uint64(len(c.sharedHeader)) { + Log.Debugf("\toffset out of range") + break + } + + simpleCounter := unsafe.Pointer(&c.sharedHeader[dirEntry.unionData]) // offset + vecLen := vectorLen(simpleCounter) + offsetVector := add(unsafe.Pointer(&c.sharedHeader[0]), uintptr(dirEntry.offsetVector)) + + data := make([][]adapter.Counter, vecLen) + for i := uint64(0); i < vecLen; i++ { + cb := *(*uint64)(add(offsetVector, uintptr(i)*unsafe.Sizeof(uint64(0)))) + counterVec := unsafe.Pointer(&c.sharedHeader[uintptr(cb)]) + vecLen2 := vectorLen(counterVec) + for j := uint64(0); j < vecLen2; j++ { + offset := uintptr(j) * unsafe.Sizeof(adapter.Counter(0)) + val := *(*adapter.Counter)(add(counterVec, offset)) + data[i] = append(data[i], val) + } + } + statEntry.Data = adapter.SimpleCounterStat(data) + + case adapter.CombinedCounterVector: + if dirEntry.unionData == 0 { + Log.Debugf("\toffset is not valid") + break + } else if dirEntry.unionData >= uint64(len(c.sharedHeader)) { + Log.Debugf("\toffset out of range") + break + } + + combinedCounter := unsafe.Pointer(&c.sharedHeader[dirEntry.unionData]) // offset + vecLen := vectorLen(combinedCounter) + offsetVector := add(unsafe.Pointer(&c.sharedHeader[0]), uintptr(dirEntry.offsetVector)) + + data := make([][]adapter.CombinedCounter, vecLen) + for i := uint64(0); i < vecLen; i++ { + cb := *(*uint64)(add(offsetVector, uintptr(i)*unsafe.Sizeof(uint64(0)))) + counterVec := unsafe.Pointer(&c.sharedHeader[uintptr(cb)]) + vecLen2 := vectorLen(counterVec) + for j := uint64(0); j < vecLen2; j++ { + offset := uintptr(j) * unsafe.Sizeof(adapter.CombinedCounter{}) + val := *(*adapter.CombinedCounter)(add(counterVec, offset)) + data[i] = append(data[i], val) + } + } + statEntry.Data = adapter.CombinedCounterStat(data) + + case adapter.NameVector: + if dirEntry.unionData == 0 { + Log.Debugf("\toffset is not valid") + break + } else if dirEntry.unionData >= uint64(len(c.sharedHeader)) { + Log.Debugf("\toffset out of range") + break + } + + nameVector := unsafe.Pointer(&c.sharedHeader[dirEntry.unionData]) // offset + vecLen := vectorLen(nameVector) + offsetVector := add(unsafe.Pointer(&c.sharedHeader[0]), uintptr(dirEntry.offsetVector)) + + data := make([]adapter.Name, vecLen) + for i := uint64(0); i < vecLen; i++ { + cb := *(*uint64)(add(offsetVector, uintptr(i)*unsafe.Sizeof(uint64(0)))) + nameVec := unsafe.Pointer(&c.sharedHeader[uintptr(cb)]) + vecLen2 := vectorLen(nameVec) + + var nameStr []byte + for j := uint64(0); j < vecLen2; j++ { + offset := uintptr(j) * unsafe.Sizeof(byte(0)) + val := *(*byte)(add(nameVec, offset)) + if val > 0 { + nameStr = append(nameStr, val) + } + } + data[i] = adapter.Name(nameStr) + } + statEntry.Data = adapter.NameStat(data) + + default: + Log.Warnf("Unknown type %d for stat entry: %s", statEntry.Type, statEntry.Name) + } + + Log.Debugf("\tentry data: %#v", statEntry.Data) + + return statEntry +} + +type statDirectoryType int32 + +func (t statDirectoryType) String() string { + return adapter.StatType(t).String() +} + +type statSegDirectoryEntry struct { + directoryType statDirectoryType + // unionData can represent: offset, index or value + unionData uint64 + offsetVector uint64 + name [128]byte +} + +type statSegSharedHeader struct { + version uint64 + epoch int64 + inProgress int64 + directoryOffset int64 + errorOffset int64 + statsOffset int64 +} + +func (c *StatsClient) readVersion() uint64 { + header := *(*statSegSharedHeader)(unsafe.Pointer(&c.sharedHeader[0])) + version := atomic.LoadUint64(&header.version) + return version +} + +func (c *StatsClient) readEpoch() (int64, bool) { + header := *(*statSegSharedHeader)(unsafe.Pointer(&c.sharedHeader[0])) + epoch := atomic.LoadInt64(&header.epoch) + inprog := atomic.LoadInt64(&header.inProgress) + return epoch, inprog != 0 +} + +func (c *StatsClient) readOffsets() (dir, err, stat int64) { + header := *(*statSegSharedHeader)(unsafe.Pointer(&c.sharedHeader[0])) + dirOffset := atomic.LoadInt64(&header.directoryOffset) + errOffset := atomic.LoadInt64(&header.errorOffset) + statOffset := atomic.LoadInt64(&header.statsOffset) + return dirOffset, errOffset, statOffset +} + +type statSegAccess struct { + epoch int64 +} + +var maxWaitInProgress = 1 * time.Second + +func (c *StatsClient) accessStart() *statSegAccess { + epoch, inprog := c.readEpoch() + t := time.Now() + for inprog { + if time.Since(t) > maxWaitInProgress { + return nil + } + epoch, inprog = c.readEpoch() + } + return &statSegAccess{ + epoch: epoch, + } +} + +func (c *StatsClient) accessEnd(acc *statSegAccess) bool { + epoch, inprog := c.readEpoch() + if acc.epoch != epoch || inprog { + return false + } + return true +} + +type vecHeader struct { + length uint64 + vectorData [0]uint8 +} + +func vectorLen(v unsafe.Pointer) uint64 { + vec := *(*vecHeader)(unsafe.Pointer(uintptr(v) - unsafe.Sizeof(uintptr(0)))) + return vec.length +} + +//go:nosplit +func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) +} diff --git a/adapter/vppapiclient/stat_client.go b/adapter/vppapiclient/stat_client.go index 389c93b..148f618 100644 --- a/adapter/vppapiclient/stat_client.go +++ b/adapter/vppapiclient/stat_client.go @@ -25,7 +25,6 @@ package vppapiclient import "C" import ( - "errors" "fmt" "os" "unsafe" @@ -33,21 +32,11 @@ import ( "git.fd.io/govpp.git/adapter" ) -var ( - ErrStatDirBusy = errors.New("stat dir busy") - ErrStatDumpBusy = errors.New("stat dump busy") -) - -var ( - // DefaultStatSocket is the default path for the VPP stat socket file. - DefaultStatSocket = "/run/vpp/stats.sock" -) - // global VPP stats API client, library vppapiclient only supports // single connection at a time var globalStatClient *statClient -// stubStatClient is the default implementation of StatsAPI. +// statClient is the default implementation of StatsAPI. type statClient struct { socketName string } @@ -66,7 +55,7 @@ func (c *statClient) Connect() error { var sockName string if c.socketName == "" { - sockName = DefaultStatSocket + sockName = adapter.DefaultStatsSocket } else { sockName = c.socketName } @@ -97,7 +86,7 @@ func (c *statClient) Disconnect() error { func (c *statClient) ListStats(patterns ...string) (stats []string, err error) { dir := C.govpp_stat_segment_ls(convertStringSlice(patterns)) if dir == nil { - return nil, ErrStatDirBusy + return nil, adapter.ErrStatDirBusy } defer C.govpp_stat_segment_vec_free(unsafe.Pointer(dir)) @@ -114,13 +103,13 @@ func (c *statClient) ListStats(patterns ...string) (stats []string, err error) { func (c *statClient) DumpStats(patterns ...string) (stats []*adapter.StatEntry, err error) { dir := C.govpp_stat_segment_ls(convertStringSlice(patterns)) if dir == nil { - return nil, ErrStatDirBusy + return nil, adapter.ErrStatDirBusy } defer C.govpp_stat_segment_vec_free(unsafe.Pointer(dir)) dump := C.govpp_stat_segment_dump(dir) if dump == nil { - return nil, ErrStatDumpBusy + return nil, adapter.ErrStatDumpBusy } defer C.govpp_stat_segment_data_free(dump) diff --git a/examples/stats-api/stats_api.go b/examples/stats-api/stats_api.go index f74a055..e20ce7a 100644 --- a/examples/stats-api/stats_api.go +++ b/examples/stats-api/stats_api.go @@ -22,6 +22,7 @@ import ( "strings" "git.fd.io/govpp.git/adapter" + "git.fd.io/govpp.git/adapter/statsclient" "git.fd.io/govpp.git/adapter/vppapiclient" "git.fd.io/govpp.git/core" ) @@ -34,7 +35,8 @@ import ( // ------------------------------------------------------------------ var ( - statsSocket = flag.String("socket", vppapiclient.DefaultStatSocket, "VPP stats segment socket") + statsSocket = flag.String("socket", adapter.DefaultStatsSocket, "Path to VPP stats socket") + goclient = flag.Bool("goclient", true, "Use pure Go client for stats API") dumpAll = flag.Bool("all", false, "Dump all stats including ones with zero values") ) @@ -62,7 +64,12 @@ func main() { patterns = flag.Args()[1:] } - client := vppapiclient.NewStatClient(*statsSocket) + var client adapter.StatsAPI + if *goclient { + client = statsclient.NewStatsClient(*statsSocket) + } else { + client = vppapiclient.NewStatClient(*statsSocket) + } fmt.Printf("Connecting to stats socket: %s\n", *statsSocket) diff --git a/vendor/github.com/ftrvxmtrx/fd/LICENSE.MIT b/vendor/github.com/ftrvxmtrx/fd/LICENSE.MIT new file mode 100644 index 0000000..136e69e --- /dev/null +++ b/vendor/github.com/ftrvxmtrx/fd/LICENSE.MIT @@ -0,0 +1,18 @@ +Copyright © 2012 Serge Zirukin + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/ftrvxmtrx/fd/README.md b/vendor/github.com/ftrvxmtrx/fd/README.md new file mode 100644 index 0000000..7a8a239 --- /dev/null +++ b/vendor/github.com/ftrvxmtrx/fd/README.md @@ -0,0 +1,25 @@ +# fd + +Package fd provides a simple API to pass file descriptors +between different OS processes. + +It can be useful if you want to inherit network connections +from another process without closing them. + +Example scenario: + + * Running server receives a "let's upgrade" message + * Server opens a Unix domain socket for the "upgrade" + * Server starts a new copy of itself and passes Unix domain + socket name + * New copy starts reading data from the socket + * Server sends its state over the socket, also sending the number + of network connections to inherit, then it sends those connections + using fd.Put() + * New server copy reads the state and inherits connections using fd.Get(), + checks that everything is OK and writes an "OK" message to the socket + * Server receives "OK" message and kills itself + +## Documentation + +[fd on godoc.org](http://godoc.org/github.com/ftrvxmtrx/fd) diff --git a/vendor/github.com/ftrvxmtrx/fd/fd.go b/vendor/github.com/ftrvxmtrx/fd/fd.go new file mode 100644 index 0000000..a5a4d48 --- /dev/null +++ b/vendor/github.com/ftrvxmtrx/fd/fd.go @@ -0,0 +1,104 @@ +// Package fd provides a simple API to pass file descriptors +// between different OS processes. +// +// It can be useful if you want to inherit network connections +// from another process without closing them. +// +// Example scenario: +// +// 1) Running server receives a "let's upgrade" message +// 2) Server opens a Unix domain socket for the "upgrade" +// 3) Server starts a new copy of itself and passes Unix domain socket name +// 4) New copy starts reading for the socket +// 5) Server sends its state over the socket, also sending the number +// of network connections to inherit, then it sends those connections +// using fd.Put() +// 6) New copy reads the state and inherits connections using fd.Get(), +// checks that everything is OK and sends the "OK" message to the socket +// 7) Server receives "OK" message and kills itself +package fd + +import ( + "net" + "os" + "syscall" +) + +// Get receives file descriptors from a Unix domain socket. +// +// Num specifies the expected number of file descriptors in one message. +// Internal files' names to be assigned are specified via optional filenames +// argument. +// +// You need to close all files in the returned slice. The slice can be +// non-empty even if this function returns an error. +// +// Use net.FileConn() if you're receiving a network connection. +func Get(via *net.UnixConn, num int, filenames []string) ([]*os.File, error) { + if num < 1 { + return nil, nil + } + + // get the underlying socket + viaf, err := via.File() + if err != nil { + return nil, err + } + socket := int(viaf.Fd()) + defer viaf.Close() + + // recvmsg + buf := make([]byte, syscall.CmsgSpace(num*4)) + _, _, _, _, err = syscall.Recvmsg(socket, nil, buf, 0) + if err != nil { + return nil, err + } + + // parse control msgs + var msgs []syscall.SocketControlMessage + msgs, err = syscall.ParseSocketControlMessage(buf) + + // convert fds to files + res := make([]*os.File, 0, len(msgs)) + for i := 0; i < len(msgs) && err == nil; i++ { + var fds []int + fds, err = syscall.ParseUnixRights(&msgs[i]) + + for fi, fd := range fds { + var filename string + if fi < len(filenames) { + filename = filenames[fi] + } + + res = append(res, os.NewFile(uintptr(fd), filename)) + } + } + + return res, err +} + +// Put sends file descriptors to Unix domain socket. +// +// Please note that the number of descriptors in one message is limited +// and is rather small. +// Use conn.File() to get a file if you want to put a network connection. +func Put(via *net.UnixConn, files ...*os.File) error { + if len(files) == 0 { + return nil + } + + viaf, err := via.File() + if err != nil { + return err + } + socket := int(viaf.Fd()) + defer viaf.Close() + + fds := make([]int, len(files)) + for i := range files { + fds[i] = int(files[i].Fd()) + } + + rights := syscall.UnixRights(fds...) + return syscall.Sendmsg(socket, nil, rights, nil, 0) +} |