/* SPDX-License-Identifier: Apache-2.0 * Copyright(c) 2024 Cisco Systems, Inc. */ #include /* * rfc8446#section-4.1.2 * struct { * ProtocolVersion legacy_version = 0x0303; // TLS v1.2 * Random random; * opaque legacy_session_id<0..32>; * CipherSuite cipher_suites<2..2^16-2>; * opaque legacy_compression_methods<1..2^8-1>; * Extension extensions<8..2^16-1>; * } ClientHello; */ tls_handshake_parse_error_t tls_handshake_client_hello_parse (u8 *b, int len, tls_handshake_msg_info_t *info) { u8 *p = b; if (PREDICT_FALSE (len < 2 + 32 + 1 + 2 + 2 + 2)) return TLS_HS_PARSE_ERR_INVALID_LEN; /* skip legacy version and random */ p += 2 + 32; /* legacy_session_id */ info->legacy_session_id_len = *p; info->legacy_session_id = p + 1; p = info->legacy_session_id + info->legacy_session_id_len; if (PREDICT_FALSE (p - b >= len)) return TLS_HS_PARSE_ERR_SESSION_ID_LEN; /* cipher_suites */ info->cipher_suite_len = clib_net_to_host_u16 (*(u16 *) p); info->cipher_suites = p + 2; p = info->cipher_suites + info->cipher_suite_len; if (PREDICT_FALSE (p - b >= len)) return TLS_HS_PARSE_ERR_CIPHER_SUITE_LEN; /* legacy_compression_method, only support null */ if (PREDICT_FALSE (*p != 1 || *(p + 1) != 0)) return TLS_HS_PARSE_ERR_COMPRESSION_METHOD; p += 2; /* extensions */ info->extensions_len = clib_net_to_host_u16 (*(u16 *) p); info->extensions = p + 2; if (PREDICT_FALSE (info->extensions + info->extensions_len - b > len)) return TLS_HS_PARSE_ERR_CIPHER_SUITE_LEN; return TLS_HS_PARSE_ERR_OK; } typedef tls_handshake_parse_error_t (*tls_handshake_msg_parser) ( u8 *b, int len, tls_handshake_msg_info_t *info); static tls_handshake_msg_parser tls_handshake_msg_parsers[] = { [TLS_HS_CLIENT_HELLO] = tls_handshake_client_hello_parse, }; static inline u32 tls_handshake_ext_requested (const tls_handshake_ext_info_t *req_exts, u32 n_reqs, tls_handshake_ext_type_t ext_type) { for (int i = 0; i < n_reqs; i++) { if (req_exts[i].type == ext_type) return i; } return ~0; } tls_handshake_parse_error_t tls_hanshake_extensions_parse (tls_handshake_msg_info_t *info, tls_handshake_ext_info_t **exts) { tls_handshake_ext_info_t *ext; u16 ext_type, ext_len; u8 *b, *b_end; ASSERT (info->extensions != 0); if (info->extensions_len < 2) return TLS_HS_PARSE_ERR_EXTENSIONS_LEN; b = info->extensions; b_end = info->extensions + info->extensions_len; while (b < b_end) { ext_type = clib_net_to_host_u16 (*(u16 *) b); b += 2; ext_len = clib_net_to_host_u16 (*(u16 *) b); b += 2; if (b + ext_len > b_end) return TLS_HS_PARSE_ERR_EXTENSIONS_LEN; vec_add2 (*exts, ext, 1); ext->type = ext_type; ext->len = ext_len; ext->data = b; b += ext_len; } return TLS_HS_PARSE_ERR_OK; } tls_handshake_parse_error_t tls_hanshake_extensions_try_parse (tls_handshake_msg_info_t *info, tls_handshake_ext_info_t *req_exts, u32 n_reqs) { u8 *b, *b_end; u16 ext_type, ext_len; u32 n_found = 0, ext_pos; ASSERT (info->extensions != 0); if (info->extensions_len < 2) return TLS_HS_PARSE_ERR_EXTENSIONS_LEN; b = info->extensions; b_end = info->extensions + info->extensions_len; while (b < b_end && n_found < n_reqs) { ext_type = clib_net_to_host_u16 (*(u16 *) b); b += 2; ext_len = clib_net_to_host_u16 (*(u16 *) b); b += 2; if (b + ext_len > b_end) return TLS_HS_PARSE_ERR_EXTENSIONS_LEN; ext_pos = tls_handshake_ext_requested (req_exts, n_reqs, ext_type); if (ext_pos == ~0) { b += ext_len; continue; } req_exts[ext_pos].len = ext_len; req_exts[ext_pos].data = b; b += ext_len; n_found++; } return TLS_HS_PARSE_ERR_OK; } tls_handshake_parse_error_t tls_handshake_message_try_parse (u8 *msg, int len, tls_handshake_msg_info_t *info) { tls_handshake_msg_t *msg_hdr = (tls_handshake_msg_t *) msg; u8 *b = msg_hdr->message; info->len = tls_handshake_message_len (msg_hdr); if (info->len > len) return info->len > TLS_FRAGMENT_MAX_ENC_LEN ? TLS_HS_PARSE_ERR_INVALID_LEN : TLS_HS_PARSE_ERR_WANT_MORE; if (msg_hdr->msg_type >= ARRAY_LEN (tls_handshake_msg_parsers) || !tls_handshake_msg_parsers[msg_hdr->msg_type]) return TLS_HS_PARSE_ERR_UNSUPPORTED; return tls_handshake_msg_parsers[msg_hdr->msg_type](b, info->len, info); } /** * As per rfc6066#section-3 * struct { * NameType name_type; * select (name_type) { * case host_name: HostName; * } name; * } ServerName; * * enum { * host_name(0), (255) * } NameType; * * opaque HostName<1..2^16-1>; * * struct { * ServerName server_name_list<1..2^16-1> * } ServerNameList; */ tls_handshake_parse_error_t tls_handshake_ext_sni_parse (tls_handshake_ext_info_t *ext_info, tls_handshake_ext_t *ext) { tls_handshake_ext_sni_t *sni = (tls_handshake_ext_sni_t *) ext; tls_handshake_ext_sni_sn_t *sn; u16 n_names, sn_len; u8 *b, *b_end; b = ext_info->data; b_end = b + ext_info->len; sni->ext.type = ext_info->type; sni->names = 0; n_names = clib_net_to_host_u16 (*(u16 *) b); b += 2; while (b < b_end && vec_len (sni->names) < n_names) { /* only host name supported */ if (b[0] != 0) return TLS_HS_PARSE_ERR_EXT_SNI_NAME_TYPE; b++; /* server name length */ sn_len = clib_net_to_host_u16 (*(u16 *) b); if (sn_len > TLS_EXT_SNI_MAX_LEN) return TLS_HS_PARSE_ERR_EXT_SNI_LEN; b += 2; vec_add2 (sni->names, sn, 1); sn->name_type = 0; vec_validate (sn->host_name, sn_len - 1); clib_memcpy (sn->host_name, b, sn_len); b += sn_len; } return TLS_HS_PARSE_ERR_OK; } typedef tls_handshake_parse_error_t (*tls_handshake_ext_parser) ( tls_handshake_ext_info_t *ext_info, tls_handshake_ext_t *ext); static tls_handshake_ext_parser tls_handshake_ext_parsers[] = { [TLS_EXT_SERVER_NAME] = tls_handshake_ext_sni_parse, }; tls_handshake_parse_error_t tls_handshake_ext_parse (tls_handshake_ext_info_t *ext_info, tls_handshake_ext_t *ext) { if (ext_info->type >= ARRAY_LEN (tls_handshake_ext_parsers) || !tls_handshake_ext_parsers[ext_info->type]) return TLS_HS_PARSE_ERR_UNSUPPORTED; return tls_handshake_ext_parsers[ext_info->type](ext_info, ext); } static void tls_handshake_ext_sni_free (tls_handshake_ext_t *ext) { tls_handshake_ext_sni_t *sni = (tls_handshake_ext_sni_t *) ext; tls_handshake_ext_sni_sn_t *sn; vec_foreach (sn, sni->names) vec_free (sn->host_name); vec_free (sni->names); } typedef void (*tls_handshake_ext_free_fn) (tls_handshake_ext_t *ext); static tls_handshake_ext_free_fn tls_handshake_ext_free_fns[] = { [TLS_EXT_SERVER_NAME] = tls_handshake_ext_sni_free, }; void tls_handshake_ext_free (tls_handshake_ext_t *ext) { if (ext->type >= ARRAY_LEN (tls_handshake_ext_free_fns) || !tls_handshake_ext_free_fns[ext->type]) return; tls_handshake_ext_free_fns[ext->type](ext); }