aboutsummaryrefslogtreecommitdiff
path: root/contrib/epee/include
diff options
context:
space:
mode:
authorj-berman <justinberman@protonmail.com>2022-07-06 16:47:34 -0700
committerj-berman <justinberman@protonmail.com>2022-07-08 15:10:03 -0700
commita82fba4b7b944a54d2a14922f44d7eee367e4912 (patch)
tree6281ddd1008cbd92b421acc1c88cccdf3c8a60a7 /contrib/epee/include
parentconnection: fix implementation (diff)
downloadmonero-a82fba4b7b944a54d2a14922f44d7eee367e4912.tar.xz
address PR comments
Diffstat (limited to 'contrib/epee/include')
-rw-r--r--contrib/epee/include/net/abstract_tcp_server2.h82
-rw-r--r--contrib/epee/include/net/abstract_tcp_server2.inl672
2 files changed, 376 insertions, 378 deletions
diff --git a/contrib/epee/include/net/abstract_tcp_server2.h b/contrib/epee/include/net/abstract_tcp_server2.h
index 0684573f2..bc0da66e2 100644
--- a/contrib/epee/include/net/abstract_tcp_server2.h
+++ b/contrib/epee/include/net/abstract_tcp_server2.h
@@ -89,20 +89,14 @@ namespace net_utils
public i_service_endpoint,
public connection_basic
{
+ public:
+ typedef typename t_protocol_handler::connection_context t_connection_context;
private:
- using string_t = std::string;
- using handler_t = t_protocol_handler;
- using context_t = typename handler_t::connection_context;
- using connection_t = connection<handler_t>;
+ using connection_t = connection<t_protocol_handler>;
using connection_ptr = boost::shared_ptr<connection_t>;
using ssl_support_t = epee::net_utils::ssl_support_t;
using timer_t = boost::asio::steady_timer;
using duration_t = timer_t::duration;
- using lock_t = std::mutex;
- using condition_t = std::condition_variable_any;
- using lock_guard_t = std::lock_guard<lock_t>;
- using unique_lock_t = std::unique_lock<lock_t>;
- using byte_slice_t = epee::byte_slice;
using ec_t = boost::system::error_code;
using handshake_t = boost::asio::ssl::stream_base::handshake_type;
@@ -110,8 +104,6 @@ namespace net_utils
using strand_t = boost::asio::io_service::strand;
using socket_t = boost::asio::ip::tcp::socket;
- using write_queue_t = std::deque<byte_slice_t>;
- using read_buffer_t = std::array<uint8_t, 0x2000>;
using network_throttle_t = epee::net_utils::network_throttle;
using network_throttle_manager_t = epee::net_utils::network_throttle_manager;
@@ -119,6 +111,8 @@ namespace net_utils
duration_t get_default_timeout();
duration_t get_timeout_from_bytes_read(size_t bytes) const;
+ void state_status_check();
+
void start_timer(duration_t duration, bool add = {});
void async_wait_timer();
void cancel_timer();
@@ -137,13 +131,21 @@ namespace net_utils
void terminate();
void on_terminating();
- bool send(byte_slice_t message);
+ bool send(epee::byte_slice message);
bool start_internal(
bool is_income,
bool is_multithreaded,
boost::optional<network_address> real_remote
);
+ enum status_t {
+ TERMINATED,
+ RUNNING,
+ INTERRUPTED,
+ TERMINATING,
+ WASTED,
+ };
+
struct state_t {
struct stat_t {
struct {
@@ -156,10 +158,10 @@ namespace net_utils
struct data_t {
struct {
- read_buffer_t buffer;
+ std::array<uint8_t, 0x2000> buffer;
} read;
struct {
- write_queue_t queue;
+ std::deque<epee::byte_slice> queue;
bool wait_consume;
} write;
};
@@ -171,7 +173,7 @@ namespace net_utils
bool handshaked;
};
- struct socket_t {
+ struct socket_status_t {
bool connected;
bool wait_handshake;
@@ -189,30 +191,22 @@ namespace net_utils
bool cancel_shutdown;
};
- struct timer_t {
+ struct timer_status_t {
bool wait_expire;
bool cancel_expire;
bool reset_expire;
};
- struct timers_t {
+ struct timers_status_t {
struct throttle_t {
- timer_t in;
- timer_t out;
+ timer_status_t in;
+ timer_status_t out;
};
- timer_t general;
+ timer_status_t general;
throttle_t throttle;
};
- enum status_t {
- TERMINATED,
- RUNNING,
- INTERRUPTED,
- TERMINATING,
- WASTED,
- };
-
struct protocol_t {
size_t reference_counter;
bool released;
@@ -223,19 +217,17 @@ namespace net_utils
size_t wait_callback;
};
- lock_t lock;
- condition_t condition;
+ std::mutex lock;
+ std::condition_variable_any condition;
status_t status;
- socket_t socket;
+ socket_status_t socket;
ssl_t ssl;
- timers_t timers;
+ timers_status_t timers;
protocol_t protocol;
stat_t stat;
data_t data;
};
- using status_t = typename state_t::status_t;
-
struct timers_t {
timers_t(io_context_t &io_context):
general(io_context),
@@ -254,19 +246,17 @@ namespace net_utils
throttle_t throttle;
};
- io_context_t &io_context;
- t_connection_type connection_type;
- context_t context{};
- strand_t strand;
- timers_t timers;
+ io_context_t &m_io_context;
+ t_connection_type m_connection_type;
+ t_connection_context m_conn_context{};
+ strand_t m_strand;
+ timers_t m_timers;
connection_ptr self{};
- bool local{};
- string_t host{};
- state_t state{};
- handler_t handler;
+ bool m_local{};
+ std::string m_host{};
+ state_t m_state{};
+ t_protocol_handler m_handler;
public:
- typedef typename t_protocol_handler::connection_context t_connection_context;
-
struct shared_state : connection_basic_shared_state, t_protocol_handler::config_type
{
shared_state()
@@ -298,7 +288,7 @@ namespace net_utils
// `real_remote` is the actual endpoint (if connection is to proxy, etc.)
bool start(bool is_income, bool is_multithreaded, network_address real_remote);
- void get_context(t_connection_context& context_){context_ = context;}
+ void get_context(t_connection_context& context_){context_ = m_conn_context;}
void call_back_starter();
diff --git a/contrib/epee/include/net/abstract_tcp_server2.inl b/contrib/epee/include/net/abstract_tcp_server2.inl
index 0fc9228b1..81aa725d1 100644
--- a/contrib/epee/include/net/abstract_tcp_server2.inl
+++ b/contrib/epee/include/net/abstract_tcp_server2.inl
@@ -79,14 +79,14 @@ namespace net_utils
template<typename T>
unsigned int connection<T>::host_count(int delta)
{
- static lock_t hosts_mutex;
- lock_guard_t guard(hosts_mutex);
- static std::map<string_t, unsigned int> hosts;
- unsigned int &val = hosts[host];
+ static std::mutex hosts_mutex;
+ std::lock_guard<std::mutex> guard(hosts_mutex);
+ static std::map<std::string, unsigned int> hosts;
+ unsigned int &val = hosts[m_host];
if (delta > 0)
- MTRACE("New connection from host " << host << ": " << val);
+ MTRACE("New connection from host " << m_host << ": " << val);
else if (delta < 0)
- MTRACE("Closed connection from host " << host << ": " << val);
+ MTRACE("Closed connection from host " << m_host << ": " << val);
CHECK_AND_ASSERT_THROW_MES(delta >= 0 || val >= (unsigned)-delta, "Count would go negative");
CHECK_AND_ASSERT_THROW_MES(delta <= 0 || val <= std::numeric_limits<unsigned int>::max() - (unsigned)delta, "Count would wrap");
val += delta;
@@ -104,7 +104,7 @@ namespace net_utils
0
);
return (
- local ?
+ m_local ?
std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_LOCAL >> shift) :
std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_REMOTE >> shift)
);
@@ -121,15 +121,34 @@ namespace net_utils
}
template<typename T>
+ void connection<T>::state_status_check()
+ {
+ switch (m_state.status)
+ {
+ case status_t::RUNNING:
+ interrupt();
+ break;
+ case status_t::INTERRUPTED:
+ on_interrupted();
+ break;
+ case status_t::TERMINATING:
+ on_terminating();
+ break;
+ default:
+ break;
+ }
+ }
+
+ template<typename T>
void connection<T>::start_timer(duration_t duration, bool add)
{
- if (state.timers.general.wait_expire) {
- state.timers.general.cancel_expire = true;
- state.timers.general.reset_expire = true;
+ if (m_state.timers.general.wait_expire) {
+ m_state.timers.general.cancel_expire = true;
+ m_state.timers.general.reset_expire = true;
ec_t ec;
- timers.general.expires_from_now(
+ m_timers.general.expires_from_now(
std::min(
- duration + (add ? timers.general.expires_from_now() : duration_t{}),
+ duration + (add ? m_timers.general.expires_from_now() : duration_t{}),
get_default_timeout()
),
ec
@@ -137,9 +156,9 @@ namespace net_utils
}
else {
ec_t ec;
- timers.general.expires_from_now(
+ m_timers.general.expires_from_now(
std::min(
- duration + (add ? timers.general.expires_from_now() : duration_t{}),
+ duration + (add ? m_timers.general.expires_from_now() : duration_t{}),
get_default_timeout()
),
ec
@@ -151,27 +170,27 @@ namespace net_utils
template<typename T>
void connection<T>::async_wait_timer()
{
- if (state.timers.general.wait_expire)
+ if (m_state.timers.general.wait_expire)
return;
- state.timers.general.wait_expire = true;
+ m_state.timers.general.wait_expire = true;
auto self = connection<T>::shared_from_this();
- timers.general.async_wait([this, self](const ec_t & ec){
- lock_guard_t guard(state.lock);
- state.timers.general.wait_expire = false;
- if (state.timers.general.cancel_expire) {
- state.timers.general.cancel_expire = false;
- if (state.timers.general.reset_expire) {
- state.timers.general.reset_expire = false;
+ m_timers.general.async_wait([this, self](const ec_t & ec){
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.timers.general.wait_expire = false;
+ if (m_state.timers.general.cancel_expire) {
+ m_state.timers.general.cancel_expire = false;
+ if (m_state.timers.general.reset_expire) {
+ m_state.timers.general.reset_expire = false;
async_wait_timer();
}
- else if (state.status == status_t::INTERRUPTED)
+ else if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
}
- else if (state.status == status_t::RUNNING)
+ else if (m_state.status == status_t::RUNNING)
interrupt();
- else if (state.status == status_t::INTERRUPTED)
+ else if (m_state.status == status_t::INTERRUPTED)
terminate();
});
}
@@ -179,72 +198,67 @@ namespace net_utils
template<typename T>
void connection<T>::cancel_timer()
{
- if (not state.timers.general.wait_expire)
+ if (!m_state.timers.general.wait_expire)
return;
- state.timers.general.cancel_expire = true;
- state.timers.general.reset_expire = false;
+ m_state.timers.general.cancel_expire = true;
+ m_state.timers.general.reset_expire = false;
ec_t ec;
- timers.general.cancel(ec);
+ m_timers.general.cancel(ec);
}
template<typename T>
void connection<T>::start_handshake()
{
- if (state.socket.wait_handshake)
+ if (m_state.socket.wait_handshake)
return;
static_assert(
- epee::net_utils::get_ssl_magic_size() <= sizeof(state.data.read.buffer),
+ epee::net_utils::get_ssl_magic_size() <= sizeof(m_state.data.read.buffer),
""
);
auto self = connection<T>::shared_from_this();
- if (not state.ssl.forced and not state.ssl.detected) {
- state.socket.wait_read = true;
+ if (!m_state.ssl.forced && !m_state.ssl.detected) {
+ m_state.socket.wait_read = true;
boost::asio::async_read(
connection_basic::socket_.next_layer(),
boost::asio::buffer(
- state.data.read.buffer.data(),
- state.data.read.buffer.size()
+ m_state.data.read.buffer.data(),
+ m_state.data.read.buffer.size()
),
boost::asio::transfer_exactly(epee::net_utils::get_ssl_magic_size()),
- strand.wrap(
+ m_strand.wrap(
[this, self](const ec_t &ec, size_t bytes_transferred){
- lock_guard_t guard(state.lock);
- state.socket.wait_read = false;
- if (state.socket.cancel_read) {
- state.socket.cancel_read = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.wait_read = false;
+ if (m_state.socket.cancel_read) {
+ m_state.socket.cancel_read = false;
+ state_status_check();
}
else if (ec.value()) {
terminate();
}
else if (
- not epee::net_utils::is_ssl(
+ !epee::net_utils::is_ssl(
static_cast<const unsigned char *>(
- state.data.read.buffer.data()
+ m_state.data.read.buffer.data()
),
bytes_transferred
)
) {
- state.ssl.enabled = false;
- state.socket.handle_read = true;
+ m_state.ssl.enabled = false;
+ m_state.socket.handle_read = true;
connection_basic::strand_.post(
[this, self, bytes_transferred]{
- bool success = handler.handle_recv(
- reinterpret_cast<char *>(state.data.read.buffer.data()),
+ bool success = m_handler.handle_recv(
+ reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
- lock_guard_t guard(state.lock);
- state.socket.handle_read = false;
- if (state.status == status_t::INTERRUPTED)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.handle_read = false;
+ if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
- else if (not success)
+ else if (!success)
interrupt();
else {
start_read();
@@ -253,7 +267,7 @@ namespace net_utils
);
}
else {
- state.ssl.detected = true;
+ m_state.ssl.detected = true;
start_handshake();
}
}
@@ -262,18 +276,13 @@ namespace net_utils
return;
}
- state.socket.wait_handshake = true;
+ m_state.socket.wait_handshake = true;
auto on_handshake = [this, self](const ec_t &ec, size_t bytes_transferred){
- lock_guard_t guard(state.lock);
- state.socket.wait_handshake = false;
- if (state.socket.cancel_handshake) {
- state.socket.cancel_handshake = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.wait_handshake = false;
+ if (m_state.socket.cancel_handshake) {
+ m_state.socket.cancel_handshake = false;
+ state_status_check();
}
else if (ec.value()) {
ec_t ec;
@@ -282,11 +291,11 @@ namespace net_utils
ec
);
connection_basic::socket_.next_layer().close(ec);
- state.socket.connected = false;
+ m_state.socket.connected = false;
interrupt();
}
else {
- state.ssl.handshaked = true;
+ m_state.ssl.handshaked = true;
start_write();
start_read();
}
@@ -295,16 +304,16 @@ namespace net_utils
static_cast<shared_state&>(
connection_basic::get_state()
).ssl_options().configure(connection_basic::socket_, handshake);
- strand.post(
+ m_strand.post(
[this, self, on_handshake]{
connection_basic::socket_.async_handshake(
handshake,
boost::asio::buffer(
- state.data.read.buffer.data(),
- state.ssl.forced ? 0 :
+ m_state.data.read.buffer.data(),
+ m_state.ssl.forced ? 0 :
epee::net_utils::get_ssl_magic_size()
),
- strand.wrap(on_handshake)
+ m_strand.wrap(on_handshake)
);
}
);
@@ -313,12 +322,13 @@ namespace net_utils
template<typename T>
void connection<T>::start_read()
{
- if (state.timers.throttle.in.wait_expire || state.socket.wait_read ||
- state.socket.handle_read
- )
+ if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read ||
+ m_state.socket.handle_read
+ ) {
return;
+ }
auto self = connection<T>::shared_from_this();
- if (connection_type != e_connection_type_RPC) {
+ if (m_connection_type != e_connection_type_RPC) {
auto calc_duration = []{
CRITICAL_REGION_LOCAL(
network_throttle_manager_t::m_lock_get_global_throttle_in
@@ -336,19 +346,14 @@ namespace net_utils
const auto duration = calc_duration();
if (duration > duration_t{}) {
ec_t ec;
- timers.throttle.in.expires_from_now(duration, ec);
- state.timers.throttle.in.wait_expire = true;
- timers.throttle.in.async_wait([this, self](const ec_t &ec){
- lock_guard_t guard(state.lock);
- state.timers.throttle.in.wait_expire = false;
- if (state.timers.throttle.in.cancel_expire) {
- state.timers.throttle.in.cancel_expire = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ m_timers.throttle.in.expires_from_now(duration, ec);
+ m_state.timers.throttle.in.wait_expire = true;
+ m_timers.throttle.in.async_wait([this, self](const ec_t &ec){
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.timers.throttle.in.wait_expire = false;
+ if (m_state.timers.throttle.in.cancel_expire) {
+ m_state.timers.throttle.in.cancel_expire = false;
+ state_status_check();
}
else if (ec.value())
interrupt();
@@ -358,28 +363,23 @@ namespace net_utils
return;
}
}
- state.socket.wait_read = true;
+ m_state.socket.wait_read = true;
auto on_read = [this, self](const ec_t &ec, size_t bytes_transferred){
- lock_guard_t guard(state.lock);
- state.socket.wait_read = false;
- if (state.socket.cancel_read) {
- state.socket.cancel_read = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.wait_read = false;
+ if (m_state.socket.cancel_read) {
+ m_state.socket.cancel_read = false;
+ state_status_check();
}
else if (ec.value())
terminate();
else {
{
- state.stat.in.throttle.handle_trafic_exact(bytes_transferred);
- const auto speed = state.stat.in.throttle.get_current_speed();
- context.m_current_speed_down = speed;
- context.m_max_speed_down = std::max(
- context.m_max_speed_down,
+ m_state.stat.in.throttle.handle_trafic_exact(bytes_transferred);
+ const auto speed = m_state.stat.in.throttle.get_current_speed();
+ m_conn_context.m_current_speed_down = speed;
+ m_conn_context.m_max_speed_down = std::max(
+ m_conn_context.m_max_speed_down,
speed
);
{
@@ -390,24 +390,30 @@ namespace net_utils
).handle_trafic_exact(bytes_transferred);
}
connection_basic::logger_handle_net_read(bytes_transferred);
- context.m_last_recv = time(NULL);
- context.m_recv_cnt += bytes_transferred;
+ m_conn_context.m_last_recv = time(NULL);
+ m_conn_context.m_recv_cnt += bytes_transferred;
start_timer(get_timeout_from_bytes_read(bytes_transferred), true);
}
- state.socket.handle_read = true;
+
+ // Post handle_recv to a separate `strand_`, distinct from `m_strand`
+ // which is listening for reads/writes. This avoids a circular dep.
+ // handle_recv can queue many writes, and `m_strand` will process those
+ // writes until the connection terminates without deadlocking waiting
+ // for handle_recv.
+ m_state.socket.handle_read = true;
connection_basic::strand_.post(
[this, self, bytes_transferred]{
- bool success = handler.handle_recv(
- reinterpret_cast<char *>(state.data.read.buffer.data()),
+ bool success = m_handler.handle_recv(
+ reinterpret_cast<char *>(m_state.data.read.buffer.data()),
bytes_transferred
);
- lock_guard_t guard(state.lock);
- state.socket.handle_read = false;
- if (state.status == status_t::INTERRUPTED)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.handle_read = false;
+ if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
- else if (not success)
+ else if (!success)
interrupt();
else {
start_read();
@@ -416,23 +422,23 @@ namespace net_utils
);
}
};
- if (not state.ssl.enabled)
+ if (!m_state.ssl.enabled)
connection_basic::socket_.next_layer().async_read_some(
boost::asio::buffer(
- state.data.read.buffer.data(),
- state.data.read.buffer.size()
+ m_state.data.read.buffer.data(),
+ m_state.data.read.buffer.size()
),
- strand.wrap(on_read)
+ m_strand.wrap(on_read)
);
else
- strand.post(
+ m_strand.post(
[this, self, on_read]{
connection_basic::socket_.async_read_some(
boost::asio::buffer(
- state.data.read.buffer.data(),
- state.data.read.buffer.size()
+ m_state.data.read.buffer.data(),
+ m_state.data.read.buffer.size()
),
- strand.wrap(on_read)
+ m_strand.wrap(on_read)
);
}
);
@@ -441,13 +447,14 @@ namespace net_utils
template<typename T>
void connection<T>::start_write()
{
- if (state.timers.throttle.out.wait_expire || state.socket.wait_write ||
- state.data.write.queue.empty() ||
- (state.ssl.enabled && not state.ssl.handshaked)
- )
+ if (m_state.timers.throttle.out.wait_expire || m_state.socket.wait_write ||
+ m_state.data.write.queue.empty() ||
+ (m_state.ssl.enabled && !m_state.ssl.handshaked)
+ ) {
return;
+ }
auto self = connection<T>::shared_from_this();
- if (connection_type != e_connection_type_RPC) {
+ if (m_connection_type != e_connection_type_RPC) {
auto calc_duration = [this]{
CRITICAL_REGION_LOCAL(
network_throttle_manager_t::m_lock_get_global_throttle_out
@@ -457,7 +464,7 @@ namespace net_utils
std::min(
network_throttle_manager_t::get_global_throttle_out(
).get_sleep_time_after_tick(
- state.data.write.queue.back().size()
+ m_state.data.write.queue.back().size()
),
1.0
)
@@ -467,19 +474,14 @@ namespace net_utils
const auto duration = calc_duration();
if (duration > duration_t{}) {
ec_t ec;
- timers.throttle.out.expires_from_now(duration, ec);
- state.timers.throttle.out.wait_expire = true;
- timers.throttle.out.async_wait([this, self](const ec_t &ec){
- lock_guard_t guard(state.lock);
- state.timers.throttle.out.wait_expire = false;
- if (state.timers.throttle.out.cancel_expire) {
- state.timers.throttle.out.cancel_expire = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ m_timers.throttle.out.expires_from_now(duration, ec);
+ m_state.timers.throttle.out.wait_expire = true;
+ m_timers.throttle.out.async_wait([this, self](const ec_t &ec){
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.timers.throttle.out.wait_expire = false;
+ if (m_state.timers.throttle.out.cancel_expire) {
+ m_state.timers.throttle.out.cancel_expire = false;
+ state_status_check();
}
else if (ec.value())
interrupt();
@@ -489,31 +491,26 @@ namespace net_utils
}
}
- state.socket.wait_write = true;
+ m_state.socket.wait_write = true;
auto on_write = [this, self](const ec_t &ec, size_t bytes_transferred){
- lock_guard_t guard(state.lock);
- state.socket.wait_write = false;
- if (state.socket.cancel_write) {
- state.socket.cancel_write = false;
- state.data.write.queue.clear();
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- on_interrupted();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.wait_write = false;
+ if (m_state.socket.cancel_write) {
+ m_state.socket.cancel_write = false;
+ m_state.data.write.queue.clear();
+ state_status_check();
}
else if (ec.value()) {
- state.data.write.queue.clear();
+ m_state.data.write.queue.clear();
interrupt();
}
else {
{
- state.stat.out.throttle.handle_trafic_exact(bytes_transferred);
- const auto speed = state.stat.out.throttle.get_current_speed();
- context.m_current_speed_up = speed;
- context.m_max_speed_down = std::max(
- context.m_max_speed_down,
+ m_state.stat.out.throttle.handle_trafic_exact(bytes_transferred);
+ const auto speed = m_state.stat.out.throttle.get_current_speed();
+ m_conn_context.m_current_speed_up = speed;
+ m_conn_context.m_max_speed_down = std::max(
+ m_conn_context.m_max_speed_down,
speed
);
{
@@ -524,36 +521,36 @@ namespace net_utils
).handle_trafic_exact(bytes_transferred);
}
connection_basic::logger_handle_net_write(bytes_transferred);
- context.m_last_send = time(NULL);
- context.m_send_cnt += bytes_transferred;
+ m_conn_context.m_last_send = time(NULL);
+ m_conn_context.m_send_cnt += bytes_transferred;
start_timer(get_default_timeout(), true);
}
- assert(bytes_transferred == state.data.write.queue.back().size());
- state.data.write.queue.pop_back();
- state.condition.notify_all();
+ assert(bytes_transferred == m_state.data.write.queue.back().size());
+ m_state.data.write.queue.pop_back();
+ m_state.condition.notify_all();
start_write();
}
};
- if (not state.ssl.enabled)
+ if (!m_state.ssl.enabled)
boost::asio::async_write(
connection_basic::socket_.next_layer(),
boost::asio::buffer(
- state.data.write.queue.back().data(),
- state.data.write.queue.back().size()
+ m_state.data.write.queue.back().data(),
+ m_state.data.write.queue.back().size()
),
- strand.wrap(on_write)
+ m_strand.wrap(on_write)
);
else
- strand.post(
+ m_strand.post(
[this, self, on_write]{
boost::asio::async_write(
connection_basic::socket_,
boost::asio::buffer(
- state.data.write.queue.back().data(),
- state.data.write.queue.back().size()
+ m_state.data.write.queue.back().data(),
+ m_state.data.write.queue.back().size()
),
- strand.wrap(on_write)
+ m_strand.wrap(on_write)
);
}
);
@@ -562,21 +559,29 @@ namespace net_utils
template<typename T>
void connection<T>::start_shutdown()
{
- if (state.socket.wait_shutdown)
+ if (m_state.socket.wait_shutdown)
return;
auto self = connection<T>::shared_from_this();
- state.socket.wait_shutdown = true;
+ m_state.socket.wait_shutdown = true;
auto on_shutdown = [this, self](const ec_t &ec){
- lock_guard_t guard(state.lock);
- state.socket.wait_shutdown = false;
- if (state.socket.cancel_shutdown) {
- state.socket.cancel_shutdown = false;
- if (state.status == status_t::RUNNING)
- interrupt();
- else if (state.status == status_t::INTERRUPTED)
- terminate();
- else if (state.status == status_t::TERMINATING)
- on_terminating();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_state.socket.wait_shutdown = false;
+ if (m_state.socket.cancel_shutdown) {
+ m_state.socket.cancel_shutdown = false;
+ switch (m_state.status)
+ {
+ case status_t::RUNNING:
+ interrupt();
+ break;
+ case status_t::INTERRUPTED:
+ terminate();
+ break;
+ case status_t::TERMINATING:
+ on_terminating();
+ break;
+ default:
+ break;
+ }
}
else if (ec.value())
terminate();
@@ -585,10 +590,10 @@ namespace net_utils
on_interrupted();
}
};
- strand.post(
+ m_strand.post(
[this, self, on_shutdown]{
connection_basic::socket_.async_shutdown(
- strand.wrap(on_shutdown)
+ m_strand.wrap(on_shutdown)
);
}
);
@@ -599,24 +604,24 @@ namespace net_utils
void connection<T>::cancel_socket()
{
bool wait_socket = false;
- if (state.socket.wait_handshake)
- wait_socket = state.socket.cancel_handshake = true;
- if (state.timers.throttle.in.wait_expire) {
- state.timers.throttle.in.cancel_expire = true;
+ if (m_state.socket.wait_handshake)
+ wait_socket = m_state.socket.cancel_handshake = true;
+ if (m_state.timers.throttle.in.wait_expire) {
+ m_state.timers.throttle.in.cancel_expire = true;
ec_t ec;
- timers.throttle.in.cancel(ec);
+ m_timers.throttle.in.cancel(ec);
}
- if (state.socket.wait_read)
- wait_socket = state.socket.cancel_read = true;
- if (state.timers.throttle.out.wait_expire) {
- state.timers.throttle.out.cancel_expire = true;
+ if (m_state.socket.wait_read)
+ wait_socket = m_state.socket.cancel_read = true;
+ if (m_state.timers.throttle.out.wait_expire) {
+ m_state.timers.throttle.out.cancel_expire = true;
ec_t ec;
- timers.throttle.out.cancel(ec);
+ m_timers.throttle.out.cancel(ec);
}
- if (state.socket.wait_write)
- wait_socket = state.socket.cancel_write = true;
- if (state.socket.wait_shutdown)
- wait_socket = state.socket.cancel_shutdown = true;
+ if (m_state.socket.wait_write)
+ wait_socket = m_state.socket.cancel_write = true;
+ if (m_state.socket.wait_shutdown)
+ wait_socket = m_state.socket.cancel_shutdown = true;
if (wait_socket) {
ec_t ec;
connection_basic::socket_.next_layer().cancel(ec);
@@ -626,136 +631,139 @@ namespace net_utils
template<typename T>
void connection<T>::cancel_handler()
{
- if (state.protocol.released || state.protocol.wait_release)
+ if (m_state.protocol.released || m_state.protocol.wait_release)
return;
- state.protocol.wait_release = true;
- state.lock.unlock();
- handler.release_protocol();
- state.lock.lock();
- state.protocol.wait_release = false;
- state.protocol.released = true;
- if (state.status == status_t::INTERRUPTED)
+ m_state.protocol.wait_release = true;
+ m_state.lock.unlock();
+ m_handler.release_protocol();
+ m_state.lock.lock();
+ m_state.protocol.wait_release = false;
+ m_state.protocol.released = true;
+ if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
}
template<typename T>
void connection<T>::interrupt()
{
- if (state.status != status_t::RUNNING)
+ if (m_state.status != status_t::RUNNING)
return;
- state.status = status_t::INTERRUPTED;
+ m_state.status = status_t::INTERRUPTED;
cancel_timer();
cancel_socket();
on_interrupted();
- state.condition.notify_all();
+ m_state.condition.notify_all();
cancel_handler();
}
template<typename T>
void connection<T>::on_interrupted()
{
- assert(state.status == status_t::INTERRUPTED);
- if (state.timers.general.wait_expire)
+ assert(m_state.status == status_t::INTERRUPTED);
+ if (m_state.timers.general.wait_expire)
return;
- if (state.socket.wait_handshake)
+ if (m_state.socket.wait_handshake)
return;
- if (state.timers.throttle.in.wait_expire)
+ if (m_state.timers.throttle.in.wait_expire)
return;
- if (state.socket.wait_read)
+ if (m_state.socket.wait_read)
return;
- if (state.socket.handle_read)
+ if (m_state.socket.handle_read)
return;
- if (state.timers.throttle.out.wait_expire)
+ if (m_state.timers.throttle.out.wait_expire)
return;
- if (state.socket.wait_write)
+ if (m_state.socket.wait_write)
return;
- if (state.socket.wait_shutdown)
+ if (m_state.socket.wait_shutdown)
return;
- if (state.protocol.wait_init)
+ if (m_state.protocol.wait_init)
return;
- if (state.protocol.wait_callback)
+ if (m_state.protocol.wait_callback)
return;
- if (state.protocol.wait_release)
+ if (m_state.protocol.wait_release)
return;
- if (state.socket.connected) {
- if (not state.ssl.enabled) {
+ if (m_state.socket.connected) {
+ if (!m_state.ssl.enabled) {
ec_t ec;
connection_basic::socket_.next_layer().shutdown(
socket_t::shutdown_both,
ec
);
connection_basic::socket_.next_layer().close(ec);
- state.socket.connected = false;
- state.status = status_t::WASTED;
+ m_state.socket.connected = false;
+ m_state.status = status_t::WASTED;
}
else
start_shutdown();
}
else
- state.status = status_t::WASTED;
+ m_state.status = status_t::WASTED;
}
template<typename T>
void connection<T>::terminate()
{
- if (state.status != status_t::RUNNING &&
- state.status != status_t::INTERRUPTED
+ if (m_state.status != status_t::RUNNING &&
+ m_state.status != status_t::INTERRUPTED
)
return;
- state.status = status_t::TERMINATING;
+ m_state.status = status_t::TERMINATING;
cancel_timer();
cancel_socket();
on_terminating();
- state.condition.notify_all();
+ m_state.condition.notify_all();
cancel_handler();
}
template<typename T>
void connection<T>::on_terminating()
{
- assert(state.status == status_t::TERMINATING);
- if (state.timers.general.wait_expire)
+ assert(m_state.status == status_t::TERMINATING);
+ if (m_state.timers.general.wait_expire)
return;
- if (state.socket.wait_handshake)
+ if (m_state.socket.wait_handshake)
return;
- if (state.timers.throttle.in.wait_expire)
+ if (m_state.timers.throttle.in.wait_expire)
return;
- if (state.socket.wait_read)
+ if (m_state.socket.wait_read)
return;
- if (state.socket.handle_read)
+ if (m_state.socket.handle_read)
return;
- if (state.timers.throttle.out.wait_expire)
+ if (m_state.timers.throttle.out.wait_expire)
return;
- if (state.socket.wait_write)
+ if (m_state.socket.wait_write)
return;
- if (state.socket.wait_shutdown)
+ if (m_state.socket.wait_shutdown)
return;
- if (state.protocol.wait_init)
+ if (m_state.protocol.wait_init)
return;
- if (state.protocol.wait_callback)
+ if (m_state.protocol.wait_callback)
return;
- if (state.protocol.wait_release)
+ if (m_state.protocol.wait_release)
return;
- if (state.socket.connected) {
+ if (m_state.socket.connected) {
ec_t ec;
connection_basic::socket_.next_layer().shutdown(
socket_t::shutdown_both,
ec
);
connection_basic::socket_.next_layer().close(ec);
- state.socket.connected = false;
+ m_state.socket.connected = false;
}
- state.status = status_t::WASTED;
+ m_state.status = status_t::WASTED;
}
template<typename T>
- bool connection<T>::send(byte_slice_t message)
+ bool connection<T>::send(epee::byte_slice message)
{
- lock_guard_t guard(state.lock);
- if (state.status != status_t::RUNNING || state.socket.wait_handshake)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ if (m_state.status != status_t::RUNNING || m_state.socket.wait_handshake)
return false;
+
+ // Wait for the write queue to fall below the max. If it doesn't after a
+ // randomized delay, drop the connection.
auto wait_consume = [this] {
auto random_delay = []{
using engine = std::mt19937;
@@ -770,62 +778,62 @@ namespace net_utils
std::uniform_int_distribution<>(5000, 6000)(rng)
);
};
- if (state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT)
+ if (m_state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT)
return true;
- state.data.write.wait_consume = true;
- bool success = state.condition.wait_for(
- state.lock,
+ m_state.data.write.wait_consume = true;
+ bool success = m_state.condition.wait_for(
+ m_state.lock,
random_delay(),
[this]{
return (
- state.status != status_t::RUNNING ||
- state.data.write.queue.size() <=
+ m_state.status != status_t::RUNNING ||
+ m_state.data.write.queue.size() <=
ABSTRACT_SERVER_SEND_QUE_MAX_COUNT
);
}
);
- state.data.write.wait_consume = false;
- if (not success) {
+ m_state.data.write.wait_consume = false;
+ if (!success) {
terminate();
return false;
}
else
- return state.status == status_t::RUNNING;
+ return m_state.status == status_t::RUNNING;
};
auto wait_sender = [this] {
- state.condition.wait(
- state.lock,
+ m_state.condition.wait(
+ m_state.lock,
[this] {
return (
- state.status != status_t::RUNNING ||
- not state.data.write.wait_consume
+ m_state.status != status_t::RUNNING ||
+ !m_state.data.write.wait_consume
);
}
);
- return state.status == status_t::RUNNING;
+ return m_state.status == status_t::RUNNING;
};
- if (not wait_sender())
+ if (!wait_sender())
return false;
constexpr size_t CHUNK_SIZE = 32 * 1024;
- if (connection_type == e_connection_type_RPC ||
+ if (m_connection_type == e_connection_type_RPC ||
message.size() <= 2 * CHUNK_SIZE
) {
- if (not wait_consume())
+ if (!wait_consume())
return false;
- state.data.write.queue.emplace_front(std::move(message));
+ m_state.data.write.queue.emplace_front(std::move(message));
start_write();
}
else {
while (!message.empty()) {
- if (not wait_consume())
+ if (!wait_consume())
return false;
- state.data.write.queue.emplace_front(
+ m_state.data.write.queue.emplace_front(
message.take_slice(CHUNK_SIZE)
);
start_write();
}
}
- state.condition.notify_all();
+ m_state.condition.notify_all();
return true;
}
@@ -836,10 +844,10 @@ namespace net_utils
boost::optional<network_address> real_remote
)
{
- unique_lock_t guard(state.lock);
- if (state.status != status_t::TERMINATED)
+ std::unique_lock<std::mutex> guard(m_state.lock);
+ if (m_state.status != status_t::TERMINATED)
return false;
- if (not real_remote) {
+ if (!real_remote) {
ec_t ec;
auto endpoint = connection_basic::socket_.next_layer().remote_endpoint(
ec
@@ -866,7 +874,7 @@ namespace net_utils
auto *filter = static_cast<shared_state&>(
connection_basic::get_state()
).pfilter;
- if (filter and not filter->is_remote_host_allowed(*real_remote))
+ if (filter && !filter->is_remote_host_allowed(*real_remote))
return false;
ec_t ec;
#if !defined(_WIN32) || !defined(__i686)
@@ -886,39 +894,39 @@ namespace net_utils
if (ec.value())
return false;
connection_basic::m_is_multithreaded = is_multithreaded;
- context.set_details(
+ m_conn_context.set_details(
boost::uuids::random_generator()(),
*real_remote,
is_income,
connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled
);
- host = real_remote->host_str();
+ m_host = real_remote->host_str();
try { host_count(1); } catch(...) { /* ignore */ }
- local = real_remote->is_loopback() || real_remote->is_local();
- state.ssl.enabled = (
+ m_local = real_remote->is_loopback() || real_remote->is_local();
+ m_state.ssl.enabled = (
connection_basic::m_ssl_support != ssl_support_t::e_ssl_support_disabled
);
- state.ssl.forced = (
+ m_state.ssl.forced = (
connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled
);
- state.socket.connected = true;
- state.status = status_t::RUNNING;
+ m_state.socket.connected = true;
+ m_state.status = status_t::RUNNING;
start_timer(
std::chrono::milliseconds(
- local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE
+ m_local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE
)
);
- state.protocol.wait_init = true;
+ m_state.protocol.wait_init = true;
guard.unlock();
- handler.after_init_connection();
+ m_handler.after_init_connection();
guard.lock();
- state.protocol.wait_init = false;
- state.protocol.initialized = true;
- if (state.status == status_t::INTERRUPTED)
+ m_state.protocol.wait_init = false;
+ m_state.protocol.initialized = true;
+ if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
- else if (not is_income || not state.ssl.enabled)
+ else if (!is_income || !m_state.ssl.enabled)
start_read();
else
start_handshake();
@@ -949,23 +957,23 @@ namespace net_utils
ssl_support_t ssl_support
):
connection_basic(std::move(socket), shared_state, ssl_support),
- handler(this, *shared_state, context),
- connection_type(connection_type),
- io_context{GET_IO_SERVICE(connection_basic::socket_)},
- strand{io_context},
- timers{io_context}
+ m_handler(this, *shared_state, m_conn_context),
+ m_connection_type(connection_type),
+ m_io_context{GET_IO_SERVICE(connection_basic::socket_)},
+ m_strand{m_io_context},
+ m_timers{m_io_context}
{
}
template<typename T>
connection<T>::~connection() noexcept(false)
{
- lock_guard_t guard(state.lock);
- assert(state.status == status_t::TERMINATED ||
- state.status == status_t::WASTED ||
- io_context.stopped()
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ assert(m_state.status == status_t::TERMINATED ||
+ m_state.status == status_t::WASTED ||
+ m_io_context.stopped()
);
- if (state.status != status_t::WASTED)
+ if (m_state.status != status_t::WASTED)
return;
try { host_count(-1); } catch (...) { /* ignore */ }
}
@@ -992,9 +1000,9 @@ namespace net_utils
template<typename T>
void connection<T>::save_dbg_log()
{
- lock_guard_t guard(state.lock);
- string_t address;
- string_t port;
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ std::string address;
+ std::string port;
ec_t ec;
auto endpoint = connection_basic::socket().remote_endpoint(ec);
if (ec.value()) {
@@ -1006,10 +1014,10 @@ namespace net_utils
port = std::to_string(endpoint.port());
}
MDEBUG(
- " connection type " << std::to_string(connection_type) <<
+ " connection type " << std::to_string(m_connection_type) <<
" " << connection_basic::socket().local_endpoint().address().to_string() <<
":" << connection_basic::socket().local_endpoint().port() <<
- " <--> " << context.m_remote_address.str() <<
+ " <--> " << m_conn_context.m_remote_address.str() <<
" (via " << address << ":" << port << ")"
);
}
@@ -1017,7 +1025,7 @@ namespace net_utils
template<typename T>
bool connection<T>::speed_limit_is_enabled() const
{
- return connection_type != e_connection_type_RPC;
+ return m_connection_type != e_connection_type_RPC;
}
template<typename T>
@@ -1041,8 +1049,8 @@ namespace net_utils
template<typename T>
bool connection<T>::close()
{
- lock_guard_t guard(state.lock);
- if (state.status != status_t::RUNNING)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ if (m_state.status != status_t::RUNNING)
return false;
terminate();
return true;
@@ -1052,11 +1060,11 @@ namespace net_utils
bool connection<T>::call_run_once_service_io()
{
if(connection_basic::m_is_multithreaded) {
- if (not io_context.poll_one())
+ if (!m_io_context.poll_one())
misc_utils::sleep_no_w(1);
}
else {
- if (!io_context.run_one())
+ if (!m_io_context.run_one())
return false;
}
return true;
@@ -1065,18 +1073,18 @@ namespace net_utils
template<typename T>
bool connection<T>::request_callback()
{
- lock_guard_t guard(state.lock);
- if (state.status != status_t::RUNNING)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ if (m_state.status != status_t::RUNNING)
return false;
auto self = connection<T>::shared_from_this();
- ++state.protocol.wait_callback;
+ ++m_state.protocol.wait_callback;
connection_basic::strand_.post([this, self]{
- handler.handle_qued_callback();
- lock_guard_t guard(state.lock);
- --state.protocol.wait_callback;
- if (state.status == status_t::INTERRUPTED)
+ m_handler.handle_qued_callback();
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ --m_state.protocol.wait_callback;
+ if (m_state.status == status_t::INTERRUPTED)
on_interrupted();
- else if (state.status == status_t::TERMINATING)
+ else if (m_state.status == status_t::TERMINATING)
on_terminating();
});
return true;
@@ -1085,7 +1093,7 @@ namespace net_utils
template<typename T>
typename connection<T>::io_context_t &connection<T>::get_io_service()
{
- return io_context;
+ return m_io_context;
}
template<typename T>
@@ -1093,9 +1101,9 @@ namespace net_utils
{
try {
auto self = connection<T>::shared_from_this();
- lock_guard_t guard(state.lock);
+ std::lock_guard<std::mutex> guard(m_state.lock);
this->self = std::move(self);
- ++state.protocol.reference_counter;
+ ++m_state.protocol.reference_counter;
return true;
}
catch (boost::bad_weak_ptr &exception) {
@@ -1107,8 +1115,8 @@ namespace net_utils
bool connection<T>::release()
{
connection_ptr self;
- lock_guard_t guard(state.lock);
- if (not --state.protocol.reference_counter)
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ if (!(--m_state.protocol.reference_counter))
self = std::move(this->self);
return true;
}
@@ -1116,8 +1124,8 @@ namespace net_utils
template<typename T>
void connection<T>::setRpcStation()
{
- lock_guard_t guard(state.lock);
- connection_type = e_connection_type_RPC;
+ std::lock_guard<std::mutex> guard(m_state.lock);
+ m_connection_type = e_connection_type_RPC;
}
template<class t_protocol_handler>