Linalg portion of the tutorial - part 3-2
authorNicolas Vasilache <ntv@google.com>
Wed, 3 Apr 2019 19:33:01 +0000 (12:33 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 4 Apr 2019 02:20:31 +0000 (19:20 -0700)
    This CL adds support for lowering tensor contractions to loops declaratively.
    This is done thanks to two properties of the such operations:
    1. the definition of an AffineMap getLoopsToOperandRangesMap for each op which maps iteration space dimensions to ranges of the view operands, in their order of occurrence;
    2. the definition of a scalar implementation for each op which creates the computation inside the loops given enclosing parallel and reduction loops,

    All the other properties are derived in a generic fashion from these 2 properties and a few analyses.

    A lowerToLoops transformation is added as well as a test that exercises it.

--

PiperOrigin-RevId: 241783992

14 files changed:
mlir/tutorial/Linalg2/include/linalg2/TensorOps.h
mlir/tutorial/Linalg3/Example.cpp
mlir/tutorial/Linalg3/include/linalg3/Analysis.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h [new file with mode: 0644]
mlir/tutorial/Linalg3/include/linalg3/Ops.h
mlir/tutorial/Linalg3/include/linalg3/TensorOps-inl.h
mlir/tutorial/Linalg3/include/linalg3/TensorOps.h
mlir/tutorial/Linalg3/include/linalg3/Transforms.h
mlir/tutorial/Linalg3/lib/Analysis.cpp [new file with mode: 0644]
mlir/tutorial/Linalg3/lib/DialectRegistration.cpp [new file with mode: 0644]
mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp [new file with mode: 0644]
mlir/tutorial/Linalg3/lib/TensorOps.cpp
mlir/tutorial/Linalg3/lib/Transforms.cpp

index c20a916..406bcaa 100644 (file)
@@ -41,6 +41,7 @@ protected:
   static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
   void print(mlir::OpAsmPrinter *p);
 
+public:
   //////////////////////////////////////////////////////////////////////////////
   // Op-specific functionality.
   //////////////////////////////////////////////////////////////////////////////
@@ -48,7 +49,6 @@ protected:
   mlir::Operation::operand_range getInputs();
   mlir::Operation::operand_range getOutputs();
 
-public:
   /// These are better as methods calling into the ConcreteOp instead of
   /// template parameters because methods allow more generic behavior and avoid
   /// specializing for number of arguments. All derived classes have
@@ -72,14 +72,24 @@ public:
   //////////////////////////////////////////////////////////////////////////////
   mlir::Value *getInputView(unsigned i);
   mlir::Value *getOutputView(unsigned i);
-  /// Computes a mapping from all the ranges of the operands to the enclosing
-  /// loops. In order to support "broadcast"-style semantics, we need to
-  /// consider all the operands (i.e. input operands are not sufficient).
+
+  /// Each op is responsible for declaring how it lowers itself to scalar form,
+  /// given the enclosing parallel and reduction induction variables.
+  /// `emitScalarImplementation` emits the scalar IR for the op in the nesting
+  /// context of the innermost enclosing loop(i.e. `reductionIvs.back()` or
+  /// `parallel.back()`).
+  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                                llvm::ArrayRef<mlir::Value *> reductionIvs);
+
+  /// Represents a mapping from the loops to all the ranges of the operands.
   /// The operands and their ranges are in the order defined by the particular
   /// ConcreteOp implementation, the resulting map must match those.
-  /// This is currently computed but can also be specified explicitly in each
-  /// operator to generalize to cases where an analysis is not available.
-  mlir::AffineMap operandRangesToLoopsMap();
+  /// In favorable cases, this can be calculated by an analysis but specifying
+  /// it explicitly is not expensive and generalizes to cases where an analysis
+  /// is not available.
+  /// For details, see the description of loopsToOperandRangesMap in each
+  /// ConcreteOp
+  mlir::AffineMap loopsToOperandRangesMap();
 };
 
 /// Implements c = A * B where c is a scalar and A and B are 1-D vectors.
@@ -119,6 +129,27 @@ public:
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
   void writeAsFinerGrainTensorContraction();
+
+  /// Inputs to this map will be (%k) coming from enclosing loops.
+  /// Therefore, the mapping to get back to A(K), B(K), C() is:
+  ///   (d0) -> (d0, d0)(%k)
+  /// And the operands ranges are:
+  ///   (%k, %k)
+  mlir::AffineMap loopsToOperandRangesMap();
+
+  ///  Given an enclosing reduction loop with iv `r_i`, emits MLIR corresponding
+  ///  to:
+  ///    1. conditionally assign scalarC to 0.0f on the first iteration or load
+  ///       C[] from memory (0-D tensor)
+  ///    2. multiply A[r_i] by B[r_i] and add to scalarC
+  ///    3. store back scalarC at C[]
+  ///
+  /// In some compact index notation this could be written:
+  ///  cond = (r_i == zero)
+  ///  scalarC = select(cond, zerof, C[]);
+  ///  C[] = scalarC + A[r_i] * B[r_i];
+  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                                llvm::ArrayRef<mlir::Value *> reductionIvs);
 };
 
 /// Implements C = A * B where A is a 2-D matrix and X and Y are 1-D vectors.
@@ -158,6 +189,27 @@ public:
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
   void writeAsFinerGrainTensorContraction();
+
+  /// Inputs to this map will be (%m, %k) coming from enclosing loops.
+  /// Therefore, the mapping to get back to A(M, K), B(K), C(M) is:
+  ///   (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
+  /// And the operands ranges are:
+  ///   (%m, %k, %k, %m)
+  mlir::AffineMap loopsToOperandRangesMap();
+
+  ///  Given an enclosing parallel loop with iv `i` and an enclosing parallel
+  ///  loop with iv `r_j`, emits MLIR corresponding to:
+  ///    1. conditionally assign scalarC to 0.0f on the first iteration or load
+  ///       C[i]
+  ///    2. multiply A[i, r_j] by B[r_j] and add to scalarC
+  ///    3. store back scalarC at C[i]
+  ///
+  /// In some compact index notation this could be written:
+  ///  cond = (r_j == zero)
+  ///  scalarC = select(cond, zerof, C(i));
+  ///  C(i) = scalarC + A(i, r_j) * B(r_j);
+  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                                llvm::ArrayRef<mlir::Value *> reductionIvs);
 };
 
 /// Implements C = A * B on 2-D matrices.
@@ -197,6 +249,27 @@ public:
   /// Rewrites this op as a finer-grained tensor contraction (e.g. matmul is a
   /// loop over matvec). Does nothing by default.
   void writeAsFinerGrainTensorContraction();
+
+  /// Inputs to this map will be (%m, %n, %k) coming from enclosing loops.
+  /// Therefore, the mapping to get back to A(M, K), B(K, N), C(M, N) is:
+  ///   (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
+  /// And the operands ranges are:
+  ///   (%m, %k, %k, %n, %m, %n)
+  mlir::AffineMap loopsToOperandRangesMap();
+
+  ///  Given a enclosing parallel loops with ivs `i` and `j`, and an enclosing
+  ///  reduction loop with iv `r_k`, emits MLIR corresponding to:
+  ///    1. conditionally assign scalarC to 0.0f on the first iteration or load
+  ///       C[i, j]
+  ///    2. multiply A[i, r_k] by B[r_k, j] and add to scalarC
+  ///    3. store back scalarC at C[i, j]
+  ///
+  /// In some compact index notation this could be written:
+  ///  cond = (r_k == zero)
+  ///  scalarC = select(cond, zerof, C[i, j]);
+  ///  C[i, j] = scalarC + A[i, r_k] * B[r_k, j];
+  void emitScalarImplementation(llvm::ArrayRef<mlir::Value *> parallelIvs,
+                                llvm::ArrayRef<mlir::Value *> reductionIvs);
 };
 
 } // namespace linalg
index 1c10fd5..13eadd4 100644 (file)
@@ -64,13 +64,15 @@ TEST_FUNC(matmul_as_matvec) {
   Module module(&context);
   mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
   lowerToFinerGrainedTensorContraction(f);
+  composeSliceOps(f);
   // clang-format off
   // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
   //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32xf32>">
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
-  //  CHECK-NEXT:   %[[vB:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:   %[[vC:.*]] = linalg.slice %{{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:   linalg.matvec {%{{.*}}, %[[vB]]} -> {%[[vC]]}
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]}
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -81,21 +83,84 @@ TEST_FUNC(matmul_as_dot) {
   mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
   lowerToFinerGrainedTensorContraction(f);
   lowerToFinerGrainedTensorContraction(f);
+  composeSliceOps(f);
   // clang-format off
   // CHECK-LABEL: func @matmul_as_dot(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
   //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
   //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
-  //  CHECK-NEXT:   %[[vB:.*]] = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:   %[[sC:.*]]  = linalg.slice {{.*}}[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
   //  CHECK-NEXT:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
-  //  CHECK-NEXT:     %[[vA:.*]] = linalg.slice {{.*}}[%i1, *] { dim : 0 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:     %[[vC:.*]] = linalg.slice %[[sC]][%i1] { dim : 0 } : !linalg<"view<0xf32>">
+  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+  //  CHECK-NEXT:     %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>">
   //  CHECK-NEXT:     linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
   // clang-format on
   cleanupAndPrintFunction(f);
 }
 
+TEST_FUNC(matmul_as_loops) {
+  MLIRContext context;
+  Module module(&context);
+  mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+  lowerToLoops(f);
+  composeSliceOps(f);
+  // clang-format off
+  // CHECK-LABEL: func @matmul_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+  //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+  //       CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
+  //       CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range">
+  //       CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range">
+  //       CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range">
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<f32xf32>">
+  //       CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<f32xf32>">
+  //       CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<f32xf32>">
+  //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
+  //       CHECK:   affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
+  //       CHECK:     affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
+  //       CHECK:       %{{.*}} = cmpi "eq", %{{.*}} : index
+  //       CHECK:       %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<f32xf32>">
+  //       CHECK:       %{{.*}} = select {{.*}} : f32
+  //       CHECK:       %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<f32xf32>">
+  //       CHECK:       %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<f32xf32>">
+  //       CHECK:       %{{.*}} = mulf {{.*}} : f32
+  //       CHECK:       %{{.*}} = addf {{.*}} : f32
+  //       CHECK:       linalg.store {{.*}}[%i0, %i1] : !linalg<"view<f32xf32>">
+  // clang-format on
+  cleanupAndPrintFunction(f);
+}
+
+TEST_FUNC(matmul_as_matvec_as_loops) {
+  MLIRContext context;
+  Module module(&context);
+  mlir::Function *f =
+      makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
+  lowerToFinerGrainedTensorContraction(f);
+  lowerToLoops(f);
+  composeSliceOps(f);
+  // clang-format off
+  // CHECK-LABEL: func @matmul_as_matvec_as_loops(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+  //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
+  //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
+  //       CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<f32xf32>">
+  //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
+  //       CHECK:     affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
+  //       CHECK:        %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
+  //       CHECK:        %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<f32>">
+  //       CHECK:        %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
+  //       CHECK:        %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<f32>">
+  //       CHECK:        %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<f32xf32>">
+  //       CHECK:        %{{.*}} = mulf %[[A]], %[[B]] : f32
+  //       CHECK:        %{{.*}} = addf %[[C2]], %{{.*}} : f32
+  //       CHECK:        linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<f32>">
+  // clang-format on
+  cleanupAndPrintFunction(f);
+}
+
 int main() {
   RUN_TESTS();
   return 0;
diff --git a/mlir/tutorial/Linalg3/include/linalg3/Analysis.h b/mlir/tutorial/Linalg3/include/linalg3/Analysis.h
new file mode 100644 (file)
index 0000000..813fc37
--- /dev/null
@@ -0,0 +1,37 @@
+//===- Analysis.h - Linalg dialect Analysis function definitions ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_ANALYSIS_H_
+#define LINALG3_ANALYSIS_H_
+
+#include "linalg2/Analysis.h"
+
+namespace mlir {
+class AffineMap;
+} // namespace mlir
+
+namespace linalg {
+
+/// Given a `map` specification and a subset of its results
+/// `[beginResult, endResult)`, returns the inverse map that maps result
+/// positions to dim positions.
+mlir::AffineMap inverseSubMap(mlir::AffineMap map, unsigned beginResult = 0,
+                              unsigned endResult = 0);
+
+} // namespace linalg
+
+#endif // LINALG3_ANALYSIS_H_
diff --git a/mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h b/mlir/tutorial/Linalg3/include/linalg3/Intrinsics.h
new file mode 100644 (file)
index 0000000..75a0417
--- /dev/null
@@ -0,0 +1,31 @@
+//===- Intrinsics.h - Linalg intrinsics definitions -----------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_INTRINSICS_H_
+#define LINALG3_INTRINSICS_H_
+
+#include "linalg2/Intrinsics.h"
+#include "linalg3/Ops.h"
+
+namespace linalg {
+namespace intrinsics {
+using load = mlir::edsc::intrinsics::ValueBuilder<LoadOp>;
+using store = mlir::edsc::intrinsics::OperationBuilder<StoreOp>;
+} // namespace intrinsics
+} // namespace linalg
+
+#endif // LINALG3_INTRINSICS_H_
diff --git a/mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h b/mlir/tutorial/Linalg3/include/linalg3/LoadStoreOps.h
new file mode 100644 (file)
index 0000000..b77e702
--- /dev/null
@@ -0,0 +1,89 @@
+//===- LoadStoreOps.h - Linalg dialect Load/Store operation definitions ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG3_LOADSTOREOP_H_
+#define LINALG3_LOADSTOREOP_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace linalg {
+
+class ViewType;
+
+/// A linalg.LoadOp is the counterpart of affine.load but operating on ViewType
+/// instead of MemRefType.
+class LoadOp : public mlir::Op<LoadOp, mlir::OpTrait::VariadicOperands,
+                               mlir::OpTrait::OneResult> {
+public:
+  using Op::Op;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Hooks to customize the behavior of this op.
+  //////////////////////////////////////////////////////////////////////////////
+  static llvm::StringRef getOperationName() { return "linalg.load"; }
+  static void build(mlir::Builder *b, mlir::OperationState *result,
+                    mlir::Value *view,
+                    mlir::ArrayRef<mlir::Value *> indices = {});
+  mlir::LogicalResult verify();
+  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+  void print(mlir::OpAsmPrinter *p);
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Op-specific functionality.
+  //////////////////////////////////////////////////////////////////////////////
+  unsigned getRank();
+  ViewType getViewType();
+  mlir::Value *getView() { return getOperand(0); }
+  mlir::Operation::operand_range getIndices() {
+    return {operand_begin() + 1, operand_end()};
+  }
+};
+
+/// A linalg.StoreOp is the counterpart of affine.store but operating on
+/// ViewType instead of MemRefType.
+class StoreOp : public mlir::Op<StoreOp, mlir::OpTrait::VariadicOperands,
+                                mlir::OpTrait::ZeroResult> {
+public:
+  using Op::Op;
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Hooks to customize the behavior of this op.
+  //////////////////////////////////////////////////////////////////////////////
+  static llvm::StringRef getOperationName() { return "linalg.store"; }
+  static void build(mlir::Builder *b, mlir::OperationState *result,
+                    mlir::Value *valueToStore, mlir::Value *view,
+                    mlir::ArrayRef<mlir::Value *> indices = {});
+  mlir::LogicalResult verify();
+  static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
+  void print(mlir::OpAsmPrinter *p);
+
+  //////////////////////////////////////////////////////////////////////////////
+  // Op-specific functionality.
+  //////////////////////////////////////////////////////////////////////////////
+  unsigned getRank();
+  ViewType getViewType();
+  mlir::Value *getValueToStore() { return getOperand(0); }
+  mlir::Value *getView() { return getOperand(1); }
+  mlir::Operation::operand_range getIndices() {
+    return {operand_begin() + 2, operand_end()};
+  }
+};
+
+} // namespace linalg
+
+#endif // LINALG3_LOADSTOREOP_H_
index f2d5ec4..813cbff 100644 (file)
@@ -19,6 +19,7 @@
 #define LINALG3_OPS_H_
 
 #include "linalg2/Ops.h"
+#include "linalg3/LoadStoreOps.h"
 #include "linalg3/TensorOps.h"
 
 #endif // LINALG3_OPS_H_
index c4082d5..60d99ab 100644 (file)
 #define LINALG3_TENSOROPS_INL_H_
 
 #include "linalg1/Common.h"
+#include "linalg1/Utils.h"
 #include "linalg2/TensorOps.h"
-
-namespace linalg {
+#include "linalg3/Analysis.h"
+#include "linalg3/Ops.h"
 
 template <class ConcreteOp>
 mlir::Value *
@@ -38,6 +39,90 @@ linalg::TensorContractionBase<ConcreteOp>::getOutputView(unsigned i) {
   return *(getOutputs().begin() + i);
 }
 
-} // namespace linalg
+template <class ConcreteOp>
+mlir::AffineMap
+linalg::TensorContractionBase<ConcreteOp>::loopsToOperandRangesMap() {
+  return static_cast<ConcreteOp *>(this)->loopsToOperandRangesMap();
+}
+
+template <class ConcreteOp>
+void linalg::TensorContractionBase<ConcreteOp>::emitScalarImplementation(
+    llvm::ArrayRef<mlir::Value *> parallelIvs,
+    llvm::ArrayRef<mlir::Value *> reductionIvs) {
+  static_cast<ConcreteOp *>(this)->emitScalarImplementation(parallelIvs,
+                                                            reductionIvs);
+}
+
+template <class ConcreteOp>
+mlir::AffineMap linalg::operandRangesToLoopsMap(
+    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  return inverseSubMap(tensorContraction.loopsToOperandRangesMap());
+}
+
+// Extract the ranges from a given ViewOp or SliceOp.
+//
+// In the case of a ViewOp, things are simple: just traverse the indexings and
+// get all the ranges (i.e. drop the indices).
+//
+// In the case of a SliceOp, things are trickier because we need to handle a
+// potential rank-reduction:
+//   1. Examine the indexing to determine if it is rank-reducing.
+//   2. If it is rank-reducing, an offset of 1 is added to the dimensions such
+//      that `d >= slicingDim`. This is to account for the rank reduction.
+// `getRootIndex` is then called on the **parent** view
+static llvm::SmallVector<mlir::Value *, 8>
+extractRangesFromViewOrSliceOp(mlir::Value *view) {
+  // This expects a viewType which must come from either ViewOp or SliceOp.
+  assert(view->getType().isa<linalg::ViewType>() && "expected ViewType");
+  if (auto viewOp = view->getDefiningOp()->dyn_cast<linalg::ViewOp>())
+    return viewOp.getRanges();
+
+  auto sliceOp = view->getDefiningOp()->cast<linalg::SliceOp>();
+  unsigned slicingDim = sliceOp.getSlicingDim();
+  auto *indexing = *(sliceOp.getIndexings().begin());
+  bool isRankReducing = indexing->getType().isa<mlir::IndexType>();
+  unsigned offset = 0;
+  llvm::SmallVector<mlir::Value *, 8> res;
+  res.reserve(sliceOp.getRank());
+  for (unsigned d = 0, e = sliceOp.getRank(); d < e; ++d) {
+    if (d == slicingDim && isRankReducing)
+      offset = 1;
+    auto *parentView = sliceOp.getParentView();
+    auto indexingPosPair = linalg::getViewRootIndexing(parentView, d + offset);
+    res.push_back(indexingPosPair.first);
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getInputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *in : tensorContraction.getInputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(in);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+static llvm::SmallVector<mlir::Value *, 8>
+getOutputRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res;
+  for (auto *out : tensorContraction.getOutputs()) {
+    auto subres = extractRangesFromViewOrSliceOp(out);
+    res.append(subres.begin(), subres.end());
+  }
+  return res;
+}
+
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8> linalg::getRanges(
+    linalg::TensorContractionBase<ConcreteOp> &tensorContraction) {
+  llvm::SmallVector<mlir::Value *, 8> res = getInputRanges(tensorContraction);
+  llvm::SmallVector<mlir::Value *, 8> tmp = getOutputRanges(tensorContraction);
+  res.append(tmp.begin(), tmp.end());
+  return res;
+}
 
-#endif // LINALG3_TENSOROPS-INL_H_
+#endif // LINALG3_TENSOROPS_INL_H_
index 3dffd6e..4ade192 100644 (file)
 
 #include "linalg2/TensorOps.h"
 
+namespace linalg {
+
+///
+/// Ideally all these functions would go in an Analysis but until
+/// TensorContractionBase is templated, they need to remain close enough.
+///
+
+/// Takes a `tensorContraction` and a returns an AffineMap that can be used to
+/// map ranges to enclosing loops for all the operands' ranges.
+template <class ConcreteOp>
+mlir::AffineMap operandRangesToLoopsMap(
+    linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+/// Takes a `tensorContraction` and returns the ranges of all its operands.
+/// When an operand comes from a ViewOp, things are simple:
+///   just traverse the indexings and get all the ranges
+///     (i.e. drop the rank-reducing indices).
+/// In the case of a SliceOp, things are more involved because we need to handle
+/// potential rank-reductions.
+/// This function abstracts this complexity away and returns all the ranges.
+template <class ConcreteOp>
+llvm::SmallVector<mlir::Value *, 8>
+getRanges(linalg::TensorContractionBase<ConcreteOp> &tensorContraction);
+
+} // namespace linalg
+
 /// The TensorOp-inl.h inclusion pattern is chosen to allow gradual extension of
 /// TensorOps by adding implementations as they are needed in the appropriate
 /// step in the tutorial.
index b5e11dd..5cc7692 100644 (file)
@@ -30,10 +30,17 @@ namespace linalg {
 /// to only use linalg.view operations.
 void composeSliceOps(mlir::Function *f);
 
-/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec, linalg.dot)
-/// as linalg.matvec (resp. linalg.dot, loop form).
+/// Traverses `f` and rewrites linalg.load and linalg.store to affine.load and
+/// affine.store operations.
+void lowerLinalgLoadStores(mlir::Function *f);
+
+/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec)
+/// as linalg.matvec (resp. linalg.dot).
 void lowerToFinerGrainedTensorContraction(mlir::Function *f);
 
+/// Traverses `f` and rewrites linalg operations in loop form.
+void lowerToLoops(mlir::Function *f);
+
 } // namespace linalg
 
 #endif // LINALG3_TRANSFORMS_H_
diff --git a/mlir/tutorial/Linalg3/lib/Analysis.cpp b/mlir/tutorial/Linalg3/lib/Analysis.cpp
new file mode 100644 (file)
index 0000000..9e7c8ee
--- /dev/null
@@ -0,0 +1,62 @@
+//===- Analysis.cpp - Implementation of analysis functions for Linalg -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a simple IR operation to create a new RangeType in the
+// linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/Analysis.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/StandardTypes.h"
+
+using llvm::SmallVector;
+using namespace mlir;
+
+// Compute an inverse map (only works with permutations for now).
+// Note that the mapping is generally non-full rank, so this returns the first
+// seen entry for each dim.
+static AffineMap inversePermutationMap(AffineMap map) {
+  SmallVector<AffineExpr, 4> exprs(map.getNumDims());
+  for (auto en : llvm::enumerate(map.getResults())) {
+    auto expr = en.value();
+    auto d = expr.dyn_cast<AffineDimExpr>();
+    assert(d && "permutation map expected");
+    if (exprs[d.getPosition()])
+      continue;
+    exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
+  }
+  SmallVector<AffineExpr, 4> seenExprs;
+  seenExprs.reserve(map.getNumDims());
+  for (auto expr : exprs)
+    if (expr)
+      seenExprs.push_back(expr);
+  assert(map.getNumSymbols() == 0 && "expected map without symbols");
+  assert(seenExprs.size() == map.getNumInputs() && "map is not invertible");
+  return AffineMap::get(map.getNumResults(), 0, seenExprs, {});
+}
+
+mlir::AffineMap linalg::inverseSubMap(AffineMap map, unsigned beginResult,
+                                      unsigned endResult) {
+  if (beginResult == 0 && endResult == 0)
+    endResult = map.getNumResults();
+  auto subMap = AffineMap::get(
+      map.getNumDims(), map.getNumSymbols(),
+      map.getResults().slice(beginResult, endResult - beginResult), {});
+  return inversePermutationMap(subMap);
+}
diff --git a/mlir/tutorial/Linalg3/lib/DialectRegistration.cpp b/mlir/tutorial/Linalg3/lib/DialectRegistration.cpp
new file mode 100644 (file)
index 0000000..1ab2751
--- /dev/null
@@ -0,0 +1,39 @@
+//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file registers the Linalg dialect and should live in a standalone
+// library. Linking with this library will create a static global object that
+// performs dialect registration.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg1/Dialect.h"
+#include "linalg1/Types.h"
+#include "linalg3/Ops.h"
+
+using namespace linalg;
+
+LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
+    : Dialect("linalg", context) {
+  addTypes<RangeType, ViewType>();
+  addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
+                ViewOp>();
+}
+
+// Dialect registration triggers the creation of a `LinalgDialect` object which
+// adds the proper types and operations to the dialect.
+static mlir::DialectRegistration<LinalgDialect> LinalgOps;
diff --git a/mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp b/mlir/tutorial/Linalg3/lib/LoadStoreOps.cpp
new file mode 100644 (file)
index 0000000..340916f
--- /dev/null
@@ -0,0 +1,136 @@
+//===- LoadStoreOps.cpp - Implementation of linalg Load/Store operations --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements linalg.load and linalg.store operations which allow
+// accessing memory through ViewType values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "linalg3/LoadStoreOps.h"
+#include "linalg3/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using llvm::ArrayRef;
+using namespace mlir;
+using namespace linalg;
+
+////////////////////////////////////////////////////////////////////////////////
+// LoadOp.
+////////////////////////////////////////////////////////////////////////////////
+void linalg::LoadOp::build(Builder *b, OperationState *result, Value *view,
+                           ArrayRef<Value *> indices) {
+  auto viewType = view->getType().cast<ViewType>();
+  result->addOperands(view);
+  result->addOperands(indices);
+  result->addTypes(viewType.getElementType());
+}
+
+void linalg::LoadOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+bool linalg::LoadOp::parse(OpAsmParser *parser, OperationState *result) {
+  llvm_unreachable("Parsing linalg dialect is not supported in this tutorial");
+  return false;
+}
+
+LogicalResult linalg::LoadOp::verify() {
+  if (getNumOperands() == 0)
+    return emitOpError("expected a view to load from");
+
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("first operand must be a view");
+
+  if (getType() != viewType.getElementType())
+    return emitOpError("result type must match element type of the view");
+
+  if (getRank() != getNumOperands() - 1)
+    return emitOpError("incorrect number of indices for load");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+
+  return success();
+}
+
+ViewType linalg::LoadOp::getViewType() {
+  return getView()->getType().cast<ViewType>();
+}
+
+unsigned linalg::LoadOp::getRank() { return getViewType().getRank(); }
+
+////////////////////////////////////////////////////////////////////////////////
+// StoreOp.
+////////////////////////////////////////////////////////////////////////////////
+void linalg::StoreOp::build(Builder *b, OperationState *result,
+                            Value *valueToStore, Value *view,
+                            ArrayRef<Value *> indices) {
+  result->addOperands(valueToStore);
+  result->addOperands(view);
+  result->addOperands(indices);
+}
+
+void linalg::StoreOp::print(OpAsmPrinter *p) {
+  *p << getOperationName() << " " << *getValueToStore();
+  *p << ", " << *getView() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << getViewType();
+}
+
+bool linalg::StoreOp::parse(OpAsmParser *parser, OperationState *result) {
+  assert(false && "NYI");
+  return false;
+}
+
+LogicalResult linalg::StoreOp::verify() {
+  if (getNumOperands() < 2)
+    return emitOpError("expected a value to store and a view");
+
+  // Second operand is a memref type.
+  auto viewType = getView()->getType().dyn_cast<ViewType>();
+  if (!viewType)
+    return emitOpError("second operand must be a view");
+
+  // First operand must have same type as memref element type.
+  if (getValueToStore()->getType() != viewType.getElementType())
+    return emitOpError("first operand must have same element type as the view");
+
+  if (getNumOperands() != 2 + viewType.getRank())
+    return emitOpError("store index operand count not equal to view rank");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to store must have 'index' type");
+
+  return success();
+}
+
+unsigned linalg::StoreOp::getRank() { return getViewType().getRank(); }
+
+ViewType linalg::StoreOp::getViewType() {
+  return getView()->getType().cast<ViewType>();
+}
index a04d772..61eaa06 100644 (file)
@@ -22,7 +22,7 @@
 
 #include "linalg1/Analysis.h"
 #include "linalg1/Common.h"
-#include "linalg2/Intrinsics.h"
+#include "linalg3/Intrinsics.h"
 #include "linalg3/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
 using namespace linalg;
 using namespace linalg::intrinsics;
 
+//////////////////////////////////////////////////////////////////////////////
+// Implementation of DotOp.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::DotOp::loopsToOperandRangesMap() {
+  // A(K), B(K), C()
+  assert(getRanges(*this).size() == 2);
+  auto *context = ScopedContext::getContext();
+  auto d0 = getAffineDimExpr(0, context); // K
+  // A(K), B(K), C()
+  //   (d0) -> (d0, d0)(%k)
+  return AffineMap::get(1, 0, {d0, d0}, {});
+}
+
+void linalg::DotOp::emitScalarImplementation(
+    llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+  using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+                                             linalg::intrinsics::store>;
+  assert(reductionIvs.size() == 1);
+  auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+  auto *body = innermostLoop.getBody();
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  using edsc::op::operator==;
+  using edsc::intrinsics::select;
+  ScopedContext scope( // account for affine.terminator in loop.
+      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  IndexHandle zero(constant_index(0));
+  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  IndexHandle r_i(reductionIvs[0]);
+  IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+  ValueHandle cond = (r_i == zero);
+  ValueHandle scalarC = select(cond, zerof, *C());
+  C() = scalarC + A(r_i) * B(r_i);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// Implementation of MatvecOp.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::MatvecOp::loopsToOperandRangesMap() {
+  // A(M, K), B(K), C(M)
+  assert(getRanges(*this).size() == 4);
+  auto *context = ScopedContext::getContext();
+  auto d0 = getAffineDimExpr(0, context); // M
+  auto d1 = getAffineDimExpr(1, context); // K
+  // A(M, K), B(K), C(M)
+  //   (d0, d1) -> (d0, d1, d1, d0)(%m, %k)
+  return AffineMap::get(2, 0, {d0, d1, d1, d0}, {});
+}
+
 // The body expression for matvec is: C(i) = scalarC + A(i, r_j) * B(r_j)
 // The body expression for dot is: C() = A(r_i) * B(r_i);
 // So we must drop the `i` loop from the matvec.
 void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
   auto *op = getOperation();
-  ScopedContext scope(FuncBuilder(op), op->getLoc());
-  IndexHandle i;
   auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
   auto indexingPosPair = getViewRootIndexing(vA, 0);
   assert(indexingPosPair.first->getDefiningOp() &&
          indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
-  linalg::common::LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
-      dot(slice(vA, i, 0), vB, slice(vC, i, 0)),
+  // clang-format off
+  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  IndexHandle i;
+  using linalg::common::LoopNestRangeBuilder;
+  LoopNestRangeBuilder(&i, ValueHandle(indexingPosPair.first))({
+    [&i, &vA, &vB, &vC]() {
+      ValueHandle sliceA = slice(vA, i, 0);
+      ValueHandle sliceC = slice(vC, i, 0);
+      dot(sliceA, vB, sliceC);
+      /// NestedBuilders expect handles, we thus return an IndexHandle.
+      return IndexHandle();
+    }()
   });
+  // clang-format on
+}
+
+void linalg::MatvecOp::emitScalarImplementation(
+    llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+  using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+                                             linalg::intrinsics::store>;
+  assert(reductionIvs.size() == 1);
+  auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+  auto *body = innermostLoop.getBody();
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  using edsc::op::operator==;
+  using edsc::intrinsics::select;
+  ScopedContext scope( // account for affine.terminator in loop.
+      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
+  IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+  IndexHandle zero(constant_index(0));
+  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  ValueHandle cond = (r_j == zero);
+  ValueHandle scalarC = select(cond, zerof, *C(i));
+  C(i) = scalarC + A(i, r_j) * B(r_j);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// Op-specific Matmul.
+//////////////////////////////////////////////////////////////////////////////
+AffineMap linalg::MatmulOp::loopsToOperandRangesMap() {
+  // A(M, K), B(K, N), C(M, N)
+  assert(getRanges(*this).size() == 6);
+  auto *context = ScopedContext::getContext();
+  auto d0 = getAffineDimExpr(0, context); // M
+  auto d1 = getAffineDimExpr(1, context); // N
+  auto d2 = getAffineDimExpr(2, context); // K
+  // A(M, K), B(K, N), C(M, N):
+  //   (d0, d1, d2) -> (d0, d2, d2, d1, d0, d1)(%m, %n, %k)
+  return AffineMap::get(3, 0, {d0, d2, d2, d1, d0, d1}, {});
 }
 
 // The body expression for matmul is: C(i, j) = scalarC + A(i, r_k) * B(r_k, j)
@@ -58,13 +156,45 @@ void linalg::MatvecOp::writeAsFinerGrainTensorContraction() {
 // declaratively.
 void linalg::MatmulOp::writeAsFinerGrainTensorContraction() {
   auto *op = getOperation();
-  ScopedContext scope(FuncBuilder(op), op->getLoc());
-  IndexHandle j;
   auto *vA(getInputView(0)), *vB(getInputView(1)), *vC(getOutputView(0));
   auto indexingPosPair = getViewRootIndexing(vB, 1);
   assert(indexingPosPair.first->getDefiningOp() &&
          indexingPosPair.first->getDefiningOp()->isa<RangeOp>());
-  linalg::common::LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
-      matvec(vA, slice(vB, j, 1), slice(vC, j, 1)),
+  using linalg::common::LoopNestRangeBuilder;
+  // clang-format off
+  ScopedContext scope(FuncBuilder(op), op->getLoc());
+  IndexHandle j;
+  LoopNestRangeBuilder(&j, ValueHandle(indexingPosPair.first))({
+    [&j, &vA, &vB, &vC]() {
+      ValueHandle sliceB = slice(vB, j, 1);
+      ValueHandle sliceC = slice(vC, j, 1);
+      matvec(vA, sliceB, sliceC);
+      /// NestedBuilders expect handles, we thus return an IndexHandle.
+      return IndexHandle();
+    }()
   });
+  // clang-format on
+}
+
+void linalg::MatmulOp::emitScalarImplementation(
+    llvm::ArrayRef<Value *> parallelIvs, llvm::ArrayRef<Value *> reductionIvs) {
+  using IndexedValue = TemplatedIndexedValue<linalg::intrinsics::load,
+                                             linalg::intrinsics::store>;
+  assert(reductionIvs.size() == 1);
+  auto innermostLoop = getForInductionVarOwner(reductionIvs.back());
+  auto *body = innermostLoop.getBody();
+  using edsc::op::operator+;
+  using edsc::op::operator*;
+  using edsc::op::operator==;
+  using edsc::intrinsics::select;
+  ScopedContext scope( // account for affine.terminator in loop.
+      FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
+  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
+  IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
+  IndexHandle zero(constant_index(0));
+  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  ValueHandle cond = r_k == zero;
+  ValueHandle scalarC = select(cond, zerof, *C(i, j));
+  C(i, j) = scalarC + A(i, r_k) * B(r_k, j);
 }
index aa9fbd0..070ef5e 100644 (file)
@@ -53,3 +53,171 @@ void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) {
     op->erase();
   });
 }
+
+// Folding eagerly is necessary to abide by affine.for static step requirement.
+// Returns nullptr if folding is not trivially feasible.
+static Value *tryFold(AffineMap map, SmallVector<Value *, 4> operands) {
+  assert(map.getNumResults() == 1 && "single result map expected");
+  auto expr = map.getResult(0);
+  if (auto dim = expr.dyn_cast<AffineDimExpr>())
+    return operands[dim.getPosition()];
+  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+    return operands[map.getNumDims() + sym.getPosition()];
+  if (auto cst = expr.dyn_cast<AffineConstantExpr>())
+    return constant_index(cst.getValue());
+  return nullptr;
+}
+
+static Value *makeFoldedComposedAffineApply(AffineMap map,
+                                            ArrayRef<Value *> operandsRef) {
+  SmallVector<Value *, 4> operands(operandsRef.begin(), operandsRef.end());
+  fullyComposeAffineMapAndOperands(&map, &operands);
+  if (auto *v = tryFold(map, operands)) {
+    return v;
+  }
+  auto *b = ScopedContext::getBuilder();
+  auto loc = ScopedContext::getLocation();
+  return b->create<AffineApplyOp>(loc, map, operands).getResult();
+}
+
+struct RangeParts {
+  explicit RangeParts(unsigned reserved);
+  RangeParts(ArrayRef<Value *> ranges);
+
+  SmallVector<Value *, 4> makeRanges();
+
+  SmallVector<Value *, 4> mins;
+  SmallVector<Value *, 4> maxes;
+  SmallVector<Value *, 4> steps;
+};
+
+RangeParts::RangeParts(unsigned reserved) {
+  mins.reserve(reserved);
+  maxes.reserve(reserved);
+  steps.reserve(reserved);
+}
+
+static SmallVector<Value *, 4>
+extractFromRanges(ArrayRef<Value *> ranges,
+                  std::function<Value *(RangeOp)> extract) {
+  SmallVector<Value *, 4> res;
+  res.reserve(ranges.size());
+  for (auto *v : ranges) {
+    auto r = v->getDefiningOp()->cast<RangeOp>();
+    res.push_back(extract(r));
+  }
+  return res;
+}
+
+RangeParts::RangeParts(ArrayRef<Value *> ranges)
+    : mins(extractFromRanges(ranges, [](RangeOp r) { return r.getMin(); })),
+      maxes(extractFromRanges(ranges, [](RangeOp r) { return r.getMax(); })),
+      steps(extractFromRanges(ranges, [](RangeOp r) { return r.getStep(); })) {}
+
+SmallVector<Value *, 4> RangeParts::makeRanges() {
+  SmallVector<Value *, 4> res;
+  res.reserve(mins.size());
+  for (auto z : llvm::zip(mins, maxes, steps)) {
+    res.push_back(range(std::get<0>(z), std::get<1>(z), std::get<2>(z)));
+  }
+  return res;
+}
+
+static RangeParts makeGenericRangeParts(AffineMap map,
+                                        ArrayRef<Value *> ranges) {
+  assert(map.getNumInputs() == ranges.size());
+  unsigned numDims = map.getNumDims();
+  assert(map.getNumSymbols() == 0);
+  assert(map.getRangeSizes().empty());
+
+  RangeParts res(map.getNumResults());
+  RangeParts rangeParts(ranges);
+  for (auto expr : map.getResults()) {
+    AffineMap map = AffineMap::get(numDims, 0, expr, {});
+    res.mins.push_back(makeFoldedComposedAffineApply(map, rangeParts.mins));
+    res.maxes.push_back(makeFoldedComposedAffineApply(map, rangeParts.maxes));
+    res.steps.push_back(makeFoldedComposedAffineApply(map, rangeParts.steps));
+  }
+  return res;
+}
+
+SmallVector<Value *, 4> makeGenericRanges(AffineMap map,
+                                          ArrayRef<Value *> ranges) {
+  return makeGenericRangeParts(map, ranges).makeRanges();
+}
+
+static SmallVector<Value *, 4> makeGenericLoopRanges(
+    AffineMap operandRangesToLoopsMap, ArrayRef<Value *> ranges,
+    llvm::Optional<ArrayRef<Value *>> tileSizes = llvm::None) {
+  RangeParts res = makeGenericRangeParts(operandRangesToLoopsMap, ranges);
+  if (!tileSizes.hasValue())
+    return res.makeRanges();
+  SmallVector<Value *, 4> tiledSteps;
+  for (auto z : llvm::zip(res.steps, *tileSizes)) {
+    auto *step = std::get<0>(z);
+    auto tileSize = std::get<1>(z);
+    auto stepValue = step->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+    auto tileSizeValue =
+        tileSize->getDefiningOp()->cast<ConstantIndexOp>().getValue();
+    assert(stepValue > 0);
+    tiledSteps.push_back(constant_index(stepValue * tileSizeValue));
+  }
+  res.steps = tiledSteps;
+  return res.makeRanges();
+}
+
+template <class ContractionOp>
+static SmallVector<mlir::AffineForOp, 4>
+writeAsLoops(ContractionOp contraction) {
+  ScopedContext scope(mlir::FuncBuilder(contraction.getOperation()),
+                      contraction.getLoc());
+  auto loopRanges = makeGenericLoopRanges(operandRangesToLoopsMap(contraction),
+                                          getRanges(contraction));
+
+  SmallVector<IndexHandle, 4> parallelIvs(contraction.getNumParallelDims());
+  SmallVector<IndexHandle, 4> reductionIvs(contraction.getNumReductionDims());
+  auto pivs = IndexHandle::makeIndexHandlePointers(parallelIvs);
+  auto rivs = IndexHandle::makeIndexHandlePointers(reductionIvs);
+  assert(loopRanges.size() == pivs.size() + rivs.size());
+
+  // clang-format off
+  using linalg::common::LoopNestRangeBuilder;
+  ArrayRef<Value *> ranges(loopRanges);
+  LoopNestRangeBuilder(pivs, ranges.take_front(pivs.size()))({
+    LoopNestRangeBuilder(rivs, ranges.take_back(rivs.size()))({
+      [&contraction, &parallelIvs, &reductionIvs]() {
+        SmallVector<mlir::Value *, 4> parallel(
+            parallelIvs.begin(), parallelIvs.end());
+        SmallVector<mlir::Value *, 4> reduction(
+            reductionIvs.begin(), reductionIvs.end());
+        contraction.emitScalarImplementation(parallel, reduction);
+        /// NestedBuilders expect handles, we thus return an IndexHandle.
+        return IndexHandle();
+      }()
+    })
+  });
+  // clang-format on
+
+  SmallVector<mlir::AffineForOp, 4> res;
+  res.reserve(pivs.size() + rivs.size());
+  for (auto iv : parallelIvs)
+    res.push_back(getForInductionVarOwner(iv.getValue()));
+  for (auto iv : reductionIvs)
+    res.push_back(getForInductionVarOwner(iv.getValue()));
+  return res;
+}
+
+void linalg::lowerToLoops(mlir::Function *f) {
+  f->walkPostOrder([](Operation *op) {
+    if (auto matmulOp = op->dyn_cast<linalg::MatmulOp>()) {
+      writeAsLoops(matmulOp);
+    } else if (auto matvecOp = op->dyn_cast<linalg::MatvecOp>()) {
+      writeAsLoops(matvecOp);
+    } else if (auto dotOp = op->dyn_cast<linalg::DotOp>()) {
+      writeAsLoops(dotOp);
+    } else {
+      return;
+    }
+    op->erase();
+  });
+}