diff options
Diffstat (limited to 'libtransport/src')
169 files changed, 16822 insertions, 1900 deletions
diff --git a/libtransport/src/CMakeLists.txt b/libtransport/src/CMakeLists.txt index 33497e0f4..0fa9bbe3c 100644 --- a/libtransport/src/CMakeLists.txt +++ b/libtransport/src/CMakeLists.txt @@ -20,7 +20,7 @@ set(ASIO_STANDALONE 1) add_subdirectory(core) add_subdirectory(interfaces) add_subdirectory(protocols) -add_subdirectory(security) +add_subdirectory(auth) add_subdirectory(implementation) add_subdirectory(utils) add_subdirectory(http) @@ -34,7 +34,16 @@ install( COMPONENT lib${LIBTRANSPORT}-dev ) -set (COMPILER_DEFINITIONS "-DTRANSPORT_LOG_DEF_LEVEL=TRANSPORT_LOG_${TRANSPORT_LOG_LEVEL}") +install( + FILES "transport.config" + DESTINATION ${CMAKE_INSTALL_FULL_SYSCONFDIR}/hicn + COMPONENT lib${LIBTRANSPORT} +) + +list(APPEND COMPILER_DEFINITIONS + "-DTRANSPORT_LOG_DEF_LEVEL=TRANSPORT_LOG_${TRANSPORT_LOG_LEVEL}" + "-DASIO_STANDALONE" +) list(INSERT LIBTRANSPORT_INTERNAL_INCLUDE_DIRS 0 ${CMAKE_CURRENT_SOURCE_DIR}/ @@ -55,8 +64,10 @@ else () set(CMAKE_SHARED_LINKER_FLAGS "/NODEFAULTLIB:\"MSVCRTD\"" ) endif () endif () -if (${CMAKE_SYSTEM_NAME} STREQUAL "Android") + +if (${CMAKE_SYSTEM_NAME} MATCHES "Android") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++ -isystem -lm") + add_subdirectory(io_modules) endif() if (DISABLE_SHARED_LIBRARIES) @@ -68,8 +79,9 @@ if (DISABLE_SHARED_LIBRARIES) DEPENDS ${DEPENDENCIES} COMPONENT lib${LIBTRANSPORT} INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} - INSTALL_ROOT_DIR hicn/transport + HEADER_ROOT_DIR hicn/transport DEFINITIONS ${COMPILER_DEFINITIONS} + VERSION ${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} ) else () build_library(${LIBTRANSPORT} @@ -80,11 +92,17 @@ else () DEPENDS ${DEPENDENCIES} COMPONENT lib${LIBTRANSPORT} INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} - INSTALL_ROOT_DIR hicn/transport + HEADER_ROOT_DIR hicn/transport DEFINITIONS ${COMPILER_DEFINITIONS} + VERSION ${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} ) endif () +# io modules +if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Android") + add_subdirectory(io_modules) +endif() + if (${BUILD_TESTS}) add_subdirectory(test) endif() diff --git a/libtransport/src/auth/CMakeLists.txt b/libtransport/src/auth/CMakeLists.txt new file mode 100644 index 000000000..0e7b5832b --- /dev/null +++ b/libtransport/src/auth/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2017-2019 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/signer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/verifier.cc + ${CMAKE_CURRENT_SOURCE_DIR}/identity.cc +) + +set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) diff --git a/libtransport/src/auth/identity.cc b/libtransport/src/auth/identity.cc new file mode 100644 index 000000000..bd787b9b6 --- /dev/null +++ b/libtransport/src/auth/identity.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/identity.h> + +using namespace std; + +namespace transport { +namespace auth { + +Identity::Identity(const string &keystore_path, const string &keystore_pwd, + CryptoSuite suite, unsigned int signature_len, + unsigned int validity_days, const string &subject_name) + : identity_(nullptr), signer_(nullptr) { + parcSecurity_Init(); + + bool success = parcPkcs12KeyStore_CreateFile( + keystore_path.c_str(), keystore_pwd.c_str(), subject_name.c_str(), + parcCryptoSuite_GetSigningAlgorithm(static_cast<PARCCryptoSuite>(suite)), + signature_len, validity_days); + + parcAssertTrue( + success, + "parcPkcs12KeyStore_CreateFile('%s', '%s', '%s', %d, %d, %d) failed.", + keystore_path.c_str(), keystore_pwd.c_str(), subject_name.c_str(), + static_cast<int>(suite), static_cast<int>(signature_len), validity_days); + + PARCIdentityFile *identity_file = + parcIdentityFile_Create(keystore_path.c_str(), keystore_pwd.c_str()); + + identity_ = + parcIdentity_Create(identity_file, PARCIdentityFileAsPARCIdentity); + + PARCSigner *signer = parcIdentity_CreateSigner( + identity_, + parcCryptoSuite_GetCryptoHash(static_cast<PARCCryptoSuite>(suite))); + + signer_ = make_shared<AsymmetricSigner>(signer); + + parcSigner_Release(&signer); + parcIdentityFile_Release(&identity_file); +} + +Identity::Identity(string &keystore_path, string &keystore_pwd, + CryptoHashType hash_type) + : identity_(nullptr), signer_(nullptr) { + parcSecurity_Init(); + + PARCIdentityFile *identity_file = + parcIdentityFile_Create(keystore_path.c_str(), keystore_pwd.c_str()); + + identity_ = + parcIdentity_Create(identity_file, PARCIdentityFileAsPARCIdentity); + + PARCSigner *signer = parcIdentity_CreateSigner( + identity_, static_cast<PARCCryptoHashType>(hash_type)); + + signer_ = make_shared<AsymmetricSigner>(signer); + + parcSigner_Release(&signer); + parcIdentityFile_Release(&identity_file); +} + +Identity::Identity(const Identity &other) + : identity_(nullptr), signer_(other.signer_) { + parcSecurity_Init(); + identity_ = parcIdentity_Acquire(other.identity_); +} + +Identity::Identity(Identity &&other) + : identity_(nullptr), signer_(move(other.signer_)) { + parcSecurity_Init(); + identity_ = parcIdentity_Acquire(other.identity_); + parcIdentity_Release(&other.identity_); +} + +Identity::~Identity() { + if (identity_) parcIdentity_Release(&identity_); + parcSecurity_Fini(); +} + +shared_ptr<AsymmetricSigner> Identity::getSigner() const { return signer_; } + +string Identity::getFilename() const { + return string(parcIdentity_GetFileName(identity_)); +} + +string Identity::getPassword() const { + return string(parcIdentity_GetPassWord(identity_)); +} + +Identity Identity::generateIdentity(const string &subject_name) { + string keystore_name = "keystore"; + string keystore_password = "password"; + size_t key_length = 1024; + unsigned int validity_days = 30; + CryptoSuite suite = CryptoSuite::RSA_SHA256; + + return Identity(keystore_name, keystore_password, suite, + (unsigned int)key_length, validity_days, subject_name); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/auth/signer.cc b/libtransport/src/auth/signer.cc new file mode 100644 index 000000000..281b9c59a --- /dev/null +++ b/libtransport/src/auth/signer.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/signer.h> + +extern "C" { +#ifndef _WIN32 +TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") +#endif +#include <hicn/hicn.h> +} + +#include <chrono> + +#define ALLOW_UNALIGNED_READS 1 + +using namespace std; + +namespace transport { +namespace auth { + +Signer::Signer() : signer_(nullptr), key_id_(nullptr) { parcSecurity_Init(); } + +Signer::Signer(PARCSigner *signer) : Signer() { setSigner(signer); } + +Signer::~Signer() { + if (signer_) parcSigner_Release(&signer_); + if (key_id_) parcKeyId_Release(&key_id_); + parcSecurity_Fini(); +} + +void Signer::signPacket(PacketPtr packet) { + parcAssertNotNull(signer_, "Expected non-null signer"); + + const utils::MemBuf &header_chain = *packet; + core::Packet::Format format = packet->getFormat(); + auto suite = getCryptoSuite(); + size_t signature_len = getSignatureSize(); + + if (!packet->authenticationHeader()) { + throw errors::MalformedAHPacketException(); + } + + packet->setSignatureSize(signature_len); + + // Copy IP+TCP / ICMP header before zeroing them + hicn_header_t header_copy; + hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + packet->resetForHash(); + + // Fill in the HICN_AH header + auto now = chrono::duration_cast<chrono::milliseconds>( + chrono::system_clock::now().time_since_epoch()) + .count(); + packet->setSignatureTimestamp(now); + packet->setValidationAlgorithm(suite); + + // Set the key ID + KeyId key_id; + key_id.first = static_cast<uint8_t *>( + parcBuffer_Overlay((PARCBuffer *)parcKeyId_GetKeyId(key_id_), 0)); + packet->setKeyId(key_id); + + // Calculate hash + CryptoHasher hasher(parcSigner_GetCryptoHasher(signer_)); + const utils::MemBuf *current = &header_chain; + + hasher.init(); + + do { + hasher.updateBytes(current->data(), current->length()); + current = current->next(); + } while (current != &header_chain); + + CryptoHash hash = hasher.finalize(); + + // Compute signature + PARCSignature *signature = parcSigner_SignDigestNoAlloc( + signer_, hash.hash_, packet->getSignature(), signature_len); + PARCBuffer *buffer = parcSignature_GetSignature(signature); + size_t bytes_len = parcBuffer_Remaining(buffer); + + if (bytes_len > signature_len) { + throw errors::MalformedAHPacketException(); + } + + // Put signature in AH header + hicn_packet_copy_header(format, &header_copy, packet->packet_start_, false); + + // Release allocated objects + parcSignature_Release(&signature); +} + +void Signer::setSigner(PARCSigner *signer) { + parcAssertNotNull(signer, "Expected non-null signer"); + + if (signer_) parcSigner_Release(&signer_); + if (key_id_) parcKeyId_Release(&key_id_); + + signer_ = parcSigner_Acquire(signer); + key_id_ = parcSigner_CreateKeyId(signer_); +} + +size_t Signer::getSignatureSize() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return parcSigner_GetSignatureSize(signer_); +} + +CryptoSuite Signer::getCryptoSuite() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return static_cast<CryptoSuite>(parcSigner_GetCryptoSuite(signer_)); +} + +CryptoHashType Signer::getCryptoHashType() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return static_cast<CryptoHashType>(parcSigner_GetCryptoHashType(signer_)); +} + +PARCSigner *Signer::getParcSigner() const { return signer_; } + +PARCKeyStore *Signer::getParcKeyStore() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return parcSigner_GetKeyStore(signer_); +} + +AsymmetricSigner::AsymmetricSigner(CryptoSuite suite, PARCKeyStore *key_store) { + parcAssertNotNull(key_store, "Expected non-null key_store"); + + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::DSA_SHA256: + case CryptoSuite::RSA_SHA256: + case CryptoSuite::RSA_SHA512: + case CryptoSuite::ECDSA_256K1: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for asymmetric signer"); + } + + setSigner( + parcSigner_Create(parcPublicKeySigner_Create(key_store, crypto_suite), + PARCPublicKeySignerAsSigner)); +} + +SymmetricSigner::SymmetricSigner(CryptoSuite suite, PARCKeyStore *key_store) { + parcAssertNotNull(key_store, "Expected non-null key_store"); + + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for symmetric signer"); + } + + setSigner(parcSigner_Create(parcSymmetricKeySigner_Create( + (PARCSymmetricKeyStore *)key_store, + parcCryptoSuite_GetCryptoHash(crypto_suite)), + PARCSymmetricKeySignerAsSigner)); +} + +SymmetricSigner::SymmetricSigner(CryptoSuite suite, const string &passphrase) { + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for symmetric signer"); + } + + PARCBufferComposer *composer = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer, passphrase.c_str()); + PARCBuffer *key_buf = parcBufferComposer_ProduceBuffer(composer); + parcBufferComposer_Release(&composer); + + PARCSymmetricKeyStore *key_store = parcSymmetricKeyStore_Create(key_buf); + PARCSymmetricKeySigner *key_signer = parcSymmetricKeySigner_Create( + key_store, parcCryptoSuite_GetCryptoHash(crypto_suite)); + + setSigner(parcSigner_Create(key_signer, PARCSymmetricKeySignerAsSigner)); + + parcSymmetricKeySigner_Release(&key_signer); + parcSymmetricKeyStore_Release(&key_store); + parcBuffer_Release(&key_buf); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/auth/verifier.cc b/libtransport/src/auth/verifier.cc new file mode 100644 index 000000000..c6648a763 --- /dev/null +++ b/libtransport/src/auth/verifier.cc @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/verifier.h> +#include <protocols/errors.h> + +extern "C" { +#ifndef _WIN32 +TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") +#endif +#include <hicn/hicn.h> +} + +#include <sys/stat.h> + +using namespace std; + +namespace transport { +namespace auth { + +const std::vector<VerificationPolicy> Verifier::DEFAULT_FAILED_POLICIES = { + VerificationPolicy::DROP, + VerificationPolicy::ABORT, +}; + +Verifier::Verifier() + : hasher_(nullptr), + verifier_(nullptr), + verification_failed_cb_(interface::VOID_HANDLER), + failed_policies_(DEFAULT_FAILED_POLICIES) { + parcSecurity_Init(); + PARCInMemoryVerifier *in_memory_verifier = parcInMemoryVerifier_Create(); + verifier_ = + parcVerifier_Create(in_memory_verifier, PARCInMemoryVerifierAsVerifier); + parcInMemoryVerifier_Release(&in_memory_verifier); +} + +Verifier::~Verifier() { + if (hasher_) parcCryptoHasher_Release(&hasher_); + if (verifier_) parcVerifier_Release(&verifier_); + parcSecurity_Fini(); +} + +bool Verifier::verifyPacket(PacketPtr packet) { + bool valid_packet = false; + core::Packet::Format format = packet->getFormat(); + + if (!packet->authenticationHeader()) { + throw errors::MalformedAHPacketException(); + } + + // Get crypto suite and hash type + auto suite = static_cast<PARCCryptoSuite>(packet->getValidationAlgorithm()); + PARCCryptoHashType hash_type = parcCryptoSuite_GetCryptoHash(suite); + + // Copy IP+TCP / ICMP header before zeroing them + hicn_header_t header_copy; + hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + + // Fetch packet signature + uint8_t *packet_signature = packet->getSignature(); + size_t signature_len = Verifier::getSignatureSize(packet); + vector<uint8_t> signature_raw(packet_signature, + packet_signature + signature_len); + + // Create a signature buffer from the raw packet signature + PARCBuffer *bits = + parcBuffer_Wrap(signature_raw.data(), signature_len, 0, signature_len); + parcBuffer_Rewind(bits); + + // If the signature algo is ECDSA, the signature might be shorter than the + // signature field + PARCSigningAlgorithm algo = parcCryptoSuite_GetSigningAlgorithm(suite); + if (algo == PARCSigningAlgorithm_ECDSA) { + while (parcBuffer_HasRemaining(bits) && parcBuffer_GetUint8(bits) == 0) + ; + parcBuffer_SetPosition(bits, parcBuffer_Position(bits) - 1); + } + + if (!parcBuffer_HasRemaining(bits)) { + parcBuffer_Release(&bits); + return false; + } + + // Create a signature object from the signature buffer + PARCSignature *signature = parcSignature_Create( + parcCryptoSuite_GetSigningAlgorithm(suite), hash_type, bits); + + // Fetch the key to verify the signature + KeyId key_buffer = packet->getKeyId(); + PARCBuffer *buffer = parcBuffer_Wrap(key_buffer.first, key_buffer.second, 0, + key_buffer.second); + PARCKeyId *key_id = parcKeyId_Create(buffer); + + // Reset fields that are not used to compute signature + packet->resetForHash(); + + // Compute the packet hash + if (!hasher_) + setHasher(parcVerifier_GetCryptoHasher(verifier_, key_id, hash_type)); + CryptoHash local_hash = computeHash(packet); + + // Compare the packet signature to the locally computed one + valid_packet = parcVerifier_VerifyDigestSignature( + verifier_, key_id, local_hash.hash_, suite, signature); + + // Restore the fields that were reset + hicn_packet_copy_header(format, &header_copy, packet->packet_start_, false); + + // Release allocated objects + parcBuffer_Release(&buffer); + parcKeyId_Release(&key_id); + parcSignature_Release(&signature); + parcBuffer_Release(&bits); + + return valid_packet; +} + +vector<VerificationPolicy> Verifier::verifyPackets( + const vector<PacketPtr> &packets) { + vector<VerificationPolicy> policies(packets.size(), VerificationPolicy::DROP); + + for (unsigned int i = 0; i < packets.size(); ++i) { + if (verifyPacket(packets[i])) { + policies[i] = VerificationPolicy::ACCEPT; + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +vector<VerificationPolicy> Verifier::verifyPackets( + const vector<PacketPtr> &packets, + const unordered_map<Suffix, HashEntry> &suffix_map) { + vector<VerificationPolicy> policies(packets.size(), + VerificationPolicy::UNKNOWN); + + for (unsigned int i = 0; i < packets.size(); ++i) { + uint32_t suffix = packets[i]->getName().getSuffix(); + auto manifest_hash = suffix_map.find(suffix); + + if (manifest_hash != suffix_map.end()) { + CryptoHashType hash_type = manifest_hash->second.first; + CryptoHash packet_hash = packets[i]->computeDigest(hash_type); + + if (!CryptoHash::compareBinaryDigest( + packet_hash.getDigest<uint8_t>().data(), + manifest_hash->second.second.data(), hash_type)) { + policies[i] = VerificationPolicy::ABORT; + } else { + policies[i] = VerificationPolicy::ACCEPT; + } + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +void Verifier::addKey(PARCKey *key) { parcVerifier_AddKey(verifier_, key); } + +void Verifier::setHasher(PARCCryptoHasher *hasher) { + parcAssertNotNull(hasher, "Expected non-null hasher"); + if (hasher_) parcCryptoHasher_Release(&hasher_); + hasher_ = parcCryptoHasher_Acquire(hasher); +} + +void Verifier::setVerificationFailedCallback( + VerificationFailedCallback verfication_failed_cb, + const vector<VerificationPolicy> &failed_policies) { + verification_failed_cb_ = verfication_failed_cb; + failed_policies_ = failed_policies; +} + +void Verifier::getVerificationFailedCallback( + VerificationFailedCallback **verfication_failed_cb) { + *verfication_failed_cb = &verification_failed_cb_; +} + +size_t Verifier::getSignatureSize(const PacketPtr packet) { + return packet->getSignatureSize(); +} + +CryptoHash Verifier::computeHash(PacketPtr packet) { + parcAssertNotNull(hasher_, "Expected non-null hasher"); + + CryptoHasher crypto_hasher(hasher_); + const utils::MemBuf &header_chain = *packet; + const utils::MemBuf *current = &header_chain; + + crypto_hasher.init(); + + do { + crypto_hasher.updateBytes(current->data(), current->length()); + current = current->next(); + } while (current != &header_chain); + + return crypto_hasher.finalize(); +} + +void Verifier::callVerificationFailedCallback(PacketPtr packet, + VerificationPolicy &policy) { + if (verification_failed_cb_ == interface::VOID_HANDLER) { + return; + } + + if (find(failed_policies_.begin(), failed_policies_.end(), policy) != + failed_policies_.end()) { + policy = verification_failed_cb_( + static_cast<const core::ContentObject &>(*packet), + make_error_code( + protocol::protocol_error::signature_verification_failed)); + } +} + +bool VoidVerifier::verifyPacket(PacketPtr packet) { return true; } + +vector<VerificationPolicy> VoidVerifier::verifyPackets( + const vector<PacketPtr> &packets) { + return vector<VerificationPolicy>(packets.size(), VerificationPolicy::ACCEPT); +} + +vector<VerificationPolicy> VoidVerifier::verifyPackets( + const vector<PacketPtr> &packets, + const unordered_map<Suffix, HashEntry> &suffix_map) { + return vector<VerificationPolicy>(packets.size(), VerificationPolicy::ACCEPT); +} + +AsymmetricVerifier::AsymmetricVerifier(PARCKey *pub_key) { addKey(pub_key); } + +AsymmetricVerifier::AsymmetricVerifier(const string &cert_path) { + setCertificate(cert_path); +} + +void AsymmetricVerifier::setCertificate(const string &cert_path) { + PARCCertificateFactory *factory = parcCertificateFactory_Create( + PARCCertificateType_X509, PARCContainerEncoding_PEM); + + struct stat buffer; + if (stat(cert_path.c_str(), &buffer) != 0) { + throw errors::RuntimeException("Certificate does not exist"); + } + + PARCCertificate *certificate = + parcCertificateFactory_CreateCertificateFromFile(factory, + cert_path.c_str(), NULL); + PARCKey *key = parcCertificate_GetPublicKey(certificate); + + addKey(key); + + parcKey_Release(&key); + parcCertificateFactory_Release(&factory); +} + +SymmetricVerifier::SymmetricVerifier(const string &passphrase) + : passphrase_(nullptr), signer_(nullptr) { + setPassphrase(passphrase); +} + +SymmetricVerifier::~SymmetricVerifier() { + if (passphrase_) parcBuffer_Release(&passphrase_); + if (signer_) parcSigner_Release(&signer_); +} + +void SymmetricVerifier::setPassphrase(const string &passphrase) { + if (passphrase_) parcBuffer_Release(&passphrase_); + + PARCBufferComposer *composer = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer, passphrase.c_str()); + passphrase_ = parcBufferComposer_ProduceBuffer(composer); + parcBufferComposer_Release(&composer); +} + +void SymmetricVerifier::setSigner(const PARCCryptoSuite &suite) { + parcAssertNotNull(passphrase_, "Expected non-null passphrase"); + + if (signer_) parcSigner_Release(&signer_); + + PARCSymmetricKeyStore *key_store = parcSymmetricKeyStore_Create(passphrase_); + PARCSymmetricKeySigner *key_signer = parcSymmetricKeySigner_Create( + key_store, parcCryptoSuite_GetCryptoHash(suite)); + signer_ = parcSigner_Create(key_signer, PARCSymmetricKeySignerAsSigner); + + PARCKeyId *key_id = parcSigner_CreateKeyId(signer_); + PARCKey *key = parcKey_CreateFromSymmetricKey( + key_id, parcSigner_GetSigningAlgorithm(signer_), passphrase_); + + addKey(key); + setHasher(parcSigner_GetCryptoHasher(signer_)); + + parcSymmetricKeyStore_Release(&key_store); + parcSymmetricKeySigner_Release(&key_signer); + parcKeyId_Release(&key_id); + parcKey_Release(&key); +} + +vector<VerificationPolicy> SymmetricVerifier::verifyPackets( + const vector<PacketPtr> &packets) { + vector<VerificationPolicy> policies(packets.size(), VerificationPolicy::DROP); + + for (unsigned int i = 0; i < packets.size(); ++i) { + auto suite = + static_cast<PARCCryptoSuite>(packets[i]->getValidationAlgorithm()); + + if (!signer_ || suite != parcSigner_GetCryptoSuite(signer_)) { + setSigner(suite); + } + + if (verifyPacket(packets[i])) { + policies[i] = VerificationPolicy::ACCEPT; + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/config.h.in b/libtransport/src/config.h.in index 4e9a0f262..ef47affda 100644 --- a/libtransport/src/config.h.in +++ b/libtransport/src/config.h.in @@ -25,10 +25,6 @@ #cmakedefine ASIO_STANDALONE #endif -#ifndef SECURE_HICNTRANSPORT -#cmakedefine SECURE_HICNTRANSPORT -#endif - #define RAAQM_CONFIG_PATH "@raaqm_config_path@" #cmakedefine __vpp__ diff --git a/libtransport/src/core/CMakeLists.txt b/libtransport/src/core/CMakeLists.txt index 5c8ab9270..4e3ac10ec 100644 --- a/libtransport/src/core/CMakeLists.txt +++ b/libtransport/src/core/CMakeLists.txt @@ -21,55 +21,27 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format.h ${CMAKE_CURRENT_SOURCE_DIR}/pending_interest.h ${CMAKE_CURRENT_SOURCE_DIR}/portal.h - ${CMAKE_CURRENT_SOURCE_DIR}/connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/tcp_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/errors.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.h + ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/rs.h ) list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/content_object.cc ${CMAKE_CURRENT_SOURCE_DIR}/interest.cc + ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc ${CMAKE_CURRENT_SOURCE_DIR}/packet.cc ${CMAKE_CURRENT_SOURCE_DIR}/name.cc ${CMAKE_CURRENT_SOURCE_DIR}/prefix.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tcp_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_interface.cc ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format_fixed.cc - ${CMAKE_CURRENT_SOURCE_DIR}/connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc + ${CMAKE_CURRENT_SOURCE_DIR}/io_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fec.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rs.cc ) -if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") - if (BUILD_WITH_VPP OR BUILD_HICNPLUGIN) - list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.h - ) - - list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.cc - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.c - ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.c - ) - endif() - - list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_interface.h - ) - - list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_interface.cc - ) -endif() - set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file diff --git a/libtransport/src/core/content_object.cc b/libtransport/src/core/content_object.cc index f5cccf404..0c68ef559 100644 --- a/libtransport/src/core/content_object.cc +++ b/libtransport/src/core/content_object.cc @@ -32,8 +32,9 @@ namespace transport { namespace core { -ContentObject::ContentObject(const Name &name, Packet::Format format) - : Packet(format) { +ContentObject::ContentObject(const Name &name, Packet::Format format, + std::size_t additional_header_size) + : Packet(format, additional_header_size) { if (TRANSPORT_EXPECT_FALSE( hicn_data_set_name(format, packet_start_, &name.name_) < 0)) { throw errors::RuntimeException("Error filling the packet name."); @@ -47,41 +48,32 @@ ContentObject::ContentObject(const Name &name, Packet::Format format) } #ifdef __ANDROID__ -ContentObject::ContentObject(hicn_format_t format) - : ContentObject(Name("0::0|0"), format) {} +ContentObject::ContentObject(hicn_format_t format, + std::size_t additional_header_size) + : ContentObject(Name("0::0|0"), format, additional_header_size) {} #else -ContentObject::ContentObject(hicn_format_t format) - : ContentObject(Packet::base_name, format) {} +ContentObject::ContentObject(hicn_format_t format, + std::size_t additional_header_size) + : ContentObject(Packet::base_name, format, additional_header_size) {} #endif ContentObject::ContentObject(const Name &name, hicn_format_t format, + std::size_t additional_header_size, const uint8_t *payload, std::size_t size) - : ContentObject(name, format) { + : ContentObject(name, format, additional_header_size) { appendPayload(payload, size); } -ContentObject::ContentObject(const uint8_t *buffer, std::size_t size) - : Packet(buffer, size) { - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::RuntimeException("Error getting name from content object."); - } +ContentObject::ContentObject(ContentObject &&other) : Packet(std::move(other)) { + name_ = std::move(other.name_); } -ContentObject::ContentObject(MemBufPtr &&buffer) : Packet(std::move(buffer)) { - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::RuntimeException("Error getting name from content object."); - } +ContentObject::ContentObject(const ContentObject &other) : Packet(other) { + name_ = other.name_; } -ContentObject::ContentObject(ContentObject &&other) : Packet(std::move(other)) { - name_ = std::move(other.name_); - - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::MalformedPacketException(); - } +ContentObject &ContentObject::operator=(const ContentObject &other) { + return (ContentObject &)Packet::operator=(other); } ContentObject::~ContentObject() {} @@ -132,10 +124,11 @@ uint32_t ContentObject::getPathLabel() const { "Error retrieving the path label from content object"); } - return path_label; + return ntohl(path_label); } ContentObject &ContentObject::setPathLabel(uint32_t path_label) { + path_label = htonl(path_label); if (hicn_data_set_path_label((hicn_header_t *)packet_start_, path_label) < 0) { throw errors::RuntimeException( diff --git a/libtransport/src/core/errors.cc b/libtransport/src/core/errors.cc new file mode 100644 index 000000000..82647a60b --- /dev/null +++ b/libtransport/src/core/errors.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/errors.h> + +namespace transport { +namespace core { + +const std::error_category& core_category() { + static core_category_impl instance; + + return instance; +} + +const char* core_category_impl::name() const throw() { + return "transport::protocol::error"; +} + +std::string core_category_impl::message(int ev) const { + switch (static_cast<core_error>(ev)) { + case core_error::success: { + return "Success"; + } + case core_error::configuration_parse_failed: { + return "Error parsing configuration."; + } + case core_error::configuration_not_applied: { + return "Configuration was not applied due to wrong parameters."; + } + default: { + return "Unknown core error"; + } + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/errors.h b/libtransport/src/core/errors.h new file mode 100644 index 000000000..a46f1dbcd --- /dev/null +++ b/libtransport/src/core/errors.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> +#include <system_error> + +namespace transport { +namespace core { + +/** + * @brief Get the default server error category. + * @return The default server error category instance. + * + * @warning The first call to this function is thread-safe only starting with + * C++11. + */ +const std::error_category& core_category(); + +/** + * The list of errors. + */ +enum class core_error { + success = 0, + configuration_parse_failed, + configuration_not_applied +}; + +/** + * @brief Create an error_code instance for the given error. + * @param error The error. + * @return The error_code instance. + */ +inline std::error_code make_error_code(core_error error) { + return std::error_code(static_cast<int>(error), core_category()); +} + +/** + * @brief Create an error_condition instance for the given error. + * @param error The error. + * @return The error_condition instance. + */ +inline std::error_condition make_error_condition(core_error error) { + return std::error_condition(static_cast<int>(error), core_category()); +} + +/** + * @brief A server error category. + */ +class core_category_impl : public std::error_category { + public: + /** + * @brief Get the name of the category. + * @return The name of the category. + */ + virtual const char* name() const throw(); + + /** + * @brief Get the error message for a given error. + * @param ev The error numeric value. + * @return The message associated to the error. + */ + virtual std::string message(int ev) const; +}; +} // namespace core +} // namespace transport + +namespace std { +// namespace system { +template <> +struct is_error_code_enum<::transport::core::core_error> + : public std::true_type {}; +// } // namespace system +} // namespace std
\ No newline at end of file diff --git a/libtransport/src/core/facade.h b/libtransport/src/core/facade.h index 04f643f63..199081271 100644 --- a/libtransport/src/core/facade.h +++ b/libtransport/src/core/facade.h @@ -15,36 +15,14 @@ #pragma once -#include <core/forwarder_interface.h> -#include <core/hicn_forwarder_interface.h> #include <core/manifest_format_fixed.h> #include <core/manifest_inline.h> #include <core/portal.h> -#ifdef __linux__ -#ifndef __ANDROID__ -#include <core/raw_socket_interface.h> -#ifdef __vpp__ -#include <core/vpp_forwarder_interface.h> -#endif -#endif -#endif - namespace transport { namespace core { -using HicnForwarderPortal = Portal<HicnForwarderInterface>; - -#ifdef __linux__ -#ifndef __ANDROID__ -using RawSocketPortal = Portal<RawSocketInterface>; -#endif -#ifdef __vpp__ -using VPPForwarderPortal = Portal<VPPForwarderInterface>; -#endif -#endif - using ContentObjectManifest = core::ManifestInline<ContentObject, Fixed>; using InterestManifest = core::ManifestInline<Interest, Fixed>; diff --git a/libtransport/src/core/fec.cc b/libtransport/src/core/fec.cc new file mode 100644 index 000000000..134198b9e --- /dev/null +++ b/libtransport/src/core/fec.cc @@ -0,0 +1,878 @@ +/* + * fec.c -- forward error correction based on Vandermonde matrices + * 980624 + * (C) 1997-98 Luigi Rizzo (luigi@iet.unipi.it) + * + * Portions derived from code by Phil Karn (karn@ka9q.ampr.org), + * Robert Morelos-Zaragoza (robert@spectra.eng.hawaii.edu) and Hari + * Thirumoorthy (harit@spectra.eng.hawaii.edu), Aug 1995 + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR + * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY + * OF SUCH DAMAGE. + */ + +/* + * The following parameter defines how many bits are used for + * field elements. The code supports any value from 2 to 16 + * but fastest operation is achieved with 8 bit elements + * This is the only parameter you may want to change. + */ +#ifndef GF_BITS +#define GF_BITS 8 /* code over GF(2**GF_BITS) - change to suit */ +#endif + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <hicn/transport/portability/platform.h> +#include "fec.h" + +/** + * XXX This disable a warning raising only in some platforms. + * TODO Check if this warning is a mistake or it is a real bug: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83404 + * https://gcc.gnu.org/bugzilla//show_bug.cgi?id=88059 + */ +#ifndef __clang__ +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif + +/* + * compatibility stuff + */ +#ifdef MSDOS /* but also for others, e.g. sun... */ +#define NEED_BCOPY +#define bcmp(a,b,n) memcmp(a,b,n) +#endif + +#ifdef ANDROID +#define bcmp(a,b,n) memcmp(a,b,n) +#endif + +#ifdef NEED_BCOPY +#define bcopy(s, d, siz) memcpy((d), (s), (siz)) +#define bzero(d, siz) memset((d), '\0', (siz)) +#endif + +/* + * stuff used for testing purposes only + */ + +#ifdef TEST +#define DEB(x) +#define DDB(x) x +#define DEBUG 0 /* minimal debugging */ +#ifdef MSDOS +#include <time.h> +struct timeval { + unsigned long ticks; +}; +#define gettimeofday(x, dummy) { (x)->ticks = clock() ; } +#define DIFF_T(a,b) (1+ 1000000*(a.ticks - b.ticks) / CLOCKS_PER_SEC ) +typedef unsigned long u_long ; +typedef unsigned short u_short ; +#else /* typically, unix systems */ +#include <sys/time.h> +#define DIFF_T(a,b) \ + (1+ 1000000*(a.tv_sec - b.tv_sec) + (a.tv_usec - b.tv_usec) ) +#endif + +#define TICK(t) \ + {struct timeval x ; \ + gettimeofday(&x, NULL) ; \ + t = x.tv_usec + 1000000* (x.tv_sec & 0xff ) ; \ + } +#define TOCK(t) \ + { u_long t1 ; TICK(t1) ; \ + if (t1 < t) t = 256000000 + t1 - t ; \ + else t = t1 - t ; \ + if (t == 0) t = 1 ;} + +u_long ticks[10]; /* vars for timekeeping */ +#else +#define DEB(x) +#define DDB(x) +#define TICK(x) +#define TOCK(x) +#endif /* TEST */ + +/* + * You should not need to change anything beyond this point. + * The first part of the file implements linear algebra in GF. + * + * gf is the type used to store an element of the Galois Field. + * Must constain at least GF_BITS bits. + * + * Note: unsigned char will work up to GF(256) but int seems to run + * faster on the Pentium. We use int whenever have to deal with an + * index, since they are generally faster. + */ +#if (GF_BITS < 2 && GF_BITS >16) +#error "GF_BITS must be 2 .. 16" +#endif + + +#define GF_SIZE ((1 << GF_BITS) - 1) /* powers of \alpha */ + +/* + * Primitive polynomials - see Lin & Costello, Appendix A, + * and Lee & Messerschmitt, p. 453. + */ +static const char *allPp[] = { /* GF_BITS polynomial */ + NULL, /* 0 no code */ + NULL, /* 1 no code */ + "111", /* 2 1+x+x^2 */ + "1101", /* 3 1+x+x^3 */ + "11001", /* 4 1+x+x^4 */ + "101001", /* 5 1+x^2+x^5 */ + "1100001", /* 6 1+x+x^6 */ + "10010001", /* 7 1 + x^3 + x^7 */ + "101110001", /* 8 1+x^2+x^3+x^4+x^8 */ + "1000100001", /* 9 1+x^4+x^9 */ + "10010000001", /* 10 1+x^3+x^10 */ + "101000000001", /* 11 1+x^2+x^11 */ + "1100101000001", /* 12 1+x+x^4+x^6+x^12 */ + "11011000000001", /* 13 1+x+x^3+x^4+x^13 */ + "110000100010001", /* 14 1+x+x^6+x^10+x^14 */ + "1100000000000001", /* 15 1+x+x^15 */ + "11010000000010001" /* 16 1+x+x^3+x^12+x^16 */ +}; + + +/* + * To speed up computations, we have tables for logarithm, exponent + * and inverse of a number. If GF_BITS <= 8, we use a table for + * multiplication as well (it takes 64K, no big deal even on a PDA, + * especially because it can be pre-initialized an put into a ROM!), + * otherwhise we use a table of logarithms. + * In any case the macro gf_mul(x,y) takes care of multiplications. + */ + +static gf gf_exp[2*GF_SIZE]; /* index->poly form conversion table */ +static int gf_log[GF_SIZE + 1]; /* Poly->index form conversion table */ +static gf inverse[GF_SIZE+1]; /* inverse of field elem. */ + /* inv[\alpha**i]=\alpha**(GF_SIZE-i-1) */ + +/* + * modnn(x) computes x % GF_SIZE, where GF_SIZE is 2**GF_BITS - 1, + * without a slow divide. + */ +static inline gf +modnn(int x) +{ + while (x >= GF_SIZE) { + x -= GF_SIZE; + x = (x >> GF_BITS) + (x & GF_SIZE); + } + return x; +} + +#define SWAP(a,b,t) {t tmp; tmp=a; a=b; b=tmp;} + +/* + * gf_mul(x,y) multiplies two numbers. If GF_BITS<=8, it is much + * faster to use a multiplication table. + * + * USE_GF_MULC, GF_MULC0(c) and GF_ADDMULC(x) can be used when multiplying + * many numbers by the same constant. In this case the first + * call sets the constant, and others perform the multiplications. + * A value related to the multiplication is held in a local variable + * declared with USE_GF_MULC . See usage in addmul1(). + */ +#if (GF_BITS <= 8) +static gf gf_mul_table[GF_SIZE + 1][GF_SIZE + 1]; + +#define gf_mul(x,y) gf_mul_table[x][y] + +#define USE_GF_MULC gf * __gf_mulc_ +#define GF_MULC0(c) __gf_mulc_ = gf_mul_table[c] +#define GF_ADDMULC(dst, x) dst ^= __gf_mulc_[x] + +static void +init_mul_table() +{ + int i, j; + for (i=0; i< GF_SIZE+1; i++) + for (j=0; j< GF_SIZE+1; j++) + gf_mul_table[i][j] = gf_exp[modnn(gf_log[i] + gf_log[j]) ] ; + + for (j=0; j< GF_SIZE+1; j++) + gf_mul_table[0][j] = gf_mul_table[j][0] = 0; +} +#else /* GF_BITS > 8 */ +static inline gf +gf_mul(x,y) +{ + if ( (x) == 0 || (y)==0 ) return 0; + + return gf_exp[gf_log[x] + gf_log[y] ] ; +} +#define init_mul_table() + +#define USE_GF_MULC register gf * __gf_mulc_ +#define GF_MULC0(c) __gf_mulc_ = &gf_exp[ gf_log[c] ] +#define GF_ADDMULC(dst, x) { if (x) dst ^= __gf_mulc_[ gf_log[x] ] ; } +#endif + +/* + * Generate GF(2**m) from the irreducible polynomial p(X) in p[0]..p[m] + * Lookup tables: + * index->polynomial form gf_exp[] contains j= \alpha^i; + * polynomial form -> index form gf_log[ j = \alpha^i ] = i + * \alpha=x is the primitive element of GF(2^m) + * + * For efficiency, gf_exp[] has size 2*GF_SIZE, so that a simple + * multiplication of two numbers can be resolved without calling modnn + */ + +/* + * i use malloc so many times, it is easier to put checks all in + * one place. + */ +static void * +my_malloc(int sz, const char *err_string) +{ + void *p = malloc( sz ); + if (p == NULL) { + fprintf(stderr, "-- malloc failure allocating %s\n", err_string); + exit(1) ; + } + return p ; +} + +#define NEW_GF_MATRIX(rows, cols) \ + (gf *)my_malloc(rows * cols * sizeof(gf), " ## __LINE__ ## " ) + +/* + * initialize the data structures used for computations in GF. + */ +static void +generate_gf(void) +{ + int i; + gf mask; + const char *Pp = allPp[GF_BITS] ; + + mask = 1; /* x ** 0 = 1 */ + gf_exp[GF_BITS] = 0; /* will be updated at the end of the 1st loop */ + /* + * first, generate the (polynomial representation of) powers of \alpha, + * which are stored in gf_exp[i] = \alpha ** i . + * At the same time build gf_log[gf_exp[i]] = i . + * The first GF_BITS powers are simply bits shifted to the left. + */ + for (i = 0; i < GF_BITS; i++, mask <<= 1 ) { + gf_exp[i] = mask; + gf_log[gf_exp[i]] = i; + /* + * If Pp[i] == 1 then \alpha ** i occurs in poly-repr + * gf_exp[GF_BITS] = \alpha ** GF_BITS + */ + if ( Pp[i] == '1' ) + gf_exp[GF_BITS] ^= mask; + } + /* + * now gf_exp[GF_BITS] = \alpha ** GF_BITS is complete, so can als + * compute its inverse. + */ + gf_log[gf_exp[GF_BITS]] = GF_BITS; + /* + * Poly-repr of \alpha ** (i+1) is given by poly-repr of + * \alpha ** i shifted left one-bit and accounting for any + * \alpha ** GF_BITS term that may occur when poly-repr of + * \alpha ** i is shifted. + */ + mask = 1 << (GF_BITS - 1 ) ; + for (i = GF_BITS + 1; i < GF_SIZE; i++) { + if (gf_exp[i - 1] >= mask) + gf_exp[i] = gf_exp[GF_BITS] ^ ((gf_exp[i - 1] ^ mask) << 1); + else + gf_exp[i] = gf_exp[i - 1] << 1; + gf_log[gf_exp[i]] = i; + } + /* + * log(0) is not defined, so use a special value + */ + gf_log[0] = GF_SIZE ; + /* set the extended gf_exp values for fast multiply */ + for (i = 0 ; i < GF_SIZE ; i++) + gf_exp[i + GF_SIZE] = gf_exp[i] ; + + /* + * again special cases. 0 has no inverse. This used to + * be initialized to GF_SIZE, but it should make no difference + * since noone is supposed to read from here. + */ + inverse[0] = 0 ; + inverse[1] = 1; + for (i=2; i<=GF_SIZE; i++) + inverse[i] = gf_exp[GF_SIZE-gf_log[i]]; +} + +/* + * Various linear algebra operations that i use often. + */ + +/* + * addmul() computes dst[] = dst[] + c * src[] + * This is used often, so better optimize it! Currently the loop is + * unrolled 16 times, a good value for 486 and pentium-class machines. + * The case c=0 is also optimized, whereas c=1 is not. These + * calls are unfrequent in my typical apps so I did not bother. + * + * Note that gcc on + */ +#define addmul(dst, src, c, sz) \ + if (c != 0) addmul1(dst, src, c, sz) + +#define UNROLL 16 /* 1, 4, 8, 16 */ +static void +addmul1(gf *dst1, gf *src1, gf c, int sz) +{ + USE_GF_MULC ; + gf *dst = dst1, *src = src1 ; + gf *lim = &dst[sz - UNROLL + 1] ; + + GF_MULC0(c) ; + +#if (UNROLL > 1) /* unrolling by 8/16 is quite effective on the pentium */ + for (; dst < lim ; dst += UNROLL, src += UNROLL ) { + GF_ADDMULC( dst[0] , src[0] ); + GF_ADDMULC( dst[1] , src[1] ); + GF_ADDMULC( dst[2] , src[2] ); + GF_ADDMULC( dst[3] , src[3] ); +#if (UNROLL > 4) + GF_ADDMULC( dst[4] , src[4] ); + GF_ADDMULC( dst[5] , src[5] ); + GF_ADDMULC( dst[6] , src[6] ); + GF_ADDMULC( dst[7] , src[7] ); +#endif +#if (UNROLL > 8) + GF_ADDMULC( dst[8] , src[8] ); + GF_ADDMULC( dst[9] , src[9] ); + GF_ADDMULC( dst[10] , src[10] ); + GF_ADDMULC( dst[11] , src[11] ); + GF_ADDMULC( dst[12] , src[12] ); + GF_ADDMULC( dst[13] , src[13] ); + GF_ADDMULC( dst[14] , src[14] ); + GF_ADDMULC( dst[15] , src[15] ); +#endif + } +#endif + lim += UNROLL - 1 ; + for (; dst < lim; dst++, src++ ) /* final components */ + GF_ADDMULC( *dst , *src ); +} + +/* + * computes C = AB where A is n*k, B is k*m, C is n*m + */ +static void +matmul(gf *a, gf *b, gf *c, int n, int k, int m) +{ + int row, col, i ; + + for (row = 0; row < n ; row++) { + for (col = 0; col < m ; col++) { + gf *pa = &a[ row * k ]; + gf *pb = &b[ col ]; + gf acc = 0 ; + for (i = 0; i < k ; i++, pa++, pb += m ) + acc ^= gf_mul( *pa, *pb ) ; + c[ row * m + col ] = acc ; + } + } +} + +#ifdef DEBUGG +/* + * returns 1 if the square matrix is identiy + * (only for test) + */ +static int +is_identity(gf *m, int k) +{ + int row, col ; + for (row=0; row<k; row++) + for (col=0; col<k; col++) + if ( (row==col && *m != 1) || + (row!=col && *m != 0) ) + return 0 ; + else + m++ ; + return 1 ; +} +#endif /* debug */ + +/* + * invert_mat() takes a matrix and produces its inverse + * k is the size of the matrix. + * (Gauss-Jordan, adapted from Numerical Recipes in C) + * Return non-zero if singular. + */ +DEB( int pivloops=0; int pivswaps=0 ; /* diagnostic */) +static int +invert_mat(gf *src, int k) +{ + gf c, *p ; + int irow, icol, row, col, i, ix ; + + int error = 1 ; + int *indxc = (int*)my_malloc(k*sizeof(int), "indxc"); + int *indxr = (int*)my_malloc(k*sizeof(int), "indxr"); + int *ipiv = (int*)my_malloc(k*sizeof(int), "ipiv"); + gf *id_row = NEW_GF_MATRIX(1, k); + gf *temp_row = NEW_GF_MATRIX(1, k); + + bzero(id_row, k*sizeof(gf)); + DEB( pivloops=0; pivswaps=0 ; /* diagnostic */ ) + /* + * ipiv marks elements already used as pivots. + */ + for (i = 0; i < k ; i++) + ipiv[i] = 0 ; + + for (col = 0; col < k ; col++) { + gf *pivot_row ; + /* + * Zeroing column 'col', look for a non-zero element. + * First try on the diagonal, if it fails, look elsewhere. + */ + irow = icol = -1 ; + if (ipiv[col] != 1 && src[col*k + col] != 0) { + irow = col ; + icol = col ; + goto found_piv ; + } + for (row = 0 ; row < k ; row++) { + if (ipiv[row] != 1) { + for (ix = 0 ; ix < k ; ix++) { + DEB( pivloops++ ; ) + if (ipiv[ix] == 0) { + if (src[row*k + ix] != 0) { + irow = row ; + icol = ix ; + goto found_piv ; + } + } else if (ipiv[ix] > 1) { + fprintf(stderr, "singular matrix\n"); + goto fail ; + } + } + } + } + if (icol == -1) { + fprintf(stderr, "XXX pivot not found!\n"); + goto fail ; + } +found_piv: + ++(ipiv[icol]) ; + /* + * swap rows irow and icol, so afterwards the diagonal + * element will be correct. Rarely done, not worth + * optimizing. + */ + if (irow != icol) { + for (ix = 0 ; ix < k ; ix++ ) { + SWAP( src[irow*k + ix], src[icol*k + ix], gf) ; + } + } + indxr[col] = irow ; + indxc[col] = icol ; + pivot_row = &src[icol*k] ; + c = pivot_row[icol] ; + if (c == 0) { + fprintf(stderr, "singular matrix 2\n"); + goto fail ; + } + if (c != 1 ) { /* otherwhise this is a NOP */ + /* + * this is done often , but optimizing is not so + * fruitful, at least in the obvious ways (unrolling) + */ + DEB( pivswaps++ ; ) + c = inverse[ c ] ; + pivot_row[icol] = 1 ; + for (ix = 0 ; ix < k ; ix++ ) + pivot_row[ix] = gf_mul(c, pivot_row[ix] ); + } + /* + * from all rows, remove multiples of the selected row + * to zero the relevant entry (in fact, the entry is not zero + * because we know it must be zero). + * (Here, if we know that the pivot_row is the identity, + * we can optimize the addmul). + */ + id_row[icol] = 1; + if (bcmp(pivot_row, id_row, k*sizeof(gf)) != 0) { + for (p = src, ix = 0 ; ix < k ; ix++, p += k ) { + if (ix != icol) { + c = p[icol] ; + p[icol] = 0 ; + addmul(p, pivot_row, c, k ); + } + } + } + id_row[icol] = 0; + } /* done all columns */ + for (col = k-1 ; col >= 0 ; col-- ) { + if (indxr[col] <0 || indxr[col] >= k) + fprintf(stderr, "AARGH, indxr[col] %d\n", indxr[col]); + else if (indxc[col] <0 || indxc[col] >= k) + fprintf(stderr, "AARGH, indxc[col] %d\n", indxc[col]); + else + if (indxr[col] != indxc[col] ) { + for (row = 0 ; row < k ; row++ ) { + SWAP( src[row*k + indxr[col]], src[row*k + indxc[col]], gf) ; + } + } + } + error = 0 ; +fail: + free(indxc); + free(indxr); + free(ipiv); + free(id_row); + free(temp_row); + return error ; +} + +/* + * fast code for inverting a vandermonde matrix. + * XXX NOTE: It assumes that the matrix + * is not singular and _IS_ a vandermonde matrix. Only uses + * the second column of the matrix, containing the p_i's. + * + * Algorithm borrowed from "Numerical recipes in C" -- sec.2.8, but + * largely revised for my purposes. + * p = coefficients of the matrix (p_i) + * q = values of the polynomial (known) + */ + +int +invert_vdm(gf *src, int k) +{ + int i, j, row, col ; + gf *b, *c, *p; + gf t, xx ; + + if (k == 1) /* degenerate case, matrix must be p^0 = 1 */ + return 0 ; + /* + * c holds the coefficient of P(x) = Prod (x - p_i), i=0..k-1 + * b holds the coefficient for the matrix inversion + */ + c = NEW_GF_MATRIX(1, k); + b = NEW_GF_MATRIX(1, k); + + p = NEW_GF_MATRIX(1, k); + + for ( j=1, i = 0 ; i < k ; i++, j+=k ) { + c[i] = 0 ; + p[i] = src[j] ; /* p[i] */ + } + /* + * construct coeffs. recursively. We know c[k] = 1 (implicit) + * and start P_0 = x - p_0, then at each stage multiply by + * x - p_i generating P_i = x P_{i-1} - p_i P_{i-1} + * After k steps we are done. + */ + c[k-1] = p[0] ; /* really -p(0), but x = -x in GF(2^m) */ + for (i = 1 ; i < k ; i++ ) { + gf p_i = p[i] ; /* see above comment */ + for (j = k-1 - ( i - 1 ) ; j < k-1 ; j++ ) + c[j] ^= gf_mul( p_i, c[j+1] ) ; + c[k-1] ^= p_i ; + } + + for (row = 0 ; row < k ; row++ ) { + /* + * synthetic division etc. + */ + xx = p[row] ; + t = 1 ; + b[k-1] = 1 ; /* this is in fact c[k] */ + for (i = k-2 ; i >= 0 ; i-- ) { + b[i] = c[i+1] ^ gf_mul(xx, b[i+1]) ; + t = gf_mul(xx, t) ^ b[i] ; + } + for (col = 0 ; col < k ; col++ ) + src[col*k + row] = gf_mul(inverse[t], b[col] ); + } + free(c) ; + free(b) ; + free(p) ; + return 0 ; +} + +static int fec_initialized = 0 ; +static void +init_fec() +{ + TICK(ticks[0]); + generate_gf(); + TOCK(ticks[0]); + DDB(fprintf(stderr, "generate_gf took %ldus\n", ticks[0]);) + TICK(ticks[0]); + init_mul_table(); + TOCK(ticks[0]); + DDB(fprintf(stderr, "init_mul_table took %ldus\n", ticks[0]);) + fec_initialized = 1 ; +} + +/* + * This section contains the proper FEC encoding/decoding routines. + * The encoding matrix is computed starting with a Vandermonde matrix, + * and then transforming it into a systematic matrix. + */ + +#define FEC_MAGIC 0xFECC0DEC + +void +fec_free(struct fec_parms *p) +{ + if (p==NULL || + p->magic != ( ( (FEC_MAGIC ^ p->k) ^ p->n) ^ (unsigned long)(p->enc_matrix)) ) { + fprintf(stderr, "bad parameters to fec_free\n"); + return ; + } + free(p->enc_matrix); + free(p); +} + +/* + * create a new encoder, returning a descriptor. This contains k,n and + * the encoding matrix. + */ +struct fec_parms * +fec_new(int k, int n) +{ + int row, col ; + gf *p, *tmp_m ; + + struct fec_parms *retval ; + + if (fec_initialized == 0) + init_fec(); + + if (k > GF_SIZE + 1 || n > GF_SIZE + 1 || k > n ) { + fprintf(stderr, "Invalid parameters k %d n %d GF_SIZE %d\n", + k, n, GF_SIZE ); + return NULL ; + } + retval = (struct fec_parms *)my_malloc(sizeof(struct fec_parms), "new_code"); + retval->k = k ; + retval->n = n ; + retval->enc_matrix = NEW_GF_MATRIX(n, k); + retval->magic = ( ( FEC_MAGIC ^ k) ^ n) ^ (unsigned long)(retval->enc_matrix) ; + tmp_m = NEW_GF_MATRIX(n, k); + /* + * fill the matrix with powers of field elements, starting from 0. + * The first row is special, cannot be computed with exp. table. + */ + tmp_m[0] = 1 ; + for (col = 1; col < k ; col++) + tmp_m[col] = 0 ; + for (p = tmp_m + k, row = 0; row < n-1 ; row++, p += k) { + for ( col = 0 ; col < k ; col ++ ) + p[col] = gf_exp[modnn(row*col)]; + } + + /* + * quick code to build systematic matrix: invert the top + * k*k vandermonde matrix, multiply right the bottom n-k rows + * by the inverse, and construct the identity matrix at the top. + */ + TICK(ticks[3]); + invert_vdm(tmp_m, k); /* much faster than invert_mat */ + matmul(tmp_m + k*k, tmp_m, retval->enc_matrix + k*k, n - k, k, k); + /* + * the upper matrix is I so do not bother with a slow multiply + */ + bzero(retval->enc_matrix, k*k*sizeof(gf) ); + for (p = retval->enc_matrix, col = 0 ; col < k ; col++, p += k+1 ) + *p = 1 ; + free(tmp_m); + TOCK(ticks[3]); + + DDB(fprintf(stderr, "--- %ld us to build encoding matrix\n", + ticks[3]);) + DEB(pr_matrix(retval->enc_matrix, n, k, "encoding_matrix");) + return retval ; +} + +/* + * fec_encode accepts as input pointers to n data packets of size sz, + * and produces as output a packet pointed to by fec, computed + * with index "index". + */ +void +fec_encode(struct fec_parms *code, gf *src[], gf *fec, int index, int sz) +{ + int i, k = code->k ; + gf *p ; + + if (GF_BITS > 8) + sz /= 2 ; + + if (index < k) + bcopy(src[index], fec, sz*sizeof(gf) ) ; + else if (index < code->n) { + p = &(code->enc_matrix[index*k] ); + bzero(fec, sz*sizeof(gf)); + for (i = 0; i < k ; i++) + addmul(fec, src[i], p[i], sz ) ; + } else + fprintf(stderr, "Invalid index %d (max %d)\n", + index, code->n - 1 ); +} + +/* + * shuffle move src packets in their position + */ +static int +shuffle(gf *pkt[], int index[], int k) +{ + int i; + + for ( i = 0 ; i < k ; ) { + if (index[i] >= k || index[i] == i) + i++ ; + else { + /* + * put pkt in the right position (first check for conflicts). + */ + int c = index[i] ; + + if (index[c] == c) { + DEB(fprintf(stderr, "\nshuffle, error at %d\n", i);) + return 1 ; + } + SWAP(index[i], index[c], int) ; + SWAP(pkt[i], pkt[c], gf *) ; + } + } + DEB( /* just test that it works... */ + for ( i = 0 ; i < k ; i++ ) { + if (index[i] < k && index[i] != i) { + fprintf(stderr, "shuffle: after\n"); + for (i=0; i<k ; i++) fprintf(stderr, "%3d ", index[i]); + fprintf(stderr, "\n"); + return 1 ; + } + } + ) + return 0 ; +} + +/* + * build_decode_matrix constructs the encoding matrix given the + * indexes. The matrix must be already allocated as + * a vector of k*k elements, in row-major order + */ +static gf * +build_decode_matrix(struct fec_parms *code, gf *pkt[], int index[]) +{ + int i , k = code->k ; + gf *p, *matrix = NEW_GF_MATRIX(k, k); + + TICK(ticks[9]); + for (i = 0, p = matrix ; i < k ; i++, p += k ) { +#if 1 /* this is simply an optimization, not very useful indeed */ + if (index[i] < k) { + bzero(p, k*sizeof(gf) ); + p[i] = 1 ; + } else +#endif + if (index[i] < code->n ) + bcopy( &(code->enc_matrix[index[i]*k]), p, k*sizeof(gf) ); + else { + fprintf(stderr, "decode: invalid index %d (max %d)\n", + index[i], code->n - 1 ); + free(matrix) ; + return NULL ; + } + } + TICK(ticks[9]); + if (invert_mat(matrix, k)) { + free(matrix); + matrix = NULL ; + } + TOCK(ticks[9]); + return matrix ; +} + +/* + * fec_decode receives as input a vector of packets, the indexes of + * packets, and produces the correct vector as output. + * + * Input: + * code: pointer to code descriptor + * pkt: pointers to received packets. They are modified + * to store the output packets (in place) + * index: pointer to packet indexes (modified) + * sz: size of each packet + */ +int +fec_decode(struct fec_parms *code, gf *pkt[], int index[], int sz) +{ + gf *m_dec ; + gf **new_pkt ; + int row, col , k = code->k ; + + if (GF_BITS > 8) + sz /= 2 ; + + if (shuffle(pkt, index, k)) /* error if true */ + return 1 ; + m_dec = build_decode_matrix(code, pkt, index); + + if (m_dec == NULL) + return 1 ; /* error */ + /* + * do the actual decoding + */ + new_pkt = (gf**)my_malloc (k * sizeof (gf * ), "new pkt pointers" ); + for (row = 0 ; row < k ; row++ ) { + if (index[row] >= k) { + new_pkt[row] = (gf*) my_malloc (sz * sizeof (gf), "new pkt buffer" ); + bzero(new_pkt[row], sz * sizeof(gf) ) ; + for (col = 0 ; col < k ; col++ ) + addmul(new_pkt[row], pkt[col], m_dec[row*k + col], sz) ; + } + } + /* + * move pkts to their final destination + */ + for (row = 0 ; row < k ; row++ ) { + if (index[row] >= k) { + bcopy(new_pkt[row], pkt[row], sz*sizeof(gf)); + free(new_pkt[row]); + } + } + free(new_pkt); + free(m_dec); + + return 0; +} diff --git a/libtransport/src/core/fec.h b/libtransport/src/core/fec.h new file mode 100644 index 000000000..8234057a7 --- /dev/null +++ b/libtransport/src/core/fec.h @@ -0,0 +1,65 @@ +/* + * fec.c -- forward error correction based on Vandermonde matrices + * 980614 + * (C) 1997-98 Luigi Rizzo (luigi@iet.unipi.it) + * + * Portions derived from code by Phil Karn (karn@ka9q.ampr.org), + * Robert Morelos-Zaragoza (robert@spectra.eng.hawaii.edu) and Hari + * Thirumoorthy (harit@spectra.eng.hawaii.edu), Aug 1995 + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR + * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY + * OF SUCH DAMAGE. + */ + +/* + * The following parameter defines how many bits are used for + * field elements. The code supports any value from 2 to 16 + * but fastest operation is achieved with 8 bit elements + * This is the only parameter you may want to change. + */ +#ifndef GF_BITS +#define GF_BITS 8 /* code over GF(2**GF_BITS) - change to suit */ +#endif + +#if (GF_BITS <= 8) +typedef unsigned char gf; +#else +typedef unsigned short gf; +#endif + +#define GF_SIZE ((1 << GF_BITS) - 1) /* powers of \alpha */ + +struct fec_parms { + unsigned long magic ; + int k, n ; /* parameters of the code */ + gf *enc_matrix ; +}; + +void fec_free(struct fec_parms *p) ; +struct fec_parms *fec_new(int k, int n) ; + +void fec_encode(struct fec_parms *code, gf *src[], gf *fec, int index, int sz); +int fec_decode(struct fec_parms *code, gf *pkt[], int index[], int sz); + +/* end of file */ diff --git a/libtransport/src/core/global_configuration.cc b/libtransport/src/core/global_configuration.cc new file mode 100644 index 000000000..e0b6c040a --- /dev/null +++ b/libtransport/src/core/global_configuration.cc @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/utils/log.h> + +#include <libconfig.h++> +#include <map> + +namespace transport { +namespace core { + +GlobalConfiguration::GlobalConfiguration() {} + +bool GlobalConfiguration::parseTransportConfig(const std::string& path) { + using namespace libconfig; + Config cfg; + + try { + cfg.readFile(path.c_str()); + } catch (const FileIOException& fioex) { + TRANSPORT_LOGE("I/O error while reading file."); + return false; + } catch (const ParseException& pex) { + TRANSPORT_LOGE("Parse error at %s:%d - %s", pex.getFile(), pex.getLine(), + pex.getError()); + return false; + } + + Setting& root = cfg.getRoot(); + + /** + * Iterate over sections. Best thing to do here would be to have other + * components of the program registering a callback here, to parse their + * section of the configuration file. + */ + for (auto section = root.begin(); section != root.end(); section++) { + std::string name = section->getName(); + std::error_code ec; + TRANSPORT_LOGD("Parsing Section: %s", name.c_str()); + + auto it = configuration_parsers_.find(name); + if (it != configuration_parsers_.end() && !it->second.first) { + TRANSPORT_LOGD("Found valid configuration parser"); + it->second.second(*section, ec); + it->second.first = true; + } + } + + return true; +} + +void GlobalConfiguration::parseConfiguration(const std::string& path) { + // Check if an environment variable with the configuration path exists. COnf + // variable comes first. + std::unique_lock<std::mutex> lck(cp_mtx_); + if (const char* env_c = std::getenv(GlobalConfiguration::conf_file)) { + parseTransportConfig(env_c); + } else if (!path.empty()) { + conf_file_path_ = path; + parseTransportConfig(conf_file_path_); + } else { + TRANSPORT_LOGD( + "Called parseConfiguration but no configuration file was provided."); + } +} + +void GlobalConfiguration::registerConfigurationSetter( + const std::string& key, const SetCallback& set_callback) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_setters_.find(key) != configuration_setters_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration setter %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_setters_.emplace(key, set_callback); + } +} + +void GlobalConfiguration::registerConfigurationGetter( + const std::string& key, const GetCallback& get_callback) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_getters_.find(key) != configuration_getters_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration getter %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_getters_.emplace(key, get_callback); + } +} + +void GlobalConfiguration::registerConfigurationParser( + const std::string& key, const ParserCallback& parser) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_parsers_.find(key) != configuration_parsers_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration key %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_parsers_.emplace(key, std::make_pair(false, parser)); + + // Trigger a parsing of the configuration. + if (!conf_file_path_.empty()) { + parseTransportConfig(conf_file_path_); + } + } +} + +void GlobalConfiguration::unregisterConfigurationParser( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_parsers_.find(key); + if (it != configuration_parsers_.end()) { + configuration_parsers_.erase(it); + } +} + +void GlobalConfiguration::unregisterConfigurationSetter( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_setters_.find(key); + if (it != configuration_setters_.end()) { + configuration_setters_.erase(it); + } +} + +void GlobalConfiguration::unregisterConfigurationGetter( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_getters_.find(key); + if (it != configuration_getters_.end()) { + configuration_getters_.erase(it); + } +} + +void GlobalConfiguration::getConfiguration( + interface::global_config::ConfigurationObject& configuration_object, + std::error_code& ec) { + auto it = configuration_getters_.find(configuration_object.getKey()); + + if (it != configuration_getters_.end()) { + it->second(configuration_object, ec); + } +} + +void GlobalConfiguration::setConfiguration( + const interface::global_config::ConfigurationObject& configuration_object, + std::error_code& ec) { + auto it = configuration_setters_.find(configuration_object.getKey()); + + if (it != configuration_setters_.end()) { + it->second(configuration_object, ec); + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/global_configuration.h b/libtransport/src/core/global_configuration.h new file mode 100644 index 000000000..dcc8d94e3 --- /dev/null +++ b/libtransport/src/core/global_configuration.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/utils/singleton.h> + +#include <functional> +#include <map> +#include <memory> +#include <mutex> +#include <system_error> + +namespace libconfig { +class Setting; +} + +namespace transport { +namespace core { + +/** + * Class holding workflow for global configuration. + * This class does not contains the actual configuration, which is rather stored + * inside the modules to be configured. This class contains the handlers to call + * for getting/setting the configurations and to parse the corresponding + * sections of the configuration file. Each class register 3 callbacks: one to + * parse conf section and 2 to set/get the configuration through programming + * interface. + */ +class GlobalConfiguration : public utils::Singleton<GlobalConfiguration> { + static const constexpr char *conf_file = "TRANSPORT_CONFIG"; + friend class utils::Singleton<GlobalConfiguration>; + + public: + /** + * This callback will be called by GlobalConfiguration in + * + */ + using ParserCallback = std::function<void(const libconfig::Setting &config, + std::error_code &ec)>; + using GetCallback = + std::function<void(interface::global_config::ConfigurationObject &object, + std::error_code &ec)>; + + using SetCallback = std::function<void( + const interface::global_config::ConfigurationObject &object, + std::error_code &ec)>; + + ~GlobalConfiguration() = default; + + public: + void parseConfiguration(const std::string &path); + + void registerConfigurationParser(const std::string &key, + const ParserCallback &parser); + + void registerConfigurationSetter(const std::string &key, + const SetCallback &set_callback); + void registerConfigurationGetter(const std::string &key, + const GetCallback &get_callback); + + void unregisterConfigurationParser(const std::string &key); + + void unregisterConfigurationSetter(const std::string &key); + + void unregisterConfigurationGetter(const std::string &key); + + void getConfiguration( + interface::global_config::ConfigurationObject &configuration_object, + std::error_code &ec); + void setConfiguration( + const interface::global_config::ConfigurationObject &configuration_object, + std::error_code &ec); + + private: + GlobalConfiguration(); + std::string conf_file_path_; + bool parseTransportConfig(const std::string &path); + + private: + std::mutex cp_mtx_; + using ParserPair = std::pair<bool, ParserCallback>; + std::map<std::string, ParserPair> configuration_parsers_; + std::map<std::string, GetCallback> configuration_getters_; + std::map<std::string, SetCallback> configuration_setters_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/interest.cc b/libtransport/src/core/interest.cc index 9ee662615..06cbe9f81 100644 --- a/libtransport/src/core/interest.cc +++ b/libtransport/src/core/interest.cc @@ -31,8 +31,9 @@ namespace transport { namespace core { -Interest::Interest(const Name &interest_name, Packet::Format format) - : Packet(format) { +Interest::Interest(const Name &interest_name, Packet::Format format, + std::size_t additional_header_size) + : Packet(format, additional_header_size) { if (hicn_interest_set_name(format_, packet_start_, interest_name.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); @@ -45,20 +46,14 @@ Interest::Interest(const Name &interest_name, Packet::Format format) } #ifdef __ANDROID__ -Interest::Interest(hicn_format_t format) : Interest(Name("0::0|0"), format) {} +Interest::Interest(hicn_format_t format, std::size_t additional_header_size) + : Interest(Name("0::0|0"), format, additional_header_size) {} #else -Interest::Interest(hicn_format_t format) : Interest(base_name, format) {} +Interest::Interest(hicn_format_t format, std::size_t additional_header_size) + : Interest(base_name, format, additional_header_size) {} #endif -Interest::Interest(const uint8_t *buffer, std::size_t size) - : Packet(buffer, size) { - if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { - throw errors::MalformedPacketException(); - } -} - -Interest::Interest(MemBufPtr &&buffer) : Packet(std::move(buffer)) { +Interest::Interest(MemBuf &&buffer) : Packet(std::move(buffer)) { if (hicn_interest_get_name(format_, packet_start_, name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); @@ -70,6 +65,14 @@ Interest::Interest(Interest &&other_interest) name_ = std::move(other_interest.name_); } +Interest::Interest(const Interest &other_interest) : Packet(other_interest) { + name_ = other_interest.name_; +} + +Interest &Interest::operator=(const Interest &other) { + return (Interest &)Packet::operator=(other); +} + Interest::~Interest() {} const Name &Interest::getName() const { @@ -152,6 +155,59 @@ void Interest::resetForHash() { } } +bool Interest::hasManifest() { + return (getPayloadType() == PayloadType::MANIFEST); +} + +void Interest::appendSuffix(std::uint32_t suffix) { + if (TRANSPORT_EXPECT_FALSE(suffix_set_.empty())) { + setPayloadType(PayloadType::MANIFEST); + } + + suffix_set_.emplace(suffix); +} + +void Interest::encodeSuffixes() { + if (!hasManifest()) { + return; + } + + // We assume interest does not hold signature for the moment. + auto int_manifest_header = + (InterestManifestHeader *)(writableData() + headerSize()); + int_manifest_header->n_suffixes = suffix_set_.size(); + std::size_t additional_length = + int_manifest_header->n_suffixes * sizeof(uint32_t); + + uint32_t *suffix = (uint32_t *)(int_manifest_header + 1); + for (auto it = suffix_set_.begin(); it != suffix_set_.end(); it++, suffix++) { + *suffix = *it; + } + + updateLength(additional_length); +} + +uint32_t *Interest::firstSuffix() { + if (!hasManifest()) { + return nullptr; + } + + auto ret = (InterestManifestHeader *)(writableData() + headerSize()); + ret += 1; + + return (uint32_t *)ret; +} + +uint32_t Interest::numberOfSuffixes() { + if (!hasManifest()) { + return 0; + } + + auto header = (InterestManifestHeader *)(writableData() + headerSize()); + + return header->n_suffixes; +} + } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/io_module.cc b/libtransport/src/core/io_module.cc new file mode 100644 index 000000000..fef0c1504 --- /dev/null +++ b/libtransport/src/core/io_module.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dlfcn.h> +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/utils/log.h> + +#ifdef ANDROID +#include <io_modules/udp/hicn_forwarder_module.h> +#endif + +#include <deque> + +namespace transport { +namespace core { + +IoModule::~IoModule() {} + +IoModule *IoModule::load(const char *module_name) { +#ifdef ANDROID + return new HicnForwarderModule(); +#else + void *handle = 0; + IoModule *module = 0; + IoModule *(*creator)(void) = 0; + const char *error = 0; + + // open module + handle = dlopen(module_name, RTLD_NOW); + if (!handle) { + if ((error = dlerror()) != 0) { + TRANSPORT_LOGE("%s", error); + } + return 0; + } + + // link factory method + creator = (IoModule * (*)(void)) dlsym(handle, "create_module"); + if (!creator) { + if ((error = dlerror()) != 0) { + TRANSPORT_LOGE("%s", error); + return 0; + } + } + + // create object and return it + module = (*creator)(); + module->handle_ = handle; + + return module; +#endif +} + +bool IoModule::unload(IoModule *module) { + if (!module) { + return false; + } + +#ifdef ANDROID + delete module; +#else + // destroy object and close module + void *handle = module->handle_; + delete module; + dlclose(handle); +#endif + + return true; +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.cc b/libtransport/src/core/local_connector.cc new file mode 100644 index 000000000..f0e36a3d7 --- /dev/null +++ b/libtransport/src/core/local_connector.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/local_connector.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +LocalConnector::~LocalConnector() {} + +void LocalConnector::close() { state_ = State::CLOSED; } + +void LocalConnector::send(Packet &packet) { + if (!isConnected()) { + return; + } + + TRANSPORT_LOGD("Sending packet to local socket."); + io_service_.get().post([this, p{packet.shared_from_this()}]() mutable { + receive_callback_(this, *p, std::make_error_code(std::errc(0))); + }); +} + +void LocalConnector::send(const uint8_t *packet, std::size_t len) { + throw errors::NotImplementedException(); +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.h b/libtransport/src/core/local_connector.h new file mode 100644 index 000000000..b0daa4f53 --- /dev/null +++ b/libtransport/src/core/local_connector.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/move_wrapper.h> +#include <hicn/transport/utils/shared_ptr_utils.h> +#include <io_modules/forwarder/errors.h> + +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +class LocalConnector : public Connector { + public: + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect> + LocalConnector(asio::io_service &io_service, + ReceiveCallback &&receive_callback, SentCallback &&packet_sent, + OnClose &&close_callback, OnReconnect &&on_reconnect) + : Connector(receive_callback, packet_sent, close_callback, on_reconnect), + io_service_(io_service), + io_service_work_(io_service_.get()) { + state_ = State::CONNECTED; + } + + ~LocalConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + auto shared_from_this() { return utils::shared_from(this); } + + private: + std::reference_wrapper<asio::io_service> io_service_; + asio::io_service::work io_service_work_; + std::string name_; +}; + +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/manifest.h b/libtransport/src/core/manifest.h index eadfed752..9b25ebd67 100644 --- a/libtransport/src/core/manifest.h +++ b/libtransport/src/core/manifest.h @@ -15,11 +15,10 @@ #pragma once +#include <core/manifest_format.h> #include <hicn/transport/core/content_object.h> #include <hicn/transport/core/name.h> -#include <core/manifest_format.h> - #include <set> namespace transport { @@ -36,18 +35,20 @@ class Manifest : public Base { "Base must inherit from packet!"); public: + // core::ContentObjectManifest::Ptr + using Encoder = typename FormatTraits::Encoder; using Decoder = typename FormatTraits::Decoder; Manifest(std::size_t signature_size = 0) - : Base(HF_INET6_TCP_AH), + : Base(HF_INET6_TCP_AH, signature_size), encoder_(*this, signature_size), decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); } Manifest(const core::Name &name, std::size_t signature_size = 0) - : Base(name, HF_INET6_TCP_AH), + : Base(name, HF_INET6_TCP_AH, signature_size), encoder_(*this, signature_size), decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); @@ -55,7 +56,9 @@ class Manifest : public Base { template <typename T> Manifest(T &&base) - : Base(std::forward<T &&>(base)), encoder_(*this), decoder_(*this) { + : Base(std::forward<T &&>(base)), + encoder_(*this, 0, false), + decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); } @@ -96,13 +99,13 @@ class Manifest : public Base { return *this; } - Manifest &setHashAlgorithm(utils::CryptoHashType hash_algorithm) { + Manifest &setHashAlgorithm(auth::CryptoHashType hash_algorithm) { hash_algorithm_ = hash_algorithm; encoder_.setHashAlgorithm(hash_algorithm_); return *this; } - utils::CryptoHashType getHashAlgorithm() { return hash_algorithm_; } + auth::CryptoHashType getHashAlgorithm() { return hash_algorithm_; } ManifestType getManifestType() const { return manifest_type_; } @@ -138,7 +141,7 @@ class Manifest : public Base { protected: ManifestType manifest_type_; - utils::CryptoHashType hash_algorithm_; + auth::CryptoHashType hash_algorithm_; bool is_last_; Encoder encoder_; diff --git a/libtransport/src/core/manifest_format.h b/libtransport/src/core/manifest_format.h index 36d23f99b..b759942cb 100644 --- a/libtransport/src/core/manifest_format.h +++ b/libtransport/src/core/manifest_format.h @@ -15,8 +15,8 @@ #pragma once +#include <hicn/transport/auth/crypto_hasher.h> #include <hicn/transport/core/name.h> -#include <hicn/transport/security/crypto_hasher.h> #include <cinttypes> #include <type_traits> @@ -63,8 +63,10 @@ template <typename T> struct format_traits { using Encoder = typename T::Encoder; using Decoder = typename T::Decoder; + using Hash = typename T::Hash; using HashType = typename T::HashType; - using HashList = typename T::HashList; + using Suffix = typename T::Suffix; + using SuffixList = typename T::SuffixList; }; class Packet; @@ -86,7 +88,7 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).setManifestTypeImpl(type); } - ManifestEncoder &setHashAlgorithm(utils::CryptoHashType hash) { + ManifestEncoder &setHashAlgorithm(auth::CryptoHashType hash) { return static_cast<Implementation &>(*this).setHashAlgorithmImpl(hash); } @@ -160,7 +162,7 @@ class ManifestDecoder { return static_cast<const Implementation &>(*this).getManifestTypeImpl(); } - utils::CryptoHashType getHashAlgorithm() const { + auth::CryptoHashType getHashAlgorithm() const { return static_cast<const Implementation &>(*this).getHashAlgorithmImpl(); } diff --git a/libtransport/src/core/manifest_format_fixed.cc b/libtransport/src/core/manifest_format_fixed.cc index ca80c38b1..55280b460 100644 --- a/libtransport/src/core/manifest_format_fixed.cc +++ b/libtransport/src/core/manifest_format_fixed.cc @@ -13,49 +13,50 @@ * limitations under the License. */ +#include <core/manifest_format_fixed.h> #include <hicn/transport/core/packet.h> #include <hicn/transport/utils/literals.h> -#include <core/manifest_format_fixed.h> - namespace transport { namespace core { // TODO use preallocated pool of membufs FixedManifestEncoder::FixedManifestEncoder(Packet &packet, - std::size_t signature_size) + std::size_t signature_size, + bool clear) : packet_(packet), - max_size_(Packet::default_mtu - packet_.headerSize() - signature_size), - manifest_( - utils::MemBuf::create(Packet::default_mtu - packet_.headerSize())), - manifest_header_( - reinterpret_cast<ManifestHeader *>(manifest_->writableData())), - manifest_entries_(reinterpret_cast<ManifestEntry *>( - manifest_->writableData() + sizeof(ManifestHeader))), + max_size_(Packet::default_mtu - packet_.headerSize()), + manifest_header_(reinterpret_cast<ManifestHeader *>( + packet_.writableData() + packet_.headerSize())), + manifest_entries_( + reinterpret_cast<ManifestEntry *>(manifest_header_ + 1)), current_entry_(0), signature_size_(signature_size) { - *manifest_header_ = {0}; + if (clear) { + *manifest_header_ = {0}; + } } FixedManifestEncoder::~FixedManifestEncoder() {} FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { - manifest_->append(sizeof(ManifestHeader) + - manifest_header_->number_of_entries * - sizeof(ManifestEntry)); - packet_.appendPayload(std::move(manifest_)); + packet_.append(sizeof(ManifestHeader) + + manifest_header_->number_of_entries * sizeof(ManifestEntry)); + packet_.updateLength(); return *this; } FixedManifestEncoder &FixedManifestEncoder::clearImpl() { - manifest_ = utils::MemBuf::create(Packet::default_mtu - packet_.headerSize() - - signature_size_); + packet_.trimEnd(sizeof(ManifestHeader) + + manifest_header_->number_of_entries * sizeof(ManifestEntry)); + current_entry_ = 0; + *manifest_header_ = {0}; return *this; } FixedManifestEncoder &FixedManifestEncoder::setHashAlgorithmImpl( - utils::CryptoHashType algorithm) { + auth::CryptoHashType algorithm) { manifest_header_->hash_algorithm = static_cast<uint8_t>(algorithm); return *this; } @@ -83,7 +84,7 @@ FixedManifestEncoder &FixedManifestEncoder::setBaseNameImpl( } FixedManifestEncoder &FixedManifestEncoder::addSuffixAndHashImpl( - uint32_t suffix, const utils::CryptoHash &hash) { + uint32_t suffix, const auth::CryptoHash &hash) { auto _hash = hash.getDigest<std::uint8_t>(); addSuffixHashBytes(suffix, _hash.data(), _hash.length()); return *this; @@ -170,8 +171,8 @@ ManifestType FixedManifestDecoder::getManifestTypeImpl() const { return static_cast<ManifestType>(manifest_header_->manifest_type); } -utils::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { - return static_cast<utils::CryptoHashType>(manifest_header_->hash_algorithm); +auth::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { + return static_cast<auth::CryptoHashType>(manifest_header_->hash_algorithm); } NextSegmentCalculationStrategy diff --git a/libtransport/src/core/manifest_format_fixed.h b/libtransport/src/core/manifest_format_fixed.h index 1d7cd7d32..56ad4ef6d 100644 --- a/libtransport/src/core/manifest_format_fixed.h +++ b/libtransport/src/core/manifest_format_fixed.h @@ -15,9 +15,8 @@ #pragma once -#include <hicn/transport/core/packet.h> - #include <core/manifest_format.h> +#include <hicn/transport/core/packet.h> #include <string> @@ -53,8 +52,10 @@ class Packet; struct Fixed { using Encoder = FixedManifestEncoder; using Decoder = FixedManifestDecoder; - using HashType = utils::CryptoHash; - using SuffixList = std::list<std::pair<std::uint32_t, std::uint8_t *>>; + using Hash = auth::CryptoHash; + using HashType = auth::CryptoHashType; + using Suffix = uint32_t; + using SuffixList = std::list<std::pair<uint32_t, uint8_t *>>; }; struct Flags { @@ -84,7 +85,8 @@ static const constexpr std::uint8_t manifest_version = 1; class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { public: - FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0); + FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0, + bool clear = true); ~FixedManifestEncoder(); @@ -94,7 +96,7 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setManifestTypeImpl(ManifestType manifest_type); - FixedManifestEncoder &setHashAlgorithmImpl(utils::CryptoHashType algorithm); + FixedManifestEncoder &setHashAlgorithmImpl(Fixed::HashType algorithm); FixedManifestEncoder &setNextSegmentCalculationStrategyImpl( NextSegmentCalculationStrategy strategy); @@ -102,7 +104,7 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setBaseNameImpl(const core::Name &base_name); FixedManifestEncoder &addSuffixAndHashImpl(uint32_t suffix, - const utils::CryptoHash &hash); + const Fixed::Hash &hash); FixedManifestEncoder &setIsFinalManifestImpl(bool is_last); @@ -125,7 +127,6 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { Packet &packet_; std::size_t max_size_; - std::unique_ptr<utils::MemBuf> manifest_; ManifestHeader *manifest_header_; ManifestEntry *manifest_entries_; std::size_t current_entry_; @@ -144,7 +145,7 @@ class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { ManifestType getManifestTypeImpl() const; - utils::CryptoHashType getHashAlgorithmImpl() const; + Fixed::HashType getHashAlgorithmImpl() const; NextSegmentCalculationStrategy getNextSegmentCalculationStrategyImpl() const; diff --git a/libtransport/src/core/manifest_inline.h b/libtransport/src/core/manifest_inline.h index dedf82b45..fcb1d214f 100644 --- a/libtransport/src/core/manifest_inline.h +++ b/libtransport/src/core/manifest_inline.h @@ -30,8 +30,12 @@ class ManifestInline : public Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>> { using ManifestBase = Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>>; + + using Hash = typename FormatTraits::Hash; using HashType = typename FormatTraits::HashType; + using Suffix = typename FormatTraits::Suffix; using SuffixList = typename FormatTraits::SuffixList; + using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; public: ManifestInline() : ManifestBase() {} @@ -44,7 +48,7 @@ class ManifestInline static TRANSPORT_ALWAYS_INLINE ManifestInline *createManifest( const core::Name &manifest_name, ManifestVersion version, - ManifestType type, utils::CryptoHashType algorithm, bool is_last, + ManifestType type, auth::CryptoHashType algorithm, bool is_last, const Name &base_name, NextSegmentCalculationStrategy strategy, std::size_t signature_size) { auto manifest = new ManifestInline(manifest_name, signature_size); @@ -84,7 +88,7 @@ class ManifestInline const Name &getBaseName() { return base_name_; } - ManifestInline &addSuffixHash(uint32_t suffix, const HashType &hash) { + ManifestInline &addSuffixHash(Suffix suffix, const Hash &hash) { ManifestBase::encoder_.addSuffixAndHash(suffix, hash); return *this; } @@ -104,12 +108,35 @@ class ManifestInline return next_segment_strategy_; } + // Convert several manifests into a single map from suffixes to packet hashes. + // All manifests must have been decoded beforehand. + static std::unordered_map<Suffix, HashEntry> getSuffixMap( + const std::vector<ManifestInline *> &manifests) { + std::unordered_map<Suffix, HashEntry> suffix_map; + + for (auto manifest_ptr : manifests) { + HashType hash_algorithm = manifest_ptr->getHashAlgorithm(); + SuffixList suffix_list = manifest_ptr->getSuffixList(); + + for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { + std::vector<uint8_t> hash( + it->second, it->second + auth::hash_size_map[hash_algorithm]); + suffix_map[it->first] = {hash_algorithm, hash}; + } + } + + return suffix_map; + } + static std::unordered_map<Suffix, HashEntry> getSuffixMap( + ManifestInline *manifest) { + return getSuffixMap(std::vector<ManifestInline *>{manifest}); + } + private: core::Name base_name_; NextSegmentCalculationStrategy next_segment_strategy_; SuffixList suffix_hash_map_; }; -} // end namespace core - -} // end namespace transport
\ No newline at end of file +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/name.cc b/libtransport/src/core/name.cc index 811e93b87..795c8a697 100644 --- a/libtransport/src/core/name.cc +++ b/libtransport/src/core/name.cc @@ -13,14 +13,13 @@ * limitations under the License. */ +#include <core/manifest_format.h> #include <hicn/transport/core/name.h> #include <hicn/transport/errors/errors.h> #include <hicn/transport/errors/tokenizer_exception.h> #include <hicn/transport/utils/hash.h> #include <hicn/transport/utils/string_tokenizer.h> -#include <core/manifest_format.h> - namespace transport { namespace core { @@ -98,7 +97,12 @@ bool Name::operator!=(const Name &name) const { } Name::operator bool() const { - return bool(hicn_name_empty((hicn_name_t *)&name_)); + auto ret = isValid(); + return ret; +} + +bool Name::isValid() const { + return bool(!hicn_name_empty((hicn_name_t *)&name_)); } bool Name::equals(const Name &name, bool consider_segment) const { diff --git a/libtransport/src/core/packet.cc b/libtransport/src/core/packet.cc index cd2c5aa69..6f237729a 100644 --- a/libtransport/src/core/packet.cc +++ b/libtransport/src/core/packet.cc @@ -31,59 +31,94 @@ namespace core { const core::Name Packet::base_name("0::0|0"); -Packet::Packet(Format format) - : packet_(utils::MemBuf::create(getHeaderSizeFromFormat(format, 256)) - .release()), - packet_start_(reinterpret_cast<hicn_header_t *>(packet_->writableData())), - header_head_(packet_.get()), - payload_head_(nullptr), - format_(format) { - if (hicn_packet_init_header(format, packet_start_) < 0) { - throw errors::RuntimeException("Unexpected error initializing the packet."); - } - - packet_->append(getHeaderSizeFromFormat(format_)); -} - -Packet::Packet(MemBufPtr &&buffer) - : packet_(std::move(buffer)), - packet_start_(reinterpret_cast<hicn_header_t *>(packet_->writableData())), - header_head_(packet_.get()), - payload_head_(nullptr), - format_(getFormatFromBuffer(packet_->writableData(), packet_->length())) { -} - -Packet::Packet(const uint8_t *buffer, std::size_t size) - : Packet(MemBufPtr(utils::MemBuf::copyBuffer(buffer, size).release())) {} +Packet::Packet(Format format, std::size_t additional_header_size) + : utils::MemBuf(utils::MemBuf(CREATE, 2048)), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(format), + payload_type_(PayloadType::UNSPECIFIED) { + setFormat(format_, additional_header_size); +} + +Packet::Packet(MemBuf &&buffer) + : utils::MemBuf(std::move(buffer)), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(data(), length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(CopyBufferOp, const uint8_t *buffer, std::size_t size) + : utils::MemBuf(COPY_BUFFER, buffer, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(data(), length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(WrapBufferOp, uint8_t *buffer, std::size_t length, + std::size_t size) + : utils::MemBuf(WRAP_BUFFER, buffer, length, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(this->data(), this->length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(CreateOp, uint8_t *buffer, std::size_t length, std::size_t size, + Format format, std::size_t additional_header_size) + : utils::MemBuf(WRAP_BUFFER, buffer, length, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(format), + payload_type_(PayloadType::UNSPECIFIED) { + clear(); + setFormat(format_, additional_header_size); +} + +Packet::Packet(const Packet &other) + : utils::MemBuf(other), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(other.header_offset_), + format_(other.format_), + payload_type_(PayloadType::UNSPECIFIED) {} Packet::Packet(Packet &&other) - : packet_(std::move(other.packet_)), + : utils::MemBuf(std::move(other)), packet_start_(other.packet_start_), - header_head_(other.header_head_), - payload_head_(other.payload_head_), - format_(other.format_) { + header_offset_(other.header_offset_), + format_(other.format_), + payload_type_(PayloadType::UNSPECIFIED) { other.packet_start_ = nullptr; - other.header_head_ = nullptr; - other.payload_head_ = nullptr; other.format_ = HF_UNSPEC; + other.header_offset_ = 0; } Packet::~Packet() {} +Packet &Packet::operator=(const Packet &other) { + if (this != &other) { + *this = other; + packet_start_ = reinterpret_cast<hicn_header_t *>(writableData()); + } + + return *this; +} + std::size_t Packet::getHeaderSizeFromBuffer(Format format, const uint8_t *buffer) { size_t header_length; + if (hicn_packet_get_header_length(format, (hicn_header_t *)buffer, &header_length) < 0) { throw errors::MalformedPacketException(); } + return header_length; } bool Packet::isInterest(const uint8_t *buffer) { bool is_interest = false; - if (TRANSPORT_EXPECT_FALSE(hicn_packet_test_ece((const hicn_header_t *)buffer, + if (TRANSPORT_EXPECT_FALSE(hicn_packet_test_ece(HF_INET6_TCP, + (const hicn_header_t *)buffer, &is_interest) < 0)) { throw errors::RuntimeException( "Impossible to retrieve ece flag from packet"); @@ -92,6 +127,25 @@ bool Packet::isInterest(const uint8_t *buffer) { return !is_interest; } +void Packet::setFormat(Packet::Format format, + std::size_t additional_header_size) { + format_ = format; + if (hicn_packet_init_header(format_, packet_start_) < 0) { + throw errors::RuntimeException("Unexpected error initializing the packet."); + } + + auto header_size = getHeaderSizeFromFormat(format_); + assert(header_size <= tailroom()); + append(header_size); + + assert(additional_header_size <= tailroom()); + append(additional_header_size); + + header_offset_ = length(); +} + +bool Packet::isInterest() { return Packet::isInterest(data()); } + std::size_t Packet::getPayloadSizeFromBuffer(Format format, const uint8_t *buffer) { std::size_t payload_length; @@ -105,67 +159,58 @@ std::size_t Packet::getPayloadSizeFromBuffer(Format format, } std::size_t Packet::payloadSize() const { - return getPayloadSizeFromBuffer(format_, - reinterpret_cast<uint8_t *>(packet_start_)); + std::size_t ret = 0; + + if (length()) { + ret = getPayloadSizeFromBuffer(format_, + reinterpret_cast<uint8_t *>(packet_start_)); + } + + return ret; } std::size_t Packet::headerSize() const { - return getHeaderSizeFromBuffer(format_, - reinterpret_cast<uint8_t *>(packet_start_)); + if (header_offset_ == 0 && length()) { + const_cast<Packet *>(this)->header_offset_ = getHeaderSizeFromBuffer( + format_, reinterpret_cast<uint8_t *>(packet_start_)); + } + + return header_offset_; } Packet &Packet::appendPayload(std::unique_ptr<utils::MemBuf> &&payload) { - separateHeaderPayload(); - - if (!payload_head_) { - payload_head_ = payload.get(); - } - - header_head_->prependChain(std::move(payload)); + prependChain(std::move(payload)); updateLength(); return *this; } Packet &Packet::appendPayload(const uint8_t *buffer, std::size_t length) { - return appendPayload(utils::MemBuf::copyBuffer(buffer, length)); -} - -Packet &Packet::appendHeader(std::unique_ptr<utils::MemBuf> &&header) { - separateHeaderPayload(); + prependPayload(&buffer, &length); - if (!payload_head_) { - header_head_->prependChain(std::move(header)); - } else { - payload_head_->prependChain(std::move(header)); + if (length) { + appendPayload(utils::MemBuf::copyBuffer(buffer, length)); } updateLength(); return *this; } -Packet &Packet::appendHeader(const uint8_t *buffer, std::size_t length) { - return appendHeader(utils::MemBuf::copyBuffer(buffer, length)); -} - std::unique_ptr<utils::MemBuf> Packet::getPayload() const { - const_cast<Packet *>(this)->separateHeaderPayload(); - - // Hopefully the payload is contiguous - if (TRANSPORT_EXPECT_FALSE(payload_head_ && - payload_head_->next() != header_head_)) { - payload_head_->gather(payloadSize()); - } - - return payload_head_->cloneOne(); + auto ret = clone(); + ret->trimStart(headerSize()); + return ret; } Packet &Packet::updateLength(std::size_t length) { std::size_t total_length = length; - for (utils::MemBuf *current = payload_head_; - current && current != header_head_; current = current->next()) { + const utils::MemBuf *current = this; + do { total_length += current->length(); - } + current = current->next(); + } while (current != this); + + total_length -= headerSize(); if (hicn_packet_set_payload_length(format_, packet_start_, total_length) < 0) { @@ -176,13 +221,16 @@ Packet &Packet::updateLength(std::size_t length) { } PayloadType Packet::getPayloadType() const { - hicn_payload_type_t ret = HPT_UNSPEC; + if (payload_type_ == PayloadType::UNSPECIFIED) { + hicn_payload_type_t ret; + if (hicn_packet_get_payload_type(packet_start_, &ret) < 0) { + throw errors::RuntimeException("Impossible to retrieve payload type."); + } - if (hicn_packet_get_payload_type(packet_start_, &ret) < 0) { - throw errors::RuntimeException("Impossible to retrieve payload type."); + payload_type_ = (PayloadType)ret; } - return PayloadType(ret); + return payload_type_; } Packet &Packet::setPayloadType(PayloadType payload_type) { @@ -191,60 +239,76 @@ Packet &Packet::setPayloadType(PayloadType payload_type) { throw errors::RuntimeException("Error setting payload type of the packet."); } + payload_type_ = payload_type; + return *this; } Packet::Format Packet::getFormat() const { - if (format_ == HF_UNSPEC) { + /** + * We check packet start because after a movement it will result in a nullptr + */ + if (format_ == HF_UNSPEC && length()) { if (hicn_packet_get_format(packet_start_, &format_) < 0) { - throw errors::MalformedPacketException(); + TRANSPORT_LOGE("Unexpected packet format."); } } return format_; } -const std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() const { - return packet_; +std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() { + return std::static_pointer_cast<utils::MemBuf>(shared_from_this()); } void Packet::dump() const { - const_cast<Packet *>(this)->separateHeaderPayload(); - TRANSPORT_LOGI("HEADER -- Length: %zu", headerSize()); - hicn_packet_dump((uint8_t *)header_head_->data(), headerSize()); - TRANSPORT_LOGI("PAYLOAD -- Length: %zu", payloadSize()); - for (utils::MemBuf *current = payload_head_; - current && current != header_head_; current = current->next()) { + + const utils::MemBuf *current = this; + do { TRANSPORT_LOGI("MemBuf Length: %zu", current->length()); - hicn_packet_dump((uint8_t *)current->data(), current->length()); - } + dump((uint8_t *)current->data(), current->length()); + current = current->next(); + } while (current != this); +} + +void Packet::dump(uint8_t *buffer, std::size_t length) { + hicn_packet_dump(buffer, length); } void Packet::setSignatureSize(std::size_t size_bytes) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_signature_size(format_, packet_start_, size_bytes); if (ret < 0) { - throw errors::RuntimeException("Packet without Authentication Header."); + throw errors::RuntimeException("Error setting signature size."); } - - packet_->append(size_bytes); - updateLength(); } uint8_t *Packet::getSignature() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint8_t *signature; int ret = hicn_packet_get_signature(format_, packet_start_, &signature); if (ret < 0) { - throw errors::RuntimeException("Packet without Authentication Header."); + throw errors::RuntimeException("Error getting signature."); } return signature; } void Packet::setSignatureTimestamp(const uint64_t ×tamp) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_signature_timestamp(format_, packet_start_, timestamp); @@ -254,6 +318,10 @@ void Packet::setSignatureTimestamp(const uint64_t ×tamp) { } uint64_t Packet::getSignatureTimestamp() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint64_t return_value; int ret = hicn_packet_get_signature_timestamp(format_, packet_start_, &return_value); @@ -266,7 +334,11 @@ uint64_t Packet::getSignatureTimestamp() const { } void Packet::setValidationAlgorithm( - const utils::CryptoSuite &validation_algorithm) { + const auth::CryptoSuite &validation_algorithm) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_validation_algorithm(format_, packet_start_, uint8_t(validation_algorithm)); @@ -275,7 +347,11 @@ void Packet::setValidationAlgorithm( } } -utils::CryptoSuite Packet::getValidationAlgorithm() const { +auth::CryptoSuite Packet::getValidationAlgorithm() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint8_t return_value; int ret = hicn_packet_get_validation_algorithm(format_, packet_start_, &return_value); @@ -284,10 +360,14 @@ utils::CryptoSuite Packet::getValidationAlgorithm() const { throw errors::RuntimeException("Error getting the validation algorithm."); } - return utils::CryptoSuite(return_value); + return auth::CryptoSuite(return_value); } -void Packet::setKeyId(const utils::KeyId &key_id) { +void Packet::setKeyId(const auth::KeyId &key_id) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_key_id(format_, packet_start_, key_id.first); if (ret < 0) { @@ -295,8 +375,12 @@ void Packet::setKeyId(const utils::KeyId &key_id) { } } -utils::KeyId Packet::getKeyId() const { - utils::KeyId return_value; +auth::KeyId Packet::getKeyId() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + + auth::KeyId return_value; int ret = hicn_packet_get_key_id(format_, packet_start_, &return_value.first, &return_value.second); @@ -307,8 +391,8 @@ utils::KeyId Packet::getKeyId() const { return return_value; } -utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { - utils::CryptoHasher hasher(static_cast<utils::CryptoHashType>(algorithm)); +auth::CryptoHash Packet::computeDigest(auth::CryptoHashType algorithm) const { + auth::CryptoHasher hasher(static_cast<auth::CryptoHashType>(algorithm)); hasher.init(); // Copy IP+TCP/ICMP header before zeroing them @@ -318,11 +402,11 @@ utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { const_cast<Packet *>(this)->resetForHash(); - auto current = header_head_; + const utils::MemBuf *current = this; do { hasher.updateBytes(current->data(), current->length()); current = current->next(); - } while (current != header_head_); + } while (current != this); hicn_packet_copy_header(format_, &header_copy, packet_start_, false); @@ -330,15 +414,33 @@ utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { } bool Packet::checkIntegrity() const { - if (hicn_packet_check_integrity(format_, packet_start_) < 0) { + uint16_t partial_csum = + csum(data() + HICN_V6_TCP_HDRLEN, length() - HICN_V6_TCP_HDRLEN, 0); + + for (const utils::MemBuf *current = next(); current != this; + current = current->next()) { + partial_csum = csum(current->data(), current->length(), ~partial_csum); + } + + if (hicn_packet_check_integrity_no_payload(format_, packet_start_, + partial_csum) < 0) { return false; } return true; } +void Packet::prependPayload(const uint8_t **buffer, std::size_t *size) { + auto last = prev(); + auto to_copy = std::min(*size, last->tailroom()); + std::memcpy(last->writableTail(), *buffer, to_copy); + last->append(to_copy); + *size -= to_copy; + *buffer += to_copy; +} + Packet &Packet::setSyn() { - if (hicn_packet_set_syn(packet_start_) < 0) { + if (hicn_packet_set_syn(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting syn bit in the packet."); } @@ -346,7 +448,7 @@ Packet &Packet::setSyn() { } Packet &Packet::resetSyn() { - if (hicn_packet_reset_syn(packet_start_) < 0) { + if (hicn_packet_reset_syn(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting syn bit in the packet."); } @@ -355,7 +457,7 @@ Packet &Packet::resetSyn() { bool Packet::testSyn() const { bool res = false; - if (hicn_packet_test_syn(packet_start_, &res) < 0) { + if (hicn_packet_test_syn(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing syn bit in the packet."); } @@ -363,7 +465,7 @@ bool Packet::testSyn() const { } Packet &Packet::setAck() { - if (hicn_packet_set_ack(packet_start_) < 0) { + if (hicn_packet_set_ack(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting ack bit in the packet."); } @@ -371,7 +473,7 @@ Packet &Packet::setAck() { } Packet &Packet::resetAck() { - if (hicn_packet_reset_ack(packet_start_) < 0) { + if (hicn_packet_reset_ack(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting ack bit in the packet."); } @@ -380,7 +482,7 @@ Packet &Packet::resetAck() { bool Packet::testAck() const { bool res = false; - if (hicn_packet_test_ack(packet_start_, &res) < 0) { + if (hicn_packet_test_ack(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing ack bit in the packet."); } @@ -388,7 +490,7 @@ bool Packet::testAck() const { } Packet &Packet::setRst() { - if (hicn_packet_set_rst(packet_start_) < 0) { + if (hicn_packet_set_rst(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting rst bit in the packet."); } @@ -396,7 +498,7 @@ Packet &Packet::setRst() { } Packet &Packet::resetRst() { - if (hicn_packet_reset_rst(packet_start_) < 0) { + if (hicn_packet_reset_rst(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting rst bit in the packet."); } @@ -405,7 +507,7 @@ Packet &Packet::resetRst() { bool Packet::testRst() const { bool res = false; - if (hicn_packet_test_rst(packet_start_, &res) < 0) { + if (hicn_packet_test_rst(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing rst bit in the packet."); } @@ -413,7 +515,7 @@ bool Packet::testRst() const { } Packet &Packet::setFin() { - if (hicn_packet_set_fin(packet_start_) < 0) { + if (hicn_packet_set_fin(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting fin bit in the packet."); } @@ -421,7 +523,7 @@ Packet &Packet::setFin() { } Packet &Packet::resetFin() { - if (hicn_packet_reset_fin(packet_start_) < 0) { + if (hicn_packet_reset_fin(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting fin bit in the packet."); } @@ -430,7 +532,7 @@ Packet &Packet::resetFin() { bool Packet::testFin() const { bool res = false; - if (hicn_packet_test_fin(packet_start_, &res) < 0) { + if (hicn_packet_test_fin(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing fin bit in the packet."); } @@ -447,24 +549,29 @@ Packet &Packet::resetFlags() { } std::string Packet::printFlags() const { - std::string flags = ""; + std::string flags; + if (testSyn()) { flags += "S"; } + if (testAck()) { flags += "A"; } + if (testRst()) { flags += "R"; } + if (testFin()) { flags += "F"; } + return flags; } Packet &Packet::setSrcPort(uint16_t srcPort) { - if (hicn_packet_set_src_port(packet_start_, srcPort) < 0) { + if (hicn_packet_set_src_port(format_, packet_start_, srcPort) < 0) { throw errors::RuntimeException("Error setting source port in the packet."); } @@ -472,7 +579,7 @@ Packet &Packet::setSrcPort(uint16_t srcPort) { } Packet &Packet::setDstPort(uint16_t dstPort) { - if (hicn_packet_set_dst_port(packet_start_, dstPort) < 0) { + if (hicn_packet_set_dst_port(format_, packet_start_, dstPort) < 0) { throw errors::RuntimeException( "Error setting destination port in the packet."); } @@ -483,7 +590,7 @@ Packet &Packet::setDstPort(uint16_t dstPort) { uint16_t Packet::getSrcPort() const { uint16_t port = 0; - if (hicn_packet_get_src_port(packet_start_, &port) < 0) { + if (hicn_packet_get_src_port(format_, packet_start_, &port) < 0) { throw errors::RuntimeException("Error reading source port in the packet."); } @@ -493,7 +600,7 @@ uint16_t Packet::getSrcPort() const { uint16_t Packet::getDstPort() const { uint16_t port = 0; - if (hicn_packet_get_dst_port(packet_start_, &port) < 0) { + if (hicn_packet_get_dst_port(format_, packet_start_, &port) < 0) { throw errors::RuntimeException( "Error reading destination port in the packet."); } @@ -518,37 +625,6 @@ uint8_t Packet::getTTL() const { return hops; } -void Packet::separateHeaderPayload() { - if (payload_head_) { - return; - } - - int signature_size = 0; - if (_is_ah(format_)) { - signature_size = (uint32_t)getSignatureSize(); - } - - auto header_size = getHeaderSizeFromFormat(format_, signature_size); - auto payload_length = packet_->length() - header_size; - - packet_->trimEnd(packet_->length()); - - auto payload = packet_->cloneOne(); - payload_head_ = payload.get(); - payload_head_->advance(header_size); - payload_head_->append(payload_length); - packet_->prependChain(std::move(payload)); - packet_->append(header_size); -} - -void Packet::resetPayload() { - if (packet_->isChained()) { - packet_->separateChain(packet_->next(), packet_->prev()); - payload_head_ = nullptr; - updateLength(); - } -} - } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/pending_interest.h b/libtransport/src/core/pending_interest.h index aeff78ea2..ca6411ddf 100644 --- a/libtransport/src/core/pending_interest.h +++ b/libtransport/src/core/pending_interest.h @@ -21,7 +21,6 @@ #include <hicn/transport/core/name.h> #include <hicn/transport/interfaces/portal.h> #include <hicn/transport/portability/portability.h> - #include <utils/deadline_timer.h> #include <asio/steady_timer.hpp> @@ -34,24 +33,21 @@ class HicnForwarderInterface; class VPPForwarderInterface; class RawSocketInterface; -template <typename ForwarderInt> class Portal; using OnContentObjectCallback = interface::Portal::OnContentObjectCallback; using OnInterestTimeoutCallback = interface::Portal::OnInterestTimeoutCallback; class PendingInterest { - friend class Portal<HicnForwarderInterface>; - friend class Portal<VPPForwarderInterface>; - friend class Portal<RawSocketInterface>; + friend class Portal; public: using Ptr = utils::ObjectPool<PendingInterest>::Ptr; - PendingInterest() - : interest_(nullptr, nullptr), - timer_(), - on_content_object_callback_(), - on_interest_timeout_callback_() {} + // PendingInterest() + // : interest_(nullptr, nullptr), + // timer_(), + // on_content_object_callback_(), + // on_interest_timeout_callback_() {} PendingInterest(Interest::Ptr &&interest, std::unique_ptr<asio::steady_timer> &&timer) diff --git a/libtransport/src/core/portal.cc b/libtransport/src/core/portal.cc new file mode 100644 index 000000000..d1d26c5b7 --- /dev/null +++ b/libtransport/src/core/portal.cc @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/errors.h> +#include <core/global_configuration.h> +#include <core/portal.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/portability/platform.h> +#include <hicn/transport/utils/file.h> + +#include <libconfig.h++> + +using namespace transport::interface::global_config; + +namespace transport { +namespace core { + +#ifdef ANDROID +static const constexpr char default_module[] = ""; +#elif defined(MACINTOSH) +static const constexpr char default_module[] = "hicnlight_module.dylib"; +#elif defined(LINUX) +static const constexpr char default_module[] = "hicnlight_module.so"; +#endif + +IoModuleConfiguration Portal::conf_; +std::string Portal::io_module_path_ = defaultIoModule(); + +std::string Portal::defaultIoModule() { + using namespace std::placeholders; + GlobalConfiguration::getInstance().registerConfigurationParser( + io_module_section, + std::bind(&Portal::parseIoModuleConfiguration, _1, _2)); + GlobalConfiguration::getInstance().registerConfigurationGetter( + io_module_section, std::bind(&Portal::getModuleConfiguration, _1, _2)); + GlobalConfiguration::getInstance().registerConfigurationSetter( + io_module_section, std::bind(&Portal::setModuleConfiguration, _1, _2)); + + // return default + conf_.name = default_module; + return default_module; +} + +void Portal::getModuleConfiguration(ConfigurationObject& object, + std::error_code& ec) { + assert(object.getKey() == io_module_section); + + auto conf = dynamic_cast<const IoModuleConfiguration&>(object); + conf = conf_; + ec = std::error_code(); +} + +std::string getIoModulePath(const std::string& name, + const std::vector<std::string>& paths, + std::error_code& ec) { +#ifdef LINUX + std::string extension = ".so"; +#elif defined(MACINTOSH) + std::string extension = ".dylib"; +#else +#error "Platform not supported."; +#endif + + std::string complete_path = name; + + if (name.empty()) { + ec = make_error_code(core_error::configuration_parse_failed); + return ""; + } + + complete_path += extension; + + for (auto& p : paths) { + if (p.at(0) != '/') { + TRANSPORT_LOGW("Path %s is not an absolute path. Ignoring it.", + p.c_str()); + continue; + } + + if (utils::File::exists(p + "/" + complete_path)) { + complete_path = p + "/" + complete_path; + break; + } + } + + return complete_path; +} + +void Portal::setModuleConfiguration(const ConfigurationObject& object, + std::error_code& ec) { + assert(object.getKey() == io_module_section); + + const IoModuleConfiguration& conf = + dynamic_cast<const IoModuleConfiguration&>(object); + auto path = getIoModulePath(conf.name, conf.search_path, ec); + if (!ec) { + conf_ = conf; + io_module_path_ = path; + } +} + +void Portal::parseIoModuleConfiguration(const libconfig::Setting& io_config, + std::error_code& ec) { + using namespace libconfig; + // path property: the list of paths where to look for the module. + std::vector<std::string> paths; + std::string name; + + if (io_config.exists("path")) { + // get path where looking for modules + const Setting& path_list = io_config.lookup("path"); + auto count = path_list.getLength(); + + for (int i = 0; i < count; i++) { + paths.emplace_back(path_list[i].c_str()); + } + } + + if (io_config.exists("name")) { + io_config.lookupValue("name", name); + } else { + ec = make_error_code(core_error::configuration_parse_failed); + return; + } + + auto path = getIoModulePath(name, paths, ec); + if (!ec) { + conf_.name = name; + conf_.search_path = paths; + io_module_path_ = path; + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/portal.h b/libtransport/src/core/portal.h index b63eab3af..59254cf7b 100644 --- a/libtransport/src/core/portal.h +++ b/libtransport/src/core/portal.h @@ -15,24 +15,20 @@ #pragma once -#include <core/forwarder_interface.h> #include <core/pending_interest.h> -#include <core/udp_socket_connector.h> #include <hicn/transport/config.h> #include <hicn/transport/core/content_object.h> #include <hicn/transport/core/interest.h> +#include <hicn/transport/core/io_module.h> #include <hicn/transport/core/name.h> #include <hicn/transport/core/prefix.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/interfaces/global_conf_interface.h> #include <hicn/transport/interfaces/portal.h> #include <hicn/transport/portability/portability.h> #include <hicn/transport/utils/fixed_block_allocator.h> #include <hicn/transport/utils/log.h> -#ifdef __vpp__ -#include <core/memif_connector.h> -#endif - #include <asio.hpp> #include <asio/steady_timer.hpp> #include <future> @@ -40,17 +36,19 @@ #include <queue> #include <unordered_map> +namespace libconfig { +class Setting; +} + namespace transport { namespace core { namespace portal_details { -static constexpr uint32_t pool_size = 2048; +static constexpr uint32_t pit_size = 1024; class HandlerMemory { #ifdef __vpp__ - static constexpr std::size_t memory_size = 1024 * 1024; - public: HandlerMemory() {} @@ -58,12 +56,11 @@ class HandlerMemory { HandlerMemory &operator=(const HandlerMemory &) = delete; TRANSPORT_ALWAYS_INLINE void *allocate(std::size_t size) { - return utils::FixedBlockAllocator<128, 4096>::getInstance() - ->allocateBlock(); + return utils::FixedBlockAllocator<128, 8192>::getInstance().allocateBlock(); } TRANSPORT_ALWAYS_INLINE void deallocate(void *pointer) { - utils::FixedBlockAllocator<128, 4096>::getInstance()->deallocateBlock( + utils::FixedBlockAllocator<128, 8192>::getInstance().deallocateBlock( pointer); } #else @@ -159,33 +156,16 @@ class Pool { public: Pool(asio::io_service &io_service) : io_service_(io_service) { increasePendingInterestPool(); - increaseInterestPool(); - increaseContentObjectPool(); } TRANSPORT_ALWAYS_INLINE void increasePendingInterestPool() { // Create pool of pending interests to reuse - for (uint32_t i = 0; i < pool_size; i++) { + for (uint32_t i = 0; i < pit_size; i++) { pending_interests_pool_.add(new PendingInterest( Interest::Ptr(nullptr), std::make_unique<asio::steady_timer>(io_service_))); } } - - TRANSPORT_ALWAYS_INLINE void increaseInterestPool() { - // Create pool of interests to reuse - for (uint32_t i = 0; i < pool_size; i++) { - interest_pool_.add(new Interest()); - } - } - - TRANSPORT_ALWAYS_INLINE void increaseContentObjectPool() { - // Create pool of content object to reuse - for (uint32_t i = 0; i < pool_size; i++) { - content_object_pool_.add(new ContentObject()); - } - } - PendingInterest::Ptr getPendingInterest() { auto res = pending_interests_pool_.get(); while (TRANSPORT_EXPECT_FALSE(!res.first)) { @@ -196,35 +176,15 @@ class Pool { return std::move(res.second); } - TRANSPORT_ALWAYS_INLINE ContentObject::Ptr getContentObject() { - auto res = content_object_pool_.get(); - while (TRANSPORT_EXPECT_FALSE(!res.first)) { - increaseContentObjectPool(); - res = content_object_pool_.get(); - } - - return std::move(res.second); - } - - TRANSPORT_ALWAYS_INLINE Interest::Ptr getInterest() { - auto res = interest_pool_.get(); - while (TRANSPORT_EXPECT_FALSE(!res.first)) { - increaseInterestPool(); - res = interest_pool_.get(); - } - - return std::move(res.second); - } - private: utils::ObjectPool<PendingInterest> pending_interests_pool_; - utils::ObjectPool<ContentObject> content_object_pool_; - utils::ObjectPool<Interest> interest_pool_; asio::io_service &io_service_; }; } // namespace portal_details +class PortalConfiguration; + using PendingInterestHashTable = std::unordered_map<uint32_t, PendingInterest::Ptr>; @@ -250,32 +210,32 @@ using interface::BindConfig; * The portal class is not thread safe, appropriate locking is required by the * users of this class. */ -template <typename ForwarderInt> -class Portal { - static_assert( - std::is_base_of<ForwarderInterface<ForwarderInt, - typename ForwarderInt::ConnectorType>, - ForwarderInt>::value, - "ForwarderInt must inherit from ForwarderInterface!"); +class Portal { public: using ConsumerCallback = interface::Portal::ConsumerCallback; using ProducerCallback = interface::Portal::ProducerCallback; + friend class PortalConfiguration; + Portal() : Portal(internal_io_service_) {} Portal(asio::io_service &io_service) - : io_service_(io_service), + : io_module_(nullptr, [](IoModule *module) { IoModule::unload(module); }), + io_service_(io_service), packet_pool_(io_service), app_name_("libtransport_application"), consumer_callback_(nullptr), producer_callback_(nullptr), - connector_(std::bind(&Portal::processIncomingMessages, this, - std::placeholders::_1), - std::bind(&Portal::setLocalRoutes, this), io_service_, - app_name_), - forwarder_interface_(connector_) {} - + is_consumer_(false) { + /** + * This workaroung allows to initialize memory for packet buffers *before* + * any static variables that may be initialized in the io_modules. In this + * way static variables in modules will be destroyed before the packet + * memory. + */ + PacketManager<>::getInstance(); + } /** * Set the consumer callback. * @@ -304,7 +264,7 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void setOutputInterface( const std::string &output_interface) { - forwarder_interface_.setOutputInterface(output_interface); + io_module_->setOutputInterface(output_interface); } /** @@ -314,8 +274,19 @@ class Portal { * is a consumer or a producer. */ TRANSPORT_ALWAYS_INLINE void connect(bool is_consumer = true) { - pending_interest_hash_table_.reserve(portal_details::pool_size); - forwarder_interface_.connect(is_consumer); + if (!io_module_) { + pending_interest_hash_table_.reserve(portal_details::pit_size); + io_module_.reset(IoModule::load(io_module_path_.c_str())); + assert(io_module_); + + io_module_->init(std::bind(&Portal::processIncomingMessages, this, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3), + std::bind(&Portal::setLocalRoutes, this), io_service_, + app_name_); + io_module_->connect(is_consumer); + is_consumer_ = is_consumer; + } } /** @@ -324,13 +295,19 @@ class Portal { ~Portal() { killConnection(); } /** + * Compute name hash + */ + TRANSPORT_ALWAYS_INLINE uint32_t getHash(const Name &name) { + return name.getHash32() + name.getSuffix(); + } + + /** * Check if there is already a pending interest for a given name. * * @param name - The interest name. */ TRANSPORT_ALWAYS_INLINE bool interestIsPending(const Name &name) { - auto it = - pending_interest_hash_table_.find(name.getHash32() + name.getSuffix()); + auto it = pending_interest_hash_table_.find(getHash(name)); if (it != pending_interest_hash_table_.end()) { return true; } @@ -357,31 +334,46 @@ class Portal { OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, OnInterestTimeoutCallback &&on_interest_timeout_callback = UNSET_CALLBACK) { - uint32_t hash = - interest->getName().getHash32() + interest->getName().getSuffix(); // Send it - forwarder_interface_.send(*interest); - - auto pending_interest = packet_pool_.getPendingInterest(); - pending_interest->setInterest(std::move(interest)); - pending_interest->setOnContentObjectCallback( - std::move(on_content_object_callback)); - pending_interest->setOnTimeoutCallback( - std::move(on_interest_timeout_callback)); - pending_interest->startCountdown(portal_details::makeCustomAllocatorHandler( - async_callback_memory_, std::bind(&Portal<ForwarderInt>::timerHandler, - this, std::placeholders::_1, hash))); + interest->encodeSuffixes(); + io_module_->send(*interest); + + uint32_t initial_hash = interest->getName().getHash32(); + auto hash = initial_hash + interest->getName().getSuffix(); + uint32_t *suffix = interest->firstSuffix(); + auto n_suffixes = interest->numberOfSuffixes(); + uint32_t counter = 0; + // Set timers + do { + auto pending_interest = packet_pool_.getPendingInterest(); + pending_interest->setInterest(std::move(interest)); + pending_interest->setOnContentObjectCallback( + std::move(on_content_object_callback)); + pending_interest->setOnTimeoutCallback( + std::move(on_interest_timeout_callback)); + + pending_interest->startCountdown( + portal_details::makeCustomAllocatorHandler( + async_callback_memory_, std::bind(&Portal::timerHandler, this, + std::placeholders::_1, hash))); + + auto it = pending_interest_hash_table_.find(hash); + if (it != pending_interest_hash_table_.end()) { + it->second->cancelTimer(); - auto it = pending_interest_hash_table_.find(hash); - if (it != pending_interest_hash_table_.end()) { - it->second->cancelTimer(); + // Get reference to interest packet in order to have it destroyed. + auto _int = it->second->getInterest(); + it->second = std::move(pending_interest); + } else { + pending_interest_hash_table_[hash] = std::move(pending_interest); + } - // Get reference to interest packet in order to have it destroyed. - auto _int = it->second->getInterest(); - it->second = std::move(pending_interest); - } else { - pending_interest_hash_table_[hash] = std::move(pending_interest); - } + if (suffix) { + hash = initial_hash + *suffix; + suffix++; + } + + } while (counter++ < n_suffixes); } /** @@ -423,9 +415,9 @@ class Portal { * @param config - The configuration for the local forwarder binding. */ TRANSPORT_ALWAYS_INLINE void bind(const BindConfig &config) { - forwarder_interface_.setContentStoreSize(config.csReserved()); + assert(io_module_); + io_module_->setContentStoreSize(config.csReserved()); served_namespaces_.push_back(config.prefix()); - setLocalRoutes(); } @@ -460,7 +452,7 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void sendContentObject( ContentObject &content_object) { - forwarder_interface_.send(content_object); + io_module_->send(content_object); } /** @@ -482,7 +474,7 @@ class Portal { * Disconnect the transport from the local forwarder. */ TRANSPORT_ALWAYS_INLINE void killConnection() { - forwarder_interface_.closeConnection(); + io_module_->closeConnection(); } /** @@ -497,6 +489,17 @@ class Portal { } /** + * Remove one pending interest. + */ + TRANSPORT_ALWAYS_INLINE void clearOne(const Name &name) { + if (!io_service_.stopped()) { + io_service_.dispatch(std::bind(&Portal::doClearOne, this, name)); + } else { + doClearOne(name); + } + } + + /** * Get a reference to the io_service object. */ TRANSPORT_ALWAYS_INLINE asio::io_service &getIoService() { @@ -508,8 +511,8 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void registerRoute(Prefix &prefix) { served_namespaces_.push_back(prefix); - if (connector_.isConnected()) { - forwarder_interface_.registerRoute(prefix); + if (io_module_->isConnected()) { + io_module_->registerRoute(prefix); } } @@ -530,36 +533,49 @@ class Portal { } /** + * Remove one pending interest. + */ + TRANSPORT_ALWAYS_INLINE void doClearOne(const Name &name) { + auto it = pending_interest_hash_table_.find(getHash(name)); + + if (it != pending_interest_hash_table_.end()) { + it->second->cancelTimer(); + + // Get interest packet from pending interest and do nothing with it. It + // will get destroyed as it goes out of scope. + auto _int = it->second->getInterest(); + + pending_interest_hash_table_.erase(it); + } + } + + /** * Callback called by the underlying connector upon reception of a packet from * the local forwarder. * * @param packet_buffer - The bytes of the packet. */ TRANSPORT_ALWAYS_INLINE void processIncomingMessages( - Packet::MemBufPtr &&packet_buffer) { + Connector *c, utils::MemBuf &buffer, const std::error_code &ec) { bool is_stopped = io_service_.stopped(); if (TRANSPORT_EXPECT_FALSE(is_stopped)) { return; } - if (TRANSPORT_EXPECT_FALSE( - ForwarderInt::isControlMessage(packet_buffer->data()))) { - processControlMessage(std::move(packet_buffer)); + if (TRANSPORT_EXPECT_FALSE(io_module_->isControlMessage(buffer.data()))) { + processControlMessage(buffer); return; } - Packet::Format format = Packet::getFormatFromBuffer( - packet_buffer->data(), packet_buffer->length()); + // The buffer is a base class for an interest or a content object + Packet &packet_buffer = static_cast<Packet &>(buffer); + auto format = packet_buffer.getFormat(); if (TRANSPORT_EXPECT_TRUE(_is_tcp(format))) { - if (!Packet::isInterest(packet_buffer->data())) { - auto content_object = packet_pool_.getContentObject(); - content_object->replace(std::move(packet_buffer)); - processContentObject(std::move(content_object)); + if (is_consumer_) { + processContentObject(static_cast<ContentObject &>(packet_buffer)); } else { - auto interest = packet_pool_.getInterest(); - interest->replace(std::move(packet_buffer)); - processInterest(std::move(interest)); + processInterest(static_cast<Interest &>(packet_buffer)); } } else { TRANSPORT_LOGE("Received not supported packet. Ignoring it."); @@ -573,16 +589,16 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void setLocalRoutes() { for (auto &prefix : served_namespaces_) { - if (connector_.isConnected()) { - forwarder_interface_.registerRoute(prefix); + if (io_module_->isConnected()) { + io_module_->registerRoute(prefix); } } } - TRANSPORT_ALWAYS_INLINE void processInterest(Interest::Ptr &&interest) { + TRANSPORT_ALWAYS_INLINE void processInterest(Interest &interest) { // Interest for a producer if (TRANSPORT_EXPECT_TRUE(producer_callback_ != nullptr)) { - producer_callback_->onInterest(std::move(interest)); + producer_callback_->onInterest(interest); } } @@ -595,24 +611,27 @@ class Portal { * @param content_object - The data packet */ TRANSPORT_ALWAYS_INLINE void processContentObject( - ContentObject::Ptr &&content_object) { - uint32_t hash = content_object->getName().getHash32() + - content_object->getName().getSuffix(); + ContentObject &content_object) { + TRANSPORT_LOGD("processContentObject %s", + content_object.getName().toString().c_str()); + uint32_t hash = getHash(content_object.getName()); auto it = pending_interest_hash_table_.find(hash); if (it != pending_interest_hash_table_.end()) { + TRANSPORT_LOGD("Found pending interest."); + PendingInterest::Ptr interest_ptr = std::move(it->second); pending_interest_hash_table_.erase(it); interest_ptr->cancelTimer(); auto _int = interest_ptr->getInterest(); if (interest_ptr->getOnDataCallback() != UNSET_CALLBACK) { - interest_ptr->on_content_object_callback_(std::move(_int), - std::move(content_object)); + interest_ptr->on_content_object_callback_(*_int, content_object); } else if (consumer_callback_) { - consumer_callback_->onContentObject(std::move(_int), - std::move(content_object)); + consumer_callback_->onContentObject(*_int, content_object); } + } else { + TRANSPORT_LOGD("No interest pending for received content object."); } } @@ -622,12 +641,13 @@ class Portal { * them. */ TRANSPORT_ALWAYS_INLINE void processControlMessage( - Packet::MemBufPtr &&packet_buffer) { - forwarder_interface_.processControlMessageReply(std::move(packet_buffer)); + utils::MemBuf &packet_buffer) { + io_module_->processControlMessageReply(packet_buffer); } private: portal_details::HandlerMemory async_callback_memory_; + std::unique_ptr<IoModule, void (*)(IoModule *)> io_module_; asio::io_service &io_service_; asio::io_service internal_io_service_; @@ -641,8 +661,19 @@ class Portal { ConsumerCallback *consumer_callback_; ProducerCallback *producer_callback_; - typename ForwarderInt::ConnectorType connector_; - ForwarderInt forwarder_interface_; + bool is_consumer_; + + private: + static std::string defaultIoModule(); + static void parseIoModuleConfiguration(const libconfig::Setting &io_config, + std::error_code &ec); + static void getModuleConfiguration( + interface::global_config::ConfigurationObject &conf, std::error_code &ec); + static void setModuleConfiguration( + const interface::global_config::ConfigurationObject &conf, + std::error_code &ec); + static interface::global_config::IoModuleConfiguration conf_; + static std::string io_module_path_; }; } // namespace core diff --git a/libtransport/src/core/prefix.cc b/libtransport/src/core/prefix.cc index eea4aeb8b..1e2b2ed9d 100644 --- a/libtransport/src/core/prefix.cc +++ b/libtransport/src/core/prefix.cc @@ -25,12 +25,12 @@ extern "C" { #include <hicn/transport/portability/win_portability.h> #endif +#include <openssl/rand.h> + #include <cstring> #include <memory> #include <random> -#include <openssl/rand.h> - namespace transport { namespace core { @@ -99,7 +99,7 @@ void Prefix::buildPrefix(std::string &prefix, uint16_t prefix_length, ip_prefix_.family = family; } -std::unique_ptr<Sockaddr> Prefix::toSockaddr() { +std::unique_ptr<Sockaddr> Prefix::toSockaddr() const { Sockaddr *ret = nullptr; switch (ip_prefix_.family) { @@ -120,14 +120,14 @@ std::unique_ptr<Sockaddr> Prefix::toSockaddr() { return std::unique_ptr<Sockaddr>(ret); } -uint16_t Prefix::getPrefixLength() { return ip_prefix_.len; } +uint16_t Prefix::getPrefixLength() const { return ip_prefix_.len; } Prefix &Prefix::setPrefixLength(uint16_t prefix_length) { ip_prefix_.len = prefix_length; return *this; } -int Prefix::getAddressFamily() { return ip_prefix_.family; } +int Prefix::getAddressFamily() const { return ip_prefix_.family; } Prefix &Prefix::setAddressFamily(int address_family) { ip_prefix_.family = address_family; @@ -226,7 +226,7 @@ Name Prefix::getRandomName() const { ip_prefix_.len; size_t size = (size_t)ceil((float)addr_len / 8.0); - uint8_t *buffer = (uint8_t *) malloc(sizeof(uint8_t) * size); + uint8_t *buffer = (uint8_t *)malloc(sizeof(uint8_t) * size); RAND_bytes(buffer, size); @@ -332,7 +332,7 @@ bool Prefix::checkPrefixLengthAndAddressFamily(uint16_t prefix_length, return true; } -ip_prefix_t &Prefix::toIpPrefixStruct() { return ip_prefix_; } +const ip_prefix_t &Prefix::toIpPrefixStruct() const { return ip_prefix_; } } // namespace core diff --git a/libtransport/src/core/rs.cc b/libtransport/src/core/rs.cc new file mode 100644 index 000000000..44b5852e5 --- /dev/null +++ b/libtransport/src/core/rs.cc @@ -0,0 +1,365 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/fec.h> +#include <core/rs.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/log.h> + +#include <cassert> + +namespace transport { +namespace core { +namespace fec { + +BlockCode::BlockCode(uint32_t k, uint32_t n, struct fec_parms *code) + : Packets(), + k_(k), + n_(n), + code_(code), + max_buffer_size_(0), + current_block_size_(0), + to_decode_(false) { + sorted_index_.reserve(n); +} + +bool BlockCode::addRepairSymbol(const fec::buffer &packet, uint32_t i) { + // Get index + to_decode_ = true; + TRANSPORT_LOGD("adding symbol of size %zu", packet->length()); + return addSymbol(packet, i, packet->length() - sizeof(fec_header)); +} + +bool BlockCode::addSourceSymbol(const fec::buffer &packet, uint32_t i) { + return addSymbol(packet, i, packet->length()); +} + +bool BlockCode::addSymbol(const fec::buffer &packet, uint32_t i, + std::size_t size) { + if (size > max_buffer_size_) { + max_buffer_size_ = size; + } + + operator[](current_block_size_++) = std::make_pair(i, packet); + + if (current_block_size_ >= k_) { + if (to_decode_) { + decode(); + } else { + encode(); + } + + clear(); + return false; + } + + return true; +} + +void BlockCode::encode() { + gf *data[n_]; + std::uint16_t old_values[k_]; + uint32_t base = operator[](0).first; + + // Set packet length in first 2 bytes + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + + TRANSPORT_LOGD("Current buffer size: %zu", packet->length()); + + auto ret = packet->ensureCapacityAndFillUnused(max_buffer_size_, 0); + if (TRANSPORT_EXPECT_FALSE(ret == false)) { + throw errors::RuntimeException( + "Provided packet is not suitable to be used as FEC source packet. " + "Aborting."); + } + + // Buffers should hold 2 bytes before the starting pointer, in order to be + // able to set the length for the encoding operation + packet->prepend(2); + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + + old_values[i] = *length; + *length = htons(packet->length() - LEN_SIZE_BYTES); + + data[i] = packet->writableData(); + } + + // Finish to fill source block with the buffers to hold the repair symbols + for (uint32_t i = k_; i < n_; i++) { + // For the moment we get a packet from the pool here.. later we'll need to + // require a packet from the caller with a callback. + auto packet = PacketManager<>::getInstance().getMemBuf(); + packet->append(max_buffer_size_ + sizeof(fec_header) + LEN_SIZE_BYTES); + fec_header *fh = reinterpret_cast<fec_header *>(packet->writableData()); + + fh->setSeqNumberBase(base); + fh->setNFecSymbols(n_ - k_); + fh->setEncodedSymbolId(i); + fh->setSourceBlockLen(n_); + + packet->trimStart(sizeof(fec_header)); + + data[i] = packet->writableData(); + operator[](i) = std::make_pair(i, std::move(packet)); + } + + // Generate repair symbols and put them in corresponding buffers + TRANSPORT_LOGD("Calling encode with max_buffer_size_ = %zu", + max_buffer_size_); + for (uint32_t i = k_; i < n_; i++) { + fec_encode(code_, data, data[i], i, max_buffer_size_ + LEN_SIZE_BYTES); + } + + // Restore original content of buffer space used to store the length + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + *length = old_values[i]; + packet->trimStart(2); + } + + // Re-include header in repair packets + for (uint32_t i = k_; i < n_; i++) { + auto &packet = operator[](i).second; + TRANSPORT_LOGD("Produced repair symbol of size = %zu", packet->length()); + packet->prepend(sizeof(fec_header)); + } +} + +void BlockCode::decode() { + gf *data[k_]; + uint32_t index[k_]; + + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + index[i] = operator[](i).first; + sorted_index_[i] = index[i]; + + if (index[i] < k_) { + TRANSPORT_LOGD("DECODE SOURCE - index %u - Current buffer size: %zu", + index[i], packet->length()); + // This is a source packet. We need to prepend the length and fill + // additional space to 0 + + // Buffers should hold 2 bytes before the starting pointer, in order to be + // able to set the length for the encoding operation + packet->prepend(LEN_SIZE_BYTES); + packet->ensureCapacityAndFillUnused(max_buffer_size_, 0); + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + + *length = htons(packet->length() - LEN_SIZE_BYTES); + } else { + TRANSPORT_LOGD("DECODE SYMBOL - index %u - Current buffer size: %zu", + index[i], packet->length()); + packet->trimStart(sizeof(fec_header)); + } + + data[i] = packet->writableData(); + } + + // We decode the source block + TRANSPORT_LOGD("Calling decode with max_buffer_size_ = %zu", + max_buffer_size_); + fec_decode(code_, data, reinterpret_cast<int *>(index), max_buffer_size_); + + // Find the index in the block for recovered packets + for (uint32_t i = 0; i < k_; i++) { + if (index[i] != i) { + for (uint32_t j = 0; j < k_; j++) + if (sorted_index_[j] == uint32_t(index[i])) { + sorted_index_[j] = i; + } + } + } + + // Reorder block by index with in-place sorting + for (uint32_t i = 0; i < k_; i++) { + for (uint32_t j = sorted_index_[i]; j != i; j = sorted_index_[i]) { + std::swap(sorted_index_[j], sorted_index_[i]); + std::swap(operator[](j), operator[](i)); + } + } + + // Adjust length according to the one written in the source packet + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + packet->trimStart(2); + packet->setLength(ntohs(*length)); + } +} + +void BlockCode::clear() { + current_block_size_ = 0; + max_buffer_size_ = 0; + sorted_index_.clear(); + to_decode_ = false; +} + +void rs::MatrixDeleter::operator()(struct fec_parms *params) { + fec_free(params); +} + +rs::Codes rs::createCodes() { + Codes ret; + + ret.emplace(std::make_pair(1, 3), Matrix(fec_new(1, 3), MatrixDeleter())); + ret.emplace(std::make_pair(6, 10), Matrix(fec_new(6, 10), MatrixDeleter())); + ret.emplace(std::make_pair(8, 32), Matrix(fec_new(8, 32), MatrixDeleter())); + ret.emplace(std::make_pair(10, 30), Matrix(fec_new(10, 30), MatrixDeleter())); + ret.emplace(std::make_pair(16, 24), Matrix(fec_new(16, 24), MatrixDeleter())); + ret.emplace(std::make_pair(10, 40), Matrix(fec_new(10, 40), MatrixDeleter())); + ret.emplace(std::make_pair(10, 60), Matrix(fec_new(10, 60), MatrixDeleter())); + ret.emplace(std::make_pair(10, 90), Matrix(fec_new(10, 90), MatrixDeleter())); + + return ret; +} + +rs::Codes rs::codes_ = createCodes(); + +rs::rs(uint32_t k, uint32_t n) : k_(k), n_(n) {} + +void rs::setFECCallback(const PacketsReady &callback) { + fec_callback_ = callback; +} + +encoder::encoder(uint32_t k, uint32_t n) + : rs(k, n), + current_code_(codes_[std::make_pair(k, n)].get()), + source_block_(k_, n_, current_code_) {} + +void encoder::consume(const fec::buffer &packet, uint32_t index) { + if (!source_block_.addSourceSymbol(packet, index)) { + std::vector<buffer> repair_packets; + for (uint32_t i = k_; i < n_; i++) { + repair_packets.emplace_back(std::move(source_block_[i].second)); + } + fec_callback_(repair_packets); + } +} + +decoder::decoder(uint32_t k, uint32_t n) : rs(k, n) {} + +void decoder::recoverPackets(SourceBlocks::iterator &src_block_it) { + TRANSPORT_LOGD("recoverPackets for %u", k_); + auto &src_block = src_block_it->second; + std::vector<buffer> source_packets(k_); + for (uint32_t i = 0; i < src_block.getK(); i++) { + source_packets[i] = std::move(src_block[i].second); + } + + fec_callback_(source_packets); + processed_source_blocks_.emplace(src_block_it->first); + + auto it = parked_packets_.find(src_block_it->first); + if (it != parked_packets_.end()) { + parked_packets_.erase(it); + } + + src_blocks_.erase(src_block_it); +} + +void decoder::consume(const fec::buffer &packet, uint32_t index) { + // Normalize index + auto i = index % n_; + + // Get base + uint32_t base = index - i; + + TRANSPORT_LOGD( + "Decoder consume called for source symbol. BASE = %u, index = %u and i = " + "%u", + base, index, i); + + // check if a source block already exist for this symbol. If it does not + // exist, we lazily park this packet until we receive a repair symbol for the + // same block. This is done for 2 reason: + // 1) If we receive all the source packets of a block, we do not need to + // recover anything. + // 2) Sender may change n and k at any moment, so we construct the source + // block based on the (n, k) values written in the fec header. This is + // actually not used right now, since we use fixed value of n and k passed + // at construction time, but it paves the ground for a more dynamic + // protocol that may come in the future. + auto it = src_blocks_.find(base); + if (it != src_blocks_.end()) { + auto ret = it->second.addSourceSymbol(packet, i); + if (!ret) { + recoverPackets(it); + } + } else { + TRANSPORT_LOGD("Adding to parked source packets"); + auto ret = parked_packets_.emplace( + base, std::vector<std::pair<buffer, uint32_t> >()); + ret.first->second.emplace_back(packet, i); + } +} + +void decoder::consume(const fec::buffer &packet) { + // Repair symbol! Get index and base source block. + fec_header *h = reinterpret_cast<fec_header *>(packet->writableData()); + auto i = h->getEncodedSymbolId(); + auto base = h->getSeqNumberBase(); + auto n = h->getSourceBlockLen(); + auto k = n - h->getNFecSymbols(); + + TRANSPORT_LOGD( + "Decoder consume called for repair symbol. BASE = %u, index = %u and i = " + "%u", + base, base + i, i); + + // check if a source block already exist for this symbol + auto it = src_blocks_.find(base); + if (it == src_blocks_.end()) { + // Create new source block + auto code_it = codes_.find(std::make_pair(k, n)); + if (code_it == codes_.end()) { + TRANSPORT_LOGE("Code for k = %u and n = %u does not exist.", k_, n_); + return; + } + + auto emplace_result = + src_blocks_.emplace(base, BlockCode(k, n, code_it->second.get())); + it = emplace_result.first; + + // Check in the parked packets and insert any packet that is part of this + // source block + + auto it2 = parked_packets_.find(base); + if (it2 != parked_packets_.end()) { + for (auto &packet_index : it2->second) { + auto ret = + it->second.addSourceSymbol(packet_index.first, packet_index.second); + if (!ret) { + recoverPackets(it); + // Finish to delete packets in same source block that were + // eventually not used + return; + } + } + } + } + + auto ret = it->second.addRepairSymbol(packet, i); + if (!ret) { + recoverPackets(it); + } +} + +} // namespace fec +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/rs.h b/libtransport/src/core/rs.h new file mode 100644 index 000000000..d630bd233 --- /dev/null +++ b/libtransport/src/core/rs.h @@ -0,0 +1,338 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <arpa/inet.h> +#include <hicn/transport/utils/membuf.h> +#include <protocols/fec_base.h> + +#include <array> +#include <cstdint> +#include <map> +#include <unordered_set> +#include <vector> + +namespace transport { +namespace core { + +namespace fec { + +static const constexpr uint16_t MAX_SOURCE_BLOCK_SIZE = 128; + +using buffer = typename utils::MemBuf::Ptr; +/** + * We use a std::array in place of std::vector to avoid to allocate a new vector + * in the heap every time we build a new source block, which would be bad if + * the decoder has to allocate several source blocks for many concurrent bases. + * std::array allows to be constructed in place, saving the allocation at the + * price os knowing in advance its size. + */ +using Packets = std::array<std::pair<uint32_t, buffer>, MAX_SOURCE_BLOCK_SIZE>; + +/** + * FEC Header, prepended to symbol packets. + */ +struct fec_header { + /** + * The base source packet seq_number this FES symbol refers to + */ + uint32_t seq_number; + + /** + * The index of the symbol inside the source block, between k and n - 1 + */ + uint8_t encoded_symbol_id; + + /** + * Total length of source block (n) + */ + uint8_t source_block_len; + + /** + * Total number of symbols (n - k) + */ + uint8_t n_fec_symbols; + + /** + * Align header to 64 bits + */ + uint8_t padding; + + void setSeqNumberBase(uint32_t suffix) { seq_number = htonl(suffix); } + uint32_t getSeqNumberBase() { return ntohl(seq_number); } + void setEncodedSymbolId(uint8_t esi) { encoded_symbol_id = esi; } + uint8_t getEncodedSymbolId() { return encoded_symbol_id; } + void setSourceBlockLen(uint8_t k) { source_block_len = k; } + uint8_t getSourceBlockLen() { return source_block_len; } + void setNFecSymbols(uint8_t n_r) { n_fec_symbols = n_r; } + uint8_t getNFecSymbols() { return n_fec_symbols; } +}; + +/** + * This class models the source block itself. + */ +class BlockCode : public Packets { + /** + * For variable length packet we need to prepend to the padded payload the + * real length of the packet. This is *not* sent over the network. + */ + static constexpr std::size_t LEN_SIZE_BYTES = 2; + + public: + BlockCode(uint32_t k, uint32_t n, struct fec_parms *code); + + /** + * Add a repair symbol to the dource block. + */ + bool addRepairSymbol(const fec::buffer &packet, uint32_t i); + + /** + * Add a source symbol to the source block. + */ + bool addSourceSymbol(const fec::buffer &packet, uint32_t i); + + /** + * Get current length of source block. + */ + std::size_t length() { return current_block_size_; } + + /** + * Get N + */ + uint32_t getN() { return n_; } + + /** + * Get K + */ + uint32_t getK() { return k_; } + + /** + * Clear source block + */ + void clear(); + + private: + /** + * Add symbol to source block + **/ + bool addSymbol(const fec::buffer &packet, uint32_t i, std::size_t size); + + /** + * Starting from k source symbols, get the n - k repair symbols + */ + void encode(); + + /** + * Starting from k symbols (mixed repair and source), get k source symbols. + * NOTE: It does not make sense to retrieve the k source symbols using the + * very same k source symbols. With the current implementation that case can + * never happen. + */ + void decode(); + + private: + uint32_t k_; + uint32_t n_; + struct fec_parms *code_; + std::size_t max_buffer_size_; + std::size_t current_block_size_; + std::vector<uint32_t> sorted_index_; + bool to_decode_; +}; + +/** + * This class contains common parameters between the fec encoder and decoder. + * In particular it contains: + * - The callback to be called when symbols are encoded / decoded + * - The reference to the static reed-solomon parameters, allocated at program + * startup + * - N and K. Ideally they are useful only for the encoder (the decoder can + * retrieve them from the FEC header). However right now we assume sender and + * receiver agreed on the parameters k and n to use. We will introduce a control + * message later to negotiate them, so that decoder cah dynamically change them + * during the download. + */ +class rs { + /** + * Deleter for static preallocated reed-solomon parameters. + */ + struct MatrixDeleter { + void operator()(struct fec_parms *params); + }; + + /** + * unique_ptr to reed-solomon parameters, with custom deleter to call fec_free + * at the end of the program + */ + using Matrix = std::unique_ptr<struct fec_parms, MatrixDeleter>; + + /** + * Key to retrieve static preallocated reed-solomon parameters. It is pair of + * k and n + */ + using Code = std::pair<std::uint32_t /* k */, std::uint32_t /* n */>; + + /** + * Custom hash function for (k, n) pair. + */ + struct CodeHasher { + std::size_t operator()(const transport::core::fec::rs::Code &code) const { + uint64_t ret = uint64_t(code.first) << 32 | uint64_t(code.second); + return std::hash<uint64_t>{}(ret); + } + }; + + protected: + /** + * Callback to be called after the encode or the decode operations. In the + * former case it will contain the symbols, while in the latter the sources. + */ + using PacketsReady = std::function<void(std::vector<buffer> &)>; + + /** + * The sequence number base. + */ + using SNBase = std::uint32_t; + + /** + * The map of source blocks, used at the decoder side. For the encoding + * operation we can use one source block only, since packet are produced in + * order. + */ + using SourceBlocks = std::unordered_map<SNBase, BlockCode>; + + /** + * Map (k, n) -> reed-solomon parameter + */ + using Codes = std::unordered_map<Code, Matrix, CodeHasher>; + + public: + rs(uint32_t k, uint32_t n); + ~rs() = default; + + /** + * Set callback to call after packet encoding / decoding + */ + void setFECCallback(const PacketsReady &callback); + + virtual void clear() { processed_source_blocks_.clear(); } + + private: + /** + * Create reed-solomon codes at program startup. + */ + static Codes createCodes(); + + protected: + bool processed(SNBase seq_base) { + return processed_source_blocks_.find(seq_base) != + processed_source_blocks_.end(); + } + + std::uint32_t k_; + std::uint32_t n_; + PacketsReady fec_callback_; + + /** + * Keep track of processed source blocks + */ + std::unordered_set<SNBase> processed_source_blocks_; + + static Codes codes_; +}; + +/** + * The reed-solomon encoder. It is feeded with source symbols and it provide + * repair-symbols through the fec_callback_ + */ +class encoder : public rs { + public: + encoder(uint32_t k, uint32_t n); + /** + * Always consume source symbols. + */ + void consume(const fec::buffer &packet, uint32_t index); + + void clear() override { + rs::clear(); + source_block_.clear(); + } + + private: + struct fec_parms *current_code_; + /** + * The source block. As soon as it is filled with k source symbols, the + * encoder calls the callback fec_callback_ and the resets the block 0, ready + * to accept another batch of k source symbols. + */ + BlockCode source_block_; +}; + +/** + * The reed-solomon encoder. It is feeded with source/repair symbols and it + * provides the original source symbols through the fec_callback_ + */ +class decoder : public rs { + public: + decoder(uint32_t k, uint32_t n); + + /** + * Consume source symbol + */ + void consume(const fec::buffer &packet, uint32_t i); + + /** + * Consume repair symbol + */ + void consume(const fec::buffer &packet); + + /** + * Clear decoder to reuse + */ + void clear() override { + rs::clear(); + src_blocks_.clear(); + parked_packets_.clear(); + } + + private: + void recoverPackets(SourceBlocks::iterator &src_block_it); + + private: + /** + * Map of source blocks. We use a map because we may receive symbols belonging + * to diffreent source blocks at the same time, so we need to be able to + * decode many source symbols at the same time. + */ + SourceBlocks src_blocks_; + + /** + * Unordered Map of source symbols for which we did not receive any repair + * symbol in the same source block. Notably this happens when: + * + * - We receive the source symbols first and the repair symbols after + * - We received only source symbols for a given block. In that case it does + * not make any sense to build the source block, since we received all the + * source packet of the block. + */ + std::unordered_map<uint32_t, std::vector<std::pair<buffer, uint32_t>>> + parked_packets_; +}; + +} // namespace fec + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/core/tcp_socket_connector.h b/libtransport/src/core/tcp_socket_connector.h index c57123e9f..9dbd250d1 100644 --- a/libtransport/src/core/tcp_socket_connector.h +++ b/libtransport/src/core/tcp_socket_connector.h @@ -15,12 +15,11 @@ #pragma once +#include <core/connector.h> #include <hicn/transport/config.h> #include <hicn/transport/core/name.h> #include <hicn/transport/utils/branch_prediction.h> -#include <core/connector.h> - #include <asio.hpp> #include <asio/steady_timer.hpp> #include <deque> diff --git a/libtransport/src/http/client_connection.cc b/libtransport/src/http/client_connection.cc index 7a3a636fe..a24a821e7 100644 --- a/libtransport/src/http/client_connection.cc +++ b/libtransport/src/http/client_connection.cc @@ -43,12 +43,6 @@ class HTTPClientConnection::Implementation read_buffer_(nullptr), response_(std::make_shared<HTTPResponse>()), timer_(nullptr) { - consumer_.setSocketOption( - ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY, - (ConsumerContentObjectVerificationCallback)std::bind( - &Implementation::verifyData, this, std::placeholders::_1, - std::placeholders::_2)); - consumer_.setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, this); consumer_.connect(); @@ -124,10 +118,10 @@ class HTTPClientConnection::Implementation return *http_client_; } - HTTPClientConnection &setCertificate(const std::string &cert_path) { - if (consumer_.setSocketOption(GeneralTransportOptions::CERTIFICATE, - cert_path) == SOCKET_OPTION_NOT_SET) { - throw errors::RuntimeException("Error setting the certificate."); + HTTPClientConnection &setVerifier(std::shared_ptr<auth::Verifier> verifier) { + if (consumer_.setSocketOption(GeneralTransportOptions::VERIFIER, + verifier) == SOCKET_OPTION_NOT_SET) { + throw errors::RuntimeException("Error setting the verifier."); } return *http_client_; @@ -177,17 +171,6 @@ class HTTPClientConnection::Implementation consumer_.stop(); } - bool verifyData(interface::ConsumerSocket &c, - const core::ContentObject &contentObject) { - if (contentObject.getPayloadType() == PayloadType::CONTENT_OBJECT) { - TRANSPORT_LOGI("VERIFY CONTENT\n"); - } else if (contentObject.getPayloadType() == PayloadType::MANIFEST) { - TRANSPORT_LOGI("VERIFY MANIFEST\n"); - } - - return true; - } - void processLeavingInterest(interface::ConsumerSocket &c, const core::Interest &interest) { if (interest.payloadSize() == 0) { @@ -307,9 +290,9 @@ HTTPClientConnection &HTTPClientConnection::setTimeout( return implementation_->setTimeout(timeout); } -HTTPClientConnection &HTTPClientConnection::setCertificate( - const std::string &cert_path) { - return implementation_->setCertificate(cert_path); +HTTPClientConnection &HTTPClientConnection::setVerifier( + std::shared_ptr<auth::Verifier> verifier) { + return implementation_->setVerifier(verifier); } } // namespace http diff --git a/libtransport/src/implementation/CMakeLists.txt b/libtransport/src/implementation/CMakeLists.txt index 5423a7697..392c99e15 100644 --- a/libtransport/src/implementation/CMakeLists.txt +++ b/libtransport/src/implementation/CMakeLists.txt @@ -13,13 +13,8 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) -list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc -) - list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/socket.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.h ) @@ -27,15 +22,16 @@ list(APPEND HEADER_FILES if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/socket.cc ) list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.h diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.cc b/libtransport/src/implementation/p2psecure_socket_consumer.cc index 9b79850d6..8c7c175b2 100644 --- a/libtransport/src/implementation/p2psecure_socket_consumer.cc +++ b/libtransport/src/implementation/p2psecure_socket_consumer.cc @@ -15,7 +15,6 @@ #include <implementation/p2psecure_socket_consumer.h> #include <interfaces/tls_socket_consumer.h> - #include <openssl/bio.h> #include <openssl/ssl.h> #include <openssl/tls1.h> @@ -175,7 +174,6 @@ P2PSecureConsumerSocket::P2PSecureConsumerSocket( : ConsumerSocket(consumer, handshake_protocol), name_(), tls_consumer_(nullptr), - buf_pool_(), decrypted_content_(), payload_(), head_(), diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.h b/libtransport/src/implementation/p2psecure_socket_consumer.h index d4c3b26c2..a35a50352 100644 --- a/libtransport/src/implementation/p2psecure_socket_consumer.h +++ b/libtransport/src/implementation/p2psecure_socket_consumer.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/interfaces/socket_consumer.h> - #include <implementation/tls_socket_consumer.h> #include <openssl/bio.h> #include <openssl/ssl.h> @@ -75,7 +74,6 @@ class P2PSecureConsumerSocket : public ConsumerSocket, BIO_METHOD *bio_meth_; /* Chain of MemBuf to be used as a temporary buffer to pass descypted data * from the underlying layer to the application */ - utils::ObjectPool<utils::MemBuf> buf_pool_; std::unique_ptr<utils::MemBuf> decrypted_content_; /* Chain of MemBuf holding the payload to be written into interest or data */ std::unique_ptr<utils::MemBuf> payload_; diff --git a/libtransport/src/implementation/p2psecure_socket_producer.cc b/libtransport/src/implementation/p2psecure_socket_producer.cc index 15c7d25cd..6dff2ba08 100644 --- a/libtransport/src/implementation/p2psecure_socket_producer.cc +++ b/libtransport/src/implementation/p2psecure_socket_producer.cc @@ -14,13 +14,11 @@ */ #include <hicn/transport/core/interest.h> - #include <implementation/p2psecure_socket_producer.h> -#include <implementation/tls_rtc_socket_producer.h> +// #include <implementation/tls_rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> #include <interfaces/tls_rtc_socket_producer.h> #include <interfaces/tls_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> @@ -34,7 +32,8 @@ namespace implementation { P2PSecureProducerSocket::P2PSecureProducerSocket( interface::ProducerSocket *producer_socket) - : ProducerSocket(producer_socket), + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), mtx_(), cv_(), map_producers(), @@ -42,8 +41,9 @@ P2PSecureProducerSocket::P2PSecureProducerSocket( P2PSecureProducerSocket::P2PSecureProducerSocket( interface::ProducerSocket *producer_socket, bool rtc, - const std::shared_ptr<utils::Identity> &identity) - : ProducerSocket(producer_socket), + const std::shared_ptr<auth::Identity> &identity) + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), rtc_(rtc), mtx_(), cv_(), @@ -51,9 +51,9 @@ P2PSecureProducerSocket::P2PSecureProducerSocket( list_producers() { /* Setup SSL context (identity and parameter to use TLS 1.3) */ der_cert_ = parcKeyStore_GetDEREncodedCertificate( - (identity->getSigner()->getKeyStore())); + (identity->getSigner()->getParcKeyStore())); der_prk_ = parcKeyStore_GetDEREncodedPrivateKey( - (identity->getSigner()->getKeyStore())); + (identity->getSigner()->getParcKeyStore())); int cert_size = parcBuffer_Limit(der_cert_); int prk_size = parcBuffer_Limit(der_prk_); @@ -88,15 +88,20 @@ void P2PSecureProducerSocket::initSessionSocket( producer->setSocketOption(MAKE_MANIFEST, this->making_manifest_); producer->setSocketOption(DATA_PACKET_SIZE, (uint32_t)(this->data_packet_size_)); - producer->output_buffer_.setLimit(this->output_buffer_.getLimit()); + uint32_t output_buffer_size = 0; + this->getSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, + output_buffer_size); + producer->setSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, + output_buffer_size); if (!rtc_) { producer->setInterface(new interface::TLSProducerSocket(producer.get())); } else { - TLSRTCProducerSocket *rtc_producer = - dynamic_cast<TLSRTCProducerSocket *>(producer.get()); - rtc_producer->setInterface( - new interface::TLSRTCProducerSocket(rtc_producer)); + // TODO + // TLSRTCProducerSocket *rtc_producer = + // dynamic_cast<TLSRTCProducerSocket *>(producer.get()); + // rtc_producer->setInterface( + // new interface::TLSRTCProducerSocket(rtc_producer)); } } @@ -114,8 +119,9 @@ void P2PSecureProducerSocket::onInterestCallback(interface::ProducerSocket &p, tls_producer = std::make_unique<TLSProducerSocket>(nullptr, this, interest.getName()); } else { - tls_producer = std::make_unique<TLSRTCProducerSocket>(nullptr, this, - interest.getName()); + // TODO + // tls_producer = std::make_unique<TLSRTCProducerSocket>(nullptr, this, + // interest.getName()); } initSessionSocket(tls_producer); @@ -129,15 +135,19 @@ void P2PSecureProducerSocket::onInterestCallback(interface::ProducerSocket &p, tls_producer_ptr->onInterest(*tls_producer_ptr, interest); tls_producer_ptr->async_accept(); } else { - TLSRTCProducerSocket *rtc_producer_ptr = - dynamic_cast<TLSRTCProducerSocket *>(tls_producer_ptr); - rtc_producer_ptr->onInterest(*rtc_producer_ptr, interest); - rtc_producer_ptr->async_accept(); + // TODO + // TLSRTCProducerSocket *rtc_producer_ptr = + // dynamic_cast<TLSRTCProducerSocket *>(tls_producer_ptr); + // rtc_producer_ptr->onInterest(*rtc_producer_ptr, interest); + // rtc_producer_ptr->async_accept(); } } -void P2PSecureProducerSocket::produce(const uint8_t *buffer, - size_t buffer_size) { +uint32_t P2PSecureProducerSocket::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + // TODO + throw errors::NotImplementedException(); + if (!rtc_) { throw errors::RuntimeException( "RTC must be the transport protocol to start the production of current " @@ -148,16 +158,20 @@ void P2PSecureProducerSocket::produce(const uint8_t *buffer, if (list_producers.empty()) cv_.wait(lck); - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) { - TLSRTCProducerSocket *rtc_producer = - dynamic_cast<TLSRTCProducerSocket *>(it->get()); - rtc_producer->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); - } + // TODO + // for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) + // { + // TLSRTCProducerSocket *rtc_producer = + // dynamic_cast<TLSRTCProducerSocket *>(it->get()); + // rtc_producer->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); + // } + + return 0; } -uint32_t P2PSecureProducerSocket::produce( - Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, - uint32_t start_offset) { +uint32_t P2PSecureProducerSocket::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { if (rtc_) { throw errors::RuntimeException( "RTC transport protocol is not compatible with the production of " @@ -170,16 +184,17 @@ uint32_t P2PSecureProducerSocket::produce( if (list_producers.empty()) cv_.wait(lck); for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += - (*it)->produce(content_name, buffer->clone(), is_last, start_offset); + segments += (*it)->produceStream(content_name, buffer->clone(), is_last, + start_offset); return segments; } -uint32_t P2PSecureProducerSocket::produce(Name content_name, - const uint8_t *buffer, - size_t buffer_size, bool is_last, - uint32_t start_offset) { +uint32_t P2PSecureProducerSocket::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, + bool is_last, + uint32_t start_offset) { if (rtc_) { throw errors::RuntimeException( "RTC transport protocol is not compatible with the production of " @@ -191,29 +206,31 @@ uint32_t P2PSecureProducerSocket::produce(Name content_name, if (list_producers.empty()) cv_.wait(lck); for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += (*it)->produce(content_name, buffer, buffer_size, is_last, - start_offset); + segments += (*it)->produceStream(content_name, buffer, buffer_size, is_last, + start_offset); return segments; } -void P2PSecureProducerSocket::asyncProduce(const Name &content_name, - const uint8_t *buf, - size_t buffer_size, bool is_last, - uint32_t *start_offset) { - if (rtc_) { - throw errors::RuntimeException( - "RTC transport protocol is not compatible with the production of " - "current data. Aborting."); - } - - std::unique_lock<std::mutex> lck(mtx_); - if (list_producers.empty()) cv_.wait(lck); - - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) { - (*it)->asyncProduce(content_name, buf, buffer_size, is_last, start_offset); - } -} +// void P2PSecureProducerSocket::asyncProduce(const Name &content_name, +// const uint8_t *buf, +// size_t buffer_size, bool is_last, +// uint32_t *start_offset) { +// if (rtc_) { +// throw errors::RuntimeException( +// "RTC transport protocol is not compatible with the production of " +// "current data. Aborting."); +// } + +// std::unique_lock<std::mutex> lck(mtx_); +// if (list_producers.empty()) cv_.wait(lck); + +// for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) +// { +// (*it)->asyncProduce(content_name, buf, buffer_size, is_last, +// start_offset); +// } +// } void P2PSecureProducerSocket::asyncProduce( Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, @@ -269,7 +286,7 @@ int P2PSecureProducerSocket::setSocketOption( int P2PSecureProducerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) (*it)->setSocketOption(socket_option_key, socket_option_value); @@ -323,16 +340,6 @@ int P2PSecureProducerSocket::setSocketOption(int socket_option_key, } int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, std::list<Prefix> socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) @@ -361,17 +368,7 @@ int P2PSecureProducerSocket::setSocketOption( } int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, utils::CryptoHashType socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, utils::CryptoSuite socket_option_value) { + int socket_option_key, auth::CryptoHashType socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) (*it)->setSocketOption(socket_option_key, socket_option_value); diff --git a/libtransport/src/implementation/p2psecure_socket_producer.h b/libtransport/src/implementation/p2psecure_socket_producer.h index bfc9fc2c1..b7c3d1958 100644 --- a/libtransport/src/implementation/p2psecure_socket_producer.h +++ b/libtransport/src/implementation/p2psecure_socket_producer.h @@ -15,15 +15,14 @@ #pragma once -#include <hicn/transport/security/identity.h> -#include <hicn/transport/security/signer.h> - +#include <hicn/transport/auth/identity.h> +#include <hicn/transport/auth/signer.h> #include <implementation/socket_producer.h> -#include <implementation/tls_rtc_socket_producer.h> +// #include <implementation/tls_rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> +#include <openssl/ssl.h> #include <utils/content_store.h> -#include <openssl/ssl.h> #include <condition_variable> #include <forward_list> #include <mutex> @@ -33,39 +32,40 @@ namespace implementation { class P2PSecureProducerSocket : public ProducerSocket { friend class TLSProducerSocket; - friend class TLSRTCProducerSocket; + // TODO + // friend class TLSRTCProducerSocket; public: explicit P2PSecureProducerSocket(interface::ProducerSocket *producer_socket); explicit P2PSecureProducerSocket( interface::ProducerSocket *producer_socket, bool rtc, - const std::shared_ptr<utils::Identity> &identity); + const std::shared_ptr<auth::Identity> &identity); ~P2PSecureProducerSocket(); - void produce(const uint8_t *buffer, size_t buffer_size) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; - uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, - bool is_last = true, uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; - uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, uint32_t offset, uint32_t **last_segment = nullptr) override; - void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, - bool is_last = true, - uint32_t *start_offset = nullptr) override; - int setSocketOption(int socket_option_key, ProducerInterestCallback socket_option_value) override; int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) override; + const std::shared_ptr<auth::Signer> &socket_option_value) override; int setSocketOption(int socket_option_key, uint32_t socket_option_value) override; @@ -75,9 +75,6 @@ class P2PSecureProducerSocket : public ProducerSocket { int setSocketOption(int socket_option_key, Name *socket_option_value) override; - int setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) override; - int setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) override; @@ -86,16 +83,13 @@ class P2PSecureProducerSocket : public ProducerSocket { ProducerContentCallback socket_option_value) override; int setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) override; - - int setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) override; + auth::CryptoHashType socket_option_value) override; int setSocketOption(int socket_option_key, const std::string &socket_option_value) override; using ProducerSocket::getSocketOption; - using ProducerSocket::onInterest; + // using ProducerSocket::onInterest; protected: /* Callback invoked once an interest has been received and its payload diff --git a/libtransport/src/implementation/socket.cc b/libtransport/src/implementation/socket.cc new file mode 100644 index 000000000..2e21f2bc3 --- /dev/null +++ b/libtransport/src/implementation/socket.cc @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <implementation/socket.h> + +namespace transport { +namespace implementation { + +Socket::Socket(std::shared_ptr<core::Portal> &&portal) + : portal_(std::move(portal)), is_async_(false) {} + +} // namespace implementation +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/implementation/socket.h b/libtransport/src/implementation/socket.h index 2e51f3027..cf22c03e1 100644 --- a/libtransport/src/implementation/socket.h +++ b/libtransport/src/implementation/socket.h @@ -15,13 +15,12 @@ #pragma once +#include <core/facade.h> #include <hicn/transport/config.h> #include <hicn/transport/interfaces/callbacks.h> #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/interfaces/socket_options_keys.h> -#include <core/facade.h> - #define SOCKET_OPTION_GET 0 #define SOCKET_OPTION_NOT_GET 1 #define SOCKET_OPTION_SET 2 @@ -32,56 +31,23 @@ namespace transport { namespace implementation { // Forward Declarations -template <typename PortalType> class Socket; -// Define the portal and its connector, depending on the compilation options -// passed by the build tool. -using HicnForwarderPortal = core::HicnForwarderPortal; - -#ifdef __linux__ -#ifndef __ANDROID__ -using RawSocketPortal = core::RawSocketPortal; -#endif -#endif - -#ifdef __vpp__ -using VPPForwarderPortal = core::VPPForwarderPortal; -using BaseSocket = Socket<VPPForwarderPortal>; -using BasePortal = VPPForwarderPortal; -#else -using BaseSocket = Socket<HicnForwarderPortal>; -using BasePortal = HicnForwarderPortal; -#endif - -template <typename PortalType> class Socket { - static_assert(std::is_same<PortalType, HicnForwarderPortal>::value -#ifdef __linux__ -#ifndef __ANDROID__ - || std::is_same<PortalType, RawSocketPortal>::value -#ifdef __vpp__ - || std::is_same<PortalType, VPPForwarderPortal>::value -#endif -#endif - , -#else - , - -#endif - "This class is not allowed as Portal"); - public: - using Portal = PortalType; - - virtual asio::io_service &getIoService() = 0; - virtual void connect() = 0; - virtual bool isRunning() = 0; + virtual asio::io_service &getIoService() { return portal_->getIoService(); } + protected: + Socket(std::shared_ptr<core::Portal> &&portal); + virtual ~Socket(){}; + + protected: + std::shared_ptr<core::Portal> portal_; + bool is_async_; }; } // namespace implementation diff --git a/libtransport/src/implementation/socket_consumer.h b/libtransport/src/implementation/socket_consumer.h index 87965923e..a7b6ac4e7 100644 --- a/libtransport/src/implementation/socket_consumer.h +++ b/libtransport/src/implementation/socket_consumer.h @@ -13,15 +13,17 @@ * limitations under the License. */ +#pragma once + #include <hicn/transport/interfaces/socket_consumer.h> #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/interfaces/statistics.h> -#include <hicn/transport/security/verifier.h> +#include <hicn/transport/auth/verifier.h> #include <hicn/transport/utils/event_thread.h> #include <protocols/cbr.h> -#include <protocols/protocol.h> #include <protocols/raaqm.h> -#include <protocols/rtc.h> +#include <protocols/rtc/rtc.h> +#include <protocols/transport_protocol.h> namespace transport { namespace implementation { @@ -30,12 +32,12 @@ using namespace core; using namespace interface; using ReadCallback = interface::ConsumerSocket::ReadCallback; -class ConsumerSocket : public Socket<BasePortal> { +class ConsumerSocket : public Socket { private: ConsumerSocket(interface::ConsumerSocket *consumer, int protocol, - std::shared_ptr<Portal> &&portal) - : consumer_interface_(consumer), - portal_(portal), + std::shared_ptr<core::Portal> &&portal) + : Socket(std::move(portal)), + consumer_interface_(consumer), async_downloader_(), interest_lifetime_(default_values::interest_lifetime), min_window_size_(default_values::min_window_size), @@ -54,16 +56,13 @@ class ConsumerSocket : public Socket<BasePortal> { rate_estimation_observer_(nullptr), rate_estimation_batching_parameter_(default_values::batch), rate_estimation_choice_(0), - is_async_(false), - verifier_(std::make_shared<utils::Verifier>()), + verifier_(std::make_shared<auth::VoidVerifier>()), verify_signature_(false), - key_content_(false), reset_window_(false), on_interest_output_(VOID_HANDLER), on_interest_timeout_(VOID_HANDLER), on_interest_satisfied_(VOID_HANDLER), on_content_object_input_(VOID_HANDLER), - on_content_object_verification_(VOID_HANDLER), stats_summary_(VOID_HANDLER), read_callback_(nullptr), timer_interval_milliseconds_(0), @@ -75,7 +74,7 @@ class ConsumerSocket : public Socket<BasePortal> { break; case TransportProtocolAlgorithms::RTC: transport_protocol_ = - std::make_unique<protocol::RTCTransportProtocol>(this); + std::make_unique<protocol::rtc::RTCTransportProtocol>(this); break; case TransportProtocolAlgorithms::RAAQM: default: @@ -87,12 +86,12 @@ class ConsumerSocket : public Socket<BasePortal> { public: ConsumerSocket(interface::ConsumerSocket *consumer, int protocol) - : ConsumerSocket(consumer, protocol, std::make_shared<Portal>()) {} + : ConsumerSocket(consumer, protocol, std::make_shared<core::Portal>()) {} ConsumerSocket(interface::ConsumerSocket *consumer, int protocol, asio::io_service &io_service) : ConsumerSocket(consumer, protocol, - std::make_shared<Portal>(io_service)) { + std::make_shared<core::Portal>(io_service)) { is_async_ = true; } @@ -138,8 +137,6 @@ class ConsumerSocket : public Socket<BasePortal> { return CONSUMER_RUNNING; } - bool verifyKeyPackets() { return transport_protocol_->verifyKeyPackets(); } - void stop() { if (transport_protocol_->isRunning()) { transport_protocol_->stop(); @@ -152,8 +149,6 @@ class ConsumerSocket : public Socket<BasePortal> { } } - asio::io_service &getIoService() { return portal_->getIoService(); } - virtual int setSocketOption(int socket_option_key, ReadCallback *socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -316,12 +311,6 @@ class ConsumerSocket : public Socket<BasePortal> { break; } - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - if (socket_option_value == VOID_HANDLER) { - on_content_object_verification_ = VOID_HANDLER; - break; - } - default: return SOCKET_OPTION_NOT_SET; } @@ -334,16 +323,6 @@ class ConsumerSocket : public Socket<BasePortal> { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { - case GeneralTransportOptions::VERIFY_SIGNATURE: - verify_signature_ = socket_option_value; - result = SOCKET_OPTION_SET; - break; - - case GeneralTransportOptions::KEY_CONTENT: - key_content_ = socket_option_value; - result = SOCKET_OPTION_SET; - break; - case RaaqmTransportOptions::PER_SESSION_CWINDOW_RESET: reset_window_ = socket_option_value; result = SOCKET_OPTION_SET; @@ -377,29 +356,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - on_content_object_verification_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); - } - int setSocketOption(int socket_option_key, ConsumerInterestCallback socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -433,51 +389,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this]( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::VERIFICATION_FAILED: - verification_failed_callback_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); - } - - // int setSocketOption( - // int socket_option_key, - // ConsumerContentObjectVerificationFailedCallback socket_option_value) { - // return rescheduleOnIOService( - // socket_option_key, socket_option_value, - // [this]( - // int socket_option_key, - // ConsumerContentObjectVerificationFailedCallback - // socket_option_value) - // -> int { - // switch (socket_option_key) { - // case ConsumerCallbacksOptions::VERIFICATION_FAILED: - // verification_failed_callback_ = socket_option_value; - // break; - - // default: - // return SOCKET_OPTION_NOT_SET; - // } - - // return SOCKET_OPTION_SET; - // }); - // } - int setSocketOption(int socket_option_key, IcnObserver *socket_option_value) { utils::SpinLock::Acquire locked(guard_raaqm_params_); switch (socket_option_key) { @@ -494,7 +405,7 @@ class ConsumerSocket : public Socket<BasePortal> { int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Verifier> &socket_option_value) { + const std::shared_ptr<auth::Verifier> &socket_option_value) { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { @@ -516,14 +427,6 @@ class ConsumerSocket : public Socket<BasePortal> { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { - case GeneralTransportOptions::CERTIFICATE: - key_id_ = verifier_->addKeyFromCertificate(socket_option_value); - - if (key_id_ != nullptr) { - result = SOCKET_OPTION_SET; - } - break; - case DataLinkOptions::OUTPUT_INTERFACE: output_interface_ = socket_option_value; portal_->setOutputInterface(output_interface_); @@ -642,14 +545,6 @@ class ConsumerSocket : public Socket<BasePortal> { socket_option_value = transport_protocol_->isRunning(); break; - case GeneralTransportOptions::VERIFY_SIGNATURE: - socket_option_value = verify_signature_; - break; - - case GeneralTransportOptions::KEY_CONTENT: - socket_option_value = key_content_; - break; - case GeneralTransportOptions::ASYNC_MODE: socket_option_value = is_async_; break; @@ -699,29 +594,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - *socket_option_value = &on_content_object_verification_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); - } - int getSocketOption(int socket_option_key, ConsumerInterestCallback **socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -755,30 +627,8 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback **socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationFailedCallback * - *socket_option_value) -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::VERIFICATION_FAILED: - *socket_option_value = &verification_failed_callback_; - break; - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); - } - int getSocketOption(int socket_option_key, - std::shared_ptr<Portal> &socket_option_value) { + std::shared_ptr<core::Portal> &socket_option_value) { switch (socket_option_key) { case PORTAL: socket_option_value = portal_; @@ -807,7 +657,7 @@ class ConsumerSocket : public Socket<BasePortal> { } int getSocketOption(int socket_option_key, - std::shared_ptr<utils::Verifier> &socket_option_value) { + std::shared_ptr<auth::Verifier> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::VERIFIER: socket_option_value = verifier_; @@ -871,7 +721,7 @@ class ConsumerSocket : public Socket<BasePortal> { // To enforce type check std::function<int(int, arg2)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (transport_protocol_->isRunning()) { + if (transport_protocol_ && transport_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; @@ -898,7 +748,6 @@ class ConsumerSocket : public Socket<BasePortal> { protected: interface::ConsumerSocket *consumer_interface_; - std::shared_ptr<Portal> portal_; utils::EventThread async_downloader_; // No need to protect from multiple accesses in the async consumer @@ -926,13 +775,10 @@ class ConsumerSocket : public Socket<BasePortal> { int rate_estimation_batching_parameter_; int rate_estimation_choice_; - bool is_async_; - // Verification parameters - std::shared_ptr<utils::Verifier> verifier_; + std::shared_ptr<auth::Verifier> verifier_; PARCKeyId *key_id_; std::atomic_bool verify_signature_; - bool key_content_; bool reset_window_; ConsumerInterestCallback on_interest_retransmission_; @@ -940,9 +786,7 @@ class ConsumerSocket : public Socket<BasePortal> { ConsumerInterestCallback on_interest_timeout_; ConsumerInterestCallback on_interest_satisfied_; ConsumerContentObjectCallback on_content_object_input_; - ConsumerContentObjectVerificationCallback on_content_object_verification_; ConsumerTimerCallback stats_summary_; - ConsumerContentObjectVerificationFailedCallback verification_failed_callback_; ReadCallback *read_callback_; @@ -959,4 +803,4 @@ class ConsumerSocket : public Socket<BasePortal> { }; } // namespace implementation -} // namespace transport
\ No newline at end of file +} // namespace transport diff --git a/libtransport/src/implementation/socket_producer.h b/libtransport/src/implementation/socket_producer.h index a6f0f969e..af69cd818 100644 --- a/libtransport/src/implementation/socket_producer.h +++ b/libtransport/src/implementation/socket_producer.h @@ -15,9 +15,11 @@ #pragma once -#include <hicn/transport/security/signer.h> +#include <hicn/transport/auth/signer.h> #include <hicn/transport/utils/event_thread.h> #include <implementation/socket.h> +#include <protocols/prod_protocol_bytestream.h> +#include <protocols/prod_protocol_rtc.h> #include <utils/content_store.h> #include <utils/suffix_strategy.h> @@ -39,21 +41,17 @@ namespace implementation { using namespace core; using namespace interface; -class ProducerSocket : public Socket<BasePortal>, - public BasePortal::ProducerCallback { - static constexpr uint32_t burst_size = 256; - - public: - explicit ProducerSocket(interface::ProducerSocket *producer_socket) - : producer_interface_(producer_socket), - portal_(std::make_shared<Portal>(io_service_)), +class ProducerSocket : public Socket { + private: + ProducerSocket(interface::ProducerSocket *producer_socket, int protocol, + std::shared_ptr<core::Portal> &&portal) + : Socket(std::move(portal)), + producer_interface_(producer_socket), data_packet_size_(default_values::content_object_packet_size), content_object_expiry_time_(default_values::content_object_expiry_time), - output_buffer_(default_values::producer_socket_output_buffer_size), async_thread_(), - registration_status_(REGISTRATION_NOT_ATTEMPTED), making_manifest_(false), - hash_algorithm_(utils::CryptoHashType::SHA_256), + hash_algorithm_(auth::CryptoHashType::SHA_256), suffix_strategy_(core::NextSegmentCalculationStrategy::INCREMENTAL), on_interest_input_(VOID_HANDLER), on_interest_dropped_input_buffer_(VOID_HANDLER), @@ -65,15 +63,33 @@ class ProducerSocket : public Socket<BasePortal>, on_content_object_in_output_buffer_(VOID_HANDLER), on_content_object_output_(VOID_HANDLER), on_content_object_evicted_from_output_buffer_(VOID_HANDLER), - on_content_produced_(VOID_HANDLER) {} - - virtual ~ProducerSocket() { - stop(); - if (listening_thread_.joinable()) { - listening_thread_.join(); + on_content_produced_(VOID_HANDLER) { + switch (protocol) { + case ProductionProtocolAlgorithms::RTC_PROD: + production_protocol_ = + std::make_unique<protocol::RTCProductionProtocol>(this); + break; + case ProductionProtocolAlgorithms::BYTE_STREAM: + default: + production_protocol_ = + std::make_unique<protocol::ByteStreamProductionProtocol>(this); + break; } } + public: + ProducerSocket(interface::ProducerSocket *producer, int protocol) + : ProducerSocket(producer, protocol, std::make_shared<core::Portal>()) {} + + ProducerSocket(interface::ProducerSocket *producer, int protocol, + asio::io_service &io_service) + : ProducerSocket(producer, protocol, + std::make_shared<core::Portal>(io_service)) { + is_async_ = true; + } + + virtual ~ProducerSocket() {} + interface::ProducerSocket *getInterface() { return producer_interface_; } @@ -84,296 +100,10 @@ class ProducerSocket : public Socket<BasePortal>, void connect() override { portal_->connect(false); - listening_thread_ = std::thread(std::bind(&ProducerSocket::listen, this)); + production_protocol_->start(); } - bool isRunning() override { return !io_service_.stopped(); }; - - virtual uint32_t produce(Name content_name, const uint8_t *buffer, - size_t buffer_size, bool is_last = true, - uint32_t start_offset = 0) { - return ProducerSocket::produce( - content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), is_last, - start_offset); - } - - virtual uint32_t produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) { - if (TRANSPORT_EXPECT_FALSE(buffer->length() == 0)) { - return 0; - } - - // Copy the atomic variables to ensure they keep the same value - // during the production - std::size_t data_packet_size = data_packet_size_; - uint32_t content_object_expiry_time = content_object_expiry_time_; - utils::CryptoHashType hash_algo = hash_algorithm_; - bool making_manifest = making_manifest_; - auto suffix_strategy = utils::SuffixStrategyFactory::getSuffixStrategy( - suffix_strategy_, start_offset); - std::shared_ptr<utils::Signer> signer; - getSocketOption(GeneralTransportOptions::SIGNER, signer); - - auto buffer_size = buffer->length(); - int bytes_segmented = 0; - std::size_t header_size; - std::size_t manifest_header_size = 0; - std::size_t signature_length = 0; - std::uint32_t final_block_number = start_offset; - uint64_t free_space_for_content = 0; - - core::Packet::Format format; - std::shared_ptr<ContentObjectManifest> manifest; - bool is_last_manifest = false; - - // TODO Manifest may still be used for indexing - if (making_manifest && !signer) { - TRANSPORT_LOGE("Making manifests without setting producer identity."); - } - - core::Packet::Format hf_format = core::Packet::Format::HF_UNSPEC; - core::Packet::Format hf_format_ah = core::Packet::Format::HF_UNSPEC; - if (content_name.getType() == HNT_CONTIGUOUS_V4 || - content_name.getType() == HNT_IOV_V4) { - hf_format = core::Packet::Format::HF_INET_TCP; - hf_format_ah = core::Packet::Format::HF_INET_TCP_AH; - } else if (content_name.getType() == HNT_CONTIGUOUS_V6 || - content_name.getType() == HNT_IOV_V6) { - hf_format = core::Packet::Format::HF_INET6_TCP; - hf_format_ah = core::Packet::Format::HF_INET6_TCP_AH; - } else { - throw errors::RuntimeException("Unknown name format."); - } - - format = hf_format; - if (making_manifest) { - manifest_header_size = core::Packet::getHeaderSizeFromFormat( - signer ? hf_format_ah : hf_format, - signer ? signer->getSignatureLength() : 0); - } else if (signer) { - format = hf_format_ah; - signature_length = signer->getSignatureLength(); - } - - header_size = - core::Packet::getHeaderSizeFromFormat(format, signature_length); - free_space_for_content = data_packet_size - header_size; - uint32_t number_of_segments = uint32_t( - std::ceil(double(buffer_size) / double(free_space_for_content))); - if (free_space_for_content * number_of_segments < buffer_size) { - number_of_segments++; - } - - // TODO allocate space for all the headers - if (making_manifest) { - uint32_t segment_in_manifest = static_cast<uint32_t>( - std::floor(double(data_packet_size - manifest_header_size - - ContentObjectManifest::getManifestHeaderSize()) / - ContentObjectManifest::getManifestEntrySize()) - - 1.0); - uint32_t number_of_manifests = static_cast<uint32_t>( - std::ceil(float(number_of_segments) / segment_in_manifest)); - final_block_number += number_of_segments + number_of_manifests - 1; - - manifest.reset(ContentObjectManifest::createManifest( - content_name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, - hash_algo, is_last_manifest, content_name, suffix_strategy_, - signer ? signer->getSignatureLength() : 0)); - manifest->setLifetime(content_object_expiry_time); - - if (is_last) { - manifest->setFinalBlockNumber(final_block_number); - } else { - manifest->setFinalBlockNumber(utils::SuffixStrategy::INVALID_SUFFIX); - } - } - - for (unsigned int packaged_segments = 0; - packaged_segments < number_of_segments; packaged_segments++) { - if (making_manifest) { - if (manifest->estimateManifestSize(2) > - data_packet_size - manifest_header_size) { - // Send the current manifest - manifest->encode(); - - // If identity set, sign manifest - if (signer) { - signer->sign(*manifest); - } - - passContentObjectToCallbacks(manifest); - TRANSPORT_LOGD("Send manifest %s", - manifest->getName().toString().c_str()); - - // Send content objects stored in the queue - while (!content_queue_.empty()) { - passContentObjectToCallbacks(content_queue_.front()); - TRANSPORT_LOGD( - "Send content %s", - content_queue_.front()->getName().toString().c_str()); - content_queue_.pop(); - } - - // Create new manifest. The reference to the last manifest has been - // acquired in the passContentObjectToCallbacks function, so we can - // safely release this reference - manifest.reset(ContentObjectManifest::createManifest( - content_name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, - core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, - content_name, suffix_strategy_, - signer ? signer->getSignatureLength() : 0)); - - manifest->setLifetime(content_object_expiry_time); - manifest->setFinalBlockNumber( - is_last ? final_block_number - : utils::SuffixStrategy::INVALID_SUFFIX); - } - } - - auto content_suffix = suffix_strategy->getNextContentSuffix(); - auto content_object = std::make_shared<ContentObject>( - content_name.setSuffix(content_suffix), format); - content_object->setLifetime(content_object_expiry_time); - - auto b = buffer->cloneOne(); - b->trimStart(free_space_for_content * packaged_segments); - b->trimEnd(b->length()); - - if (TRANSPORT_EXPECT_FALSE(packaged_segments == number_of_segments - 1)) { - b->append(buffer_size - bytes_segmented); - bytes_segmented += (int)(buffer_size - bytes_segmented); - - if (is_last && making_manifest) { - is_last_manifest = true; - } else if (is_last) { - content_object->setRst(); - } - - } else { - b->append(free_space_for_content); - bytes_segmented += (int)(free_space_for_content); - } - - content_object->appendPayload(std::move(b)); - - if (making_manifest) { - using namespace std::chrono_literals; - utils::CryptoHash hash = content_object->computeDigest(hash_algo); - manifest->addSuffixHash(content_suffix, hash); - content_queue_.push(content_object); - } else { - if (signer) { - signer->sign(*content_object); - } - passContentObjectToCallbacks(content_object); - TRANSPORT_LOGD("Send content %s", - content_object->getName().toString().c_str()); - } - } - - if (making_manifest) { - if (is_last_manifest) { - manifest->setFinalManifest(is_last_manifest); - } - - manifest->encode(); - if (signer) { - signer->sign(*manifest); - } - - passContentObjectToCallbacks(manifest); - TRANSPORT_LOGD("Send manifest %s", - manifest->getName().toString().c_str()); - - while (!content_queue_.empty()) { - passContentObjectToCallbacks(content_queue_.front()); - TRANSPORT_LOGD("Send content %s", - content_queue_.front()->getName().toString().c_str()); - content_queue_.pop(); - } - } - - io_service_.post([this]() { - std::shared_ptr<ContentObject> co; - while (object_queue_for_callbacks_.pop(co)) { - if (on_new_segment_) { - on_new_segment_(*producer_interface_, *co); - } - - if (on_content_object_to_sign_) { - on_content_object_to_sign_(*producer_interface_, *co); - } - - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, *co); - } - - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *co); - } - } - }); - - io_service_.dispatch([this, buffer_size]() { - if (on_content_produced_) { - on_content_produced_(*producer_interface_, - std::make_error_code(std::errc(0)), buffer_size); - } - }); - - return suffix_strategy->getTotalCount(); - } - - virtual void produce(ContentObject &content_object) { - io_service_.dispatch([this, &content_object]() { - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, - content_object); - } - }); - - output_buffer_.insert(std::static_pointer_cast<ContentObject>( - content_object.shared_from_this())); - - io_service_.dispatch([this, &content_object]() { - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, content_object); - } - }); - - portal_->sendContentObject(content_object); - } - - virtual void produce(const uint8_t *buffer, size_t buffer_size) { - produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); - } - - virtual void produce(std::unique_ptr<utils::MemBuf> &&buffer) { - // This API is meant to be used just with the RTC producer. - // Here it cannot be used since no name for the content is specified. - throw errors::NotImplementedException(); - } - - virtual void asyncProduce(const Name &suffix, const uint8_t *buf, - size_t buffer_size, bool is_last = true, - uint32_t *start_offset = nullptr) { - if (!async_thread_.stopped()) { - async_thread_.add([this, suffix, buffer = buf, size = buffer_size, - is_last, start_offset]() { - if (start_offset != nullptr) { - *start_offset = ProducerSocket::produce(suffix, buffer, size, is_last, - *start_offset); - } else { - ProducerSocket::produce(suffix, buffer, size, is_last, 0); - } - }); - } - } - - void asyncProduce(const Name &suffix); + bool isRunning() override { return !production_protocol_->isRunning(); }; virtual void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, @@ -381,75 +111,56 @@ class ProducerSocket : public Socket<BasePortal>, uint32_t **last_segment = nullptr) { if (!async_thread_.stopped()) { auto a = buffer.release(); - async_thread_.add( - [this, content_name, a, is_last, offset, last_segment]() { - auto buf = std::unique_ptr<utils::MemBuf>(a); - if (last_segment != NULL) { - **last_segment = - offset + ProducerSocket::produce(content_name, std::move(buf), - is_last, offset); - } else { - ProducerSocket::produce(content_name, std::move(buf), is_last, - offset); - } - }); - } - } - - virtual void asyncProduce(ContentObject &content_object) { - if (!async_thread_.stopped()) { - auto co_ptr = std::static_pointer_cast<ContentObject>( - content_object.shared_from_this()); - async_thread_.add([this, content_object = std::move(co_ptr)]() { - ProducerSocket::produce(*content_object); + async_thread_.add([this, content_name, a, is_last, offset, + last_segment]() { + auto buf = std::unique_ptr<utils::MemBuf>(a); + if (last_segment != NULL) { + **last_segment = offset + produceStream(content_name, std::move(buf), + is_last, offset); + } else { + produceStream(content_name, std::move(buf), is_last, offset); + } }); } } - virtual void registerPrefix(const Prefix &producer_namespace) { - served_namespaces_.push_back(producer_namespace); + virtual uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) { + return production_protocol_->produceStream(content_name, std::move(buffer), + is_last, start_offset); } - void serveForever() { - if (listening_thread_.joinable()) { - listening_thread_.join(); - } + virtual uint32_t produceStream(const Name &content_name, + const uint8_t *buffer, size_t buffer_size, + bool is_last = true, + uint32_t start_offset = 0) { + return production_protocol_->produceStream( + content_name, buffer, buffer_size, is_last, start_offset); } - void stop() { portal_->stopEventsLoop(); } - - asio::io_service &getIoService() override { return portal_->getIoService(); }; - - virtual void onInterest(Interest &interest) { - if (on_interest_input_) { - on_interest_input_(*producer_interface_, interest); - } - - const std::shared_ptr<ContentObject> content_object = - output_buffer_.find(interest); - - if (content_object) { - if (on_interest_satisfied_output_buffer_) { - on_interest_satisfied_output_buffer_(*producer_interface_, interest); - } + virtual uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) { + return production_protocol_->produceDatagram(content_name, + std::move(buffer)); + } - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *content_object); - } + virtual uint32_t produceDatagram(const Name &content_name, + const uint8_t *buffer, size_t buffer_size) { + return production_protocol_->produceDatagram(content_name, buffer, + buffer_size); + } - portal_->sendContentObject(*content_object); - } else { - if (on_interest_process_) { - on_interest_process_(*producer_interface_, interest); - } - } + void produce(ContentObject &content_object) { + production_protocol_->produce(content_object); } - virtual void onInterest(Interest::Ptr &&interest) override { - onInterest(*interest); - }; + void registerPrefix(const Prefix &producer_namespace) { + production_protocol_->registerNamespaceWithNetwork(producer_namespace); + } - virtual void onError(std::error_code ec) override {} + void stop() { production_protocol_->stop(); } virtual int setSocketOption(int socket_option_key, uint32_t socket_option_value) { @@ -462,7 +173,7 @@ class ProducerSocket : public Socket<BasePortal>, break; case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: - output_buffer_.setLimit(socket_option_value); + production_protocol_->setOutputBufferSize(socket_option_value); break; case GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME: @@ -533,6 +244,12 @@ class ProducerSocket : public Socket<BasePortal>, break; } + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + if (socket_option_value == VOID_HANDLER) { + on_content_object_to_sign_ = VOID_HANDLER; + break; + } + default: return SOCKET_OPTION_NOT_SET; } @@ -559,19 +276,6 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_NOT_SET; } - virtual int setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) { - switch (socket_option_key) { - case GeneralTransportOptions::NETWORK_NAME: - served_namespaces_ = socket_option_value; - break; - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - } - virtual int setSocketOption( int socket_option_key, interface::ProducerContentObjectCallback socket_option_value) { @@ -594,6 +298,10 @@ class ProducerSocket : public Socket<BasePortal>, on_content_object_output_ = socket_option_value; break; + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + on_content_object_to_sign_ = socket_option_value; + break; + default: return SOCKET_OPTION_NOT_SET; } @@ -663,7 +371,7 @@ class ProducerSocket : public Socket<BasePortal>, } virtual int setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) { + auth::CryptoHashType socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::HASH_ALGORITHM: hash_algorithm_ = socket_option_value; @@ -675,11 +383,12 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_SET; } - virtual int setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) { + virtual int setSocketOption( + int socket_option_key, + core::NextSegmentCalculationStrategy socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::CRYPTO_SUITE: - crypto_suite_ = socket_option_value; + case GeneralTransportOptions::SUFFIX_STRATEGY: + suffix_strategy_ = socket_option_value; break; default: return SOCKET_OPTION_NOT_SET; @@ -690,7 +399,7 @@ class ProducerSocket : public Socket<BasePortal>, virtual int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::SIGNER: { utils::SpinLock::Acquire locked(signer_lock_); @@ -708,7 +417,7 @@ class ProducerSocket : public Socket<BasePortal>, uint32_t &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: - socket_option_value = (uint32_t)output_buffer_.getLimit(); + socket_option_value = production_protocol_->getOutputBufferSize(); break; case GeneralTransportOptions::DATA_PACKET_SIZE: @@ -733,18 +442,8 @@ class ProducerSocket : public Socket<BasePortal>, socket_option_value = making_manifest_; break; - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - } - - virtual int getSocketOption(int socket_option_key, - std::list<Prefix> &socket_option_value) { - switch (socket_option_key) { - case GeneralTransportOptions::NETWORK_NAME: - socket_option_value = served_namespaces_; + case GeneralTransportOptions::ASYNC_MODE: + socket_option_value = is_async_; break; default: @@ -776,6 +475,10 @@ class ProducerSocket : public Socket<BasePortal>, *socket_option_value = &on_content_object_output_; break; + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + *socket_option_value = &on_content_object_to_sign_; + break; + default: return SOCKET_OPTION_NOT_GET; } @@ -828,11 +531,11 @@ class ProducerSocket : public Socket<BasePortal>, *socket_option_value = &on_interest_inserted_input_buffer_; break; - case CACHE_HIT: + case ProducerCallbacksOptions::CACHE_HIT: *socket_option_value = &on_interest_satisfied_output_buffer_; break; - case CACHE_MISS: + case ProducerCallbacksOptions::CACHE_MISS: *socket_option_value = &on_interest_process_; break; @@ -844,8 +547,9 @@ class ProducerSocket : public Socket<BasePortal>, }); } - virtual int getSocketOption(int socket_option_key, - std::shared_ptr<Portal> &socket_option_value) { + virtual int getSocketOption( + int socket_option_key, + std::shared_ptr<core::Portal> &socket_option_value) { switch (socket_option_key) { case PORTAL: socket_option_value = portal_; @@ -859,7 +563,7 @@ class ProducerSocket : public Socket<BasePortal>, } virtual int getSocketOption(int socket_option_key, - utils::CryptoHashType &socket_option_value) { + auth::CryptoHashType &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::HASH_ALGORITHM: socket_option_value = hash_algorithm_; @@ -871,22 +575,22 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_GET; } - virtual int getSocketOption(int socket_option_key, - utils::CryptoSuite &socket_option_value) { + virtual int getSocketOption( + int socket_option_key, + core::NextSegmentCalculationStrategy &socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::HASH_ALGORITHM: - socket_option_value = crypto_suite_; + case GeneralTransportOptions::SUFFIX_STRATEGY: + socket_option_value = suffix_strategy_; break; default: return SOCKET_OPTION_NOT_GET; } - return SOCKET_OPTION_GET; } virtual int getSocketOption( int socket_option_key, - std::shared_ptr<utils::Signer> &socket_option_value) { + std::shared_ptr<auth::Signer> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::SIGNER: { utils::SpinLock::Acquire locked(signer_lock_); @@ -907,19 +611,21 @@ class ProducerSocket : public Socket<BasePortal>, // If the thread calling lambda_func is not the same of io_service, this // function reschedule the function on it template <typename Lambda, typename arg2> - int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, - Lambda lambda_func) { + int rescheduleOnIOServiceWithReference(int socket_option_key, + arg2 &socket_option_value, + Lambda lambda_func) { // To enforce type check - std::function<int(int, arg2)> func = lambda_func; + std::function<int(int, arg2 &)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (listening_thread_.joinable() && - std::this_thread::get_id() != listening_thread_.get_id()) { + if (production_protocol_ && production_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; + bool done = false; - io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, - &result, &done, &func]() { + portal_->getIoService().dispatch([&socket_option_key, + &socket_option_value, &mtx, &cv, + &result, &done, &func]() { std::unique_lock<std::mutex> lck(mtx); done = true; result = func(socket_option_key, socket_option_value); @@ -939,21 +645,19 @@ class ProducerSocket : public Socket<BasePortal>, // If the thread calling lambda_func is not the same of io_service, this // function reschedule the function on it template <typename Lambda, typename arg2> - int rescheduleOnIOServiceWithReference(int socket_option_key, - arg2 &socket_option_value, - Lambda lambda_func) { + int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, + Lambda lambda_func) { // To enforce type check - std::function<int(int, arg2 &)> func = lambda_func; + std::function<int(int, arg2)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (listening_thread_.joinable() && - std::this_thread::get_id() != this->listening_thread_.get_id()) { + if (production_protocol_ && production_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; - bool done = false; - io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, - &result, &done, &func]() { + portal_->getIoService().dispatch([&socket_option_key, + &socket_option_value, &mtx, &cv, + &result, &done, &func]() { std::unique_lock<std::mutex> lck(mtx); done = true; result = func(socket_option_key, socket_option_value); @@ -973,39 +677,20 @@ class ProducerSocket : public Socket<BasePortal>, // Threads protected: interface::ProducerSocket *producer_interface_; - std::thread listening_thread_; asio::io_service io_service_; - std::shared_ptr<Portal> portal_; std::atomic<size_t> data_packet_size_; - std::list<Prefix> - served_namespaces_; // No need to be threadsafe, this is always modified - // by the application thread std::atomic<uint32_t> content_object_expiry_time_; - utils::CircularFifo<std::shared_ptr<ContentObject>, 2048> - object_queue_for_callbacks_; - - // buffers - // ContentStore is thread-safe - utils::ContentStore output_buffer_; - utils::EventThread async_thread_; - int registration_status_; std::atomic<bool> making_manifest_; - - // map for storing sequence numbers for several calls of the publish - // function - std::unordered_map<Name, std::unordered_map<int, uint32_t>> seq_number_map_; - - std::atomic<utils::CryptoHashType> hash_algorithm_; - std::atomic<utils::CryptoSuite> crypto_suite_; + std::atomic<auth::CryptoHashType> hash_algorithm_; + std::atomic<auth::CryptoSuite> crypto_suite_; utils::SpinLock signer_lock_; - std::shared_ptr<utils::Signer> signer_; + std::shared_ptr<auth::Signer> signer_; core::NextSegmentCalculationStrategy suffix_strategy_; - // While manifests are being built, contents are stored in a queue - std::queue<std::shared_ptr<ContentObject>> content_queue_; + std::unique_ptr<protocol::ProductionProtocol> production_protocol_; // callbacks ProducerInterestCallback on_interest_input_; @@ -1021,63 +706,6 @@ class ProducerSocket : public Socket<BasePortal>, ProducerContentObjectCallback on_content_object_evicted_from_output_buffer_; ProducerContentCallback on_content_produced_; - - private: - void listen() { - bool first = true; - - for (core::Prefix &producer_namespace : served_namespaces_) { - if (first) { - core::BindConfig bind_config(producer_namespace, 1000); - portal_->bind(bind_config); - portal_->setProducerCallback(this); - first = !first; - } else { - portal_->registerRoute(producer_namespace); - } - } - - portal_->runEventsLoop(); - } - - void scheduleSendBurst() { - io_service_.post([this]() { - std::shared_ptr<ContentObject> co; - - for (uint32_t i = 0; i < burst_size; i++) { - if (object_queue_for_callbacks_.pop(co)) { - if (on_new_segment_) { - on_new_segment_(*producer_interface_, *co); - } - - if (on_content_object_to_sign_) { - on_content_object_to_sign_(*producer_interface_, *co); - } - - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, *co); - } - - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *co); - } - } else { - break; - } - } - }); - } - - void passContentObjectToCallbacks( - const std::shared_ptr<ContentObject> &content_object) { - output_buffer_.insert(content_object); - portal_->sendContentObject(*content_object); - object_queue_for_callbacks_.push(std::move(content_object)); - - if (object_queue_for_callbacks_.size() >= burst_size) { - scheduleSendBurst(); - } - } }; } // namespace implementation diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.cc b/libtransport/src/implementation/tls_rtc_socket_producer.cc index 9ef79ca23..9a62c8683 100644 --- a/libtransport/src/implementation/tls_rtc_socket_producer.cc +++ b/libtransport/src/implementation/tls_rtc_socket_producer.cc @@ -15,10 +15,8 @@ #include <hicn/transport/core/interest.h> #include <hicn/transport/interfaces/p2psecure_socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> #include <implementation/tls_rtc_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.h b/libtransport/src/implementation/tls_rtc_socket_producer.h index 685c91244..92c657afc 100644 --- a/libtransport/src/implementation/tls_rtc_socket_producer.h +++ b/libtransport/src/implementation/tls_rtc_socket_producer.h @@ -15,7 +15,6 @@ #pragma once -#include <implementation/rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> namespace transport { @@ -23,8 +22,7 @@ namespace implementation { class P2PSecureProducerSocket; -class TLSRTCProducerSocket : public RTCProducerSocket, - public TLSProducerSocket { +class TLSRTCProducerSocket : public TLSProducerSocket { friend class P2PSecureProducerSocket; public: @@ -34,7 +32,8 @@ class TLSRTCProducerSocket : public RTCProducerSocket, ~TLSRTCProducerSocket() = default; - void produce(std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; void accept() override; diff --git a/libtransport/src/implementation/tls_socket_consumer.cc b/libtransport/src/implementation/tls_socket_consumer.cc index 1be6f41a7..99bcd4360 100644 --- a/libtransport/src/implementation/tls_socket_consumer.cc +++ b/libtransport/src/implementation/tls_socket_consumer.cc @@ -136,7 +136,6 @@ TLSConsumerSocket::TLSConsumerSocket(interface::ConsumerSocket *consumer_socket, int protocol, SSL *ssl) : ConsumerSocket(consumer_socket, protocol), name_(), - buf_pool_(), decrypted_content_(), payload_(), head_(), @@ -223,14 +222,15 @@ int TLSConsumerSocket::download_content(const Name &name) { content_downloaded_ = false; std::size_t max_buffer_size = read_callback_decrypted_->maxBufferSize(); - std::size_t buffer_size = read_callback_decrypted_->maxBufferSize() + SSL3_RT_MAX_PLAIN_LENGTH; + std::size_t buffer_size = + read_callback_decrypted_->maxBufferSize() + SSL3_RT_MAX_PLAIN_LENGTH; decrypted_content_ = utils::MemBuf::createCombined(buffer_size); int result = -1; std::size_t size = 0; while (!content_downloaded_ || something_to_read_) { - result = SSL_read( - this->ssl_, decrypted_content_->writableTail(), SSL3_RT_MAX_PLAIN_LENGTH); + result = SSL_read(this->ssl_, decrypted_content_->writableTail(), + SSL3_RT_MAX_PLAIN_LENGTH); /* SSL_read returns the data only if there were SSL3_RT_MAX_PLAIN_LENGTH of * the data has been fully downloaded */ diff --git a/libtransport/src/implementation/tls_socket_consumer.h b/libtransport/src/implementation/tls_socket_consumer.h index 1c5df346a..be08ec47d 100644 --- a/libtransport/src/implementation/tls_socket_consumer.h +++ b/libtransport/src/implementation/tls_socket_consumer.h @@ -16,9 +16,7 @@ #pragma once #include <hicn/transport/interfaces/socket_consumer.h> - #include <implementation/socket_consumer.h> - #include <openssl/ssl.h> namespace transport { @@ -74,7 +72,6 @@ class TLSConsumerSocket : public ConsumerSocket, SSL_CTX *ctx_; /* Chain of MemBuf to be used as a temporary buffer to pass descypted data * from the underlying layer to the application */ - utils::ObjectPool<utils::MemBuf> buf_pool_; std::unique_ptr<utils::MemBuf> decrypted_content_; /* Chain of MemBuf holding the payload to be written into interest or data */ std::unique_ptr<utils::MemBuf> payload_; diff --git a/libtransport/src/implementation/tls_socket_producer.cc b/libtransport/src/implementation/tls_socket_producer.cc index 339a1ad58..e54d38d56 100644 --- a/libtransport/src/implementation/tls_socket_producer.cc +++ b/libtransport/src/implementation/tls_socket_producer.cc @@ -14,10 +14,8 @@ */ #include <hicn/transport/interfaces/socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> #include <implementation/tls_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> @@ -50,10 +48,14 @@ int TLSProducerSocket::readOld(BIO *b, char *buf, int size) { std::unique_lock<std::mutex> lck(socket->mtx_); + TRANSPORT_LOGD("Start wait on the CV."); + if (!socket->something_to_read_) { (socket->cv_).wait(lck); } + TRANSPORT_LOGD("CV unlocked."); + /* Either there already is something to read, or the thread has been waken up. * We must return the payload in the interest anyway */ utils::MemBuf *membuf = socket->handshake_packet_->next(); @@ -103,7 +105,7 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { socket->tls_chunks_--; socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, false); - socket->parent_->ProducerSocket::produce( + socket->parent_->ProducerSocket::produceStream( socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, socket->last_segment_); socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, @@ -122,18 +124,18 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { socket->tls_chunks_--; socket->to_call_oncontentproduced_--; - socket->last_segment_ += socket->ProducerSocket::produce( + socket->last_segment_ += socket->ProducerSocket::produceStream( socket->name_, std::move(mbuf), socket->tls_chunks_ == 0, socket->last_segment_); - ProducerContentCallback on_content_produced_application; + ProducerContentCallback *on_content_produced_application; socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, - on_content_produced_application); + &on_content_produced_application); if (socket->to_call_oncontentproduced_ == 0 && on_content_produced_application) { - on_content_produced_application(*socket->getInterface(), - std::error_code(), 0); + on_content_produced_application->operator()(*socket->getInterface(), + std::error_code(), 0); } }); } @@ -144,7 +146,8 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { TLSProducerSocket::TLSProducerSocket(interface::ProducerSocket *producer_socket, P2PSecureProducerSocket *parent, const Name &handshake_name) - : ProducerSocket(producer_socket), + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), on_content_produced_application_(), mtx_(), cv_(), @@ -236,13 +239,14 @@ void TLSProducerSocket::accept() { std::move(parent_->map_producers[handshake_name_])); parent_->map_producers.erase(handshake_name_); - ProducerInterestCallback on_interest_process_decrypted; + ProducerInterestCallback *on_interest_process_decrypted; getSocketOption(ProducerCallbacksOptions::CACHE_MISS, - on_interest_process_decrypted); + &on_interest_process_decrypted); - if (on_interest_process_decrypted) { - Interest inter(std::move(handshake_packet_)); - on_interest_process_decrypted(*getInterface(), inter); + if (*on_interest_process_decrypted) { + Interest inter(std::move(*handshake_packet_)); + handshake_packet_.reset(); + on_interest_process_decrypted->operator()(*getInterface(), inter); } else { throw errors::RuntimeException( "On interest process unset: unable to perform handshake"); @@ -270,14 +274,14 @@ void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { std::unique_lock<std::mutex> lck(mtx_); name_ = interest.getName(); - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; cv_.notify_one(); return; } else if (handshake_state == SERVER_FINISHED) { - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; @@ -288,12 +292,12 @@ void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { interest.getPayload()->length()); } - ProducerInterestCallback on_interest_input_decrypted; + ProducerInterestCallback *on_interest_input_decrypted; getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, - on_interest_input_decrypted); + &on_interest_input_decrypted); - if (on_interest_input_decrypted) - (on_interest_input_decrypted)(*getInterface(), interest); + if (*on_interest_input_decrypted) + (*on_interest_input_decrypted)(*getInterface(), interest); } } @@ -301,17 +305,19 @@ void TLSProducerSocket::cacheMiss(interface::ProducerSocket &p, Interest &interest) { HandshakeState handshake_state = getHandshakeState(); + TRANSPORT_LOGD("On cache miss in TLS socket producer."); + if (handshake_state == CLIENT_HELLO) { std::unique_lock<std::mutex> lck(mtx_); - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; handshake_state_ = CLIENT_FINISHED; cv_.notify_one(); } else if (handshake_state == SERVER_FINISHED) { - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; @@ -343,16 +349,16 @@ void TLSProducerSocket::onContentProduced(interface::ProducerSocket &p, const std::error_code &err, uint64_t bytes_written) {} -uint32_t TLSProducerSocket::produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { +uint32_t TLSProducerSocket::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { if (getHandshakeState() != SERVER_FINISHED) { throw errors::RuntimeException( "New handshake on the same P2P secure producer socket not supported"); } size_t buf_size = buffer->length(); - name_ = served_namespaces_.front().mapName(content_name); + name_ = production_protocol_->getNamespaces().front().mapName(content_name); tls_chunks_ = to_call_oncontentproduced_ = ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); @@ -370,46 +376,6 @@ uint32_t TLSProducerSocket::produce(Name content_name, return 0; } -void TLSProducerSocket::asyncProduce(const Name &content_name, - const uint8_t *buf, size_t buffer_size, - bool is_last, uint32_t *start_offset) { - if (!encryption_thread_.stopped()) { - encryption_thread_.add([this, content_name, buffer = buf, - size = buffer_size, is_last, start_offset]() { - if (start_offset != NULL) { - produce(content_name, buffer, size, is_last, *start_offset); - } else { - produce(content_name, buffer, size, is_last, 0); - } - }); - } -} - -void TLSProducerSocket::asyncProduce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t offset, - uint32_t **last_segment) { - if (!encryption_thread_.stopped()) { - auto a = buffer.release(); - encryption_thread_.add( - [this, content_name, a, is_last, offset, last_segment]() { - auto buf = std::unique_ptr<utils::MemBuf>(a); - if (last_segment != NULL) { - *last_segment = &last_segment_; - } - produce(content_name, std::move(buf), is_last, offset); - }); - } -} - -void TLSProducerSocket::asyncProduce(ContentObject &content_object) { - throw errors::RuntimeException("API not supported"); -} - -void TLSProducerSocket::produce(ContentObject &content_object) { - throw errors::RuntimeException("API not supported"); -} - long TLSProducerSocket::ctrl(BIO *b, int cmd, long num, void *ptr) { if (cmd == BIO_CTRL_FLUSH) { } @@ -424,13 +390,14 @@ int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, void *add_arg) { TLSProducerSocket *socket = reinterpret_cast<TLSProducerSocket *>(add_arg); + TRANSPORT_LOGD("On addHicnKeyIdCb, for the prefix registration."); + if (ext_type == 100) { - ip_prefix_t ip_prefix = - socket->parent_->served_namespaces_.front().toIpPrefixStruct(); - int inet_family = - socket->parent_->served_namespaces_.front().getAddressFamily(); - uint16_t prefix_len_bits = - socket->parent_->served_namespaces_.front().getPrefixLength(); + auto &prefix = + socket->parent_->production_protocol_->getNamespaces().front(); + const ip_prefix_t &ip_prefix = prefix.toIpPrefixStruct(); + int inet_family = prefix.getAddressFamily(); + uint16_t prefix_len_bits = prefix.getPrefixLength(); uint8_t prefix_len_bytes = prefix_len_bits / 8; uint8_t prefix_len_u32 = prefix_len_bits / 32; @@ -479,10 +446,9 @@ int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, socket->parent_->on_interest_process_decrypted_; socket->registerPrefix( - Prefix(socket->parent_->served_namespaces_.front().getName( - Name(inet_family, (uint8_t *)&mask), - Name(inet_family, (uint8_t *)&keyId_component), - socket->parent_->served_namespaces_.front().getName()), + Prefix(prefix.getName(Name(inet_family, (uint8_t *)&mask), + Name(inet_family, (uint8_t *)&keyId_component), + prefix.getName()), out_ip->len)); socket->connect(); } @@ -580,61 +546,5 @@ int TLSProducerSocket::getSocketOption( }); } -int TLSProducerSocket::getSocketOption( - int socket_option_key, ProducerContentCallback &socket_option_value) { - return rescheduleOnIOServiceWithReference( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerContentCallback &socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::CONTENT_PRODUCED: - socket_option_value = on_content_produced_application_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); -} - -int TLSProducerSocket::getSocketOption( - int socket_option_key, ProducerInterestCallback &socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in case - // setSocketOption is called while the io_service is running. - return rescheduleOnIOServiceWithReference( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerInterestCallback &socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::INTEREST_INPUT: - socket_option_value = on_interest_input_decrypted_; - break; - - case ProducerCallbacksOptions::INTEREST_DROP: - socket_option_value = on_interest_dropped_input_buffer_; - break; - - case ProducerCallbacksOptions::INTEREST_PASS: - socket_option_value = on_interest_inserted_input_buffer_; - break; - - case ProducerCallbacksOptions::CACHE_HIT: - socket_option_value = on_interest_satisfied_output_buffer_; - break; - - case ProducerCallbacksOptions::CACHE_MISS: - socket_option_value = on_interest_process_decrypted_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); -} - } // namespace implementation } // namespace transport diff --git a/libtransport/src/implementation/tls_socket_producer.h b/libtransport/src/implementation/tls_socket_producer.h index 2382e8695..a542a4d9f 100644 --- a/libtransport/src/implementation/tls_socket_producer.h +++ b/libtransport/src/implementation/tls_socket_producer.h @@ -16,8 +16,8 @@ #pragma once #include <implementation/socket_producer.h> - #include <openssl/ssl.h> + #include <condition_variable> #include <mutex> @@ -36,26 +36,18 @@ class TLSProducerSocket : virtual public ProducerSocket { ~TLSProducerSocket(); - uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, - bool is_last = true, uint32_t start_offset = 0) override { - return produce(content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), - is_last, start_offset); + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override { + return produceStream(content_name, + utils::MemBuf::copyBuffer(buffer, buffer_size), + is_last, start_offset); } - uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) override; - - void produce(ContentObject &content_object) override; - - void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, - bool is_last = true, - uint32_t *start_offset = nullptr) override; - - void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t offset, - uint32_t **last_segment = nullptr) override; - - void asyncProduce(ContentObject &content_object) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; virtual void accept(); @@ -80,7 +72,7 @@ class TLSProducerSocket : virtual public ProducerSocket { ProducerInterestCallback &socket_option_value); using ProducerSocket::getSocketOption; - using ProducerSocket::onInterest; + // using ProducerSocket::onInterest; using ProducerSocket::setSocketOption; protected: @@ -119,6 +111,7 @@ class TLSProducerSocket : virtual public ProducerSocket { int to_call_oncontentproduced_; bool still_writing_; utils::EventThread encryption_thread_; + utils::EventThread async_thread_; void onInterest(ProducerSocket &p, Interest &interest); diff --git a/libtransport/src/interfaces/CMakeLists.txt b/libtransport/src/interfaces/CMakeLists.txt index e1d144596..0284aa412 100644 --- a/libtransport/src/interfaces/CMakeLists.txt +++ b/libtransport/src/interfaces/CMakeLists.txt @@ -14,24 +14,24 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc ${CMAKE_CURRENT_SOURCE_DIR}/callbacks.cc + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc ) if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc ) list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h ) diff --git a/libtransport/src/interfaces/global_configuration.cc b/libtransport/src/interfaces/global_configuration.cc new file mode 100644 index 000000000..8fb6601f3 --- /dev/null +++ b/libtransport/src/interfaces/global_configuration.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/utils/log.h> + +#include <system_error> + +namespace transport { +namespace interface { +namespace global_config { + +void parseConfigurationFile(const std::string& path) { + core::GlobalConfiguration::getInstance().parseConfiguration(path); +} + +void ConfigurationObject::get() { + std::error_code ec; + core::GlobalConfiguration::getInstance().getConfiguration(*this, ec); + + if (ec) { + TRANSPORT_LOGE("Error setting global config: %s", ec.message().c_str()); + } +} + +void ConfigurationObject::set() { + std::error_code ec; + core::GlobalConfiguration::getInstance().setConfiguration(*this, ec); + + if (ec) { + TRANSPORT_LOGE("Error setting global config: %s", ec.message().c_str()); + } +} + +} // namespace global_config +} // namespace interface +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/interfaces/p2psecure_socket_consumer.cc b/libtransport/src/interfaces/p2psecure_socket_consumer.cc index 038441dfc..e473a1e2e 100644 --- a/libtransport/src/interfaces/p2psecure_socket_consumer.cc +++ b/libtransport/src/interfaces/p2psecure_socket_consumer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/p2psecure_socket_consumer.h> - #include <implementation/p2psecure_socket_consumer.h> namespace transport { diff --git a/libtransport/src/interfaces/p2psecure_socket_producer.cc b/libtransport/src/interfaces/p2psecure_socket_producer.cc index 37352259c..10d8a1367 100644 --- a/libtransport/src/interfaces/p2psecure_socket_producer.cc +++ b/libtransport/src/interfaces/p2psecure_socket_producer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/p2psecure_socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> namespace transport { @@ -25,7 +24,7 @@ P2PSecureProducerSocket::P2PSecureProducerSocket() { } P2PSecureProducerSocket::P2PSecureProducerSocket( - bool rtc, const std::shared_ptr<utils::Identity> &identity) { + bool rtc, const std::shared_ptr<auth::Identity> &identity) { socket_ = std::make_unique<implementation::P2PSecureProducerSocket>(this, rtc, identity); } diff --git a/libtransport/src/interfaces/portal.cc b/libtransport/src/interfaces/portal.cc index 36cbd0c3b..2ab51c4b9 100644 --- a/libtransport/src/interfaces/portal.cc +++ b/libtransport/src/interfaces/portal.cc @@ -14,38 +14,35 @@ */ #include <hicn/transport/interfaces/portal.h> - #include <implementation/socket.h> namespace transport { namespace interface { -using implementation::BasePortal; - -Portal::Portal() { implementation_ = new implementation::BasePortal(); } +Portal::Portal() { implementation_ = new core::Portal(); } Portal::Portal(asio::io_service &io_service) { - implementation_ = new BasePortal(io_service); + implementation_ = new core::Portal(io_service); } -Portal::~Portal() { delete reinterpret_cast<BasePortal *>(implementation_); } +Portal::~Portal() { delete reinterpret_cast<core::Portal *>(implementation_); } void Portal::setConsumerCallback(ConsumerCallback *consumer_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->setConsumerCallback(consumer_callback); } void Portal::setProducerCallback(ProducerCallback *producer_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->setProducerCallback(producer_callback); } void Portal::connect(bool is_consumer) { - reinterpret_cast<BasePortal *>(implementation_)->connect(is_consumer); + reinterpret_cast<core::Portal *>(implementation_)->connect(is_consumer); } bool Portal::interestIsPending(const core::Name &name) { - return reinterpret_cast<BasePortal *>(implementation_) + return reinterpret_cast<core::Portal *>(implementation_) ->interestIsPending(name); } @@ -53,46 +50,46 @@ void Portal::sendInterest( core::Interest::Ptr &&interest, OnContentObjectCallback &&on_content_object_callback, OnInterestTimeoutCallback &&on_interest_timeout_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->sendInterest(std::move(interest), std::move(on_content_object_callback), std::move(on_interest_timeout_callback)); } void Portal::bind(const BindConfig &config) { - reinterpret_cast<BasePortal *>(implementation_)->bind(config); + reinterpret_cast<core::Portal *>(implementation_)->bind(config); } void Portal::runEventsLoop() { - reinterpret_cast<BasePortal *>(implementation_)->runEventsLoop(); + reinterpret_cast<core::Portal *>(implementation_)->runEventsLoop(); } void Portal::runOneEvent() { - reinterpret_cast<BasePortal *>(implementation_)->runOneEvent(); + reinterpret_cast<core::Portal *>(implementation_)->runOneEvent(); } void Portal::sendContentObject(core::ContentObject &content_object) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->sendContentObject(content_object); } void Portal::stopEventsLoop() { - reinterpret_cast<BasePortal *>(implementation_)->stopEventsLoop(); + reinterpret_cast<core::Portal *>(implementation_)->stopEventsLoop(); } void Portal::killConnection() { - reinterpret_cast<BasePortal *>(implementation_)->killConnection(); + reinterpret_cast<core::Portal *>(implementation_)->killConnection(); } void Portal::clear() { - reinterpret_cast<BasePortal *>(implementation_)->clear(); + reinterpret_cast<core::Portal *>(implementation_)->clear(); } asio::io_service &Portal::getIoService() { - return reinterpret_cast<BasePortal *>(implementation_)->getIoService(); + return reinterpret_cast<core::Portal *>(implementation_)->getIoService(); } void Portal::registerRoute(core::Prefix &prefix) { - reinterpret_cast<BasePortal *>(implementation_)->registerRoute(prefix); + reinterpret_cast<core::Portal *>(implementation_)->registerRoute(prefix); } } // namespace interface diff --git a/libtransport/src/interfaces/socket_consumer.cc b/libtransport/src/interfaces/socket_consumer.cc index ea0606347..4eee73cab 100644 --- a/libtransport/src/interfaces/socket_consumer.cc +++ b/libtransport/src/interfaces/socket_consumer.cc @@ -46,8 +46,6 @@ void ConsumerSocket::stop() { socket_->stop(); } void ConsumerSocket::resume() { socket_->resume(); } -bool ConsumerSocket::verifyKeyPackets() { return socket_->verifyKeyPackets(); } - asio::io_service &ConsumerSocket::getIoService() { return socket_->getIoService(); } @@ -88,22 +86,10 @@ int ConsumerSocket::setSocketOption( } int ConsumerSocket::setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ConsumerSocket::setSocketOption( int socket_option_key, ConsumerInterestCallback socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ConsumerSocket::setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ConsumerSocket::setSocketOption(int socket_option_key, IcnObserver *socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); @@ -111,7 +97,7 @@ int ConsumerSocket::setSocketOption(int socket_option_key, int ConsumerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Verifier> &socket_option_value) { + const std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } @@ -152,22 +138,10 @@ int ConsumerSocket::getSocketOption( } int ConsumerSocket::getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ConsumerSocket::getSocketOption( int socket_option_key, ConsumerInterestCallback **socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ConsumerSocket::getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback **socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ConsumerSocket::getSocketOption(int socket_option_key, IcnObserver **socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); @@ -175,7 +149,7 @@ int ConsumerSocket::getSocketOption(int socket_option_key, int ConsumerSocket::getSocketOption( int socket_option_key, - std::shared_ptr<utils::Verifier> &socket_option_value) { + std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } diff --git a/libtransport/src/interfaces/socket_producer.cc b/libtransport/src/interfaces/socket_producer.cc index d030fe756..b04947dfd 100644 --- a/libtransport/src/interfaces/socket_producer.cc +++ b/libtransport/src/interfaces/socket_producer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/socket_producer.h> - #include <implementation/socket_producer.h> #include <atomic> @@ -31,11 +30,12 @@ namespace interface { using namespace core; ProducerSocket::ProducerSocket(int protocol) { - if (protocol != 0) { - throw std::runtime_error("Production protocol must be 0."); - } + socket_ = std::make_unique<implementation::ProducerSocket>(this, protocol); +} - socket_ = std::make_unique<implementation::ProducerSocket>(this); +ProducerSocket::ProducerSocket(int protocol, asio::io_service &io_service) { + socket_ = std::make_unique<implementation::ProducerSocket>(this, protocol, + io_service); } ProducerSocket::ProducerSocket(bool) {} @@ -46,19 +46,34 @@ void ProducerSocket::connect() { socket_->connect(); } bool ProducerSocket::isRunning() { return socket_->isRunning(); } -uint32_t ProducerSocket::produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { - return socket_->produce(content_name, std::move(buffer), is_last, - start_offset); +uint32_t ProducerSocket::produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + return socket_->produceStream(content_name, std::move(buffer), is_last, + start_offset); } -void ProducerSocket::produce(ContentObject &content_object) { - return socket_->produce(content_object); +uint32_t ProducerSocket::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, bool is_last, + uint32_t start_offset) { + return socket_->produceStream(content_name, buffer, buffer_size, is_last, + start_offset); } -void ProducerSocket::produce(std::unique_ptr<utils::MemBuf> &&buffer) { - socket_->produce(std::move(buffer)); +uint32_t ProducerSocket::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + return socket_->produceDatagram(content_name, std::move(buffer)); +} + +uint32_t ProducerSocket::produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) { + return socket_->produceDatagram(content_name, buffer, buffer_size); +} + +void ProducerSocket::produce(ContentObject &content_object) { + return socket_->produce(content_object); } void ProducerSocket::asyncProduce(Name content_name, @@ -69,16 +84,10 @@ void ProducerSocket::asyncProduce(Name content_name, last_segment); } -void ProducerSocket::asyncProduce(ContentObject &content_object) { - return socket_->asyncProduce(content_object); -} - void ProducerSocket::registerPrefix(const Prefix &producer_namespace) { return socket_->registerPrefix(producer_namespace); } -void ProducerSocket::serveForever() { return socket_->serveForever(); } - void ProducerSocket::stop() { return socket_->stop(); } asio::io_service &ProducerSocket::getIoService() { @@ -105,11 +114,6 @@ int ProducerSocket::setSocketOption(int socket_option_key, return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); @@ -126,18 +130,13 @@ int ProducerSocket::setSocketOption( } int ProducerSocket::setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ProducerSocket::setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) { + auth::CryptoHashType socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } int ProducerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } @@ -156,11 +155,6 @@ int ProducerSocket::getSocketOption(int socket_option_key, return socket_->getSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::getSocketOption(int socket_option_key, - std::list<Prefix> &socket_option_value) { - return socket_->getSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::getSocketOption( int socket_option_key, ProducerContentObjectCallback **socket_option_value) { @@ -177,19 +171,13 @@ int ProducerSocket::getSocketOption( return socket_->getSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::getSocketOption( - int socket_option_key, utils::CryptoHashType &socket_option_value) { - return socket_->getSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::getSocketOption(int socket_option_key, - utils::CryptoSuite &socket_option_value) { + auth::CryptoHashType &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } int ProducerSocket::getSocketOption( - int socket_option_key, - std::shared_ptr<utils::Signer> &socket_option_value) { + int socket_option_key, std::shared_ptr<auth::Signer> &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } diff --git a/libtransport/src/interfaces/tls_rtc_socket_producer.cc b/libtransport/src/interfaces/tls_rtc_socket_producer.cc index 132f34721..7326fcbcb 100644 --- a/libtransport/src/interfaces/tls_rtc_socket_producer.cc +++ b/libtransport/src/interfaces/tls_rtc_socket_producer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_rtc_socket_producer.h> - #include <implementation/tls_rtc_socket_producer.h> +#include <interfaces/tls_rtc_socket_producer.h> namespace transport { namespace interface { diff --git a/libtransport/src/interfaces/tls_socket_consumer.cc b/libtransport/src/interfaces/tls_socket_consumer.cc index d87642f73..6c1c535b5 100644 --- a/libtransport/src/interfaces/tls_socket_consumer.cc +++ b/libtransport/src/interfaces/tls_socket_consumer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_socket_consumer.h> - #include <implementation/tls_socket_consumer.h> +#include <interfaces/tls_socket_consumer.h> namespace transport { namespace interface { diff --git a/libtransport/src/interfaces/tls_socket_producer.cc b/libtransport/src/interfaces/tls_socket_producer.cc index 44aa0cf8b..037702f72 100644 --- a/libtransport/src/interfaces/tls_socket_producer.cc +++ b/libtransport/src/interfaces/tls_socket_producer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_socket_producer.h> - #include <implementation/tls_socket_producer.h> +#include <interfaces/tls_socket_producer.h> namespace transport { namespace interface { diff --git a/libtransport/src/io_modules/CMakeLists.txt b/libtransport/src/io_modules/CMakeLists.txt new file mode 100644 index 000000000..6553b9a2b --- /dev/null +++ b/libtransport/src/io_modules/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +if (${CMAKE_SYSTEM_NAME} MATCHES Android) + list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/udp/hicn_forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp/udp_socket_connector.cc + ) + + list(APPEND HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/udp/hicn_forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp/udp_socket_connector.h + ) + + set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) + set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) +else() + add_subdirectory(udp) + add_subdirectory(loopback) + add_subdirectory(forwarder) + + if (__vpp__) + add_subdirectory(memif) + endif() +endif()
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/CMakeLists.txt b/libtransport/src/io_modules/forwarder/CMakeLists.txt new file mode 100644 index 000000000..92662bc4c --- /dev/null +++ b/libtransport/src/io_modules/forwarder/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/endpoint.h + ${CMAKE_CURRENT_SOURCE_DIR}/errors.h + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel_listener.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_counter.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel_listener.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel.cc +) + +build_module(forwarder_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/forwarder/configuration.h b/libtransport/src/io_modules/forwarder/configuration.h new file mode 100644 index 000000000..fcaa5530d --- /dev/null +++ b/libtransport/src/io_modules/forwarder/configuration.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace transport { +namespace core { + +struct ListenerConfig { + std::string address; + std::uint16_t port; + std::string name; +}; + +struct ConnectorConfig { + std::string local_address; + std::uint16_t local_port; + std::string remote_address; + std::uint16_t remote_port; + std::string name; +}; + +struct RouteConfig { + std::string prefix; + uint16_t weight; + std::string connector; + std::string name; +}; + +class Configuration { + public: + Configuration() : n_threads_(1) {} + + bool empty() { + return listeners_.empty() && connectors_.empty() && routes_.empty(); + } + + Configuration& setThreadNumber(std::size_t threads) { + n_threads_ = threads; + return *this; + } + + std::size_t getThreadNumber() { return n_threads_; } + + template <typename... Args> + Configuration& addListener(Args&&... args) { + listeners_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + template <typename... Args> + Configuration& addConnector(Args&&... args) { + connectors_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + template <typename... Args> + Configuration& addRoute(Args&&... args) { + routes_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + std::vector<ListenerConfig>& getListeners() { return listeners_; } + + std::vector<ConnectorConfig>& getConnectors() { return connectors_; } + + std::vector<RouteConfig>& getRoutes() { return routes_; } + + private: + std::vector<ListenerConfig> listeners_; + std::vector<ConnectorConfig> connectors_; + std::vector<RouteConfig> routes_; + std::size_t n_threads_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/errors.cc b/libtransport/src/io_modules/forwarder/errors.cc new file mode 100644 index 000000000..b5f131499 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/errors.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Cisco and/or its affiliates. + */ + +#include <io_modules/forwarder/errors.h> + +namespace transport { +namespace core { + +const std::error_category& forwarder_category() { + static forwarder_category_impl instance; + + return instance; +} + +const char* forwarder_category_impl::name() const throw() { + return "proxy::connector::error"; +} + +std::string forwarder_category_impl::message(int ev) const { + switch (static_cast<forwarder_error>(ev)) { + case forwarder_error::success: { + return "Success"; + } + case forwarder_error::disconnected: { + return "Connector is disconnected"; + } + case forwarder_error::receive_failed: { + return "Packet reception failed"; + } + case forwarder_error::send_failed: { + return "Packet send failed"; + } + case forwarder_error::memory_allocation_error: { + return "Impossible to allocate memory for packet pool"; + } + case forwarder_error::invalid_connector_type: { + return "Invalid type specified for connector."; + } + case forwarder_error::invalid_connector: { + return "Created connector was invalid."; + } + case forwarder_error::interest_cache_miss: { + return "interest cache miss."; + } + default: { + return "Unknown connector error"; + } + } +} +} // namespace core +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/errors.h b/libtransport/src/io_modules/forwarder/errors.h new file mode 100644 index 000000000..dd5cc8fe7 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/errors.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> +#include <system_error> + +namespace transport { +namespace core { +/** + * @brief Get the default server error category. + * @return The default server error category instance. + * + * @warning The first call to this function is thread-safe only starting with + * C++11. + */ +const std::error_category& forwarder_category(); + +/** + * The list of errors. + */ +enum class forwarder_error { + success = 0, + send_failed, + receive_failed, + disconnected, + memory_allocation_error, + invalid_connector_type, + invalid_connector, + interest_cache_miss +}; + +/** + * @brief Create an error_code instance for the given error. + * @param error The error. + * @return The error_code instance. + */ +inline std::error_code make_error_code(forwarder_error error) { + return std::error_code(static_cast<int>(error), forwarder_category()); +} + +/** + * @brief Create an error_condition instance for the given error. + * @param error The error. + * @return The error_condition instance. + */ +inline std::error_condition make_error_condition(forwarder_error error) { + return std::error_condition(static_cast<int>(error), forwarder_category()); +} + +/** + * @brief A server error category. + */ +class forwarder_category_impl : public std::error_category { + public: + /** + * @brief Get the name of the category. + * @return The name of the category. + */ + virtual const char* name() const throw(); + + /** + * @brief Get the error message for a given error. + * @param ev The error numeric value. + * @return The message associated to the error. + */ + virtual std::string message(int ev) const; +}; +} // namespace core +} // namespace transport + +namespace std { +// namespace system { +template <> +struct is_error_code_enum<::transport::core::forwarder_error> + : public std::true_type {}; +// } // namespace system +} // namespace std diff --git a/libtransport/src/io_modules/forwarder/forwarder.cc b/libtransport/src/io_modules/forwarder/forwarder.cc new file mode 100644 index 000000000..7e89e2f9f --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder.cc @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <core/local_connector.h> +#include <io_modules/forwarder/forwarder.h> +#include <io_modules/forwarder/global_id_counter.h> +#include <io_modules/forwarder/udp_tunnel.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +namespace transport { + +namespace core { + +constexpr char Forwarder::forwarder_config_section[]; + +Forwarder::Forwarder() : config_() { + using namespace std::placeholders; + GlobalConfiguration::getInstance().registerConfigurationParser( + forwarder_config_section, + std::bind(&Forwarder::parseForwarderConfiguration, this, _1, _2)); + + if (!config_.empty()) { + initThreads(); + initListeners(); + initConnectors(); + } +} + +Forwarder::~Forwarder() { + for (auto &l : listeners_) { + l->close(); + } + + for (auto &c : remote_connectors_) { + c.second->close(); + } + + GlobalConfiguration::getInstance().unregisterConfigurationParser( + forwarder_config_section); +} + +void Forwarder::initThreads() { + for (unsigned i = 0; i < config_.getThreadNumber(); i++) { + thread_pool_.emplace_back(io_service_, /* detached */ false); + } +} + +void Forwarder::initListeners() { + using namespace std::placeholders; + for (auto &l : config_.getListeners()) { + listeners_.emplace_back(std::make_shared<UdpTunnelListener>( + io_service_, + std::bind(&Forwarder::onPacketFromListener, this, _1, _2, _3), + asio::ip::udp::endpoint(asio::ip::address::from_string(l.address), + l.port))); + } +} + +void Forwarder::initConnectors() { + using namespace std::placeholders; + for (auto &c : config_.getConnectors()) { + auto id = GlobalCounter<Connector::Id>::getInstance().getNext(); + auto conn = new UdpTunnelConnector( + io_service_, std::bind(&Forwarder::onPacketReceived, this, _1, _2, _3), + std::bind(&Forwarder::onPacketSent, this, _1, _2), + std::bind(&Forwarder::onConnectorClosed, this, _1), + std::bind(&Forwarder::onConnectorReconnected, this, _1)); + conn->setConnectorId(id); + remote_connectors_.emplace(id, conn); + conn->connect(c.remote_address, c.remote_port, c.local_address, + c.local_port); + } +} + +Connector::Id Forwarder::registerLocalConnector( + asio::io_service &io_service, + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback) { + utils::SpinLock::Acquire locked(connector_lock_); + auto id = GlobalCounter<Connector::Id>::getInstance().getNext(); + auto connector = std::make_shared<LocalConnector>( + io_service, receive_callback, nullptr, nullptr, reconnect_callback); + connector->setConnectorId(id); + local_connectors_.emplace(id, std::move(connector)); + return id; +} + +Forwarder &Forwarder::deleteConnector(Connector::Id id) { + utils::SpinLock::Acquire locked(connector_lock_); + auto it = local_connectors_.find(id); + if (it != local_connectors_.end()) { + it->second->close(); + local_connectors_.erase(it); + } + + return *this; +} + +Connector::Ptr Forwarder::getConnector(Connector::Id id) { + utils::SpinLock::Acquire locked(connector_lock_); + auto it = local_connectors_.find(id); + if (it != local_connectors_.end()) { + return it->second; + } + + return nullptr; +} + +void Forwarder::onPacketFromListener(Connector *connector, + utils::MemBuf &packet_buffer, + const std::error_code &ec) { + // Create connector + connector->setReceiveCallback( + std::bind(&Forwarder::onPacketReceived, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); + + TRANSPORT_LOGD("Packet received from listener."); + + { + utils::SpinLock::Acquire locked(connector_lock_); + remote_connectors_.emplace(connector->getConnectorId(), + connector->shared_from_this()); + } + // TODO Check if control packet or not. For the moment it is not. + onPacketReceived(connector, packet_buffer, ec); +} + +void Forwarder::onPacketReceived(Connector *connector, + utils::MemBuf &packet_buffer, + const std::error_code &ec) { + // Figure out the type of packet we received + bool is_interest = Packet::isInterest(packet_buffer.data()); + + Packet *packet = nullptr; + if (is_interest) { + packet = static_cast<Interest *>(&packet_buffer); + } else { + packet = static_cast<ContentObject *>(&packet_buffer); + } + + for (auto &c : local_connectors_) { + auto role = c.second->getRole(); + auto is_producer = role == Connector::Role::PRODUCER; + if ((is_producer && is_interest) || (!is_producer && !is_interest)) { + c.second->send(*packet); + } else { + TRANSPORT_LOGD( + "Error sending packet to local connector. is_interest = %d - " + "is_producer = %d", + (int)is_interest, (int)is_producer); + } + } + + // PCS Lookup + FIB lookup. Skip for now + + // Forward packet to local connectors +} + +void Forwarder::send(Packet &packet) { + // TODo Here a nice PIT/CS / FIB would be required:) + // For now let's just forward the packet on the remote connector we get + if (remote_connectors_.begin() == remote_connectors_.end()) { + return; + } + + auto remote_endpoint = + remote_connectors_.begin()->second->getRemoteEndpoint(); + TRANSPORT_LOGD("Sending packet to: %s:%u", + remote_endpoint.getAddress().to_string().c_str(), + remote_endpoint.getPort()); + remote_connectors_.begin()->second->send(packet); +} + +void Forwarder::onPacketSent(Connector *connector, const std::error_code &ec) {} + +void Forwarder::onConnectorClosed(Connector *connector) {} + +void Forwarder::onConnectorReconnected(Connector *connector) {} + +void Forwarder::parseForwarderConfiguration( + const libconfig::Setting &forwarder_config, std::error_code &ec) { + using namespace libconfig; + + // n_thread + if (forwarder_config.exists("n_threads")) { + // Get number of threads + int n_threads = 1; + forwarder_config.lookupValue("n_threads", n_threads); + TRANSPORT_LOGD("Forwarder threads from config file: %u", n_threads); + config_.setThreadNumber(n_threads); + } + + // listeners + if (forwarder_config.exists("listeners")) { + // get path where looking for modules + const Setting &listeners = forwarder_config.lookup("listeners"); + auto count = listeners.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &listener = listeners[i]; + ListenerConfig list; + unsigned port; + + list.name = listener.getName(); + listener.lookupValue("local_address", list.address); + listener.lookupValue("local_port", port); + list.port = (uint16_t)(port); + + TRANSPORT_LOGD("Adding listener %s, (%s:%u)", list.name.c_str(), + list.address.c_str(), list.port); + config_.addListener(std::move(list)); + } + } + + // connectors + if (forwarder_config.exists("connectors")) { + // get path where looking for modules + const Setting &connectors = forwarder_config.lookup("connectors"); + auto count = connectors.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &connector = connectors[i]; + ConnectorConfig conn; + + conn.name = connector.getName(); + unsigned port = 0; + + if (!connector.lookupValue("local_address", conn.local_address)) { + conn.local_address = ""; + } + + if (!connector.lookupValue("local_port", port)) { + port = 0; + } + + conn.local_port = (uint16_t)(port); + + if (!connector.lookupValue("remote_address", conn.remote_address)) { + throw errors::RuntimeException( + "Error in configuration file: remote_address is a mandatory field " + "of Connectors."); + } + + if (!connector.lookupValue("remote_port", port)) { + throw errors::RuntimeException( + "Error in configuration file: remote_port is a mandatory field " + "of Connectors."); + } + + conn.remote_port = (uint16_t)(port); + + TRANSPORT_LOGD("Adding connector %s, (%s:%u %s:%u)", conn.name.c_str(), + conn.local_address.c_str(), conn.local_port, + conn.remote_address.c_str(), conn.remote_port); + config_.addConnector(std::move(conn)); + } + } + + // Routes + if (forwarder_config.exists("routes")) { + const Setting &routes = forwarder_config.lookup("routes"); + auto count = routes.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &route = routes[i]; + RouteConfig r; + unsigned weight; + + r.name = route.getName(); + route.lookupValue("prefix", r.prefix); + route.lookupValue("weight", weight); + route.lookupValue("connector", r.connector); + r.weight = (uint16_t)(weight); + + TRANSPORT_LOGD("Adding route %s %s (%s %u)", r.name.c_str(), + r.prefix.c_str(), r.connector.c_str(), r.weight); + config_.addRoute(std::move(r)); + } + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/forwarder.h b/libtransport/src/io_modules/forwarder/forwarder.h new file mode 100644 index 000000000..5b564bb5e --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> +#include <hicn/transport/utils/event_thread.h> +#include <hicn/transport/utils/singleton.h> +#include <hicn/transport/utils/spinlock.h> +#include <io_modules/forwarder/configuration.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +#include <atomic> +#include <libconfig.h++> +#include <unordered_map> + +namespace transport { + +namespace core { + +class Forwarder : public utils::Singleton<Forwarder> { + static constexpr char forwarder_config_section[] = "forwarder"; + friend class utils::Singleton<Forwarder>; + + public: + Forwarder(); + + ~Forwarder(); + + void initThreads(); + void initListeners(); + void initConnectors(); + + Connector::Id registerLocalConnector( + asio::io_service &io_service, + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback); + + Forwarder &deleteConnector(Connector::Id id); + + Connector::Ptr getConnector(Connector::Id id); + + void send(Packet &packet); + + void stop(); + + private: + void onPacketFromListener(Connector *connector, utils::MemBuf &packet_buffer, + const std::error_code &ec); + void onPacketReceived(Connector *connector, utils::MemBuf &packet_buffer, + const std::error_code &ec); + void onPacketSent(Connector *connector, const std::error_code &ec); + void onConnectorClosed(Connector *connector); + void onConnectorReconnected(Connector *connector); + + void parseForwarderConfiguration(const libconfig::Setting &io_config, + std::error_code &ec); + + asio::io_service io_service_; + utils::SpinLock connector_lock_; + + /** + * Connectors and listeners must be declares *before* thread_pool_, so that + * threads destructors will wait for them to gracefully close before being + * destroyed. + */ + std::unordered_map<Connector::Id, Connector::Ptr> remote_connectors_; + std::unordered_map<Connector::Id, Connector::Ptr> local_connectors_; + std::vector<UdpTunnelListener::Ptr> listeners_; + + std::vector<utils::EventThread> thread_pool_; + + Configuration config_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.cc b/libtransport/src/io_modules/forwarder/forwarder_module.cc new file mode 100644 index 000000000..356b42d3b --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder_module.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/forwarder/forwarder_module.h> + +namespace transport { + +namespace core { + +ForwarderModule::ForwarderModule() + : IoModule(), + name_(""), + connector_id_(Connector::invalid_connector), + forwarder_(Forwarder::getInstance()) {} + +ForwarderModule::~ForwarderModule() { + forwarder_.deleteConnector(connector_id_); +} + +bool ForwarderModule::isConnected() { return true; } + +void ForwarderModule::send(Packet &packet) { + IoModule::send(packet); + forwarder_.send(packet); + // TRANSPORT_LOGD("ForwarderModule: sending from %u to %d", local_id_, + // 1 - local_id_); + + // local_faces_.at(1 - local_id_).onPacket(packet); +} + +void ForwarderModule::send(const uint8_t *packet, std::size_t len) { + // not supported + throw errors::NotImplementedException(); +} + +void ForwarderModule::registerRoute(const Prefix &prefix) { + // For the moment we route packets from one socket to the other. + // Next step will be to introduce a FIB + return; +} + +void ForwarderModule::closeConnection() { + forwarder_.deleteConnector(connector_id_); +} + +void ForwarderModule::init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name) { + connector_id_ = forwarder_.registerLocalConnector( + io_service, std::move(receive_callback), std::move(reconnect_callback)); + name_ = app_name; +} + +void ForwarderModule::processControlMessageReply(utils::MemBuf &packet_buffer) { + return; +} + +void ForwarderModule::connect(bool is_consumer) { + forwarder_.getConnector(connector_id_) + ->setRole(is_consumer ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +std::uint32_t ForwarderModule::getMtu() { return interface_mtu; } + +bool ForwarderModule::isControlMessage(const uint8_t *message) { return false; } + +extern "C" IoModule *create_module(void) { return new ForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.h b/libtransport/src/io_modules/forwarder/forwarder_module.h new file mode 100644 index 000000000..58bfb7996 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder_module.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> +#include <io_modules/forwarder/forwarder.h> + +#include <atomic> + +namespace transport { + +namespace core { + +class Forwarder; + +class ForwarderModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + ForwarderModule(); + + ~ForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + std::string name_; + Connector::Id connector_id_; + Forwarder &forwarder_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/global_id_counter.h b/libtransport/src/io_modules/forwarder/global_id_counter.h new file mode 100644 index 000000000..fe8d76730 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/global_id_counter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <atomic> +#include <mutex> + +namespace transport { + +namespace core { + +template <typename T = uint64_t> +class GlobalCounter { + public: + static GlobalCounter& getInstance() { + std::lock_guard<std::mutex> lock(global_mutex_); + + if (!instance_) { + instance_.reset(new GlobalCounter()); + } + + return *instance_; + } + + T getNext() { return counter_++; } + + private: + GlobalCounter() : counter_(0) {} + static std::unique_ptr<GlobalCounter<T>> instance_; + static std::mutex global_mutex_; + std::atomic<T> counter_; +}; + +template <typename T> +std::unique_ptr<GlobalCounter<T>> GlobalCounter<T>::instance_ = nullptr; + +template <typename T> +std::mutex GlobalCounter<T>::global_mutex_; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel.cc b/libtransport/src/io_modules/forwarder/udp_tunnel.cc new file mode 100644 index 000000000..dc725fc4e --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel.cc @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#include <hicn/transport/utils/branch_prediction.h> +#include <io_modules/forwarder/errors.h> +#include <io_modules/forwarder/udp_tunnel.h> + +#include <iostream> +#include <thread> +#include <vector> + +namespace transport { +namespace core { + +UdpTunnelConnector::~UdpTunnelConnector() {} + +void UdpTunnelConnector::connect(const std::string &hostname, uint16_t port, + const std::string &bind_address, + uint16_t bind_port) { + if (state_ == State::CLOSED) { + state_ = State::CONNECTING; + endpoint_iterator_ = resolver_.resolve({hostname, std::to_string(port)}); + remote_endpoint_send_ = *endpoint_iterator_; + socket_->open(remote_endpoint_send_.protocol()); + + if (!bind_address.empty() && bind_port != 0) { + using namespace asio::ip; + socket_->bind( + udp::endpoint(address::from_string(bind_address), bind_port)); + } + + state_ = State::CONNECTED; + + remote_endpoint_ = Endpoint(remote_endpoint_send_); + local_endpoint_ = Endpoint(socket_->local_endpoint()); + + doRecvPacket(); + +#ifdef LINUX + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); +#endif + } +} + +void UdpTunnelConnector::send(Packet &packet) { + strand_->post([this, pkt{packet.shared_from_this()}]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(pkt)); + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + if (!write_in_progress) { + doSendPacket(); + } + } else { + data_available_ = true; + } + }); +} + +void UdpTunnelConnector::send(const uint8_t *packet, std::size_t len) {} + +void UdpTunnelConnector::close() { + TRANSPORT_LOGD("UDPTunnelConnector::close"); + state_ = State::CLOSED; + bool is_socket_owned = socket_.use_count() == 1; + if (is_socket_owned) { + io_service_.dispatch([this]() { + this->socket_->close(); + // on_close_callback_(shared_from_this()); + }); + } +} + +void UdpTunnelConnector::doSendPacket() { +#ifdef LINUX + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); +#else + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const ::utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_->async_send_to( + std::move(array), remote_endpoint_send_, + strand_->wrap([this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + sent_callback_(this, make_error_code(forwarder_error::success)); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + sendFailed(); + sent_callback_(this, ec); + } + + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doSendPacket(); + } + })); +#endif +} + +#ifdef LINUX +void UdpTunnelConnector::writeHandler(std::error_code ec) { + if (TRANSPORT_EXPECT_FALSE(state_ != State::CONNECTED)) { + return; + } + + auto len = std::min(output_buffer_.size(), std::size_t(Connector::max_burst)); + + if (len) { + int m = 0; + for (auto &p : output_buffer_) { + auto packet = p.get(); + ::utils::MemBuf *current = packet; + int b = 0; + do { + // array.push_back(asio::const_buffer(current->data(), + // current->length())); + tx_iovecs_[m][b].iov_base = current->writableData(); + tx_iovecs_[m][b].iov_len = current->length(); + current = current->next(); + b++; + } while (current != packet); + + tx_msgs_[m].msg_hdr.msg_iov = tx_iovecs_[m]; + tx_msgs_[m].msg_hdr.msg_iovlen = b; + tx_msgs_[m].msg_hdr.msg_name = remote_endpoint_send_.data(); + tx_msgs_[m].msg_hdr.msg_namelen = remote_endpoint_send_.size(); + m++; + + if (--len == 0) { + break; + } + } + + int retval = sendmmsg(socket_->native_handle(), tx_msgs_, m, MSG_DONTWAIT); + if (retval > 0) { + while (retval--) { + output_buffer_.pop_front(); + } + } else if (retval != EWOULDBLOCK && retval != EAGAIN) { + TRANSPORT_LOGE("Error sending messages! %s %d\n", strerror(errno), + retval); + return; + } + } + + if (!output_buffer_.empty()) { + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); + } +} + +void UdpTunnelConnector::readHandler(std::error_code ec) { + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + + // TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + if (current_position_ == 0) { + for (int i = 0; i < max_burst; i++) { + auto read_buffer = getRawBuffer(); + rx_iovecs_[i][0].iov_base = read_buffer.first; + rx_iovecs_[i][0].iov_len = read_buffer.second; + rx_msgs_[i].msg_hdr.msg_iov = rx_iovecs_[i]; + rx_msgs_[i].msg_hdr.msg_iovlen = 1; + } + } + + int res = recvmmsg(socket_->native_handle(), rx_msgs_ + current_position_, + max_burst - current_position_, MSG_DONTWAIT, nullptr); + if (res < 0) { + TRANSPORT_LOGE("Error receiving messages! %s %d\n", strerror(errno), + res); + return; + } + + for (int i = 0; i < res; i++) { + auto packet = getPacketFromBuffer( + reinterpret_cast<uint8_t *>( + rx_msgs_[current_position_].msg_hdr.msg_iov[0].iov_base), + rx_msgs_[current_position_].msg_len); + receiveSuccess(*packet); + receive_callback_(this, *packet, + make_error_code(forwarder_error::success)); + ++current_position_; + } + + doRecvPacket(); + } else { + TRANSPORT_LOGE( + "Error in UDP: Receiving packets from a not connected socket."); + } + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + // receive_callback_(this, *read_msg_, ec); + TRANSPORT_LOGE("Error in UDP connector: %d %s", ec.value(), + ec.message().c_str()); + } else { + TRANSPORT_LOGE("Error while not connector"); + } + } +} +#endif + +void UdpTunnelConnector::doRecvPacket() { +#ifdef LINUX + if (state_ == State::CONNECTED) { +#if ((ASIO_VERSION / 100 % 1000) < 11) + socket_->async_receive(asio::null_buffers(), +#else + socket_->async_wait(asio::ip::tcp::socket::wait_read, +#endif + std::bind(&UdpTunnelConnector::readHandler, this, + std::placeholders::_1)); + } +#else + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + read_msg_ = getRawBuffer(); + socket_->async_receive_from( + asio::buffer(read_msg_.first, read_msg_.second), remote_endpoint_recv_, + [this](std::error_code ec, std::size_t length) { + TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + auto packet = getPacketFromBuffer(read_msg_.first, length); + receiveSuccess(*packet); + receive_callback_(this, *packet, + make_error_code(forwarder_error::success)); + doRecvPacket(); + } else { + TRANSPORT_LOGE( + "Error in UDP: Receiving packets from a not connected socket."); + } + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + TRANSPORT_LOGE("Error in UDP connector: %d %s", ec.value(), + ec.message().c_str()); + } else { + TRANSPORT_LOGE("Error while not connector"); + } + } + }); +#endif +} + +void UdpTunnelConnector::doConnect() { + asio::async_connect( + *socket_, endpoint_iterator_, + [this](std::error_code ec, asio::ip::udp::resolver::iterator) { + if (!ec) { + state_ = State::CONNECTED; + doRecvPacket(); + + if (data_available_) { + data_available_ = false; + doSendPacket(); + } + } else { + TRANSPORT_LOGE("[Hproxy] - UDP Connection failed!!!"); + timer_.expires_from_now(std::chrono::milliseconds(500)); + timer_.async_wait(std::bind(&UdpTunnelConnector::doConnect, this)); + } + }); +} + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel.h b/libtransport/src/io_modules/forwarder/udp_tunnel.h new file mode 100644 index 000000000..df472af91 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/platform.h> +#include <io_modules/forwarder/errors.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <iostream> +#include <memory> + +namespace transport { +namespace core { + +class UdpTunnelListener; + +class UdpTunnelConnector : public Connector { + friend class UdpTunnelListener; + + public: + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect> + UdpTunnelConnector(asio::io_service &io_service, + ReceiveCallback &&receive_callback, + SentCallback &&packet_sent, OnClose &&on_close_callback, + OnReconnect &&on_reconnect) + : Connector(receive_callback, packet_sent, on_close_callback, + on_reconnect), + io_service_(io_service), + strand_(std::make_shared<asio::io_service::strand>(io_service_)), + socket_(std::make_shared<asio::ip::udp::socket>(io_service_)), + resolver_(io_service_), + timer_(io_service_), +#ifdef LINUX + send_timer_(io_service_), + tx_iovecs_{0}, + tx_msgs_{0}, + rx_iovecs_{0}, + rx_msgs_{0}, + current_position_(0), +#else + read_msg_(nullptr, 0), +#endif + data_available_(false) { + } + + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect, typename EndpointType> + UdpTunnelConnector(std::shared_ptr<asio::ip::udp::socket> &socket, + std::shared_ptr<asio::io_service::strand> &strand, + ReceiveCallback &&receive_callback, + SentCallback &&packet_sent, OnClose &&on_close_callback, + OnReconnect &&on_reconnect, EndpointType &&remote_endpoint) + : Connector(receive_callback, packet_sent, on_close_callback, + on_reconnect), +#if ((ASIO_VERSION / 100 % 1000) < 12) + io_service_(socket->get_io_service()), +#else + io_service_((asio::io_context &)(socket->get_executor().context())), +#endif + strand_(strand), + socket_(socket), + resolver_(io_service_), + remote_endpoint_send_(std::forward<EndpointType &&>(remote_endpoint)), + timer_(io_service_), +#ifdef LINUX + send_timer_(io_service_), + tx_iovecs_{0}, + tx_msgs_{0}, + rx_iovecs_{0}, + rx_msgs_{0}, + current_position_(0), +#else + read_msg_(nullptr, 0), +#endif + data_available_(false) { + if (socket_->is_open()) { + state_ = State::CONNECTED; + remote_endpoint_ = Endpoint(remote_endpoint_send_); + local_endpoint_ = socket_->local_endpoint(); + } + } + + ~UdpTunnelConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(const std::string &hostname, std::uint16_t port, + const std::string &bind_address = "", + std::uint16_t bind_port = 0); + + auto shared_from_this() { return utils::shared_from(this); } + + private: + void doConnect(); + void doRecvPacket(); + + void doRecvPacket(utils::MemBuf &buffer) { + receive_callback_(this, buffer, make_error_code(forwarder_error::success)); + } + +#ifdef LINUX + void readHandler(std::error_code ec); + void writeHandler(std::error_code ec); +#endif + + void setConnected() { state_ = State::CONNECTED; } + + void doSendPacket(); + void doClose(); + + private: + asio::io_service &io_service_; + std::shared_ptr<asio::io_service::strand> strand_; + std::shared_ptr<asio::ip::udp::socket> socket_; + asio::ip::udp::resolver resolver_; + asio::ip::udp::resolver::iterator endpoint_iterator_; + asio::ip::udp::endpoint remote_endpoint_send_; + asio::ip::udp::endpoint remote_endpoint_recv_; + + asio::steady_timer timer_; + +#ifdef LINUX + asio::steady_timer send_timer_; + struct iovec tx_iovecs_[max_burst][8]; + struct mmsghdr tx_msgs_[max_burst]; + struct iovec rx_iovecs_[max_burst][8]; + struct mmsghdr rx_msgs_[max_burst]; + std::uint8_t current_position_; +#else + std::pair<uint8_t *, std::size_t> read_msg_; +#endif + + bool data_available_; +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc new file mode 100644 index 000000000..12246c3cf --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#include <hicn/transport/utils/hash.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/forwarder/udp_tunnel.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +#ifndef LINUX +namespace std { +size_t hash<asio::ip::udp::endpoint>::operator()( + const asio::ip::udp::endpoint &endpoint) const { + auto hash_ip = endpoint.address().is_v4() + ? endpoint.address().to_v4().to_ulong() + : utils::hash::fnv32_buf( + endpoint.address().to_v6().to_bytes().data(), 16); + uint16_t port = endpoint.port(); + return utils::hash::fnv32_buf(&port, 2, hash_ip); +} +} // namespace std +#endif + +namespace transport { +namespace core { + +UdpTunnelListener::~UdpTunnelListener() {} + +void UdpTunnelListener::close() { + strand_->post([this]() { + if (socket_->is_open()) { + socket_->close(); + } + }); +} + +#ifdef LINUX +void UdpTunnelListener::readHandler(std::error_code ec) { + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + + // TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (current_position_ == 0) { + for (int i = 0; i < Connector::max_burst; i++) { + auto read_buffer = Connector::getRawBuffer(); + iovecs_[i][0].iov_base = read_buffer.first; + iovecs_[i][0].iov_len = read_buffer.second; + msgs_[i].msg_hdr.msg_iov = iovecs_[i]; + msgs_[i].msg_hdr.msg_iovlen = 1; + msgs_[i].msg_hdr.msg_name = &remote_endpoints_[i]; + msgs_[i].msg_hdr.msg_namelen = sizeof(remote_endpoints_[i]); + } + } + + int res = recvmmsg(socket_->native_handle(), msgs_ + current_position_, + Connector::max_burst - current_position_, MSG_DONTWAIT, + nullptr); + if (res < 0) { + TRANSPORT_LOGE("Error in recvmmsg."); + } + + for (int i = 0; i < res; i++) { + auto packet = Connector::getPacketFromBuffer( + reinterpret_cast<uint8_t *>( + msgs_[current_position_].msg_hdr.msg_iov[0].iov_base), + msgs_[current_position_].msg_len); + auto connector_id = + utils::hash::fnv64_buf(msgs_[current_position_].msg_hdr.msg_name, + msgs_[current_position_].msg_hdr.msg_namelen); + + auto connector = connectors_.find(connector_id); + if (connector == connectors_.end()) { + // Create new connector corresponding to new client + + /* + * Get the remote endpoint for this particular message + */ + using namespace asio::ip; + if (local_endpoint_.address().is_v4()) { + auto addr = reinterpret_cast<struct sockaddr_in *>( + &remote_endpoints_[current_position_]); + address_v4::bytes_type address_bytes; + std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin_addr), + address_bytes.size(), address_bytes.begin()); + address_v4 address(address_bytes); + remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin_port)); + } else { + auto addr = reinterpret_cast<struct sockaddr_in6 *>( + &remote_endpoints_[current_position_]); + address_v6::bytes_type address_bytes; + std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin6_addr), + address_bytes.size(), address_bytes.begin()); + address_v6 address(address_bytes); + remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin6_port)); + } + + /** + * Create new connector sharing the same socket of this listener. + */ + auto ret = connectors_.emplace( + connector_id, + std::make_shared<UdpTunnelConnector>( + socket_, strand_, receive_callback_, + [](Connector *, const std::error_code &) {}, [](Connector *) {}, + [](Connector *) {}, std::move(remote_endpoint_))); + connector = ret.first; + connector->second->setConnectorId(connector_id); + } + + /** + * Use connector callback to process incoming message. + */ + UdpTunnelConnector *c = + dynamic_cast<UdpTunnelConnector *>(connector->second.get()); + c->doRecvPacket(*packet); + + ++current_position_; + } + + doRecvPacket(); + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } +} +#endif + +void UdpTunnelListener::doRecvPacket() { +#ifdef LINUX +#if ((ASIO_VERSION / 100 % 1000) < 11) + socket_->async_receive( + asio::null_buffers(), +#else + socket_->async_wait( + asio::ip::tcp::socket::wait_read, +#endif + std::bind(&UdpTunnelListener::readHandler, this, std::placeholders::_1)); +#else + read_msg_ = Connector::getRawBuffer(); + socket_->async_receive_from( + asio::buffer(read_msg_.first, read_msg_.second), remote_endpoint_, + [this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + auto packet = Connector::getPacketFromBuffer(read_msg_.first, length); + auto connector_id = + std::hash<asio::ip::udp::endpoint>{}(remote_endpoint_); + auto connector = connectors_.find(connector_id); + if (connector == connectors_.end()) { + // Create new connector corresponding to new client + auto ret = connectors_.emplace( + connector_id, std::make_shared<UdpTunnelConnector>( + socket_, strand_, receive_callback_, + [](Connector *, const std::error_code &) {}, + [](Connector *) {}, [](Connector *) {}, + std::move(remote_endpoint_))); + connector = ret.first; + connector->second->setConnectorId(connector_id); + } + + UdpTunnelConnector *c = + dynamic_cast<UdpTunnelConnector *>(connector->second.get()); + c->doRecvPacket(*packet); + doRecvPacket(); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + }); +#endif +} +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h new file mode 100644 index 000000000..0ee40a400 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/platform.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <unordered_map> + +namespace std { +template <> +struct hash<asio::ip::udp::endpoint> { + size_t operator()(const asio::ip::udp::endpoint &endpoint) const; +}; +} // namespace std + +namespace transport { +namespace core { + +class UdpTunnelListener + : public std::enable_shared_from_this<UdpTunnelListener> { + using PacketReceivedCallback = Connector::PacketReceivedCallback; + using EndpointId = std::pair<uint32_t, uint16_t>; + + static constexpr uint16_t default_port = 5004; + + public: + using Ptr = std::shared_ptr<UdpTunnelListener>; + + template <typename ReceiveCallback> + UdpTunnelListener(asio::io_service &io_service, + ReceiveCallback &&receive_callback, + asio::ip::udp::endpoint endpoint = asio::ip::udp::endpoint( + asio::ip::udp::v4(), default_port)) + : io_service_(io_service), + strand_(std::make_shared<asio::io_service::strand>(io_service_)), + socket_(std::make_shared<asio::ip::udp::socket>(io_service_, + endpoint.protocol())), + local_endpoint_(endpoint), + receive_callback_(std::forward<ReceiveCallback &&>(receive_callback)), +#ifndef LINUX + read_msg_(nullptr, 0) +#else + iovecs_{0}, + msgs_{0}, + current_position_(0) +#endif + { + if (endpoint.protocol() == asio::ip::udp::v6()) { + std::error_code ec; + socket_->set_option(asio::ip::v6_only(false), ec); + // Call succeeds only on dual stack systems. + } + socket_->bind(local_endpoint_); + io_service_.post(std::bind(&UdpTunnelListener::doRecvPacket, this)); + } + + ~UdpTunnelListener(); + + void close(); + + int deleteConnector(Connector *connector) { + return connectors_.erase(connector->getConnectorId()); + } + + template <typename ReceiveCallback> + void setReceiveCallback(ReceiveCallback &&callback) { + receive_callback_ = std::forward<ReceiveCallback &&>(callback); + } + + Connector *findConnector(Connector::Id connId) { + auto it = connectors_.find(connId); + if (it != connectors_.end()) { + return it->second.get(); + } + + return nullptr; + } + + private: + void doRecvPacket(); + + void readHandler(std::error_code ec); + + asio::io_service &io_service_; + std::shared_ptr<asio::io_service::strand> strand_; + std::shared_ptr<asio::ip::udp::socket> socket_; + asio::ip::udp::endpoint local_endpoint_; + asio::ip::udp::endpoint remote_endpoint_; + std::unordered_map<Connector::Id, std::shared_ptr<Connector>> connectors_; + + PacketReceivedCallback receive_callback_; + +#ifdef LINUX + struct iovec iovecs_[Connector::max_burst][8]; + struct mmsghdr msgs_[Connector::max_burst]; + struct sockaddr_storage remote_endpoints_[Connector::max_burst]; + std::uint8_t current_position_; +#else + std::pair<uint8_t *, std::size_t> read_msg_; +#endif +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/CMakeLists.txt b/libtransport/src/io_modules/loopback/CMakeLists.txt new file mode 100644 index 000000000..ac6dc8068 --- /dev/null +++ b/libtransport/src/io_modules/loopback/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/loopback_module.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/loopback_module.cc +) + +build_module(loopback_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + # LIBRARY_ROOT_DIR "vpp_plugins" + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/loopback/local_face.cc b/libtransport/src/io_modules/loopback/local_face.cc new file mode 100644 index 000000000..a59dab235 --- /dev/null +++ b/libtransport/src/io_modules/loopback/local_face.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/loopback/local_face.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +Face::Face(Connector::PacketReceivedCallback &&receive_callback, + asio::io_service &io_service, const std::string &app_name) + : receive_callback_(std::move(receive_callback)), + io_service_(io_service), + name_(app_name) {} + +Face::Face(const Face &other) + : receive_callback_(other.receive_callback_), + io_service_(other.io_service_), + name_(other.name_) {} + +Face::Face(Face &&other) + : receive_callback_(std::move(other.receive_callback_)), + io_service_(other.io_service_), + name_(std::move(other.name_)) {} + +Face &Face::operator=(const Face &other) { + receive_callback_ = other.receive_callback_; + io_service_ = other.io_service_; + name_ = other.name_; + + return *this; +} + +Face &Face::operator=(Face &&other) { + receive_callback_ = std::move(other.receive_callback_); + io_service_ = std::move(other.io_service_); + name_ = std::move(other.name_); + + return *this; +} + +void Face::onPacket(const Packet &packet) { + TRANSPORT_LOGD("Sending content to local socket."); + + if (Packet::isInterest(packet.data())) { + rescheduleOnIoService<Interest>(packet); + } else { + rescheduleOnIoService<ContentObject>(packet); + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/loopback/local_face.h b/libtransport/src/io_modules/loopback/local_face.h new file mode 100644 index 000000000..1cbcc2c72 --- /dev/null +++ b/libtransport/src/io_modules/loopback/local_face.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/move_wrapper.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +class Face { + public: + Face(Connector::PacketReceivedCallback &&receive_callback, + asio::io_service &io_service, const std::string &app_name); + + Face(const Face &other); + Face(Face &&other); + void onPacket(const Packet &packet); + Face &operator=(Face &&other); + Face &operator=(const Face &other); + + private: + template <typename T> + void rescheduleOnIoService(const Packet &packet) { + auto p = core::PacketManager<T>::getInstance().getPacket(); + p->replace(packet.data(), packet.length()); + io_service_.get().post([this, p]() mutable { + receive_callback_(nullptr, *p, make_error_code(0)); + }); + } + + Connector::PacketReceivedCallback receive_callback_; + std::reference_wrapper<asio::io_service> io_service_; + std::string name_; +}; + +} // namespace core +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/loopback_module.cc b/libtransport/src/io_modules/loopback/loopback_module.cc new file mode 100644 index 000000000..0bdbf8c8e --- /dev/null +++ b/libtransport/src/io_modules/loopback/loopback_module.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/loopback/loopback_module.h> + +namespace transport { + +namespace core { + +std::vector<std::unique_ptr<LocalConnector>> LoopbackModule::local_faces_; +std::atomic<uint32_t> LoopbackModule::global_counter_(0); + +LoopbackModule::LoopbackModule() : IoModule(), local_id_(~0) {} + +LoopbackModule::~LoopbackModule() {} + +void LoopbackModule::connect(bool is_consumer) {} + +bool LoopbackModule::isConnected() { return true; } + +void LoopbackModule::send(Packet &packet) { + IoModule::send(packet); + + TRANSPORT_LOGD("LoopbackModule: sending from %u to %d", local_id_, + 1 - local_id_); + + local_faces_.at(1 - local_id_)->send(packet); +} + +void LoopbackModule::send(const uint8_t *packet, std::size_t len) { + // not supported + throw errors::NotImplementedException(); +} + +void LoopbackModule::registerRoute(const Prefix &prefix) { + // For the moment we route packets from one socket to the other. + // Next step will be to introduce a FIB + return; +} + +void LoopbackModule::closeConnection() { + local_faces_.erase(local_faces_.begin() + local_id_); +} + +void LoopbackModule::init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name) { + if (local_id_ == uint32_t(~0) && global_counter_ < 2) { + local_id_ = global_counter_++; + local_faces_.emplace( + local_faces_.begin() + local_id_, + new LocalConnector(io_service, std::move(receive_callback), nullptr, + nullptr, std::move(reconnect_callback))); + } +} + +void LoopbackModule::processControlMessageReply(utils::MemBuf &packet_buffer) { + return; +} + +std::uint32_t LoopbackModule::getMtu() { return interface_mtu; } + +bool LoopbackModule::isControlMessage(const uint8_t *message) { return false; } + +extern "C" IoModule *create_module(void) { return new LoopbackModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/loopback_module.h b/libtransport/src/io_modules/loopback/loopback_module.h new file mode 100644 index 000000000..219fa8841 --- /dev/null +++ b/libtransport/src/io_modules/loopback/loopback_module.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/local_connector.h> +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +#include <atomic> + +namespace transport { + +namespace core { + +class LoopbackModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + LoopbackModule(); + + ~LoopbackModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + static std::vector<std::unique_ptr<LocalConnector>> local_faces_; + static std::atomic<uint32_t> global_counter_; + + private: + uint32_t local_id_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/memif/CMakeLists.txt b/libtransport/src/io_modules/memif/CMakeLists.txt new file mode 100644 index 000000000..c8a930e7b --- /dev/null +++ b/libtransport/src/io_modules/memif/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +find_package(Vpp REQUIRED) +find_package(Libmemif REQUIRED) + +if(CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) + find_package(HicnPlugin REQUIRED) + find_package(SafeVapi REQUIRED) +else() + list(APPEND DEPENDENCIES + ${SAFE_VAPI_SHARED} + ) +endif() + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.h + ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.h + ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_module.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_module.cc +) + +build_module(memif_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + LINK_LIBRARIES ${LIBMEMIF_LIBRARIES} ${SAFE_VAPI_LIBRARIES} + INCLUDE_DIRS + ${LIBTRANSPORT_INCLUDE_DIRS} + ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + ${VPP_INCLUDE_DIRS} + ${LIBMEMIF_INCLUDE_DIRS} + ${SAFE_VAPI_INCLUDE_DIRS} + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/memif/hicn_vapi.c b/libtransport/src/io_modules/memif/hicn_vapi.c new file mode 100644 index 000000000..b83a36b47 --- /dev/null +++ b/libtransport/src/io_modules/memif/hicn_vapi.c @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/config.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/memif/hicn_vapi.h> + +#define HICN_VPP_PLUGIN +#include <hicn/name.h> +#undef HICN_VPP_PLUGIN + +#include <vapi/hicn.api.vapi.h> +#include <vapi/ip.api.vapi.h> +#include <vapi/vapi_safe.h> +#include <vlib/vlib.h> +#include <vlibapi/api.h> +#include <vlibmemory/api.h> +#include <vnet/ip/format.h> +#include <vnet/ip/ip4_packet.h> +#include <vnet/ip/ip6_packet.h> +#include <vpp_plugins/hicn/error.h> +#include <vppinfra/error.h> + +///////////////////////////////////////////////////// +const char *HICN_ERROR_STRING[] = { +#define _(a, b, c) c, + foreach_hicn_error +#undef _ +}; +///////////////////////////////////////////////////// + +/*********************** Missing Symbol in vpp libraries + * *************************/ +u8 *format_vl_api_address_union(u8 *s, va_list *args) { return NULL; } + +/*********************************************************************************/ + +DEFINE_VAPI_MSG_IDS_HICN_API_JSON +DEFINE_VAPI_MSG_IDS_IP_API_JSON + +static vapi_error_e register_prod_app_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_register_prod_app_reply *reply) { + hicn_producer_output_params *output_params = + (hicn_producer_output_params *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->cs_reserved = reply->cs_reserved; + output_params->prod_addr = (ip_address_t *)malloc(sizeof(ip_address_t)); + memset(output_params->prod_addr, 0, sizeof(ip_address_t)); + if (reply->prod_addr.af == ADDRESS_IP6) + memcpy(&output_params->prod_addr->v6, reply->prod_addr.un.ip6, + sizeof(ip6_address_t)); + else + memcpy(&output_params->prod_addr->v4, reply->prod_addr.un.ip4, + sizeof(ip4_address_t)); + output_params->face_id = reply->faceid; + + return reply->retval; +} + +int hicn_vapi_register_prod_app(vapi_ctx_t ctx, + hicn_producer_input_params *input_params, + hicn_producer_output_params *output_params) { + vapi_lock(); + vapi_msg_hicn_api_register_prod_app *msg = + vapi_alloc_hicn_api_register_prod_app(ctx); + + if (ip46_address_is_ip4((ip46_address_t *)&input_params->prefix->address)) { + memcpy(&msg->payload.prefix.address.un.ip4, &input_params->prefix->address, + sizeof(ip4_address_t)); + msg->payload.prefix.address.af = ADDRESS_IP4; + } else { + memcpy(&msg->payload.prefix.address.un.ip6, &input_params->prefix->address, + sizeof(ip6_address_t)); + msg->payload.prefix.address.af = ADDRESS_IP6; + } + msg->payload.prefix.len = input_params->prefix->len; + + msg->payload.swif = input_params->swif; + msg->payload.cs_reserved = input_params->cs_reserved; + + int ret = vapi_hicn_api_register_prod_app(ctx, msg, register_prod_app_cb, + output_params); + vapi_unlock(); + return ret; +} + +static vapi_error_e face_prod_del_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_face_prod_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_face_prod_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params *input_params) { + vapi_lock(); + vapi_msg_hicn_api_face_prod_del *msg = vapi_alloc_hicn_api_face_prod_del(ctx); + + msg->payload.faceid = input_params->face_id; + + int ret = vapi_hicn_api_face_prod_del(ctx, msg, face_prod_del_cb, NULL); + vapi_unlock(); + return ret; +} + +static vapi_error_e register_cons_app_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_register_cons_app_reply *reply) { + hicn_consumer_output_params *output_params = + (hicn_consumer_output_params *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->src6 = (ip_address_t *)malloc(sizeof(ip_address_t)); + output_params->src4 = (ip_address_t *)malloc(sizeof(ip_address_t)); + memset(output_params->src6, 0, sizeof(ip_address_t)); + memset(output_params->src4, 0, sizeof(ip_address_t)); + memcpy(&output_params->src6->v6, &reply->src_addr6.un.ip6, + sizeof(ip6_address_t)); + memcpy(&output_params->src4->v4, &reply->src_addr4.un.ip4, + sizeof(ip4_address_t)); + + output_params->face_id1 = reply->faceid1; + output_params->face_id2 = reply->faceid2; + + return reply->retval; +} + +int hicn_vapi_register_cons_app(vapi_ctx_t ctx, + hicn_consumer_input_params *input_params, + hicn_consumer_output_params *output_params) { + vapi_lock(); + vapi_msg_hicn_api_register_cons_app *msg = + vapi_alloc_hicn_api_register_cons_app(ctx); + + msg->payload.swif = input_params->swif; + + int ret = vapi_hicn_api_register_cons_app(ctx, msg, register_cons_app_cb, + output_params); + vapi_unlock(); + return ret; +} + +static vapi_error_e face_cons_del_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_face_cons_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_face_cons_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params *input_params) { + vapi_lock(); + vapi_msg_hicn_api_face_cons_del *msg = vapi_alloc_hicn_api_face_cons_del(ctx); + + msg->payload.faceid = input_params->face_id; + + int ret = vapi_hicn_api_face_cons_del(ctx, msg, face_cons_del_cb, NULL); + vapi_unlock(); + return ret; +} + +static vapi_error_e reigster_route_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_ip_route_add_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_register_route(vapi_ctx_t ctx, + hicn_producer_set_route_params *input_params) { + vapi_lock(); + vapi_msg_ip_route_add_del *msg = vapi_alloc_ip_route_add_del(ctx, 1); + + msg->payload.is_add = 1; + if (ip46_address_is_ip4((ip46_address_t *)(input_params->prod_addr))) { + memcpy(&msg->payload.route.prefix.address.un.ip4, + &input_params->prefix->address.v4, sizeof(ip4_address_t)); + msg->payload.route.prefix.address.af = ADDRESS_IP4; + msg->payload.route.prefix.len = input_params->prefix->len; + } else { + memcpy(&msg->payload.route.prefix.address.un.ip6, + &input_params->prefix->address.v6, sizeof(ip6_address_t)); + msg->payload.route.prefix.address.af = ADDRESS_IP6; + msg->payload.route.prefix.len = input_params->prefix->len; + } + + msg->payload.route.paths[0].sw_if_index = ~0; + msg->payload.route.paths[0].table_id = 0; + if (ip46_address_is_ip4((ip46_address_t *)(input_params->prod_addr))) { + memcpy(&(msg->payload.route.paths[0].nh.address.ip4), + input_params->prod_addr->v4.as_u8, sizeof(ip4_address_t)); + msg->payload.route.paths[0].proto = FIB_API_PATH_NH_PROTO_IP4; + } else { + memcpy(&(msg->payload.route.paths[0].nh.address.ip6), + input_params->prod_addr->v6.as_u8, sizeof(ip6_address_t)); + msg->payload.route.paths[0].proto = FIB_API_PATH_NH_PROTO_IP6; + } + + msg->payload.route.paths[0].type = FIB_API_PATH_FLAG_NONE; + msg->payload.route.paths[0].flags = FIB_API_PATH_FLAG_NONE; + + int ret = vapi_ip_route_add_del(ctx, msg, reigster_route_cb, NULL); + + vapi_unlock(); + return ret; +} + +char *hicn_vapi_get_error_string(int ret_val) { + return get_error_string(ret_val); +} diff --git a/libtransport/src/io_modules/memif/hicn_vapi.h b/libtransport/src/io_modules/memif/hicn_vapi.h new file mode 100644 index 000000000..e94c97749 --- /dev/null +++ b/libtransport/src/io_modules/memif/hicn_vapi.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/util/ip_address.h> + +#ifdef __cplusplus +extern "C" { +#endif + +#include <vapi/vapi.h> + +#include "stdint.h" + +typedef struct { + ip_prefix_t* prefix; + uint32_t swif; + uint32_t cs_reserved; +} hicn_producer_input_params; + +typedef struct { + uint32_t swif; +} hicn_consumer_input_params; + +typedef struct { + uint32_t face_id; +} hicn_del_face_app_input_params; + +typedef struct { + uint32_t cs_reserved; + ip_address_t* prod_addr; + uint32_t face_id; +} hicn_producer_output_params; + +typedef struct { + ip_address_t* src4; + ip_address_t* src6; + uint32_t face_id1; + uint32_t face_id2; +} hicn_consumer_output_params; + +typedef struct { + ip_prefix_t* prefix; + ip_address_t* prod_addr; +} hicn_producer_set_route_params; + +int hicn_vapi_register_prod_app(vapi_ctx_t ctx, + hicn_producer_input_params* input_params, + hicn_producer_output_params* output_params); + +int hicn_vapi_register_cons_app(vapi_ctx_t ctx, + hicn_consumer_input_params* input_params, + hicn_consumer_output_params* output_params); + +int hicn_vapi_register_route(vapi_ctx_t ctx, + hicn_producer_set_route_params* input_params); + +int hicn_vapi_face_cons_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params* input_params); + +int hicn_vapi_face_prod_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params* input_params); + +char* hicn_vapi_get_error_string(int ret_val); + +#ifdef __cplusplus +} +#endif diff --git a/libtransport/src/io_modules/memif/memif_connector.cc b/libtransport/src/io_modules/memif/memif_connector.cc new file mode 100644 index 000000000..4a688d68f --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_connector.cc @@ -0,0 +1,493 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <io_modules/memif/memif_connector.h> +#include <sys/epoll.h> + +#include <cstdlib> + +extern "C" { +#include <memif/libmemif.h> +}; + +#define CANCEL_TIMER 1 + +namespace transport { + +namespace core { + +struct memif_connection { + uint16_t index; + /* memif conenction handle */ + memif_conn_handle_t conn; + /* transmit queue id */ + uint16_t tx_qid; + /* tx buffers */ + memif_buffer_t *tx_bufs; + /* allocated tx buffers counter */ + /* number of tx buffers pointing to shared memory */ + uint16_t tx_buf_num; + /* rx buffers */ + memif_buffer_t *rx_bufs; + /* allcoated rx buffers counter */ + /* number of rx buffers pointing to shared memory */ + uint16_t rx_buf_num; + /* interface ip address */ + uint8_t ip_addr[4]; +}; + +std::once_flag MemifConnector::flag_; +utils::EpollEventReactor MemifConnector::main_event_reactor_; + +MemifConnector::MemifConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name) + : Connector(std::move(receive_callback), std::move(packet_sent), + std::move(close_callback), std::move(on_reconnect)), + memif_worker_(nullptr), + timer_set_(false), + send_timer_(std::make_unique<utils::FdDeadlineTimer>(event_reactor_)), + disconnect_timer_( + std::make_unique<utils::FdDeadlineTimer>(event_reactor_)), + io_service_(io_service), + memif_connection_(std::make_unique<memif_connection_t>()), + tx_buf_counter_(0), + is_reconnection_(false), + data_available_(false), + app_name_(app_name), + socket_filename_("") { + std::call_once(MemifConnector::flag_, &MemifConnector::init, this); +} + +MemifConnector::~MemifConnector() { close(); } + +void MemifConnector::init() { + /* initialize memory interface */ + int err = memif_init(controlFdUpdate, const_cast<char *>(app_name_.c_str()), + nullptr, nullptr, nullptr); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_init: %s", memif_strerror(err)); + } +} + +void MemifConnector::connect(uint32_t memif_id, long memif_mode) { + state_ = State::CONNECTING; + + memif_id_ = memif_id; + socket_filename_ = "/run/vpp/memif.sock"; + + createMemif(memif_id, memif_mode, nullptr); + + work_ = std::make_unique<asio::io_service::work>(io_service_); + + while (state_ != State::CONNECTED) { + MemifConnector::main_event_reactor_.runOneEvent(); + } + + int err; + + /* get interrupt queue id */ + int fd = -1; + err = memif_get_queue_efd(memif_connection_->conn, 0, &fd); + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_get_queue_efd: %s", memif_strerror(err)); + return; + } + + // Remove fd from main epoll + main_event_reactor_.delFileDescriptor(fd); + + // Add fd to epoll of instance + event_reactor_.addFileDescriptor( + fd, EPOLLIN, [this](const utils::Event &evt) -> int { + return onInterrupt(memif_connection_->conn, this, 0); + }); + + memif_worker_ = std::make_unique<std::thread>( + std::bind(&MemifConnector::threadMain, this)); +} + +int MemifConnector::createMemif(uint32_t index, uint8_t mode, char *s) { + memif_connection_t *c = memif_connection_.get(); + + /* setting memif connection arguments */ + memif_conn_args_t args; + memset(&args, 0, sizeof(args)); + + args.is_master = mode; + args.log2_ring_size = MEMIF_LOG2_RING_SIZE; + args.buffer_size = MEMIF_BUF_SIZE; + args.num_s2m_rings = 1; + args.num_m2s_rings = 1; + strncpy((char *)args.interface_name, IF_NAME, strlen(IF_NAME) + 1); + args.mode = memif_interface_mode_t::MEMIF_INTERFACE_MODE_IP; + + int err; + + err = memif_create_socket(&args.socket, socket_filename_.c_str(), nullptr); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + throw errors::RuntimeException(memif_strerror(err)); + } + + args.interface_id = index; + /* last argument for memif_create (void * private_ctx) is used by user + to identify connection. this context is returned with callbacks */ + + /* default interrupt */ + if (s == nullptr) { + err = memif_create(&c->conn, &args, onConnect, onDisconnect, onInterrupt, + this); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + throw errors::RuntimeException(memif_strerror(err)); + } + } + + c->index = (uint16_t)index; + c->tx_qid = 0; + /* alloc memif buffers */ + c->rx_buf_num = 0; + c->rx_bufs = static_cast<memif_buffer_t *>( + malloc(sizeof(memif_buffer_t) * MAX_MEMIF_BUFS)); + c->tx_buf_num = 0; + c->tx_bufs = static_cast<memif_buffer_t *>( + malloc(sizeof(memif_buffer_t) * MAX_MEMIF_BUFS)); + + // memif_set_rx_mode (c->conn, MEMIF_RX_MODE_POLLING, 0); + + return 0; +} + +int MemifConnector::deleteMemif() { + memif_connection_t *c = memif_connection_.get(); + + if (c->rx_bufs) { + free(c->rx_bufs); + } + + c->rx_bufs = nullptr; + c->rx_buf_num = 0; + + if (c->tx_bufs) { + free(c->tx_bufs); + } + + c->tx_bufs = nullptr; + c->tx_buf_num = 0; + + int err; + /* disconenct then delete memif connection */ + err = memif_delete(&c->conn); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_delete: %s", memif_strerror(err)); + } + + if (TRANSPORT_EXPECT_FALSE(c->conn != nullptr)) { + TRANSPORT_LOGE("memif delete fail"); + } + + return 0; +} + +int MemifConnector::controlFdUpdate(int fd, uint8_t events, void *private_ctx) { + /* convert memif event definitions to epoll events */ + if (events & MEMIF_FD_EVENT_DEL) { + return MemifConnector::main_event_reactor_.delFileDescriptor(fd); + } + + uint32_t evt = 0; + + if (events & MEMIF_FD_EVENT_READ) { + evt |= EPOLLIN; + } + + if (events & MEMIF_FD_EVENT_WRITE) { + evt |= EPOLLOUT; + } + + if (events & MEMIF_FD_EVENT_MOD) { + return MemifConnector::main_event_reactor_.modFileDescriptor(fd, evt); + } + + return MemifConnector::main_event_reactor_.addFileDescriptor( + fd, evt, [](const utils::Event &evt) -> int { + uint32_t event = 0; + int memif_err = 0; + + if (evt.events & EPOLLIN) { + event |= MEMIF_FD_EVENT_READ; + } + + if (evt.events & EPOLLOUT) { + event |= MEMIF_FD_EVENT_WRITE; + } + + if (evt.events & EPOLLERR) { + event |= MEMIF_FD_EVENT_ERROR; + } + + memif_err = memif_control_fd_handler(evt.data.fd, event); + + if (TRANSPORT_EXPECT_FALSE(memif_err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_control_fd_handler: %s", + memif_strerror(memif_err)); + } + + return 0; + }); +} + +int MemifConnector::bufferAlloc(long n, uint16_t qid) { + memif_connection_t *c = memif_connection_.get(); + int err; + uint16_t r; + /* set data pointer to shared memory and set buffer_len to shared mmeory + * buffer len */ + err = memif_buffer_alloc(c->conn, qid, c->tx_bufs, n, &r, 2000); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_alloc: %s", memif_strerror(err)); + return -1; + } + + c->tx_buf_num += r; + return r; +} + +int MemifConnector::txBurst(uint16_t qid) { + memif_connection_t *c = memif_connection_.get(); + int err; + uint16_t r; + /* inform peer memif interface about data in shared memory buffers */ + /* mark memif buffers as free */ + err = memif_tx_burst(c->conn, qid, c->tx_bufs, c->tx_buf_num, &r); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); + } + + // err = memif_refill_queue(c->conn, qid, r, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); + c->tx_buf_num -= r; + return -1; + } + + c->tx_buf_num -= r; + return 0; +} + +void MemifConnector::sendCallback(const std::error_code &ec) { + timer_set_ = false; + + if (TRANSPORT_EXPECT_TRUE(!ec && state_ == State::CONNECTED)) { + doSend(); + } +} + +void MemifConnector::processInputBuffer(std::uint16_t total_packets) { + utils::MemBuf::Ptr ptr; + + for (; total_packets > 0; total_packets--) { + if (input_buffer_.pop(ptr)) { + receive_callback_(this, *ptr, std::make_error_code(std::errc(0))); + } + } +} + +/* informs user about connected status. private_ctx is used by user to identify + connection (multiple connections WIP) */ +int MemifConnector::onConnect(memif_conn_handle_t conn, void *private_ctx) { + MemifConnector *connector = (MemifConnector *)private_ctx; + connector->state_ = State::CONNECTED; + memif_refill_queue(conn, 0, -1, 0); + + return 0; +} + +/* informs user about disconnected status. private_ctx is used by user to + identify connection (multiple connections WIP) */ +int MemifConnector::onDisconnect(memif_conn_handle_t conn, void *private_ctx) { + MemifConnector *connector = (MemifConnector *)private_ctx; + connector->state_ = State::CLOSED; + return 0; +} + +void MemifConnector::threadMain() { event_reactor_.runEventLoop(1000); } + +int MemifConnector::onInterrupt(memif_conn_handle_t conn, void *private_ctx, + uint16_t qid) { + MemifConnector *connector = (MemifConnector *)private_ctx; + + memif_connection_t *c = connector->memif_connection_.get(); + int err = MEMIF_ERR_SUCCESS, ret_val; + uint16_t total_packets = 0; + uint16_t rx; + + do { + err = memif_rx_burst(conn, qid, c->rx_bufs, MAX_MEMIF_BUFS, &rx); + ret_val = err; + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS && + err != MEMIF_ERR_NOBUF)) { + TRANSPORT_LOGE("memif_rx_burst: %s", memif_strerror(err)); + goto error; + } + + c->rx_buf_num += rx; + + if (TRANSPORT_EXPECT_FALSE(connector->io_service_.stopped())) { + TRANSPORT_LOGE("socket stopped: ignoring %u packets", rx); + goto error; + } + + std::size_t packet_length; + for (int i = 0; i < rx; i++) { + auto buffer = connector->getRawBuffer(); + packet_length = (c->rx_bufs + i)->len; + std::memcpy(buffer.first, (c->rx_bufs + i)->data, packet_length); + auto packet = connector->getPacketFromBuffer(buffer.first, packet_length); + + if (!connector->input_buffer_.push(std::move(packet))) { + TRANSPORT_LOGE("Error pushing packet. Ring buffer full."); + + // TODO Here we should consider the possibility to signal the congestion + // to the application, that would react properly (e.g. slow down + // message) + } + } + + /* mark memif buffers and shared memory buffers as free */ + /* free processed buffers */ + + err = memif_refill_queue(conn, qid, rx, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); + } + + c->rx_buf_num -= rx; + total_packets += rx; + + } while (ret_val == MEMIF_ERR_NOBUF); + + connector->io_service_.post( + std::bind(&MemifConnector::processInputBuffer, connector, total_packets)); + + return 0; + +error: + err = memif_refill_queue(c->conn, qid, rx, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); + } + c->rx_buf_num -= rx; + + return 0; +} + +void MemifConnector::close() { + if (state_ != State::CLOSED) { + disconnect_timer_->expiresFromNow(std::chrono::microseconds(50)); + disconnect_timer_->asyncWait([this](const std::error_code &ec) { + deleteMemif(); + event_reactor_.stop(); + work_.reset(); + }); + + if (memif_worker_ && memif_worker_->joinable()) { + memif_worker_->join(); + } + } +} + +void MemifConnector::send(Packet &packet) { + { + utils::SpinLock::Acquire locked(write_msgs_lock_); + output_buffer_.push_back(packet.shared_from_this()); + } +#if CANCEL_TIMER + if (!timer_set_) { + timer_set_ = true; + send_timer_->expiresFromNow(std::chrono::microseconds(50)); + send_timer_->asyncWait( + std::bind(&MemifConnector::sendCallback, this, std::placeholders::_1)); + } +#endif +} + +int MemifConnector::doSend() { + std::size_t max = 0; + int32_t n = 0; + std::size_t size = 0; + + { + utils::SpinLock::Acquire locked(write_msgs_lock_); + size = output_buffer_.size(); + } + + do { + max = size < MAX_MEMIF_BUFS ? size : MAX_MEMIF_BUFS; + n = bufferAlloc(max, memif_connection_->tx_qid); + + if (TRANSPORT_EXPECT_FALSE(n < 0)) { + TRANSPORT_LOGE("Error allocating buffers."); + return -1; + } + + for (uint16_t i = 0; i < n; i++) { + utils::SpinLock::Acquire locked(write_msgs_lock_); + + auto packet = output_buffer_.front().get(); + const utils::MemBuf *current = packet; + std::size_t offset = 0; + uint8_t *shared_buffer = + reinterpret_cast<uint8_t *>(memif_connection_->tx_bufs[i].data); + do { + std::memcpy(shared_buffer + offset, current->data(), current->length()); + offset += current->length(); + current = current->next(); + } while (current != packet); + + memif_connection_->tx_bufs[i].len = uint32_t(offset); + + output_buffer_.pop_front(); + } + + txBurst(memif_connection_->tx_qid); + + utils::SpinLock::Acquire locked(write_msgs_lock_); + size = output_buffer_.size(); + } while (size > 0); + + return 0; +} + +void MemifConnector::send(const uint8_t *packet, std::size_t len) { + throw errors::NotImplementedException(); +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/memif/memif_connector.h b/libtransport/src/io_modules/memif/memif_connector.h new file mode 100644 index 000000000..bed3516dc --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_connector.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/portability.h> +#include <hicn/transport/utils/ring_buffer.h> +//#include <hicn/transport/core/hicn_vapi.h> +#include <utils/epoll_event_reactor.h> +#include <utils/fd_deadline_timer.h> + +#include <asio.hpp> +#include <deque> +#include <mutex> +#include <thread> + +#define _Static_assert static_assert + +namespace transport { + +namespace core { + +typedef struct memif_connection memif_connection_t; + +#define APP_NAME "libtransport" +#define IF_NAME "vpp_connection" + +#define MEMIF_BUF_SIZE 2048 +#define MEMIF_LOG2_RING_SIZE 13 +#define MAX_MEMIF_BUFS (1 << MEMIF_LOG2_RING_SIZE) + +class MemifConnector : public Connector { + using memif_conn_handle_t = void *; + using PacketRing = utils::CircularFifo<utils::MemBuf::Ptr, queue_size>; + + public: + MemifConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~MemifConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(uint32_t memif_id, long memif_mode); + + TRANSPORT_ALWAYS_INLINE uint32_t getMemifId() { return memif_id_; }; + + private: + void init(); + + int doSend(); + + int createMemif(uint32_t index, uint8_t mode, char *s); + + uint32_t getMemifConfiguration(); + + int deleteMemif(); + + static int controlFdUpdate(int fd, uint8_t events, void *private_ctx); + + static int onConnect(memif_conn_handle_t conn, void *private_ctx); + + static int onDisconnect(memif_conn_handle_t conn, void *private_ctx); + + static int onInterrupt(memif_conn_handle_t conn, void *private_ctx, + uint16_t qid); + + void threadMain(); + + int txBurst(uint16_t qid); + + int bufferAlloc(long n, uint16_t qid); + + void sendCallback(const std::error_code &ec); + + void processInputBuffer(std::uint16_t total_packets); + + private: + static utils::EpollEventReactor main_event_reactor_; + static std::unique_ptr<std::thread> main_worker_; + + int epfd; + std::unique_ptr<std::thread> memif_worker_; + utils::EpollEventReactor event_reactor_; + std::atomic_bool timer_set_; + std::unique_ptr<utils::FdDeadlineTimer> send_timer_; + std::unique_ptr<utils::FdDeadlineTimer> disconnect_timer_; + asio::io_service &io_service_; + std::unique_ptr<asio::io_service::work> work_; + std::unique_ptr<memif_connection_t> memif_connection_; + uint16_t tx_buf_counter_; + + PacketRing input_buffer_; + bool is_reconnection_; + bool data_available_; + uint32_t memif_id_; + uint8_t memif_mode_; + std::string app_name_; + uint16_t transmission_index_; + utils::SpinLock write_msgs_lock_; + std::string socket_filename_; + + static std::once_flag flag_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/memif/memif_vapi.c b/libtransport/src/io_modules/memif/memif_vapi.c new file mode 100644 index 000000000..b3da2b012 --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_vapi.c @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <fcntl.h> +#include <hicn/transport/config.h> +#include <inttypes.h> +#include <io_modules/memif/memif_vapi.h> +#include <semaphore.h> +#include <string.h> +#include <sys/stat.h> +#include <vapi/vapi_safe.h> +#include <vppinfra/clib.h> + +DEFINE_VAPI_MSG_IDS_MEMIF_API_JSON + +static vapi_error_e memif_details_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_details *reply) { + uint32_t *last_memif_id = (uint32_t *)callback_ctx; + uint32_t current_memif_id = 0; + if (reply != NULL) { + current_memif_id = reply->id; + } else { + return rv; + } + + if (current_memif_id >= *last_memif_id) { + *last_memif_id = current_memif_id + 1; + } + + return rv; +} + +int memif_vapi_get_next_memif_id(vapi_ctx_t ctx, uint32_t *memif_id) { + vapi_lock(); + vapi_msg_memif_dump *msg = vapi_alloc_memif_dump(ctx); + int ret = vapi_memif_dump(ctx, msg, memif_details_cb, memif_id); + vapi_unlock(); + return ret; +} + +static vapi_error_e memif_create_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_create_reply *reply) { + memif_output_params_t *output_params = (memif_output_params_t *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->sw_if_index = reply->sw_if_index; + + return rv; +} + +int memif_vapi_create_memif(vapi_ctx_t ctx, memif_create_params_t *input_params, + memif_output_params_t *output_params) { + vapi_lock(); + vapi_msg_memif_create *msg = vapi_alloc_memif_create(ctx); + + int ret = 0; + if (input_params->socket_id == ~0) { + // invalid socket-id + ret = -1; + goto END; + } + + if (!is_pow2(input_params->ring_size)) { + // ring size must be power of 2 + ret = -1; + goto END; + } + + if (input_params->rx_queues > 255 || input_params->rx_queues < 1) { + // rx queue must be between 1 - 255 + ret = -1; + goto END; + } + + if (input_params->tx_queues > 255 || input_params->tx_queues < 1) { + // tx queue must be between 1 - 255 + ret = -1; + goto END; + } + + msg->payload.role = input_params->role; + msg->payload.mode = input_params->mode; + msg->payload.rx_queues = input_params->rx_queues; + msg->payload.tx_queues = input_params->tx_queues; + msg->payload.id = input_params->id; + msg->payload.socket_id = input_params->socket_id; + msg->payload.ring_size = input_params->ring_size; + msg->payload.buffer_size = input_params->buffer_size; + + ret = vapi_memif_create(ctx, msg, memif_create_cb, output_params); +END: + vapi_unlock(); + return ret; +} + +static vapi_error_e memif_delete_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_delete_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int memif_vapi_delete_memif(vapi_ctx_t ctx, uint32_t sw_if_index) { + vapi_lock(); + vapi_msg_memif_delete *msg = vapi_alloc_memif_delete(ctx); + + msg->payload.sw_if_index = sw_if_index; + + int ret = vapi_memif_delete(ctx, msg, memif_delete_cb, NULL); + vapi_unlock(); + return ret; +} diff --git a/libtransport/src/io_modules/memif/memif_vapi.h b/libtransport/src/io_modules/memif/memif_vapi.h new file mode 100644 index 000000000..bcf06ed43 --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_vapi.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> + +#ifdef __cplusplus +extern "C" { +#endif + +#include <vapi/memif.api.vapi.h> + +#include "stdint.h" + +typedef struct memif_create_params_s { + uint8_t role; + uint8_t mode; + uint8_t rx_queues; + uint8_t tx_queues; + uint32_t id; + uint32_t socket_id; + uint8_t secret[24]; + uint32_t ring_size; + uint16_t buffer_size; + uint8_t hw_addr[6]; +} memif_create_params_t; + +typedef struct memif_output_params_s { + uint32_t sw_if_index; +} memif_output_params_t; + +int memif_vapi_get_next_memif_id(vapi_ctx_t ctx, uint32_t *memif_id); + +int memif_vapi_create_memif(vapi_ctx_t ctx, memif_create_params_t *input_params, + memif_output_params_t *output_params); + +int memif_vapi_delete_memif(vapi_ctx_t ctx, uint32_t sw_if_index); + +#ifdef __cplusplus +} +#endif diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.cc b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc new file mode 100644 index 000000000..dcbcd7ed0 --- /dev/null +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/config.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <io_modules/memif/hicn_vapi.h> +#include <io_modules/memif/memif_connector.h> +#include <io_modules/memif/memif_vapi.h> +#include <io_modules/memif/vpp_forwarder_module.h> + +extern "C" { +#include <memif/libmemif.h> +}; + +typedef enum { MASTER = 0, SLAVE = 1 } memif_role_t; + +#define MEMIF_DEFAULT_RING_SIZE 2048 +#define MEMIF_DEFAULT_RX_QUEUES 1 +#define MEMIF_DEFAULT_TX_QUEUES 1 +#define MEMIF_DEFAULT_BUFFER_SIZE 2048 + +namespace transport { + +namespace core { + +VPPForwarderModule::VPPForwarderModule() + : IoModule(), + connector_(nullptr), + sw_if_index_(~0), + face_id1_(~0), + face_id2_(~0), + is_consumer_(false) {} + +VPPForwarderModule::~VPPForwarderModule() { delete connector_; } + +void VPPForwarderModule::init( + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, const std::string &app_name) { + if (!connector_) { + connector_ = + new MemifConnector(std::move(receive_callback), 0, 0, + std::move(reconnect_callback), io_service, app_name); + } +} + +void VPPForwarderModule::processControlMessageReply( + utils::MemBuf &packet_buffer) { + throw errors::NotImplementedException(); +} + +bool VPPForwarderModule::isControlMessage(const uint8_t *message) { + return false; +} + +bool VPPForwarderModule::isConnected() { return connector_->isConnected(); }; + +void VPPForwarderModule::send(Packet &packet) { + IoModule::send(packet); + connector_->send(packet); +} + +void VPPForwarderModule::send(const uint8_t *packet, std::size_t len) { + counters_.tx_packets++; + counters_.tx_bytes += len; + + // Perfect forwarding + connector_->send(packet, len); +} + +std::uint32_t VPPForwarderModule::getMtu() { return interface_mtu; } + +/** + * @brief Create a memif interface in the local VPP forwarder. + */ +uint32_t VPPForwarderModule::getMemifConfiguration() { + memif_create_params_t input_params = {0}; + + int ret = memif_vapi_get_next_memif_id(VPPForwarderModule::sock_, &memif_id_); + + if (ret < 0) { + throw errors::RuntimeException( + "Error getting next memif id. Could not create memif interface."); + } + + input_params.id = memif_id_; + input_params.role = memif_role_t::MASTER; + input_params.mode = memif_interface_mode_t::MEMIF_INTERFACE_MODE_IP; + input_params.rx_queues = MEMIF_DEFAULT_RX_QUEUES; + input_params.tx_queues = MEMIF_DEFAULT_TX_QUEUES; + input_params.ring_size = MEMIF_DEFAULT_RING_SIZE; + input_params.buffer_size = MEMIF_DEFAULT_BUFFER_SIZE; + + memif_output_params_t output_params = {0}; + + ret = memif_vapi_create_memif(VPPForwarderModule::sock_, &input_params, + &output_params); + + if (ret < 0) { + throw errors::RuntimeException( + "Error creating memif interface in the local VPP forwarder."); + } + + return output_params.sw_if_index; +} + +void VPPForwarderModule::consumerConnection() { + hicn_consumer_input_params input = {0}; + hicn_consumer_output_params output = {0}; + ip_address_t ip4_address; + ip_address_t ip6_address; + + output.src4 = &ip4_address; + output.src6 = &ip6_address; + input.swif = sw_if_index_; + + int ret = + hicn_vapi_register_cons_app(VPPForwarderModule::sock_, &input, &output); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + + face_id1_ = output.face_id1; + face_id2_ = output.face_id2; + + std::memcpy(inet_address_.v4.as_u8, output.src4->v4.as_u8, IPV4_ADDR_LEN); + + std::memcpy(inet6_address_.v6.as_u8, output.src6->v6.as_u8, IPV6_ADDR_LEN); +} + +void VPPForwarderModule::producerConnection() { + // Producer connection will be set when we set the first route. +} + +void VPPForwarderModule::connect(bool is_consumer) { + int retry = 20; + + TRANSPORT_LOGI("Connecting to VPP through vapi."); + vapi_error_e ret = vapi_connect_safe(&sock_, 0); + + while (ret != VAPI_OK && retry > 0) { + TRANSPORT_LOGE("Error connecting to VPP through vapi. Retrying.."); + --retry; + ret = vapi_connect_safe(&sock_, 0); + } + + if (ret != VAPI_OK) { + throw std::runtime_error( + "Impossible to connect to forwarder. Is VPP running?"); + } + + TRANSPORT_LOGI("Connected to VPP through vapi."); + + sw_if_index_ = getMemifConfiguration(); + + is_consumer_ = is_consumer; + if (is_consumer_) { + consumerConnection(); + } + + connector_->connect(memif_id_, 0); + connector_->setRole(is_consumer_ ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +void VPPForwarderModule::registerRoute(const Prefix &prefix) { + const ip_prefix_t &addr = prefix.toIpPrefixStruct(); + + ip_prefix_t producer_prefix; + ip_address_t producer_locator; + + if (face_id1_ == uint32_t(~0)) { + hicn_producer_input_params input; + std::memset(&input, 0, sizeof(input)); + + hicn_producer_output_params output; + std::memset(&output, 0, sizeof(output)); + + input.prefix = &producer_prefix; + output.prod_addr = &producer_locator; + + // Here we have to ask to the actual connector what is the + // memif_id, since this function should be called after the + // memif creation.n + input.swif = sw_if_index_; + input.prefix->address = addr.address; + input.prefix->family = addr.family; + input.prefix->len = addr.len; + input.cs_reserved = content_store_reserved_; + + int ret = + hicn_vapi_register_prod_app(VPPForwarderModule::sock_, &input, &output); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + + inet6_address_ = *output.prod_addr; + + face_id1_ = output.face_id; + } else { + hicn_producer_set_route_params params; + params.prefix = &producer_prefix; + params.prefix->address = addr.address; + params.prefix->family = addr.family; + params.prefix->len = addr.len; + params.prod_addr = &producer_locator; + + int ret = hicn_vapi_register_route(VPPForwarderModule::sock_, ¶ms); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + } +} + +void VPPForwarderModule::closeConnection() { + if (VPPForwarderModule::sock_) { + connector_->close(); + + if (is_consumer_) { + hicn_del_face_app_input_params params; + params.face_id = face_id1_; + hicn_vapi_face_cons_del(VPPForwarderModule::sock_, ¶ms); + params.face_id = face_id2_; + hicn_vapi_face_cons_del(VPPForwarderModule::sock_, ¶ms); + } else { + hicn_del_face_app_input_params params; + params.face_id = face_id1_; + hicn_vapi_face_prod_del(VPPForwarderModule::sock_, ¶ms); + } + + if (sw_if_index_ != uint32_t(~0)) { + int ret = + memif_vapi_delete_memif(VPPForwarderModule::sock_, sw_if_index_); + if (ret < 0) { + TRANSPORT_LOGE("Error deleting memif with sw idx %u.", sw_if_index_); + } + } + + vapi_disconnect_safe(); + VPPForwarderModule::sock_ = nullptr; + } +} + +extern "C" IoModule *create_module(void) { return new VPPForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.h b/libtransport/src/io_modules/memif/vpp_forwarder_module.h new file mode 100644 index 000000000..8c4114fed --- /dev/null +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +#ifdef always_inline +#undef always_inline +#endif +extern "C" { +#include <vapi/vapi_safe.h> +}; + +namespace transport { + +namespace core { + +class MemifConnector; + +class VPPForwarderModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + VPPForwarderModule(); + ~VPPForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + uint32_t getMemifConfiguration(); + void consumerConnection(); + void producerConnection(); + + private: + MemifConnector *connector_; + uint32_t memif_id_; + uint32_t sw_if_index_; + // A consumer socket in vpp has two faces (ipv4 and ipv6) + uint32_t face_id1_; + uint32_t face_id2_; + bool is_consumer_; + vapi_ctx_t sock_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc b/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc new file mode 100644 index 000000000..0bfcc2a58 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/raw_socket_connector.h> +#include <hicn/transport/utils/conversions.h> +#include <hicn/transport/utils/log.h> +#include <net/if.h> +#include <netdb.h> +#include <stdio.h> +#include <string.h> +#include <sys/ioctl.h> +#include <sys/socket.h> + +#define MY_DEST_MAC0 0x0a +#define MY_DEST_MAC1 0x7b +#define MY_DEST_MAC2 0x7c +#define MY_DEST_MAC3 0x1c +#define MY_DEST_MAC4 0x4a +#define MY_DEST_MAC5 0x14 + +namespace transport { + +namespace core { + +RawSocketConnector::RawSocketConnector( + PacketReceivedCallback &&receive_callback, + OnReconnect &&on_reconnect_callback, asio::io_service &io_service, + std::string app_name) + : Connector(std::move(receive_callback), std::move(on_reconnect_callback)), + io_service_(io_service), + socket_(io_service_, raw_protocol(PF_PACKET, SOCK_RAW)), + // resolver_(io_service_), + timer_(io_service_), + read_msg_(packet_pool_.makePtr(nullptr)), + data_available_(false), + app_name_(app_name) { + memset(&link_layer_address_, 0, sizeof(link_layer_address_)); +} + +RawSocketConnector::~RawSocketConnector() {} + +void RawSocketConnector::connect(const std::string &interface_name, + const std::string &mac_address_str) { + state_ = ConnectorState::CONNECTING; + memset(ðernet_header_, 0, sizeof(ethernet_header_)); + struct ifreq ifr; + struct ifreq if_mac; + uint8_t mac_address[6]; + + utils::convertStringToMacAddress(mac_address_str, mac_address); + + // Get interface mac address + int fd = static_cast<int>(socket_.native_handle()); + + /* Get the index of the interface to send on */ + memset(&ifr, 0, sizeof(struct ifreq)); + strncpy(ifr.ifr_name, interface_name.c_str(), interface_name.size()); + + // if (ioctl(fd, SIOCGIFINDEX, &if_idx) < 0) { + // perror("SIOCGIFINDEX"); + // } + + /* Get the MAC address of the interface to send on */ + memset(&if_mac, 0, sizeof(struct ifreq)); + strncpy(if_mac.ifr_name, interface_name.c_str(), interface_name.size()); + if (ioctl(fd, SIOCGIFHWADDR, &if_mac) < 0) { + perror("SIOCGIFHWADDR"); + throw errors::RuntimeException("Interface does not exist"); + } + + /* Ethernet header */ + for (int i = 0; i < 6; i++) { + ethernet_header_.ether_shost[i] = + ((uint8_t *)&if_mac.ifr_hwaddr.sa_data)[i]; + ethernet_header_.ether_dhost[i] = mac_address[i]; + } + + /* Ethertype field */ + ethernet_header_.ether_type = htons(ETH_P_IPV6); + + strcpy(ifr.ifr_name, interface_name.c_str()); + + if (0 == ioctl(fd, SIOCGIFHWADDR, &ifr)) { + memcpy(link_layer_address_.sll_addr, ifr.ifr_hwaddr.sa_data, 6); + } + + // memset(&ifr, 0, sizeof(ifr)); + // ioctl(fd, SIOCGIFFLAGS, &ifr); + // ifr.ifr_flags |= IFF_PROMISC; + // ioctl(fd, SIOCSIFFLAGS, &ifr); + + link_layer_address_.sll_family = AF_PACKET; + link_layer_address_.sll_protocol = htons(ETH_P_ALL); + link_layer_address_.sll_ifindex = if_nametoindex(interface_name.c_str()); + link_layer_address_.sll_hatype = 1; + link_layer_address_.sll_halen = 6; + + // startConnectionTimer(); + doConnect(); + doRecvPacket(); +} + +void RawSocketConnector::send(const uint8_t *packet, std::size_t len, + const PacketSentCallback &packet_sent) { + if (packet_sent != 0) { + socket_.async_send( + asio::buffer(packet, len), + [packet_sent](std::error_code ec, std::size_t /*length*/) { + packet_sent(); + }); + } else { + if (state_ == ConnectorState::CONNECTED) { + socket_.send(asio::buffer(packet, len)); + } + } +} + +void RawSocketConnector::send(const Packet::MemBufPtr &packet) { + io_service_.post([this, packet]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(packet)); + if (TRANSPORT_EXPECT_TRUE(state_ == ConnectorState::CONNECTED)) { + if (!write_in_progress) { + doSendPacket(); + } else { + // Tell the handle connect it has data to write + data_available_ = true; + } + } + }); +} + +void RawSocketConnector::close() { + io_service_.post([this]() { socket_.close(); }); +} + +void RawSocketConnector::doSendPacket() { + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_.async_send( + std::move(array), + [this /*, packet*/](std::error_code ec, std::size_t bytes_transferred) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doSendPacket(); + } + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + }); +} + +void RawSocketConnector::doRecvPacket() { + read_msg_ = getPacket(); + socket_.async_receive( + asio::buffer(read_msg_->writableData(), packet_size), + [this](std::error_code ec, std::size_t bytes_transferred) mutable { + if (!ec) { + // Ignore packets that are not for us + uint8_t *dst_mac_address = const_cast<uint8_t *>(read_msg_->data()); + if (!std::memcmp(dst_mac_address, ethernet_header_.ether_shost, + ETHER_ADDR_LEN)) { + read_msg_->append(bytes_transferred); + read_msg_->trimStart(sizeof(struct ether_header)); + receive_callback_(std::move(read_msg_)); + } + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + doRecvPacket(); + }); +} + +void RawSocketConnector::doConnect() { + state_ = ConnectorState::CONNECTED; + socket_.bind(raw_endpoint(&link_layer_address_, sizeof(link_layer_address_))); +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_connector.h b/libtransport/src/io_modules/raw_socket/raw_socket_connector.h new file mode 100644 index 000000000..aba4b1105 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_connector.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/connector.h> +#include <hicn/transport/config.h> +#include <hicn/transport/core/name.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <sys/socket.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <deque> + +namespace transport { + +namespace core { + +using asio::generic::raw_protocol; +using raw_endpoint = asio::generic::basic_endpoint<raw_protocol>; + +class RawSocketConnector : public Connector { + public: + RawSocketConnector(PacketReceivedCallback &&receive_callback, + OnReconnect &&reconnect_callback, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~RawSocketConnector() override; + + void send(const Packet::MemBufPtr &packet) override; + + void send(const uint8_t *packet, std::size_t len, + const PacketSentCallback &packet_sent = 0) override; + + void close() override; + + void connect(const std::string &interface_name, + const std::string &mac_address_str); + + private: + void doConnect(); + + void doRecvPacket(); + + void doSendPacket(); + + private: + asio::io_service &io_service_; + raw_protocol::socket socket_; + + struct ether_header ethernet_header_; + + struct sockaddr_ll link_layer_address_; + + asio::steady_timer timer_; + + utils::ObjectPool<utils::MemBuf>::Ptr read_msg_; + + bool data_available_; + std::string app_name_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc b/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc new file mode 100644 index 000000000..dcf489f59 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/raw_socket_interface.h> +#include <hicn/transport/utils/linux.h> + +#include <fstream> + +namespace transport { + +namespace core { + +static std::string config_folder_path = "/etc/transport/interface.conf.d"; + +RawSocketInterface::RawSocketInterface(RawSocketConnector &connector) + : ForwarderInterface<RawSocketInterface, RawSocketConnector>(connector) {} + +RawSocketInterface::~RawSocketInterface() {} + +void RawSocketInterface::connect(bool is_consumer) { + std::string complete_filename = + config_folder_path + std::string("/") + output_interface_; + + std::ifstream is(complete_filename); + std::string interface; + + if (is) { + is >> remote_mac_address_; + } + + // Get interface ip address + struct sockaddr_in6 address = {0}; + utils::retrieveInterfaceAddress(output_interface_, &address); + + std::memcpy(&inet6_address_.v6.as_u8, &address.sin6_addr, + sizeof(address.sin6_addr)); + connector_.connect(output_interface_, remote_mac_address_); +} + +void RawSocketInterface::registerRoute(Prefix &prefix) { return; } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_interface.h b/libtransport/src/io_modules/raw_socket/raw_socket_interface.h new file mode 100644 index 000000000..7036cac7e --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_interface.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/forwarder_interface.h> +#include <core/raw_socket_connector.h> +#include <hicn/transport/core/prefix.h> + +#include <atomic> +#include <deque> + +namespace transport { + +namespace core { + +class RawSocketInterface + : public ForwarderInterface<RawSocketInterface, RawSocketConnector> { + public: + typedef RawSocketConnector ConnectorType; + + RawSocketInterface(RawSocketConnector &connector); + + ~RawSocketInterface(); + + void connect(bool is_consumer); + + void registerRoute(Prefix &prefix); + + std::uint16_t getMtu() { return interface_mtu; } + + TRANSPORT_ALWAYS_INLINE static bool isControlMessageImpl( + const uint8_t *message) { + return false; + } + + TRANSPORT_ALWAYS_INLINE void processControlMessageReplyImpl( + Packet::MemBufPtr &&packet_buffer) {} + + TRANSPORT_ALWAYS_INLINE void closeConnection(){}; + + private: + static constexpr std::uint16_t interface_mtu = 1500; + std::string remote_mac_address_; +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/CMakeLists.txt b/libtransport/src/io_modules/udp/CMakeLists.txt new file mode 100644 index 000000000..1a43492dc --- /dev/null +++ b/libtransport/src/io_modules/udp/CMakeLists.txt @@ -0,0 +1,47 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.cc +) + +# add_executable(hicnlight_module MACOSX_BUNDLE ${MODULE_SOURCE_FILES}) +# target_include_directories(hicnlight_module PRIVATE ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS}) +# set_target_properties(hicnlight_module PROPERTIES +# BUNDLE True +# MACOSX_BUNDLE_GUI_IDENTIFIER my.domain.style.identifier.hicnlight_module +# MACOSX_BUNDLE_BUNDLE_NAME hicnlight_module +# MACOSX_BUNDLE_BUNDLE_VERSION "0.1" +# MACOSX_BUNDLE_SHORT_VERSION_STRING "0.1" +# # MACOSX_BUNDLE_INFO_PLIST ${CMAKE_SOURCE_DIR}/cmake/customtemplate.plist.in +# ) + +build_module(hicnlight_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + # LIBRARY_ROOT_DIR "vpp_plugins" + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/udp/hicn_forwarder_module.cc b/libtransport/src/io_modules/udp/hicn_forwarder_module.cc new file mode 100644 index 000000000..ba08dd8c0 --- /dev/null +++ b/libtransport/src/io_modules/udp/hicn_forwarder_module.cc @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <io_modules/udp/hicn_forwarder_module.h> +#include <io_modules/udp/udp_socket_connector.h> + +union AddressLight { + uint32_t ipv4; + struct in6_addr ipv6; +}; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; +} CommandHeader; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; + char symbolic_or_connid[16]; + union AddressLight address; + uint16_t cost; + uint8_t address_type; + uint8_t len; +} RouteToSelfCommand; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; + char symbolic_or_connid[16]; +} DeleteSelfConnectionCommand; + +namespace { +static constexpr uint8_t addr_inet = 1; +static constexpr uint8_t addr_inet6 = 2; +static constexpr uint8_t add_route_command = 3; +static constexpr uint8_t delete_connection_command = 5; +static constexpr uint8_t request_light = 0xc0; +static constexpr char identifier[] = "SELF"; + +void fillCommandHeader(CommandHeader *header) { + // Allocate and fill the header + header->message_type = request_light; + header->length = 1; +} + +RouteToSelfCommand createCommandRoute(std::unique_ptr<sockaddr> &&addr, + uint8_t prefix_length) { + RouteToSelfCommand command = {0}; + + // check and set IP address + if (addr->sa_family == AF_INET) { + command.address_type = addr_inet; + command.address.ipv4 = ((sockaddr_in *)addr.get())->sin_addr.s_addr; + } else if (addr->sa_family == AF_INET6) { + command.address_type = addr_inet6; + command.address.ipv6 = ((sockaddr_in6 *)addr.get())->sin6_addr; + } + + // Fill remaining payload fields +#ifndef _WIN32 + strcpy(command.symbolic_or_connid, identifier); +#else + strcpy_s(command.symbolic_or_connid, 16, identifier); +#endif + command.cost = 1; + command.len = (uint8_t)prefix_length; + + // Allocate and fill the header + command.command_id = add_route_command; + fillCommandHeader((CommandHeader *)&command); + + return command; +} + +DeleteSelfConnectionCommand createCommandDeleteConnection() { + DeleteSelfConnectionCommand command = {0}; + fillCommandHeader((CommandHeader *)&command); + command.command_id = delete_connection_command; + +#ifndef _WIN32 + strcpy(command.symbolic_or_connid, identifier); +#else + strcpy_s(command.symbolic_or_connid, 16, identifier); +#endif + + return command; +} + +} // namespace + +namespace transport { + +namespace core { + +HicnForwarderModule::HicnForwarderModule() : IoModule(), connector_(nullptr) {} + +HicnForwarderModule::~HicnForwarderModule() {} + +void HicnForwarderModule::connect(bool is_consumer) { + connector_->connect(); + connector_->setRole(is_consumer ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +bool HicnForwarderModule::isConnected() { return connector_->isConnected(); } + +void HicnForwarderModule::send(Packet &packet) { + IoModule::send(packet); + packet.setChecksum(); + connector_->send(packet); +} + +void HicnForwarderModule::send(const uint8_t *packet, std::size_t len) { + counters_.tx_packets++; + counters_.tx_bytes += len; + + // Perfect forwarding + connector_->send(packet, len); +} + +void HicnForwarderModule::registerRoute(const Prefix &prefix) { + auto command = createCommandRoute(prefix.toSockaddr(), + (uint8_t)prefix.getPrefixLength()); + send((uint8_t *)&command, sizeof(RouteToSelfCommand)); +} + +void HicnForwarderModule::closeConnection() { + auto command = createCommandDeleteConnection(); + send((uint8_t *)&command, sizeof(DeleteSelfConnectionCommand)); + connector_->close(); +} + +void HicnForwarderModule::init( + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, const std::string &app_name) { + if (!connector_) { + connector_ = new UdpSocketConnector(std::move(receive_callback), nullptr, + nullptr, std::move(reconnect_callback), + io_service, app_name); + } +} + +void HicnForwarderModule::processControlMessageReply( + utils::MemBuf &packet_buffer) { + if (packet_buffer.data()[0] == nack_code) { + throw errors::RuntimeException( + "Received Nack message from hicn light forwarder."); + } +} + +std::uint32_t HicnForwarderModule::getMtu() { return interface_mtu; } + +bool HicnForwarderModule::isControlMessage(const uint8_t *message) { + return message[0] == ack_code || message[0] == nack_code; +} + +extern "C" IoModule *create_module(void) { return new HicnForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/hicn_forwarder_module.h b/libtransport/src/io_modules/udp/hicn_forwarder_module.h new file mode 100644 index 000000000..845db73bf --- /dev/null +++ b/libtransport/src/io_modules/udp/hicn_forwarder_module.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +namespace transport { + +namespace core { + +class UdpSocketConnector; + +class HicnForwarderModule : public IoModule { + static constexpr uint8_t ack_code = 0xc2; + static constexpr uint8_t nack_code = 0xc3; + static constexpr std::uint16_t interface_mtu = 1500; + + public: + union addressLight { + uint32_t ipv4; + struct in6_addr ipv6; + }; + + struct route_to_self_command { + uint8_t messageType; + uint8_t commandID; + uint16_t length; + uint32_t seqNum; + char symbolicOrConnid[16]; + union addressLight address; + uint16_t cost; + uint8_t addressType; + uint8_t len; + }; + + using route_to_self_command = struct route_to_self_command; + + HicnForwarderModule(); + + ~HicnForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + UdpSocketConnector *connector_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/udp_socket_connector.cc b/libtransport/src/io_modules/udp/udp_socket_connector.cc new file mode 100644 index 000000000..456886a54 --- /dev/null +++ b/libtransport/src/io_modules/udp/udp_socket_connector.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _WIN32 +#include <hicn/transport/portability/win_portability.h> +#endif + +#include <hicn/transport/errors/errors.h> +#include <hicn/transport/utils/log.h> +#include <hicn/transport/utils/object_pool.h> +#include <io_modules/udp/udp_socket_connector.h> + +#include <thread> +#include <vector> + +namespace transport { + +namespace core { + +UdpSocketConnector::UdpSocketConnector( + PacketReceivedCallback &&receive_callback, PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, std::string app_name) + : Connector(std::move(receive_callback), std::move(packet_sent), + std::move(close_callback), std::move(on_reconnect)), + io_service_(io_service), + socket_(io_service_), + resolver_(io_service_), + connection_timer_(io_service_), + read_msg_(std::make_pair(nullptr, 0)), + is_reconnection_(false), + data_available_(false), + app_name_(app_name) {} + +UdpSocketConnector::~UdpSocketConnector() {} + +void UdpSocketConnector::connect(std::string ip_address, std::string port) { + endpoint_iterator_ = resolver_.resolve( + {ip_address, port, asio::ip::resolver_query_base::numeric_service}); + + state_ = Connector::State::CONNECTING; + doConnect(); +} + +void UdpSocketConnector::send(const uint8_t *packet, std::size_t len) { + socket_.async_send(asio::buffer(packet, len), + [this](std::error_code ec, std::size_t /*length*/) { + if (sent_callback_) { + sent_callback_(this, ec); + } + }); +} + +void UdpSocketConnector::send(Packet &packet) { + io_service_.post([this, _packet{packet.shared_from_this()}]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(_packet)); + if (TRANSPORT_EXPECT_TRUE(state_ == Connector::State::CONNECTED)) { + if (!write_in_progress) { + doWrite(); + } + } else { + // Tell the handle connect it has data to write + data_available_ = true; + } + }); +} + +void UdpSocketConnector::close() { + if (io_service_.stopped()) { + doClose(); + } else { + io_service_.dispatch(std::bind(&UdpSocketConnector::doClose, this)); + } +} + +void UdpSocketConnector::doClose() { + if (state_ != Connector::State::CLOSED) { + state_ = Connector::State::CLOSED; + if (socket_.is_open()) { + socket_.shutdown(asio::ip::tcp::socket::shutdown_type::shutdown_both); + socket_.close(); + } + } +} + +void UdpSocketConnector::doWrite() { + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_.async_send(std::move(array), [this](std::error_code ec, + std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doWrite(); + } + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + tryReconnect(); + } + }); +} + +void UdpSocketConnector::doRead() { + read_msg_ = getRawBuffer(); + socket_.async_receive( + asio::buffer(read_msg_.first, read_msg_.second), + [this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + auto packet = getPacketFromBuffer(read_msg_.first, length); + receive_callback_(this, *packet, std::make_error_code(std::errc(0))); + doRead(); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + tryReconnect(); + } + }); +} + +void UdpSocketConnector::tryReconnect() { + if (state_ == Connector::State::CONNECTED) { + TRANSPORT_LOGE("Connection lost. Trying to reconnect...\n"); + state_ = Connector::State::CONNECTING; + is_reconnection_ = true; + io_service_.post([this]() { + if (socket_.is_open()) { + socket_.shutdown(asio::ip::tcp::socket::shutdown_type::shutdown_both); + socket_.close(); + } + + doConnect(); + startConnectionTimer(); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + }); + } +} + +void UdpSocketConnector::doConnect() { + asio::async_connect( + socket_, endpoint_iterator_, + [this](std::error_code ec, udp::resolver::iterator) { + if (!ec) { + connection_timer_.cancel(); + state_ = Connector::State::CONNECTED; + doRead(); + + if (data_available_) { + data_available_ = false; + doWrite(); + } + + if (is_reconnection_) { + is_reconnection_ = false; + } + + on_reconnect_callback_(this); + } else { + doConnect(); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + }); +} + +bool UdpSocketConnector::checkConnected() { + return state_ == Connector::State::CONNECTED; +} + +void UdpSocketConnector::startConnectionTimer() { + connection_timer_.expires_from_now(std::chrono::seconds(60)); + connection_timer_.async_wait(std::bind(&UdpSocketConnector::handleDeadline, + this, std::placeholders::_1)); +} + +void UdpSocketConnector::handleDeadline(const std::error_code &ec) { + if (!ec) { + io_service_.post([this]() { + socket_.close(); + TRANSPORT_LOGE("Error connecting. Is the forwarder running?\n"); + }); + } +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/udp/udp_socket_connector.h b/libtransport/src/io_modules/udp/udp_socket_connector.h new file mode 100644 index 000000000..8ab08e17a --- /dev/null +++ b/libtransport/src/io_modules/udp/udp_socket_connector.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/core/name.h> +#include <hicn/transport/core/packet.h> +#include <hicn/transport/utils/branch_prediction.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <deque> + +namespace transport { +namespace core { + +using asio::ip::udp; + +class UdpSocketConnector : public Connector { + public: + UdpSocketConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~UdpSocketConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(std::string ip_address = "127.0.0.1", std::string port = "9695"); + + private: + void doConnect(); + + void doRead(); + + void doWrite(); + + void doClose(); + + bool checkConnected(); + + private: + void handleDeadline(const std::error_code &ec); + + void startConnectionTimer(); + + void tryReconnect(); + + asio::io_service &io_service_; + asio::ip::udp::socket socket_; + asio::ip::udp::resolver resolver_; + asio::ip::udp::resolver::iterator endpoint_iterator_; + asio::steady_timer connection_timer_; + + std::pair<uint8_t *, std::size_t> read_msg_; + + bool is_reconnection_; + bool data_available_; + + std::string app_name_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/protocols/CMakeLists.txt b/libtransport/src/protocols/CMakeLists.txt index 8bfbdd6ad..eba8d1aab 100644 --- a/libtransport/src/protocols/CMakeLists.txt +++ b/libtransport/src/protocols/CMakeLists.txt @@ -21,16 +21,15 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/datagram_reassembly.h ${CMAKE_CURRENT_SOURCE_DIR}/byte_stream_reassembly.h ${CMAKE_CURRENT_SOURCE_DIR}/congestion_window_protocol.h - ${CMAKE_CURRENT_SOURCE_DIR}/packet_manager.h ${CMAKE_CURRENT_SOURCE_DIR}/rate_estimation.h - ${CMAKE_CURRENT_SOURCE_DIR}/protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/transport_protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/production_protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_bytestream.h + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_rtc.h ${CMAKE_CURRENT_SOURCE_DIR}/raaqm.h ${CMAKE_CURRENT_SOURCE_DIR}/raaqm_data_path.h ${CMAKE_CURRENT_SOURCE_DIR}/cbr.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.h ${CMAKE_CURRENT_SOURCE_DIR}/errors.h - ${CMAKE_CURRENT_SOURCE_DIR}/verification_manager.h ${CMAKE_CURRENT_SOURCE_DIR}/data_processing_events.h ) @@ -41,15 +40,15 @@ list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/reassembly.cc ${CMAKE_CURRENT_SOURCE_DIR}/datagram_reassembly.cc ${CMAKE_CURRENT_SOURCE_DIR}/byte_stream_reassembly.cc - ${CMAKE_CURRENT_SOURCE_DIR}/protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/transport_protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/production_protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_bytestream.cc + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_rtc.cc ${CMAKE_CURRENT_SOURCE_DIR}/raaqm.cc ${CMAKE_CURRENT_SOURCE_DIR}/rate_estimation.cc ${CMAKE_CURRENT_SOURCE_DIR}/raaqm_data_path.cc ${CMAKE_CURRENT_SOURCE_DIR}/cbr.cc - ${CMAKE_CURRENT_SOURCE_DIR}/rtc.cc - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.cc ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc - ${CMAKE_CURRENT_SOURCE_DIR}/verification_manager.cc ) set(RAAQM_CONFIG_INSTALL_PREFIX @@ -71,5 +70,7 @@ install( COMPONENT lib${LIBTRANSPORT} ) +add_subdirectory(rtc) + set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) -set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/protocols/byte_stream_reassembly.cc b/libtransport/src/protocols/byte_stream_reassembly.cc index 6662bec3f..d2bc961c4 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.cc +++ b/libtransport/src/protocols/byte_stream_reassembly.cc @@ -20,7 +20,7 @@ #include <protocols/byte_stream_reassembly.h> #include <protocols/errors.h> #include <protocols/indexer.h> -#include <protocols/protocol.h> +#include <protocols/transport_protocol.h> namespace transport { @@ -45,11 +45,11 @@ void ByteStreamReassembly::reassemble( } } -void ByteStreamReassembly::reassemble(ContentObject::Ptr &&content_object) { - if (TRANSPORT_EXPECT_TRUE(content_object != nullptr) && - read_buffer_->capacity()) { - received_packets_.emplace(std::make_pair( - content_object->getName().getSuffix(), std::move(content_object))); +void ByteStreamReassembly::reassemble(ContentObject &content_object) { + if (TRANSPORT_EXPECT_TRUE(read_buffer_->capacity())) { + received_packets_.emplace( + std::make_pair(content_object.getName().getSuffix(), + content_object.shared_from_this())); assembleContent(); } } @@ -81,25 +81,32 @@ void ByteStreamReassembly::assembleContent() { } } -bool ByteStreamReassembly::copyContent(const ContentObject &content_object) { +bool ByteStreamReassembly::copyContent(ContentObject &content_object) { bool ret = false; - auto payload = content_object.getPayloadReference(); - auto payload_length = payload.second; - auto write_size = std::min(payload_length, read_buffer_->tailroom()); - auto additional_bytes = payload_length > read_buffer_->tailroom() - ? payload_length - read_buffer_->tailroom() - : 0; + content_object.trimStart(content_object.headerSize()); - std::memcpy(read_buffer_->writableTail(), payload.first, write_size); - read_buffer_->append(write_size); + utils::MemBuf *current = &content_object; - if (!read_buffer_->tailroom()) { - notifyApplication(); - std::memcpy(read_buffer_->writableTail(), payload.first + write_size, - additional_bytes); - read_buffer_->append(additional_bytes); - } + do { + auto payload_length = current->length(); + auto write_size = std::min(payload_length, read_buffer_->tailroom()); + auto additional_bytes = payload_length > read_buffer_->tailroom() + ? payload_length - read_buffer_->tailroom() + : 0; + + std::memcpy(read_buffer_->writableTail(), current->data(), write_size); + read_buffer_->append(write_size); + + if (!read_buffer_->tailroom()) { + notifyApplication(); + std::memcpy(read_buffer_->writableTail(), current->data() + write_size, + additional_bytes); + read_buffer_->append(additional_bytes); + } + + current = current->next(); + } while (current != &content_object); download_complete_ = index_manager_->getFinalSuffix() == content_object.getName().getSuffix(); diff --git a/libtransport/src/protocols/byte_stream_reassembly.h b/libtransport/src/protocols/byte_stream_reassembly.h index e4f62b3a8..c682d58cb 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.h +++ b/libtransport/src/protocols/byte_stream_reassembly.h @@ -27,12 +27,12 @@ class ByteStreamReassembly : public Reassembly { TransportProtocol *transport_protocol); protected: - virtual void reassemble(core::ContentObject::Ptr &&content_object) override; + virtual void reassemble(core::ContentObject &content_object) override; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) override; - bool copyContent(const core::ContentObject &content_object); + bool copyContent(core::ContentObject &content_object); virtual void reInitialize() override; diff --git a/libtransport/src/protocols/data_processing_events.h b/libtransport/src/protocols/data_processing_events.h index 8975c2b4a..5c8c16157 100644 --- a/libtransport/src/protocols/data_processing_events.h +++ b/libtransport/src/protocols/data_processing_events.h @@ -24,8 +24,7 @@ namespace protocol { class ContentObjectProcessingEventCallback { public: virtual ~ContentObjectProcessingEventCallback() = default; - virtual void onPacketDropped(core::Interest::Ptr &&i, - core::ContentObject::Ptr &&c) = 0; + virtual void onPacketDropped(core::Interest &i, core::ContentObject &c) = 0; virtual void onReassemblyFailed(std::uint32_t missing_segment) = 0; }; diff --git a/libtransport/src/protocols/datagram_reassembly.cc b/libtransport/src/protocols/datagram_reassembly.cc index abd7e984d..962c1e020 100644 --- a/libtransport/src/protocols/datagram_reassembly.cc +++ b/libtransport/src/protocols/datagram_reassembly.cc @@ -24,8 +24,8 @@ DatagramReassembly::DatagramReassembly( TransportProtocol* transport_protocol) : Reassembly(icn_socket, transport_protocol) {} -void DatagramReassembly::reassemble(core::ContentObject::Ptr&& content_object) { - read_buffer_ = content_object->getPayload(); +void DatagramReassembly::reassemble(core::ContentObject& content_object) { + read_buffer_ = content_object.getPayload(); Reassembly::notifyApplication(); } diff --git a/libtransport/src/protocols/datagram_reassembly.h b/libtransport/src/protocols/datagram_reassembly.h index 2427ae62f..3462212d3 100644 --- a/libtransport/src/protocols/datagram_reassembly.h +++ b/libtransport/src/protocols/datagram_reassembly.h @@ -26,7 +26,7 @@ class DatagramReassembly : public Reassembly { DatagramReassembly(implementation::ConsumerSocket *icn_socket, TransportProtocol *transport_protocol); - virtual void reassemble(core::ContentObject::Ptr &&content_object) override; + virtual void reassemble(core::ContentObject &content_object) override; virtual void reInitialize() override; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) override { diff --git a/libtransport/src/protocols/errors.cc b/libtransport/src/protocols/errors.cc index eefb6f957..ae7b6e634 100644 --- a/libtransport/src/protocols/errors.cc +++ b/libtransport/src/protocols/errors.cc @@ -52,7 +52,9 @@ std::string protocol_category_impl::message(int ev) const { case protocol_error::session_aborted: { return "The session has been aborted by the application."; } - default: { return "Unknown protocol error"; } + default: { + return "Unknown protocol error"; + } } } diff --git a/libtransport/src/protocols/fec_base.h b/libtransport/src/protocols/fec_base.h new file mode 100644 index 000000000..a135c474f --- /dev/null +++ b/libtransport/src/protocols/fec_base.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/content_object.h> + +#include <functional> + +namespace transport { +namespace protocol { + +/** + * Interface classes to integrate FEC inside any producer transport protocol + */ +class ProducerFECBase { + public: + /** + * Callback, to be called by implementations as soon as a repair packet is + * ready. + */ + using RepairPacketsReady = + std::function<void(std::vector<core::ContentObject::Ptr> &)>; + + /** + * Producers will call this function upon production of a new packet. + */ + virtual void onPacketProduced(const core::ContentObject &content_object) = 0; + + /** + * Set callback to signal production protocol the repair packet is ready. + */ + void setFECCallback(const RepairPacketsReady &on_repair_packet) { + rep_packet_ready_callback_ = on_repair_packet; + } + + protected: + RepairPacketsReady rep_packet_ready_callback_; +}; + +/** + * Interface classes to integrate FEC inside any consumer transport protocol + */ +class ConsumerFECBase { + public: + /** + * Callback, to be called by implemrntations as soon as a packet is recovered. + */ + using OnPacketsRecovered = + std::function<void(std::vector<core::ContentObject::Ptr> &)>; + + /** + * Consumers will call this function when they receive a FEC packet. + */ + virtual void onFECPacket(const core::ContentObject &content_object) = 0; + + /** + * Consumers will call this function when they receive a data packet + */ + virtual void onDataPacket(const core::ContentObject &content_object) = 0; + + /** + * Set callback to signal consumer protocol the repair packet is ready. + */ + void setFECCallback(const OnPacketsRecovered &on_repair_packet) { + packet_recovered_callback_ = on_repair_packet; + } + + protected: + OnPacketsRecovered packet_recovered_callback_; +}; + +} // namespace protocol +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/protocols/incremental_indexer.cc b/libtransport/src/protocols/incremental_indexer.cc index 0872c4554..95daa0a3e 100644 --- a/libtransport/src/protocols/incremental_indexer.cc +++ b/libtransport/src/protocols/incremental_indexer.cc @@ -13,37 +13,38 @@ * limitations under the License. */ -#include <protocols/incremental_indexer.h> - #include <hicn/transport/interfaces/socket_consumer.h> -#include <protocols/protocol.h> +#include <protocols/errors.h> +#include <protocols/incremental_indexer.h> +#include <protocols/transport_protocol.h> namespace transport { namespace protocol { -void IncrementalIndexer::onContentObject( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { +void IncrementalIndexer::onContentObject(core::Interest &interest, + core::ContentObject &content_object) { using namespace interface; - TRANSPORT_LOGD("Receive content %s", content_object->getName().toString().c_str()); + TRANSPORT_LOGD("Received content %s", + content_object.getName().toString().c_str()); - if (TRANSPORT_EXPECT_FALSE(content_object->testRst())) { - final_suffix_ = content_object->getName().getSuffix(); + if (TRANSPORT_EXPECT_FALSE(content_object.testRst())) { + final_suffix_ = content_object.getName().getSuffix(); } - auto ret = verification_manager_->onPacketToVerify(*content_object); + auto ret = verifier_->verifyPackets(&content_object); switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - reassembly_->reassemble(std::move(content_object)); + case auth::VerificationPolicy::ACCEPT: { + reassembly_->reassemble(content_object); break; } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(interest), - std::move(content_object)); + case auth::VerificationPolicy::UNKNOWN: + case auth::VerificationPolicy::DROP: { + transport_protocol_->onPacketDropped(interest, content_object); break; } - case VerificationPolicy::ABORT_SESSION: { + case auth::VerificationPolicy::ABORT: { transport_protocol_->onContentReassembled( make_error_code(protocol_error::session_aborted)); break; diff --git a/libtransport/src/protocols/incremental_indexer.h b/libtransport/src/protocols/incremental_indexer.h index 20c5e4759..d7760f8e6 100644 --- a/libtransport/src/protocols/incremental_indexer.h +++ b/libtransport/src/protocols/incremental_indexer.h @@ -15,13 +15,13 @@ #pragma once -#include <hicn/transport/errors/runtime_exception.h> -#include <hicn/transport/errors/unexpected_manifest_exception.h> +#include <hicn/transport/errors/errors.h> +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/auth/verifier.h> #include <hicn/transport/utils/literals.h> - +#include <implementation/socket_consumer.h> #include <protocols/indexer.h> #include <protocols/reassembly.h> -#include <protocols/verification_manager.h> #include <deque> @@ -47,11 +47,12 @@ class IncrementalIndexer : public Indexer { first_suffix_(0), next_download_suffix_(0), next_reassembly_suffix_(0), - verification_manager_( - std::make_unique<SignatureVerificationManager>(icn_socket)) { + verifier_(nullptr) { if (reassembly_) { reassembly_->setIndexer(this); } + socket_->getSocketOption(implementation::GeneralTransportOptions::VERIFIER, + verifier_); } IncrementalIndexer(const IncrementalIndexer &) = delete; @@ -64,15 +65,14 @@ class IncrementalIndexer : public Indexer { first_suffix_(other.first_suffix_), next_download_suffix_(other.next_download_suffix_), next_reassembly_suffix_(other.next_reassembly_suffix_), - verification_manager_(std::move(other.verification_manager_)) { + verifier_(nullptr) { if (reassembly_) { reassembly_->setIndexer(this); } + socket_->getSocketOption(implementation::GeneralTransportOptions::VERIFIER, + verifier_); } - /** - * - */ virtual ~IncrementalIndexer() {} TRANSPORT_ALWAYS_INLINE virtual void reset( @@ -112,8 +112,8 @@ class IncrementalIndexer : public Indexer { return final_suffix_; } - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; TRANSPORT_ALWAYS_INLINE void setReassembly(Reassembly *reassembly) { reassembly_ = reassembly; @@ -123,10 +123,6 @@ class IncrementalIndexer : public Indexer { } } - TRANSPORT_ALWAYS_INLINE bool onKeyToVerify() override { - return verification_manager_->onKeyToVerify(); - } - protected: implementation::ConsumerSocket *socket_; Reassembly *reassembly_; @@ -135,9 +131,8 @@ class IncrementalIndexer : public Indexer { uint32_t first_suffix_; uint32_t next_download_suffix_; uint32_t next_reassembly_suffix_; - std::unique_ptr<VerificationManager> verification_manager_; + std::shared_ptr<auth::Verifier> verifier_; }; -} // end namespace protocol - -} // end namespace transport +} // namespace protocol +} // namespace transport diff --git a/libtransport/src/protocols/indexer.cc b/libtransport/src/protocols/indexer.cc index ca12330a6..1379a609c 100644 --- a/libtransport/src/protocols/indexer.cc +++ b/libtransport/src/protocols/indexer.cc @@ -14,11 +14,9 @@ */ #include <hicn/transport/utils/branch_prediction.h> - #include <protocols/incremental_indexer.h> #include <protocols/indexer.h> #include <protocols/manifest_incremental_indexer.h> -#include <protocols/protocol.h> namespace transport { namespace protocol { @@ -32,16 +30,16 @@ IndexManager::IndexManager(implementation::ConsumerSocket *icn_socket, transport_(transport), reassembly_(reassembly) {} -void IndexManager::onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) { +void IndexManager::onContentObject(core::Interest &interest, + core::ContentObject &content_object) { if (first_segment_received_) { - indexer_->onContentObject(std::move(interest), std::move(content_object)); + indexer_->onContentObject(interest, content_object); } else { - std::uint32_t segment_number = interest->getName().getSuffix(); + std::uint32_t segment_number = interest.getName().getSuffix(); if (segment_number == 0) { // Check if manifest - if (content_object->getPayloadType() == PayloadType::MANIFEST) { + if (content_object.getPayloadType() == core::PayloadType::MANIFEST) { IncrementalIndexer *indexer = static_cast<IncrementalIndexer *>(indexer_.release()); indexer_ = @@ -49,25 +47,21 @@ void IndexManager::onContentObject(core::Interest::Ptr &&interest, delete indexer; } - indexer_->onContentObject(std::move(interest), std::move(content_object)); + indexer_->onContentObject(interest, content_object); auto it = interest_data_set_.begin(); while (it != interest_data_set_.end()) { - indexer_->onContentObject( - std::move(const_cast<core::Interest::Ptr &&>(it->first)), - std::move(const_cast<core::ContentObject::Ptr &&>(it->second))); + indexer_->onContentObject(*it->first, *it->second); it = interest_data_set_.erase(it); } first_segment_received_ = true; } else { - interest_data_set_.emplace(std::move(interest), - std::move(content_object)); + interest_data_set_.emplace(interest.shared_from_this(), + content_object.shared_from_this()); } } } -bool IndexManager::onKeyToVerify() { return indexer_->onKeyToVerify(); } - void IndexManager::reset(std::uint32_t offset) { indexer_ = std::make_unique<IncrementalIndexer>(icn_socket_, transport_, reassembly_); diff --git a/libtransport/src/protocols/indexer.h b/libtransport/src/protocols/indexer.h index 8213a1503..49e22a4cf 100644 --- a/libtransport/src/protocols/indexer.h +++ b/libtransport/src/protocols/indexer.h @@ -33,10 +33,8 @@ class TransportProtocol; class Indexer { public: - /** - * - */ virtual ~Indexer() = default; + /** * Retrieve from the manifest the next suffix to retrieve. */ @@ -55,10 +53,8 @@ class Indexer { virtual void reset(std::uint32_t offset = 0) = 0; - virtual void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) = 0; - - virtual bool onKeyToVerify() = 0; + virtual void onContentObject(core::Interest &interest, + core::ContentObject &content_object) = 0; }; class IndexManager : Indexer { @@ -86,10 +82,8 @@ class IndexManager : Indexer { void reset(std::uint32_t offset = 0) override; - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; - - bool onKeyToVerify() override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; private: std::unique_ptr<Indexer> indexer_; diff --git a/libtransport/src/protocols/manifest_incremental_indexer.cc b/libtransport/src/protocols/manifest_incremental_indexer.cc index da835b577..a6312ca90 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer.cc +++ b/libtransport/src/protocols/manifest_incremental_indexer.cc @@ -14,9 +14,9 @@ */ #include <implementation/socket_consumer.h> - +#include <protocols/errors.h> #include <protocols/manifest_incremental_indexer.h> -#include <protocols/protocol.h> +#include <protocols/transport_protocol.h> #include <cmath> #include <deque> @@ -36,41 +36,46 @@ ManifestIncrementalIndexer::ManifestIncrementalIndexer( 0)) {} void ManifestIncrementalIndexer::onContentObject( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { - // Check if manifest or not - if (content_object->getPayloadType() == PayloadType::MANIFEST) { - TRANSPORT_LOGD("Receive content %s", content_object->getName().toString().c_str()); - onUntrustedManifest(std::move(interest), std::move(content_object)); - } else if (content_object->getPayloadType() == PayloadType::CONTENT_OBJECT) { - TRANSPORT_LOGD("Receive manifest %s", content_object->getName().toString().c_str()); - onUntrustedContentObject(std::move(interest), std::move(content_object)); - } -} - -void ManifestIncrementalIndexer::onUntrustedManifest( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { - auto ret = verification_manager_->onPacketToVerify(*content_object); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - processTrustedManifest(std::move(content_object)); + core::Interest &interest, core::ContentObject &content_object) { + switch (content_object.getPayloadType()) { + case PayloadType::DATA: { + TRANSPORT_LOGD("Received content %s", + content_object.getName().toString().c_str()); + onUntrustedContentObject(interest, content_object); break; } - case VerificationPolicy::DROP_PACKET: - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); + case PayloadType::MANIFEST: { + TRANSPORT_LOGD("Received manifest %s", + content_object.getName().toString().c_str()); + onUntrustedManifest(interest, content_object); break; } + default: { + return; + } } } -void ManifestIncrementalIndexer::processTrustedManifest( - ContentObject::Ptr &&content_object) { +void ManifestIncrementalIndexer::onUntrustedManifest( + core::Interest &interest, core::ContentObject &content_object) { auto manifest = - std::make_unique<ContentObjectManifest>(std::move(*content_object)); + std::make_unique<ContentObjectManifest>(std::move(content_object)); + + auth::VerificationPolicy policy = verifier_->verifyPackets(manifest.get()); + manifest->decode(); + if (policy != auth::VerificationPolicy::ACCEPT) { + transport_protocol_->onContentReassembled( + make_error_code(protocol_error::session_aborted)); + return; + } + + processTrustedManifest(interest, std::move(manifest)); +} + +void ManifestIncrementalIndexer::processTrustedManifest( + core::Interest &interest, std::unique_ptr<ContentObjectManifest> manifest) { if (TRANSPORT_EXPECT_FALSE(manifest->getVersion() != core::ManifestVersion::VERSION_1)) { throw errors::RuntimeException("Received manifest with unknown version."); @@ -78,23 +83,45 @@ void ManifestIncrementalIndexer::processTrustedManifest( switch (manifest->getManifestType()) { case core::ManifestType::INLINE_MANIFEST: { - auto _it = manifest->getSuffixList().begin(); - auto _end = manifest->getSuffixList().end(); - suffix_strategy_->setFinalSuffix(manifest->getFinalBlockNumber()); - for (; _it != _end; _it++) { - auto hash = - std::make_pair(std::vector<uint8_t>(_it->second, _it->second + 32), - manifest->getHashAlgorithm()); + // The packets to verify with the received manifest + std::vector<auth::PacketPtr> packets; + + // Convert the received manifest to a map of packet suffixes to hashes + std::unordered_map<auth::Suffix, auth::HashEntry> current_manifest = + core::ContentObjectManifest::getSuffixMap(manifest.get()); + + // Update 'suffix_map_' with new hashes from the received manifest and + // build 'packets' + for (auto it = current_manifest.begin(); it != current_manifest.end();) { + if (unverified_segments_.find(it->first) == + unverified_segments_.end()) { + suffix_map_[it->first] = std::move(it->second); + current_manifest.erase(it++); + continue; + } - if (!checkUnverifiedSegments(_it->first, hash)) { - suffix_hash_map_[_it->first] = std::move(hash); + packets.push_back(unverified_segments_[it->first].second.get()); + it++; + } + + // Verify unverified segments using the received manifest + std::vector<auth::VerificationPolicy> policies = + verifier_->verifyPackets(packets, current_manifest); + + for (unsigned int i = 0; i < packets.size(); ++i) { + auth::Suffix suffix = packets[i]->getName().getSuffix(); + + if (policies[i] != auth::VerificationPolicy::UNKNOWN) { + unverified_segments_.erase(suffix); } + + applyPolicy(*unverified_segments_[suffix].first, + *unverified_segments_[suffix].second, policies[i]); } reassembly_->reassemble(std::move(manifest)); - break; } case core::ManifestType::FLIC_MANIFEST: { @@ -106,89 +133,47 @@ void ManifestIncrementalIndexer::processTrustedManifest( } } -bool ManifestIncrementalIndexer::checkUnverifiedSegments( - std::uint32_t suffix, const HashEntry &hash) { - auto it = unverified_segments_.find(suffix); - - if (it != unverified_segments_.end()) { - auto ret = verifyContentObject(hash, *it->second.second); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - reassembly_->reassemble(std::move(it->second.second)); - break; - } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(it->second.first), - std::move(it->second.second)); - break; - } - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); - break; - } +void ManifestIncrementalIndexer::onUntrustedContentObject( + Interest &interest, ContentObject &content_object) { + auth::Suffix suffix = content_object.getName().getSuffix(); + auth::VerificationPolicy policy = + verifier_->verifyPackets(&content_object, suffix_map_); + + switch (policy) { + case auth::VerificationPolicy::UNKNOWN: { + unverified_segments_[suffix] = std::make_pair( + interest.shared_from_this(), content_object.shared_from_this()); + break; + } + default: { + suffix_map_.erase(suffix); + break; } - - unverified_segments_.erase(it); - return true; - } - - return false; -} - -VerificationPolicy ManifestIncrementalIndexer::verifyContentObject( - const HashEntry &manifest_hash, const ContentObject &content_object) { - VerificationPolicy ret; - - auto hash_type = static_cast<utils::CryptoHashType>(manifest_hash.second); - auto data_packet_digest = content_object.computeDigest(manifest_hash.second); - auto data_packet_digest_bytes = - data_packet_digest.getDigest<uint8_t>().data(); - const std::vector<uint8_t> &manifest_digest_bytes = manifest_hash.first; - - if (utils::CryptoHash::compareBinaryDigest( - data_packet_digest_bytes, manifest_digest_bytes.data(), hash_type)) { - ret = VerificationPolicy::ACCEPT_PACKET; - } else { - ConsumerContentObjectVerificationFailedCallback - *verification_failed_callback = VOID_HANDLER; - socket_->getSocketOption(ConsumerCallbacksOptions::VERIFICATION_FAILED, - &verification_failed_callback); - ret = (*verification_failed_callback)( - *socket_->getInterface(), content_object, - make_error_code(protocol_error::integrity_verification_failed)); } - return ret; + applyPolicy(interest, content_object, policy); } -void ManifestIncrementalIndexer::onUntrustedContentObject( - Interest::Ptr &&i, ContentObject::Ptr &&c) { - auto suffix = c->getName().getSuffix(); - auto it = suffix_hash_map_.find(suffix); - - if (it != suffix_hash_map_.end()) { - auto ret = verifyContentObject(it->second, *c); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - suffix_hash_map_.erase(it); - reassembly_->reassemble(std::move(c)); - break; - } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(i), std::move(c)); - break; - } - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); - break; - } +void ManifestIncrementalIndexer::applyPolicy( + core::Interest &interest, core::ContentObject &content_object, + auth::VerificationPolicy policy) { + switch (policy) { + case auth::VerificationPolicy::ACCEPT: { + reassembly_->reassemble(content_object); + break; + } + case auth::VerificationPolicy::DROP: { + transport_protocol_->onPacketDropped(interest, content_object); + break; + } + case auth::VerificationPolicy::ABORT: { + transport_protocol_->onContentReassembled( + make_error_code(protocol_error::session_aborted)); + break; + } + default: { + break; } - } else { - unverified_segments_[suffix] = std::make_pair(std::move(i), std::move(c)); } } @@ -224,7 +209,7 @@ uint32_t ManifestIncrementalIndexer::getNextReassemblySegment() { void ManifestIncrementalIndexer::reset(std::uint32_t offset) { IncrementalIndexer::reset(offset); - suffix_hash_map_.clear(); + suffix_map_.clear(); unverified_segments_.clear(); SuffixQueue empty; std::swap(suffix_queue_, empty); diff --git a/libtransport/src/protocols/manifest_incremental_indexer.h b/libtransport/src/protocols/manifest_incremental_indexer.h index 38b01533e..1bb76eb87 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer.h +++ b/libtransport/src/protocols/manifest_incremental_indexer.h @@ -15,6 +15,7 @@ #pragma once +#include <hicn/transport/auth/common.h> #include <implementation/socket.h> #include <protocols/incremental_indexer.h> #include <utils/suffix_strategy.h> @@ -22,7 +23,6 @@ #include <list> namespace transport { - namespace protocol { class ManifestIncrementalIndexer : public IncrementalIndexer { @@ -30,7 +30,8 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { public: using SuffixQueue = std::queue<uint32_t>; - using HashEntry = std::pair<std::vector<uint8_t>, utils::CryptoHashType>; + using InterestContentPair = + std::pair<core::Interest::Ptr, core::ContentObject::Ptr>; ManifestIncrementalIndexer(implementation::ConsumerSocket *icn_socket, TransportProtocol *transport, @@ -50,8 +51,8 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { void reset(std::uint32_t offset = 0) override; - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; uint32_t getNextSuffix() override; @@ -61,30 +62,24 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { uint32_t getFinalSuffix() override; - private: - void onUntrustedManifest(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object); - void onUntrustedContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object); - void processTrustedManifest(core::ContentObject::Ptr &&content_object); - void onManifestReceived(core::Interest::Ptr &&i, - core::ContentObject::Ptr &&c); - void onManifestTimeout(core::Interest::Ptr &&i); - VerificationPolicy verifyContentObject( - const HashEntry &manifest_hash, - const core::ContentObject &content_object); - bool checkUnverifiedSegments(std::uint32_t suffix, const HashEntry &hash); - protected: std::unique_ptr<utils::SuffixStrategy> suffix_strategy_; SuffixQueue suffix_queue_; // Hash verification - std::unordered_map<uint32_t, HashEntry> suffix_hash_map_; + std::unordered_map<auth::Suffix, auth::HashEntry> suffix_map_; + std::unordered_map<auth::Suffix, InterestContentPair> unverified_segments_; - std::unordered_map<uint32_t, - std::pair<core::Interest::Ptr, core::ContentObject::Ptr>> - unverified_segments_; + private: + void onUntrustedManifest(core::Interest &interest, + core::ContentObject &content_object); + void processTrustedManifest(core::Interest &interest, + std::unique_ptr<ContentObjectManifest> manifest); + void onUntrustedContentObject(core::Interest &interest, + core::ContentObject &content_object); + void applyPolicy(core::Interest &interest, + core::ContentObject &content_object, + auth::VerificationPolicy policy); }; } // end namespace protocol diff --git a/libtransport/src/protocols/prod_protocol_bytestream.cc b/libtransport/src/protocols/prod_protocol_bytestream.cc new file mode 100644 index 000000000..6bd989fe4 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_bytestream.cc @@ -0,0 +1,390 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <implementation/socket_producer.h> +#include <protocols/prod_protocol_bytestream.h> + +#include <atomic> + +namespace transport { + +namespace protocol { + +using namespace core; +using namespace implementation; + +ByteStreamProductionProtocol::ByteStreamProductionProtocol( + implementation::ProducerSocket *icn_socket) + : ProductionProtocol(icn_socket) {} + +ByteStreamProductionProtocol::~ByteStreamProductionProtocol() { + stop(); + if (listening_thread_.joinable()) { + listening_thread_.join(); + } +} + +uint32_t ByteStreamProductionProtocol::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + throw errors::NotImplementedException(); +} + +uint32_t ByteStreamProductionProtocol::produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) { + throw errors::NotImplementedException(); +} + +uint32_t ByteStreamProductionProtocol::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, + bool is_last, + uint32_t start_offset) { + if (!buffer_size) { + return 0; + } + + return produceStream(content_name, + utils::MemBuf::copyBuffer(buffer, buffer_size), is_last, + start_offset); +} + +uint32_t ByteStreamProductionProtocol::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + if (TRANSPORT_EXPECT_FALSE(buffer->length() == 0)) { + return 0; + } + + Name name(content_name); + + // Get the atomic variables to ensure they keep the same value + // during the production + + // Total size of the data packet + uint32_t data_packet_size; + socket_->getSocketOption(GeneralTransportOptions::DATA_PACKET_SIZE, + data_packet_size); + + // Expiry time + uint32_t content_object_expiry_time; + socket_->getSocketOption(GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME, + content_object_expiry_time); + + // Hash algorithm + auth::CryptoHashType hash_algo; + socket_->getSocketOption(GeneralTransportOptions::HASH_ALGORITHM, hash_algo); + + // Use manifest + bool making_manifest; + socket_->getSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + making_manifest); + + // Suffix calculation strategy + core::NextSegmentCalculationStrategy _suffix_strategy; + socket_->getSocketOption(GeneralTransportOptions::SUFFIX_STRATEGY, + _suffix_strategy); + auto suffix_strategy = utils::SuffixStrategyFactory::getSuffixStrategy( + _suffix_strategy, start_offset); + + std::shared_ptr<auth::Signer> signer; + socket_->getSocketOption(GeneralTransportOptions::SIGNER, signer); + + auto buffer_size = buffer->length(); + int bytes_segmented = 0; + std::size_t header_size; + std::size_t manifest_header_size = 0; + std::size_t signature_length = 0; + std::uint32_t final_block_number = start_offset; + uint64_t free_space_for_content = 0; + + core::Packet::Format format; + std::shared_ptr<core::ContentObjectManifest> manifest; + bool is_last_manifest = false; + + // TODO Manifest may still be used for indexing + if (making_manifest && !signer) { + TRANSPORT_LOGE("Making manifests without setting producer identity."); + } + + core::Packet::Format hf_format = core::Packet::Format::HF_UNSPEC; + core::Packet::Format hf_format_ah = core::Packet::Format::HF_UNSPEC; + + if (name.getType() == HNT_CONTIGUOUS_V4 || name.getType() == HNT_IOV_V4) { + hf_format = core::Packet::Format::HF_INET_TCP; + hf_format_ah = core::Packet::Format::HF_INET_TCP_AH; + } else if (name.getType() == HNT_CONTIGUOUS_V6 || + name.getType() == HNT_IOV_V6) { + hf_format = core::Packet::Format::HF_INET6_TCP; + hf_format_ah = core::Packet::Format::HF_INET6_TCP_AH; + } else { + throw errors::RuntimeException("Unknown name format."); + } + + format = hf_format; + if (making_manifest) { + manifest_header_size = core::Packet::getHeaderSizeFromFormat( + signer ? hf_format_ah : hf_format, + signer ? signer->getSignatureSize() : 0); + } else if (signer) { + format = hf_format_ah; + signature_length = signer->getSignatureSize(); + } + + header_size = core::Packet::getHeaderSizeFromFormat(format, signature_length); + free_space_for_content = data_packet_size - header_size; + uint32_t number_of_segments = + uint32_t(std::ceil(double(buffer_size) / double(free_space_for_content))); + if (free_space_for_content * number_of_segments < buffer_size) { + number_of_segments++; + } + + // TODO allocate space for all the headers + if (making_manifest) { + uint32_t segment_in_manifest = static_cast<uint32_t>( + std::floor(double(data_packet_size - manifest_header_size - + ContentObjectManifest::getManifestHeaderSize()) / + ContentObjectManifest::getManifestEntrySize()) - + 1.0); + uint32_t number_of_manifests = static_cast<uint32_t>( + std::ceil(float(number_of_segments) / segment_in_manifest)); + final_block_number += number_of_segments + number_of_manifests - 1; + + manifest.reset(ContentObjectManifest::createManifest( + name.setSuffix(suffix_strategy->getNextManifestSuffix()), + core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, + hash_algo, is_last_manifest, name, _suffix_strategy, + signer ? signer->getSignatureSize() : 0)); + manifest->setLifetime(content_object_expiry_time); + + if (is_last) { + manifest->setFinalBlockNumber(final_block_number); + } else { + manifest->setFinalBlockNumber(utils::SuffixStrategy::INVALID_SUFFIX); + } + } + + for (unsigned int packaged_segments = 0; + packaged_segments < number_of_segments; packaged_segments++) { + if (making_manifest) { + if (manifest->estimateManifestSize(2) > + data_packet_size - manifest_header_size) { + manifest->encode(); + + // If identity set, sign manifest + if (signer) { + signer->signPacket(manifest.get()); + } + + // Send the current manifest + passContentObjectToCallbacks(manifest); + + TRANSPORT_LOGD("Send manifest %s", + manifest->getName().toString().c_str()); + + // Send content objects stored in the queue + while (!content_queue_.empty()) { + passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %s", + content_queue_.front()->getName().toString().c_str()); + content_queue_.pop(); + } + + // Create new manifest. The reference to the last manifest has been + // acquired in the passContentObjectToCallbacks function, so we can + // safely release this reference + manifest.reset(ContentObjectManifest::createManifest( + name.setSuffix(suffix_strategy->getNextManifestSuffix()), + core::ManifestVersion::VERSION_1, + core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, + name, _suffix_strategy, signer ? signer->getSignatureSize() : 0)); + + manifest->setLifetime(content_object_expiry_time); + manifest->setFinalBlockNumber( + is_last ? final_block_number + : utils::SuffixStrategy::INVALID_SUFFIX); + } + } + + auto content_suffix = suffix_strategy->getNextContentSuffix(); + auto content_object = std::make_shared<ContentObject>( + name.setSuffix(content_suffix), format, + signer && !making_manifest ? signer->getSignatureSize() : 0); + content_object->setLifetime(content_object_expiry_time); + + auto b = buffer->cloneOne(); + b->trimStart(free_space_for_content * packaged_segments); + b->trimEnd(b->length()); + + if (TRANSPORT_EXPECT_FALSE(packaged_segments == number_of_segments - 1)) { + b->append(buffer_size - bytes_segmented); + bytes_segmented += (int)(buffer_size - bytes_segmented); + + if (is_last && making_manifest) { + is_last_manifest = true; + } else if (is_last) { + content_object->setRst(); + } + + } else { + b->append(free_space_for_content); + bytes_segmented += (int)(free_space_for_content); + } + + content_object->appendPayload(std::move(b)); + + if (making_manifest) { + using namespace std::chrono_literals; + auth::CryptoHash hash = content_object->computeDigest(hash_algo); + manifest->addSuffixHash(content_suffix, hash); + content_queue_.push(content_object); + } else { + if (signer) { + signer->signPacket(content_object.get()); + } + passContentObjectToCallbacks(content_object); + TRANSPORT_LOGD("Send content %s", + content_object->getName().toString().c_str()); + } + } + + if (making_manifest) { + if (is_last_manifest) { + manifest->setFinalManifest(is_last_manifest); + } + + manifest->encode(); + + if (signer) { + signer->signPacket(manifest.get()); + } + + passContentObjectToCallbacks(manifest); + TRANSPORT_LOGD("Send manifest %s", manifest->getName().toString().c_str()); + + while (!content_queue_.empty()) { + passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %s", + content_queue_.front()->getName().toString().c_str()); + content_queue_.pop(); + } + } + + portal_->getIoService().post([this]() { + std::shared_ptr<ContentObject> co; + while (object_queue_for_callbacks_.pop(co)) { + if (*on_new_segment_) { + on_new_segment_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_to_sign_) { + on_content_object_to_sign_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()( + *socket_->getInterface(), *co); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *co); + } + } + }); + + portal_->getIoService().dispatch([this, buffer_size]() { + if (*on_content_produced_) { + on_content_produced_->operator()(*socket_->getInterface(), + std::make_error_code(std::errc(0)), + buffer_size); + } + }); + + return suffix_strategy->getTotalCount(); +} + +void ByteStreamProductionProtocol::scheduleSendBurst() { + portal_->getIoService().post([this]() { + std::shared_ptr<ContentObject> co; + + for (uint32_t i = 0; i < burst_size; i++) { + if (object_queue_for_callbacks_.pop(co)) { + if (*on_new_segment_) { + on_new_segment_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_to_sign_) { + on_content_object_to_sign_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()( + *socket_->getInterface(), *co); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *co); + } + } else { + break; + } + } + }); +} + +void ByteStreamProductionProtocol::passContentObjectToCallbacks( + const std::shared_ptr<ContentObject> &content_object) { + output_buffer_.insert(content_object); + portal_->sendContentObject(*content_object); + object_queue_for_callbacks_.push(std::move(content_object)); + + if (object_queue_for_callbacks_.size() >= burst_size) { + scheduleSendBurst(); + } +} + +void ByteStreamProductionProtocol::onInterest(Interest &interest) { + TRANSPORT_LOGD("Received interest for %s", + interest.getName().toString().c_str()); + if (*on_interest_input_) { + on_interest_input_->operator()(*socket_->getInterface(), interest); + } + + const std::shared_ptr<ContentObject> content_object = + output_buffer_.find(interest); + + if (content_object) { + if (*on_interest_satisfied_output_buffer_) { + on_interest_satisfied_output_buffer_->operator()(*socket_->getInterface(), + interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + portal_->sendContentObject(*content_object); + } else { + if (*on_interest_process_) { + on_interest_process_->operator()(*socket_->getInterface(), interest); + } + } +} + +void ByteStreamProductionProtocol::onError(std::error_code ec) {} + +} // namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_bytestream.h b/libtransport/src/protocols/prod_protocol_bytestream.h new file mode 100644 index 000000000..cf36b90a5 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_bytestream.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/utils/ring_buffer.h> +#include <protocols/production_protocol.h> + +#include <atomic> +#include <queue> + +namespace transport { + +namespace protocol { + +using namespace core; + +class ByteStreamProductionProtocol : public ProductionProtocol { + static constexpr uint32_t burst_size = 256; + + public: + ByteStreamProductionProtocol(implementation::ProducerSocket *icn_socket); + + ~ByteStreamProductionProtocol() override; + + using ProductionProtocol::start; + using ProductionProtocol::stop; + + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, const uint8_t *buffer, + size_t buffer_size) override; + + protected: + // Consumer Callback + // void reset() override; + void onInterest(core::Interest &i) override; + void onError(std::error_code ec) override; + + private: + void passContentObjectToCallbacks( + const std::shared_ptr<ContentObject> &content_object); + void scheduleSendBurst(); + + private: + // While manifests are being built, contents are stored in a queue + std::queue<std::shared_ptr<ContentObject>> content_queue_; + utils::CircularFifo<std::shared_ptr<ContentObject>, 2048> + object_queue_for_callbacks_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_rtc.cc b/libtransport/src/protocols/prod_protocol_rtc.cc new file mode 100644 index 000000000..8081923e3 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_rtc.cc @@ -0,0 +1,481 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/global_object_pool.h> +#include <implementation/socket_producer.h> +#include <protocols/prod_protocol_rtc.h> +#include <protocols/rtc/rtc_consts.h> +#include <stdlib.h> +#include <time.h> + +#include <unordered_set> + +namespace transport { +namespace protocol { + +RTCProductionProtocol::RTCProductionProtocol( + implementation::ProducerSocket *icn_socket) + : ProductionProtocol(icn_socket), + current_seg_(1), + produced_bytes_(0), + produced_packets_(0), + max_packet_production_(1), + bytes_production_rate_(0), + packets_production_rate_(0), + last_round_(std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count()), + allow_delayed_nacks_(false), + queue_timer_on_(false), + consumer_in_sync_(false), + on_consumer_in_sync_(nullptr) { + srand((unsigned int)time(NULL)); + prod_label_ = rand() % 256; + interests_queue_timer_ = + std::make_unique<asio::steady_timer>(portal_->getIoService()); + round_timer_ = std::make_unique<asio::steady_timer>(portal_->getIoService()); + setOutputBufferSize(10000); + scheduleRoundTimer(); +} + +RTCProductionProtocol::~RTCProductionProtocol() {} + +void RTCProductionProtocol::registerNamespaceWithNetwork( + const Prefix &producer_namespace) { + ProductionProtocol::registerNamespaceWithNetwork(producer_namespace); + + flow_name_ = producer_namespace.getName(); + auto family = flow_name_.getAddressFamily(); + + switch (family) { + case AF_INET6: + header_size_ = (uint32_t)Packet::getHeaderSizeFromFormat(HF_INET6_TCP); + break; + case AF_INET: + header_size_ = (uint32_t)Packet::getHeaderSizeFromFormat(HF_INET_TCP); + break; + default: + throw errors::RuntimeException("Unknown name format."); + } +} + +void RTCProductionProtocol::scheduleRoundTimer() { + round_timer_->expires_from_now( + std::chrono::milliseconds(rtc::PRODUCER_STATS_INTERVAL)); + round_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + updateStats(); + }); +} + +void RTCProductionProtocol::updateStats() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t duration = now - last_round_; + if (duration == 0) duration = 1; + double per_second = rtc::MILLI_IN_A_SEC / duration; + + uint32_t prev_packets_production_rate = packets_production_rate_; + + bytes_production_rate_ = ceil((double)produced_bytes_ * per_second); + packets_production_rate_ = ceil((double)produced_packets_ * per_second); + + TRANSPORT_LOGD("Updating production rate: produced_bytes_ = %u bps = %u", + produced_bytes_, bytes_production_rate_); + + // update the production rate as soon as it increases by 10% with respect to + // the last round + max_packet_production_ = + produced_packets_ + ceil((double)produced_packets_ * 0.1); + if (max_packet_production_ < rtc::WIN_MIN) + max_packet_production_ = rtc::WIN_MIN; + + if (packets_production_rate_ != 0) { + allow_delayed_nacks_ = false; + } else if (prev_packets_production_rate == 0) { + // at least 2 rounds with production rate = 0 + allow_delayed_nacks_ = true; + } + + // check if the production rate is decreased. if yes send nacks if needed + if (prev_packets_production_rate < packets_production_rate_) { + sendNacksForPendingInterests(); + } + + produced_bytes_ = 0; + produced_packets_ = 0; + last_round_ = now; + scheduleRoundTimer(); +} + +uint32_t RTCProductionProtocol::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + throw errors::NotImplementedException(); +} + +uint32_t RTCProductionProtocol::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, bool is_last, + uint32_t start_offset) { + throw errors::NotImplementedException(); +} + +void RTCProductionProtocol::produce(ContentObject &content_object) { + throw errors::NotImplementedException(); +} + +uint32_t RTCProductionProtocol::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + std::size_t buffer_size = buffer->length(); + if (TRANSPORT_EXPECT_FALSE(buffer_size == 0)) return 0; + + uint32_t data_packet_size; + socket_->getSocketOption(interface::GeneralTransportOptions::DATA_PACKET_SIZE, + data_packet_size); + + if (TRANSPORT_EXPECT_FALSE((buffer_size + header_size_ + + rtc::DATA_HEADER_SIZE) > data_packet_size)) { + return 0; + } + + auto content_object = + core::PacketManager<>::getInstance().getPacket<ContentObject>(); + // add rtc header to the payload + struct rtc::data_packet_t header; + content_object->appendPayload((const uint8_t *)&header, + rtc::DATA_HEADER_SIZE); + content_object->appendPayload(buffer->data(), buffer->length()); + + std::shared_ptr<ContentObject> co = std::move(content_object); + + // schedule actual sending on internal thread + portal_->getIoService().dispatch( + [this, content_object{std::move(co)}, content_name]() mutable { + produceInternal(std::move(content_object), content_name); + }); + + return 1; +} + +void RTCProductionProtocol::produceInternal( + std::shared_ptr<ContentObject> &&content_object, const Name &content_name) { + // set rtc header + struct rtc::data_packet_t *data_pkt = + (struct rtc::data_packet_t *)content_object->getPayload()->data(); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + data_pkt->setTimestamp(now); + data_pkt->setProductionRate(bytes_production_rate_); + + // set hicn stuff + Name n(content_name); + content_object->setName(n.setSuffix(current_seg_)); + content_object->setLifetime(500); // XXX this should be set by the APP + content_object->setPathLabel(prod_label_); + + // update stats + produced_bytes_ += + content_object->headerSize() + content_object->payloadSize(); + produced_packets_++; + + if (produced_packets_ >= max_packet_production_) { + // in this case all the pending interests may be used to accomodate the + // sudden increase in the production rate. calling the updateStats we will + // notify all the clients + round_timer_->cancel(); + updateStats(); + } + + TRANSPORT_LOGD("Sending content object: %s", n.toString().c_str()); + + output_buffer_.insert(content_object); + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()(*socket_->getInterface(), + *content_object); + } + + portal_->sendContentObject(*content_object); + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + // remove interests from the interest cache if it exists + removeFromInterestQueue(current_seg_); + + current_seg_ = (current_seg_ + 1) % rtc::MIN_PROBE_SEQ; +} + +void RTCProductionProtocol::onInterest(Interest &interest) { + uint32_t interest_seg = interest.getName().getSuffix(); + uint32_t lifetime = interest.getLifetime(); + + if (interest_seg == 0) { + // first packet from the consumer, reset sync state + consumer_in_sync_ = false; + } + + if (*on_interest_input_) { + on_interest_input_->operator()(*socket_->getInterface(), interest); + } + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (interest_seg > rtc::MIN_PROBE_SEQ) { + TRANSPORT_LOGD("received probe %u", interest_seg); + sendNack(interest_seg); + return; + } + + TRANSPORT_LOGD("received interest %u", interest_seg); + + const std::shared_ptr<ContentObject> content_object = + output_buffer_.find(interest); + + if (content_object) { + if (*on_interest_satisfied_output_buffer_) { + on_interest_satisfied_output_buffer_->operator()(*socket_->getInterface(), + interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + TRANSPORT_LOGD("Send content %u (onInterest)", + content_object->getName().getSuffix()); + portal_->sendContentObject(*content_object); + return; + } else { + if (*on_interest_process_) { + on_interest_process_->operator()(*socket_->getInterface(), interest); + } + } + + // if the production rate 0 use delayed nacks + if (allow_delayed_nacks_ && interest_seg >= current_seg_) { + uint64_t next_timer = ~0; + if (!timers_map_.empty()) { + next_timer = timers_map_.begin()->first; + } + + uint64_t expiration = now + rtc::SENTINEL_TIMER_INTERVAL; + addToInterestQueue(interest_seg, expiration); + + // here we have at least one interest in the queue, we need to start or + // update the timer + if (!queue_timer_on_) { + // set timeout + queue_timer_on_ = true; + scheduleQueueTimer(timers_map_.begin()->first - now); + } else { + // re-schedule the timer because a new interest will expires sooner + if (next_timer > timers_map_.begin()->first) { + interests_queue_timer_->cancel(); + scheduleQueueTimer(timers_map_.begin()->first - now); + } + } + return; + } + + if (queue_timer_on_) { + // the producer is producing. Send nacks to packets that will expire before + // the data production and remove the timer + queue_timer_on_ = false; + interests_queue_timer_->cancel(); + sendNacksForPendingInterests(); + } + + uint32_t max_gap = (uint32_t)floor( + (double)((double)((double)lifetime * + rtc::INTEREST_LIFETIME_REDUCTION_FACTOR / + rtc::MILLI_IN_A_SEC) * + (double)packets_production_rate_)); + + if (interest_seg < current_seg_ || interest_seg > (max_gap + current_seg_)) { + sendNack(interest_seg); + } else { + if (!consumer_in_sync_ && on_consumer_in_sync_) { + // we consider the remote consumer to be in sync as soon as it covers 70% + // of the production window with interests + uint32_t perc = ceil((double)max_gap * 0.7); + if (interest_seg > (perc + current_seg_)) { + consumer_in_sync_ = true; + on_consumer_in_sync_(*socket_->getInterface(), interest); + } + } + uint64_t expiration = + now + floor((double)lifetime * rtc::INTEREST_LIFETIME_REDUCTION_FACTOR); + addToInterestQueue(interest_seg, expiration); + } +} + +void RTCProductionProtocol::onError(std::error_code ec) {} + +void RTCProductionProtocol::scheduleQueueTimer(uint64_t wait) { + interests_queue_timer_->expires_from_now(std::chrono::milliseconds(wait)); + interests_queue_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + interestQueueTimer(); + }); +} + +void RTCProductionProtocol::addToInterestQueue(uint32_t interest_seg, + uint64_t expiration) { + // check if the seq number exists already + auto it_seqs = seqs_map_.find(interest_seg); + if (it_seqs != seqs_map_.end()) { + // the seq already exists + if (expiration < it_seqs->second) { + // we need to update the timer becasue we got a smaller one + // 1) remove the entry from the multimap + // 2) update this entry + auto range = timers_map_.equal_range(it_seqs->second); + for (auto it_timers = range.first; it_timers != range.second; + it_timers++) { + if (it_timers->second == it_seqs->first) { + timers_map_.erase(it_timers); + break; + } + } + timers_map_.insert( + std::pair<uint64_t, uint32_t>(expiration, interest_seg)); + it_seqs->second = expiration; + } else { + // nothing to do here + return; + } + } else { + // add the new seq + timers_map_.insert(std::pair<uint64_t, uint32_t>(expiration, interest_seg)); + seqs_map_.insert(std::pair<uint32_t, uint64_t>(interest_seg, expiration)); + } +} + +void RTCProductionProtocol::sendNacksForPendingInterests() { + std::unordered_set<uint32_t> to_remove; + + uint32_t packet_gap = 100000; // set it to a high value (100sec) + if (packets_production_rate_ != 0) + packet_gap = ceil(rtc::MILLI_IN_A_SEC / (double)packets_production_rate_); + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + for (auto it = seqs_map_.begin(); it != seqs_map_.end(); it++) { + if (it->first > current_seg_) { + uint64_t production_time = + ((it->first - current_seg_) * packet_gap) + now; + if (production_time >= it->second) { + sendNack(it->first); + to_remove.insert(it->first); + } + } + } + + // delete nacked interests + for (auto it = to_remove.begin(); it != to_remove.end(); it++) { + removeFromInterestQueue(*it); + } +} + +void RTCProductionProtocol::removeFromInterestQueue(uint32_t interest_seg) { + auto seq_it = seqs_map_.find(interest_seg); + if (seq_it != seqs_map_.end()) { + auto range = timers_map_.equal_range(seq_it->second); + for (auto it_timers = range.first; it_timers != range.second; it_timers++) { + if (it_timers->second == seq_it->first) { + timers_map_.erase(it_timers); + break; + } + } + seqs_map_.erase(seq_it); + } +} + +void RTCProductionProtocol::interestQueueTimer() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + for (auto it_timers = timers_map_.begin(); it_timers != timers_map_.end();) { + uint64_t expire = it_timers->first; + if (expire <= now) { + uint32_t seq = it_timers->second; + sendNack(seq); + // remove the interest from the other map + seqs_map_.erase(seq); + it_timers = timers_map_.erase(it_timers); + } else { + // stop, we are done! + break; + } + } + if (timers_map_.empty()) { + queue_timer_on_ = false; + } else { + queue_timer_on_ = true; + scheduleQueueTimer(timers_map_.begin()->first - now); + } +} + +void RTCProductionProtocol::sendNack(uint32_t sequence) { + auto nack = core::PacketManager<>::getInstance().getPacket<ContentObject>(); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint32_t next_packet = current_seg_; + uint32_t prod_rate = bytes_production_rate_; + + struct rtc::nack_packet_t header; + header.setTimestamp(now); + header.setProductionRate(prod_rate); + header.setProductionSegement(next_packet); + nack->appendPayload((const uint8_t *)&header, rtc::NACK_HEADER_SIZE); + + Name n(flow_name_); + n.setSuffix(sequence); + nack->setName(n); + nack->setLifetime(0); + nack->setPathLabel(prod_label_); + + if (!consumer_in_sync_ && on_consumer_in_sync_ && + sequence < rtc::MIN_PROBE_SEQ && sequence > next_packet) { + consumer_in_sync_ = true; + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); + interest->setName(n); + on_consumer_in_sync_(*socket_->getInterface(), *interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *nack); + } + + TRANSPORT_LOGD("Send nack %u", sequence); + portal_->sendContentObject(*nack); +} + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_rtc.h b/libtransport/src/protocols/prod_protocol_rtc.h new file mode 100644 index 000000000..f3584f74a --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_rtc.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/name.h> +#include <protocols/production_protocol.h> + +#include <atomic> +#include <map> +#include <mutex> + +namespace transport { +namespace protocol { + +class RTCProductionProtocol : public ProductionProtocol { + public: + RTCProductionProtocol(implementation::ProducerSocket *icn_socket); + ~RTCProductionProtocol() override; + + using ProductionProtocol::start; + using ProductionProtocol::stop; + + void produce(ContentObject &content_object) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, const uint8_t *buffer, + size_t buffer_size) override { + return produceDatagram(content_name, utils::MemBuf::wrapBuffer( + buffer, buffer_size, buffer_size)); + } + + void registerNamespaceWithNetwork(const Prefix &producer_namespace) override; + + void setConsumerInSyncCallback( + interface::ProducerInterestCallback &&callback) { + on_consumer_in_sync_ = std::move(callback); + } + + private: + // packet handlers + void onInterest(Interest &interest) override; + void onError(std::error_code ec) override; + void produceInternal(std::shared_ptr<ContentObject> &&content_object, + const Name &content_name); + void sendNack(uint32_t sequence); + + // stats + void updateStats(); + void scheduleRoundTimer(); + + // pending intersts functions + void addToInterestQueue(uint32_t interest_seg, uint64_t expiration); + void sendNacksForPendingInterests(); + void removeFromInterestQueue(uint32_t interest_seg); + void scheduleQueueTimer(uint64_t wait); + void interestQueueTimer(); + + core::Name flow_name_; + + uint32_t current_seg_; // seq id of the next packet produced + uint32_t prod_label_; // path lable of the producer + uint16_t header_size_; // hicn header size + + uint32_t produced_bytes_; // bytes produced in the last round + uint32_t produced_packets_; // packet produed in the last round + + uint32_t max_packet_production_; // never exceed this number of packets + // without update stats + + uint32_t bytes_production_rate_; // bytes per sec + uint32_t packets_production_rate_; // pps + + std::unique_ptr<asio::steady_timer> round_timer_; + uint64_t last_round_; + + // delayed nacks are used by the producer to avoid to send too + // many nacks we the producer rate is 0. however, if the producer moves + // from a production rate higher than 0 to 0 the first round the dealyed + // should be avoided in order to notify the consumer as fast as possible + // of the new rate. + bool allow_delayed_nacks_; + + // queue for the received interests + // this map maps the expiration time of an interest to + // its sequence number. the map is sorted by timeouts + // the same timeout may be used for multiple sequence numbers + // but for each sequence number we store only the smallest + // expiry time. In this way the mapping from seqs_map_ to + // timers_map_ is unique + std::multimap<uint64_t, uint32_t> timers_map_; + + // this map does the opposite, this map is not ordered + std::unordered_map<uint32_t, uint64_t> seqs_map_; + bool queue_timer_on_; + std::unique_ptr<asio::steady_timer> interests_queue_timer_; + + // this callback is called when the remote consumer is in sync with high + // probability. it is called only the first time that the switch happen. + // XXX this makes sense only in P2P mode, while in standard mode is + // impossible to know the state of the consumers so it should not be used. + bool consumer_in_sync_; + interface::ProducerInterestCallback on_consumer_in_sync_; +}; + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/production_protocol.cc b/libtransport/src/protocols/production_protocol.cc new file mode 100644 index 000000000..8addf52d1 --- /dev/null +++ b/libtransport/src/protocols/production_protocol.cc @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <implementation/socket_producer.h> +#include <protocols/production_protocol.h> + +namespace transport { + +namespace protocol { + +using namespace interface; + +ProductionProtocol::ProductionProtocol( + implementation::ProducerSocket *icn_socket) + : socket_(icn_socket), + is_running_(false), + on_interest_input_(VOID_HANDLER), + on_interest_dropped_input_buffer_(VOID_HANDLER), + on_interest_inserted_input_buffer_(VOID_HANDLER), + on_interest_satisfied_output_buffer_(VOID_HANDLER), + on_interest_process_(VOID_HANDLER), + on_new_segment_(VOID_HANDLER), + on_content_object_to_sign_(VOID_HANDLER), + on_content_object_in_output_buffer_(VOID_HANDLER), + on_content_object_output_(VOID_HANDLER), + on_content_object_evicted_from_output_buffer_(VOID_HANDLER), + on_content_produced_(VOID_HANDLER) { + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal_); + // TODO add statistics for producer + // socket_->getSocketOption(OtherOptions::STATISTICS, &stats_); +} + +ProductionProtocol::~ProductionProtocol() { + if (!is_async_ && is_running_) { + stop(); + } + + if (listening_thread_.joinable()) { + listening_thread_.join(); + } +} + +int ProductionProtocol::start() { + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, + &on_interest_input_); + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_DROP, + &on_interest_dropped_input_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_PASS, + &on_interest_inserted_input_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CACHE_HIT, + &on_interest_satisfied_output_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CACHE_MISS, + &on_interest_process_); + socket_->getSocketOption(ProducerCallbacksOptions::NEW_CONTENT_OBJECT, + &on_new_segment_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_READY, + &on_content_object_in_output_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_OUTPUT, + &on_content_object_output_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN, + &on_content_object_to_sign_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, + &on_content_produced_); + + socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); + + bool first = true; + + for (core::Prefix &producer_namespace : served_namespaces_) { + if (first) { + core::BindConfig bind_config(producer_namespace, 1000); + portal_->bind(bind_config); + portal_->setProducerCallback(this); + first = !first; + } else { + portal_->registerRoute(producer_namespace); + } + } + + is_running_ = true; + + if (!is_async_) { + listening_thread_ = std::thread([this]() { portal_->runEventsLoop(); }); + } + + return 0; +} + +void ProductionProtocol::stop() { + is_running_ = false; + + if (!is_async_) { + portal_->stopEventsLoop(); + } else { + portal_->clear(); + } +} + +void ProductionProtocol::produce(ContentObject &content_object) { + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()(*socket_->getInterface(), + content_object); + } + + output_buffer_.insert(std::static_pointer_cast<ContentObject>( + content_object.shared_from_this())); + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + content_object); + } + + portal_->sendContentObject(content_object); +} + +void ProductionProtocol::registerNamespaceWithNetwork( + const Prefix &producer_namespace) { + served_namespaces_.push_back(producer_namespace); +} + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/production_protocol.h b/libtransport/src/protocols/production_protocol.h new file mode 100644 index 000000000..780972321 --- /dev/null +++ b/libtransport/src/protocols/production_protocol.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/statistics.h> +#include <hicn/transport/utils/object_pool.h> +#include <implementation/socket.h> +#include <utils/content_store.h> + +#include <atomic> +#include <thread> + +namespace transport { + +namespace protocol { + +using namespace core; + +class ProductionProtocol : public Portal::ProducerCallback { + public: + ProductionProtocol(implementation::ProducerSocket *icn_socket); + virtual ~ProductionProtocol(); + + bool isRunning() { return is_running_; } + + virtual int start(); + virtual void stop(); + + virtual void produce(ContentObject &content_object); + virtual uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) = 0; + virtual uint32_t produceStream(const Name &content_name, + const uint8_t *buffer, size_t buffer_size, + bool is_last = true, + uint32_t start_offset = 0) = 0; + virtual uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) = 0; + virtual uint32_t produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) = 0; + + void setOutputBufferSize(std::size_t size) { output_buffer_.setLimit(size); } + std::size_t getOutputBufferSize() { return output_buffer_.getLimit(); } + + virtual void registerNamespaceWithNetwork(const Prefix &producer_namespace); + const std::list<Prefix> &getNamespaces() const { return served_namespaces_; } + + protected: + // Producer callback + virtual void onInterest(core::Interest &i) override = 0; + virtual void onError(std::error_code ec) override{}; + + protected: + implementation::ProducerSocket *socket_; + + // Thread pool responsible for IO operations (send data / receive interests) + std::vector<utils::EventThread> io_threads_; + + // TODO remove this thread + std::thread listening_thread_; + std::shared_ptr<Portal> portal_; + std::atomic<bool> is_running_; + interface::ProductionStatistics *stats_; + + // Callbacks + interface::ProducerInterestCallback *on_interest_input_; + interface::ProducerInterestCallback *on_interest_dropped_input_buffer_; + interface::ProducerInterestCallback *on_interest_inserted_input_buffer_; + interface::ProducerInterestCallback *on_interest_satisfied_output_buffer_; + interface::ProducerInterestCallback *on_interest_process_; + + interface::ProducerContentObjectCallback *on_new_segment_; + interface::ProducerContentObjectCallback *on_content_object_to_sign_; + interface::ProducerContentObjectCallback *on_content_object_in_output_buffer_; + interface::ProducerContentObjectCallback *on_content_object_output_; + interface::ProducerContentObjectCallback + *on_content_object_evicted_from_output_buffer_; + + interface::ProducerContentCallback *on_content_produced_; + + // Output buffer + utils::ContentStore output_buffer_; + + // List ot routes served by current producer protocol + std::list<Prefix> served_namespaces_; + + bool is_async_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/raaqm.cc b/libtransport/src/protocols/raaqm.cc index 5023adf2e..bc8500227 100644 --- a/libtransport/src/protocols/raaqm.cc +++ b/libtransport/src/protocols/raaqm.cc @@ -13,6 +13,7 @@ * limitations under the License. */ +#include <hicn/transport/core/global_object_pool.h> #include <hicn/transport/interfaces/socket_consumer.h> #include <implementation/socket_consumer.h> #include <protocols/errors.h> @@ -126,10 +127,6 @@ void RaaqmTransportProtocol::reset() { } } -bool RaaqmTransportProtocol::verifyKeyPackets() { - return index_manager_->onKeyToVerify(); -} - void RaaqmTransportProtocol::increaseWindow() { // return; double max_window_size = 0.; @@ -325,8 +322,8 @@ void RaaqmTransportProtocol::init() { is.close(); } -void RaaqmTransportProtocol::onContentObject( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { +void RaaqmTransportProtocol::onContentObject(Interest &interest, + ContentObject &content_object) { // Check whether makes sense to continue if (TRANSPORT_EXPECT_FALSE(!is_running_)) { return; @@ -334,54 +331,53 @@ void RaaqmTransportProtocol::onContentObject( // Call application-defined callbacks if (*on_content_object_input_) { - (*on_content_object_input_)(*socket_->getInterface(), *content_object); + (*on_content_object_input_)(*socket_->getInterface(), content_object); } if (*on_interest_satisfied_) { - (*on_interest_satisfied_)(*socket_->getInterface(), *interest); + (*on_interest_satisfied_)(*socket_->getInterface(), interest); } - if (content_object->getPayloadType() == PayloadType::CONTENT_OBJECT) { - stats_->updateBytesRecv(content_object->payloadSize()); + if (content_object.getPayloadType() == PayloadType::DATA) { + stats_->updateBytesRecv(content_object.payloadSize()); } - onContentSegment(std::move(interest), std::move(content_object)); + onContentSegment(interest, content_object); scheduleNextInterests(); } -void RaaqmTransportProtocol::onContentSegment( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { - uint32_t incremental_suffix = content_object->getName().getSuffix(); +void RaaqmTransportProtocol::onContentSegment(Interest &interest, + ContentObject &content_object) { + uint32_t incremental_suffix = content_object.getName().getSuffix(); // Decrease in-flight interests interests_in_flight_--; // Update stats if (!interest_retransmissions_[incremental_suffix & mask]) { - afterContentReception(*interest, *content_object); + afterContentReception(interest, content_object); } - index_manager_->onContentObject(std::move(interest), - std::move(content_object)); + index_manager_->onContentObject(interest, content_object); } -void RaaqmTransportProtocol::onPacketDropped( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { +void RaaqmTransportProtocol::onPacketDropped(Interest &interest, + ContentObject &content_object) { uint32_t max_rtx = 0; socket_->getSocketOption(GeneralTransportOptions::MAX_INTEREST_RETX, max_rtx); - uint64_t segment = interest->getName().getSuffix(); + uint64_t segment = interest.getName().getSuffix(); if (TRANSPORT_EXPECT_TRUE(interest_retransmissions_[segment & mask] < max_rtx)) { stats_->updateRetxCount(1); if (*on_interest_retransmission_) { - (*on_interest_retransmission_)(*socket_->getInterface(), *interest); + (*on_interest_retransmission_)(*socket_->getInterface(), interest); } if (*on_interest_output_) { - (*on_interest_output_)(*socket_->getInterface(), *interest); + (*on_interest_output_)(*socket_->getInterface(), interest); } if (!is_running_) { @@ -389,7 +385,7 @@ void RaaqmTransportProtocol::onPacketDropped( } interest_retransmissions_[segment & mask]++; - interest_to_retransmit_.push(std::move(interest)); + interest_to_retransmit_.push(interest.shared_from_this()); } else { TRANSPORT_LOGE( "Stop: received not trusted packet %llu times", @@ -477,6 +473,11 @@ void RaaqmTransportProtocol::scheduleNextInterests() { sendInterest(std::move(interest_to_retransmit_.front())); interest_to_retransmit_.pop(); } else { + if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { + TRANSPORT_LOGI("Adios"); + break; + } + index = index_manager_->getNextSuffix(); if (index == IndexManager::invalid_index) { break; @@ -487,8 +488,8 @@ void RaaqmTransportProtocol::scheduleNextInterests() { } } -bool RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { - auto interest = getPacket(); +void RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); core::Name *name; socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, &name); name->setSuffix((uint32_t)next_suffix); @@ -502,19 +503,12 @@ bool RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { if (*on_interest_output_) { on_interest_output_->operator()(*socket_->getInterface(), *interest); } - - if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { - return false; - } - // This is set to ~0 so that the next interest_retransmissions_ + 1, // performed by sendInterest, will result in 0 interest_retransmissions_[next_suffix & mask] = ~0; interest_timepoints_[next_suffix & mask] = utils::SteadyClock::now(); sendInterest(std::move(interest)); - - return true; } void RaaqmTransportProtocol::sendInterest(Interest::Ptr &&interest) { diff --git a/libtransport/src/protocols/raaqm.h b/libtransport/src/protocols/raaqm.h index fce4194d4..be477d39f 100644 --- a/libtransport/src/protocols/raaqm.h +++ b/libtransport/src/protocols/raaqm.h @@ -18,9 +18,9 @@ #include <hicn/transport/utils/chrono_typedefs.h> #include <protocols/byte_stream_reassembly.h> #include <protocols/congestion_window_protocol.h> -#include <protocols/protocol.h> #include <protocols/raaqm_data_path.h> #include <protocols/rate_estimation.h> +#include <protocols/transport_protocol.h> #include <queue> #include <vector> @@ -42,8 +42,6 @@ class RaaqmTransportProtocol : public TransportProtocol, void reset() override; - virtual bool verifyKeyPackets() override; - protected: static constexpr uint32_t buffer_size = 1 << interface::default_values::log_2_default_buffer_size; @@ -64,13 +62,12 @@ class RaaqmTransportProtocol : public TransportProtocol, private: void init(); - void onContentObject(Interest::Ptr &&i, ContentObject::Ptr &&c) override; + void onContentObject(Interest &i, ContentObject &c) override; - void onContentSegment(Interest::Ptr &&interest, - ContentObject::Ptr &&content_object); + void onContentSegment(Interest &interest, ContentObject &content_object); - void onPacketDropped(Interest::Ptr &&interest, - ContentObject::Ptr &&content_object) override; + void onPacketDropped(Interest &interest, + ContentObject &content_object) override; void onReassemblyFailed(std::uint32_t missing_segment) override; @@ -78,7 +75,7 @@ class RaaqmTransportProtocol : public TransportProtocol, virtual void scheduleNextInterests() override; - bool sendInterest(std::uint64_t next_suffix); + void sendInterest(std::uint64_t next_suffix); void sendInterest(Interest::Ptr &&interest); diff --git a/libtransport/src/protocols/raaqm_data_path.cc b/libtransport/src/protocols/raaqm_data_path.cc index 8bbbadcf2..f2c21b9ef 100644 --- a/libtransport/src/protocols/raaqm_data_path.cc +++ b/libtransport/src/protocols/raaqm_data_path.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/utils/chrono_typedefs.h> - #include <protocols/raaqm_data_path.h> namespace transport { diff --git a/libtransport/src/protocols/raaqm_data_path.h b/libtransport/src/protocols/raaqm_data_path.h index 3f037bc76..c0b53a690 100644 --- a/libtransport/src/protocols/raaqm_data_path.h +++ b/libtransport/src/protocols/raaqm_data_path.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/utils/chrono_typedefs.h> - #include <utils/min_filter.h> #include <chrono> diff --git a/libtransport/src/protocols/rate_estimation.cc b/libtransport/src/protocols/rate_estimation.cc index a2cf1aefe..5ca925760 100644 --- a/libtransport/src/protocols/rate_estimation.cc +++ b/libtransport/src/protocols/rate_estimation.cc @@ -15,7 +15,6 @@ #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/utils/log.h> - #include <protocols/rate_estimation.h> #include <thread> diff --git a/libtransport/src/protocols/rate_estimation.h b/libtransport/src/protocols/rate_estimation.h index 17f39e0b9..42ae74194 100644 --- a/libtransport/src/protocols/rate_estimation.h +++ b/libtransport/src/protocols/rate_estimation.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/interfaces/statistics.h> - #include <protocols/raaqm_data_path.h> #include <chrono> diff --git a/libtransport/src/protocols/reassembly.cc b/libtransport/src/protocols/reassembly.cc index c6602153c..0e59832dc 100644 --- a/libtransport/src/protocols/reassembly.cc +++ b/libtransport/src/protocols/reassembly.cc @@ -16,7 +16,6 @@ #include <hicn/transport/interfaces/socket_consumer.h> #include <hicn/transport/utils/array.h> #include <hicn/transport/utils/membuf.h> - #include <implementation/socket_consumer.h> #include <protocols/errors.h> #include <protocols/indexer.h> diff --git a/libtransport/src/protocols/reassembly.h b/libtransport/src/protocols/reassembly.h index fdc9f2a05..385122c53 100644 --- a/libtransport/src/protocols/reassembly.h +++ b/libtransport/src/protocols/reassembly.h @@ -46,7 +46,7 @@ class Reassembly { virtual ~Reassembly() = default; - virtual void reassemble(core::ContentObject::Ptr &&content_object) = 0; + virtual void reassemble(core::ContentObject &content_object) = 0; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) = 0; virtual void reInitialize() = 0; diff --git a/libtransport/src/protocols/rtc/CMakeLists.txt b/libtransport/src/protocols/rtc/CMakeLists.txt new file mode 100644 index 000000000..77f065d0e --- /dev/null +++ b/libtransport/src/protocols/rtc/CMakeLists.txt @@ -0,0 +1,38 @@ +# Copyright (c) 2017-2019 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +list(APPEND HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/rtc.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_state.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_ldr.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_consts.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc_queue.h + ${CMAKE_CURRENT_SOURCE_DIR}/probe_handler.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_packet.h +) + +list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/rtc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_state.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_ldr.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc_queue.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.cc + ${CMAKE_CURRENT_SOURCE_DIR}/probe_handler.cc +) + +set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/protocols/rtc/congestion_detection.cc b/libtransport/src/protocols/rtc/congestion_detection.cc new file mode 100644 index 000000000..e2d44ae66 --- /dev/null +++ b/libtransport/src/protocols/rtc/congestion_detection.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/utils/log.h> +#include <protocols/rtc/congestion_detection.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +CongestionDetection::CongestionDetection() + : cc_estimator_(), last_processed_chunk_() {} + +CongestionDetection::~CongestionDetection() {} + +void CongestionDetection::updateStats() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (chunks_number_.empty()) return; + + uint32_t chunk_number = chunks_number_.front(); + + while (chunks_[chunk_number].getReceivedTime() + HICN_CC_STATS_MAX_DELAY_MS < + now || + chunks_[chunk_number].isComplete()) { + if (chunk_number == last_processed_chunk_.getFrameSeqNum() + 1) { + chunks_[chunk_number].setPreviousSentTime( + last_processed_chunk_.getSentTime()); + + chunks_[chunk_number].setPreviousReceivedTime( + last_processed_chunk_.getReceivedTime()); + cc_estimator_.Update(chunks_[chunk_number].getReceivedDelta(), + chunks_[chunk_number].getSentDelta(), + chunks_[chunk_number].getSentTime(), + chunks_[chunk_number].getReceivedTime(), + chunks_[chunk_number].getFrameSize(), true); + + } else { + TRANSPORT_LOGD( + "CongestionDetection::updateStats frame %u but not the \ + previous one, last one was %u currentFrame %u", + chunk_number, last_processed_chunk_.getFrameSeqNum(), + chunks_[chunk_number].getFrameSeqNum()); + } + + last_processed_chunk_ = chunks_[chunk_number]; + + chunks_.erase(chunk_number); + + chunks_number_.pop(); + if (chunks_number_.empty()) break; + + chunk_number = chunks_number_.front(); + } +} + +void CongestionDetection::addPacket(const core::ContentObject &content_object) { + auto payload = content_object.getPayload(); + uint32_t payload_size = (uint32_t)payload->length(); + uint32_t segmentNumber = content_object.getName().getSuffix(); + // uint32_t pkt = segmentNumber & modMask_; + uint64_t *sentTimePtr = (uint64_t *)payload->data(); + + // this is just for testing with hiperf, assuming a frame is 10 pkts + // in the final version, the split should be based on the timestamp in the pkt + uint32_t frameNum = (int)(segmentNumber / HICN_CC_STATS_CHUNK_SIZE); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (chunks_.find(frameNum) == chunks_.end()) { + // new chunk of pkts or out of order + if (last_processed_chunk_.getFrameSeqNum() > frameNum) + return; // out of order and we already processed the chunk + + chunks_[frameNum] = FrameStats(frameNum, HICN_CC_STATS_CHUNK_SIZE); + chunks_number_.push(frameNum); + } + + chunks_[frameNum].addPacket(*sentTimePtr, now, payload_size); +} + +} // namespace rtc +} // namespace protocol +} // namespace transport diff --git a/libtransport/src/protocols/rtc/congestion_detection.h b/libtransport/src/protocols/rtc/congestion_detection.h new file mode 100644 index 000000000..17f4aa54c --- /dev/null +++ b/libtransport/src/protocols/rtc/congestion_detection.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/content_object.h> +#include <protocols/rtc/trendline_estimator.h> + +#include <map> +#include <queue> + +#define HICN_CC_STATS_CHUNK_SIZE 10 +#define HICN_CC_STATS_MAX_DELAY_MS 100 + +namespace transport { + +namespace protocol { + +namespace rtc { + +class FrameStats { + public: + FrameStats() + : frame_num_(0), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(HICN_CC_STATS_CHUNK_SIZE){}; + + FrameStats(uint32_t burst_size) + : frame_num_(0), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(burst_size){}; + + FrameStats(uint32_t frame_num, uint32_t burst_size) + : frame_num_(frame_num), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(burst_size){}; + + FrameStats(uint32_t frame_num, uint64_t sent_time, uint64_t received_time, + uint32_t size, FrameStats previousFrame, uint32_t burst_size) + : frame_num_(frame_num), + sent_time_(sent_time), + received_time_(received_time), + previous_sent_time_(previousFrame.getSentTime()), + previous_received_time_(previousFrame.getReceivedTime()), + size_(size), + received_pkt_m(1), + burst_size_m(burst_size){}; + + void addPacket(uint64_t sent_time, uint64_t received_time, uint32_t size) { + size_ += size; + sent_time_ = + (sent_time_ == 0) ? sent_time : std::min(sent_time_, sent_time); + received_time_ = std::max(received_time, received_time_); + received_pkt_m++; + } + + bool isComplete() { return received_pkt_m == burst_size_m; } + + uint32_t getFrameSeqNum() const { return frame_num_; } + uint64_t getSentTime() const { return sent_time_; } + uint64_t getReceivedTime() const { return received_time_; } + uint32_t getFrameSize() const { return size_; } + + void setPreviousReceivedTime(uint64_t time) { + previous_received_time_ = time; + } + void setPreviousSentTime(uint64_t time) { previous_sent_time_ = time; } + + // todo manage first frame + double getReceivedDelta() { + return static_cast<double>(received_time_ - previous_received_time_); + } + double getSentDelta() { + return static_cast<double>(sent_time_ - previous_sent_time_); + } + + private: + uint32_t frame_num_; + uint64_t sent_time_; + uint64_t received_time_; + + uint64_t previous_sent_time_; + uint64_t previous_received_time_; + uint32_t size_; + + uint32_t received_pkt_m; + uint32_t burst_size_m; +}; + +class CongestionDetection { + public: + CongestionDetection(); + ~CongestionDetection(); + + void addPacket(const core::ContentObject &content_object); + + BandwidthUsage getState() { return cc_estimator_.State(); } + + void updateStats(); + + private: + TrendlineEstimator cc_estimator_; + std::map<uint32_t, FrameStats> chunks_; + std::queue<uint32_t> chunks_number_; + + FrameStats last_processed_chunk_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/probe_handler.cc b/libtransport/src/protocols/rtc/probe_handler.cc new file mode 100644 index 000000000..efba362d4 --- /dev/null +++ b/libtransport/src/protocols/rtc/probe_handler.cc @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/probe_handler.h> +#include <protocols/rtc/rtc_consts.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +ProbeHandler::ProbeHandler(SendProbeCallback &&send_callback, + asio::io_service &io_service) + : probe_interval_(0), + max_probes_(0), + sent_probes_(0), + probe_timer_(std::make_unique<asio::steady_timer>(io_service)), + rand_eng_((std::random_device())()), + distr_(MIN_RTT_PROBE_SEQ, MAX_RTT_PROBE_SEQ), + send_probe_callback_(std::move(send_callback)) {} + +ProbeHandler::~ProbeHandler() {} + +uint64_t ProbeHandler::getRtt(uint32_t seq) { + auto it = pending_probes_.find(seq); + + if (it == pending_probes_.end()) return 0; + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t rtt = now - it->second; + if(rtt < 1) rtt = 1; + + pending_probes_.erase(it); + + return rtt; +} + +void ProbeHandler::setProbes(uint32_t probe_interval, uint32_t max_probes) { + stopProbes(); + probe_interval_ = probe_interval; + max_probes_ = max_probes; +} + +void ProbeHandler::stopProbes() { + probe_interval_ = 0; + max_probes_ = 0; + sent_probes_ = 0; + probe_timer_->cancel(); +} + +void ProbeHandler::sendProbes() { + if (probe_interval_ == 0) return; + if (max_probes_ != 0 && sent_probes_ >= max_probes_) return; + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + uint32_t seq = distr_(rand_eng_); + pending_probes_.insert(std::pair<uint32_t, uint64_t>(seq, now)); + send_probe_callback_(seq); + sent_probes_++; + + // clean up + // a probe may get lost. if the pending_probes_ size becomes bigger than + // MAX_PENDING_PROBES remove all the probes older than a seconds + if (pending_probes_.size() > MAX_PENDING_PROBES) { + for (auto it = pending_probes_.begin(); it != pending_probes_.end();) { + if ((now - it->second) > 1000) + it = pending_probes_.erase(it); + else + it++; + } + } + + if (probe_interval_ == 0) return; + + std::weak_ptr<ProbeHandler> self(shared_from_this()); + probe_timer_->expires_from_now(std::chrono::microseconds(probe_interval_)); + probe_timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->sendProbes(); + } + }); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/probe_handler.h b/libtransport/src/protocols/rtc/probe_handler.h new file mode 100644 index 000000000..b8ed84445 --- /dev/null +++ b/libtransport/src/protocols/rtc/probe_handler.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include <hicn/transport/config.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <functional> +#include <random> +#include <unordered_map> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class ProbeHandler : public std::enable_shared_from_this<ProbeHandler> { + public: + using SendProbeCallback = std::function<void(uint32_t)>; + + public: + ProbeHandler(SendProbeCallback &&send_callback, + asio::io_service &io_service); + + ~ProbeHandler(); + + // if the function returns 0 the probe is not valaid + uint64_t getRtt(uint32_t seq); + + // reset the probes parameters. it stop the current probing. + // to restar call sendProbes. + // probe_interval = 0 means that no event will be scheduled + // max_probe = 0 means no limit to the number of probe to send + void setProbes(uint32_t probe_interval, uint32_t max_probes); + + // stop to schedule probes + void stopProbes(); + + void sendProbes(); + + private: + uint32_t probe_interval_; // us + uint32_t max_probes_; // packets + uint32_t sent_probes_; // packets + + std::unique_ptr<asio::steady_timer> probe_timer_; + + // map from seqnumber to timestamp + std::unordered_map<uint32_t, uint64_t> pending_probes_; + + // random generator + std::default_random_engine rand_eng_; + std::uniform_int_distribution<uint32_t> distr_; + + SendProbeCallback send_probe_callback_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc.cc b/libtransport/src/protocols/rtc/rtc.cc new file mode 100644 index 000000000..bb95ab686 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc.cc @@ -0,0 +1,607 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <implementation/socket_consumer.h> +#include <math.h> +#include <protocols/rtc/rtc.h> +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_queue.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +using namespace interface; + +RTCTransportProtocol::RTCTransportProtocol( + implementation::ConsumerSocket *icn_socket) + : TransportProtocol(icn_socket, nullptr), + DatagramReassembly(icn_socket, this), + number_(0) { + icn_socket->getSocketOption(PORTAL, portal_); + round_timer_ = std::make_unique<asio::steady_timer>(portal_->getIoService()); + scheduler_timer_ = + std::make_unique<asio::steady_timer>(portal_->getIoService()); +} + +RTCTransportProtocol::~RTCTransportProtocol() {} + +void RTCTransportProtocol::resume() { + if (is_running_) return; + + is_running_ = true; + + newRound(); + + portal_->runEventsLoop(); + is_running_ = false; +} + +// private +void RTCTransportProtocol::initParams() { + portal_->setConsumerCallback(this); + + rc_ = std::make_shared<RTCRateControlQueue>(); + ldr_ = std::make_shared<RTCLossDetectionAndRecovery>( + std::bind(&RTCTransportProtocol::sendRtxInterest, this, + std::placeholders::_1), + portal_->getIoService()); + + state_ = std::make_shared<RTCState>( + std::bind(&RTCTransportProtocol::sendProbeInterest, this, + std::placeholders::_1), + std::bind(&RTCTransportProtocol::discoveredRtt, this), + portal_->getIoService()); + + rc_->setState(state_); + // TODO: for the moment we keep the congestion control disabled + // rc_->tunrOnRateControl(); + ldr_->setState(state_); + + // protocol state + start_send_interest_ = false; + current_state_ = SyncState::catch_up; + + // Cancel timer + number_++; + round_timer_->cancel(); + scheduler_timer_->cancel(); + scheduler_timer_on_ = false; + + // delete all timeouts and future nacks + timeouts_or_nacks_.clear(); + + // cwin vars + current_sync_win_ = INITIAL_WIN; + max_sync_win_ = INITIAL_WIN_MAX; + + // names/packets var + next_segment_ = 0; + + socket_->setSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + RTC_INTEREST_LIFETIME); +} + +// private +void RTCTransportProtocol::reset() { + TRANSPORT_LOGD("reset called"); + initParams(); + newRound(); +} + +void RTCTransportProtocol::inactiveProducer() { + // when the producer is inactive we reset the consumer state + // cwin vars + current_sync_win_ = INITIAL_WIN; + max_sync_win_ = INITIAL_WIN_MAX; + + TRANSPORT_LOGD("Current window: %u, max_sync_win_: %u", current_sync_win_, + max_sync_win_); + + // names/packets var + next_segment_ = 0; + + ldr_->clear(); +} + +void RTCTransportProtocol::newRound() { + round_timer_->expires_from_now(std::chrono::milliseconds(ROUND_LEN)); + // TODO pass weak_ptr here + round_timer_->async_wait([this, n{number_}](std::error_code ec) { + if (ec) return; + + if (n != number_) { + return; + } + + // saving counters that will be reset on new round + uint32_t sent_retx = state_->getSentRtxInRound(); + uint32_t received_bytes = state_->getReceivedBytesInRound(); + uint32_t sent_interest = state_->getSentInterestInRound(); + uint32_t lost_data = state_->getLostData(); + uint32_t recovered_losses = state_->getRecoveredLosses(); + uint32_t received_nacks = state_->getReceivedNacksInRound(); + + bool in_sync = (current_state_ == SyncState::in_sync); + state_->onNewRound((double)ROUND_LEN, in_sync); + rc_->onNewRound((double)ROUND_LEN); + + // update sync state if needed + if (current_state_ == SyncState::in_sync) { + double cache_rate = state_->getPacketFromCacheRatio(); + if (cache_rate > MAX_DATA_FROM_CACHE) { + current_state_ = SyncState::catch_up; + } + } else { + double target_rate = state_->getProducerRate() * PRODUCTION_RATE_FRACTION; + double received_rate = state_->getReceivedRate(); + uint32_t round_without_nacks = state_->getRoundsWithoutNacks(); + double cache_ratio = state_->getPacketFromCacheRatio(); + if (round_without_nacks >= ROUNDS_IN_SYNC_BEFORE_SWITCH && + received_rate >= target_rate && cache_ratio < MAX_DATA_FROM_CACHE) { + current_state_ = SyncState::in_sync; + } + } + + TRANSPORT_LOGD("Calling updateSyncWindow in newRound function"); + updateSyncWindow(); + + sendStatsToApp(sent_retx, received_bytes, sent_interest, lost_data, + recovered_losses, received_nacks); + newRound(); + }); +} + +void RTCTransportProtocol::discoveredRtt() { + start_send_interest_ = true; + ldr_->turnOnRTX(); + updateSyncWindow(); +} + +void RTCTransportProtocol::computeMaxSyncWindow() { + double production_rate = state_->getProducerRate(); + double packet_size = state_->getAveragePacketSize(); + if (production_rate == 0.0 || packet_size == 0.0) { + // the consumer has no info about the producer, + // keep the previous maxCWin + TRANSPORT_LOGD( + "Returning in computeMaxSyncWindow because: prod_rate: %d || " + "packet_size: %d", + (int)(production_rate == 0.0), (int)(packet_size == 0.0)); + return; + } + + uint32_t lifetime = default_values::interest_lifetime; + socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + double lifetime_ms = (double)lifetime / MILLI_IN_A_SEC; + + + max_sync_win_ = + (uint32_t)ceil((production_rate * lifetime_ms * + INTEREST_LIFETIME_REDUCTION_FACTOR) / packet_size); + + max_sync_win_ = std::min(max_sync_win_, rc_->getCongesionWindow()); +} + +void RTCTransportProtocol::updateSyncWindow() { + computeMaxSyncWindow(); + + if (max_sync_win_ == INITIAL_WIN_MAX) { + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) return; + + current_sync_win_ = INITIAL_WIN; + scheduleNextInterests(); + return; + } + + double prod_rate = state_->getProducerRate(); + double rtt = (double)state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = state_->getAveragePacketSize(); + + // if some of the info are not available do not update the current win + if (prod_rate != 0.0 && rtt != 0.0 && packet_size != 0.0) { + current_sync_win_ = (uint32_t)ceil(prod_rate * rtt / packet_size); + current_sync_win_ += + ceil(prod_rate * (PRODUCER_BUFFER_MS / MILLI_IN_A_SEC) / packet_size); + + if(current_state_ == SyncState::catch_up) { + current_sync_win_ = current_sync_win_ * CATCH_UP_WIN_INCREMENT; + } + + current_sync_win_ = std::min(current_sync_win_, max_sync_win_); + current_sync_win_ = std::max(current_sync_win_, WIN_MIN); + } + + scheduleNextInterests(); +} + +void RTCTransportProtocol::decreaseSyncWindow() { + // called on future nack + // we have a new sample of the production rate, so update max win first + computeMaxSyncWindow(); + current_sync_win_--; + current_sync_win_ = std::max(current_sync_win_, WIN_MIN); + scheduleNextInterests(); +} + +void RTCTransportProtocol::sendInterest(Name *interest_name) { + TRANSPORT_LOGD("Sending interest for name %s", + interest_name->toString().c_str()); + + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); + interest->setName(*interest_name); + + uint32_t lifetime = default_values::interest_lifetime; + socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + interest->setLifetime(uint32_t(lifetime)); + + if (*on_interest_output_) { + (*on_interest_output_)(*socket_->getInterface(), *interest); + } + + if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { + return; + } + + portal_->sendInterest(std::move(interest)); +} + +void RTCTransportProtocol::sendRtxInterest(uint32_t seq) { + if (!is_running_ && !is_first_) return; + + if(!start_send_interest_) return; + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send rtx %u", seq); + interest_name->setSuffix(seq); + sendInterest(interest_name); +} + +void RTCTransportProtocol::sendProbeInterest(uint32_t seq) { + if (!is_running_ && !is_first_) return; + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send probe %u", seq); + interest_name->setSuffix(seq); + sendInterest(interest_name); +} + +void RTCTransportProtocol::scheduleNextInterests() { + TRANSPORT_LOGD("Schedule next interests"); + + if (!is_running_ && !is_first_) return; + + if(!start_send_interest_) return; // RTT discovering phase is not finished so + // do not start to send interests + + if (scheduler_timer_on_) return; // wait befor send other interests + + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) { + TRANSPORT_LOGD("Inactive producer."); + // here we keep seding the same interest until the producer + // does not start again + if (next_segment_ != 0) { + // the producer just become inactive, reset the state + inactiveProducer(); + } + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send interest %u", next_segment_); + interest_name->setSuffix(next_segment_); + + if (portal_->interestIsPending(*interest_name)) { + // if interest 0 is already pending we return + return; + } + + sendInterest(interest_name); + state_->onSendNewInterest(interest_name); + return; + } + + TRANSPORT_LOGD("Pending interest number: %d -- current_sync_win_: %d", + state_->getPendingInterestNumber(), current_sync_win_); + + // skip nacked pacekts + if (next_segment_ <= state_->getLastSeqNacked()) { + next_segment_ = state_->getLastSeqNacked() + 1; + } + + // skipe received packets + if (next_segment_ <= state_->getHighestSeqReceivedInOrder()) { + next_segment_ = state_->getHighestSeqReceivedInOrder() + 1; + } + + uint32_t sent_interests = 0; + while ((state_->getPendingInterestNumber() < current_sync_win_) && + (sent_interests < MAX_INTERESTS_IN_BATCH)) { + TRANSPORT_LOGD("In while loop. Window size: %u", current_sync_win_); + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + interest_name->setSuffix(next_segment_); + + // send the packet only if: + // 1) it is not pending yet (not true for rtx) + // 2) the packet is not received or lost + // 3) is not in the rtx list + if (portal_->interestIsPending(*interest_name) || + state_->isReceivedOrLost(next_segment_) != PacketState::UNKNOWN || + ldr_->isRtx(next_segment_)) { + TRANSPORT_LOGD( + "skip interest %u because: pending %u, recv %u, rtx %u", + next_segment_, (portal_->interestIsPending(*interest_name)), + (state_->isReceivedOrLost(next_segment_) != PacketState::UNKNOWN), + (ldr_->isRtx(next_segment_))); + next_segment_ = (next_segment_ + 1) % MIN_PROBE_SEQ; + continue; + } + + + sent_interests++; + TRANSPORT_LOGD("send interest %u", next_segment_); + sendInterest(interest_name); + state_->onSendNewInterest(interest_name); + + next_segment_ = (next_segment_ + 1) % MIN_PROBE_SEQ; + } + + if (state_->getPendingInterestNumber() < current_sync_win_) { + // we still have space in the window but we already sent a batch of + // MAX_INTERESTS_IN_BATCH interest. for the following ones wait one + // WAIT_BETWEEN_INTEREST_BATCHES to avoid local packets drop + + scheduler_timer_on_ = true; + scheduler_timer_->expires_from_now( + std::chrono::microseconds(WAIT_BETWEEN_INTEREST_BATCHES)); + scheduler_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + if (!scheduler_timer_on_) return; + + scheduler_timer_on_ = false; + scheduleNextInterests(); + }); + } +} + +void RTCTransportProtocol::onTimeout(Interest::Ptr &&interest) { + uint32_t segment_number = interest->getName().getSuffix(); + + TRANSPORT_LOGD("timeout for packet %u", segment_number); + + if (segment_number >= MIN_PROBE_SEQ) { + // this is a timeout on a probe, do nothing + return; + } + + timeouts_or_nacks_.insert(segment_number); + + if (TRANSPORT_EXPECT_TRUE(state_->isProducerActive()) && + segment_number <= state_->getHighestSeqReceivedInOrder()) { + // we retransmit packets only if the producer is active, otherwise we + // use timeouts to avoid to send too much traffic + // + // a timeout is sent using RTX only if it is an old packet. if it is for a + // seq number that we didn't reach yet, we send the packet using the normal + // schedule next interest + TRANSPORT_LOGD("handle timeout for packet %u using rtx", segment_number); + ldr_->onTimeout(segment_number); + state_->onTimeout(segment_number); + scheduleNextInterests(); + return; + } + + TRANSPORT_LOGD("handle timeout for packet %u using normal interests", + segment_number); + + if (segment_number < next_segment_) { + // this is a timeout for a packet that will be generated in the future but + // we are asking for higher sequence numbers. we need to go back like in the + // case of future nacks + TRANSPORT_LOGD("on timeout next seg = %u, jump to %u", + next_segment_, segment_number); + next_segment_ = segment_number; + } + + state_->onTimeout(segment_number); + scheduleNextInterests(); +} + +void RTCTransportProtocol::onNack(const ContentObject &content_object) { + struct nack_packet_t *nack = + (struct nack_packet_t *)content_object.getPayload()->data(); + uint32_t production_seg = nack->getProductionSegement(); + uint32_t nack_segment = content_object.getName().getSuffix(); + bool is_rtx = ldr_->isRtx(nack_segment); + + // check if the packet got a timeout + + TRANSPORT_LOGD("Nack received %u. Production segment: %u", nack_segment, + production_seg); + + bool compute_stats = true; + auto tn_it = timeouts_or_nacks_.find(nack_segment); + if (tn_it != timeouts_or_nacks_.end() || is_rtx) { + compute_stats = false; + // remove packets from timeouts_or_nacks only in case of a past nack + } + + state_->onNackPacketReceived(content_object, compute_stats); + ldr_->onNackPacketReceived(content_object); + + // both in case of past and future nack we set next_segment_ equal to the + // production segment in the nack. In case of past nack we will skip unneded + // interest (this is already done in the scheduleNextInterest in any case) + // while in case of future nacks we can go back in time and ask again for the + // content that generated the nack + TRANSPORT_LOGD("on nack next seg = %u, jump to %u", + next_segment_, production_seg); + next_segment_ = production_seg; + + if (production_seg > nack_segment) { + // remove the nack is it exists + if (tn_it != timeouts_or_nacks_.end()) timeouts_or_nacks_.erase(tn_it); + + // the client is asking for content in the past + // switch to catch up state and increase the window + // this is true only if the packet is not an RTX + if (!is_rtx) current_state_ = SyncState::catch_up; + + updateSyncWindow(); + } else { + // if production_seg == nack_segment we consider this a future nack, since + // production_seg is not yet created. this may happen in case of low + // production rate (e.g. ping at 1pps) + + // if a future nack was also retransmitted add it to the timeout_or_nacks + // set + if (is_rtx) timeouts_or_nacks_.insert(nack_segment); + + // the client is asking for content in the future + // switch to in sync state and decrease the window + current_state_ = SyncState::in_sync; + decreaseSyncWindow(); + } +} + +void RTCTransportProtocol::onProbe(const ContentObject &content_object) { + bool valid = state_->onProbePacketReceived(content_object); + if(!valid) return; + + struct nack_packet_t *probe = + (struct nack_packet_t *)content_object.getPayload()->data(); + uint32_t production_seg = probe->getProductionSegement(); + + // as for the nacks set next_segment_ + TRANSPORT_LOGD("on probe next seg = %u, jump to %u", + next_segment_, production_seg); + next_segment_ = production_seg; + + ldr_->onProbePacketReceived(content_object); + updateSyncWindow(); +} + +void RTCTransportProtocol::onContentObject(Interest &interest, + ContentObject &content_object) { + TRANSPORT_LOGD("Received content object of size: %zu", + content_object.payloadSize()); + uint32_t payload_size = content_object.payloadSize(); + uint32_t segment_number = content_object.getName().getSuffix(); + + if (segment_number >= MIN_PROBE_SEQ) { + TRANSPORT_LOGD("Received probe %u", segment_number); + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + onProbe(content_object); + return; + } + + if (payload_size == NACK_HEADER_SIZE) { + TRANSPORT_LOGD("Received nack %u", segment_number); + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + onNack(content_object); + return; + } + + TRANSPORT_LOGD("Received content %u", segment_number); + + rc_->onDataPacketReceived(content_object); + bool compute_stats = true; + auto tn_it = timeouts_or_nacks_.find(segment_number); + if (tn_it != timeouts_or_nacks_.end()) { + compute_stats = false; + timeouts_or_nacks_.erase(tn_it); + } + if (ldr_->isRtx(segment_number)) { + compute_stats = false; + } + + // check if the packet was already received + PacketState state = state_->isReceivedOrLost(segment_number); + state_->onDataPacketReceived(content_object, compute_stats); + ldr_->onDataPacketReceived(content_object); + + // if the stat for this seq number is received do not send the packet to app + if (state != PacketState::RECEIVED) { + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + reassemble(content_object); + } else { + TRANSPORT_LOGD("Received duplicated content %u, drop it", segment_number); + } + + updateSyncWindow(); +} + +void RTCTransportProtocol::sendStatsToApp( + uint32_t retx_count, uint32_t received_bytes, uint32_t sent_interests, + uint32_t lost_data, uint32_t recovered_losses, uint32_t received_nacks) { + if (*stats_summary_) { + // Send the stats to the app + stats_->updateQueuingDelay(state_->getQueuing()); + + // stats_->updateInterestFecTx(0); //todo must be implemented + // stats_->updateBytesFecRecv(0); //todo must be implemented + + stats_->updateRetxCount(retx_count); + stats_->updateBytesRecv(received_bytes); + stats_->updateInterestTx(sent_interests); + stats_->updateReceivedNacks(received_nacks); + + stats_->updateAverageWindowSize(current_sync_win_); + stats_->updateLossRatio(state_->getLossRate()); + stats_->updateAverageRtt(state_->getRTT()); + stats_->updateLostData(lost_data); + stats_->updateRecoveredData(recovered_losses); + stats_->updateCCState((unsigned int)current_state_ ? 1 : 0); + (*stats_summary_)(*socket_->getInterface(), *stats_); + } +} + +void RTCTransportProtocol::reassemble(ContentObject &content_object) { + auto read_buffer = content_object.getPayload(); + TRANSPORT_LOGD("Size of payload: %zu", read_buffer->length()); + read_buffer->trimStart(DATA_HEADER_SIZE); + Reassembly::read_buffer_ = std::move(read_buffer); + Reassembly::notifyApplication(); +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc.h b/libtransport/src/protocols/rtc/rtc.h new file mode 100644 index 000000000..596887067 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <protocols/datagram_reassembly.h> +#include <protocols/rtc/rtc_ldr.h> +#include <protocols/rtc/rtc_rc.h> +#include <protocols/rtc/rtc_state.h> +#include <protocols/transport_protocol.h> + +#include <unordered_set> +#include <vector> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCTransportProtocol : public TransportProtocol, + public DatagramReassembly { + public: + RTCTransportProtocol(implementation::ConsumerSocket *icnet_socket); + + ~RTCTransportProtocol(); + + using TransportProtocol::start; + + using TransportProtocol::stop; + + void resume() override; + + private: + enum class SyncState { catch_up = 0, in_sync = 1, last }; + + private: + // setup functions + void initParams(); + void reset() override; + + void inactiveProducer(); + + // protocol functions + void discoveredRtt(); + void newRound(); + + // window functions + void computeMaxSyncWindow(); + void updateSyncWindow(); + void decreaseSyncWindow(); + + // packet functions + void sendInterest(Name *interest_name); + void sendRtxInterest(uint32_t seq); + void sendProbeInterest(uint32_t seq); + void scheduleNextInterests() override; + void onTimeout(Interest::Ptr &&interest) override; + void onNack(const ContentObject &content_object); + void onProbe(const ContentObject &content_object); + void reassemble(ContentObject &content_object) override; + void onContentObject(Interest &interest, + ContentObject &content_object) override; + void onPacketDropped(Interest &interest, + ContentObject &content_object) override {} + void onReassemblyFailed(std::uint32_t missing_segment) override {} + + // interaction with app functions + void sendStatsToApp(uint32_t retx_count, uint32_t received_bytes, + uint32_t sent_interests, uint32_t lost_data, + uint32_t recovered_losses, uint32_t received_nacks); + // protocol state + bool start_send_interest_; + SyncState current_state_; + // cwin vars + uint32_t current_sync_win_; + uint32_t max_sync_win_; + + // controller var + std::unique_ptr<asio::steady_timer> round_timer_; + std::unique_ptr<asio::steady_timer> scheduler_timer_; + bool scheduler_timer_on_; + + // timeouts + std::unordered_set<uint32_t> timeouts_or_nacks_; + + // names/packets var + uint32_t next_segment_; + + std::shared_ptr<RTCState> state_; + std::shared_ptr<RTCRateControl> rc_; + std::shared_ptr<RTCLossDetectionAndRecovery> ldr_; + + uint32_t number_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_consts.h b/libtransport/src/protocols/rtc/rtc_consts.h new file mode 100644 index 000000000..0cf9516ab --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_consts.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <protocols/rtc/rtc_packet.h> +#include <stdint.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +// used in rtc +// protocol consts +const uint32_t ROUND_LEN = 200; +// ms interval of time on which +// we take decisions / measurements +const double INTEREST_LIFETIME_REDUCTION_FACTOR = 0.8; +// how big (in ms) should be the buffer at the producer. +// increasing this number we increase the time that an +// interest will wait for the data packet to be produced +// at the producer socket +const uint32_t PRODUCER_BUFFER_MS = 200; // ms + +// interest scheduler +const uint32_t MAX_INTERESTS_IN_BATCH = 5; +const uint32_t WAIT_BETWEEN_INTEREST_BATCHES = 1000; // usec + +// packet const +const uint32_t HICN_HEADER_SIZE = 40 + 20; // IPv6 + TCP bytes +const uint32_t RTC_INTEREST_LIFETIME = 1000; + +// probes sequence range +const uint32_t MIN_PROBE_SEQ = 0xefffffff; +const uint32_t MIN_RTT_PROBE_SEQ = MIN_PROBE_SEQ; +const uint32_t MAX_RTT_PROBE_SEQ = 0xffffffff - 1; +// RTT_PROBE_INTERVAL will be used during the section while +// INIT_RTT_PROBE_INTERVAL is used at the beginning to +// quickily estimate the RTT +const uint32_t RTT_PROBE_INTERVAL = 200000; // us +const uint32_t INIT_RTT_PROBE_INTERVAL = 500; // us +const uint32_t INIT_RTT_PROBES = 40; // number of probes to init RTT +// if the produdcer is not yet started we need to probe multple times +// to get an answer. we wait 100ms between each try +const uint32_t INIT_RTT_PROBE_RESTART = 100; // ms +// once we get the first probe we wait at most 60ms for the others +const uint32_t INIT_RTT_PROBE_WAIT = 30; // ms +// we reuires at least 5 probes to be recevied +const uint32_t INIT_RTT_MIN_PROBES_TO_RECV = 5; //ms +const uint32_t MAX_PENDING_PROBES = 10; + + +// congestion +const double MAX_QUEUING_DELAY = 100.0; // ms + +// data from cache +const double MAX_DATA_FROM_CACHE = 0.25; // 25% + +// window const +const uint32_t INITIAL_WIN = 5; // pkts +const uint32_t INITIAL_WIN_MAX = 1000000; // pkts +const uint32_t WIN_MIN = 5; // pkts +const double CATCH_UP_WIN_INCREMENT = 1.2; +// used in rate control +const double WIN_DECREASE_FACTOR = 0.5; +const double WIN_INCREASE_FACTOR = 1.5; + +// round in congestion +const double ROUNDS_BEFORE_TAKE_ACTION = 5; + +// used in state +const uint8_t ROUNDS_IN_SYNC_BEFORE_SWITCH = 3; +const double PRODUCTION_RATE_FRACTION = 0.8; + +const uint32_t INIT_PACKET_SIZE = 1200; + +const double MOVING_AVG_ALPHA = 0.8; + +const double MILLI_IN_A_SEC = 1000.0; +const double MICRO_IN_A_SEC = 1000000.0; + +const double MAX_CACHED_PACKETS = 262144; // 2^18 + // about 50 sec of traffic at 50Mbps + // with 1200 bytes packets + +const uint32_t MAX_ROUND_WHIOUT_PACKETS = + (20 * MILLI_IN_A_SEC) / ROUND_LEN; // 20 sec in rounds; + +// used in ldr +const uint32_t RTC_MAX_RTX = 100; +const uint32_t RTC_MAX_AGE = 60000; // in ms +const uint64_t MAX_TIMER_RTX = ~0; +const uint32_t SENTINEL_TIMER_INTERVAL = 100; // ms +const uint32_t MAX_RTX_WITH_SENTINEL = 10; // packets +const double CATCH_UP_RTT_INCREMENT = 1.2; + +// used by producer +const uint32_t PRODUCER_STATS_INTERVAL = 200; // ms +const uint32_t MIN_PRODUCTION_RATE = 10; // pps + // min prod rate + // set running several test + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_data_path.cc b/libtransport/src/protocols/rtc/rtc_data_path.cc new file mode 100644 index 000000000..c098088a3 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_data_path.cc @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_data_path.h> +#include <stdlib.h> + +#include <algorithm> +#include <cfloat> +#include <chrono> + +#define MAX_ROUNDS_WITHOUT_PKTS 10 // 2sec + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCDataPath::RTCDataPath(uint32_t path_id) + : path_id_(path_id), + min_rtt(UINT_MAX), + prev_min_rtt(UINT_MAX), + min_owd(INT_MAX), // this is computed like in LEDBAT, so it is not the + // real OWD, but the measured one, that depends on the + // clock of sender and receiver. the only meaningful + // value is is the queueing delay. for this reason we + // keep both RTT (for the windowd calculation) and OWD + // (for congestion/quality control) + prev_min_owd(INT_MAX), + avg_owd(DBL_MAX), + queuing_delay(DBL_MAX), + jitter_(0.0), + last_owd_(0), + largest_recv_seq_(0), + largest_recv_seq_time_(0), + avg_inter_arrival_(DBL_MAX), + received_nacks_(false), + received_packets_(false), + rounds_without_packets_(0), + last_received_data_packet_(0), + RTT_history_(HISTORY_LEN), + OWD_history_(HISTORY_LEN){}; + +void RTCDataPath::insertRttSample(uint64_t rtt) { + // for the rtt we only keep track of the min one + if (rtt < min_rtt) min_rtt = rtt; + last_received_data_packet_ = + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); +} + +void RTCDataPath::insertOwdSample(int64_t owd) { + // for owd we use both min and avg + if (owd < min_owd) min_owd = owd; + + if (avg_owd != DBL_MAX) + avg_owd = (avg_owd * (1 - ALPHA_RTC)) + (owd * ALPHA_RTC); + else { + avg_owd = owd; + } + + int64_t queueVal = owd - std::min(getMinOwd(), min_owd); + + if (queuing_delay != DBL_MAX) + queuing_delay = (queuing_delay * (1 - ALPHA_RTC)) + (queueVal * ALPHA_RTC); + else { + queuing_delay = queueVal; + } + + // keep track of the jitter computed as for RTP (RFC 3550) + int64_t diff = std::abs(owd - last_owd_); + last_owd_ = owd; + jitter_ += (1.0 / 16.0) * ((double)diff - jitter_); + + // owd is computed only for valid data packets so we count only + // this for decide if we recevie traffic or not + received_packets_ = true; +} + +void RTCDataPath::computeInterArrivalGap(uint32_t segment_number) { + // got packet in sequence, compute gap + if (largest_recv_seq_ == (segment_number - 1)) { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t delta = now - largest_recv_seq_time_; + largest_recv_seq_ = segment_number; + largest_recv_seq_time_ = now; + if (avg_inter_arrival_ == DBL_MAX) + avg_inter_arrival_ = delta; + else + avg_inter_arrival_ = + (avg_inter_arrival_ * (1 - ALPHA_RTC)) + (delta * ALPHA_RTC); + return; + } + + // ooo packet, update the stasts if needed + if (largest_recv_seq_ <= segment_number) { + largest_recv_seq_ = segment_number; + largest_recv_seq_time_ = + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + } +} + +void RTCDataPath::receivedNack() { received_nacks_ = true; } + +double RTCDataPath::getInterArrivalGap() { + if (avg_inter_arrival_ == DBL_MAX) return 0; + return avg_inter_arrival_; +} + +bool RTCDataPath::isActive() { + if (received_nacks_ && rounds_without_packets_ < MAX_ROUNDS_WITHOUT_PKTS) + return true; + return false; +} + +bool RTCDataPath::pathToProducer() { + if (received_nacks_) return true; + return false; +} + +void RTCDataPath::roundEnd() { + // reset min_rtt and add it to the history + if (min_rtt != UINT_MAX) { + prev_min_rtt = min_rtt; + } else { + // this may happen if we do not receive any packet + // from this path in the last round. in this case + // we use the measure from the previuos round + min_rtt = prev_min_rtt; + } + + if (min_rtt == 0) min_rtt = 1; + + RTT_history_.pushBack(min_rtt); + min_rtt = UINT_MAX; + + // do the same for min owd + if (min_owd != INT_MAX) { + prev_min_owd = min_owd; + } else { + min_owd = prev_min_owd; + } + + if (min_owd != INT_MAX) { + OWD_history_.pushBack(min_owd); + min_owd = INT_MAX; + } + + if (!received_packets_) + rounds_without_packets_++; + else + rounds_without_packets_ = 0; + received_packets_ = false; +} + +uint32_t RTCDataPath::getPathId() { return path_id_; } + +double RTCDataPath::getQueuingDealy() { return queuing_delay; } + +uint64_t RTCDataPath::getMinRtt() { + if (RTT_history_.size() != 0) return RTT_history_.begin(); + return 0; +} + +int64_t RTCDataPath::getMinOwd() { + if (OWD_history_.size() != 0) return OWD_history_.begin(); + return 0; +} + +double RTCDataPath::getJitter() { return jitter_; } + +uint64_t RTCDataPath::getLastPacketTS() { return last_received_data_packet_; } + +void RTCDataPath::clearRtt() { RTT_history_.clear(); } + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_data_path.h b/libtransport/src/protocols/rtc/rtc_data_path.h new file mode 100644 index 000000000..c5c37fc0d --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_data_path.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <stdint.h> +#include <utils/min_filter.h> + +#include <climits> + +namespace transport { + +namespace protocol { + +namespace rtc { + +const double ALPHA_RTC = 0.125; +const uint32_t HISTORY_LEN = 20; // 4 sec + +class RTCDataPath { + public: + RTCDataPath(uint32_t path_id); + + public: + void insertRttSample(uint64_t rtt); + void insertOwdSample(int64_t owd); + void computeInterArrivalGap(uint32_t segment_number); + void receivedNack(); + + uint32_t getPathId(); + uint64_t getMinRtt(); + double getQueuingDealy(); + double getInterArrivalGap(); + double getJitter(); + bool isActive(); + bool pathToProducer(); + uint64_t getLastPacketTS(); + + void clearRtt(); + + void roundEnd(); + + private: + uint32_t path_id_; + + int64_t getMinOwd(); + + uint64_t min_rtt; + uint64_t prev_min_rtt; + + int64_t min_owd; + int64_t prev_min_owd; + + double avg_owd; + + double queuing_delay; + + double jitter_; + int64_t last_owd_; + + uint32_t largest_recv_seq_; + uint64_t largest_recv_seq_time_; + double avg_inter_arrival_; + + // flags to check if a path is active + // we considere a path active if it reaches a producer + //(not a cache) --aka we got at least one nack on this path-- + // and if we receives packets + bool received_nacks_; + bool received_packets_; + uint8_t rounds_without_packets_; // if we don't get any packet + // for MAX_ROUNDS_WITHOUT_PKTS + // we consider the path inactive + uint64_t last_received_data_packet_; // timestamp for the last data received + // on this path + + utils::MinFilter<uint64_t> RTT_history_; + utils::MinFilter<int64_t> OWD_history_; +}; + +} // namespace rtc + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_ldr.cc b/libtransport/src/protocols/rtc/rtc_ldr.cc new file mode 100644 index 000000000..e91b29c04 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_ldr.cc @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_ldr.h> + +#include <algorithm> +#include <unordered_set> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCLossDetectionAndRecovery::RTCLossDetectionAndRecovery( + SendRtxCallback &&callback, asio::io_service &io_service) + : rtx_on_(false), + next_rtx_timer_(MAX_TIMER_RTX), + last_event_(0), + sentinel_timer_interval_(MAX_TIMER_RTX), + send_rtx_callback_(std::move(callback)) { + timer_ = std::make_unique<asio::steady_timer>(io_service); + sentinel_timer_ = std::make_unique<asio::steady_timer>(io_service); +} + +RTCLossDetectionAndRecovery::~RTCLossDetectionAndRecovery() {} + +void RTCLossDetectionAndRecovery::turnOnRTX() { + rtx_on_ = true; + scheduleSentinelTimer(state_->getRTT() * CATCH_UP_RTT_INCREMENT); +} + +void RTCLossDetectionAndRecovery::turnOffRTX() { + rtx_on_ = false; + clear(); +} + +void RTCLossDetectionAndRecovery::onTimeout(uint32_t seq) { + // always add timeouts to the RTX list to avoid to send the same packet as if + // it was not a rtx + addToRetransmissions(seq, seq + 1); + last_event_ = getNow(); +} + +void RTCLossDetectionAndRecovery::onDataPacketReceived( + const core::ContentObject &content_object) { + last_event_ = getNow(); + + uint32_t seq = content_object.getName().getSuffix(); + if (deleteRtx(seq)) { + state_->onPacketRecovered(seq); + } else { + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + TRANSPORT_LOGD("received data. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, seq); + } +} + +void RTCLossDetectionAndRecovery::onNackPacketReceived( + const core::ContentObject &nack) { + last_event_ = getNow(); + + uint32_t seq = nack.getName().getSuffix(); + + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + + struct nack_packet_t *nack_pkt = + (struct nack_packet_t *)nack.getPayload()->data(); + uint32_t production_seq = nack_pkt->getProductionSegement(); + + if (production_seq > seq) { + // this is a past nack, all data before productionSeq are lost. if + // productionSeq > state_->getHighestSeqReceivedInOrder() is impossible to + // recover any packet. If this is not the case we can try to recover the + // packets between state_->getHighestSeqReceivedInOrder() and productionSeq. + // e.g.: the client receives packets 8 10 11 9 where 9 is a nack with + // productionSeq = 14. 9 is lost but we can try to recover packets 12 13 and + // 14 that are not arrived yet + deleteRtx(seq); + TRANSPORT_LOGD("received past nack. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); + } else { + // future nack. here there should be a gap between the last data received + // and this packet and is it possible to recover the packets between the + // last received data and the production seq. we should not use the seq + // number of the nack since we know that is too early to ask for this seq + // number + // e.g.: // e.g.: the client receives packets 10 11 12 20 where 20 is a nack + // with productionSeq = 18. this says that all the packets between 12 and 18 + // may got lost and we should ask them + deleteRtx(seq); + TRANSPORT_LOGD("received futrue nack. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); + } +} + +void RTCLossDetectionAndRecovery::onProbePacketReceived( + const core::ContentObject &probe) { + // we don't log the reception of a probe packet for the sentinel timer because + // probes are not taken into account into the sync window. we use them as + // future nacks to detect possible packets lost + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + struct nack_packet_t *probe_pkt = + (struct nack_packet_t *)probe.getPayload()->data(); + uint32_t production_seq = probe_pkt->getProductionSegement(); + TRANSPORT_LOGD("received probe. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); +} + +void RTCLossDetectionAndRecovery::clear() { + rtx_state_.clear(); + rtx_timers_.clear(); + sentinel_timer_->cancel(); + if (next_rtx_timer_ != MAX_TIMER_RTX) { + next_rtx_timer_ = MAX_TIMER_RTX; + timer_->cancel(); + } +} + +void RTCLossDetectionAndRecovery::addToRetransmissions(uint32_t start, + uint32_t stop) { + // skip nacked packets + if (start <= state_->getLastSeqNacked()) { + start = state_->getLastSeqNacked() + 1; + } + + // skip received or lost packets + if (start <= state_->getHighestSeqReceivedInOrder()) { + start = state_->getHighestSeqReceivedInOrder() + 1; + } + + for (uint32_t seq = start; seq < stop; seq++) { + if (!isRtx(seq) && // is not already an rtx + // is not received or lost + state_->isReceivedOrLost(seq) == PacketState::UNKNOWN) { + // add rtx + rtxState state; + state.first_send_ = state_->getInterestSentTime(seq); + if (state.first_send_ == 0) // this interest was never sent before + state.first_send_ = getNow(); + state.next_send_ = computeNextSend(seq, true); + state.rtx_count_ = 0; + TRANSPORT_LOGD("add %u to retransmissions. next rtx is %lu ", seq, + (state.next_send_ - getNow())); + rtx_state_.insert(std::pair<uint32_t, rtxState>(seq, state)); + rtx_timers_.insert(std::pair<uint64_t, uint32_t>(state.next_send_, seq)); + } + } + scheduleNextRtx(); +} + +uint64_t RTCLossDetectionAndRecovery::computeNextSend(uint32_t seq, + bool new_rtx) { + uint64_t now = getNow(); + if (new_rtx) { + // for the new rtx we wait one estimated IAT after the loss detection. this + // is bacause, assuming that packets arrive with a constant IAT, we should + // get a new packet every IAT + double prod_rate = state_->getProducerRate(); + uint32_t estimated_iat = SENTINEL_TIMER_INTERVAL; + uint32_t jitter = 0; + + if (prod_rate != 0) { + double packet_size = state_->getAveragePacketSize(); + estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + jitter = ceil(state_->getJitter()); + } + + uint32_t wait = estimated_iat + jitter; + TRANSPORT_LOGD("first rtx for %u in %u ms, rtt = %lu ait = %u jttr = %u", + seq, wait, state_->getRTT(), estimated_iat, jitter); + + return now + wait; + } else { + // wait one RTT + // however if the IAT is larger than the RTT, wait one IAT + uint32_t wait = SENTINEL_TIMER_INTERVAL; + + double prod_rate = state_->getProducerRate(); + if (prod_rate == 0) { + return now + SENTINEL_TIMER_INTERVAL; + } + + double packet_size = state_->getAveragePacketSize(); + uint32_t estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + + uint64_t rtt = state_->getRTT(); + if (rtt == 0) rtt = SENTINEL_TIMER_INTERVAL; + wait = rtt; + + if (estimated_iat > rtt) wait = estimated_iat; + + uint32_t jitter = ceil(state_->getJitter()); + wait += jitter; + + // it may happen that the channel is congested and we have some additional + // queuing delay to take into account + uint32_t queue = ceil(state_->getQueuing()); + wait += queue; + + TRANSPORT_LOGD( + "next rtx for %u in %u ms, rtt = %lu ait = %u jttr = %u queue = %u", + seq, wait, state_->getRTT(), estimated_iat, jitter, queue); + + return now + wait; + } +} + +void RTCLossDetectionAndRecovery::retransmit() { + if (rtx_timers_.size() == 0) return; + + uint64_t now = getNow(); + + auto it = rtx_timers_.begin(); + std::unordered_set<uint32_t> lost_pkt; + uint32_t sent_counter = 0; + while (it != rtx_timers_.end() && it->first <= now && + sent_counter < MAX_INTERESTS_IN_BATCH) { + uint32_t seq = it->second; + auto rtx_it = + rtx_state_.find(seq); // this should always return a valid iter + if (rtx_it->second.rtx_count_ >= RTC_MAX_RTX || + (now - rtx_it->second.first_send_) >= RTC_MAX_AGE || + seq < state_->getLastSeqNacked()) { + // max rtx reached or packet too old or packet nacked, this packet is lost + TRANSPORT_LOGD( + "packet %u lost because 1) max rtx: %u 2) max age: %u 3) naked: %u", + seq, (rtx_it->second.rtx_count_ >= RTC_MAX_RTX), + ((now - rtx_it->second.first_send_) >= RTC_MAX_AGE), + (seq < state_->getLastSeqNacked())); + lost_pkt.insert(seq); + it++; + } else { + // resend the packet + state_->onRetransmission(seq); + double prod_rate = state_->getProducerRate(); + if (prod_rate != 0) rtx_it->second.rtx_count_++; + rtx_it->second.next_send_ = computeNextSend(seq, false); + it = rtx_timers_.erase(it); + rtx_timers_.insert( + std::pair<uint64_t, uint32_t>(rtx_it->second.next_send_, seq)); + TRANSPORT_LOGD("send rtx for sequence %u, next send in %lu", seq, + (rtx_it->second.next_send_ - now)); + send_rtx_callback_(seq); + sent_counter++; + } + } + + // remove packets if needed + for (auto lost_it = lost_pkt.begin(); lost_it != lost_pkt.end(); lost_it++) { + uint32_t seq = *lost_it; + state_->onPacketLost(seq); + deleteRtx(seq); + } +} + +void RTCLossDetectionAndRecovery::scheduleNextRtx() { + if (rtx_timers_.size() == 0) { + // all the rtx were removed, reset timer + next_rtx_timer_ = MAX_TIMER_RTX; + return; + } + + // check if timer is alreay set + if (next_rtx_timer_ != MAX_TIMER_RTX) { + // a new check for rtx is already scheduled + if (next_rtx_timer_ > rtx_timers_.begin()->first) { + // we need to re-schedule it + timer_->cancel(); + } else { + // wait for the next timer + return; + } + } + + // set a new timer + next_rtx_timer_ = rtx_timers_.begin()->first; + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t wait = 1; + if (next_rtx_timer_ != MAX_TIMER_RTX && next_rtx_timer_ > now) + wait = next_rtx_timer_ - now; + + std::weak_ptr<RTCLossDetectionAndRecovery> self(shared_from_this()); + timer_->expires_from_now(std::chrono::milliseconds(wait)); + timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->retransmit(); + s->next_rtx_timer_ = MAX_TIMER_RTX; + s->scheduleNextRtx(); + } + }); +} + +bool RTCLossDetectionAndRecovery::deleteRtx(uint32_t seq) { + auto it_rtx = rtx_state_.find(seq); + if (it_rtx == rtx_state_.end()) return false; // rtx not found + + uint64_t ts = it_rtx->second.next_send_; + auto it_timers = rtx_timers_.find(ts); + while (it_timers != rtx_timers_.end() && it_timers->first == ts) { + if (it_timers->second == seq) { + rtx_timers_.erase(it_timers); + break; + } + it_timers++; + } + + bool lost = it_rtx->second.rtx_count_ > 0; + rtx_state_.erase(it_rtx); + + return lost; +} + +void RTCLossDetectionAndRecovery::scheduleSentinelTimer( + uint64_t expires_from_now) { + std::weak_ptr<RTCLossDetectionAndRecovery> self(shared_from_this()); + sentinel_timer_->expires_from_now( + std::chrono::milliseconds(expires_from_now)); + sentinel_timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->sentinelTimer(); + } + }); +} + +void RTCLossDetectionAndRecovery::sentinelTimer() { + uint64_t now = getNow(); + + bool expired = false; + bool sent = false; + if ((now - last_event_) >= sentinel_timer_interval_) { + // at least a sentinel_timer_interval_ elapsed since last event + expired = true; + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) { + // this happens at the beginning (or if the producer stops for some + // reason) we need to keep sending interest 0 until we get an answer + TRANSPORT_LOGD( + "sentinel timer: the producer is not active, send packet 0"); + state_->onRetransmission(0); + send_rtx_callback_(0); + } else { + TRANSPORT_LOGD( + "sentinel timer: the producer is active, send the 10 oldest packets"); + sent = true; + uint32_t rtx = 0; + auto it = state_->getPendingInterestsMapBegin(); + auto end = state_->getPendingInterestsMapEnd(); + while (it != end && rtx < MAX_RTX_WITH_SENTINEL) { + uint32_t seq = it->first; + TRANSPORT_LOGD("sentinel timer, add %u to the rtx list", seq); + addToRetransmissions(seq, seq + 1); + rtx++; + it++; + } + } + } else { + // sentinel timer did not expire because we registered at least one event + } + + uint32_t next_timer; + double prod_rate = state_->getProducerRate(); + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive()) || prod_rate == 0) { + TRANSPORT_LOGD("next timer in %u", SENTINEL_TIMER_INTERVAL); + next_timer = SENTINEL_TIMER_INTERVAL; + } else { + double prod_rate = state_->getProducerRate(); + double packet_size = state_->getAveragePacketSize(); + uint32_t estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + uint32_t jitter = ceil(state_->getJitter()); + + // try to reduce the number of timers if the estimated IAT is too small + next_timer = std::max((estimated_iat + jitter) * 20, (uint32_t)1); + TRANSPORT_LOGD("next sentinel in %u ms, rate: %f, iat: %u, jitter: %u", + next_timer, ((prod_rate * 8.0) / 1000000.0), estimated_iat, + jitter); + + if (!expired) { + // discount the amout of time that is already passed + uint32_t discount = now - last_event_; + if (next_timer > discount) { + next_timer = next_timer - discount; + } else { + // in this case we trigger the timer in 1 ms + next_timer = 1; + } + TRANSPORT_LOGD("timer after discout: %u", next_timer); + } else if (sent) { + // wait at least one producer stats interval + owd to check if the + // production rate is reducing. + uint32_t min_wait = PRODUCER_STATS_INTERVAL + ceil(state_->getQueuing()); + next_timer = std::max(next_timer, min_wait); + TRANSPORT_LOGD("wait for updates from prod, next timer: %u", next_timer); + } + } + + scheduleSentinelTimer(next_timer); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_ldr.h b/libtransport/src/protocols/rtc/rtc_ldr.h new file mode 100644 index 000000000..c0912303b --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_ldr.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/config.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/name.h> +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_state.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <functional> +#include <map> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCLossDetectionAndRecovery + : public std::enable_shared_from_this<RTCLossDetectionAndRecovery> { + struct rtx_state_ { + uint64_t first_send_; + uint64_t next_send_; + uint32_t rtx_count_; + }; + + using rtxState = struct rtx_state_; + using SendRtxCallback = std::function<void(uint32_t)>; + + public: + RTCLossDetectionAndRecovery(SendRtxCallback &&callback, + asio::io_service &io_service); + + ~RTCLossDetectionAndRecovery(); + + void setState(std::shared_ptr<RTCState> state) { state_ = state; } + void turnOnRTX(); + void turnOffRTX(); + + void onTimeout(uint32_t seq); + void onDataPacketReceived(const core::ContentObject &content_object); + void onNackPacketReceived(const core::ContentObject &nack); + void onProbePacketReceived(const core::ContentObject &probe); + + void clear(); + + bool isRtx(uint32_t seq) { + if (rtx_state_.find(seq) != rtx_state_.end()) return true; + return false; + } + + private: + void addToRetransmissions(uint32_t start, uint32_t stop); + uint64_t computeNextSend(uint32_t seq, bool new_rtx); + void retransmit(); + void scheduleNextRtx(); + bool deleteRtx(uint32_t seq); + void scheduleSentinelTimer(uint64_t expires_from_now); + void sentinelTimer(); + + uint64_t getNow() { + using namespace std::chrono; + uint64_t now = + duration_cast<milliseconds>(steady_clock::now().time_since_epoch()) + .count(); + return now; + } + + // this map keeps track of the retransmitted interest, ordered from the oldest + // to the newest one. the state contains the timer of the first send of the + // interest (from pendingIntetests_), the timer of the next send (key of the + // multimap) and the number of rtx + std::map<uint32_t, rtxState> rtx_state_; + // this map stored the rtx by timer. The key is the time at which the rtx + // should be sent, and the val is the interest seq number + std::multimap<uint64_t, uint32_t> rtx_timers_; + + bool rtx_on_; + uint64_t next_rtx_timer_; + uint64_t last_event_; + uint64_t sentinel_timer_interval_; + std::unique_ptr<asio::steady_timer> timer_; + std::unique_ptr<asio::steady_timer> sentinel_timer_; + std::shared_ptr<RTCState> state_; + + SendRtxCallback send_rtx_callback_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_packet.h b/libtransport/src/protocols/rtc/rtc_packet.h new file mode 100644 index 000000000..abb1323a3 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_packet.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + */ + +/* data packet + * +-----------------------------------------+ + * | uint64_t: timestamp | + * | | + * +-----------------------------------------+ + * | uint32_t: prod rate (bytes per sec) | + * +-----------------------------------------+ + * | payload | + * | ... | + */ + +/* nack packet + * +-----------------------------------------+ + * | uint64_t: timestamp | + * | | + * +-----------------------------------------+ + * | uint32_t: prod rate (bytes per sec) | + * +-----------------------------------------+ + * | uint32_t: current seg in production | + * +-----------------------------------------+ + */ + +#pragma once +#include <arpa/inet.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +inline uint64_t _ntohll(const uint64_t *input) { + uint64_t return_val; + uint8_t *tmp = (uint8_t *)&return_val; + + tmp[0] = *input >> 56; + tmp[1] = *input >> 48; + tmp[2] = *input >> 40; + tmp[3] = *input >> 32; + tmp[4] = *input >> 24; + tmp[5] = *input >> 16; + tmp[6] = *input >> 8; + tmp[7] = *input >> 0; + + return return_val; +} + +inline uint64_t _htonll(const uint64_t *input) { return (_ntohll(input)); } + +const uint32_t DATA_HEADER_SIZE = 12; // bytes + // XXX: sizeof(data_packet_t) is 16 + // beacuse of padding +const uint32_t NACK_HEADER_SIZE = 16; + +struct data_packet_t { + uint64_t timestamp; + uint32_t prod_rate; + + inline uint64_t getTimestamp() const { return _ntohll(×tamp); } + inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + + inline uint32_t getProductionRate() const { return ntohl(prod_rate); } + inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } +}; + +struct nack_packet_t { + uint64_t timestamp; + uint32_t prod_rate; + uint32_t prod_seg; + + inline uint64_t getTimestamp() const { return _ntohll(×tamp); } + inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + + inline uint32_t getProductionRate() const { return ntohl(prod_rate); } + inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } + + inline uint32_t getProductionSegement() const { return ntohl(prod_seg); } + inline void setProductionSegement(uint32_t seg) { prod_seg = htonl(seg); } +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc.h b/libtransport/src/protocols/rtc/rtc_rc.h new file mode 100644 index 000000000..34d090092 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <protocols/rtc/rtc_state.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControl : public std::enable_shared_from_this<RTCRateControl> { + public: + RTCRateControl() + : rc_on_(false), + congestion_win_(1000000), // init the win to a large number + congestion_state_(CongestionState::Normal), + protocol_state_(nullptr) {} + + virtual ~RTCRateControl() = default; + + void turnOnRateControl() { rc_on_ = true; } + void setState(std::shared_ptr<RTCState> state) { protocol_state_ = state; }; + uint32_t getCongesionWindow() { return congestion_win_; }; + + virtual void onNewRound(double round_len) = 0; + virtual void onDataPacketReceived( + const core::ContentObject &content_object) = 0; + + protected: + enum class CongestionState { Normal = 0, Underuse = 1, Congested = 2, Last }; + + protected: + bool rc_on_; + uint32_t congestion_win_; + CongestionState congestion_state_; + + std::shared_ptr<RTCState> protocol_state_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_frame.cc b/libtransport/src/protocols/rtc/rtc_rc_frame.cc new file mode 100644 index 000000000..b577b5bea --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_frame.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_frame.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCRateControlFrame::RTCRateControlFrame() : cc_detector_() {} + +RTCRateControlFrame::~RTCRateControlFrame() {} + +void RTCRateControlFrame::onNewRound(double round_len) { + if (!rc_on_) return; + + CongestionState prev_congestion_state = congestion_state_; + cc_detector_.updateStats(); + congestion_state_ = (CongestionState)cc_detector_.getState(); + + if (congestion_state_ == CongestionState::Congested) { + if (prev_congestion_state == CongestionState::Normal) { + // congestion detected, notify app and init congestion win + double prod_rate = protocol_state_->getReceivedRate(); + double rtt = (double)protocol_state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = protocol_state_->getAveragePacketSize(); + + if (prod_rate == 0.0 || rtt == 0.0 || packet_size == 0.0) { + // TODO do something + return; + } + + congestion_win_ = (uint32_t)ceil(prod_rate * rtt / packet_size); + } + uint32_t win = congestion_win_ * WIN_DECREASE_FACTOR; + congestion_win_ = std::max(win, WIN_MIN); + return; + } +} + +void RTCRateControlFrame::onDataPacketReceived( + const core::ContentObject &content_object) { + if (!rc_on_) return; + + uint32_t seq = content_object.getName().getSuffix(); + if (!protocol_state_->isPending(seq)) return; + + cc_detector_.addPacket(content_object); +} + +void RTCRateControlFrame::receivedBwProbeTrain(uint64_t firts_probe_ts, + uint64_t last_probe_ts, + uint32_t total_probes) { + // TODO + return; +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_frame.h b/libtransport/src/protocols/rtc/rtc_rc_frame.h new file mode 100644 index 000000000..25d5ddbb6 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_frame.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <protocols/rtc/congestion_detection.h> +#include <protocols/rtc/rtc_rc.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControlFrame : public RTCRateControl { + public: + RTCRateControlFrame(); + + ~RTCRateControlFrame(); + + void onNewRound(double round_len); + void onDataPacketReceived(const core::ContentObject &content_object); + + void receivedBwProbeTrain(uint64_t firts_probe_ts, uint64_t last_probe_ts, + uint32_t total_probes); + + private: + CongestionDetection cc_detector_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_queue.cc b/libtransport/src/protocols/rtc/rtc_rc_queue.cc new file mode 100644 index 000000000..a1c89e329 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_queue.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_queue.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCRateControlQueue::RTCRateControlQueue() + : rounds_since_last_drop_(0), + rounds_without_congestion_(0), + last_queue_(0) {} + +RTCRateControlQueue::~RTCRateControlQueue() {} + +void RTCRateControlQueue::onNewRound(double round_len) { + if (!rc_on_) return; + + double received_rate = protocol_state_->getReceivedRate(); + double target_rate = + protocol_state_->getProducerRate() * PRODUCTION_RATE_FRACTION; + double rtt = (double)protocol_state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = protocol_state_->getAveragePacketSize(); + double queue = protocol_state_->getQueuing(); + + if (rtt == 0.0) return; // no info from the producer + + CongestionState prev_congestion_state = congestion_state_; + + if (prev_congestion_state == CongestionState::Normal && + received_rate >= target_rate) { + // if the queue is high in this case we are most likelly fighting with + // a TCP flow and there is enough bandwidth to match the producer rate + congestion_state_ = CongestionState::Normal; + } else if (queue > MAX_QUEUING_DELAY || last_queue_ == queue) { + // here we detect congestion. in the case that last_queue == queue + // the consumer didn't receive any packet from the producer so we + // consider this case as congestion + // TODO: wath happen in case of high loss rate? + congestion_state_ = CongestionState::Congested; + } else { + // nothing bad is happening + congestion_state_ = CongestionState::Normal; + } + + last_queue_ = queue; + + if (congestion_state_ == CongestionState::Congested) { + if (prev_congestion_state == CongestionState::Normal) { + // init the congetion window using the received rate + congestion_win_ = (uint32_t)ceil(received_rate * rtt / packet_size); + rounds_since_last_drop_ = ROUNDS_BEFORE_TAKE_ACTION + 1; + } + + if (rounds_since_last_drop_ >= ROUNDS_BEFORE_TAKE_ACTION) { + uint32_t win = congestion_win_ * WIN_DECREASE_FACTOR; + congestion_win_ = std::max(win, WIN_MIN); + rounds_since_last_drop_ = 0; + return; + } + + rounds_since_last_drop_++; + } + + if (congestion_state_ == CongestionState::Normal) { + if (prev_congestion_state == CongestionState::Congested) { + rounds_without_congestion_ = 0; + } + + rounds_without_congestion_++; + if (rounds_without_congestion_ < ROUNDS_BEFORE_TAKE_ACTION) return; + + congestion_win_ = congestion_win_ * WIN_INCREASE_FACTOR; + congestion_win_ = std::min(congestion_win_, INITIAL_WIN_MAX); + } +} + +void RTCRateControlQueue::onDataPacketReceived( + const core::ContentObject &content_object) { + // nothing to do + return; +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_queue.h b/libtransport/src/protocols/rtc/rtc_rc_queue.h new file mode 100644 index 000000000..407354d43 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_queue.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/utils/shared_ptr_utils.h> +#include <protocols/rtc/rtc_rc.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControlQueue : public RTCRateControl { + public: + RTCRateControlQueue(); + + ~RTCRateControlQueue(); + + void onNewRound(double round_len); + void onDataPacketReceived(const core::ContentObject &content_object); + + auto shared_from_this() { return utils::shared_from(this); } + + private: + uint32_t rounds_since_last_drop_; + uint32_t rounds_without_congestion_; + double last_queue_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_state.cc b/libtransport/src/protocols/rtc/rtc_state.cc new file mode 100644 index 000000000..eabf8942c --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_state.cc @@ -0,0 +1,560 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_state.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCState::RTCState(ProbeHandler::SendProbeCallback &&rtt_probes_callback, + DiscoveredRttCallback &&discovered_rtt_callback, + asio::io_service &io_service) + : rtt_probes_(std::make_shared<ProbeHandler>( + std::move(rtt_probes_callback), io_service)), + discovered_rtt_callback_(std::move(discovered_rtt_callback)) { + init_rtt_timer_ = std::make_unique<asio::steady_timer>(io_service); + initParams(); +} + +RTCState::~RTCState() {} + +void RTCState::initParams() { + // packets counters (total) + sent_interests_ = 0; + sent_rtx_ = 0; + received_data_ = 0; + received_nacks_ = 0; + received_timeouts_ = 0; + received_probes_ = 0; + + // loss counters + packets_lost_ = 0; + losses_recovered_ = 0; + first_seq_in_round_ = 0; + highest_seq_received_ = 0; + highest_seq_received_in_order_ = 0; + last_seq_nacked_ = 0; + loss_rate_ = 0.0; + residual_loss_rate_ = 0.0; + + // bw counters + received_bytes_ = 0; + avg_packet_size_ = INIT_PACKET_SIZE; + production_rate_ = 0.0; + received_rate_ = 0.0; + + // nack counter + nack_on_last_round_ = false; + received_nacks_last_round_ = 0; + + // packets counter + received_packets_last_round_ = 0; + received_data_last_round_ = 0; + received_data_from_cache_ = 0; + data_from_cache_rate_ = 0; + sent_interests_last_round_ = 0; + sent_rtx_last_round_ = 0; + + // round conunters + rounds_ = 0; + rounds_without_nacks_ = 0; + rounds_without_packets_ = 0; + + last_production_seq_ = 0; + producer_is_active_ = false; + last_prod_update_ = 0; + + // paths stats + path_table_.clear(); + main_path_ = nullptr; + + // packet received + received_or_lost_packets_.clear(); + + // pending interests + pending_interests_.clear(); + + // init rtt + first_interest_sent_ = ~0; + init_rtt_ = false; + rtt_probes_->setProbes(INIT_RTT_PROBE_INTERVAL, INIT_RTT_PROBES); + rtt_probes_->sendProbes(); + setInitRttTimer(INIT_RTT_PROBE_RESTART); +} + +// packet events +void RTCState::onSendNewInterest(const core::Name *interest_name) { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint32_t seq = interest_name->getSuffix(); + pending_interests_.insert(std::pair<uint32_t, uint64_t>(seq, now)); + + if(sent_interests_ == 0) first_interest_sent_ = now; + + sent_interests_++; + sent_interests_last_round_++; +} + +void RTCState::onTimeout(uint32_t seq) { + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + pending_interests_.erase(it); + } + received_timeouts_++; +} + +void RTCState::onRetransmission(uint32_t seq) { + // remove the interest for the pendingInterest map only after the first rtx. + // in this way we can handle the ooo packets that come in late as normla + // packet. we consider a packet lost only if we sent at least an RTX for it. + // XXX this may become problematic if we stop the RTX transmissions + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + pending_interests_.erase(it); + packets_lost_++; + } + sent_rtx_++; + sent_rtx_last_round_++; +} + +void RTCState::onDataPacketReceived(const core::ContentObject &content_object, + bool compute_stats) { + uint32_t seq = content_object.getName().getSuffix(); + if (compute_stats) { + updatePathStats(content_object, false); + received_data_last_round_++; + } + received_data_++; + + struct data_packet_t *data_pkt = + (struct data_packet_t *)content_object.getPayload()->data(); + uint64_t production_time = data_pkt->getTimestamp(); + if (last_prod_update_ < production_time) { + last_prod_update_ = production_time; + uint32_t production_rate = data_pkt->getProductionRate(); + production_rate_ = (double)production_rate; + } + + updatePacketSize(content_object); + updateReceivedBytes(content_object); + addRecvOrLost(seq, PacketState::RECEIVED); + + if (seq > highest_seq_received_) highest_seq_received_ = seq; + + // the producer is responding + // it is generating valid data packets so we consider it active + producer_is_active_ = true; + + received_packets_last_round_++; +} + +void RTCState::onNackPacketReceived(const core::ContentObject &nack, + bool compute_stats) { + uint32_t seq = nack.getName().getSuffix(); + struct nack_packet_t *nack_pkt = + (struct nack_packet_t *)nack.getPayload()->data(); + uint64_t production_time = nack_pkt->getTimestamp(); + uint32_t production_seq = nack_pkt->getProductionSegement(); + uint32_t production_rate = nack_pkt->getProductionRate(); + + if (TRANSPORT_EXPECT_FALSE(main_path_ == nullptr) || + last_prod_update_ < production_time) { + // update production rate + last_prod_update_ = production_time; + last_production_seq_ = production_seq; + production_rate_ = (double)production_rate; + } + + if (compute_stats) { + // this is not an RTX + updatePathStats(nack, true); + nack_on_last_round_ = true; + } + + // for statistics pourpose we log all nacks, also the one received for + // retransmitted packets + received_nacks_++; + received_nacks_last_round_++; + + if (production_seq > seq) { + // old nack, seq is lost + // update last nacked + if (last_seq_nacked_ < seq) last_seq_nacked_ = seq; + TRANSPORT_LOGD("lost packet %u beacuse of a past nack", seq); + onPacketLost(seq); + } else if (seq > production_seq) { + // future nack + // remove the nack from the pending interest map + // (the packet is not received/lost yet) + pending_interests_.erase(seq); + } else { + // this should be a quite rear event. simply remove the + // packet from the pending interest list + pending_interests_.erase(seq); + } + + // the producer is responding + // we consider it active only if the production rate is not 0 + // or the production sequence number is not 1 + if (production_rate_ != 0 || production_seq != 1) { + producer_is_active_ = true; + } + + received_packets_last_round_++; +} + +void RTCState::onPacketLost(uint32_t seq) { + TRANSPORT_LOGD("packet %u is lost", seq); + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + // this packet was never retransmitted so it does + // not appear in the loss count + packets_lost_++; + } + addRecvOrLost(seq, PacketState::LOST); +} + +void RTCState::onPacketRecovered(uint32_t seq) { + losses_recovered_++; + addRecvOrLost(seq, PacketState::RECEIVED); +} + +bool RTCState::onProbePacketReceived(const core::ContentObject &probe) { + uint32_t seq = probe.getName().getSuffix(); + uint64_t rtt; + + rtt = rtt_probes_->getRtt(seq); + + if (rtt == 0) return false; // this is not a valid probe + + // like for data and nacks update the path stats. Here the RTT is computed + // by the probe handler. Both probes for rtt and bw are good to esimate + // info on the path + uint32_t path_label = probe.getPathLabel(); + + auto path_it = path_table_.find(path_label); + + // update production rate and last_seq_nacked like in case of a nack + struct nack_packet_t *probe_pkt = + (struct nack_packet_t *)probe.getPayload()->data(); + uint64_t sender_timestamp = probe_pkt->getTimestamp(); + uint32_t production_seq = probe_pkt->getProductionSegement(); + uint32_t production_rate = probe_pkt->getProductionRate(); + + + if (path_it == path_table_.end()) { + // found a new path + std::shared_ptr<RTCDataPath> newPath = + std::make_shared<RTCDataPath>(path_label); + auto ret = path_table_.insert( + std::pair<uint32_t, std::shared_ptr<RTCDataPath>>(path_label, newPath)); + path_it = ret.first; + } + + auto path = path_it->second; + + path->insertRttSample(rtt); + path->receivedNack(); + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + int64_t OWD = now - sender_timestamp; + path->insertOwdSample(OWD); + + if (last_prod_update_ < sender_timestamp) { + last_production_seq_ = production_seq; + last_prod_update_ = sender_timestamp; + production_rate_ = (double)production_rate; + } + + // the producer is responding + // we consider it active only if the production rate is not 0 + // or the production sequence numner is not 1 + if (production_rate_ != 0 || production_seq != 1) { + producer_is_active_ = true; + } + + // check for init RTT. if received_probes_ is equal to 0 schedule a timer to + // wait for the INIT_RTT_PROBES. in this way if some probes get lost we don't + // wait forever + received_probes_++; + + if(!init_rtt_ && received_probes_ <= INIT_RTT_PROBES){ + if(received_probes_ == 1){ + // we got the first probe, wait at most INIT_RTT_PROBE_WAIT sec for the others + main_path_ = path; + setInitRttTimer(INIT_RTT_PROBE_WAIT); + } + if(received_probes_ == INIT_RTT_PROBES) { + // we are done + init_rtt_timer_->cancel(); + checkInitRttTimer(); + } + } + + received_packets_last_round_++; + + // ignore probes sent before the first interest + if((now - rtt) <= first_interest_sent_) return false; + return true; +} + +void RTCState::onNewRound(double round_len, bool in_sync) { + // XXX + // here we take into account only the single path case so we assume that we + // don't use two paths in parellel for this single flow + + if (path_table_.empty()) return; + + double bytes_per_sec = + ((double)received_bytes_ * (MILLI_IN_A_SEC / round_len)); + if(received_rate_ == 0) + received_rate_ = bytes_per_sec; + else + received_rate_ = (received_rate_ * MOVING_AVG_ALPHA) + + ((1 - MOVING_AVG_ALPHA) * bytes_per_sec); + + // search for an active path. There should be only one active path (meaning a + // path that leads to the producer socket -no cache- and from which we are + // currently getting data packets) at any time. However it may happen that + // there are mulitple active paths in case of mobility (the old path will + // remain active for a short ammount of time). The main path is selected as + // the active path from where the consumer received the latest data packet + + uint64_t last_packet_ts = 0; + main_path_ = nullptr; + + for (auto it = path_table_.begin(); it != path_table_.end(); it++) { + it->second->roundEnd(); + if (it->second->isActive()) { + uint64_t ts = it->second->getLastPacketTS(); + if (ts > last_packet_ts) { + last_packet_ts = ts; + main_path_ = it->second; + } + } + } + + if (in_sync) updateLossRate(); + + // handle nacks + if (!nack_on_last_round_ && received_bytes_ > 0) { + rounds_without_nacks_++; + } else { + rounds_without_nacks_ = 0; + } + + // check if the producer is active + if (received_packets_last_round_ != 0) { + rounds_without_packets_ = 0; + } else { + rounds_without_packets_++; + if (rounds_without_packets_ >= MAX_ROUND_WHIOUT_PACKETS && + producer_is_active_ != false) { + initParams(); + } + } + + // compute cache/producer ratio + if (received_data_last_round_ != 0) { + double new_rate = + (double)received_data_from_cache_ / (double)received_data_last_round_; + data_from_cache_rate_ = data_from_cache_rate_ * MOVING_AVG_ALPHA + + (new_rate * (1 - MOVING_AVG_ALPHA)); + } + + // reset counters + received_bytes_ = 0; + packets_lost_ = 0; + losses_recovered_ = 0; + first_seq_in_round_ = highest_seq_received_; + + nack_on_last_round_ = false; + received_nacks_last_round_ = 0; + + received_packets_last_round_ = 0; + received_data_last_round_ = 0; + received_data_from_cache_ = 0; + sent_interests_last_round_ = 0; + sent_rtx_last_round_ = 0; + + rounds_++; +} + +void RTCState::updateReceivedBytes(const core::ContentObject &content_object) { + received_bytes_ += + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); +} + +void RTCState::updatePacketSize(const core::ContentObject &content_object) { + uint32_t pkt_size = + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); + avg_packet_size_ = (MOVING_AVG_ALPHA * avg_packet_size_) + + ((1 - MOVING_AVG_ALPHA) * pkt_size); +} + +void RTCState::updatePathStats(const core::ContentObject &content_object, + bool is_nack) { + // get packet path + uint32_t path_label = content_object.getPathLabel(); + auto path_it = path_table_.find(path_label); + + if (path_it == path_table_.end()) { + // found a new path + std::shared_ptr<RTCDataPath> newPath = + std::make_shared<RTCDataPath>(path_label); + auto ret = path_table_.insert( + std::pair<uint32_t, std::shared_ptr<RTCDataPath>>(path_label, newPath)); + path_it = ret.first; + } + + auto path = path_it->second; + + // compute rtt + uint32_t seq = content_object.getName().getSuffix(); + uint64_t interest_sent_time = getInterestSentTime(seq); + if (interest_sent_time == 0) + return; // this should not happen, + // it means that we are processing an interest + // that is not pending + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + uint64_t RTT = now - interest_sent_time; + + path->insertRttSample(RTT); + + // compute OWD (the first part of the nack and data packet header are the + // same, so we cast to data data packet) + struct data_packet_t *packet = + (struct data_packet_t *)content_object.getPayload()->data(); + uint64_t sender_timestamp = packet->getTimestamp(); + int64_t OWD = now - sender_timestamp; + path->insertOwdSample(OWD); + + // compute IAT or set path to producer + if (!is_nack) { + // compute the iat only for the content packets + uint32_t segment_number = content_object.getName().getSuffix(); + path->computeInterArrivalGap(segment_number); + if (!path->pathToProducer()) received_data_from_cache_++; + } else { + path->receivedNack(); + } +} + +void RTCState::updateLossRate() { + loss_rate_ = 0.0; + residual_loss_rate_ = 0.0; + + uint32_t number_theorically_received_packets_ = + highest_seq_received_ - first_seq_in_round_; + + // in this case no new packet was recevied after the previuos round, avoid + // division by 0 + if (number_theorically_received_packets_ == 0) return; + + loss_rate_ = (double)((double)(packets_lost_) / + (double)number_theorically_received_packets_); + + residual_loss_rate_ = (double)((double)(packets_lost_ - losses_recovered_) / + (double)number_theorically_received_packets_); + + if (residual_loss_rate_ < 0) residual_loss_rate_ = 0; +} + +void RTCState::addRecvOrLost(uint32_t seq, PacketState state) { + pending_interests_.erase(seq); + if (received_or_lost_packets_.size() >= MAX_CACHED_PACKETS) { + received_or_lost_packets_.erase(received_or_lost_packets_.begin()); + } + // notice that it may happen that a packet that we consider lost arrives after + // some time, in this case we simply overwrite the packet state. + received_or_lost_packets_[seq] = state; + + // keep track of the last packet received/lost + // without holes. + if (highest_seq_received_in_order_ < last_seq_nacked_) { + highest_seq_received_in_order_ = last_seq_nacked_; + } + + if ((highest_seq_received_in_order_ + 1) == seq) { + highest_seq_received_in_order_ = seq; + } else if (seq <= highest_seq_received_in_order_) { + // here we do nothing + } else if (seq > highest_seq_received_in_order_) { + // 1) there is a gap in the sequence so we do not update largest_in_seq_ + // 2) all the packets from largest_in_seq_ to seq are in + // received_or_lost_packets_ an we upate largest_in_seq_ + + for (uint32_t i = highest_seq_received_in_order_ + 1; i <= seq; i++) { + if (received_or_lost_packets_.find(i) == + received_or_lost_packets_.end()) { + break; + } + // this packet is in order so we can update the + // highest_seq_received_in_order_ + highest_seq_received_in_order_ = i; + } + } +} + +void RTCState::setInitRttTimer(uint32_t wait){ + init_rtt_timer_->cancel(); + init_rtt_timer_->expires_from_now(std::chrono::milliseconds(wait)); + init_rtt_timer_->async_wait([this](std::error_code ec) { + if(ec) return; + checkInitRttTimer(); + }); +} + +void RTCState::checkInitRttTimer() { + if(received_probes_ < INIT_RTT_MIN_PROBES_TO_RECV){ + // we didn't received enough probes, restart + received_probes_ = 0; + rtt_probes_->setProbes(INIT_RTT_PROBE_INTERVAL, INIT_RTT_PROBES); + rtt_probes_->sendProbes(); + setInitRttTimer(INIT_RTT_PROBE_RESTART); + return; + } + init_rtt_ = true; + main_path_->roundEnd(); + rtt_probes_->setProbes(RTT_PROBE_INTERVAL, 0); + rtt_probes_->sendProbes(); + + // init last_seq_nacked_. skip packets that may come from the cache + double prod_rate = getProducerRate(); + double rtt = (double)getRTT() / MILLI_IN_A_SEC; + double packet_size = getAveragePacketSize(); + uint32_t pkt_in_rtt_ = std::floor(((prod_rate / packet_size) * rtt) * 0.8); + last_seq_nacked_ = last_production_seq_ + pkt_in_rtt_; + + discovered_rtt_callback_(); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_state.h b/libtransport/src/protocols/rtc/rtc_state.h new file mode 100644 index 000000000..943a0a113 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_state.h @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/config.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/name.h> +#include <protocols/rtc/probe_handler.h> +#include <protocols/rtc/rtc_data_path.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <map> +#include <set> + +namespace transport { + +namespace protocol { + +namespace rtc { + +enum class PacketState : uint8_t { RECEIVED, LOST, UNKNOWN }; + +class RTCState : std::enable_shared_from_this<RTCState> { + public: + using DiscoveredRttCallback = std::function<void()>; + public: + RTCState(ProbeHandler::SendProbeCallback &&rtt_probes_callback, + DiscoveredRttCallback &&discovered_rtt_callback, + asio::io_service &io_service); + + ~RTCState(); + + // packet events + void onSendNewInterest(const core::Name *interest_name); + void onTimeout(uint32_t seq); + void onRetransmission(uint32_t seq); + void onDataPacketReceived(const core::ContentObject &content_object, + bool compute_stats); + void onNackPacketReceived(const core::ContentObject &nack, + bool compute_stats); + void onPacketLost(uint32_t seq); + void onPacketRecovered(uint32_t seq); + bool onProbePacketReceived(const core::ContentObject &probe); + + // protocol state + void onNewRound(double round_len, bool in_sync); + + // main path + uint32_t getProducerPath() const { + if (mainPathIsValid()) return main_path_->getPathId(); + return 0; + } + + // delay metrics + bool isRttDiscovered() const { + return init_rtt_; + } + + uint64_t getRTT() const { + if (mainPathIsValid()) return main_path_->getMinRtt(); + return 0; + } + void resetRttStats() { + if (mainPathIsValid()) main_path_->clearRtt(); + } + + double getQueuing() const { + if (mainPathIsValid()) return main_path_->getQueuingDealy(); + return 0.0; + } + double getIAT() const { + if (mainPathIsValid()) return main_path_->getInterArrivalGap(); + return 0.0; + } + + double getJitter() const { + if (mainPathIsValid()) return main_path_->getJitter(); + return 0.0; + } + + // pending interests + uint64_t getInterestSentTime(uint32_t seq) { + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) return it->second; + return 0; + } + bool isPending(uint32_t seq) { + if (pending_interests_.find(seq) != pending_interests_.end()) return true; + return false; + } + uint32_t getPendingInterestNumber() const { + return pending_interests_.size(); + } + PacketState isReceivedOrLost(uint32_t seq) { + auto it = received_or_lost_packets_.find(seq); + if (it != received_or_lost_packets_.end()) return it->second; + return PacketState::UNKNOWN; + } + + // loss rate + double getLossRate() const { return loss_rate_; } + double getResidualLossRate() const { return residual_loss_rate_; } + uint32_t getHighestSeqReceivedInOrder() const { + return highest_seq_received_in_order_; + } + uint32_t getLostData() const { return packets_lost_; }; + uint32_t getRecoveredLosses() const { return losses_recovered_; } + + // generic stats + uint32_t getReceivedBytesInRound() const { return received_bytes_; } + uint32_t getReceivedNacksInRound() const { + return received_nacks_last_round_; + } + uint32_t getSentInterestInRound() const { return sent_interests_last_round_; } + uint32_t getSentRtxInRound() const { return sent_rtx_last_round_; } + + // bandwidth/production metrics + double getAvailableBw() const { return 0.0; }; // TODO + double getProducerRate() const { return production_rate_; } + double getReceivedRate() const { return received_rate_; } + double getAveragePacketSize() const { return avg_packet_size_; } + + // nacks + uint32_t getRoundsWithoutNacks() const { return rounds_without_nacks_; } + uint32_t getLastSeqNacked() const { return last_seq_nacked_; } + + // producer state + bool isProducerActive() const { return producer_is_active_; } + + // packets from cache + double getPacketFromCacheRatio() const { return data_from_cache_rate_; } + + std::map<uint32_t, uint64_t>::iterator getPendingInterestsMapBegin() { + return pending_interests_.begin(); + } + std::map<uint32_t, uint64_t>::iterator getPendingInterestsMapEnd() { + return pending_interests_.end(); + } + + private: + void initParams(); + + // update stats + void updateState(); + void updateReceivedBytes(const core::ContentObject &content_object); + void updatePacketSize(const core::ContentObject &content_object); + void updatePathStats(const core::ContentObject &content_object, bool is_nack); + void updateLossRate(); + + void addRecvOrLost(uint32_t seq, PacketState state); + + void setInitRttTimer(uint32_t wait); + void checkInitRttTimer(); + + bool mainPathIsValid() const { + if (main_path_ != nullptr) + return true; + else + return false; + } + + // packets counters (total) + uint32_t sent_interests_; + uint32_t sent_rtx_; + uint32_t received_data_; + uint32_t received_nacks_; + uint32_t received_timeouts_; + uint32_t received_probes_; + + // loss counters + int32_t packets_lost_; + int32_t losses_recovered_; + uint32_t first_seq_in_round_; + uint32_t highest_seq_received_; + uint32_t highest_seq_received_in_order_; + uint32_t last_seq_nacked_; // segment for which we got an oldNack + double loss_rate_; + double residual_loss_rate_; + + // bw counters + uint32_t received_bytes_; + double avg_packet_size_; + double production_rate_; // rate communicated by the producer using nacks + double received_rate_; // rate recevied by the consumer + + // nack counter + // the bool takes tracks only about the valid nacks (no rtx) and it is used to + // switch between the states. Instead received_nacks_last_round_ logs all the + // nacks for statistics + bool nack_on_last_round_; + uint32_t received_nacks_last_round_; + + // packets counter + uint32_t received_packets_last_round_; + uint32_t received_data_last_round_; + uint32_t received_data_from_cache_; + double data_from_cache_rate_; + uint32_t sent_interests_last_round_; + uint32_t sent_rtx_last_round_; + + // round conunters + uint32_t rounds_; + uint32_t rounds_without_nacks_; + uint32_t rounds_without_packets_; + + // init rtt + uint64_t first_interest_sent_; + + // producer state + bool + producer_is_active_; // the prodcuer is active if we receive some packets + uint32_t last_production_seq_; // last production seq received by the producer + uint64_t last_prod_update_; // timestamp of the last packets used to update + // stats from the producer + + // paths stats + std::unordered_map<uint32_t, std::shared_ptr<RTCDataPath>> path_table_; + std::shared_ptr<RTCDataPath> main_path_; + + // packet received + // cache where to store info about the last MAX_CACHED_PACKETS + std::map<uint32_t, PacketState> received_or_lost_packets_; + + // pending interests + std::map<uint32_t, uint64_t> pending_interests_; + + // probes + std::shared_ptr<ProbeHandler> rtt_probes_; + bool init_rtt_; + std::unique_ptr<asio::steady_timer> init_rtt_timer_; + + // callbacks + DiscoveredRttCallback discovered_rtt_callback_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/trendline_estimator.cc b/libtransport/src/protocols/rtc/trendline_estimator.cc new file mode 100644 index 000000000..7a0803857 --- /dev/null +++ b/libtransport/src/protocols/rtc/trendline_estimator.cc @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// FROM +// https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/congestion_controller/goog_cc/trendline_estimator.cc + +#include "trendline_estimator.h" + +#include <math.h> + +#include <algorithm> +#include <string> + +namespace transport { + +namespace protocol { + +namespace rtc { + +// Parameters for linear least squares fit of regression line to noisy data. +constexpr double kDefaultTrendlineSmoothingCoeff = 0.9; +constexpr double kDefaultTrendlineThresholdGain = 4.0; +// const char kBweWindowSizeInPacketsExperiment[] = +// "WebRTC-BweWindowSizeInPackets"; + +/*size_t ReadTrendlineFilterWindowSize( + const WebRtcKeyValueConfig* key_value_config) { + std::string experiment_string = + key_value_config->Lookup(kBweWindowSizeInPacketsExperiment); + size_t window_size; + int parsed_values = + sscanf(experiment_string.c_str(), "Enabled-%zu", &window_size); + if (parsed_values == 1) { + if (window_size > 1) + return window_size; + RTC_LOG(WARNING) << "Window size must be greater than 1."; + } + RTC_LOG(LS_WARNING) << "Failed to parse parameters for BweWindowSizeInPackets" + " experiment from field trial string. Using default."; + return TrendlineEstimatorSettings::kDefaultTrendlineWindowSize; +} +*/ + +OptionalDouble LinearFitSlope( + const std::deque<TrendlineEstimator::PacketTiming>& packets) { + // RTC_DCHECK(packets.size() >= 2); + // Compute the "center of mass". + double sum_x = 0; + double sum_y = 0; + for (const auto& packet : packets) { + sum_x += packet.arrival_time_ms; + sum_y += packet.smoothed_delay_ms; + } + double x_avg = sum_x / packets.size(); + double y_avg = sum_y / packets.size(); + // Compute the slope k = \sum (x_i-x_avg)(y_i-y_avg) / \sum (x_i-x_avg)^2 + double numerator = 0; + double denominator = 0; + for (const auto& packet : packets) { + double x = packet.arrival_time_ms; + double y = packet.smoothed_delay_ms; + numerator += (x - x_avg) * (y - y_avg); + denominator += (x - x_avg) * (x - x_avg); + } + if (denominator == 0) return OptionalDouble(); + return OptionalDouble(numerator / denominator); +} + +OptionalDouble ComputeSlopeCap( + const std::deque<TrendlineEstimator::PacketTiming>& packets, + const TrendlineEstimatorSettings& settings) { + /*RTC_DCHECK(1 <= settings.beginning_packets && + settings.beginning_packets < packets.size()); + RTC_DCHECK(1 <= settings.end_packets && + settings.end_packets < packets.size()); + RTC_DCHECK(settings.beginning_packets + settings.end_packets <= + packets.size());*/ + TrendlineEstimator::PacketTiming early = packets[0]; + for (size_t i = 1; i < settings.beginning_packets; ++i) { + if (packets[i].raw_delay_ms < early.raw_delay_ms) early = packets[i]; + } + size_t late_start = packets.size() - settings.end_packets; + TrendlineEstimator::PacketTiming late = packets[late_start]; + for (size_t i = late_start + 1; i < packets.size(); ++i) { + if (packets[i].raw_delay_ms < late.raw_delay_ms) late = packets[i]; + } + if (late.arrival_time_ms - early.arrival_time_ms < 1) { + return OptionalDouble(); + } + return OptionalDouble((late.raw_delay_ms - early.raw_delay_ms) / + (late.arrival_time_ms - early.arrival_time_ms) + + settings.cap_uncertainty); +} + +constexpr double kMaxAdaptOffsetMs = 15.0; +constexpr double kOverUsingTimeThreshold = 10; +constexpr int kMinNumDeltas = 60; +constexpr int kDeltaCounterMax = 1000; + +//} // namespace + +constexpr char TrendlineEstimatorSettings::kKey[]; + +TrendlineEstimatorSettings::TrendlineEstimatorSettings( + /*const WebRtcKeyValueConfig* key_value_config*/) { + /*if (absl::StartsWith( + key_value_config->Lookup(kBweWindowSizeInPacketsExperiment), + "Enabled")) { + window_size = ReadTrendlineFilterWindowSize(key_value_config); + } + Parser()->Parse(key_value_config->Lookup(TrendlineEstimatorSettings::kKey));*/ + window_size = kDefaultTrendlineWindowSize; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + + /*if (window_size < 10 || 200 < window_size) { + RTC_LOG(LS_WARNING) << "Window size must be between 10 and 200 packets"; + window_size = kDefaultTrendlineWindowSize; + } + if (enable_cap) { + if (beginning_packets < 1 || end_packets < 1 || + beginning_packets > window_size || end_packets > window_size) { + RTC_LOG(LS_WARNING) << "Size of beginning and end must be between 1 and " + << window_size; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + } + if (beginning_packets + end_packets > window_size) { + RTC_LOG(LS_WARNING) + << "Size of beginning plus end can't exceed the window size"; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + } + if (cap_uncertainty < 0.0 || 0.025 < cap_uncertainty) { + RTC_LOG(LS_WARNING) << "Cap uncertainty must be between 0 and 0.025"; + cap_uncertainty = 0.0; + } + }*/ +} + +/*std::unique_ptr<StructParametersParser> TrendlineEstimatorSettings::Parser() { + return StructParametersParser::Create("sort", &enable_sort, // + "cap", &enable_cap, // + "beginning_packets", + &beginning_packets, // + "end_packets", &end_packets, // + "cap_uncertainty", &cap_uncertainty, // + "window_size", &window_size); +}*/ + +TrendlineEstimator::TrendlineEstimator( + /*const WebRtcKeyValueConfig* key_value_config, + NetworkStatePredictor* network_state_predictor*/) + : settings_(), + smoothing_coef_(kDefaultTrendlineSmoothingCoeff), + threshold_gain_(kDefaultTrendlineThresholdGain), + num_of_deltas_(0), + first_arrival_time_ms_(-1), + accumulated_delay_(0), + smoothed_delay_(0), + delay_hist_(), + k_up_(0.0087), + k_down_(0.039), + overusing_time_threshold_(kOverUsingTimeThreshold), + threshold_(12.5), + prev_modified_trend_(NAN), + last_update_ms_(-1), + prev_trend_(0.0), + time_over_using_(-1), + overuse_counter_(0), + hypothesis_(BandwidthUsage::kBwNormal){ + // hypothesis_predicted_(BandwidthUsage::kBwNormal){//}, + // network_state_predictor_(network_state_predictor) { + /* RTC_LOG(LS_INFO) + << "Using Trendline filter for delay change estimation with settings " + << settings_.Parser()->Encode() << " and " + // << (network_state_predictor_ ? "injected" : "no") + << " network state predictor";*/ +} + +TrendlineEstimator::~TrendlineEstimator() {} + +void TrendlineEstimator::UpdateTrendline(double recv_delta_ms, + double send_delta_ms, + int64_t send_time_ms, + int64_t arrival_time_ms, + size_t packet_size) { + const double delta_ms = recv_delta_ms - send_delta_ms; + ++num_of_deltas_; + num_of_deltas_ = std::min(num_of_deltas_, kDeltaCounterMax); + if (first_arrival_time_ms_ == -1) first_arrival_time_ms_ = arrival_time_ms; + + // Exponential backoff filter. + accumulated_delay_ += delta_ms; + // BWE_TEST_LOGGING_PLOT(1, "accumulated_delay_ms", arrival_time_ms, + // accumulated_delay_); + smoothed_delay_ = smoothing_coef_ * smoothed_delay_ + + (1 - smoothing_coef_) * accumulated_delay_; + // BWE_TEST_LOGGING_PLOT(1, "smoothed_delay_ms", arrival_time_ms, + // smoothed_delay_); + + // Maintain packet window + delay_hist_.emplace_back( + static_cast<double>(arrival_time_ms - first_arrival_time_ms_), + smoothed_delay_, accumulated_delay_); + if (settings_.enable_sort) { + for (size_t i = delay_hist_.size() - 1; + i > 0 && + delay_hist_[i].arrival_time_ms < delay_hist_[i - 1].arrival_time_ms; + --i) { + std::swap(delay_hist_[i], delay_hist_[i - 1]); + } + } + if (delay_hist_.size() > settings_.window_size) delay_hist_.pop_front(); + + // Simple linear regression. + double trend = prev_trend_; + if (delay_hist_.size() == settings_.window_size) { + // Update trend_ if it is possible to fit a line to the data. The delay + // trend can be seen as an estimate of (send_rate - capacity)/capacity. + // 0 < trend < 1 -> the delay increases, queues are filling up + // trend == 0 -> the delay does not change + // trend < 0 -> the delay decreases, queues are being emptied + OptionalDouble trendO = LinearFitSlope(delay_hist_); + if (trendO.has_value()) trend = trendO.value(); + if (settings_.enable_cap) { + OptionalDouble cap = ComputeSlopeCap(delay_hist_, settings_); + // We only use the cap to filter out overuse detections, not + // to detect additional underuses. + if (trend >= 0 && cap.has_value() && trend > cap.value()) { + trend = cap.value(); + } + } + } + // BWE_TEST_LOGGING_PLOT(1, "trendline_slope", arrival_time_ms, trend); + + Detect(trend, send_delta_ms, arrival_time_ms); +} + +void TrendlineEstimator::Update(double recv_delta_ms, double send_delta_ms, + int64_t send_time_ms, int64_t arrival_time_ms, + size_t packet_size, bool calculated_deltas) { + if (calculated_deltas) { + UpdateTrendline(recv_delta_ms, send_delta_ms, send_time_ms, arrival_time_ms, + packet_size); + } + /*if (network_state_predictor_) { + hypothesis_predicted_ = network_state_predictor_->Update( + send_time_ms, arrival_time_ms, hypothesis_); + }*/ +} + +BandwidthUsage TrendlineEstimator::State() const { + return /*network_state_predictor_ ? hypothesis_predicted_ :*/ hypothesis_; +} + +void TrendlineEstimator::Detect(double trend, double ts_delta, int64_t now_ms) { + /*if (num_of_deltas_ < 2) { + hypothesis_ = BandwidthUsage::kBwNormal; + return; + }*/ + + const double modified_trend = + std::min(num_of_deltas_, kMinNumDeltas) * trend * threshold_gain_; + prev_modified_trend_ = modified_trend; + // BWE_TEST_LOGGING_PLOT(1, "T", now_ms, modified_trend); + // BWE_TEST_LOGGING_PLOT(1, "threshold", now_ms, threshold_); + if (modified_trend > threshold_) { + if (time_over_using_ == -1) { + // Initialize the timer. Assume that we've been + // over-using half of the time since the previous + // sample. + time_over_using_ = ts_delta / 2; + } else { + // Increment timer + time_over_using_ += ts_delta; + } + overuse_counter_++; + if (time_over_using_ > overusing_time_threshold_ && overuse_counter_ > 1) { + if (trend >= prev_trend_) { + time_over_using_ = 0; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwOverusing; + } + } + } else if (modified_trend < -threshold_) { + time_over_using_ = -1; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwUnderusing; + } else { + time_over_using_ = -1; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwNormal; + } + prev_trend_ = trend; + UpdateThreshold(modified_trend, now_ms); +} + +void TrendlineEstimator::UpdateThreshold(double modified_trend, + int64_t now_ms) { + if (last_update_ms_ == -1) last_update_ms_ = now_ms; + + if (fabs(modified_trend) > threshold_ + kMaxAdaptOffsetMs) { + // Avoid adapting the threshold to big latency spikes, caused e.g., + // by a sudden capacity drop. + last_update_ms_ = now_ms; + return; + } + + const double k = fabs(modified_trend) < threshold_ ? k_down_ : k_up_; + const int64_t kMaxTimeDeltaMs = 100; + int64_t time_delta_ms = std::min(now_ms - last_update_ms_, kMaxTimeDeltaMs); + threshold_ += k * (fabs(modified_trend) - threshold_) * time_delta_ms; + if (threshold_ < 6.f) threshold_ = 6.f; + if (threshold_ > 600.f) threshold_ = 600.f; + // threshold_ = rtc::SafeClamp(threshold_, 6.f, 600.f); + last_update_ms_ = now_ms; +} + +} // namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/trendline_estimator.h b/libtransport/src/protocols/rtc/trendline_estimator.h new file mode 100644 index 000000000..372acbc67 --- /dev/null +++ b/libtransport/src/protocols/rtc/trendline_estimator.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// FROM +// https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/congestion_controller/goog_cc/trendline_estimator.h + +#ifndef MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ +#define MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <deque> +#include <memory> +#include <utility> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class OptionalDouble { + public: + OptionalDouble() : val(0), has_val(false){}; + OptionalDouble(double val) : val(val), has_val(true){}; + + double value() { return val; } + bool has_value() { return has_val; } + + private: + double val; + bool has_val; +}; + +enum class BandwidthUsage { + kBwNormal = 0, + kBwUnderusing = 1, + kBwOverusing = 2, + kLast +}; + +struct TrendlineEstimatorSettings { + static constexpr char kKey[] = "WebRTC-Bwe-TrendlineEstimatorSettings"; + static constexpr unsigned kDefaultTrendlineWindowSize = 20; + + // TrendlineEstimatorSettings() = delete; + TrendlineEstimatorSettings( + /*const WebRtcKeyValueConfig* key_value_config*/); + + // Sort the packets in the window. Should be redundant, + // but then almost no cost. + bool enable_sort = false; + + // Cap the trendline slope based on the minimum delay seen + // in the beginning_packets and end_packets respectively. + bool enable_cap = false; + unsigned beginning_packets = 7; + unsigned end_packets = 7; + double cap_uncertainty = 0.0; + + // Size (in packets) of the window. + unsigned window_size = kDefaultTrendlineWindowSize; + + // std::unique_ptr<StructParametersParser> Parser(); +}; + +class TrendlineEstimator /*: public DelayIncreaseDetectorInterface */ { + public: + TrendlineEstimator(/*const WebRtcKeyValueConfig* key_value_config, + NetworkStatePredictor* network_state_predictor*/); + + ~TrendlineEstimator(); + + // Update the estimator with a new sample. The deltas should represent deltas + // between timestamp groups as defined by the InterArrival class. + void Update(double recv_delta_ms, double send_delta_ms, int64_t send_time_ms, + int64_t arrival_time_ms, size_t packet_size, + bool calculated_deltas); + + void UpdateTrendline(double recv_delta_ms, double send_delta_ms, + int64_t send_time_ms, int64_t arrival_time_ms, + size_t packet_size); + + BandwidthUsage State() const; + + struct PacketTiming { + PacketTiming(double arrival_time_ms, double smoothed_delay_ms, + double raw_delay_ms) + : arrival_time_ms(arrival_time_ms), + smoothed_delay_ms(smoothed_delay_ms), + raw_delay_ms(raw_delay_ms) {} + double arrival_time_ms; + double smoothed_delay_ms; + double raw_delay_ms; + }; + + private: + // friend class GoogCcStatePrinter; + void Detect(double trend, double ts_delta, int64_t now_ms); + + void UpdateThreshold(double modified_offset, int64_t now_ms); + + // Parameters. + TrendlineEstimatorSettings settings_; + const double smoothing_coef_; + const double threshold_gain_; + // Used by the existing threshold. + int num_of_deltas_; + // Keep the arrival times small by using the change from the first packet. + int64_t first_arrival_time_ms_; + // Exponential backoff filtering. + double accumulated_delay_; + double smoothed_delay_; + // Linear least squares regression. + std::deque<PacketTiming> delay_hist_; + + const double k_up_; + const double k_down_; + double overusing_time_threshold_; + double threshold_; + double prev_modified_trend_; + int64_t last_update_ms_; + double prev_trend_; + double time_over_using_; + int overuse_counter_; + BandwidthUsage hypothesis_; + // BandwidthUsage hypothesis_predicted_; + // NetworkStatePredictor* network_state_predictor_; + + // RTC_DISALLOW_COPY_AND_ASSIGN(TrendlineEstimator); +}; + +} // namespace rtc + +} // end namespace protocol + +} // end namespace transport +#endif // MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ diff --git a/libtransport/src/protocols/transport_protocol.cc b/libtransport/src/protocols/transport_protocol.cc new file mode 100644 index 000000000..611c39212 --- /dev/null +++ b/libtransport/src/protocols/transport_protocol.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/interfaces/socket_consumer.h> +#include <implementation/socket_consumer.h> +#include <protocols/transport_protocol.h> + +namespace transport { + +namespace protocol { + +using namespace interface; + +TransportProtocol::TransportProtocol(implementation::ConsumerSocket *icn_socket, + Reassembly *reassembly_protocol) + : socket_(icn_socket), + reassembly_protocol_(reassembly_protocol), + index_manager_( + std::make_unique<IndexManager>(socket_, this, reassembly_protocol)), + is_running_(false), + is_first_(false), + on_interest_retransmission_(VOID_HANDLER), + on_interest_output_(VOID_HANDLER), + on_interest_timeout_(VOID_HANDLER), + on_interest_satisfied_(VOID_HANDLER), + on_content_object_input_(VOID_HANDLER), + stats_summary_(VOID_HANDLER), + on_payload_(VOID_HANDLER) { + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal_); + socket_->getSocketOption(OtherOptions::STATISTICS, &stats_); +} + +int TransportProtocol::start() { + // If the protocol is already running, return otherwise set as running + if (is_running_) return -1; + + // Get all callbacks references + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_RETRANSMISSION, + &on_interest_retransmission_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_OUTPUT, + &on_interest_output_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_EXPIRED, + &on_interest_timeout_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_SATISFIED, + &on_interest_satisfied_); + socket_->getSocketOption(ConsumerCallbacksOptions::CONTENT_OBJECT_INPUT, + &on_content_object_input_); + socket_->getSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + &stats_summary_); + socket_->getSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + &on_payload_); + + socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); + + // Set it is the first time we schedule an interest + is_first_ = true; + + // Reset the protocol state machine + reset(); + // Schedule next interests + scheduleNextInterests(); + + is_first_ = false; + + // Set the protocol as running + is_running_ = true; + + if (!is_async_) { + // Start Event loop + portal_->runEventsLoop(); + + // Not running anymore + is_running_ = false; + } + + return 0; +} + +void TransportProtocol::stop() { + is_running_ = false; + + if (!is_async_) { + portal_->stopEventsLoop(); + } else { + portal_->clear(); + } +} + +void TransportProtocol::resume() { + if (is_running_) return; + + is_running_ = true; + + scheduleNextInterests(); + + portal_->runEventsLoop(); + + is_running_ = false; +} + +void TransportProtocol::onContentReassembled(std::error_code ec) { + stop(); + + if (!on_payload_) { + throw errors::RuntimeException( + "The read callback must be installed in the transport before " + "starting " + "the content retrieval."); + } + + if (!ec) { + on_payload_->readSuccess(stats_->getBytesRecv()); + } else { + on_payload_->readError(ec); + } +} + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/transport_protocol.h b/libtransport/src/protocols/transport_protocol.h new file mode 100644 index 000000000..124c57122 --- /dev/null +++ b/libtransport/src/protocols/transport_protocol.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/statistics.h> +#include <hicn/transport/utils/object_pool.h> +#include <implementation/socket.h> +#include <protocols/data_processing_events.h> +#include <protocols/indexer.h> +#include <protocols/reassembly.h> + +#include <atomic> + +namespace transport { + +namespace protocol { + +using namespace core; + +class IndexVerificationManager; + +using ReadCallback = interface::ConsumerSocket::ReadCallback; + +class TransportProtocolCallback { + virtual void onContentObject(const core::Interest &interest, + const core::ContentObject &content_object) = 0; + virtual void onTimeout(const core::Interest &interest) = 0; +}; + +class TransportProtocol : public core::Portal::ConsumerCallback, + public ContentObjectProcessingEventCallback { + static constexpr std::size_t interest_pool_size = 4096; + + friend class ManifestIndexManager; + + public: + TransportProtocol(implementation::ConsumerSocket *icn_socket, + Reassembly *reassembly_protocol); + + virtual ~TransportProtocol() = default; + + TRANSPORT_ALWAYS_INLINE bool isRunning() { return is_running_; } + + virtual int start(); + + virtual void stop(); + + virtual void resume(); + + virtual void scheduleNextInterests() = 0; + + // Events generated by the indexing + virtual void onContentReassembled(std::error_code ec); + virtual void onPacketDropped(Interest &interest, + ContentObject &content_object) override = 0; + virtual void onReassemblyFailed(std::uint32_t missing_segment) override = 0; + + protected: + // Consumer Callback + virtual void reset() = 0; + virtual void onContentObject(Interest &i, ContentObject &c) override = 0; + virtual void onTimeout(Interest::Ptr &&i) override = 0; + virtual void onError(std::error_code ec) override {} + + protected: + implementation::ConsumerSocket *socket_; + std::unique_ptr<Reassembly> reassembly_protocol_; + std::unique_ptr<IndexManager> index_manager_; + std::shared_ptr<core::Portal> portal_; + std::atomic<bool> is_running_; + // True if it si the first time we schedule an interest + std::atomic<bool> is_first_; + interface::TransportStatistics *stats_; + + // Callbacks + interface::ConsumerInterestCallback *on_interest_retransmission_; + interface::ConsumerInterestCallback *on_interest_output_; + interface::ConsumerInterestCallback *on_interest_timeout_; + interface::ConsumerInterestCallback *on_interest_satisfied_; + interface::ConsumerContentObjectCallback *on_content_object_input_; + interface::ConsumerContentObjectCallback *on_content_object_; + interface::ConsumerTimerCallback *stats_summary_; + ReadCallback *on_payload_; + + bool is_async_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/test/CMakeLists.txt b/libtransport/src/test/CMakeLists.txt index 19e59c7e1..dd3d1d923 100644 --- a/libtransport/src/test/CMakeLists.txt +++ b/libtransport/src/test/CMakeLists.txt @@ -14,14 +14,15 @@ include(BuildMacros) list(APPEND TESTS + test_auth + test_consumer_producer_rtc test_core_manifest - test_transport_producer + test_event_thread + test_fec_reedsolomon + test_interest + test_packet ) -if (${LIBTRANSPORT_SHARED} MATCHES ".*-memif.*") - set(LINK_FLAGS "-Wl,-unresolved-symbols=ignore-in-shared-libs") -endif() - foreach(test ${TESTS}) build_executable(${test} NO_INSTALL @@ -35,4 +36,4 @@ foreach(test ${TESTS}) ) add_test_internal(${test}) -endforeach()
\ No newline at end of file +endforeach() diff --git a/libtransport/src/test/fec_reed_solomon.cc b/libtransport/src/test/fec_reed_solomon.cc new file mode 100644 index 000000000..36543c531 --- /dev/null +++ b/libtransport/src/test/fec_reed_solomon.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/global_conf_interface.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> +#include <fec/rs.h> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + FAIL() << "Failed to schedule packet production"; + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + + io_service_.run(); +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/fec_rely.cc b/libtransport/src/test/fec_rely.cc new file mode 100644 index 000000000..e7745bae5 --- /dev/null +++ b/libtransport/src/test/fec_rely.cc @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/global_conf_interface.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> + +#include <rely/encoder.hpp> +#include <rely/decoder.hpp> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + FAIL() << "Failed to schedule packet production"; + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + + io_service_.run(); +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/packet_samples.h b/libtransport/src/test/packet_samples.h new file mode 100644 index 000000000..e98d06a18 --- /dev/null +++ b/libtransport/src/test/packet_samples.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define TCP_PROTO 0x06 +#define ICMP_PROTO 0x01 +#define ICMP6_PROTO 0x3a + +#define IPV6_HEADER(next_header, payload_length) \ + 0x60, 0x00, 0x00, 0x00, 0x00, payload_length, next_header, 0x40, 0xb0, 0x06, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xab, 0xcd, 0xab, \ + 0xcd, 0xef, 0xb0, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0xca + +#define IPV4_HEADER(next_header, payload_length) \ + 0x45, 0x02, 0x00, payload_length + 20, 0x47, 0xc4, 0x40, 0x00, 0x25, \ + next_header, 0x6e, 0x76, 0x03, 0x7b, 0xd9, 0xd0, 0xc0, 0xa8, 0x01, 0x5c + +#define TCP_HEADER(flags) \ + 0x12, 0x34, 0x43, 0x21, 0x00, 0x00, 0x00, 0x01, 0xb2, 0x8c, 0x03, 0x1f, \ + 0x80, flags, 0x00, 0x0a, 0xb9, 0xbb, 0x00, 0x00 + +#define PAYLOAD \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x20, 0x00, 0x00 + +#define PAYLOAD_SIZE 12 + +#define ICMP_ECHO_REQUEST \ + 0x08, 0x00, 0x87, 0xdb, 0x38, 0xa7, 0x00, 0x05, 0x60, 0x2b, 0xc2, 0xcb, \ + 0x00, 0x02, 0x29, 0x7c, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, \ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, \ + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, \ + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, \ + 0x34, 0x35, 0x36, 0x37 + +#define ICMP6_ECHO_REQUEST \ + 0x80, 0x00, 0x86, 0x3c, 0x11, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, \ + 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, \ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, \ + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, \ + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33 + +#define AH_HEADER \ + 0x00, (128 >> 2), 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 diff --git a/libtransport/src/test/test_auth.cc b/libtransport/src/test/test_auth.cc new file mode 100644 index 000000000..976981cce --- /dev/null +++ b/libtransport/src/test/test_auth.cc @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/auth/crypto_hash_type.h> +#include <hicn/transport/auth/identity.h> +#include <hicn/transport/auth/signer.h> +#include <hicn/transport/auth/verifier.h> +#include <hicn/transport/core/content_object.h> + +namespace transport { +namespace auth { + +namespace { +class AuthTest : public ::testing::Test { + protected: + const std::string PASSPHRASE = "hunter2"; + + AuthTest() = default; + ~AuthTest() {} + void SetUp() override {} + void TearDown() override {} +}; +} // namespace + +TEST_F(AuthTest, VoidVerifier) { + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Verify that VoidVerifier validates the packet + std::shared_ptr<Verifier> verifier = std::make_shared<VoidVerifier>(); + ASSERT_EQ(verifier->verifyPacket(&packet), true); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); +} + +TEST_F(AuthTest, RSAVerifier) { + // Create the RSA signer from an Identity object + Identity identity("test_rsa.p12", PASSPHRASE, CryptoSuite::RSA_SHA256, 1024u, + 30, "RSAVerifier"); + std::shared_ptr<Signer> signer = identity.getSigner(); + + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH, signer->getSignatureSize()); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Sign the packet + signer->signPacket(&packet); + + // Create the RSA verifier + PARCKey *key = parcSigner_CreatePublicKey(signer->getParcSigner()); + std::shared_ptr<Verifier> verifier = + std::make_shared<AsymmetricVerifier>(key); + + ASSERT_EQ(packet.getFormat(), HF_INET6_TCP_AH); + ASSERT_EQ(signer->getCryptoHashType(), CryptoHashType::SHA_256); + ASSERT_EQ(signer->getCryptoSuite(), CryptoSuite::RSA_SHA256); + ASSERT_EQ(signer->getSignatureSize(), 128u); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); + + // Release PARC objects + parcKey_Release(&key); +} + +TEST_F(AuthTest, HMACVerifier) { + // Create the HMAC signer from a passphrase + std::shared_ptr<Signer> signer = + std::make_shared<SymmetricSigner>(CryptoSuite::HMAC_SHA256, PASSPHRASE); + + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH, signer->getSignatureSize()); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Sign the packet + signer->signPacket(&packet); + + // Create the HMAC verifier + std::shared_ptr<Verifier> verifier = + std::make_shared<SymmetricVerifier>(PASSPHRASE); + + ASSERT_EQ(packet.getFormat(), HF_INET6_TCP_AH); + ASSERT_EQ(signer->getCryptoHashType(), CryptoHashType::SHA_256); + ASSERT_EQ(signer->getCryptoSuite(), CryptoSuite::HMAC_SHA256); + ASSERT_EQ(signer->getSignatureSize(), 32u); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/test/test_consumer_producer_rtc.cc b/libtransport/src/test/test_consumer_producer_rtc.cc new file mode 100644 index 000000000..87385971a --- /dev/null +++ b/libtransport/src/test/test_consumer_producer_rtc.cc @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + stop_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void setStopTimer() { + using namespace std::chrono; + stop_timer_.expires_from_now(seconds(unsigned(10))); + stop_timer_.async_wait( + std::bind(&ConsumerProducerTest::stop, this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + void stop(const std::error_code &ec) { + rtc_timer_.cancel(); + producer_.stop(); + consumer_.stop(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + std::cout << "Received something" << std::endl; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + asio::steady_timer stop_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + setStopTimer(); + + io_service_.run(); + + std::cout << "Packet received: " << packets_received_ << std::endl; + std::cout << "Packet sent: " << packets_sent_ << std::endl; +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { +#if 0 + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +#else + return 0; +#endif +}
\ No newline at end of file diff --git a/libtransport/src/test/test_core_manifest.cc b/libtransport/src/test/test_core_manifest.cc index faf17dcf0..f98147d43 100644 --- a/libtransport/src/test/test_core_manifest.cc +++ b/libtransport/src/test/test_core_manifest.cc @@ -16,7 +16,8 @@ #include <core/manifest_format_fixed.h> #include <core/manifest_inline.h> #include <gtest/gtest.h> -#include <hicn/transport/security/crypto_hash_type.h> +#include <hicn/transport/auth/crypto_hash_type.h> +#include <test/packet_samples.h> #include <climits> #include <random> @@ -72,6 +73,27 @@ class ManifestTest : public ::testing::Test { } // namespace +TEST_F(ManifestTest, MoveConstructor) { + // Create content object with manifest in payload + ContentObject co(HF_INET6_TCP_AH, 128); + co.appendPayload(&manifest_payload[0], manifest_payload.size()); + uint8_t buffer[256]; + co.appendPayload(buffer, 256); + + // Copy packet payload + uint8_t packet[1500]; + auto length = co.getPayload()->length(); + std::memcpy(packet, co.getPayload()->data(), length); + + // Create manifest + ContentObjectManifest m(std::move(co)); + + // Check manifest payload is exactly the same of content object + ASSERT_EQ(length, m.getPayload()->length()); + auto ret = std::memcmp(packet, m.getPayload()->data(), length); + ASSERT_EQ(ret, 0); +} + TEST_F(ManifestTest, SetLastManifest) { manifest1_.clear(); @@ -102,9 +124,9 @@ TEST_F(ManifestTest, SetManifestType) { TEST_F(ManifestTest, SetHashAlgorithm) { manifest1_.clear(); - utils::CryptoHashType hash1 = utils::CryptoHashType::SHA_512; - utils::CryptoHashType hash2 = utils::CryptoHashType::CRC32C; - utils::CryptoHashType hash3 = utils::CryptoHashType::SHA_256; + auth::CryptoHashType hash1 = auth::CryptoHashType::SHA_512; + auth::CryptoHashType hash2 = auth::CryptoHashType::CRC32C; + auth::CryptoHashType hash3 = auth::CryptoHashType::SHA_256; manifest1_.setHashAlgorithm(hash1); auto type_returned1 = manifest1_.getHashAlgorithm(); @@ -161,7 +183,7 @@ TEST_F(ManifestTest, SetSuffixList) { std::uniform_int_distribution<uint64_t> idis( 0, std::numeric_limits<uint32_t>::max()); - auto entries = new std::pair<uint32_t, utils::CryptoHash>[3]; + auto entries = new std::pair<uint32_t, auth::CryptoHash>[3]; uint32_t suffixes[3]; std::vector<unsigned char> data[3]; @@ -170,8 +192,8 @@ TEST_F(ManifestTest, SetSuffixList) { std::generate(std::begin(data[i]), std::end(data[i]), std::ref(rbe)); suffixes[i] = idis(eng); entries[i] = std::make_pair( - suffixes[i], utils::CryptoHash(data[i].data(), data[i].size(), - utils::CryptoHashType::SHA_256)); + suffixes[i], auth::CryptoHash(data[i].data(), data[i].size(), + auth::CryptoHashType::SHA_256)); manifest1_.addSuffixHash(entries[i].first, entries[i].second); } @@ -186,9 +208,9 @@ TEST_F(ManifestTest, SetSuffixList) { // for (auto & item : manifest1_.getSuffixList()) { // auto hash = manifest1_.getHash(suffixes[i]); - // cond = utils::CryptoHash::compareBinaryDigest(hash, - // entries[i].second.getDigest<uint8_t>().data(), - // entries[i].second.getType()); + // cond = auth::CryptoHash::compareBinaryDigest(hash, + // entries[i].second.getDigest<uint8_t>().data(), + // entries[i].second.getType()); // ASSERT_TRUE(cond); // i++; // } @@ -205,4 +227,4 @@ TEST_F(ManifestTest, SetSuffixList) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -}
\ No newline at end of file +} diff --git a/libtransport/src/test/test_event_thread.cc b/libtransport/src/test/test_event_thread.cc new file mode 100644 index 000000000..e66b49f10 --- /dev/null +++ b/libtransport/src/test/test_event_thread.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/utils/event_thread.h> + +#include <cmath> + +namespace utils { + +namespace { + +class EventThreadTest : public ::testing::Test { + protected: + EventThreadTest() : event_thread_() { + // You can do set-up work for each test here. + } + + virtual ~EventThreadTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + utils::EventThread event_thread_; +}; + +double average(const unsigned long samples[], int size) { + double sum = 0; + + for (int i = 0; i < size; i++) { + sum += samples[i]; + } + + return sum / size; +} + +double stdDeviation(const unsigned long samples[], int size) { + double avg = average(samples, size); + double var = 0; + + for (int i = 0; i < size; i++) { + var += (samples[i] - avg) * (samples[i] - avg); + } + + return sqrt(var / size); +} + +} // namespace + +TEST_F(EventThreadTest, SchedulingDelay) { + using namespace std::chrono; + const size_t size = 1000000; + std::vector<unsigned long> samples(size); + + for (unsigned int i = 0; i < size; i++) { + auto t0 = steady_clock::now(); + event_thread_.add([t0, &samples, i]() { + auto t1 = steady_clock::now(); + samples[i] = duration_cast<nanoseconds>(t1 - t0).count(); + }); + } + + event_thread_.stop(); + + auto avg = average(&samples[0], size); + auto sd = stdDeviation(&samples[0], size); + (void)sd; + + // Expect average to be less that 1 ms + EXPECT_LT(avg, 1000000); +} + +} // namespace utils + +int main(int argc, char **argv) { +#if 0 + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +#else + return 0; +#endif +}
\ No newline at end of file diff --git a/libtransport/src/test/test_fec_reedsolomon.cc b/libtransport/src/test/test_fec_reedsolomon.cc new file mode 100644 index 000000000..3b10b7307 --- /dev/null +++ b/libtransport/src/test/test_fec_reedsolomon.cc @@ -0,0 +1,291 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/rs.h> +#include <gtest/gtest.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/global_object_pool.h> + +#include <algorithm> +#include <iostream> +#include <random> + +namespace transport { +namespace core { + +double ReedSolomonTest(int k, int n, int size) { + fec::encoder encoder(k, n); + fec::decoder decoder(k, n); + + std::vector<fec::buffer> tx_block(k); + std::vector<fec::buffer> rx_block(k); + int count = 0; + int run = 0; + + int iterations = 100; + auto &packet_manager = PacketManager<>::getInstance(); + + encoder.setFECCallback([&tx_block](std::vector<fec::buffer> &repair_packets) { + for (auto &p : repair_packets) { + // Append repair symbols to tx_block + tx_block.emplace_back(std::move(p)); + } + }); + + decoder.setFECCallback([&](std::vector<fec::buffer> &source_packets) { + for (int i = 0; i < k; i++) { + // Compare decoded source packets with original transmitted packets. + if (*tx_block[i] != *source_packets[i]) { + count++; + } + } + }); + + do { + // Discard eventual packet appended in previous callback call + tx_block.erase(tx_block.begin() + k, tx_block.end()); + + // Initialization. Feed encoder with first k source packets + for (int i = 0; i < k; i++) { + // Get new buffer from pool + auto packet = packet_manager.getMemBuf(); + + // Let's append a bit less than size, so that the FEC class will take care + // of filling the rest with zeros + auto cur_size = size - (rand() % 100); + + // Set payload, saving 2 bytes at the beginning of the buffer for encoding + // the length + packet->append(cur_size); + packet->trimStart(2); + std::generate(packet->writableData(), packet->writableTail(), rand); + std::fill(packet->writableData(), packet->writableTail(), i + 1); + + // Set first byte of payload to i, to reorder at receiver side + packet->writableData()[0] = uint8_t(i); + + // Store packet in tx buffer and clear rx buffer + tx_block[i] = std::move(packet); + } + + // Create the repair packets + for (auto &tx : tx_block) { + encoder.consume(tx, tx->writableBuffer()[0]); + } + + // Simulate transmission on lossy channel + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::vector<bool> losses(n, false); + for (int i = 0; i < n - k; i++) losses[i] = true; + + int rxi = 0; + std::shuffle(losses.begin(), losses.end(), + std::default_random_engine(seed)); + for (int i = 0; i < n && rxi < k; i++) + if (losses[i] == false) { + rx_block[rxi++] = tx_block[i]; + if (i < k) { + // Source packet + decoder.consume(rx_block[rxi - 1], rx_block[rxi - 1]->data()[0]); + } else { + // Repair packet + decoder.consume(rx_block[rxi - 1]); + } + } + + decoder.clear(); + encoder.clear(); + } while (++run < iterations); + + return count; +} + +void ReedSolomonMultiBlockTest(int n_sourceblocks) { + int k = 16; + int n = 24; + int size = 1000; + + fec::encoder encoder(k, n); + fec::decoder decoder(k, n); + + auto &packet_manager = PacketManager<>::getInstance(); + + std::vector<std::pair<fec::buffer, uint32_t>> tx_block; + std::vector<std::pair<fec::buffer, uint32_t>> rx_block; + int count = 0; + int i = 0; + + // Receiver will receive packet for n_sourceblocks in a random order. + int total_packets = n * n_sourceblocks; + int tx_packets = k * n_sourceblocks; + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + + encoder.setFECCallback([&](std::vector<fec::buffer> &repair_packets) { + for (auto &p : repair_packets) { + // Append repair symbols to tx_block + tx_block.emplace_back(std::move(p), ++i); + } + + EXPECT_EQ(tx_block.size(), size_t(n)); + + // Select k packets to send, including at least one symbol. We start from + // the end for this reason. + for (int j = n - 1; j > n - k - 1; j--) { + rx_block.emplace_back(std::move(tx_block[j])); + } + + // Clear tx block for next source block + tx_block.clear(); + encoder.clear(); + }); + + // The decode callback must be called exactly n_sourceblocks times + decoder.setFECCallback( + [&](std::vector<fec::buffer> &source_packets) { count++; }); + + // Produce n * n_sourceblocks + // - ( k ) * n_sourceblocks source packets + // - (n - k) * n_sourceblocks symbols) + for (i = 0; i < total_packets; i++) { + // Get new buffer from pool + auto packet = packet_manager.getMemBuf(); + + // Let's append a bit less than size, so that the FEC class will take care + // of filling the rest with zeros + auto cur_size = size - (rand() % 100); + + // Set payload, saving 2 bytes at the beginning of the buffer for encoding + // the length + packet->append(cur_size); + packet->trimStart(2); + std::fill(packet->writableData(), packet->writableTail(), i + 1); + + // Set first byte of payload to i, to reorder at receiver side + packet->writableData()[0] = uint8_t(i); + + // Store packet in tx buffer + tx_block.emplace_back(packet, i); + + // Feed encoder with packet + encoder.consume(packet, i); + } + + // Here rx_block must contains k * n_sourceblocks packets + EXPECT_EQ(size_t(tx_packets), size_t(rx_block.size())); + + // Lets shuffle the rx_block before starting feeding the decoder. + std::shuffle(rx_block.begin(), rx_block.end(), + std::default_random_engine(seed)); + + for (auto &p : rx_block) { + int index = p.second % n; + if (index < k) { + // Source packet + decoder.consume(p.first, p.second); + } else { + // Repair packet + decoder.consume(p.first); + } + } + + // Simple test to check we get all the source packets + EXPECT_EQ(count, n_sourceblocks); +} + +TEST(ReedSolomonTest, RSk1n3) { + int k = 1; + int n = 3; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk6n10) { + int k = 6; + int n = 10; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk8n32) { + int k = 8; + int n = 32; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk16n24) { + int k = 16; + int n = 24; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n30) { + int k = 10; + int n = 30; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n40) { + int k = 10; + int n = 40; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n60) { + int k = 10; + int n = 60; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n90) { + int k = 10; + int n = 90; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonMultiBlockTest, RSMB1) { + int blocks = 1; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB10) { + int blocks = 10; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB100) { + int blocks = 100; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB1000) { + int blocks = 1000; + ReedSolomonMultiBlockTest(blocks); +} + +int main(int argc, char **argv) { + srand(time(0)); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace core +} // namespace transport diff --git a/libtransport/src/test/test_interest.cc b/libtransport/src/test/test_interest.cc new file mode 100644 index 000000000..0a835db24 --- /dev/null +++ b/libtransport/src/test/test_interest.cc @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <test/packet_samples.h> + +#include <climits> +#include <random> +#include <vector> + +namespace transport { + +namespace core { + +namespace { +// The fixture for testing class Foo. +class InterestTest : public ::testing::Test { + protected: + InterestTest() : name_("b001::123|321"), interest_() { + // You can do set-up work for each test here. + } + + virtual ~InterestTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + Name name_; + + Interest interest_; + + std::vector<uint8_t> buffer_ = {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // ICMP6 echo request + TCP_HEADER(0x00), + // Payload + PAYLOAD}; +}; + +void testFormatConstructor(Packet::Format format = HF_UNSPEC) { + try { + Interest interest(format, 0); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown for " << format; + } +} + +void testFormatConstructorException(Packet::Format format = HF_UNSPEC) { + try { + Interest interest(format, 0); + FAIL() << "We expected an exception here"; + } catch (errors::MalformedPacketException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } +} + +} // namespace + +TEST_F(InterestTest, ConstructorWithFormat) { + /** + * Without arguments it should be format = HF_UNSPEC. + * We expect a crash. + */ + + testFormatConstructor(Packet::Format::HF_INET_TCP); + testFormatConstructor(Packet::Format::HF_INET6_TCP); + testFormatConstructorException(Packet::Format::HF_INET_ICMP); + testFormatConstructorException(Packet::Format::HF_INET6_ICMP); + testFormatConstructor(Packet::Format::HF_INET_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET6_TCP_AH); + testFormatConstructorException(Packet::Format::HF_INET_ICMP_AH); + testFormatConstructorException(Packet::Format::HF_INET6_ICMP_AH); +} + +TEST_F(InterestTest, ConstructorWithName) { + /** + * Without arguments it should be format = HF_UNSPEC. + * We expect a crash. + */ + Name n("b001::1|123"); + + try { + Interest interest(n); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } +} + +TEST_F(InterestTest, ConstructorWithBuffer) { + // Ensure buffer is interest + auto ret = Interest::isInterest(&buffer_[0]); + EXPECT_TRUE(ret); + + // Create interest from buffer + try { + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } + + std::vector<uint8_t> buffer2{// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60 + 44), + // ICMP6 echo request + TCP_HEADER(0x00), + // Payload + PAYLOAD}; + + // Ensure this throws an exception + try { + Interest interest(Interest::COPY_BUFFER, &buffer2[0], buffer2.size()); + FAIL() << "We expected an exception here"; + } catch (errors::MalformedPacketException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } +} + +TEST_F(InterestTest, SetGetName) { + // Create interest from buffer + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + + // Get name + auto n = interest.getName(); + + // ensure name is b002::ca|1 + Name n2("b002::ca|1"); + auto ret = (n == n2); + + EXPECT_TRUE(ret); + + Name n3("b003::1234|1234"); + + // Change name to b003::1234|1234 + interest.setName(n3); + + // Check name was set + n = interest.getName(); + ret = (n == n3); + EXPECT_TRUE(ret); +} + +TEST_F(InterestTest, SetGetLocator) { + // Create interest from buffer + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + + // Get locator + auto l = interest.getLocator(); + + ip_address_t address; + ip_address_pton("b006::ab:cdab:cdef", &address); + auto ret = !std::memcmp(&l, &address, sizeof(address)); + + EXPECT_TRUE(ret); + + // Set different locator + ip_address_pton("2001::1234::4321::abcd::", &address); + + // Set it on interest + interest.setLocator(address); + + // Check it was set + l = interest.getLocator(); + ret = !std::memcmp(&l, &address, sizeof(address)); + + EXPECT_TRUE(ret); +} + +TEST_F(InterestTest, SetGetLifetime) { + // Create interest from buffer + Interest interest; + const constexpr uint32_t lifetime = 10000; + + // Set lifetime + interest.setLifetime(lifetime); + + // Get lifetime + auto l = interest.getLifetime(); + + // Ensure they are the same + EXPECT_EQ(l, lifetime); +} + +TEST_F(InterestTest, HasManifest) { + // Create interest from buffer + Interest interest; + + // Let's expect anexception here + try { + interest.setPayloadType(PayloadType::UNSPECIFIED); + FAIL() << "We expect an esception here"; + } catch (errors::RuntimeException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } + + interest.setPayloadType(PayloadType::DATA); + EXPECT_FALSE(interest.hasManifest()); + + interest.setPayloadType(PayloadType::MANIFEST); + EXPECT_TRUE(interest.hasManifest()); +} + +TEST_F(InterestTest, AppendSuffixesEncodeAndIterate) { + // Create interest from buffer + Interest interest; + + // Appenad some suffixes, with some duplicates + interest.appendSuffix(1); + interest.appendSuffix(2); + interest.appendSuffix(5); + interest.appendSuffix(3); + interest.appendSuffix(4); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + + // Encode them in wire format + interest.encodeSuffixes(); + + // Iterate over them. They should be in order and without repetitions + auto suffix = interest.firstSuffix(); + auto n_suffixes = interest.numberOfSuffixes(); + + for (uint32_t i = 0; i < n_suffixes; i++) { + EXPECT_EQ(*(suffix + i), (i + 1)); + } +} + +} // namespace core +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/test_packet.cc b/libtransport/src/test/test_packet.cc new file mode 100644 index 000000000..0ee140e2c --- /dev/null +++ b/libtransport/src/test/test_packet.cc @@ -0,0 +1,1047 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/core/packet.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <test/packet_samples.h> + +#include <climits> +#include <random> +#include <vector> + +namespace transport { + +namespace core { + +/** + * Since packet is an abstract class, we derive a concrete class to be used for + * the test. + */ +class PacketForTest : public Packet { + public: + template <typename... Args> + PacketForTest(Args &&... args) : Packet(std::forward<Args>(args)...) {} + + virtual ~PacketForTest() {} + + const Name &getName() const override { + throw errors::NotImplementedException(); + } + + Name &getWritableName() override { throw errors::NotImplementedException(); } + + void setName(const Name &name) override { + throw errors::NotImplementedException(); + } + + void setName(Name &&name) override { + throw errors::NotImplementedException(); + } + + void setLifetime(uint32_t lifetime) override { + throw errors::NotImplementedException(); + } + + uint32_t getLifetime() const override { + throw errors::NotImplementedException(); + } + + void setLocator(const ip_address_t &locator) override { + throw errors::NotImplementedException(); + } + + void resetForHash() override { throw errors::NotImplementedException(); } + + ip_address_t getLocator() const override { + throw errors::NotImplementedException(); + } +}; + +namespace { +// The fixture for testing class Foo. +class PacketTest : public ::testing::Test { + protected: + PacketTest() + : name_("b001::123|321"), + packet(Packet::COPY_BUFFER, &raw_packets_[HF_INET6_TCP][0], + raw_packets_[HF_INET6_TCP].size()) { + // You can do set-up work for each test here. + } + + virtual ~PacketTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + Name name_; + + PacketForTest packet; + + static std::map<Packet::Format, std::vector<uint8_t>> raw_packets_; + + std::vector<uint8_t> payload = { + 0x11, 0x11, 0x01, 0x00, 0xb0, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad // , 0x00, 0x00, + // 0x00, 0x45, 0xa3, + // 0xd1, 0xf2, 0x2b, + // 0x94, 0x41, 0x22, + // 0xc9, 0x00, 0x00, + // 0x00, 0x44, 0xa3, + // 0xd1, 0xf2, 0x2b, + // 0x94, 0x41, 0x22, + // 0xc8 + }; +}; + +std::map<Packet::Format, std::vector<uint8_t>> PacketTest::raw_packets_ = { + {Packet::Format::HF_INET6_TCP, + + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // TCP src=0x1234 dst=0x4321, seq=0x0001 + TCP_HEADER(0x00), + // Payload + PAYLOAD}}, + + {Packet::Format::HF_INET_TCP, + {// IPv4 src=3.13.127.8, dst=192.168.1.92 + IPV4_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // TCP src=0x1234 dst=0x4321, seq=0x0001 + TCP_HEADER(0x00), + // Other + PAYLOAD}}, + + {Packet::Format::HF_INET_ICMP, + {// IPv4 src=3.13.127.8, dst=192.168.1.92 + IPV4_HEADER(ICMP_PROTO, 64), + // ICMP echo request + ICMP_ECHO_REQUEST}}, + + {Packet::Format::HF_INET6_ICMP, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60), + // ICMP6 echo request + ICMP6_ECHO_REQUEST}}, + + {Packet::Format::HF_INET6_TCP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + 44 + 128), + // ICMP6 echo request + TCP_HEADER(0x18), + // hICN AH header + AH_HEADER}}, + + {Packet::Format::HF_INET_TCP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV4_HEADER(TCP_PROTO, 20 + 44 + 128), + // ICMP6 echo request + TCP_HEADER(0x18), + // hICN AH header + AH_HEADER}}, + + // XXX No flag defined in ICMP header to signal AH header. + {Packet::Format::HF_INET_ICMP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV4_HEADER(ICMP_PROTO, 64 + 44), + // ICMP6 echo request + ICMP_ECHO_REQUEST, + // hICN AH header + AH_HEADER}}, + + {Packet::Format::HF_INET6_ICMP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60 + 44), + // ICMP6 echo request + ICMP6_ECHO_REQUEST, + // hICN AH header + AH_HEADER}}, + +}; + +void testFormatConstructor(Packet::Format format = HF_UNSPEC) { + try { + PacketForTest packet(format); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown for " << format; + } +} + +void testFormatAndAdditionalHeaderConstructor(Packet::Format format, + std::size_t additional_header) { + PacketForTest packet(format, additional_header); + // Packet length should be the one of the normal header + the + // additional_header + + EXPECT_EQ(packet.headerSize(), + Packet::getHeaderSizeFromFormat(format) + additional_header); +} + +void testRawBufferConstructor(std::vector<uint8_t> packet, + Packet::Format format) { + try { + // Try to construct packet from correct buffer + PacketForTest p(Packet::WRAP_BUFFER, &packet[0], packet.size(), + packet.size()); + + // Check format is expected one. + EXPECT_EQ(p.getFormat(), format); + + // // Try the same using a MemBuf + // auto buf = utils::MemBuf::wrapBuffer(&packet[0], packet.size()); + // buf->append(packet.size()); + // PacketForTest p2(std::move(buf)); + + // EXPECT_EQ(p2.getFormat(), format); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } + + try { + // Try to construct packet from wrong buffer + + // Modify next header to 0 + /* ipv6 */ + packet[6] = 0x00; + /* ipv4 */ + packet[9] = 0x00; + PacketForTest p(Packet::WRAP_BUFFER, &packet[0], packet.size(), + packet.size()); + + // Format should fallback to HF_UNSPEC + EXPECT_EQ(p.getFormat(), HF_UNSPEC); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown."; + } +} + +void getHeaderSizeFromBuffer(Packet::Format format, + std::vector<uint8_t> &packet, + std::size_t expected) { + auto header_size = PacketForTest::getHeaderSizeFromBuffer(format, &packet[0]); + EXPECT_EQ(header_size, expected); +} + +void getHeaderSizeFromFormat(Packet::Format format, std::size_t expected) { + auto header_size = PacketForTest::getHeaderSizeFromFormat(format); + EXPECT_EQ(header_size, expected); +} + +void getPayloadSizeFromBuffer(Packet::Format format, + std::vector<uint8_t> &packet, + std::size_t expected) { + auto payload_size = + PacketForTest::getPayloadSizeFromBuffer(format, &packet[0]); + EXPECT_EQ(payload_size, expected); +} + +void getFormatFromBuffer(Packet::Format expected, + std::vector<uint8_t> &packet) { + auto format = PacketForTest::getFormatFromBuffer(&packet[0], packet.size()); + EXPECT_EQ(format, expected); +} + +void getHeaderSize(std::size_t expected, const PacketForTest &packet) { + auto size = packet.headerSize(); + EXPECT_EQ(size, expected); +} + +void testGetFormat(Packet::Format expected, const Packet &packet) { + auto format = packet.getFormat(); + EXPECT_EQ(format, expected); +} + +} // namespace + +TEST_F(PacketTest, ConstructorWithFormat) { + testFormatConstructor(Packet::Format::HF_INET_TCP); + testFormatConstructor(Packet::Format::HF_INET6_TCP); + testFormatConstructor(Packet::Format::HF_INET_ICMP); + testFormatConstructor(Packet::Format::HF_INET6_ICMP); + testFormatConstructor(Packet::Format::HF_INET_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET6_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET_ICMP_AH); + testFormatConstructor(Packet::Format::HF_INET6_ICMP_AH); +} + +TEST_F(PacketTest, ConstructorWithFormatAndAdditionalHeader) { + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_TCP, 123); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_TCP, 360); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_ICMP, 21); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_ICMP, 444); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_TCP_AH, 555); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_TCP_AH, + 321); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_ICMP_AH, + 123); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_ICMP_AH, + 44); +} + +TEST_F(PacketTest, ConstructorWithNew) { + auto &_packet = raw_packets_[HF_INET6_TCP]; + auto packet_ptr = new PacketForTest(Packet::WRAP_BUFFER, &_packet[0], + _packet.size(), _packet.size()); + (void)packet_ptr; +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6Tcp) { + auto format = Packet::Format::HF_INET6_TCP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetTcp) { + auto format = Packet::Format::HF_INET_TCP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetIcmp) { + auto format = Packet::Format::HF_INET_ICMP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6Icmp) { + auto format = Packet::Format::HF_INET6_ICMP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6TcpAh) { + auto format = Packet::Format::HF_INET6_TCP_AH; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetTcpAh) { + auto format = Packet::Format::HF_INET_TCP_AH; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, MoveConstructor) { + PacketForTest p0(Packet::Format::HF_INET6_TCP); + PacketForTest p1(std::move(p0)); + EXPECT_EQ(p0.getFormat(), Packet::Format::HF_UNSPEC); + EXPECT_EQ(p1.getFormat(), Packet::Format::HF_INET6_TCP); +} + +TEST_F(PacketTest, TestGetHeaderSizeFromBuffer) { + getHeaderSizeFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP], + HICN_V6_TCP_HDRLEN); + getHeaderSizeFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP], + HICN_V4_TCP_HDRLEN); + getHeaderSizeFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP], + IPV6_HDRLEN + 4); + getHeaderSizeFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP], + IPV4_HDRLEN + 4); + getHeaderSizeFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH], + HICN_V6_TCP_AH_HDRLEN + 128); + getHeaderSizeFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH], + HICN_V4_TCP_AH_HDRLEN + 128); +} + +TEST_F(PacketTest, TestGetHeaderSizeFromFormat) { + getHeaderSizeFromFormat(HF_INET6_TCP, HICN_V6_TCP_HDRLEN); + getHeaderSizeFromFormat(HF_INET_TCP, HICN_V4_TCP_HDRLEN); + getHeaderSizeFromFormat(HF_INET6_ICMP, IPV6_HDRLEN + 4); + getHeaderSizeFromFormat(HF_INET_ICMP, IPV4_HDRLEN + 4); + getHeaderSizeFromFormat(HF_INET6_TCP_AH, HICN_V6_TCP_AH_HDRLEN); + getHeaderSizeFromFormat(HF_INET_TCP_AH, HICN_V4_TCP_AH_HDRLEN); +} + +TEST_F(PacketTest, TestGetPayloadSizeFromBuffer) { + getPayloadSizeFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP], 12); + getPayloadSizeFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP], 12); + getPayloadSizeFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP], 56); + getPayloadSizeFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP], 60); + getPayloadSizeFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH], 0); + getPayloadSizeFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH], 0); +} + +TEST_F(PacketTest, TestIsInterest) { + auto ret = PacketForTest::isInterest(&raw_packets_[HF_INET6_TCP][0]); + + EXPECT_TRUE(ret); +} + +TEST_F(PacketTest, TestGetFormatFromBuffer) { + getFormatFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP]); + getFormatFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP]); + getFormatFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP]); + getFormatFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP]); + getFormatFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH]); + getFormatFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH]); +} + +// TEST_F(PacketTest, TestReplace) { +// PacketForTest packet(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_TCP][0], +// raw_packets_[HF_INET6_TCP].size()); + +// // Replace current packet with another one +// packet.replace(&raw_packets_[HF_INET_TCP][0], +// raw_packets_[HF_INET_TCP].size()); + +// // Check new format +// ASSERT_EQ(packet.getFormat(), HF_INET_TCP); +// } + +TEST_F(PacketTest, TestPayloadSize) { + // Check payload size of existing packet + auto &_packet = raw_packets_[HF_INET6_TCP]; + PacketForTest packet(Packet::WRAP_BUFFER, &_packet[0], _packet.size(), + _packet.size()); + + EXPECT_EQ(packet.payloadSize(), std::size_t(PAYLOAD_SIZE)); + + // Check for dynamic generated packet + std::string payload0(1024, 'X'); + + // Create the packet + PacketForTest packet2(HF_INET6_TCP); + + // Payload size should now be zero + EXPECT_EQ(packet2.payloadSize(), std::size_t(0)); + + // Append payload 1 time + packet2.appendPayload((const uint8_t *)payload0.c_str(), payload0.size()); + + // size should now be 1024 + EXPECT_EQ(packet2.payloadSize(), std::size_t(1024)); + + // Append second payload + std::string payload1(1024, 'X'); + packet2.appendPayload((const uint8_t *)payload1.c_str(), payload1.size()); + + // Check size is 2048 + EXPECT_EQ(packet2.payloadSize(), std::size_t(2048)); + + // Append Membuf + packet2.appendPayload(utils::MemBuf::copyBuffer( + (const uint8_t *)payload1.c_str(), payload1.size())); + + // Check size is 3072 + EXPECT_EQ(packet2.payloadSize(), std::size_t(3072)); +} + +TEST_F(PacketTest, TestHeaderSize) { + getHeaderSize(HICN_V6_TCP_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_TCP)); + getHeaderSize(HICN_V4_TCP_HDRLEN, PacketForTest(Packet::Format::HF_INET_TCP)); + getHeaderSize(HICN_V6_ICMP_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_ICMP)); + getHeaderSize(HICN_V4_ICMP_HDRLEN, + PacketForTest(Packet::Format::HF_INET_ICMP)); + getHeaderSize(HICN_V6_TCP_AH_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_TCP_AH)); + getHeaderSize(HICN_V4_TCP_AH_HDRLEN, + PacketForTest(Packet::Format::HF_INET_TCP_AH)); +} + +TEST_F(PacketTest, TestMemBufReference) { + // Create packet + auto &_packet = raw_packets_[HF_INET6_TCP]; + + // Packet was not created as a shared_ptr. If we try to get a membuf shared + // ptr we should get an exception. + // TODO test with c++ 17 + // try { + // PacketForTest packet(&_packet[0], _packet.size()); + // auto membuf_ref = packet.acquireMemBufReference(); + // FAIL() << "The acquireMemBufReference() call should have throwed an " + // "exception!"; + // } catch (const std::bad_weak_ptr &e) { + // // Ok + // } catch (...) { + // FAIL() << "Not expected exception."; + // } + + auto packet_ptr = std::make_shared<PacketForTest>( + Packet::WRAP_BUFFER, &_packet[0], _packet.size(), _packet.size()); + PacketForTest &packet = *packet_ptr; + + // Acquire a reference to the membuf + auto membuf_ref = packet.acquireMemBufReference(); + + // Check refcount. It should be 2 + EXPECT_EQ(membuf_ref.use_count(), 2); + + // Now increment membuf references + Packet::MemBufPtr membuf = packet.acquireMemBufReference(); + + // Now reference count should be 2 + EXPECT_EQ(membuf_ref.use_count(), 3); + + // Copy again + Packet::MemBufPtr membuf2 = membuf; + + // Now reference count should be 3 + EXPECT_EQ(membuf_ref.use_count(), 4); +} + +TEST_F(PacketTest, TestReset) { + // Check everything is ok + EXPECT_EQ(packet.getFormat(), HF_INET6_TCP); + EXPECT_EQ(packet.length(), raw_packets_[HF_INET6_TCP].size()); + EXPECT_EQ(packet.headerSize(), HICN_V6_TCP_HDRLEN); + EXPECT_EQ(packet.payloadSize(), packet.length() - packet.headerSize()); + + // Reset the packet + packet.reset(); + + // Rerun test + EXPECT_EQ(packet.getFormat(), HF_UNSPEC); + EXPECT_EQ(packet.length(), std::size_t(0)); + EXPECT_EQ(packet.headerSize(), std::size_t(0)); + EXPECT_EQ(packet.payloadSize(), std::size_t(0)); +} + +TEST_F(PacketTest, TestAppendPayload) { + // Append payload with raw buffer + uint8_t raw_buffer[2048]; + auto original_payload_length = packet.payloadSize(); + packet.appendPayload(raw_buffer, 1024); + + EXPECT_EQ(original_payload_length + 1024, packet.payloadSize()); + + for (int i = 0; i < 10; i++) { + // Append other payload 10 times + packet.appendPayload(raw_buffer, 1024); + EXPECT_EQ(original_payload_length + 1024 + (1024) * (i + 1), + packet.payloadSize()); + } + + // Append payload using membuf + packet.appendPayload(utils::MemBuf::copyBuffer(raw_buffer, 2048)); + EXPECT_EQ(original_payload_length + 1024 + 1024 * 10 + 2048, + packet.payloadSize()); + + // Check the underlying MemBuf length is the expected one + utils::MemBuf *current = &packet; + size_t total = 0; + do { + total += current->length(); + current = current->next(); + } while (current != &packet); + + EXPECT_EQ(total, packet.headerSize() + packet.payloadSize()); + + // LEt's try now to reset this packet + packet.reset(); + + // There should be no more bufferls left in the chain + EXPECT_EQ(&packet, packet.next()); + EXPECT_EQ(packet.getFormat(), HF_UNSPEC); + EXPECT_EQ(packet.length(), std::size_t(0)); + EXPECT_EQ(packet.headerSize(), std::size_t(0)); + EXPECT_EQ(packet.payloadSize(), std::size_t(0)); +} + +TEST_F(PacketTest, GetPayload) { + // Append payload with raw buffer + uint8_t raw_buffer[2048]; + auto original_payload_length = packet.payloadSize(); + packet.appendPayload(raw_buffer, 2048); + + // Get payload + auto payload = packet.getPayload(); + // Check payload length is correct + utils::MemBuf *current = payload.get(); + size_t total = 0; + do { + total += current->length(); + current = current->next(); + } while (current != payload.get()); + + ASSERT_EQ(total, packet.payloadSize()); + + // Linearize the payload + payload->gather(total); + + // Check memory correspond + payload->trimStart(original_payload_length); + auto ret = memcmp(raw_buffer, payload->data(), 2048); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, UpdateLength) { + auto original_payload_size = packet.payloadSize(); + + // Add some fake payload without using the API + packet.append(200); + + // payloadSize does not know about the new payload, yet + EXPECT_EQ(packet.payloadSize(), original_payload_size); + + // Let's now update the packet length + packet.updateLength(); + + // Now payloadSize knows + EXPECT_EQ(packet.payloadSize(), std::size_t(original_payload_size + 200)); + + // We may also update the length without adding real content. This is only + // written in the packet header. + packet.updateLength(128); + EXPECT_EQ(packet.payloadSize(), + std::size_t(original_payload_size + 200 + 128)); +} + +TEST_F(PacketTest, SetGetPayloadType) { + auto payload_type = packet.getPayloadType(); + + // It should be normal content object by default + EXPECT_EQ(payload_type, PayloadType::DATA); + + // Set it to be manifest + packet.setPayloadType(PayloadType::MANIFEST); + + // Check it is manifest + payload_type = packet.getPayloadType(); + + EXPECT_EQ(payload_type, PayloadType::MANIFEST); +} + +TEST_F(PacketTest, GetFormat) { + { + PacketForTest p0(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_TCP][0], + raw_packets_[Packet::Format::HF_INET_TCP].size(), + raw_packets_[Packet::Format::HF_INET_TCP].size()); + testGetFormat(Packet::Format::HF_INET_TCP, p0); + + PacketForTest p1(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_TCP][0], + raw_packets_[Packet::Format::HF_INET6_TCP].size(), + raw_packets_[Packet::Format::HF_INET6_TCP].size()); + testGetFormat(Packet::Format::HF_INET6_TCP, p1); + + PacketForTest p2(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_ICMP][0], + raw_packets_[Packet::Format::HF_INET_ICMP].size(), + raw_packets_[Packet::Format::HF_INET_ICMP].size()); + testGetFormat(Packet::Format::HF_INET_ICMP, p2); + + PacketForTest p3(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_ICMP][0], + raw_packets_[Packet::Format::HF_INET6_ICMP].size(), + raw_packets_[Packet::Format::HF_INET6_ICMP].size()); + testGetFormat(Packet::Format::HF_INET6_ICMP, p3); + + PacketForTest p4(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_TCP_AH][0], + raw_packets_[Packet::Format::HF_INET_TCP_AH].size(), + raw_packets_[Packet::Format::HF_INET_TCP_AH].size()); + testGetFormat(Packet::Format::HF_INET_TCP_AH, p4); + + PacketForTest p5(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_TCP_AH][0], + raw_packets_[Packet::Format::HF_INET6_TCP_AH].size(), + raw_packets_[Packet::Format::HF_INET6_TCP_AH].size()); + testGetFormat(Packet::Format::HF_INET6_TCP_AH, p5); + } + + // Let's try now creating empty packets + { + PacketForTest p0(Packet::Format::HF_INET_TCP); + testGetFormat(Packet::Format::HF_INET_TCP, p0); + + PacketForTest p1(Packet::Format::HF_INET6_TCP); + testGetFormat(Packet::Format::HF_INET6_TCP, p1); + + PacketForTest p2(Packet::Format::HF_INET_ICMP); + testGetFormat(Packet::Format::HF_INET_ICMP, p2); + + PacketForTest p3(Packet::Format::HF_INET6_ICMP); + testGetFormat(Packet::Format::HF_INET6_ICMP, p3); + + PacketForTest p4(Packet::Format::HF_INET_TCP_AH); + testGetFormat(Packet::Format::HF_INET_TCP_AH, p4); + + PacketForTest p5(Packet::Format::HF_INET6_TCP_AH); + testGetFormat(Packet::Format::HF_INET6_TCP_AH, p5); + } +} + +TEST_F(PacketTest, SetGetTestSignatureTimestamp) { + // Let's try to set the signature timestamp in a packet without AH header. We + // expect an exception. + using namespace std::chrono; + uint64_t now = + duration_cast<milliseconds>(system_clock::now().time_since_epoch()) + .count(); + + try { + packet.setSignatureTimestamp(now); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto t = packet.getSignatureTimestamp(); + // Let's make compiler happy + (void)t; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setSignatureTimestamp(now); + uint64_t now_get = p.getSignatureTimestamp(); + + // Check we got the right value + EXPECT_EQ(now_get, now); +} + +TEST_F(PacketTest, TestSetGetValidationAlgorithm) { + // Let's try to set the validation algorithm in a packet without AH header. We + // expect an exception. + + try { + packet.setValidationAlgorithm(auth::CryptoSuite::RSA_SHA256); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto v = packet.getSignatureTimestamp(); + // Let's make compiler happy + (void)v; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setValidationAlgorithm(auth::CryptoSuite::RSA_SHA256); + auto v_get = p.getValidationAlgorithm(); + + // Check we got the right value + EXPECT_EQ(v_get, auth::CryptoSuite::RSA_SHA256); +} + +TEST_F(PacketTest, TestSetGetKeyId) { + uint8_t key[32]; + auth::KeyId key_id = std::make_pair(key, sizeof(key)); + + try { + packet.setKeyId(key_id); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto k = packet.getKeyId(); + // Let's make compiler happy + (void)k; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setKeyId(key_id); + auto p_get = p.getKeyId(); + + // Check we got the right value + EXPECT_EQ(p_get.second, key_id.second); + + auto ret = memcmp(p_get.first, key_id.first, p_get.second); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, DISABLED_TestChecksum) { + // Checksum should be wrong + bool integrity = packet.checkIntegrity(); + EXPECT_FALSE(integrity); + + // Let's fix it + packet.setChecksum(); + + // Check again + integrity = packet.checkIntegrity(); + EXPECT_TRUE(integrity); + + // Check with AH header and 300 bytes signature + PacketForTest p(HF_INET6_TCP_AH, 300); + std::string payload(5000, 'X'); + p.appendPayload((const uint8_t *)payload.c_str(), payload.size() / 2); + p.appendPayload((const uint8_t *)(payload.c_str() + payload.size() / 2), + payload.size() / 2); + + p.setChecksum(); + integrity = p.checkIntegrity(); + EXPECT_TRUE(integrity); +} + +TEST_F(PacketTest, TestSetSyn) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setSyn(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setSyn(); + EXPECT_TRUE(packet.testSyn()); + + packet.resetSyn(); + EXPECT_FALSE(packet.testSyn()); +} + +TEST_F(PacketTest, TestSetFin) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setFin(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setFin(); + EXPECT_TRUE(packet.testFin()); + + packet.resetFin(); + EXPECT_FALSE(packet.testFin()); +} + +TEST_F(PacketTest, TestSetAck) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setAck(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setAck(); + EXPECT_TRUE(packet.testAck()); + + packet.resetAck(); + EXPECT_FALSE(packet.testAck()); +} + +TEST_F(PacketTest, TestSetRst) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setRst(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setRst(); + EXPECT_TRUE(packet.testRst()); + + packet.resetRst(); + EXPECT_FALSE(packet.testRst()); +} + +TEST_F(PacketTest, TestResetFlags) { + packet.setRst(); + packet.setSyn(); + packet.setAck(); + packet.setFin(); + EXPECT_TRUE(packet.testRst()); + EXPECT_TRUE(packet.testAck()); + EXPECT_TRUE(packet.testFin()); + EXPECT_TRUE(packet.testSyn()); + + packet.resetFlags(); + EXPECT_FALSE(packet.testRst()); + EXPECT_FALSE(packet.testAck()); + EXPECT_FALSE(packet.testFin()); + EXPECT_FALSE(packet.testSyn()); +} + +TEST_F(PacketTest, TestSetGetSrcPort) { + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setSrcPort(12345); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setSrcPort(12345); + EXPECT_EQ(packet.getSrcPort(), 12345); +} + +TEST_F(PacketTest, TestSetGetDstPort) { + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setDstPort(12345); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setDstPort(12345); + EXPECT_EQ(packet.getDstPort(), 12345); +} + +TEST_F(PacketTest, TestEnsureCapacity) { + PacketForTest &p = packet; + + // This shoul be false + auto ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size() + 10); + EXPECT_FALSE(ret); + + // This should be true + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size()); + EXPECT_TRUE(ret); + + // This should be true + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size() - 10); + EXPECT_TRUE(ret); + + // Try to trim the packet start + p.trimStart(10); + // Now this should be false + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size()); + EXPECT_FALSE(ret); + + // Create a new packet + auto p2 = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + + p2.appendPayload(utils::MemBuf::createCombined(2000)); + + // This should be false, since the buffer is chained + ret = p2.ensureCapacity(raw_packets_[HF_INET6_TCP].size() - 10); + EXPECT_FALSE(ret); +} + +TEST_F(PacketTest, TestEnsureCapacityAndFillUnused) { + // Create packet by excluding the payload (So only L3 + L4 headers). The + // payload will be trated as unused tailroom + PacketForTest p = + PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_TCP][0], + raw_packets_[HF_INET6_TCP].size() - PAYLOAD_SIZE, + raw_packets_[HF_INET6_TCP].size()); + + // Copy original packet payload, which is here trated as a unused tailroom + uint8_t original_payload[PAYLOAD_SIZE]; + uint8_t *payload = &raw_packets_[HF_INET6_TCP][0] + + raw_packets_[HF_INET6_TCP].size() - PAYLOAD_SIZE; + std::memcpy(original_payload, payload, PAYLOAD_SIZE); + + // This should be true and the unused tailroom should be unmodified + auto ret = p.ensureCapacityAndFillUnused( + raw_packets_[HF_INET6_TCP].size() - (PAYLOAD_SIZE + 10), 0); + EXPECT_TRUE(ret); + ret = std::memcmp(original_payload, payload, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should fill the payload with zeros + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size(), 0); + EXPECT_TRUE(ret); + uint8_t zeros[PAYLOAD_SIZE]; + std::memset(zeros, 0, PAYLOAD_SIZE); + ret = std::memcmp(payload, zeros, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should fill the payload with ones + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size(), 1); + EXPECT_TRUE(ret); + uint8_t ones[PAYLOAD_SIZE]; + std::memset(ones, 1, PAYLOAD_SIZE); + ret = std::memcmp(payload, ones, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should return false and the payload should be unmodified + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size() + 1, 1); + EXPECT_FALSE(ret); + ret = std::memcmp(payload, ones, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, TestSetGetTTL) { + packet.setTTL(128); + EXPECT_EQ(packet.getTTL(), 128); +} + +} // namespace core +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libtransport/src/transport.config b/libtransport/src/transport.config new file mode 100644 index 000000000..a21175b8d --- /dev/null +++ b/libtransport/src/transport.config @@ -0,0 +1,27 @@ +// Configuration for io_module + +io_module = { + path = []; + name = "forwarder_module"; +}; + +forwarder = { + n_threads = 1; + + listeners = { + l0 = { + local_address = "127.0.0.1"; + local_port = 33436; + } + }; + + connectors = { + c0 = { + /* local_address and local_port are optional */ + local_address = "127.0.0.1"; + local_port = 33436; + remote_address = "10.20.30.40"; + remote_port = 33436; + } + }; +};
\ No newline at end of file diff --git a/libtransport/src/utils/content_store.cc b/libtransport/src/utils/content_store.cc index cb3db6d94..c5cb91149 100644 --- a/libtransport/src/utils/content_store.cc +++ b/libtransport/src/utils/content_store.cc @@ -17,7 +17,6 @@ #include <hicn/transport/core/interest.h> #include <hicn/transport/core/name.h> #include <hicn/transport/utils/log.h> - #include <utils/content_store.h> namespace utils { @@ -60,8 +59,7 @@ void ContentStore::insert( ObjectTimeEntry(content_object, std::chrono::steady_clock::now()), pos); } -const std::shared_ptr<ContentObject> ContentStore::find( - const Interest &interest) { +std::shared_ptr<ContentObject> ContentStore::find(const Interest &interest) { utils::SpinLock::Acquire locked(cs_mutex_); std::shared_ptr<ContentObject> ret = empty_reference_; diff --git a/libtransport/src/utils/content_store.h b/libtransport/src/utils/content_store.h index 03ce76f42..56cd2abb6 100644 --- a/libtransport/src/utils/content_store.h +++ b/libtransport/src/utils/content_store.h @@ -52,7 +52,7 @@ class ContentStore { void insert(const std::shared_ptr<ContentObject> &content_object); - const std::shared_ptr<ContentObject> find(const Interest &interest); + std::shared_ptr<ContentObject> find(const Interest &interest); void erase(const Name &exact_name); diff --git a/libtransport/src/utils/daemonizator.cc b/libtransport/src/utils/daemonizator.cc index c51a68d14..bc7bae700 100644 --- a/libtransport/src/utils/daemonizator.cc +++ b/libtransport/src/utils/daemonizator.cc @@ -17,7 +17,6 @@ #include <hicn/transport/errors/runtime_exception.h> #include <hicn/transport/utils/daemonizator.h> #include <hicn/transport/utils/log.h> - #include <sys/stat.h> #include <unistd.h> diff --git a/libtransport/src/utils/epoll_event_reactor.cc b/libtransport/src/utils/epoll_event_reactor.cc index 63c08df95..eb8c65352 100644 --- a/libtransport/src/utils/epoll_event_reactor.cc +++ b/libtransport/src/utils/epoll_event_reactor.cc @@ -14,12 +14,11 @@ */ #include <hicn/transport/utils/branch_prediction.h> - +#include <signal.h> +#include <unistd.h> #include <utils/epoll_event_reactor.h> #include <utils/fd_deadline_timer.h> -#include <signal.h> -#include <unistd.h> #include <iostream> namespace utils { @@ -111,7 +110,7 @@ void EpollEventReactor::runEventLoop(int timeout) { if (errno == EINTR) { continue; } else { - return; + return; } } diff --git a/libtransport/src/utils/epoll_event_reactor.h b/libtransport/src/utils/epoll_event_reactor.h index 4cb87ebd4..9ebfca937 100644 --- a/libtransport/src/utils/epoll_event_reactor.h +++ b/libtransport/src/utils/epoll_event_reactor.h @@ -16,9 +16,9 @@ #pragma once #include <hicn/transport/utils/spinlock.h> +#include <sys/epoll.h> #include <utils/event_reactor.h> -#include <sys/epoll.h> #include <atomic> #include <cstddef> #include <functional> diff --git a/libtransport/src/utils/fd_deadline_timer.h b/libtransport/src/utils/fd_deadline_timer.h index 8bc3bbca3..38396e027 100644 --- a/libtransport/src/utils/fd_deadline_timer.h +++ b/libtransport/src/utils/fd_deadline_timer.h @@ -17,16 +17,14 @@ #include <hicn/transport/errors/runtime_exception.h> #include <hicn/transport/utils/log.h> - +#include <sys/timerfd.h> +#include <unistd.h> #include <utils/deadline_timer.h> #include <utils/epoll_event_reactor.h> #include <chrono> #include <cstddef> -#include <sys/timerfd.h> -#include <unistd.h> - namespace utils { class FdDeadlineTimer : public DeadlineTimer<FdDeadlineTimer> { diff --git a/libtransport/src/utils/membuf.cc b/libtransport/src/utils/membuf.cc index 94e5b13a1..73c45cf6d 100644 --- a/libtransport/src/utils/membuf.cc +++ b/libtransport/src/utils/membuf.cc @@ -145,6 +145,18 @@ void MemBuf::operator delete(void* /* ptr */, void* /* placement */) { // constructor. } +bool MemBuf::operator==(const MemBuf& other) { + if (length() != other.length()) { + return false; + } + + return (memcmp(data(), other.data(), length()) == 0); +} + +bool MemBuf::operator!=(const MemBuf& other) { + return !this->operator==(other); +} + void MemBuf::releaseStorage(HeapStorage* storage, uint16_t freeFlags) { // Use relaxed memory order here. If we are unlucky and happen to get // out-of-date data the compare_exchange_weak() call below will catch @@ -299,21 +311,23 @@ unique_ptr<MemBuf> MemBuf::takeOwnership(void* buf, std::size_t capacity, } } -MemBuf::MemBuf(WrapBufferOp, const void* buf, std::size_t capacity) noexcept +MemBuf::MemBuf(WrapBufferOp, const void* buf, std::size_t length, + std::size_t capacity) noexcept : MemBuf(InternalConstructor(), 0, // We cast away the const-ness of the buffer here. // This is okay since MemBuf users must use unshare() to create a // copy of this buffer before writing to the buffer. static_cast<uint8_t*>(const_cast<void*>(buf)), capacity, - static_cast<uint8_t*>(const_cast<void*>(buf)), capacity) {} + static_cast<uint8_t*>(const_cast<void*>(buf)), length) {} -unique_ptr<MemBuf> MemBuf::wrapBuffer(const void* buf, std::size_t capacity) { - return std::make_unique<MemBuf>(WRAP_BUFFER, buf, capacity); +unique_ptr<MemBuf> MemBuf::wrapBuffer(const void* buf, std::size_t length, + std::size_t capacity) { + return std::make_unique<MemBuf>(WRAP_BUFFER, buf, length, capacity); } -MemBuf MemBuf::wrapBufferAsValue(const void* buf, +MemBuf MemBuf::wrapBufferAsValue(const void* buf, std::size_t length, std::size_t capacity) noexcept { - return MemBuf(WrapBufferOp::WRAP_BUFFER, buf, capacity); + return MemBuf(WrapBufferOp::WRAP_BUFFER, buf, length, capacity); } MemBuf::MemBuf() noexcept {} @@ -862,4 +876,22 @@ void MemBuf::initExtBuffer(uint8_t* buf, size_t mallocSize, *infoReturn = sharedInfo; } +bool MemBuf::ensureCapacity(std::size_t capacity) { + return !isChained() && std::size_t((bufferEnd() - data())) >= capacity; +} + +bool MemBuf::ensureCapacityAndFillUnused(std::size_t capacity, + uint8_t placeholder) { + auto ret = ensureCapacity(capacity); + if (!ret) { + return ret; + } + + if (length() < capacity) { + std::memset(writableTail(), placeholder, capacity - length()); + } + + return ret; +} + } // namespace utils
\ No newline at end of file diff --git a/libtransport/src/utils/memory_pool_allocator.h b/libtransport/src/utils/memory_pool_allocator.h index adc1443ad..a960b91bb 100644 --- a/libtransport/src/utils/memory_pool_allocator.h +++ b/libtransport/src/utils/memory_pool_allocator.h @@ -149,4 +149,4 @@ class Allocator : private MemoryPool<T, growSize> { void destroy(pointer p) { p->~T(); } }; -}
\ No newline at end of file +} // namespace utils
\ No newline at end of file diff --git a/libtransport/src/utils/min_filter.h b/libtransport/src/utils/min_filter.h index dcfd5652d..f1aaea7a8 100644 --- a/libtransport/src/utils/min_filter.h +++ b/libtransport/src/utils/min_filter.h @@ -43,6 +43,11 @@ class MinFilter { by_arrival_.push_front(by_order_.insert(std::forward<R>(value))); } + TRANSPORT_ALWAYS_INLINE void clear() { + by_arrival_.clear(); + by_order_.clear(); + } + TRANSPORT_ALWAYS_INLINE const T& begin() { return *by_order_.cbegin(); } TRANSPORT_ALWAYS_INLINE const T& rBegin() { return *by_order_.crbegin(); } |