Add user configurable mask generation
[contrib/qtwebsockets.git] / src / websockets / qwebsocket_p.cpp
index 5f59b54..bff060f 100644 (file)
@@ -44,6 +44,7 @@
 #include "qwebsocketprotocol_p.h"
 #include "qwebsockethandshakerequest_p.h"
 #include "qwebsockethandshakeresponse_p.h"
+#include "qdefaultmaskgenerator_p.h"
 
 #include <QtCore/QUrl>
 #include <QtNetwork/QAuthenticator>
@@ -109,7 +110,9 @@ QWebSocketPrivate::QWebSocketPrivate(const QString &origin, QWebSocketProtocol::
     m_closeReason(),
     m_pingTimer(),
     m_dataProcessor(),
-    m_configuration()
+    m_configuration(),
+    m_pMaskGenerator(&m_defaultMaskGenerator),
+    m_defaultMaskGenerator()
 {
 }
 
@@ -139,7 +142,9 @@ QWebSocketPrivate::QWebSocketPrivate(QTcpSocket *pTcpSocket, QWebSocketProtocol:
     m_closeReason(),
     m_pingTimer(),
     m_dataProcessor(),
-    m_configuration()
+    m_configuration(),
+    m_pMaskGenerator(&m_defaultMaskGenerator),
+    m_defaultMaskGenerator()
 {
 }
 
@@ -149,8 +154,9 @@ QWebSocketPrivate::QWebSocketPrivate(QTcpSocket *pTcpSocket, QWebSocketProtocol:
 void QWebSocketPrivate::init()
 {
     Q_ASSERT(q_ptr);
-    //TODO: need a better randomizer
-    qsrand(static_cast<uint>(QDateTime::currentMSecsSinceEpoch()));
+    Q_ASSERT(m_pMaskGenerator);
+
+    m_pMaskGenerator->seed();
 
     if (m_pSocket) {
         makeConnections(m_pSocket.data());
@@ -330,8 +336,14 @@ void QWebSocketPrivate::close(QWebSocketProtocol::CloseCode closeCode, QString r
 void QWebSocketPrivate::open(const QUrl &url, bool mask)
 {
     //just delete the old socket for the moment;
-    //later, we can add more 'intelligent' handling by looking at the url
+    //later, we can add more 'intelligent' handling by looking at the URL
     //m_pSocket.reset();
+    Q_Q(QWebSocket);
+    if (!url.isValid() || url.toString().contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("Invalid URL."));
+        Q_EMIT q->error(QAbstractSocket::ConnectionRefusedError);
+        return;
+    }
     QTcpSocket *pTcpSocket = m_pSocket.take();
     if (pTcpSocket) {
         releaseConnections(pTcpSocket);
@@ -339,14 +351,18 @@ void QWebSocketPrivate::open(const QUrl &url, bool mask)
     }
     //if (m_url != url)
     if (Q_LIKELY(!m_pSocket)) {
-        Q_Q(QWebSocket);
-
         m_dataProcessor.clear();
         m_isClosingHandshakeReceived = false;
         m_isClosingHandshakeSent = false;
 
         setRequestUrl(url);
         QString resourceName = url.path();
+        if (resourceName.contains(QStringLiteral("\r\n"))) {
+            setRequestUrl(QUrl());  //clear requestUrl
+            setErrorString(QWebSocket::tr("Invalid resource name."));
+            Q_EMIT q->error(QAbstractSocket::ConnectionRefusedError);
+            return;
+        }
         if (!url.query().isEmpty()) {
             if (!resourceName.endsWith(QChar::fromLatin1('?'))) {
                 resourceName.append(QChar::fromLatin1('?'));
@@ -366,7 +382,7 @@ void QWebSocketPrivate::open(const QUrl &url, bool mask)
                 setErrorString(message);
                 Q_EMIT q->error(QAbstractSocket::UnsupportedSocketOperationError);
             } else {
-                QSslSocket *sslSocket = new QSslSocket(q);
+                QSslSocket *sslSocket = new QSslSocket;
                 m_pSocket.reset(sslSocket);
                 if (Q_LIKELY(m_pSocket)) {
                     m_pSocket->setSocketOption(QAbstractSocket::LowDelayOption, 1);
@@ -377,6 +393,10 @@ void QWebSocketPrivate::open(const QUrl &url, bool mask)
                     makeConnections(m_pSocket.data());
                     QObject::connect(sslSocket, &QSslSocket::encryptedBytesWritten, q,
                                      &QWebSocket::bytesWritten);
+                    typedef void (QSslSocket:: *sslErrorSignalType)(const QList<QSslError> &);
+                    QObject::connect(sslSocket,
+                                     static_cast<sslErrorSignalType>(&QSslSocket::sslErrors),
+                                     q, &QWebSocket::sslErrors);
                     setSocketState(QAbstractSocket::ConnectingState);
 
                     sslSocket->setSslConfiguration(m_configuration.m_sslConfiguration);
@@ -397,7 +417,7 @@ void QWebSocketPrivate::open(const QUrl &url, bool mask)
         } else
     #endif
         if (url.scheme() == QStringLiteral("ws")) {
-            m_pSocket.reset(new QTcpSocket(q));
+            m_pSocket.reset(new QTcpSocket);
             if (Q_LIKELY(m_pSocket)) {
                 m_pSocket->setSocketOption(QAbstractSocket::LowDelayOption, 1);
                 m_pSocket->setSocketOption(QAbstractSocket::KeepAliveOption, 1);
@@ -561,7 +581,7 @@ void QWebSocketPrivate::releaseConnections(const QTcpSocket *pTcpSocket)
 {
     if (Q_LIKELY(pTcpSocket))
         pTcpSocket->disconnect(pTcpSocket);
-    m_dataProcessor.disconnect(&m_dataProcessor);
+    m_dataProcessor.disconnect();
 }
 
 /*!
@@ -748,20 +768,11 @@ qint64 QWebSocketPrivate::doWriteFrames(const QByteArray &data, bool isBinary)
 }
 
 /*!
- * \internal
- */
-quint32 QWebSocketPrivate::generateRandomNumber() const
-{
-    //TODO: need a better randomizer
-    return quint32((double(qrand()) / RAND_MAX) * std::numeric_limits<quint32>::max());
-}
-
-/*!
     \internal
  */
 quint32 QWebSocketPrivate::generateMaskingKey() const
 {
-    return generateRandomNumber();
+    return m_pMaskGenerator->nextMask();
 }
 
 /*!
@@ -772,7 +783,7 @@ QByteArray QWebSocketPrivate::generateKey() const
     QByteArray key;
 
     for (int i = 0; i < 4; ++i) {
-        const quint32 tmp = generateRandomNumber();
+        const quint32 tmp = m_pMaskGenerator->nextMask();
         key.append(static_cast<const char *>(static_cast<const void *>(&tmp)), sizeof(quint32));
     }
 
@@ -969,6 +980,11 @@ void QWebSocketPrivate::processStateChanged(QAbstractSocket::SocketState socketS
                                            QString(),
                                            QString(),
                                            m_key);
+            if (handshake.isEmpty()) {
+                m_pSocket->abort();
+                Q_EMIT q->error(QAbstractSocket::ConnectionRefusedError);
+                return;
+            }
             m_pSocket->write(handshake.toLatin1());
         }
         break;
@@ -1058,6 +1074,31 @@ QString QWebSocketPrivate::createHandShakeRequest(QString resourceName,
                                                   QByteArray key)
 {
     QStringList handshakeRequest;
+    if (resourceName.contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("The resource name contains newlines. " \
+                                      "Possible attack detected."));
+        return QString();
+    }
+    if (host.contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("The hostname contains newlines. " \
+                                      "Possible attack detected."));
+        return QString();
+    }
+    if (origin.contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("The origin contains newlines. " \
+                                      "Possible attack detected."));
+        return QString();
+    }
+    if (extensions.contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("The extensions attribute contains newlines. " \
+                                      "Possible attack detected."));
+        return QString();
+    }
+    if (protocols.contains(QStringLiteral("\r\n"))) {
+        setErrorString(QWebSocket::tr("The protocols attribute contains newlines. " \
+                                      "Possible attack detected."));
+        return QString();
+    }
 
     handshakeRequest << QStringLiteral("GET ") % resourceName % QStringLiteral(" HTTP/1.1") <<
                         QStringLiteral("Host: ") % host <<
@@ -1191,6 +1232,26 @@ void QWebSocketPrivate::setProxy(const QNetworkProxy &networkProxy)
 /*!
     \internal
  */
+void QWebSocketPrivate::setMaskGenerator(const QMaskGenerator *maskGenerator)
+{
+    if (!maskGenerator)
+        m_pMaskGenerator = &m_defaultMaskGenerator;
+    else if (maskGenerator != m_pMaskGenerator)
+        m_pMaskGenerator = const_cast<QMaskGenerator *>(maskGenerator);
+}
+
+/*!
+    \internal
+ */
+const QMaskGenerator *QWebSocketPrivate::maskGenerator() const
+{
+    Q_ASSERT(m_pMaskGenerator);
+    return m_pMaskGenerator;
+}
+
+/*!
+    \internal
+ */
 qint64 QWebSocketPrivate::readBufferSize() const
 {
     return m_readBufferSize;