}
ir::Instruction* ConstantManager::BuildInstructionAndAddToModule(
- std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
+ std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
+ uint32_t type_id) {
analysis::Constant* new_const = c.get();
uint32_t new_id = context()->TakeNextId();
const_val_to_id_[new_const] = new_id;
id_to_const_val_[new_id] = std::move(c);
- auto new_inst = CreateInstruction(new_id, new_const);
+ auto new_inst = CreateInstruction(new_id, new_const, type_id);
if (!new_inst) return nullptr;
auto* new_inst_ptr = new_inst.get();
*pos = pos->InsertBefore(std::move(new_inst));
}
std::unique_ptr<ir::Instruction> ConstantManager::CreateInstruction(
- uint32_t id, analysis::Constant* c) const {
+ uint32_t id, analysis::Constant* c, uint32_t type_id) const {
+ uint32_t type =
+ (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
if (c->AsNullConstant()) {
- return MakeUnique<ir::Instruction>(
- context(), SpvOp::SpvOpConstantNull,
- context()->get_type_mgr()->GetId(c->type()), id,
- std::initializer_list<ir::Operand>{});
+ return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantNull,
+ type, id,
+ std::initializer_list<ir::Operand>{});
} else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
return MakeUnique<ir::Instruction>(
context(),
bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
- context()->get_type_mgr()->GetId(c->type()), id,
- std::initializer_list<ir::Operand>{});
+ type, id, std::initializer_list<ir::Operand>{});
} else if (analysis::IntConstant* ic = c->AsIntConstant()) {
return MakeUnique<ir::Instruction>(
- context(), SpvOp::SpvOpConstant,
- context()->get_type_mgr()->GetId(c->type()), id,
+ context(), SpvOp::SpvOpConstant, type, id,
std::initializer_list<ir::Operand>{ir::Operand(
spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
ic->words())});
} else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
return MakeUnique<ir::Instruction>(
- context(), SpvOp::SpvOpConstant,
- context()->get_type_mgr()->GetId(c->type()), id,
+ context(), SpvOp::SpvOpConstant, type, id,
std::initializer_list<ir::Operand>{ir::Operand(
spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
fc->words())});
} else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
- return CreateCompositeInstruction(id, cc);
+ return CreateCompositeInstruction(id, cc, type_id);
} else {
return nullptr;
}
}
std::unique_ptr<ir::Instruction> ConstantManager::CreateCompositeInstruction(
- uint32_t result_id, analysis::CompositeConstant* cc) const {
+ uint32_t result_id, analysis::CompositeConstant* cc,
+ uint32_t type_id) const {
std::vector<ir::Operand> operands;
for (const analysis::Constant* component_const : cc->GetComponents()) {
uint32_t id = FindRecordedConstant(component_const);
operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
}
- return MakeUnique<ir::Instruction>(
- context(), SpvOp::SpvOpConstantComposite,
- context()->get_type_mgr()->GetId(cc->type()), result_id,
- std::move(operands));
+ uint32_t type =
+ (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
+ return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantComposite,
+ type, result_id, std::move(operands));
}
} // namespace analysis
// points to the same instruction before and after the insertion. This is the
// only method that actually manages id creation/assignment and instruction
// creation/insertion for a new Constant instance.
+ //
+ // |type_id| is an optional argument for disambiguating equivalent types. If
+ // |type_id| is specified, it is used as the type of the constant. Otherwise
+ // the type of the constant is derived by getting an id from the type manager
+ // for |c|.
ir::Instruction* BuildInstructionAndAddToModule(
- std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos);
+ std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
+ uint32_t type_id = 0);
// Creates an instruction with the given result id to declare a constant
// represented by the given Constant instance. Returns an unique pointer to
// the created instruction if the instruction can be created successfully.
// Otherwise, returns a null pointer.
+ //
+ // |type_id| is an optional argument for disambiguating equivalent types. If
+ // |type_id| is specified, it is used as the type of the constant. Otherwise
+ // the type of the constant is derived by getting an id from the type manager
+ // for |c|.
std::unique_ptr<ir::Instruction> CreateInstruction(
- uint32_t result_id, analysis::Constant* c) const;
+ uint32_t result_id, analysis::Constant* c, uint32_t type_id = 0) const;
// Creates an OpConstantComposite instruction with the given result id and
// the CompositeConst instance which represents a composite constant. Returns
// an unique pointer to the created instruction if succeeded. Otherwise
// returns a null pointer.
+ //
+ // |type_id| is an optional argument for disambiguating equivalent types. If
+ // |type_id| is specified, it is used as the type of the constant. Otherwise
+ // the type of the constant is derived by getting an id from the type manager
+ // for |c|.
std::unique_ptr<ir::Instruction> CreateCompositeInstruction(
- uint32_t result_id, analysis::CompositeConstant* cc) const;
+ uint32_t result_id, analysis::CompositeConstant* cc,
+ uint32_t type_id = 0) const;
// A helper function to get the result type of the given instruction. Returns
// nullptr if the instruction does not have a type id (type id is 0).
return true;
}
+uint32_t FoldSpecConstantOpAndCompositePass::GetTypeComponent(
+ uint32_t typeId, uint32_t element) const {
+ ir::Instruction* type = context()->get_def_use_mgr()->GetDef(typeId);
+ uint32_t subtype = type->GetTypeComponent(element);
+ assert(subtype != 0);
+
+ return subtype;
+}
+
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
ir::Module::inst_iterator* pos) {
ir::Instruction* inst = &**pos;
"OpSpecConstantOp CompositeExtract requires at least two non-type "
"non-opcode operands.");
assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
- "The vector operand must have a SPV_OPERAND_TYPE_ID type");
+ "The composite operand must have a SPV_OPERAND_TYPE_ID type");
assert(
inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
"The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
// Note that for OpSpecConstantOp, the second in-operand is the first id
// operand. The first in-operand is the spec opcode.
+ uint32_t source = inst->GetSingleWordInOperand(1);
+ uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id();
analysis::Constant* first_operand_const =
- context()->get_constant_mgr()->FindRecordedConstant(
- inst->GetSingleWordInOperand(1));
+ context()->get_constant_mgr()->FindRecordedConstant(source);
if (!first_operand_const) return nullptr;
const analysis::Constant* current_const = first_operand_const;
for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
uint32_t literal = inst->GetSingleWordInOperand(i);
+ type = GetTypeComponent(type, literal);
+ }
+ for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
+ uint32_t literal = inst->GetSingleWordInOperand(i);
if (const analysis::CompositeConstant* composite_const =
current_const->AsCompositeConstant()) {
// Case 1: current constant is a non-null composite type constant.
return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
context()->get_constant_mgr()->CreateConstant(
context()->get_constant_mgr()->GetType(inst), {}),
- pos);
+ pos, type);
} else {
// Dereferencing a non-composite constant. Invalid case.
return nullptr;
}
}
return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
- current_const->Copy(), pos);
+ current_const->Copy(), pos, type);
}
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
// if succeeded, otherwise return nullptr.
ir::Instruction* DoComponentWiseOperation(
ir::Module::inst_iterator* inst_iter_ptr);
+
+ // Returns the |element|'th subtype of |type|.
+ //
+ // |type| must be a composite type.
+ uint32_t GetTypeComponent(uint32_t type, uint32_t element) const;
};
} // namespace opt
uint32_t InlinePass::FindPointerToType(uint32_t type_id,
SpvStorageClass storage_class) {
+ analysis::Type* pointeeTy;
+ std::unique_ptr<analysis::Pointer> pointerTy;
+ std::tie(pointeeTy, pointerTy) =
+ context()->get_type_mgr()->GetTypeAndPointerType(type_id,
+ SpvStorageClassFunction);
+ if (type_id == context()->get_type_mgr()->GetId(pointeeTy)) {
+ // Non-ambiguous type. Get the pointer type through the type manager.
+ return context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
+ }
+
+ // Ambiguous type, do a linear search.
ir::Module::inst_iterator type_itr = get_module()->types_values_begin();
for (; type_itr != get_module()->types_values_end(); ++type_itr) {
const ir::Instruction* type_inst = &*type_itr;
{uint32_t(storage_class)}},
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}}));
context()->AddType(std::move(type_inst));
+ analysis::Type* pointeeTy;
+ std::unique_ptr<analysis::Pointer> pointerTy;
+ std::tie(pointeeTy, pointerTy) =
+ context()->get_type_mgr()->GetTypeAndPointerType(type_id,
+ SpvStorageClassFunction);
+ context()->get_type_mgr()->RegisterType(resultId, *pointerTy);
return resultId;
}
return storage_class == SpvStorageClassUniformConstant;
}
+uint32_t Instruction::GetTypeComponent(uint32_t element) const {
+ uint32_t subtype = 0;
+ switch (opcode()) {
+ case SpvOpTypeStruct:
+ subtype = GetSingleWordInOperand(element);
+ break;
+ case SpvOpTypeArray:
+ case SpvOpTypeRuntimeArray:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ // These types all have uniform subtypes.
+ subtype = GetSingleWordInOperand(0u);
+ break;
+ default:
+ break;
+ }
+
+ return subtype;
+}
+
Instruction* Instruction::InsertBefore(
std::vector<std::unique_ptr<Instruction>>&& list) {
Instruction* first_node = list.front().get();
// and return to its caller
bool IsReturn() const { return spvOpcodeIsReturn(opcode()); }
+ // Returns the id for the |element|'th subtype. If the |this| is not a
+ // composite type, this function returns 0.
+ uint32_t GetTypeComponent(uint32_t element) const;
+
// Returns true if this instruction is a basic block terminator.
bool IsBlockTerminator() const {
return spvOpcodeIsBlockTerminator(opcode());
#include "latest_version_glsl_std_450_header.h"
#include "log.h"
#include "mem_pass.h"
+#include "reflect.h"
+#include "spirv/1.0/GLSL.std.450.h"
#include <cstring>
}
}
+ if (type_mgr_ && ir::IsTypeInst(inst->opcode())) {
+ type_mgr_->RemoveId(inst->result_id());
+ }
+
Instruction* next_instruction = nullptr;
if (inst->IsInAList()) {
next_instruction = inst->NextNode();
// is never re-built.
opt::analysis::TypeManager* get_type_mgr() {
if (!type_mgr_)
- type_mgr_.reset(new opt::analysis::TypeManager(consumer(), *module()));
+ type_mgr_.reset(new opt::analysis::TypeManager(consumer(), this));
return type_mgr_.get();
}
}
Optimizer& Optimizer::RegisterPerformancePasses() {
- return RegisterPass(CreateMergeReturnPass())
+ return RegisterPass(CreateRemoveDuplicatesPass())
+ .RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateEliminateDeadFunctionsPass())
.RegisterPass(CreateScalarReplacementPass())
}
Optimizer& Optimizer::RegisterSizePasses() {
- return RegisterPass(CreateMergeReturnPass())
+ return RegisterPass(CreateRemoveDuplicatesPass())
+ .RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateEliminateDeadFunctionsPass())
.RegisterPass(CreateLocalAccessChainConvertPass())
MakeUnique<opt::RedundancyEliminationPass>());
}
+Optimizer::PassToken CreateRemoveDuplicatesPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::RemoveDuplicatesPass>());
+}
+
Optimizer::PassToken CreateScalarReplacementPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::ScalarReplacementPass>());
#include "merge_return_pass.h"
#include "null_pass.h"
#include "redundancy_elimination.h"
+#include "remove_duplicates_pass.h"
#include "scalar_replacement_pass.h"
#include "set_spec_constant_default_value_pass.h"
#include "strength_reduction_pass.h"
#include "types.h"
#include <queue>
+#include <tuple>
namespace spvtools {
namespace opt {
auto iter = pointee_to_pointer_.find(id);
if (iter != pointee_to_pointer_.end()) return iter->second;
- // TODO(alanbaker): Make the type manager useful and then replace this code.
+ analysis::Type* pointeeTy;
+ std::unique_ptr<analysis::Pointer> pointerTy;
+ std::tie(pointeeTy, pointerTy) =
+ context()->get_type_mgr()->GetTypeAndPointerType(id,
+ SpvStorageClassFunction);
uint32_t ptrId = 0;
+ if (id == context()->get_type_mgr()->GetId(pointeeTy)) {
+ // Non-ambiguous type, just ask the type manager for an id.
+ ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
+ pointee_to_pointer_[id] = ptrId;
+ return ptrId;
+ }
+
+ // Ambiguous type. We must perform a linear search to try and find the right
+ // type.
for (auto global : context()->types_values()) {
if (global.opcode() == SpvOpTypePointer &&
global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
libspirv::Extension::kSPV_KHR_variable_pointers) ||
get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
// If variable pointers is enabled, only reuse a decoration-less
- // pointer of the correct type
+ // pointer of the correct type.
ptrId = global.result_id();
break;
}
ir::Instruction* ptr = &*--context()->types_values_end();
get_def_use_mgr()->AnalyzeInstDefUse(ptr);
pointee_to_pointer_[id] = ptrId;
+ // Register with the type manager if necessary.
+ context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
return ptrId;
}
Pass::Status SetSpecConstantDefaultValuePass::Process(
ir::IRContext* irContext) {
+ InitializeProcessing(irContext);
+
// The operand index of decoration target in an OpDecorate instruction.
const uint32_t kTargetIdOperandIndex = 0;
// The operand index of the decoration literal in an OpDecorate instruction.
const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
bool modified = false;
- analysis::DefUseManager def_use_mgr(irContext->module());
- analysis::TypeManager type_mgr(consumer(), *irContext->module());
// Scan through all the annotation instructions to find 'OpDecorate SpecId'
// instructions. Then extract the decoration target of those instructions.
// The decoration targets should be spec constant defining instructions with
// Find the spec constant defining instruction. Note that the
// target_id might be a decoration group id.
ir::Instruction* spec_inst = nullptr;
- if (ir::Instruction* target_inst = def_use_mgr.GetDef(target_id)) {
+ if (ir::Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
spec_inst =
- GetSpecIdTargetFromDecorationGroup(*target_inst, &def_use_mgr);
+ GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
} else {
spec_inst = target_inst;
}
// with the type of the spec constant.
const std::string& default_value_str = iter->second;
bit_pattern = ParseDefaultValueStr(
- default_value_str.c_str(), type_mgr.GetType(spec_inst->type_id()));
+ default_value_str.c_str(),
+ context()->get_type_mgr()->GetType(spec_inst->type_id()));
} else {
// Search for the new bit-pattern-form default value for this spec id.
// Gets the bit-pattern of the default value from the map directly.
bit_pattern = ParseDefaultValueBitPattern(
- iter->second, type_mgr.GetType(spec_inst->type_id()));
+ iter->second,
+ context()->get_type_mgr()->GetType(spec_inst->type_id()));
}
if (bit_pattern.empty()) continue;
}
void StrengthReductionPass::FindIntTypesAndConstants() {
+ analysis::Integer int32(32, true);
+ int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
+ analysis::Integer uint32(32, false);
+ uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
for (auto iter = get_module()->types_values_begin();
iter != get_module()->types_values_end(); ++iter) {
switch (iter->opcode()) {
- case SpvOp::SpvOpTypeInt:
- if (iter->GetSingleWordOperand(1) == 32) {
- if (iter->GetSingleWordOperand(2) == 1) {
- int32_type_id_ = iter->result_id();
- } else {
- uint32_type_id_ = iter->result_id();
- }
- }
- break;
case SpvOp::SpvOpConstant:
if (iter->type_id() == uint32_type_id_) {
uint32_t value = iter->GetSingleWordOperand(2);
if (constant_ids_[val] == 0) {
if (uint32_type_id_ == 0) {
- uint32_type_id_ = CreateUint32Type();
+ analysis::Integer uint(32, false);
+ uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
}
// Construct the constant.
return modified;
}
-uint32_t StrengthReductionPass::CreateUint32Type() {
- uint32_t type_id = TakeNextId();
- ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {32});
- ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
- {0});
- std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
- context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
- context()->AddType(std::move(newType));
- return type_id;
-}
-
} // namespace opt
} // namespace spvtools
// ones. Returns true if something changed.
bool ScanFunctions();
- // Will create the type for an unsigned 32-bit integer and return the id.
- // This functions assumes one does not already exist.
- uint32_t CreateUint32Type();
-
// Type ids for the types of interest, or 0 if they do not exist.
uint32_t int32_type_id_;
uint32_t uint32_type_id_;
#include "type_manager.h"
+#include <cassert>
+#include <cstring>
#include <utility>
+#include "ir_context.h"
#include "log.h"
+#include "make_unique.h"
#include "reflect.h"
namespace spvtools {
namespace opt {
namespace analysis {
+TypeManager::TypeManager(const MessageConsumer& consumer,
+ spvtools::ir::IRContext* c)
+ : consumer_(consumer), context_(c) {
+ AnalyzeTypes(*c->module());
+}
+
Type* TypeManager::GetType(uint32_t id) const {
auto iter = id_to_type_.find(id);
if (iter != id_to_type_.end()) return (*iter).second.get();
return nullptr;
}
+std::pair<Type*, std::unique_ptr<Pointer>> TypeManager::GetTypeAndPointerType(
+ uint32_t id, SpvStorageClass sc) const {
+ Type* type = GetType(id);
+ if (type) {
+ return std::make_pair(type, MakeUnique<analysis::Pointer>(type, sc));
+ } else {
+ return std::make_pair(type, std::unique_ptr<analysis::Pointer>());
+ }
+}
+
uint32_t TypeManager::GetId(const Type* type) const {
auto iter = type_to_id_.find(type);
if (iter != type_to_id_.end()) return (*iter).second;
for (const auto& inst : module.annotations()) AttachIfTypeDecoration(inst);
}
+void TypeManager::RemoveId(uint32_t id) {
+ auto iter = id_to_type_.find(id);
+ if (iter == id_to_type_.end()) return;
+
+ auto& type = iter->second;
+ if (!type->IsUniqueType(true)) {
+ // Search for an equivalent type to re-map.
+ bool found = false;
+ for (auto& pair : id_to_type_) {
+ if (pair.first != id && *pair.second == *type) {
+ // Equivalent ambiguous type, re-map type.
+ type_to_id_.erase(type.get());
+ type_to_id_[pair.second.get()] = pair.first;
+ found = true;
+ break;
+ }
+ // No equivalent ambiguous type, remove mapping.
+ if (!found) type_to_id_.erase(type.get());
+ }
+ } else {
+ // Unique type, so just erase the entry.
+ type_to_id_.erase(type.get());
+ }
+
+ // Erase the entry for |id|.
+ id_to_type_.erase(iter);
+}
+
+uint32_t TypeManager::GetTypeInstruction(const Type* type) {
+ uint32_t id = GetId(type);
+ if (id != 0) return id;
+
+ std::unique_ptr<ir::Instruction> typeInst;
+ id = context()->TakeNextId();
+ RegisterType(id, *type);
+ switch (type->kind()) {
+#define DefineParameterlessCase(kind) \
+ case Type::k##kind: \
+ typeInst.reset(new ir::Instruction(context(), SpvOpType##kind, 0, id, \
+ std::initializer_list<ir::Operand>{})); \
+ break;
+ DefineParameterlessCase(Void);
+ DefineParameterlessCase(Bool);
+ DefineParameterlessCase(Sampler);
+ DefineParameterlessCase(Event);
+ DefineParameterlessCase(DeviceEvent);
+ DefineParameterlessCase(ReserveId);
+ DefineParameterlessCase(Queue);
+ DefineParameterlessCase(PipeStorage);
+ DefineParameterlessCase(NamedBarrier);
+#undef DefineParameterlessCase
+ case Type::kInteger:
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypeInt, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsInteger()->width()}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {(type->AsInteger()->IsSigned() ? 1u : 0u)}}}));
+ break;
+ case Type::kFloat:
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypeFloat, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}}));
+ break;
+ case Type::kVector: {
+ uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type());
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeVector, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {type->AsVector()->element_count()}}}));
+ break;
+ }
+ case Type::kMatrix: {
+ uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type());
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeMatrix, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {type->AsMatrix()->element_count()}}}));
+ break;
+ }
+ case Type::kImage: {
+ const Image* image = type->AsImage();
+ uint32_t subtype = GetTypeInstruction(image->sampled_type());
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypeImage, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}},
+ {SPV_OPERAND_TYPE_DIMENSIONALITY,
+ {static_cast<uint32_t>(image->dim())}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->depth()}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {(image->is_arrayed() ? 1u : 0u)}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {(image->is_multisampled() ? 1u : 0u)}},
+ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->sampled()}},
+ {SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT,
+ {static_cast<uint32_t>(image->format())}},
+ {SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
+ {static_cast<uint32_t>(image->access_qualifier())}}}));
+ break;
+ }
+ case Type::kSampledImage: {
+ uint32_t subtype =
+ GetTypeInstruction(type->AsSampledImage()->image_type());
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeSampledImage, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}}}));
+ break;
+ }
+ case Type::kArray: {
+ uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type());
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypeArray, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}},
+ {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}}));
+ break;
+ }
+ case Type::kRuntimeArray: {
+ uint32_t subtype =
+ GetTypeInstruction(type->AsRuntimeArray()->element_type());
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeRuntimeArray, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {subtype}}}));
+ break;
+ }
+ case Type::kStruct: {
+ std::vector<ir::Operand> ops;
+ const Struct* structTy = type->AsStruct();
+ for (auto ty : structTy->element_types()) {
+ ops.push_back(
+ ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+ }
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeStruct, 0, id, ops));
+ break;
+ }
+ case Type::kOpaque: {
+ const Opaque* opaque = type->AsOpaque();
+ size_t size = opaque->name().size();
+ // Convert to null-terminated packed UTF-8 string.
+ std::vector<uint32_t> words(size / 4 + 1, 0);
+ char* dst = reinterpret_cast<char*>(words.data());
+ strncpy(dst, opaque->name().c_str(), size);
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeOpaque, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_LITERAL_STRING, words}}));
+ break;
+ }
+ case Type::kPointer: {
+ const Pointer* pointer = type->AsPointer();
+ uint32_t subtype = GetTypeInstruction(pointer->pointee_type());
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypePointer, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_STORAGE_CLASS,
+ {static_cast<uint32_t>(pointer->storage_class())}},
+ {SPV_OPERAND_TYPE_ID, {subtype}}}));
+ break;
+ }
+ case Type::kFunction: {
+ std::vector<ir::Operand> ops;
+ const Function* function = type->AsFunction();
+ ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID,
+ {GetTypeInstruction(function->return_type())}));
+ for (auto ty : function->param_types()) {
+ ops.push_back(
+ ir::Operand(SPV_OPERAND_TYPE_ID, {GetTypeInstruction(ty)}));
+ }
+ typeInst.reset(
+ new ir::Instruction(context(), SpvOpTypeFunction, 0, id, ops));
+ break;
+ }
+ case Type::kPipe:
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypePipe, 0, id,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ACCESS_QUALIFIER,
+ {static_cast<uint32_t>(type->AsPipe()->access_qualifier())}}}));
+ break;
+ case Type::kForwardPointer:
+ typeInst.reset(new ir::Instruction(
+ context(), SpvOpTypeForwardPointer, 0, 0,
+ std::initializer_list<ir::Operand>{
+ {SPV_OPERAND_TYPE_ID, {type->AsForwardPointer()->target_id()}},
+ {SPV_OPERAND_TYPE_STORAGE_CLASS,
+ {static_cast<uint32_t>(
+ type->AsForwardPointer()->storage_class())}}}));
+ break;
+ default:
+ assert(false && "Unexpected type");
+ break;
+ }
+ context()->AddType(std::move(typeInst));
+ context()->get_def_use_mgr()->AnalyzeInstDefUse(
+ &*--context()->types_values_end());
+ AttachDecorations(id, type);
+
+ return id;
+}
+
+void TypeManager::AttachDecorations(uint32_t id, const Type* type) {
+ for (auto vec : type->decorations()) {
+ CreateDecoration(id, vec);
+ }
+ if (const Struct* structTy = type->AsStruct()) {
+ for (auto pair : structTy->element_decorations()) {
+ uint32_t element = pair.first;
+ for (auto vec : pair.second) {
+ CreateDecoration(id, vec, element);
+ }
+ }
+ }
+}
+
+void TypeManager::CreateDecoration(uint32_t target,
+ const std::vector<uint32_t>& decoration,
+ uint32_t element) {
+ std::vector<ir::Operand> ops;
+ ops.push_back(ir::Operand(SPV_OPERAND_TYPE_ID, {target}));
+ if (element != 0) {
+ ops.push_back(ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element}));
+ }
+ ops.push_back(ir::Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]}));
+ for (size_t i = 1; i < decoration.size(); ++i) {
+ ops.push_back(
+ ir::Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]}));
+ }
+ context()->AddAnnotationInst(MakeUnique<ir::Instruction>(
+ context(), (element == 0 ? SpvOpDecorate : SpvOpMemberDecorate), 0, 0,
+ ops));
+ ir::Instruction* inst = &*--context()->annotation_end();
+ context()->get_def_use_mgr()->AnalyzeInstUse(inst);
+}
+
+void TypeManager::RegisterType(uint32_t id, const Type& type) {
+ auto& t = id_to_type_[id];
+ t.reset(type.Clone().release());
+ if (GetId(t.get()) == 0) {
+ type_to_id_[t.get()] = id;
+ }
+}
+
Type* TypeManager::RecordIfTypeDefinition(
const spvtools::ir::Instruction& inst) {
if (!spvtools::ir::IsTypeInst(inst.opcode())) return nullptr;
type = new Image(
GetType(inst.GetSingleWordInOperand(0)),
static_cast<SpvDim>(inst.GetSingleWordInOperand(1)),
- inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3),
- inst.GetSingleWordInOperand(4), inst.GetSingleWordInOperand(5),
+ inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3) == 1,
+ inst.GetSingleWordInOperand(4) == 1, inst.GetSingleWordInOperand(5),
static_cast<SpvImageFormat>(inst.GetSingleWordInOperand(6)), access);
} break;
case SpvOpTypeSampler:
#include "types.h"
namespace spvtools {
+namespace ir {
+class IRContext;
+} // namespace ir
namespace opt {
namespace analysis {
+// Hashing functor.
+//
+// All type pointers must be non-null.
+struct HashTypePointer {
+ size_t operator()(const Type* type) const {
+ assert(type);
+ return type->HashValue();
+ }
+};
+
+// Equality functor.
+//
+// Checks if two types pointers are the same type.
+//
+// All type pointers must be non-null.
+struct CompareTypePointers {
+ bool operator()(const Type* lhs, const Type* rhs) const {
+ assert(lhs && rhs);
+ return lhs->IsSame(rhs);
+ }
+};
+
// A class for managing the SPIR-V type hierarchy.
class TypeManager {
public:
// will be communicated to the outside via the given message |consumer|.
// This instance only keeps a reference to the |consumer|, so the |consumer|
// should outlive this instance.
- TypeManager(const MessageConsumer& consumer,
- const spvtools::ir::Module& module);
+ TypeManager(const MessageConsumer& consumer, spvtools::ir::IRContext* c);
TypeManager(const TypeManager&) = delete;
TypeManager(TypeManager&&) = delete;
// Returns the number of forward pointer types hold in this manager.
size_t NumForwardPointers() const { return forward_pointers_.size(); }
- // Analyzes the types and decorations on types in the given |module|.
- // TODO(dnovillo): This should be private and the type manager should know how
- // to update itself when new types are added
- // (https://github.com/KhronosGroup/SPIRV-Tools/issues/1071).
- void AnalyzeTypes(const spvtools::ir::Module& module);
+ // Returns a pair of the type and pointer to the type in |sc|.
+ //
+ // |id| must be a registered type.
+ std::pair<Type*, std::unique_ptr<Pointer>> GetTypeAndPointerType(
+ uint32_t id, SpvStorageClass sc) const;
+
+ // Returns an id for a declaration representing |type|.
+ //
+ // If |type| is registered, then the registered id is returned. Otherwise,
+ // this function recursively adds type and annotation instructions as
+ // necessary to fully define |type|.
+ uint32_t GetTypeInstruction(const Type* type);
+
+ // Registers |id| to |type|.
+ //
+ // If GetId(|type|) already returns a non-zero id, the return value will be
+ // unchanged.
+ void RegisterType(uint32_t id, const Type& type);
+
+ // Removes knowledge of |id| from the manager.
+ //
+ // If |id| is an ambiguous type the multiple ids may be registered to |id|'s
+ // type (e.g. %struct1 and %struct1 might hash to the same type). In that
+ // case, calling GetId() with |id|'s type will return another suitable id
+ // defining that type.
+ void RemoveId(uint32_t id);
private:
- using TypeToIdMap = std::unordered_map<const Type*, uint32_t>;
+ using TypeToIdMap = std::unordered_map<const Type*, uint32_t, HashTypePointer,
+ CompareTypePointers>;
using ForwardPointerVector = std::vector<std::unique_ptr<ForwardPointer>>;
+ // Analyzes the types and decorations on types in the given |module|.
+ void AnalyzeTypes(const spvtools::ir::Module& module);
+
+ spvtools::ir::IRContext* context() { return context_; }
+
+ // Attachs the decorations on |type| to |id|.
+ void AttachDecorations(uint32_t id, const Type* type);
+
+ // Create the annotation instruction.
+ //
+ // If |element| is zero, an OpDecorate is created, other an OpMemberDecorate
+ // is created. The annotation is registered with the DefUseManager and the
+ // DecorationManager.
+ void CreateDecoration(uint32_t id, const std::vector<uint32_t>& decoration,
+ uint32_t element = 0);
+
// Creates and returns a type from the given SPIR-V |inst|. Returns nullptr if
// the given instruction is not for defining a type.
Type* RecordIfTypeDefinition(const spvtools::ir::Instruction& inst);
void AttachIfTypeDecoration(const spvtools::ir::Instruction& inst);
const MessageConsumer& consumer_; // Message consumer.
+ spvtools::ir::IRContext* context_;
IdToTypeMap id_to_type_; // Mapping from ids to their type representations.
TypeToIdMap type_to_id_; // Mapping from types to their defining ids.
ForwardPointerVector forward_pointers_; // All forward pointer declarations.
std::unordered_set<ForwardPointer*> unresolved_forward_pointers_;
};
-inline TypeManager::TypeManager(const spvtools::MessageConsumer& consumer,
- const spvtools::ir::Module& module)
- : consumer_(consumer) {
- AnalyzeTypes(module);
-}
-
} // namespace analysis
} // namespace opt
} // namespace spvtools
#include <algorithm>
#include <cassert>
+#include <cstdint>
#include <sstream>
#include "types.h"
return CompareTwoVectors(decorations_, that->decorations_);
}
-bool Integer::IsSame(Type* that) const {
+bool Type::IsUniqueType(bool allowVariablePointers) const {
+ switch (kind_) {
+ case kPointer:
+ return !allowVariablePointers;
+ case kStruct:
+ case kArray:
+ case kRuntimeArray:
+ return false;
+ default:
+ return true;
+ }
+}
+
+std::unique_ptr<Type> Type::Clone() const {
+ std::unique_ptr<Type> type;
+ switch (kind_) {
+#define DeclareKindCase(kind) \
+ case k##kind: \
+ type.reset(new kind(*this->As##kind())); \
+ break;
+ DeclareKindCase(Void);
+ DeclareKindCase(Bool);
+ DeclareKindCase(Integer);
+ DeclareKindCase(Float);
+ DeclareKindCase(Vector);
+ DeclareKindCase(Matrix);
+ DeclareKindCase(Image);
+ DeclareKindCase(Sampler);
+ DeclareKindCase(SampledImage);
+ DeclareKindCase(Array);
+ DeclareKindCase(RuntimeArray);
+ DeclareKindCase(Struct);
+ DeclareKindCase(Opaque);
+ DeclareKindCase(Pointer);
+ DeclareKindCase(Function);
+ DeclareKindCase(Event);
+ DeclareKindCase(DeviceEvent);
+ DeclareKindCase(ReserveId);
+ DeclareKindCase(Queue);
+ DeclareKindCase(Pipe);
+ DeclareKindCase(ForwardPointer);
+ DeclareKindCase(PipeStorage);
+ DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+ default:
+ assert(false && "Unhandled type");
+ }
+ return type;
+}
+
+std::unique_ptr<Type> Type::RemoveDecorations() const {
+ std::unique_ptr<Type> type(Clone());
+ type->ClearDecorations();
+ return type;
+}
+
+bool Type::operator==(const Type& other) const {
+ if (kind_ != other.kind_) return false;
+
+ switch (kind_) {
+#define DeclareKindCase(kind) \
+ case k##kind: \
+ return As##kind()->IsSame(&other);
+ DeclareKindCase(Void);
+ DeclareKindCase(Bool);
+ DeclareKindCase(Integer);
+ DeclareKindCase(Float);
+ DeclareKindCase(Vector);
+ DeclareKindCase(Matrix);
+ DeclareKindCase(Image);
+ DeclareKindCase(Sampler);
+ DeclareKindCase(SampledImage);
+ DeclareKindCase(Array);
+ DeclareKindCase(RuntimeArray);
+ DeclareKindCase(Struct);
+ DeclareKindCase(Opaque);
+ DeclareKindCase(Pointer);
+ DeclareKindCase(Function);
+ DeclareKindCase(Event);
+ DeclareKindCase(DeviceEvent);
+ DeclareKindCase(ReserveId);
+ DeclareKindCase(Queue);
+ DeclareKindCase(Pipe);
+ DeclareKindCase(ForwardPointer);
+ DeclareKindCase(PipeStorage);
+ DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+ default:
+ assert(false && "Unhandled type");
+ return false;
+ }
+}
+
+void Type::GetHashWords(std::vector<uint32_t>* words) const {
+ words->push_back(kind_);
+ for (auto d : decorations_) {
+ for (auto w : d) {
+ words->push_back(w);
+ }
+ }
+
+ switch (kind_) {
+#define DeclareKindCase(type) \
+ case k##type: \
+ As##type()->GetExtraHashWords(words); \
+ break;
+ DeclareKindCase(Void);
+ DeclareKindCase(Bool);
+ DeclareKindCase(Integer);
+ DeclareKindCase(Float);
+ DeclareKindCase(Vector);
+ DeclareKindCase(Matrix);
+ DeclareKindCase(Image);
+ DeclareKindCase(Sampler);
+ DeclareKindCase(SampledImage);
+ DeclareKindCase(Array);
+ DeclareKindCase(RuntimeArray);
+ DeclareKindCase(Struct);
+ DeclareKindCase(Opaque);
+ DeclareKindCase(Pointer);
+ DeclareKindCase(Function);
+ DeclareKindCase(Event);
+ DeclareKindCase(DeviceEvent);
+ DeclareKindCase(ReserveId);
+ DeclareKindCase(Queue);
+ DeclareKindCase(Pipe);
+ DeclareKindCase(ForwardPointer);
+ DeclareKindCase(PipeStorage);
+ DeclareKindCase(NamedBarrier);
+#undef DeclareKindCase
+ default:
+ assert(false && "Unhandled type");
+ break;
+ }
+}
+
+size_t Type::HashValue() const {
+ std::u32string h;
+ std::vector<uint32_t> words;
+ GetHashWords(&words);
+ for (auto w : words) {
+ h.push_back(w);
+ }
+
+ return std::hash<std::u32string>()(h);
+}
+
+bool Integer::IsSame(const Type* that) const {
const Integer* it = that->AsInteger();
return it && width_ == it->width_ && signed_ == it->signed_ &&
HasSameDecorations(that);
return oss.str();
}
-bool Float::IsSame(Type* that) const {
+void Integer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ words->push_back(width_);
+ words->push_back(signed_);
+}
+
+bool Float::IsSame(const Type* that) const {
const Float* ft = that->AsFloat();
return ft && width_ == ft->width_ && HasSameDecorations(that);
}
return oss.str();
}
+void Float::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ words->push_back(width_);
+}
+
Vector::Vector(Type* type, uint32_t count)
- : element_type_(type), count_(count) {
+ : Type(kVector), element_type_(type), count_(count) {
assert(type->AsBool() || type->AsInteger() || type->AsFloat());
}
-bool Vector::IsSame(Type* that) const {
+bool Vector::IsSame(const Type* that) const {
const Vector* vt = that->AsVector();
if (!vt) return false;
return count_ == vt->count_ && element_type_->IsSame(vt->element_type_) &&
return oss.str();
}
+void Vector::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ element_type_->GetHashWords(words);
+ words->push_back(count_);
+}
+
Matrix::Matrix(Type* type, uint32_t count)
- : element_type_(type), count_(count) {
+ : Type(kMatrix), element_type_(type), count_(count) {
assert(type->AsVector());
}
-bool Matrix::IsSame(Type* that) const {
+bool Matrix::IsSame(const Type* that) const {
const Matrix* mt = that->AsMatrix();
if (!mt) return false;
return count_ == mt->count_ && element_type_->IsSame(mt->element_type_) &&
return oss.str();
}
-Image::Image(Type* sampled_type, SpvDim dim, uint32_t depth, uint32_t arrayed,
- uint32_t ms, uint32_t sampled, SpvImageFormat format,
- SpvAccessQualifier access_qualifier)
- : sampled_type_(sampled_type),
- dim_(dim),
- depth_(depth),
- arrayed_(arrayed),
- ms_(ms),
- sampled_(sampled),
- format_(format),
- access_qualifier_(access_qualifier) {
+void Matrix::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ element_type_->GetHashWords(words);
+ words->push_back(count_);
+}
+
+Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
+ uint32_t sampling, SpvImageFormat f, SpvAccessQualifier qualifier)
+ : Type(kImage),
+ sampled_type_(type),
+ dim_(dimen),
+ depth_(d),
+ arrayed_(array),
+ ms_(multisample),
+ sampled_(sampling),
+ format_(f),
+ access_qualifier_(qualifier) {
// TODO(antiagainst): check sampled_type
}
-bool Image::IsSame(Type* that) const {
+bool Image::IsSame(const Type* that) const {
const Image* it = that->AsImage();
if (!it) return false;
return dim_ == it->dim_ && depth_ == it->depth_ && arrayed_ == it->arrayed_ &&
return oss.str();
}
-bool SampledImage::IsSame(Type* that) const {
+void Image::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ sampled_type_->GetHashWords(words);
+ words->push_back(dim_);
+ words->push_back(depth_);
+ words->push_back(arrayed_);
+ words->push_back(ms_);
+ words->push_back(sampled_);
+ words->push_back(format_);
+ words->push_back(access_qualifier_);
+}
+
+bool SampledImage::IsSame(const Type* that) const {
const SampledImage* sit = that->AsSampledImage();
if (!sit) return false;
return image_type_->IsSame(sit->image_type_) && HasSameDecorations(that);
return oss.str();
}
+void SampledImage::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ image_type_->GetHashWords(words);
+}
+
Array::Array(Type* type, uint32_t length_id)
- : element_type_(type), length_id_(length_id) {
+ : Type(kArray), element_type_(type), length_id_(length_id) {
assert(!type->AsVoid());
}
-bool Array::IsSame(Type* that) const {
+bool Array::IsSame(const Type* that) const {
const Array* at = that->AsArray();
if (!at) return false;
return length_id_ == at->length_id_ &&
return oss.str();
}
-RuntimeArray::RuntimeArray(Type* type) : element_type_(type) {
+void Array::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ element_type_->GetHashWords(words);
+ words->push_back(length_id_);
+}
+
+RuntimeArray::RuntimeArray(Type* type)
+ : Type(kRuntimeArray), element_type_(type) {
assert(!type->AsVoid());
}
-bool RuntimeArray::IsSame(Type* that) const {
+bool RuntimeArray::IsSame(const Type* that) const {
const RuntimeArray* rat = that->AsRuntimeArray();
if (!rat) return false;
return element_type_->IsSame(rat->element_type_) && HasSameDecorations(that);
return oss.str();
}
-Struct::Struct(const std::vector<Type*>& types) : element_types_(types) {
+void RuntimeArray::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ element_type_->GetHashWords(words);
+}
+
+Struct::Struct(const std::vector<Type*>& types)
+ : Type(kStruct), element_types_(types) {
for (auto* t : types) {
(void)t;
assert(!t->AsVoid());
element_decorations_[index].push_back(std::move(decoration));
}
-bool Struct::IsSame(Type* that) const {
+bool Struct::IsSame(const Type* that) const {
const Struct* st = that->AsStruct();
if (!st) return false;
if (element_types_.size() != st->element_types_.size()) return false;
return oss.str();
}
-bool Opaque::IsSame(Type* that) const {
+void Struct::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ for (auto t : element_types_) {
+ t->GetHashWords(words);
+ }
+ for (auto pair : element_decorations_) {
+ words->push_back(pair.first);
+ for (auto d : pair.second) {
+ for (auto w : d) {
+ words->push_back(w);
+ }
+ }
+ }
+}
+
+bool Opaque::IsSame(const Type* that) const {
const Opaque* ot = that->AsOpaque();
if (!ot) return false;
return name_ == ot->name_ && HasSameDecorations(that);
return oss.str();
}
-Pointer::Pointer(Type* type, SpvStorageClass storage_class)
- : pointee_type_(type), storage_class_(storage_class) {
+void Opaque::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ for (auto c : name_) {
+ words->push_back(static_cast<char32_t>(c));
+ }
+}
+
+Pointer::Pointer(Type* type, SpvStorageClass sc)
+ : Type(kPointer), pointee_type_(type), storage_class_(sc) {
assert(!type->AsVoid());
}
-bool Pointer::IsSame(Type* that) const {
+bool Pointer::IsSame(const Type* that) const {
const Pointer* pt = that->AsPointer();
if (!pt) return false;
if (storage_class_ != pt->storage_class_) return false;
std::string Pointer::str() const { return pointee_type_->str() + "*"; }
-Function::Function(Type* return_type, const std::vector<Type*>& param_types)
- : return_type_(return_type), param_types_(param_types) {
- for (auto* t : param_types) {
+void Pointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ pointee_type_->GetHashWords(words);
+ words->push_back(storage_class_);
+}
+
+Function::Function(Type* ret_type, const std::vector<Type*>& params)
+ : Type(kFunction), return_type_(ret_type), param_types_(params) {
+ for (auto* t : params) {
(void)t;
assert(!t->AsVoid());
}
}
-bool Function::IsSame(Type* that) const {
+bool Function::IsSame(const Type* that) const {
const Function* ft = that->AsFunction();
if (!ft) return false;
if (!return_type_->IsSame(ft->return_type_)) return false;
return oss.str();
}
-bool Pipe::IsSame(Type* that) const {
+void Function::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ return_type_->GetHashWords(words);
+ for (auto t : param_types_) {
+ t->GetHashWords(words);
+ }
+}
+
+bool Pipe::IsSame(const Type* that) const {
const Pipe* pt = that->AsPipe();
if (!pt) return false;
return access_qualifier_ == pt->access_qualifier_ && HasSameDecorations(that);
return oss.str();
}
-bool ForwardPointer::IsSame(Type* that) const {
+void Pipe::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ words->push_back(access_qualifier_);
+}
+
+bool ForwardPointer::IsSame(const Type* that) const {
const ForwardPointer* fpt = that->AsForwardPointer();
if (!fpt) return false;
return target_id_ == fpt->target_id_ &&
return oss.str();
}
+void ForwardPointer::GetExtraHashWords(std::vector<uint32_t>* words) const {
+ words->push_back(target_id_);
+ words->push_back(storage_class_);
+ if (pointer_) pointer_->GetHashWords(words);
+}
+
} // namespace analysis
} // namespace opt
} // namespace spvtools
#ifndef LIBSPIRV_OPT_TYPES_H_
#define LIBSPIRV_OPT_TYPES_H_
+#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
// which is used as a way to probe the actual <subclass>.
class Type {
public:
+ // Available subtypes.
+ //
+ // When adding a new derived class of Type, please add an entry to the enum.
+ enum Kind {
+ kVoid,
+ kBool,
+ kInteger,
+ kFloat,
+ kVector,
+ kMatrix,
+ kImage,
+ kSampler,
+ kSampledImage,
+ kArray,
+ kRuntimeArray,
+ kStruct,
+ kOpaque,
+ kPointer,
+ kFunction,
+ kEvent,
+ kDeviceEvent,
+ kReserveId,
+ kQueue,
+ kPipe,
+ kForwardPointer,
+ kPipeStorage,
+ kNamedBarrier,
+ };
+
+ Type(Kind k) : kind_(k) {}
+
virtual ~Type() {}
// Attaches a decoration directly on this type.
bool HasSameDecorations(const Type* that) const;
// Returns true if this type is exactly the same as |that| type, including
// decorations.
- virtual bool IsSame(Type* that) const = 0;
+ virtual bool IsSame(const Type* that) const = 0;
// Returns a human-readable string to represent this type.
virtual std::string str() const = 0;
+ Kind kind() const { return kind_; }
+ const std::vector<std::vector<uint32_t>>& decorations() const {
+ return decorations_;
+ }
+
// Returns true if there is no decoration on this type. For struct types,
// returns true only when there is no decoration for both the struct type
// and the struct members.
virtual bool decoration_empty() const { return decorations_.empty(); }
+ // Creates a clone of |this|.
+ std::unique_ptr<Type> Clone() const;
+
+ // Returns a clone of |this| minus any decorations.
+ std::unique_ptr<Type> RemoveDecorations() const;
+
+ // Returns true if this type must be unique.
+ //
+ // If variable pointers are allowed, then pointers are not required to be
+ // unique.
+ // TODO(alanbaker): Update this if variable pointers become a core feature.
+ bool IsUniqueType(bool allowVariablePointers = false) const;
+
// A bunch of methods for casting this type to a given type. Returns this if the
// cast can be done, nullptr otherwise.
#define DeclareCastMethod(target) \
DeclareCastMethod(NamedBarrier);
#undef DeclareCastMethod
+ bool operator==(const Type& other) const;
+
+ // Returns the hash value of this type.
+ size_t HashValue() const;
+
+ // Adds the necessary words to compute a hash value of this type to |words|.
+ void GetHashWords(std::vector<uint32_t>* words) const;
+
+ // Adds necessary extra words for a subtype to calculate a hash value into
+ // |words|.
+ virtual void GetExtraHashWords(std::vector<uint32_t>* words) const = 0;
+
protected:
// Decorations attached to this type. Each decoration is encoded as a vector
// of uint32_t numbers. The first uint32_t number is the decoration value,
// and the rest are the parameters to the decoration (if exists).
std::vector<std::vector<uint32_t>> decorations_;
+
+ private:
+ // Removes decorations on this type. For struct types, also removes element
+ // decorations.
+ virtual void ClearDecorations() { decorations_.clear(); }
+
+ Kind kind_;
};
class Integer : public Type {
public:
- Integer(uint32_t w, bool is_signed) : width_(w), signed_(is_signed) {}
+ Integer(uint32_t w, bool is_signed)
+ : Type(kInteger), width_(w), signed_(is_signed) {}
Integer(const Integer&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Integer* AsInteger() override { return this; }
uint32_t width() const { return width_; }
bool IsSigned() const { return signed_; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
uint32_t width_; // bit width
bool signed_; // true if this integer is signed
class Float : public Type {
public:
- Float(uint32_t w) : width_(w) {}
+ Float(uint32_t w) : Type(kFloat), width_(w) {}
Float(const Float&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Float* AsFloat() override { return this; }
const Float* AsFloat() const override { return this; }
uint32_t width() const { return width_; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
uint32_t width_; // bit width
};
Vector(Type* element_type, uint32_t count);
Vector(const Vector&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t element_count() const { return count_; }
Vector* AsVector() override { return this; }
const Vector* AsVector() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* element_type_;
uint32_t count_;
Matrix(Type* element_type, uint32_t count);
Matrix(const Matrix&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t element_count() const { return count_; }
Matrix* AsMatrix() override { return this; }
const Matrix* AsMatrix() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* element_type_;
uint32_t count_;
class Image : public Type {
public:
- Image(Type* sampled_type, SpvDim dim, uint32_t depth, uint32_t arrayed,
- uint32_t ms, uint32_t sampled, SpvImageFormat format,
- SpvAccessQualifier access_qualifier = SpvAccessQualifierReadOnly);
+ Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
+ uint32_t sampling, SpvImageFormat f,
+ SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly);
Image(const Image&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Image* AsImage() override { return this; }
const Image* AsImage() const override { return this; }
+ const Type* sampled_type() const { return sampled_type_; }
+ SpvDim dim() const { return dim_; }
+ uint32_t depth() const { return depth_; }
+ bool is_arrayed() const { return arrayed_; }
+ bool is_multisampled() const { return ms_; }
+ uint32_t sampled() const { return sampled_; }
+ SpvImageFormat format() const { return format_; }
+ SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
+
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* sampled_type_;
SpvDim dim_;
uint32_t depth_;
- uint32_t arrayed_;
- uint32_t ms_;
+ bool arrayed_;
+ bool ms_;
uint32_t sampled_;
SpvImageFormat format_;
SpvAccessQualifier access_qualifier_;
class SampledImage : public Type {
public:
- SampledImage(Type* image_type) : image_type_(image_type) {}
+ SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {}
SampledImage(const SampledImage&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
SampledImage* AsSampledImage() override { return this; }
const SampledImage* AsSampledImage() const override { return this; }
+ const Type* image_type() const { return image_type_; }
+
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* image_type_;
};
Array(Type* element_type, uint32_t length_id);
Array(const Array&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t LengthId() const { return length_id_; }
Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* element_type_;
uint32_t length_id_;
RuntimeArray(Type* element_type);
RuntimeArray(const RuntimeArray&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* element_type() const { return element_type_; }
RuntimeArray* AsRuntimeArray() override { return this; }
const RuntimeArray* AsRuntimeArray() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* element_type_;
};
// decoration enum, and the remaining words, if any, are its operands.
void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const std::vector<Type*>& element_types() const { return element_types_; }
bool decoration_empty() const override {
return decorations_.empty() && element_decorations_.empty();
}
+ const std::unordered_map<uint32_t, std::vector<std::vector<uint32_t>>>&
+ element_decorations() const {
+ return element_decorations_;
+ }
Struct* AsStruct() override { return this; }
const Struct* AsStruct() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
+ void ClearDecorations() override {
+ decorations_.clear();
+ element_decorations_.clear();
+ }
+
std::vector<Type*> element_types_;
// We can attach decorations to struct members and that should not affect the
// underlying element type. So we need an extra data structure here to keep
class Opaque : public Type {
public:
- Opaque(std::string name) : name_(std::move(name)) {}
+ Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {}
Opaque(const Opaque&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Opaque* AsOpaque() override { return this; }
const Opaque* AsOpaque() const override { return this; }
+ const std::string& name() const { return name_; }
+
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
std::string name_;
};
class Pointer : public Type {
public:
- Pointer(Type* pointee_type, SpvStorageClass storage_class);
+ Pointer(Type* pointee, SpvStorageClass sc);
Pointer(const Pointer&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
const Type* pointee_type() const { return pointee_type_; }
+ SpvStorageClass storage_class() const { return storage_class_; }
Pointer* AsPointer() override { return this; }
const Pointer* AsPointer() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* pointee_type_;
SpvStorageClass storage_class_;
class Function : public Type {
public:
- Function(Type* return_type, const std::vector<Type*>& param_types);
+ Function(Type* ret_type, const std::vector<Type*>& params);
Function(const Function&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Function* AsFunction() override { return this; }
const Function* AsFunction() const override { return this; }
+ const Type* return_type() const { return return_type_; }
+ const std::vector<Type*>& param_types() const { return param_types_; }
+
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
Type* return_type_;
std::vector<Type*> param_types_;
class Pipe : public Type {
public:
- Pipe(SpvAccessQualifier access_qualifier)
- : access_qualifier_(access_qualifier) {}
+ Pipe(SpvAccessQualifier qualifier)
+ : Type(kPipe), access_qualifier_(qualifier) {}
Pipe(const Pipe&) = default;
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
Pipe* AsPipe() override { return this; }
const Pipe* AsPipe() const override { return this; }
+ SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
+
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
SpvAccessQualifier access_qualifier_;
};
class ForwardPointer : public Type {
public:
- ForwardPointer(uint32_t id, SpvStorageClass storage_class)
- : target_id_(id), storage_class_(storage_class), pointer_(nullptr) {}
+ ForwardPointer(uint32_t id, SpvStorageClass sc)
+ : Type(kForwardPointer),
+ target_id_(id),
+ storage_class_(sc),
+ pointer_(nullptr) {}
ForwardPointer(const ForwardPointer&) = default;
uint32_t target_id() const { return target_id_; }
void SetTargetPointer(Pointer* pointer) { pointer_ = pointer; }
+ SpvStorageClass storage_class() const { return storage_class_; }
- bool IsSame(Type* that) const override;
+ bool IsSame(const Type* that) const override;
std::string str() const override;
ForwardPointer* AsForwardPointer() override { return this; }
const ForwardPointer* AsForwardPointer() const override { return this; }
+ void GetExtraHashWords(std::vector<uint32_t>* words) const override;
+
private:
uint32_t target_id_;
SpvStorageClass storage_class_;
Pointer* pointer_;
};
-#define DefineParameterlessType(type, name) \
- class type : public Type { \
- public: \
- type() = default; \
- type(const type&) = default; \
- \
- bool IsSame(Type* that) const override { \
- return that->As##type() && HasSameDecorations(that); \
- } \
- std::string str() const override { return #name; } \
- \
- type* As##type() override { return this; } \
- const type* As##type() const override { return this; } \
- };
+#define DefineParameterlessType(type, name) \
+ class type : public Type { \
+ public: \
+ type() : Type(k##type) {} \
+ type(const type&) = default; \
+ \
+ bool IsSame(const Type* that) const override { \
+ return that->As##type() && HasSameDecorations(that); \
+ } \
+ std::string str() const override { return #name; } \
+ \
+ type* As##type() override { return this; } \
+ const type* As##type() const override { return this; } \
+ \
+ void GetExtraHashWords(std::vector<uint32_t>*) const override {} \
+ }; // namespace analysis
DefineParameterlessType(Void, void);
DefineParameterlessType(Bool, bool);
DefineParameterlessType(Sampler, sampler);
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#ifdef SPIRV_EFFCEE
+#include "effcee/effcee.h"
+#endif
+
#include "opt/build_module.h"
#include "opt/instruction.h"
#include "opt/type_manager.h"
+#include "spirv-tools/libspirv.hpp"
namespace {
using namespace spvtools;
+using namespace spvtools::opt::analysis;
+
+bool Validate(const std::vector<uint32_t>& bin) {
+ spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
+ spv_context spvContext = spvContextCreate(target_env);
+ spv_diagnostic diagnostic = nullptr;
+ spv_const_binary_t binary = {bin.data(), bin.size()};
+ spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
+ if (error != 0) spvDiagnosticPrint(diagnostic);
+ spvDiagnosticDestroy(diagnostic);
+ spvContextDestroy(spvContext);
+ return error == 0;
+}
+
+#ifdef SPIRV_EFFCEE
+void Match(const std::string& original, ir::IRContext* context,
+ bool do_validation = true) {
+ std::vector<uint32_t> bin;
+ context->module()->ToBinary(&bin, true);
+ if (do_validation) {
+ EXPECT_TRUE(Validate(bin));
+ }
+ std::string assembly;
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_2);
+ EXPECT_TRUE(
+ tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption))
+ << "Disassembling failed for shader:\n"
+ << assembly << std::endl;
+ auto match_result = effcee::Match(assembly, original);
+ EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
+ << match_result.message() << "\nChecking result:\n"
+ << assembly;
+}
+#endif
+
+std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
+ // Types in this test case are only equal to themselves, nothing else.
+ std::vector<std::unique_ptr<Type>> types;
+
+ // Void, Bool
+ types.emplace_back(new Void());
+ auto* voidt = types.back().get();
+ types.emplace_back(new Bool());
+ auto* boolt = types.back().get();
+
+ // Integer
+ types.emplace_back(new Integer(32, true));
+ auto* s32 = types.back().get();
+ types.emplace_back(new Integer(32, false));
+ types.emplace_back(new Integer(64, true));
+ types.emplace_back(new Integer(64, false));
+ auto* u64 = types.back().get();
+
+ // Float
+ types.emplace_back(new Float(32));
+ auto* f32 = types.back().get();
+ types.emplace_back(new Float(64));
+
+ // Vector
+ types.emplace_back(new Vector(s32, 2));
+ types.emplace_back(new Vector(s32, 3));
+ auto* v3s32 = types.back().get();
+ types.emplace_back(new Vector(u64, 4));
+ types.emplace_back(new Vector(f32, 3));
+ auto* v3f32 = types.back().get();
+
+ // Matrix
+ types.emplace_back(new Matrix(v3s32, 3));
+ types.emplace_back(new Matrix(v3s32, 4));
+ types.emplace_back(new Matrix(v3f32, 4));
+
+ // Images
+ types.emplace_back(new Image(s32, SpvDim2D, 0, 0, 0, 0, SpvImageFormatRg8,
+ SpvAccessQualifierReadOnly));
+ auto* image1 = types.back().get();
+ types.emplace_back(new Image(s32, SpvDim2D, 0, 1, 0, 0, SpvImageFormatRg8,
+ SpvAccessQualifierReadOnly));
+ types.emplace_back(new Image(s32, SpvDim3D, 0, 1, 0, 0, SpvImageFormatRg8,
+ SpvAccessQualifierReadOnly));
+ types.emplace_back(new Image(voidt, SpvDim3D, 0, 1, 0, 1, SpvImageFormatRg8,
+ SpvAccessQualifierReadWrite));
+ auto* image2 = types.back().get();
+
+ // Sampler
+ types.emplace_back(new Sampler());
+
+ // Sampled Image
+ types.emplace_back(new SampledImage(image1));
+ types.emplace_back(new SampledImage(image2));
+
+ // Array
+ types.emplace_back(new Array(f32, 100));
+ types.emplace_back(new Array(f32, 42));
+ auto* a42f32 = types.back().get();
+ types.emplace_back(new Array(u64, 24));
+
+ // RuntimeArray
+ types.emplace_back(new RuntimeArray(v3f32));
+ types.emplace_back(new RuntimeArray(v3s32));
+ auto* rav3s32 = types.back().get();
+
+ // Struct
+ types.emplace_back(new Struct(std::vector<Type*>{s32}));
+ types.emplace_back(new Struct(std::vector<Type*>{s32, f32}));
+ auto* sts32f32 = types.back().get();
+ types.emplace_back(new Struct(std::vector<Type*>{u64, a42f32, rav3s32}));
+
+ // Opaque
+ types.emplace_back(new Opaque(""));
+ types.emplace_back(new Opaque("hello"));
+ types.emplace_back(new Opaque("world"));
+
+ // Pointer
+ types.emplace_back(new Pointer(f32, SpvStorageClassInput));
+ types.emplace_back(new Pointer(sts32f32, SpvStorageClassFunction));
+ types.emplace_back(new Pointer(a42f32, SpvStorageClassFunction));
+
+ // Function
+ types.emplace_back(new Function(voidt, {}));
+ types.emplace_back(new Function(voidt, {boolt}));
+ types.emplace_back(new Function(voidt, {boolt, s32}));
+ types.emplace_back(new Function(s32, {boolt, s32}));
+
+ // Event, Device Event, Reserve Id, Queue,
+ types.emplace_back(new Event());
+ types.emplace_back(new DeviceEvent());
+ types.emplace_back(new ReserveId());
+ types.emplace_back(new Queue());
+
+ // Pipe, Forward Pointer, PipeStorage, NamedBarrier
+ types.emplace_back(new Pipe(SpvAccessQualifierReadWrite));
+ types.emplace_back(new Pipe(SpvAccessQualifierReadOnly));
+ types.emplace_back(new ForwardPointer(1, SpvStorageClassInput));
+ types.emplace_back(new ForwardPointer(2, SpvStorageClassInput));
+ types.emplace_back(new ForwardPointer(2, SpvStorageClassUniform));
+ types.emplace_back(new PipeStorage());
+ types.emplace_back(new NamedBarrier());
+
+ return types;
+}
TEST(TypeManager, TypeStrings) {
const std::string text = R"(
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
EXPECT_EQ(type_id_strs.size(), manager.NumTypes());
EXPECT_EQ(2u, manager.NumForwardPointers());
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(7u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(10u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(5u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
const std::string text = "";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(0u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
)";
std::unique_ptr<ir::IRContext> context =
BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text);
- opt::analysis::TypeManager manager(nullptr, *context->module());
+ opt::analysis::TypeManager manager(nullptr, context.get());
ASSERT_EQ(5u, manager.NumTypes());
ASSERT_EQ(0u, manager.NumForwardPointers());
}
}
+TEST(TypeManager, LookupType) {
+ const std::string text = R"(
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%int = OpTypeInt 32 1
+%vec2 = OpTypeVector %int 2
+)";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+ EXPECT_NE(context, nullptr);
+ TypeManager manager(nullptr, context.get());
+
+ Void voidTy;
+ EXPECT_EQ(manager.GetId(&voidTy), 1u);
+
+ Integer uintTy(32, false);
+ EXPECT_EQ(manager.GetId(&uintTy), 2u);
+
+ Integer intTy(32, true);
+ EXPECT_EQ(manager.GetId(&intTy), 3u);
+
+ Integer intTy2(32, true);
+ Vector vecTy(&intTy2, 2u);
+ EXPECT_EQ(manager.GetId(&vecTy), 4u);
+}
+
+TEST(TypeManager, RemoveId) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeInt 32 1
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ context->get_type_mgr()->RemoveId(1u);
+ ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr);
+ ASSERT_NE(context->get_type_mgr()->GetType(2u), nullptr);
+
+ context->get_type_mgr()->RemoveId(2u);
+ ASSERT_EQ(context->get_type_mgr()->GetType(1u), nullptr);
+ ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr);
+}
+
+TEST(TypeManager, RemoveIdNonDuplicateAmbiguousType) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ Integer u32(32, false);
+ Struct st({&u32});
+ ASSERT_EQ(context->get_type_mgr()->GetId(&st), 2u);
+ context->get_type_mgr()->RemoveId(2u);
+ ASSERT_EQ(context->get_type_mgr()->GetType(2u), nullptr);
+ ASSERT_EQ(context->get_type_mgr()->GetId(&st), 0u);
+}
+
+TEST(TypeManager, RemoveIdDuplicateAmbiguousType) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+%3 = OpTypeStruct %1
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ Integer u32(32, false);
+ Struct st({&u32});
+ uint32_t id = context->get_type_mgr()->GetId(&st);
+ ASSERT_NE(id, 0u);
+ uint32_t toRemove = id == 2u ? 2u : 3u;
+ uint32_t toStay = id == 2u ? 3u : 2u;
+ context->get_type_mgr()->RemoveId(toRemove);
+ ASSERT_EQ(context->get_type_mgr()->GetType(toRemove), nullptr);
+ ASSERT_EQ(context->get_type_mgr()->GetId(&st), toStay);
+}
+
+TEST(TypeManager, GetTypeAndPointerType) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeInt 32 0
+%2 = OpTypeStruct %1
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ Integer u32(32, false);
+ Pointer u32Ptr(&u32, SpvStorageClassFunction);
+ Struct st({&u32});
+ Pointer stPtr(&st, SpvStorageClassInput);
+
+ auto pair = context->get_type_mgr()->GetTypeAndPointerType(
+ 3u, SpvStorageClassFunction);
+ ASSERT_EQ(nullptr, pair.first);
+ ASSERT_EQ(nullptr, pair.second);
+
+ pair = context->get_type_mgr()->GetTypeAndPointerType(
+ 1u, SpvStorageClassFunction);
+ ASSERT_TRUE(pair.first->IsSame(&u32));
+ ASSERT_TRUE(pair.second->IsSame(&u32Ptr));
+
+ pair =
+ context->get_type_mgr()->GetTypeAndPointerType(2u, SpvStorageClassInput);
+ ASSERT_TRUE(pair.first->IsSame(&st));
+ ASSERT_TRUE(pair.second->IsSame(&stPtr));
+}
+
+#ifdef SPIRV_EFFCEE
+TEST(TypeManager, GetTypeInstructionInt) {
+ const std::string text = R"(
+; CHECK: OpTypeInt 32 0
+; CHECK: OpTypeInt 16 1
+OpCapability Shader
+OpCapability Int16
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+ EXPECT_NE(context, nullptr);
+
+ Integer uint_32(32, false);
+ context->get_type_mgr()->GetTypeInstruction(&uint_32);
+
+ Integer int_16(16, true);
+ context->get_type_mgr()->GetTypeInstruction(&int_16);
+
+ Match(text, context.get());
+}
+
+TEST(TypeManager, GetTypeInstructionDuplicateInts) {
+ const std::string text = R"(
+; CHECK: OpTypeInt 32 0
+; CHECK-NOT: OpType
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+ EXPECT_NE(context, nullptr);
+
+ Integer uint_32(32, false);
+ uint32_t id = context->get_type_mgr()->GetTypeInstruction(&uint_32);
+
+ Integer other(32, false);
+ EXPECT_EQ(context->get_type_mgr()->GetTypeInstruction(&other), id);
+
+ Match(text, context.get());
+}
+
+TEST(TypeManager, GetTypeInstructionAllTypes) {
+ const std::string text = R"(
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
+; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
+; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
+; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
+; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK: [[bool:%\w+]] = OpTypeBool
+; CHECK: [[s32:%\w+]] = OpTypeInt 32 1
+; CHECK: OpTypeInt 64 1
+; CHECK: [[u64:%\w+]] = OpTypeInt 64 0
+; CHECK: [[f32:%\w+]] = OpTypeFloat 32
+; CHECK: OpTypeFloat 64
+; CHECK: OpTypeVector [[s32]] 2
+; CHECK: [[v3s32:%\w+]] = OpTypeVector [[s32]] 3
+; CHECK: OpTypeVector [[u64]] 4
+; CHECK: [[v3f32:%\w+]] = OpTypeVector [[f32]] 3
+; CHECK: OpTypeMatrix [[v3s32]] 3
+; CHECK: OpTypeMatrix [[v3s32]] 4
+; CHECK: OpTypeMatrix [[v3f32]] 4
+; CHECK: [[image1:%\w+]] = OpTypeImage [[s32]] 2D 0 0 0 0 Rg8 ReadOnly
+; CHECK: OpTypeImage [[s32]] 2D 0 1 0 0 Rg8 ReadOnly
+; CHECK: OpTypeImage [[s32]] 3D 0 1 0 0 Rg8 ReadOnly
+; CHECK: [[image2:%\w+]] = OpTypeImage [[void]] 3D 0 1 0 1 Rg8 ReadWrite
+; CHECK: OpTypeSampler
+; CHECK: OpTypeSampledImage [[image1]]
+; CHECK: OpTypeSampledImage [[image2]]
+; CHECK: OpTypeArray [[f32]] [[uint100]]
+; CHECK: [[a42f32:%\w+]] = OpTypeArray [[f32]] [[uint42]]
+; CHECK: OpTypeArray [[u64]] [[uint24]]
+; CHECK: OpTypeRuntimeArray [[v3f32]]
+; CHECK: [[rav3s32:%\w+]] = OpTypeRuntimeArray [[v3s32]]
+; CHECK: OpTypeStruct [[s32]]
+; CHECK: [[sts32f32:%\w+]] = OpTypeStruct [[s32]] [[f32]]
+; CHECK: OpTypeStruct [[u64]] [[a42f32]] [[rav3s32]]
+; CHECK: OpTypeOpaque ""
+; CHECK: OpTypeOpaque "hello"
+; CHECK: OpTypeOpaque "world"
+; CHECK: OpTypePointer Input [[f32]]
+; CHECK: OpTypePointer Function [[sts32f32]]
+; CHECK: OpTypePointer Function [[a42f32]]
+; CHECK: OpTypeFunction [[void]]
+; CHECK: OpTypeFunction [[void]] [[bool]]
+; CHECK: OpTypeFunction [[void]] [[bool]] [[s32]]
+; CHECK: OpTypeFunction [[s32]] [[bool]] [[s32]]
+; CHECK: OpTypeEvent
+; CHECK: OpTypeDeviceEvent
+; CHECK: OpTypeReserveId
+; CHECK: OpTypeQueue
+; CHECK: OpTypePipe ReadWrite
+; CHECK: OpTypePipe ReadOnly
+; CHECK: OpTypeForwardPointer [[input_ptr]] Input
+; CHECK: OpTypeForwardPointer [[uniform_ptr]] Input
+; CHECK: OpTypeForwardPointer [[uniform_ptr]] Uniform
+; CHECK: OpTypePipeStorage
+; CHECK: OpTypeNamedBarrier
+OpCapability Shader
+OpCapability Int64
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%uint = OpTypeInt 32 0
+%1 = OpTypePointer Input %uint
+%2 = OpTypePointer Uniform %uint
+%24 = OpConstant %uint 24
+%42 = OpConstant %uint 42
+%100 = OpConstant %uint 100
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+ for (auto& t : types) {
+ context->get_type_mgr()->GetTypeInstruction(t.get());
+ }
+
+ Match(text, context.get(), false);
+}
+
+TEST(TypeManager, GetTypeInstructionWithDecorations) {
+ const std::string text = R"(
+; CHECK: OpDecorate [[struct:%\w+]] CPacked
+; CHECK: OpMemberDecorate [[struct]] 1 Offset 4
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct]] = OpTypeStruct [[uint]] [[uint]]
+OpCapability Shader
+OpCapability Kernel
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%uint = OpTypeInt 32 0
+ )";
+
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(context, nullptr);
+
+ Integer u32(32, false);
+ Struct st({&u32, &u32});
+ st.AddDecoration({10});
+ st.AddMemberDecoration(1, {{35, 4}});
+ (void)context->get_def_use_mgr();
+ context->get_type_mgr()->GetTypeInstruction(&st);
+
+ Match(text, context.get());
+}
+#endif // SPIRV_EFFCEE
+
} // anonymous namespace
EXPECT_TRUE(types[i]->IsSame(types[j].get())) \
<< "expected '" << types[i]->str() << "' is the same as '" \
<< types[j]->str() << "'"; \
+ EXPECT_TRUE(*types[i] == *types[j]) \
+ << "expected '" << types[i]->str() << "' is the same as '" \
+ << types[j]->str() << "'"; \
} \
} \
}
TestMultipleInstancesOfTheSameType(NamedBarrier);
#undef TestMultipleInstanceOfTheSameType
-TEST(Types, AllTypes) {
+std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
// Types in this test case are only equal to themselves, nothing else.
std::vector<std::unique_ptr<Type>> types;
types.emplace_back(new PipeStorage());
types.emplace_back(new NamedBarrier());
+ return types;
+}
+
+TEST(Types, AllTypes) {
+ // Types in this test case are only equal to themselves, nothing else.
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+
for (size_t i = 0; i < types.size(); ++i) {
for (size_t j = 0; j < types.size(); ++j) {
if (i == j) {
}
}
+TEST(Types, IsUniqueType) {
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+
+ for (auto& t : types) {
+ bool expectation = true;
+ // Disallowing variable pointers.
+ switch (t->kind()) {
+ case Type::kArray:
+ case Type::kRuntimeArray:
+ case Type::kStruct:
+ expectation = false;
+ break;
+ default:
+ break;
+ }
+ EXPECT_EQ(t->IsUniqueType(false), expectation)
+ << "expected '" << t->str() << "' to be a "
+ << (expectation ? "" : "non-") << "unique type";
+
+ // Allowing variables pointers.
+ if (t->AsPointer()) expectation = false;
+ EXPECT_EQ(t->IsUniqueType(true), expectation)
+ << "expected '" << t->str() << "' to be a "
+ << (expectation ? "" : "non-") << "unique type";
+ }
+}
+
+std::vector<std::unique_ptr<Type>> GenerateAllTypesWithDecorations() {
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
+ uint32_t elems = 1;
+ uint32_t decs = 1;
+ for (auto& t : types) {
+ for (uint32_t i = 0; i < (decs % 10); ++i) {
+ std::vector<uint32_t> decoration;
+ for (uint32_t j = 0; j < (elems % 4) + 1; ++j) {
+ decoration.push_back(j);
+ }
+ t->AddDecoration(std::move(decoration));
+ ++elems;
+ ++decs;
+ }
+ }
+
+ return types;
+}
+
+TEST(Types, Clone) {
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypesWithDecorations();
+ for (auto& t : types) {
+ auto clone = t->Clone();
+ EXPECT_TRUE(*t == *clone);
+ EXPECT_TRUE(t->HasSameDecorations(clone.get()));
+ EXPECT_NE(clone.get(), t.get());
+ }
+}
+
+TEST(Types, RemoveDecorations) {
+ std::vector<std::unique_ptr<Type>> types = GenerateAllTypesWithDecorations();
+ for (auto& t : types) {
+ auto decorationless = t->RemoveDecorations();
+ EXPECT_EQ(*t == *decorationless, t->decoration_empty());
+ EXPECT_EQ(t->HasSameDecorations(decorationless.get()),
+ t->decoration_empty());
+ EXPECT_NE(t.get(), decorationless.get());
+ }
+}
+
} // anonymous namespace