diff options
Diffstat (limited to 'extras/packetforge/ProtocolHeader.py')
-rw-r--r-- | extras/packetforge/ProtocolHeader.py | 387 |
1 files changed, 387 insertions, 0 deletions
diff --git a/extras/packetforge/ProtocolHeader.py b/extras/packetforge/ProtocolHeader.py new file mode 100644 index 00000000000..272b6557460 --- /dev/null +++ b/extras/packetforge/ProtocolHeader.py @@ -0,0 +1,387 @@ +# Copyright (c) 2022 Intel and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ProtocolHeaderAttribute import * +from ProtocolHeaderField import * +from InputFormat import * +import ExpressionConverter +import copy + + +class ProtocolHeader: + def __init__(self, node): + self.fields = [] + self.attributes = [] + self.fieldDict = {} + self.attributeDict = {} + self.Buffer = [] + self.Mask = [] + + self.node = node + for field in self.node.fields: + phf = ProtocolHeaderField(field.Size, field.DefaultValue, None, field) + self.fields.append(phf) + if field.Name != "reserved": + self.fieldDict[field.Name] = phf + + for attr in self.node.attributes: + pha = ProtocolHeaderAttribute(attr.Size, attr.DefaultValue, attr) + self.attributes.append(pha) + self.attributeDict[attr.Name] = pha + + def Name(self): + return self.node.Name + + def Fields(self): + return self.fields + + def Attributes(self): + return self.attributes + + def setField(self, name, expression, auto): + if name == "reserved": + return False + + if name not in self.fieldDict: + return False + + field = self.fieldDict[name] + + if field.UpdateValue(expression, auto): + field.UpdateSize() + return True + + return False + + def SetField(self, name, expression): + return self.setField(name, expression, False) + + def SetFieldAuto(self, name, expression): + return self.setField(name, expression, True) + + def SetAttribute(self, name, expression): + if name not in self.attributeDict: + return False + attr = self.attributeDict[name] + + return attr.UpdateValue(expression) + + def SetMask(self, name, expression): + if name not in self.fieldDict: + return False + field = self.fieldDict[name] + + return field.UpdateMask(expression) + + def resolveOptional(self, condition): + if condition == None: + return True + + tokens = condition.split("|") + + if len(tokens) > 1: + result = False + + for token in tokens: + result |= self.resolveOptional(token) + + return result + + tokens = condition.split("&") + + if len(tokens) > 1: + result = True + + for token in tokens: + result &= self.resolveOptional(token) + + return result + + key = None + value = None + + if "!=" in tokens[0]: + index = tokens[0].find("!=") + key = tokens[0][:index].strip() + value = tokens[0][index + 1 :].strip() + elif "=" in tokens[0]: + index = tokens[0].find("=") + key = tokens[0][:index].strip() + value = tokens[0][index + 1 :].strip() + else: + return False + + if key not in self.fieldDict: + return False + + f = self.fieldDict[key] + return ExpressionConverter.Equal(f.Value, value) + + def resolveSize(self, exp): + shift = 0 + key = exp + + if "<<" in exp: + offset = exp.find("<<") + key = exp[0:offset].strip() + shift = int(exp[offset + 2 :].strip()) + + if self.fieldDict.has_key(key): + field = self.fieldDict[key] + _, u16 = ExpressionConverter.ToNum(field.Value) + if u16: + return u16 << shift + else: + return 0 + + if self.attributeDict.has_key(key): + attr = self.attributeDict[key] + _, u16 = ExpressionConverter.ToNum(attr.Value) + if u16: + return u16 << shift + else: + return 0 + + return 0 + + def Adjust(self): + autoIncreases = [] + increaseHeaders = [] + + self.resolveAllSize() + + for phf in self.fields: + if phf.Field.IsAutoIncrease: + autoIncreases.append(phf) + if phf.Field.IsIncreaseLength and self.resolveOptional(phf.Field.Optional): + increaseHeaders.append(phf) + + for f1 in autoIncreases: + for f2 in increaseHeaders: + f1.UpdateValue( + ExpressionConverter.IncreaseValue(f1.Value, f2.Size >> 3), True + ) + + def resolveAllSize(self): + for phf in self.fields: + if phf.Field.Optional != None and not self.resolveOptional( + phf.Field.Optional + ): + size = 0 + else: + if phf.Field.VariableSize != None: + size = self.resolveSize(phf.Field.VariableSize) + else: + size = phf.Field.Size + phf.Size = size + + def GetSize(self): + size = 0 + + for field in self.fields: + size += field.Size + + return size >> 3 + + def AppendAuto(self, size): + for phf in self.fields: + if not phf.Field.IsAutoIncrease: + continue + + phf.UpdateValue(ExpressionConverter.IncreaseValue(phf.Value, size), True) + + def getField(self, name): + if not self.fieldDict.has_key(name): + return None + field = self.fieldDict[name] + + return field.Value + + def getAttribute(self, name): + if not self.attributeDict.has_key(name): + return None + + return self.attributeDict[name].Value + + def GetValue(self, name): + result = self.getField(name) + + if result == None: + return self.getAttribute(name) + + return result + + def appendNum(self, big, exp, size): + num = 0 + if exp != None: + _, num = ExpressionConverter.ToNum(exp) + if num == None: + print("Invalid byte expression") + return None + + # cut msb + num = num & ((1 << size) - 1) + big = big << size + big = big | num + return big + + def appendUInt64(self, big, exp, size): + u64 = 0 + if exp != None: + _, u64 = ExpressionConverter.ToNum(exp) + if not u64: + print("Invalid UInt32 expression") + return False + + # cut msb + if size < 64: + u64 = u64 & ((1 << size) - 1) + big = big << size + big = big | u64 + return big + + def appendIPv4(self, big, exp): + ipv4 = bytes(4) + if exp != None: + _, ipv4 = ExpressionConverter.ToIPv4Address(exp) + if not ipv4: + print("Inavalid IPv4 Address") + return False + + for i in range(len(ipv4)): + big = big << 8 + big = big | ipv4[i] + + return big + + def appendIPv6(self, big, exp): + ipv6 = bytes(16) + if exp != None: + _, ipv6 = ExpressionConverter.ToIPv6Address(exp) + if not ipv6: + print("Inavalid IPv6 Address") + return False + + for i in range(16): + big = big << 8 + big = big | ipv6[i] + + return big + + def appendMAC(self, big, exp): + mac = bytes(6) + if exp != None: + _, mac = ExpressionConverter.ToMacAddress(exp) + if not mac: + print("Inavalid MAC Address") + return False + + for i in range(6): + big = big << 8 + big = big | mac[i] + + return big + + def appendByteArray(self, big, exp, size): + array = bytes(size >> 3) + if exp != None: + _, array = ExpressionConverter.ToByteArray(exp) + if not array: + print("Invalid byte array") + return False + + for i in range(size >> 3): + big = big << 8 + if i < len(array): + big = big | array[i] + + return big + + def append(self, big, phf): + bigVal = big["bigVal"] + bigMsk = big["bigMsk"] + + if phf.Field.IsReserved: + bigVal <<= phf.Size + bigMsk <<= phf.Size + big.update(bigVal=bigVal, bigMsk=bigMsk) + return big, phf.Size + + size = phf.Size + + if ( + phf.Field.Format == InputFormat.u8 + or phf.Field.Format == InputFormat.u16 + or phf.Field.Format == InputFormat.u32 + ): + bigVal = self.appendNum(bigVal, phf.Value, size) + bigMsk = self.appendNum(bigMsk, phf.Mask, size) + + elif phf.Field.Format == InputFormat.u64: + bigVal = self.appendUInt64(bigVal, phf.Value, size) + bigMsk = self.appendUInt64(bigMsk, phf.Mask, size) + + elif phf.Field.Format == InputFormat.ipv4: + bigVal = self.appendIPv4(bigVal, phf.Value) + bigMsk = self.appendIPv4(bigMsk, phf.Mask) + + elif phf.Field.Format == InputFormat.ipv6: + bigVal = self.appendIPv6(bigVal, phf.Value) + bigMsk = self.appendIPv6(bigMsk, phf.Mask) + + elif phf.Field.Format == InputFormat.mac: + bigVal = self.appendMAC(bigVal, phf.Value) + bigMsk = self.appendMAC(bigMsk, phf.Mask) + + elif phf.Field.Format == InputFormat.bytearray: + bigVal = self.appendByteArray(bigVal, phf.Value, size) + bigMsk = self.appendByteArray(bigMsk, phf.Mask, size) + + else: + print("Invalid input format") + + big.update(bigVal=bigVal, bigMsk=bigMsk) + return big, size + + def Resolve(self): + big = {"bigVal": 0, "bigMsk": 0} + offset = 0 + + for phf in self.fields: + if phf.Size == 0: + continue + + big, bits = self.append(big, phf) + + offset += bits + + byteList1 = [] + byteList2 = [] + + bigVal = big["bigVal"] + bigMsk = big["bigMsk"] + + while offset > 0: + byteList1.append(bigVal & 0xFF) + byteList2.append(bigMsk & 0xFF) + bigVal = bigVal >> 8 + bigMsk = bigMsk >> 8 + offset -= 8 + + byteList1.reverse() + byteList2.reverse() + buffer = copy.deepcopy(byteList1) + mask = copy.deepcopy(byteList2) + + self.Buffer = buffer + self.Mask = mask |