From 7ecd82ce19ae2b0e1abef39735ad5fa2e18e92fb Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Thu, 25 Apr 2019 22:27:35 +0000 Subject: [PATCH] [NVPTX] Refactor generation of MMA intrinsics and instructions. NFC. Generalized constructions of 'fragments' of MMA operations to provide common primitives for construction of the ops. This will make it easier to add new variants of the instructions that operate on integer types. Use nested foreach loops which makes it possible to better control naming of the intrinsics. This patch does not affect LLVM's output, so there are no test changes. Differential Revision: https://reviews.llvm.org/D59389 llvm-svn: 359245 --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 258 +++++++--------- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 512 +++++++++++-------------------- 2 files changed, 295 insertions(+), 475 deletions(-) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index cf072c7..84499e6 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -37,6 +37,69 @@ def llvm_anyi64ptr_ty : LLVMAnyPointerType; // (space)i64* // MISC // +// Helper class for construction of n-element list [t,t,...,t] +class RepLLVMType { + list ret = !if(N, !listconcat(RepLLVMType.ret, [T]), []); +} + +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// Geom: mnk. E.g. m8n32k16 +// Frag: [abcd] +// PtxEltType: PTX type for the element. +class WMMA_REGS { + string geom = Geom; + string frag = Frag; + string ptx_elt_type = PtxEltType; + string ft = frag#":"#ptx_elt_type; + list regs = !cond( + // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 + // All currently supported geometries use the same fragment format, + // so we only need to consider {fragment, type}. + !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret, + !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret); +} + +class WMMA_NAME_LDST { + string intr = "llvm.nvvm.wmma." + # Frag.geom + # "." # Op + # "." # Frag.frag + # "." # Layout + # !if(WithStride, ".stride", "") + # "." # Frag.ptx_elt_type + ; + // TODO(tra): record name should ideally use the same field order as the intrinsic. + // E.g. string record = !subst("llvm", "int", + // !subst(".", "_", llvm)); + string record = "int_nvvm_wmma_" + # Frag.geom + # "_" # Op + # "_" # Frag.frag + # "_" # Frag.ptx_elt_type + # "_" # Layout + # !if(WithStride, "_stride", ""); +} + +class WMMA_NAME_MMA { + string llvm = "llvm.nvvm.wmma." + # C.geom + # ".mma" + # "." # ALayout + # "." # BLayout + # "." # D.ptx_elt_type // Intrinsic encodes 'd' first. + # "." # C.ptx_elt_type + # !if(Satfinite, ".satfinite", ""); + + string record = !subst(".", "_", + !subst("llvm.", "int_", llvm)); +} + let TargetPrefix = "nvvm" in { def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">, Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], @@ -3889,166 +3952,69 @@ def int_nvvm_match_all_sync_i64p : // // WMMA instructions // - // WMMA.LOAD -class NVVM_WMMA_LD_GALSTS - : Intrinsic + : Intrinsic, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".load" - # "." # Abc - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; - -multiclass NVVM_WMMA_LD_GALT { - def _stride: NVVM_WMMA_LD_GALSTS; - def NAME : NVVM_WMMA_LD_GALSTS; -} - -multiclass NVVM_WMMA_LD_GAT { - defm _row: NVVM_WMMA_LD_GALT; - defm _col: NVVM_WMMA_LD_GALT; -} - -multiclass NVVM_WMMA_LD_G { - defm _a_f16: NVVM_WMMA_LD_GAT; - defm _b_f16: NVVM_WMMA_LD_GAT; - defm _c_f16: NVVM_WMMA_LD_GAT; - defm _c_f32: NVVM_WMMA_LD_GAT; -} - -multiclass NVVM_WMMA_LD { - defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">; - defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">; - defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">; -} - -defm int_nvvm_wmma: NVVM_WMMA_LD; + WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>; // WMMA.STORE.D -class NVVM_WMMA_STD_GLSTSEmpty=[]> +class NVVM_WMMA_ST : Intrinsic<[], !listconcat( [llvm_anyptr_ty], - !if(!eq(Type,"f16"), - [regty, regty, regty, regty], - [regty, regty, regty, regty, - regty, regty, regty, regty]), - !if(WithStride, [llvm_i32_ty], Empty)), + Frag.regs, + !if(WithStride, [llvm_i32_ty], [])), [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".store.d" - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; - -multiclass NVVM_WMMA_STD_GLT { - def _stride: NVVM_WMMA_STD_GLSTS; - def NAME: NVVM_WMMA_STD_GLSTS; -} - -multiclass NVVM_WMMA_STD_GT { - defm _row: NVVM_WMMA_STD_GLT; - defm _col: NVVM_WMMA_STD_GLT; -} -multiclass NVVM_WMMA_STD_G { - defm _d_f16: NVVM_WMMA_STD_GT; - defm _d_f32: NVVM_WMMA_STD_GT; -} - -multiclass NVVM_WMMA_STD { - defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">; - defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">; - defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">; + WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>; + +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach frag = [WMMA_REGS, + WMMA_REGS, + WMMA_REGS, + WMMA_REGS] in { + def WMMA_NAME_LDST<"load", frag, layout, stride>.record + : NVVM_WMMA_LD; + } + foreach frag = [WMMA_REGS, + WMMA_REGS] in { + def WMMA_NAME_LDST<"store", frag, layout, stride>.record + : NVVM_WMMA_ST; + } + } + } } -defm int_nvvm_wmma: NVVM_WMMA_STD; - // WMMA.MMA -class NVVM_WMMA_MMA_GABDCS - : Intrinsic + : Intrinsic.regs, + WMMA_REGS.regs, + C.regs), [IntrNoMem], - "llvm.nvvm.wmma." - # Geometry - # ".mma" - # "." # ALayout - # "." # BLayout - # "." # DType - # "." # CType - # Satfinite> { -} - -multiclass NVVM_WMMA_MMA_GABDC { - def NAME : NVVM_WMMA_MMA_GABDCS; - def _satfinite: NVVM_WMMA_MMA_GABDCS; -} - -multiclass NVVM_WMMA_MMA_GABD { - defm _f16: NVVM_WMMA_MMA_GABDC; - defm _f32: NVVM_WMMA_MMA_GABDC; -} - -multiclass NVVM_WMMA_MMA_GAB { - defm _f16: NVVM_WMMA_MMA_GABD; - defm _f32: NVVM_WMMA_MMA_GABD; -} - -multiclass NVVM_WMMA_MMA_GA { - defm _col: NVVM_WMMA_MMA_GAB; - defm _row: NVVM_WMMA_MMA_GAB; -} - -multiclass NVVM_WMMA_MMA_G { - defm _col: NVVM_WMMA_MMA_GA; - defm _row: NVVM_WMMA_MMA_GA; -} - -multiclass NVVM_WMMA_MMA { - defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">; - defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">; - defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">; + WMMA_NAME_MMA.llvm>; + +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGS, + WMMA_REGS] in { + foreach frag_d = [WMMA_REGS, + WMMA_REGS] in { + foreach satf = [0, 1] in { + def WMMA_NAME_MMA.record + : NVVM_WMMA_MMA; + } + } + } + } + } } -defm int_nvvm_wmma : NVVM_WMMA_MMA; - } // let TargetPrefix = "nvvm" diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8f9251b..a5f0d79 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -26,7 +26,17 @@ def immDouble1 : PatLeaf<(fpimm), [{ return (d==1.0); }]>; - +def AS_match { + code generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; +} //----------------------------------- // Synchronization and shuffle functions @@ -1006,17 +1016,11 @@ def INT_FNS_iii : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, i32imm:$ //----------------------------------- class ATOMIC_GLOBAL_CHK - : PatFrag; + : PatFrag; class ATOMIC_SHARED_CHK - : PatFrag; + : PatFrag; class ATOMIC_GENERIC_CHK - : PatFrag; + : PatFrag; multiclass F_ATOMIC_2_imp; -// -// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] -// - class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; +// Generates list of n sequential register names. +class RegSeq { + list ret = !if(n, !listconcat(RegSeq.ret, + [prefix # !add(n, -1)]), + []); +} -class WMMA_LOAD_GALSTOS - : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic - // for this function. - PatFrag IntrMatcher = !cast("INT_WMMA_" - # Geometry # "_load_" - # !subst("c", "c_" # Type, Abc) - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); - dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); - dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); - - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con((ins SrcOp:$src), StrideArg); +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// In addition to target-independent fields provided by WMMA_REGS, it adds +// the fields commonly used to implement specific PTX instruction -- register +// types and names, constraints, parts of assembly, etc. +class WMMA_REGINFO + : WMMA_REGS { + // NVPTX register types used to carry fragment data. + NVPTXRegClass regclass = !cond( + !eq(PtxEltType, "f16") : Float16x2Regs, + !eq(PtxEltType, "f32") : Float32Regs); + + // Instruction input/output arguments for the fragment. + list ptx_regs = !foreach(tmp, regs, regclass); + + // List of register names for the fragment -- ["ra0", "ra1",...] + list reg_names = RegSeq.ret; + // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction. + string regstring = "{{$" # !head(reg_names) + # !foldl("", !tail(reg_names), a, b, + !strconcat(a, ", $", b)) + # "}}"; + + // Predicates for particular fragment variant. Technically those are + // per-instruction predicates, but currently all fragments that can be used in + // a given instruction are subject to the same constraints, so an instruction + // can use predicates from any of its fragments. If/when this is no + // longer the case, we can concat all per-fragment predicates to enforce that + // all fragments of the instruction are viable. + list Predicates = !cond( + // fp16 -> fp16/fp32 @ m16n16k16 + !and(!eq(Geom, "m16n16k16"), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60], + + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 + !and(!or(!eq(Geom, "m8n32k16"), + !eq(Geom, "m32n8k16")), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]); + + // template DAGs for instruction inputs/output. + dag Outs = !dag(outs, ptx_regs, reg_names); + dag Ins = !dag(ins, ptx_regs, reg_names); +} +class BuildPattern { // Build a dag pattern that matches the intrinsic call. // We want a dag that looks like this: // (set , (intrinsic )) where input and @@ -7430,277 +7458,127 @@ class WMMA_LOAD_GALSTOS +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// + +class WMMA_LOAD_INTR_HELPER : PatFrag <(ops),(ops)> { // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma" - # "_" # Geometry # "_load_" - # Abc # "_" # Type # "_" # Layout - # !if(WithStride,"_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - + Intrinsic Intr = !cast(WMMA_NAME_LDST<"load", Frag, Layout, + WithStride>.record); let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_LOAD_GALSTS { - def _avar: WMMA_LOAD_GALSTOS; - def _areg: WMMA_LOAD_GALSTOS; - def _areg64: WMMA_LOAD_GALSTOS; - def _ari: WMMA_LOAD_GALSTOS; - def _ari64: WMMA_LOAD_GALSTOS; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); } -multiclass WMMA_LOAD_GALSTSh { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_LOAD_INTR_HELPER; - defm NAME: WMMA_LOAD_GALSTS; -} - -multiclass WMMA_LOAD_GALST { - defm _stride: WMMA_LOAD_GALSTSh; - defm NAME: WMMA_LOAD_GALSTSh; -} - -multiclass WMMA_LOAD_GALT { - defm _global: WMMA_LOAD_GALST; - defm _shared: WMMA_LOAD_GALST; - defm NAME: WMMA_LOAD_GALST; -} - -multiclass WMMA_LOAD_GAT { - defm _row: WMMA_LOAD_GALT; - defm _col: WMMA_LOAD_GALT; -} +class WMMA_LOAD + : EmptyNVPTXInst, + Requires { + // Pattern that matches the intrinsic for this instruction variant. + PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER; + dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins))); -multiclass WMMA_LOAD_G { - defm _load_a: WMMA_LOAD_GAT; - defm _load_b: WMMA_LOAD_GAT; - defm _load_c_f16: WMMA_LOAD_GAT; - defm _load_c_f32: WMMA_LOAD_GAT; + let Pattern = [BuildPattern.ret]; + let OutOperandList = Frag.Outs; + let InOperandList = Ins; + let AsmString = "wmma.load." + # Frag.frag + # ".sync" + # "." # Layout + # "." # Frag.geom + # Space + # "." # Frag.ptx_elt_type # " \t" + # Frag.regstring + # ", [$src]" + # !if(WithStride, ", $ldm", "") + # ";"; } -defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">; - // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_GLSTSO +class WMMA_STORE_INTR_HELPER + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast(WMMA_NAME_LDST<"store", Frag, Layout, + WithStride>.record); + let Operands = !con((ops node:$dst), + !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), + !if(WithStride, (ops node:$ldm), (ops))); + let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); +} + +class WMMA_STORE : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - PatFrag IntrMatcher = !cast("INT_WMMA" - # "_" # Geometry # "_store_d" - # "_" # Type - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, - regclass:$r2, regclass:$r3); - dag InsR47 = (ins regclass:$r4, regclass:$r5, - regclass:$r6, regclass:$r7); - dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con(InsR, StrideArg); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. - dag PatArgs = !foreach(tmp, Ins, - !subst(imem, ADDRvar, - !subst(MEMri64, ADDRri64, - !subst(MEMri, ADDRri, - !subst(ins, IntrMatcher, tmp))))); - let Pattern = [PatArgs]; + Requires { + PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER; + dag Ins = !con((ins DstOp:$src), + Frag.Ins, + !if(WithStride, (ins Int32Regs:$ldm), (ins))); + let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret]; let OutOperandList = (outs); let InOperandList = Ins; let AsmString = "wmma.store.d.sync." # Layout - # "." # Geometry + # "." # Frag.geom # Space - # "." # Type + # "." # Frag.ptx_elt_type # " \t[$src]," - # !if(!eq(Type,"f16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # Frag.regstring # !if(WithStride, ", $ldm", "") # ";"; - } -class WMMA_STORE_INTR_HELPER - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast("int_nvvm_wmma_" - # Geometry - # "_store_d" - # "_" # Type - # "_" # Layout - # !if(WithStride, "_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - - dag Args = !if(!eq(Type,"f16"), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, - node:$r4, node:$r5, node:$r6, node:$r7)); - dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); - let Operands = !con(Args, StrideArg); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_STORE_D_GLSTS { - def _avar: WMMA_STORE_D_GLSTSO; - def _areg: WMMA_STORE_D_GLSTSO; - def _areg64: WMMA_STORE_D_GLSTSO; - def _ari: WMMA_STORE_D_GLSTSO; - def _ari64: WMMA_STORE_D_GLSTSO; -} - -multiclass WMMA_STORE_D_GLSTSh { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_STORE_INTR_HELPER; - defm NAME: WMMA_STORE_D_GLSTS; -} - -multiclass WMMA_STORE_D_GLST { - defm _stride: WMMA_STORE_D_GLSTSh; - defm NAME: WMMA_STORE_D_GLSTSh; -} - -multiclass WMMA_STORE_D_GLT { - defm _global: WMMA_STORE_D_GLST; - defm _shared: WMMA_STORE_D_GLST; - defm NAME: WMMA_STORE_D_GLST; -} - -multiclass WMMA_STORE_D_GT { - defm _row: WMMA_STORE_D_GLT; - defm _col: WMMA_STORE_D_GLT; -} - -multiclass WMMA_STORE_D_G { - defm _store_d_f16: WMMA_STORE_D_GT; - defm _store_d_f32: WMMA_STORE_D_GT; -} - -defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">; +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_LOAD; + } + foreach frag = [WMMA_REGINFO, + WMMA_REGINFO] in { + def : WMMA_STORE; + } + } // addr + } // space + } // stride + } // layout +} // geom // WMMA.MMA -class WMMA_MMA_GABDCS +class WMMA_MMA : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - Intrinsic Intr = !cast("int_nvvm_wmma_" - # Geometry - # "_mma" - # "_" # ALayout - # "_" # BLayout - # "_" # DType - # "_" # CType - # !subst(".", "_", Satfinite)); - dag Outs = !if(!eq(DType,"f16"), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, - d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)); - dag InsExtraCArgs = !if(!eq(CType,"f16"), - (ins), - (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)); - dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, - ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, - ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, - ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, - c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), - InsExtraCArgs); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. + Requires { + //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero; + Intrinsic Intr = !cast(WMMA_NAME_MMA.record); + dag Outs = FragD.Outs; + dag Ins = !con(FragA.Ins, + FragB.Ins, + FragC.Ins); + + // Construct the pattern to match corresponding intrinsic call. + // mma does not load/store anything, so we don't need complex operand matching here. dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7709,54 +7587,30 @@ class WMMA_MMA_GABDCS { - def _satfinite: WMMA_MMA_GABDCS; - def NAME: WMMA_MMA_GABDCS; -} - -multiclass WMMA_MMA_GABD { - defm _f16: WMMA_MMA_GABDC; - defm _f32: WMMA_MMA_GABDC; -} - -multiclass WMMA_MMA_GAB { - defm _f16: WMMA_MMA_GABD; - defm _f32: WMMA_MMA_GABD; -} - -multiclass WMMA_MMA_GA { - defm _col: WMMA_MMA_GAB; - defm _row: WMMA_MMA_GAB; -} - -multiclass WMMA_MMA_G { - defm _col: WMMA_MMA_GA; - defm _row: WMMA_MMA_GA; + # "." # FragA.geom + # "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } -defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">; -defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; -defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">; +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach frag_d = [WMMA_REGINFO, + WMMA_REGINFO] in { + foreach satf = [0, 1] in { + def : WMMA_MMA, + WMMA_REGINFO, + frag_c, frag_d, layout_a, layout_b, satf>; + } // satf + } // frag_d + } // frag_c + } // layout_b + } // layout_a +} // geom -- 2.7.4