Skip to content
Snippets Groups Projects
Commit 89d9b972 authored by Mark Haines's avatar Mark Haines
Browse files

Add versions of olm_session_create_inbound and olm_session_matches_inbound...

Add versions of olm_session_create_inbound and olm_session_matches_inbound which take the curve25519 identity key of the remote device we think the message is from as an additional argument
parent 7523b700
Branches
Tags
No related merge requests found
......@@ -242,6 +242,21 @@ size_t olm_create_inbound_session(
void * one_time_key_message, size_t message_length
);
/** Create a new in-bound session for sending/receiving messages from an
* incoming PRE_KEY message. Returns olm_error() on failure. If the base64
* couldn't be decoded then olm_session_last_error will be "INVALID_BASE64".
* If the message was for an unsupported protocol version then
* olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message
* couldn't be decoded then then olm_session_last_error() will be
* "BAD_MESSAGE_FORMAT". If the message refers to an unknown one time
* key then olm_session_last_error() will be "BAD_MESSAGE_KEY_ID". */
size_t olm_create_inbound_session_from(
OlmSession * session,
OlmAccount * account,
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
);
/** Checks if the PRE_KEY message is for this in-bound session. This can happen
* if multiple messages are sent to this account before this account sends a
* message in reply. Returns olm_error() on failure. If the base64
......@@ -255,6 +270,20 @@ size_t olm_matches_inbound_session(
void * one_time_key_message, size_t message_length
);
/** Checks if the PRE_KEY message is for this in-bound session. This can happen
* if multiple messages are sent to this account before this account sends a
* message in reply. Returns olm_error() on failure. If the base64
* couldn't be decoded then olm_session_last_error will be "INVALID_BASE64".
* If the message was for an unsupported protocol version then
* olm_session_last_error() will be "BAD_MESSAGE_VERSION". If the message
* couldn't be decoded then then olm_session_last_error() will be
* "BAD_MESSAGE_FORMAT". */
size_t olm_matches_inbound_session_from(
OlmSession * session,
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
);
/** Removes the one time keys that the session used from the account. Returns
* olm_error() on failure. If the account doesn't have any matching one time
* keys then olm_account_last_error() will be "BAD_MESSAGE_KEY_ID". */
......
......@@ -50,10 +50,12 @@ struct Session {
std::size_t new_inbound_session(
Account & local_account,
Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
);
bool matches_inbound_session(
Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
);
......
......@@ -518,7 +518,36 @@ size_t olm_create_inbound_session(
return std::size_t(-1);
}
return from_c(session)->new_inbound_session(
*from_c(account), from_c(one_time_key_message), raw_length
*from_c(account), nullptr, from_c(one_time_key_message), raw_length
);
}
size_t olm_create_inbound_session_from(
OlmSession * session,
OlmAccount * account,
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
) {
if (olm::decode_base64_length(their_identity_key_length) != 32) {
from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
}
olm::Curve25519PublicKey identity_key;
olm::decode_base64(
from_c(their_identity_key), their_identity_key_length,
identity_key.public_key
);
std::size_t raw_length = b64_input(
from_c(one_time_key_message), message_length, from_c(session)->last_error
);
if (raw_length == std::size_t(-1)) {
return std::size_t(-1);
}
return from_c(session)->new_inbound_session(
*from_c(account), &identity_key,
from_c(one_time_key_message), raw_length
);
}
......@@ -534,7 +563,35 @@ size_t olm_matches_inbound_session(
return std::size_t(-1);
}
bool matches = from_c(session)->matches_inbound_session(
from_c(one_time_key_message), raw_length
nullptr, from_c(one_time_key_message), raw_length
);
return matches ? 1 : 0;
}
size_t olm_matches_inbound_session_from(
OlmSession * session,
void const * their_identity_key, size_t their_identity_key_length,
void * one_time_key_message, size_t message_length
) {
if (olm::decode_base64_length(their_identity_key_length) != 32) {
from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
return std::size_t(-1);
}
olm::Curve25519PublicKey identity_key;
olm::decode_base64(
from_c(their_identity_key), their_identity_key_length,
identity_key.public_key
);
std::size_t raw_length = b64_input(
from_c(one_time_key_message), message_length, from_c(session)->last_error
);
if (raw_length == std::size_t(-1)) {
return std::size_t(-1);
}
bool matches = from_c(session)->matches_inbound_session(
&identity_key, from_c(one_time_key_message), raw_length
);
return matches ? 1 : 0;
}
......
......@@ -102,11 +102,13 @@ std::size_t olm::Session::new_outbound_session(
namespace {
bool check_message_fields(
olm::PreKeyMessageReader & reader
olm::PreKeyMessageReader & reader, bool have_their_identity_key
) {
bool ok = true;
ok = ok && reader.identity_key;
ok = ok && (have_their_identity_key || reader.identity_key);
if (reader.identity_key) {
ok = ok && reader.identity_key_length == KEY_LENGTH;
}
ok = ok && reader.message;
ok = ok && reader.base_key;
ok = ok && reader.base_key_length == KEY_LENGTH;
......@@ -120,16 +122,27 @@ bool check_message_fields(
std::size_t olm::Session::new_inbound_session(
olm::Account & local_account,
olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
if (!check_message_fields(reader)) {
if (!check_message_fields(reader, their_identity_key)) {
last_error = olm::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
if (reader.identity_key && their_identity_key) {
bool same = 0 == std::memcmp(
their_identity_key->public_key, reader.identity_key, KEY_LENGTH
);
if (!same) {
last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
return std::size_t(-1);
}
}
olm::MessageReader message_reader;
decode_message(
message_reader, reader.message, reader.message_length,
......@@ -177,19 +190,28 @@ std::size_t olm::Session::new_inbound_session(
bool olm::Session::matches_inbound_session(
olm::Curve25519PublicKey const * their_identity_key,
std::uint8_t const * one_time_key_message, std::size_t message_length
) {
olm::PreKeyMessageReader reader;
decode_one_time_key_message(reader, one_time_key_message, message_length);
if (!check_message_fields(reader)) {
if (!check_message_fields(reader, their_identity_key)) {
return false;
}
bool same = true;
if (reader.identity_key) {
same = same && 0 == std::memcmp(
reader.identity_key, alice_identity_key.public_key, KEY_LENGTH
);
}
if (their_identity_key) {
same = same && 0 == std::memcmp(
their_identity_key->public_key, alice_identity_key.public_key,
KEY_LENGTH
);
}
same = same && 0 == std::memcmp(
reader.base_key, alice_base_key.public_key, KEY_LENGTH
);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment