From 63a2536f77a4037902be517399cb16b39fb732e7 Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Thu, 1 Sep 2022 22:11:14 +0000 Subject: [PATCH] [mlir][MemRef] Simplify extract_strided_metadata(subview) Add a dedicated pass to simplify extract_strided_metadata(other_op(memref)). Currently the pass features only one pattern: extract_strided_metadata(subview). The goal is to get rid of the subview while materializing its effects on the offset, sizes, and strides with respect to the base object. In other words, this simplification replaces: ``` baseBuffer, offset, sizes, strides = extract_strided_metadata( subview(memref, subOffset, subSizes, subStrides)) ``` With ``` baseBuffer, baseOffset, baseSizes, baseStrides = extract_strided_metadata(memref) strides#i = baseStrides#i * subSizes#i offset = baseOffset + sum(subOffset#i * strides#i) sizes = subSizes ``` Differential Revision: https://reviews.llvm.org/D133166 --- .../mlir/Dialect/MemRef/Transforms/Passes.h | 10 + .../mlir/Dialect/MemRef/Transforms/Passes.td | 13 + mlir/include/mlir/IR/AffineExpr.h | 8 + mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 1 + .../Transforms/SimplifyExtractStridedMetadata.cpp | 199 +++++++++++++++ .../MemRef/simplify-extract-strided-metadata.mlir | 283 +++++++++++++++++++++ 6 files changed, 514 insertions(+) create mode 100644 mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp create mode 100644 mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h index b33ce0e..a5309dd 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -55,6 +55,11 @@ void populateResolveRankedShapeTypeResultDimsPatterns( /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); +/// Appends patterns for simplifying extract_strided_metadata(other_op) into +/// easier to analyze constructs. +void populateSimplifyExtractStridedMetadataOpPatterns( + RewritePatternSet &patterns); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It return success if the allocation was multi-buffered and returns failure() @@ -118,6 +123,11 @@ std::unique_ptr createResolveRankedShapeTypeResultDimsPass(); /// in terms of shapes of its input operands. std::unique_ptr createResolveShapedTypeResultDimsPass(); +/// Creates an operation pass to simplify +/// `extract_strided_metadata(other_op(memref))` into +/// `extract_strided_metadata(memref)`. +std::unique_ptr createSimplifyExtractStridedMetadataPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index 5ac124a..6404503 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -173,5 +173,18 @@ def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> { ]; } +def SimplifyExtractStridedMetadata : Pass<"simplify-extract-strided-metadata"> { + let summary = "Simplify extract_strided_metadata ops"; + let description = [{ + The pass simplifies extract_strided_metadata(other_op(memref)) to + extract_strided_metadata(memref) when it is possible to model the effect + of other_op directly with affine maps applied to the result of + extract_strided_metadata. + }]; + let constructor = "mlir::memref::createSimplifyExtractStridedMetadataPass()"; + let dependentDialects = [ + "AffineDialect", "memref::MemRefDialect" + ]; +} #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 9627978..e72e299 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -320,6 +320,14 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) { e = getAffineSymbolExpr(N, ctx); bindSymbols(ctx, exprs...); } + +template +void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl &exprs) { + int idx = 0; + for (AffineExprTy &e : exprs) + e = getAffineSymbolExpr(idx++, ctx); +} + } // namespace detail /// Bind a list of AffineExpr references to DimExpr at positions: diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index f85b6e5..d64bbef 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp + SimplifyExtractStridedMetadata.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp new file mode 100644 index 0000000..3cad0af --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -0,0 +1,199 @@ +//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This pass simplifies extract_strided_metadata(other_op(memref) to +/// extract_strided_metadata(memref) when it is possible to express the effect +// of other_op using affine apply on the results of +// extract_strided_metadata(memref). +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallBitVector.h" + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir +using namespace mlir; + +namespace { +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(subview(memref, subOffset, +/// subSizes, subStrides))` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * strides#i) +/// sizes = subSizes +/// \endverbatim +/// +/// In other words, get rid of the subview in that expression and canonicalize +/// on its effects on the offset, the sizes, and the strides using affine apply. +struct ExtractStridedMetadataOpSubviewFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto subview = op.getSource().getDefiningOp(); + if (!subview) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(subview(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = subview.getSource(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes(sourceRank, indexType); + + auto newExtractStridedMetadata = + rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + SmallVector sourceStrides; + int64_t sourceOffset; + + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid subviews"); + + // Compute the new strides and offset from the base strides and offset: + // newStride#i = baseStride#i * subStride#i + // offset = baseOffset + sum(subOffsets#i * newStrides#i) + SmallVector strides; + SmallVector subStrides = subview.getMixedStrides(); + auto origStrides = newExtractStridedMetadata.getStrides(); + + // Hold the affine symbols and values for the computation of the offset. + SmallVector values(3 * sourceRank + 1); + SmallVector symbols(3 * sourceRank + 1); + + detail::bindSymbolsList(rewriter.getContext(), symbols); + AffineExpr expr = symbols.front(); + values[0] = ShapedType::isDynamicStrideOrOffset(sourceOffset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(sourceOffset); + SmallVector subOffsets = subview.getMixedOffsets(); + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + for (unsigned i = 0; i < sourceRank; ++i) { + // Compute the stride. + OpFoldResult origStride = + ShapedType::isDynamicStrideOrOffset(sourceStrides[i]) + ? origStrides[i] + : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); + strides.push_back(makeComposedFoldedAffineApply( + rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); + + // Build up the computation of the offset. + unsigned baseIdxForDim = 1 + 3 * i; + unsigned subOffsetForDim = baseIdxForDim; + unsigned subStrideForDim = baseIdxForDim + 1; + unsigned origStrideForDim = baseIdxForDim + 2; + expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] * + symbols[origStrideForDim]; + values[subOffsetForDim] = subOffsets[i]; + values[subStrideForDim] = subStrides[i]; + values[origStrideForDim] = origStride; + } + + // Compute the offset. + OpFoldResult finalOffset = + makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); + + SmallVector results; + // The final result is . + // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all + // the values. + auto subType = subview.getType().cast(); + unsigned subRank = subType.getRank(); + // Properly size the array so that we can do random insertions + // at the right indices. + // We do that to populate the non-dropped sizes and strides in one go. + results.resize_for_overwrite(subRank * 2 + 2); + + results[0] = newExtractStridedMetadata.getBaseBuffer(); + results[1] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset); + + // The sizes of the final type are defined directly by the input sizes of + // the subview. + // Moreover subviews can drop some dimensions, some strides and sizes may + // not end up in the final value that we are + // replacing. + // Do the filtering here. + SmallVector subSizes = subview.getMixedSizes(); + const unsigned sizeStartIdx = 2; + const unsigned strideStartIdx = sizeStartIdx + subRank; + unsigned insertedDims = 0; + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + for (unsigned i = 0; i < sourceRank; ++i) { + if (droppedDims.test(i)) + continue; + + results[sizeStartIdx + insertedDims] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]); + results[strideStartIdx + insertedDims] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]); + ++insertedDims; + } + assert(insertedDims == subRank && + "Should have populated all the values at this point"); + + rewriter.replaceOp(op, results); + return success(); + } +}; +} // namespace + +void memref::populateSimplifyExtractStridedMetadataOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct SimplifyExtractStridedMetadataPass final + : public memref::impl::SimplifyExtractStridedMetadataBase< + SimplifyExtractStridedMetadataPass> { + void runOnOperation() override; +}; + +} // namespace + +void SimplifyExtractStridedMetadataPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); +} + +std::unique_ptr memref::createSimplifyExtractStridedMetadataPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir new file mode 100644 index 0000000..8ef1729 --- /dev/null +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -0,0 +1,283 @@ +// RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s + +// ----- + +// Check that we simplify extract_strided_metadata of subview to +// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata +// strides = base_stride_i * subview_stride_i +// offset = base_offset + sum(subview_offsets_i * strides_i). +// +// This test also checks that we don't create useless arith operations +// when subview_offsets_i is 0. +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>) +// +// Materialize the offset for dimension 1. +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// +// Plain extract_strided_metadata. +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// Final offset is: +// origOffset + (== 0) +// base_stride0 * subview_stride0 * subview_offset0 + (== 4 * 1 * 0 == 0) +// base_stride1 * subview_stride1 * subview_offset1 (== 1 * 1 * 2) +// == 2 +// +// Return the new tuple. +// CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[0, 2][2, 2][1, 1] : + memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<2x2xf32, strided<[4,1], offset:2>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when dynamic sizes are involved. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// Orig strides: [64, 4, 1] +// Sub strides: [1, 1, 1] +// => New strides: [64, 4, 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// +// Final sizes == subview sizes == [%size, 6, 3] +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_SIZE:.*]]: index) +// +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_subview_with_dynamic_size( + %base: memref<8x16x4xf32>, %size: index) + -> (memref, index, index, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref> + + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, index, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : + memref, index, index, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview reduces the ranks. +// In particular the returned strides must come from #1 and #2 of the %strides +// value of the new extract_strided_metadata_of_subview, not #0 and #1. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [64, 4, 1] +// Sub strides: [1, 1, 1] +// => New strides: [64, 4, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [4 , 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// +// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] +// +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) +// +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<6x3xf32, strided<[4,1], offset: 210>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview reduces the rank and some of the strides are variable. +// In particular, we check that: +// A. The dynamic stride is multiplied with the base stride to create the new +// stride for dimension 1. +// B. The first returned stride is the value computed in #A. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [64, 4, 1] +// Sub strides: [1, %stride, 1] +// => New strides: [64, 4 * %stride, 1] +// Final strides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194 +// +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)> +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_STRIDE:.*]]: index) +// +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( + %base: memref<8x16x4xf32>, %stride: index) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<6x3xf32, strided<[4, 1], offset: 210>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview uses variable offsets. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [128, 1] +// Sub strides: [1, 1] +// => New strides: [128, 1] +// +// Orig offset: 0 +// Sub offsets: [%arg1, %arg2] +// => Final offset: 128 * arg1 + 1 * %arg2 + 0 +// +// CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)> +// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>, +// CHECK-SAME: %[[DYN_OFFSET0:.*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:.*]]: index) +// +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]] +#map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +func.func @extract_strided_metadata_of_subview_w_variable_offset( + %arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : + memref<384x128xf32> to memref<64x64xf32, #map0> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<64x64xf32, #map0> -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that all the math is correct for all types of computations. +// We achieve that by using dynamic values for all the different types: +// - Offsets +// - Sizes +// - Strides +// +// Orig strides: [s0, s1, s2] +// Sub strides: [subS0, subS1, subS2] +// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2] +// ==> 1 affine map (used for each stride) with two values. +// +// Orig offset: origOff +// Sub offsets: [subO0, subO1, subO2] +// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff +// ==> 1 affine map with (rank * 3 + 1) symbols +// +// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)> +// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0] +// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] +// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]] +func.func @extract_strided_metadata_of_subview_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> (memref, index, index, index, index, index, index, index) { + + %subview = memref.subview %base[%offset0, %offset1, %offset2] + [%size0, %size1, %size2] + [%stride0, %stride1, %stride2] : + memref> to + memref> + + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, index, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : + memref, index, index, index, index, index, index, index +} -- 2.7.4