Fix SubViewOp stride calculation in constant folding.
authorAndy Davis <andydavis@google.com>
Mon, 18 Nov 2019 23:00:34 +0000 (15:00 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 18 Nov 2019 23:01:08 +0000 (15:01 -0800)
Adds unit tests for subview offset and stride argument constant folding.

PiperOrigin-RevId: 281161041

mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/Transforms/canonicalize.mlir

index e38ce06..c2195ae 100644 (file)
@@ -2741,13 +2741,15 @@ struct SubViewOpShapeFolder : public OpRewritePattern<SubViewOp> {
       dynamicDimPos++;
     }
 
-    // Compute new strides based on 'newShapeConstants'.
+    // Compute new strides based on 'baseStrides' and SubViewOp stride args.
+    SmallVector<Value *, 4> viewStrides(subViewOp.getDynamicStrides().begin(),
+                                        subViewOp.getDynamicStrides().end());
+    assert(viewStrides.size() == baseStrides.size());
     SmallVector<int64_t, 4> newSubViewStrides(rank);
-    newSubViewStrides[rank - 1] = 1;
-    for (int i = rank - 2; i >= 0; --i) {
-      assert(!ShapedType::isDynamic(newShapeConstants[i + 1]));
-      newSubViewStrides[i] =
-          newShapeConstants[i + 1] * newSubViewStrides[i + 1];
+    for (unsigned i = 0, e = viewStrides.size(); i < e; ++i) {
+      int64_t viewStride =
+          cast<ConstantIndexOp>(viewStrides[i]->getDefiningOp()).getValue();
+      newSubViewStrides[i] = baseStrides[i] * viewStride;
     }
 
     // Regenerate strided layout map with 'newSubViewStrides' and
index dadecb5..f05810b 100644 (file)
@@ -685,8 +685,9 @@ func @view(%arg0 : index) {
 // -----
 
 // CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)
-// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 165 + d1 * 15 + d2)
+// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 79)
 // CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)
+// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = (d0, d1, d2) -> (d0 * 128 + d1 * 28 + d2 * 11)
 
 // CHECK-LABEL: func @subview
 func @subview(%arg0 : index) -> (index, index) {
@@ -694,6 +695,7 @@ func @subview(%arg0 : index) -> (index, index) {
   %c0 = constant 0 : index
   // CHECK: %[[C1:.*]] = constant 1 : index
   %c1 = constant 1 : index
+  %c2 = constant 2 : index   
   // CHECK: %[[C7:.*]] = constant 7 : index
   %c7 = constant 7 : index
   // CHECK: %[[C11:.*]] = constant 11 : index
@@ -705,8 +707,10 @@ func @subview(%arg0 : index) -> (index, index) {
   %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
 
   // Test: subview with constant base memref and constant operands is folded.
-  // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
-  %1 = subview %0[%c0, %c0, %c0][%c7, %c11, %c15][%c1, %c1, %c1]
+  // Note that the subview uses the base memrefs layout map because it used
+  // zero offset and unit stride arguments.
+  // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]>
+  %1 = subview %0[%c0, %c0, %c0][%c7, %c11, %c2][%c1, %c1, %c1]
     : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
       memref<?x?x?xf32,
        (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
@@ -733,12 +737,30 @@ func @subview(%arg0 : index) -> (index, index) {
   load %4[%c0, %c0, %c0] : memref<?x?x?xf32,
        (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
 
+  // Test: subview offset operands are folded correctly w.r.t. base strides. 
+  // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP0]]>
+  %5 = subview %0[%c1, %c2, %c7][%c7, %c11, %c2][%c1, %c1, %c1]
+    : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+  load %5[%c0, %c0, %c0] : memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
+  // Test: subview stride operands are folded correctly w.r.t. base strides.
+  // CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]>
+  %6 = subview %0[%c0, %c0, %c0][%c7, %c11, %c2][%c2, %c7, %c11]
+    : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
+      memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+  load %6[%c0, %c0, %c0] : memref<?x?x?xf32,
+       (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
+
   // Test: dim on subview is rewritten to size operand.
-  %5 = dim %4, 0 : memref<?x?x?xf32,
+  %7 = dim %4, 0 : memref<?x?x?xf32,
        (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
-  %6 = dim %4, 1 : memref<?x?x?xf32,
+  %8 = dim %4, 1 : memref<?x?x?xf32,
        (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
 
   // CHECK: return %[[C7]], %[[C11]]
-  return %5, %6 : index, index
+  return %7, %8 : index, index
 }