The fillOp's value needs to casted
authorPrashant Kumar <prashant@nod-labs.com>
Fri, 4 Nov 2022 17:09:22 +0000 (17:09 +0000)
committerPrashant Kumar <prashant@nod-labs.com>
Thu, 10 Nov 2022 03:43:22 +0000 (03:43 +0000)
During elementwise fusion the fillOp's value was directly
referred without casting which can create mismatching dtypes.

Reviewed By: mravishankar, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137447

mlir/include/mlir/Dialect/Arith/Utils/Utils.h
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

index d307330..d7aa7db 100644 (file)
@@ -92,6 +92,12 @@ SmallVector<Value>
 getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                 ArrayRef<OpFoldResult> valueOrAttrVec);
 
+/// Converts a scalar value `operand` to type `toType`. If the value doesn't
+/// convert, a warning will be issued and the operand is returned as is (which
+/// will presumably yield a verification issue downstream).
+Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
+                           Type toType, bool isUnsignedCast);
+
 /// Helper struct to build simple arithmetic quantities with minimal type
 /// inference support.
 struct ArithBuilder {
index 210c580..cf9fdc2 100644 (file)
@@ -80,6 +80,50 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
   return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
 }
 
+Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
+                                 Type toType, bool isUnsignedCast) {
+  if (operand.getType() == toType)
+    return operand;
+  if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+    // If operand is floating point, cast directly to the int type.
+    if (operand.getType().isa<FloatType>()) {
+      if (isUnsignedCast)
+        return b.create<arith::FPToUIOp>(loc, toType, operand);
+      return b.create<arith::FPToSIOp>(loc, toType, operand);
+    }
+    // Cast index operands directly to the int type.
+    if (operand.getType().isIndex())
+      return b.create<arith::IndexCastOp>(loc, toType, operand);
+    if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+      // Either extend or truncate.
+      if (toIntType.getWidth() > fromIntType.getWidth()) {
+        if (isUnsignedCast)
+          return b.create<arith::ExtUIOp>(loc, toType, operand);
+        return b.create<arith::ExtSIOp>(loc, toType, operand);
+      }
+      if (toIntType.getWidth() < fromIntType.getWidth())
+        return b.create<arith::TruncIOp>(loc, toType, operand);
+    }
+  } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+    // If operand is integer, cast directly to the float type.
+    // Note that it is unclear how to cast from BF16<->FP16.
+    if (operand.getType().isa<IntegerType>()) {
+      if (isUnsignedCast)
+        return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
+      return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
+    }
+    if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+      if (toFloatType.getWidth() > fromFloatType.getWidth())
+        return b.create<arith::ExtFOp>(loc, toFloatType, operand);
+      if (toFloatType.getWidth() < fromFloatType.getWidth())
+        return b.create<arith::TruncFOp>(loc, toFloatType, operand);
+    }
+  }
+  emitWarning(loc) << "could not cast operand of type " << operand.getType()
+                   << " to " << toType;
+  return operand;
+}
+
 SmallVector<Value>
 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
index 8ce1ad0..fc62407 100644 (file)
@@ -423,48 +423,7 @@ private:
   Value cast(Type toType, Value operand, bool isUnsignedCast) {
     OpBuilder builder = getBuilder();
     auto loc = operand.getLoc();
-
-    if (operand.getType() == toType)
-      return operand;
-    if (auto toIntType = toType.dyn_cast<IntegerType>()) {
-      // If operand is floating point, cast directly to the int type.
-      if (operand.getType().isa<FloatType>()) {
-        if (isUnsignedCast)
-          return builder.create<arith::FPToUIOp>(loc, toType, operand);
-        return builder.create<arith::FPToSIOp>(loc, toType, operand);
-      }
-      // Cast index operands directly to the int type.
-      if (operand.getType().isIndex())
-        return builder.create<arith::IndexCastOp>(loc, toType, operand);
-      if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
-        // Either extend or truncate.
-        if (toIntType.getWidth() > fromIntType.getWidth()) {
-          if (isUnsignedCast)
-            return builder.create<arith::ExtUIOp>(loc, toType, operand);
-          return builder.create<arith::ExtSIOp>(loc, toType, operand);
-        }
-        if (toIntType.getWidth() < fromIntType.getWidth())
-          return builder.create<arith::TruncIOp>(loc, toType, operand);
-      }
-    } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
-      // If operand is integer, cast directly to the float type.
-      // Note that it is unclear how to cast from BF16<->FP16.
-      if (operand.getType().isa<IntegerType>()) {
-        if (isUnsignedCast)
-          return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
-        return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
-      }
-      if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
-        if (toFloatType.getWidth() > fromFloatType.getWidth())
-          return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
-        if (toFloatType.getWidth() < fromFloatType.getWidth())
-          return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
-      }
-    }
-
-    emitWarning(operand.getLoc()) << "could not cast operand of type "
-                                  << operand.getType() << " to " << toType;
-    return operand;
+    return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
   }
 
   bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
index 6a9c4e3..e639158 100644 (file)
@@ -1744,8 +1744,14 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
       if (!fillOp)
         continue;
       fillFound = true;
+      Value fillVal = fillOp.value();
+      auto resultType =
+          fillOp.result().getType().cast<RankedTensorType>().getElementType();
+      Value convertedVal =
+          convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
+                               /*isUnsignedCast =*/false);
       payload.getArgument(opOperand->getOperandNumber())
-          .replaceAllUsesWith(fillOp.value());
+          .replaceAllUsesWith(convertedVal);
     }
     return success(fillFound);
   }
index ca142e3..0de109e 100644 (file)
@@ -1017,6 +1017,30 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_fill_generic_different_dtype
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> { 
+//   CHECK-NOT: linalg.fill
+//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
+//  CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
+#map0 = affine_map<(d0) -> (d0)>
+func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 7.0 : f32
+  %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
+  %1 = tensor.empty(%0) : tensor<?xf16>
+  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
+  %3 = tensor.empty(%0) : tensor<?xf16>
+  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
+  ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
+    %5 = arith.addf  %arg1, %arg2 : f16
+        linalg.yield %5 : f16
+  } -> tensor<?xf16>
+  return %4 : tensor<?xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_fill_generic_mixedaccess
 //   CHECK-NOT: linalg.fill
 //       CHECK: %[[GENERIC_OP:.*]] = linalg.generic