Support AllocOp terminal in Linalg::AliasAnalysis.
authorNicolas Vasilache <ntv@google.com>
Mon, 7 Oct 2019 16:00:39 +0000 (09:00 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 Oct 2019 16:01:18 +0000 (09:01 -0700)
Now that linalg.view and strided memrefs are unified, there is no reason to
disallow AllocOp in alias analysis. This CLs adds support for AllocOp which allows writing shorter tests that do not require explicitly creating a view for
each operation.

PiperOrigin-RevId: 273303060

mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/CMakeLists.txt
mlir/test/Dialect/Linalg/fusion.mlir

index 8e304be..14db309 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
 
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -64,6 +65,10 @@ Value *Aliases::find(Value *v) {
   while (true) {
     if (isa<BlockArgument>(v))
       return v;
+    if (auto alloc = dyn_cast_or_null<AllocOp>(v->getDefiningOp())) {
+      if (isStrided(alloc.getType()))
+        return alloc.getResult();
+    }
     if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
       auto it = aliases.insert(std::make_pair(v, find(slice.view())));
       return it.first->second;
index bd86893..f50229b 100644 (file)
@@ -21,5 +21,6 @@ add_dependencies(MLIRLinalg
   MLIRAnalysis
   MLIRLinalgOpsIncGen
   MLIRLinalgLibraryOpsIncGen
+  MLIRStandardOps
   MLIRStandardToLLVM
   )
index 922417f..c2ef8eb 100644 (file)
@@ -366,3 +366,46 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?x
 //       CHECK:         addf
 //       CHECK:       linalg.generic
 //       CHECK:         mulf
+
+
+func @pointwise_no_view(%M: index, %N: index) {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %A = alloc (%M, %N): memref<?x?xf32>
+  %B = alloc (%M, %N): memref<?x?xf32>
+  %C = alloc (%M, %N): memref<?x?xf32>
+  %D = alloc (%M, %N): memref<?x?xf32>
+  %E = alloc (%M, %N): memref<?x?xf32>
+  linalg.generic #pointwise_2d_trait %A, %A, %B {
+  ^bb0(%e: f32, %arg5: f32, %arg6: f32):   // no predecessors
+    %2 = addf %e, %arg5 : f32
+    linalg.yield %2 : f32
+  }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  %0 = dim %B, 0 : memref<?x?xf32>
+  %1 = dim %B, 1 : memref<?x?xf32>
+  loop.for %e = %c0 to %0 step %c2 {
+    loop.for %arg5 = %c0 to %1 step %c3 {
+      %2 = affine.apply #map0(%e)
+      %3 = affine.apply #map1(%arg5)
+      %4 = linalg.subview %B[%e, %2, %c1, %arg5, %3, %c1] : memref<?x?xf32>
+      %5 = linalg.subview %C[%e, %2, %c1, %arg5, %3, %c1] : memref<?x?xf32>
+      %6 = linalg.subview %D[%e, %2, %c1, %arg5, %3, %c1] : memref<?x?xf32>
+      linalg.generic #pointwise_2d_trait %4, %5, %6 {
+      ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):       // no predecessors
+        %7 = mulf %arg6, %arg7 : f32
+        linalg.yield %7 : f32
+      }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @pointwise_no_view
+//       CHECK:   loop.for
+//       CHECK:     loop.for
+//   CHECK-NOT:   loop.for
+//       CHECK:       linalg.generic
+//       CHECK:         addf
+//       CHECK:       linalg.generic
+//       CHECK:         mulf