--- /dev/null
+//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to generate ROCDLIR operations for higher-level
+// Vector operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+ return OperandAdaptor<TransferReadOp>(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+ return OperandAdaptor<TransferWriteOp>(operands);
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+ ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
+ LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+ Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
+ xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
+ return success();
+}
+
+static LogicalResult replaceTransferOpWithMubuf(
+ ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
+ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
+ LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
+ Value &offsetSizeInBytes, Value &glc, Value &slc) {
+ auto adaptor = TransferWriteOpOperandAdaptor(operands);
+ rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
+ dwordConfig, vindex,
+ offsetSizeInBytes, glc, slc);
+ return success();
+}
+
+namespace {
+/// Conversion pattern that converts a 1-D vector transfer read/write.
+/// Note that this conversion pass only converts vector x2 or x4 f32
+/// types. For unsupported cases, they will fall back to the vector to
+/// llvm conversion pattern.
+template <typename ConcreteOp>
+class VectorTransferConversion : public ConvertToLLVMPattern {
+public:
+ explicit VectorTransferConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConv)
+ : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
+ typeConv) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto xferOp = cast<ConcreteOp>(op);
+ auto adaptor = getTransferOpAdapter(xferOp, operands);
+
+ if (xferOp.getVectorType().getRank() > 1 ||
+ llvm::size(xferOp.indices()) == 0)
+ return failure();
+
+ if (!AffineMap::isMinorIdentity(xferOp.permutation_map()))
+ return failure();
+
+ // Have it handled in vector->llvm conversion pass.
+ if (!xferOp.isMaskedDim(0))
+ return failure();
+
+ auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+ LLVM::LLVMType vecTy =
+ toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
+ unsigned vecWidth = vecTy.getVectorNumElements();
+ Location loc = op->getLoc();
+
+ // The backend result vector scalarization have trouble scalarize
+ // <1 x ty> result, exclude the x1 width from the lowering.
+ if (vecWidth != 2 && vecWidth != 4)
+ return failure();
+
+ // Obtain dataPtr and elementType from the memref.
+ MemRefType memRefType = xferOp.getMemRefType();
+ // MUBUF instruction operate only on addresspace 0(unified) or 1(global)
+ // In case of 3(LDS): fall back to vector->llvm pass
+ // In case of 5(VGPR): wrong
+ if ((memRefType.getMemorySpace() != 0) &&
+ (memRefType.getMemorySpace() != 1))
+ return failure();
+
+ // Note that the dataPtr starts at the offset address specified by
+ // indices, so no need to calculat offset size in bytes again in
+ // the MUBUF instruction.
+ Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter, getModule());
+
+ // 1. Create and fill a <4 x i32> dwordConfig with:
+ // 1st two elements holding the address of dataPtr.
+ // 3rd element: -1.
+ // 4th element: 0x27000.
+ SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
+ Type i32Ty = rewriter.getIntegerType(32);
+ VectorType i32Vecx4 = VectorType::get(4, i32Ty);
+ Value constConfig = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(i32Vecx4),
+ DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
+
+ // Treat first two element of <4 x i32> as i64, and save the dataPtr
+ // to it.
+ Type i64Ty = rewriter.getIntegerType(64);
+ Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
+ loc,
+ LLVM::LLVMType::getVectorTy(
+ toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ constConfig);
+ Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
+ loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
+ Value zero = createIndexConstant(rewriter, loc, 0);
+ Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
+ loc,
+ LLVM::LLVMType::getVectorTy(
+ toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+ i64x2Ty, dataPtrAsI64, zero);
+ dwordConfig =
+ rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
+
+ // 2. Rewrite op as a buffer read or write.
+ Value int1False = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(rewriter.getIntegerType(1)),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
+ Value int32Zero = rewriter.create<LLVM::ConstantOp>(
+ loc, toLLVMTy(i32Ty),
+ rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
+ return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc,
+ xferOp, vecTy, dwordConfig, int32Zero,
+ int32Zero, int1False, int1False);
+ }
+};
+} // end anonymous namespace
+
+void mlir::populateVectorToROCDLConversionPatterns(
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+ MLIRContext *ctx = converter.getDialect()->getContext();
+ patterns.insert<VectorTransferConversion<TransferReadOp>,
+ VectorTransferConversion<TransferWriteOp>>(ctx, converter);
+}
+
+namespace {
+struct LowerVectorToROCDLPass
+ : public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void LowerVectorToROCDLPass::runOnOperation() {
+ LLVMTypeConverter converter(&getContext());
+ OwningRewritePatternList patterns;
+
+ populateVectorToROCDLConversionPatterns(converter, patterns);
+ populateStdToLLVMConversionPatterns(converter, patterns);
+
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<ROCDL::ROCDLDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target, patterns,
+ &converter))) {
+ signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertVectorToROCDLPass() {
+ return std::make_unique<LowerVectorToROCDLPass>();
+}
--- /dev/null
+// RUN: mlir-opt %s -convert-vector-to-rocdl | FileCheck %s
+
+gpu.module @test_read{
+func @transfer_readx2(%A : memref<?xf32>, %base: index) -> vector<2xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<2xf32>
+ return %f: vector<2xf32>
+}
+// CHECK-LABEL: @transfer_readx2
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<2 x float>">
+
+func @transfer_readx4(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+ return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_readx4
+// CHECK: rocdl.buffer.load {{.*}} !llvm<"<4 x float>">
+
+func @transfer_read_dwordConfig(%A : memref<?xf32>, %base: index) -> vector<4xf32> {
+ %f0 = constant 0.0: f32
+ %f = vector.transfer_read %A[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+ return %f: vector<4xf32>
+}
+// CHECK-LABEL: @transfer_read_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}
+
+gpu.module @test_write{
+func @transfer_writex2(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_writex2
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<2 x float>">
+
+func @transfer_writex4(%A : memref<?xf32>, %B : vector<4xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<4xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_writex4
+// CHECK: rocdl.buffer.store {{.*}} !llvm<"<4 x float>">
+
+func @transfer_write_dwordConfig(%A : memref<?xf32>, %B : vector<2xf32>, %base: index) {
+ vector.transfer_write %B, %A[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+ return
+}
+// CHECK-LABEL: @transfer_write_dwordConfig
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}
+// CHECK: [0, 0, -1, 159744]
+// CHECK: %[[i64:.*]] = llvm.ptrtoint %[[gep]]
+// CHECK: llvm.insertelement %[[i64]]
+}
--- /dev/null
+// RUN: mlir-rocm-runner %s --shared-libs=%rocm_wrapper_library_dir/librocm-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+func @vectransferx2(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+ %cst = constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+ threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+ %f0 = constant 0.0: f32
+ %base = constant 0 : index
+ %f = vector.transfer_read %arg0[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<2xf32>
+
+ %c = addf %f, %f : vector<2xf32>
+
+ %base1 = constant 1 : index
+ vector.transfer_write %c, %arg1[%base1]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<2xf32>, memref<?xf32>
+
+ gpu.terminator
+ }
+ return
+}
+
+func @vectransferx4(%arg0 : memref<?xf32>, %arg1 : memref<?xf32>) {
+ %cst = constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, %grid_z = %cst)
+ threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, %block_z = %cst) {
+ %f0 = constant 0.0: f32
+ %base = constant 0 : index
+ %f = vector.transfer_read %arg0[%base], %f0
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<?xf32>, vector<4xf32>
+
+ %c = addf %f, %f : vector<4xf32>
+
+ vector.transfer_write %c, %arg1[%base]
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ vector<4xf32>, memref<?xf32>
+
+ gpu.terminator
+ }
+ return
+}
+
+func @main() {
+ %cf1 = constant 1.0 : f32
+
+ %arg0 = alloc() : memref<4xf32>
+ %arg1 = alloc() : memref<4xf32>
+
+ %22 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
+ %23 = memref_cast %arg1 : memref<4xf32> to memref<?xf32>
+
+ %cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
+ %cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
+
+ call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
+ call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
+
+ %24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
+ %26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
+
+ // CHECK: [1.23, 2.46, 2.46, 1.23]
+ call @vectransferx2(%24, %26) : (memref<?xf32>, memref<?xf32>) -> ()
+ call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+
+ // CHECK: [2.46, 2.46, 2.46, 2.46]
+ call @vectransferx4(%24, %26) : (memref<?xf32>, memref<?xf32>) -> ()
+ call @print_memref_f32(%cast1) : (memref<*xf32>) -> ()
+ return
+}
+
+func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
+func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
+func @print_memref_f32(%ptr : memref<*xf32>)