// MISC
//
+// Helper class for construction of n-element list<LLVMtype> [t,t,...,t]
+class RepLLVMType<int N, LLVMType T> {
+ list<LLVMType> ret = !if(N, !listconcat(RepLLVMType<!add(N,-1), T>.ret, [T]), []);
+}
+
+// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
+// Geom: m<M>n<N>k<K>. E.g. m8n32k16
+// Frag: [abcd]
+// PtxEltType: PTX type for the element.
+class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
+ string geom = Geom;
+ string frag = Frag;
+ string ptx_elt_type = PtxEltType;
+ string ft = frag#":"#ptx_elt_type;
+ list<LLVMType> regs = !cond(
+ // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
+ // All currently supported geometries use the same fragment format,
+ // so we only need to consider {fragment, type}.
+ !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
+ !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
+ !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
+ !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
+ !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
+ !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
+}
+
+class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
+ string intr = "llvm.nvvm.wmma."
+ # Frag.geom
+ # "." # Op
+ # "." # Frag.frag
+ # "." # Layout
+ # !if(WithStride, ".stride", "")
+ # "." # Frag.ptx_elt_type
+ ;
+ // TODO(tra): record name should ideally use the same field order as the intrinsic.
+ // E.g. string record = !subst("llvm", "int",
+ // !subst(".", "_", llvm));
+ string record = "int_nvvm_wmma_"
+ # Frag.geom
+ # "_" # Op
+ # "_" # Frag.frag
+ # "_" # Frag.ptx_elt_type
+ # "_" # Layout
+ # !if(WithStride, "_stride", "");
+}
+
+class WMMA_NAME_MMA<string ALayout, string BLayout,
+ WMMA_REGS C, WMMA_REGS D,
+ int Satfinite> {
+ string llvm = "llvm.nvvm.wmma."
+ # C.geom
+ # ".mma"
+ # "." # ALayout
+ # "." # BLayout
+ # "." # D.ptx_elt_type // Intrinsic encodes 'd' first.
+ # "." # C.ptx_elt_type
+ # !if(Satfinite, ".satfinite", "");
+
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", llvm));
+}
+
let TargetPrefix = "nvvm" in {
def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">,
Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
//
// WMMA instructions
//
-
// WMMA.LOAD
-class NVVM_WMMA_LD_GALSTS<string Geometry, string Abc, string Layout,
- string Type, LLVMType regty, int WithStride>
- : Intrinsic<!if(!eq(Abc#Type,"cf16"),
- [regty, regty, regty, regty],
- [regty, regty, regty, regty,
- regty, regty, regty, regty]),
+class NVVM_WMMA_LD<WMMA_REGS Frag, string Layout, int WithStride>
+ : Intrinsic<Frag.regs,
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
- "llvm.nvvm.wmma."
- # Geometry
- # ".load"
- # "." # Abc
- # "." # Layout
- # !if(WithStride, ".stride", "")
- # "." # Type>;
-
-multiclass NVVM_WMMA_LD_GALT<string Geometry, string Abc, string Layout,
- string Type, LLVMType regty> {
- def _stride: NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 1>;
- def NAME : NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 0>;
-}
-
-multiclass NVVM_WMMA_LD_GAT<string Geometry, string Abc,
- string Type, LLVMType regty> {
- defm _row: NVVM_WMMA_LD_GALT<Geometry, Abc, "row", Type, regty>;
- defm _col: NVVM_WMMA_LD_GALT<Geometry, Abc, "col", Type, regty>;
-}
-
-multiclass NVVM_WMMA_LD_G<string Geometry> {
- defm _a_f16: NVVM_WMMA_LD_GAT<Geometry, "a", "f16", llvm_v2f16_ty>;
- defm _b_f16: NVVM_WMMA_LD_GAT<Geometry, "b", "f16", llvm_v2f16_ty>;
- defm _c_f16: NVVM_WMMA_LD_GAT<Geometry, "c", "f16", llvm_v2f16_ty>;
- defm _c_f32: NVVM_WMMA_LD_GAT<Geometry, "c", "f32", llvm_float_ty>;
-}
-
-multiclass NVVM_WMMA_LD {
- defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">;
- defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
- defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">;
-}
-
-defm int_nvvm_wmma: NVVM_WMMA_LD;
+ WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>;
// WMMA.STORE.D
-class NVVM_WMMA_STD_GLSTS<string Geometry, string Layout,
- string Type, LLVMType regty, int WithStride,
- // This is only used to create a typed empty array we
- // need to pass to !if below.
- list<LLVMType>Empty=[]>
+class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
: Intrinsic<[],
!listconcat(
[llvm_anyptr_ty],
- !if(!eq(Type,"f16"),
- [regty, regty, regty, regty],
- [regty, regty, regty, regty,
- regty, regty, regty, regty]),
- !if(WithStride, [llvm_i32_ty], Empty)),
+ Frag.regs,
+ !if(WithStride, [llvm_i32_ty], [])),
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
- "llvm.nvvm.wmma."
- # Geometry
- # ".store.d"
- # "." # Layout
- # !if(WithStride, ".stride", "")
- # "." # Type>;
-
-multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
- string Type, LLVMType regty> {
- def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
- def NAME: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
-}
-
-multiclass NVVM_WMMA_STD_GT<string Geometry, string Type, LLVMType regty> {
- defm _row: NVVM_WMMA_STD_GLT<Geometry, "row", Type, regty>;
- defm _col: NVVM_WMMA_STD_GLT<Geometry, "col", Type, regty>;
-}
-multiclass NVVM_WMMA_STD_G<string Geometry> {
- defm _d_f16: NVVM_WMMA_STD_GT<Geometry, "f16", llvm_v2f16_ty>;
- defm _d_f32: NVVM_WMMA_STD_GT<Geometry, "f32", llvm_float_ty>;
-}
-
-multiclass NVVM_WMMA_STD {
- defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">;
- defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
- defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">;
+ WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
+
+// Create all load/store variants
+foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
+ foreach layout = ["row", "col"] in {
+ foreach stride = [0, 1] in {
+ foreach frag = [WMMA_REGS<geom, "a", "f16">,
+ WMMA_REGS<geom, "b", "f16">,
+ WMMA_REGS<geom, "c", "f16">,
+ WMMA_REGS<geom, "c", "f32">] in {
+ def WMMA_NAME_LDST<"load", frag, layout, stride>.record
+ : NVVM_WMMA_LD<frag, layout, stride>;
+ }
+ foreach frag = [WMMA_REGS<geom, "d", "f16">,
+ WMMA_REGS<geom, "d", "f32">] in {
+ def WMMA_NAME_LDST<"store", frag, layout, stride>.record
+ : NVVM_WMMA_ST<frag, layout, stride>;
+ }
+ }
+ }
}
-defm int_nvvm_wmma: NVVM_WMMA_STD;
-
// WMMA.MMA
-class NVVM_WMMA_MMA_GABDCS<string Geometry,
- string ALayout, string BLayout,
- string DType, LLVMType d_regty,
- string CType, LLVMType c_regty,
- string Satfinite = "">
- : Intrinsic<!if(!eq(DType,"f16"),
- [d_regty, d_regty, d_regty, d_regty],
- [d_regty, d_regty, d_regty, d_regty,
- d_regty, d_regty, d_regty, d_regty]),
+class NVVM_WMMA_MMA<string ALayout, string BLayout,
+ WMMA_REGS C, WMMA_REGS D, int Satfinite>
+ : Intrinsic<D.regs,
!listconcat(
- [// A
- llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
- llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
- // B
- llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
- llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
- !if(!eq(CType,"f16"),
- [c_regty, c_regty, c_regty, c_regty],
- [c_regty, c_regty, c_regty, c_regty,
- c_regty, c_regty, c_regty, c_regty])),
+ WMMA_REGS<C.geom, "a", "f16">.regs,
+ WMMA_REGS<C.geom, "b", "f16">.regs,
+ C.regs),
[IntrNoMem],
- "llvm.nvvm.wmma."
- # Geometry
- # ".mma"
- # "." # ALayout
- # "." # BLayout
- # "." # DType
- # "." # CType
- # Satfinite> {
-}
-
-multiclass NVVM_WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
- string DType, LLVMType d_regty,
- string CType, LLVMType c_regty> {
- def NAME : NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
- DType, d_regty, CType, c_regty>;
- def _satfinite: NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
- DType, d_regty, CType, c_regty,".satfinite">;
-}
-
-multiclass NVVM_WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
- string DType, LLVMType d_regty> {
- defm _f16: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
- "f16", llvm_v2f16_ty>;
- defm _f32: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
- "f32", llvm_float_ty>;
-}
-
-multiclass NVVM_WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
- defm _f16: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", llvm_v2f16_ty>;
- defm _f32: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", llvm_float_ty>;
-}
-
-multiclass NVVM_WMMA_MMA_GA<string Geometry, string ALayout> {
- defm _col: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "col">;
- defm _row: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "row">;
-}
-
-multiclass NVVM_WMMA_MMA_G<string Geometry> {
- defm _col: NVVM_WMMA_MMA_GA<Geometry, "col">;
- defm _row: NVVM_WMMA_MMA_GA<Geometry, "row">;
-}
-
-multiclass NVVM_WMMA_MMA {
- defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">;
- defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
- defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">;
+ WMMA_NAME_MMA<ALayout, BLayout, C, D, Satfinite>.llvm>;
+
+foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
+ foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach frag_c = [WMMA_REGS<geom, "c", "f16">,
+ WMMA_REGS<geom, "c", "f32">] in {
+ foreach frag_d = [WMMA_REGS<geom, "d", "f16">,
+ WMMA_REGS<geom, "d", "f32">] in {
+ foreach satf = [0, 1] in {
+ def WMMA_NAME_MMA<layout_a, layout_b, frag_c, frag_d, satf>.record
+ : NVVM_WMMA_MMA<layout_a, layout_b, frag_c, frag_d, satf>;
+ }
+ }
+ }
+ }
+ }
}
-defm int_nvvm_wmma : NVVM_WMMA_MMA;
-
} // let TargetPrefix = "nvvm"
return (d==1.0);
}]>;
-
+def AS_match {
+ code generic = [{
+ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
+ }];
+ code shared = [{
+ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
+ }];
+ code global = [{
+ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
+ }];
+}
//-----------------------------------
// Synchronization and shuffle functions
//-----------------------------------
class ATOMIC_GLOBAL_CHK <dag ops, dag frag>
- : PatFrag<ops, frag, [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
-}]>;
+ : PatFrag<ops, frag, AS_match.global>;
class ATOMIC_SHARED_CHK <dag ops, dag frag>
- : PatFrag<ops, frag, [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
-}]>;
+ : PatFrag<ops, frag, AS_match.shared>;
class ATOMIC_GENERIC_CHK <dag ops, dag frag>
- : PatFrag<ops, frag, [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
-}]>;
+ : PatFrag<ops, frag, AS_match.generic>;
multiclass F_ATOMIC_2_imp<NVPTXRegClass ptrclass, NVPTXRegClass regclass,
string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp,
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
-//
-// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
-//
-
class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
+// Generates list of n sequential register names.
+class RegSeq<int n, string prefix> {
+ list<string> ret = !if(n, !listconcat(RegSeq<!add(n,-1), prefix>.ret,
+ [prefix # !add(n, -1)]),
+ []);
+}
-class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
- string Space, string Type, NVPTXRegClass regclass,
- DAGOperand SrcOp, bit WithStride>
- : EmptyNVPTXInst,
- Requires<[!if(!eq(Geometry, "m16n16k16"),
- hasPTX60,
- hasPTX61),
- hasSM70]> {
- // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
- // for this function.
- PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
- # Geometry # "_load_"
- # !subst("c", "c_" # Type, Abc)
- # "_" # Layout
- # !subst(".", "_", Space)
- # !if(WithStride,"_stride", "")
- # "_Intr");
- dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
- dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
- dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
-
- dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
- dag Ins = !con((ins SrcOp:$src), StrideArg);
+// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
+// In addition to target-independent fields provided by WMMA_REGS, it adds
+// the fields commonly used to implement specific PTX instruction -- register
+// types and names, constraints, parts of assembly, etc.
+class WMMA_REGINFO<string Geom, string Frag, string PtxEltType>
+ : WMMA_REGS<Geom, Frag, PtxEltType> {
+ // NVPTX register types used to carry fragment data.
+ NVPTXRegClass regclass = !cond(
+ !eq(PtxEltType, "f16") : Float16x2Regs,
+ !eq(PtxEltType, "f32") : Float32Regs);
+
+ // Instruction input/output arguments for the fragment.
+ list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass);
+
+ // List of register names for the fragment -- ["ra0", "ra1",...]
+ list<string> reg_names = RegSeq<!size(ptx_regs), "r"#frag>.ret;
+ // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction.
+ string regstring = "{{$" # !head(reg_names)
+ # !foldl("", !tail(reg_names), a, b,
+ !strconcat(a, ", $", b))
+ # "}}";
+
+ // Predicates for particular fragment variant. Technically those are
+ // per-instruction predicates, but currently all fragments that can be used in
+ // a given instruction are subject to the same constraints, so an instruction
+ // can use predicates from any of its fragments. If/when this is no
+ // longer the case, we can concat all per-fragment predicates to enforce that
+ // all fragments of the instruction are viable.
+ list<Predicate> Predicates = !cond(
+ // fp16 -> fp16/fp32 @ m16n16k16
+ !and(!eq(Geom, "m16n16k16"),
+ !or(!eq(PtxEltType, "f16"),
+ !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
+
+ // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
+ !and(!or(!eq(Geom, "m8n32k16"),
+ !eq(Geom, "m32n8k16")),
+ !or(!eq(PtxEltType, "f16"),
+ !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
+
+ // template DAGs for instruction inputs/output.
+ dag Outs = !dag(outs, ptx_regs, reg_names);
+ dag Ins = !dag(ins, ptx_regs, reg_names);
+}
+class BuildPattern<dag Outs, PatFrag IntrMatcher, dag Ins> {
// Build a dag pattern that matches the intrinsic call.
// We want a dag that looks like this:
// (set <output args>, (intrinsic <input arguments>)) where input and
!subst(ins, IntrMatcher, tmp)))));
// Finally, consatenate both parts together. !con() requires both dags to have
// the same operator, so we wrap PatArgs in a (set ...) dag.
- let Pattern = [!con(PatOuts, (set PatArgs))];
- let OutOperandList = Outs;
- let InOperandList = Ins;
- let AsmString = "wmma.load."
- # Abc
- # ".sync"
- # "." # Layout
- # "." # Geometry
- # Space
- # "." # Type # " \t"
- # !if(!eq(Abc#Type, "cf16"),
- "{{$r0, $r1, $r2, $r3}}",
- "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
- # ", [$src]"
- # !if(WithStride, ", $ldm", "")
- # ";";
+ dag ret = !con(PatOuts, (set PatArgs));
}
-class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout,
- string Space, string Type, bit WithStride>
+//
+// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+
+class WMMA_LOAD_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space,
+ bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
- Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma"
- # "_" # Geometry # "_load_"
- # Abc # "_" # Type # "_" # Layout
- # !if(WithStride,"_stride", ""));
- code match_generic = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
- }];
- code match_shared = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
- }];
- code match_global = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
- }];
-
+ Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"load", Frag, Layout,
+ WithStride>.record);
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
- let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
- !if(!eq(Space, ".global"), match_global, match_generic));
-}
-
-multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout,
- string Space, string Type, NVPTXRegClass regclass,
- bit WithStride> {
- def _avar: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
- imem, WithStride>;
- def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
- Int32Regs, WithStride>;
- def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
- Int64Regs, WithStride>;
- def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
- MEMri, WithStride>;
- def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
- MEMri64, WithStride>;
+ let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
+ !eq(Space, ".global"): AS_match.global,
+ 1: AS_match.generic);
}
-multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout,
- string Space, string Type, NVPTXRegClass regclass,
- bit WithStride> {
- // Define a PatFrag that matches appropriate intrinsic that loads from the
- // given address space.
- def _Intr: WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type,
- WithStride>;
- defm NAME: WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass,
- WithStride>;
-}
-
-multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout,
- string Space, string Type, NVPTXRegClass regclass> {
- defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>;
- defm NAME: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>;
-}
-
-multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout,
- string Type, NVPTXRegClass regclass> {
- defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global",
- Type, regclass>;
- defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared",
- Type, regclass>;
- defm NAME: WMMA_LOAD_GALST<Geometry, Abc, Layout, "",
- Type, regclass>;
-}
-
-multiclass WMMA_LOAD_GAT<string Geometry, string Abc,
- string Type, NVPTXRegClass regclass> {
- defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>;
- defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>;
-}
+class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride,
+ DAGOperand SrcOp>
+ : EmptyNVPTXInst,
+ Requires<Frag.Predicates> {
+ // Pattern that matches the intrinsic for this instruction variant.
+ PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER<Frag, Layout, Space, WithStride>;
+ dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins)));
-multiclass WMMA_LOAD_G<string Geometry> {
- defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>;
- defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>;
- defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>;
- defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
+ let Pattern = [BuildPattern<Frag.Outs, IntrMatcher, Ins>.ret];
+ let OutOperandList = Frag.Outs;
+ let InOperandList = Ins;
+ let AsmString = "wmma.load."
+ # Frag.frag
+ # ".sync"
+ # "." # Layout
+ # "." # Frag.geom
+ # Space
+ # "." # Frag.ptx_elt_type # " \t"
+ # Frag.regstring
+ # ", [$src]"
+ # !if(WithStride, ", $ldm", "")
+ # ";";
}
-defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
-defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
-defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
-
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
-class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
- string Type, NVPTXRegClass regclass,
- bit WithStride, DAGOperand DstOp>
+class WMMA_STORE_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space,
+ bit WithStride>
+ : PatFrag <(ops),(ops)> {
+ // Intrinsic that matches this instruction.
+ Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"store", Frag, Layout,
+ WithStride>.record);
+ let Operands = !con((ops node:$dst),
+ !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names),
+ !if(WithStride, (ops node:$ldm), (ops)));
+ let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
+ let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
+ !eq(Space, ".global"): AS_match.global,
+ 1: AS_match.generic);
+}
+
+class WMMA_STORE<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride,
+ DAGOperand DstOp>
: EmptyNVPTXInst,
- Requires<[!if(!eq(Geometry, "m16n16k16"),
- hasPTX60,
- hasPTX61),
- hasSM70]> {
- PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
- # "_" # Geometry # "_store_d"
- # "_" # Type
- # "_" # Layout
- # !subst(".", "_", Space)
- # !if(WithStride,"_stride", "")
- # "_Intr");
- dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1,
- regclass:$r2, regclass:$r3);
- dag InsR47 = (ins regclass:$r4, regclass:$r5,
- regclass:$r6, regclass:$r7);
- dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
- dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
- dag Ins = !con(InsR, StrideArg);
-
- // Construct the pattern to match corresponding intrinsic call. See the
- // details in the comments in WMMA_LOAD_ALSTOS.
- dag PatArgs = !foreach(tmp, Ins,
- !subst(imem, ADDRvar,
- !subst(MEMri64, ADDRri64,
- !subst(MEMri, ADDRri,
- !subst(ins, IntrMatcher, tmp)))));
- let Pattern = [PatArgs];
+ Requires<Frag.Predicates> {
+ PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER<Frag, Layout, Space, WithStride>;
+ dag Ins = !con((ins DstOp:$src),
+ Frag.Ins,
+ !if(WithStride, (ins Int32Regs:$ldm), (ins)));
+ let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret];
let OutOperandList = (outs);
let InOperandList = Ins;
let AsmString = "wmma.store.d.sync."
# Layout
- # "." # Geometry
+ # "." # Frag.geom
# Space
- # "." # Type
+ # "." # Frag.ptx_elt_type
# " \t[$src],"
- # !if(!eq(Type,"f16"),
- "{{$r0, $r1, $r2, $r3}}",
- "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ # Frag.regstring
# !if(WithStride, ", $ldm", "")
# ";";
-
}
-class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space,
- string Type, bit WithStride>
- : PatFrag <(ops),(ops)> {
- // Intrinsic that matches this instruction.
- Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
- # Geometry
- # "_store_d"
- # "_" # Type
- # "_" # Layout
- # !if(WithStride, "_stride", ""));
- code match_generic = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
- }];
- code match_shared = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
- }];
- code match_global = [{
- return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
- }];
-
- dag Args = !if(!eq(Type,"f16"),
- (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
- (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
- node:$r4, node:$r5, node:$r6, node:$r7));
- dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
- let Operands = !con(Args, StrideArg);
- let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
- let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
- !if(!eq(Space, ".global"), match_global, match_generic));
-}
-
-multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space,
- string Type, NVPTXRegClass regclass,
- bit WithStride> {
- def _avar: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
- WithStride, imem>;
- def _areg: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
- WithStride, Int32Regs>;
- def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
- WithStride, Int64Regs>;
- def _ari: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
- WithStride, MEMri>;
- def _ari64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
- WithStride, MEMri64>;
-}
-
-multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space,
- string Type, NVPTXRegClass regclass,
- bit WithStride> {
- // Define a PatFrag that matches appropriate intrinsic that loads from the
- // given address space.
- def _Intr: WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type,
- WithStride>;
- defm NAME: WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass,
- WithStride>;
-}
-
-multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space,
- string Type, NVPTXRegClass regclass > {
- defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>;
- defm NAME: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>;
-}
-
-multiclass WMMA_STORE_D_GLT<string Geometry, string Layout,
- string Type, NVPTXRegClass regclass> {
- defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>;
- defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>;
- defm NAME: WMMA_STORE_D_GLST<Geometry, Layout, "", Type, regclass>;
-}
-
-multiclass WMMA_STORE_D_GT<string Geometry, string Type,
- NVPTXRegClass regclass> {
- defm _row: WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>;
- defm _col: WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>;
-}
-
-multiclass WMMA_STORE_D_G<string Geometry> {
- defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>;
- defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
-}
-
-defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
-defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
-defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
+// Create all load/store variants
+foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
+ foreach layout = ["row", "col"] in {
+ foreach stride = [0, 1] in {
+ foreach space = [".global", ".shared", ""] in {
+ foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
+ foreach frag = [WMMA_REGINFO<geom, "a", "f16">,
+ WMMA_REGINFO<geom, "b", "f16">,
+ WMMA_REGINFO<geom, "c", "f16">,
+ WMMA_REGINFO<geom, "c", "f32">] in {
+ def : WMMA_LOAD<frag, layout, space, stride, addr>;
+ }
+ foreach frag = [WMMA_REGINFO<geom, "d", "f16">,
+ WMMA_REGINFO<geom, "d", "f32">] in {
+ def : WMMA_STORE<frag, layout, space, stride, addr>;
+ }
+ } // addr
+ } // space
+ } // stride
+ } // layout
+} // geom
// WMMA.MMA
-class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
- string DType, NVPTXRegClass d_reg,
- string CType, NVPTXRegClass c_reg,
- NVPTXRegClass ab_reg,
- string Satfinite = "">
+class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string ALayout, string BLayout, int Satfinite>
: EmptyNVPTXInst,
- Requires<[!if(!eq(Geometry, "m16n16k16"),
- hasPTX60,
- hasPTX61),
- hasSM70]> {
- Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
- # Geometry
- # "_mma"
- # "_" # ALayout
- # "_" # BLayout
- # "_" # DType
- # "_" # CType
- # !subst(".", "_", Satfinite));
- dag Outs = !if(!eq(DType,"f16"),
- (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
- (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
- d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7));
- dag InsExtraCArgs = !if(!eq(CType,"f16"),
- (ins),
- (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7));
- dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
- ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
- ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
- ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
- c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
- InsExtraCArgs);
-
- // Construct the pattern to match corresponding intrinsic call. See the
- // details in the comments in WMMA_LOAD_ALSTOS.
+ Requires<FragC.Predicates> {
+ //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero;
+ Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record);
+ dag Outs = FragD.Outs;
+ dag Ins = !con(FragA.Ins,
+ FragB.Ins,
+ FragC.Ins);
+
+ // Construct the pattern to match corresponding intrinsic call.
+ // mma does not load/store anything, so we don't need complex operand matching here.
dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp));
let Pattern = [!con(PatOuts, (set PatArgs))];
let AsmString = "wmma.mma.sync."
# ALayout
# "." # BLayout
- # "." # Geometry
- # "." # DType
- # "." # CType
- # Satfinite # "\n\t\t"
- # !if(!eq(DType,"f16"),
- "{{$d0, $d1, $d2, $d3}}, \n\t\t",
- "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
- # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
- # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
- # !if(!eq(CType,"f16"),
- "{{$c0, $c1, $c2, $c3}};",
- "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
-}
-
-multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
- string DType, NVPTXRegClass d_reg,
- string CType, NVPTXRegClass c_reg> {
- def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
- DType, d_reg, CType, c_reg,
- Float16x2Regs, ".satfinite">;
- def NAME: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
- DType, d_reg, CType, c_reg,
- Float16x2Regs>;
-}
-
-multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
- string DType, NVPTXRegClass d_reg> {
- defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
- "f16", Float16x2Regs>;
- defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
- "f32", Float32Regs>;
-}
-
-multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
- defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>;
- defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>;
-}
-
-multiclass WMMA_MMA_GA<string Geometry, string ALayout> {
- defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">;
- defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">;
-}
-
-multiclass WMMA_MMA_G<string Geometry> {
- defm _col: WMMA_MMA_GA<Geometry, "col">;
- defm _row: WMMA_MMA_GA<Geometry, "row">;
+ # "." # FragA.geom
+ # "." # FragD.ptx_elt_type
+ # "." # FragC.ptx_elt_type
+ # !if(Satfinite, ".satfinite", "") # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ";";
}
-defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
-defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
-defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;
+foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
+ foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">,
+ WMMA_REGINFO<geom, "c", "f32">] in {
+ foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">,
+ WMMA_REGINFO<geom, "d", "f32">] in {
+ foreach satf = [0, 1] in {
+ def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">,
+ WMMA_REGINFO<geom, "b", "f16">,
+ frag_c, frag_d, layout_a, layout_b, satf>;
+ } // satf
+ } // frag_d
+ } // frag_c
+ } // layout_b
+ } // layout_a
+} // geom