Skip to content
Snippets Groups Projects
Commit 7e522df7 authored by Konstantinos Sideris's avatar Konstantinos Sideris
Browse files

Implement decryption for group events on the crypto_bot

parent 9bdd5216
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <atomic> #include <atomic>
#include <iostream> #include <iostream>
#include <json.hpp> #include <json.hpp>
#include <stdexcept>
#include <unistd.h> #include <unistd.h>
#include <variant.hpp> #include <variant.hpp>
...@@ -31,6 +32,46 @@ using namespace mtx::identifiers; ...@@ -31,6 +32,46 @@ using namespace mtx::identifiers;
using TimelineEvent = mtx::events::collections::TimelineEvents; using TimelineEvent = mtx::events::collections::TimelineEvents;
constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
struct OlmCipherContent
{
std::string body;
uint8_t type;
};
inline void
from_json(const nlohmann::json &obj, OlmCipherContent &msg)
{
msg.body = obj.at("body");
msg.type = obj.at("type");
}
struct OlmMessage
{
std::string sender_key;
std::string sender;
using RecipientKey = std::string;
std::map<RecipientKey, OlmCipherContent> ciphertext;
};
inline void
from_json(const nlohmann::json &obj, OlmMessage &msg)
{
if (obj.at("type") != "m.room.encrypted") {
throw std::invalid_argument("invalid type for olm message");
}
if (obj.at("content").at("algorithm") != OLM_ALGO)
throw std::invalid_argument("invalid algorithm for olm message");
msg.sender = obj.at("sender");
msg.sender_key = obj.at("content").at("sender_key");
msg.ciphertext =
obj.at("content").at("ciphertext").get<std::map<std::string, OlmCipherContent>>();
}
template<class Container, class Item> template<class Container, class Item>
bool bool
exists(const Container &container, const Item &item) exists(const Container &container, const Item &item)
...@@ -69,6 +110,36 @@ struct Storage ...@@ -69,6 +110,36 @@ struct Storage
std::map<std::string, std::string> device_keys_; std::map<std::string, std::string> device_keys_;
//! Flag that indicate if a specific room has encryption enabled. //! Flag that indicate if a specific room has encryption enabled.
std::map<std::string, bool> encrypted_rooms_; std::map<std::string, bool> encrypted_rooms_;
//! Mapping from curve25519 to session.
std::map<std::string, OlmSessionPtr> olm_sessions;
std::map<std::string, InboundGroupSessionPtr> inbound_group_sessions;
bool inbound_group_exists(const std::string &room_id,
const std::string &session_id,
const std::string &sender_key)
{
const auto key = room_id + session_id + sender_key;
return inbound_group_sessions.find(key) != inbound_group_sessions.end();
}
void set_inbound_group_session(const std::string &room_id,
const std::string &session_id,
const std::string &sender_key,
InboundGroupSessionPtr session)
{
const auto key = room_id + session_id + sender_key;
inbound_group_sessions[key] = std::move(session);
}
OlmInboundGroupSession *get_inbound_group_session(const std::string &room_id,
const std::string &session_id,
const std::string &sender_key)
{
const auto key = room_id + session_id + sender_key;
return inbound_group_sessions[key].get();
}
}; };
namespace { namespace {
...@@ -182,6 +253,73 @@ mark_encrypted_room(const RoomId &id) ...@@ -182,6 +253,73 @@ mark_encrypted_room(const RoomId &id)
storage.encrypted_rooms_[id.get()] = true; storage.encrypted_rooms_[id.get()] = true;
} }
void
decrypt_olm_message(const OlmMessage &olm_msg)
{
console->info("OLM message");
console->info("sender: {}", olm_msg.sender);
console->info("sender_key: {}", olm_msg.sender_key);
const auto my_id_key = olm_client->identity_keys().curve25519;
for (const auto &cipher : olm_msg.ciphertext) {
if (cipher.first == my_id_key) {
const auto msg_body = cipher.second.body;
const auto msg_type = cipher.second.type;
console->info("the message is meant for us");
console->info("body: {}", msg_body);
console->info("type: {}", msg_type);
if (msg_type == 0) {
console->info("opening session with {}", olm_msg.sender);
auto inbound_session = olm_client->create_inbound_session(msg_body);
auto ok = matches_inbound_session_from(
inbound_session.get(), olm_msg.sender_key, msg_body);
if (!ok) {
console->error("session could not be established");
} else {
auto output = olm_client->decrypt_message(
inbound_session.get(), msg_type, msg_body);
auto plaintext = json::parse(
std::string((char *)output.data(), output.size()));
console->info("decrypted message: \n {}",
plaintext.dump(2));
storage.olm_sessions.emplace(olm_msg.sender_key,
std::move(inbound_session));
std::string room_id = plaintext.at("content").at("room_id");
std::string session_id =
plaintext.at("content").at("session_id");
std::string session_key =
plaintext.at("content").at("session_key");
if (storage.inbound_group_exists(
room_id, session_id, olm_msg.sender_key)) {
console->warn("megolm session already exists");
} else {
auto megolm_session =
olm_client->init_inbound_group_session(
session_key);
storage.set_inbound_group_session(
room_id,
session_id,
olm_msg.sender_key,
std::move(megolm_session));
console->info("megolm_session saved");
}
}
}
}
}
}
void void
parse_messages(const mtx::responses::Sync &res) parse_messages(const mtx::responses::Sync &res)
{ {
...@@ -227,7 +365,28 @@ parse_messages(const mtx::responses::Sync &res) ...@@ -227,7 +365,28 @@ parse_messages(const mtx::responses::Sync &res)
console->debug("{}", get_json(e)); console->debug("{}", get_json(e));
} else if (is_encrypted(e)) { } else if (is_encrypted(e)) {
console->info("received an encrypted event: {}", room_id); console->info("received an encrypted event: {}", room_id);
console->debug("{}", get_json(e)); console->info("{}", get_json(e));
auto msg = mpark::get<EncryptedEvent<msg::Encrypted>>(e);
if (storage.inbound_group_exists(
room_id, msg.content.session_id, msg.content.sender_key)) {
auto res = olm_client->decrypt_group_message(
storage.get_inbound_group_session(room_id,
msg.content.session_id,
msg.content.sender_key),
msg.content.ciphertext);
auto msg_str =
std::string((char *)res.data.data(), res.data.size());
console->info("decrypted data: {}", msg_str);
console->info("decrypted message_index: {}",
res.message_index);
} else {
console->warn(
"no megolm session found to decrypt the event");
}
} }
} }
} }
...@@ -368,9 +527,23 @@ get_device_keys(const UserId &user) ...@@ -368,9 +527,23 @@ get_device_keys(const UserId &user)
} }
void void
handle_to_device_msgs(const std::vector<nlohmann::json> &to_device) handle_to_device_msgs(const std::vector<nlohmann::json> &msgs)
{ {
(void)to_device; if (!msgs.empty())
console->info("inspecting {} to_device messages", msgs.size());
for (const auto &msg : msgs) {
console->info(msg.dump(2));
try {
OlmMessage olm_msg = msg;
decrypt_olm_message(std::move(olm_msg));
} catch (const nlohmann::json::exception &e) {
console->warn("parsing error for olm message: {}", e.what());
} catch (const std::invalid_argument &e) {
console->warn("validation error for olm message: {}", e.what());
}
}
} }
void void
...@@ -424,21 +597,26 @@ main() ...@@ -424,21 +597,26 @@ main()
{ {
spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v"); spdlog::set_pattern("[%H:%M:%S] [tid %t] [%^%l%$] %v");
std::string username, server, password; std::string username("mtx_bot");
std::string server("matrix.org");
std::string password("dzyvrwB09GdyEqiyBnfAEvZI3");
cout << "username: "; // cout << "username: ";
std::getline(std::cin, username); // std::getline(std::cin, username);
cout << "server: "; // cout << "server: ";
std::getline(std::cin, server); // std::getline(std::cin, server);
password = getpass("password: "); // password = getpass("password: ");
client = std::make_shared<Client>(server); client = std::make_shared<Client>(server);
olm_client = make_shared<OlmClient>(); olm_client = make_shared<OlmClient>();
olm_client->create_new_account(); olm_client->create_new_account();
console->info("ed25519: {}", olm_client->identity_keys().ed25519);
console->info("curve25519: {}", olm_client->identity_keys().curve25519);
client->login(username, password, login_cb); client->login(username, password, login_cb);
client->close(); client->close();
......
...@@ -42,6 +42,10 @@ public: ...@@ -42,6 +42,10 @@ public:
: msg_(func + ": " + std::string(olm_outbound_group_session_last_error(s))) : msg_(func + ": " + std::string(olm_outbound_group_session_last_error(s)))
{} {}
olm_exception(std::string func, OlmInboundGroupSession *s)
: msg_(func + ": " + std::string(olm_inbound_group_session_last_error(s)))
{}
olm_exception(std::string msg) olm_exception(std::string msg)
: msg_(msg) : msg_(msg)
{} {}
...@@ -127,6 +131,16 @@ create_olm_object() ...@@ -127,6 +131,16 @@ create_olm_object()
return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate()); return std::unique_ptr<T, OlmDeleter>(OlmAllocator<T>::allocate());
} }
using OlmSessionPtr = std::unique_ptr<OlmSession, OlmDeleter>;
using OutboundGroupSessionPtr = std::unique_ptr<OlmOutboundGroupSession, OlmDeleter>;
using InboundGroupSessionPtr = std::unique_ptr<OlmInboundGroupSession, OlmDeleter>;
struct GroupPlaintext
{
BinaryBuf data;
uint32_t message_index;
};
class OlmClient : public std::enable_shared_from_this<OlmClient> class OlmClient : public std::enable_shared_from_this<OlmClient>
{ {
public: public:
...@@ -170,6 +184,10 @@ public: ...@@ -170,6 +184,10 @@ public:
mtx::requests::UploadKeys create_upload_keys_request(const OneTimeKeys &keys); mtx::requests::UploadKeys create_upload_keys_request(const OneTimeKeys &keys);
mtx::requests::UploadKeys create_upload_keys_request(); mtx::requests::UploadKeys create_upload_keys_request();
//! Decrypt a message using megolm.
GroupPlaintext decrypt_group_message(OlmInboundGroupSession *session,
const std::string &message,
uint32_t message_index = 0);
//! Encrypt a message using olm. //! Encrypt a message using olm.
BinaryBuf encrypt_message(OlmSession *session, const std::string &msg); BinaryBuf encrypt_message(OlmSession *session, const std::string &msg);
//! Decrypt a message using olm. //! Decrypt a message using olm.
...@@ -178,12 +196,12 @@ public: ...@@ -178,12 +196,12 @@ public:
const std::string &msg); const std::string &msg);
//! Create an outbount megolm session. //! Create an outbount megolm session.
std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> init_outbound_group_session(); OutboundGroupSessionPtr init_outbound_group_session();
std::unique_ptr<OlmSession, OlmDeleter> create_outbound_session( InboundGroupSessionPtr init_inbound_group_session(const std::string &session_key);
const std::string &identity_key, OlmSessionPtr create_outbound_session(const std::string &identity_key,
const std::string &one_time_key); const std::string &one_time_key);
std::unique_ptr<OlmSession, OlmDeleter> create_inbound_session( OlmSessionPtr create_inbound_session(const BinaryBuf &one_time_key_message);
const BinaryBuf &one_time_key_message); OlmSessionPtr create_inbound_session(const std::string &one_time_key_message);
OlmAccount *account() { return account_.get(); } OlmAccount *account() { return account_.get(); }
OlmUtility *utility() { return utility_.get(); } OlmUtility *utility() { return utility_.get(); }
...@@ -212,12 +230,12 @@ std::string ...@@ -212,12 +230,12 @@ std::string
session_key(OlmOutboundGroupSession *s); session_key(OlmOutboundGroupSession *s);
bool bool
matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message); matches_inbound_session(OlmSession *session, const std::string &one_time_key_message);
bool bool
matches_inbound_session_from(OlmSession *session, matches_inbound_session_from(OlmSession *session,
const std::string &id_key, const std::string &id_key,
const BinaryBuf &one_time_key_message); const std::string &one_time_key_message);
//! Verify a signature object as obtained from the response of /keys/query endpoint //! Verify a signature object as obtained from the response of /keys/query endpoint
bool bool
......
...@@ -174,7 +174,7 @@ OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_k ...@@ -174,7 +174,7 @@ OlmClient::create_upload_keys_request(const mtx::crypto::OneTimeKeys &one_time_k
return req; return req;
} }
std::unique_ptr<OlmOutboundGroupSession, OlmDeleter> OutboundGroupSessionPtr
OlmClient::init_outbound_group_session() OlmClient::init_outbound_group_session()
{ {
auto session = create_olm_object<OlmOutboundGroupSession>(); auto session = create_olm_object<OlmOutboundGroupSession>();
...@@ -189,6 +189,54 @@ OlmClient::init_outbound_group_session() ...@@ -189,6 +189,54 @@ OlmClient::init_outbound_group_session()
return session; return session;
} }
InboundGroupSessionPtr
OlmClient::init_inbound_group_session(const std::string &session_key)
{
auto session = create_olm_object<OlmInboundGroupSession>();
const int ret = olm_init_inbound_group_session(
session.get(), reinterpret_cast<const uint8_t *>(session_key.data()), session_key.size());
if (ret == -1)
throw olm_exception("init_inbound_group_session", session.get());
return session;
}
GroupPlaintext
OlmClient::decrypt_group_message(OlmInboundGroupSession *session,
const std::string &message,
uint32_t message_index)
{
// TODO handle errors
auto tmp_msg = create_buffer(message.size());
std::copy(message.begin(), message.end(), tmp_msg.begin());
auto plaintext_len =
olm_group_decrypt_max_plaintext_length(session, tmp_msg.data(), tmp_msg.size());
auto plaintext = create_buffer(plaintext_len);
tmp_msg = create_buffer(message.size());
std::copy(message.begin(), message.end(), tmp_msg.begin());
const int nbytes = olm_group_decrypt(session,
tmp_msg.data(),
tmp_msg.size(),
plaintext.data(),
plaintext.size(),
&message_index);
logger->info("new message_index: {}", message_index);
if (nbytes == -1)
throw olm_exception("olm_group_decrypt", session);
auto output = create_buffer(nbytes);
std::memcpy(output.data(), plaintext.data(), nbytes);
return GroupPlaintext{std::move(output), message_index};
}
BinaryBuf BinaryBuf
OlmClient::decrypt_message(OlmSession *session, OlmClient::decrypt_message(OlmSession *session,
size_t msgtype, size_t msgtype,
...@@ -235,7 +283,16 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg) ...@@ -235,7 +283,16 @@ OlmClient::encrypt_message(OlmSession *session, const std::string &msg)
return ciphertext; return ciphertext;
} }
std::unique_ptr<OlmSession, OlmDeleter> OlmSessionPtr
OlmClient::create_inbound_session(const std::string &one_time_key_message)
{
BinaryBuf tmp(one_time_key_message.size());
memcpy(tmp.data(), one_time_key_message.data(), one_time_key_message.size());
return create_inbound_session(std::move(tmp));
}
OlmSessionPtr
OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message) OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
{ {
auto session = create_olm_object<OlmSession>(); auto session = create_olm_object<OlmSession>();
...@@ -252,7 +309,7 @@ OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message) ...@@ -252,7 +309,7 @@ OlmClient::create_inbound_session(const BinaryBuf &one_time_key_message)
return session; return session;
} }
std::unique_ptr<OlmSession, OlmDeleter> OlmSessionPtr
OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key) OlmClient::create_outbound_session(const std::string &identity_key, const std::string &one_time_key)
{ {
auto session = create_olm_object<OlmSession>(); auto session = create_olm_object<OlmSession>();
...@@ -322,7 +379,7 @@ mtx::crypto::session_key(OlmOutboundGroupSession *s) ...@@ -322,7 +379,7 @@ mtx::crypto::session_key(OlmOutboundGroupSession *s)
} }
bool bool
mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_time_key_message) mtx::crypto::matches_inbound_session(OlmSession *session, const std::string &one_time_key_message)
{ {
auto tmp = create_buffer(one_time_key_message.size()); auto tmp = create_buffer(one_time_key_message.size());
std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
...@@ -333,7 +390,7 @@ mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_t ...@@ -333,7 +390,7 @@ mtx::crypto::matches_inbound_session(OlmSession *session, const BinaryBuf &one_t
bool bool
mtx::crypto::matches_inbound_session_from(OlmSession *session, mtx::crypto::matches_inbound_session_from(OlmSession *session,
const std::string &id_key, const std::string &id_key,
const BinaryBuf &one_time_key_message) const std::string &one_time_key_message)
{ {
auto tmp = create_buffer(one_time_key_message.size()); auto tmp = create_buffer(one_time_key_message.size());
std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin()); std::copy(one_time_key_message.begin(), one_time_key_message.end(), tmp.begin());
......
...@@ -532,29 +532,30 @@ TEST(Encryption, OlmSessions) ...@@ -532,29 +532,30 @@ TEST(Encryption, OlmSessions)
auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key); auto alice_outbound_session = alice->create_outbound_session(bob_key, bob_one_time_key);
// Alice encrypts the message using the current session. // Alice encrypts the message using the current session.
auto plaintext = "Hello, Bob!"; auto plaintext = "Hello, Bob!";
size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get()); size_t msgtype = olm_encrypt_message_type(alice_outbound_session.get());
auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext); auto ciphertext = alice->encrypt_message(alice_outbound_session.get(), plaintext);
auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
EXPECT_EQ(msgtype, 0); EXPECT_EQ(msgtype, 0);
// Bob creates an inbound session to receive Alice's message. // Bob creates an inbound session to receive Alice's message.
auto bob_inbound_session = bob->create_inbound_session(ciphertext); auto bob_inbound_session = bob->create_inbound_session(ciphertext_str);
// Bob validates that the message was meant for him. // Bob validates that the message was meant for him.
ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext)); ASSERT_EQ(1, matches_inbound_session(bob_inbound_session.get(), ciphertext_str));
// Bob validates that the message was sent from Alice. // Bob validates that the message was sent from Alice.
ASSERT_EQ(1, ASSERT_EQ(
matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext)); 1, matches_inbound_session_from(bob_inbound_session.get(), alice_key, ciphertext_str));
// Bob validates that the message wasn't sent by someone else. // Bob validates that the message wasn't sent by someone else.
ASSERT_EQ(0, matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext)); ASSERT_EQ(0,
matches_inbound_session_from(bob_inbound_session.get(), bob_key, ciphertext_str));
// Bob decrypts the message // Bob decrypts the message
auto ciphertext_str = std::string((char *)ciphertext.data(), ciphertext.size());
auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str); auto decrypted = bob->decrypt_message(bob_inbound_session.get(), msgtype, ciphertext_str);
auto body_str = std::string((char *)decrypted.data(), decrypted.size()); auto body_str = std::string((char *)decrypted.data(), decrypted.size());
ASSERT_EQ(body_str, "Hello, Bob!"); ASSERT_EQ(body_str, plaintext);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment