ARM64: Switch Expansion Using Jump Table
authorKyungwoo Lee <kyulee@microsoft.com>
Thu, 12 May 2016 04:48:40 +0000 (21:48 -0700)
committerKyungwoo Lee <kyulee@microsoft.com>
Thu, 12 May 2016 17:23:07 +0000 (10:23 -0700)
Fixes dotnet/coreclr#3332
To validate various addressing in dotnet/coreclr#4896, I just enable this.
Previously, we only allow a load operation to JIT data (`ldr` or
`IF_LARGELDC`).
For switch expansion, jump table is also recorded into JIT data.
In this case, we only get the address of jump table head, and
load the right entry after computing offset. So, basically `adr` or
`IF_LARGEADR` is used to not only load label within code but also refer to
the location of JIT data.
The typical code sequence for switch expansion is like this:

```
  adr     x8, [@RWD00]          // load address of jump table head
  ldr     w8, [x8, x0, LSL dotnet/coreclr#2]  // load jump entry from table addr + x0 * 4
  adr     x9, [G_M56320_IG02]   // load address of current baisc block
  add     x8, x8, x9            // Add them to compute the final target
  br      x8                    // Indirectly jump to the target
```

Commit migrated from https://github.com/dotnet/coreclr/commit/a0c6144d406f29d70005fbf7ebd8ac3bdfe3cc0d

src/coreclr/src/jit/codegenarm64.cpp
src/coreclr/src/jit/emit.cpp
src/coreclr/src/jit/emitarm64.cpp
src/coreclr/src/jit/emitarm64.h
src/coreclr/src/jit/lower.cpp

index dc25607..8b78390 100644 (file)
@@ -4048,7 +4048,6 @@ void CodeGen::genCodeForCpBlk(GenTreeCpBlk* cpBlkNode)
 void
 CodeGen::genTableBasedSwitch(GenTree* treeNode)
 {
-    NYI("Emit table based switch");
     genConsumeOperands(treeNode->AsOp());
     regNumber idxReg = treeNode->gtOp.gtOp1->gtRegNum;
     regNumber baseReg = treeNode->gtOp.gtOp2->gtRegNum;
@@ -4056,21 +4055,21 @@ CodeGen::genTableBasedSwitch(GenTree* treeNode)
     regNumber tmpReg = genRegNumFromMask(treeNode->gtRsvdRegs);
 
     // load the ip-relative offset (which is relative to start of fgFirstBB)
-    //getEmitter()->emitIns_R_ARX(INS_mov, EA_4BYTE, baseReg, baseReg, idxReg, 4, 0);
+    getEmitter()->emitIns_R_R_R(INS_ldr, EA_4BYTE, baseReg, baseReg, idxReg, INS_OPTS_LSL);
 
     // add it to the absolute address of fgFirstBB
     compiler->fgFirstBB->bbFlags |= BBF_JMP_TARGET;
-    //getEmitter()->emitIns_R_L(INS_lea, EA_PTRSIZE, compiler->fgFirstBB, tmpReg);
-    //getEmitter()->emitIns_R_R(INS_add, EA_PTRSIZE, baseReg, tmpReg);
+    getEmitter()->emitIns_R_L(INS_adr, EA_PTRSIZE, compiler->fgFirstBB, tmpReg);
+    getEmitter()->emitIns_R_R_R(INS_add, EA_PTRSIZE, baseReg, baseReg, tmpReg);
+
     // jmp baseReg
-    // getEmitter()->emitIns_R(INS_i_jmp, emitTypeSize(TYP_I_IMPL), baseReg);
+    getEmitter()->emitIns_R(INS_br, emitTypeSize(TYP_I_IMPL), baseReg);
 }
 
 // emits the table and an instruction to get the address of the first element
 void
 CodeGen::genJumpTable(GenTree* treeNode)
 {
-    NYI("Emit Jump table");
     noway_assert(compiler->compCurBB->bbJumpKind == BBJ_SWITCH);
     assert(treeNode->OperGet() == GT_JMPTABLE);
 
@@ -4100,7 +4099,7 @@ CodeGen::genJumpTable(GenTree* treeNode)
     // Access to inline data is 'abstracted' by a special type of static member
     // (produced by eeFindJitDataOffs) which the emitter recognizes as being a reference
     // to constant data, not a real static field.
-    getEmitter()->emitIns_R_C(INS_lea,
+    getEmitter()->emitIns_R_C(INS_adr,
         emitTypeSize(TYP_I_IMPL),
         treeNode->gtRegNum,
         REG_NA,
index cea1509..005f6a4 100644 (file)
@@ -3804,10 +3804,9 @@ AGAIN:
 #if defined(_TARGET_ARM64_)
         // JIT code and data will be allocated together for arm64 so the relative offset to JIT data is known.
         // In case such offset can be encodeable for `ldr` (+-1MB), shorten it.
-        if (emitIsLoadConstant(jmp))
+        if (jmp->idAddr()->iiaIsJitDataOffset())
         {
             // Reference to JIT data
-            assert(jmp->idAddr()->iiaIsJitDataOffset());
             assert(jmp->idIsBound());
             UNATIVE_OFFSET srcOffs = jmpIG->igOffs + jmp->idjOffs;
 
@@ -4289,6 +4288,14 @@ void                emitter::emitCheckFuncletBranch(instrDesc * jmp, insGroup *
     }
 #endif // _TARGET_ARMARCH_
 
+#ifdef _TARGET_ARM64_
+    // No interest if it's not jmp.
+    if (emitIsLoadLabel(jmp) || emitIsLoadConstant(jmp))
+    {
+        return;
+    }
+#endif // _TARGET_ARM64_
+
     insGroup * tgtIG = jmp->idAddr()->iiaIGlabel;
     assert(tgtIG);
     if (tgtIG->igFuncIdx != jmpIG->igFuncIdx)
index 3eec0ea..c1e6f93 100644 (file)
@@ -5097,7 +5097,7 @@ void                emitter::emitIns_R_R_R(instruction ins,
     case INS_str:
     case INS_strb:
     case INS_strh:
-        emitIns_R_R_R_Ext(ins, attr, reg1, reg2, reg3);
+        emitIns_R_R_R_Ext(ins, attr, reg1, reg2, reg3, opt);
         return;
 
     case INS_ldp:
@@ -5473,6 +5473,7 @@ void                emitter::emitIns_R_R_R_Ext(instruction ins,
             assert(isValidGeneralDatasize(size));
             scale = (size == EA_8BYTE) ? 3 : 2;
         }
+
         break;
 
     default:
@@ -6359,6 +6360,13 @@ void                emitter::emitIns_R_C (instruction  ins,
 
     switch (ins)
     {
+    case INS_adr:
+        // This is case to get address to the constant data.
+        fmt = IF_LARGEADR;
+        assert(isGeneralRegister(reg));
+        assert(isValidGeneralDatasize(size));
+        break;
+
     case INS_ldr:
         fmt = IF_LARGELDC;
         if (isVectorRegister(reg))
@@ -6391,7 +6399,10 @@ void                emitter::emitIns_R_C (instruction  ins,
     id->idSetIsBound();    // We won't patch address since we will know the exact distance once JIT code and data are allocated together.
 
     id->idReg1(reg);       // destination register that will get the constant value.
-    id->idReg2(addrReg);   // integer register to compute long address (used for vector dest when we end up with long address)
+    if (addrReg != REG_NA)
+    {
+        id->idReg2(addrReg);   // integer register to compute long address (used for vector dest when we end up with long address)
+    }
     id->idjShort = false;  // Assume loading constant from long address
 
     // Keep it long if it's in cold code.
@@ -7880,6 +7891,39 @@ void                emitter::emitIns_Call(EmitCallType  callType,
     }
 }
 
+BYTE*               emitter::emitOutputLoadLabel(BYTE* dst, BYTE* srcAddr, BYTE* dstAddr, instrDescJmp *id)
+{
+    instruction  ins = id->idIns();
+    insFormat    fmt = id->idInsFmt();
+    regNumber dstReg = id->idReg1();
+    if (id->idjShort)
+    {
+        // adr x, [rel addr] --  compute address: current addr(ip) + rel addr.
+        assert(ins == INS_adr);
+        assert(fmt == IF_DI_1E);
+        ssize_t distVal = (ssize_t)(dstAddr - srcAddr);
+        dst = emitOutputShortAddress(dst, ins, fmt, distVal, dstReg);
+    }
+    else
+    {
+        // adrp x, [rel page addr] -- compute page address: current page addr + rel page addr
+        assert(fmt == IF_LARGEADR);
+        ssize_t relPageAddr = (((ssize_t)dstAddr & 0xFFFFFFFFFFFFF000LL) - ((ssize_t)srcAddr & 0xFFFFFFFFFFFFF000LL)) >> 12;
+        dst = emitOutputShortAddress(dst, INS_adrp, IF_DI_1E, relPageAddr, dstReg);
+
+        // add x, x, page offs -- compute address = page addr + page offs
+        ssize_t imm12 = (ssize_t)dstAddr & 0xFFF;      // 12 bits
+        assert(isValidUimm12(imm12));
+        code_t code = emitInsCode(INS_add, IF_DI_2A);  // DI_2A  X0010001shiiiiii iiiiiinnnnnddddd   1100 0000   imm(i12, sh)
+        code |= insEncodeDatasize(EA_8BYTE);           // X
+        code |= ((code_t)imm12 << 10);                 // iiiiiiiiiiii
+        code |= insEncodeReg_Rd(dstReg);               // ddddd
+        code |= insEncodeReg_Rn(dstReg);               // nnnnn
+        dst += emitOutput_Instr(dst, code);
+    }
+    return dst;
+}
+
 /*****************************************************************************
  *
  *  Output a local jump or other instruction with a pc-relative immediate.
@@ -7921,17 +7965,12 @@ BYTE*               emitter::emitOutputLJ(insGroup  *ig, BYTE *dst, instrDesc *i
 
     case INS_ldr:
     case INS_ldrsw:
+        loadConstant = true;
+        break;
+
     case INS_adr:
     case INS_adrp:
-        // Any reference to JIT data is assumed to load constant.
-        if (id->idAddr()->iiaIsJitDataOffset())
-        {
-            loadConstant = true;
-        }
-        else
-        {
-            loadLabel = true;
-        }
+        loadLabel = true;
         break;
     }
 
@@ -7940,9 +7979,9 @@ BYTE*               emitter::emitOutputLJ(insGroup  *ig, BYTE *dst, instrDesc *i
     srcOffs = emitCurCodeOffs(dst);
     srcAddr = emitOffsetToPtr(srcOffs);
 
-    if (loadConstant)
+    if (id->idAddr()->iiaIsJitDataOffset())
     {
-        /* This is actually a reference to the JIT data section */
+        assert(loadConstant || loadLabel);
         int doff = id->idAddr()->iiaGetJitDataOffset();
         assert(doff >= 0);
         ssize_t imm = emitGetInsSC(id);
@@ -7956,56 +7995,64 @@ BYTE*               emitter::emitOutputLJ(insGroup  *ig, BYTE *dst, instrDesc *i
         regNumber addrReg = dstReg; // an integer register to compute long address.
         emitAttr opSize = id->idOpSize();
 
-        if (id->idjShort)
-        {
-            // ldr x/v, [rel addr] -- load constant from current addr(ip) + rel addr.
-            assert(ins == INS_ldr);
-            assert(fmt == IF_LS_1A);
-            distVal = (ssize_t)(dstAddr - srcAddr);
-            dst = emitOutputShortConstant(dst, ins, fmt, distVal, dstReg, opSize);
-        }
-        else
+        if (loadConstant)
         {
-            // adrp x, [rel page addr] -- compute page address: current page addr + rel page addr
-            assert(fmt == IF_LARGELDC);
-            ssize_t relPageAddr = (((ssize_t)dstAddr & 0xFFFFFFFFFFFFF000LL) - ((ssize_t)srcAddr & 0xFFFFFFFFFFFFF000LL)) >> 12;
-            if (isVectorRegister(dstReg))
+            if (id->idjShort)
             {
-                // Update addrReg with the reserved integer register
-                // since we cannot use dstReg (vector) to load constant directly from memory.
-                addrReg = id->idReg2();
-                assert(isGeneralRegister(addrReg));
+                // ldr x/v, [rel addr] -- load constant from current addr(ip) + rel addr.
+                assert(ins == INS_ldr);
+                assert(fmt == IF_LS_1A);
+                distVal = (ssize_t)(dstAddr - srcAddr);
+                dst = emitOutputShortConstant(dst, ins, fmt, distVal, dstReg, opSize);
             }
-            ins = INS_adrp;
-            fmt = IF_DI_1E;
-            dst = emitOutputShortAddress(dst, ins, fmt, relPageAddr, addrReg);
-
-            // ldr x, [x, page offs] -- load constant from page address + page offset into integer register.
-            ssize_t imm12 = (ssize_t)dstAddr & 0xFFF; // 12 bits
-            assert(isValidUimm12(imm12));
-            ins = INS_ldr;
-            fmt = IF_LS_2B;
-            dst = emitOutputShortConstant(dst, ins, fmt, imm12, addrReg, opSize);
-
-            // fmov v, d -- copy constant in integer register to vector register.
-            // This is needed only for vector constant.
-            if (addrReg != dstReg)
+            else
             {
-                //  fmov    Vd,Rn                DV_2I  X00111100X100111 000000nnnnnddddd   1E27 0000   Vd,Rn    (scalar, from general)
-                assert(isVectorRegister(dstReg) && isGeneralRegister(addrReg));
-                ins = INS_fmov;
-                fmt = IF_DV_2I;
-                code_t code = emitInsCode(ins, fmt);
+                // adrp x, [rel page addr] -- compute page address: current page addr + rel page addr
+                assert(fmt == IF_LARGELDC);
+                ssize_t relPageAddr = (((ssize_t)dstAddr & 0xFFFFFFFFFFFFF000LL) - ((ssize_t)srcAddr & 0xFFFFFFFFFFFFF000LL)) >> 12;
+                if (isVectorRegister(dstReg))
+                {
+                    // Update addrReg with the reserved integer register
+                    // since we cannot use dstReg (vector) to load constant directly from memory.
+                    addrReg = id->idReg2();
+                    assert(isGeneralRegister(addrReg));
+                }
+                ins = INS_adrp;
+                fmt = IF_DI_1E;
+                dst = emitOutputShortAddress(dst, ins, fmt, relPageAddr, addrReg);
+
+                // ldr x, [x, page offs] -- load constant from page address + page offset into integer register.
+                ssize_t imm12 = (ssize_t)dstAddr & 0xFFF; // 12 bits
+                assert(isValidUimm12(imm12));
+                ins = INS_ldr;
+                fmt = IF_LS_2B;
+                dst = emitOutputShortConstant(dst, ins, fmt, imm12, addrReg, opSize);
 
-                code |= insEncodeReg_Vd(dstReg);             // ddddd
-                code |= insEncodeReg_Rn(addrReg);            // nnnnn
-                if (id->idOpSize() == EA_8BYTE)
+                // fmov v, d -- copy constant in integer register to vector register.
+                // This is needed only for vector constant.
+                if (addrReg != dstReg)
                 {
-                    code |= 0x80400000;                      // X ... X
+                    //  fmov    Vd,Rn                DV_2I  X00111100X100111 000000nnnnnddddd   1E27 0000   Vd,Rn    (scalar, from general)
+                    assert(isVectorRegister(dstReg) && isGeneralRegister(addrReg));
+                    ins = INS_fmov;
+                    fmt = IF_DV_2I;
+                    code_t code = emitInsCode(ins, fmt);
+
+                    code |= insEncodeReg_Vd(dstReg);             // ddddd
+                    code |= insEncodeReg_Rn(addrReg);            // nnnnn
+                    if (id->idOpSize() == EA_8BYTE)
+                    {
+                        code |= 0x80400000;                      // X ... X
+                    }
+                    dst += emitOutput_Instr(dst, code);
                 }
-                dst += emitOutput_Instr(dst, code);
             }
         }
+        else
+        {
+            assert(loadLabel);
+            dst = emitOutputLoadLabel(dst, srcAddr, dstAddr, id);
+        }
 
         return dst;
     }
@@ -8154,7 +8201,7 @@ BYTE*               emitter::emitOutputLJ(insGroup  *ig, BYTE *dst, instrDesc *i
             ins = INS_b;
             fmt = IF_BI_0A;
 
-            // The distVal was computed based on the beginning of the pseudo-instruction.
+            // The distVal was computed based on the beginning of the pseudo-instruction,
             // So subtract the size of the conditional branch so that it is relative to the
             // unconditional branch.
             distVal -= 4;
@@ -8164,31 +8211,7 @@ BYTE*               emitter::emitOutputLJ(insGroup  *ig, BYTE *dst, instrDesc *i
     }
     else if (loadLabel)
     {
-        regNumber dstReg = id->idReg1();
-        if (id->idjShort)
-        {
-            // adr x, [rel addr] --  compute address: current addr(ip) + rel addr.
-            assert(ins == INS_adr);
-            assert(fmt == IF_DI_1E);
-            dst = emitOutputShortAddress(dst, ins, fmt, distVal, dstReg);
-        }
-        else
-        {
-            // adrp x, [rel page addr] -- compute page address: current page addr + rel page addr
-            assert(fmt == IF_LARGEADR);
-            ssize_t relPageAddr = (((ssize_t)dstAddr & 0xFFFFFFFFFFFFF000LL) - ((ssize_t)srcAddr & 0xFFFFFFFFFFFFF000LL)) >> 12;
-            dst = emitOutputShortAddress(dst, INS_adrp, IF_DI_1E, relPageAddr, dstReg);
-
-            // add x, x, page offs -- compute address = page addr + page offs
-            ssize_t imm12 = (ssize_t)dstAddr & 0xFFF;      // 12 bits
-            assert(isValidUimm12(imm12));
-            code_t code = emitInsCode(INS_add, IF_DI_2A);  // DI_2A  X0010001shiiiiii iiiiiinnnnnddddd   1100 0000   imm(i12, sh)
-            code |= insEncodeDatasize(EA_8BYTE);           // X
-            code |= ((code_t)imm12 << 10);                 // iiiiiiiiiiii
-            code |= insEncodeReg_Rd(dstReg);               // ddddd
-            code |= insEncodeReg_Rn(dstReg);               // nnnnn
-            dst += emitOutput_Instr(dst, code);
-        }
+        dst = emitOutputLoadLabel(dst, srcAddr, dstAddr, id);
     }
 
     return  dst;
@@ -10288,29 +10311,39 @@ void                emitter::emitDispIns(instrDesc *  id,
         break;
 
     case IF_LS_1A:    // LS_1A   XX...V..iiiiiiii iiiiiiiiiiittttt      Rt    PC imm(1MB)
+    case IF_DI_1E:    // DI_1E   .ii.....iiiiiiii iiiiiiiiiiiddddd      Rd       simm21
     case IF_LARGELDC:
+    case IF_LARGEADR:
         assert(insOptsNone(id->idInsOpt()));
         emitDispReg(id->idReg1(), size, true);
         imm = emitGetInsSC(id);
 
         /* Is this actually a reference to a data section? */
-        doffs = Compiler::eeGetJitDataOffs(id->idAddr()->iiaFieldHnd);
+        if (fmt == IF_LARGEADR)
+        {
+            printf("(LARGEADR)");
+        }
+        else if (fmt == IF_LARGELDC)
+        {
+            printf("(LARGELDC)");
+        }
 
         if (fmt == IF_LARGELDC)
         {
             printf("(LARGELDC)");
         }
         printf("[");
-        if  (doffs >= 0)
+        if (id->idAddr()->iiaIsJitDataOffset())
         {
+            doffs = Compiler::eeGetJitDataOffs(id->idAddr()->iiaFieldHnd);
             /* Display a data section reference */
 
-            if  (doffs & 1)
-                printf("@CNS%02u", doffs-1);
+            if (doffs & 1)
+                printf("@CNS%02u", doffs - 1);
             else
                 printf("@RWD%02u", doffs);
 
-            if  (imm != 0)
+            if (imm != 0)
                 printf("%+Id", imm);
         }
         else
@@ -10417,24 +10450,6 @@ void                emitter::emitDispIns(instrDesc *  id,
         emitDispImm(emitDecodeBitMaskImm(bmi, size), false);
         break;
 
-    case IF_DI_1E:    // DI_1E   .ii.....iiiiiiii iiiiiiiiiiiddddd      Rd       simm21
-    case IF_LARGEADR:
-        assert(insOptsNone(id->idInsOpt()));
-        emitDispReg(id->idReg1(), size, true);
-        if (fmt == IF_LARGEADR)
-        {
-            printf("(LARGEADR)");
-        }
-        if (id->idIsBound())
-        {
-            printf("G_M%03u_IG%02u", Compiler::s_compMethodsCount, id->idAddr()->iiaIGlabel->igNum);
-        }
-        else
-        {
-            printf("L_M%03u_BB%02u", Compiler::s_compMethodsCount, id->idAddr()->iiaBBlabel->bbNum);
-        }
-        break;
-
     case IF_DI_2A:    // DI_2A   X.......shiiiiii iiiiiinnnnnddddd      Rd Rn    imm(i12,sh)
         if ((ins == INS_add) || (ins == INS_sub))
         {
index a3f1845..6538146 100644 (file)
@@ -900,6 +900,7 @@ public:
 
     BYTE*           emitOutputLJ  (insGroup  *ig, BYTE *dst, instrDesc *i);
     unsigned        emitOutputCall(insGroup  *ig, BYTE *dst, instrDesc *i, code_t code);
+    BYTE*           emitOutputLoadLabel(BYTE* dst, BYTE* srcAddr, BYTE* dstAddr, instrDescJmp* id);
     BYTE*           emitOutputShortBranch(BYTE *dst, instruction ins, insFormat fmt, ssize_t distVal, instrDescJmp* id);
     BYTE*           emitOutputShortAddress(BYTE *dst, instruction ins, insFormat fmt, ssize_t distVal, regNumber reg);
     BYTE*           emitOutputShortConstant(BYTE *dst, instruction ins, insFormat fmt, ssize_t distVal, regNumber reg, emitAttr opSize);
index 64b4fee..060032c 100644 (file)
@@ -771,6 +771,7 @@ void Lowering::LowerNode(GenTreePtr* ppTree, Compiler::fgWalkData* data)
  *     internal temporaries to maintain the index we're evaluating plus we're using existing code from LinearCodeGen
  *     to implement this instead of implement all the control flow constructs using InstrDscs and InstrGroups downstream.
  */
+
 void Lowering::LowerSwitch(GenTreePtr* pTree)
 {
     unsigned     jumpCnt;
@@ -876,14 +877,7 @@ void Lowering::LowerSwitch(GenTreePtr* pTree)
     // because the code to load the base of the switch
     // table is huge and hideous due to the relocation... :(
     minSwitchTabJumpCnt += 2;
-#elif defined(_TARGET_ARM64_) // _TARGET_ARM_
-    // In the case of ARM64 we'll stick to generate a sequence of
-    // compare and branch for now to get switch working and revisit
-    // to implement jump tables in the future.
-    //
-    // TODO-AMD64-NYI: Implement Jump Tables.
-    minSwitchTabJumpCnt = -1; 
-#endif // _TARGET_ARM64_
+#endif // _TARGET_ARM_
     // Once we have the temporary variable, we construct the conditional branch for
     // the default case.  As stated above, this conditional is being shared between
     // both GT_SWITCH lowering code paths.