diff options
Diffstat (limited to 'src/vppinfra')
-rw-r--r-- | src/vppinfra/socket.c | 130 | ||||
-rw-r--r-- | src/vppinfra/socket.h | 49 | ||||
-rw-r--r-- | src/vppinfra/test_socket.c | 6 |
3 files changed, 161 insertions, 24 deletions
diff --git a/src/vppinfra/socket.c b/src/vppinfra/socket.c index 37dcbbfd8c6..87a9333f904 100644 --- a/src/vppinfra/socket.c +++ b/src/vppinfra/socket.c @@ -35,17 +35,18 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#include <sys/un.h> +#include <stdio.h> +#include <string.h> /* strchr */ +#define __USE_GNU #include <sys/types.h> #include <sys/socket.h> +#include <sys/un.h> #include <sys/stat.h> #include <netinet/in.h> #include <arpa/inet.h> #include <netdb.h> #include <unistd.h> -#include <stdio.h> #include <fcntl.h> -#include <string.h> /* strchr */ #include <vppinfra/mem.h> #include <vppinfra/vec.h> @@ -233,7 +234,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes) u8 *buf; /* RX side of socket is down once end of file is reached. */ - if (sock->flags & SOCKET_RX_END_OF_FILE) + if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE) return 0; fd = sock->fd; @@ -255,7 +256,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes) /* Other side closed the socket. */ if (n_read == 0) - sock->flags |= SOCKET_RX_END_OF_FILE; + sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE; non_fatal: _vec_len (sock->rx_buffer) += n_read - n_bytes; @@ -271,6 +272,91 @@ default_socket_close (clib_socket_t * s) return 0; } +static clib_error_t * +default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen, + int fds[], int num_fds) +{ + struct msghdr mh = { 0 }; + struct iovec iov[1]; + char ctl[CMSG_SPACE (sizeof (int)) * num_fds]; + int rv; + + iov[0].iov_base = msg; + iov[0].iov_len = msglen; + mh.msg_iov = iov; + mh.msg_iovlen = 1; + + if (num_fds > 0) + { + struct cmsghdr *cmsg; + memset (&ctl, 0, sizeof (ctl)); + mh.msg_control = ctl; + mh.msg_controllen = sizeof (ctl); + cmsg = CMSG_FIRSTHDR (&mh); + cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds); + } + rv = sendmsg (s->fd, &mh, 0); + if (rv < 0) + return clib_error_return_unix (0, "sendmsg"); + return 0; +} + + +static clib_error_t * +default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen, + int fds[], int num_fds) +{ + char ctl[CMSG_SPACE (sizeof (int) * num_fds) + + CMSG_SPACE (sizeof (struct ucred))]; + struct msghdr mh = { 0 }; + struct iovec iov[1]; + ssize_t size; + struct ucred *cr = 0; + struct cmsghdr *cmsg; + + iov[0].iov_base = msg; + iov[0].iov_len = msglen; + mh.msg_iov = iov; + mh.msg_iovlen = 1; + mh.msg_control = ctl; + mh.msg_controllen = sizeof (ctl); + + memset (ctl, 0, sizeof (ctl)); + + /* receive the incoming message */ + size = recvmsg (s->fd, &mh, 0); + if (size != msglen) + { + return (size == 0) ? clib_error_return (0, "disconnected") : + clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')", + s->fd, s->config); + } + + cmsg = CMSG_FIRSTHDR (&mh); + while (cmsg) + { + if (cmsg->cmsg_level == SOL_SOCKET) + { + if (cmsg->cmsg_type == SCM_CREDENTIALS) + { + cr = (struct ucred *) CMSG_DATA (cmsg); + s->uid = cr->uid; + s->gid = cr->gid; + s->pid = cr->pid; + } + else if (cmsg->cmsg_type == SCM_RIGHTS) + { + clib_memcpy (fds, CMSG_DATA (cmsg), num_fds * sizeof (int)); + } + } + cmsg = CMSG_NXTHDR (&mh, cmsg); + } + return 0; +} + static void socket_init_funcs (clib_socket_t * s) { @@ -280,6 +366,10 @@ socket_init_funcs (clib_socket_t * s) s->read_func = default_socket_read; if (!s->close_func) s->close_func = default_socket_close; + if (!s->sendmsg_func) + s->sendmsg_func = default_socket_sendmsg; + if (!s->recvmsg_func) + s->recvmsg_func = default_socket_recvmsg; } clib_error_t * @@ -291,18 +381,22 @@ clib_socket_init (clib_socket_t * s) struct sockaddr_un su; } addr; socklen_t addr_len = 0; + int socket_type; clib_error_t *error = 0; word port; error = socket_config (s->config, &addr.sa, &addr_len, - (s->flags & SOCKET_IS_SERVER + (s->flags & CLIB_SOCKET_F_IS_SERVER ? INADDR_LOOPBACK : INADDR_ANY)); if (error) goto done; socket_init_funcs (s); - s->fd = socket (addr.sa.sa_family, SOCK_STREAM, 0); + socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ? + SOCK_SEQPACKET : SOCK_STREAM; + + s->fd = socket (addr.sa.sa_family, socket_type, 0); if (s->fd < 0) { error = clib_error_return_unix (0, "socket (fd %d, '%s')", @@ -314,7 +408,7 @@ clib_socket_init (clib_socket_t * s) if (addr.sa.sa_family == PF_INET) port = ((struct sockaddr_in *) &addr)->sin_port; - if (s->flags & SOCKET_IS_SERVER) + if (s->flags & CLIB_SOCKET_F_IS_SERVER) { uword need_bind = 1; @@ -342,6 +436,18 @@ clib_socket_init (clib_socket_t * s) clib_unix_warning ("setsockopt SO_REUSEADDR fails"); } + if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED) + { + int x = 1; + if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0) + { + error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, " + "fd %d, '%s')", s->fd, + s->config); + goto done; + } + } + if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0) { error = clib_error_return_unix (0, "bind (fd %d, '%s')", @@ -356,7 +462,7 @@ clib_socket_init (clib_socket_t * s) goto done; } if (addr.sa.sa_family == PF_LOCAL - && s->flags & SOCKET_ALLOW_GROUP_WRITE) + && s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE) { struct stat st = { 0 }; if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0) @@ -378,7 +484,7 @@ clib_socket_init (clib_socket_t * s) } else { - if ((s->flags & SOCKET_NON_BLOCKING_CONNECT) + if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) && fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0) { error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')", @@ -387,7 +493,7 @@ clib_socket_init (clib_socket_t * s) } if (connect (s->fd, &addr.sa, addr_len) < 0 - && !((s->flags & SOCKET_NON_BLOCKING_CONNECT) && + && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) && errno == EINPROGRESS)) { error = clib_error_return_unix (0, "connect (fd %d, '%s')", @@ -434,7 +540,7 @@ clib_socket_accept (clib_socket_t * server, clib_socket_t * client) goto close_client; } - client->flags = SOCKET_IS_CLIENT; + client->flags = CLIB_SOCKET_F_IS_CLIENT; socket_init_funcs (client); return 0; diff --git a/src/vppinfra/socket.h b/src/vppinfra/socket.h index 75037208d5d..4f9e9509342 100644 --- a/src/vppinfra/socket.h +++ b/src/vppinfra/socket.h @@ -55,13 +55,14 @@ typedef struct _socket_t char *config; u32 flags; -#define SOCKET_IS_SERVER (1 << 0) -#define SOCKET_IS_CLIENT (0 << 0) -#define SOCKET_NON_BLOCKING_CONNECT (1 << 1) -#define SOCKET_ALLOW_GROUP_WRITE (1 << 2) +#define CLIB_SOCKET_F_IS_SERVER (1 << 0) +#define CLIB_SOCKET_F_IS_CLIENT (0 << 0) +#define CLIB_SOCKET_F_RX_END_OF_FILE (1 << 2) +#define CLIB_SOCKET_F_NON_BLOCKING_CONNECT (1 << 3) +#define CLIB_SOCKET_F_ALLOW_GROUP_WRITE (1 << 4) +#define CLIB_SOCKET_F_SEQPACKET (1 << 5) +#define CLIB_SOCKET_F_PASSCRED (1 << 6) - /* Read returned end-of-file. */ -#define SOCKET_RX_END_OF_FILE (1 << 2) /* Transmit buffer. Holds data waiting to be written. */ u8 *tx_buffer; @@ -72,10 +73,19 @@ typedef struct _socket_t /* Peer socket we are connected to. */ struct sockaddr_in peer; + /* Credentials, populated if CLIB_SOCKET_F_PASSCRED is set */ + pid_t pid; + uid_t uid; + gid_t gid; + clib_error_t *(*write_func) (struct _socket_t * sock); clib_error_t *(*read_func) (struct _socket_t * sock, int min_bytes); clib_error_t *(*close_func) (struct _socket_t * sock); - void *private_data; + clib_error_t *(*recvmsg_func) (struct _socket_t * s, void *msg, int msglen, + int fds[], int num_fds); + clib_error_t *(*sendmsg_func) (struct _socket_t * s, void *msg, int msglen, + int fds[], int num_fds); + uword private_data; } clib_socket_t; /* socket config format is host:port. @@ -89,7 +99,7 @@ clib_error_t *clib_socket_accept (clib_socket_t * server, always_inline uword clib_socket_is_server (clib_socket_t * sock) { - return (sock->flags & SOCKET_IS_SERVER) != 0; + return (sock->flags & CLIB_SOCKET_F_IS_SERVER) != 0; } always_inline uword @@ -98,10 +108,17 @@ clib_socket_is_client (clib_socket_t * s) return !clib_socket_is_server (s); } +always_inline uword +clib_socket_is_connected (clib_socket_t * sock) +{ + return sock->fd > 0; +} + + always_inline int clib_socket_rx_end_of_file (clib_socket_t * s) { - return s->flags & SOCKET_RX_END_OF_FILE; + return s->flags & CLIB_SOCKET_F_RX_END_OF_FILE; } always_inline void * @@ -130,6 +147,20 @@ clib_socket_rx (clib_socket_t * s, int n_bytes) return s->read_func (s, n_bytes); } +always_inline clib_error_t * +clib_socket_sendmsg (clib_socket_t * s, void *msg, int msglen, + int fds[], int num_fds) +{ + return s->sendmsg_func (s, msg, msglen, fds, num_fds); +} + +always_inline clib_error_t * +clib_socket_recvmsg (clib_socket_t * s, void *msg, int msglen, + int fds[], int num_fds) +{ + return s->recvmsg_func (s, msg, msglen, fds, num_fds); +} + always_inline void clib_socket_free (clib_socket_t * s) { diff --git a/src/vppinfra/test_socket.c b/src/vppinfra/test_socket.c index 0b05467af80..2f25eccd91f 100644 --- a/src/vppinfra/test_socket.c +++ b/src/vppinfra/test_socket.c @@ -50,15 +50,15 @@ test_socket_main (unformat_input_t * input) clib_error_t *error; s->config = "localhost:22"; - s->flags = SOCKET_IS_CLIENT; + s->flags = CLIB_SOCKET_F_IS_CLIENT; while (unformat_check_input (input) != UNFORMAT_END_OF_INPUT) { if (unformat (input, "server %s %=", &config, - &s->flags, SOCKET_IS_SERVER)) + &s->flags, CLIB_SOCKET_F_IS_SERVER)) ; else if (unformat (input, "client %s %=", &config, - &s->flags, SOCKET_IS_CLIENT)) + &s->flags, CLIB_SOCKET_F_IS_CLIENT)) ; else { |