[mlir][Linalg] Drop check for output indexing maps.
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 26 Aug 2022 16:15:08 +0000 (16:15 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Fri, 26 Aug 2022 16:15:55 +0000 (16:15 +0000)
The current check for form of the output indexing maps disallows
generic ops that return both a reduced and unreduced value. Such an op
seems like it could fall within the scope of a Strucutred op. Drop the
check. The only load-bearing place this was found to cause isseus was
during vectorization, but the fix for that seems to be simple.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D132716

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/vectorization.mlir

index b940e41..d2d8003 100644 (file)
@@ -683,27 +683,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   SmallVector<unsigned> redDims;
   linalgOp.getReductionDims(redDims);
 
-  // Output tensor indexing map may not depend on reduction indices.
-  for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
-    AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand);
-    for (AffineExpr expr : indexingMap.getResults()) {
-      for (unsigned pos : redDims) {
-        if (expr.isFunctionOfDim(pos)) {
-          std::string exprStr;
-          {
-            llvm::raw_string_ostream os(exprStr);
-            os << expr;
-          }
-          return op->emitOpError(
-                     "unexpected output tensor expression in indexing map #")
-                 << (opOperand->getOperandNumber() - linalgOp.getNumInputs())
-                 << " a.k.a '" << exprStr
-                 << "' is function of reduction iterator 'd" << pos << "'";
-        }
-      }
-    }
-  }
-
   if (!linalgOp.getShapesToLoopsMap())
     return op->emitOpError("expected the shape-to-loops map to be non-null");
 
index 6975937..94f338a 100644 (file)
@@ -539,6 +539,10 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
     return failure();
   }
   for (OpOperand *opOperand : op.getOutputOperands()) {
+    AffineMap indexingMap = op.getTiedIndexingMap(opOperand);
+    if (indexingMap.isPermutation())
+      continue;
+
     Operation *reduceOp = matchLinalgReduction(opOperand);
     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
       LDBG("reduction precondition failed: reduction detection failed");
index 6881dce..a7fe2f0 100644 (file)
@@ -269,21 +269,6 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
 
 // -----
 
-func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
-                                 %arg1: tensor<?xf32>) {
-  // expected-error @+1 {{unexpected output tensor expression in indexing map #0 a.k.a 'd0' is function of reduction iterator 'd0'}}
-  %0 = linalg.generic {
-    indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
-    iterator_types = ["reduction"]}
-       ins(%arg0 : memref<?xf32, affine_map<(i)[off]->(off + i)>>)
-      outs(%arg1 : tensor<?xf32>) {
-    ^bb(%i: f32, %j: f32):
-      linalg.yield %i: f32
-  } -> tensor<?xf32>
-}
-
-// -----
-
 func.func @generic(%arg0: memref<?x?xf32>) {
   // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}}
   linalg.generic  {
index ba21a14..70857b7 100644 (file)
@@ -350,3 +350,24 @@ func.func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?x
   return %1 : tensor<?x?xf32>
 }
 // CHECK: %{{.+}} = linalg.fill ins(%{{.+}} : f32) outs(%{{.+}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+// -----
+
+func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
+    %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?x?xf32>, %arg3 : tensor<?x?xf32>) ->
+    (tensor<?x?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"]}
+      ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2, %arg3 : tensor<?x?x?xf32>, tensor<?x?xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32):
+      %1 = arith.mulf %b0, %b1 : f32
+      %2 = arith.addf %1, %b3 : f32
+      linalg.yield %1, %2 : f32, f32
+  } -> (tensor<?x?x?xf32>, tensor<?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?x?xf32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @mixed_parallel_reduced_results
+//       CHECK:     linalg.generic
index bbc36b1..2295305 100644 (file)
@@ -1102,3 +1102,35 @@ func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf
     } -> tensor<6x6x3x3xf32>
   return %result : tensor<6x6x3x3xf32>
 }
+
+// -----
+
+// Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values.
+func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>,
+    %arg1 : tensor<2x4xf32>, %arg2 : tensor<2x4x8xf32>, %arg3 : tensor<2x4xf32>) ->
+    (tensor<2x4x8xf32>, tensor<2x4xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>,
+                       affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"]}
+      ins(%arg0, %arg1 : tensor<2x4x8xf32>, tensor<2x4xf32>)
+      outs(%arg2, %arg3 : tensor<2x4x8xf32>, tensor<2x4xf32>) {
+    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32):
+      %1 = arith.mulf %b0, %b1 : f32
+      %2 = arith.addf %1, %b3 : f32
+      linalg.yield %1, %2 : f32, f32
+  } -> (tensor<2x4x8xf32>, tensor<2x4xf32>)
+  return %0#0, %0#1 : tensor<2x4x8xf32>, tensor<2x4xf32>
+}
+// CHECK-LABEL: func @mixed_parallel_reduced_results(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x4x8xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<2x4x8xf32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+//   CHECK-DAG:   %[[V0:.+]] = vector.transfer_read %[[ARG0]]
+//   CHECK-DAG:   %[[V1:.+]] = vector.transfer_read %[[ARG1]]
+//   CHECK-DAG:   %[[V2:.+]] = vector.transfer_read %[[ARG3]]
+//   CHECK-DAG:   %[[MUL:.+]] = arith.mulf %[[V0]], %[[V1]]
+//   CHECK-DAG:   %[[ADD:.+]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]]
+//   CHECK-DAG:   vector.transfer_write %[[MUL]], %[[ARG2]]
+//   CHECK-DAG:   vector.transfer_write %[[ADD]], %[[ARG3]]