Skip to content
Snippets Groups Projects
crypto_bot.cpp 22.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include <boost/algorithm/string/predicate.hpp>
    #include <boost/beast.hpp>
    
    
    Konstantinos Sideris's avatar
    Konstantinos Sideris committed
    #include "spdlog/spdlog.h"
    
    #include <atomic>
    #include <iostream>
    #include <json.hpp>
    
    #include <unistd.h>
    #include <variant.hpp>
    
    #include <mtx.hpp>
    #include <mtx/identifiers.hpp>
    
    #include "mtxclient/crypto/client.hpp"
    #include "mtxclient/http/client.hpp"
    #include "mtxclient/http/errors.hpp"
    
    #include "mtxclient/utils.hpp"
    #include "test_helpers.hpp"
    
    //
    // Simple example bot that will accept any invite.
    //
    
    using namespace std;
    using namespace mtx::client;
    using namespace mtx::crypto;
    using namespace mtx::http;
    using namespace mtx::events;
    using namespace mtx::identifiers;
    
    using TimelineEvent = mtx::events::collections::TimelineEvents;
    
    
    constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
    
    struct OlmCipherContent
    {
            std::string body;
            uint8_t type;
    };
    
    inline void
    from_json(const nlohmann::json &obj, OlmCipherContent &msg)
    {
            msg.body = obj.at("body");
            msg.type = obj.at("type");
    }
    
    struct OlmMessage
    {
            std::string sender_key;
            std::string sender;
    
            using RecipientKey = std::string;
            std::map<RecipientKey, OlmCipherContent> ciphertext;
    };
    
    inline void
    from_json(const nlohmann::json &obj, OlmMessage &msg)
    {
            if (obj.at("type") != "m.room.encrypted") {
                    throw std::invalid_argument("invalid type for olm message");
            }
    
            if (obj.at("content").at("algorithm") != OLM_ALGO)
                    throw std::invalid_argument("invalid algorithm for olm message");
    
            msg.sender     = obj.at("sender");
            msg.sender_key = obj.at("content").at("sender_key");
            msg.ciphertext =
              obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
    }
    
    
    template<class Container, class Item>
    bool
    exists(const Container &container, const Item &item)
    {
            return container.find(item) != container.end();
    }
    
    void
    get_device_keys(const UserId &user);
    
    void
    save_device_keys(const mtx::responses::QueryKeys &res);
    
    void
    mark_encrypted_room(const RoomId &id);
    
    
    Konstantinos Sideris's avatar
    Konstantinos Sideris committed
    void
    handle_to_device_msgs(const std::vector<nlohmann::json> &to_device);
    
    
    //! Metadata associated with each active megolm session.
    struct GroupSessionMsgData
    {
            std::string session_id;
            std::string room_id;
            std::string event_id;
    
            uint64_t origin_server_ts;
            uint64_t message_index;
    };
    
    struct Storage
    {
            //! Storage for the user_id -> list of devices mapping.
            std::map<std::string, std::vector<std::string>> devices_;
            //! Storage for the identity key for a device.
            std::map<std::string, std::string> device_keys_;
            //! Flag that indicate if a specific room has encryption enabled.
            std::map<std::string, bool> encrypted_rooms_;
    
    
            //! Mapping from curve25519 to session.
            std::map<std::string, OlmSessionPtr> olm_sessions;
    
            std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions;
    
            bool inbound_group_exists(const std::string &room_id,
                                      const std::string &session_id,
                                      const std::string &sender_key)
            {
                    const auto key = room_id + session_id + sender_key;
                    return inbound_group_sessions.find(key) != inbound_group_sessions.end();
            }
    
            void set_inbound_group_session(const std::string &room_id,
                                           const std::string &session_id,
                                           const std::string &sender_key,
                                           InboundGroupSessionPtr session)
            {
                    const auto key              = room_id + session_id + sender_key;
                    inbound_group_sessions[key] = std::move(session);
            }
    
            OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id,
                                                              const std::string &session_id,
                                                              const std::string &sender_key)
            {
                    const auto key = room_id + session_id + sender_key;
                    return inbound_group_sessions[key].get();
            }
    
    };
    
    namespace {
    std::shared_ptr<Client> client        = nullptr;
    std::shared_ptr<OlmClient> olm_client = nullptr;
    Storage storage;
    
    auto console = spdlog::stdout_color_mt("console");
    }
    
    void
    print_errors(RequestErr err)
    {
            if (err->status_code != boost::beast::http::status::unknown)
                    console->error("status code: {}", static_cast<uint16_t>(err->status_code));
            if (!err->matrix_error.error.empty())
                    console->error("matrix error: {}", err->matrix_error.error);
            if (err->error_code)
                    console->error("error code: {}", err->error_code.message());
    }
    
    template<class T>
    bool
    is_room_encryption(const T &event)
    {
            using namespace mtx::events;
            using namespace mtx::events::state;
            return mpark::holds_alternative<StateEvent<Encryption>>(event);
    }
    
    bool
    is_encrypted(const TimelineEvent &event)
    {
            using namespace mtx::events;
            return mpark::holds_alternative<EncryptedEvent<msg::Encrypted>>(event);
    }
    
    template<class T>
    bool
    is_member_event(const T &event)
    {
            using namespace mtx::events;
            using namespace mtx::events::state;
            return mpark::holds_alternative<StateEvent<Member>>(event);
    }
    
    // Check if the given event has a textual representation.
    bool
    is_room_message(const TimelineEvent &event)
    {
            return mpark::holds_alternative<mtx::events::RoomEvent<msg::Audio>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::Emote>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::File>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::Image>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::Notice>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::Text>>(event) ||
                   mpark::holds_alternative<mtx::events::RoomEvent<msg::Video>>(event);
    }
    
    // Retrieves the fallback body value from the event.
    std::string
    get_body(const TimelineEvent &event)
    {
            if (mpark::holds_alternative<RoomEvent<msg::Audio>>(event))
                    return mpark::get<RoomEvent<msg::Audio>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::Emote>>(event))
                    return mpark::get<RoomEvent<msg::Emote>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::File>>(event))
                    return mpark::get<RoomEvent<msg::File>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::Image>>(event))
                    return mpark::get<RoomEvent<msg::Image>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::Notice>>(event))
                    return mpark::get<RoomEvent<msg::Notice>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::Text>>(event))
                    return mpark::get<RoomEvent<msg::Text>>(event).content.body;
            else if (mpark::holds_alternative<RoomEvent<msg::Video>>(event))
                    return mpark::get<RoomEvent<msg::Video>>(event).content.body;
    
            return "";
    }
    
    // Retrieves the sender of the event.
    std::string
    get_sender(const TimelineEvent &event)
    {
            return mpark::visit([](auto e) { return e.sender; }, event);
    }
    
    template<class T>
    std::string
    get_json(const T &event)
    {
            return mpark::visit([](auto e) { return json(e).dump(2); }, event);
    }
    
    void
    keys_uploaded_cb(const mtx::responses::UploadKeys &, RequestErr err)
    {
            if (err) {
                    print_errors(err);
                    return;
            }
    
            console->info("keys uploaded");
    }
    
    void
    mark_encrypted_room(const RoomId &id)
    {
            console->info("encryption is enabled for room: {}", id.get());
            storage.encrypted_rooms_[id.get()] = true;
    }
    
    
    void
    decrypt_olm_message(const OlmMessage &olm_msg)
    {
            console->info("OLM message");
            console->info("sender: {}", olm_msg.sender);
            console->info("sender_key: {}", olm_msg.sender_key);
    
            const auto my_id_key = olm_client->identity_keys().curve25519;
            for (const auto &cipher : olm_msg.ciphertext) {
                    if (cipher.first == my_id_key) {
                            const auto msg_body = cipher.second.body;
                            const auto msg_type = cipher.second.type;
    
                            console->info("the message is meant for us");
                            console->info("body: {}", msg_body);
                            console->info("type: {}", msg_type);
    
                            if (msg_type == 0) {
                                    console->info("opening session with {}", olm_msg.sender);
                                    auto inbound_session = olm_client->create_inbound_session(msg_body);
    
                                    auto ok = matches_inbound_session_from(
                                      inbound_session.get(), olm_msg.sender_key, msg_body);
    
                                    if (!ok) {
                                            console->error("session could not be established");
    
                                    } else {
                                            auto output = olm_client->decrypt_message(
                                              inbound_session.get(), msg_type, msg_body);
    
                                            auto plaintext = json::parse(
                                              std::string((char *)output.data(), output.size()));
                                            console->info("decrypted message: \n {}",
                                                          plaintext.dump(2));
    
                                            storage.olm_sessions.emplace(olm_msg.sender_key,
                                                                         std::move(inbound_session));
    
                                            std::string room_id = plaintext.at("content").at("room_id");
                                            std::string session_id =
                                              plaintext.at("content").at("session_id");
                                            std::string session_key =
                                              plaintext.at("content").at("session_key");
    
                                            if (storage.inbound_group_exists(
                                                  room_id, session_id, olm_msg.sender_key)) {
                                                    console->warn("megolm session already exists");
                                            } else {
                                                    auto megolm_session =
                                                      olm_client->init_inbound_group_session(
                                                        session_key);
    
                                                    storage.set_inbound_group_session(
                                                      room_id,
                                                      session_id,
                                                      olm_msg.sender_key,
                                                      std::move(megolm_session));
    
                                                    console->info("megolm_session saved");
                                            }
                                    }
                            }
                    }
            }
    }
    
    
    void
    parse_messages(const mtx::responses::Sync &res)
    {
            for (const auto &room : res.rooms.invite) {
                    auto room_id = parse<Room>(room.first);
    
                    console->info("joining room {}", room_id.to_string());
                    client->join_room(room_id, [room_id](const nlohmann::json &, RequestErr e) {
                            if (e) {
                                    print_errors(e);
                                    console->error("failed to join room {}", room_id.to_string());
                                    return;
                            }
                    });
            }
    
            // Check if we have any new m.room_key messages (i.e starting a new megolm session)
    
    Konstantinos Sideris's avatar
    Konstantinos Sideris committed
            handle_to_device_msgs(res.to_device);
    
    
            // Check if the uploaded one time keys are enough
            for (const auto &device : res.device_one_time_keys_count) {
                    if (device.second < 50) {
                            olm_client->generate_one_time_keys(50 - device.second);
                            // TODO: Mark keys as sent
                            client->upload_keys(olm_client->create_upload_keys_request(),
                                                &keys_uploaded_cb);
                    }
            }
    
            for (const auto &room : res.rooms.join) {
                    const std::string room_id = room.first;
    
                    for (const auto &e : room.second.state.events) {
                            if (is_room_encryption(e)) {
                                    mark_encrypted_room(RoomId(room_id));
                                    console->debug("{}", get_json(e));
                            }
                    }
    
                    for (const auto &e : room.second.timeline.events) {
                            if (is_room_encryption(e)) {
                                    mark_encrypted_room(RoomId(room_id));
                                    console->debug("{}", get_json(e));
                            } else if (is_encrypted(e)) {
                                    console->info("received an encrypted event: {}", room_id);
    
                                    console->info("{}", get_json(e));
    
                                    auto msg = mpark::get<EncryptedEvent<msg::Encrypted>>(e);
    
                                    if (storage.inbound_group_exists(
                                          room_id, msg.content.session_id, msg.content.sender_key)) {
                                            auto res = olm_client->decrypt_group_message(
                                              storage.get_inbound_group_session(room_id,
                                                                                msg.content.session_id,
                                                                                msg.content.sender_key),
                                              msg.content.ciphertext);
    
                                            auto msg_str =
                                              std::string((char *)res.data.data(), res.data.size());
    
                                            console->info("decrypted data: {}", msg_str);
                                            console->info("decrypted message_index: {}",
                                                          res.message_index);
                                    } else {
                                            console->warn(
                                              "no megolm session found to decrypt the event");
                                    }
    
                            }
                    }
            }
    }
    
    // Callback to executed after a /sync request completes.
    void
    sync_handler(const mtx::responses::Sync &res, RequestErr err)
    {
            SyncOpts opts;
    
            if (err) {
                    console->error("error during sync");
                    print_errors(err);
                    opts.since = client->next_batch_token();
                    client->sync(opts, &sync_handler);
                    return;
            }
    
            parse_messages(res);
    
            opts.since = res.next_batch;
            client->set_next_batch_token(res.next_batch);
            client->sync(opts, &sync_handler);
    }
    
    // Callback to executed after the first (initial) /sync request completes.
    void
    initial_sync_handler(const mtx::responses::Sync &res, RequestErr err)
    {
            SyncOpts opts;
    
            if (err) {
                    console->error("error during initial sync");
                    print_errors(err);
    
                    if (err->status_code != boost::beast::http::status::ok) {
                            console->error("retrying initial sync ..");
                            opts.timeout = 0;
                            client->sync(opts, &initial_sync_handler);
                    }
    
                    return;
            }
    
            parse_messages(res);
    
            for (const auto &room : res.rooms.join) {
                    for (const auto &e : room.second.state.events) {
                            if (is_member_event(e)) {
                                    auto m =
                                      mpark::get<mtx::events::StateEvent<mtx::events::state::Member>>(
                                        e);
    
                                    get_device_keys(UserId(m.state_key));
                            }
                    }
            }
    
            opts.since = res.next_batch;
            client->set_next_batch_token(res.next_batch);
            client->sync(opts, &sync_handler);
    }
    
    void
    save_device_keys(const mtx::responses::QueryKeys &res)
    {
            for (const auto &entry : res.device_keys) {
                    const auto user_id = entry.first;
    
                    if (!exists(storage.devices_, user_id))
                            console->info("keys for {}", user_id);
    
                    std::vector<std::string> device_list;
                    for (const auto &device : entry.second) {
                            const auto key_struct = device.second;
    
                            const std::string device_id = key_struct.device_id;
                            const std::string index     = "curve25519:" + device_id;
    
                            if (key_struct.keys.find(index) == key_struct.keys.end())
                                    continue;
    
                            const auto key = key_struct.keys.at(index);
    
                            if (!exists(storage.device_keys_, device_id)) {
                                    console->info("{} => {}", device_id, key);
                                    storage.device_keys_[device_id] = key;
                            }
    
                            device_list.push_back(device_id);
                    }
    
                    if (!exists(storage.devices_, user_id)) {
                            storage.devices_[user_id] = device_list;
                    }
            }
    }
    
    void
    get_device_keys(const UserId &user)
    {
            // Retrieve all devices keys.
            mtx::requests::QueryKeys query;
            query.device_keys[user.get()] = {};
    
            client->query_keys(query, [](const mtx::responses::QueryKeys &res, RequestErr err) {
                    if (err) {
                            print_errors(err);
                            return;
                    }
    
                    for (const auto &key : res.device_keys) {
                            const auto user_id = key.first;
                            const auto devices = key.second;
    
                            for (const auto &device : devices) {
                                    const auto id          = device.first;
                                    const auto data        = device.second;
                                    const auto signing_key = data.keys.at("ed25519:" + id);
    
                                    try {
                                            auto ok = verify_identity_signature(
                                              json(data), DeviceId(id), UserId(user_id), signing_key);
    
                                            if (!ok) {
                                                    console->warn("signature could not be verified");
                                                    console->warn(json(data).dump(2));
                                            }
                                    } catch (const olm_exception &e) {
                                            console->warn(e.what());
                                    }
                            }
                    }
    
                    save_device_keys(std::move(res));
            });
    }
    
    
    Konstantinos Sideris's avatar
    Konstantinos Sideris committed
    void
    
    handle_to_device_msgs(const std::vector<nlohmann::json> &msgs)
    
    Konstantinos Sideris's avatar
    Konstantinos Sideris committed
    {
    
            if (!msgs.empty())
                    console->info("inspecting {} to_device messages", msgs.size());
    
            for (const auto &msg : msgs) {
                    console->info(msg.dump(2));
    
                    try {
                            OlmMessage olm_msg = msg;
                            decrypt_olm_message(std::move(olm_msg));
                    } catch (const nlohmann::json::exception &e) {
                            console->warn("parsing error for olm message: {}", e.what());
                    } catch (const std::invalid_argument &e) {
                            console->warn("validation error for olm message: {}", e.what());
                    }
            }
    
    void
    login_cb(const mtx::responses::Login &, RequestErr err)
    {
            if (err) {
                    console->error("login error");
                    print_errors(err);
                    return;
            }
    
            console->info("User ID: {}", client->user_id().to_string());
            console->info("Device ID: {}", client->device_id());
    
            // Upload one time keys.
            olm_client->set_user_id(client->user_id().to_string());
            olm_client->set_device_id(client->device_id());
            olm_client->generate_one_time_keys(50);
    
            client->upload_keys(olm_client->create_upload_keys_request(),
                                [](const mtx::responses::UploadKeys &, RequestErr err) {
                                        if (err) {
                                                print_errors(err);
                                                return;
                                        }
    
                                        console->info("keys uploaded");
                                        console->debug("starting initial sync");
    
                                        SyncOpts opts;
                                        opts.timeout = 0;
                                        client->sync(opts, &initial_sync_handler);
                                });
    }
    
    void
    join_room_cb(const nlohmann::json &obj, RequestErr err)
    {
            if (err) {
                    print_errors(err);
                    return;
            }
    
            (void)obj;
    
            // Fetch device list for all users.
    }
    
    int
    main()
    {
            spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
    
    
            std::string username("mtx_bot");
            std::string server("matrix.org");
            std::string password("dzyvrwB09GdyEqiyBnfAEvZI3");
    
            // cout << "username: ";
            // std::getline(std::cin, username);
    
            // cout << "server: ";
            // std::getline(std::cin, server);
    
            // password = getpass("password: ");
    
    
            client = std::make_shared<Client>(server);
    
            olm_client = make_shared<OlmClient>();
            olm_client->create_new_account();
    
    
            console->info("ed25519: {}", olm_client->identity_keys().ed25519);
            console->info("curve25519: {}", olm_client->identity_keys().curve25519);
    
    
            client->login(username, password, login_cb);
            client->close();
    
            return 0;
    }