Improve MbedTLS implementation of AES-CBC

This commit is contained in:
Cameron Gutman
2021-04-22 17:08:35 -05:00
parent d7549cd953
commit 29d2cc6d5b
4 changed files with 59 additions and 16 deletions

View File

@@ -10,6 +10,7 @@ bool RandomStateInitialized = false;
#else
#include <openssl/evp.h>
#include <openssl/rand.h>
#endif
static int addPkcs7PaddingInPlace(unsigned char* plaintext, int plaintextLen) {
int paddedLength = ROUND_TO_PKCS7_PADDED_LEN(plaintextLen);
@@ -19,9 +20,10 @@ static int addPkcs7PaddingInPlace(unsigned char* plaintext, int plaintextLen) {
return paddedLength;
}
#endif
// For CBC modes, inputData buffer must be allocated with length rounded up to next multiple of 16 and inputData buffer may be modified!
// When CIPHER_FLAG_PAD_TO_BLOCK_SIZE is used, inputData buffer must be allocated such that
// the buffer length is at least ROUND_TO_PKCS7_PADDED_LEN(inputDataLength) and inputData
// buffer may be modified!
bool PltEncryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
unsigned char* key, int keyLength,
unsigned char* iv, int ivLength,
@@ -34,8 +36,6 @@ bool PltEncryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
switch (algorithm) {
case ALGORITHM_AES_CBC:
LC_ASSERT(flags & CIPHER_FLAG_RESET_IV);
LC_ASSERT(flags & CIPHER_FLAG_FINISH);
LC_ASSERT(tag == NULL);
LC_ASSERT(tagLength == 0);
cipherMode = MBEDTLS_MODE_CBC;
@@ -70,9 +70,31 @@ bool PltEncryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
}
}
else {
if (mbedtls_cipher_crypt(&ctx->ctx, iv, ivLength, inputData, inputDataLength, outputData, &outLength) != 0) {
if (flags & CIPHER_FLAG_RESET_IV) {
if (mbedtls_cipher_set_iv(&ctx->ctx, iv, ivLength) != 0) {
return false;
}
mbedtls_cipher_reset(&ctx->ctx);
}
if (flags & CIPHER_FLAG_PAD_TO_BLOCK_SIZE) {
inputDataLength = addPkcs7PaddingInPlace(inputData, inputDataLength);
}
if (mbedtls_cipher_update(&ctx->ctx, inputData, inputDataLength, outputData, &outLength) != 0) {
return false;
}
if (flags & CIPHER_FLAG_FINISH) {
size_t finishLength;
if (mbedtls_cipher_finish(&ctx->ctx, &outputData[outLength], &finishLength) != 0) {
return false;
}
outLength += finishLength;
}
}
*outputDataLength = outLength;
@@ -120,7 +142,9 @@ bool PltEncryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
ctx->initialized = true;
}
inputDataLength = addPkcs7PaddingInPlace(inputData, inputDataLength);
if (flags & CIPHER_FLAG_PAD_TO_BLOCK_SIZE) {
inputDataLength = addPkcs7PaddingInPlace(inputData, inputDataLength);
}
}
if (EVP_EncryptUpdate(ctx->ctx, outputData, outputDataLength, inputData, inputDataLength) != 1) {
@@ -154,7 +178,8 @@ bool PltEncryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
#endif
}
// For CBC modes, outputData buffer must be allocated with length rounded up to next multiple of 16!
// When CBC is used, outputData buffer must be allocated such that the buffer length is
// at least ROUND_TO_PKCS7_PADDED_LEN(inputDataLength) to allow room for PKCS7 padding.
bool PltDecryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
unsigned char* key, int keyLength,
unsigned char* iv, int ivLength,
@@ -167,8 +192,6 @@ bool PltDecryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
switch (algorithm) {
case ALGORITHM_AES_CBC:
LC_ASSERT(flags & CIPHER_FLAG_RESET_IV);
LC_ASSERT(flags & CIPHER_FLAG_FINISH);
LC_ASSERT(tag == NULL);
LC_ASSERT(tagLength == 0);
cipherMode = MBEDTLS_MODE_CBC;
@@ -203,9 +226,27 @@ bool PltDecryptMessage(PPLT_CRYPTO_CONTEXT ctx, int algorithm, int flags,
}
}
else {
if (mbedtls_cipher_crypt(&ctx->ctx, iv, ivLength, inputData, inputDataLength, outputData, &outLength) != 0) {
if (flags & CIPHER_FLAG_RESET_IV) {
if (mbedtls_cipher_set_iv(&ctx->ctx, iv, ivLength) != 0) {
return false;
}
mbedtls_cipher_reset(&ctx->ctx);
}
if (mbedtls_cipher_update(&ctx->ctx, inputData, inputDataLength, outputData, &outLength) != 0) {
return false;
}
if (flags & CIPHER_FLAG_FINISH) {
size_t finishLength;
if (mbedtls_cipher_finish(&ctx->ctx, &outputData[outLength], &finishLength) != 0) {
return false;
}
outLength += finishLength;
}
}
*outputDataLength = outLength;