diff options
-rw-r--r-- | src/vnet/ipsec/esp_encrypt.c | 21 | ||||
-rw-r--r-- | test/template_ipsec.py | 4 |
2 files changed, 18 insertions, 7 deletions
diff --git a/src/vnet/ipsec/esp_encrypt.c b/src/vnet/ipsec/esp_encrypt.c index f1153d92e8c..5db10b520e3 100644 --- a/src/vnet/ipsec/esp_encrypt.c +++ b/src/vnet/ipsec/esp_encrypt.c @@ -112,19 +112,26 @@ esp_add_footer_and_icv (vlib_buffer_t * b, u8 block_size, u8 icv_sz) static_always_inline void esp_update_ip4_hdr (ip4_header_t * ip4, u16 len, int is_transport, int is_udp) { - ip_csum_t sum = ip4->checksum; - u16 old_len = 0; + ip_csum_t sum; + u16 old_len; + + len = clib_net_to_host_u16 (len); + old_len = ip4->length; if (is_transport) { u8 prot = is_udp ? IP_PROTOCOL_UDP : IP_PROTOCOL_IPSEC_ESP; - old_len = ip4->length; - sum = ip_csum_update (sum, ip4->protocol, prot, ip4_header_t, protocol); + + sum = ip_csum_update (ip4->checksum, ip4->protocol, + prot, ip4_header_t, protocol); ip4->protocol = prot; + + sum = ip_csum_update (sum, old_len, len, ip4_header_t, length); } + else + sum = ip_csum_update (ip4->checksum, old_len, len, ip4_header_t, length); - ip4->length = len = clib_net_to_host_u16 (len); - sum = ip_csum_update (ip4->checksum, old_len, len, ip4_header_t, length); + ip4->length = len; ip4->checksum = ip_csum_fold (sum); } @@ -411,7 +418,7 @@ esp_encrypt_inline (vlib_main_t * vm, vlib_node_runtime_t * node, u16 len; ip4_header_t *ip4 = (ip4_header_t *) (ip_hdr); *next_hdr_ptr = ip4->protocol; - len = payload_len + hdr_len + l2_len; + len = payload_len + hdr_len - l2_len; if (udp) { esp_update_ip4_hdr (ip4, len, /* is_transport */ 1, 1); diff --git a/test/template_ipsec.py b/test/template_ipsec.py index b954af1c824..73ae24a4295 100644 --- a/test/template_ipsec.py +++ b/test/template_ipsec.py @@ -400,6 +400,8 @@ class IpsecTra4(object): recv_pkts = self.send_and_expect(self.tra_if, send_pkts, self.tra_if) for rx in recv_pkts: + self.assertEqual(len(rx) - len(Ether()), rx[IP].len) + self.assert_packet_checksums_valid(rx) try: decrypted = p.vpp_tra_sa.decrypt(rx[IP]) self.assert_packet_checksums_valid(decrypted) @@ -522,6 +524,8 @@ class IpsecTun4(object): def verify_encrypted(self, p, sa, rxs): decrypt_pkts = [] for rx in rxs: + self.assert_packet_checksums_valid(rx) + self.assertEqual(len(rx) - len(Ether()), rx[IP].len) try: decrypt_pkt = p.vpp_tun_sa.decrypt(rx[IP]) if not decrypt_pkt.haslayer(IP): |