[spirv] Basic validity of SPV_ModuleOp
authorLei Zhang <antiagainst@google.com>
Tue, 4 Jun 2019 21:03:30 +0000 (14:03 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:17:34 +0000 (16:17 -0700)
This CL adds SPV_ModuleEndOp for terminating the only block inside a
SPV_ModuleOp's only region. Verification now enforces a spv.module only
contains func or spv.* ops and no external or nested functions are
present. Because of the structural requirement of a block, spv.Return
is also added in this CL.

PiperOrigin-RevId: 251510706

mlir/include/mlir/SPIRV/SPIRVOps.td
mlir/include/mlir/SPIRV/SPIRVStructureOps.td
mlir/lib/SPIRV/SPIRVOps.cpp
mlir/test/SPIRV/ops.mlir
mlir/test/SPIRV/structure-ops.mlir

index cb82268..d68a660 100644 (file)
@@ -53,4 +53,21 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
   let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
 }
 
+def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
+  let summary = "Return with no value from a function with void return type";
+
+  let description = [{
+    This instruction must be the last instruction in a block.
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+
+  let verifier = [{ return verifyReturn(*this); }];
+}
+
 #endif // SPIRV_OPS
index bcb485c..6e084b9 100644 (file)
@@ -60,6 +60,9 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
 
     This op takes no operands and generates no results. This op should not
     implicitly capture values from the enclosing environment.
+
+    This op has only one region, which only contains one block. The block
+    must be terminated via the `spv._module_end` op.
   }];
 
   let arguments = (ins
@@ -72,11 +75,33 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
 
   let results = (outs);
 
-  let regions = (region AnyRegion:$body);
+  let regions = (region SizedRegion<1>:$body);
 
   // Custom parser and printer implemented by static functions in SPVOps.cpp
   let parser = [{ return parseModule(parser, result); }];
   let printer = [{ printModule(getOperation(), p); }];
+
+  let verifier = [{ return verifyModule(*this); }];
+}
+
+def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator]> {
+  let summary = "The pseudo op that ends a SPIR-V module";
+
+  let description = [{
+    This op terminates the only block inside a `spv.module`'s only region.
+    This op does not have a corresponding SPIR-V instruction and thus will
+    not be serialized into the binary format; it is used solely to satisfy
+    the structual requirement that an block must be ended with a terminator.
+  }];
+
+  let arguments = (ins);
+
+  let results = (outs);
+
+  let parser = [{ return parseNoIOOp(parser, result); }];
+  let printer = [{ printNoIOOp(getOperation(), p); }];
+
+  let verifier = [{ return verifyModuleOnly(*this); }];
 }
 
 #endif // SPIRV_STRUCTURE_OPS
index 6b3449a..4f354b5 100644 (file)
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// Common utility functions
+//===----------------------------------------------------------------------===//
+
+// Parses an op that has no inputs and no outputs.
+static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) {
+  if (parser->parseOptionalAttributeDict(state->attributes))
+    return failure();
+  return success();
+}
+
+// Prints an op that has no inputs and no outputs.
+static ParseResult printNoIOOp(Operation *op, OpAsmPrinter *printer) {
+  *printer << op->getName();
+  printer->printOptionalAttrDict(op->getAttrs(),
+                                 /*elidedAttrs=*/{});
+  return success();
+}
+
+// Verifies that the given op can only be placed in a `spv.module`.
+static LogicalResult verifyModuleOnly(Operation *op) {
+  if (!llvm::isa_and_nonnull<spirv::ModuleOp>(op->getParentOp()))
+    return op->emitOpError("can only be used in a 'spv.module' block");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.module
 //===----------------------------------------------------------------------===//
 
+static void ensureModuleEnd(Region *region, Builder builder, Location loc) {
+  if (region->empty())
+    region->push_back(new Block);
+
+  Block &block = region->back();
+  if (!block.empty() && llvm::isa<spirv::ModuleEndOp>(block.back()))
+    return;
+
+  OperationState state(builder.getContext(), loc,
+                       spirv::ModuleEndOp::getOperationName());
+  spirv::ModuleEndOp::build(&builder, &state);
+  block.push_back(Operation::create(state));
+}
+
 static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
   Region *body = state->addRegion();
 
@@ -39,6 +80,8 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
       parser->parseOptionalAttributeDict(state->attributes))
     return failure();
 
+  ensureModuleEnd(body, parser->getBuilder(), state->location);
+
   return success();
 }
 
@@ -52,6 +95,57 @@ static ParseResult printModule(Operation *op, OpAsmPrinter *printer) {
   return success();
 }
 
+static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
+  auto &op = *moduleOp.getOperation();
+  auto *dialect = op.getDialect();
+  auto &body = op.getRegion(0).front();
+
+  for (auto &op : body) {
+    if (op.getDialect() == dialect)
+      continue;
+
+    auto funcOp = llvm::dyn_cast<FuncOp>(op);
+    if (!funcOp)
+      return op.emitError("'spv.module' can only contain func and spv.* ops");
+
+    if (funcOp.isExternal())
+      return op.emitError("'spv.module' cannot contain external functions");
+
+    for (auto &block : funcOp)
+      for (auto &op : block) {
+        // TODO(antiagainst): verify that return ops have the same type as the
+        // enclosing function
+        if (op.getDialect() == dialect)
+          continue;
+
+        if (llvm::isa<FuncOp>(op))
+          return op.emitError("'spv.module' cannot contain nested functions");
+
+        return op.emitError(
+            "functions in 'spv.module' can only contain spv.* ops");
+      }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.Return
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
+  auto funcOp =
+      llvm::dyn_cast_or_null<FuncOp>(returnOp.getOperation()->getParentOp());
+  if (!funcOp)
+    return returnOp.emitOpError("must appear in a 'func' op");
+
+  auto numOutputs = funcOp.getType().getNumResults();
+  if (numOutputs != 0)
+    return returnOp.emitOpError("cannot be used in functions returning value")
+           << (numOutputs > 1 ? "s" : "");
+
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 
index 7b471e3..225a76e 100644 (file)
@@ -32,3 +32,28 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.Return
+//===----------------------------------------------------------------------===//
+
+func @return_not_in_func() -> () {
+  // expected-error @+1 {{must appear in a 'func' op}}
+  spv.Return
+}
+
+// -----
+
+func @return_mismatch_func_signature() -> () {
+  spv.module {
+    func @work() -> (i32) {
+      // expected-error @+1 {{cannot be used in functions returning value}}
+      spv.Return
+    }
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
index 157ccd7..628656c 100644 (file)
@@ -7,6 +7,7 @@
 // spv.module
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @module_without_cap_ext
 func @module_without_cap_ext() -> () {
   // CHECK: spv.module
   spv.module { } attributes {
@@ -16,6 +17,7 @@ func @module_without_cap_ext() -> () {
   return
 }
 
+// CHECK-LABEL: func @module_with_cap_ext
 func @module_with_cap_ext() -> () {
   // CHECK: spv.module
   spv.module { } attributes {
@@ -27,6 +29,32 @@ func @module_with_cap_ext() -> () {
   return
 }
 
+// CHECK-LABEL: func @module_with_explict_module_end
+func @module_with_explict_module_end() -> () {
+  // CHECK: spv.module
+  spv.module {
+    spv._module_end
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// CHECK-LABEL: func @module_with_func
+func @module_with_func() -> () {
+  // CHECK: spv.module
+  spv.module {
+    func @do_nothing() -> () {
+      spv.Return
+    }
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
 // -----
 
 func @missing_addressing_model() -> () {
@@ -50,3 +78,89 @@ func @missing_memory_model() -> () {
   spv.module { } attributes {addressing_model: "Logical"}
   return
 }
+
+// -----
+
+func @module_with_multiple_blocks() -> () {
+  // expected-error @+1 {{failed to verify constraint: region with 1 blocks}}
+  spv.module {
+  ^first:
+    spv.Return
+  ^second:
+    spv.Return
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// -----
+
+func @use_non_spv_op_inside_module() -> () {
+  spv.module {
+    // expected-error @+1 {{'spv.module' can only contain func and spv.* ops}}
+    "dialect.op"() : () -> ()
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// -----
+
+func @use_non_spv_op_inside_func() -> () {
+  spv.module {
+    func @do_nothing() -> () {
+      // expected-error @+1 {{functions in 'spv.module' can only contain spv.* ops}}
+      "dialect.op"() : () -> ()
+    }
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// -----
+
+func @use_extern_func() -> () {
+  spv.module {
+    // expected-error @+1 {{'spv.module' cannot contain external functions}}
+    func @extern() -> ()
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// -----
+
+func @module_with_nested_func() -> () {
+  spv.module {
+    func @outer_func() -> () {
+      // expected-error @+1 {{'spv.module' cannot contain nested functions}}
+      func @inner_func() -> () {
+        spv.Return
+      }
+      spv.Return
+    }
+  } attributes {
+    addressing_model: "Logical",
+    memory_model: "VulkanKHR"
+  }
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv._module_end
+//===----------------------------------------------------------------------===//
+
+func @module_end_not_in_module() -> () {
+  // expected-error @+1 {{can only be used in a 'spv.module' block}}
+  spv._module_end
+}