[Mono] Pass SIMD type variables using SIMD registers (#86634)
authorFan Yang <52458914+fanyang-mono@users.noreply.github.com>
Thu, 25 May 2023 15:35:43 +0000 (11:35 -0400)
committerGitHub <noreply@github.com>
Thu, 25 May 2023 15:35:43 +0000 (11:35 -0400)
* Pass SIMD type variables in SIMD registers

* Disable support for 256/512 bit vectors.

---------

Co-authored-by: Zoltan Varga <vargaz@gmail.com>
src/mono/mono/mini/mini-arm64.c
src/mono/mono/mini/mini-arm64.h

index 69714fc..de3a390 100644 (file)
@@ -1530,15 +1530,26 @@ is_hfa (MonoType *t, int *out_nfields, int *out_esize, int *field_offsets)
 }
 
 static void
-add_valuetype (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t)
+add_valuetype (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t, gboolean is_return)
 {
        int i, size, align_size, nregs, nfields, esize;
        int field_offsets [16];
        guint32 align;
+       MonoClass *klass;
 
+       klass = mono_class_from_mono_type_internal (t);
        size = mini_type_stack_size_full (t, &align, cinfo->pinvoke);
        align_size = ALIGN_TO (size, 8);
 
+       if (m_class_is_simd_type (klass) && size <= 16 && !cinfo->pinvoke && !is_return && cinfo->fr < FP_PARAM_REGS) {
+               ainfo->storage = ArgInSIMDReg;
+               ainfo->reg = cinfo->fr;
+               ainfo->nregs = 1;
+               ainfo->size = size;
+               cinfo->fr ++;
+               return;
+       }
+
        nregs = align_size / 8;
        if (is_hfa (t, &nfields, &esize, field_offsets)) {
                /*
@@ -1594,7 +1605,7 @@ add_valuetype (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t)
 }
 
 static void
-add_param (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t)
+add_param (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t, gboolean is_return)
 {
        MonoType *ptype;
 
@@ -1646,7 +1657,7 @@ add_param (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t)
                break;
        case MONO_TYPE_VALUETYPE:
        case MONO_TYPE_TYPEDBYREF:
-               add_valuetype (cinfo, ainfo, ptype);
+               add_valuetype (cinfo, ainfo, ptype, is_return);
                break;
        case MONO_TYPE_VOID:
                ainfo->storage = ArgNone;
@@ -1661,7 +1672,7 @@ add_param (CallInfo *cinfo, ArgInfo *ainfo, MonoType *t)
                        ainfo->storage = ArgVtypeByRef;
                        ainfo->gsharedvt = TRUE;
                } else {
-                       add_valuetype (cinfo, ainfo, ptype);
+                       add_valuetype (cinfo, ainfo, ptype, is_return);
                }
                break;
        case MONO_TYPE_VAR:
@@ -1703,7 +1714,7 @@ get_call_info (MonoMemPool *mp, MonoMethodSignature *sig)
 #endif
 
        /* Return value */
-       add_param (cinfo, &cinfo->ret, sig->ret);
+       add_param (cinfo, &cinfo->ret, sig->ret, TRUE);
        if (cinfo->ret.storage == ArgVtypeByRef)
                cinfo->ret.reg = ARMREG_R8;
        /* Reset state */
@@ -1724,10 +1735,10 @@ get_call_info (MonoMemPool *mp, MonoMethodSignature *sig)
                        cinfo->gr = PARAM_REGS;
                        cinfo->fr = FP_PARAM_REGS;
                        /* Emit the signature cookie just before the implicit arguments */
-                       add_param (cinfo, &cinfo->sig_cookie, mono_get_int_type ());
+                       add_param (cinfo, &cinfo->sig_cookie, mono_get_int_type (), FALSE);
                }
 
-               add_param (cinfo, ainfo, sig->params [pindex]);
+               add_param (cinfo, ainfo, sig->params [pindex], FALSE);
                if (ainfo->storage == ArgVtypeByRef) {
                        /* Pass the argument address in the next register */
                        if (cinfo->gr >= PARAM_REGS) {
@@ -1749,7 +1760,7 @@ get_call_info (MonoMemPool *mp, MonoMethodSignature *sig)
                cinfo->gr = PARAM_REGS;
                cinfo->fr = FP_PARAM_REGS;
                /* Emit the signature cookie just before the implicit arguments */
-               add_param (cinfo, &cinfo->sig_cookie, mono_get_int_type ());
+               add_param (cinfo, &cinfo->sig_cookie, mono_get_int_type (), FALSE);
        }
 
        cinfo->stack_usage = ALIGN_TO (cinfo->stack_usage, MONO_ARCH_FRAME_ALIGNMENT);
@@ -2601,6 +2612,7 @@ mono_arch_allocate_vars (MonoCompile *cfg)
                        break;
                case ArgVtypeInIRegs:
                case ArgHFA:
+               case ArgInSIMDReg:
                        ins->opcode = OP_REGOFFSET;
                        ins->inst_basereg = cfg->frame_reg;
                        /* These arguments are saved to the stack in the prolog */
@@ -2806,21 +2818,6 @@ mono_arch_get_llvm_call_info (MonoCompile *cfg, MonoMethodSignature *sig)
                        break;
                }
                case ArgVtypeInIRegs:
-#if 0
-                       /* FIXME: the non-LLVM codegen should also pass arguments in registers or
-                        * else there could a mismatch when LLVM code calls non-LLVM code
-                        *
-                        * See https://github.com/dotnet/runtime/issues/73454
-                        */
-                       if ((t->type == MONO_TYPE_GENERICINST) && !cfg->full_aot && !sig->pinvoke) {
-                               MonoClass *klass = mono_class_from_mono_type_internal (t);
-                               if (mini_class_is_simd (cfg, klass)) {
-                                       lainfo->storage = LLVMArgVtypeInSIMDReg;
-                                       break;
-                               }
-                       }
-#endif
-
                        lainfo->storage = LLVMArgAsIArgs;
                        lainfo->nslots = ainfo->nregs;
                        break;
@@ -2839,6 +2836,9 @@ mono_arch_get_llvm_call_info (MonoCompile *cfg, MonoMethodSignature *sig)
                                lainfo->nslots = ainfo->size / 8;
                        }
                        break;
+               case ArgInSIMDReg:
+                       lainfo->storage = LLVMArgVtypeInSIMDReg;
+                       break;
                default:
                        g_assert_not_reached ();
                        break;
@@ -3001,6 +3001,7 @@ mono_arch_emit_call (MonoCompile *cfg, MonoCallInst *call)
                case ArgVtypeByRef:
                case ArgVtypeByRefOnStack:
                case ArgVtypeOnStack:
+               case ArgInSIMDReg:
                case ArgHFA: {
                        MonoInst *ins;
                        guint32 align;
@@ -3110,6 +3111,15 @@ mono_arch_emit_outarg_vt (MonoCompile *cfg, MonoInst *ins, MonoInst *src)
                        MONO_EMIT_NEW_STORE_MEMBASE (cfg, OP_STOREI8_MEMBASE_REG, ARMREG_SP, ainfo->offset + (i * 8), load->dreg);
                }
                break;
+       case ArgInSIMDReg:
+               MONO_INST_NEW (cfg, load, OP_LOADX_MEMBASE);
+               load->dreg = mono_alloc_ireg (cfg);
+               load->inst_basereg = src->dreg;
+               load->inst_offset = 0;
+               load->klass = src->klass;
+               MONO_ADD_INS (cfg->cbb, load);
+               add_outarg_reg (cfg, call, ArgInFReg, ainfo->reg, load);
+               break;
        default:
                g_assert_not_reached ();
                break;
@@ -5584,6 +5594,9 @@ emit_move_args (MonoCompile *cfg, guint8 *code)
                                                code = emit_strfpx (code, ainfo->reg + part, ins->inst_basereg, ins->inst_offset + ainfo->foffsets [part]);
                                }
                                break;
+                       case ArgInSIMDReg:
+                               code = emit_strfpq (code, ainfo->reg, ins->inst_basereg, ins->inst_offset);
+                               break;
                        default:
                                g_assert_not_reached ();
                                break;
index 1719588..f0bba65 100644 (file)
@@ -222,11 +222,13 @@ typedef enum {
        ArgOnStackR4,
        /*
         * Vtype passed in consecutive int registers.
-        * ainfo->reg is the firs register,
+        * ainfo->reg is the first register,
         * ainfo->nregs is the number of registers,
         * ainfo->size is the size of the structure.
         */
        ArgVtypeInIRegs,
+       /* SIMD arg in NEON register */
+       ArgInSIMDReg,
        ArgVtypeByRef,
        ArgVtypeByRefOnStack,
        ArgVtypeOnStack,