Skip to content
Snippets Groups Projects
Commit 5291ec78 authored by Mark Haines's avatar Mark Haines
Browse files

Send the public part of the one time key rather than passing an identifier

parent 974e0984
No related branches found
No related tags found
No related merge requests found
......@@ -100,7 +100,7 @@ struct Account {
);
OneTimeKey const * lookup_key(
std::uint32_t id
Curve25519PublicKey const & public_key
);
std::size_t remove_key(
......
......@@ -73,16 +73,16 @@ void decode_message(
struct PreKeyMessageWriter {
std::uint8_t * identity_key;
std::uint8_t * base_key;
std::uint8_t * one_time_key;
std::uint8_t * message;
};
struct PreKeyMessageReader {
std::uint8_t version;
bool has_one_time_key_id;
std::uint32_t one_time_key_id;
std::uint8_t const * identity_key; std::size_t identity_key_length;
std::uint8_t const * base_key; std::size_t base_key_length;
std::uint8_t const * one_time_key; std::size_t one_time_key_length;
std::uint8_t const * message; std::size_t message_length;
};
......@@ -91,9 +91,9 @@ struct PreKeyMessageReader {
* The length of the buffer needed to hold a message.
*/
std::size_t encode_one_time_key_message_length(
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t one_time_key_length,
std::size_t message_length
);
......@@ -105,9 +105,9 @@ std::size_t encode_one_time_key_message_length(
void encode_one_time_key_message(
PreKeyMessageWriter & writer,
std::uint8_t version,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t one_time_key_length,
std::size_t message_length,
std::uint8_t * output
);
......
......@@ -187,7 +187,6 @@ size_t olm_create_outbound_session(
OlmSession * session,
OlmAccount * account,
void const * their_identity_key, size_t their_identity_key_length,
unsigned their_one_time_key_id,
void const * their_one_time_key, size_t their_one_time_key_length,
void const * random, size_t random_length
);
......
......@@ -44,9 +44,9 @@ struct Session {
RemoteKey alice_identity_key;
Curve25519PublicKey alice_base_key;
Curve25519PublicKey bob_one_time_key;
std::uint32_t bob_one_time_key_id;
std::size_t new_outbound_session_random_length();
std::size_t new_outbound_session(
......
......@@ -18,10 +18,12 @@
olm::OneTimeKey const * olm::Account::lookup_key(
std::uint32_t id
olm::Curve25519PublicKey const & public_key
) {
for (olm::OneTimeKey const & key : one_time_keys) {
if (key.id == id) return &key;
if (0 == memcmp(key.key.public_key, public_key.public_key, 32)) {
return &key;
}
}
return 0;
}
......
......@@ -232,7 +232,7 @@ void olm::decode_message(
namespace {
static std::uint8_t const ONE_TIME_KEY_ID_TAG = 010;
static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012;
static std::uint8_t const BASE_KEY_TAG = 022;
static std::uint8_t const IDENTITY_KEY_TAG = 032;
static std::uint8_t const MESSAGE_TAG = 042;
......@@ -241,13 +241,13 @@ static std::uint8_t const MESSAGE_TAG = 042;
std::size_t olm::encode_one_time_key_message_length(
std::uint32_t one_time_key_id,
std::size_t one_time_key_length,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length
) {
std::size_t length = VERSION_LENGTH;
length += 1 + varint_length(one_time_key_id);
length += 1 + varstring_length(one_time_key_length);
length += 1 + varstring_length(identity_key_length);
length += 1 + varstring_length(base_key_length);
length += 1 + varstring_length(message_length);
......@@ -258,15 +258,15 @@ std::size_t olm::encode_one_time_key_message_length(
void olm::encode_one_time_key_message(
olm::PreKeyMessageWriter & writer,
std::uint8_t version,
std::uint32_t one_time_key_id,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t one_time_key_length,
std::size_t message_length,
std::uint8_t * output
) {
std::uint8_t * pos = output;
*(pos++) = version;
pos = encode(pos, ONE_TIME_KEY_ID_TAG, one_time_key_id);
pos = encode(pos, ONE_TIME_KEY_ID_TAG, writer.one_time_key, one_time_key_length);
pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length);
pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length);
pos = encode(pos, MESSAGE_TAG, writer.message, message_length);
......@@ -283,7 +283,7 @@ void olm::decode_one_time_key_message(
if (pos == end) return;
reader.version = *(pos++);
reader.has_one_time_key_id = false;
reader.one_time_key = nullptr;
reader.identity_key = nullptr;
reader.base_key = nullptr;
reader.message = nullptr;
......@@ -291,7 +291,7 @@ void olm::decode_one_time_key_message(
while (pos != end) {
pos = decode(
pos, end, ONE_TIME_KEY_ID_TAG,
reader.one_time_key_id, reader.has_one_time_key_id
reader.one_time_key, reader.one_time_key_length
);
pos = decode(
pos, end, BASE_KEY_TAG,
......
......@@ -425,7 +425,6 @@ size_t olm_create_outbound_session(
OlmSession * session,
OlmAccount * account,
void const * their_identity_key, size_t their_identity_key_length,
unsigned their_one_time_key_id,
void const * their_one_time_key, size_t their_one_time_key_length,
void const * random, size_t random_length
) {
......@@ -442,7 +441,6 @@ size_t olm_create_outbound_session(
from_c(their_identity_key), their_identity_key_length,
identity_key.public_key
);
one_time_key.id = their_one_time_key_id;
olm::decode_base64(
from_c(their_one_time_key), their_one_time_key_length,
one_time_key.key.public_key
......
......@@ -77,7 +77,7 @@ std::size_t olm::Session::new_outbound_session(
alice_identity_key.id = 0;
alice_identity_key.key = local_account.identity_keys.curve25519_key;
alice_base_key = base_key;
bob_one_time_key_id = one_time_key.id;
bob_one_time_key = one_time_key.key;
std::uint8_t shared_secret[96];
......@@ -112,7 +112,8 @@ bool check_message_fields(
ok = ok && reader.message;
ok = ok && reader.base_key;
ok = ok && reader.base_key_length == KEY_LENGTH;
ok = ok && reader.has_one_time_key_id;
ok = ok && reader.one_time_key;
ok = ok && reader.one_time_key_length == KEY_LENGTH;
return ok;
}
......@@ -145,15 +146,15 @@ std::size_t olm::Session::new_inbound_session(
std::memcpy(alice_identity_key.key.public_key, reader.identity_key, 32);
std::memcpy(alice_base_key.public_key, reader.base_key, 32);
bob_one_time_key_id = reader.one_time_key_id;
std::memcpy(bob_one_time_key.public_key, reader.one_time_key, 32);
olm::Curve25519PublicKey ratchet_key;
std::memcpy(ratchet_key.public_key, message_reader.ratchet_key, 32);
olm::OneTimeKey const * bob_one_time_key = local_account.lookup_key(
bob_one_time_key_id
olm::OneTimeKey const * our_one_time_key = local_account.lookup_key(
bob_one_time_key
);
if (!bob_one_time_key) {
if (!our_one_time_key) {
last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
return std::size_t(-1);
}
......@@ -161,14 +162,14 @@ std::size_t olm::Session::new_inbound_session(
std::uint8_t shared_secret[96];
olm::curve25519_shared_secret(
bob_one_time_key->key, alice_identity_key.key, shared_secret
our_one_time_key->key, alice_identity_key.key, shared_secret
);
olm::curve25519_shared_secret(
local_account.identity_keys.curve25519_key,
alice_base_key, shared_secret + 32
);
olm::curve25519_shared_secret(
bob_one_time_key->key, alice_base_key, shared_secret + 64
our_one_time_key->key, alice_base_key, shared_secret + 64
);
ratchet.initialise_as_bob(shared_secret, 96, ratchet_key);
......@@ -194,7 +195,9 @@ bool olm::Session::matches_inbound_session(
same = same && 0 == std::memcmp(
reader.base_key, alice_base_key.public_key, KEY_LENGTH
);
same = same && reader.one_time_key_id == bob_one_time_key_id;
same = same && 0 == std::memcmp(
reader.one_time_key, bob_one_time_key.public_key, KEY_LENGTH
);
return same;
}
......@@ -220,7 +223,7 @@ std::size_t olm::Session::encrypt_message_length(
}
return encode_one_time_key_message_length(
bob_one_time_key_id,
KEY_LENGTH,
KEY_LENGTH,
KEY_LENGTH,
message_length
......@@ -254,12 +257,15 @@ std::size_t olm::Session::encrypt(
encode_one_time_key_message(
writer,
PROTOCOL_VERSION,
bob_one_time_key_id,
KEY_LENGTH,
KEY_LENGTH,
KEY_LENGTH,
message_body_length,
message
);
std::memcpy(
writer.one_time_key, bob_one_time_key.public_key, KEY_LENGTH
);
std::memcpy(
writer.identity_key, alice_identity_key.key.public_key, KEY_LENGTH
);
......@@ -358,6 +364,7 @@ std::size_t olm::pickle_length(
length += olm::pickle_length(value.alice_identity_key.id);
length += olm::pickle_length(value.alice_identity_key.key);
length += olm::pickle_length(value.alice_base_key);
length += olm::pickle_length(value.bob_one_time_key);
length += olm::pickle_length(value.bob_one_time_key_id);
length += olm::pickle_length(value.ratchet);
return length;
......@@ -372,6 +379,7 @@ std::uint8_t * olm::pickle(
pos = olm::pickle(pos, value.alice_identity_key.id);
pos = olm::pickle(pos, value.alice_identity_key.key);
pos = olm::pickle(pos, value.alice_base_key);
pos = olm::pickle(pos, value.bob_one_time_key);
pos = olm::pickle(pos, value.bob_one_time_key_id);
pos = olm::pickle(pos, value.ratchet);
return pos;
......@@ -386,6 +394,7 @@ std::uint8_t const * olm::unpickle(
pos = olm::unpickle(pos, end, value.alice_identity_key.id);
pos = olm::unpickle(pos, end, value.alice_identity_key.key);
pos = olm::unpickle(pos, end, value.alice_base_key);
pos = olm::unpickle(pos, end, value.bob_one_time_key);
pos = olm::unpickle(pos, end, value.bob_one_time_key_id);
pos = olm::unpickle(pos, end, value.ratchet);
return pos;
......
......@@ -89,7 +89,7 @@ mock_random_a(a_rand, sizeof(a_rand));
assert_not_equals(std::size_t(-1), ::olm_create_outbound_session(
a_session, a_account,
b_id_keys + 88, 43,
::atol((char *)(b_ot_keys + 62)), b_ot_keys + 74, 43,
b_ot_keys + 74, 43,
a_rand, sizeof(a_rand)
));
......@@ -193,7 +193,7 @@ mock_random_a(a_rand, sizeof(a_rand));
assert_not_equals(std::size_t(-1), ::olm_create_outbound_session(
a_session, a_account,
b_id_keys + 88, 43,
::atol((char *)(b_ot_keys + 62)), b_ot_keys + 74, 43,
b_ot_keys + 74, 43,
a_rand, sizeof(a_rand)
));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment