diff options
Diffstat (limited to 'libtransport/src')
132 files changed, 2568 insertions, 4629 deletions
diff --git a/libtransport/src/auth/signer.cc b/libtransport/src/auth/signer.cc index 918e271f5..f13df53eb 100644 --- a/libtransport/src/auth/signer.cc +++ b/libtransport/src/auth/signer.cc @@ -15,6 +15,7 @@ #include <glog/logging.h> #include <hicn/transport/auth/signer.h> +#include <hicn/transport/core/interest.h> #include <hicn/transport/utils/chrono_typedefs.h> #include "hicn/transport/core/global_object_pool.h" @@ -50,6 +51,15 @@ void Signer::signPacket(PacketPtr packet) { hicn_header_t header_copy; hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + // Copy bitmap from interest manifest + uint32_t request_bitmap[BITMAP_SIZE] = {0}; + if (packet->isInterest()) { + core::Interest *interest = dynamic_cast<core::Interest *>(packet); + if (interest->hasManifest()) + memcpy(request_bitmap, interest->getRequestBitmap(), + BITMAP_SIZE * sizeof(uint32_t)); + } + // Fill in the hICN AH header auto now = utils::SteadyTime::nowMs().count(); packet->setSignatureTimestamp(now); @@ -69,6 +79,12 @@ void Signer::signPacket(PacketPtr packet) { // Restore header hicn_packet_copy_header(format, &header_copy, packet->packet_start_, false); + + // Restore bitmap in interest manifest + if (packet->isInterest()) { + core::Interest *interest = dynamic_cast<core::Interest *>(packet); + interest->setRequestBitmap(request_bitmap); + } } void Signer::signBuffer(const std::vector<uint8_t> &buffer) { @@ -241,16 +257,23 @@ void AsymmetricSigner::setKey(CryptoSuite suite, std::shared_ptr<EVP_PKEY> key, std::shared_ptr<EVP_PKEY> pub_key) { suite_ = suite; key_ = key; - signature_len_ = EVP_PKEY_size(key.get()); + + signature_len_ = EVP_PKEY_size(key_.get()); DCHECK(signature_len_ <= signature_->tailroom()); + signature_->setLength(signature_len_); - std::vector<uint8_t> pbk(i2d_PublicKey(pub_key.get(), nullptr)); - uint8_t *pbk_ptr = pbk.data(); - int len = i2d_PublicKey(pub_key.get(), &pbk_ptr); + size_t enc_pbk_len = i2d_PublicKey(pub_key.get(), nullptr); + DCHECK(enc_pbk_len >= 0); + + uint8_t *enc_pbkey_raw = nullptr; + i2d_PublicKey(pub_key.get(), &enc_pbkey_raw); + DCHECK(enc_pbkey_raw != nullptr); key_id_ = CryptoHash(getHashType()); - key_id_.computeDigest(pbk_ptr, len); + key_id_.computeDigest(enc_pbkey_raw, enc_pbk_len); + + OPENSSL_free(enc_pbkey_raw); } size_t AsymmetricSigner::getSignatureFieldSize() const { diff --git a/libtransport/src/auth/verifier.cc b/libtransport/src/auth/verifier.cc index 5d5f01711..e257582f6 100644 --- a/libtransport/src/auth/verifier.cc +++ b/libtransport/src/auth/verifier.cc @@ -15,6 +15,7 @@ #include <hicn/transport/auth/verifier.h> #include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/core/interest.h> #include <protocols/errors.h> #include "glog/logging.h" @@ -51,6 +52,14 @@ bool Verifier::verifyPacket(PacketPtr packet) { hicn_header_t header_copy; hicn_packet_copy_header(format, packet->packet_start_, &header_copy, false); + // Copy bitmap from interest manifest + uint32_t request_bitmap[BITMAP_SIZE] = {0}; + if (packet->isInterest()) { + core::Interest *interest = dynamic_cast<core::Interest *>(packet); + memcpy(request_bitmap, interest->getRequestBitmap(), + BITMAP_SIZE * sizeof(uint32_t)); + } + // Retrieve packet signature utils::MemBuf::Ptr signature_raw = packet->getSignature(); std::size_t signature_len = packet->getSignatureSize(); @@ -69,6 +78,12 @@ bool Verifier::verifyPacket(PacketPtr packet) { packet->setSignature(signature_raw); packet->setSignatureSize(signature_raw->length()); + // Restore bitmap in interest manifest + if (packet->isInterest()) { + core::Interest *interest = dynamic_cast<core::Interest *>(packet); + interest->setRequestBitmap(request_bitmap); + } + return valid_packet; } diff --git a/libtransport/src/core/CMakeLists.txt b/libtransport/src/core/CMakeLists.txt index b9b024d60..777772a04 100644 --- a/libtransport/src/core/CMakeLists.txt +++ b/libtransport/src/core/CMakeLists.txt @@ -14,17 +14,18 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/facade.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest.h - ${CMAKE_CURRENT_SOURCE_DIR}/manifest_inline.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format_fixed.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format.h ${CMAKE_CURRENT_SOURCE_DIR}/pending_interest.h ${CMAKE_CURRENT_SOURCE_DIR}/portal.h ${CMAKE_CURRENT_SOURCE_DIR}/errors.h ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_id_counter.h ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.h ${CMAKE_CURRENT_SOURCE_DIR}/global_workers.h ${CMAKE_CURRENT_SOURCE_DIR}/udp_connector.h ${CMAKE_CURRENT_SOURCE_DIR}/udp_listener.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_module_manager.h ) list(APPEND SOURCE_FILES @@ -38,9 +39,9 @@ list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc ${CMAKE_CURRENT_SOURCE_DIR}/io_module.cc - ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.cc ${CMAKE_CURRENT_SOURCE_DIR}/udp_connector.cc ${CMAKE_CURRENT_SOURCE_DIR}/udp_listener.cc + ${CMAKE_CURRENT_SOURCE_DIR}/constructor.cc ) if (NOT ${CMAKE_SYSTEM_NAME} MATCHES Android) @@ -56,4 +57,4 @@ if (NOT ${CMAKE_SYSTEM_NAME} MATCHES Android) endif() set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) -set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/core/constructor.cc b/libtransport/src/core/constructor.cc new file mode 100644 index 000000000..0c7f0dfa8 --- /dev/null +++ b/libtransport/src/core/constructor.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <core/global_module_manager.h> +#include <core/global_workers.h> +#include <hicn/transport/core/global_object_pool.h> + +namespace transport { +namespace core { + +void __attribute__((constructor)) libtransportInit() { + // First the global module manager is initialized + GlobalModuleManager::getInstance(); + // Then the packet allocator is initialized + PacketManager<>::getInstance(); + // Then the global configuration is initialized + GlobalConfiguration::getInstance(); + // Then the global workers are initialized + GlobalWorkers::getInstance(); +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/content_object.cc b/libtransport/src/core/content_object.cc index 643e0388e..e66b2a6cd 100644 --- a/libtransport/src/core/content_object.cc +++ b/libtransport/src/core/content_object.cc @@ -15,6 +15,7 @@ #include <hicn/transport/core/content_object.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/branch_prediction.h> extern "C" { @@ -46,7 +47,7 @@ ContentObject::ContentObject(const Name &name, Packet::Format format, } if (TRANSPORT_EXPECT_FALSE(hicn_data_get_name(format_, packet_start_, - name_.getStructReference()) < + &name_.getStructReference()) < 0)) { throw errors::MalformedPacketException(); } @@ -91,7 +92,7 @@ ContentObject::~ContentObject() {} const Name &ContentObject::getName() const { if (!name_) { if (hicn_data_get_name(format_, packet_start_, - (hicn_name_t *)name_.getConstStructReference()) < + (hicn_name_t *)&name_.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } @@ -104,11 +105,11 @@ Name &ContentObject::getWritableName() { return const_cast<Name &>(getName()); } void ContentObject::setName(const Name &name) { if (hicn_data_set_name(format_, packet_start_, - name.getConstStructReference()) < 0) { + &name.getConstStructReference()) < 0) { throw errors::RuntimeException("Error setting content object name."); } - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < + if (hicn_data_get_name(format_, packet_start_, &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } @@ -121,11 +122,11 @@ uint32_t ContentObject::getPathLabel() const { "Error retrieving the path label from content object"); } - return ntohl(path_label); + return portability::net_to_host(path_label); } ContentObject &ContentObject::setPathLabel(uint32_t path_label) { - path_label = htonl(path_label); + path_label = portability::host_to_net(path_label); if (hicn_data_set_path_label((hicn_header_t *)packet_start_, path_label) < 0) { throw errors::RuntimeException( diff --git a/libtransport/src/core/facade.h b/libtransport/src/core/facade.h index 1ad4437e2..77c1d16d2 100644 --- a/libtransport/src/core/facade.h +++ b/libtransport/src/core/facade.h @@ -15,16 +15,16 @@ #pragma once +#include <core/manifest.h> #include <core/manifest_format_fixed.h> -#include <core/manifest_inline.h> #include <core/portal.h> namespace transport { namespace core { -using ContentObjectManifest = core::ManifestInline<ContentObject, Fixed>; -using InterestManifest = core::ManifestInline<Interest, Fixed>; +using ContentObjectManifest = core::Manifest<Fixed>; +using InterestManifest = core::Manifest<Fixed>; } // namespace core diff --git a/libtransport/src/io_modules/forwarder/global_id_counter.h b/libtransport/src/core/global_id_counter.h index 0a67b76d5..0a67b76d5 100644 --- a/libtransport/src/io_modules/forwarder/global_id_counter.h +++ b/libtransport/src/core/global_id_counter.h diff --git a/libtransport/src/core/global_module_manager.h b/libtransport/src/core/global_module_manager.h new file mode 100644 index 000000000..c9d272cdb --- /dev/null +++ b/libtransport/src/core/global_module_manager.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <glog/logging.h> +#include <hicn/transport/utils/singleton.h> + +#ifndef _WIN32 +#include <dlfcn.h> +#endif + +#include <atomic> +#include <iostream> +#include <mutex> +#include <unordered_map> + +namespace transport { +namespace core { + +class GlobalModuleManager : public utils::Singleton<GlobalModuleManager> { + public: + friend class utils::Singleton<GlobalModuleManager>; + + ~GlobalModuleManager() { + for (const auto &[key, value] : modules_) { + unload(value); + } + } + + void *loadModule(const std::string &module_name) { + void *handle = nullptr; + const char *error = nullptr; + + // Lock + std::unique_lock lck(mtx_); + + auto it = modules_.find(module_name); + if (it != modules_.end()) { + return it->second; + } + + // open module + handle = dlopen(module_name.c_str(), RTLD_NOW); + if (!handle) { + if ((error = dlerror()) != nullptr) { + LOG(ERROR) << error; + } + return nullptr; + } + + auto ret = modules_.try_emplace(module_name, handle); + DCHECK(ret.second); + + return handle; + } + + void unload(void *handle) { + // destroy object and close module + dlclose(handle); + } + + bool unloadModule(const std::string &module_name) { + // Lock + std::unique_lock lck(mtx_); + auto it = modules_.find(module_name); + if (it != modules_.end()) { + unload(it->second); + return true; + } + + return false; + } + + private: + GlobalModuleManager() = default; + std::mutex mtx_; + std::unordered_map<std::string, void *> modules_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/global_workers.h b/libtransport/src/core/global_workers.h index 1ac254188..c5d794ef2 100644 --- a/libtransport/src/core/global_workers.h +++ b/libtransport/src/core/global_workers.h @@ -32,6 +32,8 @@ class GlobalWorkers : public utils::Singleton<GlobalWorkers> { return thread_pool_.getWorker(counter_++ % thread_pool_.getNThreads()); } + auto& getWorkers() { return thread_pool_.getWorkers(); } + private: GlobalWorkers() : counter_(0), thread_pool_() {} diff --git a/libtransport/src/core/interest.cc b/libtransport/src/core/interest.cc index b7719b3ed..8b9dcf256 100644 --- a/libtransport/src/core/interest.cc +++ b/libtransport/src/core/interest.cc @@ -21,6 +21,7 @@ extern "C" { #ifndef _WIN32 TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") #endif +#include <hicn/base.h> #include <hicn/hicn.h> } @@ -39,12 +40,12 @@ Interest::Interest(const Name &interest_name, Packet::Format format, } if (hicn_interest_set_name(format_, packet_start_, - interest_name.getConstStructReference()) < 0) { + &interest_name.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -64,7 +65,7 @@ Interest::Interest(hicn_format_t format, std::size_t additional_header_size) Interest::Interest(MemBuf &&buffer) : Packet(std::move(buffer)) { if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -86,9 +87,9 @@ Interest::~Interest() {} const Name &Interest::getName() const { if (!name_) { - if (hicn_interest_get_name(format_, packet_start_, - (hicn_name_t *)name_.getConstStructReference()) < - 0) { + if (hicn_interest_get_name( + format_, packet_start_, + (hicn_name_t *)&name_.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -100,12 +101,12 @@ Name &Interest::getWritableName() { return const_cast<Name &>(getName()); } void Interest::setName(const Name &name) { if (hicn_interest_set_name(format_, packet_start_, - name.getConstStructReference()) < 0) { + &name.getConstStructReference()) < 0) { throw errors::RuntimeException("Error setting interest name."); } if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -150,6 +151,13 @@ void Interest::resetForHash() { throw errors::RuntimeException( "Error resetting interest fields for hash computation."); } + + // Reset request bitmap in manifest + if (hasManifest()) { + auto int_manifest_header = + (interest_manifest_header_t *)(writableData() + headerSize()); + memset(int_manifest_header->request_bitmap, 0, BITMAP_SIZE * sizeof(u32)); + } } bool Interest::hasManifest() { @@ -171,19 +179,21 @@ void Interest::encodeSuffixes() { // We assume interest does not hold signature for the moment. auto int_manifest_header = - (InterestManifestHeader *)(writableData() + headerSize()); - int_manifest_header->n_suffixes = suffix_set_.size(); - std::size_t additional_length = - sizeof(InterestManifestHeader) + - int_manifest_header->n_suffixes * sizeof(uint32_t); + (interest_manifest_header_t *)(writableData() + headerSize()); + int_manifest_header->n_suffixes = (uint32_t)suffix_set_.size(); + memset(int_manifest_header->request_bitmap, 0xFFFFFFFF, + BITMAP_SIZE * sizeof(u32)); uint32_t *suffix = (uint32_t *)(int_manifest_header + 1); for (auto it = suffix_set_.begin(); it != suffix_set_.end(); it++, suffix++) { *suffix = *it; } + std::size_t additional_length = + sizeof(interest_manifest_header_t) + + int_manifest_header->n_suffixes * sizeof(uint32_t); append(additional_length); - updateLength(additional_length); + updateLength(); } uint32_t *Interest::firstSuffix() { @@ -191,7 +201,7 @@ uint32_t *Interest::firstSuffix() { return nullptr; } - auto ret = (InterestManifestHeader *)(writableData() + headerSize()); + auto ret = (interest_manifest_header_t *)(writableData() + headerSize()); ret += 1; return (uint32_t *)ret; @@ -202,11 +212,48 @@ uint32_t Interest::numberOfSuffixes() { return 0; } - auto header = (InterestManifestHeader *)(writableData() + headerSize()); + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); return header->n_suffixes; } +uint32_t *Interest::getRequestBitmap() { + if (!hasManifest()) return nullptr; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + return header->request_bitmap; +} + +void Interest::setRequestBitmap(const uint32_t *request_bitmap) { + if (!hasManifest()) return; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + memcpy(header->request_bitmap, request_bitmap, + BITMAP_SIZE * sizeof(uint32_t)); +} + +bool Interest::isValid() { + if (!hasManifest()) return true; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + + if (header->n_suffixes == 0 || + header->n_suffixes > MAX_SUFFIXES_IN_MANIFEST) { + std::cerr << "Manifest with invalid number of suffixes " + << header->n_suffixes; + return false; + } + + uint32_t empty_bitmap[BITMAP_SIZE]; + memset(empty_bitmap, 0, sizeof(empty_bitmap)); + if (memcmp(empty_bitmap, header->request_bitmap, sizeof(empty_bitmap)) == 0) { + std::cerr << "Manifest with empty bitmap"; + return false; + } + + return true; +} + } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/io_module.cc b/libtransport/src/core/io_module.cc index 69e4e8bcf..0f92cc47c 100644 --- a/libtransport/src/core/io_module.cc +++ b/libtransport/src/core/io_module.cc @@ -16,6 +16,7 @@ #ifndef _WIN32 #include <dlfcn.h> #endif +#include <core/global_module_manager.h> #include <glog/logging.h> #include <hicn/transport/core/io_module.h> @@ -38,54 +39,28 @@ IoModule *IoModule::load(const char *module_name) { #ifdef ANDROID return new HicnForwarderModule(); #else - void *handle = 0; - IoModule *module = 0; - IoModule *(*creator)(void) = 0; - const char *error = 0; + IoModule *iomodule = nullptr; + IoModule *(*creator)(void) = nullptr; + const char *error = nullptr; - // open module - handle = dlopen(module_name, RTLD_NOW); - if (!handle) { - if ((error = dlerror()) != 0) { - LOG(ERROR) << error; - } - return 0; - } + auto handle = GlobalModuleManager::getInstance().loadModule(module_name); // get factory method creator = (IoModule * (*)(void)) dlsym(handle, "create_module"); if (!creator) { - if ((error = dlerror()) != 0) { + if ((error = dlerror()) != nullptr) { LOG(ERROR) << error; } - return 0; + return nullptr; } // create object and return it - module = (*creator)(); - module->handle_ = handle; + iomodule = (*creator)(); - return module; + return iomodule; #endif } -bool IoModule::unload(IoModule *module) { - if (!module) { - return false; - } - -#ifdef ANDROID - delete module; -#else - // destroy object and close module - void *handle = module->handle_; - delete module; - dlclose(handle); -#endif - - return true; -} - } // namespace core } // namespace transport diff --git a/libtransport/src/core/local_connector.cc b/libtransport/src/core/local_connector.cc deleted file mode 100644 index f27be2e5c..000000000 --- a/libtransport/src/core/local_connector.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <core/local_connector.h> -#include <glog/logging.h> -#include <hicn/transport/core/asio_wrapper.h> -#include <hicn/transport/core/content_object.h> -#include <hicn/transport/core/interest.h> -#include <hicn/transport/errors/not_implemented_exception.h> - -namespace transport { -namespace core { - -LocalConnector::~LocalConnector() {} - -void LocalConnector::close() { state_ = State::CLOSED; } - -void LocalConnector::send(Packet &packet) { - if (!isConnected()) { - return; - } - - auto buffer = - std::static_pointer_cast<utils::MemBuf>(packet.shared_from_this()); - - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending packet to local socket."; - io_service_.get().post([this, buffer]() mutable { - std::vector<utils::MemBuf::Ptr> v{std::move(buffer)}; - receive_callback_(this, v, std::make_error_code(std::errc(0))); - }); -} - -void LocalConnector::send(const utils::MemBuf::Ptr &buffer) { - throw errors::NotImplementedException(); -} - -} // namespace core -} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.h b/libtransport/src/core/local_connector.h index eede89e74..963f455e6 100644 --- a/libtransport/src/core/local_connector.h +++ b/libtransport/src/core/local_connector.h @@ -15,9 +15,11 @@ #pragma once +#include <core/errors.h> #include <hicn/transport/core/asio_wrapper.h> #include <hicn/transport/core/connector.h> #include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/errors/not_implemented_exception.h> #include <hicn/transport/utils/move_wrapper.h> #include <hicn/transport/utils/shared_ptr_utils.h> #include <io_modules/forwarder/errors.h> @@ -34,19 +36,48 @@ class LocalConnector : public Connector { OnClose &&close_callback, OnReconnect &&on_reconnect) : Connector(receive_callback, packet_sent, close_callback, on_reconnect), io_service_(io_service), - io_service_work_(io_service_.get()) { - state_ = State::CONNECTED; - } + io_service_work_(io_service_.get()) {} - ~LocalConnector() override; + ~LocalConnector() override = default; - void send(Packet &packet) override; + auto shared_from_this() { return utils::shared_from(this); } - void send(const utils::MemBuf::Ptr &buffer) override; + void send(Packet &packet) override { send(packet.shared_from_this()); } - void close() override; + void send(const utils::MemBuf::Ptr &buffer) override { + throw errors::NotImplementedException(); + } - auto shared_from_this() { return utils::shared_from(this); } + void receive(const std::vector<utils::MemBuf::Ptr> &buffers) override { + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending packet to local socket."; + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self, _buffers{std::move(buffers)}]() mutable { + if (auto ptr = self.lock()) { + ptr->receive_callback_(ptr.get(), _buffers, + make_error_code(core_error::success)); + } + }); + } + + void reconnect() override { + state_ = State::CONNECTED; + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self]() { + if (auto ptr = self.lock()) { + ptr->on_reconnect_callback_(ptr.get(), + make_error_code(core_error::success)); + } + }); + } + + void close() override { + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self]() mutable { + if (auto ptr = self.lock()) { + ptr->on_close_callback_(ptr.get()); + } + }); + } private: std::reference_wrapper<asio::io_service> io_service_; diff --git a/libtransport/src/core/manifest.cc b/libtransport/src/core/manifest.cc deleted file mode 100644 index da2689426..000000000 --- a/libtransport/src/core/manifest.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/core/manifest.h> - -namespace transport { - -namespace core { - -std::string ManifestEncoding::manifest_type = std::string("manifest_type"); - -std::map<ManifestType, std::string> ManifestEncoding::manifest_types = { - {FINAL_CHUNK_NUMBER, "FinalChunkNumber"}, {NAME_LIST, "NameList"}}; - -std::string ManifestEncoding::final_chunk_number = - std::string("final_chunk_number"); -std::string ManifestEncoding::content_name = std::string("content_name"); - -} // end namespace core - -} // end namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/manifest.h b/libtransport/src/core/manifest.h index 5bdbfc6ff..40832bb6b 100644 --- a/libtransport/src/core/manifest.h +++ b/libtransport/src/core/manifest.h @@ -17,165 +17,72 @@ #include <core/manifest_format.h> #include <glog/logging.h> -#include <hicn/transport/core/content_object.h> -#include <hicn/transport/core/name.h> - -#include <set> +#include <hicn/transport/auth/verifier.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/core/packet.h> namespace transport { - namespace core { -using typename core::Name; -using typename core::Packet; -using typename core::PayloadType; - -template <typename Base, typename FormatTraits, typename ManifestImpl> -class Manifest : public Base { - static_assert(std::is_base_of<Packet, Base>::value, - "Base must inherit from packet!"); - +template <typename FormatTraits> +class Manifest : public FormatTraits::Encoder, public FormatTraits::Decoder { public: - // core::ContentObjectManifest::Ptr + using Ptr = std::shared_ptr<Manifest>; using Encoder = typename FormatTraits::Encoder; using Decoder = typename FormatTraits::Decoder; - Manifest(Packet::Format format, std::size_t signature_size = 0) - : Base(format, signature_size), - encoder_(*this, signature_size), - decoder_(*this) { - DCHECK(_is_ah(format)); - Base::setPayloadType(PayloadType::MANIFEST); - } + using Hash = typename FormatTraits::Hash; + using HashType = typename FormatTraits::HashType; + using Suffix = typename FormatTraits::Suffix; + using SuffixList = typename FormatTraits::SuffixList; + using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; - Manifest(Packet::Format format, const core::Name &name, - std::size_t signature_size = 0) - : Base(name, format, signature_size), - encoder_(*this, signature_size), - decoder_(*this) { - DCHECK(_is_ah(format)); - Base::setPayloadType(PayloadType::MANIFEST); + Manifest(Packet::Ptr packet, bool clear = false) + : Encoder(packet, clear), Decoder(packet), packet_(packet) { + packet->setPayloadType(PayloadType::MANIFEST); } - template <typename T> - Manifest(T &&base) - : Base(std::forward<T &&>(base)), - encoder_(*this, 0, false), - decoder_(*this) { - Base::setPayloadType(PayloadType::MANIFEST); - } - - // Useful for decoding manifests while avoiding packet copy - template <typename T> - Manifest(T &base) - : Base(base.getFormat()), encoder_(base, 0, false), decoder_(base) {} - virtual ~Manifest() = default; - std::size_t estimateManifestSize(std::size_t additional_entries = 0) { - return static_cast<ManifestImpl &>(*this).estimateManifestSizeImpl( - additional_entries); - } - - /* - * After the call to encode, users MUST call clear before adding data - * to the manifest. - */ - Manifest &encode() { return static_cast<ManifestImpl &>(*this).encodeImpl(); } - - Manifest &decode() { - Manifest::decoder_.decode(); - - manifest_type_ = decoder_.getType(); - manifest_transport_type_ = decoder_.getTransportType(); - hash_algorithm_ = decoder_.getHashAlgorithm(); - is_last_ = decoder_.getIsLast(); + Packet::Ptr getPacket() const { return packet_; } - return static_cast<ManifestImpl &>(*this).decodeImpl(); + void setHeaders(ManifestType type, uint8_t max_capacity, HashType hash_algo, + bool is_last, const Name &base_name) { + Encoder::setType(type); + Encoder::setMaxCapacity(max_capacity); + Encoder::setHashAlgorithm(hash_algo); + Encoder::setIsLast(is_last); + Encoder::setBaseName(base_name); } - static std::size_t manifestHeaderSize( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN) { - return Encoder::manifestHeaderSize(transport_type); - } + auth::Verifier::SuffixMap getSuffixMap() const { + auth::Verifier::SuffixMap suffix_map; - static std::size_t manifestEntrySize() { - return Encoder::manifestEntrySize(); - } + HashType hash_algo = Decoder::getHashAlgorithm(); + SuffixList suffix_list = Decoder::getEntries(); - Manifest &setType(ManifestType type) { - manifest_type_ = type; - encoder_.setType(manifest_type_); - return *this; - } + for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { + Hash hash(it->second, Hash::getSize(hash_algo), hash_algo); + suffix_map[it->first] = hash; + } - Manifest &setHashAlgorithm(auth::CryptoHashType hash_algorithm) { - hash_algorithm_ = hash_algorithm; - encoder_.setHashAlgorithm(hash_algorithm_); - return *this; + return suffix_map; } - auth::CryptoHashType getHashAlgorithm() const { return hash_algorithm_; } - - ManifestType getType() const { return manifest_type_; } - - interface::ProductionProtocolAlgorithms getTransportType() const { - return manifest_transport_type_; - } - - bool getIsLast() const { return is_last_; } - - Manifest &setVersion(ManifestVersion version) { - encoder_.setVersion(version); - return *this; - } - - Manifest &setParamsBytestream(const ParamsBytestream ¶ms) { - manifest_transport_type_ = - interface::ProductionProtocolAlgorithms::BYTE_STREAM; - encoder_.setParamsBytestream(params); - return *this; - } - - Manifest &setParamsRTC(const ParamsRTC ¶ms) { - manifest_transport_type_ = - interface::ProductionProtocolAlgorithms::RTC_PROD; - encoder_.setParamsRTC(params); - return *this; - } - - ParamsBytestream getParamsBytestream() const { - return decoder_.getParamsBytestream(); - } - - ParamsRTC getParamsRTC() const { return decoder_.getParamsRTC(); } - - ManifestVersion getVersion() const { return decoder_.getVersion(); } - - Manifest &setIsLast(bool is_last) { - encoder_.setIsLast(is_last); - is_last_ = is_last; - return *this; - } - - Manifest &clear() { - encoder_.clear(); - decoder_.clear(); - return *this; - } + static Manifest::Ptr createContentManifest(Packet::Format format, + const core::Name &manifest_name, + std::size_t signature_size) { + ContentObject::Ptr content_object = + core::PacketManager<>::getInstance().getPacket<ContentObject>( + format, signature_size); + content_object->setName(manifest_name); + return std::make_shared<Manifest>(content_object, true); + }; protected: - ManifestType manifest_type_; - interface::ProductionProtocolAlgorithms manifest_transport_type_; - auth::CryptoHashType hash_algorithm_; - bool is_last_; - - Encoder encoder_; - Decoder decoder_; + Packet::Ptr packet_; }; } // end namespace core - } // end namespace transport diff --git a/libtransport/src/core/manifest_format.h b/libtransport/src/core/manifest_format.h index caee210cd..89412316a 100644 --- a/libtransport/src/core/manifest_format.h +++ b/libtransport/src/core/manifest_format.h @@ -25,13 +25,8 @@ #include <unordered_map> namespace transport { - namespace core { -enum class ManifestVersion : uint8_t { - VERSION_1 = 1, -}; - enum class ManifestType : uint8_t { INLINE_MANIFEST = 1, FINAL_CHUNK_NUMBER = 2, @@ -83,14 +78,27 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).clearImpl(); } + bool isEncoded() const { + return static_cast<const Implementation &>(*this).isEncodedImpl(); + } + ManifestEncoder &setType(ManifestType type) { return static_cast<Implementation &>(*this).setTypeImpl(type); } + ManifestEncoder &setMaxCapacity(uint8_t max_capacity) { + return static_cast<Implementation &>(*this).setMaxCapacityImpl( + max_capacity); + } + ManifestEncoder &setHashAlgorithm(auth::CryptoHashType hash) { return static_cast<Implementation &>(*this).setHashAlgorithmImpl(hash); } + ManifestEncoder &setIsLast(bool is_last) { + return static_cast<Implementation &>(*this).setIsLastImpl(is_last); + } + template < typename T, typename = std::enable_if_t<std::is_same< @@ -99,45 +107,36 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).setBaseNameImpl(name); } - template <typename Hash> - ManifestEncoder &addSuffixAndHash(uint32_t suffix, Hash &&hash) { - return static_cast<Implementation &>(*this).addSuffixAndHashImpl( - suffix, std::forward<Hash &&>(hash)); + ManifestEncoder &setParamsBytestream(const ParamsBytestream ¶ms) { + return static_cast<Implementation &>(*this).setParamsBytestreamImpl(params); } - ManifestEncoder &setIsLast(bool is_last) { - return static_cast<Implementation &>(*this).setIsLastImpl(is_last); + ManifestEncoder &setParamsRTC(const ParamsRTC ¶ms) { + return static_cast<Implementation &>(*this).setParamsRTCImpl(params); } - ManifestEncoder &setVersion(ManifestVersion version) { - return static_cast<Implementation &>(*this).setVersionImpl(version); + template <typename Hash> + ManifestEncoder &addEntry(uint32_t suffix, Hash &&hash) { + return static_cast<Implementation &>(*this).addEntryImpl( + suffix, std::forward<Hash>(hash)); } - std::size_t estimateSerializedLength(std::size_t number_of_entries) { - return static_cast<Implementation &>(*this).estimateSerializedLengthImpl( - number_of_entries); + ManifestEncoder &removeEntry(uint32_t suffix) { + return static_cast<Implementation &>(*this).removeEntryImpl(suffix); } - ManifestEncoder &update() { - return static_cast<Implementation &>(*this).updateImpl(); + std::size_t manifestHeaderSize() const { + return static_cast<const Implementation &>(*this).manifestHeaderSizeImpl(); } - ManifestEncoder &setParamsBytestream(const ParamsBytestream ¶ms) { - return static_cast<Implementation &>(*this).setParamsBytestreamImpl(params); + std::size_t manifestPayloadSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestPayloadSizeImpl( + additional_entries); } - ManifestEncoder &setParamsRTC(const ParamsRTC ¶ms) { - return static_cast<Implementation &>(*this).setParamsRTCImpl(params); - } - - static std::size_t manifestHeaderSize( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN) { - return Implementation::manifestHeaderSizeImpl(transport_type); - } - - static std::size_t manifestEntrySize() { - return Implementation::manifestEntrySizeImpl(); + std::size_t manifestSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestSizeImpl( + additional_entries); } }; @@ -146,11 +145,17 @@ class ManifestDecoder { public: virtual ~ManifestDecoder() = default; + ManifestDecoder &decode() { + return static_cast<Implementation &>(*this).decodeImpl(); + } + ManifestDecoder &clear() { return static_cast<Implementation &>(*this).clearImpl(); } - void decode() { static_cast<Implementation &>(*this).decodeImpl(); } + bool isDecoded() const { + return static_cast<const Implementation &>(*this).isDecodedImpl(); + } ManifestType getType() const { return static_cast<const Implementation &>(*this).getTypeImpl(); @@ -160,40 +165,48 @@ class ManifestDecoder { return static_cast<const Implementation &>(*this).getTransportTypeImpl(); } + uint8_t getMaxCapacity() const { + return static_cast<const Implementation &>(*this).getMaxCapacityImpl(); + } + auth::CryptoHashType getHashAlgorithm() const { return static_cast<const Implementation &>(*this).getHashAlgorithmImpl(); } + bool getIsLast() const { + return static_cast<const Implementation &>(*this).getIsLastImpl(); + } + core::Name getBaseName() const { return static_cast<const Implementation &>(*this).getBaseNameImpl(); } - auto getSuffixHashList() { - return static_cast<Implementation &>(*this).getSuffixHashListImpl(); + ParamsBytestream getParamsBytestream() const { + return static_cast<const Implementation &>(*this).getParamsBytestreamImpl(); } - bool getIsLast() const { - return static_cast<const Implementation &>(*this).getIsLastImpl(); + ParamsRTC getParamsRTC() const { + return static_cast<const Implementation &>(*this).getParamsRTCImpl(); } - ManifestVersion getVersion() const { - return static_cast<const Implementation &>(*this).getVersionImpl(); + auto getEntries() const { + return static_cast<const Implementation &>(*this).getEntriesImpl(); } - std::size_t estimateSerializedLength(std::size_t number_of_entries) const { - return static_cast<const Implementation &>(*this) - .estimateSerializedLengthImpl(number_of_entries); + std::size_t manifestHeaderSize() const { + return static_cast<const Implementation &>(*this).manifestHeaderSizeImpl(); } - ParamsBytestream getParamsBytestream() const { - return static_cast<const Implementation &>(*this).getParamsBytestreamImpl(); + std::size_t manifestPayloadSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestPayloadSizeImpl( + additional_entries); } - ParamsRTC getParamsRTC() const { - return static_cast<const Implementation &>(*this).getParamsRTCImpl(); + std::size_t manifestSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestSizeImpl( + additional_entries); } }; } // namespace core - } // namespace transport diff --git a/libtransport/src/core/manifest_format_fixed.cc b/libtransport/src/core/manifest_format_fixed.cc index 428d6ad12..668169642 100644 --- a/libtransport/src/core/manifest_format_fixed.cc +++ b/libtransport/src/core/manifest_format_fixed.cc @@ -18,22 +18,42 @@ #include <hicn/transport/utils/literals.h> namespace transport { - namespace core { -// TODO use preallocated pool of membufs -FixedManifestEncoder::FixedManifestEncoder(Packet &packet, - std::size_t signature_size, - bool clear) +// --------------------------------------------------------- +// FixedManifest +// --------------------------------------------------------- +size_t FixedManifest::manifestHeaderSize( + interface::ProductionProtocolAlgorithms transport_type) { + uint32_t params_size = 0; + + switch (transport_type) { + case interface::ProductionProtocolAlgorithms::BYTE_STREAM: + params_size = MANIFEST_PARAMS_BYTESTREAM_SIZE; + break; + case interface::ProductionProtocolAlgorithms::RTC_PROD: + params_size = MANIFEST_PARAMS_RTC_SIZE; + break; + default: + break; + } + + return MANIFEST_META_SIZE + MANIFEST_ENTRY_META_SIZE + params_size; +} + +size_t FixedManifest::manifestPayloadSize(size_t nb_entries) { + return nb_entries * MANIFEST_ENTRY_SIZE; +} + +// --------------------------------------------------------- +// FixedManifestEncoder +// --------------------------------------------------------- +FixedManifestEncoder::FixedManifestEncoder(Packet::Ptr packet, bool clear) : packet_(packet), - max_size_(Packet::default_mtu - packet_.headerSize()), - signature_size_(signature_size), transport_type_(interface::ProductionProtocolAlgorithms::UNKNOWN), - encoded_(false), - params_bytestream_({0}), - params_rtc_({0}) { - manifest_meta_ = reinterpret_cast<ManifestMeta *>(packet_.writableData() + - packet_.headerSize()); + encoded_(false) { + manifest_meta_ = reinterpret_cast<ManifestMeta *>(packet_->writableData() + + packet_->headerSize()); manifest_entry_meta_ = reinterpret_cast<ManifestEntryMeta *>(manifest_meta_ + 1); @@ -50,32 +70,34 @@ FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { return *this; } + // Copy manifest header manifest_meta_->transport_type = static_cast<uint8_t>(transport_type_); manifest_entry_meta_->nb_entries = manifest_entries_.size(); - packet_.append(FixedManifestEncoder::manifestHeaderSizeImpl()); - packet_.updateLength(); + packet_->append(manifestHeaderSizeImpl()); + packet_->updateLength(); + auto params = reinterpret_cast<uint8_t *>(manifest_entry_meta_ + 1); switch (transport_type_) { - case interface::ProductionProtocolAlgorithms::BYTE_STREAM: - packet_.appendPayload( - reinterpret_cast<const uint8_t *>(¶ms_bytestream_), - MANIFEST_PARAMS_BYTESTREAM_SIZE); + case interface::ProductionProtocolAlgorithms::BYTE_STREAM: { + auto bytestream = reinterpret_cast<const uint8_t *>(¶ms_bytestream_); + std::memcpy(params, bytestream, MANIFEST_PARAMS_BYTESTREAM_SIZE); break; - case interface::ProductionProtocolAlgorithms::RTC_PROD: - packet_.appendPayload(reinterpret_cast<const uint8_t *>(¶ms_rtc_), - MANIFEST_PARAMS_RTC_SIZE); + } + case interface::ProductionProtocolAlgorithms::RTC_PROD: { + auto rtc = reinterpret_cast<const uint8_t *>(¶ms_rtc_); + std::memcpy(params, rtc, MANIFEST_PARAMS_RTC_SIZE); break; + } default: break; } - packet_.appendPayload( - reinterpret_cast<const uint8_t *>(manifest_entries_.data()), - manifest_entries_.size() * FixedManifestEncoder::manifestEntrySizeImpl()); + // Copy manifest entries + auto payload = reinterpret_cast<const uint8_t *>(manifest_entries_.data()); + packet_->appendPayload(payload, manifestPayloadSizeImpl()); - if (TRANSPORT_EXPECT_FALSE(packet_.payloadSize() < - estimateSerializedLengthImpl())) { + if (TRANSPORT_EXPECT_FALSE(packet_->payloadSize() < manifestSizeImpl())) { throw errors::RuntimeException("Error encoding the manifest"); } @@ -85,32 +107,21 @@ FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { FixedManifestEncoder &FixedManifestEncoder::clearImpl() { if (encoded_) { - packet_.trimEnd(FixedManifestEncoder::manifestHeaderSizeImpl() + - manifest_entries_.size() * - FixedManifestEncoder::manifestEntrySizeImpl()); + packet_->trimEnd(manifestSizeImpl()); } transport_type_ = interface::ProductionProtocolAlgorithms::UNKNOWN; encoded_ = false; - params_bytestream_ = {0}; - params_rtc_ = {0}; *manifest_meta_ = {0}; *manifest_entry_meta_ = {0}; + params_bytestream_ = {0}; + params_rtc_ = {0}; manifest_entries_.clear(); return *this; } -FixedManifestEncoder &FixedManifestEncoder::updateImpl() { - max_size_ = Packet::default_mtu - packet_.headerSize() - signature_size_; - return *this; -} - -FixedManifestEncoder &FixedManifestEncoder::setVersionImpl( - ManifestVersion version) { - manifest_meta_->version = static_cast<uint8_t>(version); - return *this; -} +bool FixedManifestEncoder::isEncodedImpl() const { return encoded_; } FixedManifestEncoder &FixedManifestEncoder::setTypeImpl( ManifestType manifest_type) { @@ -118,6 +129,12 @@ FixedManifestEncoder &FixedManifestEncoder::setTypeImpl( return *this; } +FixedManifestEncoder &FixedManifestEncoder::setMaxCapacityImpl( + uint8_t max_capacity) { + manifest_meta_->max_capacity = max_capacity; + return *this; +} + FixedManifestEncoder &FixedManifestEncoder::setHashAlgorithmImpl( auth::CryptoHashType algorithm) { manifest_meta_->hash_algorithm = static_cast<uint8_t>(algorithm); @@ -159,61 +176,68 @@ FixedManifestEncoder &FixedManifestEncoder::setParamsRTCImpl( return *this; } -FixedManifestEncoder &FixedManifestEncoder::addSuffixAndHashImpl( +FixedManifestEncoder &FixedManifestEncoder::addEntryImpl( uint32_t suffix, const auth::CryptoHash &hash) { - manifest_entries_.push_back(ManifestEntry{ - .suffix = htonl(suffix), + ManifestEntry last_entry = { + .suffix = portability::host_to_net(suffix), .hash = {0}, - }); + }; - std::memcpy(reinterpret_cast<uint8_t *>(manifest_entries_.back().hash), - hash.getDigest()->data(), hash.getSize()); - - if (TRANSPORT_EXPECT_FALSE(estimateSerializedLengthImpl() > max_size_)) { - throw errors::RuntimeException("Manifest size exceeded the packet MTU!"); - } + auto last_hash = reinterpret_cast<uint8_t *>(last_entry.hash); + std::memcpy(last_hash, hash.getDigest()->data(), hash.getSize()); + manifest_entries_.push_back(last_entry); return *this; } -std::size_t FixedManifestEncoder::estimateSerializedLengthImpl( - std::size_t additional_entries) { - return FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) + - (manifest_entries_.size() + additional_entries) * - FixedManifestEncoder::manifestEntrySizeImpl(); +FixedManifestEncoder &FixedManifestEncoder::removeEntryImpl(uint32_t suffix) { + for (auto it = manifest_entries_.begin(); it != manifest_entries_.end();) { + if (it->suffix == suffix) + it = manifest_entries_.erase(it); + else + ++it; + } + return *this; } -std::size_t FixedManifestEncoder::manifestHeaderSizeImpl( - interface::ProductionProtocolAlgorithms transport_type) { - uint32_t params_size = 0; - - switch (transport_type) { - case interface::ProductionProtocolAlgorithms::BYTE_STREAM: - params_size = MANIFEST_PARAMS_BYTESTREAM_SIZE; - break; - case interface::ProductionProtocolAlgorithms::RTC_PROD: - params_size = MANIFEST_PARAMS_RTC_SIZE; - break; - default: - break; - } +size_t FixedManifestEncoder::manifestHeaderSizeImpl() const { + return FixedManifest::manifestHeaderSize(transport_type_); +} - return MANIFEST_META_SIZE + MANIFEST_ENTRY_META_SIZE + params_size; +size_t FixedManifestEncoder::manifestPayloadSizeImpl( + size_t additional_entries) const { + return FixedManifest::manifestPayloadSize(manifest_entries_.size() + + additional_entries); } -std::size_t FixedManifestEncoder::manifestEntrySizeImpl() { - return MANIFEST_ENTRY_SIZE; +size_t FixedManifestEncoder::manifestSizeImpl(size_t additional_entries) const { + return manifestHeaderSizeImpl() + manifestPayloadSizeImpl(additional_entries); } -FixedManifestDecoder::FixedManifestDecoder(Packet &packet) +// --------------------------------------------------------- +// FixedManifestDecoder +// --------------------------------------------------------- +FixedManifestDecoder::FixedManifestDecoder(Packet::Ptr packet) : packet_(packet), decoded_(false) { manifest_meta_ = - reinterpret_cast<ManifestMeta *>(packet_.getPayload()->writableData()); + reinterpret_cast<ManifestMeta *>(packet_->getPayload()->writableData()); manifest_entry_meta_ = reinterpret_cast<ManifestEntryMeta *>(manifest_meta_ + 1); - transport_type_ = getTransportTypeImpl(); +} - switch (transport_type_) { +FixedManifestDecoder::~FixedManifestDecoder() {} + +FixedManifestDecoder &FixedManifestDecoder::decodeImpl() { + if (decoded_) { + return *this; + } + + if (packet_->payloadSize() < manifestSizeImpl()) { + throw errors::RuntimeException( + "The packet payload size does not match expected manifest size"); + } + + switch (getTransportTypeImpl()) { case interface::ProductionProtocolAlgorithms::BYTE_STREAM: params_bytestream_ = reinterpret_cast<TransportParamsBytestream *>( manifest_entry_meta_ + 1); @@ -230,25 +254,9 @@ FixedManifestDecoder::FixedManifestDecoder(Packet &packet) reinterpret_cast<ManifestEntry *>(manifest_entry_meta_ + 1); break; } -} - -FixedManifestDecoder::~FixedManifestDecoder() {} - -void FixedManifestDecoder::decodeImpl() { - if (decoded_) { - return; - } - - std::size_t packet_size = packet_.payloadSize(); - - if (packet_size < - FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) || - packet_size < estimateSerializedLengthImpl()) { - throw errors::RuntimeException( - "The packet does not match expected manifest size."); - } decoded_ = true; + return *this; } FixedManifestDecoder &FixedManifestDecoder::clearImpl() { @@ -256,20 +264,22 @@ FixedManifestDecoder &FixedManifestDecoder::clearImpl() { return *this; } +bool FixedManifestDecoder::isDecodedImpl() const { return decoded_; } + ManifestType FixedManifestDecoder::getTypeImpl() const { return static_cast<ManifestType>(manifest_meta_->type); } -ManifestVersion FixedManifestDecoder::getVersionImpl() const { - return static_cast<ManifestVersion>(manifest_meta_->version); -} - interface::ProductionProtocolAlgorithms FixedManifestDecoder::getTransportTypeImpl() const { return static_cast<interface::ProductionProtocolAlgorithms>( manifest_meta_->transport_type); } +uint8_t FixedManifestDecoder::getMaxCapacityImpl() const { + return manifest_meta_->max_capacity; +} + auth::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { return static_cast<auth::CryptoHashType>(manifest_meta_->hash_algorithm); } @@ -303,26 +313,34 @@ ParamsRTC FixedManifestDecoder::getParamsRTCImpl() const { }; } -typename Fixed::SuffixList FixedManifestDecoder::getSuffixHashListImpl() { +typename Fixed::SuffixList FixedManifestDecoder::getEntriesImpl() const { typename Fixed::SuffixList hash_list; for (int i = 0; i < manifest_entry_meta_->nb_entries; i++) { - hash_list.insert(hash_list.end(), - std::make_pair(ntohl(manifest_entries_[i].suffix), - reinterpret_cast<uint8_t *>( - &manifest_entries_[i].hash[0]))); + hash_list.insert( + hash_list.end(), + std::make_pair( + portability::net_to_host(manifest_entries_[i].suffix), + reinterpret_cast<uint8_t *>(&manifest_entries_[i].hash[0]))); } return hash_list; } -std::size_t FixedManifestDecoder::estimateSerializedLengthImpl( - std::size_t additional_entries) const { - return FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) + - (manifest_entry_meta_->nb_entries + additional_entries) * - FixedManifestEncoder::manifestEntrySizeImpl(); +size_t FixedManifestDecoder::manifestHeaderSizeImpl() const { + interface::ProductionProtocolAlgorithms type = getTransportTypeImpl(); + return FixedManifest::manifestHeaderSize(type); } -} // end namespace core +size_t FixedManifestDecoder::manifestPayloadSizeImpl( + size_t additional_entries) const { + size_t nb_entries = manifest_entry_meta_->nb_entries + additional_entries; + return FixedManifest::manifestPayloadSize(nb_entries); +} +size_t FixedManifestDecoder::manifestSizeImpl(size_t additional_entries) const { + return manifestHeaderSizeImpl() + manifestPayloadSizeImpl(additional_entries); +} + +} // end namespace core } // end namespace transport diff --git a/libtransport/src/core/manifest_format_fixed.h b/libtransport/src/core/manifest_format_fixed.h index 5fd2a673d..7ab371974 100644 --- a/libtransport/src/core/manifest_format_fixed.h +++ b/libtransport/src/core/manifest_format_fixed.h @@ -28,7 +28,7 @@ namespace core { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |Version| Type | Transport Type| Hash Algorithm|L| Reserved | +// | Type | TTYpe | Max Capacity | Hash Algo |L| Reserved | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // Manifest Entry Metadata: @@ -106,9 +106,9 @@ struct Fixed { const size_t MANIFEST_META_SIZE = 4; struct __attribute__((__packed__)) ManifestMeta { - std::uint8_t version : 4; std::uint8_t type : 4; - std::uint8_t transport_type; + std::uint8_t transport_type : 4; + std::uint8_t max_capacity; std::uint8_t hash_algorithm; std::uint8_t is_last; }; @@ -146,22 +146,26 @@ struct __attribute__((__packed__)) ManifestEntry { }; static_assert(sizeof(ManifestEntry) == MANIFEST_ENTRY_SIZE); -static const constexpr std::uint8_t manifest_version = 1; +class FixedManifest { + public: + static size_t manifestHeaderSize( + interface::ProductionProtocolAlgorithms transport_type); + static size_t manifestPayloadSize(size_t nb_entries); +}; class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { public: - FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0, - bool clear = true); + FixedManifestEncoder(Packet::Ptr packet, bool clear = false); ~FixedManifestEncoder(); FixedManifestEncoder &encodeImpl(); FixedManifestEncoder &clearImpl(); - FixedManifestEncoder &updateImpl(); + bool isEncodedImpl() const; // ManifestMeta - FixedManifestEncoder &setVersionImpl(ManifestVersion version); FixedManifestEncoder &setTypeImpl(ManifestType manifest_type); + FixedManifestEncoder &setMaxCapacityImpl(uint8_t max_capacity); FixedManifestEncoder &setHashAlgorithmImpl(Fixed::HashType algorithm); FixedManifestEncoder &setIsLastImpl(bool is_last); @@ -173,20 +177,15 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setParamsRTCImpl(const ParamsRTC ¶ms); // ManifestEntry - FixedManifestEncoder &addSuffixAndHashImpl(uint32_t suffix, - const Fixed::Hash &hash); + FixedManifestEncoder &addEntryImpl(uint32_t suffix, const Fixed::Hash &hash); + FixedManifestEncoder &removeEntryImpl(uint32_t suffix); - std::size_t estimateSerializedLengthImpl(std::size_t additional_entries = 0); - - static std::size_t manifestHeaderSizeImpl( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN); - static std::size_t manifestEntrySizeImpl(); + size_t manifestHeaderSizeImpl() const; + size_t manifestPayloadSizeImpl(size_t additional_entries = 0) const; + size_t manifestSizeImpl(size_t additional_entries = 0) const; private: - Packet &packet_; - std::size_t max_size_; - std::size_t signature_size_; + Packet::Ptr packet_; interface::ProductionProtocolAlgorithms transport_type_; bool encoded_; @@ -202,17 +201,18 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { public: - FixedManifestDecoder(Packet &packet); + FixedManifestDecoder(Packet::Ptr packet); ~FixedManifestDecoder(); - void decodeImpl(); + FixedManifestDecoder &decodeImpl(); FixedManifestDecoder &clearImpl(); + bool isDecodedImpl() const; // ManifestMeta - ManifestVersion getVersionImpl() const; ManifestType getTypeImpl() const; interface::ProductionProtocolAlgorithms getTransportTypeImpl() const; + uint8_t getMaxCapacityImpl() const; Fixed::HashType getHashAlgorithmImpl() const; bool getIsLastImpl() const; @@ -224,14 +224,14 @@ class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { ParamsRTC getParamsRTCImpl() const; // ManifestEntry - typename Fixed::SuffixList getSuffixHashListImpl(); + typename Fixed::SuffixList getEntriesImpl() const; - std::size_t estimateSerializedLengthImpl( - std::size_t additional_entries = 0) const; + size_t manifestHeaderSizeImpl() const; + size_t manifestPayloadSizeImpl(size_t additional_entries = 0) const; + size_t manifestSizeImpl(size_t additional_entries = 0) const; private: - Packet &packet_; - interface::ProductionProtocolAlgorithms transport_type_; + Packet::Ptr packet_; bool decoded_; // Manifest Header diff --git a/libtransport/src/core/manifest_inline.h b/libtransport/src/core/manifest_inline.h deleted file mode 100644 index ca48a4a79..000000000 --- a/libtransport/src/core/manifest_inline.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <core/manifest.h> -#include <core/manifest_format.h> -#include <hicn/transport/portability/portability.h> - -#include <set> - -namespace transport { - -namespace core { - -template <typename Base, typename FormatTraits> -class ManifestInline - : public Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>> { - using ManifestBase = - Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>>; - - using Hash = typename FormatTraits::Hash; - using HashType = typename FormatTraits::HashType; - using Suffix = typename FormatTraits::Suffix; - using SuffixList = typename FormatTraits::SuffixList; - using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; - - public: - ManifestInline() : ManifestBase() {} - - ManifestInline(Packet::Format format, const core::Name &name, - std::size_t signature_size = 0) - : ManifestBase(format, name, signature_size) {} - - template <typename T> - ManifestInline(T &&base) : ManifestBase(std::forward<T &&>(base)) {} - - template <typename T> - ManifestInline(T &base) : ManifestBase(base) {} - - static TRANSPORT_ALWAYS_INLINE ManifestInline *createManifest( - Packet::Format format, const core::Name &manifest_name, - ManifestVersion version, ManifestType type, bool is_last, - const Name &base_name, HashType hash_algo, std::size_t signature_size) { - auto manifest = new ManifestInline(format, manifest_name, signature_size); - manifest->setVersion(version); - manifest->setType(type); - manifest->setHashAlgorithm(hash_algo); - manifest->setIsLast(is_last); - manifest->setBaseName(base_name); - return manifest; - } - - ManifestInline &encodeImpl() { - ManifestBase::encoder_.encode(); - return *this; - } - - ManifestInline &decodeImpl() { - base_name_ = ManifestBase::decoder_.getBaseName(); - suffix_hash_map_ = ManifestBase::decoder_.getSuffixHashList(); - - return *this; - } - - std::size_t estimateManifestSizeImpl(std::size_t additional_entries = 0) { - return ManifestBase::encoder_.estimateSerializedLength(additional_entries); - } - - ManifestInline &setBaseName(const Name &name) { - base_name_ = name; - ManifestBase::encoder_.setBaseName(base_name_); - return *this; - } - - const Name &getBaseName() { return base_name_; } - - ManifestInline &addSuffixHash(Suffix suffix, const Hash &hash) { - ManifestBase::encoder_.addSuffixAndHash(suffix, hash); - return *this; - } - - // Call this function only after the decode function! - const SuffixList &getSuffixList() { return suffix_hash_map_; } - - // Convert several manifests into a single map from suffixes to packet hashes. - // All manifests must have been decoded beforehand. - static std::unordered_map<Suffix, Hash> getSuffixMap( - const std::vector<ManifestInline *> &manifests) { - std::unordered_map<Suffix, Hash> suffix_map; - - for (auto manifest_ptr : manifests) { - HashType hash_type = manifest_ptr->getHashAlgorithm(); - SuffixList suffix_list = manifest_ptr->getSuffixList(); - - for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { - Hash hash(it->second, Hash::getSize(hash_type), hash_type); - suffix_map[it->first] = hash; - } - } - - return suffix_map; - } - - static std::unordered_map<Suffix, Hash> getSuffixMap( - ManifestInline *manifest) { - return getSuffixMap(std::vector<ManifestInline *>{manifest}); - } - - private: - core::Name base_name_; - SuffixList suffix_hash_map_; -}; - -} // namespace core -} // namespace transport diff --git a/libtransport/src/core/name.cc b/libtransport/src/core/name.cc index 98091eea5..960947cb9 100644 --- a/libtransport/src/core/name.cc +++ b/libtransport/src/core/name.cc @@ -24,7 +24,7 @@ namespace transport { namespace core { -Name::Name() { name_ = {}; } +Name::Name() { std::memset(&name_, 0, sizeof(name_)); } /** * XXX This function does not use the name API provided by libhicn @@ -47,6 +47,7 @@ Name::Name(int family, const uint8_t *ip_address, std::uint32_t suffix) std::memcpy(dst, ip_address, length); name_.suffix = suffix; } + Name::Name(const char *name, uint32_t segment) { if (hicn_name_create(name, segment, &name_) < 0) { throw errors::InvalidIpAddressException(); diff --git a/libtransport/src/core/pending_interest.h b/libtransport/src/core/pending_interest.h index f8a4ba10e..fb10405d3 100644 --- a/libtransport/src/core/pending_interest.h +++ b/libtransport/src/core/pending_interest.h @@ -42,17 +42,9 @@ class PendingInterest { public: using Ptr = utils::ObjectPool<PendingInterest>::Ptr; - // PendingInterest() - // : interest_(nullptr, nullptr), - // timer_(), - // on_content_object_callback_(), - // on_interest_timeout_callback_() {} PendingInterest(asio::io_service &io_service, const Interest::Ptr &interest) - : interest_(interest), - timer_(io_service), - on_content_object_callback_(), - on_interest_timeout_callback_() {} + : interest_(interest), timer_(io_service) {} PendingInterest(asio::io_service &io_service, const Interest::Ptr &interest, OnContentObjectCallback &&on_content_object, @@ -65,10 +57,9 @@ class PendingInterest { ~PendingInterest() = default; template <typename Handler> - TRANSPORT_ALWAYS_INLINE void startCountdown(Handler &&cb) { - timer_.expires_from_now( - std::chrono::milliseconds(interest_->getLifetime())); - timer_.async_wait(std::forward<Handler &&>(cb)); + TRANSPORT_ALWAYS_INLINE void startCountdown(uint32_t lifetime, Handler &&cb) { + timer_.expires_from_now(std::chrono::milliseconds(lifetime)); + timer_.async_wait(std::forward<Handler>(cb)); } TRANSPORT_ALWAYS_INLINE void cancelTimer() { timer_.cancel(); } @@ -77,7 +68,7 @@ class PendingInterest { return std::move(interest_); } - TRANSPORT_ALWAYS_INLINE void setInterest(Interest::Ptr &interest) { + TRANSPORT_ALWAYS_INLINE void setInterest(const Interest::Ptr &interest) { interest_ = interest; } @@ -88,7 +79,7 @@ class PendingInterest { TRANSPORT_ALWAYS_INLINE void setOnContentObjectCallback( OnContentObjectCallback &&on_content_object) { - PendingInterest::on_content_object_callback_ = on_content_object; + PendingInterest::on_content_object_callback_ = std::move(on_content_object); } TRANSPORT_ALWAYS_INLINE const OnInterestTimeoutCallback & @@ -98,7 +89,8 @@ class PendingInterest { TRANSPORT_ALWAYS_INLINE void setOnTimeoutCallback( OnInterestTimeoutCallback &&on_interest_timeout) { - PendingInterest::on_interest_timeout_callback_ = on_interest_timeout; + PendingInterest::on_interest_timeout_callback_ = + std::move(on_interest_timeout); } private: diff --git a/libtransport/src/core/portal.cc b/libtransport/src/core/portal.cc index d8e8d78ea..c06969f19 100644 --- a/libtransport/src/core/portal.cc +++ b/libtransport/src/core/portal.cc @@ -43,12 +43,14 @@ std::string Portal::io_module_path_ = defaultIoModule(); std::string Portal::defaultIoModule() { using namespace std::placeholders; GlobalConfiguration::getInstance().registerConfigurationParser( - io_module_section, + IoModuleConfiguration::section, std::bind(&Portal::parseIoModuleConfiguration, _1, _2)); GlobalConfiguration::getInstance().registerConfigurationGetter( - io_module_section, std::bind(&Portal::getModuleConfiguration, _1, _2)); + IoModuleConfiguration::section, + std::bind(&Portal::getModuleConfiguration, _1, _2)); GlobalConfiguration::getInstance().registerConfigurationSetter( - io_module_section, std::bind(&Portal::setModuleConfiguration, _1, _2)); + IoModuleConfiguration::section, + std::bind(&Portal::setModuleConfiguration, _1, _2)); // return default conf_.name = default_module; @@ -57,7 +59,7 @@ std::string Portal::defaultIoModule() { void Portal::getModuleConfiguration(ConfigurationObject& object, std::error_code& ec) { - DCHECK(object.getKey() == io_module_section); + DCHECK(object.getKey() == IoModuleConfiguration::section); auto conf = dynamic_cast<const IoModuleConfiguration&>(object); conf = conf_; @@ -103,7 +105,7 @@ std::string getIoModulePath(const std::string& name, void Portal::setModuleConfiguration(const ConfigurationObject& object, std::error_code& ec) { - DCHECK(object.getKey() == io_module_section); + DCHECK(object.getKey() == IoModuleConfiguration::section); const IoModuleConfiguration& conf = dynamic_cast<const IoModuleConfiguration&>(object); diff --git a/libtransport/src/core/portal.h b/libtransport/src/core/portal.h index aae4c573e..6f3a48e83 100644 --- a/libtransport/src/core/portal.h +++ b/libtransport/src/core/portal.h @@ -32,6 +32,10 @@ #include <hicn/transport/utils/event_thread.h> #include <hicn/transport/utils/fixed_block_allocator.h> +extern "C" { +#include <hicn/header.h> +} + #include <future> #include <memory> #include <queue> @@ -179,19 +183,11 @@ class Portal : public ::utils::NonCopyable, Portal() : Portal(GlobalWorkers::getInstance().getWorker()) {} Portal(::utils::EventThread &worker) - : io_module_(nullptr, [](IoModule *module) { IoModule::unload(module); }), + : io_module_(nullptr), worker_(worker), app_name_("libtransport_application"), transport_callback_(nullptr), - is_consumer_(false) { - /** - * This workaroung allows to initialize memory for packet buffers *before* - * any static variables that may be initialized in the io_modules. In this - * way static variables in modules will be destroyed before the packet - * memory. - */ - PacketManager<>::getInstance(); - } + is_consumer_(false) {} public: using TransportCallback = interface::Portal::TransportCallback; @@ -275,6 +271,7 @@ class Portal : public ::utils::NonCopyable, ptr->transport_callback_->onError(ec); } }, + [self]([[maybe_unused]] Connector *c) { /* Nothing to do here */ }, [self](Connector *c, const std::error_code &ec) { auto ptr = self.lock(); if (ptr) { @@ -315,76 +312,113 @@ class Portal : public ::utils::NonCopyable, } /** - * Send an interest through to the local forwarder. - * - * @param interest - The pointer to the interest. The ownership of the - * interest is transferred by the caller to portal. - * - * @param on_content_object_callback - If the caller wishes to use a - * different callback to be called for this interest, it can set this - * parameter. Otherwise ConsumerCallback::onContentObject will be used. + * @brief Add interest to PIT * - * @param on_interest_timeout_callback - If the caller wishes to use a - * different callback to be called for this interest, it can set this - * parameter. Otherwise ConsumerCallback::onTimeout will be used. */ - void sendInterest( - Interest::Ptr &&interest, + void addInterestToPIT( + const Interest::Ptr &interest, uint32_t lifetime, OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, OnInterestTimeoutCallback &&on_interest_timeout_callback = UNSET_CALLBACK) { - DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - - // Send it - interest->encodeSuffixes(); - io_module_->send(*interest); - uint32_t initial_hash = interest->getName().getHash32(false); auto hash = initial_hash + interest->getName().getSuffix(); uint32_t seq = interest->getName().getSuffix(); - uint32_t *suffix = interest->firstSuffix(); + const uint32_t *suffix = interest->firstSuffix(); auto n_suffixes = interest->numberOfSuffixes(); uint32_t counter = 0; // Set timers do { + auto pend_int = pending_interest_hash_table_.try_emplace( + hash, worker_.getIoService(), interest); + PendingInterest &pending_interest = pend_int.first->second; + if (!pend_int.second) { + // element was already in map + pend_int.first->second.cancelTimer(); + pending_interest.setInterest(interest); + } + + pending_interest.setOnContentObjectCallback( + std::move(on_content_object_callback)); + pending_interest.setOnTimeoutCallback( + std::move(on_interest_timeout_callback)); + + if (is_consumer_) { + auto self = weak_from_this(); + pending_interest.startCountdown( + lifetime, portal_details::makeCustomAllocatorHandler( + async_callback_memory_, + [self, hash, seq](const std::error_code &ec) { + if (TRANSPORT_EXPECT_FALSE(ec.operator bool())) { + return; + } + + if (auto ptr = self.lock()) { + ptr->timerHandler(hash, seq); + } + })); + } + if (suffix) { hash = initial_hash + *suffix; seq = *suffix; suffix++; } + } while (counter++ < n_suffixes); + } - auto it = pending_interest_hash_table_.find(hash); - PendingInterest *pending_interest = nullptr; - if (it != pending_interest_hash_table_.end()) { - it->second.cancelTimer(); - pending_interest = &it->second; - pending_interest->setInterest(interest); - } else { - auto pend_int = pending_interest_hash_table_.try_emplace( - hash, worker_.getIoService(), interest); - pending_interest = &pend_int.first->second; - } + void matchContentObjectInPIT(ContentObject &content_object) { + uint32_t hash = getHash(content_object.getName()); + auto it = pending_interest_hash_table_.find(hash); + if (it != pending_interest_hash_table_.end()) { + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Found pending interest."; - pending_interest->setOnContentObjectCallback( - std::move(on_content_object_callback)); - pending_interest->setOnTimeoutCallback( - std::move(on_interest_timeout_callback)); + PendingInterest &pend_interest = it->second; + pend_interest.cancelTimer(); + auto _int = pend_interest.getInterest(); + auto callback = pend_interest.getOnDataCallback(); + pending_interest_hash_table_.erase(it); - auto self = weak_from_this(); - pending_interest->startCountdown( - portal_details::makeCustomAllocatorHandler( - async_callback_memory_, - [self, hash, seq](const std::error_code &ec) { - if (TRANSPORT_EXPECT_FALSE(ec.operator bool())) { - return; - } + if (is_consumer_) { + // Send object is for the app + if (callback != UNSET_CALLBACK) { + callback(*_int, content_object); + } else if (transport_callback_) { + transport_callback_->onContentObject(*_int, content_object); + } + } else { + // Send content object to the network + io_module_->send(content_object); + } + } else if (is_consumer_) { + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "No interest pending for received content object."; + } + } - if (auto ptr = self.lock()) { - ptr->timerHandler(hash, seq); - } - })); + /** + * Send an interest through to the local forwarder. + * + * @param interest - The pointer to the interest. The ownership of the + * interest is transferred by the caller to portal. + * + * @param on_content_object_callback - If the caller wishes to use a + * different callback to be called for this interest, it can set this + * parameter. Otherwise ConsumerCallback::onContentObject will be used. + * + * @param on_interest_timeout_callback - If the caller wishes to use a + * different callback to be called for this interest, it can set this + * parameter. Otherwise ConsumerCallback::onTimeout will be used. + */ + void sendInterest( + Interest::Ptr &interest, uint32_t lifetime, + OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, + OnInterestTimeoutCallback &&on_interest_timeout_callback = + UNSET_CALLBACK) { + DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - } while (counter++ < n_suffixes); + io_module_->send(*interest); + addInterestToPIT(interest, lifetime, std::move(on_content_object_callback), + std::move(on_interest_timeout_callback)); } /** @@ -423,8 +457,7 @@ class Portal : public ::utils::NonCopyable, void sendContentObject(ContentObject &content_object) { DCHECK(io_module_); DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - - io_module_->send(content_object); + matchContentObjectInPIT(content_object); } /** @@ -582,6 +615,9 @@ class Portal : public ::utils::NonCopyable, void processInterest(Interest &interest) { // Interest for a producer DLOG_IF(INFO, VLOG_IS_ON(3)) << "processInterest " << interest.getName(); + + // Save interest in PIT + addInterestToPIT(interest.shared_from_this(), interest.getLifetime()); if (TRANSPORT_EXPECT_TRUE(transport_callback_ != nullptr)) { transport_callback_->onInterest(interest); } @@ -598,27 +634,7 @@ class Portal : public ::utils::NonCopyable, void processContentObject(ContentObject &content_object) { DLOG_IF(INFO, VLOG_IS_ON(3)) << "processContentObject " << content_object.getName(); - uint32_t hash = getHash(content_object.getName()); - - auto it = pending_interest_hash_table_.find(hash); - if (it != pending_interest_hash_table_.end()) { - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Found pending interest."; - - PendingInterest &pend_interest = it->second; - pend_interest.cancelTimer(); - auto _int = pend_interest.getInterest(); - auto callback = pend_interest.getOnDataCallback(); - pending_interest_hash_table_.erase(it); - - if (callback != UNSET_CALLBACK) { - callback(*_int, content_object); - } else if (transport_callback_) { - transport_callback_->onContentObject(*_int, content_object); - } - } else { - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "No interest pending for received content object."; - } + matchContentObjectInPIT(content_object); } /** @@ -632,7 +648,7 @@ class Portal : public ::utils::NonCopyable, private: portal_details::HandlerMemory async_callback_memory_; - std::unique_ptr<IoModule, void (*)(IoModule *)> io_module_; + std::unique_ptr<IoModule> io_module_; ::utils::EventThread &worker_; diff --git a/libtransport/src/core/prefix.cc b/libtransport/src/core/prefix.cc index 4c1e191e9..00748148f 100644 --- a/libtransport/src/core/prefix.cc +++ b/libtransport/src/core/prefix.cc @@ -13,8 +13,10 @@ * limitations under the License. */ +#include <glog/logging.h> #include <hicn/transport/core/prefix.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/string_tokenizer.h> #ifndef _WIN32 @@ -37,10 +39,6 @@ namespace core { Prefix::Prefix() { std::memset(&ip_prefix_, 0, sizeof(ip_prefix_t)); } -Prefix::Prefix(const char *prefix) : Prefix(std::string(prefix)) {} - -Prefix::Prefix(std::string &&prefix) : Prefix(prefix) {} - Prefix::Prefix(const std::string &prefix) { utils::StringTokenizer st(prefix, "/"); @@ -56,7 +54,7 @@ Prefix::Prefix(const std::string &prefix) { buildPrefix(ip_address, uint16_t(atoi(prefix_length.c_str())), family); } -Prefix::Prefix(std::string &prefix, uint16_t prefix_length) { +Prefix::Prefix(const std::string &prefix, uint16_t prefix_length) { int family = get_addr_family(prefix.c_str()); buildPrefix(prefix, prefix_length, family); } @@ -73,12 +71,14 @@ Prefix::Prefix(const core::Name &content_name, uint16_t prefix_length) { ip_prefix_.family = family; } -void Prefix::buildPrefix(std::string &prefix, uint16_t prefix_length, +void Prefix::buildPrefix(const std::string &prefix, uint16_t prefix_length, int family) { if (!checkPrefixLengthAndAddressFamily(prefix_length, family)) { throw errors::InvalidIpAddressException(); } + std::memset(&ip_prefix_, 0, sizeof(ip_prefix_t)); + int ret; switch (family) { case AF_INET: @@ -131,62 +131,67 @@ std::unique_ptr<Sockaddr> Prefix::toSockaddr() const { uint16_t Prefix::getPrefixLength() const { return ip_prefix_.len; } Prefix &Prefix::setPrefixLength(uint16_t prefix_length) { + if (!checkPrefixLengthAndAddressFamily(prefix_length, ip_prefix_.family)) { + throw errors::InvalidIpAddressException(); + } + ip_prefix_.len = (u8)prefix_length; return *this; } int Prefix::getAddressFamily() const { return ip_prefix_.family; } -Prefix &Prefix::setAddressFamily(int address_family) { - ip_prefix_.family = address_family; - return *this; -} - std::string Prefix::getNetwork() const { if (!checkPrefixLengthAndAddressFamily(ip_prefix_.len, ip_prefix_.family)) { throw errors::InvalidIpAddressException(); } - std::size_t size = - ip_prefix_.family == 4 + AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN; - - std::string network(size, 0); + char buffer[INET6_ADDRSTRLEN]; - if (ip_prefix_ntop_short(&ip_prefix_, (char *)network.c_str(), size) < 0) { + if (ip_prefix_ntop_short(&ip_prefix_, buffer, INET6_ADDRSTRLEN) < 0) { throw errors::RuntimeException( "Impossible to retrieve network from ip address."); } - return network; + return buffer; } -int Prefix::contains(const ip_address_t &content_name) const { - int res = - ip_address_cmp(&content_name, &(ip_prefix_.address), ip_prefix_.family); +bool Prefix::contains(const ip_address_t &content_name) const { + uint64_t mask[2] = {0, 0}; + auto content_name_copy = content_name; + auto network_copy = ip_prefix_.address; - if (ip_prefix_.len != (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN_BITS - : IPV4_ADDR_LEN_BITS)) { - const u8 *ip_prefix_buffer = - ip_address_get_buffer(&(ip_prefix_.address), ip_prefix_.family); - const u8 *content_name_buffer = - ip_address_get_buffer(&content_name, ip_prefix_.family); - uint8_t mask = 0xFF >> (ip_prefix_.len % 8); - mask = ~mask; + auto prefix_length = getPrefixLength(); + if (ip_prefix_.family == AF_INET) { + prefix_length += 3 * IPV4_ADDR_LEN_BITS; + } - res += (ip_prefix_buffer[ip_prefix_.len] & mask) == - (content_name_buffer[ip_prefix_.len] & mask); + if (prefix_length == 0) { + mask[0] = mask[1] = 0; + } else if (prefix_length <= 64) { + mask[0] = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); + mask[1] = 0; + } else if (prefix_length == 128) { + mask[0] = mask[1] = 0xffffffffffffffff; + } else { + prefix_length -= 64; + mask[0] = 0xffffffffffffffff; + mask[1] = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); } - return res; -} + // Apply mask + content_name_copy.v6.as_u64[0] &= mask[0]; + content_name_copy.v6.as_u64[1] &= mask[1]; -int Prefix::contains(const core::Name &content_name) const { - return contains(content_name.toIpAddress().address); + network_copy.v6.as_u64[0] &= mask[0]; + network_copy.v6.as_u64[1] &= mask[1]; + + return ip_address_cmp(&network_copy, &content_name_copy, ip_prefix_.family) == + 0; } -Name Prefix::getName() const { - std::string s(getNetwork()); - return Name(s); +bool Prefix::contains(const core::Name &content_name) const { + return contains(content_name.toIpAddress().address); } /* @@ -199,8 +204,8 @@ Name Prefix::getName(const core::Name &mask, const core::Name &components, ip_prefix_.family != components.getAddressFamily() || ip_prefix_.family != content_name.getAddressFamily()) throw errors::RuntimeException( - "Prefix, mask, components and content name are not of the same address " - "family"); + "Prefix, mask, components and content name are not of the same" + "address family"); ip_address_t mask_ip = mask.toIpAddress().address; ip_address_t component_ip = components.toIpAddress().address; @@ -222,32 +227,6 @@ Name Prefix::getName(const core::Name &mask, const core::Name &components, return Name(ip_prefix_.family, (uint8_t *)&name_ip); } -Name Prefix::getRandomName() const { - ip_address_t name_ip = ip_prefix_.address; - u8 *name_ip_buffer = - const_cast<u8 *>(ip_address_get_buffer(&name_ip, ip_prefix_.family)); - - int addr_len = - (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN * 8 : IPV4_ADDR_LEN * 8) - - ip_prefix_.len; - - size_t size = (size_t)ceil((float)addr_len / 8.0); - uint8_t *buffer = (uint8_t *)malloc(sizeof(uint8_t) * size); - - RAND_bytes(buffer, (int)size); - - int j = 0; - for (uint8_t i = (uint8_t)ceil((float)ip_prefix_.len / 8.0); - i < (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN : IPV4_ADDR_LEN); - i++) { - name_ip_buffer[i] = buffer[j]; - j++; - } - free(buffer); - - return Name(ip_prefix_.family, (uint8_t *)&name_ip); -} - /* * Map a name in a different name prefix to this name prefix */ @@ -276,47 +255,66 @@ Name Prefix::mapName(const core::Name &content_name) const { return Name(ip_prefix_.family, (uint8_t *)&name_ip); } -Prefix &Prefix::setNetwork(std::string &network) { - if (!inet_pton(AF_INET6, network.c_str(), ip_prefix_.address.v6.buffer)) { +Prefix &Prefix::setNetwork(const std::string &network) { + if (!ip_address_pton(network.c_str(), &ip_prefix_.address)) { throw errors::RuntimeException("The network name is not valid."); } return *this; } +Name Prefix::makeName() const { return makeNameWithIndex(0); } + Name Prefix::makeRandomName() const { - if (ip_prefix_.family == AF_INET6) { - std::default_random_engine eng((std::random_device())()); - std::uniform_int_distribution<uint32_t> idis( - 0, std::numeric_limits<uint32_t>::max()); - uint64_t random_number = idis(eng); - - uint32_t hash_size_bits = IPV6_ADDR_LEN_BITS - ip_prefix_.len; - uint64_t ip_address[2]; - memcpy(ip_address, ip_prefix_.address.v6.buffer, sizeof(uint64_t)); - memcpy(ip_address + 1, ip_prefix_.address.v6.buffer + 8, sizeof(uint64_t)); - std::string network(IPV6_ADDR_LEN * 3, 0); - - // Let's do the magic ;) - int shift_size = hash_size_bits > sizeof(random_number) * 8 - ? sizeof(random_number) * 8 - : hash_size_bits; - - ip_address[1] >>= shift_size; - ip_address[1] <<= shift_size; - - ip_address[1] |= random_number >> (sizeof(uint64_t) * 8 - shift_size); - - if (!inet_ntop(ip_prefix_.family, ip_address, (char *)network.c_str(), - IPV6_ADDR_LEN * 3)) { - throw errors::RuntimeException( - "Impossible to retrieve network from ip address."); - } + std::default_random_engine eng((std::random_device())()); + std::uniform_int_distribution<uint32_t> idis( + 0, std::numeric_limits<uint32_t>::max()); + uint64_t random_number = idis(eng); + + return makeNameWithIndex(random_number); +} + +Name Prefix::makeNameWithIndex(std::uint64_t index) const { + uint16_t prefix_length = getPrefixLength(); - return Name(network); + Name ret; + + // Adjust prefix length depending on the address family + if (getAddressFamily() == AF_INET) { + // Sanity check + DCHECK(prefix_length <= 32); + // Convert prefix length to ip46_address_t prefix length + prefix_length += IPV4_ADDR_LEN_BITS * 3; + } + + std::memcpy(ret.getStructReference().prefix.v6.as_u8, + ip_prefix_.address.v6.as_u8, sizeof(ip_address_t)); + + // Convert index in network byte order + index = portability::host_to_net(index); + + // Apply mask + uint64_t mask; + if (prefix_length == 0) { + mask = 0; + } else if (prefix_length <= 64) { + mask = 0; + } else if (prefix_length == 128) { + mask = 0xffffffffffffffff; + } else { + prefix_length -= 64; + mask = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); } - return Name(); + ret.getStructReference().prefix.v6.as_u64[1] &= mask; + // Eventually truncate index if too big + index &= ~mask; + + // Apply index + ret.getStructReference().prefix.v6.as_u64[1] |= index; + + // Done + return ret; } bool Prefix::checkPrefixLengthAndAddressFamily(uint16_t prefix_length, diff --git a/libtransport/src/core/udp_connector.cc b/libtransport/src/core/udp_connector.cc index ee0c7ea9c..5d8e76bb1 100644 --- a/libtransport/src/core/udp_connector.cc +++ b/libtransport/src/core/udp_connector.cc @@ -56,9 +56,9 @@ void UdpTunnelConnector::send(Packet &packet) { void UdpTunnelConnector::send(const utils::MemBuf::Ptr &buffer) { auto self = shared_from_this(); - io_service_.post([self, pkt{buffer}]() { + io_service_.post([self, buffer]() { bool write_in_progress = !self->output_buffer_.empty(); - self->output_buffer_.push_back(std::move(pkt)); + self->output_buffer_.push_back(std::move(buffer)); if (TRANSPORT_EXPECT_TRUE(self->state_ == State::CONNECTED)) { if (!write_in_progress) { self->doSendPacket(self); @@ -201,6 +201,8 @@ void UdpTunnelConnector::writeHandler() { ptr->writeHandler(); } }); + } else { + sent_callback_(this, make_error_code(core_error::success)); } } diff --git a/libtransport/src/core/udp_connector.h b/libtransport/src/core/udp_connector.h index 65821852d..002f4ca9f 100644 --- a/libtransport/src/core/udp_connector.h +++ b/libtransport/src/core/udp_connector.h @@ -62,7 +62,7 @@ class UdpTunnelConnector : public Connector { #endif socket_(socket), resolver_(io_service_), - remote_endpoint_send_(std::forward<EndpointType &&>(remote_endpoint)), + remote_endpoint_send_(std::forward<EndpointType>(remote_endpoint)), timer_(io_service_), #ifdef LINUX send_timer_(io_service_), diff --git a/libtransport/src/core/udp_listener.cc b/libtransport/src/core/udp_listener.cc index c67673392..caa97e0ee 100644 --- a/libtransport/src/core/udp_listener.cc +++ b/libtransport/src/core/udp_listener.cc @@ -5,6 +5,7 @@ #include <core/udp_connector.h> #include <core/udp_listener.h> #include <glog/logging.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/hash.h> #ifndef LINUX @@ -16,7 +17,7 @@ size_t hash<asio::ip::udp::endpoint>::operator()( : utils::hash::fnv32_buf( endpoint.address().to_v6().to_bytes().data(), 16); uint16_t port = endpoint.port(); - return utils::hash::fnv32_buf(&port, 2, hash_ip); + return utils::hash::fnv32_buf(&port, 2, (unsigned int)hash_ip); } } // namespace std #endif @@ -83,7 +84,8 @@ void UdpTunnelListener::readHandler(const std::error_code &ec) { std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin_addr), address_bytes.size(), address_bytes.begin()); address_v4 address(address_bytes); - remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin_port)); + remote_endpoint_ = + udp::endpoint(address, portability::net_to_host(addr->sin_port)); } else { auto addr = reinterpret_cast<struct sockaddr_in6 *>( &remote_endpoints_[current_position_]); @@ -91,7 +93,8 @@ void UdpTunnelListener::readHandler(const std::error_code &ec) { std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin6_addr), address_bytes.size(), address_bytes.begin()); address_v6 address(address_bytes); - remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin6_port)); + remote_endpoint_ = + udp::endpoint(address, portability::net_to_host(addr->sin6_port)); } /** diff --git a/libtransport/src/core/udp_listener.h b/libtransport/src/core/udp_listener.h index 813520309..d8095a262 100644 --- a/libtransport/src/core/udp_listener.h +++ b/libtransport/src/core/udp_listener.h @@ -40,7 +40,7 @@ class UdpTunnelListener socket_(std::make_shared<asio::ip::udp::socket>(io_service_, endpoint.protocol())), local_endpoint_(endpoint), - receive_callback_(std::forward<ReceiveCallback &&>(receive_callback)), + receive_callback_(std::forward<ReceiveCallback>(receive_callback)), #ifndef LINUX read_msg_(nullptr, 0) #else @@ -63,12 +63,12 @@ class UdpTunnelListener void close(); int deleteConnector(Connector *connector) { - return connectors_.erase(connector->getConnectorId()); + return (int)connectors_.erase(connector->getConnectorId()); } template <typename ReceiveCallback> void setReceiveCallback(ReceiveCallback &&callback) { - receive_callback_ = std::forward<ReceiveCallback &&>(callback); + receive_callback_ = std::forward<ReceiveCallback>(callback); } Connector *findConnector(Connector::Id connId) { diff --git a/libtransport/src/implementation/CMakeLists.txt b/libtransport/src/implementation/CMakeLists.txt index 1f2a33a4c..c759dd964 100644 --- a/libtransport/src/implementation/CMakeLists.txt +++ b/libtransport/src/implementation/CMakeLists.txt @@ -19,21 +19,8 @@ list(APPEND HEADER_FILES if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket.cc ) - - list(APPEND HEADER_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h - # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.h - ) endif() set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.cc b/libtransport/src/implementation/p2psecure_socket_consumer.cc deleted file mode 100644 index 6b67a5487..000000000 --- a/libtransport/src/implementation/p2psecure_socket_consumer.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <implementation/p2psecure_socket_consumer.h> -#include <interfaces/tls_socket_consumer.h> -#include <openssl/bio.h> -#include <openssl/ssl.h> -#include <openssl/tls1.h> - -#include <random> - -namespace transport { -namespace implementation { - -void P2PSecureConsumerSocket::setInterestPayload( - interface::ConsumerSocket &c, const core::Interest &interest) { - Interest &int2 = const_cast<Interest &>(interest); - random_suffix_ = int2.getName().getSuffix(); - - if (payload_ != NULL) int2.appendPayload(std::move(payload_)); -} - -/* Return the number of read bytes in the return param */ -int readOld(BIO *b, char *buf, int size) { - if (size < 0) return size; - - P2PSecureConsumerSocket *socket; - socket = (P2PSecureConsumerSocket *)BIO_get_data(b); - - std::unique_lock<std::mutex> lck(socket->mtx_); - - if (!socket->something_to_read_) { - if (!socket->transport_protocol_->isRunning()) { - socket->network_name_.setSuffix(socket->random_suffix_); - socket->ConsumerSocket::consume(socket->network_name_); - } - - if (!socket->something_to_read_) socket->cv_.wait(lck); - } - - size_t size_to_read, read; - size_t chain_size = socket->head_->length(); - - if (socket->head_->isChained()) - chain_size = socket->head_->computeChainDataLength(); - - if (chain_size > (size_t)size) { - read = size_to_read = (size_t)size; - } else { - read = size_to_read = chain_size; - socket->something_to_read_ = false; - } - - while (size_to_read) { - if (socket->head_->length() < size_to_read) { - std::memcpy(buf, socket->head_->data(), socket->head_->length()); - size_to_read -= socket->head_->length(); - buf += socket->head_->length(); - socket->head_ = socket->head_->pop(); - } else { - std::memcpy(buf, socket->head_->data(), size_to_read); - socket->head_->trimStart(size_to_read); - size_to_read = 0; - } - } - - return (int)read; -} - -/* Return the number of read bytes in readbytes */ -int read(BIO *b, char *buf, size_t size, size_t *readbytes) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = readOld(b, buf, (int)size); - - if (ret <= 0) { - *readbytes = 0; - return ret; - } - - *readbytes = (size_t)ret; - - return 1; -} - -/* Return the number of written bytes in the return param */ -int writeOld(BIO *b, const char *buf, int num) { - P2PSecureConsumerSocket *socket; - socket = (P2PSecureConsumerSocket *)BIO_get_data(b); - - socket->payload_ = utils::MemBuf::copyBuffer(buf, num); - - socket->ConsumerSocket::setSocketOption( - ConsumerCallbacksOptions::INTEREST_OUTPUT, - (ConsumerInterestCallback)std::bind( - &P2PSecureConsumerSocket::setInterestPayload, socket, - std::placeholders::_1, std::placeholders::_2)); - - return num; -} - -/* Return the number of written bytes in written */ -int write(BIO *b, const char *buf, size_t size, size_t *written) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = writeOld(b, buf, (int)size); - - if (ret <= 0) { - *written = 0; - return ret; - } - - *written = (size_t)ret; - - return 1; -} - -long ctrl(BIO *b, int cmd, long num, void *ptr) { return 1; } - -int P2PSecureConsumerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char **out, - size_t *outlen, X509 *x, - size_t chainidx, int *al, - void *add_arg) { - if (ext_type == 100) { - *out = (unsigned char *)malloc(4); - *(uint32_t *)*out = 10; - *outlen = 4; - } - return 1; -} - -void P2PSecureConsumerSocket::freeHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char *out, - void *add_arg) { - free(const_cast<unsigned char *>(out)); -} - -int P2PSecureConsumerSocket::parseHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char *in, - size_t inlen, X509 *x, - size_t chainidx, int *al, - void *add_arg) { - P2PSecureConsumerSocket *socket = - reinterpret_cast<P2PSecureConsumerSocket *>(add_arg); - if (ext_type == 100) { - memcpy(&socket->secure_prefix_, in, sizeof(ip_prefix_t)); - } - return 1; -} - -P2PSecureConsumerSocket::P2PSecureConsumerSocket( - interface::ConsumerSocket *consumer, int handshake_protocol, - int transport_protocol) - : ConsumerSocket(consumer, handshake_protocol), - name_(), - tls_consumer_(nullptr), - decrypted_content_(), - payload_(), - head_(), - something_to_read_(false), - content_downloaded_(false), - random_suffix_(), - secure_prefix_(), - producer_namespace_(), - read_callback_decrypted_(), - mtx_(), - cv_(), - protocol_(transport_protocol) { - /* Create the (d)TLS state */ - const SSL_METHOD *meth = TLS_client_method(); - ctx_ = SSL_CTX_new(meth); - - int result = - SSL_CTX_set_ciphersuites(ctx_, - "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" - "SHA256:TLS_AES_128_GCM_SHA256"); - if (result != 1) { - throw errors::RuntimeException( - "Unable to set cipher list on TLS subsystem. Aborting."); - } - - SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); - SSL_CTX_set_ssl_version(ctx_, meth); - - result = SSL_CTX_add_custom_ext( - ctx_, 100, SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS, - P2PSecureConsumerSocket::addHicnKeyIdCb, - P2PSecureConsumerSocket::freeHicnKeyIdCb, NULL, - P2PSecureConsumerSocket::parseHicnKeyIdCb, this); - - ssl_ = SSL_new(ctx_); - - bio_meth_ = BIO_meth_new(BIO_TYPE_CONNECT, "secure consumer socket"); - BIO_meth_set_read(bio_meth_, readOld); - BIO_meth_set_write(bio_meth_, writeOld); - BIO_meth_set_ctrl(bio_meth_, ctrl); - BIO *bio = BIO_new(bio_meth_); - BIO_set_init(bio, 1); - BIO_set_data(bio, this); - SSL_set_bio(ssl_, bio, bio); - - std::default_random_engine generator; - std::uniform_int_distribution<int> distribution( - 1, std::numeric_limits<uint32_t>::max()); - random_suffix_ = 0; - - this->ConsumerSocket::setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, - this); -}; - -P2PSecureConsumerSocket::~P2PSecureConsumerSocket() { - BIO_meth_free(bio_meth_); - SSL_shutdown(ssl_); -} - -int P2PSecureConsumerSocket::handshake() { - int result = 1; - - if (!(SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { - return 1; - } - - ConsumerSocket::getSocketOption(MAX_WINDOW_SIZE, old_max_win_); - ConsumerSocket::getSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); - - ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, (double)1.0); - ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, (double)1.0); - - network_name_ = producer_namespace_.getRandomName(); - network_name_.setSuffix(0); - - DLOG_IF(INFO, VLOG_IS_ON(2)) << "Start handshake at " << network_name_; - result = SSL_connect(this->ssl_); - - return result; -} - -void P2PSecureConsumerSocket::initSessionSocket() { - tls_consumer_ = - std::make_shared<TLSConsumerSocket>(nullptr, this->protocol_, this->ssl_); - tls_consumer_->setInterface( - new interface::TLSConsumerSocket(tls_consumer_.get())); - - ConsumerTimerCallback *stats_summary_callback = nullptr; - this->getSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, - &stats_summary_callback); - - uint32_t lifetime; - this->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, lifetime); - - tls_consumer_->setSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, - lifetime); - tls_consumer_->setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, - read_callback_decrypted_); - tls_consumer_->setSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, - *stats_summary_callback); - tls_consumer_->setSocketOption(GeneralTransportOptions::STATS_INTERVAL, - this->timer_interval_milliseconds_); - tls_consumer_->setSocketOption(MAX_WINDOW_SIZE, old_max_win_); - tls_consumer_->setSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); - tls_consumer_->connect(); -} - -int P2PSecureConsumerSocket::consume(const Name &name) { - if (transport_protocol_->isRunning()) { - return CONSUMER_BUSY; - } - - if (handshake() != 1) { - throw errors::RuntimeException("Unable to perform client handshake"); - } else { - DLOG_IF(INFO, VLOG_IS_ON(2)) << "Handshake performed!"; - } - - initSessionSocket(); - - if (tls_consumer_ == nullptr) { - throw errors::RuntimeException("TLS socket does not exist"); - } - - std::shared_ptr<Name> prefix_name = std::make_shared<Name>( - secure_prefix_.family, - ip_address_get_buffer(&(secure_prefix_.address), secure_prefix_.family)); - std::shared_ptr<Prefix> prefix = - std::make_shared<Prefix>(*prefix_name, secure_prefix_.len); - - if (payload_ != nullptr) - return tls_consumer_->consume((prefix->mapName(name)), std::move(payload_)); - else - return tls_consumer_->consume((prefix->mapName(name))); -} - -void P2PSecureConsumerSocket::registerPrefix(const Prefix &producer_namespace) { - producer_namespace_ = producer_namespace; -} - -int P2PSecureConsumerSocket::setSocketOption( - int socket_option_key, ReadCallback *socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, ReadCallback *socket_option_value) -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::READ_CALLBACK: - read_callback_decrypted_ = socket_option_value; - break; - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); -} - -void P2PSecureConsumerSocket::getReadBuffer(uint8_t **application_buffer, - size_t *max_length){}; - -void P2PSecureConsumerSocket::readDataAvailable(size_t length) noexcept {}; - -size_t P2PSecureConsumerSocket::maxBufferSize() const { - return SSL3_RT_MAX_PLAIN_LENGTH; -} - -void P2PSecureConsumerSocket::readBufferAvailable( - std::unique_ptr<utils::MemBuf> &&buffer) noexcept { - std::unique_lock<std::mutex> lck(this->mtx_); - if (head_) { - head_->prependChain(std::move(buffer)); - } else { - head_ = std::move(buffer); - } - - something_to_read_ = true; - cv_.notify_one(); -} - -void P2PSecureConsumerSocket::readError(const std::error_code &ec) noexcept {}; - -void P2PSecureConsumerSocket::readSuccess(std::size_t total_size) noexcept { - std::unique_lock<std::mutex> lck(this->mtx_); - content_downloaded_ = true; - something_to_read_ = true; - cv_.notify_one(); -} - -bool P2PSecureConsumerSocket::isBufferMovable() noexcept { return true; } - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/p2psecure_socket_consumer.h b/libtransport/src/implementation/p2psecure_socket_consumer.h deleted file mode 100644 index a5e69f611..000000000 --- a/libtransport/src/implementation/p2psecure_socket_consumer.h +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/interfaces/socket_consumer.h> -#include <implementation/tls_socket_consumer.h> -#include <openssl/bio.h> -#include <openssl/ssl.h> - -namespace transport { -namespace implementation { - -class P2PSecureConsumerSocket : public ConsumerSocket, - public interface::ConsumerSocket::ReadCallback { - /* Return the number of read bytes in readbytes */ - friend int read(BIO *b, char *buf, size_t size, size_t *readbytes); - - /* Return the number of read bytes in the return param */ - friend int readOld(BIO *h, char *buf, int size); - - /* Return the number of written bytes in written */ - friend int write(BIO *b, const char *buf, size_t size, size_t *written); - - /* Return the number of written bytes in the return param */ - friend int writeOld(BIO *h, const char *buf, int num); - - friend long ctrl(BIO *b, int cmd, long num, void *ptr); - - public: - explicit P2PSecureConsumerSocket(interface::ConsumerSocket *consumer, - int handshake_protocol, - int transport_protocol); - - ~P2PSecureConsumerSocket(); - - int consume(const Name &name) override; - - void registerPrefix(const Prefix &producer_namespace); - - int setSocketOption( - int socket_option_key, - interface::ConsumerSocket::ReadCallback *socket_option_value) override; - - using ConsumerSocket::getSocketOption; - using ConsumerSocket::setSocketOption; - - protected: - /* Callback invoked once an interest has been received and its payload - * decrypted */ - ConsumerInterestCallback on_interest_input_decrypted_; - ConsumerInterestCallback on_interest_process_decrypted_; - - private: - Name name_; - std::shared_ptr<TLSConsumerSocket> tls_consumer_; - /* SSL handle */ - SSL *ssl_; - SSL_CTX *ctx_; - BIO_METHOD *bio_meth_; - /* Chain of MemBuf to be used as a temporary buffer to pass descypted data - * from the underlying layer to the application */ - std::unique_ptr<utils::MemBuf> decrypted_content_; - /* Chain of MemBuf holding the payload to be written into interest or data */ - std::unique_ptr<utils::MemBuf> payload_; - /* Chain of MemBuf holding the data retrieved from the underlying layer */ - std::unique_ptr<utils::MemBuf> head_; - bool something_to_read_; - bool content_downloaded_; - double old_max_win_; - double old_current_win_; - uint32_t random_suffix_; - ip_prefix_t secure_prefix_; - Prefix producer_namespace_; - interface::ConsumerSocket::ReadCallback *read_callback_decrypted_; - std::mutex mtx_; - - /* Condition variable for the wait */ - std::condition_variable cv_; - - int protocol_; - - void setInterestPayload(interface::ConsumerSocket &c, - const core::Interest &interest); - - static int addHicnKeyIdCb(SSL *s, unsigned int ext_type, unsigned int context, - const unsigned char **out, size_t *outlen, X509 *x, - size_t chainidx, int *al, void *add_arg); - - static void freeHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, const unsigned char *out, - void *add_arg); - - static int parseHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, const unsigned char *in, - size_t inlen, X509 *x, size_t chainidx, int *al, - void *add_arg); - - virtual void getReadBuffer(uint8_t **application_buffer, - size_t *max_length) override; - - virtual void readDataAvailable(size_t length) noexcept override; - - virtual size_t maxBufferSize() const override; - - virtual void readBufferAvailable( - std::unique_ptr<utils::MemBuf> &&buffer) noexcept override; - - virtual void readError(const std::error_code &ec) noexcept override; - - virtual void readSuccess(std::size_t total_size) noexcept override; - - virtual bool isBufferMovable() noexcept override; - - int handshake(); - - void initSessionSocket(); -}; - -} // namespace implementation - -} // end namespace transport diff --git a/libtransport/src/implementation/p2psecure_socket_producer.cc b/libtransport/src/implementation/p2psecure_socket_producer.cc deleted file mode 100644 index ee78ea53b..000000000 --- a/libtransport/src/implementation/p2psecure_socket_producer.cc +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/core/interest.h> -#include <implementation/p2psecure_socket_producer.h> -// #include <implementation/tls_rtc_socket_producer.h> -#include <implementation/tls_socket_producer.h> -#include <interfaces/tls_rtc_socket_producer.h> -#include <interfaces/tls_socket_producer.h> -#include <openssl/bio.h> -#include <openssl/pkcs12.h> -#include <openssl/rand.h> -#include <openssl/ssl.h> - -namespace transport { -namespace implementation { - -/* Workaround to prevent content with expiry time equal to 0 to be lost when - * pushed in the forwarder */ -#define HICN_HANDSHAKE_CONTENT_EXPIRY_TIME 100; - -P2PSecureProducerSocket::P2PSecureProducerSocket( - interface::ProducerSocket *producer_socket) - : ProducerSocket(producer_socket, - ProductionProtocolAlgorithms::BYTE_STREAM), - mtx_(), - cv_(), - map_producers(), - list_producers() {} - -P2PSecureProducerSocket::P2PSecureProducerSocket( - interface::ProducerSocket *producer_socket, bool rtc, - std::string &keystore_path, std::string &keystore_pwd) - : ProducerSocket(producer_socket, - ProductionProtocolAlgorithms::BYTE_STREAM), - rtc_(rtc), - mtx_(), - cv_(), - map_producers(), - list_producers() { - /* Setup SSL context (identity and parameter to use TLS 1.3) */ - FILE *p12file = fopen(keystore_path.c_str(), "r"); - if (p12file == NULL) - throw errors::RuntimeException("impossible open keystore"); - std::unique_ptr<PKCS12, decltype(&::PKCS12_free)> p12( - d2i_PKCS12_fp(p12file, NULL), ::PKCS12_free); - // now we parse the file to get the first key and certificate - if (1 != PKCS12_parse(p12.get(), keystore_pwd.c_str(), &pkey_rsa_, &cert_509_, - NULL)) - throw errors::RuntimeException("impossible to get the private key"); - fclose(p12file); - - /* Set the callback so that when an interest is received we catch it and we - * decrypt the payload before passing it to the application. */ - ProducerSocket::setSocketOption( - ProducerCallbacksOptions::INTEREST_INPUT, - (ProducerInterestCallback)std::bind( - &P2PSecureProducerSocket::onInterestCallback, this, - std::placeholders::_1, std::placeholders::_2)); -} - -P2PSecureProducerSocket::~P2PSecureProducerSocket() {} - -void P2PSecureProducerSocket::initSessionSocket( - std::unique_ptr<TLSProducerSocket> &producer) { - producer->on_content_produced_application_ = - this->on_content_produced_application_; - producer->setSocketOption(CONTENT_OBJECT_EXPIRY_TIME, - this->content_object_expiry_time_); - producer->setSocketOption(SIGNER, this->signer_); - producer->setSocketOption(MAKE_MANIFEST, this->making_manifest_); - producer->setSocketOption(DATA_PACKET_SIZE, - (uint32_t)(this->data_packet_size_)); - uint32_t output_buffer_size = 0; - this->getSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, - output_buffer_size); - producer->setSocketOption(GeneralTransportOptions::OUTPUT_BUFFER_SIZE, - output_buffer_size); - - if (!rtc_) { - producer->setInterface(new interface::TLSProducerSocket(producer.get())); - } else { - // TODO - // TLSRTCProducerSocket *rtc_producer = - // dynamic_cast<TLSRTCProducerSocket *>(producer.get()); - // rtc_producer->setInterface( - // new interface::TLSRTCProducerSocket(rtc_producer)); - } -} - -void P2PSecureProducerSocket::onInterestCallback(interface::ProducerSocket &p, - Interest &interest) { - std::unique_lock<std::mutex> lck(mtx_); - std::unique_ptr<TLSProducerSocket> tls_producer; - auto it = map_producers.find(interest.getName()); - - if (it != map_producers.end()) { - return; - } - - if (!rtc_) { - tls_producer = - std::make_unique<TLSProducerSocket>(nullptr, this, interest.getName()); - } else { - // TODO - // tls_producer = std::make_unique<TLSRTCProducerSocket>(nullptr, this, - // interest.getName()); - } - - initSessionSocket(tls_producer); - TLSProducerSocket *tls_producer_ptr = tls_producer.get(); - map_producers.insert({interest.getName(), move(tls_producer)}); - - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Start handshake at " << interest.getName(); - - if (!rtc_) { - tls_producer_ptr->onInterest(*tls_producer_ptr, interest); - tls_producer_ptr->async_accept(); - } else { - // TODO - // TLSRTCProducerSocket *rtc_producer_ptr = - // dynamic_cast<TLSRTCProducerSocket *>(tls_producer_ptr); - // rtc_producer_ptr->onInterest(*rtc_producer_ptr, interest); - // rtc_producer_ptr->async_accept(); - } -} - -uint32_t P2PSecureProducerSocket::produceDatagram( - const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { - // TODO - throw errors::NotImplementedException(); - - // if (!rtc_) { - // throw errors::RuntimeException( - // "RTC must be the transport protocol to start the production of - // current " "data. Aborting."); - // } - - // std::unique_lock<std::mutex> lck(mtx_); - - // if (list_producers.empty()) cv_.wait(lck); - - // TODO - // for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - // { - // TLSRTCProducerSocket *rtc_producer = - // dynamic_cast<TLSRTCProducerSocket *>(it->get()); - // rtc_producer->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); - // } - - // return 0; -} - -uint32_t P2PSecureProducerSocket::produceStream( - const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { - if (rtc_) { - throw errors::RuntimeException( - "RTC transport protocol is not compatible with the production of " - "current data. Aborting."); - } - - std::unique_lock<std::mutex> lck(mtx_); - uint32_t segments = 0; - - if (list_producers.empty()) cv_.wait(lck); - - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += (*it)->produceStream(content_name, buffer->clone(), is_last, - start_offset); - - return segments; -} - -uint32_t P2PSecureProducerSocket::produceStream(const Name &content_name, - const uint8_t *buffer, - size_t buffer_size, - bool is_last, - uint32_t start_offset) { - if (rtc_) { - throw errors::RuntimeException( - "RTC transport protocol is not compatible with the production of " - "current data. Aborting."); - } - - std::unique_lock<std::mutex> lck(mtx_); - uint32_t segments = 0; - if (list_producers.empty()) cv_.wait(lck); - - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - segments += (*it)->produceStream(content_name, buffer, buffer_size, is_last, - start_offset); - - return segments; -} - -/* Redefinition of socket options to avoid name hiding */ -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, ProducerInterestCallback socket_option_value) { - if (!list_producers.empty()) { - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - } - - switch (socket_option_key) { - case ProducerCallbacksOptions::INTEREST_INPUT: - on_interest_input_decrypted_ = socket_option_value; - return SOCKET_OPTION_SET; - - case ProducerCallbacksOptions::INTEREST_DROP: - on_interest_dropped_input_buffer_ = socket_option_value; - return SOCKET_OPTION_SET; - - case ProducerCallbacksOptions::INTEREST_PASS: - on_interest_inserted_input_buffer_ = socket_option_value; - return SOCKET_OPTION_SET; - - case ProducerCallbacksOptions::CACHE_HIT: - on_interest_satisfied_output_buffer_ = socket_option_value; - return SOCKET_OPTION_SET; - - case ProducerCallbacksOptions::CACHE_MISS: - on_interest_process_decrypted_ = socket_option_value; - return SOCKET_OPTION_SET; - - default: - return SOCKET_OPTION_NOT_SET; - } -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, - const std::shared_ptr<auth::Signer> &socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - switch (socket_option_key) { - case GeneralTransportOptions::SIGNER: { - signer_.reset(); - signer_ = socket_option_value; - - return SOCKET_OPTION_SET; - } - default: - return SOCKET_OPTION_NOT_SET; - } -} - -int P2PSecureProducerSocket::setSocketOption(int socket_option_key, - uint32_t socket_option_value) { - if (!list_producers.empty()) { - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - } - switch (socket_option_key) { - case GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME: - content_object_expiry_time_ = - socket_option_value; // HICN_HANDSHAKE_CONTENT_EXPIRY_TIME; - return SOCKET_OPTION_SET; - } - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption(int socket_option_key, - bool socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption(int socket_option_key, - Name *socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, ProducerContentObjectCallback socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, ProducerContentCallback socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - switch (socket_option_key) { - case ProducerCallbacksOptions::CONTENT_PRODUCED: - on_content_produced_application_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, auth::CryptoHashType socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -int P2PSecureProducerSocket::setSocketOption( - int socket_option_key, const std::string &socket_option_value) { - if (!list_producers.empty()) - for (auto it = list_producers.cbegin(); it != list_producers.cend(); it++) - (*it)->setSocketOption(socket_option_key, socket_option_value); - - return ProducerSocket::setSocketOption(socket_option_key, - socket_option_value); -} - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/p2psecure_socket_producer.h b/libtransport/src/implementation/p2psecure_socket_producer.h deleted file mode 100644 index 00f407a75..000000000 --- a/libtransport/src/implementation/p2psecure_socket_producer.h +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/auth/signer.h> -#include <implementation/socket_producer.h> -// #include <implementation/tls_rtc_socket_producer.h> -#include <implementation/tls_socket_producer.h> -#include <openssl/ssl.h> -#include <utils/content_store.h> - -#include <condition_variable> -#include <forward_list> -#include <mutex> - -namespace transport { -namespace implementation { - -class P2PSecureProducerSocket : public ProducerSocket { - friend class TLSProducerSocket; - // TODO - // friend class TLSRTCProducerSocket; - - public: - explicit P2PSecureProducerSocket(interface::ProducerSocket *producer_socket); - - explicit P2PSecureProducerSocket(interface::ProducerSocket *producer_socket, - bool rtc, std::string &keystore_path, - std::string &keystore_pwd); - - ~P2PSecureProducerSocket(); - - uint32_t produceDatagram(const Name &content_name, - std::unique_ptr<utils::MemBuf> &&buffer) override; - - uint32_t produceStream(const Name &content_name, const uint8_t *buffer, - size_t buffer_size, bool is_last = true, - uint32_t start_offset = 0) override; - - uint32_t produceStream(const Name &content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, - uint32_t start_offset = 0) override; - - int setSocketOption(int socket_option_key, - ProducerInterestCallback socket_option_value) override; - - int setSocketOption( - int socket_option_key, - const std::shared_ptr<auth::Signer> &socket_option_value) override; - - int setSocketOption(int socket_option_key, - uint32_t socket_option_value) override; - - int setSocketOption(int socket_option_key, bool socket_option_value) override; - - int setSocketOption(int socket_option_key, - Name *socket_option_value) override; - - int setSocketOption( - int socket_option_key, - ProducerContentObjectCallback socket_option_value) override; - - int setSocketOption(int socket_option_key, - ProducerContentCallback socket_option_value) override; - - int setSocketOption(int socket_option_key, - auth::CryptoHashType socket_option_value) override; - - int setSocketOption(int socket_option_key, - const std::string &socket_option_value) override; - - using ProducerSocket::getSocketOption; - // using ProducerSocket::onInterest; - - protected: - /* Callback invoked once an interest has been received and its payload - * decrypted */ - ProducerInterestCallback on_interest_input_decrypted_; - ProducerInterestCallback on_interest_process_decrypted_; - ProducerContentCallback on_content_produced_application_; - - private: - bool rtc_; - std::mutex mtx_; - /* Condition variable for the wait */ - std::condition_variable cv_; - X509 *cert_509_; - EVP_PKEY *pkey_rsa_; - std::unordered_map<core::Name, std::unique_ptr<TLSProducerSocket>, - core::hash<core::Name>, core::compare2<core::Name>> - map_producers; - std::list<std::unique_ptr<TLSProducerSocket>> list_producers; - - void onInterestCallback(interface::ProducerSocket &p, Interest &interest); - - void initSessionSocket(std::unique_ptr<TLSProducerSocket> &producer); -}; - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/socket.cc b/libtransport/src/implementation/socket.cc index 95941da07..b80fbb58c 100644 --- a/libtransport/src/implementation/socket.cc +++ b/libtransport/src/implementation/socket.cc @@ -23,7 +23,9 @@ namespace implementation { Socket::Socket(std::shared_ptr<core::Portal> &&portal) : portal_(std::move(portal)), is_async_(false), - packet_format_(interface::default_values::packet_format) {} + packet_format_(interface::default_values::packet_format), + signer_(std::make_shared<auth::VoidSigner>()), + verifier_(std::make_shared<auth::VoidVerifier>()) {} int Socket::setSocketOption(int socket_option_key, hicn_format_t packet_format) { diff --git a/libtransport/src/implementation/socket.h b/libtransport/src/implementation/socket.h index 11c9a704d..3eb93cff6 100644 --- a/libtransport/src/implementation/socket.h +++ b/libtransport/src/implementation/socket.h @@ -16,6 +16,8 @@ #pragma once #include <core/facade.h> +#include <hicn/transport/auth/signer.h> +#include <hicn/transport/auth/verifier.h> #include <hicn/transport/config.h> #include <hicn/transport/interfaces/callbacks.h> #include <hicn/transport/interfaces/socket_options_default_values.h> @@ -68,6 +70,8 @@ class Socket { std::shared_ptr<core::Portal> portal_; bool is_async_; hicn_format_t packet_format_; + std::shared_ptr<auth::Signer> signer_; + std::shared_ptr<auth::Verifier> verifier_; }; } // namespace implementation diff --git a/libtransport/src/implementation/socket_consumer.h b/libtransport/src/implementation/socket_consumer.h index 33e70888f..4721f426c 100644 --- a/libtransport/src/implementation/socket_consumer.h +++ b/libtransport/src/implementation/socket_consumer.h @@ -56,8 +56,8 @@ class ConsumerSocket : public Socket { rate_estimation_observer_(nullptr), rate_estimation_batching_parameter_(default_values::batch), rate_estimation_choice_(0), - unverified_interval_(default_values::unverified_interval), - unverified_ratio_(default_values::unverified_ratio), + manifest_factor_relevant_(default_values::manifest_factor_relevant), + manifest_factor_alert_(default_values::manifest_factor_alert), verifier_(std::make_shared<auth::VoidVerifier>()), verify_signature_(false), reset_window_(false), @@ -72,6 +72,8 @@ class ConsumerSocket : public Socket { timer_interval_milliseconds_(0), recovery_strategy_(RtcTransportRecoveryStrategies::RTX_ONLY), aggregated_data_(false), + content_sharing_mode_(false), + aggregated_interests_(false), guard_raaqm_params_() { switch (protocol) { case TransportProtocolAlgorithms::CBR: @@ -197,10 +199,6 @@ class ConsumerSocket : public Socket { current_window_size_ = socket_option_value; break; - case UNVERIFIED_RATIO: - unverified_ratio_ = socket_option_value; - break; - case GAMMA_VALUE: gamma_ = socket_option_value; break; @@ -242,10 +240,6 @@ class ConsumerSocket : public Socket { interest_lifetime_ = socket_option_value; break; - case GeneralTransportOptions::UNVERIFIED_INTERVAL: - unverified_interval_ = socket_option_value; - break; - case RateEstimationOptions::RATE_ESTIMATION_BATCH_PARAMETER: if (socket_option_value > 0) { rate_estimation_batching_parameter_ = socket_option_value; @@ -271,6 +265,14 @@ class ConsumerSocket : public Socket { (RtcTransportRecoveryStrategies)socket_option_value; break; + case MANIFEST_FACTOR_RELEVANT: + manifest_factor_relevant_ = socket_option_value; + break; + + case MANIFEST_FACTOR_ALERT: + manifest_factor_alert_ = socket_option_value; + break; + default: return SOCKET_OPTION_NOT_SET; } @@ -339,6 +341,16 @@ class ConsumerSocket : public Socket { result = SOCKET_OPTION_SET; break; + case RtcTransportOptions::CONTENT_SHARING_MODE: + content_sharing_mode_ = socket_option_value; + result = SOCKET_OPTION_SET; + break; + + case RtcTransportOptions::AGGREGATED_INTERESTS: + aggregated_interests_ = socket_option_value; + result = SOCKET_OPTION_SET; + break; + default: return result; } @@ -416,6 +428,22 @@ class ConsumerSocket : public Socket { int setSocketOption( int socket_option_key, + const std::shared_ptr<auth::Signer> &socket_option_value) { + if (!transport_protocol_->isRunning()) { + switch (socket_option_key) { + case GeneralTransportOptions::SIGNER: + signer_.reset(); + signer_ = socket_option_value; + break; + default: + return SOCKET_OPTION_NOT_SET; + } + } + return SOCKET_OPTION_SET; + } + + int setSocketOption( + int socket_option_key, const std::shared_ptr<auth::Verifier> &socket_option_value) { if (!transport_protocol_->isRunning()) { switch (socket_option_key) { @@ -506,10 +534,6 @@ class ConsumerSocket : public Socket { socket_option_value = current_window_size_; break; - case GeneralTransportOptions::UNVERIFIED_RATIO: - socket_option_value = unverified_ratio_; - break; - // RAAQM parameters case RaaqmTransportOptions::GAMMA_VALUE: @@ -550,10 +574,6 @@ class ConsumerSocket : public Socket { socket_option_value = interest_lifetime_; break; - case GeneralTransportOptions::UNVERIFIED_INTERVAL: - socket_option_value = unverified_interval_; - break; - case RaaqmTransportOptions::SAMPLE_NUMBER: socket_option_value = sample_number_; break; @@ -574,6 +594,14 @@ class ConsumerSocket : public Socket { socket_option_value = recovery_strategy_; break; + case GeneralTransportOptions::MANIFEST_FACTOR_RELEVANT: + socket_option_value = manifest_factor_relevant_; + break; + + case GeneralTransportOptions::MANIFEST_FACTOR_ALERT: + socket_option_value = manifest_factor_alert_; + break; + default: return SOCKET_OPTION_NOT_GET; } @@ -599,6 +627,14 @@ class ConsumerSocket : public Socket { socket_option_value = aggregated_data_; break; + case RtcTransportOptions::CONTENT_SHARING_MODE: + socket_option_value = content_sharing_mode_; + break; + + case RtcTransportOptions::AGGREGATED_INTERESTS: + socket_option_value = aggregated_interests_; + break; + default: return SOCKET_OPTION_NOT_GET; } @@ -689,6 +725,18 @@ class ConsumerSocket : public Socket { } int getSocketOption(int socket_option_key, + std::shared_ptr<auth::Signer> &socket_option_value) { + switch (socket_option_key) { + case GeneralTransportOptions::SIGNER: + socket_option_value = signer_; + return SOCKET_OPTION_GET; + + default: + return SOCKET_OPTION_NOT_GET; + } + } + + int getSocketOption(int socket_option_key, std::shared_ptr<auth::Verifier> &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::VERIFIER: @@ -827,8 +875,8 @@ class ConsumerSocket : public Socket { int rate_estimation_choice_; // Verification parameters - uint32_t unverified_interval_; - double unverified_ratio_; + uint32_t manifest_factor_relevant_; + uint32_t manifest_factor_alert_; std::shared_ptr<auth::Verifier> verifier_; transport::auth::KeyId *key_id_; std::atomic_bool verify_signature_; @@ -856,6 +904,8 @@ class ConsumerSocket : public Socket { // RTC protocol RtcTransportRecoveryStrategies recovery_strategy_; bool aggregated_data_; + bool content_sharing_mode_; + bool aggregated_interests_; utils::SpinLock guard_raaqm_params_; std::string output_interface_; diff --git a/libtransport/src/implementation/socket_producer.h b/libtransport/src/implementation/socket_producer.h index 37151d497..53ce28766 100644 --- a/libtransport/src/implementation/socket_producer.h +++ b/libtransport/src/implementation/socket_producer.h @@ -51,9 +51,8 @@ class ProducerSocket : public Socket { data_packet_size_(default_values::content_object_packet_size), max_segment_size_(default_values::content_object_packet_size), content_object_expiry_time_(default_values::content_object_expiry_time), - making_manifest_(default_values::manifest_capacity), + manifest_max_capacity_(default_values::manifest_max_capacity), hash_algorithm_(auth::CryptoHashType::SHA256), - signer_(std::make_shared<auth::VoidSigner>()), suffix_strategy_(std::make_shared<utils::IncrementalSuffixStrategy>(0)), aggregated_data_(false), fec_setting_(""), @@ -181,8 +180,8 @@ class ProducerSocket : public Socket { } break; - case GeneralTransportOptions::MAKE_MANIFEST: - making_manifest_ = socket_option_value; + case GeneralTransportOptions::MANIFEST_MAX_CAPACITY: + manifest_max_capacity_ = socket_option_value; break; case GeneralTransportOptions::MAX_SEGMENT_SIZE: @@ -433,6 +432,20 @@ class ProducerSocket : public Socket { return SOCKET_OPTION_SET; } + virtual int setSocketOption( + int socket_option_key, + const std::shared_ptr<auth::Verifier> &socket_option_value) { + switch (socket_option_key) { + case GeneralTransportOptions::VERIFIER: + verifier_.reset(); + verifier_ = socket_option_value; + return SOCKET_OPTION_SET; + + default: + return SOCKET_OPTION_NOT_SET; + } + } + int getSocketOption(int socket_option_key, ProducerCallback **socket_option_value) { // Reschedule the function on the io_service to avoid race condition in @@ -456,12 +469,13 @@ class ProducerSocket : public Socket { virtual int getSocketOption(int socket_option_key, uint32_t &socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::MAKE_MANIFEST: - socket_option_value = making_manifest_; + case GeneralTransportOptions::MANIFEST_MAX_CAPACITY: + socket_option_value = (uint32_t)manifest_max_capacity_; break; case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: - socket_option_value = production_protocol_->getOutputBufferSize(); + socket_option_value = + (uint32_t)production_protocol_->getOutputBufferSize(); break; case GeneralTransportOptions::DATA_PACKET_SIZE: @@ -636,6 +650,18 @@ class ProducerSocket : public Socket { return SOCKET_OPTION_GET; } + int getSocketOption(int socket_option_key, + std::shared_ptr<auth::Verifier> &socket_option_value) { + switch (socket_option_key) { + case GeneralTransportOptions::VERIFIER: + socket_option_value = verifier_; + return SOCKET_OPTION_GET; + + default: + return SOCKET_OPTION_NOT_GET; + } + } + int getSocketOption(int socket_option_key, std::string &socket_option_value) { switch (socket_option_key) { case GeneralTransportOptions::FEC_TYPE: @@ -736,11 +762,10 @@ class ProducerSocket : public Socket { std::atomic<size_t> max_segment_size_; std::atomic<uint32_t> content_object_expiry_time_; - std::atomic<uint32_t> making_manifest_; + std::atomic<uint32_t> manifest_max_capacity_; std::atomic<auth::CryptoHashType> hash_algorithm_; std::atomic<auth::CryptoSuite> crypto_suite_; utils::SpinLock signer_lock_; - std::shared_ptr<auth::Signer> signer_; std::shared_ptr<utils::SuffixStrategy> suffix_strategy_; std::shared_ptr<protocol::ProductionProtocol> production_protocol_; diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.cc b/libtransport/src/implementation/tls_rtc_socket_producer.cc deleted file mode 100644 index 06d613ef0..000000000 --- a/libtransport/src/implementation/tls_rtc_socket_producer.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/core/interest.h> -#include <hicn/transport/interfaces/p2psecure_socket_producer.h> -#include <implementation/p2psecure_socket_producer.h> -#include <implementation/tls_rtc_socket_producer.h> -#include <openssl/bio.h> -#include <openssl/rand.h> -#include <openssl/ssl.h> - -namespace transport { -namespace implementation { - -int TLSRTCProducerSocket::read(BIO *b, char *buf, size_t size, - size_t *readbytes) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = TLSRTCProducerSocket::readOld(b, buf, (int)size); - - if (ret <= 0) { - *readbytes = 0; - return ret; - } - - *readbytes = (size_t)ret; - - return 1; -} - -int TLSRTCProducerSocket::readOld(BIO *b, char *buf, int size) { - TLSRTCProducerSocket *socket; - socket = (TLSRTCProducerSocket *)BIO_get_data(b); - - std::unique_lock<std::mutex> lck(socket->mtx_); - if (!socket->something_to_read_) { - (socket->cv_).wait(lck); - } - - utils::MemBuf *membuf = socket->handshake_packet_->next(); - int size_to_read; - - if ((int)membuf->length() > size) { - size_to_read = size; - } else { - size_to_read = membuf->length(); - socket->something_to_read_ = false; - } - - std::memcpy(buf, membuf->data(), size_to_read); - membuf->trimStart(size_to_read); - - return size_to_read; -} - -int TLSRTCProducerSocket::write(BIO *b, const char *buf, size_t size, - size_t *written) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = TLSRTCProducerSocket::writeOld(b, buf, (int)size); - - if (ret <= 0) { - *written = 0; - return ret; - } - - *written = (size_t)ret; - - return 1; -} - -int TLSRTCProducerSocket::writeOld(BIO *b, const char *buf, int num) { - TLSRTCProducerSocket *socket; - socket = (TLSRTCProducerSocket *)BIO_get_data(b); - - if (socket->getHandshakeState() != SERVER_FINISHED && socket->first_) { - uint32_t making_manifest = socket->parent_->making_manifest_; - - socket->tls_chunks_--; - socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, - 0U); - socket->parent_->ProducerSocket::produce( - socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, 0); - socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, - making_manifest); - socket->first_ = false; - - } else { - std::unique_ptr<utils::MemBuf> mbuf = - utils::MemBuf::copyBuffer(buf, (std::size_t)num, 0, 0); - auto a = mbuf.release(); - - socket->async_thread_.add([socket = socket, a]() { - socket->to_call_oncontentproduced_--; - auto mbuf = std::unique_ptr<utils::MemBuf>(a); - - socket->RTCProducerSocket::produce(std::move(mbuf)); - - ProducerContentCallback on_content_produced_application; - socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, - on_content_produced_application); - - if (socket->to_call_oncontentproduced_ == 0 && - on_content_produced_application) { - on_content_produced_application( - (transport::interface::ProducerSocket &)(*socket->getInterface()), - std::error_code(), 0); - } - }); - } - - return num; -} - -TLSRTCProducerSocket::TLSRTCProducerSocket( - interface::ProducerSocket *producer_socket, P2PSecureProducerSocket *parent, - const Name &handshake_name) - : ProducerSocket(producer_socket), - RTCProducerSocket(producer_socket), - TLSProducerSocket(producer_socket, parent, handshake_name) { - BIO_METHOD *bio_meth = - BIO_meth_new(BIO_TYPE_ACCEPT, "secure rtc producer socket"); - BIO_meth_set_read(bio_meth, TLSRTCProducerSocket::readOld); - BIO_meth_set_write(bio_meth, TLSRTCProducerSocket::writeOld); - BIO_meth_set_ctrl(bio_meth, TLSProducerSocket::ctrl); - BIO *bio = BIO_new(bio_meth); - BIO_set_init(bio, 1); - BIO_set_data(bio, this); - SSL_set_bio(ssl_, bio, bio); -} - -void TLSRTCProducerSocket::accept() { - HandshakeState handshake_state = getHandshakeState(); - - if (handshake_state == UNINITIATED || handshake_state == CLIENT_HELLO) { - tls_chunks_ = 1; - int result = SSL_accept(ssl_); - - if (result != 1) - throw errors::RuntimeException("Unable to perform client handshake"); - } - - DLOG_IF(INFO, VLOG_IS_ON(2)) << "Handshake performed!"; - - parent_->list_producers.push_front( - std::move(parent_->map_producers[handshake_name_])); - parent_->map_producers.erase(handshake_name_); - - ProducerInterestCallback on_interest_process_decrypted; - getSocketOption(ProducerCallbacksOptions::CACHE_MISS, - on_interest_process_decrypted); - - if (on_interest_process_decrypted) { - Interest inter(std::move(handshake_packet_)); - on_interest_process_decrypted( - (transport::interface::ProducerSocket &)(*getInterface()), inter); - } - - parent_->cv_.notify_one(); -} - -int TLSRTCProducerSocket::async_accept() { - if (!async_thread_.stopped()) { - async_thread_.add([this]() { this->TLSRTCProducerSocket::accept(); }); - } else { - throw errors::RuntimeException( - "Async thread not running, impossible to perform handshake"); - } - - return 1; -} - -void TLSRTCProducerSocket::produce(std::unique_ptr<utils::MemBuf> &&buffer) { - HandshakeState handshake_state = getHandshakeState(); - - if (handshake_state != SERVER_FINISHED) { - throw errors::RuntimeException( - "New handshake on the same P2P secure producer socket not supported"); - } - - size_t buf_size = buffer->length(); - tls_chunks_ = ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); - to_call_oncontentproduced_ = tls_chunks_; - - SSL_write(ssl_, buffer->data(), buf_size); - BIO *wbio = SSL_get_wbio(ssl_); - int i = BIO_flush(wbio); - (void)i; // To shut up gcc 5 -} - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/tls_rtc_socket_producer.h b/libtransport/src/implementation/tls_rtc_socket_producer.h deleted file mode 100644 index f6dc425e4..000000000 --- a/libtransport/src/implementation/tls_rtc_socket_producer.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <implementation/tls_socket_producer.h> - -namespace transport { -namespace implementation { - -class P2PSecureProducerSocket; - -class TLSRTCProducerSocket : public TLSProducerSocket { - friend class P2PSecureProducerSocket; - - public: - explicit TLSRTCProducerSocket(interface::ProducerSocket *producer_socket, - P2PSecureProducerSocket *parent, - const Name &handshake_name); - - ~TLSRTCProducerSocket() = default; - - uint32_t produceDatagram(const Name &content_name, - std::unique_ptr<utils::MemBuf> &&buffer) override; - - void accept() override; - - int async_accept() override; - - using TLSProducerSocket::onInterest; - using TLSProducerSocket::produce; - - protected: - static int read(BIO *b, char *buf, size_t size, size_t *readbytes); - - static int readOld(BIO *h, char *buf, int size); - - static int write(BIO *b, const char *buf, size_t size, size_t *written); - - static int writeOld(BIO *h, const char *buf, int num); -}; - -} // namespace implementation - -} // end namespace transport diff --git a/libtransport/src/implementation/tls_socket_consumer.cc b/libtransport/src/implementation/tls_socket_consumer.cc deleted file mode 100644 index b368c4b88..000000000 --- a/libtransport/src/implementation/tls_socket_consumer.cc +++ /dev/null @@ -1,343 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <implementation/tls_socket_consumer.h> -#include <openssl/bio.h> -#include <openssl/ssl.h> -#include <openssl/tls1.h> - -#include <random> - -namespace transport { -namespace implementation { - -void TLSConsumerSocket::setInterestPayload(interface::ConsumerSocket &c, - const core::Interest &interest) { - Interest &int2 = const_cast<Interest &>(interest); - random_suffix_ = int2.getName().getSuffix(); - - if (payload_ != NULL) int2.appendPayload(std::move(payload_)); -} - -/* Return the number of read bytes in the return param */ -int readOldTLS(BIO *b, char *buf, int size) { - if (size < 0) return size; - - TLSConsumerSocket *socket; - socket = (TLSConsumerSocket *)BIO_get_data(b); - - std::unique_lock<std::mutex> lck(socket->mtx_); - - if (!socket->something_to_read_) { - if (!socket->transport_protocol_->isRunning()) { - socket->network_name_.setSuffix(socket->random_suffix_); - socket->ConsumerSocket::consume(socket->network_name_); - } - - if (!socket->something_to_read_) socket->cv_.wait(lck); - } - - size_t size_to_read, read; - size_t chain_size = socket->head_->length(); - - if (socket->head_->isChained()) - chain_size = socket->head_->computeChainDataLength(); - - if (chain_size > (size_t)size) { - read = size_to_read = (size_t)size; - } else { - read = size_to_read = chain_size; - socket->something_to_read_ = false; - } - - while (size_to_read) { - if (socket->head_->length() < size_to_read) { - std::memcpy(buf, socket->head_->data(), socket->head_->length()); - size_to_read -= socket->head_->length(); - buf += socket->head_->length(); - socket->head_ = socket->head_->pop(); - } else { - std::memcpy(buf, socket->head_->data(), size_to_read); - socket->head_->trimStart(size_to_read); - size_to_read = 0; - } - } - - return (int)read; -} - -/* Return the number of read bytes in readbytes */ -int readTLS(BIO *b, char *buf, size_t size, size_t *readbytes) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = readOldTLS(b, buf, (int)size); - - if (ret <= 0) { - *readbytes = 0; - return ret; - } - - *readbytes = (size_t)ret; - - return 1; -} - -/* Return the number of written bytes in the return param */ -int writeOldTLS(BIO *b, const char *buf, int num) { - TLSConsumerSocket *socket; - socket = (TLSConsumerSocket *)BIO_get_data(b); - - socket->payload_ = utils::MemBuf::copyBuffer(buf, num); - - socket->ConsumerSocket::setSocketOption( - ConsumerCallbacksOptions::INTEREST_OUTPUT, - (ConsumerInterestCallback)std::bind( - &TLSConsumerSocket::setInterestPayload, socket, std::placeholders::_1, - std::placeholders::_2)); - - return num; -} - -/* Return the number of written bytes in written */ -int writeTLS(BIO *b, const char *buf, size_t size, size_t *written) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = writeOldTLS(b, buf, (int)size); - - if (ret <= 0) { - *written = 0; - return ret; - } - - *written = (size_t)ret; - - return 1; -} - -long ctrlTLS(BIO *b, int cmd, long num, void *ptr) { return 1; } - -TLSConsumerSocket::TLSConsumerSocket(interface::ConsumerSocket *consumer_socket, - int protocol, SSL *ssl) - : ConsumerSocket(consumer_socket, protocol), - name_(), - decrypted_content_(), - payload_(), - head_(), - something_to_read_(false), - content_downloaded_(false), - random_suffix_(), - producer_namespace_(), - read_callback_decrypted_(), - mtx_(), - cv_(), - async_downloader_tls_() { - /* Create the (d)TLS state */ - const SSL_METHOD *meth = TLS_client_method(); - ctx_ = SSL_CTX_new(meth); - - int result = - SSL_CTX_set_ciphersuites(ctx_, - "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" - "SHA256:TLS_AES_128_GCM_SHA256"); - if (result != 1) { - throw errors::RuntimeException( - "Unable to set cipher list on TLS subsystem. Aborting."); - } - - SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); - SSL_CTX_set_ssl_version(ctx_, meth); - - ssl_ = ssl; - - BIO_METHOD *bio_meth = - BIO_meth_new(BIO_TYPE_CONNECT, "secure consumer socket"); - BIO_meth_set_read(bio_meth, readOldTLS); - BIO_meth_set_write(bio_meth, writeOldTLS); - BIO_meth_set_ctrl(bio_meth, ctrlTLS); - BIO *bio = BIO_new(bio_meth); - BIO_set_init(bio, 1); - BIO_set_data(bio, this); - SSL_set_bio(ssl_, bio, bio); - - std::default_random_engine generator; - std::uniform_int_distribution<int> distribution( - 1, std::numeric_limits<uint32_t>::max()); - random_suffix_ = 0; - - this->ConsumerSocket::setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, - this); -}; - -/* The producer interface is not owned by the application, so is TLSSocket task - * to deallocate the memory */ -TLSConsumerSocket::~TLSConsumerSocket() { delete consumer_interface_; } - -int TLSConsumerSocket::consume(const Name &name, - std::unique_ptr<utils::MemBuf> &&buffer) { - this->payload_ = std::move(buffer); - - this->ConsumerSocket::setSocketOption( - ConsumerCallbacksOptions::INTEREST_OUTPUT, - (ConsumerInterestCallback)std::bind( - &TLSConsumerSocket::setInterestPayload, this, std::placeholders::_1, - std::placeholders::_2)); - - return consume(name); -} - -int TLSConsumerSocket::consume(const Name &name) { - if (transport_protocol_->isRunning()) { - return CONSUMER_BUSY; - } - - if ((SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { - throw errors::RuntimeException("Handshake not performed"); - } - - return download_content(name); -} - -int TLSConsumerSocket::download_content(const Name &name) { - network_name_ = name; - network_name_.setSuffix(0); - something_to_read_ = false; - content_downloaded_ = false; - - std::size_t max_buffer_size = read_callback_decrypted_->maxBufferSize(); - std::size_t buffer_size = - read_callback_decrypted_->maxBufferSize() + SSL3_RT_MAX_PLAIN_LENGTH; - decrypted_content_ = utils::MemBuf::createCombined(buffer_size); - int result = -1; - std::size_t size = 0; - - while (!content_downloaded_ || something_to_read_) { - result = SSL_read(this->ssl_, decrypted_content_->writableTail(), - SSL3_RT_MAX_PLAIN_LENGTH); - - /* SSL_read returns the data only if there were SSL3_RT_MAX_PLAIN_LENGTH of - * the data has been fully downloaded */ - - /* ASSERT((result < SSL3_RT_MAX_PLAIN_LENGTH && content_downloaded_) || */ - /* result == SSL3_RT_MAX_PLAIN_LENGTH); */ - - if (result >= 0) { - size += result; - decrypted_content_->append(result); - } else { - throw errors::RuntimeException("Unable to download content"); - } - - if (decrypted_content_->length() >= max_buffer_size) { - if (read_callback_decrypted_->isBufferMovable()) { - /* No need to perform an additional copy. The whole buffer will be - * tranferred to the application. */ - read_callback_decrypted_->readBufferAvailable( - std::move(decrypted_content_)); - decrypted_content_ = utils::MemBuf::create(buffer_size); - } else { - /* The buffer will be copied into the application-provided buffer */ - uint8_t *buffer; - std::size_t length; - std::size_t total_length = decrypted_content_->length(); - - while (decrypted_content_->length()) { - buffer = nullptr; - length = 0; - read_callback_decrypted_->getReadBuffer(&buffer, &length); - - if (!buffer || !length) { - throw errors::RuntimeException( - "Invalid buffer provided by the application."); - } - - auto to_copy = std::min(decrypted_content_->length(), length); - std::memcpy(buffer, decrypted_content_->data(), to_copy); - decrypted_content_->trimStart(to_copy); - } - - read_callback_decrypted_->readDataAvailable(total_length); - decrypted_content_->clear(); - } - } - } - - read_callback_decrypted_->readSuccess(size); - - return CONSUMER_FINISHED; -} - -void TLSConsumerSocket::registerPrefix(const Prefix &producer_namespace) { - producer_namespace_ = producer_namespace; -} - -int TLSConsumerSocket::setSocketOption(int socket_option_key, - ReadCallback *socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, ReadCallback *socket_option_value) -> int { - switch (socket_option_key) { - case ConsumerCallbacksOptions::READ_CALLBACK: - read_callback_decrypted_ = socket_option_value; - break; - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); -} - -void TLSConsumerSocket::getReadBuffer(uint8_t **application_buffer, - size_t *max_length) {} - -void TLSConsumerSocket::readDataAvailable(size_t length) noexcept {} - -size_t TLSConsumerSocket::maxBufferSize() const { - return SSL3_RT_MAX_PLAIN_LENGTH; -} - -void TLSConsumerSocket::readBufferAvailable( - std::unique_ptr<utils::MemBuf> &&buffer) noexcept { - std::unique_lock<std::mutex> lck(this->mtx_); - - if (head_) { - head_->prependChain(std::move(buffer)); - } else { - head_ = std::move(buffer); - } - - something_to_read_ = true; - cv_.notify_one(); -} - -void TLSConsumerSocket::readError(const std::error_code &ec) noexcept {} - -void TLSConsumerSocket::readSuccess(std::size_t total_size) noexcept { - std::unique_lock<std::mutex> lck(this->mtx_); - content_downloaded_ = true; - something_to_read_ = true; - cv_.notify_one(); -} - -bool TLSConsumerSocket::isBufferMovable() noexcept { return true; } - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/tls_socket_consumer.h b/libtransport/src/implementation/tls_socket_consumer.h deleted file mode 100644 index a74f1ee10..000000000 --- a/libtransport/src/implementation/tls_socket_consumer.h +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/interfaces/socket_consumer.h> -#include <implementation/socket_consumer.h> -#include <openssl/ssl.h> - -namespace transport { -namespace implementation { - -class TLSConsumerSocket : public ConsumerSocket, - public interface::ConsumerSocket::ReadCallback { - /* Return the number of read bytes in readbytes */ - friend int readTLS(BIO *b, char *buf, size_t size, size_t *readbytes); - - /* Return the number of read bytes in the return param */ - friend int readOldTLS(BIO *h, char *buf, int size); - - /* Return the number of written bytes in written */ - friend int writeTLS(BIO *b, const char *buf, size_t size, size_t *written); - - /* Return the number of written bytes in the return param */ - friend int writeOldTLS(BIO *h, const char *buf, int num); - - friend long ctrlTLS(BIO *b, int cmd, long num, void *ptr); - - public: - explicit TLSConsumerSocket(interface::ConsumerSocket *consumer_socket, - int protocol, SSL *ssl_); - - ~TLSConsumerSocket(); - - int consume(const Name &name, std::unique_ptr<utils::MemBuf> &&buffer); - int consume(const Name &name) override; - - void registerPrefix(const Prefix &producer_namespace); - - int setSocketOption( - int socket_option_key, - interface::ConsumerSocket::ReadCallback *socket_option_value) override; - - using ConsumerSocket::getSocketOption; - using ConsumerSocket::setSocketOption; - - protected: - /* Callback invoked once an interest has been received and its payload - * decrypted */ - ConsumerInterestCallback on_interest_input_decrypted_; - ConsumerInterestCallback on_interest_process_decrypted_; - - private: - Name name_; - /* SSL handle */ - SSL *ssl_; - SSL_CTX *ctx_; - /* Chain of MemBuf to be used as a temporary buffer to pass descypted data - * from the underlying layer to the application */ - std::unique_ptr<utils::MemBuf> decrypted_content_; - /* Chain of MemBuf holding the payload to be written into interest or data */ - std::unique_ptr<utils::MemBuf> payload_; - /* Chain of MemBuf holding the data retrieved from the underlying layer */ - std::unique_ptr<utils::MemBuf> head_; - bool something_to_read_; - bool content_downloaded_; - uint32_t random_suffix_; - Prefix producer_namespace_; - interface::ConsumerSocket::ReadCallback *read_callback_decrypted_; - std::mutex mtx_; - /* Condition variable for the wait */ - std::condition_variable cv_; - utils::EventThread async_downloader_tls_; - - void setInterestPayload(interface::ConsumerSocket &c, - const core::Interest &interest); - - virtual void getReadBuffer(uint8_t **application_buffer, - size_t *max_length) override; - - virtual void readDataAvailable(size_t length) noexcept override; - - virtual size_t maxBufferSize() const override; - - virtual void readBufferAvailable( - std::unique_ptr<utils::MemBuf> &&buffer) noexcept override; - - virtual void readError(const std::error_code &ec) noexcept override; - - virtual void readSuccess(std::size_t total_size) noexcept override; - - virtual bool isBufferMovable() noexcept override; - - int download_content(const Name &name); -}; - -} // namespace implementation -} // end namespace transport diff --git a/libtransport/src/implementation/tls_socket_producer.cc b/libtransport/src/implementation/tls_socket_producer.cc deleted file mode 100644 index 47f3b43a6..000000000 --- a/libtransport/src/implementation/tls_socket_producer.cc +++ /dev/null @@ -1,550 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/interfaces/socket_producer.h> -#include <implementation/p2psecure_socket_producer.h> -#include <implementation/tls_socket_producer.h> -#include <openssl/bio.h> -#include <openssl/rand.h> -#include <openssl/ssl.h> - -namespace transport { -namespace implementation { - -/* Return the number of read bytes in readbytes */ -int TLSProducerSocket::read(BIO *b, char *buf, size_t size, size_t *readbytes) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = TLSProducerSocket::readOld(b, buf, (int)size); - - if (ret <= 0) { - *readbytes = 0; - return ret; - } - - *readbytes = (size_t)ret; - - return 1; -} - -/* Return the number of read bytes in the return param */ -int TLSProducerSocket::readOld(BIO *b, char *buf, int size) { - TLSProducerSocket *socket; - socket = (TLSProducerSocket *)BIO_get_data(b); - - std::unique_lock<std::mutex> lck(socket->mtx_); - - DLOG_IF(INFO, VLOG_IS_ON(4)) << "Start wait on the CV."; - - if (!socket->something_to_read_) { - (socket->cv_).wait(lck); - } - - DLOG_IF(INFO, VLOG_IS_ON(4)) << "CV unlocked."; - - /* Either there already is something to read, or the thread has been waken up. - * We must return the payload in the interest anyway */ - utils::MemBuf *membuf = socket->handshake_packet_->next(); - int size_to_read; - - if ((int)membuf->length() > size) { - size_to_read = size; - } else { - size_to_read = (int)membuf->length(); - socket->something_to_read_ = false; - } - - std::memcpy(buf, membuf->data(), size_to_read); - membuf->trimStart(size_to_read); - - return size_to_read; -} - -/* Return the number of written bytes in written */ -int TLSProducerSocket::write(BIO *b, const char *buf, size_t size, - size_t *written) { - int ret; - - if (size > INT_MAX) size = INT_MAX; - - ret = TLSProducerSocket::writeOld(b, buf, (int)size); - - if (ret <= 0) { - *written = 0; - return ret; - } - - *written = (size_t)ret; - - return 1; -} - -/* Return the number of written bytes in the return param */ -int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { - TLSProducerSocket *socket; - socket = (TLSProducerSocket *)BIO_get_data(b); - - if (socket->getHandshakeState() != SERVER_FINISHED && socket->first_) { - uint32_t making_manifest = socket->parent_->making_manifest_; - - //! socket->tls_chunks_ corresponds to is_last - socket->tls_chunks_--; - socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, - 0U); - socket->parent_->ProducerSocket::produceStream( - socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, - socket->last_segment_); - socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, - making_manifest); - socket->first_ = false; - } else { - socket->still_writing_ = true; - - std::unique_ptr<utils::MemBuf> mbuf = - utils::MemBuf::copyBuffer(buf, (std::size_t)num, 0, 0); - auto a = mbuf.release(); - - socket->async_thread_.add([socket = socket, a]() { - auto mbuf = std::unique_ptr<utils::MemBuf>(a); - - socket->tls_chunks_--; - socket->to_call_oncontentproduced_--; - - socket->last_segment_ += socket->ProducerSocket::produceStream( - socket->name_, std::move(mbuf), socket->tls_chunks_ == 0, - socket->last_segment_); - - ProducerContentCallback *on_content_produced_application; - socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, - &on_content_produced_application); - - if (socket->to_call_oncontentproduced_ == 0 && - on_content_produced_application) { - on_content_produced_application->operator()(*socket->getInterface(), - std::error_code(), 0); - } - }); - } - - return num; -} - -TLSProducerSocket::TLSProducerSocket(interface::ProducerSocket *producer_socket, - P2PSecureProducerSocket *parent, - const Name &handshake_name) - : ProducerSocket(producer_socket, - ProductionProtocolAlgorithms::BYTE_STREAM), - on_content_produced_application_(), - mtx_(), - cv_(), - something_to_read_(false), - handshake_state_(UNINITIATED), - name_(), - handshake_packet_(), - last_segment_(0), - parent_(parent), - first_(true), - handshake_name_(handshake_name), - tls_chunks_(0), - to_call_oncontentproduced_(0), - still_writing_(false), - encryption_thread_() { - const SSL_METHOD *meth = TLS_server_method(); - ctx_ = SSL_CTX_new(meth); - - /* Setup SSL context (identity and parameter to use TLS 1.3) */ - SSL_CTX_use_certificate(ctx_, parent->cert_509_); - SSL_CTX_use_PrivateKey(ctx_, parent->pkey_rsa_); - - int result = - SSL_CTX_set_ciphersuites(ctx_, - "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" - "SHA256:TLS_AES_128_GCM_SHA256"); - - if (result != 1) { - throw errors::RuntimeException( - "Unable to set cipher list on TLS subsystem. Aborting."); - } - - // We force it to be TLS 1.3 - SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); - SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); - SSL_CTX_set_num_tickets(ctx_, 0); - - result = SSL_CTX_add_custom_ext( - ctx_, 100, SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS, - TLSProducerSocket::addHicnKeyIdCb, TLSProducerSocket::freeHicnKeyIdCb, - this, TLSProducerSocket::parseHicnKeyIdCb, NULL); - - ssl_ = SSL_new(ctx_); - - /* Setup this producer socker as the bio that TLS will use to write and read - * data (in stream mode) */ - BIO_METHOD *bio_meth = - BIO_meth_new(BIO_TYPE_ACCEPT, "secure producer socket"); - BIO_meth_set_read(bio_meth, TLSProducerSocket::readOld); - BIO_meth_set_write(bio_meth, TLSProducerSocket::writeOld); - BIO_meth_set_ctrl(bio_meth, TLSProducerSocket::ctrl); - BIO *bio = BIO_new(bio_meth); - BIO_set_init(bio, 1); - BIO_set_data(bio, this); - SSL_set_bio(ssl_, bio, bio); - - /* Set the callback so that when an interest is received we catch it and we - * decrypt the payload before passing it to the application. */ - this->ProducerSocket::setSocketOption( - ProducerCallbacksOptions::CACHE_MISS, - (ProducerInterestCallback)std::bind(&TLSProducerSocket::cacheMiss, this, - std::placeholders::_1, - std::placeholders::_2)); - - this->ProducerSocket::setSocketOption( - ProducerCallbacksOptions::CONTENT_PRODUCED, - (ProducerContentCallback)bind( - &TLSProducerSocket::onContentProduced, this, std::placeholders::_1, - std::placeholders::_2, std::placeholders::_3)); -} - -/* The producer interface is not owned by the application, so is TLSSocket task - * to deallocate the memory */ -TLSProducerSocket::~TLSProducerSocket() { delete producer_interface_; } - -void TLSProducerSocket::accept() { - HandshakeState handshake_state = getHandshakeState(); - - if (handshake_state == UNINITIATED || handshake_state == CLIENT_HELLO) { - tls_chunks_ = 1; - int result = SSL_accept(ssl_); - - if (result != 1) - throw errors::RuntimeException("Unable to perform client handshake"); - } - - parent_->list_producers.push_front( - std::move(parent_->map_producers[handshake_name_])); - parent_->map_producers.erase(handshake_name_); - - ProducerInterestCallback *on_interest_process_decrypted; - getSocketOption(ProducerCallbacksOptions::CACHE_MISS, - &on_interest_process_decrypted); - - if (*on_interest_process_decrypted) { - Interest inter(std::move(*handshake_packet_)); - handshake_packet_.reset(); - on_interest_process_decrypted->operator()(*getInterface(), inter); - } else { - throw errors::RuntimeException( - "On interest process unset: unable to perform handshake"); - } - - handshake_state_ = SERVER_FINISHED; - DLOG_IF(INFO, VLOG_IS_ON(2)) << "Handshake performed!"; -} - -int TLSProducerSocket::async_accept() { - if (!async_thread_.stopped()) { - async_thread_.add([this]() { this->accept(); }); - } else { - throw errors::RuntimeException( - "Async thread not running: unable to perform handshake"); - } - - return 1; -} - -void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { - HandshakeState handshake_state = getHandshakeState(); - - if (handshake_state == UNINITIATED || handshake_state == CLIENT_HELLO) { - std::unique_lock<std::mutex> lck(mtx_); - - name_ = interest.getName(); - // interest.separateHeaderPayload(); - handshake_packet_ = interest.acquireMemBufReference(); - something_to_read_ = true; - - cv_.notify_one(); - return; - } else if (handshake_state == SERVER_FINISHED) { - // interest.separateHeaderPayload(); - handshake_packet_ = interest.acquireMemBufReference(); - something_to_read_ = true; - - if (interest.getPayload()->length() > 0) { - SSL_read( - ssl_, - const_cast<unsigned char *>(interest.getPayload()->writableData()), - (int)interest.getPayload()->length()); - } - - ProducerInterestCallback *on_interest_input_decrypted; - getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, - &on_interest_input_decrypted); - - if (*on_interest_input_decrypted) - (*on_interest_input_decrypted)(*getInterface(), interest); - } -} - -void TLSProducerSocket::cacheMiss(interface::ProducerSocket &p, - Interest &interest) { - HandshakeState handshake_state = getHandshakeState(); - - DLOG_IF(INFO, VLOG_IS_ON(3)) << "On cache miss in TLS socket producer."; - - if (handshake_state == CLIENT_HELLO) { - std::unique_lock<std::mutex> lck(mtx_); - - // interest.separateHeaderPayload(); - handshake_packet_ = interest.acquireMemBufReference(); - something_to_read_ = true; - handshake_state_ = CLIENT_FINISHED; - - cv_.notify_one(); - } else if (handshake_state == SERVER_FINISHED) { - // interest.separateHeaderPayload(); - handshake_packet_ = interest.acquireMemBufReference(); - something_to_read_ = true; - - if (interest.getPayload()->length() > 0) { - SSL_read( - ssl_, - const_cast<unsigned char *>(interest.getPayload()->writableData()), - (int)interest.getPayload()->length()); - } - - if (on_interest_process_decrypted_ != VOID_HANDLER) - on_interest_process_decrypted_(*getInterface(), interest); - } -} - -TLSProducerSocket::HandshakeState TLSProducerSocket::getHandshakeState() { - if (SSL_in_before(ssl_)) { - handshake_state_ = UNINITIATED; - } - - if (SSL_in_init(ssl_) && handshake_state_ == UNINITIATED) { - handshake_state_ = CLIENT_HELLO; - } - - return handshake_state_; -} - -void TLSProducerSocket::onContentProduced(interface::ProducerSocket &p, - const std::error_code &err, - uint64_t bytes_written) {} - -uint32_t TLSProducerSocket::produceStream( - const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last, uint32_t start_offset) { - if (getHandshakeState() != SERVER_FINISHED) { - throw errors::RuntimeException( - "New handshake on the same P2P secure producer socket not supported"); - } - - size_t buf_size = buffer->length(); - name_ = portal_->getServedNamespaces().begin()->mapName(content_name); - tls_chunks_ = to_call_oncontentproduced_ = - (int)ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); - - if (!is_last) { - tls_chunks_++; - } - - last_segment_ = start_offset; - - SSL_write(ssl_, buffer->data(), (int)buf_size); - BIO *wbio = SSL_get_wbio(ssl_); - int i = BIO_flush(wbio); - (void)i; // To shut up gcc 5 - - return 0; -} - -long TLSProducerSocket::ctrl(BIO *b, int cmd, long num, void *ptr) { - if (cmd == BIO_CTRL_FLUSH) { - } - - return 1; -} - -int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char **out, size_t *outlen, - X509 *x, size_t chainidx, int *al, - void *add_arg) { - TLSProducerSocket *socket = reinterpret_cast<TLSProducerSocket *>(add_arg); - - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "On addHicnKeyIdCb, for the prefix registration."; - - if (ext_type == 100) { - auto &prefix = *socket->parent_->portal_->getServedNamespaces().begin(); - const ip_prefix_t &ip_prefix = prefix.toIpPrefixStruct(); - int inet_family = prefix.getAddressFamily(); - uint16_t prefix_len_bits = prefix.getPrefixLength(); - uint8_t prefix_len_bytes = prefix_len_bits / 8; - uint8_t prefix_len_u32 = prefix_len_bits / 32; - - ip_prefix_t *out_ip = (ip_prefix_t *)malloc(sizeof(ip_prefix_t)); - out_ip->family = inet_family; - out_ip->len = prefix_len_bits + 32; - u8 *out_ip_buf = const_cast<u8 *>( - ip_address_get_buffer(&(out_ip->address), inet_family)); - *out = reinterpret_cast<unsigned char *>(out_ip); - - RAND_bytes((unsigned char *)&socket->key_id_, 4); - - memcpy(out_ip_buf, ip_address_get_buffer(&(ip_prefix.address), inet_family), - prefix_len_bytes); - memcpy((out_ip_buf + prefix_len_bytes), &socket->key_id_, 4); - *outlen = sizeof(ip_prefix_t); - - ip_address_t mask = {}; - ip_address_t keyId_component = {}; - u32 *mask_buf; - u32 *keyId_component_buf; - - switch (inet_family) { - case AF_INET: - mask_buf = &(mask.v4.as_u32); - keyId_component_buf = &(keyId_component.v4.as_u32); - break; - case AF_INET6: - mask_buf = mask.v6.as_u32; - keyId_component_buf = keyId_component.v6.as_u32; - break; - default: - throw errors::RuntimeException("Unknown protocol"); - } - - if (prefix_len_bits > (inet_family == AF_INET6 ? IPV6_ADDR_LEN_BITS - 32 - : IPV4_ADDR_LEN_BITS - 32)) - throw errors::RuntimeException( - "Not enough space in the content name to add key_id"); - - mask_buf[prefix_len_u32] = 0xffffffff; - keyId_component_buf[prefix_len_u32] = socket->key_id_; - socket->last_segment_ = 0; - - socket->on_interest_process_decrypted_ = - socket->parent_->on_interest_process_decrypted_; - - socket->registerPrefix( - Prefix(prefix.getName(Name(inet_family, (uint8_t *)&mask), - Name(inet_family, (uint8_t *)&keyId_component), - prefix.getName()), - out_ip->len)); - socket->connect(); - } - return 1; -} - -void TLSProducerSocket::freeHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char *out, - void *add_arg) { - free(const_cast<unsigned char *>(out)); -} - -int TLSProducerSocket::parseHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, - const unsigned char *in, size_t inlen, - X509 *x, size_t chainidx, int *al, - void *add_arg) { - return 1; -} - -int TLSProducerSocket::setSocketOption( - int socket_option_key, ProducerInterestCallback socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerInterestCallback socket_option_value) -> int { - int result = SOCKET_OPTION_SET; - - switch (socket_option_key) { - case ProducerCallbacksOptions::INTEREST_INPUT: - on_interest_input_decrypted_ = socket_option_value; - break; - - case ProducerCallbacksOptions::INTEREST_DROP: - on_interest_dropped_input_buffer_ = socket_option_value; - break; - - case ProducerCallbacksOptions::INTEREST_PASS: - on_interest_inserted_input_buffer_ = socket_option_value; - break; - - case ProducerCallbacksOptions::CACHE_HIT: - on_interest_satisfied_output_buffer_ = socket_option_value; - break; - - case ProducerCallbacksOptions::CACHE_MISS: - on_interest_process_decrypted_ = socket_option_value; - break; - - default: - result = SOCKET_OPTION_NOT_SET; - break; - } - - return result; - }); -} - -int TLSProducerSocket::setSocketOption( - int socket_option_key, ProducerContentCallback socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerContentCallback socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::CONTENT_PRODUCED: - on_content_produced_application_ = socket_option_value; - break; - - default: - return SOCKET_OPTION_NOT_SET; - } - - return SOCKET_OPTION_SET; - }); -} - -int TLSProducerSocket::getSocketOption( - int socket_option_key, ProducerContentCallback **socket_option_value) { - return rescheduleOnIOService( - socket_option_key, socket_option_value, - [this](int socket_option_key, - ProducerContentCallback **socket_option_value) -> int { - switch (socket_option_key) { - case ProducerCallbacksOptions::CONTENT_PRODUCED: - *socket_option_value = &on_content_produced_application_; - break; - - default: - return SOCKET_OPTION_NOT_GET; - } - - return SOCKET_OPTION_GET; - }); -} - -} // namespace implementation -} // namespace transport diff --git a/libtransport/src/implementation/tls_socket_producer.h b/libtransport/src/implementation/tls_socket_producer.h deleted file mode 100644 index 0e958b321..000000000 --- a/libtransport/src/implementation/tls_socket_producer.h +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <implementation/socket_producer.h> -#include <openssl/ssl.h> - -#include <condition_variable> -#include <mutex> - -namespace transport { -namespace implementation { - -class P2PSecureProducerSocket; - -class TLSProducerSocket : virtual public ProducerSocket { - friend class P2PSecureProducerSocket; - - public: - explicit TLSProducerSocket(interface::ProducerSocket *producer_socket, - P2PSecureProducerSocket *parent, - const Name &handshake_name); - - ~TLSProducerSocket(); - - uint32_t produceStream(const Name &content_name, const uint8_t *buffer, - size_t buffer_size, bool is_last = true, - uint32_t start_offset = 0) override { - return produceStream(content_name, - utils::MemBuf::copyBuffer(buffer, buffer_size), - is_last, start_offset); - } - - uint32_t produceStream(const Name &content_name, - std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, - uint32_t start_offset = 0) override; - - virtual void accept(); - - virtual int async_accept(); - - virtual int setSocketOption( - int socket_option_key, - ProducerInterestCallback socket_option_value) override; - - virtual int setSocketOption( - int socket_option_key, - ProducerContentCallback socket_option_value) override; - - virtual int getSocketOption( - int socket_option_key, - ProducerContentCallback **socket_option_value) override; - - int getSocketOption(int socket_option_key, - ProducerContentCallback &socket_option_value); - - int getSocketOption(int socket_option_key, - ProducerInterestCallback &socket_option_value); - - using ProducerSocket::getSocketOption; - // using ProducerSocket::onInterest; - using ProducerSocket::setSocketOption; - - protected: - enum HandshakeState { - UNINITIATED, - CLIENT_HELLO, // when CLIENT_HELLO interest has been received - CLIENT_FINISHED, // when CLIENT_FINISHED interest has been received - SERVER_FINISHED, // when handshake is done - }; - /* Callback invoked once an interest has been received and its payload - * decrypted */ - ProducerInterestCallback on_interest_input_decrypted_; - ProducerInterestCallback on_interest_process_decrypted_; - ProducerContentCallback on_content_produced_application_; - std::mutex mtx_; - /* Condition variable for the wait */ - std::condition_variable cv_; - /* Bool variable, true if there is something to read (an interest arrived) */ - bool something_to_read_; - /* Bool variable, true if CLIENT_FINISHED interest has been received */ - HandshakeState handshake_state_; - /* First interest that open a secure connection */ - transport::core::Name name_; - /* SSL handle */ - SSL *ssl_; - SSL_CTX *ctx_; - Packet::MemBufPtr handshake_packet_; - std::unique_ptr<utils::MemBuf> head_; - std::uint32_t last_segment_; - std::uint32_t key_id_; - std::thread *handshake; - P2PSecureProducerSocket *parent_; - bool first_; - Name handshake_name_; - int tls_chunks_; - int to_call_oncontentproduced_; - bool still_writing_; - utils::EventThread encryption_thread_; - utils::EventThread async_thread_; - - void onInterest(ProducerSocket &p, Interest &interest); - - void cacheMiss(interface::ProducerSocket &p, Interest &interest); - - /* Return the number of read bytes in readbytes */ - static int read(BIO *b, char *buf, size_t size, size_t *readbytes); - - /* Return the number of read bytes in the return param */ - static int readOld(BIO *h, char *buf, int size); - - /* Return the number of written bytes in written */ - static int write(BIO *b, const char *buf, size_t size, size_t *written); - - /* Return the number of written bytes in the return param */ - static int writeOld(BIO *h, const char *buf, int num); - - static long ctrl(BIO *b, int cmd, long num, void *ptr); - - static int addHicnKeyIdCb(SSL *s, unsigned int ext_type, unsigned int context, - const unsigned char **out, size_t *outlen, X509 *x, - size_t chainidx, int *al, void *add_arg); - - static void freeHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, const unsigned char *out, - void *add_arg); - - static int parseHicnKeyIdCb(SSL *s, unsigned int ext_type, - unsigned int context, const unsigned char *in, - size_t inlen, X509 *x, size_t chainidx, int *al, - void *add_arg); - - void onContentProduced(interface::ProducerSocket &p, - const std::error_code &err, uint64_t bytes_written); - - HandshakeState getHandshakeState(); -}; - -} // namespace implementation -} // end namespace transport diff --git a/libtransport/src/interfaces/CMakeLists.txt b/libtransport/src/interfaces/CMakeLists.txt index 0a0603ac8..bf8f8dcf8 100644 --- a/libtransport/src/interfaces/CMakeLists.txt +++ b/libtransport/src/interfaces/CMakeLists.txt @@ -19,20 +19,4 @@ list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc ) -if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") - list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc - ) - - list(APPEND HEADER_FILES - # ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h - ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h - ) -endif() - set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) diff --git a/libtransport/src/interfaces/global_configuration.cc b/libtransport/src/interfaces/global_configuration.cc index cecdacc07..f8e7a90c3 100644 --- a/libtransport/src/interfaces/global_configuration.cc +++ b/libtransport/src/interfaces/global_configuration.cc @@ -14,6 +14,7 @@ */ #include <core/global_configuration.h> +#include <core/global_workers.h> #include <glog/logging.h> #include <hicn/transport/interfaces/global_conf_interface.h> @@ -23,10 +24,29 @@ namespace transport { namespace interface { namespace global_config { -void parseConfigurationFile(const std::string& path) { +GlobalConfigInterface::GlobalConfigInterface() { libtransportConfigInit(); } + +GlobalConfigInterface::~GlobalConfigInterface() { + libtransportConfigTerminate(); +} + +void GlobalConfigInterface::parseConfigurationFile( + const std::string &path) const { core::GlobalConfiguration::getInstance().parseConfiguration(path); } +void GlobalConfigInterface::libtransportConfigInit() const { + // nothing to do +} + +void GlobalConfigInterface::libtransportConfigTerminate() const { + // cleanup workers + auto &workers = core::GlobalWorkers::getInstance().getWorkers(); + for (auto &worker : workers) { + worker.stop(); + } +} + void ConfigurationObject::get() { std::error_code ec; core::GlobalConfiguration::getInstance().getConfiguration(*this, ec); diff --git a/libtransport/src/interfaces/p2psecure_socket_consumer.cc b/libtransport/src/interfaces/p2psecure_socket_consumer.cc deleted file mode 100644 index e329a50f1..000000000 --- a/libtransport/src/interfaces/p2psecure_socket_consumer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/interfaces/p2psecure_socket_consumer.h> -#include <implementation/p2psecure_socket_consumer.h> - -namespace transport { -namespace interface { - -P2PSecureConsumerSocket::P2PSecureConsumerSocket(int handshake_protocol, - int transport_protocol) - : ConsumerSocket() { - socket_ = std::unique_ptr<implementation::ConsumerSocket>( - new implementation::P2PSecureConsumerSocket(this, handshake_protocol, - transport_protocol)); -} - -void P2PSecureConsumerSocket::registerPrefix(const Prefix &producer_namespace) { - implementation::P2PSecureConsumerSocket &secure_consumer_socket = - *(static_cast<implementation::P2PSecureConsumerSocket *>(socket_.get())); - secure_consumer_socket.registerPrefix(producer_namespace); -} - -} // namespace interface -} // namespace transport diff --git a/libtransport/src/interfaces/p2psecure_socket_producer.cc b/libtransport/src/interfaces/p2psecure_socket_producer.cc deleted file mode 100644 index 5f98302d0..000000000 --- a/libtransport/src/interfaces/p2psecure_socket_producer.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <hicn/transport/interfaces/p2psecure_socket_producer.h> -#include <implementation/p2psecure_socket_producer.h> - -namespace transport { -namespace interface { - -P2PSecureProducerSocket::P2PSecureProducerSocket() { - socket_ = std::make_unique<implementation::P2PSecureProducerSocket>(this); -} - -P2PSecureProducerSocket::P2PSecureProducerSocket(bool rtc, - std::string &keystore_path, - std::string &keystore_pwd) { - socket_ = std::make_unique<implementation::P2PSecureProducerSocket>( - this, rtc, keystore_path, keystore_pwd); -} - -} // namespace interface -} // namespace transport diff --git a/libtransport/src/interfaces/portal.cc b/libtransport/src/interfaces/portal.cc index 84634a282..898766c50 100644 --- a/libtransport/src/interfaces/portal.cc +++ b/libtransport/src/interfaces/portal.cc @@ -35,10 +35,10 @@ class Portal::Impl { return portal_->interestIsPending(name); } - void sendInterest(core::Interest::Ptr &&interest, + void sendInterest(core::Interest::Ptr &interest, uint32_t lifetime, OnContentObjectCallback &&on_content_object_callback, OnInterestTimeoutCallback &&on_interest_timeout_callback) { - portal_->sendInterest(std::move(interest), + portal_->sendInterest(interest, lifetime, std::move(on_content_object_callback), std::move(on_interest_timeout_callback)); } @@ -86,10 +86,10 @@ bool Portal::interestIsPending(const core::Name &name) { } void Portal::sendInterest( - core::Interest::Ptr &&interest, + core::Interest::Ptr &interest, uint32_t lifetime, OnContentObjectCallback &&on_content_object_callback, OnInterestTimeoutCallback &&on_interest_timeout_callback) { - implementation_->sendInterest(std::move(interest), + implementation_->sendInterest(interest, lifetime, std::move(on_content_object_callback), std::move(on_interest_timeout_callback)); } diff --git a/libtransport/src/interfaces/socket_consumer.cc b/libtransport/src/interfaces/socket_consumer.cc index 747dc0974..cc496c8a5 100644 --- a/libtransport/src/interfaces/socket_consumer.cc +++ b/libtransport/src/interfaces/socket_consumer.cc @@ -102,6 +102,12 @@ int ConsumerSocket::setSocketOption(int socket_option_key, int ConsumerSocket::setSocketOption( int socket_option_key, + const std::shared_ptr<auth::Signer> &socket_option_value) { + return socket_->setSocketOption(socket_option_key, socket_option_value); +} + +int ConsumerSocket::setSocketOption( + int socket_option_key, const std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); } @@ -163,6 +169,11 @@ int ConsumerSocket::getSocketOption(int socket_option_key, } int ConsumerSocket::getSocketOption( + int socket_option_key, std::shared_ptr<auth::Signer> &socket_option_value) { + return socket_->getSocketOption(socket_option_key, socket_option_value); +} + +int ConsumerSocket::getSocketOption( int socket_option_key, std::shared_ptr<auth::Verifier> &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); diff --git a/libtransport/src/interfaces/socket_producer.cc b/libtransport/src/interfaces/socket_producer.cc index 10613c0e1..2155ebd78 100644 --- a/libtransport/src/interfaces/socket_producer.cc +++ b/libtransport/src/interfaces/socket_producer.cc @@ -148,6 +148,12 @@ int ProducerSocket::setSocketOption( return socket_->setSocketOption(socket_option_key, socket_option_value); } +int ProducerSocket::setSocketOption( + int socket_option_key, + const std::shared_ptr<auth::Verifier> &socket_option_value) { + return socket_->setSocketOption(socket_option_key, socket_option_value); +} + int ProducerSocket::setSocketOption(int socket_option_key, Packet::Format socket_option_value) { return socket_->setSocketOption(socket_option_key, socket_option_value); @@ -194,6 +200,12 @@ int ProducerSocket::getSocketOption( return socket_->getSocketOption(socket_option_key, socket_option_value); } +int ProducerSocket::getSocketOption( + int socket_option_key, + std::shared_ptr<auth::Verifier> &socket_option_value) { + return socket_->getSocketOption(socket_option_key, socket_option_value); +} + int ProducerSocket::getSocketOption(int socket_option_key, Packet::Format &socket_option_value) { return socket_->getSocketOption(socket_option_key, socket_option_value); diff --git a/libtransport/src/interfaces/tls_rtc_socket_producer.cc b/libtransport/src/interfaces/tls_rtc_socket_producer.cc deleted file mode 100644 index 6bf1b011c..000000000 --- a/libtransport/src/interfaces/tls_rtc_socket_producer.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <implementation/tls_rtc_socket_producer.h> -#include <interfaces/tls_rtc_socket_producer.h> - -namespace transport { -namespace interface { - -TLSRTCProducerSocket::TLSRTCProducerSocket( - implementation::TLSRTCProducerSocket *implementation) { - socket_ = - std::unique_ptr<implementation::TLSRTCProducerSocket>(implementation); -} - -TLSRTCProducerSocket::~TLSRTCProducerSocket() { socket_.release(); } - -} // namespace interface -} // namespace transport diff --git a/libtransport/src/interfaces/tls_rtc_socket_producer.h b/libtransport/src/interfaces/tls_rtc_socket_producer.h deleted file mode 100644 index b8b6ec298..000000000 --- a/libtransport/src/interfaces/tls_rtc_socket_producer.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/interfaces/socket_producer.h> - -namespace transport { - -namespace implementation { -class TLSRTCProducerSocket; -} - -namespace interface { - -class TLSRTCProducerSocket : public ProducerSocket { - public: - TLSRTCProducerSocket(implementation::TLSRTCProducerSocket *implementation); - - ~TLSRTCProducerSocket(); -}; - -} // namespace interface -} // end namespace transport diff --git a/libtransport/src/interfaces/tls_socket_consumer.cc b/libtransport/src/interfaces/tls_socket_consumer.cc deleted file mode 100644 index 24060d1d8..000000000 --- a/libtransport/src/interfaces/tls_socket_consumer.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <implementation/tls_socket_consumer.h> -#include <interfaces/tls_socket_consumer.h> - -namespace transport { -namespace interface { - -TLSConsumerSocket::TLSConsumerSocket( - implementation::TLSConsumerSocket *implementation) { - socket_ = std::unique_ptr<implementation::TLSConsumerSocket>(implementation); -} - -TLSConsumerSocket::~TLSConsumerSocket() { socket_.release(); } - -} // namespace interface -} // namespace transport diff --git a/libtransport/src/interfaces/tls_socket_consumer.h b/libtransport/src/interfaces/tls_socket_consumer.h deleted file mode 100644 index 242dc91a5..000000000 --- a/libtransport/src/interfaces/tls_socket_consumer.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/interfaces/socket_consumer.h> - -namespace transport { - -namespace implementation { -class TLSConsumerSocket; -} - -namespace interface { - -class TLSConsumerSocket : public ConsumerSocket { - public: - TLSConsumerSocket(implementation::TLSConsumerSocket *implementation); - ~TLSConsumerSocket(); -}; - -} // namespace interface - -} // end namespace transport diff --git a/libtransport/src/interfaces/tls_socket_producer.cc b/libtransport/src/interfaces/tls_socket_producer.cc deleted file mode 100644 index b2b9e723a..000000000 --- a/libtransport/src/interfaces/tls_socket_producer.cc +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <implementation/tls_socket_producer.h> -#include <interfaces/tls_socket_producer.h> - -namespace transport { -namespace interface { - -TLSProducerSocket::TLSProducerSocket( - implementation::TLSProducerSocket *implementation) { - socket_ = std::unique_ptr<implementation::TLSProducerSocket>(implementation); -} - -TLSProducerSocket::~TLSProducerSocket() { socket_.release(); } - -} // namespace interface -} // namespace transport diff --git a/libtransport/src/interfaces/tls_socket_producer.h b/libtransport/src/interfaces/tls_socket_producer.h deleted file mode 100644 index 9b31cb483..000000000 --- a/libtransport/src/interfaces/tls_socket_producer.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <hicn/transport/interfaces/socket_producer.h> - -namespace transport { - -namespace implementation { -class TLSProducerSocket; -} - -namespace interface { - -class TLSProducerSocket : public ProducerSocket { - public: - TLSProducerSocket(implementation::TLSProducerSocket *implementation); - ~TLSProducerSocket(); -}; - -} // namespace interface - -} // end namespace transport diff --git a/libtransport/src/io_modules/CMakeLists.txt b/libtransport/src/io_modules/CMakeLists.txt index f4143de04..f1a27d3cb 100644 --- a/libtransport/src/io_modules/CMakeLists.txt +++ b/libtransport/src/io_modules/CMakeLists.txt @@ -15,7 +15,7 @@ ############################################################## # Android case: no submodules ############################################################## -if (${CMAKE_SYSTEM_NAME} MATCHES Android) +if (${CMAKE_SYSTEM_NAME} MATCHES Android OR ${CMAKE_SYSTEM_NAME} MATCHES iOS) list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/hicn-light-ng/hicn_forwarder_module.cc ) diff --git a/libtransport/src/io_modules/forwarder/CMakeLists.txt b/libtransport/src/io_modules/forwarder/CMakeLists.txt index 3922316d3..2235d842e 100644 --- a/libtransport/src/io_modules/forwarder/CMakeLists.txt +++ b/libtransport/src/io_modules/forwarder/CMakeLists.txt @@ -17,7 +17,6 @@ list(APPEND MODULE_HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/errors.h ${CMAKE_CURRENT_SOURCE_DIR}/forwarder_module.h ${CMAKE_CURRENT_SOURCE_DIR}/forwarder.h - ${CMAKE_CURRENT_SOURCE_DIR}/global_counter.h ) list(APPEND MODULE_SOURCE_FILES diff --git a/libtransport/src/io_modules/forwarder/forwarder.cc b/libtransport/src/io_modules/forwarder/forwarder.cc index 3ae5bf397..bfe4dd5de 100644 --- a/libtransport/src/io_modules/forwarder/forwarder.cc +++ b/libtransport/src/io_modules/forwarder/forwarder.cc @@ -14,12 +14,12 @@ */ #include <core/global_configuration.h> +#include <core/global_id_counter.h> #include <core/local_connector.h> #include <core/udp_connector.h> #include <core/udp_listener.h> #include <glog/logging.h> #include <io_modules/forwarder/forwarder.h> -#include <io_modules/forwarder/global_id_counter.h> namespace transport { @@ -90,11 +90,13 @@ Connector::Id Forwarder::registerLocalConnector( asio::io_service &io_service, Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback) { utils::SpinLock::Acquire locked(connector_lock_); auto id = GlobalCounter<Connector::Id>::getInstance().getNext(); auto connector = std::make_shared<LocalConnector>( - io_service, receive_callback, sent_callback, nullptr, reconnect_callback); + io_service, std::move(receive_callback), std::move(sent_callback), + std::move(close_callback), std::move(reconnect_callback)); connector->setConnectorId(id); local_connectors_.emplace(id, std::move(connector)); return id; @@ -150,34 +152,13 @@ void Forwarder::onPacketReceived(Connector *connector, return; } - for (auto &packet_buffer_ptr : packets) { - auto &packet_buffer = *packet_buffer_ptr; - - // Figure out the type of packet we received - bool is_interest = Packet::isInterest(packet_buffer.data()); - - Packet *packet = nullptr; - if (is_interest) { - packet = static_cast<Interest *>(&packet_buffer); - } else { - packet = static_cast<ContentObject *>(&packet_buffer); - } - - for (auto &c : local_connectors_) { - auto role = c.second->getRole(); - auto is_producer = role == Connector::Role::PRODUCER; - if ((is_producer && is_interest) || (!is_producer && !is_interest)) { - c.second->send(*packet); - } else { - LOG(ERROR) << "Error sending packet to local connector. is_interest = " - << is_interest << " - is_producer = " << is_producer; - } - } + for (auto &c : local_connectors_) { + c.second->receive(packets); + } - // PCS Lookup + FIB lookup. Skip for now + // PCS Lookup + FIB lookup. Skip for now - // Forward packet to local connectors - } + // Forward packet to local connectors } void Forwarder::send(Packet &packet) { @@ -304,4 +285,4 @@ void Forwarder::parseForwarderConfiguration( } } // namespace core -} // namespace transport
\ No newline at end of file +} // namespace transport diff --git a/libtransport/src/io_modules/forwarder/forwarder.h b/libtransport/src/io_modules/forwarder/forwarder.h index 38b4260b3..9ad989fcd 100644 --- a/libtransport/src/io_modules/forwarder/forwarder.h +++ b/libtransport/src/io_modules/forwarder/forwarder.h @@ -47,6 +47,7 @@ class Forwarder { asio::io_service &io_service, Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback); Forwarder &deleteConnector(Connector::Id id); diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.cc b/libtransport/src/io_modules/forwarder/forwarder_module.cc index 0ced84ab4..77d2b5e6a 100644 --- a/libtransport/src/io_modules/forwarder/forwarder_module.cc +++ b/libtransport/src/io_modules/forwarder/forwarder_module.cc @@ -37,8 +37,6 @@ void ForwarderModule::send(Packet &packet) { forwarder_.send(packet); DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending from " << connector_id_ << " to " << 1 - connector_id_; - - // local_faces_.at(1 - local_id_).onPacket(packet); } void ForwarderModule::send(const utils::MemBuf::Ptr &buffer) { @@ -58,12 +56,13 @@ void ForwarderModule::closeConnection() { void ForwarderModule::init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name) { connector_id_ = forwarder_.registerLocalConnector( io_service, std::move(receive_callback), std::move(sent_callback), - std::move(reconnect_callback)); + std::move(close_callback), std::move(reconnect_callback)); name_ = app_name; } diff --git a/libtransport/src/io_modules/forwarder/forwarder_module.h b/libtransport/src/io_modules/forwarder/forwarder_module.h index 52a12b67e..a48701161 100644 --- a/libtransport/src/io_modules/forwarder/forwarder_module.h +++ b/libtransport/src/io_modules/forwarder/forwarder_module.h @@ -44,6 +44,7 @@ class ForwarderModule : public IoModule { void init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name = "Libtransport") override; diff --git a/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.cc b/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.cc index f67bd9447..95f04822f 100644 --- a/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.cc +++ b/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.cc @@ -100,12 +100,13 @@ void HicnForwarderModule::closeConnection() { void HicnForwarderModule::init( Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name) { if (!connector_) { connector_.reset(new UdpTunnelConnector( io_service, std::move(receive_callback), std::move(sent_callback), - nullptr, std::move(reconnect_callback))); + std::move(close_callback), std::move(reconnect_callback))); } } @@ -130,7 +131,7 @@ bool HicnForwarderModule::isControlMessage(utils::MemBuf &packet_buffer) { */ utils::MemBuf::Ptr HicnForwarderModule::createCommandRoute( std::unique_ptr<sockaddr> &&addr, uint8_t prefix_length) { - utils::MemBuf::Ptr ret = utils::MemBuf::create(sizeof(msg_route_add_t)); + auto ret = PacketManager<>::getInstance().getMemBuf(); auto command = reinterpret_cast<msg_route_add_t *>(ret->writableData()); ret->append(sizeof(msg_route_add_t)); std::memset(command, 0, sizeof(*command)); @@ -170,8 +171,7 @@ utils::MemBuf::Ptr HicnForwarderModule::createCommandRoute( } utils::MemBuf::Ptr HicnForwarderModule::createCommandDeleteConnection() { - utils::MemBuf::Ptr ret = - utils::MemBuf::create(sizeof(msg_connection_remove_t)); + auto ret = PacketManager<>::getInstance().getMemBuf(); auto command = reinterpret_cast<msg_connection_remove_t *>(ret->writableData()); ret->append(sizeof(msg_connection_remove_t)); @@ -194,8 +194,7 @@ utils::MemBuf::Ptr HicnForwarderModule::createCommandDeleteConnection() { } utils::MemBuf::Ptr HicnForwarderModule::createCommandMapmeSendUpdate() { - utils::MemBuf::Ptr ret = - utils::MemBuf::create(sizeof(msg_mapme_send_update_t)); + auto ret = PacketManager<>::getInstance().getMemBuf(); auto command = reinterpret_cast<msg_mapme_send_update_t *>(ret->writableData()); ret->append(sizeof(msg_mapme_send_update_t)); @@ -217,7 +216,7 @@ utils::MemBuf::Ptr HicnForwarderModule::createCommandMapmeSendUpdate() { utils::MemBuf::Ptr HicnForwarderModule::createCommandSetForwardingStrategy( std::unique_ptr<sockaddr> &&addr, uint32_t prefix_len, std::string strategy) { - utils::MemBuf::Ptr ret = utils::MemBuf::create(sizeof(msg_strategy_set_t)); + auto ret = PacketManager<>::getInstance().getMemBuf(); auto command = reinterpret_cast<msg_strategy_set_t *>(ret->writableData()); ret->append(sizeof(msg_strategy_set_t)); std::memset(command, 0, sizeof(*command)); diff --git a/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.h b/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.h index 0bf82757d..7f0e7aca8 100644 --- a/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.h +++ b/libtransport/src/io_modules/hicn-light-ng/hicn_forwarder_module.h @@ -64,6 +64,7 @@ class HicnForwarderModule : public IoModule { void init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name = "Libtransport") override; diff --git a/libtransport/src/io_modules/loopback/loopback_module.cc b/libtransport/src/io_modules/loopback/loopback_module.cc index 5b7ed5f61..a7f30eb26 100644 --- a/libtransport/src/io_modules/loopback/loopback_module.cc +++ b/libtransport/src/io_modules/loopback/loopback_module.cc @@ -58,6 +58,7 @@ void LoopbackModule::closeConnection() { void LoopbackModule::init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name) { @@ -66,7 +67,7 @@ void LoopbackModule::init(Connector::PacketReceivedCallback &&receive_callback, local_faces_.emplace( local_faces_.begin() + local_id_, new LocalConnector(io_service, std::move(receive_callback), - std::move(sent_callback), nullptr, + std::move(sent_callback), std::move(close_callback), std::move(reconnect_callback))); } } diff --git a/libtransport/src/io_modules/loopback/loopback_module.h b/libtransport/src/io_modules/loopback/loopback_module.h index 2779ae7e3..d51f237f4 100644 --- a/libtransport/src/io_modules/loopback/loopback_module.h +++ b/libtransport/src/io_modules/loopback/loopback_module.h @@ -42,6 +42,7 @@ class LoopbackModule : public IoModule { void init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name = "Libtransport") override; diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.cc b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc index 65260077a..c096a71b8 100644 --- a/libtransport/src/io_modules/memif/vpp_forwarder_module.cc +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.cc @@ -50,13 +50,14 @@ VPPForwarderModule::~VPPForwarderModule() {} void VPPForwarderModule::init( Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name) { if (!connector_) { connector_ = std::make_unique<MemifConnector>( std::move(receive_callback), std::move(sent_callback), - Connector::OnCloseCallback(0), std::move(reconnect_callback), - io_service, app_name); + std::move(close_callback), std::move(reconnect_callback), io_service, + app_name); } } diff --git a/libtransport/src/io_modules/memif/vpp_forwarder_module.h b/libtransport/src/io_modules/memif/vpp_forwarder_module.h index 162ee0ca5..5a5358078 100644 --- a/libtransport/src/io_modules/memif/vpp_forwarder_module.h +++ b/libtransport/src/io_modules/memif/vpp_forwarder_module.h @@ -48,6 +48,7 @@ class VPPForwarderModule : public IoModule { void init(Connector::PacketReceivedCallback &&receive_callback, Connector::PacketSentCallback &&sent_callback, + Connector::OnCloseCallback &&close_callback, Connector::OnReconnectCallback &&reconnect_callback, asio::io_service &io_service, const std::string &app_name = "Libtransport") override; diff --git a/libtransport/src/protocols/byte_stream_reassembly.cc b/libtransport/src/protocols/byte_stream_reassembly.cc index 3278595b7..b9eaf3bec 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.cc +++ b/libtransport/src/protocols/byte_stream_reassembly.cc @@ -36,15 +36,6 @@ ByteStreamReassembly::ByteStreamReassembly( index_(Indexer::invalid_index), download_complete_(false) {} -void ByteStreamReassembly::reassemble( - std::unique_ptr<ContentObjectManifest> &&manifest) { - if (TRANSPORT_EXPECT_TRUE(manifest != nullptr) && read_buffer_->capacity()) { - received_packets_.emplace( - std::make_pair(manifest->getName().getSuffix(), nullptr)); - assembleContent(); - } -} - void ByteStreamReassembly::reassemble(ContentObject &content_object) { if (TRANSPORT_EXPECT_TRUE(read_buffer_->capacity())) { received_packets_.emplace( diff --git a/libtransport/src/protocols/byte_stream_reassembly.h b/libtransport/src/protocols/byte_stream_reassembly.h index bfcac3181..a1f965d5c 100644 --- a/libtransport/src/protocols/byte_stream_reassembly.h +++ b/libtransport/src/protocols/byte_stream_reassembly.h @@ -29,9 +29,6 @@ class ByteStreamReassembly : public Reassembly { protected: void reassemble(core::ContentObject &content_object) override; - void reassemble( - std::unique_ptr<core::ContentObjectManifest> &&manifest) override; - void reassemble(utils::MemBuf &buffer, uint32_t suffix) override; bool copyContent(core::ContentObject &content_object); diff --git a/libtransport/src/protocols/cbr.cc b/libtransport/src/protocols/cbr.cc index 446ea8b99..e3f0f1336 100644 --- a/libtransport/src/protocols/cbr.cc +++ b/libtransport/src/protocols/cbr.cc @@ -38,7 +38,7 @@ void CbrTransportProtocol::afterContentReception( const Interest &interest, const ContentObject &content_object) { auto segment = content_object.getName().getSuffix(); auto now = utils::SteadyTime::Clock::now(); - auto rtt = utils::SteadyTime::getDurationMs( + auto rtt = utils::SteadyTime::getDurationUs( interest_timepoints_[segment & mask], now); // Update stats updateStats(segment, rtt, now); diff --git a/libtransport/src/protocols/datagram_reassembly.cc b/libtransport/src/protocols/datagram_reassembly.cc index 3a32c81f5..a04b0eecf 100644 --- a/libtransport/src/protocols/datagram_reassembly.cc +++ b/libtransport/src/protocols/datagram_reassembly.cc @@ -29,8 +29,9 @@ void DatagramReassembly::reassemble(core::ContentObject& content_object) { auto read_buffer = content_object.getPayload(); DLOG_IF(INFO, VLOG_IS_ON(4)) << "Size of payload: " << read_buffer->length() << ". Trimming " - << transport_protocol_->transportHeaderLength(); - read_buffer->trimStart(transport_protocol_->transportHeaderLength()); + << transport_protocol_->transportHeaderLength(false); + // here we have only src data packet + read_buffer->trimStart(transport_protocol_->transportHeaderLength(false)); Reassembly::read_buffer_ = std::move(read_buffer); Reassembly::notifyApplication(); } diff --git a/libtransport/src/protocols/datagram_reassembly.h b/libtransport/src/protocols/datagram_reassembly.h index 0def32dd2..cefdca93b 100644 --- a/libtransport/src/protocols/datagram_reassembly.h +++ b/libtransport/src/protocols/datagram_reassembly.h @@ -29,10 +29,6 @@ class DatagramReassembly : public Reassembly { virtual void reassemble(core::ContentObject &content_object) override; void reassemble(utils::MemBuf &buffer, uint32_t suffix) override; virtual void reInitialize() override; - virtual void reassemble( - std::unique_ptr<core::ContentObjectManifest> &&manifest) override { - return; - } bool reassembleUnverified() override { return true; } }; diff --git a/libtransport/src/protocols/fec/rely.cc b/libtransport/src/protocols/fec/rely.cc index d4d98a90b..9e0a06dd8 100644 --- a/libtransport/src/protocols/fec/rely.cc +++ b/libtransport/src/protocols/fec/rely.cc @@ -79,7 +79,7 @@ void RelyEncoder::onPacketProduced(core::ContentObject &content_object, // Check new payload size and make sure it fits in packet buffer auto new_payload_size = produce_bytes(); - int difference = new_payload_size - length; + int difference = (int)(new_payload_size - length); DCHECK(difference > 0); DCHECK(content_object.ensureCapacity(difference)); diff --git a/libtransport/src/protocols/fec/rely.h b/libtransport/src/protocols/fec/rely.h index 001a26002..cc81222b2 100644 --- a/libtransport/src/protocols/fec/rely.h +++ b/libtransport/src/protocols/fec/rely.h @@ -15,6 +15,7 @@ #pragma once +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/chrono_typedefs.h> #include <hicn/transport/utils/membuf.h> #include <protocols/fec/fec_info.h> @@ -80,11 +81,19 @@ class RelyBase : public virtual FECBase { */ class fec_metadata { public: - void setSeqNumberBase(uint32_t suffix) { seq_number = htonl(suffix); } - uint32_t getSeqNumberBase() const { return ntohl(seq_number); } - - void setMetadataBase(uint32_t value) { metadata = htonl(value); } - uint32_t getMetadataBase() const { return ntohl(metadata); } + void setSeqNumberBase(uint32_t suffix) { + seq_number = portability::host_to_net(suffix); + } + uint32_t getSeqNumberBase() const { + return portability::net_to_host(seq_number); + } + + void setMetadataBase(uint32_t value) { + metadata = portability::host_to_net(value); + } + uint32_t getMetadataBase() const { + return portability::net_to_host(metadata); + } private: uint32_t seq_number; @@ -162,8 +171,9 @@ class RelyEncoder : RelyBase, rely::encoder, public ProducerFEC { /** * @brief Get the fec header size, if added to source packets + * there is not need to distinguish between source and FEC packets here */ - std::size_t getFecHeaderSize() override { + std::size_t getFecHeaderSize(bool isFEC) override { return header_bytes() + sizeof(fec_metadata) + 4; } @@ -184,8 +194,9 @@ class RelyDecoder : RelyBase, rely::decoder, public ConsumerFEC { /** * @brief Get the fec header size, if added to source packets + * there is not need to distinguish between source and FEC packets here */ - std::size_t getFecHeaderSize() override { + std::size_t getFecHeaderSize(bool isFEC) override { return header_bytes() + sizeof(fec_metadata); } diff --git a/libtransport/src/protocols/fec/rs.cc b/libtransport/src/protocols/fec/rs.cc index 9c0a3d4fb..d42740c32 100644 --- a/libtransport/src/protocols/fec/rs.cc +++ b/libtransport/src/protocols/fec/rs.cc @@ -146,7 +146,8 @@ void BlockCode::encode() { DLOG_IF(INFO, VLOG_IS_ON(4)) << "Calling encode with max_buffer_size_ = " << max_buffer_size_; for (uint32_t i = k_; i < n_; i++) { - fec_encode(code_, data, data[i], i, max_buffer_size_ + METADATA_BYTES); + fec_encode(code_, data, data[i], i, + (int)(max_buffer_size_ + METADATA_BYTES)); } // Re-include header in repair packets @@ -213,7 +214,8 @@ void BlockCode::decode() { DLOG_IF(INFO, VLOG_IS_ON(4)) << "Calling decode with max_buffer_size_ = " << max_buffer_size_; - fec_decode(code_, data, reinterpret_cast<int *>(index), max_buffer_size_); + fec_decode(code_, data, reinterpret_cast<int *>(index), + (int)max_buffer_size_); // Find the index in the block for recovered packets for (uint32_t i = 0, j = 0; i < k_; i++) { @@ -228,6 +230,7 @@ void BlockCode::decode() { auto &packet = operator[](i).getBuffer(); fec_metadata *metadata = reinterpret_cast<fec_metadata *>( packet->writableData() + max_buffer_size_ - METADATA_BYTES); + DCHECK(metadata->getPacketLength() <= packet->capacity()); // Adjust buffer length packet->setLength(metadata->getPacketLength()); // Adjust metadata diff --git a/libtransport/src/protocols/fec/rs.h b/libtransport/src/protocols/fec/rs.h index 034c32bdc..6672eaa6b 100644 --- a/libtransport/src/protocols/fec/rs.h +++ b/libtransport/src/protocols/fec/rs.h @@ -18,6 +18,7 @@ #include <arpa/inet.h> #include <hicn/transport/portability/c_portability.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/membuf.h> #include <protocols/fec/fec_info.h> #include <protocols/fec_base.h> @@ -153,8 +154,10 @@ struct fec_header { */ uint8_t padding; - void setSeqNumberBase(uint32_t suffix) { seq_number = htonl(suffix); } - uint32_t getSeqNumberBase() { return ntohl(seq_number); } + void setSeqNumberBase(uint32_t suffix) { + seq_number = portability::host_to_net(suffix); + } + uint32_t getSeqNumberBase() { return portability::net_to_host(seq_number); } void setEncodedSymbolId(uint8_t esi) { encoded_symbol_id = esi; } uint8_t getEncodedSymbolId() { return encoded_symbol_id; } void setSourceBlockLen(uint8_t k) { source_block_len = k; } @@ -163,6 +166,8 @@ struct fec_header { uint8_t getNFecSymbols() { return n_fec_symbols; } }; +static_assert(sizeof(fec_header) <= 8, "fec_header is too large"); + class rs; /** @@ -177,11 +182,17 @@ class BlockCode : public Packets { */ class __attribute__((__packed__)) fec_metadata { public: - void setPacketLength(uint16_t length) { packet_length = htons(length); } - uint32_t getPacketLength() { return ntohs(packet_length); } + void setPacketLength(uint16_t length) { + packet_length = portability::host_to_net(length); + } + uint32_t getPacketLength() { + return portability::net_to_host(packet_length); + } - void setMetadataBase(uint32_t value) { metadata = htonl(value); } - uint32_t getMetadataBase() { return ntohl(metadata); } + void setMetadataBase(uint32_t value) { + metadata = portability::host_to_net(value); + } + uint32_t getMetadataBase() { return portability::net_to_host(metadata); } private: uint16_t packet_length; /* Used to get the real size of the packet after we @@ -388,8 +399,11 @@ class RSEncoder : public rs, public ProducerFEC { /** * @brief Get the fec header size, if added to source packets + * in RS the source packets do not transport any FEC header */ - std::size_t getFecHeaderSize() override { return 0; } + std::size_t getFecHeaderSize(bool isFEC) override { + return isFEC ? sizeof(fec_header) : 0; + } void clear() override { rs::clear(); @@ -435,8 +449,11 @@ class RSDecoder : public rs, public ConsumerFEC { /** * @brief Get the fec header size, if added to source packets + * in RS the source packets do not transport any FEC header */ - std::size_t getFecHeaderSize() override { return 0; } + std::size_t getFecHeaderSize(bool isFEC) override { + return isFEC ? sizeof(fec_header) : 0; + } /** * Clear decoder to reuse diff --git a/libtransport/src/protocols/fec_base.h b/libtransport/src/protocols/fec_base.h index bda3ee756..28f6a820a 100644 --- a/libtransport/src/protocols/fec_base.h +++ b/libtransport/src/protocols/fec_base.h @@ -101,8 +101,10 @@ class FECBase { /** * @brief Get size of FEC header. + * the fec header size may be different if a packet is a data packet or a FEC + * packet */ - virtual std::size_t getFecHeaderSize() = 0; + virtual std::size_t getFecHeaderSize(bool isFEC) = 0; /** * Set callback to call after packet encoding / decoding diff --git a/libtransport/src/protocols/manifest_incremental_indexer_bytestream.cc b/libtransport/src/protocols/manifest_incremental_indexer_bytestream.cc index b5ab8184f..0b15559a4 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer_bytestream.cc +++ b/libtransport/src/protocols/manifest_incremental_indexer_bytestream.cc @@ -66,40 +66,33 @@ void ManifestIncrementalIndexer::onUntrustedManifest( return; } - auto manifest = - std::make_unique<ContentObjectManifest>(std::move(content_object)); - manifest->decode(); + core::ContentObjectManifest manifest(content_object.shared_from_this()); + manifest.decode(); - processTrustedManifest(interest, std::move(manifest), reassembly); + processTrustedManifest(interest, manifest, reassembly); } void ManifestIncrementalIndexer::processTrustedManifest( - core::Interest &interest, std::unique_ptr<ContentObjectManifest> manifest, + core::Interest &interest, core::ContentObjectManifest &manifest, bool reassembly) { - if (TRANSPORT_EXPECT_FALSE(manifest->getVersion() != - core::ManifestVersion::VERSION_1)) { - throw errors::RuntimeException("Received manifest with unknown version."); - } - - switch (manifest->getType()) { + switch (manifest.getType()) { case core::ManifestType::INLINE_MANIFEST: { suffix_strategy_->setFinalSuffix( - manifest->getParamsBytestream().final_segment); + manifest.getParamsBytestream().final_segment); // The packets to verify with the received manifest std::vector<auth::PacketPtr> packets; // Convert the received manifest to a map of packet suffixes to hashes - auth::Verifier::SuffixMap current_manifest = - core::ContentObjectManifest::getSuffixMap(manifest.get()); + auth::Verifier::SuffixMap suffix_map = manifest.getSuffixMap(); // Update 'suffix_map_' with new hashes from the received manifest and // build 'packets' - for (auto it = current_manifest.begin(); it != current_manifest.end();) { + for (auto it = suffix_map.begin(); it != suffix_map.end();) { if (unverified_segments_.find(it->first) == unverified_segments_.end()) { suffix_map_[it->first] = std::move(it->second); - current_manifest.erase(it++); + suffix_map.erase(it++); continue; } @@ -109,7 +102,7 @@ void ManifestIncrementalIndexer::processTrustedManifest( // Verify unverified segments using the received manifest auth::Verifier::PolicyMap policies = - verifier_->verifyPackets(packets, current_manifest); + verifier_->verifyPackets(packets, suffix_map); for (unsigned int i = 0; i < packets.size(); ++i) { auth::Suffix suffix = packets[i]->getName().getSuffix(); @@ -126,7 +119,9 @@ void ManifestIncrementalIndexer::processTrustedManifest( } if (reassembly) { - reassembly_->reassemble(std::move(manifest)); + auto manifest_co = + std::dynamic_pointer_cast<ContentObject>(manifest.getPacket()); + reassembly_->reassemble(*manifest_co); } break; } diff --git a/libtransport/src/protocols/manifest_incremental_indexer_bytestream.h b/libtransport/src/protocols/manifest_incremental_indexer_bytestream.h index 12876f35c..8527b55c1 100644 --- a/libtransport/src/protocols/manifest_incremental_indexer_bytestream.h +++ b/libtransport/src/protocols/manifest_incremental_indexer_bytestream.h @@ -76,7 +76,7 @@ class ManifestIncrementalIndexer : public IncrementalIndexer { core::ContentObject &content_object, bool reassembly); void processTrustedManifest(core::Interest &interest, - std::unique_ptr<ContentObjectManifest> manifest, + core::ContentObjectManifest &manifest, bool reassembly); void onUntrustedContentObject(core::Interest &interest, core::ContentObject &content_object, diff --git a/libtransport/src/protocols/prod_protocol_bytestream.cc b/libtransport/src/protocols/prod_protocol_bytestream.cc index 2a3ec07e1..7f103e12b 100644 --- a/libtransport/src/protocols/prod_protocol_bytestream.cc +++ b/libtransport/src/protocols/prod_protocol_bytestream.cc @@ -111,18 +111,18 @@ uint32_t ByteStreamProductionProtocol::produceStream( uint64_t manifest_free_space; uint32_t nb_manifests; std::shared_ptr<core::ContentObjectManifest> manifest; - uint32_t manifest_capacity = making_manifest_; + uint32_t manifest_capacity = manifest_max_capacity_; bool is_last_manifest = false; ParamsBytestream transport_params; manifest_format = Packet::toAHFormat(default_format); - content_format = - !making_manifest_ ? Packet::toAHFormat(default_format) : default_format; + content_format = !manifest_max_capacity_ ? Packet::toAHFormat(default_format) + : default_format; - content_header_size = - core::Packet::getHeaderSizeFromFormat(content_format, signature_length); - manifest_header_size = - core::Packet::getHeaderSizeFromFormat(manifest_format, signature_length); + content_header_size = (uint32_t)core::Packet::getHeaderSizeFromFormat( + content_format, signature_length); + manifest_header_size = (uint32_t)core::Packet::getHeaderSizeFromFormat( + manifest_format, signature_length); content_free_space = std::min(max_segment_size, data_packet_size - content_header_size); manifest_free_space = @@ -135,34 +135,39 @@ uint32_t ByteStreamProductionProtocol::produceStream( nb_segments++; } - if (making_manifest_) { + if (manifest_max_capacity_) { nb_manifests = static_cast<uint32_t>( std::ceil(float(nb_segments) / manifest_capacity)); final_block_number += nb_segments + nb_manifests - 1; transport_params.final_segment = is_last ? final_block_number : utils::SuffixStrategy::MAX_SUFFIX; - manifest.reset(ContentObjectManifest::createManifest( + manifest = ContentObjectManifest::createContentManifest( manifest_format, name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, - is_last_manifest, name, hash_algo, signature_length)); - - manifest->setLifetime(content_object_expiry_time); + signature_length); + manifest->setHeaders(core::ManifestType::INLINE_MANIFEST, + manifest_max_capacity_, hash_algo, is_last_manifest, + name); manifest->setParamsBytestream(transport_params); + manifest->getPacket()->setLifetime(content_object_expiry_time); } auto self = shared_from_this(); for (unsigned int packaged_segments = 0; packaged_segments < nb_segments; packaged_segments++) { - if (making_manifest_) { - if (manifest->estimateManifestSize(1) > manifest_free_space) { + if (manifest_max_capacity_) { + if (manifest->Encoder::manifestSize(1) > manifest_free_space) { manifest->encode(); - signer_->signPacket(manifest.get()); + auto manifest_co = + std::dynamic_pointer_cast<ContentObject>(manifest->getPacket()); + + signer_->signPacket(manifest_co.get()); // Send the current manifest - passContentObjectToCallbacks(manifest, self); - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Send manifest " << manifest->getName(); + passContentObjectToCallbacks(manifest_co, self); + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "Send manifest " << manifest_co->getName(); // Send content objects stored in the queue while (!content_queue_.empty()) { @@ -175,15 +180,15 @@ uint32_t ByteStreamProductionProtocol::produceStream( // Create new manifest. The reference to the last manifest has been // acquired in the passContentObjectToCallbacks function, so we can // safely release this reference. - manifest.reset(ContentObjectManifest::createManifest( + manifest = ContentObjectManifest::createContentManifest( manifest_format, name.setSuffix(suffix_strategy->getNextManifestSuffix()), - core::ManifestVersion::VERSION_1, - core::ManifestType::INLINE_MANIFEST, is_last_manifest, name, - hash_algo, signature_length)); - - manifest->setLifetime(content_object_expiry_time); + signature_length); + manifest->setHeaders(core::ManifestType::INLINE_MANIFEST, + manifest_max_capacity_, hash_algo, + is_last_manifest, name); manifest->setParamsBytestream(transport_params); + manifest->getPacket()->setLifetime(content_object_expiry_time); } } @@ -191,7 +196,7 @@ uint32_t ByteStreamProductionProtocol::produceStream( uint32_t content_suffix = suffix_strategy->getNextContentSuffix(); auto content_object = std::make_shared<ContentObject>( name.setSuffix(content_suffix), content_format, - !making_manifest_ ? signature_length : 0); + !manifest_max_capacity_ ? signature_length : 0); content_object->setLifetime(content_object_expiry_time); auto b = buffer->cloneOne(); @@ -203,7 +208,7 @@ uint32_t ByteStreamProductionProtocol::produceStream( b->append(buffer_size - bytes_segmented); bytes_segmented += (int)(buffer_size - bytes_segmented); - if (is_last && making_manifest_) { + if (is_last && manifest_max_capacity_) { is_last_manifest = true; } else if (is_last) { content_object->setLast(); @@ -219,9 +224,9 @@ uint32_t ByteStreamProductionProtocol::produceStream( // Either we sign the content object or we save its hash into the current // manifest - if (making_manifest_) { + if (manifest_max_capacity_) { auth::CryptoHash hash = content_object->computeDigest(hash_algo); - manifest->addSuffixHash(content_suffix, hash); + manifest->addEntry(content_suffix, hash); content_queue_.push(content_object); } else { signer_->signPacket(content_object.get()); @@ -232,16 +237,19 @@ uint32_t ByteStreamProductionProtocol::produceStream( } // We send the manifest that hasn't been fully filled yet - if (making_manifest_) { + if (manifest_max_capacity_) { if (is_last_manifest) { manifest->setIsLast(is_last_manifest); } manifest->encode(); - signer_->signPacket(manifest.get()); + auto manifest_co = + std::dynamic_pointer_cast<ContentObject>(manifest->getPacket()); + + signer_->signPacket(manifest_co.get()); - passContentObjectToCallbacks(manifest, self); - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Send manifest " << manifest->getName(); + passContentObjectToCallbacks(manifest_co, self); + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Send manifest " << manifest_co->getName(); while (!content_queue_.empty()) { passContentObjectToCallbacks(content_queue_.front(), self); diff --git a/libtransport/src/protocols/prod_protocol_rtc.cc b/libtransport/src/protocols/prod_protocol_rtc.cc index e49f58167..cb8dff6e4 100644 --- a/libtransport/src/protocols/prod_protocol_rtc.cc +++ b/libtransport/src/protocols/prod_protocol_rtc.cc @@ -43,9 +43,6 @@ RTCProductionProtocol::RTCProductionProtocol( last_produced_data_ts_(0), last_round_(utils::SteadyTime::nowMs().count()), allow_delayed_nacks_(false), - queue_timer_on_(false), - consumer_in_sync_(false), - on_consumer_in_sync_(nullptr), pending_fec_pace_(false), max_len_(0), queue_len_(0), @@ -54,8 +51,6 @@ RTCProductionProtocol::RTCProductionProtocol( std::uniform_int_distribution<> dis(0, 255); prod_label_ = dis(gen_); cache_label_ = (prod_label_ + 1) % 256; - interests_queue_timer_ = - std::make_unique<asio::steady_timer>(portal_->getThread().getIoService()); round_timer_ = std::make_unique<asio::steady_timer>(portal_->getThread().getIoService()); fec_pacing_timer_ = @@ -69,16 +64,7 @@ RTCProductionProtocol::~RTCProductionProtocol() {} void RTCProductionProtocol::setProducerParam() { // Flow name: here we assume there is only one prefix registered in the portal - flow_name_ = portal_->getServedNamespaces().begin()->getName(); - - // Manifest - uint32_t making_manifest; - socket_->getSocketOption(interface::GeneralTransportOptions::MAKE_MANIFEST, - making_manifest); - - // Signer - std::shared_ptr<auth::Signer> signer; - socket_->getSocketOption(interface::GeneralTransportOptions::SIGNER, signer); + flow_name_ = portal_->getServedNamespaces().begin()->makeName(); // Default format core::Packet::Format default_format; @@ -94,15 +80,22 @@ void RTCProductionProtocol::setProducerParam() { socket_->getSocketOption(interface::RtcTransportOptions::AGGREGATED_DATA, data_aggregation_); - size_t signature_size = signer->getSignatureFieldSize(); - data_header_format_ = { - !making_manifest ? Packet::toAHFormat(default_format) : default_format, - !making_manifest ? signature_size : 0}; + size_t signature_size = signer_->getSignatureFieldSize(); + data_header_format_ = {!manifest_max_capacity_ + ? Packet::toAHFormat(default_format) + : default_format, + !manifest_max_capacity_ ? signature_size : 0}; manifest_header_format_ = {Packet::toAHFormat(default_format), signature_size}; nack_header_format_ = {Packet::toAHFormat(default_format), signature_size}; fec_header_format_ = {Packet::toAHFormat(default_format), signature_size}; + // Initialize verifier for aggregated interests + std::shared_ptr<auth::Verifier> verifier; + socket_->getSocketOption(implementation::GeneralTransportOptions::VERIFIER, + verifier); + verifier_ = std::make_shared<rtc::RTCVerifier>(verifier, 0, 0); + // Schedule round timer scheduleRoundTimer(); } @@ -143,15 +136,17 @@ void RTCProductionProtocol::updateStats(bool new_round) { packets_production_rate_ = ceil((double)(produced_packets_ + prev_produced_packets_) * per_second); - // add fec packets looking at the fec code. we don't use directly the number - // of fec packets produced in 1 round because it may happen that different - // numbers of blocks are generated during the rounds and this creates - // inconsistencies in the estimation of the production rate - uint32_t k = fec::FECUtils::getSourceSymbols(fec_type_); - uint32_t n = fec::FECUtils::getBlockSymbols(fec_type_); + if (fec_encoder_ && fec_type_ != fec::FECType::UNKNOWN) { + // add fec packets looking at the fec code. we don't use directly the number + // of fec packets produced in 1 round because it may happen that different + // numbers of blocks are generated during the rounds and this creates + // inconsistencies in the estimation of the production rate + uint32_t k = fec::FECUtils::getSourceSymbols(fec_type_); + uint32_t n = fec::FECUtils::getBlockSymbols(fec_type_); - packets_production_rate_ += - ceil((double)packets_production_rate_ / (double)k) * (n - k); + packets_production_rate_ += + ceil((double)packets_production_rate_ / (double)k) * (n - k); + } // update the production rate as soon as it increases by 10% with respect to // the last round @@ -168,11 +163,6 @@ void RTCProductionProtocol::updateStats(bool new_round) { allow_delayed_nacks_ = false; } - // check if the production rate is decreased. if yes send nacks if needed - if (prev_packets_production_rate < packets_production_rate_) { - sendNacksForPendingInterests(); - } - if (new_round) { prev_produced_bytes_ = produced_bytes_; prev_produced_packets_ = produced_packets_; @@ -203,16 +193,25 @@ void RTCProductionProtocol::produce(ContentObject &content_object) { uint32_t RTCProductionProtocol::produceDatagram( const Name &content_name, std::unique_ptr<utils::MemBuf> &&buffer) { std::size_t buffer_size = buffer->length(); + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "Maybe Sending content object: " << content_name; + if (TRANSPORT_EXPECT_FALSE(buffer_size == 0)) return 0; + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending content object: " << content_name; + uint32_t data_packet_size; socket_->getSocketOption(interface::GeneralTransportOptions::DATA_PACKET_SIZE, data_packet_size); - - if (TRANSPORT_EXPECT_FALSE( - (Packet::getHeaderSizeFromFormat(data_header_format_.first, - data_header_format_.second) + - rtc::DATA_HEADER_SIZE + buffer_size) > data_packet_size)) { + // this is a source packet but we check the fec header size of FEC packet in + // order to leave room for the header when FEC packets will be generated + uint32_t fec_header = 0; + if (fec_encoder_) fec_encoder_->getFecHeaderSize(true); + uint32_t headers_size = + (uint32_t)Packet::getHeaderSizeFromFormat(data_header_format_.first, + data_header_format_.second) + + rtc::DATA_HEADER_SIZE + fec_header; + if (TRANSPORT_EXPECT_FALSE((headers_size + buffer_size) > data_packet_size)) { return 0; } @@ -338,47 +337,42 @@ void RTCProductionProtocol::emptyQueue() { } void RTCProductionProtocol::sendManifest(const Name &name) { - if (!making_manifest_) { + if (!manifest_max_capacity_) { return; } - Name manifest_name(name); - - uint32_t data_packet_size; - socket_->getSocketOption(interface::GeneralTransportOptions::DATA_PACKET_SIZE, - data_packet_size); - - // The maximum number of entries a manifest can hold - uint32_t manifest_capacity = making_manifest_; + Name manifest_name = name; // If there is not enough hashes to fill a manifest, return early - if (manifest_entries_.size() < manifest_capacity) { + if (manifest_entries_.size() < manifest_max_capacity_) { return; } // Create a new manifest std::shared_ptr<core::ContentObjectManifest> manifest = createManifest(manifest_name.setSuffix(current_seg_)); + auto manifest_co = + std::dynamic_pointer_cast<ContentObject>(manifest->getPacket()); // Fill the manifest with packet hashes that were previously saved uint32_t nb_entries; - for (nb_entries = 0; nb_entries < manifest_capacity; ++nb_entries) { + for (nb_entries = 0; nb_entries < manifest_max_capacity_; ++nb_entries) { if (manifest_entries_.empty()) { break; } std::pair<uint32_t, auth::CryptoHash> front = manifest_entries_.front(); - manifest->addSuffixHash(front.first, front.second); + manifest->addEntry(front.first, front.second); manifest_entries_.pop(); } DLOG_IF(INFO, VLOG_IS_ON(3)) - << "Sending manifest " << manifest->getName().getSuffix() << " of size " - << nb_entries; + << "Sending manifest " << manifest_co->getName().getSuffix() + << " of size " << nb_entries; // Encode and send the manifest manifest->encode(); portal_->getThread().tryRunHandlerNow( - [this, content_object{std::move(manifest)}, manifest_name]() mutable { + [this, content_object{std::move(manifest_co)}, manifest_name]() mutable { produceInternal(std::move(content_object), manifest_name); }); } @@ -394,11 +388,12 @@ RTCProductionProtocol::createManifest(const Name &content_name) const { uint64_t now = utils::SteadyTime::nowMs().count(); // Create a new manifest - std::shared_ptr<core::ContentObjectManifest> manifest( - ContentObjectManifest::createManifest( - manifest_header_format_.first, name, core::ManifestVersion::VERSION_1, - core::ManifestType::INLINE_MANIFEST, false, name, hash_algo, - manifest_header_format_.second)); + std::shared_ptr<core::ContentObjectManifest> manifest = + ContentObjectManifest::createContentManifest( + manifest_header_format_.first, name, manifest_header_format_.second); + manifest->setHeaders(core::ManifestType::INLINE_MANIFEST, + manifest_max_capacity_, hash_algo, false /* is_last */, + name); // Set connection parameters manifest->setParamsRTC(ParamsRTC{ @@ -444,7 +439,15 @@ void RTCProductionProtocol::producePktInternal( // set hicn stuff Name n(content_name); content_object->setName(n.setSuffix(current_seg_)); - content_object->setLifetime(500); // XXX this should be set by the APP + + uint32_t expiry_time = 0; + socket_->getSocketOption( + interface::GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME, + expiry_time); + if (expiry_time == interface::default_values::content_object_expiry_time) + expiry_time = 500; // the data expiration time should be set by the App. if + // the App does not specify it the default is 500ms + content_object->setLifetime(expiry_time); content_object->setPathLabel(prod_label_); // update stats @@ -466,9 +469,9 @@ void RTCProductionProtocol::producePktInternal( // pass packet to FEC encoder if (fec_encoder_ && !fec) { - uint32_t offset = - is_manifest ? content_object->headerSize() - : content_object->headerSize() + rtc::DATA_HEADER_SIZE; + uint32_t offset = is_manifest ? (uint32_t)content_object->headerSize() + : (uint32_t)content_object->headerSize() + + rtc::DATA_HEADER_SIZE; uint32_t metadata = static_cast<uint32_t>(content_object->getPayloadType()); fec_encoder_->onPacketProduced(*content_object, offset, metadata); @@ -481,19 +484,14 @@ void RTCProductionProtocol::producePktInternal( *content_object); } - auto seq_it = seqs_map_.find(current_seg_); - if (seq_it != seqs_map_.end()) { - sendContentObject(content_object, false, fec); - } + // TODO we may want to send FEC only if an interest is pending in the pit in + sendContentObject(content_object, false, fec); if (*on_content_object_output_) { on_content_object_output_->operator()(*socket_->getInterface(), *content_object); } - // remove interests from the interest cache if it exists - removeFromInterestQueue(current_seg_); - if (!fec) last_produced_data_ts_ = now; // Update current segment @@ -563,59 +561,65 @@ void RTCProductionProtocol::onInterest(Interest &interest) { on_interest_input_->operator()(*socket_->getInterface(), interest); } - auto suffix = interest.firstSuffix(); - // numberOfSuffixes returns only the prefixes in the payalod - // we add + 1 to count also the seq in the name - auto n_suffixes = interest.numberOfSuffixes() + 1; - Name name = interest.getName(); - bool prev_consumer_state = consumer_in_sync_; - - for (uint32_t i = 0; i < n_suffixes; i++) { - if (i > 0) { - name.setSuffix(*(suffix + (i - 1))); - } + if (!interest.isValid()) throw std::runtime_error("Bad interest format"); + if (interest.hasManifest() && + verifier_->verify(interest) != auth::VerificationPolicy::ACCEPT) + throw std::runtime_error("Interset manifest verification failed"); - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Received interest " << name; + uint32_t *suffix = interest.firstSuffix(); + uint32_t n_suffixes_in_manifest = interest.numberOfSuffixes(); + uint32_t *request_bitmap = interest.getRequestBitmap(); - const std::shared_ptr<ContentObject> content_object = - output_buffer_.find(name); + Name name = interest.getName(); + uint32_t pos = 0; // Position of current suffix in manifest - if (content_object) { - if (*on_interest_satisfied_output_buffer_) { - on_interest_satisfied_output_buffer_->operator()( - *socket_->getInterface(), interest); - } + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "Received interest " << name << " (" << n_suffixes_in_manifest + << " suffixes in manifest)"; + + // Process the suffix in the interest header + // (first loop iteration), then suffixes in the manifest + do { + if (!interest.hasManifest() || is_bit_set(request_bitmap, pos)) { + const std::shared_ptr<ContentObject> content_object = + output_buffer_.find(name); + + if (content_object) { + if (*on_interest_satisfied_output_buffer_) { + on_interest_satisfied_output_buffer_->operator()( + *socket_->getInterface(), interest); + } - if (*on_content_object_output_) { - on_content_object_output_->operator()(*socket_->getInterface(), - *content_object); - } + if (*on_content_object_output_) { + on_content_object_output_->operator()(*socket_->getInterface(), + *content_object); + } - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "Send content %u (onInterest) " << content_object->getName(); - content_object->setPathLabel(cache_label_); - sendContentObject(content_object); - } else { - if (*on_interest_process_) { - on_interest_process_->operator()(*socket_->getInterface(), interest); + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "Send content %u (onInterest) " << content_object->getName(); + content_object->setPathLabel(cache_label_); + sendContentObject(content_object); + } else { + if (*on_interest_process_) { + on_interest_process_->operator()(*socket_->getInterface(), interest); + } + processInterest(name.getSuffix(), interest.getLifetime()); } - processInterest(name.getSuffix(), interest.getLifetime()); } - } - if (prev_consumer_state != consumer_in_sync_ && consumer_in_sync_) - on_consumer_in_sync_(*socket_->getInterface(), interest); + // Retrieve next suffix in the manifest + if (interest.hasManifest()) { + uint32_t seq = *suffix; + suffix++; + + name.setSuffix(seq); + interest.setName(name); + } + } while (pos++ < n_suffixes_in_manifest); } void RTCProductionProtocol::processInterest(uint32_t interest_seg, uint32_t lifetime) { - if (interest_seg == 0) { - // first packet from the consumer, reset sync state - consumer_in_sync_ = false; - } - - uint64_t now = utils::SteadyTime::nowMs().count(); - switch (rtc::ProbeHandler::getProbeType(interest_seg)) { case rtc::ProbeType::INIT: DLOG_IF(INFO, VLOG_IS_ON(3)) << "Received init probe " << interest_seg; @@ -629,183 +633,7 @@ void RTCProductionProtocol::processInterest(uint32_t interest_seg, break; } - // if the production rate 0 use delayed nacks - if (allow_delayed_nacks_ && interest_seg >= current_seg_) { - uint64_t next_timer = UINT64_MAX; - if (!timers_map_.empty()) { - next_timer = timers_map_.begin()->first; - } - - uint64_t expiration = now + rtc::NACK_DELAY; - addToInterestQueue(interest_seg, expiration); - - // here we have at least one interest in the queue, we need to start or - // update the timer - if (!queue_timer_on_) { - // set timeout - queue_timer_on_ = true; - scheduleQueueTimer(timers_map_.begin()->first - now); - } else { - // re-schedule the timer because a new interest will expires sooner - if (next_timer > timers_map_.begin()->first) { - interests_queue_timer_->cancel(); - scheduleQueueTimer(timers_map_.begin()->first - now); - } - } - return; - } - - if (queue_timer_on_) { - // the producer is producing. Send nacks to packets that will expire - // before the data production and remove the timer - queue_timer_on_ = false; - interests_queue_timer_->cancel(); - sendNacksForPendingInterests(); - } - - uint32_t max_gap = (uint32_t)floor( - (double)((double)((double)lifetime * - rtc::INTEREST_LIFETIME_REDUCTION_FACTOR / - rtc::MILLI_IN_A_SEC) * - (double)(packets_production_rate_))); - - if (interest_seg < current_seg_ || interest_seg > (max_gap + current_seg_)) { - sendNack(interest_seg); - } else { - if (!consumer_in_sync_ && on_consumer_in_sync_) { - // we consider the remote consumer to be in sync as soon as it covers - // 70% of the production window with interests - uint32_t perc = ceil((double)max_gap * 0.7); - if (interest_seg > (perc + current_seg_)) { - consumer_in_sync_ = true; - // on_consumer_in_sync_(*socket_->getInterface(), interest); - } - } - uint64_t expiration = - now + floor((double)lifetime * rtc::INTEREST_LIFETIME_REDUCTION_FACTOR); - addToInterestQueue(interest_seg, expiration); - } -} - -void RTCProductionProtocol::scheduleQueueTimer(uint64_t wait) { - interests_queue_timer_->expires_from_now(std::chrono::milliseconds(wait)); - std::weak_ptr<RTCProductionProtocol> self = shared_from_this(); - interests_queue_timer_->async_wait([self](const std::error_code &ec) { - if (ec) { - return; - } - - auto sp = self.lock(); - if (sp && sp->isRunning()) { - sp->interestQueueTimer(); - } - }); -} - -void RTCProductionProtocol::addToInterestQueue(uint32_t interest_seg, - uint64_t expiration) { - // check if the seq number exists already - auto it_seqs = seqs_map_.find(interest_seg); - if (it_seqs != seqs_map_.end()) { - // the seq already exists - if (expiration < it_seqs->second) { - // we need to update the timer becasue we got a smaller one - // 1) remove the entry from the multimap - // 2) update this entry - auto range = timers_map_.equal_range(it_seqs->second); - for (auto it_timers = range.first; it_timers != range.second; - it_timers++) { - if (it_timers->second == it_seqs->first) { - timers_map_.erase(it_timers); - break; - } - } - timers_map_.insert( - std::pair<uint64_t, uint32_t>(expiration, interest_seg)); - it_seqs->second = expiration; - } else { - // nothing to do here - return; - } - } else { - // add the new seq - timers_map_.insert(std::pair<uint64_t, uint32_t>(expiration, interest_seg)); - seqs_map_.insert(std::pair<uint32_t, uint64_t>(interest_seg, expiration)); - } -} - -void RTCProductionProtocol::sendNacksForPendingInterests() { - std::unordered_set<uint32_t> to_remove; - - uint32_t pps = ceil((double)(packets_production_rate_)*rtc:: - INTEREST_LIFETIME_REDUCTION_FACTOR); - - uint64_t now = utils::SteadyTime::nowMs().count(); - for (auto it = seqs_map_.begin(); it != seqs_map_.end(); it++) { - if (it->first > current_seg_ && it->second > now) { - double exp_time_in_sec = - (double)(it->second - now) / (double)rtc::MILLI_IN_A_SEC; - uint32_t packets_prod_before_expire = ceil((double)pps * exp_time_in_sec); - - if (it->first > (current_seg_ + packets_prod_before_expire)) { - sendNack(it->first); - to_remove.insert(it->first); - } - } else if (TRANSPORT_EXPECT_FALSE(it->first < current_seg_ || - it->second <= now)) { - // this branch should never be execcuted - // first condition: the packet was already prdocued and we have and old - // interest pending. send a nack to notify the consumer if needed. the - // case it->first = current_seg_ is not handled because - // the interest will be satified by the next data packet. - // second condition: the interest is expired. - sendNack(it->first); - to_remove.insert(it->first); - } - } - - // delete nacked interests - for (auto it = to_remove.begin(); it != to_remove.end(); it++) { - removeFromInterestQueue(*it); - } -} - -void RTCProductionProtocol::removeFromInterestQueue(uint32_t interest_seg) { - auto seq_it = seqs_map_.find(interest_seg); - if (seq_it != seqs_map_.end()) { - auto range = timers_map_.equal_range(seq_it->second); - for (auto it_timers = range.first; it_timers != range.second; it_timers++) { - if (it_timers->second == seq_it->first) { - timers_map_.erase(it_timers); - break; - } - } - seqs_map_.erase(seq_it); - } -} - -void RTCProductionProtocol::interestQueueTimer() { - uint64_t now = utils::SteadyTime::nowMs().count(); - - for (auto it_timers = timers_map_.begin(); it_timers != timers_map_.end();) { - uint64_t expire = it_timers->first; - if (expire <= now) { - uint32_t seq = it_timers->second; - sendNack(seq); - // remove the interest from the other map - seqs_map_.erase(seq); - it_timers = timers_map_.erase(it_timers); - } else { - // stop, we are done! - break; - } - } - if (timers_map_.empty()) { - queue_timer_on_ = false; - } else { - queue_timer_on_ = true; - scheduleQueueTimer(timers_map_.begin()->first - now); - } + if (interest_seg < current_seg_) sendNack(interest_seg); } void RTCProductionProtocol::sendManifestProbe(uint32_t sequence) { @@ -814,18 +642,20 @@ void RTCProductionProtocol::sendManifestProbe(uint32_t sequence) { std::shared_ptr<core::ContentObjectManifest> manifest_probe = createManifest(manifest_name); + auto manifest_probe_co = + std::dynamic_pointer_cast<ContentObject>(manifest_probe->getPacket()); - manifest_probe->setLifetime(0); - manifest_probe->setPathLabel(prod_label_); + manifest_probe_co->setLifetime(0); + manifest_probe_co->setPathLabel(prod_label_); manifest_probe->encode(); if (*on_content_object_output_) { on_content_object_output_->operator()(*socket_->getInterface(), - *manifest_probe); + *manifest_probe_co); } DLOG_IF(INFO, VLOG_IS_ON(3)) << "Send init probe " << sequence; - sendContentObject(manifest_probe, true, false); + sendContentObject(manifest_probe_co, true, false); } void RTCProductionProtocol::sendNack(uint32_t sequence) { @@ -847,20 +677,6 @@ void RTCProductionProtocol::sendNack(uint32_t sequence) { nack->setLifetime(0); nack->setPathLabel(prod_label_); - if (!consumer_in_sync_ && on_consumer_in_sync_ && - rtc::ProbeHandler::getProbeType(sequence) == rtc::ProbeType::NOT_PROBE && - sequence > next_packet) { - consumer_in_sync_ = true; - Packet::Format format; - socket_->getSocketOption(interface::GeneralTransportOptions::PACKET_FORMAT, - format); - - auto interest = - core::PacketManager<>::getInstance().getPacket<Interest>(format); - interest->setName(n); - on_consumer_in_sync_(*socket_->getInterface(), *interest); - } - if (*on_content_object_output_) { on_content_object_output_->operator()(*socket_->getInterface(), *nack); } @@ -881,7 +697,7 @@ void RTCProductionProtocol::sendContentObject( portal_->sendContentObject(*content_object); // Compute and save data packet digest - if (making_manifest_ && !is_ah) { + if (manifest_max_capacity_ && !is_ah) { auth::CryptoHashType hash_algo; socket_->getSocketOption(interface::GeneralTransportOptions::HASH_ALGORITHM, hash_algo); diff --git a/libtransport/src/protocols/prod_protocol_rtc.h b/libtransport/src/protocols/prod_protocol_rtc.h index c0424a39c..285ccb646 100644 --- a/libtransport/src/protocols/prod_protocol_rtc.h +++ b/libtransport/src/protocols/prod_protocol_rtc.h @@ -17,6 +17,7 @@ #include <hicn/transport/core/name.h> #include <protocols/production_protocol.h> +#include <protocols/rtc/rtc_verifier.h> #include <atomic> #include <map> @@ -50,11 +51,6 @@ class RTCProductionProtocol : public ProductionProtocol { buffer, buffer_size, buffer_size)); } - void setConsumerInSyncCallback( - interface::ProducerInterestCallback &&callback) { - on_consumer_in_sync_ = std::move(callback); - } - auto shared_from_this() { return utils::shared_from(this); } private: @@ -80,13 +76,6 @@ class RTCProductionProtocol : public ProductionProtocol { void updateStats(bool new_round); void scheduleRoundTimer(); - // pending intersts functions - void addToInterestQueue(uint32_t interest_seg, uint64_t expiration); - void sendNacksForPendingInterests(); - void removeFromInterestQueue(uint32_t interest_seg); - void scheduleQueueTimer(uint64_t wait); - void interestQueueTimer(); - // FEC functions void onFecPackets(fec::BufferArray &packets); fec::buffer getBuffer(std::size_t size); @@ -111,14 +100,14 @@ class RTCProductionProtocol : public ProductionProtocol { uint32_t prev_produced_bytes_; // XXX clearly explain all these new vars uint32_t prev_produced_packets_; - uint32_t produced_bytes_; // bytes produced in the last round - uint32_t produced_packets_; // packet produed in the last round + uint32_t produced_bytes_; // bytes produced in the last round + uint32_t produced_packets_; // packet produed in the last round uint32_t max_packet_production_; // never exceed this number of packets // without update stats - uint32_t bytes_production_rate_; // bytes per sec - uint32_t packets_production_rate_; // pps + uint32_t bytes_production_rate_; // bytes per sec + uint32_t packets_production_rate_; // pps uint64_t last_produced_data_ts_; // ms @@ -134,27 +123,6 @@ class RTCProductionProtocol : public ProductionProtocol { // of the new rate. bool allow_delayed_nacks_; - // queue for the received interests - // this map maps the expiration time of an interest to - // its sequence number. the map is sorted by timeouts - // the same timeout may be used for multiple sequence numbers - // but for each sequence number we store only the smallest - // expiry time. In this way the mapping from seqs_map_ to - // timers_map_ is unique - std::multimap<uint64_t, uint32_t> timers_map_; - - // this map does the opposite, this map is not ordered - std::unordered_map<uint32_t, uint64_t> seqs_map_; - bool queue_timer_on_; - std::unique_ptr<asio::steady_timer> interests_queue_timer_; - - // this callback is called when the remote consumer is in sync with high - // probability. it is called only the first time that the switch happen. - // XXX this makes sense only in P2P mode, while in standard mode is - // impossible to know the state of the consumers so it should not be used. - bool consumer_in_sync_; - interface::ProducerInterestCallback on_consumer_in_sync_; - // Save FEC packets here before sending them std::queue<ContentObject::Ptr> pending_fec_packets_; std::queue<std::pair<uint64_t, ContentObject::Ptr>> paced_fec_packets_; @@ -172,6 +140,9 @@ class RTCProductionProtocol : public ProductionProtocol { // Manifest std::queue<std::pair<uint32_t, auth::CryptoHash>> manifest_entries_; // map a packet suffix to a packet hash + + // Verifier for aggregated interests + std::shared_ptr<rtc::RTCVerifier> verifier_; }; } // namespace protocol diff --git a/libtransport/src/protocols/production_protocol.cc b/libtransport/src/protocols/production_protocol.cc index 8b781e38a..039a6a55a 100644 --- a/libtransport/src/protocols/production_protocol.cc +++ b/libtransport/src/protocols/production_protocol.cc @@ -78,8 +78,8 @@ int ProductionProtocol::start() { socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); socket_->getSocketOption(GeneralTransportOptions::SIGNER, signer_); - socket_->getSocketOption(GeneralTransportOptions::MAKE_MANIFEST, - making_manifest_); + socket_->getSocketOption(GeneralTransportOptions::MANIFEST_MAX_CAPACITY, + manifest_max_capacity_); std::string fec_type_str = ""; socket_->getSocketOption(GeneralTransportOptions::FEC_TYPE, fec_type_str); diff --git a/libtransport/src/protocols/production_protocol.h b/libtransport/src/protocols/production_protocol.h index 8e10d2f40..09718631f 100644 --- a/libtransport/src/protocols/production_protocol.h +++ b/libtransport/src/protocols/production_protocol.h @@ -79,6 +79,7 @@ class ProductionProtocol if (fec_str && (fec_type_ == fec::FECType::UNKNOWN)) { LOG(INFO) << "Using FEC " << fec_str; fec_type_ = fec::FECUtils::fecTypeFromString(fec_str); + CHECK(fec_type_ != fec::FECType::UNKNOWN); } if (fec_type_ == fec::FECType::UNKNOWN) { @@ -123,7 +124,7 @@ class ProductionProtocol // Signature and manifest std::shared_ptr<auth::Signer> signer_; - uint32_t making_manifest_; + uint32_t manifest_max_capacity_; bool is_async_; fec::FECType fec_type_; diff --git a/libtransport/src/protocols/raaqm.cc b/libtransport/src/protocols/raaqm.cc index 131367d78..bcbc15aef 100644 --- a/libtransport/src/protocols/raaqm.cc +++ b/libtransport/src/protocols/raaqm.cc @@ -371,7 +371,7 @@ void RaaqmTransportProtocol::onPacketDropped(Interest &interest, } interest_retransmissions_[segment & mask]++; - interest_to_retransmit_.push(segment); + interest_to_retransmit_.push((unsigned int)segment); } else { LOG(ERROR) << "Stop: received not trusted packet " << interest_retransmissions_[segment & mask] << " times"; @@ -429,7 +429,7 @@ void RaaqmTransportProtocol::onInterestTimeout(Interest::Ptr &interest, return; } - interest_to_retransmit_.push(segment); + interest_to_retransmit_.push((unsigned int)segment); scheduleNextInterests(); } else { LOG(ERROR) << "Stop: reached max retx limit."; @@ -491,7 +491,7 @@ void RaaqmTransportProtocol::updateRtt(uint64_t segment) { throw std::runtime_error("RAAQM ERROR: no current path found, exit"); } else { auto now = utils::SteadyTime::Clock::now(); - utils::SteadyTime::Milliseconds rtt = utils::SteadyTime::getDurationMs( + auto rtt = utils::SteadyTime::getDurationUs( interest_timepoints_[segment & mask], now); // Update stats @@ -525,7 +525,7 @@ void RaaqmTransportProtocol::RAAQM() { } void RaaqmTransportProtocol::updateStats( - uint32_t suffix, const utils::SteadyTime::Milliseconds &rtt, + uint32_t suffix, const utils::SteadyTime::Microseconds &rtt, utils::SteadyTime::TimePoint &now) { // Update RTT statistics stats_->updateAverageRtt(rtt); diff --git a/libtransport/src/protocols/raaqm.h b/libtransport/src/protocols/raaqm.h index ec344c23a..a7ef23b68 100644 --- a/libtransport/src/protocols/raaqm.h +++ b/libtransport/src/protocols/raaqm.h @@ -57,7 +57,7 @@ class RaaqmTransportProtocol : public TransportProtocol, virtual void afterDataUnsatisfied(uint64_t segment); virtual void updateStats(uint32_t suffix, - const utils::SteadyTime::Milliseconds &rtt, + const utils::SteadyTime::Microseconds &rtt, utils::SteadyTime::TimePoint &now); private: diff --git a/libtransport/src/protocols/raaqm_data_path.cc b/libtransport/src/protocols/raaqm_data_path.cc index d06fee918..b8e6e6285 100644 --- a/libtransport/src/protocols/raaqm_data_path.cc +++ b/libtransport/src/protocols/raaqm_data_path.cc @@ -50,9 +50,9 @@ RaaqmDataPath::RaaqmDataPath(double drop_factor, alpha_(ALPHA) {} RaaqmDataPath &RaaqmDataPath::insertNewRtt( - const utils::SteadyTime::Milliseconds &new_rtt, + const utils::SteadyTime::Microseconds &new_rtt, const utils::SteadyTime::TimePoint &now) { - rtt_ = new_rtt.count(); + rtt_ = new_rtt.count() / 1000; rtt_samples_.pushBack(rtt_); rtt_max_ = rtt_samples_.rBegin(); diff --git a/libtransport/src/protocols/raaqm_data_path.h b/libtransport/src/protocols/raaqm_data_path.h index b6f7c5ac1..dd24dad51 100644 --- a/libtransport/src/protocols/raaqm_data_path.h +++ b/libtransport/src/protocols/raaqm_data_path.h @@ -49,7 +49,7 @@ class RaaqmDataPath { * max of RTT. * @param new_rtt is the value of the new RTT */ - RaaqmDataPath &insertNewRtt(const utils::SteadyTime::Milliseconds &new_rtt, + RaaqmDataPath &insertNewRtt(const utils::SteadyTime::Microseconds &new_rtt, const utils::SteadyTime::TimePoint &now); /** diff --git a/libtransport/src/protocols/rate_estimation.cc b/libtransport/src/protocols/rate_estimation.cc index d834b53e6..01c18c6cb 100644 --- a/libtransport/src/protocols/rate_estimation.cc +++ b/libtransport/src/protocols/rate_estimation.cc @@ -107,9 +107,9 @@ InterRttEstimator::~InterRttEstimator() { } void InterRttEstimator::onRttUpdate( - const utils::SteadyTime::Milliseconds &rtt) { + const utils::SteadyTime::Microseconds &rtt) { pthread_mutex_lock(&(this->mutex_)); - this->rtt_ = rtt.count(); + this->rtt_ = rtt.count() / 1000.0; this->number_of_packets_++; this->avg_rtt_ += this->rtt_; pthread_mutex_unlock(&(this->mutex_)); @@ -256,7 +256,7 @@ void SimpleEstimator::onDataReceived(int packet_size) { this->total_size_ += packet_size; } -void SimpleEstimator::onRttUpdate(const utils::SteadyTime::Milliseconds &rtt) { +void SimpleEstimator::onRttUpdate(const utils::SteadyTime::Microseconds &rtt) { this->number_of_packets_++; if (this->number_of_packets_ == this->batching_param_) { @@ -300,9 +300,9 @@ BatchingPacketsEstimator::BatchingPacketsEstimator(double alpha_arg, } void BatchingPacketsEstimator::onRttUpdate( - const utils::SteadyTime::Milliseconds &rtt) { + const utils::SteadyTime::Microseconds &rtt) { this->number_of_packets_++; - this->avg_rtt_ += rtt.count(); + this->avg_rtt_ += rtt.count() / 1000.0; if (number_of_packets_ == this->batching_param_) { if (estimation_ == 0) { diff --git a/libtransport/src/protocols/rate_estimation.h b/libtransport/src/protocols/rate_estimation.h index b71de12e4..d809b2b7c 100644 --- a/libtransport/src/protocols/rate_estimation.h +++ b/libtransport/src/protocols/rate_estimation.h @@ -31,7 +31,7 @@ class IcnRateEstimator : utils::NonCopyable { virtual ~IcnRateEstimator(){}; - virtual void onRttUpdate(const utils::SteadyTime::Milliseconds &rtt){}; + virtual void onRttUpdate(const utils::SteadyTime::Microseconds &rtt){}; virtual void onDataReceived(int packetSize){}; @@ -66,7 +66,7 @@ class InterRttEstimator : public IcnRateEstimator { ~InterRttEstimator(); - void onRttUpdate(const utils::SteadyTime::Milliseconds &rtt); + void onRttUpdate(const utils::SteadyTime::Microseconds &rtt); void onDataReceived(int packet_size) { if (packet_size > this->max_packet_size_) { @@ -101,7 +101,7 @@ class BatchingPacketsEstimator : public IcnRateEstimator { public: BatchingPacketsEstimator(double alpha_arg, int batchingParam); - void onRttUpdate(const utils::SteadyTime::Milliseconds &rtt); + void onRttUpdate(const utils::SteadyTime::Microseconds &rtt); void onDataReceived(int packet_size) { if (packet_size > this->max_packet_size_) { @@ -148,7 +148,7 @@ class SimpleEstimator : public IcnRateEstimator { public: SimpleEstimator(double alpha, int batching_param); - void onRttUpdate(const utils::SteadyTime::Milliseconds &rtt); + void onRttUpdate(const utils::SteadyTime::Microseconds &rtt); void onDataReceived(int packet_size); diff --git a/libtransport/src/protocols/reassembly.h b/libtransport/src/protocols/reassembly.h index b0879201d..c0c4de3d8 100644 --- a/libtransport/src/protocols/reassembly.h +++ b/libtransport/src/protocols/reassembly.h @@ -57,12 +57,6 @@ class Reassembly { virtual void reassemble(utils::MemBuf &buffer, uint32_t suffix) = 0; /** - * Handle reassembly of manifest - */ - virtual void reassemble( - std::unique_ptr<core::ContentObjectManifest> &&manifest) = 0; - - /** * Reset reassembler for new round */ virtual void reInitialize() = 0; diff --git a/libtransport/src/protocols/rtc/probe_handler.cc b/libtransport/src/protocols/rtc/probe_handler.cc index 6a84914ab..60eceeb19 100644 --- a/libtransport/src/protocols/rtc/probe_handler.cc +++ b/libtransport/src/protocols/rtc/probe_handler.cc @@ -13,6 +13,7 @@ * limitations under the License. */ +#include <glog/logging.h> #include <hicn/transport/utils/chrono_typedefs.h> #include <protocols/rtc/probe_handler.h> #include <protocols/rtc/rtc_consts.h> @@ -64,7 +65,7 @@ double ProbeHandler::getProbeLossRate() { } void ProbeHandler::setSuffixRange(uint32_t min, uint32_t max) { - assert(min <= max && min >= MIN_PROBE_SEQ); + DCHECK(min <= max && min >= MIN_PROBE_SEQ); distr_ = std::uniform_int_distribution<uint32_t>(min, max); } diff --git a/libtransport/src/protocols/rtc/rtc.cc b/libtransport/src/protocols/rtc/rtc.cc index d2682edfa..9a56269f3 100644 --- a/libtransport/src/protocols/rtc/rtc.cc +++ b/libtransport/src/protocols/rtc/rtc.cc @@ -38,6 +38,7 @@ RTCTransportProtocol::RTCTransportProtocol( implementation::ConsumerSocket *icn_socket) : TransportProtocol(icn_socket, new RtcIndexer<>(icn_socket, this), new RtcReassembly(icn_socket, this)), + max_aggregated_interest_(1), number_(0) { icn_socket->getSocketOption(PORTAL, portal_); round_timer_ = @@ -55,9 +56,9 @@ void RTCTransportProtocol::resume() { TransportProtocol::resume(); } -std::size_t RTCTransportProtocol::transportHeaderLength() { +std::size_t RTCTransportProtocol::transportHeaderLength(bool isFEC) { return DATA_HEADER_SIZE + - (fec_decoder_ != nullptr ? fec_decoder_->getFecHeaderSize() : 0); + (fec_decoder_ != nullptr ? fec_decoder_->getFecHeaderSize(isFEC) : 0); } // private @@ -75,13 +76,13 @@ void RTCTransportProtocol::initParams() { std::shared_ptr<auth::Verifier> verifier; socket_->getSocketOption(GeneralTransportOptions::VERIFIER, verifier); - uint32_t unverified_interval; - socket_->getSocketOption(GeneralTransportOptions::UNVERIFIED_INTERVAL, - unverified_interval); + uint32_t factor_relevant; + socket_->getSocketOption(GeneralTransportOptions::MANIFEST_FACTOR_RELEVANT, + factor_relevant); - double unverified_ratio; - socket_->getSocketOption(GeneralTransportOptions::UNVERIFIED_RATIO, - unverified_ratio); + uint32_t factor_alert; + socket_->getSocketOption(GeneralTransportOptions::MANIFEST_FACTOR_ALERT, + factor_alert); rc_ = std::make_shared<RTCRateControlCongestionDetection>(); ldr_ = std::make_shared<RTCLossDetectionAndRecovery>( @@ -100,8 +101,8 @@ void RTCTransportProtocol::initParams() { } }); - verifier_ = std::make_shared<RTCVerifier>(verifier, unverified_interval, - unverified_ratio); + verifier_ = + std::make_shared<RTCVerifier>(verifier, factor_relevant, factor_alert); state_ = std::make_shared<RTCState>( indexer_verifier_.get(), @@ -138,19 +139,20 @@ void RTCTransportProtocol::initParams() { last_interest_sent_time_ = 0; last_interest_sent_seq_ = 0; -#if 0 - if(portal_->isConnectedToFwd()){ - max_aggregated_interest_ = 1; - }else{ - max_aggregated_interest_ = MAX_INTERESTS_IN_BATCH; - } -#else - max_aggregated_interest_ = 1; - if (const char *max_aggr = std::getenv("MAX_AGGREGATED_INTERESTS")) { - LOG(INFO) << "Max Aggregated: " << max_aggr; - max_aggregated_interest_ = std::stoul(std::string(max_aggr)); + // Aggregated interests setup + bool aggregated_interests_on; + socket_->getSocketOption(RtcTransportOptions::AGGREGATED_INTERESTS, + aggregated_interests_on); + if (aggregated_interests_on) { + if (const char *max_aggr = std::getenv("MAX_AGGREGATED_INTERESTS")) + max_aggregated_interest_ = (uint32_t)std::stoul(std::string(max_aggr)); + else + max_aggregated_interest_ = MAX_INTERESTS_IN_BATCH; + + max_aggregated_interest_ = std::min<uint32_t>(max_aggregated_interest_, + 1 + MAX_SUFFIXES_IN_MANIFEST); } -#endif + LOG(INFO) << "Max Aggregated: " << max_aggregated_interest_; max_sent_int_ = std::ceil((double)MAX_PACING_BATCH / (double)max_aggregated_interest_); @@ -263,6 +265,11 @@ void RTCTransportProtocol::discoveredRtt() { socket_->getSocketOption(RtcTransportOptions::RECOVERY_STRATEGY, strategy); ldr_->changeRecoveryStrategy( (interface::RtcTransportRecoveryStrategies)strategy); + + bool content_sharing_mode; + socket_->getSocketOption(RtcTransportOptions::CONTENT_SHARING_MODE, + content_sharing_mode); + if (content_sharing_mode) ldr_->setContentSharingMode(); ldr_->turnOnRecovery(); ldr_->onNewRound(false); @@ -270,22 +277,9 @@ void RTCTransportProtocol::discoveredRtt() { Name *name = nullptr; socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, &name); Prefix prefix(*name, 128); - if ((interface::RtcTransportRecoveryStrategies)strategy == - interface::RtcTransportRecoveryStrategies::LOW_RATE_AND_BESTPATH) { - fwd_strategy_.initFwdStrategy(portal_, prefix, state_.get(), - RTCForwardingStrategy::BEST_PATH); - } else if ((interface::RtcTransportRecoveryStrategies)strategy == - interface::RtcTransportRecoveryStrategies:: - LOW_RATE_AND_REPLICATION) { - fwd_strategy_.initFwdStrategy(portal_, prefix, state_.get(), - RTCForwardingStrategy::REPLICATION); - } else if ((interface::RtcTransportRecoveryStrategies)strategy == - interface::RtcTransportRecoveryStrategies:: - LOW_RATE_AND_ALL_FWD_STRATEGIES) { - fwd_strategy_.initFwdStrategy(portal_, prefix, state_.get(), - RTCForwardingStrategy::BOTH); - } - + fwd_strategy_.initFwdStrategy( + portal_, prefix, state_.get(), + (interface::RtcTransportRecoveryStrategies)strategy); updateSyncWindow(); } @@ -302,6 +296,12 @@ void RTCTransportProtocol::computeMaxSyncWindow() { return; } + bool content_sharing_mode; + socket_->getSocketOption(RtcTransportOptions::CONTENT_SHARING_MODE, + content_sharing_mode); + if (content_sharing_mode && (production_rate < MIN_PROD_RATE_SHARING_MODE)) + production_rate = MIN_PROD_RATE_SHARING_MODE; + production_rate += (production_rate * indexer_verifier_->getMaxFecOverhead()); uint32_t lifetime = default_values::interest_lifetime; @@ -330,6 +330,11 @@ void RTCTransportProtocol::updateSyncWindow() { double prod_rate = state_->getProducerRate(); double rtt = (double)state_->getMinRTT() / MILLI_IN_A_SEC; double packet_size = state_->getAveragePacketSize(); + bool content_sharing_mode; + socket_->getSocketOption(RtcTransportOptions::CONTENT_SHARING_MODE, + content_sharing_mode); + if (content_sharing_mode && (prod_rate < MIN_PROD_RATE_SHARING_MODE)) + prod_rate = MIN_PROD_RATE_SHARING_MODE; // if some of the info are not available do not update the current win if (prod_rate != 0.0 && rtt != 0.0 && packet_size != 0.0) { @@ -385,6 +390,19 @@ void RTCTransportProtocol::sendProbeInterest(uint32_t seq) { sendInterest(*interest_name); } +void RTCTransportProtocol::sendInterestForTimeout(uint32_t seq) { + if (!isRunning() && !is_first_) return; + + Name *interest_name = nullptr; + socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, + &interest_name); + + // we got a timeout for this packet so it is not pending anymore + interest_name->setSuffix(seq); + state_->onSendNewInterest(interest_name); + sendInterest(*interest_name); +} + void RTCTransportProtocol::scheduleNextInterests() { DLOG_IF(INFO, VLOG_IS_ON(3)) << "Schedule next interests"; @@ -475,9 +493,9 @@ void RTCTransportProtocol::scheduleNextInterests() { } // skip received packets - if (indexer_verifier_->checkNextSuffix() <= - state_->getHighestSeqReceivedInOrder()) { - indexer_verifier_->jumpToIndex(state_->getHighestSeqReceivedInOrder() + 1); + uint32_t max_received = state_->getHighestSeqReceivedInOrder(); + if (indexer_verifier_->checkNextSuffix() <= max_received) { + indexer_verifier_->jumpToIndex(max_received + 1); } uint32_t sent_interests = 0; @@ -495,7 +513,6 @@ void RTCTransportProtocol::scheduleNextInterests() { << "In while loop. Window size: " << current_sync_win_; uint32_t next_seg = indexer_verifier_->getNextSuffix(); - name->setSuffix(next_seg); // send the packet only if: @@ -586,7 +603,6 @@ void RTCTransportProtocol::onInterestTimeout(Interest::Ptr &interest, } timeouts_or_nacks_.insert(segment_number); - if (TRANSPORT_EXPECT_TRUE(state_->isProducerActive()) && segment_number <= state_->getHighestSeqReceived()) { // we retransmit packets only if the producer is active, otherwise we @@ -627,11 +643,11 @@ void RTCTransportProtocol::onInterestTimeout(Interest::Ptr &interest, << "On timeout next seg = " << indexer_verifier_->checkNextSuffix() << ", jump to " << segment_number; // add an extra space in the window - current_sync_win_++; indexer_verifier_->jumpToIndex(segment_number); } state_->onTimeout(segment_number, false); + sendInterestForTimeout(segment_number); scheduleNextInterests(); } @@ -672,7 +688,6 @@ void RTCTransportProtocol::onNack(const ContentObject &content_object) { if (tn_it != timeouts_or_nacks_.end()) timeouts_or_nacks_.erase(tn_it); state_->onJumpForward(production_seg); - verifier_->onJumpForward(production_seg); // the client is asking for content in the past // switch to catch up state and increase the window // this is true only if the packet is not an RTX @@ -821,7 +836,8 @@ void RTCTransportProtocol::onContentObjectReceived( // Check if the packet is a retransmission if (ldr_->isRtx(segment_number) && state != PacketState::RECEIVED) { if (is_data || is_manifest) { - state_->onPacketRecoveredRtx(segment_number); + uint64_t rtt = ldr_->getRtxRtt(segment_number); + state_->onPacketRecoveredRtx(content_object, rtt); if (*on_content_object_input_) { (*on_content_object_input_)(*socket_->getInterface(), content_object); @@ -842,7 +858,7 @@ void RTCTransportProtocol::onContentObjectReceived( } if (is_fec) { - state_->onFecPacketRecoveredRtx(segment_number); + state_->onFecPacketRecoveredRtx(content_object); } } @@ -920,7 +936,7 @@ void RTCTransportProtocol::sendStatsToApp( stats_->updateAverageWindowSize(state_->getPendingInterestNumber()); stats_->updateLossRatio(state_->getPerSecondLossRate()); uint64_t rtt = state_->getAvgRTT(); - stats_->updateAverageRtt(utils::SteadyTime::Milliseconds(rtt)); + stats_->updateAverageRtt(utils::SteadyTime::Microseconds(rtt * 1000)); stats_->updateQueuingDelay(state_->getQueuing()); stats_->updateLostData(lost_data); @@ -960,9 +976,10 @@ void RTCTransportProtocol::decodePacket(ContentObject &content_object, DLOG_IF(INFO, VLOG_IS_ON(4)) << "Send packet " << content_object.getName() << " to FEC decoder"; - uint32_t offset = is_manifest - ? content_object.headerSize() - : content_object.headerSize() + rtc::DATA_HEADER_SIZE; + uint32_t offset = + is_manifest + ? (uint32_t)content_object.headerSize() + : (uint32_t)(content_object.headerSize() + rtc::DATA_HEADER_SIZE); uint32_t metadata = static_cast<uint32_t>(content_object.getPayloadType()); fec_decoder_->onDataPacket(content_object, offset, metadata); @@ -1016,7 +1033,7 @@ void RTCTransportProtocol::onFecPackets(fec::BufferArray &packets) { processManifest(*interest, *content_object); } - state_->onPacketRecoveredFec(seq_number, buffer->length()); + state_->onPacketRecoveredFec(seq_number, (uint32_t)buffer->length()); ldr_->onPacketRecoveredFec(seq_number); if (payload_type == PayloadType::DATA) { @@ -1038,11 +1055,11 @@ void RTCTransportProtocol::processManifest(Interest &interest, ContentObject::Ptr RTCTransportProtocol::removeFecHeader( const ContentObject &content_object) { - if (!fec_decoder_ || !fec_decoder_->getFecHeaderSize()) { + if (!fec_decoder_ || !fec_decoder_->getFecHeaderSize(false)) { return nullptr; } - size_t fec_header_size = fec_decoder_->getFecHeaderSize(); + size_t fec_header_size = fec_decoder_->getFecHeaderSize(false); const uint8_t *payload = content_object.data() + content_object.headerSize() + fec_header_size; size_t payload_size = content_object.payloadSize() - fec_header_size; diff --git a/libtransport/src/protocols/rtc/rtc.h b/libtransport/src/protocols/rtc/rtc.h index 3763f33c7..a8a474216 100644 --- a/libtransport/src/protocols/rtc/rtc.h +++ b/libtransport/src/protocols/rtc/rtc.h @@ -44,7 +44,7 @@ class RTCTransportProtocol : public TransportProtocol { void resume() override; - std::size_t transportHeaderLength() override; + std::size_t transportHeaderLength(bool isFEC) override; auto shared_from_this() { return utils::shared_from(this); } @@ -69,6 +69,7 @@ class RTCTransportProtocol : public TransportProtocol { // packet functions void sendRtxInterest(uint32_t seq); void sendProbeInterest(uint32_t seq); + void sendInterestForTimeout(uint32_t seq); void scheduleNextInterests() override; void onInterestTimeout(Interest::Ptr &interest, const Name &name) override; void onNack(const ContentObject &content_object); diff --git a/libtransport/src/protocols/rtc/rtc_consts.h b/libtransport/src/protocols/rtc/rtc_consts.h index 96e39d07e..29b5a3a12 100644 --- a/libtransport/src/protocols/rtc/rtc_consts.h +++ b/libtransport/src/protocols/rtc/rtc_consts.h @@ -54,7 +54,7 @@ const uint32_t PACING_WAIT = 1000; // usec to wait betwing two pacing batch. As const uint32_t MAX_RTX_IN_BATCH = 10; // max rtx to send in loop // packet const -const uint32_t RTC_INTEREST_LIFETIME = 2000; +const uint32_t RTC_INTEREST_LIFETIME = 4000; // probes sequence range const uint32_t MIN_PROBE_SEQ = 0xefffffff; @@ -93,6 +93,7 @@ const double CATCH_UP_WIN_INCREMENT = 1.2; // used in rate control const double WIN_DECREASE_FACTOR = 0.5; const double WIN_INCREASE_FACTOR = 1.5; +const uint32_t MIN_PROD_RATE_SHARING_MODE = 125000; // 1Mbps in bytes // round in congestion const double ROUNDS_BEFORE_TAKE_ACTION = 5; @@ -120,14 +121,14 @@ const uint64_t MAX_TIMER_RTX = ~0; const uint32_t SENTINEL_TIMER_INTERVAL = 100; // ms const uint32_t MAX_RTX_WITH_SENTINEL = 10; // packets const double CATCH_UP_RTT_INCREMENT = 1.2; -const double MAX_RESIDUAL_LOSS_RATE = 2.0; // % -const uint32_t WAIT_BEFORE_FEC_UPDATE = ROUNDS_PER_SEC * 5; +const double MAX_RESIDUAL_LOSS_RATE = 1.0; // % +const uint32_t WAIT_BEFORE_FEC_UPDATE = ROUNDS_PER_SEC; +const uint32_t MAX_RTT_BEFORE_FEC = 60; // ms // used by producer const uint32_t PRODUCER_STATS_INTERVAL = 200; // ms const uint32_t MIN_PRODUCTION_RATE = 25; // pps, equal to min window * // rounds in a second -const uint32_t NACK_DELAY = 1500; // ms const uint32_t FEC_PACING_TIME = 5; // ms // aggregated data consts @@ -139,6 +140,73 @@ const uint32_t AGGREGATED_PACKETS_TIMER = 2; // ms const uint32_t MAX_RTT = 200; // ms const double MAX_RESIDUAL_LOSSES = 0.05; // % +const uint8_t FEC_MATRIX[64][10] = { + {1, 2, 2, 2, 3, 3, 4, 5, 5, 6}, // k = 1 + {1, 2, 3, 3, 4, 5, 5, 6, 7, 9}, + {2, 2, 3, 4, 5, 6, 7, 8, 9, 11}, + {2, 3, 4, 5, 5, 7, 8, 9, 11, 13}, + {2, 3, 4, 5, 6, 7, 9, 10, 12, 14}, // k = 5 + {2, 3, 4, 6, 7, 8, 10, 12, 14, 16}, + {2, 4, 5, 6, 8, 9, 11, 13, 15, 18}, + {3, 4, 5, 7, 8, 10, 12, 14, 16, 19}, + {3, 4, 6, 7, 9, 11, 13, 15, 18, 21}, + {3, 4, 6, 8, 9, 11, 14, 16, 19, 23}, // k = 10 + {3, 5, 6, 8, 10, 12, 14, 17, 20, 24}, + {3, 5, 7, 8, 10, 13, 15, 18, 21, 26}, + {3, 5, 7, 9, 11, 13, 16, 19, 23, 27}, + {3, 5, 7, 9, 12, 14, 17, 20, 24, 28}, + {4, 6, 8, 10, 12, 15, 18, 21, 25, 30}, // k = 15 + {4, 6, 8, 10, 13, 15, 19, 22, 26, 31}, + {4, 6, 8, 11, 13, 16, 19, 23, 27, 33}, + {4, 6, 9, 11, 14, 17, 20, 24, 29, 34}, + {4, 6, 9, 11, 14, 17, 21, 25, 30, 35}, + {4, 7, 9, 12, 15, 18, 22, 26, 31, 37}, // k = 20 + {4, 7, 9, 12, 15, 19, 22, 27, 32, 38}, + {4, 7, 10, 13, 16, 19, 23, 28, 33, 40}, + {5, 7, 10, 13, 16, 20, 24, 29, 34, 41}, + {5, 7, 10, 13, 17, 20, 25, 30, 35, 42}, + {5, 8, 11, 14, 17, 21, 26, 31, 37, 44}, // k = 25 + {5, 8, 11, 14, 18, 22, 26, 31, 38, 45}, + {5, 8, 11, 15, 18, 22, 27, 32, 39, 46}, + {5, 8, 11, 15, 19, 23, 28, 33, 40, 48}, + {5, 8, 12, 15, 19, 24, 28, 34, 41, 49}, + {5, 9, 12, 16, 20, 24, 29, 35, 42, 50}, // k = 30 + {5, 9, 12, 16, 20, 25, 30, 36, 43, 51}, + {5, 9, 13, 16, 21, 25, 31, 37, 44, 53}, + {6, 9, 13, 17, 21, 26, 31, 38, 45, 54}, + {6, 9, 13, 17, 22, 26, 32, 39, 46, 55}, + {6, 10, 13, 17, 22, 27, 33, 40, 47, 57}, // k = 35 + {6, 10, 14, 18, 22, 28, 34, 40, 48, 58}, + {6, 10, 14, 18, 23, 28, 34, 41, 49, 59}, + {6, 10, 14, 19, 23, 29, 35, 42, 50, 60}, + {6, 10, 14, 19, 24, 29, 36, 43, 52, 62}, + {6, 10, 15, 19, 24, 30, 36, 44, 53, 63}, // k = 40 + {6, 11, 15, 20, 25, 31, 37, 45, 54, 64}, + {6, 11, 15, 20, 25, 31, 38, 46, 55, 65}, + {7, 11, 15, 20, 26, 32, 39, 46, 56, 67}, + {7, 11, 16, 21, 26, 32, 39, 47, 57, 68}, + {7, 11, 16, 21, 27, 33, 40, 48, 58, 69}, // k = 45 + {7, 11, 16, 21, 27, 33, 41, 49, 59, 70}, + {7, 12, 16, 22, 27, 34, 41, 50, 60, 72}, + {7, 12, 17, 22, 28, 34, 42, 51, 61, 73}, + {7, 12, 17, 22, 28, 35, 43, 52, 62, 74}, + {7, 12, 17, 23, 29, 36, 43, 52, 63, 75}, // k = 50 + {7, 12, 17, 23, 29, 36, 44, 53, 64, 77}, + {7, 12, 18, 23, 30, 37, 45, 54, 65, 78}, + {7, 13, 18, 24, 30, 37, 45, 55, 66, 79}, + {8, 13, 18, 24, 31, 38, 46, 56, 67, 80}, + {8, 13, 18, 24, 31, 38, 47, 57, 68, 82}, // k = 55 + {8, 13, 19, 25, 31, 39, 47, 57, 69, 83}, + {8, 13, 19, 25, 32, 39, 48, 58, 70, 84}, + {8, 13, 19, 25, 32, 40, 49, 59, 71, 85}, + {8, 14, 19, 26, 33, 41, 50, 60, 72, 86}, + {8, 14, 20, 26, 33, 41, 50, 61, 73, 88}, // k = 60 + {8, 14, 20, 26, 34, 42, 51, 61, 74, 89}, + {8, 14, 20, 27, 34, 42, 52, 62, 75, 90}, + {8, 14, 20, 27, 34, 43, 52, 63, 76, 91}, + {8, 14, 21, 27, 35, 43, 53, 64, 77, 92}, // k = 64 +}; + } // namespace rtc } // namespace protocol diff --git a/libtransport/src/protocols/rtc/rtc_data_path.cc b/libtransport/src/protocols/rtc/rtc_data_path.cc index b3abf5ea8..a421396b1 100644 --- a/libtransport/src/protocols/rtc/rtc_data_path.cc +++ b/libtransport/src/protocols/rtc/rtc_data_path.cc @@ -91,6 +91,8 @@ void RTCDataPath::insertRttSample( rtt_samples_ = 0; last_avg_rtt_compute_ = now; } + + received_packets_++; } void RTCDataPath::insertOwdSample(int64_t owd) { @@ -115,10 +117,6 @@ void RTCDataPath::insertOwdSample(int64_t owd) { int64_t diff = std::abs(owd - last_owd_); last_owd_ = owd; jitter_ += (1.0 / 16.0) * ((double)diff - jitter_); - - // owd is computed only for valid data packets so we count only - // this for decide if we recevie traffic or not - received_packets_++; } void RTCDataPath::computeInterArrivalGap(uint32_t segment_number) { @@ -150,12 +148,17 @@ double RTCDataPath::getInterArrivalGap() { return avg_inter_arrival_; } -bool RTCDataPath::isActive() { +bool RTCDataPath::isValidProducer() { if (received_nacks_ && rounds_without_packets_ < MAX_ROUNDS_WITHOUT_PKTS) return true; return false; } +bool RTCDataPath::isActive() { + if (rounds_without_packets_ < MAX_ROUNDS_WITHOUT_PKTS) return true; + return false; +} + bool RTCDataPath::pathToProducer() { if (received_nacks_) return true; return false; diff --git a/libtransport/src/protocols/rtc/rtc_data_path.h b/libtransport/src/protocols/rtc/rtc_data_path.h index 5afbbb87f..ba5201fe8 100644 --- a/libtransport/src/protocols/rtc/rtc_data_path.h +++ b/libtransport/src/protocols/rtc/rtc_data_path.h @@ -49,8 +49,9 @@ class RTCDataPath { double getQueuingDealy(); double getInterArrivalGap(); double getJitter(); - bool isActive(); - bool pathToProducer(); + bool isActive(); // pakets recevied from this path in the last rounds + bool pathToProducer(); // path from a producer + bool isValidProducer(); // path from a producer that is also active uint64_t getLastPacketTS(); uint32_t getPacketsLastRound(); diff --git a/libtransport/src/protocols/rtc/rtc_forwarding_strategy.cc b/libtransport/src/protocols/rtc/rtc_forwarding_strategy.cc index c6bc751e6..4bbd7eac0 100644 --- a/libtransport/src/protocols/rtc/rtc_forwarding_strategy.cc +++ b/libtransport/src/protocols/rtc/rtc_forwarding_strategy.cc @@ -14,6 +14,7 @@ */ #include <hicn/transport/interfaces/notification.h> +#include <protocols/rtc/rtc_consts.h> #include <protocols/rtc/rtc_forwarding_strategy.h> namespace transport { @@ -24,8 +25,13 @@ namespace rtc { using namespace transport::interface; +const double FWD_MAX_QUEUE = 30.0; // ms +const double FWD_MAX_RTT = MAX_RTT_BEFORE_FEC; // ms +const double FWD_MAX_LOSS_RATE = 0.1; + RTCForwardingStrategy::RTCForwardingStrategy() - : init_(false), + : low_rate_app_(false), + init_(false), forwarder_set_(false), selected_strategy_(NONE), current_strategy_(NONE), @@ -42,17 +48,56 @@ void RTCForwardingStrategy::setCallback( void RTCForwardingStrategy::initFwdStrategy( std::shared_ptr<core::Portal> portal, core::Prefix& prefix, RTCState* state, - strategy_t strategy) { - init_ = true; - selected_strategy_ = strategy; - if (strategy == BOTH) - current_strategy_ = BEST_PATH; - else - current_strategy_ = strategy; - rounds_since_last_set_ = 0; - prefix_ = prefix; - portal_ = portal; - state_ = state; + interface::RtcTransportRecoveryStrategies strategy) { + switch (strategy) { + case interface::RtcTransportRecoveryStrategies::LOW_RATE_AND_BESTPATH: + init_ = true; + low_rate_app_ = true; + selected_strategy_ = BEST_PATH; + current_strategy_ = BEST_PATH; + break; + case interface::RtcTransportRecoveryStrategies::LOW_RATE_AND_REPLICATION: + init_ = true; + low_rate_app_ = true; + selected_strategy_ = REPLICATION; + current_strategy_ = REPLICATION; + break; + case interface::RtcTransportRecoveryStrategies:: + LOW_RATE_AND_ALL_FWD_STRATEGIES: + init_ = true; + low_rate_app_ = true; + selected_strategy_ = BEST_PATH; + current_strategy_ = BEST_PATH; + break; + case interface::RtcTransportRecoveryStrategies::DELAY_AND_BESTPATH: + init_ = true; + low_rate_app_ = false; + selected_strategy_ = BEST_PATH; + current_strategy_ = BEST_PATH; + break; + case interface::RtcTransportRecoveryStrategies::DELAY_AND_REPLICATION: + init_ = true; + low_rate_app_ = false; + selected_strategy_ = REPLICATION; + current_strategy_ = REPLICATION; + break; + case interface::RtcTransportRecoveryStrategies::RECOVERY_OFF: + case interface::RtcTransportRecoveryStrategies::RTX_ONLY: + case interface::RtcTransportRecoveryStrategies::FEC_ONLY: + case interface::RtcTransportRecoveryStrategies::DELAY_BASED: + case interface::RtcTransportRecoveryStrategies::LOW_RATE: + case interface::RtcTransportRecoveryStrategies::FEC_ONLY_LOW_RES_LOSSES: + default: + // fwd strategies are not used + init_ = false; + } + + if (init_) { + rounds_since_last_set_ = 0; + prefix_ = prefix; + portal_ = portal; + state_ = state; + } } void RTCForwardingStrategy::checkStrategy() { @@ -99,16 +144,35 @@ void RTCForwardingStrategy::checkStrategyBestPath() { return; } - uint8_t qs = state_->getQualityScore(); + if (low_rate_app_) { + // this is used for gaming + uint8_t qs = state_->getQualityScore(); - if (qs >= 4 || rounds_since_last_set_ < 25) { // wait a least 5 sec - // between each switch - rounds_since_last_set_++; - return; - } + if (qs >= 4 || rounds_since_last_set_ < 25) { // wait a least 5 sec + // between each switch + rounds_since_last_set_++; + return; + } - // try to switch path - setStrategy(BEST_PATH); + // try to switch path + setStrategy(BEST_PATH); + } else { + if (rounds_since_last_set_ < 25) { // wait a least 5 sec + // between each switch + rounds_since_last_set_++; + return; + } + + double queue = state_->getQueuing(); + double rtt = state_->getAvgRTT(); + double loss_rate = state_->getPerSecondLossRate(); + + if (queue >= FWD_MAX_QUEUE || rtt >= FWD_MAX_RTT || + loss_rate > FWD_MAX_LOSS_RATE) { + // try to switch path + setStrategy(BEST_PATH); + } + } } void RTCForwardingStrategy::checkStrategyReplication() { @@ -133,7 +197,7 @@ void RTCForwardingStrategy::checkStrategyBoth() { // TODO // for the moment we use only best path. - // but later: + // for later: // 1. if both paths are bad use replication // 2. while using replication compute the effectiveness. if the majority of // the packets are coming from a single path, try to use bestpath diff --git a/libtransport/src/protocols/rtc/rtc_forwarding_strategy.h b/libtransport/src/protocols/rtc/rtc_forwarding_strategy.h index 9825877fd..c2227e09f 100644 --- a/libtransport/src/protocols/rtc/rtc_forwarding_strategy.h +++ b/libtransport/src/protocols/rtc/rtc_forwarding_strategy.h @@ -41,7 +41,7 @@ class RTCForwardingStrategy { void initFwdStrategy(std::shared_ptr<core::Portal> portal, core::Prefix& prefix, RTCState* state, - strategy_t strategy); + interface::RtcTransportRecoveryStrategies strategy); void checkStrategy(); void setCallback(interface::StrategyCallback&& callback); @@ -56,6 +56,10 @@ class RTCForwardingStrategy { std::array<std::string, 4> string_strategies_ = {"bestpath", "replication", "both", "none"}; + bool low_rate_app_; // if set to true the best path strategy will + // trigger a path switch based on the quality + // score, otherwise it will use the RTT, + // queuing delay and loss rate bool init_; // true if all val are initializes bool forwarder_set_; // true if the strategy is been set at least // once diff --git a/libtransport/src/protocols/rtc/rtc_ldr.cc b/libtransport/src/protocols/rtc/rtc_ldr.cc index abf6cda2c..6e88a8636 100644 --- a/libtransport/src/protocols/rtc/rtc_ldr.cc +++ b/libtransport/src/protocols/rtc/rtc_ldr.cc @@ -37,16 +37,24 @@ RTCLossDetectionAndRecovery::RTCLossDetectionAndRecovery( interface::RtcTransportRecoveryStrategies type, RecoveryStrategy::SendRtxCallback &&callback, interface::StrategyCallback &&external_callback) { - rs_type_ = type; if (type == interface::RtcTransportRecoveryStrategies::RECOVERY_OFF) { rs_ = std::make_shared<RecoveryStrategyRecoveryOff>( - indexer, std::move(callback), io_service, std::move(external_callback)); - } else if (type == interface::RtcTransportRecoveryStrategies::DELAY_BASED) { + indexer, std::move(callback), io_service, type, + std::move(external_callback)); + } else if (type == interface::RtcTransportRecoveryStrategies::DELAY_BASED || + type == interface::RtcTransportRecoveryStrategies:: + DELAY_AND_BESTPATH || + type == interface::RtcTransportRecoveryStrategies:: + DELAY_AND_REPLICATION) { rs_ = std::make_shared<RecoveryStrategyDelayBased>( - indexer, std::move(callback), io_service, std::move(external_callback)); - } else if (type == interface::RtcTransportRecoveryStrategies::FEC_ONLY) { + indexer, std::move(callback), io_service, type, + std::move(external_callback)); + } else if (type == interface::RtcTransportRecoveryStrategies::FEC_ONLY || + type == interface::RtcTransportRecoveryStrategies:: + FEC_ONLY_LOW_RES_LOSSES) { rs_ = std::make_shared<RecoveryStrategyFecOnly>( - indexer, std::move(callback), io_service, std::move(external_callback)); + indexer, std::move(callback), io_service, type, + std::move(external_callback)); } else if (type == interface::RtcTransportRecoveryStrategies::LOW_RATE || type == interface::RtcTransportRecoveryStrategies:: LOW_RATE_AND_BESTPATH || @@ -55,12 +63,14 @@ RTCLossDetectionAndRecovery::RTCLossDetectionAndRecovery( type == interface::RtcTransportRecoveryStrategies:: LOW_RATE_AND_ALL_FWD_STRATEGIES) { rs_ = std::make_shared<RecoveryStrategyLowRate>( - indexer, std::move(callback), io_service, std::move(external_callback)); + indexer, std::move(callback), io_service, type, + std::move(external_callback)); } else { // default - rs_type_ = interface::RtcTransportRecoveryStrategies::RTX_ONLY; + type = interface::RtcTransportRecoveryStrategies::RTX_ONLY; rs_ = std::make_shared<RecoveryStrategyRtxOnly>( - indexer, std::move(callback), io_service, std::move(external_callback)); + indexer, std::move(callback), io_service, type, + std::move(external_callback)); } } @@ -68,15 +78,21 @@ RTCLossDetectionAndRecovery::~RTCLossDetectionAndRecovery() {} void RTCLossDetectionAndRecovery::changeRecoveryStrategy( interface::RtcTransportRecoveryStrategies type) { - if (type == rs_type_) return; + if (type == rs_->getType()) return; - rs_type_ = type; + rs_->updateType(type); if (type == interface::RtcTransportRecoveryStrategies::RECOVERY_OFF) { rs_ = std::make_shared<RecoveryStrategyRecoveryOff>(std::move(*(rs_.get()))); - } else if (type == interface::RtcTransportRecoveryStrategies::DELAY_BASED) { + } else if (type == interface::RtcTransportRecoveryStrategies::DELAY_BASED || + type == interface::RtcTransportRecoveryStrategies:: + DELAY_AND_BESTPATH || + type == interface::RtcTransportRecoveryStrategies:: + DELAY_AND_REPLICATION) { rs_ = std::make_shared<RecoveryStrategyDelayBased>(std::move(*(rs_.get()))); - } else if (type == interface::RtcTransportRecoveryStrategies::FEC_ONLY) { + } else if (type == interface::RtcTransportRecoveryStrategies::FEC_ONLY || + type == interface::RtcTransportRecoveryStrategies:: + FEC_ONLY_LOW_RES_LOSSES) { rs_ = std::make_shared<RecoveryStrategyFecOnly>(std::move(*(rs_.get()))); } else if (type == interface::RtcTransportRecoveryStrategies::LOW_RATE || type == interface::RtcTransportRecoveryStrategies:: @@ -116,14 +132,15 @@ bool RTCLossDetectionAndRecovery::onDataPacketReceived( uint32_t seq = content_object.getName().getSuffix(); bool is_rtx = rs_->isRtx(seq); rs_->receivedPacket(seq); + bool ret = false; DLOG_IF(INFO, VLOG_IS_ON(3)) << "received data. add from " - << rs_->getState()->getHighestSeqReceivedInOrder() + 1 << " to " << seq; + << rs_->getState()->getHighestSeqReceived() + 1 << " to " << seq; if (!is_rtx) - return detectLoss(rs_->getState()->getHighestSeqReceivedInOrder() + 1, seq, - false); + ret = detectLoss(rs_->getState()->getHighestSeqReceived() + 1, seq, false); - return false; + rs_->getState()->updateHighestSeqReceived(seq); + return ret; } bool RTCLossDetectionAndRecovery::onNackPacketReceived( @@ -141,10 +158,9 @@ bool RTCLossDetectionAndRecovery::onNackPacketReceived( // may got lost and we should ask them rs_->receivedPacket(seq); - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "received nack. add from " - << rs_->getState()->getHighestSeqReceivedInOrder() + 1 << " to " - << production_seq; + DLOG_IF(INFO, VLOG_IS_ON(3)) << "received nack. add from " + << rs_->getState()->getHighestSeqReceived() + 1 + << " to " << production_seq; // if it is a future nack store it in the list set of nacked seq if (production_seq <= seq) rs_->receivedFutureNack(seq); @@ -152,7 +168,7 @@ bool RTCLossDetectionAndRecovery::onNackPacketReceived( // call the detectLoss function using the probe flag = true. in fact the // losses detected using nacks are the same as the one detected using probes, // we should not increase the loss counter - return detectLoss(rs_->getState()->getHighestSeqReceivedInOrder() + 1, + return detectLoss(rs_->getState()->getHighestSeqReceived() + 1, production_seq, true); } @@ -164,12 +180,11 @@ bool RTCLossDetectionAndRecovery::onProbePacketReceived( uint32_t production_seq = RTCState::getProbeParams(probe).prod_seg; - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "received probe. add from " - << rs_->getState()->getHighestSeqReceivedInOrder() + 1 << " to " - << production_seq; + DLOG_IF(INFO, VLOG_IS_ON(3)) << "received probe. add from " + << rs_->getState()->getHighestSeqReceived() + 1 + << " to " << production_seq; - return detectLoss(rs_->getState()->getHighestSeqReceivedInOrder() + 1, + return detectLoss(rs_->getState()->getHighestSeqReceived() + 1, production_seq, true); } @@ -183,8 +198,8 @@ bool RTCLossDetectionAndRecovery::detectLoss(uint32_t start, uint32_t stop, } // skip received or lost packets - if (start <= rs_->getState()->getHighestSeqReceivedInOrder()) { - start = rs_->getState()->getHighestSeqReceivedInOrder() + 1; + if (start <= rs_->getState()->getHighestSeqReceived()) { + start = rs_->getState()->getHighestSeqReceived() + 1; } bool loss_detected = false; diff --git a/libtransport/src/protocols/rtc/rtc_ldr.h b/libtransport/src/protocols/rtc/rtc_ldr.h index 7f683eaa6..24f22ffed 100644 --- a/libtransport/src/protocols/rtc/rtc_ldr.h +++ b/libtransport/src/protocols/rtc/rtc_ldr.h @@ -47,6 +47,7 @@ class RTCLossDetectionAndRecovery void setFecParams(uint32_t n, uint32_t k) { rs_->setFecParams(n, k); } + void setContentSharingMode() { rs_->setContentSharingMode(); } void turnOnRecovery() { rs_->turnOnRecovery(); } bool isRtxOn() { return rs_->isRtxOn(); } @@ -68,11 +69,12 @@ class RTCLossDetectionAndRecovery return rs_->isPossibleLossWithNoRtx(seq); } + uint64_t getRtxRtt(uint32_t seq) { return rs_->getRtxRtt(seq); } + private: // returns true if a loss is detected, false otherwise bool detectLoss(uint32_t start, uint32_t stop, bool recv_probe); - interface::RtcTransportRecoveryStrategies rs_type_; std::shared_ptr<RecoveryStrategy> rs_; }; diff --git a/libtransport/src/protocols/rtc/rtc_packet.h b/libtransport/src/protocols/rtc/rtc_packet.h index 391aedfc6..ffbbd78fd 100644 --- a/libtransport/src/protocols/rtc/rtc_packet.h +++ b/libtransport/src/protocols/rtc/rtc_packet.h @@ -52,6 +52,8 @@ #include <hicn/transport/portability/win_portability.h> #endif +#include <hicn/transport/portability/endianess.h> + #include <cstring> namespace transport { @@ -60,24 +62,6 @@ namespace protocol { namespace rtc { -inline uint64_t _ntohll(const uint64_t *input) { - uint64_t return_val; - uint8_t *tmp = (uint8_t *)&return_val; - - tmp[0] = (uint8_t)(*input >> 56); - tmp[1] = (uint8_t)(*input >> 48); - tmp[2] = (uint8_t)(*input >> 40); - tmp[3] = (uint8_t)(*input >> 32); - tmp[4] = (uint8_t)(*input >> 24); - tmp[5] = (uint8_t)(*input >> 16); - tmp[6] = (uint8_t)(*input >> 8); - tmp[7] = (uint8_t)(*input >> 0); - - return return_val; -} - -inline uint64_t _htonll(const uint64_t *input) { return (_ntohll(input)); } - const uint32_t DATA_HEADER_SIZE = 12; // bytes // XXX: sizeof(data_packet_t) is 16 // beacuse of padding @@ -87,11 +71,19 @@ struct data_packet_t { uint64_t timestamp; uint32_t prod_rate; - inline uint64_t getTimestamp() const { return _ntohll(×tamp); } - inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + inline uint64_t getTimestamp() const { + return portability::net_to_host(timestamp); + } + inline void setTimestamp(uint64_t time) { + timestamp = portability::host_to_net(time); + } - inline uint32_t getProductionRate() const { return ntohl(prod_rate); } - inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } + inline uint32_t getProductionRate() const { + return portability::net_to_host(prod_rate); + } + inline void setProductionRate(uint32_t rate) { + prod_rate = portability::host_to_net(rate); + } }; struct nack_packet_t { @@ -99,14 +91,26 @@ struct nack_packet_t { uint32_t prod_rate; uint32_t prod_seg; - inline uint64_t getTimestamp() const { return _ntohll(×tamp); } - inline void setTimestamp(uint64_t time) { timestamp = _htonll(&time); } + inline uint64_t getTimestamp() const { + return portability::net_to_host(timestamp); + } + inline void setTimestamp(uint64_t time) { + timestamp = portability::host_to_net(time); + } - inline uint32_t getProductionRate() const { return ntohl(prod_rate); } - inline void setProductionRate(uint32_t rate) { prod_rate = htonl(rate); } + inline uint32_t getProductionRate() const { + return portability::net_to_host(prod_rate); + } + inline void setProductionRate(uint32_t rate) { + prod_rate = portability::host_to_net(rate); + } - inline uint32_t getProductionSegment() const { return ntohl(prod_seg); } - inline void setProductionSegment(uint32_t seg) { prod_seg = htonl(seg); } + inline uint32_t getProductionSegment() const { + return portability::net_to_host(prod_seg); + } + inline void setProductionSegment(uint32_t seg) { + prod_seg = portability::host_to_net(seg); + } }; class AggrPktHeader { @@ -225,7 +229,7 @@ class AggrPktHeader { return (uint16_t) * (buf_ + pkt_index); } else { // 16 bits uint16_t *buf_16 = (uint16_t *)buf_; - return ntohs(*(buf_16 + pkt_index)); + return portability::net_to_host(*(buf_16 + pkt_index)); } } @@ -235,7 +239,7 @@ class AggrPktHeader { *(buf_ + pkt_index) = (uint8_t)len; } else { // 16 bits uint16_t *buf_16 = (uint16_t *)buf_; - *(buf_16 + pkt_index) = htons(len); + *(buf_16 + pkt_index) = portability::host_to_net(len); } } diff --git a/libtransport/src/protocols/rtc/rtc_reassembly.cc b/libtransport/src/protocols/rtc/rtc_reassembly.cc index 992bab50e..b1b0fcaba 100644 --- a/libtransport/src/protocols/rtc/rtc_reassembly.cc +++ b/libtransport/src/protocols/rtc/rtc_reassembly.cc @@ -40,7 +40,7 @@ void RtcReassembly::reassemble(core::ContentObject& content_object) { auto read_buffer = content_object.getPayload(); DLOG_IF(INFO, VLOG_IS_ON(3)) << "Size of payload: " << read_buffer->length(); - read_buffer->trimStart(transport_protocol_->transportHeaderLength()); + read_buffer->trimStart(transport_protocol_->transportHeaderLength(false)); if (data_aggregation_) { rtc::AggrPktHeader hdr((uint8_t*)read_buffer->data()); diff --git a/libtransport/src/protocols/rtc/rtc_recovery_strategy.cc b/libtransport/src/protocols/rtc/rtc_recovery_strategy.cc index 66ae5086c..257fdd09b 100644 --- a/libtransport/src/protocols/rtc/rtc_recovery_strategy.cc +++ b/libtransport/src/protocols/rtc/rtc_recovery_strategy.cc @@ -29,8 +29,12 @@ using namespace transport::interface; RecoveryStrategy::RecoveryStrategy( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, - bool use_rtx, bool use_fec, interface::StrategyCallback &&external_callback) - : recovery_on_(false), + bool use_rtx, bool use_fec, + interface::RtcTransportRecoveryStrategies rs_type, + interface::StrategyCallback &&external_callback) + : rs_type_(rs_type), + recovery_on_(false), + content_sharing_mode_(false), rtx_during_fec_(0), next_rtx_timer_(MAX_TIMER_RTX), send_rtx_callback_(std::move(callback)), @@ -43,7 +47,9 @@ RecoveryStrategy::RecoveryStrategy( } RecoveryStrategy::RecoveryStrategy(RecoveryStrategy &&rs) - : rtx_during_fec_(0), + : rs_type_(rs.rs_type_), + content_sharing_mode_(rs.content_sharing_mode_), + rtx_during_fec_(0), rtx_state_(std::move(rs.rtx_state_)), rtx_timers_(std::move(rs.rtx_timers_)), recover_with_fec_(std::move(rs.recover_with_fec_)), @@ -64,25 +70,52 @@ RecoveryStrategy::RecoveryStrategy(RecoveryStrategy &&rs) RecoveryStrategy::~RecoveryStrategy() {} void RecoveryStrategy::setFecParams(uint32_t n, uint32_t k) { + // if rs_type == FEC_ONLY_LOW_RES_LOSSES max k == 64 n_ = n; k_ = k; // XXX for the moment we go in steps of 5% loss rate. - // max loss rate = 95% + uint32_t i = 0; for (uint32_t loss_rate = 5; loss_rate < 100; loss_rate += 5) { - double dec_loss_rate = (double)(loss_rate + 5) / 100.0; - double exp_losses = (double)k_ * dec_loss_rate; - uint32_t fec_to_ask = ceil(exp_losses / (1 - dec_loss_rate)); - - fec_state_ f; - f.fec_to_ask = std::min(fec_to_ask, (n_ - k_)); - f.last_update = round_id_; - f.avg_residual_losses = 0.0; - f.consecutive_use = 0; - fec_per_loss_rate_.push_back(f); + uint32_t fec_to_ask = 0; + if (n_ != 0 && k_ != 0) { + if (rs_type_ == + interface::RtcTransportRecoveryStrategies::FEC_ONLY_LOW_RES_LOSSES) { + // the max loss rate in the matrix is 50% + uint32_t index = i; + if (i > 9) index = 9; + fec_to_ask = FEC_MATRIX[k_ - 1][index]; + } else { + double dec_loss_rate = (double)(loss_rate + 5); + if (dec_loss_rate == 100.0) dec_loss_rate = 95.0; + dec_loss_rate = dec_loss_rate / 100.0; + double exp_losses = ceil((double)k_ * dec_loss_rate); + fec_to_ask = ceil((exp_losses / (1 - dec_loss_rate)) * 1.25); + } + } + fec_to_ask = std::min(fec_to_ask, (n_ - k_)); + fec_per_loss_rate_.push_back(fec_to_ask); + + i++; } } +uint64_t RecoveryStrategy::getRtxRtt(uint32_t seq) { + auto it = rtx_state_.find(seq); + + if (it == rtx_state_.end()) return 0; + + // we can compute the RTT of an RTX only if it was send once. Infact if the + // RTX was sent twice or more the data may be alredy in flight and the RTT + // will be underestimated. This may happen also for packets that we + // retransmitted too soon. in that case the RTT will be filtered out by + // checking the path label + if (it->second.rtx_count_ != 1) return 0; + + // this a potentialy valid packet, compute the RTT + return (utils::SteadyTime::nowMs().count() - it->second.last_send_); +} + bool RecoveryStrategy::lossDetected(uint32_t seq) { if (isRtx(seq)) { // this packet is already in the list of rtx @@ -141,8 +174,10 @@ void RecoveryStrategy::addNewRtx(uint32_t seq, bool force) { state.first_send_ = state_->getInterestSentTime(seq); if (state.first_send_ == 0) // this interest was never sent before state.first_send_ = getNow(); - state.next_send_ = computeNextSend(seq, true); + state.last_send_ = state.first_send_; // we didn't send an RTX for this + // packet yet state.rtx_count_ = 0; + state.next_send_ = computeNextSend(seq, state.rtx_count_); DLOG_IF(INFO, VLOG_IS_ON(4)) << "Add " << seq << " to retransmissions. next rtx is in " << state.next_send_ - getNow() << " ms"; @@ -158,66 +193,50 @@ void RecoveryStrategy::addNewRtx(uint32_t seq, bool force) { } } -uint64_t RecoveryStrategy::computeNextSend(uint32_t seq, bool new_rtx) { +uint64_t RecoveryStrategy::computeNextSend(uint32_t seq, uint32_t rtx_counter) { uint64_t now = getNow(); - if (new_rtx) { - // for the new rtx we wait one estimated IAT after the loss detection. this - // is bacause, assuming that packets arrive with a constant IAT, we should - // get a new packet every IAT - double prod_rate = state_->getProducerRate(); - uint32_t estimated_iat = SENTINEL_TIMER_INTERVAL; - uint32_t jitter = 0; + if (rtx_counter == 0) { + uint32_t wait = 1; + if (content_sharing_mode_) return now + wait; - if (prod_rate != 0) { - double packet_size = state_->getAveragePacketSize(); - estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); - jitter = ceil(state_->getJitter()); - } + uint32_t jitter = SENTINEL_TIMER_INTERVAL; + double prod_rate = state_->getProducerRate(); + if (prod_rate != 0) jitter = ceil(state_->getJitter()); - uint32_t wait = 1; - if (estimated_iat < 18) { - // for low rate app we do not wait to send a RTX - // we consider low rate stream with less than 50pps (iat >= 20ms) - // (e.g. audio in videoconf, mobile games). - // in the check we use 18ms to accomodate for measurements errors - // for flows with higher rate wait 1 ait + jitter - wait = estimated_iat + jitter; - } + wait += jitter; - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "first rtx for " << seq << " in " << wait - << " ms, rtt = " << state_->getMinRTT() << " ait = " << estimated_iat - << " jttr = " << jitter; + DLOG_IF(INFO, VLOG_IS_ON(3)) << "first rtx for " << seq << " in " << wait + << " ms, jitter = " << jitter; return now + wait; } else { - // wait one RTT - uint32_t wait = SENTINEL_TIMER_INTERVAL; - + // wait one RTT. if an edge is known use the edge RTT for the first 5 rtx double prod_rate = state_->getProducerRate(); if (prod_rate == 0) { return now + SENTINEL_TIMER_INTERVAL; } - double packet_size = state_->getAveragePacketSize(); - uint32_t estimated_iat = ceil(1000.0 / (prod_rate / packet_size)); + uint64_t rtt = 0; + // if the transport detects an edge we try first to get the RTX from the + // edge. if no interest get a reply we move to the full RTT + if (rtx_counter < 5 && (state_->getEdgeRtt() != 0)) { + rtt = state_->getEdgeRtt(); + } else { + rtt = state_->getAvgRTT(); + } - uint64_t rtt = state_->getMinRTT(); if (rtt == 0) rtt = SENTINEL_TIMER_INTERVAL; - wait = rtt; + + if (content_sharing_mode_) return now + rtt; + + uint32_t wait = (uint32_t)rtt; uint32_t jitter = ceil(state_->getJitter()); wait += jitter; - // it may happen that the channel is congested and we have some additional - // queuing delay to take into account - uint32_t queue = ceil(state_->getQueuing()); - wait += queue; - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "next rtx for " << seq << " in " << wait - << " ms, rtt = " << state_->getMinRTT() << " ait = " << estimated_iat - << " jttr = " << jitter << " queue = " << queue; + << "next rtx for " << seq << " in " << wait << " ms, rtt = " << rtt + << " jtter = " << jitter; return now + wait; } @@ -252,7 +271,9 @@ void RecoveryStrategy::retransmit() { state_->onRetransmission(seq); double prod_rate = state_->getProducerRate(); if (prod_rate != 0) rtx_it->second.rtx_count_++; - rtx_it->second.next_send_ = computeNextSend(seq, false); + rtx_it->second.last_send_ = now; + rtx_it->second.next_send_ = + computeNextSend(seq, rtx_it->second.rtx_count_); it = rtx_timers_.erase(it); rtx_timers_.insert( std::pair<uint64_t, uint32_t>(rtx_it->second.next_send_, seq)); @@ -327,6 +348,7 @@ void RecoveryStrategy::deleteRtx(uint32_t seq) { } it_timers++; } + // remove rtx rtx_state_.erase(it_rtx); } @@ -339,53 +361,13 @@ uint32_t RecoveryStrategy::computeFecPacketsToAsk() { if (loss_rate == 0) return 0; - // once per minute try to reduce the fec rate. it may happen that for some bin - // we ask too many fec packet. here we try to reduce this values gently - if (round_id_ % ROUNDS_PER_MIN == 0) { - reduceFec(); - } - // keep track of the last used fec. if we use a new bin on this round reset // consecutive use and avg loss in the prev bin uint32_t bin = ceil(loss_rate / 5.0) - 1; - if (bin > fec_per_loss_rate_.size() - 1) bin = fec_per_loss_rate_.size() - 1; + if (bin > fec_per_loss_rate_.size() - 1) + bin = (uint32_t)fec_per_loss_rate_.size() - 1; - if (bin != last_fec_used_) { - fec_per_loss_rate_[last_fec_used_].consecutive_use = 0; - fec_per_loss_rate_[last_fec_used_].avg_residual_losses = 0.0; - } - last_fec_used_ = bin; - fec_per_loss_rate_[last_fec_used_].consecutive_use++; - - // we update the stats only once very 5 rounds (1sec) that is the rate at - // which we compute residual losses - if (round_id_ % ROUNDS_PER_SEC == 0) { - double residual_losses = state_->getResidualLossRate() * 100; - // update residual loss rate - fec_per_loss_rate_[bin].avg_residual_losses = - (fec_per_loss_rate_[bin].avg_residual_losses * MOVING_AVG_ALPHA) + - (1 - MOVING_AVG_ALPHA) * residual_losses; - - if ((fec_per_loss_rate_[bin].last_update - round_id_) < - WAIT_BEFORE_FEC_UPDATE) { - // this bin is been updated recently so don't modify it and - // return the current state - return fec_per_loss_rate_[bin].fec_to_ask; - } - - // if the residual loss rate is too high and we can ask more fec packets and - // we are using this configuration since at least 5 sec update fec - if (fec_per_loss_rate_[bin].avg_residual_losses > MAX_RESIDUAL_LOSS_RATE && - fec_per_loss_rate_[bin].fec_to_ask < (n_ - k_) && - fec_per_loss_rate_[bin].consecutive_use > WAIT_BEFORE_FEC_UPDATE) { - // so increase the number of fec packets to ask - fec_per_loss_rate_[bin].fec_to_ask++; - fec_per_loss_rate_[bin].last_update = round_id_; - fec_per_loss_rate_[bin].avg_residual_losses = 0.0; - } - } - - return fec_per_loss_rate_[bin].fec_to_ask; + return fec_per_loss_rate_[bin]; } void RecoveryStrategy::setRtxFec(std::optional<bool> rtx_on, @@ -431,21 +413,6 @@ void RecoveryStrategy::removePacketState(uint32_t seq) { deleteRtx(seq); } -// private methods - -void RecoveryStrategy::reduceFec() { - for (uint32_t loss_rate = 5; loss_rate < 100; loss_rate += 5) { - double dec_loss_rate = (double)loss_rate / 100.0; - double exp_losses = (double)k_ * dec_loss_rate; - uint32_t fec_to_ask = ceil(exp_losses / (1 - dec_loss_rate)); - - uint32_t bin = ceil(loss_rate / 5.0) - 1; - if (fec_per_loss_rate_[bin].fec_to_ask > fec_to_ask) { - fec_per_loss_rate_[bin].fec_to_ask--; - } - } -} - } // end namespace rtc } // end namespace protocol diff --git a/libtransport/src/protocols/rtc/rtc_recovery_strategy.h b/libtransport/src/protocols/rtc/rtc_recovery_strategy.h index 482aedc9d..aceb85888 100644 --- a/libtransport/src/protocols/rtc/rtc_recovery_strategy.h +++ b/libtransport/src/protocols/rtc/rtc_recovery_strategy.h @@ -32,9 +32,10 @@ namespace rtc { class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { protected: struct rtx_state_ { - uint64_t first_send_; - uint64_t next_send_; - uint32_t rtx_count_; + uint64_t first_send_; // first time this interest was sent + uint64_t last_send_; // last time this rtx was sent + uint64_t next_send_; // next retransmission time + uint32_t rtx_count_; // number or rtx }; using rtxState = struct rtx_state_; @@ -44,6 +45,7 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { RecoveryStrategy(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, bool use_rtx, bool use_fec, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategy(RecoveryStrategy &&rs); @@ -55,6 +57,7 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { void setState(RTCState *state) { state_ = state; } void setRateControl(RTCRateControl *rateControl) { rc_ = rateControl; } void setFecParams(uint32_t n, uint32_t k); + void setContentSharingMode() { content_sharing_mode_ = true; } bool isRtx(uint32_t seq) { if (rtx_state_.find(seq) != rtx_state_.end()) return true; @@ -71,10 +74,20 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { return false; } + interface::RtcTransportRecoveryStrategies getType() { + return rs_type_; + } + void updateType(interface::RtcTransportRecoveryStrategies type) { + rs_type_ = type; + } bool isRtxOn() { return rtx_on_; } bool isFecOn() { return fec_on_; } RTCState *getState() { return state_; } + + // if the function returns 0 it means that the packet is not an RTX or it is + // not a valid packet to safely compute the RTT + uint64_t getRtxRtt(uint32_t seq); bool lossDetected(uint32_t seq); void notifyNewLossDetedcted(uint32_t seq); void requestPossibleLostPacket(uint32_t seq); @@ -98,7 +111,7 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { protected: // rtx functions void addNewRtx(uint32_t seq, bool force); - uint64_t computeNextSend(uint32_t seq, bool new_rtx); + uint64_t computeNextSend(uint32_t seq, uint32_t rtx_counter); void retransmit(); void scheduleNextRtx(); void deleteRtx(uint32_t seq); @@ -109,9 +122,11 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { // common functons void removePacketState(uint32_t seq); + interface::RtcTransportRecoveryStrategies rs_type_; bool recovery_on_; bool rtx_on_; bool fec_on_; + bool content_sharing_mode_; // number of RTX sent after fec turned on // this is used to take into account jitter and out of order packets @@ -152,19 +167,9 @@ class RecoveryStrategy : public std::enable_shared_from_this<RecoveryStrategy> { RTCRateControl *rc_; private: - struct fec_state_ { - uint32_t fec_to_ask; - uint32_t last_update; // round id of the last update - // (wait 10 ruonds (2sec) between updates) - uint32_t consecutive_use; // consecutive ruonds where this fec was used - double avg_residual_losses; - }; - - void reduceFec(); - uint32_t round_id_; // number of rounds uint32_t last_fec_used_; - std::vector<fec_state_> fec_per_loss_rate_; + std::vector<uint32_t> fec_per_loss_rate_; interface::StrategyCallback callback_; }; diff --git a/libtransport/src/protocols/rtc/rtc_rs_delay.cc b/libtransport/src/protocols/rtc/rtc_rs_delay.cc index 4be751ec9..7d7a01133 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_delay.cc +++ b/libtransport/src/protocols/rtc/rtc_rs_delay.cc @@ -25,8 +25,10 @@ namespace rtc { RecoveryStrategyDelayBased::RecoveryStrategyDelayBased( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback) : RecoveryStrategy(indexer, std::move(callback), io_service, true, false, + rs_type, std::move(external_callback)), // start with rtx congestion_state_(false), probing_state_(false), @@ -48,7 +50,7 @@ void RecoveryStrategyDelayBased::turnOnRecovery() { recovery_on_ = true; uint64_t rtt = state_->getMinRTT(); uint32_t fec_to_ask = computeFecPacketsToAsk(); - if (rtt > 80 && fec_to_ask != 0) { + if (rtt > MAX_RTT_BEFORE_FEC && fec_to_ask > 0) { // we need to start FEC (see fec only strategy for more details) setRtxFec(true, true); rtx_during_fec_ = 1; // avoid to stop fec @@ -84,16 +86,16 @@ void RecoveryStrategyDelayBased::onNewRound(bool in_sync) { return; } - uint64_t rtt = state_->getMinRTT(); + uint64_t rtt = state_->getAvgRTT(); - bool congestion = false; // XXX at the moment we are not looking at congestion events - // congestion = rc_->inCongestionState(); + // bool congestion = rc_->inCongestionState(); - if ((!fec_on_ && rtt >= 100) || (fec_on_ && rtt > 80) || congestion) { + if ((!fec_on_ && rtt >= MAX_RTT_BEFORE_FEC) || + (fec_on_ && rtt > (MAX_RTT_BEFORE_FEC - 10))) { // switch from rtx to fec or keep use fec. Notice that if some rtx are // waiting to be scheduled, they will be sent normally, but no new rtx will - // be created If the loss rate is 0 keep to use RTX. + // be created if the loss rate is 0 keep to use RTX. uint32_t fec_to_ask = computeFecPacketsToAsk(); softSwitchToFec(fec_to_ask); if (rtx_during_fec_ == 0) // if we do not send any RTX the losses @@ -104,7 +106,8 @@ void RecoveryStrategyDelayBased::onNewRound(bool in_sync) { return; } - if ((fec_on_ && rtt <= 80) || (!rtx_on_ && rtt <= 100)) { + if ((fec_on_ && rtt <= (MAX_RTT_BEFORE_FEC - 10)) || + (!rtx_on_ && rtt <= MAX_RTT_BEFORE_FEC)) { // turn on rtx softSwitchToFec(0); indexer_->setNFec(0); diff --git a/libtransport/src/protocols/rtc/rtc_rs_delay.h b/libtransport/src/protocols/rtc/rtc_rs_delay.h index 5ca90f4cb..9e1c41388 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_delay.h +++ b/libtransport/src/protocols/rtc/rtc_rs_delay.h @@ -26,6 +26,7 @@ class RecoveryStrategyDelayBased : public RecoveryStrategy { public: RecoveryStrategyDelayBased(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategyDelayBased(RecoveryStrategy &&rs); diff --git a/libtransport/src/protocols/rtc/rtc_rs_fec_only.cc b/libtransport/src/protocols/rtc/rtc_rs_fec_only.cc index c44212bda..5b10823ec 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_fec_only.cc +++ b/libtransport/src/protocols/rtc/rtc_rs_fec_only.cc @@ -25,9 +25,10 @@ namespace rtc { RecoveryStrategyFecOnly::RecoveryStrategyFecOnly( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback) : RecoveryStrategy(indexer, std::move(callback), io_service, true, false, - std::move(external_callback)), + rs_type, std::move(external_callback)), congestion_state_(false), probing_state_(false), switch_rounds_(0) {} diff --git a/libtransport/src/protocols/rtc/rtc_rs_fec_only.h b/libtransport/src/protocols/rtc/rtc_rs_fec_only.h index 1ab78b842..42df25bd9 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_fec_only.h +++ b/libtransport/src/protocols/rtc/rtc_rs_fec_only.h @@ -26,6 +26,7 @@ class RecoveryStrategyFecOnly : public RecoveryStrategy { public: RecoveryStrategyFecOnly(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategyFecOnly(RecoveryStrategy &&rs); diff --git a/libtransport/src/protocols/rtc/rtc_rs_low_rate.cc b/libtransport/src/protocols/rtc/rtc_rs_low_rate.cc index 48dd3e34f..dbad563cd 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_low_rate.cc +++ b/libtransport/src/protocols/rtc/rtc_rs_low_rate.cc @@ -25,8 +25,10 @@ namespace rtc { RecoveryStrategyLowRate::RecoveryStrategyLowRate( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback) : RecoveryStrategy(indexer, std::move(callback), io_service, false, true, + rs_type, std::move(external_callback)), // start with fec fec_consecutive_rounds_((MILLI_IN_A_SEC / ROUND_LEN) * 5), // 5 sec rtx_allowed_consecutive_rounds_(0) { @@ -75,7 +77,7 @@ void RecoveryStrategyLowRate::selectRecoveryStrategy(bool in_sync) { } uint32_t loss_rate = std::round(state_->getPerSecondLossRate() * 100); - uint32_t rtt = state_->getAvgRTT(); + uint32_t rtt = (uint32_t)state_->getAvgRTT(); bool use_rtx = false; for (size_t i = 0; i < switch_vector.size(); i++) { diff --git a/libtransport/src/protocols/rtc/rtc_rs_low_rate.h b/libtransport/src/protocols/rtc/rtc_rs_low_rate.h index d66b197e2..0e76efaca 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_low_rate.h +++ b/libtransport/src/protocols/rtc/rtc_rs_low_rate.h @@ -34,6 +34,7 @@ class RecoveryStrategyLowRate : public RecoveryStrategy { public: RecoveryStrategyLowRate(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategyLowRate(RecoveryStrategy &&rs); diff --git a/libtransport/src/protocols/rtc/rtc_rs_recovery_off.cc b/libtransport/src/protocols/rtc/rtc_rs_recovery_off.cc index 16b14eff6..00c6a0504 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_recovery_off.cc +++ b/libtransport/src/protocols/rtc/rtc_rs_recovery_off.cc @@ -25,9 +25,10 @@ namespace rtc { RecoveryStrategyRecoveryOff::RecoveryStrategyRecoveryOff( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback) : RecoveryStrategy(indexer, std::move(callback), io_service, false, false, - std::move(external_callback)) {} + rs_type, std::move(external_callback)) {} RecoveryStrategyRecoveryOff::RecoveryStrategyRecoveryOff(RecoveryStrategy &&rs) : RecoveryStrategy(std::move(rs)) { diff --git a/libtransport/src/protocols/rtc/rtc_rs_recovery_off.h b/libtransport/src/protocols/rtc/rtc_rs_recovery_off.h index 3a9e71e7d..3d59cc473 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_recovery_off.h +++ b/libtransport/src/protocols/rtc/rtc_rs_recovery_off.h @@ -26,6 +26,7 @@ class RecoveryStrategyRecoveryOff : public RecoveryStrategy { public: RecoveryStrategyRecoveryOff(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategyRecoveryOff(RecoveryStrategy &&rs); diff --git a/libtransport/src/protocols/rtc/rtc_rs_rtx_only.cc b/libtransport/src/protocols/rtc/rtc_rs_rtx_only.cc index 8e5db5439..4d7cf7a82 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_rtx_only.cc +++ b/libtransport/src/protocols/rtc/rtc_rs_rtx_only.cc @@ -25,9 +25,10 @@ namespace rtc { RecoveryStrategyRtxOnly::RecoveryStrategyRtxOnly( Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback) : RecoveryStrategy(indexer, std::move(callback), io_service, true, false, - std::move(external_callback)) {} + rs_type, std::move(external_callback)) {} RecoveryStrategyRtxOnly::RecoveryStrategyRtxOnly(RecoveryStrategy &&rs) : RecoveryStrategy(std::move(rs)) { diff --git a/libtransport/src/protocols/rtc/rtc_rs_rtx_only.h b/libtransport/src/protocols/rtc/rtc_rs_rtx_only.h index e90e5ba13..03dbed1c7 100644 --- a/libtransport/src/protocols/rtc/rtc_rs_rtx_only.h +++ b/libtransport/src/protocols/rtc/rtc_rs_rtx_only.h @@ -26,6 +26,7 @@ class RecoveryStrategyRtxOnly : public RecoveryStrategy { public: RecoveryStrategyRtxOnly(Indexer *indexer, SendRtxCallback &&callback, asio::io_service &io_service, + interface::RtcTransportRecoveryStrategies rs_type, interface::StrategyCallback &&external_callback); RecoveryStrategyRtxOnly(RecoveryStrategy &&rs); diff --git a/libtransport/src/protocols/rtc/rtc_state.cc b/libtransport/src/protocols/rtc/rtc_state.cc index 5b3b5e4c3..82ac0b9c1 100644 --- a/libtransport/src/protocols/rtc/rtc_state.cc +++ b/libtransport/src/protocols/rtc/rtc_state.cc @@ -106,6 +106,7 @@ void RTCState::initParams() { // paths stats path_table_.clear(); main_path_ = nullptr; + edge_path_ = nullptr; // packet cache (not pending anymore) packet_cache_.clear(); @@ -231,11 +232,9 @@ void RTCState::onDataPacketReceived(const core::ContentObject &content_object, } updatePacketSize(content_object); - updateReceivedBytes(content_object); + updateReceivedBytes(content_object, false); addRecvOrLost(seq, PacketState::RECEIVED); - if (seq > highest_seq_received_) highest_seq_received_ = seq; - // the producer is responding // it is generating valid data packets so we consider it active producer_is_active_ = true; @@ -245,11 +244,7 @@ void RTCState::onDataPacketReceived(const core::ContentObject &content_object, void RTCState::onFecPacketReceived(const core::ContentObject &content_object) { uint32_t seq = content_object.getName().getSuffix(); - // updateReceivedBytes(content_object); - received_fec_bytes_ += - (uint32_t)(content_object.headerSize() + content_object.payloadSize()); - - if (seq > highest_seq_received_) highest_seq_received_ = seq; + updateReceivedBytes(content_object, true); PacketState state = getPacketState(seq); if (state != PacketState::LOST) { @@ -328,12 +323,14 @@ void RTCState::onPacketLost(uint32_t seq) { DLOG_IF(INFO, VLOG_IS_ON(4)) << "packet " << seq << " is lost"; } } + addRecvOrLost(seq, PacketState::DEFINITELY_LOST); } -void RTCState::onPacketRecoveredRtx(uint32_t seq) { +void RTCState::onPacketRecoveredRtx(const core::ContentObject &content_object, + uint64_t rtt) { + uint32_t seq = content_object.getName().getSuffix(); packets_sent_to_app_++; - if (seq > highest_seq_received_) highest_seq_received_ = seq; // increase the recovered packet counter only if the packet was marked as LOST // before. @@ -341,13 +338,37 @@ void RTCState::onPacketRecoveredRtx(uint32_t seq) { if (state == PacketState::LOST) losses_recovered_++; addRecvOrLost(seq, PacketState::RECEIVED); + updateReceivedBytes(content_object, false); + + if (rtt == 0) return; // nothing to do + + uint32_t path_label = content_object.getPathLabel(); + auto path_it = path_table_.find(path_label); + if (path_it == path_table_.end()) { + // this is a new path and it must be a cache + std::shared_ptr<RTCDataPath> newPath = + std::make_shared<RTCDataPath>(path_label); + auto ret = path_table_.insert( + std::pair<uint32_t, std::shared_ptr<RTCDataPath>>(path_label, newPath)); + path_it = ret.first; + } + + auto path = path_it->second; + if (path->pathToProducer()) + return; // this packet is coming from a producer + // even if we sent an RTX. this may happen + // for RTX that are sent too fast or in + // case of multipath + + path->insertRttSample(utils::SteadyTime::Milliseconds(rtt), true); } -void RTCState::onFecPacketRecoveredRtx(uint32_t seq) { +void RTCState::onFecPacketRecoveredRtx( + const core::ContentObject &content_object) { // This is the same as onPacketRecoveredRtx, but in this is case the // pkt is also a FEC pkt, the addRecvOrLost will be called afterwards - if (seq > highest_seq_received_) highest_seq_received_ = seq; losses_recovered_++; + updateReceivedBytes(content_object, true); } void RTCState::onPacketRecoveredFec(uint32_t seq, uint32_t size) { @@ -355,8 +376,6 @@ void RTCState::onPacketRecoveredFec(uint32_t seq, uint32_t size) { packets_sent_to_app_++; recovered_bytes_with_fec_ += size; - if (seq > highest_seq_received_) highest_seq_received_ = seq; - // adding header to the count recovered_bytes_with_fec_ += 60; // XXX get header size some where @@ -487,21 +506,32 @@ void RTCState::onNewRound(double round_len, bool in_sync) { // channel losses uint32_t last_round_packets = 0; + uint64_t min_edge_rtt = UINT_MAX; std::shared_ptr<RTCDataPath> old_main_path = main_path_; main_path_ = nullptr; + edge_path_ = nullptr; for (auto it = path_table_.begin(); it != path_table_.end(); it++) { - if (it->second->isActive()) { + if (it->second->isValidProducer()) { uint32_t pkt = it->second->getPacketsLastRound(); if (pkt > last_round_packets) { last_round_packets = pkt; main_path_ = it->second; } + } else if (it->second->isActive() && !it->second->pathToProducer()) { + // this is a path to a cache from where we are receiving content + if (it->second->getMinRtt() < min_edge_rtt) { + min_edge_rtt = it->second->getMinRtt(); + edge_path_ = it->second; + } } it->second->roundEnd(); } if (main_path_ == nullptr) main_path_ = old_main_path; + if (edge_path_ == nullptr) edge_path_ = main_path_; + if (edge_path_->getMinRtt() >= main_path_->getMinRtt()) + edge_path_ = main_path_; // in case we get a new main path we reset the stats of the old one. this is // beacuse, in case we need to switch back we don't what to take decisions on @@ -551,9 +581,15 @@ void RTCState::onNewRound(double round_len, bool in_sync) { rounds_++; } -void RTCState::updateReceivedBytes(const core::ContentObject &content_object) { - received_bytes_ += - (uint32_t)(content_object.headerSize() + content_object.payloadSize()); +void RTCState::updateReceivedBytes(const core::ContentObject &content_object, + bool isFec) { + if (isFec) { + received_fec_bytes_ += + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); + } else { + received_bytes_ += + (uint32_t)(content_object.headerSize() + content_object.payloadSize()); + } } void RTCState::updatePacketSize(const core::ContentObject &content_object) { @@ -703,6 +739,10 @@ void RTCState::dataToBeReceived(uint32_t seq) { addToPacketCache(seq, PacketState::TO_BE_RECEIVED); } +void RTCState::updateHighestSeqReceived(uint32_t seq) { + if (seq > highest_seq_received_) highest_seq_received_ = seq; +} + void RTCState::addRecvOrLost(uint32_t seq, PacketState state) { auto it = pending_interests_.find(seq); if (it != pending_interests_.end()) { @@ -803,7 +843,7 @@ core::ParamsRTC RTCState::getProbeParams(const core::ContentObject &probe) { switch (ProbeHandler::getProbeType(seq)) { case ProbeType::INIT: { core::ContentObjectManifest manifest( - const_cast<core::ContentObject &>(probe)); + const_cast<core::ContentObject &>(probe).shared_from_this()); manifest.decode(); params = manifest.getParamsRTC(); break; @@ -841,7 +881,7 @@ core::ParamsRTC RTCState::getDataParams(const core::ContentObject &data) { } case core::PayloadType::MANIFEST: { core::ContentObjectManifest manifest( - const_cast<core::ContentObject &>(data)); + const_cast<core::ContentObject &>(data).shared_from_this()); manifest.decode(); params = manifest.getParamsRTC(); break; diff --git a/libtransport/src/protocols/rtc/rtc_state.h b/libtransport/src/protocols/rtc/rtc_state.h index 4bd2f76a0..ac3cc621f 100644 --- a/libtransport/src/protocols/rtc/rtc_state.h +++ b/libtransport/src/protocols/rtc/rtc_state.h @@ -84,8 +84,9 @@ class RTCState : public std::enable_shared_from_this<RTCState> { void onNackPacketReceived(const core::ContentObject &nack, bool compute_stats); void onPacketLost(uint32_t seq); - void onPacketRecoveredRtx(uint32_t seq); - void onFecPacketRecoveredRtx(uint32_t seq); + void onPacketRecoveredRtx(const core::ContentObject &content_object, + uint64_t rtt); + void onFecPacketRecoveredRtx(const core::ContentObject &content_object); void onPacketRecoveredFec(uint32_t seq, uint32_t size); bool onProbePacketReceived(const core::ContentObject &probe); void onJumpForward(uint32_t next_seq); @@ -117,6 +118,11 @@ class RTCState : public std::enable_shared_from_this<RTCState> { return 0; } + uint64_t getEdgeRtt() const { + if (edge_path_ != nullptr) return edge_path_->getMinRtt(); + return 0; + } + void resetRttStats() { if (mainPathIsValid()) main_path_->clearRtt(); } @@ -149,7 +155,7 @@ class RTCState : public std::enable_shared_from_this<RTCState> { } uint32_t getPendingInterestNumber() const { - return pending_interests_.size(); + return (uint32_t)pending_interests_.size(); } PacketState getPacketState(uint32_t seq) { @@ -242,6 +248,8 @@ class RTCState : public std::enable_shared_from_this<RTCState> { // set it as TO_BE_RECEIVED. void dataToBeReceived(uint32_t seq); + void updateHighestSeqReceived(uint32_t seq); + // Extract RTC parameters from probes (init or RTT probes) and data packets. static core::ParamsRTC getProbeParams(const core::ContentObject &probe); static core::ParamsRTC getDataParams(const core::ContentObject &data); @@ -259,7 +267,8 @@ class RTCState : public std::enable_shared_from_this<RTCState> { // update stats void updateState(); - void updateReceivedBytes(const core::ContentObject &content_object); + void updateReceivedBytes(const core::ContentObject &content_object, + bool isFec); void updatePacketSize(const core::ContentObject &content_object); void updatePathStats(const core::ContentObject &content_object, bool is_nack); void updateLossRate(bool in_sycn); @@ -360,7 +369,12 @@ class RTCState : public std::enable_shared_from_this<RTCState> { // paths stats std::unordered_map<uint32_t, std::shared_ptr<RTCDataPath>> path_table_; - std::shared_ptr<RTCDataPath> main_path_; + std::shared_ptr<RTCDataPath> main_path_; // this is the path that connects + // the consumer to the producer. in + // case of multipath the trasnport + // uses the most active path + std::shared_ptr<RTCDataPath> edge_path_; // path to the closest cache if it + // exists // packet received // cache where to store info about the last MAX_CACHED_PACKETS diff --git a/libtransport/src/protocols/rtc/rtc_verifier.cc b/libtransport/src/protocols/rtc/rtc_verifier.cc index 7b6330a1f..861ceee89 100644 --- a/libtransport/src/protocols/rtc/rtc_verifier.cc +++ b/libtransport/src/protocols/rtc/rtc_verifier.cc @@ -22,11 +22,11 @@ namespace protocol { namespace rtc { RTCVerifier::RTCVerifier(std::shared_ptr<auth::Verifier> verifier, - uint32_t max_unverified_interval, - double max_unverified_ratio) + uint32_t factor_relevant, uint32_t factor_alert) : verifier_(verifier), - max_unverified_interval_(max_unverified_interval), - max_unverified_ratio_(max_unverified_ratio) {} + factor_relevant_(factor_relevant), + factor_alert_(factor_alert), + manifest_max_capacity_(std::numeric_limits<uint8_t>::max()) {} void RTCVerifier::setState(std::shared_ptr<RTCState> rtc_state) { rtc_state_ = rtc_state; @@ -36,12 +36,16 @@ void RTCVerifier::setVerifier(std::shared_ptr<auth::Verifier> verifier) { verifier_ = verifier; } -void RTCVerifier::setMaxUnverifiedInterval(uint32_t max_unverified_interval) { - max_unverified_interval_ = max_unverified_interval; +void RTCVerifier::setFactorRelevant(uint32_t factor_relevant) { + factor_relevant_ = factor_relevant; } -void RTCVerifier::setMaxUnverifiedRatio(double max_unverified_ratio) { - max_unverified_ratio_ = max_unverified_ratio; +void RTCVerifier::setFactorAlert(uint32_t factor_alert) { + factor_alert_ = factor_alert; +} + +auth::VerificationPolicy RTCVerifier::verify(core::Interest &interest) { + return verifier_->verifyPackets(&interest); } auth::VerificationPolicy RTCVerifier::verify( @@ -108,19 +112,27 @@ auth::VerificationPolicy RTCVerifier::verifyData( auth::Suffix suffix = content_object.getName().getSuffix(); auth::VerificationPolicy policy = auth::VerificationPolicy::ABORT; - Timestamp now = utils::SteadyTime::nowMs().count(); - // Flush old packets - Timestamp oldest = flush_packets(now); + uint32_t threshold_relevant = factor_relevant_ * manifest_max_capacity_; + uint32_t threshold_alert = factor_alert_ * manifest_max_capacity_; - // Add packet to map of unverified packets - packets_unverif_.add( - {.suffix = suffix, .timestamp = now, .size = content_object.length()}, - content_object.computeDigest(manifest_hash_algo_)); + // Flush packets outside relevance window + for (auto it = packets_unverif_.set().begin(); + it != packets_unverif_.set().end();) { + if (it->first > current_index_ - threshold_relevant) { + break; + } + packets_unverif_erased_.insert((unsigned int)it->first); + it = packets_unverif_.remove(it); + } + + // Add packet to set of unverified packets + packets_unverif_.add({current_index_, suffix}, + content_object.computeDigest(manifest_hash_algo_)); + current_index_++; - // Check that the ratio of unverified packets stays below the limit - if (now - oldest < max_unverified_interval_ || - getBufferRatio() < max_unverified_ratio_) { + // Check that the number of unverified packets is below the alert threshold + if (packets_unverif_.set().size() <= threshold_alert) { policy = auth::VerificationPolicy::ACCEPT; } @@ -139,18 +151,13 @@ auth::VerificationPolicy RTCVerifier::processManifest( auth::VerificationPolicy accept_policy = auth::VerificationPolicy::ACCEPT; // Decode manifest - core::ContentObjectManifest manifest(content_object); + core::ContentObjectManifest manifest(content_object.shared_from_this()); manifest.decode(); - // Update last manifest - if (suffix > last_manifest_) { - last_manifest_ = suffix; - } - - // Extract hash algorithm and hashes + // Extract manifest data + manifest_max_capacity_ = manifest.getMaxCapacity(); manifest_hash_algo_ = manifest.getHashAlgorithm(); - auth::Verifier::SuffixMap suffix_map = - core::ContentObjectManifest::getSuffixMap(&manifest); + auth::Verifier::SuffixMap suffix_map = manifest.getSuffixMap(); // Return early if the manifest is empty if (suffix_map.empty()) { @@ -186,10 +193,7 @@ auth::VerificationPolicy RTCVerifier::processManifest( for (const auto &p : policies) { switch (p.second) { case auth::VerificationPolicy::ACCEPT: { - auto packet_unverif_it = packets_unverif_.packetIt(p.first); - Packet packet_verif = *packet_unverif_it; - packets_unverif_.remove(packet_unverif_it); - packets_verif_.add(packet_verif); + packets_unverif_.remove(packets_unverif_.packet(p.first)); manifest_digests_.erase(p.first); break; } @@ -209,69 +213,20 @@ void RTCVerifier::onDataRecoveredFec(uint32_t suffix) { manifest_digests_.erase(suffix); } -void RTCVerifier::onJumpForward(uint32_t next_suffix) { - if (next_suffix <= last_manifest_ + 1) { - return; - } - - // When we jump forward in the suffix sequence, we remove packets that won't - // be verified. Those packets have a suffix in the range [last_manifest_ + 1, - // next_suffix[. - for (auth::Suffix suffix = last_manifest_ + 1; suffix < next_suffix; - ++suffix) { - auto packet_it = packets_unverif_.packetIt(suffix); - if (packet_it != packets_unverif_.set().end()) { - packets_unverif_.remove(packet_it); - } - } -} - -double RTCVerifier::getBufferRatio() const { - size_t total = packets_verif_.size() + packets_unverif_.size(); - double total_unverified = static_cast<double>(packets_unverif_.size()); - return total ? total_unverified / total : 0.0; -} - -RTCVerifier::Timestamp RTCVerifier::flush_packets(Timestamp now) { - Timestamp oldest_verified = packets_verif_.set().empty() - ? now - : packets_verif_.set().begin()->timestamp; - Timestamp oldest_unverified = packets_unverif_.set().empty() - ? now - : packets_unverif_.set().begin()->timestamp; - - // Prune verified packets older than the unverified interval - for (auto it = packets_verif_.set().begin(); - it != packets_verif_.set().end();) { - if (now - it->timestamp < max_unverified_interval_) { - break; - } - it = packets_verif_.remove(it); - } - - // Prune unverified packets older than the unverified interval - for (auto it = packets_unverif_.set().begin(); - it != packets_unverif_.set().end();) { - if (now - it->timestamp < max_unverified_interval_) { - break; - } - packets_unverif_erased_.insert(it->suffix); - it = packets_unverif_.remove(it); - } - - return std::min(oldest_verified, oldest_unverified); -} - std::pair<RTCVerifier::PacketSet::iterator, bool> RTCVerifier::Packets::add( - const Packet &packet) { + const Packet &packet, const auth::CryptoHash &digest) { auto inserted = packets_.insert(packet); - size_ += inserted.second ? packet.size : 0; + if (inserted.second) { + packets_map_[packet.second] = inserted.first; + suffix_map_[packet.second] = digest; + } return inserted; } RTCVerifier::PacketSet::iterator RTCVerifier::Packets::remove( PacketSet::iterator packet_it) { - size_ -= packet_it->size; + packets_map_.erase(packet_it->second); + suffix_map_.erase(packet_it->second); return packets_.erase(packet_it); } @@ -279,35 +234,13 @@ const std::set<RTCVerifier::Packet> &RTCVerifier::Packets::set() const { return packets_; }; -size_t RTCVerifier::Packets::size() const { return size_; }; - -std::pair<RTCVerifier::PacketSet::iterator, bool> -RTCVerifier::PacketsUnverif::add(const Packet &packet, - const auth::CryptoHash &digest) { - auto inserted = add(packet); - if (inserted.second) { - packets_map_[packet.suffix] = inserted.first; - digests_map_[packet.suffix] = digest; - } - return inserted; -} - -RTCVerifier::PacketSet::iterator RTCVerifier::PacketsUnverif::remove( - PacketSet::iterator packet_it) { - size_ -= packet_it->size; - packets_map_.erase(packet_it->suffix); - digests_map_.erase(packet_it->suffix); - return packets_.erase(packet_it); -} - -RTCVerifier::PacketSet::iterator RTCVerifier::PacketsUnverif::packetIt( +RTCVerifier::PacketSet::iterator RTCVerifier::Packets::packet( auth::Suffix suffix) { return packets_map_.at(suffix); }; -const auth::Verifier::SuffixMap &RTCVerifier::PacketsUnverif::suffixMap() - const { - return digests_map_; +const auth::Verifier::SuffixMap &RTCVerifier::Packets::suffixMap() const { + return suffix_map_; } } // end namespace rtc diff --git a/libtransport/src/protocols/rtc/rtc_verifier.h b/libtransport/src/protocols/rtc/rtc_verifier.h index 098984057..c83faf08a 100644 --- a/libtransport/src/protocols/rtc/rtc_verifier.h +++ b/libtransport/src/protocols/rtc/rtc_verifier.h @@ -27,19 +27,16 @@ namespace rtc { class RTCVerifier { public: explicit RTCVerifier(std::shared_ptr<auth::Verifier> verifier, - uint32_t max_unverified_interval, - double max_unverified_ratio); + uint32_t factor_relevant, uint32_t factor_alert); virtual ~RTCVerifier() = default; void setState(std::shared_ptr<RTCState> rtc_state); - void setVerifier(std::shared_ptr<auth::Verifier> verifier); + void setFactorRelevant(uint32_t factor_relevant); + void setFactorAlert(uint32_t factor_alert); - void setMaxUnverifiedInterval(uint32_t max_unverified_interval); - - void setMaxUnverifiedRatio(double max_unverified_ratio); - + auth::VerificationPolicy verify(core::Interest &interest); auth::VerificationPolicy verify(core::ContentObject &content_object, bool is_fec = false); auth::VerificationPolicy verifyProbe(core::ContentObject &content_object); @@ -51,81 +48,47 @@ class RTCVerifier { auth::VerificationPolicy processManifest(core::ContentObject &content_object); void onDataRecoveredFec(uint32_t suffix); - void onJumpForward(uint32_t next_suffix); - - double getBufferRatio() const; protected: - struct Packet; - using Timestamp = uint64_t; + using Index = uint64_t; + using Packet = std::pair<Index, auth::Suffix>; using PacketSet = std::set<Packet>; - struct Packet { - auth::Suffix suffix; - Timestamp timestamp; - size_t size; - - bool operator==(const Packet &b) const { - return timestamp == b.timestamp && suffix == b.suffix; - } - bool operator<(const Packet &b) const { - return timestamp == b.timestamp ? suffix < b.suffix - : timestamp < b.timestamp; - } - }; - class Packets { public: - virtual std::pair<PacketSet::iterator, bool> add(const Packet &packet); - virtual PacketSet::iterator remove(PacketSet::iterator packet_it); - const PacketSet &set() const; - size_t size() const; - - protected: - PacketSet packets_; - size_t size_; - }; - - class PacketsVerif : public Packets {}; - - class PacketsUnverif : public Packets { - public: - using Packets::add; std::pair<PacketSet::iterator, bool> add(const Packet &packet, const auth::CryptoHash &digest); - PacketSet::iterator remove(PacketSet::iterator packet_it) override; - PacketSet::iterator packetIt(auth::Suffix suffix); + PacketSet::iterator remove(PacketSet::iterator packet_it); + const PacketSet &set() const; + PacketSet::iterator packet(auth::Suffix suffix); const auth::Verifier::SuffixMap &suffixMap() const; private: + PacketSet packets_; std::unordered_map<auth::Suffix, PacketSet::iterator> packets_map_; - auth::Verifier::SuffixMap digests_map_; + auth::Verifier::SuffixMap suffix_map_; }; // The RTC state. std::shared_ptr<RTCState> rtc_state_; // The verifier instance. std::shared_ptr<auth::Verifier> verifier_; - // Window to consider when verifying packets. - uint32_t max_unverified_interval_; - // Ratio of unverified packets over which an alert is triggered. - double max_unverified_ratio_; - // The suffix of the last processed manifest. - auth::Suffix last_manifest_; + // Used to compute the relevance windows size (in packets). + uint32_t factor_relevant_; + // Used to compute the alert threshold (in packets). + uint32_t factor_alert_; + // The maximum number of entries a manifest can contain. + uint8_t manifest_max_capacity_; // Hash algorithm used by manifests. auth::CryptoHashType manifest_hash_algo_; // Digests extracted from all manifests received. auth::Verifier::SuffixMap manifest_digests_; - // Verified packets with timestamp >= now - max_unverified_interval_. - PacketsVerif packets_verif_; - // Unverified packets with timestamp >= now - max_unverified_interval_. - PacketsUnverif packets_unverif_; - // Unverified erased packets with timestamp < now - max_unverified_interval_. + // The number of data packets processed. + Index current_index_; + // Unverified packets with index in relevance window. + Packets packets_unverif_; + // Unverified erased packets with index outside relevance window. std::unordered_set<auth::Suffix> packets_unverif_erased_; - - // Flushes all packets with timestamp < now - max_unverified_interval_. - // Returns the timestamp of the oldest packet, verified or not. - Timestamp flush_packets(Timestamp now); }; } // namespace rtc diff --git a/libtransport/src/protocols/transport_protocol.cc b/libtransport/src/protocols/transport_protocol.cc index a73b9fb7b..b1803709b 100644 --- a/libtransport/src/protocols/transport_protocol.cc +++ b/libtransport/src/protocols/transport_protocol.cc @@ -79,6 +79,7 @@ int TransportProtocol::start() { &on_payload_); socket_->getSocketOption(GeneralTransportOptions::ASYNC_MODE, is_async_); + socket_->getSocketOption(GeneralTransportOptions::SIGNER, signer_); // Set it is the first time we schedule an interest is_first_ = true; @@ -143,14 +144,22 @@ void TransportProtocol::sendInterest( Packet::Format format; socket_->getSocketOption(interface::GeneralTransportOptions::PACKET_FORMAT, format); + size_t signature_size = 0; - auto interest = - core::PacketManager<>::getInstance().getPacket<Interest>(format); + // If aggregated interest, add spapce for signature + if (len > 0) { + format = Packet::toAHFormat(format); + signature_size = signer_->getSignatureFieldSize(); + } + + auto interest = core::PacketManager<>::getInstance().getPacket<Interest>( + format, signature_size); interest->setName(interest_name); for (uint32_t i = 0; i < len; i++) { interest->appendSuffix(additional_suffixes->at(i)); } + interest->encodeSuffixes(); uint32_t lifetime = default_values::interest_lifetime; socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, @@ -165,7 +174,16 @@ void TransportProtocol::sendInterest( return; } - portal_->sendInterest(std::move(interest)); + bool content_sharing_mode; + socket_->getSocketOption(RtcTransportOptions::CONTENT_SHARING_MODE, + content_sharing_mode); + if (content_sharing_mode) lifetime = ceil((double)lifetime * 0.9); + + // Compute signature + bool is_ah = _is_ah(interest->getFormat()); + if (is_ah) signer_->signPacket(interest.get()); + + portal_->sendInterest(interest, lifetime); } void TransportProtocol::onError(const std::error_code &ec) { diff --git a/libtransport/src/protocols/transport_protocol.h b/libtransport/src/protocols/transport_protocol.h index ad8cf0346..e71992561 100644 --- a/libtransport/src/protocols/transport_protocol.h +++ b/libtransport/src/protocols/transport_protocol.h @@ -64,7 +64,7 @@ class TransportProtocol * * @return The header length in bytes. */ - virtual std::size_t transportHeaderLength() { return 0; } + virtual std::size_t transportHeaderLength(bool isFEC) { return 0; } virtual void scheduleNextInterests() = 0; @@ -141,6 +141,9 @@ class TransportProtocol bool is_async_; fec::FECType fec_type_; + + // Signer for aggregated interests + std::shared_ptr<auth::Signer> signer_; }; } // end namespace protocol diff --git a/libtransport/src/test/CMakeLists.txt b/libtransport/src/test/CMakeLists.txt index e7018ceed..b7f14766e 100644 --- a/libtransport/src/test/CMakeLists.txt +++ b/libtransport/src/test/CMakeLists.txt @@ -31,6 +31,8 @@ list(APPEND TESTS_SRC test_quality_score.cc test_sessions.cc test_thread_pool.cc + test_quadloop.cc + test_prefix.cc ) if (ENABLE_RELY) diff --git a/libtransport/src/test/test_core_manifest.cc b/libtransport/src/test/test_core_manifest.cc index b998ce96b..e3d66c1cd 100644 --- a/libtransport/src/test/test_core_manifest.cc +++ b/libtransport/src/test/test_core_manifest.cc @@ -13,8 +13,8 @@ * limitations under the License. */ +#include <core/manifest.h> #include <core/manifest_format_fixed.h> -#include <core/manifest_inline.h> #include <gtest/gtest.h> #include <hicn/transport/auth/crypto_hash.h> #include <hicn/transport/auth/signer.h> @@ -33,10 +33,12 @@ namespace { // The fixture for testing class Foo. class ManifestTest : public ::testing::Test { protected: - using ContentObjectManifest = ManifestInline<ContentObject, Fixed>; + using ContentObjectManifest = Manifest<Fixed>; - ManifestTest() : name_("b001::123|321"), manifest1_(HF_INET6_TCP_AH, name_) { - // You can do set-up work for each test here. + ManifestTest() + : format_(HF_INET6_TCP_AH), name_("b001::123|321"), signature_size_(0) { + manifest_ = ContentObjectManifest::createContentManifest(format_, name_, + signature_size_); } virtual ~ManifestTest() { @@ -56,10 +58,11 @@ class ManifestTest : public ::testing::Test { // before the destructor). } + Packet::Format format_; Name name_; - ContentObjectManifest manifest1_; - - std::vector<uint8_t> manifest_payload = { + std::size_t signature_size_; + std::shared_ptr<ContentObjectManifest> manifest_; + std::vector<uint8_t> manifest_payload_ = { 0x11, 0x11, 0x01, 0x00, 0xb0, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad // , 0x00, 0x00, // 0x00, 0x45, 0xa3, @@ -75,169 +78,200 @@ class ManifestTest : public ::testing::Test { } // namespace -TEST_F(ManifestTest, MoveConstructor) { +TEST_F(ManifestTest, ManifestConstructor) { // Create content object with manifest in payload - ContentObject co(HF_INET6_TCP_AH, 128); - co.appendPayload(&manifest_payload[0], manifest_payload.size()); - uint8_t buffer[256]; - co.appendPayload(buffer, 256); + ContentObject::Ptr co = + core::PacketManager<>::getInstance().getPacket<ContentObject>( + format_, signature_size_); + co->setName(name_); + co->appendPayload(manifest_payload_.data(), manifest_payload_.size()); + + uint8_t buffer[256] = {0}; + co->appendPayload(buffer, 256); // Copy packet payload uint8_t packet[1500]; - auto length = co.getPayload()->length(); - std::memcpy(packet, co.getPayload()->data(), length); + auto length = co->getPayload()->length(); + std::memcpy(packet, co->getPayload()->data(), length); // Create manifest - ContentObjectManifest m(std::move(co)); + ContentObjectManifest manifest(co); // Check manifest payload is exactly the same of content object - ASSERT_EQ(length, m.getPayload()->length()); - auto ret = std::memcmp(packet, m.getPayload()->data(), length); + ASSERT_EQ(length, manifest.getPacket()->getPayload()->length()); + auto ret = + std::memcmp(packet, manifest.getPacket()->getPayload()->data(), length); ASSERT_EQ(ret, 0); } -TEST_F(ManifestTest, SetLastManifest) { - manifest1_.clear(); - - manifest1_.setIsLast(true); - bool fcn = manifest1_.getIsLast(); - - ASSERT_TRUE(fcn == true); -} - TEST_F(ManifestTest, SetManifestType) { - manifest1_.clear(); + manifest_->Encoder::clear(); ManifestType type1 = ManifestType::INLINE_MANIFEST; ManifestType type2 = ManifestType::FLIC_MANIFEST; - manifest1_.setType(type1); - ManifestType type_returned1 = manifest1_.getType(); + manifest_->setType(type1); + ManifestType type_returned1 = manifest_->getType(); - manifest1_.clear(); + manifest_->Encoder::clear(); - manifest1_.setType(type2); - ManifestType type_returned2 = manifest1_.getType(); + manifest_->setType(type2); + ManifestType type_returned2 = manifest_->getType(); ASSERT_EQ(type1, type_returned1); ASSERT_EQ(type2, type_returned2); } +TEST_F(ManifestTest, SetMaxCapacity) { + manifest_->Encoder::clear(); + + uint8_t max_capacity1 = 0; + uint8_t max_capacity2 = 20; + + manifest_->setMaxCapacity(max_capacity1); + uint8_t max_capacity_returned1 = manifest_->getMaxCapacity(); + + manifest_->Encoder::clear(); + + manifest_->setMaxCapacity(max_capacity2); + uint8_t max_capacity_returned2 = manifest_->getMaxCapacity(); + + ASSERT_EQ(max_capacity1, max_capacity_returned1); + ASSERT_EQ(max_capacity2, max_capacity_returned2); +} + TEST_F(ManifestTest, SetHashAlgorithm) { - manifest1_.clear(); + manifest_->Encoder::clear(); - auth::CryptoHashType hash1 = auth::CryptoHashType::SHA512; - auth::CryptoHashType hash2 = auth::CryptoHashType::BLAKE2B512; - auth::CryptoHashType hash3 = auth::CryptoHashType::SHA256; + auth::CryptoHashType hash1 = auth::CryptoHashType::SHA256; + auth::CryptoHashType hash2 = auth::CryptoHashType::SHA512; + auth::CryptoHashType hash3 = auth::CryptoHashType::BLAKE2B512; - manifest1_.setHashAlgorithm(hash1); - auto type_returned1 = manifest1_.getHashAlgorithm(); + manifest_->setHashAlgorithm(hash1); + auto type_returned1 = manifest_->getHashAlgorithm(); - manifest1_.clear(); + manifest_->Encoder::clear(); - manifest1_.setHashAlgorithm(hash2); - auto type_returned2 = manifest1_.getHashAlgorithm(); + manifest_->setHashAlgorithm(hash2); + auto type_returned2 = manifest_->getHashAlgorithm(); - manifest1_.clear(); + manifest_->Encoder::clear(); - manifest1_.setHashAlgorithm(hash3); - auto type_returned3 = manifest1_.getHashAlgorithm(); + manifest_->setHashAlgorithm(hash3); + auto type_returned3 = manifest_->getHashAlgorithm(); ASSERT_EQ(hash1, type_returned1); ASSERT_EQ(hash2, type_returned2); ASSERT_EQ(hash3, type_returned3); } +TEST_F(ManifestTest, SetLastManifest) { + manifest_->Encoder::clear(); + + manifest_->setIsLast(true); + bool is_last = manifest_->getIsLast(); + + ASSERT_TRUE(is_last); +} + +TEST_F(ManifestTest, SetBaseName) { + manifest_->Encoder::clear(); + + core::Name base_name("b001::dead"); + + manifest_->setBaseName(base_name); + core::Name ret_name = manifest_->getBaseName(); + + ASSERT_EQ(base_name, ret_name); +} + TEST_F(ManifestTest, setParamsBytestream) { - manifest1_.clear(); + manifest_->Encoder::clear(); ParamsBytestream params{ - .final_segment = 1, + .final_segment = 0x0a, }; - manifest1_.setParamsBytestream(params); - manifest1_.encode(); + manifest_->setParamsBytestream(params); + auth::CryptoHash hash(auth::CryptoHashType::SHA256); + hash.computeDigest({0x01, 0x02, 0x03, 0x04}); + manifest_->addEntry(1, hash); + + manifest_->encode(); + manifest_->decode(); - ContentObjectManifest manifest(manifest1_); - manifest.decode(); + auto transport_type_returned = manifest_->getTransportType(); + auto params_returned = manifest_->getParamsBytestream(); ASSERT_EQ(interface::ProductionProtocolAlgorithms::BYTE_STREAM, - manifest.getTransportType()); - ASSERT_EQ(params, manifest.getParamsBytestream()); + transport_type_returned); + ASSERT_EQ(params, params_returned); } TEST_F(ManifestTest, SetParamsRTC) { - manifest1_.clear(); + manifest_->Encoder::clear(); ParamsRTC params{ - .timestamp = 1, - .prod_rate = 2, - .prod_seg = 3, + .timestamp = 0x0a, + .prod_rate = 0x0b, + .prod_seg = 0x0c, .fec_type = protocol::fec::FECType::UNKNOWN, }; - manifest1_.setParamsRTC(params); - manifest1_.encode(); + manifest_->setParamsRTC(params); + auth::CryptoHash hash(auth::CryptoHashType::SHA256); + hash.computeDigest({0x01, 0x02, 0x03, 0x04}); + manifest_->addEntry(1, hash); - ContentObjectManifest manifest(manifest1_); - manifest.decode(); + manifest_->encode(); + manifest_->decode(); + + auto transport_type_returned = manifest_->getTransportType(); + auto params_returned = manifest_->getParamsRTC(); ASSERT_EQ(interface::ProductionProtocolAlgorithms::RTC_PROD, - manifest.getTransportType()); - ASSERT_EQ(params, manifest.getParamsRTC()); + transport_type_returned); + ASSERT_EQ(params, params_returned); } TEST_F(ManifestTest, SignManifest) { - Name name("b001::", 0); auto signer = std::make_shared<auth::SymmetricSigner>( auth::CryptoSuite::HMAC_SHA256, "hunter2"); auto verifier = std::make_shared<auth::SymmetricVerifier>("hunter2"); - std::shared_ptr<ContentObjectManifest> manifest; - // Instantiate Manifest - manifest.reset(ContentObjectManifest::createManifest( - HF_INET6_TCP_AH, name, ManifestVersion::VERSION_1, - ManifestType::INLINE_MANIFEST, false, name, signer->getHashType(), - signer->getSignatureFieldSize())); + // Instantiate manifest + uint8_t max_capacity = 30; + std::shared_ptr<ContentObjectManifest> manifest = + ContentObjectManifest::createContentManifest( + format_, name_, signer->getSignatureFieldSize()); + manifest->setHeaders(ManifestType::INLINE_MANIFEST, max_capacity, + signer->getHashType(), false /* is_last */, name_); - // Add Manifest entry + // Add manifest entry auth::CryptoHash hash(signer->getHashType()); - hash.computeDigest(std::vector<uint8_t>{0x01, 0x02, 0x03, 0x04}); - manifest->addSuffixHash(1, hash); + hash.computeDigest({0x01, 0x02, 0x03, 0x04}); + manifest->addEntry(1, hash); // Encode manifest manifest->encode(); + auto manifest_co = + std::dynamic_pointer_cast<ContentObject>(manifest->getPacket()); // Sign manifest - signer->signPacket(manifest.get()); + signer->signPacket(manifest_co.get()); // Check size - ASSERT_EQ(manifest->payloadSize(), manifest->estimateManifestSize()); - ASSERT_EQ(manifest->length(), - manifest->headerSize() + manifest->payloadSize()); - ASSERT_EQ(ContentObjectManifest::manifestHeaderSize( - interface::ProductionProtocolAlgorithms::UNKNOWN), - manifest->manifestHeaderSize()); + ASSERT_EQ(manifest_co->payloadSize(), manifest->Encoder::manifestSize()); + ASSERT_EQ(manifest_co->length(), + manifest_co->headerSize() + manifest_co->payloadSize()); // Verify manifest - auth::VerificationPolicy policy = verifier->verifyPackets(manifest.get()); + auth::VerificationPolicy policy = verifier->verifyPackets(manifest_co.get()); ASSERT_EQ(auth::VerificationPolicy::ACCEPT, policy); } -TEST_F(ManifestTest, SetBaseName) { - manifest1_.clear(); - - core::Name base_name("b001::dead"); - manifest1_.setBaseName(base_name); - core::Name ret_name = manifest1_.getBaseName(); - - ASSERT_EQ(base_name, ret_name); -} - TEST_F(ManifestTest, SetSuffixList) { - manifest1_.clear(); - - core::Name base_name("b001::dead"); + manifest_->Encoder::clear(); using random_bytes_engine = std::independent_bits_engine<std::default_random_engine, CHAR_BIT, @@ -259,12 +293,13 @@ TEST_F(ManifestTest, SetSuffixList) { entries[i] = std::make_pair(suffixes[i], auth::CryptoHash(data[i].data(), data[i].size(), auth::CryptoHashType::SHA256)); - manifest1_.addSuffixHash(entries[i].first, entries[i].second); + manifest_->addEntry(entries[i].first, entries[i].second); } - manifest1_.setBaseName(base_name); - core::Name ret_name = manifest1_.getBaseName(); + core::Name base_name("b001::dead"); + manifest_->setBaseName(base_name); + core::Name ret_name = manifest_->getBaseName(); ASSERT_EQ(base_name, ret_name); delete[] entries; diff --git a/libtransport/src/test/test_interest.cc b/libtransport/src/test/test_interest.cc index d9c535881..e36ca0f93 100644 --- a/libtransport/src/test/test_interest.cc +++ b/libtransport/src/test/test_interest.cc @@ -258,5 +258,44 @@ TEST_F(InterestTest, AppendSuffixesEncodeAndIterate) { } } +TEST_F(InterestTest, AppendSuffixesWithGaps) { + // Create interest from buffer + Interest interest(HF_INET6_TCP); + + // Appenad some suffixes, out of order and with gaps + interest.appendSuffix(6); + interest.appendSuffix(2); + interest.appendSuffix(5); + interest.appendSuffix(1); + + // Encode them in wire format + interest.encodeSuffixes(); + EXPECT_TRUE(interest.hasManifest()); + + // Check first suffix correctness + auto suffix = interest.firstSuffix(); + EXPECT_NE(suffix, nullptr); + EXPECT_EQ(*suffix, 1U); + + // Iterate over them. They should be in order and without repetitions + std::vector<uint32_t> expected = {1, 2, 5, 6}; + EXPECT_EQ(interest.numberOfSuffixes(), expected.size()); + + for (uint32_t seq : expected) { + EXPECT_EQ(*suffix, seq); + suffix++; + } +} + +TEST_F(InterestTest, InterestWithoutManifest) { + // Create interest without manifest + Interest interest(HF_INET6_TCP); + auto suffix = interest.firstSuffix(); + + EXPECT_FALSE(interest.hasManifest()); + EXPECT_EQ(interest.numberOfSuffixes(), 0U); + EXPECT_EQ(suffix, nullptr); +} + } // namespace core } // namespace transport diff --git a/libtransport/src/test/test_memif_connector.cc b/libtransport/src/test/test_memif_connector.cc index 562a12c88..40f4df927 100644 --- a/libtransport/src/test/test_memif_connector.cc +++ b/libtransport/src/test/test_memif_connector.cc @@ -83,8 +83,8 @@ class Memif { recv_counter_ += buffers.size(); if (recv_counter_ == total_packets) { auto t1 = utils::SteadyTime::now(); - auto delta = utils::SteadyTime::getDurationS(t0_, t1); - auto rate = recv_counter_ / delta.count(); + auto delta = utils::SteadyTime::getDurationUs(t0_, t1); + double rate = double(recv_counter_) * 1.0e6 / double(delta.count()); LOG(INFO) << "rate: " << rate << " packets/s"; io_service_.stop(); } diff --git a/libtransport/src/test/test_packet_allocator.cc b/libtransport/src/test/test_packet_allocator.cc index b63ddde8d..744f1bd24 100644 --- a/libtransport/src/test/test_packet_allocator.cc +++ b/libtransport/src/test/test_packet_allocator.cc @@ -21,6 +21,7 @@ #define ALLOCATION_CHECKS #include <hicn/transport/core/global_object_pool.h> #undef ALLOCATION_CHECKS +#include <hicn/transport/utils/chrono_typedefs.h> #include <hicn/transport/utils/event_thread.h> namespace transport { @@ -30,6 +31,8 @@ class PacketAllocatorTest : public ::testing::Test { protected: static inline const std::size_t default_size = 2048; static inline const std::size_t default_n_buffer = 1024; + static inline const std::size_t counter = 1024; + static inline const std::size_t total_packets = 1024 * counter; // Get fixed block allocator_ of 1024 buffers of size 2048 bytes PacketAllocatorTest() : allocator_(PacketManager<>::getInstance()) { @@ -102,5 +105,27 @@ TEST_F(PacketAllocatorTest, CheckAllocationIsCorrect) { PacketManager<>::PacketStorage::packet_and_shared_ptr))); } +TEST_F(PacketAllocatorTest, CheckAllocationSpeed) { + // Check time needed to allocate 1 million packeauto &packet_manager = + auto &packet_manager = core::PacketManager<>::getInstance(); + + // Send 1 million packets + std::array<utils::MemBuf::Ptr, counter> packets; + auto t0 = utils::SteadyTime::now(); + std::size_t sum = 0; + for (std::size_t j = 0; j < counter; j++) { + for (std::size_t i = 0; i < counter; i++) { + packets[i] = packet_manager.getMemBuf(); + sum++; + } + } + auto t1 = utils::SteadyTime::now(); + + auto delta = utils::SteadyTime::getDurationUs(t0, t1); + auto rate = double(sum) * 1000000.0 / double(delta.count()); + + LOG(INFO) << "rate: " << rate << " packets/s"; +} + } // namespace core } // namespace transport
\ No newline at end of file diff --git a/libtransport/src/test/test_prefix.cc b/libtransport/src/test/test_prefix.cc new file mode 100644 index 000000000..5de737566 --- /dev/null +++ b/libtransport/src/test/test_prefix.cc @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <glog/logging.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <hicn/transport/core/prefix.h> +#include <hicn/transport/errors/invalid_ip_address_exception.h> +#include <hicn/transport/portability/endianess.h> + +#include <cstring> +#include <memory> +#include <vector> + +namespace transport { +namespace core { + +namespace { +class PrefixTest : public ::testing::Test { + protected: + static inline const char prefix_str0[] = "2001:db8:1::/64"; + static inline const char prefix_str1[] = "10.11.12.0/24"; + static inline const char prefix_str2[] = "2001:db8:1::abcd/64"; + static inline const char prefix_str3[] = "10.11.12.245/27"; + static inline const char wrong_prefix_str0[] = "10.11.12.245/45"; + static inline const char wrong_prefix_str1[] = "10.400.12.13/8"; + static inline const char wrong_prefix_str2[] = "2001:db8:1::/640"; + static inline const char wrong_prefix_str3[] = "20011::db8:1::/16"; + static inline const char wrong_prefix_str4[] = "2001::db8:1::fffff/96"; + + PrefixTest() = default; + + ~PrefixTest() override = default; + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + } + + void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } +}; + +TEST_F(PrefixTest, ConstructorRightString) { + // Create empty prefix + Prefix p; + + // Create prefix from string + Prefix p0(prefix_str0); + // Reconstruct string and check it is equal to original address + std::string network = p0.getNetwork(); + std::uint16_t prefix_length = p0.getPrefixLength(); + EXPECT_THAT(network + "/" + std::to_string(prefix_length), + ::testing::StrEq(prefix_str0)); + + // Create prefix from string + Prefix p1(prefix_str1); + // Reconstruct string and check it is equal to original address + network = p1.getNetwork(); + prefix_length = p1.getPrefixLength(); + EXPECT_THAT(network + "/" + std::to_string(prefix_length), + ::testing::StrEq(prefix_str1)); + + // Create prefix from string + Prefix p2(prefix_str2); + // Reconstruct string and check it is equal to original address + network = p2.getNetwork(); + prefix_length = p2.getPrefixLength(); + EXPECT_THAT(network + "/" + std::to_string(prefix_length), + ::testing::StrEq(prefix_str2)); + + // Create prefix from string + Prefix p3(prefix_str3); + // Reconstruct string and check it is equal to original address + network = p3.getNetwork(); + prefix_length = p3.getPrefixLength(); + EXPECT_THAT(network + "/" + std::to_string(prefix_length), + ::testing::StrEq(prefix_str3)); + + // Create prefix from string and prefix length + Prefix p4("2001::1234", 66); + // Reconstruct string and check it is equal to original address + network = p4.getNetwork(); + prefix_length = p4.getPrefixLength(); + auto af = p4.getAddressFamily(); + EXPECT_THAT(network, ::testing::StrEq("2001::1234")); + EXPECT_THAT(prefix_length, ::testing::Eq(66)); + EXPECT_THAT(af, ::testing::Eq(AF_INET6)); +} + +TEST_F(PrefixTest, ConstructorWrongString) { + try { + Prefix p0(wrong_prefix_str0); + FAIL() << "Expected exception"; + } catch (const errors::InvalidIpAddressException &) { + // Expected exception + } + + try { + Prefix p1(wrong_prefix_str1); + FAIL() << "Expected exception"; + } catch (const errors::InvalidIpAddressException &) { + // Expected exception + } + + try { + Prefix p2(wrong_prefix_str2); + FAIL() << "Expected exception"; + } catch (const errors::InvalidIpAddressException &) { + // Expected exception + } + + try { + Prefix p3(wrong_prefix_str3); + FAIL() << "Expected exception"; + } catch (const errors::InvalidIpAddressException &) { + // Expected exception + } + + try { + Prefix p4(wrong_prefix_str4); + FAIL() << "Expected exception"; + } catch (const errors::InvalidIpAddressException &) { + // Expected exception + } +} + +TEST_F(PrefixTest, Comparison) { + Prefix p0(prefix_str0); + Prefix p1(prefix_str1); + + // Expect they are different + EXPECT_THAT(p0, ::testing::Ne(p1)); + + auto p2 = p1; + // Expect they are equal + EXPECT_THAT(p1, ::testing::Eq(p2)); +} + +TEST_F(PrefixTest, ToSockAddress) { + Prefix p0(prefix_str3); + + auto ret = p0.toSockaddr(); + auto sockaddr = reinterpret_cast<sockaddr_in *>(ret.get()); + + EXPECT_THAT(sockaddr->sin_family, ::testing::Eq(AF_INET)); + EXPECT_THAT(sockaddr->sin_addr.s_addr, portability::host_to_net(0x0a0b0cf5)); +} + +TEST_F(PrefixTest, GetPrefixLength) { + Prefix p0(prefix_str3); + EXPECT_THAT(p0.getPrefixLength(), ::testing::Eq(27)); +} + +TEST_F(PrefixTest, SetPrefixLength) { + Prefix p0(prefix_str3); + EXPECT_THAT(p0.getPrefixLength(), ::testing::Eq(27)); + p0.setPrefixLength(20); + EXPECT_THAT(p0.getPrefixLength(), ::testing::Eq(20)); + + try { + p0.setPrefixLength(33); + FAIL() << "Expected exception"; + } catch ([[maybe_unused]] const errors::InvalidIpAddressException &) { + // Expected exception + } +} + +TEST_F(PrefixTest, SetGetNetwork) { + Prefix p0(prefix_str0); + EXPECT_THAT(p0.getPrefixLength(), ::testing::Eq(64)); + p0.setNetwork("b001::1234"); + EXPECT_THAT(p0.getNetwork(), ::testing::StrEq("b001::1234")); + EXPECT_THAT(p0.getPrefixLength(), ::testing::Eq(64)); +} + +TEST_F(PrefixTest, Contains) { + // IPv6 prefix + Prefix p0(prefix_str0); + ip_address_t ip0, ip1; + + ip_address_pton("2001:db8:1::1234", &ip0); + ip_address_pton("2001:db9:1::1234", &ip1); + + EXPECT_TRUE(p0.contains(ip0)); + EXPECT_FALSE(p0.contains(ip1)); + + Prefix p1(prefix_str1); + ip_address_pton("10.11.12.12", &ip0); + ip_address_pton("10.12.12.13", &ip1); + + EXPECT_TRUE(p1.contains(ip0)); + EXPECT_FALSE(p1.contains(ip1)); + + Prefix p2(prefix_str2); + ip_address_pton("2001:db8:1::dbca", &ip0); + ip_address_pton("10.12.12.12", &ip1); + + EXPECT_TRUE(p2.contains(ip0)); + EXPECT_FALSE(p2.contains(ip1)); + + Prefix p3(prefix_str3); + ip_address_pton("10.11.12.245", &ip0); + ip_address_pton("10.11.12.1", &ip1); + + EXPECT_TRUE(p3.contains(ip0)); + EXPECT_FALSE(p3.contains(ip1)); + + // Corner cases + Prefix p4("::/0"); + ip_address_pton("7001:db8:1::1234", &ip0); + ip_address_pton("8001:db8:1::1234", &ip1); + + EXPECT_TRUE(p4.contains(ip0)); + EXPECT_TRUE(p4.contains(ip1)); + + // Corner cases + Prefix p5("b001:a:b:c:d:e:f:1/128"); + ip_address_pton("b001:a:b:c:d:e:f:1", &ip0); + ip_address_pton("b001:a:b:c:d:e:f:2", &ip1); + + EXPECT_TRUE(p5.contains(ip0)); + EXPECT_FALSE(p5.contains(ip1)); +} + +TEST_F(PrefixTest, GetAddressFamily) { + Prefix p0(prefix_str0); + auto af = p0.getAddressFamily(); + EXPECT_THAT(af, ::testing::Eq(AF_INET6)); + + Prefix p1(prefix_str1); + af = p1.getAddressFamily(); + EXPECT_THAT(af, ::testing::Eq(AF_INET)); +} + +TEST_F(PrefixTest, MakeName) { + Prefix p0(prefix_str0); + auto name0 = p0.makeName(); + EXPECT_THAT(name0.toString(), ::testing::StrEq("2001:db8:1::|0")); + + Prefix p1(prefix_str1); + auto name1 = p1.makeName(); + EXPECT_THAT(name1.toString(), ::testing::StrEq("10.11.12.0|0")); + + Prefix p2(prefix_str2); + auto name2 = p2.makeName(); + EXPECT_THAT(name2.toString(), ::testing::StrEq("2001:db8:1::|0")); + + Prefix p3(prefix_str3); + auto name3 = p3.makeName(); + EXPECT_THAT(name3.toString(), ::testing::StrEq("10.11.12.224|0")); + + Prefix p4("b001:a:b:c:d:e:f:1/128"); + auto name4 = p4.makeName(); + EXPECT_THAT(name4.toString(), ::testing::StrEq("b001:a:b:c:d:e:f:1|0")); +} + +TEST_F(PrefixTest, MakeRandomName) { + Prefix p0(prefix_str0); + auto name0 = p0.makeRandomName(); + auto name1 = p0.makeRandomName(); + auto name2 = p0.makeRandomName(); + auto name3 = p0.makeRandomName(); + + EXPECT_THAT(name0, ::testing::Not(::testing::Eq(name1))); + EXPECT_THAT(name0, ::testing::Not(::testing::Eq(name2))); + EXPECT_THAT(name0, ::testing::Not(::testing::Eq(name3))); + EXPECT_THAT(name1, ::testing::Not(::testing::Eq(name2))); + EXPECT_THAT(name1, ::testing::Not(::testing::Eq(name3))); + EXPECT_THAT(name2, ::testing::Not(::testing::Eq(name3))); + + // Corner case + Prefix p2("b001:a:b:c:d:e:f:1/128"); + name0 = p2.makeRandomName(); + name1 = p2.makeRandomName(); + name2 = p2.makeRandomName(); + name3 = p2.makeRandomName(); + + EXPECT_THAT(name0, ::testing::Eq(name1)); + EXPECT_THAT(name0, ::testing::Eq(name2)); + EXPECT_THAT(name0, ::testing::Eq(name3)); + EXPECT_THAT(name1, ::testing::Eq(name2)); + EXPECT_THAT(name1, ::testing::Eq(name3)); + EXPECT_THAT(name2, ::testing::Eq(name3)); +} + +TEST_F(PrefixTest, MakeNameWithIndex) { + Prefix p0(prefix_str0); + auto name0 = p0.makeNameWithIndex(0); + EXPECT_THAT(name0.toString(), ::testing::StrEq("2001:db8:1::|0")); + auto name1 = p0.makeNameWithIndex(1); + EXPECT_THAT(name1.toString(), ::testing::StrEq("2001:db8:1::1|0")); + auto name2 = p0.makeNameWithIndex(2); + EXPECT_THAT(name2.toString(), ::testing::StrEq("2001:db8:1::2|0")); + auto name3 = p0.makeNameWithIndex(3); + EXPECT_THAT(name3.toString(), ::testing::StrEq("2001:db8:1::3|0")); + + Prefix p1(prefix_str1); + name0 = p1.makeNameWithIndex(0); + EXPECT_THAT(name0.toString(), ::testing::StrEq("10.11.12.0|0")); + name1 = p1.makeNameWithIndex(1); + EXPECT_THAT(name1.toString(), ::testing::StrEq("10.11.12.1|0")); + name2 = p1.makeNameWithIndex(2); + EXPECT_THAT(name2.toString(), ::testing::StrEq("10.11.12.2|0")); + name3 = p1.makeNameWithIndex(3); + EXPECT_THAT(name3.toString(), ::testing::StrEq("10.11.12.3|0")); + + // Test truncation + Prefix p2("b001::/96"); + name0 = p2.makeNameWithIndex(0xffffffffffffffff); + EXPECT_THAT(name0.toString(), ::testing::StrEq("b001::ffff:ffff|0")); +} + +} // namespace + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/test/test_quadloop.cc b/libtransport/src/test/test_quadloop.cc new file mode 100644 index 000000000..6a08033aa --- /dev/null +++ b/libtransport/src/test/test_quadloop.cc @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2021 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <glog/logging.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <hicn/transport/portability/cache.h> + +#include <array> +#include <cstring> +#include <memory> +#include <vector> + +namespace utils { + +class LoopTest : public ::testing::Test { + protected: + static inline const std::size_t size = 256; + + LoopTest() = default; + + ~LoopTest() override = default; + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + } + + void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } +}; + +// 1 cache line struct (64 bytes) +struct Data { + std::array<uint64_t, 8> data; +}; + +TEST_F(LoopTest, QuadLoopTest) { + // Create 2 arrays of 256 elements + std::vector<std::unique_ptr<Data>> _from; + std::vector<std::unique_ptr<Data>> _to_next; + _from.reserve(size); + _to_next.reserve(size); + + int n_left_from = size; + int n_left_to_next = size; + + // Initialize the arrays + for (std::size_t i = 0; i < size; i++) { + _from.push_back(std::make_unique<Data>()); + _to_next.push_back(std::make_unique<Data>()); + + for (int j = 0; j < 8; j++) { + _from[i]->data[j] = j; + _to_next[i]->data[j] = 0; + } + } + + const std::unique_ptr<Data> *from = &_from[0]; + const std::unique_ptr<Data> *to_next = &_to_next[0]; + + clock_t start; + clock_t end; + double clocks; + + start = clock(); + // Create a quad loop + while (n_left_from > 0) { + while (n_left_from >= 4 && n_left_to_next >= 4) { + { + using namespace transport::portability::cache; + Data *d2; + Data *d3; + + d2 = from[2].get(); + d3 = from[3].get(); + + prefetch<Data, READ>(d2, sizeof(Data)); + prefetch<Data, READ>(d3, sizeof(Data)); + + d2 = to_next[2].get(); + d3 = to_next[3].get(); + + prefetch<Data, WRITE>(d2, sizeof(Data)); + prefetch<Data, WRITE>(d3, sizeof(Data)); + } + + // Do 4 iterations + std::memcpy(to_next[0].get()->data.data(), from[0].get()->data.data(), + sizeof(Data)); + std::memcpy(to_next[1].get()->data.data(), from[1].get()->data.data(), + sizeof(Data)); + n_left_from -= 2; + n_left_to_next -= 2; + from += 2; + to_next += 2; + } + + while (n_left_from > 0 && n_left_to_next > 0) { + std::memcpy(to_next[0].get()->data.data(), from[0].get()->data.data(), + sizeof(Data)); + n_left_from -= 1; + n_left_to_next -= 1; + from += 1; + to_next += 1; + } + } + end = clock(); + clocks = (double)(end - start); + + LOG(INFO) << "Time with quad loop: " << clocks << std::endl; +} + +TEST_F(LoopTest, NormalLoopTest) { + // Create 2 arrays of 256 elements + std::vector<std::unique_ptr<Data>> _from; + std::vector<std::unique_ptr<Data>> _to_next; + _from.reserve(size); + _to_next.reserve(size); + + int n_left_from = size; + int n_left_to_next = size; + + // Initialize the arrays + for (std::size_t i = 0; i < size; i++) { + _from.push_back(std::make_unique<Data>()); + _to_next.push_back(std::make_unique<Data>()); + + for (int j = 0; j < 8; j++) { + _from[i]->data[j] = j; + _to_next[i]->data[j] = 0; + } + } + + const std::unique_ptr<Data> *from = &_from[0]; + const std::unique_ptr<Data> *to_next = &_to_next[0]; + + clock_t start; + clock_t end; + double clocks; + + start = clock(); + while (n_left_from > 0) { + while (n_left_from > 0 && n_left_to_next > 0) { + std::memcpy(to_next[0].get()->data.data(), from[0].get()->data.data(), + sizeof(Data)); + n_left_from -= 1; + n_left_to_next -= 1; + from += 1; + to_next += 1; + } + } + end = clock(); + clocks = ((double)(end - start)); + + LOG(INFO) << "Time with normal loop: " << clocks << std::endl; +} + +} // namespace utils
\ No newline at end of file diff --git a/libtransport/src/utils/epoll_event_reactor.h b/libtransport/src/utils/epoll_event_reactor.h index 8e7681c20..32d99c837 100644 --- a/libtransport/src/utils/epoll_event_reactor.h +++ b/libtransport/src/utils/epoll_event_reactor.h @@ -49,7 +49,7 @@ class EpollEventReactor : public EventReactor { if (it == event_callback_map_.end()) { { utils::SpinLock::Acquire locked(event_callback_map_lock_); - event_callback_map_[fd] = std::forward<EventHandler &&>(callback); + event_callback_map_[fd] = std::forward<EventHandler>(callback); } ret = addFileDescriptor(fd, events); diff --git a/libtransport/src/utils/fd_deadline_timer.h b/libtransport/src/utils/fd_deadline_timer.h index e15cd4d2a..cf0cde112 100644 --- a/libtransport/src/utils/fd_deadline_timer.h +++ b/libtransport/src/utils/fd_deadline_timer.h @@ -57,8 +57,8 @@ class FdDeadlineTimer : public DeadlineTimer<FdDeadlineTimer> { reactor_.addFileDescriptor( timer_fd_, events, - [callback = std::forward<WaitHandler &&>(callback)]( - const Event &event) -> int { + [callback = + std::forward<WaitHandler>(callback)](const Event &event) -> int { uint64_t s = 0; std::error_code ec; diff --git a/libtransport/src/utils/suffix_strategy.h b/libtransport/src/utils/suffix_strategy.h index 96eaed662..4b3ddbc74 100644 --- a/libtransport/src/utils/suffix_strategy.h +++ b/libtransport/src/utils/suffix_strategy.h @@ -36,11 +36,11 @@ enum class NextSuffixStrategy : uint8_t { class SuffixStrategy { public: static constexpr uint32_t MAX_SUFFIX = std::numeric_limits<uint32_t>::max(); - static constexpr uint8_t MAX_MANIFEST_CAPACITY = + static constexpr uint8_t MANIFEST_MAX_CAPACITY = std::numeric_limits<uint8_t>::max(); SuffixStrategy(NextSuffixStrategy strategy, uint32_t offset = 0, - uint32_t manifest_capacity = MAX_MANIFEST_CAPACITY) + uint32_t manifest_capacity = MANIFEST_MAX_CAPACITY) : suffix_stragegy_(strategy), next_suffix_(offset), manifest_capacity_(manifest_capacity), @@ -130,7 +130,7 @@ class SuffixStrategyFactory { public: static std::unique_ptr<SuffixStrategy> getSuffixStrategy( NextSuffixStrategy strategy, uint32_t start_offset = 0, - uint32_t manifest_capacity = SuffixStrategy::MAX_MANIFEST_CAPACITY) { + uint32_t manifest_capacity = SuffixStrategy::MANIFEST_MAX_CAPACITY) { switch (strategy) { case NextSuffixStrategy::INCREMENTAL: return std::make_unique<IncrementalSuffixStrategy>(start_offset); |