From 7e522df71c804cbc04f38007c3958ebbb56815d3 Mon Sep 17 00:00:00 2001
From: Konstantinos Sideris <sideris.konstantin@gmail.com>
Date: Sun, 20 May 2018 15:55:38 +0300
Subject: [PATCH] Implement decryption for group events on the crypto_bot

---
 examples/crypto_bot.cpp             | 196 ++++++++++++++++++++++++++--
 include/mtxclient/crypto/client.hpp |  34 +++--
 lib/crypto/client.cpp               |  67 +++++++++-
 tests/e2ee.cpp                      |  21 +--
 4 files changed, 286 insertions(+), 32 deletions(-)

diff --git a/examples/crypto_bot.cpp b/examples/crypto_bot.cpp
index 8171af80b..70915fd0d 100644
--- a/examples/crypto_bot.cpp
+++ b/examples/crypto_bot.cpp
@@ -5,6 +5,7 @@
 #include <atomic>
 #include <iostream>
 #include <json.hpp>
+#include <stdexcept>
 #include <unistd.h>
 #include <variant.hpp>
 
@@ -31,6 +32,46 @@ using namespace mtx::identifiers;
 
 using TimelineEvent = mtx::events::collections::TimelineEvents;
 
+constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
+
+struct OlmCipherContent
+{
+        std::string body;
+        uint8_t type;
+};
+
+inline void
+from_json(const nlohmann::json &obj, OlmCipherContent &msg)
+{
+        msg.body = obj.at("body");
+        msg.type = obj.at("type");
+}
+
+struct OlmMessage
+{
+        std::string sender_key;
+        std::string sender;
+
+        using RecipientKey = std::string;
+        std::map<RecipientKey, OlmCipherContent> ciphertext;
+};
+
+inline void
+from_json(const nlohmann::json &obj, OlmMessage &msg)
+{
+        if (obj.at("type") != "m.room.encrypted") {
+                throw std::invalid_argument("invalid type for olm message");
+        }
+
+        if (obj.at("content").at("algorithm") != OLM_ALGO)
+                throw std::invalid_argument("invalid algorithm for olm message");
+
+        msg.sender     = obj.at("sender");
+        msg.sender_key = obj.at("content").at("sender_key");
+        msg.ciphertext =
+          obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
+}
+
 template<class Container, class Item>
 bool
 exists(const Container &container, const Item &item)
@@ -69,6 +110,36 @@ struct Storage
         std::map<std::string, std::string> device_keys_;
         //! Flag that indicate if a specific room has encryption enabled.
         std::map<std::string, bool> encrypted_rooms_;
+
+        //! Mapping from curve25519 to session.
+        std::map<std::string, OlmSessionPtr> olm_sessions;
+
+        std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions;
+
+        bool inbound_group_exists(const std::string &room_id,
+                                  const std::string &session_id,
+                                  const std::string &sender_key)
+        {
+                const auto key = room_id + session_id + sender_key;
+                return inbound_group_sessions.find(key) != inbound_group_sessions.end();
+        }
+
+        void set_inbound_group_session(const std::string &room_id,
+                                       const std::string &session_id,
+                                       const std::string &sender_key,
+                                       InboundGroupSessionPtr session)
+        {
+                const auto key              = room_id + session_id + sender_key;
+                inbound_group_sessions[key] = std::move(session);
+        }
+
+        OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id,
+                                                          const std::string &session_id,
+                                                          const std::string &sender_key)
+        {
+                const auto key = room_id + session_id + sender_key;
+                return inbound_group_sessions[key].get();
+        }
 };
 
 namespace {
@@ -182,6 +253,73 @@ mark_encrypted_room(const RoomId &id)
         storage.encrypted_rooms_[id.get()] = true;
 }
 
+void
+decrypt_olm_message(const OlmMessage &olm_msg)
+{
+        console->info("OLM message");
+        console->info("sender: {}", olm_msg.sender);
+        console->info("sender_key: {}", olm_msg.sender_key);
+
+        const auto my_id_key = olm_client->identity_keys().curve25519;
+        for (const auto &cipher : olm_msg.ciphertext) {
+                if (cipher.first == my_id_key) {
+                        const auto msg_body = cipher.second.body;
+                        const auto msg_type = cipher.second.type;
+
+                        console->info("the message is meant for us");
+                        console->info("body: {}", msg_body);
+                        console->info("type: {}", msg_type);
+
+                        if (msg_type == 0) {
+                                console->info("opening session with {}", olm_msg.sender);
+                                auto inbound_session = olm_client->create_inbound_session(msg_body);
+
+                                auto ok = matches_inbound_session_from(
+                                  inbound_session.get(), olm_msg.sender_key, msg_body);
+
+                                if (!ok) {
+                                        console->error("session could not be established");
+
+                                } else {
+                                        auto output = olm_client->decrypt_message(
+                                          inbound_session.get(), msg_type, msg_body);
+
+                                        auto plaintext = json::parse(
+                                          std::string((char *)output.data(), output.size()));
+                                        console->info("decrypted message: \n {}",
+                                                      plaintext.dump(2));
+
+                                        storage.olm_sessions.emplace(olm_msg.sender_key,
+                                                                     std::move(inbound_session));
+
+                                        std::string room_id = plaintext.at("content").at("room_id");
+                                        std::string session_id =
+                                          plaintext.at("content").at("session_id");
+                                        std::string session_key =
+                                          plaintext.at("content").at("session_key");
+
+                                        if (storage.inbound_group_exists(
+                                              room_id, session_id, olm_msg.sender_key)) {
+                                                console->warn("megolm session already exists");
+                                        } else {
+                                                auto megolm_session =
+                                                  olm_client->init_inbound_group_session(
+                                                    session_key);
+
+                                                storage.set_inbound_group_session(
+                                                  room_id,
+                                                  session_id,
+                                                  olm_msg.sender_key,
+                                                  std::move(megolm_session));
+
+                                                console->info("megolm_session saved");
+                                        }
+                                }
+                        }
+                }
+        }
+}
+
 void
 parse_messages(const mtx::responses::Sync &res)
 {
@@ -227,7 +365,28 @@ parse_messages(const mtx::responses::Sync &res)
                                 console->debug("{}", get_json(e));
                         } else if (is_encrypted(e)) {
                                 console->info("received an encrypted event: {}", room_id);
-                                console->debug("{}", get_json(e));
+                                console->info("{}", get_json(e));
+
+                                auto msg = mpark::get<EncryptedEvent<msg::Encrypted>>(e);
+
+                                if (storage.inbound_group_exists(
+                                      room_id, msg.content.session_id, msg.content.sender_key)) {
+                                        auto res = olm_client->decrypt_group_message(
+                                          storage.get_inbound_group_session(room_id,
+                                                                            msg.content.session_id,
+                                                                            msg.content.sender_key),
+                                          msg.content.ciphertext);
+
+                                        auto msg_str =
+                                          std::string((char *)res.data.data(), res.data.size());
+
+                                        console->info("decrypted data: {}", msg_str);
+                                        console->info("decrypted message_index: {}",
+                                                      res.message_index);
+                                } else {
+                                        console->warn(
+                                          "no megolm session found to decrypt the event");
+                                }
                         }
                 }
         }
@@ -368,9 +527,23 @@ get_device_keys(const UserId &user)
 }
 
 void
-handle_to_device_msgs(const std::vector<nlohmann::json> &to_device)
+handle_to_device_msgs(const std::vector<nlohmann::json> &msgs)
 {
-        (void)to_device;
+        if (!msgs.empty())
+                console->info("inspecting {} to_device messages", msgs.size());
+
+        for (const auto &msg : msgs) {
+                console->info(msg.dump(2));
+
+                try {
+                        OlmMessage olm_msg = msg;
+                        decrypt_olm_message(std::move(olm_msg));
+                } catch (const nlohmann::json::exception &e) {
+                        console->warn("parsing error for olm message: {}", e.what());
+                } catch (const std::invalid_argument &e) {
+                        console->warn("validation error for olm message: {}", e.what());
+                }
+        }
 }
 
 void
@@ -424,21 +597,26 @@ main()
 {
         spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
 
-        std::string username, server, password;
+        std::string username("mtx_bot");
+        std::string server("matrix.org");
+        std::string password("dzyvrwB09GdyEqiyBnfAEvZI3");
 
-        cout << "username: ";
-        std::getline(std::cin, username);
+        // cout << "username: ";
+        // std::getline(std::cin, username);
 
-        cout << "server: ";
-        std::getline(std::cin, server);
+        // cout << "server: ";
+        // std::getline(std::cin, server);
 
-        password = getpass("password: ");
+        // password = getpass("password: ");
 
         client = std::make_shared<Client>(server);
 
         olm_client = make_shared<OlmClient>();
         olm_client->create_new_account();
 
+        console->info("ed25519: {}", olm_client->identity_keys().ed25519);
+        console->info("curve25519: {}", olm_client->identity_keys().curve25519);
+
         client->login(username, password, login_cb);
         client->close();
 
diff --git a/include/mtxclient/crypto/client.hpp b/include/mtxclient/crypto/client.hpp
index b100f1c55..1b64e450b 100644
--- a/include/mtxclient/crypto/client.hpp
+++ b/include/mtxclient/crypto/client.hpp
@@ -42,6 +42,10 @@ public:
           : msg_(func + ": " + std::string(olm_outbound_group_session_last_error(s)))
         {}
 
+        olm_exception(std::string func, OlmInboundGroupSession *s)
+          : msg_(func + ": " + std::string(olm_inbound_group_session_last_error(s)))
+        {}
+
         olm_exception(std::string msg)
           : msg_(msg)
         {}
@@ -127,6 +131,16 @@ create_olm_object()
         return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate());
 }
 
+using OlmSessionPtr           = std::unique_ptr<OlmSession, OlmDeleter>;
+using OutboundGroupSessionPtr = std::unique_ptr<OlmOutboundGroupSession, OlmDeleter>;
+using InboundGroupSessionPtr  = std::unique_ptr<OlmInboundGroupSession, OlmDeleter>;
+
+struct GroupPlaintext
+{
+        BinaryBuf data;
+        uint32_t message_index;
+};
+
 class OlmClient : public std::enable_shared_from_this<OlmClient>
 {
 public:
@@ -170,6 +184,10 @@ public:
         mtx::requests::UploadKeys create_upload_keys_request(const OneTimeKeys &keys);
         mtx::requests::UploadKeys create_upload_keys_request();
 
+        //! Decrypt a message using megolm.
+        GroupPlaintext decrypt_group_message(OlmInboundGroupSession *session,
+                                             const std::string &message,
+                                             uint32_t message_index = 0);
         //! Encrypt a message using olm.
         BinaryBuf encrypt_message(OlmSession *session, const std::string &msg);
         //! Decrypt a message using olm.
@@ -178,12 +196,12 @@ public:
                                   const std::string &msg);
 
         //! Create an outbount megolm session.
-        std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> init_outbound_group_session();
-        std::unique_ptr<OlmSession, OlmDeleter> create_outbound_session(
-          const std::string &identity_key,
-          const std::string &one_time_key);
-        std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session(
-          const BinaryBuf &one_time_key_message);
+        OutboundGroupSessionPtr init_outbound_group_session();
+        InboundGroupSessionPtr init_inbound_group_session(const std::string &session_key);
+        OlmSessionPtr create_outbound_session(const std::string &identity_key,
+                                              const std::string &one_time_key);
+        OlmSessionPtr create_inbound_session(const BinaryBuf &one_time_key_message);
+        OlmSessionPtr create_inbound_session(const std::string &one_time_key_message);
 
         OlmAccount *account() { return account_.get(); }
         OlmUtility *utility() { return utility_.get(); }
@@ -212,12 +230,12 @@ std::string
 session_key(OlmOutboundGroupSession *s);
 
 bool
-matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message);
+matches_inbound_session(OlmSession *session, const std::string &one_time_key_message);
 
 bool
 matches_inbound_session_from(OlmSession *session,
                              const std::string &id_key,
-                             const BinaryBuf &one_time_key_message);
+                             const std::string &one_time_key_message);
 
 //! Verify a signature object as obtained from the response of /keys/query endpoint
 bool
diff --git a/lib/crypto/client.cpp b/lib/crypto/client.cpp
index c28e58b32..7637127b6 100644
--- a/lib/crypto/client.cpp
+++ b/lib/crypto/client.cpp
@@ -174,7 +174,7 @@ OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_k
         return req;
 }
 
-std::unique_ptr<OlmOutboundGroupSession, OlmDeleter>
+OutboundGroupSessionPtr
 OlmClient::init_outbound_group_session()
 {
         auto session = create_olm_object<OlmOutboundGroupSession>();
@@ -189,6 +189,54 @@ OlmClient::init_outbound_group_session()
         return session;
 }
 
+InboundGroupSessionPtr
+OlmClient::init_inbound_group_session(const std::string &session_key)
+{
+        auto session = create_olm_object<OlmInboundGroupSession>();
+
+        const int ret = olm_init_inbound_group_session(
+          session.get(), reinterpret_cast<const uint8_t *>(session_key.data()), session_key.size());
+
+        if (ret == -1)
+                throw olm_exception("init_inbound_group_session", session.get());
+
+        return session;
+}
+
+GroupPlaintext
+OlmClient::decrypt_group_message(OlmInboundGroupSession *session,
+                                 const std::string &message,
+                                 uint32_t message_index)
+{
+        // TODO handle errors
+        auto tmp_msg = create_buffer(message.size());
+        std::copy(message.begin(), message.end(), tmp_msg.begin());
+
+        auto plaintext_len =
+          olm_group_decrypt_max_plaintext_length(session, tmp_msg.data(), tmp_msg.size());
+        auto plaintext = create_buffer(plaintext_len);
+
+        tmp_msg = create_buffer(message.size());
+        std::copy(message.begin(), message.end(), tmp_msg.begin());
+
+        const int nbytes = olm_group_decrypt(session,
+                                             tmp_msg.data(),
+                                             tmp_msg.size(),
+                                             plaintext.data(),
+                                             plaintext.size(),
+                                             &message_index);
+
+        logger->info("new message_index: {}", message_index);
+
+        if (nbytes == -1)
+                throw olm_exception("olm_group_decrypt", session);
+
+        auto output = create_buffer(nbytes);
+        std::memcpy(output.data(), plaintext.data(), nbytes);
+
+        return GroupPlaintext{std::move(output), message_index};
+}
+
 BinaryBuf
 OlmClient::decrypt_message(OlmSession *session,
                            size_t msgtype,
@@ -235,7 +283,16 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg)
         return ciphertext;
 }
 
-std::unique_ptr<OlmSession, OlmDeleter>
+OlmSessionPtr
+OlmClient::create_inbound_session(const std::string &one_time_key_message)
+{
+        BinaryBuf tmp(one_time_key_message.size());
+        memcpy(tmp.data(), one_time_key_message.data(), one_time_key_message.size());
+
+        return create_inbound_session(std::move(tmp));
+}
+
+OlmSessionPtr
 OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
 {
         auto session = create_olm_object<OlmSession>();
@@ -252,7 +309,7 @@ OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
         return session;
 }
 
-std::unique_ptr<OlmSession, OlmDeleter>
+OlmSessionPtr
 OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key)
 {
         auto session    = create_olm_object<OlmSession>();
@@ -322,7 +379,7 @@ mtx::crypto::session_key(OlmOutboundGroupSession *s)
 }
 
 bool
-mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message)
+mtx::crypto::matches_inbound_session(OlmSession *session, const std::string &one_time_key_message)
 {
         auto tmp = create_buffer(one_time_key_message.size());
         std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
@@ -333,7 +390,7 @@ mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_t
 bool
 mtx::crypto::matches_inbound_session_from(OlmSession *session,
                                           const std::string &id_key,
-                                          const BinaryBuf &one_time_key_message)
+                                          const std::string &one_time_key_message)
 {
         auto tmp = create_buffer(one_time_key_message.size());
         std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp
index 8fd0b2514..24b608468 100644
--- a/tests/e2ee.cpp
+++ b/tests/e2ee.cpp
@@ -532,29 +532,30 @@ TEST(Encryption, OlmSessions)
         auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key);
 
         // Alice encrypts the message using the current session.
-        auto plaintext  = "Hello, Bob!";
-        size_t msgtype  = olm_encrypt_message_type(alice_outbound_session.get());
-        auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext);
+        auto plaintext      = "Hello, Bob!";
+        size_t msgtype      = olm_encrypt_message_type(alice_outbound_session.get());
+        auto ciphertext     = alice->encrypt_message(alice_outbound_session.get(), plaintext);
+        auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
 
         EXPECT_EQ(msgtype, 0);
 
         // Bob creates an inbound session to receive Alice's message.
-        auto bob_inbound_session = bob->create_inbound_session(ciphertext);
+        auto bob_inbound_session = bob->create_inbound_session(ciphertext_str);
 
         // Bob validates that the message was meant for him.
-        ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext));
+        ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext_str));
 
         // Bob validates that the message was sent from Alice.
-        ASSERT_EQ(1,
-                  matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext));
+        ASSERT_EQ(
+          1, matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext_str));
 
         // Bob validates that the message wasn't sent by someone else.
-        ASSERT_EQ(0, matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext));
+        ASSERT_EQ(0,
+                  matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext_str));
 
         // Bob decrypts the message
-        auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
         auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str);
 
         auto body_str = std::string((char *)decrypted.data(), decrypted.size());
-        ASSERT_EQ(body_str, "Hello, Bob!");
+        ASSERT_EQ(body_str, plaintext);
 }
-- 
GitLab