aboutsummaryrefslogtreecommitdiffstats
path: root/extras/packetforge/ProtocolHeader.py
diff options
context:
space:
mode:
Diffstat (limited to 'extras/packetforge/ProtocolHeader.py')
-rw-r--r--extras/packetforge/ProtocolHeader.py387
1 files changed, 387 insertions, 0 deletions
diff --git a/extras/packetforge/ProtocolHeader.py b/extras/packetforge/ProtocolHeader.py
new file mode 100644
index 00000000000..398a52d3455
--- /dev/null
+++ b/extras/packetforge/ProtocolHeader.py
@@ -0,0 +1,387 @@
+# Copyright (c) 2022 Intel and/or its affiliates.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ProtocolHeaderAttribute import *
+from ProtocolHeaderField import *
+from InputFormat import *
+import ExpressionConverter
+import copy
+
+
+class ProtocolHeader:
+ def __init__(self, node):
+ self.fields = []
+ self.attributes = []
+ self.fieldDict = {}
+ self.attributeDict = {}
+ self.Buffer = []
+ self.Mask = []
+
+ self.node = node
+ for field in self.node.fields:
+ phf = ProtocolHeaderField(field.Size, field.DefaultValue, None, field)
+ self.fields.append(phf)
+ if field.Name != "reserved":
+ self.fieldDict[field.Name] = phf
+
+ for attr in self.node.attributes:
+ pha = ProtocolHeaderAttribute(attr.Size, attr.DefaultValue, attr)
+ self.attributes.append(pha)
+ self.attributeDict[attr.Name] = pha
+
+ def Name(self):
+ return self.node.Name
+
+ def Fields(self):
+ return self.fields
+
+ def Attributes(self):
+ return self.attributes
+
+ def setField(self, name, expression, auto):
+ if name == "reserved":
+ return False
+
+ if name not in self.fieldDict:
+ return False
+
+ field = self.fieldDict[name]
+
+ if field.UpdateValue(expression, auto):
+ field.UpdateSize()
+ return True
+
+ return False
+
+ def SetField(self, name, expression):
+ return self.setField(name, expression, False)
+
+ def SetFieldAuto(self, name, expression):
+ return self.setField(name, expression, True)
+
+ def SetAttribute(self, name, expression):
+ if name not in self.attributeDict:
+ return False
+ attr = self.attributeDict[name]
+
+ return attr.UpdateValue(expression)
+
+ def SetMask(self, name, expression):
+ if name not in self.fieldDict:
+ return False
+ field = self.fieldDict[name]
+
+ return field.UpdateMask(expression)
+
+ def resolveOptional(self, condition):
+ if condition == None:
+ return True
+
+ tokens = condition.split("|")
+
+ if len(tokens) > 1:
+ result = False
+
+ for token in tokens:
+ result |= self.resolveOptional(token)
+
+ return result
+
+ tokens = condition.split("&")
+
+ if len(tokens) > 1:
+ result = True
+
+ for token in tokens:
+ result &= self.resolveOptional(token)
+
+ return result
+
+ key = None
+ value = None
+
+ if "!=" in tokens[0]:
+ index = tokens[0].find("!=")
+ key = tokens[0][:index].strip()
+ value = tokens[0][index + 1 :].strip()
+ elif "=" in tokens[0]:
+ index = tokens[0].find("=")
+ key = tokens[0][:index].strip()
+ value = tokens[0][index + 1 :].strip()
+ else:
+ return False
+
+ if key not in self.fieldDict:
+ return False
+
+ f = self.fieldDict[key]
+ return ExpressionConverter.Equal(f.Value, value)
+
+ def resolveSize(self, exp):
+ shift = 0
+ key = exp
+
+ if "<<" in exp:
+ offset = exp.find("<<")
+ key = exp[0:offset].strip()
+ shift = int(exp[offset + 2 :].strip())
+
+ if key in self.fieldDict:
+ field = self.fieldDict[key]
+ _, u16 = ExpressionConverter.ToNum(field.Value)
+ if u16:
+ return u16 << shift
+ else:
+ return 0
+
+ if key in self.attributeDict:
+ attr = self.attributeDict[key]
+ _, u16 = ExpressionConverter.ToNum(attr.Value)
+ if u16:
+ return u16 << shift
+ else:
+ return 0
+
+ return 0
+
+ def Adjust(self):
+ autoIncreases = []
+ increaseHeaders = []
+
+ self.resolveAllSize()
+
+ for phf in self.fields:
+ if phf.Field.IsAutoIncrease:
+ autoIncreases.append(phf)
+ if phf.Field.IsIncreaseLength and self.resolveOptional(phf.Field.Optional):
+ increaseHeaders.append(phf)
+
+ for f1 in autoIncreases:
+ for f2 in increaseHeaders:
+ f1.UpdateValue(
+ ExpressionConverter.IncreaseValue(f1.Value, f2.Size >> 3), True
+ )
+
+ def resolveAllSize(self):
+ for phf in self.fields:
+ if phf.Field.Optional != None and not self.resolveOptional(
+ phf.Field.Optional
+ ):
+ size = 0
+ else:
+ if phf.Field.VariableSize != None:
+ size = self.resolveSize(phf.Field.VariableSize)
+ else:
+ size = phf.Field.Size
+ phf.Size = size
+
+ def GetSize(self):
+ size = 0
+
+ for field in self.fields:
+ size += field.Size
+
+ return size >> 3
+
+ def AppendAuto(self, size):
+ for phf in self.fields:
+ if not phf.Field.IsAutoIncrease:
+ continue
+
+ phf.UpdateValue(ExpressionConverter.IncreaseValue(phf.Value, size), True)
+
+ def getField(self, name):
+ if name not in self.fieldDict:
+ return None
+ field = self.fieldDict[name]
+
+ return field.Value
+
+ def getAttribute(self, name):
+ if name not in self.attributeDict:
+ return None
+
+ return self.attributeDict[name].Value
+
+ def GetValue(self, name):
+ result = self.getField(name)
+
+ if result == None:
+ return self.getAttribute(name)
+
+ return result
+
+ def appendNum(self, big, exp, size):
+ num = 0
+ if exp != None:
+ _, num = ExpressionConverter.ToNum(exp)
+ if num == None:
+ print("Invalid byte expression")
+ return None
+
+ # cut msb
+ num = num & ((1 << size) - 1)
+ big = big << size
+ big = big | num
+ return big
+
+ def appendUInt64(self, big, exp, size):
+ u64 = 0
+ if exp != None:
+ _, u64 = ExpressionConverter.ToNum(exp)
+ if not u64:
+ print("Invalid UInt32 expression")
+ return False
+
+ # cut msb
+ if size < 64:
+ u64 = u64 & ((1 << size) - 1)
+ big = big << size
+ big = big | u64
+ return big
+
+ def appendIPv4(self, big, exp):
+ ipv4 = bytes(4)
+ if exp != None:
+ _, ipv4 = ExpressionConverter.ToIPv4Address(exp)
+ if not ipv4:
+ print("Inavalid IPv4 Address")
+ return False
+
+ for i in range(len(ipv4)):
+ big = big << 8
+ big = big | ipv4[i]
+
+ return big
+
+ def appendIPv6(self, big, exp):
+ ipv6 = bytes(16)
+ if exp != None:
+ _, ipv6 = ExpressionConverter.ToIPv6Address(exp)
+ if not ipv6:
+ print("Inavalid IPv6 Address")
+ return False
+
+ for i in range(16):
+ big = big << 8
+ big = big | ipv6[i]
+
+ return big
+
+ def appendMAC(self, big, exp):
+ mac = bytes(6)
+ if exp != None:
+ _, mac = ExpressionConverter.ToMacAddress(exp)
+ if not mac:
+ print("Inavalid MAC Address")
+ return False
+
+ for i in range(6):
+ big = big << 8
+ big = big | mac[i]
+
+ return big
+
+ def appendByteArray(self, big, exp, size):
+ array = bytes(size >> 3)
+ if exp != None:
+ _, array = ExpressionConverter.ToByteArray(exp)
+ if not array:
+ print("Invalid byte array")
+ return False
+
+ for i in range(size >> 3):
+ big = big << 8
+ if i < len(array):
+ big = big | array[i]
+
+ return big
+
+ def append(self, big, phf):
+ bigVal = big["bigVal"]
+ bigMsk = big["bigMsk"]
+
+ if phf.Field.IsReserved:
+ bigVal <<= phf.Size
+ bigMsk <<= phf.Size
+ big.update(bigVal=bigVal, bigMsk=bigMsk)
+ return big, phf.Size
+
+ size = phf.Size
+
+ if (
+ phf.Field.Format == InputFormat.u8
+ or phf.Field.Format == InputFormat.u16
+ or phf.Field.Format == InputFormat.u32
+ ):
+ bigVal = self.appendNum(bigVal, phf.Value, size)
+ bigMsk = self.appendNum(bigMsk, phf.Mask, size)
+
+ elif phf.Field.Format == InputFormat.u64:
+ bigVal = self.appendUInt64(bigVal, phf.Value, size)
+ bigMsk = self.appendUInt64(bigMsk, phf.Mask, size)
+
+ elif phf.Field.Format == InputFormat.ipv4:
+ bigVal = self.appendIPv4(bigVal, phf.Value)
+ bigMsk = self.appendIPv4(bigMsk, phf.Mask)
+
+ elif phf.Field.Format == InputFormat.ipv6:
+ bigVal = self.appendIPv6(bigVal, phf.Value)
+ bigMsk = self.appendIPv6(bigMsk, phf.Mask)
+
+ elif phf.Field.Format == InputFormat.mac:
+ bigVal = self.appendMAC(bigVal, phf.Value)
+ bigMsk = self.appendMAC(bigMsk, phf.Mask)
+
+ elif phf.Field.Format == InputFormat.bytearray:
+ bigVal = self.appendByteArray(bigVal, phf.Value, size)
+ bigMsk = self.appendByteArray(bigMsk, phf.Mask, size)
+
+ else:
+ print("Invalid input format")
+
+ big.update(bigVal=bigVal, bigMsk=bigMsk)
+ return big, size
+
+ def Resolve(self):
+ big = {"bigVal": 0, "bigMsk": 0}
+ offset = 0
+
+ for phf in self.fields:
+ if phf.Size == 0:
+ continue
+
+ big, bits = self.append(big, phf)
+
+ offset += bits
+
+ byteList1 = []
+ byteList2 = []
+
+ bigVal = big["bigVal"]
+ bigMsk = big["bigMsk"]
+
+ while offset > 0:
+ byteList1.append(bigVal & 0xFF)
+ byteList2.append(bigMsk & 0xFF)
+ bigVal = bigVal >> 8
+ bigMsk = bigMsk >> 8
+ offset -= 8
+
+ byteList1.reverse()
+ byteList2.reverse()
+ buffer = copy.deepcopy(byteList1)
+ mask = copy.deepcopy(byteList2)
+
+ self.Buffer = buffer
+ self.Mask = mask