From 816435a86097a6609cb6e5ad422083bc49b19632 Mon Sep 17 00:00:00 2001
From: Mark Haines <mark.haines@matrix.org>
Date: Thu, 11 Jun 2015 14:20:35 +0100
Subject: [PATCH] Move AES specific details behind a cipher interface

---
 include/axolotl/cipher.hh  | 127 +++++++++++++++++++++
 include/axolotl/list.hh    |   2 +-
 include/axolotl/message.hh |  22 ++--
 include/axolotl/ratchet.hh |  32 +++---
 src/cipher.cpp             | 125 +++++++++++++++++++++
 src/message.cpp            |  39 ++++---
 src/ratchet.cpp            | 218 ++++++++++++++++---------------------
 tests/test_message.cpp     |  11 +-
 tests/test_ratchet.cpp     |  16 ++-
 9 files changed, 406 insertions(+), 186 deletions(-)
 create mode 100644 include/axolotl/cipher.hh
 create mode 100644 src/cipher.cpp

diff --git a/include/axolotl/cipher.hh b/include/axolotl/cipher.hh
new file mode 100644
index 0000000..93974fd
--- /dev/null
+++ b/include/axolotl/cipher.hh
@@ -0,0 +1,127 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef AXOLOTL_CIPHER_HH_
+#define AXOLOTL_CIPHER_HH_
+
+#include <cstdint>
+
+namespace axolotl {
+
+class Cipher {
+public:
+    virtual ~Cipher();
+
+    /**
+     * Returns the length of the message authentication code that will be
+     * appended to the output.
+     */
+    virtual std::size_t mac_length() const = 0;
+
+    /**
+     * Returns the length of cipher-text for a given length of plain-text.
+     */
+    virtual std::size_t encrypt_ciphertext_length(
+        std::size_t plaintext_length
+    ) const = 0;
+
+    /*
+     * Encrypts the plain-text into the output buffer and authenticates the
+     * contents of the output buffer covering both cipher-text and any other
+     * associated data in the output buffer.
+     *
+     *  |---------------------------------------output_length-->|
+     *  output  |--ciphertext_length-->|       |---mac_length-->|
+     *          ciphertext
+     *
+     * Returns std::size_t(-1) if the length of the cipher-text or the output
+     * buffer is too small. Otherwise returns the length of the output buffer.
+     */
+    virtual std::size_t encrypt(
+        std::uint8_t const * key, std::size_t key_length,
+        std::uint8_t const * plaintext, std::size_t plaintext_length,
+        std::uint8_t * ciphertext, std::size_t ciphertext_length,
+        std::uint8_t * output, std::size_t output_length
+    ) const = 0;
+
+    /**
+     * Returns the maximum length of plain-text that a given length of
+     * cipher-text can contain.
+     */
+    virtual std::size_t decrypt_max_plaintext_length(
+        std::size_t ciphertext_length
+    ) const = 0;
+
+    /**
+     * Authenticates the input and decrypts the cipher-text into the plain-text
+     * buffer.
+     *
+     *  |----------------------------------------input_length-->|
+     *  input   |--ciphertext_length-->|       |---mac_length-->|
+     *          ciphertext
+     *
+     *  Returns std::size_t(-1) if the length of the plain-text buffer is too
+     *  small or if the authentication check fails. Otherwise returns the length
+     *  of the plain text.
+     */
+    virtual std::size_t decrypt(
+       std::uint8_t const * key, std::size_t key_length,
+       std::uint8_t const * input, std::size_t input_length,
+       std::uint8_t const * ciphertext, std::size_t ciphertext_length,
+       std::uint8_t * plaintext, std::size_t max_plaintext_length
+    ) const = 0;
+};
+
+
+class CipherAesSha256 : public Cipher {
+public:
+    CipherAesSha256(
+        std::uint8_t const * kdf_info, std::size_t kdf_info_length
+    );
+
+    virtual std::size_t mac_length() const;
+
+    virtual std::size_t encrypt_ciphertext_length(
+        std::size_t plaintext_length
+    ) const;
+
+    virtual std::size_t encrypt(
+        std::uint8_t const * key, std::size_t key_length,
+        std::uint8_t const * plaintext, std::size_t plaintext_length,
+        std::uint8_t * ciphertext, std::size_t ciphertext_length,
+        std::uint8_t * output, std::size_t output_length
+    ) const;
+
+    virtual std::size_t decrypt_max_plaintext_length(
+        std::size_t ciphertext_length
+    ) const;
+
+    virtual std::size_t decrypt(
+        std::uint8_t const * key, std::size_t key_length,
+        std::uint8_t const * input, std::size_t input_length,
+        std::uint8_t const * ciphertext, std::size_t ciphertext_length,
+        std::uint8_t * plaintext, std::size_t max_plaintext_length
+    ) const;
+
+private:
+    std::uint8_t const * kdf_info;
+    std::size_t kdf_info_length;
+};
+
+
+} // namespace
+
+
+#endif /* AXOLOTL_CIPHER_HH_ */
diff --git a/include/axolotl/list.hh b/include/axolotl/list.hh
index d1407b8..ae8900c 100644
--- a/include/axolotl/list.hh
+++ b/include/axolotl/list.hh
@@ -92,7 +92,7 @@ public:
     }
 
     List<T, max_size> & operator=(List<T, max_size> const & other) {
-        if (this = &other) {
+        if (this == &other) {
             return *this;
         }
         T * this_pos = _data;
diff --git a/include/axolotl/message.hh b/include/axolotl/message.hh
index 5cd4211..cfbb715 100644
--- a/include/axolotl/message.hh
+++ b/include/axolotl/message.hh
@@ -30,22 +30,17 @@ std::size_t encode_message_length(
 
 
 struct MessageWriter {
-    std::size_t body_length;
     std::uint8_t * ratchet_key;
     std::uint8_t * ciphertext;
-    std::uint8_t * mac;
 };
 
 
 struct MessageReader {
-    std::size_t body_length;
     std::uint8_t version;
     std::uint32_t counter;
-    std::size_t ratchet_key_length;
-    std::size_t ciphertext_length;
-    std::uint8_t const * ratchet_key;
-    std::uint8_t const * ciphertext;
-    std::uint8_t const * mac;
+    std::uint8_t const * input; std::size_t input_length;
+    std::uint8_t const * ratchet_key; std::size_t ratchet_key_length;
+    std::uint8_t const * ciphertext; std::size_t ciphertext_length;
 };
 
 
@@ -53,7 +48,9 @@ struct MessageReader {
  * Writes the message headers into the output buffer.
  * Returns a writer struct populated with pointers into the output buffer.
  */
-MessageWriter encode_message(
+
+void encode_message(
+    MessageWriter & writer,
     std::uint8_t version,
     std::uint32_t counter,
     std::size_t ratchet_key_length,
@@ -64,10 +61,11 @@ MessageWriter encode_message(
 
 /**
  * Reads the message headers from the input buffer.
- * Returns a reader struct populated with pointers into the input buffer.
- * On failure the returned body_length will be 0.
+ * Populates the reader struct with pointers into the input buffer.
+ * On failure returns std::size_t(-1).
  */
-MessageReader decode_message(
+std::size_t decode_message(
+    MessageReader & reader,
     std::uint8_t const * input, std::size_t input_length,
     std::size_t mac_length
 );
diff --git a/include/axolotl/ratchet.hh b/include/axolotl/ratchet.hh
index cf41359..f4eeafa 100644
--- a/include/axolotl/ratchet.hh
+++ b/include/axolotl/ratchet.hh
@@ -18,6 +18,8 @@
 
 namespace axolotl {
 
+class Cipher;
+
 typedef std::uint8_t SharedKey[32];
 
 
@@ -29,9 +31,7 @@ struct ChainKey {
 
 struct MessageKey {
     std::uint32_t index;
-    Aes256Key cipher_key;
-    SharedKey mac_key;
-    Aes256Iv iv;
+    SharedKey key;
 };
 
 
@@ -72,21 +72,23 @@ struct KdfInfo {
     std::size_t root_info_length;
     std::uint8_t const * ratchet_info;
     std::size_t ratchet_info_length;
-    std::uint8_t const * message_info;
-    std::size_t message_info_length;
 };
 
 
 struct Session {
 
     Session(
-        KdfInfo const & kdf_info
+        KdfInfo const & kdf_info,
+        Cipher const & ratchet_cipher
     );
 
-    /** A some strings identifing the application to feed into the KDF. */
-    const KdfInfo &kdf_info;
+    /** A some strings identifying the application to feed into the KDF. */
+    KdfInfo const & kdf_info;
+
+    /** The AEAD cipher to use for encrypting messages. */
+    Cipher const & ratchet_cipher;
 
-    /** The last error that happened encypting or decrypting a message. */
+    /** The last error that happened encrypting or decrypting a message. */
     ErrorCode last_error;
 
     /** The root key is used to generate chain keys from the ephemeral keys.
@@ -98,7 +100,7 @@ struct Session {
      * with a new empheral key when we next send a message. */
     List<SenderChain, 1> sender_chain;
 
-    /** The receiver chain is used to decrypt recieved messages. We store the
+    /** The receiver chain is used to decrypt received messages. We store the
      * last few chains so we can decrypt any out of order messages we haven't
      * received yet. */
     List<ReceiverChain, MAX_RECEIVER_CHAINS> receiver_chains;
@@ -114,7 +116,7 @@ struct Session {
         Curve25519PublicKey const & their_ratchet_key
     );
 
-    /** Intialise the session using a shared secret and the public/private key
+    /** Initialise the session using a shared secret and the public/private key
      * pair for the first ratchet key */
     void initialise_as_alice(
         std::uint8_t const * shared_secret, std::size_t shared_secret_length,
@@ -150,7 +152,7 @@ struct Session {
      * generate a new ephemeral key, or will be 0 bytes otherwise.*/
     std::size_t encrypt_random_length();
 
-    /** Encrypt some plaintext. Returns the length of the encrypted message
+    /** Encrypt some plain-text. Returns the length of the encrypted message
      * or std::size_t(-1) on failure. On failure last_error will be set with
      * an error code. The last_error will be NOT_ENOUGH_RANDOM if the number
      * of random bytes is too small. The last_error will be
@@ -161,16 +163,16 @@ struct Session {
         std::uint8_t * output, std::size_t max_output_length
     );
 
-    /** An upper bound on the number of bytes of plaintext the decrypt method
+    /** An upper bound on the number of bytes of plain-text the decrypt method
      * will write for a given input message length. */
     std::size_t decrypt_max_plaintext_length(
         std::size_t input_length
     );
 
-    /** Decrypt a message. Returns the length of the decrypted plaintext or
+    /** Decrypt a message. Returns the length of the decrypted plain-text or
      * std::size_t(-1) on failure. On failure last_error will be set with an
      * error code. The last_error will be OUTPUT_BUFFER_TOO_SMALL if the
-     * plaintext buffer is too small. The last_error will be
+     * plain-text buffer is too small. The last_error will be
      * BAD_MESSAGE_VERSION if the message was encrypted with an unsupported
      * version of the protocol. The last_error will be BAD_MESSAGE_FORMAT if
      * the message headers could not be decoded. The last_error will be
diff --git a/src/cipher.cpp b/src/cipher.cpp
new file mode 100644
index 0000000..86cde88
--- /dev/null
+++ b/src/cipher.cpp
@@ -0,0 +1,125 @@
+#include "axolotl/cipher.hh"
+#include "axolotl/crypto.hh"
+#include "axolotl/memory.hh"
+#include <cstring>
+
+axolotl::Cipher::~Cipher() {
+
+}
+
+namespace {
+
+static const std::size_t SHA256_LENGTH = 32;
+
+struct DerivedKeys {
+    axolotl::Aes256Key aes_key;
+    std::uint8_t mac_key[SHA256_LENGTH];
+    axolotl::Aes256Iv aes_iv;
+};
+
+
+static void derive_keys(
+    std::uint8_t const * kdf_info, std::size_t kdf_info_length,
+    std::uint8_t const * key, std::size_t key_length,
+    DerivedKeys & keys
+) {
+    std::uint8_t derived_secrets[80];
+    axolotl::hkdf_sha256(
+        key, key_length,
+        NULL, 0,
+        kdf_info, kdf_info_length,
+        derived_secrets, sizeof(derived_secrets)
+    );
+    std::memcpy(keys.aes_key.key, derived_secrets, 32);
+    std::memcpy(keys.mac_key, derived_secrets + 32, 32);
+    std::memcpy(keys.aes_iv.iv, derived_secrets + 64, 16);
+    axolotl::unset(derived_secrets);
+}
+
+static const std::size_t MAC_LENGTH = 8;
+
+} // namespace
+
+
+axolotl::CipherAesSha256::CipherAesSha256(
+    std::uint8_t const * kdf_info, std::size_t kdf_info_length
+) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) {
+
+}
+
+
+std::size_t axolotl::CipherAesSha256::mac_length() const {
+    return MAC_LENGTH;
+}
+
+
+std::size_t axolotl::CipherAesSha256::encrypt_ciphertext_length(
+    std::size_t plaintext_length
+) const {
+    return axolotl::aes_encrypt_cbc_length(plaintext_length);
+}
+
+
+std::size_t axolotl::CipherAesSha256::encrypt(
+    std::uint8_t const * key, std::size_t key_length,
+    std::uint8_t const * plaintext, std::size_t plaintext_length,
+    std::uint8_t * ciphertext, std::size_t ciphertext_length,
+    std::uint8_t * output, std::size_t output_length
+) const {
+    if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) {
+        return std::size_t(-1);
+    }
+    struct DerivedKeys keys;
+    std::uint8_t mac[SHA256_LENGTH];
+
+    derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
+
+    axolotl::aes_encrypt_cbc(
+        keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext
+    );
+
+    axolotl::hmac_sha256(
+        keys.mac_key, SHA256_LENGTH, output, output_length - MAC_LENGTH, mac
+    );
+
+    std::memcpy(output + output_length - MAC_LENGTH, mac, MAC_LENGTH);
+
+    axolotl::unset(keys);
+    return output_length;
+}
+
+
+std::size_t axolotl::CipherAesSha256::decrypt_max_plaintext_length(
+    std::size_t ciphertext_length
+) const {
+    return ciphertext_length;
+}
+
+std::size_t axolotl::CipherAesSha256::decrypt(
+     std::uint8_t const * key, std::size_t key_length,
+     std::uint8_t const * input, std::size_t input_length,
+     std::uint8_t const * ciphertext, std::size_t ciphertext_length,
+     std::uint8_t * plaintext, std::size_t max_plaintext_length
+) const {
+    DerivedKeys keys;
+    std::uint8_t mac[SHA256_LENGTH];
+
+    derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
+
+    axolotl::hmac_sha256(
+        keys.mac_key, SHA256_LENGTH, input, input_length - MAC_LENGTH, mac
+    );
+
+    std::uint8_t const * input_mac = input + input_length - MAC_LENGTH;
+    if (!axolotl::is_equal(input_mac, mac, MAC_LENGTH)) {
+        axolotl::unset(keys);
+        return std::size_t(-1);
+    }
+
+    std::size_t plaintext_length = axolotl::aes_decrypt_cbc(
+        keys.aes_key, keys.aes_iv, ciphertext, ciphertext_length, plaintext
+    );
+
+    axolotl::unset(keys);
+    return plaintext_length;
+}
diff --git a/src/message.cpp b/src/message.cpp
index 289faa3..46cadd8 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -94,55 +94,52 @@ std::size_t axolotl::encode_message_length(
     length += 1 + varstring_length(ratchet_key_length);
     length += 1 + varint_length(counter);
     length += 1 + varstring_length(ciphertext_length);
-    return length + mac_length;
+    length += mac_length;
+    return length;
 }
 
 
-axolotl::MessageWriter axolotl::encode_message(
+void axolotl::encode_message(
+    axolotl::MessageWriter & writer,
     std::uint8_t version,
     std::uint32_t counter,
     std::size_t ratchet_key_length,
     std::size_t ciphertext_length,
     std::uint8_t * output
 ) {
-    axolotl::MessageWriter result;
     std::uint8_t * pos = output;
     *(pos++) = version;
     *(pos++) = COUNTER_TAG;
     pos = varint_encode(pos, counter);
     *(pos++) = RATCHET_KEY_TAG;
     pos = varint_encode(pos, ratchet_key_length);
-    result.ratchet_key = pos;
+    writer.ratchet_key = pos;
     pos += ratchet_key_length;
     *(pos++) = CIPHERTEXT_TAG;
     pos = varint_encode(pos, ciphertext_length);
-    result.ciphertext = pos;
+    writer.ciphertext = pos;
     pos += ciphertext_length;
-    result.body_length = pos - output;
-    result.mac = pos;
-    return result;
 }
 
 
-axolotl::MessageReader axolotl::decode_message(
+std::size_t axolotl::decode_message(
+    axolotl::MessageReader & reader,
     std::uint8_t const * input, std::size_t input_length,
     std::size_t mac_length
 ) {
-    axolotl::MessageReader result;
-    result.body_length = 0;
     std::uint8_t const * pos = input;
     std::uint8_t const * end = input + input_length - mac_length;
     std::uint8_t flags = 0;
-    result.mac = end;
+    std::size_t result = std::size_t(-1);
     if (pos == end) return result;
-    result.version = *(pos++);
+    reader.version = *(pos++);
     while (pos != end) {
         uint8_t tag = *(pos);
         if (tag == COUNTER_TAG) {
             ++pos;
             std::uint8_t const * counter_start = pos;
             pos = varint_skip(pos, end);
-            result.counter = varint_decode<std::uint32_t>(counter_start, pos);
+            reader.counter = varint_decode<std::uint32_t>(counter_start, pos);
             flags |= 1;
         } else if (tag == RATCHET_KEY_TAG) {
             ++pos;
@@ -150,8 +147,8 @@ axolotl::MessageReader axolotl::decode_message(
             pos = varint_skip(pos, end);
             std::size_t len = varint_decode<std::size_t>(len_start, pos);
             if (len > end - pos) return result;
-            result.ratchet_key_length = len;
-            result.ratchet_key = pos;
+            reader.ratchet_key_length = len;
+            reader.ratchet_key = pos;
             pos += len;
             flags |= 2;
         } else if (tag == CIPHERTEXT_TAG) {
@@ -160,8 +157,8 @@ axolotl::MessageReader axolotl::decode_message(
             pos = varint_skip(pos, end);
             std::size_t len = varint_decode<std::size_t>(len_start, pos);
             if (len > end - pos) return result;
-            result.ciphertext_length = len;
-            result.ciphertext = pos;
+            reader.ciphertext_length = len;
+            reader.ciphertext = pos;
             pos += len;
             flags |= 4;
         } else if (tag & 0x7 == 0) {
@@ -174,11 +171,13 @@ axolotl::MessageReader axolotl::decode_message(
             if (len > end - pos) return result;
             pos += len;
         } else {
-            return result;
+            return std::size_t(-1);
         }
     }
     if (flags == 0x7) {
-        result.body_length = end - input;
+        reader.input = input;
+        reader.input_length = input_length;
+        return std::size_t(pos - input);
     }
     return result;
 }
diff --git a/src/ratchet.cpp b/src/ratchet.cpp
index b17e162..cd4f8f7 100644
--- a/src/ratchet.cpp
+++ b/src/ratchet.cpp
@@ -15,13 +15,13 @@
 #include "axolotl/ratchet.hh"
 #include "axolotl/message.hh"
 #include "axolotl/memory.hh"
+#include "axolotl/cipher.hh"
 
 #include <cstring>
 
 namespace {
 
 std::uint8_t PROTOCOL_VERSION = 3;
-std::size_t MAC_LENGTH = 8;
 std::size_t KEY_LENGTH = axolotl::Curve25519PublicKey::LENGTH;
 std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
 std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
@@ -70,59 +70,43 @@ void create_message_keys(
     axolotl::KdfInfo const & info,
     axolotl::MessageKey & message_key
 ) {
-    axolotl::SharedKey secret;
     axolotl::hmac_sha256(
         chain_key.key, sizeof(chain_key.key),
         MESSAGE_KEY_SEED, sizeof(MESSAGE_KEY_SEED),
-        secret
+        message_key.key
     );
-    std::uint8_t derived_secrets[80];
-    axolotl::hkdf_sha256(
-        secret, sizeof(secret),
-        NULL, 0,
-        info.message_info, info.message_info_length,
-        derived_secrets, sizeof(derived_secrets)
-    );
-    std::memcpy(message_key.cipher_key.key, derived_secrets, 32);
-    std::memcpy(message_key.mac_key, derived_secrets + 32, 32);
-    std::memcpy(message_key.iv.iv, derived_secrets + 64, 16);
     message_key.index = chain_key.index;
-    axolotl::unset(derived_secrets);
-    axolotl::unset(secret);
 }
 
 
-bool verify_mac(
+std::size_t verify_mac_and_decrypt(
+    axolotl::Cipher const & cipher,
     axolotl::MessageKey const & message_key,
-    std::uint8_t const * input,
-    axolotl::MessageReader const & reader
+    axolotl::MessageReader const & reader,
+    std::uint8_t * plaintext, std::size_t max_plaintext_length
 ) {
-    std::uint8_t mac[axolotl::HMAC_SHA256_OUTPUT_LENGTH];
-    axolotl::hmac_sha256(
-        message_key.mac_key, sizeof(message_key.mac_key),
-        input, reader.body_length,
-        mac
+    return cipher.decrypt(
+        message_key.key, sizeof(message_key.key),
+        reader.input, reader.input_length,
+        reader.ciphertext, reader.ciphertext_length,
+        plaintext, max_plaintext_length
     );
-
-    bool result = axolotl::is_equal(mac, reader.mac, MAC_LENGTH);
-    axolotl::unset(mac);
-    return result;
 }
 
 
-bool verify_mac_for_existing_chain(
+std::size_t verify_mac_and_decrypt_for_existing_chain(
     axolotl::Session const & session,
     axolotl::ChainKey const & chain,
-    std::uint8_t const * input,
-    axolotl::MessageReader const & reader
+    axolotl::MessageReader const & reader,
+    std::uint8_t * plaintext, std::size_t max_plaintext_length
 ) {
     if (reader.counter < chain.index) {
-        return false;
+        return std::size_t(-1);
     }
 
     /* Limit the number of hashes we're prepared to compute */
     if (reader.counter - chain.index > MAX_MESSAGE_GAP) {
-        return false;
+        return std::size_t(-1);
     }
 
     axolotl::ChainKey new_chain = chain;
@@ -134,16 +118,20 @@ bool verify_mac_for_existing_chain(
     axolotl::MessageKey message_key;
     create_message_keys(new_chain, session.kdf_info, message_key);
 
-    bool result = verify_mac(message_key, input, reader);
+    std::size_t result = verify_mac_and_decrypt(
+        session.ratchet_cipher, message_key, reader,
+        plaintext, max_plaintext_length
+    );
+
     axolotl::unset(new_chain);
     return result;
 }
 
 
-bool verify_mac_for_new_chain(
+std::size_t verify_mac_and_decrypt_for_new_chain(
     axolotl::Session const & session,
-    std::uint8_t const * input,
-    axolotl::MessageReader const & reader
+    axolotl::MessageReader const & reader,
+    std::uint8_t * plaintext, std::size_t max_plaintext_length
 ) {
     axolotl::SharedKey new_root_key;
     axolotl::ReceiverChain new_chain;
@@ -168,8 +156,9 @@ bool verify_mac_for_new_chain(
         new_root_key, new_chain.chain_key
     );
 
-    bool result = verify_mac_for_existing_chain(
-        session, new_chain.chain_key, input, reader
+    std::size_t result = verify_mac_and_decrypt_for_existing_chain(
+        session, new_chain.chain_key, reader,
+        plaintext, max_plaintext_length
     );
     axolotl::unset(new_root_key);
     axolotl::unset(new_chain);
@@ -180,8 +169,11 @@ bool verify_mac_for_new_chain(
 
 
 axolotl::Session::Session(
-    axolotl::KdfInfo const & kdf_info
-) : kdf_info(kdf_info), last_error(axolotl::ErrorCode::SUCCESS) {
+    axolotl::KdfInfo const & kdf_info,
+    Cipher const & ratchet_cipher
+) : kdf_info(kdf_info),
+    ratchet_cipher(ratchet_cipher),
+    last_error(axolotl::ErrorCode::SUCCESS) {
 }
 
 
@@ -232,7 +224,7 @@ std::size_t axolotl::Session::pickle_max_output_length() {
     pickle_length += sender_chain.size() * send_chain_length;
     pickle_length += receiver_chains.size() * recv_chain_length;
     pickle_length += skipped_message_keys.size() * skip_key_length;
-    return axolotl::aes_encrypt_cbc_length(pickle_length) + MAC_LENGTH;
+    return pickle_length;
 }
 
 namespace {
@@ -299,11 +291,10 @@ std::size_t axolotl::Session::pickle(
     }
     for (const axolotl::SkippedMessageKey &key : skipped_message_keys) {
         pos = pickle_counter(pos, key.message_key.index);
-        pos = pickle_bytes(pos, 32, key.message_key.cipher_key.key);
-        pos = pickle_bytes(pos, 32, key.message_key.mac_key);
-        pos = pickle_bytes(pos, 16, key.message_key.iv.iv);
+        pos = pickle_bytes(pos, 32, key.message_key.key);
         pos = pickle_bytes(pos, 32, key.ratchet_key.public_key);
     }
+    return pos - output;
 }
 
 std::size_t axolotl::Session::unpickle(
@@ -352,11 +343,10 @@ std::size_t axolotl::Session::unpickle(
             skipped_message_keys.end()
         );
         pos = unpickle_counter(pos, key.message_key.index);
-        pos = unpickle_bytes(pos, 32, key.message_key.cipher_key.key);
-        pos = unpickle_bytes(pos, 32, key.message_key.mac_key);
-        pos = unpickle_bytes(pos, 16, key.message_key.iv.iv);
+        pos = unpickle_bytes(pos, 32, key.message_key.key);
         pos = unpickle_bytes(pos, 32, key.ratchet_key.public_key);
     }
+    return pos - input;
 }
 
 
@@ -369,7 +359,7 @@ std::size_t axolotl::Session::encrypt_max_output_length(
     }
     std::size_t padded = axolotl::aes_encrypt_cbc_length(plaintext_length);
     return axolotl::encode_message_length(
-        counter, KEY_LENGTH, padded, MAC_LENGTH
+        counter, KEY_LENGTH, padded, ratchet_cipher.mac_length()
     );
 }
 
@@ -384,11 +374,13 @@ std::size_t axolotl::Session::encrypt(
     std::uint8_t const * random, std::size_t random_length,
     std::uint8_t * output, std::size_t max_output_length
 ) {
+    std::size_t output_length = encrypt_max_output_length(plaintext_length);
+
     if (random_length < encrypt_random_length()) {
         last_error = axolotl::ErrorCode::NOT_ENOUGH_RANDOM;
         return std::size_t(-1);
     }
-    if (max_output_length < encrypt_max_output_length(plaintext_length)) {
+    if (max_output_length < output_length) {
         last_error = axolotl::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
         return std::size_t(-1);
     }
@@ -409,32 +401,29 @@ std::size_t axolotl::Session::encrypt(
     create_message_keys(sender_chain[0].chain_key, kdf_info, keys);
     advance_chain_key(sender_chain[0].chain_key, sender_chain[0].chain_key);
 
-    std::size_t padded = axolotl::aes_encrypt_cbc_length(plaintext_length);
+    std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length(
+        plaintext_length
+    );
     std::uint32_t counter = keys.index;
     Curve25519PublicKey const & ratchet_key = sender_chain[0].ratchet_key;
 
-    axolotl::MessageWriter writer(axolotl::encode_message(
-        PROTOCOL_VERSION, counter, KEY_LENGTH, padded, output
-    ));
+    axolotl::MessageWriter writer;
+
+    axolotl::encode_message(
+        writer, PROTOCOL_VERSION, counter, KEY_LENGTH, ciphertext_length, output
+    );
 
     std::memcpy(writer.ratchet_key, ratchet_key.public_key, KEY_LENGTH);
 
-    axolotl::aes_encrypt_cbc(
-        keys.cipher_key, keys.iv,
+    ratchet_cipher.encrypt(
+        keys.key, sizeof(keys.key),
         plaintext, plaintext_length,
-        writer.ciphertext
+        writer.ciphertext, ciphertext_length,
+        output, output_length
     );
 
-    std::uint8_t mac[axolotl::HMAC_SHA256_OUTPUT_LENGTH];
-    axolotl::hmac_sha256(
-        keys.mac_key, sizeof(keys.mac_key),
-        output, writer.body_length,
-        mac
-    );
-    std::memcpy(writer.mac, mac, MAC_LENGTH);
-
     axolotl::unset(keys);
-    return writer.body_length + MAC_LENGTH;
+    return output_length;
 }
 
 
@@ -454,16 +443,17 @@ std::size_t axolotl::Session::decrypt(
         return std::size_t(-1);
     }
 
-    axolotl::MessageReader reader(axolotl::decode_message(
-        input, input_length, MAC_LENGTH
-    ));
+    axolotl::MessageReader reader;
+    std::size_t body_length = axolotl::decode_message(
+        reader, input, input_length, ratchet_cipher.mac_length()
+    );
 
     if (reader.version != PROTOCOL_VERSION) {
         last_error = axolotl::ErrorCode::BAD_MESSAGE_VERSION;
         return std::size_t(-1);
     }
 
-    if (reader.body_length == 0 || reader.ratchet_key_length != KEY_LENGTH) {
+    if (body_length == size_t(-1) || reader.ratchet_key_length != KEY_LENGTH) {
         last_error = axolotl::ErrorCode::BAD_MESSAGE_FORMAT;
         return std::size_t(-1);
     }
@@ -479,40 +469,30 @@ std::size_t axolotl::Session::decrypt(
         }
     }
 
-    if (!chain) {
-        if (!verify_mac_for_new_chain(*this, input, reader)) {
-            last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-            return std::size_t(-1);
-        }
-    } else {
-        if (chain->chain_key.index > reader.counter) {
-            /* Chain already advanced beyond the key for this message
-             * Check if the message keys are in the skipped key list. */
-            for (axolotl::SkippedMessageKey & skipped : skipped_message_keys) {
-                if (reader.counter == skipped.message_key.index
-                        && 0 == std::memcmp(
-                            skipped.ratchet_key.public_key, reader.ratchet_key,
-                            KEY_LENGTH
-                        )
-                ) {
-                    /* Found the key for this message. Check the MAC. */
-                    if (!verify_mac(skipped.message_key, input, reader)) {
-                        last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-                        return std::size_t(-1);
-                    }
-
-                    std::size_t result = axolotl::aes_decrypt_cbc(
-                        skipped.message_key.cipher_key,
-                        skipped.message_key.iv,
-                        reader.ciphertext, reader.ciphertext_length,
-                        plaintext
-                    );
-
-                    if (result == std::size_t(-1)) {
-                        last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-                        return result;
-                    }
+    std::size_t result = std::size_t(-1);
 
+    if (!chain) {
+        result = verify_mac_and_decrypt_for_new_chain(
+            *this, reader, plaintext, max_plaintext_length
+        );
+    } else if (chain->chain_key.index > reader.counter) {
+        /* Chain already advanced beyond the key for this message
+         * Check if the message keys are in the skipped key list. */
+        for (axolotl::SkippedMessageKey & skipped : skipped_message_keys) {
+            if (reader.counter == skipped.message_key.index
+                    && 0 == std::memcmp(
+                        skipped.ratchet_key.public_key, reader.ratchet_key,
+                        KEY_LENGTH
+                    )
+            ) {
+                /* Found the key for this message. Check the MAC. */
+
+                result = verify_mac_and_decrypt(
+                    ratchet_cipher, skipped.message_key, reader,
+                    plaintext, max_plaintext_length
+                );
+
+                if (result != std::size_t(-1)) {
                     /* Remove the key from the skipped keys now that we've
                      * decoded the message it corresponds to. */
                     axolotl::unset(skipped);
@@ -520,15 +500,16 @@ std::size_t axolotl::Session::decrypt(
                     return result;
                 }
             }
-            /* No matching keys for the message, fail with bad mac */
-            last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-            return std::size_t(-1);
-        } else if (!verify_mac_for_existing_chain(
-               *this, chain->chain_key, input, reader
-        )) {
-            last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-            return std::size_t(-1);
         }
+    } else {
+        result = verify_mac_and_decrypt_for_existing_chain(
+            *this, chain->chain_key, reader, plaintext, max_plaintext_length
+        );
+    }
+
+    if (result == std::size_t(-1)) {
+        last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
+        return std::size_t(-1);
     }
 
     if (!chain) {
@@ -555,22 +536,7 @@ std::size_t axolotl::Session::decrypt(
         advance_chain_key(chain->chain_key, chain->chain_key);
     }
 
-    axolotl::MessageKey message_key;
-    create_message_keys(chain->chain_key, kdf_info, message_key);
-    std::size_t result = axolotl::aes_decrypt_cbc(
-        message_key.cipher_key,
-        message_key.iv,
-        reader.ciphertext, reader.ciphertext_length,
-        plaintext
-    );
-    axolotl::unset(message_key);
-
     advance_chain_key(chain->chain_key, chain->chain_key);
 
-    if (result == std::size_t(-1)) {
-        last_error = axolotl::ErrorCode::BAD_MESSAGE_MAC;
-        return std::size_t(-1);
-    } else {
-        return result;
-    }
+    return result;
 }
diff --git a/tests/test_message.cpp b/tests/test_message.cpp
index ca36ff5..9c0ab4a 100644
--- a/tests/test_message.cpp
+++ b/tests/test_message.cpp
@@ -27,9 +27,9 @@ std::uint8_t hmacsha2[9] = "hmacsha2";
 
 TestCase test_case("Message decode test");
 
-axolotl::MessageReader reader(axolotl::decode_message(message1, 35, 8));
+axolotl::MessageReader reader;
+axolotl::decode_message(reader, message1, 35, 8);
 
-assert_equals(std::size_t(27), reader.body_length);
 assert_equals(std::uint8_t(3), reader.version);
 assert_equals(std::uint32_t(1), reader.counter);
 assert_equals(std::size_t(10), reader.ratchet_key_length);
@@ -37,7 +37,6 @@ assert_equals(std::size_t(10), reader.ciphertext_length);
 
 assert_equals(ratchetkey, reader.ratchet_key, 10);
 assert_equals(ciphertext, reader.ciphertext, 10);
-assert_equals(hmacsha2, reader.mac, 8);
 
 
 } /* Message decode test */
@@ -51,12 +50,12 @@ assert_equals(std::size_t(35), length);
 
 std::uint8_t output[length];
 
-axolotl::MessageWriter writer(axolotl::encode_message(3, 1, 10, 10, output));
-assert_equals(std::size_t(27), writer.body_length);
+axolotl::MessageWriter writer;
+axolotl::encode_message(writer, 3, 1, 10, 10, output);
 
 std::memcpy(writer.ratchet_key, ratchetkey, 10);
 std::memcpy(writer.ciphertext, ciphertext, 10);
-std::memcpy(writer.mac, hmacsha2, 8);
+std::memcpy(output + length - 8, hmacsha2, 8);
 
 assert_equals(message2, output, 35);
 
diff --git a/tests/test_ratchet.cpp b/tests/test_ratchet.cpp
index 95391e3..18c22e3 100644
--- a/tests/test_ratchet.cpp
+++ b/tests/test_ratchet.cpp
@@ -13,6 +13,7 @@
  * limitations under the License.
  */
 #include "axolotl/ratchet.hh"
+#include "axolotl/cipher.hh"
 #include "unittest.hh"
 
 
@@ -24,10 +25,13 @@ std::uint8_t message_info[] = "AxolotlMessageKeys";
 
 axolotl::KdfInfo kdf_info = {
     root_info, sizeof(root_info) - 1,
-    ratchet_info, sizeof(ratchet_info - 1),
-    message_info, sizeof(ratchet_info - 1)
+    ratchet_info, sizeof(ratchet_info) - 1
 };
 
+axolotl::CipherAesSha256 cipher(
+    message_info, sizeof(message_info) - 1
+);
+
 std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF";
 axolotl::Curve25519KeyPair bob_key;
 axolotl::generate_key(random_bytes, bob_key);
@@ -37,8 +41,8 @@ std::uint8_t shared_secret[] = "A secret";
 { /* Send/Receive test case */
 TestCase test_case("Axolotl Send/Receive");
 
-axolotl::Session alice(kdf_info);
-axolotl::Session bob(kdf_info);
+axolotl::Session alice(kdf_info, cipher);
+axolotl::Session bob(kdf_info, cipher);
 
 alice.initialise_as_bob(shared_secret, sizeof(shared_secret) - 1, bob_key);
 bob.initialise_as_alice(shared_secret, sizeof(shared_secret) - 1, bob_key);
@@ -106,8 +110,8 @@ std::size_t encrypt_length, decrypt_length;
 
 TestCase test_case("Axolotl Out of Order");
 
-axolotl::Session alice(kdf_info);
-axolotl::Session bob(kdf_info);
+axolotl::Session alice(kdf_info, cipher);
+axolotl::Session bob(kdf_info, cipher);
 
 alice.initialise_as_bob(shared_secret, sizeof(shared_secret) - 1, bob_key);
 bob.initialise_as_alice(shared_secret, sizeof(shared_secret) - 1, bob_key);
-- 
GitLab