Newer
Older
/* Copyright 2015-2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "olm/message.hh"
#include "olm/memory.hh"
static std::size_t varint_length(
++result;
value >>= 7;
}
return result;
}
template<typename T>
static std::uint8_t * varint_encode(
}
(*output++) = value;
return output;
}
template<typename T>
static T varint_decode(
std::uint8_t const * varint_start,
std::uint8_t const * varint_end
) {
T value = 0;
if (varint_end == varint_start) {
return 0;
}
do {
value <<= 7;
value |= 0x7F & *(--varint_end);
} while (varint_end != varint_start);
return value;
}
static std::uint8_t const * varint_skip(
std::uint8_t const * input,
std::uint8_t const * input_end
) {
while (input != input_end) {
std::uint8_t tmp = *(input++);
if ((tmp & 0x80) == 0) {
return input;
}
}
return input;
}
static std::size_t varstring_length(
std::size_t string_length
) {
return varint_length(string_length) + string_length;
}
static std::size_t const VERSION_LENGTH = 1;
static std::uint8_t const RATCHET_KEY_TAG = 012;
static std::uint8_t const COUNTER_TAG = 020;
static std::uint8_t const CIPHERTEXT_TAG = 042;
static std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint32_t value
) {
*(pos++) = tag;
return varint_encode(pos, value);
}
static std::uint8_t * encode(
std::uint8_t * pos,
std::uint8_t tag,
std::uint8_t * & value, std::size_t value_length
) {
*(pos++) = tag;
pos = varint_encode(pos, value_length);
value = pos;
return pos + value_length;
}
static std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint32_t & value, bool & has_value
) {
if (pos != end && *pos == tag) {
++pos;
std::uint8_t const * value_start = pos;
pos = varint_skip(pos, end);
value = varint_decode<std::uint32_t>(value_start, pos);
has_value = true;
}
return pos;
}
static std::uint8_t const * decode(
std::uint8_t const * pos, std::uint8_t const * end,
std::uint8_t tag,
std::uint8_t const * & value, std::size_t & value_length
) {
if (pos != end && *pos == tag) {
++pos;
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > std::size_t(end - pos)) return end;
value = pos;
value_length = len;
pos += len;
}
return pos;
}
static std::uint8_t const * skip_unknown(
std::uint8_t const * pos, std::uint8_t const * end
) {
if (pos != end) {
uint8_t tag = *pos;
if ((tag & 0x7) == 0) {
pos = varint_skip(pos, end);
pos = varint_skip(pos, end);
} else if ((tag & 0x7) == 2) {
pos = varint_skip(pos, end);
std::uint8_t const * len_start = pos;
pos = varint_skip(pos, end);
std::size_t len = varint_decode<std::size_t>(len_start, pos);
if (len > std::size_t(end - pos)) return end;
pos += len;
} else {
return end;
}
}
return pos;
}
std::size_t olm::encode_message_length(
std::uint32_t counter,
std::size_t ratchet_key_length,
std::size_t ciphertext_length,
std::size_t mac_length
) {
std::size_t length = VERSION_LENGTH;
length += 1 + varstring_length(ratchet_key_length);
length += 1 + varint_length(counter);
length += 1 + varstring_length(ciphertext_length);
length += mac_length;
return length;
void olm::encode_message(
olm::MessageWriter & writer,
std::uint8_t version,
std::uint32_t counter,
std::size_t ratchet_key_length,
std::size_t ciphertext_length,
std::uint8_t * output
) {
std::uint8_t * pos = output;
*(pos++) = version;
pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length);
pos = encode(pos, COUNTER_TAG, counter);
pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length);
void olm::decode_message(
olm::MessageReader & reader,
std::uint8_t const * input, std::size_t input_length,
std::size_t mac_length
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length;
std::uint8_t const * unknown = nullptr;
reader.input = input;
reader.input_length = input_length;
reader.has_counter = false;
Mark Haines
committed
reader.ratchet_key_length = 0;
Mark Haines
committed
reader.ciphertext_length = 0;
if (pos == end) return;
if (input_length < mac_length) return;
Mark Haines
committed
reader.version = *(pos++);
pos = decode(
pos, end, RATCHET_KEY_TAG,
reader.ratchet_key, reader.ratchet_key_length
);
pos = decode(
pos, end, COUNTER_TAG,
reader.counter, reader.has_counter
);
pos = decode(
pos, end, CIPHERTEXT_TAG,
reader.ciphertext, reader.ciphertext_length
);
if (unknown == pos) {
pos = skip_unknown(pos, end);
static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012;
static std::uint8_t const BASE_KEY_TAG = 022;
static std::uint8_t const IDENTITY_KEY_TAG = 032;
static std::uint8_t const MESSAGE_TAG = 042;
} // namespace
std::size_t olm::encode_one_time_key_message_length(
std::size_t one_time_key_length,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t message_length
) {
std::size_t length = VERSION_LENGTH;
length += 1 + varstring_length(one_time_key_length);
length += 1 + varstring_length(identity_key_length);
length += 1 + varstring_length(base_key_length);
length += 1 + varstring_length(message_length);
return length;
}
void olm::encode_one_time_key_message(
olm::PreKeyMessageWriter & writer,
std::uint8_t version,
std::size_t identity_key_length,
std::size_t base_key_length,
std::size_t one_time_key_length,
std::size_t message_length,
std::uint8_t * output
) {
std::uint8_t * pos = output;
*(pos++) = version;
pos = encode(pos, ONE_TIME_KEY_ID_TAG, writer.one_time_key, one_time_key_length);
pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length);
pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length);
pos = encode(pos, MESSAGE_TAG, writer.message, message_length);
}
void olm::decode_one_time_key_message(
PreKeyMessageReader & reader,
std::uint8_t const * input, std::size_t input_length
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length;
std::uint8_t const * unknown = nullptr;
reader.one_time_key = nullptr;
Mark Haines
committed
reader.one_time_key_length = 0;
Mark Haines
committed
reader.identity_key_length = 0;
Mark Haines
committed
reader.base_key_length = 0;
Mark Haines
committed
reader.message_length = 0;
if (pos == end) return;
reader.version = *(pos++);
while (pos != end) {
pos = decode(
pos, end, ONE_TIME_KEY_ID_TAG,
reader.one_time_key, reader.one_time_key_length
);
pos = decode(
pos, end, BASE_KEY_TAG,
reader.base_key, reader.base_key_length
);
pos = decode(
pos, end, IDENTITY_KEY_TAG,
reader.identity_key, reader.identity_key_length
);
pos = decode(
pos, end, MESSAGE_TAG,
reader.message, reader.message_length
);
if (unknown == pos) {
pos = skip_unknown(pos, end);
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
static std::uint8_t const GROUP_SESSION_ID_TAG = 052;
size_t _olm_encode_group_message_length(
size_t group_session_id_length,
uint32_t chain_index,
size_t ciphertext_length,
size_t mac_length
) {
size_t length = VERSION_LENGTH;
length += 1 + varstring_length(group_session_id_length);
length += 1 + varint_length(chain_index);
length += 1 + varstring_length(ciphertext_length);
length += mac_length;
return length;
}
void _olm_encode_group_message(
uint8_t version,
const uint8_t *session_id,
size_t session_id_length,
uint32_t chain_index,
size_t ciphertext_length,
uint8_t *output,
uint8_t **ciphertext_ptr
) {
std::uint8_t * pos = output;
std::uint8_t * session_id_pos;
*(pos++) = version;
pos = encode(pos, GROUP_SESSION_ID_TAG, session_id_pos, session_id_length);
std::memcpy(session_id_pos, session_id, session_id_length);
pos = encode(pos, COUNTER_TAG, chain_index);
pos = encode(pos, CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
}
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
struct _OlmDecodeGroupMessageResults *results
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length;
std::uint8_t const * unknown = nullptr;
results->session_id = nullptr;
results->session_id_length = 0;
bool has_chain_index = false;
results->chain_index = 0;
results->ciphertext = nullptr;
results->ciphertext_length = 0;
if (pos == end) return;
if (input_length < mac_length) return;
results->version = *(pos++);
while (pos != end) {
pos = decode(
pos, end, GROUP_SESSION_ID_TAG,
results->session_id, results->session_id_length
);
pos = decode(
pos, end, COUNTER_TAG,
results->chain_index, has_chain_index
);
pos = decode(
pos, end, CIPHERTEXT_TAG,
results->ciphertext, results->ciphertext_length
);
if (unknown == pos) {
pos = skip_unknown(pos, end);
}
unknown = pos;
}
results->has_chain_index = (int)has_chain_index;
}