From 7e522df71c804cbc04f38007c3958ebbb56815d3 Mon Sep 17 00:00:00 2001 From: Konstantinos Sideris <sideris.konstantin@gmail.com> Date: Sun, 20 May 2018 15:55:38 +0300 Subject: [PATCH] Implement decryption for group events on the crypto_bot --- examples/crypto_bot.cpp | 196 ++++++++++++++++++++++++++-- include/mtxclient/crypto/client.hpp | 34 +++-- lib/crypto/client.cpp | 67 +++++++++- tests/e2ee.cpp | 21 +-- 4 files changed, 286 insertions(+), 32 deletions(-) diff --git a/examples/crypto_bot.cpp b/examples/crypto_bot.cpp index 8171af80b..70915fd0d 100644 --- a/examples/crypto_bot.cpp +++ b/examples/crypto_bot.cpp @@ -5,6 +5,7 @@ #include <atomic> #include <iostream> #include <json.hpp> +#include <stdexcept> #include <unistd.h> #include <variant.hpp> @@ -31,6 +32,46 @@ using namespace mtx::identifiers; using TimelineEvent = mtx::events::collections::TimelineEvents; +constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2"; + +struct OlmCipherContent +{ + std::string body; + uint8_t type; +}; + +inline void +from_json(const nlohmann::json &obj, OlmCipherContent &msg) +{ + msg.body = obj.at("body"); + msg.type = obj.at("type"); +} + +struct OlmMessage +{ + std::string sender_key; + std::string sender; + + using RecipientKey = std::string; + std::map<RecipientKey, OlmCipherContent> ciphertext; +}; + +inline void +from_json(const nlohmann::json &obj, OlmMessage &msg) +{ + if (obj.at("type") != "m.room.encrypted") { + throw std::invalid_argument("invalid type for olm message"); + } + + if (obj.at("content").at("algorithm") != OLM_ALGO) + throw std::invalid_argument("invalid algorithm for olm message"); + + msg.sender = obj.at("sender"); + msg.sender_key = obj.at("content").at("sender_key"); + msg.ciphertext = + obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>(); +} + template<class Container, class Item> bool exists(const Container &container, const Item &item) @@ -69,6 +110,36 @@ struct Storage std::map<std::string, std::string> device_keys_; //! Flag that indicate if a specific room has encryption enabled. std::map<std::string, bool> encrypted_rooms_; + + //! Mapping from curve25519 to session. + std::map<std::string, OlmSessionPtr> olm_sessions; + + std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions; + + bool inbound_group_exists(const std::string &room_id, + const std::string &session_id, + const std::string &sender_key) + { + const auto key = room_id + session_id + sender_key; + return inbound_group_sessions.find(key) != inbound_group_sessions.end(); + } + + void set_inbound_group_session(const std::string &room_id, + const std::string &session_id, + const std::string &sender_key, + InboundGroupSessionPtr session) + { + const auto key = room_id + session_id + sender_key; + inbound_group_sessions[key] = std::move(session); + } + + OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id, + const std::string &session_id, + const std::string &sender_key) + { + const auto key = room_id + session_id + sender_key; + return inbound_group_sessions[key].get(); + } }; namespace { @@ -182,6 +253,73 @@ mark_encrypted_room(const RoomId &id) storage.encrypted_rooms_[id.get()] = true; } +void +decrypt_olm_message(const OlmMessage &olm_msg) +{ + console->info("OLM message"); + console->info("sender: {}", olm_msg.sender); + console->info("sender_key: {}", olm_msg.sender_key); + + const auto my_id_key = olm_client->identity_keys().curve25519; + for (const auto &cipher : olm_msg.ciphertext) { + if (cipher.first == my_id_key) { + const auto msg_body = cipher.second.body; + const auto msg_type = cipher.second.type; + + console->info("the message is meant for us"); + console->info("body: {}", msg_body); + console->info("type: {}", msg_type); + + if (msg_type == 0) { + console->info("opening session with {}", olm_msg.sender); + auto inbound_session = olm_client->create_inbound_session(msg_body); + + auto ok = matches_inbound_session_from( + inbound_session.get(), olm_msg.sender_key, msg_body); + + if (!ok) { + console->error("session could not be established"); + + } else { + auto output = olm_client->decrypt_message( + inbound_session.get(), msg_type, msg_body); + + auto plaintext = json::parse( + std::string((char *)output.data(), output.size())); + console->info("decrypted message: \n {}", + plaintext.dump(2)); + + storage.olm_sessions.emplace(olm_msg.sender_key, + std::move(inbound_session)); + + std::string room_id = plaintext.at("content").at("room_id"); + std::string session_id = + plaintext.at("content").at("session_id"); + std::string session_key = + plaintext.at("content").at("session_key"); + + if (storage.inbound_group_exists( + room_id, session_id, olm_msg.sender_key)) { + console->warn("megolm session already exists"); + } else { + auto megolm_session = + olm_client->init_inbound_group_session( + session_key); + + storage.set_inbound_group_session( + room_id, + session_id, + olm_msg.sender_key, + std::move(megolm_session)); + + console->info("megolm_session saved"); + } + } + } + } + } +} + void parse_messages(const mtx::responses::Sync &res) { @@ -227,7 +365,28 @@ parse_messages(const mtx::responses::Sync &res) console->debug("{}", get_json(e)); } else if (is_encrypted(e)) { console->info("received an encrypted event: {}", room_id); - console->debug("{}", get_json(e)); + console->info("{}", get_json(e)); + + auto msg = mpark::get<EncryptedEvent<msg::Encrypted>>(e); + + if (storage.inbound_group_exists( + room_id, msg.content.session_id, msg.content.sender_key)) { + auto res = olm_client->decrypt_group_message( + storage.get_inbound_group_session(room_id, + msg.content.session_id, + msg.content.sender_key), + msg.content.ciphertext); + + auto msg_str = + std::string((char *)res.data.data(), res.data.size()); + + console->info("decrypted data: {}", msg_str); + console->info("decrypted message_index: {}", + res.message_index); + } else { + console->warn( + "no megolm session found to decrypt the event"); + } } } } @@ -368,9 +527,23 @@ get_device_keys(const UserId &user) } void -handle_to_device_msgs(const std::vector<nlohmann::json> &to_device) +handle_to_device_msgs(const std::vector<nlohmann::json> &msgs) { - (void)to_device; + if (!msgs.empty()) + console->info("inspecting {} to_device messages", msgs.size()); + + for (const auto &msg : msgs) { + console->info(msg.dump(2)); + + try { + OlmMessage olm_msg = msg; + decrypt_olm_message(std::move(olm_msg)); + } catch (const nlohmann::json::exception &e) { + console->warn("parsing error for olm message: {}", e.what()); + } catch (const std::invalid_argument &e) { + console->warn("validation error for olm message: {}", e.what()); + } + } } void @@ -424,21 +597,26 @@ main() { spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v"); - std::string username, server, password; + std::string username("mtx_bot"); + std::string server("matrix.org"); + std::string password("dzyvrwB09GdyEqiyBnfAEvZI3"); - cout << "username: "; - std::getline(std::cin, username); + // cout << "username: "; + // std::getline(std::cin, username); - cout << "server: "; - std::getline(std::cin, server); + // cout << "server: "; + // std::getline(std::cin, server); - password = getpass("password: "); + // password = getpass("password: "); client = std::make_shared<Client>(server); olm_client = make_shared<OlmClient>(); olm_client->create_new_account(); + console->info("ed25519: {}", olm_client->identity_keys().ed25519); + console->info("curve25519: {}", olm_client->identity_keys().curve25519); + client->login(username, password, login_cb); client->close(); diff --git a/include/mtxclient/crypto/client.hpp b/include/mtxclient/crypto/client.hpp index b100f1c55..1b64e450b 100644 --- a/include/mtxclient/crypto/client.hpp +++ b/include/mtxclient/crypto/client.hpp @@ -42,6 +42,10 @@ public: : msg_(func + ": " + std::string(olm_outbound_group_session_last_error(s))) {} + olm_exception(std::string func, OlmInboundGroupSession *s) + : msg_(func + ": " + std::string(olm_inbound_group_session_last_error(s))) + {} + olm_exception(std::string msg) : msg_(msg) {} @@ -127,6 +131,16 @@ create_olm_object() return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate()); } +using OlmSessionPtr = std::unique_ptr<OlmSession, OlmDeleter>; +using OutboundGroupSessionPtr = std::unique_ptr<OlmOutboundGroupSession, OlmDeleter>; +using InboundGroupSessionPtr = std::unique_ptr<OlmInboundGroupSession, OlmDeleter>; + +struct GroupPlaintext +{ + BinaryBuf data; + uint32_t message_index; +}; + class OlmClient : public std::enable_shared_from_this<OlmClient> { public: @@ -170,6 +184,10 @@ public: mtx::requests::UploadKeys create_upload_keys_request(const OneTimeKeys &keys); mtx::requests::UploadKeys create_upload_keys_request(); + //! Decrypt a message using megolm. + GroupPlaintext decrypt_group_message(OlmInboundGroupSession *session, + const std::string &message, + uint32_t message_index = 0); //! Encrypt a message using olm. BinaryBuf encrypt_message(OlmSession *session, const std::string &msg); //! Decrypt a message using olm. @@ -178,12 +196,12 @@ public: const std::string &msg); //! Create an outbount megolm session. - std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> init_outbound_group_session(); - std::unique_ptr<OlmSession, OlmDeleter> create_outbound_session( - const std::string &identity_key, - const std::string &one_time_key); - std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session( - const BinaryBuf &one_time_key_message); + OutboundGroupSessionPtr init_outbound_group_session(); + InboundGroupSessionPtr init_inbound_group_session(const std::string &session_key); + OlmSessionPtr create_outbound_session(const std::string &identity_key, + const std::string &one_time_key); + OlmSessionPtr create_inbound_session(const BinaryBuf &one_time_key_message); + OlmSessionPtr create_inbound_session(const std::string &one_time_key_message); OlmAccount *account() { return account_.get(); } OlmUtility *utility() { return utility_.get(); } @@ -212,12 +230,12 @@ std::string session_key(OlmOutboundGroupSession *s); bool -matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message); +matches_inbound_session(OlmSession *session, const std::string &one_time_key_message); bool matches_inbound_session_from(OlmSession *session, const std::string &id_key, - const BinaryBuf &one_time_key_message); + const std::string &one_time_key_message); //! Verify a signature object as obtained from the response of /keys/query endpoint bool diff --git a/lib/crypto/client.cpp b/lib/crypto/client.cpp index c28e58b32..7637127b6 100644 --- a/lib/crypto/client.cpp +++ b/lib/crypto/client.cpp @@ -174,7 +174,7 @@ OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_k return req; } -std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> +OutboundGroupSessionPtr OlmClient::init_outbound_group_session() { auto session = create_olm_object<OlmOutboundGroupSession>(); @@ -189,6 +189,54 @@ OlmClient::init_outbound_group_session() return session; } +InboundGroupSessionPtr +OlmClient::init_inbound_group_session(const std::string &session_key) +{ + auto session = create_olm_object<OlmInboundGroupSession>(); + + const int ret = olm_init_inbound_group_session( + session.get(), reinterpret_cast<const uint8_t *>(session_key.data()), session_key.size()); + + if (ret == -1) + throw olm_exception("init_inbound_group_session", session.get()); + + return session; +} + +GroupPlaintext +OlmClient::decrypt_group_message(OlmInboundGroupSession *session, + const std::string &message, + uint32_t message_index) +{ + // TODO handle errors + auto tmp_msg = create_buffer(message.size()); + std::copy(message.begin(), message.end(), tmp_msg.begin()); + + auto plaintext_len = + olm_group_decrypt_max_plaintext_length(session, tmp_msg.data(), tmp_msg.size()); + auto plaintext = create_buffer(plaintext_len); + + tmp_msg = create_buffer(message.size()); + std::copy(message.begin(), message.end(), tmp_msg.begin()); + + const int nbytes = olm_group_decrypt(session, + tmp_msg.data(), + tmp_msg.size(), + plaintext.data(), + plaintext.size(), + &message_index); + + logger->info("new message_index: {}", message_index); + + if (nbytes == -1) + throw olm_exception("olm_group_decrypt", session); + + auto output = create_buffer(nbytes); + std::memcpy(output.data(), plaintext.data(), nbytes); + + return GroupPlaintext{std::move(output), message_index}; +} + BinaryBuf OlmClient::decrypt_message(OlmSession *session, size_t msgtype, @@ -235,7 +283,16 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg) return ciphertext; } -std::unique_ptr<OlmSession, OlmDeleter> +OlmSessionPtr +OlmClient::create_inbound_session(const std::string &one_time_key_message) +{ + BinaryBuf tmp(one_time_key_message.size()); + memcpy(tmp.data(), one_time_key_message.data(), one_time_key_message.size()); + + return create_inbound_session(std::move(tmp)); +} + +OlmSessionPtr OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message) { auto session = create_olm_object<OlmSession>(); @@ -252,7 +309,7 @@ OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message) return session; } -std::unique_ptr<OlmSession, OlmDeleter> +OlmSessionPtr OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key) { auto session = create_olm_object<OlmSession>(); @@ -322,7 +379,7 @@ mtx::crypto::session_key(OlmOutboundGroupSession *s) } bool -mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message) +mtx::crypto::matches_inbound_session(OlmSession *session, const std::string &one_time_key_message) { auto tmp = create_buffer(one_time_key_message.size()); std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); @@ -333,7 +390,7 @@ mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_t bool mtx::crypto::matches_inbound_session_from(OlmSession *session, const std::string &id_key, - const BinaryBuf &one_time_key_message) + const std::string &one_time_key_message) { auto tmp = create_buffer(one_time_key_message.size()); std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp index 8fd0b2514..24b608468 100644 --- a/tests/e2ee.cpp +++ b/tests/e2ee.cpp @@ -532,29 +532,30 @@ TEST(Encryption, OlmSessions) auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key); // Alice encrypts the message using the current session. - auto plaintext = "Hello, Bob!"; - size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); - auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext); + auto plaintext = "Hello, Bob!"; + size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); + auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext); + auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size()); EXPECT_EQ(msgtype, 0); // Bob creates an inbound session to receive Alice's message. - auto bob_inbound_session = bob->create_inbound_session(ciphertext); + auto bob_inbound_session = bob->create_inbound_session(ciphertext_str); // Bob validates that the message was meant for him. - ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext)); + ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext_str)); // Bob validates that the message was sent from Alice. - ASSERT_EQ(1, - matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext)); + ASSERT_EQ( + 1, matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext_str)); // Bob validates that the message wasn't sent by someone else. - ASSERT_EQ(0, matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext)); + ASSERT_EQ(0, + matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext_str)); // Bob decrypts the message - auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size()); auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str); auto body_str = std::string((char *)decrypted.data(), decrypted.size()); - ASSERT_EQ(body_str, "Hello, Bob!"); + ASSERT_EQ(body_str, plaintext); } -- GitLab