From 0e988237f6fcb826afc42719adc335dcc7ca0e2e Mon Sep 17 00:00:00 2001
From: Mark Haines <mark.haines@matrix.org>
Date: Wed, 8 Jul 2015 16:00:08 +0100
Subject: [PATCH] Don't pass a key id when creating a new outbound session

---
 include/olm/account.hh |  5 +++--
 include/olm/session.hh |  1 -
 javascript/build.py    |  1 +
 javascript/demo.html   | 18 +++++++++++-------
 javascript/olm_post.js |  3 +--
 olm.py                 | 20 +++++++-------------
 src/account.cpp        |  7 ++++---
 src/olm.cpp            |  2 +-
 src/session.cpp        |  7 +------
 9 files changed, 29 insertions(+), 35 deletions(-)

diff --git a/include/olm/account.hh b/include/olm/account.hh
index 552f069..cf886d1 100644
--- a/include/olm/account.hh
+++ b/include/olm/account.hh
@@ -112,13 +112,14 @@ struct Account {
         std::uint8_t * one_time_json, std::size_t one_time_json_length
     );
 
-    /** Lookup a one_time key with the given key-id */
+    /** Lookup a one time key with the given public key */
     OneTimeKey const * lookup_key(
         Curve25519PublicKey const & public_key
     );
 
+    /** Remove a one time key with the given public key */
     std::size_t remove_key(
-        std::uint32_t id
+        Curve25519PublicKey const & public_key
     );
 };
 
diff --git a/include/olm/session.hh b/include/olm/session.hh
index 1c04108..125df68 100644
--- a/include/olm/session.hh
+++ b/include/olm/session.hh
@@ -38,7 +38,6 @@ struct Session {
     Curve25519PublicKey alice_identity_key;
     Curve25519PublicKey alice_base_key;
     Curve25519PublicKey bob_one_time_key;
-    std::uint32_t bob_one_time_key_id;
 
     std::size_t new_outbound_session_random_length();
 
diff --git a/javascript/build.py b/javascript/build.py
index 68b7e45..9766906 100755
--- a/javascript/build.py
+++ b/javascript/build.py
@@ -59,6 +59,7 @@ compile_args += source_files
 compile_args += ("--pre-js", pre_js)
 compile_args += ("--post-js", post_js)
 compile_args += ("-s", "EXPORTED_FUNCTIONS=@" + exported_functions)
+compile_args += sys.argv[1:]
 
 library = "build/olm.js"
 
diff --git a/javascript/demo.html b/javascript/demo.html
index 5a32e96..c9cad8b 100644
--- a/javascript/demo.html
+++ b/javascript/demo.html
@@ -30,16 +30,20 @@ document.addEventListener("DOMContentLoaded", function (event) {
     tasks.push(["bob", "Creating account", function() { bob.create() }]);
     tasks.push(["alice", "Create outbound session", function() {
         var bobs_id_keys = JSON.parse(bob.identity_keys("bob", "bob_device", 0, 0));
-        var bobs_curve25519_key;
+        var bobs_id_key;
         for (key in bobs_id_keys.keys) {
             if (key.startsWith("curve25519:")) {
-                bobs_curve25519_key = bobs_id_keys.keys[key];
+                bobs_id_key = bobs_id_keys.keys[key];
             }
         }
-        var bobs_keys_2 = JSON.parse(bob.one_time_keys())[1];
-        a_session.create_outbound(
-            alice, bobs_curve25519_key, bobs_keys_2[0], bobs_keys_2[1]
-        );
+        var bobs_ot_keys = JSON.parse(bob.one_time_keys());
+        var bobs_ot_key;
+        for (key in bobs_ot_keys) {
+            if (key.startsWith("curve25519:")) {
+                bobs_ot_key = bobs_ot_keys[key];
+            }
+        }
+        a_session.create_outbound(alice, bobs_id_key, bobs_ot_key);
     }]);
     tasks.push(["alice", "Encrypt first message", function() {
         message_1 = a_session.encrypt("");
@@ -96,7 +100,7 @@ document.addEventListener("DOMContentLoaded", function (event) {
             window.setTimeout(function() {
                 task[2]();
                 p.done();
-                window.setTimeout(do_tasks, 0, next);
+                window.setTimeout(do_tasks, 50, next);
             }, 0)
         } else {
             next();
diff --git a/javascript/olm_post.js b/javascript/olm_post.js
index 0494460..7bcc580 100644
--- a/javascript/olm_post.js
+++ b/javascript/olm_post.js
@@ -172,7 +172,7 @@ Session.prototype['unpickle'] = restore_stack(function(key, pickle) {
 });
 
 Session.prototype['create_outbound'] = restore_stack(function(
-    account, their_identity_key, their_one_time_key_id, their_one_time_key
+    account, their_identity_key, their_one_time_key
 ) {
     var random_length = session_method(
         Module['_olm_create_outbound_session_random_length']
@@ -185,7 +185,6 @@ Session.prototype['create_outbound'] = restore_stack(function(
     session_method(Module['_olm_create_outbound_session'])(
         this.ptr, account.ptr,
         identity_key_buffer, identity_key_array.length,
-        their_one_time_key_id,
         one_time_key_buffer, one_time_key_array.length,
         random, random_length
     );
diff --git a/olm.py b/olm.py
index 0666a93..81f9e25 100755
--- a/olm.py
+++ b/olm.py
@@ -1,8 +1,11 @@
 #! /usr/bin/python
 from ctypes import *
 import json
+import os
 
-lib = cdll.LoadLibrary("build/libolm.so")
+lib = cdll.LoadLibrary(os.path.join(
+    os.path.dirname(__file__), "build", "libolm.so")
+)
 
 
 lib.olm_error.argtypes = []
@@ -149,7 +152,6 @@ session_function(
     lib.olm_create_outbound_session,
     c_void_p,  # Account
     c_void_p, c_size_t,  # Identity Key
-    c_uint,  # One Time Key Id
     c_void_p, c_size_t,  # One Time Key
     c_void_p, c_size_t,  # Random
 )
@@ -201,8 +203,7 @@ class Session(object):
             self.ptr, key_buffer, len(key), pickle_buffer, len(pickle)
         )
 
-    def create_outbound(self, account, identity_key, one_time_key_id,
-                        one_time_key):
+    def create_outbound(self, account, identity_key, one_time_key):
         r_length = lib.olm_create_outbound_session_random_length(self.ptr)
         random = read_random(r_length)
         random_buffer = create_string_buffer(random)
@@ -212,7 +213,6 @@ class Session(object):
             self.ptr,
             account.ptr,
             identity_key_buffer, len(identity_key),
-            one_time_key_id,
             one_time_key_buffer, len(one_time_key),
             random_buffer, r_length
         )
@@ -325,11 +325,6 @@ if __name__ == '__main__':
     outbound.add_argument("account_file", help="Local account file")
     outbound.add_argument("session_file", help="Local session file")
     outbound.add_argument("identity_key", help="Remote identity key")
-    outbound.add_argument("signed_key_id", help="Remote signed key id",
-                          type=int)
-    outbound.add_argument("signed_key", help="Remote signed key")
-    outbound.add_argument("one_time_key_id", help="Remote one time key id",
-                          type=int)
     outbound.add_argument("one_time_key", help="Remote one time key")
 
     def do_outbound(args):
@@ -343,8 +338,7 @@ if __name__ == '__main__':
             account.unpickle(args.key, f.read())
         session = Session()
         session.create_outbound(
-            account, args.identity_key, args.signed_key_id, args.signed_key,
-            args.one_time_key_id, args.one_time_key
+            account, args.identity_key, args.one_time_key
         )
         with open(args.session_file, "wb") as f:
             f.write(session.pickle(args.key))
@@ -416,8 +410,8 @@ if __name__ == '__main__':
 
     decrypt = commands.add_parser("decrypt", help="Decrypt a message")
     decrypt.add_argument("session_file", help="Local session file")
-    decrypt.add_argument("plaintext_file", help="Plaintext", default="-")
     decrypt.add_argument("message_file", help="Message", default="-")
+    decrypt.add_argument("plaintext_file", help="Plaintext", default="-")
 
     def do_decrypt(args):
         session = Session()
diff --git a/src/account.cpp b/src/account.cpp
index a171f5c..5bbd6a6 100644
--- a/src/account.cpp
+++ b/src/account.cpp
@@ -29,11 +29,12 @@ olm::OneTimeKey const * olm::Account::lookup_key(
 }
 
 std::size_t olm::Account::remove_key(
-    std::uint32_t id
+    olm::Curve25519PublicKey const & public_key
 ) {
     OneTimeKey * i;
     for (i = one_time_keys.begin(); i != one_time_keys.end(); ++i) {
-        if (i->id == id) {
+        if (0 == memcmp(i->key.public_key, public_key.public_key, 32)) {
+            std::uint32_t id = i->id;
             one_time_keys.erase(i);
             return id;
         }
@@ -42,7 +43,7 @@ std::size_t olm::Account::remove_key(
 }
 
 std::size_t olm::Account::new_account_random_length() {
-    return 103 * 32;
+    return 12 * 32;
 }
 
 std::size_t olm::Account::new_account(
diff --git a/src/olm.cpp b/src/olm.cpp
index ede9c26..65d0648 100644
--- a/src/olm.cpp
+++ b/src/olm.cpp
@@ -447,7 +447,7 @@ size_t olm_remove_one_time_keys(
     OlmSession * session
 ) {
     size_t result = from_c(account)->remove_key(
-        from_c(session)->bob_one_time_key_id
+        from_c(session)->bob_one_time_key
     );
     if (result == std::size_t(-1)) {
         from_c(account)->last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
diff --git a/src/session.cpp b/src/session.cpp
index f3b7637..4abf6cf 100644
--- a/src/session.cpp
+++ b/src/session.cpp
@@ -45,8 +45,7 @@ static const olm::KdfInfo OLM_KDF_INFO = {
 olm::Session::Session(
 ) : ratchet(OLM_KDF_INFO, OLM_CIPHER),
     last_error(olm::ErrorCode::SUCCESS),
-    received_message(false),
-    bob_one_time_key_id(0) {
+    received_message(false) {
 
 }
 
@@ -157,7 +156,6 @@ std::size_t olm::Session::new_inbound_session(
         last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
         return std::size_t(-1);
     }
-    bob_one_time_key_id = our_one_time_key->id;
 
     std::uint8_t shared_secret[96];
 
@@ -364,7 +362,6 @@ std::size_t olm::pickle_length(
     length += olm::pickle_length(value.alice_identity_key);
     length += olm::pickle_length(value.alice_base_key);
     length += olm::pickle_length(value.bob_one_time_key);
-    length += olm::pickle_length(value.bob_one_time_key_id);
     length += olm::pickle_length(value.ratchet);
     return length;
 }
@@ -378,7 +375,6 @@ std::uint8_t * olm::pickle(
     pos = olm::pickle(pos, value.alice_identity_key);
     pos = olm::pickle(pos, value.alice_base_key);
     pos = olm::pickle(pos, value.bob_one_time_key);
-    pos = olm::pickle(pos, value.bob_one_time_key_id);
     pos = olm::pickle(pos, value.ratchet);
     return pos;
 }
@@ -392,7 +388,6 @@ std::uint8_t const * olm::unpickle(
     pos = olm::unpickle(pos, end, value.alice_identity_key);
     pos = olm::unpickle(pos, end, value.alice_base_key);
     pos = olm::unpickle(pos, end, value.bob_one_time_key);
-    pos = olm::unpickle(pos, end, value.bob_one_time_key_id);
     pos = olm::unpickle(pos, end, value.ratchet);
     return pos;
 }
-- 
GitLab