diff --git a/src/crypto.cpp b/src/crypto.cpp index 53a0220b2855402f7fd6c1c2a116080a27d585e6..7f0d9f87823e7091dcee4d284eb997d809523ced 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -16,7 +16,7 @@ OlmClient::create_new_account() account_ = create_olm_object<OlmAccount>(); auto tmp_buf = create_buffer(olm_create_account_random_length(account_.get())); - const int ret = olm_create_account(account_.get(), tmp_buf->data(), tmp_buf->size()); + const int ret = olm_create_account(account_.get(), tmp_buf.data(), tmp_buf.size()); if (ret == -1) { account_.reset(); @@ -39,12 +39,12 @@ OlmClient::identity_keys() { auto tmp_buf = create_buffer(olm_account_identity_keys_length(account_.get())); int result = - olm_account_identity_keys(account_.get(), (void *)tmp_buf->data(), tmp_buf->size()); + olm_account_identity_keys(account_.get(), (void *)tmp_buf.data(), tmp_buf.size()); if (result == -1) throw olm_exception("identity_keys", account_.get()); - return json::parse(std::string(tmp_buf->begin(), tmp_buf->end())); + return json::parse(std::string(tmp_buf.begin(), tmp_buf.end())); } std::string @@ -52,9 +52,9 @@ OlmClient::sign_message(const std::string &msg) { auto signature_buf = create_buffer(olm_account_signature_length(account_.get())); olm_account_sign( - account_.get(), msg.data(), msg.size(), signature_buf->data(), signature_buf->size()); + account_.get(), msg.data(), msg.size(), signature_buf.data(), signature_buf.size()); - return std::string(signature_buf->begin(), signature_buf->end()); + return std::string(signature_buf.begin(), signature_buf.end()); } std::string @@ -83,7 +83,7 @@ OlmClient::generate_one_time_keys(std::size_t number_of_keys) auto buf = create_buffer(nbytes); const int ret = olm_account_generate_one_time_keys( - account_.get(), number_of_keys, buf->data(), buf->size()); + account_.get(), number_of_keys, buf.data(), buf.size()); if (ret == -1) throw olm_exception("generate_one_time_keys", account_.get()); @@ -96,12 +96,12 @@ OlmClient::one_time_keys() { auto buf = create_buffer(olm_account_one_time_keys_length(account_.get())); - const int ret = olm_account_one_time_keys(account_.get(), buf->data(), buf->size()); + const int ret = olm_account_one_time_keys(account_.get(), buf.data(), buf.size()); if (ret == -1) throw olm_exception("one_time_keys", account_.get()); - return json::parse(std::string(buf->begin(), buf->end())); + return json::parse(std::string(buf.begin(), buf.end())); } std::string @@ -172,7 +172,7 @@ OlmClient::init_outbound_group_session() auto tmp_buf = create_buffer(olm_init_outbound_group_session_random_length(session.get())); const int ret = - olm_init_outbound_group_session(session.get(), tmp_buf->data(), tmp_buf->size()); + olm_init_outbound_group_session(session.get(), tmp_buf.data(), tmp_buf.size()); if (ret == -1) throw olm_exception("init_outbound_group_session", session.get()); @@ -180,7 +180,79 @@ OlmClient::init_outbound_group_session() return session; } -std::unique_ptr<BinaryBuf> +BinaryBuf +OlmClient::decrypt_message(OlmSession *session, size_t msgtype, const std::string &msg) +{ + auto declen = + olm_decrypt_max_plaintext_length(session, msgtype, (void *)msg.data(), msg.size()); + + auto decrypted = create_buffer(declen); + const int ret = olm_decrypt( + session, msgtype, (void *)msg.data(), msg.size(), decrypted.data(), decrypted.size()); + + if (ret == -1) + throw olm_exception("olm_decrypt", session); + + return decrypted; +} + +BinaryBuf +OlmClient::encrypt_message(OlmSession *session, const std::string &msg) +{ + auto ciphertext = create_buffer(olm_encrypt_message_length(session, msg.size())); + auto random_buf = create_buffer(olm_encrypt_random_length(session)); + + const int ret = olm_encrypt(session, + msg.data(), + msg.size(), + random_buf.data(), + random_buf.size(), + ciphertext.data(), + ciphertext.size()); + if (ret == -1) + throw olm_exception("olm_encrypt", session); + + return ciphertext; +} + +std::unique_ptr<OlmSession, OlmDeleter> +OlmClient::create_inbound_session(const std::string &one_time_key_message) +{ + auto session = create_olm_object<OlmSession>(); + + const int ret = olm_create_inbound_session(session.get(), + account(), + (void *)one_time_key_message.data(), + one_time_key_message.size()); + + if (ret == -1) + throw olm_exception("create_inbound_session", session.get()); + + return session; +} + +std::unique_ptr<OlmSession, OlmDeleter> +OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key) +{ + auto session = create_olm_object<OlmSession>(); + auto random_buf = create_buffer(olm_create_outbound_session_random_length(session.get())); + + const int ret = olm_create_outbound_session(session.get(), + account(), + identity_key.data(), + identity_key.size(), + one_time_key.data(), + one_time_key.size(), + random_buf.data(), + random_buf.size()); + + if (ret == -1) + throw olm_exception("create_outbound_session", session.get()); + + return session; +} + +BinaryBuf mtx::client::crypto::decode_base64(const std::string &msg) { const int output_nbytes = olm::decode_base64_length(msg.size()); @@ -191,7 +263,7 @@ mtx::client::crypto::decode_base64(const std::string &msg) auto output_buf = create_buffer(output_nbytes); olm::decode_base64( - reinterpret_cast<const uint8_t *>(msg.data()), msg.size(), output_buf->data()); + reinterpret_cast<const uint8_t *>(msg.data()), msg.size(), output_buf.data()); return output_buf; } @@ -205,25 +277,25 @@ mtx::client::crypto::encode_base64(const uint8_t *data, std::size_t len) throw std::runtime_error("invalid base64 input length"); auto output_buf = create_buffer(output_nbytes); - olm::encode_base64(data, len, output_buf->data()); + olm::encode_base64(data, len, output_buf.data()); - return std::string(output_buf->begin(), output_buf->end()); + return std::string(output_buf.begin(), output_buf.end()); } std::string mtx::client::crypto::session_id(OlmOutboundGroupSession *s) { auto tmp = create_buffer(olm_outbound_group_session_id_length(s)); - olm_outbound_group_session_id(s, tmp->data(), tmp->size()); + olm_outbound_group_session_id(s, tmp.data(), tmp.size()); - return std::string(tmp->begin(), tmp->end()); + return std::string(tmp.begin(), tmp.end()); } std::string mtx::client::crypto::session_key(OlmOutboundGroupSession *s) { auto tmp = create_buffer(olm_outbound_group_session_key_length(s)); - olm_outbound_group_session_key(s, tmp->data(), tmp->size()); + olm_outbound_group_session_key(s, tmp.data(), tmp.size()); - return std::string(tmp->begin(), tmp->end()); + return std::string(tmp.begin(), tmp.end()); } diff --git a/src/crypto.hpp b/src/crypto.hpp index a935eb6cde545b6fe8c1e22414c68c59b1c3daa6..8be233bff6d100595e6bc3c21626eb6ff3764dd0 100644 --- a/src/crypto.hpp +++ b/src/crypto.hpp @@ -96,12 +96,11 @@ private: }; //! Create a uint8_t buffer which is initialized with random bytes. -template<class T = BinaryBuf> -std::unique_ptr<T> +inline BinaryBuf create_buffer(std::size_t nbytes) { - auto buf = std::make_unique<T>(nbytes); - randombytes_buf(buf->data(), buf->size()); + auto buf = BinaryBuf(nbytes); + randombytes_buf(buf.data(), buf.size()); return buf; } @@ -213,8 +212,20 @@ public: mtx::requests::UploadKeys create_upload_keys_request(const OneTimeKeys &keys); mtx::requests::UploadKeys create_upload_keys_request(); + //! Encrypt a message using olm. + BinaryBuf encrypt_message(OlmSession *session, const std::string &msg); + //! Decrypt a message using olm. + BinaryBuf decrypt_message(OlmSession *session, + std::size_t msg_type, + const std::string &msg); + //! Create an outbount megolm session. std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> init_outbound_group_session(); + std::unique_ptr<OlmSession, OlmDeleter> create_outbound_session( + const std::string &identity_key, + const std::string &one_time_key); + std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session( + const std::string &one_time_key_message); OlmAccount *account() { return account_.get(); } OlmUtility *utility() { return utility_.get(); } @@ -231,7 +242,7 @@ std::string encode_base64(const uint8_t *data, std::size_t len); //! Decode the given base64 string -std::unique_ptr<BinaryBuf> +BinaryBuf decode_base64(const std::string &data); //! Retrieve the session id. diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp index 265c7734e6748a87ed69ef99eeaf30577007868c..5821d3784b4e2c88c0f9623d8ee5f552c12432dd 100644 --- a/tests/e2ee.cpp +++ b/tests/e2ee.cpp @@ -522,3 +522,52 @@ TEST(Encryption, CreateOutboundGroupSession) auto session_id = mtx::client::crypto::session_id(outbound_session.get()); auto session_key = mtx::client::crypto::session_key(outbound_session.get()); } + +TEST(Encryption, OlmSessions) +{ + using namespace mtx::client::crypto; + + auto alice = std::make_shared<OlmClient>(); + alice->create_new_account(); + alice->generate_one_time_keys(1); + + auto bob = std::make_shared<OlmClient>(); + bob->create_new_account(); + bob->generate_one_time_keys(1); + + std::string alice_key = alice->identity_keys().curve25519; + std::string alice_one_time_key = alice->one_time_keys().curve25519.begin()->second; + + std::string bob_key = bob->identity_keys().curve25519; + std::string bob_one_time_key = bob->one_time_keys().curve25519.begin()->second; + + // Alice is preparing to send a pre-shared message to Bob by opening + // a new 1-1 outbound session. + auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key); + + // Alice encrypts the message using the current session. + auto plaintext = "Hello, Bob!"; + // size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); + auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext); + + // Bob creates an inbound session to receive Alice's message. + auto bob_inbound_session = + bob->create_inbound_session(std::string((char *)ciphertext.data(), ciphertext.size())); + + // Bob validates that the message was meant for him. + auto matches = olm_matches_inbound_session( + bob_inbound_session.get(), (void *)ciphertext.data(), ciphertext.size()); + + ASSERT_EQ(matches, 1); + + // Bob decrypts the message + // auto decrypted = + // bob->decrypt_message(bob_inbound_session.get(), + // msgtype, + // std::string((char *)ciphertext.data(), ciphertext.size())); + + // auto body = std::string((char *)decrypted.data(), decrypted.size()); + // std::cout << body << std::endl; + + // ASSERT_EQ(body, "Hello, Bob!"); +}