Skip to content
Snippets Groups Projects
Commit 6866000b authored by Konstantinos Sideris's avatar Konstantinos Sideris
Browse files

Add pickle/unpickle support on the e2ee example

parent a2e7e43f
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@
#include "spdlog/spdlog.h"
#include <atomic>
#include <fstream>
#include <iostream>
#include <json.hpp>
#include <stdexcept>
......@@ -34,7 +35,8 @@ using namespace mtx::identifiers;
using TimelineEvent = mtx::events::collections::TimelineEvents;
constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
constexpr auto OLM_ALGO = "m.olm.v1.curve25519-aes-sha2";
constexpr auto STORAGE_KEY = "secret";
struct OlmCipherContent
{
......@@ -93,17 +95,6 @@ mark_encrypted_room(const RoomId &id);
void
handle_to_device_msgs(const std::vector<nlohmann::json> &to_device);
//! Metadata associated with each active megolm session.
struct GroupSessionMsgData
{
std::string session_id;
std::string room_id;
std::string event_id;
uint64_t origin_server_ts;
uint64_t message_index;
};
struct OutboundSessionData
{
std::string session_id;
......@@ -111,6 +102,22 @@ struct OutboundSessionData
uint64_t message_index = 0;
};
inline void
to_json(nlohmann::json &obj, const OutboundSessionData &msg)
{
obj["session_id"] = msg.session_id;
obj["session_key"] = msg.session_key;
obj["message_index"] = msg.message_index;
}
inline void
from_json(const nlohmann::json &obj, OutboundSessionData &msg)
{
msg.session_id = obj.at("session_id");
msg.session_key = obj.at("session_key");
msg.message_index = obj.at("message_index");
}
struct OutboundSessionDataRef
{
OlmOutboundGroupSession *session;
......@@ -123,14 +130,33 @@ struct DevKeys
std::string curve25519;
};
inline void
to_json(nlohmann::json &obj, const DevKeys &msg)
{
obj["ed25519"] = msg.ed25519;
obj["curve25519"] = msg.curve25519;
}
inline void
from_json(const nlohmann::json &obj, DevKeys &msg)
{
msg.ed25519 = obj.at("ed25519");
msg.curve25519 = obj.at("curve25519");
}
auto console = spdlog::stdout_color_mt("console");
std::shared_ptr<Client> client = nullptr;
std::shared_ptr<OlmClient> olm_client = nullptr;
struct Storage
{
//! Storage for the user_id -> list of devices mapping.
std::map<std::string, std::vector<std::string>> devices_;
std::map<std::string, std::vector<std::string>> devices;
//! Storage for the identity key for a device.
std::map<std::string, DevKeys> device_keys;
//! Flag that indicate if a specific room has encryption enabled.
std::map<std::string, bool> encrypted_rooms_;
std::map<std::string, bool> encrypted_rooms;
//! Keep track of members per room.
std::map<std::string, std::map<std::string, bool>> members;
......@@ -195,14 +221,105 @@ struct Storage
const auto key = room_id + session_id + sender_key;
return inbound_group_sessions[key].get();
}
void load()
{
console->info("restoring storage");
ifstream db("db.json");
string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
if (db_data.empty())
return;
json obj = json::parse(db_data);
devices = obj.at("devices").get<map<string, vector<string>>>();
device_keys = obj.at("device_keys").get<map<string, DevKeys>>();
encrypted_rooms = obj.at("encrypted_rooms").get<map<string, bool>>();
members = obj.at("members").get<map<string, map<string, bool>>>();
if (obj.count("olm_inbound_sessions") != 0) {
auto sessions = obj.at("olm_inbound_sessions").get<map<string, string>>();
for (const auto &s : sessions)
olm_inbound_sessions[s.first] =
unpickle<SessionObject>(s.second, STORAGE_KEY);
}
if (obj.count("olm_outbound_sessions") != 0) {
auto sessions = obj.at("olm_outbound_sessions").get<map<string, string>>();
for (const auto &s : sessions)
olm_outbound_sessions[s.first] =
unpickle<SessionObject>(s.second, STORAGE_KEY);
}
if (obj.count("inbound_group_sessions") != 0) {
auto sessions = obj.at("inbound_group_sessions").get<map<string, string>>();
for (const auto &s : sessions)
inbound_group_sessions[s.first] =
unpickle<InboundSessionObject>(s.second, STORAGE_KEY);
}
if (obj.count("outbound_group_sessions") != 0) {
auto sessions =
obj.at("outbound_group_sessions").get<map<string, string>>();
for (const auto &s : sessions)
outbound_group_sessions[s.first] =
unpickle<OutboundSessionObject>(s.second, STORAGE_KEY);
}
if (obj.count("outbound_group_session_data") != 0) {
auto sessions = obj.at("outbound_group_session_data")
.get<map<string, OutboundSessionData>>();
for (const auto &s : sessions)
outbound_group_session_data[s.first] = s.second;
}
}
void save()
{
console->info("saving storage");
std::ofstream db("db.json");
if (!db.is_open()) {
console->error("couldn't open file to save keys");
return;
}
json data;
data["devices"] = devices;
data["device_keys"] = device_keys;
data["encrypted_rooms"] = encrypted_rooms;
data["members"] = members;
// Save inbound sessions
for (const auto &s : olm_inbound_sessions)
data["olm_inbound_sessions"][s.first] =
mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
for (const auto &s : olm_outbound_sessions)
data["olm_outbound_sessions"][s.first] =
mtx::crypto::pickle<SessionObject>(s.second.get(), STORAGE_KEY);
for (const auto &s : inbound_group_sessions)
data["inbound_group_sessions"][s.first] =
mtx::crypto::pickle<InboundSessionObject>(s.second.get(), STORAGE_KEY);
for (const auto &s : outbound_group_sessions)
data["outbound_group_sessions"][s.first] =
mtx::crypto::pickle<OutboundSessionObject>(s.second.get(), STORAGE_KEY);
for (const auto &s : outbound_group_session_data)
data["outbound_group_session_data"][s.first] = s.second;
// Save to file
db << data.dump(2);
db.close();
}
};
namespace {
std::shared_ptr<Client> client = nullptr;
std::shared_ptr<OlmClient> olm_client = nullptr;
Storage storage;
auto console = spdlog::stdout_color_mt("console");
}
void
......@@ -279,7 +396,7 @@ create_outbound_megolm_session(const std::string &room_id, const std::string &re
const auto members = storage.members[room_id];
for (const auto &member : members) {
const auto devices = storage.devices_[member.first];
const auto devices = storage.devices[member.first];
// TODO: Figure out for which devices we don't have olm sessions.
for (const auto &dev : devices) {
......@@ -454,7 +571,7 @@ void
mark_encrypted_room(const RoomId &id)
{
console->info("encryption is enabled for room: {}", id.get());
storage.encrypted_rooms_[id.get()] = true;
storage.encrypted_rooms[id.get()] = true;
}
void
......@@ -709,7 +826,7 @@ save_device_keys(const mtx::responses::QueryKeys &res)
for (const auto &entry : res.device_keys) {
const auto user_id = entry.first;
if (!exists(storage.devices_, user_id))
if (!exists(storage.devices, user_id))
console->info("keys for {}", user_id);
std::vector<std::string> device_list;
......@@ -734,8 +851,8 @@ save_device_keys(const mtx::responses::QueryKeys &res)
device_list.push_back(device_id);
}
if (!exists(storage.devices_, user_id)) {
storage.devices_[user_id] = device_list;
if (!exists(storage.devices, user_id)) {
storage.devices[user_id] = device_list;
}
}
}
......@@ -810,6 +927,8 @@ login_cb(const mtx::responses::Login &, RequestErr err)
console->info("User ID: {}", client->user_id().to_string());
console->info("Device ID: {}", client->device_id());
console->info("ed25519: {}", olm_client->identity_keys().ed25519);
console->info("curve25519: {}", olm_client->identity_keys().curve25519);
// Upload one time keys.
olm_client->set_user_id(client->user_id().to_string());
......@@ -850,8 +969,19 @@ void
shutdown_handler(int sig)
{
console->warn("received {} signal", sig);
console->info("saving storage");
console->info("shutting down");
storage.save();
std::ofstream db("account.json");
if (!db.is_open()) {
console->error("couldn't open file to save account keys");
return;
}
json data;
data["account"] = olm_client->save(STORAGE_KEY);
db << data.dump(2);
db.close();
// The sync calls will stop.
client->shutdown();
......@@ -871,10 +1001,17 @@ main()
client = std::make_shared<Client>(server);
olm_client = make_shared<OlmClient>();
olm_client->create_new_account();
console->info("ed25519: {}", olm_client->identity_keys().ed25519);
console->info("curve25519: {}", olm_client->identity_keys().curve25519);
ifstream db("account.json");
string db_data((istreambuf_iterator<char>(db)), istreambuf_iterator<char>());
if (db_data.empty())
olm_client->create_new_account();
else
olm_client->load(json::parse(db_data).at("account").get<std::string>(),
STORAGE_KEY);
storage.load();
client->login(username, password, login_cb);
client->close();
......
......@@ -187,6 +187,9 @@ public:
const std::string &room_key_event,
const std::string &recipient_key);
std::string save(const std::string &key);
void load(const std::string &data, const std::string &key);
OlmAccount *account() { return account_.get(); }
OlmUtility *utility() { return utility_.get(); }
......
......@@ -387,6 +387,21 @@ OlmClient::create_olm_encrypted_content(OlmSession *session,
{"ciphertext", {{recipient_key, {{"body", encrypted_str}, {"type", msg_type}}}}}};
}
std::string
OlmClient::save(const std::string &key)
{
if (!account_)
return std::string();
return pickle<AccountObject>(account(), key);
}
void
OlmClient::load(const std::string &data, const std::string &key)
{
account_ = unpickle<AccountObject>(data, key);
}
BinaryBuf
mtx::crypto::decode_base64(const std::string &msg)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment