aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_node_variants.py
blob: a5c9137f9bafac4de9646103e40013c4df20d5b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
import re
import unittest
import platform
from framework import VppTestCase


def checkX86():
    return platform.machine() in ["x86_64", "AMD64"]


def skipVariant(variant):
    with open("/proc/cpuinfo") as f:
        cpuinfo = f.read()

    exp = re.compile(r"(?:flags\s+:)(?:\s\w+)+(?:\s(" + variant + r"))(?:\s\w+)+")
    match = exp.search(cpuinfo, re.DOTALL | re.MULTILINE)

    return checkX86() and match is not None


class TestNodeVariant(VppTestCase):
    """Test Node Variants"""

    @classmethod
    def setUpConstants(cls, variant):
        super(TestNodeVariant, cls).setUpConstants()
        # find the position of node_variants in the cmdline args.

        if checkX86():
            node_variants = cls.vpp_cmdline.index("node { ") + 1
            cls.vpp_cmdline[node_variants] = (
                "default { variant default } "
                "ip4-rewrite { variant " + variant + " } "
            )

    @classmethod
    def setUpClass(cls):
        super(TestNodeVariant, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        super(TestNodeVariant, cls).tearDownClass()

    def setUp(self):
        super(TestNodeVariant, self).setUp()

    def tearDown(self):
        super(TestNodeVariant, self).tearDown()

    def getActiveVariant(self, node):
        node_desc = self.vapi.cli("show node " + node)
        self.logger.info(node_desc)

        match = re.search(
            r"\s+(\S+)\s+(\d+)\s+(:?yes)", node_desc, re.DOTALL | re.MULTILINE
        )

        return match.groups(0)

    def checkVariant(self, variant):
        """Test node variants defaults"""

        variant_info = self.getActiveVariant("ip4-lookup")
        self.assertEqual(variant_info[0], "default")

        variant_info = self.getActiveVariant("ip4-rewrite")
        self.assertEqual(variant_info[0], variant)


class TestICLVariant(TestNodeVariant):
    """Test icl Node Variants"""

    VARIANT = "icl"
    LINUX_VARIANT = "avx512_bitalg"

    @classmethod
    def setUpConstants(cls):
        super(TestICLVariant, cls).setUpConstants(cls.VARIANT)

    @classmethod
    def setUpClass(cls):
        super(TestICLVariant, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        super(TestICLVariant, cls).tearDownClass()

    @unittest.skipUnless(
        skipVariant(LINUX_VARIANT), VARIANT + " not a supported variant, skip."
    )
    def test_icl(self):
        self.checkVariant(self.VARIANT)


class TestSKXVariant(TestNodeVariant):
    """Test skx Node Variants"""

    VARIANT = "skx"
    LINUX_VARIANT = "avx512f"

    @classmethod
    def setUpConstants(cls):
        super(TestSKXVariant, cls).setUpConstants(cls.VARIANT)

    @classmethod
    def setUpClass(cls):
        super(TestSKXVariant, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        super(TestSKXVariant, cls).tearDownClass()

    @unittest.skipUnless(
        skipVariant(LINUX_VARIANT), VARIANT + " not a supported variant, skip."
    )
    def test_skx(self):
        self.checkVariant(self.VARIANT)


class TestHSWVariant(TestNodeVariant):
    """Test avx2 Node Variants"""

    VARIANT = "hsw"
    LINUX_VARIANT = "avx2"

    @classmethod
    def setUpConstants(cls):
        super(TestHSWVariant, cls).setUpConstants(cls.VARIANT)

    @classmethod
    def setUpClass(cls):
        super(TestHSWVariant, cls).setUpClass()

    @classmethod
    def tearDownClass(cls):
        super(TestHSWVariant, cls).tearDownClass()

    @unittest.skipUnless(
        skipVariant(LINUX_VARIANT), VARIANT + " not a supported variant, skip."
    )
    def test_hsw(self):
        self.checkVariant(self.VARIANT)