aboutsummaryrefslogtreecommitdiffstats
path: root/libtransport/src/hicn/transport/interfaces/tls_socket_consumer.h
blob: 05f7fe6a55a7e80a7c23100ad8b25b0d40712807 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
 * Copyright (c) 2017-2019 Cisco and/or its affiliates.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <hicn/transport/interfaces/socket_consumer.h>
#include <openssl/ssl.h>

namespace transport {

namespace interface {

class TLSConsumerSocket : public ConsumerSocket,
                          public ConsumerSocket::ReadCallback {
  /* Return the number of read bytes in readbytes */
  friend int readTLS(BIO *b, char *buf, size_t size, size_t *readbytes);

  /* Return the number of read bytes in the return param */
  friend int readOldTLS(BIO *h, char *buf, int size);

  /* Return the number of written bytes in written */
  friend int writeTLS(BIO *b, const char *buf, size_t size, size_t *written);

  /* Return the number of written bytes in the return param */
  friend int writeOldTLS(BIO *h, const char *buf, int num);

  friend long ctrlTLS(BIO *b, int cmd, long num, void *ptr);

 public:
  explicit TLSConsumerSocket(int protocol, SSL *ssl_);

  ~TLSConsumerSocket() = default;

  int consume(const Name &name, std::unique_ptr<utils::MemBuf> &&buffer);
  int consume(const Name &name) override;

  int asyncConsume(const Name &name, std::unique_ptr<utils::MemBuf> &&buffer);
  int asyncConsume(const Name &name) override;

  void registerPrefix(const Prefix &producer_namespace);

  int setSocketOption(
      int socket_option_key,
      ConsumerSocket::ReadCallback *socket_option_value) override;

  using ConsumerSocket::getSocketOption;
  using ConsumerSocket::setSocketOption;

 protected:
  /* Callback invoked once an interest has been received and its payload
   * decrypted */
  ConsumerInterestCallback on_interest_input_decrypted_;
  ConsumerInterestCallback on_interest_process_decrypted_;

 private:
  Name name_;

  /* SSL handle */
  SSL *ssl_;
  SSL_CTX *ctx_;

  /* Chain of MemBuf to be used as a temporary buffer to pass descypted data
   * from the underlying layer to the application */
  utils::ObjectPool<utils::MemBuf> buf_pool_;
  std::unique_ptr<utils::MemBuf> decrypted_content_;

  /* Chain of MemBuf holding the payload to be written into interest or data */
  std::unique_ptr<utils::MemBuf> payload_;

  /* Chain of MemBuf holding the data retrieved from the underlying layer */
  std::unique_ptr<utils::MemBuf> head_;

  bool something_to_read_;

  bool content_downloaded_;

  double old_max_win_;

  double old_current_win_;

  uint32_t random_suffix_;

  ip_address_t secure_prefix_;

  Prefix producer_namespace_;

  ConsumerSocket::ReadCallback *read_callback_decrypted_;

  std::mutex mtx_;

  /* Condition variable for the wait */
  std::condition_variable cv_;

  utils::EventThread async_downloader_tls_;

  void setInterestPayload(ConsumerSocket &c, const core::Interest &interest);
  void processPayload(ConsumerSocket &c, std::size_t bytes_transferred,
                      const std::error_code &ec);

  virtual void getReadBuffer(uint8_t **application_buffer,
                             size_t *max_length) override;

  virtual void readDataAvailable(size_t length) noexcept override;

  virtual size_t maxBufferSize() const override;

  virtual void readBufferAvailable(
      std::unique_ptr<utils::MemBuf> &&buffer) noexcept override;

  virtual void readError(const std::error_code ec) noexcept override;

  virtual void readSuccess(std::size_t total_size) noexcept override;
  virtual bool isBufferMovable() noexcept override;

  int download_content(const Name &name);
};

}  // namespace interface

}  // end namespace transport