From 54cda2ec976a89fcf5157d78479a576b09922df7 Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Thu, 23 Mar 2023 15:41:14 +0100 Subject: [PATCH] [mlir][MemRef] Add patterns to extract address computations This patch adds patterns to rewrite memory accesses such that the resulting accesses are only using a base pointer. E.g., ```mlir memref.load %base[%off0, ...] ``` Will be rewritten in: ```mlir %new_base = memref.subview %base[%off0,...][1,...][1,...] memref.load %new_base[%c0,...] ``` The idea behind these patterns is to offer a way to more gradually lower address computations. These patterns are the exact opposite of FoldMemRefAliasOps. I've implemented the support of only five operations in this patch: - memref.load - memref.store - nvgpu.ldmatrix - vector.transfer_read - vector.transfer_write Going forward we may want to provide an interface for these rewritings (and the ones in FoldMemRefAliasOps.) One step at a time! Differential Revision: https://reviews.llvm.org/D146724 --- .../MemRef/TransformOps/MemRefTransformOps.td | 44 +++ .../mlir/Dialect/MemRef/Transforms/Transforms.h | 40 +++ .../lib/Dialect/MemRef/TransformOps/CMakeLists.txt | 2 + .../MemRef/TransformOps/MemRefTransformOps.cpp | 32 ++ mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 + .../Transforms/ExtractAddressComputations.cpp | 313 ++++++++++++++++ .../MemRef/extract-address-computations.mlir | 393 +++++++++++++++++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 4 + 8 files changed, 830 insertions(+) create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp create mode 100644 mlir/test/Dialect/MemRef/extract-address-computations.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index ea7784e..a0b5a68 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -49,4 +49,48 @@ def MemRefMultiBufferOp : Op { + let summary = "Extract address computations from memory accesses"; + let description = [{ + Transformation that extracts address computations from instructions + with memory accesses such that these memory accesses use only a base + pointer. + + For instance, + ```mlir + memref.load %base[%off0, ...] + ``` + + Will be rewritten in: + ```mlir + %new_base = memref.subview %base[%off0,...][1,...][1,...] + memref.load %new_base[%c0,...] + ``` + + Note: The current implementation requires that the input operation + is "isolated from above". + + #### Return modes + + This operation produces `definiteFailure` if the extraction fails for any + reason. + The operation always returns the handle to the target op that is expected + to be isolated from above. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &transformResults, + ::mlir::transform::TransformState &state); + }]; +} #endif // MEMREF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h new file mode 100644 index 0000000..18b12d6 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -0,0 +1,40 @@ +//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This header declares functions that assit transformations in the MemRef +/// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H + +namespace mlir { +class RewritePatternSet; + +namespace memref { +/// Appends patterns for extracting address computations from the instructions +/// with memory accesses such that these memory accesses use only a base +/// pointer. +/// +/// For instance, +/// ```mlir +/// memref.load %base[%off0, ...] +/// ``` +/// +/// Will be rewritten in: +/// ```mlir +/// %new_base = memref.subview %base[%off0,...][1,...][1,...] +/// memref.load %new_base[%c0,...] +/// ``` +void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); + +} // namespace memref +} // namespace mlir + +#endif diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt index b98db40..b32e06a 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -15,5 +15,7 @@ add_mlir_dialect_library(MLIRMemRefTransformOps MLIRLoopLikeInterface MLIRMemRefDialect MLIRMemRefTransforms + MLIRNVGPUDialect MLIRTransformDialect + MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index ae721fe..3209b1b 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -11,10 +11,14 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -69,6 +73,31 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( } //===----------------------------------------------------------------------===// +// MemRefExtractAddressComputationsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MemRefExtractAddressComputationsOp::applyToOne( + Operation *target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (!target->hasTrait()) { + auto diag = this->emitOpError("requires isolated-from-above targets"); + diag.attachNote(target->getLoc()) << "non-isolated target"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + + MLIRContext *ctx = getContext(); + RewritePatternSet patterns(ctx); + memref::populateExtractAddressComputationsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return emitDefaultDefiniteFailure(target); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// @@ -83,6 +112,9 @@ public: declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 744f5c6..0b01a1c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp MultiBuffer.cpp NormalizeMemRefs.cpp @@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRMemRefDialect + MLIRNVGPUDialect MLIRPass MLIRTensorDialect MLIRTransforms diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp new file mode 100644 index 0000000..5ef977f --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -0,0 +1,313 @@ +//===- ExtractAddressCmoputations.cpp - Extract address computations -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This transformation pass rewrites loading/storing from/to a memref with +/// offsets into loading/storing from/to a subview and without any offset on +/// the instruction itself. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper functions for the `load base[off0...]` +// => `load (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getLoadOpSrcMemRef(memref::LoadOp loadOp) { + return loadOp.getMemRef(); +} + +// Matches rebuildOpFromAddressAndIndices specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, + memref::LoadOp loadOp, Value srcMemRef, + ArrayRef indices) { + Location loc = loadOp.getLoc(); + return rewriter.create(loc, srcMemRef, indices, + loadOp.getNontemporal()); +} + +// Matches getViewSizeForEachDim specs for LoadOp. +// \see LoadStoreLikeOpRewriter. +static SmallVector +getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) { + MemRefType ldTy = loadOp.getMemRefType(); + unsigned loadRank = ldTy.getRank(); + return SmallVector(loadRank, rewriter.getIndexAttr(1)); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `store val, base[off0...]` +// => `store val, (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getStoreOpSrcMemRef(memref::StoreOp storeOp) { + return storeOp.getMemRef(); +} + +// Matches rebuildOpFromAddressAndIndices specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, + memref::StoreOp storeOp, Value srcMemRef, + ArrayRef indices) { + Location loc = storeOp.getLoc(); + return rewriter.create(loc, storeOp.getValueToStore(), + srcMemRef, indices, + storeOp.getNontemporal()); +} + +// Matches getViewSizeForEachDim specs for StoreOp. +// \see LoadStoreLikeOpRewriter. +static SmallVector +getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) { + MemRefType ldTy = storeOp.getMemRefType(); + unsigned loadRank = ldTy.getRank(); + return SmallVector(loadRank, rewriter.getIndexAttr(1)); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `ldmatrix base[off0...]` +// => `ldmatrix (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for LdMatrixOp. +// \see LoadStoreLikeOpRewriter. +static FailureOr getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) { + return ldMatrixOp.getSrcMemref(); +} + +// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp. +// \see LoadStoreLikeOpRewriter. +static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, + nvgpu::LdMatrixOp ldMatrixOp, + Value srcMemRef, + ArrayRef indices) { + Location loc = ldMatrixOp.getLoc(); + return rewriter.create( + loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, + ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `transfer_read base[off0...]` +// => `transfer_read (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches getFailureOrSrcMemRef specs for TransferReadOp. +// \see LoadStoreLikeOpRewriter. +template +static FailureOr +getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { + Value src = transferLikeOp.getSource(); + if (src.getType().isa()) + return src; + return failure(); +} + +// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp. +// \see LoadStoreLikeOpRewriter. +static vector::TransferReadOp +rebuildTransferReadOp(RewriterBase &rewriter, + vector::TransferReadOp transferReadOp, Value srcMemRef, + ArrayRef indices) { + Location loc = transferReadOp.getLoc(); + return rewriter.create( + loc, transferReadOp.getResult().getType(), srcMemRef, indices, + transferReadOp.getPermutationMap(), transferReadOp.getPadding(), + transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); +} + +//===----------------------------------------------------------------------===// +// Helper functions for the `transfer_write base[off0...]` +// => `transfer_write (subview base[off0...])[0...]` pattern. +//===----------------------------------------------------------------------===// + +// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp. +// \see LoadStoreLikeOpRewriter. +static vector::TransferWriteOp +rebuildTransferWriteOp(RewriterBase &rewriter, + vector::TransferWriteOp transferWriteOp, Value srcMemRef, + ArrayRef indices) { + Location loc = transferWriteOp.getLoc(); + return rewriter.create( + loc, transferWriteOp.getValue(), srcMemRef, indices, + transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), + transferWriteOp.getInBoundsAttr()); +} + +//===----------------------------------------------------------------------===// +// Generic helper functions used as default implementation in +// LoadStoreLikeOpRewriter. +//===----------------------------------------------------------------------===// + +/// Helper function to get the src memref. +/// It uses the already defined getFailureOrSrcMemRef but asserts +/// that the source is a memref. +template (*getFailureOrSrcMemRef)(LoadStoreLikeOp)> +static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) { + FailureOr failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp); + assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used"); + return *failureOrSrcMemRef; +} + +/// Helper function to get the sizes of the resulting view. +/// This function gets the sizes of the source memref then substracts the +/// offsets used within \p loadStoreLikeOp. This gives the maximal (for +/// inbound) sizes for the view. +/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp. +template +static SmallVector +getGenericOpViewSizeForEachDim(RewriterBase &rewriter, + LoadStoreLikeOp loadStoreLikeOp) { + Location loc = loadStoreLikeOp.getLoc(); + auto extractStridedMetadataOp = + rewriter.create( + loc, getSrcMemRef(loadStoreLikeOp)); + SmallVector srcSizes = + extractStridedMetadataOp.getConstifiedMixedSizes(); + SmallVector indices = + getAsOpFoldResult(loadStoreLikeOp.getIndices()); + SmallVector finalSizes; + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + + for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) { + finalSizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, s0 - s1, + {srcSize, indice})); + } + return finalSizes; +} + +/// Rewrite a store/load-like op so that all its indices are zeros. +/// E.g., %ld = memref.load %base[%off0]...[%offN] +/// => +/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] +/// %ld = memref.load %new_base[0,..,0] : +/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> +/// +/// `getSrcMemRef` returns the source memref for the given load-like operation. +/// +/// `getViewSizeForEachDim` returns the sizes of view that is going to feed +/// new operation. This must return one size per dimension of the view. +/// The sizes of the view needs to be at least as big as what is actually +/// going to be accessed. Use the provided `loadStoreOp` to get the right +/// sizes. +/// +/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new +/// LoadStoreLikeOp that reads from srcMemRef[indices]. +/// The returned operation will be used to replace loadStoreOp. +template (*getFailureOrSrcMemRef)(LoadStoreLikeOp), + LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)( + RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/, + Value /*srcMemRef*/, ArrayRef /*indices*/), + SmallVector (*getViewSizeForEachDim)( + RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) = + getGenericOpViewSizeForEachDim< + LoadStoreLikeOp, + getSrcMemRef>> +struct LoadStoreLikeOpRewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp, + PatternRewriter &rewriter) const override { + FailureOr failureOrSrcMemRef = + getFailureOrSrcMemRef(loadStoreLikeOp); + if (failed(failureOrSrcMemRef)) + return rewriter.notifyMatchFailure(loadStoreLikeOp, + "source is not a memref"); + Value srcMemRef = *failureOrSrcMemRef; + auto ldStTy = srcMemRef.getType().cast(); + unsigned loadStoreRank = ldStTy.getRank(); + // Don't waste compile time if there is nothing to rewrite. + if (loadStoreRank == 0) + return rewriter.notifyMatchFailure(loadStoreLikeOp, + "0-D accesses don't need rewriting"); + + // If our load already has only zeros as indices there is nothing + // to do. + SmallVector indices = + getAsOpFoldResult(loadStoreLikeOp.getIndices()); + if (std::all_of(indices.begin(), indices.end(), + [](const OpFoldResult &opFold) { + return isConstantIntValue(opFold, 0); + })) { + return rewriter.notifyMatchFailure( + loadStoreLikeOp, "no computation to extract: offsets are 0s"); + } + + // Create the array of ones of the right size. + SmallVector ones(loadStoreRank, rewriter.getIndexAttr(1)); + SmallVector sizes = + getViewSizeForEachDim(rewriter, loadStoreLikeOp); + assert(sizes.size() == loadStoreRank && + "Expected one size per load dimension"); + Location loc = loadStoreLikeOp.getLoc(); + // The subview inherits its strides from the original memref and will + // apply them properly to the input indices. + // Therefore the strides multipliers are simply ones. + auto subview = + rewriter.create(loc, /*source=*/srcMemRef, + /*offsets=*/indices, + /*sizes=*/sizes, /*strides=*/ones); + // Rewrite the load/store with the subview as the base pointer. + SmallVector zeros(loadStoreRank, + rewriter.create(loc, 0)); + LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( + rewriter, loadStoreLikeOp, subview.getResult(), zeros); + rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); + return success(); + } +}; +} // namespace + +void memref::populateExtractAddressComputationsPatterns( + RewritePatternSet &patterns) { + patterns.add< + LoadStoreLikeOpRewriter< + memref::LoadOp, + /*getSrcMemRef=*/getLoadOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp, + /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>, + LoadStoreLikeOpRewriter< + memref::StoreOp, + /*getSrcMemRef=*/getStoreOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp, + /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>, + LoadStoreLikeOpRewriter< + nvgpu::LdMatrixOp, + /*getSrcMemRef=*/getLdMatrixOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>, + LoadStoreLikeOpRewriter< + vector::TransferReadOp, + /*getSrcMemRef=*/getTransferLikeOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>, + LoadStoreLikeOpRewriter< + vector::TransferWriteOp, + /*getSrcMemRef=*/getTransferLikeOpSrcMemRef, + /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/MemRef/extract-address-computations.mlir b/mlir/test/Dialect/MemRef/extract-address-computations.mlir new file mode 100644 index 0000000..17e2ac3 --- /dev/null +++ b/mlir/test/Dialect/MemRef/extract-address-computations.mlir @@ -0,0 +1,393 @@ +// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --verify-diagnostics | FileCheck %s + +// Simple test: check that we extract the address computation of a load into +// a dedicated subview. +// The resulting load will be loading from the subview and have only indices +// set to zero. + +// CHECK-LABEL: @test_load( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return %[[LOADED_VAL]] : f32 + +// expected-remark @below {{transformed}} +func.func @test_load(%base : memref<2x16x16xf32>, %offset : index) -> f32 { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %loaded_val = memref.load %base[%offset, %c0, %c8] : memref<2x16x16xf32> + return %loaded_val : f32 +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation +} + +// ----- + +// Same as previous @test_load but with the nontemporal flag. + +// CHECK-LABEL: @test_load_nontemporal( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return %[[LOADED_VAL]] : f32 +func.func @test_load_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> f32 { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %loaded_val = memref.load %base[%offset, %c0, %c8] {nontemporal = true } : memref<2x16x16xf32> + return %loaded_val : f32 +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Simple test: check that we extract the address computation of a store into +// a dedicated subview. +// The resulting store will use the address from the subview and have only +// indices set to zero. + +// CHECK-LABEL: @test_store( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return +func.func @test_store(%base : memref<2x16x16xf32>, %offset : index) -> () { + %cf0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + memref.store %cf0, %base[%offset, %c0, %c8] : memref<2x16x16xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Same as @test_store but check that the nontemporal flag is preserved. + +// CHECK-LABEL: @test_store_nontemporal( +// CHECK-SAME: %[[BASE:[^:]*]]: memref{{[^,]*}}, +// CHECK-SAME: %[[DYN_OFFSET:.*]]: index) +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET]], 0, 8] [1, 1, 1] [1, 1, 1] : memref<2x16x16xf32> to memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: memref.store %[[CF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {nontemporal = true} : memref<1x1x1xf32, strided<[256, 16, 1], offset: ?>> +// CHECK: return +func.func @test_store_nontemporal(%base : memref<2x16x16xf32>, %offset : index) -> () { + %cf0 = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + memref.store %cf0, %base[%offset, %c0, %c8] { nontemporal = true } : memref<2x16x16xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- +// For this test, we made the source memref fully dynamic. +// The gist of the check remains the same as the simple test: +// The address computation is extracted into its own subview. +// CHECK-LABEL: @testWithLoop( +// CHECK-SAME: %[[BASE:[^:]*]]: memref +// CHECK: %[[SUM_ALL:.*]] = arith.constant 0.0{{0*e\+00}} : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[UPPER_BOUND0:.*]] = memref.dim %[[BASE]], %[[C0]] : memref (f32) { +// CHECK: %[[SUM_RES1:.*]] = scf.for %[[IV1:.*]] = %[[C0]] to %[[UPPER_BOUND1]] step %[[C1]] iter_args(%[[SUM_ITER1:.*]] = %[[SUM_ITER2]]) -> (f32) { +// CHECK: %[[SUM_RES0:.*]] = scf.for %[[IV0:.*]] = %[[C0]] to %[[UPPER_BOUND0]] step %[[C1]] iter_args(%[[SUM_ITER0:.*]] = %[[SUM_ITER1]]) -> (f32) { +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[IV0]], %[[IV1]], %[[IV2]]] [1, 1, 1] [1, 1, 1] : memref> to memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %[[LOADED_VAL:.*]] = memref.load %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %[[RES:.*]] = arith.addf %[[LOADED_VAL]], %[[SUM_ITER2]] : f32 +// CHECK: scf.yield %[[RES]] : f32 +// CHECK: } +// CHECK: scf.yield %[[SUM_RES0]] : f32 +// CHECK: } +// CHECK: scf.yield %[[SUM_RES1]] : f32 +// CHECK: } +// CHECK: return %[[SUM_RES2]] : f32 +func.func @testWithLoop(%base : memref>) -> f32 { + %sum_all = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %upper_bound0 = memref.dim %base, %c0 : memref> + %upper_bound1 = memref.dim %base, %c1 : memref> + %upper_bound2 = memref.dim %base, %c2 : memref> + %sum_res2 = scf.for %iv2 = %c0 to %upper_bound2 step %c1 iter_args(%sum_iter2 = %sum_all) -> (f32) { + %sum_res1 = scf.for %iv1 = %c0 to %upper_bound1 step %c1 iter_args(%sum_iter1 = %sum_iter2) -> (f32) { + %sum_res0 = scf.for %iv0 = %c0 to %upper_bound0 step %c1 iter_args(%sum_iter0 = %sum_iter1) -> (f32) { + %loaded_val = memref.load %base[%iv0, %iv1, %iv2] : memref> + %res = arith.addf %loaded_val, %sum_iter2 : f32 + scf.yield %res : f32 + } + scf.yield %sum_res0 : f32 + } + scf.yield %sum_res1 : f32 + } + return %sum_res2 : f32 +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Simple test: check that we extract the address computation of a ldmatrix into +// a dedicated subview. +// The resulting ldmatrix will loaded from with subview and have only indices set +// to zero. +// Also the sizes of the view are adjusted to `original size - offset`. + +// CHECK-DAG: #[[$FOUR_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 4)> +// CHECK-DAG: #[[$THIRTY_TWO_MINUS_OFF_MAP:.*]] = affine_map<()[s0] -> (-s0 + 32)> +// CHECK-LABEL: @test_ldmatrix( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$FOUR_MINUS_OFF_MAP]]()[%[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$THIRTY_TWO_MINUS_OFF_MAP]]()[%[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref<4x32x32xf16, 3> to memref, 3> +// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref, 3> -> vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_ldmatrix(%base : memref<4x32x32xf16, 3>, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %loaded_val = nvgpu.ldmatrix + %base[%offset0, %offset1, %offset2] + {numTiles = 4 : i32, transpose = false} + : memref<4x32x32xf16, 3> -> vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Same as test_ldmatrix but with fully dynamic memref. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-LABEL: @test_ldmatrix( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}, 3>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref, 3> +// CHECK: %[[LOADED_VAL:.*]] = nvgpu.ldmatrix %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {numTiles = 4 : i32, transpose = false} : memref, 3> -> vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_ldmatrix(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %loaded_val = nvgpu.ldmatrix + %base[%offset0, %offset1, %offset2] + {numTiles = 4 : i32, transpose = false} + : memref -> vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Simple test for vector.transfer_read with fully dynamic memref. +// We also set a permutation map to make sure it is properly preserved. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_read_op( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16 +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref> +// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : memref>, vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_transfer_read_op(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %cf0 = arith.constant 0.0 : f16 + %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : memref, vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Same as test_transfer_read_op but with tensors. +// Right now this rewrite is not supported but we still shouldn't choke on it. + +// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_read_op_with_tensor( +// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK: %[[CF0:.*]] = arith.constant 0.0{{0*e\+00}} : f16 +// CHECK: %[[LOADED_VAL:.*]] = vector.transfer_read %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]], %[[CF0]] {permutation_map = #[[$PERMUTATION_MAP]]} : tensor, vector<4x2xf16> +// CHECK: return %[[LOADED_VAL]] : vector<4x2xf16> +func.func @test_transfer_read_op_with_tensor(%base : tensor, + %offset0 : index, %offset1: index, %offset2: index) + -> vector<4x2xf16> { + %cf0 = arith.constant 0.0 : f16 + %loaded_val = vector.transfer_read %base[%offset0, %offset1, %offset2], %cf0 { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : tensor, vector<4x2xf16> + return %loaded_val : vector<4x2xf16> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Simple test for vector.transfer_write with fully dynamic memref. +// We also set a permutation map to make sure it is properly preserved. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_write_op( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16> +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref to memref> +// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref> +// CHECK: return +func.func @test_transfer_write_op(%base : memref, + %offset0 : index, %offset1: index, %offset2: index) { + %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16> + vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} + +// ----- + +// Check that the strides of the original memref are kept. +// Moreover even with non-1 strides the subview should still issue [1,...] +// strides, since this is a multiplication factor. + +// CHECK-DAG: #[[$A_MINUS_B_MAP:.*]] = affine_map<()[s0, s1] -> (s0 - s1)> +// CHECK-DAG: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_write_op_with_strides( +// CHECK-SAME: %[[BASE:[^:]*]]: memref<{{[^>]*}}>>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: {{.*}}, {{.*}}, %[[DYN_SIZES:.*]]:3, {{.*}} = memref.extract_strided_metadata %[[BASE]] +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#0, %[[DYN_OFFSET0]]] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#1, %[[DYN_OFFSET1]]] +// CHECK-DAG: %[[DYN_SIZE2:.*]] = affine.apply #[[$A_MINUS_B_MAP]]()[%[[DYN_SIZES]]#2, %[[DYN_OFFSET2]]] +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16> +// CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]] [1, 1, 1] : memref> to memref> +// CHECK: vector.transfer_write %[[VCF0]], %[[SUBVIEW]][%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, memref> +// CHECK: return +func.func @test_transfer_write_op_with_strides(%base : memref>, + %offset0 : index, %offset1: index, %offset2: index) { + %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16> + vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, memref> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} +// ----- + +// Same as test_transfer_write_op but with tensors. +// Right now this rewrite is not supported but we still shouldn't choke on it. + +// CHECK: #[[$PERMUTATION_MAP:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> +// CHECK-LABEL: @test_transfer_write_op_with_tensor( +// CHECK-SAME: %[[BASE:[^:]*]]: tensor<{{[^,]*}}>, +// CHECK-SAME: %[[DYN_OFFSET0:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:[^:]*]]: index, +// CHECK-SAME: %[[DYN_OFFSET2:[^:]*]]: index) +// CHECK-DAG: %[[VCF0:.*]] = arith.constant dense<0.0{{0*e\+00}}> : vector<4x2xf16> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[VCF0]], %[[BASE]][%[[DYN_OFFSET0]], %[[DYN_OFFSET1]], %[[DYN_OFFSET2]]] {permutation_map = #[[$PERMUTATION_MAP]]} : vector<4x2xf16>, tensor +// CHECK: return %[[RES]] : tensor +func.func @test_transfer_write_op_with_tensor(%base : tensor, + %offset0 : index, %offset1: index, %offset2: index) -> tensor { + %vcf0 = arith.constant dense<0.000000e+00> : vector<4x2xf16> + %res = vector.transfer_write %vcf0, %base[%offset0, %offset1, %offset2] { permutation_map = affine_map<(d0,d1,d2) -> (d2,d0)> } : vector<4x2xf16>, tensor + return %res : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.extract_address_computations %0 : (!pdl.operation) -> !pdl.operation +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index b80ddae..6d75536 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10098,6 +10098,7 @@ cc_library( ":LoopLikeInterface", ":MemRefDialect", ":MemRefPassIncGen", + ":NVGPUDialect", ":Pass", ":RuntimeVerifiableOpInterface", ":TensorDialect", @@ -10152,8 +10153,11 @@ cc_library( ":MemRefDialect", ":MemRefTransformOpsIncGen", ":MemRefTransforms", + ":NVGPUDialect", ":PDLDialect", ":TransformDialect", + ":TransformUtils", + ":VectorDialect", "//llvm:Support", ], ) -- 2.7.4