#!/usr/bin/env python3 import ipaddress import random import socket import struct import unittest from io import BytesIO import scapy.compat from config import config from framework import VppTestCase from asfframework import ( tag_fixme_vpp_workers, tag_fixme_ubuntu2204, is_distro_ubuntu2204, VppTestRunner, ) from ipfix import IPFIX, Set, Template, Data, IPFIXDecoder from scapy.data import IP_PROTOS from scapy.layers.inet import IP, TCP, UDP, ICMP from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror from scapy.layers.inet6 import ICMPv6DestUnreach, IPerror6, IPv6ExtHdrFragment from scapy.layers.inet6 import ( IPv6, ICMPv6EchoRequest, ICMPv6EchoReply, ICMPv6ND_NS, ICMPv6ND_NA, ICMPv6NDOptDstLLAddr, fragment6, ) from scapy.layers.l2 import Ether, GRE from scapy.packet import Raw from syslog_rfc5424_parser import SyslogMessage, ParseError from syslog_rfc5424_parser.constants import SyslogSeverity from util import ppc, ppp from vpp_papi import VppEnum from config import config @tag_fixme_vpp_workers @tag_fixme_ubuntu2204 @unittest.skipIf("nat" in config.excluded_plugins, "Exclude NAT plugin tests") class TestNAT64(VppTestCase): """NAT64 Test Cases""" @property def SYSLOG_SEVERITY(self): return VppEnum.vl_api_syslog_severity_t @property def config_flags(self): return VppEnum.vl_api_nat_config_flags_t @classmethod def setUpClass(cls): super(TestNAT64, cls).setUpClass() cls.tcp_port_in = 6303 cls.tcp_port_out = 6303 cls.udp_port_in = 6304 cls.udp_port_out = 6304 cls.icmp_id_in = 6305 cls.icmp_id_out = 6305 cls.tcp_external_port = 80 cls.nat_addr = "10.0.0.3" cls.nat_addr_n = socket.inet_pton(socket.AF_INET, cls.nat_addr) cls.vrf1_id = 10 cls.vrf1_nat_addr = "10.0.10.3" cls.ipfix_src_port = 4739 cls.ipfix_domain_id = 1 cls.create_pg_interfaces(range(6)) cls.ip6_interfaces = list(cls.pg_interfaces[0:1]) cls.ip6_interfaces.append(cls.pg_interfaces[2]) cls.ip4_interfaces = list(cls.pg_interfaces[1:2]) cls.vapi.ip_table_add_del_v2( is_add=1, table={"table_id": cls.vrf1_id, "is_ip6": 1} ) cls.pg_interfaces[2].set_table_ip6(cls.vrf1_id) cls.pg0.generate_remote_hosts(2) for i in cls.ip6_interfaces: i.admin_up() i.config_ip6() i.configure_ipv6_neighbors() for i in cls.ip4_interfaces: i.admin_up() i.config_ip4() i.resolve_arp() cls.pg3.admin_up() cls.pg3.config_ip4() cls.pg3.resolve_arp() cls.pg3.config_ip6() cls.pg3.configure_ipv6_neighbors() cls.pg5.admin_up() cls.pg5.config_ip6() @classmethod def tearDownClass(cls): super(TestNAT64, cls).tearDownClass() def setUp(self): super(TestNAT64, self).setUp() self.vapi.nat64_plugin_enable_disable(enable=1, bib_buckets=128, st_buckets=256) def tearDown(self): super(TestNAT64, self).tearDown() if not self.vpp_dead: self.vapi.nat64_plugin_enable_disable(enable=0) def show_commands_at_teardown(self): self.logger.info(self.vapi.cli("show nat64 pool")) self.logger.info(self.vapi.cli("show nat64 interfaces")) self.logger.info(self.vapi.cli("show nat64 prefix")) self.logger.info(self.vapi.cli("show nat64 bib all")) self.logger.info(self.vapi.cli("show nat64 session table all")) def create_stream_in_ip6(self, in_if, out_if, hlim=64, pref=None, plen=0): """ Create IPv6 packet stream for inside network :param in_if: Inside interface :param out_if: Outside interface :param ttl: Hop Limit of generated packets :param pref: NAT64 prefix :param plen: NAT64 prefix length """ pkts = [] if pref is None: dst = "".join(["64:ff9b::", out_if.remote_ip4]) else: dst = self.compose_ip6(out_if.remote_ip4, pref, plen) # TCP p = ( Ether(dst=in_if.local_mac, src=in_if.remote_mac) / IPv6(src=in_if.remote_ip6, dst=dst, hlim=hlim) / TCP(sport=self.tcp_port_in, dport=20) ) pkts.append(p) # UDP p = ( Ether(dst=in_if.local_mac, src=in_if.remote_mac) / IPv6(src=in_if.remote_ip6, dst=dst, hlim=hlim) / UDP(sport=self.udp_port_in, dport=20) ) pkts.append(p) # ICMP p = ( Ether(dst=in_if.local_mac, src=in_if.remote_mac) / IPv6(src=in_if.remote_ip6, dst=dst, hlim=hlim) / ICMPv6EchoRequest(id=self.icmp_id_in) ) pkts.append(p) return pkts def create_stream_out(self, out_if, dst_ip=None, ttl=64, use_inside_ports=False): """ Create packet stream for outside network :param out_if: Outside interface :param dst_ip: Destination IP address (Default use global NAT address) :param ttl: TTL of generated packets :param use_inside_ports: Use inside NAT ports as destination ports instead of outside ports """ if dst_ip is None: dst_ip = self.nat_addr if not use_inside_ports: tcp_port = self.tcp_port_out udp_port = self.udp_port_out icmp_id = self.icmp_id_out else: tcp_port = self.tcp_port_in udp_port = self.udp_port_in icmp_id = self.icmp_id_in pkts = [] # TCP p = ( Ether(dst=out_if.local_mac, src=out_if.remote_mac) / IP(src=out_if.remote_ip4, dst=dst_ip, ttl=ttl) / TCP(dport=tcp_port, sport=20) ) pkts.extend([p, p]) # UDP p = ( Ether(dst=out_if.local_mac, src=out_if.remote_mac) / IP(src=out_if.remote_ip4, dst=dst_ip, ttl=ttl) / UDP(dport=udp_port, sport=20) ) pkts.append(p) # ICMP p = ( Ether(dst=out_if.local_mac, src=out_if.remote_mac) / IP(src=out_if.remote_ip4, dst=dst_ip, ttl=ttl) / ICMP(id=icmp_id, type="echo-reply") ) pkts.append(p) return pkts def verify_capture_out( self, capture, nat_ip=None, same_port=False, dst_ip=None, is_ip6=False, ignore_port=False, ): """ Verify captured packets on outside network :param capture: Captured packets :param nat_ip: Translated IP address (Default use global NAT address) :param same_port: Source port number is not translated (Default False) :param dst_ip: Destination IP address (Default do not verify) :param is_ip6: If L3 protocol is IPv6 (Default False) """ if is_ip6: IP46 = IPv6 ICMP46 = ICMPv6EchoRequest else: IP46 = IP ICMP46 = ICMP if nat_ip is None: nat_ip = self.nat_addr for packet in capture: try: if not is_ip6: self.assert_packet_checksums_valid(packet) self.assertEqual(packet[IP46].src, nat_ip) if dst_ip is not None: self.assertEqual(packet[IP46].dst, dst_ip) if packet.haslayer(TCP): if not ignore_port: if same_port: self.assertEqual(packet[TCP].sport, self.tcp_port_in) else: self.assertNotEqual(packet[TCP].sport, self.tcp_port_in) self.tcp_port_out = packet[TCP].sport self.assert_packet_checksums_valid(packet) elif packet.haslayer(UDP): if not ignore_port: if same_port: self.assertEqual(packet[UDP].sport, self.udp_port_in) else: self.assertNotEqual(packet[UDP].sport, self.udp_port_in) self.udp_port_out = packet[UDP].sport else: if not ignore_port: if same_port: self.assertEqual(packet[ICMP46].id, self.icmp_id_in) else: self.assertNotEqual(packet[ICMP46].id, self.icmp_id_in) self.icmp_id_out = packet[ICMP46].id self.assert_packet_checksums_valid(packet) except: self.logger.error( ppp("Unexpected or invalid packet (outside network):", packet) ) raise def verify_capture_in_ip6(self, capture, src_ip, dst_ip): """ Verify captured IPv6 packets on inside network :param capture: Captured packets :param src_ip: Source IP :param dst_ip: Destination IP address """ for packet in capture: try: self.assertEqual(packet[IPv6].src, src_ip) self.assertEqual(packet[IPv6].dst, dst_ip) self.assert_packet_checksums_valid(packet) if packet.haslayer(TCP): self.assertEqual(packet[TCP].dport, self.tcp_port_in) elif packet.haslayer(UDP): self.assertEqual(packet[UDP].dport, self.udp_port_in) else: self.assertEqual(packet[ICMPv6EchoReply].id, self.icmp_id_in) except: self.logger.error( ppp("Unexpected or invalid packet (inside network):", packet) ) raise def create_stream_frag( self, src_if, dst, sport, dport, data, proto=IP_PROTOS.tcp, echo_reply=False ): """ Create fragmented packet stream :param src_if: Source interface :param dst: Destination IPv4 address :param sport: Source port :param dport: Destination port :param data: Payload data :param proto: protocol (TCP, UDP, ICMP) :param echo_reply: use echo_reply if protocol is ICMP :returns: Fragments """ if proto == IP_PROTOS.tcp: p = ( IP(src=src_if.remote_ip4, dst=dst) / TCP(sport=sport, dport=dport) / Raw(data) ) p = p.__class__(scapy.compat.raw(p)) chksum = p[TCP].chksum proto_header = TCP(sport=sport, dport=dport, chksum=chksum) elif proto == IP_PROTOS.udp: proto_header = UDP(sport=sport, dport=dport) elif proto == IP_PROTOS.icmp: if not echo_reply: proto_header = ICMP(id=sport, type="echo-request") else: proto_header = ICMP(id=sport, type="echo-reply") else: raise Exception("Unsupported protocol") id = random.randint(0, 65535) pkts = [] if proto == IP_PROTOS.tcp: raw = Raw(data[0:4]) else: raw = Raw(data[0:16]) p = ( Ether(src=src_if.remote_mac, dst=src_if.local_mac) / IP(src=src_if.remote_ip4, dst=dst, flags="MF", frag=0, id=id) / proto_header / raw ) pkts.append(p) if proto == IP_PROTOS.tcp: raw = Raw(data[4:20]) else: raw = Raw(data[16:32]) p = ( Ether(src=src_if.remote_mac, dst=src_if.local_mac) / IP(src=src_if.remote_ip4, dst=dst, flags="MF", frag=3, id=id, proto=proto) / raw ) pkts.append(p) if proto == IP_PROTOS.tcp: raw = Raw(data[20:]) else: raw = Raw(data[32:]) p = ( Ether(src=src_if.remote_mac, dst=src_if.local_mac) / IP(src=src_if.remote_ip4, dst=dst, frag=5, proto=proto, id=id) / raw ) pkts.append(p) return pkts def create_stream_frag_ip6( self, src_if, dst, sport, dport, data, pref=None, plen=0, frag_size=128 ): """ Create fragmented packet stream :param src_if: Source interface :param dst: Destination IPv4 address :param sport: Source TCP port :param dport: Destination TCP port :param data: Payload data :param pref: NAT64 prefix :param plen: NAT64 prefix length :param fragsize: size of fragments :returns: Fragments """ if pref is None: dst_ip6 = "".join(["64:ff9b::", dst]) else: dst_ip6 = self.compose_ip6(dst, pref, plen) p = ( Ether(dst=src_if.local_mac, src=src_if.remote_mac) / IPv6(src=src_if.remote_ip6, dst=dst_ip6) / IPv6ExtHdrFragment(id=random.randint(0, 65535)) / TCP(sport=sport, dport=dport) / Raw(data) ) return fragment6(p, frag_size) def reass_frags_and_verify(self, frags, src, dst): """ Reassemble and verify fragmented packet :param frags: Captured fragments :param src: Source IPv4 address to verify :param dst: Destination IPv4 address to verify :returns: Reassembled IPv4 packet """ buffer = BytesIO() for p in frags: self.assertEqual(p[IP].src, src) self.assertEqual(p[IP].dst, dst) self.assert_ip_checksum_valid(p) buffer.seek(p[IP].frag * 8) buffer.write(bytes(p[IP].payload)) ip = IP(src=frags[0][IP].src, dst=frags[0][IP].dst, proto=frags[0][IP].proto) if ip.proto == IP_PROTOS.tcp: p = ip / TCP(buffer.getvalue()) self.logger.debug(ppp("Reassembled:", p)) self.assert_tcp_checksum_valid(p) elif ip.proto == IP_PROTOS.udp: p = ip / UDP(buffer.getvalue()[:8]) / Raw(buffer.getvalue()[8:]) elif ip.proto == IP_PROTOS.icmp: p = ip / ICMP(buffer.getvalue()) return p def reass_frags_and_verify_ip6(self, frags, src, dst): """ Reassemble and verify fragmented packet :param frags: Captured fragments :param src: Source IPv6 address to verify :param dst: Destination IPv6 address to verify :returns: Reassembled IPv6 packet """ buffer = BytesIO() for p in frags: self.assertEqual(p[IPv6].src, src) self.assertEqual(p[IPv6].dst, dst) buffer.seek(p[IPv6ExtHdrFragment].offset * 8) buffer.write(bytes(p[IPv6ExtHdrFragment].payload)) ip = IPv6( src=frags[0][IPv6].src, dst=frags[0][IPv6].dst, nh=frags[0][IPv6ExtHdrFragment].nh, ) if ip.nh == IP_PROTOS.tcp: p = ip / TCP(buffer.getvalue()) elif ip.nh == IP_PROTOS.udp: p = ip / UDP(buffer.getvalue()) self.logger.debug(ppp("Reassembled:", p)) self.assert_packet_checksums_valid(p) return p def verify_ipfix_max_bibs(self, data, limit): """ Verify IPFIX maximum BIB entries exceeded event :param data: Decoded IPFIX data records :param limit: Number of maximum BIB entries that can be created. """ self.assertEqual(1, len(data)) record = data[0] # natEvent self.assertEqual(scapy.compat.orb(record[230]), 13) # natQuotaExceededEvent self.assertEqual(struct.pack("!I", 2), record[466]) # maxBIBEntries self.assertEqual(struct.pack("!I", limit), record[472]) return len(data) def verify_ipfix_bib(self, data, is_create, src_addr): """ Verify IPFIX NAT64 BIB create and delete events :param data: Decoded IPFIX data records :param is_create: Create event if nonzero value otherwise delete event :param src_addr: IPv6 source address """ self.assertEqual(1, len(data)) record = data[0] # natEvent if is_create: self.assertEqual(scapy.compat.orb(record[230]), 10) else: self.assertEqual(scapy.compat.orb(record[230]), 11) # sourceIPv6Address self.assertEqual(src_addr, str(ipaddress.IPv6Address(record[27]))) # postNATSourceIPv4Address self.assertEqual(self.nat_addr_n, record[225]) # protocolIdentifier self.assertEqual(IP_PROTOS.tcp, scapy.compat.orb(record[4])) # ingressVRFID self.assertEqual(struct.pack("!I", 0), record[234]) # sourceTransportPort self.assertEqual(struct.pack("!H", self.tcp_port_in), record[7]) # postNAPTSourceTransportPort self.assertEqual(struct.pack("!H", self.tcp_port_out), record[227]) def verify_ipfix_nat64_ses(self, data, is_create, src_addr, dst_addr, dst_port): """ Verify IPFIX NAT64 session create and delete events :param data: Decoded IPFIX data records :param is_create: Create event if nonzero value otherwise delete event :param src_addr: IPv6 source address :param dst_addr: IPv4 destination address :param dst_port: destination TCP port """ self.assertEqual(1, len(data)) record = data[0] # natEvent if is_create: self.assertEqual(scapy.compat.orb(record[230]), 6) else: self.assertEqual(scapy.compat.orb(record[230]), 7) # sourceIPv6Address self.assertEqual(src_addr, str(ipaddress.IPv6Address(record[27]))) # destinationIPv6Address self.assertEqual( socket.inet_pton( socket.AF_INET6, self.compose_ip6(dst_addr, "64:ff9b::", 96) ), record[28], ) # postNATSourceIPv4Address self.assertEqual(self.nat_addr_n, record[225]) # postNATDestinationIPv4Address self.assertEqual(socket.inet_pton(socket.AF_INET, dst_addr), record[226]) # protocolIdentifier self.assertEqual(IP_PROTOS.tcp, scapy.compat.orb(record[4])) # ingressVRFID self.assertEqual(struct.pack("!I", 0), record[234]) # sourceTransportPort self.assertEqual(struct.pack("!H", self.tcp_port_in), record[7]) # postNAPTSourceTransportPort self.assertEqual(struct.pack("!H", self.tcp_port_out), record[227]) # destinationTransportPort self.assertEqual(struct.pack("!H", dst_port), record[11]) # postNAPTDestinationTransportPort self.assertEqual(struct.pack("!H", dst_port), record[228]) def verify_syslog_sess(self, data, is_add=True, is_ip6=False): message = data.decode("utf-8") try: message = SyslogMessage.parse(message) except ParseError as e: self.logger.error(e) raise else: self.assertEqual(message.severity, SyslogSeverity.info) self.assertEqual(message.appname, "NAT") self.assertEqual(message.msgid, "SADD" if is_add else "SDEL") sd_params = message.sd.get("nsess") self.assertTrue(sd_params is not None) if is_ip6: self.assertEqual(sd_params.get("IATYP"), "IPv6") self.assertEqual(sd_params.get("ISADDR"), self.pg0.remote_ip6) else: self.assertEqual(sd_params.get("IATYP"), "IPv4") self.assertEqual(sd_params.get("ISADDR"), self.pg0.remote_ip4) self.assertTrue(sd_params.get("SSUBIX") is not None) self.assertEqual(sd_params.get("ISPORT"), "%d" % self.tcp_port_in) self.assertEqual(sd_params.get("XATYP"), "IPv4") self.assertEqual(sd_params.get("XSADDR"), self.nat_addr) self.assertEqual(sd_params.get("XSPORT"), "%d" % self.tcp_port_out) self.assertEqual(sd_params.get("PROTO"), "%d" % IP_PROTOS.tcp) self.assertEqual(sd_params.get("SVLAN"), "0") self.assertEqual(sd_params.get("XDADDR"), self.pg1.remote_ip4) self.assertEqual(sd_params.get("XDPORT"), "%d" % self.tcp_external_port) def compose_ip6(self, ip4, pref, plen): """ Compose IPv4-embedded IPv6 addresses :param ip4: IPv4 address :param pref: IPv6 prefix :param plen: IPv6 prefix length :returns: IPv4-embedded IPv6 addresses """ pref_n = list(socket.inet_pton(socket.AF_INET6, pref)) ip4_n = list(socket.inet_pton(socket.AF_INET, ip4)) if plen == 32: pref_n[4] = ip4_n[0] pref_n[5] = ip4_n[1] pref_n[6] = ip4_n[2] pref_n[7] = ip4_n[3] elif plen == 40: pref_n[5] = ip4_n[0] pref_n[6] = ip4_n[1] pref_n[7] = ip4_n[2] pref_n[9] = ip4_n[3] elif plen == 48: pref_n[6] = ip4_n[0] pref_n[7] = ip4_n[1] pref_n[9] = ip4_n[2] pref_n[10] = ip4_n[3] elif plen == 56: pref_n[7] = ip4_n[0] pref_n[9] = ip4_n[1] pref_n[10] = ip4_n[2] pref_n[11] = ip4_n[3] elif plen == 64: pref_n[9] = ip4_n[0] pref_n[10] = ip4_n[1] pref_n[11] = ip4_n[2] pref_n[12] = ip4_n[3] elif plen == 96: pref_n[12] = ip4_n[0] pref_n[13] = ip4_n[1] pref_n[14] = ip4_n[2] pref_n[15] = ip4_n[3] packed_pref_n = b"".join([scapy.compat.chb(x) for x in pref_n]) return socket.inet_ntop(socket.AF_INET6, packed_pref_n) def verify_ipfix_max_sessions(self, data, limit): """ Verify IPFIX maximum session entries exceeded event :param data: Decoded IPFIX data records :param limit: Number of maximum session entries that can be created. """ self.assertEqual(1, len(data)) record = data[0] # natEvent self.assertEqual(scapy.compat.orb(record[230]), 13) # natQuotaExceededEvent self.assertEqual(struct.pack("!I", 1), record[466]) # maxSessionEntries self.assertEqual(struct.pack("!I", limit), record[471]) return len(data) def test_nat64_inside_interface_handles_neighbor_advertisement(self): """NAT64 inside interface handles Neighbor Advertisement""" flags = self.config_flags.NAT_IS_INSIDE self.vapi.nat64_add_del_interface( is_add=1, flags=flags, sw_if_index=self.pg5.sw_if_index ) # Try to send ping ping = ( Ether(dst=self.pg5.local_mac, src=self.pg5.remote_mac) / IPv6(src=self.pg5.remote_ip6, dst=self.pg5.local_ip6) / ICMPv6EchoRequest() ) pkts = [ping] self.pg5.add_stream(pkts) self.pg_enable_capture(self.pg_interfaces) self.pg_start() # Wait for Neighbor Solicitation capture = self.pg5.get_capture(len(pkts)) packet = capture[0] try: self.assertEqual(packet[IPv6].src, self.pg5.local_ip6_ll) self.assertEqual(packet.haslayer(ICMPv6ND_NS), 1) tgt = packet[ICMPv6ND_NS].tgt except: self.logger.error(ppp("Unexpected or invalid packet:", packet)) raise # Send Neighbor Advertisement p = ( Ether(dst=self.pg5.local_mac, src=self.pg5.remote_mac) / IPv6(src=self.pg5.remote_ip6, dst=self.pg5.local_ip6) / ICMPv6ND_NA(tgt=
/*
* Copyright (c) 2015 Cisco and/or its affiliates.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
Copyright (c) 2001, 2002, 2003, 2005 Eliot Dresselhaus
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <stdio.h>
#include <string.h> /* strchr */
#define __USE_GNU
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/stat.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <vppinfra/mem.h>
#include <vppinfra/vec.h>
#include <vppinfra/socket.h>
#include <vppinfra/format.h>
#include <vppinfra/error.h>
void
clib_socket_tx_add_formatted (clib_socket_t * s, char *fmt, ...)
{
va_list va;
va_start (va, fmt);
clib_socket_tx_add_va_formatted (s, fmt, &va);
va_end (va);
}
/* Return and bind to an unused port. */
static word
find_free_port (word sock)
{
word port;
for (port = IPPORT_USERRESERVED; port < 1 << 16; port++)
{
struct sockaddr_in a;
clib_memset (&a, 0, sizeof (a)); /* Warnings be gone */
a.sin_family = PF_INET;
a.sin_addr.s_addr = INADDR_ANY;
a.sin_port = htons (port);
if (bind (sock, (struct sockaddr *) &a, sizeof (a)) >= 0)
break;
}
return port < 1 << 16 ? port : -1;
}
/* Convert a config string to a struct sockaddr and length for use
with bind or connect. */
static clib_error_t *
socket_config (char *config,
void *addr, socklen_t * addr_len, u32 ip4_default_address)
{
clib_error_t *error = 0;
if (!config)
config = "";
/* Anything that begins with a / is a local PF_LOCAL socket. */
if (config[0] == '/')
{
struct sockaddr_un *su = addr;
su->sun_family = PF_LOCAL;
clib_memcpy (&su->sun_path, config,
clib_min (sizeof (su->sun_path), 1 + strlen (config)));
*addr_len = sizeof (su[0]);
}
/* Hostname or hostname:port or port. */
else
{
char *host_name;
int port = -1;
struct sockaddr_in *sa = addr;
host_name = 0;
port = -1;
if (config[0] != 0)
{
unformat_input_t i;
unformat_init_string (&i, config, strlen (config));
if (unformat (&i, "%s:%d", &host_name, &port)
|| unformat (&i, "%s:0x%x", &host_name, &port))
;
else if (unformat (&i, "%s", &host_name))
;
else
error = clib_error_return (0, "unknown input `%U'",
format_unformat_error, &i);
unformat_free (&i);
if (error)
goto done;
}
sa->sin_family = PF_INET;
*addr_len = sizeof (sa[0]);
if (port != -1)
sa->sin_port = htons (port);
else
sa->sin_port = 0;
if (host_name)
{
struct in_addr host_addr;
/* Recognize localhost to avoid host lookup in most common cast. */
if (!strcmp (host_name, "localhost"))
sa->sin_addr.s_addr = htonl (INADDR_LOOPBACK);
else if (inet_aton (host_name, &host_addr))
sa->sin_addr = host_addr;
else if (host_name && strlen (host_name) > 0)
{
struct hostent *host = gethostbyname (host_name);
if (!host)
error = clib_error_return (0, "unknown host `%s'", config);
else
clib_memcpy (&sa->sin_addr.s_addr, host->h_addr_list[0],
host->h_length);
}
else
sa->sin_addr.s_addr = htonl (ip4_default_address);
vec_free (host_name);
if (error)
goto done;
}
}
done:
return error;
}
static clib_error_t *
default_socket_write (clib_socket_t * s)
{
clib_error_t *err = 0;
word written = 0;
word fd = 0;
word tx_len;
fd = s->fd;
/* Map standard input to standard output.
Typically, fd is a socket for which read/write both work. */
if (fd == 0)
fd = 1;
tx_len = vec_len (s->tx_buffer);
written = write (fd, s->tx_buffer, tx_len);
/* Ignore certain errors. */
if (written < 0 && !unix_error_is_fatal (errno))
written = 0;
/* A "real" error occurred. */
if (written < 0)
{
err = clib_error_return_unix (0, "write %wd bytes (fd %d, '%s')",
tx_len, s->fd, s->config);
vec_free (s->tx_buffer);
goto done;
}
/* Reclaim the transmitted part of the tx buffer on successful writes. */
else if (written > 0)
{
if (written == tx_len)
_vec_len (s->tx_buffer) = 0;
else
vec_delete (s->tx_buffer, written, 0);
}
/* If a non-fatal error occurred AND
the buffer is full, then we must free it. */
else if (written == 0 && tx_len > 64 * 1024)
{
vec_free (s->tx_buffer);
}
done:
return err;
}
static clib_error_t *
default_socket_read (clib_socket_t * sock, int n_bytes)
{
word fd, n_read;
u8 *buf;
/* RX side of socket is down once end of file is reached. */
if (sock->flags & CLIB_SOCKET_F_RX_END_OF_FILE)
return 0;
fd = sock->fd;
n_bytes = clib_max (n_bytes, 4096);
vec_add2 (sock->rx_buffer, buf, n_bytes);
if ((n_read = read (fd, buf, n_bytes)) < 0)
{
n_read = 0;
/* Ignore certain errors. */
if (!unix_error_is_fatal (errno))
goto non_fatal;
return clib_error_return_unix (0, "read %d bytes (fd %d, '%s')",
n_bytes, sock->fd, sock->config);
}
/* Other side closed the socket. */
if (n_read == 0)
sock->flags |= CLIB_SOCKET_F_RX_END_OF_FILE;
non_fatal:
_vec_len (sock->rx_buffer) += n_read - n_bytes;
return 0;
}
static clib_error_t *
default_socket_close (clib_socket_t * s)
{
if (close (s->fd) < 0)
return clib_error_return_unix (0, "close (fd %d, %s)", s->fd, s->config);
return 0;
}
static clib_error_t *
default_socket_sendmsg (clib_socket_t * s, void *msg, int msglen,
int fds[], int num_fds)
{
struct msghdr mh = { 0 };
struct iovec iov[1];
char ctl[CMSG_SPACE (sizeof (int) * num_fds)];
int rv;
iov[0].iov_base = msg;
iov[0].iov_len = msglen;
mh.msg_iov = iov;
mh.msg_iovlen = 1;
if (num_fds > 0)
{
struct cmsghdr *cmsg;
clib_memset (&ctl, 0, sizeof (ctl));
mh.msg_control = ctl;
mh.msg_controllen = sizeof (ctl);
cmsg = CMSG_FIRSTHDR (&mh);
cmsg->cmsg_len = CMSG_LEN (sizeof (int) * num_fds);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
memcpy (CMSG_DATA (cmsg), fds, sizeof (int) * num_fds);
}
rv = sendmsg (s->fd, &mh, 0);
if (rv < 0)
return clib_error_return_unix (0, "sendmsg");
return 0;
}
static clib_error_t *
default_socket_recvmsg (clib_socket_t * s, void *msg, int msglen,
int fds[], int num_fds)
{
#ifdef __linux__
char ctl[CMSG_SPACE (sizeof (int) * num_fds) +
CMSG_SPACE (sizeof (struct ucred))];
struct ucred *cr = 0;
#else
char ctl[CMSG_SPACE (sizeof (int) * num_fds)];
#endif
struct msghdr mh = { 0 };
struct iovec iov[1];
ssize_t size;
struct cmsghdr *cmsg;
iov[0].iov_base = msg;
iov[0].iov_len = msglen;
mh.msg_iov = iov;
mh.msg_iovlen = 1;
mh.msg_control = ctl;
mh.msg_controllen = sizeof (ctl);
clib_memset (ctl, 0, sizeof (ctl));
/* receive the incoming message */
size = recvmsg (s->fd, &mh, 0);
if (size != msglen)
{
return (size == 0) ? clib_error_return (0, "disconnected") :
clib_error_return_unix (0, "recvmsg: malformed message (fd %d, '%s')",
s->fd, s->config);
}
cmsg = CMSG_FIRSTHDR (&mh);
while (cmsg)
{
if (cmsg->cmsg_level == SOL_SOCKET)
{
#ifdef __linux__
if (cmsg->cmsg_type == SCM_CREDENTIALS)
{
cr = (struct ucred *) CMSG_DATA (cmsg);
s->uid = cr->uid;
s->gid = cr->gid;
s->pid = cr->pid;
}
else
#endif
if (cmsg->cmsg_type == SCM_RIGHTS)
{
clib_memcpy_fast (fds, CMSG_DATA (cmsg),
num_fds * sizeof (int));
}
}
cmsg = CMSG_NXTHDR (&mh, cmsg);
}
return 0;
}
static void
socket_init_funcs (clib_socket_t * s)
{
if (!s->write_func)
s->write_func = default_socket_write;
if (!s->read_func)
s->read_func = default_socket_read;
if (!s->close_func)
s->close_func = default_socket_close;
if (!s->sendmsg_func)
s->sendmsg_func = default_socket_sendmsg;
if (!s->recvmsg_func)
s->recvmsg_func = default_socket_recvmsg;
}
clib_error_t *
clib_socket_init (clib_socket_t