diff --git a/src/crypto.cpp b/src/crypto.cpp index 7f0d9f87823e7091dcee4d284eb997d809523ced..2967aaf704e4319d63929ead3ce9dfd743c709f0 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -181,19 +181,30 @@ OlmClient::init_outbound_group_session() } BinaryBuf -OlmClient::decrypt_message(OlmSession *session, size_t msgtype, const std::string &msg) +OlmClient::decrypt_message(OlmSession *session, + size_t msgtype, + const std::string &one_time_key_message) { + auto tmp = create_buffer(one_time_key_message.size()); + std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); + auto declen = - olm_decrypt_max_plaintext_length(session, msgtype, (void *)msg.data(), msg.size()); + olm_decrypt_max_plaintext_length(session, msgtype, (void *)tmp.data(), tmp.size()); auto decrypted = create_buffer(declen); - const int ret = olm_decrypt( - session, msgtype, (void *)msg.data(), msg.size(), decrypted.data(), decrypted.size()); + std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); - if (ret == -1) + const int nbytes = olm_decrypt( + session, msgtype, (void *)tmp.data(), tmp.size(), decrypted.data(), decrypted.size()); + + if (nbytes == -1) throw olm_exception("olm_decrypt", session); - return decrypted; + // Removing the extra padding from the origial buffer. + auto output = create_buffer(nbytes); + std::copy(decrypted.begin(), decrypted.end(), output.begin()); + + return output; } BinaryBuf @@ -216,14 +227,15 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg) } std::unique_ptr<OlmSession, OlmDeleter> -OlmClient::create_inbound_session(const std::string &one_time_key_message) +OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message) { auto session = create_olm_object<OlmSession>(); - const int ret = olm_create_inbound_session(session.get(), - account(), - (void *)one_time_key_message.data(), - one_time_key_message.size()); + auto tmp = create_buffer(one_time_key_message.size()); + std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); + + const int ret = + olm_create_inbound_session(session.get(), account(), (void *)tmp.data(), tmp.size()); if (ret == -1) throw olm_exception("create_inbound_session", session.get()); @@ -299,3 +311,25 @@ mtx::client::crypto::session_key(OlmOutboundGroupSession *s) return std::string(tmp.begin(), tmp.end()); } + +bool +mtx::client::crypto::matches_inbound_session(OlmSession *session, + const BinaryBuf &one_time_key_message) +{ + auto tmp = create_buffer(one_time_key_message.size()); + std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); + + return olm_matches_inbound_session(session, (void *)tmp.data(), tmp.size()); +} + +bool +mtx::client::crypto::matches_inbound_session_from(OlmSession *session, + const std::string &id_key, + const BinaryBuf &one_time_key_message) +{ + auto tmp = create_buffer(one_time_key_message.size()); + std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); + + return olm_matches_inbound_session_from( + session, id_key.data(), id_key.size(), (void *)tmp.data(), tmp.size()); +} diff --git a/src/crypto.hpp b/src/crypto.hpp index 8be233bff6d100595e6bc3c21626eb6ff3764dd0..925d77d6a1bc873bc4825762bf339be85664a9b1 100644 --- a/src/crypto.hpp +++ b/src/crypto.hpp @@ -225,7 +225,7 @@ public: const std::string &identity_key, const std::string &one_time_key); std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session( - const std::string &one_time_key_message); + const BinaryBuf &one_time_key_message); OlmAccount *account() { return account_.get(); } OlmUtility *utility() { return utility_.get(); } @@ -253,6 +253,14 @@ session_id(OlmOutboundGroupSession *s); std::string session_key(OlmOutboundGroupSession *s); +bool +matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message); + +bool +matches_inbound_session_from(OlmSession *session, + const std::string &id_key, + const BinaryBuf &one_time_key_message); + } // namespace crypto } // namespace client } // namespace mtx diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp index 5821d3784b4e2c88c0f9623d8ee5f552c12432dd..46361b455884532d3b7065c682213016fac2d1bd 100644 --- a/tests/e2ee.cpp +++ b/tests/e2ee.cpp @@ -546,28 +546,29 @@ TEST(Encryption, OlmSessions) auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key); // Alice encrypts the message using the current session. - auto plaintext = "Hello, Bob!"; - // size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); + auto plaintext = "Hello, Bob!"; + size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext); + EXPECT_EQ(msgtype, 0); + // Bob creates an inbound session to receive Alice's message. - auto bob_inbound_session = - bob->create_inbound_session(std::string((char *)ciphertext.data(), ciphertext.size())); + auto bob_inbound_session = bob->create_inbound_session(ciphertext); // Bob validates that the message was meant for him. - auto matches = olm_matches_inbound_session( - bob_inbound_session.get(), (void *)ciphertext.data(), ciphertext.size()); + ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext)); - ASSERT_EQ(matches, 1); + // Bob validates that the message was sent from Alice. + ASSERT_EQ(1, + matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext)); - // Bob decrypts the message - // auto decrypted = - // bob->decrypt_message(bob_inbound_session.get(), - // msgtype, - // std::string((char *)ciphertext.data(), ciphertext.size())); + // Bob validates that the message wasn't sent by someone else. + ASSERT_EQ(0, matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext)); - // auto body = std::string((char *)decrypted.data(), decrypted.size()); - // std::cout << body << std::endl; + // Bob decrypts the message + auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size()); + auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str); - // ASSERT_EQ(body, "Hello, Bob!"); + auto body_str = std::string((char *)decrypted.data(), decrypted.size()); + ASSERT_EQ(body_str, "Hello, Bob!"); }