From 42a300fc62a2d10fc14868ac6135d3da3857469f Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <richard@matrix.org>
Date: Tue, 17 May 2016 18:48:16 +0100
Subject: [PATCH] Factor out pickle_encoding from olm.cpp

We don't need to have all of the top-level pickling functions in olm.cpp;
factor out the utilities to support it to pickle_encoding.cpp (and make sure
that they have plain-C bindings).
---
 include/olm/pickle_encoding.h | 76 +++++++++++++++++++++++++++
 src/olm.cpp                   | 97 +++++------------------------------
 src/pickle_encoding.c         | 92 +++++++++++++++++++++++++++++++++
 3 files changed, 181 insertions(+), 84 deletions(-)
 create mode 100644 include/olm/pickle_encoding.h
 create mode 100644 src/pickle_encoding.c

diff --git a/include/olm/pickle_encoding.h b/include/olm/pickle_encoding.h
new file mode 100644
index 0000000..03611df
--- /dev/null
+++ b/include/olm/pickle_encoding.h
@@ -0,0 +1,76 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* functions for encrypting and decrypting pickled representations of objects */
+
+#ifndef OLM_PICKLE_ENCODING_H_
+#define OLM_PICKLE_ENCODING_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "olm/error.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+/**
+ * Get the number of bytes needed to encode a pickle of the length given
+ */
+size_t _olm_enc_output_length(size_t raw_length);
+
+/**
+ * Get the point in the output buffer that the raw pickle should be written to.
+ *
+ * In order that we can use the same buffer for the raw pickle, and the encoded
+ * pickle, the raw pickle needs to be written at the end of the buffer. (The
+ * base-64 encoding would otherwise overwrite the end of the input before it
+ * was encoded.)
+ */
+ uint8_t *_olm_enc_output_pos(uint8_t * output, size_t raw_length);
+
+/**
+ * Encrypt and encode the given pickle in-situ.
+ *
+ * The raw pickle should have been written to enc_output_pos(pickle,
+ * raw_length).
+ *
+ * Returns the number of bytes in the encoded pickle.
+ */
+size_t _olm_enc_output(
+    uint8_t const * key, size_t key_length,
+    uint8_t *pickle, size_t raw_length
+);
+
+/**
+ * Decode and decrypt the given pickle in-situ.
+ *
+ * Returns the number of bytes in the decoded pickle, or olm_error() on error,
+ * in which case *last_error will be updated, if last_error is non-NULL.
+ */
+size_t _olm_enc_input(
+    uint8_t const * key, size_t key_length,
+    uint8_t * input, size_t b64_length,
+    enum OlmErrorCode * last_error
+);
+
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif /* OLM_PICKLE_ENCODING_H_ */
diff --git a/src/olm.cpp b/src/olm.cpp
index babe7eb..0a4a734 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -16,6 +16,7 @@
 #include "olm/session.hh"
 #include "olm/account.hh"
 #include "olm/cipher.h"
+#include "olm/pickle_encoding.h"
 #include "olm/utility.hh"
 #include "olm/base64.hh"
 #include "olm/memory.hh"
@@ -57,78 +58,6 @@ static std::uint8_t const * from_c(void const * bytes) {
     return reinterpret_cast<std::uint8_t const *>(bytes);
 }
 
-static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER =
-    OLM_CIPHER_INIT_AES_SHA_256("Pickle");
-
-std::size_t enc_output_length(
-    size_t raw_length
-) {
-    auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
-    std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
-    length += cipher->ops->mac_length(cipher);
-    return olm::encode_base64_length(length);
-}
-
-
-std::uint8_t * enc_output_pos(
-    std::uint8_t * output,
-    size_t raw_length
-) {
-    auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
-    std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
-    length += cipher->ops->mac_length(cipher);
-    return output + olm::encode_base64_length(length) - length;
-}
-
-std::size_t enc_output(
-    std::uint8_t const * key, std::size_t key_length,
-    std::uint8_t * output, size_t raw_length
-) {
-    auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
-    std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
-        cipher, raw_length
-    );
-    std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
-    std::size_t base64_length = olm::encode_base64_length(length);
-    std::uint8_t * raw_output = output + base64_length - length;
-    cipher->ops->encrypt(
-        cipher,
-        key, key_length,
-        raw_output, raw_length,
-        raw_output, ciphertext_length,
-        raw_output, length
-    );
-    olm::encode_base64(raw_output, length, output);
-    return raw_length;
-}
-
-std::size_t enc_input(
-    std::uint8_t const * key, std::size_t key_length,
-    std::uint8_t * input, size_t b64_length,
-    OlmErrorCode & last_error
-) {
-    std::size_t enc_length = olm::decode_base64_length(b64_length);
-    if (enc_length == std::size_t(-1)) {
-        last_error = OlmErrorCode::OLM_INVALID_BASE64;
-        return std::size_t(-1);
-    }
-    olm::decode_base64(input, b64_length, input);
-    auto *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
-    std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
-    std::size_t result = cipher->ops->decrypt(
-        cipher,
-        key, key_length,
-        input, enc_length,
-        input, raw_length,
-        input, raw_length
-    );
-    if (result == std::size_t(-1)) {
-        last_error = OlmErrorCode::OLM_BAD_ACCOUNT_KEY;
-    }
-    return result;
-}
-
-
 std::size_t b64_output_length(
     size_t raw_length
 ) {
@@ -270,14 +199,14 @@ size_t olm_clear_utility(
 size_t olm_pickle_account_length(
     OlmAccount * account
 ) {
-    return enc_output_length(pickle_length(*from_c(account)));
+    return _olm_enc_output_length(pickle_length(*from_c(account)));
 }
 
 
 size_t olm_pickle_session_length(
     OlmSession * session
 ) {
-    return enc_output_length(pickle_length(*from_c(session)));
+    return _olm_enc_output_length(pickle_length(*from_c(session)));
 }
 
 
@@ -288,12 +217,12 @@ size_t olm_pickle_account(
 ) {
     olm::Account & object = *from_c(account);
     std::size_t raw_length = pickle_length(object);
-    if (pickled_length < enc_output_length(raw_length)) {
+    if (pickled_length < _olm_enc_output_length(raw_length)) {
         object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
         return size_t(-1);
     }
-    pickle(enc_output_pos(from_c(pickled), raw_length), object);
-    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
+    pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object);
+    return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length);
 }
 
 
@@ -304,12 +233,12 @@ size_t olm_pickle_session(
 ) {
     olm::Session & object = *from_c(session);
     std::size_t raw_length = pickle_length(object);
-    if (pickled_length < enc_output_length(raw_length)) {
+    if (pickled_length < _olm_enc_output_length(raw_length)) {
         object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
         return size_t(-1);
     }
-    pickle(enc_output_pos(from_c(pickled), raw_length), object);
-    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
+    pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object);
+    return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length);
 }
 
 
@@ -320,8 +249,8 @@ size_t olm_unpickle_account(
 ) {
     olm::Account & object = *from_c(account);
     std::uint8_t * const pos = from_c(pickled);
-    std::size_t raw_length = enc_input(
-        from_c(key), key_length, pos, pickled_length, object.last_error
+    std::size_t raw_length = _olm_enc_input(
+        from_c(key), key_length, pos, pickled_length, &object.last_error
     );
     if (raw_length == std::size_t(-1)) {
         return std::size_t(-1);
@@ -348,8 +277,8 @@ size_t olm_unpickle_session(
 ) {
     olm::Session & object = *from_c(session);
     std::uint8_t * const pos = from_c(pickled);
-    std::size_t raw_length = enc_input(
-        from_c(key), key_length, pos, pickled_length, object.last_error
+    std::size_t raw_length = _olm_enc_input(
+        from_c(key), key_length, pos, pickled_length, &object.last_error
     );
     if (raw_length == std::size_t(-1)) {
         return std::size_t(-1);
diff --git a/src/pickle_encoding.c b/src/pickle_encoding.c
new file mode 100644
index 0000000..5d5f8d7
--- /dev/null
+++ b/src/pickle_encoding.c
@@ -0,0 +1,92 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/pickle_encoding.h"
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/olm.h"
+
+static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER =
+    OLM_CIPHER_INIT_AES_SHA_256("Pickle");
+
+size_t _olm_enc_output_length(
+    size_t raw_length
+) {
+    const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+    size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+    length += cipher->ops->mac_length(cipher);
+    return _olm_encode_base64_length(length);
+}
+
+uint8_t * _olm_enc_output_pos(
+    uint8_t * output,
+    size_t raw_length
+) {
+    const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+    size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+    length += cipher->ops->mac_length(cipher);
+    return output + _olm_encode_base64_length(length) - length;
+}
+
+size_t _olm_enc_output(
+    uint8_t const * key, size_t key_length,
+    uint8_t * output, size_t raw_length
+) {
+    const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+    size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
+        cipher, raw_length
+    );
+    size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
+    size_t base64_length = _olm_encode_base64_length(length);
+    uint8_t * raw_output = output + base64_length - length;
+    cipher->ops->encrypt(
+        cipher,
+        key, key_length,
+        raw_output, raw_length,
+        raw_output, ciphertext_length,
+        raw_output, length
+    );
+    _olm_encode_base64(raw_output, length, output);
+    return raw_length;
+}
+
+
+size_t _olm_enc_input(uint8_t const * key, size_t key_length,
+                      uint8_t * input, size_t b64_length,
+                      enum OlmErrorCode * last_error
+) {
+    size_t enc_length = _olm_decode_base64_length(b64_length);
+    if (enc_length == (size_t)-1) {
+        if (last_error) {
+            *last_error = OLM_INVALID_BASE64;
+        }
+        return (size_t)-1;
+    }
+    _olm_decode_base64(input, b64_length, input);
+    const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+    size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
+    size_t result = cipher->ops->decrypt(
+        cipher,
+        key, key_length,
+        input, enc_length,
+        input, raw_length,
+        input, raw_length
+    );
+    if (result == (size_t)-1 && last_error) {
+        *last_error = OLM_BAD_ACCOUNT_KEY;
+    }
+    return result;
+}
-- 
GitLab