Add support to AffineApplyOp::fold for folding dim and symbol expression results.
authorRiver Riddle <riverriddle@google.com>
Tue, 4 Jun 2019 21:12:40 +0000 (14:12 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:17:46 +0000 (16:17 -0700)
PiperOrigin-RevId: 251512700

mlir/examples/Linalg/Linalg3/Example.cpp
mlir/examples/Linalg/Linalg4/Example.cpp
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/test/AffineOps/canonicalize.mlir

index 69717e8..cf77785 100644 (file)
@@ -185,16 +185,16 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
   //   CHECK-NOT: {{.*}} = linalg.
   //       CHECK:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
   //       CHECK:     affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
-  //       CHECK:       %4 = cmpi "eq", %i2, %c0 : index
-  //       CHECK:       %6 = load %arg2[%5, %3] : memref<?x?xf32>
-  //       CHECK:       %7 = select %4, %cst, %6 : f32
+  //       CHECK:       %3 = cmpi "eq", %i2, %c0 : index
+  //       CHECK:       %4 = load %arg2[%i1, %i0] : memref<?x?xf32>
+  //       CHECK:       %5 = select %3, %cst, %4 : f32
   //   CHECK-NOT: {{.*}} = linalg.
-  //       CHECK:       %9 = load %arg1[%8, %3] : memref<?x?xf32>
-  //       CHECK:       %10 = load %arg0[%5, %8] : memref<?x?xf32>
-  //       CHECK:       %11 = mulf %10, %9 : f32
-  //       CHECK:       %12 = addf %7, %11 : f32
+  //       CHECK:       %6 = load %arg1[%i2, %i0] : memref<?x?xf32>
+  //       CHECK:       %7 = load %arg0[%i1, %i2] : memref<?x?xf32>
+  //       CHECK:       %8 = mulf %7, %6 : f32
+  //       CHECK:       %9 = addf %5, %8 : f32
   //   CHECK-NOT: {{.*}} = linalg.
-  //       CHECK:       store %12, %arg2[%5, %3] : memref<?x?xf32>
+  //       CHECK:       store %9, %arg2[%i1, %i0] : memref<?x?xf32>
   // clang-format on
 }
 
index bb32758..73e7570 100644 (file)
@@ -83,16 +83,13 @@ TEST_FUNC(matmul_tiled_loops) {
   //       CHECK:       affine.for %i3 = max (d0)[s0] -> (s0, d0)(%i0)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 8)(%i0)[%[[M]]] {
   //       CHECK:         affine.for %i4 = max (d0)[s0] -> (s0, d0)(%i1)[%{{.*}}] to min (d0)[s0] -> (s0, d0 + 9)(%i1)[%[[N]]] {
   //  CHECK-NEXT:           %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
-  //  CHECK-NEXT:           %[[I3:.*]] = affine.apply (d0) -> (d0)(%i3)
-  //  CHECK-NEXT:           %[[I4:.*]] = affine.apply (d0) -> (d0)(%i4)
-  //  CHECK-NEXT:           %{{.*}} = load %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
+  //  CHECK-NEXT:           %{{.*}} = load %arg2[%i3, %i4] : memref<?x?xf32>
   //  CHECK-NEXT:           %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : f32
-  //  CHECK-NEXT:           %[[I2:.*]] = affine.apply (d0) -> (d0)(%i2)
-  //  CHECK-NEXT:           %{{.*}} = load %arg1[%[[I2]], %[[I4]]] : memref<?x?xf32>
-  //  CHECK-NEXT:           %{{.*}} = load %arg0[%[[I3]], %[[I2]]] : memref<?x?xf32>
-  //  CHECK-NEXT:           %{{.*}} = mulf %10, %9 : f32
-  //  CHECK-NEXT:           %{{.*}} = addf %7, %11 : f32
-  //  CHECK-NEXT:           store %{{.*}}, %arg2[%[[I3]], %[[I4]]] : memref<?x?xf32>
+  //  CHECK-NEXT:           %{{.*}} = load %arg1[%i2, %i4] : memref<?x?xf32>
+  //  CHECK-NEXT:           %{{.*}} = load %arg0[%i3, %i2] : memref<?x?xf32>
+  //  CHECK-NEXT:           %{{.*}} = mulf %7, %6 : f32
+  //  CHECK-NEXT:           %{{.*}} = addf %5, %8 : f32
+  //  CHECK-NEXT:           store %{{.*}}, %arg2[%i3, %i4] : memref<?x?xf32>
   // clang-format on
 }
 
@@ -112,16 +109,14 @@ TEST_FUNC(matmul_tiled_views) {
   //       CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
   //  CHECK-NEXT:   affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
-  //  CHECK-NEXT:     %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
   //  CHECK-NEXT:     %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
-  //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
+  //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
   //       CHECK:     %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
   //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
-  //       CHECK:     %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
-  //  CHECK-NEXT:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
-  //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
-  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
-  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+  //       CHECK:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
+  //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
+  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
   //  CHECK-NEXT:     linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg.view<?x?xf32>
   // clang-format on
   cleanupAndPrintFunction(f);
@@ -148,16 +143,14 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
   //       CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) step 8 {
   //  CHECK-NEXT:   affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) step 9 {
-  //  CHECK-NEXT:     %[[i0min:.*]] = affine.apply (d0) -> (d0)(%i0)
   //  CHECK-NEXT:     %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
-  //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg.range
+  //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %i0:%[[i0max]]:{{.*}} : !linalg.range
   //       CHECK:     %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
   //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
-  //       CHECK:     %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
-  //  CHECK-NEXT:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
-  //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg.range
-  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
-  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+  //       CHECK:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
+  //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %i1:%[[i1max]]:%{{.*}} : !linalg.range
+  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%7, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
+  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%4, %9] : memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
   //  CHECK-NEXT:     affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) {
   //  CHECK-NEXT:       affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) {
   //  CHECK-NEXT:         affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) {
index 7a6b861..ea448cc 100644 (file)
@@ -241,7 +241,10 @@ public:
     if (!result)
       return failure();
 
-    results.push_back(result);
+    // Check if the operation was folded in place. In this case, the operation
+    // returns itself.
+    if (result.template dyn_cast<Value *>() != op->getResult(0))
+      results.push_back(result);
     return success();
   }
 
index f6c0441..28594a3 100644 (file)
@@ -203,6 +203,15 @@ bool AffineApplyOp::isValidSymbol() {
 
 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
   auto map = getAffineMap();
+
+  // Fold dims and symbols to existing values.
+  auto expr = map.getResult(0);
+  if (auto dim = expr.dyn_cast<AffineDimExpr>())
+    return getOperand(dim.getPosition());
+  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+    return getOperand(map.getNumDims() + sym.getPosition());
+
+  // Otherwise, default to folding the map.
   SmallVector<Attribute, 1> result;
   if (failed(map.constantFold(operands, result)))
     return {};
index 90f6aed..f6d15a7 100644 (file)
@@ -22,9 +22,6 @@
 // CHECK-DAG: [[MAP13A:#map[0-9]+]] = (d0) -> ((d0 + 6) ceildiv 8)
 // CHECK-DAG: [[MAP13B:#map[0-9]+]] = (d0) -> ((d0 * 4 - 4) floordiv 3)
 
-// Affine maps for test case: arg_used_as_dim_and_symbol
-// CHECK-DAG: [[MAP14:#map[0-9]+]] = (d0) -> (d0)
-
 // Affine maps for test case: partial_fold_map
 // CHECK-DAG: [[MAP15:#map[0-9]+]] = ()[s0, s1] -> (s0 - s1)
 
@@ -55,8 +52,7 @@ func @compose_affine_maps_1dto2d_no_symbols() {
     %x1_1 = affine.apply (d0, d1) -> (d1) (%x0, %x0)
 
     // CHECK: [[I0A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
-    // CHECK-NEXT: [[I0B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
-    // CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0B]]{{\]}}
+    // CHECK-NEXT: load %0{{\[}}[[I0A]], [[I0A]]{{\]}}
     %v0 = load %0[%x1_0, %x1_1] : memref<4x4xf32>
 
     // Test load[%y, %y]
@@ -65,25 +61,20 @@ func @compose_affine_maps_1dto2d_no_symbols() {
     %y1_1 = affine.apply (d0, d1) -> (d1) (%y0, %y0)
 
     // CHECK-NEXT: [[I1A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
-    // CHECK-NEXT: [[I1B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
-    // CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1B]]{{\]}}
+    // CHECK-NEXT: load %0{{\[}}[[I1A]], [[I1A]]{{\]}}
     %v1 = load %0[%y1_0, %y1_1] : memref<4x4xf32>
 
     // Test load[%x, %y]
     %xy_0 = affine.apply (d0, d1) -> (d0) (%x0, %y0)
     %xy_1 = affine.apply (d0, d1) -> (d1) (%x0, %y0)
 
-    // CHECK-NEXT: [[I2A:%[0-9]+]] = affine.apply [[MAP0]](%i0)
-    // CHECK-NEXT: [[I2B:%[0-9]+]] = affine.apply [[MAP1]](%i0)
-    // CHECK-NEXT: load %0{{\[}}[[I2A]], [[I2B]]{{\]}}
+    // CHECK-NEXT: load %0{{\[}}[[I0A]], [[I1A]]{{\]}}
     %v2 = load %0[%xy_0, %xy_1] : memref<4x4xf32>
 
     // Test load[%y, %x]
     %yx_0 = affine.apply (d0, d1) -> (d0) (%y0, %x0)
     %yx_1 = affine.apply (d0, d1) -> (d1) (%y0, %x0)
-    // CHECK-NEXT: [[I3A:%[0-9]+]] = affine.apply [[MAP1]](%i0)
-    // CHECK-NEXT: [[I3B:%[0-9]+]] = affine.apply [[MAP0]](%i0)
-    // CHECK-NEXT: load %0{{\[}}[[I3A]], [[I3B]]{{\]}}
+    // CHECK-NEXT: load %0{{\[}}[[I1A]], [[I0A]]{{\]}}
     %v3 = load %0[%yx_0, %yx_1] : memref<4x4xf32>
   }
   return
@@ -238,8 +229,7 @@ func @arg_used_as_dim_and_symbol(%arg0: memref<100x100xf32>, %arg1: index) {
         (%i0, %i1)[%arg1, %c9]
       %4 = affine.apply (d0, d1, d3) -> (d3 - (d0 + d1))
         (%arg1, %c9, %3)
-      // CHECK: [[I0:%[0-9]+]] = affine.apply [[MAP14]](%i1)
-      // CHECK-NEXT: load %{{[0-9]+}}{{\[}}[[I0]], %arg1{{\]}}
+      // CHECK: load %{{[0-9]+}}{{\[}}%i1, %arg1{{\]}}
       %5 = load %1[%4, %arg1] : memref<100x100xf32, 1>
     }
   }