From 23cf3b39e0a88975039fb173eb3befd46e13fe60 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 4 Jun 2019 14:03:30 -0700 Subject: [PATCH] [spirv] Basic validity of SPV_ModuleOp This CL adds SPV_ModuleEndOp for terminating the only block inside a SPV_ModuleOp's only region. Verification now enforces a spv.module only contains func or spv.* ops and no external or nested functions are present. Because of the structural requirement of a block, spv.Return is also added in this CL. PiperOrigin-RevId: 251510706 --- mlir/include/mlir/SPIRV/SPIRVOps.td | 17 ++++ mlir/include/mlir/SPIRV/SPIRVStructureOps.td | 27 ++++++- mlir/lib/SPIRV/SPIRVOps.cpp | 94 ++++++++++++++++++++++ mlir/test/SPIRV/ops.mlir | 25 ++++++ mlir/test/SPIRV/structure-ops.mlir | 114 +++++++++++++++++++++++++++ 5 files changed, 276 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index cb82268..d68a660 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -53,4 +53,21 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { let printer = [{ return impl::printBinaryOp(getOperation(), p); }]; } +def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { + let summary = "Return with no value from a function with void return type"; + + let description = [{ + This instruction must be the last instruction in a block. + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; + + let verifier = [{ return verifyReturn(*this); }]; +} + #endif // SPIRV_OPS diff --git a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td index bcb485c..6e084b9 100644 --- a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td @@ -60,6 +60,9 @@ def SPV_ModuleOp : SPV_Op<"module", []> { This op takes no operands and generates no results. This op should not implicitly capture values from the enclosing environment. + + This op has only one region, which only contains one block. The block + must be terminated via the `spv._module_end` op. }]; let arguments = (ins @@ -72,11 +75,33 @@ def SPV_ModuleOp : SPV_Op<"module", []> { let results = (outs); - let regions = (region AnyRegion:$body); + let regions = (region SizedRegion<1>:$body); // Custom parser and printer implemented by static functions in SPVOps.cpp let parser = [{ return parseModule(parser, result); }]; let printer = [{ printModule(getOperation(), p); }]; + + let verifier = [{ return verifyModule(*this); }]; +} + +def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator]> { + let summary = "The pseudo op that ends a SPIR-V module"; + + let description = [{ + This op terminates the only block inside a `spv.module`'s only region. + This op does not have a corresponding SPIR-V instruction and thus will + not be serialized into the binary format; it is used solely to satisfy + the structual requirement that an block must be ended with a terminator. + }]; + + let arguments = (ins); + + let results = (outs); + + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; + + let verifier = [{ return verifyModuleOnly(*this); }]; } #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/lib/SPIRV/SPIRVOps.cpp b/mlir/lib/SPIRV/SPIRVOps.cpp index 6b3449a..4f354b5 100644 --- a/mlir/lib/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/SPIRV/SPIRVOps.cpp @@ -28,9 +28,50 @@ using namespace mlir; //===----------------------------------------------------------------------===// +// Common utility functions +//===----------------------------------------------------------------------===// + +// Parses an op that has no inputs and no outputs. +static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) { + if (parser->parseOptionalAttributeDict(state->attributes)) + return failure(); + return success(); +} + +// Prints an op that has no inputs and no outputs. +static ParseResult printNoIOOp(Operation *op, OpAsmPrinter *printer) { + *printer << op->getName(); + printer->printOptionalAttrDict(op->getAttrs(), + /*elidedAttrs=*/{}); + return success(); +} + +// Verifies that the given op can only be placed in a `spv.module`. +static LogicalResult verifyModuleOnly(Operation *op) { + if (!llvm::isa_and_nonnull(op->getParentOp())) + return op->emitOpError("can only be used in a 'spv.module' block"); + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.module //===----------------------------------------------------------------------===// +static void ensureModuleEnd(Region *region, Builder builder, Location loc) { + if (region->empty()) + region->push_back(new Block); + + Block &block = region->back(); + if (!block.empty() && llvm::isa(block.back())) + return; + + OperationState state(builder.getContext(), loc, + spirv::ModuleEndOp::getOperationName()); + spirv::ModuleEndOp::build(&builder, &state); + block.push_back(Operation::create(state)); +} + static ParseResult parseModule(OpAsmParser *parser, OperationState *state) { Region *body = state->addRegion(); @@ -39,6 +80,8 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) { parser->parseOptionalAttributeDict(state->attributes)) return failure(); + ensureModuleEnd(body, parser->getBuilder(), state->location); + return success(); } @@ -52,6 +95,57 @@ static ParseResult printModule(Operation *op, OpAsmPrinter *printer) { return success(); } +static LogicalResult verifyModule(spirv::ModuleOp moduleOp) { + auto &op = *moduleOp.getOperation(); + auto *dialect = op.getDialect(); + auto &body = op.getRegion(0).front(); + + for (auto &op : body) { + if (op.getDialect() == dialect) + continue; + + auto funcOp = llvm::dyn_cast(op); + if (!funcOp) + return op.emitError("'spv.module' can only contain func and spv.* ops"); + + if (funcOp.isExternal()) + return op.emitError("'spv.module' cannot contain external functions"); + + for (auto &block : funcOp) + for (auto &op : block) { + // TODO(antiagainst): verify that return ops have the same type as the + // enclosing function + if (op.getDialect() == dialect) + continue; + + if (llvm::isa(op)) + return op.emitError("'spv.module' cannot contain nested functions"); + + return op.emitError( + "functions in 'spv.module' can only contain spv.* ops"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.Return +//===----------------------------------------------------------------------===// + +static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { + auto funcOp = + llvm::dyn_cast_or_null(returnOp.getOperation()->getParentOp()); + if (!funcOp) + return returnOp.emitOpError("must appear in a 'func' op"); + + auto numOutputs = funcOp.getType().getNumResults(); + if (numOutputs != 0) + return returnOp.emitOpError("cannot be used in functions returning value") + << (numOutputs > 1 ? "s" : ""); + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/SPIRV/ops.mlir b/mlir/test/SPIRV/ops.mlir index 7b471e3..225a76e 100644 --- a/mlir/test/SPIRV/ops.mlir +++ b/mlir/test/SPIRV/ops.mlir @@ -32,3 +32,28 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { return %0 : tensor<4xf32> } +// ----- + +//===----------------------------------------------------------------------===// +// spv.Return +//===----------------------------------------------------------------------===// + +func @return_not_in_func() -> () { + // expected-error @+1 {{must appear in a 'func' op}} + spv.Return +} + +// ----- + +func @return_mismatch_func_signature() -> () { + spv.module { + func @work() -> (i32) { + // expected-error @+1 {{cannot be used in functions returning value}} + spv.Return + } + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} diff --git a/mlir/test/SPIRV/structure-ops.mlir b/mlir/test/SPIRV/structure-ops.mlir index 157ccd7..628656c 100644 --- a/mlir/test/SPIRV/structure-ops.mlir +++ b/mlir/test/SPIRV/structure-ops.mlir @@ -7,6 +7,7 @@ // spv.module //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @module_without_cap_ext func @module_without_cap_ext() -> () { // CHECK: spv.module spv.module { } attributes { @@ -16,6 +17,7 @@ func @module_without_cap_ext() -> () { return } +// CHECK-LABEL: func @module_with_cap_ext func @module_with_cap_ext() -> () { // CHECK: spv.module spv.module { } attributes { @@ -27,6 +29,32 @@ func @module_with_cap_ext() -> () { return } +// CHECK-LABEL: func @module_with_explict_module_end +func @module_with_explict_module_end() -> () { + // CHECK: spv.module + spv.module { + spv._module_end + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// CHECK-LABEL: func @module_with_func +func @module_with_func() -> () { + // CHECK: spv.module + spv.module { + func @do_nothing() -> () { + spv.Return + } + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + // ----- func @missing_addressing_model() -> () { @@ -50,3 +78,89 @@ func @missing_memory_model() -> () { spv.module { } attributes {addressing_model: "Logical"} return } + +// ----- + +func @module_with_multiple_blocks() -> () { + // expected-error @+1 {{failed to verify constraint: region with 1 blocks}} + spv.module { + ^first: + spv.Return + ^second: + spv.Return + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// ----- + +func @use_non_spv_op_inside_module() -> () { + spv.module { + // expected-error @+1 {{'spv.module' can only contain func and spv.* ops}} + "dialect.op"() : () -> () + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// ----- + +func @use_non_spv_op_inside_func() -> () { + spv.module { + func @do_nothing() -> () { + // expected-error @+1 {{functions in 'spv.module' can only contain spv.* ops}} + "dialect.op"() : () -> () + } + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// ----- + +func @use_extern_func() -> () { + spv.module { + // expected-error @+1 {{'spv.module' cannot contain external functions}} + func @extern() -> () + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// ----- + +func @module_with_nested_func() -> () { + spv.module { + func @outer_func() -> () { + // expected-error @+1 {{'spv.module' cannot contain nested functions}} + func @inner_func() -> () { + spv.Return + } + spv.Return + } + } attributes { + addressing_model: "Logical", + memory_model: "VulkanKHR" + } + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv._module_end +//===----------------------------------------------------------------------===// + +func @module_end_not_in_module() -> () { + // expected-error @+1 {{can only be used in a 'spv.module' block}} + spv._module_end +} -- 2.7.4