diff options
Diffstat (limited to 'src/ringct/multiexp.cc')
-rw-r--r-- | src/ringct/multiexp.cc | 80 |
1 files changed, 52 insertions, 28 deletions
diff --git a/src/ringct/multiexp.cc b/src/ringct/multiexp.cc index 7ed9672f2..4f16bd588 100644 --- a/src/ringct/multiexp.cc +++ b/src/ringct/multiexp.cc @@ -259,42 +259,66 @@ rct::key bos_coster_heap_conv_robust(std::vector<MultiexpData> data) return res; } -rct::key straus(const std::vector<MultiexpData> &data, bool HiGi) +struct straus_cached_data { - MULTIEXP_PERF(PERF_TIMER_UNIT(straus, 1000000)); + std::vector<std::vector<ge_cached>> multiples; +}; - MULTIEXP_PERF(PERF_TIMER_START_UNIT(setup, 1000000)); - static constexpr unsigned int c = 4; - static constexpr unsigned int mask = (1<<c)-1; - static std::vector<std::vector<ge_cached>> HiGi_multiples; - std::vector<std::vector<ge_cached>> local_multiples, &multiples = HiGi ? HiGi_multiples : local_multiples; +static constexpr unsigned int STRAUS_C = 4; + +std::shared_ptr<straus_cached_data> straus_init_cache(const std::vector<MultiexpData> &data) +{ + MULTIEXP_PERF(PERF_TIMER_START_UNIT(multiples, 1000000)); ge_cached cached; ge_p1p1 p1; ge_p3 p3; + std::shared_ptr<straus_cached_data> cache(new straus_cached_data()); - std::vector<uint8_t> skip(data.size()); - for (size_t i = 0; i < data.size(); ++i) - skip[i] = data[i].scalar == rct::zero() || !memcmp(&data[i].point, &ge_p3_identity, sizeof(ge_p3)); - - MULTIEXP_PERF(PERF_TIMER_START_UNIT(multiples, 1000000)); - multiples.resize(1<<c); - size_t offset = multiples[1].size(); - multiples[1].resize(std::max(offset, data.size())); + cache->multiples.resize(1<<STRAUS_C); + size_t offset = cache->multiples[1].size(); + cache->multiples[1].resize(std::max(offset, data.size())); for (size_t i = offset; i < data.size(); ++i) - ge_p3_to_cached(&multiples[1][i], &data[i].point); - for (size_t i=2;i<1<<c;++i) - multiples[i].resize(std::max(offset, data.size())); + ge_p3_to_cached(&cache->multiples[1][i], &data[i].point); + for (size_t i=2;i<1<<STRAUS_C;++i) + cache->multiples[i].resize(std::max(offset, data.size())); for (size_t j=offset;j<data.size();++j) { - for (size_t i=2;i<1<<c;++i) + for (size_t i=2;i<1<<STRAUS_C;++i) { - ge_add(&p1, &data[j].point, &multiples[i-1][j]); + ge_add(&p1, &data[j].point, &cache->multiples[i-1][j]); ge_p1p1_to_p3(&p3, &p1); - ge_p3_to_cached(&multiples[i][j], &p3); + ge_p3_to_cached(&cache->multiples[i][j], &p3); } } MULTIEXP_PERF(PERF_TIMER_STOP(multiples)); + return cache; +} + +size_t straus_get_cache_size(const std::shared_ptr<straus_cached_data> &cache) +{ + size_t sz = 0; + for (const auto &e0: cache->multiples) + sz += e0.size() * sizeof(ge_p3); + return sz; +} + +rct::key straus(const std::vector<MultiexpData> &data, const std::shared_ptr<straus_cached_data> &cache) +{ + MULTIEXP_PERF(PERF_TIMER_UNIT(straus, 1000000)); + bool HiGi = cache != NULL; + + MULTIEXP_PERF(PERF_TIMER_START_UNIT(setup, 1000000)); + static constexpr unsigned int mask = (1<<STRAUS_C)-1; + std::shared_ptr<straus_cached_data> local_cache = cache == NULL ? straus_init_cache(data) : cache; + ge_cached cached; + ge_p1p1 p1; + ge_p3 p3; + + std::vector<uint8_t> skip(data.size()); + for (size_t i = 0; i < data.size(); ++i) + skip[i] = data[i].scalar == rct::zero() || !memcmp(&data[i].point, &ge_p3_identity, sizeof(ge_p3)); + MULTIEXP_PERF(PERF_TIMER_START_UNIT(digits, 1000000)); std::vector<std::vector<uint8_t>> digits; digits.resize(data.size()); @@ -305,7 +329,7 @@ rct::key straus(const std::vector<MultiexpData> &data, bool HiGi) memcpy(bytes33, data[j].scalar.bytes, 32); bytes33[32] = 0; #if 1 - static_assert(c == 4, "optimized version needs c == 4"); + static_assert(STRAUS_C == 4, "optimized version needs STRAUS_C == 4"); const unsigned char *bytes = bytes33; unsigned int i; for (i = 0; i < 256; i += 8, bytes++) @@ -339,22 +363,22 @@ rct::key straus(const std::vector<MultiexpData> &data, bool HiGi) maxscalar = data[i].scalar; size_t i = 0; while (i < 256 && !(maxscalar < pow2(i))) - i += c; + i += STRAUS_C; MULTIEXP_PERF(PERF_TIMER_STOP(setup)); ge_p3 res_p3 = ge_p3_identity; - if (!(i < c)) + if (!(i < STRAUS_C)) goto skipfirst; - while (!(i < c)) + while (!(i < STRAUS_C)) { - for (size_t j = 0; j < c; ++j) + for (size_t j = 0; j < STRAUS_C; ++j) { ge_p3_to_cached(&cached, &res_p3); ge_add(&p1, &res_p3, &cached); ge_p1p1_to_p3(&res_p3, &p1); } skipfirst: - i -= c; + i -= STRAUS_C; for (size_t j = 0; j < data.size(); ++j) { if (skip[j]) @@ -362,7 +386,7 @@ skipfirst: int digit = digits[j][i]; if (digit) { - ge_add(&p1, &res_p3, &multiples[digit][j]); + ge_add(&p1, &res_p3, &local_cache->multiples[digit][j]); ge_p1p1_to_p3(&res_p3, &p1); } } |