diff options
Diffstat (limited to 'tests/unit_tests')
-rw-r--r-- | tests/unit_tests/CMakeLists.txt | 1 | ||||
-rw-r--r-- | tests/unit_tests/ban.cpp | 10 | ||||
-rw-r--r-- | tests/unit_tests/device.cpp | 131 | ||||
-rw-r--r-- | tests/unit_tests/epee_levin_protocol_handler_async.cpp | 1 | ||||
-rw-r--r-- | tests/unit_tests/hardfork.cpp | 40 | ||||
-rw-r--r-- | tests/unit_tests/threadpool.cpp | 57 |
6 files changed, 211 insertions, 29 deletions
diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 6d79ba74b..3105eccfa 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -41,6 +41,7 @@ set(unit_tests_sources command_line.cpp crypto.cpp decompose_amount_into_digits.cpp + device.cpp dns_resolver.cpp epee_boosted_tcp_server.cpp epee_levin_protocol_handler_async.cpp diff --git a/tests/unit_tests/ban.cpp b/tests/unit_tests/ban.cpp index 15bc0bce3..e3dbdaef1 100644 --- a/tests/unit_tests/ban.cpp +++ b/tests/unit_tests/ban.cpp @@ -55,7 +55,7 @@ public: bool have_block(const crypto::hash& id) const {return true;} void get_blockchain_top(uint64_t& height, crypto::hash& top_id)const{height=0;top_id=crypto::null_hash;} bool handle_incoming_tx(const cryptonote::blobdata& tx_blob, cryptonote::tx_verification_context& tvc, bool keeped_by_block, bool relayed, bool do_not_relay) { return true; } - bool handle_incoming_txs(const std::list<cryptonote::blobdata>& tx_blob, std::vector<cryptonote::tx_verification_context>& tvc, bool keeped_by_block, bool relayed, bool do_not_relay) { return true; } + bool handle_incoming_txs(const std::vector<cryptonote::blobdata>& tx_blob, std::vector<cryptonote::tx_verification_context>& tvc, bool keeped_by_block, bool relayed, bool do_not_relay) { return true; } bool handle_incoming_block(const cryptonote::blobdata& block_blob, cryptonote::block_verification_context& bvc, bool update_miner_blocktemplate = true) { return true; } void pause_mine(){} void resume_mine(){} @@ -65,7 +65,7 @@ public: cryptonote::blockchain_storage &get_blockchain_storage() { throw std::runtime_error("Called invalid member function: please never call get_blockchain_storage on the TESTING class test_core."); } bool get_test_drop_download() const {return true;} bool get_test_drop_download_height() const {return true;} - bool prepare_handle_incoming_blocks(const std::list<cryptonote::block_complete_entry> &blocks) { return true; } + bool prepare_handle_incoming_blocks(const std::vector<cryptonote::block_complete_entry> &blocks) { return true; } bool cleanup_handle_incoming_blocks(bool force_sync = false) { return true; } uint64_t get_target_blockchain_height() const { return 1; } size_t get_block_sync_size(uint64_t height) const { return BLOCKS_SYNCHRONIZING_DEFAULT_COUNT; } @@ -73,8 +73,8 @@ public: cryptonote::network_type get_nettype() const { return cryptonote::MAINNET; } bool get_pool_transaction(const crypto::hash& id, cryptonote::blobdata& tx_blob) const { return false; } bool pool_has_tx(const crypto::hash &txid) const { return false; } - bool get_blocks(uint64_t start_offset, size_t count, std::list<std::pair<cryptonote::blobdata, cryptonote::block>>& blocks, std::list<cryptonote::blobdata>& txs) const { return false; } - bool get_transactions(const std::vector<crypto::hash>& txs_ids, std::list<cryptonote::transaction>& txs, std::list<crypto::hash>& missed_txs) const { return false; } + bool get_blocks(uint64_t start_offset, size_t count, std::vector<std::pair<cryptonote::blobdata, cryptonote::block>>& blocks, std::vector<cryptonote::blobdata>& txs) const { return false; } + bool get_transactions(const std::vector<crypto::hash>& txs_ids, std::vector<cryptonote::transaction>& txs, std::vector<crypto::hash>& missed_txs) const { return false; } bool get_block_by_hash(const crypto::hash &h, cryptonote::block &blk, bool *orphan = NULL) const { return false; } uint8_t get_ideal_hard_fork_version() const { return 0; } uint8_t get_ideal_hard_fork_version(uint64_t height) const { return 0; } @@ -82,7 +82,7 @@ public: uint64_t get_earliest_ideal_height_for_version(uint8_t version) const { return 0; } cryptonote::difficulty_type get_block_cumulative_difficulty(uint64_t height) const { return 0; } bool fluffy_blocks_enabled() const { return false; } - uint64_t prevalidate_block_hashes(uint64_t height, const std::list<crypto::hash> &hashes) { return 0; } + uint64_t prevalidate_block_hashes(uint64_t height, const std::vector<crypto::hash> &hashes) { return 0; } void stop() {} }; diff --git a/tests/unit_tests/device.cpp b/tests/unit_tests/device.cpp new file mode 100644 index 000000000..50ccec9fa --- /dev/null +++ b/tests/unit_tests/device.cpp @@ -0,0 +1,131 @@ +// Copyright (c) 2018, The Monero Project +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, are +// permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list +// of conditions and the following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be +// used to endorse or promote products derived from this software without specific +// prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +// THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "gtest/gtest.h" +#include "ringct/rctOps.h" +#include "device/device_default.hpp" + +TEST(device, name) +{ + hw::core::device_default dev; + ASSERT_TRUE(dev.set_name("test")); + ASSERT_EQ(dev.get_name(), "test"); +} + +/* +TEST(device, locking) +{ + hw::core::device_default dev; + ASSERT_TRUE(dev.try_lock()); + ASSERT_FALSE(dev.try_lock()); + dev.unlock(); + ASSERT_TRUE(dev.try_lock()); + dev.unlock(); + dev.lock(); + ASSERT_FALSE(dev.try_lock()); + dev.unlock(); + ASSERT_TRUE(dev.try_lock()); + dev.unlock(); +} +*/ + +TEST(device, open_close) +{ + hw::core::device_default dev; + crypto::secret_key key; + ASSERT_TRUE(dev.open_tx(key)); + ASSERT_TRUE(dev.close_tx()); +} + +TEST(device, ops) +{ + hw::core::device_default dev; + rct::key resd, res; + crypto::key_derivation derd, der; + rct::key sk, pk; + crypto::secret_key sk0, sk1; + crypto::public_key pk0, pk1; + crypto::ec_scalar ressc0, ressc1; + crypto::key_image ki0, ki1; + + rct::skpkGen(sk, pk); + rct::scalarmultBase((rct::key&)pk0, (rct::key&)sk0); + rct::scalarmultBase((rct::key&)pk1, (rct::key&)sk1); + + dev.scalarmultKey(resd, pk, sk); + rct::scalarmultKey(res, pk, sk); + ASSERT_EQ(resd, res); + + dev.scalarmultBase(resd, sk); + rct::scalarmultBase(res, sk); + ASSERT_EQ(resd, res); + + dev.sc_secret_add((crypto::secret_key&)resd, sk0, sk1); + sc_add((unsigned char*)&res, (unsigned char*)&sk0, (unsigned char*)&sk1); + ASSERT_EQ(resd, res); + + dev.generate_key_derivation(pk0, sk0, derd); + crypto::generate_key_derivation(pk0, sk0, der); + ASSERT_FALSE(memcmp(&derd, &der, sizeof(der))); + + dev.derivation_to_scalar(der, 0, ressc0); + crypto::derivation_to_scalar(der, 0, ressc1); + ASSERT_FALSE(memcmp(&ressc0, &ressc1, sizeof(ressc1))); + + dev.derive_secret_key(der, 0, rct::rct2sk(sk), sk0); + crypto::derive_secret_key(der, 0, rct::rct2sk(sk), sk1); + ASSERT_EQ(sk0, sk1); + + dev.derive_public_key(der, 0, rct::rct2pk(pk), pk0); + crypto::derive_public_key(der, 0, rct::rct2pk(pk), pk1); + ASSERT_EQ(pk0, pk1); + + dev.secret_key_to_public_key(rct::rct2sk(sk), pk0); + crypto::secret_key_to_public_key(rct::rct2sk(sk), pk1); + ASSERT_EQ(pk0, pk1); + + dev.generate_key_image(pk0, sk0, ki0); + crypto::generate_key_image(pk0, sk0, ki1); + ASSERT_EQ(ki0, ki1); +} + +TEST(device, ecdh) +{ + hw::core::device_default dev; + rct::ecdhTuple tuple, tuple2; + rct::key key = rct::skGen(); + tuple.mask = rct::skGen(); + tuple.amount = rct::skGen(); + tuple.senderPk = rct::pkGen(); + tuple2 = tuple; + dev.ecdhEncode(tuple, key); + dev.ecdhDecode(tuple, key); + ASSERT_EQ(tuple2.mask, tuple.mask); + ASSERT_EQ(tuple2.amount, tuple.amount); + ASSERT_EQ(tuple2.senderPk, tuple.senderPk); +} diff --git a/tests/unit_tests/epee_levin_protocol_handler_async.cpp b/tests/unit_tests/epee_levin_protocol_handler_async.cpp index 38a8360d7..72d8f3205 100644 --- a/tests/unit_tests/epee_levin_protocol_handler_async.cpp +++ b/tests/unit_tests/epee_levin_protocol_handler_async.cpp @@ -150,6 +150,7 @@ namespace } virtual bool close() { /*std::cout << "test_connection::close()" << std::endl; */return true; } + virtual bool send_done() { /*std::cout << "test_connection::send_done()" << std::endl; */return true; } virtual bool call_run_once_service_io() { std::cout << "test_connection::call_run_once_service_io()" << std::endl; return true; } virtual bool request_callback() { std::cout << "test_connection::request_callback()" << std::endl; return true; } virtual boost::asio::io_service& get_io_service() { std::cout << "test_connection::get_io_service()" << std::endl; return m_io_service; } diff --git a/tests/unit_tests/hardfork.cpp b/tests/unit_tests/hardfork.cpp index 913ebe84a..930aeb782 100644 --- a/tests/unit_tests/hardfork.cpp +++ b/tests/unit_tests/hardfork.cpp @@ -50,6 +50,7 @@ public: virtual void safesyncmode(const bool onoff) {} virtual void reset() {} virtual std::vector<std::string> get_filenames() const { return std::vector<std::string>(); } + virtual bool remove_data_file(const std::string& folder) const { return true; } virtual std::string get_db_name() const { return std::string(); } virtual bool lock() { return true; } virtual void unlock() { } @@ -69,6 +70,7 @@ public: virtual uint64_t get_block_height(const crypto::hash& h) const { return 0; } virtual block_header get_block_header(const crypto::hash& h) const { return block_header(); } virtual uint64_t get_block_timestamp(const uint64_t& height) const { return 0; } + virtual std::vector<uint64_t> get_block_cumulative_rct_outputs(const std::vector<uint64_t> &heights) const { return {}; } virtual uint64_t get_top_block_timestamp() const { return 0; } virtual size_t get_block_size(const uint64_t& height) const { return 128; } virtual difficulty_type get_block_cumulative_difficulty(const uint64_t& height) const { return 10; } @@ -124,6 +126,7 @@ public: virtual void remove_txpool_tx(const crypto::hash& txid) {} virtual bool get_txpool_tx_meta(const crypto::hash& txid, txpool_tx_meta_t &meta) const { return false; } virtual bool get_txpool_tx_blob(const crypto::hash& txid, cryptonote::blobdata &bd) const { return false; } + virtual uint64_t get_database_size() const { return 0; } virtual cryptonote::blobdata get_txpool_tx_blob(const crypto::hash& txid) const { return ""; } virtual bool for_all_txpool_txes(std::function<bool(const crypto::hash&, const txpool_tx_meta_t&, const cryptonote::blobdata*)>, bool include_blob = false, bool include_unrelayed_txes = false) const { return false; } @@ -131,6 +134,7 @@ public: , const size_t& block_size , const difficulty_type& cumulative_difficulty , const uint64_t& coins_generated + , uint64_t num_rct_outs , const crypto::hash& blk_hash ) { blocks.push_back(blk); @@ -183,20 +187,20 @@ TEST(major, Only) ASSERT_FALSE(hf.add(mkblock(0, 2), 0)); ASSERT_FALSE(hf.add(mkblock(2, 2), 0)); ASSERT_TRUE(hf.add(mkblock(1, 2), 0)); - db.add_block(mkblock(1, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(1, 1), 0, 0, 0, 0, crypto::hash()); // block height 1, only version 1 is accepted ASSERT_FALSE(hf.add(mkblock(0, 2), 1)); ASSERT_FALSE(hf.add(mkblock(2, 2), 1)); ASSERT_TRUE(hf.add(mkblock(1, 2), 1)); - db.add_block(mkblock(1, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(1, 1), 0, 0, 0, 0, crypto::hash()); // block height 2, only version 2 is accepted ASSERT_FALSE(hf.add(mkblock(0, 2), 2)); ASSERT_FALSE(hf.add(mkblock(1, 2), 2)); ASSERT_FALSE(hf.add(mkblock(3, 2), 2)); ASSERT_TRUE(hf.add(mkblock(2, 2), 2)); - db.add_block(mkblock(2, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(2, 1), 0, 0, 0, 0, crypto::hash()); } TEST(empty_hardforks, Success) @@ -210,7 +214,7 @@ TEST(empty_hardforks, Success) ASSERT_TRUE(hf.get_state(time(NULL) + 3600*24*400) == HardFork::Ready); for (uint64_t h = 0; h <= 10; ++h) { - db.add_block(mkblock(hf, h, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 1), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } ASSERT_EQ(hf.get(0), 1); @@ -244,14 +248,14 @@ TEST(check_for_height, Success) for (uint64_t h = 0; h <= 4; ++h) { ASSERT_TRUE(hf.check_for_height(mkblock(1, 1), h)); ASSERT_FALSE(hf.check_for_height(mkblock(2, 2), h)); // block version is too high - db.add_block(mkblock(hf, h, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 1), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } for (uint64_t h = 5; h <= 10; ++h) { ASSERT_FALSE(hf.check_for_height(mkblock(1, 1), h)); // block version is too low ASSERT_TRUE(hf.check_for_height(mkblock(2, 2), h)); - db.add_block(mkblock(hf, h, 2), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 2), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } } @@ -268,19 +272,19 @@ TEST(get, next_version) for (uint64_t h = 0; h <= 4; ++h) { ASSERT_EQ(2, hf.get_next_version()); - db.add_block(mkblock(hf, h, 1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 1), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } for (uint64_t h = 5; h <= 9; ++h) { ASSERT_EQ(4, hf.get_next_version()); - db.add_block(mkblock(hf, h, 2), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 2), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } for (uint64_t h = 10; h <= 15; ++h) { ASSERT_EQ(4, hf.get_next_version()); - db.add_block(mkblock(hf, h, 4), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 4), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } } @@ -321,7 +325,7 @@ TEST(steps_asap, Success) hf.init(); for (uint64_t h = 0; h < 10; ++h) { - db.add_block(mkblock(hf, h, 9), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, 9), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } @@ -348,7 +352,7 @@ TEST(steps_1, Success) hf.init(); for (uint64_t h = 0 ; h < 10; ++h) { - db.add_block(mkblock(hf, h, h+1), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, h+1), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } @@ -373,7 +377,7 @@ TEST(reorganize, Same) // index 0 1 2 3 4 5 6 7 8 9 static const uint8_t block_versions[] = { 1, 1, 4, 4, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 }; for (uint64_t h = 0; h < 20; ++h) { - db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } @@ -404,7 +408,7 @@ TEST(reorganize, Changed) static const uint8_t block_versions[] = { 1, 1, 4, 4, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 }; static const uint8_t expected_versions[] = { 1, 1, 1, 1, 1, 1, 4, 4, 7, 7, 9, 9, 9, 9, 9, 9 }; for (uint64_t h = 0; h < 16; ++h) { - db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE (hf.add(db.get_block_from_height(h), h)); } @@ -424,7 +428,7 @@ TEST(reorganize, Changed) ASSERT_EQ(db.height(), 3); hf.reorganize_from_block_height(2); for (uint64_t h = 3; h < 16; ++h) { - db.add_block(mkblock(hf, h, block_versions_new[h]), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, block_versions_new[h]), 0, 0, 0, 0, crypto::hash()); bool ret = hf.add(db.get_block_from_height(h), h); ASSERT_EQ (ret, h < 15); } @@ -448,7 +452,7 @@ TEST(voting, threshold) for (uint64_t h = 0; h <= 8; ++h) { uint8_t v = 1 + !!(h % 8); - db.add_block(mkblock(hf, h, v), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, v), 0, 0, 0, 0, crypto::hash()); bool ret = hf.add(db.get_block_from_height(h), h); if (h >= 8 && threshold == 87) { // for threshold 87, we reach the treshold at height 7, so from height 8, hard fork to version 2, but 8 tries to add 1 @@ -482,7 +486,7 @@ TEST(voting, different_thresholds) static const uint8_t expected_versions[] = { 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4 }; for (uint64_t h = 0; h < sizeof(block_versions) / sizeof(block_versions[0]); ++h) { - db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, 0, crypto::hash()); bool ret = hf.add(db.get_block_from_height(h), h); ASSERT_EQ(ret, true); } @@ -536,7 +540,7 @@ TEST(voting, info) ASSERT_EQ(expected_thresholds[h], threshold); ASSERT_EQ(4, voting); - db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, crypto::hash()); + db.add_block(mkblock(hf, h, block_versions[h]), 0, 0, 0, 0, crypto::hash()); ASSERT_TRUE(hf.add(db.get_block_from_height(h), h)); } } @@ -599,7 +603,7 @@ TEST(reorganize, changed) #define ADD(v, h, a) \ do { \ cryptonote::block b = mkblock(hf, h, v); \ - db.add_block(b, 0, 0, 0, crypto::hash()); \ + db.add_block(b, 0, 0, 0, 0, crypto::hash()); \ ASSERT_##a(hf.add(b, h)); \ } while(0) #define ADD_TRUE(v, h) ADD(v, h, TRUE) diff --git a/tests/unit_tests/threadpool.cpp b/tests/unit_tests/threadpool.cpp index 34be1417a..1307cd738 100644 --- a/tests/unit_tests/threadpool.cpp +++ b/tests/unit_tests/threadpool.cpp @@ -35,7 +35,7 @@ TEST(threadpool, wait_nothing) { std::shared_ptr<tools::threadpool> tpool(tools::threadpool::getNewForUnitTests()); tools::threadpool::waiter waiter; - waiter.wait(); + waiter.wait(tpool.get()); } TEST(threadpool, wait_waits) @@ -45,7 +45,7 @@ TEST(threadpool, wait_waits) std::atomic<bool> b(false); tpool->submit(&waiter, [&b](){ epee::misc_utils::sleep_no_w(1000); b = true; }); ASSERT_FALSE(b); - waiter.wait(); + waiter.wait(tpool.get()); ASSERT_TRUE(b); } @@ -59,7 +59,7 @@ TEST(threadpool, one_thread) { tpool->submit(&waiter, [&counter](){++counter;}); } - waiter.wait(); + waiter.wait(tpool.get()); ASSERT_EQ(counter, 4096); } @@ -73,7 +73,7 @@ TEST(threadpool, many_threads) { tpool->submit(&waiter, [&counter](){++counter;}); } - waiter.wait(); + waiter.wait(tpool.get()); ASSERT_EQ(counter, 4096); } @@ -85,7 +85,7 @@ static uint64_t fibonacci(std::shared_ptr<tools::threadpool> tpool, uint64_t n) tools::threadpool::waiter waiter; tpool->submit(&waiter, [&tpool, &f1, n](){ f1 = fibonacci(tpool, n-1); }); tpool->submit(&waiter, [&tpool, &f2, n](){ f2 = fibonacci(tpool, n-2); }); - waiter.wait(); + waiter.wait(tpool.get()); return f1 + f2; } @@ -95,7 +95,52 @@ TEST(threadpool, reentrency) tools::threadpool::waiter waiter; uint64_t f = fibonacci(tpool, 13); - waiter.wait(); + waiter.wait(tpool.get()); ASSERT_EQ(f, 233); } +TEST(threadpool, reentrancy) +{ + std::shared_ptr<tools::threadpool> tpool(tools::threadpool::getNewForUnitTests(4)); + tools::threadpool::waiter waiter; + + uint64_t f = fibonacci(tpool, 13); + waiter.wait(tpool.get()); + ASSERT_EQ(f, 233); +} + +TEST(threadpool, leaf_throws) +{ + std::shared_ptr<tools::threadpool> tpool(tools::threadpool::getNewForUnitTests()); + tools::threadpool::waiter waiter; + + bool thrown = false, executed = false; + tpool->submit(&waiter, [&](){ + try { tpool->submit(&waiter, [&](){ executed = true; }); } + catch(const std::exception &e) { thrown = true; } + }, true); + waiter.wait(tpool.get()); + ASSERT_TRUE(thrown); + ASSERT_FALSE(executed); +} + +TEST(threadpool, leaf_reentrancy) +{ + std::shared_ptr<tools::threadpool> tpool(tools::threadpool::getNewForUnitTests(4)); + tools::threadpool::waiter waiter; + + std::atomic<int> counter(0); + for (int i = 0; i < 1000; ++i) + { + tpool->submit(&waiter, [&](){ + tools::threadpool::waiter waiter; + for (int j = 0; j < 500; ++j) + { + tpool->submit(&waiter, [&](){ ++counter; }, true); + } + waiter.wait(tpool.get()); + }); + } + waiter.wait(tpool.get()); + ASSERT_EQ(counter, 500000); +} |