#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <cstdint>
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
+//===----------------------------------------------------------------------===//
+// Patterns to Commute Extension Ops
+//===----------------------------------------------------------------------===//
+
+struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ Value newExtract = rewriter.create<vector::ExtractOp>(
+ op.getLoc(), extOp.getIn(), op.getPosition());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
+struct ExtensionOverExtractElement final
+ : OpRewritePattern<vector::ExtractElementOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractElementOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ Value newExtract = rewriter.create<vector::ExtractElementOp>(
+ op.getLoc(), extOp.getIn(), op.getPosition());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
+struct ExtensionOverExtractStridedSlice final
+ : OpRewritePattern<vector::ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getVector().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ VectorType origTy = op.getType();
+ Type inElemTy =
+ cast<VectorType>(extOp.getIn().getType()).getElementType();
+ VectorType extractTy = origTy.cloneWith(origTy.getShape(), inElemTy);
+ Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+ op.getLoc(), extractTy, extOp.getIn(), op.getOffsets(),
+ op.getSizes(), op.getStrides());
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ newExtract);
+ return success();
+ })
+ .Default(failure());
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//
void populateArithIntNarrowingPatterns(
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
+ // Add commute patterns with a higher benefit. This is to expose more
+ // optimization opportunities to narrowing patterns.
+ patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
+ ExtensionOverExtractStridedSlice>(patterns.getContext(),
+ PatternBenefit(2));
+
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
}
// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
// RUN: --verify-diagnostics %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// arith.*itofp
+//===----------------------------------------------------------------------===//
+
// CHECK-LABEL: func.func @sitofp_extsi_i16
// CHECK-SAME: (%[[ARG:.+]]: i16)
// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16
%f = arith.uitofp %b : i32 to f16
return %f : f16
}
+
+//===----------------------------------------------------------------------===//
+// Commute Extension over Vector Ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @extsi_over_extract_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract %b[1] : vector<3xi32>
+ %f = arith.sitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract %[[ARG]][1] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract %b[1] : vector<3xi32>
+ %f = arith.uitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.sitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+ %f = arith.sitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extui_over_extractelement_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.uitofp %[[EXTR]] : i16 to f16
+// CHECK-NEXT: return %[[RET]] : f16
+func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
+ %f = arith.uitofp %c : i32 to f16
+ return %f : f16
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2xi32>
+func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+ %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+ return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2xi32>
+func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
+ return %c : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<1x2xi32>
+func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+ %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+ return %c : vector<1x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<1x2xi32>
+func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
+ %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32>
+ %c = vector.extract_strided_slice %b
+ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
+ return %c : vector<1x2xi32>
+}