In the current code, msg->data is set as sg_virt(&sg[i]) + start - offset
and msg->data_end relative to it as msg->data + bytes. Using iterator i
to point to the updated starting scatterlist element holds true for some
cases, however not for all where we'd end up pointing out of bounds. It
is /correct/ for these ones:
1) When first finding the starting scatterlist element (sge) where we
find that the page is already privately owned by the msg and where
the requested bytes and headroom fit into the sge's length.
However, it's /incorrect/ for the following ones:
2) After we made the requested area private and updated the newly allocated
page into first_sg slot of the scatterlist ring; when we find that no
shift repair of the ring is needed where we bail out updating msg->data
and msg->data_end. At that point i will point to last_sg, which in this
case is the next elem of first_sg in the ring. The sge at that point
might as well be invalid (e.g. i == msg->sg_end), which we use for
setting the range of sg_virt(&sg[i]). The correct one would have been
first_sg.
3) Similar as in 2) but when we find that a shift repair of the ring is
needed. In this case we fix up all sges and stop once we've reached the
end. In this case i will point to will point to the new msg->sg_end,
and the sge at that point will be invalid. Again here the requested
range sits in first_sg.
Fixes:
015632bb30da ("bpf: sk_msg program helper bpf_sk_msg_pull_data")
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
if (unlikely(start >= offset + len))
return -EINVAL;
+ first_sg = i;
/* The start may point into the sg element so we need to also
* account for the headroom.
*/
if (!msg->sg_copy[i] && bytes_sg_total <= len)
goto out;
- first_sg = i;
-
/* At this point we need to linearize multiple scatterlist
* elements or a single shared page. Either way we need to
* copy into a linear buffer exclusively owned by BPF. Then
if (msg->sg_end < 0)
msg->sg_end += MAX_SKB_FRAGS;
out:
- msg->data = sg_virt(&sg[i]) + start - offset;
+ msg->data = sg_virt(&sg[first_sg]) + start - offset;
msg->data_end = msg->data + bytes;
return 0;