diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d7039001e279f80d7cb38b3205c1ffdf6987786..e2465c58cbaa84cb176e19317657edcb64eb21fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -163,6 +163,7 @@ target_sources(matrix_client lib/structs/identifiers.cpp lib/structs/pushrules.cpp lib/structs/requests.cpp + lib/structs/user_interactive.cpp lib/structs/events/aliases.cpp lib/structs/events/avatar.cpp lib/structs/events/canonical_alias.cpp diff --git a/include/mtx.hpp b/include/mtx.hpp index b05e9ff3f0e8898efee0cb5c96792588c6e26430..3de62b61a655baa0b17950084c4fd9cf0c89b224 100644 --- a/include/mtx.hpp +++ b/include/mtx.hpp @@ -26,5 +26,7 @@ #include "mtx/events/messages/text.hpp" #include "mtx/events/messages/video.hpp" +#include "mtx/user_interactive.hpp" + #include "mtx/requests.hpp" #include "mtx/responses.hpp" diff --git a/include/mtx/errors.hpp b/include/mtx/errors.hpp index 267c02fee1821d32646e2df0ccfec58437e840b5..fb12ba87f1501e927368c4829774a4ba4560e897 100644 --- a/include/mtx/errors.hpp +++ b/include/mtx/errors.hpp @@ -7,6 +7,8 @@ #endif #include <string> +#include "user_interactive.hpp" + namespace mtx { namespace errors { @@ -62,6 +64,9 @@ struct Error ErrorCode errcode = {}; //! Human readable version of the error. std::string error; + + //! Auth flows in case of 401 + user_interactive::Unauthorized unauthorized; }; void diff --git a/include/mtx/user_interactive.hpp b/include/mtx/user_interactive.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e102ac630e6b3438f52b33b489777ea672d69f0b --- /dev/null +++ b/include/mtx/user_interactive.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include <string> +#include <string_view> +#include <unordered_map> +#include <variant> +#include <vector> + +#include <nlohmann/json.hpp> + +namespace mtx::user_interactive { +using AuthType = std::string; +namespace auth_types { +constexpr std::string_view password = "m.login.password"; +constexpr std::string_view recaptcha = "m.login.recaptcha"; +constexpr std::string_view oauth2 = "m.login.oauth2"; +constexpr std::string_view email_identity = "m.login.email.identity"; +constexpr std::string_view msisdn = "m.login.msisdn"; +constexpr std::string_view token = "m.login.token"; +constexpr std::string_view dummy = "m.login.dummy"; +constexpr std::string_view terms = "m.login.terms"; // see MSC1692 +} + +using Stages = std::vector<AuthType>; +struct Flow +{ + Stages stages; +}; +void +from_json(const nlohmann::json &obj, Flow &flow); +struct OAuth2Params +{ + std::string uri; +}; +void +from_json(const nlohmann::json &obj, OAuth2Params ¶ms); + +struct PolicyDescription +{ + std::string name; // language specific name + std::string url; // language specific link +}; +void +from_json(const nlohmann::json &obj, PolicyDescription &desc); + +struct Policy +{ + std::string version; + // 2 letter language code to policy name and link, fallback to "en" + // recommended, when language not available. + std::unordered_map<std::string, PolicyDescription> langToPolicy; +}; +void +from_json(const nlohmann::json &obj, Policy &policy); + +struct TermsParams +{ + std::unordered_map<std::string, Policy> policies; +}; +void +from_json(const nlohmann::json &obj, TermsParams ¶ms); + +using Params = std::variant<OAuth2Params, TermsParams, nlohmann::json>; + +struct Unauthorized +{ + // completed stages + Stages completed; + + // session key to provide to further auth stages + std::string session; + + // list of flows, which can be used to complete the UI auth + std::vector<Flow> flows; + + // AuthType may be an undocumented string, not defined in auth_types + std::unordered_map<AuthType, Params> params; +}; +void +from_json(const nlohmann::json &obj, Unauthorized &unauthorized); +} diff --git a/lib/structs/errors.cpp b/lib/structs/errors.cpp index fc034b5ffab5a96365cc77fc358be8e97d754ec8..42940151a4f413c7390be672b95eae47bb5b04c3 100644 --- a/lib/structs/errors.cpp +++ b/lib/structs/errors.cpp @@ -2,8 +2,7 @@ #include <nlohmann/json.hpp> -namespace mtx { -namespace errors { +namespace mtx::errors { std::string to_string(ErrorCode code) { @@ -86,7 +85,9 @@ void from_json(const nlohmann::json &obj, Error &error) { error.errcode = from_string(obj.at("errcode").get<std::string>()); - error.error = obj.at("error").get<std::string>(); -} + error.error = obj.value("error", ""); + + if (obj.contains("session")) + error.unauthorized = obj.get<user_interactive::Unauthorized>(); } } diff --git a/lib/structs/user_interactive.cpp b/lib/structs/user_interactive.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b525bae042eca81b8109542e5e6ae2b80bf3c1ce --- /dev/null +++ b/lib/structs/user_interactive.cpp @@ -0,0 +1,57 @@ +#include "mtx/user_interactive.hpp" + +namespace mtx::user_interactive { +void +from_json(const nlohmann::json &obj, OAuth2Params ¶ms) +{ + params.uri = obj.value("uri", ""); +} + +void +from_json(const nlohmann::json &obj, PolicyDescription &d) +{ + d.name = obj.value("name", ""); + d.url = obj.value("url", ""); +} +void +from_json(const nlohmann::json &obj, Policy &policy) +{ + policy.version = obj.at("version"); + + for (const auto &e : obj.items()) + if (e.key() != "version") + policy.langToPolicy.emplace(e.key(), e.value().get<PolicyDescription>()); +} + +void +from_json(const nlohmann::json &obj, TermsParams &terms) +{ + terms.policies = obj["policies"].get<std::unordered_map<std::string, Policy>>(); +} + +void +from_json(const nlohmann::json &obj, Flow &flow) +{ + flow.stages = obj["stages"].get<Stages>(); +} +void +from_json(const nlohmann::json &obj, Unauthorized &u) +{ + if (obj.contains("completed")) + u.completed = obj.at("completed").get<Stages>(); + + u.session = obj.at("session"); + u.flows = obj.at("flows").get<std::vector<Flow>>(); + + if (obj.contains("params")) { + for (const auto &e : obj["params"].items()) { + if (e.key() == auth_types::terms) + u.params.emplace(e.key(), e.value().get<TermsParams>()); + else if (e.key() == auth_types::oauth2) + u.params.emplace(e.key(), e.value().get<OAuth2Params>()); + else + u.params.emplace(e.key(), e.value()); + } + } +} +} diff --git a/tests/responses.cpp b/tests/responses.cpp index 917426f57994bab5839796897f55d8e834825ff7..7c73f7d1ab2e25b9bb7e442a84dceecd88dbe130 100644 --- a/tests/responses.cpp +++ b/tests/responses.cpp @@ -959,3 +959,88 @@ TEST(Responses, Notifications) EXPECT_EQ(event.content.body, "I am a fish"); EXPECT_EQ(event.sender, "@alice:example.com"); } + +TEST(Responses, Userinteractive) +{ + json data = + R"( +{ + "completed": [ "example.type.foo" ], + "session": "YQVPFRiztSYtmsjLNQmsxTCg", + "flows": [ + { + "stages": [ + "m.login.recaptcha", + "m.login.terms", + "m.login.dummy" + ] + }, + { + "stages": [ + "m.login.recaptcha", + "m.login.terms", + "m.login.email.identity" + ] + } + ], + "params": { + "m.login.recaptcha": { + "public_key": "6LcgI54UAAAAABGdGmruw 6DdOocFpYVdjYBRe4zb" + }, + "m.login.terms": { + "policies": { + "privacy_policy": { + "version": "1.0", + "en": { + "name": "Terms and Conditions", + "url": "https://matrix-client.matrix.org/_matrix/consent?v=1.0" + } + } + } + } + } +})"_json; + mtx::user_interactive::Unauthorized unauthorized = data; + + EXPECT_EQ(unauthorized.completed[0], "example.type.foo"); + EXPECT_EQ(unauthorized.session, "YQVPFRiztSYtmsjLNQmsxTCg"); + EXPECT_EQ(unauthorized.flows.size(), 2); + EXPECT_EQ(unauthorized.flows[0].stages[0], "m.login.recaptcha"); + EXPECT_EQ(unauthorized.flows[0].stages[1], "m.login.terms"); + EXPECT_EQ(unauthorized.flows[0].stages[2], "m.login.dummy"); + EXPECT_EQ(unauthorized.flows[1].stages[0], "m.login.recaptcha"); + EXPECT_EQ(unauthorized.flows[1].stages[1], "m.login.terms"); + EXPECT_EQ(unauthorized.flows[1].stages[2], "m.login.email.identity"); + + EXPECT_EQ(std::get<mtx::user_interactive::TermsParams>( + unauthorized.params[std::string{mtx::user_interactive::auth_types::terms}]) + .policies.size(), + 1); + EXPECT_EQ(std::get<mtx::user_interactive::TermsParams>( + unauthorized.params[std::string{mtx::user_interactive::auth_types::terms}]) + .policies["privacy_policy"] + .version, + "1.0"); + EXPECT_EQ(std::get<mtx::user_interactive::TermsParams>( + unauthorized.params[std::string{mtx::user_interactive::auth_types::terms}]) + .policies["privacy_policy"] + .langToPolicy["en"] + .name, + "Terms and Conditions"); + + json data2 = R"( +{ + "session": "CFNYzCbLYyGTpURjdmkIXMHc", + "flows": [ + { + "stages": [ + "m.login.password" + ] + } + ], + "params": {} +})"_json; + + unauthorized = data2; + EXPECT_EQ(unauthorized.flows[0].stages[0], mtx::user_interactive::auth_types::password); +}