[mlir][IR] Implement proper folder for `IsCommutative` trait
authorMatthias Springer <me@m-sp.org>
Thu, 20 Jul 2023 08:11:44 +0000 (10:11 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 20 Jul 2023 08:19:48 +0000 (10:19 +0200)
Commutative ops were previously folded with a special rule in `OperationFolder`. This change turns the folding into a proper `OpTrait` folder.

Differential Revision: https://reviews.llvm.org/D155687

mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir

index d895679..221c607 100644 (file)
@@ -314,6 +314,8 @@ namespace OpTrait {
 // corresponding trait classes.  This avoids them being template
 // instantiated/duplicated.
 namespace impl {
+LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+                              SmallVectorImpl<OpFoldResult> &results);
 OpFoldResult foldIdempotent(Operation *op);
 OpFoldResult foldInvolution(Operation *op);
 LogicalResult verifyZeroOperands(Operation *op);
@@ -1148,7 +1150,13 @@ public:
 
 /// This class adds property that the operation is commutative.
 template <typename ConcreteType>
-class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
+class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
+public:
+  static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+                                 SmallVectorImpl<OpFoldResult> &results) {
+    return impl::foldCommutative(op, operands, results);
+  }
+};
 
 /// This class adds property that the operation is an involution.
 /// This means a unary to unary operation "f" that satisfies f(f(x)) = x
index 449c97d..efce8d9 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -790,6 +791,24 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
 // Op Trait implementations
 //===----------------------------------------------------------------------===//
 
+LogicalResult
+OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+                               SmallVectorImpl<OpFoldResult> &results) {
+  // Nothing to fold if there are not at least 2 operands.
+  if (op->getNumOperands() < 2)
+    return failure();
+  // Move all constant operands to the end.
+  OpOperand *operandsBegin = op->getOpOperands().begin();
+  auto isNonConstant = [&](OpOperand &o) {
+    return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
+  };
+  auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
+  auto *newConstantIt = std::stable_partition(
+      firstConstantIt, op->getOpOperands().end(), isNonConstant);
+  // Return success if the op was modified.
+  return success(firstConstantIt != newConstantIt);
+}
+
 OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
   if (op->getNumOperands() == 1) {
     auto *argumentOp = op->getOperand(0).getDefiningOp();
index e9e59cf..ad1e043 100644 (file)
@@ -217,21 +217,6 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
                                          SmallVectorImpl<Value> &results) {
   SmallVector<Attribute, 8> operandConstants;
 
-  // If this is a commutative operation, move constants to be trailing operands.
-  bool updatedOpOperands = false;
-  if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
-    auto isNonConstant = [&](OpOperand &o) {
-      return !matchPattern(o.get(), m_Constant());
-    };
-    auto *firstConstantIt =
-        llvm::find_if_not(op->getOpOperands(), isNonConstant);
-    auto *newConstantIt = std::stable_partition(
-        firstConstantIt, op->getOpOperands().end(), isNonConstant);
-
-    // Remember if we actually moved anything.
-    updatedOpOperands = firstConstantIt != newConstantIt;
-  }
-
   // Check to see if any operands to the operation is constant and whether
   // the operation knows how to constant fold itself.
   operandConstants.assign(op->getNumOperands(), Attribute());
@@ -244,7 +229,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op,
   SmallVector<OpFoldResult, 8> foldResults;
   if (failed(op->fold(operandConstants, foldResults)) ||
       failed(processFoldResults(op, results, foldResults)))
-    return success(updatedOpOperands);
+    return failure();
   return success();
 }
 
index 710f704..1ed6770 100644 (file)
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
   // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,7 +159,7 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
 
   // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
   // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
-  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+  // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
   // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
   // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
   // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
index f11abe4..238c0c5 100644 (file)
@@ -27,7 +27,7 @@ func.func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> t
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
 // CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
+// CHECK:           %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index
 // CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
 // CHECK:           %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
 // CHECK:           %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index