From 54747fe24a9c1133b526843e1ba734652c7d5555 Mon Sep 17 00:00:00 2001 From: Shivraj Patil Date: Thu, 22 Sep 2016 17:25:46 +0530 Subject: [PATCH] DGEMM function split and data prefech Signed-off-by: Shivraj Patil --- kernel/mips/dgemm_kernel_8x4_msa.c | 710 +++++++------------------------------ 1 file changed, 137 insertions(+), 573 deletions(-) diff --git a/kernel/mips/dgemm_kernel_8x4_msa.c b/kernel/mips/dgemm_kernel_8x4_msa.c index 9286e74..7636827 100644 --- a/kernel/mips/dgemm_kernel_8x4_msa.c +++ b/kernel/mips/dgemm_kernel_8x4_msa.c @@ -28,30 +28,26 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" #include "macros_msa.h" -int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, - FLOAT *C, BLASLONG ldc -#ifdef TRMMKERNEL - , BLASLONG offset +#define ENABLE_PREFETCH + +#ifdef ENABLE_PREFETCH +inline static void prefetch_load_lf(unsigned char *src) { + __asm__ __volatile__("pref 0, 0(%[src]) \n\t" : : [src] "r"(src)); +} #endif - ) + +static void __attribute__ ((noinline)) +dgemmkernel_8x4_core_msa(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, + FLOAT *C, BLASLONG ldc) { - BLASLONG i, j, l, temp; -#if defined(TRMMKERNEL) - BLASLONG off; -#endif + BLASLONG i, j, l; FLOAT *pc0, *pc1, *pc2, *pc3, *pa0, *pb0; - FLOAT tmp0, tmp1, tmp2, tmp3; - FLOAT a0, b0, b1, b2, b3; v2f64 v_alpha = {alpha, alpha}; v2f64 src_a0, src_a1, src_a2, src_a3, src_b, src_b0, src_b1; v2f64 dst0, dst1, dst2, dst3, dst4, dst5, dst6, dst7; v2f64 res0, res1, res2, res3, res4, res5, res6, res7; v2f64 res8, res9, res10, res11, res12, res13, res14, res15; -#if defined(TRMMKERNEL) && !defined(LEFT) - off = -offset; -#endif - for (j = (n >> 2); j--;) { pc0 = C; @@ -61,30 +57,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 = A; -#if defined(TRMMKERNEL) && defined(LEFT) - off = offset; -#endif - for (i = (m >> 3); i--;) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 8; - pb0 = B + off * 4; -#endif -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 8; // number of values in A -#else - temp = off + 4; // number of values in B -#endif -#else - pb0 = B; - temp = k; +#ifdef ENABLE_PREFETCH + prefetch_load_lf((unsigned char *)(pa0 + 8)); + prefetch_load_lf((unsigned char *)(pa0 + 12)); + prefetch_load_lf((unsigned char *)(pa0 + 16)); + prefetch_load_lf((unsigned char *)(pa0 + 20)); + + prefetch_load_lf((unsigned char *)(pb0 + 4)); + prefetch_load_lf((unsigned char *)(pb0 + 8)); + prefetch_load_lf((unsigned char *)(pb0 + 12)); #endif LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); @@ -114,8 +99,14 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res14 = src_a2 * src_b; res15 = src_a3 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { +#ifdef ENABLE_PREFETCH + prefetch_load_lf((unsigned char *)(pa0 + 16)); + prefetch_load_lf((unsigned char *)(pa0 + 20)); + prefetch_load_lf((unsigned char *)(pa0 + 24)); + prefetch_load_lf((unsigned char *)(pa0 + 28)); +#endif LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); LD_DP2_INC(pb0, 2, src_b0, src_b1); @@ -144,6 +135,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res15 += src_a3 * src_b; LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); +#ifdef ENABLE_PREFETCH + prefetch_load_lf((unsigned char *)(pb0 + 8)); + prefetch_load_lf((unsigned char *)(pb0 + 12)); +#endif LD_DP2_INC(pb0, 2, src_b0, src_b1); src_b = (v2f64) __msa_ilvr_d((v2i64) src_b0, (v2i64) src_b0); @@ -171,7 +166,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res15 += src_a3 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); LD_DP2_INC(pb0, 2, src_b0, src_b1); @@ -201,16 +196,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res15 += src_a3 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; - dst4 = res4 * v_alpha; - dst5 = res5 * v_alpha; - dst6 = res6 * v_alpha; - dst7 = res7 * v_alpha; -#else +#ifdef ENABLE_PREFETCH + prefetch_load_lf((unsigned char *)(pc0 + 8)); + prefetch_load_lf((unsigned char *)(pc1 + 8)); + prefetch_load_lf((unsigned char *)(pc2 + 8)); + prefetch_load_lf((unsigned char *)(pc3 + 8)); +#endif LD_DP4(pc0, 2, dst0, dst1, dst2, dst3); LD_DP4(pc1, 2, dst4, dst5, dst6, dst7); @@ -222,20 +213,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst5 += res5 * v_alpha; dst6 += res6 * v_alpha; dst7 += res7 * v_alpha; -#endif + ST_DP4_INC(dst0, dst1, dst2, dst3, pc0, 2); ST_DP4_INC(dst4, dst5, dst6, dst7, pc1, 2); -#if defined(TRMMKERNEL) - dst0 = res8 * v_alpha; - dst1 = res9 * v_alpha; - dst2 = res10 * v_alpha; - dst3 = res11 * v_alpha; - dst4 = res12 * v_alpha; - dst5 = res13 * v_alpha; - dst6 = res14 * v_alpha; - dst7 = res15 * v_alpha; -#else LD_DP4(pc2, 2, dst0, dst1, dst2, dst3); LD_DP4(pc3, 2, dst4, dst5, dst6, dst7); @@ -247,50 +228,44 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst5 += res13 * v_alpha; dst6 += res14 * v_alpha; dst7 += res15 * v_alpha; -#endif ST_DP4_INC(dst0, dst1, dst2, dst3, pc2, 2); ST_DP4_INC(dst4, dst5, dst6, dst7, pc3, 2); + } -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 8; // number of values in A -#else - temp -= 4; // number of values in B -#endif - pa0 += temp * 8; - pb0 += temp * 4; -#endif + l = (k << 2); + B = B + l; + i = (ldc << 2); + C = C + i; + } +} -#ifdef LEFT - off += 8; // number of values in A -#endif -#endif - } +static void __attribute__ ((noinline)) +dgemmkernel_7x4_core_msa(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, + FLOAT *C, BLASLONG ldc) +{ + BLASLONG i, j, l; + FLOAT *pc0, *pc1, *pc2, *pc3, *pa0, *pb0; + FLOAT tmp0, tmp1, tmp2, tmp3; + FLOAT a0, b0, b1, b2, b3; + v2f64 v_alpha = {alpha, alpha}; + v2f64 src_a0, src_a1, src_b, src_b0, src_b1; + v2f64 dst0, dst1, dst2, dst3, dst4, dst5, dst6, dst7; + v2f64 res0, res1, res2, res3, res4, res5, res6, res7; + + for (j = (n >> 2); j--;) + { + + pc0 = C + 8 * (m >> 3); + pc1 = pc0 + ldc; + pc2 = pc1 + ldc; + pc3 = pc2 + ldc; + + pa0 = A + k * 8 * (m >> 3); if (m & 4) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 4; - pb0 = B + off * 4; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 4; // number of values in A -#else - temp = off + 4; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif LD_DP2_INC(pa0, 2, src_a0, src_a1); LD_DP2_INC(pb0, 2, src_b0, src_b1); @@ -311,7 +286,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res6 = src_a0 * src_b; res7 = src_a1 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { LD_DP2_INC(pa0, 2, src_a0, src_a1); LD_DP2_INC(pb0, 2, src_b0, src_b1); @@ -352,7 +327,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res7 += src_a1 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP2_INC(pa0, 2, src_a0, src_a1); LD_DP2_INC(pb0, 2, src_b0, src_b1); @@ -374,16 +349,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res7 += src_a1 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; - dst4 = res4 * v_alpha; - dst5 = res5 * v_alpha; - dst6 = res6 * v_alpha; - dst7 = res7 * v_alpha; -#else LD_DP2(pc0, 2, dst0, dst1); LD_DP2(pc1, 2, dst2, dst3); LD_DP2(pc2, 2, dst4, dst5); @@ -397,51 +362,16 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst5 += res5 * v_alpha; dst6 += res6 * v_alpha; dst7 += res7 * v_alpha; -#endif + ST_DP2_INC(dst0, dst1, pc0, 2); ST_DP2_INC(dst2, dst3, pc1, 2); ST_DP2_INC(dst4, dst5, pc2, 2); ST_DP2_INC(dst6, dst7, pc3, 2); - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 4; // number of values in A -#else - temp -= 4; // number of values in B -#endif - pa0 += temp * 4; - pb0 += temp * 4; -#endif - -#ifdef LEFT - off += 4; // number of values in A -#endif -#endif } if (m & 2) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 2; - pb0 = B + off * 4; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 2; // number of values in A -#else - temp = off + 4; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif src_a0 = LD_DP(pa0); pa0 += 2; @@ -459,7 +389,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, src_b = (v2f64) __msa_ilvl_d((v2i64) src_b1, (v2i64) src_b1); res3 = src_a0 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { src_a0 = LD_DP(pa0); pa0 += 2; @@ -494,7 +424,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res3 += src_a0 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { src_a0 = LD_DP(pa0); pa0 += 2; @@ -513,12 +443,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res3 += src_a0 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; -#else dst0 = LD_DP(pc0); dst1 = LD_DP(pc1); dst2 = LD_DP(pc2); @@ -528,28 +452,12 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst1 += res1 * v_alpha; dst2 += res2 * v_alpha; dst3 += res3 * v_alpha; -#endif + ST_DP(dst0, pc0); ST_DP(dst1, pc1); ST_DP(dst2, pc2); ST_DP(dst3, pc3); -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 2; // number of values in A -#else - temp -= 4; // number of values in B -#endif - pa0 += temp * 2; - pb0 += temp * 4; -#endif - -#ifdef LEFT - off += 2; // number of values in A -#endif -#endif pc0 += 2; pc1 += 2; pc2 += 2; @@ -558,25 +466,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, if (m & 1) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 1; - pb0 = B + off * 4; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 1; // number of values in A -#else - temp = off + 4; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif a0 = pa0[0]; b0 = pb0[0]; @@ -594,7 +484,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 += 1; pb0 += 4; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { a0 = pa0[0]; b0 = pb0[0]; @@ -629,7 +519,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 4; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { a0 = pa0[0]; b0 = pb0[0]; @@ -653,34 +543,10 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, tmp2 = alpha * tmp2; tmp3 = alpha * tmp3; -#if defined(TRMMKERNEL) - pc0[0] = tmp0; - pc1[0] = tmp1; - pc2[0] = tmp2; - pc3[0] = tmp3; -#else pc0[0] += tmp0; pc1[0] += tmp1; pc2[0] += tmp2; pc3[0] += tmp3; -#endif - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 1; // number of values in A -#else - temp -= 4; // number of values in B -#endif - pa0 += temp * 1; - pb0 += temp * 4; -#endif - -#ifdef LEFT - off += 1; // number of values in A -#endif -#endif pc0 += 1; pc1 += 1; @@ -688,15 +554,25 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pc3 += 1; } -#if defined(TRMMKERNEL) && !defined(LEFT) - off += 4; // number of values in A -#endif - l = (k << 2); B = B + l; i = (ldc << 2); C = C + i; } +} + +static void __attribute__ ((noinline)) +dgemmkernel_8x4_non_core_msa(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, + FLOAT *C, BLASLONG ldc) +{ + BLASLONG i, l; + FLOAT *pc0, *pc1, *pa0, *pb0; + FLOAT tmp0, tmp1; + FLOAT a0, b0, b1; + v2f64 v_alpha = {alpha, alpha}; + v2f64 src_a0, src_a1, src_a2, src_a3, src_b, src_b0; + v2f64 dst0, dst1, dst2, dst3, dst4, dst5, dst6, dst7; + v2f64 res0, res1, res2, res3, res4, res5, res6, res7; if (n & 2) { @@ -705,32 +581,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 = A; -#if defined(TRMMKERNEL) && defined(LEFT) - off = offset; -#endif - for (i = (m >> 3); i--;) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 8; - pb0 = B + off * 2; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 8; // number of values in A -#else - temp = off + 2; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif - LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b0 = LD_DP(pb0); @@ -748,7 +601,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res6 = src_a2 * src_b; res7 = src_a3 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b0 = LD_DP(pb0); @@ -783,7 +636,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res7 += src_a3 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b0 = LD_DP(pb0); @@ -802,16 +655,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res7 += src_a3 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; - dst4 = res4 * v_alpha; - dst5 = res5 * v_alpha; - dst6 = res6 * v_alpha; - dst7 = res7 * v_alpha; -#else LD_DP4(pc0, 2, dst0, dst1, dst2, dst3); LD_DP4(pc1, 2, dst4, dst5, dst6, dst7); @@ -823,49 +666,14 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst5 += res5 * v_alpha; dst6 += res6 * v_alpha; dst7 += res7 * v_alpha; -#endif + ST_DP4_INC(dst0, dst1, dst2, dst3, pc0, 2); ST_DP4_INC(dst4, dst5, dst6, dst7, pc1, 2); - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 8; // number of values in A -#else - temp -= 2; // number of values in B -#endif - pa0 += temp * 8; - pb0 += temp * 2; -#endif - -#ifdef LEFT - off += 8; // number of values in A -#endif -#endif } if (m & 4) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 4; - pb0 = B + off * 2; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 4; // number of values in A -#else - temp = off + 2; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b0 = LD_DP(pb0); @@ -879,7 +687,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res2 = src_a0 * src_b; res3 = src_a1 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b0 = LD_DP(pb0); @@ -906,7 +714,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res3 += src_a1 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b0 = LD_DP(pb0); @@ -921,12 +729,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res3 += src_a1 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; -#else LD_DP2(pc0, 2, dst0, dst1); LD_DP2(pc1, 2, dst2, dst3); @@ -934,49 +736,14 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, dst1 += res1 * v_alpha; dst2 += res2 * v_alpha; dst3 += res3 * v_alpha; -#endif + ST_DP2_INC(dst0, dst1, pc0, 2); ST_DP2_INC(dst2, dst3, pc1, 2); - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 4; // number of values in A -#else - temp -= 2; // number of values in B -#endif - pa0 += temp * 4; - pb0 += temp * 2; -#endif - -#ifdef LEFT - off += 4; // number of values in A -#endif -#endif } if (m & 2) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 2; - pb0 = B + off * 2; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 2; // number of values in A -#else - temp = off + 2; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif src_a0 = LD_DP(pa0); pa0 += 2; @@ -989,7 +756,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, src_b = (v2f64) __msa_ilvl_d((v2i64) src_b0, (v2i64) src_b0); res1 = src_a0 * src_b; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { src_a0 = LD_DP(pa0); pa0 += 2; @@ -1014,7 +781,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res1 += src_a0 * src_b; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { src_a0 = LD_DP(pa0); pa0 += 2; @@ -1028,60 +795,22 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, res1 += src_a0 * src_b; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; -#else dst0 = LD_DP(pc0); dst1 = LD_DP(pc1); dst0 += res0 * v_alpha; dst1 += res1 * v_alpha; -#endif + ST_DP(dst0, pc0); ST_DP(dst1, pc1); -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 2; // number of values in A -#else - temp -= 2; // number of values in B -#endif - pa0 += temp * 2; - pb0 += temp * 2; -#endif - -#ifdef LEFT - off += 2; // number of values in A -#endif -#endif pc0 += 2; pc1 += 2; } if (m & 1) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - pb0 = B; -#else - pa0 += off * 1; - pb0 = B + off * 2; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 1; // number of values in A -#else - temp = off + 2; // number of values in B -#endif -#else pb0 = B; - temp = k; -#endif a0 = pa0[0]; b0 = pb0[0]; @@ -1093,7 +822,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 += 1; pb0 += 2; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { a0 = pa0[0]; b0 = pb0[0]; @@ -1116,7 +845,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 2; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { a0 = pa0[0]; b0 = pb0[0]; @@ -1132,39 +861,13 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, tmp0 = alpha * tmp0; tmp1 = alpha * tmp1; -#if defined(TRMMKERNEL) - pc0[0] = tmp0; - pc1[0] = tmp1; -#else pc0[0] += tmp0; pc1[0] += tmp1; -#endif - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 1; // number of values in A -#else - temp -= 2; // number of values in B -#endif - pa0 += temp * 1; - pb0 += temp * 2; -#endif - -#ifdef LEFT - off += 1; // number of values in A -#endif -#endif pc0 += 1; pc1 += 1; } -#if defined(TRMMKERNEL) && !defined(LEFT) - off += 2; // number of values in A -#endif - l = (k << 1); B = B + l; i = (ldc << 1); @@ -1176,31 +879,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pc0 = C; pa0 = A; -#if defined(TRMMKERNEL) && defined(LEFT) - off = offset; -#endif - for (i = (m >> 3); i--;) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 8; - pb0 = B + off * 1; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 8; // number of values in A -#else - temp = off + 1; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b[0] = pb0[0]; @@ -1213,7 +894,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b[0] = pb0[0]; @@ -1238,7 +919,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP4_INC(pa0, 2, src_a0, src_a1, src_a2, src_a3); src_b[0] = pb0[0]; @@ -1252,60 +933,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; - dst2 = res2 * v_alpha; - dst3 = res3 * v_alpha; -#else LD_DP4(pc0, 2, dst0, dst1, dst2, dst3); dst0 += res0 * v_alpha; dst1 += res1 * v_alpha; dst2 += res2 * v_alpha; dst3 += res3 * v_alpha; -#endif - ST_DP4_INC(dst0, dst1, dst2, dst3, pc0, 2); - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 8; // number of values in A -#else - temp -= 1; // number of values in B -#endif - pa0 += temp * 8; - pb0 += temp * 1; -#endif -#ifdef LEFT - off += 8; // number of values in A -#endif -#endif + ST_DP4_INC(dst0, dst1, dst2, dst3, pc0, 2); } if (m & 4) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 4; - pb0 = B + off * 1; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 4; // number of values in A -#else - temp = off + 1; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b[0] = pb0[0]; @@ -1316,7 +956,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b[0] = pb0[0]; @@ -1337,7 +977,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { LD_DP2_INC(pa0, 2, src_a0, src_a1); src_b[0] = pb0[0]; @@ -1349,56 +989,17 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; - dst1 = res1 * v_alpha; -#else LD_DP2(pc0, 2, dst0, dst1); dst0 += res0 * v_alpha; dst1 += res1 * v_alpha; -#endif - ST_DP2_INC(dst0, dst1, pc0, 2); -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 4; // number of values in A -#else - temp -= 1; // number of values in B -#endif - pa0 += temp * 4; - pb0 += temp * 1; -#endif - -#ifdef LEFT - off += 4; // number of values in A -#endif -#endif + ST_DP2_INC(dst0, dst1, pc0, 2); } if (m & 2) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - pb0 = B; -#else - pa0 += off * 2; - pb0 = B + off * 1; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 2; // number of values in A -#else - temp = off + 1; // number of values in B -#endif -#else pb0 = B; - temp = k; -#endif src_a0 = LD_DP(pa0); src_b[0] = pb0[0]; @@ -1409,7 +1010,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 += 2; pb0 += 1; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { src_a0 = LD_DP(pa0); src_b[0] = pb0[0]; @@ -1430,7 +1031,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { src_a0 = LD_DP(pa0); src_b[0] = pb0[0]; @@ -1442,55 +1043,18 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } -#if defined(TRMMKERNEL) - dst0 = res0 * v_alpha; -#else dst0 = LD_DP(pc0); dst0 += res0 * v_alpha; -#endif - ST_DP(dst0, pc0); -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 2; // number of values in A -#else - temp -= 1; // number of values in B -#endif - pa0 += temp * 2; - pb0 += temp * 1; -#endif + ST_DP(dst0, pc0); -#ifdef LEFT - off += 2; // number of values in A -#endif -#endif pc0 += 2; } if (m & 1) { -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) pb0 = B; -#else - pa0 += off * 1; - pb0 = B + off * 1; -#endif - -#if (defined(LEFT) && !defined(TRANSA)) || (!defined(LEFT) && defined(TRANSA)) - temp = k - off; -#elif defined(LEFT) - temp = off + 1; // number of values in A -#else - temp = off + 1; // number of values in B -#endif -#else - pb0 = B; - temp = k; -#endif a0 = pa0[0]; b0 = pb0[0]; @@ -1499,7 +1063,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pa0 += 1; pb0 += 1; - for (l = ((temp - 1) >> 1); l--;) + for (l = ((k - 1) >> 1); l--;) { a0 = pa0[0]; b0 = pb0[0]; @@ -1516,7 +1080,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } - if ((temp - 1) & 1) + if ((k - 1) & 1) { a0 = pa0[0]; b0 = pb0[0]; @@ -1526,41 +1090,41 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, pb0 += 1; } -#if defined(TRMMKERNEL) - pc0[0] = alpha * tmp0; -#else pc0[0] += alpha * tmp0; -#endif - -#if defined(TRMMKERNEL) -#if (defined(LEFT) && defined(TRANSA)) || (!defined(LEFT) && !defined(TRANSA)) - temp = k - off; -#ifdef LEFT - temp -= 1; // number of values in A -#else - temp -= 1; // number of values in B -#endif - pa0 += temp * 1; - pb0 += temp * 1; -#endif - -#ifdef LEFT - off += 1; // number of values in A -#endif -#endif pc0 += 1; } -#if defined(TRMMKERNEL) && !defined(LEFT) - off += 1; // number of values in A -#endif - l = (k << 0); B = B + l; i = (ldc << 0); C = C + i; } +} + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, FLOAT *A, FLOAT *B, + FLOAT *C, BLASLONG ldc +#ifdef TRMMKERNEL + , BLASLONG offset +#endif + ) +{ + if (n >> 2) + { + if (m >> 3) + dgemmkernel_8x4_core_msa(m, n, k, alpha, A, B, C, ldc); + + if (m & 7) + dgemmkernel_7x4_core_msa(m, n, k, alpha, A, B, C, ldc); + } + + if (n & 3) + { + B = B + (k << 2) * (n >> 2); + C = C + (ldc << 2) * (n >> 2); + + dgemmkernel_8x4_non_core_msa(m, n, k, alpha, A, B, C, ldc); + } return 0; } -- 2.7.4