Skip to content
Snippets Groups Projects
Commit a08d7063 authored by Mark Haines's avatar Mark Haines
Browse files

Add methods for pickling and unpickling sessions

parent 8123ce62
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ struct Curve25519PublicKey {
struct Curve25519KeyPair : public Curve25519PublicKey {
static const int LENGTH = 64;
std::uint8_t private_key[32];
};
......
......@@ -84,7 +84,7 @@ struct Session {
);
/** A some strings identifing the application to feed into the KDF. */
KdfInfo kdf_info;
const KdfInfo &kdf_info;
/** The last error that happened encypting or decrypting a message. */
ErrorCode last_error;
......@@ -121,6 +121,24 @@ struct Session {
Curve25519KeyPair const & our_ratchet_key
);
/** The number of bytes needed to persist the current session. */
std::size_t pickle_max_output_length();
/** Persists a session as a sequence of bytes, encrypting using a key
* Returns the number of output bytes used. */
std::size_t pickle(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, std::size_t max_output_length
);
/** Loads a session from a sequence of bytes, decrypting using a key.
* Returns 0 on success, or std::size_t(-1) on failure. The last_error
* will be BAD_SESSION_KEY if the supplied key is incorrect. */
std::size_t unpickle(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * input, std::size_t input_length
);
/** The maximum number of bytes of output the encrypt method will write for
* a given message length. */
std::size_t encrypt_max_output_length(
......
......@@ -223,6 +223,143 @@ void axolotl::Session::initialise_as_alice(
}
std::size_t axolotl::Session::pickle_max_output_length() {
std::size_t counter_length = 4;
std::size_t send_chain_length = counter_length + 64 + 32;
std::size_t recv_chain_length = counter_length + 32 + 32;
std::size_t skip_key_length = counter_length + 32 + 32 + 32 + 16;
std::size_t pickle_length = 3 * counter_length + 32;
pickle_length += sender_chain.size() * send_chain_length;
pickle_length += receiver_chains.size() * recv_chain_length;
pickle_length += skipped_message_keys.size() * skip_key_length;
return axolotl::aes_encrypt_cbc_length(pickle_length) + MAC_LENGTH;
}
namespace {
std::uint8_t * pickle_counter(
std::uint8_t * output, std::uint32_t value
) {
unsigned i = 4;
output += 4;
while (i--) { *(--output) = value; value >>= 8; }
return output + 4;
}
std::uint8_t * unpickle_counter(
std::uint8_t *input, std::uint32_t &value
) {
unsigned i = 4;
value = 0;
while (i--) { value <<= 8; value |= *(input++); }
return input;
}
std::uint8_t * pickle_bytes(
std::uint8_t * output, std::size_t count, std::uint8_t const * bytes
) {
std::memcpy(output, bytes, count);
return output + count;
}
std::uint8_t * unpickle_bytes(
std::uint8_t * input, std::size_t count, std::uint8_t * bytes
) {
std::memcpy(bytes, input, count);
return input + count;
}
} // namespace
std::size_t axolotl::Session::pickle(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, std::size_t max_output_length
) {
std::uint8_t * pos = output;
if (max_output_length < pickle_max_output_length()) {
last_error = axolotl::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
return std::size_t(-1);
}
pos = pickle_counter(pos, sender_chain.size());
pos = pickle_counter(pos, receiver_chains.size());
pos = pickle_counter(pos, skipped_message_keys.size());
pos = pickle_bytes(pos, 32, root_key);
for (const axolotl::SenderChain &chain : sender_chain) {
pos = pickle_counter(pos, chain.chain_key.index);
pos = pickle_bytes(pos, 32, chain.chain_key.key);
pos = pickle_bytes(pos, 32, chain.ratchet_key.public_key);
pos = pickle_bytes(pos, 32, chain.ratchet_key.private_key);
}
for (const axolotl::ReceiverChain &chain : receiver_chains) {
pos = pickle_counter(pos, chain.chain_key.index);
pos = pickle_bytes(pos, 32, chain.chain_key.key);
pos = pickle_bytes(pos, 32, chain.ratchet_key.public_key);
}
for (const axolotl::SkippedMessageKey &key : skipped_message_keys) {
pos = pickle_counter(pos, key.message_key.index);
pos = pickle_bytes(pos, 32, key.message_key.cipher_key.key);
pos = pickle_bytes(pos, 32, key.message_key.mac_key);
pos = pickle_bytes(pos, 16, key.message_key.iv.iv);
pos = pickle_bytes(pos, 32, key.ratchet_key.public_key);
}
}
std::size_t axolotl::Session::unpickle(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * input, std::size_t input_length
) {
std::uint8_t * pos = input;
std::uint8_t * end = input + input_length;
std::uint32_t send_chain_num, recv_chain_num, skipped_num;
if (end - pos < 4 * 3 + 32) {} // input too small.
pos = unpickle_counter(pos, send_chain_num);
pos = unpickle_counter(pos, recv_chain_num);
pos = unpickle_counter(pos, skipped_num);
pos = unpickle_bytes(pos, 32, root_key);
if (end - pos < send_chain_num * (32 * 3 + 4)) {} // input too small.
while (send_chain_num--) {
axolotl::SenderChain & chain = *sender_chain.insert(
sender_chain.end()
);
pos = unpickle_counter(pos, chain.chain_key.index);
pos = unpickle_bytes(pos, 32, chain.chain_key.key);
pos = unpickle_bytes(pos, 32, chain.ratchet_key.public_key);
pos = unpickle_bytes(pos, 32, chain.ratchet_key.private_key);
}
if (end - pos < recv_chain_num * (32 * 2 + 4)) {} // input too small.
while (recv_chain_num--) {
axolotl::ReceiverChain & chain = *receiver_chains.insert(
receiver_chains.end()
);
pos = unpickle_counter(pos, chain.chain_key.index);
pos = unpickle_bytes(pos, 32, chain.chain_key.key);
pos = unpickle_bytes(pos, 32, chain.ratchet_key.public_key);
}
if (end - pos < skipped_num * (32 * 3 + 16 + 4)) {} // input too small.
while (skipped_num--) {
axolotl::SkippedMessageKey &key = *skipped_message_keys.insert(
skipped_message_keys.end()
);
pos = unpickle_counter(pos, key.message_key.index);
pos = unpickle_bytes(pos, 32, key.message_key.cipher_key.key);
pos = unpickle_bytes(pos, 32, key.message_key.mac_key);
pos = unpickle_bytes(pos, 16, key.message_key.iv.iv);
pos = unpickle_bytes(pos, 32, key.ratchet_key.public_key);
}
}
std::size_t axolotl::Session::encrypt_max_output_length(
std::size_t plaintext_length
) {
......
......@@ -26,7 +26,11 @@ source_files = glob.glob("src/*.cpp")
compile_args = "g++ -g -O0 -Itests/include -Iinclude -Ilib --std=c++11".split()
compile_args += source_files
def run(args):
print " ".join(args)
subprocess.check_call(args)
for test_file in test_files:
exe_file = "build/" + test_file[5:-4]
subprocess.check_call(compile_args + [test_file, "-o", exe_file])
subprocess.check_call([exe_file])
run(compile_args + [test_file, "-o", exe_file])
run([exe_file])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment