From 4f526ecae972ff503fb777fc739246b09324e9d2 Mon Sep 17 00:00:00 2001
From: Konstantinos Sideris <sideris.konstantin@gmail.com>
Date: Fri, 25 May 2018 19:35:44 +0300
Subject: [PATCH] Add method for pickling/unpickling olm objects

---
 include/mtxclient/crypto/client.hpp  |  74 ++++--------
 include/mtxclient/crypto/objects.hpp | 169 +++++++++++++++++++++++++++
 lib/crypto/client.cpp                |  20 ++--
 tests/e2ee.cpp                       | 129 ++++++++++++++++++++
 tests/utils.cpp                      |   2 +-
 5 files changed, 332 insertions(+), 62 deletions(-)
 create mode 100644 include/mtxclient/crypto/objects.hpp

diff --git a/include/mtxclient/crypto/client.hpp b/include/mtxclient/crypto/client.hpp
index b55460ae1..96ff6953b 100644
--- a/include/mtxclient/crypto/client.hpp
+++ b/include/mtxclient/crypto/client.hpp
@@ -15,6 +15,7 @@
 #include <olm/olm.h>
 #include <olm/session.hh>
 
+#include "mtxclient/crypto/objects.hpp"
 #include "mtxclient/crypto/types.hpp"
 
 namespace mtx {
@@ -66,69 +67,32 @@ create_buffer(std::size_t nbytes)
         return buf;
 }
 
-struct OlmDeleter
-{
-        void operator()(OlmAccount *ptr) { operator delete(ptr, olm_account_size()); }
-        void operator()(OlmUtility *ptr) { operator delete(ptr, olm_utility_size()); }
-
-        void operator()(OlmSession *ptr) { operator delete(ptr, olm_session_size()); }
-        void operator()(OlmOutboundGroupSession *ptr)
-        {
-                operator delete(ptr, olm_outbound_group_session_size());
-        }
-        void operator()(OlmInboundGroupSession *ptr)
-        {
-                operator delete(ptr, olm_inbound_group_session_size());
-        }
-};
-
 template<class T>
-struct OlmAllocator
+std::string
+pickle(typename T::olm_type *object, const std::string &key)
 {
-        static T allocate() = delete;
-};
+        auto tmp      = create_buffer(T::pickle_length(object));
+        const int ret = T::pickle(object, key.data(), key.size(), tmp.data(), tmp.size());
 
-template<>
-struct OlmAllocator<OlmAccount>
-{
-        static OlmAccount *allocate() { return olm_account(new uint8_t[olm_account_size()]); }
-};
+        if (ret == -1)
+                throw olm_exception("pickle", object);
 
-template<>
-struct OlmAllocator<OlmSession>
-{
-        static OlmSession *allocate() { return olm_session(new uint8_t[olm_session_size()]); }
-};
+        return std::string((char *)tmp.data(), tmp.size());
+}
 
-template<>
-struct OlmAllocator<OlmUtility>
+template<class T>
+std::unique_ptr<typename T::olm_type, OlmDeleter>
+unpickle(const std::string &pickled, const std::string &key)
 {
-        static OlmUtility *allocate() { return olm_utility(new uint8_t[olm_utility_size()]); }
-};
+        auto object = create_olm_object<T>();
 
-template<>
-struct OlmAllocator<OlmOutboundGroupSession>
-{
-        static OlmOutboundGroupSession *allocate()
-        {
-                return olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
-        }
-};
+        const int ret =
+          T::unpickle(object.get(), key.data(), key.size(), (void *)pickled.data(), pickled.size());
 
-template<>
-struct OlmAllocator<OlmInboundGroupSession>
-{
-        static OlmInboundGroupSession *allocate()
-        {
-                return olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]);
-        }
-};
+        if (ret == -1)
+                throw olm_exception("unpickle", object.get());
 
-template<class T>
-std::unique_ptr<T, OlmDeleter>
-create_olm_object()
-{
-        return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate());
+        return std::move(object);
 }
 
 using OlmSessionPtr           = std::unique_ptr<OlmSession, OlmDeleter>;
@@ -163,6 +127,8 @@ public:
         void create_new_account();
         void create_new_utility();
 
+        void restore_account(const std::string &saved_data, const std::string &key);
+
         //! Retrieve the json representation of the identity keys for the given account.
         IdentityKeys identity_keys() const;
         //! Sign the identity keys.
diff --git a/include/mtxclient/crypto/objects.hpp b/include/mtxclient/crypto/objects.hpp
new file mode 100644
index 000000000..07f96190e
--- /dev/null
+++ b/include/mtxclient/crypto/objects.hpp
@@ -0,0 +1,169 @@
+#pragma once
+
+#include <memory>
+#include <olm/olm.h>
+
+namespace mtx {
+namespace crypto {
+
+struct OlmDeleter
+{
+        void operator()(OlmAccount *ptr) { operator delete(ptr, olm_account_size()); }
+        void operator()(OlmUtility *ptr) { operator delete(ptr, olm_utility_size()); }
+
+        void operator()(OlmSession *ptr) { operator delete(ptr, olm_session_size()); }
+        void operator()(OlmOutboundGroupSession *ptr)
+        {
+                operator delete(ptr, olm_outbound_group_session_size());
+        }
+        void operator()(OlmInboundGroupSession *ptr)
+        {
+                operator delete(ptr, olm_inbound_group_session_size());
+        }
+};
+
+struct UtilityObject
+{
+        using olm_type = OlmUtility;
+
+        static olm_type *allocate() { return olm_utility(new uint8_t[olm_utility_size()]); }
+};
+
+struct AccountObject
+{
+        using olm_type = OlmAccount;
+
+        static olm_type *allocate() { return olm_account(new uint8_t[olm_account_size()]); }
+
+        static size_t pickle_length(olm_type *account)
+        {
+                return olm_pickle_account_length(account);
+        }
+
+        static size_t pickle(olm_type *account,
+                             void const *key,
+                             size_t key_length,
+                             void *pickled,
+                             size_t pickled_length)
+        {
+                return olm_pickle_account(account, key, key_length, pickled, pickled_length);
+        }
+
+        static size_t unpickle(olm_type *account,
+                               void const *key,
+                               size_t key_length,
+                               void *pickled,
+                               size_t pickled_length)
+        {
+                return olm_unpickle_account(account, key, key_length, pickled, pickled_length);
+        }
+};
+
+struct SessionObject
+{
+        using olm_type = OlmSession;
+
+        static olm_type *allocate() { return olm_session(new uint8_t[olm_session_size()]); }
+
+        static size_t pickle_length(olm_type *session)
+        {
+                return olm_pickle_session_length(session);
+        }
+
+        static size_t pickle(olm_type *session,
+                             void const *key,
+                             size_t key_length,
+                             void *pickled,
+                             size_t pickled_length)
+        {
+                return olm_pickle_session(session, key, key_length, pickled, pickled_length);
+        }
+
+        static size_t unpickle(olm_type *session,
+                               void const *key,
+                               size_t key_length,
+                               void *pickled,
+                               size_t pickled_length)
+        {
+                return olm_unpickle_session(session, key, key_length, pickled, pickled_length);
+        }
+};
+
+struct InboundSessionObject
+{
+        using olm_type = OlmInboundGroupSession;
+
+        static olm_type *allocate()
+        {
+                return olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]);
+        }
+
+        static size_t pickle_length(olm_type *session)
+        {
+                return olm_pickle_inbound_group_session_length(session);
+        }
+
+        static size_t pickle(olm_type *session,
+                             void const *key,
+                             size_t key_length,
+                             void *pickled,
+                             size_t pickled_length)
+        {
+                return olm_pickle_inbound_group_session(
+                  session, key, key_length, pickled, pickled_length);
+        }
+
+        static size_t unpickle(olm_type *session,
+                               void const *key,
+                               size_t key_length,
+                               void *pickled,
+                               size_t pickled_length)
+        {
+                return olm_unpickle_inbound_group_session(
+                  session, key, key_length, pickled, pickled_length);
+        }
+};
+
+struct OutboundSessionObject
+{
+        using olm_type = OlmOutboundGroupSession;
+
+        static olm_type *allocate()
+        {
+                return olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]);
+        }
+
+        static size_t pickle_length(olm_type *session)
+        {
+                return olm_pickle_outbound_group_session_length(session);
+        }
+
+        static size_t pickle(olm_type *session,
+                             void const *key,
+                             size_t key_length,
+                             void *pickled,
+                             size_t pickled_length)
+        {
+                return olm_pickle_outbound_group_session(
+                  session, key, key_length, pickled, pickled_length);
+        }
+
+        static size_t unpickle(olm_type *session,
+                               void const *key,
+                               size_t key_length,
+                               void *pickled,
+                               size_t pickled_length)
+        {
+                return olm_unpickle_outbound_group_session(
+                  session, key, key_length, pickled, pickled_length);
+        }
+};
+
+template<class T>
+std::unique_ptr<typename T::olm_type, OlmDeleter>
+create_olm_object()
+{
+        return std::unique_ptr<typename T::olm_type, OlmDeleter>(T::allocate());
+}
+}
+}
diff --git a/lib/crypto/client.cpp b/lib/crypto/client.cpp
index b3a6f0fc2..eed9890a9 100644
--- a/lib/crypto/client.cpp
+++ b/lib/crypto/client.cpp
@@ -19,7 +19,7 @@ OlmClient::create_new_account()
         if (account_)
                 return;
 
-        account_ = create_olm_object<OlmAccount>();
+        account_ = create_olm_object<AccountObject>();
 
         auto tmp_buf  = create_buffer(olm_create_account_random_length(account_.get()));
         const int ret = olm_create_account(account_.get(), tmp_buf.data(), tmp_buf.size());
@@ -37,7 +37,13 @@ OlmClient::create_new_utility()
         if (utility_)
                 return;
 
-        utility_ = create_olm_object<OlmUtility>();
+        utility_ = create_olm_object<UtilityObject>();
+}
+
+void
+OlmClient::restore_account(const std::string &saved_data, const std::string &key)
+{
+        account_ = unpickle<AccountObject>(saved_data, key);
 }
 
 IdentityKeys
@@ -183,7 +189,7 @@ OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_k
 OutboundGroupSessionPtr
 OlmClient::init_outbound_group_session()
 {
-        auto session = create_olm_object<OlmOutboundGroupSession>();
+        auto session = create_olm_object<OutboundSessionObject>();
         auto tmp_buf = create_buffer(olm_init_outbound_group_session_random_length(session.get()));
 
         const int ret =
@@ -198,7 +204,7 @@ OlmClient::init_outbound_group_session()
 InboundGroupSessionPtr
 OlmClient::init_inbound_group_session(const std::string &session_key)
 {
-        auto session = create_olm_object<OlmInboundGroupSession>();
+        auto session = create_olm_object<InboundSessionObject>();
 
         const int ret = olm_init_inbound_group_session(
           session.get(), reinterpret_cast<const uint8_t *>(session_key.data()), session_key.size());
@@ -317,7 +323,7 @@ OlmClient::create_inbound_session(const std::string &one_time_key_message)
 OlmSessionPtr
 OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
 {
-        auto session = create_olm_object<OlmSession>();
+        auto session = create_olm_object<SessionObject>();
 
         auto tmp = create_buffer(one_time_key_message.size());
         std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
@@ -334,7 +340,7 @@ OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
 OlmSessionPtr
 OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key)
 {
-        auto session    = create_olm_object<OlmSession>();
+        auto session    = create_olm_object<SessionObject>();
         auto random_buf = create_buffer(olm_create_outbound_session_random_length(session.get()));
 
         const int ret = olm_create_outbound_session(session.get(),
@@ -471,7 +477,7 @@ mtx::crypto::verify_identity_signature(nlohmann::json obj,
 
                 const auto msg = obj.dump();
 
-                auto utility = create_olm_object<OlmUtility>();
+                auto utility = create_olm_object<UtilityObject>();
                 auto ret     = olm_ed25519_verify(utility.get(),
                                               signing_key.data(),
                                               signing_key.size(),
diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp
index d9a153588..47f834662 100644
--- a/tests/e2ee.cpp
+++ b/tests/e2ee.cpp
@@ -832,6 +832,135 @@ TEST(Encryption, OlmRoomKeyEncryption)
         bob_http->close();
 }
 
+TEST(Encryption, PickleAccount)
+{
+        auto alice = std::make_shared<OlmClient>();
+        alice->create_new_account();
+        alice->generate_one_time_keys(10);
+
+        auto alice_pickled = pickle<AccountObject>(alice->account(), "secret");
+
+        auto bob = std::make_shared<OlmClient>();
+        bob->restore_account(alice_pickled, "secret");
+
+        EXPECT_EQ(json(bob->identity_keys()).dump(), json(alice->identity_keys()).dump());
+        EXPECT_EQ(json(bob->one_time_keys()).dump(), json(alice->one_time_keys()).dump());
+
+        auto carl = std::make_shared<OlmClient>();
+
+        // BAD_ACCOUNT_KEY
+        EXPECT_THROW(carl->restore_account(alice_pickled, "another_secret"), olm_exception);
+}
+
+TEST(Encryption, PickleOlmSessions)
+{
+        auto alice = std::make_shared<OlmClient>();
+        alice->create_new_account();
+
+        auto bob = std::make_shared<OlmClient>();
+        bob->create_new_account();
+        bob->generate_one_time_keys(1);
+
+        std::string bob_key          = bob->identity_keys().curve25519;
+        std::string bob_one_time_key = bob->one_time_keys().curve25519.begin()->second;
+
+        auto outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key);
+
+        auto plaintext      = "Hello, Bob!";
+        size_t msgtype      = olm_encrypt_message_type(outbound_session.get());
+        auto ciphertext     = alice->encrypt_message(outbound_session.get(), plaintext);
+        auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
+
+        EXPECT_EQ(msgtype, 0);
+
+        auto saved_outbound_session    = pickle<SessionObject>(outbound_session.get(), "wat");
+        auto restored_outbound_session = unpickle<SessionObject>(saved_outbound_session, "wat");
+
+        EXPECT_THROW(unpickle<SessionObject>(saved_outbound_session, "another_secret"),
+                     olm_exception);
+
+        msgtype = olm_encrypt_message_type(restored_outbound_session.get());
+        EXPECT_EQ(msgtype, 0);
+
+        auto restored_ciphertext =
+          alice->encrypt_message(restored_outbound_session.get(), plaintext);
+        auto restored_ciphertext_str =
+          std::string((char *)restored_ciphertext.data(), restored_ciphertext.size());
+
+        auto inbound_session          = bob->create_inbound_session(ciphertext_str);
+        auto saved_inbound_session    = pickle<SessionObject>(inbound_session.get(), "woot");
+        auto restored_inbound_session = unpickle<SessionObject>(saved_inbound_session, "woot");
+
+        EXPECT_THROW(unpickle<SessionObject>(saved_inbound_session, "another_secret"),
+                     olm_exception);
+
+        ASSERT_EQ(1, matches_inbound_session(inbound_session.get(), ciphertext_str));
+        ASSERT_EQ(1, matches_inbound_session(inbound_session.get(), restored_ciphertext_str));
+        ASSERT_EQ(1,
+                  matches_inbound_session(restored_inbound_session.get(), restored_ciphertext_str));
+        ASSERT_EQ(1, matches_inbound_session(restored_inbound_session.get(), ciphertext_str));
+
+        auto d1 = bob->decrypt_message(inbound_session.get(), msgtype, ciphertext_str);
+        auto d2 = bob->decrypt_message(restored_inbound_session.get(), msgtype, ciphertext_str);
+        auto d3 = bob->decrypt_message(inbound_session.get(), msgtype, restored_ciphertext_str);
+        auto d4 =
+          bob->decrypt_message(restored_inbound_session.get(), msgtype, restored_ciphertext_str);
+
+        EXPECT_EQ(d1, d2);
+        EXPECT_EQ(d2, d3);
+        EXPECT_EQ(d3, d4);
+        EXPECT_EQ(d1, d4);
+        EXPECT_EQ(d2, d4);
+
+        EXPECT_EQ(std::string((char *)d1.data(), d1.size()), "Hello, Bob!");
+}
+
+TEST(Encryption, PickleMegolmSessions)
+{
+        // Outbound Session
+        auto alice = make_shared<mtx::crypto::OlmClient>();
+        alice->create_new_account();
+
+        auto outbound_session = alice->init_outbound_group_session();
+
+        const auto original_session_id  = mtx::crypto::session_id(outbound_session.get());
+        const auto original_session_key = mtx::crypto::session_key(outbound_session.get());
+
+        auto saved_session = pickle<OutboundSessionObject>(outbound_session.get(), "secret");
+        auto restored_outbound_session = unpickle<OutboundSessionObject>(saved_session, "secret");
+
+        const auto restored_session_id  = mtx::crypto::session_id(restored_outbound_session.get());
+        const auto restored_session_key = mtx::crypto::session_key(restored_outbound_session.get());
+
+        EXPECT_EQ(original_session_id, restored_session_id);
+        EXPECT_EQ(original_session_key, restored_session_key);
+
+        // BAD_ACCOUNT_KEY
+        EXPECT_THROW(unpickle<OutboundSessionObject>(saved_session, "another_secret"),
+                     olm_exception);
+
+        const auto SECRET = "Hello World!";
+
+        auto encrypted  = alice->encrypt_group_message(outbound_session.get(), SECRET);
+        auto ciphertext = std::string((char *)encrypted.data(), encrypted.size());
+
+        // Inbound Session
+        auto inbound_session = alice->init_inbound_group_session(original_session_key);
+        auto plaintext       = alice->decrypt_group_message(inbound_session.get(), ciphertext);
+
+        saved_session = pickle<InboundSessionObject>(inbound_session.get(), "secret");
+
+        auto restored_inbound_session = unpickle<InboundSessionObject>(saved_session, "secret");
+        auto restored_plaintext =
+          alice->decrypt_group_message(restored_inbound_session.get(), ciphertext);
+
+        EXPECT_EQ(
+          std::string((char *)plaintext.data.data(), plaintext.data.size()),
+          std::string((char *)restored_plaintext.data.data(), restored_plaintext.data.size()));
+
+        EXPECT_EQ(std::string((char *)plaintext.data.data(), plaintext.data.size()), SECRET);
+}
+
 TEST(Encryption, DISABLED_HandleRoomKeyEvent) {}
 TEST(Encryption, DISABLED_HandleRoomKeyRequestEvent) {}
 TEST(Encryption, DISABLED_HandleNewDevices) {}
diff --git a/tests/utils.cpp b/tests/utils.cpp
index 6ff6d407d..32f913d3b 100644
--- a/tests/utils.cpp
+++ b/tests/utils.cpp
@@ -177,7 +177,7 @@ TEST(Utilities, VerifyIdentityKeyJson)
 
         auto msg = tmp.dump();
 
-        auto utility = create_olm_object<OlmUtility>();
+        auto utility = create_olm_object<UtilityObject>();
         EXPECT_EQ(olm_ed25519_verify(utility.get(),
                                      signing_key.data(),
                                      signing_key.size(),
-- 
GitLab