diff --git a/Makefile b/Makefile index 724792f26f10bf423c8e972a25b9ce9fa230b8d0..3f6d56f949945016f9d771cb9cdc0429d5cce541 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,6 @@ asan: -DCMAKE_BUILD_TYPE=Debug \ -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ -DOPENSSL_ROOT_DIR=/usr/local/opt/openssl \ - -DBUILD_OLM=1 \ -DASAN=1 @cmake --build build diff --git a/src/crypto.cpp b/src/crypto.cpp index 61947de854d5c39d8498460d92c97315572ed11b..270695dafb5ed43fd029cb6569bf7cb016ec73df 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -13,8 +13,7 @@ OlmClient::create_new_account() if (account_) return; - account_ = - std::unique_ptr<OlmAccount, OlmDeleter>(olm_account(new uint8_t[olm_account_size()])); + account_ = create_olm_object<OlmAccount>(); auto tmp_buf = create_buffer(olm_create_account_random_length(account_.get())); const int ret = olm_create_account(account_.get(), tmp_buf->data(), tmp_buf->size()); @@ -32,15 +31,7 @@ OlmClient::create_new_utility() if (utility_) return; - utility_ = - std::unique_ptr<OlmUtility, OlmDeleter>(olm_utility(new uint8_t[olm_utility_size()])); -} - -std::unique_ptr<OlmSession, OlmDeleter> -OlmClient::create_new_session() -{ - return std::unique_ptr<OlmSession, OlmDeleter>( - olm_session(new uint8_t[olm_session_size()])); + utility_ = create_olm_object<OlmUtility>(); } IdentityKeys @@ -182,24 +173,14 @@ OlmClient::create_upload_keys_request(const mtx::client::crypto::OneTimeKeys &on return req; } -std::unique_ptr<OlmSession, OlmDeleter> -OlmClient::create_outbound_group_session(const std::string &peer_identity_key, - const std::string &peer_one_time_key) +std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> +OlmClient::init_outbound_group_session() { - auto session = create_new_session(); - auto tmp_buf = create_buffer(olm_create_outbound_session_random_length(session.get())); - - auto idk_buf = str_to_buffer(peer_identity_key); - auto otk_buf = str_to_buffer(peer_one_time_key); + auto session = create_olm_object<OlmOutboundGroupSession>(); + auto tmp_buf = create_buffer(olm_init_outbound_group_session_random_length(session.get())); - const int ret = olm_create_outbound_session(session.get(), - account_.get(), - idk_buf->data(), - idk_buf->size(), - otk_buf->data(), - otk_buf->size(), - tmp_buf->data(), - tmp_buf->size()); + const int ret = + olm_init_outbound_group_session(session.get(), tmp_buf->data(), tmp_buf->size()); if (ret == -1) throw olm_exception("init_outbound_group_session", session.get()); @@ -255,3 +236,21 @@ mtx::client::crypto::json_to_buffer(const nlohmann::json &obj) { return str_to_buffer(obj.dump()); } + +std::string +mtx::client::crypto::session_id(OlmOutboundGroupSession *s) +{ + auto tmp = create_buffer(olm_outbound_group_session_id_length(s)); + olm_outbound_group_session_id(s, tmp->data(), tmp->size()); + + return encode_base64(tmp->data(), tmp->size()); +} + +std::string +mtx::client::crypto::session_key(OlmOutboundGroupSession *s) +{ + auto tmp = create_buffer(olm_outbound_group_session_key_length(s)); + olm_outbound_group_session_key(s, tmp->data(), tmp->size()); + + return encode_base64(tmp->data(), tmp->size()); +} diff --git a/src/crypto.hpp b/src/crypto.hpp index de803e29005136b5b807856277a41861757c584b..d5b9a59034aab815cc65bb322385ec7d4598d726 100644 --- a/src/crypto.hpp +++ b/src/crypto.hpp @@ -81,6 +81,10 @@ public: : msg_(func + ": " + std::string(olm_utility_last_error(util))) {} + olm_exception(std::string func, OlmOutboundGroupSession *s) + : msg_(func + ": " + std::string(olm_outbound_group_session_last_error(s))) + {} + olm_exception(std::string msg) : msg_(msg) {} @@ -105,8 +109,59 @@ create_buffer(std::size_t nbytes) struct OlmDeleter { void operator()(OlmAccount *ptr) { operator delete(ptr, olm_account_size()); } - void operator()(OlmSession *ptr) { operator delete(ptr, olm_session_size()); } void operator()(OlmUtility *ptr) { operator delete(ptr, olm_utility_size()); } + + void operator()(OlmSession *ptr) { operator delete(ptr, olm_session_size()); } + void operator()(OlmOutboundGroupSession *ptr) + { + operator delete(ptr, olm_outbound_group_session_size()); + } + void operator()(OlmInboundGroupSession *ptr) + { + operator delete(ptr, olm_inbound_group_session_size()); + } +}; + +template<class T> +struct OlmAllocator +{ + static T allocate() = delete; +}; + +template<> +struct OlmAllocator<OlmAccount> +{ + static OlmAccount *allocate() { return olm_account(new uint8_t[olm_account_size()]); } +}; + +template<> +struct OlmAllocator<OlmSession> +{ + static OlmSession *allocate() { return olm_session(new uint8_t[olm_session_size()]); } +}; + +template<> +struct OlmAllocator<OlmUtility> +{ + static OlmUtility *allocate() { return olm_utility(new uint8_t[olm_utility_size()]); } +}; + +template<> +struct OlmAllocator<OlmOutboundGroupSession> +{ + static OlmOutboundGroupSession *allocate() + { + return olm_outbound_group_session(new uint8_t[olm_outbound_group_session_size()]); + } +}; + +template<> +struct OlmAllocator<OlmInboundGroupSession> +{ + static OlmInboundGroupSession *allocate() + { + return olm_inbound_group_session(new uint8_t[olm_inbound_group_session_size()]); + } }; class OlmClient : public std::enable_shared_from_this<OlmClient> @@ -127,10 +182,15 @@ public: //! Sign the given message. std::unique_ptr<BinaryBuf> sign_message(const std::string &msg); + template<class T> + std::unique_ptr<T, OlmDeleter> create_olm_object() + { + return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate()); + } + //! Create a new olm Account. Must be called before any other operation. void create_new_account(); void create_new_utility(); - std::unique_ptr<OlmSession, OlmDeleter> create_new_session(); //! Retrieve the json representation of the identity keys for the given account. IdentityKeys identity_keys(); @@ -154,9 +214,7 @@ public: mtx::requests::UploadKeys create_upload_keys_request(); //! Create an outbount megolm session. - std::unique_ptr<OlmSession, OlmDeleter> create_outbound_group_session( - const std::string &peer_identity_key, - const std::string &peer_one_time_key); + std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> init_outbound_group_session(); OlmAccount *account() { return account_.get(); } OlmUtility *utility() { return utility_.get(); } @@ -184,6 +242,14 @@ json_to_buffer(const nlohmann::json &obj); std::unique_ptr<BinaryBuf> str_to_buffer(const std::string &data); +//! Retrieve the session id. +std::string +session_id(OlmOutboundGroupSession *s); + +//! Retrieve the session key. +std::string +session_key(OlmOutboundGroupSession *s); + } // namespace crypto } // namespace client } // namespace mtx diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp index f591e6bf5240820d75090aab37f0799d45cf7d09..e688ffc6a740d54dcc64df23584d6130ae75d18b 100644 --- a/tests/e2ee.cpp +++ b/tests/e2ee.cpp @@ -396,3 +396,20 @@ TEST(Encryption, EnableEncryption) bob->close(); carl->close(); } + +TEST(Encryption, CreateOutboundGroupSession) +{ + auto alice = make_shared<mtx::client::crypto::OlmClient>(); + auto bob = make_shared<mtx::client::crypto::OlmClient>(); + + alice->create_new_account(); + bob->create_new_account(); + + bob->generate_one_time_keys(1); + alice->generate_one_time_keys(1); + + auto outbound_session = alice->init_outbound_group_session(); + + auto session_id = mtx::client::crypto::session_id(outbound_session.get()); + auto session_key = mtx::client::crypto::session_key(outbound_session.get()); +}