#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StandardTypes.h"
+#include "toy/ShapeInferenceInterface.h"
namespace mlir {
namespace toy {
let verifier = [{ return ::verify(*this); }];
}
-def AddOp : Toy_Op<"add", [NoSideEffect]> {
+def AddOp : Toy_Op<"add",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
buildAddOp(b, result, lhs, rhs);
}]
>];
- let extraClassDeclaration = [{
- void inferShapes() {
- getResult()->setType(getOperand(0)->getType());
- return;
- }
- }];
}
def GenericCallOp : Toy_Op<"generic_call"> {
];
}
-def MulOp : Toy_Op<"mul", [NoSideEffect]> {
+def MulOp : Toy_Op<"mul",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
buildMulOp(b, result, lhs, rhs);
}]
>];
- let extraClassDeclaration = [{
- void inferShapes() {
- auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
- auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
- auto lhsRank = lhs.getShape().size();
- auto rhsRank = rhs.getShape().size();
- if (lhsRank != rhsRank) {
- return;
- }
- SmallVector<int64_t, 2> dims;
- if (lhsRank == 1) {
- // dot product, result shape is <1>
- dims.push_back(1);
- } else {
- if (lhsRank != 2) {
- return;
- }
- dims.push_back(lhs.getShape()[0]);
- dims.push_back(rhs.getShape()[1]);
- }
- getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
- return;
- }
- }];
}
def PrintOp : Toy_Op<"print"> {
let verifier = [{ return ::verify(*this); }];
}
-def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
+def TransposeOp : Toy_Op<"transpose",
+ [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
buildTransposeOp(b, result, input);
}]
>];
- let extraClassDeclaration = [{
- void inferShapes() {
- SmallVector<int64_t, 2> dims;
- auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
- dims.insert(dims.end(), arrayTy.getShape().begin(),
- arrayTy.getShape().end());
- if (dims.size() == 2)
- std::swap(dims[0], dims[1]);
- getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
- return;
- }
- }];
}
#endif // TOY_OPS
--- /dev/null
+//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the declarations of the shape inference interfaces defined
+// in ShapeInferenceInterface.td.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace toy {
+
+/// Include the auto-generated declarations.
+#include "toy/ShapeInferenceOpInterfaces.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [
- InterfaceMethod<"Infer output shape for the current operation.",
- "void", "inferShapes", (ins), [{}]>
+ InterfaceMethod<"Infer and set the output shape for the current operation.",
+ "void", "inferShapes">
];
}
#include <algorithm>
namespace {
+/// This is a simple function DCE pass that deletes all non-main functions after
+/// inlining.
+/// TODO(riverriddle) This is only necessary because MLIR currently does not
+/// have generic DCE support for functions.
class DeadFunctionEliminationPass
: public mlir::ModulePass<DeadFunctionEliminationPass> {
public:
void runOnModule() override {
- std::string str = "main";
- auto module = getModule();
- for (auto &f : module) {
- // eliminate dead functions that are not main
- if (str.find(f.getName().getStringRef()) == std::string::npos)
- f.erase();
+ mlir::ModuleOp module = getModule();
+ mlir::SymbolTable moduleSymTable(module);
+
+ // Eliminate non-main functions.
+ auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
+ for (mlir::FuncOp func :
+ llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
+ if (func != mainFn)
+ func.erase();
}
}
};
-} // namespace
+} // end anonymous namespace
/// Create a pass that eliminates inlined functions in toy.
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
state.addOperands({lhs, rhs});
}
+/// Infer the output shape of the AddOp, this is required by the shape inference
+/// interface.
+void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
+
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
state.addOperands({lhs, rhs});
}
+/// Infer the output shape of the MulOp, this is required by the shape inference
+/// interface.
+void MulOp::inferShapes() {
+ auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
+ auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
+ auto lhsRank = lhs.getShape().size();
+ auto rhsRank = rhs.getShape().size();
+ if (lhsRank != rhsRank)
+ return;
+
+ SmallVector<int64_t, 2> dims;
+ if (lhsRank == 1) {
+ // dot product, result shape is <1>
+ dims.push_back(1);
+ } else if (lhsRank == 2) {
+ dims.push_back(lhs.getShape()[0]);
+ dims.push_back(rhs.getShape()[1]);
+ } else {
+ return;
+ }
+ getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
+}
+
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
state.addOperands(value);
}
+void TransposeOp::inferShapes() {
+ SmallVector<int64_t, 2> dims;
+ auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
+ dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
+ if (dims.size() == 2)
+ std::swap(dims[0], dims[1]);
+ getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSet.h"
+#include "toy/ShapeInferenceInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#include <algorithm>
#define DEBUG_TYPE "shape-inference"
-using llvm::MutableArrayRef;
-using llvm::raw_ostream;
-using llvm::SmallVector;
-using llvm::SmallVectorImpl;
-using llvm::StringRef;
-using llvm::Twine;
using namespace mlir;
+using namespace toy;
-namespace {
-
-// clang-format off
-#include "toy/ShapeInferenceOpInterfaces.h.inc"
+/// Include the auto-generated definitions for the shape inference interfaces.
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
+namespace {
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
/// shape inference.
///
/// Algorithm:
///
-/// 1) Build a worklist containing all the operations that are returning
-/// a generic Toy array: these are the operations that need shape
+/// 1) Build a worklist containing all the operations that return a
+/// dynamically shaped tensor: these are the operations that need shape
/// inference.
/// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist,
-/// d) infer the shape of its output from the arguments type.
-/// 3) If the worklist is empty, the algorithm succeeded and we infer the
-/// return type for the function from the return operation.
+/// d) infer the shape of its output from the argument types.
+/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
public:
- bool returnsGenericArray(Operation *op) {
- if (op->getNumResults() == 1) {
- if (!op->getResult(0)->getType().isa<ShapedType>())
- return true;
- }
- return false;
- }
-
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
- // these are operations that return a generic array.
+ // these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
- if (returnsGenericArray(op)) {
+ if (returnsDynamicShape(op))
opWorklist.insert(op);
- }
});
// Iterate on the operations in the worklist until all operations have been
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
- auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
- return this->returnsGenericArray(op);
- });
-
+ auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
if (nextop == opWorklist.end())
- break; // failure: no operations can be inferred.
+ break;
Operation *op = *nextop;
opWorklist.erase(op);
+
+ // Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
auto shapeOp = dyn_cast<ShapeInference>(op);
shapeOp.inferShapes();
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
+ f.emitError("Shape inference failed, ")
+ << opWorklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
- auto diag = f.emitError("Shape inference failed, ")
- << opWorklist.size() << " operations couldn't be inferred\n";
}
}
+
+ /// A utility method that returns if the given operation has a dynamically
+ /// shaped result.
+ static bool returnsDynamicShape(Operation *op) {
+ return llvm::any_of(op->getResultTypes(), [](Type resultType) {
+ return !resultType.isa<RankedTensorType>();
+ });
+ }
};
} // end anonymous namespace