Revert "[mlir][arith] Add expansion pattern for ext/trunc of bf16"
authorBenjamin Kramer <benny.kra@googlemail.com>
Tue, 4 Apr 2023 13:53:04 +0000 (15:53 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Tue, 4 Apr 2023 13:58:38 +0000 (15:58 +0200)
This reverts commit 5bff523793ee8c30c260cc77b23c61dcbb606486. The
bf16->f32 conversion is incorrect. This can't be on by default, if you
want this behavior make it a separate pass.

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arith/expand-ops.mlir

index 6d60f8a..257a62a 100644 (file)
@@ -38,9 +38,6 @@ void populateArithWideIntEmulationPatterns(
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 
-/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
-void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
-
 /// Add patterns to expand Arith ops.
 void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
 
index 7a62469..8f34531 100644 (file)
@@ -10,7 +10,6 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -26,13 +25,15 @@ using namespace mlir;
 /// Create an integer or index constant.
 static Value createConst(Location loc, Type type, int value,
                          PatternRewriter &rewriter) {
-  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
-  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+
+  auto elTy = getElementTypeOrSelf(type);
+  auto constantAttr = rewriter.getIntegerAttr(elTy, value);
+
+  if (auto vecTy = llvm::dyn_cast<ShapedType>(type))
     return rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(shapedTy, attr));
-  }
+        loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr));
 
-  return rewriter.create<arith::ConstantOp>(loc, attr);
+  return rewriter.create<arith::ConstantOp>(loc, constantAttr);
 }
 
 namespace {
@@ -186,73 +187,6 @@ public:
   }
 };
 
-struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::ExtFOp op,
-                                PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto operand = op.getOperand();
-    Type operandTy = operand.getType();
-    Type resultTy = op.getType();
-    Type operandETy = getElementTypeOrSelf(operandTy);
-    Type resultETy = getElementTypeOrSelf(resultTy);
-
-    if (!operandETy.isBF16() || !resultETy.isF32()) {
-      return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
-    }
-
-    Type i16Ty = b.getI16Type();
-    Type i32Ty = b.getI32Type();
-    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i16Ty = shapedTy.clone(i16Ty);
-      i32Ty = shapedTy.clone(i32Ty);
-    }
-
-    Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
-    Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
-
-    Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
-    Value shl = b.create<arith::ShLIOp>(exti, c16);
-    Value result = b.create<arith::BitcastOp>(resultTy, shl);
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(arith::TruncFOp op,
-                                PatternRewriter &rewriter) const final {
-    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    auto operand = op.getOperand();
-    Type operandTy = operand.getType();
-    Type resultTy = op.getType();
-    Type operandETy = getElementTypeOrSelf(operandTy);
-    Type resultETy = getElementTypeOrSelf(resultTy);
-
-    if (!operandETy.isF32() || !resultETy.isBF16()) {
-      return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
-    }
-
-    Type i16Ty = b.getI16Type();
-    Type i32Ty = b.getI32Type();
-    if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
-      i16Ty = shapedTy.clone(i16Ty);
-      i32Ty = shapedTy.clone(i32Ty);
-    }
-
-    Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
-    Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
-    Value shl = b.create<arith::ShRUIOp>(bitcast, c16);
-    Value trunc = b.create<arith::TruncIOp>(i16Ty, shl);
-    Value result = b.create<arith::BitcastOp>(resultTy, trunc);
-
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
 struct ArithExpandOpsPass
     : public arith::impl::ArithExpandOpsBase<ArithExpandOpsPass> {
   void runOnOperation() override {
@@ -270,21 +204,6 @@ struct ArithExpandOpsPass
       arith::MaxFOp,
       arith::MinFOp
     >();
-
-    target.addDynamicallyLegalOp<arith::ExtFOp>(
-      [](arith::ExtFOp op) {
-        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-        Type outETy = getElementTypeOrSelf(op.getType());
-        return !(inETy.isBF16() && outETy.isF32());
-      });
-
-    target.addDynamicallyLegalOp<arith::TruncFOp>(
-      [](arith::TruncFOp op)  {
-        Type inETy = getElementTypeOrSelf(op.getOperand().getType());
-        Type outETy = getElementTypeOrSelf(op.getType());
-        return !(inETy.isF32() && outETy.isBF16());
-      });
-
     // clang-format on
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -301,19 +220,12 @@ void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
           patterns.getContext());
 }
 
-void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
-  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
-      patterns.getContext());
-}
-
 void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
   populateCeilFloorDivExpandOpsPatterns(patterns);
   // clang-format off
   patterns.add<
     MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
-    MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
-    BFloat16ExtFOpConverter,
-    BFloat16TruncFOpConverter
+    MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>
    >(patterns.getContext());
   // clang-format on
 }
index ba87e29..7b7eb40 100644 (file)
@@ -215,67 +215,3 @@ func.func @minf(%a: f32, %b: f32) -> f32 {
 // CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
 // CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-func.func @extf_bf16(%arg0 : bf16) -> f32 {
-    %0 = arith.extf %arg0 : bf16 to f32
-    return %0 : f32
-}
-
-// CHECK-LABEL: @extf_bf16
-// CHECK-SAME: %[[ARG0:.+]]: bf16
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : bf16 to i16
-// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : i16 to i32
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16
-// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]]
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : i32 to f32
-// CHECK: return %[[BITCAST]]
-
-// -----
-
-func.func @extf_vector_bf16(%arg0 : vector<4xbf16>) -> vector<4xf32> {
-    %0 = arith.extf %arg0 : vector<4xbf16> to vector<4xf32>
-    return %0 : vector<4xf32>
-}
-
-// CHECK-LABEL: @extf_vector_bf16
-// CHECK-SAME: %[[ARG0:.+]]: vector<4xbf16>
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xbf16> to vector<4xi16>
-// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : vector<4xi16> to vector<4xi32>
-// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16>
-// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]]
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : vector<4xi32> to vector<4xf32>
-// CHECK: return %[[BITCAST]]
-
-// -----
-
-func.func @truncf_f32(%arg0 : f32) -> bf16 {
-    %0 = arith.truncf %arg0 : f32 to bf16
-    return %0 : bf16
-}
-
-// CHECK-LABEL: @truncf_f32
-// CHECK-SAME: %[[ARG0:.+]]: f32
-// CHECK-DAG: %[[C16:.+]] = arith.constant 16
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
-// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]]
-// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i16
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : i16 to bf16
-// CHECK: return %[[BITCAST]] : bf16
-
-// -----
-
-func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
-    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16>
-    return %0 : vector<4xbf16>
-}
-
-// CHECK-LABEL: @truncf_vector_f32
-// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
-// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16>
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xf32> to vector<4xi32>
-// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]]
-// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : vector<4xi32> to vector<4xi16>
-// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : vector<4xi16> to vector<4xbf16>
-// CHECK: return %[[BITCAST]] : vector<4xbf16>