From 3598605c0b3658dbb6cac634cb92a0a131f2fe0b Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Tue, 17 Nov 2020 13:39:08 +0100 Subject: [PATCH] [mlir][std] Fold dim(dynamic_tensor_from_elements, %cst) The shape of the result of a dynamic_tensor_from_elements is defined via its result type and operands. We already fold dim operations when they reference one of the statically sized dimensions. Now, also fold dim on the dynamically sized dimensions by picking the corresponding operand. Differential Revision: https://reviews.llvm.org/D91616 --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 19 +++++++++++++ mlir/test/Dialect/Standard/canonicalize.mlir | 41 +++++++++++++++++++--------- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 629c8e4..4698893 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1484,6 +1484,25 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return getResult(); } + // Fold dim to the operand of dynamic_tensor_from_elements. + if (auto fromElements = + dyn_cast_or_null(definingOp)) { + auto resultType = + fromElements.getResult().getType().cast(); + // The case where the type encodes the size of the dimension is handled + // above. + assert(resultType.getShape()[index.getInt()] == + RankedTensorType::kDynamicSize); + + // Find the operand of the fromElements that corresponds to this index. + auto dynExtents = fromElements.dynamicExtents().begin(); + for (auto dim : resultType.getShape().take_front(index.getInt())) + if (dim == RankedTensorType::kDynamicSize) + dynExtents++; + + return Value{*dynExtents}; + } + // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); if (!memrefType) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 1589dc1..1e2e4a5 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -5,9 +5,9 @@ // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor { // CHECK: return %[[TENSOR]] func @tensor_load_of_tensor_to_memref(%arg0: tensor) -> tensor { - %0 = tensor_to_memref %arg0 : memref - %1 = tensor_load %0 : memref - return %1 : tensor + %0 = tensor_to_memref %arg0 : memref + %1 = tensor_load %0 : memref + return %1 : tensor } // Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m @@ -15,9 +15,9 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor) -> tensor { // CHECK-SAME: %[[MEMREF:.*]]: memref) -> memref { // CHECK: return %[[MEMREF]] func @tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { - %0 = tensor_load %arg0 : memref - %1 = tensor_to_memref %0 : memref - return %1 : memref + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref } // Test case: If the memrefs are not the same type, don't fold them. @@ -27,9 +27,9 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref // CHECK: return %[[MEMREF_ADDRSPACE7]] func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { - %0 = tensor_load %arg0 : memref - %1 = tensor_to_memref %0 : memref - return %1 : memref + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref } // Test case: Basic folding of dim(tensor_load(m)) -> dim(m). @@ -39,8 +39,23 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref) -> memref // CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]] // CHECK: return %[[D]] : index func @dim_of_tensor_load(%arg0: memref) -> index { - %c0 = constant 0 : index - %0 = tensor_load %arg0 : memref - %1 = dim %0, %c0 : tensor - return %1 : index + %c0 = constant 0 : index + %0 = tensor_load %arg0 : memref + %1 = dim %0, %c0 : tensor + return %1 : index +} + +// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx +// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-NOT: dim +// CHECK: return %[[IDX1]] : index +func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index { + %c3 = constant 3 : index + %0 = dynamic_tensor_from_elements %arg0, %arg1 { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + yield %c3 : index + } : tensor<2x?x4x?x5xindex> + %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> + return %1 : index } -- 2.7.4