[mlir][spirv] Introduce OwningSPIRVModuleRef for ownership
authorLei Zhang <antiagainst@google.com>
Tue, 7 Jul 2020 12:28:25 +0000 (08:28 -0400)
committerLei Zhang <antiagainst@google.com>
Tue, 7 Jul 2020 12:29:27 +0000 (08:29 -0400)
Similar to OwningModuleRef, OwningSPIRVModuleRef signals ownership
transfer clearly. This is useful for APIs like spirv::deserialize,
where a spirv::ModuleOp is returned by deserializing SPIR-V binary
module.

This addresses the ASAN error as reported in
https://bugs.llvm.org/show_bug.cgi?id=46272

Differential Revision: https://reviews.llvm.org/D81652

mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/Serialization.h
mlir/include/mlir/IR/Module.h
mlir/include/mlir/IR/OwningOpRefBase.h [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp

diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h
new file mode 100644 (file)
index 0000000..a53331e
--- /dev/null
@@ -0,0 +1,29 @@
+//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/OwningOpRefBase.h"
+
+namespace mlir {
+namespace spirv {
+
+/// This class acts as an owning reference to a SPIR-V module, and will
+/// automatically destroy the held module on destruction if the held module
+/// is valid.
+class OwningSPIRVModuleRef : public OwningOpRefBase<spirv::ModuleOp> {
+public:
+  using OwningOpRefBase<spirv::ModuleOp>::OwningOpRefBase;
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H
index f6370a1..2c91286 100644 (file)
@@ -22,6 +22,7 @@ class MLIRContext;
 
 namespace spirv {
 class ModuleOp;
+class OwningSPIRVModuleRef;
 
 /// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
 /// reports errors to the error handler registered with the MLIR context for
@@ -31,9 +32,10 @@ LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
 
 /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
 /// in the given `context`. Returns the ModuleOp on success; otherwise, reports
-/// errors to the error handler registered with `context` and returns
-/// llvm::None.
-Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
+/// errors to the error handler registered with `context` and returns a null
+/// module.
+OwningSPIRVModuleRef deserialize(ArrayRef<uint32_t> binary,
+                                 MLIRContext *context);
 
 } // end namespace spirv
 } // end namespace mlir
index 3c61574..8a51013 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_IR_MODULE_H
 #define MLIR_IR_MODULE_H
 
+#include "mlir/IR/OwningOpRefBase.h"
 #include "mlir/IR/SymbolTable.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
@@ -122,40 +123,10 @@ public:
 };
 
 /// This class acts as an owning reference to a module, and will automatically
-/// destroy the held module if valid.
-class OwningModuleRef {
+/// destroy the held module on destruction if the held module is valid.
+class OwningModuleRef : public OwningOpRefBase<ModuleOp> {
 public:
-  OwningModuleRef(std::nullptr_t = nullptr) {}
-  OwningModuleRef(ModuleOp module) : module(module) {}
-  OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
-  ~OwningModuleRef() {
-    if (module)
-      module.erase();
-  }
-
-  // Assign from another module reference.
-  OwningModuleRef &operator=(OwningModuleRef &&other) {
-    if (module)
-      module.erase();
-    module = other.release();
-    return *this;
-  }
-
-  /// Allow accessing the internal module.
-  ModuleOp get() const { return module; }
-  ModuleOp operator*() const { return module; }
-  ModuleOp *operator->() { return &module; }
-  explicit operator bool() const { return module; }
-
-  /// Release the referenced module.
-  ModuleOp release() {
-    ModuleOp released;
-    std::swap(released, module);
-    return released;
-  }
-
-private:
-  ModuleOp module;
+  using OwningOpRefBase<ModuleOp>::OwningOpRefBase;
 };
 
 } // end namespace mlir
diff --git a/mlir/include/mlir/IR/OwningOpRefBase.h b/mlir/include/mlir/IR/OwningOpRefBase.h
new file mode 100644 (file)
index 0000000..bfdf98f
--- /dev/null
@@ -0,0 +1,64 @@
+//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a base class for owning op refs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OWNINGOPREFBASE_H
+#define MLIR_IR_OWNINGOPREFBASE_H
+
+#include <utility>
+
+namespace mlir {
+
+/// This class acts as an owning reference to an op, and will automatically
+/// destroy the held op on destruction if the held op is valid.
+///
+/// Note that OpBuilder and related functionality should be highly preferred
+/// instead, and this should only be used in situations where existing solutions
+/// are not viable.
+template <typename OpTy>
+class OwningOpRefBase {
+public:
+  OwningOpRefBase(std::nullptr_t = nullptr) {}
+  OwningOpRefBase(OpTy op) : op(op) {}
+  OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {}
+  ~OwningOpRefBase() {
+    if (op)
+      op.erase();
+  }
+
+  // Assign from another op reference.
+  OwningOpRefBase &operator=(OwningOpRefBase &&other) {
+    if (op)
+      op.erase();
+    op = other.release();
+    return *this;
+  }
+
+  /// Allow accessing the internal op.
+  OpTy get() const { return op; }
+  OpTy operator*() const { return op; }
+  OpTy *operator->() { return &op; }
+  explicit operator bool() const { return op; }
+
+  /// Release the referenced op.
+  OpTy release() {
+    OpTy released;
+    std::swap(released, op);
+    return released;
+  }
+
+private:
+  OpTy op;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_OWNINGOPREFBASE_H
index b5fef14..92f2a01 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
@@ -2516,12 +2517,12 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
 #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
 } // namespace
 
-Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
-                                             MLIRContext *context) {
+spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
+                                               MLIRContext *context) {
   Deserializer deserializer(binary, context);
 
   if (failed(deserializer.deserialize()))
-    return llvm::None;
+    return nullptr;
 
-  return deserializer.collect();
+  return deserializer.collect().getValueOr(nullptr);
 }
index 4c3fb1e..42b458d 100644 (file)
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/IR/Builders.h"
@@ -49,13 +50,13 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
   auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
                                    size / sizeof(uint32_t));
 
-  auto spirvModule = spirv::deserialize(binary, context);
+  spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
   if (!spirvModule)
     return {};
 
   OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
       input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
-  module->getBody()->push_front(spirvModule->getOperation());
+  module->getBody()->push_front(spirvModule.release());
 
   return module;
 }
@@ -136,14 +137,14 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
     return failure();
 
   // Then deserialize to get back a SPIR-V module.
-  auto spirvModule = spirv::deserialize(binary, context);
+  spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
   if (!spirvModule)
     return failure();
 
   // Wrap around in a new MLIR module.
   OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
       /*filename=*/"", /*line=*/0, /*column=*/0, context)));
-  dstModule->getBody()->push_front(spirvModule->getOperation());
+  dstModule->getBody()->push_front(spirvModule.release());
   dstModule->print(output);
 
   return mlir::success();
index 31fc0e4..a81b774 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/IR/Diagnostics.h"
@@ -46,7 +47,7 @@ protected:
   }
 
   /// Performs deserialization and returns the constructed spv.module op.
-  Optional<spirv::ModuleOp> deserialize() {
+  spirv::OwningSPIRVModuleRef deserialize() {
     return spirv::deserialize(binary, &context);
   }
 
@@ -130,27 +131,27 @@ protected:
 //===----------------------------------------------------------------------===//
 
 TEST_F(DeserializationTest, EmptyModuleFailure) {
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("SPIR-V binary module must have a 5-word header");
 }
 
 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
   addHeader();
   binary.front() = 0xdeadbeef; // Change to a wrong magic number
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("incorrect magic number");
 }
 
 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
   addHeader();
-  EXPECT_NE(llvm::None, deserialize());
+  EXPECT_TRUE(deserialize());
 }
 
 TEST_F(DeserializationTest, ZeroWordCountFailure) {
   addHeader();
   binary.push_back(0); // OpNop with zero word count
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("word count cannot be zero");
 }
 
@@ -160,7 +161,7 @@ TEST_F(DeserializationTest, InsufficientWordFailure) {
                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
   // Missing word for type <id>
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("insufficient words for the last instruction");
 }
 
@@ -172,7 +173,7 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
   addHeader();
   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
 }
 
@@ -198,7 +199,7 @@ TEST_F(DeserializationTest, OpMemberNameSuccess) {
   addInstruction(spirv::Opcode::OpMemberName, operands2);
 
   binary.append(typeDecl.begin(), typeDecl.end());
-  EXPECT_NE(llvm::None, deserialize());
+  EXPECT_TRUE(deserialize());
 }
 
 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
@@ -215,7 +216,7 @@ TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
   addInstruction(spirv::Opcode::OpMemberName, operands1);
 
   binary.append(typeDecl.begin(), typeDecl.end());
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("OpMemberName must have at least 3 operands");
 }
 
@@ -234,7 +235,7 @@ TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
   addInstruction(spirv::Opcode::OpMemberName, operands);
 
   binary.append(typeDecl.begin(), typeDecl.end());
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
 }
 
@@ -249,7 +250,7 @@ TEST_F(DeserializationTest, FunctionMissingEndFailure) {
   addFunction(voidType, fnType);
   // Missing OpFunctionEnd
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("expected OpFunctionEnd instruction");
 }
 
@@ -261,7 +262,7 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
   addFunction(voidType, fnType);
   // Missing OpFunctionParameter
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("expected OpFunctionParameter instruction");
 }
 
@@ -274,7 +275,7 @@ TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
   addReturn();
   addFunctionEnd();
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("a basic block must start with OpLabel");
 }
 
@@ -287,6 +288,6 @@ TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
   addReturn();
   addFunctionEnd();
 
-  ASSERT_EQ(llvm::None, deserialize());
+  ASSERT_FALSE(deserialize());
   expectDiagnostic("OpLabel should only have result <id>");
 }