diff options
Diffstat (limited to 'src/device_trezor/device_trezor_base.hpp')
-rw-r--r-- | src/device_trezor/device_trezor_base.hpp | 112 |
1 files changed, 37 insertions, 75 deletions
diff --git a/src/device_trezor/device_trezor_base.hpp b/src/device_trezor/device_trezor_base.hpp index 644a49332..88d419494 100644 --- a/src/device_trezor/device_trezor_base.hpp +++ b/src/device_trezor/device_trezor_base.hpp @@ -54,7 +54,7 @@ namespace hw { namespace trezor { -#if WITH_DEVICE_TREZOR +#ifdef WITH_DEVICE_TREZOR class device_trezor_base; /** @@ -69,41 +69,6 @@ namespace trezor { }; /** - * Default Trezor protocol client callback - */ - class trezor_protocol_callback { - protected: - device_trezor_base & device; - - public: - explicit trezor_protocol_callback(device_trezor_base & device): device(device) {} - - std::shared_ptr<google::protobuf::Message> on_button_request(const messages::common::ButtonRequest * msg); - std::shared_ptr<google::protobuf::Message> on_pin_matrix_request(const messages::common::PinMatrixRequest * msg); - std::shared_ptr<google::protobuf::Message> on_passphrase_request(const messages::common::PassphraseRequest * msg); - std::shared_ptr<google::protobuf::Message> on_passphrase_state_request(const messages::common::PassphraseStateRequest * msg); - - std::shared_ptr<google::protobuf::Message> on_message(const google::protobuf::Message * msg, messages::MessageType message_type){ - MDEBUG("on_general_message"); - return on_message_dispatch(msg, message_type); - } - - std::shared_ptr<google::protobuf::Message> on_message_dispatch(const google::protobuf::Message * msg, messages::MessageType message_type){ - if (message_type == messages::MessageType_ButtonRequest){ - return on_button_request(dynamic_cast<const messages::common::ButtonRequest*>(msg)); - } else if (message_type == messages::MessageType_PassphraseRequest) { - return on_passphrase_request(dynamic_cast<const messages::common::PassphraseRequest*>(msg)); - } else if (message_type == messages::MessageType_PassphraseStateRequest) { - return on_passphrase_state_request(dynamic_cast<const messages::common::PassphraseStateRequest*>(msg)); - } else if (message_type == messages::MessageType_PinMatrixRequest) { - return on_pin_matrix_request(dynamic_cast<const messages::common::PinMatrixRequest*>(msg)); - } else { - return nullptr; - } - } - }; - - /** * TREZOR device template with basic functions */ class device_trezor_base : public hw::core::device_default { @@ -114,7 +79,6 @@ namespace trezor { mutable boost::mutex command_locker; std::shared_ptr<Transport> m_transport; - std::shared_ptr<trezor_protocol_callback> m_protocol_callback; std::shared_ptr<trezor_callback> m_callback; std::string full_name; @@ -129,6 +93,15 @@ namespace trezor { void call_ping_unsafe(); void test_ping(); + // Communication methods + + void write_raw(const google::protobuf::Message * msg); + GenericMessage read_raw(); + GenericMessage call_raw(const google::protobuf::Message * msg); + + // Trezor message protocol handler. Handles specific signalling messages. + bool message_handler(GenericMessage & input); + /** * Client communication wrapper, handles specific Trezor protocol. * @@ -141,8 +114,7 @@ namespace trezor { const boost::optional<messages::MessageType> & resp_type = boost::none, const boost::optional<std::vector<messages::MessageType>> & resp_types = boost::none, const boost::optional<messages::MessageType*> & resp_type_ptr = boost::none, - bool open_session = false, - unsigned depth=0) + bool open_session = false) { // Require strictly protocol buffers response in the template. BOOST_STATIC_ASSERT(boost::is_base_of<google::protobuf::Message, t_message>::value); @@ -151,8 +123,12 @@ namespace trezor { throw std::invalid_argument("Cannot specify list of accepted types and not using generic response"); } + // Determine type of expected message response + const messages::MessageType required_type = accepting_base ? messages::MessageType_Success : + (resp_type ? resp_type.get() : MessageMapper::get_message_wire_number<t_message>()); + // Open session if required - if (open_session && depth == 0){ + if (open_session){ try { m_transport->open(); } catch (const std::exception& e) { @@ -162,47 +138,37 @@ namespace trezor { // Scoped session closer BOOST_SCOPE_EXIT_ALL(&, this) { - if (open_session && depth == 0){ + if (open_session){ this->getTransport()->close(); } }; - // Write the request + // Write/read the request CHECK_AND_ASSERT_THROW_MES(req, "Request is null"); - this->getTransport()->write(*req); + auto msg_resp = call_raw(req.get()); - // Read the response - std::shared_ptr<google::protobuf::Message> msg_resp; - hw::trezor::messages::MessageType msg_resp_type; + bool processed = false; + do { + processed = message_handler(msg_resp); + } while(processed); - // We may have several roundtrips with the handler - this->getTransport()->read(msg_resp, &msg_resp_type); + // Response section if (resp_type_ptr){ - *(resp_type_ptr.get()) = msg_resp_type; + *(resp_type_ptr.get()) = msg_resp.m_type; } - // Determine type of expected message response - messages::MessageType required_type = accepting_base ? messages::MessageType_Success : - (resp_type ? resp_type.get() : MessageMapper::get_message_wire_number<t_message>()); + if (msg_resp.m_type == messages::MessageType_Failure) { + throw_failure_exception(dynamic_cast<messages::common::Failure *>(msg_resp.m_msg.get())); - if (msg_resp_type == messages::MessageType_Failure) { - throw_failure_exception(dynamic_cast<messages::common::Failure *>(msg_resp.get())); + } else if (!accepting_base && msg_resp.m_type == required_type) { + return message_ptr_retype<t_message>(msg_resp.m_msg); - } else if (!accepting_base && msg_resp_type == required_type) { - return message_ptr_retype<t_message>(msg_resp); + } else if (accepting_base && (!resp_types || + std::find(resp_types.get().begin(), resp_types.get().end(), msg_resp.m_type) != resp_types.get().end())) { + return message_ptr_retype<t_message>(msg_resp.m_msg); } else { - auto resp = this->getProtocolCallback()->on_message(msg_resp.get(), msg_resp_type); - if (resp) { - return this->client_exchange<t_message>(resp, boost::none, resp_types, resp_type_ptr, false, depth + 1); - - } else if (accepting_base && (!resp_types || - std::find(resp_types.get().begin(), resp_types.get().end(), msg_resp_type) != resp_types.get().end())) { - return message_ptr_retype<t_message>(msg_resp); - - } else { - throw exc::UnexpectedMessageException(msg_resp_type, msg_resp); - } + throw exc::UnexpectedMessageException(msg_resp.m_type, msg_resp.m_msg); } } @@ -252,10 +218,6 @@ namespace trezor { return m_transport; } - std::shared_ptr<trezor_protocol_callback> getProtocolCallback(){ - return m_protocol_callback; - } - std::shared_ptr<trezor_callback> getCallback(){ return m_callback; } @@ -288,10 +250,10 @@ namespace trezor { bool ping(); // Protocol callbacks - void on_button_request(); - void on_pin_request(epee::wipeable_string & pin); - void on_passphrase_request(bool on_device, epee::wipeable_string & passphrase); - void on_passphrase_state_request(const std::string & state); + void on_button_request(GenericMessage & resp, const messages::common::ButtonRequest * msg); + void on_pin_request(GenericMessage & resp, const messages::common::PinMatrixRequest * msg); + void on_passphrase_request(GenericMessage & resp, const messages::common::PassphraseRequest * msg); + void on_passphrase_state_request(GenericMessage & resp, const messages::common::PassphraseStateRequest * msg); }; #endif |