From f719236b08d373d9508f2467bbfc6dfa953b1f8d Mon Sep 17 00:00:00 2001
From: Nicolas Werner <nicolas.werner@hotmail.de>
Date: Mon, 2 Dec 2019 22:53:15 +0100
Subject: [PATCH] Implement file encryption and decryption

No idea, if that is secure yet.
---
 include/mtxclient/crypto/client.hpp |  19 ---
 include/mtxclient/crypto/utils.hpp  |  65 +++++++-
 lib/crypto/client.cpp               |  47 +-----
 lib/crypto/utils.cpp                | 223 +++++++++++++++++++++++++++-
 tests/e2ee.cpp                      |  56 +++++++
 tests/messages.cpp                  |   2 +-
 6 files changed, 340 insertions(+), 72 deletions(-)

diff --git a/include/mtxclient/crypto/client.hpp b/include/mtxclient/crypto/client.hpp
index 7b7197299..cb626241a 100644
--- a/include/mtxclient/crypto/client.hpp
+++ b/include/mtxclient/crypto/client.hpp
@@ -55,19 +55,6 @@ private:
         std::string msg_;
 };
 
-class sodium_exception : public std::exception
-{
-public:
-        sodium_exception(std::string func, const char *msg)
-          : msg_(func + ": " + std::string(msg))
-        {}
-
-        virtual const char *what() const throw() { return msg_.c_str(); }
-
-private:
-        std::string msg_;
-};
-
 template<class T>
 std::string
 pickle(typename T::olm_type *object, const std::string &key)
@@ -244,12 +231,6 @@ encrypt_exported_sessions(const mtx::crypto::ExportedSessionKeys &keys, std::str
 mtx::crypto::ExportedSessionKeys
 decrypt_exported_sessions(const std::string &data, std::string pass);
 
-std::string
-base642bin(const std::string &b64);
-
-std::string
-bin2base64(const std::string &b64);
-
 BinaryBuf
 derive_key(const std::string &pass, const BinaryBuf &salt);
 
diff --git a/include/mtxclient/crypto/utils.hpp b/include/mtxclient/crypto/utils.hpp
index f8af46f75..fc81969c4 100644
--- a/include/mtxclient/crypto/utils.hpp
+++ b/include/mtxclient/crypto/utils.hpp
@@ -1,20 +1,28 @@
 #pragma once
 
-#include <algorithm>
 #include <string>
 #include <vector>
 
-#include <openssl/aes.h>
-#include <openssl/evp.h>
-#include <openssl/hmac.h>
-#include <openssl/sha.h>
-
 #include <sodium.h>
 
 #include <boost/algorithm/string.hpp>
 
+#include "mtx/common.hpp"
+
 namespace mtx {
 namespace crypto {
+class sodium_exception : public std::exception
+{
+public:
+        sodium_exception(std::string func, const char *msg)
+          : msg_(func + ": " + std::string(msg))
+        {}
+
+        virtual const char *what() const throw() { return msg_.c_str(); }
+
+private:
+        std::string msg_;
+};
 
 //! Data representation used to interact with libolm.
 using BinaryBuf = std::vector<uint8_t>;
@@ -32,6 +40,19 @@ create_buffer(std::size_t nbytes)
         return buf;
 }
 
+inline BinaryBuf
+to_binary_buf(const std::string &str)
+{
+        return BinaryBuf(reinterpret_cast<const uint8_t *>(str.data()),
+                         reinterpret_cast<const uint8_t *>(str.data()) + str.size());
+}
+
+inline std::string
+to_string(const BinaryBuf &buf)
+{
+        return std::string(reinterpret_cast<const char *>(buf.data()), buf.size());
+}
+
 //! Simple wrapper around the OpenSSL PKCS5_PBKDF2_HMAC function
 BinaryBuf
 PBKDF2_HMAC_SHA_512(const std::string pass, const BinaryBuf salt, uint32_t iterations);
@@ -45,6 +66,18 @@ AES_CTR_256_Decrypt(const std::string ciphertext, const BinaryBuf aes256Key, Bin
 BinaryBuf
 HMAC_SHA256(const BinaryBuf hmacKey, const BinaryBuf data);
 
+std::string
+sha256(const std::string &data);
+
+//! Decrypt matrix EncryptedFile
+BinaryBuf
+decrypt_file(const std::string &ciphertext, const mtx::crypto::EncryptedFile &encryption_info);
+
+//! Encrypt matrix EncryptedFile
+// Remember to set the url member of the EncryptedFile struct!
+std::pair<BinaryBuf, mtx::crypto::EncryptedFile>
+encrypt_file(const std::string &plaintext);
+
 //! Translates the data back into the binary buffer, taking care
 //! to remove the header and footer elements.
 std::string
@@ -63,5 +96,23 @@ uint32_to_uint8(uint8_t b[4], uint32_t u32);
 void
 print_binary_buf(const BinaryBuf buf);
 
+std::string
+base642bin(const std::string &b64);
+
+std::string
+bin2base64(const std::string &bin);
+
+std::string
+base642bin_unpadded(const std::string &b64);
+
+std::string
+bin2base64_unpadded(const std::string &bin);
+
+std::string
+base642bin_urlsafe_unpadded(const std::string &b64);
+
+std::string
+bin2base64_urlsafe_unpadded(const std::string &bin);
+
 } // namespace crypto
-} // namespace mtx
\ No newline at end of file
+} // namespace mtx
diff --git a/lib/crypto/client.cpp b/lib/crypto/client.cpp
index ac180273a..6f60d5c31 100644
--- a/lib/crypto/client.cpp
+++ b/lib/crypto/client.cpp
@@ -1,5 +1,8 @@
 #include <iostream>
 
+#include <openssl/aes.h>
+#include <openssl/sha.h>
+
 #include "mtxclient/crypto/client.hpp"
 #include "mtxclient/crypto/types.hpp"
 #include "mtxclient/crypto/utils.hpp"
@@ -669,50 +672,6 @@ mtx::crypto::decrypt_exported_sessions(const std::string &data, std::string pass
         return json::parse(plaintext);
 }
 
-std::string
-mtx::crypto::base642bin(const std::string &b64)
-{
-        std::size_t bin_maxlen = b64.size();
-        std::size_t bin_len;
-
-        const char *max_end;
-
-        auto ciphertext = create_buffer(bin_maxlen);
-
-        const int rc = sodium_base642bin(reinterpret_cast<unsigned char *>(ciphertext.data()),
-                                         ciphertext.size(),
-                                         b64.data(),
-                                         b64.size(),
-                                         nullptr,
-                                         &bin_len,
-                                         &max_end,
-                                         sodium_base64_VARIANT_ORIGINAL);
-        if (rc != 0)
-                throw sodium_exception{"sodium_base642bin", "encoding failed"};
-
-        if (bin_len != bin_maxlen)
-                ciphertext.resize(bin_len);
-
-        return std::string(std::make_move_iterator(ciphertext.begin()),
-                           std::make_move_iterator(ciphertext.end()));
-}
-
-std::string
-mtx::crypto::bin2base64(const std::string &bin)
-{
-        auto base64buf =
-          create_buffer(sodium_base64_encoded_len(bin.size(), sodium_base64_VARIANT_ORIGINAL));
-
-        sodium_bin2base64(reinterpret_cast<char *>(base64buf.data()),
-                          base64buf.size(),
-                          reinterpret_cast<const unsigned char *>(bin.data()),
-                          bin.size(),
-                          sodium_base64_VARIANT_ORIGINAL);
-
-        // Removing the null byte.
-        return std::string(base64buf.begin(), base64buf.end() - 1);
-}
-
 BinaryBuf
 mtx::crypto::derive_key(const std::string &pass, const BinaryBuf &salt)
 {
diff --git a/lib/crypto/utils.cpp b/lib/crypto/utils.cpp
index 6dcb8fead..82cd77434 100644
--- a/lib/crypto/utils.cpp
+++ b/lib/crypto/utils.cpp
@@ -1,5 +1,12 @@
 #include "mtxclient/crypto/utils.hpp"
 
+#include <openssl/aes.h>
+#include <openssl/evp.h>
+#include <openssl/hmac.h>
+#include <openssl/sha.h>
+
+#include <algorithm>
+#include <iomanip>
 #include <iostream>
 
 namespace mtx {
@@ -33,7 +40,7 @@ AES_CTR_256_Encrypt(const std::string plaintext, const BinaryBuf aes256Key, Bina
         int ciphertext_len;
 
         // The ciphertext expand up to block size, which is 128 for AES256
-        BinaryBuf encrypted = create_buffer(plaintext.size() + 128);
+        BinaryBuf encrypted = create_buffer(plaintext.size() + AES_BLOCK_SIZE);
 
         uint8_t *iv_data = iv.data();
         // need to set bit 63 to 0
@@ -128,6 +135,90 @@ AES_CTR_256_Decrypt(const std::string ciphertext, const BinaryBuf aes256Key, Bin
         return decrypted;
 }
 
+std::string
+sha256(const std::string &data)
+{
+        bool success = false;
+        std::string hashed;
+
+        EVP_MD_CTX *context = EVP_MD_CTX_new();
+
+        if (context != NULL) {
+                if (EVP_DigestInit_ex(context, EVP_sha256(), NULL)) {
+                        if (EVP_DigestUpdate(context, data.c_str(), data.length())) {
+                                unsigned char hash[EVP_MAX_MD_SIZE];
+                                unsigned int lengthOfHash = 0;
+
+                                if (EVP_DigestFinal_ex(context, hash, &lengthOfHash)) {
+                                        hashed  = std::string(hash, hash + lengthOfHash);
+                                        success = true;
+                                }
+                        }
+                }
+
+                EVP_MD_CTX_free(context);
+        }
+
+        if (success)
+                return hashed;
+        throw std::runtime_error("sha256 failed!");
+}
+
+BinaryBuf
+decrypt_file(const std::string &ciphertext, const mtx::crypto::EncryptedFile &encryption_info)
+{
+        if (encryption_info.v != "v2")
+                throw std::invalid_argument("Unsupported encrypted file version");
+
+        if (encryption_info.key.kty != "oct")
+                throw std::invalid_argument("Unsupported key type");
+
+        if (encryption_info.key.alg != "A256CTR")
+                throw std::invalid_argument("Unsupported algorithm");
+
+        // Be careful, the key should be urlsafe and unpadded, the iv and sha only need to
+        // be unpadded
+        if (bin2base64_unpadded(sha256(ciphertext)) != encryption_info.hashes.at("sha256"))
+                throw std::invalid_argument(
+                  "sha256 of encrypted file does not match the ciphertext, expected '" +
+                  bin2base64_unpadded(sha256(ciphertext)) + "', got '" +
+                  encryption_info.hashes.at("sha256") + "'");
+
+        return AES_CTR_256_Decrypt(
+          ciphertext,
+          to_binary_buf(base642bin_urlsafe_unpadded(encryption_info.key.k)),
+          to_binary_buf(base642bin_unpadded(encryption_info.iv)));
+}
+
+std::pair<BinaryBuf, mtx::crypto::EncryptedFile>
+encrypt_file(const std::string &plaintext)
+{
+        mtx::crypto::EncryptedFile encryption_info;
+
+        // not sure if 16 bytes would be enough, 32 seems to be safe though
+        BinaryBuf key = create_buffer(32);
+        BinaryBuf iv  = create_buffer(32);
+
+        BinaryBuf cyphertext = AES_CTR_256_Encrypt(plaintext, key, iv);
+
+        // Be careful, the key should be urlsafe and unpadded, the iv and sha only need to
+        // be unpadded
+        JWK web_key;
+        web_key.ext     = true;
+        web_key.kty     = "oct";
+        web_key.key_ops = {"encrypt", "decrypt"};
+        web_key.alg     = "A256CTR";
+        web_key.k       = bin2base64_urlsafe_unpadded(to_string(key));
+        web_key.ext     = true;
+
+        encryption_info.key              = web_key;
+        encryption_info.iv               = bin2base64_unpadded(to_string(iv));
+        encryption_info.hashes["sha256"] = bin2base64_unpadded(sha256(to_string(cyphertext)));
+        encryption_info.v                = "v2";
+
+        return std::make_pair(cyphertext, encryption_info);
+}
+
 template<typename T>
 void
 remove_substrs(std::basic_string<T> &s, const std::basic_string<T> &p)
@@ -185,5 +276,135 @@ uint32_to_uint8(uint8_t b[4], uint32_t u32)
         b[0] = (uint8_t)(u32 >>= 8);
 }
 
+std::string
+base642bin(const std::string &b64)
+{
+        std::size_t bin_maxlen = b64.size();
+        std::size_t bin_len;
+
+        const char *max_end;
+
+        auto ciphertext = create_buffer(bin_maxlen);
+
+        const int rc = sodium_base642bin(reinterpret_cast<unsigned char *>(ciphertext.data()),
+                                         ciphertext.size(),
+                                         b64.data(),
+                                         b64.size(),
+                                         nullptr,
+                                         &bin_len,
+                                         &max_end,
+                                         sodium_base64_VARIANT_ORIGINAL);
+        if (rc != 0)
+                throw sodium_exception{"sodium_base642bin", "encoding failed"};
+
+        if (bin_len != bin_maxlen)
+                ciphertext.resize(bin_len);
+
+        return std::string(std::make_move_iterator(ciphertext.begin()),
+                           std::make_move_iterator(ciphertext.end()));
+}
+
+std::string
+bin2base64(const std::string &bin)
+{
+        auto base64buf =
+          create_buffer(sodium_base64_encoded_len(bin.size(), sodium_base64_VARIANT_ORIGINAL));
+
+        sodium_bin2base64(reinterpret_cast<char *>(base64buf.data()),
+                          base64buf.size(),
+                          reinterpret_cast<const unsigned char *>(bin.data()),
+                          bin.size(),
+                          sodium_base64_VARIANT_ORIGINAL);
+
+        // Removing the null byte.
+        return std::string(base64buf.begin(), base64buf.end() - 1);
+}
+std::string
+base642bin_unpadded(const std::string &b64)
+{
+        std::size_t bin_maxlen = b64.size();
+        std::size_t bin_len;
+
+        const char *max_end;
+
+        auto ciphertext = create_buffer(bin_maxlen);
+
+        const int rc = sodium_base642bin(reinterpret_cast<unsigned char *>(ciphertext.data()),
+                                         ciphertext.size(),
+                                         b64.data(),
+                                         b64.size(),
+                                         nullptr,
+                                         &bin_len,
+                                         &max_end,
+                                         sodium_base64_VARIANT_ORIGINAL_NO_PADDING);
+        if (rc != 0)
+                throw sodium_exception{"sodium_base642bin", "encoding failed"};
+
+        if (bin_len != bin_maxlen)
+                ciphertext.resize(bin_len);
+
+        return std::string(std::make_move_iterator(ciphertext.begin()),
+                           std::make_move_iterator(ciphertext.end()));
+}
+
+std::string
+bin2base64_unpadded(const std::string &bin)
+{
+        auto base64buf = create_buffer(
+          sodium_base64_encoded_len(bin.size(), sodium_base64_VARIANT_ORIGINAL_NO_PADDING));
+
+        sodium_bin2base64(reinterpret_cast<char *>(base64buf.data()),
+                          base64buf.size(),
+                          reinterpret_cast<const unsigned char *>(bin.data()),
+                          bin.size(),
+                          sodium_base64_VARIANT_ORIGINAL_NO_PADDING);
+
+        // Removing the null byte.
+        return std::string(base64buf.begin(), base64buf.end() - 1);
+}
+std::string
+base642bin_urlsafe_unpadded(const std::string &b64)
+{
+        std::size_t bin_maxlen = b64.size();
+        std::size_t bin_len;
+
+        const char *max_end;
+
+        auto ciphertext = create_buffer(bin_maxlen);
+
+        const int rc = sodium_base642bin(reinterpret_cast<unsigned char *>(ciphertext.data()),
+                                         ciphertext.size(),
+                                         b64.data(),
+                                         b64.size(),
+                                         nullptr,
+                                         &bin_len,
+                                         &max_end,
+                                         sodium_base64_VARIANT_URLSAFE_NO_PADDING);
+        if (rc != 0)
+                throw sodium_exception{"sodium_base642bin", "encoding failed"};
+
+        if (bin_len != bin_maxlen)
+                ciphertext.resize(bin_len);
+
+        return std::string(std::make_move_iterator(ciphertext.begin()),
+                           std::make_move_iterator(ciphertext.end()));
+}
+
+std::string
+bin2base64_urlsafe_unpadded(const std::string &bin)
+{
+        auto base64buf = create_buffer(
+          sodium_base64_encoded_len(bin.size(), sodium_base64_VARIANT_URLSAFE_NO_PADDING));
+
+        sodium_bin2base64(reinterpret_cast<char *>(base64buf.data()),
+                          base64buf.size(),
+                          reinterpret_cast<const unsigned char *>(bin.data()),
+                          bin.size(),
+                          sodium_base64_VARIANT_URLSAFE_NO_PADDING);
+
+        // Removing the null byte.
+        return std::string(base64buf.begin(), base64buf.end() - 1);
+}
+
 } // namespace crypto
 } // namespace mtx
diff --git a/tests/e2ee.cpp b/tests/e2ee.cpp
index dd0aa9ed8..885789b09 100644
--- a/tests/e2ee.cpp
+++ b/tests/e2ee.cpp
@@ -24,6 +24,8 @@ using namespace mtx::responses;
 
 using namespace std;
 
+using namespace nlohmann;
+
 struct OlmCipherContent
 {
         std::string body;
@@ -1130,6 +1132,60 @@ TEST(ExportSessions, InboundMegolmSessions)
         ASSERT_EQ(restored_output_str, secret_message);
 }
 
+TEST(Encryption, EncryptedFile)
+{
+        std::string plaintext = "This is some plain text payload";
+        auto encryption_data  = mtx::crypto::encrypt_file(plaintext);
+        ASSERT_NE(plaintext, mtx::crypto::to_string(encryption_data.first));
+        ASSERT_EQ(plaintext,
+                  mtx::crypto::to_string(mtx::crypto::decrypt_file(
+                    mtx::crypto::to_string(encryption_data.first), encryption_data.second)));
+
+        json j                                            = R"({
+  "type": "m.room.message",
+  "content": {
+    "body": "test.txt",
+    "info": {
+      "size": 8,
+      "mimetype": "text/plain"
+    },
+    "msgtype": "m.file",
+    "file": {
+      "v": "v2",
+      "key": {
+        "alg": "A256CTR",
+        "ext": true,
+        "k": "6osKLzUKV1YZ06WEX0b77D784Te8oAj5eNU-gAgkjs4",
+        "key_ops": [
+          "encrypt",
+          "decrypt"
+        ],
+        "kty": "oct"
+      },
+      "iv": "7zRP/t89YWcAAAAAAAAAAA",
+      "hashes": {
+        "sha256": "5g41hn7n10sCw3+2j7CQ9SJl6R/v5EBT4MshdFgHhzo"
+      },
+      "url": "mxc://neko.dev/WPKoOAPfPlcHiZZTEoaIoZhN",
+      "mimetype": "text/plain"
+    }
+ },
+ "event_id": "$1575320135447DEPky:neko.dev",
+  "origin_server_ts": 1575320135324,
+  "sender": "@test:neko.dev",
+  "unsigned": {
+    "age": 1081,
+    "transaction_id": "m1575320142400.8"
+  },
+  "room_id": "!YnUlhwgbBaGcAFsJOJ:neko.dev"
+})"_json;
+        mtx::events::RoomEvent<mtx::events::msg::File> ev = j;
+
+        ASSERT_EQ("abcdefg\n",
+                  mtx::crypto::to_string(mtx::crypto::decrypt_file("=\xFDX\xAB\xCA\xEB\x8F\xFF",
+                                                                   ev.content.file.value())));
+}
+
 TEST(Encryption, DISABLED_HandleRoomKeyEvent) {}
 TEST(Encryption, DISABLED_HandleRoomKeyRequestEvent) {}
 TEST(Encryption, DISABLED_HandleNewDevices) {}
diff --git a/tests/messages.cpp b/tests/messages.cpp
index 3c05b5e53..23b5f334e 100644
--- a/tests/messages.cpp
+++ b/tests/messages.cpp
@@ -144,7 +144,7 @@ TEST(RoomEvents, FileMessage)
 
 TEST(RoomEvents, EncryptedImageMessage)
 {
-        json data = R"(
+        json data                   = R"(
 {
   "content": {
     "body": "something-important.jpg",
-- 
GitLab