diff options
-rw-r--r-- | contrib/epee/include/net/abstract_tcp_server2.h | 82 | ||||
-rw-r--r-- | contrib/epee/include/net/abstract_tcp_server2.inl | 672 | ||||
-rw-r--r-- | contrib/epee/src/net_ssl.cpp | 21 | ||||
-rw-r--r-- | tests/unit_tests/epee_boosted_tcp_server.cpp | 10 |
4 files changed, 390 insertions, 395 deletions
diff --git a/contrib/epee/include/net/abstract_tcp_server2.h b/contrib/epee/include/net/abstract_tcp_server2.h index 0684573f2..bc0da66e2 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.h +++ b/contrib/epee/include/net/abstract_tcp_server2.h @@ -89,20 +89,14 @@ namespace net_utils public i_service_endpoint, public connection_basic { + public: + typedef typename t_protocol_handler::connection_context t_connection_context; private: - using string_t = std::string; - using handler_t = t_protocol_handler; - using context_t = typename handler_t::connection_context; - using connection_t = connection<handler_t>; + using connection_t = connection<t_protocol_handler>; using connection_ptr = boost::shared_ptr<connection_t>; using ssl_support_t = epee::net_utils::ssl_support_t; using timer_t = boost::asio::steady_timer; using duration_t = timer_t::duration; - using lock_t = std::mutex; - using condition_t = std::condition_variable_any; - using lock_guard_t = std::lock_guard<lock_t>; - using unique_lock_t = std::unique_lock<lock_t>; - using byte_slice_t = epee::byte_slice; using ec_t = boost::system::error_code; using handshake_t = boost::asio::ssl::stream_base::handshake_type; @@ -110,8 +104,6 @@ namespace net_utils using strand_t = boost::asio::io_service::strand; using socket_t = boost::asio::ip::tcp::socket; - using write_queue_t = std::deque<byte_slice_t>; - using read_buffer_t = std::array<uint8_t, 0x2000>; using network_throttle_t = epee::net_utils::network_throttle; using network_throttle_manager_t = epee::net_utils::network_throttle_manager; @@ -119,6 +111,8 @@ namespace net_utils duration_t get_default_timeout(); duration_t get_timeout_from_bytes_read(size_t bytes) const; + void state_status_check(); + void start_timer(duration_t duration, bool add = {}); void async_wait_timer(); void cancel_timer(); @@ -137,13 +131,21 @@ namespace net_utils void terminate(); void on_terminating(); - bool send(byte_slice_t message); + bool send(epee::byte_slice message); bool start_internal( bool is_income, bool is_multithreaded, boost::optional<network_address> real_remote ); + enum status_t { + TERMINATED, + RUNNING, + INTERRUPTED, + TERMINATING, + WASTED, + }; + struct state_t { struct stat_t { struct { @@ -156,10 +158,10 @@ namespace net_utils struct data_t { struct { - read_buffer_t buffer; + std::array<uint8_t, 0x2000> buffer; } read; struct { - write_queue_t queue; + std::deque<epee::byte_slice> queue; bool wait_consume; } write; }; @@ -171,7 +173,7 @@ namespace net_utils bool handshaked; }; - struct socket_t { + struct socket_status_t { bool connected; bool wait_handshake; @@ -189,30 +191,22 @@ namespace net_utils bool cancel_shutdown; }; - struct timer_t { + struct timer_status_t { bool wait_expire; bool cancel_expire; bool reset_expire; }; - struct timers_t { + struct timers_status_t { struct throttle_t { - timer_t in; - timer_t out; + timer_status_t in; + timer_status_t out; }; - timer_t general; + timer_status_t general; throttle_t throttle; }; - enum status_t { - TERMINATED, - RUNNING, - INTERRUPTED, - TERMINATING, - WASTED, - }; - struct protocol_t { size_t reference_counter; bool released; @@ -223,19 +217,17 @@ namespace net_utils size_t wait_callback; }; - lock_t lock; - condition_t condition; + std::mutex lock; + std::condition_variable_any condition; status_t status; - socket_t socket; + socket_status_t socket; ssl_t ssl; - timers_t timers; + timers_status_t timers; protocol_t protocol; stat_t stat; data_t data; }; - using status_t = typename state_t::status_t; - struct timers_t { timers_t(io_context_t &io_context): general(io_context), @@ -254,19 +246,17 @@ namespace net_utils throttle_t throttle; }; - io_context_t &io_context; - t_connection_type connection_type; - context_t context{}; - strand_t strand; - timers_t timers; + io_context_t &m_io_context; + t_connection_type m_connection_type; + t_connection_context m_conn_context{}; + strand_t m_strand; + timers_t m_timers; connection_ptr self{}; - bool local{}; - string_t host{}; - state_t state{}; - handler_t handler; + bool m_local{}; + std::string m_host{}; + state_t m_state{}; + t_protocol_handler m_handler; public: - typedef typename t_protocol_handler::connection_context t_connection_context; - struct shared_state : connection_basic_shared_state, t_protocol_handler::config_type { shared_state() @@ -298,7 +288,7 @@ namespace net_utils // `real_remote` is the actual endpoint (if connection is to proxy, etc.) bool start(bool is_income, bool is_multithreaded, network_address real_remote); - void get_context(t_connection_context& context_){context_ = context;} + void get_context(t_connection_context& context_){context_ = m_conn_context;} void call_back_starter(); diff --git a/contrib/epee/include/net/abstract_tcp_server2.inl b/contrib/epee/include/net/abstract_tcp_server2.inl index 0fc9228b1..81aa725d1 100644 --- a/contrib/epee/include/net/abstract_tcp_server2.inl +++ b/contrib/epee/include/net/abstract_tcp_server2.inl @@ -79,14 +79,14 @@ namespace net_utils template<typename T> unsigned int connection<T>::host_count(int delta) { - static lock_t hosts_mutex; - lock_guard_t guard(hosts_mutex); - static std::map<string_t, unsigned int> hosts; - unsigned int &val = hosts[host]; + static std::mutex hosts_mutex; + std::lock_guard<std::mutex> guard(hosts_mutex); + static std::map<std::string, unsigned int> hosts; + unsigned int &val = hosts[m_host]; if (delta > 0) - MTRACE("New connection from host " << host << ": " << val); + MTRACE("New connection from host " << m_host << ": " << val); else if (delta < 0) - MTRACE("Closed connection from host " << host << ": " << val); + MTRACE("Closed connection from host " << m_host << ": " << val); CHECK_AND_ASSERT_THROW_MES(delta >= 0 || val >= (unsigned)-delta, "Count would go negative"); CHECK_AND_ASSERT_THROW_MES(delta <= 0 || val <= std::numeric_limits<unsigned int>::max() - (unsigned)delta, "Count would wrap"); val += delta; @@ -104,7 +104,7 @@ namespace net_utils 0 ); return ( - local ? + m_local ? std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_LOCAL >> shift) : std::chrono::milliseconds(DEFAULT_TIMEOUT_MS_REMOTE >> shift) ); @@ -121,15 +121,34 @@ namespace net_utils } template<typename T> + void connection<T>::state_status_check() + { + switch (m_state.status) + { + case status_t::RUNNING: + interrupt(); + break; + case status_t::INTERRUPTED: + on_interrupted(); + break; + case status_t::TERMINATING: + on_terminating(); + break; + default: + break; + } + } + + template<typename T> void connection<T>::start_timer(duration_t duration, bool add) { - if (state.timers.general.wait_expire) { - state.timers.general.cancel_expire = true; - state.timers.general.reset_expire = true; + if (m_state.timers.general.wait_expire) { + m_state.timers.general.cancel_expire = true; + m_state.timers.general.reset_expire = true; ec_t ec; - timers.general.expires_from_now( + m_timers.general.expires_from_now( std::min( - duration + (add ? timers.general.expires_from_now() : duration_t{}), + duration + (add ? m_timers.general.expires_from_now() : duration_t{}), get_default_timeout() ), ec @@ -137,9 +156,9 @@ namespace net_utils } else { ec_t ec; - timers.general.expires_from_now( + m_timers.general.expires_from_now( std::min( - duration + (add ? timers.general.expires_from_now() : duration_t{}), + duration + (add ? m_timers.general.expires_from_now() : duration_t{}), get_default_timeout() ), ec @@ -151,27 +170,27 @@ namespace net_utils template<typename T> void connection<T>::async_wait_timer() { - if (state.timers.general.wait_expire) + if (m_state.timers.general.wait_expire) return; - state.timers.general.wait_expire = true; + m_state.timers.general.wait_expire = true; auto self = connection<T>::shared_from_this(); - timers.general.async_wait([this, self](const ec_t & ec){ - lock_guard_t guard(state.lock); - state.timers.general.wait_expire = false; - if (state.timers.general.cancel_expire) { - state.timers.general.cancel_expire = false; - if (state.timers.general.reset_expire) { - state.timers.general.reset_expire = false; + m_timers.general.async_wait([this, self](const ec_t & ec){ + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.timers.general.wait_expire = false; + if (m_state.timers.general.cancel_expire) { + m_state.timers.general.cancel_expire = false; + if (m_state.timers.general.reset_expire) { + m_state.timers.general.reset_expire = false; async_wait_timer(); } - else if (state.status == status_t::INTERRUPTED) + else if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); } - else if (state.status == status_t::RUNNING) + else if (m_state.status == status_t::RUNNING) interrupt(); - else if (state.status == status_t::INTERRUPTED) + else if (m_state.status == status_t::INTERRUPTED) terminate(); }); } @@ -179,72 +198,67 @@ namespace net_utils template<typename T> void connection<T>::cancel_timer() { - if (not state.timers.general.wait_expire) + if (!m_state.timers.general.wait_expire) return; - state.timers.general.cancel_expire = true; - state.timers.general.reset_expire = false; + m_state.timers.general.cancel_expire = true; + m_state.timers.general.reset_expire = false; ec_t ec; - timers.general.cancel(ec); + m_timers.general.cancel(ec); } template<typename T> void connection<T>::start_handshake() { - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; static_assert( - epee::net_utils::get_ssl_magic_size() <= sizeof(state.data.read.buffer), + epee::net_utils::get_ssl_magic_size() <= sizeof(m_state.data.read.buffer), "" ); auto self = connection<T>::shared_from_this(); - if (not state.ssl.forced and not state.ssl.detected) { - state.socket.wait_read = true; + if (!m_state.ssl.forced && !m_state.ssl.detected) { + m_state.socket.wait_read = true; boost::asio::async_read( connection_basic::socket_.next_layer(), boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), boost::asio::transfer_exactly(epee::net_utils::get_ssl_magic_size()), - strand.wrap( + m_strand.wrap( [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_read = false; - if (state.socket.cancel_read) { - state.socket.cancel_read = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.wait_read = false; + if (m_state.socket.cancel_read) { + m_state.socket.cancel_read = false; + state_status_check(); } else if (ec.value()) { terminate(); } else if ( - not epee::net_utils::is_ssl( + !epee::net_utils::is_ssl( static_cast<const unsigned char *>( - state.data.read.buffer.data() + m_state.data.read.buffer.data() ), bytes_transferred ) ) { - state.ssl.enabled = false; - state.socket.handle_read = true; + m_state.ssl.enabled = false; + m_state.socket.handle_read = true; connection_basic::strand_.post( [this, self, bytes_transferred]{ - bool success = handler.handle_recv( - reinterpret_cast<char *>(state.data.read.buffer.data()), + bool success = m_handler.handle_recv( + reinterpret_cast<char *>(m_state.data.read.buffer.data()), bytes_transferred ); - lock_guard_t guard(state.lock); - state.socket.handle_read = false; - if (state.status == status_t::INTERRUPTED) + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.handle_read = false; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not success) + else if (!success) interrupt(); else { start_read(); @@ -253,7 +267,7 @@ namespace net_utils ); } else { - state.ssl.detected = true; + m_state.ssl.detected = true; start_handshake(); } } @@ -262,18 +276,13 @@ namespace net_utils return; } - state.socket.wait_handshake = true; + m_state.socket.wait_handshake = true; auto on_handshake = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_handshake = false; - if (state.socket.cancel_handshake) { - state.socket.cancel_handshake = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.wait_handshake = false; + if (m_state.socket.cancel_handshake) { + m_state.socket.cancel_handshake = false; + state_status_check(); } else if (ec.value()) { ec_t ec; @@ -282,11 +291,11 @@ namespace net_utils ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; + m_state.socket.connected = false; interrupt(); } else { - state.ssl.handshaked = true; + m_state.ssl.handshaked = true; start_write(); start_read(); } @@ -295,16 +304,16 @@ namespace net_utils static_cast<shared_state&>( connection_basic::get_state() ).ssl_options().configure(connection_basic::socket_, handshake); - strand.post( + m_strand.post( [this, self, on_handshake]{ connection_basic::socket_.async_handshake( handshake, boost::asio::buffer( - state.data.read.buffer.data(), - state.ssl.forced ? 0 : + m_state.data.read.buffer.data(), + m_state.ssl.forced ? 0 : epee::net_utils::get_ssl_magic_size() ), - strand.wrap(on_handshake) + m_strand.wrap(on_handshake) ); } ); @@ -313,12 +322,13 @@ namespace net_utils template<typename T> void connection<T>::start_read() { - if (state.timers.throttle.in.wait_expire || state.socket.wait_read || - state.socket.handle_read - ) + if (m_state.timers.throttle.in.wait_expire || m_state.socket.wait_read || + m_state.socket.handle_read + ) { return; + } auto self = connection<T>::shared_from_this(); - if (connection_type != e_connection_type_RPC) { + if (m_connection_type != e_connection_type_RPC) { auto calc_duration = []{ CRITICAL_REGION_LOCAL( network_throttle_manager_t::m_lock_get_global_throttle_in @@ -336,19 +346,14 @@ namespace net_utils const auto duration = calc_duration(); if (duration > duration_t{}) { ec_t ec; - timers.throttle.in.expires_from_now(duration, ec); - state.timers.throttle.in.wait_expire = true; - timers.throttle.in.async_wait([this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.timers.throttle.in.wait_expire = false; - if (state.timers.throttle.in.cancel_expire) { - state.timers.throttle.in.cancel_expire = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + m_timers.throttle.in.expires_from_now(duration, ec); + m_state.timers.throttle.in.wait_expire = true; + m_timers.throttle.in.async_wait([this, self](const ec_t &ec){ + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.timers.throttle.in.wait_expire = false; + if (m_state.timers.throttle.in.cancel_expire) { + m_state.timers.throttle.in.cancel_expire = false; + state_status_check(); } else if (ec.value()) interrupt(); @@ -358,28 +363,23 @@ namespace net_utils return; } } - state.socket.wait_read = true; + m_state.socket.wait_read = true; auto on_read = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_read = false; - if (state.socket.cancel_read) { - state.socket.cancel_read = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.wait_read = false; + if (m_state.socket.cancel_read) { + m_state.socket.cancel_read = false; + state_status_check(); } else if (ec.value()) terminate(); else { { - state.stat.in.throttle.handle_trafic_exact(bytes_transferred); - const auto speed = state.stat.in.throttle.get_current_speed(); - context.m_current_speed_down = speed; - context.m_max_speed_down = std::max( - context.m_max_speed_down, + m_state.stat.in.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = m_state.stat.in.throttle.get_current_speed(); + m_conn_context.m_current_speed_down = speed; + m_conn_context.m_max_speed_down = std::max( + m_conn_context.m_max_speed_down, speed ); { @@ -390,24 +390,30 @@ namespace net_utils ).handle_trafic_exact(bytes_transferred); } connection_basic::logger_handle_net_read(bytes_transferred); - context.m_last_recv = time(NULL); - context.m_recv_cnt += bytes_transferred; + m_conn_context.m_last_recv = time(NULL); + m_conn_context.m_recv_cnt += bytes_transferred; start_timer(get_timeout_from_bytes_read(bytes_transferred), true); } - state.socket.handle_read = true; + + // Post handle_recv to a separate `strand_`, distinct from `m_strand` + // which is listening for reads/writes. This avoids a circular dep. + // handle_recv can queue many writes, and `m_strand` will process those + // writes until the connection terminates without deadlocking waiting + // for handle_recv. + m_state.socket.handle_read = true; connection_basic::strand_.post( [this, self, bytes_transferred]{ - bool success = handler.handle_recv( - reinterpret_cast<char *>(state.data.read.buffer.data()), + bool success = m_handler.handle_recv( + reinterpret_cast<char *>(m_state.data.read.buffer.data()), bytes_transferred ); - lock_guard_t guard(state.lock); - state.socket.handle_read = false; - if (state.status == status_t::INTERRUPTED) + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.handle_read = false; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not success) + else if (!success) interrupt(); else { start_read(); @@ -416,23 +422,23 @@ namespace net_utils ); } }; - if (not state.ssl.enabled) + if (!m_state.ssl.enabled) connection_basic::socket_.next_layer().async_read_some( boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), - strand.wrap(on_read) + m_strand.wrap(on_read) ); else - strand.post( + m_strand.post( [this, self, on_read]{ connection_basic::socket_.async_read_some( boost::asio::buffer( - state.data.read.buffer.data(), - state.data.read.buffer.size() + m_state.data.read.buffer.data(), + m_state.data.read.buffer.size() ), - strand.wrap(on_read) + m_strand.wrap(on_read) ); } ); @@ -441,13 +447,14 @@ namespace net_utils template<typename T> void connection<T>::start_write() { - if (state.timers.throttle.out.wait_expire || state.socket.wait_write || - state.data.write.queue.empty() || - (state.ssl.enabled && not state.ssl.handshaked) - ) + if (m_state.timers.throttle.out.wait_expire || m_state.socket.wait_write || + m_state.data.write.queue.empty() || + (m_state.ssl.enabled && !m_state.ssl.handshaked) + ) { return; + } auto self = connection<T>::shared_from_this(); - if (connection_type != e_connection_type_RPC) { + if (m_connection_type != e_connection_type_RPC) { auto calc_duration = [this]{ CRITICAL_REGION_LOCAL( network_throttle_manager_t::m_lock_get_global_throttle_out @@ -457,7 +464,7 @@ namespace net_utils std::min( network_throttle_manager_t::get_global_throttle_out( ).get_sleep_time_after_tick( - state.data.write.queue.back().size() + m_state.data.write.queue.back().size() ), 1.0 ) @@ -467,19 +474,14 @@ namespace net_utils const auto duration = calc_duration(); if (duration > duration_t{}) { ec_t ec; - timers.throttle.out.expires_from_now(duration, ec); - state.timers.throttle.out.wait_expire = true; - timers.throttle.out.async_wait([this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.timers.throttle.out.wait_expire = false; - if (state.timers.throttle.out.cancel_expire) { - state.timers.throttle.out.cancel_expire = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + m_timers.throttle.out.expires_from_now(duration, ec); + m_state.timers.throttle.out.wait_expire = true; + m_timers.throttle.out.async_wait([this, self](const ec_t &ec){ + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.timers.throttle.out.wait_expire = false; + if (m_state.timers.throttle.out.cancel_expire) { + m_state.timers.throttle.out.cancel_expire = false; + state_status_check(); } else if (ec.value()) interrupt(); @@ -489,31 +491,26 @@ namespace net_utils } } - state.socket.wait_write = true; + m_state.socket.wait_write = true; auto on_write = [this, self](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); - state.socket.wait_write = false; - if (state.socket.cancel_write) { - state.socket.cancel_write = false; - state.data.write.queue.clear(); - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - on_interrupted(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.wait_write = false; + if (m_state.socket.cancel_write) { + m_state.socket.cancel_write = false; + m_state.data.write.queue.clear(); + state_status_check(); } else if (ec.value()) { - state.data.write.queue.clear(); + m_state.data.write.queue.clear(); interrupt(); } else { { - state.stat.out.throttle.handle_trafic_exact(bytes_transferred); - const auto speed = state.stat.out.throttle.get_current_speed(); - context.m_current_speed_up = speed; - context.m_max_speed_down = std::max( - context.m_max_speed_down, + m_state.stat.out.throttle.handle_trafic_exact(bytes_transferred); + const auto speed = m_state.stat.out.throttle.get_current_speed(); + m_conn_context.m_current_speed_up = speed; + m_conn_context.m_max_speed_down = std::max( + m_conn_context.m_max_speed_down, speed ); { @@ -524,36 +521,36 @@ namespace net_utils ).handle_trafic_exact(bytes_transferred); } connection_basic::logger_handle_net_write(bytes_transferred); - context.m_last_send = time(NULL); - context.m_send_cnt += bytes_transferred; + m_conn_context.m_last_send = time(NULL); + m_conn_context.m_send_cnt += bytes_transferred; start_timer(get_default_timeout(), true); } - assert(bytes_transferred == state.data.write.queue.back().size()); - state.data.write.queue.pop_back(); - state.condition.notify_all(); + assert(bytes_transferred == m_state.data.write.queue.back().size()); + m_state.data.write.queue.pop_back(); + m_state.condition.notify_all(); start_write(); } }; - if (not state.ssl.enabled) + if (!m_state.ssl.enabled) boost::asio::async_write( connection_basic::socket_.next_layer(), boost::asio::buffer( - state.data.write.queue.back().data(), - state.data.write.queue.back().size() + m_state.data.write.queue.back().data(), + m_state.data.write.queue.back().size() ), - strand.wrap(on_write) + m_strand.wrap(on_write) ); else - strand.post( + m_strand.post( [this, self, on_write]{ boost::asio::async_write( connection_basic::socket_, boost::asio::buffer( - state.data.write.queue.back().data(), - state.data.write.queue.back().size() + m_state.data.write.queue.back().data(), + m_state.data.write.queue.back().size() ), - strand.wrap(on_write) + m_strand.wrap(on_write) ); } ); @@ -562,21 +559,29 @@ namespace net_utils template<typename T> void connection<T>::start_shutdown() { - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; auto self = connection<T>::shared_from_this(); - state.socket.wait_shutdown = true; + m_state.socket.wait_shutdown = true; auto on_shutdown = [this, self](const ec_t &ec){ - lock_guard_t guard(state.lock); - state.socket.wait_shutdown = false; - if (state.socket.cancel_shutdown) { - state.socket.cancel_shutdown = false; - if (state.status == status_t::RUNNING) - interrupt(); - else if (state.status == status_t::INTERRUPTED) - terminate(); - else if (state.status == status_t::TERMINATING) - on_terminating(); + std::lock_guard<std::mutex> guard(m_state.lock); + m_state.socket.wait_shutdown = false; + if (m_state.socket.cancel_shutdown) { + m_state.socket.cancel_shutdown = false; + switch (m_state.status) + { + case status_t::RUNNING: + interrupt(); + break; + case status_t::INTERRUPTED: + terminate(); + break; + case status_t::TERMINATING: + on_terminating(); + break; + default: + break; + } } else if (ec.value()) terminate(); @@ -585,10 +590,10 @@ namespace net_utils on_interrupted(); } }; - strand.post( + m_strand.post( [this, self, on_shutdown]{ connection_basic::socket_.async_shutdown( - strand.wrap(on_shutdown) + m_strand.wrap(on_shutdown) ); } ); @@ -599,24 +604,24 @@ namespace net_utils void connection<T>::cancel_socket() { bool wait_socket = false; - if (state.socket.wait_handshake) - wait_socket = state.socket.cancel_handshake = true; - if (state.timers.throttle.in.wait_expire) { - state.timers.throttle.in.cancel_expire = true; + if (m_state.socket.wait_handshake) + wait_socket = m_state.socket.cancel_handshake = true; + if (m_state.timers.throttle.in.wait_expire) { + m_state.timers.throttle.in.cancel_expire = true; ec_t ec; - timers.throttle.in.cancel(ec); + m_timers.throttle.in.cancel(ec); } - if (state.socket.wait_read) - wait_socket = state.socket.cancel_read = true; - if (state.timers.throttle.out.wait_expire) { - state.timers.throttle.out.cancel_expire = true; + if (m_state.socket.wait_read) + wait_socket = m_state.socket.cancel_read = true; + if (m_state.timers.throttle.out.wait_expire) { + m_state.timers.throttle.out.cancel_expire = true; ec_t ec; - timers.throttle.out.cancel(ec); + m_timers.throttle.out.cancel(ec); } - if (state.socket.wait_write) - wait_socket = state.socket.cancel_write = true; - if (state.socket.wait_shutdown) - wait_socket = state.socket.cancel_shutdown = true; + if (m_state.socket.wait_write) + wait_socket = m_state.socket.cancel_write = true; + if (m_state.socket.wait_shutdown) + wait_socket = m_state.socket.cancel_shutdown = true; if (wait_socket) { ec_t ec; connection_basic::socket_.next_layer().cancel(ec); @@ -626,136 +631,139 @@ namespace net_utils template<typename T> void connection<T>::cancel_handler() { - if (state.protocol.released || state.protocol.wait_release) + if (m_state.protocol.released || m_state.protocol.wait_release) return; - state.protocol.wait_release = true; - state.lock.unlock(); - handler.release_protocol(); - state.lock.lock(); - state.protocol.wait_release = false; - state.protocol.released = true; - if (state.status == status_t::INTERRUPTED) + m_state.protocol.wait_release = true; + m_state.lock.unlock(); + m_handler.release_protocol(); + m_state.lock.lock(); + m_state.protocol.wait_release = false; + m_state.protocol.released = true; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); } template<typename T> void connection<T>::interrupt() { - if (state.status != status_t::RUNNING) + if (m_state.status != status_t::RUNNING) return; - state.status = status_t::INTERRUPTED; + m_state.status = status_t::INTERRUPTED; cancel_timer(); cancel_socket(); on_interrupted(); - state.condition.notify_all(); + m_state.condition.notify_all(); cancel_handler(); } template<typename T> void connection<T>::on_interrupted() { - assert(state.status == status_t::INTERRUPTED); - if (state.timers.general.wait_expire) + assert(m_state.status == status_t::INTERRUPTED); + if (m_state.timers.general.wait_expire) return; - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; - if (state.timers.throttle.in.wait_expire) + if (m_state.timers.throttle.in.wait_expire) return; - if (state.socket.wait_read) + if (m_state.socket.wait_read) return; - if (state.socket.handle_read) + if (m_state.socket.handle_read) return; - if (state.timers.throttle.out.wait_expire) + if (m_state.timers.throttle.out.wait_expire) return; - if (state.socket.wait_write) + if (m_state.socket.wait_write) return; - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; - if (state.protocol.wait_init) + if (m_state.protocol.wait_init) return; - if (state.protocol.wait_callback) + if (m_state.protocol.wait_callback) return; - if (state.protocol.wait_release) + if (m_state.protocol.wait_release) return; - if (state.socket.connected) { - if (not state.ssl.enabled) { + if (m_state.socket.connected) { + if (!m_state.ssl.enabled) { ec_t ec; connection_basic::socket_.next_layer().shutdown( socket_t::shutdown_both, ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; - state.status = status_t::WASTED; + m_state.socket.connected = false; + m_state.status = status_t::WASTED; } else start_shutdown(); } else - state.status = status_t::WASTED; + m_state.status = status_t::WASTED; } template<typename T> void connection<T>::terminate() { - if (state.status != status_t::RUNNING && - state.status != status_t::INTERRUPTED + if (m_state.status != status_t::RUNNING && + m_state.status != status_t::INTERRUPTED ) return; - state.status = status_t::TERMINATING; + m_state.status = status_t::TERMINATING; cancel_timer(); cancel_socket(); on_terminating(); - state.condition.notify_all(); + m_state.condition.notify_all(); cancel_handler(); } template<typename T> void connection<T>::on_terminating() { - assert(state.status == status_t::TERMINATING); - if (state.timers.general.wait_expire) + assert(m_state.status == status_t::TERMINATING); + if (m_state.timers.general.wait_expire) return; - if (state.socket.wait_handshake) + if (m_state.socket.wait_handshake) return; - if (state.timers.throttle.in.wait_expire) + if (m_state.timers.throttle.in.wait_expire) return; - if (state.socket.wait_read) + if (m_state.socket.wait_read) return; - if (state.socket.handle_read) + if (m_state.socket.handle_read) return; - if (state.timers.throttle.out.wait_expire) + if (m_state.timers.throttle.out.wait_expire) return; - if (state.socket.wait_write) + if (m_state.socket.wait_write) return; - if (state.socket.wait_shutdown) + if (m_state.socket.wait_shutdown) return; - if (state.protocol.wait_init) + if (m_state.protocol.wait_init) return; - if (state.protocol.wait_callback) + if (m_state.protocol.wait_callback) return; - if (state.protocol.wait_release) + if (m_state.protocol.wait_release) return; - if (state.socket.connected) { + if (m_state.socket.connected) { ec_t ec; connection_basic::socket_.next_layer().shutdown( socket_t::shutdown_both, ec ); connection_basic::socket_.next_layer().close(ec); - state.socket.connected = false; + m_state.socket.connected = false; } - state.status = status_t::WASTED; + m_state.status = status_t::WASTED; } template<typename T> - bool connection<T>::send(byte_slice_t message) + bool connection<T>::send(epee::byte_slice message) { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING || state.socket.wait_handshake) + std::lock_guard<std::mutex> guard(m_state.lock); + if (m_state.status != status_t::RUNNING || m_state.socket.wait_handshake) return false; + + // Wait for the write queue to fall below the max. If it doesn't after a + // randomized delay, drop the connection. auto wait_consume = [this] { auto random_delay = []{ using engine = std::mt19937; @@ -770,62 +778,62 @@ namespace net_utils std::uniform_int_distribution<>(5000, 6000)(rng) ); }; - if (state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) + if (m_state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT) return true; - state.data.write.wait_consume = true; - bool success = state.condition.wait_for( - state.lock, + m_state.data.write.wait_consume = true; + bool success = m_state.condition.wait_for( + m_state.lock, random_delay(), [this]{ return ( - state.status != status_t::RUNNING || - state.data.write.queue.size() <= + m_state.status != status_t::RUNNING || + m_state.data.write.queue.size() <= ABSTRACT_SERVER_SEND_QUE_MAX_COUNT ); } ); - state.data.write.wait_consume = false; - if (not success) { + m_state.data.write.wait_consume = false; + if (!success) { terminate(); return false; } else - return state.status == status_t::RUNNING; + return m_state.status == status_t::RUNNING; }; auto wait_sender = [this] { - state.condition.wait( - state.lock, + m_state.condition.wait( + m_state.lock, [this] { return ( - state.status != status_t::RUNNING || - not state.data.write.wait_consume + m_state.status != status_t::RUNNING || + !m_state.data.write.wait_consume ); } ); - return state.status == status_t::RUNNING; + return m_state.status == status_t::RUNNING; }; - if (not wait_sender()) + if (!wait_sender()) return false; constexpr size_t CHUNK_SIZE = 32 * 1024; - if (connection_type == e_connection_type_RPC || + if (m_connection_type == e_connection_type_RPC || message.size() <= 2 * CHUNK_SIZE ) { - if (not wait_consume()) + if (!wait_consume()) return false; - state.data.write.queue.emplace_front(std::move(message)); + m_state.data.write.queue.emplace_front(std::move(message)); start_write(); } else { while (!message.empty()) { - if (not wait_consume()) + if (!wait_consume()) return false; - state.data.write.queue.emplace_front( + m_state.data.write.queue.emplace_front( message.take_slice(CHUNK_SIZE) ); start_write(); } } - state.condition.notify_all(); + m_state.condition.notify_all(); return true; } @@ -836,10 +844,10 @@ namespace net_utils boost::optional<network_address> real_remote ) { - unique_lock_t guard(state.lock); - if (state.status != status_t::TERMINATED) + std::unique_lock<std::mutex> guard(m_state.lock); + if (m_state.status != status_t::TERMINATED) return false; - if (not real_remote) { + if (!real_remote) { ec_t ec; auto endpoint = connection_basic::socket_.next_layer().remote_endpoint( ec @@ -866,7 +874,7 @@ namespace net_utils auto *filter = static_cast<shared_state&>( connection_basic::get_state() ).pfilter; - if (filter and not filter->is_remote_host_allowed(*real_remote)) + if (filter && !filter->is_remote_host_allowed(*real_remote)) return false; ec_t ec; #if !defined(_WIN32) || !defined(__i686) @@ -886,39 +894,39 @@ namespace net_utils if (ec.value()) return false; connection_basic::m_is_multithreaded = is_multithreaded; - context.set_details( + m_conn_context.set_details( boost::uuids::random_generator()(), *real_remote, is_income, connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled ); - host = real_remote->host_str(); + m_host = real_remote->host_str(); try { host_count(1); } catch(...) { /* ignore */ } - local = real_remote->is_loopback() || real_remote->is_local(); - state.ssl.enabled = ( + m_local = real_remote->is_loopback() || real_remote->is_local(); + m_state.ssl.enabled = ( connection_basic::m_ssl_support != ssl_support_t::e_ssl_support_disabled ); - state.ssl.forced = ( + m_state.ssl.forced = ( connection_basic::m_ssl_support == ssl_support_t::e_ssl_support_enabled ); - state.socket.connected = true; - state.status = status_t::RUNNING; + m_state.socket.connected = true; + m_state.status = status_t::RUNNING; start_timer( std::chrono::milliseconds( - local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE + m_local ? NEW_CONNECTION_TIMEOUT_LOCAL : NEW_CONNECTION_TIMEOUT_REMOTE ) ); - state.protocol.wait_init = true; + m_state.protocol.wait_init = true; guard.unlock(); - handler.after_init_connection(); + m_handler.after_init_connection(); guard.lock(); - state.protocol.wait_init = false; - state.protocol.initialized = true; - if (state.status == status_t::INTERRUPTED) + m_state.protocol.wait_init = false; + m_state.protocol.initialized = true; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); - else if (not is_income || not state.ssl.enabled) + else if (!is_income || !m_state.ssl.enabled) start_read(); else start_handshake(); @@ -949,23 +957,23 @@ namespace net_utils ssl_support_t ssl_support ): connection_basic(std::move(socket), shared_state, ssl_support), - handler(this, *shared_state, context), - connection_type(connection_type), - io_context{GET_IO_SERVICE(connection_basic::socket_)}, - strand{io_context}, - timers{io_context} + m_handler(this, *shared_state, m_conn_context), + m_connection_type(connection_type), + m_io_context{GET_IO_SERVICE(connection_basic::socket_)}, + m_strand{m_io_context}, + m_timers{m_io_context} { } template<typename T> connection<T>::~connection() noexcept(false) { - lock_guard_t guard(state.lock); - assert(state.status == status_t::TERMINATED || - state.status == status_t::WASTED || - io_context.stopped() + std::lock_guard<std::mutex> guard(m_state.lock); + assert(m_state.status == status_t::TERMINATED || + m_state.status == status_t::WASTED || + m_io_context.stopped() ); - if (state.status != status_t::WASTED) + if (m_state.status != status_t::WASTED) return; try { host_count(-1); } catch (...) { /* ignore */ } } @@ -992,9 +1000,9 @@ namespace net_utils template<typename T> void connection<T>::save_dbg_log() { - lock_guard_t guard(state.lock); - string_t address; - string_t port; + std::lock_guard<std::mutex> guard(m_state.lock); + std::string address; + std::string port; ec_t ec; auto endpoint = connection_basic::socket().remote_endpoint(ec); if (ec.value()) { @@ -1006,10 +1014,10 @@ namespace net_utils port = std::to_string(endpoint.port()); } MDEBUG( - " connection type " << std::to_string(connection_type) << + " connection type " << std::to_string(m_connection_type) << " " << connection_basic::socket().local_endpoint().address().to_string() << ":" << connection_basic::socket().local_endpoint().port() << - " <--> " << context.m_remote_address.str() << + " <--> " << m_conn_context.m_remote_address.str() << " (via " << address << ":" << port << ")" ); } @@ -1017,7 +1025,7 @@ namespace net_utils template<typename T> bool connection<T>::speed_limit_is_enabled() const { - return connection_type != e_connection_type_RPC; + return m_connection_type != e_connection_type_RPC; } template<typename T> @@ -1041,8 +1049,8 @@ namespace net_utils template<typename T> bool connection<T>::close() { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING) + std::lock_guard<std::mutex> guard(m_state.lock); + if (m_state.status != status_t::RUNNING) return false; terminate(); return true; @@ -1052,11 +1060,11 @@ namespace net_utils bool connection<T>::call_run_once_service_io() { if(connection_basic::m_is_multithreaded) { - if (not io_context.poll_one()) + if (!m_io_context.poll_one()) misc_utils::sleep_no_w(1); } else { - if (!io_context.run_one()) + if (!m_io_context.run_one()) return false; } return true; @@ -1065,18 +1073,18 @@ namespace net_utils template<typename T> bool connection<T>::request_callback() { - lock_guard_t guard(state.lock); - if (state.status != status_t::RUNNING) + std::lock_guard<std::mutex> guard(m_state.lock); + if (m_state.status != status_t::RUNNING) return false; auto self = connection<T>::shared_from_this(); - ++state.protocol.wait_callback; + ++m_state.protocol.wait_callback; connection_basic::strand_.post([this, self]{ - handler.handle_qued_callback(); - lock_guard_t guard(state.lock); - --state.protocol.wait_callback; - if (state.status == status_t::INTERRUPTED) + m_handler.handle_qued_callback(); + std::lock_guard<std::mutex> guard(m_state.lock); + --m_state.protocol.wait_callback; + if (m_state.status == status_t::INTERRUPTED) on_interrupted(); - else if (state.status == status_t::TERMINATING) + else if (m_state.status == status_t::TERMINATING) on_terminating(); }); return true; @@ -1085,7 +1093,7 @@ namespace net_utils template<typename T> typename connection<T>::io_context_t &connection<T>::get_io_service() { - return io_context; + return m_io_context; } template<typename T> @@ -1093,9 +1101,9 @@ namespace net_utils { try { auto self = connection<T>::shared_from_this(); - lock_guard_t guard(state.lock); + std::lock_guard<std::mutex> guard(m_state.lock); this->self = std::move(self); - ++state.protocol.reference_counter; + ++m_state.protocol.reference_counter; return true; } catch (boost::bad_weak_ptr &exception) { @@ -1107,8 +1115,8 @@ namespace net_utils bool connection<T>::release() { connection_ptr self; - lock_guard_t guard(state.lock); - if (not --state.protocol.reference_counter) + std::lock_guard<std::mutex> guard(m_state.lock); + if (!(--m_state.protocol.reference_counter)) self = std::move(this->self); return true; } @@ -1116,8 +1124,8 @@ namespace net_utils template<typename T> void connection<T>::setRpcStation() { - lock_guard_t guard(state.lock); - connection_type = e_connection_type_RPC; + std::lock_guard<std::mutex> guard(m_state.lock); + m_connection_type = e_connection_type_RPC; } template<class t_protocol_handler> diff --git a/contrib/epee/src/net_ssl.cpp b/contrib/epee/src/net_ssl.cpp index 7dda65bb5..2d0b7d791 100644 --- a/contrib/epee/src/net_ssl.cpp +++ b/contrib/epee/src/net_ssl.cpp @@ -553,9 +553,6 @@ bool ssl_options_t::handshake( using ec_t = boost::system::error_code; using timer_t = boost::asio::steady_timer; using strand_t = boost::asio::io_service::strand; - using lock_t = std::mutex; - using lock_guard_t = std::lock_guard<lock_t>; - using condition_t = std::condition_variable_any; using socket_t = boost::asio::ip::tcp::socket; auto &io_context = GET_IO_SERVICE(socket); @@ -565,8 +562,8 @@ bool ssl_options_t::handshake( timer_t deadline(io_context, timeout); struct state_t { - lock_t lock; - condition_t condition; + std::mutex lock; + std::condition_variable_any condition; ec_t result; bool wait_timer; bool wait_handshake; @@ -577,10 +574,10 @@ bool ssl_options_t::handshake( state.wait_timer = true; auto on_timer = [&](const ec_t &ec){ - lock_guard_t guard(state.lock); + std::lock_guard<std::mutex> guard(state.lock); state.wait_timer = false; state.condition.notify_all(); - if (not state.cancel_timer) { + if (!state.cancel_timer) { state.cancel_handshake = true; ec_t ec; socket.next_layer().cancel(ec); @@ -589,11 +586,11 @@ bool ssl_options_t::handshake( state.wait_handshake = true; auto on_handshake = [&](const ec_t &ec, size_t bytes_transferred){ - lock_guard_t guard(state.lock); + std::lock_guard<std::mutex> guard(state.lock); state.wait_handshake = false; state.condition.notify_all(); state.result = ec; - if (not state.cancel_handshake) { + if (!state.cancel_handshake) { state.cancel_timer = true; ec_t ec; deadline.cancel(ec); @@ -614,15 +611,15 @@ bool ssl_options_t::handshake( while (!io_context.stopped()) { io_context.poll_one(); - lock_guard_t guard(state.lock); + std::lock_guard<std::mutex> guard(state.lock); state.condition.wait_for( state.lock, std::chrono::milliseconds(30), [&]{ - return not state.wait_timer and not state.wait_handshake; + return !state.wait_timer && !state.wait_handshake; } ); - if (not state.wait_timer and not state.wait_handshake) + if (!state.wait_timer && !state.wait_handshake) break; } if (state.result.value()) { diff --git a/tests/unit_tests/epee_boosted_tcp_server.cpp b/tests/unit_tests/epee_boosted_tcp_server.cpp index d64431edf..c08a86a5e 100644 --- a/tests/unit_tests/epee_boosted_tcp_server.cpp +++ b/tests/unit_tests/epee_boosted_tcp_server.cpp @@ -617,7 +617,7 @@ TEST(boosted_tcp_server, strand_deadlock) void after_init_connection() { unique_lock_t guard(lock); - if (not context.m_is_income) { + if (!context.m_is_income) { guard.unlock(); socket->do_send(byte_slice_t{"."}); } @@ -628,7 +628,7 @@ TEST(boosted_tcp_server, strand_deadlock) bool handle_recv(const char *data, size_t bytes_transferred) { unique_lock_t guard(lock); - if (not context.m_is_income) { + if (!context.m_is_income) { if (context.m_recv_cnt == 1024) { guard.unlock(); socket->do_send(byte_slice_t{"."}); @@ -652,9 +652,9 @@ TEST(boosted_tcp_server, strand_deadlock) void release_protocol() { unique_lock_t guard(lock); - if(not context.m_is_income - and context.m_recv_cnt == 1024 - and context.m_send_cnt == 2 + if(!context.m_is_income + && context.m_recv_cnt == 1024 + && context.m_send_cnt == 2 ) { guard.unlock(); config.notify_success(); |