From 11cf2d5f62f94b14644fab3478f17af9cb015706 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 10 Jun 2022 18:01:31 -0400 Subject: [PATCH] [mlir][spirv] Unify aliases of different bitwidth scalar types This commit extends the UnifyAliasedResourcePass to handle scalar types of different bitwidths. It requires to get the smaller bitwidth resource as the canonical resource so that we can avoid subcomponent load/store. Instead we load/store multiple smaller bitwidth ones. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D127266 --- .../SPIRV/Transforms/UnifyAliasedResourcePass.cpp | 176 ++++++++++++++++----- .../SPIRV/Transforms/unify-aliased-resource.mlir | 67 +++++++- 2 files changed, 200 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index e2b83a8..2dc4d73 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include +#include #define DEBUG_TYPE "spirv-unify-aliased-resource" @@ -72,20 +73,65 @@ static Type getRuntimeArrayElementType(Type type) { return rtArrayType.getElementType(); } -/// Returns true if all `types`, which can either be scalar or vector types, -/// have the same bitwidth base scalar type. -static bool hasSameBitwidthScalarType(ArrayRef types) { - SmallVector scalarTypes; - scalarTypes.reserve(types.size()); +/// Given a list of resource element `types`, returns the index of the canonical +/// resource that all resources should be unified into. Returns llvm::None if +/// unable to unify. +static Optional deduceCanonicalResource(ArrayRef types) { + SmallVector scalarNumBits, totalNumBits; + scalarNumBits.reserve(types.size()); + totalNumBits.reserve(types.size()); + bool hasVector = false; + for (spirv::SPIRVType type : types) { assert(type.isScalarOrVector()); - if (auto vectorType = type.dyn_cast()) - scalarTypes.push_back( + if (auto vectorType = type.dyn_cast()) { + if (vectorType.getNumElements() % 2 != 0) + return llvm::None; // Odd-sized vector has special layout requirements. + + Optional numBytes = type.getSizeInBytes(); + if (!numBytes) + return llvm::None; + + scalarNumBits.push_back( vectorType.getElementType().getIntOrFloatBitWidth()); - else - scalarTypes.push_back(type.getIntOrFloatBitWidth()); + totalNumBits.push_back(*numBytes * 8); + hasVector = true; + } else { + scalarNumBits.push_back(type.getIntOrFloatBitWidth()); + totalNumBits.push_back(scalarNumBits.back()); + } } - return llvm::is_splat(scalarTypes); + + if (hasVector) { + // If there are vector types, require all element types to be the same for + // now to simplify the transformation. + if (!llvm::is_splat(scalarNumBits)) + return llvm::None; + + // Choose the one with the largest bitwidth as the canonical resource, so + // that we can still keep vectorized load/store. + auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end()); + // Make sure that the canonical resource's bitwidth is divisible by others. + // With out this, we cannot properly adjust the index later. + if (llvm::any_of(totalNumBits, + [maxVal](int64_t bits) { return *maxVal % bits != 0; })) + return llvm::None; + + return std::distance(totalNumBits.begin(), maxVal); + } + + // All element types are scalars. Then choose the smallest bitwidth as the + // cannonical resource to avoid subcomponent load/store. + auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end()); + if (llvm::any_of(scalarNumBits, + [minVal](int64_t bit) { return bit % *minVal != 0; })) + return llvm::None; + return std::distance(scalarNumBits.begin(), minVal); +} + +static bool areSameBitwidthScalarType(Type a, Type b) { + return a.isIntOrFloat() && b.isIntOrFloat() && + a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth(); } //===----------------------------------------------------------------------===// @@ -203,11 +249,8 @@ ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const { void ResourceAliasAnalysis::recordIfUnifiable( const Descriptor &descriptor, ArrayRef resources) { - // Collect the element types and byte counts for all resources in the - // current set. + // Collect the element types for all resources in the current set. SmallVector elementTypes; - SmallVector numBytes; - for (spirv::GlobalVariableOp resource : resources) { Type elementType = getRuntimeArrayElementType(resource.type()); if (!elementType) @@ -217,37 +260,16 @@ void ResourceAliasAnalysis::recordIfUnifiable( if (!type.isScalarOrVector()) return; // Unexpected resource element type. - if (auto vectorType = type.dyn_cast()) - if (vectorType.getNumElements() % 2 != 0) - return; // Odd-sized vector has special layout requirements. - - Optional count = type.getSizeInBytes(); - if (!count) - return; - elementTypes.push_back(type); - numBytes.push_back(*count); } - // Make sure base scalar types have the same bitwdith, so that we don't need - // to handle extracting components for now. - if (!hasSameBitwidthScalarType(elementTypes)) - return; - - // Make sure that the canonical resource's bitwidth is divisible by others. - // With out this, we cannot properly adjust the index later. - auto *maxCount = std::max_element(numBytes.begin(), numBytes.end()); - if (llvm::any_of(numBytes, [maxCount](int64_t count) { - return *maxCount % count != 0; - })) + Optional index = deduceCanonicalResource(elementTypes); + if (!index) return; - spirv::GlobalVariableOp canonicalResource = - resources[std::distance(numBytes.begin(), maxCount)]; - // Update internal data structures for later use. resourceMap[descriptor].assign(resources.begin(), resources.end()); - canonicalResourceMap[descriptor] = canonicalResource; + canonicalResourceMap[descriptor] = resources[*index]; for (const auto &resource : llvm::enumerate(resources)) { descriptorMap[resource.value()] = descriptor; elementTypeMap[resource.value()] = elementTypes[resource.index()]; @@ -316,8 +338,8 @@ struct ConvertAccessChain : public ConvertAliasResource { spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp); - if ((srcElemType == dstElemType) || - (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) { + if (srcElemType == dstElemType || + areSameBitwidthScalarType(srcElemType, dstElemType)) { // We have the same bitwidth for source and destination element types. // Thie indices keep the same. rewriter.replaceOpWithNewOp( @@ -333,7 +355,10 @@ struct ConvertAccessChain : public ConvertAliasResource { // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside // the vector. - int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes(); + int srcNumBits = *srcElemType.getSizeInBytes(); + int dstNumBits = *dstElemType.getSizeInBytes(); + assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0); + int ratio = dstNumBits / srcNumBits; auto ratioValue = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(ratio)); @@ -349,6 +374,27 @@ struct ConvertAccessChain : public ConvertAliasResource { return success(); } + if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) { + // The source indices are for a buffer with larger bitwidth scalar element + // types. Rewrite them into a buffer with smaller bitwidth element types. + // We only need to scale the last index. + int srcNumBits = *srcElemType.getSizeInBytes(); + int dstNumBits = *dstElemType.getSizeInBytes(); + assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); + int ratio = srcNumBits / dstNumBits; + auto ratioValue = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(ratio)); + + auto indices = llvm::to_vector<4>(acOp.indices()); + Value oldIndex = indices.back(); + indices.back() = + rewriter.create(loc, i32Type, oldIndex, ratioValue); + + rewriter.replaceOpWithNewOp( + acOp, adaptor.base_ptr(), indices); + return success(); + } + return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types"); } }; @@ -370,12 +416,56 @@ struct ConvertLoad : public ConvertAliasResource { auto newLoadOp = rewriter.create(loc, adaptor.ptr()); if (srcElemType == dstElemType) { rewriter.replaceOp(loadOp, newLoadOp->getResults()); - } else { + return success(); + } + + if (areSameBitwidthScalarType(srcElemType, dstElemType)) { auto castOp = rewriter.create(loc, srcElemType, newLoadOp.value()); rewriter.replaceOp(loadOp, castOp->getResults()); + + return success(); } + // The source and destination have scalar types of different bitwidths. + // For such cases, we need to load multiple smaller bitwidth values and + // construct a larger bitwidth one. + + int srcNumBits = srcElemType.getIntOrFloatBitWidth(); + int dstNumBits = dstElemType.getIntOrFloatBitWidth(); + assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0); + int ratio = srcNumBits / dstNumBits; + if (ratio > 4) + return rewriter.notifyMatchFailure(loadOp, "more than 4 components"); + + SmallVector components; + components.reserve(ratio); + components.push_back(newLoadOp); + + auto acOp = adaptor.ptr().getDefiningOp(); + if (!acOp) + return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); + + auto i32Type = rewriter.getI32Type(); + Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); + auto indices = llvm::to_vector<4>(acOp.indices()); + for (int i = 1; i < ratio; ++i) { + // Load all subsequent components belonging to this element. + indices.back() = rewriter.create(loc, i32Type, + indices.back(), oneValue); + auto componentAcOp = + rewriter.create(loc, acOp.base_ptr(), indices); + components.push_back(rewriter.create(loc, componentAcOp)); + } + std::reverse(components.begin(), components.end()); // For little endian.. + + // Create a vector of the components and then cast back to the larger + // bitwidth element type. + auto vectorType = VectorType::get({ratio}, dstElemType); + Value vectorValue = rewriter.create( + loc, vectorType, components); + rewriter.replaceOpWithNewOp(loadOp, srcElemType, + vectorValue); return success(); } }; @@ -392,6 +482,8 @@ struct ConvertStore : public ConvertAliasResource { adaptor.ptr().getType().cast().getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); + if (!areSameBitwidthScalarType(srcElemType, dstElemType)) + return rewriter.notifyMatchFailure(storeOp, "different bitwidth"); Location loc = storeOp.getLoc(); Value value = adaptor.value(); diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir index 546fc1f..0b36178 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource -verify-diagnostics %s | FileCheck %s spv.module Logical GLSL450 { spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> @@ -213,3 +213,68 @@ spv.module Logical GLSL450 { // CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32 // CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32 // CHECK: spv.ReturnValue %[[CAST1]] : i32 + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @load_different_scalar_bitwidth(%index: i32) -> i64 "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s_i64 : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %val0 = spv.Load "StorageBuffer" %ac0 : i64 + + spv.ReturnValue %val0 : i64 + } +} + +// CHECK-LABEL: spv.module + +// CHECK-NOT: @var01s_i64 +// CHECK: spv.GlobalVariable @var01s_f32 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> +// CHECK-NOT: @var01s_i64 + +// CHECK: spv.func @load_different_scalar_bitwidth(%[[INDEX:.+]]: i32) +// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 +// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01s_f32 + +// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32 +// CHECK: %[[BASE:.+]] = spv.IMul %[[INDEX]], %[[TWO]] : i32 +// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[BASE]]] +// CHECK: %[[LOAD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : f32 + +// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 +// CHECK: %[[ADD:.+]] = spv.IAdd %[[BASE]], %[[ONE]] : i32 +// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[ADD]]] +// CHECK: %[[LOAD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : f32 + +// CHECK: %[[CC:.+]] = spv.CompositeConstruct %[[LOAD1]], %[[LOAD0]] +// CHECK: %[[CAST:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64 +// CHECK: spv.ReturnValue %[[CAST]] + +// ----- + +spv.module Logical GLSL450 { + spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr [0])>, StorageBuffer> + + spv.func @store_different_scalar_bitwidth(%i0: i32, %i1: i32) "None" { + %c0 = spv.Constant 0 : i32 + + %addr0 = spv.mlir.addressof @var01s_f32 : !spv.ptr [0])>, StorageBuffer> + %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + %f32val = spv.Load "StorageBuffer" %ac0 : f32 + %f64val = spv.FConvert %f32val : f32 to f64 + %i64val = spv.Bitcast %f64val : f64 to i64 + + %addr1 = spv.mlir.addressof @var01s_i64 : !spv.ptr [0])>, StorageBuffer> + %ac1 = spv.AccessChain %addr1[%c0, %i1] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + // expected-error@+1 {{failed to legalize operation 'spv.Store'}} + spv.Store "StorageBuffer" %ac1, %i64val : i64 + + spv.Return + } +} -- 2.7.4