Add adding optional extra http header with Tunnel::Connect()
authorKrzysztof Malysa <k.malysa@samsung.com>
Tue, 19 Mar 2024 13:28:57 +0000 (14:28 +0100)
committerKrzysztof Jackiewicz <k.jackiewicz@samsung.com>
Fri, 12 Apr 2024 10:05:06 +0000 (12:05 +0200)
Change-Id: I928dbb4bb92dc55eda067e51bf2a3ac88dcb4625

srcs/tunnel.cpp
srcs/tunnel.h
srcs/websockets.cpp
srcs/websockets.h
tests/encrypted_tunnel_tests.cpp
tests/handshake_tests.cpp
tests/tunnel/auto_tests.cpp
tests/tunnel/auto_tests.h

index 0c28a4a403e7ca2878c5fb4b9dcc40715e3bc540..289ef07b5edc2bb66dbd81a08ff95350a4c97fc8 100644 (file)
@@ -105,7 +105,7 @@ Tunnel::~Tunnel()
     }
 }
 
-void Tunnel::Connect(const std::string &url)
+void Tunnel::Connect(const std::string &url, std::optional<ExtraHttpHeader> extraHttpHeader)
 {
     LogDebug("Connecting to " << url);
 
@@ -124,6 +124,7 @@ void Tunnel::Connect(const std::string &url)
             THROW_UNKNOWN("Creating libwebsocket context failed");
     }
 
+    m_extraHttpHeader = std::move(extraHttpHeader);
     m_connection = m_ws->ClientConnect(m_context, url);
     if (!m_connection)
         DisconnectOnError();
@@ -241,6 +242,30 @@ bool Tunnel::HandleEvent(Lws *wsi, enum lws_callback_reasons reason, void *in, s
             return true;
         }
 
+        case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER:
+            if (m_state != State::DISCONNECTED || in == nullptr) {
+                LogError("Unexpected event");
+                m_state = State::FAILED;
+                return true;
+            }
+
+            if (m_extraHttpHeader) {
+                auto p = static_cast<unsigned char **>(in);
+                auto end = *p + len;
+                if (!m_ws->AddHttpHeaderByName(
+                        wsi,
+                        reinterpret_cast<const unsigned char *>(m_extraHttpHeader->name.c_str()),
+                        reinterpret_cast<const unsigned char *>(m_extraHttpHeader->value.c_str()),
+                        m_extraHttpHeader->value.size(),
+                        p,
+                        end)) {
+                    LogError("Failed to add extra HTTP header");
+                    m_state = State::FAILED;
+                    return true;
+                }
+            }
+            break;
+
         case LWS_CALLBACK_CLIENT_ESTABLISHED:
             if (m_state != State::DISCONNECTED) {
                 LogError("Unexpected event");
index a709ebb38c313cb6e4f5df79ed8815bdb09e1ea6..a585bfe658169d006640b5e3a286262e25b08c56 100644 (file)
 #include <libwebsockets.h>
 #include <memory>
 #include <mutex>
+#include <optional>
 #include <string>
 #include <vector>
 
+struct ExtraHttpHeader {
+    std::string name;
+    std::string value;
+};
+
 class ITunnel {
 protected:
     ITunnel() = default;
@@ -33,7 +39,8 @@ protected:
 public:
     virtual ~ITunnel() = default;
 
-    virtual void Connect(const std::string &url) = 0;
+    virtual void Connect(const std::string &url,
+                         std::optional<ExtraHttpHeader> extraHttpHeader = std::nullopt) = 0;
     virtual void WriteBinary(const std::vector<uint8_t> &msg) = 0;
     virtual std::vector<uint8_t> ReadBinary() = 0;
     virtual void Disconnect() = 0;
@@ -48,7 +55,8 @@ public:
     Tunnel(const Tunnel &) = delete;
     Tunnel &operator=(const Tunnel &) = delete;
 
-    void Connect(const std::string &url) override;
+    void Connect(const std::string &url,
+                 std::optional<ExtraHttpHeader> extraHttpHeader = std::nullopt) override;
     void WriteBinary(const std::vector<uint8_t> &msg) override;
     std::vector<uint8_t> ReadBinary() override;
     void Disconnect() override;
@@ -76,6 +84,7 @@ protected:
     std::shared_ptr<IWebsockets> m_ws;
     LwsContext *m_context;
     Lws *m_connection;
+    std::optional<ExtraHttpHeader> m_extraHttpHeader;
     std::vector<uint8_t> m_in;
     std::vector<uint8_t> m_out;
     std::deque<std::vector<uint8_t>> m_recvMsgs;
index 9ffffe241afff735de7eb4e8b3eb593fd95c2a9c..2d5445aafe8dc1b8545e393aec58a8f9f1a6d6d7 100644 (file)
@@ -139,6 +139,16 @@ bool Websockets::Service(LwsContext *context) noexcept
     return lws_service(Context2Native(context), 0) != 0;
 }
 
+bool Websockets::AddHttpHeaderByName(Lws *wsi,
+                                     const unsigned char *name,
+                                     const unsigned char *value,
+                                     int length,
+                                     unsigned char **p,
+                                     unsigned char *end) noexcept
+{
+    return lws_add_http_header_by_name(Lws2Native(wsi), name, value, length, p, end) == 0;
+}
+
 bool Websockets::CallbackOnWritable(Lws *wsi) noexcept
 {
     return lws_callback_on_writable(Lws2Native(wsi)) != 0;
index 8535b11af506b3efba48774c12a437c881606da1..6a5ed326a073b579ae5222f79620872b70160d1a 100644 (file)
@@ -81,6 +81,25 @@ public:
      */
     virtual bool Service(LwsContext *context) noexcept = 0;
 
+    /*
+     *
+     * @brief Add custom HTTP header to the handshake.
+     *
+     * @param[in] name  Name of the HTTP header.
+     * @param[in] value  Value of the HTTP header.
+     * @param[in] length  Length of the value of the HTTP header.
+     * @param[in,out] p  Pointer to current position in the buffer.
+     * @param[in] end  Pointer to the end of the buffer.
+     *
+     * @return true if it succeeds, false otherwise
+     */
+    virtual bool AddHttpHeaderByName(Lws *wsi,
+                                     const unsigned char *name,
+                                     const unsigned char *value,
+                                     int length,
+                                     unsigned char **p,
+                                     unsigned char *end) = 0;
+
     /*
      * @brief Asks to trigger LWS_CALLBACK_CLIENT_WRITEABLE event as soon as the connection is
      *        writeable.
@@ -161,6 +180,12 @@ public:
     LwsContext *CreateContext() noexcept override;
     Lws *ClientConnect(LwsContext *context, const std::string &url) noexcept override;
     bool Service(LwsContext *context) noexcept override;
+    bool AddHttpHeaderByName(Lws *wsi,
+                             const unsigned char *name,
+                             const unsigned char *value,
+                             int length,
+                             unsigned char **p,
+                             unsigned char *end) noexcept override;
     bool CallbackOnWritable(Lws *wsi) noexcept override;
     void CancelService(LwsContext *context) noexcept override;
     bool FrameIsBinary(Lws *wsi) const noexcept override;
index d3e0517308eee99b25b41e0bb0c84f6b2c80c123..0fbb7e9df0af9207fd97a4f82d3721357499d2fe 100644 (file)
@@ -27,7 +27,7 @@ namespace {
 
 class MTunnel : public ITunnel {
 public:
-    void Connect(const std::string & /*url*/) override
+    void Connect(const std::string &, std::optional<ExtraHttpHeader>) override
     {
         throw std::runtime_error{"Connect() should not be called"};
     }
@@ -144,7 +144,7 @@ namespace {
 
 class OTMEchoTunnel : public ITunnel {
 public:
-    void Connect(const std::string & /*url*/) override
+    void Connect(const std::string &, std::optional<ExtraHttpHeader>) override
     {
         throw std::runtime_error{"Connect() should not be called"};
     }
index aca3048c3b4aa49a93e1405cb648fd8b24b440f0..c65967956a631b89bd75e463dcacee22d6ea0d3d 100644 (file)
@@ -54,7 +54,7 @@ public:
     {
     }
 
-    void Connect(const std::string &url) override
+    void Connect(const std::string &url, std::optional<ExtraHttpHeader>) override
     {
         auto unpackedBleAdvert = UnpackDecryptedAdvert(m_decryptedBleAdvert);
         auto tunnelId = DeriveKey(m_qrSecret, {}, KeyPurpose::TunnelID, TUNNEL_ID_LEN);
@@ -190,9 +190,10 @@ class OTMTunnel : public MTunnel {
 public:
     using MTunnel::MTunnel;
 
-    void Connect(const std::string &url) override
+    void Connect(const std::string &url, std::optional<ExtraHttpHeader> extraHttpHeader) override
     {
-        m_cancelFacilitator.WithCancelCheck([&] { MTunnel::Connect(url); });
+        m_cancelFacilitator.WithCancelCheck(
+            [&] { MTunnel::Connect(url, std::move(extraHttpHeader)); });
     }
 
     void WriteBinary(const std::vector<uint8_t> &msg) override
index 2b383fbedd870d5d876d95632599cdac51e2c59d..b73219fd945bf8ac34ffd62f4b49e3386b8d8766 100644 (file)
@@ -145,6 +145,19 @@ public:
         return res;
     }
 
+    bool AddHttpHeaderByName(Lws *lws,
+                             const unsigned char *name,
+                             const unsigned char *value,
+                             int length,
+                             unsigned char **p,
+                             unsigned char *end) noexcept override
+    {
+        RandomDelay();
+        auto res = MockedSockets::AddHttpHeaderByName(lws, name, value, length, p, end);
+        RandomDelay();
+        return res;
+    }
+
     bool CallbackOnWritable(Lws *lws) noexcept override
     {
         RandomDelay();
@@ -236,7 +249,8 @@ Lws *MockedSockets::ClientConnect(LwsContext *lwsContext, const std::string &url
         // schedule events
         mockedLws->PushEvent(LWS_CALLBACK_WSI_CREATE);
         mockedLws->PushEvent(LWS_CALLBACK_OPENSSL_PERFORM_SERVER_CERT_VERIFICATION);
-        mockedLws->PushEvent(LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER);
+        mockedLws->PushEvent(LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER,
+                             std::vector<uint8_t>(sizeof(char **), '\0'));
         mockedLws->PushEvent(LWS_CALLBACK_ESTABLISHED_CLIENT_HTTP);
         mockedLws->PushEvent(LWS_CALLBACK_CLIENT_FILTER_PRE_ESTABLISH);
         mockedLws->PushEvent(LWS_CALLBACK_CLIENT_ESTABLISHED);
@@ -278,6 +292,16 @@ bool MockedSockets::Service(LwsContext *lwsContext) noexcept
     }
 }
 
+bool MockedSockets::AddHttpHeaderByName(Lws *,
+                                        const unsigned char *,
+                                        const unsigned char *,
+                                        int,
+                                        unsigned char **,
+                                        unsigned char *) noexcept
+{
+    return true;
+}
+
 bool MockedSockets::CallbackOnWritable(Lws *lws) noexcept
 {
     auto mockedLws = Lws2Mocked(lws);
index f240b97c91562373030d9a03ba7a6b58a5f94067..bb1f7500d46963379b5a0bd197242fd7ddd24ee6 100644 (file)
@@ -33,6 +33,12 @@ protected:
     LwsContext *CreateContext() noexcept override;
     Lws *ClientConnect(LwsContext *lwsContext, const std::string &url) noexcept override;
     bool Service(LwsContext *lwsContext) noexcept override;
+    bool AddHttpHeaderByName(Lws *lws,
+                             const unsigned char *name,
+                             const unsigned char *value,
+                             int length,
+                             unsigned char **p,
+                             unsigned char *end) noexcept override;
     bool CallbackOnWritable(Lws *lws) noexcept override;
 
     // only this method can be called from a different thread