diff options
author | Luca Muscariello <muscariello@ieee.org> | 2021-04-15 09:05:46 +0200 |
---|---|---|
committer | Mauro Sardara <msardara@cisco.com> | 2021-04-15 16:36:16 +0200 |
commit | e92e9e839ca2cf42b56322b2489ccc0d8bf767af (patch) | |
tree | 9f1647c83a87fbf982ae329e800af25dbfb226b5 /libtransport/src | |
parent | 3e541d7c947cc2f9db145f26c9274efd29a6fb56 (diff) |
[HICN-690] Transport Library Major Refactory
The current patch provides a major refactory of the transportlibrary.
A summary of the different components that underwent major modifications is
reported below.
- Transport protocol updates
The hierarchy of classes has been optimized to have common transport services
across different transport protocols. This can allow to customize a transport
protocol with new features.
- A new real-time communication protocol
The RTC protocol has been optimized in terms of algorithms to reduce
consumer-producer synchronization latency.
- A novel socket API
The API has been reworked to be easier to consumer but also to have a more
efficient integration in L4 proxies.
- Several performance improvements
A large number of performance improvements have been included in
particular to make the entire stack zero-copy and optimize cache miss.
- New memory buffer framework
Memory management has been reworked entirely to provide a more efficient infra
with a richer API. Buffers are now allocated in blocks and a single buffer
holds the memory for (1) the shared_ptr control block, (2) the metadata of the
packet (e.g. name, pointer to other buffers if buffer is chained and relevant
offsets), and (3) the packet itself, as it is sent/received over the network.
- A new slab allocator
Dynamic memory allocation is now managed by a novel slab allocator that is
optimised for packet processing and connection management. Memory is organized
in pools of blocks all of the same size which are used during the processing of
outgoing/incoming packets. When a memory block Is allocated is always taken
from a global pool and when it is deallocated is returned to the pool, thus
avoiding the cost of any heap allocation in the data path.
- New transport connectors
Consumer and producer end-points can communication either using an hicn packet
forwarder or with direct connector based on shared memories or sockets.
The usage of transport connectors typically for unit and funcitonal
testing but may have additional usage.
- Support for FEC/ECC for transport services
FEC/ECC via reed solomon is supported by default and made available to
transport services as a modular component. Reed solomon block codes is a
default FEC model that can be replaced in a modular way by many other
codes including RLNC not avaiable in this distribution.
The current FEC framework support variable size padding and efficiently
makes use of the infra memory buffers to avoid additiona copies.
- Secure transport framework for signature computation and verification
Crypto support is nativelty used in hICN for integrity and authenticity.
Novel support that includes RTC has been implemented and made modular
and reusable acrosso different transport protocols.
- TLS - Transport layer security over hicn
Point to point confidentiality is provided by integrating TLS on top of
hICN reliable and non-reliable transport. The integration is common and
makes a different use of the TLS record.
- MLS - Messaging layer security over hicn
MLS integration on top of hICN is made by using the MLSPP implemetation
open sourced by Cisco. We have included instrumentation tools to deploy
performance and functional tests of groups of end-points.
- Android support
The overall code has been heavily tested in Android environments and
has received heavy lifting to better run natively in recent Android OS.
Co-authored-by: Mauro Sardara <msardara@cisco.com>
Co-authored-by: Michele Papalini <micpapal@cisco.com>
Co-authored-by: Olivier Roques <oroques+fdio@cisco.com>
Co-authored-by: Giulio Grassi <gigrassi@cisco.com>
Change-Id: If477ba2fa686e6f47bdf96307ac60938766aef69
Signed-off-by: Luca Muscariello <muscariello@ieee.org>
Diffstat (limited to 'libtransport/src')
169 files changed, 16822 insertions, 1900 deletions
diff --git a/libtransport/src/CMakeLists.txt b/libtransport/src/CMakeLists.txt index 33497e0f4..0fa9bbe3c 100644 --- a/libtransport/src/CMakeLists.txt +++ b/libtransport/src/CMakeLists.txt @@ -20,7 +20,7 @@ set(ASIO_STANDALONE 1) add_subdirectory(core) add_subdirectory(interfaces) add_subdirectory(protocols) -add_subdirectory(security) +add_subdirectory(auth) add_subdirectory(implementation) add_subdirectory(utils) add_subdirectory(http) @@ -34,7 +34,16 @@ install( COMPONENT lib${LIBTRANSPORT}-dev ) -set (COMPILER_DEFINITIONS "-DTRANSPORT_LOG_DEF_LEVEL=TRANSPORT_LOG_${TRANSPORT_LOG_LEVEL}") +install( + FILES "transport.config" + DESTINATION ${CMAKE_INSTALL_FULL_SYSCONFDIR}/hicn + COMPONENT lib${LIBTRANSPORT} +) + +list(APPEND COMPILER_DEFINITIONS + "-DTRANSPORT_LOG_DEF_LEVEL=TRANSPORT_LOG_${TRANSPORT_LOG_LEVEL}" + "-DASIO_STANDALONE" +) list(INSERT LIBTRANSPORT_INTERNAL_INCLUDE_DIRS 0 ${CMAKE_CURRENT_SOURCE_DIR}/ @@ -55,8 +64,10 @@ else () set(CMAKE_SHARED_LINKER_FLAGS "/NODEFAULTLIB:\"MSVCRTD\"" ) endif () endif () -if (${CMAKE_SYSTEM_NAME} STREQUAL "Android") + +if (${CMAKE_SYSTEM_NAME} MATCHES "Android") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++ -isystem -lm") + add_subdirectory(io_modules) endif() if (DISABLE_SHARED_LIBRARIES) @@ -68,8 +79,9 @@ if (DISABLE_SHARED_LIBRARIES) DEPENDS ${DEPENDENCIES} COMPONENT lib${LIBTRANSPORT} INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} - INSTALL_ROOT_DIR hicn/transport + HEADER_ROOT_DIR hicn/transport DEFINITIONS ${COMPILER_DEFINITIONS} + VERSION ${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} ) else () build_library(${LIBTRANSPORT} @@ -80,11 +92,17 @@ else () DEPENDS ${DEPENDENCIES} COMPONENT lib${LIBTRANSPORT} INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} - INSTALL_ROOT_DIR hicn/transport + HEADER_ROOT_DIR hicn/transport DEFINITIONS ${COMPILER_DEFINITIONS} + VERSION ${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_REVISION} ) endif () +# io modules +if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Android") + add_subdirectory(io_modules) +endif() + if (${BUILD_TESTS}) add_subdirectory(test) endif() diff --git a/libtransport/src/auth/CMakeLists.txt b/libtransport/src/auth/CMakeLists.txt new file mode 100644 index 000000000..0e7b5832b --- /dev/null +++ b/libtransport/src/auth/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2017-2019 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/signer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/verifier.cc + ${CMAKE_CURRENT_SOURCE_DIR}/identity.cc +) + +set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) diff --git a/libtransport/src/auth/identity.cc b/libtransport/src/auth/identity.cc new file mode 100644 index 000000000..bd787b9b6 --- /dev/null +++ b/libtransport/src/auth/identity.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/identity.h> + +using namespace std; + +namespace transport { +namespace auth { + +Identity::Identity(const string &keystore_path, const string &keystore_pwd, + CryptoSuite suite, unsigned int signature_len, + unsigned int validity_days, const string &subject_name) + : identity_(nullptr), signer_(nullptr) { + parcSecurity_Init(); + + bool success = parcPkcs12KeyStore_CreateFile( + keystore_path.c_str(), keystore_pwd.c_str(), subject_name.c_str(), + parcCryptoSuite_GetSigningAlgorithm(static_cast<PARCCryptoSuite>(suite)), + signature_len, validity_days); + + parcAssertTrue( + success, + "parcPkcs12KeyStore_CreateFile('%s', '%s', '%s', %d, %d, %d) failed.", + keystore_path.c_str(), keystore_pwd.c_str(), subject_name.c_str(), + static_cast<int>(suite), static_cast<int>(signature_len), validity_days); + + PARCIdentityFile *identity_file = + parcIdentityFile_Create(keystore_path.c_str(), keystore_pwd.c_str()); + + identity_ = + parcIdentity_Create(identity_file, PARCIdentityFileAsPARCIdentity); + + PARCSigner *signer = parcIdentity_CreateSigner( + identity_, + parcCryptoSuite_GetCryptoHash(static_cast<PARCCryptoSuite>(suite))); + + signer_ = make_shared<AsymmetricSigner>(signer); + + parcSigner_Release(&signer); + parcIdentityFile_Release(&identity_file); +} + +Identity::Identity(string &keystore_path, string &keystore_pwd, + CryptoHashType hash_type) + : identity_(nullptr), signer_(nullptr) { + parcSecurity_Init(); + + PARCIdentityFile *identity_file = + parcIdentityFile_Create(keystore_path.c_str(), keystore_pwd.c_str()); + + identity_ = + parcIdentity_Create(identity_file, PARCIdentityFileAsPARCIdentity); + + PARCSigner *signer = parcIdentity_CreateSigner( + identity_, static_cast<PARCCryptoHashType>(hash_type)); + + signer_ = make_shared<AsymmetricSigner>(signer); + + parcSigner_Release(&signer); + parcIdentityFile_Release(&identity_file); +} + +Identity::Identity(const Identity &other) + : identity_(nullptr), signer_(other.signer_) { + parcSecurity_Init(); + identity_ = parcIdentity_Acquire(other.identity_); +} + +Identity::Identity(Identity &&other) + : identity_(nullptr), signer_(move(other.signer_)) { + parcSecurity_Init(); + identity_ = parcIdentity_Acquire(other.identity_); + parcIdentity_Release(&other.identity_); +} + +Identity::~Identity() { + if (identity_) parcIdentity_Release(&identity_); + parcSecurity_Fini(); +} + +shared_ptr<AsymmetricSigner> Identity::getSigner() const { return signer_; } + +string Identity::getFilename() const { + return string(parcIdentity_GetFileName(identity_)); +} + +string Identity::getPassword() const { + return string(parcIdentity_GetPassWord(identity_)); +} + +Identity Identity::generateIdentity(const string &subject_name) { + string keystore_name = "keystore"; + string keystore_password = "password"; + size_t key_length = 1024; + unsigned int validity_days = 30; + CryptoSuite suite = CryptoSuite::RSA_SHA256; + + return Identity(keystore_name, keystore_password, suite, + (unsigned int)key_length, validity_days, subject_name); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/auth/signer.cc b/libtransport/src/auth/signer.cc new file mode 100644 index 000000000..281b9c59a --- /dev/null +++ b/libtransport/src/auth/signer.cc @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/signer.h> + +extern "C" { +#ifndef _WIN32 +TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") +#endif +#include <hicn/hicn.h> +} + +#include <chrono> + +#define ALLOW_UNALIGNED_READS 1 + +using namespace std; + +namespace transport { +namespace auth { + +Signer::Signer() : signer_(nullptr), key_id_(nullptr) { parcSecurity_Init(); } + +Signer::Signer(PARCSigner *signer) : Signer() { setSigner(signer); } + +Signer::~Signer() { + if (signer_) parcSigner_Release(&signer_); + if (key_id_) parcKeyId_Release(&key_id_); + parcSecurity_Fini(); +} + +void Signer::signPacket(PacketPtr packet) { + parcAssertNotNull(signer_, "Expected non-null signer"); + + const utils::MemBuf &header_chain = *packet; + core::Packet::Format format = packet->getFormat(); + auto suite = getCryptoSuite(); + size_t signature_len = getSignatureSize(); + + if (!packet->authenticationHeader()) { + throw errors::MalformedAHPacketException(); + } + + packet->setSignatureSize(signature_len); + + // Copy IP+TCP / ICMP header before zeroing them + hicn_header_t header_copy; + hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + packet->resetForHash(); + + // Fill in the HICN_AH header + auto now = chrono::duration_cast<chrono::milliseconds>( + chrono::system_clock::now().time_since_epoch()) + .count(); + packet->setSignatureTimestamp(now); + packet->setValidationAlgorithm(suite); + + // Set the key ID + KeyId key_id; + key_id.first = static_cast<uint8_t *>( + parcBuffer_Overlay((PARCBuffer *)parcKeyId_GetKeyId(key_id_), 0)); + packet->setKeyId(key_id); + + // Calculate hash + CryptoHasher hasher(parcSigner_GetCryptoHasher(signer_)); + const utils::MemBuf *current = &header_chain; + + hasher.init(); + + do { + hasher.updateBytes(current->data(), current->length()); + current = current->next(); + } while (current != &header_chain); + + CryptoHash hash = hasher.finalize(); + + // Compute signature + PARCSignature *signature = parcSigner_SignDigestNoAlloc( + signer_, hash.hash_, packet->getSignature(), signature_len); + PARCBuffer *buffer = parcSignature_GetSignature(signature); + size_t bytes_len = parcBuffer_Remaining(buffer); + + if (bytes_len > signature_len) { + throw errors::MalformedAHPacketException(); + } + + // Put signature in AH header + hicn_packet_copy_header(format, &header_copy, packet->packet_start_, false); + + // Release allocated objects + parcSignature_Release(&signature); +} + +void Signer::setSigner(PARCSigner *signer) { + parcAssertNotNull(signer, "Expected non-null signer"); + + if (signer_) parcSigner_Release(&signer_); + if (key_id_) parcKeyId_Release(&key_id_); + + signer_ = parcSigner_Acquire(signer); + key_id_ = parcSigner_CreateKeyId(signer_); +} + +size_t Signer::getSignatureSize() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return parcSigner_GetSignatureSize(signer_); +} + +CryptoSuite Signer::getCryptoSuite() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return static_cast<CryptoSuite>(parcSigner_GetCryptoSuite(signer_)); +} + +CryptoHashType Signer::getCryptoHashType() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return static_cast<CryptoHashType>(parcSigner_GetCryptoHashType(signer_)); +} + +PARCSigner *Signer::getParcSigner() const { return signer_; } + +PARCKeyStore *Signer::getParcKeyStore() const { + parcAssertNotNull(signer_, "Expected non-null signer"); + return parcSigner_GetKeyStore(signer_); +} + +AsymmetricSigner::AsymmetricSigner(CryptoSuite suite, PARCKeyStore *key_store) { + parcAssertNotNull(key_store, "Expected non-null key_store"); + + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::DSA_SHA256: + case CryptoSuite::RSA_SHA256: + case CryptoSuite::RSA_SHA512: + case CryptoSuite::ECDSA_256K1: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for asymmetric signer"); + } + + setSigner( + parcSigner_Create(parcPublicKeySigner_Create(key_store, crypto_suite), + PARCPublicKeySignerAsSigner)); +} + +SymmetricSigner::SymmetricSigner(CryptoSuite suite, PARCKeyStore *key_store) { + parcAssertNotNull(key_store, "Expected non-null key_store"); + + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for symmetric signer"); + } + + setSigner(parcSigner_Create(parcSymmetricKeySigner_Create( + (PARCSymmetricKeyStore *)key_store, + parcCryptoSuite_GetCryptoHash(crypto_suite)), + PARCSymmetricKeySignerAsSigner)); +} + +SymmetricSigner::SymmetricSigner(CryptoSuite suite, const string &passphrase) { + auto crypto_suite = static_cast<PARCCryptoSuite>(suite); + + switch (suite) { + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: + break; + default: + throw errors::RuntimeException( + "Invalid crypto suite for symmetric signer"); + } + + PARCBufferComposer *composer = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer, passphrase.c_str()); + PARCBuffer *key_buf = parcBufferComposer_ProduceBuffer(composer); + parcBufferComposer_Release(&composer); + + PARCSymmetricKeyStore *key_store = parcSymmetricKeyStore_Create(key_buf); + PARCSymmetricKeySigner *key_signer = parcSymmetricKeySigner_Create( + key_store, parcCryptoSuite_GetCryptoHash(crypto_suite)); + + setSigner(parcSigner_Create(key_signer, PARCSymmetricKeySignerAsSigner)); + + parcSymmetricKeySigner_Release(&key_signer); + parcSymmetricKeyStore_Release(&key_store); + parcBuffer_Release(&key_buf); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/auth/verifier.cc b/libtransport/src/auth/verifier.cc new file mode 100644 index 000000000..c6648a763 --- /dev/null +++ b/libtransport/src/auth/verifier.cc @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/auth/verifier.h> +#include <protocols/errors.h> + +extern "C" { +#ifndef _WIN32 +TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") +#endif +#include <hicn/hicn.h> +} + +#include <sys/stat.h> + +using namespace std; + +namespace transport { +namespace auth { + +const std::vector<VerificationPolicy> Verifier::DEFAULT_FAILED_POLICIES = { + VerificationPolicy::DROP, + VerificationPolicy::ABORT, +}; + +Verifier::Verifier() + : hasher_(nullptr), + verifier_(nullptr), + verification_failed_cb_(interface::VOID_HANDLER), + failed_policies_(DEFAULT_FAILED_POLICIES) { + parcSecurity_Init(); + PARCInMemoryVerifier *in_memory_verifier = parcInMemoryVerifier_Create(); + verifier_ = + parcVerifier_Create(in_memory_verifier, PARCInMemoryVerifierAsVerifier); + parcInMemoryVerifier_Release(&in_memory_verifier); +} + +Verifier::~Verifier() { + if (hasher_) parcCryptoHasher_Release(&hasher_); + if (verifier_) parcVerifier_Release(&verifier_); + parcSecurity_Fini(); +} + +bool Verifier::verifyPacket(PacketPtr packet) { + bool valid_packet = false; + core::Packet::Format format = packet->getFormat(); + + if (!packet->authenticationHeader()) { + throw errors::MalformedAHPacketException(); + } + + // Get crypto suite and hash type + auto suite = static_cast<PARCCryptoSuite>(packet->getValidationAlgorithm()); + PARCCryptoHashType hash_type = parcCryptoSuite_GetCryptoHash(suite); + + // Copy IP+TCP / ICMP header before zeroing them + hicn_header_t header_copy; + hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + + // Fetch packet signature + uint8_t *packet_signature = packet->getSignature(); + size_t signature_len = Verifier::getSignatureSize(packet); + vector<uint8_t> signature_raw(packet_signature, + packet_signature + signature_len); + + // Create a signature buffer from the raw packet signature + PARCBuffer *bits = + parcBuffer_Wrap(signature_raw.data(), signature_len, 0, signature_len); + parcBuffer_Rewind(bits); + + // If the signature algo is ECDSA, the signature might be shorter than the + // signature field + PARCSigningAlgorithm algo = parcCryptoSuite_GetSigningAlgorithm(suite); + if (algo == PARCSigningAlgorithm_ECDSA) { + while (parcBuffer_HasRemaining(bits) && parcBuffer_GetUint8(bits) == 0) + ; + parcBuffer_SetPosition(bits, parcBuffer_Position(bits) - 1); + } + + if (!parcBuffer_HasRemaining(bits)) { + parcBuffer_Release(&bits); + return false; + } + + // Create a signature object from the signature buffer + PARCSignature *signature = parcSignature_Create( + parcCryptoSuite_GetSigningAlgorithm(suite), hash_type, bits); + + // Fetch the key to verify the signature + KeyId key_buffer = packet->getKeyId(); + PARCBuffer *buffer = parcBuffer_Wrap(key_buffer.first, key_buffer.second, 0, + key_buffer.second); + PARCKeyId *key_id = parcKeyId_Create(buffer); + + // Reset fields that are not used to compute signature + packet->resetForHash(); + + // Compute the packet hash + if (!hasher_) + setHasher(parcVerifier_GetCryptoHasher(verifier_, key_id, hash_type)); + CryptoHash local_hash = computeHash(packet); + + // Compare the packet signature to the locally computed one + valid_packet = parcVerifier_VerifyDigestSignature( + verifier_, key_id, local_hash.hash_, suite, signature); + + // Restore the fields that were reset + hicn_packet_copy_header(format, &header_copy, packet->packet_start_, false); + + // Release allocated objects + parcBuffer_Release(&buffer); + parcKeyId_Release(&key_id); + parcSignature_Release(&signature); + parcBuffer_Release(&bits); + + return valid_packet; +} + +vector<VerificationPolicy> Verifier::verifyPackets( + const vector<PacketPtr> &packets) { + vector<VerificationPolicy> policies(packets.size(), VerificationPolicy::DROP); + + for (unsigned int i = 0; i < packets.size(); ++i) { + if (verifyPacket(packets[i])) { + policies[i] = VerificationPolicy::ACCEPT; + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +vector<VerificationPolicy> Verifier::verifyPackets( + const vector<PacketPtr> &packets, + const unordered_map<Suffix, HashEntry> &suffix_map) { + vector<VerificationPolicy> policies(packets.size(), + VerificationPolicy::UNKNOWN); + + for (unsigned int i = 0; i < packets.size(); ++i) { + uint32_t suffix = packets[i]->getName().getSuffix(); + auto manifest_hash = suffix_map.find(suffix); + + if (manifest_hash != suffix_map.end()) { + CryptoHashType hash_type = manifest_hash->second.first; + CryptoHash packet_hash = packets[i]->computeDigest(hash_type); + + if (!CryptoHash::compareBinaryDigest( + packet_hash.getDigest<uint8_t>().data(), + manifest_hash->second.second.data(), hash_type)) { + policies[i] = VerificationPolicy::ABORT; + } else { + policies[i] = VerificationPolicy::ACCEPT; + } + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +void Verifier::addKey(PARCKey *key) { parcVerifier_AddKey(verifier_, key); } + +void Verifier::setHasher(PARCCryptoHasher *hasher) { + parcAssertNotNull(hasher, "Expected non-null hasher"); + if (hasher_) parcCryptoHasher_Release(&hasher_); + hasher_ = parcCryptoHasher_Acquire(hasher); +} + +void Verifier::setVerificationFailedCallback( + VerificationFailedCallback verfication_failed_cb, + const vector<VerificationPolicy> &failed_policies) { + verification_failed_cb_ = verfication_failed_cb; + failed_policies_ = failed_policies; +} + +void Verifier::getVerificationFailedCallback( + VerificationFailedCallback **verfication_failed_cb) { + *verfication_failed_cb = &verification_failed_cb_; +} + +size_t Verifier::getSignatureSize(const PacketPtr packet) { + return packet->getSignatureSize(); +} + +CryptoHash Verifier::computeHash(PacketPtr packet) { + parcAssertNotNull(hasher_, "Expected non-null hasher"); + + CryptoHasher crypto_hasher(hasher_); + const utils::MemBuf &header_chain = *packet; + const utils::MemBuf *current = &header_chain; + + crypto_hasher.init(); + + do { + crypto_hasher.updateBytes(current->data(), current->length()); + current = current->next(); + } while (current != &header_chain); + + return crypto_hasher.finalize(); +} + +void Verifier::callVerificationFailedCallback(PacketPtr packet, + VerificationPolicy &policy) { + if (verification_failed_cb_ == interface::VOID_HANDLER) { + return; + } + + if (find(failed_policies_.begin(), failed_policies_.end(), policy) != + failed_policies_.end()) { + policy = verification_failed_cb_( + static_cast<const core::ContentObject &>(*packet), + make_error_code( + protocol::protocol_error::signature_verification_failed)); + } +} + +bool VoidVerifier::verifyPacket(PacketPtr packet) { return true; } + +vector<VerificationPolicy> VoidVerifier::verifyPackets( + const vector<PacketPtr> &packets) { + return vector<VerificationPolicy>(packets.size(), VerificationPolicy::ACCEPT); +} + +vector<VerificationPolicy> VoidVerifier::verifyPackets( + const vector<PacketPtr> &packets, + const unordered_map<Suffix, HashEntry> &suffix_map) { + return vector<VerificationPolicy>(packets.size(), VerificationPolicy::ACCEPT); +} + +AsymmetricVerifier::AsymmetricVerifier(PARCKey *pub_key) { addKey(pub_key); } + +AsymmetricVerifier::AsymmetricVerifier(const string &cert_path) { + setCertificate(cert_path); +} + +void AsymmetricVerifier::setCertificate(const string &cert_path) { + PARCCertificateFactory *factory = parcCertificateFactory_Create( + PARCCertificateType_X509, PARCContainerEncoding_PEM); + + struct stat buffer; + if (stat(cert_path.c_str(), &buffer) != 0) { + throw errors::RuntimeException("Certificate does not exist"); + } + + PARCCertificate *certificate = + parcCertificateFactory_CreateCertificateFromFile(factory, + cert_path.c_str(), NULL); + PARCKey *key = parcCertificate_GetPublicKey(certificate); + + addKey(key); + + parcKey_Release(&key); + parcCertificateFactory_Release(&factory); +} + +SymmetricVerifier::SymmetricVerifier(const string &passphrase) + : passphrase_(nullptr), signer_(nullptr) { + setPassphrase(passphrase); +} + +SymmetricVerifier::~SymmetricVerifier() { + if (passphrase_) parcBuffer_Release(&passphrase_); + if (signer_) parcSigner_Release(&signer_); +} + +void SymmetricVerifier::setPassphrase(const string &passphrase) { + if (passphrase_) parcBuffer_Release(&passphrase_); + + PARCBufferComposer *composer = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer, passphrase.c_str()); + passphrase_ = parcBufferComposer_ProduceBuffer(composer); + parcBufferComposer_Release(&composer); +} + +void SymmetricVerifier::setSigner(const PARCCryptoSuite &suite) { + parcAssertNotNull(passphrase_, "Expected non-null passphrase"); + + if (signer_) parcSigner_Release(&signer_); + + PARCSymmetricKeyStore *key_store = parcSymmetricKeyStore_Create(passphrase_); + PARCSymmetricKeySigner *key_signer = parcSymmetricKeySigner_Create( + key_store, parcCryptoSuite_GetCryptoHash(suite)); + signer_ = parcSigner_Create(key_signer, PARCSymmetricKeySignerAsSigner); + + PARCKeyId *key_id = parcSigner_CreateKeyId(signer_); + PARCKey *key = parcKey_CreateFromSymmetricKey( + key_id, parcSigner_GetSigningAlgorithm(signer_), passphrase_); + + addKey(key); + setHasher(parcSigner_GetCryptoHasher(signer_)); + + parcSymmetricKeyStore_Release(&key_store); + parcSymmetricKeySigner_Release(&key_signer); + parcKeyId_Release(&key_id); + parcKey_Release(&key); +} + +vector<VerificationPolicy> SymmetricVerifier::verifyPackets( + const vector<PacketPtr> &packets) { + vector<VerificationPolicy> policies(packets.size(), VerificationPolicy::DROP); + + for (unsigned int i = 0; i < packets.size(); ++i) { + auto suite = + static_cast<PARCCryptoSuite>(packets[i]->getValidationAlgorithm()); + + if (!signer_ || suite != parcSigner_GetCryptoSuite(signer_)) { + setSigner(suite); + } + + if (verifyPacket(packets[i])) { + policies[i] = VerificationPolicy::ACCEPT; + } + + callVerificationFailedCallback(packets[i], policies[i]); + } + + return policies; +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/config.h.in b/libtransport/src/config.h.in index 4e9a0f262..ef47affda 100644 --- a/libtransport/src/config.h.in +++ b/libtransport/src/config.h.in @@ -25,10 +25,6 @@ #cmakedefine ASIO_STANDALONE #endif -#ifndef SECURE_HICNTRANSPORT -#cmakedefine SECURE_HICNTRANSPORT -#endif - #define RAAQM_CONFIG_PATH "@raaqm_config_path@" #cmakedefine __vpp__ diff --git a/libtransport/src/core/CMakeLists.txt b/libtransport/src/core/CMakeLists.txt index 5c8ab9270..4e3ac10ec 100644 --- a/libtransport/src/core/CMakeLists.txt +++ b/libtransport/src/core/CMakeLists.txt @@ -21,55 +21,27 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format.h ${CMAKE_CURRENT_SOURCE_DIR}/pending_interest.h ${CMAKE_CURRENT_SOURCE_DIR}/portal.h - ${CMAKE_CURRENT_SOURCE_DIR}/connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/tcp_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/errors.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.h + ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/rs.h ) list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/content_object.cc ${CMAKE_CURRENT_SOURCE_DIR}/interest.cc + ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc ${CMAKE_CURRENT_SOURCE_DIR}/packet.cc ${CMAKE_CURRENT_SOURCE_DIR}/name.cc ${CMAKE_CURRENT_SOURCE_DIR}/prefix.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tcp_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_interface.cc ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format_fixed.cc - ${CMAKE_CURRENT_SOURCE_DIR}/connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc + ${CMAKE_CURRENT_SOURCE_DIR}/io_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/fec.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rs.cc ) -if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") - if (BUILD_WITH_VPP OR BUILD_HICNPLUGIN) - list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.h - ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.h - ) - - list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_interface.cc - ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.c - ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.c - ) - endif() - - list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_connector.h - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_interface.h - ) - - list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_connector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/raw_socket_interface.cc - ) -endif() - set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file diff --git a/libtransport/src/core/content_object.cc b/libtransport/src/core/content_object.cc index f5cccf404..0c68ef559 100644 --- a/libtransport/src/core/content_object.cc +++ b/libtransport/src/core/content_object.cc @@ -32,8 +32,9 @@ namespace transport { namespace core { -ContentObject::ContentObject(const Name &name, Packet::Format format) - : Packet(format) { +ContentObject::ContentObject(const Name &name, Packet::Format format, + std::size_t additional_header_size) + : Packet(format, additional_header_size) { if (TRANSPORT_EXPECT_FALSE( hicn_data_set_name(format, packet_start_, &name.name_) < 0)) { throw errors::RuntimeException("Error filling the packet name."); @@ -47,41 +48,32 @@ ContentObject::ContentObject(const Name &name, Packet::Format format) } #ifdef __ANDROID__ -ContentObject::ContentObject(hicn_format_t format) - : ContentObject(Name("0::0|0"), format) {} +ContentObject::ContentObject(hicn_format_t format, + std::size_t additional_header_size) + : ContentObject(Name("0::0|0"), format, additional_header_size) {} #else -ContentObject::ContentObject(hicn_format_t format) - : ContentObject(Packet::base_name, format) {} +ContentObject::ContentObject(hicn_format_t format, + std::size_t additional_header_size) + : ContentObject(Packet::base_name, format, additional_header_size) {} #endif ContentObject::ContentObject(const Name &name, hicn_format_t format, + std::size_t additional_header_size, const uint8_t *payload, std::size_t size) - : ContentObject(name, format) { + : ContentObject(name, format, additional_header_size) { appendPayload(payload, size); } -ContentObject::ContentObject(const uint8_t *buffer, std::size_t size) - : Packet(buffer, size) { - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::RuntimeException("Error getting name from content object."); - } +ContentObject::ContentObject(ContentObject &&other) : Packet(std::move(other)) { + name_ = std::move(other.name_); } -ContentObject::ContentObject(MemBufPtr &&buffer) : Packet(std::move(buffer)) { - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::RuntimeException("Error getting name from content object."); - } +ContentObject::ContentObject(const ContentObject &other) : Packet(other) { + name_ = other.name_; } -ContentObject::ContentObject(ContentObject &&other) : Packet(std::move(other)) { - name_ = std::move(other.name_); - - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < - 0) { - throw errors::MalformedPacketException(); - } +ContentObject &ContentObject::operator=(const ContentObject &other) { + return (ContentObject &)Packet::operator=(other); } ContentObject::~ContentObject() {} @@ -132,10 +124,11 @@ uint32_t ContentObject::getPathLabel() const { "Error retrieving the path label from content object"); } - return path_label; + return ntohl(path_label); } ContentObject &ContentObject::setPathLabel(uint32_t path_label) { + path_label = htonl(path_label); if (hicn_data_set_path_label((hicn_header_t *)packet_start_, path_label) < 0) { throw errors::RuntimeException( diff --git a/libtransport/src/core/errors.cc b/libtransport/src/core/errors.cc new file mode 100644 index 000000000..82647a60b --- /dev/null +++ b/libtransport/src/core/errors.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/errors.h> + +namespace transport { +namespace core { + +const std::error_category& core_category() { + static core_category_impl instance; + + return instance; +} + +const char* core_category_impl::name() const throw() { + return "transport::protocol::error"; +} + +std::string core_category_impl::message(int ev) const { + switch (static_cast<core_error>(ev)) { + case core_error::success: { + return "Success"; + } + case core_error::configuration_parse_failed: { + return "Error parsing configuration."; + } + case core_error::configuration_not_applied: { + return "Configuration was not applied due to wrong parameters."; + } + default: { + return "Unknown core error"; + } + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/errors.h b/libtransport/src/core/errors.h new file mode 100644 index 000000000..a46f1dbcd --- /dev/null +++ b/libtransport/src/core/errors.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> +#include <system_error> + +namespace transport { +namespace core { + +/** + * @brief Get the default server error category. + * @return The default server error category instance. + * + * @warning The first call to this function is thread-safe only starting with + * C++11. + */ +const std::error_category& core_category(); + +/** + * The list of errors. + */ +enum class core_error { + success = 0, + configuration_parse_failed, + configuration_not_applied +}; + +/** + * @brief Create an error_code instance for the given error. + * @param error The error. + * @return The error_code instance. + */ +inline std::error_code make_error_code(core_error error) { + return std::error_code(static_cast<int>(error), core_category()); +} + +/** + * @brief Create an error_condition instance for the given error. + * @param error The error. + * @return The error_condition instance. + */ +inline std::error_condition make_error_condition(core_error error) { + return std::error_condition(static_cast<int>(error), core_category()); +} + +/** + * @brief A server error category. + */ +class core_category_impl : public std::error_category { + public: + /** + * @brief Get the name of the category. + * @return The name of the category. + */ + virtual const char* name() const throw(); + + /** + * @brief Get the error message for a given error. + * @param ev The error numeric value. + * @return The message associated to the error. + */ + virtual std::string message(int ev) const; +}; +} // namespace core +} // namespace transport + +namespace std { +// namespace system { +template <> +struct is_error_code_enum<::transport::core::core_error> + : public std::true_type {}; +// } // namespace system +} // namespace std
\ No newline at end of file diff --git a/libtransport/src/core/facade.h b/libtransport/src/core/facade.h index 04f643f63..199081271 100644 --- a/libtransport/src/core/facade.h +++ b/libtransport/src/core/facade.h @@ -15,36 +15,14 @@ #pragma once -#include <core/forwarder_interface.h> -#include <core/hicn_forwarder_interface.h> #include <core/manifest_format_fixed.h> #include <core/manifest_inline.h> #include <core/portal.h> -#ifdef __linux__ -#ifndef __ANDROID__ -#include <core/raw_socket_interface.h> -#ifdef __vpp__ -#include <core/vpp_forwarder_interface.h> -#endif -#endif -#endif - namespace transport { namespace core { -using HicnForwarderPortal = Portal<HicnForwarderInterface>; - -#ifdef __linux__ -#ifndef __ANDROID__ -using RawSocketPortal = Portal<RawSocketInterface>; -#endif -#ifdef __vpp__ -using VPPForwarderPortal = Portal<VPPForwarderInterface>; -#endif -#endif - using ContentObjectManifest = core::ManifestInline<ContentObject, Fixed>; using InterestManifest = core::ManifestInline<Interest, Fixed>; diff --git a/libtransport/src/core/fec.cc b/libtransport/src/core/fec.cc new file mode 100644 index 000000000..134198b9e --- /dev/null +++ b/libtransport/src/core/fec.cc @@ -0,0 +1,878 @@ +/* + * fec.c -- forward error correction based on Vandermonde matrices + * 980624 + * (C) 1997-98 Luigi Rizzo (luigi@iet.unipi.it) + * + * Portions derived from code by Phil Karn (karn@ka9q.ampr.org), + * Robert Morelos-Zaragoza (robert@spectra.eng.hawaii.edu) and Hari + * Thirumoorthy (harit@spectra.eng.hawaii.edu), Aug 1995 + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR + * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY + * OF SUCH DAMAGE. + */ + +/* + * The following parameter defines how many bits are used for + * field elements. The code supports any value from 2 to 16 + * but fastest operation is achieved with 8 bit elements + * This is the only parameter you may want to change. + */ +#ifndef GF_BITS +#define GF_BITS 8 /* code over GF(2**GF_BITS) - change to suit */ +#endif + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <hicn/transport/portability/platform.h> +#include "fec.h" + +/** + * XXX This disable a warning raising only in some platforms. + * TODO Check if this warning is a mistake or it is a real bug: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83404 + * https://gcc.gnu.org/bugzilla//show_bug.cgi?id=88059 + */ +#ifndef __clang__ +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif + +/* + * compatibility stuff + */ +#ifdef MSDOS /* but also for others, e.g. sun... */ +#define NEED_BCOPY +#define bcmp(a,b,n) memcmp(a,b,n) +#endif + +#ifdef ANDROID +#define bcmp(a,b,n) memcmp(a,b,n) +#endif + +#ifdef NEED_BCOPY +#define bcopy(s, d, siz) memcpy((d), (s), (siz)) +#define bzero(d, siz) memset((d), '\0', (siz)) +#endif + +/* + * stuff used for testing purposes only + */ + +#ifdef TEST +#define DEB(x) +#define DDB(x) x +#define DEBUG 0 /* minimal debugging */ +#ifdef MSDOS +#include <time.h> +struct timeval { + unsigned long ticks; +}; +#define gettimeofday(x, dummy) { (x)->ticks = clock() ; } +#define DIFF_T(a,b) (1+ 1000000*(a.ticks - b.ticks) / CLOCKS_PER_SEC ) +typedef unsigned long u_long ; +typedef unsigned short u_short ; +#else /* typically, unix systems */ +#include <sys/time.h> +#define DIFF_T(a,b) \ + (1+ 1000000*(a.tv_sec - b.tv_sec) + (a.tv_usec - b.tv_usec) ) +#endif + +#define TICK(t) \ + {struct timeval x ; \ + gettimeofday(&x, NULL) ; \ + t = x.tv_usec + 1000000* (x.tv_sec & 0xff ) ; \ + } +#define TOCK(t) \ + { u_long t1 ; TICK(t1) ; \ + if (t1 < t) t = 256000000 + t1 - t ; \ + else t = t1 - t ; \ + if (t == 0) t = 1 ;} + +u_long ticks[10]; /* vars for timekeeping */ +#else +#define DEB(x) +#define DDB(x) +#define TICK(x) +#define TOCK(x) +#endif /* TEST */ + +/* + * You should not need to change anything beyond this point. + * The first part of the file implements linear algebra in GF. + * + * gf is the type used to store an element of the Galois Field. + * Must constain at least GF_BITS bits. + * + * Note: unsigned char will work up to GF(256) but int seems to run + * faster on the Pentium. We use int whenever have to deal with an + * index, since they are generally faster. + */ +#if (GF_BITS < 2 && GF_BITS >16) +#error "GF_BITS must be 2 .. 16" +#endif + + +#define GF_SIZE ((1 << GF_BITS) - 1) /* powers of \alpha */ + +/* + * Primitive polynomials - see Lin & Costello, Appendix A, + * and Lee & Messerschmitt, p. 453. + */ +static const char *allPp[] = { /* GF_BITS polynomial */ + NULL, /* 0 no code */ + NULL, /* 1 no code */ + "111", /* 2 1+x+x^2 */ + "1101", /* 3 1+x+x^3 */ + "11001", /* 4 1+x+x^4 */ + "101001", /* 5 1+x^2+x^5 */ + "1100001", /* 6 1+x+x^6 */ + "10010001", /* 7 1 + x^3 + x^7 */ + "101110001", /* 8 1+x^2+x^3+x^4+x^8 */ + "1000100001", /* 9 1+x^4+x^9 */ + "10010000001", /* 10 1+x^3+x^10 */ + "101000000001", /* 11 1+x^2+x^11 */ + "1100101000001", /* 12 1+x+x^4+x^6+x^12 */ + "11011000000001", /* 13 1+x+x^3+x^4+x^13 */ + "110000100010001", /* 14 1+x+x^6+x^10+x^14 */ + "1100000000000001", /* 15 1+x+x^15 */ + "11010000000010001" /* 16 1+x+x^3+x^12+x^16 */ +}; + + +/* + * To speed up computations, we have tables for logarithm, exponent + * and inverse of a number. If GF_BITS <= 8, we use a table for + * multiplication as well (it takes 64K, no big deal even on a PDA, + * especially because it can be pre-initialized an put into a ROM!), + * otherwhise we use a table of logarithms. + * In any case the macro gf_mul(x,y) takes care of multiplications. + */ + +static gf gf_exp[2*GF_SIZE]; /* index->poly form conversion table */ +static int gf_log[GF_SIZE + 1]; /* Poly->index form conversion table */ +static gf inverse[GF_SIZE+1]; /* inverse of field elem. */ + /* inv[\alpha**i]=\alpha**(GF_SIZE-i-1) */ + +/* + * modnn(x) computes x % GF_SIZE, where GF_SIZE is 2**GF_BITS - 1, + * without a slow divide. + */ +static inline gf +modnn(int x) +{ + while (x >= GF_SIZE) { + x -= GF_SIZE; + x = (x >> GF_BITS) + (x & GF_SIZE); + } + return x; +} + +#define SWAP(a,b,t) {t tmp; tmp=a; a=b; b=tmp;} + +/* + * gf_mul(x,y) multiplies two numbers. If GF_BITS<=8, it is much + * faster to use a multiplication table. + * + * USE_GF_MULC, GF_MULC0(c) and GF_ADDMULC(x) can be used when multiplying + * many numbers by the same constant. In this case the first + * call sets the constant, and others perform the multiplications. + * A value related to the multiplication is held in a local variable + * declared with USE_GF_MULC . See usage in addmul1(). + */ +#if (GF_BITS <= 8) +static gf gf_mul_table[GF_SIZE + 1][GF_SIZE + 1]; + +#define gf_mul(x,y) gf_mul_table[x][y] + +#define USE_GF_MULC gf * __gf_mulc_ +#define GF_MULC0(c) __gf_mulc_ = gf_mul_table[c] +#define GF_ADDMULC(dst, x) dst ^= __gf_mulc_[x] + +static void +init_mul_table() +{ + int i, j; + for (i=0; i< GF_SIZE+1; i++) + for (j=0; j< GF_SIZE+1; j++) + gf_mul_table[i][j] = gf_exp[modnn(gf_log[i] + gf_log[j]) ] ; + + for (j=0; j< GF_SIZE+1; j++) + gf_mul_table[0][j] = gf_mul_table[j][0] = 0; +} +#else /* GF_BITS > 8 */ +static inline gf +gf_mul(x,y) +{ + if ( (x) == 0 || (y)==0 ) return 0; + + return gf_exp[gf_log[x] + gf_log[y] ] ; +} +#define init_mul_table() + +#define USE_GF_MULC register gf * __gf_mulc_ +#define GF_MULC0(c) __gf_mulc_ = &gf_exp[ gf_log[c] ] +#define GF_ADDMULC(dst, x) { if (x) dst ^= __gf_mulc_[ gf_log[x] ] ; } +#endif + +/* + * Generate GF(2**m) from the irreducible polynomial p(X) in p[0]..p[m] + * Lookup tables: + * index->polynomial form gf_exp[] contains j= \alpha^i; + * polynomial form -> index form gf_log[ j = \alpha^i ] = i + * \alpha=x is the primitive element of GF(2^m) + * + * For efficiency, gf_exp[] has size 2*GF_SIZE, so that a simple + * multiplication of two numbers can be resolved without calling modnn + */ + +/* + * i use malloc so many times, it is easier to put checks all in + * one place. + */ +static void * +my_malloc(int sz, const char *err_string) +{ + void *p = malloc( sz ); + if (p == NULL) { + fprintf(stderr, "-- malloc failure allocating %s\n", err_string); + exit(1) ; + } + return p ; +} + +#define NEW_GF_MATRIX(rows, cols) \ + (gf *)my_malloc(rows * cols * sizeof(gf), " ## __LINE__ ## " ) + +/* + * initialize the data structures used for computations in GF. + */ +static void +generate_gf(void) +{ + int i; + gf mask; + const char *Pp = allPp[GF_BITS] ; + + mask = 1; /* x ** 0 = 1 */ + gf_exp[GF_BITS] = 0; /* will be updated at the end of the 1st loop */ + /* + * first, generate the (polynomial representation of) powers of \alpha, + * which are stored in gf_exp[i] = \alpha ** i . + * At the same time build gf_log[gf_exp[i]] = i . + * The first GF_BITS powers are simply bits shifted to the left. + */ + for (i = 0; i < GF_BITS; i++, mask <<= 1 ) { + gf_exp[i] = mask; + gf_log[gf_exp[i]] = i; + /* + * If Pp[i] == 1 then \alpha ** i occurs in poly-repr + * gf_exp[GF_BITS] = \alpha ** GF_BITS + */ + if ( Pp[i] == '1' ) + gf_exp[GF_BITS] ^= mask; + } + /* + * now gf_exp[GF_BITS] = \alpha ** GF_BITS is complete, so can als + * compute its inverse. + */ + gf_log[gf_exp[GF_BITS]] = GF_BITS; + /* + * Poly-repr of \alpha ** (i+1) is given by poly-repr of + * \alpha ** i shifted left one-bit and accounting for any + * \alpha ** GF_BITS term that may occur when poly-repr of + * \alpha ** i is shifted. + */ + mask = 1 << (GF_BITS - 1 ) ; + for (i = GF_BITS + 1; i < GF_SIZE; i++) { + if (gf_exp[i - 1] >= mask) + gf_exp[i] = gf_exp[GF_BITS] ^ ((gf_exp[i - 1] ^ mask) << 1); + else + gf_exp[i] = gf_exp[i - 1] << 1; + gf_log[gf_exp[i]] = i; + } + /* + * log(0) is not defined, so use a special value + */ + gf_log[0] = GF_SIZE ; + /* set the extended gf_exp values for fast multiply */ + for (i = 0 ; i < GF_SIZE ; i++) + gf_exp[i + GF_SIZE] = gf_exp[i] ; + + /* + * again special cases. 0 has no inverse. This used to + * be initialized to GF_SIZE, but it should make no difference + * since noone is supposed to read from here. + */ + inverse[0] = 0 ; + inverse[1] = 1; + for (i=2; i<=GF_SIZE; i++) + inverse[i] = gf_exp[GF_SIZE-gf_log[i]]; +} + +/* + * Various linear algebra operations that i use often. + */ + +/* + * addmul() computes dst[] = dst[] + c * src[] + * This is used often, so better optimize it! Currently the loop is + * unrolled 16 times, a good value for 486 and pentium-class machines. + * The case c=0 is also optimized, whereas c=1 is not. These + * calls are unfrequent in my typical apps so I did not bother. + * + * Note that gcc on + */ +#define addmul(dst, src, c, sz) \ + if (c != 0) addmul1(dst, src, c, sz) + +#define UNROLL 16 /* 1, 4, 8, 16 */ +static void +addmul1(gf *dst1, gf *src1, gf c, int sz) +{ + USE_GF_MULC ; + gf *dst = dst1, *src = src1 ; + gf *lim = &dst[sz - UNROLL + 1] ; + + GF_MULC0(c) ; + +#if (UNROLL > 1) /* unrolling by 8/16 is quite effective on the pentium */ + for (; dst < lim ; dst += UNROLL, src += UNROLL ) { + GF_ADDMULC( dst[0] , src[0] ); + GF_ADDMULC( dst[1] , src[1] ); + GF_ADDMULC( dst[2] , src[2] ); + GF_ADDMULC( dst[3] , src[3] ); +#if (UNROLL > 4) + GF_ADDMULC( dst[4] , src[4] ); + GF_ADDMULC( dst[5] , src[5] ); + GF_ADDMULC( dst[6] , src[6] ); + GF_ADDMULC( dst[7] , src[7] ); +#endif +#if (UNROLL > 8) + GF_ADDMULC( dst[8] , src[8] ); + GF_ADDMULC( dst[9] , src[9] ); + GF_ADDMULC( dst[10] , src[10] ); + GF_ADDMULC( dst[11] , src[11] ); + GF_ADDMULC( dst[12] , src[12] ); + GF_ADDMULC( dst[13] , src[13] ); + GF_ADDMULC( dst[14] , src[14] ); + GF_ADDMULC( dst[15] , src[15] ); +#endif + } +#endif + lim += UNROLL - 1 ; + for (; dst < lim; dst++, src++ ) /* final components */ + GF_ADDMULC( *dst , *src ); +} + +/* + * computes C = AB where A is n*k, B is k*m, C is n*m + */ +static void +matmul(gf *a, gf *b, gf *c, int n, int k, int m) +{ + int row, col, i ; + + for (row = 0; row < n ; row++) { + for (col = 0; col < m ; col++) { + gf *pa = &a[ row * k ]; + gf *pb = &b[ col ]; + gf acc = 0 ; + for (i = 0; i < k ; i++, pa++, pb += m ) + acc ^= gf_mul( *pa, *pb ) ; + c[ row * m + col ] = acc ; + } + } +} + +#ifdef DEBUGG +/* + * returns 1 if the square matrix is identiy + * (only for test) + */ +static int +is_identity(gf *m, int k) +{ + int row, col ; + for (row=0; row<k; row++) + for (col=0; col<k; col++) + if ( (row==col && *m != 1) || + (row!=col && *m != 0) ) + return 0 ; + else + m++ ; + return 1 ; +} +#endif /* debug */ + +/* + * invert_mat() takes a matrix and produces its inverse + * k is the size of the matrix. + * (Gauss-Jordan, adapted from Numerical Recipes in C) + * Return non-zero if singular. + */ +DEB( int pivloops=0; int pivswaps=0 ; /* diagnostic */) +static int +invert_mat(gf *src, int k) +{ + gf c, *p ; + int irow, icol, row, col, i, ix ; + + int error = 1 ; + int *indxc = (int*)my_malloc(k*sizeof(int), "indxc"); + int *indxr = (int*)my_malloc(k*sizeof(int), "indxr"); + int *ipiv = (int*)my_malloc(k*sizeof(int), "ipiv"); + gf *id_row = NEW_GF_MATRIX(1, k); + gf *temp_row = NEW_GF_MATRIX(1, k); + + bzero(id_row, k*sizeof(gf)); + DEB( pivloops=0; pivswaps=0 ; /* diagnostic */ ) + /* + * ipiv marks elements already used as pivots. + */ + for (i = 0; i < k ; i++) + ipiv[i] = 0 ; + + for (col = 0; col < k ; col++) { + gf *pivot_row ; + /* + * Zeroing column 'col', look for a non-zero element. + * First try on the diagonal, if it fails, look elsewhere. + */ + irow = icol = -1 ; + if (ipiv[col] != 1 && src[col*k + col] != 0) { + irow = col ; + icol = col ; + goto found_piv ; + } + for (row = 0 ; row < k ; row++) { + if (ipiv[row] != 1) { + for (ix = 0 ; ix < k ; ix++) { + DEB( pivloops++ ; ) + if (ipiv[ix] == 0) { + if (src[row*k + ix] != 0) { + irow = row ; + icol = ix ; + goto found_piv ; + } + } else if (ipiv[ix] > 1) { + fprintf(stderr, "singular matrix\n"); + goto fail ; + } + } + } + } + if (icol == -1) { + fprintf(stderr, "XXX pivot not found!\n"); + goto fail ; + } +found_piv: + ++(ipiv[icol]) ; + /* + * swap rows irow and icol, so afterwards the diagonal + * element will be correct. Rarely done, not worth + * optimizing. + */ + if (irow != icol) { + for (ix = 0 ; ix < k ; ix++ ) { + SWAP( src[irow*k + ix], src[icol*k + ix], gf) ; + } + } + indxr[col] = irow ; + indxc[col] = icol ; + pivot_row = &src[icol*k] ; + c = pivot_row[icol] ; + if (c == 0) { + fprintf(stderr, "singular matrix 2\n"); + goto fail ; + } + if (c != 1 ) { /* otherwhise this is a NOP */ + /* + * this is done often , but optimizing is not so + * fruitful, at least in the obvious ways (unrolling) + */ + DEB( pivswaps++ ; ) + c = inverse[ c ] ; + pivot_row[icol] = 1 ; + for (ix = 0 ; ix < k ; ix++ ) + pivot_row[ix] = gf_mul(c, pivot_row[ix] ); + } + /* + * from all rows, remove multiples of the selected row + * to zero the relevant entry (in fact, the entry is not zero + * because we know it must be zero). + * (Here, if we know that the pivot_row is the identity, + * we can optimize the addmul). + */ + id_row[icol] = 1; + if (bcmp(pivot_row, id_row, k*sizeof(gf)) != 0) { + for (p = src, ix = 0 ; ix < k ; ix++, p += k ) { + if (ix != icol) { + c = p[icol] ; + p[icol] = 0 ; + addmul(p, pivot_row, c, k ); + } + } + } + id_row[icol] = 0; + } /* done all columns */ + for (col = k-1 ; col >= 0 ; col-- ) { + if (indxr[col] <0 || indxr[col] >= k) + fprintf(stderr, "AARGH, indxr[col] %d\n", indxr[col]); + else if (indxc[col] <0 || indxc[col] >= k) + fprintf(stderr, "AARGH, indxc[col] %d\n", indxc[col]); + else + if (indxr[col] != indxc[col] ) { + for (row = 0 ; row < k ; row++ ) { + SWAP( src[row*k + indxr[col]], src[row*k + indxc[col]], gf) ; + } + } + } + error = 0 ; +fail: + free(indxc); + free(indxr); + free(ipiv); + free(id_row); + free(temp_row); + return error ; +} + +/* + * fast code for inverting a vandermonde matrix. + * XXX NOTE: It assumes that the matrix + * is not singular and _IS_ a vandermonde matrix. Only uses + * the second column of the matrix, containing the p_i's. + * + * Algorithm borrowed from "Numerical recipes in C" -- sec.2.8, but + * largely revised for my purposes. + * p = coefficients of the matrix (p_i) + * q = values of the polynomial (known) + */ + +int +invert_vdm(gf *src, int k) +{ + int i, j, row, col ; + gf *b, *c, *p; + gf t, xx ; + + if (k == 1) /* degenerate case, matrix must be p^0 = 1 */ + return 0 ; + /* + * c holds the coefficient of P(x) = Prod (x - p_i), i=0..k-1 + * b holds the coefficient for the matrix inversion + */ + c = NEW_GF_MATRIX(1, k); + b = NEW_GF_MATRIX(1, k); + + p = NEW_GF_MATRIX(1, k); + + for ( j=1, i = 0 ; i < k ; i++, j+=k ) { + c[i] = 0 ; + p[i] = src[j] ; /* p[i] */ + } + /* + * construct coeffs. recursively. We know c[k] = 1 (implicit) + * and start P_0 = x - p_0, then at each stage multiply by + * x - p_i generating P_i = x P_{i-1} - p_i P_{i-1} + * After k steps we are done. + */ + c[k-1] = p[0] ; /* really -p(0), but x = -x in GF(2^m) */ + for (i = 1 ; i < k ; i++ ) { + gf p_i = p[i] ; /* see above comment */ + for (j = k-1 - ( i - 1 ) ; j < k-1 ; j++ ) + c[j] ^= gf_mul( p_i, c[j+1] ) ; + c[k-1] ^= p_i ; + } + + for (row = 0 ; row < k ; row++ ) { + /* + * synthetic division etc. + */ + xx = p[row] ; + t = 1 ; + b[k-1] = 1 ; /* this is in fact c[k] */ + for (i = k-2 ; i >= 0 ; i-- ) { + b[i] = c[i+1] ^ gf_mul(xx, b[i+1]) ; + t = gf_mul(xx, t) ^ b[i] ; + } + for (col = 0 ; col < k ; col++ ) + src[col*k + row] = gf_mul(inverse[t], b[col] ); + } + free(c) ; + free(b) ; + free(p) ; + return 0 ; +} + +static int fec_initialized = 0 ; +static void +init_fec() +{ + TICK(ticks[0]); + generate_gf(); + TOCK(ticks[0]); + DDB(fprintf(stderr, "generate_gf took %ldus\n", ticks[0]);) + TICK(ticks[0]); + init_mul_table(); + TOCK(ticks[0]); + DDB(fprintf(stderr, "init_mul_table took %ldus\n", ticks[0]);) + fec_initialized = 1 ; +} + +/* + * This section contains the proper FEC encoding/decoding routines. + * The encoding matrix is computed starting with a Vandermonde matrix, + * and then transforming it into a systematic matrix. + */ + +#define FEC_MAGIC 0xFECC0DEC + +void +fec_free(struct fec_parms *p) +{ + if (p==NULL || + p->magic != ( ( (FEC_MAGIC ^ p->k) ^ p->n) ^ (unsigned long)(p->enc_matrix)) ) { + fprintf(stderr, "bad parameters to fec_free\n"); + return ; + } + free(p->enc_matrix); + free(p); +} + +/* + * create a new encoder, returning a descriptor. This contains k,n and + * the encoding matrix. + */ +struct fec_parms * +fec_new(int k, int n) +{ + int row, col ; + gf *p, *tmp_m ; + + struct fec_parms *retval ; + + if (fec_initialized == 0) + init_fec(); + + if (k > GF_SIZE + 1 || n > GF_SIZE + 1 || k > n ) { + fprintf(stderr, "Invalid parameters k %d n %d GF_SIZE %d\n", + k, n, GF_SIZE ); + return NULL ; + } + retval = (struct fec_parms *)my_malloc(sizeof(struct fec_parms), "new_code"); + retval->k = k ; + retval->n = n ; + retval->enc_matrix = NEW_GF_MATRIX(n, k); + retval->magic = ( ( FEC_MAGIC ^ k) ^ n) ^ (unsigned long)(retval->enc_matrix) ; + tmp_m = NEW_GF_MATRIX(n, k); + /* + * fill the matrix with powers of field elements, starting from 0. + * The first row is special, cannot be computed with exp. table. + */ + tmp_m[0] = 1 ; + for (col = 1; col < k ; col++) + tmp_m[col] = 0 ; + for (p = tmp_m + k, row = 0; row < n-1 ; row++, p += k) { + for ( col = 0 ; col < k ; col ++ ) + p[col] = gf_exp[modnn(row*col)]; + } + + /* + * quick code to build systematic matrix: invert the top + * k*k vandermonde matrix, multiply right the bottom n-k rows + * by the inverse, and construct the identity matrix at the top. + */ + TICK(ticks[3]); + invert_vdm(tmp_m, k); /* much faster than invert_mat */ + matmul(tmp_m + k*k, tmp_m, retval->enc_matrix + k*k, n - k, k, k); + /* + * the upper matrix is I so do not bother with a slow multiply + */ + bzero(retval->enc_matrix, k*k*sizeof(gf) ); + for (p = retval->enc_matrix, col = 0 ; col < k ; col++, p += k+1 ) + *p = 1 ; + free(tmp_m); + TOCK(ticks[3]); + + DDB(fprintf(stderr, "--- %ld us to build encoding matrix\n", + ticks[3]);) + DEB(pr_matrix(retval->enc_matrix, n, k, "encoding_matrix");) + return retval ; +} + +/* + * fec_encode accepts as input pointers to n data packets of size sz, + * and produces as output a packet pointed to by fec, computed + * with index "index". + */ +void +fec_encode(struct fec_parms *code, gf *src[], gf *fec, int index, int sz) +{ + int i, k = code->k ; + gf *p ; + + if (GF_BITS > 8) + sz /= 2 ; + + if (index < k) + bcopy(src[index], fec, sz*sizeof(gf) ) ; + else if (index < code->n) { + p = &(code->enc_matrix[index*k] ); + bzero(fec, sz*sizeof(gf)); + for (i = 0; i < k ; i++) + addmul(fec, src[i], p[i], sz ) ; + } else + fprintf(stderr, "Invalid index %d (max %d)\n", + index, code->n - 1 ); +} + +/* + * shuffle move src packets in their position + */ +static int +shuffle(gf *pkt[], int index[], int k) +{ + int i; + + for ( i = 0 ; i < k ; ) { + if (index[i] >= k || index[i] == i) + i++ ; + else { + /* + * put pkt in the right position (first check for conflicts). + */ + int c = index[i] ; + + if (index[c] == c) { + DEB(fprintf(stderr, "\nshuffle, error at %d\n", i);) + return 1 ; + } + SWAP(index[i], index[c], int) ; + SWAP(pkt[i], pkt[c], gf *) ; + } + } + DEB( /* just test that it works... */ + for ( i = 0 ; i < k ; i++ ) { + if (index[i] < k && index[i] != i) { + fprintf(stderr, "shuffle: after\n"); + for (i=0; i<k ; i++) fprintf(stderr, "%3d ", index[i]); + fprintf(stderr, "\n"); + return 1 ; + } + } + ) + return 0 ; +} + +/* + * build_decode_matrix constructs the encoding matrix given the + * indexes. The matrix must be already allocated as + * a vector of k*k elements, in row-major order + */ +static gf * +build_decode_matrix(struct fec_parms *code, gf *pkt[], int index[]) +{ + int i , k = code->k ; + gf *p, *matrix = NEW_GF_MATRIX(k, k); + + TICK(ticks[9]); + for (i = 0, p = matrix ; i < k ; i++, p += k ) { +#if 1 /* this is simply an optimization, not very useful indeed */ + if (index[i] < k) { + bzero(p, k*sizeof(gf) ); + p[i] = 1 ; + } else +#endif + if (index[i] < code->n ) + bcopy( &(code->enc_matrix[index[i]*k]), p, k*sizeof(gf) ); + else { + fprintf(stderr, "decode: invalid index %d (max %d)\n", + index[i], code->n - 1 ); + free(matrix) ; + return NULL ; + } + } + TICK(ticks[9]); + if (invert_mat(matrix, k)) { + free(matrix); + matrix = NULL ; + } + TOCK(ticks[9]); + return matrix ; +} + +/* + * fec_decode receives as input a vector of packets, the indexes of + * packets, and produces the correct vector as output. + * + * Input: + * code: pointer to code descriptor + * pkt: pointers to received packets. They are modified + * to store the output packets (in place) + * index: pointer to packet indexes (modified) + * sz: size of each packet + */ +int +fec_decode(struct fec_parms *code, gf *pkt[], int index[], int sz) +{ + gf *m_dec ; + gf **new_pkt ; + int row, col , k = code->k ; + + if (GF_BITS > 8) + sz /= 2 ; + + if (shuffle(pkt, index, k)) /* error if true */ + return 1 ; + m_dec = build_decode_matrix(code, pkt, index); + + if (m_dec == NULL) + return 1 ; /* error */ + /* + * do the actual decoding + */ + new_pkt = (gf**)my_malloc (k * sizeof (gf * ), "new pkt pointers" ); + for (row = 0 ; row < k ; row++ ) { + if (index[row] >= k) { + new_pkt[row] = (gf*) my_malloc (sz * sizeof (gf), "new pkt buffer" ); + bzero(new_pkt[row], sz * sizeof(gf) ) ; + for (col = 0 ; col < k ; col++ ) + addmul(new_pkt[row], pkt[col], m_dec[row*k + col], sz) ; + } + } + /* + * move pkts to their final destination + */ + for (row = 0 ; row < k ; row++ ) { + if (index[row] >= k) { + bcopy(new_pkt[row], pkt[row], sz*sizeof(gf)); + free(new_pkt[row]); + } + } + free(new_pkt); + free(m_dec); + + return 0; +} diff --git a/libtransport/src/core/fec.h b/libtransport/src/core/fec.h new file mode 100644 index 000000000..8234057a7 --- /dev/null +++ b/libtransport/src/core/fec.h @@ -0,0 +1,65 @@ +/* + * fec.c -- forward error correction based on Vandermonde matrices + * 980614 + * (C) 1997-98 Luigi Rizzo (luigi@iet.unipi.it) + * + * Portions derived from code by Phil Karn (karn@ka9q.ampr.org), + * Robert Morelos-Zaragoza (robert@spectra.eng.hawaii.edu) and Hari + * Thirumoorthy (harit@spectra.eng.hawaii.edu), Aug 1995 + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, + * OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR + * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY + * OF SUCH DAMAGE. + */ + +/* + * The following parameter defines how many bits are used for + * field elements. The code supports any value from 2 to 16 + * but fastest operation is achieved with 8 bit elements + * This is the only parameter you may want to change. + */ +#ifndef GF_BITS +#define GF_BITS 8 /* code over GF(2**GF_BITS) - change to suit */ +#endif + +#if (GF_BITS <= 8) +typedef unsigned char gf; +#else +typedef unsigned short gf; +#endif + +#define GF_SIZE ((1 << GF_BITS) - 1) /* powers of \alpha */ + +struct fec_parms { + unsigned long magic ; + int k, n ; /* parameters of the code */ + gf *enc_matrix ; +}; + +void fec_free(struct fec_parms *p) ; +struct fec_parms *fec_new(int k, int n) ; + +void fec_encode(struct fec_parms *code, gf *src[], gf *fec, int index, int sz); +int fec_decode(struct fec_parms *code, gf *pkt[], int index[], int sz); + +/* end of file */ diff --git a/libtransport/src/core/global_configuration.cc b/libtransport/src/core/global_configuration.cc new file mode 100644 index 000000000..e0b6c040a --- /dev/null +++ b/libtransport/src/core/global_configuration.cc @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/utils/log.h> + +#include <libconfig.h++> +#include <map> + +namespace transport { +namespace core { + +GlobalConfiguration::GlobalConfiguration() {} + +bool GlobalConfiguration::parseTransportConfig(const std::string& path) { + using namespace libconfig; + Config cfg; + + try { + cfg.readFile(path.c_str()); + } catch (const FileIOException& fioex) { + TRANSPORT_LOGE("I/O error while reading file."); + return false; + } catch (const ParseException& pex) { + TRANSPORT_LOGE("Parse error at %s:%d - %s", pex.getFile(), pex.getLine(), + pex.getError()); + return false; + } + + Setting& root = cfg.getRoot(); + + /** + * Iterate over sections. Best thing to do here would be to have other + * components of the program registering a callback here, to parse their + * section of the configuration file. + */ + for (auto section = root.begin(); section != root.end(); section++) { + std::string name = section->getName(); + std::error_code ec; + TRANSPORT_LOGD("Parsing Section: %s", name.c_str()); + + auto it = configuration_parsers_.find(name); + if (it != configuration_parsers_.end() && !it->second.first) { + TRANSPORT_LOGD("Found valid configuration parser"); + it->second.second(*section, ec); + it->second.first = true; + } + } + + return true; +} + +void GlobalConfiguration::parseConfiguration(const std::string& path) { + // Check if an environment variable with the configuration path exists. COnf + // variable comes first. + std::unique_lock<std::mutex> lck(cp_mtx_); + if (const char* env_c = std::getenv(GlobalConfiguration::conf_file)) { + parseTransportConfig(env_c); + } else if (!path.empty()) { + conf_file_path_ = path; + parseTransportConfig(conf_file_path_); + } else { + TRANSPORT_LOGD( + "Called parseConfiguration but no configuration file was provided."); + } +} + +void GlobalConfiguration::registerConfigurationSetter( + const std::string& key, const SetCallback& set_callback) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_setters_.find(key) != configuration_setters_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration setter %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_setters_.emplace(key, set_callback); + } +} + +void GlobalConfiguration::registerConfigurationGetter( + const std::string& key, const GetCallback& get_callback) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_getters_.find(key) != configuration_getters_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration getter %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_getters_.emplace(key, get_callback); + } +} + +void GlobalConfiguration::registerConfigurationParser( + const std::string& key, const ParserCallback& parser) { + std::unique_lock<std::mutex> lck(cp_mtx_); + if (configuration_parsers_.find(key) != configuration_parsers_.end()) { + TRANSPORT_LOGW( + "Trying to register configuration key %s twice. Ignoring second " + "registration attempt.", + key.c_str()); + } else { + configuration_parsers_.emplace(key, std::make_pair(false, parser)); + + // Trigger a parsing of the configuration. + if (!conf_file_path_.empty()) { + parseTransportConfig(conf_file_path_); + } + } +} + +void GlobalConfiguration::unregisterConfigurationParser( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_parsers_.find(key); + if (it != configuration_parsers_.end()) { + configuration_parsers_.erase(it); + } +} + +void GlobalConfiguration::unregisterConfigurationSetter( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_setters_.find(key); + if (it != configuration_setters_.end()) { + configuration_setters_.erase(it); + } +} + +void GlobalConfiguration::unregisterConfigurationGetter( + const std::string& key) { + std::unique_lock<std::mutex> lck(cp_mtx_); + auto it = configuration_getters_.find(key); + if (it != configuration_getters_.end()) { + configuration_getters_.erase(it); + } +} + +void GlobalConfiguration::getConfiguration( + interface::global_config::ConfigurationObject& configuration_object, + std::error_code& ec) { + auto it = configuration_getters_.find(configuration_object.getKey()); + + if (it != configuration_getters_.end()) { + it->second(configuration_object, ec); + } +} + +void GlobalConfiguration::setConfiguration( + const interface::global_config::ConfigurationObject& configuration_object, + std::error_code& ec) { + auto it = configuration_setters_.find(configuration_object.getKey()); + + if (it != configuration_setters_.end()) { + it->second(configuration_object, ec); + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/global_configuration.h b/libtransport/src/core/global_configuration.h new file mode 100644 index 000000000..dcc8d94e3 --- /dev/null +++ b/libtransport/src/core/global_configuration.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/utils/singleton.h> + +#include <functional> +#include <map> +#include <memory> +#include <mutex> +#include <system_error> + +namespace libconfig { +class Setting; +} + +namespace transport { +namespace core { + +/** + * Class holding workflow for global configuration. + * This class does not contains the actual configuration, which is rather stored + * inside the modules to be configured. This class contains the handlers to call + * for getting/setting the configurations and to parse the corresponding + * sections of the configuration file. Each class register 3 callbacks: one to + * parse conf section and 2 to set/get the configuration through programming + * interface. + */ +class GlobalConfiguration : public utils::Singleton<GlobalConfiguration> { + static const constexpr char *conf_file = "TRANSPORT_CONFIG"; + friend class utils::Singleton<GlobalConfiguration>; + + public: + /** + * This callback will be called by GlobalConfiguration in + * + */ + using ParserCallback = std::function<void(const libconfig::Setting &config, + std::error_code &ec)>; + using GetCallback = + std::function<void(interface::global_config::ConfigurationObject &object, + std::error_code &ec)>; + + using SetCallback = std::function<void( + const interface::global_config::ConfigurationObject &object, + std::error_code &ec)>; + + ~GlobalConfiguration() = default; + + public: + void parseConfiguration(const std::string &path); + + void registerConfigurationParser(const std::string &key, + const ParserCallback &parser); + + void registerConfigurationSetter(const std::string &key, + const SetCallback &set_callback); + void registerConfigurationGetter(const std::string &key, + const GetCallback &get_callback); + + void unregisterConfigurationParser(const std::string &key); + + void unregisterConfigurationSetter(const std::string &key); + + void unregisterConfigurationGetter(const std::string &key); + + void getConfiguration( + interface::global_config::ConfigurationObject &configuration_object, + std::error_code &ec); + void setConfiguration( + const interface::global_config::ConfigurationObject &configuration_object, + std::error_code &ec); + + private: + GlobalConfiguration(); + std::string conf_file_path_; + bool parseTransportConfig(const std::string &path); + + private: + std::mutex cp_mtx_; + using ParserPair = std::pair<bool, ParserCallback>; + std::map<std::string, ParserPair> configuration_parsers_; + std::map<std::string, GetCallback> configuration_getters_; + std::map<std::string, SetCallback> configuration_setters_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/interest.cc b/libtransport/src/core/interest.cc index 9ee662615..06cbe9f81 100644 --- a/libtransport/src/core/interest.cc +++ b/libtransport/src/core/interest.cc @@ -31,8 +31,9 @@ namespace transport { namespace core { -Interest::Interest(const Name &interest_name, Packet::Format format) - : Packet(format) { +Interest::Interest(const Name &interest_name, Packet::Format format, + std::size_t additional_header_size) + : Packet(format, additional_header_size) { if (hicn_interest_set_name(format_, packet_start_, interest_name.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); @@ -45,20 +46,14 @@ Interest::Interest(const Name &interest_name, Packet::Format format) } #ifdef __ANDROID__ -Interest::Interest(hicn_format_t format) : Interest(Name("0::0|0"), format) {} +Interest::Interest(hicn_format_t format, std::size_t additional_header_size) + : Interest(Name("0::0|0"), format, additional_header_size) {} #else -Interest::Interest(hicn_format_t format) : Interest(base_name, format) {} +Interest::Interest(hicn_format_t format, std::size_t additional_header_size) + : Interest(base_name, format, additional_header_size) {} #endif -Interest::Interest(const uint8_t *buffer, std::size_t size) - : Packet(buffer, size) { - if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { - throw errors::MalformedPacketException(); - } -} - -Interest::Interest(MemBufPtr &&buffer) : Packet(std::move(buffer)) { +Interest::Interest(MemBuf &&buffer) : Packet(std::move(buffer)) { if (hicn_interest_get_name(format_, packet_start_, name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); @@ -70,6 +65,14 @@ Interest::Interest(Interest &&other_interest) name_ = std::move(other_interest.name_); } +Interest::Interest(const Interest &other_interest) : Packet(other_interest) { + name_ = other_interest.name_; +} + +Interest &Interest::operator=(const Interest &other) { + return (Interest &)Packet::operator=(other); +} + Interest::~Interest() {} const Name &Interest::getName() const { @@ -152,6 +155,59 @@ void Interest::resetForHash() { } } +bool Interest::hasManifest() { + return (getPayloadType() == PayloadType::MANIFEST); +} + +void Interest::appendSuffix(std::uint32_t suffix) { + if (TRANSPORT_EXPECT_FALSE(suffix_set_.empty())) { + setPayloadType(PayloadType::MANIFEST); + } + + suffix_set_.emplace(suffix); +} + +void Interest::encodeSuffixes() { + if (!hasManifest()) { + return; + } + + // We assume interest does not hold signature for the moment. + auto int_manifest_header = + (InterestManifestHeader *)(writableData() + headerSize()); + int_manifest_header->n_suffixes = suffix_set_.size(); + std::size_t additional_length = + int_manifest_header->n_suffixes * sizeof(uint32_t); + + uint32_t *suffix = (uint32_t *)(int_manifest_header + 1); + for (auto it = suffix_set_.begin(); it != suffix_set_.end(); it++, suffix++) { + *suffix = *it; + } + + updateLength(additional_length); +} + +uint32_t *Interest::firstSuffix() { + if (!hasManifest()) { + return nullptr; + } + + auto ret = (InterestManifestHeader *)(writableData() + headerSize()); + ret += 1; + + return (uint32_t *)ret; +} + +uint32_t Interest::numberOfSuffixes() { + if (!hasManifest()) { + return 0; + } + + auto header = (InterestManifestHeader *)(writableData() + headerSize()); + + return header->n_suffixes; +} + } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/io_module.cc b/libtransport/src/core/io_module.cc new file mode 100644 index 000000000..fef0c1504 --- /dev/null +++ b/libtransport/src/core/io_module.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <dlfcn.h> +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/utils/log.h> + +#ifdef ANDROID +#include <io_modules/udp/hicn_forwarder_module.h> +#endif + +#include <deque> + +namespace transport { +namespace core { + +IoModule::~IoModule() {} + +IoModule *IoModule::load(const char *module_name) { +#ifdef ANDROID + return new HicnForwarderModule(); +#else + void *handle = 0; + IoModule *module = 0; + IoModule *(*creator)(void) = 0; + const char *error = 0; + + // open module + handle = dlopen(module_name, RTLD_NOW); + if (!handle) { + if ((error = dlerror()) != 0) { + TRANSPORT_LOGE("%s", error); + } + return 0; + } + + // link factory method + creator = (IoModule * (*)(void)) dlsym(handle, "create_module"); + if (!creator) { + if ((error = dlerror()) != 0) { + TRANSPORT_LOGE("%s", error); + return 0; + } + } + + // create object and return it + module = (*creator)(); + module->handle_ = handle; + + return module; +#endif +} + +bool IoModule::unload(IoModule *module) { + if (!module) { + return false; + } + +#ifdef ANDROID + delete module; +#else + // destroy object and close module + void *handle = module->handle_; + delete module; + dlclose(handle); +#endif + + return true; +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.cc b/libtransport/src/core/local_connector.cc new file mode 100644 index 000000000..f0e36a3d7 --- /dev/null +++ b/libtransport/src/core/local_connector.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/local_connector.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +LocalConnector::~LocalConnector() {} + +void LocalConnector::close() { state_ = State::CLOSED; } + +void LocalConnector::send(Packet &packet) { + if (!isConnected()) { + return; + } + + TRANSPORT_LOGD("Sending packet to local socket."); + io_service_.get().post([this, p{packet.shared_from_this()}]() mutable { + receive_callback_(this, *p, std::make_error_code(std::errc(0))); + }); +} + +void LocalConnector::send(const uint8_t *packet, std::size_t len) { + throw errors::NotImplementedException(); +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.h b/libtransport/src/core/local_connector.h new file mode 100644 index 000000000..b0daa4f53 --- /dev/null +++ b/libtransport/src/core/local_connector.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/move_wrapper.h> +#include <hicn/transport/utils/shared_ptr_utils.h> +#include <io_modules/forwarder/errors.h> + +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#endif +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +class LocalConnector : public Connector { + public: + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect> + LocalConnector(asio::io_service &io_service, + ReceiveCallback &&receive_callback, SentCallback &&packet_sent, + OnClose &&close_callback, OnReconnect &&on_reconnect) + : Connector(receive_callback, packet_sent, close_callback, on_reconnect), + io_service_(io_service), + io_service_work_(io_service_.get()) { + state_ = State::CONNECTED; + } + + ~LocalConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + auto shared_from_this() { return utils::shared_from(this); } + + private: + std::reference_wrapper<asio::io_service> io_service_; + asio::io_service::work io_service_work_; + std::string name_; +}; + +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/manifest.h b/libtransport/src/core/manifest.h index eadfed752..9b25ebd67 100644 --- a/libtransport/src/core/manifest.h +++ b/libtransport/src/core/manifest.h @@ -15,11 +15,10 @@ #pragma once +#include <core/manifest_format.h> #include <hicn/transport/core/content_object.h> #include <hicn/transport/core/name.h> -#include <core/manifest_format.h> - #include <set> namespace transport { @@ -36,18 +35,20 @@ class Manifest : public Base { "Base must inherit from packet!"); public: + // core::ContentObjectManifest::Ptr + using Encoder = typename FormatTraits::Encoder; using Decoder = typename FormatTraits::Decoder; Manifest(std::size_t signature_size = 0) - : Base(HF_INET6_TCP_AH), + : Base(HF_INET6_TCP_AH, signature_size), encoder_(*this, signature_size), decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); } Manifest(const core::Name &name, std::size_t signature_size = 0) - : Base(name, HF_INET6_TCP_AH), + : Base(name, HF_INET6_TCP_AH, signature_size), encoder_(*this, signature_size), decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); @@ -55,7 +56,9 @@ class Manifest : public Base { template <typename T> Manifest(T &&base) - : Base(std::forward<T &&>(base)), encoder_(*this), decoder_(*this) { + : Base(std::forward<T &&>(base)), + encoder_(*this, 0, false), + decoder_(*this) { Base::setPayloadType(PayloadType::MANIFEST); } @@ -96,13 +99,13 @@ class Manifest : public Base { return *this; } - Manifest &setHashAlgorithm(utils::CryptoHashType hash_algorithm) { + Manifest &setHashAlgorithm(auth::CryptoHashType hash_algorithm) { hash_algorithm_ = hash_algorithm; encoder_.setHashAlgorithm(hash_algorithm_); return *this; } - utils::CryptoHashType getHashAlgorithm() { return hash_algorithm_; } + auth::CryptoHashType getHashAlgorithm() { return hash_algorithm_; } ManifestType getManifestType() const { return manifest_type_; } @@ -138,7 +141,7 @@ class Manifest : public Base { protected: ManifestType manifest_type_; - utils::CryptoHashType hash_algorithm_; + auth::CryptoHashType hash_algorithm_; bool is_last_; Encoder encoder_; diff --git a/libtransport/src/core/manifest_format.h b/libtransport/src/core/manifest_format.h index 36d23f99b..b759942cb 100644 --- a/libtransport/src/core/manifest_format.h +++ b/libtransport/src/core/manifest_format.h @@ -15,8 +15,8 @@ #pragma once +#include <hicn/transport/auth/crypto_hasher.h> #include <hicn/transport/core/name.h> -#include <hicn/transport/security/crypto_hasher.h> #include <cinttypes> #include <type_traits> @@ -63,8 +63,10 @@ template <typename T> struct format_traits { using Encoder = typename T::Encoder; using Decoder = typename T::Decoder; + using Hash = typename T::Hash; using HashType = typename T::HashType; - using HashList = typename T::HashList; + using Suffix = typename T::Suffix; + using SuffixList = typename T::SuffixList; }; class Packet; @@ -86,7 +88,7 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).setManifestTypeImpl(type); } - ManifestEncoder &setHashAlgorithm(utils::CryptoHashType hash) { + ManifestEncoder &setHashAlgorithm(auth::CryptoHashType hash) { return static_cast<Implementation &>(*this).setHashAlgorithmImpl(hash); } @@ -160,7 +162,7 @@ class ManifestDecoder { return static_cast<const Implementation &>(*this).getManifestTypeImpl(); } - utils::CryptoHashType getHashAlgorithm() const { + auth::CryptoHashType getHashAlgorithm() const { return static_cast<const Implementation &>(*this).getHashAlgorithmImpl(); } diff --git a/libtransport/src/core/manifest_format_fixed.cc b/libtransport/src/core/manifest_format_fixed.cc index ca80c38b1..55280b460 100644 --- a/libtransport/src/core/manifest_format_fixed.cc +++ b/libtransport/src/core/manifest_format_fixed.cc @@ -13,49 +13,50 @@ * limitations under the License. */ +#include <core/manifest_format_fixed.h> #include <hicn/transport/core/packet.h> #include <hicn/transport/utils/literals.h> -#include <core/manifest_format_fixed.h> - namespace transport { namespace core { // TODO use preallocated pool of membufs FixedManifestEncoder::FixedManifestEncoder(Packet &packet, - std::size_t signature_size) + std::size_t signature_size, + bool clear) : packet_(packet), - max_size_(Packet::default_mtu - packet_.headerSize() - signature_size), - manifest_( - utils::MemBuf::create(Packet::default_mtu - packet_.headerSize())), - manifest_header_( - reinterpret_cast<ManifestHeader *>(manifest_->writableData())), - manifest_entries_(reinterpret_cast<ManifestEntry *>( - manifest_->writableData() + sizeof(ManifestHeader))), + max_size_(Packet::default_mtu - packet_.headerSize()), + manifest_header_(reinterpret_cast<ManifestHeader *>( + packet_.writableData() + packet_.headerSize())), + manifest_entries_( + reinterpret_cast<ManifestEntry *>(manifest_header_ + 1)), current_entry_(0), signature_size_(signature_size) { - *manifest_header_ = {0}; + if (clear) { + *manifest_header_ = {0}; + } } FixedManifestEncoder::~FixedManifestEncoder() {} FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { - manifest_->append(sizeof(ManifestHeader) + - manifest_header_->number_of_entries * - sizeof(ManifestEntry)); - packet_.appendPayload(std::move(manifest_)); + packet_.append(sizeof(ManifestHeader) + + manifest_header_->number_of_entries * sizeof(ManifestEntry)); + packet_.updateLength(); return *this; } FixedManifestEncoder &FixedManifestEncoder::clearImpl() { - manifest_ = utils::MemBuf::create(Packet::default_mtu - packet_.headerSize() - - signature_size_); + packet_.trimEnd(sizeof(ManifestHeader) + + manifest_header_->number_of_entries * sizeof(ManifestEntry)); + current_entry_ = 0; + *manifest_header_ = {0}; return *this; } FixedManifestEncoder &FixedManifestEncoder::setHashAlgorithmImpl( - utils::CryptoHashType algorithm) { + auth::CryptoHashType algorithm) { manifest_header_->hash_algorithm = static_cast<uint8_t>(algorithm); return *this; } @@ -83,7 +84,7 @@ FixedManifestEncoder &FixedManifestEncoder::setBaseNameImpl( } FixedManifestEncoder &FixedManifestEncoder::addSuffixAndHashImpl( - uint32_t suffix, const utils::CryptoHash &hash) { + uint32_t suffix, const auth::CryptoHash &hash) { auto _hash = hash.getDigest<std::uint8_t>(); addSuffixHashBytes(suffix, _hash.data(), _hash.length()); return *this; @@ -170,8 +171,8 @@ ManifestType FixedManifestDecoder::getManifestTypeImpl() const { return static_cast<ManifestType>(manifest_header_->manifest_type); } -utils::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { - return static_cast<utils::CryptoHashType>(manifest_header_->hash_algorithm); +auth::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { + return static_cast<auth::CryptoHashType>(manifest_header_->hash_algorithm); } NextSegmentCalculationStrategy diff --git a/libtransport/src/core/manifest_format_fixed.h b/libtransport/src/core/manifest_format_fixed.h index 1d7cd7d32..56ad4ef6d 100644 --- a/libtransport/src/core/manifest_format_fixed.h +++ b/libtransport/src/core/manifest_format_fixed.h @@ -15,9 +15,8 @@ #pragma once -#include <hicn/transport/core/packet.h> - #include <core/manifest_format.h> +#include <hicn/transport/core/packet.h> #include <string> @@ -53,8 +52,10 @@ class Packet; struct Fixed { using Encoder = FixedManifestEncoder; using Decoder = FixedManifestDecoder; - using HashType = utils::CryptoHash; - using SuffixList = std::list<std::pair<std::uint32_t, std::uint8_t *>>; + using Hash = auth::CryptoHash; + using HashType = auth::CryptoHashType; + using Suffix = uint32_t; + using SuffixList = std::list<std::pair<uint32_t, uint8_t *>>; }; struct Flags { @@ -84,7 +85,8 @@ static const constexpr std::uint8_t manifest_version = 1; class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { public: - FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0); + FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0, + bool clear = true); ~FixedManifestEncoder(); @@ -94,7 +96,7 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setManifestTypeImpl(ManifestType manifest_type); - FixedManifestEncoder &setHashAlgorithmImpl(utils::CryptoHashType algorithm); + FixedManifestEncoder &setHashAlgorithmImpl(Fixed::HashType algorithm); FixedManifestEncoder &setNextSegmentCalculationStrategyImpl( NextSegmentCalculationStrategy strategy); @@ -102,7 +104,7 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setBaseNameImpl(const core::Name &base_name); FixedManifestEncoder &addSuffixAndHashImpl(uint32_t suffix, - const utils::CryptoHash &hash); + const Fixed::Hash &hash); FixedManifestEncoder &setIsFinalManifestImpl(bool is_last); @@ -125,7 +127,6 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { Packet &packet_; std::size_t max_size_; - std::unique_ptr<utils::MemBuf> manifest_; ManifestHeader *manifest_header_; ManifestEntry *manifest_entries_; std::size_t current_entry_; @@ -144,7 +145,7 @@ class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { ManifestType getManifestTypeImpl() const; - utils::CryptoHashType getHashAlgorithmImpl() const; + Fixed::HashType getHashAlgorithmImpl() const; NextSegmentCalculationStrategy getNextSegmentCalculationStrategyImpl() const; diff --git a/libtransport/src/core/manifest_inline.h b/libtransport/src/core/manifest_inline.h index dedf82b45..fcb1d214f 100644 --- a/libtransport/src/core/manifest_inline.h +++ b/libtransport/src/core/manifest_inline.h @@ -30,8 +30,12 @@ class ManifestInline : public Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>> { using ManifestBase = Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>>; + + using Hash = typename FormatTraits::Hash; using HashType = typename FormatTraits::HashType; + using Suffix = typename FormatTraits::Suffix; using SuffixList = typename FormatTraits::SuffixList; + using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; public: ManifestInline() : ManifestBase() {} @@ -44,7 +48,7 @@ class ManifestInline static TRANSPORT_ALWAYS_INLINE ManifestInline *createManifest( const core::Name &manifest_name, ManifestVersion version, - ManifestType type, utils::CryptoHashType algorithm, bool is_last, + ManifestType type, auth::CryptoHashType algorithm, bool is_last, const Name &base_name, NextSegmentCalculationStrategy strategy, std::size_t signature_size) { auto manifest = new ManifestInline(manifest_name, signature_size); @@ -84,7 +88,7 @@ class ManifestInline const Name &getBaseName() { return base_name_; } - ManifestInline &addSuffixHash(uint32_t suffix, const HashType &hash) { + ManifestInline &addSuffixHash(Suffix suffix, const Hash &hash) { ManifestBase::encoder_.addSuffixAndHash(suffix, hash); return *this; } @@ -104,12 +108,35 @@ class ManifestInline return next_segment_strategy_; } + // Convert several manifests into a single map from suffixes to packet hashes. + // All manifests must have been decoded beforehand. + static std::unordered_map<Suffix, HashEntry> getSuffixMap( + const std::vector<ManifestInline *> &manifests) { + std::unordered_map<Suffix, HashEntry> suffix_map; + + for (auto manifest_ptr : manifests) { + HashType hash_algorithm = manifest_ptr->getHashAlgorithm(); + SuffixList suffix_list = manifest_ptr->getSuffixList(); + + for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { + std::vector<uint8_t> hash( + it->second, it->second + auth::hash_size_map[hash_algorithm]); + suffix_map[it->first] = {hash_algorithm, hash}; + } + } + + return suffix_map; + } + static std::unordered_map<Suffix, HashEntry> getSuffixMap( + ManifestInline *manifest) { + return getSuffixMap(std::vector<ManifestInline *>{manifest}); + } + private: core::Name base_name_; NextSegmentCalculationStrategy next_segment_strategy_; SuffixList suffix_hash_map_; }; -} // end namespace core - -} // end namespace transport
\ No newline at end of file +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/name.cc b/libtransport/src/core/name.cc index 811e93b87..795c8a697 100644 --- a/libtransport/src/core/name.cc +++ b/libtransport/src/core/name.cc @@ -13,14 +13,13 @@ * limitations under the License. */ +#include <core/manifest_format.h> #include <hicn/transport/core/name.h> #include <hicn/transport/errors/errors.h> #include <hicn/transport/errors/tokenizer_exception.h> #include <hicn/transport/utils/hash.h> #include <hicn/transport/utils/string_tokenizer.h> -#include <core/manifest_format.h> - namespace transport { namespace core { @@ -98,7 +97,12 @@ bool Name::operator!=(const Name &name) const { } Name::operator bool() const { - return bool(hicn_name_empty((hicn_name_t *)&name_)); + auto ret = isValid(); + return ret; +} + +bool Name::isValid() const { + return bool(!hicn_name_empty((hicn_name_t *)&name_)); } bool Name::equals(const Name &name, bool consider_segment) const { diff --git a/libtransport/src/core/packet.cc b/libtransport/src/core/packet.cc index cd2c5aa69..6f237729a 100644 --- a/libtransport/src/core/packet.cc +++ b/libtransport/src/core/packet.cc @@ -31,59 +31,94 @@ namespace core { const core::Name Packet::base_name("0::0|0"); -Packet::Packet(Format format) - : packet_(utils::MemBuf::create(getHeaderSizeFromFormat(format, 256)) - .release()), - packet_start_(reinterpret_cast<hicn_header_t *>(packet_->writableData())), - header_head_(packet_.get()), - payload_head_(nullptr), - format_(format) { - if (hicn_packet_init_header(format, packet_start_) < 0) { - throw errors::RuntimeException("Unexpected error initializing the packet."); - } - - packet_->append(getHeaderSizeFromFormat(format_)); -} - -Packet::Packet(MemBufPtr &&buffer) - : packet_(std::move(buffer)), - packet_start_(reinterpret_cast<hicn_header_t *>(packet_->writableData())), - header_head_(packet_.get()), - payload_head_(nullptr), - format_(getFormatFromBuffer(packet_->writableData(), packet_->length())) { -} - -Packet::Packet(const uint8_t *buffer, std::size_t size) - : Packet(MemBufPtr(utils::MemBuf::copyBuffer(buffer, size).release())) {} +Packet::Packet(Format format, std::size_t additional_header_size) + : utils::MemBuf(utils::MemBuf(CREATE, 2048)), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(format), + payload_type_(PayloadType::UNSPECIFIED) { + setFormat(format_, additional_header_size); +} + +Packet::Packet(MemBuf &&buffer) + : utils::MemBuf(std::move(buffer)), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(data(), length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(CopyBufferOp, const uint8_t *buffer, std::size_t size) + : utils::MemBuf(COPY_BUFFER, buffer, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(data(), length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(WrapBufferOp, uint8_t *buffer, std::size_t length, + std::size_t size) + : utils::MemBuf(WRAP_BUFFER, buffer, length, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(getFormatFromBuffer(this->data(), this->length())), + payload_type_(PayloadType::UNSPECIFIED) {} + +Packet::Packet(CreateOp, uint8_t *buffer, std::size_t length, std::size_t size, + Format format, std::size_t additional_header_size) + : utils::MemBuf(WRAP_BUFFER, buffer, length, size), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(0), + format_(format), + payload_type_(PayloadType::UNSPECIFIED) { + clear(); + setFormat(format_, additional_header_size); +} + +Packet::Packet(const Packet &other) + : utils::MemBuf(other), + packet_start_(reinterpret_cast<hicn_header_t *>(writableData())), + header_offset_(other.header_offset_), + format_(other.format_), + payload_type_(PayloadType::UNSPECIFIED) {} Packet::Packet(Packet &&other) - : packet_(std::move(other.packet_)), + : utils::MemBuf(std::move(other)), packet_start_(other.packet_start_), - header_head_(other.header_head_), - payload_head_(other.payload_head_), - format_(other.format_) { + header_offset_(other.header_offset_), + format_(other.format_), + payload_type_(PayloadType::UNSPECIFIED) { other.packet_start_ = nullptr; - other.header_head_ = nullptr; - other.payload_head_ = nullptr; other.format_ = HF_UNSPEC; + other.header_offset_ = 0; } Packet::~Packet() {} +Packet &Packet::operator=(const Packet &other) { + if (this != &other) { + *this = other; + packet_start_ = reinterpret_cast<hicn_header_t *>(writableData()); + } + + return *this; +} + std::size_t Packet::getHeaderSizeFromBuffer(Format format, const uint8_t *buffer) { size_t header_length; + if (hicn_packet_get_header_length(format, (hicn_header_t *)buffer, &header_length) < 0) { throw errors::MalformedPacketException(); } + return header_length; } bool Packet::isInterest(const uint8_t *buffer) { bool is_interest = false; - if (TRANSPORT_EXPECT_FALSE(hicn_packet_test_ece((const hicn_header_t *)buffer, + if (TRANSPORT_EXPECT_FALSE(hicn_packet_test_ece(HF_INET6_TCP, + (const hicn_header_t *)buffer, &is_interest) < 0)) { throw errors::RuntimeException( "Impossible to retrieve ece flag from packet"); @@ -92,6 +127,25 @@ bool Packet::isInterest(const uint8_t *buffer) { return !is_interest; } +void Packet::setFormat(Packet::Format format, + std::size_t additional_header_size) { + format_ = format; + if (hicn_packet_init_header(format_, packet_start_) < 0) { + throw errors::RuntimeException("Unexpected error initializing the packet."); + } + + auto header_size = getHeaderSizeFromFormat(format_); + assert(header_size <= tailroom()); + append(header_size); + + assert(additional_header_size <= tailroom()); + append(additional_header_size); + + header_offset_ = length(); +} + +bool Packet::isInterest() { return Packet::isInterest(data()); } + std::size_t Packet::getPayloadSizeFromBuffer(Format format, const uint8_t *buffer) { std::size_t payload_length; @@ -105,67 +159,58 @@ std::size_t Packet::getPayloadSizeFromBuffer(Format format, } std::size_t Packet::payloadSize() const { - return getPayloadSizeFromBuffer(format_, - reinterpret_cast<uint8_t *>(packet_start_)); + std::size_t ret = 0; + + if (length()) { + ret = getPayloadSizeFromBuffer(format_, + reinterpret_cast<uint8_t *>(packet_start_)); + } + + return ret; } std::size_t Packet::headerSize() const { - return getHeaderSizeFromBuffer(format_, - reinterpret_cast<uint8_t *>(packet_start_)); + if (header_offset_ == 0 && length()) { + const_cast<Packet *>(this)->header_offset_ = getHeaderSizeFromBuffer( + format_, reinterpret_cast<uint8_t *>(packet_start_)); + } + + return header_offset_; } Packet &Packet::appendPayload(std::unique_ptr<utils::MemBuf> &&payload) { - separateHeaderPayload(); - - if (!payload_head_) { - payload_head_ = payload.get(); - } - - header_head_->prependChain(std::move(payload)); + prependChain(std::move(payload)); updateLength(); return *this; } Packet &Packet::appendPayload(const uint8_t *buffer, std::size_t length) { - return appendPayload(utils::MemBuf::copyBuffer(buffer, length)); -} - -Packet &Packet::appendHeader(std::unique_ptr<utils::MemBuf> &&header) { - separateHeaderPayload(); + prependPayload(&buffer, &length); - if (!payload_head_) { - header_head_->prependChain(std::move(header)); - } else { - payload_head_->prependChain(std::move(header)); + if (length) { + appendPayload(utils::MemBuf::copyBuffer(buffer, length)); } updateLength(); return *this; } -Packet &Packet::appendHeader(const uint8_t *buffer, std::size_t length) { - return appendHeader(utils::MemBuf::copyBuffer(buffer, length)); -} - std::unique_ptr<utils::MemBuf> Packet::getPayload() const { - const_cast<Packet *>(this)->separateHeaderPayload(); - - // Hopefully the payload is contiguous - if (TRANSPORT_EXPECT_FALSE(payload_head_ && - payload_head_->next() != header_head_)) { - payload_head_->gather(payloadSize()); - } - - return payload_head_->cloneOne(); + auto ret = clone(); + ret->trimStart(headerSize()); + return ret; } Packet &Packet::updateLength(std::size_t length) { std::size_t total_length = length; - for (utils::MemBuf *current = payload_head_; - current && current != header_head_; current = current->next()) { + const utils::MemBuf *current = this; + do { total_length += current->length(); - } + current = current->next(); + } while (current != this); + + total_length -= headerSize(); if (hicn_packet_set_payload_length(format_, packet_start_, total_length) < 0) { @@ -176,13 +221,16 @@ Packet &Packet::updateLength(std::size_t length) { } PayloadType Packet::getPayloadType() const { - hicn_payload_type_t ret = HPT_UNSPEC; + if (payload_type_ == PayloadType::UNSPECIFIED) { + hicn_payload_type_t ret; + if (hicn_packet_get_payload_type(packet_start_, &ret) < 0) { + throw errors::RuntimeException("Impossible to retrieve payload type."); + } - if (hicn_packet_get_payload_type(packet_start_, &ret) < 0) { - throw errors::RuntimeException("Impossible to retrieve payload type."); + payload_type_ = (PayloadType)ret; } - return PayloadType(ret); + return payload_type_; } Packet &Packet::setPayloadType(PayloadType payload_type) { @@ -191,60 +239,76 @@ Packet &Packet::setPayloadType(PayloadType payload_type) { throw errors::RuntimeException("Error setting payload type of the packet."); } + payload_type_ = payload_type; + return *this; } Packet::Format Packet::getFormat() const { - if (format_ == HF_UNSPEC) { + /** + * We check packet start because after a movement it will result in a nullptr + */ + if (format_ == HF_UNSPEC && length()) { if (hicn_packet_get_format(packet_start_, &format_) < 0) { - throw errors::MalformedPacketException(); + TRANSPORT_LOGE("Unexpected packet format."); } } return format_; } -const std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() const { - return packet_; +std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() { + return std::static_pointer_cast<utils::MemBuf>(shared_from_this()); } void Packet::dump() const { - const_cast<Packet *>(this)->separateHeaderPayload(); - TRANSPORT_LOGI("HEADER -- Length: %zu", headerSize()); - hicn_packet_dump((uint8_t *)header_head_->data(), headerSize()); - TRANSPORT_LOGI("PAYLOAD -- Length: %zu", payloadSize()); - for (utils::MemBuf *current = payload_head_; - current && current != header_head_; current = current->next()) { + + const utils::MemBuf *current = this; + do { TRANSPORT_LOGI("MemBuf Length: %zu", current->length()); - hicn_packet_dump((uint8_t *)current->data(), current->length()); - } + dump((uint8_t *)current->data(), current->length()); + current = current->next(); + } while (current != this); +} + +void Packet::dump(uint8_t *buffer, std::size_t length) { + hicn_packet_dump(buffer, length); } void Packet::setSignatureSize(std::size_t size_bytes) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_signature_size(format_, packet_start_, size_bytes); if (ret < 0) { - throw errors::RuntimeException("Packet without Authentication Header."); + throw errors::RuntimeException("Error setting signature size."); } - - packet_->append(size_bytes); - updateLength(); } uint8_t *Packet::getSignature() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint8_t *signature; int ret = hicn_packet_get_signature(format_, packet_start_, &signature); if (ret < 0) { - throw errors::RuntimeException("Packet without Authentication Header."); + throw errors::RuntimeException("Error getting signature."); } return signature; } void Packet::setSignatureTimestamp(const uint64_t ×tamp) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_signature_timestamp(format_, packet_start_, timestamp); @@ -254,6 +318,10 @@ void Packet::setSignatureTimestamp(const uint64_t ×tamp) { } uint64_t Packet::getSignatureTimestamp() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint64_t return_value; int ret = hicn_packet_get_signature_timestamp(format_, packet_start_, &return_value); @@ -266,7 +334,11 @@ uint64_t Packet::getSignatureTimestamp() const { } void Packet::setValidationAlgorithm( - const utils::CryptoSuite &validation_algorithm) { + const auth::CryptoSuite &validation_algorithm) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_validation_algorithm(format_, packet_start_, uint8_t(validation_algorithm)); @@ -275,7 +347,11 @@ void Packet::setValidationAlgorithm( } } -utils::CryptoSuite Packet::getValidationAlgorithm() const { +auth::CryptoSuite Packet::getValidationAlgorithm() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + uint8_t return_value; int ret = hicn_packet_get_validation_algorithm(format_, packet_start_, &return_value); @@ -284,10 +360,14 @@ utils::CryptoSuite Packet::getValidationAlgorithm() const { throw errors::RuntimeException("Error getting the validation algorithm."); } - return utils::CryptoSuite(return_value); + return auth::CryptoSuite(return_value); } -void Packet::setKeyId(const utils::KeyId &key_id) { +void Packet::setKeyId(const auth::KeyId &key_id) { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + int ret = hicn_packet_set_key_id(format_, packet_start_, key_id.first); if (ret < 0) { @@ -295,8 +375,12 @@ void Packet::setKeyId(const utils::KeyId &key_id) { } } -utils::KeyId Packet::getKeyId() const { - utils::KeyId return_value; +auth::KeyId Packet::getKeyId() const { + if (!authenticationHeader()) { + throw errors::RuntimeException("Packet without Authentication Header."); + } + + auth::KeyId return_value; int ret = hicn_packet_get_key_id(format_, packet_start_, &return_value.first, &return_value.second); @@ -307,8 +391,8 @@ utils::KeyId Packet::getKeyId() const { return return_value; } -utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { - utils::CryptoHasher hasher(static_cast<utils::CryptoHashType>(algorithm)); +auth::CryptoHash Packet::computeDigest(auth::CryptoHashType algorithm) const { + auth::CryptoHasher hasher(static_cast<auth::CryptoHashType>(algorithm)); hasher.init(); // Copy IP+TCP/ICMP header before zeroing them @@ -318,11 +402,11 @@ utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { const_cast<Packet *>(this)->resetForHash(); - auto current = header_head_; + const utils::MemBuf *current = this; do { hasher.updateBytes(current->data(), current->length()); current = current->next(); - } while (current != header_head_); + } while (current != this); hicn_packet_copy_header(format_, &header_copy, packet_start_, false); @@ -330,15 +414,33 @@ utils::CryptoHash Packet::computeDigest(utils::CryptoHashType algorithm) const { } bool Packet::checkIntegrity() const { - if (hicn_packet_check_integrity(format_, packet_start_) < 0) { + uint16_t partial_csum = + csum(data() + HICN_V6_TCP_HDRLEN, length() - HICN_V6_TCP_HDRLEN, 0); + + for (const utils::MemBuf *current = next(); current != this; + current = current->next()) { + partial_csum = csum(current->data(), current->length(), ~partial_csum); + } + + if (hicn_packet_check_integrity_no_payload(format_, packet_start_, + partial_csum) < 0) { return false; } return true; } +void Packet::prependPayload(const uint8_t **buffer, std::size_t *size) { + auto last = prev(); + auto to_copy = std::min(*size, last->tailroom()); + std::memcpy(last->writableTail(), *buffer, to_copy); + last->append(to_copy); + *size -= to_copy; + *buffer += to_copy; +} + Packet &Packet::setSyn() { - if (hicn_packet_set_syn(packet_start_) < 0) { + if (hicn_packet_set_syn(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting syn bit in the packet."); } @@ -346,7 +448,7 @@ Packet &Packet::setSyn() { } Packet &Packet::resetSyn() { - if (hicn_packet_reset_syn(packet_start_) < 0) { + if (hicn_packet_reset_syn(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting syn bit in the packet."); } @@ -355,7 +457,7 @@ Packet &Packet::resetSyn() { bool Packet::testSyn() const { bool res = false; - if (hicn_packet_test_syn(packet_start_, &res) < 0) { + if (hicn_packet_test_syn(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing syn bit in the packet."); } @@ -363,7 +465,7 @@ bool Packet::testSyn() const { } Packet &Packet::setAck() { - if (hicn_packet_set_ack(packet_start_) < 0) { + if (hicn_packet_set_ack(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting ack bit in the packet."); } @@ -371,7 +473,7 @@ Packet &Packet::setAck() { } Packet &Packet::resetAck() { - if (hicn_packet_reset_ack(packet_start_) < 0) { + if (hicn_packet_reset_ack(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting ack bit in the packet."); } @@ -380,7 +482,7 @@ Packet &Packet::resetAck() { bool Packet::testAck() const { bool res = false; - if (hicn_packet_test_ack(packet_start_, &res) < 0) { + if (hicn_packet_test_ack(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing ack bit in the packet."); } @@ -388,7 +490,7 @@ bool Packet::testAck() const { } Packet &Packet::setRst() { - if (hicn_packet_set_rst(packet_start_) < 0) { + if (hicn_packet_set_rst(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting rst bit in the packet."); } @@ -396,7 +498,7 @@ Packet &Packet::setRst() { } Packet &Packet::resetRst() { - if (hicn_packet_reset_rst(packet_start_) < 0) { + if (hicn_packet_reset_rst(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting rst bit in the packet."); } @@ -405,7 +507,7 @@ Packet &Packet::resetRst() { bool Packet::testRst() const { bool res = false; - if (hicn_packet_test_rst(packet_start_, &res) < 0) { + if (hicn_packet_test_rst(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing rst bit in the packet."); } @@ -413,7 +515,7 @@ bool Packet::testRst() const { } Packet &Packet::setFin() { - if (hicn_packet_set_fin(packet_start_) < 0) { + if (hicn_packet_set_fin(format_, packet_start_) < 0) { throw errors::RuntimeException("Error setting fin bit in the packet."); } @@ -421,7 +523,7 @@ Packet &Packet::setFin() { } Packet &Packet::resetFin() { - if (hicn_packet_reset_fin(packet_start_) < 0) { + if (hicn_packet_reset_fin(format_, packet_start_) < 0) { throw errors::RuntimeException("Error resetting fin bit in the packet."); } @@ -430,7 +532,7 @@ Packet &Packet::resetFin() { bool Packet::testFin() const { bool res = false; - if (hicn_packet_test_fin(packet_start_, &res) < 0) { + if (hicn_packet_test_fin(format_, packet_start_, &res) < 0) { throw errors::RuntimeException("Error testing fin bit in the packet."); } @@ -447,24 +549,29 @@ Packet &Packet::resetFlags() { } std::string Packet::printFlags() const { - std::string flags = ""; + std::string flags; + if (testSyn()) { flags += "S"; } + if (testAck()) { flags += "A"; } + if (testRst()) { flags += "R"; } + if (testFin()) { flags += "F"; } + return flags; } Packet &Packet::setSrcPort(uint16_t srcPort) { - if (hicn_packet_set_src_port(packet_start_, srcPort) < 0) { + if (hicn_packet_set_src_port(format_, packet_start_, srcPort) < 0) { throw errors::RuntimeException("Error setting source port in the packet."); } @@ -472,7 +579,7 @@ Packet &Packet::setSrcPort(uint16_t srcPort) { } Packet &Packet::setDstPort(uint16_t dstPort) { - if (hicn_packet_set_dst_port(packet_start_, dstPort) < 0) { + if (hicn_packet_set_dst_port(format_, packet_start_, dstPort) < 0) { throw errors::RuntimeException( "Error setting destination port in the packet."); } @@ -483,7 +590,7 @@ Packet &Packet::setDstPort(uint16_t dstPort) { uint16_t Packet::getSrcPort() const { uint16_t port = 0; - if (hicn_packet_get_src_port(packet_start_, &port) < 0) { + if (hicn_packet_get_src_port(format_, packet_start_, &port) < 0) { throw errors::RuntimeException("Error reading source port in the packet."); } @@ -493,7 +600,7 @@ uint16_t Packet::getSrcPort() const { uint16_t Packet::getDstPort() const { uint16_t port = 0; - if (hicn_packet_get_dst_port(packet_start_, &port) < 0) { + if (hicn_packet_get_dst_port(format_, packet_start_, &port) < 0) { throw errors::RuntimeException( "Error reading destination port in the packet."); } @@ -518,37 +625,6 @@ uint8_t Packet::getTTL() const { return hops; } -void Packet::separateHeaderPayload() { - if (payload_head_) { - return; - } - - int signature_size = 0; - if (_is_ah(format_)) { - signature_size = (uint32_t)getSignatureSize(); - } - - auto header_size = getHeaderSizeFromFormat(format_, signature_size); - auto payload_length = packet_->length() - header_size; - - packet_->trimEnd(packet_->length()); - - auto payload = packet_->cloneOne(); - payload_head_ = payload.get(); - payload_head_->advance(header_size); - payload_head_->append(payload_length); - packet_->prependChain(std::move(payload)); - packet_->append(header_size); -} - -void Packet::resetPayload() { - if (packet_->isChained()) { - packet_->separateChain(packet_->next(), packet_->prev()); - payload_head_ = nullptr; - updateLength(); - } -} - } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/pending_interest.h b/libtransport/src/core/pending_interest.h index aeff78ea2..ca6411ddf 100644 --- a/libtransport/src/core/pending_interest.h +++ b/libtransport/src/core/pending_interest.h @@ -21,7 +21,6 @@ #include <hicn/transport/core/name.h> #include <hicn/transport/interfaces/portal.h> #include <hicn/transport/portability/portability.h> - #include <utils/deadline_timer.h> #include <asio/steady_timer.hpp> @@ -34,24 +33,21 @@ class HicnForwarderInterface; class VPPForwarderInterface; class RawSocketInterface; -template <typename ForwarderInt> class Portal; using OnContentObjectCallback = interface::Portal::OnContentObjectCallback; using OnInterestTimeoutCallback = interface::Portal::OnInterestTimeoutCallback; class PendingInterest { - friend class Portal<HicnForwarderInterface>; - friend class Portal<VPPForwarderInterface>; - friend class Portal<RawSocketInterface>; + friend class Portal; public: using Ptr = utils::ObjectPool<PendingInterest>::Ptr; - PendingInterest() - : interest_(nullptr, nullptr), - timer_(), - on_content_object_callback_(), - on_interest_timeout_callback_() {} + // PendingInterest() + // : interest_(nullptr, nullptr), + // timer_(), + // on_content_object_callback_(), + // on_interest_timeout_callback_() {} PendingInterest(Interest::Ptr &&interest, std::unique_ptr<asio::steady_timer> &&timer) diff --git a/libtransport/src/core/portal.cc b/libtransport/src/core/portal.cc new file mode 100644 index 000000000..d1d26c5b7 --- /dev/null +++ b/libtransport/src/core/portal.cc @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/errors.h> +#include <core/global_configuration.h> +#include <core/portal.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/portability/platform.h> +#include <hicn/transport/utils/file.h> + +#include <libconfig.h++> + +using namespace transport::interface::global_config; + +namespace transport { +namespace core { + +#ifdef ANDROID +static const constexpr char default_module[] = ""; +#elif defined(MACINTOSH) +static const constexpr char default_module[] = "hicnlight_module.dylib"; +#elif defined(LINUX) +static const constexpr char default_module[] = "hicnlight_module.so"; +#endif + +IoModuleConfiguration Portal::conf_; +std::string Portal::io_module_path_ = defaultIoModule(); + +std::string Portal::defaultIoModule() { + using namespace std::placeholders; + GlobalConfiguration::getInstance().registerConfigurationParser( + io_module_section, + std::bind(&Portal::parseIoModuleConfiguration, _1, _2)); + GlobalConfiguration::getInstance().registerConfigurationGetter( + io_module_section, std::bind(&Portal::getModuleConfiguration, _1, _2)); + GlobalConfiguration::getInstance().registerConfigurationSetter( + io_module_section, std::bind(&Portal::setModuleConfiguration, _1, _2)); + + // return default + conf_.name = default_module; + return default_module; +} + +void Portal::getModuleConfiguration(ConfigurationObject& object, + std::error_code& ec) { + assert(object.getKey() == io_module_section); + + auto conf = dynamic_cast<const IoModuleConfiguration&>(object); + conf = conf_; + ec = std::error_code(); +} + +std::string getIoModulePath(const std::string& name, + const std::vector<std::string>& paths, + std::error_code& ec) { +#ifdef LINUX + std::string extension = ".so"; +#elif defined(MACINTOSH) + std::string extension = ".dylib"; +#else +#error "Platform not supported."; +#endif + + std::string complete_path = name; + + if (name.empty()) { + ec = make_error_code(core_error::configuration_parse_failed); + return ""; + } + + complete_path += extension; + + for (auto& p : paths) { + if (p.at(0) != '/') { + TRANSPORT_LOGW("Path %s is not an absolute path. Ignoring it.", + p.c_str()); + continue; + } + + if (utils::File::exists(p + "/" + complete_path)) { + complete_path = p + "/" + complete_path; + break; + } + } + + return complete_path; +} + +void Portal::setModuleConfiguration(const ConfigurationObject& object, + std::error_code& ec) { + assert(object.getKey() == io_module_section); + + const IoModuleConfiguration& conf = + dynamic_cast<const IoModuleConfiguration&>(object); + auto path = getIoModulePath(conf.name, conf.search_path, ec); + if (!ec) { + conf_ = conf; + io_module_path_ = path; + } +} + +void Portal::parseIoModuleConfiguration(const libconfig::Setting& io_config, + std::error_code& ec) { + using namespace libconfig; + // path property: the list of paths where to look for the module. + std::vector<std::string> paths; + std::string name; + + if (io_config.exists("path")) { + // get path where looking for modules + const Setting& path_list = io_config.lookup("path"); + auto count = path_list.getLength(); + + for (int i = 0; i < count; i++) { + paths.emplace_back(path_list[i].c_str()); + } + } + + if (io_config.exists("name")) { + io_config.lookupValue("name", name); + } else { + ec = make_error_code(core_error::configuration_parse_failed); + return; + } + + auto path = getIoModulePath(name, paths, ec); + if (!ec) { + conf_.name = name; + conf_.search_path = paths; + io_module_path_ = path; + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/portal.h b/libtransport/src/core/portal.h index b63eab3af..59254cf7b 100644 --- a/libtransport/src/core/portal.h +++ b/libtransport/src/core/portal.h @@ -15,24 +15,20 @@ #pragma once -#include <core/forwarder_interface.h> #include <core/pending_interest.h> -#include <core/udp_socket_connector.h> #include <hicn/transport/config.h> #include <hicn/transport/core/content_object.h> #include <hicn/transport/core/interest.h> +#include <hicn/transport/core/io_module.h> #include <hicn/transport/core/name.h> #include <hicn/transport/core/prefix.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/interfaces/global_conf_interface.h> #include <hicn/transport/interfaces/portal.h> #include <hicn/transport/portability/portability.h> #include <hicn/transport/utils/fixed_block_allocator.h> #include <hicn/transport/utils/log.h> -#ifdef __vpp__ -#include <core/memif_connector.h> -#endif - #include <asio.hpp> #include <asio/steady_timer.hpp> #include <future> @@ -40,17 +36,19 @@ #include <queue> #include <unordered_map> +namespace libconfig { +class Setting; +} + namespace transport { namespace core { namespace portal_details { -static constexpr uint32_t pool_size = 2048; +static constexpr uint32_t pit_size = 1024; class HandlerMemory { #ifdef __vpp__ - static constexpr std::size_t memory_size = 1024 * 1024; - public: HandlerMemory() {} @@ -58,12 +56,11 @@ class HandlerMemory { HandlerMemory &operator=(const HandlerMemory &) = delete; TRANSPORT_ALWAYS_INLINE void *allocate(std::size_t size) { - return utils::FixedBlockAllocator<128, 4096>::getInstance() - ->allocateBlock(); + return utils::FixedBlockAllocator<128, 8192>::getInstance().allocateBlock(); } TRANSPORT_ALWAYS_INLINE void deallocate(void *pointer) { - utils::FixedBlockAllocator<128, 4096>::getInstance()->deallocateBlock( + utils::FixedBlockAllocator<128, 8192>::getInstance().deallocateBlock( pointer); } #else @@ -159,33 +156,16 @@ class Pool { public: Pool(asio::io_service &io_service) : io_service_(io_service) { increasePendingInterestPool(); - increaseInterestPool(); - increaseContentObjectPool(); } TRANSPORT_ALWAYS_INLINE void increasePendingInterestPool() { // Create pool of pending interests to reuse - for (uint32_t i = 0; i < pool_size; i++) { + for (uint32_t i = 0; i < pit_size; i++) { pending_interests_pool_.add(new PendingInterest( Interest::Ptr(nullptr), std::make_unique<asio::steady_timer>(io_service_))); } } - - TRANSPORT_ALWAYS_INLINE void increaseInterestPool() { - // Create pool of interests to reuse - for (uint32_t i = 0; i < pool_size; i++) { - interest_pool_.add(new Interest()); - } - } - - TRANSPORT_ALWAYS_INLINE void increaseContentObjectPool() { - // Create pool of content object to reuse - for (uint32_t i = 0; i < pool_size; i++) { - content_object_pool_.add(new ContentObject()); - } - } - PendingInterest::Ptr getPendingInterest() { auto res = pending_interests_pool_.get(); while (TRANSPORT_EXPECT_FALSE(!res.first)) { @@ -196,35 +176,15 @@ class Pool { return std::move(res.second); } - TRANSPORT_ALWAYS_INLINE ContentObject::Ptr getContentObject() { - auto res = content_object_pool_.get(); - while (TRANSPORT_EXPECT_FALSE(!res.first)) { - increaseContentObjectPool(); - res = content_object_pool_.get(); - } - - return std::move(res.second); - } - - TRANSPORT_ALWAYS_INLINE Interest::Ptr getInterest() { - auto res = interest_pool_.get(); - while (TRANSPORT_EXPECT_FALSE(!res.first)) { - increaseInterestPool(); - res = interest_pool_.get(); - } - - return std::move(res.second); - } - private: utils::ObjectPool<PendingInterest> pending_interests_pool_; - utils::ObjectPool<ContentObject> content_object_pool_; - utils::ObjectPool<Interest> interest_pool_; asio::io_service &io_service_; }; } // namespace portal_details +class PortalConfiguration; + using PendingInterestHashTable = std::unordered_map<uint32_t, PendingInterest::Ptr>; @@ -250,32 +210,32 @@ using interface::BindConfig; * The portal class is not thread safe, appropriate locking is required by the * users of this class. */ -template <typename ForwarderInt> -class Portal { - static_assert( - std::is_base_of<ForwarderInterface<ForwarderInt, - typename ForwarderInt::ConnectorType>, - ForwarderInt>::value, - "ForwarderInt must inherit from ForwarderInterface!"); +class Portal { public: using ConsumerCallback = interface::Portal::ConsumerCallback; using ProducerCallback = interface::Portal::ProducerCallback; + friend class PortalConfiguration; + Portal() : Portal(internal_io_service_) {} Portal(asio::io_service &io_service) - : io_service_(io_service), + : io_module_(nullptr, [](IoModule *module) { IoModule::unload(module); }), + io_service_(io_service), packet_pool_(io_service), app_name_("libtransport_application"), consumer_callback_(nullptr), producer_callback_(nullptr), - connector_(std::bind(&Portal::processIncomingMessages, this, - std::placeholders::_1), - std::bind(&Portal::setLocalRoutes, this), io_service_, - app_name_), - forwarder_interface_(connector_) {} - + is_consumer_(false) { + /** + * This workaroung allows to initialize memory for packet buffers *before* + * any static variables that may be initialized in the io_modules. In this + * way static variables in modules will be destroyed before the packet + * memory. + */ + PacketManager<>::getInstance(); + } /** * Set the consumer callback. * @@ -304,7 +264,7 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void setOutputInterface( const std::string &output_interface) { - forwarder_interface_.setOutputInterface(output_interface); + io_module_->setOutputInterface(output_interface); } /** @@ -314,8 +274,19 @@ class Portal { * is a consumer or a producer. */ TRANSPORT_ALWAYS_INLINE void connect(bool is_consumer = true) { - pending_interest_hash_table_.reserve(portal_details::pool_size); - forwarder_interface_.connect(is_consumer); + if (!io_module_) { + pending_interest_hash_table_.reserve(portal_details::pit_size); + io_module_.reset(IoModule::load(io_module_path_.c_str())); + assert(io_module_); + + io_module_->init(std::bind(&Portal::processIncomingMessages, this, + std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3), + std::bind(&Portal::setLocalRoutes, this), io_service_, + app_name_); + io_module_->connect(is_consumer); + is_consumer_ = is_consumer; + } } /** @@ -324,13 +295,19 @@ class Portal { ~Portal() { killConnection(); } /** + * Compute name hash + */ + TRANSPORT_ALWAYS_INLINE uint32_t getHash(const Name &name) { + return name.getHash32() + name.getSuffix(); + } + + /** * Check if there is already a pending interest for a given name. * * @param name - The interest name. */ TRANSPORT_ALWAYS_INLINE bool interestIsPending(const Name &name) { - auto it = - pending_interest_hash_table_.find(name.getHash32() + name.getSuffix()); + auto it = pending_interest_hash_table_.find(getHash(name)); if (it != pending_interest_hash_table_.end()) { return true; } @@ -357,31 +334,46 @@ class Portal { OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, OnInterestTimeoutCallback &&on_interest_timeout_callback = UNSET_CALLBACK) { - uint32_t hash = - interest->getName().getHash32() + interest->getName().getSuffix(); // Send it - forwarder_interface_.send(*interest); - - auto pending_interest = packet_pool_.getPendingInterest(); - pending_interest->setInterest(std::move(interest)); - pending_interest->setOnContentObjectCallback( - std::move(on_content_object_callback)); - pending_interest->setOnTimeoutCallback( - std::move(on_interest_timeout_callback)); - pending_interest->startCountdown(portal_details::makeCustomAllocatorHandler( - async_callback_memory_, std::bind(&Portal<ForwarderInt>::timerHandler, - this, std::placeholders::_1, hash))); + interest->encodeSuffixes(); + io_module_->send(*interest); + + uint32_t initial_hash = interest->getName().getHash32(); + auto hash = initial_hash + interest->getName().getSuffix(); + uint32_t *suffix = interest->firstSuffix(); + auto n_suffixes = interest->numberOfSuffixes(); + uint32_t counter = 0; + // Set timers + do { + auto pending_interest = packet_pool_.getPendingInterest(); + pending_interest->setInterest(std::move(interest)); + pending_interest->setOnContentObjectCallback( + std::move(on_content_object_callback)); + pending_interest->setOnTimeoutCallback( + std::move(on_interest_timeout_callback)); + + pending_interest->startCountdown( + portal_details::makeCustomAllocatorHandler( + async_callback_memory_, std::bind(&Portal::timerHandler, this, + std::placeholders::_1, hash))); + + auto it = pending_interest_hash_table_.find(hash); + if (it != pending_interest_hash_table_.end()) { + it->second->cancelTimer(); - auto it = pending_interest_hash_table_.find(hash); - if (it != pending_interest_hash_table_.end()) { - it->second->cancelTimer(); + // Get reference to interest packet in order to have it destroyed. + auto _int = it->second->getInterest(); + it->second = std::move(pending_interest); + } else { + pending_interest_hash_table_[hash] = std::move(pending_interest); + } - // Get reference to interest packet in order to have it destroyed. - auto _int = it->second->getInterest(); - it->second = std::move(pending_interest); - } else { - pending_interest_hash_table_[hash] = std::move(pending_interest); - } + if (suffix) { + hash = initial_hash + *suffix; + suffix++; + } + + } while (counter++ < n_suffixes); } /** @@ -423,9 +415,9 @@ class Portal { * @param config - The configuration for the local forwarder binding. */ TRANSPORT_ALWAYS_INLINE void bind(const BindConfig &config) { - forwarder_interface_.setContentStoreSize(config.csReserved()); + assert(io_module_); + io_module_->setContentStoreSize(config.csReserved()); served_namespaces_.push_back(config.prefix()); - setLocalRoutes(); } @@ -460,7 +452,7 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void sendContentObject( ContentObject &content_object) { - forwarder_interface_.send(content_object); + io_module_->send(content_object); } /** @@ -482,7 +474,7 @@ class Portal { * Disconnect the transport from the local forwarder. */ TRANSPORT_ALWAYS_INLINE void killConnection() { - forwarder_interface_.closeConnection(); + io_module_->closeConnection(); } /** @@ -497,6 +489,17 @@ class Portal { } /** + * Remove one pending interest. + */ + TRANSPORT_ALWAYS_INLINE void clearOne(const Name &name) { + if (!io_service_.stopped()) { + io_service_.dispatch(std::bind(&Portal::doClearOne, this, name)); + } else { + doClearOne(name); + } + } + + /** * Get a reference to the io_service object. */ TRANSPORT_ALWAYS_INLINE asio::io_service &getIoService() { @@ -508,8 +511,8 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void registerRoute(Prefix &prefix) { served_namespaces_.push_back(prefix); - if (connector_.isConnected()) { - forwarder_interface_.registerRoute(prefix); + if (io_module_->isConnected()) { + io_module_->registerRoute(prefix); } } @@ -530,36 +533,49 @@ class Portal { } /** + * Remove one pending interest. + */ + TRANSPORT_ALWAYS_INLINE void doClearOne(const Name &name) { + auto it = pending_interest_hash_table_.find(getHash(name)); + + if (it != pending_interest_hash_table_.end()) { + it->second->cancelTimer(); + + // Get interest packet from pending interest and do nothing with it. It + // will get destroyed as it goes out of scope. + auto _int = it->second->getInterest(); + + pending_interest_hash_table_.erase(it); + } + } + + /** * Callback called by the underlying connector upon reception of a packet from * the local forwarder. * * @param packet_buffer - The bytes of the packet. */ TRANSPORT_ALWAYS_INLINE void processIncomingMessages( - Packet::MemBufPtr &&packet_buffer) { + Connector *c, utils::MemBuf &buffer, const std::error_code &ec) { bool is_stopped = io_service_.stopped(); if (TRANSPORT_EXPECT_FALSE(is_stopped)) { return; } - if (TRANSPORT_EXPECT_FALSE( - ForwarderInt::isControlMessage(packet_buffer->data()))) { - processControlMessage(std::move(packet_buffer)); + if (TRANSPORT_EXPECT_FALSE(io_module_->isControlMessage(buffer.data()))) { + processControlMessage(buffer); return; } - Packet::Format format = Packet::getFormatFromBuffer( - packet_buffer->data(), packet_buffer->length()); + // The buffer is a base class for an interest or a content object + Packet &packet_buffer = static_cast<Packet &>(buffer); + auto format = packet_buffer.getFormat(); if (TRANSPORT_EXPECT_TRUE(_is_tcp(format))) { - if (!Packet::isInterest(packet_buffer->data())) { - auto content_object = packet_pool_.getContentObject(); - content_object->replace(std::move(packet_buffer)); - processContentObject(std::move(content_object)); + if (is_consumer_) { + processContentObject(static_cast<ContentObject &>(packet_buffer)); } else { - auto interest = packet_pool_.getInterest(); - interest->replace(std::move(packet_buffer)); - processInterest(std::move(interest)); + processInterest(static_cast<Interest &>(packet_buffer)); } } else { TRANSPORT_LOGE("Received not supported packet. Ignoring it."); @@ -573,16 +589,16 @@ class Portal { */ TRANSPORT_ALWAYS_INLINE void setLocalRoutes() { for (auto &prefix : served_namespaces_) { - if (connector_.isConnected()) { - forwarder_interface_.registerRoute(prefix); + if (io_module_->isConnected()) { + io_module_->registerRoute(prefix); } } } - TRANSPORT_ALWAYS_INLINE void processInterest(Interest::Ptr &&interest) { + TRANSPORT_ALWAYS_INLINE void processInterest(Interest &interest) { // Interest for a producer if (TRANSPORT_EXPECT_TRUE(producer_callback_ != nullptr)) { - producer_callback_->onInterest(std::move(interest)); + producer_callback_->onInterest(interest); } } @@ -595,24 +611,27 @@ class Portal { * @param content_object - The data packet */ TRANSPORT_ALWAYS_INLINE void processContentObject( - ContentObject::Ptr &&content_object) { - uint32_t hash = content_object->getName().getHash32() + - content_object->getName().getSuffix(); + ContentObject &content_object) { + TRANSPORT_LOGD("processContentObject %s", + content_object.getName().toString().c_str()); + uint32_t hash = getHash(content_object.getName()); auto it = pending_interest_hash_table_.find(hash); if (it != pending_interest_hash_table_.end()) { + TRANSPORT_LOGD("Found pending interest."); + PendingInterest::Ptr interest_ptr = std::move(it->second); pending_interest_hash_table_.erase(it); interest_ptr->cancelTimer(); auto _int = interest_ptr->getInterest(); if (interest_ptr->getOnDataCallback() != UNSET_CALLBACK) { - interest_ptr->on_content_object_callback_(std::move(_int), - std::move(content_object)); + interest_ptr->on_content_object_callback_(*_int, content_object); } else if (consumer_callback_) { - consumer_callback_->onContentObject(std::move(_int), - std::move(content_object)); + consumer_callback_->onContentObject(*_int, content_object); } + } else { + TRANSPORT_LOGD("No interest pending for received content object."); } } @@ -622,12 +641,13 @@ class Portal { * them. */ TRANSPORT_ALWAYS_INLINE void processControlMessage( - Packet::MemBufPtr &&packet_buffer) { - forwarder_interface_.processControlMessageReply(std::move(packet_buffer)); + utils::MemBuf &packet_buffer) { + io_module_->processControlMessageReply(packet_buffer); } private: portal_details::HandlerMemory async_callback_memory_; + std::unique_ptr<IoModule, void (*)(IoModule *)> io_module_; asio::io_service &io_service_; asio::io_service internal_io_service_; @@ -641,8 +661,19 @@ class Portal { ConsumerCallback *consumer_callback_; ProducerCallback *producer_callback_; - typename ForwarderInt::ConnectorType connector_; - ForwarderInt forwarder_interface_; + bool is_consumer_; + + private: + static std::string defaultIoModule(); + static void parseIoModuleConfiguration(const libconfig::Setting &io_config, + std::error_code &ec); + static void getModuleConfiguration( + interface::global_config::ConfigurationObject &conf, std::error_code &ec); + static void setModuleConfiguration( + const interface::global_config::ConfigurationObject &conf, + std::error_code &ec); + static interface::global_config::IoModuleConfiguration conf_; + static std::string io_module_path_; }; } // namespace core diff --git a/libtransport/src/core/prefix.cc b/libtransport/src/core/prefix.cc index eea4aeb8b..1e2b2ed9d 100644 --- a/libtransport/src/core/prefix.cc +++ b/libtransport/src/core/prefix.cc @@ -25,12 +25,12 @@ extern "C" { #include <hicn/transport/portability/win_portability.h> #endif +#include <openssl/rand.h> + #include <cstring> #include <memory> #include <random> -#include <openssl/rand.h> - namespace transport { namespace core { @@ -99,7 +99,7 @@ void Prefix::buildPrefix(std::string &prefix, uint16_t prefix_length, ip_prefix_.family = family; } -std::unique_ptr<Sockaddr> Prefix::toSockaddr() { +std::unique_ptr<Sockaddr> Prefix::toSockaddr() const { Sockaddr *ret = nullptr; switch (ip_prefix_.family) { @@ -120,14 +120,14 @@ std::unique_ptr<Sockaddr> Prefix::toSockaddr() { return std::unique_ptr<Sockaddr>(ret); } -uint16_t Prefix::getPrefixLength() { return ip_prefix_.len; } +uint16_t Prefix::getPrefixLength() const { return ip_prefix_.len; } Prefix &Prefix::setPrefixLength(uint16_t prefix_length) { ip_prefix_.len = prefix_length; return *this; } -int Prefix::getAddressFamily() { return ip_prefix_.family; } +int Prefix::getAddressFamily() const { return ip_prefix_.family; } Prefix &Prefix::setAddressFamily(int address_family) { ip_prefix_.family = address_family; @@ -226,7 +226,7 @@ Name Prefix::getRandomName() const { ip_prefix_.len; size_t size = (size_t)ceil((float)addr_len / 8.0); - uint8_t *buffer = (uint8_t *) malloc(sizeof(uint8_t) * size); + uint8_t *buffer = (uint8_t *)malloc(sizeof(uint8_t) * size); RAND_bytes(buffer, size); @@ -332,7 +332,7 @@ bool Prefix::checkPrefixLengthAndAddressFamily(uint16_t prefix_length, return true; } -ip_prefix_t &Prefix::toIpPrefixStruct() { return ip_prefix_; } +const ip_prefix_t &Prefix::toIpPrefixStruct() const { return ip_prefix_; } } // namespace core diff --git a/libtransport/src/core/rs.cc b/libtransport/src/core/rs.cc new file mode 100644 index 000000000..44b5852e5 --- /dev/null +++ b/libtransport/src/core/rs.cc @@ -0,0 +1,365 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/fec.h> +#include <core/rs.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/log.h> + +#include <cassert> + +namespace transport { +namespace core { +namespace fec { + +BlockCode::BlockCode(uint32_t k, uint32_t n, struct fec_parms *code) + : Packets(), + k_(k), + n_(n), + code_(code), + max_buffer_size_(0), + current_block_size_(0), + to_decode_(false) { + sorted_index_.reserve(n); +} + +bool BlockCode::addRepairSymbol(const fec::buffer &packet, uint32_t i) { + // Get index + to_decode_ = true; + TRANSPORT_LOGD("adding symbol of size %zu", packet->length()); + return addSymbol(packet, i, packet->length() - sizeof(fec_header)); +} + +bool BlockCode::addSourceSymbol(const fec::buffer &packet, uint32_t i) { + return addSymbol(packet, i, packet->length()); +} + +bool BlockCode::addSymbol(const fec::buffer &packet, uint32_t i, + std::size_t size) { + if (size > max_buffer_size_) { + max_buffer_size_ = size; + } + + operator[](current_block_size_++) = std::make_pair(i, packet); + + if (current_block_size_ >= k_) { + if (to_decode_) { + decode(); + } else { + encode(); + } + + clear(); + return false; + } + + return true; +} + +void BlockCode::encode() { + gf *data[n_]; + std::uint16_t old_values[k_]; + uint32_t base = operator[](0).first; + + // Set packet length in first 2 bytes + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + + TRANSPORT_LOGD("Current buffer size: %zu", packet->length()); + + auto ret = packet->ensureCapacityAndFillUnused(max_buffer_size_, 0); + if (TRANSPORT_EXPECT_FALSE(ret == false)) { + throw errors::RuntimeException( + "Provided packet is not suitable to be used as FEC source packet. " + "Aborting."); + } + + // Buffers should hold 2 bytes before the starting pointer, in order to be + // able to set the length for the encoding operation + packet->prepend(2); + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + + old_values[i] = *length; + *length = htons(packet->length() - LEN_SIZE_BYTES); + + data[i] = packet->writableData(); + } + + // Finish to fill source block with the buffers to hold the repair symbols + for (uint32_t i = k_; i < n_; i++) { + // For the moment we get a packet from the pool here.. later we'll need to + // require a packet from the caller with a callback. + auto packet = PacketManager<>::getInstance().getMemBuf(); + packet->append(max_buffer_size_ + sizeof(fec_header) + LEN_SIZE_BYTES); + fec_header *fh = reinterpret_cast<fec_header *>(packet->writableData()); + + fh->setSeqNumberBase(base); + fh->setNFecSymbols(n_ - k_); + fh->setEncodedSymbolId(i); + fh->setSourceBlockLen(n_); + + packet->trimStart(sizeof(fec_header)); + + data[i] = packet->writableData(); + operator[](i) = std::make_pair(i, std::move(packet)); + } + + // Generate repair symbols and put them in corresponding buffers + TRANSPORT_LOGD("Calling encode with max_buffer_size_ = %zu", + max_buffer_size_); + for (uint32_t i = k_; i < n_; i++) { + fec_encode(code_, data, data[i], i, max_buffer_size_ + LEN_SIZE_BYTES); + } + + // Restore original content of buffer space used to store the length + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + *length = old_values[i]; + packet->trimStart(2); + } + + // Re-include header in repair packets + for (uint32_t i = k_; i < n_; i++) { + auto &packet = operator[](i).second; + TRANSPORT_LOGD("Produced repair symbol of size = %zu", packet->length()); + packet->prepend(sizeof(fec_header)); + } +} + +void BlockCode::decode() { + gf *data[k_]; + uint32_t index[k_]; + + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + index[i] = operator[](i).first; + sorted_index_[i] = index[i]; + + if (index[i] < k_) { + TRANSPORT_LOGD("DECODE SOURCE - index %u - Current buffer size: %zu", + index[i], packet->length()); + // This is a source packet. We need to prepend the length and fill + // additional space to 0 + + // Buffers should hold 2 bytes before the starting pointer, in order to be + // able to set the length for the encoding operation + packet->prepend(LEN_SIZE_BYTES); + packet->ensureCapacityAndFillUnused(max_buffer_size_, 0); + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + + *length = htons(packet->length() - LEN_SIZE_BYTES); + } else { + TRANSPORT_LOGD("DECODE SYMBOL - index %u - Current buffer size: %zu", + index[i], packet->length()); + packet->trimStart(sizeof(fec_header)); + } + + data[i] = packet->writableData(); + } + + // We decode the source block + TRANSPORT_LOGD("Calling decode with max_buffer_size_ = %zu", + max_buffer_size_); + fec_decode(code_, data, reinterpret_cast<int *>(index), max_buffer_size_); + + // Find the index in the block for recovered packets + for (uint32_t i = 0; i < k_; i++) { + if (index[i] != i) { + for (uint32_t j = 0; j < k_; j++) + if (sorted_index_[j] == uint32_t(index[i])) { + sorted_index_[j] = i; + } + } + } + + // Reorder block by index with in-place sorting + for (uint32_t i = 0; i < k_; i++) { + for (uint32_t j = sorted_index_[i]; j != i; j = sorted_index_[i]) { + std::swap(sorted_index_[j], sorted_index_[i]); + std::swap(operator[](j), operator[](i)); + } + } + + // Adjust length according to the one written in the source packet + for (uint32_t i = 0; i < k_; i++) { + auto &packet = operator[](i).second; + uint16_t *length = reinterpret_cast<uint16_t *>(packet->writableData()); + packet->trimStart(2); + packet->setLength(ntohs(*length)); + } +} + +void BlockCode::clear() { + current_block_size_ = 0; + max_buffer_size_ = 0; + sorted_index_.clear(); + to_decode_ = false; +} + +void rs::MatrixDeleter::operator()(struct fec_parms *params) { + fec_free(params); +} + +rs::Codes rs::createCodes() { + Codes ret; + + ret.emplace(std::make_pair(1, 3), Matrix(fec_new(1, 3), MatrixDeleter())); + ret.emplace(std::make_pair(6, 10), Matrix(fec_new(6, 10), MatrixDeleter())); + ret.emplace(std::make_pair(8, 32), Matrix(fec_new(8, 32), MatrixDeleter())); + ret.emplace(std::make_pair(10, 30), Matrix(fec_new(10, 30), MatrixDeleter())); + ret.emplace(std::make_pair(16, 24), Matrix(fec_new(16, 24), MatrixDeleter())); + ret.emplace(std::make_pair(10, 40), Matrix(fec_new(10, 40), MatrixDeleter())); + ret.emplace(std::make_pair(10, 60), Matrix(fec_new(10, 60), MatrixDeleter())); + ret.emplace(std::make_pair(10, 90), Matrix(fec_new(10, 90), MatrixDeleter())); + + return ret; +} + +rs::Codes rs::codes_ = createCodes(); + +rs::rs(uint32_t k, uint32_t n) : k_(k), n_(n) {} + +void rs::setFECCallback(const PacketsReady &callback) { + fec_callback_ = callback; +} + +encoder::encoder(uint32_t k, uint32_t n) + : rs(k, n), + current_code_(codes_[std::make_pair(k, n)].get()), + source_block_(k_, n_, current_code_) {} + +void encoder::consume(const fec::buffer &packet, uint32_t index) { + if (!source_block_.addSourceSymbol(packet, index)) { + std::vector<buffer> repair_packets; + for (uint32_t i = k_; i < n_; i++) { + repair_packets.emplace_back(std::move(source_block_[i].second)); + } + fec_callback_(repair_packets); + } +} + +decoder::decoder(uint32_t k, uint32_t n) : rs(k, n) {} + +void decoder::recoverPackets(SourceBlocks::iterator &src_block_it) { + TRANSPORT_LOGD("recoverPackets for %u", k_); + auto &src_block = src_block_it->second; + std::vector<buffer> source_packets(k_); + for (uint32_t i = 0; i < src_block.getK(); i++) { + source_packets[i] = std::move(src_block[i].second); + } + + fec_callback_(source_packets); + processed_source_blocks_.emplace(src_block_it->first); + + auto it = parked_packets_.find(src_block_it->first); + if (it != parked_packets_.end()) { + parked_packets_.erase(it); + } + + src_blocks_.erase(src_block_it); +} + +void decoder::consume(const fec::buffer &packet, uint32_t index) { + // Normalize index + auto i = index % n_; + + // Get base + uint32_t base = index - i; + + TRANSPORT_LOGD( + "Decoder consume called for source symbol. BASE = %u, index = %u and i = " + "%u", + base, index, i); + + // check if a source block already exist for this symbol. If it does not + // exist, we lazily park this packet until we receive a repair symbol for the + // same block. This is done for 2 reason: + // 1) If we receive all the source packets of a block, we do not need to + // recover anything. + // 2) Sender may change n and k at any moment, so we construct the source + // block based on the (n, k) values written in the fec header. This is + // actually not used right now, since we use fixed value of n and k passed + // at construction time, but it paves the ground for a more dynamic + // protocol that may come in the future. + auto it = src_blocks_.find(base); + if (it != src_blocks_.end()) { + auto ret = it->second.addSourceSymbol(packet, i); + if (!ret) { + recoverPackets(it); + } + } else { + TRANSPORT_LOGD("Adding to parked source packets"); + auto ret = parked_packets_.emplace( + base, std::vector<std::pair<buffer, uint32_t> >()); + ret.first->second.emplace_back(packet, i); + } +} + +void decoder::consume(const fec::buffer &packet) { + // Repair symbol! Get index and base source block. + fec_header *h = reinterpret_cast<fec_header *>(packet->writableData()); + auto i = h->getEncodedSymbolId(); + auto base = h->getSeqNumberBase(); + auto n = h->getSourceBlockLen(); + auto k = n - h->getNFecSymbols(); + + TRANSPORT_LOGD( + "Decoder consume called for repair symbol. BASE = %u, index = %u and i = " + "%u", + base, base + i, i); + + // check if a source block already exist for this symbol + auto it = src_blocks_.find(base); + if (it == src_blocks_.end()) { + // Create new source block + auto code_it = codes_.find(std::make_pair(k, n)); + if (code_it == codes_.end()) { + TRANSPORT_LOGE("Code for k = %u and n = %u does not exist.", k_, n_); + return; + } + + auto emplace_result = + src_blocks_.emplace(base, BlockCode(k, n, code_it->second.get())); + it = emplace_result.first; + + // Check in the parked packets and insert any packet that is part of this + // source block + + auto it2 = parked_packets_.find(base); + if (it2 != parked_packets_.end()) { + for (auto &packet_index : it2->second) { + auto ret = + it->second.addSourceSymbol(packet_index.first, packet_index.second); + if (!ret) { + recoverPackets(it); + // Finish to delete packets in same source block that were + // eventually not used + return; + } + } + } + } + + auto ret = it->second.addRepairSymbol(packet, i); + if (!ret) { + recoverPackets(it); + } +} + +} // namespace fec +} // namespace core +} // namespace transport diff --git a/libtransport/src/core/rs.h b/libtransport/src/core/rs.h new file mode 100644 index 000000000..d630bd233 --- /dev/null +++ b/libtransport/src/core/rs.h @@ -0,0 +1,338 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <arpa/inet.h> +#include <hicn/transport/utils/membuf.h> +#include <protocols/fec_base.h> + +#include <array> +#include <cstdint> +#include <map> +#include <unordered_set> +#include <vector> + +namespace transport { +namespace core { + +namespace fec { + +static const constexpr uint16_t MAX_SOURCE_BLOCK_SIZE = 128; + +using buffer = typename utils::MemBuf::Ptr; +/** + * We use a std::array in place of std::vector to avoid to allocate a new vector + * in the heap every time we build a new source block, which would be bad if + * the decoder has to allocate several source blocks for many concurrent bases. + * std::array allows to be constructed in place, saving the allocation at the + * price os knowing in advance its size. + */ +using Packets = std::array<std::pair<uint32_t, buffer>, MAX_SOURCE_BLOCK_SIZE>; + +/** + * FEC Header, prepended to symbol packets. + */ +struct fec_header { + /** + * The base source packet seq_number this FES symbol refers to + */ + uint32_t seq_number; + + /** + * The index of the symbol inside the source block, between k and n - 1 + */ + uint8_t encoded_symbol_id; + + /** + * Total length of source block (n) + */ + uint8_t source_block_len; + + /** + * Total number of symbols (n - k) + */ + uint8_t n_fec_symbols; + + /** + * Align header to 64 bits + */ + uint8_t padding; + + void setSeqNumberBase(uint32_t suffix) { seq_number = htonl(suffix); } + uint32_t getSeqNumberBase() { return ntohl(seq_number); } + void setEncodedSymbolId(uint8_t esi) { encoded_symbol_id = esi; } + uint8_t getEncodedSymbolId() { return encoded_symbol_id; } + void setSourceBlockLen(uint8_t k) { source_block_len = k; } + uint8_t getSourceBlockLen() { return source_block_len; } + void setNFecSymbols(uint8_t n_r) { n_fec_symbols = n_r; } + uint8_t getNFecSymbols() { return n_fec_symbols; } +}; + +/** + * This class models the source block itself. + */ +class BlockCode : public Packets { + /** + * For variable length packet we need to prepend to the padded payload the + * real length of the packet. This is *not* sent over the network. + */ + static constexpr std::size_t LEN_SIZE_BYTES = 2; + + public: + BlockCode(uint32_t k, uint32_t n, struct fec_parms *code); + + /** + * Add a repair symbol to the dource block. + */ + bool addRepairSymbol(const fec::buffer &packet, uint32_t i); + + /** + * Add a source symbol to the source block. + */ + bool addSourceSymbol(const fec::buffer &packet, uint32_t i); + + /** + * Get current length of source block. + */ + std::size_t length() { return current_block_size_; } + + /** + * Get N + */ + uint32_t getN() { return n_; } + + /** + * Get K + */ + uint32_t getK() { return k_; } + + /** + * Clear source block + */ + void clear(); + + private: + /** + * Add symbol to source block + **/ + bool addSymbol(const fec::buffer &packet, uint32_t i, std::size_t size); + + /** + * Starting from k source symbols, get the n - k repair symbols + */ + void encode(); + + /** + * Starting from k symbols (mixed repair and source), get k source symbols. + * NOTE: It does not make sense to retrieve the k source symbols using the + * very same k source symbols. With the current implementation that case can + * never happen. + */ + void decode(); + + private: + uint32_t k_; + uint32_t n_; + struct fec_parms *code_; + std::size_t max_buffer_size_; + std::size_t current_block_size_; + std::vector<uint32_t> sorted_index_; + bool to_decode_; +}; + +/** + * This class contains common parameters between the fec encoder and decoder. + * In particular it contains: + * - The callback to be called when symbols are encoded / decoded + * - The reference to the static reed-solomon parameters, allocated at program + * startup + * - N and K. Ideally they are useful only for the encoder (the decoder can + * retrieve them from the FEC header). However right now we assume sender and + * receiver agreed on the parameters k and n to use. We will introduce a control + * message later to negotiate them, so that decoder cah dynamically change them + * during the download. + */ +class rs { + /** + * Deleter for static preallocated reed-solomon parameters. + */ + struct MatrixDeleter { + void operator()(struct fec_parms *params); + }; + + /** + * unique_ptr to reed-solomon parameters, with custom deleter to call fec_free + * at the end of the program + */ + using Matrix = std::unique_ptr<struct fec_parms, MatrixDeleter>; + + /** + * Key to retrieve static preallocated reed-solomon parameters. It is pair of + * k and n + */ + using Code = std::pair<std::uint32_t /* k */, std::uint32_t /* n */>; + + /** + * Custom hash function for (k, n) pair. + */ + struct CodeHasher { + std::size_t operator()(const transport::core::fec::rs::Code &code) const { + uint64_t ret = uint64_t(code.first) << 32 | uint64_t(code.second); + return std::hash<uint64_t>{}(ret); + } + }; + + protected: + /** + * Callback to be called after the encode or the decode operations. In the + * former case it will contain the symbols, while in the latter the sources. + */ + using PacketsReady = std::function<void(std::vector<buffer> &)>; + + /** + * The sequence number base. + */ + using SNBase = std::uint32_t; + + /** + * The map of source blocks, used at the decoder side. For the encoding + * operation we can use one source block only, since packet are produced in + * order. + */ + using SourceBlocks = std::unordered_map<SNBase, BlockCode>; + + /** + * Map (k, n) -> reed-solomon parameter + */ + using Codes = std::unordered_map<Code, Matrix, CodeHasher>; + + public: + rs(uint32_t k, uint32_t n); + ~rs() = default; + + /** + * Set callback to call after packet encoding / decoding + */ + void setFECCallback(const PacketsReady &callback); + + virtual void clear() { processed_source_blocks_.clear(); } + + private: + /** + * Create reed-solomon codes at program startup. + */ + static Codes createCodes(); + + protected: + bool processed(SNBase seq_base) { + return processed_source_blocks_.find(seq_base) != + processed_source_blocks_.end(); + } + + std::uint32_t k_; + std::uint32_t n_; + PacketsReady fec_callback_; + + /** + * Keep track of processed source blocks + */ + std::unordered_set<SNBase> processed_source_blocks_; + + static Codes codes_; +}; + +/** + * The reed-solomon encoder. It is feeded with source symbols and it provide + * repair-symbols through the fec_callback_ + */ +class encoder : public rs { + public: + encoder(uint32_t k, uint32_t n); + /** + * Always consume source symbols. + */ + void consume(const fec::buffer &packet, uint32_t index); + + void clear() override { + rs::clear(); + source_block_.clear(); + } + + private: + struct fec_parms *current_code_; + /** + * The source block. As soon as it is filled with k source symbols, the + * encoder calls the callback fec_callback_ and the resets the block 0, ready + * to accept another batch of k source symbols. + */ + BlockCode source_block_; +}; + +/** + * The reed-solomon encoder. It is feeded with source/repair symbols and it + * provides the original source symbols through the fec_callback_ + */ +class decoder : public rs { + public: + decoder(uint32_t k, uint32_t n); + + /** + * Consume source symbol + */ + void consume(const fec::buffer &packet, uint32_t i); + + /** + * Consume repair symbol + */ + void consume(const fec::buffer &packet); + + /** + * Clear decoder to reuse + */ + void clear() override { + rs::clear(); + src_blocks_.clear(); + parked_packets_.clear(); + } + + private: + void recoverPackets(SourceBlocks::iterator &src_block_it); + + private: + /** + * Map of source blocks. We use a map because we may receive symbols belonging + * to diffreent source blocks at the same time, so we need to be able to + * decode many source symbols at the same time. + */ + SourceBlocks src_blocks_; + + /** + * Unordered Map of source symbols for which we did not receive any repair + * symbol in the same source block. Notably this happens when: + * + * - We receive the source symbols first and the repair symbols after + * - We received only source symbols for a given block. In that case it does + * not make any sense to build the source block, since we received all the + * source packet of the block. + */ + std::unordered_map<uint32_t, std::vector<std::pair<buffer, uint32_t>>> + parked_packets_; +}; + +} // namespace fec + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/core/tcp_socket_connector.h b/libtransport/src/core/tcp_socket_connector.h index c57123e9f..9dbd250d1 100644 --- a/libtransport/src/core/tcp_socket_connector.h +++ b/libtransport/src/core/tcp_socket_connector.h @@ -15,12 +15,11 @@ #pragma once +#include <core/connector.h> #include <hicn/transport/config.h> #include <hicn/transport/core/name.h> #include <hicn/transport/utils/branch_prediction.h> -#include <core/connector.h> - #include <asio.hpp> #include <asio/steady_timer.hpp> #include <deque> diff --git a/libtransport/src/http/client_connection.cc b/libtransport/src/http/client_connection.cc index 7a3a636fe..a24a821e7 100644 --- a/libtransport/src/http/client_connection.cc +++ b/libtransport/src/http/client_connection.cc @@ -43,12 +43,6 @@ class HTTPClientConnection::Implementation read_buffer_(nullptr), response_(std::make_shared<HTTPResponse>()), timer_(nullptr) { - consumer_.setSocketOption( - ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY, - (ConsumerContentObjectVerificationCallback)std::bind( - &Implementation::verifyData, this, std::placeholders::_1, - std::placeholders::_2)); - consumer_.setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, this); consumer_.connect(); @@ -124,10 +118,10 @@ class HTTPClientConnection::Implementation return *http_client_; } - HTTPClientConnection &setCertificate(const std::string &cert_path) { - if (consumer_.setSocketOption(GeneralTransportOptions::CERTIFICATE, - cert_path) == SOCKET_OPTION_NOT_SET) { - throw errors::RuntimeException("Error setting the certificate."); + HTTPClientConnection &setVerifier(std::shared_ptr<auth::Verifier> verifier) { + if (consumer_.setSocketOption(GeneralTransportOptions::VERIFIER, + verifier) == SOCKET_OPTION_NOT_SET) { + throw errors::RuntimeException("Error setting the verifier."); } return *http_client_; @@ -177,17 +171,6 @@ class HTTPClientConnection::Implementation consumer_.stop(); } - bool verifyData(interface::ConsumerSocket &c, - const core::ContentObject &contentObject) { - if (contentObject.getPayloadType() == PayloadType::CONTENT_OBJECT) { - TRANSPORT_LOGI("VERIFY CONTENT\n"); - } else if (contentObject.getPayloadType() == PayloadType::MANIFEST) { - TRANSPORT_LOGI("VERIFY MANIFEST\n"); - } - - return true; - } - void processLeavingInterest(interface::ConsumerSocket &c, const core::Interest &interest) { if (interest.payloadSize() == 0) { @@ -307,9 +290,9 @@ HTTPClientConnection &HTTPClientConnection::setTimeout( return implementation_->setTimeout(timeout); } -HTTPClientConnection &HTTPClientConnection::setCertificate( - const std::string &cert_path) { - return implementation_->setCertificate(cert_path); +HTTPClientConnection &HTTPClientConnection::setVerifier( + std::shared_ptr<auth::Verifier> verifier) { + return implementation_->setVerifier(verifier); } } // namespace http diff --git a/libtransport/src/implementation/CMakeLists.txt b/libtransport/src/implementation/CMakeLists.txt index 5423a7697..392c99e15 100644 --- a/libtransport/src/implementation/CMakeLists.txt +++ b/libtransport/src/implementation/CMakeLists.txt @@ -13,13 +13,8 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) -list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc -) - list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/socket.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.h ) @@ -27,15 +22,16 @@ list(APPEND HEADER_FILES if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/socket.cc ) list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.h diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.cc b/libtransport/src/implementation/p2psecure_socket_consumer.cc index 9b79850d6..8c7c175b2 100644 --- a/libtransport/src/implementation/p2psecure_socket_consumer.cc +++ b/libtransport/src/implementation/p2psecure_socket_consumer.cc @@ -15,7 +15,6 @@ #include <implementation/p2psecure_socket_consumer.h> #include <interfaces/tls_socket_consumer.h> - #include <openssl/bio.h> #include <openssl/ssl.h> #include <openssl/tls1.h> @@ -175,7 +174,6 @@ P2PSecureConsumerSocket::P2PSecureConsumerSocket( : ConsumerSocket(consumer, handshake_protocol), name_(), tls_consumer_(nullptr), - buf_pool_(), decrypted_content_(), payload_(), head_(), diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.h b/libtransport/src/implementation/p2psecure_socket_consumer.h index d4c3b26c2..a35a50352 100644 --- a/libtransport/src/implementation/p2psecure_socket_consumer.h +++ b/libtransport/src/implementation/p2psecure_socket_consumer.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/interfaces/socket_consumer.h> - #include <implementation/tls_socket_consumer.h> #include <openssl/bio.h> #include <openssl/ssl.h> @@ -75,7 +74,6 @@ class P2PSecureConsumerSocket : public ConsumerSocket, BIO_METHOD *bio_meth_; /* Chain of MemBuf to be used as a temporary buffer to pass descypted data * from the underlying layer to the application */ - utils::ObjectPool<utils::MemBuf> buf_pool_; std::unique_ptr<utils::MemBuf> decrypted_content_; /* Chain of MemBuf holding the payload to be written into interest or data */ std::unique_ptr<utils::MemBuf> payload_; diff --git a/libtransport/src/implementation/p2psecure_socket_producer.cc b/libtransport/src/implementation/p2psecure_socket_producer.cc index 15c7d25cd..6dff2ba08 100644 --- a/libtransport/src/implementation/p2psecure_socket_producer.cc +++ b/libtransport/src/implementation/p2psecure_socket_producer.cc @@ -14,13 +14,11 @@ */ #include <hicn/transport/core/interest.h> - #include <implementation/p2psecure_socket_producer.h> -#include <implementation/tls_rtc_socket_producer.h> +// #include <implementation/tls_rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> #include <interfaces/tls_rtc_socket_producer.h> #include <interfaces/tls_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> @@ -34,7 +32,8 @@ namespace implementation { P2PSecureProducerSocket::P2PSecureProducerSocket( interface::ProducerSocket *producer_socket) - : ProducerSocket(producer_socket), + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), mtx_(), cv_(), map_producers(), @@ -42,8 +41,9 @@ P2PSecureProducerSocket::P2PSecureProducerSocket( P2PSecureProducerSocket::P2PSecureProducerSocket( interface::ProducerSocket *producer_socket, bool rtc, - const std::shared_ptr<utils::Identity> &identity) - : ProducerSocket(producer_socket), + const std::shared_ptr<auth::Identity> &identity) + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), rtc_(rtc), mtx_(), cv_(), @@ -51,9 +51,9 @@ P2PSecureProducerSocket::P2PSecureProducerSocket( list_producers() { /* Setup SSL context (identity and parameter to use TLS 1.3) */ der_cert_ = parcKeyStore_GetDEREncodedCertificate( - (identity->getSigner()->getKeyStore())); + (identity->getSigner()->getParcKeyStore())); der_prk_ = parcKeyStore_GetDEREncodedPrivateKey( - (identity->getSigner()->getKeyStore())); + (identity->getSigner()->getParcKeyStore())); int cert_size = parcBuffer_Limit(der_cert_); int prk_size = parcBuffer_Limit(der_prk_); @@ -88,15 +88,20 @@ void P2PSecureProducerSocket::initSessionSocket( producer->setSocketOption(MAKE_MANIFEST, this->making_manifest_); producer->setSocketOption(DATA_PACKET_SIZE, (uint32_t)(this->data_packet_size_)); - producer->output_buffer_.setLimit(this->output_buffer_.getLimit()); + uint32_t output_buffer_size = 0; + this->getSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, + output_buffer_size); + producer->setSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, + output_buffer_size); if (!rtc_) { producer->setInterface(new interface::TLSProducerSocket(producer.get())); } else { - TLSRTCProducerSocket *rtc_producer = - dynamic_cast<TLSRTCProducerSocket *>(producer.get()); - rtc_producer->setInterface( - new interface::TLSRTCProducerSocket(rtc_producer)); + // TODO + // TLSRTCProducerSocket *rtc_producer = + // dynamic_cast<TLSRTCProducerSocket *>(producer.get()); + // rtc_producer->setInterface( + // new interface::TLSRTCProducerSocket(rtc_producer)); } } @@ -114,8 +119,9 @@ void P2PSecureProducerSocket::onInterestCallback(interface::ProducerSocket &p, tls_producer = std::make_unique<TLSProducerSocket>(nullptr, this, interest.getName()); } else { - tls_producer = std::make_unique<TLSRTCProducerSocket>(nullptr, this, - interest.getName()); + // TODO + // tls_producer = std::make_unique<TLSRTCProducerSocket>(nullptr, this, + // interest.getName()); } initSessionSocket(tls_producer); @@ -129,15 +135,19 @@ void P2PSecureProducerSocket::onInterestCallback(interface::ProducerSocket &p, tls_producer_ptr->onInterest(*tls_producer_ptr, interest); tls_producer_ptr->async_accept(); } else { - TLSRTCProducerSocket *rtc_producer_ptr = - dynamic_cast<TLSRTCProducerSocket *>(tls_producer_ptr); - rtc_producer_ptr->onInterest(*rtc_producer_ptr, interest); - rtc_producer_ptr->async_accept(); + // TODO + // TLSRTCProducerSocket *rtc_producer_ptr = + // dynamic_cast<TLSRTCProducerSocket *>(tls_producer_ptr); + // rtc_producer_ptr->onInterest(*rtc_producer_ptr, interest); + // rtc_producer_ptr->async_accept(); } } -void P2PSecureProducerSocket::produce(const uint8_t *buffer, - size_t buffer_size) { +uint32_t P2PSecureProducerSocket::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + // TODO + throw errors::NotImplementedException(); + if (!rtc_) { throw errors::RuntimeException( "RTC must be the transport protocol to start the production of current " @@ -148,16 +158,20 @@ void P2PSecureProducerSocket::produce(const uint8_t *buffer, if (list_producers.empty()) cv_.wait(lck); - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) { - TLSRTCProducerSocket *rtc_producer = - dynamic_cast<TLSRTCProducerSocket *>(it->get()); - rtc_producer->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); - } + // TODO + // for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) + // { + // TLSRTCProducerSocket *rtc_producer = + // dynamic_cast<TLSRTCProducerSocket *>(it->get()); + // rtc_producer->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); + // } + + return 0; } -uint32_t P2PSecureProducerSocket::produce( - Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, - uint32_t start_offset) { +uint32_t P2PSecureProducerSocket::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { if (rtc_) { throw errors::RuntimeException( "RTC transport protocol is not compatible with the production of " @@ -170,16 +184,17 @@ uint32_t P2PSecureProducerSocket::produce( if (list_producers.empty()) cv_.wait(lck); for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += - (*it)->produce(content_name, buffer->clone(), is_last, start_offset); + segments += (*it)->produceStream(content_name, buffer->clone(), is_last, + start_offset); return segments; } -uint32_t P2PSecureProducerSocket::produce(Name content_name, - const uint8_t *buffer, - size_t buffer_size, bool is_last, - uint32_t start_offset) { +uint32_t P2PSecureProducerSocket::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, + bool is_last, + uint32_t start_offset) { if (rtc_) { throw errors::RuntimeException( "RTC transport protocol is not compatible with the production of " @@ -191,29 +206,31 @@ uint32_t P2PSecureProducerSocket::produce(Name content_name, if (list_producers.empty()) cv_.wait(lck); for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += (*it)->produce(content_name, buffer, buffer_size, is_last, - start_offset); + segments += (*it)->produceStream(content_name, buffer, buffer_size, is_last, + start_offset); return segments; } -void P2PSecureProducerSocket::asyncProduce(const Name &content_name, - const uint8_t *buf, - size_t buffer_size, bool is_last, - uint32_t *start_offset) { - if (rtc_) { - throw errors::RuntimeException( - "RTC transport protocol is not compatible with the production of " - "current data. Aborting."); - } - - std::unique_lock<std::mutex> lck(mtx_); - if (list_producers.empty()) cv_.wait(lck); - - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) { - (*it)->asyncProduce(content_name, buf, buffer_size, is_last, start_offset); - } -} +// void P2PSecureProducerSocket::asyncProduce(const Name &content_name, +// const uint8_t *buf, +// size_t buffer_size, bool is_last, +// uint32_t *start_offset) { +// if (rtc_) { +// throw errors::RuntimeException( +// "RTC transport protocol is not compatible with the production of " +// "current data. Aborting."); +// } + +// std::unique_lock<std::mutex> lck(mtx_); +// if (list_producers.empty()) cv_.wait(lck); + +// for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) +// { +// (*it)->asyncProduce(content_name, buf, buffer_size, is_last, +// start_offset); +// } +// } void P2PSecureProducerSocket::asyncProduce( Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, @@ -269,7 +286,7 @@ int P2PSecureProducerSocket::setSocketOption( int P2PSecureProducerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) (*it)->setSocketOption(socket_option_key, socket_option_value); @@ -323,16 +340,6 @@ int P2PSecureProducerSocket::setSocketOption(int socket_option_key, } int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, std::list<Prefix> socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) @@ -361,17 +368,7 @@ int P2PSecureProducerSocket::setSocketOption( } int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, utils::CryptoHashType socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, utils::CryptoSuite socket_option_value) { + int socket_option_key, auth::CryptoHashType socket_option_value) { if (!list_producers.empty()) for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) (*it)->setSocketOption(socket_option_key, socket_option_value); diff --git a/libtransport/src/implementation/p2psecure_socket_producer.h b/libtransport/src/implementation/p2psecure_socket_producer.h index bfc9fc2c1..b7c3d1958 100644 --- a/libtransport/src/implementation/p2psecure_socket_producer.h +++ b/libtransport/src/implementation/p2psecure_socket_producer.h @@ -15,15 +15,14 @@ #pragma once -#include <hicn/transport/security/identity.h> -#include <hicn/transport/security/signer.h> - +#include <hicn/transport/auth/identity.h> +#include <hicn/transport/auth/signer.h> #include <implementation/socket_producer.h> -#include <implementation/tls_rtc_socket_producer.h> +// #include <implementation/tls_rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> +#include <openssl/ssl.h> #include <utils/content_store.h> -#include <openssl/ssl.h> #include <condition_variable> #include <forward_list> #include <mutex> @@ -33,39 +32,40 @@ namespace implementation { class P2PSecureProducerSocket : public ProducerSocket { friend class TLSProducerSocket; - friend class TLSRTCProducerSocket; + // TODO + // friend class TLSRTCProducerSocket; public: explicit P2PSecureProducerSocket(interface::ProducerSocket *producer_socket); explicit P2PSecureProducerSocket( interface::ProducerSocket *producer_socket, bool rtc, - const std::shared_ptr<utils::Identity> &identity); + const std::shared_ptr<auth::Identity> &identity); ~P2PSecureProducerSocket(); - void produce(const uint8_t *buffer, size_t buffer_size) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; - uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, - bool is_last = true, uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; - uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, uint32_t offset, uint32_t **last_segment = nullptr) override; - void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, - bool is_last = true, - uint32_t *start_offset = nullptr) override; - int setSocketOption(int socket_option_key, ProducerInterestCallback socket_option_value) override; int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) override; + const std::shared_ptr<auth::Signer> &socket_option_value) override; int setSocketOption(int socket_option_key, uint32_t socket_option_value) override; @@ -75,9 +75,6 @@ class P2PSecureProducerSocket : public ProducerSocket { int setSocketOption(int socket_option_key, Name *socket_option_value) override; - int setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) override; - int setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) override; @@ -86,16 +83,13 @@ class P2PSecureProducerSocket : public ProducerSocket { ProducerContentCallback socket_option_value) override; int setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) override; - - int setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) override; + auth::CryptoHashType socket_option_value) override; int setSocketOption(int socket_option_key, const std::string &socket_option_value) override; using ProducerSocket::getSocketOption; - using ProducerSocket::onInterest; + // using ProducerSocket::onInterest; protected: /* Callback invoked once an interest has been received and its payload diff --git a/libtransport/src/implementation/socket.cc b/libtransport/src/implementation/socket.cc new file mode 100644 index 000000000..2e21f2bc3 --- /dev/null +++ b/libtransport/src/implementation/socket.cc @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <implementation/socket.h> + +namespace transport { +namespace implementation { + +Socket::Socket(std::shared_ptr<core::Portal> &&portal) + : portal_(std::move(portal)), is_async_(false) {} + +} // namespace implementation +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/implementation/socket.h b/libtransport/src/implementation/socket.h index 2e51f3027..cf22c03e1 100644 --- a/libtransport/src/implementation/socket.h +++ b/libtransport/src/implementation/socket.h @@ -15,13 +15,12 @@ #pragma once +#include <core/facade.h> #include <hicn/transport/config.h> #include <hicn/transport/interfaces/callbacks.h> #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/interfaces/socket_options_keys.h> -#include <core/facade.h> - #define SOCKET_OPTION_GET 0 #define SOCKET_OPTION_NOT_GET 1 #define SOCKET_OPTION_SET 2 @@ -32,56 +31,23 @@ namespace transport { namespace implementation { // Forward Declarations -template <typename PortalType> class Socket; -// Define the portal and its connector, depending on the compilation options -// passed by the build tool. -using HicnForwarderPortal = core::HicnForwarderPortal; - -#ifdef __linux__ -#ifndef __ANDROID__ -using RawSocketPortal = core::RawSocketPortal; -#endif -#endif - -#ifdef __vpp__ -using VPPForwarderPortal = core::VPPForwarderPortal; -using BaseSocket = Socket<VPPForwarderPortal>; -using BasePortal = VPPForwarderPortal; -#else -using BaseSocket = Socket<HicnForwarderPortal>; -using BasePortal = HicnForwarderPortal; -#endif - -template <typename PortalType> class Socket { - static_assert(std::is_same<PortalType, HicnForwarderPortal>::value -#ifdef __linux__ -#ifndef __ANDROID__ - || std::is_same<PortalType, RawSocketPortal>::value -#ifdef __vpp__ - || std::is_same<PortalType, VPPForwarderPortal>::value -#endif -#endif - , -#else - , - -#endif - "This class is not allowed as Portal"); - public: - using Portal = PortalType; - - virtual asio::io_service &getIoService() = 0; - virtual void connect() = 0; - virtual bool isRunning() = 0; + virtual asio::io_service &getIoService() { return portal_->getIoService(); } + protected: + Socket(std::shared_ptr<core::Portal> &&portal); + virtual ~Socket(){}; + + protected: + std::shared_ptr<core::Portal> portal_; + bool is_async_; }; } // namespace implementation diff --git a/libtransport/src/implementation/socket_consumer.h b/libtransport/src/implementation/socket_consumer.h index 87965923e..a7b6ac4e7 100644 --- a/libtransport/src/implementation/socket_consumer.h +++ b/libtransport/src/implementation/socket_consumer.h @@ -13,15 +13,17 @@ * limitations under the License. */ +#pragma once + #include <hicn/transport/interfaces/socket_consumer.h> #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/interfaces/statistics.h> -#include <hicn/transport/security/verifier.h> +#include <hicn/transport/auth/verifier.h> #include <hicn/transport/utils/event_thread.h> #include <protocols/cbr.h> -#include <protocols/protocol.h> #include <protocols/raaqm.h> -#include <protocols/rtc.h> +#include <protocols/rtc/rtc.h> +#include <protocols/transport_protocol.h> namespace transport { namespace implementation { @@ -30,12 +32,12 @@ using namespace core; using namespace interface; using ReadCallback = interface::ConsumerSocket::ReadCallback; -class ConsumerSocket : public Socket<BasePortal> { +class ConsumerSocket : public Socket { private: ConsumerSocket(interface::ConsumerSocket *consumer, int protocol, - std::shared_ptr<Portal> &&portal) - : consumer_interface_(consumer), - portal_(portal), + std::shared_ptr<core::Portal> &&portal) + : Socket(std::move(portal)), + consumer_interface_(consumer), async_downloader_(), interest_lifetime_(default_values::interest_lifetime), min_window_size_(default_values::min_window_size), @@ -54,16 +56,13 @@ class ConsumerSocket : public Socket<BasePortal> { rate_estimation_observer_(nullptr), rate_estimation_batching_parameter_(default_values::batch), rate_estimation_choice_(0), - is_async_(false), - verifier_(std::make_shared<utils::Verifier>()), + verifier_(std::make_shared<auth::VoidVerifier>()), verify_signature_(false), - key_content_(false), reset_window_(false), on_interest_output_(VOID_HANDLER), on_interest_timeout_(VOID_HANDLER), on_interest_satisfied_(VOID_HANDLER), on_content_object_input_(VOID_HANDLER), - on_content_object_verification_(VOID_HANDLER), stats_summary_(VOID_HANDLER), read_callback_(nullptr), timer_interval_milliseconds_(0), @@ -75,7 +74,7 @@ class ConsumerSocket : public Socket<BasePortal> { break; case TransportProtocolAlgorithms::RTC: transport_protocol_ = - std::make_unique<protocol::RTCTransportProtocol>(this); + std::make_unique<protocol::rtc::RTCTransportProtocol>(this); break; case TransportProtocolAlgorithms::RAAQM: default: @@ -87,12 +86,12 @@ class ConsumerSocket : public Socket<BasePortal> { public: ConsumerSocket(interface::ConsumerSocket *consumer, int protocol) - : ConsumerSocket(consumer, protocol, std::make_shared<Portal>()) {} + : ConsumerSocket(consumer, protocol, std::make_shared<core::Portal>()) {} ConsumerSocket(interface::ConsumerSocket *consumer, int protocol, asio::io_service &io_service) : ConsumerSocket(consumer, protocol, - std::make_shared<Portal>(io_service)) { + std::make_shared<core::Portal>(io_service)) { is_async_ = true; } @@ -138,8 +137,6 @@ class ConsumerSocket : public Socket<BasePortal> { return CONSUMER_RUNNING; } - bool verifyKeyPackets() { return transport_protocol_->verifyKeyPackets(); } - void stop() { if (transport_protocol_->isRunning()) { transport_protocol_->stop(); @@ -152,8 +149,6 @@ class ConsumerSocket : public Socket<BasePortal> { } } - asio::io_service &getIoService() { return portal_->getIoService(); } - virtual int setSocketOption(int socket_option_key, ReadCallback *socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -316,12 +311,6 @@ class ConsumerSocket : public Socket<BasePortal> { break; } - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - if (socket_option_value == VOID_HANDLER) { - on_content_object_verification_ = VOID_HANDLER; - break; - } - default: return SOCKET_OPTION_NOT_SET; } @@ -334,16 +323,6 @@ class ConsumerSocket : public Socket<BasePortal> { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { - case GeneralTransportOptions::VERIFY_SIGNATURE: - verify_signature_ = socket_option_value; - result = SOCKET_OPTION_SET; - break; - - case GeneralTransportOptions::KEY_CONTENT: - key_content_ = socket_option_value; - result = SOCKET_OPTION_SET; - break; - case RaaqmTransportOptions::PER_SESSION_CWINDOW_RESET: reset_window_ = socket_option_value; result = SOCKET_OPTION_SET; @@ -377,29 +356,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - on_content_object_verification_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); - } - int setSocketOption(int socket_option_key, ConsumerInterestCallback socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -433,51 +389,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this]( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::VERIFICATION_FAILED: - verification_failed_callback_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); - } - - // int setSocketOption( - // int socket_option_key, - // ConsumerContentObjectVerificationFailedCallback socket_option_value) { - // return rescheduleOnIOService( - // socket_option_key, socket_option_value, - // [this]( - // int socket_option_key, - // ConsumerContentObjectVerificationFailedCallback - // socket_option_value) - // -> int { - // switch (socket_option_key) { - // case ConsumerCallbacksOptions::VERIFICATION_FAILED: - // verification_failed_callback_ = socket_option_value; - // break; - - // default: - // return SOCKET_OPTION_NOT_SET; - // } - - // return SOCKET_OPTION_SET; - // }); - // } - int setSocketOption(int socket_option_key, IcnObserver *socket_option_value) { utils::SpinLock::Acquire locked(guard_raaqm_params_); switch (socket_option_key) { @@ -494,7 +405,7 @@ class ConsumerSocket : public Socket<BasePortal> { int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Verifier> &socket_option_value) { + const std::shared_ptr<auth::Verifier> &socket_option_value) { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { @@ -516,14 +427,6 @@ class ConsumerSocket : public Socket<BasePortal> { int result = SOCKET_OPTION_NOT_SET; if (!transport_protocol_->isRunning()) { switch (socket_option_key) { - case GeneralTransportOptions::CERTIFICATE: - key_id_ = verifier_->addKeyFromCertificate(socket_option_value); - - if (key_id_ != nullptr) { - result = SOCKET_OPTION_SET; - } - break; - case DataLinkOptions::OUTPUT_INTERFACE: output_interface_ = socket_option_value; portal_->setOutputInterface(output_interface_); @@ -642,14 +545,6 @@ class ConsumerSocket : public Socket<BasePortal> { socket_option_value = transport_protocol_->isRunning(); break; - case GeneralTransportOptions::VERIFY_SIGNATURE: - socket_option_value = verify_signature_; - break; - - case GeneralTransportOptions::KEY_CONTENT: - socket_option_value = key_content_; - break; - case GeneralTransportOptions::ASYNC_MODE: socket_option_value = is_async_; break; @@ -699,29 +594,6 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) - -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::CONTENT_OBJECT_TO_VERIFY: - *socket_option_value = &on_content_object_verification_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); - } - int getSocketOption(int socket_option_key, ConsumerInterestCallback **socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -755,30 +627,8 @@ class ConsumerSocket : public Socket<BasePortal> { }); } - int getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback **socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in - // case setSocketOption is called while the io_service is running. - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ConsumerContentObjectVerificationFailedCallback * - *socket_option_value) -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::VERIFICATION_FAILED: - *socket_option_value = &verification_failed_callback_; - break; - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); - } - int getSocketOption(int socket_option_key, - std::shared_ptr<Portal> &socket_option_value) { + std::shared_ptr<core::Portal> &socket_option_value) { switch (socket_option_key) { case PORTAL: socket_option_value = portal_; @@ -807,7 +657,7 @@ class ConsumerSocket : public Socket<BasePortal> { } int getSocketOption(int socket_option_key, - std::shared_ptr<utils::Verifier> &socket_option_value) { + std::shared_ptr<auth::Verifier> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::VERIFIER: socket_option_value = verifier_; @@ -871,7 +721,7 @@ class ConsumerSocket : public Socket<BasePortal> { // To enforce type check std::function<int(int, arg2)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (transport_protocol_->isRunning()) { + if (transport_protocol_ && transport_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; @@ -898,7 +748,6 @@ class ConsumerSocket : public Socket<BasePortal> { protected: interface::ConsumerSocket *consumer_interface_; - std::shared_ptr<Portal> portal_; utils::EventThread async_downloader_; // No need to protect from multiple accesses in the async consumer @@ -926,13 +775,10 @@ class ConsumerSocket : public Socket<BasePortal> { int rate_estimation_batching_parameter_; int rate_estimation_choice_; - bool is_async_; - // Verification parameters - std::shared_ptr<utils::Verifier> verifier_; + std::shared_ptr<auth::Verifier> verifier_; PARCKeyId *key_id_; std::atomic_bool verify_signature_; - bool key_content_; bool reset_window_; ConsumerInterestCallback on_interest_retransmission_; @@ -940,9 +786,7 @@ class ConsumerSocket : public Socket<BasePortal> { ConsumerInterestCallback on_interest_timeout_; ConsumerInterestCallback on_interest_satisfied_; ConsumerContentObjectCallback on_content_object_input_; - ConsumerContentObjectVerificationCallback on_content_object_verification_; ConsumerTimerCallback stats_summary_; - ConsumerContentObjectVerificationFailedCallback verification_failed_callback_; ReadCallback *read_callback_; @@ -959,4 +803,4 @@ class ConsumerSocket : public Socket<BasePortal> { }; } // namespace implementation -} // namespace transport
\ No newline at end of file +} // namespace transport diff --git a/libtransport/src/implementation/socket_producer.h b/libtransport/src/implementation/socket_producer.h index a6f0f969e..af69cd818 100644 --- a/libtransport/src/implementation/socket_producer.h +++ b/libtransport/src/implementation/socket_producer.h @@ -15,9 +15,11 @@ #pragma once -#include <hicn/transport/security/signer.h> +#include <hicn/transport/auth/signer.h> #include <hicn/transport/utils/event_thread.h> #include <implementation/socket.h> +#include <protocols/prod_protocol_bytestream.h> +#include <protocols/prod_protocol_rtc.h> #include <utils/content_store.h> #include <utils/suffix_strategy.h> @@ -39,21 +41,17 @@ namespace implementation { using namespace core; using namespace interface; -class ProducerSocket : public Socket<BasePortal>, - public BasePortal::ProducerCallback { - static constexpr uint32_t burst_size = 256; - - public: - explicit ProducerSocket(interface::ProducerSocket *producer_socket) - : producer_interface_(producer_socket), - portal_(std::make_shared<Portal>(io_service_)), +class ProducerSocket : public Socket { + private: + ProducerSocket(interface::ProducerSocket *producer_socket, int protocol, + std::shared_ptr<core::Portal> &&portal) + : Socket(std::move(portal)), + producer_interface_(producer_socket), data_packet_size_(default_values::content_object_packet_size), content_object_expiry_time_(default_values::content_object_expiry_time), - output_buffer_(default_values::producer_socket_output_buffer_size), async_thread_(), - registration_status_(REGISTRATION_NOT_ATTEMPTED), making_manifest_(false), - hash_algorithm_(utils::CryptoHashType::SHA_256), + hash_algorithm_(auth::CryptoHashType::SHA_256), suffix_strategy_(core::NextSegmentCalculationStrategy::INCREMENTAL), on_interest_input_(VOID_HANDLER), on_interest_dropped_input_buffer_(VOID_HANDLER), @@ -65,15 +63,33 @@ class ProducerSocket : public Socket<BasePortal>, on_content_object_in_output_buffer_(VOID_HANDLER), on_content_object_output_(VOID_HANDLER), on_content_object_evicted_from_output_buffer_(VOID_HANDLER), - on_content_produced_(VOID_HANDLER) {} - - virtual ~ProducerSocket() { - stop(); - if (listening_thread_.joinable()) { - listening_thread_.join(); + on_content_produced_(VOID_HANDLER) { + switch (protocol) { + case ProductionProtocolAlgorithms::RTC_PROD: + production_protocol_ = + std::make_unique<protocol::RTCProductionProtocol>(this); + break; + case ProductionProtocolAlgorithms::BYTE_STREAM: + default: + production_protocol_ = + std::make_unique<protocol::ByteStreamProductionProtocol>(this); + break; } } + public: + ProducerSocket(interface::ProducerSocket *producer, int protocol) + : ProducerSocket(producer, protocol, std::make_shared<core::Portal>()) {} + + ProducerSocket(interface::ProducerSocket *producer, int protocol, + asio::io_service &io_service) + : ProducerSocket(producer, protocol, + std::make_shared<core::Portal>(io_service)) { + is_async_ = true; + } + + virtual ~ProducerSocket() {} + interface::ProducerSocket *getInterface() { return producer_interface_; } @@ -84,296 +100,10 @@ class ProducerSocket : public Socket<BasePortal>, void connect() override { portal_->connect(false); - listening_thread_ = std::thread(std::bind(&ProducerSocket::listen, this)); + production_protocol_->start(); } - bool isRunning() override { return !io_service_.stopped(); }; - - virtual uint32_t produce(Name content_name, const uint8_t *buffer, - size_t buffer_size, bool is_last = true, - uint32_t start_offset = 0) { - return ProducerSocket::produce( - content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), is_last, - start_offset); - } - - virtual uint32_t produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) { - if (TRANSPORT_EXPECT_FALSE(buffer->length() == 0)) { - return 0; - } - - // Copy the atomic variables to ensure they keep the same value - // during the production - std::size_t data_packet_size = data_packet_size_; - uint32_t content_object_expiry_time = content_object_expiry_time_; - utils::CryptoHashType hash_algo = hash_algorithm_; - bool making_manifest = making_manifest_; - auto suffix_strategy = utils::SuffixStrategyFactory::getSuffixStrategy( - suffix_strategy_, start_offset); - std::shared_ptr<utils::Signer> signer; - getSocketOption(GeneralTransportOptions::SIGNER, signer); - - auto buffer_size = buffer->length(); - int bytes_segmented = 0; - std::size_t header_size; - std::size_t manifest_header_size = 0; - std::size_t signature_length = 0; - std::uint32_t final_block_number = start_offset; - uint64_t free_space_for_content = 0; - - core::Packet::Format format; - std::shared_ptr<ContentObjectManifest> manifest; - bool is_last_manifest = false; - - // TODO Manifest may still be used for indexing - if (making_manifest && !signer) { - TRANSPORT_LOGE("Making manifests without setting producer identity."); - } - - core::Packet::Format hf_format = core::Packet::Format::HF_UNSPEC; - core::Packet::Format hf_format_ah = core::Packet::Format::HF_UNSPEC; - if (content_name.getType() == HNT_CONTIGUOUS_V4 || - content_name.getType() == HNT_IOV_V4) { - hf_format = core::Packet::Format::HF_INET_TCP; - hf_format_ah = core::Packet::Format::HF_INET_TCP_AH; - } else if (content_name.getType() == HNT_CONTIGUOUS_V6 || - content_name.getType() == HNT_IOV_V6) { - hf_format = core::Packet::Format::HF_INET6_TCP; - hf_format_ah = core::Packet::Format::HF_INET6_TCP_AH; - } else { - throw errors::RuntimeException("Unknown name format."); - } - - format = hf_format; - if (making_manifest) { - manifest_header_size = core::Packet::getHeaderSizeFromFormat( - signer ? hf_format_ah : hf_format, - signer ? signer->getSignatureLength() : 0); - } else if (signer) { - format = hf_format_ah; - signature_length = signer->getSignatureLength(); - } - - header_size = - core::Packet::getHeaderSizeFromFormat(format, signature_length); - free_space_for_content = data_packet_size - header_size; - uint32_t number_of_segments = uint32_t( - std::ceil(double(buffer_size) / double(free_space_for_content))); - if (free_space_for_content * number_of_segments < buffer_size) { - number_of_segments++; - } - - // TODO allocate space for all the headers - if (making_manifest) { - uint32_t segment_in_manifest = static_cast<uint32_t>( - std::floor(double(data_packet_size - manifest_header_size - - ContentObjectManifest::getManifestHeaderSize()) / - ContentObjectManifest::getManifestEntrySize()) - - 1.0); - uint32_t number_of_manifests = static_cast<uint32_t>( - std::ceil(float(number_of_segments) / segment_in_manifest)); - final_block_number += number_of_segments + number_of_manifests - 1; - - manifest.reset(ContentObjectManifest::createManifest( - content_name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, - hash_algo, is_last_manifest, content_name, suffix_strategy_, - signer ? signer->getSignatureLength() : 0)); - manifest->setLifetime(content_object_expiry_time); - - if (is_last) { - manifest->setFinalBlockNumber(final_block_number); - } else { - manifest->setFinalBlockNumber(utils::SuffixStrategy::INVALID_SUFFIX); - } - } - - for (unsigned int packaged_segments = 0; - packaged_segments < number_of_segments; packaged_segments++) { - if (making_manifest) { - if (manifest->estimateManifestSize(2) > - data_packet_size - manifest_header_size) { - // Send the current manifest - manifest->encode(); - - // If identity set, sign manifest - if (signer) { - signer->sign(*manifest); - } - - passContentObjectToCallbacks(manifest); - TRANSPORT_LOGD("Send manifest %s", - manifest->getName().toString().c_str()); - - // Send content objects stored in the queue - while (!content_queue_.empty()) { - passContentObjectToCallbacks(content_queue_.front()); - TRANSPORT_LOGD( - "Send content %s", - content_queue_.front()->getName().toString().c_str()); - content_queue_.pop(); - } - - // Create new manifest. The reference to the last manifest has been - // acquired in the passContentObjectToCallbacks function, so we can - // safely release this reference - manifest.reset(ContentObjectManifest::createManifest( - content_name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, - core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, - content_name, suffix_strategy_, - signer ? signer->getSignatureLength() : 0)); - - manifest->setLifetime(content_object_expiry_time); - manifest->setFinalBlockNumber( - is_last ? final_block_number - : utils::SuffixStrategy::INVALID_SUFFIX); - } - } - - auto content_suffix = suffix_strategy->getNextContentSuffix(); - auto content_object = std::make_shared<ContentObject>( - content_name.setSuffix(content_suffix), format); - content_object->setLifetime(content_object_expiry_time); - - auto b = buffer->cloneOne(); - b->trimStart(free_space_for_content * packaged_segments); - b->trimEnd(b->length()); - - if (TRANSPORT_EXPECT_FALSE(packaged_segments == number_of_segments - 1)) { - b->append(buffer_size - bytes_segmented); - bytes_segmented += (int)(buffer_size - bytes_segmented); - - if (is_last && making_manifest) { - is_last_manifest = true; - } else if (is_last) { - content_object->setRst(); - } - - } else { - b->append(free_space_for_content); - bytes_segmented += (int)(free_space_for_content); - } - - content_object->appendPayload(std::move(b)); - - if (making_manifest) { - using namespace std::chrono_literals; - utils::CryptoHash hash = content_object->computeDigest(hash_algo); - manifest->addSuffixHash(content_suffix, hash); - content_queue_.push(content_object); - } else { - if (signer) { - signer->sign(*content_object); - } - passContentObjectToCallbacks(content_object); - TRANSPORT_LOGD("Send content %s", - content_object->getName().toString().c_str()); - } - } - - if (making_manifest) { - if (is_last_manifest) { - manifest->setFinalManifest(is_last_manifest); - } - - manifest->encode(); - if (signer) { - signer->sign(*manifest); - } - - passContentObjectToCallbacks(manifest); - TRANSPORT_LOGD("Send manifest %s", - manifest->getName().toString().c_str()); - - while (!content_queue_.empty()) { - passContentObjectToCallbacks(content_queue_.front()); - TRANSPORT_LOGD("Send content %s", - content_queue_.front()->getName().toString().c_str()); - content_queue_.pop(); - } - } - - io_service_.post([this]() { - std::shared_ptr<ContentObject> co; - while (object_queue_for_callbacks_.pop(co)) { - if (on_new_segment_) { - on_new_segment_(*producer_interface_, *co); - } - - if (on_content_object_to_sign_) { - on_content_object_to_sign_(*producer_interface_, *co); - } - - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, *co); - } - - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *co); - } - } - }); - - io_service_.dispatch([this, buffer_size]() { - if (on_content_produced_) { - on_content_produced_(*producer_interface_, - std::make_error_code(std::errc(0)), buffer_size); - } - }); - - return suffix_strategy->getTotalCount(); - } - - virtual void produce(ContentObject &content_object) { - io_service_.dispatch([this, &content_object]() { - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, - content_object); - } - }); - - output_buffer_.insert(std::static_pointer_cast<ContentObject>( - content_object.shared_from_this())); - - io_service_.dispatch([this, &content_object]() { - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, content_object); - } - }); - - portal_->sendContentObject(content_object); - } - - virtual void produce(const uint8_t *buffer, size_t buffer_size) { - produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); - } - - virtual void produce(std::unique_ptr<utils::MemBuf> &&buffer) { - // This API is meant to be used just with the RTC producer. - // Here it cannot be used since no name for the content is specified. - throw errors::NotImplementedException(); - } - - virtual void asyncProduce(const Name &suffix, const uint8_t *buf, - size_t buffer_size, bool is_last = true, - uint32_t *start_offset = nullptr) { - if (!async_thread_.stopped()) { - async_thread_.add([this, suffix, buffer = buf, size = buffer_size, - is_last, start_offset]() { - if (start_offset != nullptr) { - *start_offset = ProducerSocket::produce(suffix, buffer, size, is_last, - *start_offset); - } else { - ProducerSocket::produce(suffix, buffer, size, is_last, 0); - } - }); - } - } - - void asyncProduce(const Name &suffix); + bool isRunning() override { return !production_protocol_->isRunning(); }; virtual void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, @@ -381,75 +111,56 @@ class ProducerSocket : public Socket<BasePortal>, uint32_t **last_segment = nullptr) { if (!async_thread_.stopped()) { auto a = buffer.release(); - async_thread_.add( - [this, content_name, a, is_last, offset, last_segment]() { - auto buf = std::unique_ptr<utils::MemBuf>(a); - if (last_segment != NULL) { - **last_segment = - offset + ProducerSocket::produce(content_name, std::move(buf), - is_last, offset); - } else { - ProducerSocket::produce(content_name, std::move(buf), is_last, - offset); - } - }); - } - } - - virtual void asyncProduce(ContentObject &content_object) { - if (!async_thread_.stopped()) { - auto co_ptr = std::static_pointer_cast<ContentObject>( - content_object.shared_from_this()); - async_thread_.add([this, content_object = std::move(co_ptr)]() { - ProducerSocket::produce(*content_object); + async_thread_.add([this, content_name, a, is_last, offset, + last_segment]() { + auto buf = std::unique_ptr<utils::MemBuf>(a); + if (last_segment != NULL) { + **last_segment = offset + produceStream(content_name, std::move(buf), + is_last, offset); + } else { + produceStream(content_name, std::move(buf), is_last, offset); + } }); } } - virtual void registerPrefix(const Prefix &producer_namespace) { - served_namespaces_.push_back(producer_namespace); + virtual uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) { + return production_protocol_->produceStream(content_name, std::move(buffer), + is_last, start_offset); } - void serveForever() { - if (listening_thread_.joinable()) { - listening_thread_.join(); - } + virtual uint32_t produceStream(const Name &content_name, + const uint8_t *buffer, size_t buffer_size, + bool is_last = true, + uint32_t start_offset = 0) { + return production_protocol_->produceStream( + content_name, buffer, buffer_size, is_last, start_offset); } - void stop() { portal_->stopEventsLoop(); } - - asio::io_service &getIoService() override { return portal_->getIoService(); }; - - virtual void onInterest(Interest &interest) { - if (on_interest_input_) { - on_interest_input_(*producer_interface_, interest); - } - - const std::shared_ptr<ContentObject> content_object = - output_buffer_.find(interest); - - if (content_object) { - if (on_interest_satisfied_output_buffer_) { - on_interest_satisfied_output_buffer_(*producer_interface_, interest); - } + virtual uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) { + return production_protocol_->produceDatagram(content_name, + std::move(buffer)); + } - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *content_object); - } + virtual uint32_t produceDatagram(const Name &content_name, + const uint8_t *buffer, size_t buffer_size) { + return production_protocol_->produceDatagram(content_name, buffer, + buffer_size); + } - portal_->sendContentObject(*content_object); - } else { - if (on_interest_process_) { - on_interest_process_(*producer_interface_, interest); - } - } + void produce(ContentObject &content_object) { + production_protocol_->produce(content_object); } - virtual void onInterest(Interest::Ptr &&interest) override { - onInterest(*interest); - }; + void registerPrefix(const Prefix &producer_namespace) { + production_protocol_->registerNamespaceWithNetwork(producer_namespace); + } - virtual void onError(std::error_code ec) override {} + void stop() { production_protocol_->stop(); } virtual int setSocketOption(int socket_option_key, uint32_t socket_option_value) { @@ -462,7 +173,7 @@ class ProducerSocket : public Socket<BasePortal>, break; case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: - output_buffer_.setLimit(socket_option_value); + production_protocol_->setOutputBufferSize(socket_option_value); break; case GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME: @@ -533,6 +244,12 @@ class ProducerSocket : public Socket<BasePortal>, break; } + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + if (socket_option_value == VOID_HANDLER) { + on_content_object_to_sign_ = VOID_HANDLER; + break; + } + default: return SOCKET_OPTION_NOT_SET; } @@ -559,19 +276,6 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_NOT_SET; } - virtual int setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) { - switch (socket_option_key) { - case GeneralTransportOptions::NETWORK_NAME: - served_namespaces_ = socket_option_value; - break; - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - } - virtual int setSocketOption( int socket_option_key, interface::ProducerContentObjectCallback socket_option_value) { @@ -594,6 +298,10 @@ class ProducerSocket : public Socket<BasePortal>, on_content_object_output_ = socket_option_value; break; + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + on_content_object_to_sign_ = socket_option_value; + break; + default: return SOCKET_OPTION_NOT_SET; } @@ -663,7 +371,7 @@ class ProducerSocket : public Socket<BasePortal>, } virtual int setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) { + auth::CryptoHashType socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::HASH_ALGORITHM: hash_algorithm_ = socket_option_value; @@ -675,11 +383,12 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_SET; } - virtual int setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) { + virtual int setSocketOption( + int socket_option_key, + core::NextSegmentCalculationStrategy socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::CRYPTO_SUITE: - crypto_suite_ = socket_option_value; + case GeneralTransportOptions::SUFFIX_STRATEGY: + suffix_strategy_ = socket_option_value; break; default: return SOCKET_OPTION_NOT_SET; @@ -690,7 +399,7 @@ class ProducerSocket : public Socket<BasePortal>, virtual int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::SIGNER: { utils::SpinLock::Acquire locked(signer_lock_); @@ -708,7 +417,7 @@ class ProducerSocket : public Socket<BasePortal>, uint32_t &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: - socket_option_value = (uint32_t)output_buffer_.getLimit(); + socket_option_value = production_protocol_->getOutputBufferSize(); break; case GeneralTransportOptions::DATA_PACKET_SIZE: @@ -733,18 +442,8 @@ class ProducerSocket : public Socket<BasePortal>, socket_option_value = making_manifest_; break; - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - } - - virtual int getSocketOption(int socket_option_key, - std::list<Prefix> &socket_option_value) { - switch (socket_option_key) { - case GeneralTransportOptions::NETWORK_NAME: - socket_option_value = served_namespaces_; + case GeneralTransportOptions::ASYNC_MODE: + socket_option_value = is_async_; break; default: @@ -776,6 +475,10 @@ class ProducerSocket : public Socket<BasePortal>, *socket_option_value = &on_content_object_output_; break; + case ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN: + *socket_option_value = &on_content_object_to_sign_; + break; + default: return SOCKET_OPTION_NOT_GET; } @@ -828,11 +531,11 @@ class ProducerSocket : public Socket<BasePortal>, *socket_option_value = &on_interest_inserted_input_buffer_; break; - case CACHE_HIT: + case ProducerCallbacksOptions::CACHE_HIT: *socket_option_value = &on_interest_satisfied_output_buffer_; break; - case CACHE_MISS: + case ProducerCallbacksOptions::CACHE_MISS: *socket_option_value = &on_interest_process_; break; @@ -844,8 +547,9 @@ class ProducerSocket : public Socket<BasePortal>, }); } - virtual int getSocketOption(int socket_option_key, - std::shared_ptr<Portal> &socket_option_value) { + virtual int getSocketOption( + int socket_option_key, + std::shared_ptr<core::Portal> &socket_option_value) { switch (socket_option_key) { case PORTAL: socket_option_value = portal_; @@ -859,7 +563,7 @@ class ProducerSocket : public Socket<BasePortal>, } virtual int getSocketOption(int socket_option_key, - utils::CryptoHashType &socket_option_value) { + auth::CryptoHashType &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::HASH_ALGORITHM: socket_option_value = hash_algorithm_; @@ -871,22 +575,22 @@ class ProducerSocket : public Socket<BasePortal>, return SOCKET_OPTION_GET; } - virtual int getSocketOption(int socket_option_key, - utils::CryptoSuite &socket_option_value) { + virtual int getSocketOption( + int socket_option_key, + core::NextSegmentCalculationStrategy &socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::HASH_ALGORITHM: - socket_option_value = crypto_suite_; + case GeneralTransportOptions::SUFFIX_STRATEGY: + socket_option_value = suffix_strategy_; break; default: return SOCKET_OPTION_NOT_GET; } - return SOCKET_OPTION_GET; } virtual int getSocketOption( int socket_option_key, - std::shared_ptr<utils::Signer> &socket_option_value) { + std::shared_ptr<auth::Signer> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::SIGNER: { utils::SpinLock::Acquire locked(signer_lock_); @@ -907,19 +611,21 @@ class ProducerSocket : public Socket<BasePortal>, // If the thread calling lambda_func is not the same of io_service, this // function reschedule the function on it template <typename Lambda, typename arg2> - int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, - Lambda lambda_func) { + int rescheduleOnIOServiceWithReference(int socket_option_key, + arg2 &socket_option_value, + Lambda lambda_func) { // To enforce type check - std::function<int(int, arg2)> func = lambda_func; + std::function<int(int, arg2 &)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (listening_thread_.joinable() && - std::this_thread::get_id() != listening_thread_.get_id()) { + if (production_protocol_ && production_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; + bool done = false; - io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, - &result, &done, &func]() { + portal_->getIoService().dispatch([&socket_option_key, + &socket_option_value, &mtx, &cv, + &result, &done, &func]() { std::unique_lock<std::mutex> lck(mtx); done = true; result = func(socket_option_key, socket_option_value); @@ -939,21 +645,19 @@ class ProducerSocket : public Socket<BasePortal>, // If the thread calling lambda_func is not the same of io_service, this // function reschedule the function on it template <typename Lambda, typename arg2> - int rescheduleOnIOServiceWithReference(int socket_option_key, - arg2 &socket_option_value, - Lambda lambda_func) { + int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, + Lambda lambda_func) { // To enforce type check - std::function<int(int, arg2 &)> func = lambda_func; + std::function<int(int, arg2)> func = lambda_func; int result = SOCKET_OPTION_SET; - if (listening_thread_.joinable() && - std::this_thread::get_id() != this->listening_thread_.get_id()) { + if (production_protocol_ && production_protocol_->isRunning()) { std::mutex mtx; /* Condition variable for the wait */ std::condition_variable cv; - bool done = false; - io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, - &result, &done, &func]() { + portal_->getIoService().dispatch([&socket_option_key, + &socket_option_value, &mtx, &cv, + &result, &done, &func]() { std::unique_lock<std::mutex> lck(mtx); done = true; result = func(socket_option_key, socket_option_value); @@ -973,39 +677,20 @@ class ProducerSocket : public Socket<BasePortal>, // Threads protected: interface::ProducerSocket *producer_interface_; - std::thread listening_thread_; asio::io_service io_service_; - std::shared_ptr<Portal> portal_; std::atomic<size_t> data_packet_size_; - std::list<Prefix> - served_namespaces_; // No need to be threadsafe, this is always modified - // by the application thread std::atomic<uint32_t> content_object_expiry_time_; - utils::CircularFifo<std::shared_ptr<ContentObject>, 2048> - object_queue_for_callbacks_; - - // buffers - // ContentStore is thread-safe - utils::ContentStore output_buffer_; - utils::EventThread async_thread_; - int registration_status_; std::atomic<bool> making_manifest_; - - // map for storing sequence numbers for several calls of the publish - // function - std::unordered_map<Name, std::unordered_map<int, uint32_t>> seq_number_map_; - - std::atomic<utils::CryptoHashType> hash_algorithm_; - std::atomic<utils::CryptoSuite> crypto_suite_; + std::atomic<auth::CryptoHashType> hash_algorithm_; + std::atomic<auth::CryptoSuite> crypto_suite_; utils::SpinLock signer_lock_; - std::shared_ptr<utils::Signer> signer_; + std::shared_ptr<auth::Signer> signer_; core::NextSegmentCalculationStrategy suffix_strategy_; - // While manifests are being built, contents are stored in a queue - std::queue<std::shared_ptr<ContentObject>> content_queue_; + std::unique_ptr<protocol::ProductionProtocol> production_protocol_; // callbacks ProducerInterestCallback on_interest_input_; @@ -1021,63 +706,6 @@ class ProducerSocket : public Socket<BasePortal>, ProducerContentObjectCallback on_content_object_evicted_from_output_buffer_; ProducerContentCallback on_content_produced_; - - private: - void listen() { - bool first = true; - - for (core::Prefix &producer_namespace : served_namespaces_) { - if (first) { - core::BindConfig bind_config(producer_namespace, 1000); - portal_->bind(bind_config); - portal_->setProducerCallback(this); - first = !first; - } else { - portal_->registerRoute(producer_namespace); - } - } - - portal_->runEventsLoop(); - } - - void scheduleSendBurst() { - io_service_.post([this]() { - std::shared_ptr<ContentObject> co; - - for (uint32_t i = 0; i < burst_size; i++) { - if (object_queue_for_callbacks_.pop(co)) { - if (on_new_segment_) { - on_new_segment_(*producer_interface_, *co); - } - - if (on_content_object_to_sign_) { - on_content_object_to_sign_(*producer_interface_, *co); - } - - if (on_content_object_in_output_buffer_) { - on_content_object_in_output_buffer_(*producer_interface_, *co); - } - - if (on_content_object_output_) { - on_content_object_output_(*producer_interface_, *co); - } - } else { - break; - } - } - }); - } - - void passContentObjectToCallbacks( - const std::shared_ptr<ContentObject> &content_object) { - output_buffer_.insert(content_object); - portal_->sendContentObject(*content_object); - object_queue_for_callbacks_.push(std::move(content_object)); - - if (object_queue_for_callbacks_.size() >= burst_size) { - scheduleSendBurst(); - } - } }; } // namespace implementation diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.cc b/libtransport/src/implementation/tls_rtc_socket_producer.cc index 9ef79ca23..9a62c8683 100644 --- a/libtransport/src/implementation/tls_rtc_socket_producer.cc +++ b/libtransport/src/implementation/tls_rtc_socket_producer.cc @@ -15,10 +15,8 @@ #include <hicn/transport/core/interest.h> #include <hicn/transport/interfaces/p2psecure_socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> #include <implementation/tls_rtc_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.h b/libtransport/src/implementation/tls_rtc_socket_producer.h index 685c91244..92c657afc 100644 --- a/libtransport/src/implementation/tls_rtc_socket_producer.h +++ b/libtransport/src/implementation/tls_rtc_socket_producer.h @@ -15,7 +15,6 @@ #pragma once -#include <implementation/rtc_socket_producer.h> #include <implementation/tls_socket_producer.h> namespace transport { @@ -23,8 +22,7 @@ namespace implementation { class P2PSecureProducerSocket; -class TLSRTCProducerSocket : public RTCProducerSocket, - public TLSProducerSocket { +class TLSRTCProducerSocket : public TLSProducerSocket { friend class P2PSecureProducerSocket; public: @@ -34,7 +32,8 @@ class TLSRTCProducerSocket : public RTCProducerSocket, ~TLSRTCProducerSocket() = default; - void produce(std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; void accept() override; diff --git a/libtransport/src/implementation/tls_socket_consumer.cc b/libtransport/src/implementation/tls_socket_consumer.cc index 1be6f41a7..99bcd4360 100644 --- a/libtransport/src/implementation/tls_socket_consumer.cc +++ b/libtransport/src/implementation/tls_socket_consumer.cc @@ -136,7 +136,6 @@ TLSConsumerSocket::TLSConsumerSocket(interface::ConsumerSocket *consumer_socket, int protocol, SSL *ssl) : ConsumerSocket(consumer_socket, protocol), name_(), - buf_pool_(), decrypted_content_(), payload_(), head_(), @@ -223,14 +222,15 @@ int TLSConsumerSocket::download_content(const Name &name) { content_downloaded_ = false; std::size_t max_buffer_size = read_callback_decrypted_->maxBufferSize(); - std::size_t buffer_size = read_callback_decrypted_->maxBufferSize() + SSL3_RT_MAX_PLAIN_LENGTH; + std::size_t buffer_size = + read_callback_decrypted_->maxBufferSize() + SSL3_RT_MAX_PLAIN_LENGTH; decrypted_content_ = utils::MemBuf::createCombined(buffer_size); int result = -1; std::size_t size = 0; while (!content_downloaded_ || something_to_read_) { - result = SSL_read( - this->ssl_, decrypted_content_->writableTail(), SSL3_RT_MAX_PLAIN_LENGTH); + result = SSL_read(this->ssl_, decrypted_content_->writableTail(), + SSL3_RT_MAX_PLAIN_LENGTH); /* SSL_read returns the data only if there were SSL3_RT_MAX_PLAIN_LENGTH of * the data has been fully downloaded */ diff --git a/libtransport/src/implementation/tls_socket_consumer.h b/libtransport/src/implementation/tls_socket_consumer.h index 1c5df346a..be08ec47d 100644 --- a/libtransport/src/implementation/tls_socket_consumer.h +++ b/libtransport/src/implementation/tls_socket_consumer.h @@ -16,9 +16,7 @@ #pragma once #include <hicn/transport/interfaces/socket_consumer.h> - #include <implementation/socket_consumer.h> - #include <openssl/ssl.h> namespace transport { @@ -74,7 +72,6 @@ class TLSConsumerSocket : public ConsumerSocket, SSL_CTX *ctx_; /* Chain of MemBuf to be used as a temporary buffer to pass descypted data * from the underlying layer to the application */ - utils::ObjectPool<utils::MemBuf> buf_pool_; std::unique_ptr<utils::MemBuf> decrypted_content_; /* Chain of MemBuf holding the payload to be written into interest or data */ std::unique_ptr<utils::MemBuf> payload_; diff --git a/libtransport/src/implementation/tls_socket_producer.cc b/libtransport/src/implementation/tls_socket_producer.cc index 339a1ad58..e54d38d56 100644 --- a/libtransport/src/implementation/tls_socket_producer.cc +++ b/libtransport/src/implementation/tls_socket_producer.cc @@ -14,10 +14,8 @@ */ #include <hicn/transport/interfaces/socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> #include <implementation/tls_socket_producer.h> - #include <openssl/bio.h> #include <openssl/rand.h> #include <openssl/ssl.h> @@ -50,10 +48,14 @@ int TLSProducerSocket::readOld(BIO *b, char *buf, int size) { std::unique_lock<std::mutex> lck(socket->mtx_); + TRANSPORT_LOGD("Start wait on the CV."); + if (!socket->something_to_read_) { (socket->cv_).wait(lck); } + TRANSPORT_LOGD("CV unlocked."); + /* Either there already is something to read, or the thread has been waken up. * We must return the payload in the interest anyway */ utils::MemBuf *membuf = socket->handshake_packet_->next(); @@ -103,7 +105,7 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { socket->tls_chunks_--; socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, false); - socket->parent_->ProducerSocket::produce( + socket->parent_->ProducerSocket::produceStream( socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, socket->last_segment_); socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, @@ -122,18 +124,18 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { socket->tls_chunks_--; socket->to_call_oncontentproduced_--; - socket->last_segment_ += socket->ProducerSocket::produce( + socket->last_segment_ += socket->ProducerSocket::produceStream( socket->name_, std::move(mbuf), socket->tls_chunks_ == 0, socket->last_segment_); - ProducerContentCallback on_content_produced_application; + ProducerContentCallback *on_content_produced_application; socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, - on_content_produced_application); + &on_content_produced_application); if (socket->to_call_oncontentproduced_ == 0 && on_content_produced_application) { - on_content_produced_application(*socket->getInterface(), - std::error_code(), 0); + on_content_produced_application->operator()(*socket->getInterface(), + std::error_code(), 0); } }); } @@ -144,7 +146,8 @@ int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { TLSProducerSocket::TLSProducerSocket(interface::ProducerSocket *producer_socket, P2PSecureProducerSocket *parent, const Name &handshake_name) - : ProducerSocket(producer_socket), + : ProducerSocket(producer_socket, + ProductionProtocolAlgorithms::BYTE_STREAM), on_content_produced_application_(), mtx_(), cv_(), @@ -236,13 +239,14 @@ void TLSProducerSocket::accept() { std::move(parent_->map_producers[handshake_name_])); parent_->map_producers.erase(handshake_name_); - ProducerInterestCallback on_interest_process_decrypted; + ProducerInterestCallback *on_interest_process_decrypted; getSocketOption(ProducerCallbacksOptions::CACHE_MISS, - on_interest_process_decrypted); + &on_interest_process_decrypted); - if (on_interest_process_decrypted) { - Interest inter(std::move(handshake_packet_)); - on_interest_process_decrypted(*getInterface(), inter); + if (*on_interest_process_decrypted) { + Interest inter(std::move(*handshake_packet_)); + handshake_packet_.reset(); + on_interest_process_decrypted->operator()(*getInterface(), inter); } else { throw errors::RuntimeException( "On interest process unset: unable to perform handshake"); @@ -270,14 +274,14 @@ void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { std::unique_lock<std::mutex> lck(mtx_); name_ = interest.getName(); - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; cv_.notify_one(); return; } else if (handshake_state == SERVER_FINISHED) { - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; @@ -288,12 +292,12 @@ void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { interest.getPayload()->length()); } - ProducerInterestCallback on_interest_input_decrypted; + ProducerInterestCallback *on_interest_input_decrypted; getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, - on_interest_input_decrypted); + &on_interest_input_decrypted); - if (on_interest_input_decrypted) - (on_interest_input_decrypted)(*getInterface(), interest); + if (*on_interest_input_decrypted) + (*on_interest_input_decrypted)(*getInterface(), interest); } } @@ -301,17 +305,19 @@ void TLSProducerSocket::cacheMiss(interface::ProducerSocket &p, Interest &interest) { HandshakeState handshake_state = getHandshakeState(); + TRANSPORT_LOGD("On cache miss in TLS socket producer."); + if (handshake_state == CLIENT_HELLO) { std::unique_lock<std::mutex> lck(mtx_); - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; handshake_state_ = CLIENT_FINISHED; cv_.notify_one(); } else if (handshake_state == SERVER_FINISHED) { - interest.separateHeaderPayload(); + // interest.separateHeaderPayload(); handshake_packet_ = interest.acquireMemBufReference(); something_to_read_ = true; @@ -343,16 +349,16 @@ void TLSProducerSocket::onContentProduced(interface::ProducerSocket &p, const std::error_code &err, uint64_t bytes_written) {} -uint32_t TLSProducerSocket::produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { +uint32_t TLSProducerSocket::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { if (getHandshakeState() != SERVER_FINISHED) { throw errors::RuntimeException( "New handshake on the same P2P secure producer socket not supported"); } size_t buf_size = buffer->length(); - name_ = served_namespaces_.front().mapName(content_name); + name_ = production_protocol_->getNamespaces().front().mapName(content_name); tls_chunks_ = to_call_oncontentproduced_ = ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); @@ -370,46 +376,6 @@ uint32_t TLSProducerSocket::produce(Name content_name, return 0; } -void TLSProducerSocket::asyncProduce(const Name &content_name, - const uint8_t *buf, size_t buffer_size, - bool is_last, uint32_t *start_offset) { - if (!encryption_thread_.stopped()) { - encryption_thread_.add([this, content_name, buffer = buf, - size = buffer_size, is_last, start_offset]() { - if (start_offset != NULL) { - produce(content_name, buffer, size, is_last, *start_offset); - } else { - produce(content_name, buffer, size, is_last, 0); - } - }); - } -} - -void TLSProducerSocket::asyncProduce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t offset, - uint32_t **last_segment) { - if (!encryption_thread_.stopped()) { - auto a = buffer.release(); - encryption_thread_.add( - [this, content_name, a, is_last, offset, last_segment]() { - auto buf = std::unique_ptr<utils::MemBuf>(a); - if (last_segment != NULL) { - *last_segment = &last_segment_; - } - produce(content_name, std::move(buf), is_last, offset); - }); - } -} - -void TLSProducerSocket::asyncProduce(ContentObject &content_object) { - throw errors::RuntimeException("API not supported"); -} - -void TLSProducerSocket::produce(ContentObject &content_object) { - throw errors::RuntimeException("API not supported"); -} - long TLSProducerSocket::ctrl(BIO *b, int cmd, long num, void *ptr) { if (cmd == BIO_CTRL_FLUSH) { } @@ -424,13 +390,14 @@ int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, void *add_arg) { TLSProducerSocket *socket = reinterpret_cast<TLSProducerSocket *>(add_arg); + TRANSPORT_LOGD("On addHicnKeyIdCb, for the prefix registration."); + if (ext_type == 100) { - ip_prefix_t ip_prefix = - socket->parent_->served_namespaces_.front().toIpPrefixStruct(); - int inet_family = - socket->parent_->served_namespaces_.front().getAddressFamily(); - uint16_t prefix_len_bits = - socket->parent_->served_namespaces_.front().getPrefixLength(); + auto &prefix = + socket->parent_->production_protocol_->getNamespaces().front(); + const ip_prefix_t &ip_prefix = prefix.toIpPrefixStruct(); + int inet_family = prefix.getAddressFamily(); + uint16_t prefix_len_bits = prefix.getPrefixLength(); uint8_t prefix_len_bytes = prefix_len_bits / 8; uint8_t prefix_len_u32 = prefix_len_bits / 32; @@ -479,10 +446,9 @@ int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, socket->parent_->on_interest_process_decrypted_; socket->registerPrefix( - Prefix(socket->parent_->served_namespaces_.front().getName( - Name(inet_family, (uint8_t *)&mask), - Name(inet_family, (uint8_t *)&keyId_component), - socket->parent_->served_namespaces_.front().getName()), + Prefix(prefix.getName(Name(inet_family, (uint8_t *)&mask), + Name(inet_family, (uint8_t *)&keyId_component), + prefix.getName()), out_ip->len)); socket->connect(); } @@ -580,61 +546,5 @@ int TLSProducerSocket::getSocketOption( }); } -int TLSProducerSocket::getSocketOption( - int socket_option_key, ProducerContentCallback &socket_option_value) { - return rescheduleOnIOServiceWithReference( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerContentCallback &socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::CONTENT_PRODUCED: - socket_option_value = on_content_produced_application_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); -} - -int TLSProducerSocket::getSocketOption( - int socket_option_key, ProducerInterestCallback &socket_option_value) { - // Reschedule the function on the io_service to avoid race condition in case - // setSocketOption is called while the io_service is running. - return rescheduleOnIOServiceWithReference( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerInterestCallback &socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::INTEREST_INPUT: - socket_option_value = on_interest_input_decrypted_; - break; - - case ProducerCallbacksOptions::INTEREST_DROP: - socket_option_value = on_interest_dropped_input_buffer_; - break; - - case ProducerCallbacksOptions::INTEREST_PASS: - socket_option_value = on_interest_inserted_input_buffer_; - break; - - case ProducerCallbacksOptions::CACHE_HIT: - socket_option_value = on_interest_satisfied_output_buffer_; - break; - - case ProducerCallbacksOptions::CACHE_MISS: - socket_option_value = on_interest_process_decrypted_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); -} - } // namespace implementation } // namespace transport diff --git a/libtransport/src/implementation/tls_socket_producer.h b/libtransport/src/implementation/tls_socket_producer.h index 2382e8695..a542a4d9f 100644 --- a/libtransport/src/implementation/tls_socket_producer.h +++ b/libtransport/src/implementation/tls_socket_producer.h @@ -16,8 +16,8 @@ #pragma once #include <implementation/socket_producer.h> - #include <openssl/ssl.h> + #include <condition_variable> #include <mutex> @@ -36,26 +36,18 @@ class TLSProducerSocket : virtual public ProducerSocket { ~TLSProducerSocket(); - uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, - bool is_last = true, uint32_t start_offset = 0) override { - return produce(content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), - is_last, start_offset); + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override { + return produceStream(content_name, + utils::MemBuf::copyBuffer(buffer, buffer_size), + is_last, start_offset); } - uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0) override; - - void produce(ContentObject &content_object) override; - - void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, - bool is_last = true, - uint32_t *start_offset = nullptr) override; - - void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t offset, - uint32_t **last_segment = nullptr) override; - - void asyncProduce(ContentObject &content_object) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; virtual void accept(); @@ -80,7 +72,7 @@ class TLSProducerSocket : virtual public ProducerSocket { ProducerInterestCallback &socket_option_value); using ProducerSocket::getSocketOption; - using ProducerSocket::onInterest; + // using ProducerSocket::onInterest; using ProducerSocket::setSocketOption; protected: @@ -119,6 +111,7 @@ class TLSProducerSocket : virtual public ProducerSocket { int to_call_oncontentproduced_; bool still_writing_; utils::EventThread encryption_thread_; + utils::EventThread async_thread_; void onInterest(ProducerSocket &p, Interest &interest); diff --git a/libtransport/src/interfaces/CMakeLists.txt b/libtransport/src/interfaces/CMakeLists.txt index e1d144596..0284aa412 100644 --- a/libtransport/src/interfaces/CMakeLists.txt +++ b/libtransport/src/interfaces/CMakeLists.txt @@ -14,24 +14,24 @@ cmake_minimum_required(VERSION 3.5 FATAL_ERROR) list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc ${CMAKE_CURRENT_SOURCE_DIR}/callbacks.cc + ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc ) if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc ) list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h ) diff --git a/libtransport/src/interfaces/global_configuration.cc b/libtransport/src/interfaces/global_configuration.cc new file mode 100644 index 000000000..8fb6601f3 --- /dev/null +++ b/libtransport/src/interfaces/global_configuration.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/utils/log.h> + +#include <system_error> + +namespace transport { +namespace interface { +namespace global_config { + +void parseConfigurationFile(const std::string& path) { + core::GlobalConfiguration::getInstance().parseConfiguration(path); +} + +void ConfigurationObject::get() { + std::error_code ec; + core::GlobalConfiguration::getInstance().getConfiguration(*this, ec); + + if (ec) { + TRANSPORT_LOGE("Error setting global config: %s", ec.message().c_str()); + } +} + +void ConfigurationObject::set() { + std::error_code ec; + core::GlobalConfiguration::getInstance().setConfiguration(*this, ec); + + if (ec) { + TRANSPORT_LOGE("Error setting global config: %s", ec.message().c_str()); + } +} + +} // namespace global_config +} // namespace interface +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/interfaces/p2psecure_socket_consumer.cc b/libtransport/src/interfaces/p2psecure_socket_consumer.cc index 038441dfc..e473a1e2e 100644 --- a/libtransport/src/interfaces/p2psecure_socket_consumer.cc +++ b/libtransport/src/interfaces/p2psecure_socket_consumer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/p2psecure_socket_consumer.h> - #include <implementation/p2psecure_socket_consumer.h> namespace transport { diff --git a/libtransport/src/interfaces/p2psecure_socket_producer.cc b/libtransport/src/interfaces/p2psecure_socket_producer.cc index 37352259c..10d8a1367 100644 --- a/libtransport/src/interfaces/p2psecure_socket_producer.cc +++ b/libtransport/src/interfaces/p2psecure_socket_producer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/p2psecure_socket_producer.h> - #include <implementation/p2psecure_socket_producer.h> namespace transport { @@ -25,7 +24,7 @@ P2PSecureProducerSocket::P2PSecureProducerSocket() { } P2PSecureProducerSocket::P2PSecureProducerSocket( - bool rtc, const std::shared_ptr<utils::Identity> &identity) { + bool rtc, const std::shared_ptr<auth::Identity> &identity) { socket_ = std::make_unique<implementation::P2PSecureProducerSocket>(this, rtc, identity); } diff --git a/libtransport/src/interfaces/portal.cc b/libtransport/src/interfaces/portal.cc index 36cbd0c3b..2ab51c4b9 100644 --- a/libtransport/src/interfaces/portal.cc +++ b/libtransport/src/interfaces/portal.cc @@ -14,38 +14,35 @@ */ #include <hicn/transport/interfaces/portal.h> - #include <implementation/socket.h> namespace transport { namespace interface { -using implementation::BasePortal; - -Portal::Portal() { implementation_ = new implementation::BasePortal(); } +Portal::Portal() { implementation_ = new core::Portal(); } Portal::Portal(asio::io_service &io_service) { - implementation_ = new BasePortal(io_service); + implementation_ = new core::Portal(io_service); } -Portal::~Portal() { delete reinterpret_cast<BasePortal *>(implementation_); } +Portal::~Portal() { delete reinterpret_cast<core::Portal *>(implementation_); } void Portal::setConsumerCallback(ConsumerCallback *consumer_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->setConsumerCallback(consumer_callback); } void Portal::setProducerCallback(ProducerCallback *producer_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->setProducerCallback(producer_callback); } void Portal::connect(bool is_consumer) { - reinterpret_cast<BasePortal *>(implementation_)->connect(is_consumer); + reinterpret_cast<core::Portal *>(implementation_)->connect(is_consumer); } bool Portal::interestIsPending(const core::Name &name) { - return reinterpret_cast<BasePortal *>(implementation_) + return reinterpret_cast<core::Portal *>(implementation_) ->interestIsPending(name); } @@ -53,46 +50,46 @@ void Portal::sendInterest( core::Interest::Ptr &&interest, OnContentObjectCallback &&on_content_object_callback, OnInterestTimeoutCallback &&on_interest_timeout_callback) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->sendInterest(std::move(interest), std::move(on_content_object_callback), std::move(on_interest_timeout_callback)); } void Portal::bind(const BindConfig &config) { - reinterpret_cast<BasePortal *>(implementation_)->bind(config); + reinterpret_cast<core::Portal *>(implementation_)->bind(config); } void Portal::runEventsLoop() { - reinterpret_cast<BasePortal *>(implementation_)->runEventsLoop(); + reinterpret_cast<core::Portal *>(implementation_)->runEventsLoop(); } void Portal::runOneEvent() { - reinterpret_cast<BasePortal *>(implementation_)->runOneEvent(); + reinterpret_cast<core::Portal *>(implementation_)->runOneEvent(); } void Portal::sendContentObject(core::ContentObject &content_object) { - reinterpret_cast<BasePortal *>(implementation_) + reinterpret_cast<core::Portal *>(implementation_) ->sendContentObject(content_object); } void Portal::stopEventsLoop() { - reinterpret_cast<BasePortal *>(implementation_)->stopEventsLoop(); + reinterpret_cast<core::Portal *>(implementation_)->stopEventsLoop(); } void Portal::killConnection() { - reinterpret_cast<BasePortal *>(implementation_)->killConnection(); + reinterpret_cast<core::Portal *>(implementation_)->killConnection(); } void Portal::clear() { - reinterpret_cast<BasePortal *>(implementation_)->clear(); + reinterpret_cast<core::Portal *>(implementation_)->clear(); } asio::io_service &Portal::getIoService() { - return reinterpret_cast<BasePortal *>(implementation_)->getIoService(); + return reinterpret_cast<core::Portal *>(implementation_)->getIoService(); } void Portal::registerRoute(core::Prefix &prefix) { - reinterpret_cast<BasePortal *>(implementation_)->registerRoute(prefix); + reinterpret_cast<core::Portal *>(implementation_)->registerRoute(prefix); } } // namespace interface diff --git a/libtransport/src/interfaces/socket_consumer.cc b/libtransport/src/interfaces/socket_consumer.cc index ea0606347..4eee73cab 100644 --- a/libtransport/src/interfaces/socket_consumer.cc +++ b/libtransport/src/interfaces/socket_consumer.cc @@ -46,8 +46,6 @@ void ConsumerSocket::stop() { socket_->stop(); } void ConsumerSocket::resume() { socket_->resume(); } -bool ConsumerSocket::verifyKeyPackets() { return socket_->verifyKeyPackets(); } - asio::io_service &ConsumerSocket::getIoService() { return socket_->getIoService(); } @@ -88,22 +86,10 @@ int ConsumerSocket::setSocketOption( } int ConsumerSocket::setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ConsumerSocket::setSocketOption( int socket_option_key, ConsumerInterestCallback socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ConsumerSocket::setSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ConsumerSocket::setSocketOption(int socket_option_key, IcnObserver *socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); @@ -111,7 +97,7 @@ int ConsumerSocket::setSocketOption(int socket_option_key, int ConsumerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Verifier> &socket_option_value) { + const std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } @@ -152,22 +138,10 @@ int ConsumerSocket::getSocketOption( } int ConsumerSocket::getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationCallback **socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ConsumerSocket::getSocketOption( int socket_option_key, ConsumerInterestCallback **socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ConsumerSocket::getSocketOption( - int socket_option_key, - ConsumerContentObjectVerificationFailedCallback **socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ConsumerSocket::getSocketOption(int socket_option_key, IcnObserver **socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); @@ -175,7 +149,7 @@ int ConsumerSocket::getSocketOption(int socket_option_key, int ConsumerSocket::getSocketOption( int socket_option_key, - std::shared_ptr<utils::Verifier> &socket_option_value) { + std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } diff --git a/libtransport/src/interfaces/socket_producer.cc b/libtransport/src/interfaces/socket_producer.cc index d030fe756..b04947dfd 100644 --- a/libtransport/src/interfaces/socket_producer.cc +++ b/libtransport/src/interfaces/socket_producer.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/interfaces/socket_producer.h> - #include <implementation/socket_producer.h> #include <atomic> @@ -31,11 +30,12 @@ namespace interface { using namespace core; ProducerSocket::ProducerSocket(int protocol) { - if (protocol != 0) { - throw std::runtime_error("Production protocol must be 0."); - } + socket_ = std::make_unique<implementation::ProducerSocket>(this, protocol); +} - socket_ = std::make_unique<implementation::ProducerSocket>(this); +ProducerSocket::ProducerSocket(int protocol, asio::io_service &io_service) { + socket_ = std::make_unique<implementation::ProducerSocket>(this, protocol, + io_service); } ProducerSocket::ProducerSocket(bool) {} @@ -46,19 +46,34 @@ void ProducerSocket::connect() { socket_->connect(); } bool ProducerSocket::isRunning() { return socket_->isRunning(); } -uint32_t ProducerSocket::produce(Name content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { - return socket_->produce(content_name, std::move(buffer), is_last, - start_offset); +uint32_t ProducerSocket::produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + return socket_->produceStream(content_name, std::move(buffer), is_last, + start_offset); } -void ProducerSocket::produce(ContentObject &content_object) { - return socket_->produce(content_object); +uint32_t ProducerSocket::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, bool is_last, + uint32_t start_offset) { + return socket_->produceStream(content_name, buffer, buffer_size, is_last, + start_offset); } -void ProducerSocket::produce(std::unique_ptr<utils::MemBuf> &&buffer) { - socket_->produce(std::move(buffer)); +uint32_t ProducerSocket::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + return socket_->produceDatagram(content_name, std::move(buffer)); +} + +uint32_t ProducerSocket::produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) { + return socket_->produceDatagram(content_name, buffer, buffer_size); +} + +void ProducerSocket::produce(ContentObject &content_object) { + return socket_->produce(content_object); } void ProducerSocket::asyncProduce(Name content_name, @@ -69,16 +84,10 @@ void ProducerSocket::asyncProduce(Name content_name, last_segment); } -void ProducerSocket::asyncProduce(ContentObject &content_object) { - return socket_->asyncProduce(content_object); -} - void ProducerSocket::registerPrefix(const Prefix &producer_namespace) { return socket_->registerPrefix(producer_namespace); } -void ProducerSocket::serveForever() { return socket_->serveForever(); } - void ProducerSocket::stop() { return socket_->stop(); } asio::io_service &ProducerSocket::getIoService() { @@ -105,11 +114,6 @@ int ProducerSocket::setSocketOption(int socket_option_key, return socket_->setSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::setSocketOption(int socket_option_key, - std::list<Prefix> socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::setSocketOption( int socket_option_key, ProducerContentObjectCallback socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); @@ -126,18 +130,13 @@ int ProducerSocket::setSocketOption( } int ProducerSocket::setSocketOption(int socket_option_key, - utils::CryptoHashType socket_option_value) { - return socket_->setSocketOption(socket_option_key, socket_option_value); -} - -int ProducerSocket::setSocketOption(int socket_option_key, - utils::CryptoSuite socket_option_value) { + auth::CryptoHashType socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } int ProducerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Signer> &socket_option_value) { + const std::shared_ptr<auth::Signer> &socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } @@ -156,11 +155,6 @@ int ProducerSocket::getSocketOption(int socket_option_key, return socket_->getSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::getSocketOption(int socket_option_key, - std::list<Prefix> &socket_option_value) { - return socket_->getSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::getSocketOption( int socket_option_key, ProducerContentObjectCallback **socket_option_value) { @@ -177,19 +171,13 @@ int ProducerSocket::getSocketOption( return socket_->getSocketOption(socket_option_key, socket_option_value); } -int ProducerSocket::getSocketOption( - int socket_option_key, utils::CryptoHashType &socket_option_value) { - return socket_->getSocketOption(socket_option_key, socket_option_value); -} - int ProducerSocket::getSocketOption(int socket_option_key, - utils::CryptoSuite &socket_option_value) { + auth::CryptoHashType &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } int ProducerSocket::getSocketOption( - int socket_option_key, - std::shared_ptr<utils::Signer> &socket_option_value) { + int socket_option_key, std::shared_ptr<auth::Signer> &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); } diff --git a/libtransport/src/interfaces/tls_rtc_socket_producer.cc b/libtransport/src/interfaces/tls_rtc_socket_producer.cc index 132f34721..7326fcbcb 100644 --- a/libtransport/src/interfaces/tls_rtc_socket_producer.cc +++ b/libtransport/src/interfaces/tls_rtc_socket_producer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_rtc_socket_producer.h> - #include <implementation/tls_rtc_socket_producer.h> +#include <interfaces/tls_rtc_socket_producer.h> namespace transport { namespace interface { diff --git a/libtransport/src/interfaces/tls_socket_consumer.cc b/libtransport/src/interfaces/tls_socket_consumer.cc index d87642f73..6c1c535b5 100644 --- a/libtransport/src/interfaces/tls_socket_consumer.cc +++ b/libtransport/src/interfaces/tls_socket_consumer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_socket_consumer.h> - #include <implementation/tls_socket_consumer.h> +#include <interfaces/tls_socket_consumer.h> namespace transport { namespace interface { diff --git a/libtransport/src/interfaces/tls_socket_producer.cc b/libtransport/src/interfaces/tls_socket_producer.cc index 44aa0cf8b..037702f72 100644 --- a/libtransport/src/interfaces/tls_socket_producer.cc +++ b/libtransport/src/interfaces/tls_socket_producer.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <interfaces/tls_socket_producer.h> - #include <implementation/tls_socket_producer.h> +#include <interfaces/tls_socket_producer.h> namespace transport { namespace interface { diff --git a/libtransport/src/io_modules/CMakeLists.txt b/libtransport/src/io_modules/CMakeLists.txt new file mode 100644 index 000000000..6553b9a2b --- /dev/null +++ b/libtransport/src/io_modules/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +if (${CMAKE_SYSTEM_NAME} MATCHES Android) + list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/udp/hicn_forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp/udp_socket_connector.cc + ) + + list(APPEND HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/udp/hicn_forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp/udp_socket_connector.h + ) + + set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) + set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) +else() + add_subdirectory(udp) + add_subdirectory(loopback) + add_subdirectory(forwarder) + + if (__vpp__) + add_subdirectory(memif) + endif() +endif()
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/CMakeLists.txt b/libtransport/src/io_modules/forwarder/CMakeLists.txt new file mode 100644 index 000000000..92662bc4c --- /dev/null +++ b/libtransport/src/io_modules/forwarder/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/endpoint.h + ${CMAKE_CURRENT_SOURCE_DIR}/errors.h + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel_listener.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_counter.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/forwarder.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel_listener.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_tunnel.cc +) + +build_module(forwarder_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/forwarder/configuration.h b/libtransport/src/io_modules/forwarder/configuration.h new file mode 100644 index 000000000..fcaa5530d --- /dev/null +++ b/libtransport/src/io_modules/forwarder/configuration.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace transport { +namespace core { + +struct ListenerConfig { + std::string address; + std::uint16_t port; + std::string name; +}; + +struct ConnectorConfig { + std::string local_address; + std::uint16_t local_port; + std::string remote_address; + std::uint16_t remote_port; + std::string name; +}; + +struct RouteConfig { + std::string prefix; + uint16_t weight; + std::string connector; + std::string name; +}; + +class Configuration { + public: + Configuration() : n_threads_(1) {} + + bool empty() { + return listeners_.empty() && connectors_.empty() && routes_.empty(); + } + + Configuration& setThreadNumber(std::size_t threads) { + n_threads_ = threads; + return *this; + } + + std::size_t getThreadNumber() { return n_threads_; } + + template <typename... Args> + Configuration& addListener(Args&&... args) { + listeners_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + template <typename... Args> + Configuration& addConnector(Args&&... args) { + connectors_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + template <typename... Args> + Configuration& addRoute(Args&&... args) { + routes_.emplace_back(std::forward<Args>(args)...); + return *this; + } + + std::vector<ListenerConfig>& getListeners() { return listeners_; } + + std::vector<ConnectorConfig>& getConnectors() { return connectors_; } + + std::vector<RouteConfig>& getRoutes() { return routes_; } + + private: + std::vector<ListenerConfig> listeners_; + std::vector<ConnectorConfig> connectors_; + std::vector<RouteConfig> routes_; + std::size_t n_threads_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/errors.cc b/libtransport/src/io_modules/forwarder/errors.cc new file mode 100644 index 000000000..b5f131499 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/errors.cc @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Cisco and/or its affiliates. + */ + +#include <io_modules/forwarder/errors.h> + +namespace transport { +namespace core { + +const std::error_category& forwarder_category() { + static forwarder_category_impl instance; + + return instance; +} + +const char* forwarder_category_impl::name() const throw() { + return "proxy::connector::error"; +} + +std::string forwarder_category_impl::message(int ev) const { + switch (static_cast<forwarder_error>(ev)) { + case forwarder_error::success: { + return "Success"; + } + case forwarder_error::disconnected: { + return "Connector is disconnected"; + } + case forwarder_error::receive_failed: { + return "Packet reception failed"; + } + case forwarder_error::send_failed: { + return "Packet send failed"; + } + case forwarder_error::memory_allocation_error: { + return "Impossible to allocate memory for packet pool"; + } + case forwarder_error::invalid_connector_type: { + return "Invalid type specified for connector."; + } + case forwarder_error::invalid_connector: { + return "Created connector was invalid."; + } + case forwarder_error::interest_cache_miss: { + return "interest cache miss."; + } + default: { + return "Unknown connector error"; + } + } +} +} // namespace core +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/errors.h b/libtransport/src/io_modules/forwarder/errors.h new file mode 100644 index 000000000..dd5cc8fe7 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/errors.h @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <string> +#include <system_error> + +namespace transport { +namespace core { +/** + * @brief Get the default server error category. + * @return The default server error category instance. + * + * @warning The first call to this function is thread-safe only starting with + * C++11. + */ +const std::error_category& forwarder_category(); + +/** + * The list of errors. + */ +enum class forwarder_error { + success = 0, + send_failed, + receive_failed, + disconnected, + memory_allocation_error, + invalid_connector_type, + invalid_connector, + interest_cache_miss +}; + +/** + * @brief Create an error_code instance for the given error. + * @param error The error. + * @return The error_code instance. + */ +inline std::error_code make_error_code(forwarder_error error) { + return std::error_code(static_cast<int>(error), forwarder_category()); +} + +/** + * @brief Create an error_condition instance for the given error. + * @param error The error. + * @return The error_condition instance. + */ +inline std::error_condition make_error_condition(forwarder_error error) { + return std::error_condition(static_cast<int>(error), forwarder_category()); +} + +/** + * @brief A server error category. + */ +class forwarder_category_impl : public std::error_category { + public: + /** + * @brief Get the name of the category. + * @return The name of the category. + */ + virtual const char* name() const throw(); + + /** + * @brief Get the error message for a given error. + * @param ev The error numeric value. + * @return The message associated to the error. + */ + virtual std::string message(int ev) const; +}; +} // namespace core +} // namespace transport + +namespace std { +// namespace system { +template <> +struct is_error_code_enum<::transport::core::forwarder_error> + : public std::true_type {}; +// } // namespace system +} // namespace std diff --git a/libtransport/src/io_modules/forwarder/forwarder.cc b/libtransport/src/io_modules/forwarder/forwarder.cc new file mode 100644 index 000000000..7e89e2f9f --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder.cc @@ -0,0 +1,296 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <core/local_connector.h> +#include <io_modules/forwarder/forwarder.h> +#include <io_modules/forwarder/global_id_counter.h> +#include <io_modules/forwarder/udp_tunnel.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +namespace transport { + +namespace core { + +constexpr char Forwarder::forwarder_config_section[]; + +Forwarder::Forwarder() : config_() { + using namespace std::placeholders; + GlobalConfiguration::getInstance().registerConfigurationParser( + forwarder_config_section, + std::bind(&Forwarder::parseForwarderConfiguration, this, _1, _2)); + + if (!config_.empty()) { + initThreads(); + initListeners(); + initConnectors(); + } +} + +Forwarder::~Forwarder() { + for (auto &l : listeners_) { + l->close(); + } + + for (auto &c : remote_connectors_) { + c.second->close(); + } + + GlobalConfiguration::getInstance().unregisterConfigurationParser( + forwarder_config_section); +} + +void Forwarder::initThreads() { + for (unsigned i = 0; i < config_.getThreadNumber(); i++) { + thread_pool_.emplace_back(io_service_, /* detached */ false); + } +} + +void Forwarder::initListeners() { + using namespace std::placeholders; + for (auto &l : config_.getListeners()) { + listeners_.emplace_back(std::make_shared<UdpTunnelListener>( + io_service_, + std::bind(&Forwarder::onPacketFromListener, this, _1, _2, _3), + asio::ip::udp::endpoint(asio::ip::address::from_string(l.address), + l.port))); + } +} + +void Forwarder::initConnectors() { + using namespace std::placeholders; + for (auto &c : config_.getConnectors()) { + auto id = GlobalCounter<Connector::Id>::getInstance().getNext(); + auto conn = new UdpTunnelConnector( + io_service_, std::bind(&Forwarder::onPacketReceived, this, _1, _2, _3), + std::bind(&Forwarder::onPacketSent, this, _1, _2), + std::bind(&Forwarder::onConnectorClosed, this, _1), + std::bind(&Forwarder::onConnectorReconnected, this, _1)); + conn->setConnectorId(id); + remote_connectors_.emplace(id, conn); + conn->connect(c.remote_address, c.remote_port, c.local_address, + c.local_port); + } +} + +Connector::Id Forwarder::registerLocalConnector( + asio::io_service &io_service, + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback) { + utils::SpinLock::Acquire locked(connector_lock_); + auto id = GlobalCounter<Connector::Id>::getInstance().getNext(); + auto connector = std::make_shared<LocalConnector>( + io_service, receive_callback, nullptr, nullptr, reconnect_callback); + connector->setConnectorId(id); + local_connectors_.emplace(id, std::move(connector)); + return id; +} + +Forwarder &Forwarder::deleteConnector(Connector::Id id) { + utils::SpinLock::Acquire locked(connector_lock_); + auto it = local_connectors_.find(id); + if (it != local_connectors_.end()) { + it->second->close(); + local_connectors_.erase(it); + } + + return *this; +} + +Connector::Ptr Forwarder::getConnector(Connector::Id id) { + utils::SpinLock::Acquire locked(connector_lock_); + auto it = local_connectors_.find(id); + if (it != local_connectors_.end()) { + return it->second; + } + + return nullptr; +} + +void Forwarder::onPacketFromListener(Connector *connector, + utils::MemBuf &packet_buffer, + const std::error_code &ec) { + // Create connector + connector->setReceiveCallback( + std::bind(&Forwarder::onPacketReceived, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); + + TRANSPORT_LOGD("Packet received from listener."); + + { + utils::SpinLock::Acquire locked(connector_lock_); + remote_connectors_.emplace(connector->getConnectorId(), + connector->shared_from_this()); + } + // TODO Check if control packet or not. For the moment it is not. + onPacketReceived(connector, packet_buffer, ec); +} + +void Forwarder::onPacketReceived(Connector *connector, + utils::MemBuf &packet_buffer, + const std::error_code &ec) { + // Figure out the type of packet we received + bool is_interest = Packet::isInterest(packet_buffer.data()); + + Packet *packet = nullptr; + if (is_interest) { + packet = static_cast<Interest *>(&packet_buffer); + } else { + packet = static_cast<ContentObject *>(&packet_buffer); + } + + for (auto &c : local_connectors_) { + auto role = c.second->getRole(); + auto is_producer = role == Connector::Role::PRODUCER; + if ((is_producer && is_interest) || (!is_producer && !is_interest)) { + c.second->send(*packet); + } else { + TRANSPORT_LOGD( + "Error sending packet to local connector. is_interest = %d - " + "is_producer = %d", + (int)is_interest, (int)is_producer); + } + } + + // PCS Lookup + FIB lookup. Skip for now + + // Forward packet to local connectors +} + +void Forwarder::send(Packet &packet) { + // TODo Here a nice PIT/CS / FIB would be required:) + // For now let's just forward the packet on the remote connector we get + if (remote_connectors_.begin() == remote_connectors_.end()) { + return; + } + + auto remote_endpoint = + remote_connectors_.begin()->second->getRemoteEndpoint(); + TRANSPORT_LOGD("Sending packet to: %s:%u", + remote_endpoint.getAddress().to_string().c_str(), + remote_endpoint.getPort()); + remote_connectors_.begin()->second->send(packet); +} + +void Forwarder::onPacketSent(Connector *connector, const std::error_code &ec) {} + +void Forwarder::onConnectorClosed(Connector *connector) {} + +void Forwarder::onConnectorReconnected(Connector *connector) {} + +void Forwarder::parseForwarderConfiguration( + const libconfig::Setting &forwarder_config, std::error_code &ec) { + using namespace libconfig; + + // n_thread + if (forwarder_config.exists("n_threads")) { + // Get number of threads + int n_threads = 1; + forwarder_config.lookupValue("n_threads", n_threads); + TRANSPORT_LOGD("Forwarder threads from config file: %u", n_threads); + config_.setThreadNumber(n_threads); + } + + // listeners + if (forwarder_config.exists("listeners")) { + // get path where looking for modules + const Setting &listeners = forwarder_config.lookup("listeners"); + auto count = listeners.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &listener = listeners[i]; + ListenerConfig list; + unsigned port; + + list.name = listener.getName(); + listener.lookupValue("local_address", list.address); + listener.lookupValue("local_port", port); + list.port = (uint16_t)(port); + + TRANSPORT_LOGD("Adding listener %s, (%s:%u)", list.name.c_str(), + list.address.c_str(), list.port); + config_.addListener(std::move(list)); + } + } + + // connectors + if (forwarder_config.exists("connectors")) { + // get path where looking for modules + const Setting &connectors = forwarder_config.lookup("connectors"); + auto count = connectors.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &connector = connectors[i]; + ConnectorConfig conn; + + conn.name = connector.getName(); + unsigned port = 0; + + if (!connector.lookupValue("local_address", conn.local_address)) { + conn.local_address = ""; + } + + if (!connector.lookupValue("local_port", port)) { + port = 0; + } + + conn.local_port = (uint16_t)(port); + + if (!connector.lookupValue("remote_address", conn.remote_address)) { + throw errors::RuntimeException( + "Error in configuration file: remote_address is a mandatory field " + "of Connectors."); + } + + if (!connector.lookupValue("remote_port", port)) { + throw errors::RuntimeException( + "Error in configuration file: remote_port is a mandatory field " + "of Connectors."); + } + + conn.remote_port = (uint16_t)(port); + + TRANSPORT_LOGD("Adding connector %s, (%s:%u %s:%u)", conn.name.c_str(), + conn.local_address.c_str(), conn.local_port, + conn.remote_address.c_str(), conn.remote_port); + config_.addConnector(std::move(conn)); + } + } + + // Routes + if (forwarder_config.exists("routes")) { + const Setting &routes = forwarder_config.lookup("routes"); + auto count = routes.getLength(); + + for (int i = 0; i < count; i++) { + const Setting &route = routes[i]; + RouteConfig r; + unsigned weight; + + r.name = route.getName(); + route.lookupValue("prefix", r.prefix); + route.lookupValue("weight", weight); + route.lookupValue("connector", r.connector); + r.weight = (uint16_t)(weight); + + TRANSPORT_LOGD("Adding route %s %s (%s %u)", r.name.c_str(), + r.prefix.c_str(), r.connector.c_str(), r.weight); + config_.addRoute(std::move(r)); + } + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/forwarder.h b/libtransport/src/io_modules/forwarder/forwarder.h new file mode 100644 index 000000000..5b564bb5e --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> +#include <hicn/transport/utils/event_thread.h> +#include <hicn/transport/utils/singleton.h> +#include <hicn/transport/utils/spinlock.h> +#include <io_modules/forwarder/configuration.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +#include <atomic> +#include <libconfig.h++> +#include <unordered_map> + +namespace transport { + +namespace core { + +class Forwarder : public utils::Singleton<Forwarder> { + static constexpr char forwarder_config_section[] = "forwarder"; + friend class utils::Singleton<Forwarder>; + + public: + Forwarder(); + + ~Forwarder(); + + void initThreads(); + void initListeners(); + void initConnectors(); + + Connector::Id registerLocalConnector( + asio::io_service &io_service, + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback); + + Forwarder &deleteConnector(Connector::Id id); + + Connector::Ptr getConnector(Connector::Id id); + + void send(Packet &packet); + + void stop(); + + private: + void onPacketFromListener(Connector *connector, utils::MemBuf &packet_buffer, + const std::error_code &ec); + void onPacketReceived(Connector *connector, utils::MemBuf &packet_buffer, + const std::error_code &ec); + void onPacketSent(Connector *connector, const std::error_code &ec); + void onConnectorClosed(Connector *connector); + void onConnectorReconnected(Connector *connector); + + void parseForwarderConfiguration(const libconfig::Setting &io_config, + std::error_code &ec); + + asio::io_service io_service_; + utils::SpinLock connector_lock_; + + /** + * Connectors and listeners must be declares *before* thread_pool_, so that + * threads destructors will wait for them to gracefully close before being + * destroyed. + */ + std::unordered_map<Connector::Id, Connector::Ptr> remote_connectors_; + std::unordered_map<Connector::Id, Connector::Ptr> local_connectors_; + std::vector<UdpTunnelListener::Ptr> listeners_; + + std::vector<utils::EventThread> thread_pool_; + + Configuration config_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.cc b/libtransport/src/io_modules/forwarder/forwarder_module.cc new file mode 100644 index 000000000..356b42d3b --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder_module.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/forwarder/forwarder_module.h> + +namespace transport { + +namespace core { + +ForwarderModule::ForwarderModule() + : IoModule(), + name_(""), + connector_id_(Connector::invalid_connector), + forwarder_(Forwarder::getInstance()) {} + +ForwarderModule::~ForwarderModule() { + forwarder_.deleteConnector(connector_id_); +} + +bool ForwarderModule::isConnected() { return true; } + +void ForwarderModule::send(Packet &packet) { + IoModule::send(packet); + forwarder_.send(packet); + // TRANSPORT_LOGD("ForwarderModule: sending from %u to %d", local_id_, + // 1 - local_id_); + + // local_faces_.at(1 - local_id_).onPacket(packet); +} + +void ForwarderModule::send(const uint8_t *packet, std::size_t len) { + // not supported + throw errors::NotImplementedException(); +} + +void ForwarderModule::registerRoute(const Prefix &prefix) { + // For the moment we route packets from one socket to the other. + // Next step will be to introduce a FIB + return; +} + +void ForwarderModule::closeConnection() { + forwarder_.deleteConnector(connector_id_); +} + +void ForwarderModule::init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name) { + connector_id_ = forwarder_.registerLocalConnector( + io_service, std::move(receive_callback), std::move(reconnect_callback)); + name_ = app_name; +} + +void ForwarderModule::processControlMessageReply(utils::MemBuf &packet_buffer) { + return; +} + +void ForwarderModule::connect(bool is_consumer) { + forwarder_.getConnector(connector_id_) + ->setRole(is_consumer ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +std::uint32_t ForwarderModule::getMtu() { return interface_mtu; } + +bool ForwarderModule::isControlMessage(const uint8_t *message) { return false; } + +extern "C" IoModule *create_module(void) { return new ForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.h b/libtransport/src/io_modules/forwarder/forwarder_module.h new file mode 100644 index 000000000..58bfb7996 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/forwarder_module.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> +#include <io_modules/forwarder/forwarder.h> + +#include <atomic> + +namespace transport { + +namespace core { + +class Forwarder; + +class ForwarderModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + ForwarderModule(); + + ~ForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + std::string name_; + Connector::Id connector_id_; + Forwarder &forwarder_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/global_id_counter.h b/libtransport/src/io_modules/forwarder/global_id_counter.h new file mode 100644 index 000000000..fe8d76730 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/global_id_counter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <atomic> +#include <mutex> + +namespace transport { + +namespace core { + +template <typename T = uint64_t> +class GlobalCounter { + public: + static GlobalCounter& getInstance() { + std::lock_guard<std::mutex> lock(global_mutex_); + + if (!instance_) { + instance_.reset(new GlobalCounter()); + } + + return *instance_; + } + + T getNext() { return counter_++; } + + private: + GlobalCounter() : counter_(0) {} + static std::unique_ptr<GlobalCounter<T>> instance_; + static std::mutex global_mutex_; + std::atomic<T> counter_; +}; + +template <typename T> +std::unique_ptr<GlobalCounter<T>> GlobalCounter<T>::instance_ = nullptr; + +template <typename T> +std::mutex GlobalCounter<T>::global_mutex_; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel.cc b/libtransport/src/io_modules/forwarder/udp_tunnel.cc new file mode 100644 index 000000000..dc725fc4e --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel.cc @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#include <hicn/transport/utils/branch_prediction.h> +#include <io_modules/forwarder/errors.h> +#include <io_modules/forwarder/udp_tunnel.h> + +#include <iostream> +#include <thread> +#include <vector> + +namespace transport { +namespace core { + +UdpTunnelConnector::~UdpTunnelConnector() {} + +void UdpTunnelConnector::connect(const std::string &hostname, uint16_t port, + const std::string &bind_address, + uint16_t bind_port) { + if (state_ == State::CLOSED) { + state_ = State::CONNECTING; + endpoint_iterator_ = resolver_.resolve({hostname, std::to_string(port)}); + remote_endpoint_send_ = *endpoint_iterator_; + socket_->open(remote_endpoint_send_.protocol()); + + if (!bind_address.empty() && bind_port != 0) { + using namespace asio::ip; + socket_->bind( + udp::endpoint(address::from_string(bind_address), bind_port)); + } + + state_ = State::CONNECTED; + + remote_endpoint_ = Endpoint(remote_endpoint_send_); + local_endpoint_ = Endpoint(socket_->local_endpoint()); + + doRecvPacket(); + +#ifdef LINUX + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); +#endif + } +} + +void UdpTunnelConnector::send(Packet &packet) { + strand_->post([this, pkt{packet.shared_from_this()}]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(pkt)); + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + if (!write_in_progress) { + doSendPacket(); + } + } else { + data_available_ = true; + } + }); +} + +void UdpTunnelConnector::send(const uint8_t *packet, std::size_t len) {} + +void UdpTunnelConnector::close() { + TRANSPORT_LOGD("UDPTunnelConnector::close"); + state_ = State::CLOSED; + bool is_socket_owned = socket_.use_count() == 1; + if (is_socket_owned) { + io_service_.dispatch([this]() { + this->socket_->close(); + // on_close_callback_(shared_from_this()); + }); + } +} + +void UdpTunnelConnector::doSendPacket() { +#ifdef LINUX + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); +#else + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const ::utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_->async_send_to( + std::move(array), remote_endpoint_send_, + strand_->wrap([this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + sent_callback_(this, make_error_code(forwarder_error::success)); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + sendFailed(); + sent_callback_(this, ec); + } + + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doSendPacket(); + } + })); +#endif +} + +#ifdef LINUX +void UdpTunnelConnector::writeHandler(std::error_code ec) { + if (TRANSPORT_EXPECT_FALSE(state_ != State::CONNECTED)) { + return; + } + + auto len = std::min(output_buffer_.size(), std::size_t(Connector::max_burst)); + + if (len) { + int m = 0; + for (auto &p : output_buffer_) { + auto packet = p.get(); + ::utils::MemBuf *current = packet; + int b = 0; + do { + // array.push_back(asio::const_buffer(current->data(), + // current->length())); + tx_iovecs_[m][b].iov_base = current->writableData(); + tx_iovecs_[m][b].iov_len = current->length(); + current = current->next(); + b++; + } while (current != packet); + + tx_msgs_[m].msg_hdr.msg_iov = tx_iovecs_[m]; + tx_msgs_[m].msg_hdr.msg_iovlen = b; + tx_msgs_[m].msg_hdr.msg_name = remote_endpoint_send_.data(); + tx_msgs_[m].msg_hdr.msg_namelen = remote_endpoint_send_.size(); + m++; + + if (--len == 0) { + break; + } + } + + int retval = sendmmsg(socket_->native_handle(), tx_msgs_, m, MSG_DONTWAIT); + if (retval > 0) { + while (retval--) { + output_buffer_.pop_front(); + } + } else if (retval != EWOULDBLOCK && retval != EAGAIN) { + TRANSPORT_LOGE("Error sending messages! %s %d\n", strerror(errno), + retval); + return; + } + } + + if (!output_buffer_.empty()) { + send_timer_.expires_from_now(std::chrono::microseconds(50)); + send_timer_.async_wait(std::bind(&UdpTunnelConnector::writeHandler, this, + std::placeholders::_1)); + } +} + +void UdpTunnelConnector::readHandler(std::error_code ec) { + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + + // TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + if (current_position_ == 0) { + for (int i = 0; i < max_burst; i++) { + auto read_buffer = getRawBuffer(); + rx_iovecs_[i][0].iov_base = read_buffer.first; + rx_iovecs_[i][0].iov_len = read_buffer.second; + rx_msgs_[i].msg_hdr.msg_iov = rx_iovecs_[i]; + rx_msgs_[i].msg_hdr.msg_iovlen = 1; + } + } + + int res = recvmmsg(socket_->native_handle(), rx_msgs_ + current_position_, + max_burst - current_position_, MSG_DONTWAIT, nullptr); + if (res < 0) { + TRANSPORT_LOGE("Error receiving messages! %s %d\n", strerror(errno), + res); + return; + } + + for (int i = 0; i < res; i++) { + auto packet = getPacketFromBuffer( + reinterpret_cast<uint8_t *>( + rx_msgs_[current_position_].msg_hdr.msg_iov[0].iov_base), + rx_msgs_[current_position_].msg_len); + receiveSuccess(*packet); + receive_callback_(this, *packet, + make_error_code(forwarder_error::success)); + ++current_position_; + } + + doRecvPacket(); + } else { + TRANSPORT_LOGE( + "Error in UDP: Receiving packets from a not connected socket."); + } + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + // receive_callback_(this, *read_msg_, ec); + TRANSPORT_LOGE("Error in UDP connector: %d %s", ec.value(), + ec.message().c_str()); + } else { + TRANSPORT_LOGE("Error while not connector"); + } + } +} +#endif + +void UdpTunnelConnector::doRecvPacket() { +#ifdef LINUX + if (state_ == State::CONNECTED) { +#if ((ASIO_VERSION / 100 % 1000) < 11) + socket_->async_receive(asio::null_buffers(), +#else + socket_->async_wait(asio::ip::tcp::socket::wait_read, +#endif + std::bind(&UdpTunnelConnector::readHandler, this, + std::placeholders::_1)); + } +#else + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + read_msg_ = getRawBuffer(); + socket_->async_receive_from( + asio::buffer(read_msg_.first, read_msg_.second), remote_endpoint_recv_, + [this](std::error_code ec, std::size_t length) { + TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + auto packet = getPacketFromBuffer(read_msg_.first, length); + receiveSuccess(*packet); + receive_callback_(this, *packet, + make_error_code(forwarder_error::success)); + doRecvPacket(); + } else { + TRANSPORT_LOGE( + "Error in UDP: Receiving packets from a not connected socket."); + } + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + if (TRANSPORT_EXPECT_TRUE(state_ == State::CONNECTED)) { + TRANSPORT_LOGE("Error in UDP connector: %d %s", ec.value(), + ec.message().c_str()); + } else { + TRANSPORT_LOGE("Error while not connector"); + } + } + }); +#endif +} + +void UdpTunnelConnector::doConnect() { + asio::async_connect( + *socket_, endpoint_iterator_, + [this](std::error_code ec, asio::ip::udp::resolver::iterator) { + if (!ec) { + state_ = State::CONNECTED; + doRecvPacket(); + + if (data_available_) { + data_available_ = false; + doSendPacket(); + } + } else { + TRANSPORT_LOGE("[Hproxy] - UDP Connection failed!!!"); + timer_.expires_from_now(std::chrono::milliseconds(500)); + timer_.async_wait(std::bind(&UdpTunnelConnector::doConnect, this)); + } + }); +} + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel.h b/libtransport/src/io_modules/forwarder/udp_tunnel.h new file mode 100644 index 000000000..df472af91 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/platform.h> +#include <io_modules/forwarder/errors.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <iostream> +#include <memory> + +namespace transport { +namespace core { + +class UdpTunnelListener; + +class UdpTunnelConnector : public Connector { + friend class UdpTunnelListener; + + public: + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect> + UdpTunnelConnector(asio::io_service &io_service, + ReceiveCallback &&receive_callback, + SentCallback &&packet_sent, OnClose &&on_close_callback, + OnReconnect &&on_reconnect) + : Connector(receive_callback, packet_sent, on_close_callback, + on_reconnect), + io_service_(io_service), + strand_(std::make_shared<asio::io_service::strand>(io_service_)), + socket_(std::make_shared<asio::ip::udp::socket>(io_service_)), + resolver_(io_service_), + timer_(io_service_), +#ifdef LINUX + send_timer_(io_service_), + tx_iovecs_{0}, + tx_msgs_{0}, + rx_iovecs_{0}, + rx_msgs_{0}, + current_position_(0), +#else + read_msg_(nullptr, 0), +#endif + data_available_(false) { + } + + template <typename ReceiveCallback, typename SentCallback, typename OnClose, + typename OnReconnect, typename EndpointType> + UdpTunnelConnector(std::shared_ptr<asio::ip::udp::socket> &socket, + std::shared_ptr<asio::io_service::strand> &strand, + ReceiveCallback &&receive_callback, + SentCallback &&packet_sent, OnClose &&on_close_callback, + OnReconnect &&on_reconnect, EndpointType &&remote_endpoint) + : Connector(receive_callback, packet_sent, on_close_callback, + on_reconnect), +#if ((ASIO_VERSION / 100 % 1000) < 12) + io_service_(socket->get_io_service()), +#else + io_service_((asio::io_context &)(socket->get_executor().context())), +#endif + strand_(strand), + socket_(socket), + resolver_(io_service_), + remote_endpoint_send_(std::forward<EndpointType &&>(remote_endpoint)), + timer_(io_service_), +#ifdef LINUX + send_timer_(io_service_), + tx_iovecs_{0}, + tx_msgs_{0}, + rx_iovecs_{0}, + rx_msgs_{0}, + current_position_(0), +#else + read_msg_(nullptr, 0), +#endif + data_available_(false) { + if (socket_->is_open()) { + state_ = State::CONNECTED; + remote_endpoint_ = Endpoint(remote_endpoint_send_); + local_endpoint_ = socket_->local_endpoint(); + } + } + + ~UdpTunnelConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(const std::string &hostname, std::uint16_t port, + const std::string &bind_address = "", + std::uint16_t bind_port = 0); + + auto shared_from_this() { return utils::shared_from(this); } + + private: + void doConnect(); + void doRecvPacket(); + + void doRecvPacket(utils::MemBuf &buffer) { + receive_callback_(this, buffer, make_error_code(forwarder_error::success)); + } + +#ifdef LINUX + void readHandler(std::error_code ec); + void writeHandler(std::error_code ec); +#endif + + void setConnected() { state_ = State::CONNECTED; } + + void doSendPacket(); + void doClose(); + + private: + asio::io_service &io_service_; + std::shared_ptr<asio::io_service::strand> strand_; + std::shared_ptr<asio::ip::udp::socket> socket_; + asio::ip::udp::resolver resolver_; + asio::ip::udp::resolver::iterator endpoint_iterator_; + asio::ip::udp::endpoint remote_endpoint_send_; + asio::ip::udp::endpoint remote_endpoint_recv_; + + asio::steady_timer timer_; + +#ifdef LINUX + asio::steady_timer send_timer_; + struct iovec tx_iovecs_[max_burst][8]; + struct mmsghdr tx_msgs_[max_burst]; + struct iovec rx_iovecs_[max_burst][8]; + struct mmsghdr rx_msgs_[max_burst]; + std::uint8_t current_position_; +#else + std::pair<uint8_t *, std::size_t> read_msg_; +#endif + + bool data_available_; +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc new file mode 100644 index 000000000..12246c3cf --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.cc @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#include <hicn/transport/utils/hash.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/forwarder/udp_tunnel.h> +#include <io_modules/forwarder/udp_tunnel_listener.h> + +#ifndef LINUX +namespace std { +size_t hash<asio::ip::udp::endpoint>::operator()( + const asio::ip::udp::endpoint &endpoint) const { + auto hash_ip = endpoint.address().is_v4() + ? endpoint.address().to_v4().to_ulong() + : utils::hash::fnv32_buf( + endpoint.address().to_v6().to_bytes().data(), 16); + uint16_t port = endpoint.port(); + return utils::hash::fnv32_buf(&port, 2, hash_ip); +} +} // namespace std +#endif + +namespace transport { +namespace core { + +UdpTunnelListener::~UdpTunnelListener() {} + +void UdpTunnelListener::close() { + strand_->post([this]() { + if (socket_->is_open()) { + socket_->close(); + } + }); +} + +#ifdef LINUX +void UdpTunnelListener::readHandler(std::error_code ec) { + TRANSPORT_LOGD("UdpTunnelConnector receive packet"); + + // TRANSPORT_LOGD("UdpTunnelConnector received packet length=%lu", length); + if (TRANSPORT_EXPECT_TRUE(!ec)) { + if (current_position_ == 0) { + for (int i = 0; i < Connector::max_burst; i++) { + auto read_buffer = Connector::getRawBuffer(); + iovecs_[i][0].iov_base = read_buffer.first; + iovecs_[i][0].iov_len = read_buffer.second; + msgs_[i].msg_hdr.msg_iov = iovecs_[i]; + msgs_[i].msg_hdr.msg_iovlen = 1; + msgs_[i].msg_hdr.msg_name = &remote_endpoints_[i]; + msgs_[i].msg_hdr.msg_namelen = sizeof(remote_endpoints_[i]); + } + } + + int res = recvmmsg(socket_->native_handle(), msgs_ + current_position_, + Connector::max_burst - current_position_, MSG_DONTWAIT, + nullptr); + if (res < 0) { + TRANSPORT_LOGE("Error in recvmmsg."); + } + + for (int i = 0; i < res; i++) { + auto packet = Connector::getPacketFromBuffer( + reinterpret_cast<uint8_t *>( + msgs_[current_position_].msg_hdr.msg_iov[0].iov_base), + msgs_[current_position_].msg_len); + auto connector_id = + utils::hash::fnv64_buf(msgs_[current_position_].msg_hdr.msg_name, + msgs_[current_position_].msg_hdr.msg_namelen); + + auto connector = connectors_.find(connector_id); + if (connector == connectors_.end()) { + // Create new connector corresponding to new client + + /* + * Get the remote endpoint for this particular message + */ + using namespace asio::ip; + if (local_endpoint_.address().is_v4()) { + auto addr = reinterpret_cast<struct sockaddr_in *>( + &remote_endpoints_[current_position_]); + address_v4::bytes_type address_bytes; + std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin_addr), + address_bytes.size(), address_bytes.begin()); + address_v4 address(address_bytes); + remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin_port)); + } else { + auto addr = reinterpret_cast<struct sockaddr_in6 *>( + &remote_endpoints_[current_position_]); + address_v6::bytes_type address_bytes; + std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin6_addr), + address_bytes.size(), address_bytes.begin()); + address_v6 address(address_bytes); + remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin6_port)); + } + + /** + * Create new connector sharing the same socket of this listener. + */ + auto ret = connectors_.emplace( + connector_id, + std::make_shared<UdpTunnelConnector>( + socket_, strand_, receive_callback_, + [](Connector *, const std::error_code &) {}, [](Connector *) {}, + [](Connector *) {}, std::move(remote_endpoint_))); + connector = ret.first; + connector->second->setConnectorId(connector_id); + } + + /** + * Use connector callback to process incoming message. + */ + UdpTunnelConnector *c = + dynamic_cast<UdpTunnelConnector *>(connector->second.get()); + c->doRecvPacket(*packet); + + ++current_position_; + } + + doRecvPacket(); + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } +} +#endif + +void UdpTunnelListener::doRecvPacket() { +#ifdef LINUX +#if ((ASIO_VERSION / 100 % 1000) < 11) + socket_->async_receive( + asio::null_buffers(), +#else + socket_->async_wait( + asio::ip::tcp::socket::wait_read, +#endif + std::bind(&UdpTunnelListener::readHandler, this, std::placeholders::_1)); +#else + read_msg_ = Connector::getRawBuffer(); + socket_->async_receive_from( + asio::buffer(read_msg_.first, read_msg_.second), remote_endpoint_, + [this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + auto packet = Connector::getPacketFromBuffer(read_msg_.first, length); + auto connector_id = + std::hash<asio::ip::udp::endpoint>{}(remote_endpoint_); + auto connector = connectors_.find(connector_id); + if (connector == connectors_.end()) { + // Create new connector corresponding to new client + auto ret = connectors_.emplace( + connector_id, std::make_shared<UdpTunnelConnector>( + socket_, strand_, receive_callback_, + [](Connector *, const std::error_code &) {}, + [](Connector *) {}, [](Connector *) {}, + std::move(remote_endpoint_))); + connector = ret.first; + connector->second->setConnectorId(connector_id); + } + + UdpTunnelConnector *c = + dynamic_cast<UdpTunnelConnector *>(connector->second.get()); + c->doRecvPacket(*packet); + doRecvPacket(); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + TRANSPORT_LOGE("The connection has been closed by the application."); + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + }); +#endif +} +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h new file mode 100644 index 000000000..0ee40a400 --- /dev/null +++ b/libtransport/src/io_modules/forwarder/udp_tunnel_listener.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/platform.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <unordered_map> + +namespace std { +template <> +struct hash<asio::ip::udp::endpoint> { + size_t operator()(const asio::ip::udp::endpoint &endpoint) const; +}; +} // namespace std + +namespace transport { +namespace core { + +class UdpTunnelListener + : public std::enable_shared_from_this<UdpTunnelListener> { + using PacketReceivedCallback = Connector::PacketReceivedCallback; + using EndpointId = std::pair<uint32_t, uint16_t>; + + static constexpr uint16_t default_port = 5004; + + public: + using Ptr = std::shared_ptr<UdpTunnelListener>; + + template <typename ReceiveCallback> + UdpTunnelListener(asio::io_service &io_service, + ReceiveCallback &&receive_callback, + asio::ip::udp::endpoint endpoint = asio::ip::udp::endpoint( + asio::ip::udp::v4(), default_port)) + : io_service_(io_service), + strand_(std::make_shared<asio::io_service::strand>(io_service_)), + socket_(std::make_shared<asio::ip::udp::socket>(io_service_, + endpoint.protocol())), + local_endpoint_(endpoint), + receive_callback_(std::forward<ReceiveCallback &&>(receive_callback)), +#ifndef LINUX + read_msg_(nullptr, 0) +#else + iovecs_{0}, + msgs_{0}, + current_position_(0) +#endif + { + if (endpoint.protocol() == asio::ip::udp::v6()) { + std::error_code ec; + socket_->set_option(asio::ip::v6_only(false), ec); + // Call succeeds only on dual stack systems. + } + socket_->bind(local_endpoint_); + io_service_.post(std::bind(&UdpTunnelListener::doRecvPacket, this)); + } + + ~UdpTunnelListener(); + + void close(); + + int deleteConnector(Connector *connector) { + return connectors_.erase(connector->getConnectorId()); + } + + template <typename ReceiveCallback> + void setReceiveCallback(ReceiveCallback &&callback) { + receive_callback_ = std::forward<ReceiveCallback &&>(callback); + } + + Connector *findConnector(Connector::Id connId) { + auto it = connectors_.find(connId); + if (it != connectors_.end()) { + return it->second.get(); + } + + return nullptr; + } + + private: + void doRecvPacket(); + + void readHandler(std::error_code ec); + + asio::io_service &io_service_; + std::shared_ptr<asio::io_service::strand> strand_; + std::shared_ptr<asio::ip::udp::socket> socket_; + asio::ip::udp::endpoint local_endpoint_; + asio::ip::udp::endpoint remote_endpoint_; + std::unordered_map<Connector::Id, std::shared_ptr<Connector>> connectors_; + + PacketReceivedCallback receive_callback_; + +#ifdef LINUX + struct iovec iovecs_[Connector::max_burst][8]; + struct mmsghdr msgs_[Connector::max_burst]; + struct sockaddr_storage remote_endpoints_[Connector::max_burst]; + std::uint8_t current_position_; +#else + std::pair<uint8_t *, std::size_t> read_msg_; +#endif +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/CMakeLists.txt b/libtransport/src/io_modules/loopback/CMakeLists.txt new file mode 100644 index 000000000..ac6dc8068 --- /dev/null +++ b/libtransport/src/io_modules/loopback/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/loopback_module.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/loopback_module.cc +) + +build_module(loopback_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + # LIBRARY_ROOT_DIR "vpp_plugins" + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/loopback/local_face.cc b/libtransport/src/io_modules/loopback/local_face.cc new file mode 100644 index 000000000..a59dab235 --- /dev/null +++ b/libtransport/src/io_modules/loopback/local_face.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/loopback/local_face.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +Face::Face(Connector::PacketReceivedCallback &&receive_callback, + asio::io_service &io_service, const std::string &app_name) + : receive_callback_(std::move(receive_callback)), + io_service_(io_service), + name_(app_name) {} + +Face::Face(const Face &other) + : receive_callback_(other.receive_callback_), + io_service_(other.io_service_), + name_(other.name_) {} + +Face::Face(Face &&other) + : receive_callback_(std::move(other.receive_callback_)), + io_service_(other.io_service_), + name_(std::move(other.name_)) {} + +Face &Face::operator=(const Face &other) { + receive_callback_ = other.receive_callback_; + io_service_ = other.io_service_; + name_ = other.name_; + + return *this; +} + +Face &Face::operator=(Face &&other) { + receive_callback_ = std::move(other.receive_callback_); + io_service_ = std::move(other.io_service_); + name_ = std::move(other.name_); + + return *this; +} + +void Face::onPacket(const Packet &packet) { + TRANSPORT_LOGD("Sending content to local socket."); + + if (Packet::isInterest(packet.data())) { + rescheduleOnIoService<Interest>(packet); + } else { + rescheduleOnIoService<ContentObject>(packet); + } +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/io_modules/loopback/local_face.h b/libtransport/src/io_modules/loopback/local_face.h new file mode 100644 index 000000000..1cbcc2c72 --- /dev/null +++ b/libtransport/src/io_modules/loopback/local_face.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/utils/move_wrapper.h> + +#include <asio/io_service.hpp> + +namespace transport { +namespace core { + +class Face { + public: + Face(Connector::PacketReceivedCallback &&receive_callback, + asio::io_service &io_service, const std::string &app_name); + + Face(const Face &other); + Face(Face &&other); + void onPacket(const Packet &packet); + Face &operator=(Face &&other); + Face &operator=(const Face &other); + + private: + template <typename T> + void rescheduleOnIoService(const Packet &packet) { + auto p = core::PacketManager<T>::getInstance().getPacket(); + p->replace(packet.data(), packet.length()); + io_service_.get().post([this, p]() mutable { + receive_callback_(nullptr, *p, make_error_code(0)); + }); + } + + Connector::PacketReceivedCallback receive_callback_; + std::reference_wrapper<asio::io_service> io_service_; + std::string name_; +}; + +} // namespace core +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/loopback_module.cc b/libtransport/src/io_modules/loopback/loopback_module.cc new file mode 100644 index 000000000..0bdbf8c8e --- /dev/null +++ b/libtransport/src/io_modules/loopback/loopback_module.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/loopback/loopback_module.h> + +namespace transport { + +namespace core { + +std::vector<std::unique_ptr<LocalConnector>> LoopbackModule::local_faces_; +std::atomic<uint32_t> LoopbackModule::global_counter_(0); + +LoopbackModule::LoopbackModule() : IoModule(), local_id_(~0) {} + +LoopbackModule::~LoopbackModule() {} + +void LoopbackModule::connect(bool is_consumer) {} + +bool LoopbackModule::isConnected() { return true; } + +void LoopbackModule::send(Packet &packet) { + IoModule::send(packet); + + TRANSPORT_LOGD("LoopbackModule: sending from %u to %d", local_id_, + 1 - local_id_); + + local_faces_.at(1 - local_id_)->send(packet); +} + +void LoopbackModule::send(const uint8_t *packet, std::size_t len) { + // not supported + throw errors::NotImplementedException(); +} + +void LoopbackModule::registerRoute(const Prefix &prefix) { + // For the moment we route packets from one socket to the other. + // Next step will be to introduce a FIB + return; +} + +void LoopbackModule::closeConnection() { + local_faces_.erase(local_faces_.begin() + local_id_); +} + +void LoopbackModule::init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name) { + if (local_id_ == uint32_t(~0) && global_counter_ < 2) { + local_id_ = global_counter_++; + local_faces_.emplace( + local_faces_.begin() + local_id_, + new LocalConnector(io_service, std::move(receive_callback), nullptr, + nullptr, std::move(reconnect_callback))); + } +} + +void LoopbackModule::processControlMessageReply(utils::MemBuf &packet_buffer) { + return; +} + +std::uint32_t LoopbackModule::getMtu() { return interface_mtu; } + +bool LoopbackModule::isControlMessage(const uint8_t *message) { return false; } + +extern "C" IoModule *create_module(void) { return new LoopbackModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/loopback/loopback_module.h b/libtransport/src/io_modules/loopback/loopback_module.h new file mode 100644 index 000000000..219fa8841 --- /dev/null +++ b/libtransport/src/io_modules/loopback/loopback_module.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/local_connector.h> +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +#include <atomic> + +namespace transport { + +namespace core { + +class LoopbackModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + LoopbackModule(); + + ~LoopbackModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + static std::vector<std::unique_ptr<LocalConnector>> local_faces_; + static std::atomic<uint32_t> global_counter_; + + private: + uint32_t local_id_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/memif/CMakeLists.txt b/libtransport/src/io_modules/memif/CMakeLists.txt new file mode 100644 index 000000000..c8a930e7b --- /dev/null +++ b/libtransport/src/io_modules/memif/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +find_package(Vpp REQUIRED) +find_package(Libmemif REQUIRED) + +if(CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) + find_package(HicnPlugin REQUIRED) + find_package(SafeVapi REQUIRED) +else() + list(APPEND DEPENDENCIES + ${SAFE_VAPI_SHARED} + ) +endif() + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.h + ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.h + ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.h + ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_module.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_vapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/memif_connector.cc + ${CMAKE_CURRENT_SOURCE_DIR}/memif_vapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/vpp_forwarder_module.cc +) + +build_module(memif_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + LINK_LIBRARIES ${LIBMEMIF_LIBRARIES} ${SAFE_VAPI_LIBRARIES} + INCLUDE_DIRS + ${LIBTRANSPORT_INCLUDE_DIRS} + ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + ${VPP_INCLUDE_DIRS} + ${LIBMEMIF_INCLUDE_DIRS} + ${SAFE_VAPI_INCLUDE_DIRS} + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/memif/hicn_vapi.c b/libtransport/src/io_modules/memif/hicn_vapi.c new file mode 100644 index 000000000..b83a36b47 --- /dev/null +++ b/libtransport/src/io_modules/memif/hicn_vapi.c @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/config.h> +#include <hicn/transport/utils/log.h> +#include <io_modules/memif/hicn_vapi.h> + +#define HICN_VPP_PLUGIN +#include <hicn/name.h> +#undef HICN_VPP_PLUGIN + +#include <vapi/hicn.api.vapi.h> +#include <vapi/ip.api.vapi.h> +#include <vapi/vapi_safe.h> +#include <vlib/vlib.h> +#include <vlibapi/api.h> +#include <vlibmemory/api.h> +#include <vnet/ip/format.h> +#include <vnet/ip/ip4_packet.h> +#include <vnet/ip/ip6_packet.h> +#include <vpp_plugins/hicn/error.h> +#include <vppinfra/error.h> + +///////////////////////////////////////////////////// +const char *HICN_ERROR_STRING[] = { +#define _(a, b, c) c, + foreach_hicn_error +#undef _ +}; +///////////////////////////////////////////////////// + +/*********************** Missing Symbol in vpp libraries + * *************************/ +u8 *format_vl_api_address_union(u8 *s, va_list *args) { return NULL; } + +/*********************************************************************************/ + +DEFINE_VAPI_MSG_IDS_HICN_API_JSON +DEFINE_VAPI_MSG_IDS_IP_API_JSON + +static vapi_error_e register_prod_app_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_register_prod_app_reply *reply) { + hicn_producer_output_params *output_params = + (hicn_producer_output_params *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->cs_reserved = reply->cs_reserved; + output_params->prod_addr = (ip_address_t *)malloc(sizeof(ip_address_t)); + memset(output_params->prod_addr, 0, sizeof(ip_address_t)); + if (reply->prod_addr.af == ADDRESS_IP6) + memcpy(&output_params->prod_addr->v6, reply->prod_addr.un.ip6, + sizeof(ip6_address_t)); + else + memcpy(&output_params->prod_addr->v4, reply->prod_addr.un.ip4, + sizeof(ip4_address_t)); + output_params->face_id = reply->faceid; + + return reply->retval; +} + +int hicn_vapi_register_prod_app(vapi_ctx_t ctx, + hicn_producer_input_params *input_params, + hicn_producer_output_params *output_params) { + vapi_lock(); + vapi_msg_hicn_api_register_prod_app *msg = + vapi_alloc_hicn_api_register_prod_app(ctx); + + if (ip46_address_is_ip4((ip46_address_t *)&input_params->prefix->address)) { + memcpy(&msg->payload.prefix.address.un.ip4, &input_params->prefix->address, + sizeof(ip4_address_t)); + msg->payload.prefix.address.af = ADDRESS_IP4; + } else { + memcpy(&msg->payload.prefix.address.un.ip6, &input_params->prefix->address, + sizeof(ip6_address_t)); + msg->payload.prefix.address.af = ADDRESS_IP6; + } + msg->payload.prefix.len = input_params->prefix->len; + + msg->payload.swif = input_params->swif; + msg->payload.cs_reserved = input_params->cs_reserved; + + int ret = vapi_hicn_api_register_prod_app(ctx, msg, register_prod_app_cb, + output_params); + vapi_unlock(); + return ret; +} + +static vapi_error_e face_prod_del_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_face_prod_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_face_prod_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params *input_params) { + vapi_lock(); + vapi_msg_hicn_api_face_prod_del *msg = vapi_alloc_hicn_api_face_prod_del(ctx); + + msg->payload.faceid = input_params->face_id; + + int ret = vapi_hicn_api_face_prod_del(ctx, msg, face_prod_del_cb, NULL); + vapi_unlock(); + return ret; +} + +static vapi_error_e register_cons_app_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_register_cons_app_reply *reply) { + hicn_consumer_output_params *output_params = + (hicn_consumer_output_params *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->src6 = (ip_address_t *)malloc(sizeof(ip_address_t)); + output_params->src4 = (ip_address_t *)malloc(sizeof(ip_address_t)); + memset(output_params->src6, 0, sizeof(ip_address_t)); + memset(output_params->src4, 0, sizeof(ip_address_t)); + memcpy(&output_params->src6->v6, &reply->src_addr6.un.ip6, + sizeof(ip6_address_t)); + memcpy(&output_params->src4->v4, &reply->src_addr4.un.ip4, + sizeof(ip4_address_t)); + + output_params->face_id1 = reply->faceid1; + output_params->face_id2 = reply->faceid2; + + return reply->retval; +} + +int hicn_vapi_register_cons_app(vapi_ctx_t ctx, + hicn_consumer_input_params *input_params, + hicn_consumer_output_params *output_params) { + vapi_lock(); + vapi_msg_hicn_api_register_cons_app *msg = + vapi_alloc_hicn_api_register_cons_app(ctx); + + msg->payload.swif = input_params->swif; + + int ret = vapi_hicn_api_register_cons_app(ctx, msg, register_cons_app_cb, + output_params); + vapi_unlock(); + return ret; +} + +static vapi_error_e face_cons_del_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_hicn_api_face_cons_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_face_cons_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params *input_params) { + vapi_lock(); + vapi_msg_hicn_api_face_cons_del *msg = vapi_alloc_hicn_api_face_cons_del(ctx); + + msg->payload.faceid = input_params->face_id; + + int ret = vapi_hicn_api_face_cons_del(ctx, msg, face_cons_del_cb, NULL); + vapi_unlock(); + return ret; +} + +static vapi_error_e reigster_route_cb( + vapi_ctx_t ctx, void *callback_ctx, vapi_error_e rv, bool is_last, + vapi_payload_ip_route_add_del_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int hicn_vapi_register_route(vapi_ctx_t ctx, + hicn_producer_set_route_params *input_params) { + vapi_lock(); + vapi_msg_ip_route_add_del *msg = vapi_alloc_ip_route_add_del(ctx, 1); + + msg->payload.is_add = 1; + if (ip46_address_is_ip4((ip46_address_t *)(input_params->prod_addr))) { + memcpy(&msg->payload.route.prefix.address.un.ip4, + &input_params->prefix->address.v4, sizeof(ip4_address_t)); + msg->payload.route.prefix.address.af = ADDRESS_IP4; + msg->payload.route.prefix.len = input_params->prefix->len; + } else { + memcpy(&msg->payload.route.prefix.address.un.ip6, + &input_params->prefix->address.v6, sizeof(ip6_address_t)); + msg->payload.route.prefix.address.af = ADDRESS_IP6; + msg->payload.route.prefix.len = input_params->prefix->len; + } + + msg->payload.route.paths[0].sw_if_index = ~0; + msg->payload.route.paths[0].table_id = 0; + if (ip46_address_is_ip4((ip46_address_t *)(input_params->prod_addr))) { + memcpy(&(msg->payload.route.paths[0].nh.address.ip4), + input_params->prod_addr->v4.as_u8, sizeof(ip4_address_t)); + msg->payload.route.paths[0].proto = FIB_API_PATH_NH_PROTO_IP4; + } else { + memcpy(&(msg->payload.route.paths[0].nh.address.ip6), + input_params->prod_addr->v6.as_u8, sizeof(ip6_address_t)); + msg->payload.route.paths[0].proto = FIB_API_PATH_NH_PROTO_IP6; + } + + msg->payload.route.paths[0].type = FIB_API_PATH_FLAG_NONE; + msg->payload.route.paths[0].flags = FIB_API_PATH_FLAG_NONE; + + int ret = vapi_ip_route_add_del(ctx, msg, reigster_route_cb, NULL); + + vapi_unlock(); + return ret; +} + +char *hicn_vapi_get_error_string(int ret_val) { + return get_error_string(ret_val); +} diff --git a/libtransport/src/io_modules/memif/hicn_vapi.h b/libtransport/src/io_modules/memif/hicn_vapi.h new file mode 100644 index 000000000..e94c97749 --- /dev/null +++ b/libtransport/src/io_modules/memif/hicn_vapi.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/util/ip_address.h> + +#ifdef __cplusplus +extern "C" { +#endif + +#include <vapi/vapi.h> + +#include "stdint.h" + +typedef struct { + ip_prefix_t* prefix; + uint32_t swif; + uint32_t cs_reserved; +} hicn_producer_input_params; + +typedef struct { + uint32_t swif; +} hicn_consumer_input_params; + +typedef struct { + uint32_t face_id; +} hicn_del_face_app_input_params; + +typedef struct { + uint32_t cs_reserved; + ip_address_t* prod_addr; + uint32_t face_id; +} hicn_producer_output_params; + +typedef struct { + ip_address_t* src4; + ip_address_t* src6; + uint32_t face_id1; + uint32_t face_id2; +} hicn_consumer_output_params; + +typedef struct { + ip_prefix_t* prefix; + ip_address_t* prod_addr; +} hicn_producer_set_route_params; + +int hicn_vapi_register_prod_app(vapi_ctx_t ctx, + hicn_producer_input_params* input_params, + hicn_producer_output_params* output_params); + +int hicn_vapi_register_cons_app(vapi_ctx_t ctx, + hicn_consumer_input_params* input_params, + hicn_consumer_output_params* output_params); + +int hicn_vapi_register_route(vapi_ctx_t ctx, + hicn_producer_set_route_params* input_params); + +int hicn_vapi_face_cons_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params* input_params); + +int hicn_vapi_face_prod_del(vapi_ctx_t ctx, + hicn_del_face_app_input_params* input_params); + +char* hicn_vapi_get_error_string(int ret_val); + +#ifdef __cplusplus +} +#endif diff --git a/libtransport/src/io_modules/memif/memif_connector.cc b/libtransport/src/io_modules/memif/memif_connector.cc new file mode 100644 index 000000000..4a688d68f --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_connector.cc @@ -0,0 +1,493 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/errors/not_implemented_exception.h> +#include <io_modules/memif/memif_connector.h> +#include <sys/epoll.h> + +#include <cstdlib> + +extern "C" { +#include <memif/libmemif.h> +}; + +#define CANCEL_TIMER 1 + +namespace transport { + +namespace core { + +struct memif_connection { + uint16_t index; + /* memif conenction handle */ + memif_conn_handle_t conn; + /* transmit queue id */ + uint16_t tx_qid; + /* tx buffers */ + memif_buffer_t *tx_bufs; + /* allocated tx buffers counter */ + /* number of tx buffers pointing to shared memory */ + uint16_t tx_buf_num; + /* rx buffers */ + memif_buffer_t *rx_bufs; + /* allcoated rx buffers counter */ + /* number of rx buffers pointing to shared memory */ + uint16_t rx_buf_num; + /* interface ip address */ + uint8_t ip_addr[4]; +}; + +std::once_flag MemifConnector::flag_; +utils::EpollEventReactor MemifConnector::main_event_reactor_; + +MemifConnector::MemifConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name) + : Connector(std::move(receive_callback), std::move(packet_sent), + std::move(close_callback), std::move(on_reconnect)), + memif_worker_(nullptr), + timer_set_(false), + send_timer_(std::make_unique<utils::FdDeadlineTimer>(event_reactor_)), + disconnect_timer_( + std::make_unique<utils::FdDeadlineTimer>(event_reactor_)), + io_service_(io_service), + memif_connection_(std::make_unique<memif_connection_t>()), + tx_buf_counter_(0), + is_reconnection_(false), + data_available_(false), + app_name_(app_name), + socket_filename_("") { + std::call_once(MemifConnector::flag_, &MemifConnector::init, this); +} + +MemifConnector::~MemifConnector() { close(); } + +void MemifConnector::init() { + /* initialize memory interface */ + int err = memif_init(controlFdUpdate, const_cast<char *>(app_name_.c_str()), + nullptr, nullptr, nullptr); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_init: %s", memif_strerror(err)); + } +} + +void MemifConnector::connect(uint32_t memif_id, long memif_mode) { + state_ = State::CONNECTING; + + memif_id_ = memif_id; + socket_filename_ = "/run/vpp/memif.sock"; + + createMemif(memif_id, memif_mode, nullptr); + + work_ = std::make_unique<asio::io_service::work>(io_service_); + + while (state_ != State::CONNECTED) { + MemifConnector::main_event_reactor_.runOneEvent(); + } + + int err; + + /* get interrupt queue id */ + int fd = -1; + err = memif_get_queue_efd(memif_connection_->conn, 0, &fd); + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_get_queue_efd: %s", memif_strerror(err)); + return; + } + + // Remove fd from main epoll + main_event_reactor_.delFileDescriptor(fd); + + // Add fd to epoll of instance + event_reactor_.addFileDescriptor( + fd, EPOLLIN, [this](const utils::Event &evt) -> int { + return onInterrupt(memif_connection_->conn, this, 0); + }); + + memif_worker_ = std::make_unique<std::thread>( + std::bind(&MemifConnector::threadMain, this)); +} + +int MemifConnector::createMemif(uint32_t index, uint8_t mode, char *s) { + memif_connection_t *c = memif_connection_.get(); + + /* setting memif connection arguments */ + memif_conn_args_t args; + memset(&args, 0, sizeof(args)); + + args.is_master = mode; + args.log2_ring_size = MEMIF_LOG2_RING_SIZE; + args.buffer_size = MEMIF_BUF_SIZE; + args.num_s2m_rings = 1; + args.num_m2s_rings = 1; + strncpy((char *)args.interface_name, IF_NAME, strlen(IF_NAME) + 1); + args.mode = memif_interface_mode_t::MEMIF_INTERFACE_MODE_IP; + + int err; + + err = memif_create_socket(&args.socket, socket_filename_.c_str(), nullptr); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + throw errors::RuntimeException(memif_strerror(err)); + } + + args.interface_id = index; + /* last argument for memif_create (void * private_ctx) is used by user + to identify connection. this context is returned with callbacks */ + + /* default interrupt */ + if (s == nullptr) { + err = memif_create(&c->conn, &args, onConnect, onDisconnect, onInterrupt, + this); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + throw errors::RuntimeException(memif_strerror(err)); + } + } + + c->index = (uint16_t)index; + c->tx_qid = 0; + /* alloc memif buffers */ + c->rx_buf_num = 0; + c->rx_bufs = static_cast<memif_buffer_t *>( + malloc(sizeof(memif_buffer_t) * MAX_MEMIF_BUFS)); + c->tx_buf_num = 0; + c->tx_bufs = static_cast<memif_buffer_t *>( + malloc(sizeof(memif_buffer_t) * MAX_MEMIF_BUFS)); + + // memif_set_rx_mode (c->conn, MEMIF_RX_MODE_POLLING, 0); + + return 0; +} + +int MemifConnector::deleteMemif() { + memif_connection_t *c = memif_connection_.get(); + + if (c->rx_bufs) { + free(c->rx_bufs); + } + + c->rx_bufs = nullptr; + c->rx_buf_num = 0; + + if (c->tx_bufs) { + free(c->tx_bufs); + } + + c->tx_bufs = nullptr; + c->tx_buf_num = 0; + + int err; + /* disconenct then delete memif connection */ + err = memif_delete(&c->conn); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_delete: %s", memif_strerror(err)); + } + + if (TRANSPORT_EXPECT_FALSE(c->conn != nullptr)) { + TRANSPORT_LOGE("memif delete fail"); + } + + return 0; +} + +int MemifConnector::controlFdUpdate(int fd, uint8_t events, void *private_ctx) { + /* convert memif event definitions to epoll events */ + if (events & MEMIF_FD_EVENT_DEL) { + return MemifConnector::main_event_reactor_.delFileDescriptor(fd); + } + + uint32_t evt = 0; + + if (events & MEMIF_FD_EVENT_READ) { + evt |= EPOLLIN; + } + + if (events & MEMIF_FD_EVENT_WRITE) { + evt |= EPOLLOUT; + } + + if (events & MEMIF_FD_EVENT_MOD) { + return MemifConnector::main_event_reactor_.modFileDescriptor(fd, evt); + } + + return MemifConnector::main_event_reactor_.addFileDescriptor( + fd, evt, [](const utils::Event &evt) -> int { + uint32_t event = 0; + int memif_err = 0; + + if (evt.events & EPOLLIN) { + event |= MEMIF_FD_EVENT_READ; + } + + if (evt.events & EPOLLOUT) { + event |= MEMIF_FD_EVENT_WRITE; + } + + if (evt.events & EPOLLERR) { + event |= MEMIF_FD_EVENT_ERROR; + } + + memif_err = memif_control_fd_handler(evt.data.fd, event); + + if (TRANSPORT_EXPECT_FALSE(memif_err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_control_fd_handler: %s", + memif_strerror(memif_err)); + } + + return 0; + }); +} + +int MemifConnector::bufferAlloc(long n, uint16_t qid) { + memif_connection_t *c = memif_connection_.get(); + int err; + uint16_t r; + /* set data pointer to shared memory and set buffer_len to shared mmeory + * buffer len */ + err = memif_buffer_alloc(c->conn, qid, c->tx_bufs, n, &r, 2000); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_alloc: %s", memif_strerror(err)); + return -1; + } + + c->tx_buf_num += r; + return r; +} + +int MemifConnector::txBurst(uint16_t qid) { + memif_connection_t *c = memif_connection_.get(); + int err; + uint16_t r; + /* inform peer memif interface about data in shared memory buffers */ + /* mark memif buffers as free */ + err = memif_tx_burst(c->conn, qid, c->tx_bufs, c->tx_buf_num, &r); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); + } + + // err = memif_refill_queue(c->conn, qid, r, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); + c->tx_buf_num -= r; + return -1; + } + + c->tx_buf_num -= r; + return 0; +} + +void MemifConnector::sendCallback(const std::error_code &ec) { + timer_set_ = false; + + if (TRANSPORT_EXPECT_TRUE(!ec && state_ == State::CONNECTED)) { + doSend(); + } +} + +void MemifConnector::processInputBuffer(std::uint16_t total_packets) { + utils::MemBuf::Ptr ptr; + + for (; total_packets > 0; total_packets--) { + if (input_buffer_.pop(ptr)) { + receive_callback_(this, *ptr, std::make_error_code(std::errc(0))); + } + } +} + +/* informs user about connected status. private_ctx is used by user to identify + connection (multiple connections WIP) */ +int MemifConnector::onConnect(memif_conn_handle_t conn, void *private_ctx) { + MemifConnector *connector = (MemifConnector *)private_ctx; + connector->state_ = State::CONNECTED; + memif_refill_queue(conn, 0, -1, 0); + + return 0; +} + +/* informs user about disconnected status. private_ctx is used by user to + identify connection (multiple connections WIP) */ +int MemifConnector::onDisconnect(memif_conn_handle_t conn, void *private_ctx) { + MemifConnector *connector = (MemifConnector *)private_ctx; + connector->state_ = State::CLOSED; + return 0; +} + +void MemifConnector::threadMain() { event_reactor_.runEventLoop(1000); } + +int MemifConnector::onInterrupt(memif_conn_handle_t conn, void *private_ctx, + uint16_t qid) { + MemifConnector *connector = (MemifConnector *)private_ctx; + + memif_connection_t *c = connector->memif_connection_.get(); + int err = MEMIF_ERR_SUCCESS, ret_val; + uint16_t total_packets = 0; + uint16_t rx; + + do { + err = memif_rx_burst(conn, qid, c->rx_bufs, MAX_MEMIF_BUFS, &rx); + ret_val = err; + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS && + err != MEMIF_ERR_NOBUF)) { + TRANSPORT_LOGE("memif_rx_burst: %s", memif_strerror(err)); + goto error; + } + + c->rx_buf_num += rx; + + if (TRANSPORT_EXPECT_FALSE(connector->io_service_.stopped())) { + TRANSPORT_LOGE("socket stopped: ignoring %u packets", rx); + goto error; + } + + std::size_t packet_length; + for (int i = 0; i < rx; i++) { + auto buffer = connector->getRawBuffer(); + packet_length = (c->rx_bufs + i)->len; + std::memcpy(buffer.first, (c->rx_bufs + i)->data, packet_length); + auto packet = connector->getPacketFromBuffer(buffer.first, packet_length); + + if (!connector->input_buffer_.push(std::move(packet))) { + TRANSPORT_LOGE("Error pushing packet. Ring buffer full."); + + // TODO Here we should consider the possibility to signal the congestion + // to the application, that would react properly (e.g. slow down + // message) + } + } + + /* mark memif buffers and shared memory buffers as free */ + /* free processed buffers */ + + err = memif_refill_queue(conn, qid, rx, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); + } + + c->rx_buf_num -= rx; + total_packets += rx; + + } while (ret_val == MEMIF_ERR_NOBUF); + + connector->io_service_.post( + std::bind(&MemifConnector::processInputBuffer, connector, total_packets)); + + return 0; + +error: + err = memif_refill_queue(c->conn, qid, rx, 0); + + if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); + } + c->rx_buf_num -= rx; + + return 0; +} + +void MemifConnector::close() { + if (state_ != State::CLOSED) { + disconnect_timer_->expiresFromNow(std::chrono::microseconds(50)); + disconnect_timer_->asyncWait([this](const std::error_code &ec) { + deleteMemif(); + event_reactor_.stop(); + work_.reset(); + }); + + if (memif_worker_ && memif_worker_->joinable()) { + memif_worker_->join(); + } + } +} + +void MemifConnector::send(Packet &packet) { + { + utils::SpinLock::Acquire locked(write_msgs_lock_); + output_buffer_.push_back(packet.shared_from_this()); + } +#if CANCEL_TIMER + if (!timer_set_) { + timer_set_ = true; + send_timer_->expiresFromNow(std::chrono::microseconds(50)); + send_timer_->asyncWait( + std::bind(&MemifConnector::sendCallback, this, std::placeholders::_1)); + } +#endif +} + +int MemifConnector::doSend() { + std::size_t max = 0; + int32_t n = 0; + std::size_t size = 0; + + { + utils::SpinLock::Acquire locked(write_msgs_lock_); + size = output_buffer_.size(); + } + + do { + max = size < MAX_MEMIF_BUFS ? size : MAX_MEMIF_BUFS; + n = bufferAlloc(max, memif_connection_->tx_qid); + + if (TRANSPORT_EXPECT_FALSE(n < 0)) { + TRANSPORT_LOGE("Error allocating buffers."); + return -1; + } + + for (uint16_t i = 0; i < n; i++) { + utils::SpinLock::Acquire locked(write_msgs_lock_); + + auto packet = output_buffer_.front().get(); + const utils::MemBuf *current = packet; + std::size_t offset = 0; + uint8_t *shared_buffer = + reinterpret_cast<uint8_t *>(memif_connection_->tx_bufs[i].data); + do { + std::memcpy(shared_buffer + offset, current->data(), current->length()); + offset += current->length(); + current = current->next(); + } while (current != packet); + + memif_connection_->tx_bufs[i].len = uint32_t(offset); + + output_buffer_.pop_front(); + } + + txBurst(memif_connection_->tx_qid); + + utils::SpinLock::Acquire locked(write_msgs_lock_); + size = output_buffer_.size(); + } while (size > 0); + + return 0; +} + +void MemifConnector::send(const uint8_t *packet, std::size_t len) { + throw errors::NotImplementedException(); +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/memif/memif_connector.h b/libtransport/src/io_modules/memif/memif_connector.h new file mode 100644 index 000000000..bed3516dc --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_connector.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/portability/portability.h> +#include <hicn/transport/utils/ring_buffer.h> +//#include <hicn/transport/core/hicn_vapi.h> +#include <utils/epoll_event_reactor.h> +#include <utils/fd_deadline_timer.h> + +#include <asio.hpp> +#include <deque> +#include <mutex> +#include <thread> + +#define _Static_assert static_assert + +namespace transport { + +namespace core { + +typedef struct memif_connection memif_connection_t; + +#define APP_NAME "libtransport" +#define IF_NAME "vpp_connection" + +#define MEMIF_BUF_SIZE 2048 +#define MEMIF_LOG2_RING_SIZE 13 +#define MAX_MEMIF_BUFS (1 << MEMIF_LOG2_RING_SIZE) + +class MemifConnector : public Connector { + using memif_conn_handle_t = void *; + using PacketRing = utils::CircularFifo<utils::MemBuf::Ptr, queue_size>; + + public: + MemifConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~MemifConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(uint32_t memif_id, long memif_mode); + + TRANSPORT_ALWAYS_INLINE uint32_t getMemifId() { return memif_id_; }; + + private: + void init(); + + int doSend(); + + int createMemif(uint32_t index, uint8_t mode, char *s); + + uint32_t getMemifConfiguration(); + + int deleteMemif(); + + static int controlFdUpdate(int fd, uint8_t events, void *private_ctx); + + static int onConnect(memif_conn_handle_t conn, void *private_ctx); + + static int onDisconnect(memif_conn_handle_t conn, void *private_ctx); + + static int onInterrupt(memif_conn_handle_t conn, void *private_ctx, + uint16_t qid); + + void threadMain(); + + int txBurst(uint16_t qid); + + int bufferAlloc(long n, uint16_t qid); + + void sendCallback(const std::error_code &ec); + + void processInputBuffer(std::uint16_t total_packets); + + private: + static utils::EpollEventReactor main_event_reactor_; + static std::unique_ptr<std::thread> main_worker_; + + int epfd; + std::unique_ptr<std::thread> memif_worker_; + utils::EpollEventReactor event_reactor_; + std::atomic_bool timer_set_; + std::unique_ptr<utils::FdDeadlineTimer> send_timer_; + std::unique_ptr<utils::FdDeadlineTimer> disconnect_timer_; + asio::io_service &io_service_; + std::unique_ptr<asio::io_service::work> work_; + std::unique_ptr<memif_connection_t> memif_connection_; + uint16_t tx_buf_counter_; + + PacketRing input_buffer_; + bool is_reconnection_; + bool data_available_; + uint32_t memif_id_; + uint8_t memif_mode_; + std::string app_name_; + uint16_t transmission_index_; + utils::SpinLock write_msgs_lock_; + std::string socket_filename_; + + static std::once_flag flag_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/memif/memif_vapi.c b/libtransport/src/io_modules/memif/memif_vapi.c new file mode 100644 index 000000000..b3da2b012 --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_vapi.c @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include <fcntl.h> +#include <hicn/transport/config.h> +#include <inttypes.h> +#include <io_modules/memif/memif_vapi.h> +#include <semaphore.h> +#include <string.h> +#include <sys/stat.h> +#include <vapi/vapi_safe.h> +#include <vppinfra/clib.h> + +DEFINE_VAPI_MSG_IDS_MEMIF_API_JSON + +static vapi_error_e memif_details_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_details *reply) { + uint32_t *last_memif_id = (uint32_t *)callback_ctx; + uint32_t current_memif_id = 0; + if (reply != NULL) { + current_memif_id = reply->id; + } else { + return rv; + } + + if (current_memif_id >= *last_memif_id) { + *last_memif_id = current_memif_id + 1; + } + + return rv; +} + +int memif_vapi_get_next_memif_id(vapi_ctx_t ctx, uint32_t *memif_id) { + vapi_lock(); + vapi_msg_memif_dump *msg = vapi_alloc_memif_dump(ctx); + int ret = vapi_memif_dump(ctx, msg, memif_details_cb, memif_id); + vapi_unlock(); + return ret; +} + +static vapi_error_e memif_create_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_create_reply *reply) { + memif_output_params_t *output_params = (memif_output_params_t *)callback_ctx; + + if (reply == NULL) return rv; + + output_params->sw_if_index = reply->sw_if_index; + + return rv; +} + +int memif_vapi_create_memif(vapi_ctx_t ctx, memif_create_params_t *input_params, + memif_output_params_t *output_params) { + vapi_lock(); + vapi_msg_memif_create *msg = vapi_alloc_memif_create(ctx); + + int ret = 0; + if (input_params->socket_id == ~0) { + // invalid socket-id + ret = -1; + goto END; + } + + if (!is_pow2(input_params->ring_size)) { + // ring size must be power of 2 + ret = -1; + goto END; + } + + if (input_params->rx_queues > 255 || input_params->rx_queues < 1) { + // rx queue must be between 1 - 255 + ret = -1; + goto END; + } + + if (input_params->tx_queues > 255 || input_params->tx_queues < 1) { + // tx queue must be between 1 - 255 + ret = -1; + goto END; + } + + msg->payload.role = input_params->role; + msg->payload.mode = input_params->mode; + msg->payload.rx_queues = input_params->rx_queues; + msg->payload.tx_queues = input_params->tx_queues; + msg->payload.id = input_params->id; + msg->payload.socket_id = input_params->socket_id; + msg->payload.ring_size = input_params->ring_size; + msg->payload.buffer_size = input_params->buffer_size; + + ret = vapi_memif_create(ctx, msg, memif_create_cb, output_params); +END: + vapi_unlock(); + return ret; +} + +static vapi_error_e memif_delete_cb(vapi_ctx_t ctx, void *callback_ctx, + vapi_error_e rv, bool is_last, + vapi_payload_memif_delete_reply *reply) { + if (reply == NULL) return rv; + + return reply->retval; +} + +int memif_vapi_delete_memif(vapi_ctx_t ctx, uint32_t sw_if_index) { + vapi_lock(); + vapi_msg_memif_delete *msg = vapi_alloc_memif_delete(ctx); + + msg->payload.sw_if_index = sw_if_index; + + int ret = vapi_memif_delete(ctx, msg, memif_delete_cb, NULL); + vapi_unlock(); + return ret; +} diff --git a/libtransport/src/io_modules/memif/memif_vapi.h b/libtransport/src/io_modules/memif/memif_vapi.h new file mode 100644 index 000000000..bcf06ed43 --- /dev/null +++ b/libtransport/src/io_modules/memif/memif_vapi.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> + +#ifdef __cplusplus +extern "C" { +#endif + +#include <vapi/memif.api.vapi.h> + +#include "stdint.h" + +typedef struct memif_create_params_s { + uint8_t role; + uint8_t mode; + uint8_t rx_queues; + uint8_t tx_queues; + uint32_t id; + uint32_t socket_id; + uint8_t secret[24]; + uint32_t ring_size; + uint16_t buffer_size; + uint8_t hw_addr[6]; +} memif_create_params_t; + +typedef struct memif_output_params_s { + uint32_t sw_if_index; +} memif_output_params_t; + +int memif_vapi_get_next_memif_id(vapi_ctx_t ctx, uint32_t *memif_id); + +int memif_vapi_create_memif(vapi_ctx_t ctx, memif_create_params_t *input_params, + memif_output_params_t *output_params); + +int memif_vapi_delete_memif(vapi_ctx_t ctx, uint32_t sw_if_index); + +#ifdef __cplusplus +} +#endif diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.cc b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc new file mode 100644 index 000000000..dcbcd7ed0 --- /dev/null +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/config.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <io_modules/memif/hicn_vapi.h> +#include <io_modules/memif/memif_connector.h> +#include <io_modules/memif/memif_vapi.h> +#include <io_modules/memif/vpp_forwarder_module.h> + +extern "C" { +#include <memif/libmemif.h> +}; + +typedef enum { MASTER = 0, SLAVE = 1 } memif_role_t; + +#define MEMIF_DEFAULT_RING_SIZE 2048 +#define MEMIF_DEFAULT_RX_QUEUES 1 +#define MEMIF_DEFAULT_TX_QUEUES 1 +#define MEMIF_DEFAULT_BUFFER_SIZE 2048 + +namespace transport { + +namespace core { + +VPPForwarderModule::VPPForwarderModule() + : IoModule(), + connector_(nullptr), + sw_if_index_(~0), + face_id1_(~0), + face_id2_(~0), + is_consumer_(false) {} + +VPPForwarderModule::~VPPForwarderModule() { delete connector_; } + +void VPPForwarderModule::init( + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, const std::string &app_name) { + if (!connector_) { + connector_ = + new MemifConnector(std::move(receive_callback), 0, 0, + std::move(reconnect_callback), io_service, app_name); + } +} + +void VPPForwarderModule::processControlMessageReply( + utils::MemBuf &packet_buffer) { + throw errors::NotImplementedException(); +} + +bool VPPForwarderModule::isControlMessage(const uint8_t *message) { + return false; +} + +bool VPPForwarderModule::isConnected() { return connector_->isConnected(); }; + +void VPPForwarderModule::send(Packet &packet) { + IoModule::send(packet); + connector_->send(packet); +} + +void VPPForwarderModule::send(const uint8_t *packet, std::size_t len) { + counters_.tx_packets++; + counters_.tx_bytes += len; + + // Perfect forwarding + connector_->send(packet, len); +} + +std::uint32_t VPPForwarderModule::getMtu() { return interface_mtu; } + +/** + * @brief Create a memif interface in the local VPP forwarder. + */ +uint32_t VPPForwarderModule::getMemifConfiguration() { + memif_create_params_t input_params = {0}; + + int ret = memif_vapi_get_next_memif_id(VPPForwarderModule::sock_, &memif_id_); + + if (ret < 0) { + throw errors::RuntimeException( + "Error getting next memif id. Could not create memif interface."); + } + + input_params.id = memif_id_; + input_params.role = memif_role_t::MASTER; + input_params.mode = memif_interface_mode_t::MEMIF_INTERFACE_MODE_IP; + input_params.rx_queues = MEMIF_DEFAULT_RX_QUEUES; + input_params.tx_queues = MEMIF_DEFAULT_TX_QUEUES; + input_params.ring_size = MEMIF_DEFAULT_RING_SIZE; + input_params.buffer_size = MEMIF_DEFAULT_BUFFER_SIZE; + + memif_output_params_t output_params = {0}; + + ret = memif_vapi_create_memif(VPPForwarderModule::sock_, &input_params, + &output_params); + + if (ret < 0) { + throw errors::RuntimeException( + "Error creating memif interface in the local VPP forwarder."); + } + + return output_params.sw_if_index; +} + +void VPPForwarderModule::consumerConnection() { + hicn_consumer_input_params input = {0}; + hicn_consumer_output_params output = {0}; + ip_address_t ip4_address; + ip_address_t ip6_address; + + output.src4 = &ip4_address; + output.src6 = &ip6_address; + input.swif = sw_if_index_; + + int ret = + hicn_vapi_register_cons_app(VPPForwarderModule::sock_, &input, &output); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + + face_id1_ = output.face_id1; + face_id2_ = output.face_id2; + + std::memcpy(inet_address_.v4.as_u8, output.src4->v4.as_u8, IPV4_ADDR_LEN); + + std::memcpy(inet6_address_.v6.as_u8, output.src6->v6.as_u8, IPV6_ADDR_LEN); +} + +void VPPForwarderModule::producerConnection() { + // Producer connection will be set when we set the first route. +} + +void VPPForwarderModule::connect(bool is_consumer) { + int retry = 20; + + TRANSPORT_LOGI("Connecting to VPP through vapi."); + vapi_error_e ret = vapi_connect_safe(&sock_, 0); + + while (ret != VAPI_OK && retry > 0) { + TRANSPORT_LOGE("Error connecting to VPP through vapi. Retrying.."); + --retry; + ret = vapi_connect_safe(&sock_, 0); + } + + if (ret != VAPI_OK) { + throw std::runtime_error( + "Impossible to connect to forwarder. Is VPP running?"); + } + + TRANSPORT_LOGI("Connected to VPP through vapi."); + + sw_if_index_ = getMemifConfiguration(); + + is_consumer_ = is_consumer; + if (is_consumer_) { + consumerConnection(); + } + + connector_->connect(memif_id_, 0); + connector_->setRole(is_consumer_ ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +void VPPForwarderModule::registerRoute(const Prefix &prefix) { + const ip_prefix_t &addr = prefix.toIpPrefixStruct(); + + ip_prefix_t producer_prefix; + ip_address_t producer_locator; + + if (face_id1_ == uint32_t(~0)) { + hicn_producer_input_params input; + std::memset(&input, 0, sizeof(input)); + + hicn_producer_output_params output; + std::memset(&output, 0, sizeof(output)); + + input.prefix = &producer_prefix; + output.prod_addr = &producer_locator; + + // Here we have to ask to the actual connector what is the + // memif_id, since this function should be called after the + // memif creation.n + input.swif = sw_if_index_; + input.prefix->address = addr.address; + input.prefix->family = addr.family; + input.prefix->len = addr.len; + input.cs_reserved = content_store_reserved_; + + int ret = + hicn_vapi_register_prod_app(VPPForwarderModule::sock_, &input, &output); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + + inet6_address_ = *output.prod_addr; + + face_id1_ = output.face_id; + } else { + hicn_producer_set_route_params params; + params.prefix = &producer_prefix; + params.prefix->address = addr.address; + params.prefix->family = addr.family; + params.prefix->len = addr.len; + params.prod_addr = &producer_locator; + + int ret = hicn_vapi_register_route(VPPForwarderModule::sock_, ¶ms); + + if (ret < 0) { + throw errors::RuntimeException(hicn_vapi_get_error_string(ret)); + } + } +} + +void VPPForwarderModule::closeConnection() { + if (VPPForwarderModule::sock_) { + connector_->close(); + + if (is_consumer_) { + hicn_del_face_app_input_params params; + params.face_id = face_id1_; + hicn_vapi_face_cons_del(VPPForwarderModule::sock_, ¶ms); + params.face_id = face_id2_; + hicn_vapi_face_cons_del(VPPForwarderModule::sock_, ¶ms); + } else { + hicn_del_face_app_input_params params; + params.face_id = face_id1_; + hicn_vapi_face_prod_del(VPPForwarderModule::sock_, ¶ms); + } + + if (sw_if_index_ != uint32_t(~0)) { + int ret = + memif_vapi_delete_memif(VPPForwarderModule::sock_, sw_if_index_); + if (ret < 0) { + TRANSPORT_LOGE("Error deleting memif with sw idx %u.", sw_if_index_); + } + } + + vapi_disconnect_safe(); + VPPForwarderModule::sock_ = nullptr; + } +} + +extern "C" IoModule *create_module(void) { return new VPPForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.h b/libtransport/src/io_modules/memif/vpp_forwarder_module.h new file mode 100644 index 000000000..8c4114fed --- /dev/null +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +#ifdef always_inline +#undef always_inline +#endif +extern "C" { +#include <vapi/vapi_safe.h> +}; + +namespace transport { + +namespace core { + +class MemifConnector; + +class VPPForwarderModule : public IoModule { + static constexpr std::uint16_t interface_mtu = 1500; + + public: + VPPForwarderModule(); + ~VPPForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + uint32_t getMemifConfiguration(); + void consumerConnection(); + void producerConnection(); + + private: + MemifConnector *connector_; + uint32_t memif_id_; + uint32_t sw_if_index_; + // A consumer socket in vpp has two faces (ipv4 and ipv6) + uint32_t face_id1_; + uint32_t face_id2_; + bool is_consumer_; + vapi_ctx_t sock_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc b/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc new file mode 100644 index 000000000..0bfcc2a58 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_connector.cc @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/raw_socket_connector.h> +#include <hicn/transport/utils/conversions.h> +#include <hicn/transport/utils/log.h> +#include <net/if.h> +#include <netdb.h> +#include <stdio.h> +#include <string.h> +#include <sys/ioctl.h> +#include <sys/socket.h> + +#define MY_DEST_MAC0 0x0a +#define MY_DEST_MAC1 0x7b +#define MY_DEST_MAC2 0x7c +#define MY_DEST_MAC3 0x1c +#define MY_DEST_MAC4 0x4a +#define MY_DEST_MAC5 0x14 + +namespace transport { + +namespace core { + +RawSocketConnector::RawSocketConnector( + PacketReceivedCallback &&receive_callback, + OnReconnect &&on_reconnect_callback, asio::io_service &io_service, + std::string app_name) + : Connector(std::move(receive_callback), std::move(on_reconnect_callback)), + io_service_(io_service), + socket_(io_service_, raw_protocol(PF_PACKET, SOCK_RAW)), + // resolver_(io_service_), + timer_(io_service_), + read_msg_(packet_pool_.makePtr(nullptr)), + data_available_(false), + app_name_(app_name) { + memset(&link_layer_address_, 0, sizeof(link_layer_address_)); +} + +RawSocketConnector::~RawSocketConnector() {} + +void RawSocketConnector::connect(const std::string &interface_name, + const std::string &mac_address_str) { + state_ = ConnectorState::CONNECTING; + memset(ðernet_header_, 0, sizeof(ethernet_header_)); + struct ifreq ifr; + struct ifreq if_mac; + uint8_t mac_address[6]; + + utils::convertStringToMacAddress(mac_address_str, mac_address); + + // Get interface mac address + int fd = static_cast<int>(socket_.native_handle()); + + /* Get the index of the interface to send on */ + memset(&ifr, 0, sizeof(struct ifreq)); + strncpy(ifr.ifr_name, interface_name.c_str(), interface_name.size()); + + // if (ioctl(fd, SIOCGIFINDEX, &if_idx) < 0) { + // perror("SIOCGIFINDEX"); + // } + + /* Get the MAC address of the interface to send on */ + memset(&if_mac, 0, sizeof(struct ifreq)); + strncpy(if_mac.ifr_name, interface_name.c_str(), interface_name.size()); + if (ioctl(fd, SIOCGIFHWADDR, &if_mac) < 0) { + perror("SIOCGIFHWADDR"); + throw errors::RuntimeException("Interface does not exist"); + } + + /* Ethernet header */ + for (int i = 0; i < 6; i++) { + ethernet_header_.ether_shost[i] = + ((uint8_t *)&if_mac.ifr_hwaddr.sa_data)[i]; + ethernet_header_.ether_dhost[i] = mac_address[i]; + } + + /* Ethertype field */ + ethernet_header_.ether_type = htons(ETH_P_IPV6); + + strcpy(ifr.ifr_name, interface_name.c_str()); + + if (0 == ioctl(fd, SIOCGIFHWADDR, &ifr)) { + memcpy(link_layer_address_.sll_addr, ifr.ifr_hwaddr.sa_data, 6); + } + + // memset(&ifr, 0, sizeof(ifr)); + // ioctl(fd, SIOCGIFFLAGS, &ifr); + // ifr.ifr_flags |= IFF_PROMISC; + // ioctl(fd, SIOCSIFFLAGS, &ifr); + + link_layer_address_.sll_family = AF_PACKET; + link_layer_address_.sll_protocol = htons(ETH_P_ALL); + link_layer_address_.sll_ifindex = if_nametoindex(interface_name.c_str()); + link_layer_address_.sll_hatype = 1; + link_layer_address_.sll_halen = 6; + + // startConnectionTimer(); + doConnect(); + doRecvPacket(); +} + +void RawSocketConnector::send(const uint8_t *packet, std::size_t len, + const PacketSentCallback &packet_sent) { + if (packet_sent != 0) { + socket_.async_send( + asio::buffer(packet, len), + [packet_sent](std::error_code ec, std::size_t /*length*/) { + packet_sent(); + }); + } else { + if (state_ == ConnectorState::CONNECTED) { + socket_.send(asio::buffer(packet, len)); + } + } +} + +void RawSocketConnector::send(const Packet::MemBufPtr &packet) { + io_service_.post([this, packet]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(packet)); + if (TRANSPORT_EXPECT_TRUE(state_ == ConnectorState::CONNECTED)) { + if (!write_in_progress) { + doSendPacket(); + } else { + // Tell the handle connect it has data to write + data_available_ = true; + } + } + }); +} + +void RawSocketConnector::close() { + io_service_.post([this]() { socket_.close(); }); +} + +void RawSocketConnector::doSendPacket() { + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_.async_send( + std::move(array), + [this /*, packet*/](std::error_code ec, std::size_t bytes_transferred) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doSendPacket(); + } + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + }); +} + +void RawSocketConnector::doRecvPacket() { + read_msg_ = getPacket(); + socket_.async_receive( + asio::buffer(read_msg_->writableData(), packet_size), + [this](std::error_code ec, std::size_t bytes_transferred) mutable { + if (!ec) { + // Ignore packets that are not for us + uint8_t *dst_mac_address = const_cast<uint8_t *>(read_msg_->data()); + if (!std::memcmp(dst_mac_address, ethernet_header_.ether_shost, + ETHER_ADDR_LEN)) { + read_msg_->append(bytes_transferred); + read_msg_->trimStart(sizeof(struct ether_header)); + receive_callback_(std::move(read_msg_)); + } + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + } + doRecvPacket(); + }); +} + +void RawSocketConnector::doConnect() { + state_ = ConnectorState::CONNECTED; + socket_.bind(raw_endpoint(&link_layer_address_, sizeof(link_layer_address_))); +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_connector.h b/libtransport/src/io_modules/raw_socket/raw_socket_connector.h new file mode 100644 index 000000000..aba4b1105 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_connector.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/connector.h> +#include <hicn/transport/config.h> +#include <hicn/transport/core/name.h> +#include <linux/if_packet.h> +#include <net/ethernet.h> +#include <sys/socket.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <deque> + +namespace transport { + +namespace core { + +using asio::generic::raw_protocol; +using raw_endpoint = asio::generic::basic_endpoint<raw_protocol>; + +class RawSocketConnector : public Connector { + public: + RawSocketConnector(PacketReceivedCallback &&receive_callback, + OnReconnect &&reconnect_callback, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~RawSocketConnector() override; + + void send(const Packet::MemBufPtr &packet) override; + + void send(const uint8_t *packet, std::size_t len, + const PacketSentCallback &packet_sent = 0) override; + + void close() override; + + void connect(const std::string &interface_name, + const std::string &mac_address_str); + + private: + void doConnect(); + + void doRecvPacket(); + + void doSendPacket(); + + private: + asio::io_service &io_service_; + raw_protocol::socket socket_; + + struct ether_header ethernet_header_; + + struct sockaddr_ll link_layer_address_; + + asio::steady_timer timer_; + + utils::ObjectPool<utils::MemBuf>::Ptr read_msg_; + + bool data_available_; + std::string app_name_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc b/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc new file mode 100644 index 000000000..dcf489f59 --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_interface.cc @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/raw_socket_interface.h> +#include <hicn/transport/utils/linux.h> + +#include <fstream> + +namespace transport { + +namespace core { + +static std::string config_folder_path = "/etc/transport/interface.conf.d"; + +RawSocketInterface::RawSocketInterface(RawSocketConnector &connector) + : ForwarderInterface<RawSocketInterface, RawSocketConnector>(connector) {} + +RawSocketInterface::~RawSocketInterface() {} + +void RawSocketInterface::connect(bool is_consumer) { + std::string complete_filename = + config_folder_path + std::string("/") + output_interface_; + + std::ifstream is(complete_filename); + std::string interface; + + if (is) { + is >> remote_mac_address_; + } + + // Get interface ip address + struct sockaddr_in6 address = {0}; + utils::retrieveInterfaceAddress(output_interface_, &address); + + std::memcpy(&inet6_address_.v6.as_u8, &address.sin6_addr, + sizeof(address.sin6_addr)); + connector_.connect(output_interface_, remote_mac_address_); +} + +void RawSocketInterface::registerRoute(Prefix &prefix) { return; } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/raw_socket/raw_socket_interface.h b/libtransport/src/io_modules/raw_socket/raw_socket_interface.h new file mode 100644 index 000000000..7036cac7e --- /dev/null +++ b/libtransport/src/io_modules/raw_socket/raw_socket_interface.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <core/forwarder_interface.h> +#include <core/raw_socket_connector.h> +#include <hicn/transport/core/prefix.h> + +#include <atomic> +#include <deque> + +namespace transport { + +namespace core { + +class RawSocketInterface + : public ForwarderInterface<RawSocketInterface, RawSocketConnector> { + public: + typedef RawSocketConnector ConnectorType; + + RawSocketInterface(RawSocketConnector &connector); + + ~RawSocketInterface(); + + void connect(bool is_consumer); + + void registerRoute(Prefix &prefix); + + std::uint16_t getMtu() { return interface_mtu; } + + TRANSPORT_ALWAYS_INLINE static bool isControlMessageImpl( + const uint8_t *message) { + return false; + } + + TRANSPORT_ALWAYS_INLINE void processControlMessageReplyImpl( + Packet::MemBufPtr &&packet_buffer) {} + + TRANSPORT_ALWAYS_INLINE void closeConnection(){}; + + private: + static constexpr std::uint16_t interface_mtu = 1500; + std::string remote_mac_address_; +}; + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/CMakeLists.txt b/libtransport/src/io_modules/udp/CMakeLists.txt new file mode 100644 index 000000000..1a43492dc --- /dev/null +++ b/libtransport/src/io_modules/udp/CMakeLists.txt @@ -0,0 +1,47 @@ +# Copyright (c) 2021 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + + +list(APPEND MODULE_HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.h +) + +list(APPEND MODULE_SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/hicn_forwarder_module.cc + ${CMAKE_CURRENT_SOURCE_DIR}/udp_socket_connector.cc +) + +# add_executable(hicnlight_module MACOSX_BUNDLE ${MODULE_SOURCE_FILES}) +# target_include_directories(hicnlight_module PRIVATE ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS}) +# set_target_properties(hicnlight_module PROPERTIES +# BUNDLE True +# MACOSX_BUNDLE_GUI_IDENTIFIER my.domain.style.identifier.hicnlight_module +# MACOSX_BUNDLE_BUNDLE_NAME hicnlight_module +# MACOSX_BUNDLE_BUNDLE_VERSION "0.1" +# MACOSX_BUNDLE_SHORT_VERSION_STRING "0.1" +# # MACOSX_BUNDLE_INFO_PLIST ${CMAKE_SOURCE_DIR}/cmake/customtemplate.plist.in +# ) + +build_module(hicnlight_module + SHARED + SOURCES ${MODULE_SOURCE_FILES} + DEPENDS ${DEPENDENCIES} + COMPONENT lib${LIBTRANSPORT} + INCLUDE_DIRS ${LIBTRANSPORT_INCLUDE_DIRS} ${LIBTRANSPORT_INTERNAL_INCLUDE_DIRS} + # LIBRARY_ROOT_DIR "vpp_plugins" + DEFINITIONS ${COMPILER_DEFINITIONS} + COMPILE_OPTIONS ${COMPILE_FLAGS} +) diff --git a/libtransport/src/io_modules/udp/hicn_forwarder_module.cc b/libtransport/src/io_modules/udp/hicn_forwarder_module.cc new file mode 100644 index 000000000..ba08dd8c0 --- /dev/null +++ b/libtransport/src/io_modules/udp/hicn_forwarder_module.cc @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <io_modules/udp/hicn_forwarder_module.h> +#include <io_modules/udp/udp_socket_connector.h> + +union AddressLight { + uint32_t ipv4; + struct in6_addr ipv6; +}; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; +} CommandHeader; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; + char symbolic_or_connid[16]; + union AddressLight address; + uint16_t cost; + uint8_t address_type; + uint8_t len; +} RouteToSelfCommand; + +typedef struct { + uint8_t message_type; + uint8_t command_id; + uint16_t length; + uint32_t seq_num; + char symbolic_or_connid[16]; +} DeleteSelfConnectionCommand; + +namespace { +static constexpr uint8_t addr_inet = 1; +static constexpr uint8_t addr_inet6 = 2; +static constexpr uint8_t add_route_command = 3; +static constexpr uint8_t delete_connection_command = 5; +static constexpr uint8_t request_light = 0xc0; +static constexpr char identifier[] = "SELF"; + +void fillCommandHeader(CommandHeader *header) { + // Allocate and fill the header + header->message_type = request_light; + header->length = 1; +} + +RouteToSelfCommand createCommandRoute(std::unique_ptr<sockaddr> &&addr, + uint8_t prefix_length) { + RouteToSelfCommand command = {0}; + + // check and set IP address + if (addr->sa_family == AF_INET) { + command.address_type = addr_inet; + command.address.ipv4 = ((sockaddr_in *)addr.get())->sin_addr.s_addr; + } else if (addr->sa_family == AF_INET6) { + command.address_type = addr_inet6; + command.address.ipv6 = ((sockaddr_in6 *)addr.get())->sin6_addr; + } + + // Fill remaining payload fields +#ifndef _WIN32 + strcpy(command.symbolic_or_connid, identifier); +#else + strcpy_s(command.symbolic_or_connid, 16, identifier); +#endif + command.cost = 1; + command.len = (uint8_t)prefix_length; + + // Allocate and fill the header + command.command_id = add_route_command; + fillCommandHeader((CommandHeader *)&command); + + return command; +} + +DeleteSelfConnectionCommand createCommandDeleteConnection() { + DeleteSelfConnectionCommand command = {0}; + fillCommandHeader((CommandHeader *)&command); + command.command_id = delete_connection_command; + +#ifndef _WIN32 + strcpy(command.symbolic_or_connid, identifier); +#else + strcpy_s(command.symbolic_or_connid, 16, identifier); +#endif + + return command; +} + +} // namespace + +namespace transport { + +namespace core { + +HicnForwarderModule::HicnForwarderModule() : IoModule(), connector_(nullptr) {} + +HicnForwarderModule::~HicnForwarderModule() {} + +void HicnForwarderModule::connect(bool is_consumer) { + connector_->connect(); + connector_->setRole(is_consumer ? Connector::Role::CONSUMER + : Connector::Role::PRODUCER); +} + +bool HicnForwarderModule::isConnected() { return connector_->isConnected(); } + +void HicnForwarderModule::send(Packet &packet) { + IoModule::send(packet); + packet.setChecksum(); + connector_->send(packet); +} + +void HicnForwarderModule::send(const uint8_t *packet, std::size_t len) { + counters_.tx_packets++; + counters_.tx_bytes += len; + + // Perfect forwarding + connector_->send(packet, len); +} + +void HicnForwarderModule::registerRoute(const Prefix &prefix) { + auto command = createCommandRoute(prefix.toSockaddr(), + (uint8_t)prefix.getPrefixLength()); + send((uint8_t *)&command, sizeof(RouteToSelfCommand)); +} + +void HicnForwarderModule::closeConnection() { + auto command = createCommandDeleteConnection(); + send((uint8_t *)&command, sizeof(DeleteSelfConnectionCommand)); + connector_->close(); +} + +void HicnForwarderModule::init( + Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, const std::string &app_name) { + if (!connector_) { + connector_ = new UdpSocketConnector(std::move(receive_callback), nullptr, + nullptr, std::move(reconnect_callback), + io_service, app_name); + } +} + +void HicnForwarderModule::processControlMessageReply( + utils::MemBuf &packet_buffer) { + if (packet_buffer.data()[0] == nack_code) { + throw errors::RuntimeException( + "Received Nack message from hicn light forwarder."); + } +} + +std::uint32_t HicnForwarderModule::getMtu() { return interface_mtu; } + +bool HicnForwarderModule::isControlMessage(const uint8_t *message) { + return message[0] == ack_code || message[0] == nack_code; +} + +extern "C" IoModule *create_module(void) { return new HicnForwarderModule(); } + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/hicn_forwarder_module.h b/libtransport/src/io_modules/udp/hicn_forwarder_module.h new file mode 100644 index 000000000..845db73bf --- /dev/null +++ b/libtransport/src/io_modules/udp/hicn_forwarder_module.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2017-2020 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/io_module.h> +#include <hicn/transport/core/prefix.h> + +namespace transport { + +namespace core { + +class UdpSocketConnector; + +class HicnForwarderModule : public IoModule { + static constexpr uint8_t ack_code = 0xc2; + static constexpr uint8_t nack_code = 0xc3; + static constexpr std::uint16_t interface_mtu = 1500; + + public: + union addressLight { + uint32_t ipv4; + struct in6_addr ipv6; + }; + + struct route_to_self_command { + uint8_t messageType; + uint8_t commandID; + uint16_t length; + uint32_t seqNum; + char symbolicOrConnid[16]; + union addressLight address; + uint16_t cost; + uint8_t addressType; + uint8_t len; + }; + + using route_to_self_command = struct route_to_self_command; + + HicnForwarderModule(); + + ~HicnForwarderModule(); + + void connect(bool is_consumer) override; + + void send(Packet &packet) override; + void send(const uint8_t *packet, std::size_t len) override; + + bool isConnected() override; + + void init(Connector::PacketReceivedCallback &&receive_callback, + Connector::OnReconnectCallback &&reconnect_callback, + asio::io_service &io_service, + const std::string &app_name = "Libtransport") override; + + void registerRoute(const Prefix &prefix) override; + + std::uint32_t getMtu() override; + + bool isControlMessage(const uint8_t *message) override; + + void processControlMessageReply(utils::MemBuf &packet_buffer) override; + + void closeConnection() override; + + private: + UdpSocketConnector *connector_; +}; + +extern "C" IoModule *create_module(void); + +} // namespace core + +} // namespace transport diff --git a/libtransport/src/io_modules/udp/udp_socket_connector.cc b/libtransport/src/io_modules/udp/udp_socket_connector.cc new file mode 100644 index 000000000..456886a54 --- /dev/null +++ b/libtransport/src/io_modules/udp/udp_socket_connector.cc @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _WIN32 +#include <hicn/transport/portability/win_portability.h> +#endif + +#include <hicn/transport/errors/errors.h> +#include <hicn/transport/utils/log.h> +#include <hicn/transport/utils/object_pool.h> +#include <io_modules/udp/udp_socket_connector.h> + +#include <thread> +#include <vector> + +namespace transport { + +namespace core { + +UdpSocketConnector::UdpSocketConnector( + PacketReceivedCallback &&receive_callback, PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, std::string app_name) + : Connector(std::move(receive_callback), std::move(packet_sent), + std::move(close_callback), std::move(on_reconnect)), + io_service_(io_service), + socket_(io_service_), + resolver_(io_service_), + connection_timer_(io_service_), + read_msg_(std::make_pair(nullptr, 0)), + is_reconnection_(false), + data_available_(false), + app_name_(app_name) {} + +UdpSocketConnector::~UdpSocketConnector() {} + +void UdpSocketConnector::connect(std::string ip_address, std::string port) { + endpoint_iterator_ = resolver_.resolve( + {ip_address, port, asio::ip::resolver_query_base::numeric_service}); + + state_ = Connector::State::CONNECTING; + doConnect(); +} + +void UdpSocketConnector::send(const uint8_t *packet, std::size_t len) { + socket_.async_send(asio::buffer(packet, len), + [this](std::error_code ec, std::size_t /*length*/) { + if (sent_callback_) { + sent_callback_(this, ec); + } + }); +} + +void UdpSocketConnector::send(Packet &packet) { + io_service_.post([this, _packet{packet.shared_from_this()}]() { + bool write_in_progress = !output_buffer_.empty(); + output_buffer_.push_back(std::move(_packet)); + if (TRANSPORT_EXPECT_TRUE(state_ == Connector::State::CONNECTED)) { + if (!write_in_progress) { + doWrite(); + } + } else { + // Tell the handle connect it has data to write + data_available_ = true; + } + }); +} + +void UdpSocketConnector::close() { + if (io_service_.stopped()) { + doClose(); + } else { + io_service_.dispatch(std::bind(&UdpSocketConnector::doClose, this)); + } +} + +void UdpSocketConnector::doClose() { + if (state_ != Connector::State::CLOSED) { + state_ = Connector::State::CLOSED; + if (socket_.is_open()) { + socket_.shutdown(asio::ip::tcp::socket::shutdown_type::shutdown_both); + socket_.close(); + } + } +} + +void UdpSocketConnector::doWrite() { + auto packet = output_buffer_.front().get(); + auto array = std::vector<asio::const_buffer>(); + + const utils::MemBuf *current = packet; + do { + array.push_back(asio::const_buffer(current->data(), current->length())); + current = current->next(); + } while (current != packet); + + socket_.async_send(std::move(array), [this](std::error_code ec, + std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + output_buffer_.pop_front(); + if (!output_buffer_.empty()) { + doWrite(); + } + } else if (ec.value() == static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + tryReconnect(); + } + }); +} + +void UdpSocketConnector::doRead() { + read_msg_ = getRawBuffer(); + socket_.async_receive( + asio::buffer(read_msg_.first, read_msg_.second), + [this](std::error_code ec, std::size_t length) { + if (TRANSPORT_EXPECT_TRUE(!ec)) { + auto packet = getPacketFromBuffer(read_msg_.first, length); + receive_callback_(this, *packet, std::make_error_code(std::errc(0))); + doRead(); + } else if (ec.value() == + static_cast<int>(std::errc::operation_canceled)) { + // The connection has been closed by the application. + return; + } else { + TRANSPORT_LOGE("%d %s", ec.value(), ec.message().c_str()); + tryReconnect(); + } + }); +} + +void UdpSocketConnector::tryReconnect() { + if (state_ == Connector::State::CONNECTED) { + TRANSPORT_LOGE("Connection lost. Trying to reconnect...\n"); + state_ = Connector::State::CONNECTING; + is_reconnection_ = true; + io_service_.post([this]() { + if (socket_.is_open()) { + socket_.shutdown(asio::ip::tcp::socket::shutdown_type::shutdown_both); + socket_.close(); + } + + doConnect(); + startConnectionTimer(); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + }); + } +} + +void UdpSocketConnector::doConnect() { + asio::async_connect( + socket_, endpoint_iterator_, + [this](std::error_code ec, udp::resolver::iterator) { + if (!ec) { + connection_timer_.cancel(); + state_ = Connector::State::CONNECTED; + doRead(); + + if (data_available_) { + data_available_ = false; + doWrite(); + } + + if (is_reconnection_) { + is_reconnection_ = false; + } + + on_reconnect_callback_(this); + } else { + doConnect(); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + }); +} + +bool UdpSocketConnector::checkConnected() { + return state_ == Connector::State::CONNECTED; +} + +void UdpSocketConnector::startConnectionTimer() { + connection_timer_.expires_from_now(std::chrono::seconds(60)); + connection_timer_.async_wait(std::bind(&UdpSocketConnector::handleDeadline, + this, std::placeholders::_1)); +} + +void UdpSocketConnector::handleDeadline(const std::error_code &ec) { + if (!ec) { + io_service_.post([this]() { + socket_.close(); + TRANSPORT_LOGE("Error connecting. Is the forwarder running?\n"); + }); + } +} + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/io_modules/udp/udp_socket_connector.h b/libtransport/src/io_modules/udp/udp_socket_connector.h new file mode 100644 index 000000000..8ab08e17a --- /dev/null +++ b/libtransport/src/io_modules/udp/udp_socket_connector.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/config.h> +#include <hicn/transport/core/connector.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/core/name.h> +#include <hicn/transport/core/packet.h> +#include <hicn/transport/utils/branch_prediction.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <deque> + +namespace transport { +namespace core { + +using asio::ip::udp; + +class UdpSocketConnector : public Connector { + public: + UdpSocketConnector(PacketReceivedCallback &&receive_callback, + PacketSentCallback &&packet_sent, + OnCloseCallback &&close_callback, + OnReconnectCallback &&on_reconnect, + asio::io_service &io_service, + std::string app_name = "Libtransport"); + + ~UdpSocketConnector() override; + + void send(Packet &packet) override; + + void send(const uint8_t *packet, std::size_t len) override; + + void close() override; + + void connect(std::string ip_address = "127.0.0.1", std::string port = "9695"); + + private: + void doConnect(); + + void doRead(); + + void doWrite(); + + void doClose(); + + bool checkConnected(); + + private: + void handleDeadline(const std::error_code &ec); + + void startConnectionTimer(); + + void tryReconnect(); + + asio::io_service &io_service_; + asio::ip::udp::socket socket_; + asio::ip::udp::resolver resolver_; + asio::ip::udp::resolver::iterator endpoint_iterator_; + asio::steady_timer connection_timer_; + + std::pair<uint8_t *, std::size_t> read_msg_; + + bool is_reconnection_; + bool data_available_; + + std::string app_name_; +}; + +} // end namespace core + +} // end namespace transport diff --git a/libtransport/src/protocols/CMakeLists.txt b/libtransport/src/protocols/CMakeLists.txt index 8bfbdd6ad..eba8d1aab 100644 --- a/libtransport/src/protocols/CMakeLists.txt +++ b/libtransport/src/protocols/CMakeLists.txt @@ -21,16 +21,15 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/datagram_reassembly.h ${CMAKE_CURRENT_SOURCE_DIR}/byte_stream_reassembly.h ${CMAKE_CURRENT_SOURCE_DIR}/congestion_window_protocol.h - ${CMAKE_CURRENT_SOURCE_DIR}/packet_manager.h ${CMAKE_CURRENT_SOURCE_DIR}/rate_estimation.h - ${CMAKE_CURRENT_SOURCE_DIR}/protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/transport_protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/production_protocol.h + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_bytestream.h + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_rtc.h ${CMAKE_CURRENT_SOURCE_DIR}/raaqm.h ${CMAKE_CURRENT_SOURCE_DIR}/raaqm_data_path.h ${CMAKE_CURRENT_SOURCE_DIR}/cbr.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc.h - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.h ${CMAKE_CURRENT_SOURCE_DIR}/errors.h - ${CMAKE_CURRENT_SOURCE_DIR}/verification_manager.h ${CMAKE_CURRENT_SOURCE_DIR}/data_processing_events.h ) @@ -41,15 +40,15 @@ list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/reassembly.cc ${CMAKE_CURRENT_SOURCE_DIR}/datagram_reassembly.cc ${CMAKE_CURRENT_SOURCE_DIR}/byte_stream_reassembly.cc - ${CMAKE_CURRENT_SOURCE_DIR}/protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/transport_protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/production_protocol.cc + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_bytestream.cc + ${CMAKE_CURRENT_SOURCE_DIR}/prod_protocol_rtc.cc ${CMAKE_CURRENT_SOURCE_DIR}/raaqm.cc ${CMAKE_CURRENT_SOURCE_DIR}/rate_estimation.cc ${CMAKE_CURRENT_SOURCE_DIR}/raaqm_data_path.cc ${CMAKE_CURRENT_SOURCE_DIR}/cbr.cc - ${CMAKE_CURRENT_SOURCE_DIR}/rtc.cc - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.cc ${CMAKE_CURRENT_SOURCE_DIR}/errors.cc - ${CMAKE_CURRENT_SOURCE_DIR}/verification_manager.cc ) set(RAAQM_CONFIG_INSTALL_PREFIX @@ -71,5 +70,7 @@ install( COMPONENT lib${LIBTRANSPORT} ) +add_subdirectory(rtc) + set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) -set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/protocols/byte_stream_reassembly.cc b/libtransport/src/protocols/byte_stream_reassembly.cc index 6662bec3f..d2bc961c4 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.cc +++ b/libtransport/src/protocols/byte_stream_reassembly.cc @@ -20,7 +20,7 @@ #include <protocols/byte_stream_reassembly.h> #include <protocols/errors.h> #include <protocols/indexer.h> -#include <protocols/protocol.h> +#include <protocols/transport_protocol.h> namespace transport { @@ -45,11 +45,11 @@ void ByteStreamReassembly::reassemble( } } -void ByteStreamReassembly::reassemble(ContentObject::Ptr &&content_object) { - if (TRANSPORT_EXPECT_TRUE(content_object != nullptr) && - read_buffer_->capacity()) { - received_packets_.emplace(std::make_pair( - content_object->getName().getSuffix(), std::move(content_object))); +void ByteStreamReassembly::reassemble(ContentObject &content_object) { + if (TRANSPORT_EXPECT_TRUE(read_buffer_->capacity())) { + received_packets_.emplace( + std::make_pair(content_object.getName().getSuffix(), + content_object.shared_from_this())); assembleContent(); } } @@ -81,25 +81,32 @@ void ByteStreamReassembly::assembleContent() { } } -bool ByteStreamReassembly::copyContent(const ContentObject &content_object) { +bool ByteStreamReassembly::copyContent(ContentObject &content_object) { bool ret = false; - auto payload = content_object.getPayloadReference(); - auto payload_length = payload.second; - auto write_size = std::min(payload_length, read_buffer_->tailroom()); - auto additional_bytes = payload_length > read_buffer_->tailroom() - ? payload_length - read_buffer_->tailroom() - : 0; + content_object.trimStart(content_object.headerSize()); - std::memcpy(read_buffer_->writableTail(), payload.first, write_size); - read_buffer_->append(write_size); + utils::MemBuf *current = &content_object; - if (!read_buffer_->tailroom()) { - notifyApplication(); - std::memcpy(read_buffer_->writableTail(), payload.first + write_size, - additional_bytes); - read_buffer_->append(additional_bytes); - } + do { + auto payload_length = current->length(); + auto write_size = std::min(payload_length, read_buffer_->tailroom()); + auto additional_bytes = payload_length > read_buffer_->tailroom() + ? payload_length - read_buffer_->tailroom() + : 0; + + std::memcpy(read_buffer_->writableTail(), current->data(), write_size); + read_buffer_->append(write_size); + + if (!read_buffer_->tailroom()) { + notifyApplication(); + std::memcpy(read_buffer_->writableTail(), current->data() + write_size, + additional_bytes); + read_buffer_->append(additional_bytes); + } + + current = current->next(); + } while (current != &content_object); download_complete_ = index_manager_->getFinalSuffix() == content_object.getName().getSuffix(); diff --git a/libtransport/src/protocols/byte_stream_reassembly.h b/libtransport/src/protocols/byte_stream_reassembly.h index e4f62b3a8..c682d58cb 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.h +++ b/libtransport/src/protocols/byte_stream_reassembly.h @@ -27,12 +27,12 @@ class ByteStreamReassembly : public Reassembly { TransportProtocol *transport_protocol); protected: - virtual void reassemble(core::ContentObject::Ptr &&content_object) override; + virtual void reassemble(core::ContentObject &content_object) override; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) override; - bool copyContent(const core::ContentObject &content_object); + bool copyContent(core::ContentObject &content_object); virtual void reInitialize() override; diff --git a/libtransport/src/protocols/data_processing_events.h b/libtransport/src/protocols/data_processing_events.h index 8975c2b4a..5c8c16157 100644 --- a/libtransport/src/protocols/data_processing_events.h +++ b/libtransport/src/protocols/data_processing_events.h @@ -24,8 +24,7 @@ namespace protocol { class ContentObjectProcessingEventCallback { public: virtual ~ContentObjectProcessingEventCallback() = default; - virtual void onPacketDropped(core::Interest::Ptr &&i, - core::ContentObject::Ptr &&c) = 0; + virtual void onPacketDropped(core::Interest &i, core::ContentObject &c) = 0; virtual void onReassemblyFailed(std::uint32_t missing_segment) = 0; }; diff --git a/libtransport/src/protocols/datagram_reassembly.cc b/libtransport/src/protocols/datagram_reassembly.cc index abd7e984d..962c1e020 100644 --- a/libtransport/src/protocols/datagram_reassembly.cc +++ b/libtransport/src/protocols/datagram_reassembly.cc @@ -24,8 +24,8 @@ DatagramReassembly::DatagramReassembly( TransportProtocol* transport_protocol) : Reassembly(icn_socket, transport_protocol) {} -void DatagramReassembly::reassemble(core::ContentObject::Ptr&& content_object) { - read_buffer_ = content_object->getPayload(); +void DatagramReassembly::reassemble(core::ContentObject& content_object) { + read_buffer_ = content_object.getPayload(); Reassembly::notifyApplication(); } diff --git a/libtransport/src/protocols/datagram_reassembly.h b/libtransport/src/protocols/datagram_reassembly.h index 2427ae62f..3462212d3 100644 --- a/libtransport/src/protocols/datagram_reassembly.h +++ b/libtransport/src/protocols/datagram_reassembly.h @@ -26,7 +26,7 @@ class DatagramReassembly : public Reassembly { DatagramReassembly(implementation::ConsumerSocket *icn_socket, TransportProtocol *transport_protocol); - virtual void reassemble(core::ContentObject::Ptr &&content_object) override; + virtual void reassemble(core::ContentObject &content_object) override; virtual void reInitialize() override; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) override { diff --git a/libtransport/src/protocols/errors.cc b/libtransport/src/protocols/errors.cc index eefb6f957..ae7b6e634 100644 --- a/libtransport/src/protocols/errors.cc +++ b/libtransport/src/protocols/errors.cc @@ -52,7 +52,9 @@ std::string protocol_category_impl::message(int ev) const { case protocol_error::session_aborted: { return "The session has been aborted by the application."; } - default: { return "Unknown protocol error"; } + default: { + return "Unknown protocol error"; + } } } diff --git a/libtransport/src/protocols/fec_base.h b/libtransport/src/protocols/fec_base.h new file mode 100644 index 000000000..a135c474f --- /dev/null +++ b/libtransport/src/protocols/fec_base.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/content_object.h> + +#include <functional> + +namespace transport { +namespace protocol { + +/** + * Interface classes to integrate FEC inside any producer transport protocol + */ +class ProducerFECBase { + public: + /** + * Callback, to be called by implementations as soon as a repair packet is + * ready. + */ + using RepairPacketsReady = + std::function<void(std::vector<core::ContentObject::Ptr> &)>; + + /** + * Producers will call this function upon production of a new packet. + */ + virtual void onPacketProduced(const core::ContentObject &content_object) = 0; + + /** + * Set callback to signal production protocol the repair packet is ready. + */ + void setFECCallback(const RepairPacketsReady &on_repair_packet) { + rep_packet_ready_callback_ = on_repair_packet; + } + + protected: + RepairPacketsReady rep_packet_ready_callback_; +}; + +/** + * Interface classes to integrate FEC inside any consumer transport protocol + */ +class ConsumerFECBase { + public: + /** + * Callback, to be called by implemrntations as soon as a packet is recovered. + */ + using OnPacketsRecovered = + std::function<void(std::vector<core::ContentObject::Ptr> &)>; + + /** + * Consumers will call this function when they receive a FEC packet. + */ + virtual void onFECPacket(const core::ContentObject &content_object) = 0; + + /** + * Consumers will call this function when they receive a data packet + */ + virtual void onDataPacket(const core::ContentObject &content_object) = 0; + + /** + * Set callback to signal consumer protocol the repair packet is ready. + */ + void setFECCallback(const OnPacketsRecovered &on_repair_packet) { + packet_recovered_callback_ = on_repair_packet; + } + + protected: + OnPacketsRecovered packet_recovered_callback_; +}; + +} // namespace protocol +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/protocols/incremental_indexer.cc b/libtransport/src/protocols/incremental_indexer.cc index 0872c4554..95daa0a3e 100644 --- a/libtransport/src/protocols/incremental_indexer.cc +++ b/libtransport/src/protocols/incremental_indexer.cc @@ -13,37 +13,38 @@ * limitations under the License. */ -#include <protocols/incremental_indexer.h> - #include <hicn/transport/interfaces/socket_consumer.h> -#include <protocols/protocol.h> +#include <protocols/errors.h> +#include <protocols/incremental_indexer.h> +#include <protocols/transport_protocol.h> namespace transport { namespace protocol { -void IncrementalIndexer::onContentObject( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { +void IncrementalIndexer::onContentObject(core::Interest &interest, + core::ContentObject &content_object) { using namespace interface; - TRANSPORT_LOGD("Receive content %s", content_object->getName().toString().c_str()); + TRANSPORT_LOGD("Received content %s", + content_object.getName().toString().c_str()); - if (TRANSPORT_EXPECT_FALSE(content_object->testRst())) { - final_suffix_ = content_object->getName().getSuffix(); + if (TRANSPORT_EXPECT_FALSE(content_object.testRst())) { + final_suffix_ = content_object.getName().getSuffix(); } - auto ret = verification_manager_->onPacketToVerify(*content_object); + auto ret = verifier_->verifyPackets(&content_object); switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - reassembly_->reassemble(std::move(content_object)); + case auth::VerificationPolicy::ACCEPT: { + reassembly_->reassemble(content_object); break; } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(interest), - std::move(content_object)); + case auth::VerificationPolicy::UNKNOWN: + case auth::VerificationPolicy::DROP: { + transport_protocol_->onPacketDropped(interest, content_object); break; } - case VerificationPolicy::ABORT_SESSION: { + case auth::VerificationPolicy::ABORT: { transport_protocol_->onContentReassembled( make_error_code(protocol_error::session_aborted)); break; diff --git a/libtransport/src/protocols/incremental_indexer.h b/libtransport/src/protocols/incremental_indexer.h index 20c5e4759..d7760f8e6 100644 --- a/libtransport/src/protocols/incremental_indexer.h +++ b/libtransport/src/protocols/incremental_indexer.h @@ -15,13 +15,13 @@ #pragma once -#include <hicn/transport/errors/runtime_exception.h> -#include <hicn/transport/errors/unexpected_manifest_exception.h> +#include <hicn/transport/errors/errors.h> +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/auth/verifier.h> #include <hicn/transport/utils/literals.h> - +#include <implementation/socket_consumer.h> #include <protocols/indexer.h> #include <protocols/reassembly.h> -#include <protocols/verification_manager.h> #include <deque> @@ -47,11 +47,12 @@ class IncrementalIndexer : public Indexer { first_suffix_(0), next_download_suffix_(0), next_reassembly_suffix_(0), - verification_manager_( - std::make_unique<SignatureVerificationManager>(icn_socket)) { + verifier_(nullptr) { if (reassembly_) { reassembly_->setIndexer(this); } + socket_->getSocketOption(implementation::GeneralTransportOptions::VERIFIER, + verifier_); } IncrementalIndexer(const IncrementalIndexer &) = delete; @@ -64,15 +65,14 @@ class IncrementalIndexer : public Indexer { first_suffix_(other.first_suffix_), next_download_suffix_(other.next_download_suffix_), next_reassembly_suffix_(other.next_reassembly_suffix_), - verification_manager_(std::move(other.verification_manager_)) { + verifier_(nullptr) { if (reassembly_) { reassembly_->setIndexer(this); } + socket_->getSocketOption(implementation::GeneralTransportOptions::VERIFIER, + verifier_); } - /** - * - */ virtual ~IncrementalIndexer() {} TRANSPORT_ALWAYS_INLINE virtual void reset( @@ -112,8 +112,8 @@ class IncrementalIndexer : public Indexer { return final_suffix_; } - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; TRANSPORT_ALWAYS_INLINE void setReassembly(Reassembly *reassembly) { reassembly_ = reassembly; @@ -123,10 +123,6 @@ class IncrementalIndexer : public Indexer { } } - TRANSPORT_ALWAYS_INLINE bool onKeyToVerify() override { - return verification_manager_->onKeyToVerify(); - } - protected: implementation::ConsumerSocket *socket_; Reassembly *reassembly_; @@ -135,9 +131,8 @@ class IncrementalIndexer : public Indexer { uint32_t first_suffix_; uint32_t next_download_suffix_; uint32_t next_reassembly_suffix_; - std::unique_ptr<VerificationManager> verification_manager_; + std::shared_ptr<auth::Verifier> verifier_; }; -} // end namespace protocol - -} // end namespace transport +} // namespace protocol +} // namespace transport diff --git a/libtransport/src/protocols/indexer.cc b/libtransport/src/protocols/indexer.cc index ca12330a6..1379a609c 100644 --- a/libtransport/src/protocols/indexer.cc +++ b/libtransport/src/protocols/indexer.cc @@ -14,11 +14,9 @@ */ #include <hicn/transport/utils/branch_prediction.h> - #include <protocols/incremental_indexer.h> #include <protocols/indexer.h> #include <protocols/manifest_incremental_indexer.h> -#include <protocols/protocol.h> namespace transport { namespace protocol { @@ -32,16 +30,16 @@ IndexManager::IndexManager(implementation::ConsumerSocket *icn_socket, transport_(transport), reassembly_(reassembly) {} -void IndexManager::onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) { +void IndexManager::onContentObject(core::Interest &interest, + core::ContentObject &content_object) { if (first_segment_received_) { - indexer_->onContentObject(std::move(interest), std::move(content_object)); + indexer_->onContentObject(interest, content_object); } else { - std::uint32_t segment_number = interest->getName().getSuffix(); + std::uint32_t segment_number = interest.getName().getSuffix(); if (segment_number == 0) { // Check if manifest - if (content_object->getPayloadType() == PayloadType::MANIFEST) { + if (content_object.getPayloadType() == core::PayloadType::MANIFEST) { IncrementalIndexer *indexer = static_cast<IncrementalIndexer *>(indexer_.release()); indexer_ = @@ -49,25 +47,21 @@ void IndexManager::onContentObject(core::Interest::Ptr &&interest, delete indexer; } - indexer_->onContentObject(std::move(interest), std::move(content_object)); + indexer_->onContentObject(interest, content_object); auto it = interest_data_set_.begin(); while (it != interest_data_set_.end()) { - indexer_->onContentObject( - std::move(const_cast<core::Interest::Ptr &&>(it->first)), - std::move(const_cast<core::ContentObject::Ptr &&>(it->second))); + indexer_->onContentObject(*it->first, *it->second); it = interest_data_set_.erase(it); } first_segment_received_ = true; } else { - interest_data_set_.emplace(std::move(interest), - std::move(content_object)); + interest_data_set_.emplace(interest.shared_from_this(), + content_object.shared_from_this()); } } } -bool IndexManager::onKeyToVerify() { return indexer_->onKeyToVerify(); } - void IndexManager::reset(std::uint32_t offset) { indexer_ = std::make_unique<IncrementalIndexer>(icn_socket_, transport_, reassembly_); diff --git a/libtransport/src/protocols/indexer.h b/libtransport/src/protocols/indexer.h index 8213a1503..49e22a4cf 100644 --- a/libtransport/src/protocols/indexer.h +++ b/libtransport/src/protocols/indexer.h @@ -33,10 +33,8 @@ class TransportProtocol; class Indexer { public: - /** - * - */ virtual ~Indexer() = default; + /** * Retrieve from the manifest the next suffix to retrieve. */ @@ -55,10 +53,8 @@ class Indexer { virtual void reset(std::uint32_t offset = 0) = 0; - virtual void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) = 0; - - virtual bool onKeyToVerify() = 0; + virtual void onContentObject(core::Interest &interest, + core::ContentObject &content_object) = 0; }; class IndexManager : Indexer { @@ -86,10 +82,8 @@ class IndexManager : Indexer { void reset(std::uint32_t offset = 0) override; - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; - - bool onKeyToVerify() override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; private: std::unique_ptr<Indexer> indexer_; diff --git a/libtransport/src/protocols/manifest_incremental_indexer.cc b/libtransport/src/protocols/manifest_incremental_indexer.cc index da835b577..a6312ca90 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer.cc +++ b/libtransport/src/protocols/manifest_incremental_indexer.cc @@ -14,9 +14,9 @@ */ #include <implementation/socket_consumer.h> - +#include <protocols/errors.h> #include <protocols/manifest_incremental_indexer.h> -#include <protocols/protocol.h> +#include <protocols/transport_protocol.h> #include <cmath> #include <deque> @@ -36,41 +36,46 @@ ManifestIncrementalIndexer::ManifestIncrementalIndexer( 0)) {} void ManifestIncrementalIndexer::onContentObject( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { - // Check if manifest or not - if (content_object->getPayloadType() == PayloadType::MANIFEST) { - TRANSPORT_LOGD("Receive content %s", content_object->getName().toString().c_str()); - onUntrustedManifest(std::move(interest), std::move(content_object)); - } else if (content_object->getPayloadType() == PayloadType::CONTENT_OBJECT) { - TRANSPORT_LOGD("Receive manifest %s", content_object->getName().toString().c_str()); - onUntrustedContentObject(std::move(interest), std::move(content_object)); - } -} - -void ManifestIncrementalIndexer::onUntrustedManifest( - core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) { - auto ret = verification_manager_->onPacketToVerify(*content_object); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - processTrustedManifest(std::move(content_object)); + core::Interest &interest, core::ContentObject &content_object) { + switch (content_object.getPayloadType()) { + case PayloadType::DATA: { + TRANSPORT_LOGD("Received content %s", + content_object.getName().toString().c_str()); + onUntrustedContentObject(interest, content_object); break; } - case VerificationPolicy::DROP_PACKET: - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); + case PayloadType::MANIFEST: { + TRANSPORT_LOGD("Received manifest %s", + content_object.getName().toString().c_str()); + onUntrustedManifest(interest, content_object); break; } + default: { + return; + } } } -void ManifestIncrementalIndexer::processTrustedManifest( - ContentObject::Ptr &&content_object) { +void ManifestIncrementalIndexer::onUntrustedManifest( + core::Interest &interest, core::ContentObject &content_object) { auto manifest = - std::make_unique<ContentObjectManifest>(std::move(*content_object)); + std::make_unique<ContentObjectManifest>(std::move(content_object)); + + auth::VerificationPolicy policy = verifier_->verifyPackets(manifest.get()); + manifest->decode(); + if (policy != auth::VerificationPolicy::ACCEPT) { + transport_protocol_->onContentReassembled( + make_error_code(protocol_error::session_aborted)); + return; + } + + processTrustedManifest(interest, std::move(manifest)); +} + +void ManifestIncrementalIndexer::processTrustedManifest( + core::Interest &interest, std::unique_ptr<ContentObjectManifest> manifest) { if (TRANSPORT_EXPECT_FALSE(manifest->getVersion() != core::ManifestVersion::VERSION_1)) { throw errors::RuntimeException("Received manifest with unknown version."); @@ -78,23 +83,45 @@ void ManifestIncrementalIndexer::processTrustedManifest( switch (manifest->getManifestType()) { case core::ManifestType::INLINE_MANIFEST: { - auto _it = manifest->getSuffixList().begin(); - auto _end = manifest->getSuffixList().end(); - suffix_strategy_->setFinalSuffix(manifest->getFinalBlockNumber()); - for (; _it != _end; _it++) { - auto hash = - std::make_pair(std::vector<uint8_t>(_it->second, _it->second + 32), - manifest->getHashAlgorithm()); + // The packets to verify with the received manifest + std::vector<auth::PacketPtr> packets; + + // Convert the received manifest to a map of packet suffixes to hashes + std::unordered_map<auth::Suffix, auth::HashEntry> current_manifest = + core::ContentObjectManifest::getSuffixMap(manifest.get()); + + // Update 'suffix_map_' with new hashes from the received manifest and + // build 'packets' + for (auto it = current_manifest.begin(); it != current_manifest.end();) { + if (unverified_segments_.find(it->first) == + unverified_segments_.end()) { + suffix_map_[it->first] = std::move(it->second); + current_manifest.erase(it++); + continue; + } - if (!checkUnverifiedSegments(_it->first, hash)) { - suffix_hash_map_[_it->first] = std::move(hash); + packets.push_back(unverified_segments_[it->first].second.get()); + it++; + } + + // Verify unverified segments using the received manifest + std::vector<auth::VerificationPolicy> policies = + verifier_->verifyPackets(packets, current_manifest); + + for (unsigned int i = 0; i < packets.size(); ++i) { + auth::Suffix suffix = packets[i]->getName().getSuffix(); + + if (policies[i] != auth::VerificationPolicy::UNKNOWN) { + unverified_segments_.erase(suffix); } + + applyPolicy(*unverified_segments_[suffix].first, + *unverified_segments_[suffix].second, policies[i]); } reassembly_->reassemble(std::move(manifest)); - break; } case core::ManifestType::FLIC_MANIFEST: { @@ -106,89 +133,47 @@ void ManifestIncrementalIndexer::processTrustedManifest( } } -bool ManifestIncrementalIndexer::checkUnverifiedSegments( - std::uint32_t suffix, const HashEntry &hash) { - auto it = unverified_segments_.find(suffix); - - if (it != unverified_segments_.end()) { - auto ret = verifyContentObject(hash, *it->second.second); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - reassembly_->reassemble(std::move(it->second.second)); - break; - } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(it->second.first), - std::move(it->second.second)); - break; - } - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); - break; - } +void ManifestIncrementalIndexer::onUntrustedContentObject( + Interest &interest, ContentObject &content_object) { + auth::Suffix suffix = content_object.getName().getSuffix(); + auth::VerificationPolicy policy = + verifier_->verifyPackets(&content_object, suffix_map_); + + switch (policy) { + case auth::VerificationPolicy::UNKNOWN: { + unverified_segments_[suffix] = std::make_pair( + interest.shared_from_this(), content_object.shared_from_this()); + break; + } + default: { + suffix_map_.erase(suffix); + break; } - - unverified_segments_.erase(it); - return true; - } - - return false; -} - -VerificationPolicy ManifestIncrementalIndexer::verifyContentObject( - const HashEntry &manifest_hash, const ContentObject &content_object) { - VerificationPolicy ret; - - auto hash_type = static_cast<utils::CryptoHashType>(manifest_hash.second); - auto data_packet_digest = content_object.computeDigest(manifest_hash.second); - auto data_packet_digest_bytes = - data_packet_digest.getDigest<uint8_t>().data(); - const std::vector<uint8_t> &manifest_digest_bytes = manifest_hash.first; - - if (utils::CryptoHash::compareBinaryDigest( - data_packet_digest_bytes, manifest_digest_bytes.data(), hash_type)) { - ret = VerificationPolicy::ACCEPT_PACKET; - } else { - ConsumerContentObjectVerificationFailedCallback - *verification_failed_callback = VOID_HANDLER; - socket_->getSocketOption(ConsumerCallbacksOptions::VERIFICATION_FAILED, - &verification_failed_callback); - ret = (*verification_failed_callback)( - *socket_->getInterface(), content_object, - make_error_code(protocol_error::integrity_verification_failed)); } - return ret; + applyPolicy(interest, content_object, policy); } -void ManifestIncrementalIndexer::onUntrustedContentObject( - Interest::Ptr &&i, ContentObject::Ptr &&c) { - auto suffix = c->getName().getSuffix(); - auto it = suffix_hash_map_.find(suffix); - - if (it != suffix_hash_map_.end()) { - auto ret = verifyContentObject(it->second, *c); - - switch (ret) { - case VerificationPolicy::ACCEPT_PACKET: { - suffix_hash_map_.erase(it); - reassembly_->reassemble(std::move(c)); - break; - } - case VerificationPolicy::DROP_PACKET: { - transport_protocol_->onPacketDropped(std::move(i), std::move(c)); - break; - } - case VerificationPolicy::ABORT_SESSION: { - transport_protocol_->onContentReassembled( - make_error_code(protocol_error::session_aborted)); - break; - } +void ManifestIncrementalIndexer::applyPolicy( + core::Interest &interest, core::ContentObject &content_object, + auth::VerificationPolicy policy) { + switch (policy) { + case auth::VerificationPolicy::ACCEPT: { + reassembly_->reassemble(content_object); + break; + } + case auth::VerificationPolicy::DROP: { + transport_protocol_->onPacketDropped(interest, content_object); + break; + } + case auth::VerificationPolicy::ABORT: { + transport_protocol_->onContentReassembled( + make_error_code(protocol_error::session_aborted)); + break; + } + default: { + break; } - } else { - unverified_segments_[suffix] = std::make_pair(std::move(i), std::move(c)); } } @@ -224,7 +209,7 @@ uint32_t ManifestIncrementalIndexer::getNextReassemblySegment() { void ManifestIncrementalIndexer::reset(std::uint32_t offset) { IncrementalIndexer::reset(offset); - suffix_hash_map_.clear(); + suffix_map_.clear(); unverified_segments_.clear(); SuffixQueue empty; std::swap(suffix_queue_, empty); diff --git a/libtransport/src/protocols/manifest_incremental_indexer.h b/libtransport/src/protocols/manifest_incremental_indexer.h index 38b01533e..1bb76eb87 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer.h +++ b/libtransport/src/protocols/manifest_incremental_indexer.h @@ -15,6 +15,7 @@ #pragma once +#include <hicn/transport/auth/common.h> #include <implementation/socket.h> #include <protocols/incremental_indexer.h> #include <utils/suffix_strategy.h> @@ -22,7 +23,6 @@ #include <list> namespace transport { - namespace protocol { class ManifestIncrementalIndexer : public IncrementalIndexer { @@ -30,7 +30,8 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { public: using SuffixQueue = std::queue<uint32_t>; - using HashEntry = std::pair<std::vector<uint8_t>, utils::CryptoHashType>; + using InterestContentPair = + std::pair<core::Interest::Ptr, core::ContentObject::Ptr>; ManifestIncrementalIndexer(implementation::ConsumerSocket *icn_socket, TransportProtocol *transport, @@ -50,8 +51,8 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { void reset(std::uint32_t offset = 0) override; - void onContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object) override; + void onContentObject(core::Interest &interest, + core::ContentObject &content_object) override; uint32_t getNextSuffix() override; @@ -61,30 +62,24 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { uint32_t getFinalSuffix() override; - private: - void onUntrustedManifest(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object); - void onUntrustedContentObject(core::Interest::Ptr &&interest, - core::ContentObject::Ptr &&content_object); - void processTrustedManifest(core::ContentObject::Ptr &&content_object); - void onManifestReceived(core::Interest::Ptr &&i, - core::ContentObject::Ptr &&c); - void onManifestTimeout(core::Interest::Ptr &&i); - VerificationPolicy verifyContentObject( - const HashEntry &manifest_hash, - const core::ContentObject &content_object); - bool checkUnverifiedSegments(std::uint32_t suffix, const HashEntry &hash); - protected: std::unique_ptr<utils::SuffixStrategy> suffix_strategy_; SuffixQueue suffix_queue_; // Hash verification - std::unordered_map<uint32_t, HashEntry> suffix_hash_map_; + std::unordered_map<auth::Suffix, auth::HashEntry> suffix_map_; + std::unordered_map<auth::Suffix, InterestContentPair> unverified_segments_; - std::unordered_map<uint32_t, - std::pair<core::Interest::Ptr, core::ContentObject::Ptr>> - unverified_segments_; + private: + void onUntrustedManifest(core::Interest &interest, + core::ContentObject &content_object); + void processTrustedManifest(core::Interest &interest, + std::unique_ptr<ContentObjectManifest> manifest); + void onUntrustedContentObject(core::Interest &interest, + core::ContentObject &content_object); + void applyPolicy(core::Interest &interest, + core::ContentObject &content_object, + auth::VerificationPolicy policy); }; } // end namespace protocol diff --git a/libtransport/src/protocols/prod_protocol_bytestream.cc b/libtransport/src/protocols/prod_protocol_bytestream.cc new file mode 100644 index 000000000..6bd989fe4 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_bytestream.cc @@ -0,0 +1,390 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <implementation/socket_producer.h> +#include <protocols/prod_protocol_bytestream.h> + +#include <atomic> + +namespace transport { + +namespace protocol { + +using namespace core; +using namespace implementation; + +ByteStreamProductionProtocol::ByteStreamProductionProtocol( + implementation::ProducerSocket *icn_socket) + : ProductionProtocol(icn_socket) {} + +ByteStreamProductionProtocol::~ByteStreamProductionProtocol() { + stop(); + if (listening_thread_.joinable()) { + listening_thread_.join(); + } +} + +uint32_t ByteStreamProductionProtocol::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + throw errors::NotImplementedException(); +} + +uint32_t ByteStreamProductionProtocol::produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) { + throw errors::NotImplementedException(); +} + +uint32_t ByteStreamProductionProtocol::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, + bool is_last, + uint32_t start_offset) { + if (!buffer_size) { + return 0; + } + + return produceStream(content_name, + utils::MemBuf::copyBuffer(buffer, buffer_size), is_last, + start_offset); +} + +uint32_t ByteStreamProductionProtocol::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + if (TRANSPORT_EXPECT_FALSE(buffer->length() == 0)) { + return 0; + } + + Name name(content_name); + + // Get the atomic variables to ensure they keep the same value + // during the production + + // Total size of the data packet + uint32_t data_packet_size; + socket_->getSocketOption(GeneralTransportOptions::DATA_PACKET_SIZE, + data_packet_size); + + // Expiry time + uint32_t content_object_expiry_time; + socket_->getSocketOption(GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME, + content_object_expiry_time); + + // Hash algorithm + auth::CryptoHashType hash_algo; + socket_->getSocketOption(GeneralTransportOptions::HASH_ALGORITHM, hash_algo); + + // Use manifest + bool making_manifest; + socket_->getSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + making_manifest); + + // Suffix calculation strategy + core::NextSegmentCalculationStrategy _suffix_strategy; + socket_->getSocketOption(GeneralTransportOptions::SUFFIX_STRATEGY, + _suffix_strategy); + auto suffix_strategy = utils::SuffixStrategyFactory::getSuffixStrategy( + _suffix_strategy, start_offset); + + std::shared_ptr<auth::Signer> signer; + socket_->getSocketOption(GeneralTransportOptions::SIGNER, signer); + + auto buffer_size = buffer->length(); + int bytes_segmented = 0; + std::size_t header_size; + std::size_t manifest_header_size = 0; + std::size_t signature_length = 0; + std::uint32_t final_block_number = start_offset; + uint64_t free_space_for_content = 0; + + core::Packet::Format format; + std::shared_ptr<core::ContentObjectManifest> manifest; + bool is_last_manifest = false; + + // TODO Manifest may still be used for indexing + if (making_manifest && !signer) { + TRANSPORT_LOGE("Making manifests without setting producer identity."); + } + + core::Packet::Format hf_format = core::Packet::Format::HF_UNSPEC; + core::Packet::Format hf_format_ah = core::Packet::Format::HF_UNSPEC; + + if (name.getType() == HNT_CONTIGUOUS_V4 || name.getType() == HNT_IOV_V4) { + hf_format = core::Packet::Format::HF_INET_TCP; + hf_format_ah = core::Packet::Format::HF_INET_TCP_AH; + } else if (name.getType() == HNT_CONTIGUOUS_V6 || + name.getType() == HNT_IOV_V6) { + hf_format = core::Packet::Format::HF_INET6_TCP; + hf_format_ah = core::Packet::Format::HF_INET6_TCP_AH; + } else { + throw errors::RuntimeException("Unknown name format."); + } + + format = hf_format; + if (making_manifest) { + manifest_header_size = core::Packet::getHeaderSizeFromFormat( + signer ? hf_format_ah : hf_format, + signer ? signer->getSignatureSize() : 0); + } else if (signer) { + format = hf_format_ah; + signature_length = signer->getSignatureSize(); + } + + header_size = core::Packet::getHeaderSizeFromFormat(format, signature_length); + free_space_for_content = data_packet_size - header_size; + uint32_t number_of_segments = + uint32_t(std::ceil(double(buffer_size) / double(free_space_for_content))); + if (free_space_for_content * number_of_segments < buffer_size) { + number_of_segments++; + } + + // TODO allocate space for all the headers + if (making_manifest) { + uint32_t segment_in_manifest = static_cast<uint32_t>( + std::floor(double(data_packet_size - manifest_header_size - + ContentObjectManifest::getManifestHeaderSize()) / + ContentObjectManifest::getManifestEntrySize()) - + 1.0); + uint32_t number_of_manifests = static_cast<uint32_t>( + std::ceil(float(number_of_segments) / segment_in_manifest)); + final_block_number += number_of_segments + number_of_manifests - 1; + + manifest.reset(ContentObjectManifest::createManifest( + name.setSuffix(suffix_strategy->getNextManifestSuffix()), + core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, + hash_algo, is_last_manifest, name, _suffix_strategy, + signer ? signer->getSignatureSize() : 0)); + manifest->setLifetime(content_object_expiry_time); + + if (is_last) { + manifest->setFinalBlockNumber(final_block_number); + } else { + manifest->setFinalBlockNumber(utils::SuffixStrategy::INVALID_SUFFIX); + } + } + + for (unsigned int packaged_segments = 0; + packaged_segments < number_of_segments; packaged_segments++) { + if (making_manifest) { + if (manifest->estimateManifestSize(2) > + data_packet_size - manifest_header_size) { + manifest->encode(); + + // If identity set, sign manifest + if (signer) { + signer->signPacket(manifest.get()); + } + + // Send the current manifest + passContentObjectToCallbacks(manifest); + + TRANSPORT_LOGD("Send manifest %s", + manifest->getName().toString().c_str()); + + // Send content objects stored in the queue + while (!content_queue_.empty()) { + passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %s", + content_queue_.front()->getName().toString().c_str()); + content_queue_.pop(); + } + + // Create new manifest. The reference to the last manifest has been + // acquired in the passContentObjectToCallbacks function, so we can + // safely release this reference + manifest.reset(ContentObjectManifest::createManifest( + name.setSuffix(suffix_strategy->getNextManifestSuffix()), + core::ManifestVersion::VERSION_1, + core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, + name, _suffix_strategy, signer ? signer->getSignatureSize() : 0)); + + manifest->setLifetime(content_object_expiry_time); + manifest->setFinalBlockNumber( + is_last ? final_block_number + : utils::SuffixStrategy::INVALID_SUFFIX); + } + } + + auto content_suffix = suffix_strategy->getNextContentSuffix(); + auto content_object = std::make_shared<ContentObject>( + name.setSuffix(content_suffix), format, + signer && !making_manifest ? signer->getSignatureSize() : 0); + content_object->setLifetime(content_object_expiry_time); + + auto b = buffer->cloneOne(); + b->trimStart(free_space_for_content * packaged_segments); + b->trimEnd(b->length()); + + if (TRANSPORT_EXPECT_FALSE(packaged_segments == number_of_segments - 1)) { + b->append(buffer_size - bytes_segmented); + bytes_segmented += (int)(buffer_size - bytes_segmented); + + if (is_last && making_manifest) { + is_last_manifest = true; + } else if (is_last) { + content_object->setRst(); + } + + } else { + b->append(free_space_for_content); + bytes_segmented += (int)(free_space_for_content); + } + + content_object->appendPayload(std::move(b)); + + if (making_manifest) { + using namespace std::chrono_literals; + auth::CryptoHash hash = content_object->computeDigest(hash_algo); + manifest->addSuffixHash(content_suffix, hash); + content_queue_.push(content_object); + } else { + if (signer) { + signer->signPacket(content_object.get()); + } + passContentObjectToCallbacks(content_object); + TRANSPORT_LOGD("Send content %s", + content_object->getName().toString().c_str()); + } + } + + if (making_manifest) { + if (is_last_manifest) { + manifest->setFinalManifest(is_last_manifest); + } + + manifest->encode(); + + if (signer) { + signer->signPacket(manifest.get()); + } + + passContentObjectToCallbacks(manifest); + TRANSPORT_LOGD("Send manifest %s", manifest->getName().toString().c_str()); + + while (!content_queue_.empty()) { + passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %s", + content_queue_.front()->getName().toString().c_str()); + content_queue_.pop(); + } + } + + portal_->getIoService().post([this]() { + std::shared_ptr<ContentObject> co; + while (object_queue_for_callbacks_.pop(co)) { + if (*on_new_segment_) { + on_new_segment_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_to_sign_) { + on_content_object_to_sign_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()( + *socket_->getInterface(), *co); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *co); + } + } + }); + + portal_->getIoService().dispatch([this, buffer_size]() { + if (*on_content_produced_) { + on_content_produced_->operator()(*socket_->getInterface(), + std::make_error_code(std::errc(0)), + buffer_size); + } + }); + + return suffix_strategy->getTotalCount(); +} + +void ByteStreamProductionProtocol::scheduleSendBurst() { + portal_->getIoService().post([this]() { + std::shared_ptr<ContentObject> co; + + for (uint32_t i = 0; i < burst_size; i++) { + if (object_queue_for_callbacks_.pop(co)) { + if (*on_new_segment_) { + on_new_segment_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_to_sign_) { + on_content_object_to_sign_->operator()(*socket_->getInterface(), *co); + } + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()( + *socket_->getInterface(), *co); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *co); + } + } else { + break; + } + } + }); +} + +void ByteStreamProductionProtocol::passContentObjectToCallbacks( + const std::shared_ptr<ContentObject> &content_object) { + output_buffer_.insert(content_object); + portal_->sendContentObject(*content_object); + object_queue_for_callbacks_.push(std::move(content_object)); + + if (object_queue_for_callbacks_.size() >= burst_size) { + scheduleSendBurst(); + } +} + +void ByteStreamProductionProtocol::onInterest(Interest &interest) { + TRANSPORT_LOGD("Received interest for %s", + interest.getName().toString().c_str()); + if (*on_interest_input_) { + on_interest_input_->operator()(*socket_->getInterface(), interest); + } + + const std::shared_ptr<ContentObject> content_object = + output_buffer_.find(interest); + + if (content_object) { + if (*on_interest_satisfied_output_buffer_) { + on_interest_satisfied_output_buffer_->operator()(*socket_->getInterface(), + interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + portal_->sendContentObject(*content_object); + } else { + if (*on_interest_process_) { + on_interest_process_->operator()(*socket_->getInterface(), interest); + } + } +} + +void ByteStreamProductionProtocol::onError(std::error_code ec) {} + +} // namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_bytestream.h b/libtransport/src/protocols/prod_protocol_bytestream.h new file mode 100644 index 000000000..cf36b90a5 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_bytestream.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/utils/ring_buffer.h> +#include <protocols/production_protocol.h> + +#include <atomic> +#include <queue> + +namespace transport { + +namespace protocol { + +using namespace core; + +class ByteStreamProductionProtocol : public ProductionProtocol { + static constexpr uint32_t burst_size = 256; + + public: + ByteStreamProductionProtocol(implementation::ProducerSocket *icn_socket); + + ~ByteStreamProductionProtocol() override; + + using ProductionProtocol::start; + using ProductionProtocol::stop; + + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, const uint8_t *buffer, + size_t buffer_size) override; + + protected: + // Consumer Callback + // void reset() override; + void onInterest(core::Interest &i) override; + void onError(std::error_code ec) override; + + private: + void passContentObjectToCallbacks( + const std::shared_ptr<ContentObject> &content_object); + void scheduleSendBurst(); + + private: + // While manifests are being built, contents are stored in a queue + std::queue<std::shared_ptr<ContentObject>> content_queue_; + utils::CircularFifo<std::shared_ptr<ContentObject>, 2048> + object_queue_for_callbacks_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_rtc.cc b/libtransport/src/protocols/prod_protocol_rtc.cc new file mode 100644 index 000000000..8081923e3 --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_rtc.cc @@ -0,0 +1,481 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/global_object_pool.h> +#include <implementation/socket_producer.h> +#include <protocols/prod_protocol_rtc.h> +#include <protocols/rtc/rtc_consts.h> +#include <stdlib.h> +#include <time.h> + +#include <unordered_set> + +namespace transport { +namespace protocol { + +RTCProductionProtocol::RTCProductionProtocol( + implementation::ProducerSocket *icn_socket) + : ProductionProtocol(icn_socket), + current_seg_(1), + produced_bytes_(0), + produced_packets_(0), + max_packet_production_(1), + bytes_production_rate_(0), + packets_production_rate_(0), + last_round_(std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count()), + allow_delayed_nacks_(false), + queue_timer_on_(false), + consumer_in_sync_(false), + on_consumer_in_sync_(nullptr) { + srand((unsigned int)time(NULL)); + prod_label_ = rand() % 256; + interests_queue_timer_ = + std::make_unique<asio::steady_timer>(portal_->getIoService()); + round_timer_ = std::make_unique<asio::steady_timer>(portal_->getIoService()); + setOutputBufferSize(10000); + scheduleRoundTimer(); +} + +RTCProductionProtocol::~RTCProductionProtocol() {} + +void RTCProductionProtocol::registerNamespaceWithNetwork( + const Prefix &producer_namespace) { + ProductionProtocol::registerNamespaceWithNetwork(producer_namespace); + + flow_name_ = producer_namespace.getName(); + auto family = flow_name_.getAddressFamily(); + + switch (family) { + case AF_INET6: + header_size_ = (uint32_t)Packet::getHeaderSizeFromFormat(HF_INET6_TCP); + break; + case AF_INET: + header_size_ = (uint32_t)Packet::getHeaderSizeFromFormat(HF_INET_TCP); + break; + default: + throw errors::RuntimeException("Unknown name format."); + } +} + +void RTCProductionProtocol::scheduleRoundTimer() { + round_timer_->expires_from_now( + std::chrono::milliseconds(rtc::PRODUCER_STATS_INTERVAL)); + round_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + updateStats(); + }); +} + +void RTCProductionProtocol::updateStats() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t duration = now - last_round_; + if (duration == 0) duration = 1; + double per_second = rtc::MILLI_IN_A_SEC / duration; + + uint32_t prev_packets_production_rate = packets_production_rate_; + + bytes_production_rate_ = ceil((double)produced_bytes_ * per_second); + packets_production_rate_ = ceil((double)produced_packets_ * per_second); + + TRANSPORT_LOGD("Updating production rate: produced_bytes_ = %u bps = %u", + produced_bytes_, bytes_production_rate_); + + // update the production rate as soon as it increases by 10% with respect to + // the last round + max_packet_production_ = + produced_packets_ + ceil((double)produced_packets_ * 0.1); + if (max_packet_production_ < rtc::WIN_MIN) + max_packet_production_ = rtc::WIN_MIN; + + if (packets_production_rate_ != 0) { + allow_delayed_nacks_ = false; + } else if (prev_packets_production_rate == 0) { + // at least 2 rounds with production rate = 0 + allow_delayed_nacks_ = true; + } + + // check if the production rate is decreased. if yes send nacks if needed + if (prev_packets_production_rate < packets_production_rate_) { + sendNacksForPendingInterests(); + } + + produced_bytes_ = 0; + produced_packets_ = 0; + last_round_ = now; + scheduleRoundTimer(); +} + +uint32_t RTCProductionProtocol::produceStream( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + throw errors::NotImplementedException(); +} + +uint32_t RTCProductionProtocol::produceStream(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size, bool is_last, + uint32_t start_offset) { + throw errors::NotImplementedException(); +} + +void RTCProductionProtocol::produce(ContentObject &content_object) { + throw errors::NotImplementedException(); +} + +uint32_t RTCProductionProtocol::produceDatagram( + const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { + std::size_t buffer_size = buffer->length(); + if (TRANSPORT_EXPECT_FALSE(buffer_size == 0)) return 0; + + uint32_t data_packet_size; + socket_->getSocketOption(interface::GeneralTransportOptions::DATA_PACKET_SIZE, + data_packet_size); + + if (TRANSPORT_EXPECT_FALSE((buffer_size + header_size_ + + rtc::DATA_HEADER_SIZE) > data_packet_size)) { + return 0; + } + + auto content_object = + core::PacketManager<>::getInstance().getPacket<ContentObject>(); + // add rtc header to the payload + struct rtc::data_packet_t header; + content_object->appendPayload((const uint8_t *)&header, + rtc::DATA_HEADER_SIZE); + content_object->appendPayload(buffer->data(), buffer->length()); + + std::shared_ptr<ContentObject> co = std::move(content_object); + + // schedule actual sending on internal thread + portal_->getIoService().dispatch( + [this, content_object{std::move(co)}, content_name]() mutable { + produceInternal(std::move(content_object), content_name); + }); + + return 1; +} + +void RTCProductionProtocol::produceInternal( + std::shared_ptr<ContentObject> &&content_object, const Name &content_name) { + // set rtc header + struct rtc::data_packet_t *data_pkt = + (struct rtc::data_packet_t *)content_object->getPayload()->data(); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + data_pkt->setTimestamp(now); + data_pkt->setProductionRate(bytes_production_rate_); + + // set hicn stuff + Name n(content_name); + content_object->setName(n.setSuffix(current_seg_)); + content_object->setLifetime(500); // XXX this should be set by the APP + content_object->setPathLabel(prod_label_); + + // update stats + produced_bytes_ += + content_object->headerSize() + content_object->payloadSize(); + produced_packets_++; + + if (produced_packets_ >= max_packet_production_) { + // in this case all the pending interests may be used to accomodate the + // sudden increase in the production rate. calling the updateStats we will + // notify all the clients + round_timer_->cancel(); + updateStats(); + } + + TRANSPORT_LOGD("Sending content object: %s", n.toString().c_str()); + + output_buffer_.insert(content_object); + + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()(*socket_->getInterface(), + *content_object); + } + + portal_->sendContentObject(*content_object); + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + // remove interests from the interest cache if it exists + removeFromInterestQueue(current_seg_); + + current_seg_ = (current_seg_ + 1) % rtc::MIN_PROBE_SEQ; +} + +void RTCProductionProtocol::onInterest(Interest &interest) { + uint32_t interest_seg = interest.getName().getSuffix(); + uint32_t lifetime = interest.getLifetime(); + + if (interest_seg == 0) { + // first packet from the consumer, reset sync state + consumer_in_sync_ = false; + } + + if (*on_interest_input_) { + on_interest_input_->operator()(*socket_->getInterface(), interest); + } + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (interest_seg > rtc::MIN_PROBE_SEQ) { + TRANSPORT_LOGD("received probe %u", interest_seg); + sendNack(interest_seg); + return; + } + + TRANSPORT_LOGD("received interest %u", interest_seg); + + const std::shared_ptr<ContentObject> content_object = + output_buffer_.find(interest); + + if (content_object) { + if (*on_interest_satisfied_output_buffer_) { + on_interest_satisfied_output_buffer_->operator()(*socket_->getInterface(), + interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } + + TRANSPORT_LOGD("Send content %u (onInterest)", + content_object->getName().getSuffix()); + portal_->sendContentObject(*content_object); + return; + } else { + if (*on_interest_process_) { + on_interest_process_->operator()(*socket_->getInterface(), interest); + } + } + + // if the production rate 0 use delayed nacks + if (allow_delayed_nacks_ && interest_seg >= current_seg_) { + uint64_t next_timer = ~0; + if (!timers_map_.empty()) { + next_timer = timers_map_.begin()->first; + } + + uint64_t expiration = now + rtc::SENTINEL_TIMER_INTERVAL; + addToInterestQueue(interest_seg, expiration); + + // here we have at least one interest in the queue, we need to start or + // update the timer + if (!queue_timer_on_) { + // set timeout + queue_timer_on_ = true; + scheduleQueueTimer(timers_map_.begin()->first - now); + } else { + // re-schedule the timer because a new interest will expires sooner + if (next_timer > timers_map_.begin()->first) { + interests_queue_timer_->cancel(); + scheduleQueueTimer(timers_map_.begin()->first - now); + } + } + return; + } + + if (queue_timer_on_) { + // the producer is producing. Send nacks to packets that will expire before + // the data production and remove the timer + queue_timer_on_ = false; + interests_queue_timer_->cancel(); + sendNacksForPendingInterests(); + } + + uint32_t max_gap = (uint32_t)floor( + (double)((double)((double)lifetime * + rtc::INTEREST_LIFETIME_REDUCTION_FACTOR / + rtc::MILLI_IN_A_SEC) * + (double)packets_production_rate_)); + + if (interest_seg < current_seg_ || interest_seg > (max_gap + current_seg_)) { + sendNack(interest_seg); + } else { + if (!consumer_in_sync_ && on_consumer_in_sync_) { + // we consider the remote consumer to be in sync as soon as it covers 70% + // of the production window with interests + uint32_t perc = ceil((double)max_gap * 0.7); + if (interest_seg > (perc + current_seg_)) { + consumer_in_sync_ = true; + on_consumer_in_sync_(*socket_->getInterface(), interest); + } + } + uint64_t expiration = + now + floor((double)lifetime * rtc::INTEREST_LIFETIME_REDUCTION_FACTOR); + addToInterestQueue(interest_seg, expiration); + } +} + +void RTCProductionProtocol::onError(std::error_code ec) {} + +void RTCProductionProtocol::scheduleQueueTimer(uint64_t wait) { + interests_queue_timer_->expires_from_now(std::chrono::milliseconds(wait)); + interests_queue_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + interestQueueTimer(); + }); +} + +void RTCProductionProtocol::addToInterestQueue(uint32_t interest_seg, + uint64_t expiration) { + // check if the seq number exists already + auto it_seqs = seqs_map_.find(interest_seg); + if (it_seqs != seqs_map_.end()) { + // the seq already exists + if (expiration < it_seqs->second) { + // we need to update the timer becasue we got a smaller one + // 1) remove the entry from the multimap + // 2) update this entry + auto range = timers_map_.equal_range(it_seqs->second); + for (auto it_timers = range.first; it_timers != range.second; + it_timers++) { + if (it_timers->second == it_seqs->first) { + timers_map_.erase(it_timers); + break; + } + } + timers_map_.insert( + std::pair<uint64_t, uint32_t>(expiration, interest_seg)); + it_seqs->second = expiration; + } else { + // nothing to do here + return; + } + } else { + // add the new seq + timers_map_.insert(std::pair<uint64_t, uint32_t>(expiration, interest_seg)); + seqs_map_.insert(std::pair<uint32_t, uint64_t>(interest_seg, expiration)); + } +} + +void RTCProductionProtocol::sendNacksForPendingInterests() { + std::unordered_set<uint32_t> to_remove; + + uint32_t packet_gap = 100000; // set it to a high value (100sec) + if (packets_production_rate_ != 0) + packet_gap = ceil(rtc::MILLI_IN_A_SEC / (double)packets_production_rate_); + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + for (auto it = seqs_map_.begin(); it != seqs_map_.end(); it++) { + if (it->first > current_seg_) { + uint64_t production_time = + ((it->first - current_seg_) * packet_gap) + now; + if (production_time >= it->second) { + sendNack(it->first); + to_remove.insert(it->first); + } + } + } + + // delete nacked interests + for (auto it = to_remove.begin(); it != to_remove.end(); it++) { + removeFromInterestQueue(*it); + } +} + +void RTCProductionProtocol::removeFromInterestQueue(uint32_t interest_seg) { + auto seq_it = seqs_map_.find(interest_seg); + if (seq_it != seqs_map_.end()) { + auto range = timers_map_.equal_range(seq_it->second); + for (auto it_timers = range.first; it_timers != range.second; it_timers++) { + if (it_timers->second == seq_it->first) { + timers_map_.erase(it_timers); + break; + } + } + seqs_map_.erase(seq_it); + } +} + +void RTCProductionProtocol::interestQueueTimer() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + for (auto it_timers = timers_map_.begin(); it_timers != timers_map_.end();) { + uint64_t expire = it_timers->first; + if (expire <= now) { + uint32_t seq = it_timers->second; + sendNack(seq); + // remove the interest from the other map + seqs_map_.erase(seq); + it_timers = timers_map_.erase(it_timers); + } else { + // stop, we are done! + break; + } + } + if (timers_map_.empty()) { + queue_timer_on_ = false; + } else { + queue_timer_on_ = true; + scheduleQueueTimer(timers_map_.begin()->first - now); + } +} + +void RTCProductionProtocol::sendNack(uint32_t sequence) { + auto nack = core::PacketManager<>::getInstance().getPacket<ContentObject>(); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint32_t next_packet = current_seg_; + uint32_t prod_rate = bytes_production_rate_; + + struct rtc::nack_packet_t header; + header.setTimestamp(now); + header.setProductionRate(prod_rate); + header.setProductionSegement(next_packet); + nack->appendPayload((const uint8_t *)&header, rtc::NACK_HEADER_SIZE); + + Name n(flow_name_); + n.setSuffix(sequence); + nack->setName(n); + nack->setLifetime(0); + nack->setPathLabel(prod_label_); + + if (!consumer_in_sync_ && on_consumer_in_sync_ && + sequence < rtc::MIN_PROBE_SEQ && sequence > next_packet) { + consumer_in_sync_ = true; + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); + interest->setName(n); + on_consumer_in_sync_(*socket_->getInterface(), *interest); + } + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), *nack); + } + + TRANSPORT_LOGD("Send nack %u", sequence); + portal_->sendContentObject(*nack); +} + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/prod_protocol_rtc.h b/libtransport/src/protocols/prod_protocol_rtc.h new file mode 100644 index 000000000..f3584f74a --- /dev/null +++ b/libtransport/src/protocols/prod_protocol_rtc.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/core/name.h> +#include <protocols/production_protocol.h> + +#include <atomic> +#include <map> +#include <mutex> + +namespace transport { +namespace protocol { + +class RTCProductionProtocol : public ProductionProtocol { + public: + RTCProductionProtocol(implementation::ProducerSocket *icn_socket); + ~RTCProductionProtocol() override; + + using ProductionProtocol::start; + using ProductionProtocol::stop; + + void produce(ContentObject &content_object) override; + uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceStream(const Name &content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) override; + uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) override; + uint32_t produceDatagram(const Name &content_name, const uint8_t *buffer, + size_t buffer_size) override { + return produceDatagram(content_name, utils::MemBuf::wrapBuffer( + buffer, buffer_size, buffer_size)); + } + + void registerNamespaceWithNetwork(const Prefix &producer_namespace) override; + + void setConsumerInSyncCallback( + interface::ProducerInterestCallback &&callback) { + on_consumer_in_sync_ = std::move(callback); + } + + private: + // packet handlers + void onInterest(Interest &interest) override; + void onError(std::error_code ec) override; + void produceInternal(std::shared_ptr<ContentObject> &&content_object, + const Name &content_name); + void sendNack(uint32_t sequence); + + // stats + void updateStats(); + void scheduleRoundTimer(); + + // pending intersts functions + void addToInterestQueue(uint32_t interest_seg, uint64_t expiration); + void sendNacksForPendingInterests(); + void removeFromInterestQueue(uint32_t interest_seg); + void scheduleQueueTimer(uint64_t wait); + void interestQueueTimer(); + + core::Name flow_name_; + + uint32_t current_seg_; // seq id of the next packet produced + uint32_t prod_label_; // path lable of the producer + uint16_t header_size_; // hicn header size + + uint32_t produced_bytes_; // bytes produced in the last round + uint32_t produced_packets_; // packet produed in the last round + + uint32_t max_packet_production_; // never exceed this number of packets + // without update stats + + uint32_t bytes_production_rate_; // bytes per sec + uint32_t packets_production_rate_; // pps + + std::unique_ptr<asio::steady_timer> round_timer_; + uint64_t last_round_; + + // delayed nacks are used by the producer to avoid to send too + // many nacks we the producer rate is 0. however, if the producer moves + // from a production rate higher than 0 to 0 the first round the dealyed + // should be avoided in order to notify the consumer as fast as possible + // of the new rate. + bool allow_delayed_nacks_; + + // queue for the received interests + // this map maps the expiration time of an interest to + // its sequence number. the map is sorted by timeouts + // the same timeout may be used for multiple sequence numbers + // but for each sequence number we store only the smallest + // expiry time. In this way the mapping from seqs_map_ to + // timers_map_ is unique + std::multimap<uint64_t, uint32_t> timers_map_; + + // this map does the opposite, this map is not ordered + std::unordered_map<uint32_t, uint64_t> seqs_map_; + bool queue_timer_on_; + std::unique_ptr<asio::steady_timer> interests_queue_timer_; + + // this callback is called when the remote consumer is in sync with high + // probability. it is called only the first time that the switch happen. + // XXX this makes sense only in P2P mode, while in standard mode is + // impossible to know the state of the consumers so it should not be used. + bool consumer_in_sync_; + interface::ProducerInterestCallback on_consumer_in_sync_; +}; + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/production_protocol.cc b/libtransport/src/protocols/production_protocol.cc new file mode 100644 index 000000000..8addf52d1 --- /dev/null +++ b/libtransport/src/protocols/production_protocol.cc @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <implementation/socket_producer.h> +#include <protocols/production_protocol.h> + +namespace transport { + +namespace protocol { + +using namespace interface; + +ProductionProtocol::ProductionProtocol( + implementation::ProducerSocket *icn_socket) + : socket_(icn_socket), + is_running_(false), + on_interest_input_(VOID_HANDLER), + on_interest_dropped_input_buffer_(VOID_HANDLER), + on_interest_inserted_input_buffer_(VOID_HANDLER), + on_interest_satisfied_output_buffer_(VOID_HANDLER), + on_interest_process_(VOID_HANDLER), + on_new_segment_(VOID_HANDLER), + on_content_object_to_sign_(VOID_HANDLER), + on_content_object_in_output_buffer_(VOID_HANDLER), + on_content_object_output_(VOID_HANDLER), + on_content_object_evicted_from_output_buffer_(VOID_HANDLER), + on_content_produced_(VOID_HANDLER) { + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal_); + // TODO add statistics for producer + // socket_->getSocketOption(OtherOptions::STATISTICS, &stats_); +} + +ProductionProtocol::~ProductionProtocol() { + if (!is_async_ && is_running_) { + stop(); + } + + if (listening_thread_.joinable()) { + listening_thread_.join(); + } +} + +int ProductionProtocol::start() { + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, + &on_interest_input_); + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_DROP, + &on_interest_dropped_input_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::INTEREST_PASS, + &on_interest_inserted_input_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CACHE_HIT, + &on_interest_satisfied_output_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CACHE_MISS, + &on_interest_process_); + socket_->getSocketOption(ProducerCallbacksOptions::NEW_CONTENT_OBJECT, + &on_new_segment_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_READY, + &on_content_object_in_output_buffer_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_OUTPUT, + &on_content_object_output_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_OBJECT_TO_SIGN, + &on_content_object_to_sign_); + socket_->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, + &on_content_produced_); + + socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); + + bool first = true; + + for (core::Prefix &producer_namespace : served_namespaces_) { + if (first) { + core::BindConfig bind_config(producer_namespace, 1000); + portal_->bind(bind_config); + portal_->setProducerCallback(this); + first = !first; + } else { + portal_->registerRoute(producer_namespace); + } + } + + is_running_ = true; + + if (!is_async_) { + listening_thread_ = std::thread([this]() { portal_->runEventsLoop(); }); + } + + return 0; +} + +void ProductionProtocol::stop() { + is_running_ = false; + + if (!is_async_) { + portal_->stopEventsLoop(); + } else { + portal_->clear(); + } +} + +void ProductionProtocol::produce(ContentObject &content_object) { + if (*on_content_object_in_output_buffer_) { + on_content_object_in_output_buffer_->operator()(*socket_->getInterface(), + content_object); + } + + output_buffer_.insert(std::static_pointer_cast<ContentObject>( + content_object.shared_from_this())); + + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + content_object); + } + + portal_->sendContentObject(content_object); +} + +void ProductionProtocol::registerNamespaceWithNetwork( + const Prefix &producer_namespace) { + served_namespaces_.push_back(producer_namespace); +} + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/production_protocol.h b/libtransport/src/protocols/production_protocol.h new file mode 100644 index 000000000..780972321 --- /dev/null +++ b/libtransport/src/protocols/production_protocol.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/statistics.h> +#include <hicn/transport/utils/object_pool.h> +#include <implementation/socket.h> +#include <utils/content_store.h> + +#include <atomic> +#include <thread> + +namespace transport { + +namespace protocol { + +using namespace core; + +class ProductionProtocol : public Portal::ProducerCallback { + public: + ProductionProtocol(implementation::ProducerSocket *icn_socket); + virtual ~ProductionProtocol(); + + bool isRunning() { return is_running_; } + + virtual int start(); + virtual void stop(); + + virtual void produce(ContentObject &content_object); + virtual uint32_t produceStream(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, + uint32_t start_offset = 0) = 0; + virtual uint32_t produceStream(const Name &content_name, + const uint8_t *buffer, size_t buffer_size, + bool is_last = true, + uint32_t start_offset = 0) = 0; + virtual uint32_t produceDatagram(const Name &content_name, + std::unique_ptr<utils::MemBuf> &&buffer) = 0; + virtual uint32_t produceDatagram(const Name &content_name, + const uint8_t *buffer, + size_t buffer_size) = 0; + + void setOutputBufferSize(std::size_t size) { output_buffer_.setLimit(size); } + std::size_t getOutputBufferSize() { return output_buffer_.getLimit(); } + + virtual void registerNamespaceWithNetwork(const Prefix &producer_namespace); + const std::list<Prefix> &getNamespaces() const { return served_namespaces_; } + + protected: + // Producer callback + virtual void onInterest(core::Interest &i) override = 0; + virtual void onError(std::error_code ec) override{}; + + protected: + implementation::ProducerSocket *socket_; + + // Thread pool responsible for IO operations (send data / receive interests) + std::vector<utils::EventThread> io_threads_; + + // TODO remove this thread + std::thread listening_thread_; + std::shared_ptr<Portal> portal_; + std::atomic<bool> is_running_; + interface::ProductionStatistics *stats_; + + // Callbacks + interface::ProducerInterestCallback *on_interest_input_; + interface::ProducerInterestCallback *on_interest_dropped_input_buffer_; + interface::ProducerInterestCallback *on_interest_inserted_input_buffer_; + interface::ProducerInterestCallback *on_interest_satisfied_output_buffer_; + interface::ProducerInterestCallback *on_interest_process_; + + interface::ProducerContentObjectCallback *on_new_segment_; + interface::ProducerContentObjectCallback *on_content_object_to_sign_; + interface::ProducerContentObjectCallback *on_content_object_in_output_buffer_; + interface::ProducerContentObjectCallback *on_content_object_output_; + interface::ProducerContentObjectCallback + *on_content_object_evicted_from_output_buffer_; + + interface::ProducerContentCallback *on_content_produced_; + + // Output buffer + utils::ContentStore output_buffer_; + + // List ot routes served by current producer protocol + std::list<Prefix> served_namespaces_; + + bool is_async_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/protocols/raaqm.cc b/libtransport/src/protocols/raaqm.cc index 5023adf2e..bc8500227 100644 --- a/libtransport/src/protocols/raaqm.cc +++ b/libtransport/src/protocols/raaqm.cc @@ -13,6 +13,7 @@ * limitations under the License. */ +#include <hicn/transport/core/global_object_pool.h> #include <hicn/transport/interfaces/socket_consumer.h> #include <implementation/socket_consumer.h> #include <protocols/errors.h> @@ -126,10 +127,6 @@ void RaaqmTransportProtocol::reset() { } } -bool RaaqmTransportProtocol::verifyKeyPackets() { - return index_manager_->onKeyToVerify(); -} - void RaaqmTransportProtocol::increaseWindow() { // return; double max_window_size = 0.; @@ -325,8 +322,8 @@ void RaaqmTransportProtocol::init() { is.close(); } -void RaaqmTransportProtocol::onContentObject( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { +void RaaqmTransportProtocol::onContentObject(Interest &interest, + ContentObject &content_object) { // Check whether makes sense to continue if (TRANSPORT_EXPECT_FALSE(!is_running_)) { return; @@ -334,54 +331,53 @@ void RaaqmTransportProtocol::onContentObject( // Call application-defined callbacks if (*on_content_object_input_) { - (*on_content_object_input_)(*socket_->getInterface(), *content_object); + (*on_content_object_input_)(*socket_->getInterface(), content_object); } if (*on_interest_satisfied_) { - (*on_interest_satisfied_)(*socket_->getInterface(), *interest); + (*on_interest_satisfied_)(*socket_->getInterface(), interest); } - if (content_object->getPayloadType() == PayloadType::CONTENT_OBJECT) { - stats_->updateBytesRecv(content_object->payloadSize()); + if (content_object.getPayloadType() == PayloadType::DATA) { + stats_->updateBytesRecv(content_object.payloadSize()); } - onContentSegment(std::move(interest), std::move(content_object)); + onContentSegment(interest, content_object); scheduleNextInterests(); } -void RaaqmTransportProtocol::onContentSegment( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { - uint32_t incremental_suffix = content_object->getName().getSuffix(); +void RaaqmTransportProtocol::onContentSegment(Interest &interest, + ContentObject &content_object) { + uint32_t incremental_suffix = content_object.getName().getSuffix(); // Decrease in-flight interests interests_in_flight_--; // Update stats if (!interest_retransmissions_[incremental_suffix & mask]) { - afterContentReception(*interest, *content_object); + afterContentReception(interest, content_object); } - index_manager_->onContentObject(std::move(interest), - std::move(content_object)); + index_manager_->onContentObject(interest, content_object); } -void RaaqmTransportProtocol::onPacketDropped( - Interest::Ptr &&interest, ContentObject::Ptr &&content_object) { +void RaaqmTransportProtocol::onPacketDropped(Interest &interest, + ContentObject &content_object) { uint32_t max_rtx = 0; socket_->getSocketOption(GeneralTransportOptions::MAX_INTEREST_RETX, max_rtx); - uint64_t segment = interest->getName().getSuffix(); + uint64_t segment = interest.getName().getSuffix(); if (TRANSPORT_EXPECT_TRUE(interest_retransmissions_[segment & mask] < max_rtx)) { stats_->updateRetxCount(1); if (*on_interest_retransmission_) { - (*on_interest_retransmission_)(*socket_->getInterface(), *interest); + (*on_interest_retransmission_)(*socket_->getInterface(), interest); } if (*on_interest_output_) { - (*on_interest_output_)(*socket_->getInterface(), *interest); + (*on_interest_output_)(*socket_->getInterface(), interest); } if (!is_running_) { @@ -389,7 +385,7 @@ void RaaqmTransportProtocol::onPacketDropped( } interest_retransmissions_[segment & mask]++; - interest_to_retransmit_.push(std::move(interest)); + interest_to_retransmit_.push(interest.shared_from_this()); } else { TRANSPORT_LOGE( "Stop: received not trusted packet %llu times", @@ -477,6 +473,11 @@ void RaaqmTransportProtocol::scheduleNextInterests() { sendInterest(std::move(interest_to_retransmit_.front())); interest_to_retransmit_.pop(); } else { + if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { + TRANSPORT_LOGI("Adios"); + break; + } + index = index_manager_->getNextSuffix(); if (index == IndexManager::invalid_index) { break; @@ -487,8 +488,8 @@ void RaaqmTransportProtocol::scheduleNextInterests() { } } -bool RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { - auto interest = getPacket(); +void RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); core::Name *name; socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, &name); name->setSuffix((uint32_t)next_suffix); @@ -502,19 +503,12 @@ bool RaaqmTransportProtocol::sendInterest(std::uint64_t next_suffix) { if (*on_interest_output_) { on_interest_output_->operator()(*socket_->getInterface(), *interest); } - - if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { - return false; - } - // This is set to ~0 so that the next interest_retransmissions_ + 1, // performed by sendInterest, will result in 0 interest_retransmissions_[next_suffix & mask] = ~0; interest_timepoints_[next_suffix & mask] = utils::SteadyClock::now(); sendInterest(std::move(interest)); - - return true; } void RaaqmTransportProtocol::sendInterest(Interest::Ptr &&interest) { diff --git a/libtransport/src/protocols/raaqm.h b/libtransport/src/protocols/raaqm.h index fce4194d4..be477d39f 100644 --- a/libtransport/src/protocols/raaqm.h +++ b/libtransport/src/protocols/raaqm.h @@ -18,9 +18,9 @@ #include <hicn/transport/utils/chrono_typedefs.h> #include <protocols/byte_stream_reassembly.h> #include <protocols/congestion_window_protocol.h> -#include <protocols/protocol.h> #include <protocols/raaqm_data_path.h> #include <protocols/rate_estimation.h> +#include <protocols/transport_protocol.h> #include <queue> #include <vector> @@ -42,8 +42,6 @@ class RaaqmTransportProtocol : public TransportProtocol, void reset() override; - virtual bool verifyKeyPackets() override; - protected: static constexpr uint32_t buffer_size = 1 << interface::default_values::log_2_default_buffer_size; @@ -64,13 +62,12 @@ class RaaqmTransportProtocol : public TransportProtocol, private: void init(); - void onContentObject(Interest::Ptr &&i, ContentObject::Ptr &&c) override; + void onContentObject(Interest &i, ContentObject &c) override; - void onContentSegment(Interest::Ptr &&interest, - ContentObject::Ptr &&content_object); + void onContentSegment(Interest &interest, ContentObject &content_object); - void onPacketDropped(Interest::Ptr &&interest, - ContentObject::Ptr &&content_object) override; + void onPacketDropped(Interest &interest, + ContentObject &content_object) override; void onReassemblyFailed(std::uint32_t missing_segment) override; @@ -78,7 +75,7 @@ class RaaqmTransportProtocol : public TransportProtocol, virtual void scheduleNextInterests() override; - bool sendInterest(std::uint64_t next_suffix); + void sendInterest(std::uint64_t next_suffix); void sendInterest(Interest::Ptr &&interest); diff --git a/libtransport/src/protocols/raaqm_data_path.cc b/libtransport/src/protocols/raaqm_data_path.cc index 8bbbadcf2..f2c21b9ef 100644 --- a/libtransport/src/protocols/raaqm_data_path.cc +++ b/libtransport/src/protocols/raaqm_data_path.cc @@ -14,7 +14,6 @@ */ #include <hicn/transport/utils/chrono_typedefs.h> - #include <protocols/raaqm_data_path.h> namespace transport { diff --git a/libtransport/src/protocols/raaqm_data_path.h b/libtransport/src/protocols/raaqm_data_path.h index 3f037bc76..c0b53a690 100644 --- a/libtransport/src/protocols/raaqm_data_path.h +++ b/libtransport/src/protocols/raaqm_data_path.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/utils/chrono_typedefs.h> - #include <utils/min_filter.h> #include <chrono> diff --git a/libtransport/src/protocols/rate_estimation.cc b/libtransport/src/protocols/rate_estimation.cc index a2cf1aefe..5ca925760 100644 --- a/libtransport/src/protocols/rate_estimation.cc +++ b/libtransport/src/protocols/rate_estimation.cc @@ -15,7 +15,6 @@ #include <hicn/transport/interfaces/socket_options_default_values.h> #include <hicn/transport/utils/log.h> - #include <protocols/rate_estimation.h> #include <thread> diff --git a/libtransport/src/protocols/rate_estimation.h b/libtransport/src/protocols/rate_estimation.h index 17f39e0b9..42ae74194 100644 --- a/libtransport/src/protocols/rate_estimation.h +++ b/libtransport/src/protocols/rate_estimation.h @@ -16,7 +16,6 @@ #pragma once #include <hicn/transport/interfaces/statistics.h> - #include <protocols/raaqm_data_path.h> #include <chrono> diff --git a/libtransport/src/protocols/reassembly.cc b/libtransport/src/protocols/reassembly.cc index c6602153c..0e59832dc 100644 --- a/libtransport/src/protocols/reassembly.cc +++ b/libtransport/src/protocols/reassembly.cc @@ -16,7 +16,6 @@ #include <hicn/transport/interfaces/socket_consumer.h> #include <hicn/transport/utils/array.h> #include <hicn/transport/utils/membuf.h> - #include <implementation/socket_consumer.h> #include <protocols/errors.h> #include <protocols/indexer.h> diff --git a/libtransport/src/protocols/reassembly.h b/libtransport/src/protocols/reassembly.h index fdc9f2a05..385122c53 100644 --- a/libtransport/src/protocols/reassembly.h +++ b/libtransport/src/protocols/reassembly.h @@ -46,7 +46,7 @@ class Reassembly { virtual ~Reassembly() = default; - virtual void reassemble(core::ContentObject::Ptr &&content_object) = 0; + virtual void reassemble(core::ContentObject &content_object) = 0; virtual void reassemble( std::unique_ptr<core::ContentObjectManifest> &&manifest) = 0; virtual void reInitialize() = 0; diff --git a/libtransport/src/protocols/rtc/CMakeLists.txt b/libtransport/src/protocols/rtc/CMakeLists.txt new file mode 100644 index 000000000..77f065d0e --- /dev/null +++ b/libtransport/src/protocols/rtc/CMakeLists.txt @@ -0,0 +1,38 @@ +# Copyright (c) 2017-2019 Cisco and/or its affiliates. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) + +list(APPEND HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/rtc.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_state.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_ldr.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_consts.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc_queue.h + ${CMAKE_CURRENT_SOURCE_DIR}/probe_handler.h + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_packet.h +) + +list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/rtc.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_state.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_ldr.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_rc_queue.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_data_path.cc + ${CMAKE_CURRENT_SOURCE_DIR}/probe_handler.cc +) + +set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/protocols/rtc/congestion_detection.cc b/libtransport/src/protocols/rtc/congestion_detection.cc new file mode 100644 index 000000000..e2d44ae66 --- /dev/null +++ b/libtransport/src/protocols/rtc/congestion_detection.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/utils/log.h> +#include <protocols/rtc/congestion_detection.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +CongestionDetection::CongestionDetection() + : cc_estimator_(), last_processed_chunk_() {} + +CongestionDetection::~CongestionDetection() {} + +void CongestionDetection::updateStats() { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (chunks_number_.empty()) return; + + uint32_t chunk_number = chunks_number_.front(); + + while (chunks_[chunk_number].getReceivedTime() + HICN_CC_STATS_MAX_DELAY_MS < + now || + chunks_[chunk_number].isComplete()) { + if (chunk_number == last_processed_chunk_.getFrameSeqNum() + 1) { + chunks_[chunk_number].setPreviousSentTime( + last_processed_chunk_.getSentTime()); + + chunks_[chunk_number].setPreviousReceivedTime( + last_processed_chunk_.getReceivedTime()); + cc_estimator_.Update(chunks_[chunk_number].getReceivedDelta(), + chunks_[chunk_number].getSentDelta(), + chunks_[chunk_number].getSentTime(), + chunks_[chunk_number].getReceivedTime(), + chunks_[chunk_number].getFrameSize(), true); + + } else { + TRANSPORT_LOGD( + "CongestionDetection::updateStats frame %u but not the \ + previous one, last one was %u currentFrame %u", + chunk_number, last_processed_chunk_.getFrameSeqNum(), + chunks_[chunk_number].getFrameSeqNum()); + } + + last_processed_chunk_ = chunks_[chunk_number]; + + chunks_.erase(chunk_number); + + chunks_number_.pop(); + if (chunks_number_.empty()) break; + + chunk_number = chunks_number_.front(); + } +} + +void CongestionDetection::addPacket(const core::ContentObject &content_object) { + auto payload = content_object.getPayload(); + uint32_t payload_size = (uint32_t)payload->length(); + uint32_t segmentNumber = content_object.getName().getSuffix(); + // uint32_t pkt = segmentNumber & modMask_; + uint64_t *sentTimePtr = (uint64_t *)payload->data(); + + // this is just for testing with hiperf, assuming a frame is 10 pkts + // in the final version, the split should be based on the timestamp in the pkt + uint32_t frameNum = (int)(segmentNumber / HICN_CC_STATS_CHUNK_SIZE); + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + if (chunks_.find(frameNum) == chunks_.end()) { + // new chunk of pkts or out of order + if (last_processed_chunk_.getFrameSeqNum() > frameNum) + return; // out of order and we already processed the chunk + + chunks_[frameNum] = FrameStats(frameNum, HICN_CC_STATS_CHUNK_SIZE); + chunks_number_.push(frameNum); + } + + chunks_[frameNum].addPacket(*sentTimePtr, now, payload_size); +} + +} // namespace rtc +} // namespace protocol +} // namespace transport diff --git a/libtransport/src/protocols/rtc/congestion_detection.h b/libtransport/src/protocols/rtc/congestion_detection.h new file mode 100644 index 000000000..17f4aa54c --- /dev/null +++ b/libtransport/src/protocols/rtc/congestion_detection.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/content_object.h> +#include <protocols/rtc/trendline_estimator.h> + +#include <map> +#include <queue> + +#define HICN_CC_STATS_CHUNK_SIZE 10 +#define HICN_CC_STATS_MAX_DELAY_MS 100 + +namespace transport { + +namespace protocol { + +namespace rtc { + +class FrameStats { + public: + FrameStats() + : frame_num_(0), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(HICN_CC_STATS_CHUNK_SIZE){}; + + FrameStats(uint32_t burst_size) + : frame_num_(0), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(burst_size){}; + + FrameStats(uint32_t frame_num, uint32_t burst_size) + : frame_num_(frame_num), + sent_time_(0), + received_time_(0), + previous_sent_time_(0), + previous_received_time_(0), + size_(0), + received_pkt_m(0), + burst_size_m(burst_size){}; + + FrameStats(uint32_t frame_num, uint64_t sent_time, uint64_t received_time, + uint32_t size, FrameStats previousFrame, uint32_t burst_size) + : frame_num_(frame_num), + sent_time_(sent_time), + received_time_(received_time), + previous_sent_time_(previousFrame.getSentTime()), + previous_received_time_(previousFrame.getReceivedTime()), + size_(size), + received_pkt_m(1), + burst_size_m(burst_size){}; + + void addPacket(uint64_t sent_time, uint64_t received_time, uint32_t size) { + size_ += size; + sent_time_ = + (sent_time_ == 0) ? sent_time : std::min(sent_time_, sent_time); + received_time_ = std::max(received_time, received_time_); + received_pkt_m++; + } + + bool isComplete() { return received_pkt_m == burst_size_m; } + + uint32_t getFrameSeqNum() const { return frame_num_; } + uint64_t getSentTime() const { return sent_time_; } + uint64_t getReceivedTime() const { return received_time_; } + uint32_t getFrameSize() const { return size_; } + + void setPreviousReceivedTime(uint64_t time) { + previous_received_time_ = time; + } + void setPreviousSentTime(uint64_t time) { previous_sent_time_ = time; } + + // todo manage first frame + double getReceivedDelta() { + return static_cast<double>(received_time_ - previous_received_time_); + } + double getSentDelta() { + return static_cast<double>(sent_time_ - previous_sent_time_); + } + + private: + uint32_t frame_num_; + uint64_t sent_time_; + uint64_t received_time_; + + uint64_t previous_sent_time_; + uint64_t previous_received_time_; + uint32_t size_; + + uint32_t received_pkt_m; + uint32_t burst_size_m; +}; + +class CongestionDetection { + public: + CongestionDetection(); + ~CongestionDetection(); + + void addPacket(const core::ContentObject &content_object); + + BandwidthUsage getState() { return cc_estimator_.State(); } + + void updateStats(); + + private: + TrendlineEstimator cc_estimator_; + std::map<uint32_t, FrameStats> chunks_; + std::queue<uint32_t> chunks_number_; + + FrameStats last_processed_chunk_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/probe_handler.cc b/libtransport/src/protocols/rtc/probe_handler.cc new file mode 100644 index 000000000..efba362d4 --- /dev/null +++ b/libtransport/src/protocols/rtc/probe_handler.cc @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/probe_handler.h> +#include <protocols/rtc/rtc_consts.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +ProbeHandler::ProbeHandler(SendProbeCallback &&send_callback, + asio::io_service &io_service) + : probe_interval_(0), + max_probes_(0), + sent_probes_(0), + probe_timer_(std::make_unique<asio::steady_timer>(io_service)), + rand_eng_((std::random_device())()), + distr_(MIN_RTT_PROBE_SEQ, MAX_RTT_PROBE_SEQ), + send_probe_callback_(std::move(send_callback)) {} + +ProbeHandler::~ProbeHandler() {} + +uint64_t ProbeHandler::getRtt(uint32_t seq) { + auto it = pending_probes_.find(seq); + + if (it == pending_probes_.end()) return 0; + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t rtt = now - it->second; + if(rtt < 1) rtt = 1; + + pending_probes_.erase(it); + + return rtt; +} + +void ProbeHandler::setProbes(uint32_t probe_interval, uint32_t max_probes) { + stopProbes(); + probe_interval_ = probe_interval; + max_probes_ = max_probes; +} + +void ProbeHandler::stopProbes() { + probe_interval_ = 0; + max_probes_ = 0; + sent_probes_ = 0; + probe_timer_->cancel(); +} + +void ProbeHandler::sendProbes() { + if (probe_interval_ == 0) return; + if (max_probes_ != 0 && sent_probes_ >= max_probes_) return; + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + uint32_t seq = distr_(rand_eng_); + pending_probes_.insert(std::pair<uint32_t, uint64_t>(seq, now)); + send_probe_callback_(seq); + sent_probes_++; + + // clean up + // a probe may get lost. if the pending_probes_ size becomes bigger than + // MAX_PENDING_PROBES remove all the probes older than a seconds + if (pending_probes_.size() > MAX_PENDING_PROBES) { + for (auto it = pending_probes_.begin(); it != pending_probes_.end();) { + if ((now - it->second) > 1000) + it = pending_probes_.erase(it); + else + it++; + } + } + + if (probe_interval_ == 0) return; + + std::weak_ptr<ProbeHandler> self(shared_from_this()); + probe_timer_->expires_from_now(std::chrono::microseconds(probe_interval_)); + probe_timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->sendProbes(); + } + }); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/probe_handler.h b/libtransport/src/protocols/rtc/probe_handler.h new file mode 100644 index 000000000..b8ed84445 --- /dev/null +++ b/libtransport/src/protocols/rtc/probe_handler.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include <hicn/transport/config.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <functional> +#include <random> +#include <unordered_map> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class ProbeHandler : public std::enable_shared_from_this<ProbeHandler> { + public: + using SendProbeCallback = std::function<void(uint32_t)>; + + public: + ProbeHandler(SendProbeCallback &&send_callback, + asio::io_service &io_service); + + ~ProbeHandler(); + + // if the function returns 0 the probe is not valaid + uint64_t getRtt(uint32_t seq); + + // reset the probes parameters. it stop the current probing. + // to restar call sendProbes. + // probe_interval = 0 means that no event will be scheduled + // max_probe = 0 means no limit to the number of probe to send + void setProbes(uint32_t probe_interval, uint32_t max_probes); + + // stop to schedule probes + void stopProbes(); + + void sendProbes(); + + private: + uint32_t probe_interval_; // us + uint32_t max_probes_; // packets + uint32_t sent_probes_; // packets + + std::unique_ptr<asio::steady_timer> probe_timer_; + + // map from seqnumber to timestamp + std::unordered_map<uint32_t, uint64_t> pending_probes_; + + // random generator + std::default_random_engine rand_eng_; + std::uniform_int_distribution<uint32_t> distr_; + + SendProbeCallback send_probe_callback_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc.cc b/libtransport/src/protocols/rtc/rtc.cc new file mode 100644 index 000000000..bb95ab686 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc.cc @@ -0,0 +1,607 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <implementation/socket_consumer.h> +#include <math.h> +#include <protocols/rtc/rtc.h> +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_queue.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +using namespace interface; + +RTCTransportProtocol::RTCTransportProtocol( + implementation::ConsumerSocket *icn_socket) + : TransportProtocol(icn_socket, nullptr), + DatagramReassembly(icn_socket, this), + number_(0) { + icn_socket->getSocketOption(PORTAL, portal_); + round_timer_ = std::make_unique<asio::steady_timer>(portal_->getIoService()); + scheduler_timer_ = + std::make_unique<asio::steady_timer>(portal_->getIoService()); +} + +RTCTransportProtocol::~RTCTransportProtocol() {} + +void RTCTransportProtocol::resume() { + if (is_running_) return; + + is_running_ = true; + + newRound(); + + portal_->runEventsLoop(); + is_running_ = false; +} + +// private +void RTCTransportProtocol::initParams() { + portal_->setConsumerCallback(this); + + rc_ = std::make_shared<RTCRateControlQueue>(); + ldr_ = std::make_shared<RTCLossDetectionAndRecovery>( + std::bind(&RTCTransportProtocol::sendRtxInterest, this, + std::placeholders::_1), + portal_->getIoService()); + + state_ = std::make_shared<RTCState>( + std::bind(&RTCTransportProtocol::sendProbeInterest, this, + std::placeholders::_1), + std::bind(&RTCTransportProtocol::discoveredRtt, this), + portal_->getIoService()); + + rc_->setState(state_); + // TODO: for the moment we keep the congestion control disabled + // rc_->tunrOnRateControl(); + ldr_->setState(state_); + + // protocol state + start_send_interest_ = false; + current_state_ = SyncState::catch_up; + + // Cancel timer + number_++; + round_timer_->cancel(); + scheduler_timer_->cancel(); + scheduler_timer_on_ = false; + + // delete all timeouts and future nacks + timeouts_or_nacks_.clear(); + + // cwin vars + current_sync_win_ = INITIAL_WIN; + max_sync_win_ = INITIAL_WIN_MAX; + + // names/packets var + next_segment_ = 0; + + socket_->setSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + RTC_INTEREST_LIFETIME); +} + +// private +void RTCTransportProtocol::reset() { + TRANSPORT_LOGD("reset called"); + initParams(); + newRound(); +} + +void RTCTransportProtocol::inactiveProducer() { + // when the producer is inactive we reset the consumer state + // cwin vars + current_sync_win_ = INITIAL_WIN; + max_sync_win_ = INITIAL_WIN_MAX; + + TRANSPORT_LOGD("Current window: %u, max_sync_win_: %u", current_sync_win_, + max_sync_win_); + + // names/packets var + next_segment_ = 0; + + ldr_->clear(); +} + +void RTCTransportProtocol::newRound() { + round_timer_->expires_from_now(std::chrono::milliseconds(ROUND_LEN)); + // TODO pass weak_ptr here + round_timer_->async_wait([this, n{number_}](std::error_code ec) { + if (ec) return; + + if (n != number_) { + return; + } + + // saving counters that will be reset on new round + uint32_t sent_retx = state_->getSentRtxInRound(); + uint32_t received_bytes = state_->getReceivedBytesInRound(); + uint32_t sent_interest = state_->getSentInterestInRound(); + uint32_t lost_data = state_->getLostData(); + uint32_t recovered_losses = state_->getRecoveredLosses(); + uint32_t received_nacks = state_->getReceivedNacksInRound(); + + bool in_sync = (current_state_ == SyncState::in_sync); + state_->onNewRound((double)ROUND_LEN, in_sync); + rc_->onNewRound((double)ROUND_LEN); + + // update sync state if needed + if (current_state_ == SyncState::in_sync) { + double cache_rate = state_->getPacketFromCacheRatio(); + if (cache_rate > MAX_DATA_FROM_CACHE) { + current_state_ = SyncState::catch_up; + } + } else { + double target_rate = state_->getProducerRate() * PRODUCTION_RATE_FRACTION; + double received_rate = state_->getReceivedRate(); + uint32_t round_without_nacks = state_->getRoundsWithoutNacks(); + double cache_ratio = state_->getPacketFromCacheRatio(); + if (round_without_nacks >= ROUNDS_IN_SYNC_BEFORE_SWITCH && + received_rate >= target_rate && cache_ratio < MAX_DATA_FROM_CACHE) { + current_state_ = SyncState::in_sync; + } + } + + TRANSPORT_LOGD("Calling updateSyncWindow in newRound function"); + updateSyncWindow(); + + sendStatsToApp(sent_retx, received_bytes, sent_interest, lost_data, + recovered_losses, received_nacks); + newRound(); + }); +} + +void RTCTransportProtocol::discoveredRtt() { + start_send_interest_ = true; + ldr_->turnOnRTX(); + updateSyncWindow(); +} + +void RTCTransportProtocol::computeMaxSyncWindow() { + double production_rate = state_->getProducerRate(); + double packet_size = state_->getAveragePacketSize(); + if (production_rate == 0.0 || packet_size == 0.0) { + // the consumer has no info about the producer, + // keep the previous maxCWin + TRANSPORT_LOGD( + "Returning in computeMaxSyncWindow because: prod_rate: %d || " + "packet_size: %d", + (int)(production_rate == 0.0), (int)(packet_size == 0.0)); + return; + } + + uint32_t lifetime = default_values::interest_lifetime; + socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + double lifetime_ms = (double)lifetime / MILLI_IN_A_SEC; + + + max_sync_win_ = + (uint32_t)ceil((production_rate * lifetime_ms * + INTEREST_LIFETIME_REDUCTION_FACTOR) / packet_size); + + max_sync_win_ = std::min(max_sync_win_, rc_->getCongesionWindow()); +} + +void RTCTransportProtocol::updateSyncWindow() { + computeMaxSyncWindow(); + + if (max_sync_win_ == INITIAL_WIN_MAX) { + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) return; + + current_sync_win_ = INITIAL_WIN; + scheduleNextInterests(); + return; + } + + double prod_rate = state_->getProducerRate(); + double rtt = (double)state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = state_->getAveragePacketSize(); + + // if some of the info are not available do not update the current win + if (prod_rate != 0.0 && rtt != 0.0 && packet_size != 0.0) { + current_sync_win_ = (uint32_t)ceil(prod_rate * rtt / packet_size); + current_sync_win_ += + ceil(prod_rate * (PRODUCER_BUFFER_MS / MILLI_IN_A_SEC) / packet_size); + + if(current_state_ == SyncState::catch_up) { + current_sync_win_ = current_sync_win_ * CATCH_UP_WIN_INCREMENT; + } + + current_sync_win_ = std::min(current_sync_win_, max_sync_win_); + current_sync_win_ = std::max(current_sync_win_, WIN_MIN); + } + + scheduleNextInterests(); +} + +void RTCTransportProtocol::decreaseSyncWindow() { + // called on future nack + // we have a new sample of the production rate, so update max win first + computeMaxSyncWindow(); + current_sync_win_--; + current_sync_win_ = std::max(current_sync_win_, WIN_MIN); + scheduleNextInterests(); +} + +void RTCTransportProtocol::sendInterest(Name *interest_name) { + TRANSPORT_LOGD("Sending interest for name %s", + interest_name->toString().c_str()); + + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>(); + interest->setName(*interest_name); + + uint32_t lifetime = default_values::interest_lifetime; + socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + interest->setLifetime(uint32_t(lifetime)); + + if (*on_interest_output_) { + (*on_interest_output_)(*socket_->getInterface(), *interest); + } + + if (TRANSPORT_EXPECT_FALSE(!is_running_ && !is_first_)) { + return; + } + + portal_->sendInterest(std::move(interest)); +} + +void RTCTransportProtocol::sendRtxInterest(uint32_t seq) { + if (!is_running_ && !is_first_) return; + + if(!start_send_interest_) return; + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send rtx %u", seq); + interest_name->setSuffix(seq); + sendInterest(interest_name); +} + +void RTCTransportProtocol::sendProbeInterest(uint32_t seq) { + if (!is_running_ && !is_first_) return; + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send probe %u", seq); + interest_name->setSuffix(seq); + sendInterest(interest_name); +} + +void RTCTransportProtocol::scheduleNextInterests() { + TRANSPORT_LOGD("Schedule next interests"); + + if (!is_running_ && !is_first_) return; + + if(!start_send_interest_) return; // RTT discovering phase is not finished so + // do not start to send interests + + if (scheduler_timer_on_) return; // wait befor send other interests + + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) { + TRANSPORT_LOGD("Inactive producer."); + // here we keep seding the same interest until the producer + // does not start again + if (next_segment_ != 0) { + // the producer just become inactive, reset the state + inactiveProducer(); + } + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + TRANSPORT_LOGD("send interest %u", next_segment_); + interest_name->setSuffix(next_segment_); + + if (portal_->interestIsPending(*interest_name)) { + // if interest 0 is already pending we return + return; + } + + sendInterest(interest_name); + state_->onSendNewInterest(interest_name); + return; + } + + TRANSPORT_LOGD("Pending interest number: %d -- current_sync_win_: %d", + state_->getPendingInterestNumber(), current_sync_win_); + + // skip nacked pacekts + if (next_segment_ <= state_->getLastSeqNacked()) { + next_segment_ = state_->getLastSeqNacked() + 1; + } + + // skipe received packets + if (next_segment_ <= state_->getHighestSeqReceivedInOrder()) { + next_segment_ = state_->getHighestSeqReceivedInOrder() + 1; + } + + uint32_t sent_interests = 0; + while ((state_->getPendingInterestNumber() < current_sync_win_) && + (sent_interests < MAX_INTERESTS_IN_BATCH)) { + TRANSPORT_LOGD("In while loop. Window size: %u", current_sync_win_); + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + interest_name->setSuffix(next_segment_); + + // send the packet only if: + // 1) it is not pending yet (not true for rtx) + // 2) the packet is not received or lost + // 3) is not in the rtx list + if (portal_->interestIsPending(*interest_name) || + state_->isReceivedOrLost(next_segment_) != PacketState::UNKNOWN || + ldr_->isRtx(next_segment_)) { + TRANSPORT_LOGD( + "skip interest %u because: pending %u, recv %u, rtx %u", + next_segment_, (portal_->interestIsPending(*interest_name)), + (state_->isReceivedOrLost(next_segment_) != PacketState::UNKNOWN), + (ldr_->isRtx(next_segment_))); + next_segment_ = (next_segment_ + 1) % MIN_PROBE_SEQ; + continue; + } + + + sent_interests++; + TRANSPORT_LOGD("send interest %u", next_segment_); + sendInterest(interest_name); + state_->onSendNewInterest(interest_name); + + next_segment_ = (next_segment_ + 1) % MIN_PROBE_SEQ; + } + + if (state_->getPendingInterestNumber() < current_sync_win_) { + // we still have space in the window but we already sent a batch of + // MAX_INTERESTS_IN_BATCH interest. for the following ones wait one + // WAIT_BETWEEN_INTEREST_BATCHES to avoid local packets drop + + scheduler_timer_on_ = true; + scheduler_timer_->expires_from_now( + std::chrono::microseconds(WAIT_BETWEEN_INTEREST_BATCHES)); + scheduler_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + if (!scheduler_timer_on_) return; + + scheduler_timer_on_ = false; + scheduleNextInterests(); + }); + } +} + +void RTCTransportProtocol::onTimeout(Interest::Ptr &&interest) { + uint32_t segment_number = interest->getName().getSuffix(); + + TRANSPORT_LOGD("timeout for packet %u", segment_number); + + if (segment_number >= MIN_PROBE_SEQ) { + // this is a timeout on a probe, do nothing + return; + } + + timeouts_or_nacks_.insert(segment_number); + + if (TRANSPORT_EXPECT_TRUE(state_->isProducerActive()) && + segment_number <= state_->getHighestSeqReceivedInOrder()) { + // we retransmit packets only if the producer is active, otherwise we + // use timeouts to avoid to send too much traffic + // + // a timeout is sent using RTX only if it is an old packet. if it is for a + // seq number that we didn't reach yet, we send the packet using the normal + // schedule next interest + TRANSPORT_LOGD("handle timeout for packet %u using rtx", segment_number); + ldr_->onTimeout(segment_number); + state_->onTimeout(segment_number); + scheduleNextInterests(); + return; + } + + TRANSPORT_LOGD("handle timeout for packet %u using normal interests", + segment_number); + + if (segment_number < next_segment_) { + // this is a timeout for a packet that will be generated in the future but + // we are asking for higher sequence numbers. we need to go back like in the + // case of future nacks + TRANSPORT_LOGD("on timeout next seg = %u, jump to %u", + next_segment_, segment_number); + next_segment_ = segment_number; + } + + state_->onTimeout(segment_number); + scheduleNextInterests(); +} + +void RTCTransportProtocol::onNack(const ContentObject &content_object) { + struct nack_packet_t *nack = + (struct nack_packet_t *)content_object.getPayload()->data(); + uint32_t production_seg = nack->getProductionSegement(); + uint32_t nack_segment = content_object.getName().getSuffix(); + bool is_rtx = ldr_->isRtx(nack_segment); + + // check if the packet got a timeout + + TRANSPORT_LOGD("Nack received %u. Production segment: %u", nack_segment, + production_seg); + + bool compute_stats = true; + auto tn_it = timeouts_or_nacks_.find(nack_segment); + if (tn_it != timeouts_or_nacks_.end() || is_rtx) { + compute_stats = false; + // remove packets from timeouts_or_nacks only in case of a past nack + } + + state_->onNackPacketReceived(content_object, compute_stats); + ldr_->onNackPacketReceived(content_object); + + // both in case of past and future nack we set next_segment_ equal to the + // production segment in the nack. In case of past nack we will skip unneded + // interest (this is already done in the scheduleNextInterest in any case) + // while in case of future nacks we can go back in time and ask again for the + // content that generated the nack + TRANSPORT_LOGD("on nack next seg = %u, jump to %u", + next_segment_, production_seg); + next_segment_ = production_seg; + + if (production_seg > nack_segment) { + // remove the nack is it exists + if (tn_it != timeouts_or_nacks_.end()) timeouts_or_nacks_.erase(tn_it); + + // the client is asking for content in the past + // switch to catch up state and increase the window + // this is true only if the packet is not an RTX + if (!is_rtx) current_state_ = SyncState::catch_up; + + updateSyncWindow(); + } else { + // if production_seg == nack_segment we consider this a future nack, since + // production_seg is not yet created. this may happen in case of low + // production rate (e.g. ping at 1pps) + + // if a future nack was also retransmitted add it to the timeout_or_nacks + // set + if (is_rtx) timeouts_or_nacks_.insert(nack_segment); + + // the client is asking for content in the future + // switch to in sync state and decrease the window + current_state_ = SyncState::in_sync; + decreaseSyncWindow(); + } +} + +void RTCTransportProtocol::onProbe(const ContentObject &content_object) { + bool valid = state_->onProbePacketReceived(content_object); + if(!valid) return; + + struct nack_packet_t *probe = + (struct nack_packet_t *)content_object.getPayload()->data(); + uint32_t production_seg = probe->getProductionSegement(); + + // as for the nacks set next_segment_ + TRANSPORT_LOGD("on probe next seg = %u, jump to %u", + next_segment_, production_seg); + next_segment_ = production_seg; + + ldr_->onProbePacketReceived(content_object); + updateSyncWindow(); +} + +void RTCTransportProtocol::onContentObject(Interest &interest, + ContentObject &content_object) { + TRANSPORT_LOGD("Received content object of size: %zu", + content_object.payloadSize()); + uint32_t payload_size = content_object.payloadSize(); + uint32_t segment_number = content_object.getName().getSuffix(); + + if (segment_number >= MIN_PROBE_SEQ) { + TRANSPORT_LOGD("Received probe %u", segment_number); + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + onProbe(content_object); + return; + } + + if (payload_size == NACK_HEADER_SIZE) { + TRANSPORT_LOGD("Received nack %u", segment_number); + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + onNack(content_object); + return; + } + + TRANSPORT_LOGD("Received content %u", segment_number); + + rc_->onDataPacketReceived(content_object); + bool compute_stats = true; + auto tn_it = timeouts_or_nacks_.find(segment_number); + if (tn_it != timeouts_or_nacks_.end()) { + compute_stats = false; + timeouts_or_nacks_.erase(tn_it); + } + if (ldr_->isRtx(segment_number)) { + compute_stats = false; + } + + // check if the packet was already received + PacketState state = state_->isReceivedOrLost(segment_number); + state_->onDataPacketReceived(content_object, compute_stats); + ldr_->onDataPacketReceived(content_object); + + // if the stat for this seq number is received do not send the packet to app + if (state != PacketState::RECEIVED) { + if (*on_content_object_input_) { + (*on_content_object_input_)(*socket_->getInterface(), content_object); + } + reassemble(content_object); + } else { + TRANSPORT_LOGD("Received duplicated content %u, drop it", segment_number); + } + + updateSyncWindow(); +} + +void RTCTransportProtocol::sendStatsToApp( + uint32_t retx_count, uint32_t received_bytes, uint32_t sent_interests, + uint32_t lost_data, uint32_t recovered_losses, uint32_t received_nacks) { + if (*stats_summary_) { + // Send the stats to the app + stats_->updateQueuingDelay(state_->getQueuing()); + + // stats_->updateInterestFecTx(0); //todo must be implemented + // stats_->updateBytesFecRecv(0); //todo must be implemented + + stats_->updateRetxCount(retx_count); + stats_->updateBytesRecv(received_bytes); + stats_->updateInterestTx(sent_interests); + stats_->updateReceivedNacks(received_nacks); + + stats_->updateAverageWindowSize(current_sync_win_); + stats_->updateLossRatio(state_->getLossRate()); + stats_->updateAverageRtt(state_->getRTT()); + stats_->updateLostData(lost_data); + stats_->updateRecoveredData(recovered_losses); + stats_->updateCCState((unsigned int)current_state_ ? 1 : 0); + (*stats_summary_)(*socket_->getInterface(), *stats_); + } +} + +void RTCTransportProtocol::reassemble(ContentObject &content_object) { + auto read_buffer = content_object.getPayload(); + TRANSPORT_LOGD("Size of payload: %zu", read_buffer->length()); + read_buffer->trimStart(DATA_HEADER_SIZE); + Reassembly::read_buffer_ = std::move(read_buffer); + Reassembly::notifyApplication(); +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc.h b/libtransport/src/protocols/rtc/rtc.h new file mode 100644 index 000000000..596887067 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <protocols/datagram_reassembly.h> +#include <protocols/rtc/rtc_ldr.h> +#include <protocols/rtc/rtc_rc.h> +#include <protocols/rtc/rtc_state.h> +#include <protocols/transport_protocol.h> + +#include <unordered_set> +#include <vector> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCTransportProtocol : public TransportProtocol, + public DatagramReassembly { + public: + RTCTransportProtocol(implementation::ConsumerSocket *icnet_socket); + + ~RTCTransportProtocol(); + + using TransportProtocol::start; + + using TransportProtocol::stop; + + void resume() override; + + private: + enum class SyncState { catch_up = 0, in_sync = 1, last }; + + private: + // setup functions + void initParams(); + void reset() override; + + void inactiveProducer(); + + // protocol functions + void discoveredRtt(); + void newRound(); + + // window functions + void computeMaxSyncWindow(); + void updateSyncWindow(); + void decreaseSyncWindow(); + + // packet functions + void sendInterest(Name *interest_name); + void sendRtxInterest(uint32_t seq); + void sendProbeInterest(uint32_t seq); + void scheduleNextInterests() override; + void onTimeout(Interest::Ptr &&interest) override; + void onNack(const ContentObject &content_object); + void onProbe(const ContentObject &content_object); + void reassemble(ContentObject &content_object) override; + void onContentObject(Interest &interest, + ContentObject &content_object) override; + void onPacketDropped(Interest &interest, + ContentObject &content_object) override {} + void onReassemblyFailed(std::uint32_t missing_segment) override {} + + // interaction with app functions + void sendStatsToApp(uint32_t retx_count, uint32_t received_bytes, + uint32_t sent_interests, uint32_t lost_data, + uint32_t recovered_losses, uint32_t received_nacks); + // protocol state + bool start_send_interest_; + SyncState current_state_; + // cwin vars + uint32_t current_sync_win_; + uint32_t max_sync_win_; + + // controller var + std::unique_ptr<asio::steady_timer> round_timer_; + std::unique_ptr<asio::steady_timer> scheduler_timer_; + bool scheduler_timer_on_; + + // timeouts + std::unordered_set<uint32_t> timeouts_or_nacks_; + + // names/packets var + uint32_t next_segment_; + + std::shared_ptr<RTCState> state_; + std::shared_ptr<RTCRateControl> rc_; + std::shared_ptr<RTCLossDetectionAndRecovery> ldr_; + + uint32_t number_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_consts.h b/libtransport/src/protocols/rtc/rtc_consts.h new file mode 100644 index 000000000..0cf9516ab --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_consts.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <protocols/rtc/rtc_packet.h> +#include <stdint.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +// used in rtc +// protocol consts +const uint32_t ROUND_LEN = 200; +// ms interval of time on which +// we take decisions / measurements +const double INTEREST_LIFETIME_REDUCTION_FACTOR = 0.8; +// how big (in ms) should be the buffer at the producer. +// increasing this number we increase the time that an +// interest will wait for the data packet to be produced +// at the producer socket +const uint32_t PRODUCER_BUFFER_MS = 200; // ms + +// interest scheduler +const uint32_t MAX_INTERESTS_IN_BATCH = 5; +const uint32_t WAIT_BETWEEN_INTEREST_BATCHES = 1000; // usec + +// packet const +const uint32_t HICN_HEADER_SIZE = 40 + 20; // IPv6 + TCP bytes +const uint32_t RTC_INTEREST_LIFETIME = 1000; + +// probes sequence range +const uint32_t MIN_PROBE_SEQ = 0xefffffff; +const uint32_t MIN_RTT_PROBE_SEQ = MIN_PROBE_SEQ; +const uint32_t MAX_RTT_PROBE_SEQ = 0xffffffff - 1; +// RTT_PROBE_INTERVAL will be used during the section while +// INIT_RTT_PROBE_INTERVAL is used at the beginning to +// quickily estimate the RTT +const uint32_t RTT_PROBE_INTERVAL = 200000; // us +const uint32_t INIT_RTT_PROBE_INTERVAL = 500; // us +const uint32_t INIT_RTT_PROBES = 40; // number of probes to init RTT +// if the produdcer is not yet started we need to probe multple times +// to get an answer. we wait 100ms between each try +const uint32_t INIT_RTT_PROBE_RESTART = 100; // ms +// once we get the first probe we wait at most 60ms for the others +const uint32_t INIT_RTT_PROBE_WAIT = 30; // ms +// we reuires at least 5 probes to be recevied +const uint32_t INIT_RTT_MIN_PROBES_TO_RECV = 5; //ms +const uint32_t MAX_PENDING_PROBES = 10; + + +// congestion +const double MAX_QUEUING_DELAY = 100.0; // ms + +// data from cache +const double MAX_DATA_FROM_CACHE = 0.25; // 25% + +// window const +const uint32_t INITIAL_WIN = 5; // pkts +const uint32_t INITIAL_WIN_MAX = 1000000; // pkts +const uint32_t WIN_MIN = 5; // pkts +const double CATCH_UP_WIN_INCREMENT = 1.2; +// used in rate control +const double WIN_DECREASE_FACTOR = 0.5; +const double WIN_INCREASE_FACTOR = 1.5; + +// round in congestion +const double ROUNDS_BEFORE_TAKE_ACTION = 5; + +// used in state +const uint8_t ROUNDS_IN_SYNC_BEFORE_SWITCH = 3; +const double PRODUCTION_RATE_FRACTION = 0.8; + +const uint32_t INIT_PACKET_SIZE = 1200; + +const double MOVING_AVG_ALPHA = 0.8; + +const double MILLI_IN_A_SEC = 1000.0; +const double MICRO_IN_A_SEC = 1000000.0; + +const double MAX_CACHED_PACKETS = 262144; // 2^18 + // about 50 sec of traffic at 50Mbps + // with 1200 bytes packets + +const uint32_t MAX_ROUND_WHIOUT_PACKETS = + (20 * MILLI_IN_A_SEC) / ROUND_LEN; // 20 sec in rounds; + +// used in ldr +const uint32_t RTC_MAX_RTX = 100; +const uint32_t RTC_MAX_AGE = 60000; // in ms +const uint64_t MAX_TIMER_RTX = ~0; +const uint32_t SENTINEL_TIMER_INTERVAL = 100; // ms +const uint32_t MAX_RTX_WITH_SENTINEL = 10; // packets +const double CATCH_UP_RTT_INCREMENT = 1.2; + +// used by producer +const uint32_t PRODUCER_STATS_INTERVAL = 200; // ms +const uint32_t MIN_PRODUCTION_RATE = 10; // pps + // min prod rate + // set running several test + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_data_path.cc b/libtransport/src/protocols/rtc/rtc_data_path.cc new file mode 100644 index 000000000..c098088a3 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_data_path.cc @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_data_path.h> +#include <stdlib.h> + +#include <algorithm> +#include <cfloat> +#include <chrono> + +#define MAX_ROUNDS_WITHOUT_PKTS 10 // 2sec + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCDataPath::RTCDataPath(uint32_t path_id) + : path_id_(path_id), + min_rtt(UINT_MAX), + prev_min_rtt(UINT_MAX), + min_owd(INT_MAX), // this is computed like in LEDBAT, so it is not the + // real OWD, but the measured one, that depends on the + // clock of sender and receiver. the only meaningful + // value is is the queueing delay. for this reason we + // keep both RTT (for the windowd calculation) and OWD + // (for congestion/quality control) + prev_min_owd(INT_MAX), + avg_owd(DBL_MAX), + queuing_delay(DBL_MAX), + jitter_(0.0), + last_owd_(0), + largest_recv_seq_(0), + largest_recv_seq_time_(0), + avg_inter_arrival_(DBL_MAX), + received_nacks_(false), + received_packets_(false), + rounds_without_packets_(0), + last_received_data_packet_(0), + RTT_history_(HISTORY_LEN), + OWD_history_(HISTORY_LEN){}; + +void RTCDataPath::insertRttSample(uint64_t rtt) { + // for the rtt we only keep track of the min one + if (rtt < min_rtt) min_rtt = rtt; + last_received_data_packet_ = + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); +} + +void RTCDataPath::insertOwdSample(int64_t owd) { + // for owd we use both min and avg + if (owd < min_owd) min_owd = owd; + + if (avg_owd != DBL_MAX) + avg_owd = (avg_owd * (1 - ALPHA_RTC)) + (owd * ALPHA_RTC); + else { + avg_owd = owd; + } + + int64_t queueVal = owd - std::min(getMinOwd(), min_owd); + + if (queuing_delay != DBL_MAX) + queuing_delay = (queuing_delay * (1 - ALPHA_RTC)) + (queueVal * ALPHA_RTC); + else { + queuing_delay = queueVal; + } + + // keep track of the jitter computed as for RTP (RFC 3550) + int64_t diff = std::abs(owd - last_owd_); + last_owd_ = owd; + jitter_ += (1.0 / 16.0) * ((double)diff - jitter_); + + // owd is computed only for valid data packets so we count only + // this for decide if we recevie traffic or not + received_packets_ = true; +} + +void RTCDataPath::computeInterArrivalGap(uint32_t segment_number) { + // got packet in sequence, compute gap + if (largest_recv_seq_ == (segment_number - 1)) { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t delta = now - largest_recv_seq_time_; + largest_recv_seq_ = segment_number; + largest_recv_seq_time_ = now; + if (avg_inter_arrival_ == DBL_MAX) + avg_inter_arrival_ = delta; + else + avg_inter_arrival_ = + (avg_inter_arrival_ * (1 - ALPHA_RTC)) + (delta * ALPHA_RTC); + return; + } + + // ooo packet, update the stasts if needed + if (largest_recv_seq_ <= segment_number) { + largest_recv_seq_ = segment_number; + largest_recv_seq_time_ = + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + } +} + +void RTCDataPath::receivedNack() { received_nacks_ = true; } + +double RTCDataPath::getInterArrivalGap() { + if (avg_inter_arrival_ == DBL_MAX) return 0; + return avg_inter_arrival_; +} + +bool RTCDataPath::isActive() { + if (received_nacks_ && rounds_without_packets_ < MAX_ROUNDS_WITHOUT_PKTS) + return true; + return false; +} + +bool RTCDataPath::pathToProducer() { + if (received_nacks_) return true; + return false; +} + +void RTCDataPath::roundEnd() { + // reset min_rtt and add it to the history + if (min_rtt != UINT_MAX) { + prev_min_rtt = min_rtt; + } else { + // this may happen if we do not receive any packet + // from this path in the last round. in this case + // we use the measure from the previuos round + min_rtt = prev_min_rtt; + } + + if (min_rtt == 0) min_rtt = 1; + + RTT_history_.pushBack(min_rtt); + min_rtt = UINT_MAX; + + // do the same for min owd + if (min_owd != INT_MAX) { + prev_min_owd = min_owd; + } else { + min_owd = prev_min_owd; + } + + if (min_owd != INT_MAX) { + OWD_history_.pushBack(min_owd); + min_owd = INT_MAX; + } + + if (!received_packets_) + rounds_without_packets_++; + else + rounds_without_packets_ = 0; + received_packets_ = false; +} + +uint32_t RTCDataPath::getPathId() { return path_id_; } + +double RTCDataPath::getQueuingDealy() { return queuing_delay; } + +uint64_t RTCDataPath::getMinRtt() { + if (RTT_history_.size() != 0) return RTT_history_.begin(); + return 0; +} + +int64_t RTCDataPath::getMinOwd() { + if (OWD_history_.size() != 0) return OWD_history_.begin(); + return 0; +} + +double RTCDataPath::getJitter() { return jitter_; } + +uint64_t RTCDataPath::getLastPacketTS() { return last_received_data_packet_; } + +void RTCDataPath::clearRtt() { RTT_history_.clear(); } + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_data_path.h b/libtransport/src/protocols/rtc/rtc_data_path.h new file mode 100644 index 000000000..c5c37fc0d --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_data_path.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <stdint.h> +#include <utils/min_filter.h> + +#include <climits> + +namespace transport { + +namespace protocol { + +namespace rtc { + +const double ALPHA_RTC = 0.125; +const uint32_t HISTORY_LEN = 20; // 4 sec + +class RTCDataPath { + public: + RTCDataPath(uint32_t path_id); + + public: + void insertRttSample(uint64_t rtt); + void insertOwdSample(int64_t owd); + void computeInterArrivalGap(uint32_t segment_number); + void receivedNack(); + + uint32_t getPathId(); + uint64_t getMinRtt(); + double getQueuingDealy(); + double getInterArrivalGap(); + double getJitter(); + bool isActive(); + bool pathToProducer(); + uint64_t getLastPacketTS(); + + void clearRtt(); + + void roundEnd(); + + private: + uint32_t path_id_; + + int64_t getMinOwd(); + + uint64_t min_rtt; + uint64_t prev_min_rtt; + + int64_t min_owd; + int64_t prev_min_owd; + + double avg_owd; + + double queuing_delay; + + double jitter_; + int64_t last_owd_; + + uint32_t largest_recv_seq_; + uint64_t largest_recv_seq_time_; + double avg_inter_arrival_; + + // flags to check if a path is active + // we considere a path active if it reaches a producer + //(not a cache) --aka we got at least one nack on this path-- + // and if we receives packets + bool received_nacks_; + bool received_packets_; + uint8_t rounds_without_packets_; // if we don't get any packet + // for MAX_ROUNDS_WITHOUT_PKTS + // we consider the path inactive + uint64_t last_received_data_packet_; // timestamp for the last data received + // on this path + + utils::MinFilter<uint64_t> RTT_history_; + utils::MinFilter<int64_t> OWD_history_; +}; + +} // namespace rtc + +} // namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_ldr.cc b/libtransport/src/protocols/rtc/rtc_ldr.cc new file mode 100644 index 000000000..e91b29c04 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_ldr.cc @@ -0,0 +1,427 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_ldr.h> + +#include <algorithm> +#include <unordered_set> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCLossDetectionAndRecovery::RTCLossDetectionAndRecovery( + SendRtxCallback &&callback, asio::io_service &io_service) + : rtx_on_(false), + next_rtx_timer_(MAX_TIMER_RTX), + last_event_(0), + sentinel_timer_interval_(MAX_TIMER_RTX), + send_rtx_callback_(std::move(callback)) { + timer_ = std::make_unique<asio::steady_timer>(io_service); + sentinel_timer_ = std::make_unique<asio::steady_timer>(io_service); +} + +RTCLossDetectionAndRecovery::~RTCLossDetectionAndRecovery() {} + +void RTCLossDetectionAndRecovery::turnOnRTX() { + rtx_on_ = true; + scheduleSentinelTimer(state_->getRTT() * CATCH_UP_RTT_INCREMENT); +} + +void RTCLossDetectionAndRecovery::turnOffRTX() { + rtx_on_ = false; + clear(); +} + +void RTCLossDetectionAndRecovery::onTimeout(uint32_t seq) { + // always add timeouts to the RTX list to avoid to send the same packet as if + // it was not a rtx + addToRetransmissions(seq, seq + 1); + last_event_ = getNow(); +} + +void RTCLossDetectionAndRecovery::onDataPacketReceived( + const core::ContentObject &content_object) { + last_event_ = getNow(); + + uint32_t seq = content_object.getName().getSuffix(); + if (deleteRtx(seq)) { + state_->onPacketRecovered(seq); + } else { + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + TRANSPORT_LOGD("received data. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, seq); + } +} + +void RTCLossDetectionAndRecovery::onNackPacketReceived( + const core::ContentObject &nack) { + last_event_ = getNow(); + + uint32_t seq = nack.getName().getSuffix(); + + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + + struct nack_packet_t *nack_pkt = + (struct nack_packet_t *)nack.getPayload()->data(); + uint32_t production_seq = nack_pkt->getProductionSegement(); + + if (production_seq > seq) { + // this is a past nack, all data before productionSeq are lost. if + // productionSeq > state_->getHighestSeqReceivedInOrder() is impossible to + // recover any packet. If this is not the case we can try to recover the + // packets between state_->getHighestSeqReceivedInOrder() and productionSeq. + // e.g.: the client receives packets 8 10 11 9 where 9 is a nack with + // productionSeq = 14. 9 is lost but we can try to recover packets 12 13 and + // 14 that are not arrived yet + deleteRtx(seq); + TRANSPORT_LOGD("received past nack. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); + } else { + // future nack. here there should be a gap between the last data received + // and this packet and is it possible to recover the packets between the + // last received data and the production seq. we should not use the seq + // number of the nack since we know that is too early to ask for this seq + // number + // e.g.: // e.g.: the client receives packets 10 11 12 20 where 20 is a nack + // with productionSeq = 18. this says that all the packets between 12 and 18 + // may got lost and we should ask them + deleteRtx(seq); + TRANSPORT_LOGD("received futrue nack. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); + } +} + +void RTCLossDetectionAndRecovery::onProbePacketReceived( + const core::ContentObject &probe) { + // we don't log the reception of a probe packet for the sentinel timer because + // probes are not taken into account into the sync window. we use them as + // future nacks to detect possible packets lost + if (TRANSPORT_EXPECT_FALSE(!rtx_on_)) return; // do not add if RTX is off + struct nack_packet_t *probe_pkt = + (struct nack_packet_t *)probe.getPayload()->data(); + uint32_t production_seq = probe_pkt->getProductionSegement(); + TRANSPORT_LOGD("received probe. add from %u to %u ", + state_->getHighestSeqReceivedInOrder() + 1, production_seq); + addToRetransmissions(state_->getHighestSeqReceivedInOrder() + 1, + production_seq); +} + +void RTCLossDetectionAndRecovery::clear() { + rtx_state_.clear(); + rtx_timers_.clear(); + sentinel_timer_->cancel(); + if (next_rtx_timer_ != MAX_TIMER_RTX) { + next_rtx_timer_ = MAX_TIMER_RTX; + timer_->cancel(); + } +} + +void RTCLossDetectionAndRecovery::addToRetransmissions(uint32_t start, + uint32_t stop) { + // skip nacked packets + if (start <= state_->getLastSeqNacked()) { + start = state_->getLastSeqNacked() + 1; + } + + // skip received or lost packets + if (start <= state_->getHighestSeqReceivedInOrder()) { + start = state_->getHighestSeqReceivedInOrder() + 1; + } + + for (uint32_t seq = start; seq < stop; seq++) { + if (!isRtx(seq) && // is not already an rtx + // is not received or lost + state_->isReceivedOrLost(seq) == PacketState::UNKNOWN) { + // add rtx + rtxState state; + state.first_send_ = state_->getInterestSentTime(seq); + if (state.first_send_ == 0) // this interest was never sent before + state.first_send_ = getNow(); + state.next_send_ = computeNextSend(seq, true); + state.rtx_count_ = 0; + TRANSPORT_LOGD("add %u to retransmissions. next rtx is %lu ", seq, + (state.next_send_ - getNow())); + rtx_state_.insert(std::pair<uint32_t, rtxState>(seq, state)); + rtx_timers_.insert(std::pair<uint64_t, uint32_t>(state.next_send_, seq)); + } + } + scheduleNextRtx(); +} + +uint64_t RTCLossDetectionAndRecovery::computeNextSend(uint32_t seq, + bool new_rtx) { + uint64_t now = getNow(); + if (new_rtx) { + // for the new rtx we wait one estimated IAT after the loss detection. this + // is bacause, assuming that packets arrive with a constant IAT, we should + // get a new packet every IAT + double prod_rate = state_->getProducerRate(); + uint32_t estimated_iat = SENTINEL_TIMER_INTERVAL; + uint32_t jitter = 0; + + if (prod_rate != 0) { + double packet_size = state_->getAveragePacketSize(); + estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + jitter = ceil(state_->getJitter()); + } + + uint32_t wait = estimated_iat + jitter; + TRANSPORT_LOGD("first rtx for %u in %u ms, rtt = %lu ait = %u jttr = %u", + seq, wait, state_->getRTT(), estimated_iat, jitter); + + return now + wait; + } else { + // wait one RTT + // however if the IAT is larger than the RTT, wait one IAT + uint32_t wait = SENTINEL_TIMER_INTERVAL; + + double prod_rate = state_->getProducerRate(); + if (prod_rate == 0) { + return now + SENTINEL_TIMER_INTERVAL; + } + + double packet_size = state_->getAveragePacketSize(); + uint32_t estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + + uint64_t rtt = state_->getRTT(); + if (rtt == 0) rtt = SENTINEL_TIMER_INTERVAL; + wait = rtt; + + if (estimated_iat > rtt) wait = estimated_iat; + + uint32_t jitter = ceil(state_->getJitter()); + wait += jitter; + + // it may happen that the channel is congested and we have some additional + // queuing delay to take into account + uint32_t queue = ceil(state_->getQueuing()); + wait += queue; + + TRANSPORT_LOGD( + "next rtx for %u in %u ms, rtt = %lu ait = %u jttr = %u queue = %u", + seq, wait, state_->getRTT(), estimated_iat, jitter, queue); + + return now + wait; + } +} + +void RTCLossDetectionAndRecovery::retransmit() { + if (rtx_timers_.size() == 0) return; + + uint64_t now = getNow(); + + auto it = rtx_timers_.begin(); + std::unordered_set<uint32_t> lost_pkt; + uint32_t sent_counter = 0; + while (it != rtx_timers_.end() && it->first <= now && + sent_counter < MAX_INTERESTS_IN_BATCH) { + uint32_t seq = it->second; + auto rtx_it = + rtx_state_.find(seq); // this should always return a valid iter + if (rtx_it->second.rtx_count_ >= RTC_MAX_RTX || + (now - rtx_it->second.first_send_) >= RTC_MAX_AGE || + seq < state_->getLastSeqNacked()) { + // max rtx reached or packet too old or packet nacked, this packet is lost + TRANSPORT_LOGD( + "packet %u lost because 1) max rtx: %u 2) max age: %u 3) naked: %u", + seq, (rtx_it->second.rtx_count_ >= RTC_MAX_RTX), + ((now - rtx_it->second.first_send_) >= RTC_MAX_AGE), + (seq < state_->getLastSeqNacked())); + lost_pkt.insert(seq); + it++; + } else { + // resend the packet + state_->onRetransmission(seq); + double prod_rate = state_->getProducerRate(); + if (prod_rate != 0) rtx_it->second.rtx_count_++; + rtx_it->second.next_send_ = computeNextSend(seq, false); + it = rtx_timers_.erase(it); + rtx_timers_.insert( + std::pair<uint64_t, uint32_t>(rtx_it->second.next_send_, seq)); + TRANSPORT_LOGD("send rtx for sequence %u, next send in %lu", seq, + (rtx_it->second.next_send_ - now)); + send_rtx_callback_(seq); + sent_counter++; + } + } + + // remove packets if needed + for (auto lost_it = lost_pkt.begin(); lost_it != lost_pkt.end(); lost_it++) { + uint32_t seq = *lost_it; + state_->onPacketLost(seq); + deleteRtx(seq); + } +} + +void RTCLossDetectionAndRecovery::scheduleNextRtx() { + if (rtx_timers_.size() == 0) { + // all the rtx were removed, reset timer + next_rtx_timer_ = MAX_TIMER_RTX; + return; + } + + // check if timer is alreay set + if (next_rtx_timer_ != MAX_TIMER_RTX) { + // a new check for rtx is already scheduled + if (next_rtx_timer_ > rtx_timers_.begin()->first) { + // we need to re-schedule it + timer_->cancel(); + } else { + // wait for the next timer + return; + } + } + + // set a new timer + next_rtx_timer_ = rtx_timers_.begin()->first; + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint64_t wait = 1; + if (next_rtx_timer_ != MAX_TIMER_RTX && next_rtx_timer_ > now) + wait = next_rtx_timer_ - now; + + std::weak_ptr<RTCLossDetectionAndRecovery> self(shared_from_this()); + timer_->expires_from_now(std::chrono::milliseconds(wait)); + timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->retransmit(); + s->next_rtx_timer_ = MAX_TIMER_RTX; + s->scheduleNextRtx(); + } + }); +} + +bool RTCLossDetectionAndRecovery::deleteRtx(uint32_t seq) { + auto it_rtx = rtx_state_.find(seq); + if (it_rtx == rtx_state_.end()) return false; // rtx not found + + uint64_t ts = it_rtx->second.next_send_; + auto it_timers = rtx_timers_.find(ts); + while (it_timers != rtx_timers_.end() && it_timers->first == ts) { + if (it_timers->second == seq) { + rtx_timers_.erase(it_timers); + break; + } + it_timers++; + } + + bool lost = it_rtx->second.rtx_count_ > 0; + rtx_state_.erase(it_rtx); + + return lost; +} + +void RTCLossDetectionAndRecovery::scheduleSentinelTimer( + uint64_t expires_from_now) { + std::weak_ptr<RTCLossDetectionAndRecovery> self(shared_from_this()); + sentinel_timer_->expires_from_now( + std::chrono::milliseconds(expires_from_now)); + sentinel_timer_->async_wait([self](std::error_code ec) { + if (ec) return; + if (auto s = self.lock()) { + s->sentinelTimer(); + } + }); +} + +void RTCLossDetectionAndRecovery::sentinelTimer() { + uint64_t now = getNow(); + + bool expired = false; + bool sent = false; + if ((now - last_event_) >= sentinel_timer_interval_) { + // at least a sentinel_timer_interval_ elapsed since last event + expired = true; + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive())) { + // this happens at the beginning (or if the producer stops for some + // reason) we need to keep sending interest 0 until we get an answer + TRANSPORT_LOGD( + "sentinel timer: the producer is not active, send packet 0"); + state_->onRetransmission(0); + send_rtx_callback_(0); + } else { + TRANSPORT_LOGD( + "sentinel timer: the producer is active, send the 10 oldest packets"); + sent = true; + uint32_t rtx = 0; + auto it = state_->getPendingInterestsMapBegin(); + auto end = state_->getPendingInterestsMapEnd(); + while (it != end && rtx < MAX_RTX_WITH_SENTINEL) { + uint32_t seq = it->first; + TRANSPORT_LOGD("sentinel timer, add %u to the rtx list", seq); + addToRetransmissions(seq, seq + 1); + rtx++; + it++; + } + } + } else { + // sentinel timer did not expire because we registered at least one event + } + + uint32_t next_timer; + double prod_rate = state_->getProducerRate(); + if (TRANSPORT_EXPECT_FALSE(!state_->isProducerActive()) || prod_rate == 0) { + TRANSPORT_LOGD("next timer in %u", SENTINEL_TIMER_INTERVAL); + next_timer = SENTINEL_TIMER_INTERVAL; + } else { + double prod_rate = state_->getProducerRate(); + double packet_size = state_->getAveragePacketSize(); + uint32_t estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + uint32_t jitter = ceil(state_->getJitter()); + + // try to reduce the number of timers if the estimated IAT is too small + next_timer = std::max((estimated_iat + jitter) * 20, (uint32_t)1); + TRANSPORT_LOGD("next sentinel in %u ms, rate: %f, iat: %u, jitter: %u", + next_timer, ((prod_rate * 8.0) / 1000000.0), estimated_iat, + jitter); + + if (!expired) { + // discount the amout of time that is already passed + uint32_t discount = now - last_event_; + if (next_timer > discount) { + next_timer = next_timer - discount; + } else { + // in this case we trigger the timer in 1 ms + next_timer = 1; + } + TRANSPORT_LOGD("timer after discout: %u", next_timer); + } else if (sent) { + // wait at least one producer stats interval + owd to check if the + // production rate is reducing. + uint32_t min_wait = PRODUCER_STATS_INTERVAL + ceil(state_->getQueuing()); + next_timer = std::max(next_timer, min_wait); + TRANSPORT_LOGD("wait for updates from prod, next timer: %u", next_timer); + } + } + + scheduleSentinelTimer(next_timer); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_ldr.h b/libtransport/src/protocols/rtc/rtc_ldr.h new file mode 100644 index 000000000..c0912303b --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_ldr.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/config.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/name.h> +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_state.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <functional> +#include <map> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCLossDetectionAndRecovery + : public std::enable_shared_from_this<RTCLossDetectionAndRecovery> { + struct rtx_state_ { + uint64_t first_send_; + uint64_t next_send_; + uint32_t rtx_count_; + }; + + using rtxState = struct rtx_state_; + using SendRtxCallback = std::function<void(uint32_t)>; + + public: + RTCLossDetectionAndRecovery(SendRtxCallback &&callback, + asio::io_service &io_service); + + ~RTCLossDetectionAndRecovery(); + + void setState(std::shared_ptr<RTCState> state) { state_ = state; } + void turnOnRTX(); + void turnOffRTX(); + + void onTimeout(uint32_t seq); + void onDataPacketReceived(const core::ContentObject &content_object); + void onNackPacketReceived(const core::ContentObject &nack); + void onProbePacketReceived(const core::ContentObject &probe); + + void clear(); + + bool isRtx(uint32_t seq) { + if (rtx_state_.find(seq) != rtx_state_.end()) return true; + return false; + } + + private: + void addToRetransmissions(uint32_t start, uint32_t stop); + uint64_t computeNextSend(uint32_t seq, bool new_rtx); + void retransmit(); + void scheduleNextRtx(); + bool deleteRtx(uint32_t seq); + void scheduleSentinelTimer(uint64_t expires_from_now); + void sentinelTimer(); + + uint64_t getNow() { + using namespace std::chrono; + uint64_t now = + duration_cast<milliseconds>(steady_clock::now().time_since_epoch()) + .count(); + return now; + } + + // this map keeps track of the retransmitted interest, ordered from the oldest + // to the newest one. the state contains the timer of the first send of the + // interest (from pendingIntetests_), the timer of the next send (key of the + // multimap) and the number of rtx + std::map<uint32_t, rtxState> rtx_state_; + // this map stored the rtx by timer. The key is the time at which the rtx + // should be sent, and the val is the interest seq number + std::multimap<uint64_t, uint32_t> rtx_timers_; + + bool rtx_on_; + uint64_t next_rtx_timer_; + uint64_t last_event_; + uint64_t sentinel_timer_interval_; + std::unique_ptr<asio::steady_timer> timer_; + std::unique_ptr<asio::steady_timer> sentinel_timer_; + std::shared_ptr<RTCState> state_; + + SendRtxCallback send_rtx_callback_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_packet.h b/libtransport/src/protocols/rtc/rtc_packet.h new file mode 100644 index 000000000..abb1323a3 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_packet.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + */ + +/* data packet + * +-----------------------------------------+ + * | uint64_t: timestamp | + * | | + * +-----------------------------------------+ + * | uint32_t: prod rate (bytes per sec) | + * +-----------------------------------------+ + * | payload | + * | ... | + */ + +/* nack packet + * +-----------------------------------------+ + * | uint64_t: timestamp | + * | | + * +-----------------------------------------+ + * | uint32_t: prod rate (bytes per sec) | + * +-----------------------------------------+ + * | uint32_t: current seg in production | + * +-----------------------------------------+ + */ + +#pragma once +#include <arpa/inet.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +inline uint64_t _ntohll(const uint64_t *input) { + uint64_t return_val; + uint8_t *tmp = (uint8_t *)&return_val; + + tmp[0] = *input >> 56; + tmp[1] = *input >> 48; + tmp[2] = *input >> 40; + tmp[3] = *input >> 32; + tmp[4] = *input >> 24; + tmp[5] = *input >> 16; + tmp[6] = *input >> 8; + tmp[7] = *input >> 0; + + return return_val; +} + +inline uint64_t _htonll(const uint64_t *input) { return (_ntohll(input)); } + +const uint32_t DATA_HEADER_SIZE = 12; // bytes + // XXX: sizeof(data_packet_t) is 16 + // beacuse of padding +const uint32_t NACK_HEADER_SIZE = 16; + +struct data_packet_t { + uint64_t timestamp; + uint32_t prod_rate; + + inline uint64_t getTimestamp() const { return _ntohll(×tamp); } + inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + + inline uint32_t getProductionRate() const { return ntohl(prod_rate); } + inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } +}; + +struct nack_packet_t { + uint64_t timestamp; + uint32_t prod_rate; + uint32_t prod_seg; + + inline uint64_t getTimestamp() const { return _ntohll(×tamp); } + inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + + inline uint32_t getProductionRate() const { return ntohl(prod_rate); } + inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } + + inline uint32_t getProductionSegement() const { return ntohl(prod_seg); } + inline void setProductionSegement(uint32_t seg) { prod_seg = htonl(seg); } +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc.h b/libtransport/src/protocols/rtc/rtc_rc.h new file mode 100644 index 000000000..34d090092 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <protocols/rtc/rtc_state.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControl : public std::enable_shared_from_this<RTCRateControl> { + public: + RTCRateControl() + : rc_on_(false), + congestion_win_(1000000), // init the win to a large number + congestion_state_(CongestionState::Normal), + protocol_state_(nullptr) {} + + virtual ~RTCRateControl() = default; + + void turnOnRateControl() { rc_on_ = true; } + void setState(std::shared_ptr<RTCState> state) { protocol_state_ = state; }; + uint32_t getCongesionWindow() { return congestion_win_; }; + + virtual void onNewRound(double round_len) = 0; + virtual void onDataPacketReceived( + const core::ContentObject &content_object) = 0; + + protected: + enum class CongestionState { Normal = 0, Underuse = 1, Congested = 2, Last }; + + protected: + bool rc_on_; + uint32_t congestion_win_; + CongestionState congestion_state_; + + std::shared_ptr<RTCState> protocol_state_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_frame.cc b/libtransport/src/protocols/rtc/rtc_rc_frame.cc new file mode 100644 index 000000000..b577b5bea --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_frame.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_frame.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCRateControlFrame::RTCRateControlFrame() : cc_detector_() {} + +RTCRateControlFrame::~RTCRateControlFrame() {} + +void RTCRateControlFrame::onNewRound(double round_len) { + if (!rc_on_) return; + + CongestionState prev_congestion_state = congestion_state_; + cc_detector_.updateStats(); + congestion_state_ = (CongestionState)cc_detector_.getState(); + + if (congestion_state_ == CongestionState::Congested) { + if (prev_congestion_state == CongestionState::Normal) { + // congestion detected, notify app and init congestion win + double prod_rate = protocol_state_->getReceivedRate(); + double rtt = (double)protocol_state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = protocol_state_->getAveragePacketSize(); + + if (prod_rate == 0.0 || rtt == 0.0 || packet_size == 0.0) { + // TODO do something + return; + } + + congestion_win_ = (uint32_t)ceil(prod_rate * rtt / packet_size); + } + uint32_t win = congestion_win_ * WIN_DECREASE_FACTOR; + congestion_win_ = std::max(win, WIN_MIN); + return; + } +} + +void RTCRateControlFrame::onDataPacketReceived( + const core::ContentObject &content_object) { + if (!rc_on_) return; + + uint32_t seq = content_object.getName().getSuffix(); + if (!protocol_state_->isPending(seq)) return; + + cc_detector_.addPacket(content_object); +} + +void RTCRateControlFrame::receivedBwProbeTrain(uint64_t firts_probe_ts, + uint64_t last_probe_ts, + uint32_t total_probes) { + // TODO + return; +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_frame.h b/libtransport/src/protocols/rtc/rtc_rc_frame.h new file mode 100644 index 000000000..25d5ddbb6 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_frame.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <protocols/rtc/congestion_detection.h> +#include <protocols/rtc/rtc_rc.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControlFrame : public RTCRateControl { + public: + RTCRateControlFrame(); + + ~RTCRateControlFrame(); + + void onNewRound(double round_len); + void onDataPacketReceived(const core::ContentObject &content_object); + + void receivedBwProbeTrain(uint64_t firts_probe_ts, uint64_t last_probe_ts, + uint32_t total_probes); + + private: + CongestionDetection cc_detector_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_queue.cc b/libtransport/src/protocols/rtc/rtc_rc_queue.cc new file mode 100644 index 000000000..a1c89e329 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_queue.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_rc_queue.h> + +#include <algorithm> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCRateControlQueue::RTCRateControlQueue() + : rounds_since_last_drop_(0), + rounds_without_congestion_(0), + last_queue_(0) {} + +RTCRateControlQueue::~RTCRateControlQueue() {} + +void RTCRateControlQueue::onNewRound(double round_len) { + if (!rc_on_) return; + + double received_rate = protocol_state_->getReceivedRate(); + double target_rate = + protocol_state_->getProducerRate() * PRODUCTION_RATE_FRACTION; + double rtt = (double)protocol_state_->getRTT() / MILLI_IN_A_SEC; + double packet_size = protocol_state_->getAveragePacketSize(); + double queue = protocol_state_->getQueuing(); + + if (rtt == 0.0) return; // no info from the producer + + CongestionState prev_congestion_state = congestion_state_; + + if (prev_congestion_state == CongestionState::Normal && + received_rate >= target_rate) { + // if the queue is high in this case we are most likelly fighting with + // a TCP flow and there is enough bandwidth to match the producer rate + congestion_state_ = CongestionState::Normal; + } else if (queue > MAX_QUEUING_DELAY || last_queue_ == queue) { + // here we detect congestion. in the case that last_queue == queue + // the consumer didn't receive any packet from the producer so we + // consider this case as congestion + // TODO: wath happen in case of high loss rate? + congestion_state_ = CongestionState::Congested; + } else { + // nothing bad is happening + congestion_state_ = CongestionState::Normal; + } + + last_queue_ = queue; + + if (congestion_state_ == CongestionState::Congested) { + if (prev_congestion_state == CongestionState::Normal) { + // init the congetion window using the received rate + congestion_win_ = (uint32_t)ceil(received_rate * rtt / packet_size); + rounds_since_last_drop_ = ROUNDS_BEFORE_TAKE_ACTION + 1; + } + + if (rounds_since_last_drop_ >= ROUNDS_BEFORE_TAKE_ACTION) { + uint32_t win = congestion_win_ * WIN_DECREASE_FACTOR; + congestion_win_ = std::max(win, WIN_MIN); + rounds_since_last_drop_ = 0; + return; + } + + rounds_since_last_drop_++; + } + + if (congestion_state_ == CongestionState::Normal) { + if (prev_congestion_state == CongestionState::Congested) { + rounds_without_congestion_ = 0; + } + + rounds_without_congestion_++; + if (rounds_without_congestion_ < ROUNDS_BEFORE_TAKE_ACTION) return; + + congestion_win_ = congestion_win_ * WIN_INCREASE_FACTOR; + congestion_win_ = std::min(congestion_win_, INITIAL_WIN_MAX); + } +} + +void RTCRateControlQueue::onDataPacketReceived( + const core::ContentObject &content_object) { + // nothing to do + return; +} + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_rc_queue.h b/libtransport/src/protocols/rtc/rtc_rc_queue.h new file mode 100644 index 000000000..407354d43 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_rc_queue.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/utils/shared_ptr_utils.h> +#include <protocols/rtc/rtc_rc.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class RTCRateControlQueue : public RTCRateControl { + public: + RTCRateControlQueue(); + + ~RTCRateControlQueue(); + + void onNewRound(double round_len); + void onDataPacketReceived(const core::ContentObject &content_object); + + auto shared_from_this() { return utils::shared_from(this); } + + private: + uint32_t rounds_since_last_drop_; + uint32_t rounds_without_congestion_; + double last_queue_; +}; + +} // end namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_state.cc b/libtransport/src/protocols/rtc/rtc_state.cc new file mode 100644 index 000000000..eabf8942c --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_state.cc @@ -0,0 +1,560 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <protocols/rtc/rtc_consts.h> +#include <protocols/rtc/rtc_state.h> + +namespace transport { + +namespace protocol { + +namespace rtc { + +RTCState::RTCState(ProbeHandler::SendProbeCallback &&rtt_probes_callback, + DiscoveredRttCallback &&discovered_rtt_callback, + asio::io_service &io_service) + : rtt_probes_(std::make_shared<ProbeHandler>( + std::move(rtt_probes_callback), io_service)), + discovered_rtt_callback_(std::move(discovered_rtt_callback)) { + init_rtt_timer_ = std::make_unique<asio::steady_timer>(io_service); + initParams(); +} + +RTCState::~RTCState() {} + +void RTCState::initParams() { + // packets counters (total) + sent_interests_ = 0; + sent_rtx_ = 0; + received_data_ = 0; + received_nacks_ = 0; + received_timeouts_ = 0; + received_probes_ = 0; + + // loss counters + packets_lost_ = 0; + losses_recovered_ = 0; + first_seq_in_round_ = 0; + highest_seq_received_ = 0; + highest_seq_received_in_order_ = 0; + last_seq_nacked_ = 0; + loss_rate_ = 0.0; + residual_loss_rate_ = 0.0; + + // bw counters + received_bytes_ = 0; + avg_packet_size_ = INIT_PACKET_SIZE; + production_rate_ = 0.0; + received_rate_ = 0.0; + + // nack counter + nack_on_last_round_ = false; + received_nacks_last_round_ = 0; + + // packets counter + received_packets_last_round_ = 0; + received_data_last_round_ = 0; + received_data_from_cache_ = 0; + data_from_cache_rate_ = 0; + sent_interests_last_round_ = 0; + sent_rtx_last_round_ = 0; + + // round conunters + rounds_ = 0; + rounds_without_nacks_ = 0; + rounds_without_packets_ = 0; + + last_production_seq_ = 0; + producer_is_active_ = false; + last_prod_update_ = 0; + + // paths stats + path_table_.clear(); + main_path_ = nullptr; + + // packet received + received_or_lost_packets_.clear(); + + // pending interests + pending_interests_.clear(); + + // init rtt + first_interest_sent_ = ~0; + init_rtt_ = false; + rtt_probes_->setProbes(INIT_RTT_PROBE_INTERVAL, INIT_RTT_PROBES); + rtt_probes_->sendProbes(); + setInitRttTimer(INIT_RTT_PROBE_RESTART); +} + +// packet events +void RTCState::onSendNewInterest(const core::Name *interest_name) { + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + uint32_t seq = interest_name->getSuffix(); + pending_interests_.insert(std::pair<uint32_t, uint64_t>(seq, now)); + + if(sent_interests_ == 0) first_interest_sent_ = now; + + sent_interests_++; + sent_interests_last_round_++; +} + +void RTCState::onTimeout(uint32_t seq) { + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + pending_interests_.erase(it); + } + received_timeouts_++; +} + +void RTCState::onRetransmission(uint32_t seq) { + // remove the interest for the pendingInterest map only after the first rtx. + // in this way we can handle the ooo packets that come in late as normla + // packet. we consider a packet lost only if we sent at least an RTX for it. + // XXX this may become problematic if we stop the RTX transmissions + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + pending_interests_.erase(it); + packets_lost_++; + } + sent_rtx_++; + sent_rtx_last_round_++; +} + +void RTCState::onDataPacketReceived(const core::ContentObject &content_object, + bool compute_stats) { + uint32_t seq = content_object.getName().getSuffix(); + if (compute_stats) { + updatePathStats(content_object, false); + received_data_last_round_++; + } + received_data_++; + + struct data_packet_t *data_pkt = + (struct data_packet_t *)content_object.getPayload()->data(); + uint64_t production_time = data_pkt->getTimestamp(); + if (last_prod_update_ < production_time) { + last_prod_update_ = production_time; + uint32_t production_rate = data_pkt->getProductionRate(); + production_rate_ = (double)production_rate; + } + + updatePacketSize(content_object); + updateReceivedBytes(content_object); + addRecvOrLost(seq, PacketState::RECEIVED); + + if (seq > highest_seq_received_) highest_seq_received_ = seq; + + // the producer is responding + // it is generating valid data packets so we consider it active + producer_is_active_ = true; + + received_packets_last_round_++; +} + +void RTCState::onNackPacketReceived(const core::ContentObject &nack, + bool compute_stats) { + uint32_t seq = nack.getName().getSuffix(); + struct nack_packet_t *nack_pkt = + (struct nack_packet_t *)nack.getPayload()->data(); + uint64_t production_time = nack_pkt->getTimestamp(); + uint32_t production_seq = nack_pkt->getProductionSegement(); + uint32_t production_rate = nack_pkt->getProductionRate(); + + if (TRANSPORT_EXPECT_FALSE(main_path_ == nullptr) || + last_prod_update_ < production_time) { + // update production rate + last_prod_update_ = production_time; + last_production_seq_ = production_seq; + production_rate_ = (double)production_rate; + } + + if (compute_stats) { + // this is not an RTX + updatePathStats(nack, true); + nack_on_last_round_ = true; + } + + // for statistics pourpose we log all nacks, also the one received for + // retransmitted packets + received_nacks_++; + received_nacks_last_round_++; + + if (production_seq > seq) { + // old nack, seq is lost + // update last nacked + if (last_seq_nacked_ < seq) last_seq_nacked_ = seq; + TRANSPORT_LOGD("lost packet %u beacuse of a past nack", seq); + onPacketLost(seq); + } else if (seq > production_seq) { + // future nack + // remove the nack from the pending interest map + // (the packet is not received/lost yet) + pending_interests_.erase(seq); + } else { + // this should be a quite rear event. simply remove the + // packet from the pending interest list + pending_interests_.erase(seq); + } + + // the producer is responding + // we consider it active only if the production rate is not 0 + // or the production sequence number is not 1 + if (production_rate_ != 0 || production_seq != 1) { + producer_is_active_ = true; + } + + received_packets_last_round_++; +} + +void RTCState::onPacketLost(uint32_t seq) { + TRANSPORT_LOGD("packet %u is lost", seq); + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) { + // this packet was never retransmitted so it does + // not appear in the loss count + packets_lost_++; + } + addRecvOrLost(seq, PacketState::LOST); +} + +void RTCState::onPacketRecovered(uint32_t seq) { + losses_recovered_++; + addRecvOrLost(seq, PacketState::RECEIVED); +} + +bool RTCState::onProbePacketReceived(const core::ContentObject &probe) { + uint32_t seq = probe.getName().getSuffix(); + uint64_t rtt; + + rtt = rtt_probes_->getRtt(seq); + + if (rtt == 0) return false; // this is not a valid probe + + // like for data and nacks update the path stats. Here the RTT is computed + // by the probe handler. Both probes for rtt and bw are good to esimate + // info on the path + uint32_t path_label = probe.getPathLabel(); + + auto path_it = path_table_.find(path_label); + + // update production rate and last_seq_nacked like in case of a nack + struct nack_packet_t *probe_pkt = + (struct nack_packet_t *)probe.getPayload()->data(); + uint64_t sender_timestamp = probe_pkt->getTimestamp(); + uint32_t production_seq = probe_pkt->getProductionSegement(); + uint32_t production_rate = probe_pkt->getProductionRate(); + + + if (path_it == path_table_.end()) { + // found a new path + std::shared_ptr<RTCDataPath> newPath = + std::make_shared<RTCDataPath>(path_label); + auto ret = path_table_.insert( + std::pair<uint32_t, std::shared_ptr<RTCDataPath>>(path_label, newPath)); + path_it = ret.first; + } + + auto path = path_it->second; + + path->insertRttSample(rtt); + path->receivedNack(); + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + int64_t OWD = now - sender_timestamp; + path->insertOwdSample(OWD); + + if (last_prod_update_ < sender_timestamp) { + last_production_seq_ = production_seq; + last_prod_update_ = sender_timestamp; + production_rate_ = (double)production_rate; + } + + // the producer is responding + // we consider it active only if the production rate is not 0 + // or the production sequence numner is not 1 + if (production_rate_ != 0 || production_seq != 1) { + producer_is_active_ = true; + } + + // check for init RTT. if received_probes_ is equal to 0 schedule a timer to + // wait for the INIT_RTT_PROBES. in this way if some probes get lost we don't + // wait forever + received_probes_++; + + if(!init_rtt_ && received_probes_ <= INIT_RTT_PROBES){ + if(received_probes_ == 1){ + // we got the first probe, wait at most INIT_RTT_PROBE_WAIT sec for the others + main_path_ = path; + setInitRttTimer(INIT_RTT_PROBE_WAIT); + } + if(received_probes_ == INIT_RTT_PROBES) { + // we are done + init_rtt_timer_->cancel(); + checkInitRttTimer(); + } + } + + received_packets_last_round_++; + + // ignore probes sent before the first interest + if((now - rtt) <= first_interest_sent_) return false; + return true; +} + +void RTCState::onNewRound(double round_len, bool in_sync) { + // XXX + // here we take into account only the single path case so we assume that we + // don't use two paths in parellel for this single flow + + if (path_table_.empty()) return; + + double bytes_per_sec = + ((double)received_bytes_ * (MILLI_IN_A_SEC / round_len)); + if(received_rate_ == 0) + received_rate_ = bytes_per_sec; + else + received_rate_ = (received_rate_ * MOVING_AVG_ALPHA) + + ((1 - MOVING_AVG_ALPHA) * bytes_per_sec); + + // search for an active path. There should be only one active path (meaning a + // path that leads to the producer socket -no cache- and from which we are + // currently getting data packets) at any time. However it may happen that + // there are mulitple active paths in case of mobility (the old path will + // remain active for a short ammount of time). The main path is selected as + // the active path from where the consumer received the latest data packet + + uint64_t last_packet_ts = 0; + main_path_ = nullptr; + + for (auto it = path_table_.begin(); it != path_table_.end(); it++) { + it->second->roundEnd(); + if (it->second->isActive()) { + uint64_t ts = it->second->getLastPacketTS(); + if (ts > last_packet_ts) { + last_packet_ts = ts; + main_path_ = it->second; + } + } + } + + if (in_sync) updateLossRate(); + + // handle nacks + if (!nack_on_last_round_ && received_bytes_ > 0) { + rounds_without_nacks_++; + } else { + rounds_without_nacks_ = 0; + } + + // check if the producer is active + if (received_packets_last_round_ != 0) { + rounds_without_packets_ = 0; + } else { + rounds_without_packets_++; + if (rounds_without_packets_ >= MAX_ROUND_WHIOUT_PACKETS && + producer_is_active_ != false) { + initParams(); + } + } + + // compute cache/producer ratio + if (received_data_last_round_ != 0) { + double new_rate = + (double)received_data_from_cache_ / (double)received_data_last_round_; + data_from_cache_rate_ = data_from_cache_rate_ * MOVING_AVG_ALPHA + + (new_rate * (1 - MOVING_AVG_ALPHA)); + } + + // reset counters + received_bytes_ = 0; + packets_lost_ = 0; + losses_recovered_ = 0; + first_seq_in_round_ = highest_seq_received_; + + nack_on_last_round_ = false; + received_nacks_last_round_ = 0; + + received_packets_last_round_ = 0; + received_data_last_round_ = 0; + received_data_from_cache_ = 0; + sent_interests_last_round_ = 0; + sent_rtx_last_round_ = 0; + + rounds_++; +} + +void RTCState::updateReceivedBytes(const core::ContentObject &content_object) { + received_bytes_ += + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); +} + +void RTCState::updatePacketSize(const core::ContentObject &content_object) { + uint32_t pkt_size = + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); + avg_packet_size_ = (MOVING_AVG_ALPHA * avg_packet_size_) + + ((1 - MOVING_AVG_ALPHA) * pkt_size); +} + +void RTCState::updatePathStats(const core::ContentObject &content_object, + bool is_nack) { + // get packet path + uint32_t path_label = content_object.getPathLabel(); + auto path_it = path_table_.find(path_label); + + if (path_it == path_table_.end()) { + // found a new path + std::shared_ptr<RTCDataPath> newPath = + std::make_shared<RTCDataPath>(path_label); + auto ret = path_table_.insert( + std::pair<uint32_t, std::shared_ptr<RTCDataPath>>(path_label, newPath)); + path_it = ret.first; + } + + auto path = path_it->second; + + // compute rtt + uint32_t seq = content_object.getName().getSuffix(); + uint64_t interest_sent_time = getInterestSentTime(seq); + if (interest_sent_time == 0) + return; // this should not happen, + // it means that we are processing an interest + // that is not pending + + uint64_t now = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + + uint64_t RTT = now - interest_sent_time; + + path->insertRttSample(RTT); + + // compute OWD (the first part of the nack and data packet header are the + // same, so we cast to data data packet) + struct data_packet_t *packet = + (struct data_packet_t *)content_object.getPayload()->data(); + uint64_t sender_timestamp = packet->getTimestamp(); + int64_t OWD = now - sender_timestamp; + path->insertOwdSample(OWD); + + // compute IAT or set path to producer + if (!is_nack) { + // compute the iat only for the content packets + uint32_t segment_number = content_object.getName().getSuffix(); + path->computeInterArrivalGap(segment_number); + if (!path->pathToProducer()) received_data_from_cache_++; + } else { + path->receivedNack(); + } +} + +void RTCState::updateLossRate() { + loss_rate_ = 0.0; + residual_loss_rate_ = 0.0; + + uint32_t number_theorically_received_packets_ = + highest_seq_received_ - first_seq_in_round_; + + // in this case no new packet was recevied after the previuos round, avoid + // division by 0 + if (number_theorically_received_packets_ == 0) return; + + loss_rate_ = (double)((double)(packets_lost_) / + (double)number_theorically_received_packets_); + + residual_loss_rate_ = (double)((double)(packets_lost_ - losses_recovered_) / + (double)number_theorically_received_packets_); + + if (residual_loss_rate_ < 0) residual_loss_rate_ = 0; +} + +void RTCState::addRecvOrLost(uint32_t seq, PacketState state) { + pending_interests_.erase(seq); + if (received_or_lost_packets_.size() >= MAX_CACHED_PACKETS) { + received_or_lost_packets_.erase(received_or_lost_packets_.begin()); + } + // notice that it may happen that a packet that we consider lost arrives after + // some time, in this case we simply overwrite the packet state. + received_or_lost_packets_[seq] = state; + + // keep track of the last packet received/lost + // without holes. + if (highest_seq_received_in_order_ < last_seq_nacked_) { + highest_seq_received_in_order_ = last_seq_nacked_; + } + + if ((highest_seq_received_in_order_ + 1) == seq) { + highest_seq_received_in_order_ = seq; + } else if (seq <= highest_seq_received_in_order_) { + // here we do nothing + } else if (seq > highest_seq_received_in_order_) { + // 1) there is a gap in the sequence so we do not update largest_in_seq_ + // 2) all the packets from largest_in_seq_ to seq are in + // received_or_lost_packets_ an we upate largest_in_seq_ + + for (uint32_t i = highest_seq_received_in_order_ + 1; i <= seq; i++) { + if (received_or_lost_packets_.find(i) == + received_or_lost_packets_.end()) { + break; + } + // this packet is in order so we can update the + // highest_seq_received_in_order_ + highest_seq_received_in_order_ = i; + } + } +} + +void RTCState::setInitRttTimer(uint32_t wait){ + init_rtt_timer_->cancel(); + init_rtt_timer_->expires_from_now(std::chrono::milliseconds(wait)); + init_rtt_timer_->async_wait([this](std::error_code ec) { + if(ec) return; + checkInitRttTimer(); + }); +} + +void RTCState::checkInitRttTimer() { + if(received_probes_ < INIT_RTT_MIN_PROBES_TO_RECV){ + // we didn't received enough probes, restart + received_probes_ = 0; + rtt_probes_->setProbes(INIT_RTT_PROBE_INTERVAL, INIT_RTT_PROBES); + rtt_probes_->sendProbes(); + setInitRttTimer(INIT_RTT_PROBE_RESTART); + return; + } + init_rtt_ = true; + main_path_->roundEnd(); + rtt_probes_->setProbes(RTT_PROBE_INTERVAL, 0); + rtt_probes_->sendProbes(); + + // init last_seq_nacked_. skip packets that may come from the cache + double prod_rate = getProducerRate(); + double rtt = (double)getRTT() / MILLI_IN_A_SEC; + double packet_size = getAveragePacketSize(); + uint32_t pkt_in_rtt_ = std::floor(((prod_rate / packet_size) * rtt) * 0.8); + last_seq_nacked_ = last_production_seq_ + pkt_in_rtt_; + + discovered_rtt_callback_(); +} + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/rtc_state.h b/libtransport/src/protocols/rtc/rtc_state.h new file mode 100644 index 000000000..943a0a113 --- /dev/null +++ b/libtransport/src/protocols/rtc/rtc_state.h @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2017-2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include <hicn/transport/config.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/name.h> +#include <protocols/rtc/probe_handler.h> +#include <protocols/rtc/rtc_data_path.h> + +#include <asio.hpp> +#include <asio/steady_timer.hpp> +#include <map> +#include <set> + +namespace transport { + +namespace protocol { + +namespace rtc { + +enum class PacketState : uint8_t { RECEIVED, LOST, UNKNOWN }; + +class RTCState : std::enable_shared_from_this<RTCState> { + public: + using DiscoveredRttCallback = std::function<void()>; + public: + RTCState(ProbeHandler::SendProbeCallback &&rtt_probes_callback, + DiscoveredRttCallback &&discovered_rtt_callback, + asio::io_service &io_service); + + ~RTCState(); + + // packet events + void onSendNewInterest(const core::Name *interest_name); + void onTimeout(uint32_t seq); + void onRetransmission(uint32_t seq); + void onDataPacketReceived(const core::ContentObject &content_object, + bool compute_stats); + void onNackPacketReceived(const core::ContentObject &nack, + bool compute_stats); + void onPacketLost(uint32_t seq); + void onPacketRecovered(uint32_t seq); + bool onProbePacketReceived(const core::ContentObject &probe); + + // protocol state + void onNewRound(double round_len, bool in_sync); + + // main path + uint32_t getProducerPath() const { + if (mainPathIsValid()) return main_path_->getPathId(); + return 0; + } + + // delay metrics + bool isRttDiscovered() const { + return init_rtt_; + } + + uint64_t getRTT() const { + if (mainPathIsValid()) return main_path_->getMinRtt(); + return 0; + } + void resetRttStats() { + if (mainPathIsValid()) main_path_->clearRtt(); + } + + double getQueuing() const { + if (mainPathIsValid()) return main_path_->getQueuingDealy(); + return 0.0; + } + double getIAT() const { + if (mainPathIsValid()) return main_path_->getInterArrivalGap(); + return 0.0; + } + + double getJitter() const { + if (mainPathIsValid()) return main_path_->getJitter(); + return 0.0; + } + + // pending interests + uint64_t getInterestSentTime(uint32_t seq) { + auto it = pending_interests_.find(seq); + if (it != pending_interests_.end()) return it->second; + return 0; + } + bool isPending(uint32_t seq) { + if (pending_interests_.find(seq) != pending_interests_.end()) return true; + return false; + } + uint32_t getPendingInterestNumber() const { + return pending_interests_.size(); + } + PacketState isReceivedOrLost(uint32_t seq) { + auto it = received_or_lost_packets_.find(seq); + if (it != received_or_lost_packets_.end()) return it->second; + return PacketState::UNKNOWN; + } + + // loss rate + double getLossRate() const { return loss_rate_; } + double getResidualLossRate() const { return residual_loss_rate_; } + uint32_t getHighestSeqReceivedInOrder() const { + return highest_seq_received_in_order_; + } + uint32_t getLostData() const { return packets_lost_; }; + uint32_t getRecoveredLosses() const { return losses_recovered_; } + + // generic stats + uint32_t getReceivedBytesInRound() const { return received_bytes_; } + uint32_t getReceivedNacksInRound() const { + return received_nacks_last_round_; + } + uint32_t getSentInterestInRound() const { return sent_interests_last_round_; } + uint32_t getSentRtxInRound() const { return sent_rtx_last_round_; } + + // bandwidth/production metrics + double getAvailableBw() const { return 0.0; }; // TODO + double getProducerRate() const { return production_rate_; } + double getReceivedRate() const { return received_rate_; } + double getAveragePacketSize() const { return avg_packet_size_; } + + // nacks + uint32_t getRoundsWithoutNacks() const { return rounds_without_nacks_; } + uint32_t getLastSeqNacked() const { return last_seq_nacked_; } + + // producer state + bool isProducerActive() const { return producer_is_active_; } + + // packets from cache + double getPacketFromCacheRatio() const { return data_from_cache_rate_; } + + std::map<uint32_t, uint64_t>::iterator getPendingInterestsMapBegin() { + return pending_interests_.begin(); + } + std::map<uint32_t, uint64_t>::iterator getPendingInterestsMapEnd() { + return pending_interests_.end(); + } + + private: + void initParams(); + + // update stats + void updateState(); + void updateReceivedBytes(const core::ContentObject &content_object); + void updatePacketSize(const core::ContentObject &content_object); + void updatePathStats(const core::ContentObject &content_object, bool is_nack); + void updateLossRate(); + + void addRecvOrLost(uint32_t seq, PacketState state); + + void setInitRttTimer(uint32_t wait); + void checkInitRttTimer(); + + bool mainPathIsValid() const { + if (main_path_ != nullptr) + return true; + else + return false; + } + + // packets counters (total) + uint32_t sent_interests_; + uint32_t sent_rtx_; + uint32_t received_data_; + uint32_t received_nacks_; + uint32_t received_timeouts_; + uint32_t received_probes_; + + // loss counters + int32_t packets_lost_; + int32_t losses_recovered_; + uint32_t first_seq_in_round_; + uint32_t highest_seq_received_; + uint32_t highest_seq_received_in_order_; + uint32_t last_seq_nacked_; // segment for which we got an oldNack + double loss_rate_; + double residual_loss_rate_; + + // bw counters + uint32_t received_bytes_; + double avg_packet_size_; + double production_rate_; // rate communicated by the producer using nacks + double received_rate_; // rate recevied by the consumer + + // nack counter + // the bool takes tracks only about the valid nacks (no rtx) and it is used to + // switch between the states. Instead received_nacks_last_round_ logs all the + // nacks for statistics + bool nack_on_last_round_; + uint32_t received_nacks_last_round_; + + // packets counter + uint32_t received_packets_last_round_; + uint32_t received_data_last_round_; + uint32_t received_data_from_cache_; + double data_from_cache_rate_; + uint32_t sent_interests_last_round_; + uint32_t sent_rtx_last_round_; + + // round conunters + uint32_t rounds_; + uint32_t rounds_without_nacks_; + uint32_t rounds_without_packets_; + + // init rtt + uint64_t first_interest_sent_; + + // producer state + bool + producer_is_active_; // the prodcuer is active if we receive some packets + uint32_t last_production_seq_; // last production seq received by the producer + uint64_t last_prod_update_; // timestamp of the last packets used to update + // stats from the producer + + // paths stats + std::unordered_map<uint32_t, std::shared_ptr<RTCDataPath>> path_table_; + std::shared_ptr<RTCDataPath> main_path_; + + // packet received + // cache where to store info about the last MAX_CACHED_PACKETS + std::map<uint32_t, PacketState> received_or_lost_packets_; + + // pending interests + std::map<uint32_t, uint64_t> pending_interests_; + + // probes + std::shared_ptr<ProbeHandler> rtt_probes_; + bool init_rtt_; + std::unique_ptr<asio::steady_timer> init_rtt_timer_; + + // callbacks + DiscoveredRttCallback discovered_rtt_callback_; +}; + +} // namespace rtc + +} // namespace protocol + +} // namespace transport diff --git a/libtransport/src/protocols/rtc/trendline_estimator.cc b/libtransport/src/protocols/rtc/trendline_estimator.cc new file mode 100644 index 000000000..7a0803857 --- /dev/null +++ b/libtransport/src/protocols/rtc/trendline_estimator.cc @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// FROM +// https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/congestion_controller/goog_cc/trendline_estimator.cc + +#include "trendline_estimator.h" + +#include <math.h> + +#include <algorithm> +#include <string> + +namespace transport { + +namespace protocol { + +namespace rtc { + +// Parameters for linear least squares fit of regression line to noisy data. +constexpr double kDefaultTrendlineSmoothingCoeff = 0.9; +constexpr double kDefaultTrendlineThresholdGain = 4.0; +// const char kBweWindowSizeInPacketsExperiment[] = +// "WebRTC-BweWindowSizeInPackets"; + +/*size_t ReadTrendlineFilterWindowSize( + const WebRtcKeyValueConfig* key_value_config) { + std::string experiment_string = + key_value_config->Lookup(kBweWindowSizeInPacketsExperiment); + size_t window_size; + int parsed_values = + sscanf(experiment_string.c_str(), "Enabled-%zu", &window_size); + if (parsed_values == 1) { + if (window_size > 1) + return window_size; + RTC_LOG(WARNING) << "Window size must be greater than 1."; + } + RTC_LOG(LS_WARNING) << "Failed to parse parameters for BweWindowSizeInPackets" + " experiment from field trial string. Using default."; + return TrendlineEstimatorSettings::kDefaultTrendlineWindowSize; +} +*/ + +OptionalDouble LinearFitSlope( + const std::deque<TrendlineEstimator::PacketTiming>& packets) { + // RTC_DCHECK(packets.size() >= 2); + // Compute the "center of mass". + double sum_x = 0; + double sum_y = 0; + for (const auto& packet : packets) { + sum_x += packet.arrival_time_ms; + sum_y += packet.smoothed_delay_ms; + } + double x_avg = sum_x / packets.size(); + double y_avg = sum_y / packets.size(); + // Compute the slope k = \sum (x_i-x_avg)(y_i-y_avg) / \sum (x_i-x_avg)^2 + double numerator = 0; + double denominator = 0; + for (const auto& packet : packets) { + double x = packet.arrival_time_ms; + double y = packet.smoothed_delay_ms; + numerator += (x - x_avg) * (y - y_avg); + denominator += (x - x_avg) * (x - x_avg); + } + if (denominator == 0) return OptionalDouble(); + return OptionalDouble(numerator / denominator); +} + +OptionalDouble ComputeSlopeCap( + const std::deque<TrendlineEstimator::PacketTiming>& packets, + const TrendlineEstimatorSettings& settings) { + /*RTC_DCHECK(1 <= settings.beginning_packets && + settings.beginning_packets < packets.size()); + RTC_DCHECK(1 <= settings.end_packets && + settings.end_packets < packets.size()); + RTC_DCHECK(settings.beginning_packets + settings.end_packets <= + packets.size());*/ + TrendlineEstimator::PacketTiming early = packets[0]; + for (size_t i = 1; i < settings.beginning_packets; ++i) { + if (packets[i].raw_delay_ms < early.raw_delay_ms) early = packets[i]; + } + size_t late_start = packets.size() - settings.end_packets; + TrendlineEstimator::PacketTiming late = packets[late_start]; + for (size_t i = late_start + 1; i < packets.size(); ++i) { + if (packets[i].raw_delay_ms < late.raw_delay_ms) late = packets[i]; + } + if (late.arrival_time_ms - early.arrival_time_ms < 1) { + return OptionalDouble(); + } + return OptionalDouble((late.raw_delay_ms - early.raw_delay_ms) / + (late.arrival_time_ms - early.arrival_time_ms) + + settings.cap_uncertainty); +} + +constexpr double kMaxAdaptOffsetMs = 15.0; +constexpr double kOverUsingTimeThreshold = 10; +constexpr int kMinNumDeltas = 60; +constexpr int kDeltaCounterMax = 1000; + +//} // namespace + +constexpr char TrendlineEstimatorSettings::kKey[]; + +TrendlineEstimatorSettings::TrendlineEstimatorSettings( + /*const WebRtcKeyValueConfig* key_value_config*/) { + /*if (absl::StartsWith( + key_value_config->Lookup(kBweWindowSizeInPacketsExperiment), + "Enabled")) { + window_size = ReadTrendlineFilterWindowSize(key_value_config); + } + Parser()->Parse(key_value_config->Lookup(TrendlineEstimatorSettings::kKey));*/ + window_size = kDefaultTrendlineWindowSize; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + + /*if (window_size < 10 || 200 < window_size) { + RTC_LOG(LS_WARNING) << "Window size must be between 10 and 200 packets"; + window_size = kDefaultTrendlineWindowSize; + } + if (enable_cap) { + if (beginning_packets < 1 || end_packets < 1 || + beginning_packets > window_size || end_packets > window_size) { + RTC_LOG(LS_WARNING) << "Size of beginning and end must be between 1 and " + << window_size; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + } + if (beginning_packets + end_packets > window_size) { + RTC_LOG(LS_WARNING) + << "Size of beginning plus end can't exceed the window size"; + enable_cap = false; + beginning_packets = end_packets = 0; + cap_uncertainty = 0.0; + } + if (cap_uncertainty < 0.0 || 0.025 < cap_uncertainty) { + RTC_LOG(LS_WARNING) << "Cap uncertainty must be between 0 and 0.025"; + cap_uncertainty = 0.0; + } + }*/ +} + +/*std::unique_ptr<StructParametersParser> TrendlineEstimatorSettings::Parser() { + return StructParametersParser::Create("sort", &enable_sort, // + "cap", &enable_cap, // + "beginning_packets", + &beginning_packets, // + "end_packets", &end_packets, // + "cap_uncertainty", &cap_uncertainty, // + "window_size", &window_size); +}*/ + +TrendlineEstimator::TrendlineEstimator( + /*const WebRtcKeyValueConfig* key_value_config, + NetworkStatePredictor* network_state_predictor*/) + : settings_(), + smoothing_coef_(kDefaultTrendlineSmoothingCoeff), + threshold_gain_(kDefaultTrendlineThresholdGain), + num_of_deltas_(0), + first_arrival_time_ms_(-1), + accumulated_delay_(0), + smoothed_delay_(0), + delay_hist_(), + k_up_(0.0087), + k_down_(0.039), + overusing_time_threshold_(kOverUsingTimeThreshold), + threshold_(12.5), + prev_modified_trend_(NAN), + last_update_ms_(-1), + prev_trend_(0.0), + time_over_using_(-1), + overuse_counter_(0), + hypothesis_(BandwidthUsage::kBwNormal){ + // hypothesis_predicted_(BandwidthUsage::kBwNormal){//}, + // network_state_predictor_(network_state_predictor) { + /* RTC_LOG(LS_INFO) + << "Using Trendline filter for delay change estimation with settings " + << settings_.Parser()->Encode() << " and " + // << (network_state_predictor_ ? "injected" : "no") + << " network state predictor";*/ +} + +TrendlineEstimator::~TrendlineEstimator() {} + +void TrendlineEstimator::UpdateTrendline(double recv_delta_ms, + double send_delta_ms, + int64_t send_time_ms, + int64_t arrival_time_ms, + size_t packet_size) { + const double delta_ms = recv_delta_ms - send_delta_ms; + ++num_of_deltas_; + num_of_deltas_ = std::min(num_of_deltas_, kDeltaCounterMax); + if (first_arrival_time_ms_ == -1) first_arrival_time_ms_ = arrival_time_ms; + + // Exponential backoff filter. + accumulated_delay_ += delta_ms; + // BWE_TEST_LOGGING_PLOT(1, "accumulated_delay_ms", arrival_time_ms, + // accumulated_delay_); + smoothed_delay_ = smoothing_coef_ * smoothed_delay_ + + (1 - smoothing_coef_) * accumulated_delay_; + // BWE_TEST_LOGGING_PLOT(1, "smoothed_delay_ms", arrival_time_ms, + // smoothed_delay_); + + // Maintain packet window + delay_hist_.emplace_back( + static_cast<double>(arrival_time_ms - first_arrival_time_ms_), + smoothed_delay_, accumulated_delay_); + if (settings_.enable_sort) { + for (size_t i = delay_hist_.size() - 1; + i > 0 && + delay_hist_[i].arrival_time_ms < delay_hist_[i - 1].arrival_time_ms; + --i) { + std::swap(delay_hist_[i], delay_hist_[i - 1]); + } + } + if (delay_hist_.size() > settings_.window_size) delay_hist_.pop_front(); + + // Simple linear regression. + double trend = prev_trend_; + if (delay_hist_.size() == settings_.window_size) { + // Update trend_ if it is possible to fit a line to the data. The delay + // trend can be seen as an estimate of (send_rate - capacity)/capacity. + // 0 < trend < 1 -> the delay increases, queues are filling up + // trend == 0 -> the delay does not change + // trend < 0 -> the delay decreases, queues are being emptied + OptionalDouble trendO = LinearFitSlope(delay_hist_); + if (trendO.has_value()) trend = trendO.value(); + if (settings_.enable_cap) { + OptionalDouble cap = ComputeSlopeCap(delay_hist_, settings_); + // We only use the cap to filter out overuse detections, not + // to detect additional underuses. + if (trend >= 0 && cap.has_value() && trend > cap.value()) { + trend = cap.value(); + } + } + } + // BWE_TEST_LOGGING_PLOT(1, "trendline_slope", arrival_time_ms, trend); + + Detect(trend, send_delta_ms, arrival_time_ms); +} + +void TrendlineEstimator::Update(double recv_delta_ms, double send_delta_ms, + int64_t send_time_ms, int64_t arrival_time_ms, + size_t packet_size, bool calculated_deltas) { + if (calculated_deltas) { + UpdateTrendline(recv_delta_ms, send_delta_ms, send_time_ms, arrival_time_ms, + packet_size); + } + /*if (network_state_predictor_) { + hypothesis_predicted_ = network_state_predictor_->Update( + send_time_ms, arrival_time_ms, hypothesis_); + }*/ +} + +BandwidthUsage TrendlineEstimator::State() const { + return /*network_state_predictor_ ? hypothesis_predicted_ :*/ hypothesis_; +} + +void TrendlineEstimator::Detect(double trend, double ts_delta, int64_t now_ms) { + /*if (num_of_deltas_ < 2) { + hypothesis_ = BandwidthUsage::kBwNormal; + return; + }*/ + + const double modified_trend = + std::min(num_of_deltas_, kMinNumDeltas) * trend * threshold_gain_; + prev_modified_trend_ = modified_trend; + // BWE_TEST_LOGGING_PLOT(1, "T", now_ms, modified_trend); + // BWE_TEST_LOGGING_PLOT(1, "threshold", now_ms, threshold_); + if (modified_trend > threshold_) { + if (time_over_using_ == -1) { + // Initialize the timer. Assume that we've been + // over-using half of the time since the previous + // sample. + time_over_using_ = ts_delta / 2; + } else { + // Increment timer + time_over_using_ += ts_delta; + } + overuse_counter_++; + if (time_over_using_ > overusing_time_threshold_ && overuse_counter_ > 1) { + if (trend >= prev_trend_) { + time_over_using_ = 0; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwOverusing; + } + } + } else if (modified_trend < -threshold_) { + time_over_using_ = -1; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwUnderusing; + } else { + time_over_using_ = -1; + overuse_counter_ = 0; + hypothesis_ = BandwidthUsage::kBwNormal; + } + prev_trend_ = trend; + UpdateThreshold(modified_trend, now_ms); +} + +void TrendlineEstimator::UpdateThreshold(double modified_trend, + int64_t now_ms) { + if (last_update_ms_ == -1) last_update_ms_ = now_ms; + + if (fabs(modified_trend) > threshold_ + kMaxAdaptOffsetMs) { + // Avoid adapting the threshold to big latency spikes, caused e.g., + // by a sudden capacity drop. + last_update_ms_ = now_ms; + return; + } + + const double k = fabs(modified_trend) < threshold_ ? k_down_ : k_up_; + const int64_t kMaxTimeDeltaMs = 100; + int64_t time_delta_ms = std::min(now_ms - last_update_ms_, kMaxTimeDeltaMs); + threshold_ += k * (fabs(modified_trend) - threshold_) * time_delta_ms; + if (threshold_ < 6.f) threshold_ = 6.f; + if (threshold_ > 600.f) threshold_ = 600.f; + // threshold_ = rtc::SafeClamp(threshold_, 6.f, 600.f); + last_update_ms_ = now_ms; +} + +} // namespace rtc + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/rtc/trendline_estimator.h b/libtransport/src/protocols/rtc/trendline_estimator.h new file mode 100644 index 000000000..372acbc67 --- /dev/null +++ b/libtransport/src/protocols/rtc/trendline_estimator.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// FROM +// https://source.chromium.org/chromium/chromium/src/+/master:third_party/webrtc/modules/congestion_controller/goog_cc/trendline_estimator.h + +#ifndef MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ +#define MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <deque> +#include <memory> +#include <utility> + +namespace transport { + +namespace protocol { + +namespace rtc { + +class OptionalDouble { + public: + OptionalDouble() : val(0), has_val(false){}; + OptionalDouble(double val) : val(val), has_val(true){}; + + double value() { return val; } + bool has_value() { return has_val; } + + private: + double val; + bool has_val; +}; + +enum class BandwidthUsage { + kBwNormal = 0, + kBwUnderusing = 1, + kBwOverusing = 2, + kLast +}; + +struct TrendlineEstimatorSettings { + static constexpr char kKey[] = "WebRTC-Bwe-TrendlineEstimatorSettings"; + static constexpr unsigned kDefaultTrendlineWindowSize = 20; + + // TrendlineEstimatorSettings() = delete; + TrendlineEstimatorSettings( + /*const WebRtcKeyValueConfig* key_value_config*/); + + // Sort the packets in the window. Should be redundant, + // but then almost no cost. + bool enable_sort = false; + + // Cap the trendline slope based on the minimum delay seen + // in the beginning_packets and end_packets respectively. + bool enable_cap = false; + unsigned beginning_packets = 7; + unsigned end_packets = 7; + double cap_uncertainty = 0.0; + + // Size (in packets) of the window. + unsigned window_size = kDefaultTrendlineWindowSize; + + // std::unique_ptr<StructParametersParser> Parser(); +}; + +class TrendlineEstimator /*: public DelayIncreaseDetectorInterface */ { + public: + TrendlineEstimator(/*const WebRtcKeyValueConfig* key_value_config, + NetworkStatePredictor* network_state_predictor*/); + + ~TrendlineEstimator(); + + // Update the estimator with a new sample. The deltas should represent deltas + // between timestamp groups as defined by the InterArrival class. + void Update(double recv_delta_ms, double send_delta_ms, int64_t send_time_ms, + int64_t arrival_time_ms, size_t packet_size, + bool calculated_deltas); + + void UpdateTrendline(double recv_delta_ms, double send_delta_ms, + int64_t send_time_ms, int64_t arrival_time_ms, + size_t packet_size); + + BandwidthUsage State() const; + + struct PacketTiming { + PacketTiming(double arrival_time_ms, double smoothed_delay_ms, + double raw_delay_ms) + : arrival_time_ms(arrival_time_ms), + smoothed_delay_ms(smoothed_delay_ms), + raw_delay_ms(raw_delay_ms) {} + double arrival_time_ms; + double smoothed_delay_ms; + double raw_delay_ms; + }; + + private: + // friend class GoogCcStatePrinter; + void Detect(double trend, double ts_delta, int64_t now_ms); + + void UpdateThreshold(double modified_offset, int64_t now_ms); + + // Parameters. + TrendlineEstimatorSettings settings_; + const double smoothing_coef_; + const double threshold_gain_; + // Used by the existing threshold. + int num_of_deltas_; + // Keep the arrival times small by using the change from the first packet. + int64_t first_arrival_time_ms_; + // Exponential backoff filtering. + double accumulated_delay_; + double smoothed_delay_; + // Linear least squares regression. + std::deque<PacketTiming> delay_hist_; + + const double k_up_; + const double k_down_; + double overusing_time_threshold_; + double threshold_; + double prev_modified_trend_; + int64_t last_update_ms_; + double prev_trend_; + double time_over_using_; + int overuse_counter_; + BandwidthUsage hypothesis_; + // BandwidthUsage hypothesis_predicted_; + // NetworkStatePredictor* network_state_predictor_; + + // RTC_DISALLOW_COPY_AND_ASSIGN(TrendlineEstimator); +}; + +} // namespace rtc + +} // end namespace protocol + +} // end namespace transport +#endif // MODULES_CONGESTION_CONTROLLER_GOOG_CC_TRENDLINE_ESTIMATOR_H_ diff --git a/libtransport/src/protocols/transport_protocol.cc b/libtransport/src/protocols/transport_protocol.cc new file mode 100644 index 000000000..611c39212 --- /dev/null +++ b/libtransport/src/protocols/transport_protocol.cc @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/interfaces/socket_consumer.h> +#include <implementation/socket_consumer.h> +#include <protocols/transport_protocol.h> + +namespace transport { + +namespace protocol { + +using namespace interface; + +TransportProtocol::TransportProtocol(implementation::ConsumerSocket *icn_socket, + Reassembly *reassembly_protocol) + : socket_(icn_socket), + reassembly_protocol_(reassembly_protocol), + index_manager_( + std::make_unique<IndexManager>(socket_, this, reassembly_protocol)), + is_running_(false), + is_first_(false), + on_interest_retransmission_(VOID_HANDLER), + on_interest_output_(VOID_HANDLER), + on_interest_timeout_(VOID_HANDLER), + on_interest_satisfied_(VOID_HANDLER), + on_content_object_input_(VOID_HANDLER), + stats_summary_(VOID_HANDLER), + on_payload_(VOID_HANDLER) { + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal_); + socket_->getSocketOption(OtherOptions::STATISTICS, &stats_); +} + +int TransportProtocol::start() { + // If the protocol is already running, return otherwise set as running + if (is_running_) return -1; + + // Get all callbacks references + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_RETRANSMISSION, + &on_interest_retransmission_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_OUTPUT, + &on_interest_output_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_EXPIRED, + &on_interest_timeout_); + socket_->getSocketOption(ConsumerCallbacksOptions::INTEREST_SATISFIED, + &on_interest_satisfied_); + socket_->getSocketOption(ConsumerCallbacksOptions::CONTENT_OBJECT_INPUT, + &on_content_object_input_); + socket_->getSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + &stats_summary_); + socket_->getSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + &on_payload_); + + socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); + + // Set it is the first time we schedule an interest + is_first_ = true; + + // Reset the protocol state machine + reset(); + // Schedule next interests + scheduleNextInterests(); + + is_first_ = false; + + // Set the protocol as running + is_running_ = true; + + if (!is_async_) { + // Start Event loop + portal_->runEventsLoop(); + + // Not running anymore + is_running_ = false; + } + + return 0; +} + +void TransportProtocol::stop() { + is_running_ = false; + + if (!is_async_) { + portal_->stopEventsLoop(); + } else { + portal_->clear(); + } +} + +void TransportProtocol::resume() { + if (is_running_) return; + + is_running_ = true; + + scheduleNextInterests(); + + portal_->runEventsLoop(); + + is_running_ = false; +} + +void TransportProtocol::onContentReassembled(std::error_code ec) { + stop(); + + if (!on_payload_) { + throw errors::RuntimeException( + "The read callback must be installed in the transport before " + "starting " + "the content retrieval."); + } + + if (!ec) { + on_payload_->readSuccess(stats_->getBytesRecv()); + } else { + on_payload_->readError(ec); + } +} + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/protocols/transport_protocol.h b/libtransport/src/protocols/transport_protocol.h new file mode 100644 index 000000000..124c57122 --- /dev/null +++ b/libtransport/src/protocols/transport_protocol.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/callbacks.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/statistics.h> +#include <hicn/transport/utils/object_pool.h> +#include <implementation/socket.h> +#include <protocols/data_processing_events.h> +#include <protocols/indexer.h> +#include <protocols/reassembly.h> + +#include <atomic> + +namespace transport { + +namespace protocol { + +using namespace core; + +class IndexVerificationManager; + +using ReadCallback = interface::ConsumerSocket::ReadCallback; + +class TransportProtocolCallback { + virtual void onContentObject(const core::Interest &interest, + const core::ContentObject &content_object) = 0; + virtual void onTimeout(const core::Interest &interest) = 0; +}; + +class TransportProtocol : public core::Portal::ConsumerCallback, + public ContentObjectProcessingEventCallback { + static constexpr std::size_t interest_pool_size = 4096; + + friend class ManifestIndexManager; + + public: + TransportProtocol(implementation::ConsumerSocket *icn_socket, + Reassembly *reassembly_protocol); + + virtual ~TransportProtocol() = default; + + TRANSPORT_ALWAYS_INLINE bool isRunning() { return is_running_; } + + virtual int start(); + + virtual void stop(); + + virtual void resume(); + + virtual void scheduleNextInterests() = 0; + + // Events generated by the indexing + virtual void onContentReassembled(std::error_code ec); + virtual void onPacketDropped(Interest &interest, + ContentObject &content_object) override = 0; + virtual void onReassemblyFailed(std::uint32_t missing_segment) override = 0; + + protected: + // Consumer Callback + virtual void reset() = 0; + virtual void onContentObject(Interest &i, ContentObject &c) override = 0; + virtual void onTimeout(Interest::Ptr &&i) override = 0; + virtual void onError(std::error_code ec) override {} + + protected: + implementation::ConsumerSocket *socket_; + std::unique_ptr<Reassembly> reassembly_protocol_; + std::unique_ptr<IndexManager> index_manager_; + std::shared_ptr<core::Portal> portal_; + std::atomic<bool> is_running_; + // True if it si the first time we schedule an interest + std::atomic<bool> is_first_; + interface::TransportStatistics *stats_; + + // Callbacks + interface::ConsumerInterestCallback *on_interest_retransmission_; + interface::ConsumerInterestCallback *on_interest_output_; + interface::ConsumerInterestCallback *on_interest_timeout_; + interface::ConsumerInterestCallback *on_interest_satisfied_; + interface::ConsumerContentObjectCallback *on_content_object_input_; + interface::ConsumerContentObjectCallback *on_content_object_; + interface::ConsumerTimerCallback *stats_summary_; + ReadCallback *on_payload_; + + bool is_async_; +}; + +} // end namespace protocol +} // end namespace transport diff --git a/libtransport/src/test/CMakeLists.txt b/libtransport/src/test/CMakeLists.txt index 19e59c7e1..dd3d1d923 100644 --- a/libtransport/src/test/CMakeLists.txt +++ b/libtransport/src/test/CMakeLists.txt @@ -14,14 +14,15 @@ include(BuildMacros) list(APPEND TESTS + test_auth + test_consumer_producer_rtc test_core_manifest - test_transport_producer + test_event_thread + test_fec_reedsolomon + test_interest + test_packet ) -if (${LIBTRANSPORT_SHARED} MATCHES ".*-memif.*") - set(LINK_FLAGS "-Wl,-unresolved-symbols=ignore-in-shared-libs") -endif() - foreach(test ${TESTS}) build_executable(${test} NO_INSTALL @@ -35,4 +36,4 @@ foreach(test ${TESTS}) ) add_test_internal(${test}) -endforeach()
\ No newline at end of file +endforeach() diff --git a/libtransport/src/test/fec_reed_solomon.cc b/libtransport/src/test/fec_reed_solomon.cc new file mode 100644 index 000000000..36543c531 --- /dev/null +++ b/libtransport/src/test/fec_reed_solomon.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/global_conf_interface.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> +#include <fec/rs.h> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + FAIL() << "Failed to schedule packet production"; + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + + io_service_.run(); +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/fec_rely.cc b/libtransport/src/test/fec_rely.cc new file mode 100644 index 000000000..e7745bae5 --- /dev/null +++ b/libtransport/src/test/fec_rely.cc @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/global_conf_interface.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> + +#include <rely/encoder.hpp> +#include <rely/decoder.hpp> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + FAIL() << "Failed to schedule packet production"; + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + + io_service_.run(); +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/packet_samples.h b/libtransport/src/test/packet_samples.h new file mode 100644 index 000000000..e98d06a18 --- /dev/null +++ b/libtransport/src/test/packet_samples.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define TCP_PROTO 0x06 +#define ICMP_PROTO 0x01 +#define ICMP6_PROTO 0x3a + +#define IPV6_HEADER(next_header, payload_length) \ + 0x60, 0x00, 0x00, 0x00, 0x00, payload_length, next_header, 0x40, 0xb0, 0x06, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xab, 0xcd, 0xab, \ + 0xcd, 0xef, 0xb0, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0xca + +#define IPV4_HEADER(next_header, payload_length) \ + 0x45, 0x02, 0x00, payload_length + 20, 0x47, 0xc4, 0x40, 0x00, 0x25, \ + next_header, 0x6e, 0x76, 0x03, 0x7b, 0xd9, 0xd0, 0xc0, 0xa8, 0x01, 0x5c + +#define TCP_HEADER(flags) \ + 0x12, 0x34, 0x43, 0x21, 0x00, 0x00, 0x00, 0x01, 0xb2, 0x8c, 0x03, 0x1f, \ + 0x80, flags, 0x00, 0x0a, 0xb9, 0xbb, 0x00, 0x00 + +#define PAYLOAD \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x20, 0x00, 0x00 + +#define PAYLOAD_SIZE 12 + +#define ICMP_ECHO_REQUEST \ + 0x08, 0x00, 0x87, 0xdb, 0x38, 0xa7, 0x00, 0x05, 0x60, 0x2b, 0xc2, 0xcb, \ + 0x00, 0x02, 0x29, 0x7c, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, \ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, \ + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, \ + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, \ + 0x34, 0x35, 0x36, 0x37 + +#define ICMP6_ECHO_REQUEST \ + 0x80, 0x00, 0x86, 0x3c, 0x11, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, \ + 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, \ + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, \ + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, \ + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33 + +#define AH_HEADER \ + 0x00, (128 >> 2), 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, \ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 diff --git a/libtransport/src/test/test_auth.cc b/libtransport/src/test/test_auth.cc new file mode 100644 index 000000000..976981cce --- /dev/null +++ b/libtransport/src/test/test_auth.cc @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/auth/crypto_hash_type.h> +#include <hicn/transport/auth/identity.h> +#include <hicn/transport/auth/signer.h> +#include <hicn/transport/auth/verifier.h> +#include <hicn/transport/core/content_object.h> + +namespace transport { +namespace auth { + +namespace { +class AuthTest : public ::testing::Test { + protected: + const std::string PASSPHRASE = "hunter2"; + + AuthTest() = default; + ~AuthTest() {} + void SetUp() override {} + void TearDown() override {} +}; +} // namespace + +TEST_F(AuthTest, VoidVerifier) { + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Verify that VoidVerifier validates the packet + std::shared_ptr<Verifier> verifier = std::make_shared<VoidVerifier>(); + ASSERT_EQ(verifier->verifyPacket(&packet), true); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); +} + +TEST_F(AuthTest, RSAVerifier) { + // Create the RSA signer from an Identity object + Identity identity("test_rsa.p12", PASSPHRASE, CryptoSuite::RSA_SHA256, 1024u, + 30, "RSAVerifier"); + std::shared_ptr<Signer> signer = identity.getSigner(); + + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH, signer->getSignatureSize()); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Sign the packet + signer->signPacket(&packet); + + // Create the RSA verifier + PARCKey *key = parcSigner_CreatePublicKey(signer->getParcSigner()); + std::shared_ptr<Verifier> verifier = + std::make_shared<AsymmetricVerifier>(key); + + ASSERT_EQ(packet.getFormat(), HF_INET6_TCP_AH); + ASSERT_EQ(signer->getCryptoHashType(), CryptoHashType::SHA_256); + ASSERT_EQ(signer->getCryptoSuite(), CryptoSuite::RSA_SHA256); + ASSERT_EQ(signer->getSignatureSize(), 128u); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); + + // Release PARC objects + parcKey_Release(&key); +} + +TEST_F(AuthTest, HMACVerifier) { + // Create the HMAC signer from a passphrase + std::shared_ptr<Signer> signer = + std::make_shared<SymmetricSigner>(CryptoSuite::HMAC_SHA256, PASSPHRASE); + + // Create a content object + core::ContentObject packet(HF_INET6_TCP_AH, signer->getSignatureSize()); + + // Fill it with bogus data + uint8_t buffer[256] = {0}; + packet.appendPayload(buffer, 256); + + // Sign the packet + signer->signPacket(&packet); + + // Create the HMAC verifier + std::shared_ptr<Verifier> verifier = + std::make_shared<SymmetricVerifier>(PASSPHRASE); + + ASSERT_EQ(packet.getFormat(), HF_INET6_TCP_AH); + ASSERT_EQ(signer->getCryptoHashType(), CryptoHashType::SHA_256); + ASSERT_EQ(signer->getCryptoSuite(), CryptoSuite::HMAC_SHA256); + ASSERT_EQ(signer->getSignatureSize(), 32u); + ASSERT_EQ(verifier->verifyPackets(&packet), VerificationPolicy::ACCEPT); +} + +} // namespace auth +} // namespace transport diff --git a/libtransport/src/test/test_consumer_producer_rtc.cc b/libtransport/src/test/test_consumer_producer_rtc.cc new file mode 100644 index 000000000..87385971a --- /dev/null +++ b/libtransport/src/test/test_consumer_producer_rtc.cc @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/interfaces/global_conf_interface.h> +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/socket_options_keys.h> +#include <hicn/transport/interfaces/socket_producer.h> + +#include <asio/io_service.hpp> +#include <asio/steady_timer.hpp> + +namespace transport { +namespace interface { + +namespace { + +class ConsumerProducerTest : public ::testing::Test, + public ConsumerSocket::ReadCallback { + static const constexpr char prefix[] = "b001::1/128"; + static const constexpr char name[] = "b001::1"; + static const constexpr double prod_rate = 1.0e6; + static const constexpr size_t payload_size = 1200; + static constexpr std::size_t receive_buffer_size = 1500; + static const constexpr double prod_interval_microseconds = + double(payload_size) * 8 * 1e6 / prod_rate; + + public: + ConsumerProducerTest() + : io_service_(), + rtc_timer_(io_service_), + stop_timer_(io_service_), + consumer_(TransportProtocolAlgorithms::RTC, io_service_), + producer_(ProductionProtocolAlgorithms::RTC_PROD, io_service_), + producer_prefix_(prefix), + consumer_name_(name), + packets_sent_(0), + packets_received_(0) { + global_config::IoModuleConfiguration config; + config.name = "loopback_module"; + config.set(); + } + + virtual ~ConsumerProducerTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + + auto ret = consumer_.setSocketOption( + ConsumerCallbacksOptions::READ_CALLBACK, this); + ASSERT_EQ(ret, SOCKET_OPTION_SET); + + consumer_.connect(); + producer_.registerPrefix(producer_prefix_); + producer_.connect(); + } + + virtual void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + void setTimer() { + using namespace std::chrono; + rtc_timer_.expires_from_now( + microseconds(unsigned(prod_interval_microseconds))); + rtc_timer_.async_wait(std::bind(&ConsumerProducerTest::produceRTCPacket, + this, std::placeholders::_1)); + } + + void setStopTimer() { + using namespace std::chrono; + stop_timer_.expires_from_now(seconds(unsigned(10))); + stop_timer_.async_wait( + std::bind(&ConsumerProducerTest::stop, this, std::placeholders::_1)); + } + + void produceRTCPacket(const std::error_code &ec) { + if (ec) { + io_service_.stop(); + } + + producer_.produceDatagram(consumer_name_, payload_, payload_size); + packets_sent_++; + setTimer(); + } + + void stop(const std::error_code &ec) { + rtc_timer_.cancel(); + producer_.stop(); + consumer_.stop(); + } + + // Consumer callback + bool isBufferMovable() noexcept override { return false; } + + void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override { + *application_buffer = receive_buffer_; + *max_length = receive_buffer_size; + } + + void readDataAvailable(std::size_t length) noexcept override {} + + size_t maxBufferSize() const override { return receive_buffer_size; } + + void readError(const std::error_code ec) noexcept override { + FAIL() << "Error while reading from RTC socket"; + io_service_.stop(); + } + + void readSuccess(std::size_t total_size) noexcept override { + packets_received_++; + std::cout << "Received something" << std::endl; + } + + asio::io_service io_service_; + asio::steady_timer rtc_timer_; + asio::steady_timer stop_timer_; + ConsumerSocket consumer_; + ProducerSocket producer_; + core::Prefix producer_prefix_; + core::Name consumer_name_; + uint8_t payload_[payload_size]; + uint8_t receive_buffer_[payload_size]; + + uint64_t packets_sent_; + uint64_t packets_received_; +}; + +const char ConsumerProducerTest::prefix[]; +const char ConsumerProducerTest::name[]; + +} // namespace + +TEST_F(ConsumerProducerTest, EndToEnd) { + produceRTCPacket(std::error_code()); + consumer_.consume(consumer_name_); + setStopTimer(); + + io_service_.run(); + + std::cout << "Packet received: " << packets_received_ << std::endl; + std::cout << "Packet sent: " << packets_sent_ << std::endl; +} + +} // namespace interface + +} // namespace transport + +int main(int argc, char **argv) { +#if 0 + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +#else + return 0; +#endif +}
\ No newline at end of file diff --git a/libtransport/src/test/test_core_manifest.cc b/libtransport/src/test/test_core_manifest.cc index faf17dcf0..f98147d43 100644 --- a/libtransport/src/test/test_core_manifest.cc +++ b/libtransport/src/test/test_core_manifest.cc @@ -16,7 +16,8 @@ #include <core/manifest_format_fixed.h> #include <core/manifest_inline.h> #include <gtest/gtest.h> -#include <hicn/transport/security/crypto_hash_type.h> +#include <hicn/transport/auth/crypto_hash_type.h> +#include <test/packet_samples.h> #include <climits> #include <random> @@ -72,6 +73,27 @@ class ManifestTest : public ::testing::Test { } // namespace +TEST_F(ManifestTest, MoveConstructor) { + // Create content object with manifest in payload + ContentObject co(HF_INET6_TCP_AH, 128); + co.appendPayload(&manifest_payload[0], manifest_payload.size()); + uint8_t buffer[256]; + co.appendPayload(buffer, 256); + + // Copy packet payload + uint8_t packet[1500]; + auto length = co.getPayload()->length(); + std::memcpy(packet, co.getPayload()->data(), length); + + // Create manifest + ContentObjectManifest m(std::move(co)); + + // Check manifest payload is exactly the same of content object + ASSERT_EQ(length, m.getPayload()->length()); + auto ret = std::memcmp(packet, m.getPayload()->data(), length); + ASSERT_EQ(ret, 0); +} + TEST_F(ManifestTest, SetLastManifest) { manifest1_.clear(); @@ -102,9 +124,9 @@ TEST_F(ManifestTest, SetManifestType) { TEST_F(ManifestTest, SetHashAlgorithm) { manifest1_.clear(); - utils::CryptoHashType hash1 = utils::CryptoHashType::SHA_512; - utils::CryptoHashType hash2 = utils::CryptoHashType::CRC32C; - utils::CryptoHashType hash3 = utils::CryptoHashType::SHA_256; + auth::CryptoHashType hash1 = auth::CryptoHashType::SHA_512; + auth::CryptoHashType hash2 = auth::CryptoHashType::CRC32C; + auth::CryptoHashType hash3 = auth::CryptoHashType::SHA_256; manifest1_.setHashAlgorithm(hash1); auto type_returned1 = manifest1_.getHashAlgorithm(); @@ -161,7 +183,7 @@ TEST_F(ManifestTest, SetSuffixList) { std::uniform_int_distribution<uint64_t> idis( 0, std::numeric_limits<uint32_t>::max()); - auto entries = new std::pair<uint32_t, utils::CryptoHash>[3]; + auto entries = new std::pair<uint32_t, auth::CryptoHash>[3]; uint32_t suffixes[3]; std::vector<unsigned char> data[3]; @@ -170,8 +192,8 @@ TEST_F(ManifestTest, SetSuffixList) { std::generate(std::begin(data[i]), std::end(data[i]), std::ref(rbe)); suffixes[i] = idis(eng); entries[i] = std::make_pair( - suffixes[i], utils::CryptoHash(data[i].data(), data[i].size(), - utils::CryptoHashType::SHA_256)); + suffixes[i], auth::CryptoHash(data[i].data(), data[i].size(), + auth::CryptoHashType::SHA_256)); manifest1_.addSuffixHash(entries[i].first, entries[i].second); } @@ -186,9 +208,9 @@ TEST_F(ManifestTest, SetSuffixList) { // for (auto & item : manifest1_.getSuffixList()) { // auto hash = manifest1_.getHash(suffixes[i]); - // cond = utils::CryptoHash::compareBinaryDigest(hash, - // entries[i].second.getDigest<uint8_t>().data(), - // entries[i].second.getType()); + // cond = auth::CryptoHash::compareBinaryDigest(hash, + // entries[i].second.getDigest<uint8_t>().data(), + // entries[i].second.getType()); // ASSERT_TRUE(cond); // i++; // } @@ -205,4 +227,4 @@ TEST_F(ManifestTest, SetSuffixList) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -}
\ No newline at end of file +} diff --git a/libtransport/src/test/test_event_thread.cc b/libtransport/src/test/test_event_thread.cc new file mode 100644 index 000000000..e66b49f10 --- /dev/null +++ b/libtransport/src/test/test_event_thread.cc @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/utils/event_thread.h> + +#include <cmath> + +namespace utils { + +namespace { + +class EventThreadTest : public ::testing::Test { + protected: + EventThreadTest() : event_thread_() { + // You can do set-up work for each test here. + } + + virtual ~EventThreadTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + utils::EventThread event_thread_; +}; + +double average(const unsigned long samples[], int size) { + double sum = 0; + + for (int i = 0; i < size; i++) { + sum += samples[i]; + } + + return sum / size; +} + +double stdDeviation(const unsigned long samples[], int size) { + double avg = average(samples, size); + double var = 0; + + for (int i = 0; i < size; i++) { + var += (samples[i] - avg) * (samples[i] - avg); + } + + return sqrt(var / size); +} + +} // namespace + +TEST_F(EventThreadTest, SchedulingDelay) { + using namespace std::chrono; + const size_t size = 1000000; + std::vector<unsigned long> samples(size); + + for (unsigned int i = 0; i < size; i++) { + auto t0 = steady_clock::now(); + event_thread_.add([t0, &samples, i]() { + auto t1 = steady_clock::now(); + samples[i] = duration_cast<nanoseconds>(t1 - t0).count(); + }); + } + + event_thread_.stop(); + + auto avg = average(&samples[0], size); + auto sd = stdDeviation(&samples[0], size); + (void)sd; + + // Expect average to be less that 1 ms + EXPECT_LT(avg, 1000000); +} + +} // namespace utils + +int main(int argc, char **argv) { +#if 0 + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +#else + return 0; +#endif +}
\ No newline at end of file diff --git a/libtransport/src/test/test_fec_reedsolomon.cc b/libtransport/src/test/test_fec_reedsolomon.cc new file mode 100644 index 000000000..3b10b7307 --- /dev/null +++ b/libtransport/src/test/test_fec_reedsolomon.cc @@ -0,0 +1,291 @@ + +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/rs.h> +#include <gtest/gtest.h> +#include <hicn/transport/core/content_object.h> +#include <hicn/transport/core/global_object_pool.h> + +#include <algorithm> +#include <iostream> +#include <random> + +namespace transport { +namespace core { + +double ReedSolomonTest(int k, int n, int size) { + fec::encoder encoder(k, n); + fec::decoder decoder(k, n); + + std::vector<fec::buffer> tx_block(k); + std::vector<fec::buffer> rx_block(k); + int count = 0; + int run = 0; + + int iterations = 100; + auto &packet_manager = PacketManager<>::getInstance(); + + encoder.setFECCallback([&tx_block](std::vector<fec::buffer> &repair_packets) { + for (auto &p : repair_packets) { + // Append repair symbols to tx_block + tx_block.emplace_back(std::move(p)); + } + }); + + decoder.setFECCallback([&](std::vector<fec::buffer> &source_packets) { + for (int i = 0; i < k; i++) { + // Compare decoded source packets with original transmitted packets. + if (*tx_block[i] != *source_packets[i]) { + count++; + } + } + }); + + do { + // Discard eventual packet appended in previous callback call + tx_block.erase(tx_block.begin() + k, tx_block.end()); + + // Initialization. Feed encoder with first k source packets + for (int i = 0; i < k; i++) { + // Get new buffer from pool + auto packet = packet_manager.getMemBuf(); + + // Let's append a bit less than size, so that the FEC class will take care + // of filling the rest with zeros + auto cur_size = size - (rand() % 100); + + // Set payload, saving 2 bytes at the beginning of the buffer for encoding + // the length + packet->append(cur_size); + packet->trimStart(2); + std::generate(packet->writableData(), packet->writableTail(), rand); + std::fill(packet->writableData(), packet->writableTail(), i + 1); + + // Set first byte of payload to i, to reorder at receiver side + packet->writableData()[0] = uint8_t(i); + + // Store packet in tx buffer and clear rx buffer + tx_block[i] = std::move(packet); + } + + // Create the repair packets + for (auto &tx : tx_block) { + encoder.consume(tx, tx->writableBuffer()[0]); + } + + // Simulate transmission on lossy channel + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::vector<bool> losses(n, false); + for (int i = 0; i < n - k; i++) losses[i] = true; + + int rxi = 0; + std::shuffle(losses.begin(), losses.end(), + std::default_random_engine(seed)); + for (int i = 0; i < n && rxi < k; i++) + if (losses[i] == false) { + rx_block[rxi++] = tx_block[i]; + if (i < k) { + // Source packet + decoder.consume(rx_block[rxi - 1], rx_block[rxi - 1]->data()[0]); + } else { + // Repair packet + decoder.consume(rx_block[rxi - 1]); + } + } + + decoder.clear(); + encoder.clear(); + } while (++run < iterations); + + return count; +} + +void ReedSolomonMultiBlockTest(int n_sourceblocks) { + int k = 16; + int n = 24; + int size = 1000; + + fec::encoder encoder(k, n); + fec::decoder decoder(k, n); + + auto &packet_manager = PacketManager<>::getInstance(); + + std::vector<std::pair<fec::buffer, uint32_t>> tx_block; + std::vector<std::pair<fec::buffer, uint32_t>> rx_block; + int count = 0; + int i = 0; + + // Receiver will receive packet for n_sourceblocks in a random order. + int total_packets = n * n_sourceblocks; + int tx_packets = k * n_sourceblocks; + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + + encoder.setFECCallback([&](std::vector<fec::buffer> &repair_packets) { + for (auto &p : repair_packets) { + // Append repair symbols to tx_block + tx_block.emplace_back(std::move(p), ++i); + } + + EXPECT_EQ(tx_block.size(), size_t(n)); + + // Select k packets to send, including at least one symbol. We start from + // the end for this reason. + for (int j = n - 1; j > n - k - 1; j--) { + rx_block.emplace_back(std::move(tx_block[j])); + } + + // Clear tx block for next source block + tx_block.clear(); + encoder.clear(); + }); + + // The decode callback must be called exactly n_sourceblocks times + decoder.setFECCallback( + [&](std::vector<fec::buffer> &source_packets) { count++; }); + + // Produce n * n_sourceblocks + // - ( k ) * n_sourceblocks source packets + // - (n - k) * n_sourceblocks symbols) + for (i = 0; i < total_packets; i++) { + // Get new buffer from pool + auto packet = packet_manager.getMemBuf(); + + // Let's append a bit less than size, so that the FEC class will take care + // of filling the rest with zeros + auto cur_size = size - (rand() % 100); + + // Set payload, saving 2 bytes at the beginning of the buffer for encoding + // the length + packet->append(cur_size); + packet->trimStart(2); + std::fill(packet->writableData(), packet->writableTail(), i + 1); + + // Set first byte of payload to i, to reorder at receiver side + packet->writableData()[0] = uint8_t(i); + + // Store packet in tx buffer + tx_block.emplace_back(packet, i); + + // Feed encoder with packet + encoder.consume(packet, i); + } + + // Here rx_block must contains k * n_sourceblocks packets + EXPECT_EQ(size_t(tx_packets), size_t(rx_block.size())); + + // Lets shuffle the rx_block before starting feeding the decoder. + std::shuffle(rx_block.begin(), rx_block.end(), + std::default_random_engine(seed)); + + for (auto &p : rx_block) { + int index = p.second % n; + if (index < k) { + // Source packet + decoder.consume(p.first, p.second); + } else { + // Repair packet + decoder.consume(p.first); + } + } + + // Simple test to check we get all the source packets + EXPECT_EQ(count, n_sourceblocks); +} + +TEST(ReedSolomonTest, RSk1n3) { + int k = 1; + int n = 3; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk6n10) { + int k = 6; + int n = 10; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk8n32) { + int k = 8; + int n = 32; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk16n24) { + int k = 16; + int n = 24; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n30) { + int k = 10; + int n = 30; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n40) { + int k = 10; + int n = 40; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n60) { + int k = 10; + int n = 60; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonTest, RSk10n90) { + int k = 10; + int n = 90; + int size = 1000; + EXPECT_LE(ReedSolomonTest(k, n, size), 0); +} + +TEST(ReedSolomonMultiBlockTest, RSMB1) { + int blocks = 1; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB10) { + int blocks = 10; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB100) { + int blocks = 100; + ReedSolomonMultiBlockTest(blocks); +} + +TEST(ReedSolomonMultiBlockTest, RSMB1000) { + int blocks = 1000; + ReedSolomonMultiBlockTest(blocks); +} + +int main(int argc, char **argv) { + srand(time(0)); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +} // namespace core +} // namespace transport diff --git a/libtransport/src/test/test_interest.cc b/libtransport/src/test/test_interest.cc new file mode 100644 index 000000000..0a835db24 --- /dev/null +++ b/libtransport/src/test/test_interest.cc @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/core/interest.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <test/packet_samples.h> + +#include <climits> +#include <random> +#include <vector> + +namespace transport { + +namespace core { + +namespace { +// The fixture for testing class Foo. +class InterestTest : public ::testing::Test { + protected: + InterestTest() : name_("b001::123|321"), interest_() { + // You can do set-up work for each test here. + } + + virtual ~InterestTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + Name name_; + + Interest interest_; + + std::vector<uint8_t> buffer_ = {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // ICMP6 echo request + TCP_HEADER(0x00), + // Payload + PAYLOAD}; +}; + +void testFormatConstructor(Packet::Format format = HF_UNSPEC) { + try { + Interest interest(format, 0); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown for " << format; + } +} + +void testFormatConstructorException(Packet::Format format = HF_UNSPEC) { + try { + Interest interest(format, 0); + FAIL() << "We expected an exception here"; + } catch (errors::MalformedPacketException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } +} + +} // namespace + +TEST_F(InterestTest, ConstructorWithFormat) { + /** + * Without arguments it should be format = HF_UNSPEC. + * We expect a crash. + */ + + testFormatConstructor(Packet::Format::HF_INET_TCP); + testFormatConstructor(Packet::Format::HF_INET6_TCP); + testFormatConstructorException(Packet::Format::HF_INET_ICMP); + testFormatConstructorException(Packet::Format::HF_INET6_ICMP); + testFormatConstructor(Packet::Format::HF_INET_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET6_TCP_AH); + testFormatConstructorException(Packet::Format::HF_INET_ICMP_AH); + testFormatConstructorException(Packet::Format::HF_INET6_ICMP_AH); +} + +TEST_F(InterestTest, ConstructorWithName) { + /** + * Without arguments it should be format = HF_UNSPEC. + * We expect a crash. + */ + Name n("b001::1|123"); + + try { + Interest interest(n); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } +} + +TEST_F(InterestTest, ConstructorWithBuffer) { + // Ensure buffer is interest + auto ret = Interest::isInterest(&buffer_[0]); + EXPECT_TRUE(ret); + + // Create interest from buffer + try { + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } + + std::vector<uint8_t> buffer2{// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60 + 44), + // ICMP6 echo request + TCP_HEADER(0x00), + // Payload + PAYLOAD}; + + // Ensure this throws an exception + try { + Interest interest(Interest::COPY_BUFFER, &buffer2[0], buffer2.size()); + FAIL() << "We expected an exception here"; + } catch (errors::MalformedPacketException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } +} + +TEST_F(InterestTest, SetGetName) { + // Create interest from buffer + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + + // Get name + auto n = interest.getName(); + + // ensure name is b002::ca|1 + Name n2("b002::ca|1"); + auto ret = (n == n2); + + EXPECT_TRUE(ret); + + Name n3("b003::1234|1234"); + + // Change name to b003::1234|1234 + interest.setName(n3); + + // Check name was set + n = interest.getName(); + ret = (n == n3); + EXPECT_TRUE(ret); +} + +TEST_F(InterestTest, SetGetLocator) { + // Create interest from buffer + Interest interest(Interest::COPY_BUFFER, &buffer_[0], buffer_.size()); + + // Get locator + auto l = interest.getLocator(); + + ip_address_t address; + ip_address_pton("b006::ab:cdab:cdef", &address); + auto ret = !std::memcmp(&l, &address, sizeof(address)); + + EXPECT_TRUE(ret); + + // Set different locator + ip_address_pton("2001::1234::4321::abcd::", &address); + + // Set it on interest + interest.setLocator(address); + + // Check it was set + l = interest.getLocator(); + ret = !std::memcmp(&l, &address, sizeof(address)); + + EXPECT_TRUE(ret); +} + +TEST_F(InterestTest, SetGetLifetime) { + // Create interest from buffer + Interest interest; + const constexpr uint32_t lifetime = 10000; + + // Set lifetime + interest.setLifetime(lifetime); + + // Get lifetime + auto l = interest.getLifetime(); + + // Ensure they are the same + EXPECT_EQ(l, lifetime); +} + +TEST_F(InterestTest, HasManifest) { + // Create interest from buffer + Interest interest; + + // Let's expect anexception here + try { + interest.setPayloadType(PayloadType::UNSPECIFIED); + FAIL() << "We expect an esception here"; + } catch (errors::RuntimeException &exc) { + // Ok right exception + } catch (...) { + FAIL() << "Wrong exception thrown"; + } + + interest.setPayloadType(PayloadType::DATA); + EXPECT_FALSE(interest.hasManifest()); + + interest.setPayloadType(PayloadType::MANIFEST); + EXPECT_TRUE(interest.hasManifest()); +} + +TEST_F(InterestTest, AppendSuffixesEncodeAndIterate) { + // Create interest from buffer + Interest interest; + + // Appenad some suffixes, with some duplicates + interest.appendSuffix(1); + interest.appendSuffix(2); + interest.appendSuffix(5); + interest.appendSuffix(3); + interest.appendSuffix(4); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + interest.appendSuffix(5); + + // Encode them in wire format + interest.encodeSuffixes(); + + // Iterate over them. They should be in order and without repetitions + auto suffix = interest.firstSuffix(); + auto n_suffixes = interest.numberOfSuffixes(); + + for (uint32_t i = 0; i < n_suffixes; i++) { + EXPECT_EQ(*(suffix + i), (i + 1)); + } +} + +} // namespace core +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}
\ No newline at end of file diff --git a/libtransport/src/test/test_packet.cc b/libtransport/src/test/test_packet.cc new file mode 100644 index 000000000..0ee140e2c --- /dev/null +++ b/libtransport/src/test/test_packet.cc @@ -0,0 +1,1047 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> +#include <hicn/transport/core/packet.h> +#include <hicn/transport/errors/not_implemented_exception.h> +#include <test/packet_samples.h> + +#include <climits> +#include <random> +#include <vector> + +namespace transport { + +namespace core { + +/** + * Since packet is an abstract class, we derive a concrete class to be used for + * the test. + */ +class PacketForTest : public Packet { + public: + template <typename... Args> + PacketForTest(Args &&... args) : Packet(std::forward<Args>(args)...) {} + + virtual ~PacketForTest() {} + + const Name &getName() const override { + throw errors::NotImplementedException(); + } + + Name &getWritableName() override { throw errors::NotImplementedException(); } + + void setName(const Name &name) override { + throw errors::NotImplementedException(); + } + + void setName(Name &&name) override { + throw errors::NotImplementedException(); + } + + void setLifetime(uint32_t lifetime) override { + throw errors::NotImplementedException(); + } + + uint32_t getLifetime() const override { + throw errors::NotImplementedException(); + } + + void setLocator(const ip_address_t &locator) override { + throw errors::NotImplementedException(); + } + + void resetForHash() override { throw errors::NotImplementedException(); } + + ip_address_t getLocator() const override { + throw errors::NotImplementedException(); + } +}; + +namespace { +// The fixture for testing class Foo. +class PacketTest : public ::testing::Test { + protected: + PacketTest() + : name_("b001::123|321"), + packet(Packet::COPY_BUFFER, &raw_packets_[HF_INET6_TCP][0], + raw_packets_[HF_INET6_TCP].size()) { + // You can do set-up work for each test here. + } + + virtual ~PacketTest() { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + virtual void SetUp() { + // Code here will be called immediately after the constructor (right + // before each test). + } + + virtual void TearDown() { + // Code here will be called immediately after each test (right + // before the destructor). + } + + Name name_; + + PacketForTest packet; + + static std::map<Packet::Format, std::vector<uint8_t>> raw_packets_; + + std::vector<uint8_t> payload = { + 0x11, 0x11, 0x01, 0x00, 0xb0, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad // , 0x00, 0x00, + // 0x00, 0x45, 0xa3, + // 0xd1, 0xf2, 0x2b, + // 0x94, 0x41, 0x22, + // 0xc9, 0x00, 0x00, + // 0x00, 0x44, 0xa3, + // 0xd1, 0xf2, 0x2b, + // 0x94, 0x41, 0x22, + // 0xc8 + }; +}; + +std::map<Packet::Format, std::vector<uint8_t>> PacketTest::raw_packets_ = { + {Packet::Format::HF_INET6_TCP, + + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // TCP src=0x1234 dst=0x4321, seq=0x0001 + TCP_HEADER(0x00), + // Payload + PAYLOAD}}, + + {Packet::Format::HF_INET_TCP, + {// IPv4 src=3.13.127.8, dst=192.168.1.92 + IPV4_HEADER(TCP_PROTO, 20 + PAYLOAD_SIZE), + // TCP src=0x1234 dst=0x4321, seq=0x0001 + TCP_HEADER(0x00), + // Other + PAYLOAD}}, + + {Packet::Format::HF_INET_ICMP, + {// IPv4 src=3.13.127.8, dst=192.168.1.92 + IPV4_HEADER(ICMP_PROTO, 64), + // ICMP echo request + ICMP_ECHO_REQUEST}}, + + {Packet::Format::HF_INET6_ICMP, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60), + // ICMP6 echo request + ICMP6_ECHO_REQUEST}}, + + {Packet::Format::HF_INET6_TCP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(TCP_PROTO, 20 + 44 + 128), + // ICMP6 echo request + TCP_HEADER(0x18), + // hICN AH header + AH_HEADER}}, + + {Packet::Format::HF_INET_TCP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV4_HEADER(TCP_PROTO, 20 + 44 + 128), + // ICMP6 echo request + TCP_HEADER(0x18), + // hICN AH header + AH_HEADER}}, + + // XXX No flag defined in ICMP header to signal AH header. + {Packet::Format::HF_INET_ICMP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV4_HEADER(ICMP_PROTO, 64 + 44), + // ICMP6 echo request + ICMP_ECHO_REQUEST, + // hICN AH header + AH_HEADER}}, + + {Packet::Format::HF_INET6_ICMP_AH, + {// IPv6 src=b001::ab:cdab:cdef, dst=b002::ca + IPV6_HEADER(ICMP6_PROTO, 60 + 44), + // ICMP6 echo request + ICMP6_ECHO_REQUEST, + // hICN AH header + AH_HEADER}}, + +}; + +void testFormatConstructor(Packet::Format format = HF_UNSPEC) { + try { + PacketForTest packet(format); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown for " << format; + } +} + +void testFormatAndAdditionalHeaderConstructor(Packet::Format format, + std::size_t additional_header) { + PacketForTest packet(format, additional_header); + // Packet length should be the one of the normal header + the + // additional_header + + EXPECT_EQ(packet.headerSize(), + Packet::getHeaderSizeFromFormat(format) + additional_header); +} + +void testRawBufferConstructor(std::vector<uint8_t> packet, + Packet::Format format) { + try { + // Try to construct packet from correct buffer + PacketForTest p(Packet::WRAP_BUFFER, &packet[0], packet.size(), + packet.size()); + + // Check format is expected one. + EXPECT_EQ(p.getFormat(), format); + + // // Try the same using a MemBuf + // auto buf = utils::MemBuf::wrapBuffer(&packet[0], packet.size()); + // buf->append(packet.size()); + // PacketForTest p2(std::move(buf)); + + // EXPECT_EQ(p2.getFormat(), format); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown"; + } + + try { + // Try to construct packet from wrong buffer + + // Modify next header to 0 + /* ipv6 */ + packet[6] = 0x00; + /* ipv4 */ + packet[9] = 0x00; + PacketForTest p(Packet::WRAP_BUFFER, &packet[0], packet.size(), + packet.size()); + + // Format should fallback to HF_UNSPEC + EXPECT_EQ(p.getFormat(), HF_UNSPEC); + } catch (...) { + FAIL() << "ERROR: Unexpected exception thrown."; + } +} + +void getHeaderSizeFromBuffer(Packet::Format format, + std::vector<uint8_t> &packet, + std::size_t expected) { + auto header_size = PacketForTest::getHeaderSizeFromBuffer(format, &packet[0]); + EXPECT_EQ(header_size, expected); +} + +void getHeaderSizeFromFormat(Packet::Format format, std::size_t expected) { + auto header_size = PacketForTest::getHeaderSizeFromFormat(format); + EXPECT_EQ(header_size, expected); +} + +void getPayloadSizeFromBuffer(Packet::Format format, + std::vector<uint8_t> &packet, + std::size_t expected) { + auto payload_size = + PacketForTest::getPayloadSizeFromBuffer(format, &packet[0]); + EXPECT_EQ(payload_size, expected); +} + +void getFormatFromBuffer(Packet::Format expected, + std::vector<uint8_t> &packet) { + auto format = PacketForTest::getFormatFromBuffer(&packet[0], packet.size()); + EXPECT_EQ(format, expected); +} + +void getHeaderSize(std::size_t expected, const PacketForTest &packet) { + auto size = packet.headerSize(); + EXPECT_EQ(size, expected); +} + +void testGetFormat(Packet::Format expected, const Packet &packet) { + auto format = packet.getFormat(); + EXPECT_EQ(format, expected); +} + +} // namespace + +TEST_F(PacketTest, ConstructorWithFormat) { + testFormatConstructor(Packet::Format::HF_INET_TCP); + testFormatConstructor(Packet::Format::HF_INET6_TCP); + testFormatConstructor(Packet::Format::HF_INET_ICMP); + testFormatConstructor(Packet::Format::HF_INET6_ICMP); + testFormatConstructor(Packet::Format::HF_INET_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET6_TCP_AH); + testFormatConstructor(Packet::Format::HF_INET_ICMP_AH); + testFormatConstructor(Packet::Format::HF_INET6_ICMP_AH); +} + +TEST_F(PacketTest, ConstructorWithFormatAndAdditionalHeader) { + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_TCP, 123); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_TCP, 360); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_ICMP, 21); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_ICMP, 444); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_TCP_AH, 555); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_TCP_AH, + 321); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET_ICMP_AH, + 123); + testFormatAndAdditionalHeaderConstructor(Packet::Format::HF_INET6_ICMP_AH, + 44); +} + +TEST_F(PacketTest, ConstructorWithNew) { + auto &_packet = raw_packets_[HF_INET6_TCP]; + auto packet_ptr = new PacketForTest(Packet::WRAP_BUFFER, &_packet[0], + _packet.size(), _packet.size()); + (void)packet_ptr; +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6Tcp) { + auto format = Packet::Format::HF_INET6_TCP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetTcp) { + auto format = Packet::Format::HF_INET_TCP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetIcmp) { + auto format = Packet::Format::HF_INET_ICMP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6Icmp) { + auto format = Packet::Format::HF_INET6_ICMP; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInet6TcpAh) { + auto format = Packet::Format::HF_INET6_TCP_AH; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, ConstructorWithRawBufferInetTcpAh) { + auto format = Packet::Format::HF_INET_TCP_AH; + testRawBufferConstructor(raw_packets_[format], format); +} + +TEST_F(PacketTest, MoveConstructor) { + PacketForTest p0(Packet::Format::HF_INET6_TCP); + PacketForTest p1(std::move(p0)); + EXPECT_EQ(p0.getFormat(), Packet::Format::HF_UNSPEC); + EXPECT_EQ(p1.getFormat(), Packet::Format::HF_INET6_TCP); +} + +TEST_F(PacketTest, TestGetHeaderSizeFromBuffer) { + getHeaderSizeFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP], + HICN_V6_TCP_HDRLEN); + getHeaderSizeFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP], + HICN_V4_TCP_HDRLEN); + getHeaderSizeFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP], + IPV6_HDRLEN + 4); + getHeaderSizeFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP], + IPV4_HDRLEN + 4); + getHeaderSizeFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH], + HICN_V6_TCP_AH_HDRLEN + 128); + getHeaderSizeFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH], + HICN_V4_TCP_AH_HDRLEN + 128); +} + +TEST_F(PacketTest, TestGetHeaderSizeFromFormat) { + getHeaderSizeFromFormat(HF_INET6_TCP, HICN_V6_TCP_HDRLEN); + getHeaderSizeFromFormat(HF_INET_TCP, HICN_V4_TCP_HDRLEN); + getHeaderSizeFromFormat(HF_INET6_ICMP, IPV6_HDRLEN + 4); + getHeaderSizeFromFormat(HF_INET_ICMP, IPV4_HDRLEN + 4); + getHeaderSizeFromFormat(HF_INET6_TCP_AH, HICN_V6_TCP_AH_HDRLEN); + getHeaderSizeFromFormat(HF_INET_TCP_AH, HICN_V4_TCP_AH_HDRLEN); +} + +TEST_F(PacketTest, TestGetPayloadSizeFromBuffer) { + getPayloadSizeFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP], 12); + getPayloadSizeFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP], 12); + getPayloadSizeFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP], 56); + getPayloadSizeFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP], 60); + getPayloadSizeFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH], 0); + getPayloadSizeFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH], 0); +} + +TEST_F(PacketTest, TestIsInterest) { + auto ret = PacketForTest::isInterest(&raw_packets_[HF_INET6_TCP][0]); + + EXPECT_TRUE(ret); +} + +TEST_F(PacketTest, TestGetFormatFromBuffer) { + getFormatFromBuffer(HF_INET6_TCP, raw_packets_[HF_INET6_TCP]); + getFormatFromBuffer(HF_INET_TCP, raw_packets_[HF_INET_TCP]); + getFormatFromBuffer(HF_INET6_ICMP, raw_packets_[HF_INET6_ICMP]); + getFormatFromBuffer(HF_INET_ICMP, raw_packets_[HF_INET_ICMP]); + getFormatFromBuffer(HF_INET6_TCP_AH, raw_packets_[HF_INET6_TCP_AH]); + getFormatFromBuffer(HF_INET_TCP_AH, raw_packets_[HF_INET_TCP_AH]); +} + +// TEST_F(PacketTest, TestReplace) { +// PacketForTest packet(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_TCP][0], +// raw_packets_[HF_INET6_TCP].size()); + +// // Replace current packet with another one +// packet.replace(&raw_packets_[HF_INET_TCP][0], +// raw_packets_[HF_INET_TCP].size()); + +// // Check new format +// ASSERT_EQ(packet.getFormat(), HF_INET_TCP); +// } + +TEST_F(PacketTest, TestPayloadSize) { + // Check payload size of existing packet + auto &_packet = raw_packets_[HF_INET6_TCP]; + PacketForTest packet(Packet::WRAP_BUFFER, &_packet[0], _packet.size(), + _packet.size()); + + EXPECT_EQ(packet.payloadSize(), std::size_t(PAYLOAD_SIZE)); + + // Check for dynamic generated packet + std::string payload0(1024, 'X'); + + // Create the packet + PacketForTest packet2(HF_INET6_TCP); + + // Payload size should now be zero + EXPECT_EQ(packet2.payloadSize(), std::size_t(0)); + + // Append payload 1 time + packet2.appendPayload((const uint8_t *)payload0.c_str(), payload0.size()); + + // size should now be 1024 + EXPECT_EQ(packet2.payloadSize(), std::size_t(1024)); + + // Append second payload + std::string payload1(1024, 'X'); + packet2.appendPayload((const uint8_t *)payload1.c_str(), payload1.size()); + + // Check size is 2048 + EXPECT_EQ(packet2.payloadSize(), std::size_t(2048)); + + // Append Membuf + packet2.appendPayload(utils::MemBuf::copyBuffer( + (const uint8_t *)payload1.c_str(), payload1.size())); + + // Check size is 3072 + EXPECT_EQ(packet2.payloadSize(), std::size_t(3072)); +} + +TEST_F(PacketTest, TestHeaderSize) { + getHeaderSize(HICN_V6_TCP_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_TCP)); + getHeaderSize(HICN_V4_TCP_HDRLEN, PacketForTest(Packet::Format::HF_INET_TCP)); + getHeaderSize(HICN_V6_ICMP_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_ICMP)); + getHeaderSize(HICN_V4_ICMP_HDRLEN, + PacketForTest(Packet::Format::HF_INET_ICMP)); + getHeaderSize(HICN_V6_TCP_AH_HDRLEN, + PacketForTest(Packet::Format::HF_INET6_TCP_AH)); + getHeaderSize(HICN_V4_TCP_AH_HDRLEN, + PacketForTest(Packet::Format::HF_INET_TCP_AH)); +} + +TEST_F(PacketTest, TestMemBufReference) { + // Create packet + auto &_packet = raw_packets_[HF_INET6_TCP]; + + // Packet was not created as a shared_ptr. If we try to get a membuf shared + // ptr we should get an exception. + // TODO test with c++ 17 + // try { + // PacketForTest packet(&_packet[0], _packet.size()); + // auto membuf_ref = packet.acquireMemBufReference(); + // FAIL() << "The acquireMemBufReference() call should have throwed an " + // "exception!"; + // } catch (const std::bad_weak_ptr &e) { + // // Ok + // } catch (...) { + // FAIL() << "Not expected exception."; + // } + + auto packet_ptr = std::make_shared<PacketForTest>( + Packet::WRAP_BUFFER, &_packet[0], _packet.size(), _packet.size()); + PacketForTest &packet = *packet_ptr; + + // Acquire a reference to the membuf + auto membuf_ref = packet.acquireMemBufReference(); + + // Check refcount. It should be 2 + EXPECT_EQ(membuf_ref.use_count(), 2); + + // Now increment membuf references + Packet::MemBufPtr membuf = packet.acquireMemBufReference(); + + // Now reference count should be 2 + EXPECT_EQ(membuf_ref.use_count(), 3); + + // Copy again + Packet::MemBufPtr membuf2 = membuf; + + // Now reference count should be 3 + EXPECT_EQ(membuf_ref.use_count(), 4); +} + +TEST_F(PacketTest, TestReset) { + // Check everything is ok + EXPECT_EQ(packet.getFormat(), HF_INET6_TCP); + EXPECT_EQ(packet.length(), raw_packets_[HF_INET6_TCP].size()); + EXPECT_EQ(packet.headerSize(), HICN_V6_TCP_HDRLEN); + EXPECT_EQ(packet.payloadSize(), packet.length() - packet.headerSize()); + + // Reset the packet + packet.reset(); + + // Rerun test + EXPECT_EQ(packet.getFormat(), HF_UNSPEC); + EXPECT_EQ(packet.length(), std::size_t(0)); + EXPECT_EQ(packet.headerSize(), std::size_t(0)); + EXPECT_EQ(packet.payloadSize(), std::size_t(0)); +} + +TEST_F(PacketTest, TestAppendPayload) { + // Append payload with raw buffer + uint8_t raw_buffer[2048]; + auto original_payload_length = packet.payloadSize(); + packet.appendPayload(raw_buffer, 1024); + + EXPECT_EQ(original_payload_length + 1024, packet.payloadSize()); + + for (int i = 0; i < 10; i++) { + // Append other payload 10 times + packet.appendPayload(raw_buffer, 1024); + EXPECT_EQ(original_payload_length + 1024 + (1024) * (i + 1), + packet.payloadSize()); + } + + // Append payload using membuf + packet.appendPayload(utils::MemBuf::copyBuffer(raw_buffer, 2048)); + EXPECT_EQ(original_payload_length + 1024 + 1024 * 10 + 2048, + packet.payloadSize()); + + // Check the underlying MemBuf length is the expected one + utils::MemBuf *current = &packet; + size_t total = 0; + do { + total += current->length(); + current = current->next(); + } while (current != &packet); + + EXPECT_EQ(total, packet.headerSize() + packet.payloadSize()); + + // LEt's try now to reset this packet + packet.reset(); + + // There should be no more bufferls left in the chain + EXPECT_EQ(&packet, packet.next()); + EXPECT_EQ(packet.getFormat(), HF_UNSPEC); + EXPECT_EQ(packet.length(), std::size_t(0)); + EXPECT_EQ(packet.headerSize(), std::size_t(0)); + EXPECT_EQ(packet.payloadSize(), std::size_t(0)); +} + +TEST_F(PacketTest, GetPayload) { + // Append payload with raw buffer + uint8_t raw_buffer[2048]; + auto original_payload_length = packet.payloadSize(); + packet.appendPayload(raw_buffer, 2048); + + // Get payload + auto payload = packet.getPayload(); + // Check payload length is correct + utils::MemBuf *current = payload.get(); + size_t total = 0; + do { + total += current->length(); + current = current->next(); + } while (current != payload.get()); + + ASSERT_EQ(total, packet.payloadSize()); + + // Linearize the payload + payload->gather(total); + + // Check memory correspond + payload->trimStart(original_payload_length); + auto ret = memcmp(raw_buffer, payload->data(), 2048); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, UpdateLength) { + auto original_payload_size = packet.payloadSize(); + + // Add some fake payload without using the API + packet.append(200); + + // payloadSize does not know about the new payload, yet + EXPECT_EQ(packet.payloadSize(), original_payload_size); + + // Let's now update the packet length + packet.updateLength(); + + // Now payloadSize knows + EXPECT_EQ(packet.payloadSize(), std::size_t(original_payload_size + 200)); + + // We may also update the length without adding real content. This is only + // written in the packet header. + packet.updateLength(128); + EXPECT_EQ(packet.payloadSize(), + std::size_t(original_payload_size + 200 + 128)); +} + +TEST_F(PacketTest, SetGetPayloadType) { + auto payload_type = packet.getPayloadType(); + + // It should be normal content object by default + EXPECT_EQ(payload_type, PayloadType::DATA); + + // Set it to be manifest + packet.setPayloadType(PayloadType::MANIFEST); + + // Check it is manifest + payload_type = packet.getPayloadType(); + + EXPECT_EQ(payload_type, PayloadType::MANIFEST); +} + +TEST_F(PacketTest, GetFormat) { + { + PacketForTest p0(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_TCP][0], + raw_packets_[Packet::Format::HF_INET_TCP].size(), + raw_packets_[Packet::Format::HF_INET_TCP].size()); + testGetFormat(Packet::Format::HF_INET_TCP, p0); + + PacketForTest p1(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_TCP][0], + raw_packets_[Packet::Format::HF_INET6_TCP].size(), + raw_packets_[Packet::Format::HF_INET6_TCP].size()); + testGetFormat(Packet::Format::HF_INET6_TCP, p1); + + PacketForTest p2(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_ICMP][0], + raw_packets_[Packet::Format::HF_INET_ICMP].size(), + raw_packets_[Packet::Format::HF_INET_ICMP].size()); + testGetFormat(Packet::Format::HF_INET_ICMP, p2); + + PacketForTest p3(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_ICMP][0], + raw_packets_[Packet::Format::HF_INET6_ICMP].size(), + raw_packets_[Packet::Format::HF_INET6_ICMP].size()); + testGetFormat(Packet::Format::HF_INET6_ICMP, p3); + + PacketForTest p4(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET_TCP_AH][0], + raw_packets_[Packet::Format::HF_INET_TCP_AH].size(), + raw_packets_[Packet::Format::HF_INET_TCP_AH].size()); + testGetFormat(Packet::Format::HF_INET_TCP_AH, p4); + + PacketForTest p5(Packet::WRAP_BUFFER, + &raw_packets_[Packet::Format::HF_INET6_TCP_AH][0], + raw_packets_[Packet::Format::HF_INET6_TCP_AH].size(), + raw_packets_[Packet::Format::HF_INET6_TCP_AH].size()); + testGetFormat(Packet::Format::HF_INET6_TCP_AH, p5); + } + + // Let's try now creating empty packets + { + PacketForTest p0(Packet::Format::HF_INET_TCP); + testGetFormat(Packet::Format::HF_INET_TCP, p0); + + PacketForTest p1(Packet::Format::HF_INET6_TCP); + testGetFormat(Packet::Format::HF_INET6_TCP, p1); + + PacketForTest p2(Packet::Format::HF_INET_ICMP); + testGetFormat(Packet::Format::HF_INET_ICMP, p2); + + PacketForTest p3(Packet::Format::HF_INET6_ICMP); + testGetFormat(Packet::Format::HF_INET6_ICMP, p3); + + PacketForTest p4(Packet::Format::HF_INET_TCP_AH); + testGetFormat(Packet::Format::HF_INET_TCP_AH, p4); + + PacketForTest p5(Packet::Format::HF_INET6_TCP_AH); + testGetFormat(Packet::Format::HF_INET6_TCP_AH, p5); + } +} + +TEST_F(PacketTest, SetGetTestSignatureTimestamp) { + // Let's try to set the signature timestamp in a packet without AH header. We + // expect an exception. + using namespace std::chrono; + uint64_t now = + duration_cast<milliseconds>(system_clock::now().time_since_epoch()) + .count(); + + try { + packet.setSignatureTimestamp(now); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto t = packet.getSignatureTimestamp(); + // Let's make compiler happy + (void)t; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setSignatureTimestamp(now); + uint64_t now_get = p.getSignatureTimestamp(); + + // Check we got the right value + EXPECT_EQ(now_get, now); +} + +TEST_F(PacketTest, TestSetGetValidationAlgorithm) { + // Let's try to set the validation algorithm in a packet without AH header. We + // expect an exception. + + try { + packet.setValidationAlgorithm(auth::CryptoSuite::RSA_SHA256); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto v = packet.getSignatureTimestamp(); + // Let's make compiler happy + (void)v; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setValidationAlgorithm(auth::CryptoSuite::RSA_SHA256); + auto v_get = p.getValidationAlgorithm(); + + // Check we got the right value + EXPECT_EQ(v_get, auth::CryptoSuite::RSA_SHA256); +} + +TEST_F(PacketTest, TestSetGetKeyId) { + uint8_t key[32]; + auth::KeyId key_id = std::make_pair(key, sizeof(key)); + + try { + packet.setKeyId(key_id); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Same fot get method + try { + auto k = packet.getKeyId(); + // Let's make compiler happy + (void)k; + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + // Now let's construct a AH packet, with no additional space for signature + PacketForTest p(HF_INET6_TCP_AH); + p.setKeyId(key_id); + auto p_get = p.getKeyId(); + + // Check we got the right value + EXPECT_EQ(p_get.second, key_id.second); + + auto ret = memcmp(p_get.first, key_id.first, p_get.second); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, DISABLED_TestChecksum) { + // Checksum should be wrong + bool integrity = packet.checkIntegrity(); + EXPECT_FALSE(integrity); + + // Let's fix it + packet.setChecksum(); + + // Check again + integrity = packet.checkIntegrity(); + EXPECT_TRUE(integrity); + + // Check with AH header and 300 bytes signature + PacketForTest p(HF_INET6_TCP_AH, 300); + std::string payload(5000, 'X'); + p.appendPayload((const uint8_t *)payload.c_str(), payload.size() / 2); + p.appendPayload((const uint8_t *)(payload.c_str() + payload.size() / 2), + payload.size() / 2); + + p.setChecksum(); + integrity = p.checkIntegrity(); + EXPECT_TRUE(integrity); +} + +TEST_F(PacketTest, TestSetSyn) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setSyn(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setSyn(); + EXPECT_TRUE(packet.testSyn()); + + packet.resetSyn(); + EXPECT_FALSE(packet.testSyn()); +} + +TEST_F(PacketTest, TestSetFin) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setFin(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setFin(); + EXPECT_TRUE(packet.testFin()); + + packet.resetFin(); + EXPECT_FALSE(packet.testFin()); +} + +TEST_F(PacketTest, TestSetAck) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setAck(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setAck(); + EXPECT_TRUE(packet.testAck()); + + packet.resetAck(); + EXPECT_FALSE(packet.testAck()); +} + +TEST_F(PacketTest, TestSetRst) { + // Test syn of non-tcp format and check exception is thrown + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setRst(); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setRst(); + EXPECT_TRUE(packet.testRst()); + + packet.resetRst(); + EXPECT_FALSE(packet.testRst()); +} + +TEST_F(PacketTest, TestResetFlags) { + packet.setRst(); + packet.setSyn(); + packet.setAck(); + packet.setFin(); + EXPECT_TRUE(packet.testRst()); + EXPECT_TRUE(packet.testAck()); + EXPECT_TRUE(packet.testFin()); + EXPECT_TRUE(packet.testSyn()); + + packet.resetFlags(); + EXPECT_FALSE(packet.testRst()); + EXPECT_FALSE(packet.testAck()); + EXPECT_FALSE(packet.testFin()); + EXPECT_FALSE(packet.testSyn()); +} + +TEST_F(PacketTest, TestSetGetSrcPort) { + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setSrcPort(12345); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setSrcPort(12345); + EXPECT_EQ(packet.getSrcPort(), 12345); +} + +TEST_F(PacketTest, TestSetGetDstPort) { + try { + auto p = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + // Let's make compiler happy + p.setDstPort(12345); + FAIL() << "We should not reach this point."; + } catch (const errors::RuntimeException &exc) { + /* ok right exception*/ + } catch (...) { + FAIL() << "Unexpected exception"; + } + + packet.setDstPort(12345); + EXPECT_EQ(packet.getDstPort(), 12345); +} + +TEST_F(PacketTest, TestEnsureCapacity) { + PacketForTest &p = packet; + + // This shoul be false + auto ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size() + 10); + EXPECT_FALSE(ret); + + // This should be true + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size()); + EXPECT_TRUE(ret); + + // This should be true + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size() - 10); + EXPECT_TRUE(ret); + + // Try to trim the packet start + p.trimStart(10); + // Now this should be false + ret = p.ensureCapacity(raw_packets_[HF_INET6_TCP].size()); + EXPECT_FALSE(ret); + + // Create a new packet + auto p2 = PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_ICMP][0], + raw_packets_[HF_INET6_ICMP].size(), + raw_packets_[HF_INET6_ICMP].size()); + + p2.appendPayload(utils::MemBuf::createCombined(2000)); + + // This should be false, since the buffer is chained + ret = p2.ensureCapacity(raw_packets_[HF_INET6_TCP].size() - 10); + EXPECT_FALSE(ret); +} + +TEST_F(PacketTest, TestEnsureCapacityAndFillUnused) { + // Create packet by excluding the payload (So only L3 + L4 headers). The + // payload will be trated as unused tailroom + PacketForTest p = + PacketForTest(Packet::WRAP_BUFFER, &raw_packets_[HF_INET6_TCP][0], + raw_packets_[HF_INET6_TCP].size() - PAYLOAD_SIZE, + raw_packets_[HF_INET6_TCP].size()); + + // Copy original packet payload, which is here trated as a unused tailroom + uint8_t original_payload[PAYLOAD_SIZE]; + uint8_t *payload = &raw_packets_[HF_INET6_TCP][0] + + raw_packets_[HF_INET6_TCP].size() - PAYLOAD_SIZE; + std::memcpy(original_payload, payload, PAYLOAD_SIZE); + + // This should be true and the unused tailroom should be unmodified + auto ret = p.ensureCapacityAndFillUnused( + raw_packets_[HF_INET6_TCP].size() - (PAYLOAD_SIZE + 10), 0); + EXPECT_TRUE(ret); + ret = std::memcmp(original_payload, payload, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should fill the payload with zeros + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size(), 0); + EXPECT_TRUE(ret); + uint8_t zeros[PAYLOAD_SIZE]; + std::memset(zeros, 0, PAYLOAD_SIZE); + ret = std::memcmp(payload, zeros, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should fill the payload with ones + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size(), 1); + EXPECT_TRUE(ret); + uint8_t ones[PAYLOAD_SIZE]; + std::memset(ones, 1, PAYLOAD_SIZE); + ret = std::memcmp(payload, ones, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); + + // This should return false and the payload should be unmodified + ret = p.ensureCapacityAndFillUnused(raw_packets_[HF_INET6_TCP].size() + 1, 1); + EXPECT_FALSE(ret); + ret = std::memcmp(payload, ones, PAYLOAD_SIZE); + EXPECT_EQ(ret, 0); +} + +TEST_F(PacketTest, TestSetGetTTL) { + packet.setTTL(128); + EXPECT_EQ(packet.getTTL(), 128); +} + +} // namespace core +} // namespace transport + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libtransport/src/transport.config b/libtransport/src/transport.config new file mode 100644 index 000000000..a21175b8d --- /dev/null +++ b/libtransport/src/transport.config @@ -0,0 +1,27 @@ +// Configuration for io_module + +io_module = { + path = []; + name = "forwarder_module"; +}; + +forwarder = { + n_threads = 1; + + listeners = { + l0 = { + local_address = "127.0.0.1"; + local_port = 33436; + } + }; + + connectors = { + c0 = { + /* local_address and local_port are optional */ + local_address = "127.0.0.1"; + local_port = 33436; + remote_address = "10.20.30.40"; + remote_port = 33436; + } + }; +};
\ No newline at end of file diff --git a/libtransport/src/utils/content_store.cc b/libtransport/src/utils/content_store.cc index cb3db6d94..c5cb91149 100644 --- a/libtransport/src/utils/content_store.cc +++ b/libtransport/src/utils/content_store.cc @@ -17,7 +17,6 @@ #include <hicn/transport/core/interest.h> #include <hicn/transport/core/name.h> #include <hicn/transport/utils/log.h> - #include <utils/content_store.h> namespace utils { @@ -60,8 +59,7 @@ void ContentStore::insert( ObjectTimeEntry(content_object, std::chrono::steady_clock::now()), pos); } -const std::shared_ptr<ContentObject> ContentStore::find( - const Interest &interest) { +std::shared_ptr<ContentObject> ContentStore::find(const Interest &interest) { utils::SpinLock::Acquire locked(cs_mutex_); std::shared_ptr<ContentObject> ret = empty_reference_; diff --git a/libtransport/src/utils/content_store.h b/libtransport/src/utils/content_store.h index 03ce76f42..56cd2abb6 100644 --- a/libtransport/src/utils/content_store.h +++ b/libtransport/src/utils/content_store.h @@ -52,7 +52,7 @@ class ContentStore { void insert(const std::shared_ptr<ContentObject> &content_object); - const std::shared_ptr<ContentObject> find(const Interest &interest); + std::shared_ptr<ContentObject> find(const Interest &interest); void erase(const Name &exact_name); diff --git a/libtransport/src/utils/daemonizator.cc b/libtransport/src/utils/daemonizator.cc index c51a68d14..bc7bae700 100644 --- a/libtransport/src/utils/daemonizator.cc +++ b/libtransport/src/utils/daemonizator.cc @@ -17,7 +17,6 @@ #include <hicn/transport/errors/runtime_exception.h> #include <hicn/transport/utils/daemonizator.h> #include <hicn/transport/utils/log.h> - #include <sys/stat.h> #include <unistd.h> diff --git a/libtransport/src/utils/epoll_event_reactor.cc b/libtransport/src/utils/epoll_event_reactor.cc index 63c08df95..eb8c65352 100644 --- a/libtransport/src/utils/epoll_event_reactor.cc +++ b/libtransport/src/utils/epoll_event_reactor.cc @@ -14,12 +14,11 @@ */ #include <hicn/transport/utils/branch_prediction.h> - +#include <signal.h> +#include <unistd.h> #include <utils/epoll_event_reactor.h> #include <utils/fd_deadline_timer.h> -#include <signal.h> -#include <unistd.h> #include <iostream> namespace utils { @@ -111,7 +110,7 @@ void EpollEventReactor::runEventLoop(int timeout) { if (errno == EINTR) { continue; } else { - return; + return; } } diff --git a/libtransport/src/utils/epoll_event_reactor.h b/libtransport/src/utils/epoll_event_reactor.h index 4cb87ebd4..9ebfca937 100644 --- a/libtransport/src/utils/epoll_event_reactor.h +++ b/libtransport/src/utils/epoll_event_reactor.h @@ -16,9 +16,9 @@ #pragma once #include <hicn/transport/utils/spinlock.h> +#include <sys/epoll.h> #include <utils/event_reactor.h> -#include <sys/epoll.h> #include <atomic> #include <cstddef> #include <functional> diff --git a/libtransport/src/utils/fd_deadline_timer.h b/libtransport/src/utils/fd_deadline_timer.h index 8bc3bbca3..38396e027 100644 --- a/libtransport/src/utils/fd_deadline_timer.h +++ b/libtransport/src/utils/fd_deadline_timer.h @@ -17,16 +17,14 @@ #include <hicn/transport/errors/runtime_exception.h> #include <hicn/transport/utils/log.h> - +#include <sys/timerfd.h> +#include <unistd.h> #include <utils/deadline_timer.h> #include <utils/epoll_event_reactor.h> #include <chrono> #include <cstddef> -#include <sys/timerfd.h> -#include <unistd.h> - namespace utils { class FdDeadlineTimer : public DeadlineTimer<FdDeadlineTimer> { diff --git a/libtransport/src/utils/membuf.cc b/libtransport/src/utils/membuf.cc index 94e5b13a1..73c45cf6d 100644 --- a/libtransport/src/utils/membuf.cc +++ b/libtransport/src/utils/membuf.cc @@ -145,6 +145,18 @@ void MemBuf::operator delete(void* /* ptr */, void* /* placement */) { // constructor. } +bool MemBuf::operator==(const MemBuf& other) { + if (length() != other.length()) { + return false; + } + + return (memcmp(data(), other.data(), length()) == 0); +} + +bool MemBuf::operator!=(const MemBuf& other) { + return !this->operator==(other); +} + void MemBuf::releaseStorage(HeapStorage* storage, uint16_t freeFlags) { // Use relaxed memory order here. If we are unlucky and happen to get // out-of-date data the compare_exchange_weak() call below will catch @@ -299,21 +311,23 @@ unique_ptr<MemBuf> MemBuf::takeOwnership(void* buf, std::size_t capacity, } } -MemBuf::MemBuf(WrapBufferOp, const void* buf, std::size_t capacity) noexcept +MemBuf::MemBuf(WrapBufferOp, const void* buf, std::size_t length, + std::size_t capacity) noexcept : MemBuf(InternalConstructor(), 0, // We cast away the const-ness of the buffer here. // This is okay since MemBuf users must use unshare() to create a // copy of this buffer before writing to the buffer. static_cast<uint8_t*>(const_cast<void*>(buf)), capacity, - static_cast<uint8_t*>(const_cast<void*>(buf)), capacity) {} + static_cast<uint8_t*>(const_cast<void*>(buf)), length) {} -unique_ptr<MemBuf> MemBuf::wrapBuffer(const void* buf, std::size_t capacity) { - return std::make_unique<MemBuf>(WRAP_BUFFER, buf, capacity); +unique_ptr<MemBuf> MemBuf::wrapBuffer(const void* buf, std::size_t length, + std::size_t capacity) { + return std::make_unique<MemBuf>(WRAP_BUFFER, buf, length, capacity); } -MemBuf MemBuf::wrapBufferAsValue(const void* buf, +MemBuf MemBuf::wrapBufferAsValue(const void* buf, std::size_t length, std::size_t capacity) noexcept { - return MemBuf(WrapBufferOp::WRAP_BUFFER, buf, capacity); + return MemBuf(WrapBufferOp::WRAP_BUFFER, buf, length, capacity); } MemBuf::MemBuf() noexcept {} @@ -862,4 +876,22 @@ void MemBuf::initExtBuffer(uint8_t* buf, size_t mallocSize, *infoReturn = sharedInfo; } +bool MemBuf::ensureCapacity(std::size_t capacity) { + return !isChained() && std::size_t((bufferEnd() - data())) >= capacity; +} + +bool MemBuf::ensureCapacityAndFillUnused(std::size_t capacity, + uint8_t placeholder) { + auto ret = ensureCapacity(capacity); + if (!ret) { + return ret; + } + + if (length() < capacity) { + std::memset(writableTail(), placeholder, capacity - length()); + } + + return ret; +} + } // namespace utils
\ No newline at end of file diff --git a/libtransport/src/utils/memory_pool_allocator.h b/libtransport/src/utils/memory_pool_allocator.h index adc1443ad..a960b91bb 100644 --- a/libtransport/src/utils/memory_pool_allocator.h +++ b/libtransport/src/utils/memory_pool_allocator.h @@ -149,4 +149,4 @@ class Allocator : private MemoryPool<T, growSize> { void destroy(pointer p) { p->~T(); } }; -}
\ No newline at end of file +} // namespace utils
\ No newline at end of file diff --git a/libtransport/src/utils/min_filter.h b/libtransport/src/utils/min_filter.h index dcfd5652d..f1aaea7a8 100644 --- a/libtransport/src/utils/min_filter.h +++ b/libtransport/src/utils/min_filter.h @@ -43,6 +43,11 @@ class MinFilter { by_arrival_.push_front(by_order_.insert(std::forward<R>(value))); } + TRANSPORT_ALWAYS_INLINE void clear() { + by_arrival_.clear(); + by_order_.clear(); + } + TRANSPORT_ALWAYS_INLINE const T& begin() { return *by_order_.cbegin(); } TRANSPORT_ALWAYS_INLINE const T& rBegin() { return *by_order_.crbegin(); } |