diff options
author | Riccardo Spagni <ric@spagni.net> | 2018-11-04 20:46:41 +0200 |
---|---|---|
committer | Riccardo Spagni <ric@spagni.net> | 2018-11-04 20:46:42 +0200 |
commit | 6d3d8635bea2ca04937521a2b03215c7fe8e3262 (patch) | |
tree | ca455c8c2c75ca1dc533f43a7b30825fbc02c590 /src/ringct/bulletproofs.cc | |
parent | Merge pull request #4692 (diff) | |
parent | multiexp: some minor speedups (diff) | |
download | monero-6d3d8635bea2ca04937521a2b03215c7fe8e3262.tar.xz |
Merge pull request #4693
74fb3d88 multiexp: some minor speedups (moneromooo-monero)
a6d2e246 bulletproofs: only enable profiling on request (moneromooo-monero)
a110e6aa multiexp: tune which variants to use for which number of points (moneromooo-monero)
8b476722 bulletproofs: speedup prover (moneromooo-monero)
6f9ae5b6 multiexp: handle pippenger multiexps with part precalc (moneromooo-monero)
10e5a927 bulletproofs: maintain -z4, -z5, and -y0 to avoid subtractions (moneromooo-monero)
8629a42c bulletproofs: rework flow to use sarang's fast batch inversion code (moneromooo-monero)
fc9f7d9c bulletproofs: merge multiexps as per sarang's new python code (moneromooo-monero)
4061960a multiexp: pack the digits table when STRAUS_C is 4 (moneromooo-monero)
bf8e4b98 bulletproofs: some more minor speedup (moneromooo-monero)
c415df97 performance_tests: sc_check and ge_dsm_precomp (moneromooo-monero)
a281b950 bulletproofs: remove single value prover (moneromooo-monero)
484155d0 bulletproofs: some more speedup (moneromooo-monero)
a621d6c8 bulletproofs: random minor speedups (moneromooo-monero)
a49a1761 bulletproofs: shave off a lot of scalar muls from the g/h construction (moneromooo-monero)
4564a5d1 bulletproofs: speedup PROVE (moneromooo-monero)
Diffstat (limited to 'src/ringct/bulletproofs.cc')
-rw-r--r-- | src/ringct/bulletproofs.cc | 924 |
1 files changed, 364 insertions, 560 deletions
diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc index 381f50872..bed48769a 100644 --- a/src/ringct/bulletproofs.cc +++ b/src/ringct/bulletproofs.cc @@ -29,8 +29,6 @@ // Adapted from Java code by Sarang Noether #include <stdlib.h> -#include <openssl/ssl.h> -#include <openssl/bn.h> #include <boost/thread/mutex.hpp> #include "misc_log_ex.h" #include "common/perf_timer.h" @@ -48,9 +46,15 @@ extern "C" //#define DEBUG_BP +#if 1 #define PERF_TIMER_START_BP(x) PERF_TIMER_START_UNIT(x, 1000000) +#define PERF_TIMER_STOP_BP(x) PERF_TIMER_STOP(x) +#else +#define PERF_TIMER_START_BP(x) ((void*)0) +#define PERF_TIMER_STOP_BP(x) ((void*)0) +#endif -#define STRAUS_SIZE_LIMIT 128 +#define STRAUS_SIZE_LIMIT 232 #define PIPPENGER_SIZE_LIMIT 0 namespace rct @@ -75,65 +79,20 @@ static const rct::keyV twoN = vector_powers(TWO, maxN); static const rct::key ip12 = inner_product(oneN, twoN); static boost::mutex init_mutex; -static inline rct::key multiexp(const std::vector<MultiexpData> &data, bool HiGi) +static inline rct::key multiexp(const std::vector<MultiexpData> &data, size_t HiGi_size) { - if (HiGi) + if (HiGi_size > 0) { - static_assert(128 <= STRAUS_SIZE_LIMIT, "Straus in precalc mode can only be calculated till STRAUS_SIZE_LIMIT"); - return data.size() <= 128 ? straus(data, straus_HiGi_cache, 0) : pippenger(data, pippenger_HiGi_cache, get_pippenger_c(data.size())); + static_assert(232 <= STRAUS_SIZE_LIMIT, "Straus in precalc mode can only be calculated till STRAUS_SIZE_LIMIT"); + return HiGi_size <= 232 && data.size() == HiGi_size ? straus(data, straus_HiGi_cache, 0) : pippenger(data, pippenger_HiGi_cache, HiGi_size, get_pippenger_c(data.size())); } else - return data.size() <= 64 ? straus(data, NULL, 0) : pippenger(data, NULL, get_pippenger_c(data.size())); -} - -static bool is_reduced(const rct::key &scalar) -{ - rct::key reduced = scalar; - sc_reduce32(reduced.bytes); - return scalar == reduced; -} - -static void addKeys_acc_p3(ge_p3 *acc_p3, const rct::key &a, const rct::key &point) -{ - ge_p3 p3; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&p3, point.bytes) == 0, "ge_frombytes_vartime failed"); - ge_scalarmult_p3(&p3, a.bytes, &p3); - ge_cached cached; - ge_p3_to_cached(&cached, acc_p3); - ge_p1p1 p1; - ge_add(&p1, &p3, &cached); - ge_p1p1_to_p3(acc_p3, &p1); -} - -static void add_acc_p3(ge_p3 *acc_p3, const rct::key &point) -{ - ge_p3 p3; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&p3, point.bytes) == 0, "ge_frombytes_vartime failed"); - ge_cached cached; - ge_p3_to_cached(&cached, &p3); - ge_p1p1 p1; - ge_add(&p1, acc_p3, &cached); - ge_p1p1_to_p3(acc_p3, &p1); + return data.size() <= 95 ? straus(data, NULL, 0) : pippenger(data, NULL, 0, get_pippenger_c(data.size())); } -static void sub_acc_p3(ge_p3 *acc_p3, const rct::key &point) +static inline bool is_reduced(const rct::key &scalar) { - ge_p3 p3; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&p3, point.bytes) == 0, "ge_frombytes_vartime failed"); - ge_cached cached; - ge_p3_to_cached(&cached, &p3); - ge_p1p1 p1; - ge_sub(&p1, acc_p3, &cached); - ge_p1p1_to_p3(acc_p3, &p1); -} - -static rct::key scalarmultKey(const ge_p3 &P, const rct::key &a) -{ - ge_p2 R; - ge_scalarmult(&R, a.bytes, &P); - rct::key aP; - ge_tobytes(aP.bytes, &R); - return aP; + return sc_check(scalar.bytes) == 0; } static rct::key get_exponent(const rct::key &base, size_t idx) @@ -160,12 +119,12 @@ static void init_exponents() Gi[i] = get_exponent(rct::H, i * 2 + 1); CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&Gi_p3[i], Gi[i].bytes) == 0, "ge_frombytes_vartime failed"); - data.push_back({rct::zero(), Gi[i]}); - data.push_back({rct::zero(), Hi[i]}); + data.push_back({rct::zero(), Gi_p3[i]}); + data.push_back({rct::zero(), Hi_p3[i]}); } straus_HiGi_cache = straus_init_cache(data, STRAUS_SIZE_LIMIT); - pippenger_HiGi_cache = pippenger_init_cache(data, PIPPENGER_SIZE_LIMIT); + pippenger_HiGi_cache = pippenger_init_cache(data, 0, PIPPENGER_SIZE_LIMIT); MINFO("Hi/Gi cache size: " << (sizeof(Hi)+sizeof(Gi))/1024 << " kB"); MINFO("Hi_p3/Gi_p3 cache size: " << (sizeof(Hi_p3)+sizeof(Gi_p3))/1024 << " kB"); @@ -189,29 +148,37 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b) multiexp_data.emplace_back(a[i], Gi_p3[i]); multiexp_data.emplace_back(b[i], Hi_p3[i]); } - return multiexp(multiexp_data, true); + return multiexp(multiexp_data, 2 * a.size()); } /* Compute a custom vector-scalar commitment */ -static rct::key vector_exponent_custom(const rct::keyV &A, const rct::keyV &B, const rct::keyV &a, const rct::keyV &b) +static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const rct::keyV *scale, const ge_p3 *extra_point, const rct::key *extra_scalar) { - CHECK_AND_ASSERT_THROW_MES(A.size() == B.size(), "Incompatible sizes of A and B"); - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); - CHECK_AND_ASSERT_THROW_MES(a.size() == A.size(), "Incompatible sizes of a and A"); - CHECK_AND_ASSERT_THROW_MES(a.size() <= maxN*maxM, "Incompatible sizes of a and maxN"); + CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A"); + CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B"); + CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a"); + CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b"); + CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large"); + CHECK_AND_ASSERT_THROW_MES(!scale || size == scale->size() / 2, "Incompatible size for scale"); + CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present"); std::vector<MultiexpData> multiexp_data; - multiexp_data.reserve(a.size()*2); - for (size_t i = 0; i < a.size(); ++i) + multiexp_data.resize(size*2 + (!!extra_point)); + for (size_t i = 0; i < size; ++i) { - multiexp_data.resize(multiexp_data.size() + 1); - multiexp_data.back().scalar = a[i]; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, A[i].bytes) == 0, "ge_frombytes_vartime failed"); - multiexp_data.resize(multiexp_data.size() + 1); - multiexp_data.back().scalar = b[i]; - CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, B[i].bytes) == 0, "ge_frombytes_vartime failed"); + sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);; + multiexp_data[i*2].point = A[Ao+i]; + sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes); + if (scale) + sc_mul(multiexp_data[i*2+1].scalar.bytes, multiexp_data[i*2+1].scalar.bytes, (*scale)[Bo+i].bytes); + multiexp_data[i*2+1].point = B[Bo+i]; } - return multiexp(multiexp_data, false); + if (extra_point) + { + sc_mul(multiexp_data.back().scalar.bytes, extra_scalar->bytes, INV_EIGHT.bytes); + multiexp_data.back().point = *extra_point; + } + return multiexp(multiexp_data, 0); } /* Given a scalar, construct a vector of powers */ @@ -273,16 +240,22 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b) return res; } -/* Given two curvepoint arrays, construct the Hadamard product */ -static rct::keyV hadamard2(const rct::keyV &a, const rct::keyV &b) +/* folds a curvepoint array using a two way scaled Hadamard product */ +static void hadamard_fold(std::vector<ge_p3> &v, const rct::keyV *scale, const rct::key &a, const rct::key &b) { - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); - rct::keyV res(a.size()); - for (size_t i = 0; i < a.size(); ++i) + CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even"); + const size_t sz = v.size() / 2; + for (size_t n = 0; n < sz; ++n) { - rct::addKeys(res[i], a[i], b[i]); + ge_dsmp c[2]; + ge_dsm_precomp(c[0], &v[n]); + ge_dsm_precomp(c[1], &v[sz + n]); + rct::key sa, sb; + if (scale) sc_mul(sa.bytes, a.bytes, (*scale)[n].bytes); else sa = a; + if (scale) sc_mul(sb.bytes, b.bytes, (*scale)[sz + n].bytes); else sb = b; + ge_double_scalarmult_precomp_vartime2_p3(&v[n], sa.bytes, c[0], sb.bytes, c[1]); } - return res; + v.resize(sz); } /* Add two vectors */ @@ -297,88 +270,98 @@ static rct::keyV vector_add(const rct::keyV &a, const rct::keyV &b) return res; } -/* Subtract two vectors */ -static rct::keyV vector_subtract(const rct::keyV &a, const rct::keyV &b) +/* Add a scalar to all elements of a vector */ +static rct::keyV vector_add(const rct::keyV &a, const rct::key &b) { - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); rct::keyV res(a.size()); for (size_t i = 0; i < a.size(); ++i) { - sc_sub(res[i].bytes, a[i].bytes, b[i].bytes); + sc_add(res[i].bytes, a[i].bytes, b.bytes); } return res; } -/* Multiply a scalar and a vector */ -static rct::keyV vector_scalar(const rct::keyV &a, const rct::key &x) +/* Subtract a scalar from all elements of a vector */ +static rct::keyV vector_subtract(const rct::keyV &a, const rct::key &b) { rct::keyV res(a.size()); for (size_t i = 0; i < a.size(); ++i) { - sc_mul(res[i].bytes, a[i].bytes, x.bytes); + sc_sub(res[i].bytes, a[i].bytes, b.bytes); } return res; } -/* Create a vector from copies of a single value */ -static rct::keyV vector_dup(const rct::key &x, size_t N) -{ - return rct::keyV(N, x); -} - -/* Exponentiate a curve vector by a scalar */ -static rct::keyV vector_scalar2(const rct::keyV &a, const rct::key &x) +/* Multiply a scalar and a vector */ +static rct::keyV vector_scalar(const rct::keyV &a, const rct::key &x) { rct::keyV res(a.size()); for (size_t i = 0; i < a.size(); ++i) { - rct::scalarmultKey(res[i], a[i], x); + sc_mul(res[i].bytes, a[i].bytes, x.bytes); } return res; } -/* Get the sum of a vector's elements */ -static rct::key vector_sum(const rct::keyV &a) +/* Create a vector from copies of a single value */ +static rct::keyV vector_dup(const rct::key &x, size_t N) { - rct::key res = rct::zero(); - for (size_t i = 0; i < a.size(); ++i) - { - sc_add(res.bytes, res.bytes, a[i].bytes); - } - return res; + return rct::keyV(N, x); } -static rct::key switch_endianness(rct::key k) +static rct::key sm(rct::key y, int n, const rct::key &x) { - std::reverse(k.bytes, k.bytes + sizeof(k)); - return k; + while (n--) + sc_mul(y.bytes, y.bytes, y.bytes); + sc_mul(y.bytes, y.bytes, x.bytes); + return y; } -/* Compute the inverse of a scalar, the stupid way */ +/* Compute the inverse of a scalar, the clever way */ static rct::key invert(const rct::key &x) { - rct::key inv; - - BN_CTX *ctx = BN_CTX_new(); - BIGNUM *X = BN_new(); - BIGNUM *L = BN_new(); - BIGNUM *I = BN_new(); - - BN_bin2bn(switch_endianness(x).bytes, sizeof(rct::key), X); - BN_bin2bn(switch_endianness(rct::curveOrder()).bytes, sizeof(rct::key), L); - - CHECK_AND_ASSERT_THROW_MES(BN_mod_inverse(I, X, L, ctx), "Failed to invert"); - - const int len = BN_num_bytes(I); - CHECK_AND_ASSERT_THROW_MES((size_t)len <= sizeof(rct::key), "Invalid number length"); - inv = rct::zero(); - BN_bn2bin(I, inv.bytes); - std::reverse(inv.bytes, inv.bytes + len); + rct::key _1, _10, _100, _11, _101, _111, _1001, _1011, _1111; + + _1 = x; + sc_mul(_10.bytes, _1.bytes, _1.bytes); + sc_mul(_100.bytes, _10.bytes, _10.bytes); + sc_mul(_11.bytes, _10.bytes, _1.bytes); + sc_mul(_101.bytes, _10.bytes, _11.bytes); + sc_mul(_111.bytes, _10.bytes, _101.bytes); + sc_mul(_1001.bytes, _10.bytes, _111.bytes); + sc_mul(_1011.bytes, _10.bytes, _1001.bytes); + sc_mul(_1111.bytes, _100.bytes, _1011.bytes); - BN_free(I); - BN_free(L); - BN_free(X); - BN_CTX_free(ctx); + rct::key inv; + sc_mul(inv.bytes, _1111.bytes, _1.bytes); + + inv = sm(inv, 123 + 3, _101); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 4, _1001); + inv = sm(inv, 2, _11); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 1 + 3, _101); + inv = sm(inv, 3 + 3, _101); + inv = sm(inv, 3, _111); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 2 + 3, _111); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 1 + 4, _1011); + inv = sm(inv, 2 + 4, _1011); + inv = sm(inv, 6 + 4, _1001); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 3 + 2, _11); + inv = sm(inv, 3 + 2, _11); + inv = sm(inv, 1 + 4, _1001); + inv = sm(inv, 1 + 3, _111); + inv = sm(inv, 2 + 4, _1111); + inv = sm(inv, 1 + 4, _1011); + inv = sm(inv, 3, _101); + inv = sm(inv, 2 + 4, _1111); + inv = sm(inv, 3, _101); + inv = sm(inv, 1 + 2, _11); #ifdef DEBUG_BP rct::key tmp; @@ -388,6 +371,34 @@ static rct::key invert(const rct::key &x) return inv; } +static rct::keyV invert(rct::keyV x) +{ + rct::keyV scratch; + scratch.reserve(x.size()); + + rct::key acc = rct::identity(); + for (size_t n = 0; n < x.size(); ++n) + { + scratch.push_back(acc); + if (n == 0) + acc = x[0]; + else + sc_mul(acc.bytes, acc.bytes, x[n].bytes); + } + + acc = invert(acc); + + rct::key tmp; + for (int i = x.size(); i-- > 0; ) + { + sc_mul(tmp.bytes, acc.bytes, x[i].bytes); + sc_mul(x[i].bytes, acc.bytes, scratch[i].bytes); + acc = tmp; + } + + return x; +} + /* Compute the slice of a vector */ static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop) { @@ -438,270 +449,12 @@ static rct::key hash_cache_mash(rct::key &hash_cache, const rct::key &mash0, con /* Given a value v (0..2^N-1) and a mask gamma, construct a range proof */ Bulletproof bulletproof_PROVE(const rct::key &sv, const rct::key &gamma) { - init_exponents(); - - PERF_TIMER_UNIT(PROVE, 1000000); - - constexpr size_t logN = 6; // log2(64) - constexpr size_t N = 1<<logN; - - rct::key V; - rct::keyV aL(N), aR(N); - - PERF_TIMER_START_BP(PROVE_v); - rct::addKeys2(V, gamma, sv, rct::H); - V = rct::scalarmultKey(V, INV_EIGHT); - PERF_TIMER_STOP(PROVE_v); - - PERF_TIMER_START_BP(PROVE_aLaR); - for (size_t i = N; i-- > 0; ) - { - if (sv[i/8] & (((uint64_t)1)<<(i%8))) - { - aL[i] = rct::identity(); - } - else - { - aL[i] = rct::zero(); - } - sc_sub(aR[i].bytes, aL[i].bytes, rct::identity().bytes); - } - PERF_TIMER_STOP(PROVE_aLaR); - - rct::key hash_cache = rct::hash_to_scalar(V); - - // DEBUG: Test to ensure this recovers the value -#ifdef DEBUG_BP - uint64_t test_aL = 0, test_aR = 0; - for (size_t i = 0; i < N; ++i) - { - if (aL[i] == rct::identity()) - test_aL += ((uint64_t)1)<<i; - if (aR[i] == rct::zero()) - test_aR += ((uint64_t)1)<<i; - } - uint64_t v_test = 0; - for (int n = 0; n < 8; ++n) v_test |= (((uint64_t)sv[n]) << (8*n)); - CHECK_AND_ASSERT_THROW_MES(test_aL == v_test, "test_aL failed"); - CHECK_AND_ASSERT_THROW_MES(test_aR == v_test, "test_aR failed"); -#endif - -try_again: - PERF_TIMER_START_BP(PROVE_step1); - // PAPER LINES 38-39 - rct::key alpha = rct::skGen(); - rct::key ve = vector_exponent(aL, aR); - rct::key A; - rct::addKeys(A, ve, rct::scalarmultBase(alpha)); - A = rct::scalarmultKey(A, INV_EIGHT); - - // PAPER LINES 40-42 - rct::keyV sL = rct::skvGen(N), sR = rct::skvGen(N); - rct::key rho = rct::skGen(); - ve = vector_exponent(sL, sR); - rct::key S; - rct::addKeys(S, ve, rct::scalarmultBase(rho)); - S = rct::scalarmultKey(S, INV_EIGHT); - - // PAPER LINES 43-45 - rct::key y = hash_cache_mash(hash_cache, A, S); - if (y == rct::zero()) - { - PERF_TIMER_STOP(PROVE_step1); - MINFO("y is 0, trying again"); - goto try_again; - } - rct::key z = hash_cache = rct::hash_to_scalar(y); - if (z == rct::zero()) - { - PERF_TIMER_STOP(PROVE_step1); - MINFO("z is 0, trying again"); - goto try_again; - } - - // Polynomial construction before PAPER LINE 46 - rct::key t0 = rct::zero(); - rct::key t1 = rct::zero(); - rct::key t2 = rct::zero(); - - const auto yN = vector_powers(y, N); - - rct::key ip1y = vector_sum(yN); - rct::key tmp; - sc_muladd(t0.bytes, z.bytes, ip1y.bytes, t0.bytes); - - rct::key zsq; - sc_mul(zsq.bytes, z.bytes, z.bytes); - sc_muladd(t0.bytes, zsq.bytes, sv.bytes, t0.bytes); - - rct::key k = rct::zero(); - sc_mulsub(k.bytes, zsq.bytes, ip1y.bytes, k.bytes); - - rct::key zcu; - sc_mul(zcu.bytes, zsq.bytes, z.bytes); - sc_mulsub(k.bytes, zcu.bytes, ip12.bytes, k.bytes); - sc_add(t0.bytes, t0.bytes, k.bytes); - - // DEBUG: Test the value of t0 has the correct form -#ifdef DEBUG_BP - rct::key test_t0 = rct::zero(); - rct::key iph = inner_product(aL, hadamard(aR, yN)); - sc_add(test_t0.bytes, test_t0.bytes, iph.bytes); - rct::key ips = inner_product(vector_subtract(aL, aR), yN); - sc_muladd(test_t0.bytes, z.bytes, ips.bytes, test_t0.bytes); - rct::key ipt = inner_product(twoN, aL); - sc_muladd(test_t0.bytes, zsq.bytes, ipt.bytes, test_t0.bytes); - sc_add(test_t0.bytes, test_t0.bytes, k.bytes); - CHECK_AND_ASSERT_THROW_MES(t0 == test_t0, "t0 check failed"); -#endif - PERF_TIMER_STOP(PROVE_step1); - - PERF_TIMER_START_BP(PROVE_step2); - const auto HyNsR = hadamard(yN, sR); - const auto vpIz = vector_dup(z, N); - const auto vp2zsq = vector_scalar(twoN, zsq); - const auto aL_vpIz = vector_subtract(aL, vpIz); - const auto aR_vpIz = vector_add(aR, vpIz); - - rct::key ip1 = inner_product(aL_vpIz, HyNsR); - sc_add(t1.bytes, t1.bytes, ip1.bytes); - - rct::key ip2 = inner_product(sL, vector_add(hadamard(yN, aR_vpIz), vp2zsq)); - sc_add(t1.bytes, t1.bytes, ip2.bytes); - - rct::key ip3 = inner_product(sL, HyNsR); - sc_add(t2.bytes, t2.bytes, ip3.bytes); - - // PAPER LINES 47-48 - rct::key tau1 = rct::skGen(), tau2 = rct::skGen(); - - rct::key T1 = rct::addKeys(rct::scalarmultH(t1), rct::scalarmultBase(tau1)); - T1 = rct::scalarmultKey(T1, INV_EIGHT); - rct::key T2 = rct::addKeys(rct::scalarmultH(t2), rct::scalarmultBase(tau2)); - T2 = rct::scalarmultKey(T2, INV_EIGHT); - - // PAPER LINES 49-51 - rct::key x = hash_cache_mash(hash_cache, z, T1, T2); - if (x == rct::zero()) - { - PERF_TIMER_STOP(PROVE_step2); - MINFO("x is 0, trying again"); - goto try_again; - } - - // PAPER LINES 52-53 - rct::key taux = rct::zero(); - sc_mul(taux.bytes, tau1.bytes, x.bytes); - rct::key xsq; - sc_mul(xsq.bytes, x.bytes, x.bytes); - sc_muladd(taux.bytes, tau2.bytes, xsq.bytes, taux.bytes); - sc_muladd(taux.bytes, gamma.bytes, zsq.bytes, taux.bytes); - rct::key mu; - sc_muladd(mu.bytes, x.bytes, rho.bytes, alpha.bytes); - - // PAPER LINES 54-57 - rct::keyV l = vector_add(aL_vpIz, vector_scalar(sL, x)); - rct::keyV r = vector_add(hadamard(yN, vector_add(aR_vpIz, vector_scalar(sR, x))), vp2zsq); - PERF_TIMER_STOP(PROVE_step2); - - PERF_TIMER_START_BP(PROVE_step3); - rct::key t = inner_product(l, r); - - // DEBUG: Test if the l and r vectors match the polynomial forms -#ifdef DEBUG_BP - rct::key test_t; - sc_muladd(test_t.bytes, t1.bytes, x.bytes, t0.bytes); - sc_muladd(test_t.bytes, t2.bytes, xsq.bytes, test_t.bytes); - CHECK_AND_ASSERT_THROW_MES(test_t == t, "test_t check failed"); -#endif - - // PAPER LINES 32-33 - rct::key x_ip = hash_cache_mash(hash_cache, x, taux, mu, t); - - // These are used in the inner product rounds - size_t nprime = N; - rct::keyV Gprime(N); - rct::keyV Hprime(N); - rct::keyV aprime(N); - rct::keyV bprime(N); - const rct::key yinv = invert(y); - rct::key yinvpow = rct::identity(); - for (size_t i = 0; i < N; ++i) - { - Gprime[i] = Gi[i]; - Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); - sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); - aprime[i] = l[i]; - bprime[i] = r[i]; - } - rct::keyV L(logN); - rct::keyV R(logN); - int round = 0; - rct::keyV w(logN); // this is the challenge x in the inner product protocol - PERF_TIMER_STOP(PROVE_step3); - - PERF_TIMER_START_BP(PROVE_step4); - // PAPER LINE 13 - while (nprime > 1) - { - // PAPER LINE 15 - nprime /= 2; - - // PAPER LINES 16-17 - rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); - rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); - - // PAPER LINES 18-19 - L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); - sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); - rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); - L[round] = rct::scalarmultKey(L[round], INV_EIGHT); - R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); - sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); - rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); - R[round] = rct::scalarmultKey(R[round], INV_EIGHT); - - // PAPER LINES 21-22 - w[round] = hash_cache_mash(hash_cache, L[round], R[round]); - if (w[round] == rct::zero()) - { - PERF_TIMER_STOP(PROVE_step4); - MINFO("w[round] is 0, trying again"); - goto try_again; - } - - // PAPER LINES 24-25 - const rct::key winv = invert(w[round]); - Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); - Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); - - // PAPER LINES 28-29 - aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); - bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); - - ++round; - } - PERF_TIMER_STOP(PROVE_step4); - - // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) - return Bulletproof(V, A, S, T1, T2, taux, mu, L, R, aprime[0], bprime[0], t); + return bulletproof_PROVE(rct::keyV(1, sv), rct::keyV(1, gamma)); } Bulletproof bulletproof_PROVE(uint64_t v, const rct::key &gamma) { - // vG + gammaH - PERF_TIMER_START_BP(PROVE_v); - rct::key sv = rct::zero(); - sv.bytes[0] = v & 255; - sv.bytes[1] = (v >> 8) & 255; - sv.bytes[2] = (v >> 16) & 255; - sv.bytes[3] = (v >> 24) & 255; - sv.bytes[4] = (v >> 32) & 255; - sv.bytes[5] = (v >> 40) & 255; - sv.bytes[6] = (v >> 48) & 255; - sv.bytes[7] = (v >> 56) & 255; - PERF_TIMER_STOP(PROVE_v); - return bulletproof_PROVE(sv, gamma); + return bulletproof_PROVE(std::vector<uint64_t>(1, v), rct::keyV(1, gamma)); } /* Given a set of values v (0..2^N-1) and masks gamma, construct a range proof */ @@ -728,37 +481,39 @@ Bulletproof bulletproof_PROVE(const rct::keyV &sv, const rct::keyV &gamma) rct::keyV V(sv.size()); rct::keyV aL(MN), aR(MN); - rct::key tmp; + rct::keyV aL8(MN), aR8(MN); + rct::key tmp, tmp2; PERF_TIMER_START_BP(PROVE_v); for (size_t i = 0; i < sv.size(); ++i) { - rct::addKeys2(V[i], gamma[i], sv[i], rct::H); - V[i] = rct::scalarmultKey(V[i], INV_EIGHT); + rct::key gamma8, sv8; + sc_mul(gamma8.bytes, gamma[i].bytes, INV_EIGHT.bytes); + sc_mul(sv8.bytes, sv[i].bytes, INV_EIGHT.bytes); + rct::addKeys2(V[i], gamma8, sv8, rct::H); } - PERF_TIMER_STOP(PROVE_v); + PERF_TIMER_STOP_BP(PROVE_v); PERF_TIMER_START_BP(PROVE_aLaR); for (size_t j = 0; j < M; ++j) { for (size_t i = N; i-- > 0; ) { - if (j >= sv.size()) - { - aL[j*N+i] = rct::zero(); - } - else if (sv[j][i/8] & (((uint64_t)1)<<(i%8))) + if (j < sv.size() && (sv[j][i/8] & (((uint64_t)1)<<(i%8)))) { aL[j*N+i] = rct::identity(); + aL8[j*N+i] = INV_EIGHT; + aR[j*N+i] = aR8[j*N+i] = rct::zero(); } else { - aL[j*N+i] = rct::zero(); + aL[j*N+i] = aL8[j*N+i] = rct::zero(); + aR[j*N+i] = MINUS_ONE; + aR8[j*N+i] = MINUS_INV_EIGHT; } - sc_sub(aR[j*N+i].bytes, aL[j*N+i].bytes, rct::identity().bytes); } } - PERF_TIMER_STOP(PROVE_aLaR); + PERF_TIMER_STOP_BP(PROVE_aLaR); // DEBUG: Test to ensure this recovers the value #ifdef DEBUG_BP @@ -786,10 +541,10 @@ try_again: PERF_TIMER_START_BP(PROVE_step1); // PAPER LINES 38-39 rct::key alpha = rct::skGen(); - rct::key ve = vector_exponent(aL, aR); + rct::key ve = vector_exponent(aL8, aR8); rct::key A; - rct::addKeys(A, ve, rct::scalarmultBase(alpha)); - A = rct::scalarmultKey(A, INV_EIGHT); + sc_mul(tmp.bytes, alpha.bytes, INV_EIGHT.bytes); + rct::addKeys(A, ve, rct::scalarmultBase(tmp)); // PAPER LINES 40-42 rct::keyV sL = rct::skvGen(MN), sR = rct::skvGen(MN); @@ -803,21 +558,20 @@ try_again: rct::key y = hash_cache_mash(hash_cache, A, S); if (y == rct::zero()) { - PERF_TIMER_STOP(PROVE_step1); + PERF_TIMER_STOP_BP(PROVE_step1); MINFO("y is 0, trying again"); goto try_again; } rct::key z = hash_cache = rct::hash_to_scalar(y); if (z == rct::zero()) { - PERF_TIMER_STOP(PROVE_step1); + PERF_TIMER_STOP_BP(PROVE_step1); MINFO("z is 0, trying again"); goto try_again; } // Polynomial construction by coefficients - const auto zMN = vector_dup(z, MN); - rct::keyV l0 = vector_subtract(aL, zMN); + rct::keyV l0 = vector_subtract(aL, z); const rct::keyV &l1 = sL; // This computes the ugly sum/concatenation from PAPER LINE 65 @@ -837,7 +591,7 @@ try_again: } } - rct::keyV r0 = vector_add(aR, zMN); + rct::keyV r0 = vector_add(aR, z); const auto yMN = vector_powers(y, MN); r0 = hadamard(r0, yMN); r0 = vector_add(r0, zero_twos); @@ -850,22 +604,28 @@ try_again: sc_add(t1.bytes, t1_1.bytes, t1_2.bytes); rct::key t2 = inner_product(l1, r1); - PERF_TIMER_STOP(PROVE_step1); + PERF_TIMER_STOP_BP(PROVE_step1); PERF_TIMER_START_BP(PROVE_step2); // PAPER LINES 47-48 rct::key tau1 = rct::skGen(), tau2 = rct::skGen(); - rct::key T1 = rct::addKeys(rct::scalarmultH(t1), rct::scalarmultBase(tau1)); - T1 = rct::scalarmultKey(T1, INV_EIGHT); - rct::key T2 = rct::addKeys(rct::scalarmultH(t2), rct::scalarmultBase(tau2)); - T2 = rct::scalarmultKey(T2, INV_EIGHT); + rct::key T1, T2; + ge_p3 p3; + sc_mul(tmp.bytes, t1.bytes, INV_EIGHT.bytes); + sc_mul(tmp2.bytes, tau1.bytes, INV_EIGHT.bytes); + ge_double_scalarmult_base_vartime_p3(&p3, tmp.bytes, &ge_p3_H, tmp2.bytes); + ge_p3_tobytes(T1.bytes, &p3); + sc_mul(tmp.bytes, t2.bytes, INV_EIGHT.bytes); + sc_mul(tmp2.bytes, tau2.bytes, INV_EIGHT.bytes); + ge_double_scalarmult_base_vartime_p3(&p3, tmp.bytes, &ge_p3_H, tmp2.bytes); + ge_p3_tobytes(T2.bytes, &p3); // PAPER LINES 49-51 rct::key x = hash_cache_mash(hash_cache, z, T1, T2); if (x == rct::zero()) { - PERF_TIMER_STOP(PROVE_step2); + PERF_TIMER_STOP_BP(PROVE_step2); MINFO("x is 0, trying again"); goto try_again; } @@ -889,7 +649,7 @@ try_again: l = vector_add(l, vector_scalar(l1, x)); rct::keyV r = r0; r = vector_add(r, vector_scalar(r1, x)); - PERF_TIMER_STOP(PROVE_step2); + PERF_TIMER_STOP_BP(PROVE_step2); PERF_TIMER_START_BP(PROVE_step3); rct::key t = inner_product(l, r); @@ -907,24 +667,27 @@ try_again: rct::key x_ip = hash_cache_mash(hash_cache, x, taux, mu, t); if (x_ip == rct::zero()) { - PERF_TIMER_STOP(PROVE_step3); + PERF_TIMER_STOP_BP(PROVE_step3); MINFO("x_ip is 0, trying again"); goto try_again; } // These are used in the inner product rounds size_t nprime = MN; - rct::keyV Gprime(MN); - rct::keyV Hprime(MN); + std::vector<ge_p3> Gprime(MN); + std::vector<ge_p3> Hprime(MN); rct::keyV aprime(MN); rct::keyV bprime(MN); const rct::key yinv = invert(y); - rct::key yinvpow = rct::identity(); + rct::keyV yinvpow(MN); + yinvpow[0] = rct::identity(); + yinvpow[1] = yinv; for (size_t i = 0; i < MN; ++i) { - Gprime[i] = Gi[i]; - Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow); - sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); + Gprime[i] = Gi_p3[i]; + Hprime[i] = Hi_p3[i]; + if (i > 1) + sc_mul(yinvpow[i].bytes, yinvpow[i-1].bytes, yinv.bytes); aprime[i] = l[i]; bprime[i] = r[i]; } @@ -932,53 +695,62 @@ try_again: rct::keyV R(logMN); int round = 0; rct::keyV w(logMN); // this is the challenge x in the inner product protocol - PERF_TIMER_STOP(PROVE_step3); + PERF_TIMER_STOP_BP(PROVE_step3); PERF_TIMER_START_BP(PROVE_step4); // PAPER LINE 13 + const rct::keyV *scale = &yinvpow; while (nprime > 1) { // PAPER LINE 15 nprime /= 2; // PAPER LINES 16-17 + PERF_TIMER_START_BP(PROVE_inner_product); rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); + PERF_TIMER_STOP_BP(PROVE_inner_product); // PAPER LINES 18-19 - L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size())); + PERF_TIMER_START_BP(PROVE_LR); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); - rct::addKeys(L[round], L[round], rct::scalarmultH(tmp)); - L[round] = rct::scalarmultKey(L[round], INV_EIGHT); - R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime)); + L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, scale, &ge_p3_H, &tmp); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); - rct::addKeys(R[round], R[round], rct::scalarmultH(tmp)); - R[round] = rct::scalarmultKey(R[round], INV_EIGHT); + R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, scale, &ge_p3_H, &tmp); + PERF_TIMER_STOP_BP(PROVE_LR); // PAPER LINES 21-22 w[round] = hash_cache_mash(hash_cache, L[round], R[round]); if (w[round] == rct::zero()) { - PERF_TIMER_STOP(PROVE_step4); + PERF_TIMER_STOP_BP(PROVE_step4); MINFO("w[round] is 0, trying again"); goto try_again; } // PAPER LINES 24-25 const rct::key winv = invert(w[round]); - Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round])); - Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv)); + if (nprime > 1) + { + PERF_TIMER_START_BP(PROVE_hadamard2); + hadamard_fold(Gprime, NULL, winv, w[round]); + hadamard_fold(Hprime, scale, w[round], winv); + PERF_TIMER_STOP_BP(PROVE_hadamard2); + } // PAPER LINES 28-29 + PERF_TIMER_START_BP(PROVE_prime); aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv)); bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); + PERF_TIMER_STOP_BP(PROVE_prime); + scale = NULL; ++round; } - PERF_TIMER_STOP(PROVE_step4); + PERF_TIMER_STOP_BP(PROVE_step4); // PAPER LINE 58 (with inclusions from PAPER LINE 8 and PAPER LINE 20) - return Bulletproof(V, A, S, T1, T2, taux, mu, L, R, aprime[0], bprime[0], t); + return Bulletproof(std::move(V), A, S, T1, T2, taux, mu, std::move(L), std::move(R), aprime[0], bprime[0], t); } Bulletproof bulletproof_PROVE(const std::vector<uint64_t> &v, const rct::keyV &gamma) @@ -1000,10 +772,17 @@ Bulletproof bulletproof_PROVE(const std::vector<uint64_t> &v, const rct::keyV &g sv[i].bytes[6] = (v[i] >> 48) & 255; sv[i].bytes[7] = (v[i] >> 56) & 255; } - PERF_TIMER_STOP(PROVE_v); + PERF_TIMER_STOP_BP(PROVE_v); return bulletproof_PROVE(sv, gamma); } +struct proof_data_t +{ + rct::key x, y, z, x_ip; + std::vector<rct::key> w; + size_t logM, inv_offset; +}; + /* Given a range proof, determine if it is valid */ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs) { @@ -1011,8 +790,17 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs) PERF_TIMER_START_BP(VERIFY); + const size_t logN = 6; + const size_t N = 1 << logN; + // sanity and figure out which proof is longest size_t max_length = 0; + size_t nV = 0; + std::vector<proof_data_t> proof_data; + proof_data.reserve(proofs.size()); + size_t inv_offset = 0; + std::vector<rct::key> to_invert; + to_invert.reserve(11 * sizeof(proofs)); for (const Bulletproof *p: proofs) { const Bulletproof &proof = *p; @@ -1029,44 +817,76 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs) CHECK_AND_ASSERT_MES(proof.L.size() > 0, false, "Empty proof"); max_length = std::max(max_length, proof.L.size()); + nV += proof.V.size(); + + // Reconstruct the challenges + PERF_TIMER_START_BP(VERIFY_start); + proof_data.resize(proof_data.size() + 1); + proof_data_t &pd = proof_data.back(); + rct::key hash_cache = rct::hash_to_scalar(proof.V); + pd.y = hash_cache_mash(hash_cache, proof.A, proof.S); + CHECK_AND_ASSERT_MES(!(pd.y == rct::zero()), false, "y == 0"); + pd.z = hash_cache = rct::hash_to_scalar(pd.y); + CHECK_AND_ASSERT_MES(!(pd.z == rct::zero()), false, "z == 0"); + pd.x = hash_cache_mash(hash_cache, pd.z, proof.T1, proof.T2); + CHECK_AND_ASSERT_MES(!(pd.x == rct::zero()), false, "x == 0"); + pd.x_ip = hash_cache_mash(hash_cache, pd.x, proof.taux, proof.mu, proof.t); + CHECK_AND_ASSERT_MES(!(pd.x_ip == rct::zero()), false, "x_ip == 0"); + PERF_TIMER_STOP_BP(VERIFY_start); + + size_t M; + for (pd.logM = 0; (M = 1<<pd.logM) <= maxM && M < proof.V.size(); ++pd.logM); + CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size"); + + const size_t rounds = pd.logM+logN; + CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds"); + + PERF_TIMER_START_BP(VERIFY_line_21_22); + // PAPER LINES 21-22 + // The inner product challenges are computed per round + pd.w.resize(rounds); + for (size_t i = 0; i < rounds; ++i) + { + pd.w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]); + CHECK_AND_ASSERT_MES(!(pd.w[i] == rct::zero()), false, "w[i] == 0"); + } + PERF_TIMER_STOP_BP(VERIFY_line_21_22); + + pd.inv_offset = inv_offset; + for (size_t i = 0; i < rounds; ++i) + to_invert.push_back(pd.w[i]); + to_invert.push_back(pd.y); + inv_offset += rounds + 1; } CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large"); size_t maxMN = 1u << max_length; - const size_t logN = 6; - const size_t N = 1 << logN; rct::key tmp; + std::vector<MultiexpData> multiexp_data; + multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN); + multiexp_data.resize(2 * maxMN); + + PERF_TIMER_START_BP(VERIFY_line_24_25_invert); + const std::vector<rct::key> inverses = invert(to_invert); + PERF_TIMER_STOP_BP(VERIFY_line_24_25_invert); + // setup weighted aggregates - rct::key Z0 = rct::identity(); rct::key z1 = rct::zero(); - rct::key Z2 = rct::identity(); rct::key z3 = rct::zero(); - rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero()); - rct::key Y2 = rct::identity(), Y3 = rct::identity(), Y4 = rct::identity(); - rct::key y0 = rct::zero(), y1 = rct::zero(); + rct::keyV m_z4(maxMN, rct::zero()), m_z5(maxMN, rct::zero()); + rct::key m_y0 = rct::zero(), y1 = rct::zero(); + int proof_data_index = 0; for (const Bulletproof *p: proofs) { const Bulletproof &proof = *p; + const proof_data_t &pd = proof_data[proof_data_index++]; - size_t M, logM; - for (logM = 0; (M = 1<<logM) <= maxM && M < proof.V.size(); ++logM); - CHECK_AND_ASSERT_MES(proof.L.size() == 6+logM, false, "Proof is not the expected size"); + CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size"); + const size_t M = 1 << pd.logM; const size_t MN = M*N; - rct::key weight = rct::skGen(); - - // Reconstruct the challenges - PERF_TIMER_START_BP(VERIFY_start); - rct::key hash_cache = rct::hash_to_scalar(proof.V); - rct::key y = hash_cache_mash(hash_cache, proof.A, proof.S); - CHECK_AND_ASSERT_MES(!(y == rct::zero()), false, "y == 0"); - rct::key z = hash_cache = rct::hash_to_scalar(y); - CHECK_AND_ASSERT_MES(!(z == rct::zero()), false, "z == 0"); - rct::key x = hash_cache_mash(hash_cache, z, proof.T1, proof.T2); - CHECK_AND_ASSERT_MES(!(x == rct::zero()), false, "x == 0"); - rct::key x_ip = hash_cache_mash(hash_cache, x, proof.taux, proof.mu, proof.t); - CHECK_AND_ASSERT_MES(!(x_ip == rct::zero()), false, "x_ip == 0"); - PERF_TIMER_STOP(VERIFY_start); + const rct::key weight_y = rct::skGen(); + const rct::key weight_z = rct::skGen(); // pre-multiply some points by 8 rct::keyV proof8_V = proof.V; for (rct::key &k: proof8_V) k = rct::scalarmult8(k); @@ -1075,177 +895,161 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs) rct::key proof8_T1 = rct::scalarmult8(proof.T1); rct::key proof8_T2 = rct::scalarmult8(proof.T2); rct::key proof8_S = rct::scalarmult8(proof.S); + rct::key proof8_A = rct::scalarmult8(proof.A); PERF_TIMER_START_BP(VERIFY_line_61); // PAPER LINE 61 - sc_muladd(y0.bytes, proof.taux.bytes, weight.bytes, y0.bytes); + sc_mulsub(m_y0.bytes, proof.taux.bytes, weight_y.bytes, m_y0.bytes); - const rct::keyV zpow = vector_powers(z, M+3); + const rct::keyV zpow = vector_powers(pd.z, M+3); rct::key k; - const rct::key ip1y = vector_power_sum(y, MN); + const rct::key ip1y = vector_power_sum(pd.y, MN); sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes); for (size_t j = 1; j <= M; ++j) { CHECK_AND_ASSERT_MES(j+2 < zpow.size(), false, "invalid zpow index"); sc_mulsub(k.bytes, zpow[j+2].bytes, ip12.bytes, k.bytes); } - PERF_TIMER_STOP(VERIFY_line_61); + PERF_TIMER_STOP_BP(VERIFY_line_61); PERF_TIMER_START_BP(VERIFY_line_61rl_new); - sc_muladd(tmp.bytes, z.bytes, ip1y.bytes, k.bytes); - std::vector<MultiexpData> multiexp_data; - multiexp_data.reserve(proof.V.size()); + sc_muladd(tmp.bytes, pd.z.bytes, ip1y.bytes, k.bytes); sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes); - sc_muladd(y1.bytes, tmp.bytes, weight.bytes, y1.bytes); + sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes); for (size_t j = 0; j < proof8_V.size(); j++) { - multiexp_data.emplace_back(zpow[j+2], proof8_V[j]); + sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes); + multiexp_data.emplace_back(tmp, proof8_V[j]); } - rct::addKeys(Y2, Y2, rct::scalarmultKey(multiexp(multiexp_data, false), weight)); - sc_mul(tmp.bytes, x.bytes, weight.bytes); - rct::addKeys(Y3, Y3, rct::scalarmultKey(proof8_T1, tmp)); + sc_mul(tmp.bytes, pd.x.bytes, weight_y.bytes); + multiexp_data.emplace_back(tmp, proof8_T1); rct::key xsq; - sc_mul(xsq.bytes, x.bytes, x.bytes); - sc_mul(tmp.bytes, xsq.bytes, weight.bytes); - rct::addKeys(Y4, Y4, rct::scalarmultKey(proof8_T2, tmp)); - PERF_TIMER_STOP(VERIFY_line_61rl_new); + sc_mul(xsq.bytes, pd.x.bytes, pd.x.bytes); + sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes); + multiexp_data.emplace_back(tmp, proof8_T2); + PERF_TIMER_STOP_BP(VERIFY_line_61rl_new); PERF_TIMER_START_BP(VERIFY_line_62); // PAPER LINE 62 - rct::addKeys(Z0, Z0, rct::scalarmultKey(rct::addKeys(rct::scalarmult8(proof.A), rct::scalarmultKey(proof8_S, x)), weight)); - PERF_TIMER_STOP(VERIFY_line_62); + multiexp_data.emplace_back(weight_z, proof8_A); + sc_mul(tmp.bytes, pd.x.bytes, weight_z.bytes); + multiexp_data.emplace_back(tmp, proof8_S); + PERF_TIMER_STOP_BP(VERIFY_line_62); // Compute the number of rounds for the inner product - const size_t rounds = logM+logN; + const size_t rounds = pd.logM+logN; CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds"); - PERF_TIMER_START_BP(VERIFY_line_21_22); - // PAPER LINES 21-22 - // The inner product challenges are computed per round - rct::keyV w(rounds); - for (size_t i = 0; i < rounds; ++i) - { - w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]); - CHECK_AND_ASSERT_MES(!(w[i] == rct::zero()), false, "w[i] == 0"); - } - PERF_TIMER_STOP(VERIFY_line_21_22); - PERF_TIMER_START_BP(VERIFY_line_24_25); // Basically PAPER LINES 24-25 // Compute the curvepoints from G[i] and H[i] rct::key yinvpow = rct::identity(); rct::key ypow = rct::identity(); - PERF_TIMER_START_BP(VERIFY_line_24_25_invert); - const rct::key yinv = invert(y); - rct::keyV winv(rounds); - for (size_t i = 0; i < rounds; ++i) - winv[i] = invert(w[i]); - PERF_TIMER_STOP(VERIFY_line_24_25_invert); + const rct::key *winv = &inverses[pd.inv_offset]; + const rct::key yinv = inverses[pd.inv_offset + rounds]; + + // precalc + PERF_TIMER_START_BP(VERIFY_line_24_25_precalc); + rct::keyV w_cache(1<<rounds); + w_cache[0] = winv[0]; + w_cache[1] = pd.w[0]; + for (size_t j = 1; j < rounds; ++j) + { + const size_t slots = 1<<(j+1); + for (size_t s = slots; s-- > 0; --s) + { + sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, pd.w[j].bytes); + sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes); + } + } + PERF_TIMER_STOP_BP(VERIFY_line_24_25_precalc); for (size_t i = 0; i < MN; ++i) { - // Convert the index to binary IN REVERSE and construct the scalar exponent rct::key g_scalar = proof.a; rct::key h_scalar; - sc_mul(h_scalar.bytes, proof.b.bytes, yinvpow.bytes); + if (i == 0) + h_scalar = proof.b; + else + sc_mul(h_scalar.bytes, proof.b.bytes, yinvpow.bytes); - for (size_t j = rounds; j-- > 0; ) - { - size_t J = w.size() - j - 1; - - if ((i & (((size_t)1)<<j)) == 0) - { - sc_mul(g_scalar.bytes, g_scalar.bytes, winv[J].bytes); - sc_mul(h_scalar.bytes, h_scalar.bytes, w[J].bytes); - } - else - { - sc_mul(g_scalar.bytes, g_scalar.bytes, w[J].bytes); - sc_mul(h_scalar.bytes, h_scalar.bytes, winv[J].bytes); - } - } + // Convert the index to binary IN REVERSE and construct the scalar exponent + sc_mul(g_scalar.bytes, g_scalar.bytes, w_cache[i].bytes); + sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes); // Adjust the scalars using the exponents from PAPER LINE 62 - sc_add(g_scalar.bytes, g_scalar.bytes, z.bytes); + sc_add(g_scalar.bytes, g_scalar.bytes, pd.z.bytes); CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index"); CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index"); sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes); - sc_muladd(tmp.bytes, z.bytes, ypow.bytes, tmp.bytes); - sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes); + if (i == 0) + { + sc_add(tmp.bytes, tmp.bytes, pd.z.bytes); + sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes); + } + else + { + sc_muladd(tmp.bytes, pd.z.bytes, ypow.bytes, tmp.bytes); + sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes); + } - sc_muladd(z4[i].bytes, g_scalar.bytes, weight.bytes, z4[i].bytes); - sc_muladd(z5[i].bytes, h_scalar.bytes, weight.bytes, z5[i].bytes); + sc_mulsub(m_z4[i].bytes, g_scalar.bytes, weight_z.bytes, m_z4[i].bytes); + sc_mulsub(m_z5[i].bytes, h_scalar.bytes, weight_z.bytes, m_z5[i].bytes); - if (i != MN-1) + if (i == 0) + { + yinvpow = yinv; + ypow = pd.y; + } + else if (i != MN-1) { sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); - sc_mul(ypow.bytes, ypow.bytes, y.bytes); + sc_mul(ypow.bytes, ypow.bytes, pd.y.bytes); } } - PERF_TIMER_STOP(VERIFY_line_24_25); + PERF_TIMER_STOP_BP(VERIFY_line_24_25); // PAPER LINE 26 PERF_TIMER_START_BP(VERIFY_line_26_new); - multiexp_data.clear(); - multiexp_data.reserve(2*rounds); - - sc_muladd(z1.bytes, proof.mu.bytes, weight.bytes, z1.bytes); + sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes); for (size_t i = 0; i < rounds; ++i) { - sc_mul(tmp.bytes, w[i].bytes, w[i].bytes); + sc_mul(tmp.bytes, pd.w[i].bytes, pd.w[i].bytes); + sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes); multiexp_data.emplace_back(tmp, proof8_L[i]); sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes); + sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes); multiexp_data.emplace_back(tmp, proof8_R[i]); } - rct::key acc = multiexp(multiexp_data, false); - rct::addKeys(Z2, Z2, rct::scalarmultKey(acc, weight)); sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes); - sc_mul(tmp.bytes, tmp.bytes, x_ip.bytes); - sc_muladd(z3.bytes, tmp.bytes, weight.bytes, z3.bytes); - PERF_TIMER_STOP(VERIFY_line_26_new); + sc_mul(tmp.bytes, tmp.bytes, pd.x_ip.bytes); + sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes); + PERF_TIMER_STOP_BP(VERIFY_line_26_new); } // now check all proofs at once PERF_TIMER_START_BP(VERIFY_step2_check); - ge_p3 check1; - ge_scalarmult_base(&check1, y0.bytes); - addKeys_acc_p3(&check1, y1, rct::H); - sub_acc_p3(&check1, Y2); - sub_acc_p3(&check1, Y3); - sub_acc_p3(&check1, Y4); - if (!ge_p3_is_point_at_infinity(&check1)) - { - MERROR("Verification failure at step 1"); - return false; - } - ge_p3 check2; - sc_sub(tmp.bytes, rct::zero().bytes, z1.bytes); - ge_double_scalarmult_base_vartime_p3(&check2, z3.bytes, &ge_p3_H, tmp.bytes); - add_acc_p3(&check2, Z0); - add_acc_p3(&check2, Z2); - - std::vector<MultiexpData> multiexp_data; - multiexp_data.reserve(2 * maxMN); + sc_sub(tmp.bytes, m_y0.bytes, z1.bytes); + multiexp_data.emplace_back(tmp, rct::G); + sc_sub(tmp.bytes, z3.bytes, y1.bytes); + multiexp_data.emplace_back(tmp, rct::H); for (size_t i = 0; i < maxMN; ++i) { - sc_sub(tmp.bytes, rct::zero().bytes, z4[i].bytes); - multiexp_data.emplace_back(tmp, Gi_p3[i]); - sc_sub(tmp.bytes, rct::zero().bytes, z5[i].bytes); - multiexp_data.emplace_back(tmp, Hi_p3[i]); + multiexp_data[i * 2] = {m_z4[i], Gi_p3[i]}; + multiexp_data[i * 2 + 1] = {m_z5[i], Hi_p3[i]}; } - add_acc_p3(&check2, multiexp(multiexp_data, true)); - PERF_TIMER_STOP(VERIFY_step2_check); - - if (!ge_p3_is_point_at_infinity(&check2)) + if (!(multiexp(multiexp_data, 2 * maxMN) == rct::identity())) { - MERROR("Verification failure at step 2"); + PERF_TIMER_STOP_BP(VERIFY_step2_check); + MERROR("Verification failure"); return false; } + PERF_TIMER_STOP_BP(VERIFY_step2_check); - PERF_TIMER_STOP(VERIFY); + PERF_TIMER_STOP_BP(VERIFY); return true; } |