From 39ad75314b9e28053f568ed6a4109f5d3a9468fe Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Wed, 18 May 2016 17:23:09 +0100
Subject: [PATCH] Implement decrypting inbound group messages

Includes creation of inbound sessions, etc
---
 include/olm/error.h                 |   3 +
 include/olm/inbound_group_session.h | 153 +++++++++++++++++++++
 include/olm/message.h               |  24 ++++
 include/olm/olm.h                   |   1 +
 src/inbound_group_session.c         | 199 ++++++++++++++++++++++++++++
 src/message.cpp                     |  42 ++++++
 tests/test_group_session.cpp        |  42 +++++-
 tests/test_message.cpp              |  22 +++
 8 files changed, 480 insertions(+), 6 deletions(-)
 create mode 100644 include/olm/inbound_group_session.h
 create mode 100644 src/inbound_group_session.c

diff --git a/include/olm/error.h b/include/olm/error.h
index 87e019a..3f74992 100644
--- a/include/olm/error.h
+++ b/include/olm/error.h
@@ -32,6 +32,9 @@ enum OlmErrorCode {
     OLM_UNKNOWN_PICKLE_VERSION = 9, /*!< The pickled object is too new */
     OLM_CORRUPTED_PICKLE = 10, /*!< The pickled object couldn't be decoded */
 
+    OLM_BAD_RATCHET_KEY = 11,
+    OLM_BAD_CHAIN_INDEX = 12,
+
     /* remember to update the list of string constants in error.c when updating
      * this list. */
 };
diff --git a/include/olm/inbound_group_session.h b/include/olm/inbound_group_session.h
new file mode 100644
index 0000000..4cf4ac4
--- /dev/null
+++ b/include/olm/inbound_group_session.h
@@ -0,0 +1,153 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef OLM_INBOUND_GROUP_SESSION_H_
+#define OLM_INBOUND_GROUP_SESSION_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct OlmInboundGroupSession OlmInboundGroupSession;
+
+/** get the size of an inbound group session, in bytes. */
+size_t olm_inbound_group_session_size();
+
+/**
+ * Initialise an inbound group session object using the supplied memory
+ * The supplied memory should be at least olm_inbound_group_session_size()
+ * bytes.
+ */
+OlmInboundGroupSession * olm_inbound_group_session(
+    void *memory
+);
+
+/**
+ * A null terminated string describing the most recent error to happen to a
+ * group session */
+const char *olm_inbound_group_session_last_error(
+    const OlmInboundGroupSession *session
+);
+
+/** Clears the memory used to back this group session */
+size_t olm_clear_inbound_group_session(
+    OlmInboundGroupSession *session
+);
+
+/** Returns the number of bytes needed to store an inbound group session */
+size_t olm_pickle_inbound_group_session_length(
+    const OlmInboundGroupSession *session
+);
+
+/**
+ * Stores a group session as a base64 string. Encrypts the session using the
+ * supplied key. Returns the length of the session on success.
+ *
+ * Returns olm_error() on failure. If the pickle output buffer
+ * is smaller than olm_pickle_inbound_group_session_length() then
+ * olm_inbound_group_session_last_error() will be "OUTPUT_BUFFER_TOO_SMALL"
+ */
+size_t olm_pickle_inbound_group_session(
+    OlmInboundGroupSession *session,
+    void const * key, size_t key_length,
+    void * pickled, size_t pickled_length
+);
+
+/**
+ * Loads a group session from a pickled base64 string. Decrypts the session
+ * using the supplied key.
+ *
+ * Returns olm_error() on failure. If the key doesn't match the one used to
+ * encrypt the account then olm_inbound_group_session_last_error() will be
+ * "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then
+ * olm_inbound_group_session_last_error() will be "INVALID_BASE64". The input
+ * pickled buffer is destroyed
+ */
+size_t olm_unpickle_inbound_group_session(
+    OlmInboundGroupSession *session,
+    void const * key, size_t key_length,
+    void * pickled, size_t pickled_length
+);
+
+
+/**
+ * Start a new inbound group session, based on the parameters supplied.
+ *
+ * Returns olm_error() on failure. On failure last_error will be set with an
+ * error code. The last_error will be:
+ *
+ *  * OLM_INVALID_BASE64  if the session_key is not valid base64
+ *  * OLM_BAD_RATCHET_KEY if the session_key is invalid
+ */
+size_t olm_init_inbound_group_session(
+    OlmInboundGroupSession *session,
+    uint32_t message_index,
+
+    /* base64-encoded key */
+    uint8_t const * session_key, size_t session_key_length
+);
+
+/**
+ * Get an upper bound on the number of bytes of plain-text the decrypt method
+ * will write for a given input message length. The actual size could be
+ * different due to padding.
+ *
+ * The input message buffer is destroyed.
+ *
+ * Returns olm_error() on failure.
+ */
+size_t olm_group_decrypt_max_plaintext_length(
+    OlmInboundGroupSession *session,
+    uint8_t * message, size_t message_length
+);
+
+/**
+ * Decrypt a message.
+ *
+ * The input message buffer is destroyed.
+ *
+ * Returns the length of the decrypted plain-text, or olm_error() on failure.
+ *
+ * On failure last_error will be set with an error code. The last_error will
+ * be:
+ *   * OLM_OUTPUT_BUFFER_TOO_SMALL if the plain-text buffer is too small
+ *   * OLM_INVALID_BASE64 if the message is not valid base-64
+ *   * OLM_BAD_MESSAGE_VERSION if the message was encrypted with an unsupported
+ *     version of the protocol
+ *   * OLM_BAD_MESSAGE_FORMAT if the message headers could not be decoded
+ *   * OLM_BAD_MESSAGE_MAC if the message could not be verified
+ *   * OLM_BAD_CHAIN_INDEX if we do not have a ratchet key corresponding to the
+ *     message's index (ie, it was sent before the ratchet key was shared with
+ *     us)
+ */
+size_t olm_group_decrypt(
+    OlmInboundGroupSession *session,
+
+    /* input; note that it will be overwritten with the base64-decoded
+       message. */
+    uint8_t * message, size_t message_length,
+
+    /* output */
+    uint8_t * plaintext, size_t max_plaintext_length
+);
+
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif /* OLM_INBOUND_GROUP_SESSION_H_ */
diff --git a/include/olm/message.h b/include/olm/message.h
index 05fb56c..bd7aec3 100644
--- a/include/olm/message.h
+++ b/include/olm/message.h
@@ -65,6 +65,30 @@ void _olm_encode_group_message(
 );
 
 
+struct _OlmDecodeGroupMessageResults {
+    uint8_t version;
+    const uint8_t *session_id;
+    size_t session_id_length;
+    uint32_t chain_index;
+    int has_chain_index;
+    const uint8_t *ciphertext;
+    size_t ciphertext_length;
+};
+
+
+/**
+ * Reads the message headers from the input buffer.
+ */
+void _olm_decode_group_message(
+    const uint8_t *input, size_t input_length,
+    size_t mac_length,
+
+    /* output structure: updated with results */
+    struct _OlmDecodeGroupMessageResults *results
+);
+
+
+
 #ifdef __cplusplus
 } // extern "C"
 #endif
diff --git a/include/olm/olm.h b/include/olm/olm.h
index 00e1f63..dbaf71e 100644
--- a/include/olm/olm.h
+++ b/include/olm/olm.h
@@ -19,6 +19,7 @@
 #include <stddef.h>
 #include <stdint.h>
 
+#include "olm/inbound_group_session.h"
 #include "olm/outbound_group_session.h"
 
 #ifdef __cplusplus
diff --git a/src/inbound_group_session.c b/src/inbound_group_session.c
new file mode 100644
index 0000000..4796414
--- /dev/null
+++ b/src/inbound_group_session.c
@@ -0,0 +1,199 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/inbound_group_session.h"
+
+#include <string.h>
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/error.h"
+#include "olm/megolm.h"
+#include "olm/message.h"
+
+#define OLM_PROTOCOL_VERSION     3
+
+struct OlmInboundGroupSession {
+    /** our earliest known ratchet value */
+    Megolm initial_ratchet;
+
+    /** The most recent ratchet value */
+    Megolm latest_ratchet;
+
+    enum OlmErrorCode last_error;
+};
+
+size_t olm_inbound_group_session_size() {
+    return sizeof(OlmInboundGroupSession);
+}
+
+OlmInboundGroupSession * olm_inbound_group_session(
+    void *memory
+) {
+    OlmInboundGroupSession *session = memory;
+    olm_clear_inbound_group_session(session);
+    return session;
+}
+
+const char *olm_inbound_group_session_last_error(
+    const OlmInboundGroupSession *session
+) {
+    return _olm_error_to_string(session->last_error);
+}
+
+size_t olm_clear_inbound_group_session(
+    OlmInboundGroupSession *session
+) {
+    memset(session, 0, sizeof(OlmInboundGroupSession));
+    return sizeof(OlmInboundGroupSession);
+}
+
+size_t olm_init_inbound_group_session(
+    OlmInboundGroupSession *session,
+    uint32_t message_index,
+    const uint8_t * session_key, size_t session_key_length
+) {
+    uint8_t key_buf[MEGOLM_RATCHET_LENGTH];
+    size_t raw_length = _olm_decode_base64_length(session_key_length);
+
+    if (raw_length == (size_t)-1) {
+        session->last_error = OLM_INVALID_BASE64;
+        return (size_t)-1;
+    }
+
+    if (raw_length != MEGOLM_RATCHET_LENGTH) {
+        session->last_error = OLM_BAD_RATCHET_KEY;
+        return (size_t)-1;
+    }
+
+    _olm_decode_base64(session_key, session_key_length, key_buf);
+    megolm_init(&session->initial_ratchet, key_buf, message_index);
+    megolm_init(&session->latest_ratchet, key_buf, message_index);
+    memset(key_buf, 0, MEGOLM_RATCHET_LENGTH);
+
+    return 0;
+}
+
+size_t olm_group_decrypt_max_plaintext_length(
+    OlmInboundGroupSession *session,
+    uint8_t * message, size_t message_length
+) {
+    size_t r;
+    const struct _olm_cipher *cipher = megolm_cipher();
+    struct _OlmDecodeGroupMessageResults decoded_results;
+
+    r = _olm_decode_base64(message, message_length, message);
+    if (r == (size_t)-1) {
+        session->last_error = OLM_INVALID_BASE64;
+        return r;
+    }
+
+    _olm_decode_group_message(
+        message, message_length,
+        cipher->ops->mac_length(cipher),
+        &decoded_results);
+
+    if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+        session->last_error = OLM_BAD_MESSAGE_VERSION;
+        return (size_t)-1;
+    }
+
+    if (!decoded_results.ciphertext) {
+        session->last_error = OLM_BAD_MESSAGE_FORMAT;
+        return (size_t)-1;
+    }
+
+    return cipher->ops->decrypt_max_plaintext_length(
+        cipher, decoded_results.ciphertext_length);
+}
+
+
+size_t olm_group_decrypt(
+    OlmInboundGroupSession *session,
+    uint8_t * message, size_t message_length,
+    uint8_t * plaintext, size_t max_plaintext_length
+) {
+    struct _OlmDecodeGroupMessageResults decoded_results;
+    const struct _olm_cipher *cipher = megolm_cipher();
+    size_t max_length, raw_message_length, r;
+    Megolm *megolm;
+    Megolm tmp_megolm;
+
+    raw_message_length = _olm_decode_base64(message, message_length, message);
+    if (raw_message_length == (size_t)-1) {
+        session->last_error = OLM_INVALID_BASE64;
+        return (size_t)-1;
+    }
+
+    _olm_decode_group_message(
+        message, raw_message_length,
+        cipher->ops->mac_length(cipher),
+        &decoded_results);
+
+    if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+        session->last_error = OLM_BAD_MESSAGE_VERSION;
+        return (size_t)-1;
+    }
+
+    if (!decoded_results.has_chain_index || !decoded_results.session_id
+        || !decoded_results.ciphertext
+    ) {
+        session->last_error = OLM_BAD_MESSAGE_FORMAT;
+        return (size_t)-1;
+    }
+
+    max_length = cipher->ops->decrypt_max_plaintext_length(
+        cipher,
+        decoded_results.ciphertext_length
+    );
+    if (max_plaintext_length < max_length) {
+        session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+        return (size_t)-1;
+    }
+
+    /* pick a megolm instance to use. If we're at or beyond the latest ratchet
+     * value, use that */
+    if ((int32_t)(decoded_results.chain_index - session->latest_ratchet.counter) >= 0) {
+        megolm = &session->latest_ratchet;
+    } else if ((int32_t)(decoded_results.chain_index - session->initial_ratchet.counter) < 0) {
+        /* the counter is before our intial ratchet - we can't decode this. */
+        session->last_error = OLM_BAD_CHAIN_INDEX;
+        return (size_t)-1;
+    } else {
+        /* otherwise, start from the initial megolm. Take a copy so that we
+         * don't overwrite the initial megolm */
+        tmp_megolm = session->initial_ratchet;
+        megolm = &tmp_megolm;
+    }
+
+    megolm_advance_to(megolm, decoded_results.chain_index);
+
+    /* now try checking the mac, and decrypting */
+    r = cipher->ops->decrypt(
+        cipher,
+        megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH,
+        message, raw_message_length,
+        decoded_results.ciphertext, decoded_results.ciphertext_length,
+        plaintext, max_plaintext_length
+    );
+
+    memset(&tmp_megolm, 0, sizeof(tmp_megolm));
+    if (r == (size_t)-1) {
+        session->last_error = OLM_BAD_MESSAGE_MAC;
+        return r;
+    }
+
+    return r;
+}
diff --git a/src/message.cpp b/src/message.cpp
index df0c7bb..ec44262 100644
--- a/src/message.cpp
+++ b/src/message.cpp
@@ -363,3 +363,45 @@ void _olm_encode_group_message(
     pos = encode(pos, COUNTER_TAG, chain_index);
     pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
 }
+
+void _olm_decode_group_message(
+    const uint8_t *input, size_t input_length,
+    size_t mac_length,
+    struct _OlmDecodeGroupMessageResults *results
+) {
+    std::uint8_t const * pos = input;
+    std::uint8_t const * end = input + input_length - mac_length;
+    std::uint8_t const * unknown = nullptr;
+
+    results->session_id = nullptr;
+    results->session_id_length = 0;
+    bool has_chain_index = false;
+    results->chain_index = 0;
+    results->ciphertext = nullptr;
+    results->ciphertext_length = 0;
+
+    if (pos == end) return;
+    if (input_length < mac_length) return;
+    results->version = *(pos++);
+
+    while (pos != end) {
+        pos = decode(
+            pos, end, GROUP_SESSION_ID_TAG,
+            results->session_id, results->session_id_length
+        );
+        pos = decode(
+            pos, end, COUNTER_TAG,
+            results->chain_index, has_chain_index
+        );
+        pos = decode(
+            pos, end, CIPHERTEXT_TAG,
+            results->ciphertext, results->ciphertext_length
+        );
+        if (unknown == pos) {
+            pos = skip_unknown(pos, end);
+        }
+        unknown = pos;
+    }
+
+    results->has_chain_index = (int)has_chain_index;
+}
diff --git a/tests/test_group_session.cpp b/tests/test_group_session.cpp
index b9fe1ef..5bbdc9d 100644
--- a/tests/test_group_session.cpp
+++ b/tests/test_group_session.cpp
@@ -12,6 +12,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+#include "olm/inbound_group_session.h"
 #include "olm/outbound_group_session.h"
 #include "unittest.hh"
 
@@ -19,11 +20,10 @@
 int main() {
 
 {
-
     TestCase test_case("Pickle outbound group");
 
     size_t size = olm_outbound_group_session_size();
-    void *memory = alloca(size);
+    uint8_t memory[size];
     OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
 
     size_t pickle_length = olm_pickle_outbound_group_session_length(session);
@@ -61,9 +61,9 @@ int main() {
         "0123456789ABDEF0123456789ABCDEF";
 
 
-
+    /* build the outbound session */
     size_t size = olm_outbound_group_session_size();
-    void *memory = alloca(size);
+    uint8_t memory[size];
     OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
 
     assert_equals((size_t)132,
@@ -73,18 +73,48 @@ int main() {
         session, random_bytes, sizeof(random_bytes));
     assert_equals((size_t)0, res);
 
+    assert_equals(0U, olm_outbound_group_session_message_index(session));
+    size_t session_key_len = olm_outbound_group_session_key_length(session);
+    uint8_t session_key[session_key_len];
+    olm_outbound_group_session_key(session, session_key, session_key_len);
+
+
+    /* encode the message */
     uint8_t plaintext[] = "Message";
     size_t plaintext_length = sizeof(plaintext) - 1;
 
     size_t msglen = olm_group_encrypt_message_length(
         session, plaintext_length);
 
-    uint8_t *msg = (uint8_t *)alloca(msglen);
+    uint8_t msg[msglen];
     res = olm_group_encrypt(session, plaintext, plaintext_length,
                             msg, msglen);
     assert_equals(msglen, res);
+    assert_equals(1U, olm_outbound_group_session_message_index(session));
+
+
+    /* build the inbound session */
+    size = olm_inbound_group_session_size();
+    uint8_t inbound_session_memory[size];
+    OlmInboundGroupSession *inbound_session =
+        olm_inbound_group_session(inbound_session_memory);
+
+    res = olm_init_inbound_group_session(
+        inbound_session, 0U, session_key, session_key_len);
+    assert_equals((size_t)0, res);
 
-    // TODO: decode the message
+    /* decode the message */
+
+    /* olm_group_decrypt_max_plaintext_length destroys the input so we have to
+       copy it. */
+    uint8_t msgcopy[msglen];
+    memcpy(msgcopy, msg, msglen);
+    size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
+    uint8_t plaintext_buf[size];
+    res = olm_group_decrypt(inbound_session, msg, msglen,
+                            plaintext_buf, size);
+    assert_equals(plaintext_length, res);
+    assert_equals(plaintext, plaintext_buf, res);
 }
 
 }
diff --git a/tests/test_message.cpp b/tests/test_message.cpp
index e2385ea..5fec9e0 100644
--- a/tests/test_message.cpp
+++ b/tests/test_message.cpp
@@ -97,4 +97,26 @@ assert_equals(message2, output, 35);
     assert_equals(output+sizeof(expected)-1, ciphertext_ptr);
 } /* group message encode test */
 
+{
+    TestCase test_case("Group message decode test");
+
+    struct _OlmDecodeGroupMessageResults results;
+    std::uint8_t message[] =
+        "\x03"
+        "\x2A\x09sessionid"
+        "\x10\xc8\x01"
+        "\x22\x0A" "ciphertext"
+        "hmacsha2";
+
+    const uint8_t expected_session_id[] = "sessionid";
+
+    _olm_decode_group_message(message, sizeof(message)-1, 8, &results);
+    assert_equals(std::uint8_t(3), results.version);
+    assert_equals(std::size_t(9), results.session_id_length);
+    assert_equals(expected_session_id, results.session_id, 9);
+    assert_equals(1, results.has_chain_index);
+    assert_equals(std::uint32_t(200), results.chain_index);
+    assert_equals(std::size_t(10), results.ciphertext_length);
+    assert_equals(ciphertext, results.ciphertext, 10);
+} /* group message decode test */
 }
-- 
GitLab