From f34bd3554077f066dbf06a5845b90c815fc6749c Mon Sep 17 00:00:00 2001
From: Konstantinos Sideris <sideris.konstantin@gmail.com>
Date: Thu, 24 May 2018 19:53:28 +0300
Subject: [PATCH] Add test case for sending encrypted olm & m.room_key events
 through sync

---
 examples/crypto_bot.cpp             |  53 ++------
 include/mtxclient/crypto/client.hpp |  16 ++-
 lib/crypto/client.cpp               |  33 ++++-
 tests/e2ee.cpp                      | 200 ++++++++++++++++++++++++++++
 tests/test_helpers.hpp              |   5 +
 5 files changed, 261 insertions(+), 46 deletions(-)

diff --git a/examples/crypto_bot.cpp b/examples/crypto_bot.cpp
index 8ba2573e6..e029846a5 100644
--- a/examples/crypto_bot.cpp
+++ b/examples/crypto_bot.cpp
@@ -223,41 +223,6 @@ is_room_encryption(const T &event)
         return mpark::holds_alternative<StateEvent<Encryption>>(event);
 }
 
-std::string
-create_room_key_event(const json &megolm_payload,
-                      const std::string &recipient,
-                      const std::string &recipient_key)
-{
-        auto room_key = json{{"content", megolm_payload},
-                             {"keys", {{"ed25519", olm_client->identity_keys().ed25519}}},
-                             {"recipient", recipient},
-                             {"recipient_keys", {{"ed25519", recipient_key}}},
-                             {"sender", client->user_id().to_string()},
-                             {"sender_device", client->device_id()},
-                             {"type", "m.room_key"}};
-
-        return room_key.dump();
-}
-
-json
-encrypt_to_device_message(OlmSession *session,
-                          const std::string &room_key_event,
-                          const std::string &recipient_key)
-{
-        auto encrypted = olm_client->encrypt_message(session, room_key_event);
-
-        auto final_payload = json{
-          {"algorithm", "m.olm.v1.curve25519-aes-sha2"},
-          {"sender_key", olm_client->identity_keys().curve25519},
-          {"ciphertext",
-           {{recipient_key,
-             {{"body", std::string((char *)encrypted.data(), encrypted.size())}, {"type", 0}}}}}};
-
-        console->info("about to send: \n {}", final_payload.dump(2));
-
-        return final_payload;
-}
-
 void
 send_group_message(OlmOutboundGroupSession *session,
                    const std::string &session_id,
@@ -318,8 +283,11 @@ create_outbound_megolm_session(const std::string &room_id, const std::string &re
                 for (const auto &dev : devices) {
                         // TODO: check if we have downloaded the keys
                         const auto device_keys = storage.device_keys[dev];
-                        auto room_key =
-                          create_room_key_event(megolm_payload, member.first, device_keys.ed25519);
+                        auto room_key          = olm_client
+                                          ->create_room_key_event(UserId(member.first),
+                                                                  device_keys.ed25519,
+                                                                  megolm_payload)
+                                          .dump();
 
                         auto to_device_cb = [](RequestErr err) {
                                 if (err) {
@@ -334,7 +302,7 @@ create_outbound_megolm_session(const std::string &room_id, const std::string &re
                                 auto olm_session =
                                   storage.olm_outbound_sessions[device_keys.curve25519].get();
 
-                                auto device_msg = encrypt_to_device_message(
+                                auto device_msg = olm_client->create_olm_encrypted_content(
                                   olm_session, room_key, device_keys.curve25519);
 
                                 json body{{"messages", {{member, {{dev, device_msg}}}}}};
@@ -366,10 +334,11 @@ create_outbound_megolm_session(const std::string &room_id, const std::string &re
                                                 auto session =
                                                   olm_client->create_outbound_session(id_key, otk);
 
-                                                auto device_msg = encrypt_to_device_message(
-                                                  session.get(),
-                                                  room_key,
-                                                  storage.device_keys[dev].curve25519);
+                                                auto device_msg =
+                                                  olm_client->create_olm_encrypted_content(
+                                                    session.get(),
+                                                    room_key,
+                                                    storage.device_keys[dev].curve25519);
 
                                                 // TODO: saving should happen when the message is
                                                 // sent.
diff --git a/include/mtxclient/crypto/client.hpp b/include/mtxclient/crypto/client.hpp
index ad97eeb1f..b55460ae1 100644
--- a/include/mtxclient/crypto/client.hpp
+++ b/include/mtxclient/crypto/client.hpp
@@ -157,14 +157,14 @@ public:
         void set_user_id(std::string user_id) { user_id_ = std::move(user_id); }
 
         //! Sign the given message.
-        Base64String sign_message(const std::string &msg);
+        Base64String sign_message(const std::string &msg) const;
 
         //! Create a new olm Account. Must be called before any other operation.
         void create_new_account();
         void create_new_utility();
 
         //! Retrieve the json representation of the identity keys for the given account.
-        IdentityKeys identity_keys();
+        IdentityKeys identity_keys() const;
         //! Sign the identity keys.
         //! The result should be used as part of the /keys/upload/ request.
         Base64String sign_identity_keys();
@@ -209,6 +209,18 @@ public:
         OlmSessionPtr create_inbound_session(const BinaryBuf &one_time_key_message);
         OlmSessionPtr create_inbound_session(const std::string &one_time_key_message);
 
+        //! The `m.room_key` event is used to share the session_id & session_key
+        //! of an outbound megolm session.
+        nlohmann::json create_room_key_event(const UserId &user_id,
+                                             const std::string &ed25519_device_key,
+                                             const nlohmann::json &content) const noexcept;
+
+        //! Create the content for an m.room.encrypted event.
+        //! algorithm: m.olm.v1.curve25519-aes-sha2
+        nlohmann::json create_olm_encrypted_content(OlmSession *session,
+                                                    const std::string &room_key_event,
+                                                    const std::string &recipient_key);
+
         OlmAccount *account() { return account_.get(); }
         OlmUtility *utility() { return utility_.get(); }
 
diff --git a/lib/crypto/client.cpp b/lib/crypto/client.cpp
index d63e542dc..b3a6f0fc2 100644
--- a/lib/crypto/client.cpp
+++ b/lib/crypto/client.cpp
@@ -41,7 +41,7 @@ OlmClient::create_new_utility()
 }
 
 IdentityKeys
-OlmClient::identity_keys()
+OlmClient::identity_keys() const
 {
         auto tmp_buf = create_buffer(olm_account_identity_keys_length(account_.get()));
         int result =
@@ -54,7 +54,7 @@ OlmClient::identity_keys()
 }
 
 std::string
-OlmClient::sign_message(const std::string &msg)
+OlmClient::sign_message(const std::string &msg) const
 {
         auto signature_buf = create_buffer(olm_account_signature_length(account_.get()));
         olm_account_sign(
@@ -352,6 +352,35 @@ OlmClient::create_outbound_session(const std::string &identity_key, const std::s
         return session;
 }
 
+nlohmann::json
+OlmClient::create_room_key_event(const UserId &recipient,
+                                 const std::string &ed25519_recipient_key,
+                                 const nlohmann::json &content) const noexcept
+{
+        return json{{"content", content},
+                    {"keys", {{"ed25519", identity_keys().ed25519}}},
+                    {"recipient", recipient.get()},
+                    {"recipient_keys", {{"ed25519", ed25519_recipient_key}}},
+                    {"sender", user_id_},
+                    {"sender_device", device_id_},
+                    {"type", "m.room_key"}};
+}
+
+nlohmann::json
+OlmClient::create_olm_encrypted_content(OlmSession *session,
+                                        const std::string &room_key_event,
+                                        const std::string &recipient_key)
+{
+        size_t msg_type    = olm_encrypt_message_type(session);
+        auto encrypted     = encrypt_message(session, room_key_event);
+        auto encrypted_str = std::string((char *)encrypted.data(), encrypted.size());
+
+        return json{
+          {"algorithm", "m.olm.v1.curve25519-aes-sha2"},
+          {"sender_key", identity_keys().curve25519},
+          {"ciphertext", {{recipient_key, {{"body", encrypted_str}, {"type", msg_type}}}}}};
+}
+
 BinaryBuf
 mtx::crypto::decode_base64(const std::string &msg)
 {
diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp
index 1ff78d88c..d9a153588 100644
--- a/tests/e2ee.cpp
+++ b/tests/e2ee.cpp
@@ -25,6 +25,46 @@ using namespace mtx::responses;
 
 using namespace std;
 
+struct OlmCipherContent
+{
+        std::string body;
+        uint8_t type;
+};
+
+inline void
+from_json(const nlohmann::json &obj, OlmCipherContent &msg)
+{
+        msg.body = obj.at("body");
+        msg.type = obj.at("type");
+}
+
+struct OlmMessage
+{
+        std::string sender_key;
+        std::string sender;
+
+        using RecipientKey = std::string;
+        std::map<RecipientKey, OlmCipherContent> ciphertext;
+};
+
+constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
+
+inline void
+from_json(const nlohmann::json &obj, OlmMessage &msg)
+{
+        if (obj.at("type") != "m.room.encrypted") {
+                throw std::invalid_argument("invalid type for olm message");
+        }
+
+        if (obj.at("content").at("algorithm") != OLM_ALGO)
+                throw std::invalid_argument("invalid algorithm for olm message");
+
+        msg.sender     = obj.at("sender");
+        msg.sender_key = obj.at("content").at("sender_key");
+        msg.ciphertext =
+          obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
+}
+
 mtx::requests::UploadKeys
 generate_keys(std::shared_ptr<mtx::crypto::OlmClient> account)
 {
@@ -639,3 +679,163 @@ TEST(Encryption, MegolmSessions)
         auto output_str = std::string((char *)bob_plaintext.data.data(), bob_plaintext.data.size());
         ASSERT_EQ(output_str, secret_message);
 }
+
+TEST(Encryption, OlmRoomKeyEncryption)
+{
+        // Alice wants to use olm to send data to Bob.
+        auto alice_olm  = std::make_shared<OlmClient>();
+        auto alice_http = std::make_shared<Client>("localhost");
+        alice_olm->create_new_account();
+        alice_olm->generate_one_time_keys(10);
+
+        auto bob_olm  = std::make_shared<OlmClient>();
+        auto bob_http = std::make_shared<Client>("localhost");
+        bob_olm->create_new_account();
+        bob_olm->generate_one_time_keys(10);
+
+        alice_http->login("alice", "secret", &check_login);
+        bob_http->login("bob", "secret", &check_login);
+
+        WAIT_UNTIL(!bob_http->access_token().empty() && !alice_http->access_token().empty())
+
+        bob_olm->set_user_id(bob_http->user_id().to_string());
+        bob_olm->set_device_id(bob_http->device_id());
+        alice_olm->set_user_id(alice_http->user_id().to_string());
+        alice_olm->set_device_id(alice_http->device_id());
+
+        // Both users upload their identity & one time keys
+        atomic_int uploads(0);
+        auto upload_cb = [&uploads](const mtx::responses::UploadKeys &res, RequestErr err) {
+                check_error(err);
+                EXPECT_EQ(res.one_time_key_counts.size(), 1);
+                EXPECT_EQ(res.one_time_key_counts.at("signed_curve25519"), 10);
+                uploads += 1;
+        };
+
+        alice_http->upload_keys(alice_olm->create_upload_keys_request(), upload_cb);
+        bob_http->upload_keys(bob_olm->create_upload_keys_request(), upload_cb);
+
+        WAIT_UNTIL(uploads == 2)
+
+        atomic_bool request_finished(false);
+        std::string bob_ed25519, bob_curve25519, bob_otk;
+
+        // Alice needs Bob's ed25519 device key.
+        mtx::requests::QueryKeys query;
+        query.device_keys[bob_http->user_id().to_string()] = {};
+        alice_http->query_keys(query,
+                               [&request_finished, &bob_ed25519, &bob_curve25519, bob = bob_http](
+                                 const mtx::responses::QueryKeys &res, RequestErr err) {
+                                       check_error(err);
+
+                                       const auto device_id = bob->device_id();
+                                       const auto user_id   = bob->user_id().to_string();
+                                       const auto devices   = res.device_keys.at(user_id);
+
+                                       assert(devices.find(device_id) != devices.end());
+
+                                       bob_ed25519 =
+                                         devices.at(device_id).keys.at("ed25519:" + device_id);
+                                       bob_curve25519 =
+                                         devices.at(device_id).keys.at("curve25519:" + device_id);
+
+                                       request_finished = true;
+                               });
+
+        WAIT_UNTIL(request_finished);
+
+        // Alice needs one of Bob's one time keys.
+        request_finished = false;
+        alice_http->claim_keys(bob_http->user_id(),
+                               {bob_http->device_id()},
+                               [&bob_otk, bob = bob_http, &request_finished](
+                                 const mtx::responses::ClaimKeys &res, RequestErr err) {
+                                       check_error(err);
+
+                                       const auto user_id   = bob->user_id().to_string();
+                                       const auto device_id = bob->device_id();
+
+                                       auto retrieved_devices = res.one_time_keys.at(user_id);
+                                       for (const auto &device : retrieved_devices) {
+                                               if (device.first == device_id) {
+                                                       bob_otk = device.second.begin()->at("key");
+                                                       break;
+                                               }
+                                       }
+
+                                       request_finished = true;
+                               });
+
+        WAIT_UNTIL(request_finished);
+
+        EXPECT_EQ(bob_ed25519, bob_olm->identity_keys().ed25519);
+        EXPECT_EQ(bob_curve25519, bob_olm->identity_keys().curve25519);
+        EXPECT_EQ(bob_otk, bob_olm->one_time_keys().curve25519.begin()->second);
+
+        constexpr auto SECRET_TEXT = "Hello Bob!";
+
+        // Alice create m.room.key request
+        json payload  = json{{"secret", SECRET_TEXT}};
+        auto room_key = alice_olm->create_room_key_event(
+          UserId("@bob:localhost"), bob_olm->identity_keys().ed25519, payload);
+
+        // Alice creates an outbound session.
+        auto out_session = alice_olm->create_outbound_session(bob_curve25519, bob_otk);
+        auto device_msg  = alice_olm->create_olm_encrypted_content(
+          out_session.get(), room_key.dump(), bob_curve25519);
+
+        // Finally sends the olm encrypted message to Bob's device.
+        atomic_bool is_sent(false);
+        json body{
+          {"messages", {{bob_http->user_id().to_string(), {{bob_http->device_id(), device_msg}}}}}};
+        alice_http->send_to_device("m.room.encrypted", body, [&is_sent](RequestErr err) {
+                check_error(err);
+                is_sent = true;
+        });
+
+        WAIT_UNTIL(is_sent)
+
+        SyncOpts opts;
+        opts.timeout = 0;
+        bob_http->sync(
+          opts, [bob = bob_olm, SECRET_TEXT](const mtx::responses::Sync &res, RequestErr err) {
+                  check_error(err);
+
+                  assert(!res.to_device.empty());
+                  assert(res.to_device.size() == 1);
+
+                  OlmMessage olm_msg = res.to_device[0];
+                  auto cipher        = olm_msg.ciphertext.begin();
+
+                  EXPECT_EQ(cipher->first, bob->identity_keys().curve25519);
+
+                  const auto msg_body = cipher->second.body;
+                  const auto msg_type = cipher->second.type;
+
+                  assert(msg_type == 0);
+
+                  auto inbound_session = bob->create_inbound_session(msg_body);
+                  auto ok              = matches_inbound_session_from(
+                    inbound_session.get(), olm_msg.sender_key, msg_body);
+                  assert(ok == true);
+
+                  auto output = bob->decrypt_message(inbound_session.get(), msg_type, msg_body);
+
+                  // Parsing the original plaintext json object.
+                  auto plaintext = json::parse(std::string((char *)output.data(), output.size()));
+                  std::string secret = plaintext.at("content").at("secret");
+
+                  ASSERT_EQ(secret, SECRET_TEXT);
+          });
+
+        alice_http->close();
+        bob_http->close();
+}
+
+TEST(Encryption, DISABLED_HandleRoomKeyEvent) {}
+TEST(Encryption, DISABLED_HandleRoomKeyRequestEvent) {}
+TEST(Encryption, DISABLED_HandleNewDevices) {}
+TEST(Encryption, DISABLED_HandleLeftDevices) {}
+
+TEST(Encryption, DISABLED_SendEncryptedMessageWithMegolm) {}
+TEST(Encryption, DISABLED_RotateMegolmSession) {}
diff --git a/tests/test_helpers.hpp b/tests/test_helpers.hpp
index 0b2e4d0d9..a7673dab0 100644
--- a/tests/test_helpers.hpp
+++ b/tests/test_helpers.hpp
@@ -14,6 +14,11 @@ sleep()
         std::this_thread::sleep_for(std::chrono::milliseconds(100));
 }
 
+#define WAIT_UNTIL(condition)                                                                      \
+        while (!(condition)) {                                                                     \
+                sleep();                                                                           \
+        };
+
 inline void
 check_error(mtx::http::RequestErr err)
 {
-- 
GitLab