#!/usr/bin/env python3

import unittest

from framework import VppTestCase
from asfframework import VppTestRunner

from scapy.layers.inet import IP, TCP
from scapy.layers.inet6 import IPv6
from scapy.layers.l2 import Ether
from scapy.packet import Raw


class TestMSSClamp(VppTestCase):
    """TCP MSS Clamping Test Case"""

    def setUp(self):
        super(TestMSSClamp, self).setUp()

        # create 2 pg interfaces
        self.create_pg_interfaces(range(2))

        for i in self.pg_interfaces:
            i.admin_up()
            i.config_ip4()
            i.resolve_arp()
            i.config_ip6()
            i.resolve_ndp()

    def tearDown(self):
        for i in self.pg_interfaces:
            i.unconfig_ip4()
            i.unconfig_ip6()
            i.admin_down()
        super(TestMSSClamp, self).tearDown()

    def verify_pkt(self, rx, expected_mss):
        # check that the MSS size equals the expected value
        # and the IP and TCP checksums are correct
        tcp = rx[TCP]
        tcp_csum = tcp.chksum
        del tcp.chksum
        ip_csum = 0
        if rx.haslayer(IP):
            ip_csum = rx[IP].chksum
            del rx[IP].chksum

        opt = tcp.options
        self.assertEqual(opt[0][0], "MSS")
        self.assertEqual(opt[0][1], expected_mss)
        # recalculate checksums
        rx = rx.__class__(bytes(rx))
        tcp = rx[TCP]
        self.assertEqual(tcp_csum, tcp.chksum)
        if rx.haslayer(IP):
            self.assertEqual(ip_csum, rx[IP].chksum)

    def send_and_verify_ip4(self, src_pg, dst_pg, mss, expected_mss):
        # IPv4 TCP packet with the requested MSS option.
        # from a host on src_pg to a host on dst_pg.
        p = (
            Ether(dst=src_pg.local_mac, src=src_pg.remote_mac)
            / IP(src=src_pg.remote_ip4, dst=dst_pg.remote_ip4)
            / TCP(
                sport=1234,
                dport=1234,
                flags="S",
                options=[("MSS", (mss)), ("EOL", None)],
            )
            / Raw("\xa5" * 100)
        )

        rxs = self.send_and_expect(src_pg, p * 65, dst_pg)

        for rx in rxs:
            self.verify_pkt(rx, expected_mss)

    def send_and_verify_ip6(self, src_pg, dst_pg, mss, expected_mss):
        #
        # IPv6 TCP packet with the requested MSS option.
        # from a host on src_pg to a host on dst_pg.
        #
        p = (
            Ether(dst=src_pg.local_mac, src=src_pg.remote_mac)
            / IPv6(src=src_pg.remote_ip6, dst=dst_pg.remote_ip6)
            / TCP(
                sport=1234,
                dport=1234,
                flags="S",
                options=[("MSS", (mss)), ("EOL", None)],
            )
            / Raw("\xa5" * 100)
        )

        rxs = self.send_and_expect(src_pg, p * 65, dst_pg)

        for rx in rxs:
            self.verify_pkt(rx, expected_mss)

    def test_tcp_mss_clamping_ip4_tx(self):
        """IP4 TCP MSS Clamping TX"""

        # enable the TCP MSS clamping feature to lower the MSS to 1424.
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1424,
            ipv6_mss=0,
            ipv4_direction=3,
            ipv6_direction=0,
        )

        # Verify that the feature is enabled.
        rv, reply = self.vapi.mss_clamp_get(sw_if_index=self.pg1.sw_if_index)
        self.assertEqual(reply[0].ipv4_mss, 1424)
        self.assertEqual(reply[0].ipv4_direction, 3)

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip4(self.pg0, self.pg1, 1460, 1424)

        # check the stats
        stats = self.statistics.get_counter("/err/tcp-mss-clamping-ip4-out/clamped")
        self.assertEqual(sum(stats), 65)

        # Send syn packets with small enough MSS values and verify they are
        # unchanged.
        self.send_and_verify_ip4(self.pg0, self.pg1, 1400, 1400)

        # enable the the feature only in TX direction
        # and change the max MSS value
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1420,
            ipv6_mss=0,
            ipv4_direction=2,
            ipv6_direction=0,
        )

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip4(self.pg0, self.pg1, 1460, 1420)

        # enable the the feature only in RX direction
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1424,
            ipv6_mss=0,
            ipv4_direction=1,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip4(self.pg0, self.pg1, 1460, 1460)

        # disable the feature
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=0,
            ipv4_direction=0,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip4(self.pg0, self.pg1, 1460, 1460)

    def test_tcp_mss_clamping_ip4_rx(self):
        """IP4 TCP MSS Clamping RX"""

        # enable the TCP MSS clamping feature to lower the MSS to 1424.
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1424,
            ipv6_mss=0,
            ipv4_direction=3,
            ipv6_direction=0,
        )

        # Verify that the feature is enabled.
        rv, reply = self.vapi.mss_clamp_get(sw_if_index=self.pg1.sw_if_index)
        self.assertEqual(reply[0].ipv4_mss, 1424)
        self.assertEqual(reply[0].ipv4_direction, 3)

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip4(self.pg1, self.pg0, 1460, 1424)

        # check the stats
        stats = self.statistics.get_counter("/err/tcp-mss-clamping-ip4-in/clamped")
        self.assertEqual(sum(stats), 65)

        # Send syn packets with small enough MSS values and verify they are
        # unchanged.
        self.send_and_verify_ip4(self.pg1, self.pg0, 1400, 1400)

        # enable the the feature only in RX direction
        # and change the max MSS value
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1420,
            ipv6_mss=0,
            ipv4_direction=1,
            ipv6_direction=0,
        )

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip4(self.pg1, self.pg0, 1460, 1420)

        # enable the the feature only in TX direction
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=1424,
            ipv6_mss=0,
            ipv4_direction=2,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip4(self.pg1, self.pg0, 1460, 1460)

        # disable the feature
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=0,
            ipv4_direction=0,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip4(self.pg1, self.pg0, 1460, 1460)

    def test_tcp_mss_clamping_ip6_tx(self):
        """IP6 TCP MSS Clamping TX"""

        # enable the TCP MSS clamping feature to lower the MSS to 1424.
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1424,
            ipv4_direction=0,
            ipv6_direction=3,
        )

        # Verify that the feature is enabled.
        rv, reply = self.vapi.mss_clamp_get(sw_if_index=self.pg1.sw_if_index)
        self.assertEqual(reply[0].ipv6_mss, 1424)
        self.assertEqual(reply[0].ipv6_direction, 3)

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip6(self.pg0, self.pg1, 1460, 1424)

        # check the stats
        stats = self.statistics.get_counter("/err/tcp-mss-clamping-ip6-out/clamped")
        self.assertEqual(sum(stats), 65)

        # Send syn packets with small enough MSS values and verify they are
        # unchanged.
        self.send_and_verify_ip6(self.pg0, self.pg1, 1400, 1400)

        # enable the the feature only in TX direction
        # and change the max MSS value
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1420,
            ipv4_direction=0,
            ipv6_direction=2,
        )

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip6(self.pg0, self.pg1, 1460, 1420)

        # enable the the feature only in RX direction
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1424,
            ipv4_direction=0,
            ipv6_direction=1,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip6(self.pg0, self.pg1, 1460, 1460)

        # disable the feature
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=0,
            ipv4_direction=0,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip6(self.pg0, self.pg1, 1460, 1460)

    def test_tcp_mss_clamping_ip6_rx(self):
        """IP6 TCP MSS Clamping RX"""

        # enable the TCP MSS clamping feature to lower the MSS to 1424.
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1424,
            ipv4_direction=0,
            ipv6_direction=3,
        )

        # Verify that the feature is enabled.
        rv, reply = self.vapi.mss_clamp_get(sw_if_index=self.pg1.sw_if_index)
        self.assertEqual(reply[0].ipv6_mss, 1424)
        self.assertEqual(reply[0].ipv6_direction, 3)

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip6(self.pg1, self.pg0, 1460, 1424)

        # check the stats
        stats = self.statistics.get_counter("/err/tcp-mss-clamping-ip6-in/clamped")
        self.assertEqual(sum(stats), 65)

        # Send syn packets with small enough MSS values and verify they are
        # unchanged.
        self.send_and_verify_ip6(self.pg1, self.pg0, 1400, 1400)

        # enable the the feature only in RX direction
        # and change the max MSS value
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1420,
            ipv4_direction=0,
            ipv6_direction=1,
        )

        # Send syn packets and verify that the MSS value is lowered.
        self.send_and_verify_ip6(self.pg1, self.pg0, 1460, 1420)

        # enable the the feature only in TX direction
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=1424,
            ipv4_direction=0,
            ipv6_direction=2,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip6(self.pg1, self.pg0, 1460, 1460)

        # disable the feature
        self.vapi.mss_clamp_enable_disable(
            self.pg1.sw_if_index,
            ipv4_mss=0,
            ipv6_mss=0,
            ipv4_direction=0,
            ipv6_direction=0,
        )

        # Send the packets again and ensure they are unchanged.
        self.send_and_verify_ip6(self.pg1, self.pg0, 1460, 1460)


if __name__ == "__main__":
    unittest.main(testRunner=VppTestRunner)