[Linalg] Add a primitive tiling pass
authorNicolas Vasilache <ntv@google.com>
Wed, 1 May 2019 13:47:32 +0000 (06:47 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:23:43 +0000 (08:23 -0700)
    This CL adds a primitive tiling pass for Linalg.
    The tiling pass uses the loopToOperandRangesMaps property which should be ideally Tablegen'd and in-class.

    The tiling specification uses 0 as a convention to skip loops that should not be tiled.

    Tiling proceeds in 3 steps, for each op:
    1. Pad tile sizes with 0 to match the number of loops, this simplifies the implementation and avoids affine map manipulations to align dimensions.
    2. Create loop ranges that represent the min/max/step by which to iterate. This should be later complemented by a range intersection to avoid the out-of-bounds case.
    3. Map the loop ranges to view ranges in order to create subviews on which the op can be called.

    Relevant utility and helper functions are added separately that support writing the transformation in a declarative fashion.
    Simplifying assumptions are made for now on the views and the ranges that are constructed
    in the function and are not passed as function arguments. This restriction will be lifted
    in the future.

--

PiperOrigin-RevId: 246124419

17 files changed:
mlir/include/mlir/IR/AffineMap.h
mlir/include/mlir/Linalg/CMakeLists.txt
mlir/include/mlir/Linalg/IR/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Linalg/IR/LinalgOps.h [moved from mlir/include/mlir/Linalg/LinalgOps.h with 92% similarity]
mlir/include/mlir/Linalg/IR/LinalgOps.td [moved from mlir/include/mlir/Linalg/LinalgOps.td with 86% similarity]
mlir/include/mlir/Linalg/IR/LinalgTraits.h [moved from mlir/include/mlir/Linalg/LinalgTraits.h with 99% similarity]
mlir/include/mlir/Linalg/IR/LinalgTypes.h [moved from mlir/include/mlir/Linalg/LinalgTypes.h with 100% similarity]
mlir/include/mlir/Linalg/Passes.h [new file with mode: 0644]
mlir/include/mlir/Linalg/Utils/Utils.h [new file with mode: 0644]
mlir/lib/IR/AffineMap.cpp
mlir/lib/Linalg/CMakeLists.txt
mlir/lib/Linalg/IR/LinalgOps.cpp [moved from mlir/lib/Linalg/LinalgOps.cpp with 88% similarity]
mlir/lib/Linalg/IR/LinalgTypes.cpp [moved from mlir/lib/Linalg/LinalgTypes.cpp with 97% similarity]
mlir/lib/Linalg/LinalgRegistration.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp [new file with mode: 0644]
mlir/lib/Linalg/Utils/Utils.cpp [new file with mode: 0644]
mlir/test/Linalg/tile.mlir [new file with mode: 0644]

index 41aefba..1df6ac7 100644 (file)
@@ -154,6 +154,52 @@ inline ::llvm::hash_code hash_value(AffineMap arg) {
 /// sizes.
 AffineMap simplifyAffineMap(AffineMap map);
 
+/// Returns a map of codomain to domain dimensions such that the first codomain
+/// dimension for a particular domain dimension is selected.
+///
+/// Prerequisites:
+///   1. `map` is a permutation of full rank.
+///   2. `map` has no symbols.
+///   3. `map` has empty `rangeSizes`.
+///
+/// Example:
+///
+/// ```{.mlir}
+///    (d0, d1, d2) -> (d1, d1, d0, d2, d1, d2, d1, d0)
+///                      0       2   3
+/// ```
+///
+/// returns:
+///
+/// ```{.mlir}
+///    (d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d0, d3)
+/// ```
+AffineMap inversePermutation(AffineMap map);
+
+/// Concatenates a list of `maps` into a single AffineMap, stepping over
+/// potentially empty maps. Assumes each of the underlying map has 0 symbols and
+/// empty `rangeSizes`.
+/// The resulting map has a number of dims equal to the max of `maps`' dims and
+/// the concatenated results as its results.
+///
+/// Example:
+/// When applied to the following list of 3 affine maps,
+///
+/// ```{.mlir}
+///    {
+///      (i, j, k) -> (i, k),
+///      (i, j, k) -> (k, j),
+///      (i, j, k) -> (i, j)
+///    }
+/// ```
+///
+/// Returns the map:
+///
+/// ```{.mlir}
+///     (i, j, k) -> (i, k, k, j, i, j)
+/// ```
+AffineMap concatAffineMaps(llvm::ArrayRef<AffineMap> maps);
+
 } // end namespace mlir
 
 namespace llvm {
index d3ed75c..f33061b 100644 (file)
@@ -1,4 +1 @@
-set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
-mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
-mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
-add_public_tablegen_target(MLIRLinalgOpsIncGen)
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Linalg/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d3ed75c
--- /dev/null
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
+mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
+mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRLinalgOpsIncGen)
similarity index 92%
rename from mlir/include/mlir/Linalg/LinalgOps.h
rename to mlir/include/mlir/Linalg/IR/LinalgOps.h
index 9406feb..5503b65 100644 (file)
@@ -19,8 +19,8 @@
 #define MLIR_LINALG_LINALGOPS_H_
 
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/Linalg/LinalgTraits.h"
-#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Linalg/IR/LinalgTraits.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
@@ -221,7 +221,25 @@ public:
 };
 
 #define GET_OP_CLASSES
-#include "mlir/Linalg/LinalgOps.h.inc"
+#include "mlir/Linalg/IR/LinalgOps.h.inc"
+
+/// Returns the list of maps that map loops to operands of a Linalg op.
+/// The i-th affine map identifies loop indices to subscripts that are used when
+/// accessing the i-th operand.
+/// For instance, a matmul that can be written in index notation as:
+/// `A(i, k) * B(k, j) -> C(i, j)` will have the following, ordered, list of
+/// affine maps:
+///
+/// ```{.mlir}
+///    (
+///      (i, j, k) -> (i, k),
+///      (i, j, k) -> (k, j),
+///      (i, j, k) -> (i, j)
+///    )
+/// ```
+///
+/// Only permutation maps are currently supported. 
+SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
 
 } // namespace mlir
 
similarity index 86%
rename from mlir/include/mlir/Linalg/LinalgOps.td
rename to mlir/include/mlir/Linalg/IR/LinalgOps.td
index d6673f2..958d879 100644 (file)
@@ -36,7 +36,8 @@ def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
 def View : Type<LinalgIsViewTypePred, "view">;
 
 class ParametricNativeOpTrait<string prop, string parameters> :
-  NativeOpTrait<prop # parameters>;
+  NativeOpTrait<prop # parameters>
+{}
 
 class ParametricIntNativeOpTrait<string prop, list<int> parameters> :
   ParametricNativeOpTrait<
@@ -48,32 +49,35 @@ class ParametricIntNativeOpTrait<string prop, list<int> parameters> :
                       sum,
                       param,
                       sum # "," # !cast<string>(param)),
-               ">::Impl")>;
+               ">::Impl")>
+{}
 
 // The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
 // to have a specified number of inputs and outputs, all passed as operands.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class NInputsAndOutputs<int n_ins, int n_outs> :
-  ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>;
+  ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+{}
 
 // The linalg `NLoopTypes` trait provides the API for ops that are known to have
 // a specified number of parallel (n_par), reduction (n_red) and window (n_win)
 // loops.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class NLoopTypes<int n_par, int n_red, int n_win> :
-  ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>;
+ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+{}
 
 // The linalg `ViewRanks` trait the API for ops that are known to have a
 // specified list of view ranks.
 // See Linalg/LinalgTraits.h for implementation details an usage.
 class ViewRanks<list<int> ranks> :
-  ParametricIntNativeOpTrait<"ViewRanks", ranks>;
+ParametricIntNativeOpTrait<"ViewRanks", ranks>
+{}
 
 // Base Tablegen class for Linalg ops.
 class LinalgOp<string mnemonic, list<OpTrait> props> :
-    Op<Linalg_Dialect, mnemonic, props> {
-  // The default variadic builder.
-  let arguments = (ins Variadic<View>);
+Op<Linalg_Dialect, mnemonic, props> {
+  let arguments = (ins Variadic<View>); // default variadic builder
 
   let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }];
 
@@ -85,12 +89,12 @@ class LinalgOp<string mnemonic, list<OpTrait> props> :
 ////////////////////////////////////////////////////////////////////////////////
 def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
                              NLoopTypes<0, 1, 0>,
-                             ViewRanks<[1, 1, 0]>]>;
+                             ViewRanks<[1, 1, 0]>]> {}
 def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
                                    NLoopTypes<1, 1, 0>,
-                                   ViewRanks<[2, 1, 1]>]>;
+                                   ViewRanks<[2, 1, 1]>]> {}
 def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
                                    NLoopTypes<2, 1, 0>,
-                                   ViewRanks<[2, 2, 2]>]>;
+                                   ViewRanks<[2, 2, 2]>]> {}
 
-#endif // LINALG_OPS
+#endif // LINALG_OPS
\ No newline at end of file
similarity index 99%
rename from mlir/include/mlir/Linalg/LinalgTraits.h
rename to mlir/include/mlir/Linalg/IR/LinalgTraits.h
index 94620a6..4a7428b 100644 (file)
@@ -19,7 +19,7 @@
 #define MLIR_LINALG_LINALGTRAITS_H_
 
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Linalg/Passes.h b/mlir/include/mlir/Linalg/Passes.h
new file mode 100644 (file)
index 0000000..7ccb788
--- /dev/null
@@ -0,0 +1,35 @@
+//===- Passes.h - Linalg pass entry points ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LINALG_PASSES_H_
+#define MLIR_LINALG_PASSES_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+class ModulePassBase;
+
+mlir::ModulePassBase *
+createLinalgTilingPass(llvm::ArrayRef<int64_t> tileSizes = {});
+} // namespace mlir
+
+#endif // MLIR_LINALG_PASSES_H_
diff --git a/mlir/include/mlir/Linalg/Utils/Utils.h b/mlir/include/mlir/Linalg/Utils/Utils.h
new file mode 100644 (file)
index 0000000..63bb2b3
--- /dev/null
@@ -0,0 +1,82 @@
+//===- Utils.h - Utilities to support the Linalg dialect --------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_LINALG_UTILS_H_
+#define MLIR_LINALG_UTILS_H_
+
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+namespace edsc {
+/// Helper class to sugar building loop nests from ranges.
+/// This is similar to edsc::LoopNestBuilder except it works on ranges directly.
+/// In the current implementation it produces affine.for operations and thus
+/// only admits ranges with constant steps.
+class LoopNestRangeBuilder {
+public:
+  LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
+                       llvm::ArrayRef<edsc::ValueHandle> ranges);
+  LoopNestRangeBuilder(llvm::ArrayRef<edsc::ValueHandle *> ivs,
+                       llvm::ArrayRef<Value *> ranges);
+  edsc::ValueHandle operator()(llvm::ArrayRef<edsc::CapturableHandle> stmts);
+
+private:
+  llvm::SmallVector<edsc::LoopBuilder, 4> loops;
+};
+
+} // namespace edsc
+
+/// Abstracts away the extraction of values of RangeType from the actual op
+/// implementation. For each operand of `op`:
+///   1. If it is of RangeType, appends it to the result.
+///   2. If it is of ViewType, further differentiates between:
+///      a. Views that have a defining op, in which cases it appends the ranges
+///         of the defining op.
+///      b. Views that do not have a defining op, in which case it materializes
+///         new range extraction ops to retrieve the range. This is not yet
+///         implemented and depends on future operations (e.g. extract_range).
+/// Precedence is given to a. over b. because it allows propagating existing
+/// values instead of creating new, duplicate, values.
+// TODO(ntv): Implement range extraction ops.
+SmallVector<Value *, 8> getRanges(Operation *op);
+
+/// Returns a value of ViewType at `b`, `loc` by applying the `ranges` to
+/// `viewDefiningOp`. This creates a new op unless `viewDefiningOp` already has
+/// the same exact `ranges`, in which case its (unique) result is returned.
+Value *createOrReturnView(FuncBuilder *b, Location loc,
+                          Operation *viewDefiningOp,
+                          llvm::ArrayRef<Value *> ranges);
+
+/// Returns the min/max/step from a RangeType value, depending on `part`:
+///   1. If `range` comes from a range defining op, this just returns the proper
+///      operand.
+///   2. Otherwise (e.g. if range is a function parameter), it materializes new
+///      part extraction ops to retrieve the min/max/step. This is not yet
+///      implemented and depends on future operations (e.g. extract_min, ...).
+/// Precedence is given to 1. over 2. because it allows propagating existing
+/// values instead of creating new, duplicate, values.
+/// This is used to abstract away the extraction of the min/max/step from a
+/// RangeType value.
+// TODO(ntv): Implement range extraction ops.
+enum class RangePart { Min = 0, Max, Step };
+Value *extractRangePart(Value *range, RangePart part);
+
+} // namespace mlir
+
+#endif // MLIR_LINALG_UTILS_H_
index 420dfb3..f62a1c9 100644 (file)
@@ -268,3 +268,41 @@ AffineMap mlir::simplifyAffineMap(AffineMap map) {
   }
   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, sizes);
 }
+
+AffineMap mlir::inversePermutation(AffineMap map) {
+  assert(map.getNumSymbols() == 0 && "expected map without symbols");
+  assert(map.getRangeSizes().empty() && "expected map without range sizes");
+  SmallVector<AffineExpr, 4> exprs(map.getNumDims());
+  for (auto en : llvm::enumerate(map.getResults())) {
+    auto expr = en.value();
+    auto d = expr.cast<AffineDimExpr>(); // permutation map expected;
+    if (exprs[d.getPosition()])
+      continue;
+    exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
+  }
+  SmallVector<AffineExpr, 4> seenExprs;
+  seenExprs.reserve(map.getNumDims());
+  for (auto expr : exprs)
+    if (expr)
+      seenExprs.push_back(expr);
+  assert(seenExprs.size() == map.getNumInputs() && "map is not full rank");
+  return AffineMap::get(map.getNumResults(), 0, seenExprs, {});
+}
+
+AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
+  unsigned numResults = 0;
+  for (auto m : maps)
+    numResults += m ? m.getNumResults() : 0;
+  unsigned numDims = 0;
+  llvm::SmallVector<AffineExpr, 8> results;
+  results.reserve(numResults);
+  for (auto m : maps) {
+    if (!m)
+      continue;
+    assert(m.getNumSymbols() == 0 && "expected map without symbols");
+    assert(m.getRangeSizes().empty() && "expected map without range sizes");
+    results.append(m.getResults().begin(), m.getResults().end());
+    numDims = std::max(m.getNumDims(), numDims);
+  }
+  return AffineMap::get(numDims, 0, results, {});
+}
index 50af3cc..e048d50 100644 (file)
@@ -1,7 +1,9 @@
 add_llvm_library(MLIRLinalg
-  LinalgOps.cpp
   LinalgRegistration.cpp
-  LinalgTypes.cpp
+  IR/LinalgOps.cpp
+  IR/LinalgTypes.cpp
+  Transforms/Tiling.cpp
+  Utils/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
similarity index 88%
rename from mlir/lib/Linalg/LinalgOps.cpp
rename to mlir/lib/Linalg/IR/LinalgOps.cpp
index 423fed0..84e1d44 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Linalg/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 
@@ -156,20 +156,14 @@ void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
     if (!i->getType().isa<RangeType>())
       rank--;
   Type elementType = viewType.getElementType();
-  result->addTypes(
-      {ViewType::get(b->getContext(), elementType, indexings.size())});
+  result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
 }
 
 LogicalResult mlir::SliceOp::verify() {
   if (llvm::empty(getOperands()))
     return emitOpError(
         "requires at least a view operand followed by 'rank' indices");
-  if (!getOperand(0)->getDefiningOp()->isa<ViewOp>())
-    return emitOpError(
-        "requires at least a view operand followed by 'rank' indices");
-
-  auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast<ViewOp>();
-  if (!viewOp)
+  if (!dyn_cast_or_null<ViewOp>(getOperand(0)->getDefiningOp()))
     return emitOpError("first operand must come from a ViewOp");
   unsigned rank = getBaseViewRank();
   if (llvm::size(getIndexings()) != rank) {
@@ -189,8 +183,8 @@ LogicalResult mlir::SliceOp::verify() {
   }
   if (getRank() != rank) {
     return emitOpError("the rank of the view must be the number of its range "
-                       "indices" +
-                       Twine(rank));
+                       "indices (" +
+                       Twine(rank) + ") but got: " + Twine(getRank()));
   }
   return success();
 }
@@ -359,6 +353,14 @@ void mlir::ViewOp::print(OpAsmPrinter *p) {
 
 namespace mlir {
 namespace impl {
+void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op);
+bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result);
+} // namespace impl
+
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+
+} // namespace mlir
 
 // A LinalgLibraryOp prints as:
 //
@@ -374,7 +376,7 @@ namespace impl {
 // ```
 //
 // Where %0, %1 and %2 are ssa-values of type ViewType.
-void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
+void mlir::impl::printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
   assert(op->getAbstractOperation() && "unregistered operation");
   *p << op->getName().getStringRef() << "(";
   interleave(
@@ -386,7 +388,8 @@ void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) {
       [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
 }
 
-bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result) {
+bool mlir::impl::parseLinalgLibraryOp(OpAsmParser *parser,
+                                      OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> ops;
   SmallVector<Type, 3> types;
   return parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
@@ -395,9 +398,28 @@ bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result) {
          parser->resolveOperands(ops, types, parser->getNameLoc(),
                                  result->operands);
 }
-} // namespace impl
 
-#define GET_OP_CLASSES
-#include "mlir/Linalg/LinalgOps.cpp.inc"
-
-} // namespace mlir
+// Ideally this should all be Tablegen'd but there is no good story for
+// AffineMap for now.
+SmallVector<AffineMap, 4> mlir::loopToOperandRangesMaps(Operation *op) {
+  MLIRContext *context = op->getContext();
+  auto i = getAffineDimExpr(0, context);
+  auto j = getAffineDimExpr(1, context);
+  auto k = getAffineDimExpr(2, context);
+  if (op->isa<DotOp>())
+    // A(r_i) * B(r_i) -> C()
+    return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}, {}),
+                                     AffineMap::get(1, 0, {i}, {}),
+                                     AffineMap()};
+  if (op->isa<MatvecOp>())
+    //   A(i, r_j) * B(r_j) -> C(i)
+    return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}, {}),
+                                     AffineMap::get(2, 0, {j}, {}),
+                                     AffineMap::get(2, 0, {i}, {})};
+  if (op->isa<MatmulOp>())
+    //   A(i, r_j) * B(r_j) -> C(i)
+    return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}, {}),
+                                     AffineMap::get(3, 0, {k, j}, {}),
+                                     AffineMap::get(3, 0, {i, j}, {})};
+  llvm_unreachable("Missing loopToOperandRangesMaps for op");
+}
similarity index 97%
rename from mlir/lib/Linalg/LinalgTypes.cpp
rename to mlir/lib/Linalg/IR/LinalgTypes.cpp
index a507fa8..556d5d1 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Linalg/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
 #include "mlir/Support/LLVM.h"
 
 using namespace mlir;
@@ -33,7 +33,7 @@ mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
   addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
   addOperations<
 #define GET_OP_LIST
-#include "mlir/Linalg/LinalgOps.cpp.inc"
+#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
       >();
 }
 
index 3637037..816b565 100644 (file)
@@ -15,8 +15,8 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/Linalg/LinalgOps.h"
-#include "mlir/Linalg/LinalgTypes.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
 
 using namespace mlir;
 
diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp
new file mode 100644 (file)
index 0000000..52f29eb
--- /dev/null
@@ -0,0 +1,367 @@
+//===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the linalg dialect Tiling pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/STLExtras.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace llvm;
+
+static llvm::cl::OptionCategory clOptionsCategory("linalg options");
+static llvm::cl::list<unsigned>
+    clTileSizes("linalg-tile-sizes",
+                llvm::cl::desc("Tile sizes by which to tile linalg operations"),
+                llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
+                llvm::cl::cat(clOptionsCategory));
+
+namespace {
+class PerFunctionState {
+public:
+  PerFunctionState(Function &f) : f(f) {}
+
+  Value *getOrCreate(int64_t v) {
+    auto it = map.find(v);
+    if (it != map.end())
+      return it->second;
+    edsc::ScopedContext s(&f);
+    return map.insert(make_pair(v, edsc::intrinsics::constant_index(v)))
+        .first->getSecond();
+  }
+
+private:
+  Function &f;
+  SmallDenseMap<int64_t, Value *> map;
+};
+} // namespace
+
+// Folding eagerly is necessary to abide by affine.for static step requirement.
+// We must propagate constants on the steps as aggressively as possible.
+// Returns nullptr if folding is not trivially feasible.
+static Value *tryFold(AffineMap map, ArrayRef<Value *> operands,
+                      PerFunctionState &state) {
+  assert(map.getNumResults() == 1 && "single result map expected");
+  auto expr = map.getResult(0);
+  if (auto dim = expr.dyn_cast<AffineDimExpr>())
+    return operands[dim.getPosition()];
+  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+    return operands[map.getNumDims() + sym.getPosition()];
+  if (auto cst = expr.dyn_cast<AffineConstantExpr>())
+    return state.getOrCreate(cst.getValue());
+  return nullptr;
+}
+
+static Value *emitOrFoldComposedAffineApply(FuncBuilder *b, Location loc,
+                                            AffineMap map,
+                                            ArrayRef<Value *> operandsRef,
+                                            PerFunctionState &state) {
+  SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  if (auto *v = tryFold(map, operands, state))
+    return v;
+  return b->create<AffineApplyOp>(loc, map, operands);
+}
+
+static SmallVector<Value *, 4> applyMapToRangePart(FuncBuilder *b, Location loc,
+                                                   AffineMap map,
+                                                   ArrayRef<Value *> ranges,
+                                                   RangePart part,
+                                                   PerFunctionState &state) {
+  SmallVector<Value *, 4> rangeParts(ranges.size());
+  transform(llvm::make_range(ranges.begin(), ranges.end()), rangeParts.begin(),
+            [&](Value *range) { return extractRangePart(range, part); });
+
+  SmallVector<Value *, 4> res;
+  res.reserve(map.getNumResults());
+  unsigned numDims = map.getNumDims();
+  for (auto expr : map.getResults()) {
+    AffineMap map = AffineMap::get(numDims, 0, expr, {});
+    res.push_back(
+        emitOrFoldComposedAffineApply(b, loc, map, rangeParts, state));
+  }
+  return res;
+}
+
+static bool isZero(Value *v) {
+  return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() &&
+         v->getDefiningOp()->cast<ConstantIndexOp>().getValue() == 0;
+}
+
+/// Returns a map that can be used to filter the zero values out of tileSizes.
+/// For example, if tileSizes contains `{v1, 0, v2}`, the returned map is:
+///
+/// ```{.mlir}
+///    (d0, d1, d2) -> (d0, d2)
+/// ```
+static AffineMap nonZeroMap(ArrayRef<Value *> tileSizes) {
+  SmallVector<AffineExpr, 4> exprs;
+  for (auto en : llvm::enumerate(tileSizes))
+    if (!isZero(en.value()))
+      exprs.push_back(getAffineDimExpr(en.index(), en.value()->getContext()));
+  assert(!exprs.empty() &&
+         "unexpected zero-only tile sizes, should have been handled earlier");
+  return AffineMap::get(tileSizes.size(), 0, exprs, {});
+}
+
+// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
+// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument has
+// one entry per surrounding loop. It uses zero as the convention that a
+// particular loop is not tiled. This convention simplifies implementations by
+// avoiding affine map manipulations.
+// The returned ranges correspond to the loop ranges, in the proper order, that
+// are tiled and for which new loops will be created.
+static SmallVector<Value *, 4>
+makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map,
+                    ArrayRef<Value *> allOpRanges, ArrayRef<Value *> tileSizes,
+                    PerFunctionState &state) {
+  assert(tileSizes.size() == map.getNumResults());
+  // Tile sizes are in loop order by construction, apply `map` to
+  // get mins/maxes/steps in loop order.
+  auto mins =
+      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Min, state);
+  auto maxes =
+      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Max, state);
+  auto steps =
+      applyMapToRangePart(b, loc, map, allOpRanges, RangePart::Step, state);
+  SmallVector<Value *, 4> sizes(tileSizes.begin(), tileSizes.end());
+
+  // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
+  for (int idx = mins.size() - 1; idx >= 0; --idx) {
+    if (isZero(tileSizes[idx])) {
+      mins.erase(mins.begin() + idx);
+      maxes.erase(maxes.begin() + idx);
+      steps.erase(steps.begin() + idx);
+      sizes.erase(sizes.begin() + idx);
+    }
+  }
+
+  // Create a new range with the applied tile sizes.
+  SmallVector<Value *, 4> res;
+  for (unsigned idx = 0, e = steps.size(); idx < e; ++idx) {
+    auto *step = steps[idx];
+    auto *tileSize = sizes[idx];
+    // clang-format off
+    // Steps must be constant for now to abide by affine.for semantics.
+    auto *newStep =
+        state.getOrCreate(
+            step->getDefiningOp()->cast<ConstantIndexOp>().getValue() *
+            tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue());
+    res.push_back(b->create<RangeOp>(loc, mins[idx], maxes[idx], newStep));
+    // clang-format on
+  }
+  return res;
+}
+
+static SmallVector<Value *, 4> makeTiledViews(FuncBuilder *b, Location loc,
+                                              Operation *op,
+                                              ArrayRef<Value *> ivs,
+                                              ArrayRef<Value *> tileSizes,
+                                              PerFunctionState &state) {
+  assert(ivs.size() == llvm::count_if(
+                           llvm::make_range(tileSizes.begin(), tileSizes.end()),
+                           [](Value *v) { return !isZero(v); }) &&
+         "expected as many ivs as non-zero sizes");
+  auto *context = op->getContext();
+
+  SmallVector<Value *, 4> res;
+  res.reserve(op->getNumOperands());
+  for (unsigned i = 0, ei = op->getNumOperands(); i < ei; ++i) {
+    auto *viewDefiningOp = op->getOperand(i)->getDefiningOp();
+    assert(viewDefiningOp && "Need operations to extract ranges from views");
+    auto ranges = getRanges(viewDefiningOp);
+    // E.g. for A in A(i, k) * B(k, j) -> C(i, j) returns the map:
+    //   (i, j, k) -> (i, k)
+    auto map = loopToOperandRangesMaps(op)[i];
+    if (!map) {
+      assert(ranges.empty() && "scalar should have empty ranges");
+      res.push_back(op->getOperand(i));
+      continue;
+    }
+    assert(ranges.size() == map.getNumResults());
+    // E.g. for {0, 0, v2} returns the map:
+    //   (i, j, k) -> (k)
+    auto nzMap = nonZeroMap(tileSizes);
+
+    SmallVector<Value *, 4> newRanges;
+    newRanges.reserve(ranges.size());
+    for (unsigned j = 0, ej = ranges.size(); j < ej; ++j) {
+      // Loop position for the range dimension.
+      // E.g. for A in A(i, k) * B(k, j) -> C(i, j) and map: (i, j, k) -> (i, k)
+      //   and for j == 1 (i.e. result `k`)
+      //   returns loopPos = 2 (i.e. `k` on the map domain).
+      auto pos = map.getResult(j).template cast<AffineDimExpr>().getPosition();
+      if (isZero(tileSizes[pos])) {
+        newRanges.push_back(ranges[j]);
+        continue;
+      }
+      auto it = llvm::find_if(nzMap.getResults(), [pos, context](AffineExpr e) {
+        return e == getAffineDimExpr(pos, context);
+      });
+      assert(it != nzMap.getResults().end() &&
+             "position does not correspond to a valid induction variable");
+      unsigned pos2 = it - nzMap.getResults().begin();
+      using edsc::op::operator+;
+      using range = ValueBuilder<RangeOp>;
+      ScopedContext scope(*b, loc);
+      ValueHandle iv(ivs[pos2]), step(tileSizes[pos]);
+      auto min = ValueHandle(extractRangePart(ranges[j], RangePart::Min));
+      // zero case is important enough to fold away by special-casing.
+      auto newMin = isZero(min) ? iv : min + iv;
+      // TODO(ntv): intersect with current range once the operation exists.
+      Value *r = range(newMin, newMin + step, step);
+      newRanges.push_back(r);
+    }
+    res.push_back(createOrReturnView(b, loc, viewDefiningOp, newRanges));
+  }
+  return res;
+}
+
+template <class LinalgOp>
+static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
+                                  PerFunctionState &state) {
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  assert(op.getNumParallelLoops() + op.getNumReductionLoops() +
+                 op.getNumWindowLoops() ==
+             tileSizes.size() &&
+         "expected matching number of tile sizes and loops");
+
+  ScopedContext scope(FuncBuilder(op.getOperation()), op.getLoc());
+  auto loopRanges = makeTiledLoopRanges(
+      scope.getBuilder(), scope.getLocation(),
+      // The flattened loopToOperandRangesMaps is expected to be an invertible
+      // permutation map (which is asserted in the inverse calculation).
+      inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op))),
+      getRanges(op.getOperation()), tileSizes, state);
+
+  SmallVector<IndexHandle, 4> ivs(loopRanges.size());
+  auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
+  LoopNestRangeBuilder(pivs, loopRanges)({[&op, &tileSizes, &ivs, &state]() {
+    auto *b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
+    // If/when the assertion below becomes false, we will have to templatize
+    // `makeTiledViews`.
+    assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
+    auto views =
+        makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
+    b->create<LinalgOp>(loc, views);
+    /// NestedBuilders expect handles, we thus return an IndexHandle.
+    return IndexHandle();
+  }()});
+
+  return success();
+}
+
+template <class LinalgOp>
+static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
+                                  PerFunctionState &state) {
+  if (tileSizes.empty())
+    return failure();
+
+  // The following uses the convention that "tiling by zero" skips tiling a
+  // particular dimension. This convention is significantly simpler to handle
+  // instead of adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() +
+                op.getNumWindowLoops();
+  tileSizes = tileSizes.take_front(nLoops);
+  // If only 0 tilings are left, then return.
+  if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
+    return failure();
+
+  // Materialize concrete tile size values to pass the generic tiling function.
+  SmallVector<Value *, 8> tileSizeValues;
+  tileSizeValues.reserve(tileSizes.size());
+  for (auto ts : tileSizes)
+    tileSizeValues.push_back(state.getOrCreate(ts));
+  // Pad tile sizes with zero values to enforce our convention.
+  if (tileSizeValues.size() < nLoops) {
+    for (unsigned i = tileSizeValues.size(); i < nLoops; ++i)
+      tileSizeValues.push_back(state.getOrCreate(0));
+  }
+
+  return tileLinalgOp(op, tileSizeValues, state);
+}
+
+// TODO(ntv) expose as a primitive for other passes.
+static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
+                                  PerFunctionState &state) {
+  if (auto matmulOp = op->dyn_cast<MatmulOp>()) {
+    return tileLinalgOp(matmulOp, tileSizes, state);
+  } else if (auto matvecOp = op->dyn_cast<MatvecOp>()) {
+    return tileLinalgOp(matvecOp, tileSizes, state);
+  } else if (auto dotOp = op->dyn_cast<DotOp>()) {
+    return tileLinalgOp(dotOp, tileSizes, state);
+  }
+  return failure();
+}
+
+static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) {
+  PerFunctionState state(f);
+  f.walk([tileSizes, &state](Operation *op) {
+    if (succeeded(tileLinalgOp(op, tileSizes, state)))
+      op->erase();
+  });
+}
+
+namespace {
+struct LinalgTilingPass : public ModulePass<LinalgTilingPass> {
+  LinalgTilingPass();
+  LinalgTilingPass(ArrayRef<int64_t> sizes);
+
+  void runOnModule() {
+    for (auto &f : getModule())
+      tileLinalgOps(f, tileSizes);
+  }
+
+  SmallVector<int64_t, 8> tileSizes;
+};
+} // namespace
+
+LinalgTilingPass::LinalgTilingPass()
+    : tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
+
+LinalgTilingPass::LinalgTilingPass(ArrayRef<int64_t> sizes)
+    : LinalgTilingPass() {
+  if (!sizes.empty())
+    this->tileSizes.assign(sizes.begin(), sizes.end());
+}
+
+ModulePassBase *mlir::createLinalgTilingPass(ArrayRef<int64_t> tileSizes) {
+  return new LinalgTilingPass(tileSizes);
+}
+
+static PassRegistration<LinalgTilingPass>
+    pass("linalg-tile", "Tile operations in the linalg dialect");
diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp
new file mode 100644 (file)
index 0000000..0052ef0
--- /dev/null
@@ -0,0 +1,139 @@
+//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utilities for the Linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Linalg/Utils/Utils.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Linalg/IR/LinalgTypes.h"
+#include "mlir/Linalg/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/STLExtras.h"
+
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+using namespace llvm;
+
+mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
+    ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> ranges) {
+  for (unsigned i = 0, e = ranges.size(); i < e; ++i) {
+    assert(ranges[i].getType() && "expected !linalg.range type");
+    assert(ranges[i].getValue()->getDefiningOp() &&
+           "need operations to extract range parts");
+    auto rangeOp = ranges[i].getValue()->getDefiningOp()->cast<RangeOp>();
+    auto lb = rangeOp.min();
+    auto ub = rangeOp.max();
+    // This must be a constexpr index until we relax the affine.for constraint
+    auto step =
+        rangeOp.step()->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+    loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step);
+  }
+  assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
+}
+
+mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder(
+    ArrayRef<ValueHandle *> ivs, ArrayRef<Value *> ranges)
+    : LoopNestRangeBuilder(
+          ivs, SmallVector<ValueHandle, 4>(ranges.begin(), ranges.end())) {}
+
+ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()(
+    ArrayRef<CapturableHandle> stmts) {
+  for (auto &lit : reverse(loops)) {
+    lit({});
+  }
+  return ValueHandle::null();
+}
+
+SmallVector<Value *, 8> mlir::getRanges(Operation *op) {
+  SmallVector<Value *, 8> res;
+  if (auto view = op->dyn_cast<ViewOp>()) {
+    res.append(view.getIndexings().begin(), view.getIndexings().end());
+  } else if (auto slice = op->dyn_cast<SliceOp>()) {
+    for (auto *i : slice.getIndexings())
+      if (i->getType().isa<RangeType>())
+        res.push_back(i);
+  } else {
+    for (auto *v : op->getOperands()) {
+      if (v->getType().isa<ViewType>()) {
+        if (auto *vOp = v->getDefiningOp()) {
+          auto tmp = getRanges(vOp);
+          res.append(tmp.begin(), tmp.end());
+        } else {
+          llvm_unreachable("Needs an operation to extract ranges from a view");
+        }
+      }
+    }
+  }
+  return res;
+}
+
+// Implementation details:
+//   1. Checks whether `ranges` define a new View by performing an equality
+//      check between the range ssa-values and the operands of
+//      `viewDefiningOp`.
+//   2. If all ranges happen to be equal, op creation is elided and the
+//      original result is returned instead.
+//   3. Otherwise, creates a SliceOp with the new `ranges`.
+// This is used to abstract away the creation of a SliceOp.
+Value *mlir::createOrReturnView(FuncBuilder *b, Location loc,
+                                Operation *viewDefiningOp,
+                                ArrayRef<Value *> ranges) {
+  if (auto view = viewDefiningOp->dyn_cast<ViewOp>()) {
+    auto indexings = view.getIndexings();
+    if (std::equal(indexings.begin(), indexings.end(), ranges.begin()))
+      return view.getResult();
+    return b->create<SliceOp>(loc, view.getResult(), ranges);
+  }
+  auto slice = viewDefiningOp->cast<SliceOp>();
+  unsigned idxRange = 0;
+  SmallVector<Value *, 4> newIndexings;
+  bool elide = true;
+  for (auto indexing : slice.getIndexings()) {
+    if (indexing->getType().isa<RangeType>()) {
+      elide &= (indexing != ranges[idxRange]);
+      newIndexings.push_back(ranges[idxRange++]);
+    } else
+      newIndexings.push_back(indexing);
+  }
+  if (elide)
+    return slice.getResult();
+  return b->create<SliceOp>(loc, slice.getBaseView(), newIndexings);
+}
+
+Value *mlir::extractRangePart(Value *range, RangePart part) {
+  assert(range->getType().isa<RangeType>() && "expected range type");
+  if (range->getDefiningOp()) {
+    if (auto r = dyn_cast_or_null<RangeOp>(range->getDefiningOp())) {
+      switch (part) {
+      case RangePart::Min:
+        return r.min();
+      case RangePart::Max:
+        return r.max();
+      case RangePart::Step:
+        return r.step();
+      }
+    }
+  }
+  llvm_unreachable("need operations to extract range parts");
+}
diff --git a/mlir/test/Linalg/tile.mlir b/mlir/test/Linalg/tile.mlir
new file mode 100644 (file)
index 0000000..aeebb70
--- /dev/null
@@ -0,0 +1,191 @@
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2 | FileCheck %s -check-prefix=TILE-2
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=0,2 | FileCheck %s -check-prefix=TILE-02
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=0,0,2 | FileCheck %s -check-prefix=TILE-002
+// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 | FileCheck %s -check-prefix=TILE-234
+
+//   TILE-2-DAG: #[[ID:.*]] = (d0) -> (d0)
+//   TILE-2-DAG: #[[UB0:.*]] = (d0) -> (d0 + 2)
+//  TILE-02-DAG: #[[ID:.*]] = (d0) -> (d0)
+//  TILE-02-DAG: #[[UB0:.*]] = (d0) -> (d0 + 2)
+// TILE-002-DAG: #[[ID:.*]] = (d0) -> (d0)
+// TILE-002-DAG: #[[UB0:.*]] = (d0) -> (d0 + 2)
+// TILE-234-DAG: #[[ID:.*]] = (d0) -> (d0)
+// TILE-234-DAG: #[[UB0:.*]] = (d0) -> (d0 + 2)
+// TILE-234-DAG: #[[UB1:.*]] = (d0) -> (d0 + 3)
+// TILE-234-DAG: #[[UB2:.*]] = (d0) -> (d0 + 4)
+
+func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %I = linalg.range %c0:%arg1:%c1 : !linalg.range
+  %J = linalg.range %c0:%arg2:%c1 : !linalg.range
+  %K = linalg.range %c0:%arg3:%c1 : !linalg.range
+  %A = linalg.view %arg0[%I, %K] : !linalg.view<?x?xf32>
+  %B = linalg.view %arg0[%K, %J] : !linalg.view<?x?xf32>
+  %C = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
+  linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  return
+}
+// TILE-2-LABEL: func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-2: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-2-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-2-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//       TILE-2: affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-2-NEXT:   %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sAi:.*]] = linalg.slice %[[A]][%[[ra]], %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-2-NEXT:   %[[c:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[rc:.*]] = linalg.range %i0:%[[c]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sCi:.*]] = linalg.slice %[[C]][%[[rc]], %1] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-2-NEXT:   linalg.matmul(%[[sAi]], %[[B]], %[[sCi]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+
+// TILE-02-LABEL: func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-02: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-02-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-02-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//       TILE-02: affine.for %i0 = #[[ID]](%c0_0) to #[[ID]](%arg2) step 2 {
+//  TILE-02-NEXT:   %[[b:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-02-NEXT:   %[[rb:.*]] = linalg.range %i0:%[[b]]:%c2 : !linalg.range
+//  TILE-02-NEXT:   %[[sBj:.*]] = linalg.slice %[[B]][%{{.*}}, %[[rb]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-02-NEXT:   %[[c:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-02-NEXT:   %[[rc:.*]] = linalg.range %i0:%[[c]]:%c2 : !linalg.range
+//  TILE-02-NEXT:   %[[sCj:.*]] = linalg.slice %[[C]][%{{.*}}, %[[rc]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-02-NEXT:   linalg.matmul(%[[A]], %[[sBj]], %[[sCj]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+
+// TILE-002-LABEL: func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-002: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-002-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-002-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//       TILE-002: affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg3) step 2 {
+//  TILE-002-NEXT:   %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-002-NEXT:   %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2 : !linalg.range
+//  TILE-002-NEXT:   %[[sAj:.*]] = linalg.slice %[[A]][%{{.*}}, %[[ra]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-002-NEXT:   %[[b:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-002-NEXT:   %[[rb:.*]] = linalg.range %i0:%[[b]]:%c2 : !linalg.range
+//  TILE-002-NEXT:   %[[sBj:.*]] = linalg.slice %[[B]][%[[rb]], %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-002-NEXT:   linalg.matmul(%[[sAj]], %[[sBj]], %[[C]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+
+// TILE-234-LABEL: func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-234: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-234-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-234-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//       TILE-234:  affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-234-NEXT:    affine.for %i1 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg2) step 3 {
+//  TILE-234-NEXT:      affine.for %i2 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg3) step 4 {
+//  TILE-234-NEXT:        %[[ai:.*]]  = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:        %[[rai:.*]] = linalg.range %i0:%[[ai]]:%c2{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[ak:.*]] = affine.apply #[[UB2]](%i2)
+//  TILE-234-NEXT:        %[[rak:.*]] = linalg.range %i2:%[[ak]]:%c4{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[sAik:.*]] = linalg.slice %[[A]][%[[rai]], %[[rak]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-234-NEXT:        %[[bk:.*]] = affine.apply #[[UB2]](%i2)
+//  TILE-234-NEXT:        %[[rbk:.*]] = linalg.range %i2:%[[bk]]:%c4{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[bj:.*]] = affine.apply #[[UB1]](%i1)
+//  TILE-234-NEXT:        %[[rbj:.*]] = linalg.range %i1:%[[bj]]:%c3{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[sBkj:.*]] = linalg.slice %[[B]][%[[rbk]], %[[rbj]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-234-NEXT:        %[[ci:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:        %[[rci:.*]] = linalg.range %i0:%[[ci]]:%c2{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[cj:.*]] = affine.apply #[[UB1]](%i1)
+//  TILE-234-NEXT:        %[[rcj:.*]] = linalg.range %i1:%[[cj]]:%c3{{.*}} : !linalg.range
+//  TILE-234-NEXT:        %[[sCij:.*]] = linalg.slice %[[C]][%[[rci]], %[[rcj]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-234-NEXT:        linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+
+func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %I = linalg.range %c0:%arg1:%c1 : !linalg.range
+  %J = linalg.range %c0:%arg2:%c1 : !linalg.range
+  %2 = linalg.view %arg0[%I, %J] : !linalg.view<?x?xf32>
+  %3 = linalg.view %arg0[%J] : !linalg.view<?xf32>
+  %4 = linalg.view %arg0[%I] : !linalg.view<?xf32>
+  linalg.matvec(%2, %3, %4) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
+  return
+}
+// TILE-2-LABEL: func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-2: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-2-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-2-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//       TILE-2: affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-2-NEXT:   %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sAi:.*]] = linalg.slice %[[A]][%[[ra]], %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-2-NEXT:   %[[c:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[rc:.*]] = linalg.range %i0:%[[c]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sCi:.*]] = linalg.slice %[[C]][%[[rc]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-2-NEXT:   linalg.matvec(%[[sAi]], %[[B]], %[[sCi]]) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
+
+// TILE-02-LABEL: func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-02: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-02-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-02-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//       TILE-02: affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg2) step 2 {
+//  TILE-02-NEXT:   %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-02-NEXT:   %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2{{.*}} : !linalg.range
+//  TILE-02-NEXT:   %[[sAj:.*]] = linalg.slice %[[A]][%{{.*}}, %[[ra]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-02-NEXT:   %[[b:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-02-NEXT:   %[[rb:.*]] = linalg.range %i0:%[[b]]:%c2{{.*}} : !linalg.range
+//  TILE-02-NEXT:   %[[sBj:.*]] = linalg.slice %[[B]][%[[rb]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-02-NEXT:   linalg.matvec(%[[sAj]], %[[sBj]], %[[C]]) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
+
+// TILE-002-LABEL: func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//   TILE-002-NOT: affine.for
+
+// TILE-234-LABEL: func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-234: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
+//  TILE-234-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-234-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//       TILE-234:  affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-234-NEXT:    affine.for %i1 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg2) step 3 {
+//  TILE-234-NEXT:      %[[ai:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:      %[[rai:.*]] = linalg.range %i0:%[[ai]]:%c2 : !linalg.range
+//  TILE-234-NEXT:      %[[aj:.*]] = affine.apply #[[UB1]](%i1)
+//  TILE-234-NEXT:      %[[raj:.*]] = linalg.range %i1:%[[aj]]:%c3 : !linalg.range
+//  TILE-234-NEXT:      %[[sAij:.*]] = linalg.slice %[[A]][%[[rai]], %[[raj]]] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+//  TILE-234-NEXT:      %[[b:.*]] = affine.apply #[[UB1]](%i1)
+//  TILE-234-NEXT:      %[[rb:.*]] = linalg.range %i1:%[[b]]:%c3 : !linalg.range
+//  TILE-234-NEXT:      %[[sB:.*]] = linalg.slice %[[B]][%[[rb]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-234-NEXT:      %[[c:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:      %[[rc:.*]] = linalg.range %i0:%[[c]]:%c2 : !linalg.range
+//  TILE-234-NEXT:      %[[sC:.*]] = linalg.slice %[[C]][%[[rc]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-234-NEXT:      linalg.matvec(%[[sAij]], %[[sB]], %[[sC]]) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
+
+func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %I = linalg.range %c0:%arg1:%c1 : !linalg.range
+  %1 = linalg.view %arg0[%I] : !linalg.view<?xf32>
+  %2 = linalg.view %arg0[%I] : !linalg.view<?xf32>
+  %3 = linalg.view %arg0[] : !linalg.view<f32>
+  linalg.dot(%1, %2, %3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+  return
+}
+// TILE-2-LABEL: func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-2: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-2-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-2-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<f32>
+//       TILE-2: affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-2-NEXT:   %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sAi:.*]] = linalg.slice %[[A]][%[[ra]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-2-NEXT:   %[[b:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-2-NEXT:   %[[rb:.*]] = linalg.range %i0:%[[b]]:%c2 : !linalg.range
+//  TILE-2-NEXT:   %[[sBi:.*]] = linalg.slice %[[B]][%[[rb]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-2-NEXT:   linalg.dot(%[[sAi]], %[[sBi]], %[[C]]) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
+
+// TILE-02-LABEL: func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//   TILE-02-NOT: affine.for
+
+// TILE-002-LABEL: func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//   TILE-002-NOT: affine.for
+
+// TILE-234-LABEL: func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+//       TILE-234: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-234-NEXT: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
+//  TILE-234-NEXT: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<f32>
+//       TILE-234:  affine.for %i0 = #[[ID]](%c0{{.*}}) to #[[ID]](%arg1) step 2 {
+//  TILE-234-NEXT:    %[[a:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:    %[[ra:.*]] = linalg.range %i0:%[[a]]:%c2 : !linalg.range
+//  TILE-234-NEXT:    %[[sA:.*]] = linalg.slice %[[A]][%[[ra]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-234-NEXT:    %[[b:.*]] = affine.apply #[[UB0]](%i0)
+//  TILE-234-NEXT:    %[[rb:.*]] = linalg.range %i0:%[[b]]:%c2 : !linalg.range
+//  TILE-234-NEXT:    %[[sB:.*]] = linalg.slice %[[B]][%[[rb]]] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
+//  TILE-234-NEXT:    linalg.dot(%[[sA]], %[[sB]], %[[C]]) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>