--- /dev/null
+//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/OwningOpRefBase.h"
+
+namespace mlir {
+namespace spirv {
+
+/// This class acts as an owning reference to a SPIR-V module, and will
+/// automatically destroy the held module on destruction if the held module
+/// is valid.
+class OwningSPIRVModuleRef : public OwningOpRefBase<spirv::ModuleOp> {
+public:
+ using OwningOpRefBase<spirv::ModuleOp>::OwningOpRefBase;
+};
+
+} // end namespace spirv
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H
namespace spirv {
class ModuleOp;
+class OwningSPIRVModuleRef;
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
/// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp
/// in the given `context`. Returns the ModuleOp on success; otherwise, reports
-/// errors to the error handler registered with `context` and returns
-/// llvm::None.
-Optional<ModuleOp> deserialize(ArrayRef<uint32_t> binary, MLIRContext *context);
+/// errors to the error handler registered with `context` and returns a null
+/// module.
+OwningSPIRVModuleRef deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context);
} // end namespace spirv
} // end namespace mlir
#ifndef MLIR_IR_MODULE_H
#define MLIR_IR_MODULE_H
+#include "mlir/IR/OwningOpRefBase.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
};
/// This class acts as an owning reference to a module, and will automatically
-/// destroy the held module if valid.
-class OwningModuleRef {
+/// destroy the held module on destruction if the held module is valid.
+class OwningModuleRef : public OwningOpRefBase<ModuleOp> {
public:
- OwningModuleRef(std::nullptr_t = nullptr) {}
- OwningModuleRef(ModuleOp module) : module(module) {}
- OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
- ~OwningModuleRef() {
- if (module)
- module.erase();
- }
-
- // Assign from another module reference.
- OwningModuleRef &operator=(OwningModuleRef &&other) {
- if (module)
- module.erase();
- module = other.release();
- return *this;
- }
-
- /// Allow accessing the internal module.
- ModuleOp get() const { return module; }
- ModuleOp operator*() const { return module; }
- ModuleOp *operator->() { return &module; }
- explicit operator bool() const { return module; }
-
- /// Release the referenced module.
- ModuleOp release() {
- ModuleOp released;
- std::swap(released, module);
- return released;
- }
-
-private:
- ModuleOp module;
+ using OwningOpRefBase<ModuleOp>::OwningOpRefBase;
};
} // end namespace mlir
--- /dev/null
+//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a base class for owning op refs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_OWNINGOPREFBASE_H
+#define MLIR_IR_OWNINGOPREFBASE_H
+
+#include <utility>
+
+namespace mlir {
+
+/// This class acts as an owning reference to an op, and will automatically
+/// destroy the held op on destruction if the held op is valid.
+///
+/// Note that OpBuilder and related functionality should be highly preferred
+/// instead, and this should only be used in situations where existing solutions
+/// are not viable.
+template <typename OpTy>
+class OwningOpRefBase {
+public:
+ OwningOpRefBase(std::nullptr_t = nullptr) {}
+ OwningOpRefBase(OpTy op) : op(op) {}
+ OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {}
+ ~OwningOpRefBase() {
+ if (op)
+ op.erase();
+ }
+
+ // Assign from another op reference.
+ OwningOpRefBase &operator=(OwningOpRefBase &&other) {
+ if (op)
+ op.erase();
+ op = other.release();
+ return *this;
+ }
+
+ /// Allow accessing the internal op.
+ OpTy get() const { return op; }
+ OpTy operator*() const { return op; }
+ OpTy *operator->() { return &op; }
+ explicit operator bool() const { return op; }
+
+ /// Release the referenced op.
+ OpTy release() {
+ OpTy released;
+ std::swap(released, op);
+ return released;
+ }
+
+private:
+ OpTy op;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_OWNINGOPREFBASE_H
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
} // namespace
-Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
- MLIRContext *context) {
+spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context) {
Deserializer deserializer(binary, context);
if (failed(deserializer.deserialize()))
- return llvm::None;
+ return nullptr;
- return deserializer.collect();
+ return deserializer.collect().getValueOr(nullptr);
}
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Builders.h"
auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
size / sizeof(uint32_t));
- auto spirvModule = spirv::deserialize(binary, context);
+ spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return {};
OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
- module->getBody()->push_front(spirvModule->getOperation());
+ module->getBody()->push_front(spirvModule.release());
return module;
}
return failure();
// Then deserialize to get back a SPIR-V module.
- auto spirvModule = spirv::deserialize(binary, context);
+ spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
if (!spirvModule)
return failure();
// Wrap around in a new MLIR module.
OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
/*filename=*/"", /*line=*/0, /*column=*/0, context)));
- dstModule->getBody()->push_front(spirvModule->getOperation());
+ dstModule->getBody()->push_front(spirvModule.release());
dstModule->print(output);
return mlir::success();
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Diagnostics.h"
}
/// Performs deserialization and returns the constructed spv.module op.
- Optional<spirv::ModuleOp> deserialize() {
+ spirv::OwningSPIRVModuleRef deserialize() {
return spirv::deserialize(binary, &context);
}
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, EmptyModuleFailure) {
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("SPIR-V binary module must have a 5-word header");
}
TEST_F(DeserializationTest, WrongMagicNumberFailure) {
addHeader();
binary.front() = 0xdeadbeef; // Change to a wrong magic number
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("incorrect magic number");
}
TEST_F(DeserializationTest, OnlyHeaderSuccess) {
addHeader();
- EXPECT_NE(llvm::None, deserialize());
+ EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, ZeroWordCountFailure) {
addHeader();
binary.push_back(0); // OpNop with zero word count
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("word count cannot be zero");
}
static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
// Missing word for type <id>
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("insufficient words for the last instruction");
}
addHeader();
addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}
addInstruction(spirv::Opcode::OpMemberName, operands2);
binary.append(typeDecl.begin(), typeDecl.end());
- EXPECT_NE(llvm::None, deserialize());
+ EXPECT_TRUE(deserialize());
}
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
addInstruction(spirv::Opcode::OpMemberName, operands1);
binary.append(typeDecl.begin(), typeDecl.end());
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpMemberName must have at least 3 operands");
}
addInstruction(spirv::Opcode::OpMemberName, operands);
binary.append(typeDecl.begin(), typeDecl.end());
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
}
addFunction(voidType, fnType);
// Missing OpFunctionEnd
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionEnd instruction");
}
addFunction(voidType, fnType);
// Missing OpFunctionParameter
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("expected OpFunctionParameter instruction");
}
addReturn();
addFunctionEnd();
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("a basic block must start with OpLabel");
}
addReturn();
addFunctionEnd();
- ASSERT_EQ(llvm::None, deserialize());
+ ASSERT_FALSE(deserialize());
expectDiagnostic("OpLabel should only have result <id>");
}