diff options
Diffstat (limited to 'libtransport/src/core')
25 files changed, 752 insertions, 785 deletions
diff --git a/libtransport/src/core/CMakeLists.txt b/libtransport/src/core/CMakeLists.txt index b9b024d60..777772a04 100644 --- a/libtransport/src/core/CMakeLists.txt +++ b/libtransport/src/core/CMakeLists.txt @@ -14,17 +14,18 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/facade.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest.h - ${CMAKE_CURRENT_SOURCE_DIR}/manifest_inline.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format_fixed.h ${CMAKE_CURRENT_SOURCE_DIR}/manifest_format.h ${CMAKE_CURRENT_SOURCE_DIR}/pending_interest.h ${CMAKE_CURRENT_SOURCE_DIR}/portal.h ${CMAKE_CURRENT_SOURCE_DIR}/errors.h ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_id_counter.h ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.h ${CMAKE_CURRENT_SOURCE_DIR}/global_workers.h ${CMAKE_CURRENT_SOURCE_DIR}/udp_connector.h ${CMAKE_CURRENT_SOURCE_DIR}/udp_listener.h + ${CMAKE_CURRENT_SOURCE_DIR}/global_module_manager.h ) list(APPEND SOURCE_FILES @@ -38,9 +39,9 @@ list(APPEND SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/portal.cc ${CMAKE_CURRENT_SOURCE_DIR}/global_configuration.cc ${CMAKE_CURRENT_SOURCE_DIR}/io_module.cc - ${CMAKE_CURRENT_SOURCE_DIR}/local_connector.cc ${CMAKE_CURRENT_SOURCE_DIR}/udp_connector.cc ${CMAKE_CURRENT_SOURCE_DIR}/udp_listener.cc + ${CMAKE_CURRENT_SOURCE_DIR}/constructor.cc ) if (NOT ${CMAKE_SYSTEM_NAME} MATCHES Android) @@ -56,4 +57,4 @@ if (NOT ${CMAKE_SYSTEM_NAME} MATCHES Android) endif() set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) -set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE)
\ No newline at end of file +set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/core/constructor.cc b/libtransport/src/core/constructor.cc new file mode 100644 index 000000000..0c7f0dfa8 --- /dev/null +++ b/libtransport/src/core/constructor.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <core/global_configuration.h> +#include <core/global_module_manager.h> +#include <core/global_workers.h> +#include <hicn/transport/core/global_object_pool.h> + +namespace transport { +namespace core { + +void __attribute__((constructor)) libtransportInit() { + // First the global module manager is initialized + GlobalModuleManager::getInstance(); + // Then the packet allocator is initialized + PacketManager<>::getInstance(); + // Then the global configuration is initialized + GlobalConfiguration::getInstance(); + // Then the global workers are initialized + GlobalWorkers::getInstance(); +} + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/content_object.cc b/libtransport/src/core/content_object.cc index 643e0388e..e66b2a6cd 100644 --- a/libtransport/src/core/content_object.cc +++ b/libtransport/src/core/content_object.cc @@ -15,6 +15,7 @@ #include <hicn/transport/core/content_object.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/branch_prediction.h> extern "C" { @@ -46,7 +47,7 @@ ContentObject::ContentObject(const Name &name, Packet::Format format, } if (TRANSPORT_EXPECT_FALSE(hicn_data_get_name(format_, packet_start_, - name_.getStructReference()) < + &name_.getStructReference()) < 0)) { throw errors::MalformedPacketException(); } @@ -91,7 +92,7 @@ ContentObject::~ContentObject() {} const Name &ContentObject::getName() const { if (!name_) { if (hicn_data_get_name(format_, packet_start_, - (hicn_name_t *)name_.getConstStructReference()) < + (hicn_name_t *)&name_.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } @@ -104,11 +105,11 @@ Name &ContentObject::getWritableName() { return const_cast<Name &>(getName()); } void ContentObject::setName(const Name &name) { if (hicn_data_set_name(format_, packet_start_, - name.getConstStructReference()) < 0) { + &name.getConstStructReference()) < 0) { throw errors::RuntimeException("Error setting content object name."); } - if (hicn_data_get_name(format_, packet_start_, name_.getStructReference()) < + if (hicn_data_get_name(format_, packet_start_, &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } @@ -121,11 +122,11 @@ uint32_t ContentObject::getPathLabel() const { "Error retrieving the path label from content object"); } - return ntohl(path_label); + return portability::net_to_host(path_label); } ContentObject &ContentObject::setPathLabel(uint32_t path_label) { - path_label = htonl(path_label); + path_label = portability::host_to_net(path_label); if (hicn_data_set_path_label((hicn_header_t *)packet_start_, path_label) < 0) { throw errors::RuntimeException( diff --git a/libtransport/src/core/facade.h b/libtransport/src/core/facade.h index 1ad4437e2..77c1d16d2 100644 --- a/libtransport/src/core/facade.h +++ b/libtransport/src/core/facade.h @@ -15,16 +15,16 @@ #pragma once +#include <core/manifest.h> #include <core/manifest_format_fixed.h> -#include <core/manifest_inline.h> #include <core/portal.h> namespace transport { namespace core { -using ContentObjectManifest = core::ManifestInline<ContentObject, Fixed>; -using InterestManifest = core::ManifestInline<Interest, Fixed>; +using ContentObjectManifest = core::Manifest<Fixed>; +using InterestManifest = core::Manifest<Fixed>; } // namespace core diff --git a/libtransport/src/core/manifest.cc b/libtransport/src/core/global_id_counter.h index da2689426..0a67b76d5 100644 --- a/libtransport/src/core/manifest.cc +++ b/libtransport/src/core/global_id_counter.h @@ -13,21 +13,27 @@ * limitations under the License. */ -#include <hicn/transport/core/manifest.h> +#pragma once -namespace transport { +#include <hicn/transport/utils/singleton.h> -namespace core { +#include <atomic> +#include <mutex> -std::string ManifestEncoding::manifest_type = std::string("manifest_type"); +namespace transport { -std::map<ManifestType, std::string> ManifestEncoding::manifest_types = { - {FINAL_CHUNK_NUMBER, "FinalChunkNumber"}, {NAME_LIST, "NameList"}}; +namespace core { -std::string ManifestEncoding::final_chunk_number = - std::string("final_chunk_number"); -std::string ManifestEncoding::content_name = std::string("content_name"); +template <typename T = uint64_t> +class GlobalCounter : public utils::Singleton<GlobalCounter<T>> { + public: + friend class utils::Singleton<GlobalCounter>; + T getNext() { return counter_++; } -} // end namespace core + private: + GlobalCounter() : counter_(0) {} + std::atomic<T> counter_; +}; -} // end namespace transport
\ No newline at end of file +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/global_module_manager.h b/libtransport/src/core/global_module_manager.h new file mode 100644 index 000000000..c9d272cdb --- /dev/null +++ b/libtransport/src/core/global_module_manager.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <glog/logging.h> +#include <hicn/transport/utils/singleton.h> + +#ifndef _WIN32 +#include <dlfcn.h> +#endif + +#include <atomic> +#include <iostream> +#include <mutex> +#include <unordered_map> + +namespace transport { +namespace core { + +class GlobalModuleManager : public utils::Singleton<GlobalModuleManager> { + public: + friend class utils::Singleton<GlobalModuleManager>; + + ~GlobalModuleManager() { + for (const auto &[key, value] : modules_) { + unload(value); + } + } + + void *loadModule(const std::string &module_name) { + void *handle = nullptr; + const char *error = nullptr; + + // Lock + std::unique_lock lck(mtx_); + + auto it = modules_.find(module_name); + if (it != modules_.end()) { + return it->second; + } + + // open module + handle = dlopen(module_name.c_str(), RTLD_NOW); + if (!handle) { + if ((error = dlerror()) != nullptr) { + LOG(ERROR) << error; + } + return nullptr; + } + + auto ret = modules_.try_emplace(module_name, handle); + DCHECK(ret.second); + + return handle; + } + + void unload(void *handle) { + // destroy object and close module + dlclose(handle); + } + + bool unloadModule(const std::string &module_name) { + // Lock + std::unique_lock lck(mtx_); + auto it = modules_.find(module_name); + if (it != modules_.end()) { + unload(it->second); + return true; + } + + return false; + } + + private: + GlobalModuleManager() = default; + std::mutex mtx_; + std::unordered_map<std::string, void *> modules_; +}; + +} // namespace core +} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/global_workers.h b/libtransport/src/core/global_workers.h index 1ac254188..c5d794ef2 100644 --- a/libtransport/src/core/global_workers.h +++ b/libtransport/src/core/global_workers.h @@ -32,6 +32,8 @@ class GlobalWorkers : public utils::Singleton<GlobalWorkers> { return thread_pool_.getWorker(counter_++ % thread_pool_.getNThreads()); } + auto& getWorkers() { return thread_pool_.getWorkers(); } + private: GlobalWorkers() : counter_(0), thread_pool_() {} diff --git a/libtransport/src/core/interest.cc b/libtransport/src/core/interest.cc index b7719b3ed..8b9dcf256 100644 --- a/libtransport/src/core/interest.cc +++ b/libtransport/src/core/interest.cc @@ -21,6 +21,7 @@ extern "C" { #ifndef _WIN32 TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") #endif +#include <hicn/base.h> #include <hicn/hicn.h> } @@ -39,12 +40,12 @@ Interest::Interest(const Name &interest_name, Packet::Format format, } if (hicn_interest_set_name(format_, packet_start_, - interest_name.getConstStructReference()) < 0) { + &interest_name.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -64,7 +65,7 @@ Interest::Interest(hicn_format_t format, std::size_t additional_header_size) Interest::Interest(MemBuf &&buffer) : Packet(std::move(buffer)) { if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -86,9 +87,9 @@ Interest::~Interest() {} const Name &Interest::getName() const { if (!name_) { - if (hicn_interest_get_name(format_, packet_start_, - (hicn_name_t *)name_.getConstStructReference()) < - 0) { + if (hicn_interest_get_name( + format_, packet_start_, + (hicn_name_t *)&name_.getConstStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -100,12 +101,12 @@ Name &Interest::getWritableName() { return const_cast<Name &>(getName()); } void Interest::setName(const Name &name) { if (hicn_interest_set_name(format_, packet_start_, - name.getConstStructReference()) < 0) { + &name.getConstStructReference()) < 0) { throw errors::RuntimeException("Error setting interest name."); } if (hicn_interest_get_name(format_, packet_start_, - name_.getStructReference()) < 0) { + &name_.getStructReference()) < 0) { throw errors::MalformedPacketException(); } } @@ -150,6 +151,13 @@ void Interest::resetForHash() { throw errors::RuntimeException( "Error resetting interest fields for hash computation."); } + + // Reset request bitmap in manifest + if (hasManifest()) { + auto int_manifest_header = + (interest_manifest_header_t *)(writableData() + headerSize()); + memset(int_manifest_header->request_bitmap, 0, BITMAP_SIZE * sizeof(u32)); + } } bool Interest::hasManifest() { @@ -171,19 +179,21 @@ void Interest::encodeSuffixes() { // We assume interest does not hold signature for the moment. auto int_manifest_header = - (InterestManifestHeader *)(writableData() + headerSize()); - int_manifest_header->n_suffixes = suffix_set_.size(); - std::size_t additional_length = - sizeof(InterestManifestHeader) + - int_manifest_header->n_suffixes * sizeof(uint32_t); + (interest_manifest_header_t *)(writableData() + headerSize()); + int_manifest_header->n_suffixes = (uint32_t)suffix_set_.size(); + memset(int_manifest_header->request_bitmap, 0xFFFFFFFF, + BITMAP_SIZE * sizeof(u32)); uint32_t *suffix = (uint32_t *)(int_manifest_header + 1); for (auto it = suffix_set_.begin(); it != suffix_set_.end(); it++, suffix++) { *suffix = *it; } + std::size_t additional_length = + sizeof(interest_manifest_header_t) + + int_manifest_header->n_suffixes * sizeof(uint32_t); append(additional_length); - updateLength(additional_length); + updateLength(); } uint32_t *Interest::firstSuffix() { @@ -191,7 +201,7 @@ uint32_t *Interest::firstSuffix() { return nullptr; } - auto ret = (InterestManifestHeader *)(writableData() + headerSize()); + auto ret = (interest_manifest_header_t *)(writableData() + headerSize()); ret += 1; return (uint32_t *)ret; @@ -202,11 +212,48 @@ uint32_t Interest::numberOfSuffixes() { return 0; } - auto header = (InterestManifestHeader *)(writableData() + headerSize()); + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); return header->n_suffixes; } +uint32_t *Interest::getRequestBitmap() { + if (!hasManifest()) return nullptr; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + return header->request_bitmap; +} + +void Interest::setRequestBitmap(const uint32_t *request_bitmap) { + if (!hasManifest()) return; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + memcpy(header->request_bitmap, request_bitmap, + BITMAP_SIZE * sizeof(uint32_t)); +} + +bool Interest::isValid() { + if (!hasManifest()) return true; + + auto header = (interest_manifest_header_t *)(writableData() + headerSize()); + + if (header->n_suffixes == 0 || + header->n_suffixes > MAX_SUFFIXES_IN_MANIFEST) { + std::cerr << "Manifest with invalid number of suffixes " + << header->n_suffixes; + return false; + } + + uint32_t empty_bitmap[BITMAP_SIZE]; + memset(empty_bitmap, 0, sizeof(empty_bitmap)); + if (memcmp(empty_bitmap, header->request_bitmap, sizeof(empty_bitmap)) == 0) { + std::cerr << "Manifest with empty bitmap"; + return false; + } + + return true; +} + } // end namespace core } // end namespace transport diff --git a/libtransport/src/core/io_module.cc b/libtransport/src/core/io_module.cc index 69e4e8bcf..0f92cc47c 100644 --- a/libtransport/src/core/io_module.cc +++ b/libtransport/src/core/io_module.cc @@ -16,6 +16,7 @@ #ifndef _WIN32 #include <dlfcn.h> #endif +#include <core/global_module_manager.h> #include <glog/logging.h> #include <hicn/transport/core/io_module.h> @@ -38,54 +39,28 @@ IoModule *IoModule::load(const char *module_name) { #ifdef ANDROID return new HicnForwarderModule(); #else - void *handle = 0; - IoModule *module = 0; - IoModule *(*creator)(void) = 0; - const char *error = 0; + IoModule *iomodule = nullptr; + IoModule *(*creator)(void) = nullptr; + const char *error = nullptr; - // open module - handle = dlopen(module_name, RTLD_NOW); - if (!handle) { - if ((error = dlerror()) != 0) { - LOG(ERROR) << error; - } - return 0; - } + auto handle = GlobalModuleManager::getInstance().loadModule(module_name); // get factory method creator = (IoModule * (*)(void)) dlsym(handle, "create_module"); if (!creator) { - if ((error = dlerror()) != 0) { + if ((error = dlerror()) != nullptr) { LOG(ERROR) << error; } - return 0; + return nullptr; } // create object and return it - module = (*creator)(); - module->handle_ = handle; + iomodule = (*creator)(); - return module; + return iomodule; #endif } -bool IoModule::unload(IoModule *module) { - if (!module) { - return false; - } - -#ifdef ANDROID - delete module; -#else - // destroy object and close module - void *handle = module->handle_; - delete module; - dlclose(handle); -#endif - - return true; -} - } // namespace core } // namespace transport diff --git a/libtransport/src/core/local_connector.cc b/libtransport/src/core/local_connector.cc deleted file mode 100644 index f27be2e5c..000000000 --- a/libtransport/src/core/local_connector.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <core/local_connector.h> -#include <glog/logging.h> -#include <hicn/transport/core/asio_wrapper.h> -#include <hicn/transport/core/content_object.h> -#include <hicn/transport/core/interest.h> -#include <hicn/transport/errors/not_implemented_exception.h> - -namespace transport { -namespace core { - -LocalConnector::~LocalConnector() {} - -void LocalConnector::close() { state_ = State::CLOSED; } - -void LocalConnector::send(Packet &packet) { - if (!isConnected()) { - return; - } - - auto buffer = - std::static_pointer_cast<utils::MemBuf>(packet.shared_from_this()); - - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending packet to local socket."; - io_service_.get().post([this, buffer]() mutable { - std::vector<utils::MemBuf::Ptr> v{std::move(buffer)}; - receive_callback_(this, v, std::make_error_code(std::errc(0))); - }); -} - -void LocalConnector::send(const utils::MemBuf::Ptr &buffer) { - throw errors::NotImplementedException(); -} - -} // namespace core -} // namespace transport
\ No newline at end of file diff --git a/libtransport/src/core/local_connector.h b/libtransport/src/core/local_connector.h index eede89e74..963f455e6 100644 --- a/libtransport/src/core/local_connector.h +++ b/libtransport/src/core/local_connector.h @@ -15,9 +15,11 @@ #pragma once +#include <core/errors.h> #include <hicn/transport/core/asio_wrapper.h> #include <hicn/transport/core/connector.h> #include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/errors/not_implemented_exception.h> #include <hicn/transport/utils/move_wrapper.h> #include <hicn/transport/utils/shared_ptr_utils.h> #include <io_modules/forwarder/errors.h> @@ -34,19 +36,48 @@ class LocalConnector : public Connector { OnClose &&close_callback, OnReconnect &&on_reconnect) : Connector(receive_callback, packet_sent, close_callback, on_reconnect), io_service_(io_service), - io_service_work_(io_service_.get()) { - state_ = State::CONNECTED; - } + io_service_work_(io_service_.get()) {} - ~LocalConnector() override; + ~LocalConnector() override = default; - void send(Packet &packet) override; + auto shared_from_this() { return utils::shared_from(this); } - void send(const utils::MemBuf::Ptr &buffer) override; + void send(Packet &packet) override { send(packet.shared_from_this()); } - void close() override; + void send(const utils::MemBuf::Ptr &buffer) override { + throw errors::NotImplementedException(); + } - auto shared_from_this() { return utils::shared_from(this); } + void receive(const std::vector<utils::MemBuf::Ptr> &buffers) override { + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Sending packet to local socket."; + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self, _buffers{std::move(buffers)}]() mutable { + if (auto ptr = self.lock()) { + ptr->receive_callback_(ptr.get(), _buffers, + make_error_code(core_error::success)); + } + }); + } + + void reconnect() override { + state_ = State::CONNECTED; + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self]() { + if (auto ptr = self.lock()) { + ptr->on_reconnect_callback_(ptr.get(), + make_error_code(core_error::success)); + } + }); + } + + void close() override { + std::weak_ptr<LocalConnector> self = shared_from_this(); + io_service_.get().post([self]() mutable { + if (auto ptr = self.lock()) { + ptr->on_close_callback_(ptr.get()); + } + }); + } private: std::reference_wrapper<asio::io_service> io_service_; diff --git a/libtransport/src/core/manifest.h b/libtransport/src/core/manifest.h index 5bdbfc6ff..40832bb6b 100644 --- a/libtransport/src/core/manifest.h +++ b/libtransport/src/core/manifest.h @@ -17,165 +17,72 @@ #include <core/manifest_format.h> #include <glog/logging.h> -#include <hicn/transport/core/content_object.h> -#include <hicn/transport/core/name.h> - -#include <set> +#include <hicn/transport/auth/verifier.h> +#include <hicn/transport/core/global_object_pool.h> +#include <hicn/transport/core/packet.h> namespace transport { - namespace core { -using typename core::Name; -using typename core::Packet; -using typename core::PayloadType; - -template <typename Base, typename FormatTraits, typename ManifestImpl> -class Manifest : public Base { - static_assert(std::is_base_of<Packet, Base>::value, - "Base must inherit from packet!"); - +template <typename FormatTraits> +class Manifest : public FormatTraits::Encoder, public FormatTraits::Decoder { public: - // core::ContentObjectManifest::Ptr + using Ptr = std::shared_ptr<Manifest>; using Encoder = typename FormatTraits::Encoder; using Decoder = typename FormatTraits::Decoder; - Manifest(Packet::Format format, std::size_t signature_size = 0) - : Base(format, signature_size), - encoder_(*this, signature_size), - decoder_(*this) { - DCHECK(_is_ah(format)); - Base::setPayloadType(PayloadType::MANIFEST); - } + using Hash = typename FormatTraits::Hash; + using HashType = typename FormatTraits::HashType; + using Suffix = typename FormatTraits::Suffix; + using SuffixList = typename FormatTraits::SuffixList; + using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; - Manifest(Packet::Format format, const core::Name &name, - std::size_t signature_size = 0) - : Base(name, format, signature_size), - encoder_(*this, signature_size), - decoder_(*this) { - DCHECK(_is_ah(format)); - Base::setPayloadType(PayloadType::MANIFEST); + Manifest(Packet::Ptr packet, bool clear = false) + : Encoder(packet, clear), Decoder(packet), packet_(packet) { + packet->setPayloadType(PayloadType::MANIFEST); } - template <typename T> - Manifest(T &&base) - : Base(std::forward<T &&>(base)), - encoder_(*this, 0, false), - decoder_(*this) { - Base::setPayloadType(PayloadType::MANIFEST); - } - - // Useful for decoding manifests while avoiding packet copy - template <typename T> - Manifest(T &base) - : Base(base.getFormat()), encoder_(base, 0, false), decoder_(base) {} - virtual ~Manifest() = default; - std::size_t estimateManifestSize(std::size_t additional_entries = 0) { - return static_cast<ManifestImpl &>(*this).estimateManifestSizeImpl( - additional_entries); - } - - /* - * After the call to encode, users MUST call clear before adding data - * to the manifest. - */ - Manifest &encode() { return static_cast<ManifestImpl &>(*this).encodeImpl(); } - - Manifest &decode() { - Manifest::decoder_.decode(); - - manifest_type_ = decoder_.getType(); - manifest_transport_type_ = decoder_.getTransportType(); - hash_algorithm_ = decoder_.getHashAlgorithm(); - is_last_ = decoder_.getIsLast(); + Packet::Ptr getPacket() const { return packet_; } - return static_cast<ManifestImpl &>(*this).decodeImpl(); + void setHeaders(ManifestType type, uint8_t max_capacity, HashType hash_algo, + bool is_last, const Name &base_name) { + Encoder::setType(type); + Encoder::setMaxCapacity(max_capacity); + Encoder::setHashAlgorithm(hash_algo); + Encoder::setIsLast(is_last); + Encoder::setBaseName(base_name); } - static std::size_t manifestHeaderSize( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN) { - return Encoder::manifestHeaderSize(transport_type); - } + auth::Verifier::SuffixMap getSuffixMap() const { + auth::Verifier::SuffixMap suffix_map; - static std::size_t manifestEntrySize() { - return Encoder::manifestEntrySize(); - } + HashType hash_algo = Decoder::getHashAlgorithm(); + SuffixList suffix_list = Decoder::getEntries(); - Manifest &setType(ManifestType type) { - manifest_type_ = type; - encoder_.setType(manifest_type_); - return *this; - } + for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { + Hash hash(it->second, Hash::getSize(hash_algo), hash_algo); + suffix_map[it->first] = hash; + } - Manifest &setHashAlgorithm(auth::CryptoHashType hash_algorithm) { - hash_algorithm_ = hash_algorithm; - encoder_.setHashAlgorithm(hash_algorithm_); - return *this; + return suffix_map; } - auth::CryptoHashType getHashAlgorithm() const { return hash_algorithm_; } - - ManifestType getType() const { return manifest_type_; } - - interface::ProductionProtocolAlgorithms getTransportType() const { - return manifest_transport_type_; - } - - bool getIsLast() const { return is_last_; } - - Manifest &setVersion(ManifestVersion version) { - encoder_.setVersion(version); - return *this; - } - - Manifest &setParamsBytestream(const ParamsBytestream ¶ms) { - manifest_transport_type_ = - interface::ProductionProtocolAlgorithms::BYTE_STREAM; - encoder_.setParamsBytestream(params); - return *this; - } - - Manifest &setParamsRTC(const ParamsRTC ¶ms) { - manifest_transport_type_ = - interface::ProductionProtocolAlgorithms::RTC_PROD; - encoder_.setParamsRTC(params); - return *this; - } - - ParamsBytestream getParamsBytestream() const { - return decoder_.getParamsBytestream(); - } - - ParamsRTC getParamsRTC() const { return decoder_.getParamsRTC(); } - - ManifestVersion getVersion() const { return decoder_.getVersion(); } - - Manifest &setIsLast(bool is_last) { - encoder_.setIsLast(is_last); - is_last_ = is_last; - return *this; - } - - Manifest &clear() { - encoder_.clear(); - decoder_.clear(); - return *this; - } + static Manifest::Ptr createContentManifest(Packet::Format format, + const core::Name &manifest_name, + std::size_t signature_size) { + ContentObject::Ptr content_object = + core::PacketManager<>::getInstance().getPacket<ContentObject>( + format, signature_size); + content_object->setName(manifest_name); + return std::make_shared<Manifest>(content_object, true); + }; protected: - ManifestType manifest_type_; - interface::ProductionProtocolAlgorithms manifest_transport_type_; - auth::CryptoHashType hash_algorithm_; - bool is_last_; - - Encoder encoder_; - Decoder decoder_; + Packet::Ptr packet_; }; } // end namespace core - } // end namespace transport diff --git a/libtransport/src/core/manifest_format.h b/libtransport/src/core/manifest_format.h index caee210cd..89412316a 100644 --- a/libtransport/src/core/manifest_format.h +++ b/libtransport/src/core/manifest_format.h @@ -25,13 +25,8 @@ #include <unordered_map> namespace transport { - namespace core { -enum class ManifestVersion : uint8_t { - VERSION_1 = 1, -}; - enum class ManifestType : uint8_t { INLINE_MANIFEST = 1, FINAL_CHUNK_NUMBER = 2, @@ -83,14 +78,27 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).clearImpl(); } + bool isEncoded() const { + return static_cast<const Implementation &>(*this).isEncodedImpl(); + } + ManifestEncoder &setType(ManifestType type) { return static_cast<Implementation &>(*this).setTypeImpl(type); } + ManifestEncoder &setMaxCapacity(uint8_t max_capacity) { + return static_cast<Implementation &>(*this).setMaxCapacityImpl( + max_capacity); + } + ManifestEncoder &setHashAlgorithm(auth::CryptoHashType hash) { return static_cast<Implementation &>(*this).setHashAlgorithmImpl(hash); } + ManifestEncoder &setIsLast(bool is_last) { + return static_cast<Implementation &>(*this).setIsLastImpl(is_last); + } + template < typename T, typename = std::enable_if_t<std::is_same< @@ -99,45 +107,36 @@ class ManifestEncoder { return static_cast<Implementation &>(*this).setBaseNameImpl(name); } - template <typename Hash> - ManifestEncoder &addSuffixAndHash(uint32_t suffix, Hash &&hash) { - return static_cast<Implementation &>(*this).addSuffixAndHashImpl( - suffix, std::forward<Hash &&>(hash)); + ManifestEncoder &setParamsBytestream(const ParamsBytestream ¶ms) { + return static_cast<Implementation &>(*this).setParamsBytestreamImpl(params); } - ManifestEncoder &setIsLast(bool is_last) { - return static_cast<Implementation &>(*this).setIsLastImpl(is_last); + ManifestEncoder &setParamsRTC(const ParamsRTC ¶ms) { + return static_cast<Implementation &>(*this).setParamsRTCImpl(params); } - ManifestEncoder &setVersion(ManifestVersion version) { - return static_cast<Implementation &>(*this).setVersionImpl(version); + template <typename Hash> + ManifestEncoder &addEntry(uint32_t suffix, Hash &&hash) { + return static_cast<Implementation &>(*this).addEntryImpl( + suffix, std::forward<Hash>(hash)); } - std::size_t estimateSerializedLength(std::size_t number_of_entries) { - return static_cast<Implementation &>(*this).estimateSerializedLengthImpl( - number_of_entries); + ManifestEncoder &removeEntry(uint32_t suffix) { + return static_cast<Implementation &>(*this).removeEntryImpl(suffix); } - ManifestEncoder &update() { - return static_cast<Implementation &>(*this).updateImpl(); + std::size_t manifestHeaderSize() const { + return static_cast<const Implementation &>(*this).manifestHeaderSizeImpl(); } - ManifestEncoder &setParamsBytestream(const ParamsBytestream ¶ms) { - return static_cast<Implementation &>(*this).setParamsBytestreamImpl(params); + std::size_t manifestPayloadSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestPayloadSizeImpl( + additional_entries); } - ManifestEncoder &setParamsRTC(const ParamsRTC ¶ms) { - return static_cast<Implementation &>(*this).setParamsRTCImpl(params); - } - - static std::size_t manifestHeaderSize( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN) { - return Implementation::manifestHeaderSizeImpl(transport_type); - } - - static std::size_t manifestEntrySize() { - return Implementation::manifestEntrySizeImpl(); + std::size_t manifestSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestSizeImpl( + additional_entries); } }; @@ -146,11 +145,17 @@ class ManifestDecoder { public: virtual ~ManifestDecoder() = default; + ManifestDecoder &decode() { + return static_cast<Implementation &>(*this).decodeImpl(); + } + ManifestDecoder &clear() { return static_cast<Implementation &>(*this).clearImpl(); } - void decode() { static_cast<Implementation &>(*this).decodeImpl(); } + bool isDecoded() const { + return static_cast<const Implementation &>(*this).isDecodedImpl(); + } ManifestType getType() const { return static_cast<const Implementation &>(*this).getTypeImpl(); @@ -160,40 +165,48 @@ class ManifestDecoder { return static_cast<const Implementation &>(*this).getTransportTypeImpl(); } + uint8_t getMaxCapacity() const { + return static_cast<const Implementation &>(*this).getMaxCapacityImpl(); + } + auth::CryptoHashType getHashAlgorithm() const { return static_cast<const Implementation &>(*this).getHashAlgorithmImpl(); } + bool getIsLast() const { + return static_cast<const Implementation &>(*this).getIsLastImpl(); + } + core::Name getBaseName() const { return static_cast<const Implementation &>(*this).getBaseNameImpl(); } - auto getSuffixHashList() { - return static_cast<Implementation &>(*this).getSuffixHashListImpl(); + ParamsBytestream getParamsBytestream() const { + return static_cast<const Implementation &>(*this).getParamsBytestreamImpl(); } - bool getIsLast() const { - return static_cast<const Implementation &>(*this).getIsLastImpl(); + ParamsRTC getParamsRTC() const { + return static_cast<const Implementation &>(*this).getParamsRTCImpl(); } - ManifestVersion getVersion() const { - return static_cast<const Implementation &>(*this).getVersionImpl(); + auto getEntries() const { + return static_cast<const Implementation &>(*this).getEntriesImpl(); } - std::size_t estimateSerializedLength(std::size_t number_of_entries) const { - return static_cast<const Implementation &>(*this) - .estimateSerializedLengthImpl(number_of_entries); + std::size_t manifestHeaderSize() const { + return static_cast<const Implementation &>(*this).manifestHeaderSizeImpl(); } - ParamsBytestream getParamsBytestream() const { - return static_cast<const Implementation &>(*this).getParamsBytestreamImpl(); + std::size_t manifestPayloadSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestPayloadSizeImpl( + additional_entries); } - ParamsRTC getParamsRTC() const { - return static_cast<const Implementation &>(*this).getParamsRTCImpl(); + std::size_t manifestSize(size_t additional_entries = 0) const { + return static_cast<const Implementation &>(*this).manifestSizeImpl( + additional_entries); } }; } // namespace core - } // namespace transport diff --git a/libtransport/src/core/manifest_format_fixed.cc b/libtransport/src/core/manifest_format_fixed.cc index 428d6ad12..668169642 100644 --- a/libtransport/src/core/manifest_format_fixed.cc +++ b/libtransport/src/core/manifest_format_fixed.cc @@ -18,22 +18,42 @@ #include <hicn/transport/utils/literals.h> namespace transport { - namespace core { -// TODO use preallocated pool of membufs -FixedManifestEncoder::FixedManifestEncoder(Packet &packet, - std::size_t signature_size, - bool clear) +// --------------------------------------------------------- +// FixedManifest +// --------------------------------------------------------- +size_t FixedManifest::manifestHeaderSize( + interface::ProductionProtocolAlgorithms transport_type) { + uint32_t params_size = 0; + + switch (transport_type) { + case interface::ProductionProtocolAlgorithms::BYTE_STREAM: + params_size = MANIFEST_PARAMS_BYTESTREAM_SIZE; + break; + case interface::ProductionProtocolAlgorithms::RTC_PROD: + params_size = MANIFEST_PARAMS_RTC_SIZE; + break; + default: + break; + } + + return MANIFEST_META_SIZE + MANIFEST_ENTRY_META_SIZE + params_size; +} + +size_t FixedManifest::manifestPayloadSize(size_t nb_entries) { + return nb_entries * MANIFEST_ENTRY_SIZE; +} + +// --------------------------------------------------------- +// FixedManifestEncoder +// --------------------------------------------------------- +FixedManifestEncoder::FixedManifestEncoder(Packet::Ptr packet, bool clear) : packet_(packet), - max_size_(Packet::default_mtu - packet_.headerSize()), - signature_size_(signature_size), transport_type_(interface::ProductionProtocolAlgorithms::UNKNOWN), - encoded_(false), - params_bytestream_({0}), - params_rtc_({0}) { - manifest_meta_ = reinterpret_cast<ManifestMeta *>(packet_.writableData() + - packet_.headerSize()); + encoded_(false) { + manifest_meta_ = reinterpret_cast<ManifestMeta *>(packet_->writableData() + + packet_->headerSize()); manifest_entry_meta_ = reinterpret_cast<ManifestEntryMeta *>(manifest_meta_ + 1); @@ -50,32 +70,34 @@ FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { return *this; } + // Copy manifest header manifest_meta_->transport_type = static_cast<uint8_t>(transport_type_); manifest_entry_meta_->nb_entries = manifest_entries_.size(); - packet_.append(FixedManifestEncoder::manifestHeaderSizeImpl()); - packet_.updateLength(); + packet_->append(manifestHeaderSizeImpl()); + packet_->updateLength(); + auto params = reinterpret_cast<uint8_t *>(manifest_entry_meta_ + 1); switch (transport_type_) { - case interface::ProductionProtocolAlgorithms::BYTE_STREAM: - packet_.appendPayload( - reinterpret_cast<const uint8_t *>(¶ms_bytestream_), - MANIFEST_PARAMS_BYTESTREAM_SIZE); + case interface::ProductionProtocolAlgorithms::BYTE_STREAM: { + auto bytestream = reinterpret_cast<const uint8_t *>(¶ms_bytestream_); + std::memcpy(params, bytestream, MANIFEST_PARAMS_BYTESTREAM_SIZE); break; - case interface::ProductionProtocolAlgorithms::RTC_PROD: - packet_.appendPayload(reinterpret_cast<const uint8_t *>(¶ms_rtc_), - MANIFEST_PARAMS_RTC_SIZE); + } + case interface::ProductionProtocolAlgorithms::RTC_PROD: { + auto rtc = reinterpret_cast<const uint8_t *>(¶ms_rtc_); + std::memcpy(params, rtc, MANIFEST_PARAMS_RTC_SIZE); break; + } default: break; } - packet_.appendPayload( - reinterpret_cast<const uint8_t *>(manifest_entries_.data()), - manifest_entries_.size() * FixedManifestEncoder::manifestEntrySizeImpl()); + // Copy manifest entries + auto payload = reinterpret_cast<const uint8_t *>(manifest_entries_.data()); + packet_->appendPayload(payload, manifestPayloadSizeImpl()); - if (TRANSPORT_EXPECT_FALSE(packet_.payloadSize() < - estimateSerializedLengthImpl())) { + if (TRANSPORT_EXPECT_FALSE(packet_->payloadSize() < manifestSizeImpl())) { throw errors::RuntimeException("Error encoding the manifest"); } @@ -85,32 +107,21 @@ FixedManifestEncoder &FixedManifestEncoder::encodeImpl() { FixedManifestEncoder &FixedManifestEncoder::clearImpl() { if (encoded_) { - packet_.trimEnd(FixedManifestEncoder::manifestHeaderSizeImpl() + - manifest_entries_.size() * - FixedManifestEncoder::manifestEntrySizeImpl()); + packet_->trimEnd(manifestSizeImpl()); } transport_type_ = interface::ProductionProtocolAlgorithms::UNKNOWN; encoded_ = false; - params_bytestream_ = {0}; - params_rtc_ = {0}; *manifest_meta_ = {0}; *manifest_entry_meta_ = {0}; + params_bytestream_ = {0}; + params_rtc_ = {0}; manifest_entries_.clear(); return *this; } -FixedManifestEncoder &FixedManifestEncoder::updateImpl() { - max_size_ = Packet::default_mtu - packet_.headerSize() - signature_size_; - return *this; -} - -FixedManifestEncoder &FixedManifestEncoder::setVersionImpl( - ManifestVersion version) { - manifest_meta_->version = static_cast<uint8_t>(version); - return *this; -} +bool FixedManifestEncoder::isEncodedImpl() const { return encoded_; } FixedManifestEncoder &FixedManifestEncoder::setTypeImpl( ManifestType manifest_type) { @@ -118,6 +129,12 @@ FixedManifestEncoder &FixedManifestEncoder::setTypeImpl( return *this; } +FixedManifestEncoder &FixedManifestEncoder::setMaxCapacityImpl( + uint8_t max_capacity) { + manifest_meta_->max_capacity = max_capacity; + return *this; +} + FixedManifestEncoder &FixedManifestEncoder::setHashAlgorithmImpl( auth::CryptoHashType algorithm) { manifest_meta_->hash_algorithm = static_cast<uint8_t>(algorithm); @@ -159,61 +176,68 @@ FixedManifestEncoder &FixedManifestEncoder::setParamsRTCImpl( return *this; } -FixedManifestEncoder &FixedManifestEncoder::addSuffixAndHashImpl( +FixedManifestEncoder &FixedManifestEncoder::addEntryImpl( uint32_t suffix, const auth::CryptoHash &hash) { - manifest_entries_.push_back(ManifestEntry{ - .suffix = htonl(suffix), + ManifestEntry last_entry = { + .suffix = portability::host_to_net(suffix), .hash = {0}, - }); + }; - std::memcpy(reinterpret_cast<uint8_t *>(manifest_entries_.back().hash), - hash.getDigest()->data(), hash.getSize()); - - if (TRANSPORT_EXPECT_FALSE(estimateSerializedLengthImpl() > max_size_)) { - throw errors::RuntimeException("Manifest size exceeded the packet MTU!"); - } + auto last_hash = reinterpret_cast<uint8_t *>(last_entry.hash); + std::memcpy(last_hash, hash.getDigest()->data(), hash.getSize()); + manifest_entries_.push_back(last_entry); return *this; } -std::size_t FixedManifestEncoder::estimateSerializedLengthImpl( - std::size_t additional_entries) { - return FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) + - (manifest_entries_.size() + additional_entries) * - FixedManifestEncoder::manifestEntrySizeImpl(); +FixedManifestEncoder &FixedManifestEncoder::removeEntryImpl(uint32_t suffix) { + for (auto it = manifest_entries_.begin(); it != manifest_entries_.end();) { + if (it->suffix == suffix) + it = manifest_entries_.erase(it); + else + ++it; + } + return *this; } -std::size_t FixedManifestEncoder::manifestHeaderSizeImpl( - interface::ProductionProtocolAlgorithms transport_type) { - uint32_t params_size = 0; - - switch (transport_type) { - case interface::ProductionProtocolAlgorithms::BYTE_STREAM: - params_size = MANIFEST_PARAMS_BYTESTREAM_SIZE; - break; - case interface::ProductionProtocolAlgorithms::RTC_PROD: - params_size = MANIFEST_PARAMS_RTC_SIZE; - break; - default: - break; - } +size_t FixedManifestEncoder::manifestHeaderSizeImpl() const { + return FixedManifest::manifestHeaderSize(transport_type_); +} - return MANIFEST_META_SIZE + MANIFEST_ENTRY_META_SIZE + params_size; +size_t FixedManifestEncoder::manifestPayloadSizeImpl( + size_t additional_entries) const { + return FixedManifest::manifestPayloadSize(manifest_entries_.size() + + additional_entries); } -std::size_t FixedManifestEncoder::manifestEntrySizeImpl() { - return MANIFEST_ENTRY_SIZE; +size_t FixedManifestEncoder::manifestSizeImpl(size_t additional_entries) const { + return manifestHeaderSizeImpl() + manifestPayloadSizeImpl(additional_entries); } -FixedManifestDecoder::FixedManifestDecoder(Packet &packet) +// --------------------------------------------------------- +// FixedManifestDecoder +// --------------------------------------------------------- +FixedManifestDecoder::FixedManifestDecoder(Packet::Ptr packet) : packet_(packet), decoded_(false) { manifest_meta_ = - reinterpret_cast<ManifestMeta *>(packet_.getPayload()->writableData()); + reinterpret_cast<ManifestMeta *>(packet_->getPayload()->writableData()); manifest_entry_meta_ = reinterpret_cast<ManifestEntryMeta *>(manifest_meta_ + 1); - transport_type_ = getTransportTypeImpl(); +} - switch (transport_type_) { +FixedManifestDecoder::~FixedManifestDecoder() {} + +FixedManifestDecoder &FixedManifestDecoder::decodeImpl() { + if (decoded_) { + return *this; + } + + if (packet_->payloadSize() < manifestSizeImpl()) { + throw errors::RuntimeException( + "The packet payload size does not match expected manifest size"); + } + + switch (getTransportTypeImpl()) { case interface::ProductionProtocolAlgorithms::BYTE_STREAM: params_bytestream_ = reinterpret_cast<TransportParamsBytestream *>( manifest_entry_meta_ + 1); @@ -230,25 +254,9 @@ FixedManifestDecoder::FixedManifestDecoder(Packet &packet) reinterpret_cast<ManifestEntry *>(manifest_entry_meta_ + 1); break; } -} - -FixedManifestDecoder::~FixedManifestDecoder() {} - -void FixedManifestDecoder::decodeImpl() { - if (decoded_) { - return; - } - - std::size_t packet_size = packet_.payloadSize(); - - if (packet_size < - FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) || - packet_size < estimateSerializedLengthImpl()) { - throw errors::RuntimeException( - "The packet does not match expected manifest size."); - } decoded_ = true; + return *this; } FixedManifestDecoder &FixedManifestDecoder::clearImpl() { @@ -256,20 +264,22 @@ FixedManifestDecoder &FixedManifestDecoder::clearImpl() { return *this; } +bool FixedManifestDecoder::isDecodedImpl() const { return decoded_; } + ManifestType FixedManifestDecoder::getTypeImpl() const { return static_cast<ManifestType>(manifest_meta_->type); } -ManifestVersion FixedManifestDecoder::getVersionImpl() const { - return static_cast<ManifestVersion>(manifest_meta_->version); -} - interface::ProductionProtocolAlgorithms FixedManifestDecoder::getTransportTypeImpl() const { return static_cast<interface::ProductionProtocolAlgorithms>( manifest_meta_->transport_type); } +uint8_t FixedManifestDecoder::getMaxCapacityImpl() const { + return manifest_meta_->max_capacity; +} + auth::CryptoHashType FixedManifestDecoder::getHashAlgorithmImpl() const { return static_cast<auth::CryptoHashType>(manifest_meta_->hash_algorithm); } @@ -303,26 +313,34 @@ ParamsRTC FixedManifestDecoder::getParamsRTCImpl() const { }; } -typename Fixed::SuffixList FixedManifestDecoder::getSuffixHashListImpl() { +typename Fixed::SuffixList FixedManifestDecoder::getEntriesImpl() const { typename Fixed::SuffixList hash_list; for (int i = 0; i < manifest_entry_meta_->nb_entries; i++) { - hash_list.insert(hash_list.end(), - std::make_pair(ntohl(manifest_entries_[i].suffix), - reinterpret_cast<uint8_t *>( - &manifest_entries_[i].hash[0]))); + hash_list.insert( + hash_list.end(), + std::make_pair( + portability::net_to_host(manifest_entries_[i].suffix), + reinterpret_cast<uint8_t *>(&manifest_entries_[i].hash[0]))); } return hash_list; } -std::size_t FixedManifestDecoder::estimateSerializedLengthImpl( - std::size_t additional_entries) const { - return FixedManifestEncoder::manifestHeaderSizeImpl(transport_type_) + - (manifest_entry_meta_->nb_entries + additional_entries) * - FixedManifestEncoder::manifestEntrySizeImpl(); +size_t FixedManifestDecoder::manifestHeaderSizeImpl() const { + interface::ProductionProtocolAlgorithms type = getTransportTypeImpl(); + return FixedManifest::manifestHeaderSize(type); } -} // end namespace core +size_t FixedManifestDecoder::manifestPayloadSizeImpl( + size_t additional_entries) const { + size_t nb_entries = manifest_entry_meta_->nb_entries + additional_entries; + return FixedManifest::manifestPayloadSize(nb_entries); +} +size_t FixedManifestDecoder::manifestSizeImpl(size_t additional_entries) const { + return manifestHeaderSizeImpl() + manifestPayloadSizeImpl(additional_entries); +} + +} // end namespace core } // end namespace transport diff --git a/libtransport/src/core/manifest_format_fixed.h b/libtransport/src/core/manifest_format_fixed.h index 5fd2a673d..7ab371974 100644 --- a/libtransport/src/core/manifest_format_fixed.h +++ b/libtransport/src/core/manifest_format_fixed.h @@ -28,7 +28,7 @@ namespace core { // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |Version| Type | Transport Type| Hash Algorithm|L| Reserved | +// | Type | TTYpe | Max Capacity | Hash Algo |L| Reserved | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // Manifest Entry Metadata: @@ -106,9 +106,9 @@ struct Fixed { const size_t MANIFEST_META_SIZE = 4; struct __attribute__((__packed__)) ManifestMeta { - std::uint8_t version : 4; std::uint8_t type : 4; - std::uint8_t transport_type; + std::uint8_t transport_type : 4; + std::uint8_t max_capacity; std::uint8_t hash_algorithm; std::uint8_t is_last; }; @@ -146,22 +146,26 @@ struct __attribute__((__packed__)) ManifestEntry { }; static_assert(sizeof(ManifestEntry) == MANIFEST_ENTRY_SIZE); -static const constexpr std::uint8_t manifest_version = 1; +class FixedManifest { + public: + static size_t manifestHeaderSize( + interface::ProductionProtocolAlgorithms transport_type); + static size_t manifestPayloadSize(size_t nb_entries); +}; class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { public: - FixedManifestEncoder(Packet &packet, std::size_t signature_size = 0, - bool clear = true); + FixedManifestEncoder(Packet::Ptr packet, bool clear = false); ~FixedManifestEncoder(); FixedManifestEncoder &encodeImpl(); FixedManifestEncoder &clearImpl(); - FixedManifestEncoder &updateImpl(); + bool isEncodedImpl() const; // ManifestMeta - FixedManifestEncoder &setVersionImpl(ManifestVersion version); FixedManifestEncoder &setTypeImpl(ManifestType manifest_type); + FixedManifestEncoder &setMaxCapacityImpl(uint8_t max_capacity); FixedManifestEncoder &setHashAlgorithmImpl(Fixed::HashType algorithm); FixedManifestEncoder &setIsLastImpl(bool is_last); @@ -173,20 +177,15 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { FixedManifestEncoder &setParamsRTCImpl(const ParamsRTC ¶ms); // ManifestEntry - FixedManifestEncoder &addSuffixAndHashImpl(uint32_t suffix, - const Fixed::Hash &hash); + FixedManifestEncoder &addEntryImpl(uint32_t suffix, const Fixed::Hash &hash); + FixedManifestEncoder &removeEntryImpl(uint32_t suffix); - std::size_t estimateSerializedLengthImpl(std::size_t additional_entries = 0); - - static std::size_t manifestHeaderSizeImpl( - interface::ProductionProtocolAlgorithms transport_type = - interface::ProductionProtocolAlgorithms::UNKNOWN); - static std::size_t manifestEntrySizeImpl(); + size_t manifestHeaderSizeImpl() const; + size_t manifestPayloadSizeImpl(size_t additional_entries = 0) const; + size_t manifestSizeImpl(size_t additional_entries = 0) const; private: - Packet &packet_; - std::size_t max_size_; - std::size_t signature_size_; + Packet::Ptr packet_; interface::ProductionProtocolAlgorithms transport_type_; bool encoded_; @@ -202,17 +201,18 @@ class FixedManifestEncoder : public ManifestEncoder<FixedManifestEncoder> { class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { public: - FixedManifestDecoder(Packet &packet); + FixedManifestDecoder(Packet::Ptr packet); ~FixedManifestDecoder(); - void decodeImpl(); + FixedManifestDecoder &decodeImpl(); FixedManifestDecoder &clearImpl(); + bool isDecodedImpl() const; // ManifestMeta - ManifestVersion getVersionImpl() const; ManifestType getTypeImpl() const; interface::ProductionProtocolAlgorithms getTransportTypeImpl() const; + uint8_t getMaxCapacityImpl() const; Fixed::HashType getHashAlgorithmImpl() const; bool getIsLastImpl() const; @@ -224,14 +224,14 @@ class FixedManifestDecoder : public ManifestDecoder<FixedManifestDecoder> { ParamsRTC getParamsRTCImpl() const; // ManifestEntry - typename Fixed::SuffixList getSuffixHashListImpl(); + typename Fixed::SuffixList getEntriesImpl() const; - std::size_t estimateSerializedLengthImpl( - std::size_t additional_entries = 0) const; + size_t manifestHeaderSizeImpl() const; + size_t manifestPayloadSizeImpl(size_t additional_entries = 0) const; + size_t manifestSizeImpl(size_t additional_entries = 0) const; private: - Packet &packet_; - interface::ProductionProtocolAlgorithms transport_type_; + Packet::Ptr packet_; bool decoded_; // Manifest Header diff --git a/libtransport/src/core/manifest_inline.h b/libtransport/src/core/manifest_inline.h deleted file mode 100644 index ca48a4a79..000000000 --- a/libtransport/src/core/manifest_inline.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2021 Cisco and/or its affiliates. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include <core/manifest.h> -#include <core/manifest_format.h> -#include <hicn/transport/portability/portability.h> - -#include <set> - -namespace transport { - -namespace core { - -template <typename Base, typename FormatTraits> -class ManifestInline - : public Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>> { - using ManifestBase = - Manifest<Base, FormatTraits, ManifestInline<Base, FormatTraits>>; - - using Hash = typename FormatTraits::Hash; - using HashType = typename FormatTraits::HashType; - using Suffix = typename FormatTraits::Suffix; - using SuffixList = typename FormatTraits::SuffixList; - using HashEntry = std::pair<auth::CryptoHashType, std::vector<uint8_t>>; - - public: - ManifestInline() : ManifestBase() {} - - ManifestInline(Packet::Format format, const core::Name &name, - std::size_t signature_size = 0) - : ManifestBase(format, name, signature_size) {} - - template <typename T> - ManifestInline(T &&base) : ManifestBase(std::forward<T &&>(base)) {} - - template <typename T> - ManifestInline(T &base) : ManifestBase(base) {} - - static TRANSPORT_ALWAYS_INLINE ManifestInline *createManifest( - Packet::Format format, const core::Name &manifest_name, - ManifestVersion version, ManifestType type, bool is_last, - const Name &base_name, HashType hash_algo, std::size_t signature_size) { - auto manifest = new ManifestInline(format, manifest_name, signature_size); - manifest->setVersion(version); - manifest->setType(type); - manifest->setHashAlgorithm(hash_algo); - manifest->setIsLast(is_last); - manifest->setBaseName(base_name); - return manifest; - } - - ManifestInline &encodeImpl() { - ManifestBase::encoder_.encode(); - return *this; - } - - ManifestInline &decodeImpl() { - base_name_ = ManifestBase::decoder_.getBaseName(); - suffix_hash_map_ = ManifestBase::decoder_.getSuffixHashList(); - - return *this; - } - - std::size_t estimateManifestSizeImpl(std::size_t additional_entries = 0) { - return ManifestBase::encoder_.estimateSerializedLength(additional_entries); - } - - ManifestInline &setBaseName(const Name &name) { - base_name_ = name; - ManifestBase::encoder_.setBaseName(base_name_); - return *this; - } - - const Name &getBaseName() { return base_name_; } - - ManifestInline &addSuffixHash(Suffix suffix, const Hash &hash) { - ManifestBase::encoder_.addSuffixAndHash(suffix, hash); - return *this; - } - - // Call this function only after the decode function! - const SuffixList &getSuffixList() { return suffix_hash_map_; } - - // Convert several manifests into a single map from suffixes to packet hashes. - // All manifests must have been decoded beforehand. - static std::unordered_map<Suffix, Hash> getSuffixMap( - const std::vector<ManifestInline *> &manifests) { - std::unordered_map<Suffix, Hash> suffix_map; - - for (auto manifest_ptr : manifests) { - HashType hash_type = manifest_ptr->getHashAlgorithm(); - SuffixList suffix_list = manifest_ptr->getSuffixList(); - - for (auto it = suffix_list.begin(); it != suffix_list.end(); ++it) { - Hash hash(it->second, Hash::getSize(hash_type), hash_type); - suffix_map[it->first] = hash; - } - } - - return suffix_map; - } - - static std::unordered_map<Suffix, Hash> getSuffixMap( - ManifestInline *manifest) { - return getSuffixMap(std::vector<ManifestInline *>{manifest}); - } - - private: - core::Name base_name_; - SuffixList suffix_hash_map_; -}; - -} // namespace core -} // namespace transport diff --git a/libtransport/src/core/name.cc b/libtransport/src/core/name.cc index 98091eea5..960947cb9 100644 --- a/libtransport/src/core/name.cc +++ b/libtransport/src/core/name.cc @@ -24,7 +24,7 @@ namespace transport { namespace core { -Name::Name() { name_ = {}; } +Name::Name() { std::memset(&name_, 0, sizeof(name_)); } /** * XXX This function does not use the name API provided by libhicn @@ -47,6 +47,7 @@ Name::Name(int family, const uint8_t *ip_address, std::uint32_t suffix) std::memcpy(dst, ip_address, length); name_.suffix = suffix; } + Name::Name(const char *name, uint32_t segment) { if (hicn_name_create(name, segment, &name_) < 0) { throw errors::InvalidIpAddressException(); diff --git a/libtransport/src/core/pending_interest.h b/libtransport/src/core/pending_interest.h index f8a4ba10e..fb10405d3 100644 --- a/libtransport/src/core/pending_interest.h +++ b/libtransport/src/core/pending_interest.h @@ -42,17 +42,9 @@ class PendingInterest { public: using Ptr = utils::ObjectPool<PendingInterest>::Ptr; - // PendingInterest() - // : interest_(nullptr, nullptr), - // timer_(), - // on_content_object_callback_(), - // on_interest_timeout_callback_() {} PendingInterest(asio::io_service &io_service, const Interest::Ptr &interest) - : interest_(interest), - timer_(io_service), - on_content_object_callback_(), - on_interest_timeout_callback_() {} + : interest_(interest), timer_(io_service) {} PendingInterest(asio::io_service &io_service, const Interest::Ptr &interest, OnContentObjectCallback &&on_content_object, @@ -65,10 +57,9 @@ class PendingInterest { ~PendingInterest() = default; template <typename Handler> - TRANSPORT_ALWAYS_INLINE void startCountdown(Handler &&cb) { - timer_.expires_from_now( - std::chrono::milliseconds(interest_->getLifetime())); - timer_.async_wait(std::forward<Handler &&>(cb)); + TRANSPORT_ALWAYS_INLINE void startCountdown(uint32_t lifetime, Handler &&cb) { + timer_.expires_from_now(std::chrono::milliseconds(lifetime)); + timer_.async_wait(std::forward<Handler>(cb)); } TRANSPORT_ALWAYS_INLINE void cancelTimer() { timer_.cancel(); } @@ -77,7 +68,7 @@ class PendingInterest { return std::move(interest_); } - TRANSPORT_ALWAYS_INLINE void setInterest(Interest::Ptr &interest) { + TRANSPORT_ALWAYS_INLINE void setInterest(const Interest::Ptr &interest) { interest_ = interest; } @@ -88,7 +79,7 @@ class PendingInterest { TRANSPORT_ALWAYS_INLINE void setOnContentObjectCallback( OnContentObjectCallback &&on_content_object) { - PendingInterest::on_content_object_callback_ = on_content_object; + PendingInterest::on_content_object_callback_ = std::move(on_content_object); } TRANSPORT_ALWAYS_INLINE const OnInterestTimeoutCallback & @@ -98,7 +89,8 @@ class PendingInterest { TRANSPORT_ALWAYS_INLINE void setOnTimeoutCallback( OnInterestTimeoutCallback &&on_interest_timeout) { - PendingInterest::on_interest_timeout_callback_ = on_interest_timeout; + PendingInterest::on_interest_timeout_callback_ = + std::move(on_interest_timeout); } private: diff --git a/libtransport/src/core/portal.cc b/libtransport/src/core/portal.cc index d8e8d78ea..c06969f19 100644 --- a/libtransport/src/core/portal.cc +++ b/libtransport/src/core/portal.cc @@ -43,12 +43,14 @@ std::string Portal::io_module_path_ = defaultIoModule(); std::string Portal::defaultIoModule() { using namespace std::placeholders; GlobalConfiguration::getInstance().registerConfigurationParser( - io_module_section, + IoModuleConfiguration::section, std::bind(&Portal::parseIoModuleConfiguration, _1, _2)); GlobalConfiguration::getInstance().registerConfigurationGetter( - io_module_section, std::bind(&Portal::getModuleConfiguration, _1, _2)); + IoModuleConfiguration::section, + std::bind(&Portal::getModuleConfiguration, _1, _2)); GlobalConfiguration::getInstance().registerConfigurationSetter( - io_module_section, std::bind(&Portal::setModuleConfiguration, _1, _2)); + IoModuleConfiguration::section, + std::bind(&Portal::setModuleConfiguration, _1, _2)); // return default conf_.name = default_module; @@ -57,7 +59,7 @@ std::string Portal::defaultIoModule() { void Portal::getModuleConfiguration(ConfigurationObject& object, std::error_code& ec) { - DCHECK(object.getKey() == io_module_section); + DCHECK(object.getKey() == IoModuleConfiguration::section); auto conf = dynamic_cast<const IoModuleConfiguration&>(object); conf = conf_; @@ -103,7 +105,7 @@ std::string getIoModulePath(const std::string& name, void Portal::setModuleConfiguration(const ConfigurationObject& object, std::error_code& ec) { - DCHECK(object.getKey() == io_module_section); + DCHECK(object.getKey() == IoModuleConfiguration::section); const IoModuleConfiguration& conf = dynamic_cast<const IoModuleConfiguration&>(object); diff --git a/libtransport/src/core/portal.h b/libtransport/src/core/portal.h index aae4c573e..6f3a48e83 100644 --- a/libtransport/src/core/portal.h +++ b/libtransport/src/core/portal.h @@ -32,6 +32,10 @@ #include <hicn/transport/utils/event_thread.h> #include <hicn/transport/utils/fixed_block_allocator.h> +extern "C" { +#include <hicn/header.h> +} + #include <future> #include <memory> #include <queue> @@ -179,19 +183,11 @@ class Portal : public ::utils::NonCopyable, Portal() : Portal(GlobalWorkers::getInstance().getWorker()) {} Portal(::utils::EventThread &worker) - : io_module_(nullptr, [](IoModule *module) { IoModule::unload(module); }), + : io_module_(nullptr), worker_(worker), app_name_("libtransport_application"), transport_callback_(nullptr), - is_consumer_(false) { - /** - * This workaroung allows to initialize memory for packet buffers *before* - * any static variables that may be initialized in the io_modules. In this - * way static variables in modules will be destroyed before the packet - * memory. - */ - PacketManager<>::getInstance(); - } + is_consumer_(false) {} public: using TransportCallback = interface::Portal::TransportCallback; @@ -275,6 +271,7 @@ class Portal : public ::utils::NonCopyable, ptr->transport_callback_->onError(ec); } }, + [self]([[maybe_unused]] Connector *c) { /* Nothing to do here */ }, [self](Connector *c, const std::error_code &ec) { auto ptr = self.lock(); if (ptr) { @@ -315,76 +312,113 @@ class Portal : public ::utils::NonCopyable, } /** - * Send an interest through to the local forwarder. - * - * @param interest - The pointer to the interest. The ownership of the - * interest is transferred by the caller to portal. - * - * @param on_content_object_callback - If the caller wishes to use a - * different callback to be called for this interest, it can set this - * parameter. Otherwise ConsumerCallback::onContentObject will be used. + * @brief Add interest to PIT * - * @param on_interest_timeout_callback - If the caller wishes to use a - * different callback to be called for this interest, it can set this - * parameter. Otherwise ConsumerCallback::onTimeout will be used. */ - void sendInterest( - Interest::Ptr &&interest, + void addInterestToPIT( + const Interest::Ptr &interest, uint32_t lifetime, OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, OnInterestTimeoutCallback &&on_interest_timeout_callback = UNSET_CALLBACK) { - DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - - // Send it - interest->encodeSuffixes(); - io_module_->send(*interest); - uint32_t initial_hash = interest->getName().getHash32(false); auto hash = initial_hash + interest->getName().getSuffix(); uint32_t seq = interest->getName().getSuffix(); - uint32_t *suffix = interest->firstSuffix(); + const uint32_t *suffix = interest->firstSuffix(); auto n_suffixes = interest->numberOfSuffixes(); uint32_t counter = 0; // Set timers do { + auto pend_int = pending_interest_hash_table_.try_emplace( + hash, worker_.getIoService(), interest); + PendingInterest &pending_interest = pend_int.first->second; + if (!pend_int.second) { + // element was already in map + pend_int.first->second.cancelTimer(); + pending_interest.setInterest(interest); + } + + pending_interest.setOnContentObjectCallback( + std::move(on_content_object_callback)); + pending_interest.setOnTimeoutCallback( + std::move(on_interest_timeout_callback)); + + if (is_consumer_) { + auto self = weak_from_this(); + pending_interest.startCountdown( + lifetime, portal_details::makeCustomAllocatorHandler( + async_callback_memory_, + [self, hash, seq](const std::error_code &ec) { + if (TRANSPORT_EXPECT_FALSE(ec.operator bool())) { + return; + } + + if (auto ptr = self.lock()) { + ptr->timerHandler(hash, seq); + } + })); + } + if (suffix) { hash = initial_hash + *suffix; seq = *suffix; suffix++; } + } while (counter++ < n_suffixes); + } - auto it = pending_interest_hash_table_.find(hash); - PendingInterest *pending_interest = nullptr; - if (it != pending_interest_hash_table_.end()) { - it->second.cancelTimer(); - pending_interest = &it->second; - pending_interest->setInterest(interest); - } else { - auto pend_int = pending_interest_hash_table_.try_emplace( - hash, worker_.getIoService(), interest); - pending_interest = &pend_int.first->second; - } + void matchContentObjectInPIT(ContentObject &content_object) { + uint32_t hash = getHash(content_object.getName()); + auto it = pending_interest_hash_table_.find(hash); + if (it != pending_interest_hash_table_.end()) { + DLOG_IF(INFO, VLOG_IS_ON(3)) << "Found pending interest."; - pending_interest->setOnContentObjectCallback( - std::move(on_content_object_callback)); - pending_interest->setOnTimeoutCallback( - std::move(on_interest_timeout_callback)); + PendingInterest &pend_interest = it->second; + pend_interest.cancelTimer(); + auto _int = pend_interest.getInterest(); + auto callback = pend_interest.getOnDataCallback(); + pending_interest_hash_table_.erase(it); - auto self = weak_from_this(); - pending_interest->startCountdown( - portal_details::makeCustomAllocatorHandler( - async_callback_memory_, - [self, hash, seq](const std::error_code &ec) { - if (TRANSPORT_EXPECT_FALSE(ec.operator bool())) { - return; - } + if (is_consumer_) { + // Send object is for the app + if (callback != UNSET_CALLBACK) { + callback(*_int, content_object); + } else if (transport_callback_) { + transport_callback_->onContentObject(*_int, content_object); + } + } else { + // Send content object to the network + io_module_->send(content_object); + } + } else if (is_consumer_) { + DLOG_IF(INFO, VLOG_IS_ON(3)) + << "No interest pending for received content object."; + } + } - if (auto ptr = self.lock()) { - ptr->timerHandler(hash, seq); - } - })); + /** + * Send an interest through to the local forwarder. + * + * @param interest - The pointer to the interest. The ownership of the + * interest is transferred by the caller to portal. + * + * @param on_content_object_callback - If the caller wishes to use a + * different callback to be called for this interest, it can set this + * parameter. Otherwise ConsumerCallback::onContentObject will be used. + * + * @param on_interest_timeout_callback - If the caller wishes to use a + * different callback to be called for this interest, it can set this + * parameter. Otherwise ConsumerCallback::onTimeout will be used. + */ + void sendInterest( + Interest::Ptr &interest, uint32_t lifetime, + OnContentObjectCallback &&on_content_object_callback = UNSET_CALLBACK, + OnInterestTimeoutCallback &&on_interest_timeout_callback = + UNSET_CALLBACK) { + DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - } while (counter++ < n_suffixes); + io_module_->send(*interest); + addInterestToPIT(interest, lifetime, std::move(on_content_object_callback), + std::move(on_interest_timeout_callback)); } /** @@ -423,8 +457,7 @@ class Portal : public ::utils::NonCopyable, void sendContentObject(ContentObject &content_object) { DCHECK(io_module_); DCHECK(std::this_thread::get_id() == worker_.getThreadId()); - - io_module_->send(content_object); + matchContentObjectInPIT(content_object); } /** @@ -582,6 +615,9 @@ class Portal : public ::utils::NonCopyable, void processInterest(Interest &interest) { // Interest for a producer DLOG_IF(INFO, VLOG_IS_ON(3)) << "processInterest " << interest.getName(); + + // Save interest in PIT + addInterestToPIT(interest.shared_from_this(), interest.getLifetime()); if (TRANSPORT_EXPECT_TRUE(transport_callback_ != nullptr)) { transport_callback_->onInterest(interest); } @@ -598,27 +634,7 @@ class Portal : public ::utils::NonCopyable, void processContentObject(ContentObject &content_object) { DLOG_IF(INFO, VLOG_IS_ON(3)) << "processContentObject " << content_object.getName(); - uint32_t hash = getHash(content_object.getName()); - - auto it = pending_interest_hash_table_.find(hash); - if (it != pending_interest_hash_table_.end()) { - DLOG_IF(INFO, VLOG_IS_ON(3)) << "Found pending interest."; - - PendingInterest &pend_interest = it->second; - pend_interest.cancelTimer(); - auto _int = pend_interest.getInterest(); - auto callback = pend_interest.getOnDataCallback(); - pending_interest_hash_table_.erase(it); - - if (callback != UNSET_CALLBACK) { - callback(*_int, content_object); - } else if (transport_callback_) { - transport_callback_->onContentObject(*_int, content_object); - } - } else { - DLOG_IF(INFO, VLOG_IS_ON(3)) - << "No interest pending for received content object."; - } + matchContentObjectInPIT(content_object); } /** @@ -632,7 +648,7 @@ class Portal : public ::utils::NonCopyable, private: portal_details::HandlerMemory async_callback_memory_; - std::unique_ptr<IoModule, void (*)(IoModule *)> io_module_; + std::unique_ptr<IoModule> io_module_; ::utils::EventThread &worker_; diff --git a/libtransport/src/core/prefix.cc b/libtransport/src/core/prefix.cc index 4c1e191e9..00748148f 100644 --- a/libtransport/src/core/prefix.cc +++ b/libtransport/src/core/prefix.cc @@ -13,8 +13,10 @@ * limitations under the License. */ +#include <glog/logging.h> #include <hicn/transport/core/prefix.h> #include <hicn/transport/errors/errors.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/string_tokenizer.h> #ifndef _WIN32 @@ -37,10 +39,6 @@ namespace core { Prefix::Prefix() { std::memset(&ip_prefix_, 0, sizeof(ip_prefix_t)); } -Prefix::Prefix(const char *prefix) : Prefix(std::string(prefix)) {} - -Prefix::Prefix(std::string &&prefix) : Prefix(prefix) {} - Prefix::Prefix(const std::string &prefix) { utils::StringTokenizer st(prefix, "/"); @@ -56,7 +54,7 @@ Prefix::Prefix(const std::string &prefix) { buildPrefix(ip_address, uint16_t(atoi(prefix_length.c_str())), family); } -Prefix::Prefix(std::string &prefix, uint16_t prefix_length) { +Prefix::Prefix(const std::string &prefix, uint16_t prefix_length) { int family = get_addr_family(prefix.c_str()); buildPrefix(prefix, prefix_length, family); } @@ -73,12 +71,14 @@ Prefix::Prefix(const core::Name &content_name, uint16_t prefix_length) { ip_prefix_.family = family; } -void Prefix::buildPrefix(std::string &prefix, uint16_t prefix_length, +void Prefix::buildPrefix(const std::string &prefix, uint16_t prefix_length, int family) { if (!checkPrefixLengthAndAddressFamily(prefix_length, family)) { throw errors::InvalidIpAddressException(); } + std::memset(&ip_prefix_, 0, sizeof(ip_prefix_t)); + int ret; switch (family) { case AF_INET: @@ -131,62 +131,67 @@ std::unique_ptr<Sockaddr> Prefix::toSockaddr() const { uint16_t Prefix::getPrefixLength() const { return ip_prefix_.len; } Prefix &Prefix::setPrefixLength(uint16_t prefix_length) { + if (!checkPrefixLengthAndAddressFamily(prefix_length, ip_prefix_.family)) { + throw errors::InvalidIpAddressException(); + } + ip_prefix_.len = (u8)prefix_length; return *this; } int Prefix::getAddressFamily() const { return ip_prefix_.family; } -Prefix &Prefix::setAddressFamily(int address_family) { - ip_prefix_.family = address_family; - return *this; -} - std::string Prefix::getNetwork() const { if (!checkPrefixLengthAndAddressFamily(ip_prefix_.len, ip_prefix_.family)) { throw errors::InvalidIpAddressException(); } - std::size_t size = - ip_prefix_.family == 4 + AF_INET ? INET_ADDRSTRLEN : INET6_ADDRSTRLEN; - - std::string network(size, 0); + char buffer[INET6_ADDRSTRLEN]; - if (ip_prefix_ntop_short(&ip_prefix_, (char *)network.c_str(), size) < 0) { + if (ip_prefix_ntop_short(&ip_prefix_, buffer, INET6_ADDRSTRLEN) < 0) { throw errors::RuntimeException( "Impossible to retrieve network from ip address."); } - return network; + return buffer; } -int Prefix::contains(const ip_address_t &content_name) const { - int res = - ip_address_cmp(&content_name, &(ip_prefix_.address), ip_prefix_.family); +bool Prefix::contains(const ip_address_t &content_name) const { + uint64_t mask[2] = {0, 0}; + auto content_name_copy = content_name; + auto network_copy = ip_prefix_.address; - if (ip_prefix_.len != (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN_BITS - : IPV4_ADDR_LEN_BITS)) { - const u8 *ip_prefix_buffer = - ip_address_get_buffer(&(ip_prefix_.address), ip_prefix_.family); - const u8 *content_name_buffer = - ip_address_get_buffer(&content_name, ip_prefix_.family); - uint8_t mask = 0xFF >> (ip_prefix_.len % 8); - mask = ~mask; + auto prefix_length = getPrefixLength(); + if (ip_prefix_.family == AF_INET) { + prefix_length += 3 * IPV4_ADDR_LEN_BITS; + } - res += (ip_prefix_buffer[ip_prefix_.len] & mask) == - (content_name_buffer[ip_prefix_.len] & mask); + if (prefix_length == 0) { + mask[0] = mask[1] = 0; + } else if (prefix_length <= 64) { + mask[0] = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); + mask[1] = 0; + } else if (prefix_length == 128) { + mask[0] = mask[1] = 0xffffffffffffffff; + } else { + prefix_length -= 64; + mask[0] = 0xffffffffffffffff; + mask[1] = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); } - return res; -} + // Apply mask + content_name_copy.v6.as_u64[0] &= mask[0]; + content_name_copy.v6.as_u64[1] &= mask[1]; -int Prefix::contains(const core::Name &content_name) const { - return contains(content_name.toIpAddress().address); + network_copy.v6.as_u64[0] &= mask[0]; + network_copy.v6.as_u64[1] &= mask[1]; + + return ip_address_cmp(&network_copy, &content_name_copy, ip_prefix_.family) == + 0; } -Name Prefix::getName() const { - std::string s(getNetwork()); - return Name(s); +bool Prefix::contains(const core::Name &content_name) const { + return contains(content_name.toIpAddress().address); } /* @@ -199,8 +204,8 @@ Name Prefix::getName(const core::Name &mask, const core::Name &components, ip_prefix_.family != components.getAddressFamily() || ip_prefix_.family != content_name.getAddressFamily()) throw errors::RuntimeException( - "Prefix, mask, components and content name are not of the same address " - "family"); + "Prefix, mask, components and content name are not of the same" + "address family"); ip_address_t mask_ip = mask.toIpAddress().address; ip_address_t component_ip = components.toIpAddress().address; @@ -222,32 +227,6 @@ Name Prefix::getName(const core::Name &mask, const core::Name &components, return Name(ip_prefix_.family, (uint8_t *)&name_ip); } -Name Prefix::getRandomName() const { - ip_address_t name_ip = ip_prefix_.address; - u8 *name_ip_buffer = - const_cast<u8 *>(ip_address_get_buffer(&name_ip, ip_prefix_.family)); - - int addr_len = - (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN * 8 : IPV4_ADDR_LEN * 8) - - ip_prefix_.len; - - size_t size = (size_t)ceil((float)addr_len / 8.0); - uint8_t *buffer = (uint8_t *)malloc(sizeof(uint8_t) * size); - - RAND_bytes(buffer, (int)size); - - int j = 0; - for (uint8_t i = (uint8_t)ceil((float)ip_prefix_.len / 8.0); - i < (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN : IPV4_ADDR_LEN); - i++) { - name_ip_buffer[i] = buffer[j]; - j++; - } - free(buffer); - - return Name(ip_prefix_.family, (uint8_t *)&name_ip); -} - /* * Map a name in a different name prefix to this name prefix */ @@ -276,47 +255,66 @@ Name Prefix::mapName(const core::Name &content_name) const { return Name(ip_prefix_.family, (uint8_t *)&name_ip); } -Prefix &Prefix::setNetwork(std::string &network) { - if (!inet_pton(AF_INET6, network.c_str(), ip_prefix_.address.v6.buffer)) { +Prefix &Prefix::setNetwork(const std::string &network) { + if (!ip_address_pton(network.c_str(), &ip_prefix_.address)) { throw errors::RuntimeException("The network name is not valid."); } return *this; } +Name Prefix::makeName() const { return makeNameWithIndex(0); } + Name Prefix::makeRandomName() const { - if (ip_prefix_.family == AF_INET6) { - std::default_random_engine eng((std::random_device())()); - std::uniform_int_distribution<uint32_t> idis( - 0, std::numeric_limits<uint32_t>::max()); - uint64_t random_number = idis(eng); - - uint32_t hash_size_bits = IPV6_ADDR_LEN_BITS - ip_prefix_.len; - uint64_t ip_address[2]; - memcpy(ip_address, ip_prefix_.address.v6.buffer, sizeof(uint64_t)); - memcpy(ip_address + 1, ip_prefix_.address.v6.buffer + 8, sizeof(uint64_t)); - std::string network(IPV6_ADDR_LEN * 3, 0); - - // Let's do the magic ;) - int shift_size = hash_size_bits > sizeof(random_number) * 8 - ? sizeof(random_number) * 8 - : hash_size_bits; - - ip_address[1] >>= shift_size; - ip_address[1] <<= shift_size; - - ip_address[1] |= random_number >> (sizeof(uint64_t) * 8 - shift_size); - - if (!inet_ntop(ip_prefix_.family, ip_address, (char *)network.c_str(), - IPV6_ADDR_LEN * 3)) { - throw errors::RuntimeException( - "Impossible to retrieve network from ip address."); - } + std::default_random_engine eng((std::random_device())()); + std::uniform_int_distribution<uint32_t> idis( + 0, std::numeric_limits<uint32_t>::max()); + uint64_t random_number = idis(eng); + + return makeNameWithIndex(random_number); +} + +Name Prefix::makeNameWithIndex(std::uint64_t index) const { + uint16_t prefix_length = getPrefixLength(); - return Name(network); + Name ret; + + // Adjust prefix length depending on the address family + if (getAddressFamily() == AF_INET) { + // Sanity check + DCHECK(prefix_length <= 32); + // Convert prefix length to ip46_address_t prefix length + prefix_length += IPV4_ADDR_LEN_BITS * 3; + } + + std::memcpy(ret.getStructReference().prefix.v6.as_u8, + ip_prefix_.address.v6.as_u8, sizeof(ip_address_t)); + + // Convert index in network byte order + index = portability::host_to_net(index); + + // Apply mask + uint64_t mask; + if (prefix_length == 0) { + mask = 0; + } else if (prefix_length <= 64) { + mask = 0; + } else if (prefix_length == 128) { + mask = 0xffffffffffffffff; + } else { + prefix_length -= 64; + mask = portability::host_to_net((uint64_t)(~0) << (64 - prefix_length)); } - return Name(); + ret.getStructReference().prefix.v6.as_u64[1] &= mask; + // Eventually truncate index if too big + index &= ~mask; + + // Apply index + ret.getStructReference().prefix.v6.as_u64[1] |= index; + + // Done + return ret; } bool Prefix::checkPrefixLengthAndAddressFamily(uint16_t prefix_length, diff --git a/libtransport/src/core/udp_connector.cc b/libtransport/src/core/udp_connector.cc index ee0c7ea9c..5d8e76bb1 100644 --- a/libtransport/src/core/udp_connector.cc +++ b/libtransport/src/core/udp_connector.cc @@ -56,9 +56,9 @@ void UdpTunnelConnector::send(Packet &packet) { void UdpTunnelConnector::send(const utils::MemBuf::Ptr &buffer) { auto self = shared_from_this(); - io_service_.post([self, pkt{buffer}]() { + io_service_.post([self, buffer]() { bool write_in_progress = !self->output_buffer_.empty(); - self->output_buffer_.push_back(std::move(pkt)); + self->output_buffer_.push_back(std::move(buffer)); if (TRANSPORT_EXPECT_TRUE(self->state_ == State::CONNECTED)) { if (!write_in_progress) { self->doSendPacket(self); @@ -201,6 +201,8 @@ void UdpTunnelConnector::writeHandler() { ptr->writeHandler(); } }); + } else { + sent_callback_(this, make_error_code(core_error::success)); } } diff --git a/libtransport/src/core/udp_connector.h b/libtransport/src/core/udp_connector.h index 65821852d..002f4ca9f 100644 --- a/libtransport/src/core/udp_connector.h +++ b/libtransport/src/core/udp_connector.h @@ -62,7 +62,7 @@ class UdpTunnelConnector : public Connector { #endif socket_(socket), resolver_(io_service_), - remote_endpoint_send_(std::forward<EndpointType &&>(remote_endpoint)), + remote_endpoint_send_(std::forward<EndpointType>(remote_endpoint)), timer_(io_service_), #ifdef LINUX send_timer_(io_service_), diff --git a/libtransport/src/core/udp_listener.cc b/libtransport/src/core/udp_listener.cc index c67673392..caa97e0ee 100644 --- a/libtransport/src/core/udp_listener.cc +++ b/libtransport/src/core/udp_listener.cc @@ -5,6 +5,7 @@ #include <core/udp_connector.h> #include <core/udp_listener.h> #include <glog/logging.h> +#include <hicn/transport/portability/endianess.h> #include <hicn/transport/utils/hash.h> #ifndef LINUX @@ -16,7 +17,7 @@ size_t hash<asio::ip::udp::endpoint>::operator()( : utils::hash::fnv32_buf( endpoint.address().to_v6().to_bytes().data(), 16); uint16_t port = endpoint.port(); - return utils::hash::fnv32_buf(&port, 2, hash_ip); + return utils::hash::fnv32_buf(&port, 2, (unsigned int)hash_ip); } } // namespace std #endif @@ -83,7 +84,8 @@ void UdpTunnelListener::readHandler(const std::error_code &ec) { std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin_addr), address_bytes.size(), address_bytes.begin()); address_v4 address(address_bytes); - remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin_port)); + remote_endpoint_ = + udp::endpoint(address, portability::net_to_host(addr->sin_port)); } else { auto addr = reinterpret_cast<struct sockaddr_in6 *>( &remote_endpoints_[current_position_]); @@ -91,7 +93,8 @@ void UdpTunnelListener::readHandler(const std::error_code &ec) { std::copy_n(reinterpret_cast<uint8_t *>(&addr->sin6_addr), address_bytes.size(), address_bytes.begin()); address_v6 address(address_bytes); - remote_endpoint_ = udp::endpoint(address, ntohs(addr->sin6_port)); + remote_endpoint_ = + udp::endpoint(address, portability::net_to_host(addr->sin6_port)); } /** diff --git a/libtransport/src/core/udp_listener.h b/libtransport/src/core/udp_listener.h index 813520309..d8095a262 100644 --- a/libtransport/src/core/udp_listener.h +++ b/libtransport/src/core/udp_listener.h @@ -40,7 +40,7 @@ class UdpTunnelListener socket_(std::make_shared<asio::ip::udp::socket>(io_service_, endpoint.protocol())), local_endpoint_(endpoint), - receive_callback_(std::forward<ReceiveCallback &&>(receive_callback)), + receive_callback_(std::forward<ReceiveCallback>(receive_callback)), #ifndef LINUX read_msg_(nullptr, 0) #else @@ -63,12 +63,12 @@ class UdpTunnelListener void close(); int deleteConnector(Connector *connector) { - return connectors_.erase(connector->getConnectorId()); + return (int)connectors_.erase(connector->getConnectorId()); } template <typename ReceiveCallback> void setReceiveCallback(ReceiveCallback &&callback) { - receive_callback_ = std::forward<ReceiveCallback &&>(callback); + receive_callback_ = std::forward<ReceiveCallback>(callback); } Connector *findConnector(Connector::Id connId) { |