Malloy
Loading...
Searching...
No Matches
connection.hpp
1#pragma once
2
3#include "types.hpp"
4#include "../error.hpp"
5#include "../type_traits.hpp"
6#include "../utils.hpp"
7#include "../detail/action_queue.hpp"
8#include "../http/request.hpp"
9#include "../websocket/stream.hpp"
10
11#include <boost/asio/io_context.hpp>
12#include <boost/asio/post.hpp>
13#include <boost/beast/core/error.hpp>
14#include <fmt/format.h>
15#include <spdlog/spdlog.h>
16
17#include <concepts>
18#include <functional>
19#include <memory>
20
22{
23
34 template<bool isClient>
35 class connection :
36 public std::enable_shared_from_this<connection<isClient>>
37 {
38 using ws_executor_t = std::invoke_result_t<decltype(&stream::get_executor), stream*>;
40
41 public:
42 using handler_t = std::function<void(const malloy::http::request<>&, const std::shared_ptr<connection>&)>;
43
47 enum class state
48 {
49 handshaking,
50 active,
51 closing,
52 closed,
53 inactive, // Initial state
54 };
55
59 virtual
60 ~connection() noexcept
61 {
62 m_logger->trace("destructor()");
63 }
64
72 [[nodiscard]]
73 std::shared_ptr<spdlog::logger>
74 logger() const noexcept
75 {
76 return m_logger;
77 }
78
82 void
83 set_binary(const bool enabled)
84 {
85 m_ws.set_binary(enabled);
86 }
87
91 [[nodiscard]]
92 bool
94 {
95 return m_ws.binary();
96 }
97
104 static
105 std::shared_ptr<connection>
106 make(const std::shared_ptr<spdlog::logger> logger, stream&& ws, const std::string& agent_string)
107 {
108 // We have to emulate make_shared here because the ctor is private
109 connection* me = nullptr;
110 try {
111 me = new connection{logger, std::move(ws), agent_string};
112 return std::shared_ptr<connection>{me};
113 }
114 catch (...) {
115 delete me;
116 throw;
117 }
118 }
119
128 template<concepts::accept_handler Callback>
129 void
130 connect(const boost::asio::ip::tcp::resolver::results_type& target, const std::string& resource, Callback&& done)
131 requires(isClient)
132 {
133 m_logger->trace("connect()");
134
135 if (m_state != state::inactive)
136 throw std::logic_error{"connect() called on already active websocket connection"};
137
138 // Set the timeout for the operation
139 m_ws.get_lowest_layer([&, me = this->shared_from_this(), this, done = std::forward<Callback>(done), resource](auto& sock) mutable {
140 sock.expires_after(std::chrono::seconds(30));
141
142 // Make the connection on the IP address we get from a lookup
143 sock.async_connect(
144 target,
145 [this, me, target, done = std::forward<Callback>(done), resource](auto ec, auto ep) mutable {
146 if (ec) {
147 done(ec);
148 } else {
149 me->on_connect(ec, ep, resource, [this, done = std::forward<Callback>(done)](auto ec) mutable {
150 go_active();
151 std::invoke(std::forward<decltype(done)>(done), ec);
152 });
153 }
154 });
155 });
156 }
157
169 template<class Body, class Fields, std::invocable<> Callback>
170 void
171 accept(const boost::beast::http::request<Body, Fields>& req, Callback&& done)
172 requires(!isClient)
173 {
174 m_logger->trace("accept()");
175
176 if (m_state != state::inactive)
177 throw std::logic_error{"accept() called on already active websocket connection"};
178
179 // Update state
180 m_state = state::handshaking;
181
182 setup_connection();
183
184 // Accept the websocket handshake
185 m_ws.async_accept(req, [this, me = this->shared_from_this(), done = std::forward<decltype(done)>(done)](malloy::error_code ec) mutable {
186 m_logger->trace("on_accept()");
187
188 // Check for errors
189 if (ec) {
190 m_logger->error("on_accept(): {}", ec.message());
191 return;
192 }
193
194 // We're good to go
195 go_active();
196
197 std::invoke(std::forward<decltype(done)>(done));
198 });
199 }
200
210 void
211 disconnect(boost::beast::websocket::close_reason why = boost::beast::websocket::normal)
212 {
213 m_logger->trace("disconnect()");
214
215 if (m_state == state::closed || m_state == state::closing)
216 return;
217
218 auto build_act = [this, why, me = this->shared_from_this()](const auto& on_done) mutable {
219 // Check we haven't been beaten to it
220 if (m_state == state::closed || m_state == state::closing) {
221 on_done();
222 return;
223 }
224
225 do_disconnect(why, on_done);
226 };
227
228 // We queue in both read and write, and whichever gets there first wins
229 m_write_queue.push(build_act);
230 m_read_queue.push(build_act);
231 }
232
240 void
241 force_disconnect(boost::beast::websocket::close_reason why = boost::beast::websocket::normal)
242 {
243 m_logger->trace("force_disconnect()");
244
245 if (m_state == state::inactive)
246 throw std::logic_error{"force_disconnect() called on inactive websocket connection"};
247
248 else if (m_state == state::closed || m_state == state::closing)
249 return; // Already disconnecting
250
251 do_disconnect(why, []{});
252 }
253
266 void
268 {
269 m_logger->trace("read()");
270
271 m_read_queue.push(
272 [
273 this,
274 me = this->shared_from_this(),
275 buff = &buff, // Capturing reference by value copies the object
276 done = std::forward<decltype(done)>(done)
277 ]
278 (const auto& on_done) mutable
279 {
280 assert(buff != nullptr);
281 m_ws.async_read(*buff, [this, me, on_done, done = std::forward<decltype(done)>(done)](auto ec, auto size) mutable {
282 std::invoke(std::forward<decltype(done)>(done), ec, size);
283 on_done();
284 });
285 }
286 );
287 }
288
299 template<concepts::async_read_handler Callback>
300 void
301 send(const concepts::const_buffer_sequence auto& payload, Callback&& done)
302 {
303 m_logger->trace("send(). payload size: {}", payload.size());
304
305 m_write_queue.push([buff = payload, done = std::forward<Callback>(done), this, me = this->shared_from_this()](const auto& on_done) mutable {
306 m_ws.async_write(buff, [this, me, on_done, done = std::forward<decltype(done)>(done)](auto ec, auto size) mutable {
307 on_write(ec, size);
308 std::invoke(std::forward<Callback>(done), ec, size);
309 on_done();
310 });
311 });
312 }
313
314 private:
315 enum class sending_state
316 {
317 idling,
318 sending
319 };
320
321 enum sending_state m_sending_state = sending_state::idling;
322 std::shared_ptr<spdlog::logger> m_logger;
323 stream m_ws;
324 std::string m_agent_string;
325 act_queue_t m_write_queue;
326 act_queue_t m_read_queue;
327 std::atomic<state> m_state{ state::inactive };
328
329 connection(
330 std::shared_ptr<spdlog::logger> logger, stream&& ws, std::string agent_str) :
331 m_logger(std::move(logger)),
332 m_ws{std::move(ws)},
333 m_agent_string{std::move(agent_str)},
334 m_write_queue{boost::asio::make_strand(m_ws.get_executor())},
335 m_read_queue{boost::asio::make_strand(m_ws.get_executor())}
336 {
337 // Sanity check logger
338 if (!m_logger)
339 throw std::invalid_argument("no valid logger provided.");
340 }
341
342 void
343 go_active()
344 {
345 m_logger->trace("go_active()");
346
347 // Update state
348 m_state = state::active;
349
350 // Start/run action queues
351 m_read_queue.run();
352 m_write_queue.run();
353 }
354
355 void
356 setup_connection()
357 {
358 m_logger->trace("setup_connection()");
359
360 // Set suggested timeout settings for the websocket
361 m_ws.set_option(
362 boost::beast::websocket::stream_base::timeout::suggested(
363 isClient ? boost::beast::role_type::client : boost::beast::role_type::server)
364 );
365
366 // Set agent string/field
367 const auto agent_field = isClient ? malloy::http::field::user_agent : malloy::http::field::server;
368 m_ws.set_option(
369 boost::beast::websocket::stream_base::decorator(
370 [this, agent_field](boost::beast::websocket::request_type& req) {
371 req.set(agent_field, m_agent_string);
372 }
373 )
374 );
375 }
376
377 void
378 do_disconnect(boost::beast::websocket::close_reason why, const std::invocable<> auto& on_done)
379 {
380 m_logger->trace("do_disconnect()");
381
382 // Update state
383 m_state = state::closing;
384
385 m_ws.async_close(why, [me = this->shared_from_this(), this, on_done](auto ec) {
386 if (ec)
387 m_logger->error("could not close websocket: '{}'", ec.message()); // TODO: See #40
388 else
389 on_close();
390
391 on_done();
392 });
393 }
394
395 void
396 on_connect(
397 boost::beast::error_code ec,
398 boost::asio::ip::tcp::resolver::results_type::endpoint_type ep,
399 const std::string& resource,
400 concepts::accept_handler auto&& on_handshake)
401 {
402 m_logger->trace("on_connect()");
403
404 if (ec) {
405 m_logger->error("on_connect(): {}", ec.message());
406 return;
407 }
408
409 m_ws.get_lowest_layer([](auto& s) { s.expires_never(); }); // websocket has its own timeout system that conflicts
410
411 // Update the m_host string. This will provide the value of the
412 // Host HTTP header during the WebSocket handshake.
413 // See https://tools.ietf.org/html/rfc7230#section-5.4
414 const std::string host = fmt::format("{}:{}", ep.address().to_string(), ep.port());
415
416#if MALLOY_FEATURE_TLS
417 if constexpr (isClient) {
418 if (m_ws.is_tls()) {
419 // TODO: Should this be a separate method?
420 m_ws.async_handshake_tls(
421 boost::asio::ssl::stream_base::handshake_type::client,
422 [on_handshake = std::forward<decltype(on_handshake)>(on_handshake), resource, host, me = this->shared_from_this()](auto ec) mutable
423 {
424 if (ec)
425 on_handshake(ec);
426
427 me->on_ready_for_handshake(host, resource, std::forward<decltype(on_handshake)>(on_handshake));
428 }
429 );
430 return;
431 }
432 }
433#endif
434 on_ready_for_handshake(host, resource, std::forward<decltype(on_handshake)>(on_handshake));
435 }
436
437 void
438 on_ready_for_handshake(const std::string& host, const std::string& resource, concepts::accept_handler auto&& on_handshake)
439 {
440 m_logger->trace("on_ready_for_handshake()");
441
442 // Turn off the timeout on the tcp_stream, because
443 // the websocket stream has its own timeout system.
444 m_ws.get_lowest_layer([](auto& s) { s.expires_never(); });
445 setup_connection();
446
447 // Perform the websocket handshake
448 m_ws.async_handshake(
449 host,
450 resource,
451 std::forward<decltype(on_handshake)>(on_handshake)
452 );
453 }
454
455 void
456 on_write(auto ec, auto size)
457 {
458 m_logger->trace("on_write() wrote: '{}' bytes", size);
459
460 if (ec) {
461 m_logger->error("on_write failed for websocket connection: '{}'", ec.message());
462 return;
463 }
464 }
465
466 void
467 on_close()
468 {
469 m_logger->trace("on_close()");
470
471 m_state = state::closed;
472 }
473 };
474
475} // namespace malloy::websocket
void push(act_t act)
Add an action to the queue.
Definition: action_queue.hpp:48
Definition: request.hpp:19
Represents a connection via the WebSocket protocol.
Definition: connection.hpp:37
static std::shared_ptr< connection > make(const std::shared_ptr< spdlog::logger > logger, stream &&ws, const std::string &agent_string)
Construct a new connection object.
Definition: connection.hpp:106
void send(const concepts::const_buffer_sequence auto &payload, Callback &&done)
Send the contents of a buffer to the client.
Definition: connection.hpp:301
state
Definition: connection.hpp:48
void force_disconnect(boost::beast::websocket::close_reason why=boost::beast::websocket::normal)
Same as disconnect, but bypasses all queues and runs immediately.
Definition: connection.hpp:241
void set_binary(const bool enabled)
Definition: connection.hpp:83
void read(concepts::dynamic_buffer auto &buff, concepts::async_read_handler auto &&done)
Read a complete message into a buffer.
Definition: connection.hpp:267
bool binary()
Definition: connection.hpp:93
std::shared_ptr< spdlog::logger > logger() const noexcept
Definition: connection.hpp:74
void accept(const boost::beast::http::request< Body, Fields > &req, Callback &&done)
Accept an incoming connection.
Definition: connection.hpp:171
void connect(const boost::asio::ip::tcp::resolver::results_type &target, const std::string &resource, Callback &&done)
Connect to a remote (websocket) endpoint.
Definition: connection.hpp:130
virtual ~connection() noexcept
Definition: connection.hpp:60
void disconnect(boost::beast::websocket::close_reason why=boost::beast::websocket::normal)
Disconnect/stop/close the connection.
Definition: connection.hpp:211
Websocket stream. May use TLS.
Definition: stream.hpp:50
bool binary() const
Checks whether outgoing messages will be indicated as text or binary.
Definition: stream.hpp:179
void get_lowest_layer(Func &&visitor)
Access get_lowest_layer of wrapped stream type.
Definition: stream.hpp:200
void set_binary(const bool enabled)
Controls whether outgoing message will be indicated text or binary.
Definition: stream.hpp:162
auto get_executor()
Get executor of the underlying stream.
Definition: stream.hpp:211
Definition: type_traits.hpp:44
Definition: type_traits.hpp:35
Definition: type_traits.hpp:41
Definition: connection.hpp:22
boost::beast::error_code error_code
Error code used to signify errors without throwing. Truthy means it holds an error.
Definition: error.hpp:9