summaryrefslogtreecommitdiffstats
path: root/src/plugins/dns/test/test_dns.py
blob: 0f878831b0772f1372398cdd7822b4035ccd0c6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python

import unittest

from framework import VppTestCase, VppTestRunner
from vpp_ip_route import VppIpTable, VppIpRoute, VppRoutePath
from vpp_ip import VppIpPrefix
from ipaddress import *

import scapy.compat
from scapy.contrib.mpls import MPLS
from scapy.layers.inet import IP, UDP, TCP, ICMP, icmptypes, icmpcodes
from scapy.layers.l2 import Ether
from scapy.packet import Raw
from scapy.layers.dns import DNSRR, DNS, DNSQR


class TestDns(VppTestCase):
    """ Dns Test Cases """

    @classmethod
    def setUpClass(cls):
        super(TestDns, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        super(TestDns, cls).tearDownClass()

    def setUp(self):
        super(TestDns, self).setUp()

        self.create_pg_interfaces(range(1))

        for i in self.pg_interfaces:
            i.admin_up()
            i.config_ip4()
            i.resolve_arp()

    def tearDown(self):
        super(TestDns, self).tearDown()

    def create_stream(self, src_if):
        """Create input packet stream for defined interface.

        :param VppInterface src_if: Interface to create packet stream for.
        """
        good_request = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                        IP(src=src_if.remote_ip4) /
                        UDP(sport=1234, dport=53) /
                        DNS(rd=1, qd=DNSQR(qname="bozo.clown.org")))

        bad_request = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                       IP(src=src_if.remote_ip4) /
                       UDP(sport=1234, dport=53) /
                       DNS(rd=1, qd=DNSQR(qname="no.clown.org")))
        pkts = [good_request, bad_request]
        return pkts

    def verify_capture(self, dst_if, capture):
        """Verify captured input packet stream for defined interface.

        :param VppInterface dst_if: Interface to verify captured packet stream
            for.
        :param list capture: Captured packet stream.
        """
        self.logger.info("Verifying capture on interface %s" % dst_if.name)
        for packet in capture:
            dns = packet[DNS]
            self.assertEqual(dns.an[0].rdata, '1.2.3.4')

    def test_dns_unittest(self):
        """ DNS Name Resolver Basic Functional Test """

        # Set up an upstream name resolver. We won't actually go there
        self.vapi.dns_name_server_add_del(
            is_ip6=0, is_add=1, server_address=IPv4Address(u'8.8.8.8').packed)

        # Enable name resolution
        self.vapi.dns_enable_disable(enable=1)

        # Manually add a static dns cache entry
        self.logger.info(self.vapi.cli("dns cache add bozo.clown.org 1.2.3.4"))

        # Test the binary API
        rv = self.vapi.dns_resolve_name(name=b'bozo.clown.org')
        self.assertEqual(rv.ip4_address, IPv4Address(u'1.2.3.4').packed)

        # Configure 127.0.0.1/8 on the pg interface
        self.vapi.sw_interface_add_del_address(
            sw_if_index=self.pg0.sw_if_index,
            prefix=VppIpPrefix("127.0.0.1", 8).encode())

        # Send a couple of DNS request packets, one for bozo.clown.org
        # and one for no.clown.org which won't resolve

        pkts = self.create_stream(self.pg0)
        self.pg0.add_stream(pkts)
        self.pg_enable_capture(self.pg_interfaces)

        self.pg_start()
        pkts = self.pg0.get_capture(1)
        self.verify_capture(self.pg0, pkts)

        # Make sure that the cache contents are correct
        str = self.vapi.cli("show dns cache verbose")
        self.assertIn('1.2.3.4', str)
        self.assertIn('[P] no.clown.org:', str)

if __name__ == '__main__':
    unittest.main(testRunner=VppTestRunner)