diff --git a/src/crypto.cpp b/src/crypto.cpp index f4812d8ad7a76d90d8cc1e6f7d969ea262e08131..1aa77ad13c2dbd6ef54bcf2bc00420e02551697b 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -1,5 +1,4 @@ #include <iostream> -#include <sodium.h> #include "crypto.hpp" #include <olm/base64.hh> @@ -9,15 +8,6 @@ using namespace mtx::client::crypto; constexpr std::size_t SIGNATURE_SIZE = 64; -std::unique_ptr<BinaryBuf> -mtx::client::crypto::create_buffer(std::size_t nbytes) -{ - auto buf = std::make_unique<BinaryBuf>(nbytes); - randombytes_buf(buf->data(), buf->size()); - - return buf; -} - std::shared_ptr<olm::Account> mtx::client::crypto::olm_new_account() { @@ -90,9 +80,7 @@ mtx::client::crypto::one_time_keys(std::shared_ptr<olm::Account> account) if (result == -1) throw olm_exception("one_time_keys", account->last_error); - std::string data(buf->begin(), buf->end()); - - return json::parse(data); + return json::parse(std::string(buf->begin(), buf->end())); } std::string @@ -225,3 +213,34 @@ mtx::client::crypto::json_to_buffer(const nlohmann::json &obj) { return str_to_buffer(obj.dump()); } + +_olm_curve25519_public_key +mtx::client::crypto::str_to_curve25519_pk(const std::string &data) +{ + auto decoded = decode_base64(data); + + if (decoded->size() != CURVE25519_KEY_LENGTH) + throw olm_exception("str_to_curve25519_pk: invalid input size"); + + _olm_curve25519_public_key pk; + std::copy(decoded->begin(), decoded->end(), pk.public_key); + + return pk; +} + +olm::Session +mtx::client::crypto::init_outbound_group_session(std::shared_ptr<olm::Account> account, + const std::string &peer_identity_key, + const std::string &peer_one_time_key) +{ + olm::Session session; + + auto buf = create_buffer(session.new_outbound_session_random_length()); + session.new_outbound_session(*account, + str_to_curve25519_pk(peer_identity_key), + str_to_curve25519_pk(peer_one_time_key), + buf->data(), + buf->size()); + + return session; +} diff --git a/src/crypto.hpp b/src/crypto.hpp index d8073282177e79173ba62b1e6154353da95bce96..32ba11ed8180e89e4a24d6e22e9ca656db8a437f 100644 --- a/src/crypto.hpp +++ b/src/crypto.hpp @@ -4,10 +4,14 @@ #include <memory> #include <json.hpp> +#include <sodium.h> + #include <mtx/identifiers.hpp> #include <mtx/requests.hpp> + #include <olm/account.hh> #include <olm/error.h> +#include <olm/session.hh> namespace mtx { namespace client { @@ -67,6 +71,10 @@ public: , msg_(msg + ": " + std::string(_olm_error_to_string(errcode))) {} + olm_exception(std::string msg) + : msg_(msg) + {} + OlmErrorCode get_errcode() const { return errcode_; } const char *get_error() const { return _olm_error_to_string(errcode_); } @@ -94,8 +102,15 @@ nlohmann::json one_time_keys(std::shared_ptr<olm::Account> user); //! Create a uint8_t buffer which is initialized with random bytes. -std::unique_ptr<BinaryBuf> -create_buffer(std::size_t nbytes); +template<class T = BinaryBuf> +std::unique_ptr<T> +create_buffer(std::size_t nbytes) +{ + auto buf = std::make_unique<T>(nbytes); + randombytes_buf(buf->data(), buf->size()); + + return buf; +} //! Sign the given one time keys and encode it to base64. std::string @@ -149,6 +164,16 @@ str_to_buffer(const std::string &data); std::unique_ptr<BinaryBuf> json_to_buffer(const nlohmann::json &obj); +//! Convert from base64 encoded public key. +_olm_curve25519_public_key +str_to_curve25519_pk(const std::string &data); + +//! Create an outbount megolm session. +olm::Session +init_outbound_group_session(std::shared_ptr<olm::Account> account, + const std::string &peer_id_key, + const std::string &peer_one_time_key); + } // namespace crypto } // namespace client } // namespace mtx diff --git a/tests/utils.cpp b/tests/utils.cpp index ac6774345e64fe86a2ae97d86254104396ec0f12..db182c64466d51fb4fc2a7f66a89f8d36316bb82 100644 --- a/tests/utils.cpp +++ b/tests/utils.cpp @@ -3,10 +3,13 @@ #include "crypto.hpp" #include "json.hpp" -#include "olm/utility.hh" +#include <olm/olm.h> +#include <olm/utility.hh> using json = nlohmann::json; + using namespace mtx::client::crypto; +using namespace std; constexpr int SIGNATURE_SIZE = 64; @@ -69,3 +72,34 @@ TEST(Utilities, VerifySignedIdentityKeys) EXPECT_EQ(utillity.last_error, 0); EXPECT_EQ(res, 0); } + +TEST(Utilities, OutboundGroupSession) +{ + auto alice = olm_new_account(); + auto bob = olm_new_account(); + auto carl = olm_new_account(); + + generate_one_time_keys(bob, 1); + generate_one_time_keys(carl, 1); + + OneTimeKeys bob_otk = one_time_keys(bob); + IdentityKeys bob_ik = identity_keys(bob); + + OneTimeKeys carl_otk = one_time_keys(carl); + IdentityKeys carl_ik = identity_keys(carl); + + auto bob_session = + init_outbound_group_session(alice, bob_ik.curve25519, bob_otk.curve25519.begin()->second); + auto carl_session = init_outbound_group_session( + alice, carl_ik.curve25519, carl_otk.curve25519.begin()->second); + + auto sid_1 = create_buffer(bob_session.session_id_length()); + bob_session.session_id(sid_1->data(), sid_1->size()); + + EXPECT_EQ(sid_1->size(), 32); + + auto sid_2 = create_buffer(carl_session.session_id_length()); + carl_session.session_id(sid_2->data(), sid_2->size()); + + EXPECT_EQ(sid_2->size(), 32); +}