From 307309c69b99087489a5d2d2eccecee7751fead6 Mon Sep 17 00:00:00 2001
From: Hubert Chathi <hubert@uhoreg.ca>
Date: Wed, 27 Jun 2018 17:38:14 -0400
Subject: [PATCH] add initial version of Android wrapper for public key API

---
 .../java/org/matrix/olm/OlmPkTest.java        |  94 +++
 .../java/org/matrix/olm/OlmException.java     |   8 +
 .../java/org/matrix/olm/OlmPkDecryption.java  |  78 +++
 .../java/org/matrix/olm/OlmPkEncryption.java  |  89 +++
 .../java/org/matrix/olm/OlmPkMessage.java     |  23 +
 android/olm-sdk/src/main/jni/Android.mk       |   4 +-
 android/olm-sdk/src/main/jni/olm_jni.h        |   2 +
 .../olm-sdk/src/main/jni/olm_jni_helper.cpp   |  10 +
 android/olm-sdk/src/main/jni/olm_jni_helper.h |   2 +
 android/olm-sdk/src/main/jni/olm_pk.cpp       | 564 ++++++++++++++++++
 android/olm-sdk/src/main/jni/olm_pk.h         |  45 ++
 11 files changed, 918 insertions(+), 1 deletion(-)
 create mode 100644 android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmPkTest.java
 create mode 100644 android/olm-sdk/src/main/java/org/matrix/olm/OlmPkDecryption.java
 create mode 100644 android/olm-sdk/src/main/java/org/matrix/olm/OlmPkEncryption.java
 create mode 100644 android/olm-sdk/src/main/java/org/matrix/olm/OlmPkMessage.java
 create mode 100644 android/olm-sdk/src/main/jni/olm_pk.cpp
 create mode 100644 android/olm-sdk/src/main/jni/olm_pk.h

diff --git a/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmPkTest.java b/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmPkTest.java
new file mode 100644
index 0000000..04b2217
--- /dev/null
+++ b/android/olm-sdk/src/androidTest/java/org/matrix/olm/OlmPkTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.matrix.olm;
+
+import android.support.test.runner.AndroidJUnit4;
+import android.util.Log;
+
+import org.junit.BeforeClass;
+import org.junit.FixMethodOrder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.MethodSorters;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+@RunWith(AndroidJUnit4.class)
+@FixMethodOrder(MethodSorters.NAME_ASCENDING)
+public class OlmPkTest {
+    private static final String LOG_TAG = "OlmPkEncryptionTest";
+
+    private static OlmPkEncryption mOlmPkEncryption;
+    private static OlmPkDecryption mOlmPkDecryption;
+
+    @Test
+    public void test01EncryptAndDecrypt() {
+        try {
+            mOlmPkEncryption = new OlmPkEncryption();
+        } catch (OlmException e) {
+            e.printStackTrace();
+            assertTrue("OlmPkEncryption failed " + e.getMessage(), false);
+        }
+        try {
+            mOlmPkDecryption = new OlmPkDecryption();
+        } catch (OlmException e) {
+            e.printStackTrace();
+            assertTrue("OlmPkEncryption failed " + e.getMessage(), false);
+        }
+
+        assertNotNull(mOlmPkEncryption);
+        assertNotNull(mOlmPkDecryption);
+
+        String key = null;
+        try {
+            key = mOlmPkDecryption.generateKey();
+        } catch (OlmException e) {
+            assertTrue("Exception in generateKey, Exception code=" + e.getExceptionCode(), false);
+        }
+        Log.d(LOG_TAG, "Ephemeral Key: " + key);
+        try {
+            mOlmPkEncryption.setRecipientKey(key);
+        } catch (OlmException e) {
+            assertTrue("Exception in setRecipientKey, Exception code=" + e.getExceptionCode(), false);
+        }
+
+        String clearMessage = "Public key test";
+        OlmPkMessage message = null;
+        try {
+            message = mOlmPkEncryption.encrypt(clearMessage);
+        } catch (OlmException e) {
+            assertTrue("Exception in encrypt, Exception code=" + e.getExceptionCode(), false);
+        }
+        Log.d(LOG_TAG, "message: " + message.mCipherText + " " + message.mMac + " " + message.mEphemeralKey);
+
+        String decryptedMessage = null;
+        try {
+            decryptedMessage = mOlmPkDecryption.decrypt(message);
+        } catch (OlmException e) {
+            assertTrue("Exception in decrypt, Exception code=" + e.getExceptionCode(), false);
+        }
+        assertTrue(clearMessage.equals(decryptedMessage));
+
+        mOlmPkEncryption.releaseEncryption();
+        mOlmPkDecryption.releaseDecryption();
+        assertTrue(mOlmPkEncryption.isReleased());
+        assertTrue(mOlmPkDecryption.isReleased());
+    }
+}
diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmException.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmException.java
index 33ac49e..31b5729 100644
--- a/android/olm-sdk/src/main/java/org/matrix/olm/OlmException.java
+++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmException.java
@@ -61,6 +61,14 @@ public class OlmException extends IOException {
     public static final int EXCEPTION_CODE_UTILITY_CREATION = 500;
     public static final int EXCEPTION_CODE_UTILITY_VERIFY_SIGNATURE = 501;
 
+    public static final int EXCEPTION_CODE_PK_ENCRYPTION_CREATION = 600;
+    public static final int EXCEPTION_CODE_PK_ENCRYPTION_SET_RECIPIENT_KEY = 601;
+    public static final int EXCEPTION_CODE_PK_ENCRYPTION_ENCRYPT = 602;
+
+    public static final int EXCEPTION_CODE_PK_DECRYPTION_CREATION = 700;
+    public static final int EXCEPTION_CODE_PK_DECRYPTION_GENERATE_KEY = 701;
+    public static final int EXCEPTION_CODE_PK_DECRYPTION_DECRYPT = 702;
+
     // exception human readable messages
     public static final String EXCEPTION_MSG_INVALID_PARAMS_DESERIALIZATION = "invalid de-serialized parameters";
 
diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkDecryption.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkDecryption.java
new file mode 100644
index 0000000..03d055a
--- /dev/null
+++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkDecryption.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.matrix.olm;
+
+import android.util.Log;
+
+public class OlmPkDecryption {
+    private static final String LOG_TAG = "OlmPkDecryption";
+
+    /** Session Id returned by JNI.
+     * This value uniquely identifies the native session instance.
+     **/
+    private transient long mNativeId;
+
+    public OlmPkDecryption() throws OlmException {
+        try {
+            mNativeId = createNewPkDecryptionJni();
+        } catch (Exception e) {
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_DECRYPTION_CREATION, e.getMessage());
+        }
+    }
+
+    private native long createNewPkDecryptionJni();
+
+    private native void releasePkDecryptionJni();
+
+    public void releaseDecryption() {
+        if (0 != mNativeId) {
+            releasePkDecryptionJni();
+        }
+        mNativeId = 0;
+    }
+
+    public boolean isReleased() {
+        return (0 == mNativeId);
+    }
+
+    public String generateKey() throws OlmException {
+        try {
+            byte[] key = generateKeyJni();
+            return new String(key, "UTF-8");
+        } catch (Exception e) {
+            Log.e(LOG_TAG, "## setRecipientKey(): failed " + e.getMessage());
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_DECRYPTION_GENERATE_KEY, e.getMessage());
+        }
+    }
+
+    private native byte[] generateKeyJni();
+
+    public String decrypt(OlmPkMessage aMessage) throws OlmException {
+        if (null == aMessage) {
+            return null;
+        }
+
+        try {
+            return new String(decryptJni(aMessage), "UTF-8");
+        } catch (Exception e) {
+            Log.e(LOG_TAG, "## pkDecrypt(): failed " + e.getMessage());
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_DECRYPTION_DECRYPT, e.getMessage());
+        }
+    }
+
+    private native byte[] decryptJni(OlmPkMessage aMessage);
+}
diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkEncryption.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkEncryption.java
new file mode 100644
index 0000000..9bd429d
--- /dev/null
+++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkEncryption.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.matrix.olm;
+
+import android.util.Log;
+
+public class OlmPkEncryption {
+    private static final String LOG_TAG = "OlmPkEncryption";
+
+    /** Session Id returned by JNI.
+     * This value uniquely identifies the native session instance.
+     **/
+    private transient long mNativeId;
+
+    public OlmPkEncryption() throws OlmException {
+        try {
+            mNativeId = createNewPkEncryptionJni();
+        } catch (Exception e) {
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_ENCRYPTION_CREATION, e.getMessage());
+        }
+    }
+
+    private native long createNewPkEncryptionJni();
+
+    private native void releasePkEncryptionJni();
+
+    public void releaseEncryption() {
+        if (0 != mNativeId) {
+            releasePkEncryptionJni();
+        }
+        mNativeId = 0;
+    }
+
+    public boolean isReleased() {
+        return (0 == mNativeId);
+    }
+
+    public void setRecipientKey(String aKey) throws OlmException {
+        if (null == aKey) {
+            return;
+        }
+
+        try {
+            setRecipientKeyJni(aKey.getBytes("UTF-8"));
+        } catch (Exception e) {
+            Log.e(LOG_TAG, "## setRecipientKey(): failed " + e.getMessage());
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_ENCRYPTION_SET_RECIPIENT_KEY, e.getMessage());
+        }
+    }
+
+    private native void setRecipientKeyJni(byte[] aKey);
+
+    public OlmPkMessage encrypt(String aPlaintext) throws OlmException {
+        if (null == aPlaintext) {
+            return null;
+        }
+
+        OlmPkMessage encryptedMsgRetValue = new OlmPkMessage();
+
+        try {
+            byte[] ciphertextBuffer = encryptJni(aPlaintext.getBytes("UTF-8"), encryptedMsgRetValue);
+
+            if (null != ciphertextBuffer) {
+                encryptedMsgRetValue.mCipherText = new String(ciphertextBuffer, "UTF-8");
+            }
+        } catch (Exception e) {
+            Log.e(LOG_TAG, "## pkEncrypt(): failed " + e.getMessage());
+            throw new OlmException(OlmException.EXCEPTION_CODE_PK_ENCRYPTION_ENCRYPT, e.getMessage());
+        }
+
+        return encryptedMsgRetValue;
+    }
+
+    private native byte[] encryptJni(byte[] plaintext, OlmPkMessage aMessage);
+}
diff --git a/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkMessage.java b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkMessage.java
new file mode 100644
index 0000000..f8a065a
--- /dev/null
+++ b/android/olm-sdk/src/main/java/org/matrix/olm/OlmPkMessage.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.matrix.olm;
+
+public class OlmPkMessage {
+    public String mCipherText;
+    public String mMac;
+    public String mEphemeralKey;
+}
diff --git a/android/olm-sdk/src/main/jni/Android.mk b/android/olm-sdk/src/main/jni/Android.mk
index 44a2787..0d98f69 100644
--- a/android/olm-sdk/src/main/jni/Android.mk
+++ b/android/olm-sdk/src/main/jni/Android.mk
@@ -40,6 +40,7 @@ $(SRC_ROOT_DIR)/src/pickle.cpp \
 $(SRC_ROOT_DIR)/src/ratchet.cpp \
 $(SRC_ROOT_DIR)/src/session.cpp \
 $(SRC_ROOT_DIR)/src/utility.cpp \
+$(SRC_ROOT_DIR)/src/pk.cpp \
 $(SRC_ROOT_DIR)/src/ed25519.c \
 $(SRC_ROOT_DIR)/src/error.c \
 $(SRC_ROOT_DIR)/src/inbound_group_session.c \
@@ -55,7 +56,8 @@ olm_jni_helper.cpp \
 olm_inbound_group_session.cpp \
 olm_outbound_group_session.cpp \
 olm_utility.cpp \
-olm_manager.cpp
+olm_manager.cpp \
+olm_pk.cpp
 
 LOCAL_LDLIBS := -llog
 
diff --git a/android/olm-sdk/src/main/jni/olm_jni.h b/android/olm-sdk/src/main/jni/olm_jni.h
index 6a5eb1d..80c0a97 100644
--- a/android/olm-sdk/src/main/jni/olm_jni.h
+++ b/android/olm-sdk/src/main/jni/olm_jni.h
@@ -70,6 +70,8 @@ struct OlmAccount* getAccountInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
 struct OlmInboundGroupSession* getInboundGroupSessionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
 struct OlmOutboundGroupSession* getOutboundGroupSessionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
 struct OlmUtility* getUtilityInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
+struct OlmPkDecryption* getPkDecryptionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
+struct OlmPkEncryption* getPkEncryptionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject);
 
 #ifdef __cplusplus
 }
diff --git a/android/olm-sdk/src/main/jni/olm_jni_helper.cpp b/android/olm-sdk/src/main/jni/olm_jni_helper.cpp
index a1f5c59..1997334 100644
--- a/android/olm-sdk/src/main/jni/olm_jni_helper.cpp
+++ b/android/olm-sdk/src/main/jni/olm_jni_helper.cpp
@@ -212,3 +212,13 @@ struct OlmUtility* getUtilityInstanceId(JNIEnv* aJniEnv, jobject aJavaObject)
 {
     return (struct OlmUtility*)getInstanceId(aJniEnv, aJavaObject, CLASS_OLM_UTILITY);
 }
+
+struct OlmPkDecryption* getPkDecryptionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject)
+{
+    return (struct OlmPkDecryption*)getInstanceId(aJniEnv, aJavaObject, CLASS_OLM_PK_DECRYPTION);
+}
+
+struct OlmPkEncryption* getPkEncryptionInstanceId(JNIEnv* aJniEnv, jobject aJavaObject)
+{
+    return (struct OlmPkEncryption*)getInstanceId(aJniEnv, aJavaObject, CLASS_OLM_PK_ENCRYPTION);
+}
diff --git a/android/olm-sdk/src/main/jni/olm_jni_helper.h b/android/olm-sdk/src/main/jni/olm_jni_helper.h
index a181b32..9a23532 100644
--- a/android/olm-sdk/src/main/jni/olm_jni_helper.h
+++ b/android/olm-sdk/src/main/jni/olm_jni_helper.h
@@ -25,4 +25,6 @@ namespace AndroidOlmSdk
     static const char *CLASS_OLM_SESSION = "org/matrix/olm/OlmSession";
     static const char *CLASS_OLM_ACCOUNT = "org/matrix/olm/OlmAccount";
     static const char *CLASS_OLM_UTILITY = "org/matrix/olm/OlmUtility";
+    static const char *CLASS_OLM_PK_ENCRYPTION = "org/matrix/olm/OlmPkEncryption";
+    static const char *CLASS_OLM_PK_DECRYPTION = "org/matrix/olm/OlmPkDecryption";
 }
diff --git a/android/olm-sdk/src/main/jni/olm_pk.cpp b/android/olm-sdk/src/main/jni/olm_pk.cpp
new file mode 100644
index 0000000..2e936c6
--- /dev/null
+++ b/android/olm-sdk/src/main/jni/olm_pk.cpp
@@ -0,0 +1,564 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm_pk.h"
+
+#include "olm/olm.h"
+
+using namespace AndroidOlmSdk;
+
+OlmPkEncryption * initializePkEncryptionMemory()
+{
+    size_t encryptionSize = olm_pk_encryption_size();
+    OlmPkEncryption *encryptionPtr = (OlmPkEncryption *)malloc(encryptionSize);
+
+    if (encryptionPtr)
+    {
+        // init encryption object
+        encryptionPtr = olm_pk_encryption(encryptionPtr);
+        LOGD("## initializePkEncryptionMemory(): success - OLM encryption size=%lu",static_cast<long unsigned int>(encryptionSize));
+    }
+    else
+    {
+        LOGE("## initializePkEncryptionMemory(): failure - OOM");
+    }
+
+    return encryptionPtr;
+}
+
+JNIEXPORT jlong OLM_PK_ENCRYPTION_FUNC_DEF(createNewPkEncryptionJni)(JNIEnv *env, jobject thiz)
+{
+    const char* errorMessage = NULL;
+    OlmPkEncryption *encryptionPtr = initializePkEncryptionMemory();
+
+    // init encryption memory allocation
+    if (!encryptionPtr)
+    {
+        LOGE("## createNewPkEncryptionJni(): failure - init encryption OOM");
+        errorMessage = "init encryption OOM";
+    }
+    else
+    {
+        LOGD("## createNewPkEncryptionJni(): success - OLM encryption created");
+        LOGD("## createNewPkEncryptionJni(): encryptionPtr=%p (jlong)(intptr_t)encryptionPtr=%lld", encryptionPtr, (jlong)(intptr_t)encryptionPtr);
+    }
+
+    if (errorMessage)
+    {
+        // release the allocated data
+        if (encryptionPtr)
+        {
+            olm_clear_pk_encryption(encryptionPtr);
+            free(encryptionPtr);
+        }
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+
+    return (jlong)(intptr_t)encryptionPtr;
+}
+
+JNIEXPORT void OLM_PK_ENCRYPTION_FUNC_DEF(releasePkEncryptionJni)(JNIEnv *env, jobject thiz)
+{
+    LOGD("## releasePkEncryptionJni(): IN");
+
+    OlmPkEncryption* encryptionPtr = getPkEncryptionInstanceId(env, thiz);
+
+    if (!encryptionPtr)
+    {
+        LOGE(" ## releasePkEncryptionJni(): failure - invalid Encryption ptr=NULL");
+    }
+    else
+    {
+        LOGD(" ## releasePkEncryptionJni(): encryptionPtr=%p", encryptionPtr);
+        olm_clear_pk_encryption(encryptionPtr);
+
+        LOGD(" ## releasePkEncryptionJni(): IN");
+        // even if free(NULL) does not crash, logs are performed for debug
+        // purpose
+        free(encryptionPtr);
+        LOGD(" ## releasePkEncryptionJni(): OUT");
+    }
+}
+
+JNIEXPORT void OLM_PK_ENCRYPTION_FUNC_DEF(setRecipientKeyJni)(JNIEnv *env, jobject thiz, jbyteArray aKeyBuffer)
+{
+    const char *errorMessage = NULL;
+    jbyte *keyPtr = NULL;
+
+    OlmPkEncryption *encryptionPtr = getPkEncryptionInstanceId(env, thiz);
+
+    if (!encryptionPtr)
+    {
+        LOGE(" ## pkSetRecipientKeyJni(): failure - invalid Encryption ptr=NULL");
+    }
+    else if (!aKeyBuffer)
+    {
+        LOGE(" ## pkSetRecipientKeyJni(): failure - invalid key");
+        errorMessage = "invalid key";
+    }
+    else if (!(keyPtr = env->GetByteArrayElements(aKeyBuffer, 0)))
+    {
+        LOGE(" ## pkSetRecipientKeyJni(): failure - key JNI allocation OOM");
+        errorMessage = "key JNI allocation OOM";
+    }
+    else
+    {
+        if(olm_pk_encryption_set_recipient_key(encryptionPtr, keyPtr, (size_t)env->GetArrayLength(aKeyBuffer)) == olm_error())
+        {
+            errorMessage = olm_pk_encryption_last_error(encryptionPtr);
+            LOGE(" ## pkSetRecipientKeyJni(): failure - olm_pk_encryption_set_recipient_key Msg=%s", errorMessage);
+        }
+    }
+
+    if (keyPtr)
+    {
+        env->ReleaseByteArrayElements(aKeyBuffer, keyPtr, JNI_ABORT);
+    }
+
+    if (errorMessage)
+    {
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+}
+
+JNIEXPORT jbyteArray OLM_PK_ENCRYPTION_FUNC_DEF(encryptJni)(JNIEnv *env, jobject thiz, jbyteArray aPlaintextBuffer, jobject aEncryptedMsg)
+{
+    jbyteArray encryptedMsgRet = 0;
+    const char* errorMessage = NULL;
+    jbyte *plaintextPtr = NULL;
+
+    OlmPkEncryption *encryptionPtr = getPkEncryptionInstanceId(env, thiz);
+    jclass encryptedMsgJClass = 0;
+    jfieldID macFieldId;
+    jfieldID ephemeralFieldId;
+
+    if (!encryptionPtr)
+    {
+        LOGE(" ## pkEncryptJni(): failure - invalid Encryption ptr=NULL");
+    }
+    else if (!aPlaintextBuffer)
+    {
+        LOGE(" ## pkEncryptJni(): failure - invalid clear message");
+        errorMessage = "invalid clear message";
+    }
+    else if (!(plaintextPtr = env->GetByteArrayElements(aPlaintextBuffer, 0)))
+    {
+        LOGE(" ## pkEncryptJni(): failure - plaintext JNI allocation OOM");
+        errorMessage = "plaintext JNI allocation OOM";
+    }
+    else if (!(encryptedMsgJClass = env->GetObjectClass(aEncryptedMsg)))
+    {
+        LOGE(" ## pkEncryptJni(): failure - unable to get crypted message class");
+        errorMessage = "unable to get crypted message class";
+    }
+    else if (!(macFieldId = env->GetFieldID(encryptedMsgJClass, "mMac", "Ljava/lang/String;")))
+    {
+        LOGE("## pkEncryptJni(): failure - unable to get MAC field");
+        errorMessage = "unable to get MAC field";
+    }
+    else if (!(ephemeralFieldId = env->GetFieldID(encryptedMsgJClass, "mEphemeralKey", "Ljava/lang/String;")))
+    {
+        LOGE("## pkEncryptJni(): failure - unable to get ephemeral key field");
+        errorMessage = "unable to get ephemeral key field";
+    }
+    else
+    {
+        size_t plaintextLength = (size_t)env->GetArrayLength(aPlaintextBuffer);
+        size_t ciphertextLength = olm_pk_ciphertext_length(encryptionPtr, plaintextLength);
+        size_t macLength = olm_pk_mac_length(encryptionPtr);
+        size_t ephemeralLength = olm_pk_key_length();
+        uint8_t *ciphertextPtr = NULL, *macPtr = NULL, *ephemeralPtr = NULL;
+        size_t randomLength = olm_pk_encrypt_random_length(encryptionPtr);
+        uint8_t *randomBuffPtr = NULL;
+        LOGD("## pkEncryptJni(): randomLength=%lu",static_cast<long unsigned int>(randomLength));
+        if (!(ciphertextPtr = (uint8_t*)malloc(ciphertextLength)))
+        {
+            LOGE("## pkEncryptJni(): failure - ciphertext JNI allocation OOM");
+            errorMessage = "ciphertext JNI allocation OOM";
+        }
+        else if (!(macPtr = (uint8_t*)malloc(macLength + 1)))
+        {
+            LOGE("## pkEncryptJni(): failure - MAC JNI allocation OOM");
+            errorMessage = "MAC JNI allocation OOM";
+        }
+        else if (!(ephemeralPtr = (uint8_t*)malloc(ephemeralLength + 1)))
+        {
+            LOGE("## pkEncryptJni(): failure: ephemeral key JNI allocation OOM");
+            errorMessage = "ephemeral JNI allocation OOM";
+        }
+        else if (!setRandomInBuffer(env, &randomBuffPtr, randomLength))
+        {
+            LOGE("## pkEncryptJni(): failure - random buffer init");
+            errorMessage = "random buffer init";
+        }
+        else
+        {
+            macPtr[macLength] = '\0';
+            ephemeralPtr[ephemeralLength] = '\0';
+
+            size_t returnValue = olm_pk_encrypt(
+                encryptionPtr,
+                plaintextPtr, plaintextLength,
+                ciphertextPtr, ciphertextLength,
+                macPtr, macLength,
+                ephemeralPtr, ephemeralLength,
+                randomBuffPtr, randomLength
+            );
+
+            if (returnValue == olm_error())
+            {
+                errorMessage = olm_pk_encryption_last_error(encryptionPtr);
+                LOGE("## pkEncryptJni(): failure - olm_pk_encrypt Msg=%s", errorMessage);
+            }
+            else
+            {
+                encryptedMsgRet = env->NewByteArray(ciphertextLength);
+                env->SetByteArrayRegion(encryptedMsgRet, 0, ciphertextLength, (jbyte*)ciphertextPtr);
+
+                jstring macStr = env->NewStringUTF((char*)macPtr);
+                env->SetObjectField(aEncryptedMsg, macFieldId, macStr);
+                jstring ephemeralStr = env->NewStringUTF((char*)ephemeralPtr);
+                env->SetObjectField(aEncryptedMsg, ephemeralFieldId, ephemeralStr);
+            }
+        }
+
+        if (randomBuffPtr)
+        {
+            memset(randomBuffPtr, 0, randomLength);
+            free(randomBuffPtr);
+        }
+        if (ephemeralPtr)
+        {
+            free(ephemeralPtr);
+        }
+        if (macPtr)
+        {
+            free(macPtr);
+        }
+        if (ciphertextPtr)
+        {
+            free(ciphertextPtr);
+        }
+    }
+
+    if (plaintextPtr)
+    {
+        env->ReleaseByteArrayElements(aPlaintextBuffer, plaintextPtr, JNI_ABORT);
+    }
+
+    if (errorMessage)
+    {
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+
+    return encryptedMsgRet;
+}
+
+OlmPkDecryption * initializePkDecryptionMemory()
+{
+    size_t decryptionSize = olm_pk_decryption_size();
+    OlmPkDecryption *decryptionPtr = (OlmPkDecryption *)malloc(decryptionSize);
+
+    if (decryptionPtr)
+    {
+        // init decryption object
+        decryptionPtr = olm_pk_decryption(decryptionPtr);
+        LOGD("## initializePkDecryptionMemory(): success - OLM decryption size=%lu",static_cast<long unsigned int>(decryptionSize));
+    }
+    else
+    {
+        LOGE("## initializePkDecryptionMemory(): failure - OOM");
+    }
+
+    return decryptionPtr;
+}
+
+JNIEXPORT jlong OLM_PK_DECRYPTION_FUNC_DEF(createNewPkDecryptionJni)(JNIEnv *env, jobject thiz)
+{
+    const char* errorMessage = NULL;
+    OlmPkDecryption *decryptionPtr = initializePkDecryptionMemory();
+
+    // init encryption memory allocation
+    if (!decryptionPtr)
+    {
+        LOGE("## createNewPkDecryptionJni(): failure - init decryption OOM");
+        errorMessage = "init decryption OOM";
+    }
+    else
+    {
+        LOGD("## createNewPkDecryptionJni(): success - OLM decryption created");
+        LOGD("## createNewPkDecryptionJni(): decryptionPtr=%p (jlong)(intptr_t)decryptionPtr=%lld", decryptionPtr, (jlong)(intptr_t)decryptionPtr);
+    }
+
+    if (errorMessage)
+    {
+        // release the allocated data
+        if (decryptionPtr)
+        {
+            olm_clear_pk_decryption(decryptionPtr);
+            free(decryptionPtr);
+        }
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+
+    return (jlong)(intptr_t)decryptionPtr;
+}
+
+JNIEXPORT void OLM_PK_DECRYPTION_FUNC_DEF(releasePkDecryptionJni)(JNIEnv *env, jobject thiz)
+{
+    LOGD("## releasePkDecryptionJni(): IN");
+
+    OlmPkDecryption* decryptionPtr = getPkDecryptionInstanceId(env, thiz);
+
+    if (!decryptionPtr)
+    {
+        LOGE(" ## releasePkDecryptionJni(): failure - invalid Decryption ptr=NULL");
+    }
+    else
+    {
+        LOGD(" ## releasePkDecryptionJni(): decryptionPtr=%p", encryptionPtr);
+        olm_clear_pk_decryption(decryptionPtr);
+
+        LOGD(" ## releasePkDecryptionJni(): IN");
+        // even if free(NULL) does not crash, logs are performed for debug
+        // purpose
+        free(decryptionPtr);
+        LOGD(" ## releasePkDecryptionJni(): OUT");
+    }
+}
+
+JNIEXPORT jbyteArray OLM_PK_DECRYPTION_FUNC_DEF(generateKeyJni)(JNIEnv *env, jobject thiz)
+{
+    size_t randomLength = olm_pk_generate_key_random_length();
+    uint8_t *randomBuffPtr = NULL;
+
+    jbyteArray publicKeyRet = 0;
+    uint8_t *publicKeyPtr = NULL;
+    size_t publicKeyLength = olm_pk_key_length();
+    const char* errorMessage = NULL;
+
+    OlmPkDecryption *decryptionPtr = getPkDecryptionInstanceId(env, thiz);
+
+    if (!decryptionPtr)
+    {
+        LOGE(" ## pkGenerateKeyJni(): failure - invalid Decryption ptr=NULL");
+        errorMessage = "invalid Decryption ptr=NULL";
+    }
+    else if (!setRandomInBuffer(env, &randomBuffPtr, randomLength))
+    {
+        LOGE("## pkGenerateKeyJni(): failure - random buffer init");
+        errorMessage = "random buffer init";
+    }
+    else if (!(publicKeyPtr = static_cast<uint8_t*>(malloc(publicKeyLength))))
+    {
+        LOGE("## pkGenerateKeyJni(): failure - public key allocation OOM");
+        errorMessage = "public key allocation OOM";
+    }
+    else
+    {
+        if (olm_pk_generate_key(decryptionPtr, publicKeyPtr, publicKeyLength, randomBuffPtr, randomLength) == olm_error())
+        {
+            errorMessage = olm_pk_decryption_last_error(decryptionPtr);
+            LOGE("## pkGenerateKeyJni(): failure - olm_pk_generate_key Msg=%s", errorMessage);
+        }
+        else
+        {
+            publicKeyRet = env->NewByteArray(publicKeyLength);
+            env->SetByteArrayRegion(publicKeyRet, 0, publicKeyLength, (jbyte*)publicKeyPtr);
+            LOGD("## pkGenerateKeyJni(): public key generated");
+        }
+    }
+
+    if (randomBuffPtr)
+    {
+        memset(randomBuffPtr, 0, randomLength);
+        free(randomBuffPtr);
+    }
+
+    if (errorMessage)
+    {
+        // release the allocated data
+        if (decryptionPtr)
+        {
+            olm_clear_pk_decryption(decryptionPtr);
+            free(decryptionPtr);
+        }
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+
+    return publicKeyRet;
+}
+
+JNIEXPORT jbyteArray OLM_PK_DECRYPTION_FUNC_DEF(decryptJni)(JNIEnv *env, jobject thiz, jobject aEncryptedMsg)
+{
+    const char* errorMessage = NULL;
+    OlmPkDecryption *decryptionPtr = getPkDecryptionInstanceId(env, thiz);
+
+    jclass encryptedMsgJClass = 0;
+    jstring ciphertextJstring = 0;
+    jstring macJstring = 0;
+    jstring ephemeralKeyJstring = 0;
+    jfieldID ciphertextFieldId;
+    jfieldID macFieldId;
+    jfieldID ephemeralKeyFieldId;
+
+    const char *ciphertextPtr = NULL;
+    const char *macPtr = NULL;
+    const char *ephemeralKeyPtr = NULL;
+
+    jbyteArray decryptedMsgRet = 0;
+
+    if (!decryptionPtr)
+    {
+        LOGE(" ## pkDecryptJni(): failure - invalid Decryption ptr=NULL");
+        errorMessage = "invalid Decryption ptr=NULL";
+    }
+    else if (!aEncryptedMsg)
+    {
+        LOGE(" ## pkDecryptJni(): failure - invalid encrypted message");
+        errorMessage = "invalid encrypted message";
+    }
+    else if (!(encryptedMsgJClass = env->GetObjectClass(aEncryptedMsg)))
+    {
+        LOGE("## pkDecryptJni(): failure - unable to get encrypted message class");
+        errorMessage = "unable to get encrypted message class";
+    }
+    else if (!(ciphertextFieldId = env->GetFieldID(encryptedMsgJClass,"mCipherText","Ljava/lang/String;")))
+    {
+        LOGE("## pkDecryptJni(): failure - unable to get message field");
+        errorMessage = "unable to get message field";
+    }
+    else if (!(ciphertextJstring = (jstring)env->GetObjectField(aEncryptedMsg, ciphertextFieldId)))
+    {
+        LOGE("## pkDecryptJni(): failure - no ciphertext");
+        errorMessage = "no ciphertext";
+    }
+    else if (!(ciphertextPtr = env->GetStringUTFChars(ciphertextJstring, 0)))
+    {
+        LOGE("## pkDecryptJni(): failure - ciphertext JNI allocation OOM");
+        errorMessage = "ciphertext JNI allocation OOM";
+    }
+    else if (!(ciphertextJstring = (jstring)env->GetObjectField(aEncryptedMsg, ciphertextFieldId)))
+    {
+        LOGE("## pkDecryptJni(): failure - no ciphertext");
+        errorMessage = "no ciphertext";
+    }
+    else if (!(ciphertextPtr = env->GetStringUTFChars(ciphertextJstring, 0)))
+    {
+        LOGE("## decryptMessageJni(): failure - ciphertext JNI allocation OOM");
+        errorMessage = "ciphertext JNI allocation OOM";
+    }
+    else if (!(macFieldId = env->GetFieldID(encryptedMsgJClass,"mMac","Ljava/lang/String;")))
+    {
+        LOGE("## pkDecryptJni(): failure - unable to get MAC field");
+        errorMessage = "unable to get MAC field";
+    }
+    else if (!(macJstring = (jstring)env->GetObjectField(aEncryptedMsg, macFieldId)))
+    {
+        LOGE("## pkDecryptJni(): failure - no MAC");
+        errorMessage = "no MAC";
+    }
+    else if (!(macPtr = env->GetStringUTFChars(macJstring, 0)))
+    {
+        LOGE("## pkDecryptJni(): failure - MAC JNI allocation OOM");
+        errorMessage = "ciphertext JNI allocation OOM";
+    }
+    else if (!(ephemeralKeyFieldId = env->GetFieldID(encryptedMsgJClass,"mEphemeralKey","Ljava/lang/String;")))
+    {
+        LOGE("## pkDecryptJni(): failure - unable to get ephemeral key field");
+        errorMessage = "unable to get ephemeral key field";
+    }
+    else if (!(ephemeralKeyJstring = (jstring)env->GetObjectField(aEncryptedMsg, ephemeralKeyFieldId)))
+    {
+        LOGE("## pkDecryptJni(): failure - no ephemeral key");
+        errorMessage = "no ephemeral key";
+    }
+    else if (!(ephemeralKeyPtr = env->GetStringUTFChars(ephemeralKeyJstring, 0)))
+    {
+        LOGE("## pkDecryptJni(): failure - ephemeral key JNI allocation OOM");
+        errorMessage = "ephemeral key JNI allocation OOM";
+    }
+    else
+    {
+        size_t maxPlaintextLength = olm_pk_max_plaintext_length(
+            decryptionPtr,
+            (size_t)env->GetStringUTFLength(ciphertextJstring)
+        );
+        uint8_t *plaintextPtr = NULL;
+        uint8_t *tempCiphertextPtr = NULL;
+        size_t ciphertextLength = (size_t)env->GetStringUTFLength(ciphertextJstring);
+        if (!(plaintextPtr = (uint8_t*)malloc(maxPlaintextLength)))
+        {
+            LOGE("## pkDecryptJni(): failure - plaintext JNI allocation OOM");
+            errorMessage = "plaintext JNI allocation OOM";
+        }
+        else if (!(tempCiphertextPtr = (uint8_t*)malloc(ciphertextLength)))
+        {
+            LOGE("## pkDecryptJni(): failure - temp ciphertext JNI allocation OOM");
+        }
+        else
+        {
+            memcpy(tempCiphertextPtr, ciphertextPtr, ciphertextLength);
+            size_t plaintextLength = olm_pk_decrypt(
+                decryptionPtr,
+                ephemeralKeyPtr, (size_t)env->GetStringUTFLength(ephemeralKeyJstring),
+                macPtr, (size_t)env->GetStringUTFLength(macJstring),
+                tempCiphertextPtr, ciphertextLength,
+                plaintextPtr, maxPlaintextLength
+            );
+            if (plaintextLength == olm_error())
+            {
+                errorMessage = olm_pk_decryption_last_error(decryptionPtr);
+                LOGE("## pkDecryptJni(): failure - olm_pk_decrypt Msg=%s", errorMessage);
+            }
+            else
+            {
+                decryptedMsgRet = env->NewByteArray(plaintextLength);
+                env->SetByteArrayRegion(decryptedMsgRet, 0, plaintextLength, (jbyte*)plaintextPtr);
+                LOGD("## pkDecryptJni(): success returnedLg=%lu OK", static_cast<long unsigned int>(plaintextLength));
+            }
+        }
+
+        if (tempCiphertextPtr)
+        {
+          free(tempCiphertextPtr);
+        }
+        if (plaintextPtr)
+        {
+          free(plaintextPtr);
+        }
+    }
+
+    if (ciphertextPtr)
+    {
+        env->ReleaseStringUTFChars(ciphertextJstring, ciphertextPtr);
+    }
+    if (macPtr)
+    {
+        env->ReleaseStringUTFChars(macJstring, macPtr);
+    }
+    if (ephemeralKeyPtr)
+    {
+        env->ReleaseStringUTFChars(ephemeralKeyJstring, ephemeralKeyPtr);
+    }
+
+    if (errorMessage)
+    {
+        env->ThrowNew(env->FindClass("java/lang/Exception"), errorMessage);
+    }
+
+    return decryptedMsgRet;
+}
diff --git a/android/olm-sdk/src/main/jni/olm_pk.h b/android/olm-sdk/src/main/jni/olm_pk.h
new file mode 100644
index 0000000..984c5f8
--- /dev/null
+++ b/android/olm-sdk/src/main/jni/olm_pk.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _OMLPK_H
+#define _OMLPK_H
+
+#include "olm_jni.h"
+#include "olm/pk.h"
+
+#define OLM_PK_ENCRYPTION_FUNC_DEF(func_name) FUNC_DEF(OlmPkEncryption,func_name)
+#define OLM_PK_DECRYPTION_FUNC_DEF(func_name) FUNC_DEF(OlmPkDecryption,func_name)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+JNIEXPORT jlong OLM_PK_ENCRYPTION_FUNC_DEF(createNewPkEncryptionJni)(JNIEnv *env, jobject thiz);
+JNIEXPORT void OLM_PK_ENCRYPTION_FUNC_DEF(releasePkEncryptionJni)(JNIEnv *env, jobject thiz);
+JNIEXPORT void OLM_PK_ENCRYPTION_FUNC_DEF(setRecipientKeyJni)(JNIEnv *env, jobject thiz, jbyteArray aKeyBuffer);
+
+JNIEXPORT jbyteArray OLM_PK_ENCRYPTION_FUNC_DEF(encryptJni)(JNIEnv *env, jobject thiz, jbyteArray aPlaintextBuffer, jobject aEncryptedMsg);
+
+JNIEXPORT jlong OLM_PK_DECRYPTION_FUNC_DEF(createNewPkDecryptionJni)(JNIEnv *env, jobject thiz);
+JNIEXPORT void OLM_PK_DECRYPTION_FUNC_DEF(releasePkDecryptionJni)(JNIEnv *env, jobject thiz);
+JNIEXPORT jbyteArray OLM_PK_DECRYPTION_FUNC_DEF(generateKeyJni)(JNIEnv *env, jobject thiz);
+JNIEXPORT jbyteArray OLM_PK_DECRYPTION_FUNC_DEF(decryptJni)(JNIEnv *env, jobject thiz, jobject aEncryptedMsg);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
-- 
GitLab