From 2c46051aa9d30dc1740f2183ceb45a235b994cc3 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 1 Feb 2023 19:35:25 +0000 Subject: [PATCH] [mlir][spirv] Fix vector type mismatch in UnifyAliasedResourcePass For the cases where we have aliases of `vector<4xf16>` and `vector<4xf32>`, we need to do casting before composite construction. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D143042 --- .../SPIRV/Transforms/UnifyAliasedResourcePass.cpp | 17 ++++++++++++ .../SPIRV/Transforms/unify-aliased-resource.mlir | 31 ++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 3e5b934..3acfba2 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -485,11 +485,28 @@ struct ConvertLoad : public ConvertAliasResource { // bitwidth element type. For spirv.bitcast, the lower-numbered components // of the vector map to lower-ordered bits of the larger bitwidth element // type. + Type vectorType = srcElemType; if (!srcElemType.isa()) vectorType = VectorType::get({ratio}, dstElemType); + + // If both the source and destination are vector types, we need to make + // sure the scalar type is the same for composite construction later. + if (auto srcElemVecType = srcElemType.dyn_cast()) + if (auto dstElemVecType = dstElemType.dyn_cast()) { + if (srcElemVecType.getElementType() != + dstElemVecType.getElementType()) { + int64_t count = + dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8); + auto castType = + VectorType::get({count}, srcElemVecType.getElementType()); + for (auto &c : components) + c = rewriter.create(loc, castType, c); + } + } Value vectorValue = rewriter.create( loc, vectorType, components); + if (!srcElemType.isa()) vectorValue = rewriter.create(loc, srcElemType, vectorValue); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir index a456016..1d532f3 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -448,3 +448,34 @@ spirv.module Logical GLSL450 { // CHECK: spirv.GlobalVariable @var01_i16 bind(0, 1) {aliased} // CHECK: spirv.func @scalar_type_bitwidth_smaller_than_vector + +// ----- + +spirv.module Logical GLSL450 { + spirv.GlobalVariable @var00_v4f32 bind(0, 0) {aliased} : !spirv.ptr, stride=16> [0])>, StorageBuffer> + spirv.GlobalVariable @var00_v4f16 bind(0, 0) {aliased} : !spirv.ptr, stride=8> [0])>, StorageBuffer> + + spirv.func @vector_type_same_size_different_element_type(%i0: i32) -> vector<4xf32> "None" { + %c0 = spirv.Constant 0 : i32 + + %addr = spirv.mlir.addressof @var00_v4f32 : !spirv.ptr, stride=16> [0])>, StorageBuffer> + %ac = spirv.AccessChain %addr[%c0, %i0] : !spirv.ptr, stride=16> [0])>, StorageBuffer>, i32, i32 + %val = spirv.Load "StorageBuffer" %ac : vector<4xf32> + + spirv.ReturnValue %val : vector<4xf32> + } +} + +// CHECK-LABEL: spirv.module + +// CHECK: spirv.GlobalVariable @var00_v4f16 bind(0, 0) : !spirv.ptr, stride=8> [0])>, StorageBuffer> + +// CHECK: spirv.func @vector_type_same_size_different_element_type + +// CHECK: %[[LD0:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16> +// CHECK: %[[LD1:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xf16> +// CHECK: %[[BC0:.+]] = spirv.Bitcast %[[LD0]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[BC1:.+]] = spirv.Bitcast %[[LD1]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32> +// CHECK: spirv.ReturnValue %[[CC]] + -- 2.7.4