[mlir][std] Fold dim(dynamic_tensor_from_elements, %cst)
authorStephan Herhut <herhut@google.com>
Tue, 17 Nov 2020 12:39:08 +0000 (13:39 +0100)
committerStephan Herhut <herhut@google.com>
Tue, 17 Nov 2020 13:39:59 +0000 (14:39 +0100)
The shape of the result of a dynamic_tensor_from_elements is defined via its
result type and operands. We already fold dim operations when they reference
one of the statically sized dimensions. Now, also fold dim on the dynamically
sized dimensions by picking the corresponding operand.

Differential Revision: https://reviews.llvm.org/D91616

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index 629c8e4..4698893 100644 (file)
@@ -1484,6 +1484,25 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     return getResult();
   }
 
+  // Fold dim to the operand of dynamic_tensor_from_elements.
+  if (auto fromElements =
+          dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) {
+    auto resultType =
+        fromElements.getResult().getType().cast<RankedTensorType>();
+    // The case where the type encodes the size of the dimension is handled
+    // above.
+    assert(resultType.getShape()[index.getInt()] ==
+           RankedTensorType::kDynamicSize);
+
+    // Find the operand of the fromElements that corresponds to this index.
+    auto dynExtents = fromElements.dynamicExtents().begin();
+    for (auto dim : resultType.getShape().take_front(index.getInt()))
+      if (dim == RankedTensorType::kDynamicSize)
+        dynExtents++;
+
+    return Value{*dynExtents};
+  }
+
   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
   auto memrefType = argTy.dyn_cast<MemRefType>();
   if (!memrefType)
index 1589dc1..1e2e4a5 100644 (file)
@@ -5,9 +5,9 @@
 // CHECK-SAME:                                          %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
 // CHECK:           return %[[TENSOR]]
 func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
-    %0 = tensor_to_memref %arg0 : memref<?xf32>
-    %1 = tensor_load %0 : memref<?xf32>
-    return %1 : tensor<?xf32>
+  %0 = tensor_to_memref %arg0 : memref<?xf32>
+  %1 = tensor_load %0 : memref<?xf32>
+  return %1 : tensor<?xf32>
 }
 
 // Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
@@ -15,9 +15,9 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
 // CHECK-SAME:                                          %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
 // CHECK:           return %[[MEMREF]]
 func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
-    %0 = tensor_load %arg0 : memref<?xf32>
-    %1 = tensor_to_memref %0 : memref<?xf32>
-    return %1 : memref<?xf32>
+  %0 = tensor_load %arg0 : memref<?xf32>
+  %1 = tensor_to_memref %0 : memref<?xf32>
+  return %1 : memref<?xf32>
 }
 
 // Test case: If the memrefs are not the same type, don't fold them.
@@ -27,9 +27,9 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
 // CHECK:           %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32, 7>
 // CHECK:           return %[[MEMREF_ADDRSPACE7]]
 func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> {
-    %0 = tensor_load %arg0 : memref<?xf32, 2>
-    %1 = tensor_to_memref %0 : memref<?xf32, 7>
-    return %1 : memref<?xf32, 7>
+  %0 = tensor_load %arg0 : memref<?xf32, 2>
+  %1 = tensor_to_memref %0 : memref<?xf32, 7>
+  return %1 : memref<?xf32, 7>
 }
 
 // Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
@@ -39,8 +39,23 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
 //       CHECK:   %[[D:.*]] = dim %[[MEMREF]], %[[C0]]
 //       CHECK:   return %[[D]] : index
 func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
-    %c0 = constant 0 : index
-    %0 = tensor_load %arg0 : memref<?xf32>
-    %1 = dim %0, %c0 : tensor<?xf32>
-    return %1 : index
+  %c0 = constant 0 : index
+  %0 = tensor_load %arg0 : memref<?xf32>
+  %1 = dim %0, %c0 : tensor<?xf32>
+  return %1 : index
+}
+
+// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
+// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
+//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+//   CHECK-NOT:   dim
+//       CHECK:   return %[[IDX1]] : index
+func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
+  %c3 = constant 3 : index
+  %0 = dynamic_tensor_from_elements %arg0, %arg1 {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    yield %c3 : index
+  } : tensor<2x?x4x?x5xindex>
+  %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
+  return %1 : index
 }