Code cleanups on Ch.4
authorRiver Riddle <riverriddle@google.com>
Wed, 16 Oct 2019 19:33:55 +0000 (12:33 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 16 Oct 2019 19:34:26 +0000 (12:34 -0700)
This change performs general cleanups of the implementation of ch.4 and fixes some bugs. For example, the operations currently don't inherit from the shape inference interface.

PiperOrigin-RevId: 275089914

mlir/examples/toy/Ch4/include/toy/Dialect.h
mlir/examples/toy/Ch4/include/toy/Ops.td
mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h [new file with mode: 0644]
mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td
mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp

index da61191..556ae97 100644 (file)
@@ -26,6 +26,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/StandardTypes.h"
+#include "toy/ShapeInferenceInterface.h"
 
 namespace mlir {
 namespace toy {
index f0140d7..a8c6759 100644 (file)
@@ -92,7 +92,8 @@ def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
   let verifier = [{ return ::verify(*this); }];
 }
 
-def AddOp : Toy_Op<"add", [NoSideEffect]> {
+def AddOp : Toy_Op<"add",
+    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
   let summary = "element-wise addition operation";
   let description = [{
     The "add" operation performs element-wise addition between two tensors.
@@ -108,12 +109,6 @@ def AddOp : Toy_Op<"add", [NoSideEffect]> {
       buildAddOp(b, result, lhs, rhs);
     }]
   >];
-  let extraClassDeclaration = [{
-  void inferShapes() {
-    getResult()->setType(getOperand(0)->getType());
-    return;
-  }
-  }];
 }
 
 def GenericCallOp : Toy_Op<"generic_call"> {
@@ -150,7 +145,8 @@ def GenericCallOp : Toy_Op<"generic_call"> {
   ];
 }
 
-def MulOp : Toy_Op<"mul", [NoSideEffect]> {
+def MulOp : Toy_Op<"mul",
+    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
   let summary = "element-wise multiplication operation";
   let description = [{
     The "mul" operation performs element-wise multiplication between two
@@ -166,30 +162,6 @@ def MulOp : Toy_Op<"mul", [NoSideEffect]> {
       buildMulOp(b, result, lhs, rhs);
     }]
   >];
-  let extraClassDeclaration = [{
-  void inferShapes() {
-    auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
-    auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
-    auto lhsRank = lhs.getShape().size();
-    auto rhsRank = rhs.getShape().size();
-    if (lhsRank != rhsRank) {
-      return;
-    }
-    SmallVector<int64_t, 2> dims;
-    if (lhsRank == 1) {
-      // dot product, result shape is <1>
-      dims.push_back(1);
-      } else {
-      if (lhsRank != 2) {
-        return;
-      }
-      dims.push_back(lhs.getShape()[0]);
-      dims.push_back(rhs.getShape()[1]);
-    }
-    getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
-    return;
-  }
-  }];
 }
 
 def PrintOp : Toy_Op<"print"> {
@@ -255,7 +227,8 @@ def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
   let verifier = [{ return ::verify(*this); }];
 }
 
-def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
+def TransposeOp : Toy_Op<"transpose",
+    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
   let summary = "transpose operation";
 
   let arguments = (ins F64Tensor:$input);
@@ -268,18 +241,6 @@ def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
       buildTransposeOp(b, result, input);
     }]
   >];
-  let extraClassDeclaration = [{
-  void inferShapes() {
-    SmallVector<int64_t, 2> dims;
-    auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
-    dims.insert(dims.end(), arrayTy.getShape().begin(),
-                arrayTy.getShape().end());
-    if (dims.size() == 2)
-      std::swap(dims[0], dims[1]);
-    getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
-    return;
-  }
-  }];
 }
 
 #endif // TOY_OPS
diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h
new file mode 100644 (file)
index 0000000..fc36b5b
--- /dev/null
@@ -0,0 +1,37 @@
+//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains the declarations of the shape inference interfaces defined
+// in ShapeInferenceInterface.td.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace toy {
+
+/// Include the auto-generated declarations.
+#include "toy/ShapeInferenceOpInterfaces.h.inc"
+
+} // end namespace toy
+} // end namespace mlir
+
+#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_
index 2040cc4..4b1240d 100644 (file)
@@ -30,8 +30,8 @@ include "mlir/IR/OpBase.td"
 
 def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
   let methods = [
-    InterfaceMethod<"Infer output shape for the current operation.",
-                    "void", "inferShapes", (ins), [{}]>
+    InterfaceMethod<"Infer and set the output shape for the current operation.",
+                    "void", "inferShapes">
   ];
 }
 
index e7e64ce..b58adb5 100644 (file)
 #include <algorithm>
 
 namespace {
+/// This is a simple function DCE pass that deletes all non-main functions after
+/// inlining.
+/// TODO(riverriddle) This is only necessary because MLIR currently does not
+/// have generic DCE support for functions.
 class DeadFunctionEliminationPass
     : public mlir::ModulePass<DeadFunctionEliminationPass> {
 public:
   void runOnModule() override {
-    std::string str = "main";
-    auto module = getModule();
-    for (auto &f : module) {
-      // eliminate dead functions that are not main
-      if (str.find(f.getName().getStringRef()) == std::string::npos)
-        f.erase();
+    mlir::ModuleOp module = getModule();
+    mlir::SymbolTable moduleSymTable(module);
+
+    // Eliminate non-main functions.
+    auto mainFn = moduleSymTable.lookup<mlir::FuncOp>("main");
+    for (mlir::FuncOp func :
+         llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
+      if (func != mainFn)
+        func.erase();
     }
   }
 };
-} // namespace
+} // end anonymous namespace
 
 /// Create a pass that eliminates inlined functions in toy.
 std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
index 63eee4e..e285fac 100644 (file)
@@ -126,6 +126,10 @@ static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
   state.addOperands({lhs, rhs});
 }
 
+/// Infer the output shape of the AddOp, this is required by the shape inference
+/// interface.
+void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
+
 static void buildGenericCallOp(mlir::Builder *builder,
                                mlir::OperationState &state, StringRef callee,
                                ArrayRef<mlir::Value *> arguments) {
@@ -141,6 +145,29 @@ static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
   state.addOperands({lhs, rhs});
 }
 
+/// Infer the output shape of the MulOp, this is required by the shape inference
+/// interface.
+void MulOp::inferShapes() {
+  auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
+  auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
+  auto lhsRank = lhs.getShape().size();
+  auto rhsRank = rhs.getShape().size();
+  if (lhsRank != rhsRank)
+    return;
+
+  SmallVector<int64_t, 2> dims;
+  if (lhsRank == 1) {
+    // dot product, result shape is <1>
+    dims.push_back(1);
+  } else if (lhsRank == 2) {
+    dims.push_back(lhs.getShape()[0]);
+    dims.push_back(rhs.getShape()[1]);
+  } else {
+    return;
+  }
+  getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
+}
+
 static mlir::LogicalResult verify(ReturnOp op) {
   // We know that the parent operation is a function, because of the 'HasParent'
   // trait attached to the operation definition.
@@ -182,6 +209,15 @@ static void buildTransposeOp(mlir::Builder *builder,
   state.addOperands(value);
 }
 
+void TransposeOp::inferShapes() {
+  SmallVector<int64_t, 2> dims;
+  auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
+  dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end());
+  if (dims.size() == 2)
+    std::swap(dims[0], dims[1]);
+  getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
index b8b091a..5acf8f9 100644 (file)
 #include "mlir/Pass/Pass.h"
 #include "toy/Dialect.h"
 #include "toy/Passes.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringSet.h"
+#include "toy/ShapeInferenceInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
-#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
-#include <algorithm>
 
 #define DEBUG_TYPE "shape-inference"
 
-using llvm::MutableArrayRef;
-using llvm::raw_ostream;
-using llvm::SmallVector;
-using llvm::SmallVectorImpl;
-using llvm::StringRef;
-using llvm::Twine;
 using namespace mlir;
+using namespace toy;
 
-namespace {
-
-// clang-format off
-#include "toy/ShapeInferenceOpInterfaces.h.inc"
+/// Include the auto-generated definitions for the shape inference interfaces.
 #include "toy/ShapeInferenceOpInterfaces.cpp.inc"
 
+namespace {
 /// The ShapeInferencePass is a FunctionPass that performs intra-procedural
 /// shape inference.
 ///
 ///    Algorithm:
 ///
-///   1) Build a worklist containing all the operations that are returning
-///      a generic Toy array: these are the operations that need shape
+///   1) Build a worklist containing all the operations that return a
+///      dynamically shaped tensor: these are the operations that need shape
 ///      inference.
 ///   2) Iterate on the worklist:
 ///     a) find an operation to process: the next ready operation in the
 ///        worklist has all of its arguments non-generic,
 ///     b) if no operation is found, break out of the loop,
 ///     c) remove the operation from the worklist,
-///     d) infer the shape of its output from the arguments type.
-///   3) If the worklist is empty, the algorithm succeeded and we infer the
-///      return type for the function from the return operation.
+///     d) infer the shape of its output from the argument types.
+///   3) If the worklist is empty, the algorithm succeeded.
 ///
 class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
 public:
-  bool returnsGenericArray(Operation *op) {
-    if (op->getNumResults() == 1) {
-      if (!op->getResult(0)->getType().isa<ShapedType>())
-        return true;
-    }
-    return false;
-  }
-
   void runOnFunction() override {
     auto f = getFunction();
 
     // Populate the worklist with the operations that need shape inference:
-    // these are operations that return a generic array.
+    // these are operations that return a dynamic shape.
     llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
     f.walk([&](mlir::Operation *op) {
-      if (returnsGenericArray(op)) {
+      if (returnsDynamicShape(op))
         opWorklist.insert(op);
-      }
     });
 
     // Iterate on the operations in the worklist until all operations have been
@@ -91,15 +71,14 @@ public:
     while (!opWorklist.empty()) {
       // Find the next operation ready for inference, that is an operation
       // with all operands already resolved (non-generic).
-      auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
-        return this->returnsGenericArray(op);
-      });
-
+      auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
       if (nextop == opWorklist.end())
-        break; // failure: no operations can be inferred.
+        break;
 
       Operation *op = *nextop;
       opWorklist.erase(op);
+
+      // Ask the operation to infer its output shapes.
       LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
       auto shapeOp = dyn_cast<ShapeInference>(op);
       shapeOp.inferShapes();
@@ -107,11 +86,19 @@ public:
 
     // If the operation worklist isn't empty, this indicates a failure.
     if (!opWorklist.empty()) {
+      f.emitError("Shape inference failed, ")
+          << opWorklist.size() << " operations couldn't be inferred\n";
       signalPassFailure();
-      auto diag = f.emitError("Shape inference failed, ")
-                  << opWorklist.size() << " operations couldn't be inferred\n";
     }
   }
+
+  /// A utility method that returns if the given operation has a dynamically
+  /// shaped result.
+  static bool returnsDynamicShape(Operation *op) {
+    return llvm::any_of(op->getResultTypes(), [](Type resultType) {
+      return !resultType.isa<RankedTensorType>();
+    });
+  }
 };
 } // end anonymous namespace