[IOT-2205] g_sslContextMutex synchronization improvements
authorDan Mihai <Daniel.Mihai@microsoft.com>
Thu, 4 May 2017 13:24:20 +0000 (06:24 -0700)
committerRandeep Singh <randeep.s@samsung.com>
Mon, 22 May 2017 11:48:37 +0000 (11:48 +0000)
Change-Id: Ib5e6df4e02da583e086a07591ef1e35f367ae4f0
Signed-off-by: Dan Mihai <Daniel.Mihai@microsoft.com>
Reviewed-on: https://gerrit.iotivity.org/gerrit/19633
Tested-by: jenkins-iotivity <jenkins@iotivity.org>
Reviewed-by: Kevin Kane <kkane@microsoft.com>
Reviewed-by: Dmitriy Zhuravlev <d.zhuravlev@samsung.com>
Reviewed-by: Oleksii Beketov <ol.beketov@samsung.com>
Reviewed-by: Randeep Singh <randeep.s@samsung.com>
resource/csdk/connectivity/src/adapter_util/ca_adapter_net_ssl.c

index a57f66a..ec685ca 100644 (file)
@@ -196,12 +196,16 @@ do
  * @param[in] peer remote peer
  */
 #define SSL_RES(peer, status)                                                                      \
-if (g_sslCallback)                                                                                 \
+do                                                                                                 \
 {                                                                                                  \
-    CAErrorInfo_t errorInfo;                                                                       \
-    errorInfo.result = (status);                                                                   \
-    g_sslCallback(&(peer)->sep.endpoint, &errorInfo);                                              \
-}
+    oc_mutex_assert_owner(g_sslContextMutex, true);                                                \
+    if (g_sslCallback)                                                                             \
+    {                                                                                              \
+        CAErrorInfo_t errorInfo;                                                                   \
+        errorInfo.result = (status);                                                               \
+        g_sslCallback(&(peer)->sep.endpoint, &errorInfo);                                          \
+    }                                                                                              \
+} while(false)
 
 /* OCF-defined EKU value indicating an identity certificate, that can be used for
  * TLS client and server authentication.  This is the DER encoding of the OID
@@ -424,7 +428,7 @@ static CAgetPkixInfoHandler g_getPkixInfoCallback = NULL;
 
 /**
  * @var g_dtlsContextMutex
- * @brief Mutex to synchronize access to g_caSslContext.
+ * @brief Mutex to synchronize access to g_caSslContext and g_sslCallback.
  */
 static oc_mutex g_sslContextMutex = NULL;
 
@@ -784,6 +788,9 @@ static SslEndPoint_t *GetSslPeer(const CAEndpoint_t *peer)
     size_t listIndex = 0;
     size_t listLength = 0;
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
+
+    oc_mutex_assert_owner(g_sslContextMutex, true);
+
     VERIFY_NON_NULL_RET(peer, NET_SSL_TAG, "TLS peer is NULL", NULL);
     VERIFY_NON_NULL_RET(g_caSslContext, NET_SSL_TAG, "SSL Context is NULL", NULL);
 
@@ -826,8 +833,8 @@ CAResult_t GetCASecureEndpointData(const CAEndpoint_t* peer, CASecureEndpoint_t*
 {
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
 
-    oc_mutex_assert_owner(g_sslContextMutex, false);
     oc_mutex_lock(g_sslContextMutex);
+
     if (NULL == g_caSslContext)
     {
         OIC_LOG(ERROR, NET_SSL_TAG, "Context is NULL");
@@ -836,17 +843,21 @@ CAResult_t GetCASecureEndpointData(const CAEndpoint_t* peer, CASecureEndpoint_t*
     }
 
     SslEndPoint_t* sslPeer = GetSslPeer(peer);
-    if(sslPeer)
+    if (sslPeer)
     {
-        OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
+        // sslPeer could be destroyed after releasing the lock, so make a copy
+        // of the endpoint information before releasing the lock.
         memcpy(sep, &sslPeer->sep, sizeof(sslPeer->sep));
         oc_mutex_unlock(g_sslContextMutex);
+
+        OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
         return CA_STATUS_OK;
     }
 
-    OIC_LOG(DEBUG, NET_SSL_TAG, "Return NULL");
-    OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
     oc_mutex_unlock(g_sslContextMutex);
+
+    OIC_LOG(DEBUG, NET_SSL_TAG, "GetSslPeer returned NULL");
+    OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
     return CA_STATUS_INVALID_PARAM;
 }
 
@@ -865,10 +876,7 @@ bool SetCASecureEndpointAttribute(const CAEndpoint_t* peer, uint32_t newAttribut
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s(peer = %s:%u, attribute = %#x)", __func__,
         peer->addr, (uint32_t)peer->port, newAttribute);
 
-    // Acquiring g_sslContextMutex recursively here is not supported, so assert
-    // that the caller already owns this mutex. IOT-1876 tracks a possible
-    // refactoring of the code that is using g_sslContextMutex, to address these
-    // API quirks.
+    // In the current implementation, the caller already owns g_sslContextMutex.
     oc_mutex_assert_owner(g_sslContextMutex, true);
 
     if (NULL == g_caSslContext)
@@ -909,6 +917,8 @@ bool GetCASecureEndpointAttributes(const CAEndpoint_t* peer, uint32_t* allAttrib
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s(peer = %s:%u)", __func__,
         peer->addr, (uint32_t)peer->port);
 
+    // In the current implementation, the caller doesn't own g_sslContextMutex.
+    oc_mutex_assert_owner(g_sslContextMutex, false);
     oc_mutex_lock(g_sslContextMutex);
 
     if (NULL == g_caSslContext)
@@ -999,8 +1009,11 @@ static void DeleteSslEndPoint(SslEndPoint_t * tep)
  */
 static void RemovePeerFromList(CAEndpoint_t * endpoint)
 {
+    oc_mutex_assert_owner(g_sslContextMutex, true);
+
     VERIFY_NON_NULL_VOID(g_caSslContext, NET_SSL_TAG, "SSL Context is NULL");
     VERIFY_NON_NULL_VOID(endpoint, NET_SSL_TAG, "endpoint");
+
     size_t listLength = u_arraylist_length(g_caSslContext->peerList);
     for (size_t listIndex = 0; listIndex < listLength; listIndex++)
     {
@@ -1037,6 +1050,7 @@ static bool checkSslOperation(SslEndPoint_t*  peer,
                               unsigned char msg)
 {
     OC_UNUSED(str);
+    OC_UNUSED(msg);
 
     if ((0 != ret) &&
         (MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY != ret) &&
@@ -1060,7 +1074,7 @@ static bool checkSslOperation(SslEndPoint_t*  peer,
         (MBEDTLS_SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL != ret))
     {
         OIC_LOG_V(ERROR, NET_SSL_TAG, "%s: -0x%x", (str), -ret);
-        (void)msg;
+        oc_mutex_lock(g_sslContextMutex);
 
         if (MBEDTLS_ERR_SSL_BAD_HS_CLIENT_HELLO != ret)
         {
@@ -1068,6 +1082,8 @@ static bool checkSslOperation(SslEndPoint_t*  peer,
         }
 
         RemovePeerFromList(&(peer)->sep.endpoint);
+
+        oc_mutex_unlock(g_sslContextMutex);
         return false;
     }
 
@@ -1079,6 +1095,8 @@ static bool checkSslOperation(SslEndPoint_t*  peer,
  */
 static void DeletePeerList()
 {
+    oc_mutex_assert_owner(g_sslContextMutex, true);
+
     VERIFY_NON_NULL_VOID(g_caSslContext, NET_SSL_TAG, "SSL Context is NULL");
 
     size_t listLength = u_arraylist_length(g_caSslContext->peerList);
@@ -1391,9 +1409,11 @@ static SslEndPoint_t * InitiateTlsHandshake(const CAEndpoint_t *endpoint)
     //Load allowed SVR suites from SVR DB
     SetupCipher(config, endpoint->adapter, endpoint->remoteId);
 
+    oc_mutex_lock(g_sslContextMutex);
     ret = u_arraylist_add(g_caSslContext->peerList, (void *) tep);
     if (!ret)
     {
+        oc_mutex_unlock(g_sslContextMutex);
         OIC_LOG(ERROR, NET_SSL_TAG, "u_arraylist_add failed!");
         DeleteSslEndPoint(tep);
         return NULL;
@@ -1410,6 +1430,7 @@ static SslEndPoint_t * InitiateTlsHandshake(const CAEndpoint_t *endpoint)
         {
             OIC_LOG(ERROR, NET_SSL_TAG, "Handshake failed due to socket error");
             RemovePeerFromList(&tep->sep.endpoint);
+            oc_mutex_unlock(g_sslContextMutex);
             return NULL;
         }
         if (!checkSslOperation(tep,
@@ -1417,11 +1438,14 @@ static SslEndPoint_t * InitiateTlsHandshake(const CAEndpoint_t *endpoint)
                                "Handshake error",
                                MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE))
         {
+            oc_mutex_unlock(g_sslContextMutex);
             OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
             DeleteSslEndPoint(tep);
             return NULL;
         }
     }
+
+    oc_mutex_unlock(g_sslContextMutex);
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
     return tep;
 }
@@ -1585,8 +1609,9 @@ CAResult_t CAinitSslAdapter()
     // Initialize mutex for tlsContext
     if (NULL == g_sslContextMutex)
     {
-        g_sslContextMutex = oc_mutex_new();
-        VERIFY_NON_NULL_RET(g_sslContextMutex, NET_SSL_TAG, "malloc failed", CA_MEMORY_ALLOC_FAILED);
+        g_sslContextMutex = oc_mutex_new_recursive();
+        VERIFY_NON_NULL_RET(g_sslContextMutex, NET_SSL_TAG, "oc_mutex_new_recursive failed",
+            CA_MEMORY_ALLOC_FAILED);
     }
     else
     {
@@ -1725,7 +1750,7 @@ CAResult_t CAinitSslAdapter()
     return CA_STATUS_OK;
 }
 
-SslCacheMessage_t *  NewCacheMessage(uint8_t * data, size_t dataLen)
+SslCacheMessage_t *NewCacheMessage(uint8_t * data, size_t dataLen)
 {
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
     VERIFY_NON_NULL_RET(data, NET_SSL_TAG, "Param data is NULL" , NULL);
@@ -1848,6 +1873,10 @@ CAResult_t CAencryptSsl(const CAEndpoint_t *endpoint,
 static void SendCacheMessages(SslEndPoint_t * tep)
 {
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
+
+    // The mutex protects the access to tep.
+    oc_mutex_assert_owner(g_sslContextMutex, true);
+
     VERIFY_NON_NULL_VOID(tep, NET_SSL_TAG, "Param tep is NULL");
 
     size_t listIndex = 0;
@@ -1904,7 +1933,11 @@ static void SendCacheMessages(SslEndPoint_t * tep)
 void CAsetSslHandshakeCallback(CAErrorCallback tlsHandshakeCallback)
 {
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s(%p)", __func__, tlsHandshakeCallback);
+
+    oc_mutex_lock(g_sslContextMutex);
     g_sslCallback = tlsHandshakeCallback;
+    oc_mutex_unlock(g_sslContextMutex);
+
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s(%p)", __func__, tlsHandshakeCallback);
 }
 
@@ -2297,7 +2330,14 @@ static SslCipher_t GetCipherIndex(const uint32_t cipher)
 CAResult_t CAsetTlsCipherSuite(const uint32_t cipher)
 {
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
-    VERIFY_NON_NULL_RET(g_caSslContext, NET_SSL_TAG, "SSL context is not initialized." , CA_STATUS_NOT_INITIALIZED);
+    oc_mutex_lock(g_sslContextMutex);
+
+    if (NULL == g_caSslContext)
+    {
+        OIC_LOG(ERROR, NET_SSL_TAG, "SSL context is not initialized.");
+        oc_mutex_unlock(g_sslContextMutex);
+        return CA_STATUS_NOT_INITIALIZED;
+    }
 
     SslCipher_t index = GetCipherIndex(cipher);
     if (SSL_CIPHER_MAX == index)
@@ -2318,6 +2358,7 @@ CAResult_t CAsetTlsCipherSuite(const uint32_t cipher)
     }
     g_caSslContext->cipher = index;
 
+    oc_mutex_unlock(g_sslContextMutex);
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
     return CA_STATUS_OK;
 }
@@ -2327,6 +2368,7 @@ CAResult_t CAinitiateSslHandshake(const CAEndpoint_t *endpoint)
     CAResult_t res = CA_STATUS_OK;
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "In %s", __func__);
     VERIFY_NON_NULL_RET(endpoint, NET_SSL_TAG, "Param endpoint is NULL" , CA_STATUS_INVALID_PARAM);
+    oc_mutex_lock(g_sslContextMutex);
 
     if (NULL != GetSslPeer(endpoint))
     {
@@ -2337,12 +2379,12 @@ CAResult_t CAinitiateSslHandshake(const CAEndpoint_t *endpoint)
         }
     }
 
-    oc_mutex_lock(g_sslContextMutex);
     if (NULL == InitiateTlsHandshake(endpoint))
     {
         OIC_LOG(ERROR, NET_SSL_TAG, "TLS handshake failed");
         res = CA_STATUS_FAILED;
     }
+
     oc_mutex_unlock(g_sslContextMutex);
     OIC_LOG_V(DEBUG, NET_SSL_TAG, "Out %s", __func__);
     return res;