diff --git a/src/client.hpp b/src/client.hpp index dc023f8cd69370364f80a394748340a2ff1f119d..3d3ca86b80afa356e85c78017a3d10b7dd3efb31 100644 --- a/src/client.hpp +++ b/src/client.hpp @@ -281,6 +281,8 @@ private: boost::thread_group thread_group_; //! Used to resolve DNS names. boost::asio::ip::tcp::resolver resolver_; + //! SSL context for requests. + boost::asio::ssl::context ssl_ctx_{boost::asio::ssl::context::sslv23_client}; //! The homeserver to connect to. std::string server_; //! The access token that would be used for authentication. @@ -333,7 +335,7 @@ mtx::client::Client::post(const std::string &endpoint, bool requires_auth, const std::string &content_type) { - std::shared_ptr<Session> session = create_session<Response>( + auto session = create_session<Response>( [callback](const Response &res, HeaderFields, RequestErr err) { callback(res, err); }); if (!session) @@ -361,7 +363,7 @@ mtx::client::Client::put(const std::string &endpoint, Callback<Response> callback, bool requires_auth) { - std::shared_ptr<Session> session = create_session<Response>( + auto session = create_session<Response>( [callback](const Response &res, HeaderFields, RequestErr err) { callback(res, err); }); if (!session) @@ -402,7 +404,7 @@ mtx::client::Client::get(const std::string &endpoint, HeadersCallback<Response> callback, bool requires_auth) { - std::shared_ptr<Session> session = create_session<Response>(callback); + auto session = create_session<Response>(callback); if (!session) return; @@ -423,68 +425,61 @@ template<class Response> std::shared_ptr<mtx::client::Session> mtx::client::Client::create_session(HeadersCallback<Response> callback) { - boost::asio::ssl::context ssl_ctx{boost::asio::ssl::context::sslv23_client}; - std::shared_ptr<Session> session = std::make_shared<Session>( ios_, - ssl_ctx, + ssl_ctx_, server_, utils::random_token(), - [callback, _this = shared_from_this()]( - RequestID, - const boost::beast::http::response<boost::beast::http::string_body> &response, - const boost::system::error_code &err_code) { - _this->ios_.post([callback, response, err_code]() { - Response response_data; - mtx::client::errors::ClientError client_error; - - const auto header = response.base(); - - if (err_code) { - client_error.error_code = err_code; - return callback(response_data, header, client_error); - } + [callback](RequestID, + const boost::beast::http::response<boost::beast::http::string_body> &response, + const boost::system::error_code &err_code) { + Response response_data; + mtx::client::errors::ClientError client_error; + + const auto header = response.base(); + + if (err_code) { + client_error.error_code = err_code; + return callback(response_data, header, client_error); + } + + // Decompress the response. + const auto body = utils::decompress( + boost::iostreams::array_source{response.body().data(), response.body().size()}, + header["Content-Encoding"].to_string()); - // Decompress the response. - const auto body = utils::decompress( - boost::iostreams::array_source{response.body().data(), - response.body().size()}, - header["Content-Encoding"].to_string()); - - if (response.result() != boost::beast::http::status::ok) { - client_error.status_code = response.result(); - - // Try to parse the response in case we have an endpoint that - // doesn't return an error struct for non 200 requests. - try { - response_data = deserialize<Response>(body); - } catch (const nlohmann::json::exception &e) { - } - - // The homeserver should return an error struct. - try { - nlohmann::json json_error = json::parse(body); - mtx::errors::Error matrix_error = json_error; - - client_error.matrix_error = matrix_error; - return callback(response_data, header, client_error); - } catch (const nlohmann::json::exception &e) { - client_error.parse_error = - std::string(e.what()) + ": " + body; - - return callback(response_data, header, client_error); - } + if (response.result() != boost::beast::http::status::ok) { + client_error.status_code = response.result(); + + // Try to parse the response in case we have an endpoint that + // doesn't return an error struct for non 200 requests. + try { + response_data = deserialize<Response>(body); + } catch (const nlohmann::json::exception &e) { } - // If we reach that point we most likely have a valid output from the - // homeserver. + // The homeserver should return an error struct. try { - callback(deserialize<Response>(body), header, {}); + nlohmann::json json_error = json::parse(body); + mtx::errors::Error matrix_error = json_error; + + client_error.matrix_error = matrix_error; + return callback(response_data, header, client_error); } catch (const nlohmann::json::exception &e) { client_error.parse_error = std::string(e.what()) + ": " + body; - callback(response_data, header, client_error); + + return callback(response_data, header, client_error); } - }); + } + + // If we reach that point we most likely have a valid output from the + // homeserver. + try { + callback(deserialize<Response>(body), header, {}); + } catch (const nlohmann::json::exception &e) { + client_error.parse_error = std::string(e.what()) + ": " + body; + callback(response_data, header, client_error); + } }, [callback](RequestID, const boost::system::error_code ec) { Response response_data;