diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | contrib/epee/include/net/levin_protocol_handler_async.h | 2 | ||||
-rw-r--r-- | src/device_trezor/device_trezor_base.cpp | 29 | ||||
-rw-r--r-- | src/device_trezor/device_trezor_base.hpp | 2 | ||||
-rw-r--r-- | tests/unit_tests/epee_boosted_tcp_server.cpp | 75 |
5 files changed, 85 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f43f9d1a..49ac18c66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -522,7 +522,7 @@ ExternalProject_Add(generate_translations_header BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/translations" STAMP_DIR ${LRELEASE_PATH} CMAKE_ARGS -DLRELEASE_PATH=${LRELEASE_PATH} - INSTALL_COMMAND cmake -E echo "") + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "") include_directories("${CMAKE_CURRENT_BINARY_DIR}/translations") add_subdirectory(external) diff --git a/contrib/epee/include/net/levin_protocol_handler_async.h b/contrib/epee/include/net/levin_protocol_handler_async.h index 635876589..f6b73a2d5 100644 --- a/contrib/epee/include/net/levin_protocol_handler_async.h +++ b/contrib/epee/include/net/levin_protocol_handler_async.h @@ -787,7 +787,7 @@ void async_protocol_handler_config<t_connection_context>::delete_connections(siz { auto i = connections.end() - 1; async_protocol_handler<t_connection_context> *conn = m_connects.at(*i); - del_connection(conn); + m_connects.erase(*i); conn->close(); connections.erase(i); } diff --git a/src/device_trezor/device_trezor_base.cpp b/src/device_trezor/device_trezor_base.cpp index f59be1573..70dc7f539 100644 --- a/src/device_trezor/device_trezor_base.cpp +++ b/src/device_trezor/device_trezor_base.cpp @@ -365,15 +365,14 @@ namespace trezor { void device_trezor_base::device_state_initialize_unsafe() { require_connected(); - std::string tmp_session_id; auto initMsg = std::make_shared<messages::management::Initialize>(); const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() { - memwipe(&tmp_session_id[0], tmp_session_id.size()); + if (initMsg->has_session_id()) + memwipe(&(*initMsg->mutable_session_id())[0], initMsg->mutable_session_id()->size()); }); if(!m_device_session_id.empty()) { - tmp_session_id.assign(m_device_session_id.data(), m_device_session_id.size()); - initMsg->set_allocated_session_id(&tmp_session_id); + initMsg->set_allocated_session_id(new std::string(m_device_session_id.data(), m_device_session_id.size())); } m_features = this->client_exchange<messages::management::Features>(initMsg); @@ -382,8 +381,6 @@ namespace trezor { } else { m_device_session_id.clear(); } - - initMsg->release_session_id(); } void device_trezor_base::device_state_reset() @@ -453,18 +450,14 @@ namespace trezor { pin = m_pin; } - std::string pin_field; messages::common::PinMatrixAck m; if (pin) { - pin_field.assign(pin->data(), pin->size()); - m.set_allocated_pin(&pin_field); + m.set_allocated_pin(new std::string(pin->data(), pin->size())); } const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() { - m.release_pin(); - if (!pin_field.empty()){ - memwipe(&pin_field[0], pin_field.size()); - } + if (m.has_pin()) + memwipe(&(*m.mutable_pin())[0], m.mutable_pin()->size()); }); resp = call_raw(&m); @@ -499,7 +492,6 @@ namespace trezor { boost::optional<epee::wipeable_string> passphrase; TREZOR_CALLBACK_GET(passphrase, on_passphrase_request, on_device); - std::string passphrase_field; messages::common::PassphraseAck m; m.set_on_device(on_device); if (!on_device) { @@ -512,16 +504,13 @@ namespace trezor { } if (passphrase) { - passphrase_field.assign(passphrase->data(), passphrase->size()); - m.set_allocated_passphrase(&passphrase_field); + m.set_allocated_passphrase(new std::string(passphrase->data(), passphrase->size())); } } const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() { - m.release_passphrase(); - if (!passphrase_field.empty()){ - memwipe(&passphrase_field[0], passphrase_field.size()); - } + if (m.has_passphrase()) + memwipe(&(m.mutable_passphrase())[0], m.mutable_passphrase()->size()); }); resp = call_raw(&m); diff --git a/src/device_trezor/device_trezor_base.hpp b/src/device_trezor/device_trezor_base.hpp index 4db8f0c8e..0162b23df 100644 --- a/src/device_trezor/device_trezor_base.hpp +++ b/src/device_trezor/device_trezor_base.hpp @@ -165,7 +165,7 @@ namespace trezor { // Scoped session closer BOOST_SCOPE_EXIT_ALL(&, this) { - if (open_session){ + if (open_session && this->get_transport()){ this->get_transport()->close(); } }; diff --git a/tests/unit_tests/epee_boosted_tcp_server.cpp b/tests/unit_tests/epee_boosted_tcp_server.cpp index e99666eb1..06e076a3a 100644 --- a/tests/unit_tests/epee_boosted_tcp_server.cpp +++ b/tests/unit_tests/epee_boosted_tcp_server.cpp @@ -143,14 +143,24 @@ TEST(test_epee_connection, test_lifetime) static constexpr bool handshake_complete() noexcept { return true; } }; + using functional_obj_t = std::function<void ()>; struct command_handler_t: epee::levin::levin_commands_handler<context_t> { size_t delay; - command_handler_t(size_t delay = 0): delay(delay) {} + functional_obj_t on_connection_close_f; + command_handler_t(size_t delay = 0, + functional_obj_t on_connection_close_f = nullptr + ): + delay(delay), + on_connection_close_f(on_connection_close_f) + {} virtual int invoke(int, const epee::span<const uint8_t>, epee::byte_slice&, context_t&) override { epee::misc_utils::sleep_no_w(delay); return {}; } virtual int notify(int, const epee::span<const uint8_t>, context_t&) override { return {}; } virtual void callback(context_t&) override {} virtual void on_connection_new(context_t&) override {} - virtual void on_connection_close(context_t&) override {} + virtual void on_connection_close(context_t&) override { + if (on_connection_close_f) + on_connection_close_f(); + } virtual ~command_handler_t() override {} static void destroy(epee::levin::levin_commands_handler<context_t>* ptr) { delete ptr; } }; @@ -168,6 +178,14 @@ TEST(test_epee_connection, test_lifetime) using work_ptr = std::shared_ptr<work_t>; using workers_t = std::vector<std::thread>; using server_t = epee::net_utils::boosted_tcp_server<handler_t>; + using lock_t = std::mutex; + using lock_guard_t = std::lock_guard<lock_t>; + using connection_weak_ptr = boost::weak_ptr<connection_t>; + struct shared_conn_t { + lock_t lock; + connection_weak_ptr conn; + }; + using shared_conn_ptr = std::shared_ptr<shared_conn_t>; io_context_t io_context; work_ptr work(std::make_shared<work_t>(io_context)); @@ -255,6 +273,7 @@ TEST(test_epee_connection, test_lifetime) ASSERT_TRUE(index == N * N); ASSERT_TRUE(shared_state->get_connections_count() == 0); + while (shared_state->sock_count); ASSERT_TRUE(shared_state->get_connections_count() == 0); constexpr auto DELAY = 30; constexpr auto TIMEOUT = 1; @@ -273,6 +292,58 @@ TEST(test_epee_connection, test_lifetime) shared_state->close(tag); ASSERT_TRUE(shared_state->get_connections_count() == 0); } + + while (shared_state->sock_count); + constexpr auto ZERO_DELAY = 0; + size_t counter = 0; + shared_state->set_handler(new command_handler_t(ZERO_DELAY, + [&counter]{ + ASSERT_TRUE(counter++ == 0); + } + ), + &command_handler_t::destroy + ); + connection_ptr conn(new connection_t(io_context, shared_state, {}, {})); + conn->socket().connect(endpoint); + conn->start({}, {}); + ASSERT_TRUE(shared_state->get_connections_count() == 1); + shared_state->del_out_connections(1); + ASSERT_TRUE(shared_state->get_connections_count() == 0); + conn.reset(); + + while (shared_state->sock_count); + shared_conn_ptr shared_conn(std::make_shared<shared_conn_t>()); + shared_state->set_handler(new command_handler_t(ZERO_DELAY, + [shared_state, shared_conn]{ + { + connection_ptr conn; + { + lock_guard_t guard(shared_conn->lock); + conn = std::move(shared_conn->conn.lock()); + } + if (conn) + conn->cancel(); + } + const auto success = shared_state->foreach_connection([](context_t&){ + return true; + }); + ASSERT_TRUE(success); + } + ), + &command_handler_t::destroy + ); + for (auto i = 0; i < N; ++i) { + { + connection_ptr conn(new connection_t(io_context, shared_state, {}, {})); + conn->socket().connect(endpoint); + conn->start({}, {}); + lock_guard_t guard(shared_conn->lock); + shared_conn->conn = conn; + } + ASSERT_TRUE(shared_state->get_connections_count() == 1); + shared_state->del_out_connections(1); + ASSERT_TRUE(shared_state->get_connections_count() == 0); + } }); for (auto& w: workers) { |