diff options
author | Riccardo Spagni <ric@spagni.net> | 2016-11-08 22:37:43 +0200 |
---|---|---|
committer | Riccardo Spagni <ric@spagni.net> | 2016-11-08 22:37:43 +0200 |
commit | dce47d52af614b339c1068fe76d815cb93faca30 (patch) | |
tree | 497a2476daca11e9a4b4532740a73b218e74ea0d | |
parent | Merge pull request #1287 (diff) | |
parent | adding thread_group for managing async tasks (diff) | |
download | monero-dce47d52af614b339c1068fe76d815cb93faca30.tar.xz |
Merge pull request #1291
64094e5 adding thread_group for managing async tasks (Lee Clagett)
-rw-r--r-- | src/common/CMakeLists.txt | 6 | ||||
-rw-r--r-- | src/common/thread_group.cpp | 164 | ||||
-rw-r--r-- | src/common/thread_group.h | 133 | ||||
-rw-r--r-- | src/ringct/rctSigs.cpp | 148 |
4 files changed, 351 insertions, 100 deletions
diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 4b6149cbd..d5d22bca6 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -32,7 +32,8 @@ set(common_sources dns_utils.cpp util.cpp i18n.cpp - perf_timer.cpp) + perf_timer.cpp + thread_group.cpp) if (STACK_TRACE) list(APPEND common_sources stack_trace.cpp) @@ -55,7 +56,8 @@ set(common_private_headers varint.h i18n.h perf_timer.h - stack_trace.h) + stack_trace.h + thread_group.h) monero_private_headers(common ${common_private_headers}) diff --git a/src/common/thread_group.cpp b/src/common/thread_group.cpp new file mode 100644 index 000000000..ece268ab7 --- /dev/null +++ b/src/common/thread_group.cpp @@ -0,0 +1,164 @@ +// Copyright (c) 2014-2016, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include "common/thread_group.h" + +#include <cassert> +#include <limits> +#include <stdexcept> + +#include "common/util.h" + +namespace tools +{ +thread_group::thread_group(std::size_t count) : internal() { + static_assert( + std::numeric_limits<unsigned>::max() <= std::numeric_limits<std::size_t>::max(), + "unexpected truncation" + ); + count = std::min<std::size_t>(count, get_max_concurrency()); + count = count ? count - 1 : 0; + + if (count) { + internal.emplace(count); + } +} + +thread_group::data::data(std::size_t count) + : threads() + , head{nullptr} + , last(std::addressof(head)) + , pending(count) + , mutex() + , has_work() + , finished_work() + , stop(false) { + threads.reserve(count); + while (count--) { + threads.push_back(std::thread(&thread_group::data::run, this)); + } +} + +thread_group::data::~data() noexcept { + { + const std::unique_lock<std::mutex> lock(mutex); + stop = true; + } + has_work.notify_all(); + finished_work.notify_all(); + for (auto& worker : threads) { + try { + worker.join(); + } + catch(...) {} + } +} + + +void thread_group::data::sync() noexcept { + /* This function and `run()` can both throw when acquiring the lock, or in + the dispatched function. It is tough to recover from either, particularly the + lock case. These functions are marked as noexcept so that if either call + throws, the entire process is terminated. Users of the `dispatch` call are + expected to make their functions noexcept, or use std::packaged_task to copy + exceptions so that the process will continue in all but the most pessimistic + cases (std::bad_alloc). This was the existing behavior; + `asio::io_service::run` propogates errors from dispatched calls, and uncaught + exceptions on threads result in process termination. */ + assert(!threads.empty()); + bool not_first = false; + while (true) { + std::unique_ptr<work> next = nullptr; + { + std::unique_lock<std::mutex> lock(mutex); + pending -= std::size_t(not_first); + not_first = true; + finished_work.notify_all(); + + if (stop) { + return; + } + + next = get_next(); + if (next == nullptr) { + finished_work.wait(lock, [this] { return pending == 0 || stop; }); + return; + } + } + assert(next->f); + next->f(); + } +} + +std::unique_ptr<thread_group::data::work> thread_group::data::get_next() noexcept { + std::unique_ptr<work> rc = std::move(head.ptr); + if (rc != nullptr) { + head.ptr = std::move(rc->next.ptr); + if (head.ptr == nullptr) { + last = std::addressof(head); + } + } + return rc; +} + +void thread_group::data::run() noexcept { + // see `sync()` source for additional information + while (true) { + std::unique_ptr<work> next = nullptr; + { + std::unique_lock<std::mutex> lock(mutex); + --pending; + finished_work.notify_all(); + has_work.wait(lock, [this] { return head.ptr != nullptr || stop; }); + if (stop) { + return; + } + next = get_next(); + } + assert(next != nullptr); + assert(next->f); + next->f(); + } +} + +void thread_group::data::dispatch(std::function<void()> f) { + std::unique_ptr<work> latest(new work{std::move(f), node{nullptr}}); + node* const latest_node = std::addressof(latest->next); + { + const std::unique_lock<std::mutex> lock(mutex); + assert(last != nullptr); + assert(last->next == nullptr); + if (pending == std::numeric_limits<std::size_t>::max()) { + throw std::overflow_error("thread_group exceeded max queue depth"); + } + last->ptr = std::move(latest); + last = latest_node; + ++pending; + } + has_work.notify_one(); +} +} diff --git a/src/common/thread_group.h b/src/common/thread_group.h new file mode 100644 index 000000000..d8461d49a --- /dev/null +++ b/src/common/thread_group.h @@ -0,0 +1,133 @@ +// Copyright (c) 2014-2016, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include <boost/optional/optional.hpp> +#include <condition_variable> +#include <cstddef> +#include <functional> +#include <thread> +#include <utility> +#include <vector> + +namespace tools +{ +//! Manages zero or more threads for work dispatching +class thread_group +{ +public: + //! Create `min(count, get_max_concurrency()) - 1` threads + explicit thread_group(std::size_t count); + + thread_group(thread_group const&) = delete; + thread_group(thread_group&&) = delete; + + //! Joins threads, but does not necessarily run all dispatched functions. + ~thread_group() = default; + + thread_group& operator=(thread_group const&) = delete; + thread_group& operator=(thread_group&&) = delete; + + /*! Blocks until all functions provided to `dispatch` complete. Does not + destroy threads. If a dispatched function calls `this->dispatch(...)`, + `this->sync()` will continue to block until that new function completes. */ + void sync() noexcept { + if (internal) { + internal->sync(); + } + } + + /*! Example usage: + std::unique_ptr<thread_group, thread_group::lazy_sync> sync(std::addressof(group)); + which guarantees synchronization before the unique_ptr destructor returns. */ + struct lazy_sync { + void operator()(thread_group* group) const noexcept { + if (group != nullptr) { + group->sync(); + } + } + }; + + /*! `f` is invoked immediately if the thread_group is empty, otherwise + execution of `f` is queued for next available thread. If `f` is queued, any + exception leaving that function will result in process termination. Use + std::packaged_task if exceptions need to be handled. */ + template<typename F> + void dispatch(F&& f) { + if (internal) { + internal->dispatch(std::forward<F>(f)); + } + else { + f(); + } + } + +private: + class data { + public: + data(std::size_t count); + ~data() noexcept; + + void sync() noexcept; + + void dispatch(std::function<void()> f); + + private: + struct work; + + struct node { + node() = delete; + std::unique_ptr<work> ptr; + }; + + struct work { + work() = delete; + std::function<void()> f; + node next; + }; + + //! Requires lock on `mutex`. + std::unique_ptr<work> get_next() noexcept; + + //! Blocks until destructor is invoked, only call from thread. + void run() noexcept; + + private: + std::vector<std::thread> threads; + node head; + node* last; + std::size_t pending; + std::condition_variable has_work; + std::condition_variable finished_work; + std::mutex mutex; + bool stop; + }; + +private: + // optionally construct elements, without separate heap allocation + boost::optional<data> internal; +}; +} diff --git a/src/ringct/rctSigs.cpp b/src/ringct/rctSigs.cpp index 19e9d291e..df33c26b2 100644 --- a/src/ringct/rctSigs.cpp +++ b/src/ringct/rctSigs.cpp @@ -28,9 +28,9 @@ // STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF // THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#include <boost/asio.hpp> #include "misc_log_ex.h" #include "common/perf_timer.h" +#include "common/thread_group.h" #include "common/util.h" #include "rctSigs.h" #include "cryptonote_core/cryptonote_format_utils.h" @@ -38,17 +38,22 @@ using namespace crypto; using namespace std; -#define KILL_IOSERVICE() \ - if(ioservice_active) \ - { \ - work.reset(); \ - while (!ioservice.stopped()) ioservice.poll(); \ - threadpool.join_all(); \ - ioservice.stop(); \ - ioservice_active = false; \ +namespace rct { + namespace { + struct verRangeWrapper_ { + void operator()(const key & C, const rangeSig & as, bool &result) const { + result = verRange(C, as); } + }; + constexpr const verRangeWrapper_ verRangeWrapper{}; -namespace rct { + struct verRctMGSimpleWrapper_ { + void operator()(const key &message, const mgSig &mg, const ctkeyV & pubs, const key & C, bool &result) const { + result = verRctMGSimple(message, mg, pubs, C); + } + }; + constexpr const verRctMGSimpleWrapper_ verRctMGSimpleWrapper{}; + } //Schnorr Non-linkable //Gen Gives a signature (L1, s1, s2) proving that the sender knows "x" such that xG = one of P1 or P2 @@ -360,10 +365,6 @@ namespace rct { return true; } - void verRangeWrapper(const key & C, const rangeSig & as, bool &result) { - result = verRange(C, as); - } - key get_pre_mlsag_hash(const rctSig &rv) { keyV hashes; @@ -544,9 +545,6 @@ namespace rct { return MLSAG_Ver(message, M, mg, rows); } - void verRctMGSimpleWrapper(const key &message, const mgSig &mg, const ctkeyV & pubs, const key & C, bool &result) { - result = verRctMGSimple(message, mg, pubs, C); - } //These functions get keys from blockchain //replace these when connecting blockchain @@ -767,38 +765,20 @@ namespace rct { // some rct ops can throw try { - boost::asio::io_service ioservice; - boost::thread_group threadpool; - std::unique_ptr<boost::asio::io_service::work> work(new boost::asio::io_service::work(ioservice)); - size_t threads = tools::get_max_concurrency(); - threads = std::min(threads, rv.outPk.size()); - for (size_t i = 0; i < threads; ++i) - threadpool.create_thread(boost::bind(&boost::asio::io_service::run, &ioservice)); - bool ioservice_active = true; std::deque<bool> results(rv.outPk.size(), false); - epee::misc_utils::auto_scope_leave_caller ioservice_killer = epee::misc_utils::create_scope_leave_handler([&]() { KILL_IOSERVICE(); }); + tools::thread_group threadpool(rv.outPk.size()); // this must destruct before results DP("range proofs verified?"); for (size_t i = 0; i < rv.outPk.size(); i++) { - if (threads > 1) { - ioservice.dispatch(boost::bind(&verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i]))); - } - else { - bool tmp = verRange(rv.outPk[i].mask, rv.p.rangeSigs[i]); - DP(tmp); - if (!tmp) { - LOG_ERROR("Range proof verification failed for input " << i); - return false; - } - } + threadpool.dispatch( + std::bind(verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i])) + ); } - KILL_IOSERVICE(); - if (threads > 1) { - for (size_t i = 0; i < rv.outPk.size(); ++i) { - if (!results[i]) { - LOG_ERROR("Range proof verified failed for input " << i); - return false; - } + threadpool.sync(); + for (size_t i = 0; i < rv.outPk.size(); ++i) { + if (!results[i]) { + LOG_ERROR("Range proof verified failed for input " << i); + return false; } } @@ -832,34 +812,23 @@ namespace rct { CHECK_AND_ASSERT_MES(rv.pseudoOuts.size() == rv.p.MGs.size(), false, "Mismatched sizes of rv.pseudoOuts and rv.p.MGs"); CHECK_AND_ASSERT_MES(rv.pseudoOuts.size() == rv.mixRing.size(), false, "Mismatched sizes of rv.pseudoOuts and mixRing"); + const size_t threads = std::max(rv.outPk.size(), rv.mixRing.size()); + tools::thread_group threadpool(threads); { - boost::asio::io_service ioservice; - boost::thread_group threadpool; - std::unique_ptr<boost::asio::io_service::work> work(new boost::asio::io_service::work(ioservice)); - size_t threads = tools::get_max_concurrency(); - threads = std::min(threads, rv.outPk.size()); - for (size_t i = 0; i < threads; ++i) - threadpool.create_thread(boost::bind(&boost::asio::io_service::run, &ioservice)); - bool ioservice_active = true; std::deque<bool> results(rv.outPk.size(), false); - epee::misc_utils::auto_scope_leave_caller ioservice_killer = epee::misc_utils::create_scope_leave_handler([&]() { KILL_IOSERVICE(); }); - - for (i = 0; i < rv.outPk.size(); i++) { - if (threads > 1) { - ioservice.dispatch(boost::bind(&verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i]))); - } - else if (!verRange(rv.outPk[i].mask, rv.p.rangeSigs[i])) { - LOG_ERROR("Range proof verified failed for input " << i); - return false; + { + const std::unique_ptr<tools::thread_group, tools::thread_group::lazy_sync> + sync(std::addressof(threadpool)); + for (i = 0; i < rv.outPk.size(); i++) { + threadpool.dispatch( + std::bind(verRangeWrapper, std::cref(rv.outPk[i].mask), std::cref(rv.p.rangeSigs[i]), std::ref(results[i])) + ); } - } - KILL_IOSERVICE(); - if (threads > 1) { - for (size_t i = 0; i < rv.outPk.size(); ++i) { - if (!results[i]) { - LOG_ERROR("Range proof verified failed for input " << i); - return false; - } + } // threadpool.sync(); + for (size_t i = 0; i < rv.outPk.size(); ++i) { + if (!results[i]) { + LOG_ERROR("Range proof verified failed for input " << i); + return false; } } } @@ -875,37 +844,20 @@ namespace rct { key message = get_pre_mlsag_hash(rv); { - boost::asio::io_service ioservice; - boost::thread_group threadpool; - std::unique_ptr<boost::asio::io_service::work> work(new boost::asio::io_service::work(ioservice)); - size_t threads = tools::get_max_concurrency(); - threads = std::min(threads, rv.mixRing.size()); - for (size_t i = 0; i < threads; ++i) - threadpool.create_thread(boost::bind(&boost::asio::io_service::run, &ioservice)); - bool ioservice_active = true; std::deque<bool> results(rv.mixRing.size(), false); - epee::misc_utils::auto_scope_leave_caller ioservice_killer = epee::misc_utils::create_scope_leave_handler([&]() { KILL_IOSERVICE(); }); - - for (i = 0 ; i < rv.mixRing.size() ; i++) { - if (threads > 1) { - ioservice.dispatch(boost::bind(&verRctMGSimpleWrapper, std::cref(message), std::cref(rv.p.MGs[i]), std::cref(rv.mixRing[i]), std::cref(rv.pseudoOuts[i]), std::ref(results[i]))); + { + const std::unique_ptr<tools::thread_group, tools::thread_group::lazy_sync> + sync(std::addressof(threadpool)); + for (i = 0 ; i < rv.mixRing.size() ; i++) { + threadpool.dispatch( + std::bind(verRctMGSimpleWrapper, std::cref(message), std::cref(rv.p.MGs[i]), std::cref(rv.mixRing[i]), std::cref(rv.pseudoOuts[i]), std::ref(results[i])) + ); } - else { - bool tmpb = verRctMGSimple(message, rv.p.MGs[i], rv.mixRing[i], rv.pseudoOuts[i]); - DP(tmpb); - if (!tmpb) { - LOG_ERROR("verRctMGSimple failed for input " << i); - return false; - } - } - } - KILL_IOSERVICE(); - if (threads > 1) { - for (size_t i = 0; i < results.size(); ++i) { - if (!results[i]) { - LOG_ERROR("verRctMGSimple failed for input " << i); - return false; - } + } // threadpool.sync(); + for (size_t i = 0; i < results.size(); ++i) { + if (!results[i]) { + LOG_ERROR("verRctMGSimple failed for input " << i); + return false; } } } |