#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <algorithm>
+#include <iterator>
#define DEBUG_TYPE "spirv-unify-aliased-resource"
return rtArrayType.getElementType();
}
-/// Returns true if all `types`, which can either be scalar or vector types,
-/// have the same bitwidth base scalar type.
-static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
- SmallVector<int64_t> scalarTypes;
- scalarTypes.reserve(types.size());
+/// Given a list of resource element `types`, returns the index of the canonical
+/// resource that all resources should be unified into. Returns llvm::None if
+/// unable to unify.
+static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
+ SmallVector<int> scalarNumBits, totalNumBits;
+ scalarNumBits.reserve(types.size());
+ totalNumBits.reserve(types.size());
+ bool hasVector = false;
+
for (spirv::SPIRVType type : types) {
assert(type.isScalarOrVector());
- if (auto vectorType = type.dyn_cast<VectorType>())
- scalarTypes.push_back(
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (vectorType.getNumElements() % 2 != 0)
+ return llvm::None; // Odd-sized vector has special layout requirements.
+
+ Optional<int64_t> numBytes = type.getSizeInBytes();
+ if (!numBytes)
+ return llvm::None;
+
+ scalarNumBits.push_back(
vectorType.getElementType().getIntOrFloatBitWidth());
- else
- scalarTypes.push_back(type.getIntOrFloatBitWidth());
+ totalNumBits.push_back(*numBytes * 8);
+ hasVector = true;
+ } else {
+ scalarNumBits.push_back(type.getIntOrFloatBitWidth());
+ totalNumBits.push_back(scalarNumBits.back());
+ }
}
- return llvm::is_splat(scalarTypes);
+
+ if (hasVector) {
+ // If there are vector types, require all element types to be the same for
+ // now to simplify the transformation.
+ if (!llvm::is_splat(scalarNumBits))
+ return llvm::None;
+
+ // Choose the one with the largest bitwidth as the canonical resource, so
+ // that we can still keep vectorized load/store.
+ auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
+ // Make sure that the canonical resource's bitwidth is divisible by others.
+ // With out this, we cannot properly adjust the index later.
+ if (llvm::any_of(totalNumBits,
+ [maxVal](int64_t bits) { return *maxVal % bits != 0; }))
+ return llvm::None;
+
+ return std::distance(totalNumBits.begin(), maxVal);
+ }
+
+ // All element types are scalars. Then choose the smallest bitwidth as the
+ // cannonical resource to avoid subcomponent load/store.
+ auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
+ if (llvm::any_of(scalarNumBits,
+ [minVal](int64_t bit) { return bit % *minVal != 0; }))
+ return llvm::None;
+ return std::distance(scalarNumBits.begin(), minVal);
+}
+
+static bool areSameBitwidthScalarType(Type a, Type b) {
+ return a.isIntOrFloat() && b.isIntOrFloat() &&
+ a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
}
//===----------------------------------------------------------------------===//
void ResourceAliasAnalysis::recordIfUnifiable(
const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
- // Collect the element types and byte counts for all resources in the
- // current set.
+ // Collect the element types for all resources in the current set.
SmallVector<spirv::SPIRVType> elementTypes;
- SmallVector<int64_t> numBytes;
-
for (spirv::GlobalVariableOp resource : resources) {
Type elementType = getRuntimeArrayElementType(resource.type());
if (!elementType)
if (!type.isScalarOrVector())
return; // Unexpected resource element type.
- if (auto vectorType = type.dyn_cast<VectorType>())
- if (vectorType.getNumElements() % 2 != 0)
- return; // Odd-sized vector has special layout requirements.
-
- Optional<int64_t> count = type.getSizeInBytes();
- if (!count)
- return;
-
elementTypes.push_back(type);
- numBytes.push_back(*count);
}
- // Make sure base scalar types have the same bitwdith, so that we don't need
- // to handle extracting components for now.
- if (!hasSameBitwidthScalarType(elementTypes))
- return;
-
- // Make sure that the canonical resource's bitwidth is divisible by others.
- // With out this, we cannot properly adjust the index later.
- auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
- if (llvm::any_of(numBytes, [maxCount](int64_t count) {
- return *maxCount % count != 0;
- }))
+ Optional<int> index = deduceCanonicalResource(elementTypes);
+ if (!index)
return;
- spirv::GlobalVariableOp canonicalResource =
- resources[std::distance(numBytes.begin(), maxCount)];
-
// Update internal data structures for later use.
resourceMap[descriptor].assign(resources.begin(), resources.end());
- canonicalResourceMap[descriptor] = canonicalResource;
+ canonicalResourceMap[descriptor] = resources[*index];
for (const auto &resource : llvm::enumerate(resources)) {
descriptorMap[resource.value()] = descriptor;
elementTypeMap[resource.value()] = elementTypes[resource.index()];
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
- if ((srcElemType == dstElemType) ||
- (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
+ if (srcElemType == dstElemType ||
+ areSameBitwidthScalarType(srcElemType, dstElemType)) {
// We have the same bitwidth for source and destination element types.
// Thie indices keep the same.
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
// the vector.
- int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
+ int srcNumBits = *srcElemType.getSizeInBytes();
+ int dstNumBits = *dstElemType.getSizeInBytes();
+ assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
+ int ratio = dstNumBits / srcNumBits;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
return success();
}
+ if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) {
+ // The source indices are for a buffer with larger bitwidth scalar element
+ // types. Rewrite them into a buffer with smaller bitwidth element types.
+ // We only need to scale the last index.
+ int srcNumBits = *srcElemType.getSizeInBytes();
+ int dstNumBits = *dstElemType.getSizeInBytes();
+ assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
+ int ratio = srcNumBits / dstNumBits;
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(ratio));
+
+ auto indices = llvm::to_vector<4>(acOp.indices());
+ Value oldIndex = indices.back();
+ indices.back() =
+ rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+
+ rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+ acOp, adaptor.base_ptr(), indices);
+ return success();
+ }
+
return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
}
};
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults());
- } else {
+ return success();
+ }
+
+ if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
newLoadOp.value());
rewriter.replaceOp(loadOp, castOp->getResults());
+
+ return success();
}
+ // The source and destination have scalar types of different bitwidths.
+ // For such cases, we need to load multiple smaller bitwidth values and
+ // construct a larger bitwidth one.
+
+ int srcNumBits = srcElemType.getIntOrFloatBitWidth();
+ int dstNumBits = dstElemType.getIntOrFloatBitWidth();
+ assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
+ int ratio = srcNumBits / dstNumBits;
+ if (ratio > 4)
+ return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
+
+ SmallVector<Value> components;
+ components.reserve(ratio);
+ components.push_back(newLoadOp);
+
+ auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
+ if (!acOp)
+ return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
+
+ auto i32Type = rewriter.getI32Type();
+ Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
+ auto indices = llvm::to_vector<4>(acOp.indices());
+ for (int i = 1; i < ratio; ++i) {
+ // Load all subsequent components belonging to this element.
+ indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type,
+ indices.back(), oneValue);
+ auto componentAcOp =
+ rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
+ components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+ }
+ std::reverse(components.begin(), components.end()); // For little endian..
+
+ // Create a vector of the components and then cast back to the larger
+ // bitwidth element type.
+ auto vectorType = VectorType::get({ratio}, dstElemType);
+ Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
+ loc, vectorType, components);
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType,
+ vectorValue);
return success();
}
};
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
+ if (!areSameBitwidthScalarType(srcElemType, dstElemType))
+ return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
Location loc = storeOp.getLoc();
Value value = adaptor.value();
-// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource -verify-diagnostics %s | FileCheck %s
spv.module Logical GLSL450 {
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
// CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32
// CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32
// CHECK: spv.ReturnValue %[[CAST1]] : i32
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+ spv.func @load_different_scalar_bitwidth(%index: i32) -> i64 "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr0 = spv.mlir.addressof @var01s_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val0 = spv.Load "StorageBuffer" %ac0 : i64
+
+ spv.ReturnValue %val0 : i64
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s_i64
+// CHECK: spv.GlobalVariable @var01s_f32 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s_i64
+
+// CHECK: spv.func @load_different_scalar_bitwidth(%[[INDEX:.+]]: i32)
+// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
+// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01s_f32
+
+// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32
+// CHECK: %[[BASE:.+]] = spv.IMul %[[INDEX]], %[[TWO]] : i32
+// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[BASE]]]
+// CHECK: %[[LOAD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : f32
+
+// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+// CHECK: %[[ADD:.+]] = spv.IAdd %[[BASE]], %[[ONE]] : i32
+// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[ADD]]]
+// CHECK: %[[LOAD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : f32
+
+// CHECK: %[[CC:.+]] = spv.CompositeConstruct %[[LOAD1]], %[[LOAD0]]
+// CHECK: %[[CAST:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64
+// CHECK: spv.ReturnValue %[[CAST]]
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+ spv.func @store_different_scalar_bitwidth(%i0: i32, %i1: i32) "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr0 = spv.mlir.addressof @var01s_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %f32val = spv.Load "StorageBuffer" %ac0 : f32
+ %f64val = spv.FConvert %f32val : f32 to f64
+ %i64val = spv.Bitcast %f64val : f64 to i64
+
+ %addr1 = spv.mlir.addressof @var01s_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+ %ac1 = spv.AccessChain %addr1[%c0, %i1] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>, i32, i32
+ // expected-error@+1 {{failed to legalize operation 'spv.Store'}}
+ spv.Store "StorageBuffer" %ac1, %i64val : i64
+
+ spv.Return
+ }
+}