public:
AffineOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "affine"; }
+
+ /// Materialize a single constant operation from a given attribute value with
+ /// the desired resultant type.
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
};
/// The "affine.apply" operation applies an affine map to a list of operands,
public:
StandardOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "std"; }
+
+ /// Materialize a single constant operation from a given attribute value with
+ /// the desired resultant type.
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
};
/// The predicate indicates the type of the comparison to perform:
public:
VectorOpsDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "vector"; }
+
+ /// Materialize a single constant operation from a given attribute value with
+ /// the desired resultant type.
+ Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+ Location loc) override;
};
/// Collect a set of vector-to-vector canonicalization patterns.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value *> &results, Location location,
Args &&... args) {
- auto op = create<OpTy>(location, std::forward<Args>(args)...);
- tryFold(op.getOperation(), results);
+ // Create the operation without using 'createOperation' as we don't want to
+ // insert it yet.
+ OperationState state(location, OpTy::getOperationName());
+ OpTy::build(this, state, std::forward<Args>(args)...);
+ Operation *op = Operation::create(state);
+
+ // Fold the operation. If successful destroy it, otherwise insert it.
+ if (succeeded(tryFold(op, results)))
+ op->destroy();
+ else
+ insert(op);
}
/// Overload to create or fold a single result operation.
return op;
}
+ /// Attempts to fold the given operation and places new results within
+ /// 'results'. Returns success if the operation was folded, failure otherwise.
+ /// Note: This function does not erase the operation on a successful fold.
+ LogicalResult tryFold(Operation *op, SmallVectorImpl<Value *> &results);
+
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
/// ( leaving them alone if no entry is present). Replaces references to
}
private:
- /// Attempts to fold the given operation and places new results within
- /// 'results'.
- void tryFold(Operation *op, SmallVectorImpl<Value *> &results);
-
Block *block = nullptr;
Block::iterator insertPoint;
};
addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
/// A utility function to check if a given region is attached to a function.
static bool isFunctionRegion(Region *region) {
return llvm::isa<FuncOp>(region->getParentOp());
addInterfaces<StdInlinerInterface>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &p) {
// VectorOpsDialect
//===----------------------------------------------------------------------===//
-mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
+VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
>();
}
+/// Materialize a single constant operation from a given attribute value with
+/// the desired resultant type.
+Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ return builder.create<ConstantOp>(loc, type, value);
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/Functional.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
Builder::Builder(ModuleOp module) : context(module.getContext()) {}
}
/// Attempts to fold the given operation and places new results within
-/// 'results'.
-void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
+/// 'results'. Returns success if the operation was folded, failure otherwise.
+/// Note: This function does not erase the operation on a successful fold.
+LogicalResult OpBuilder::tryFold(Operation *op,
+ SmallVectorImpl<Value *> &results) {
results.reserve(op->getNumResults());
- SmallVector<OpFoldResult, 4> foldResults;
-
- // Returns if the given fold result corresponds to a valid existing value.
- auto isValidValue = [](OpFoldResult result) {
- return result.dyn_cast<Value *>();
+ auto cleanupFailure = [&] {
+ results.assign(op->result_begin(), op->result_end());
+ return failure();
};
- // Check if the fold failed, or did not result in only existing values.
+ // If this operation is already a constant, there is nothing to do.
+ Attribute unused;
+ if (matchPattern(op, m_Constant(&unused)))
+ return cleanupFailure();
+
+ // Check to see if any operands to the operation is constant and whether
+ // the operation knows how to constant fold itself.
SmallVector<Attribute, 4> constOperands(op->getNumOperands());
- if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
- !llvm::all_of(foldResults, isValidValue)) {
- // Simply return the existing operation results.
- results.assign(op->result_begin(), op->result_end());
- return;
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
+
+ // Try to fold the operation.
+ SmallVector<OpFoldResult, 4> foldResults;
+ if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
+ return cleanupFailure();
+
+ // A temporary builder used for creating constants during folding.
+ OpBuilder cstBuilder(context);
+ SmallVector<Operation *, 1> generatedConstants;
+
+ // Populate the results with the folded results.
+ Dialect *dialect = op->getDialect();
+ for (auto &it : llvm::enumerate(foldResults)) {
+ // Normal values get pushed back directly.
+ if (auto *value = it.value().dyn_cast<Value *>()) {
+ results.push_back(value);
+ continue;
+ }
+
+ // Otherwise, try to materialize a constant operation.
+ if (!dialect)
+ return cleanupFailure();
+
+ // Ask the dialect to materialize a constant operation for this value.
+ Attribute attr = it.value().get<Attribute>();
+ auto *constOp = dialect->materializeConstant(
+ cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc());
+ if (!constOp) {
+ // Erase any generated constants.
+ for (Operation *cst : generatedConstants)
+ cst->erase();
+ return cleanupFailure();
+ }
+ assert(matchPattern(constOp, m_Constant(&attr)));
+
+ generatedConstants.push_back(constOp);
+ results.push_back(constOp->getResult(0));
}
- // Populate the results with the folded results and remove the original op.
- llvm::transform(foldResults, std::back_inserter(results),
- [](OpFoldResult result) { return result.get<Value *>(); });
- op->erase();
+ // If we were successful, insert any generated constants.
+ for (Operation *cst : generatedConstants)
+ insert(cst);
+
+ return success();
}
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
ConversionTarget &getTarget() { return target; }
private:
+ /// Attempt to legalize the given operation by folding it.
+ LogicalResult legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter);
+
/// Attempt to legalize the given operation by applying the provided pattern.
/// Returns success if the operation was legalized, failure otherwise.
LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
return success();
}
+ // If the operation isn't legal, try to fold it in-place.
+ // TODO(riverriddle) Should we always try to do this, even if the op is
+ // already legal?
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n");
+ return success();
+ }
+
// Otherwise, we need to apply a legalization pattern to this operation.
auto it = legalizerPatterns.find(op->getName());
if (it == legalizerPatterns.end()) {
}
LogicalResult
+OperationLegalizer::legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
+
+ // Try to fold the operation.
+ SmallVector<Value *, 2> replacementValues;
+ rewriter.setInsertionPoint(op);
+ if (failed(rewriter.tryFold(op, replacementValues)))
+ return failure();
+
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
+ // Recursively legalize any new constant operations.
+ for (unsigned i = curState.numCreatedOperations,
+ e = rewriterImpl.createdOps.size();
+ i != e; ++i) {
+ Operation *cstOp = rewriterImpl.createdOps[i];
+ if (failed(legalize(cstOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '"
+ << cstOp->getName() << "' was illegal.\n");
+ rewriterImpl.resetState(curState);
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult
OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
return
}
-// CHECK-LABEL: func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
-func @vector_ops(vector<4xf32>, vector<4xi1>, vector<4xi64>) -> vector<4xf32> {
-^bb0(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>):
+// CHECK-LABEL: func @vector_ops(%arg0: !llvm<"<4 x float>">, %arg1: !llvm<"<4 x i1>">, %arg2: !llvm<"<4 x i64>">, %arg3: !llvm<"<4 x i64>">) -> !llvm<"<4 x float>"> {
+func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4xi64>, %arg3: vector<4xi64>) -> vector<4xf32> {
// CHECK-NEXT: %0 = llvm.mlir.constant(dense<4.200000e+01> : vector<4xf32>) : !llvm<"<4 x float>">
%0 = constant dense<42.> : vector<4xf32>
// CHECK-NEXT: %1 = llvm.fadd %arg0, %0 : !llvm<"<4 x float>">
%7 = divf %arg0, %0 : vector<4xf32>
// CHECK-NEXT: %7 = llvm.frem %arg0, %0 : !llvm<"<4 x float>">
%8 = remf %arg0, %0 : vector<4xf32>
-// CHECK-NEXT: %8 = llvm.and %arg2, %arg2 : !llvm<"<4 x i64>">
- %9 = and %arg2, %arg2 : vector<4xi64>
-// CHECK-NEXT: %9 = llvm.or %arg2, %arg2 : !llvm<"<4 x i64>">
- %10 = or %arg2, %arg2 : vector<4xi64>
-// CHECK-NEXT: %10 = llvm.xor %arg2, %arg2 : !llvm<"<4 x i64>">
- %11 = xor %arg2, %arg2 : vector<4xi64>
+// CHECK-NEXT: %8 = llvm.and %arg2, %arg3 : !llvm<"<4 x i64>">
+ %9 = and %arg2, %arg3 : vector<4xi64>
+// CHECK-NEXT: %9 = llvm.or %arg2, %arg3 : !llvm<"<4 x i64>">
+ %10 = or %arg2, %arg3 : vector<4xi64>
+// CHECK-NEXT: %10 = llvm.xor %arg2, %arg3 : !llvm<"<4 x i64>">
+ %11 = xor %arg2, %arg3 : vector<4xi64>
return %1 : vector<4xf32>
}
}
// CHECK-LABEL: @dfs_block_order
-func @dfs_block_order() -> (i32) {
-// CHECK-NEXT: %0 = llvm.mlir.constant(42 : i32) : !llvm.i32
+func @dfs_block_order(%arg0: i32) -> (i32) {
+// CHECK-NEXT: %[[CST:.*]] = llvm.mlir.constant(42 : i32) : !llvm.i32
%0 = constant 42 : i32
// CHECK-NEXT: llvm.br ^bb2
br ^bb2
// CHECK-NEXT: ^bb1:
-// CHECK-NEXT: %1 = llvm.add %0, %2 : !llvm.i32
-// CHECK-NEXT: llvm.return %1 : !llvm.i32
+// CHECK-NEXT: %[[ADD:.*]] = llvm.add %arg0, %[[CST]] : !llvm.i32
+// CHECK-NEXT: llvm.return %[[ADD]] : !llvm.i32
^bb1:
- %2 = addi %0, %1 : i32
+ %2 = addi %arg0, %0 : i32
return %2 : i32
// CHECK-NEXT: ^bb2:
^bb2:
-// CHECK-NEXT: %2 = llvm.mlir.constant(55 : i32) : !llvm.i32
- %1 = constant 55 : i32
// CHECK-NEXT: llvm.br ^bb1
br ^bb1
}
#map6 = (d0,d1,d2) -> (d0 + d1 + d2)
// CHECK-LABEL: func @affine_applies(
-func @affine_applies() {
-^bb0:
+func @affine_applies(%arg0 : index) {
// CHECK: %[[c0:.*]] = constant 0 : index
%zero = affine.apply #map0()
// CHECK-NEXT: %[[v1:.*]] = addi %[[v0]], %[[c1]] : index
%one = affine.apply #map3(%symbZero)[%zero]
-// CHECK-NEXT: %[[c103:.*]] = constant 103 : index
-// CHECK-NEXT: %[[c104:.*]] = constant 104 : index
-// CHECK-NEXT: %[[c105:.*]] = constant 105 : index
-// CHECK-NEXT: %[[c106:.*]] = constant 106 : index
-// CHECK-NEXT: %[[c107:.*]] = constant 107 : index
-// CHECK-NEXT: %[[c108:.*]] = constant 108 : index
-// CHECK-NEXT: %[[c109:.*]] = constant 109 : index
- %103 = constant 103 : index
- %104 = constant 104 : index
- %105 = constant 105 : index
- %106 = constant 106 : index
- %107 = constant 107 : index
- %108 = constant 108 : index
- %109 = constant 109 : index
// CHECK-NEXT: %[[c2:.*]] = constant 2 : index
-// CHECK-NEXT: %[[v2:.*]] = muli %[[c104]], %[[c2]] : index
-// CHECK-NEXT: %[[v3:.*]] = addi %[[c103]], %[[v2]] : index
+// CHECK-NEXT: %[[v2:.*]] = muli %arg0, %[[c2]] : index
+// CHECK-NEXT: %[[v3:.*]] = addi %arg0, %[[v2]] : index
// CHECK-NEXT: %[[c3:.*]] = constant 3 : index
-// CHECK-NEXT: %[[v4:.*]] = muli %[[c105]], %[[c3]] : index
+// CHECK-NEXT: %[[v4:.*]] = muli %arg0, %[[c3]] : index
// CHECK-NEXT: %[[v5:.*]] = addi %[[v3]], %[[v4]] : index
// CHECK-NEXT: %[[c4:.*]] = constant 4 : index
-// CHECK-NEXT: %[[v6:.*]] = muli %[[c106]], %[[c4]] : index
+// CHECK-NEXT: %[[v6:.*]] = muli %arg0, %[[c4]] : index
// CHECK-NEXT: %[[v7:.*]] = addi %[[v5]], %[[v6]] : index
// CHECK-NEXT: %[[c5:.*]] = constant 5 : index
-// CHECK-NEXT: %[[v8:.*]] = muli %[[c107]], %[[c5]] : index
+// CHECK-NEXT: %[[v8:.*]] = muli %arg0, %[[c5]] : index
// CHECK-NEXT: %[[v9:.*]] = addi %[[v7]], %[[v8]] : index
// CHECK-NEXT: %[[c6:.*]] = constant 6 : index
-// CHECK-NEXT: %[[v10:.*]] = muli %[[c108]], %[[c6]] : index
+// CHECK-NEXT: %[[v10:.*]] = muli %arg0, %[[c6]] : index
// CHECK-NEXT: %[[v11:.*]] = addi %[[v9]], %[[v10]] : index
// CHECK-NEXT: %[[c7:.*]] = constant 7 : index
-// CHECK-NEXT: %[[v12:.*]] = muli %[[c109]], %[[c7]] : index
+// CHECK-NEXT: %[[v12:.*]] = muli %arg0, %[[c7]] : index
// CHECK-NEXT: %[[v13:.*]] = addi %[[v11]], %[[v12]] : index
- %four = affine.apply #map4(%103,%104,%105,%106)[%107,%108,%109]
+ %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
return
}
return %repl_2 : i8
}
+// CHECK-LABEL: func @remove_foldable_op
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32)
+func @remove_foldable_op(%arg0 : i32) -> (i32) {
+ // CHECK-NEXT: return %[[ARG_0]]
+ %0 = "test.op_with_region_fold"(%arg0) ({
+ "foo.op_with_region_terminator"() : () -> ()
+ }) : (i32) -> (i32)
+ return %0 : i32
+}
+
// -----
func @fail_to_convert_illegal_op() -> i32 {
ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
- target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
+ target
+ .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
// Don't allow F32 operands.
return llvm::none_of(op.getOperandTypes(),