return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
folder);
+ } else if (auto indexedGenericOpConsumer =
+ dyn_cast<IndexedGenericOp>(consumer)) {
+ return FuseTensorReshapeOpAsProducer<IndexedGenericOp>::fuse(
+ reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter,
+ folder);
}
} else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
if (genericOpProducer.hasTensorSemantics())
return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+ } else if (auto indexedGenericOpProducer =
+ dyn_cast<IndexedGenericOp>(producer)) {
+ if (indexedGenericOpProducer.hasTensorSemantics())
+ return FuseTensorReshapeOpAsConsumer<IndexedGenericOp>::fuse(
+ indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
}
return nullptr;
}
// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
// CHECK: linalg.yield %[[VAL4]] : i32
// CHECK-NOT: linalg.indexed_generic
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+ -> tensor<?x?x4x?xi32> {
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k)>,
+ affine_map<(i, j, k, l) -> (l)>] :
+ tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = addi %arg6, %2 : i32
+ linalg.yield %3 : i32
+ }: tensor<?x?x4x?xi32> -> tensor<?x?x4x?xi32>
+ return %1 : tensor<?x?x4x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+ -> tensor<?x?xi32> {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = addi %arg6, %2 : i32
+ linalg.yield %3 : i32
+ }: tensor<?x?x4x5xi32> -> tensor<?x?x4x5xi32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?x4x5xi32> into tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape