bpf: Call __bpf_sk_lookup()/__bpf_skc_lookup() directly via TC hookpoint
[platform/kernel/linux-starfive.git] / net / core / filter.c
index e0f73ed..1b60a6c 100644 (file)
@@ -6649,8 +6649,12 @@ static const struct bpf_func_proto bpf_sk_lookup_udp_proto = {
 BPF_CALL_5(bpf_tc_skc_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return (unsigned long)bpf_skc_lookup(skb, tuple, len, IPPROTO_TCP,
-                                            netns_id, flags);
+       struct net *caller_net = dev_net(skb->dev);
+       int ifindex = skb->dev->ifindex;
+
+       return (unsigned long)__bpf_skc_lookup(skb, tuple, len, caller_net,
+                                              ifindex, IPPROTO_TCP, netns_id,
+                                              flags);
 }
 
 static const struct bpf_func_proto bpf_tc_skc_lookup_tcp_proto = {
@@ -6668,8 +6672,12 @@ static const struct bpf_func_proto bpf_tc_skc_lookup_tcp_proto = {
 BPF_CALL_5(bpf_tc_sk_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP,
-                                           netns_id, flags);
+       struct net *caller_net = dev_net(skb->dev);
+       int ifindex = skb->dev->ifindex;
+
+       return (unsigned long)__bpf_sk_lookup(skb, tuple, len, caller_net,
+                                             ifindex, IPPROTO_TCP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_tc_sk_lookup_tcp_proto = {
@@ -6687,8 +6695,12 @@ static const struct bpf_func_proto bpf_tc_sk_lookup_tcp_proto = {
 BPF_CALL_5(bpf_tc_sk_lookup_udp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP,
-                                           netns_id, flags);
+       struct net *caller_net = dev_net(skb->dev);
+       int ifindex = skb->dev->ifindex;
+
+       return (unsigned long)__bpf_sk_lookup(skb, tuple, len, caller_net,
+                                             ifindex, IPPROTO_UDP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_tc_sk_lookup_udp_proto = {