From 75ffe6cb2a5e88f49206abb366ddd0e90a2954cc Mon Sep 17 00:00:00 2001
From: Joseph Donofry <joedonofry@gmail.com>
Date: Thu, 25 Mar 2021 20:11:42 -0400
Subject: [PATCH] Initial commit of MacOS ssl verification with code from MS
 cpprestsdk

---
 CMakeLists.txt                    |   2 +
 include/mtxclient/http/client.hpp |   9 ++
 lib/http/client.cpp               | 172 +++++++++++++++++++++++++++++-
 3 files changed, 181 insertions(+), 2 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7f0255065..87dba0b92 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -220,6 +220,8 @@ if(NOT MSVC AND NOT APPLE)
 	target_link_libraries(matrix_client PUBLIC Threads::Threads)
 elseif(MSVC)
 	target_compile_options(matrix_client PUBLIC /Zc:__cplusplus /utf-8 /MP /Gm- /EHsc)
+elseif(APPLE)
+	target_link_libraries(matrix_client PUBLIC "-framework CoreFoundation" "-framework Security")
 endif()
 
 if(COVERAGE)
diff --git a/include/mtxclient/http/client.hpp b/include/mtxclient/http/client.hpp
index 14824fb06..a7f414eab 100644
--- a/include/mtxclient/http/client.hpp
+++ b/include/mtxclient/http/client.hpp
@@ -22,6 +22,9 @@
 #include "mtxclient/utils.hpp"       // for random_token, url_encode, des...
 // #include "mtx/common.hpp"
 
+#if __APPLE__
+#include <boost/asio/ssl.hpp>
+#endif 
 #include <boost/beast/http/fields.hpp> // for fields
 #include <boost/beast/http/status.hpp> // for status
 #include <boost/system/error_code.hpp> // for error_code
@@ -165,6 +168,12 @@ struct ThumbOpts
 struct ClientPrivate;
 struct Session;
 
+#if __APPLE__
+bool handle_cert_verification(const std::string &server, bool preverified, boost::asio::ssl::verify_context &ctx);
+bool verify_cert_chain(boost::asio::ssl::verify_context &ctx, const std::string &hostName);
+bool verify_X509_cert_chain(const std::vector<std::string>& certChain, const std::string& hostName);
+#endif
+
 //! The main object that the user will interact.
 class Client : public std::enable_shared_from_this<Client>
 {
diff --git a/lib/http/client.cpp b/lib/http/client.cpp
index 44b9d3908..127834cf6 100644
--- a/lib/http/client.cpp
+++ b/lib/http/client.cpp
@@ -3,14 +3,23 @@
 #include "mtxclient/http/client.hpp"
 #include "mtxclient/http/client_impl.hpp"
 
+#if defined(__APPLE__)
+#include <CoreFoundation/CFData.h>
+#include <Security/SecBase.h>
+#include <Security/SecCertificate.h>
+#include <Security/SecPolicy.h>
+#include <Security/SecTrust.h>
+#endif
+
 #include <mutex>
 #include <thread>
+#include <iostream>
 
 #include <nlohmann/json.hpp>
 
 #include <boost/algorithm/string.hpp>
 #include <boost/utility/typed_in_place_factory.hpp>
-
+#include <boost/asio/ssl.hpp>
 #include <boost/asio/ssl/context.hpp>
 #include <boost/beast/http/message.hpp>
 #include <boost/iostreams/stream.hpp>
@@ -79,6 +88,159 @@ struct ClientPrivate
         //! All the active sessions will shutdown the connection.
         boost::signals2::signal<void()> shutdown_signal;
 };
+
+#if defined(__APPLE__)
+// This whole if defined section is basically taken verbatim from MS cpprestsdk:
+// https://github.com/microsoft/cpprestsdk
+// It will be changed once the correct fixes are identified (I just want it to work first...)
+bool
+handle_cert_verification(const std::string &server, bool preverified, boost::asio::ssl::verify_context &ctx)
+{
+        std::cout << "Handling Cert Verification for " << server << std::endl;
+        std::cout << "preverified: " << std::to_string(preverified) << std::endl;
+        if (!preverified)
+        {
+                return false;
+        }
+
+        return verify_cert_chain(ctx, server);
+}
+
+bool
+verify_cert_chain(boost::asio::ssl::verify_context &verifyCtx, const std::string &hostName)
+{
+        X509_STORE_CTX *storeContext = verifyCtx.native_handle();
+        int currentDepth             = X509_STORE_CTX_get_error_depth(storeContext);
+        if (currentDepth != 0) {
+                return true;
+        }
+
+#if (OPENSSL_VERSION_NUMBER < 0x10100000L)
+        STACK_OF(X509) *certStack = X509_STORE_CTX_get_chain(storeContext);
+#else
+        STACK_OF(X509) *certStack = X509_STORE_CTX_get0_chain(storeContext);
+#endif
+
+        const int numCerts = sk_X509_num(certStack);
+        if (numCerts < 0) {
+                std::cout << "numCerts == 0!" << std::endl;
+                return false;
+        }
+
+        std::vector<std::string> certChain;
+        certChain.reserve(numCerts);
+        for (int i = 0; i < numCerts; ++i) {
+                X509 *cert = sk_X509_value(certStack, i);
+
+                // Encode into DER format into raw memory.
+                int len = i2d_X509(cert, nullptr);
+                if (len < 0) {
+                        return false;
+                }
+
+                std::string certData;
+                certData.resize(len);
+                unsigned char *buffer = reinterpret_cast<unsigned char *>(&certData[0]);
+                len                   = i2d_X509(cert, &buffer);
+                if (len < 0) {
+                        return false;
+                }
+
+                certChain.push_back(std::move(certData));
+        }
+
+        auto verify_result = verify_X509_cert_chain(certChain, hostName);
+
+        return verify_result;
+}
+// Simple RAII pattern wrapper to perform CFRelease on objects.
+template<typename T>
+class cf_ref
+{
+public:
+    cf_ref(T v) : value(v)
+    {
+        static_assert(sizeof(cf_ref<T>) == sizeof(T), "Code assumes just a wrapper, see usage in CFArrayCreate below.");
+    }
+    cf_ref() : value(nullptr) {}
+    cf_ref(cf_ref&& other) : value(other.value) { other.value = nullptr; }
+
+    ~cf_ref()
+    {
+        if (value != nullptr)
+        {
+            CFRelease(value);
+        }
+    }
+
+    T& get() { return value; }
+
+private:
+    cf_ref(const cf_ref&);
+    cf_ref& operator=(const cf_ref&);
+    T value;
+};
+
+bool verify_X509_cert_chain(const std::vector<std::string>& certChain, const std::string& hostName)
+{
+    // Build up CFArrayRef with all the certificates.
+    // All this code is basically just to get into the correct structures for the Apple APIs.
+    // Copies are avoided whenever possible.
+    std::vector<cf_ref<SecCertificateRef>> certs;
+    for (const auto& certBuf : certChain)
+    {
+        cf_ref<CFDataRef> certDataRef =
+            CFDataCreateWithBytesNoCopy(kCFAllocatorDefault,
+                                        reinterpret_cast<const unsigned char*>(certBuf.c_str()),
+                                        certBuf.size(),
+                                        kCFAllocatorNull);
+        if (certDataRef.get() == nullptr)
+        {
+                std::cout << "certDataRef is null!" << std::endl;
+            return false;
+        }
+
+        cf_ref<SecCertificateRef> certObj = SecCertificateCreateWithData(nullptr, certDataRef.get());
+        if (certObj.get() == nullptr)
+        {
+                std::cout << "certObj is null!" << std::endl;
+            return false;
+        }
+        certs.push_back(std::move(certObj));
+    }
+    cf_ref<CFArrayRef> certsArray = CFArrayCreate(
+        kCFAllocatorDefault, const_cast<const void**>(reinterpret_cast<void**>(&certs[0])), certs.size(), nullptr);
+    if (certsArray.get() == nullptr)
+    {
+            std::cout << "certObj is null!" << std::endl;
+        return false;
+    }
+
+    // Create trust management object with certificates and SSL policy.
+    // Note: SecTrustCreateWithCertificates expects the certificate to be
+    // verified is the first element.
+    cf_ref<CFStringRef> cfHostName = CFStringCreateWithCStringNoCopy(
+        kCFAllocatorDefault, hostName.c_str(), kCFStringEncodingASCII, kCFAllocatorNull);
+    if (cfHostName.get() == nullptr)
+    {
+        return false;
+    }
+    cf_ref<SecPolicyRef> policy = SecPolicyCreateSSL(true /* client side */, cfHostName.get());
+    cf_ref<SecTrustRef> trust;
+    OSStatus status = SecTrustCreateWithCertificates(certsArray.get(), policy.get(), &trust.get());
+    std::cout << status << std::endl;
+    if (status == noErr)
+    {
+        // Perform actual certificate verification.
+        status = SecTrustEvaluateWithError(trust.get(), nullptr);
+        std::cout << status << std::endl;
+        return status;
+    }
+
+    return false;
+}
+#endif
+
 }
 
 Client::Client(const std::string &server, uint16_t port)
@@ -99,8 +261,14 @@ Client::Client(const std::string &server, uint16_t port)
 #endif
 
         verify_certificates(true);
-        p->ssl_ctx_.set_verify_callback(ssl::rfc2818_verification(server));
 
+#ifdef __APPLE__
+        p->ssl_ctx_.set_verify_callback([server](bool preverified, boost::asio::ssl::verify_context &ctx) {
+                return handle_cert_verification(server, preverified, ctx);
+        });
+#else
+        p->ssl_ctx_.set_verify_callback(ssl::rfc2818_verification(server));
+#endif
         for (unsigned int i = 0; i < threads_num; ++i)
                 p->thread_group_.add_thread(new boost::thread([this]() { p->ioc_.run(); }));
 }
-- 
GitLab