ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
-/// Pattern to fold a generic op with a splat constant.
-class FoldSplatConstants : public OpRewritePattern<GenericOp> {
+/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
+/// handle cases where the constant is not single-valued.
+class FoldConstants : public OpRewritePattern<GenericOp> {
public:
- FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
- PatternBenefit benefit = 1)
+ FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+ PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
return failure();
for (OpOperand *opOperand : genericOp.getInputOperands()) {
Operation *def = opOperand->get().getDefiningOp();
- DenseElementsAttr constantAttr;
- if (!def ||
- !matchPattern(def, m_Constant<DenseElementsAttr>(&constantAttr)) ||
- !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand))
+ Attribute constantAttr;
+ auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
+ {
+ DenseElementsAttr splatAttr;
+ if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
+ splatAttr.isSplat() &&
+ splatAttr.getType().getElementType().isIntOrFloat()) {
+ constantAttr = splatAttr.getSplatValue();
+ return true;
+ }
+ }
+ {
+ IntegerAttr intAttr;
+ if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
+ constantAttr = intAttr;
+ return true;
+ }
+ }
+ {
+ FloatAttr floatAttr;
+ if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
+ constantAttr = floatAttr;
+ return true;
+ }
+ }
+ return false;
+ };
+
+ auto resultValue = opOperand->get().dyn_cast<OpResult>();
+ if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
+ !controlFn(resultValue, *opOperand))
continue;
// The operands and the indexing_maps of the fused operation the same as
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
- def->getLoc(), constantAttr.getSplatValue(),
- constantAttr.getType().getElementType());
+ def->getLoc(), constantAttr, constantAttr.getType());
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto fusedOp = rewriter.create<GenericOp>(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
auto *context = patterns.getContext();
- patterns.add<FuseElementwiseOps, FoldSplatConstants>(
+ patterns.add<FuseElementwiseOps, FoldConstants>(
context, options.controlElementwiseOpsFusionFn);
patterns.add<RemoveOutsDependency>(context);
populateFoldReshapeOpsByExpansionPatterns(patterns,
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
+
+// -----
+
+func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
+ %cst = constant 4.0 : f32
+ %c42 = constant 42 : i32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %0 = linalg.init_tensor[%d0, %d1] : tensor<?x?xf32>
+ %1 = linalg.init_tensor[%d0, %d1] : tensor<?x?xi32>
+ %2:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> ()>,
+ affine_map<(d0, d1) -> ()>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
+ outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
+ ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
+ %3 = addf %arg1, %arg2 : f32
+ linalg.yield %3, %arg3 : f32, i32
+ } -> (tensor<?x?xf32>, tensor<?x?xi32>)
+ return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
+}
+// CHECK-LABEL: func @fuse_scalar_constant
+// CHECK-DAG: %[[CST:.+]] = constant 4.000000e+00 : f32
+// CHECK-DAG: %[[C42:.+]] = constant 42 : i32
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}} : tensor<?x?xf32>)
+// CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
+// CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32