misc: fastrpc: Fix use-after-free and race in fastrpc_map_find
[platform/kernel/linux-starfive.git] / drivers / misc / fastrpc.c
index 7ff0b63..fc2173c 100644 (file)
@@ -333,30 +333,31 @@ static void fastrpc_map_get(struct fastrpc_map *map)
 
 
 static int fastrpc_map_lookup(struct fastrpc_user *fl, int fd,
-                           struct fastrpc_map **ppmap)
+                           struct fastrpc_map **ppmap, bool take_ref)
 {
+       struct fastrpc_session_ctx *sess = fl->sctx;
        struct fastrpc_map *map = NULL;
+       int ret = -ENOENT;
 
-       mutex_lock(&fl->mutex);
+       spin_lock(&fl->lock);
        list_for_each_entry(map, &fl->maps, node) {
-               if (map->fd == fd) {
-                       *ppmap = map;
-                       mutex_unlock(&fl->mutex);
-                       return 0;
-               }
-       }
-       mutex_unlock(&fl->mutex);
-
-       return -ENOENT;
-}
+               if (map->fd != fd)
+                       continue;
 
-static int fastrpc_map_find(struct fastrpc_user *fl, int fd,
-                           struct fastrpc_map **ppmap)
-{
-       int ret = fastrpc_map_lookup(fl, fd, ppmap);
+               if (take_ref) {
+                       ret = fastrpc_map_get(map);
+                       if (ret) {
+                               dev_dbg(sess->dev, "%s: Failed to get map fd=%d ret=%d\n",
+                                       __func__, fd, ret);
+                               break;
+                       }
+               }
 
-       if (!ret)
-               fastrpc_map_get(*ppmap);
+               *ppmap = map;
+               ret = 0;
+               break;
+       }
+       spin_unlock(&fl->lock);
 
        return ret;
 }
@@ -703,7 +704,7 @@ static int fastrpc_map_create(struct fastrpc_user *fl, int fd,
        struct fastrpc_map *map = NULL;
        int err = 0;
 
-       if (!fastrpc_map_find(fl, fd, ppmap))
+       if (!fastrpc_map_lookup(fl, fd, ppmap, true))
                return 0;
 
        map = kzalloc(sizeof(*map), GFP_KERNEL);
@@ -1026,7 +1027,7 @@ static int fastrpc_put_args(struct fastrpc_invoke_ctx *ctx,
        for (i = 0; i < FASTRPC_MAX_FDLIST; i++) {
                if (!fdlist[i])
                        break;
-               if (!fastrpc_map_lookup(fl, (int)fdlist[i], &mmap))
+               if (!fastrpc_map_lookup(fl, (int)fdlist[i], &mmap, false))
                        fastrpc_map_put(mmap);
        }