De-templatize TensorContractionBase (Linalg example/tutorial)
authorNicolas Vasilache <ntv@google.com>
Mon, 8 Apr 2019 22:06:34 +0000 (15:06 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 9 Apr 2019 02:17:56 +0000 (19:17 -0700)
    TensorContractionBase has become too unwieldy with all the CRTP manipulation once less trivial transformations are implemented.
    This CL drops CRTP for inheritance and uses the same name comparison trick to figure out what to cast into.
    As a byproduct, all the -inl.h files disappear.
    To maintain the separation between directories, a LINALG_STEP variable is introduced

--

PiperOrigin-RevId: 242546977

mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h [deleted file]
mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h
mlir/examples/Linalg/Linalg2/lib/TensorOps.cpp
mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h [deleted file]
mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps.h
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp
mlir/examples/Linalg/Linalg3/lib/Transforms.cpp
mlir/examples/Linalg/Linalg4/lib/Transforms.cpp

diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h
deleted file mode 100644 (file)
index 940f8d7..0000000
+++ /dev/null
@@ -1,120 +0,0 @@
-//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#ifndef LINALG2_TENSOROPS_INL_H_
-#define LINALG2_TENSOROPS_INL_H_
-
-#include "linalg2/Ops.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/StandardTypes.h"
-
-namespace linalg {
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputs() {
-  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
-  return {op->operand_begin(), op->operand_begin() + getNumInputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
-  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
-  return {op->operand_begin() + getNumInputs(),
-          op->operand_begin() + getNumInputs() + getNumOutputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
-  return {getInputs().begin(), getOutputs().end()};
-}
-
-template <class ConcreteOp>
-mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
-  auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
-  if (getNumInputs() <= 0)
-    concreteOp->emitOpError("expected at least one input");
-  if (getNumOutputs() <= 0)
-    concreteOp->emitOpError("expected at least one output");
-  if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
-    concreteOp->emitOpError("expected " +
-                            llvm::Twine(getNumInputs() + getNumOutputs()) +
-                            " operands");
-  }
-  for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
-    if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
-      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
-                                     " not a ViewType");
-  }
-  for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
-       ++i) {
-    auto viewType =
-        concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
-    if (!viewType)
-      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
-                                     " not a ViewType");
-    if (viewType.getRank() != getNumParallelDims())
-      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
-                                     " must be of rank " +
-                                     llvm::Twine(getNumParallelDims()));
-  }
-  return mlir::success();
-}
-
-template <class ConcreteOp>
-bool linalg::TensorContractionBase<ConcreteOp>::parse(
-    mlir::OpAsmParser *parser, mlir::OperationState *result) {
-  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
-}
-
-// A TensorContraction prints as:
-//
-// ```{.mlir}
-//   concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
-// ```
-//
-// for example:
-//
-// ```
-//   linalg.matmul(%0, %1, %2) : view<?x?xf32>
-// ```
-//
-// Where %0, %1 and %2 are ssa-values of type ViewType.
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
-  *p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
-  auto *last = *std::prev(getInputsAndOutputs().end());
-  for (auto *i : getInputsAndOutputs()) {
-    *p << *i << ((i == last) ? "" : ", ");
-  }
-  *p << ") : ";
-  auto *lastOutput = *std::prev(getOutputs().end());
-  for (auto *o : getOutputs()) {
-    *p << o->getType() << ((o == lastOutput) ? "" : ",");
-  }
-}
-
-} // namespace linalg
-
-#endif // LINALG2_TENSOROPS_INL_H_
index 39e51f0..cac813e 100644 (file)
@@ -29,44 +29,29 @@ namespace linalg {
 
 /// A generic TensorContraction base class which captures the generic behavior
 /// of tensor contraction operations (with broadcast).
-template <class ConcreteOp> class TensorContractionBase {
-protected:
-  using TensorContractionBaseType = TensorContractionBase<ConcreteOp>;
-
-  //////////////////////////////////////////////////////////////////////////////
-  // Hooks to customize the behavior of this op.
-  //////////////////////////////////////////////////////////////////////////////
-  /// Generic implementation of hooks that should be called from `ConcreteType`s
-  mlir::LogicalResult verify();
-  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
-  void print(mlir::OpAsmPrinter *p);
-
+class TensorContractionBase {
 public:
+  virtual ~TensorContractionBase() {}
+
   //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
-  TensorContractionBase() = default;
+  virtual llvm::StringRef getTensorContractionName() = 0;
   mlir::Operation::operand_range getInputs();
   mlir::Operation::operand_range getOutputs();
-  mlir::Operation::operand_range getInputsAndOutputs();
+  mlir::Operation::operand_range getInputsAndOutputs() {
+    return {getInputs().begin(), getOutputs().end()};
+  }
 
   /// These are better as methods calling into the ConcreteOp instead of
   /// template parameters because methods allow more generic behavior and avoid
   /// specializing for number of arguments. All derived classes have
   /// `VariadicOperands` and a build method from both an ArrayRef<mlirValue*>
   /// and the proper number of mlir::Value*.
-  unsigned getNumInputs() {
-    return static_cast<ConcreteOp *>(this)->numInputs;
-  };
-  unsigned getNumOutputs() {
-    return static_cast<ConcreteOp *>(this)->numOutputs;
-  };
-  unsigned getNumParallelDims() {
-    return static_cast<ConcreteOp *>(this)->numParallelDims;
-  };
-  unsigned getNumReductionDims() {
-    return static_cast<ConcreteOp *>(this)->numReductionDims;
-  };
+  virtual unsigned getNumInputs() = 0;
+  virtual unsigned getNumOutputs() = 0;
+  virtual unsigned getNumParallelDims() = 0;
+  virtual unsigned getNumReductionDims() = 0;
 
   //////////////////////////////////////////////////////////////////////////////
   // Used in Linalg3 and later.
@@ -79,13 +64,18 @@ public:
                : getOutputView(viewIndex - getNumInputs());
   }
 
+  /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+  /// loop over matvec). Does nothing by default.
+  virtual void writeAsFinerGrainTensorContraction() {}
+
   /// Each op is responsible for declaring how it lowers itself to scalar form,
   /// given the enclosing parallel and reduction induction variables.
   /// `emitScalarImplementation` emits the scalar IR for the op in the nesting
   /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
   /// `parallel.back()`).
-  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
-                                llvm::ArrayRef<mlir::Value *> reductionIvs);
+  virtual void
+  emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                           llvm::ArrayRef<mlir::Value *> reductionIvs) {}
 
   /// Represents a mapping from the loops to all the ranges of the operands.
   /// The operands and their ranges are in the order defined by the particular
@@ -94,17 +84,17 @@ public:
   /// it explicitly is not expensive and generalizes to cases where an analysis
   /// is not available. For details, see the description of
   /// loopsToOperandRangeMaps in each ConcreteOp.
-  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+  virtual llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() {
+    return llvm::SmallVector<mlir::AffineMap, 8>();
+  }
 };
 
 /// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
-class DotOp : public TensorContractionBase<DotOp>,
+class DotOp : public TensorContractionBase,
               public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
                               mlir::OpTrait::ZeroResult> {
 public:
   using Op::Op;
-  using TensorContractionBaseType =
-      TensorContractionBase::TensorContractionBaseType;
 
   //////////////////////////////////////////////////////////////////////////////
   // Hooks to customize the behavior of this op.
@@ -123,24 +113,28 @@ public:
   //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
-  static constexpr unsigned numInputs = 2;
-  static constexpr unsigned numOutputs = 1;
-  static constexpr unsigned numParallelDims = 0;
-  static constexpr unsigned numReductionDims = 1;
+  llvm::StringRef getTensorContractionName() override {
+    return getOperationName();
+  }
+  unsigned getNumInputs() override { return 2; }
+  unsigned getNumOutputs() override { return 1; }
+  unsigned getNumParallelDims() override { return 0; }
+  unsigned getNumReductionDims() override { return 1; }
 
+#if LINALG_STEP > 2
   //////////////////////////////////////////////////////////////////////////////
   // Used in Linalg3 and later.
   //////////////////////////////////////////////////////////////////////////////
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
-  void writeAsFinerGrainTensorContraction();
+  void writeAsFinerGrainTensorContraction() override;
 
   /// Inputs to this map will be (%k) coming from enclosing loops.
   /// Therefore, the mapping to get back to A(K), B(K), C() is:
   ///   (d0) -> (d0, d0)(%k)
   /// And the operands ranges are:
   ///   (%k, %k)
-  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
 
   ///  Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
   ///  to:
@@ -153,18 +147,18 @@ public:
   ///  cond = (r_i == zero)
   ///  scalarC = select(cond, zerof, C[]);
   ///  C[] = scalarC + A[r_i] * B[r_i];
-  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
-                                llvm::ArrayRef<mlir::Value *> reductionIvs);
+  void
+  emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                           llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
 };
 
 /// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
-class MatvecOp : public TensorContractionBase<MatvecOp>,
+class MatvecOp : public TensorContractionBase,
                  public mlir::Op<MatvecOp, mlir::OpTrait::VariadicOperands,
                                  mlir::OpTrait::ZeroResult> {
 public:
   using Op::Op;
-  using TensorContractionBaseType =
-      TensorContractionBase::TensorContractionBaseType;
 
   //////////////////////////////////////////////////////////////////////////////
   // Hooks to customize the behavior of this op.
@@ -183,24 +177,28 @@ public:
   //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
-  static constexpr unsigned numInputs = 2;
-  static constexpr unsigned numOutputs = 1;
-  static constexpr unsigned numParallelDims = 1;
-  static constexpr unsigned numReductionDims = 1;
+  llvm::StringRef getTensorContractionName() override {
+    return getOperationName();
+  }
+  unsigned getNumInputs() override { return 2; }
+  unsigned getNumOutputs() override { return 1; }
+  unsigned getNumParallelDims() override { return 1; }
+  unsigned getNumReductionDims() override { return 1; }
 
+#if LINALG_STEP > 2
   //////////////////////////////////////////////////////////////////////////////
   // Used in Linalg3 and later.
   //////////////////////////////////////////////////////////////////////////////
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
-  void writeAsFinerGrainTensorContraction();
+  void writeAsFinerGrainTensorContraction() override;
 
   /// Inputs to this map will be (%m, %k) coming from enclosing loops.
   /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
   ///   (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
   /// And the operands ranges are:
   ///   (%m, %k, %k, %m)
-  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
 
   ///  Given an enclosing parallel loop with iv `i` and an enclosing parallel
   ///  loop with iv `r_j`, emits MLIR corresponding to:
@@ -213,18 +211,18 @@ public:
   ///  cond = (r_j == zero)
   ///  scalarC = select(cond, zerof, C(i));
   ///  C(i) = scalarC + A(i, r_j) * B(r_j);
-  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
-                                llvm::ArrayRef<mlir::Value *> reductionIvs);
+  void
+  emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                           llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
 };
 
 /// Implements C = A * B on 2-D matrices.
-class MatmulOp : public TensorContractionBase<MatmulOp>,
+class MatmulOp : public TensorContractionBase,
                  public mlir::Op<MatmulOp, mlir::OpTrait::VariadicOperands,
                                  mlir::OpTrait::ZeroResult> {
 public:
   using Op::Op;
-  using TensorContractionBaseType =
-      TensorContractionBase::TensorContractionBaseType;
 
   //////////////////////////////////////////////////////////////////////////////
   // Hooks to customize the behavior of this op.
@@ -243,24 +241,28 @@ public:
   //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
-  static constexpr unsigned numInputs = 2;
-  static constexpr unsigned numOutputs = 1;
-  static constexpr unsigned numParallelDims = 2;
-  static constexpr unsigned numReductionDims = 1;
+  llvm::StringRef getTensorContractionName() override {
+    return getOperationName();
+  }
+  unsigned getNumInputs() override { return 2; }
+  unsigned getNumOutputs() override { return 1; }
+  unsigned getNumParallelDims() override { return 2; }
+  unsigned getNumReductionDims() override { return 1; }
 
+#if LINALG_STEP > 2
   //////////////////////////////////////////////////////////////////////////////
   // Used in Linalg3 and later.
   //////////////////////////////////////////////////////////////////////////////
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
-  void writeAsFinerGrainTensorContraction();
+  void writeAsFinerGrainTensorContraction() override;
 
   /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
   /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
   ///   (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
   /// And the operands ranges are:
   ///   (%m, %k, %k, %n, %m, %n)
-  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps();
+  llvm::SmallVector<mlir::AffineMap, 8> loopsToOperandRangeMaps() override;
 
   ///  Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
   ///  reduction loop with iv `r_k`, emits MLIR corresponding to:
@@ -273,15 +275,12 @@ public:
   ///  cond = (r_k == zero)
   ///  scalarC = select(cond, zerof, C[i, j]);
   ///  C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
-  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
-                                llvm::ArrayRef<mlir::Value *> reductionIvs);
+  void
+  emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                           llvm::ArrayRef<mlir::Value *> reductionIvs) override;
+#endif // LINALG_STEP
 };
 
 } // namespace linalg
 
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#include "linalg2/TensorOps-inl.h"
-
 #endif // LINALG2_TENSOROPS_H_
index 8a47e5d..6aeefc8 100644 (file)
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 
-using llvm::ArrayRef;
-using llvm::Twine;
-
 using namespace mlir;
 using namespace linalg;
 
+#define TENSOR_CONTRACTION_DISPATCH(FUNCTION_NAME)                             \
+  if (getTensorContractionName() == MatmulOp::getOperationName()) {            \
+    return FUNCTION_NAME(static_cast<MatmulOp &>(*this));                      \
+  }                                                                            \
+  if (getTensorContractionName() == MatvecOp::getOperationName()) {            \
+    return FUNCTION_NAME(static_cast<MatvecOp &>(*this));                      \
+  }                                                                            \
+  if (getTensorContractionName() == DotOp::getOperationName()) {               \
+    return FUNCTION_NAME(static_cast<DotOp &>(*this));                         \
+  }                                                                            \
+  llvm_unreachable("Missing linalg op");
+
+template <typename ConcreteOp>
+static mlir::Operation::operand_range getInputs(ConcreteOp &concreteOp) {
+  return {concreteOp.operand_begin(),
+          concreteOp.operand_begin() + concreteOp.getNumInputs()};
+}
+
+mlir::Operation::operand_range linalg::TensorContractionBase::getInputs() {
+  TENSOR_CONTRACTION_DISPATCH(::getInputs);
+}
+
+template <typename ConcreteOp>
+static mlir::Operation::operand_range getOutputs(ConcreteOp &concreteOp) {
+  return {concreteOp.operand_begin() + concreteOp.getNumInputs(),
+          concreteOp.operand_begin() + concreteOp.getNumInputs() +
+              concreteOp.getNumOutputs()};
+}
+
+mlir::Operation::operand_range linalg::TensorContractionBase::getOutputs() {
+  TENSOR_CONTRACTION_DISPATCH(::getOutputs);
+}
+
+template <typename LinalgOp>
+static mlir::LogicalResult verifyLinalgOp(LinalgOp op) {
+  if (op.getNumInputs() <= 0)
+    op.emitOpError("expected at least one input");
+  if (op.getNumOutputs() <= 0)
+    op.emitOpError("expected at least one output");
+  if (op.getNumOperands() != op.getNumInputs() + op.getNumOutputs()) {
+    op.emitOpError("expected " +
+                   llvm::Twine(op.getNumInputs() + op.getNumOutputs()) +
+                   " operands");
+  }
+  for (unsigned i = 0, e = op.getNumInputs(); i < e; ++i) {
+    if (!op.getOperand(i)->getType().template isa<ViewType>())
+      return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
+  }
+  for (unsigned i = op.getNumInputs(),
+                e = op.getNumInputs() + op.getNumOutputs();
+       i < e; ++i) {
+    auto viewType = op.getOperand(i)->getType().template dyn_cast<ViewType>();
+    if (!viewType)
+      return op.emitOpError("operand " + llvm::Twine(i) + " not a ViewType");
+    if (viewType.getRank() != op.getNumParallelDims())
+      return op.emitOpError("operand " + llvm::Twine(i) + " must be of rank " +
+                            llvm::Twine(op.getNumParallelDims()));
+  }
+  return mlir::success();
+}
+
+// A TensorContraction prints as:
+//
+// ```{.mlir}
+//   concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
+// ```
+//
+// for example:
+//
+// ```
+//   linalg.matmul(%0, %1, %2) : view<?x?xf32>
+// ```
+//
+// Where %0, %1 and %2 are ssa-values of type ViewType.
+template <typename LinalgOp>
+static void printLinalgOp(mlir::OpAsmPrinter *p, LinalgOp op) {
+  *p << op.getOperationName() << "(";
+  auto *last = *std::prev(op.getInputsAndOutputs().end());
+  for (auto *i : op.getInputsAndOutputs()) {
+    *p << *i << ((i == last) ? "" : ", ");
+  }
+  *p << ") : ";
+  auto *lastOutput = *std::prev(op.getOutputs().end());
+  for (auto *o : op.getOutputs()) {
+    *p << o->getType() << ((o == lastOutput) ? "" : ",");
+  }
+}
+
 //////////////////////////////////////////////////////////////////////////////
 // Op-specific Dot.
 //////////////////////////////////////////////////////////////////////////////
@@ -43,7 +129,7 @@ void linalg::DotOp::build(Builder *b, OperationState *result,
 }
 
 LogicalResult linalg::DotOp::verify() {
-  if (failed(TensorContractionBaseType::verify()))
+  if (failed(verifyLinalgOp(*this)))
     return failure();
   auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
   unsigned index = 0;
@@ -60,12 +146,10 @@ LogicalResult linalg::DotOp::verify() {
 // Parsing of the linalg dialect is not supported in this tutorial.
 bool linalg::DotOp::parse(mlir::OpAsmParser *parser,
                           mlir::OperationState *result) {
-  return TensorContractionBaseType::parse(parser, result);
+  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
 }
 
-void linalg::DotOp::print(mlir::OpAsmPrinter *p) {
-  TensorContractionBaseType::print(p);
-}
+void linalg::DotOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
 
 //////////////////////////////////////////////////////////////////////////////
 // Op-specific Matvec.
@@ -76,7 +160,7 @@ void linalg::MatvecOp::build(Builder *b, OperationState *result,
 }
 
 LogicalResult linalg::MatvecOp::verify() {
-  if (failed(TensorContractionBaseType::verify()))
+  if (failed(verifyLinalgOp(*this)))
     return failure();
   auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
   if (getViewRank(A) != 2)
@@ -94,12 +178,10 @@ LogicalResult linalg::MatvecOp::verify() {
 // Parsing of the linalg dialect is not supported in this tutorial.
 bool linalg::MatvecOp::parse(mlir::OpAsmParser *parser,
                              mlir::OperationState *result) {
-  return TensorContractionBaseType::parse(parser, result);
+  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
 }
 
-void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) {
-  TensorContractionBaseType::print(p);
-}
+void linalg::MatvecOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
 
 //////////////////////////////////////////////////////////////////////////////
 // Op-specific Matmul.
@@ -110,7 +192,7 @@ void linalg::MatmulOp::build(Builder *b, OperationState *result,
 }
 
 LogicalResult linalg::MatmulOp::verify() {
-  if (failed(TensorContractionBaseType::verify()))
+  if (failed(verifyLinalgOp(*this)))
     return failure();
   auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
   unsigned index = 0;
@@ -125,9 +207,7 @@ LogicalResult linalg::MatmulOp::verify() {
 // Parsing of the linalg dialect is not supported in this tutorial.
 bool linalg::MatmulOp::parse(mlir::OpAsmParser *parser,
                              mlir::OperationState *result) {
-  return TensorContractionBaseType::parse(parser, result);
+  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
 }
 
-void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) {
-  TensorContractionBaseType::print(p);
-}
+void linalg::MatmulOp::print(mlir::OpAsmPrinter *p) { printLinalgOp(p, *this); }
diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h
deleted file mode 100644 (file)
index b651053..0000000
+++ /dev/null
@@ -1,145 +0,0 @@
-//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#ifndef LINALG3_TENSOROPS_INL_H_
-#define LINALG3_TENSOROPS_INL_H_
-
-#include "linalg1/Common.h"
-#include "linalg1/Utils.h"
-#include "linalg2/TensorOps.h"
-#include "linalg3/Analysis.h"
-#include "linalg3/Ops.h"
-
-template <class ConcreteOp>
-mlir::Value *
-linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned viewIndex) {
-  return *(getInputs().begin() + viewIndex);
-}
-
-template <class ConcreteOp>
-mlir::Value *
-linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned viewIndex) {
-  return *(getOutputs().begin() + viewIndex);
-}
-
-template <class ConcreteOp>
-llvm::SmallVector<mlir::AffineMap, 8>
-linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangeMaps() {
-  return static_cast<ConcreteOp *>(this)->loopsToOperandRangeMaps();
-}
-
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
-    llvm::ArrayRef<mlir::Value *> parallelIvs,
-    llvm::ArrayRef<mlir::Value *> reductionIvs) {
-  static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
-                                                            reductionIvs);
-}
-
-template <class ConcreteOp>
-mlir::AffineMap linalg::operandRangesToLoopsMap(
-    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
-  mlir::AffineMap current;
-  // Individual submaps may not be invertible but their union must be invertible
-  // by construction.
-  for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
-    if (!m)
-      continue;
-    if (!current) {
-      current = m;
-      continue;
-    }
-    llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
-                                                   current.getResults().end());
-    results.append(m.getResults().begin(), m.getResults().end());
-    current = mlir::AffineMap::get(
-        std::max(current.getNumDims(), m.getNumDims()),
-        current.getNumSymbols() + m.getNumSymbols(), results, {});
-  }
-  return inverseSubMap(current);
-}
-
-// Extract the ranges from a given ViewOp or SliceOp.
-//
-// In the case of a ViewOp, things are simple: just traverse the indexings and
-// get all the ranges (i.e. drop the indices).
-//
-// In the case of a SliceOp, things are trickier because we need to handle a
-// potential rank-reduction:
-//   1. Examine the indexing to determine if it is rank-reducing.
-//   2. If it is rank-reducing, an offset of 1 is added to the dimensions such
-//      that `d >= slicingDim`. This is to account for the rank reduction.
-// `getRootIndex` is then called on the **parent** view
-static llvm::SmallVector<mlir::Value *, 8>
-extractRangesFromViewOrSliceOp(mlir::Value *view) {
-  // This expects a viewType which must come from either ViewOp or SliceOp.
-  assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
-  if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
-    return viewOp.getRanges();
-
-  auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
-  unsigned slicingDim = sliceOp.getSlicingDim();
-  auto *indexing = *(sliceOp.getIndexings().begin());
-  bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
-  unsigned offset = 0;
-  llvm::SmallVector<mlir::Value *, 8> res;
-  res.reserve(sliceOp.getRank());
-  for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
-    if (d == slicingDim && isRankReducing)
-      offset = 1;
-    auto *parentView = sliceOp.getParentView();
-    auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
-    res.push_back(indexingPosPair.first);
-  }
-  return res;
-}
-
-template <class ConcreteOp>
-static llvm::SmallVector<mlir::Value *, 8>
-getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
-  llvm::SmallVector<mlir::Value *, 8> res;
-  for (auto *in : tensorContraction.getInputs()) {
-    auto subres = extractRangesFromViewOrSliceOp(in);
-    res.append(subres.begin(), subres.end());
-  }
-  return res;
-}
-
-template <class ConcreteOp>
-static llvm::SmallVector<mlir::Value *, 8>
-getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
-  llvm::SmallVector<mlir::Value *, 8> res;
-  for (auto *out : tensorContraction.getOutputs()) {
-    auto subres = extractRangesFromViewOrSliceOp(out);
-    res.append(subres.begin(), subres.end());
-  }
-  return res;
-}
-
-template <class ConcreteOp>
-llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
-    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
-  llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
-  llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
-  res.append(tmp.begin(), tmp.end());
-  return res;
-}
-
-#endif // LINALG3_TENSOROPS_INL_H_
index bf5a377..cbb247c 100644 (file)
@@ -29,9 +29,8 @@ namespace linalg {
 
 /// Takes a `tensorContraction` and a returns an AffineMap that can be used to
 /// map ranges to enclosing loops for all the operands' ranges.
-template <class ConcreteOp>
-mlir::AffineMap operandRangesToLoopsMap(
-    linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+mlir::AffineMap
+operandRangesToLoopsMap(linalg::TensorContractionBase &tensorContraction);
 
 /// Takes a `tensorContraction` and returns the ranges of all its operands.
 /// When an operand comes from a ViewOp, things are simple:
@@ -40,15 +39,9 @@ mlir::AffineMap operandRangesToLoopsMap(
 /// In the case of a SliceOp, things are more involved because we need to handle
 /// potential rank-reductions.
 /// This function abstracts this complexity away and returns all the ranges.
-template <class ConcreteOp>
 llvm::SmallVector<mlir::Value *, 8>
-getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+getRanges(linalg::TensorContractionBase &tensorContraction);
 
 } // namespace linalg
 
-/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
-/// TensorOps by adding implementations as they are needed in the appropriate
-/// step in the tutorial.
-#include "linalg3/TensorOps-inl.h"
-
 #endif // LINALG3_TENSOROPS_H_
index fd2afd9..1acdf7a 100644 (file)
@@ -27,7 +27,9 @@
 #include "mlir/IR/Types.h"
 #include "mlir/LLVMIR/LLVMDialect.h"
 #include "mlir/LLVMIR/Transforms.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
 
 #include "linalg1/ConvertToLLVMDialect.h"
 #include "linalg1/LLVMIntrinsics.h"
index a5b094c..673c686 100644 (file)
@@ -20,8 +20,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "linalg1/Analysis.h"
 #include "linalg1/Common.h"
+#include "linalg3/Analysis.h"
 #include "linalg3/Intrinsics.h"
 #include "linalg3/Ops.h"
 #include "mlir/IR/Builders.h"
@@ -36,6 +36,99 @@ using namespace mlir::edsc::intrinsics;
 using namespace linalg;
 using namespace linalg::intrinsics;
 
+mlir::Value *linalg::TensorContractionBase::getInputView(unsigned viewIndex) {
+  return *(getInputs().begin() + viewIndex);
+}
+
+mlir::Value *linalg::TensorContractionBase::getOutputView(unsigned viewIndex) {
+  return *(getOutputs().begin() + viewIndex);
+}
+
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+    linalg::TensorContractionBase &tensorContraction) {
+  mlir::AffineMap current;
+  // Individual submaps may not be invertible but their union must be invertible
+  // by construction.
+  for (auto m : tensorContraction.loopsToOperandRangeMaps()) {
+    if (!m)
+      continue;
+    if (!current) {
+      current = m;
+      continue;
+    }
+    llvm::SmallVector<mlir::AffineExpr, 8> results(current.getResults().begin(),
+                                                   current.getResults().end());
+    results.append(m.getResults().begin(), m.getResults().end());
+    current = mlir::AffineMap::get(
+        std::max(current.getNumDims(), m.getNumDims()),
+        current.getNumSymbols() + m.getNumSymbols(), results, {});
+  }
+  return inverseSubMap(current);
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+//   1. Examine the indexing to determine if it is rank-reducing.
+//   2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+//      that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+  // This expects a viewType which must come from either ViewOp or SliceOp.
+  assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+  if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+    return viewOp.getRanges();
+
+  auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+  unsigned slicingDim = sliceOp.getSlicingDim();
+  auto *indexing = *(sliceOp.getIndexings().begin());
+  bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+  unsigned offset = 0;
+  llvm::SmallVector<mlir::Value *, 8> res;
+  res.reserve(sliceOp.getRank());
+  for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+    if (d == slicingDim && isRankReducing)
+      offset = 1;
+    auto *parentView = sliceOp.getParentView();
+    auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+    res.push_back(indexingPosPair.first);
+  }
+  return res;
+}
+
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *in : tensorContraction.getInputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(in);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *out : tensorContraction.getOutputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(out);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+llvm::SmallVector<mlir::Value *, 8>
+linalg::getRanges(linalg::TensorContractionBase &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+  llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+  res.append(tmp.begin(), tmp.end());
+  return res;
+}
+
 //////////////////////////////////////////////////////////////////////////////
 // Implementation of DotOp.
 //////////////////////////////////////////////////////////////////////////////
index d9a56c6..3f1c36d 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "linalg3/Transforms.h"
+#include "linalg1/Common.h"
 #include "linalg2/Intrinsics.h"
 #include "linalg3/Ops.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
index 05865e9..7a16089 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "linalg4/Transforms.h"
+#include "linalg1/Common.h"
+#include "linalg1/Utils.h"
 #include "linalg3/Intrinsics.h"
 #include "linalg3/TensorOps.h"
 
 #include "mlir/AffineOps/AffineOps.h"
 #include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Transforms/LoopUtils.h"
 
@@ -56,27 +59,25 @@ static bool isZeroIndex(Value *v) {
          v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
 }
 
-template <typename ConcreteOp>
 static llvm::SmallVector<Value *, 4>
-makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
-                ArrayRef<Value *> allRanges, llvm::ArrayRef<Value *> ivs,
+makeTiledRanges(TensorContractionBase &contraction, ArrayRef<Value *> allRanges,
+                llvm::ArrayRef<Value *> ivs,
                 llvm::ArrayRef<Value *> tileSizes) {
   assert(ivs.size() == tileSizes.size());
   if (ivs.empty())
     return RangeParts(allRanges).makeRanges();
 
-  auto *op = static_cast<ConcreteOp *>(&contraction);
   RangeParts result(allRanges.size());
   RangeParts rangeParts(allRanges);
 
-  for (auto map : op->loopsToOperandRangeMaps()) {
+  for (auto map : contraction.loopsToOperandRangeMaps()) {
     // 1. Take the first ivs results of the map, the other ones are not composed
     // but merely copied over.
     assert(map.getNumSymbols() == 0);
     assert(map.getRangeSizes().empty());
     MLIRContext *context = ScopedContext::getContext();
-    unsigned numParallel = op->getNumParallelDims();
-    unsigned numReduction = op->getNumReductionDims();
+    unsigned numParallel = contraction.getNumParallelDims();
+    unsigned numReduction = contraction.getNumReductionDims();
     if (ivs.size() < numParallel + numReduction) {
       // Inject zeros in positions that are not tiled.
       SmallVector<AffineExpr, 4> dimReplacements(numParallel + numReduction);
@@ -120,9 +121,8 @@ makeTiledRanges(TensorContractionBase<ConcreteOp> &contraction,
   return result.makeRanges();
 }
 
-template <class ConcreteOp>
 static SmallVector<Value *, 4>
-makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
+makeTiledViews(linalg::TensorContractionBase &contraction,
                ArrayRef<Value *> ivs, ArrayRef<Value *> tileSizes) {
   auto tiledRanges =
       makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
@@ -141,15 +141,15 @@ makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
   return res;
 }
 
-template <class ConcreteOp>
+template <typename ConcreteOp>
 static SmallVector<mlir::AffineForOp, 8>
-writeContractionAsTiledViews(TensorContractionBase<ConcreteOp> &contraction,
+writeContractionAsTiledViews(ConcreteOp &contraction,
                              ArrayRef<Value *> tileSizes) {
   assert(tileSizes.size() <=
          contraction.getNumParallelDims() + contraction.getNumReductionDims());
 
-  auto *op = static_cast<ConcreteOp *>(&contraction);
-  ScopedContext scope(mlir::FuncBuilder(op->getOperation()), op->getLoc());
+  ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
+                      contraction.getLoc());
   SmallVector<IndexHandle, 4> ivs(tileSizes.size());
   auto pivs = IndexHandle::makeIndexHandlePointers(ivs);