Skip to content
Snippets Groups Projects
Commit 08a7e44a authored by Mark Haines's avatar Mark Haines
Browse files

Pass the message body to decrypt_max_plaintext_length so we can get a more...

Pass the message body to decrypt_max_plaintext_length so we can get a more accurate estimate, rename encrypt_max_output_length to encrypt_output_length and change the api to return the exact number of bytes needed to hold the message
parent 793b9b91
No related branches found
No related tags found
No related merge requests found
......@@ -129,9 +129,9 @@ struct Ratchet {
std::uint8_t * input, std::size_t input_length
);
/** The maximum number of bytes of output the encrypt method will write for
/** The number of bytes of output the encrypt method will write for
* a given message length. */
std::size_t encrypt_max_output_length(
std::size_t encrypt_output_length(
std::size_t plaintext_length
);
......@@ -154,7 +154,7 @@ struct Ratchet {
/** An upper bound on the number of bytes of plain-text the decrypt method
* will write for a given input message length. */
std::size_t decrypt_max_plaintext_length(
std::size_t input_length
std::uint8_t const * input, std::size_t input_length
);
/** Decrypt a message. Returns the length of the decrypted plain-text or
......
......@@ -348,7 +348,7 @@ std::size_t axolotl::Ratchet::unpickle(
}
std::size_t axolotl::Ratchet::encrypt_max_output_length(
std::size_t axolotl::Ratchet::encrypt_output_length(
std::size_t plaintext_length
) {
std::size_t counter = 0;
......@@ -374,7 +374,7 @@ std::size_t axolotl::Ratchet::encrypt(
std::uint8_t const * random, std::size_t random_length,
std::uint8_t * output, std::size_t max_output_length
) {
std::size_t output_length = encrypt_max_output_length(plaintext_length);
std::size_t output_length = encrypt_output_length(plaintext_length);
if (random_length < encrypt_random_length()) {
last_error = axolotl::ErrorCode::NOT_ENOUGH_RANDOM;
......@@ -428,9 +428,19 @@ std::size_t axolotl::Ratchet::encrypt(
std::size_t axolotl::Ratchet::decrypt_max_plaintext_length(
std::size_t input_length
std::uint8_t const * input, std::size_t input_length
) {
return input_length;
axolotl::MessageReader reader;
axolotl::decode_message(
reader, input, input_length, ratchet_cipher.mac_length()
);
if (!reader.ciphertext) {
last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
}
return ratchet_cipher.decrypt_max_plaintext_length(reader.ciphertext_length);
}
......@@ -438,11 +448,6 @@ std::size_t axolotl::Ratchet::decrypt(
std::uint8_t const * input, std::size_t input_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) {
if (max_plaintext_length < decrypt_max_plaintext_length(input_length)) {
last_error = axolotl::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
return std::size_t(-1);
}
axolotl::MessageReader reader;
axolotl::decode_message(
reader, input, input_length, ratchet_cipher.mac_length()
......@@ -458,6 +463,15 @@ std::size_t axolotl::Ratchet::decrypt(
return std::size_t(-1);
}
std::size_t max_length = ratchet_cipher.decrypt_max_plaintext_length(
reader.ciphertext_length
);
if (max_plaintext_length < max_length) {
last_error = axolotl::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
return std::size_t(-1);
}
if (reader.ratchet_key_length != KEY_LENGTH) {
last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT;
return std::size_t(-1);
......
......@@ -54,13 +54,11 @@ std::size_t message_length, random_length, output_length;
std::size_t encrypt_length, decrypt_length;
{
/* Bob sends Alice a message */
message_length = bob.encrypt_max_output_length(plaintext_length);
message_length = bob.encrypt_output_length(plaintext_length);
random_length = bob.encrypt_random_length();
assert_equals(std::size_t(0), random_length);
output_length = alice.decrypt_max_plaintext_length(message_length);
std::uint8_t message[message_length];
std::uint8_t output[output_length];
encrypt_length = bob.encrypt(
plaintext, plaintext_length,
......@@ -69,6 +67,8 @@ std::size_t encrypt_length, decrypt_length;
);
assert_equals(message_length, encrypt_length);
output_length = alice.decrypt_max_plaintext_length(message, message_length);
std::uint8_t output[output_length];
decrypt_length = alice.decrypt(
message, message_length,
output, output_length
......@@ -80,13 +80,11 @@ std::size_t encrypt_length, decrypt_length;
{
/* Alice sends Bob a message */
message_length = alice.encrypt_max_output_length(plaintext_length);
message_length = alice.encrypt_output_length(plaintext_length);
random_length = alice.encrypt_random_length();
assert_equals(std::size_t(32), random_length);
output_length = bob.decrypt_max_plaintext_length(message_length);
std::uint8_t message[message_length];
std::uint8_t output[output_length];
std::uint8_t random[] = "This is a random 32 byte string.";
encrypt_length = alice.encrypt(
......@@ -96,6 +94,8 @@ std::size_t encrypt_length, decrypt_length;
);
assert_equals(message_length, encrypt_length);
output_length = bob.decrypt_max_plaintext_length(message, message_length);
std::uint8_t output[output_length];
decrypt_length = bob.decrypt(
message, message_length,
output, output_length
......@@ -127,7 +127,7 @@ std::size_t encrypt_length, decrypt_length;
{
/* Alice sends Bob two messages and they arrive out of order */
message_1_length = alice.encrypt_max_output_length(plaintext_1_length);
message_1_length = alice.encrypt_output_length(plaintext_1_length);
random_length = alice.encrypt_random_length();
assert_equals(std::size_t(32), random_length);
......@@ -140,7 +140,7 @@ std::size_t encrypt_length, decrypt_length;
);
assert_equals(message_1_length, encrypt_length);
message_2_length = alice.encrypt_max_output_length(plaintext_2_length);
message_2_length = alice.encrypt_output_length(plaintext_2_length);
random_length = alice.encrypt_random_length();
assert_equals(std::size_t(0), random_length);
......@@ -152,7 +152,9 @@ std::size_t encrypt_length, decrypt_length;
);
assert_equals(message_2_length, encrypt_length);
output_length = bob.decrypt_max_plaintext_length(message_2_length);
output_length = bob.decrypt_max_plaintext_length(
message_2, message_2_length
);
std::uint8_t output_1[output_length];
decrypt_length = bob.decrypt(
message_2, message_2_length,
......@@ -161,7 +163,9 @@ std::size_t encrypt_length, decrypt_length;
assert_equals(plaintext_2_length, decrypt_length);
assert_equals(plaintext_2, output_1, decrypt_length);
output_length = bob.decrypt_max_plaintext_length(message_1_length);
output_length = bob.decrypt_max_plaintext_length(
message_1, message_1_length
);
std::uint8_t output_2[output_length];
decrypt_length = bob.decrypt(
message_1, message_1_length,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment