}
}
-void Tunnel::Connect(const std::string &url)
+void Tunnel::Connect(const std::string &url, std::optional<ExtraHttpHeader> extraHttpHeader)
{
LogDebug("Connecting to " << url);
THROW_UNKNOWN("Creating libwebsocket context failed");
}
+ m_extraHttpHeader = std::move(extraHttpHeader);
m_connection = m_ws->ClientConnect(m_context, url);
if (!m_connection)
DisconnectOnError();
return true;
}
+ case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER:
+ if (m_state != State::DISCONNECTED || in == nullptr) {
+ LogError("Unexpected event");
+ m_state = State::FAILED;
+ return true;
+ }
+
+ if (m_extraHttpHeader) {
+ auto p = static_cast<unsigned char **>(in);
+ auto end = *p + len;
+ if (!m_ws->AddHttpHeaderByName(
+ wsi,
+ reinterpret_cast<const unsigned char *>(m_extraHttpHeader->name.c_str()),
+ reinterpret_cast<const unsigned char *>(m_extraHttpHeader->value.c_str()),
+ m_extraHttpHeader->value.size(),
+ p,
+ end)) {
+ LogError("Failed to add extra HTTP header");
+ m_state = State::FAILED;
+ return true;
+ }
+ }
+ break;
+
case LWS_CALLBACK_CLIENT_ESTABLISHED:
if (m_state != State::DISCONNECTED) {
LogError("Unexpected event");
#include <libwebsockets.h>
#include <memory>
#include <mutex>
+#include <optional>
#include <string>
#include <vector>
+struct ExtraHttpHeader {
+ std::string name;
+ std::string value;
+};
+
class ITunnel {
protected:
ITunnel() = default;
public:
virtual ~ITunnel() = default;
- virtual void Connect(const std::string &url) = 0;
+ virtual void Connect(const std::string &url,
+ std::optional<ExtraHttpHeader> extraHttpHeader = std::nullopt) = 0;
virtual void WriteBinary(const std::vector<uint8_t> &msg) = 0;
virtual std::vector<uint8_t> ReadBinary() = 0;
virtual void Disconnect() = 0;
Tunnel(const Tunnel &) = delete;
Tunnel &operator=(const Tunnel &) = delete;
- void Connect(const std::string &url) override;
+ void Connect(const std::string &url,
+ std::optional<ExtraHttpHeader> extraHttpHeader = std::nullopt) override;
void WriteBinary(const std::vector<uint8_t> &msg) override;
std::vector<uint8_t> ReadBinary() override;
void Disconnect() override;
std::shared_ptr<IWebsockets> m_ws;
LwsContext *m_context;
Lws *m_connection;
+ std::optional<ExtraHttpHeader> m_extraHttpHeader;
std::vector<uint8_t> m_in;
std::vector<uint8_t> m_out;
std::deque<std::vector<uint8_t>> m_recvMsgs;
return lws_service(Context2Native(context), 0) != 0;
}
+bool Websockets::AddHttpHeaderByName(Lws *wsi,
+ const unsigned char *name,
+ const unsigned char *value,
+ int length,
+ unsigned char **p,
+ unsigned char *end) noexcept
+{
+ return lws_add_http_header_by_name(Lws2Native(wsi), name, value, length, p, end) == 0;
+}
+
bool Websockets::CallbackOnWritable(Lws *wsi) noexcept
{
return lws_callback_on_writable(Lws2Native(wsi)) != 0;
*/
virtual bool Service(LwsContext *context) noexcept = 0;
+ /*
+ *
+ * @brief Add custom HTTP header to the handshake.
+ *
+ * @param[in] name Name of the HTTP header.
+ * @param[in] value Value of the HTTP header.
+ * @param[in] length Length of the value of the HTTP header.
+ * @param[in,out] p Pointer to current position in the buffer.
+ * @param[in] end Pointer to the end of the buffer.
+ *
+ * @return true if it succeeds, false otherwise
+ */
+ virtual bool AddHttpHeaderByName(Lws *wsi,
+ const unsigned char *name,
+ const unsigned char *value,
+ int length,
+ unsigned char **p,
+ unsigned char *end) = 0;
+
/*
* @brief Asks to trigger LWS_CALLBACK_CLIENT_WRITEABLE event as soon as the connection is
* writeable.
LwsContext *CreateContext() noexcept override;
Lws *ClientConnect(LwsContext *context, const std::string &url) noexcept override;
bool Service(LwsContext *context) noexcept override;
+ bool AddHttpHeaderByName(Lws *wsi,
+ const unsigned char *name,
+ const unsigned char *value,
+ int length,
+ unsigned char **p,
+ unsigned char *end) noexcept override;
bool CallbackOnWritable(Lws *wsi) noexcept override;
void CancelService(LwsContext *context) noexcept override;
bool FrameIsBinary(Lws *wsi) const noexcept override;
class MTunnel : public ITunnel {
public:
- void Connect(const std::string & /*url*/) override
+ void Connect(const std::string &, std::optional<ExtraHttpHeader>) override
{
throw std::runtime_error{"Connect() should not be called"};
}
class OTMEchoTunnel : public ITunnel {
public:
- void Connect(const std::string & /*url*/) override
+ void Connect(const std::string &, std::optional<ExtraHttpHeader>) override
{
throw std::runtime_error{"Connect() should not be called"};
}
{
}
- void Connect(const std::string &url) override
+ void Connect(const std::string &url, std::optional<ExtraHttpHeader>) override
{
auto unpackedBleAdvert = UnpackDecryptedAdvert(m_decryptedBleAdvert);
auto tunnelId = DeriveKey(m_qrSecret, {}, KeyPurpose::TunnelID, TUNNEL_ID_LEN);
public:
using MTunnel::MTunnel;
- void Connect(const std::string &url) override
+ void Connect(const std::string &url, std::optional<ExtraHttpHeader> extraHttpHeader) override
{
- m_cancelFacilitator.WithCancelCheck([&] { MTunnel::Connect(url); });
+ m_cancelFacilitator.WithCancelCheck(
+ [&] { MTunnel::Connect(url, std::move(extraHttpHeader)); });
}
void WriteBinary(const std::vector<uint8_t> &msg) override
return res;
}
+ bool AddHttpHeaderByName(Lws *lws,
+ const unsigned char *name,
+ const unsigned char *value,
+ int length,
+ unsigned char **p,
+ unsigned char *end) noexcept override
+ {
+ RandomDelay();
+ auto res = MockedSockets::AddHttpHeaderByName(lws, name, value, length, p, end);
+ RandomDelay();
+ return res;
+ }
+
bool CallbackOnWritable(Lws *lws) noexcept override
{
RandomDelay();
// schedule events
mockedLws->PushEvent(LWS_CALLBACK_WSI_CREATE);
mockedLws->PushEvent(LWS_CALLBACK_OPENSSL_PERFORM_SERVER_CERT_VERIFICATION);
- mockedLws->PushEvent(LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER);
+ mockedLws->PushEvent(LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER,
+ std::vector<uint8_t>(sizeof(char **), '\0'));
mockedLws->PushEvent(LWS_CALLBACK_ESTABLISHED_CLIENT_HTTP);
mockedLws->PushEvent(LWS_CALLBACK_CLIENT_FILTER_PRE_ESTABLISH);
mockedLws->PushEvent(LWS_CALLBACK_CLIENT_ESTABLISHED);
}
}
+bool MockedSockets::AddHttpHeaderByName(Lws *,
+ const unsigned char *,
+ const unsigned char *,
+ int,
+ unsigned char **,
+ unsigned char *) noexcept
+{
+ return true;
+}
+
bool MockedSockets::CallbackOnWritable(Lws *lws) noexcept
{
auto mockedLws = Lws2Mocked(lws);
LwsContext *CreateContext() noexcept override;
Lws *ClientConnect(LwsContext *lwsContext, const std::string &url) noexcept override;
bool Service(LwsContext *lwsContext) noexcept override;
+ bool AddHttpHeaderByName(Lws *lws,
+ const unsigned char *name,
+ const unsigned char *value,
+ int length,
+ unsigned char **p,
+ unsigned char *end) noexcept override;
bool CallbackOnWritable(Lws *lws) noexcept override;
// only this method can be called from a different thread