[mlir][linalg] update fusion on tensors to support linalg index operations.
authorTobias Gysi <gysit@google.com>
Tue, 20 Apr 2021 05:28:26 +0000 (05:28 +0000)
committerTobias Gysi <gysit@google.com>
Tue, 20 Apr 2021 06:13:04 +0000 (06:13 +0000)
The patch replaces the index operations in the body of fused producers and linearizes the indices after expansion.

Differential Revision: https://reviews.llvm.org/D100479

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir

index a404cbd..4d6045a 100644 (file)
@@ -28,10 +28,6 @@ using namespace mlir::linalg;
 /// Implementation of fusion of generic ops and indexed_generic ops.
 static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
                                      unsigned consumerIdx) {
-  // TODO: remove once index ops are supported.
-  if (producer.hasIndexSemantics() || consumer.hasIndexSemantics())
-    return false;
-
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
     return false;
@@ -138,7 +134,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
   // 1. Map consumer indices to fusedBlock indices 1-1.
   mapper.map(consumerBlock.getArguments().take_front(numConsumerIndices),
              fusedBlock->getArguments().take_front(numConsumerIndices));
-  // 2. Embed producer indices into fusedBlock index space 1-1.
+  // 2a. Embed producer indices into fusedBlock index space 1-1.
   for (auto it :
        llvm::zip(producerBlock.getArguments().take_front(numProducerIndices),
                  fusedBlock->getArguments().take_front(numProducerIndices))) {
@@ -148,6 +144,28 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
         fusedBlock->getArguments().take_front(numFusedOpIndices));
     mapper.map(std::get<0>(it), newIndex);
   }
+  // 2b. Replace the producer index operations by index operations placed in the
+  // fused block using the `consumerToProducerLoopsMap` to map the index spaces.
+  unsigned numFusedOpLoops =
+      std::max(producer.getNumLoops(), consumer.getNumLoops());
+  if (producer.hasIndexSemantics()) {
+    SmallVector<Value> fusedIndices;
+    fusedIndices.reserve(numFusedOpLoops);
+    llvm::transform(llvm::seq<int64_t>(0, numFusedOpLoops),
+                    std::back_inserter(fusedIndices), [&](int64_t dim) {
+                      return rewriter.create<IndexOp>(producer.getLoc(), dim);
+                    });
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
+      Value newIndex = rewriter.create<mlir::AffineApplyOp>(
+          producer.getLoc(),
+          consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
+      // Replace the producer index operation by the index value computed in the
+      // fused block. All remaining operations in the producer block are later
+      // on cloned to the fused block.
+      rewriter.replaceOp(indexOp, newIndex);
+    }
+  }
   // TODO: allow fusing the producer of an output operand.
   assert(consumerIdx < consumer.getNumInputs() &&
          "expected producer of input operand");
@@ -329,8 +347,8 @@ fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
       invProducerResultIndexMap.compose(consumerResultIndexMap);
 
   generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer,
-                              consumerToProducerLoopsMap, consumerIdx,
-                              consumer.getNumLoops());
+                                   consumerToProducerLoopsMap, consumerIdx,
+                                   consumer.getNumLoops());
   return SmallVector<Value, 1>(fusedOp->getResults());
 }
 
@@ -602,17 +620,16 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   return success();
 }
 
-/// To expand an indexed_generic operation, the body of the indexed generic op
-/// need to be modified appropriately. Specifically, uses of arguments for
-/// induction variables in the original operation need to be replaced with
-/// linearization of the corresponding arguments in the expanded op. That
-/// requires the shape of the expanded dimensions (at least all but the most
-/// significant. For now check that these are all statically sized. Note that
-/// this could be extended to handle dynamic case, but the implementation below
-/// uses `affine.apply` which seems to have issues when the shapes are not
-/// static.
-LogicalResult isIndexedGenericOpExpandable(LinalgOp linalgOp,
-                                           const ExpansionInfo &expansionInfo) {
+/// Epanding the body of a linalg operation requires adaptations of the accessed
+/// loop indices. Specifically, access of indices in the original operation need
+/// to be replaced with linearizations of indices in the expanded op. That
+/// requires the shape of the expanded dimensions to be static (at least all but
+/// the most significant). For now check that these are all statically sized.
+/// Note that this could be extended to handle dynamic case, but the
+/// implementation below uses `affine.apply` which seems to have issues when the
+/// shapes are not static.
+LogicalResult isIndexedOpExpandable(LinalgOp linalgOp,
+                                    const ExpansionInfo &expansionInfo) {
   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
     if (expandedShape.size() == 1)
@@ -734,6 +751,49 @@ static void buildExpandedIndexedGenericOpRegion(
                        argReplacements);
 }
 
+/// Update the body of an expanded linalg operation having index semantics. The
+/// indices of the original operation need to be recovered by linearizing the
+/// indices of the correspoding dimensions of the expanded operation. For now it
+/// is assumed that the shapes of the expanded operation needed for
+/// linearization are static.
+static void updateExpandedIndexOpRegion(PatternRewriter &rewriter, Location loc,
+                                        Region &fusedRegion,
+                                        const ExpansionInfo &expansionInfo) {
+  // Replace the original indices by the linearization of the expanded indices.
+  for (IndexOp indexOp :
+       llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
+    ArrayRef<int64_t> expandedDims =
+        expansionInfo.getExpandedDims(indexOp.dim());
+    assert(!expandedDims.empty() && "expected valid expansion info");
+
+    // Skip index operations that are not affected by the expansion.
+    if (expandedDims.size() == 1 &&
+        expandedDims.front() == (int64_t)indexOp.dim())
+      continue;
+
+    // Linearize the expanded indices of the original index dimension.
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointAfter(indexOp);
+    ArrayRef<int64_t> expandedDimsShape =
+        expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
+    SmallVector<Value> expandedIndices;
+    expandedIndices.reserve(expandedDims.size() - 1);
+    llvm::transform(
+        expandedDims.drop_front(), std::back_inserter(expandedIndices),
+        [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
+    Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+    for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
+      assert(!ShapedType::isDynamic(std::get<0>(it)));
+      AffineExpr idx, acc;
+      bindDims(rewriter.getContext(), idx, acc);
+      newIndex = rewriter.create<AffineApplyOp>(
+          indexOp.getLoc(), idx + acc * std::get<0>(it),
+          ValueRange{std::get<1>(it), newIndex});
+    }
+    rewriter.replaceOp(indexOp, newIndex);
+  }
+}
+
 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
 /// conditions have been satisfied.
@@ -748,6 +808,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
   RankedTensorType expandedType =
       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
+  bool hasIndexSemantics = linalgOp.hasIndexSemantics() ||
+                           isa<IndexedGenericOp>(linalgOp.getOperation());
 
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(linalgOp, fusedTensorIndex,
@@ -755,8 +817,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
                                    expandedType.getShape())))
     return llvm::None;
 
-  if (isa<IndexedGenericOp>(linalgOp.getOperation()) &&
-      failed(isIndexedGenericOpExpandable(linalgOp, expansionInfo)))
+  if (hasIndexSemantics &&
+      failed(isIndexedOpExpandable(linalgOp, expansionInfo)))
     return llvm::None;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -823,6 +885,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
                                         fusedRegion, expansionInfo);
   }
 
+  // Update the index accesses after the expansion.
+  if (linalgOp.hasIndexSemantics())
+    updateExpandedIndexOpRegion(rewriter, loc, fusedRegion, expansionInfo);
+
   // Reshape the result values to their original shape if this is a collapsing
   // reshape folded into its consumer.
   SmallVector<Value, 1> resultVals;
@@ -1261,6 +1327,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
           context, options.controlElementwiseOpsFusionFn);
   populateFoldReshapeOpsByExpansionPatterns(
       patterns, options.allowFoldingUnitDimReshapes);
+  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
index 00d0995..40c5265 100644 (file)
@@ -360,6 +360,58 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
 // -----
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
+                                       %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+  %3 = linalg.generic {
+    indexing_maps = [#map0, #map0, #map0],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%arg0, %arg1  : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):       // no predecessors
+      %10 = addi %arg2, %arg3 : i32
+      linalg.yield %10 : i32
+    } -> tensor<?x?xi32>
+  %4 = linalg.generic {
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%3 : tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg2: i32, %arg3: i32):       // no predecessors
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %5 = index_cast %idx0 : index to i32
+      %6 = index_cast %idx1 : index to i32
+      %7 = addi %arg2, %5 : i32
+      %8 = subi %7, %6 : i32
+      linalg.yield %8 : i32
+    } -> tensor<?x?xi32>
+  return %4 : tensor<?x?xi32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @producer_indexed_consumer_fusion
+//      CHECK: linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
+//      CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+//      CHECK:   %[[VAL1:.+]] = addi %[[ARG0]], %[[ARG1]] : i32
+//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
+//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
+//      CHECK:   %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32
+//      CHECK:   %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
+//      CHECK:   %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
+//      CHECK:   %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
+//      CHECK:   linalg.yield %[[VAL3]] : i32
+//  CHECK-NOT: linalg.generic
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
                                            %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
@@ -409,6 +461,58 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
+                                       %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+  %3 = linalg.generic {
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg4: i32, %arg5: i32):       // no predecessors
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %4 = index_cast %idx0 : index to i32
+      %5 = index_cast %idx1 : index to i32
+      %6 = addi %arg4, %4 : i32
+      %7 = subi %6, %5 : i32
+      linalg.yield %7 : i32
+    } -> tensor<?x?xi32>
+  %4 = linalg.generic {
+    indexing_maps = [#map0, #map0, #map0],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):       // no predecessors
+      %10 = addi %arg2, %arg3 : i32
+      linalg.yield %10 : i32
+    } -> tensor<?x?xi32>
+  return %4 : tensor<?x?xi32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_producer_consumer_fusion
+//       CHECK: linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
+//      CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
+//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
+//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
+//      CHECK:   %[[ADD_OPERAND:.+]] = index_cast %[[IDX0]] : index to i32
+//      CHECK:   %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
+//      CHECK:   %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32
+//      CHECK:   %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
+//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32
+//      CHECK:   linalg.yield %[[VAL3]] : i32
+//   CHECK-NOT: linalg.generic
+
+// -----
+
 // The indices of the first indexed_generic op are swapped after fusion.
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
@@ -465,6 +569,69 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
 
 // -----
 
+// The indices of the first indexed_generic op are swapped after fusion.
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
+                                               -> tensor<?x?xi32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
+  %1 = memref.dim %arg0, %c1 : tensor<?x?xi32>
+  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xi32>
+  %3 = linalg.generic {
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg2: i32, %arg3: i32):       // no predecessors
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %4 = index_cast %idx0 : index to i32
+      %5 = index_cast %idx1 : index to i32
+      %6 = addi %arg2, %4 : i32
+      %7 = subi %5, %6 : i32
+      linalg.yield %7 : i32
+    } -> tensor<?x?xi32>
+  %4= linalg.generic {
+    indexing_maps = [#map1, #map1],
+    iterator_types = ["parallel", "parallel"] }
+    ins(%3 : tensor<?x?xi32>)
+    outs(%2 : tensor<?x?xi32>) {
+    ^bb0(%arg2: i32, %arg3: i32):       // no predecessors
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %5 = index_cast %idx0 : index to i32
+      %6 = index_cast %idx1 : index to i32
+      %7 = addi %arg2, %5 : i32
+      %8 = subi %7, %6 : i32
+      linalg.yield %8 : i32
+    } -> tensor<?x?xi32>
+  return %4 : tensor<?x?xi32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_producer_indexed_consumer_fusion
+//       CHECK: linalg.generic
+// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
+//      CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
+//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
+//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
+//      CHECK:   %[[ADD_OPERAND1:.+]] = index_cast %[[IDX1]] : index to i32
+//      CHECK:   %[[SUB_OPERAND1:.+]] = index_cast %[[IDX0]] : index to i32
+//      CHECK:   %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND1]] : i32
+//      CHECK:   %[[VAL2:.+]] = subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
+//      CHECK:   %[[IDX2:.+]] = linalg.index 0 : index
+//      CHECK:   %[[IDX3:.+]] = linalg.index 1 : index
+//      CHECK:   %[[ADD_OPERAND2:.+]] = index_cast %[[IDX2]] : index to i32
+//      CHECK:   %[[SUB_OPERAND2:.+]] = index_cast %[[IDX3]] : index to i32
+//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
+//      CHECK:   %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
+//      CHECK:   linalg.yield %[[VAL4]] : i32
+//   CHECK-NOT: linalg.generic
+
+// -----
+
 func @scalar_indexed_generic_fusion
   (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
 {
@@ -507,6 +674,48 @@ func @scalar_indexed_generic_fusion
 
 // -----
 
+func @scalar_generic_fusion
+  (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
+{
+  %c0 = constant 0 : index
+  %cst = constant dense<1.000000e+00> : tensor<10xf32>
+  %0 = linalg.init_tensor [] : tensor<f32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+     iterator_types = []}
+    ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
+    ^bb0(%arg2: i32, %arg3: f32):  // no predecessors
+      %3 = index_cast %arg2 : i32 to index
+      %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
+      linalg.yield %4 : f32
+    } -> tensor<f32>
+  %2 = linalg.init_tensor [10] : tensor<10xf32>
+  %3 = linalg.generic
+   {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
+                     affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]}
+    ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
+    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):  // no predecessors
+      %4 = mulf %arg2, %arg3 : f32
+      linalg.yield %4 : f32
+    } -> tensor<10xf32>
+  return %3 : tensor<10xf32>
+}
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+//       CHECK: func @scalar_generic_fusion
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
+//       CHECK:   %[[T0:.+]] = linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     iterator_types = ["parallel"]
+//  CHECK-SAME:     ins(%[[ARG1]] : tensor<i32>)
+//       CHECK:     tensor.extract %[[ARG0]]
+//       CHECK:     linalg.yield
+//       CHECK   return %[[T0]]
+
+// -----
+
 func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) {
   %cst = constant dense<1.0> : tensor<4xf32>
   %1 = linalg.init_tensor [4] : tensor<4xf32>
@@ -655,32 +864,6 @@ func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tenso
 
 // -----
 
-// CHECK-LABEL: func @index_op(
-// CHECK-COUNT-2: linalg.generic
-func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> {
-  %0 = linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : tensor<1x8xindex>) {
-  ^bb0(%a: index):   // no predecessors
-    %2 = linalg.index 1 : index
-    linalg.yield %2 : index
-  } -> tensor<1x8xindex>
-  %1 = linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  ins(%0 : tensor<1x8xindex>)
-  outs(%arg1 : tensor<1x8xindex>) {
-  ^bb0(%a: index, %b: index):   // no predecessors
-    %2 = linalg.index 0 : index
-    %3 = addi %2, %a : index
-    linalg.yield %3 : index
-  } -> tensor<1x8xindex>
-  return %1 : tensor<1x8xindex>
-}
-
-// -----
-
 // CHECK-LABEL: func @no_fuse_constant_with_reduction
 func @no_fuse_constant_with_reduction() -> tensor<3xf32>
 {
index 9c0fe41..0e7239e 100644 (file)
@@ -237,6 +237,60 @@ func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
+                                         %arg1 : tensor<?x?x?xi32>) ->
+                                         tensor<?x?x?xi32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map1],
+     iterator_types = ["parallel", "parallel", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
+      outs(%0 : tensor<?x?x?xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32, %s: i32):
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %idx2 = linalg.index 2 : index
+      %1 = muli %arg3, %arg4 : i32
+      %2 = index_cast %idx0 : index to i32
+      %3 = addi %1, %2 : i32
+      %4 = index_cast %idx1 : index to i32
+      %5 = addi %3, %4 : i32
+      %6 = index_cast %idx2 : index to i32
+      %7 = addi %5, %6 : i32
+      linalg.yield %7 : i32
+  } -> tensor<?x?x?xi32>
+  return %1 : tensor<?x?x?xi32>
+}
+
+// Only check the body in the indexed version of the test.
+//       CHECK: #[[MAP:.+]] =  affine_map<(d0, d1) -> (d0 + d1 * 4)>
+//       CHECK: func @indexed_consumer_reshape_producer_fusion
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.*}}(
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32,
+//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32)
+//   CHECK-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
+//   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
+//   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
+//   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
+//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
+//       CHECK:     %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]]
+//       CHECK:     %[[T5:.+]] = index_cast %[[T3]]
+//       CHECK:     %[[T6:.+]] = addi %[[T4]], %[[T5]]
+//       CHECK:     %[[T7:.+]] = index_cast %[[IDX2]]
+//       CHECK:     %[[T8:.+]] = addi %[[T6]], %[[T7]]
+//       CHECK:     %[[T9:.+]] = index_cast %[[IDX3]]
+//       CHECK:     %[[T10:.+]] = addi %[[T8]], %[[T9]]
+//       CHECK:     linalg.yield %[[T10]]
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
                                          %arg1 : tensor<?x?xi32>) ->
@@ -280,6 +334,53 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
+                                         %arg1 : tensor<?x?xi32>) ->
+                                         tensor<?x?x4x5xi32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map0, #map0],
+     iterator_types = ["parallel", "parallel"]}
+       ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+      outs(%arg0 : tensor<?x?xi32>) {
+    ^bb0(%arg3: i32, %arg4: i32, %s: i32):       // no predecessors
+      %idx0 = linalg.index 0 : index
+      %idx1 = linalg.index 1 : index
+      %1 = muli %arg3, %arg4 : i32
+      %2 = index_cast %idx0 : index to i32
+      %3 = addi %1, %2 : i32
+      %4 = index_cast %idx1 : index to i32
+      %5 = addi %3, %4 : i32
+      linalg.yield %5 : i32
+  } -> tensor<?x?xi32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?xi32> into tensor<?x?x4x5xi32>
+  return %1 : tensor<?x?x4x5xi32>
+}
+
+// Only check the body in the indexed version of the test.
+//       CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)>
+//       CHECK: func @indexed_producer_reshape_consumer_fusion
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.*}}(
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: i32, %[[ARG4:[a-zA-Z0-9]+]]: i32,
+//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: i32)
+//   CHECK-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
+//   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
+//   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
+//   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
+//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]])
+//       CHECK:     %[[T4:.+]] = muli %[[ARG3]], %[[ARG4]]
+//       CHECK:     %[[T5:.+]] = index_cast %[[IDX0]]
+//       CHECK:     %[[T6:.+]] = addi %[[T4]], %[[T5]]
+//       CHECK:     %[[T7:.+]] = index_cast %[[T3]]
+//       CHECK:     %[[T8:.+]] = addi %[[T6]], %[[T7]]
+//       CHECK:     linalg.yield %[[T8]]
+
+// -----
+
 func @reshape_as_consumer_permutation
   (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
     -> tensor<2x3x4x5x6x7xi32> {
@@ -350,6 +451,82 @@ func @reshape_as_consumer_permutation
 
 // -----
 
+func @reshape_as_consumer_permutation
+  (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
+    -> tensor<2x3x4x5x6x7xi32> {
+  %shape = linalg.init_tensor [6, 4, 210] : tensor<6x4x210xi32>
+  %c = linalg.generic {
+         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
+                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+          ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
+          outs(%shape : tensor<6x4x210xi32>) {
+       ^bb0(%arg3 : i32, %arg4: i32, %s: i32):
+         %idx0 = linalg.index 0 : index
+         %idx1 = linalg.index 1 : index
+         %idx2 = linalg.index 2 : index
+         %1 = addi %arg3, %arg4 : i32
+         %2 = index_cast %idx0 : index to i32
+         %3 = addi %1, %2 : i32
+         %4 = index_cast %idx1 : index to i32
+         %5 = addi %3, %4 : i32
+         %6 = index_cast %idx2 : index to i32
+         %7 = addi %5, %6 : i32
+         linalg.yield %7 : i32
+       } -> tensor<6x4x210xi32>
+  %d = linalg.tensor_reshape %c
+         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
+       : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+  return %d : tensor<2x3x4x5x6x7xi32>
+}
+
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
+//   CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+//   CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
+//   CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//   CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)>
+//       CHECK: func @reshape_as_consumer_permutation
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
+//  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
+//   CHECK-DAG:   %[[T1:.+]] = linalg.tensor_reshape %[[ARG0]]
+//  CHECK-SAME:     [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+//   CHECK-DAG:   %[[T2:.+]] = linalg.tensor_reshape %[[ARG1]]
+//  CHECK-SAME:     [#[[MAP3]], #[[MAP4]]]
+//   CHECK-DAG:   %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7]
+//       CHECK:   %[[T4:.+]] = linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]]
+//  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
+//  CHECK-SAME:     outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>)
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32,
+//  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9]+]]: i32)
+//   CHECK-DAG:       %[[IDX0:.+]] = linalg.index 0 : index
+//   CHECK-DAG:       %[[IDX1:.+]] = linalg.index 1 : index
+//   CHECK-DAG:       %[[IDX2:.+]] = linalg.index 2 : index
+//   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
+//   CHECK-DAG:       %[[IDX4:.+]] = linalg.index 4 : index
+//   CHECK-DAG:       %[[IDX5:.+]] = linalg.index 5 : index
+//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]])
+//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]])
+//   CHECK-DAG:       %[[T7:.+]] = addi %[[ARG8]], %[[ARG9]]
+//       CHECK:       %[[T8:.+]] = index_cast %[[T5]]
+//       CHECK:       %[[T9:.+]] = addi %[[T7]], %[[T8]]
+//       CHECK:       %[[T10:.+]] = index_cast %[[T6]]
+//       CHECK:       %[[T11:.+]] = addi %[[T9]], %[[T10]]
+//       CHECK:       %[[T12:.+]] = index_cast %[[IDX5]]
+//       CHECK:       %[[T13:.+]] = addi %[[T11]], %[[T12]]
+
+// -----
+
 func @reshape_as_producer_projected_permutation(
     %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
 {
@@ -407,6 +584,66 @@ func @reshape_as_producer_projected_permutation(
 
 // -----
 
+func @reshape_as_producer_projected_permutation(
+    %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                                    affine_map<(d0, d1, d2) -> (d2)>]
+    : tensor<33x8x?xi32> into tensor<264x?xi32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel"]}
+     ins(%0 : tensor<264x?xi32>)
+    outs(%shape : tensor<264x?x4xi32>) {
+  ^bb0(%arg1: i32, %s: i32):  // no predecessors
+    %idx0 = linalg.index 0 : index
+    %idx1 = linalg.index 1 : index
+    %idx2 = linalg.index 2 : index
+    %2 = index_cast %idx0 : index to i32
+    %3 = addi %arg1, %2 : i32
+    %4 = index_cast %idx1 : index to i32
+    %5 = addi %3, %4 : i32
+    %6 = index_cast %idx2 : index to i32
+    %7 = addi %5, %6 : i32
+    linalg.yield %7 : i32
+  } -> tensor<264x?x4xi32>
+  return %1 : tensor<264x?x4xi32>
+}
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)>
+//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+//       CHECK: @reshape_as_producer_projected_permutation
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
+//       CHECK:   %[[RES:.+]] = linalg.generic
+//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:     ins(%[[ARG0]] : tensor<33x8x?xi32>)
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: i32,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: i32)
+//   CHECK-DAG:       %[[IDX0:.+]] = linalg.index 0 : index
+//   CHECK-DAG:       %[[IDX1:.+]] = linalg.index 1 : index
+//   CHECK-DAG:       %[[IDX2:.+]] = linalg.index 2 : index
+//   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
+//   CHECK-DAG:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
+//       CHECK:       %[[T1:.+]] = index_cast %[[T0]] : index to i32
+//       CHECK:       %[[T2:.+]] = addi %[[ARG1]], %[[T1]] : i32
+//       CHECK:       %[[T3:.+]] = index_cast %[[IDX2]] : index to i32
+//       CHECK:       %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32
+//       CHECK:       %[[T5:.+]] = index_cast %[[IDX3]] : index to i32
+//       CHECK:       %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32
+//       CHECK:       linalg.yield %[[T6]] : i32
+//       CHECK:    %[[RES2:.+]] = linalg.tensor_reshape %[[RES]]
+//  CHECK-SAME:      [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
+//       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1, d0)>
 func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,