aboutsummaryrefslogtreecommitdiffstats
path: root/test/packetdrill/wire_conn.c
blob: 945f4b00c5e36ffe587433469773aba991205458 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
/*
 * Copyright 2013 Google Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 */
/*
 * Author: ncardwell@google.com (Neal Cardwell)
 *
 * TCP connection handling for remote on-the-wire testing using a real NIC.
 */

#include "wire_conn.h"

#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <stdlib.h>
#include <unistd.h>

#include "logging.h"
#include "tcp.h"
#include "wrap.h"

/* Cap the max message we're willing to read, so remote side can't OOM us. */
#define MAX_MESSAGE_BYTES (10*1000*1000)

struct wire_conn *wire_conn_new(void)
{
	DEBUGP("wire_conn_new\n");
	struct wire_conn *wire_conn = calloc(1, sizeof(struct wire_conn));
	wire_conn->fd = -1;

	return wire_conn;
}

void wire_conn_free(struct wire_conn *conn)
{
	if (conn->fd != -1)
		close(conn->fd);
	free(conn->in.buf);
	memset(conn, 0, sizeof(*conn));  /* paranoia: catch bugs */
	free(conn);
}

/* Create the TCP socket. */
static void create_tcp_socket(struct wire_conn *conn,
				enum ip_version_t ip_version)
{
	assert(conn->fd == -1);
	conn->fd = wrap_socket(ip_version, SOCK_STREAM);
}

/* Set default TCP socket options for decent performance. */
static void set_default_tcp_options(struct wire_conn *conn)
{
	int val;

	DEBUGP("set_default_tcp_options fd %d\n", conn->fd);

	/* Disable Nagle algorithm so packets go out ASAP regardless of size. */
	val = 1;
	if (setsockopt(conn->fd, SOL_TCP, TCP_NODELAY, &val, sizeof(val)) < 0)
		die_perror("setsockopt TCP_NODELAY");

	/* Set receive buffer to allow high throughput. */
	val = 128*1024;
	if (setsockopt(conn->fd, SOL_SOCKET, SO_RCVBUF, &val,
		       sizeof(val)) < 0) {
		die_perror("setsockopt SO_RCVBUF");
	}

	/* Set send buffer to allow high throughput and avoid blocking. */
	val = 128*1024;
	if (setsockopt(conn->fd, SOL_SOCKET, SO_SNDBUF, &val,
		       sizeof(val)) < 0) {
		die_perror("setsockopt SO_SNDBUF");
	}
}

void wire_conn_connect(struct wire_conn *conn,
			const struct ip_address *ip,
			u16 port,
			enum ip_version_t ip_version)
{
	DEBUGP("wire_conn_connect\n");
	struct sockaddr_storage sa;
	socklen_t length = 0;

	create_tcp_socket(conn, ip_version);
	set_default_tcp_options(conn);

	/* Do a blocking connect to the server. */
	ip_to_sockaddr(ip, port, (struct sockaddr *)&sa, &length);
	if (connect(conn->fd, (struct sockaddr *)&sa, length) < 0) {
		char ip_string[ADDR_STR_LEN];
		die("error connecting to wire server at %s:%d: %s\n",
		    ip_to_string(ip, ip_string), port, strerror(errno));
	}
}

void wire_conn_bind_listen(struct wire_conn *listen_conn,
				u16 port,
				enum ip_version_t ip_version)
{
	DEBUGP("wire_conn_bind_listen\n");
	int val;

	create_tcp_socket(listen_conn, ip_version);

	val = 1;
	if (setsockopt(listen_conn->fd, SOL_SOCKET, SO_REUSEADDR,
		       &val, sizeof(val)) < 0) {
		die_perror("setsockopt SO_REUSEADDR");
	}

	wrap_bind_listen(listen_conn->fd, ip_version, port);
}

void wire_conn_accept(struct wire_conn *listen_conn,
		      struct wire_conn **accepted_conn)
{
	int fd = -1;

	DEBUGP("wire_conn_accept\n");

	fd = accept(listen_conn->fd, NULL, NULL);
	if (fd < 0)
		die_perror("accept");

	DEBUGP("accepted fd %d\n", fd);

	*accepted_conn = wire_conn_new();
	(*accepted_conn)->fd = fd;

	set_default_tcp_options(*accepted_conn);
}

/* Do blocking writes until all bytes are written.  Given our large
 * socket buffer size and typically small write sizes, in practice all
 * the writes should complete in one call.
 */
static int write_bytes(struct wire_conn *conn,
		       const void *buf, int buf_len)
{
	while (buf_len > 0) {
		int bytes_written = write(conn->fd, buf, buf_len);
		if (bytes_written < 0) {
			if (errno == EINTR || errno == EAGAIN) {
				continue;
			} else {
				perror("TCP socket write");
				return STATUS_ERR;
			}
		}
		assert(bytes_written <= buf_len);
		buf_len -= bytes_written;
		buf += bytes_written;
	}
	return STATUS_OK;
}

int wire_conn_write(struct wire_conn *conn,
		    enum wire_op_t op,
		    const void *buf, int buf_len)
{
	DEBUGP("wire_conn_write -> op: %s\n",
	       wire_op_to_string(op));
	struct wire_header header;

	header.length	= htonl(sizeof(header) + buf_len);
	header.op	= htonl(op);

	if (write_bytes(conn, &header, sizeof(header)))
		return STATUS_ERR;

	if (write_bytes(conn, buf, buf_len))
		return STATUS_ERR;

	return STATUS_OK;
}

/* Do blocking reads until we've read the given number of bytes. */
static int read_bytes(struct wire_conn *conn,
		      void *buf, int buf_len)
{
	while (buf_len > 0) {
		int bytes_read = read(conn->fd, buf, buf_len);
		if (bytes_read < 0) {
			if (errno == EINTR || errno == EAGAIN) {
				continue;
			} else {
				perror("TCP socket read");
				return STATUS_ERR;
			}
		} else if (bytes_read == 0) {
			fprintf(stderr, "remote side closed connection\n");
			return STATUS_ERR;
		}
		assert(bytes_read <= buf_len);
		buf_len -= bytes_read;
		buf += bytes_read;
	}
	return STATUS_OK;
}

int wire_conn_read(struct wire_conn *conn,
		   enum wire_op_t *op,
		   void **buf, int *buf_len)
{
	DEBUGP("wire_conn_read\n");

	struct wire_header header;

	if (read_bytes(conn, &header, sizeof(header)))
		return STATUS_ERR;

	*op = ntohl(header.op);

	DEBUGP("wire_conn_read -> op: %s\n", wire_op_to_string(*op));

	*buf_len = ntohl(header.length) - sizeof(header);
	if ((*buf_len < 0) || (*buf_len > MAX_MESSAGE_BYTES)) {
		fprintf(stderr, "invalid length %d from remote wire conn\n",
			*buf_len);
		return STATUS_ERR;
	}

	if (conn->in.buf_space < *buf_len) {
		free(conn->in.buf);
		conn->in.buf_space = 2 * *buf_len;
		conn->in.buf = malloc(conn->in.buf_space);
	}

	*buf = conn->in.buf;

	if (read_bytes(conn, *buf, *buf_len))
		return STATUS_ERR;

	return STATUS_OK;
}