[mlir][Linalg] Disable const -> linalg.generic when fused op is illegal.
authorMaheshRavishankar <ravishankarm@google.com>
Mon, 12 Apr 2021 15:49:45 +0000 (08:49 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Mon, 12 Apr 2021 17:15:54 +0000 (10:15 -0700)
Fusing a constant with a linalg.generic operation can result in the
fused operation being illegal since the loop bound computation
fails. Avoid such fusions.

Differential Revision: https://reviews.llvm.org/D100272

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir

index 713de7b..a404cbd 100644 (file)
@@ -1103,6 +1103,12 @@ public:
           linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
       fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
 
+      // Check if the operation shapes to loops map is computable.
+      if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
+        return rewriter.notifyMatchFailure(
+            linalgOp, "fused op loop bound computation failed");
+      }
+
       // The operands list is same as the linalgOp with the argument for
       // constant index dropped.
       SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
index 7983fe1..00d0995 100644 (file)
@@ -678,3 +678,26 @@ func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8
   } -> tensor<1x8xindex>
   return %1 : tensor<1x8xindex>
 }
+
+// -----
+
+// CHECK-LABEL: func @no_fuse_constant_with_reduction
+func @no_fuse_constant_with_reduction() -> tensor<3xf32>
+{
+  //      CHECK: %[[CONST:.+]] = constant {{.+}} : tensor<3x2xf32>
+  //      CHECK: %[[RESULT:.+]] = linalg.generic
+  // CHECK-SAME:   ins(%[[CONST]] : tensor<3x2xf32>)
+  //      CHECK: return %[[RESULT]]
+  %three = constant dense<3.0> : tensor<3x2xf32>
+  %init = linalg.init_tensor [3] : tensor<3xf32>
+  %result = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+     ins(%three : tensor<3x2xf32>) outs(%init : tensor<3xf32>) {
+     ^bb0(%arg0 : f32, %arg1 : f32):
+        %0 = addf %arg0, %arg1 : f32
+        linalg.yield %0 : f32
+  } -> tensor<3xf32>
+  return %result : tensor<3xf32>
+}