From 2dd9e43579b341e5de238de924cc910042b0194e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 16 Jul 2020 16:05:51 -0400 Subject: [PATCH] [spirv] Use owning module ref to avoid leaks and fix ASAN tests Differential Revision: https://reviews.llvm.org/D83982 --- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 22 ++++++++++++---------- mlir/unittests/Dialect/SPIRV/SerializationTest.cpp | 9 +++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 4e6b294..4ba3f16 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -103,7 +103,7 @@ public: LogicalResult deserialize(); /// Collects the final SPIR-V ModuleOp. - Optional collect(); + spirv::OwningSPIRVModuleRef collect(); private: //===--------------------------------------------------------------------===// @@ -111,7 +111,7 @@ private: //===--------------------------------------------------------------------===// /// Initializes the `module` ModuleOp in this deserializer instance. - spirv::ModuleOp createModuleOp(); + spirv::OwningSPIRVModuleRef createModuleOp(); /// Processes SPIR-V module header in `binary`. LogicalResult processHeader(); @@ -425,7 +425,7 @@ private: Location unknownLoc; /// The SPIR-V ModuleOp. - Optional module; + spirv::OwningSPIRVModuleRef module; /// The current function under construction. Optional curFunction; @@ -556,13 +556,15 @@ LogicalResult Deserializer::deserialize() { return success(); } -Optional Deserializer::collect() { return module; } +spirv::OwningSPIRVModuleRef Deserializer::collect() { + return std::move(module); +} //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// -spirv::ModuleOp Deserializer::createModuleOp() { +spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() { OpBuilder builder(context); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); spirv::ModuleOp::build(builder, state); @@ -1912,10 +1914,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() { // Go through all ops and remap the operands. auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) - if (auto mappedOp = mapper.lookupOrNull(operand.get())) + if (Value mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); for (auto &succOp : op->getBlockOperands()) - if (auto mappedOp = mapper.lookupOrNull(succOp.get())) + if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; for (auto &block : body) { @@ -2354,7 +2356,7 @@ Deserializer::processOp(ArrayRef words) { return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } - auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]); + auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing in OpEntryPoint"); } @@ -2382,7 +2384,7 @@ Deserializer::processOp(ArrayRef words) { interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } - opBuilder.create(unknownLoc, exec_model, + opBuilder.create(unknownLoc, execModel, opBuilder.getSymbolRefAttr(fnName), opBuilder.getArrayAttr(interface)); return success(); @@ -2594,5 +2596,5 @@ spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, if (failed(deserializer.deserialize())) return nullptr; - return deserializer.collect().getValueOr(nullptr); + return deserializer.collect(); } diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp index 340bfd9..3d57e55 100644 --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -56,7 +57,7 @@ protected: } Type getFloatStructType() { - OpBuilder opBuilder(module.body()); + OpBuilder opBuilder(module->body()); llvm::SmallVector elementTypes{opBuilder.getF32Type()}; llvm::SmallVector offsetInfo{0}; auto structType = spirv::StructType::get(elementTypes, offsetInfo); @@ -64,7 +65,7 @@ protected: } void addGlobalVar(Type type, llvm::StringRef name) { - OpBuilder opBuilder(module.body()); + OpBuilder opBuilder(module->body()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); opBuilder.create( UnknownLoc::get(&context), TypeAttr::get(ptrType), @@ -98,7 +99,7 @@ protected: protected: MLIRContext context; - spirv::ModuleOp module; + spirv::OwningSPIRVModuleRef module; SmallVector binary; }; @@ -109,7 +110,7 @@ protected: TEST_F(SerializationTest, BlockDecorationTest) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0"); - ASSERT_TRUE(succeeded(spirv::serialize(module, binary))); + ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); auto hasBlockDecoration = [](spirv::Opcode opcode, ArrayRef operands) -> bool { if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) -- 2.7.4