[mlir][Linalg] Add support to lower named ops to loops.
authorNicolas Vasilache <ntv@google.com>
Thu, 30 Apr 2020 13:44:55 +0000 (09:44 -0400)
committerNicolas Vasilache <ntv@google.com>
Thu, 30 Apr 2020 17:45:17 +0000 (13:45 -0400)
This revision adds support to allow named ops to lower to loops.
Linalg.batch_matmul successfully lowers to loops and to LLVM.

In the process, this test also activates linalg to affine loops.
However padded convolutions to not lower to affine.load atm so this revision overrides the type of underlying load / store operation.

Differential Revision: https://reviews.llvm.org/D79135

mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/test/Dialect/Linalg/affine.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

index 1c427fa..b7bba5a 100644 (file)
@@ -351,10 +351,11 @@ template <typename ConcreteType>
 class NamedStructuredOpTraits
     : public OpTrait::TraitBase<ConcreteType, NamedStructuredOpTraits> {
 public:
-  llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
-  llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
-  std::function<void(OpBuilder &, Location, ArrayRef<Value>)>
-  emitScalarImplementation();
+  static SmallVector<StringRef, 8> referenceIterators(TypeRange inputTypes,
+                                                      TypeRange outputTypes);
+
+  static SmallVector<AffineMap, 8> referenceIndexingMaps(TypeRange inputTypes,
+                                                         TypeRange outputTypes);
 };
 
 } // namespace linalg
index 974bff5..82ae6de 100644 (file)
@@ -33,10 +33,9 @@ using namespace mlir::linalg;
 
 /// Forward declarations.
 template <typename NamedStructuredOpType>
-static void buildNamedStructuredOpRegion(Builder &builder,
-                                         OperationState &result,
-                                         TypeRange operandTypes,
-                                         TypeRange tensorResultTypes);
+static void buildNamedStructuredOpRegionAndAttributes(
+    Builder &builder, OperationState &result, TypeRange operandTypes,
+    TypeRange tensorResultTypes);
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 template <typename NamedStructuredOpType>
@@ -1085,9 +1084,10 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
 //===----------------------------------------------------------------------===//
 
 template <typename NamedStructuredOpType>
-void buildNamedStructuredOpRegion(Builder &builder, OperationState &result,
-                                  TypeRange operandTypes,
-                                  TypeRange tensorResultTypes) {
+void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
+                                               OperationState &result,
+                                               TypeRange operandTypes,
+                                               TypeRange tensorResultTypes) {
   Region &region = *result.addRegion();
   Block *body = new Block();
   // TODO: atm all operands go through getElementTypeOrSelf,
@@ -1102,12 +1102,24 @@ void buildNamedStructuredOpRegion(Builder &builder, OperationState &result,
   opBuilder.setInsertionPointToStart(&region.front());
   mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
   NamedStructuredOpType::regionBuilder(*body);
+
+  auto indexingMaps = builder.getAffineMapArrayAttr(
+      NamedStructuredOpType::referenceIndexingMaps(operandTypes,
+                                                   tensorResultTypes));
+  result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
+
+  auto iterators =
+      builder.getStrArrayAttr(NamedStructuredOpType::referenceIterators(
+          operandTypes, tensorResultTypes));
+  result.addAttribute(getIteratorTypesAttrName(), iterators);
 }
 
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
+  std::array<StringRef, 2> silentAttrNames{getIndexingMapsAttrName(),
+                                           getIteratorTypesAttrName()};
   p << op.getOperationName() << ' ';
-  p.printOptionalAttrDict(op.getAttrs());
+  p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
   p << ' ' << op.getOperands();
   p << ": (" << op.getOperandTypes() << ")";
   auto outputTensorTypes = op.getResultTypes();
@@ -1139,7 +1151,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
   if (!tensorResultTypes.empty())
     result.addTypes(tensorResultTypes);
 
-  buildNamedStructuredOpRegion<NamedStructuredOpType>(
+  buildNamedStructuredOpRegionAndAttributes<NamedStructuredOpType>(
       parser.getBuilder(), result, operandTypes, tensorResultTypes);
 
   return parser.resolveOperands(operandsInfo, operandTypes,
index 4a6d54c..62a5a02 100644 (file)
@@ -78,11 +78,10 @@ SmallVector<Value, 4> emitLoopRanges(OpBuilder &b, Location loc, AffineMap map,
   return res;
 }
 
-template <typename OpType>
-static void
-inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
-                            ArrayRef<SmallVector<Value, 8>> indexing,
-                            ArrayRef<Value> outputBuffers) {
+template <typename IndexedValueType, typename OpType>
+static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
+                                     ArrayRef<SmallVector<Value, 8>> indexing,
+                                     ArrayRef<Value> outputBuffers) {
   auto &b = ScopedContext::getBuilder();
   auto &block = op.region().front();
   BlockAndValueMapping map;
@@ -95,10 +94,10 @@ inlineRegionAndEmitStdStore(OpType op, ArrayRef<Value> indexedValues,
 
   Operation &terminator = block.back();
   assert(isa<YieldOp>(terminator) &&
-         "expected an yield op in the end of the region");
+         "expected a yield op in the end of the region");
   for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) {
-    std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i],
-              ArrayRef<Value>{indexing[i].begin(), indexing[i].end()});
+    IndexedValueType O(outputBuffers[i]);
+    O(indexing[i]) = map.lookupOrDefault(terminator.getOperand(i));
   }
 }
 
@@ -123,9 +122,36 @@ static InputAndOutputIndices getInputAndOutputIndices(ArrayRef<Value> allIvs,
 
 namespace {
 
-// Generic loop emitter, to be specialized on an op-per op basis.
-// TODO: Hook up to named ops interface and, later, retire when all named ops
-// are auto-generated.
+/// Emits the MLIR for the scalar part of the generic op by:
+///   1. Emitting load ops for each input and output view in order. This is
+///      achieved by applying the appropriate input or output map to the
+///      enclosing induction variables.
+///   2. Emitting a call to `op.fun()` that takes as arguments the scalars
+///      from point 1. above.
+///   3. Emitting store ops to store the results of 2. to the output
+///      views.
+///
+/// An example output may resemble:
+///
+/// ```
+///    loop.for %i = %c0 to %0 step %c1 {
+///      loop.for %j = %c0 to %1 step %c1 {
+///        loop.for %k = %c0 to %4 step %c1 {
+///          %11 = load %arg0[%i, %j] :
+///            memref<?x?xf32, stride_specification>
+///          %12 = load %arg1[%i, %j, %k] :
+///            memref<?x?x?xf32, stride_specification>
+///          %13 = load %arg2[%i, %k, %j] :
+///            memref<?x?x?xf32, stride_specification>
+///          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
+///          store %14#0, %arg1[%i, %j, %k] :
+///            memref<?x?x?Xf32, stride_specification>
+///          store %14#1, %arg2[%i, %k, %j] :
+///            memref<?x?x?Xf32, stride_specification>
+///       }
+///      }
+///    }
+/// ```
 template <typename IndexedValueType, typename LinalgOpType>
 class LinalgScopedEmitter {
 public:
@@ -133,9 +159,43 @@ public:
                                        LinalgOpType linalgOp) {
     assert(linalgOp.hasBufferSemantics() &&
            "expected linalg op with buffer semantics");
-    llvm_unreachable("NYI");
-    linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(),
-                                        ScopedContext::getLocation(), allIvs);
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    unsigned nInputs = linalgOp.getNumInputs();
+    unsigned nOutputs = linalgOp.getNumOutputs();
+    SmallVector<Value, 4> indexedValues;
+    indexedValues.reserve(nInputs + nOutputs);
+
+    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
+    // region has no uses.
+    // 1.a. Emit load from input views.
+    for (unsigned i = 0; i < nInputs; ++i) {
+      auto indexing = makeCanonicalAffineApplies(
+          b, loc, linalgOp.getInputIndexingMap(i), allIvs);
+      // Passing through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(IndexedValueType(linalgOp.getInput(i))(indexing));
+    }
+    // 1.b. Emit load from output views.
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      auto indexing = makeCanonicalAffineApplies(
+          b, loc, linalgOp.getOutputIndexingMap(i), allIvs);
+      // Passing through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(linalgOp.getOutputBuffer(i))(indexing));
+    }
+
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    // 3. Emit store.
+    SmallVector<SmallVector<Value, 8>, 8> indexing;
+    SmallVector<Value, 8> outputBuffers;
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      indexing.push_back(makeCanonicalAffineApplies(
+          b, loc, linalgOp.getOutputIndexingMap(i), allIvs));
+      outputBuffers.push_back(linalgOp.getOutputBuffer(i));
+    }
+    inlineRegionAndEmitStore<IndexedValueType>(linalgOp, indexedValues,
+                                               indexing, outputBuffers);
   }
 };
 
@@ -231,7 +291,7 @@ class LinalgScopedEmitter<IndexedValueType, ConvOp> {
 public:
   /// Returns the input value of convOp. If the indices in `imIdx` is out of
   /// boundary, returns 0 instead.
-  static Value getConvOpInput(ConvOp convOp, IndexedValueType im,
+  static Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
                               MutableArrayRef<Value> imIdx) {
     // TODO(ntv): add a level of indirection to linalg.generic.
     if (!convOp.padding())
@@ -293,7 +353,11 @@ public:
         makeCanonicalAffineApplies(b, loc, maps[1], allIvs));
     SmallVector<Value, 8> oIdx(
         makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
-    IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output());
+
+    // Padded conv involves an affine.max in the memory access which is not
+    // allowed by affine.load. Override to always use an StdIndexedValue.
+    StdIndexedValue I(convOp.input());
+    IndexedValueType F(convOp.filter()), O(convOp.output());
 
     // Emit scalar form.
     Value paddedInput = getConvOpInput(convOp, I, imIdx);
@@ -344,111 +408,36 @@ public:
   }
 };
 
-// Emits the MLIR for the scalar part of the generic op by:
-//   1. Emitting std_load and std_store ops for each input and output
-//      view in order. This is achieved by applying the appropriate input or
-//      output map to the enclosing induction variables.
-//   2. Emitting a call to `op.fun()` that takes as arguments the scalars
-//      from point 1. above.
-//   3. Emitting std_store to store the results of 2. to the output
-//      views.
-//
-// An example output may resemble:
-//
-// ```
-//    loop.for %i = %c0 to %0 step %c1 {
-//      loop.for %j = %c0 to %1 step %c1 {
-//        loop.for %k = %c0 to %4 step %c1 {
-//          %11 = load %arg0[%i, %j] :
-//            memref<?x?xf32, stride_specification>
-//          %12 = load %arg1[%i, %j, %k] :
-//            memref<?x?x?xf32, stride_specification>
-//          %13 = load %arg2[%i, %k, %j] :
-//            memref<?x?x?xf32, stride_specification>
-//          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
-//          store %14#0, %arg1[%i, %j, %k] :
-//            memref<?x?x?Xf32, stride_specification>
-//          store %14#1, %arg2[%i, %k, %j] :
-//            memref<?x?x?Xf32, stride_specification>
-//       }
-//      }
-//    }
-// ```
-template <typename IndexedValueType>
-class LinalgScopedEmitter<IndexedValueType, GenericOp> {
-public:
-  static void emitScalarImplementation(ArrayRef<Value> allIvs,
-                                       GenericOp genericOp) {
-    assert(genericOp.hasBufferSemantics() &&
-           "expected linalg op with buffer semantics");
-    auto b = ScopedContext::getBuilder();
-    auto loc = ScopedContext::getLocation();
-    unsigned nInputs = genericOp.getNumInputs();
-    unsigned nOutputs = genericOp.getNumOutputs();
-    SmallVector<Value, 4> indexedValues(nInputs + nOutputs);
-
-    // 1.a. Emit std_load from input views.
-    for (unsigned i = 0; i < nInputs; ++i) {
-      auto indexing = makeCanonicalAffineApplies(
-          b, loc, genericOp.getInputIndexingMap(i), allIvs);
-      indexedValues[i] = std_load(genericOp.getInput(i), indexing);
-    }
-
-    // 1.b. Emit std_load from output views.
-    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
-    // region has no uses.
-    for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = genericOp.getOutputBuffer(i);
-      auto indexing = makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs);
-      indexedValues[nInputs + i] = std_load(output, indexing);
-    }
-
-    // TODO(ntv): When a region inliner exists, use it.
-    // 2. Inline region, currently only works for a single basic block.
-    // 3. Emit std_store.
-    SmallVector<SmallVector<Value, 8>, 8> indexing;
-    SmallVector<Value, 8> outputBuffers;
-    for (unsigned i = 0; i < nOutputs; ++i) {
-      indexing.push_back(makeCanonicalAffineApplies(
-          b, loc, genericOp.getOutputIndexingMap(i), allIvs));
-      outputBuffers.push_back(genericOp.getOutputBuffer(i));
-    }
-    inlineRegionAndEmitStdStore(genericOp, indexedValues, indexing,
-                                outputBuffers);
-  }
-};
-
-// Emits the MLIR for the scalar part of the indexed generic op by:
-//   1. Emitting std_load and std_store ops for each input and output view in
-//      order. This is achieved by applying the appropriate input or output map
-//      to the enclosing induction variables.
-//   2. Emitting a call to `op.fun()` that takes as arguments the induction
-//      variables and the scalars from point 1. above.
-//   3. Emitting std_store to store the results of 2. to the output views.
-//
-// An example output may resemble:
-//
-// ```
-//    loop.for %i = %c0 to %0 step %c1 {
-//      loop.for %j = %c0 to %1 step %c1 {
-//        loop.for %k = %c0 to %4 step %c1 {
-//          %11 = load %arg0[%i, %j] :
-//            memref<?x?xf32, stride_specification>
-//          %12 = load %arg1[%i, %j, %k] :
-//            memref<?x?x?xf32, stride_specification>
-//          %13 = load %arg2[%i, %k, %j] :
-//            memref<?x?x?xf32, stride_specification>
-//          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
-//            (index, index, index, f32, f32, f32) -> (f32, f32)
-//          store %14#0, %arg1[%i, %j, %k] :
-//            memref<?x?x?Xf32, stride_specification>
-//          store %14#1, %arg2[%i, %k, %j] :
-//            memref<?x?x?Xf32, stride_specification>
-//       }
-//      }
-//    }
-// ```
+/// Emits the MLIR for the scalar part of the indexed generic op by:
+///   1. Emitting load ops for each input and output view in order. This is
+///      achieved by applying the appropriate input or output map to the
+///      enclosing induction variables.
+///   2. Emitting a call to `op.fun()` that takes as arguments the induction
+///      variables and the scalars from point 1. above.
+///   3. Emitting store ops to store the results of 2. to the output views.
+///
+/// An example output may resemble:
+///
+/// ```
+///    loop.for %i = %c0 to %0 step %c1 {
+///      loop.for %j = %c0 to %1 step %c1 {
+///        loop.for %k = %c0 to %4 step %c1 {
+///          %11 = load %arg0[%i, %j] :
+///            memref<?x?xf32, stride_specification>
+///          %12 = load %arg1[%i, %j, %k] :
+///            memref<?x?x?xf32, stride_specification>
+///          %13 = load %arg2[%i, %k, %j] :
+///            memref<?x?x?xf32, stride_specification>
+///          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+///            (index, index, index, f32, f32, f32) -> (f32, f32)
+///          store %14#0, %arg1[%i, %j, %k] :
+///            memref<?x?x?Xf32, stride_specification>
+///          store %14#1, %arg2[%i, %k, %j] :
+///            memref<?x?x?Xf32, stride_specification>
+///       }
+///      }
+///    }
+/// ```
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 public:
@@ -461,31 +450,33 @@ public:
     unsigned nInputs = indexedGenericOp.getNumInputs();
     unsigned nOutputs = indexedGenericOp.getNumOutputs();
     unsigned nLoops = allIvs.size();
-    SmallVector<Value, 4> indexedValues(nLoops + nInputs + nOutputs);
-
-    for (unsigned i = 0; i < nLoops; ++i) {
-      indexedValues[i] = allIvs[i];
-    }
+    SmallVector<Value, 4> indexedValues;
+    indexedValues.reserve(nLoops + nInputs + nOutputs);
+    for (unsigned i = 0; i < nLoops; ++i)
+      indexedValues.push_back(allIvs[i]);
 
-    // 1.a. Emit std_load from input views.
+    // TODO(mravishankar): Avoid the loads if the corresponding argument of the
+    // region has no uses.
+    // 1.a. Emit load from input views.
     for (unsigned i = 0; i < nInputs; ++i) {
-      Value input = indexedGenericOp.getInput(i);
       auto indexing = makeCanonicalAffineApplies(
           b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs);
-      indexedValues[nLoops + i] = std_load(input, indexing);
+      // Pass input i through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(indexedGenericOp.getInput(i))(indexing));
     }
-
-    // 1.b. Emit std_load from output views.
+    // 1.b. Emit load from output views.
     for (unsigned i = 0; i < nOutputs; ++i) {
-      Value output = indexedGenericOp.getOutputBuffer(i);
       auto indexing = makeCanonicalAffineApplies(
           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs);
-      indexedValues[nLoops + nInputs + i] = std_load(output, indexing);
+      // Pass output i through IndexedValueType emits the proper load operation.
+      indexedValues.push_back(
+          IndexedValueType(indexedGenericOp.getOutputBuffer(i))(indexing));
     }
 
     // TODO(ntv): When a region inliner exists, use it.
     // 2. Inline region, currently only works for a single basic block.
-    // 3. Emit std_store.
+    // 3. Emit store.
     SmallVector<SmallVector<Value, 8>, 8> indexing;
     SmallVector<Value, 8> outputBuffers;
     for (unsigned i = 0; i < nOutputs; ++i) {
@@ -493,19 +484,19 @@ public:
           b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs));
       outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i));
     }
-    inlineRegionAndEmitStdStore(indexedGenericOp, indexedValues, indexing,
-                                outputBuffers);
+    inlineRegionAndEmitStore<IndexedValueType>(indexedGenericOp, indexedValues,
+                                               indexing, outputBuffers);
   }
 };
 
-// This struct is for factoring out the implementation and support template
-// instantiations in the following 2 cases:
-//   1. Appending to a list of patterns via RewritePatternList.
-//   2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
-// The implementation must work both in DRR and inside a RewritePattern. As a
-// consequence, (1) it is only allowed to emit new ops if the match is
-// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
-// encompassing pattern must take care of the erasure logic.
+/// This struct is for factoring out the implementation and support template
+/// instantiations in the following 2 cases:
+///   1. Appending to a list of patterns via RewritePatternList.
+///   2. Direct invocation via `linalgOpToLoops` and `linalgOpToAffineLoops`.
+/// The implementation must work both in DRR and inside a RewritePattern. As a
+/// consequence, (1) it is only allowed to emit new ops if the match is
+/// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
+/// encompassing pattern must take care of the erasure logic.
 template <typename LoopTy, typename ConcreteOpTy>
 class LinalgOpToLoopsImpl {
 public:
@@ -532,7 +523,7 @@ public:
   }
 };
 
-/// Generates loops nest using loop.parallel. loop.parallel is only used for the
+/// Generates loop nest using loop.parallel. loop.parallel is only used for the
 /// outer parallel loops. All other loops are generated using loop.for
 /// operation.
 template <typename ConcreteOpTy>
@@ -652,7 +643,7 @@ public:
   }
 };
 
-// Helper classes for type list expansion.
+/// Helper classes for type list expansion.
 template <typename LoopType, typename... LinalgOps>
 class RewritePatternList;
 
@@ -680,16 +671,16 @@ void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
                      >::build(patterns, ctx);
 }
 
-// Local folding pattern for AffineApplyOp that we can apply greedily.
-// This replaces AffineApplyOp by the proper value in cases where the associated
-// map is trivial. A trivial map here is defined as a map with a single result
-// and either:
-//   1. Zero operand + returns a single AffineConstantExpr
-//   2. One operand + returns a single AffineDimExpr
-//   3. One operands + returns a single AffineSymbolExpr
+/// Local folding pattern for AffineApplyOp that we can apply greedily.
+/// This replaces AffineApplyOp by the proper value in cases where the
+/// associated map is trivial.
+/// A trivial map here is defined as a map with a single result and either:
+///   1. Zero operand + returns a single AffineConstantExpr
+///   2. One operand + returns a single AffineDimExpr
+///   3. One operand + returns a single AffineSymbolExpr
 //
-// In the first case, the AffineApplyOp is replaced by a new constant. In the
-// other cases, it is replaced by its unique operand.
+/// In the first case, the AffineApplyOp is replaced by a new constant. In the
+/// other cases, it is replaced by its unique operand.
 struct FoldAffineOp : public RewritePattern {
   FoldAffineOp(MLIRContext *context)
       : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {}
index 7045782..dfe130a 100644 (file)
@@ -1,13 +1,15 @@
 // RUN: mlir-opt %s -convert-linalg-to-affine-loops | FileCheck %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
+// RUN: mlir-opt %s -convert-linalg-to-affine-loops -convert-linalg-to-llvm -o=/dev/null 2>&1
 
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
 
 // CHECK-DAG: #[[stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
 
+// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
+
 func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -53,3 +55,69 @@ func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1:
 //       CHECK:         affine.for %{{.*}} = 0 to %[[Q]] {
 //       CHECK:           affine.for %{{.*}} = 0 to %[[Z0]] {
 //       CHECK:            %[[SUM:.*]] = affine.apply #[[stride2Dilation1]](%{{.*}}, %{{.*}})
+
+func @conv_padding(%arg0: memref<?x?x?x?xf32>,
+                   %arg1: memref<?x?x?x?xf32>,
+                   %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
+                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
+                                    strides = [1, 1]} :
+    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @conv_padding
+//       CHECK: %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>, %{{.*}}: memref<?x?x?x?xf32>) {
+//       CHECK:   %[[ZERO:.*]] = constant 0.000000e+00 : f32
+//       CHECK:   %[[Z0:.*]] = dim %arg0, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Z1:.*]] = dim %arg0, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[Q:.*]] =  dim %arg0, 2 : memref<?x?x?x?xf32>
+//       CHECK:   %[[K:.*]] =  dim %arg0, 3 : memref<?x?x?x?xf32>
+//       CHECK:   %[[B:.*]] =  dim %arg1, 0 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X0:.*]] = dim %arg2, 1 : memref<?x?x?x?xf32>
+//       CHECK:   %[[X1:.*]] = dim %arg2, 2 : memref<?x?x?x?xf32>
+//       CHECK:   affine.for %{{.*}} = 0 to %[[B]] {
+//       CHECK:     affine.for %{{.*}} = 0 to %[[X0]] {
+//       CHECK:       affine.for %{{.*}} = 0 to %[[X1]] {
+//       CHECK:         affine.for %{{.*}} = 0 to %[[K]] {
+//       CHECK:           affine.for %{{.*}} = 0 to %[[Q]] {
+//       CHECK:             affine.for %{{.*}} = 0 to %[[Z0]] {
+//       CHECK:               affine.for %{{.*}} = 0 to %[[Z1]] {
+//       CHECK:                 %[[SUM0:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[SUM1:.*]] = affine.apply #{{.*}}(%{{.*}}, %{{.*}})
+//       CHECK:                 %[[IDX:.*]] = affine.max #[[clampMinMap]](%[[SUM0]])
+//       CHECK:                 %[[IDY:.*]] = affine.max #[[clampMinMap]](%[[SUM1]])
+// Padded conv involves an affine.max in the memory access which is not
+// allowed by affine.load. Override to always use an std.load.
+//       CHECK:                 %{{.*}} = load %{{.*}}[%{{.*}}, %[[IDX]], %[[IDY]], %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 %{{.*}} = affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+//       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+//       CHECK:                 affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
+
+//----------------------------------------------------------------------------//
+// Named ops to loops.
+//----------------------------------------------------------------------------//
+func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}
+// CHECK-LABEL: @named_batch_matmul
+//  CHECK-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECK-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECK: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECK: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECK: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECK: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECK: affine.for %[[b:.*]] = 0 to %[[B]] {
+//       CHECK:   affine.for %[[m:.*]] = 0 to %[[M]] {
+//       CHECK:     affine.for %[[n:.*]] = 0 to %[[N]] {
+//       CHECK:       affine.for %[[k:.*]] = 0 to %[[K]] {
+//       CHECK:       %[[va:.*]] = affine.load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vb:.*]] = affine.load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[vc:.*]] = affine.load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECK:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECK:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECK:       affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
index 7c71dbf..6075b98 100644 (file)
@@ -2,7 +2,7 @@
 // RUN: mlir-opt %s -convert-linalg-to-parallel-loops | FileCheck --check-prefix=CHECKPARALLEL %s
 
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
-// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -o=/dev/null 2>&1
 
 // CHECKLOOP-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 // CHECKLOOP-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
@@ -354,7 +354,6 @@ func @conv_view4(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %
 //       CHECKPARALLEL:           %{{.*}} = addf %{{.*}}, %{{.*}} : f32
 //       CHECKPARALLEL:           store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32, #[[strided4D]]>
 
-
 func @conv_padding(%arg0: memref<?x?x?x?xf32>,
                    %arg1: memref<?x?x?x?xf32>,
                    %arg2: memref<?x?x?x?xf32>) {
@@ -854,8 +853,8 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
 //  CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
 //  CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
 //   CHECKLOOP-NOT: loop.for
-//   CHECKLOOP-DAG: load %[[ARG0]][]
-//   CHECKLOOP-DAG: load %[[ARG1]][]
+//       CHECKLOOP: load %[[ARG0]][]
+//       CHECKLOOP: load %[[ARG1]][]
 //       CHECKLOOP: addf
 //       CHECKLOOP: store %{{.*}}, %[[ARG2]][]
 
@@ -864,7 +863,50 @@ func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
 //  CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
 //  CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
 //   CHECKPARALLEL-NOT: loop.for
-//   CHECKPARALLEL-DAG: load %[[ARG0]][]
-//   CHECKPARALLEL-DAG: load %[[ARG1]][]
+//       CHECKPARALLEL: load %[[ARG0]][]
+//       CHECKPARALLEL: load %[[ARG1]][]
 //       CHECKPARALLEL: addf
 //       CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]
+
+//----------------------------------------------------------------------------//
+// Named ops to loops.
+//----------------------------------------------------------------------------//
+func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul %A, %B, %C : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  return
+}
+// CHECKLOOP-LABEL: @named_batch_matmul
+//  CHECKLOOP-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKLOOP-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKLOOP: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECKLOOP: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECKLOOP: loop.for %[[b:.*]] = %{{.*}} to %[[B]] step %{{.*}} {
+//       CHECKLOOP:   loop.for %[[m:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
+//       CHECKLOOP:     loop.for %[[n:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
+//       CHECKLOOP:       loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} {
+//       CHECKLOOP:       %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKLOOP:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECKLOOP:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKLOOP:       store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+
+// CHECKPARALLEL-LABEL: @named_batch_matmul
+//  CHECKPARALLEL-SAME: %[[mA:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mB:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//  CHECKPARALLEL-SAME: %[[mC:[a-zA-Z0-9]+]]: memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[B:.*]] = dim %[[mA]], 0 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[M:.*]] = dim %[[mA]], 1 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[K:.*]] = dim %[[mA]], 2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: %[[N:.*]] = dim %[[mB]], 2 : memref<?x?x?xf32>
+//       CHECKPARALLEL: loop.parallel (%[[b:.*]], %[[m:.*]], %[[n:.*]]) = ({{.*}}) to (%[[B]], %[[M]], %[[N]]) step ({{.*}}) {
+//       CHECKPARALLEL:   loop.for %[[k:.*]] = %{{.*}} to %[[K]] step %{{.*}} {
+//       CHECKPARALLEL:       %[[va:.*]] = load %[[mA]][%[[b]], %[[m]], %[[k]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[vb:.*]] = load %[[mB]][%[[b]], %[[k]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[vc:.*]] = load %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
+//       CHECKPARALLEL:       %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32
+//       CHECKPARALLEL:       %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
+//       CHECKPARALLEL:       store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32>
index 0b88f2a..d796d19 100644 (file)
@@ -7,15 +7,15 @@
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test1Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test1Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test1Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test1Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(2, 0, {d0, d1}, context),
 //  IMPL-NEXT:  AffineMap::get(2, 0, {d1}, context),
 //  IMPL-NEXT:  AffineMap::get(2, 0, {d0}, context) };
 //
-//       IMPL:  Test1Op::regionBuilder(Block &block) {
+//       IMPL:  void Test1Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -32,10 +32,10 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test2Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test2Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test2Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test2Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(3, 0, {d0, d2}, context),
 //  IMPL-NEXT:  AffineMap::get(3, 0, {d2, d1}, context),
 //  IMPL-NEXT:  AffineMap::get(3, 0, {d0, d1}, context) };
@@ -57,10 +57,10 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //  ODS-NEXT:   NamedStructuredOpTraits
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
-// IMPL-LABEL:  Test3Op::referenceIterators() {
-//  IMPL-NEXT:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
+// IMPL-LABEL:  SmallVector<StringRef, 8> Test3Op::referenceIterators
+//       IMPL:  { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
 //
-//       IMPL:  Test3Op::referenceIndexingMaps() {
+//       IMPL:  SmallVector<AffineMap, 8> Test3Op::referenceIndexingMaps
 //       IMPL:  AffineMap::get(4, 0, {d0, d1, d3}, context),
 //  IMPL-NEXT:  AffineMap::get(4, 0, {d3, d2}, context),
 //  IMPL-NEXT:  AffineMap::get(4, 0, {d0, d1, d2}, context) };
index 424a297..d2dd1f5 100644 (file)
@@ -1472,7 +1472,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         [{{
           result.addOperands(views);
           result.addTypes(outputTypes);
-          buildNamedStructuredOpRegion<{0}>(
+          buildNamedStructuredOpRegionAndAttributes<{0}>(
             b, result, TypeRange(views), outputTypes);
         }]>
       ];
@@ -1481,7 +1481,13 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       }];
       let extraClassDeclaration = [{{
         llvm::Optional<SmallVector<StringRef, 8>> referenceIterators();
+        static SmallVector<StringRef, 8> referenceIterators(
+          TypeRange inputTypes, TypeRange outputTypes);
+
         llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps();
+        static SmallVector<AffineMap, 8> referenceIndexingMaps(
+          TypeRange inputTypes, TypeRange outputTypes);
+
         static void regionBuilder(Block &block);
       }];
   })FMT";
@@ -1503,7 +1509,13 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
                                        ComprehensionParsingState &state) {
   const char *referenceReferenceIteratorsFmt =
       R"FMT(
-    llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {
+    // This is temporary until we transition out of manually specified ops
+    // that should be auto-generated with linalg-ods-gen.
+    llvm::Optional<SmallVector<StringRef, 8>> {0}::referenceIterators() {{
+      llvm_unreachable("Unexpected missing `iterator_types` attribute.");
+    }
+    SmallVector<StringRef, 8> {0}::referenceIterators(
+      TypeRange inputTypes, TypeRange outputTypes) {
       return SmallVector<StringRef, 8>{{ {1} };
     })FMT";
 
@@ -1536,15 +1548,27 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
                                           StringRef cppOpName,
                                           ComprehensionParsingState &state) {
+  // 1. Generic string template for specifying reference indexing maps.
   const char *referenceIndexingMapsFmt =
       R"FMT(
-  llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {
-    MLIRContext *context = getContext();
+  // This is temporary until we transition out of manually specified ops that
+  // should be auto-generated with linalg-ods-gen.
+  llvm::Optional<SmallVector<AffineMap, 8>> {0}::referenceIndexingMaps() {{
+    llvm_unreachable("Unexpected missing `indexing_maps` attribute.");
+  }
+  SmallVector<AffineMap, 8> {0}::referenceIndexingMaps(
+    TypeRange inputTypes, TypeRange outputTypes) {
+    assert(!inputTypes.empty() && "At least one input expected");
+    MLIRContext *context = (*inputTypes.begin()).getContext();
     AffineExpr {1};
     bindDims(context, {1});
     return SmallVector<AffineMap, 8>{{ {2} };
   })FMT";
 
+  // 2. Print a comma-separated list of identifiers for the AffineExpr in
+  // `state.dims`. These will replace the `{1}` placeholder in both
+  // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
+  // identifiers are bound in the right order to the proper AffineDimExpr.
   std::string dimsStr;
   llvm::raw_string_ostream ss(dimsStr);
   llvm::interleaveComma(
@@ -1552,10 +1576,14 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
   ss.flush();
 
+  // 3. Print a comma-separated list of AffineMap constructors that use the
+  // identifiers from 1. The AffineExpr use the common arithmetic operators on
+  // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
+  // in return `SmallVector<AffineMap, 8>{{ {2} };`.
   std::string mapsStr;
   llvm::raw_string_ostream mapsStringStream(mapsStr);
   SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
-  for (auto it : state.orderedTensorArgs)
+  for (const auto &it : state.orderedTensorArgs)
     orderedUses[it.second] = it.first;
   llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
     assert(u.indexingMap);
@@ -1576,6 +1604,7 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
   });
   mapsStringStream.flush();
 
+  // 4. Apply format to 1. using 2. and 3.
   os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
 }