libbpf: Support XDP_SHARED_UMEM with external XDP program
[platform/kernel/linux-starfive.git] / tools / lib / bpf / xsk.c
index 86c1b61..8ebd810 100644 (file)
@@ -586,15 +586,21 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
        if (!umem || !xsk_ptr || !rx || !tx)
                return -EFAULT;
 
-       if (umem->refcount) {
-               pr_warn("Error: shared umems not supported by libbpf.\n");
-               return -EBUSY;
-       }
-
        xsk = calloc(1, sizeof(*xsk));
        if (!xsk)
                return -ENOMEM;
 
+       err = xsk_set_xdp_socket_config(&xsk->config, usr_config);
+       if (err)
+               goto out_xsk_alloc;
+
+       if (umem->refcount &&
+           !(xsk->config.libbpf_flags & XSK_LIBBPF_FLAGS__INHIBIT_PROG_LOAD)) {
+               pr_warn("Error: shared umems not supported by libbpf supplied XDP program.\n");
+               err = -EBUSY;
+               goto out_xsk_alloc;
+       }
+
        if (umem->refcount++ > 0) {
                xsk->fd = socket(AF_XDP, SOCK_RAW, 0);
                if (xsk->fd < 0) {
@@ -616,10 +622,6 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
        memcpy(xsk->ifname, ifname, IFNAMSIZ - 1);
        xsk->ifname[IFNAMSIZ - 1] = '\0';
 
-       err = xsk_set_xdp_socket_config(&xsk->config, usr_config);
-       if (err)
-               goto out_socket;
-
        if (rx) {
                err = setsockopt(xsk->fd, SOL_XDP, XDP_RX_RING,
                                 &xsk->config.rx_size,
@@ -687,7 +689,12 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
        sxdp.sxdp_family = PF_XDP;
        sxdp.sxdp_ifindex = xsk->ifindex;
        sxdp.sxdp_queue_id = xsk->queue_id;
-       sxdp.sxdp_flags = xsk->config.bind_flags;
+       if (umem->refcount > 1) {
+               sxdp.sxdp_flags = XDP_SHARED_UMEM;
+               sxdp.sxdp_shared_umem_fd = umem->fd;
+       } else {
+               sxdp.sxdp_flags = xsk->config.bind_flags;
+       }
 
        err = bind(xsk->fd, (struct sockaddr *)&sxdp, sizeof(sxdp));
        if (err) {