[Arm64] Support table-driven code generation for scalar intrinsics (#447)
authorEgor Chesakov <Egor.Chesakov@microsoft.com>
Tue, 3 Dec 2019 00:16:37 +0000 (16:16 -0800)
committerGitHub <noreply@github.com>
Tue, 3 Dec 2019 00:16:37 +0000 (16:16 -0800)
* Define HWIntrinsic class to incapsulate all the initialization shared between table-driven and special intrinsics in jit/hwintrinsiccodegenarm64.cpp

src/coreclr/src/jit/hwintrinsiccodegenarm64.cpp

index 3e9082b..f793bb7 100644 (file)
 #include "gcinfo.h"
 #include "gcinfoencoder.h"
 
-//------------------------------------------------------------------------
-// genIsTableDrivenHWIntrinsic:
-//
-// Arguments:
-//    category - category of a HW intrinsic
-//
-// Return Value:
-//    returns true if this category can be table-driven in CodeGen
-//
-static bool genIsTableDrivenHWIntrinsic(NamedIntrinsic intrinsicId, HWIntrinsicCategory category)
+struct HWIntrinsic final
 {
-    // TODO-Arm64-Cleanup - make more categories to the table-driven framework
-    const bool tableDrivenCategory =
-        (category != HW_Category_Special) && (category != HW_Category_Scalar) && (category != HW_Category_Helper);
-    const bool tableDrivenFlag =
-        !HWIntrinsicInfo::GeneratesMultipleIns(intrinsicId) && !HWIntrinsicInfo::HasSpecialCodegen(intrinsicId);
-    return tableDrivenCategory && tableDrivenFlag;
-}
+    HWIntrinsic(const GenTreeHWIntrinsic* node)
+        : op1(nullptr), op2(nullptr), op3(nullptr), numOperands(0), baseType(TYP_UNDEF)
+    {
+        assert(node != nullptr);
 
-//------------------------------------------------------------------------
-// genHWIntrinsic: Generates the code for a given hardware intrinsic node.
-//
-// Arguments:
-//    node - The hardware intrinsic node
-//
-void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
-{
-    NamedIntrinsic      intrinsicId = node->gtHWIntrinsicId;
-    HWIntrinsicCategory category    = HWIntrinsicInfo::lookupCategory(intrinsicId);
+        id       = node->gtHWIntrinsicId;
+        category = HWIntrinsicInfo::lookupCategory(id);
+
+        assert(HWIntrinsicInfo::RequiresCodegen(id));
+
+        InitializeOperands(node);
+        InitializeBaseType(node);
+    }
 
-    assert(HWIntrinsicInfo::RequiresCodegen(intrinsicId));
+    bool IsTableDriven() const
+    {
+        // TODO-Arm64-Cleanup - make more categories to the table-driven framework
+        bool isTableDrivenCategory = (category != HW_Category_Special) && (category != HW_Category_Helper);
+        bool isTableDrivenFlag = !HWIntrinsicInfo::GeneratesMultipleIns(id) && !HWIntrinsicInfo::HasSpecialCodegen(id);
+
+        return isTableDrivenCategory && isTableDrivenFlag;
+    }
 
-    if (genIsTableDrivenHWIntrinsic(intrinsicId, category))
+    NamedIntrinsic      id;
+    HWIntrinsicCategory category;
+    GenTree*            op1;
+    GenTree*            op2;
+    GenTree*            op3;
+    int                 numOperands;
+    var_types           baseType;
+
+private:
+    void InitializeOperands(const GenTreeHWIntrinsic* node)
     {
-        InstructionSet isa     = HWIntrinsicInfo::lookupIsa(intrinsicId);
-        int            ival    = HWIntrinsicInfo::lookupIval(intrinsicId);
-        int            numArgs = HWIntrinsicInfo::lookupNumArgs(node);
+        op1 = node->gtGetOp1();
+        op2 = node->gtGetOp2();
 
-        assert(numArgs >= 0);
+        assert(op1 != nullptr);
 
-        GenTree*  op1        = node->gtGetOp1();
-        GenTree*  op2        = node->gtGetOp2();
-        regNumber targetReg  = node->GetRegNum();
-        var_types targetType = node->TypeGet();
-        var_types baseType   = node->gtSIMDBaseType;
+        if (op1->OperIsList())
+        {
+            assert(op2 == nullptr);
 
-        instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
-        assert(ins != INS_invalid);
+            GenTreeArgList* list = op1->AsArgList();
+            op1                  = list->Current();
+            list                 = list->Rest();
+            op2                  = list->Current();
+            list                 = list->Rest();
+            op3                  = list->Current();
 
-        regNumber op1Reg   = REG_NA;
-        regNumber op2Reg   = REG_NA;
-        emitter*  emit     = GetEmitter();
-        emitAttr  emitSize = EA_ATTR(node->gtSIMDSize);
-        insOpts   opt      = INS_OPTS_NONE;
+            assert(list->Rest() == nullptr);
 
-        if (category == HW_Category_SIMDScalar)
+            numOperands = 3;
+        }
+        else if (op2 != nullptr)
         {
-            emitSize = emitActualTypeSize(baseType);
+            numOperands = 2;
         }
         else
         {
-            opt = genGetSimdInsOpt(emitSize, baseType);
+            numOperands = 1;
         }
+    }
 
-        assert(emitSize != 0);
-        genConsumeOperands(node);
+    void InitializeBaseType(const GenTreeHWIntrinsic* node)
+    {
+        baseType = node->gtSIMDBaseType;
 
-        switch (numArgs)
+        if (baseType == TYP_UNKNOWN)
         {
-            case 1:
-            {
-                assert(op1 != nullptr);
-                assert(op2 == nullptr);
+            assert(category == HW_Category_Scalar);
 
-                op1Reg = op1->GetRegNum();
-                emit->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
-                break;
-            }
-
-            case 2:
+            if (HWIntrinsicInfo::BaseTypeFromFirstArg(id))
             {
                 assert(op1 != nullptr);
-                assert(op2 != nullptr);
-
-                op1Reg = op1->GetRegNum();
-                op2Reg = op2->GetRegNum();
-
-                emit->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
-                break;
+                baseType = op1->TypeGet();
             }
-
-            case 3:
+            else if (HWIntrinsicInfo::BaseTypeFromSecondArg(id))
             {
-                assert(op1 != nullptr);
-                assert(op2 == nullptr);
-
-                GenTreeArgList* argList = op1->AsArgList();
-                op1                     = argList->Current();
-                op1Reg                  = op1->GetRegNum();
-
-                argList = argList->Rest();
-                op2     = argList->Current();
-                op2Reg  = op2->GetRegNum();
-
-                argList          = argList->Rest();
-                GenTree*  op3    = argList->Current();
-                regNumber op3Reg = op3->GetRegNum();
-
-                if (targetReg != op1Reg)
-                {
-                    emit->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
-                }
-                emit->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
-                break;
+                assert(op2 != nullptr);
+                baseType = op2->TypeGet();
             }
-
-            default:
+            else
             {
-                unreached();
+                baseType = node->TypeGet();
             }
         }
-        genProduceReg(node);
     }
-    else
-    {
-        genSpecialIntrinsic(node);
-    }
-}
+};
 
-void CodeGen::genSpecialIntrinsic(GenTreeHWIntrinsic* node)
+//------------------------------------------------------------------------
+// genHWIntrinsic: Generates the code for a given hardware intrinsic node.
+//
+// Arguments:
+//    node - The hardware intrinsic node
+//
+void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
 {
-    NamedIntrinsic      intrinsicId = node->gtHWIntrinsicId;
-    HWIntrinsicCategory category    = HWIntrinsicInfo::lookupCategory(intrinsicId);
-
-    assert(HWIntrinsicInfo::RequiresCodegen(intrinsicId));
-
-    InstructionSet isa     = HWIntrinsicInfo::lookupIsa(intrinsicId);
-    int            ival    = HWIntrinsicInfo::lookupIval(intrinsicId);
-    int            numArgs = HWIntrinsicInfo::lookupNumArgs(node);
+    const HWIntrinsic intrin(node);
 
-    assert(numArgs >= 0);
+    regNumber targetReg = node->GetRegNum();
 
-    GenTree*  op1        = node->gtGetOp1();
-    GenTree*  op2        = node->gtGetOp2();
-    regNumber targetReg  = node->GetRegNum();
-    var_types targetType = node->TypeGet();
-    var_types baseType   = node->gtSIMDBaseType;
-
-    if (baseType == TYP_UNKNOWN)
-    {
-        assert(category == HW_Category_Scalar);
-
-        if (HWIntrinsicInfo::BaseTypeFromFirstArg(intrinsicId))
-        {
-            assert(op1 != nullptr);
-            baseType = op1->TypeGet();
-        }
-        else if (HWIntrinsicInfo::BaseTypeFromSecondArg(intrinsicId))
-        {
-            assert(op2 != nullptr);
-            baseType = op2->TypeGet();
-        }
-        else
-        {
-            baseType = targetType;
-        }
-    }
+    regNumber op1Reg = REG_NA;
+    regNumber op2Reg = REG_NA;
+    regNumber op3Reg = REG_NA;
 
-    switch (intrinsicId)
+    switch (intrin.numOperands)
     {
-        case NI_Crc32_ComputeCrc32:
-        case NI_Crc32_ComputeCrc32C:
-        {
-            if (baseType == TYP_INT)
-            {
-                baseType = TYP_UINT;
-            }
+        case 3:
+            assert(intrin.op3 != nullptr);
+            op3Reg = intrin.op3->GetRegNum();
+            __fallthrough;
+
+        case 2:
+            assert(intrin.op2 != nullptr);
+            op2Reg = intrin.op2->GetRegNum();
+            __fallthrough;
+
+        case 1:
+            assert(intrin.op1 != nullptr);
+            op1Reg = intrin.op1->GetRegNum();
             break;
-        }
-
-        case NI_Crc32_Arm64_ComputeCrc32:
-        case NI_Crc32_Arm64_ComputeCrc32C:
-        {
-            assert(baseType == TYP_LONG);
-            baseType = TYP_ULONG;
-            break;
-        }
 
         default:
-            break;
+            unreached();
     }
 
-    instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
-    assert(ins != INS_invalid);
-
-    regNumber op1Reg   = REG_NA;
-    regNumber op2Reg   = REG_NA;
-    emitter*  emit     = GetEmitter();
-    emitAttr  emitSize = EA_ATTR(node->gtSIMDSize);
-    insOpts   opt      = INS_OPTS_NONE;
+    emitAttr emitSize;
+    insOpts  opt = INS_OPTS_NONE;
 
-    if ((category == HW_Category_SIMDScalar) || (category == HW_Category_Scalar))
+    if ((intrin.category == HW_Category_SIMDScalar) || (intrin.category == HW_Category_Scalar))
     {
-        emitSize = emitActualTypeSize(baseType);
+        emitSize = emitActualTypeSize(intrin.baseType);
     }
     else
     {
-        opt = genGetSimdInsOpt(emitSize, baseType);
+        emitSize = EA_SIZE(node->gtSIMDSize);
+        opt      = genGetSimdInsOpt(emitSize, intrin.baseType);
     }
 
-    genConsumeOperands(node);
+    genConsumeHWIntrinsicOperands(node);
 
-    switch (intrinsicId)
+    if (intrin.IsTableDriven())
     {
-        case NI_Aes_Decrypt:
-        case NI_Aes_Encrypt:
+        instruction ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
+        assert(ins != INS_invalid);
+
+        switch (intrin.numOperands)
         {
-            assert(op1 != nullptr);
-            assert(op2 != nullptr);
+            case 1:
+                GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
+                break;
 
-            op1Reg = op1->GetRegNum();
-            op2Reg = op2->GetRegNum();
+            case 2:
+                GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
+                break;
 
-            if (op1Reg != targetReg)
-            {
-                emit->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
-            }
-            emit->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
-            break;
+            case 3:
+                if (targetReg != op1Reg)
+                {
+                    GetEmitter()->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
+                }
+                GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
+                break;
+
+            default:
+                unreached();
         }
+    }
+    else
+    {
+        instruction ins = INS_invalid;
 
-        case NI_ArmBase_LeadingZeroCount:
-        case NI_ArmBase_ReverseElementBits:
-        case NI_ArmBase_Arm64_LeadingSignCount:
-        case NI_ArmBase_Arm64_LeadingZeroCount:
-        case NI_ArmBase_Arm64_ReverseElementBits:
+        switch (intrin.id)
         {
-            assert(op1 != nullptr);
-            assert(op2 == nullptr);
+            case NI_Crc32_ComputeCrc32:
+                if (intrin.baseType == TYP_INT)
+                {
+                    ins = INS_crc32w;
+                }
+                else
+                {
+                    ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
+                }
+                break;
 
-            op1Reg = op1->GetRegNum();
-            emit->emitIns_R_R(ins, emitSize, targetReg, op1Reg);
-            break;
-        }
+            case NI_Crc32_ComputeCrc32C:
+                if (intrin.baseType == TYP_INT)
+                {
+                    ins = INS_crc32cw;
+                }
+                else
+                {
+                    ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
+                }
+                break;
 
-        case NI_Crc32_ComputeCrc32:
-        case NI_Crc32_ComputeCrc32C:
-        case NI_Crc32_Arm64_ComputeCrc32:
-        case NI_Crc32_Arm64_ComputeCrc32C:
-        {
-            assert(op1 != nullptr);
-            assert(op2 != nullptr);
+            case NI_Crc32_Arm64_ComputeCrc32:
+                assert(intrin.baseType == TYP_LONG);
+                ins = INS_crc32x;
+                break;
 
-            op1Reg = op1->GetRegNum();
-            op2Reg = op2->GetRegNum();
-            emit->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg);
-            break;
+            case NI_Crc32_Arm64_ComputeCrc32C:
+                assert(intrin.baseType == TYP_LONG);
+                ins = INS_crc32cx;
+                break;
+
+            default:
+                ins = HWIntrinsicInfo::lookupIns(intrin.id, intrin.baseType);
+                break;
         }
 
-        default:
+        assert(ins != INS_invalid);
+
+        switch (intrin.id)
         {
-            unreached();
+            case NI_Aes_Decrypt:
+            case NI_Aes_Encrypt:
+                if (targetReg != op1Reg)
+                {
+                    GetEmitter()->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
+                }
+                GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
+                break;
+
+            case NI_Crc32_ComputeCrc32:
+            case NI_Crc32_ComputeCrc32C:
+            case NI_Crc32_Arm64_ComputeCrc32:
+            case NI_Crc32_Arm64_ComputeCrc32C:
+                GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg);
+                break;
+
+            default:
+                unreached();
         }
     }