From adfd3c7083f9808d145239153c10f72eece485d8 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 16 Feb 2021 11:03:58 -0800 Subject: [PATCH] [mlir] Fix memref_cast + subview folder when reducing rank When the destination of the subview has a lower rank than its source we need to fix the result type of the new subview op. Differential Revision: https://reviews.llvm.org/D96804 --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 46 ++++++++++++++++++++++++---- mlir/test/Dialect/Standard/canonicalize.mlir | 14 +++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 4908291..5582c0b 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3058,7 +3058,23 @@ isRankReducedType(Type originalType, Type candidateReducedType, candidateLayout = getStridedLinearLayoutMap(candidateReduced); else candidateLayout = candidateReduced.getAffineMaps().front(); - if (inferredType != candidateLayout) { + assert(inferredType.getNumResults() == 1 && + candidateLayout.getNumResults() == 1); + if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || + inferredType.getNumDims() != candidateLayout.getNumDims()) { + if (errMsg) { + llvm::raw_string_ostream os(*errMsg); + os << "inferred type: " << inferredType; + } + return SubViewVerificationResult::AffineMapMismatch; + } + // Check that the difference of the affine maps simplifies to 0. + AffineExpr diffExpr = + inferredType.getResult(0) - candidateLayout.getResult(0); + diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), + inferredType.getNumSymbols()); + auto cst = diffExpr.dyn_cast(); + if (!(cst && cst.getValue() == 0)) { if (errMsg) { llvm::raw_string_ostream os(*errMsg); os << "inferred type: " << inferredType; @@ -3344,11 +3360,29 @@ public: /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. - Type resultType = SubViewOp::inferResultType( - castOp.source().getType().cast(), - extractFromI64ArrayAttr(subViewOp.static_offsets()), - extractFromI64ArrayAttr(subViewOp.static_sizes()), - extractFromI64ArrayAttr(subViewOp.static_strides())); + auto resultType = SubViewOp::inferResultType( + castOp.source().getType().cast(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())) + .cast(); + uint32_t rankDiff = + subViewOp.getSourceType().getRank() - subViewOp.getType().getRank(); + if (rankDiff > 0) { + auto shape = resultType.getShape(); + auto projectedShape = shape.drop_front(rankDiff); + AffineMap map; + auto maps = resultType.getAffineMaps(); + if (!maps.empty() && maps.front()) { + auto optionalUnusedDimsMask = + computeRankReductionMask(shape, projectedShape); + llvm::SmallDenseSet dimsToProject = + optionalUnusedDimsMask.getValue(); + map = getProjectedMap(maps.front(), dimsToProject); + } + resultType = MemRefType::get(projectedShape, resultType.getElementType(), + map, resultType.getMemorySpace()); + } Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 7b54938..c864af8 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -143,3 +143,17 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> %1 = tensor_to_memref %0 : memref return %1 : memref } + +// CHECK-LABEL: func @subview_of_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> +// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> +func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ + %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref + %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : + memref to + memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> + return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> +} -- 2.7.4