summaryrefslogtreecommitdiffstats
path: root/src/vppinfra
diff options
context:
space:
mode:
Diffstat (limited to 'src/vppinfra')
-rw-r--r--src/vppinfra/socket.c130
-rw-r--r--src/vppinfra/socket.h49
-rw-r--r--src/vppinfra/test_socket.c6
3 files changed, 161 insertions, 24 deletions
diff --git a/src/vppinfra/socket.c b/src/vppinfra/socket.c
index 37dcbbfd8c6..87a9333f904 100644
--- a/src/vppinfra/socket.c
+++ b/src/vppinfra/socket.c
@@ -35,17 +35,18 @@
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
-#include <sys/un.h>
+#include <stdio.h>
+#include <string.h> /* strchr */
+#define __USE_GNU
#include <sys/types.h>
#include <sys/socket.h>
+#include <sys/un.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
-#include <stdio.h>
#include <fcntl.h>
-#include <string.h> /* strchr */
#include <vppinfra/mem.h>
#include <vppinfra/vec.h>
@@ -233,7 +234,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes)
u8 *buf;
/* RX side of socket is down once end of file is reached. */
- if (sock->flags & SOCKET_RX_END_OF_FILE)
+ if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
return 0;
fd = sock->fd;
@@ -255,7 +256,7 @@ default_socket_read (clib_socket_t * sock, int n_bytes)
/* Other side closed the socket. */
if (n_read == 0)
- sock->flags |= SOCKET_RX_END_OF_FILE;
+ sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
non_fatal:
_vec_len (sock->rx_buffer) += n_read - n_bytes;
@@ -271,6 +272,91 @@ default_socket_close (clib_socket_t * s)
return 0;
}
+static clib_error_t *
+default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds)
+{
+ struct msghdr mh = { 0 };
+ struct iovec iov[1];
+ char ctl[CMSG_SPACE (sizeof (int)) * num_fds];
+ int rv;
+
+ iov[0].iov_base = msg;
+ iov[0].iov_len = msglen;
+ mh.msg_iov = iov;
+ mh.msg_iovlen = 1;
+
+ if (num_fds > 0)
+ {
+ struct cmsghdr *cmsg;
+ memset (&ctl, 0, sizeof (ctl));
+ mh.msg_control = ctl;
+ mh.msg_controllen = sizeof (ctl);
+ cmsg = CMSG_FIRSTHDR (&mh);
+ cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
+ }
+ rv = sendmsg (s->fd, &mh, 0);
+ if (rv < 0)
+ return clib_error_return_unix (0, "sendmsg");
+ return 0;
+}
+
+
+static clib_error_t *
+default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds)
+{
+ char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
+ CMSG_SPACE (sizeof (struct ucred))];
+ struct msghdr mh = { 0 };
+ struct iovec iov[1];
+ ssize_t size;
+ struct ucred *cr = 0;
+ struct cmsghdr *cmsg;
+
+ iov[0].iov_base = msg;
+ iov[0].iov_len = msglen;
+ mh.msg_iov = iov;
+ mh.msg_iovlen = 1;
+ mh.msg_control = ctl;
+ mh.msg_controllen = sizeof (ctl);
+
+ memset (ctl, 0, sizeof (ctl));
+
+ /* receive the incoming message */
+ size = recvmsg (s->fd, &mh, 0);
+ if (size != msglen)
+ {
+ return (size == 0) ? clib_error_return (0, "disconnected") :
+ clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
+ s->fd, s->config);
+ }
+
+ cmsg = CMSG_FIRSTHDR (&mh);
+ while (cmsg)
+ {
+ if (cmsg->cmsg_level == SOL_SOCKET)
+ {
+ if (cmsg->cmsg_type == SCM_CREDENTIALS)
+ {
+ cr = (struct ucred *) CMSG_DATA (cmsg);
+ s->uid = cr->uid;
+ s->gid = cr->gid;
+ s->pid = cr->pid;
+ }
+ else if (cmsg->cmsg_type == SCM_RIGHTS)
+ {
+ clib_memcpy (fds, CMSG_DATA (cmsg), num_fds * sizeof (int));
+ }
+ }
+ cmsg = CMSG_NXTHDR (&mh, cmsg);
+ }
+ return 0;
+}
+
static void
socket_init_funcs (clib_socket_t * s)
{
@@ -280,6 +366,10 @@ socket_init_funcs (clib_socket_t * s)
s->read_func = default_socket_read;
if (!s->close_func)
s->close_func = default_socket_close;
+ if (!s->sendmsg_func)
+ s->sendmsg_func = default_socket_sendmsg;
+ if (!s->recvmsg_func)
+ s->recvmsg_func = default_socket_recvmsg;
}
clib_error_t *
@@ -291,18 +381,22 @@ clib_socket_init (clib_socket_t * s)
struct sockaddr_un su;
} addr;
socklen_t addr_len = 0;
+ int socket_type;
clib_error_t *error = 0;
word port;
error = socket_config (s->config, &addr.sa, &addr_len,
- (s->flags & SOCKET_IS_SERVER
+ (s->flags & CLIB_SOCKET_F_IS_SERVER
? INADDR_LOOPBACK : INADDR_ANY));
if (error)
goto done;
socket_init_funcs (s);
- s->fd = socket (addr.sa.sa_family, SOCK_STREAM, 0);
+ socket_type = s->flags & CLIB_SOCKET_F_SEQPACKET ?
+ SOCK_SEQPACKET : SOCK_STREAM;
+
+ s->fd = socket (addr.sa.sa_family, socket_type, 0);
if (s->fd < 0)
{
error = clib_error_return_unix (0, "socket (fd %d, '%s')",
@@ -314,7 +408,7 @@ clib_socket_init (clib_socket_t * s)
if (addr.sa.sa_family == PF_INET)
port = ((struct sockaddr_in *) &addr)->sin_port;
- if (s->flags & SOCKET_IS_SERVER)
+ if (s->flags & CLIB_SOCKET_F_IS_SERVER)
{
uword need_bind = 1;
@@ -342,6 +436,18 @@ clib_socket_init (clib_socket_t * s)
clib_unix_warning ("setsockopt SO_REUSEADDR fails");
}
+ if (addr.sa.sa_family == PF_LOCAL && s->flags & CLIB_SOCKET_F_PASSCRED)
+ {
+ int x = 1;
+ if (setsockopt (s->fd, SOL_SOCKET, SO_PASSCRED, &x, sizeof (x)) < 0)
+ {
+ error = clib_error_return_unix (0, "setsockopt (SO_PASSCRED, "
+ "fd %d, '%s')", s->fd,
+ s->config);
+ goto done;
+ }
+ }
+
if (need_bind && bind (s->fd, &addr.sa, addr_len) < 0)
{
error = clib_error_return_unix (0, "bind (fd %d, '%s')",
@@ -356,7 +462,7 @@ clib_socket_init (clib_socket_t * s)
goto done;
}
if (addr.sa.sa_family == PF_LOCAL
- && s->flags & SOCKET_ALLOW_GROUP_WRITE)
+ && s->flags & CLIB_SOCKET_F_ALLOW_GROUP_WRITE)
{
struct stat st = { 0 };
if (stat (((struct sockaddr_un *) &addr)->sun_path, &st) < 0)
@@ -378,7 +484,7 @@ clib_socket_init (clib_socket_t * s)
}
else
{
- if ((s->flags & SOCKET_NON_BLOCKING_CONNECT)
+ if ((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT)
&& fcntl (s->fd, F_SETFL, O_NONBLOCK) < 0)
{
error = clib_error_return_unix (0, "fcntl NONBLOCK (fd %d, '%s')",
@@ -387,7 +493,7 @@ clib_socket_init (clib_socket_t * s)
}
if (connect (s->fd, &addr.sa, addr_len) < 0
- && !((s->flags & SOCKET_NON_BLOCKING_CONNECT) &&
+ && !((s->flags & CLIB_SOCKET_F_NON_BLOCKING_CONNECT) &&
errno == EINPROGRESS))
{
error = clib_error_return_unix (0, "connect (fd %d, '%s')",
@@ -434,7 +540,7 @@ clib_socket_accept (clib_socket_t * server, clib_socket_t * client)
goto close_client;
}
- client->flags = SOCKET_IS_CLIENT;
+ client->flags = CLIB_SOCKET_F_IS_CLIENT;
socket_init_funcs (client);
return 0;
diff --git a/src/vppinfra/socket.h b/src/vppinfra/socket.h
index 75037208d5d..4f9e9509342 100644
--- a/src/vppinfra/socket.h
+++ b/src/vppinfra/socket.h
@@ -55,13 +55,14 @@ typedef struct _socket_t
char *config;
u32 flags;
-#define SOCKET_IS_SERVER (1 << 0)
-#define SOCKET_IS_CLIENT (0 << 0)
-#define SOCKET_NON_BLOCKING_CONNECT (1 << 1)
-#define SOCKET_ALLOW_GROUP_WRITE (1 << 2)
+#define CLIB_SOCKET_F_IS_SERVER (1 << 0)
+#define CLIB_SOCKET_F_IS_CLIENT (0 << 0)
+#define CLIB_SOCKET_F_RX_END_OF_FILE (1 << 2)
+#define CLIB_SOCKET_F_NON_BLOCKING_CONNECT (1 << 3)
+#define CLIB_SOCKET_F_ALLOW_GROUP_WRITE (1 << 4)
+#define CLIB_SOCKET_F_SEQPACKET (1 << 5)
+#define CLIB_SOCKET_F_PASSCRED (1 << 6)
- /* Read returned end-of-file. */
-#define SOCKET_RX_END_OF_FILE (1 << 2)
/* Transmit buffer. Holds data waiting to be written. */
u8 *tx_buffer;
@@ -72,10 +73,19 @@ typedef struct _socket_t
/* Peer socket we are connected to. */
struct sockaddr_in peer;
+ /* Credentials, populated if CLIB_SOCKET_F_PASSCRED is set */
+ pid_t pid;
+ uid_t uid;
+ gid_t gid;
+
clib_error_t *(*write_func) (struct _socket_t * sock);
clib_error_t *(*read_func) (struct _socket_t * sock, int min_bytes);
clib_error_t *(*close_func) (struct _socket_t * sock);
- void *private_data;
+ clib_error_t *(*recvmsg_func) (struct _socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds);
+ clib_error_t *(*sendmsg_func) (struct _socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds);
+ uword private_data;
} clib_socket_t;
/* socket config format is host:port.
@@ -89,7 +99,7 @@ clib_error_t *clib_socket_accept (clib_socket_t * server,
always_inline uword
clib_socket_is_server (clib_socket_t * sock)
{
- return (sock->flags & SOCKET_IS_SERVER) != 0;
+ return (sock->flags & CLIB_SOCKET_F_IS_SERVER) != 0;
}
always_inline uword
@@ -98,10 +108,17 @@ clib_socket_is_client (clib_socket_t * s)
return !clib_socket_is_server (s);
}
+always_inline uword
+clib_socket_is_connected (clib_socket_t * sock)
+{
+ return sock->fd > 0;
+}
+
+
always_inline int
clib_socket_rx_end_of_file (clib_socket_t * s)
{
- return s->flags & SOCKET_RX_END_OF_FILE;
+ return s->flags & CLIB_SOCKET_F_RX_END_OF_FILE;
}
always_inline void *
@@ -130,6 +147,20 @@ clib_socket_rx (clib_socket_t * s, int n_bytes)
return s->read_func (s, n_bytes);
}
+always_inline clib_error_t *
+clib_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds)
+{
+ return s->sendmsg_func (s, msg, msglen, fds, num_fds);
+}
+
+always_inline clib_error_t *
+clib_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
+ int fds[], int num_fds)
+{
+ return s->recvmsg_func (s, msg, msglen, fds, num_fds);
+}
+
always_inline void
clib_socket_free (clib_socket_t * s)
{
diff --git a/src/vppinfra/test_socket.c b/src/vppinfra/test_socket.c
index 0b05467af80..2f25eccd91f 100644
--- a/src/vppinfra/test_socket.c
+++ b/src/vppinfra/test_socket.c
@@ -50,15 +50,15 @@ test_socket_main (unformat_input_t * input)
clib_error_t *error;
s->config = "localhost:22";
- s->flags = SOCKET_IS_CLIENT;
+ s->flags = CLIB_SOCKET_F_IS_CLIENT;
while (unformat_check_input (input) != UNFORMAT_END_OF_INPUT)
{
if (unformat (input, "server %s %=", &config,
- &s->flags, SOCKET_IS_SERVER))
+ &s->flags, CLIB_SOCKET_F_IS_SERVER))
;
else if (unformat (input, "client %s %=", &config,
- &s->flags, SOCKET_IS_CLIENT))
+ &s->flags, CLIB_SOCKET_F_IS_CLIENT))
;
else
{