From 325426d72ce50c35e52ce801dcbfabc4a5a2afb3 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 19 Aug 2022 08:33:09 -0700 Subject: [PATCH] [mlir][MemRef] Introduce a memref.extract_metadata op. This is the counterpart of `memref.reinterpret_cast` and is useful to lift strided memref manipulation out of the LLVM dialect. Discussion: https://discourse.llvm.org/t/extracting-dynamic-offsets-strides-from-memref/64170 Differential Revision: https://reviews.llvm.org/D132243 --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 65 ++++++++++++++++++++++++ mlir/test/Dialect/MemRef/ops.mlir | 17 +++++++ 2 files changed, 82 insertions(+) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 082f4bfa..5ef8d8f 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -694,6 +694,71 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> { } //===----------------------------------------------------------------------===// +// ExtractMetadataOp +//===----------------------------------------------------------------------===// + +def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", + [SameVariadicResultSize]> { + let summary = "Extracts a buffer base with offset and strides"; + let description = [{ + Extracts a base buffer, offset and strides. This op allows additional layers + of transformations and foldings to be added as lowering progresses from + higher-level dialect to lower-level dialects such as the LLVM dialect. + + The op requires a strided memref source operand. If the source operand is not + a strided memref, then verification fails. + + This operation is also useful for completeness to the existing memref.dim op. + While accessing strides, offsets and the base pointer independently is not + available, this is useful for composing with its natural complement op: + `memref.reinterpret_cast`. + + Intended Use Cases: + + The main use case is to expose the logic for manipulate memref metadata at a + higher level than the LLVM dialect. + This makes lowering more progressive and brings the following benefits: + - not all users of MLIR want to lower to LLVM and the information to e.g. + lower to library calls---like libxsmm---or to SPIR-V was not available. + - foldings and canonicalizations can happen at a higher level in MLIR: + before this op existed, lowering to LLVM would create large amounts of + LLVMIR. Even when LLVM does a good job at folding the low-level IR from + a performance perspective, it is unnecessarily opaque and inefficient to + send unkempt IR to LLVM. + + Example: + + ```mlir + %base, %offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %memref : + memref<10x?xf32>, index, index, index, index, index + + // After folding, the type of %m2 can be memref<10x?xf32> and further + // folded to %memref. + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + ``` + }]; + + let arguments = (ins + AnyStridedMemRef:$source + ); + let results = (outs + AnyStridedMemRefOfRank<0>:$base_buffer, + Index:$offset, + Variadic:$sizes, + Variadic:$strides + ); + + let assemblyFormat = [{ + $source `:` type($source) `->` type(results) attr-dict + }]; +} + +//===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 39a15b1..eda7c4c 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -336,3 +336,20 @@ func.func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { } { index_attr = 8 : index } return } + +// ----- + +func.func @extract_strided_metadata(%memref : memref<10x?xf32>) + -> memref { + + %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref + : memref<10x?xf32> -> memref, index, index, index, index, index + + %m2 = memref.reinterpret_cast %base to + offset: [%offset], + sizes: [%sizes#0, %sizes#1], + strides: [%strides#0, %strides#1] + : memref to memref + + return %m2: memref +} -- 2.7.4