drm/amdgpu/psp: move shared buffer frees into single function
authorAlex Deucher <alexander.deucher@amd.com>
Fri, 22 Apr 2022 21:19:29 +0000 (17:19 -0400)
committerAlex Deucher <alexander.deucher@amd.com>
Wed, 4 May 2022 13:55:05 +0000 (09:55 -0400)
So we can properly clean up if any of the TAs or TMR fails
to properly initialize or terminate.  This avoids any
memory leaks in the error case.

Reviewed-by: Hawking Zhang <Hawking.Zhang@amd.com>
Signed-off-by: Alex Deucher <alexander.deucher@amd.com>
drivers/gpu/drm/amd/amdgpu/amdgpu_psp.c

index 1ef2aba..b1b6f5d 100644 (file)
@@ -153,6 +153,36 @@ static int psp_early_init(void *handle)
        return 0;
 }
 
+static void psp_free_shared_bufs(struct psp_context *psp)
+{
+       void *tmr_buf;
+       void **pptr;
+
+       /* free TMR memory buffer */
+       pptr = amdgpu_sriov_vf(psp->adev) ? &tmr_buf : NULL;
+       amdgpu_bo_free_kernel(&psp->tmr_bo, &psp->tmr_mc_addr, pptr);
+
+       /* free xgmi shared memory */
+       psp_ta_free_shared_buf(&psp->xgmi_context.context.mem_context);
+
+       /* free ras shared memory */
+       psp_ta_free_shared_buf(&psp->ras_context.context.mem_context);
+
+       /* free hdcp shared memory */
+       psp_ta_free_shared_buf(&psp->hdcp_context.context.mem_context);
+
+       /* free dtm shared memory */
+       psp_ta_free_shared_buf(&psp->dtm_context.context.mem_context);
+
+       /* free rap shared memory */
+       psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
+
+       /* free securedisplay shared memory */
+       psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
+
+
+}
+
 static void psp_memory_training_fini(struct psp_context *psp)
 {
        struct psp_memory_training_context *ctx = &psp->mem_train_ctx;
@@ -747,17 +777,7 @@ static int psp_tmr_unload(struct psp_context *psp)
 
 static int psp_tmr_terminate(struct psp_context *psp)
 {
-       int ret;
-       void *tmr_buf;
-       void **pptr;
-
-       ret = psp_tmr_unload(psp);
-
-       /* free TMR memory buffer */
-       pptr = amdgpu_sriov_vf(psp->adev) ? &tmr_buf : NULL;
-       amdgpu_bo_free_kernel(&psp->tmr_bo, &psp->tmr_mc_addr, pptr);
-
-       return ret;
+       return psp_tmr_unload(psp);
 }
 
 int psp_get_fw_attestation_records_addr(struct psp_context *psp,
@@ -1102,9 +1122,6 @@ int psp_xgmi_terminate(struct psp_context *psp)
 
        psp->xgmi_context.context.initialized = false;
 
-       /* free xgmi shared memory */
-       psp_ta_free_shared_buf(&psp->xgmi_context.context.mem_context);
-
        return ret;
 }
 
@@ -1465,9 +1482,6 @@ int psp_ras_terminate(struct psp_context *psp)
 
        psp->ras_context.context.initialized = false;
 
-       /* free ras shared memory */
-       psp_ta_free_shared_buf(&psp->ras_context.context.mem_context);
-
        return ret;
 }
 
@@ -1650,23 +1664,13 @@ static int psp_hdcp_terminate(struct psp_context *psp)
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       if (!psp->hdcp_context.context.initialized) {
-               if (psp->hdcp_context.context.mem_context.shared_buf) {
-                       ret = 0;
-                       goto out;
-               } else {
-                       return 0;
-               }
-       }
+       if (!psp->hdcp_context.context.initialized)
+               return 0;
 
        ret = psp_ta_unload(psp, &psp->hdcp_context.context);
 
        psp->hdcp_context.context.initialized = false;
 
-out:
-       /* free hdcp shared memory */
-       psp_ta_free_shared_buf(&psp->hdcp_context.context.mem_context);
-
        return ret;
 }
 // HDCP end
@@ -1727,23 +1731,13 @@ static int psp_dtm_terminate(struct psp_context *psp)
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       if (!psp->dtm_context.context.initialized) {
-               if (psp->dtm_context.context.mem_context.shared_buf) {
-                       ret = 0;
-                       goto out;
-               } else {
-                       return 0;
-               }
-       }
+       if (!psp->dtm_context.context.initialized)
+               return 0;
 
        ret = psp_ta_unload(psp, &psp->dtm_context.context);
 
        psp->dtm_context.context.initialized = false;
 
-out:
-       /* free dtm shared memory */
-       psp_ta_free_shared_buf(&psp->dtm_context.context.mem_context);
-
        return ret;
 }
 // DTM end
@@ -1785,6 +1779,8 @@ static int psp_rap_initialize(struct psp_context *psp)
        ret = psp_rap_invoke(psp, TA_CMD_RAP__INITIALIZE, &status);
        if (ret || status != TA_RAP_STATUS__SUCCESS) {
                psp_rap_terminate(psp);
+               /* free rap shared memory */
+               psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
 
                dev_warn(psp->adev->dev, "RAP TA initialize fail (%d) status %d.\n",
                         ret, status);
@@ -1806,9 +1802,6 @@ static int psp_rap_terminate(struct psp_context *psp)
 
        psp->rap_context.context.initialized = false;
 
-       /* free rap shared memory */
-       psp_ta_free_shared_buf(&psp->rap_context.context.mem_context);
-
        return ret;
 }
 
@@ -1889,6 +1882,8 @@ static int psp_securedisplay_initialize(struct psp_context *psp)
        ret = psp_securedisplay_invoke(psp, TA_SECUREDISPLAY_COMMAND__QUERY_TA);
        if (ret) {
                psp_securedisplay_terminate(psp);
+               /* free securedisplay shared memory */
+               psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
                dev_err(psp->adev->dev, "SECUREDISPLAY TA initialize fail.\n");
                return -EINVAL;
        }
@@ -1919,9 +1914,6 @@ static int psp_securedisplay_terminate(struct psp_context *psp)
 
        psp->securedisplay_context.context.initialized = false;
 
-       /* free securedisplay shared memory */
-       psp_ta_free_shared_buf(&psp->securedisplay_context.context.mem_context);
-
        return ret;
 }
 
@@ -2524,16 +2516,18 @@ static int psp_hw_fini(void *handle)
        }
 
        psp_asd_terminate(psp);
-
        psp_tmr_terminate(psp);
+
        psp_ring_destroy(psp, PSP_RING_TYPE__KM);
 
+       psp_free_shared_bufs(psp);
+
        return 0;
 }
 
 static int psp_suspend(void *handle)
 {
-       int ret;
+       int ret = 0;
        struct amdgpu_device *adev = (struct amdgpu_device *)handle;
        struct psp_context *psp = &adev->psp;
 
@@ -2542,7 +2536,7 @@ static int psp_suspend(void *handle)
                ret = psp_xgmi_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate xgmi ta\n");
-                       return ret;
+                       goto out;
                }
        }
 
@@ -2550,49 +2544,51 @@ static int psp_suspend(void *handle)
                ret = psp_ras_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate ras ta\n");
-                       return ret;
+                       goto out;
                }
                ret = psp_hdcp_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate hdcp ta\n");
-                       return ret;
+                       goto out;
                }
                ret = psp_dtm_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate dtm ta\n");
-                       return ret;
+                       goto out;
                }
                ret = psp_rap_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate rap ta\n");
-                       return ret;
+                       goto out;
                }
                ret = psp_securedisplay_terminate(psp);
                if (ret) {
                        DRM_ERROR("Failed to terminate securedisplay ta\n");
-                       return ret;
+                       goto out;
                }
        }
 
        ret = psp_asd_terminate(psp);
        if (ret) {
                DRM_ERROR("Failed to terminate asd\n");
-               return ret;
+               goto out;
        }
 
        ret = psp_tmr_terminate(psp);
        if (ret) {
                DRM_ERROR("Failed to terminate tmr\n");
-               return ret;
+               goto out;
        }
 
        ret = psp_ring_stop(psp, PSP_RING_TYPE__KM);
        if (ret) {
                DRM_ERROR("PSP ring stop failed\n");
-               return ret;
        }
 
-       return 0;
+out:
+       psp_free_shared_bufs(psp);
+
+       return ret;
 }
 
 static int psp_resume(void *handle)