aboutsummaryrefslogtreecommitdiff
path: root/src/ringct/bulletproofs.cc
diff options
context:
space:
mode:
authormoneromooo-monero <moneromooo-monero@users.noreply.github.com>2018-08-22 22:30:14 +0000
committermoneromooo-monero <moneromooo-monero@users.noreply.github.com>2018-10-22 16:07:44 +0000
commit8629a42cf6e4650b552925f7637761b8e7ee66e3 (patch)
tree5dc8195e33c2cd67f1643edee12653704d913890 /src/ringct/bulletproofs.cc
parentbulletproofs: merge multiexps as per sarang's new python code (diff)
downloadmonero-8629a42cf6e4650b552925f7637761b8e7ee66e3.tar.xz
bulletproofs: rework flow to use sarang's fast batch inversion code
Diffstat (limited to 'src/ringct/bulletproofs.cc')
-rw-r--r--src/ringct/bulletproofs.cc233
1 files changed, 155 insertions, 78 deletions
diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc
index 549e52296..d9961cb20 100644
--- a/src/ringct/bulletproofs.cc
+++ b/src/ringct/bulletproofs.cc
@@ -29,8 +29,6 @@
// Adapted from Java code by Sarang Noether
#include <stdlib.h>
-#include <openssl/ssl.h>
-#include <openssl/bn.h>
#include <boost/thread/mutex.hpp>
#include "misc_log_ex.h"
#include "common/perf_timer.h"
@@ -289,37 +287,59 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
return rct::keyV(N, x);
}
-static rct::key switch_endianness(rct::key k)
+static rct::key sm(rct::key y, int n, const rct::key &x)
{
- std::reverse(k.bytes, k.bytes + sizeof(k));
- return k;
+ while (n--)
+ sc_mul(y.bytes, y.bytes, y.bytes);
+ sc_mul(y.bytes, y.bytes, x.bytes);
+ return y;
}
-/* Compute the inverse of a scalar, the stupid way */
+/* Compute the inverse of a scalar, the clever way */
static rct::key invert(const rct::key &x)
{
- rct::key inv;
-
- BN_CTX *ctx = BN_CTX_new();
- BIGNUM *X = BN_new();
- BIGNUM *L = BN_new();
- BIGNUM *I = BN_new();
-
- BN_bin2bn(switch_endianness(x).bytes, sizeof(rct::key), X);
- BN_bin2bn(switch_endianness(rct::curveOrder()).bytes, sizeof(rct::key), L);
-
- CHECK_AND_ASSERT_THROW_MES(BN_mod_inverse(I, X, L, ctx), "Failed to invert");
+ rct::key _1, _10, _100, _11, _101, _111, _1001, _1011, _1111;
+
+ _1 = x;
+ sc_mul(_10.bytes, _1.bytes, _1.bytes);
+ sc_mul(_100.bytes, _10.bytes, _10.bytes);
+ sc_mul(_11.bytes, _10.bytes, _1.bytes);
+ sc_mul(_101.bytes, _10.bytes, _11.bytes);
+ sc_mul(_111.bytes, _10.bytes, _101.bytes);
+ sc_mul(_1001.bytes, _10.bytes, _111.bytes);
+ sc_mul(_1011.bytes, _10.bytes, _1001.bytes);
+ sc_mul(_1111.bytes, _100.bytes, _1011.bytes);
- const int len = BN_num_bytes(I);
- CHECK_AND_ASSERT_THROW_MES((size_t)len <= sizeof(rct::key), "Invalid number length");
- inv = rct::zero();
- BN_bn2bin(I, inv.bytes);
- std::reverse(inv.bytes, inv.bytes + len);
-
- BN_free(I);
- BN_free(L);
- BN_free(X);
- BN_CTX_free(ctx);
+ rct::key inv;
+ sc_mul(inv.bytes, _1111.bytes, _1.bytes);
+
+ inv = sm(inv, 123 + 3, _101);
+ inv = sm(inv, 2 + 2, _11);
+ inv = sm(inv, 1 + 4, _1111);
+ inv = sm(inv, 1 + 4, _1111);
+ inv = sm(inv, 4, _1001);
+ inv = sm(inv, 2, _11);
+ inv = sm(inv, 1 + 4, _1111);
+ inv = sm(inv, 1 + 3, _101);
+ inv = sm(inv, 3 + 3, _101);
+ inv = sm(inv, 3, _111);
+ inv = sm(inv, 1 + 4, _1111);
+ inv = sm(inv, 2 + 3, _111);
+ inv = sm(inv, 2 + 2, _11);
+ inv = sm(inv, 1 + 4, _1011);
+ inv = sm(inv, 2 + 4, _1011);
+ inv = sm(inv, 6 + 4, _1001);
+ inv = sm(inv, 2 + 2, _11);
+ inv = sm(inv, 3 + 2, _11);
+ inv = sm(inv, 3 + 2, _11);
+ inv = sm(inv, 1 + 4, _1001);
+ inv = sm(inv, 1 + 3, _111);
+ inv = sm(inv, 2 + 4, _1111);
+ inv = sm(inv, 1 + 4, _1011);
+ inv = sm(inv, 3, _101);
+ inv = sm(inv, 2 + 4, _1111);
+ inv = sm(inv, 3, _101);
+ inv = sm(inv, 1 + 2, _11);
#ifdef DEBUG_BP
rct::key tmp;
@@ -329,6 +349,34 @@ static rct::key invert(const rct::key &x)
return inv;
}
+static rct::keyV invert(rct::keyV x)
+{
+ rct::keyV scratch;
+ scratch.reserve(x.size());
+
+ rct::key acc = rct::identity();
+ for (size_t n = 0; n < x.size(); ++n)
+ {
+ scratch.push_back(acc);
+ if (n == 0)
+ acc = x[0];
+ else
+ sc_mul(acc.bytes, acc.bytes, x[n].bytes);
+ }
+
+ acc = invert(acc);
+
+ rct::key tmp;
+ for (int i = x.size(); i-- > 0; )
+ {
+ sc_mul(tmp.bytes, acc.bytes, x[i].bytes);
+ sc_mul(x[i].bytes, acc.bytes, scratch[i].bytes);
+ acc = tmp;
+ }
+
+ return x;
+}
+
/* Compute the slice of a vector */
static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop)
{
@@ -702,6 +750,13 @@ Bulletproof bulletproof_PROVE(const std::vector<uint64_t> &v, const rct::keyV &g
return bulletproof_PROVE(sv, gamma);
}
+struct proof_data_t
+{
+ rct::key x, y, z, x_ip;
+ std::vector<rct::key> w;
+ size_t logM, inv_offset;
+};
+
/* Given a range proof, determine if it is valid */
bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
{
@@ -709,9 +764,17 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_START_BP(VERIFY);
+ const size_t logN = 6;
+ const size_t N = 1 << logN;
+
// sanity and figure out which proof is longest
size_t max_length = 0;
size_t nV = 0;
+ std::vector<proof_data_t> proof_data;
+ proof_data.reserve(proofs.size());
+ size_t inv_offset = 0;
+ std::vector<rct::key> to_invert;
+ to_invert.reserve(11 * sizeof(proofs));
for (const Bulletproof *p: proofs)
{
const Bulletproof &proof = *p;
@@ -729,46 +792,75 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
max_length = std::max(max_length, proof.L.size());
nV += proof.V.size();
+
+ // Reconstruct the challenges
+ PERF_TIMER_START_BP(VERIFY_start);
+ proof_data.resize(proof_data.size() + 1);
+ proof_data_t &pd = proof_data.back();
+ rct::key hash_cache = rct::hash_to_scalar(proof.V);
+ pd.y = hash_cache_mash(hash_cache, proof.A, proof.S);
+ CHECK_AND_ASSERT_MES(!(pd.y == rct::zero()), false, "y == 0");
+ pd.z = hash_cache = rct::hash_to_scalar(pd.y);
+ CHECK_AND_ASSERT_MES(!(pd.z == rct::zero()), false, "z == 0");
+ pd.x = hash_cache_mash(hash_cache, pd.z, proof.T1, proof.T2);
+ CHECK_AND_ASSERT_MES(!(pd.x == rct::zero()), false, "x == 0");
+ pd.x_ip = hash_cache_mash(hash_cache, pd.x, proof.taux, proof.mu, proof.t);
+ CHECK_AND_ASSERT_MES(!(pd.x_ip == rct::zero()), false, "x_ip == 0");
+ PERF_TIMER_STOP(VERIFY_start);
+
+ size_t M;
+ for (pd.logM = 0; (M = 1<<pd.logM) <= maxM && M < proof.V.size(); ++pd.logM);
+ CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
+
+ const size_t rounds = pd.logM+logN;
+ CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
+
+ PERF_TIMER_START_BP(VERIFY_line_21_22);
+ // PAPER LINES 21-22
+ // The inner product challenges are computed per round
+ pd.w.resize(rounds);
+ for (size_t i = 0; i < rounds; ++i)
+ {
+ pd.w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
+ CHECK_AND_ASSERT_MES(!(pd.w[i] == rct::zero()), false, "w[i] == 0");
+ }
+ PERF_TIMER_STOP(VERIFY_line_21_22);
+
+ pd.inv_offset = inv_offset;
+ for (size_t i = 0; i < rounds; ++i)
+ to_invert.push_back(pd.w[i]);
+ to_invert.push_back(pd.y);
+ inv_offset += rounds + 1;
}
CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large");
size_t maxMN = 1u << max_length;
- const size_t logN = 6;
- const size_t N = 1 << logN;
rct::key tmp;
std::vector<MultiexpData> multiexp_data;
multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN);
+ PERF_TIMER_START_BP(VERIFY_line_24_25_invert);
+ const std::vector<rct::key> inverses = invert(to_invert);
+ PERF_TIMER_STOP(VERIFY_line_24_25_invert);
+
// setup weighted aggregates
rct::key z1 = rct::zero();
rct::key z3 = rct::zero();
rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero());
rct::key y0 = rct::zero(), y1 = rct::zero();
+ int proof_data_index = 0;
for (const Bulletproof *p: proofs)
{
const Bulletproof &proof = *p;
+ const proof_data_t &pd = proof_data[proof_data_index++];
- size_t M, logM;
- for (logM = 0; (M = 1<<logM) <= maxM && M < proof.V.size(); ++logM);
- CHECK_AND_ASSERT_MES(proof.L.size() == 6+logM, false, "Proof is not the expected size");
+ CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
+ const size_t M = 1 << pd.logM;
const size_t MN = M*N;
const rct::key weight_y = rct::skGen();
const rct::key weight_z = rct::skGen();
- // Reconstruct the challenges
- PERF_TIMER_START_BP(VERIFY_start);
- rct::key hash_cache = rct::hash_to_scalar(proof.V);
- rct::key y = hash_cache_mash(hash_cache, proof.A, proof.S);
- CHECK_AND_ASSERT_MES(!(y == rct::zero()), false, "y == 0");
- rct::key z = hash_cache = rct::hash_to_scalar(y);
- CHECK_AND_ASSERT_MES(!(z == rct::zero()), false, "z == 0");
- rct::key x = hash_cache_mash(hash_cache, z, proof.T1, proof.T2);
- CHECK_AND_ASSERT_MES(!(x == rct::zero()), false, "x == 0");
- rct::key x_ip = hash_cache_mash(hash_cache, x, proof.taux, proof.mu, proof.t);
- CHECK_AND_ASSERT_MES(!(x_ip == rct::zero()), false, "x_ip == 0");
- PERF_TIMER_STOP(VERIFY_start);
-
// pre-multiply some points by 8
rct::keyV proof8_V = proof.V; for (rct::key &k: proof8_V) k = rct::scalarmult8(k);
rct::keyV proof8_L = proof.L; for (rct::key &k: proof8_L) k = rct::scalarmult8(k);
@@ -782,10 +874,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
// PAPER LINE 61
sc_muladd(y0.bytes, proof.taux.bytes, weight_y.bytes, y0.bytes);
- const rct::keyV zpow = vector_powers(z, M+3);
+ const rct::keyV zpow = vector_powers(pd.z, M+3);
rct::key k;
- const rct::key ip1y = vector_power_sum(y, MN);
+ const rct::key ip1y = vector_power_sum(pd.y, MN);
sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes);
for (size_t j = 1; j <= M; ++j)
{
@@ -795,7 +887,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_STOP(VERIFY_line_61);
PERF_TIMER_START_BP(VERIFY_line_61rl_new);
- sc_muladd(tmp.bytes, z.bytes, ip1y.bytes, k.bytes);
+ sc_muladd(tmp.bytes, pd.z.bytes, ip1y.bytes, k.bytes);
sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes);
sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes);
for (size_t j = 0; j < proof8_V.size(); j++)
@@ -803,10 +895,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_V[j]);
}
- sc_mul(tmp.bytes, x.bytes, weight_y.bytes);
+ sc_mul(tmp.bytes, pd.x.bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_T1);
rct::key xsq;
- sc_mul(xsq.bytes, x.bytes, x.bytes);
+ sc_mul(xsq.bytes, pd.x.bytes, pd.x.bytes);
sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_T2);
PERF_TIMER_STOP(VERIFY_line_61rl_new);
@@ -814,49 +906,34 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_START_BP(VERIFY_line_62);
// PAPER LINE 62
multiexp_data.emplace_back(weight_z, proof8_A);
- sc_mul(tmp.bytes, x.bytes, weight_z.bytes);
+ sc_mul(tmp.bytes, pd.x.bytes, weight_z.bytes);
multiexp_data.emplace_back(tmp, proof8_S);
PERF_TIMER_STOP(VERIFY_line_62);
// Compute the number of rounds for the inner product
- const size_t rounds = logM+logN;
+ const size_t rounds = pd.logM+logN;
CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
- PERF_TIMER_START_BP(VERIFY_line_21_22);
- // PAPER LINES 21-22
- // The inner product challenges are computed per round
- rct::keyV w(rounds);
- for (size_t i = 0; i < rounds; ++i)
- {
- w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
- CHECK_AND_ASSERT_MES(!(w[i] == rct::zero()), false, "w[i] == 0");
- }
- PERF_TIMER_STOP(VERIFY_line_21_22);
-
PERF_TIMER_START_BP(VERIFY_line_24_25);
// Basically PAPER LINES 24-25
// Compute the curvepoints from G[i] and H[i]
rct::key yinvpow = rct::identity();
rct::key ypow = rct::identity();
- PERF_TIMER_START_BP(VERIFY_line_24_25_invert);
- const rct::key yinv = invert(y);
- rct::keyV winv(rounds);
- for (size_t i = 0; i < rounds; ++i)
- winv[i] = invert(w[i]);
- PERF_TIMER_STOP(VERIFY_line_24_25_invert);
+ const rct::key *winv = &inverses[pd.inv_offset];
+ const rct::key yinv = inverses[pd.inv_offset + rounds];
// precalc
PERF_TIMER_START_BP(VERIFY_line_24_25_precalc);
rct::keyV w_cache(1<<rounds);
w_cache[0] = winv[0];
- w_cache[1] = w[0];
+ w_cache[1] = pd.w[0];
for (size_t j = 1; j < rounds; ++j)
{
const size_t slots = 1<<(j+1);
for (size_t s = slots; s-- > 0; --s)
{
- sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, w[j].bytes);
+ sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, pd.w[j].bytes);
sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes);
}
}
@@ -876,18 +953,18 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes);
// Adjust the scalars using the exponents from PAPER LINE 62
- sc_add(g_scalar.bytes, g_scalar.bytes, z.bytes);
+ sc_add(g_scalar.bytes, g_scalar.bytes, pd.z.bytes);
CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index");
CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index");
sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes);
if (i == 0)
{
- sc_add(tmp.bytes, tmp.bytes, z.bytes);
+ sc_add(tmp.bytes, tmp.bytes, pd.z.bytes);
sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes);
}
else
{
- sc_muladd(tmp.bytes, z.bytes, ypow.bytes, tmp.bytes);
+ sc_muladd(tmp.bytes, pd.z.bytes, ypow.bytes, tmp.bytes);
sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes);
}
@@ -897,12 +974,12 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
if (i == 0)
{
yinvpow = yinv;
- ypow = y;
+ ypow = pd.y;
}
else if (i != MN-1)
{
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
- sc_mul(ypow.bytes, ypow.bytes, y.bytes);
+ sc_mul(ypow.bytes, ypow.bytes, pd.y.bytes);
}
}
@@ -913,7 +990,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes);
for (size_t i = 0; i < rounds; ++i)
{
- sc_mul(tmp.bytes, w[i].bytes, w[i].bytes);
+ sc_mul(tmp.bytes, pd.w[i].bytes, pd.w[i].bytes);
sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes);
multiexp_data.emplace_back(tmp, proof8_L[i]);
sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes);
@@ -921,7 +998,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
multiexp_data.emplace_back(tmp, proof8_R[i]);
}
sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes);
- sc_mul(tmp.bytes, tmp.bytes, x_ip.bytes);
+ sc_mul(tmp.bytes, tmp.bytes, pd.x_ip.bytes);
sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes);
PERF_TIMER_STOP(VERIFY_line_26_new);
}