[mlir][vector] Fix crash in vector.reduction canonicalization
authorThomas Raoux <thomasraoux@google.com>
Tue, 12 Jul 2022 22:44:39 +0000 (22:44 +0000)
committerThomas Raoux <thomasraoux@google.com>
Tue, 12 Jul 2022 23:15:30 +0000 (23:15 +0000)
since vector.reduce support accumulator in all the cases remove the
assert assuming old definition.

Differential Revision: https://reviews.llvm.org/D129602

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 24c2ff5f636d9a20c4248c7bcd22e3b38a6bb1cd..d51c5592ee3bc52f4a40719dfa7fcb4700377973 100644 (file)
@@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
 /// memory.
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+                         Value v1, Value v2);
 } // namespace vector
 } // namespace mlir
 
index f6b84f1e28cda6e0364a8069a1e92233333df443..b5e6bc1ae5747f55e6a0083dd244b3ed4f4e1947 100644 (file)
@@ -34,11 +34,6 @@ namespace vector {
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
-
-/// Return the result value of reducing two scalar/vector values with the
-/// corresponding arith operation.
-Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value v2);
 } // namespace vector
 
 /// Return the number of elements of basis, `0` if empty.
index f803868c2150d6f589460a05e4d636747c75b411..c50359af87b0609fadcc3a3319b866cf4d49055e 100644 (file)
@@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
                                               reductionOp.getVector(),
                                               rewriter.getI64ArrayAttr(0));
 
-    if (Value acc = reductionOp.getAcc()) {
-      assert(reductionOp.getType().isa<FloatType>());
-      switch (reductionOp.getKind()) {
-      case CombiningKind::ADD:
-        result = rewriter.create<arith::AddFOp>(loc, result, acc);
-        break;
-      case CombiningKind::MUL:
-        result = rewriter.create<arith::MulFOp>(loc, result, acc);
-        break;
-      default:
-        assert(false && "invalid op!");
-      }
-    }
+    if (Value acc = reductionOp.getAcc())
+      result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
+                                          result, acc);
 
     rewriter.replaceOp(reductionOp, result);
     return success();
@@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
       verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
 }
 
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+                                       CombiningKind kind, Value v1, Value v2) {
+  Type t1 = getElementTypeOrSelf(v1.getType());
+  Type t2 = getElementTypeOrSelf(v2.getType());
+  switch (kind) {
+  case CombiningKind::ADD:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for ADD reduction");
+  case CombiningKind::AND:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+  case CombiningKind::MAXF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+  case CombiningKind::MINF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+  case CombiningKind::MAXSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+  case CombiningKind::MINSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+  case CombiningKind::MAXUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+  case CombiningKind::MINUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+  case CombiningKind::MUL:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for MUL reduction");
+  case CombiningKind::OR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+  case CombiningKind::XOR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+  };
+  llvm_unreachable("unknown CombiningKind");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
index 7e6d56aa622e7be625e61a98f516de374ebed768..b979033ab47167c285d5077d25780f00f3229f96 100644 (file)
@@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
   llvm_unreachable("Expected MemRefType or TensorType");
 }
 
-Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
-                                       CombiningKind kind, Value v1, Value v2) {
-  Type t1 = getElementTypeOrSelf(v1.getType());
-  Type t2 = getElementTypeOrSelf(v2.getType());
-  switch (kind) {
-  case CombiningKind::ADD:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for ADD reduction");
-  case CombiningKind::AND:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
-  case CombiningKind::MAXF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
-           "expected float values");
-    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
-  case CombiningKind::MINF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
-           "expected float values");
-    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
-  case CombiningKind::MAXSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
-  case CombiningKind::MINSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
-  case CombiningKind::MAXUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
-  case CombiningKind::MINUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
-  case CombiningKind::MUL:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for MUL reduction");
-  case CombiningKind::OR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
-  case CombiningKind::XOR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
-  };
-  llvm_unreachable("unknown CombiningKind");
-}
-
 /// Return the number of elements of basis, `0` if empty.
 int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   if (basis.empty())
index 702670095c8d5e39957f60a449765dcbd0eba3fa..54025a626f002f41293bc60cdb340b7c71d079fe 100644 (file)
@@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @reduce_one_element_vector_maxf
+//  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
+//       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
+//       CHECK:   return %[[S]]
+func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
+  %s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @bitcast(
 //  CHECK-SAME:               %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
 //       CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>