Linalg portion of the tutorial - part 3
authorNicolas Vasilache <ntv@google.com>
Tue, 2 Apr 2019 21:35:09 +0000 (14:35 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 2 Apr 2019 22:16:14 +0000 (15:16 -0700)
    This CL starts the third part of the Linalg tutorial by adding support for ops to declare how they lower themselves to other ops.
    Tests are added that demonstrate matmul lowering to a loop over matvec and matvec lowering to a loop over dot.

    This is part of a list of CLs that add new Transforms and Analyses to Linalg3: it iseasier to integrate in small chunks.

    As part of working with the TensorContractionBase template class and in an effort to add pieces incrementally without copying code, it is easiest to define operations ahead of time in Linalg2/TensorOps.h and gradually implement them as needed. This CL performs the necessary refactoring for this to happen.

--

PiperOrigin-RevId: 241605869

16 files changed:
mlir/include/mlir/EDSC/Intrinsics.h
mlir/tutorial/Linalg1/include/linalg1/Analysis.h
mlir/tutorial/Linalg1/include/linalg1/Utils.h [new file with mode: 0644]
mlir/tutorial/Linalg1/lib/Analysis.cpp
mlir/tutorial/Linalg1/lib/SliceOp.cpp
mlir/tutorial/Linalg1/lib/Utils.cpp [new file with mode: 0644]
mlir/tutorial/Linalg2/include/linalg2/TensorOps-inl.h [new file with mode: 0644]
mlir/tutorial/Linalg2/include/linalg2/TensorOps.h
mlir/tutorial/Linalg2/lib/TensorOps.cpp
mlir/tutorial/Linalg3/Example.cpp [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/Ops.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/TensorOps.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/Transforms.h [new file with mode: 0644]
mlir/tutorial/Linalg3/lib/TensorOps.cpp [new file with mode: 0644]
mlir/tutorial/Linalg3/lib/Transforms.cpp [new file with mode: 0644]

index 3858dff..05d6943 100644 (file)
@@ -166,6 +166,7 @@ using constant_float = ValueBuilder<ConstantFloatOp>;
 using constant_index = ValueBuilder<ConstantIndexOp>;
 using constant_int = ValueBuilder<ConstantIntOp>;
 using dealloc = OperationBuilder<DeallocOp>;
+using dim = ValueBuilder<DimOp>;
 using load = ValueBuilder<LoadOp>;
 using ret = OperationBuilder<ReturnOp>;
 using select = ValueBuilder<SelectOp>;
index 4c2f4ba..ef8fb98 100644 (file)
@@ -44,12 +44,6 @@ mlir::Value *getViewSupportingMemRef(mlir::Value *view);
 std::pair<mlir::Value *, unsigned> getViewRootIndexing(mlir::Value *view,
                                                        unsigned dim);
 
-////////////////////////////////////////////////////////////////////////////////
-/// Helper functions to avoid dispatching at all client sites.
-////////////////////////////////////////////////////////////////////////////////
-/// Asserts `view` is of ViewType and returns its rank.
-unsigned getViewRank(mlir::Value *view);
-
 } // namespace linalg
 
 #endif // LINALG1_ANALYSIS_H_
diff --git a/mlir/tutorial/Linalg1/include/linalg1/Utils.h b/mlir/tutorial/Linalg1/include/linalg1/Utils.h
new file mode 100644 (file)
index 0000000..cb6b285
--- /dev/null
@@ -0,0 +1,32 @@
+//===- Utils.h - Linalg dialect utility functions definitions -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG1_UTILS_H_
+#define LINALG1_UTILS_H_
+
+namespace mlir {
+class Value;
+} // namespace mlir
+
+namespace linalg {
+
+/// Asserts `view` is of ViewType and returns its rank.
+unsigned getViewRank(mlir::Value *view);
+
+} // namespace linalg
+
+#endif // LINALG1_UTILS_H_
index d7ed249..7a11a85 100644 (file)
@@ -73,13 +73,3 @@ std::pair<mlir::Value *, unsigned> linalg::getViewRootIndexing(Value *view,
   unsigned parentDim = dim > sliceDim ? dim + 1 : dim;
   return getViewRootIndexing(parentView, parentDim);
 }
-
-////////////////////////////////////////////////////////////////////////////////
-/// Helper functions to avoid dispatch at all client sites.
-////////////////////////////////////////////////////////////////////////////////
-unsigned linalg::getViewRank(Value *view) {
-  assert(view->getType().isa<ViewType>() && "expected a ViewType");
-  if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
-    return viewOp.getRank();
-  return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
-}
index c2ea98a..4383743 100644 (file)
@@ -23,6 +23,7 @@
 #include "linalg1/Analysis.h"
 #include "linalg1/Ops.h"
 #include "linalg1/Types.h"
+#include "linalg1/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/tutorial/Linalg1/lib/Utils.cpp b/mlir/tutorial/Linalg1/lib/Utils.cpp
new file mode 100644 (file)
index 0000000..46be325
--- /dev/null
@@ -0,0 +1,34 @@
+//===- Utils.cpp - Implementation of utiliy functions for Linalg ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utility functions for the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Utils.h"
+#include "linalg1/Ops.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace linalg;
+
+unsigned linalg::getViewRank(Value *view) {
+  assert(view->getType().isa<ViewType>() && "expected a ViewType");
+  if (auto viewOp = view->getDefiningOp()->dyn_cast<ViewOp>())
+    return viewOp.getRank();
+  return view->getDefiningOp()->dyn_cast<SliceOp>().getRank();
+}
diff --git a/mlir/tutorial/Linalg2/include/linalg2/TensorOps-inl.h b/mlir/tutorial/Linalg2/include/linalg2/TensorOps-inl.h
new file mode 100644 (file)
index 0000000..1a08e9e
--- /dev/null
@@ -0,0 +1,108 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG2_TENSOROPS_INL_H_
+#define LINALG2_TENSOROPS_INL_H_
+
+#include "linalg2/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+namespace linalg {
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getInputs() {
+  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+  return {op->operand_begin(), op->operand_begin() + getNumInputs()};
+}
+
+template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
+  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
+  return {op->operand_begin() + getNumInputs(),
+          op->operand_begin() + getNumInputs() + getNumOutputs()};
+}
+
+template <class ConcreteOp>
+mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
+  auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
+  if (getNumInputs() <= 0)
+    concreteOp->emitOpError("expected at least one input");
+  if (getNumOutputs() <= 0)
+    concreteOp->emitOpError("expected at least one output");
+  if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
+    concreteOp->emitOpError("expected " +
+                            llvm::Twine(getNumInputs() + getNumOutputs()) +
+                            " operands");
+  }
+  for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
+    if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
+      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+                                     " not a ViewType");
+  }
+  for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
+       ++i) {
+    auto viewType =
+        concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
+    if (!viewType)
+      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+                                     " not a ViewType");
+    if (viewType.getRank() != getNumParallelDims())
+      return concreteOp->emitOpError("operand " + llvm::Twine(i) +
+                                     " must be of rank " +
+                                     llvm::Twine(getNumParallelDims()));
+  }
+  return mlir::success();
+}
+
+template <class ConcreteOp>
+bool linalg::TensorContractionBase<ConcreteOp>::parse(
+    mlir::OpAsmParser *parser, mlir::OperationState *result) {
+  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
+}
+
+// A TensorContraction prints as:
+//
+// ```{.mlir}
+//   concrete_op_name {%0, %1} -> {%2}
+// ```
+//
+// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
+// view.
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
+  *p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
+  auto *lastInput = *std::prev(getInputs().end());
+  for (auto *i : getInputs()) {
+    *p << *i << ((i == lastInput) ? "} -> {" : ", ");
+  }
+  auto *lastOutput = *std::prev(getOutputs().end());
+  for (auto *o : getOutputs()) {
+    *p << *o << ((o == lastOutput) ? "}" : ",");
+  }
+}
+
+} // namespace linalg
+
+#endif // LINALG2_TENSOROPS_INL_H_
index cefd81a..c20a916 100644 (file)
 // limitations under the License.
 // =============================================================================
 
-#ifndef LINALG2_MATMULOP_H_
-#define LINALG2_MATMULOP_H_
+#ifndef LINALG2_TENSOROPS_H_
+#define LINALG2_TENSOROPS_H_
 
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
 
+namespace mlir {
+class AffineForOp;
+} // namespace mlir
+
 namespace linalg {
 
 /// A generic TensorContraction base class which captures the generic behavior
@@ -41,13 +45,6 @@ protected:
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
   TensorContractionBase() = default;
-
-  mlir::Type getInputElementType(unsigned i);
-  mlir::Type getOutputElementType(unsigned i);
-  mlir::Value *getInputView(unsigned i);
-  mlir::Value *getOutputView(unsigned i);
-  mlir::Value *getInputMemRef(unsigned i);
-  mlir::Value *getOutputMemRef(unsigned i);
   mlir::Operation::operand_range getInputs();
   mlir::Operation::operand_range getOutputs();
 
@@ -69,6 +66,20 @@ public:
   unsigned getNumReductionDims() {
     return static_cast<ConcreteOp *>(this)->numReductionDims;
   };
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Used in Linalg3 and later.
+  //////////////////////////////////////////////////////////////////////////////
+  mlir::Value *getInputView(unsigned i);
+  mlir::Value *getOutputView(unsigned i);
+  /// Computes a mapping from all the ranges of the operands to the enclosing
+  /// loops. In order to support "broadcast"-style semantics, we need to
+  /// consider all the operands (i.e. input operands are not sufficient).
+  /// The operands and their ranges are in the order defined by the particular
+  /// ConcreteOp implementation, the resulting map must match those.
+  /// This is currently computed but can also be specified explicitly in each
+  /// operator to generalize to cases where an analysis is not available.
+  mlir::AffineMap operandRangesToLoopsMap();
 };
 
 /// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
@@ -101,6 +112,13 @@ public:
   static constexpr unsigned numOutputs = 1;
   static constexpr unsigned numParallelDims = 0;
   static constexpr unsigned numReductionDims = 1;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Used in Linalg3 and later.
+  //////////////////////////////////////////////////////////////////////////////
+  /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+  /// loop over matvec). Does nothing by default.
+  void writeAsFinerGrainTensorContraction();
 };
 
 /// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
@@ -133,6 +151,13 @@ public:
   static constexpr unsigned numOutputs = 1;
   static constexpr unsigned numParallelDims = 1;
   static constexpr unsigned numReductionDims = 1;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Used in Linalg3 and later.
+  //////////////////////////////////////////////////////////////////////////////
+  /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+  /// loop over matvec). Does nothing by default.
+  void writeAsFinerGrainTensorContraction();
 };
 
 /// Implements C = A * B on 2-D matrices.
@@ -165,7 +190,20 @@ public:
   static constexpr unsigned numOutputs = 1;
   static constexpr unsigned numParallelDims = 2;
   static constexpr unsigned numReductionDims = 1;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Used in Linalg3 and later.
+  //////////////////////////////////////////////////////////////////////////////
+  /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
+  /// loop over matvec). Does nothing by default.
+  void writeAsFinerGrainTensorContraction();
 };
+
 } // namespace linalg
 
-#endif // LINALG2_MATMULOP_H_
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg2/TensorOps-inl.h"
+
+#endif // LINALG2_TENSOROPS_H_
index 8223629..8a47e5d 100644 (file)
@@ -20,7 +20,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "linalg2/Analysis.h"
+#include "linalg1/Utils.h"
 #include "linalg2/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
@@ -34,116 +34,6 @@ using llvm::Twine;
 using namespace mlir;
 using namespace linalg;
 
-template <class ConcreteOp>
-Type linalg::TensorContractionBase<ConcreteOp>::getInputElementType(
-    unsigned idx) {
-  return getInputView(idx)
-      ->getType()
-      .template cast<ViewType>()
-      .getElementType();
-}
-
-template <class ConcreteOp>
-Type linalg::TensorContractionBase<ConcreteOp>::getOutputElementType(
-    unsigned idx) {
-  return getOutputView(idx)
-      ->getType()
-      .template cast<ViewType>()
-      .getElementType();
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned idx) {
-  return *(getInputs().begin() + idx);
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned idx) {
-  return *(getOutputs().begin() + idx);
-}
-
-template <class ConcreteOp>
-Value *linalg::TensorContractionBase<ConcreteOp>::getInputMemRef(unsigned idx) {
-  return getViewSupportingMemRef(*(getInputs().begin() + idx));
-}
-
-template <class ConcreteOp>
-Value *
-linalg::TensorContractionBase<ConcreteOp>::getOutputMemRef(unsigned idx) {
-  return getViewSupportingMemRef(*(getOutputs().begin() + idx));
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getInputs() {
-  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
-  return {op->operand_begin(), op->operand_begin() + getNumInputs()};
-}
-
-template <class ConcreteOp>
-mlir::Operation::operand_range
-linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
-  auto *op = static_cast<ConcreteOp *>(this)->getOperation();
-  return {op->operand_begin() + getNumInputs(),
-          op->operand_begin() + getNumInputs() + getNumOutputs()};
-}
-
-template <class ConcreteOp>
-LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
-  auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
-  if (getNumInputs() <= 0)
-    concreteOp->emitOpError("expected at least one input");
-  if (getNumOutputs() <= 0)
-    concreteOp->emitOpError("expected at least one output");
-  if (concreteOp->getNumOperands() != getNumInputs() + getNumOutputs()) {
-    concreteOp->emitOpError(
-        "expected " + Twine(getNumInputs() + getNumOutputs()) + " operands");
-  }
-  for (unsigned i = 0, e = getNumInputs(); i < e; ++i) {
-    if (!concreteOp->getOperand(i)->getType().template isa<ViewType>())
-      return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
-  }
-  for (unsigned i = getNumInputs(), e = getNumInputs() + getNumOutputs(); i < e;
-       ++i) {
-    auto viewType =
-        concreteOp->getOperand(i)->getType().template dyn_cast<ViewType>();
-    if (!viewType)
-      return concreteOp->emitOpError("operand " + Twine(i) + " not a ViewType");
-    if (viewType.getRank() != getNumParallelDims())
-      return concreteOp->emitOpError("operand " + Twine(i) +
-                                     " must be of rank " +
-                                     Twine(getNumParallelDims()));
-  }
-  return success();
-}
-
-template <class ConcreteOp>
-bool linalg::TensorContractionBase<ConcreteOp>::parse(OpAsmParser *parser,
-                                                      OperationState *result) {
-  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
-}
-
-// A TensorContraction prints as:
-//
-// ```{.mlir}
-//   concrete_op_name {%0, %1} -> {%2}
-// ```
-//
-// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
-// view.
-template <class ConcreteOp>
-void linalg::TensorContractionBase<ConcreteOp>::print(OpAsmPrinter *p) {
-  *p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
-  auto *lastInput = *std::prev(getInputs().end());
-  for (auto *i : getInputs()) {
-    *p << *i << ((i == lastInput) ? "} -> {" : ", ");
-  }
-  auto *lastOutput = *std::prev(getOutputs().end());
-  for (auto *o : getOutputs()) {
-    *p << *o << ((o == lastOutput) ? "}" : ",");
-  }
-}
-
 //////////////////////////////////////////////////////////////////////////////
 // Op-specific Dot.
 //////////////////////////////////////////////////////////////////////////////
diff --git a/mlir/tutorial/Linalg3/Example.cpp b/mlir/tutorial/Linalg3/Example.cpp
new file mode 100644 (file)
index 0000000..1c10fd5
--- /dev/null
@@ -0,0 +1,102 @@
+//===- Example.cpp - Our running example ----------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+// RUN: %p/test | FileCheck %s
+
+#include "TestHarness.h"
+#include "linalg1/Common.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "linalg3/Transforms.h"
+#include "mlir/IR/OpImplementation.h"
+
+using llvm::StringRef;
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace linalg;
+using namespace linalg::common;
+using namespace linalg::intrinsics;
+
+Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+  MLIRContext *context = module.getContext();
+  auto dynamic2DMemRefType = floatMemRefType<2>(context);
+  mlir::Function *f = linalg::common::makeFunction(
+      module, name,
+      {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {});
+
+  ScopedContext scope(f);
+  // clang-format off
+  ValueHandle
+    M = dim(f->getArgument(0), 0),
+    N = dim(f->getArgument(2), 1),
+    K = dim(f->getArgument(0), 1),
+    rM = range(constant_index(0), M, constant_index(1)),
+    rN = range(constant_index(0), N, constant_index(1)),
+    rK = range(constant_index(0), K, constant_index(1)),
+    vA = view(f->getArgument(0), {rM, rK}),
+    vB = view(f->getArgument(1), {rK, rN}),
+    vC = view(f->getArgument(2), {rM, rN});
+  matmul(vA, vB, vC);
+  ret();
+  // clang-format on
+
+  return f;
+}
+
+TEST_FUNC(matmul_as_matvec) {
+  MLIRContext context;
+  Module module(&context);
+  mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
+  lowerToFinerGrainedTensorContraction(f);
+  // clang-format off
+  // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+  //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+  //  CHECK-NEXT:   %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+  //  CHECK-NEXT:   %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+  //  CHECK-NEXT:   linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
+  // clang-format on
+  cleanupAndPrintFunction(f);
+}
+
+TEST_FUNC(matmul_as_dot) {
+  MLIRContext context;
+  Module module(&context);
+  mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
+  lowerToFinerGrainedTensorContraction(f);
+  lowerToFinerGrainedTensorContraction(f);
+  // clang-format off
+  // CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+  //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+  //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+  //  CHECK-NEXT:   %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+  //  CHECK-NEXT:   %[[sC:.*]]  = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+  //  CHECK-NEXT:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
+  //  CHECK-NEXT:     %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
+  //  CHECK-NEXT:     %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
+  //  CHECK-NEXT:     linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
+  // clang-format on
+  cleanupAndPrintFunction(f);
+}
+
+int main() {
+  RUN_TESTS();
+  return 0;
+}
diff --git a/mlir/tutorial/Linalg3/include/linalg3/Ops.h b/mlir/tutorial/Linalg3/include/linalg3/Ops.h
new file mode 100644 (file)
index 0000000..f2d5ec4
--- /dev/null
@@ -0,0 +1,24 @@
+//===- Ops.h - Linalg Ops single entry point ------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_OPS_H_
+#define LINALG3_OPS_H_
+
+#include "linalg2/Ops.h"
+#include "linalg3/TensorOps.h"
+
+#endif // LINALG3_OPS_H_
diff --git a/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h
new file mode 100644 (file)
index 0000000..c4082d5
--- /dev/null
@@ -0,0 +1,43 @@
+//===- TensorOps-inl.h - Linalg dialect TensorOps operation implementation ===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#ifndef LINALG3_TENSOROPS_INL_H_
+#define LINALG3_TENSOROPS_INL_H_
+
+#include "linalg1/Common.h"
+#include "linalg2/TensorOps.h"
+
+namespace linalg {
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getInputView(unsigned i) {
+  return *(getInputs().begin() + i);
+}
+
+template <class ConcreteOp>
+mlir::Value *
+linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
+  return *(getOutputs().begin() + i);
+}
+
+} // namespace linalg
+
+#endif // LINALG3_TENSOROPS-INL_H_
diff --git a/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h b/mlir/tutorial/Linalg3/include/linalg3/TensorOps.h
new file mode 100644 (file)
index 0000000..3dffd6e
--- /dev/null
@@ -0,0 +1,28 @@
+//===- TensorOps.h - Linalg dialect TensorOps operation definition --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TENSOROPS_H_
+#define LINALG3_TENSOROPS_H_
+
+#include "linalg2/TensorOps.h"
+
+/// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
+/// TensorOps by adding implementations as they are needed in the appropriate
+/// step in the tutorial.
+#include "linalg3/TensorOps-inl.h"
+
+#endif // LINALG3_TENSOROPS_H_
diff --git a/mlir/tutorial/Linalg3/include/linalg3/Transforms.h b/mlir/tutorial/Linalg3/include/linalg3/Transforms.h
new file mode 100644 (file)
index 0000000..b5e11dd
--- /dev/null
@@ -0,0 +1,39 @@
+//===- Transforms.h - Linalg dialect Transformations definition -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_TRANSFORMS_H_
+#define LINALG3_TRANSFORMS_H_
+
+#include "linalg2/Transforms.h"
+
+namespace mlir {
+class Function;
+} // namespace mlir
+
+namespace linalg {
+
+/// Traverses `f` and rewrites linalg.slice, and the operations it depends on,
+/// to only use linalg.view operations.
+void composeSliceOps(mlir::Function *f);
+
+/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
+/// as linalg.matvec (resp. linalg.dot, loop form).
+void lowerToFinerGrainedTensorContraction(mlir::Function *f);
+
+} // namespace linalg
+
+#endif // LINALG3_TRANSFORMS_H_
diff --git a/mlir/tutorial/Linalg3/lib/TensorOps.cpp b/mlir/tutorial/Linalg3/lib/TensorOps.cpp
new file mode 100644 (file)
index 0000000..a04d772
--- /dev/null
@@ -0,0 +1,70 @@
+//===- TensorOps.cpp - Implementation of the linalg TensorOps operation ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR operation to create new tensor computation
+// operations in the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Analysis.h"
+#include "linalg1/Common.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace linalg;
+using namespace linalg::intrinsics;
+
+// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
+// The body expression for dot is: C() = A(r_i) * B(r_i);
+// So we must drop the `i` loop from the matvec.
+void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
+  auto *op = getOperation();
+  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  IndexHandle i;
+  auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
+  auto indexingPosPair = getViewRootIndexing(vA, 0);
+  assert(indexingPosPair.first->getDefiningOp() &&
+         indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
+  linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
+      dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
+  });
+}
+
+// The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
+// The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
+// So we must drop the `j` loop from the matmul.
+// This is fine because parallel dimensions permute: we can just do it
+// declaratively.
+void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
+  auto *op = getOperation();
+  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  IndexHandle j;
+  auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
+  auto indexingPosPair = getViewRootIndexing(vB, 1);
+  assert(indexingPosPair.first->getDefiningOp() &&
+         indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
+  linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
+      matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
+  });
+}
diff --git a/mlir/tutorial/Linalg3/lib/Transforms.cpp b/mlir/tutorial/Linalg3/lib/Transforms.cpp
new file mode 100644 (file)
index 0000000..aa9fbd0
--- /dev/null
@@ -0,0 +1,55 @@
+//===- Transforms.cpp - Implementation of the linalg Transformations ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements analyses and transformations for the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/Transforms.h"
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace linalg;
+using namespace linalg::intrinsics;
+
+void linalg::composeSliceOps(mlir::Function *f) {
+  f->walkPostOrder<SliceOp>([](SliceOp sliceOp) {
+    auto *sliceResult = sliceOp.getResult();
+    auto viewOp = createFullyComposedView(sliceResult);
+    sliceResult->replaceAllUsesWith(viewOp.getResult());
+    sliceOp.erase();
+  });
+}
+
+void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
+  f->walkPostOrder([](Operation *op) {
+    if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
+      matmulOp.writeAsFinerGrainTensorContraction();
+    } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
+      matvecOp.writeAsFinerGrainTensorContraction();
+    } else {
+      return;
+    }
+    op->erase();
+  });
+}