[Clang][NVPTX]Add NVPTX intrinsics and builtins for CUDA PTX cvt sm80 instructions
authorJack Kirk <jack.kirk@codeplay.com>
Thu, 13 Jan 2022 20:01:20 +0000 (12:01 -0800)
committerArtem Belevich <tra@google.com>
Thu, 13 Jan 2022 21:29:48 +0000 (13:29 -0800)
Adds NVPTX intrinsics and builtins for CUDA PTX cvt instructions for sm80
architectures and above. Requires ptx 7.0.

PTX ISA description of cvt instructions :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
Differential Revision: https://reviews.llvm.org/D116673

clang/include/clang/Basic/BuiltinsNVPTX.def
clang/test/CodeGen/builtins-nvptx.c
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
llvm/lib/Target/NVPTX/NVPTX.h
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/convert-sm80.ll [new file with mode: 0644]

index 025fef0..6b94dd8 100644 (file)
@@ -402,6 +402,23 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
 BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
 BUILTIN(__nvvm_f2h_rn, "Usf", "")
 
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
+
 // Bitcast
 
 BUILTIN(__nvvm_bitcast_f2i, "if", "")
index ec0f742..1e31aaa 100644 (file)
@@ -754,4 +754,40 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
   __nvvm_cp_async_wait_all();
   #endif
   // CHECK: ret void
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: nvvm_cvt_sm80
+__device__ void nvvm_cvt_sm80() {
+#if __CUDA_ARCH__ >= 800
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
+  __nvvm_f2bf16_rn(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rn_relu(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
+  __nvvm_f2bf16_rz(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rz_relu(1);
+
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
+  __nvvm_f2tf32_rna(1);
+#endif
+  // CHECK: ret void
+}
index 6f55d1e..41b28db 100644 (file)
@@ -1185,6 +1185,36 @@ let TargetPrefix = "nvvm" in {
   def int_nvvm_f2h_rn : GCCBuiltin<"__nvvm_f2h_rn">,
       DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
 
+  def int_nvvm_ff2bf16x2_rn : GCCBuiltin<"__nvvm_ff2bf16x2_rn">,
+       Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rn_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz : GCCBuiltin<"__nvvm_ff2bf16x2_rz">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_ff2f16x2_rn : GCCBuiltin<"__nvvm_ff2f16x2_rn">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rn_relu : GCCBuiltin<"__nvvm_ff2f16x2_rn_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz : GCCBuiltin<"__nvvm_ff2f16x2_rz">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz_relu : GCCBuiltin<"__nvvm_ff2f16x2_rz_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2bf16_rn : GCCBuiltin<"__nvvm_f2bf16_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rn_relu : GCCBuiltin<"__nvvm_f2bf16_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz : GCCBuiltin<"__nvvm_f2bf16_rz">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz_relu : GCCBuiltin<"__nvvm_f2bf16_rz_relu">,
+       Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2tf32_rna : GCCBuiltin<"__nvvm_f2tf32_rna">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem]>;
+
 //
 // Bitcast
 //
index 82d332a..da0cbb3 100644 (file)
@@ -108,6 +108,10 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     // SAT flag
     if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
       O << ".sat";
+  } else if (strcmp(Modifier, "relu") == 0) {
+    // RELU flag
+    if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
+      O << ".relu";
   } else if (strcmp(Modifier, "base") == 0) {
     // Default operand
     switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
@@ -139,6 +143,9 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     case NVPTX::PTXCvtMode::RP:
       O << ".rp";
       break;
+    case NVPTX::PTXCvtMode::RNA:
+      O << ".rna";
+      break;
     }
   } else {
     llvm_unreachable("Invalid conversion modifier");
index c2fd090..41e9f37 100644 (file)
@@ -137,10 +137,12 @@ enum CvtMode {
   RZ,
   RM,
   RP,
+  RNA,
 
   BASE_MASK = 0x0F,
   FTZ_FLAG = 0x10,
-  SAT_FLAG = 0x20
+  SAT_FLAG = 0x20,
+  RELU_FLAG = 0x40
 };
 }
 
index 360731c..22e200e 100644 (file)
@@ -48,6 +48,7 @@ def CvtRN   : PatLeaf<(i32 0x5)>;
 def CvtRZ   : PatLeaf<(i32 0x6)>;
 def CvtRM   : PatLeaf<(i32 0x7)>;
 def CvtRP   : PatLeaf<(i32 0x8)>;
+def CvtRNA   : PatLeaf<(i32 0x9)>;
 
 def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
 def CvtRNI_FTZ  : PatLeaf<(i32 0x11)>;
@@ -62,6 +63,10 @@ def CvtRP_FTZ   : PatLeaf<(i32 0x18)>;
 def CvtSAT      : PatLeaf<(i32 0x20)>;
 def CvtSAT_FTZ  : PatLeaf<(i32 0x30)>;
 
+def CvtNONE_RELU   : PatLeaf<(i32 0x40)>;
+def CvtRN_RELU   : PatLeaf<(i32 0x45)>;
+def CvtRZ_RELU   : PatLeaf<(i32 0x46)>;
+
 def CvtMode : Operand<i32> {
   let PrintMethod = "printCvtMode";
 }
@@ -526,6 +531,29 @@ let hasSideEffects = false in {
                                     "cvt.s64.s16 \t$dst, $src;", []>;
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
+
+multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src;"), []>,
+                Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
+
+    multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src1, Float32Regs:$src2,  CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src1, $src2;"), []>,
+    Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
+  defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
 }
 
 //-----------------------------------
index 511cd87..ec069a0 100644 (file)
@@ -1046,6 +1046,38 @@ def : Pat<(int_nvvm_ui2f_rm Int32Regs:$a),
 def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a),
           (CVT_f32_u32 Int32Regs:$a, CvtRP)>;
 
+def : Pat<(int_nvvm_ff2bf16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2bf16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
+def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
+
+def CVT_tf32_f32 :
+   NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
+                   "cvt.rna.tf32.f32 \t$dest, $a;",
+       [(set Int32Regs:$dest, (int_nvvm_f2tf32_rna Float32Regs:$a))]>;
+
 def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
   Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;
 
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
new file mode 100644 (file)
index 0000000..81893fd
--- /dev/null
@@ -0,0 +1,136 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+
+
+; CHECK-LABEL: cvt_rn_bf16x2_f32
+define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16x2_f32
+define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16x2_f32
+define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16x2_f32
+define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_f16x2_f32
+define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_f32
+define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_f16x2_f32
+define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_f16x2_f32
+define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_bf16_f32
+define i16 @cvt_rn_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16_f32
+define i16 @cvt_rn_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16_f32
+define i16 @cvt_rz_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16_f32
+define i16 @cvt_rz_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1);
+
+ret i16 %val
+}
+
+declare i16 @llvm.nvvm.f2bf16.rn(float)
+declare i16 @llvm.nvvm.f2bf16.rn.relu(float)
+declare i16 @llvm.nvvm.f2bf16.rz(float)
+declare i16 @llvm.nvvm.f2bf16.rz.relu(float)
+
+; CHECK-LABEL: cvt_rna_tf32_f32
+define i32 @cvt_rna_tf32_f32(float %f1) {
+
+; CHECK: cvt.rna.tf32.f32
+  %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.f2tf32.rna(float)