From: Kerry McLaughlin Date: Mon, 16 Jan 2023 11:36:37 +0000 (+0000) Subject: [AArch64][SME] Add an instruction mapping for SME pseudos X-Git-Tag: upstream/17.0.6~20821 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6387d3896629e225100d91b5827ea67882496eb4;p=platform%2Fupstream%2Fllvm.git [AArch64][SME] Add an instruction mapping for SME pseudos Adds an instruction mapping to SMEInstrFormats which matches SME pseudos with the real instructions they are transformed to. A new flag is also added to AArch64Inst (SMEMatrixType), which is used to indicate the base register required when emitting many of the SME instructions. This reduces the number of pseudos handled by the switch statement in EmitInstrWithCustomInserter. Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D136856 --- diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 93aac68..511c103 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2676,35 +2676,16 @@ AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const { } MachineBasicBlock * -AArch64TargetLowering::EmitMopa(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, MachineBasicBlock *BB) const { +AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg, + MachineInstr &MI, + MachineBasicBlock *BB) const { const TargetInstrInfo *TII = Subtarget->getInstrInfo(); MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // pn - MIB.add(MI.getOperand(2)); // pm - MIB.add(MI.getOperand(3)); // zn - MIB.add(MI.getOperand(4)); // zm - - MI.eraseFromParent(); // The pseudo is gone now. - return BB; -} - -MachineBasicBlock * -AArch64TargetLowering::EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const { - const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); - - MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); - MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // Slice index register - MIB.add(MI.getOperand(2)); // Slice index offset - MIB.add(MI.getOperand(3)); // pg - MIB.add(MI.getOperand(4)); // zn + for (unsigned I = 1; I < MI.getNumOperands(); ++I) + MIB.add(MI.getOperand(I)); MI.eraseFromParent(); // The pseudo is gone now. return BB; @@ -2727,25 +2708,28 @@ AArch64TargetLowering::EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const { return BB; } -MachineBasicBlock * -AArch64TargetLowering::EmitAddVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const { - const TargetInstrInfo *TII = Subtarget->getInstrInfo(); - MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc)); - - MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define); - MIB.addReg(BaseReg + MI.getOperand(0).getImm()); - MIB.add(MI.getOperand(1)); // pn - MIB.add(MI.getOperand(2)); // pm - MIB.add(MI.getOperand(3)); // zn - - MI.eraseFromParent(); // The pseudo is gone now. - return BB; -} - MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { + + int SMEOrigInstr = AArch64::getSMEPseudoMap(MI.getOpcode()); + if (SMEOrigInstr != -1) { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + uint64_t SMEMatrixType = + TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask; + switch (SMEMatrixType) { + case (AArch64::SMEMatrixTileB): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB); + case (AArch64::SMEMatrixTileH): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB); + case (AArch64::SMEMatrixTileS): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB); + case (AArch64::SMEMatrixTileD): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB); + case (AArch64::SMEMatrixTileQ): + return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB); + } + } + switch (MI.getOpcode()) { default: #ifndef NDEBUG @@ -2795,94 +2779,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB); case AArch64::LDR_ZA_PSEUDO: return EmitFill(MI, BB); - case AArch64::BFMOPA_MPPZZ_PSEUDO: - return EmitMopa(AArch64::BFMOPA_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::BFMOPS_MPPZZ_PSEUDO: - return EmitMopa(AArch64::BFMOPS_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPAL_MPPZZ_PSEUDO: - return EmitMopa(AArch64::FMOPAL_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPSL_MPPZZ_PSEUDO: - return EmitMopa(AArch64::FMOPSL_MPPZZ, AArch64::ZAS0, MI, BB); - case AArch64::FMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::FMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::FMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::FMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::FMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::FMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::FMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::FMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::UMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::UMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::UMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::UMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SUMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SUMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SUMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::SUMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::USMOPA_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::USMOPA_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::USMOPS_MPPZZ_S_PSEUDO: - return EmitMopa(AArch64::USMOPS_MPPZZ_S, AArch64::ZAS0, MI, BB); - case AArch64::SMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::UMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::UMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::UMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::UMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SUMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SUMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::SUMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::SUMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::USMOPA_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::USMOPA_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::USMOPS_MPPZZ_D_PSEUDO: - return EmitMopa(AArch64::USMOPS_MPPZZ_D, AArch64::ZAD0, MI, BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_B: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_B, AArch64::ZAB0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_H: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_H, AArch64::ZAH0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_S: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_S, AArch64::ZAS0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_D: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_D, AArch64::ZAD0, MI, - BB); - case AArch64::INSERT_MXIPZ_H_PSEUDO_Q: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_Q, AArch64::ZAQ0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_B: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_B, AArch64::ZAB0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_H: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_H, AArch64::ZAH0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_S: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_S, AArch64::ZAS0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_D: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_D, AArch64::ZAD0, MI, - BB); - case AArch64::INSERT_MXIPZ_V_PSEUDO_Q: - return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_Q, AArch64::ZAQ0, MI, - BB); case AArch64::ZERO_M_PSEUDO: return EmitZero(MI, BB); - case AArch64::ADDHA_MPPZ_PSEUDO_S: - return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_S, AArch64::ZAS0, MI, BB); - case AArch64::ADDVA_MPPZ_PSEUDO_S: - return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_S, AArch64::ZAS0, MI, BB); - case AArch64::ADDHA_MPPZ_PSEUDO_D: - return EmitAddVectorToTile(AArch64::ADDHA_MPPZ_D, AArch64::ZAD0, MI, BB); - case AArch64::ADDVA_MPPZ_PSEUDO_D: - return EmitAddVectorToTile(AArch64::ADDVA_MPPZ_D, AArch64::ZAD0, MI, BB); } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 9cf99b3..febb116 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -595,15 +595,9 @@ public: MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const; - MachineBasicBlock *EmitMopa(unsigned Opc, unsigned BaseReg, MachineInstr &MI, - MachineBasicBlock *BB) const; - MachineBasicBlock *EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const; + MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg, + MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const; - MachineBasicBlock *EmitAddVectorToTile(unsigned Opc, unsigned BaseReg, - MachineInstr &MI, - MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 0a24896..91179aa 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -45,6 +45,17 @@ def FalseLanesNone : FalseLanesEnum<0>; def FalseLanesZero : FalseLanesEnum<1>; def FalseLanesUndef : FalseLanesEnum<2>; +class SMEMatrixTypeEnum val> { + bits<3> Value = val; +} +def SMEMatrixNone : SMEMatrixTypeEnum<0>; +def SMEMatrixTileB : SMEMatrixTypeEnum<1>; +def SMEMatrixTileH : SMEMatrixTypeEnum<2>; +def SMEMatrixTileS : SMEMatrixTypeEnum<3>; +def SMEMatrixTileD : SMEMatrixTypeEnum<4>; +def SMEMatrixTileQ : SMEMatrixTypeEnum<5>; +def SMEMatrixArray : SMEMatrixTypeEnum<6>; + // AArch64 Instruction Format class AArch64Inst : Instruction { field bits<32> Inst; // Instruction encoding. @@ -65,16 +76,18 @@ class AArch64Inst : Instruction { bit isPTestLike = 0; FalseLanesEnum FalseLanes = FalseLanesNone; DestructiveInstTypeEnum DestructiveInstType = NotDestructive; + SMEMatrixTypeEnum SMEMatrixType = SMEMatrixNone; ElementSizeEnum ElementSize = ElementSizeNone; - let TSFlags{10} = isPTestLike; - let TSFlags{9} = isWhile; - let TSFlags{8-7} = FalseLanes.Value; - let TSFlags{6-3} = DestructiveInstType.Value; - let TSFlags{2-0} = ElementSize.Value; + let TSFlags{13-11} = SMEMatrixType.Value; + let TSFlags{10} = isPTestLike; + let TSFlags{9} = isWhile; + let TSFlags{8-7} = FalseLanes.Value; + let TSFlags{6-3} = DestructiveInstType.Value; + let TSFlags{2-0} = ElementSize.Value; - let Pattern = []; - let Constraints = cstr; + let Pattern = []; + let Constraints = cstr; } class InstSubst diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index ec60927..caf9421 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -539,10 +539,11 @@ static inline unsigned getPACOpcodeForKey(AArch64PACKey::ID K, bool Zero) { } // struct TSFlags { -#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits -#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits -#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits -#define TSFLAG_INSTR_FLAGS(X) ((X) << 9) // 2-bits +#define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits +#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits +#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits +#define TSFLAG_INSTR_FLAGS(X) ((X) << 9) // 2-bits +#define TSFLAG_SME_MATRIX_TYPE(X) ((X) << 11) // 3-bits // } namespace AArch64 { @@ -580,14 +581,28 @@ enum FalseLaneType { static const uint64_t InstrFlagIsWhile = TSFLAG_INSTR_FLAGS(0x1); static const uint64_t InstrFlagIsPTestLike = TSFLAG_INSTR_FLAGS(0x2); +enum SMEMatrixType { + SMEMatrixTypeMask = TSFLAG_SME_MATRIX_TYPE(0x7), + SMEMatrixNone = TSFLAG_SME_MATRIX_TYPE(0x0), + SMEMatrixTileB = TSFLAG_SME_MATRIX_TYPE(0x1), + SMEMatrixTileH = TSFLAG_SME_MATRIX_TYPE(0x2), + SMEMatrixTileS = TSFLAG_SME_MATRIX_TYPE(0x3), + SMEMatrixTileD = TSFLAG_SME_MATRIX_TYPE(0x4), + SMEMatrixTileQ = TSFLAG_SME_MATRIX_TYPE(0x5), + SMEMatrixArray = TSFLAG_SME_MATRIX_TYPE(0x6), +}; + #undef TSFLAG_ELEMENT_SIZE_TYPE #undef TSFLAG_DESTRUCTIVE_INST_TYPE #undef TSFLAG_FALSE_LANE_TYPE #undef TSFLAG_INSTR_FLAGS +#undef TSFLAG_SME_MATRIX_TYPE int getSVEPseudoMap(uint16_t Opcode); int getSVERevInstr(uint16_t Opcode); int getSVENonRevInstr(uint16_t Opcode); + +int getSMEPseudoMap(uint16_t Opcode); } } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 468db59..df08e5e 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -49,15 +49,15 @@ def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1> def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>; def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>; -def ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha">; -def ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva">; +defm ADDHA_MPPZ_S : sme_add_vector_to_tile_u32<0b0, "addha", int_aarch64_sme_addha>; +defm ADDVA_MPPZ_S : sme_add_vector_to_tile_u32<0b1, "addva", int_aarch64_sme_addva>; def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>; } let Predicates = [HasSMEI16I64] in { -def ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha">; -def ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva">; +defm ADDHA_MPPZ_D : sme_add_vector_to_tile_u64<0b0, "addha", int_aarch64_sme_addha>; +defm ADDVA_MPPZ_D : sme_add_vector_to_tile_u64<0b1, "addva", int_aarch64_sme_addva>; } let Predicates = [HasSME] in { diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td index e541603..3556d7f 100644 --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -25,17 +25,35 @@ def tileslice128 : ComplexPattern", []>; // nop def am_sme_indexed_b4 :ComplexPattern", [], [SDNPWantRoot]>; //===----------------------------------------------------------------------===// -// SME Outer Products +// SME Pseudo Classes //===----------------------------------------------------------------------===// -class sme_outer_product_pseudo +def getSMEPseudoMap : InstrMapping { + let FilterClass = "SMEPseudo2Instr"; + let RowFields = ["PseudoName"]; + let ColFields = ["IsInstr"]; + let KeyCol = ["0"]; + let ValueCols = [["1"]]; +} + +class SMEPseudo2Instr { + string PseudoName = name; + bit IsInstr = instr; +} + +class sme_outer_product_pseudo : Pseudo<(outs), (ins i32imm:$tile, PPR3bAny:$pn, PPR3bAny:$pm, zpr_ty:$zn, zpr_ty:$zm), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } +//===----------------------------------------------------------------------===// +// SME Outer Products +//===----------------------------------------------------------------------===// + class sme_fp_outer_product_inst sz, bit op, MatrixTileOperand za_ty, ZPRRegOp zpr_ty, string mnemonic> : I<(outs za_ty:$ZAda), @@ -62,13 +80,13 @@ class sme_fp_outer_product_inst sz, bit op, MatrixTileOperand za_ } multiclass sme_outer_product_fp32 { - def NAME : sme_fp_outer_product_inst { + def NAME : sme_fp_outer_product_inst, SMEPseudo2Instr { bits<2> ZAda; let Inst{1-0} = ZAda; let Inst{2} = 0b0; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), (nxv4f32 ZPR32:$zn), (nxv4f32 ZPR32:$zm)), @@ -76,12 +94,12 @@ multiclass sme_outer_product_fp32 } multiclass sme_outer_product_fp64 { - def NAME : sme_fp_outer_product_inst { + def NAME : sme_fp_outer_product_inst, SMEPseudo2Instr { bits<3> ZAda; let Inst{2-0} = ZAda; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), (nxv2f64 ZPR64:$zn), (nxv2f64 ZPR64:$zm)), @@ -126,13 +144,13 @@ class sme_int_outer_product_inst opc, bit sz, bit sme2, multiclass sme_int_outer_product_i32 opc, string mnemonic, SDPatternOperator op> { def NAME : sme_int_outer_product_inst { + ZPR8, mnemonic>, SMEPseudo2Instr { bits<2> ZAda; let Inst{1-0} = ZAda; let Inst{2} = 0b0; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv16i1 PPR3bAny:$pn), (nxv16i1 PPR3bAny:$pm), (nxv16i8 ZPR8:$zn), (nxv16i8 ZPR8:$zm)), @@ -142,12 +160,12 @@ multiclass sme_int_outer_product_i32 opc, string mnemonic, multiclass sme_int_outer_product_i64 opc, string mnemonic, SDPatternOperator op> { def NAME : sme_int_outer_product_inst { + ZPR16, mnemonic>, SMEPseudo2Instr { bits<3> ZAda; let Inst{2-0} = ZAda; } - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_7:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8i16 ZPR16:$zn), (nxv8i16 ZPR16:$zm)), @@ -182,9 +200,9 @@ class sme_outer_product_widening_inst opc, ZPRRegOp zpr_ty, string mnemo } multiclass sme_bf16_outer_product opc, string mnemonic, SDPatternOperator op> { - def NAME : sme_outer_product_widening_inst; + def NAME : sme_outer_product_widening_inst, SMEPseudo2Instr; - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8bf16 ZPR16:$zn), (nxv8bf16 ZPR16:$zm)), @@ -192,9 +210,9 @@ multiclass sme_bf16_outer_product opc, string mnemonic, SDPatternOperato } multiclass sme_f16_outer_product opc, string mnemonic, SDPatternOperator op> { - def NAME : sme_outer_product_widening_inst; + def NAME : sme_outer_product_widening_inst, SMEPseudo2Instr; - def NAME # _PSEUDO : sme_outer_product_pseudo; + def NAME # _PSEUDO : sme_outer_product_pseudo, SMEPseudo2Instr; def : Pat<(op timm32_0_3:$tile, (nxv8i1 PPR3bAny:$pn), (nxv8i1 PPR3bAny:$pm), (nxv8f16 ZPR16:$zn), (nxv8f16 ZPR16:$zm)), @@ -226,51 +244,42 @@ class sme_add_vector_to_tile_inst - : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic> { - bits<2> ZAda; - let Inst{2} = 0b0; - let Inst{1-0} = ZAda; -} - -class sme_add_vector_to_tile_u64 - : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic> { - bits<3> ZAda; - let Inst{2-0} = ZAda; -} - -class sme_add_vector_to_tile_pseudo +class sme_add_vector_to_tile_pseudo : Pseudo<(outs), (ins i32imm:$tile, PPR3bAny:$Pn, PPR3bAny:$Pm, zpr_ty:$Zn), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } -def ADDHA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo; -def ADDVA_MPPZ_PSEUDO_S : sme_add_vector_to_tile_pseudo; +multiclass sme_add_vector_to_tile_u32 { + def NAME : sme_add_vector_to_tile_inst<0b0, V, TileOp32, ZPR32, mnemonic>, SMEPseudo2Instr { + bits<2> ZAda; + let Inst{2} = 0b0; + let Inst{1-0} = ZAda; + } + + def _PSEUDO_S : sme_add_vector_to_tile_pseudo, SMEPseudo2Instr; -def : Pat<(int_aarch64_sme_addha - timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), - (nxv4i32 ZPR32:$zn)), - (ADDHA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>; -def : Pat<(int_aarch64_sme_addva - timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), + def : Pat<(op timm32_0_3:$tile, (nxv4i1 PPR3bAny:$pn), (nxv4i1 PPR3bAny:$pm), (nxv4i32 ZPR32:$zn)), - (ADDVA_MPPZ_PSEUDO_S timm32_0_3:$tile, $pn, $pm, $zn)>; + (!cast(NAME # _PSEUDO_S) timm32_0_3:$tile, $pn, $pm, $zn)>; +} + +multiclass sme_add_vector_to_tile_u64 { + def NAME : sme_add_vector_to_tile_inst<0b1, V, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr { + bits<3> ZAda; + let Inst{2-0} = ZAda; + } -let Predicates = [HasSMEI16I64] in { -def ADDHA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo; -def ADDVA_MPPZ_PSEUDO_D : sme_add_vector_to_tile_pseudo; + def _PSEUDO_D : sme_add_vector_to_tile_pseudo, SMEPseudo2Instr; -def : Pat<(int_aarch64_sme_addha - timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), - (nxv2i64 ZPR64:$zn)), - (ADDHA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>; -def : Pat<(int_aarch64_sme_addva - timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), - (nxv2i64 ZPR64:$zn)), - (ADDVA_MPPZ_PSEUDO_D timm32_0_7:$tile, $pn, $pm, $zn)>; + let Predicates = [HasSMEI16I64] in { + def : Pat<(op timm32_0_7:$tile, (nxv2i1 PPR3bAny:$pn), (nxv2i1 PPR3bAny:$pm), + (nxv2i64 ZPR64:$zn)), + (!cast(NAME # _PSEUDO_D) timm32_0_7:$tile, $pn, $pm, $zn)>; + } } //===----------------------------------------------------------------------===// @@ -711,24 +720,27 @@ multiclass sme_vector_to_tile_patterns : Pseudo<(outs), (ins i32imm:$tile, MatrixIndexGPR32Op12_15:$idx, i32imm:$imm, PPR3bAny:$pg, ZPRAny:$zn), []>, Sched<[]> { // Translated to the actual instructions in AArch64ISelLowering.cpp + let SMEMatrixType = za_flag; let usesCustomInserter = 1; } multiclass sme_vector_v_to_tile { def _B : sme_vector_to_tile_inst<0b0, 0b00, !if(is_col, TileVectorOpV8, TileVectorOpH8), - is_col, sme_elm_idx0_15, ZPR8, mnemonic> { + is_col, sme_elm_idx0_15, ZPR8, mnemonic>, + SMEPseudo2Instr { bits<4> imm; let Inst{3-0} = imm; } def _H : sme_vector_to_tile_inst<0b0, 0b01, !if(is_col, TileVectorOpV16, TileVectorOpH16), - is_col, sme_elm_idx0_7, ZPR16, mnemonic> { + is_col, sme_elm_idx0_7, ZPR16, mnemonic>, + SMEPseudo2Instr { bits<1> ZAd; bits<3> imm; let Inst{3} = ZAd; @@ -736,7 +748,8 @@ multiclass sme_vector_v_to_tile { } def _S : sme_vector_to_tile_inst<0b0, 0b10, !if(is_col, TileVectorOpV32, TileVectorOpH32), - is_col, sme_elm_idx0_3, ZPR32, mnemonic> { + is_col, sme_elm_idx0_3, ZPR32, mnemonic>, + SMEPseudo2Instr { bits<2> ZAd; bits<2> imm; let Inst{3-2} = ZAd; @@ -744,7 +757,8 @@ multiclass sme_vector_v_to_tile { } def _D : sme_vector_to_tile_inst<0b0, 0b11, !if(is_col, TileVectorOpV64, TileVectorOpH64), - is_col, sme_elm_idx0_1, ZPR64, mnemonic> { + is_col, sme_elm_idx0_1, ZPR64, mnemonic>, + SMEPseudo2Instr { bits<3> ZAd; bits<1> imm; let Inst{3-1} = ZAd; @@ -752,7 +766,8 @@ multiclass sme_vector_v_to_tile { } def _Q : sme_vector_to_tile_inst<0b1, 0b11, !if(is_col, TileVectorOpV128, TileVectorOpH128), - is_col, sme_elm_idx0_0, ZPR128, mnemonic> { + is_col, sme_elm_idx0_0, ZPR128, mnemonic>, + SMEPseudo2Instr { bits<4> ZAd; bits<1> imm; let Inst{3-0} = ZAd; @@ -760,11 +775,11 @@ multiclass sme_vector_v_to_tile { // Pseudo instructions for lowering intrinsics, using immediates instead of // tile registers. - def _PSEUDO_B : sme_mova_insert_pseudo; - def _PSEUDO_H : sme_mova_insert_pseudo; - def _PSEUDO_S : sme_mova_insert_pseudo; - def _PSEUDO_D : sme_mova_insert_pseudo; - def _PSEUDO_Q : sme_mova_insert_pseudo; + def _PSEUDO_B : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_H : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_S : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_D : sme_mova_insert_pseudo, SMEPseudo2Instr; + def _PSEUDO_Q : sme_mova_insert_pseudo, SMEPseudo2Instr; defm : sme_vector_to_tile_aliases(NAME # _B), !if(is_col, TileVectorOpV8,