Commutative ops were previously folded with a special rule in `OperationFolder`. This change turns the folding into a proper `OpTrait` folder.
Differential Revision: https://reviews.llvm.org/D155687
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
+LogicalResult foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results);
OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
/// This class adds property that the operation is commutative.
template <typename ConcreteType>
-class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
+class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
+public:
+ static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ return impl::foldCommutative(op, operands, results);
+ }
+};
/// This class adds property that the operation is an involution.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
// Op Trait implementations
//===----------------------------------------------------------------------===//
+LogicalResult
+OpTrait::impl::foldCommutative(Operation *op, ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ // Nothing to fold if there are not at least 2 operands.
+ if (op->getNumOperands() < 2)
+ return failure();
+ // Move all constant operands to the end.
+ OpOperand *operandsBegin = op->getOpOperands().begin();
+ auto isNonConstant = [&](OpOperand &o) {
+ return !static_cast<bool>(operands[std::distance(operandsBegin, &o)]);
+ };
+ auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant);
+ auto *newConstantIt = std::stable_partition(
+ firstConstantIt, op->getOpOperands().end(), isNonConstant);
+ // Return success if the op was modified.
+ return success(firstConstantIt != newConstantIt);
+}
+
OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) {
if (op->getNumOperands() == 1) {
auto *argumentOp = op->getOperand(0).getDefiningOp();
SmallVectorImpl<Value> &results) {
SmallVector<Attribute, 8> operandConstants;
- // If this is a commutative operation, move constants to be trailing operands.
- bool updatedOpOperands = false;
- if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
- auto isNonConstant = [&](OpOperand &o) {
- return !matchPattern(o.get(), m_Constant());
- };
- auto *firstConstantIt =
- llvm::find_if_not(op->getOpOperands(), isNonConstant);
- auto *newConstantIt = std::stable_partition(
- firstConstantIt, op->getOpOperands().end(), isNonConstant);
-
- // Remember if we actually moved anything.
- updatedOpOperands = firstConstantIt != newConstantIt;
- }
-
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
operandConstants.assign(op->getNumOperands(), Attribute());
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(operandConstants, foldResults)) ||
failed(processFoldResults(op, results, foldResults)))
- return success(updatedOpOperands);
+ return failure();
return success();
}
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
// CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index
+// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index