DimOp folding for alloc/view dynamic dimensions
authorUday Bondhugula <uday@polymagelabs.com>
Fri, 6 Dec 2019 13:59:06 +0000 (05:59 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Dec 2019 14:00:54 +0000 (06:00 -0800)
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes tensorflow/mlir#253

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/253 from bondhugula:dimop a4b464f24ae63fd259114558d87e11b8ee4dae86
PiperOrigin-RevId: 284169689

mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Transforms/canonicalize.mlir

index 0e2bee0..a9e9364 100644 (file)
@@ -1364,11 +1364,26 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
   else if (auto memrefType = opType.dyn_cast<MemRefType>())
     indexSize = memrefType.getShape()[getIndex()];
 
-  if (indexSize >= 0)
+  if (!ShapedType::isDynamic(indexSize))
     return IntegerAttr::get(IndexType::get(getContext()), indexSize);
 
-  // Fold dim to the size argument of a SubViewOp.
+  // Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
+  auto memrefType = opType.dyn_cast<MemRefType>();
+  if (!memrefType)
+    return {};
+
+  // The size at getIndex() is now a dynamic size of a memref.
+
   auto memref = memrefOrTensor()->getDefiningOp();
+  if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
+    return *(alloc.getDynamicSizes().begin() +
+             memrefType.getDynamicDimIndex(getIndex()));
+
+  if (auto view = dyn_cast_or_null<ViewOp>(memref))
+    return *(view.getDynamicSizes().begin() +
+             memrefType.getDynamicDimIndex(getIndex()));
+
+  // The subview op here is expected to have rank dynamic sizes now.
   if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
     auto sizes = subview.sizes();
     if (!sizes.empty())
index 20637e6..e73e658 100644 (file)
@@ -46,10 +46,7 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d
   }
   // CHECK: %[[tensor:[0-9]+]] = alloc
   // CHECK-NOT: {{.*}} dim %[[tensor]], 0
-  // CHECK: {{.*}} dim %[[tensor]], 1
-  // CHECK: {{.*}} dim %[[tensor]], 2
   // CHECK-NOT: {{.*}} dim %[[tensor]], 3
-  // CHECK: {{.*}} dim %[[tensor]], 4
   return
 }
 
@@ -66,36 +63,32 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
   // CHECK-NEXT:    affine.for %[[I1:.*]] = 0 to %{{.*}} {
   // CHECK-NEXT:      affine.for %[[I2:.*]] = 0 to %{{.*}} {
   // CHECK-NEXT:        affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
-  //      CHECK:          %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
   //      CHECK:          %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
   // CHECK-NEXT:          %[[VECTOR_VIEW:.*]] = vector.type_cast %[[ALLOC]] : memref<5x4x3xf32>
   // CHECK-NEXT:          loop.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
   // CHECK-NEXT:            loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
   // CHECK-NEXT:              loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
   // CHECK-NEXT:                {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
   // CHECK-NEXT:                %[[L0:.*]] = select
   //
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
   // CHECK-NEXT:                %[[L1:.*]] = select
   //
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
   // CHECK-NEXT:                %[[L2:.*]] = select
   //
   // CHECK-NEXT:                {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
@@ -144,10 +137,6 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
   // CHECK-NEXT:    affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
   // CHECK-NEXT:      affine.for %[[I2:.*]] = 0 to %{{.*}} {
   // CHECK-NEXT:        affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
-  //      CHECK:          %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
-  // CHECK-NEXT:          %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
   // CHECK:               %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
   // CHECK-NEXT:          %[[VECTOR_VIEW:.*]] = vector.type_cast {{.*}} : memref<5x4x3xf32>
   //      CHECK:          store %{{.*}}, {{.*}} : memref<vector<5x4x3xf32>>
@@ -155,27 +144,27 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
   // CHECK-NEXT:            loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
   // CHECK-NEXT:              loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
   // CHECK-NEXT:                {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
   // CHECK-NEXT:                %[[S0:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
   //
   // CHECK-NEXT:                {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
   // CHECK-NEXT:                %[[S1:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
   //
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", %[[I2]], %{{.*}} : index
   // CHECK-NEXT:                {{.*}} = select {{.*}}, %[[I2]], {{.*}} : index
   // CHECK-NEXT:                {{.*}} = cmpi "slt", %[[I2]], %[[C0]] : index
   // CHECK-NEXT:                %[[S2:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
   //
   // CHECK-NEXT:                {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
-  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
+  // CHECK-NEXT:                {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
   // CHECK-NEXT:                {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
index 7fa5594..933280b 100644 (file)
@@ -22,13 +22,13 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
   linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
-// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
+// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[N:arg[0-9]+]]: index
+// CHECK-SAME: [[K:arg[0-9]+]]: index
 //       CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
 //       CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
 //       CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
-//       CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
-//       CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
-//       CHECK: %[[N:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
 //       CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
 //       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
 //       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
@@ -48,12 +48,12 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
   linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
   return
 }
-// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index) {
+// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[K:arg[0-9]+]]: index
 //       CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
 //       CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
 //       CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
-//       CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
-//       CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
 //       CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
 //       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
 //   CHECK-DAG:     %[[a:.*]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>
@@ -72,11 +72,11 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
   linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
   return
 }
-// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>, %{{.*}}: index) {
+// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>,
+// CHECK-SAME: [[K:arg[0-9]+]]: index
 //       CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
 //       CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
 //       CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref<?xi8> to memref<f32>
-//       CHECK: %[[K:.*]] = dim %[[A]], 0 : memref<?xf32, #[[strided1D]]>
 //       CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
 //   CHECK-DAG:   %[[a:.*]] = load %[[A]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
 //   CHECK-DAG:   %[[b:.*]] = load %[[B]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
index f6840ce..07a7e7c 100644 (file)
@@ -418,6 +418,62 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
   return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
 }
 
+#map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
+#map2 = (d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)
+
+// CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index,
+func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, %BUF: memref<?xi8>, %M : index, %N : index, %K : index) {
+// CHECK-SAME: [[M:arg[0-9]+]]: index
+// CHECK-SAME: [[N:arg[0-9]+]]: index
+// CHECK-SAME: [[K:arg[0-9]+]]: index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+  %1 = alloc(%arg1, %arg2) : memref<?x8x?xf32>
+  %2 = dim %1, 2 : memref<?x8x?xf32>
+  affine.for %arg3 = 0 to %2 {
+    %3 = alloc(%arg0) : memref<?xi8>
+    %ub = dim %3, 0 : memref<?xi8>
+    affine.for %arg4 = 0 to %ub {
+      %s = dim %0, 0 : memref<?x?xf32>
+      %v = std.view %3[%c0][%arg4, %s] : memref<?xi8> to memref<?x?xf32, #map1>
+      %sv = std.subview %0[%c0, %c0][%s,%arg4][%c1,%c1] : memref<?x?xf32> to memref<?x?xf32, #map1>
+      %l = dim %v, 1 : memref<?x?xf32, #map1>
+      %u = dim %sv, 0 : memref<?x?xf32, #map1>
+      affine.for %arg5 = %l to %u {
+        "foo"() : () -> ()
+      }
+    }
+  }
+  // CHECK-NEXT: %c0 = constant 0 : index
+  // CHECK-NEXT: %c1 = constant 1 : index
+  // CHECK-NEXT: affine.for %arg7 = 0 to %arg2 {
+  // CHECK-NEXT:   affine.for %arg8 = 0 to %arg0 {
+  // CHECK-NEXT:     affine.for %arg9 = %arg0 to %arg0 {
+  // CHECK-NEXT:       "foo"() : () -> ()
+  // CHECK-NEXT:     }
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: }
+
+  %A = view %BUF[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %B = view %BUF[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %C = view %BUF[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+
+  %M_ = dim %A, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %K_ = dim %A, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %N_ = dim %C, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  loop.for %i = %c0 to %M_ step %c1 {
+    loop.for %j = %c0 to %N_ step %c1 {
+      loop.for %k = %c0 to %K_ step %c1 {
+      }
+    }
+  }
+  // CHECK: loop.for %{{.*}} = %c0 to %[[M]] step %c1 {
+  // CHECK:   loop.for %arg8 = %c0 to %[[N]] step %c1 {
+  // CHECK:     loop.for %arg9 = %c0 to %[[K]] step %c1 {
+  return
+}
+
 // CHECK-LABEL: func @merge_constants
 func @merge_constants() -> (index, index) {
   // CHECK-NEXT: %c42 = constant 42 : index
@@ -743,7 +799,7 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
   load %4[%c0, %c0, %c0] : memref<?x?x?xf32,
        (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
 
-  // Test: subview offset operands are folded correctly w.r.t. base strides. 
+  // Test: subview offset operands are folded correctly w.r.t. base strides.
   // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
   %5 = subview %0[%c1, %c2, %c7][%c7, %c11, %c2][%c1, %c1, %c1]
     : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to