#include <net/sctp/sm.h>
#include <net/sctp/stream_sched.h>
-/* Fair Capacity handling
- * RFC 8260 section 3.5
+/* Fair Capacity and Weighted Fair Queueing handling
+ * RFC 8260 section 3.5 and 3.6
*/
static void sctp_sched_fc_unsched_all(struct sctp_stream *stream);
+static int sctp_sched_wfq_set(struct sctp_stream *stream, __u16 sid,
+ __u16 weight, gfp_t gfp)
+{
+ struct sctp_stream_out_ext *soute = SCTP_SO(stream, sid)->ext;
+
+ if (!weight)
+ return -EINVAL;
+
+ soute->fc_weight = weight;
+ return 0;
+}
+
+static int sctp_sched_wfq_get(struct sctp_stream *stream, __u16 sid,
+ __u16 *value)
+{
+ struct sctp_stream_out_ext *soute = SCTP_SO(stream, sid)->ext;
+
+ *value = soute->fc_weight;
+ return 0;
+}
+
static int sctp_sched_fc_set(struct sctp_stream *stream, __u16 sid,
__u16 weight, gfp_t gfp)
{
INIT_LIST_HEAD(&soute->fc_list);
soute->fc_length = 0;
+ soute->fc_weight = 1;
return 0;
}
return;
list_for_each_entry(pos, &stream->fc_list, fc_list)
- if (pos->fc_length >= soute->fc_length)
+ if ((__u64)pos->fc_length * soute->fc_weight >=
+ (__u64)soute->fc_length * pos->fc_weight)
break;
list_add_tail(&soute->fc_list, &pos->fc_list);
}
pos = soute;
list_for_each_entry_continue(pos, &stream->fc_list, fc_list)
- if (pos->fc_length >= soute->fc_length)
+ if ((__u64)pos->fc_length * soute->fc_weight >=
+ (__u64)soute->fc_length * pos->fc_weight)
break;
list_move_tail(&soute->fc_list, &pos->fc_list);
}
{
sctp_sched_ops_register(SCTP_SS_FC, &sctp_sched_fc);
}
+
+static struct sctp_sched_ops sctp_sched_wfq = {
+ .set = sctp_sched_wfq_set,
+ .get = sctp_sched_wfq_get,
+ .init = sctp_sched_fc_init,
+ .init_sid = sctp_sched_fc_init_sid,
+ .free_sid = sctp_sched_fc_free_sid,
+ .enqueue = sctp_sched_fc_enqueue,
+ .dequeue = sctp_sched_fc_dequeue,
+ .dequeue_done = sctp_sched_fc_dequeue_done,
+ .sched_all = sctp_sched_fc_sched_all,
+ .unsched_all = sctp_sched_fc_unsched_all,
+};
+
+void sctp_sched_ops_wfq_init(void)
+{
+ sctp_sched_ops_register(SCTP_SS_WFQ, &sctp_sched_wfq);
+}