diff options
Diffstat (limited to 'libtransport/src')
52 files changed, 3640 insertions, 347 deletions
diff --git a/libtransport/src/hicn/transport/config.h.in b/libtransport/src/hicn/transport/config.h.in index ef47affda..4e9a0f262 100644 --- a/libtransport/src/hicn/transport/config.h.in +++ b/libtransport/src/hicn/transport/config.h.in @@ -25,6 +25,10 @@ #cmakedefine ASIO_STANDALONE #endif +#ifndef SECURE_HICNTRANSPORT +#cmakedefine SECURE_HICNTRANSPORT +#endif + #define RAAQM_CONFIG_PATH "@raaqm_config_path@" #cmakedefine __vpp__ diff --git a/libtransport/src/hicn/transport/core/forwarder_interface.h b/libtransport/src/hicn/transport/core/forwarder_interface.h index 380ce76bd..63b4a2eda 100644 --- a/libtransport/src/hicn/transport/core/forwarder_interface.h +++ b/libtransport/src/hicn/transport/core/forwarder_interface.h @@ -19,6 +19,8 @@ #include <hicn/transport/core/udp_socket_connector.h> #include <hicn/transport/portability/portability.h> #include <hicn/transport/utils/chrono_typedefs.h> +#include <hicn/transport/utils/log.h> + #include <deque> namespace transport { @@ -95,6 +97,9 @@ class ForwarderInterface { packet.setLocator(inet6_address_); } + // TRANSPORT_LOGI("Sending packet %s at %lu", + // packet.getName().toString().c_str(), + // utils::SteadyClock::now().time_since_epoch().count()); packet.setChecksum(); connector_.send(packet.acquireMemBufReference()); } diff --git a/libtransport/src/hicn/transport/core/memif_connector.cc b/libtransport/src/hicn/transport/core/memif_connector.cc index 43dfab345..5e37c882a 100644 --- a/libtransport/src/hicn/transport/core/memif_connector.cc +++ b/libtransport/src/hicn/transport/core/memif_connector.cc @@ -83,12 +83,11 @@ void MemifConnector::init() { nullptr, nullptr, nullptr); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_init: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_init: %s", memif_strerror(err)); } } void MemifConnector::connect(uint32_t memif_id, long memif_mode) { - TRANSPORT_LOGI("Creating memif"); state_ = ConnectorState::CONNECTING; memif_id_ = memif_id; @@ -108,7 +107,7 @@ void MemifConnector::connect(uint32_t memif_id, long memif_mode) { int fd = -1; err = memif_get_queue_efd(memif_connection_->conn, 0, &fd); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_get_queue_efd: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_get_queue_efd: %s", memif_strerror(err)); return; } @@ -142,15 +141,12 @@ int MemifConnector::createMemif(uint32_t index, uint8_t mode, char *s) { int err; - err= memif_create_socket (&args.socket, socket_filename_.c_str(), - nullptr); + err = memif_create_socket(&args.socket, socket_filename_.c_str(), nullptr); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - throw errors::RuntimeException(memif_strerror(err)); + throw errors::RuntimeException(memif_strerror(err)); } - TRANSPORT_LOGD("Socket filename: %s", socket_filename_.c_str()); - args.interface_id = index; /* last argument for memif_create (void * private_ctx) is used by user to identify connection. this context is returned with callbacks */ @@ -202,11 +198,11 @@ int MemifConnector::deleteMemif() { err = memif_delete(&c->conn); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_delete: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_delete: %s", memif_strerror(err)); } if (TRANSPORT_EXPECT_FALSE(c->conn != nullptr)) { - TRANSPORT_LOGI("memif delete fail"); + TRANSPORT_LOGE("memif delete fail"); } return 0; @@ -252,7 +248,7 @@ int MemifConnector::controlFdUpdate(int fd, uint8_t events, void *private_ctx) { memif_err = memif_control_fd_handler(evt.data.fd, event); if (TRANSPORT_EXPECT_FALSE(memif_err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_control_fd_handler: %s", + TRANSPORT_LOGE("memif_control_fd_handler: %s", memif_strerror(memif_err)); } @@ -269,12 +265,10 @@ int MemifConnector::bufferAlloc(long n, uint16_t qid) { err = memif_buffer_alloc(c->conn, qid, c->tx_bufs, n, &r, 2000); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGD("memif_buffer_alloc: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_buffer_alloc: %s", memif_strerror(err)); } c->tx_buf_num += r; - TRANSPORT_LOGD("allocated %d/%ld buffers, %u free buffers", r, n, - MAX_MEMIF_BUFS - c->tx_buf_num); return r; } @@ -287,18 +281,17 @@ int MemifConnector::txBurst(uint16_t qid) { err = memif_tx_burst(c->conn, qid, c->tx_bufs, c->tx_buf_num, &r); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_tx_burst: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); } // err = memif_refill_queue(c->conn, qid, r, 0); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_tx_burst: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_tx_burst: %s", memif_strerror(err)); c->tx_buf_num -= r; return -1; } - TRANSPORT_LOGD("tx: %d/%u", r, c->tx_buf_num); c->tx_buf_num -= r; return 0; } @@ -322,7 +315,6 @@ void MemifConnector::processInputBuffer() { /* informs user about connected status. private_ctx is used by user to identify connection (multiple connections WIP) */ int MemifConnector::onConnect(memif_conn_handle_t conn, void *private_ctx) { - TRANSPORT_LOGI("memif connected!\n"); MemifConnector *connector = (MemifConnector *)private_ctx; connector->state_ = ConnectorState::CONNECTED; memif_refill_queue(conn, 0, -1, 0); @@ -333,11 +325,8 @@ int MemifConnector::onConnect(memif_conn_handle_t conn, void *private_ctx) { /* informs user about disconnected status. private_ctx is used by user to identify connection (multiple connections WIP) */ int MemifConnector::onDisconnect(memif_conn_handle_t conn, void *private_ctx) { - TRANSPORT_LOGI("memif disconnected!"); MemifConnector *connector = (MemifConnector *)private_ctx; connector->state_ = ConnectorState::CLOSED; - TRANSPORT_LOGI("Packet to process: %u", - connector->memif_connection_->tx_buf_num); return 0; } @@ -357,14 +346,14 @@ int MemifConnector::onInterrupt(memif_conn_handle_t conn, void *private_ctx, if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS && err != MEMIF_ERR_NOBUF)) { - TRANSPORT_LOGI("memif_rx_burst: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_rx_burst: %s", memif_strerror(err)); goto error; } c->rx_buf_num += rx; if (TRANSPORT_EXPECT_TRUE(connector->io_service_.stopped())) { - TRANSPORT_LOGD("socket stopped: ignoring %u packets", rx); + TRANSPORT_LOGE("socket stopped: ignoring %u packets", rx); goto error; } @@ -378,7 +367,7 @@ int MemifConnector::onInterrupt(memif_conn_handle_t conn, void *private_ctx, packet->append(packet_length); if (!connector->input_buffer_.push(std::move(packet))) { - TRANSPORT_LOGI("Error pushing packet. Ring buffer full."); + TRANSPORT_LOGE("Error pushing packet. Ring buffer full."); // TODO Here we should consider the possibility to signal the congestion // to the application, that would react properly (e.g. slow down @@ -392,13 +381,11 @@ int MemifConnector::onInterrupt(memif_conn_handle_t conn, void *private_ctx, err = memif_refill_queue(conn, qid, rx, 0); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_buffer_free: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); } c->rx_buf_num -= rx; - TRANSPORT_LOGD("freed %d buffers. %u/%u alloc/free buffers", rx, rx, - MAX_MEMIF_BUFS - rx); } while (ret_val == MEMIF_ERR_NOBUF); connector->io_service_.post( @@ -410,12 +397,10 @@ error: err = memif_refill_queue(c->conn, qid, rx, 0); if (TRANSPORT_EXPECT_FALSE(err != MEMIF_ERR_SUCCESS)) { - TRANSPORT_LOGI("memif_buffer_free: %s", memif_strerror(err)); + TRANSPORT_LOGE("memif_buffer_free: %s", memif_strerror(err)); } c->rx_buf_num -= rx; - TRANSPORT_LOGD("freed %d buffers. %u/%u alloc/free buffers", rx, - c->rx_buf_num, MAX_MEMIF_BUFS - c->rx_buf_num); return 0; } @@ -430,9 +415,6 @@ void MemifConnector::close() { if (memif_worker_ && memif_worker_->joinable()) { memif_worker_->join(); - TRANSPORT_LOGD("Memif worker joined"); - } else { - TRANSPORT_LOGD("Memif worker not joined"); } } } @@ -467,7 +449,7 @@ int MemifConnector::doSend() { if (TRANSPORT_EXPECT_FALSE( (n = bufferAlloc(max, memif_connection_->tx_qid)) < 0)) { - TRANSPORT_LOGI("Error allocating buffers."); + TRANSPORT_LOGE("Error allocating buffers."); return -1; } @@ -487,8 +469,6 @@ int MemifConnector::doSend() { memif_connection_->tx_bufs[i].len = uint32_t(offset); - TRANSPORT_LOGD("Packet size : %zu", offset); - output_buffer_.pop_front(); } diff --git a/libtransport/src/hicn/transport/core/name.cc b/libtransport/src/hicn/transport/core/name.cc index 46ef98948..85e2b8565 100644 --- a/libtransport/src/hicn/transport/core/name.cc +++ b/libtransport/src/hicn/transport/core/name.cc @@ -116,9 +116,9 @@ std::string Name::toString() const { return name_string; } -uint32_t Name::getHash32() const { +uint32_t Name::getHash32(bool consider_suffix) const { uint32_t hash; - if (hicn_name_hash((hicn_name_t *)&name_, &hash) < 0) { + if (hicn_name_hash(&name_, &hash, consider_suffix) < 0) { throw errors::RuntimeException("Error computing the hash of the name!"); } return hash; @@ -206,6 +206,17 @@ std::ostream &operator<<(std::ostream &os, const Name &name) { return os; } +size_t hash<transport::core::Name>::operator()( + const transport::core::Name &name) const { + return name.getHash32(false); +} + +size_t compare2<transport::core::Name>::operator()( + const transport::core::Name &name1, + const transport::core::Name &name2) const { + return name1.equals(name2, false); +} + } // end namespace core } // end namespace transport diff --git a/libtransport/src/hicn/transport/core/name.h b/libtransport/src/hicn/transport/core/name.h index 35625ddd1..ea72797ad 100644 --- a/libtransport/src/hicn/transport/core/name.h +++ b/libtransport/src/hicn/transport/core/name.h @@ -81,7 +81,7 @@ class Name { bool equals(const Name &name, bool consider_segment = true) const; - uint32_t getHash32() const; + uint32_t getHash32(bool consider_suffix = true) const; void clear(); @@ -112,10 +112,27 @@ class Name { std::ostream &operator<<(std::ostream &os, const Name &name); +template <typename T> +struct hash {}; + +template <> +struct hash<transport::core::Name> { + size_t operator()(const transport::core::Name &name) const; +}; + +template <typename T> +struct compare2 {}; + +template <> +struct compare2<transport::core::Name> { + size_t operator()(const transport::core::Name &name1, const transport::core::Name &name2) const; +}; + } // end namespace core } // end namespace transport + namespace std { template <> struct hash<transport::core::Name> { diff --git a/libtransport/src/hicn/transport/core/packet.cc b/libtransport/src/hicn/transport/core/packet.cc index 954266664..817f8de66 100644 --- a/libtransport/src/hicn/transport/core/packet.cc +++ b/libtransport/src/hicn/transport/core/packet.cc @@ -230,7 +230,7 @@ Packet::Format Packet::getFormat() const { return format_; } -const std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() { +const std::shared_ptr<utils::MemBuf> Packet::acquireMemBufReference() const { return packet_; } diff --git a/libtransport/src/hicn/transport/core/packet.h b/libtransport/src/hicn/transport/core/packet.h index 4ec93205a..35c8606c9 100644 --- a/libtransport/src/hicn/transport/core/packet.h +++ b/libtransport/src/hicn/transport/core/packet.h @@ -99,7 +99,7 @@ class Packet : public std::enable_shared_from_this<Packet> { std::size_t headerSize() const; - const std::shared_ptr<utils::MemBuf> acquireMemBufReference(); + const std::shared_ptr<utils::MemBuf> acquireMemBufReference() const; virtual const Name &getName() const = 0; diff --git a/libtransport/src/hicn/transport/core/portal.h b/libtransport/src/hicn/transport/core/portal.h index 4f161e4c8..c6e11ada6 100644 --- a/libtransport/src/hicn/transport/core/portal.h +++ b/libtransport/src/hicn/transport/core/portal.h @@ -658,9 +658,6 @@ class Portal { consumer_callback_->onContentObject(std::move(_int), std::move(content_object)); } - } else { - TRANSPORT_LOGD("No pending interests for current content (%s)", - content_object->getName().toString().c_str()); } } diff --git a/libtransport/src/hicn/transport/core/prefix.cc b/libtransport/src/hicn/transport/core/prefix.cc index 648c0a67b..59898ab70 100644 --- a/libtransport/src/hicn/transport/core/prefix.cc +++ b/libtransport/src/hicn/transport/core/prefix.cc @@ -29,6 +29,8 @@ extern "C" { #include <memory> #include <random> +#include <openssl/rand.h> + namespace transport { namespace core { @@ -80,11 +82,11 @@ void Prefix::buildPrefix(std::string &prefix, uint16_t prefix_length, int ret; switch (family) { case AF_INET: - ret = inet_pton(AF_INET, prefix.c_str(), ip_prefix_.address.v4.buffer); - break; + ret = inet_pton(AF_INET, prefix.c_str(), ip_prefix_.address.v4.buffer); + break; case AF_INET6: - ret = inet_pton(AF_INET6, prefix.c_str(), ip_prefix_.address.v6.buffer); - break; + ret = inet_pton(AF_INET6, prefix.c_str(), ip_prefix_.address.v6.buffer); + break; default: throw errors::InvalidIpAddressException(); } @@ -133,8 +135,7 @@ Prefix &Prefix::setAddressFamily(int address_family) { } std::string Prefix::getNetwork() const { - if (!checkPrefixLengthAndAddressFamily(ip_prefix_.len, - ip_prefix_.family)) { + if (!checkPrefixLengthAndAddressFamily(ip_prefix_.len, ip_prefix_.family)) { throw errors::InvalidIpAddressException(); } @@ -151,11 +152,123 @@ std::string Prefix::getNetwork() const { return network; } +int Prefix::contains(const ip_address_t &content_name) const { + int res = + ip_address_cmp(&content_name, &(ip_prefix_.address), ip_prefix_.family); + + if (ip_prefix_.len != (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN_BITS + : IPV4_ADDR_LEN_BITS)) { + const u8 *ip_prefix_buffer = + ip_address_get_buffer(&(ip_prefix_.address), ip_prefix_.family); + const u8 *content_name_buffer = + ip_address_get_buffer(&content_name, ip_prefix_.family); + uint8_t mask = 0xFF >> (ip_prefix_.len % 8); + mask = ~mask; + + res += (ip_prefix_buffer[ip_prefix_.len] & mask) == + (content_name_buffer[ip_prefix_.len] & mask); + } + + return res; +} + +int Prefix::contains(const core::Name &content_name) const { + return contains(content_name.toIpAddress().address); +} + Name Prefix::getName() const { std::string s(getNetwork()); return Name(s); } +/* + * Mask is used to apply the components to a content name that belong to this + * prefix + */ +Name Prefix::getName(const core::Name &mask, const core::Name &components, + const core::Name &content_name) const { + if (ip_prefix_.family != mask.getAddressFamily() || + ip_prefix_.family != components.getAddressFamily() || + ip_prefix_.family != content_name.getAddressFamily()) + throw errors::RuntimeException( + "Prefix, mask, components and content name are not of the same address " + "family"); + + ip_address_t mask_ip = mask.toIpAddress().address; + ip_address_t component_ip = components.toIpAddress().address; + ip_address_t name_ip = content_name.toIpAddress().address; + const u8 *mask_ip_buffer = ip_address_get_buffer(&mask_ip, ip_prefix_.family); + const u8 *component_ip_buffer = + ip_address_get_buffer(&component_ip, ip_prefix_.family); + u8 *name_ip_buffer = + const_cast<u8 *>(ip_address_get_buffer(&name_ip, ip_prefix_.family)); + + int addr_len = ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN : IPV4_ADDR_LEN; + + for (int i = 0; i < addr_len; i++) { + if (mask_ip_buffer[i]) { + name_ip_buffer[i] = component_ip_buffer[i] & mask_ip_buffer[i]; + } + } + + if (this->contains(name_ip)) + throw errors::RuntimeException("Mask overrides the prefix"); + return Name(ip_prefix_.family, (uint8_t *)&name_ip); +} + +Name Prefix::getRandomName() const { + ip_address_t name_ip = ip_prefix_.address; + u8 *name_ip_buffer = + const_cast<u8 *>(ip_address_get_buffer(&name_ip, ip_prefix_.family)); + + int addr_len = + (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN * 8 : IPV4_ADDR_LEN * 8) - + ip_prefix_.len; + + size_t size = (size_t)ceil((float)addr_len / 8.0); + uint8_t buffer[size]; + + RAND_bytes(buffer, size); + + int j = 0; + for (uint8_t i = (uint8_t)ceil((float)ip_prefix_.len / 8.0); + i < (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN : IPV4_ADDR_LEN); + i++) { + name_ip_buffer[i] = buffer[j]; + j++; + } + + return Name(ip_prefix_.family, (uint8_t *)&name_ip); +} + +/* + * Map a name in a different name prefix to this name prefix + */ +Name Prefix::mapName(const core::Name &content_name) const { + if (ip_prefix_.family != content_name.getAddressFamily()) + throw errors::RuntimeException( + "Prefix content name are not of the same address " + "family"); + + ip_address_t name_ip = content_name.toIpAddress().address; + const u8 *ip_prefix_buffer = + ip_address_get_buffer(&(ip_prefix_.address), ip_prefix_.family); + u8 *name_ip_buffer = + const_cast<u8 *>(ip_address_get_buffer(&name_ip, ip_prefix_.family)); + + memcpy(name_ip_buffer, ip_prefix_buffer, ip_prefix_.len / 8); + + if (ip_prefix_.len != (ip_prefix_.family == AF_INET6 ? IPV6_ADDR_LEN_BITS + : IPV4_ADDR_LEN_BITS)) { + uint8_t mask = 0xFF >> (ip_prefix_.len % 8); + name_ip_buffer[ip_prefix_.len / 8 + 1] = + (name_ip_buffer[ip_prefix_.len / 8 + 1] & mask) | + (ip_prefix_buffer[ip_prefix_.len / 8 + 1] & ~mask); + } + + return Name(ip_prefix_.family, (uint8_t *)&name_ip); +} + Prefix &Prefix::setNetwork(std::string &network) { if (!inet_pton(AF_INET6, network.c_str(), ip_prefix_.address.v6.buffer)) { throw errors::RuntimeException("The network name is not valid."); diff --git a/libtransport/src/hicn/transport/core/prefix.h b/libtransport/src/hicn/transport/core/prefix.h index af7c705cf..47971acaf 100644 --- a/libtransport/src/hicn/transport/core/prefix.h +++ b/libtransport/src/hicn/transport/core/prefix.h @@ -42,8 +42,19 @@ class Prefix { std::string getNetwork() const; + int contains(const ip_address_t &content_name) const; + + int contains(const core::Name &content_name) const; + Name getName() const; + Name getRandomName() const; + + Name getName(const core::Name &mask, const core::Name &components, + const core::Name &content_name) const; + + Name mapName(const core::Name &content_name) const; + Prefix &setNetwork(std::string &network); int getAddressFamily(); diff --git a/libtransport/src/hicn/transport/interfaces/CMakeLists.txt b/libtransport/src/hicn/transport/interfaces/CMakeLists.txt index 1f3c29b1f..88b83e9d4 100644 --- a/libtransport/src/hicn/transport/interfaces/CMakeLists.txt +++ b/libtransport/src/hicn/transport/interfaces/CMakeLists.txt @@ -17,6 +17,11 @@ list(APPEND HEADER_FILES ${CMAKE_CURRENT_SOURCE_DIR}/socket.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.h ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.h ${CMAKE_CURRENT_SOURCE_DIR}/publication_options.h ${CMAKE_CURRENT_SOURCE_DIR}/socket_options_default_values.h @@ -26,11 +31,29 @@ list(APPEND HEADER_FILES ) list(APPEND SOURCE_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_producer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rtc_socket_producer.cc ${CMAKE_CURRENT_SOURCE_DIR}/socket_consumer.cc ${CMAKE_CURRENT_SOURCE_DIR}/callbacks.cc ) +if (${OPENSSL_VERSION} VERSION_EQUAL "1.1.1a" OR ${OPENSSL_VERSION} VERSION_GREATER "1.1.1a") + list(APPEND SOURCE_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.cc + ) + + list(APPEND HEADER_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/tls_rtc_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_producer.h + ${CMAKE_CURRENT_SOURCE_DIR}/tls_socket_consumer.h + ${CMAKE_CURRENT_SOURCE_DIR}/p2psecure_socket_consumer.h + ) +endif() + set(SOURCE_FILES ${SOURCE_FILES} PARENT_SCOPE) set(HEADER_FILES ${HEADER_FILES} PARENT_SCOPE) diff --git a/libtransport/src/hicn/transport/interfaces/callbacks.h b/libtransport/src/hicn/transport/interfaces/callbacks.h index 6de48d14b..9d3d57992 100644 --- a/libtransport/src/hicn/transport/interfaces/callbacks.h +++ b/libtransport/src/hicn/transport/interfaces/callbacks.h @@ -122,4 +122,4 @@ extern std::nullptr_t VOID_HANDLER; } // namespace interface -} // namespace transport
\ No newline at end of file +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.cc b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.cc new file mode 100644 index 000000000..ec966e509 --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.cc @@ -0,0 +1,382 @@ +#include <hicn/transport/interfaces/p2psecure_socket_consumer.h> +#include <openssl/bio.h> +#include <openssl/ssl.h> +#include <openssl/tls1.h> + +#include <random> + +namespace transport { + +namespace interface { + +void P2PSecureConsumerSocket::setInterestPayload( + ConsumerSocket &c, const core::Interest &interest) { + Interest &int2 = const_cast<Interest &>(interest); + random_suffix_ = int2.getName().getSuffix(); + + if (payload_ != NULL) int2.appendPayload(std::move(payload_)); +} + +// implement void readBufferAvailable(), size_t maxBufferSize() const override, +// void readError(), void readSuccess(). getReadBuffer() and readDataAvailable() +// must be implemented even if empty. + +/* Return the number of read bytes in the return param */ +int readOld(BIO *b, char *buf, int size) { + if (size < 0) return size; + + P2PSecureConsumerSocket *socket; + socket = (P2PSecureConsumerSocket *)BIO_get_data(b); + + std::unique_lock<std::mutex> lck(socket->mtx_); + + if (!socket->something_to_read_) { + if (!socket->transport_protocol_->isRunning()) { + socket->network_name_.setSuffix(socket->random_suffix_); + socket->ConsumerSocket::asyncConsume(socket->network_name_); + } + if (!socket->something_to_read_) socket->cv_.wait(lck); + } + + size_t size_to_read, read; + size_t chain_size = socket->head_->length(); + if (socket->head_->isChained()) + chain_size = socket->head_->computeChainDataLength(); + + if (chain_size > (size_t)size) { + read = size_to_read = (size_t)size; + } else { + read = size_to_read = chain_size; + socket->something_to_read_ = false; + } + + while (size_to_read) { + if (socket->head_->length() < size_to_read) { + std::memcpy(buf, socket->head_->data(), socket->head_->length()); + size_to_read -= socket->head_->length(); + buf += socket->head_->length(); + socket->head_ = socket->head_->pop(); + } else { + std::memcpy(buf, socket->head_->data(), size_to_read); + socket->head_->trimStart(size_to_read); + size_to_read = 0; + } + } + + return read; +} + +/* Return the number of read bytes in readbytes */ +int read(BIO *b, char *buf, size_t size, size_t *readbytes) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = transport::interface::readOld(b, buf, (int)size); + + if (ret <= 0) { + *readbytes = 0; + return ret; + } + + *readbytes = (size_t)ret; + + return 1; +} + +/* Return the number of written bytes in the return param */ +int writeOld(BIO *b, const char *buf, int num) { + P2PSecureConsumerSocket *socket; + socket = (P2PSecureConsumerSocket *)BIO_get_data(b); + + socket->payload_ = utils::MemBuf::copyBuffer(buf, num); + socket->ConsumerSocket::setSocketOption( + ConsumerCallbacksOptions::INTEREST_OUTPUT, + (ConsumerInterestCallback)std::bind( + &P2PSecureConsumerSocket::setInterestPayload, socket, + std::placeholders::_1, std::placeholders::_2)); + + return num; +} + +/* Return the number of written bytes in written */ +int write(BIO *b, const char *buf, size_t size, size_t *written) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = transport::interface::writeOld(b, buf, (int)size); + + if (ret <= 0) { + *written = 0; + return ret; + } + + *written = (size_t)ret; + + return 1; +} + +long ctrl(BIO *b, int cmd, long num, void *ptr) { return 1; } + +int P2PSecureConsumerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char **out, + size_t *outlen, X509 *x, + size_t chainidx, int *al, + void *add_arg) { + if (ext_type == 100) { + *out = (unsigned char *)malloc(4); + *(uint32_t *)*out = 10; + *outlen = 4; + } + return 1; +} + +void P2PSecureConsumerSocket::freeHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char *out, + void *add_arg) { + free(const_cast<unsigned char *>(out)); +} + +int P2PSecureConsumerSocket::parseHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char *in, + size_t inlen, X509 *x, + size_t chainidx, int *al, + void *add_arg) { + P2PSecureConsumerSocket *socket = + reinterpret_cast<P2PSecureConsumerSocket *>(add_arg); + if (ext_type == 100) { + memcpy(&socket->secure_prefix_, in, sizeof(ip_prefix_t)); + } + return 1; +} + +P2PSecureConsumerSocket::P2PSecureConsumerSocket(int handshake_protocol, int transport_protocol) + : ConsumerSocket(handshake_protocol), + name_(), + tls_consumer_(), + buf_pool_(), + decrypted_content_(), + payload_(), + head_(), + something_to_read_(false), + content_downloaded_(false), + random_suffix_(), + secure_prefix_(), + producer_namespace_(), + read_callback_decrypted_(), + mtx_(), + cv_(), + protocol_(transport_protocol) { + /* Create the (d)TLS state */ + const SSL_METHOD *meth = TLS_client_method(); + ctx_ = SSL_CTX_new(meth); + + int result = + SSL_CTX_set_ciphersuites(ctx_, + "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" + "SHA256:TLS_AES_128_GCM_SHA256"); + if (result != 1) { + throw errors::RuntimeException( + "Unable to set cipher list on TLS subsystem. Aborting."); + } + + SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); + SSL_CTX_set_ssl_version(ctx_, meth); + + result = SSL_CTX_add_custom_ext( + ctx_, 100, SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS, + P2PSecureConsumerSocket::addHicnKeyIdCb, + P2PSecureConsumerSocket::freeHicnKeyIdCb, NULL, + P2PSecureConsumerSocket::parseHicnKeyIdCb, this); + + ssl_ = SSL_new(ctx_); + + bio_meth_ = BIO_meth_new(BIO_TYPE_CONNECT, "secure consumer socket"); + BIO_meth_set_read(bio_meth_, transport::interface::readOld); + BIO_meth_set_write(bio_meth_, transport::interface::writeOld); + BIO_meth_set_ctrl(bio_meth_, transport::interface::ctrl); + BIO *bio = BIO_new(bio_meth_); + BIO_set_init(bio, 1); + BIO_set_data(bio, this); + SSL_set_bio(ssl_, bio, bio); + + ConsumerSocket::getSocketOption(MAX_WINDOW_SIZE, old_max_win_); + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, (double)1.0); + + ConsumerSocket::getSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, (double)1.0); + + std::default_random_engine generator; + std::uniform_int_distribution<int> distribution( + 1, std::numeric_limits<uint32_t>::max()); + random_suffix_ = 0; + + this->ConsumerSocket::setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + this); +}; + +P2PSecureConsumerSocket::~P2PSecureConsumerSocket() { + BIO_meth_free(bio_meth_); + SSL_shutdown(ssl_); +} + +int P2PSecureConsumerSocket::consume(const Name &name) { + if (transport_protocol_->isRunning()) { + return CONSUMER_BUSY; + } + + if ((SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, (double)1.0); + network_name_ = producer_namespace_.getRandomName(); + network_name_.setSuffix(0); + int result = SSL_connect(this->ssl_); + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, old_max_win_); + ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + if (result != 1) + throw errors::RuntimeException("Unable to perform client handshake"); + } + std::shared_ptr<Name> prefix_name = std::make_shared<Name>( + secure_prefix_.family, + ip_address_get_buffer(&(secure_prefix_.address), secure_prefix_.family)); + std::shared_ptr<Prefix> prefix = + std::make_shared<Prefix>(*prefix_name, secure_prefix_.len); + TLSConsumerSocket tls_consumer(this->protocol_, this->ssl_); + + ConsumerTimerCallback *stats_summary_callback = nullptr; + this->getSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + &stats_summary_callback); + + uint32_t lifetime; + this->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, lifetime); + tls_consumer.setSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + tls_consumer.setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + read_callback_decrypted_); + tls_consumer.setSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + *stats_summary_callback); + tls_consumer.setSocketOption(GeneralTransportOptions::STATS_INTERVAL, + this->timer_interval_milliseconds_); + tls_consumer.setSocketOption(MAX_WINDOW_SIZE, old_max_win_); + tls_consumer.setSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + tls_consumer.connect(); + + if (payload_ != NULL) + return tls_consumer.consume((prefix->mapName(name)), std::move(payload_)); + else + return tls_consumer.consume((prefix->mapName(name))); +} + +int P2PSecureConsumerSocket::asyncConsume(const Name &name) { + if ((SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { + ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, (double)1.0); + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, (double)1.0); + network_name_ = producer_namespace_.getRandomName(); + network_name_.setSuffix(0); + TRANSPORT_LOGD("Start handshake at %s", network_name_.toString().c_str()); + interface::ConsumerSocket::ReadCallback *on_payload = VOID_HANDLER; + this->getSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, &on_payload); + int result = SSL_connect(this->ssl_); + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, old_max_win_); + ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + if (result != 1) + throw errors::RuntimeException("Unable to perform client handshake"); + TRANSPORT_LOGD("Handshake performed!"); + } + + std::shared_ptr<Name> prefix_name = std::make_shared<Name>( + secure_prefix_.family, + ip_address_get_buffer(&(secure_prefix_.address), secure_prefix_.family)); + std::shared_ptr<Prefix> prefix = + std::make_shared<Prefix>(*prefix_name, secure_prefix_.len); + tls_consumer_ = + std::make_shared<TLSConsumerSocket>(this->protocol_, this->ssl_); + + ConsumerTimerCallback *stats_summary_callback = nullptr; + this->getSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + &stats_summary_callback); + + uint32_t lifetime; + this->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, lifetime); + tls_consumer_->setSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + lifetime); + tls_consumer_->setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + read_callback_decrypted_); + tls_consumer_->setSocketOption(ConsumerCallbacksOptions::STATS_SUMMARY, + *stats_summary_callback); + tls_consumer_->setSocketOption(GeneralTransportOptions::STATS_INTERVAL, + this->timer_interval_milliseconds_); + tls_consumer_->setSocketOption(MAX_WINDOW_SIZE, old_max_win_); + tls_consumer_->setSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + tls_consumer_->connect(); + + if (payload_ != NULL) + return tls_consumer_->asyncConsume((prefix->mapName(name)), + std::move(payload_)); + else + return tls_consumer_->asyncConsume((prefix->mapName(name))); +} + +void P2PSecureConsumerSocket::registerPrefix(const Prefix &producer_namespace) { + producer_namespace_ = producer_namespace; +} + +int P2PSecureConsumerSocket::setSocketOption( + int socket_option_key, ConsumerSocket::ReadCallback *socket_option_value) { + return rescheduleOnIOService( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ConsumerSocket::ReadCallback *socket_option_value) -> int { + switch (socket_option_key) { + case ConsumerCallbacksOptions::READ_CALLBACK: + read_callback_decrypted_ = socket_option_value; + break; + default: + return SOCKET_OPTION_NOT_SET; + } + + return SOCKET_OPTION_SET; + }); +} + +void P2PSecureConsumerSocket::getReadBuffer(uint8_t **application_buffer, + size_t *max_length){}; + +void P2PSecureConsumerSocket::readDataAvailable(size_t length) noexcept {}; + +size_t P2PSecureConsumerSocket::maxBufferSize() const { + return SSL3_RT_MAX_PLAIN_LENGTH; +} + +void P2PSecureConsumerSocket::readBufferAvailable( + std::unique_ptr<utils::MemBuf> &&buffer) noexcept { + std::unique_lock<std::mutex> lck(this->mtx_); + if (head_) { + head_->prependChain(std::move(buffer)); + } else { + head_ = std::move(buffer); + } + + something_to_read_ = true; + cv_.notify_one(); +} + +void P2PSecureConsumerSocket::readError(const std::error_code ec) noexcept {}; + +void P2PSecureConsumerSocket::readSuccess(std::size_t total_size) noexcept { + std::unique_lock<std::mutex> lck(this->mtx_); + content_downloaded_ = true; + something_to_read_ = true; + cv_.notify_one(); +} + +bool P2PSecureConsumerSocket::isBufferMovable() noexcept { return true; } + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.h b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.h new file mode 100644 index 000000000..ff867f07b --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_consumer.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/interfaces/tls_socket_consumer.h> +#include <openssl/bio.h> +#include <openssl/ssl.h> + +namespace transport { + +namespace interface { + +class P2PSecureConsumerSocket : public ConsumerSocket, + public ConsumerSocket::ReadCallback { + /* Return the number of read bytes in readbytes */ + friend int read(BIO *b, char *buf, size_t size, size_t *readbytes); + + /* Return the number of read bytes in the return param */ + friend int readOld(BIO *h, char *buf, int size); + + /* Return the number of written bytes in written */ + friend int write(BIO *b, const char *buf, size_t size, size_t *written); + + /* Return the number of written bytes in the return param */ + friend int writeOld(BIO *h, const char *buf, int num); + + friend long ctrl(BIO *b, int cmd, long num, void *ptr); + + public: + explicit P2PSecureConsumerSocket(int handshake_protocol, int transport_protocol); + + ~P2PSecureConsumerSocket(); + + int consume(const Name &name) override; + + int asyncConsume(const Name &name) override; + + void registerPrefix(const Prefix &producer_namespace); + + int setSocketOption( + int socket_option_key, + ConsumerSocket::ReadCallback *socket_option_value) override; + + using ConsumerSocket::getSocketOption; + using ConsumerSocket::setSocketOption; + + protected: + /* Callback invoked once an interest has been received and its payload + * decrypted */ + ConsumerInterestCallback on_interest_input_decrypted_; + ConsumerInterestCallback on_interest_process_decrypted_; + + private: + Name name_; + std::shared_ptr<TLSConsumerSocket> tls_consumer_; + + /* SSL handle */ + SSL *ssl_; + SSL_CTX *ctx_; + BIO_METHOD *bio_meth_; + + /* Chain of MemBuf to be used as a temporary buffer to pass descypted data + * from the underlying layer to the application */ + utils::ObjectPool<utils::MemBuf> buf_pool_; + std::unique_ptr<utils::MemBuf> decrypted_content_; + + /* Chain of MemBuf holding the payload to be written into interest or data */ + std::unique_ptr<utils::MemBuf> payload_; + + /* Chain of MemBuf holding the data retrieved from the underlying layer */ + std::unique_ptr<utils::MemBuf> head_; + + bool something_to_read_; + + bool content_downloaded_; + + double old_max_win_; + + double old_current_win_; + + uint32_t random_suffix_; + + ip_prefix_t secure_prefix_; + + Prefix producer_namespace_; + + ConsumerSocket::ReadCallback *read_callback_decrypted_; + + std::mutex mtx_; + + /* Condition variable for the wait */ + std::condition_variable cv_; + + int protocol_; + + void setInterestPayload(ConsumerSocket &c, const core::Interest &interest); + void processPayload(ConsumerSocket &c, std::size_t bytes_transferred, + const std::error_code &ec); + + static int addHicnKeyIdCb(SSL *s, unsigned int ext_type, unsigned int context, + const unsigned char **out, size_t *outlen, X509 *x, + size_t chainidx, int *al, void *add_arg); + + static void freeHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, const unsigned char *out, + void *add_arg); + + static int parseHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, const unsigned char *in, + size_t inlen, X509 *x, size_t chainidx, int *al, + void *add_arg); + + virtual void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override; + + virtual void readDataAvailable(size_t length) noexcept override; + + virtual size_t maxBufferSize() const override; + + virtual void readBufferAvailable( + std::unique_ptr<utils::MemBuf> &&buffer) noexcept override; + + virtual void readError(const std::error_code ec) noexcept override; + + virtual void readSuccess(std::size_t total_size) noexcept override; + virtual bool isBufferMovable() noexcept override; + + int download_content(const Name &name); +}; + +} // namespace interface + +} // end namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.cc b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.cc new file mode 100644 index 000000000..8850bde8a --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.cc @@ -0,0 +1,380 @@ +#include <hicn/transport/core/interest.h> +#include <hicn/transport/interfaces/p2psecure_socket_producer.h> +#include <hicn/transport/interfaces/tls_rtc_socket_producer.h> +#include <hicn/transport/interfaces/tls_socket_producer.h> + +#include <openssl/bio.h> +#include <openssl/rand.h> +#include <openssl/ssl.h> + +namespace transport { + +namespace interface { + +/* Workaround to prevent content with expiry time equal to 0 to be lost when + * pushed in the forwarder */ +#define HICN_HANDSHAKE_CONTENT_EXPIRY_TIME 100; + +P2PSecureProducerSocket::P2PSecureProducerSocket() + : ProducerSocket(), + mtx_(), + cv_(), + map_secure_producers(), + map_secure_rtc_producers(), + list_secure_producers() {} + +P2PSecureProducerSocket::P2PSecureProducerSocket( + bool rtc, const std::shared_ptr<utils::Identity> &identity) + : ProducerSocket(), + rtc_(rtc), + mtx_(), + cv_(), + map_secure_producers(), + map_secure_rtc_producers(), + list_secure_producers() { + /* + * Setup SSL context (identity and parameter to use TLS 1.3) + */ + der_cert_ = parcKeyStore_GetDEREncodedCertificate( + (identity->getSigner()->getKeyStore())); + der_prk_ = parcKeyStore_GetDEREncodedPrivateKey( + (identity->getSigner()->getKeyStore())); + + int cert_size = parcBuffer_Limit(der_cert_); + int prk_size = parcBuffer_Limit(der_prk_); + const uint8_t *cert = + reinterpret_cast<uint8_t *>(parcBuffer_Overlay(der_cert_, cert_size)); + const uint8_t *prk = + reinterpret_cast<uint8_t *>(parcBuffer_Overlay(der_prk_, prk_size)); + cert_509_ = d2i_X509(NULL, &cert, cert_size); + pkey_rsa_ = d2i_AutoPrivateKey(NULL, &prk, prk_size); + + /* + * Set the callback so that when an interest is received we catch it and we + * decrypt the payload before passing it to the application. + */ + ProducerSocket::setSocketOption( + ProducerCallbacksOptions::INTEREST_INPUT, + (ProducerInterestCallback)std::bind( + &P2PSecureProducerSocket::onInterestCallback, this, + std::placeholders::_1, std::placeholders::_2)); +} + +P2PSecureProducerSocket::~P2PSecureProducerSocket() { + if (der_cert_) parcBuffer_Release(&der_cert_); + if (der_prk_) parcBuffer_Release(&der_prk_); +} + +void P2PSecureProducerSocket::onInterestCallback(ProducerSocket &p, + Interest &interest) { + std::unique_lock<std::mutex> lck(mtx_); + + TRANSPORT_LOGD("Start handshake at %s", interest.getName().toString().c_str()); + if (!rtc_) { + auto it = map_secure_producers.find(interest.getName()); + if (it != map_secure_producers.end()) return; + TLSProducerSocket *tls_producer = + new TLSProducerSocket(this, interest.getName()); + tls_producer->on_content_produced_application_ = + this->on_content_produced_application_; + tls_producer->setSocketOption(CONTENT_OBJECT_EXPIRY_TIME, + this->content_object_expiry_time_); + tls_producer->setSocketOption(SIGNER, this->signer_); + tls_producer->setSocketOption(MAKE_MANIFEST, this->making_manifest_); + tls_producer->setSocketOption(DATA_PACKET_SIZE, + (uint32_t)(this->data_packet_size_)); + tls_producer->output_buffer_.setLimit(this->output_buffer_.getLimit()); + map_secure_producers.insert( + {interest.getName(), std::unique_ptr<TLSProducerSocket>(tls_producer)}); + tls_producer->onInterest(*tls_producer, interest); + tls_producer->async_accept(); + } else { + auto it = map_secure_rtc_producers.find(interest.getName()); + if (it != map_secure_rtc_producers.end()) return; + TLSRTCProducerSocket *tls_producer = + new TLSRTCProducerSocket(this, interest.getName()); + tls_producer->on_content_produced_application_ = + this->on_content_produced_application_; + tls_producer->setSocketOption(CONTENT_OBJECT_EXPIRY_TIME, + this->content_object_expiry_time_); + tls_producer->setSocketOption(SIGNER, this->signer_); + tls_producer->setSocketOption(MAKE_MANIFEST, this->making_manifest_); + tls_producer->setSocketOption(DATA_PACKET_SIZE, + (uint32_t)(this->data_packet_size_)); + tls_producer->output_buffer_.setLimit(this->output_buffer_.getLimit()); + map_secure_rtc_producers.insert( + {interest.getName(), + std::unique_ptr<TLSRTCProducerSocket>(tls_producer)}); + tls_producer->onInterest(*tls_producer, interest); + tls_producer->async_accept(); + } +} + +void P2PSecureProducerSocket::produce(const uint8_t *buffer, + size_t buffer_size) { + if (!rtc_) { + throw errors::RuntimeException( + "RTC must be the transport protocol to start the production of current " + "data. Aborting."); + } + + std::unique_lock<std::mutex> lck(mtx_); + if (list_secure_rtc_producers.empty()) cv_.wait(lck); + + for (auto it = list_secure_rtc_producers.cbegin(); + it != list_secure_rtc_producers.cend(); it++) { + (*it)->produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); + } +} + +uint32_t P2PSecureProducerSocket::produce( + Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, + uint32_t start_offset) { + if (rtc_) { + throw errors::RuntimeException( + "RTC transport protocol is not compatible with the production of " + "current data. Aborting."); + } + + std::unique_lock<std::mutex> lck(mtx_); + uint32_t segments = 0; + if (list_secure_producers.empty()) cv_.wait(lck); + + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + segments += + (*it)->produce(content_name, buffer->clone(), is_last, start_offset); + return segments; +} + +uint32_t P2PSecureProducerSocket::produce(Name content_name, + const uint8_t *buffer, + size_t buffer_size, bool is_last, + uint32_t start_offset) { + if (rtc_) { + throw errors::RuntimeException( + "RTC transport protocol is not compatible with the production of " + "current data. Aborting."); + } + + std::unique_lock<std::mutex> lck(mtx_); + uint32_t segments = 0; + if (list_secure_producers.empty()) cv_.wait(lck); + + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + segments += (*it)->produce(content_name, buffer, buffer_size, is_last, + start_offset); + return segments; +} + +void P2PSecureProducerSocket::asyncProduce(const Name &content_name, + const uint8_t *buf, + size_t buffer_size, bool is_last, + uint32_t *start_offset) { + if (rtc_) { + throw errors::RuntimeException( + "RTC transport protocol is not compatible with the production of " + "current data. Aborting."); + } + + std::unique_lock<std::mutex> lck(mtx_); + if (list_secure_producers.empty()) cv_.wait(lck); + + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) { + (*it)->asyncProduce(content_name, buf, buffer_size, is_last, start_offset); + } +} + +void P2PSecureProducerSocket::asyncProduce( + Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, bool is_last, + uint32_t offset, uint32_t **last_segment) { + if (rtc_) { + throw errors::RuntimeException( + "RTC transport protocol is not compatible with the production of " + "current data. Aborting."); + } + + std::unique_lock<std::mutex> lck(mtx_); + if (list_secure_producers.empty()) cv_.wait(lck); + + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) { + (*it)->asyncProduce(content_name, buffer->clone(), is_last, offset, + last_segment); + } +} + +// Socket Option Redefinition to avoid name hiding + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, ProducerInterestCallback socket_option_value) { + if (!list_secure_producers.empty()) { + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + } + + switch (socket_option_key) { + case ProducerCallbacksOptions::INTEREST_INPUT: + on_interest_input_decrypted_ = socket_option_value; + return SOCKET_OPTION_SET; + + case ProducerCallbacksOptions::INTEREST_DROP: + on_interest_dropped_input_buffer_ = socket_option_value; + return SOCKET_OPTION_SET; + + case ProducerCallbacksOptions::INTEREST_PASS: + on_interest_inserted_input_buffer_ = socket_option_value; + return SOCKET_OPTION_SET; + + case ProducerCallbacksOptions::CACHE_HIT: + on_interest_satisfied_output_buffer_ = socket_option_value; + return SOCKET_OPTION_SET; + + case ProducerCallbacksOptions::CACHE_MISS: + on_interest_process_decrypted_ = socket_option_value; + return SOCKET_OPTION_SET; + + default: + return SOCKET_OPTION_NOT_SET; + } +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, + const std::shared_ptr<utils::Signer> &socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + switch (socket_option_key) { + case GeneralTransportOptions::SIGNER: { + signer_.reset(); + signer_ = socket_option_value; + + return SOCKET_OPTION_SET; + } + default: + return SOCKET_OPTION_NOT_SET; + } +} + +int P2PSecureProducerSocket::setSocketOption(int socket_option_key, + uint32_t socket_option_value) { + if (!list_secure_producers.empty()) { + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + } + switch (socket_option_key) { + case GeneralTransportOptions::CONTENT_OBJECT_EXPIRY_TIME: + content_object_expiry_time_ = + socket_option_value; // HICN_HANDSHAKE_CONTENT_EXPIRY_TIME; + return SOCKET_OPTION_SET; + } + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption(int socket_option_key, + bool socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption(int socket_option_key, + Name *socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, std::list<Prefix> socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, ProducerContentObjectCallback socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, ProducerContentCallback socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + switch (socket_option_key) { + case ProducerCallbacksOptions::CONTENT_PRODUCED: + on_content_produced_application_ = socket_option_value; + break; + + default: + return SOCKET_OPTION_NOT_SET; + } + + return SOCKET_OPTION_SET; +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, HashAlgorithm socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, utils::CryptoSuite socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +int P2PSecureProducerSocket::setSocketOption( + int socket_option_key, const std::string &socket_option_value) { + if (!list_secure_producers.empty()) + for (auto it = list_secure_producers.cbegin(); + it != list_secure_producers.cend(); it++) + (*it)->setSocketOption(socket_option_key, socket_option_value); + + return ProducerSocket::setSocketOption(socket_option_key, + socket_option_value); +} + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.h b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.h new file mode 100644 index 000000000..ba3fa0189 --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/p2psecure_socket_producer.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/interfaces/tls_socket_producer.h> +#include <hicn/transport/interfaces/tls_rtc_socket_producer.h> +#include <hicn/transport/utils/content_store.h> +#include <hicn/transport/utils/identity.h> +#include <hicn/transport/utils/signer.h> + +#include <openssl/ssl.h> +#include <condition_variable> +#include <forward_list> +#include <mutex> + +namespace transport { + +namespace interface { + +class P2PSecureProducerSocket : public ProducerSocket { + friend class TLSProducerSocket; + friend class TLSRTCProducerSocket; + + public: + explicit P2PSecureProducerSocket(); + explicit P2PSecureProducerSocket( + bool rtc, const std::shared_ptr<utils::Identity> &identity); + ~P2PSecureProducerSocket(); + + void produce(const uint8_t *buffer, size_t buffer_size) override; + + uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, + bool is_last = true, uint32_t start_offset = 0) override; + + uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, uint32_t start_offset = 0) override; + + void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, + bool is_last = true, + uint32_t *start_offset = nullptr) override; + + void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t offset, + uint32_t **last_segment = nullptr) override; + + int setSocketOption(int socket_option_key, + ProducerInterestCallback socket_option_value) override; + + int setSocketOption( + int socket_option_key, + const std::shared_ptr<utils::Signer> &socket_option_value) override; + + int setSocketOption(int socket_option_key, + uint32_t socket_option_value) override; + + int setSocketOption(int socket_option_key, bool socket_option_value) override; + + int setSocketOption(int socket_option_key, + Name *socket_option_value) override; + + int setSocketOption(int socket_option_key, + std::list<Prefix> socket_option_value) override; + + int setSocketOption( + int socket_option_key, + ProducerContentObjectCallback socket_option_value) override; + + int setSocketOption(int socket_option_key, + ProducerContentCallback socket_option_value) override; + + int setSocketOption(int socket_option_key, + HashAlgorithm socket_option_value) override; + + int setSocketOption(int socket_option_key, + utils::CryptoSuite socket_option_value) override; + + int setSocketOption(int socket_option_key, + const std::string &socket_option_value) override; + + using ProducerSocket::getSocketOption; + using ProducerSocket::onInterest; + + protected: + bool rtc_; + /* Callback invoked once an interest has been received and its payload + * decrypted */ + ProducerInterestCallback on_interest_input_decrypted_; + ProducerInterestCallback on_interest_process_decrypted_; + ProducerContentCallback on_content_produced_application_; + + private: + std::mutex mtx_; + + /* Condition variable for the wait */ + std::condition_variable cv_; + + PARCBuffer *der_cert_; + PARCBuffer *der_prk_; + X509 *cert_509_; + EVP_PKEY *pkey_rsa_; + std::unordered_map<core::Name, std::unique_ptr<TLSProducerSocket>, + core::hash<core::Name>, core::compare2<core::Name>> + map_secure_producers; + std::unordered_map<core::Name, std::unique_ptr<TLSRTCProducerSocket>, + core::hash<core::Name>, core::compare2<core::Name>> + map_secure_rtc_producers; + std::list<std::unique_ptr<TLSProducerSocket>> list_secure_producers; + std::list<std::unique_ptr<TLSRTCProducerSocket>> list_secure_rtc_producers; + + void onInterestCallback(ProducerSocket &p, Interest &interest); +}; + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.cc b/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.cc index bb93e0535..fefa419a9 100644 --- a/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.cc +++ b/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.cc @@ -171,6 +171,7 @@ void RTCProducerSocket::produce(std::unique_ptr<utils::MemBuf> &&buffer) { on_content_object_in_output_buffer_(*this, *content_object); } + TRANSPORT_LOGD("Send content %u (produce)", content_object->getName().getSuffix()); portal_->sendContentObject(*content_object); if (on_content_object_output_) { @@ -222,6 +223,7 @@ void RTCProducerSocket::onInterest(Interest::Ptr &&interest) { on_content_object_output_(*this, *content_object); } + TRANSPORT_LOGD("Send content %u (onInterest)", content_object->getName().getSuffix()); portal_->sendContentObject(*content_object); return; } else { @@ -357,6 +359,7 @@ void RTCProducerSocket::sendNack(uint32_t sequence) { on_content_object_output_(*this, nack); } + TRANSPORT_LOGD("Send nack %u", sequence); portal_->sendContentObject(nack); } diff --git a/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.h b/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.h index 37ba88d8a..d7917a8c0 100644 --- a/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.h +++ b/libtransport/src/hicn/transport/interfaces/rtc_socket_producer.h @@ -26,7 +26,7 @@ namespace transport { namespace interface { -class RTCProducerSocket : public ProducerSocket { +class RTCProducerSocket : virtual public ProducerSocket { public: RTCProducerSocket(asio::io_service &io_service); @@ -36,7 +36,7 @@ class RTCProducerSocket : public ProducerSocket { void registerPrefix(const Prefix &producer_namespace) override; - void produce(std::unique_ptr<utils::MemBuf> &&buffer) override; + virtual void produce(std::unique_ptr<utils::MemBuf> &&buffer) override; void onInterest(Interest::Ptr &&interest) override; diff --git a/libtransport/src/hicn/transport/interfaces/socket_consumer.cc b/libtransport/src/hicn/transport/interfaces/socket_consumer.cc index fba972fe5..b2c054947 100644 --- a/libtransport/src/hicn/transport/interfaces/socket_consumer.cc +++ b/libtransport/src/hicn/transport/interfaces/socket_consumer.cc @@ -48,6 +48,7 @@ ConsumerSocket::ConsumerSocket(int protocol, asio::io_service &io_service) rate_estimation_choice_(0), verifier_(std::make_shared<utils::Verifier>()), verify_signature_(false), + key_content_(false), on_interest_output_(VOID_HANDLER), on_interest_timeout_(VOID_HANDLER), on_interest_satisfied_(VOID_HANDLER), @@ -106,9 +107,13 @@ int ConsumerSocket::asyncConsume(const Name &name) { return CONSUMER_RUNNING; } +bool ConsumerSocket::verifyKeyPackets() { + return transport_protocol_->verifyKeyPackets(); +} + void ConsumerSocket::stop() { - if (transport_protocol_->isRunning()) { - transport_protocol_->stop(); + if (transport_protocol_) { + if (transport_protocol_->isRunning()) transport_protocol_->stop(); } } @@ -312,6 +317,11 @@ int ConsumerSocket::setSocketOption(int socket_option_key, result = SOCKET_OPTION_SET; break; + case GeneralTransportOptions::KEY_CONTENT: + key_content_ = socket_option_value; + result = SOCKET_OPTION_SET; + break; + default: return result; } @@ -461,6 +471,7 @@ int ConsumerSocket::setSocketOption( if (!transport_protocol_->isRunning()) { switch (socket_option_key) { case GeneralTransportOptions::VERIFIER: + verifier_.reset(); verifier_ = socket_option_value; result = SOCKET_OPTION_SET; break; @@ -479,10 +490,7 @@ int ConsumerSocket::setSocketOption(int socket_option_key, switch (socket_option_key) { case GeneralTransportOptions::CERTIFICATE: key_id_ = verifier_->addKeyFromCertificate(socket_option_value); - - if (key_id_ != nullptr) { - result = SOCKET_OPTION_SET; - } + if (key_id_ != nullptr) result = SOCKET_OPTION_SET; break; case DataLinkOptions::OUTPUT_INTERFACE: @@ -614,6 +622,10 @@ int ConsumerSocket::getSocketOption(int socket_option_key, socket_option_value = verify_signature_; break; + case GeneralTransportOptions::KEY_CONTENT: + socket_option_value = key_content_; + break; + default: return SOCKET_OPTION_NOT_GET; } diff --git a/libtransport/src/hicn/transport/interfaces/socket_consumer.h b/libtransport/src/hicn/transport/interfaces/socket_consumer.h index acce28c1d..48a594adf 100644 --- a/libtransport/src/hicn/transport/interfaces/socket_consumer.h +++ b/libtransport/src/hicn/transport/interfaces/socket_consumer.h @@ -132,6 +132,8 @@ class ConsumerSocket : public BaseSocket { * the application when the transfer is done. */ virtual void readSuccess(std::size_t total_size) noexcept = 0; + + virtual void afterRead() {} }; /** @@ -181,8 +183,16 @@ class ConsumerSocket : public BaseSocket { * content retrieval succeeded. This information can be obtained from the * error code in CONTENT_RETRIEVED callback. */ - int consume(const Name &name); - int asyncConsume(const Name &name); + virtual int consume(const Name &name); + virtual int asyncConsume(const Name &name); + + /** + * Verify the packets containing a key after the origin of the key has been + * validated by the client. + * + * @return true if all packets are valid, false otherwise + */ + virtual bool verifyKeyPackets(); /** * Stops the consumer socket. If several downloads are queued (using @@ -330,16 +340,14 @@ class ConsumerSocket : public BaseSocket { return result; } - private: + // context inner state variables asio::io_service internal_io_service_; asio::io_service &io_service_; std::shared_ptr<Portal> portal_; + utils::EventThread async_downloader_; - // No need to protect from multiple accesses in the async consumer - // The parameter is accessible only with a getSocketOption and - // set from the consume Name network_name_; int interest_lifetime_; @@ -362,17 +370,22 @@ class ConsumerSocket : public BaseSocket { int rate_estimation_batching_parameter_; int rate_estimation_choice_; + bool is_async_; + // Verification parameters std::shared_ptr<utils::Verifier> verifier_; PARCKeyId *key_id_; std::atomic_bool verify_signature_; + std::atomic_bool key_content_; ConsumerInterestCallback on_interest_retransmission_; ConsumerInterestCallback on_interest_output_; ConsumerInterestCallback on_interest_timeout_; ConsumerInterestCallback on_interest_satisfied_; + ConsumerContentObjectCallback on_content_object_input_; ConsumerContentObjectVerificationCallback on_content_object_verification_; + ConsumerContentObjectCallback on_content_object_; ConsumerManifestCallback on_manifest_; ConsumerTimerCallback stats_summary_; @@ -396,4 +409,4 @@ class ConsumerSocket : public BaseSocket { } // namespace interface -} // end namespace transport
\ No newline at end of file +} // end namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/socket_options_keys.h b/libtransport/src/hicn/transport/interfaces/socket_options_keys.h index b25bacbb9..a38131271 100644 --- a/libtransport/src/hicn/transport/interfaces/socket_options_keys.h +++ b/libtransport/src/hicn/transport/interfaces/socket_options_keys.h @@ -34,7 +34,7 @@ typedef enum { DATA_PACKET_SIZE = 106, INTEREST_LIFETIME = 107, CONTENT_OBJECT_EXPIRY_TIME = 108, - KEY_LOCATOR = 110, + KEY_CONTENT = 110, MIN_WINDOW_SIZE = 111, MAX_WINDOW_SIZE = 112, CURRENT_WINDOW_SIZE = 113, @@ -45,11 +45,11 @@ typedef enum { APPLICATION_BUFFER = 118, HASH_ALGORITHM = 119, CRYPTO_SUITE = 120, - IDENTITY = 121, + SIGNER = 121, VERIFIER = 122, CERTIFICATE = 123, VERIFY_SIGNATURE = 124, - STATS_INTERVAL = 125 + STATS_INTERVAL = 125, } GeneralTransportOptions; typedef enum { @@ -92,7 +92,7 @@ typedef enum { CONTENT_OBJECT_SIGN = 513, CONTENT_OBJECT_READY = 510, CONTENT_OBJECT_OUTPUT = 511, - CONTENT_PRODUCED = 512 + CONTENT_PRODUCED = 512, } ProducerCallbacksOptions; typedef enum { OUTPUT_INTERFACE = 601 } DataLinkOptions; @@ -110,4 +110,4 @@ typedef enum { } // namespace interface -} // end namespace transport
\ No newline at end of file +} // end namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/socket_producer.cc b/libtransport/src/hicn/transport/interfaces/socket_producer.cc index 4fef5d1e2..26a7208b6 100644 --- a/libtransport/src/hicn/transport/interfaces/socket_producer.cc +++ b/libtransport/src/hicn/transport/interfaces/socket_producer.cc @@ -14,7 +14,7 @@ */ #include <hicn/transport/interfaces/socket_producer.h> -#include <hicn/transport/utils/identity.h> +#include <hicn/transport/utils/signer.h> #include <functional> @@ -35,6 +35,7 @@ ProducerSocket::ProducerSocket(asio::io_service &io_service) data_packet_size_(default_values::content_object_packet_size), content_object_expiry_time_(default_values::content_object_expiry_time), output_buffer_(default_values::producer_socket_output_buffer_size), + async_thread_(), registration_status_(REGISTRATION_NOT_ATTEMPTED), making_manifest_(false), hash_algorithm_(HashAlgorithm::SHA_256), @@ -96,51 +97,47 @@ void ProducerSocket::listen() { void ProducerSocket::passContentObjectToCallbacks( const std::shared_ptr<ContentObject> &content_object) { if (content_object) { - if (on_new_segment_) { - io_service_.dispatch([this, content_object]() { + io_service_.dispatch([this, content_object]() { + if (on_new_segment_) { on_new_segment_(*this, *content_object); - }); - } + } - if (on_content_object_to_sign_) { - io_service_.dispatch([this, content_object]() { + if (on_content_object_to_sign_) { on_content_object_to_sign_(*this, *content_object); - }); - } + } - if (on_content_object_in_output_buffer_) { - io_service_.dispatch([this, content_object]() { + if (on_content_object_in_output_buffer_) { on_content_object_in_output_buffer_(*this, *content_object); - }); - } + } + }); output_buffer_.insert(content_object); - if (on_content_object_output_) { - io_service_.dispatch([this, content_object]() { + io_service_.dispatch([this, content_object]() { + if (on_content_object_output_) { on_content_object_output_(*this, *content_object); - }); - } + } + }); portal_->sendContentObject(*content_object); } } void ProducerSocket::produce(ContentObject &content_object) { - if (on_content_object_in_output_buffer_) { - io_service_.dispatch([this, &content_object]() { + io_service_.dispatch([this, &content_object]() { + if (on_content_object_in_output_buffer_) { on_content_object_in_output_buffer_(*this, content_object); - }); - } + } + }); output_buffer_.insert(std::static_pointer_cast<ContentObject>( content_object.shared_from_this())); - if (on_content_object_output_) { - io_service_.dispatch([this, &content_object]() { + io_service_.dispatch([this, &content_object]() { + if (on_content_object_output_) { on_content_object_output_(*this, content_object); - }); - } + } + }); portal_->sendContentObject(content_object); } @@ -160,8 +157,8 @@ uint32_t ProducerSocket::produce(Name content_name, bool making_manifest = making_manifest_; auto suffix_strategy = utils::SuffixStrategyFactory::getSuffixStrategy( suffix_strategy_, start_offset); - std::shared_ptr<utils::Identity> identity; - getSocketOption(GeneralTransportOptions::IDENTITY, identity); + std::shared_ptr<utils::Signer> signer; + getSocketOption(GeneralTransportOptions::SIGNER, signer); auto buffer_size = buffer->length(); int bytes_segmented = 0; @@ -176,7 +173,7 @@ uint32_t ProducerSocket::produce(Name content_name, bool is_last_manifest = false; // TODO Manifest may still be used for indexing - if (making_manifest && !identity) { + if (making_manifest && !signer) { TRANSPORT_LOGD("Making manifests without setting producer identity."); } @@ -197,11 +194,11 @@ uint32_t ProducerSocket::produce(Name content_name, format = hf_format; if (making_manifest) { manifest_header_size = core::Packet::getHeaderSizeFromFormat( - identity ? hf_format_ah : hf_format, - identity ? identity->getSignatureLength() : 0); - } else if (identity) { + signer ? hf_format_ah : hf_format, + signer ? signer->getSignatureLength() : 0); + } else if (signer) { format = hf_format_ah; - signature_length = identity->getSignatureLength(); + signature_length = signer->getSignatureLength(); } header_size = core::Packet::getHeaderSizeFromFormat(format, signature_length); @@ -227,7 +224,7 @@ uint32_t ProducerSocket::produce(Name content_name, content_name.setSuffix(suffix_strategy->getNextManifestSuffix()), core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, content_name, suffix_strategy_, - identity ? identity->getSignatureLength() : 0)); + signer ? signer->getSignatureLength() : 0)); manifest->setLifetime(content_object_expiry_time); if (is_last) { @@ -237,6 +234,7 @@ uint32_t ProducerSocket::produce(Name content_name, } } + TRANSPORT_LOGD("--------- START PRODUCE ----------"); for (unsigned int packaged_segments = 0; packaged_segments < number_of_segments; packaged_segments++) { if (making_manifest) { @@ -246,15 +244,18 @@ uint32_t ProducerSocket::produce(Name content_name, manifest->encode(); // If identity set, sign manifest - if (identity) { - identity->getSigner().sign(*manifest); + if (signer) { + signer->sign(*manifest); } passContentObjectToCallbacks(manifest); + TRANSPORT_LOGD("Send manifest %u", manifest->getName().getSuffix()); // Send content objects stored in the queue while (!content_queue_.empty()) { passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %u", + content_queue_.front()->getName().getSuffix()); content_queue_.pop(); } @@ -266,7 +267,7 @@ uint32_t ProducerSocket::produce(Name content_name, core::ManifestVersion::VERSION_1, core::ManifestType::INLINE_MANIFEST, hash_algo, is_last_manifest, content_name, suffix_strategy_, - identity ? identity->getSignatureLength() : 0)); + signer ? signer->getSignatureLength() : 0)); manifest->setLifetime(content_object_expiry_time); manifest->setFinalBlockNumber( @@ -307,10 +308,11 @@ uint32_t ProducerSocket::produce(Name content_name, manifest->addSuffixHash(content_suffix, hash); content_queue_.push(content_object); } else { - if (identity) { - identity->getSigner().sign(*content_object); + if (signer) { + signer->sign(*content_object); } passContentObjectToCallbacks(content_object); + TRANSPORT_LOGD("Send content %u", content_object->getName().getSuffix()); } } @@ -320,24 +322,28 @@ uint32_t ProducerSocket::produce(Name content_name, } manifest->encode(); - if (identity) { - identity->getSigner().sign(*manifest); + if (signer) { + signer->sign(*manifest); } passContentObjectToCallbacks(manifest); + TRANSPORT_LOGD("Send manifest %u", manifest->getName().getSuffix()); while (!content_queue_.empty()) { passContentObjectToCallbacks(content_queue_.front()); + TRANSPORT_LOGD("Send content %u", + content_queue_.front()->getName().getSuffix()); content_queue_.pop(); } } - if (on_content_produced_) { - io_service_.dispatch([this, buffer_size]() { + io_service_.dispatch([this, buffer_size]() { + if (on_content_produced_) { on_content_produced_(*this, std::make_error_code(std::errc(0)), buffer_size); - }); - } + } + }); + TRANSPORT_LOGD("--------- END PRODUCE ------------"); return suffix_strategy->getTotalCount(); } @@ -346,16 +352,42 @@ void ProducerSocket::asyncProduce(ContentObject &content_object) { auto co_ptr = std::static_pointer_cast<ContentObject>( content_object.shared_from_this()); async_thread_.add([this, content_object = std::move(co_ptr)]() { - produce(*content_object); + ProducerSocket::produce(*content_object); }); } } void ProducerSocket::asyncProduce(const Name &suffix, const uint8_t *buf, - size_t buffer_size) { + size_t buffer_size, bool is_last, + uint32_t *start_offset) { if (!async_thread_.stopped()) { - async_thread_.add([this, suffix, buffer = buf, size = buffer_size]() { - produce(suffix, buffer, size, 0, false); + async_thread_.add([this, suffix, buffer = buf, size = buffer_size, is_last, + start_offset]() { + if (start_offset != NULL) { + *start_offset = ProducerSocket::produce(suffix, buffer, size, is_last, + *start_offset); + } else { + ProducerSocket::produce(suffix, buffer, size, is_last, 0); + } + }); + } +} + +void ProducerSocket::asyncProduce(Name content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t offset, + uint32_t **last_segment) { + if (!async_thread_.stopped()) { + auto a = buffer.release(); + async_thread_.add([this, content_name, a, is_last, offset, last_segment]() { + auto buf = std::unique_ptr<utils::MemBuf>(a); + if (last_segment != NULL) { + **last_segment = + offset + ProducerSocket::produce(content_name, std::move(buf), + is_last, offset); + } else { + ProducerSocket::produce(content_name, std::move(buf), is_last, offset); + } }); } } @@ -392,8 +424,8 @@ int ProducerSocket::setSocketOption(int socket_option_key, if (socket_option_value < default_values::max_content_object_size && socket_option_value > 0) { data_packet_size_ = socket_option_value; - break; } + break; case GeneralTransportOptions::OUTPUT_BUFFER_SIZE: output_buffer_.setLimit(socket_option_value); @@ -632,12 +664,12 @@ int ProducerSocket::setSocketOption(int socket_option_key, int ProducerSocket::setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Identity> &socket_option_value) { + const std::shared_ptr<utils::Signer> &socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::IDENTITY: { - utils::SpinLock::Acquire locked(identity_lock_); - identity_.reset(); - identity_ = socket_option_value; + case GeneralTransportOptions::SIGNER: { + utils::SpinLock::Acquire locked(signer_lock_); + signer_.reset(); + signer_ = socket_option_value; } break; default: return SOCKET_OPTION_NOT_SET; @@ -844,11 +876,11 @@ int ProducerSocket::getSocketOption(int socket_option_key, int ProducerSocket::getSocketOption( int socket_option_key, - std::shared_ptr<utils::Identity> &socket_option_value) { + std::shared_ptr<utils::Signer> &socket_option_value) { switch (socket_option_key) { - case GeneralTransportOptions::IDENTITY: { - utils::SpinLock::Acquire locked(identity_lock_); - socket_option_value = identity_; + case GeneralTransportOptions::SIGNER: { + utils::SpinLock::Acquire locked(signer_lock_); + socket_option_value = signer_; } break; default: return SOCKET_OPTION_NOT_GET; diff --git a/libtransport/src/hicn/transport/interfaces/socket_producer.h b/libtransport/src/hicn/transport/interfaces/socket_producer.h index ff6f49723..5f360f2ce 100644 --- a/libtransport/src/hicn/transport/interfaces/socket_producer.h +++ b/libtransport/src/hicn/transport/interfaces/socket_producer.h @@ -18,10 +18,12 @@ #include <hicn/transport/interfaces/socket.h> #include <hicn/transport/utils/content_store.h> #include <hicn/transport/utils/event_thread.h> +#include <hicn/transport/utils/signer.h> #include <hicn/transport/utils/suffix_strategy.h> #include <atomic> #include <cmath> +#include <condition_variable> #include <mutex> #include <queue> #include <thread> @@ -31,10 +33,6 @@ #define REGISTRATION_FAILURE 2 #define REGISTRATION_IN_PROGRESS 3 -namespace utils { -class Identity; -} - namespace transport { namespace interface { @@ -53,16 +51,19 @@ class ProducerSocket : public Socket<BasePortal>, bool isRunning() override { return !io_service_.stopped(); }; - uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, - bool is_last = true, uint32_t start_offset = 0) { - return produce(content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), - is_last, start_offset); + virtual uint32_t produce(Name content_name, const uint8_t *buffer, + size_t buffer_size, bool is_last = true, + uint32_t start_offset = 0) { + return ProducerSocket::produce( + content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), is_last, + start_offset); } - uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, - bool is_last = true, uint32_t start_offset = 0); + virtual uint32_t produce(Name content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, uint32_t start_offset = 0); - void produce(ContentObject &content_object); + virtual void produce(ContentObject &content_object); virtual void produce(const uint8_t *buffer, size_t buffer_size) { produce(utils::MemBuf::copyBuffer(buffer, buffer_size)); @@ -74,11 +75,18 @@ class ProducerSocket : public Socket<BasePortal>, throw errors::NotImplementedException(); } - void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size); + virtual void asyncProduce(const Name &suffix, const uint8_t *buf, + size_t buffer_size, bool is_last = true, + uint32_t *start_offset = nullptr); void asyncProduce(const Name &suffix); - void asyncProduce(ContentObject &content_object); + virtual void asyncProduce(Name content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t offset, + uint32_t **last_segment = nullptr); + + virtual void asyncProduce(ContentObject &content_object); virtual void registerPrefix(const Prefix &producer_namespace); @@ -124,7 +132,7 @@ class ProducerSocket : public Socket<BasePortal>, virtual int setSocketOption( int socket_option_key, - const std::shared_ptr<utils::Identity> &socket_option_value); + const std::shared_ptr<utils::Signer> &socket_option_value); virtual int setSocketOption(int socket_option_key, const std::string &socket_option_value); @@ -158,15 +166,77 @@ class ProducerSocket : public Socket<BasePortal>, virtual int getSocketOption( int socket_option_key, - std::shared_ptr<utils::Identity> &socket_option_value); + std::shared_ptr<utils::Signer> &socket_option_value); virtual int getSocketOption(int socket_option_key, std::string &socket_option_value); - protected: + // If the thread calling lambda_func is not the same of io_service, this + // function reschedule the function on it + template <typename Lambda, typename arg2> + int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, + Lambda lambda_func) { + // To enforce type check + std::function<int(int, arg2)> func = lambda_func; + int result = SOCKET_OPTION_SET; + if (listening_thread_.joinable() && + std::this_thread::get_id() != listening_thread_.get_id()) { + std::mutex mtx; + /* Condition variable for the wait */ + std::condition_variable cv; + bool done = false; + io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, + &result, &done, &func]() { + std::unique_lock<std::mutex> lck(mtx); + done = true; + result = func(socket_option_key, socket_option_value); + cv.notify_all(); + }); + std::unique_lock<std::mutex> lck(mtx); + if (!done) { + cv.wait(lck); + } + } else { + result = func(socket_option_key, socket_option_value); + } + + return result; + } + + template <typename Lambda, typename arg2> + int rescheduleOnIOServiceWithReference(int socket_option_key, + arg2 &socket_option_value, + Lambda lambda_func) { + // To enforce type check + std::function<int(int, arg2 &)> func = lambda_func; + int result = SOCKET_OPTION_SET; + if (listening_thread_.joinable() && + std::this_thread::get_id() != this->listening_thread_.get_id()) { + std::mutex mtx; + /* Condition variable for the wait */ + std::condition_variable cv; + + bool done = false; + io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, + &result, &done, &func]() { + std::unique_lock<std::mutex> lck(mtx); + done = true; + result = func(socket_option_key, socket_option_value); + cv.notify_all(); + }); + std::unique_lock<std::mutex> lck(mtx); + if (!done) { + cv.wait(lck); + } + } else { + result = func(socket_option_key, socket_option_value); + } + + return result; + } // Threads + protected: std::thread listening_thread_; - asio::io_service internal_io_service_; asio::io_service &io_service_; std::shared_ptr<Portal> portal_; @@ -191,8 +261,8 @@ class ProducerSocket : public Socket<BasePortal>, std::atomic<HashAlgorithm> hash_algorithm_; std::atomic<utils::CryptoSuite> crypto_suite_; - utils::SpinLock identity_lock_; - std::shared_ptr<utils::Identity> identity_; + utils::SpinLock signer_lock_; + std::shared_ptr<utils::Signer> signer_; core::NextSegmentCalculationStrategy suffix_strategy_; // While manifests are being built, contents are stored in a queue @@ -213,38 +283,6 @@ class ProducerSocket : public Socket<BasePortal>, ProducerContentCallback on_content_produced_; - // If the thread calling lambda_func is not the same of io_service, this - // function reschedule the function on it - template <typename Lambda, typename arg2> - int rescheduleOnIOService(int socket_option_key, arg2 socket_option_value, - Lambda lambda_func) { - // To enforce type check - std::function<int(int, arg2)> func = lambda_func; - int result = SOCKET_OPTION_SET; - if (listening_thread_.joinable() && - std::this_thread::get_id() != listening_thread_.get_id()) { - std::mutex mtx; - /* Condition variable for the wait */ - std::condition_variable cv; - bool done = false; - io_service_.dispatch([&socket_option_key, &socket_option_value, &mtx, &cv, - &result, &done, &func]() { - std::unique_lock<std::mutex> lck(mtx); - done = true; - result = func(socket_option_key, socket_option_value); - cv.notify_all(); - }); - std::unique_lock<std::mutex> lck(mtx); - if (!done) { - cv.wait(lck); - } - } else { - result = func(socket_option_key, socket_option_value); - } - - return result; - } - private: void listen(); diff --git a/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.cc b/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.cc new file mode 100644 index 000000000..27c1e54bd --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.cc @@ -0,0 +1,178 @@ +#include <hicn/transport/core/interest.h> +#include <hicn/transport/interfaces/p2psecure_socket_producer.h> +#include <hicn/transport/interfaces/tls_rtc_socket_producer.h> + +#include <openssl/bio.h> +#include <openssl/rand.h> +#include <openssl/ssl.h> + +namespace transport { + +namespace interface { +int TLSRTCProducerSocket::read(BIO *b, char *buf, size_t size, + size_t *readbytes) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = TLSRTCProducerSocket::readOld(b, buf, (int)size); + + if (ret <= 0) { + *readbytes = 0; + return ret; + } + + *readbytes = (size_t)ret; + + return 1; +} + +int TLSRTCProducerSocket::readOld(BIO *b, char *buf, int size) { + TLSRTCProducerSocket *socket; + socket = (TLSRTCProducerSocket *)BIO_get_data(b); + + std::unique_lock<std::mutex> lck(socket->mtx_); + if (!socket->something_to_read_) { + (socket->cv_).wait(lck); + } + + utils::MemBuf *membuf = socket->packet_->next(); + int size_to_read; + + if ((int)membuf->length() > size) { + size_to_read = size; + } else { + size_to_read = membuf->length(); + socket->something_to_read_ = false; + } + + std::memcpy(buf, membuf->data(), size_to_read); + membuf->trimStart(size_to_read); + + return size_to_read; +} + +int TLSRTCProducerSocket::write(BIO *b, const char *buf, size_t size, + size_t *written) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = TLSRTCProducerSocket::writeOld(b, buf, (int)size); + + if (ret <= 0) { + *written = 0; + return ret; + } + + *written = (size_t)ret; + + return 1; +} + +int TLSRTCProducerSocket::writeOld(BIO *b, const char *buf, int num) { + TLSRTCProducerSocket *socket; + socket = (TLSRTCProducerSocket *)BIO_get_data(b); + + if ((SSL_in_before(socket->ssl_) || SSL_in_init(socket->ssl_)) && + socket->first_) { + socket->tls_chunks_--; + bool making_manifest = socket->parent_->making_manifest_; + socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + false); + socket->parent_->ProducerSocket::produce( + socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, 0); + socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + making_manifest); + socket->first_ = false; + + } else { + std::unique_ptr<utils::MemBuf> mbuf = + utils::MemBuf::copyBuffer(buf, (std::size_t)num, 0, 0); + auto a = mbuf.release(); + socket->async_thread_.add([socket = socket, a]() { + socket->to_call_oncontentproduced_--; + auto mbuf = std::unique_ptr<utils::MemBuf>(a); + socket->RTCProducerSocket::produce(std::move(mbuf)); + ProducerContentCallback on_content_produced_application; + socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, + on_content_produced_application); + if (socket->to_call_oncontentproduced_ == 0 && + on_content_produced_application) { + on_content_produced_application(*socket, std::error_code(), 0); + } + }); + } + + return num; +} + +TLSRTCProducerSocket::TLSRTCProducerSocket(P2PSecureProducerSocket *parent, + const Name &handshake_name) + : RTCProducerSocket(), TLSProducerSocket(parent, handshake_name) { + BIO_METHOD *bio_meth = + BIO_meth_new(BIO_TYPE_ACCEPT, "secure rtc producer socket"); + BIO_meth_set_read(bio_meth, TLSRTCProducerSocket::readOld); + BIO_meth_set_write(bio_meth, TLSRTCProducerSocket::writeOld); + BIO_meth_set_ctrl(bio_meth, TLSProducerSocket::ctrl); + BIO *bio = BIO_new(bio_meth); + BIO_set_init(bio, 1); + BIO_set_data(bio, this); + SSL_set_bio(ssl_, bio, bio); +} + +void TLSRTCProducerSocket::accept() { + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + tls_chunks_ = 1; + int result = SSL_accept(ssl_); + if (result != 1) + throw errors::RuntimeException("Unable to perform client handshake"); + } + + TRANSPORT_LOGD("Handshake performed!"); + parent_->list_secure_rtc_producers.push_front( + std::move(parent_->map_secure_rtc_producers[handshake_name_])); + parent_->map_secure_rtc_producers.erase(handshake_name_); + + ProducerInterestCallback on_interest_process_decrypted; + getSocketOption(ProducerCallbacksOptions::CACHE_MISS, + on_interest_process_decrypted); + + if (on_interest_process_decrypted) { + Interest inter(std::move(packet_)); + on_interest_process_decrypted(*this, inter); + } + + parent_->cv_.notify_one(); +} + +int TLSRTCProducerSocket::async_accept() { + if (!async_thread_.stopped()) { + async_thread_.add([this]() { this->TLSRTCProducerSocket::accept(); }); + } else { + throw errors::RuntimeException( + "Async thread not running, impossible to perform handshake"); + } + + return 1; +} + +void TLSRTCProducerSocket::produce(std::unique_ptr<utils::MemBuf> &&buffer) { + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + throw errors::RuntimeException( + "New handshake on the same P2P secure producer socket not supported"); + } + + size_t buf_size = buffer->length(); + tls_chunks_ = ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); + to_call_oncontentproduced_ = tls_chunks_; + + SSL_write(ssl_, buffer->data(), buf_size); + BIO *wbio = SSL_get_wbio(ssl_); + int i = BIO_flush(wbio); + (void)i; // To shut up gcc 5 +} + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.h b/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.h new file mode 100644 index 000000000..16125f889 --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_rtc_socket_producer.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/rtc_socket_producer.h> +#include <hicn/transport/interfaces/tls_socket_producer.h> + +namespace transport { + +namespace interface { + +class P2PSecureProducerSocket; + +class TLSRTCProducerSocket : public RTCProducerSocket, + public TLSProducerSocket { + friend class P2PSecureProducerSocket; + + public: + explicit TLSRTCProducerSocket(P2PSecureProducerSocket *parent, + const Name &handshake_name); + + ~TLSRTCProducerSocket() = default; + + void produce(std::unique_ptr<utils::MemBuf> &&buffer) override; + + void accept() override; + + int async_accept() override; + + using TLSProducerSocket::produce; + using TLSProducerSocket::onInterest; + + protected: + static int read(BIO *b, char *buf, size_t size, size_t *readbytes); + + static int readOld(BIO *h, char *buf, int size); + + static int write(BIO *b, const char *buf, size_t size, size_t *written); + + static int writeOld(BIO *h, const char *buf, int num); +}; + +} // namespace interface + +} // end namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.cc b/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.cc new file mode 100644 index 000000000..58b3c1d7d --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.cc @@ -0,0 +1,364 @@ +#include <hicn/transport/interfaces/tls_socket_consumer.h> +#include <openssl/bio.h> +#include <openssl/ssl.h> +#include <openssl/tls1.h> + +#include <random> + +namespace transport { + +namespace interface { + +void TLSConsumerSocket::setInterestPayload(ConsumerSocket &c, + const core::Interest &interest) { + Interest &int2 = const_cast<Interest &>(interest); + random_suffix_ = int2.getName().getSuffix(); + + if (payload_ != NULL) int2.appendPayload(std::move(payload_)); +} + +/* Return the number of read bytes in the return param */ +int readOldTLS(BIO *b, char *buf, int size) { + if (size < 0) return size; + + TLSConsumerSocket *socket; + socket = (TLSConsumerSocket *)BIO_get_data(b); + + std::unique_lock<std::mutex> lck(socket->mtx_); + + if (!socket->something_to_read_) { + if (!socket->transport_protocol_->isRunning()) { + socket->network_name_.setSuffix(socket->random_suffix_); + socket->ConsumerSocket::asyncConsume(socket->network_name_); + } + if (!socket->something_to_read_) socket->cv_.wait(lck); + } + + size_t size_to_read, read; + size_t chain_size = socket->head_->length(); + if (socket->head_->isChained()) + chain_size = socket->head_->computeChainDataLength(); + + if (chain_size > (size_t)size) { + read = size_to_read = (size_t)size; + } else { + read = size_to_read = chain_size; + socket->something_to_read_ = false; + } + + while (size_to_read) { + if (socket->head_->length() < size_to_read) { + std::memcpy(buf, socket->head_->data(), socket->head_->length()); + size_to_read -= socket->head_->length(); + buf += socket->head_->length(); + socket->head_ = socket->head_->pop(); + } else { + std::memcpy(buf, socket->head_->data(), size_to_read); + socket->head_->trimStart(size_to_read); + size_to_read = 0; + } + } + + return read; +} + +/* Return the number of read bytes in readbytes */ +int readTLS(BIO *b, char *buf, size_t size, size_t *readbytes) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = transport::interface::readOldTLS(b, buf, (int)size); + + if (ret <= 0) { + *readbytes = 0; + return ret; + } + + *readbytes = (size_t)ret; + + return 1; +} + +/* Return the number of written bytes in the return param */ +int writeOldTLS(BIO *b, const char *buf, int num) { + TLSConsumerSocket *socket; + socket = (TLSConsumerSocket *)BIO_get_data(b); + + socket->payload_ = utils::MemBuf::copyBuffer(buf, num); + socket->ConsumerSocket::setSocketOption( + ConsumerCallbacksOptions::INTEREST_OUTPUT, + (ConsumerInterestCallback)std::bind( + &TLSConsumerSocket::setInterestPayload, socket, std::placeholders::_1, + std::placeholders::_2)); + + return num; +} + +/* Return the number of written bytes in written */ +int writeTLS(BIO *b, const char *buf, size_t size, size_t *written) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = transport::interface::writeOldTLS(b, buf, (int)size); + + if (ret <= 0) { + *written = 0; + return ret; + } + + *written = (size_t)ret; + + return 1; +} + +long ctrlTLS(BIO *b, int cmd, long num, void *ptr) { return 1; } + +TLSConsumerSocket::TLSConsumerSocket(int protocol, SSL *ssl) + : ConsumerSocket(protocol), + name_(), + buf_pool_(), + decrypted_content_(), + payload_(), + head_(), + something_to_read_(false), + content_downloaded_(false), + random_suffix_(), + secure_prefix_(), + producer_namespace_(), + read_callback_decrypted_(), + mtx_(), + cv_(), + async_downloader_tls_() { + /* Create the (d)TLS state */ + const SSL_METHOD *meth = TLS_client_method(); + ctx_ = SSL_CTX_new(meth); + + int result = + SSL_CTX_set_ciphersuites(ctx_, + "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" + "SHA256:TLS_AES_128_GCM_SHA256"); + if (result != 1) { + throw errors::RuntimeException( + "Unable to set cipher list on TLS subsystem. Aborting."); + } + + SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); + SSL_CTX_set_ssl_version(ctx_, meth); + + ssl_ = ssl; + + BIO_METHOD *bio_meth = + BIO_meth_new(BIO_TYPE_CONNECT, "secure consumer socket"); + BIO_meth_set_read(bio_meth, transport::interface::readOldTLS); + BIO_meth_set_write(bio_meth, transport::interface::writeOldTLS); + BIO_meth_set_ctrl(bio_meth, transport::interface::ctrlTLS); + BIO *bio = BIO_new(bio_meth); + BIO_set_init(bio, 1); + BIO_set_data(bio, this); + SSL_set_bio(ssl_, bio, bio); + + ConsumerSocket::getSocketOption(MAX_WINDOW_SIZE, old_max_win_); + ConsumerSocket::setSocketOption(MAX_WINDOW_SIZE, (double)1.0); + + ConsumerSocket::getSocketOption(CURRENT_WINDOW_SIZE, old_current_win_); + ConsumerSocket::setSocketOption(CURRENT_WINDOW_SIZE, (double)1.0); + + std::default_random_engine generator; + std::uniform_int_distribution<int> distribution( + 1, std::numeric_limits<uint32_t>::max()); + random_suffix_ = 0; + + this->ConsumerSocket::setSocketOption(ConsumerCallbacksOptions::READ_CALLBACK, + this); +}; + +int TLSConsumerSocket::consume(const Name &name, + std::unique_ptr<utils::MemBuf> &&buffer) { + this->payload_ = std::move(buffer); + + this->ConsumerSocket::setSocketOption( + ConsumerCallbacksOptions::INTEREST_OUTPUT, + (ConsumerInterestCallback)std::bind( + &TLSConsumerSocket::setInterestPayload, this, std::placeholders::_1, + std::placeholders::_2)); + + return consume(name); +} + +int TLSConsumerSocket::consume(const Name &name) { + if (transport_protocol_->isRunning()) { + return CONSUMER_BUSY; + } + + if ((SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { + throw errors::RuntimeException("Handshake not performed"); + } + + return download_content(name); +} + +int TLSConsumerSocket::download_content(const Name &name) { + network_name_ = name; + network_name_.setSuffix(0); + something_to_read_ = false; + content_downloaded_ = false; + + decrypted_content_ = utils::MemBuf::createCombined(SSL3_RT_MAX_PLAIN_LENGTH); + uint8_t *buf = decrypted_content_->writableData(); + size_t size = 0; + int result = -1; + + while (!content_downloaded_ || something_to_read_) { + if (decrypted_content_->tailroom() < SSL3_RT_MAX_PLAIN_LENGTH) { + decrypted_content_->appendChain( + utils::MemBuf::createCombined(SSL3_RT_MAX_PLAIN_LENGTH)); + // decrypted_content_->computeChainDataLength(); + buf = decrypted_content_->prev()->writableData(); + } else { + buf = decrypted_content_->writableTail(); + } + + result = SSL_read(this->ssl_, buf, SSL3_RT_MAX_PLAIN_LENGTH); + + /* SSL_read returns the data only if there were SSL3_RT_MAX_PLAIN_LENGTH of + * the data has been fully downloaded */ + + /* ASSERT((result < SSL3_RT_MAX_PLAIN_LENGTH && content_downloaded_) || */ + /* result == SSL3_RT_MAX_PLAIN_LENGTH); */ + + if (result >= 0) { + size += result; + decrypted_content_->prepend(result); + } else + throw errors::RuntimeException("Unable to download content"); + + if (size >= read_callback_decrypted_->maxBufferSize()) { + if (read_callback_decrypted_->isBufferMovable()) { + // No need to perform an additional copy. The whole buffer will be + // tranferred to the application. + + read_callback_decrypted_->readBufferAvailable( + std::move(decrypted_content_)); + decrypted_content_ = utils::MemBuf::create(SSL3_RT_MAX_PLAIN_LENGTH); + } else { + // The buffer will be copied into the application-provided buffer + uint8_t *buffer; + std::size_t length; + std::size_t total_length = decrypted_content_->length(); + + while (decrypted_content_->length()) { + buffer = nullptr; + length = 0; + read_callback_decrypted_->getReadBuffer(&buffer, &length); + + if (!buffer || !length) { + throw errors::RuntimeException( + "Invalid buffer provided by the application."); + } + + auto to_copy = std::min(decrypted_content_->length(), length); + std::memcpy(buffer, decrypted_content_->data(), to_copy); + decrypted_content_->trimStart(to_copy); + } + + read_callback_decrypted_->readDataAvailable(total_length); + decrypted_content_->clear(); + } + } + } + + read_callback_decrypted_->readSuccess(size); + + return CONSUMER_FINISHED; +} + +int TLSConsumerSocket::asyncConsume(const Name &name, + std::unique_ptr<utils::MemBuf> &&buffer) { + this->payload_ = std::move(buffer); + + this->ConsumerSocket::setSocketOption( + ConsumerCallbacksOptions::INTEREST_OUTPUT, + (ConsumerInterestCallback)std::bind( + &TLSConsumerSocket::setInterestPayload, this, std::placeholders::_1, + std::placeholders::_2)); + + return asyncConsume(name); +} + +int TLSConsumerSocket::asyncConsume(const Name &name) { + if ((SSL_in_before(this->ssl_) || SSL_in_init(this->ssl_))) { + throw errors::RuntimeException("Handshake not performed"); + } + + if (!async_downloader_tls_.stopped()) { + async_downloader_tls_.add([this, name]() { + is_async_ = true; + download_content(name); + }); + } + + return CONSUMER_RUNNING; +} + +void TLSConsumerSocket::registerPrefix(const Prefix &producer_namespace) { + producer_namespace_ = producer_namespace; +} + +int TLSConsumerSocket::setSocketOption( + int socket_option_key, ConsumerSocket::ReadCallback *socket_option_value) { + return rescheduleOnIOService( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ConsumerSocket::ReadCallback *socket_option_value) -> int { + switch (socket_option_key) { + case ConsumerCallbacksOptions::READ_CALLBACK: + read_callback_decrypted_ = socket_option_value; + break; + default: + return SOCKET_OPTION_NOT_SET; + } + + return SOCKET_OPTION_SET; + }); +} + +void TLSConsumerSocket::getReadBuffer(uint8_t **application_buffer, + size_t *max_length) {} + +void TLSConsumerSocket::readDataAvailable(size_t length) noexcept {} + +size_t TLSConsumerSocket::maxBufferSize() const { + return SSL3_RT_MAX_PLAIN_LENGTH; +} + +void TLSConsumerSocket::readBufferAvailable( + std::unique_ptr<utils::MemBuf> &&buffer) noexcept { + std::unique_lock<std::mutex> lck(this->mtx_); + if (head_) { + head_->prependChain(std::move(buffer)); + } else { + head_ = std::move(buffer); + } + + something_to_read_ = true; + cv_.notify_one(); +} + +void TLSConsumerSocket::readError(const std::error_code ec) noexcept {} + +void TLSConsumerSocket::readSuccess(std::size_t total_size) noexcept { + std::unique_lock<std::mutex> lck(this->mtx_); + content_downloaded_ = true; + something_to_read_ = true; + cv_.notify_one(); +} + +bool TLSConsumerSocket::isBufferMovable() noexcept { return true; } + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.h b/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.h new file mode 100644 index 000000000..05f7fe6a5 --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.h @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/socket_consumer.h> +#include <openssl/ssl.h> + +namespace transport { + +namespace interface { + +class TLSConsumerSocket : public ConsumerSocket, + public ConsumerSocket::ReadCallback { + /* Return the number of read bytes in readbytes */ + friend int readTLS(BIO *b, char *buf, size_t size, size_t *readbytes); + + /* Return the number of read bytes in the return param */ + friend int readOldTLS(BIO *h, char *buf, int size); + + /* Return the number of written bytes in written */ + friend int writeTLS(BIO *b, const char *buf, size_t size, size_t *written); + + /* Return the number of written bytes in the return param */ + friend int writeOldTLS(BIO *h, const char *buf, int num); + + friend long ctrlTLS(BIO *b, int cmd, long num, void *ptr); + + public: + explicit TLSConsumerSocket(int protocol, SSL *ssl_); + + ~TLSConsumerSocket() = default; + + int consume(const Name &name, std::unique_ptr<utils::MemBuf> &&buffer); + int consume(const Name &name) override; + + int asyncConsume(const Name &name, std::unique_ptr<utils::MemBuf> &&buffer); + int asyncConsume(const Name &name) override; + + void registerPrefix(const Prefix &producer_namespace); + + int setSocketOption( + int socket_option_key, + ConsumerSocket::ReadCallback *socket_option_value) override; + + using ConsumerSocket::getSocketOption; + using ConsumerSocket::setSocketOption; + + protected: + /* Callback invoked once an interest has been received and its payload + * decrypted */ + ConsumerInterestCallback on_interest_input_decrypted_; + ConsumerInterestCallback on_interest_process_decrypted_; + + private: + Name name_; + + /* SSL handle */ + SSL *ssl_; + SSL_CTX *ctx_; + + /* Chain of MemBuf to be used as a temporary buffer to pass descypted data + * from the underlying layer to the application */ + utils::ObjectPool<utils::MemBuf> buf_pool_; + std::unique_ptr<utils::MemBuf> decrypted_content_; + + /* Chain of MemBuf holding the payload to be written into interest or data */ + std::unique_ptr<utils::MemBuf> payload_; + + /* Chain of MemBuf holding the data retrieved from the underlying layer */ + std::unique_ptr<utils::MemBuf> head_; + + bool something_to_read_; + + bool content_downloaded_; + + double old_max_win_; + + double old_current_win_; + + uint32_t random_suffix_; + + ip_address_t secure_prefix_; + + Prefix producer_namespace_; + + ConsumerSocket::ReadCallback *read_callback_decrypted_; + + std::mutex mtx_; + + /* Condition variable for the wait */ + std::condition_variable cv_; + + utils::EventThread async_downloader_tls_; + + void setInterestPayload(ConsumerSocket &c, const core::Interest &interest); + void processPayload(ConsumerSocket &c, std::size_t bytes_transferred, + const std::error_code &ec); + + virtual void getReadBuffer(uint8_t **application_buffer, + size_t *max_length) override; + + virtual void readDataAvailable(size_t length) noexcept override; + + virtual size_t maxBufferSize() const override; + + virtual void readBufferAvailable( + std::unique_ptr<utils::MemBuf> &&buffer) noexcept override; + + virtual void readError(const std::error_code ec) noexcept override; + + virtual void readSuccess(std::size_t total_size) noexcept override; + virtual bool isBufferMovable() noexcept override; + + int download_content(const Name &name); +}; + +} // namespace interface + +} // end namespace transport
\ No newline at end of file diff --git a/libtransport/src/hicn/transport/interfaces/tls_socket_producer.cc b/libtransport/src/hicn/transport/interfaces/tls_socket_producer.cc new file mode 100644 index 000000000..ad85ec6ea --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_socket_producer.cc @@ -0,0 +1,587 @@ +#include <hicn/transport/core/interest.h> +#include <hicn/transport/interfaces/p2psecure_socket_producer.h> +#include <hicn/transport/interfaces/tls_socket_producer.h> + +#include <openssl/bio.h> +#include <openssl/rand.h> +#include <openssl/ssl.h> + +namespace transport { + +namespace interface { + +/* Return the number of read bytes in readbytes */ +int TLSProducerSocket::read(BIO *b, char *buf, size_t size, size_t *readbytes) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = TLSProducerSocket::readOld(b, buf, (int)size); + + if (ret <= 0) { + *readbytes = 0; + return ret; + } + + *readbytes = (size_t)ret; + + return 1; +} + +/* Return the number of read bytes in the return param */ +int TLSProducerSocket::readOld(BIO *b, char *buf, int size) { + TLSProducerSocket *socket; + socket = (TLSProducerSocket *)BIO_get_data(b); + + /* take a lock on the mutex. It will be unlocked by */ + std::unique_lock<std::mutex> lck(socket->mtx_); + if (!socket->something_to_read_) { + (socket->cv_).wait(lck); + } + + /* Either there already is something to read, or the thread has been waken up + */ + /* must return the payload in the interest */ + + utils::MemBuf *membuf = socket->packet_->next(); + int size_to_read; + if ((int)membuf->length() > size) { + size_to_read = size; + } else { + size_to_read = membuf->length(); + socket->something_to_read_ = false; + } + + std::memcpy(buf, membuf->data(), size_to_read); + membuf->trimStart(size_to_read); + + return size_to_read; +} + +/* Return the number of written bytes in written */ +int TLSProducerSocket::write(BIO *b, const char *buf, size_t size, + size_t *written) { + int ret; + + if (size > INT_MAX) size = INT_MAX; + + ret = TLSProducerSocket::writeOld(b, buf, (int)size); + + if (ret <= 0) { + *written = 0; + return ret; + } + + *written = (size_t)ret; + + return 1; +} + +/* Return the number of written bytes in the return param */ +int TLSProducerSocket::writeOld(BIO *b, const char *buf, int num) { + TLSProducerSocket *socket; + socket = (TLSProducerSocket *)BIO_get_data(b); + + if ((SSL_in_before(socket->ssl_) || SSL_in_init(socket->ssl_)) && + socket->first_) { + //! socket->tls_chunks_ corresponds to is_last + socket->tls_chunks_--; + bool making_manifest = socket->parent_->making_manifest_; + socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + false); + socket->parent_->ProducerSocket::produce( + socket->name_, (const uint8_t *)buf, num, socket->tls_chunks_ == 0, + socket->last_segment_); + socket->parent_->setSocketOption(GeneralTransportOptions::MAKE_MANIFEST, + making_manifest); + socket->first_ = false; + } else { + socket->still_writing_ = true; + + std::unique_ptr<utils::MemBuf> mbuf = + utils::MemBuf::copyBuffer(buf, (std::size_t)num, 0, 0); + auto a = mbuf.release(); + socket->async_thread_.add([socket = socket, a]() { + socket->tls_chunks_--; + socket->to_call_oncontentproduced_--; + auto mbuf = std::unique_ptr<utils::MemBuf>(a); + socket->last_segment_ += socket->ProducerSocket::produce( + socket->name_, std::move(mbuf), socket->tls_chunks_ == 0, + socket->last_segment_); + ProducerContentCallback on_content_produced_application; + socket->getSocketOption(ProducerCallbacksOptions::CONTENT_PRODUCED, + on_content_produced_application); + if (socket->to_call_oncontentproduced_ == 0 && + on_content_produced_application) { + on_content_produced_application(*socket, std::error_code(), 0); + } + }); + } + + return num; +} + +TLSProducerSocket::TLSProducerSocket(P2PSecureProducerSocket *parent, + const Name &handshake_name) + : ProducerSocket(), + on_content_produced_application_(), + mtx_(), + cv_(), + something_to_read_(), + name_(), + last_segment_(0), + parent_(parent), + first_(true), + handshake_name_(handshake_name), + tls_chunks_(0), + to_call_oncontentproduced_(0), + still_writing_(false), + encryption_thread_() { + const SSL_METHOD *meth = TLS_server_method(); + ctx_ = SSL_CTX_new(meth); + + /* + * Setup SSL context (identity and parameter to use TLS 1.3) + */ + SSL_CTX_use_certificate(ctx_, parent->cert_509_); + SSL_CTX_use_PrivateKey(ctx_, parent->pkey_rsa_); + + int result = + SSL_CTX_set_ciphersuites(ctx_, + "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_" + "SHA256:TLS_AES_128_GCM_SHA256"); + if (result != 1) { + throw errors::RuntimeException( + "Unable to set cipher list on TLS subsystem. Aborting."); + } + + // We force it to be TLS 1.3 + SSL_CTX_set_min_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_max_proto_version(ctx_, TLS1_3_VERSION); + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, NULL); + SSL_CTX_set_num_tickets(ctx_, 0); + + result = SSL_CTX_add_custom_ext( + ctx_, 100, SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_3_ENCRYPTED_EXTENSIONS, + TLSProducerSocket::addHicnKeyIdCb, TLSProducerSocket::freeHicnKeyIdCb, + this, TLSProducerSocket::parseHicnKeyIdCb, NULL); + + ssl_ = SSL_new(ctx_); + /* + * Setup this producer socker as the bio that TLS will use to write and read + * data (in stream mode) + */ + BIO_METHOD *bio_meth = + BIO_meth_new(BIO_TYPE_ACCEPT, "secure producer socket"); + BIO_meth_set_read(bio_meth, TLSProducerSocket::readOld); + BIO_meth_set_write(bio_meth, TLSProducerSocket::writeOld); + BIO_meth_set_ctrl(bio_meth, TLSProducerSocket::ctrl); + BIO *bio = BIO_new(bio_meth); + BIO_set_init(bio, 1); + BIO_set_data(bio, this); + SSL_set_bio(ssl_, bio, bio); + /* + * Set the callback so that when an interest is received we catch it and we + * decrypt the payload before passing it to the application. + */ + this->ProducerSocket::setSocketOption( + ProducerCallbacksOptions::CACHE_MISS, + (ProducerInterestCallback)std::bind(&TLSProducerSocket::cacheMiss, this, + std::placeholders::_1, + std::placeholders::_2)); + this->ProducerSocket::setSocketOption( + ProducerCallbacksOptions::CONTENT_PRODUCED, + (ProducerContentCallback)bind( + &TLSProducerSocket::onContentProduced, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); +} + +void TLSProducerSocket::accept() { + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + tls_chunks_ = 1; + int result = SSL_accept(ssl_); + if (result != 1) + throw errors::RuntimeException("Unable to perform client handshake"); + } + TRANSPORT_LOGD("Handshake performed!"); + parent_->list_secure_producers.push_front( + std::move(parent_->map_secure_producers[handshake_name_])); + parent_->map_secure_producers.erase(handshake_name_); + + ProducerInterestCallback on_interest_process_decrypted; + getSocketOption(ProducerCallbacksOptions::CACHE_MISS, + on_interest_process_decrypted); + + if (on_interest_process_decrypted) { + Interest inter(std::move(packet_)); + on_interest_process_decrypted(*this, inter); + } else { + throw errors::RuntimeException( + "On interest process unset. Unable to perform handshake"); + } +} + +int TLSProducerSocket::async_accept() { + if (!async_thread_.stopped()) { + async_thread_.add([this]() { this->accept(); }); + } else { + throw errors::RuntimeException( + "Async thread not running, impossible to perform handshake"); + } + + return 1; +} + +void TLSProducerSocket::onInterest(ProducerSocket &p, Interest &interest) { + /* Based on the state machine of (D)TLS, we know what action to do */ + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + std::unique_lock<std::mutex> lck(mtx_); + name_ = interest.getName(); + something_to_read_ = true; + packet_ = interest.acquireMemBufReference(); + if (head_) { + payload_->prependChain(interest.getPayload()); + } else { + payload_ = interest.getPayload(); // std::move(interest.getPayload()); + } + cv_.notify_one(); + } else { + name_ = interest.getName(); + packet_ = interest.acquireMemBufReference(); + payload_ = interest.getPayload(); + something_to_read_ = true; + + if (interest.getPayload()->length() > 0) + SSL_read( + ssl_, + const_cast<unsigned char *>(interest.getPayload()->writableData()), + interest.getPayload()->length()); + } + + ProducerInterestCallback on_interest_input_decrypted; + getSocketOption(ProducerCallbacksOptions::INTEREST_INPUT, + on_interest_input_decrypted); + if (on_interest_input_decrypted) + (on_interest_input_decrypted)(*this, interest); +} + +void TLSProducerSocket::cacheMiss(ProducerSocket &p, Interest &interest) { + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + std::unique_lock<std::mutex> lck(mtx_); + name_ = interest.getName(); + something_to_read_ = true; + packet_ = interest.acquireMemBufReference(); + payload_ = interest.getPayload(); + cv_.notify_one(); + } else { + name_ = interest.getName(); + packet_ = interest.acquireMemBufReference(); + payload_ = interest.getPayload(); + something_to_read_ = true; + + if (interest.getPayload()->length() > 0) + SSL_read( + ssl_, + const_cast<unsigned char *>(interest.getPayload()->writableData()), + interest.getPayload()->length()); + + if (on_interest_process_decrypted_ != VOID_HANDLER) + on_interest_process_decrypted_(*this, interest); + } +} + +void TLSProducerSocket::onContentProduced(ProducerSocket &p, + const std::error_code &err, + uint64_t bytes_written) {} + +uint32_t TLSProducerSocket::produce(Name content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t start_offset) { + if (SSL_in_before(ssl_) || SSL_in_init(ssl_)) { + throw errors::RuntimeException( + "New handshake on the same P2P secure producer socket not supported"); + } + size_t buf_size = buffer->length(); + name_ = served_namespaces_.front().mapName(content_name); + + tls_chunks_ = to_call_oncontentproduced_ = + ceil((float)buf_size / (float)SSL3_RT_MAX_PLAIN_LENGTH); + + if (!is_last) { + tls_chunks_++; + } + + last_segment_ = start_offset; + + SSL_write(ssl_, buffer->data(), buf_size); + BIO *wbio = SSL_get_wbio(ssl_); + int i = BIO_flush(wbio); + (void)i; // To shut up gcc 5 + + return 0; +} + +void TLSProducerSocket::asyncProduce(const Name &content_name, + const uint8_t *buf, size_t buffer_size, + bool is_last, uint32_t *start_offset) { + if (!encryption_thread_.stopped()) { + encryption_thread_.add([this, content_name, buffer = buf, + size = buffer_size, is_last, start_offset]() { + if (start_offset != NULL) { + produce(content_name, buffer, size, is_last, *start_offset); + } else { + produce(content_name, buffer, size, is_last, 0); + } + }); + } +} + +void TLSProducerSocket::asyncProduce(Name content_name, + std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t offset, + uint32_t **last_segment) { + if (!encryption_thread_.stopped()) { + auto a = buffer.release(); + encryption_thread_.add( + [this, content_name, a, is_last, offset, last_segment]() { + auto buf = std::unique_ptr<utils::MemBuf>(a); + if (last_segment != NULL) { + *last_segment = &last_segment_; + } + produce(content_name, std::move(buf), is_last, offset); + }); + } +} + +void TLSProducerSocket::asyncProduce(ContentObject &content_object) { + throw errors::RuntimeException("API not supported"); +} + +void TLSProducerSocket::produce(ContentObject &content_object) { + throw errors::RuntimeException("API not supported"); +} + +long TLSProducerSocket::ctrl(BIO *b, int cmd, long num, void *ptr) { + if (cmd == BIO_CTRL_FLUSH) { + } + return 1; +} + +int TLSProducerSocket::addHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char **out, size_t *outlen, + X509 *x, size_t chainidx, int *al, + void *add_arg) { + TLSProducerSocket *socket = reinterpret_cast<TLSProducerSocket *>(add_arg); + if (ext_type == 100) { + ip_prefix_t ip_prefix = + socket->parent_->served_namespaces_.front().toIpPrefixStruct(); + int inet_family = + socket->parent_->served_namespaces_.front().getAddressFamily(); + uint16_t prefix_len_bits = + socket->parent_->served_namespaces_.front().getPrefixLength(); + uint8_t prefix_len_bytes = prefix_len_bits / 8; + uint8_t prefix_len_u32 = prefix_len_bits / 32; + + ip_prefix_t *out_ip = (ip_prefix_t *)malloc(sizeof(ip_prefix_t)); + out_ip->family = inet_family; + out_ip->len = prefix_len_bits + 32; + u8 *out_ip_buf = const_cast<u8 *>( + ip_address_get_buffer(&(out_ip->address), inet_family)); + *out = reinterpret_cast<unsigned char *>(out_ip); + + RAND_bytes((unsigned char *)&socket->key_id_, 4); + + memcpy(out_ip_buf, ip_address_get_buffer(&(ip_prefix.address), inet_family), + prefix_len_bytes); + memcpy((out_ip_buf + prefix_len_bytes), &socket->key_id_, 4); + *outlen = sizeof(ip_prefix_t); + + ip_address_t mask = {}; + ip_address_t keyId_component = {}; + u32 *mask_buf; + u32 *keyId_component_buf; + switch (inet_family) { + case AF_INET: + mask_buf = &(mask.v4.as_u32); + keyId_component_buf = &(keyId_component.v4.as_u32); + break; + case AF_INET6: + mask_buf = mask.v6.as_u32; + keyId_component_buf = keyId_component.v6.as_u32; + break; + default: + throw errors::RuntimeException("Unknown protocol"); + } + + if (prefix_len_bits > (inet_family == AF_INET6 ? IPV6_ADDR_LEN_BITS - 32 + : IPV4_ADDR_LEN_BITS - 32)) + throw errors::RuntimeException( + "Not enough space in the content name to add key_id"); + + mask_buf[prefix_len_u32] = 0xffffffff; + keyId_component_buf[prefix_len_u32] = socket->key_id_; + socket->last_segment_ = 0; + + socket->on_interest_process_decrypted_ = + socket->parent_->on_interest_process_decrypted_; + + socket->registerPrefix( + Prefix(socket->parent_->served_namespaces_.front().getName( + Name(inet_family, (uint8_t *)&mask), + Name(inet_family, (uint8_t *)&keyId_component), + socket->parent_->served_namespaces_.front().getName()), + out_ip->len)); + socket->connect(); + } + return 1; +} + +void TLSProducerSocket::freeHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char *out, + void *add_arg) { + free(const_cast<unsigned char *>(out)); +} + +int TLSProducerSocket::parseHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, + const unsigned char *in, size_t inlen, + X509 *x, size_t chainidx, int *al, + void *add_arg) { + return 1; +} + +int TLSProducerSocket::setSocketOption( + int socket_option_key, ProducerInterestCallback socket_option_value) { + return rescheduleOnIOService( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ProducerInterestCallback socket_option_value) -> int { + int result = SOCKET_OPTION_SET; + switch (socket_option_key) { + case ProducerCallbacksOptions::INTEREST_INPUT: + on_interest_input_decrypted_ = socket_option_value; + break; + + case ProducerCallbacksOptions::INTEREST_DROP: + on_interest_dropped_input_buffer_ = socket_option_value; + break; + + case ProducerCallbacksOptions::INTEREST_PASS: + on_interest_inserted_input_buffer_ = socket_option_value; + break; + + case ProducerCallbacksOptions::CACHE_HIT: + on_interest_satisfied_output_buffer_ = socket_option_value; + break; + + case ProducerCallbacksOptions::CACHE_MISS: + on_interest_process_decrypted_ = socket_option_value; + break; + + default: + result = SOCKET_OPTION_NOT_SET; + break; + } + return result; + }); +} + +int TLSProducerSocket::setSocketOption( + int socket_option_key, ProducerContentCallback socket_option_value) { + return rescheduleOnIOService( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ProducerContentCallback socket_option_value) -> int { + switch (socket_option_key) { + case ProducerCallbacksOptions::CONTENT_PRODUCED: + on_content_produced_application_ = socket_option_value; + break; + + default: + return SOCKET_OPTION_NOT_SET; + } + + return SOCKET_OPTION_SET; + }); +} + +int TLSProducerSocket::getSocketOption( + int socket_option_key, ProducerContentCallback **socket_option_value) { + return rescheduleOnIOService( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ProducerContentCallback **socket_option_value) -> int { + switch (socket_option_key) { + case ProducerCallbacksOptions::CONTENT_PRODUCED: + *socket_option_value = &on_content_produced_application_; + break; + + default: + return SOCKET_OPTION_NOT_GET; + } + + return SOCKET_OPTION_GET; + }); +} + +int TLSProducerSocket::getSocketOption( + int socket_option_key, ProducerContentCallback &socket_option_value) { + return rescheduleOnIOServiceWithReference( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ProducerContentCallback &socket_option_value) -> int { + switch (socket_option_key) { + case ProducerCallbacksOptions::CONTENT_PRODUCED: + socket_option_value = on_content_produced_application_; + break; + + default: + return SOCKET_OPTION_NOT_GET; + } + + return SOCKET_OPTION_GET; + }); +} + +int TLSProducerSocket::getSocketOption( + int socket_option_key, ProducerInterestCallback &socket_option_value) { + // Reschedule the function on the io_service to avoid race condition in case + // setSocketOption is called while the io_service is running. + return rescheduleOnIOServiceWithReference( + socket_option_key, socket_option_value, + [this](int socket_option_key, + ProducerInterestCallback &socket_option_value) -> int { + switch (socket_option_key) { + case ProducerCallbacksOptions::INTEREST_INPUT: + socket_option_value = on_interest_input_decrypted_; + break; + + case ProducerCallbacksOptions::INTEREST_DROP: + socket_option_value = on_interest_dropped_input_buffer_; + break; + + case ProducerCallbacksOptions::INTEREST_PASS: + socket_option_value = on_interest_inserted_input_buffer_; + break; + + case ProducerCallbacksOptions::CACHE_HIT: + socket_option_value = on_interest_satisfied_output_buffer_; + break; + + case ProducerCallbacksOptions::CACHE_MISS: + socket_option_value = on_interest_process_decrypted_; + break; + + default: + return SOCKET_OPTION_NOT_GET; + } + + return SOCKET_OPTION_GET; + }); +} + +} // namespace interface + +} // namespace transport diff --git a/libtransport/src/hicn/transport/interfaces/tls_socket_producer.h b/libtransport/src/hicn/transport/interfaces/tls_socket_producer.h new file mode 100644 index 000000000..4c09ddaa5 --- /dev/null +++ b/libtransport/src/hicn/transport/interfaces/tls_socket_producer.h @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include <hicn/transport/interfaces/socket_producer.h> +#include <hicn/transport/utils/content_store.h> + +#include <openssl/ssl.h> +#include <condition_variable> +#include <mutex> + +namespace transport { + +namespace interface { + +class P2PSecureProducerSocket; + +class TLSProducerSocket : virtual public ProducerSocket { + friend class P2PSecureProducerSocket; + + public: + explicit TLSProducerSocket(P2PSecureProducerSocket *parent, + const Name &handshake_name); + ~TLSProducerSocket() = default; + + uint32_t produce(Name content_name, const uint8_t *buffer, size_t buffer_size, + bool is_last = true, uint32_t start_offset = 0) override { + return produce(content_name, utils::MemBuf::copyBuffer(buffer, buffer_size), + is_last, start_offset); + } + + uint32_t produce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last = true, uint32_t start_offset = 0) override; + + void produce(ContentObject &content_object) override; + + void asyncProduce(const Name &suffix, const uint8_t *buf, size_t buffer_size, + bool is_last = true, + uint32_t *start_offset = nullptr) override; + + void asyncProduce(Name content_name, std::unique_ptr<utils::MemBuf> &&buffer, + bool is_last, uint32_t offset, + uint32_t **last_segment = nullptr) override; + + void asyncProduce(ContentObject &content_object) override; + + virtual void accept(); + + virtual int async_accept(); + + virtual int setSocketOption( + int socket_option_key, + ProducerInterestCallback socket_option_value) override; + + virtual int setSocketOption( + int socket_option_key, + ProducerContentCallback socket_option_value) override; + + virtual int getSocketOption( + int socket_option_key, + ProducerContentCallback **socket_option_value) override; + + int getSocketOption(int socket_option_key, + ProducerContentCallback &socket_option_value); + + int getSocketOption(int socket_option_key, + ProducerInterestCallback &socket_option_value); + + using ProducerSocket::getSocketOption; + using ProducerSocket::onInterest; + using ProducerSocket::setSocketOption; + + protected: + /* Callback invoked once an interest has been received and its payload + * decrypted */ + ProducerInterestCallback on_interest_input_decrypted_; + ProducerInterestCallback on_interest_process_decrypted_; + ProducerContentCallback on_content_produced_application_; + + std::mutex mtx_; + + /* Condition variable for the wait */ + std::condition_variable cv_; + + /* Bool variable, true if there is something to read (an interest arrived) */ + bool something_to_read_; + + /* First interest that open a secure connection */ + transport::core::Name name_; + + /* SSL handle */ + SSL *ssl_; + SSL_CTX *ctx_; + + Packet::MemBufPtr packet_; + + std::unique_ptr<utils::MemBuf> head_; + std::uint32_t last_segment_; + std::shared_ptr<utils::MemBuf> payload_; + std::uint32_t key_id_; + + std::thread *handshake; + P2PSecureProducerSocket *parent_; + + bool first_; + Name handshake_name_; + int tls_chunks_; + int to_call_oncontentproduced_; + + bool still_writing_; + + utils::EventThread encryption_thread_; + + void onInterest(ProducerSocket &p, Interest &interest); + void cacheMiss(ProducerSocket &p, Interest &interest); + + /* Return the number of read bytes in readbytes */ + static int read(BIO *b, char *buf, size_t size, size_t *readbytes); + + /* Return the number of read bytes in the return param */ + static int readOld(BIO *h, char *buf, int size); + + /* Return the number of written bytes in written */ + static int write(BIO *b, const char *buf, size_t size, size_t *written); + + /* Return the number of written bytes in the return param */ + static int writeOld(BIO *h, const char *buf, int num); + + static long ctrl(BIO *b, int cmd, long num, void *ptr); + + static int addHicnKeyIdCb(SSL *s, unsigned int ext_type, unsigned int context, + const unsigned char **out, size_t *outlen, X509 *x, + size_t chainidx, int *al, void *add_arg); + + static void freeHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, const unsigned char *out, + void *add_arg); + + static int parseHicnKeyIdCb(SSL *s, unsigned int ext_type, + unsigned int context, const unsigned char *in, + size_t inlen, X509 *x, size_t chainidx, int *al, + void *add_arg); + + void onContentProduced(ProducerSocket &p, const std::error_code &err, + uint64_t bytes_written); +}; + +} // namespace interface + +} // end namespace transport diff --git a/libtransport/src/hicn/transport/protocols/incremental_indexer.h b/libtransport/src/hicn/transport/protocols/incremental_indexer.h index ea84d645a..b587a8332 100644 --- a/libtransport/src/hicn/transport/protocols/incremental_indexer.h +++ b/libtransport/src/hicn/transport/protocols/incremental_indexer.h @@ -123,6 +123,10 @@ class IncrementalIndexer : public Indexer { } } + TRANSPORT_ALWAYS_INLINE bool onKeyToVerify() override { + return verification_manager_->onKeyToVerify(); + } + protected: interface::ConsumerSocket *socket_; Reassembly *reassembly_; diff --git a/libtransport/src/hicn/transport/protocols/indexer.cc b/libtransport/src/hicn/transport/protocols/indexer.cc index c50c4236b..d95c10ff9 100644 --- a/libtransport/src/hicn/transport/protocols/indexer.cc +++ b/libtransport/src/hicn/transport/protocols/indexer.cc @@ -63,6 +63,10 @@ void IndexManager::onContentObject(core::Interest::Ptr &&interest, } } +bool IndexManager::onKeyToVerify() { + return indexer_->onKeyToVerify(); +} + void IndexManager::reset(std::uint32_t offset) { indexer_ = std::make_unique<IncrementalIndexer>(icn_socket_, transport_, reassembly_); @@ -71,4 +75,4 @@ void IndexManager::reset(std::uint32_t offset) { } } // namespace protocol -} // namespace transport
\ No newline at end of file +} // namespace transport diff --git a/libtransport/src/hicn/transport/protocols/indexer.h b/libtransport/src/hicn/transport/protocols/indexer.h index 89751095e..87cf9b307 100644 --- a/libtransport/src/hicn/transport/protocols/indexer.h +++ b/libtransport/src/hicn/transport/protocols/indexer.h @@ -57,6 +57,8 @@ class Indexer { virtual void onContentObject(core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) = 0; + + virtual bool onKeyToVerify() = 0; }; class IndexManager : Indexer { @@ -87,6 +89,8 @@ class IndexManager : Indexer { void onContentObject(core::Interest::Ptr &&interest, core::ContentObject::Ptr &&content_object) override; + bool onKeyToVerify() override; + private: std::unique_ptr<Indexer> indexer_; bool first_segment_received_; diff --git a/libtransport/src/hicn/transport/protocols/manifest_indexing_manager.cc b/libtransport/src/hicn/transport/protocols/manifest_indexing_manager.cc new file mode 100644 index 000000000..5bf9c89f7 --- /dev/null +++ b/libtransport/src/hicn/transport/protocols/manifest_indexing_manager.cc @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2017-2019 Cisco and/or its affiliates. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/protocols/manifest_indexing_manager.h> + +#include <cmath> +#include <deque> + +namespace transport { + +namespace protocol { + +using namespace interface; + +ManifestIndexManager::ManifestIndexManager( + interface::ConsumerSocket *icn_socket, TransportProtocol *next_interest) + : IncrementalIndexManager(icn_socket), + PacketManager<Interest>(1024), + next_to_retrieve_segment_(suffix_queue_.end()), + suffix_manifest_(core::NextSegmentCalculationStrategy::INCREMENTAL, 0), + next_reassembly_segment_( + core::NextSegmentCalculationStrategy::INCREMENTAL, 1, true), + ignored_segments_(), + next_interest_(next_interest) {} + +bool ManifestIndexManager::onManifest( + core::ContentObject::Ptr &&content_object) { + auto manifest = + std::make_unique<ContentObjectManifest>(std::move(*content_object)); + bool manifest_verified = verification_manager_->onPacketToVerify(*manifest); + + if (manifest_verified) { + manifest->decode(); + + if (TRANSPORT_EXPECT_FALSE(manifest->getVersion() != + core::ManifestVersion::VERSION_1)) { + throw errors::RuntimeException("Received manifest with unknown version."); + } + + switch (manifest->getManifestType()) { + case core::ManifestType::INLINE_MANIFEST: { + auto _it = manifest->getSuffixList().begin(); + auto _end = manifest->getSuffixList().end(); + size_t nb_segments = std::distance(_it, _end); + final_suffix_ = manifest->getFinalBlockNumber(); // final block number + + TRANSPORT_LOGD("Received manifest %u", + manifest->getWritableName().getSuffix()); + suffix_hash_map_[_it->first] = + std::make_pair(std::vector<uint8_t>(_it->second, _it->second + 32), + manifest->getHashAlgorithm()); + suffix_queue_.push_back(_it->first); + + // If the transport protocol finished the list of segments to retrieve, + // reset the next_to_retrieve_segment_ iterator to the next segment + // provided by this manifest. + if (TRANSPORT_EXPECT_FALSE(next_to_retrieve_segment_ == + suffix_queue_.end())) { + next_to_retrieve_segment_ = --suffix_queue_.end(); + } + + std::advance(_it, 1); + for (; _it != _end; _it++) { + suffix_hash_map_[_it->first] = std::make_pair( + std::vector<uint8_t>(_it->second, _it->second + 32), + manifest->getHashAlgorithm()); + suffix_queue_.push_back(_it->first); + } + + if (TRANSPORT_EXPECT_FALSE(manifest->getName().getSuffix()) == 0) { + core::NextSegmentCalculationStrategy strategy = + manifest->getNextSegmentCalculationStrategy(); + + suffix_manifest_.reset(0); + suffix_manifest_.setNbSegments(nb_segments); + suffix_manifest_.setSuffixStrategy(strategy); + TRANSPORT_LOGD("Capacity of 1st manifest %zu", + suffix_manifest_.getNbSegments()); + + next_reassembly_segment_.reset(*suffix_queue_.begin()); + next_reassembly_segment_.setNbSegments(nb_segments); + suffix_manifest_.setSuffixStrategy(strategy); + } + + // If the manifest is not full, we add the suffixes of missing segments + // to the list of segments to ignore when computing the next reassembly + // index. + if (TRANSPORT_EXPECT_FALSE( + suffix_manifest_.getNbSegments() - nb_segments > 0)) { + auto start = manifest->getSuffixList().begin(); + auto last = --_end; + for (uint32_t i = last->first + 1; + i < start->first + suffix_manifest_.getNbSegments(); i++) { + ignored_segments_.push_back(i); + } + } + + if (TRANSPORT_EXPECT_FALSE(manifest->isFinalManifest()) == 0) { + fillWindow(manifest->getWritableName(), + manifest->getName().getSuffix()); + } + + break; + } + case core::ManifestType::FLIC_MANIFEST: { + throw errors::NotImplementedException(); + } + case core::ManifestType::FINAL_CHUNK_NUMBER: { + throw errors::NotImplementedException(); + } + } + } + + return manifest_verified; +} + +void ManifestIndexManager::onManifestReceived(Interest::Ptr &&i, + ContentObject::Ptr &&c) { + onManifest(std::move(c)); + if (next_interest_) { + next_interest_->scheduleNextInterests(); + } +} + +void ManifestIndexManager::onManifestTimeout(Interest::Ptr &&i) { + const Name &n = i->getName(); + uint32_t segment = n.getSuffix(); + + if (segment > final_suffix_) { + return; + } + + TRANSPORT_LOGD("Timeout on manifest %u", segment); + // Get portal + std::shared_ptr<interface::BasePortal> portal; + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal); + + // Send requests for manifest out of the congestion window (no + // in_flight_interests++) + portal->sendInterest( + std::move(i), + std::bind(&ManifestIndexManager::onManifestReceived, this, + std::placeholders::_1, std::placeholders::_2), + std::bind(&ManifestIndexManager::onManifestTimeout, this, + std::placeholders::_1)); +} + +void ManifestIndexManager::fillWindow(Name &name, uint32_t current_manifest) { + /* Send as many manifest as required for filling window. */ + uint32_t interest_lifetime; + double window_size; + std::shared_ptr<interface::BasePortal> portal; + Interest::Ptr interest; + uint32_t current_segment = *next_to_retrieve_segment_; + // suffix_manifest_ now points to the next manifest to request + uint32_t last_requested_manifest = (suffix_manifest_++).getSuffix(); + + socket_->getSocketOption(GeneralTransportOptions::PORTAL, portal); + socket_->getSocketOption(GeneralTransportOptions::INTEREST_LIFETIME, + interest_lifetime); + socket_->getSocketOption(GeneralTransportOptions::CURRENT_WINDOW_SIZE, + window_size); + + if (TRANSPORT_EXPECT_FALSE(suffix_manifest_.getSuffix() >= final_suffix_)) { + suffix_manifest_.updateSuffix(last_requested_manifest); + return; + } + + if (current_segment + window_size < suffix_manifest_.getSuffix() && + current_manifest != last_requested_manifest) { + suffix_manifest_.updateSuffix(last_requested_manifest); + return; + } + + do { + interest = getPacket(); + name.setSuffix(suffix_manifest_.getSuffix()); + interest->setName(name); + interest->setLifetime(interest_lifetime); + + // Send interests for manifest out of the congestion window (no + // in_flight_interests++) + portal->sendInterest( + std::move(interest), + std::bind(&ManifestIndexManager::onManifestReceived, this, + std::placeholders::_1, std::placeholders::_2), + std::bind(&ManifestIndexManager::onManifestTimeout, this, + std::placeholders::_1)); + TRANSPORT_LOGD("Send manifest interest %u", name.getSuffix()); + + last_requested_manifest = (suffix_manifest_++).getSuffix(); + } while (current_segment + window_size >= suffix_manifest_.getSuffix() && + suffix_manifest_.getSuffix() < final_suffix_); + + // suffix_manifest_ now points to the last requested manifest + suffix_manifest_.updateSuffix(last_requested_manifest); +} + +bool ManifestIndexManager::onContentObject( + const core::ContentObject &content_object) { + bool verify_signature; + socket_->getSocketOption(GeneralTransportOptions::VERIFY_SIGNATURE, + verify_signature); + + if (!verify_signature) { + return true; + } + + uint64_t segment = content_object.getName().getSuffix(); + + bool ret = false; + + auto it = suffix_hash_map_.find((const unsigned int)segment); + if (it != suffix_hash_map_.end()) { + auto hash_type = static_cast<utils::CryptoHashType>(it->second.second); + auto data_packet_digest = content_object.computeDigest(it->second.second); + auto data_packet_digest_bytes = + data_packet_digest.getDigest<uint8_t>().data(); + std::vector<uint8_t> &manifest_digest_bytes = it->second.first; + + if (utils::CryptoHash::compareBinaryDigest(data_packet_digest_bytes, + manifest_digest_bytes.data(), + hash_type)) { + suffix_hash_map_.erase(it); + ret = true; + } else { + throw errors::RuntimeException( + "Verification failure policy has to be implemented."); + } + } + + return ret; +} + +uint32_t ManifestIndexManager::getNextSuffix() { + if (TRANSPORT_EXPECT_FALSE(next_to_retrieve_segment_ == + suffix_queue_.end())) { + return invalid_index; + } + + return *next_to_retrieve_segment_++; +} + +uint32_t ManifestIndexManager::getFinalSuffix() { return final_suffix_; } + +bool ManifestIndexManager::isFinalSuffixDiscovered() { + return IncrementalIndexManager::isFinalSuffixDiscovered(); +} + +uint32_t ManifestIndexManager::getNextReassemblySegment() { + uint32_t current_reassembly_segment; + + while (true) { + current_reassembly_segment = next_reassembly_segment_.getSuffix(); + next_reassembly_segment_++; + + if (TRANSPORT_EXPECT_FALSE(current_reassembly_segment > final_suffix_)) { + return invalid_index; + } + + if (ignored_segments_.empty()) break; + + auto is_ignored = + std::find(ignored_segments_.begin(), ignored_segments_.end(), + current_reassembly_segment); + + if (is_ignored == ignored_segments_.end()) break; + + ignored_segments_.erase(is_ignored); + } + + return current_reassembly_segment; +} + +void ManifestIndexManager::reset() { + IncrementalIndexManager::reset(); + suffix_manifest_.reset(0); + suffix_queue_.clear(); + suffix_hash_map_.clear(); +} + +} // end namespace protocol + +} // end namespace transport diff --git a/libtransport/src/hicn/transport/protocols/protocol.cc b/libtransport/src/hicn/transport/protocols/protocol.cc index a0f847453..db461a66f 100644 --- a/libtransport/src/hicn/transport/protocols/protocol.cc +++ b/libtransport/src/hicn/transport/protocols/protocol.cc @@ -96,6 +96,8 @@ void TransportProtocol::onContentReassembled(std::error_code ec) { } stop(); + + on_payload->afterRead(); } } // end namespace protocol diff --git a/libtransport/src/hicn/transport/protocols/protocol.h b/libtransport/src/hicn/transport/protocols/protocol.h index 87fab588b..4897da902 100644 --- a/libtransport/src/hicn/transport/protocols/protocol.h +++ b/libtransport/src/hicn/transport/protocols/protocol.h @@ -60,6 +60,8 @@ class TransportProtocol : public interface::BasePortal::ConsumerCallback, virtual void resume(); + virtual bool verifyKeyPackets() = 0; + virtual void scheduleNextInterests() = 0; // Events generated by the indexing diff --git a/libtransport/src/hicn/transport/protocols/raaqm.cc b/libtransport/src/hicn/transport/protocols/raaqm.cc index 641ae45c3..984470edb 100644 --- a/libtransport/src/hicn/transport/protocols/raaqm.cc +++ b/libtransport/src/hicn/transport/protocols/raaqm.cc @@ -115,7 +115,12 @@ void RaaqmTransportProtocol::reset() { t0_ = utils::SteadyClock::now(); } +bool RaaqmTransportProtocol::verifyKeyPackets() { + return index_manager_->onKeyToVerify(); +} + void RaaqmTransportProtocol::increaseWindow() { + // return; double max_window_size = 0.; socket_->getSocketOption(GeneralTransportOptions::MAX_WINDOW_SIZE, max_window_size); @@ -131,6 +136,7 @@ void RaaqmTransportProtocol::increaseWindow() { } void RaaqmTransportProtocol::decreaseWindow() { + // return; double min_window_size = 0.; socket_->getSocketOption(GeneralTransportOptions::MIN_WINDOW_SIZE, min_window_size); @@ -404,7 +410,7 @@ void RaaqmTransportProtocol::onTimeout(Interest::Ptr &&interest) { const Name &n = interest->getName(); - TRANSPORT_LOGW("Timeout on %s", n.toString().c_str()); + TRANSPORT_LOGW("Timeout on content %s", n.toString().c_str()); if (TRANSPORT_EXPECT_FALSE(!is_running_)) { return; @@ -474,21 +480,27 @@ void RaaqmTransportProtocol::scheduleNextInterests() { // send at least one interest if there are retransmissions to perform and // there is no space left in the window sendInterest(std::move(interest_to_retransmit_.front())); + TRANSPORT_LOGD("Window full, retransmit one content interest"); interest_to_retransmit_.pop(); } uint32_t index = IndexManager::invalid_index; + // Send the interest needed for filling the window while (interests_in_flight_ < current_window_size_) { if (interest_to_retransmit_.size() > 0) { sendInterest(std::move(interest_to_retransmit_.front())); + TRANSPORT_LOGD("Retransmit content interest"); interest_to_retransmit_.pop(); } else { index = index_manager_->getNextSuffix(); if (index == IndexManager::invalid_index) { + TRANSPORT_LOGE("INVALID INDEX %d", index); break; } + sendInterest(index); + TRANSPORT_LOGD("Send content interest %u", index); } } } diff --git a/libtransport/src/hicn/transport/protocols/raaqm.h b/libtransport/src/hicn/transport/protocols/raaqm.h index 7fc540c9f..f2d819ec5 100644 --- a/libtransport/src/hicn/transport/protocols/raaqm.h +++ b/libtransport/src/hicn/transport/protocols/raaqm.h @@ -42,6 +42,8 @@ class RaaqmTransportProtocol : public TransportProtocol, void reset() override; + virtual bool verifyKeyPackets() override; + protected: static constexpr uint32_t buffer_size = 1 << interface::default_values::log_2_default_buffer_size; @@ -136,4 +138,4 @@ class RaaqmTransportProtocol : public TransportProtocol, } // end namespace protocol -} // end namespace transport
\ No newline at end of file +} // end namespace transport diff --git a/libtransport/src/hicn/transport/protocols/rate_estimation.cc b/libtransport/src/hicn/transport/protocols/rate_estimation.cc index 99fc78c3e..50306e6e5 100644 --- a/libtransport/src/hicn/transport/protocols/rate_estimation.cc +++ b/libtransport/src/hicn/transport/protocols/rate_estimation.cc @@ -186,7 +186,7 @@ void ALaTcpEstimator::onDataReceived(int packet_size) { SimpleEstimator::SimpleEstimator(double alphaArg, int batching_param) { this->estimation_ = 0.0; this->estimated_ = false; - this->observer_ = NULL; + this->observer_ = nullptr; this->batching_param_ = batching_param; this->total_size_ = 0.0; this->number_of_packets_ = 0; @@ -260,7 +260,7 @@ void SimpleEstimator::onDataReceived(int packet_size) { void SimpleEstimator::onRttUpdate(double rtt) { this->number_of_packets_++; - if (number_of_packets_ == this->batching_param_) { + if (this->number_of_packets_ == this->batching_param_) { TimePoint end = std::chrono::steady_clock::now(); auto delay = std::chrono::duration_cast<Microseconds>(end - this->begin_batch_) diff --git a/libtransport/src/hicn/transport/protocols/rtc.cc b/libtransport/src/hicn/transport/protocols/rtc.cc index e371217f8..fece95d03 100644 --- a/libtransport/src/hicn/transport/protocols/rtc.cc +++ b/libtransport/src/hicn/transport/protocols/rtc.cc @@ -48,10 +48,22 @@ RTCTransportProtocol::~RTCTransportProtocol() { } int RTCTransportProtocol::start() { + if (is_running_) return -1; + + reset(); + is_first_ = true; + probeRtt(); sentinelTimer(); newRound(); - return TransportProtocol::start(); + scheduleNextInterests(); + + is_first_ = false; + is_running_ = true; + portal_->runEventsLoop(); + is_running_ = false; + + return 0; } void RTCTransportProtocol::stop() { @@ -65,7 +77,6 @@ void RTCTransportProtocol::resume() { if (is_running_) return; is_running_ = true; - inflightInterestsCount_ = 0; probeRtt(); @@ -74,7 +85,6 @@ void RTCTransportProtocol::resume() { scheduleNextInterests(); portal_->runEventsLoop(); - is_running_ = false; } @@ -190,7 +200,6 @@ void RTCTransportProtocol::updateDelayStats( // we collect OWD only for datapackets if (payload->length() != HICN_NACK_HEADER_SIZE) { uint64_t *senderTimeStamp = (uint64_t *)payload->data(); - int64_t OWD = std::chrono::duration_cast<std::chrono::milliseconds>( std::chrono::steady_clock::now().time_since_epoch()) .count() - @@ -288,7 +297,6 @@ void RTCTransportProtocol::updateStats(uint32_t round_duration) { stats_->updateAverageRtt(pathTable_[producerPathLabels_[1]]->getMinRtt()); (*stats_callback)(*socket_, *stats_); } - // bound also by interest lifitime* production rate if (!gotNack_) { roundsWithoutNacks_++; @@ -403,6 +411,15 @@ void RTCTransportProtocol::increaseWindow() { } void RTCTransportProtocol::probeRtt() { + probe_timer_->expires_from_now(std::chrono::milliseconds(1000)); + probe_timer_->async_wait([this](std::error_code ec) { + if (ec) return; + probeRtt(); + }); + + // To avoid sending the first probe, because the transport is not running yet + if (is_first_ && !is_running_) return; + time_sent_probe_ = std::chrono::duration_cast<std::chrono::milliseconds>( std::chrono::steady_clock::now().time_since_epoch()) .count(); @@ -419,13 +436,9 @@ void RTCTransportProtocol::probeRtt() { // we considere the probe as a rtx so that we do not incresea inFlightInt received_probe_ = false; + TRANSPORT_LOGD("Send content interest %u (probeRtt)", + interest_name->getSuffix()); sendInterest(interest_name, true); - - probe_timer_->expires_from_now(std::chrono::milliseconds(1000)); - probe_timer_->async_wait([this](std::error_code ec) { - if (ec) return; - probeRtt(); - }); } void RTCTransportProtocol::sendInterest(Name *interest_name, bool rtx) { @@ -463,6 +476,10 @@ void RTCTransportProtocol::sendInterest(Name *interest_name, bool rtx) { void RTCTransportProtocol::scheduleNextInterests() { if (!is_running_ && !is_first_) return; + TRANSPORT_LOGD("----- [window %u - inflight_interests %u = %d] -----", + currentCWin_, inflightInterestsCount_, + currentCWin_ - inflightInterestsCount_); + while (inflightInterestsCount_ < currentCWin_) { Name *interest_name = nullptr; socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, @@ -479,6 +496,9 @@ void RTCTransportProtocol::scheduleNextInterests() { inflightInterests_[pkt].state = sent_; inflightInterests_[pkt].sequence = actualSegment_; actualSegment_ = (actualSegment_ + 1) % HICN_MIN_PROBE_SEQ; + TRANSPORT_LOGD( + "Send content interest %u (scheduleNextInterests no replies)", + interest_name->getSuffix()); sendInterest(interest_name, false); return; } @@ -515,8 +535,17 @@ void RTCTransportProtocol::scheduleNextInterests() { inflightInterests_[pkt].sequence = actualSegment_; actualSegment_ = (actualSegment_ + 1) % HICN_MIN_PROBE_SEQ; + TRANSPORT_LOGD("Send content interest %u (scheduleNextInterests)", + interest_name->getSuffix()); sendInterest(interest_name, false); } + + TRANSPORT_LOGD("----- end of scheduleNextInterest -----"); +} + +bool RTCTransportProtocol::verifyKeyPackets() { + // Not yet implemented + return false; } void RTCTransportProtocol::sentinelTimer() { @@ -703,6 +732,8 @@ uint64_t RTCTransportProtocol::retransmit() { socket_->getSocketOption(GeneralTransportOptions::NETWORK_NAME, &interest_name); interest_name->setSuffix(it->first); + TRANSPORT_LOGD("Send content interest %u (retransmit)", + interest_name->getSuffix()); sendInterest(interest_name, true); } else if (rtx_time < smallest_timeout) { smallest_timeout = rtx_time; @@ -866,6 +897,7 @@ void RTCTransportProtocol::onContentObject( } if (segmentNumber >= HICN_MIN_PROBE_SEQ) { + TRANSPORT_LOGD("Received probe %u", segmentNumber); if (segmentNumber == probe_seq_number_ && !received_probe_) { received_probe_ = true; @@ -899,6 +931,8 @@ void RTCTransportProtocol::onContentObject( } if (payload_size == HICN_NACK_HEADER_SIZE) { + TRANSPORT_LOGD("Received nack %u", segmentNumber); + if (inflightInterests_[pkt].state == sent_) { lastEvent_ = std::chrono::duration_cast<std::chrono::milliseconds>( std::chrono::steady_clock::now().time_since_epoch()) @@ -930,6 +964,8 @@ void RTCTransportProtocol::onContentObject( } } else { + TRANSPORT_LOGD("Received content %u", segmentNumber); + avgPacketSize_ = (HICN_ESTIMATED_PACKET_SIZE * avgPacketSize_) + ((1 - HICN_ESTIMATED_PACKET_SIZE) * payload->length()); diff --git a/libtransport/src/hicn/transport/protocols/rtc.h b/libtransport/src/hicn/transport/protocols/rtc.h index 9e1731e96..f34afbb5f 100644 --- a/libtransport/src/hicn/transport/protocols/rtc.h +++ b/libtransport/src/hicn/transport/protocols/rtc.h @@ -81,8 +81,8 @@ typedef enum packetState packetState_t; struct sentInterest { uint64_t transmissionTime; - uint32_t sequence; // sequence number of the interest sent - // to handle seq % buffer_size + uint32_t sequence; // sequence number of the interest sent + // to handle seq % buffer_size packetState_t state; // see packet state }; @@ -99,6 +99,8 @@ class RTCTransportProtocol : public TransportProtocol, void resume() override; + bool verifyKeyPackets() override; + private: // algo functions void reset() override; diff --git a/libtransport/src/hicn/transport/protocols/statistics.h b/libtransport/src/hicn/transport/protocols/statistics.h index d5e89b96d..c92940ab4 100644 --- a/libtransport/src/hicn/transport/protocols/statistics.h +++ b/libtransport/src/hicn/transport/protocols/statistics.h @@ -82,9 +82,7 @@ class TransportStatistics { return interest_tx_; } - TRANSPORT_ALWAYS_INLINE double getLossRatio() const { - return loss_ratio_; - } + TRANSPORT_ALWAYS_INLINE double getLossRatio() const { return loss_ratio_; } TRANSPORT_ALWAYS_INLINE double getQueuingDelay() const { return queuing_delay_; diff --git a/libtransport/src/hicn/transport/protocols/verification_manager.cc b/libtransport/src/hicn/transport/protocols/verification_manager.cc index f45cab743..74faf0521 100644 --- a/libtransport/src/hicn/transport/protocols/verification_manager.cc +++ b/libtransport/src/hicn/transport/protocols/verification_manager.cc @@ -13,9 +13,8 @@ * limitations under the License. */ -#include <hicn/transport/protocols/verification_manager.h> - #include <hicn/transport/interfaces/socket_consumer.h> +#include <hicn/transport/protocols/verification_manager.h> namespace transport { @@ -25,20 +24,30 @@ interface::VerificationPolicy SignatureVerificationManager::onPacketToVerify( const Packet& packet) { using namespace interface; - bool verify_signature; + bool verify_signature = false, key_content = false; VerificationPolicy ret = VerificationPolicy::DROP_PACKET; - ConsumerContentObjectVerificationFailedCallback* - verification_failed_callback = VOID_HANDLER; icn_socket_->getSocketOption(GeneralTransportOptions::VERIFY_SIGNATURE, verify_signature); + icn_socket_->getSocketOption(GeneralTransportOptions::KEY_CONTENT, + key_content); if (!verify_signature) { return VerificationPolicy::ACCEPT_PACKET; } + if (key_content) { + key_packets_.push(copyPacket(packet)); + return VerificationPolicy::ACCEPT_PACKET; + } else if (!key_packets_.empty()) { + std::queue<ContentObjectPtr>().swap(key_packets_); + } + + ConsumerContentObjectVerificationFailedCallback* + verification_failed_callback = VOID_HANDLER; icn_socket_->getSocketOption(ConsumerCallbacksOptions::VERIFICATION_FAILED, &verification_failed_callback); + if (!verification_failed_callback) { throw errors::RuntimeException( "No verification failed callback provided by application. " @@ -66,6 +75,22 @@ interface::VerificationPolicy SignatureVerificationManager::onPacketToVerify( return ret; } +bool SignatureVerificationManager::onKeyToVerify() { + if (TRANSPORT_EXPECT_FALSE(key_packets_.empty())) { + throw errors::RuntimeException("No key to verify."); + } + + while (!key_packets_.empty()) { + ContentObjectPtr packet_to_verify = key_packets_.front(); + key_packets_.pop(); + if (onPacketToVerify(*packet_to_verify) != + VerificationPolicy::ACCEPT_PACKET) + return false; + } + + return true; +} + } // end namespace protocol } // end namespace transport diff --git a/libtransport/src/hicn/transport/protocols/verification_manager.h b/libtransport/src/hicn/transport/protocols/verification_manager.h index 6e5d32127..293e8103a 100644 --- a/libtransport/src/hicn/transport/protocols/verification_manager.h +++ b/libtransport/src/hicn/transport/protocols/verification_manager.h @@ -30,22 +30,36 @@ namespace protocol { using Packet = core::Packet; using interface::ConsumerSocket; using interface::VerificationPolicy; +using ContentObjectPtr = std::shared_ptr<core::ContentObject>; class VerificationManager { public: virtual ~VerificationManager() = default; virtual VerificationPolicy onPacketToVerify(const Packet& packet) = 0; + virtual bool onKeyToVerify() { return false; } }; class SignatureVerificationManager : public VerificationManager { public: - SignatureVerificationManager(ConsumerSocket* icn_socket) - : icn_socket_(icn_socket) {} + SignatureVerificationManager(interface::ConsumerSocket* icn_socket) + : icn_socket_(icn_socket), key_packets_() {} interface::VerificationPolicy onPacketToVerify(const Packet& packet) override; + bool onKeyToVerify() override; private: ConsumerSocket* icn_socket_; + std::queue<ContentObjectPtr> key_packets_; + + ContentObjectPtr copyPacket(const Packet& packet) { + std::shared_ptr<utils::MemBuf> packet_copy = + packet.acquireMemBufReference(); + ContentObjectPtr content_object_copy = + std::make_shared<core::ContentObject>(std::move(packet_copy)); + std::unique_ptr<utils::MemBuf> payload_copy = packet.getPayload(); + content_object_copy->appendPayload(std::move(payload_copy)); + return content_object_copy; + } }; } // end namespace protocol diff --git a/libtransport/src/hicn/transport/utils/content_store.cc b/libtransport/src/hicn/transport/utils/content_store.cc index 8e3435507..ba13bef41 100644 --- a/libtransport/src/hicn/transport/utils/content_store.cc +++ b/libtransport/src/hicn/transport/utils/content_store.cc @@ -40,41 +40,48 @@ void ContentStore::insert( content_store_hash_table_.size(), fifo_list_.size()); } - // Check if the content can be cached - if (content_object->getLifetime() > 0) { - if (content_store_hash_table_.size() >= max_content_store_size_) { - content_store_hash_table_.erase(fifo_list_.back()); - fifo_list_.pop_back(); - } - - // Insert new item - - auto it = content_store_hash_table_.find(content_object->getName()); - if (it != content_store_hash_table_.end()) { - fifo_list_.erase(it->second.second); - content_store_hash_table_.erase(content_object->getName()); - } + if (content_store_hash_table_.size() >= max_content_store_size_) { + content_store_hash_table_.erase(fifo_list_.back()); + fifo_list_.pop_back(); + } - fifo_list_.push_front(std::cref(content_object->getName())); - auto pos = fifo_list_.begin(); - content_store_hash_table_[content_object->getName()] = ContentStoreEntry( - ObjectTimeEntry(content_object, std::chrono::steady_clock::now()), pos); + // Insert new item + auto it = content_store_hash_table_.find(content_object->getName()); + if (it != content_store_hash_table_.end()) { + fifo_list_.erase(it->second.second); + content_store_hash_table_.erase(content_object->getName()); } + + fifo_list_.push_front(std::cref(content_object->getName())); + auto pos = fifo_list_.begin(); + content_store_hash_table_[content_object->getName()] = ContentStoreEntry( + ObjectTimeEntry(content_object, std::chrono::steady_clock::now()), pos); } const std::shared_ptr<ContentObject> ContentStore::find( const Interest &interest) { utils::SpinLock::Acquire locked(cs_mutex_); + + std::shared_ptr<ContentObject> ret = empty_reference_; auto it = content_store_hash_table_.find(interest.getName()); if (it != content_store_hash_table_.end()) { - if (std::chrono::duration_cast<std::chrono::milliseconds>( + + auto content_lifetime = it->second.first.first->getLifetime(); + auto time_passed_since_creation = + std::chrono::duration_cast<std::chrono::milliseconds>( std::chrono::steady_clock::now() - it->second.first.second) - .count() < it->second.first.first->getLifetime()) { - return it->second.first.first; + .count(); + + if (time_passed_since_creation > content_lifetime) { + fifo_list_.erase(it->second.second); + content_store_hash_table_.erase(it); + } + else { + ret = it->second.first.first; } } - return empty_reference_; + return ret; } void ContentStore::erase(const Name &exact_name) { @@ -103,7 +110,10 @@ void ContentStore::printContent() { for (auto &item : content_store_hash_table_) { if (item.second.first.first->getPayloadType() == transport::core::PayloadType::MANIFEST) { - TRANSPORT_LOGI("Manifest: %s\n", + TRANSPORT_LOGI("Manifest: %s", + item.second.first.first->getName().toString().c_str()); + } else { + TRANSPORT_LOGI("Data Packet: %s", item.second.first.first->getName().toString().c_str()); } } diff --git a/libtransport/src/hicn/transport/utils/content_store.h b/libtransport/src/hicn/transport/utils/content_store.h index f7dc41835..613ffcbc2 100644 --- a/libtransport/src/hicn/transport/utils/content_store.h +++ b/libtransport/src/hicn/transport/utils/content_store.h @@ -16,6 +16,7 @@ #pragma once #include <hicn/transport/interfaces/socket.h> +#include <hicn/transport/utils/spinlock.h> #include <mutex> diff --git a/libtransport/src/hicn/transport/utils/identity.cc b/libtransport/src/hicn/transport/utils/identity.cc index d84129537..c5ab03e44 100644 --- a/libtransport/src/hicn/transport/utils/identity.cc +++ b/libtransport/src/hicn/transport/utils/identity.cc @@ -49,18 +49,14 @@ Identity::Identity(const std::string &keystore_name, identity_, parcCryptoSuite_GetCryptoHash(static_cast<PARCCryptoSuite>(suite))); - signer_ = std::make_shared<Signer>(signer); - - signature_length_ = (unsigned int)parcSigner_GetSignatureSize(signer); + signer_ = std::make_shared<Signer>(signer, suite); parcSigner_Release(&signer); parcIdentityFile_Release(&identity_file); } Identity::Identity(const Identity &other) - : signer_(other.signer_), - hash_algorithm_(other.hash_algorithm_), - signature_length_(other.signature_length_) { + : signer_(other.signer_), hash_algorithm_(other.hash_algorithm_) { parcSecurity_Init(); identity_ = parcIdentity_Acquire(other.identity_); } @@ -90,27 +86,13 @@ Identity::Identity(std::string &file_name, std::string &password, PARCSigner *signer = parcIdentity_CreateSigner( identity_, static_cast<PARCCryptoHashType>(hash_algorithm)); - signer_ = std::make_shared<Signer>(signer); - - signature_length_ = (unsigned int)parcSigner_GetSignatureSize(signer); + signer_ = std::make_shared<Signer>( + signer, CryptoSuite(parcSigner_GetCryptoSuite(signer))); parcSigner_Release(&signer); parcIdentityFile_Release(&identity_file); } -// Identity::Identity(Identity &&other) { -// identity_ = parcIdentity_Acquire(other.identity_); -//} - -// Identity& Identity::operator=(const Identity& other) { -// signer_ = other.signer_; -// hash_algorithm_ = other.hash_algorithm_; -// signature_length_ = other.signature_length_; -// identity_ = parcIdentity_Acquire(other.identity_); - -// parcSecurity_Init(); -// } - Identity::~Identity() { parcIdentity_Release(&identity_); parcSecurity_Fini(); @@ -124,8 +106,10 @@ std::string Identity::getPassword() { return std::string(parcIdentity_GetPassWord(identity_)); } -Signer &Identity::getSigner() { return *signer_; } +std::shared_ptr<Signer> Identity::getSigner() { return signer_; } -unsigned int Identity::getSignatureLength() const { return signature_length_; } +size_t Identity::getSignatureLength() const { + return signer_->getSignatureLength(); +} } // namespace utils diff --git a/libtransport/src/hicn/transport/utils/identity.h b/libtransport/src/hicn/transport/utils/identity.h index 349b38914..e9801a005 100644 --- a/libtransport/src/hicn/transport/utils/identity.h +++ b/libtransport/src/hicn/transport/utils/identity.h @@ -49,15 +49,14 @@ class Identity { std::string getPassword(); - Signer &getSigner(); + std::shared_ptr<Signer> getSigner(); - unsigned int getSignatureLength() const; + size_t getSignatureLength() const; private: PARCIdentity *identity_; std::shared_ptr<Signer> signer_; transport::core::HashAlgorithm hash_algorithm_; - unsigned int signature_length_; }; } // namespace utils diff --git a/libtransport/src/hicn/transport/utils/signer.cc b/libtransport/src/hicn/transport/utils/signer.cc index 981e5f02b..9ac9a2c45 100644 --- a/libtransport/src/hicn/transport/utils/signer.cc +++ b/libtransport/src/hicn/transport/utils/signer.cc @@ -39,54 +39,96 @@ namespace utils { uint8_t Signer::zeros[200] = {0}; /*One signer_ per Private Key*/ -Signer::Signer(PARCKeyStore *keyStore, PARCCryptoSuite suite) { +Signer::Signer(PARCKeyStore *keyStore, CryptoSuite suite) { + parcSecurity_Init(); + switch (suite) { - case PARCCryptoSuite_NULL_CRC32C: - break; - case PARCCryptoSuite_ECDSA_SHA256: - case PARCCryptoSuite_RSA_SHA256: - case PARCCryptoSuite_DSA_SHA256: - case PARCCryptoSuite_RSA_SHA512: + case CryptoSuite::DSA_SHA256: + case CryptoSuite::RSA_SHA256: + case CryptoSuite::RSA_SHA512: + case CryptoSuite::ECDSA_256K1: { this->signer_ = - parcSigner_Create(parcPublicKeySigner_Create(keyStore, suite), + parcSigner_Create(parcPublicKeySigner_Create( + keyStore, static_cast<PARCCryptoSuite>(suite)), PARCPublicKeySignerAsSigner); - this->key_id_ = parcSigner_CreateKeyId(this->signer_); break; + } + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: { + this->signer_ = + parcSigner_Create(parcSymmetricKeySigner_Create( + (PARCSymmetricKeyStore *)keyStore, + parcCryptoSuite_GetCryptoHash( + static_cast<PARCCryptoSuite>(suite))), + PARCSymmetricKeySignerAsSigner); + break; + } + default: { return; } + } + + suite_ = suite; + key_id_ = parcSigner_CreateKeyId(this->signer_); + signature_length_ = parcSigner_GetSignatureSize(this->signer_); +} + +Signer::Signer(const std::string &passphrase, CryptoSuite suite) { + parcSecurity_Init(); - case PARCCryptoSuite_HMAC_SHA512: - case PARCCryptoSuite_HMAC_SHA256: - default: + switch (suite) { + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: { + composer_ = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer_, passphrase.c_str()); + key_buffer_ = parcBufferComposer_ProduceBuffer(composer_); + symmetricKeyStore_ = parcSymmetricKeyStore_Create(key_buffer_); this->signer_ = parcSigner_Create( - parcSymmetricKeySigner_Create((PARCSymmetricKeyStore *)keyStore, - parcCryptoSuite_GetCryptoHash(suite)), + parcSymmetricKeySigner_Create( + symmetricKeyStore_, parcCryptoSuite_GetCryptoHash( + static_cast<PARCCryptoSuite>(suite))), PARCSymmetricKeySignerAsSigner); - this->key_id_ = parcSigner_CreateKeyId(this->signer_); break; + } + default: { return; } } + + suite_ = suite; + key_id_ = parcSigner_CreateKeyId(this->signer_); + signature_length_ = parcSigner_GetSignatureSize(this->signer_); } -Signer::Signer(const PARCSigner *signer) +Signer::Signer(const PARCSigner *signer, CryptoSuite suite) : signer_(parcSigner_Acquire(signer)), - key_id_(parcSigner_CreateKeyId(this->signer_)) {} + key_id_(parcSigner_CreateKeyId(this->signer_)), + suite_(suite), + signature_length_(parcSigner_GetSignatureSize(this->signer_)) { + parcSecurity_Init(); +} + +Signer::Signer(const PARCSigner *signer) + : Signer(signer, CryptoSuite::UNKNOWN) {} Signer::~Signer() { - parcSigner_Release(&signer_); - parcKeyId_Release(&key_id_); + if (signature_) parcSignature_Release(&signature_); + if (symmetricKeyStore_) parcSymmetricKeyStore_Release(&symmetricKeyStore_); + if (key_buffer_) parcBuffer_Release(&key_buffer_); + if (composer_) parcBufferComposer_Release(&composer_); + if (signer_) parcSigner_Release(&signer_); + if (key_id_) parcKeyId_Release(&key_id_); + parcSecurity_Fini(); } void Signer::sign(Packet &packet) { // header chain points to the IP + TCP hicn header + AH Header - utils::MemBuf *header_chain = packet.header_head_; - utils::MemBuf *payload_chain = packet.payload_head_; + MemBuf *header_chain = packet.header_head_; + MemBuf *payload_chain = packet.payload_head_; uint8_t *hicn_packet = (uint8_t *)header_chain->writableData(); Packet::Format format = packet.getFormat(); - std::size_t sign_len_bytes = parcSigner_GetSignatureSize(signer_); if (!(format & HFO_AH)) { throw errors::MalformedAHPacketException(); } - packet.setSignatureSize(sign_len_bytes); + packet.setSignatureSize(signature_length_); // Copy IP+TCP/ICMP header before zeroing them hicn_header_t header_copy; @@ -102,8 +144,7 @@ void Signer::sign(Packet &packet) { auto now = duration_cast<milliseconds>(system_clock::now().time_since_epoch()) .count(); packet.setSignatureTimestamp(now); - packet.setValidationAlgorithm( - CryptoSuite(parcSigner_GetCryptoSuite(this->signer_))); + packet.setValidationAlgorithm(suite_); KeyId key_id; key_id.first = (uint8_t *)parcBuffer_Overlay( @@ -111,25 +152,25 @@ void Signer::sign(Packet &packet) { packet.setKeyId(key_id); // Calculate hash - utils::CryptoHasher hasher(parcSigner_GetCryptoHasher(signer_)); + CryptoHasher hasher(parcSigner_GetCryptoHasher(signer_)); hasher.init(); - hasher.updateBytes(hicn_packet, header_len + sign_len_bytes); + hasher.updateBytes(hicn_packet, header_len + signature_length_); - for (utils::MemBuf *current = payload_chain; current != header_chain; + for (MemBuf *current = payload_chain; current != header_chain; current = current->next()) { hasher.updateBytes(current->data(), current->length()); } - utils::CryptoHash hash = hasher.finalize(); + CryptoHash hash = hasher.finalize(); - PARCSignature *signature = parcSigner_SignDigestNoAlloc( - this->signer_, hash.hash_, packet.getSignature(), - (uint32_t)sign_len_bytes); - PARCBuffer *buffer = parcSignature_GetSignature(signature); + signature_ = parcSigner_SignDigestNoAlloc(this->signer_, hash.hash_, + packet.getSignature(), + (uint32_t)signature_length_); + PARCBuffer *buffer = parcSignature_GetSignature(signature_); size_t bytes_len = parcBuffer_Remaining(buffer); - if (bytes_len > sign_len_bytes) { + if (bytes_len > signature_length_) { throw errors::MalformedAHPacketException(); } @@ -137,6 +178,8 @@ void Signer::sign(Packet &packet) { (hicn_header_t *)packet.packet_start_, false); } +size_t Signer::getSignatureLength() { return signature_length_; } + PARCKeyStore *Signer::getKeyStore() { return parcSigner_GetKeyStore(this->signer_); } diff --git a/libtransport/src/hicn/transport/utils/signer.h b/libtransport/src/hicn/transport/utils/signer.h index 6afb9544c..31b21462b 100644 --- a/libtransport/src/hicn/transport/utils/signer.h +++ b/libtransport/src/hicn/transport/utils/signer.h @@ -22,6 +22,7 @@ extern "C" { #include <parc/security/parc_CryptoSuite.h> #include <parc/security/parc_KeyStore.h> #include <parc/security/parc_Signer.h> +#include <parc/security/parc_SymmetricKeySigner.h> } namespace utils { @@ -42,7 +43,17 @@ class Signer { * use to sign packet with this Signer. * @param suite CryptoSuite to use to verify the signature */ - Signer(PARCKeyStore *keyStore, PARCCryptoSuite suite); + Signer(PARCKeyStore *keyStore, CryptoSuite suite); + + /** + * Create a Signer + * + * @param passphrase A string from which the symmetric key will be derived + * @param suite CryptoSuite to use to verify the signature + */ + Signer(const std::string &passphrase, CryptoSuite suite); + + Signer(const PARCSigner *signer, CryptoSuite suite); Signer(const PARCSigner *signer); @@ -60,11 +71,19 @@ class Signer { */ void sign(Packet &packet); + size_t getSignatureLength(); + PARCKeyStore *getKeyStore(); private: - PARCSigner *signer_; - PARCKeyId *key_id_; + PARCBufferComposer *composer_ = nullptr; + PARCBuffer *key_buffer_ = nullptr; + PARCSymmetricKeyStore *symmetricKeyStore_ = nullptr; + PARCSigner *signer_ = nullptr; + PARCSignature *signature_ = nullptr; + PARCKeyId *key_id_ = nullptr; + CryptoSuite suite_; + size_t signature_length_; static uint8_t zeros[200]; }; diff --git a/libtransport/src/hicn/transport/utils/suffix_strategy.h b/libtransport/src/hicn/transport/utils/suffix_strategy.h index 0ed3c5b0e..ab9b1eff6 100644 --- a/libtransport/src/hicn/transport/utils/suffix_strategy.h +++ b/libtransport/src/hicn/transport/utils/suffix_strategy.h @@ -16,6 +16,7 @@ #pragma once #include <hicn/transport/core/manifest_format.h> +#include <limits> namespace utils { diff --git a/libtransport/src/hicn/transport/utils/verifier.cc b/libtransport/src/hicn/transport/utils/verifier.cc index 69b2101da..281ee21dc 100644 --- a/libtransport/src/hicn/transport/utils/verifier.cc +++ b/libtransport/src/hicn/transport/utils/verifier.cc @@ -25,9 +25,6 @@ extern "C" { TRANSPORT_CLANG_DISABLE_WARNING("-Wextern-c-compat") #endif #include <hicn/hicn.h> -#include <parc/security/parc_CertificateFactory.h> -#include <parc/security/parc_InMemoryVerifier.h> -#include <parc/security/parc_Security.h> } #include <sys/stat.h> @@ -46,11 +43,18 @@ Verifier::Verifier() { PARCInMemoryVerifier *in_memory_verifier = parcInMemoryVerifier_Create(); this->verifier_ = parcVerifier_Create(in_memory_verifier, PARCInMemoryVerifierAsVerifier); - parcInMemoryVerifier_Release(&in_memory_verifier); } Verifier::~Verifier() { - parcVerifier_Release(&verifier_); + if (key_) parcKey_Release(&key_); + if (keyId_) parcKeyId_Release(&keyId_); + if (signer_) parcSigner_Release(&signer_); + if (symmetricKeyStore_) parcSymmetricKeyStore_Release(&symmetricKeyStore_); + if (key_buffer_) parcBuffer_Release(&key_buffer_); + if (composer_) parcBufferComposer_Release(&composer_); + if (certificate_) parcCertificate_Release(&certificate_); + if (factory_) parcCertificateFactory_Release(&factory_); + if (verifier_) parcVerifier_Release(&verifier_); parcSecurity_Fini(); } @@ -67,10 +71,30 @@ bool Verifier::addKey(PARCKey *key) { return true; } +PARCKeyId *Verifier::addKeyFromPassphrase(const std::string &passphrase, + CryptoSuite suite) { + composer_ = parcBufferComposer_Create(); + parcBufferComposer_PutString(composer_, passphrase.c_str()); + key_buffer_ = parcBufferComposer_ProduceBuffer(composer_); + + symmetricKeyStore_ = parcSymmetricKeyStore_Create(key_buffer_); + signer_ = parcSigner_Create( + parcSymmetricKeySigner_Create( + symmetricKeyStore_, + parcCryptoSuite_GetCryptoHash(static_cast<PARCCryptoSuite>(suite))), + PARCSymmetricKeySignerAsSigner); + keyId_ = parcSigner_CreateKeyId(signer_); + key_ = parcKey_CreateFromSymmetricKey( + keyId_, parcSigner_GetSigningAlgorithm(signer_), key_buffer_); + + addKey(key_); + return keyId_; +} + PARCKeyId *Verifier::addKeyFromCertificate(const std::string &file_name) { - PARCCertificateFactory *factory = parcCertificateFactory_Create( - PARCCertificateType_X509, PARCContainerEncoding_PEM); - parcAssertNotNull(factory, "Expected non-NULL factory"); + factory_ = parcCertificateFactory_Create(PARCCertificateType_X509, + PARCContainerEncoding_PEM); + parcAssertNotNull(factory_, "Expected non-NULL factory"); if (!file_exists(file_name)) { TRANSPORT_LOGW("Warning! The certificate %s file does not exist", @@ -78,31 +102,23 @@ PARCKeyId *Verifier::addKeyFromCertificate(const std::string &file_name) { return nullptr; } - PARCCertificate *certificate = - parcCertificateFactory_CreateCertificateFromFile( - factory, (char *)file_name.c_str(), NULL); - - PARCKey *key = parcCertificate_GetPublicKey(certificate); - addKey(key); - - PARCKeyId *ret = parcKeyId_Acquire(parcKey_GetKeyId(key)); - - // parcKey_Release(&key); - // parcCertificate_Release(&certificate); - // parcCertificateFactory_Release(&factory); - - return ret; + certificate_ = parcCertificateFactory_CreateCertificateFromFile( + factory_, (char *)file_name.c_str(), NULL); + PARCBuffer *derEncodedVersion = + parcCertificate_GetDEREncodedPublicKey(certificate_); + PARCCryptoHash *keyDigest = parcCertificate_GetPublicKeyDigest(certificate_); + keyId_ = parcKeyId_Create(parcCryptoHash_GetDigest(keyDigest)); + key_ = parcKey_CreateFromDerEncodedPublicKey(keyId_, PARCSigningAlgorithm_RSA, + derEncodedVersion); + + addKey(key_); + return keyId_; } int Verifier::verify(const Packet &packet) { + // to initialize packet.payload_head_ + const_cast<Packet *>(&packet)->separateHeaderPayload(); bool valid = false; - - // initialize packet.payload_head_ - const_cast<Packet*>(&packet)->separateHeaderPayload(); - // header chain points to the IP + TCP hicn header - utils::MemBuf *header_chain = packet.header_head_; - utils::MemBuf *payload_chain = packet.payload_head_; - uint8_t *hicn_packet = header_chain->writableData(); Packet::Format format = packet.getFormat(); if (!(packet.format_ & HFO_AH)) { @@ -114,10 +130,9 @@ int Verifier::verify(const Packet &packet) { hicn_packet_copy_header(format, (const hicn_header_t *)packet.packet_start_, &header_copy, false); - std::size_t header_len = Packet::getHeaderSizeFromFormat(format); - PARCCryptoSuite suite = static_cast<PARCCryptoSuite>(packet.getValidationAlgorithm()); + PARCCryptoHashType hashtype = parcCryptoSuite_GetCryptoHash(suite); KeyId _key_id = packet.getKeyId(); PARCBuffer *buffer = parcBuffer_Wrap(_key_id.first, _key_id.second, 0, _key_id.second); @@ -127,27 +142,30 @@ int Verifier::verify(const Packet &packet) { int ah_payload_len = (int)packet.getSignatureSize(); uint8_t *_signature = packet.getSignature(); uint8_t *signature = new uint8_t[ah_payload_len]; - // TODO Remove signature copy at this point, by not setting to zero // the validation payload. std::memcpy(signature, _signature, ah_payload_len); - // Reset fields that should not appear in the signature - const_cast<Packet &>(packet).resetForHash(); - - PARCCryptoHashType hashtype = parcCryptoSuite_GetCryptoHash(suite); - utils::CryptoHasher hasher( - parcVerifier_GetCryptoHasher(verifier_, key_id, hashtype)); - - hasher.init().updateBytes(hicn_packet, header_len + ah_payload_len); - - for (utils::MemBuf *current = payload_chain; current != header_chain; - current = current->next()) { - hasher.updateBytes(current->data(), current->length()); + std::shared_ptr<CryptoHasher> hasher; + switch (CryptoSuite(suite)) { + case CryptoSuite::DSA_SHA256: + case CryptoSuite::RSA_SHA256: + case CryptoSuite::RSA_SHA512: + case CryptoSuite::ECDSA_256K1: { + hasher = std::make_shared<CryptoHasher>( + parcVerifier_GetCryptoHasher(verifier_, key_id, hashtype)); + break; + } + case CryptoSuite::HMAC_SHA256: + case CryptoSuite::HMAC_SHA512: { + if (!signer_) return false; + hasher = + std::make_shared<CryptoHasher>(parcSigner_GetCryptoHasher(signer_)); + break; + } + default: { return false; } } - - utils::CryptoHash hash = hasher.finalize(); - PARCCryptoHash *hash_computed_locally = hash.hash_; + CryptoHash hash_computed_locally = getPacketHash(packet, hasher); PARCBuffer *bits = parcBuffer_Wrap(signature, ah_payload_len, 0, ah_payload_len); @@ -178,20 +196,39 @@ int Verifier::verify(const Packet &packet) { } valid = parcVerifier_VerifyDigestSignature( - verifier_, key_id, hash_computed_locally, suite, signatureToVerify); + verifier_, key_id, hash_computed_locally.hash_, suite, signatureToVerify); /* Restore the resetted fields */ hicn_packet_copy_header(format, &header_copy, (hicn_header_t *)packet.packet_start_, false); delete[] signature; - parcKeyId_Release(&key_id); - parcBuffer_Release(&bits); parcSignature_Release(&signatureToVerify); return valid; } +CryptoHash Verifier::getPacketHash(const Packet &packet, + std::shared_ptr<CryptoHasher> hasher) { + MemBuf *header_chain = packet.header_head_; + MemBuf *payload_chain = packet.payload_head_; + Packet::Format format = packet.getFormat(); + int ah_payload_len = (int)packet.getSignatureSize(); + uint8_t *hicn_packet = header_chain->writableData(); + std::size_t header_len = Packet::getHeaderSizeFromFormat(format); + + // Reset fields that should not appear in the signature + const_cast<Packet &>(packet).resetForHash(); + hasher->init().updateBytes(hicn_packet, header_len + ah_payload_len); + + for (MemBuf *current = payload_chain; current != header_chain; + current = current->next()) { + hasher->updateBytes(current->data(), current->length()); + } + + return hasher->finalize(); +} + } // namespace utils diff --git a/libtransport/src/hicn/transport/utils/verifier.h b/libtransport/src/hicn/transport/utils/verifier.h index 6313a7240..7ec6e7eda 100644 --- a/libtransport/src/hicn/transport/utils/verifier.h +++ b/libtransport/src/hicn/transport/utils/verifier.h @@ -18,7 +18,11 @@ #include <hicn/transport/core/packet.h> extern "C" { +#include <parc/security/parc_CertificateFactory.h> +#include <parc/security/parc_InMemoryVerifier.h> #include <parc/security/parc_KeyId.h> +#include <parc/security/parc_Security.h> +#include <parc/security/parc_SymmetricKeySigner.h> #include <parc/security/parc_Verifier.h> } @@ -56,6 +60,9 @@ class Verifier { */ bool addKey(PARCKey *key); + PARCKeyId *addKeyFromPassphrase(const std::string &passphrase, + CryptoSuite suite); + PARCKeyId *addKeyFromCertificate(const std::string &file_name); /** @@ -77,8 +84,19 @@ class Verifier { */ int verify(const Packet &packet); + CryptoHash getPacketHash(const Packet &packet, + std::shared_ptr<CryptoHasher> hasher); + private: - PARCVerifier *verifier_; + PARCVerifier *verifier_ = nullptr; + PARCCertificateFactory *factory_ = nullptr; + PARCCertificate *certificate_ = nullptr; + PARCKeyId *keyId_ = nullptr; + PARCKey *key_ = nullptr; + PARCBuffer *key_buffer_ = nullptr; + PARCSymmetricKeyStore *symmetricKeyStore_ = nullptr; + PARCSigner *signer_ = nullptr; + PARCBufferComposer *composer_ = nullptr; static uint8_t zeros[200]; }; |