From 29ad9d6b26ee92c7843c06392625d894d58658c2 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Fri, 21 Feb 2020 14:39:32 -0500 Subject: [PATCH] [mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V. Differential Revision: https://reviews.llvm.org/D74874 --- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 11 +++++-- .../Conversion/StandardToSPIRV/std-to-spirv.mlir | 38 ++++++++++++++++++++++ .../Dialect/SPIRV/Serialization/memory-ops.mlir | 34 +++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index 7730661..3cf5046 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -69,6 +69,9 @@ static Optional getTypeNumBytes(Type t) { if (!elementSize) { return llvm::None; } + if (memRefType.getRank() == 0) { + return elementSize; + } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || @@ -325,8 +328,12 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( } SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. - linearizedIndices.push_back(builder.create( - loc, indexType, IntegerAttr::get(indexType, 0))); + auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder); + linearizedIndices.push_back(zero); + // If it is a zero-rank memref type, extract the element directly. + if (!ptrLoc) { + ptrLoc = zero; + } linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir index 1ec8b4c..341df27 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -312,3 +312,41 @@ func @sitofp(%arg0 : i32) { func @memref_type(%arg0: memref<3xi1>) { return } + +// CHECK-LABEL: @load_store_zero_rank_float +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 + store %0, %arg1[] : memref + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +// CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, +// CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) +func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { + // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]], [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 + %0 = load %arg0[] : memref + // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32 + // CHECK: spv.AccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]], [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 + store %0, %arg1[] : memref + return +} diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir index 9c8be4e..d89f1ff 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -23,3 +23,37 @@ spv.module "Logical" "GLSL450" { spv.Return } } + +// ----- + +spv.module "Logical" "GLSL450" { + spv.func @load_store_zero_rank_float(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32 + %0 = spv.constant 0 : i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.Load "StorageBuffer" %1 : f32 + + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 + %3 = spv.constant 0 : i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + spv.Store "StorageBuffer" %4, %2 : f32 + spv.Return + } + + spv.func @load_store_zero_rank_int(%arg0: !spv.ptr [0]>, StorageBuffer>, %arg1: !spv.ptr [0]>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32 + %0 = spv.constant 0 : i32 + %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr [0]>, StorageBuffer> + %2 = spv.Load "StorageBuffer" %1 : i32 + + // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [0]> + // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 + %3 = spv.constant 0 : i32 + %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr [0]>, StorageBuffer> + spv.Store "StorageBuffer" %4, %2 : i32 + spv.Return + } +} -- 2.7.4