#include <openssl/bn.h>
#include <assert.h>
#include <string.h>
#include <openssl/cpu.h>
#include <openssl/err.h>
#include <openssl/mem.h>
#include "internal.h"
#if defined(OPENSSL_X86_64)
#define OPENSSL_BN_ASM_MONT5
#define RSAZ_ENABLED
#include "rsaz_exp.h"
void bn_mul_mont_gather5(BN_ULONG *rp, const BN_ULONG *ap, const void *table,
const BN_ULONG *np, const BN_ULONG *n0, int num,
int power);
void bn_scatter5(const BN_ULONG *inp, size_t num, void *table, size_t power);
void bn_gather5(BN_ULONG *out, size_t num, void *table, size_t power);
void bn_power5(BN_ULONG *rp, const BN_ULONG *ap, const void *table,
const BN_ULONG *np, const BN_ULONG *n0, int num, int power);
int bn_from_montgomery(BN_ULONG *rp, const BN_ULONG *ap,
const BN_ULONG *not_used, const BN_ULONG *np,
const BN_ULONG *n0, int num);
#endif
#define TABLE_SIZE 32
#define BN_window_bits_for_exponent_size(b) \
((b) > 671 ? 6 : \
(b) > 239 ? 5 : \
(b) > 79 ? 4 : \
(b) > 23 ? 3 : 1)
int BN_mod_exp_mont_vartime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
const BIGNUM *m, const BN_MONT_CTX *mont) {
int j, bits, ret = 0, wstart, window;
int start = 1;
BIGNUM *val[TABLE_SIZE];
size_t val_len = 0;
BN_MONT_CTX *new_mont = NULL;
if (!BN_is_odd(m)) {
OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
return 0;
}
bits = BN_num_bits(p);
if (bits == 0) {
if (BN_is_one(m)) {
BN_zero(rr);
return 1;
}
return BN_one(rr);
}
if (a->neg || BN_ucmp(a, m) >= 0) {
OPENSSL_PUT_ERROR(BN, BN_R_INPUT_NOT_REDUCED);
return 0;
}
BIGNUM d;
BN_init(&d);
BIGNUM r;
BN_init(&r);
val[0] = BN_new();
if (val[0] == NULL) {
goto err;
}
++val_len;
if (mont == NULL) {
new_mont = BN_MONT_CTX_new();
if (new_mont == NULL || !BN_MONT_CTX_set(new_mont, m)) {
goto err;
}
mont = new_mont;
}
if (BN_is_zero(a)) {
BN_zero(rr);
ret = 1;
goto err;
}
if (!BN_to_mont(val[0], a, mont)) {
goto err;
}
window = BN_window_bits_for_exponent_size(bits);
if (window > 1) {
if (!BN_mod_mul_mont(&d, val[0], val[0], mont)) {
goto err;
}
j = 1 << (window - 1);
for (int i = 1; i < j; i++) {
val[i] = BN_new();
if (val[i] == NULL) {
goto err;
}
++val_len;
if (!BN_mod_mul_mont(val[i], val[i - 1], &d, mont)) {
goto err;
}
}
}
start = 1;
wstart = bits - 1;
j = m->top;
if (m->d[j - 1] & (((BN_ULONG)1) << (BN_BITS2 - 1))) {
if (bn_wexpand(&r, j) == NULL) {
goto err;
}
r.d[0] = (0 - m->d[0]) & BN_MASK2;
for (int i = 1; i < j; i++) {
r.d[i] = (~m->d[i]) & BN_MASK2;
}
r.top = j;
bn_correct_top(&r);
} else if (!BN_to_mont(&r, BN_value_one(), mont)) {
goto err;
}
for (;;) {
int wvalue;
int wend;
if (BN_is_bit_set(p, wstart) == 0) {
if (!start && !BN_mod_mul_mont(&r, &r, &r, mont)) {
goto err;
}
if (wstart == 0) {
break;
}
wstart--;
continue;
}
wvalue = 1;
wend = 0;
for (int i = 1; i < window; i++) {
if (wstart - i < 0) {
break;
}
if (BN_is_bit_set(p, wstart - i)) {
wvalue <<= (i - wend);
wvalue |= 1;
wend = i;
}
}
j = wend + 1;
if (!start) {
for (int i = 0; i < j; i++) {
if (!BN_mod_mul_mont(&r, &r, &r, mont)) {
goto err;
}
}
}
if (!BN_mod_mul_mont(&r, &r, val[wvalue >> 1], mont)) {
goto err;
}
wstart -= wend + 1;
start = 0;
if (wstart < 0) {
break;
}
}
if (!BN_from_mont(rr, &r, mont)) {
goto err;
}
ret = 1;
err:
BN_MONT_CTX_free(new_mont);
for (size_t i = 0; i < val_len; ++i) {
BN_free(val[i]);
}
BN_free(&r);
BN_free(&d);
return ret;
}
static int copy_to_prebuf(const BIGNUM *b, int top, unsigned char *buf, int idx,
int window) {
int i, j;
const int width = 1 << window;
BN_ULONG *table = (BN_ULONG *) buf;
if (top > b->top) {
top = b->top;
}
for (i = 0, j = idx; i < top; i++, j += width) {
table[j] = b->d[i];
}
return 1;
}
static int copy_from_prebuf(BIGNUM *b, int top, unsigned char *buf, int idx,
int window) {
int i, j;
const int width = 1 << window;
volatile BN_ULONG *table = (volatile BN_ULONG *)buf;
if (bn_wexpand(b, top) == NULL) {
return 0;
}
if (window <= 3) {
for (i = 0; i < top; i++, table += width) {
BN_ULONG acc = 0;
for (j = 0; j < width; j++) {
acc |= table[j] & ((BN_ULONG)0 - (constant_time_eq_int(j, idx) & 1));
}
b->d[i] = acc;
}
} else {
int xstride = 1 << (window - 2);
BN_ULONG y0, y1, y2, y3;
i = idx >> (window - 2);
idx &= xstride - 1;
y0 = (BN_ULONG)0 - (constant_time_eq_int(i, 0) & 1);
y1 = (BN_ULONG)0 - (constant_time_eq_int(i, 1) & 1);
y2 = (BN_ULONG)0 - (constant_time_eq_int(i, 2) & 1);
y3 = (BN_ULONG)0 - (constant_time_eq_int(i, 3) & 1);
for (i = 0; i < top; i++, table += width) {
BN_ULONG acc = 0;
for (j = 0; j < xstride; j++) {
acc |= ((table[j + 0 * xstride] & y0) | (table[j + 1 * xstride] & y1) |
(table[j + 2 * xstride] & y2) | (table[j + 3 * xstride] & y3)) &
((BN_ULONG)0 - (constant_time_eq_int(j, idx) & 1));
}
b->d[i] = acc;
}
}
b->top = top;
bn_correct_top(b);
return 1;
}
#define MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH (64)
#define MOD_EXP_CTIME_MIN_CACHE_LINE_MASK \
(MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH - 1)
#if MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH == 64
#define BN_window_bits_for_ctime_exponent_size(b) \
((b) > 937 ? 6 : (b) > 306 ? 5 : (b) > 89 ? 4 : (b) > 22 ? 3 : 1)
#define BN_MAX_WINDOW_BITS_FOR_CTIME_EXPONENT_SIZE (6)
#elif MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH == 32
#define BN_window_bits_for_ctime_exponent_size(b) \
((b) > 306 ? 5 : (b) > 89 ? 4 : (b) > 22 ? 3 : 1)
#define BN_MAX_WINDOW_BITS_FOR_CTIME_EXPONENT_SIZE (5)
#endif
#define MOD_EXP_CTIME_ALIGN(x_) \
((unsigned char *)(x_) + \
(MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH - \
(((uintptr_t)(x_)) & (MOD_EXP_CTIME_MIN_CACHE_LINE_MASK))))
int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
const BN_MONT_CTX *mont) {
int i, bits, ret = 0, window, wvalue;
int top;
BN_MONT_CTX *new_mont = NULL;
int numPowers;
unsigned char *powerbufFree = NULL;
int powerbufLen = 0;
unsigned char *powerbuf = NULL;
BIGNUM tmp, am;
const BIGNUM *m = &mont->N;
if (!BN_is_odd(m)) {
OPENSSL_PUT_ERROR(BN, BN_R_CALLED_WITH_EVEN_MODULUS);
return 0;
}
top = m->top;
bits = BN_num_bits(p);
if (bits == 0) {
if (BN_is_one(m)) {
BN_zero(rr);
return 1;
}
return BN_one(rr);
}
#ifdef RSAZ_ENABLED
if ((16 == a->top) && (16 == p->top) && (BN_num_bits(m) == 1024) &&
rsaz_avx2_eligible()) {
if (NULL == bn_wexpand(rr, 16)) {
goto err;
}
RSAZ_1024_mod_exp_avx2(rr->d, a->d, p->d, m->d, mont->RR.d, mont->n0[0]);
rr->top = 16;
rr->neg = 0;
bn_correct_top(rr);
ret = 1;
goto err;
}
#endif
window = BN_window_bits_for_ctime_exponent_size(bits);
#if defined(OPENSSL_BN_ASM_MONT5)
if (window >= 5) {
window = 5;
powerbufLen += top * sizeof(mont->N.d[0]);
}
#endif
numPowers = 1 << window;
powerbufLen +=
sizeof(m->d[0]) *
(top * numPowers + ((2 * top) > numPowers ? (2 * top) : numPowers));
#ifdef alloca
if (powerbufLen < 3072) {
powerbufFree = alloca(powerbufLen + MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH);
} else
#endif
{
if ((powerbufFree = OPENSSL_malloc(
powerbufLen + MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH)) == NULL) {
goto err;
}
}
powerbuf = MOD_EXP_CTIME_ALIGN(powerbufFree);
memset(powerbuf, 0, powerbufLen);
#ifdef alloca
if (powerbufLen < 3072) {
powerbufFree = NULL;
}
#endif
tmp.d = (BN_ULONG *)(powerbuf + sizeof(m->d[0]) * top * numPowers);
am.d = tmp.d + top;
tmp.top = am.top = 0;
tmp.dmax = am.dmax = top;
tmp.neg = am.neg = 0;
tmp.flags = am.flags = BN_FLG_STATIC_DATA;
if (m->d[top - 1] & (((BN_ULONG)1) << (BN_BITS2 - 1))) {
tmp.d[0] = (0 - m->d[0]) & BN_MASK2;
for (i = 1; i < top; i++) {
tmp.d[i] = (~m->d[i]) & BN_MASK2;
}
tmp.top = top;
} else if (!BN_to_mont(&tmp, BN_value_one(), mont)) {
goto err;
}
if (a->neg || BN_ucmp(a, m) >= 0) {
OPENSSL_PUT_ERROR(BN, BN_R_INPUT_NOT_REDUCED);
goto err;
} else if (!BN_to_mont(&am, a, mont)) {
goto err;
}
#if defined(OPENSSL_BN_ASM_MONT5)
if (window == 5 && top > 1) {
const BN_ULONG *n0 = mont->n0;
BN_ULONG *np;
for (i = am.top; i < top; i++) {
am.d[i] = 0;
}
for (i = tmp.top; i < top; i++) {
tmp.d[i] = 0;
}
for (np = am.d + top, i = 0; i < top; i++) {
np[i] = mont->N.d[i];
}
bn_scatter5(tmp.d, top, powerbuf, 0);
bn_scatter5(am.d, am.top, powerbuf, 1);
bn_mul_mont(tmp.d, am.d, am.d, np, n0, top);
bn_scatter5(tmp.d, top, powerbuf, 2);
for (i = 4; i < 32; i *= 2) {
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_scatter5(tmp.d, top, powerbuf, i);
}
for (i = 3; i < 8; i += 2) {
int j;
bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
bn_scatter5(tmp.d, top, powerbuf, i);
for (j = 2 * i; j < 32; j *= 2) {
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_scatter5(tmp.d, top, powerbuf, j);
}
}
for (; i < 16; i += 2) {
bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
bn_scatter5(tmp.d, top, powerbuf, i);
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_scatter5(tmp.d, top, powerbuf, 2 * i);
}
for (; i < 32; i += 2) {
bn_mul_mont_gather5(tmp.d, am.d, powerbuf, np, n0, top, i - 1);
bn_scatter5(tmp.d, top, powerbuf, i);
}
bits--;
for (wvalue = 0, i = bits % 5; i >= 0; i--, bits--) {
wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
}
bn_gather5(tmp.d, top, powerbuf, wvalue);
assert(bits >= -1 && (bits == -1 || bits % 5 == 4));
if (top & 7) {
while (bits >= 0) {
for (wvalue = 0, i = 0; i < 5; i++, bits--) {
wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
}
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_mul_mont(tmp.d, tmp.d, tmp.d, np, n0, top);
bn_mul_mont_gather5(tmp.d, tmp.d, powerbuf, np, n0, top, wvalue);
}
} else {
const uint8_t *p_bytes = (const uint8_t *)p->d;
int max_bits = p->top * BN_BITS2;
assert(bits < max_bits);
assert(max_bits >= 64);
if (bits - 4 >= max_bits - 8) {
wvalue = p_bytes[p->top * BN_BYTES - 1];
wvalue >>= (bits - 4) & 7;
wvalue &= 0x1f;
bits -= 5;
bn_power5(tmp.d, tmp.d, powerbuf, np, n0, top, wvalue);
}
while (bits >= 0) {
int first_bit = bits - 4;
wvalue = *(const uint16_t *) (p_bytes + (first_bit >> 3));
wvalue >>= first_bit & 7;
wvalue &= 0x1f;
bits -= 5;
bn_power5(tmp.d, tmp.d, powerbuf, np, n0, top, wvalue);
}
}
ret = bn_from_montgomery(tmp.d, tmp.d, NULL, np, n0, top);
tmp.top = top;
bn_correct_top(&tmp);
if (ret) {
if (!BN_copy(rr, &tmp)) {
ret = 0;
}
goto err;
}
} else
#endif
{
if (!copy_to_prebuf(&tmp, top, powerbuf, 0, window) ||
!copy_to_prebuf(&am, top, powerbuf, 1, window)) {
goto err;
}
if (window > 1) {
if (!BN_mod_mul_mont(&tmp, &am, &am, mont) ||
!copy_to_prebuf(&tmp, top, powerbuf, 2, window)) {
goto err;
}
for (i = 3; i < numPowers; i++) {
if (!BN_mod_mul_mont(&tmp, &am, &tmp, mont) ||
!copy_to_prebuf(&tmp, top, powerbuf, i, window)) {
goto err;
}
}
}
bits--;
for (wvalue = 0, i = bits % window; i >= 0; i--, bits--) {
wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
}
if (!copy_from_prebuf(&tmp, top, powerbuf, wvalue, window)) {
goto err;
}
while (bits >= 0) {
wvalue = 0;
for (i = 0; i < window; i++, bits--) {
if (!BN_mod_mul_mont(&tmp, &tmp, &tmp, mont)) {
goto err;
}
wvalue = (wvalue << 1) + BN_is_bit_set(p, bits);
}
if (!copy_from_prebuf(&am, top, powerbuf, wvalue, window)) {
goto err;
}
if (!BN_mod_mul_mont(&tmp, &tmp, &am, mont)) {
goto err;
}
}
}
if (!BN_from_mont(rr, &tmp, mont)) {
goto err;
}
ret = 1;
err:
BN_MONT_CTX_free(new_mont);
OPENSSL_free(powerbufFree);
return (ret);
}