vp8[loongarch]: Optimize vp8 encoding partial function
[platform/upstream/libvpx.git] / vpx_dsp / loongarch / subtract_lsx.c
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10
11 #include "./vpx_dsp_rtcd.h"
12 #include "vpx_util/loongson_intrinsics.h"
13
14 static void sub_blk_4x4_lsx(const uint8_t *src_ptr, int32_t src_stride,
15                             const uint8_t *pred_ptr, int32_t pred_stride,
16                             int16_t *diff_ptr, int32_t diff_stride) {
17   __m128i src0, src1, src2, src3;
18   __m128i pred0, pred1, pred2, pred3;
19   __m128i diff0, diff1;
20   __m128i reg0, reg1;
21   int32_t src_stride2 = src_stride << 1;
22   int32_t pred_stride2 = pred_stride << 1;
23   int32_t diff_stride2 = diff_stride << 1;
24   int32_t src_stride3 = src_stride2 + src_stride;
25   int32_t pred_stride3 = pred_stride2 + pred_stride;
26   int32_t diff_stride3 = diff_stride2 + diff_stride;
27
28   DUP4_ARG2(__lsx_vldrepl_w, src_ptr, 0, src_ptr + src_stride, 0,
29             src_ptr + src_stride2, 0, src_ptr + src_stride3, 0, src0, src1,
30             src2, src3);
31   DUP4_ARG2(__lsx_vldrepl_w, pred_ptr, 0, pred_ptr + pred_stride, 0,
32             pred_ptr + pred_stride2, 0, pred_ptr + pred_stride3, 0, pred0,
33             pred1, pred2, pred3);
34   DUP4_ARG2(__lsx_vilvl_w, src1, src0, src3, src2, pred1, pred0, pred3, pred2,
35             src0, src2, pred0, pred2);
36   DUP2_ARG2(__lsx_vilvl_d, src2, src0, pred2, pred0, src0, pred0);
37   reg0 = __lsx_vilvl_b(src0, pred0);
38   reg1 = __lsx_vilvh_b(src0, pred0);
39   DUP2_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, diff0, diff1);
40   __lsx_vstelm_d(diff0, diff_ptr, 0, 0);
41   __lsx_vstelm_d(diff0, diff_ptr + diff_stride, 0, 1);
42   __lsx_vstelm_d(diff1, diff_ptr + diff_stride2, 0, 0);
43   __lsx_vstelm_d(diff1, diff_ptr + diff_stride3, 0, 1);
44 }
45
46 static void sub_blk_8x8_lsx(const uint8_t *src_ptr, int32_t src_stride,
47                             const uint8_t *pred_ptr, int32_t pred_stride,
48                             int16_t *diff_ptr, int32_t diff_stride) {
49   __m128i src0, src1, src2, src3, src4, src5, src6, src7;
50   __m128i pred0, pred1, pred2, pred3, pred4, pred5, pred6, pred7;
51   __m128i reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7;
52   int32_t src_stride2 = src_stride << 1;
53   int32_t pred_stride2 = pred_stride << 1;
54   int32_t dst_stride = diff_stride << 1;
55   int32_t src_stride3 = src_stride2 + src_stride;
56   int32_t pred_stride3 = pred_stride2 + pred_stride;
57   int32_t dst_stride2 = dst_stride << 1;
58   int32_t src_stride4 = src_stride2 << 1;
59   int32_t pred_stride4 = pred_stride2 << 1;
60   int32_t dst_stride3 = dst_stride + dst_stride2;
61
62   DUP4_ARG2(__lsx_vldrepl_d, src_ptr, 0, src_ptr + src_stride, 0,
63             src_ptr + src_stride2, 0, src_ptr + src_stride3, 0, src0, src1,
64             src2, src3);
65   DUP4_ARG2(__lsx_vldrepl_d, pred_ptr, 0, pred_ptr + pred_stride, 0,
66             pred_ptr + pred_stride2, 0, pred_ptr + pred_stride3, 0, pred0,
67             pred1, pred2, pred3);
68   src_ptr += src_stride4;
69   pred_ptr += pred_stride4;
70
71   DUP4_ARG2(__lsx_vldrepl_d, src_ptr, 0, src_ptr + src_stride, 0,
72             src_ptr + src_stride2, 0, src_ptr + src_stride3, 0, src4, src5,
73             src6, src7);
74   DUP4_ARG2(__lsx_vldrepl_d, pred_ptr, 0, pred_ptr + pred_stride, 0,
75             pred_ptr + pred_stride2, 0, pred_ptr + pred_stride3, 0, pred4,
76             pred5, pred6, pred7);
77
78   DUP4_ARG2(__lsx_vilvl_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
79             reg0, reg1, reg2, reg3);
80   DUP4_ARG2(__lsx_vilvl_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
81             reg4, reg5, reg6, reg7);
82   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, reg2, reg2, reg3, reg3,
83             src0, src1, src2, src3);
84   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg4, reg4, reg5, reg5, reg6, reg6, reg7, reg7,
85             src4, src5, src6, src7);
86   __lsx_vst(src0, diff_ptr, 0);
87   __lsx_vstx(src1, diff_ptr, dst_stride);
88   __lsx_vstx(src2, diff_ptr, dst_stride2);
89   __lsx_vstx(src3, diff_ptr, dst_stride3);
90   diff_ptr += dst_stride2;
91   __lsx_vst(src4, diff_ptr, 0);
92   __lsx_vstx(src5, diff_ptr, dst_stride);
93   __lsx_vstx(src6, diff_ptr, dst_stride2);
94   __lsx_vstx(src7, diff_ptr, dst_stride3);
95 }
96
97 static void sub_blk_16x16_lsx(const uint8_t *src, int32_t src_stride,
98                               const uint8_t *pred, int32_t pred_stride,
99                               int16_t *diff, int32_t diff_stride) {
100   __m128i src0, src1, src2, src3, src4, src5, src6, src7;
101   __m128i pred0, pred1, pred2, pred3, pred4, pred5, pred6, pred7;
102   __m128i reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7;
103   __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
104   int32_t src_stride2 = src_stride << 1;
105   int32_t pred_stride2 = pred_stride << 1;
106   int32_t dst_stride = diff_stride << 1;
107   int32_t src_stride3 = src_stride2 + src_stride;
108   int32_t pred_stride3 = pred_stride2 + pred_stride;
109   int32_t dst_stride2 = dst_stride << 1;
110   int32_t src_stride4 = src_stride2 << 1;
111   int32_t pred_stride4 = pred_stride2 << 1;
112   int32_t dst_stride3 = dst_stride + dst_stride2;
113   int16_t *diff_tmp = diff + 8;
114
115   DUP2_ARG2(__lsx_vld, src, 0, pred, 0, src0, pred0);
116   DUP4_ARG2(__lsx_vldx, src, src_stride, src, src_stride2, src, src_stride3,
117             src, src_stride4, src1, src2, src3, src4);
118   DUP4_ARG2(__lsx_vldx, pred, pred_stride, pred, pred_stride2, pred,
119             pred_stride3, pred, pred_stride4, pred1, pred2, pred3, pred4);
120   src += src_stride4;
121   pred += pred_stride4;
122   DUP4_ARG2(__lsx_vldx, src, src_stride, src, src_stride2, src, src_stride3,
123             pred, pred_stride, src5, src6, src7, pred5);
124   DUP2_ARG2(__lsx_vldx, pred, pred_stride2, pred, pred_stride3, pred6, pred7);
125   src += src_stride4;
126   pred += pred_stride4;
127   DUP4_ARG2(__lsx_vilvl_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
128             reg0, reg2, reg4, reg6);
129   DUP4_ARG2(__lsx_vilvh_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
130             reg1, reg3, reg5, reg7);
131   DUP4_ARG2(__lsx_vilvl_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
132             tmp0, tmp2, tmp4, tmp6);
133   DUP4_ARG2(__lsx_vilvh_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
134             tmp1, tmp3, tmp5, tmp7);
135   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, reg2, reg2, reg3, reg3,
136             src0, src1, src2, src3);
137   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg4, reg4, reg5, reg5, reg6, reg6, reg7, reg7,
138             src4, src5, src6, src7);
139   DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp0, tmp0, tmp1, tmp1, tmp2, tmp2, tmp3, tmp3,
140             pred0, pred1, pred2, pred3);
141   DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp4, tmp4, tmp5, tmp5, tmp6, tmp6, tmp7, tmp7,
142             pred4, pred5, pred6, pred7);
143   __lsx_vst(src0, diff, 0);
144   __lsx_vstx(src2, diff, dst_stride);
145   __lsx_vstx(src4, diff, dst_stride2);
146   __lsx_vstx(src6, diff, dst_stride3);
147   __lsx_vst(src1, diff_tmp, 0);
148   __lsx_vstx(src3, diff_tmp, dst_stride);
149   __lsx_vstx(src5, diff_tmp, dst_stride2);
150   __lsx_vstx(src7, diff_tmp, dst_stride3);
151   diff += dst_stride2;
152   diff_tmp += dst_stride2;
153   __lsx_vst(pred0, diff, 0);
154   __lsx_vstx(pred2, diff, dst_stride);
155   __lsx_vstx(pred4, diff, dst_stride2);
156   __lsx_vstx(pred6, diff, dst_stride3);
157   __lsx_vst(pred1, diff_tmp, 0);
158   __lsx_vstx(pred3, diff_tmp, dst_stride);
159   __lsx_vstx(pred5, diff_tmp, dst_stride2);
160   __lsx_vstx(pred7, diff_tmp, dst_stride3);
161   diff += dst_stride2;
162   diff_tmp += dst_stride2;
163   DUP2_ARG2(__lsx_vld, src, 0, pred, 0, src0, pred0);
164   DUP4_ARG2(__lsx_vldx, src, src_stride, src, src_stride2, src, src_stride3,
165             src, src_stride4, src1, src2, src3, src4);
166   DUP4_ARG2(__lsx_vldx, pred, pred_stride, pred, pred_stride2, pred,
167             pred_stride3, pred, pred_stride4, pred1, pred2, pred3, pred4);
168   src += src_stride4;
169   pred += pred_stride4;
170   DUP4_ARG2(__lsx_vldx, src, src_stride, src, src_stride2, src, src_stride3,
171             pred, pred_stride, src5, src6, src7, pred5);
172   DUP2_ARG2(__lsx_vldx, pred, pred_stride2, pred, pred_stride3, pred6, pred7);
173   DUP4_ARG2(__lsx_vilvl_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
174             reg0, reg2, reg4, reg6);
175   DUP4_ARG2(__lsx_vilvh_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
176             reg1, reg3, reg5, reg7);
177   DUP4_ARG2(__lsx_vilvl_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
178             tmp0, tmp2, tmp4, tmp6);
179   DUP4_ARG2(__lsx_vilvh_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
180             tmp1, tmp3, tmp5, tmp7);
181   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, reg2, reg2, reg3, reg3,
182             src0, src1, src2, src3);
183   DUP4_ARG2(__lsx_vhsubw_hu_bu, reg4, reg4, reg5, reg5, reg6, reg6, reg7, reg7,
184             src4, src5, src6, src7);
185   DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp0, tmp0, tmp1, tmp1, tmp2, tmp2, tmp3, tmp3,
186             pred0, pred1, pred2, pred3);
187   DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp4, tmp4, tmp5, tmp5, tmp6, tmp6, tmp7, tmp7,
188             pred4, pred5, pred6, pred7);
189   __lsx_vst(src0, diff, 0);
190   __lsx_vstx(src2, diff, dst_stride);
191   __lsx_vstx(src4, diff, dst_stride2);
192   __lsx_vstx(src6, diff, dst_stride3);
193   __lsx_vst(src1, diff_tmp, 0);
194   __lsx_vstx(src3, diff_tmp, dst_stride);
195   __lsx_vstx(src5, diff_tmp, dst_stride2);
196   __lsx_vstx(src7, diff_tmp, dst_stride3);
197   diff += dst_stride2;
198   diff_tmp += dst_stride2;
199   __lsx_vst(pred0, diff, 0);
200   __lsx_vstx(pred2, diff, dst_stride);
201   __lsx_vstx(pred4, diff, dst_stride2);
202   __lsx_vstx(pred6, diff, dst_stride3);
203   __lsx_vst(pred1, diff_tmp, 0);
204   __lsx_vstx(pred3, diff_tmp, dst_stride);
205   __lsx_vstx(pred5, diff_tmp, dst_stride2);
206   __lsx_vstx(pred7, diff_tmp, dst_stride3);
207 }
208
209 static void sub_blk_32x32_lsx(const uint8_t *src, int32_t src_stride,
210                               const uint8_t *pred, int32_t pred_stride,
211                               int16_t *diff, int32_t diff_stride) {
212   __m128i src0, src1, src2, src3, src4, src5, src6, src7;
213   __m128i pred0, pred1, pred2, pred3, pred4, pred5, pred6, pred7;
214   __m128i reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7;
215   __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
216   uint32_t loop_cnt;
217   int32_t src_stride2 = src_stride << 1;
218   int32_t pred_stride2 = pred_stride << 1;
219   int32_t src_stride3 = src_stride2 + src_stride;
220   int32_t pred_stride3 = pred_stride2 + pred_stride;
221   int32_t src_stride4 = src_stride2 << 1;
222   int32_t pred_stride4 = pred_stride2 << 1;
223
224   for (loop_cnt = 8; loop_cnt--;) {
225     const uint8_t *src_tmp = src + 16;
226     const uint8_t *pred_tmp = pred + 16;
227     DUP4_ARG2(__lsx_vld, src, 0, src_tmp, 0, pred, 0, pred_tmp, 0, src0, src1,
228               pred0, pred1);
229     DUP4_ARG2(__lsx_vldx, src, src_stride, src_tmp, src_stride, src,
230               src_stride2, src_tmp, src_stride2, src2, src3, src4, src5);
231     DUP4_ARG2(__lsx_vldx, src, src_stride3, src_tmp, src_stride3, pred,
232               pred_stride, pred_tmp, pred_stride, src6, src7, pred2, pred3);
233     DUP4_ARG2(__lsx_vldx, pred, pred_stride2, pred_tmp, pred_stride2, pred,
234               pred_stride3, pred_tmp, pred_stride3, pred4, pred5, pred6, pred7);
235     DUP4_ARG2(__lsx_vilvl_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
236               reg0, reg2, reg4, reg6);
237     DUP4_ARG2(__lsx_vilvh_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
238               reg1, reg3, reg5, reg7);
239     DUP4_ARG2(__lsx_vilvl_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
240               tmp0, tmp2, tmp4, tmp6);
241     DUP4_ARG2(__lsx_vilvh_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
242               tmp1, tmp3, tmp5, tmp7);
243     DUP4_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, reg2, reg2, reg3,
244               reg3, src0, src1, src2, src3);
245     DUP4_ARG2(__lsx_vhsubw_hu_bu, reg4, reg4, reg5, reg5, reg6, reg6, reg7,
246               reg7, src4, src5, src6, src7);
247     DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp0, tmp0, tmp1, tmp1, tmp2, tmp2, tmp3,
248               tmp3, pred0, pred1, pred2, pred3);
249     DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp4, tmp4, tmp5, tmp5, tmp6, tmp6, tmp7,
250               tmp7, pred4, pred5, pred6, pred7);
251     src += src_stride4;
252     pred += pred_stride4;
253     __lsx_vst(src0, diff, 0);
254     __lsx_vst(src1, diff, 16);
255     __lsx_vst(src2, diff, 32);
256     __lsx_vst(src3, diff, 48);
257     diff += diff_stride;
258     __lsx_vst(src4, diff, 0);
259     __lsx_vst(src5, diff, 16);
260     __lsx_vst(src6, diff, 32);
261     __lsx_vst(src7, diff, 48);
262     diff += diff_stride;
263     __lsx_vst(pred0, diff, 0);
264     __lsx_vst(pred1, diff, 16);
265     __lsx_vst(pred2, diff, 32);
266     __lsx_vst(pred3, diff, 48);
267     diff += diff_stride;
268     __lsx_vst(pred4, diff, 0);
269     __lsx_vst(pred5, diff, 16);
270     __lsx_vst(pred6, diff, 32);
271     __lsx_vst(pred7, diff, 48);
272     diff += diff_stride;
273   }
274 }
275
276 static void sub_blk_64x64_lsx(const uint8_t *src, int32_t src_stride,
277                               const uint8_t *pred, int32_t pred_stride,
278                               int16_t *diff, int32_t diff_stride) {
279   __m128i src0, src1, src2, src3, src4, src5, src6, src7;
280   __m128i pred0, pred1, pred2, pred3, pred4, pred5, pred6, pred7;
281   __m128i reg0, reg1, reg2, reg3, reg4, reg5, reg6, reg7;
282   __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
283   uint32_t loop_cnt;
284
285   for (loop_cnt = 32; loop_cnt--;) {
286     DUP4_ARG2(__lsx_vld, src, 0, src, 16, src, 32, src, 48, src0, src1, src2,
287               src3);
288     DUP4_ARG2(__lsx_vld, pred, 0, pred, 16, pred, 32, pred, 48, pred0, pred1,
289               pred2, pred3);
290     src += src_stride;
291     pred += pred_stride;
292     DUP4_ARG2(__lsx_vld, src, 0, src, 16, src, 32, src, 48, src4, src5, src6,
293               src7);
294     DUP4_ARG2(__lsx_vld, pred, 0, pred, 16, pred, 32, pred, 48, pred4, pred5,
295               pred6, pred7);
296     src += src_stride;
297     pred += pred_stride;
298
299     DUP4_ARG2(__lsx_vilvl_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
300               reg0, reg2, reg4, reg6);
301     DUP4_ARG2(__lsx_vilvh_b, src0, pred0, src1, pred1, src2, pred2, src3, pred3,
302               reg1, reg3, reg5, reg7);
303     DUP4_ARG2(__lsx_vilvl_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
304               tmp0, tmp2, tmp4, tmp6);
305     DUP4_ARG2(__lsx_vilvh_b, src4, pred4, src5, pred5, src6, pred6, src7, pred7,
306               tmp1, tmp3, tmp5, tmp7);
307     DUP4_ARG2(__lsx_vhsubw_hu_bu, reg0, reg0, reg1, reg1, reg2, reg2, reg3,
308               reg3, src0, src1, src2, src3);
309     DUP4_ARG2(__lsx_vhsubw_hu_bu, reg4, reg4, reg5, reg5, reg6, reg6, reg7,
310               reg7, src4, src5, src6, src7);
311     DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp0, tmp0, tmp1, tmp1, tmp2, tmp2, tmp3,
312               tmp3, pred0, pred1, pred2, pred3);
313     DUP4_ARG2(__lsx_vhsubw_hu_bu, tmp4, tmp4, tmp5, tmp5, tmp6, tmp6, tmp7,
314               tmp7, pred4, pred5, pred6, pred7);
315     __lsx_vst(src0, diff, 0);
316     __lsx_vst(src1, diff, 16);
317     __lsx_vst(src2, diff, 32);
318     __lsx_vst(src3, diff, 48);
319     __lsx_vst(src4, diff, 64);
320     __lsx_vst(src5, diff, 80);
321     __lsx_vst(src6, diff, 96);
322     __lsx_vst(src7, diff, 112);
323     diff += diff_stride;
324     __lsx_vst(pred0, diff, 0);
325     __lsx_vst(pred1, diff, 16);
326     __lsx_vst(pred2, diff, 32);
327     __lsx_vst(pred3, diff, 48);
328     __lsx_vst(pred4, diff, 64);
329     __lsx_vst(pred5, diff, 80);
330     __lsx_vst(pred6, diff, 96);
331     __lsx_vst(pred7, diff, 112);
332     diff += diff_stride;
333   }
334 }
335
336 void vpx_subtract_block_lsx(int32_t rows, int32_t cols, int16_t *diff_ptr,
337                             ptrdiff_t diff_stride, const uint8_t *src_ptr,
338                             ptrdiff_t src_stride, const uint8_t *pred_ptr,
339                             ptrdiff_t pred_stride) {
340   if (rows == cols) {
341     switch (rows) {
342       case 4:
343         sub_blk_4x4_lsx(src_ptr, src_stride, pred_ptr, pred_stride, diff_ptr,
344                         diff_stride);
345         break;
346       case 8:
347         sub_blk_8x8_lsx(src_ptr, src_stride, pred_ptr, pred_stride, diff_ptr,
348                         diff_stride);
349         break;
350       case 16:
351         sub_blk_16x16_lsx(src_ptr, src_stride, pred_ptr, pred_stride, diff_ptr,
352                           diff_stride);
353         break;
354       case 32:
355         sub_blk_32x32_lsx(src_ptr, src_stride, pred_ptr, pred_stride, diff_ptr,
356                           diff_stride);
357         break;
358       case 64:
359         sub_blk_64x64_lsx(src_ptr, src_stride, pred_ptr, pred_stride, diff_ptr,
360                           diff_stride);
361         break;
362       default:
363         vpx_subtract_block_c(rows, cols, diff_ptr, diff_stride, src_ptr,
364                              src_stride, pred_ptr, pred_stride);
365         break;
366     }
367   } else {
368     vpx_subtract_block_c(rows, cols, diff_ptr, diff_stride, src_ptr, src_stride,
369                          pred_ptr, pred_stride);
370   }
371 }