From b98bf5760bc72424f17ed231d41ea9f7b84bed0d Mon Sep 17 00:00:00 2001
From: Konstantinos Sideris <sideris.konstantin@gmail.com>
Date: Sun, 13 May 2018 13:23:06 +0300
Subject: [PATCH] Fix the OlmSessions test case

Pass copies of input buffer to methods that delete them
---
 src/crypto.cpp | 56 ++++++++++++++++++++++++++++++++++++++++----------
 src/crypto.hpp | 10 ++++++++-
 tests/e2ee.cpp | 31 ++++++++++++++--------------
 3 files changed, 70 insertions(+), 27 deletions(-)

diff --git a/src/crypto.cpp b/src/crypto.cpp
index 7f0d9f878..2967aaf70 100644
--- a/src/crypto.cpp
+++ b/src/crypto.cpp
@@ -181,19 +181,30 @@ OlmClient::init_outbound_group_session()
 }
 
 BinaryBuf
-OlmClient::decrypt_message(OlmSession *session, size_t msgtype, const std::string &msg)
+OlmClient::decrypt_message(OlmSession *session,
+                           size_t msgtype,
+                           const std::string &one_time_key_message)
 {
+        auto tmp = create_buffer(one_time_key_message.size());
+        std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
+
         auto declen =
-          olm_decrypt_max_plaintext_length(session, msgtype, (void *)msg.data(), msg.size());
+          olm_decrypt_max_plaintext_length(session, msgtype, (void *)tmp.data(), tmp.size());
 
         auto decrypted = create_buffer(declen);
-        const int ret  = olm_decrypt(
-          session, msgtype, (void *)msg.data(), msg.size(), decrypted.data(), decrypted.size());
+        std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
 
-        if (ret == -1)
+        const int nbytes = olm_decrypt(
+          session, msgtype, (void *)tmp.data(), tmp.size(), decrypted.data(), decrypted.size());
+
+        if (nbytes == -1)
                 throw olm_exception("olm_decrypt", session);
 
-        return decrypted;
+        // Removing the extra padding from the origial buffer.
+        auto output = create_buffer(nbytes);
+        std::copy(decrypted.begin(), decrypted.end(), output.begin());
+
+        return output;
 }
 
 BinaryBuf
@@ -216,14 +227,15 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg)
 }
 
 std::unique_ptr<OlmSession, OlmDeleter>
-OlmClient::create_inbound_session(const std::string &one_time_key_message)
+OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
 {
         auto session = create_olm_object<OlmSession>();
 
-        const int ret = olm_create_inbound_session(session.get(),
-                                                   account(),
-                                                   (void *)one_time_key_message.data(),
-                                                   one_time_key_message.size());
+        auto tmp = create_buffer(one_time_key_message.size());
+        std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
+
+        const int ret =
+          olm_create_inbound_session(session.get(), account(), (void *)tmp.data(), tmp.size());
 
         if (ret == -1)
                 throw olm_exception("create_inbound_session", session.get());
@@ -299,3 +311,25 @@ mtx::client::crypto::session_key(OlmOutboundGroupSession *s)
 
         return std::string(tmp.begin(), tmp.end());
 }
+
+bool
+mtx::client::crypto::matches_inbound_session(OlmSession *session,
+                                             const BinaryBuf &one_time_key_message)
+{
+        auto tmp = create_buffer(one_time_key_message.size());
+        std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
+
+        return olm_matches_inbound_session(session, (void *)tmp.data(), tmp.size());
+}
+
+bool
+mtx::client::crypto::matches_inbound_session_from(OlmSession *session,
+                                                  const std::string &id_key,
+                                                  const BinaryBuf &one_time_key_message)
+{
+        auto tmp = create_buffer(one_time_key_message.size());
+        std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
+
+        return olm_matches_inbound_session_from(
+          session, id_key.data(), id_key.size(), (void *)tmp.data(), tmp.size());
+}
diff --git a/src/crypto.hpp b/src/crypto.hpp
index 8be233bff..925d77d6a 100644
--- a/src/crypto.hpp
+++ b/src/crypto.hpp
@@ -225,7 +225,7 @@ public:
           const std::string &identity_key,
           const std::string &one_time_key);
         std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session(
-          const std::string &one_time_key_message);
+          const BinaryBuf &one_time_key_message);
 
         OlmAccount *account() { return account_.get(); }
         OlmUtility *utility() { return utility_.get(); }
@@ -253,6 +253,14 @@ session_id(OlmOutboundGroupSession *s);
 std::string
 session_key(OlmOutboundGroupSession *s);
 
+bool
+matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message);
+
+bool
+matches_inbound_session_from(OlmSession *session,
+                             const std::string &id_key,
+                             const BinaryBuf &one_time_key_message);
+
 } // namespace crypto
 } // namespace client
 } // namespace mtx
diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp
index 5821d3784..46361b455 100644
--- a/tests/e2ee.cpp
+++ b/tests/e2ee.cpp
@@ -546,28 +546,29 @@ TEST(Encryption, OlmSessions)
         auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key);
 
         // Alice encrypts the message using the current session.
-        auto plaintext = "Hello, Bob!";
-        // size_t msgtype  = olm_encrypt_message_type(alice_outbound_session.get());
+        auto plaintext  = "Hello, Bob!";
+        size_t msgtype  = olm_encrypt_message_type(alice_outbound_session.get());
         auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext);
 
+        EXPECT_EQ(msgtype, 0);
+
         // Bob creates an inbound session to receive Alice's message.
-        auto bob_inbound_session =
-          bob->create_inbound_session(std::string((char *)ciphertext.data(), ciphertext.size()));
+        auto bob_inbound_session = bob->create_inbound_session(ciphertext);
 
         // Bob validates that the message was meant for him.
-        auto matches = olm_matches_inbound_session(
-          bob_inbound_session.get(), (void *)ciphertext.data(), ciphertext.size());
+        ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext));
 
-        ASSERT_EQ(matches, 1);
+        // Bob validates that the message was sent from Alice.
+        ASSERT_EQ(1,
+                  matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext));
 
-        // Bob decrypts the message
-        // auto decrypted =
-        //   bob->decrypt_message(bob_inbound_session.get(),
-        //                        msgtype,
-        //                        std::string((char *)ciphertext.data(), ciphertext.size()));
+        // Bob validates that the message wasn't sent by someone else.
+        ASSERT_EQ(0, matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext));
 
-        // auto body = std::string((char *)decrypted.data(), decrypted.size());
-        // std::cout << body << std::endl;
+        // Bob decrypts the message
+        auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
+        auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str);
 
-        // ASSERT_EQ(body, "Hello, Bob!");
+        auto body_str = std::string((char *)decrypted.data(), decrypted.size());
+        ASSERT_EQ(body_str, "Hello, Bob!");
 }
-- 
GitLab