diff --git a/include/libssh/libcrypto.h b/include/libssh/libcrypto.h index 2a7343fc52e04a57ca43b6eaa3732f907bea8c81..79a5fd5cf1845b5cf248f21ca4df6981b6c83628 100644 --- a/include/libssh/libcrypto.h +++ b/include/libssh/libcrypto.h @@ -59,8 +59,15 @@ typedef void *EVPCTX; #define EVP_DIGEST_LEN EVP_MAX_MD_SIZE #endif +/* Use ssh_crypto_free() to release memory allocated by bignum_bn2dec(), + bignum_bn2hex() and other functions that use crypto-library functions that + are documented to allocate memory that needs to be de-allocate with + OPENSSL_free. */ +#define ssh_crypto_free(x) OPENSSL_free(x) + #include #include + typedef BIGNUM* bignum; typedef const BIGNUM* const_bignum; typedef BN_CTX* bignum_CTX; diff --git a/include/libssh/libgcrypt.h b/include/libssh/libgcrypt.h index e4087fd229b07bdaccc555aec7a8fe8f1fcdac59..966fb0447038913baeeda15e48fedc0c16c6da11 100644 --- a/include/libssh/libgcrypt.h +++ b/include/libssh/libgcrypt.h @@ -49,6 +49,8 @@ typedef gcry_md_hd_t EVPCTX; #define EVP_DIGEST_LEN EVP_MAX_MD_SIZE +#define ssh_crypto_free(x) gcry_free(x) + typedef gcry_mpi_t bignum; typedef const struct gcry_mpi *const_bignum; typedef void* bignum_CTX; diff --git a/include/libssh/libmbedcrypto.h b/include/libssh/libmbedcrypto.h index 6cf186261d1be88b70d98309fb4d6496a6ad0a59..a4ee010b7941e4cb4d24879a3733d6f9e5a0f277 100644 --- a/include/libssh/libmbedcrypto.h +++ b/include/libssh/libmbedcrypto.h @@ -34,6 +34,7 @@ #include #include #include +#include typedef mbedtls_md_context_t *SHACTX; typedef mbedtls_md_context_t *SHA256CTX; @@ -59,6 +60,8 @@ typedef mbedtls_md_context_t *EVPCTX; #define EVP_DIGEST_LEN EVP_MAX_MD_SIZE +#define ssh_crypto_free(x) mbedtls_free(x) + typedef mbedtls_mpi *bignum; typedef const mbedtls_mpi *const_bignum; typedef void* bignum_CTX; diff --git a/src/bignum.c b/src/bignum.c index d812b4127f16930da4643741ca66a6bc9b5b8084..bee55d6717602a13c3a5492b81cb2e878f58984c 100644 --- a/src/bignum.c +++ b/src/bignum.c @@ -88,11 +88,5 @@ void ssh_print_bignum(const char *name, const_bignum num) } SSH_LOG(SSH_LOG_DEBUG, "%s value: %s", name, (hex == NULL) ? "(null)" : (char *)hex); -#ifdef HAVE_LIBGCRYPT - SAFE_FREE(hex); -#elif defined HAVE_LIBCRYPTO - OPENSSL_free(hex); -#elif defined HAVE_LIBMBEDCRYPTO - SAFE_FREE(hex); -#endif + ssh_crypto_free(hex); } diff --git a/src/gcrypt_missing.c b/src/gcrypt_missing.c index e931ec5bb1f1c0b50eeb9ab6815bcb4f5e878a29..21a63a9b252e3ec0fe9f4a1f1a9be0a727073f74 100644 --- a/src/gcrypt_missing.c +++ b/src/gcrypt_missing.c @@ -55,7 +55,7 @@ char *ssh_gcry_bn2dec(bignum bn) { size = gcry_mpi_get_nbits(bn) * 3; rsize = size / 10 + size / 1000 + 2; - ret = malloc(rsize + 1); + ret = gcry_malloc(rsize + 1); if (ret == NULL) { return NULL; } diff --git a/src/mbedcrypto_missing.c b/src/mbedcrypto_missing.c index fb35ca473ecd70a6e174917ac90a147c695da782..2c1a8d7ad148988054f478307f7e7fadc201c78f 100644 --- a/src/mbedcrypto_missing.c +++ b/src/mbedcrypto_missing.c @@ -56,7 +56,7 @@ char *ssh_mbedcry_bn2num(const_bignum num, int radix) return NULL; } - buf = malloc(olen); + buf = mbedtls_calloc(1, olen); if (buf == NULL) { return NULL; } diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt index f145986539b7448247313447c46ca97e122d1238..04fcba1120a6007a5ff3cbe55925769fbf927639 100644 --- a/tests/unittests/CMakeLists.txt +++ b/tests/unittests/CMakeLists.txt @@ -3,6 +3,7 @@ project(unittests C) include_directories(${OPENSSL_INCLUDE_DIR}) set(LIBSSH_UNIT_TESTS + torture_bignum torture_buffer torture_bytearray torture_callbacks diff --git a/tests/unittests/torture_bignum.c b/tests/unittests/torture_bignum.c new file mode 100644 index 0000000000000000000000000000000000000000..c36b81f857f2f93d384306496d223245fcef8fc9 --- /dev/null +++ b/tests/unittests/torture_bignum.c @@ -0,0 +1,106 @@ +#include "config.h" + +#define LIBSSH_STATIC + +#include "torture.h" +#include "libssh/bignum.h" +#include "libssh/string.h" + +static void check_str (int n, ssh_string str) +{ + if (n > 0 && n <= 127) { + assert_int_equal(1, ntohl (str->size)); + assert_int_equal(n, str->data[0]); + } else if (n > 127 && n <= 255) { + assert_int_equal(2, ntohl (str->size)); + assert_int_equal(0, str->data[0]); + assert_int_equal(n, str->data[1]); + } else if (n > 255 && n <= 32767) { + assert_int_equal(2, ntohl (str->size)); + assert_int_equal(n >> 8, str->data[0]); + assert_int_equal(n & 0xFF, str->data[1]); + } else { + assert_int_equal(3, ntohl (str->size)); + assert_int_equal(n >> 16, str->data[0]); + assert_int_equal((n >> 8) & 0xFF, str->data[1]); + assert_int_equal(n & 0xFF, str->data[2]); + } +} + +static void check_bignum(int n, const char *nstr) { + bignum num, num2; + ssh_string str; + char *dec; + + num = bignum_new(); + assert_non_null(num); + + assert_int_equal (1, bignum_set_word (num, n)); + + ssh_print_bignum("num", num); + + dec = bignum_bn2dec (num); + assert_non_null (dec); + assert_string_equal (nstr, dec); + ssh_crypto_free(dec); + + /* ssh_make_bignum_string */ + + str = ssh_make_bignum_string(num); + assert_non_null(str); + + check_str (n, str); + + /* ssh_make_string_bn */ + + num2 = ssh_make_string_bn(str); + ssh_string_free (str); + assert_non_null(num2); + + ssh_print_bignum("num2", num2); + + assert_int_equal (0, bignum_cmp (num, num2)); + + dec = bignum_bn2dec (num2); + assert_non_null (dec); + assert_string_equal (nstr, dec); + ssh_crypto_free(dec); + + bignum_safe_free(num); + bignum_safe_free(num2); +} + + +static void torture_bignum(void **state) { + (void) state; /* unused */ + + ssh_set_log_level(SSH_LOG_TRACE); + + check_bignum (1, "1"); + check_bignum (17, "17"); + check_bignum (42, "42"); + check_bignum (127, "127"); + check_bignum (128, "128"); + check_bignum (254, "254"); + check_bignum (255, "255"); + check_bignum (256, "256"); + check_bignum (257, "257"); + check_bignum (300, "300"); + check_bignum (32767, "32767"); + check_bignum (32768, "32768"); + check_bignum (65535, "65535"); + check_bignum (65536, "65536"); +} + +int torture_run_tests(void) { + int rc; + struct CMUnitTest tests[] = { + cmocka_unit_test(torture_bignum), + }; + + ssh_init(); + torture_filter_tests(tests); + rc = cmocka_run_group_tests(tests, NULL, NULL); + ssh_finalize(); + return rc; +}