From a2e7e43f88feca7449c3ae1bf7e994e748879832 Mon Sep 17 00:00:00 2001
From: Konstantinos Sideris <sideris.konstantin@gmail.com>
Date: Sat, 2 Jun 2018 12:19:34 +0300
Subject: [PATCH] Add method to force shutdown all active connections

---
 examples/crypto_bot.cpp            | 18 +++++++++++++++
 include/mtxclient/http/client.hpp  | 13 ++++++++++-
 include/mtxclient/http/session.hpp |  9 +++++++-
 lib/http/client.cpp                | 10 ++++-----
 lib/http/session.cpp               | 11 +++++++++
 tests/connection.cpp               | 36 ++++++++++++++++++++++++++++++
 6 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/examples/crypto_bot.cpp b/examples/crypto_bot.cpp
index 7ad14af12..1b0398d24 100644
--- a/examples/crypto_bot.cpp
+++ b/examples/crypto_bot.cpp
@@ -1,6 +1,9 @@
 #include <boost/algorithm/string/predicate.hpp>
 #include <boost/beast.hpp>
 
+#include <csignal>
+#include <cstdlib>
+
 #include "spdlog/spdlog.h"
 #include <atomic>
 #include <iostream>
@@ -843,11 +846,24 @@ join_room_cb(const nlohmann::json &obj, RequestErr err)
         // Fetch device list for all users.
 }
 
+void
+shutdown_handler(int sig)
+{
+        console->warn("received {} signal", sig);
+        console->info("saving storage");
+        console->info("shutting down");
+
+        // The sync calls will stop.
+        client->shutdown();
+}
+
 int
 main()
 {
         spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
 
+        std::signal(SIGINT, shutdown_handler);
+
         std::string username("alice");
         std::string server("localhost");
         std::string password("secret");
@@ -863,5 +879,7 @@ main()
         client->login(username, password, login_cb);
         client->close();
 
+        console->info("exit");
+
         return 0;
 }
diff --git a/include/mtxclient/http/client.hpp b/include/mtxclient/http/client.hpp
index 40e6ab94d..79de93ddb 100644
--- a/include/mtxclient/http/client.hpp
+++ b/include/mtxclient/http/client.hpp
@@ -9,6 +9,8 @@
 #include <boost/beast.hpp>
 #include <boost/iostreams/stream.hpp>
 #include <boost/optional.hpp>
+#include <boost/signals2.hpp>
+#include <boost/signals2/signal_type.hpp>
 #include <boost/thread/thread.hpp>
 #include <json.hpp>
 
@@ -83,6 +85,8 @@ public:
         std::string device_id() const { return device_id_; }
         //! Generate a new transaction id.
         std::string generate_txn_id() { return client::utils::random_token(32, false); }
+        //! Abort all active pending requests.
+        void shutdown() { shutdown_signal(); }
 
         //! Perfom login.
         void login(const std::string &username,
@@ -133,7 +137,7 @@ public:
                          Callback<mtx::responses::RoomInvite> cb);
 
         //! Perform sync.
-        void sync(const SyncOpts &opts, Callback<nlohmann::json> cb);
+        void sync(const SyncOpts &opts, Callback<mtx::responses::Sync> cb);
 
         //! Paginate through room messages.
         void messages(const mtx::identifiers::Room &room_id,
@@ -286,6 +290,8 @@ private:
         std::string next_batch_token_;
         //! The homeserver port to connect.
         uint16_t port_ = 443;
+        //! All the active sessions will shutdown the connection.
+        boost::signals2::signal<void()> shutdown_signal;
 };
 }
 }
@@ -435,6 +441,11 @@ mtx::http::Client::create_session(HeadersCallback<Response> callback)
                   callback(response_data, {}, client_error);
           });
 
+        if (session)
+                shutdown_signal.connect(
+                  boost::signals2::signal<void()>::slot_type(&Session::terminate, session.get())
+                    .track_foreign(session));
+
         return std::move(session);
 }
 
diff --git a/include/mtxclient/http/session.hpp b/include/mtxclient/http/session.hpp
index d59a47ea5..c6d52cbee 100644
--- a/include/mtxclient/http/session.hpp
+++ b/include/mtxclient/http/session.hpp
@@ -77,7 +77,11 @@ struct Session : public std::enable_shared_from_this<Session>
                                                   std::placeholders::_2));
         }
 
+        //! Force shutdown all connections. Pending responses will not be processed.
+        void terminate();
+
 private:
+        void shutdown();
         void on_resolve(boost::system::error_code ec,
                         boost::asio::ip::tcp::resolver::results_type results);
         void on_close(boost::system::error_code ec);
@@ -86,7 +90,10 @@ private:
         void on_read(const boost::system::error_code &ec, std::size_t bytes_transferred);
         void on_request_complete();
         void on_write(const boost::system::error_code &ec, std::size_t bytes_transferred);
-        void shutdown();
+
+        //! Flag to indicate that the connection of this session is closing and no
+        //! response should be processed.
+        std::atomic_bool is_shutting_down_;
 };
 
 template<class Request, boost::beast::http::verb HttpVerb>
diff --git a/lib/http/client.cpp b/lib/http/client.cpp
index ad5f3bae6..1a1039a54 100644
--- a/lib/http/client.cpp
+++ b/lib/http/client.cpp
@@ -176,7 +176,7 @@ Client::invite_user(const mtx::identifiers::Room &room_id,
 }
 
 void
-Client::sync(const SyncOpts &opts, Callback<nlohmann::json> callback)
+Client::sync(const SyncOpts &opts, Callback<mtx::responses::Sync> callback)
 {
         std::map<std::string, std::string> params;
 
@@ -191,10 +191,10 @@ Client::sync(const SyncOpts &opts, Callback<nlohmann::json> callback)
 
         params.emplace("timeout", std::to_string(opts.timeout));
 
-        get<nlohmann::json>("/client/r0/sync?" + mtx::client::utils::query_params(params),
-                            [callback](const nlohmann::json &res, HeaderFields, RequestErr err) {
-                                    callback(res, err);
-                            });
+        get<mtx::responses::Sync>("/client/r0/sync?" + mtx::client::utils::query_params(params),
+                                  [callback](const mtx::responses::Sync &res,
+                                             HeaderFields,
+                                             RequestErr err) { callback(res, err); });
 }
 
 void
diff --git a/lib/http/session.cpp b/lib/http/session.cpp
index 4754afaa5..08891235f 100644
--- a/lib/http/session.cpp
+++ b/lib/http/session.cpp
@@ -16,6 +16,7 @@ Session::Session(boost::asio::io_service &ios,
   , id(std::move(id))
   , on_success(std::move(on_success))
   , on_failure(std::move(on_failure))
+  , is_shutting_down_(false)
 {
         parser.header_limit(8192);
         parser.body_limit(1 * 1024 * 1024 * 1024); // 1 GiB
@@ -75,6 +76,13 @@ Session::on_connect(const boost::system::error_code &ec)
           std::bind(&Session::on_handshake, shared_from_this(), std::placeholders::_1));
 }
 
+void
+Session::terminate()
+{
+        is_shutting_down_ = true;
+        shutdown();
+}
+
 void
 Session::shutdown()
 {
@@ -85,6 +93,9 @@ Session::shutdown()
 void
 Session::on_request_complete()
 {
+        if (is_shutting_down_)
+                return;
+
         boost::system::error_code ec(error_code);
         on_success(id, parser.get(), ec);
 
diff --git a/tests/connection.cpp b/tests/connection.cpp
index 8d3d18214..3d5168546 100644
--- a/tests/connection.cpp
+++ b/tests/connection.cpp
@@ -5,6 +5,8 @@
 #include "mtxclient/http/client.hpp"
 #include "mtxclient/http/errors.hpp"
 
+#include "test_helpers.hpp"
+
 using namespace mtx::http;
 using namespace mtx::client;
 
@@ -29,3 +31,37 @@ TEST(Basic, Failure)
         alice->versions([](const mtx::responses::Versions &, RequestErr err) { ASSERT_TRUE(err); });
         alice->close();
 }
+
+TEST(Basic, Shutdown)
+{
+        std::shared_ptr<Client> client = std::make_shared<Client>("localhost");
+
+        client->login("carl", "secret", [client](const mtx::responses::Login &, RequestErr err) {
+                check_error(err);
+        });
+
+        while (client->access_token().empty())
+                sleep();
+
+        std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
+
+        SyncOpts opts;
+        opts.timeout = 40'000; // milliseconds
+        client->sync(opts, [client, &opts](const mtx::responses::Sync &res, RequestErr err) {
+                check_error(err);
+
+                opts.since = res.next_batch;
+                client->sync(opts, [](const mtx::responses::Sync &, RequestErr) {});
+        });
+
+        std::this_thread::sleep_for(std::chrono::seconds(1));
+
+        // Force terminate all active connections.
+        client->shutdown();
+        client->close();
+
+        std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
+
+        auto diff = std::chrono::duration_cast<std::chrono::seconds>(end - begin).count();
+        ASSERT_TRUE(diff < 5);
+}
-- 
GitLab