aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--contrib/epee/include/net/levin_protocol_handler_async.h2
-rw-r--r--src/device_trezor/device_trezor_base.cpp29
-rw-r--r--src/device_trezor/device_trezor_base.hpp2
-rw-r--r--tests/unit_tests/epee_boosted_tcp_server.cpp75
5 files changed, 85 insertions, 25 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4f43f9d1a..49ac18c66 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -522,7 +522,7 @@ ExternalProject_Add(generate_translations_header
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/translations"
STAMP_DIR ${LRELEASE_PATH}
CMAKE_ARGS -DLRELEASE_PATH=${LRELEASE_PATH}
- INSTALL_COMMAND cmake -E echo "")
+ INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/translations")
add_subdirectory(external)
diff --git a/contrib/epee/include/net/levin_protocol_handler_async.h b/contrib/epee/include/net/levin_protocol_handler_async.h
index 635876589..f6b73a2d5 100644
--- a/contrib/epee/include/net/levin_protocol_handler_async.h
+++ b/contrib/epee/include/net/levin_protocol_handler_async.h
@@ -787,7 +787,7 @@ void async_protocol_handler_config<t_connection_context>::delete_connections(siz
{
auto i = connections.end() - 1;
async_protocol_handler<t_connection_context> *conn = m_connects.at(*i);
- del_connection(conn);
+ m_connects.erase(*i);
conn->close();
connections.erase(i);
}
diff --git a/src/device_trezor/device_trezor_base.cpp b/src/device_trezor/device_trezor_base.cpp
index f59be1573..70dc7f539 100644
--- a/src/device_trezor/device_trezor_base.cpp
+++ b/src/device_trezor/device_trezor_base.cpp
@@ -365,15 +365,14 @@ namespace trezor {
void device_trezor_base::device_state_initialize_unsafe()
{
require_connected();
- std::string tmp_session_id;
auto initMsg = std::make_shared<messages::management::Initialize>();
const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() {
- memwipe(&tmp_session_id[0], tmp_session_id.size());
+ if (initMsg->has_session_id())
+ memwipe(&(*initMsg->mutable_session_id())[0], initMsg->mutable_session_id()->size());
});
if(!m_device_session_id.empty()) {
- tmp_session_id.assign(m_device_session_id.data(), m_device_session_id.size());
- initMsg->set_allocated_session_id(&tmp_session_id);
+ initMsg->set_allocated_session_id(new std::string(m_device_session_id.data(), m_device_session_id.size()));
}
m_features = this->client_exchange<messages::management::Features>(initMsg);
@@ -382,8 +381,6 @@ namespace trezor {
} else {
m_device_session_id.clear();
}
-
- initMsg->release_session_id();
}
void device_trezor_base::device_state_reset()
@@ -453,18 +450,14 @@ namespace trezor {
pin = m_pin;
}
- std::string pin_field;
messages::common::PinMatrixAck m;
if (pin) {
- pin_field.assign(pin->data(), pin->size());
- m.set_allocated_pin(&pin_field);
+ m.set_allocated_pin(new std::string(pin->data(), pin->size()));
}
const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() {
- m.release_pin();
- if (!pin_field.empty()){
- memwipe(&pin_field[0], pin_field.size());
- }
+ if (m.has_pin())
+ memwipe(&(*m.mutable_pin())[0], m.mutable_pin()->size());
});
resp = call_raw(&m);
@@ -499,7 +492,6 @@ namespace trezor {
boost::optional<epee::wipeable_string> passphrase;
TREZOR_CALLBACK_GET(passphrase, on_passphrase_request, on_device);
- std::string passphrase_field;
messages::common::PassphraseAck m;
m.set_on_device(on_device);
if (!on_device) {
@@ -512,16 +504,13 @@ namespace trezor {
}
if (passphrase) {
- passphrase_field.assign(passphrase->data(), passphrase->size());
- m.set_allocated_passphrase(&passphrase_field);
+ m.set_allocated_passphrase(new std::string(passphrase->data(), passphrase->size()));
}
}
const auto data_cleaner = epee::misc_utils::create_scope_leave_handler([&]() {
- m.release_passphrase();
- if (!passphrase_field.empty()){
- memwipe(&passphrase_field[0], passphrase_field.size());
- }
+ if (m.has_passphrase())
+ memwipe(&(m.mutable_passphrase())[0], m.mutable_passphrase()->size());
});
resp = call_raw(&m);
diff --git a/src/device_trezor/device_trezor_base.hpp b/src/device_trezor/device_trezor_base.hpp
index 4db8f0c8e..0162b23df 100644
--- a/src/device_trezor/device_trezor_base.hpp
+++ b/src/device_trezor/device_trezor_base.hpp
@@ -165,7 +165,7 @@ namespace trezor {
// Scoped session closer
BOOST_SCOPE_EXIT_ALL(&, this) {
- if (open_session){
+ if (open_session && this->get_transport()){
this->get_transport()->close();
}
};
diff --git a/tests/unit_tests/epee_boosted_tcp_server.cpp b/tests/unit_tests/epee_boosted_tcp_server.cpp
index e99666eb1..06e076a3a 100644
--- a/tests/unit_tests/epee_boosted_tcp_server.cpp
+++ b/tests/unit_tests/epee_boosted_tcp_server.cpp
@@ -143,14 +143,24 @@ TEST(test_epee_connection, test_lifetime)
static constexpr bool handshake_complete() noexcept { return true; }
};
+ using functional_obj_t = std::function<void ()>;
struct command_handler_t: epee::levin::levin_commands_handler<context_t> {
size_t delay;
- command_handler_t(size_t delay = 0): delay(delay) {}
+ functional_obj_t on_connection_close_f;
+ command_handler_t(size_t delay = 0,
+ functional_obj_t on_connection_close_f = nullptr
+ ):
+ delay(delay),
+ on_connection_close_f(on_connection_close_f)
+ {}
virtual int invoke(int, const epee::span<const uint8_t>, epee::byte_slice&, context_t&) override { epee::misc_utils::sleep_no_w(delay); return {}; }
virtual int notify(int, const epee::span<const uint8_t>, context_t&) override { return {}; }
virtual void callback(context_t&) override {}
virtual void on_connection_new(context_t&) override {}
- virtual void on_connection_close(context_t&) override {}
+ virtual void on_connection_close(context_t&) override {
+ if (on_connection_close_f)
+ on_connection_close_f();
+ }
virtual ~command_handler_t() override {}
static void destroy(epee::levin::levin_commands_handler<context_t>* ptr) { delete ptr; }
};
@@ -168,6 +178,14 @@ TEST(test_epee_connection, test_lifetime)
using work_ptr = std::shared_ptr<work_t>;
using workers_t = std::vector<std::thread>;
using server_t = epee::net_utils::boosted_tcp_server<handler_t>;
+ using lock_t = std::mutex;
+ using lock_guard_t = std::lock_guard<lock_t>;
+ using connection_weak_ptr = boost::weak_ptr<connection_t>;
+ struct shared_conn_t {
+ lock_t lock;
+ connection_weak_ptr conn;
+ };
+ using shared_conn_ptr = std::shared_ptr<shared_conn_t>;
io_context_t io_context;
work_ptr work(std::make_shared<work_t>(io_context));
@@ -255,6 +273,7 @@ TEST(test_epee_connection, test_lifetime)
ASSERT_TRUE(index == N * N);
ASSERT_TRUE(shared_state->get_connections_count() == 0);
+ while (shared_state->sock_count);
ASSERT_TRUE(shared_state->get_connections_count() == 0);
constexpr auto DELAY = 30;
constexpr auto TIMEOUT = 1;
@@ -273,6 +292,58 @@ TEST(test_epee_connection, test_lifetime)
shared_state->close(tag);
ASSERT_TRUE(shared_state->get_connections_count() == 0);
}
+
+ while (shared_state->sock_count);
+ constexpr auto ZERO_DELAY = 0;
+ size_t counter = 0;
+ shared_state->set_handler(new command_handler_t(ZERO_DELAY,
+ [&counter]{
+ ASSERT_TRUE(counter++ == 0);
+ }
+ ),
+ &command_handler_t::destroy
+ );
+ connection_ptr conn(new connection_t(io_context, shared_state, {}, {}));
+ conn->socket().connect(endpoint);
+ conn->start({}, {});
+ ASSERT_TRUE(shared_state->get_connections_count() == 1);
+ shared_state->del_out_connections(1);
+ ASSERT_TRUE(shared_state->get_connections_count() == 0);
+ conn.reset();
+
+ while (shared_state->sock_count);
+ shared_conn_ptr shared_conn(std::make_shared<shared_conn_t>());
+ shared_state->set_handler(new command_handler_t(ZERO_DELAY,
+ [shared_state, shared_conn]{
+ {
+ connection_ptr conn;
+ {
+ lock_guard_t guard(shared_conn->lock);
+ conn = std::move(shared_conn->conn.lock());
+ }
+ if (conn)
+ conn->cancel();
+ }
+ const auto success = shared_state->foreach_connection([](context_t&){
+ return true;
+ });
+ ASSERT_TRUE(success);
+ }
+ ),
+ &command_handler_t::destroy
+ );
+ for (auto i = 0; i < N; ++i) {
+ {
+ connection_ptr conn(new connection_t(io_context, shared_state, {}, {}));
+ conn->socket().connect(endpoint);
+ conn->start({}, {});
+ lock_guard_t guard(shared_conn->lock);
+ shared_conn->conn = conn;
+ }
+ ASSERT_TRUE(shared_state->get_connections_count() == 1);
+ shared_state->del_out_connections(1);
+ ASSERT_TRUE(shared_state->get_connections_count() == 0);
+ }
});
for (auto& w: workers) {