sctp: only allow the asoc reset when the asoc outq is empty
[platform/kernel/linux-starfive.git] / net / sctp / stream.c
index b8c8cab..f3b7d27 100644 (file)
@@ -254,6 +254,30 @@ static int sctp_send_reconf(struct sctp_association *asoc,
        return retval;
 }
 
+static bool sctp_stream_outq_is_empty(struct sctp_stream *stream,
+                                     __u16 str_nums, __be16 *str_list)
+{
+       struct sctp_association *asoc;
+       __u16 i;
+
+       asoc = container_of(stream, struct sctp_association, stream);
+       if (!asoc->outqueue.out_qlen)
+               return true;
+
+       if (!str_nums)
+               return false;
+
+       for (i = 0; i < str_nums; i++) {
+               __u16 sid = ntohs(str_list[i]);
+
+               if (stream->out[sid].ext &&
+                   !list_empty(&stream->out[sid].ext->outq))
+                       return false;
+       }
+
+       return true;
+}
+
 int sctp_send_reset_streams(struct sctp_association *asoc,
                            struct sctp_reset_streams *params)
 {
@@ -282,15 +306,31 @@ int sctp_send_reset_streams(struct sctp_association *asoc,
 
        str_nums = params->srs_number_streams;
        str_list = params->srs_stream_list;
-       if (out && str_nums)
-               for (i = 0; i < str_nums; i++)
-                       if (str_list[i] >= stream->outcnt)
-                               goto out;
+       if (str_nums) {
+               int param_len = 0;
 
-       if (in && str_nums)
-               for (i = 0; i < str_nums; i++)
-                       if (str_list[i] >= stream->incnt)
-                               goto out;
+               if (out) {
+                       for (i = 0; i < str_nums; i++)
+                               if (str_list[i] >= stream->outcnt)
+                                       goto out;
+
+                       param_len = str_nums * sizeof(__u16) +
+                                   sizeof(struct sctp_strreset_outreq);
+               }
+
+               if (in) {
+                       for (i = 0; i < str_nums; i++)
+                               if (str_list[i] >= stream->incnt)
+                                       goto out;
+
+                       param_len += str_nums * sizeof(__u16) +
+                                    sizeof(struct sctp_strreset_inreq);
+               }
+
+               if (param_len > SCTP_MAX_CHUNK_LEN -
+                               sizeof(struct sctp_reconf_chunk))
+                       goto out;
+       }
 
        nstr_list = kcalloc(str_nums, sizeof(__be16), GFP_KERNEL);
        if (!nstr_list) {
@@ -301,6 +341,11 @@ int sctp_send_reset_streams(struct sctp_association *asoc,
        for (i = 0; i < str_nums; i++)
                nstr_list[i] = htons(str_list[i]);
 
+       if (out && !sctp_stream_outq_is_empty(stream, str_nums, nstr_list)) {
+               retval = -EAGAIN;
+               goto out;
+       }
+
        chunk = sctp_make_strreset_req(asoc, str_nums, nstr_list, out, in);
 
        kfree(nstr_list);
@@ -361,6 +406,9 @@ int sctp_send_reset_assoc(struct sctp_association *asoc)
        if (asoc->strreset_outstanding)
                return -EINPROGRESS;
 
+       if (!sctp_outq_is_empty(&asoc->outqueue))
+               return -EAGAIN;
+
        chunk = sctp_make_strreset_tsnreq(asoc);
        if (!chunk)
                return -ENOMEM;
@@ -547,7 +595,7 @@ struct sctp_chunk *sctp_process_strreset_outreq(
                flags = SCTP_STREAM_RESET_INCOMING_SSN;
        }
 
-       nums = (ntohs(param.p->length) - sizeof(*outreq)) / 2;
+       nums = (ntohs(param.p->length) - sizeof(*outreq)) / sizeof(__u16);
        if (nums) {
                str_p = outreq->list_of_streams;
                for (i = 0; i < nums; i++) {
@@ -611,7 +659,7 @@ struct sctp_chunk *sctp_process_strreset_inreq(
                goto out;
        }
 
-       nums = (ntohs(param.p->length) - sizeof(*inreq)) / 2;
+       nums = (ntohs(param.p->length) - sizeof(*inreq)) / sizeof(__u16);
        str_p = inreq->list_of_streams;
        for (i = 0; i < nums; i++) {
                if (ntohs(str_p[i]) >= stream->outcnt) {
@@ -620,6 +668,12 @@ struct sctp_chunk *sctp_process_strreset_inreq(
                }
        }
 
+       if (!sctp_stream_outq_is_empty(stream, nums, str_p)) {
+               result = SCTP_STRRESET_IN_PROGRESS;
+               asoc->strreset_inseq--;
+               goto err;
+       }
+
        chunk = sctp_make_strreset_req(asoc, nums, str_p, 1, 0);
        if (!chunk)
                goto out;
@@ -677,6 +731,12 @@ struct sctp_chunk *sctp_process_strreset_tsnreq(
                }
                goto err;
        }
+
+       if (!sctp_outq_is_empty(&asoc->outqueue)) {
+               result = SCTP_STRRESET_IN_PROGRESS;
+               goto err;
+       }
+
        asoc->strreset_inseq++;
 
        if (!(asoc->strreset_enable & SCTP_ENABLE_RESET_ASSOC_REQ))
@@ -911,7 +971,8 @@ struct sctp_chunk *sctp_process_strreset_resp(
 
                outreq = (struct sctp_strreset_outreq *)req;
                str_p = outreq->list_of_streams;
-               nums = (ntohs(outreq->param_hdr.length) - sizeof(*outreq)) / 2;
+               nums = (ntohs(outreq->param_hdr.length) - sizeof(*outreq)) /
+                      sizeof(__u16);
 
                if (result == SCTP_STRRESET_PERFORMED) {
                        if (nums) {
@@ -940,7 +1001,8 @@ struct sctp_chunk *sctp_process_strreset_resp(
 
                inreq = (struct sctp_strreset_inreq *)req;
                str_p = inreq->list_of_streams;
-               nums = (ntohs(inreq->param_hdr.length) - sizeof(*inreq)) / 2;
+               nums = (ntohs(inreq->param_hdr.length) - sizeof(*inreq)) /
+                      sizeof(__u16);
 
                *evp = sctp_ulpevent_make_stream_reset_event(asoc, flags,
                        nums, str_p, GFP_ATOMIC);