diff options
Diffstat (limited to 'apps/higet/higet.cc')
-rw-r--r-- | apps/higet/higet.cc | 212 |
1 files changed, 166 insertions, 46 deletions
diff --git a/apps/higet/higet.cc b/apps/higet/higet.cc index fa19528f8..df34d5c14 100644 --- a/apps/higet/higet.cc +++ b/apps/higet/higet.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Cisco and/or its affiliates. + * Copyright (c) 2020 Cisco and/or its affiliates. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at: @@ -8,14 +8,19 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHout_ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <hicn/transport/http/client_connection.h> - #include <fstream> +#include <map> + +#ifndef ASIO_STANDALONE +#define ASIO_STANDALONE +#include <asio.hpp> +#endif typedef std::chrono::time_point<std::chrono::system_clock> Time; typedef std::chrono::milliseconds TimeDuration; @@ -34,61 +39,160 @@ typedef struct { std::string ipv6_first_word; } Configuration; -void processResponse(Configuration &conf, - transport::http::HTTPResponse &&response) { - auto &payload = response.getPayload(); +class ReadBytesCallbackImplementation + : public transport::http::HTTPClientConnection::ReadBytesCallback { + public: + ReadBytesCallbackImplementation(std::string file_name, long yet_downloaded) + : file_name_(file_name), + temp_file_name_(file_name_ + ".temp"), + yet_downloaded_(yet_downloaded), + byte_downloaded_(yet_downloaded), + work_(std::make_unique<asio::io_service::work>(io_service_)), + thread_( + std::make_unique<std::thread>([this]() { io_service_.run(); })) { + std::streambuf *buf; + if (file_name_ != "-") { + of_.open(temp_file_name_, std::ofstream::binary | std::ofstream::app); + buf = of_.rdbuf(); + } else { + buf = std::cout.rdbuf(); + } - if (conf.file_name != "-") { - std::cerr << "Saving to: " << conf.file_name << " " << payload.size() - << "kB" << std::endl; + out_ = new std::ostream(buf); } - Time t3 = std::chrono::system_clock::now(); + ~ReadBytesCallbackImplementation() { + if (thread_->joinable()) { + thread_->join(); + } + } - std::streambuf *buf; - std::ofstream of; + void onBytesReceived(std::unique_ptr<utils::MemBuf> &&buffer) { + auto buffer_ptr = buffer.release(); + io_service_.post([this, buffer_ptr]() { + auto buffer = std::unique_ptr<utils::MemBuf>(buffer_ptr); + if (!first_chunk_read_) { + transport::http::HTTPResponse http_response(std::move(buffer)); + auto payload = http_response.getPayload(); + auto header = http_response.getHeaders(); + std::map<std::string, std::string>::iterator it = + header.find("Content-Length"); + if (it != header.end()) { + content_size_ = yet_downloaded_ + std::stol(it->second); + } + out_->write((char *)payload->data(), payload->length()); + first_chunk_read_ = true; + byte_downloaded_ += payload->length(); + } else { + out_->write((char *)buffer->data(), buffer->length()); + byte_downloaded_ += buffer->length(); + } + + if (file_name_ != "-") { + print_bar(byte_downloaded_, content_size_, false); + } + }); + } - if (conf.file_name != "-") { - of.open(conf.file_name, std::ofstream::binary); - buf = of.rdbuf(); - } else { - buf = std::cout.rdbuf(); + void onSuccess(std::size_t bytes) { + io_service_.post([this, bytes]() { + if (file_name_ != "-") { + of_.close(); + delete out_; + std::size_t found = file_name_.find_last_of("."); + std::string name = file_name_.substr(0, found); + std::string extension = file_name_.substr(found + 1); + if (!exists_file(file_name_)) { + std::rename(temp_file_name_.c_str(), file_name_.c_str()); + } else { + int i = 1; + std::ostringstream sstream; + sstream << name << "(" << i << ")." << extension; + std::string final_name = sstream.str(); + while (exists_file(final_name)) { + i++; + sstream.str(""); + sstream << name << "(" << i << ")." << extension; + final_name = sstream.str(); + } + std::rename(temp_file_name_.c_str(), final_name.c_str()); + } + + print_bar(100, 100, true); + std::cout << "\nDownloaded " << bytes << " bytes" << std::endl; + } + work_.reset(); + }); } - std::ostream out(buf); + void onError(const std::error_code ec) { + io_service_.post([this]() { + of_.close(); + delete out_; + work_.reset(); + }); + } + + private: + bool exists_file(const std::string &name) { + std::ifstream f(name.c_str()); + return f.good(); + } - if (conf.print_headers) { - auto &headers = response.getHeaders(); - out << "HTTP/" << response.getHttpVersion() << " " - << response.getStatusCode() << " " << response.getStatusString() - << "\n"; - for (auto &h : headers) { - out << h.first << ": " << h.second << "\n"; + void print_bar(long value, long max_value, bool last) { + float progress = (float)value / max_value; + struct winsize size; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &size); + int barWidth = size.ws_col - 8; + + std::cout << "["; + int pos = barWidth * progress; + for (int i = 0; i < barWidth; ++i) { + if (i < pos) { + std::cout << "="; + } else if (i == pos) { + std::cout << ">"; + } else { + std::cout << " "; + } + } + if (last) { + std::cout << "] " << int(progress * 100.0) << " %" << std::endl + << std::endl; + } else { + std::cout << "] " << int(progress * 100.0) << " %\r"; + std::cout.flush(); } - out << "\n"; } - out.write((char *)payload.data(), payload.size()); - of.close(); - - Time t2 = std::chrono::system_clock::now(); - TimeDuration dt = - std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1); - TimeDuration dt3 = - std::chrono::duration_cast<std::chrono::milliseconds>(t3 - t1); - long msec = (long)dt.count(); - long msec3 = (long)dt3.count(); - std::cerr << "Elapsed Time: " << msec / 1000.0 << " seconds -- " - << payload.size() * 8 / msec / 1000.0 << "[Mbps] -- " - << payload.size() * 8 / msec3 / 1000.0 << "[Mbps]" << std::endl; + private: + std::string file_name_; + std::string temp_file_name_; + std::ostream *out_; + std::ofstream of_; + long yet_downloaded_; + long content_size_; + bool first_chunk_read_ = false; + long byte_downloaded_ = 0; + asio::io_service io_service_; + std::unique_ptr<asio::io_service::work> work_; + std::unique_ptr<std::thread> thread_; +}; + +long checkFileStatus(std::string file_name) { + struct stat stat_buf; + std::string temp_file_name_ = file_name + ".temp"; + int rc = stat(temp_file_name_.c_str(), &stat_buf); + return rc == 0 ? stat_buf.st_size : -1; } void usage(char *program_name) { std::cerr << "usage:" << std::endl; std::cerr << program_name << " [option]... [url]..." << std::endl; std::cerr << program_name << "options:" << std::endl; - std::cerr << "-O <output_path> = write documents to <output_file>" - << std::endl; + std::cerr + << "-O <out_put_path> = write documents to <out_put_file>" + << std::endl; std::cerr << "-S = print server response" << std::endl; std::cerr << "-P = first word of the ipv6 name of " @@ -145,10 +249,23 @@ int main(int argc, char **argv) { conf.file_name = name.substr(1 + name.find_last_of("/")); } - std::map<std::string, std::string> headers = {{"Host", "localhost"}, - {"User-Agent", "higet/1.0"}, - {"Connection", "Keep-Alive"}}; + long yetDownloaded = checkFileStatus(conf.file_name); + std::map<std::string, std::string> headers; + if (yetDownloaded == -1) { + headers = {{"Host", "localhost"}, + {"User-Agent", "higet/1.0"}, + {"Connection", "Keep-Alive"}}; + } else { + std::string range; + range.append("bytes="); + range.append(std::to_string(yetDownloaded)); + range.append("-"); + headers = {{"Host", "localhost"}, + {"User-Agent", "higet/1.0"}, + {"Connection", "Keep-Alive"}, + {"Range", range}}; + } transport::http::HTTPClientConnection connection; if (!conf.producer_certificate.empty()) { connection.setCertificate(conf.producer_certificate); @@ -156,8 +273,11 @@ int main(int argc, char **argv) { t1 = std::chrono::system_clock::now(); - connection.get(name, headers, {}, nullptr, nullptr, conf.ipv6_first_word); - processResponse(conf, connection.response()); + http::ReadBytesCallbackImplementation readBytesCallback(conf.file_name, + yetDownloaded); + + connection.get(name, headers, {}, nullptr, &readBytesCallback, + conf.ipv6_first_word); #ifdef _WIN32 WSACleanup(); |