[MLIR] Add `and`, `or`, `xor`, `min`, `max` too gpu.all_reduce and the nvvm lowering
authorValentin Clement <clementval@gmail.com>
Wed, 11 Mar 2020 12:56:31 +0000 (13:56 +0100)
committerStephan Herhut <herhut@google.com>
Wed, 11 Mar 2020 13:07:04 +0000 (14:07 +0100)
Summary:
This patch add some builtin operation for the gpu.all_reduce ops.
- for Integer only: `and`, `or`, `xor`
- for Float and Integer: `min`, `max`

This is useful for higher level dialect like OpenACC or OpenMP that can lower to the GPU dialect.

Differential Revision: https://reviews.llvm.org/D75766

13 files changed:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/ExecutionEngine/RunnerUtils.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/ExecutionEngine/RunnerUtils.cpp
mlir/test/Dialect/GPU/all-reduce-max.mlir [new file with mode: 0644]
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/mlir-cuda-runner/all-reduce-and.mlir [new file with mode: 0644]
mlir/test/mlir-cuda-runner/all-reduce-max.mlir [new file with mode: 0644]
mlir/test/mlir-cuda-runner/all-reduce-min.mlir [new file with mode: 0644]
mlir/test/mlir-cuda-runner/all-reduce-or.mlir [new file with mode: 0644]
mlir/test/mlir-cuda-runner/all-reduce-xor.mlir [new file with mode: 0644]

index a979ac9..d557ad5 100644 (file)
@@ -482,15 +482,25 @@ def GPU_YieldOp : GPU_Op<"yield", [Terminator]>,
   }];
 }
 
-// These mirror the XLA ComparisonDirection enum.
+// add, mul mirror the XLA ComparisonDirection enum.
 def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
+def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
+def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
+def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
 def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
+def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
+def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;
 
 def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
     "built-in reduction operations supported by gpu.allreduce.",
     [
       GPU_AllReduceOpAdd,
+      GPU_AllReduceOpAnd,
+      GPU_AllReduceOpMax,
+      GPU_AllReduceOpMin,
       GPU_AllReduceOpMul,
+      GPU_AllReduceOpOr,
+      GPU_AllReduceOpXor
     ]>;
 
 def GPU_AllReduceOp : GPU_Op<"all_reduce",
@@ -514,8 +524,8 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
     ```
     compute the sum of each work item's %0 value. The first version specifies
     the accumulation as operation, whereas the second version specifies the
-    accumulation as code region. The accumulation operation must either be
-    `add` or `mul`.
+    accumulation as code region. The accumulation operation must be one of:
+    `add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
 
     Either none or all work items of a workgroup need to execute this op
     in convergence.
index 16dc54f..5f239a4 100644 (file)
@@ -211,6 +211,8 @@ _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
 extern "C" MLIR_RUNNERUTILS_EXPORT void
 _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
 
+extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank,
+                                                         void *ptr);
 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank,
                                                          void *ptr);
 
index b14968a..7af62df 100644 (file)
@@ -123,18 +123,51 @@ private:
       return isFloatingPoint ? getFactory<LLVM::FMulOp>()
                              : getFactory<LLVM::MulOp>();
     }
+    if (opName == "and") {
+      return getFactory<LLVM::AndOp>();
+    }
+    if (opName == "or") {
+      return getFactory<LLVM::OrOp>();
+    }
+    if (opName == "xor") {
+      return getFactory<LLVM::XOrOp>();
+    }
+    if (opName == "max") {
+      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
+                                             LLVM::FCmpPredicate::ugt>()
+                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
+                                             LLVM::ICmpPredicate::ugt>();
+    }
+    if (opName == "min") {
+      return isFloatingPoint ? getCmpFactory<LLVM::FCmpOp, LLVM::FCmpPredicate,
+                                             LLVM::FCmpPredicate::ult>()
+                             : getCmpFactory<LLVM::ICmpOp, LLVM::ICmpPredicate,
+                                             LLVM::ICmpPredicate::ult>();
+    }
 
     return AccumulatorFactory();
   }
 
   /// Returns an accumulator factory that creates an op of type T.
-  template <typename T> AccumulatorFactory getFactory() const {
+  template <typename T>
+  AccumulatorFactory getFactory() const {
     return [](Location loc, Value lhs, Value rhs,
               ConversionPatternRewriter &rewriter) {
       return rewriter.create<T>(loc, lhs.getType(), lhs, rhs);
     };
   }
 
+  /// Returns an accumulator for comparaison such as min, max. T is the type
+  /// of the compare op.
+  template <typename T, typename PredicateEnum, PredicateEnum predicate>
+  AccumulatorFactory getCmpFactory() const {
+    return [](Location loc, Value lhs, Value rhs,
+              ConversionPatternRewriter &rewriter) {
+      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
+      return rewriter.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+    };
+  }
+
   /// Creates an all_reduce across the block.
   ///
   /// First reduce the elements within a warp. The first thread of each warp
@@ -705,9 +738,9 @@ void mlir::populateGpuToNVVMConversionPatterns(
               GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering,
               GPUReturnOpLowering>(converter);
   patterns.insert<OpToFuncCallLowering<AbsFOp>>(converter, "__nv_fabsf",
-                                               "__nv_fabs");
+                                                "__nv_fabs");
   patterns.insert<OpToFuncCallLowering<CeilFOp>>(converter, "__nv_ceilf",
-                                               "__nv_ceil");
+                                                 "__nv_ceil");
   patterns.insert<OpToFuncCallLowering<CosOp>>(converter, "__nv_cosf",
                                                "__nv_cos");
   patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
index bd4083a..547c06d 100644 (file)
@@ -148,6 +148,14 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
     }
     if (yieldCount == 0)
       return allReduce.emitError("expected gpu.yield op in region");
+  } else {
+    StringRef opName = *allReduce.op();
+    if ((opName == "and" || opName == "or" || opName == "xor") &&
+        !allReduce.getType().isa<IntegerType>()) {
+      return allReduce.emitError()
+             << '`' << opName << '`'
+             << " accumulator is only compatible with Integer type";
+    }
   }
   return success();
 }
index 2982daa..ae5c981 100644 (file)
@@ -212,6 +212,25 @@ private:
       return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
     if (opName == "mul")
       return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
+    if (opName == "and") {
+      return getFactory<AndOp>();
+    }
+    if (opName == "or") {
+      return getFactory<OrOp>();
+    }
+    if (opName == "xor") {
+      return getFactory<XOrOp>();
+    }
+    if (opName == "max") {
+      return isFloatingPoint
+                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
+                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
+    }
+    if (opName == "min") {
+      return isFloatingPoint
+                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
+                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
+    }
     return AccumulatorFactory();
   }
 
@@ -222,6 +241,16 @@ private:
     };
   }
 
+  /// Returns an accumulator for comparaison such as min, max. T is the type
+  /// of the compare op.
+  template <typename T, typename PredicateEnum, PredicateEnum predicate>
+  AccumulatorFactory getCmpFactory() const {
+    return [&](Value lhs, Value rhs) {
+      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
+      return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
+    };
+  }
+
   /// Creates an if-block skeleton and calls the two factories to generate the
   /// ops in the `then` and `else` block..
   ///
index 3f345e0..af75b10 100644 (file)
@@ -27,7 +27,7 @@ extern "C" void _mlir_ciface_print_memref_vector_4x4xf32(
 
 extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
   printUnrankedMemRefMetaData(std::cout, *M);
-  int rank = M->rank;
+  int64_t rank = M->rank;
   void *ptr = M->descriptor;
 
   switch (rank) {
@@ -41,9 +41,25 @@ extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M) {
   }
 }
 
+extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType<int32_t> *M) {
+  printUnrankedMemRefMetaData(std::cout, *M);
+  int64_t rank = M->rank;
+  void *ptr = M->descriptor;
+
+  switch (rank) {
+    MEMREF_CASE(int32_t, 0);
+    MEMREF_CASE(int32_t, 1);
+    MEMREF_CASE(int32_t, 2);
+    MEMREF_CASE(int32_t, 3);
+    MEMREF_CASE(int32_t, 4);
+  default:
+    assert(0 && "Unsupported rank to print");
+  }
+}
+
 extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
   printUnrankedMemRefMetaData(std::cout, *M);
-  int rank = M->rank;
+  int64_t rank = M->rank;
   void *ptr = M->descriptor;
 
   switch (rank) {
@@ -57,10 +73,13 @@ extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M) {
   }
 }
 
+extern "C" void print_memref_i32(int64_t rank, void *ptr) {
+  UnrankedMemRefType<int32_t> descriptor = {rank, ptr};
+  _mlir_ciface_print_memref_i32(&descriptor);
+}
+
 extern "C" void print_memref_f32(int64_t rank, void *ptr) {
-  UnrankedMemRefType<float> descriptor;
-  descriptor.rank = rank;
-  descriptor.descriptor = ptr;
+  UnrankedMemRefType<float> descriptor = {rank, ptr};
   _mlir_ciface_print_memref_f32(&descriptor);
 }
 
diff --git a/mlir/test/Dialect/GPU/all-reduce-max.mlir b/mlir/test/Dialect/GPU/all-reduce-max.mlir
new file mode 100644 (file)
index 0000000..ffd2447
--- /dev/null
@@ -0,0 +1,203 @@
+// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s
+
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: module @kernels attributes {gpu.kernel_module} {
+module @kernels attributes {gpu.kernel_module} {
+
+  // CHECK-LABEL: gpu.func @kernel(
+  // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel {
+  gpu.func @kernel(%arg0 : f32) attributes { gpu.kernel } {
+    // CHECK:   [[VAL_2:%.*]] = constant 31 : i32
+    // CHECK:   [[VAL_3:%.*]] = constant 0 : i32
+    // CHECK:   [[VAL_4:%.*]] = constant 0 : index
+    // CHECK:   [[VAL_5:%.*]] = constant 32 : i32
+    // CHECK:   [[VAL_6:%.*]] = constant 1 : i32
+    // CHECK:   [[VAL_7:%.*]] = constant 2 : i32
+    // CHECK:   [[VAL_8:%.*]] = constant 4 : i32
+    // CHECK:   [[VAL_9:%.*]] = constant 8 : i32
+    // CHECK:   [[VAL_10:%.*]] = constant 16 : i32
+    // CHECK:   [[VAL_11:%.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index
+    // CHECK:   [[VAL_12:%.*]] = index_cast [[VAL_11]] : index to i32
+    // CHECK:   [[VAL_13:%.*]] = "gpu.block_dim"() {dimension = "y"} : () -> index
+    // CHECK:   [[VAL_14:%.*]] = index_cast [[VAL_13]] : index to i32
+    // CHECK:   [[VAL_15:%.*]] = "gpu.block_dim"() {dimension = "z"} : () -> index
+    // CHECK:   [[VAL_16:%.*]] = index_cast [[VAL_15]] : index to i32
+    // CHECK:   [[VAL_17:%.*]] = "gpu.thread_id"() {dimension = "x"} : () -> index
+    // CHECK:   [[VAL_18:%.*]] = index_cast [[VAL_17]] : index to i32
+    // CHECK:   [[VAL_19:%.*]] = "gpu.thread_id"() {dimension = "y"} : () -> index
+    // CHECK:   [[VAL_20:%.*]] = index_cast [[VAL_19]] : index to i32
+    // CHECK:   [[VAL_21:%.*]] = "gpu.thread_id"() {dimension = "z"} : () -> index
+    // CHECK:   [[VAL_22:%.*]] = index_cast [[VAL_21]] : index to i32
+    // CHECK:   [[VAL_23:%.*]] = muli [[VAL_22]], [[VAL_14]] : i32
+    // CHECK:   [[VAL_24:%.*]] = addi [[VAL_23]], [[VAL_20]] : i32
+    // CHECK:   [[VAL_25:%.*]] = muli [[VAL_24]], [[VAL_12]] : i32
+    // CHECK:   [[VAL_26:%.*]] = muli [[VAL_12]], [[VAL_14]] : i32
+    // CHECK:   [[VAL_27:%.*]] = addi [[VAL_25]], [[VAL_18]] : i32
+    // CHECK:   [[VAL_28:%.*]] = muli [[VAL_26]], [[VAL_16]] : i32
+    // CHECK:   [[VAL_29:%.*]] = and [[VAL_27]], [[VAL_2]] : i32
+    // CHECK:   [[VAL_30:%.*]] = cmpi "eq", [[VAL_29]], [[VAL_3]] : i32
+    // CHECK:   [[VAL_31:%.*]] = subi [[VAL_27]], [[VAL_29]] : i32
+    // CHECK:   [[VAL_32:%.*]] = subi [[VAL_28]], [[VAL_31]] : i32
+    // CHECK:   [[VAL_33:%.*]] = cmpi "slt", [[VAL_32]], [[VAL_5]] : i32
+    // CHECK:   cond_br [[VAL_33]], ^bb1, ^bb17
+    // CHECK: ^bb1:
+    // CHECK:   [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_32]] xor : f32
+    // CHECK:   cond_br [[VAL_35]], ^bb2, ^bb3
+    // CHECK: ^bb2:
+    // CHECK:   [[VAL_36:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_34]] : f32
+    // CHECK:   [[VAL_37:%.*]] = select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32
+    // CHECK:   br ^bb4([[VAL_37]] : f32)
+    // CHECK: ^bb3:
+    // CHECK:   br ^bb4([[VAL_0]] : f32)
+    // CHECK: ^bb4([[VAL_38:%.*]]: f32):
+    // CHECK:   [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle [[VAL_38]], [[VAL_7]], [[VAL_32]] xor : f32
+    // CHECK:   cond_br [[VAL_40]], ^bb5, ^bb6
+    // CHECK: ^bb5:
+    // CHECK:   [[VAL_41:%.*]] = cmpf "ugt", [[VAL_38]], [[VAL_39]] : f32
+    // CHECK:   [[VAL_42:%.*]] = select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32
+    // CHECK:   br ^bb7([[VAL_42]] : f32)
+    // CHECK: ^bb6:
+    // CHECK:   br ^bb7([[VAL_38]] : f32)
+    // CHECK: ^bb7([[VAL_43:%.*]]: f32):
+    // CHECK:   [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle [[VAL_43]], [[VAL_8]], [[VAL_32]] xor : f32
+    // CHECK:   cond_br [[VAL_45]], ^bb8, ^bb9
+    // CHECK: ^bb8:
+    // CHECK:   [[VAL_46:%.*]] = cmpf "ugt", [[VAL_43]], [[VAL_44]] : f32
+    // CHECK:   [[VAL_47:%.*]] = select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32
+    // CHECK:   br ^bb10([[VAL_47]] : f32)
+    // CHECK: ^bb9:
+    // CHECK:   br ^bb10([[VAL_43]] : f32)
+    // CHECK: ^bb10([[VAL_48:%.*]]: f32):
+    // CHECK:   [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle [[VAL_48]], [[VAL_9]], [[VAL_32]] xor : f32
+    // CHECK:   cond_br [[VAL_50]], ^bb11, ^bb12
+    // CHECK: ^bb11:
+    // CHECK:   [[VAL_51:%.*]] = cmpf "ugt", [[VAL_48]], [[VAL_49]] : f32
+    // CHECK:   [[VAL_52:%.*]] = select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32
+    // CHECK:   br ^bb13([[VAL_52]] : f32)
+    // CHECK: ^bb12:
+    // CHECK:   br ^bb13([[VAL_48]] : f32)
+    // CHECK: ^bb13([[VAL_53:%.*]]: f32):
+    // CHECK:   [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle [[VAL_53]], [[VAL_10]], [[VAL_32]] xor : f32
+    // CHECK:   cond_br [[VAL_55]], ^bb14, ^bb15
+    // CHECK: ^bb14:
+    // CHECK:   [[VAL_56:%.*]] = cmpf "ugt", [[VAL_53]], [[VAL_54]] : f32
+    // CHECK:   [[VAL_57:%.*]] = select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32
+    // CHECK:   br ^bb16([[VAL_57]] : f32)
+    // CHECK: ^bb15:
+    // CHECK:   br ^bb16([[VAL_53]] : f32)
+    // CHECK: ^bb16([[VAL_58:%.*]]: f32):
+    // CHECK:   br ^bb18([[VAL_58]] : f32)
+    // CHECK: ^bb17:
+    // CHECK:   [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_61:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_59]] : f32
+    // CHECK:   [[VAL_62:%.*]] = select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32
+    // CHECK:   [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle [[VAL_62]], [[VAL_7]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_65:%.*]] = cmpf "ugt", [[VAL_62]], [[VAL_63]] : f32
+    // CHECK:   [[VAL_66:%.*]] = select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32
+    // CHECK:   [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle [[VAL_66]], [[VAL_8]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_69:%.*]] = cmpf "ugt", [[VAL_66]], [[VAL_67]] : f32
+    // CHECK:   [[VAL_70:%.*]] = select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32
+    // CHECK:   [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle [[VAL_70]], [[VAL_9]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_73:%.*]] = cmpf "ugt", [[VAL_70]], [[VAL_71]] : f32
+    // CHECK:   [[VAL_74:%.*]] = select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32
+    // CHECK:   [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle [[VAL_74]], [[VAL_10]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_77:%.*]] = cmpf "ugt", [[VAL_74]], [[VAL_75]] : f32
+    // CHECK:   [[VAL_78:%.*]] = select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32
+    // CHECK:   br ^bb18([[VAL_78]] : f32)
+    // CHECK: ^bb18([[VAL_79:%.*]]: f32):
+    // CHECK:   cond_br [[VAL_30]], ^bb19, ^bb20
+    // CHECK: ^bb19:
+    // CHECK:   [[VAL_80:%.*]] = divi_signed [[VAL_27]], [[VAL_5]] : i32
+    // CHECK:   [[VAL_81:%.*]] = index_cast [[VAL_80]] : i32 to index
+    // CHECK:   store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, 3>
+    // CHECK:   br ^bb21
+    // CHECK: ^bb20:
+    // CHECK:   br ^bb21
+    // CHECK: ^bb21:
+    // CHECK:   gpu.barrier
+    // CHECK:   [[VAL_82:%.*]] = addi [[VAL_28]], [[VAL_2]] : i32
+    // CHECK:   [[VAL_83:%.*]] = divi_signed [[VAL_82]], [[VAL_5]] : i32
+    // CHECK:   [[VAL_84:%.*]] = cmpi "slt", [[VAL_27]], [[VAL_83]] : i32
+    // CHECK:   cond_br [[VAL_84]], ^bb22, ^bb41
+    // CHECK: ^bb22:
+    // CHECK:   [[VAL_85:%.*]] = index_cast [[VAL_27]] : i32 to index
+    // CHECK:   [[VAL_86:%.*]] = load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, 3>
+    // CHECK:   [[VAL_87:%.*]] = cmpi "slt", [[VAL_83]], [[VAL_5]] : i32
+    // CHECK:   cond_br [[VAL_87]], ^bb23, ^bb39
+    // CHECK: ^bb23:
+    // CHECK:   [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_83]] xor : f32
+    // CHECK:   cond_br [[VAL_89]], ^bb24, ^bb25
+    // CHECK: ^bb24:
+    // CHECK:   [[VAL_90:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_88]] : f32
+    // CHECK:   [[VAL_91:%.*]] = select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32
+    // CHECK:   br ^bb26([[VAL_91]] : f32)
+    // CHECK: ^bb25:
+    // CHECK:   br ^bb26([[VAL_86]] : f32)
+    // CHECK: ^bb26([[VAL_92:%.*]]: f32):
+    // CHECK:   [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle [[VAL_92]], [[VAL_7]], [[VAL_83]] xor : f32
+    // CHECK:   cond_br [[VAL_94]], ^bb27, ^bb28
+    // CHECK: ^bb27:
+    // CHECK:   [[VAL_95:%.*]] = cmpf "ugt", [[VAL_92]], [[VAL_93]] : f32
+    // CHECK:   [[VAL_96:%.*]] = select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32
+    // CHECK:   br ^bb29([[VAL_96]] : f32)
+    // CHECK: ^bb28:
+    // CHECK:   br ^bb29([[VAL_92]] : f32)
+    // CHECK: ^bb29([[VAL_97:%.*]]: f32):
+    // CHECK:   [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle [[VAL_97]], [[VAL_8]], [[VAL_83]] xor : f32
+    // CHECK:   cond_br [[VAL_99]], ^bb30, ^bb31
+    // CHECK: ^bb30:
+    // CHECK:   [[VAL_100:%.*]] = cmpf "ugt", [[VAL_97]], [[VAL_98]] : f32
+    // CHECK:   [[VAL_101:%.*]] = select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32
+    // CHECK:   br ^bb32([[VAL_101]] : f32)
+    // CHECK: ^bb31:
+    // CHECK:   br ^bb32([[VAL_97]] : f32)
+    // CHECK: ^bb32([[VAL_102:%.*]]: f32):
+    // CHECK:   [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle [[VAL_102]], [[VAL_9]], [[VAL_83]] xor : f32
+    // CHECK:   cond_br [[VAL_104]], ^bb33, ^bb34
+    // CHECK: ^bb33:
+    // CHECK:   [[VAL_105:%.*]] = cmpf "ugt", [[VAL_102]], [[VAL_103]] : f32
+    // CHECK:   [[VAL_106:%.*]] = select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32
+    // CHECK:   br ^bb35([[VAL_106]] : f32)
+    // CHECK: ^bb34:
+    // CHECK:   br ^bb35([[VAL_102]] : f32)
+    // CHECK: ^bb35([[VAL_107:%.*]]: f32):
+    // CHECK:   [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle [[VAL_107]], [[VAL_10]], [[VAL_83]] xor : f32
+    // CHECK:   cond_br [[VAL_109]], ^bb36, ^bb37
+    // CHECK: ^bb36:
+    // CHECK:   [[VAL_110:%.*]] = cmpf "ugt", [[VAL_107]], [[VAL_108]] : f32
+    // CHECK:   [[VAL_111:%.*]] = select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32
+    // CHECK:   br ^bb38([[VAL_111]] : f32)
+    // CHECK: ^bb37:
+    // CHECK:   br ^bb38([[VAL_107]] : f32)
+    // CHECK: ^bb38([[VAL_112:%.*]]: f32):
+    // CHECK:   br ^bb40([[VAL_112]] : f32)
+    // CHECK: ^bb39:
+    // CHECK:   [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_115:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_113]] : f32
+    // CHECK:   [[VAL_116:%.*]] = select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32
+    // CHECK:   [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle [[VAL_116]], [[VAL_7]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_119:%.*]] = cmpf "ugt", [[VAL_116]], [[VAL_117]] : f32
+    // CHECK:   [[VAL_120:%.*]] = select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32
+    // CHECK:   [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle [[VAL_120]], [[VAL_8]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_123:%.*]] = cmpf "ugt", [[VAL_120]], [[VAL_121]] : f32
+    // CHECK:   [[VAL_124:%.*]] = select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32
+    // CHECK:   [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle [[VAL_124]], [[VAL_9]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_127:%.*]] = cmpf "ugt", [[VAL_124]], [[VAL_125]] : f32
+    // CHECK:   [[VAL_128:%.*]] = select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32
+    // CHECK:   [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle [[VAL_128]], [[VAL_10]], [[VAL_5]] xor : f32
+    // CHECK:   [[VAL_131:%.*]] = cmpf "ugt", [[VAL_128]], [[VAL_129]] : f32
+    // CHECK:   [[VAL_132:%.*]] = select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32
+    // CHECK:   br ^bb40([[VAL_132]] : f32)
+    // CHECK: ^bb40([[VAL_133:%.*]]: f32):
+    // CHECK:   store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3>
+    // CHECK:   br ^bb42
+    // CHECK: ^bb41:
+    // CHECK:   br ^bb42
+    // CHECK: ^bb42:
+    // CHECK:   gpu.barrier
+    // CHECK:   [[VAL_134:%.*]] = load [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3>
+    %sum = "gpu.all_reduce"(%arg0) ({}) {op = "max"} : (f32) -> (f32)
+    gpu.return
+  }
+
+}
index 341af2d..f2a9e11 100644 (file)
@@ -255,6 +255,14 @@ func @reduce_invalid_op(%arg0 : f32) {
 
 // -----
 
+func @reduce_invalid_op_type(%arg0 : f32) {
+  // expected-error@+1 {{`and` accumulator is only compatible with Integer type}}
+  %res = "gpu.all_reduce"(%arg0) ({}) {op = "and"} : (f32) -> (f32)
+  return
+}
+
+// -----
+
 func @reduce_incorrect_region_arguments(%arg0 : f32) {
   // expected-error@+1 {{expected two region arguments}}
   %res = "gpu.all_reduce"(%arg0) ({
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir
new file mode 100644 (file)
index 0000000..edf3e02
--- /dev/null
@@ -0,0 +1,60 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @main() {
+  %data = alloc() : memref<2x6xi32>
+  %sum_and = alloc() : memref<2xi32>
+  %sum_or = alloc() : memref<2xi32>
+  %sum_min = alloc() : memref<2xi32>
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %cst2 = constant 2 : i32
+  %cst4 = constant 4 : i32
+  %cst8 = constant 8 : i32
+  %cst16 = constant 16 : i32
+
+  %cst3 = constant 3 : i32
+  %cst6 = constant 6 : i32
+  %cst7 = constant 7 : i32
+  %cst10 = constant 10 : i32
+  %cst11 = constant 11 : i32
+
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+    
+  store %cst0, %data[%c0, %c0] : memref<2x6xi32>
+  store %cst1, %data[%c0, %c1] : memref<2x6xi32>
+  store %cst2, %data[%c0, %c2] : memref<2x6xi32>
+  store %cst4, %data[%c0, %c3] : memref<2x6xi32>
+  store %cst8, %data[%c0, %c4] : memref<2x6xi32>
+  store %cst16, %data[%c0, %c5] : memref<2x6xi32>
+
+  store %cst2, %data[%c1, %c0] : memref<2x6xi32>
+  store %cst3, %data[%c1, %c1] : memref<2x6xi32>
+  store %cst6, %data[%c1, %c2] : memref<2x6xi32>
+  store %cst7, %data[%c1, %c3] : memref<2x6xi32>
+  store %cst10, %data[%c1, %c4] : memref<2x6xi32>
+  store %cst11, %data[%c1, %c5] : memref<2x6xi32>
+
+  // AND
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
+    %val = load %data[%bx, %tx] : memref<2x6xi32>
+    %reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32)
+    store %reduced_and, %sum_and[%bx] : memref<2xi32>
+    gpu.terminator
+  }
+
+  %ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> ()
+  // CHECK: [0, 2]
+
+  return
+}
+
+func @print_memref_i32(memref<*xi32>)
+
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir
new file mode 100644 (file)
index 0000000..6ed27cc
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @main() {
+  %data = alloc() : memref<2x6xi32>
+  %sum = alloc() : memref<2xi32>
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %cst2 = constant 2 : i32
+  %cst4 = constant 4 : i32
+  %cst8 = constant 8 : i32
+  %cst16 = constant 16 : i32
+
+  %cst3 = constant 3 : i32
+  %cst6 = constant 6 : i32
+  %cst7 = constant 7 : i32
+  %cst10 = constant 10 : i32
+  %cst11 = constant 11 : i32
+
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+    
+  store %cst0, %data[%c0, %c0] : memref<2x6xi32>
+  store %cst1, %data[%c0, %c1] : memref<2x6xi32>
+  store %cst2, %data[%c0, %c2] : memref<2x6xi32>
+  store %cst4, %data[%c0, %c3] : memref<2x6xi32>
+  store %cst8, %data[%c0, %c4] : memref<2x6xi32>
+  store %cst16, %data[%c0, %c5] : memref<2x6xi32>
+
+  store %cst2, %data[%c1, %c0] : memref<2x6xi32>
+  store %cst3, %data[%c1, %c1] : memref<2x6xi32>
+  store %cst6, %data[%c1, %c2] : memref<2x6xi32>
+  store %cst7, %data[%c1, %c3] : memref<2x6xi32>
+  store %cst10, %data[%c1, %c4] : memref<2x6xi32>
+  store %cst11, %data[%c1, %c5] : memref<2x6xi32>
+
+  // MAX
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
+    %val = load %data[%bx, %tx] : memref<2x6xi32>
+    %reduced = "gpu.all_reduce"(%val) ({}) { op = "max" } : (i32) -> (i32)
+    store %reduced, %sum[%bx] : memref<2xi32>
+    gpu.terminator
+  }
+
+  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  // CHECK: [16, 11]
+
+  return
+}
+
+func @print_memref_i32(memref<*xi32>)
+
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir
new file mode 100644 (file)
index 0000000..2165fe5
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @main() {
+  %data = alloc() : memref<2x6xi32>
+  %sum = alloc() : memref<2xi32>
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %cst2 = constant 2 : i32
+  %cst4 = constant 4 : i32
+  %cst8 = constant 8 : i32
+  %cst16 = constant 16 : i32
+
+  %cst3 = constant 3 : i32
+  %cst6 = constant 6 : i32
+  %cst7 = constant 7 : i32
+  %cst10 = constant 10 : i32
+  %cst11 = constant 11 : i32
+
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+    
+  store %cst0, %data[%c0, %c0] : memref<2x6xi32>
+  store %cst1, %data[%c0, %c1] : memref<2x6xi32>
+  store %cst2, %data[%c0, %c2] : memref<2x6xi32>
+  store %cst4, %data[%c0, %c3] : memref<2x6xi32>
+  store %cst8, %data[%c0, %c4] : memref<2x6xi32>
+  store %cst16, %data[%c0, %c5] : memref<2x6xi32>
+
+  store %cst2, %data[%c1, %c0] : memref<2x6xi32>
+  store %cst3, %data[%c1, %c1] : memref<2x6xi32>
+  store %cst6, %data[%c1, %c2] : memref<2x6xi32>
+  store %cst7, %data[%c1, %c3] : memref<2x6xi32>
+  store %cst10, %data[%c1, %c4] : memref<2x6xi32>
+  store %cst11, %data[%c1, %c5] : memref<2x6xi32>
+
+  // MIN
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
+    %val = load %data[%bx, %tx] : memref<2x6xi32>
+    %reduced = "gpu.all_reduce"(%val) ({}) { op = "min" } : (i32) -> (i32)
+    store %reduced, %sum[%bx] : memref<2xi32>
+    gpu.terminator
+  }
+
+  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  // CHECK: [0, 2]
+
+  return
+}
+
+func @print_memref_i32(memref<*xi32>)
+
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir
new file mode 100644 (file)
index 0000000..2091c22
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @main() {
+  %data = alloc() : memref<2x6xi32>
+  %sum = alloc() : memref<2xi32>
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %cst2 = constant 2 : i32
+  %cst4 = constant 4 : i32
+  %cst8 = constant 8 : i32
+  %cst16 = constant 16 : i32
+
+  %cst3 = constant 3 : i32
+  %cst6 = constant 6 : i32
+  %cst7 = constant 7 : i32
+  %cst10 = constant 10 : i32
+  %cst11 = constant 11 : i32
+
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+    
+  store %cst0, %data[%c0, %c0] : memref<2x6xi32>
+  store %cst1, %data[%c0, %c1] : memref<2x6xi32>
+  store %cst2, %data[%c0, %c2] : memref<2x6xi32>
+  store %cst4, %data[%c0, %c3] : memref<2x6xi32>
+  store %cst8, %data[%c0, %c4] : memref<2x6xi32>
+  store %cst16, %data[%c0, %c5] : memref<2x6xi32>
+
+  store %cst2, %data[%c1, %c0] : memref<2x6xi32>
+  store %cst3, %data[%c1, %c1] : memref<2x6xi32>
+  store %cst6, %data[%c1, %c2] : memref<2x6xi32>
+  store %cst7, %data[%c1, %c3] : memref<2x6xi32>
+  store %cst10, %data[%c1, %c4] : memref<2x6xi32>
+  store %cst11, %data[%c1, %c5] : memref<2x6xi32>
+
+  // OR
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
+    %val = load %data[%bx, %tx] : memref<2x6xi32>
+    %reduced = "gpu.all_reduce"(%val) ({}) { op = "or" } : (i32) -> (i32)
+    store %reduced, %sum[%bx] : memref<2xi32>
+    gpu.terminator
+  }
+
+  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  // CHECK: [31, 15]
+
+  return
+}
+
+func @print_memref_i32(memref<*xi32>)
+
diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir
new file mode 100644 (file)
index 0000000..1531641
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @main() {
+  %data = alloc() : memref<2x6xi32>
+  %sum = alloc() : memref<2xi32>
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %cst2 = constant 2 : i32
+  %cst4 = constant 4 : i32
+  %cst8 = constant 8 : i32
+  %cst16 = constant 16 : i32
+
+  %cst3 = constant 3 : i32
+  %cst6 = constant 6 : i32
+  %cst7 = constant 7 : i32
+  %cst10 = constant 10 : i32
+  %cst11 = constant 11 : i32
+
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %c6 = constant 6 : index
+    
+  store %cst0, %data[%c0, %c0] : memref<2x6xi32>
+  store %cst1, %data[%c0, %c1] : memref<2x6xi32>
+  store %cst2, %data[%c0, %c2] : memref<2x6xi32>
+  store %cst4, %data[%c0, %c3] : memref<2x6xi32>
+  store %cst8, %data[%c0, %c4] : memref<2x6xi32>
+  store %cst16, %data[%c0, %c5] : memref<2x6xi32>
+
+  store %cst2, %data[%c1, %c0] : memref<2x6xi32>
+  store %cst3, %data[%c1, %c1] : memref<2x6xi32>
+  store %cst6, %data[%c1, %c2] : memref<2x6xi32>
+  store %cst7, %data[%c1, %c3] : memref<2x6xi32>
+  store %cst10, %data[%c1, %c4] : memref<2x6xi32>
+  store %cst11, %data[%c1, %c5] : memref<2x6xi32>
+
+  // XOR
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) {
+    %val = load %data[%bx, %tx] : memref<2x6xi32>
+    %reduced = "gpu.all_reduce"(%val) ({}) { op = "xor" } : (i32) -> (i32)
+    store %reduced, %sum[%bx] : memref<2xi32>
+    gpu.terminator
+  }
+
+  %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32>
+  call @print_memref_i32(%ptr) : (memref<*xi32>) -> ()
+  // CHECK: [31, 1]
+
+  return
+}
+
+func @print_memref_i32(memref<*xi32>)
+