Make SPIR-V spv.EntryPoint and spv.ExecutionMode consistent with SPIR-V spec
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 19 Jul 2019 14:30:15 +0000 (07:30 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:40:58 +0000 (11:40 -0700)
This CL changes the Op definition of spirv::EntryPointOp and
spirv::ExecutionModeOp to be consistent with the SPIR-V spec.
1) The EntryPointOp doesn't return a value
2) The ExecutionModeOp takes as argument, the SymbolRefAttr to refer
to the function, instead of the result of the EntryPointOp.

Following this, the spirv::EntryPointType is no longer necessary, and
is removed.

PiperOrigin-RevId: 258964027

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/test/SPIRV/ops.mlir

index b638bc8..7fd4d64 100644 (file)
@@ -130,8 +130,8 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
     For example:
 
     ```
-    %4 = spv.EntryPoint "GLCompute" "foo"
-    %3 = spv.EntryPoint "Kernel" "foo", %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
+    spv.EntryPoint "GLCompute" @foo
+    spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
 
     ```
   }];
@@ -142,7 +142,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
     Variadic<SPV_AnyPtr>:$interface
   );
 
-  let results = (outs SPV_EntryPoint:$id);
+  let results = (outs);
   let autogenSerialization = 0;
 }
 
@@ -174,13 +174,13 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
     For example:
 
     ```
-    spv.ExecutionMode %7 "ContractionOff"
-    spv.ExecutionMode %8 "LocalSizeHint", 3, 4, 5
+    spv.ExecutionMode @foo "ContractionOff"
+    spv.ExecutionMode @bar "LocalSizeHint", 3, 4, 5
     ```
   }];
 
   let arguments = (ins
-    SPV_EntryPoint:$entry_point,
+    SymbolRefAttr:$fn,
     SPV_ExecutionModeAttr:$execution_mode,
     OptionalAttr<I32ArrayAttr>:$values
   );
@@ -188,6 +188,8 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
   let results = (outs);
 
   let verifier = [{ return success(); }];
+
+  let autogenSerialization = 0;
 }
 
 // -----
index 8144912..264fed3 100644 (file)
@@ -45,7 +45,6 @@ struct StructTypeStorage;
 namespace TypeKind {
 enum Kind {
   Array = Type::FIRST_SPIRV_TYPE,
-  EntryPoint,
   Image,
   Pointer,
   RuntimeArray,
@@ -84,19 +83,6 @@ public:
   Type getElementType() const;
 };
 
-// SPIR-V type for return of EntryPointOp. The EntryPointOp returns a value that
-// can be used in other ops (like ExecutionModeOp) to refer to the
-// EntryPointOp. The type of the return value contains no other information
-class EntryPointType
-    : public Type::TypeBase<EntryPointType, Type, DefaultTypeStorage> {
-public:
-  using Base::Base;
-
-  static bool kindof(unsigned kind) { return kind == TypeKind::EntryPoint; }
-
-  static Type get(MLIRContext *context);
-};
-
 // SPIR-V image type
 class ImageType
     : public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
index 70f6593..f9ddc47 100644 (file)
@@ -38,8 +38,7 @@ using namespace mlir::spirv;
 
 SPIRVDialect::SPIRVDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addTypes<ArrayType, EntryPointType, ImageType, PointerType, RuntimeArrayType,
-           StructType>();
+  addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
 
   addOperations<
 #define GET_OP_LIST
@@ -533,18 +532,11 @@ static void print(StructType type, llvm::raw_ostream &os) {
   os << ">";
 }
 
-static void print(EntryPointType type, llvm::raw_ostream &os) {
-  os << "entrypoint";
-}
-
 void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
   switch (type.getKind()) {
   case TypeKind::Array:
     print(type.cast<ArrayType>(), os);
     return;
-  case TypeKind::EntryPoint:
-    print(type.cast<EntryPointType>(), os);
-    return;
   case TypeKind::Pointer:
     print(type.cast<PointerType>(), os);
     return;
index a9c682a..a4d801f 100644 (file)
@@ -385,8 +385,6 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
   if (!fn.isa<SymbolRefAttr>()) {
     return parser->emitError(loc, "expected symbol reference attribute");
   }
-  state->addTypes(
-      spirv::EntryPointType::get(parser->getBuilder().getContext()));
   return success();
 }
 
@@ -436,12 +434,9 @@ static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
 
 static ParseResult parseExecutionModeOp(OpAsmParser *parser,
                                         OperationState *state) {
-  OpAsmParser::OperandType entryPointInfo;
   spirv::ExecutionMode execMode;
-  if (parser->parseOperand(entryPointInfo) ||
-      parser->resolveOperand(entryPointInfo,
-                             spirv::EntryPointType::get(state->getContext()),
-                             state->operands) ||
+  Attribute fn;
+  if (parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
       parseEnumAttribute(execMode, parser, state)) {
     return failure();
   }
@@ -462,10 +457,9 @@ static ParseResult parseExecutionModeOp(OpAsmParser *parser,
 }
 
 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
-  *printer << spirv::ExecutionModeOp::getOperationName() << " ";
-  printer->printOperand(execModeOp.entry_point());
-  *printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
-           << "\"";
+  *printer << spirv::ExecutionModeOp::getOperationName() << " @"
+           << execModeOp.fn() << " \""
+           << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
   auto values = execModeOp.values();
   if (!values) {
     return;
index 4c54e6f..345d13d 100644 (file)
@@ -93,14 +93,6 @@ unsigned CompositeType::getNumElements() const {
 }
 
 //===----------------------------------------------------------------------===//
-// EntryPointType
-//===----------------------------------------------------------------------===//
-
-Type EntryPointType::get(MLIRContext *context) {
-  return Base::get(context, TypeKind::EntryPoint);
-}
-
-//===----------------------------------------------------------------------===//
 // ImageType
 //===----------------------------------------------------------------------===//
 
index 183f426..e3056c9 100644 (file)
@@ -152,8 +152,8 @@ spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   // CHECK: {{%.*}} = spv.EntryPoint "GLCompute" @do_nothing
-   %2 = spv.EntryPoint "GLCompute" @do_nothing
+   // CHECK: spv.EntryPoint "GLCompute" @do_nothing
+   spv.EntryPoint "GLCompute" @do_nothing
 }
 
 spv.module "Logical" "VulkanKHR" {
@@ -164,8 +164,8 @@ spv.module "Logical" "VulkanKHR" {
      spv.Store "Output" %arg1, %1 : f32
      spv.Return
    }
-   // CHECK: {{%.*}} = spv.EntryPoint "GLCompute" @do_something, {{%.*}}, {{%.*}} : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
-   %4 = spv.EntryPoint "GLCompute" @do_something, %2, %3 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
+   // CHECK: spv.EntryPoint "GLCompute" @do_something, {{%.*}}, {{%.*}} : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
+   spv.EntryPoint "GLCompute" @do_something, %2, %3 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
 }
 
 // -----
@@ -175,7 +175,7 @@ spv.module "Logical" "VulkanKHR" {
      spv.Return
    }
    // expected-error @+1 {{custom op 'spv.EntryPoint' expected symbol reference attribute}}
-   %4 = spv.EntryPoint "GLCompute" "do_nothing"
+   spv.EntryPoint "GLCompute" "do_nothing"
 }
 
 // -----
@@ -185,7 +185,7 @@ spv.module "Logical" "VulkanKHR" {
      spv.Return
    }
    // expected-error @+1 {{function 'do_something' not found in 'spv.module'}}
-   %4 = spv.EntryPoint "GLCompute" @do_something
+   spv.EntryPoint "GLCompute" @do_something
 }
 
 /// TODO(ravishankarm) : Add a test that verifies an error is thrown
@@ -198,7 +198,7 @@ spv.module "Logical" "VulkanKHR" {
 spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op can only be used in a 'spv.module' block}}
-     %2 = spv.EntryPoint "GLCompute" @do_something
+     spv.EntryPoint "GLCompute" @do_something
    }
 }
 
@@ -208,9 +208,9 @@ spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   %5 = spv.EntryPoint "GLCompute" @do_nothing
+   spv.EntryPoint "GLCompute" @do_nothing
    // expected-error @+1 {{duplicate of a previous EntryPointOp}}
-   %6 = spv.EntryPoint "GLCompute" @do_nothing
+   spv.EntryPoint "GLCompute" @do_nothing
 }
 
 // -----
@@ -219,9 +219,9 @@ spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   %5 = spv.EntryPoint "GLCompute" @do_nothing
+   spv.EntryPoint "GLCompute" @do_nothing
    // expected-error @+1 {{custom op 'spv.EntryPoint' invalid execution_model attribute specification: "ContractionOff"}}
-   %6 = spv.EntryPoint "ContractionOff" @do_nothing
+   spv.EntryPoint "ContractionOff" @do_nothing
 }
 
 // -----
@@ -232,7 +232,7 @@ spv.module "Logical" "VulkanKHR" {
      spv.Return
    }
    // expected-error @+1 {{'spv.EntryPoint' op invalid storage class 'Workgroup'}}
-   %6 = spv.EntryPoint "GLCompute" @do_nothing, %2 : !spv.ptr<f32, Workgroup>
+   spv.EntryPoint "GLCompute" @do_nothing, %2 : !spv.ptr<f32, Workgroup>
 }
 
 // -----
@@ -245,31 +245,18 @@ spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   %7 = spv.EntryPoint "GLCompute" @do_nothing
-   // CHECK: spv.ExecutionMode {{%.*}} "ContractionOff"
-   spv.ExecutionMode %7 "ContractionOff"
+   spv.EntryPoint "GLCompute" @do_nothing
+   // CHECK: spv.ExecutionMode {{@.*}} "ContractionOff"
+   spv.ExecutionMode @do_nothing "ContractionOff"
 }
 
 spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   %8 = spv.EntryPoint "GLCompute" @do_nothing
-   // CHECK: spv.ExecutionMode {{%.*}} "LocalSizeHint", 3, 4, 5
-   spv.ExecutionMode %8 "LocalSizeHint", 3, 4, 5
-}
-
-// -----
-
-spv.module "Logical" "VulkanKHR" {
-   // expected-note @+1{{prior use here}}
-   %2 = spv.Variable : !spv.ptr<f32, Input>
-   func @do_nothing() -> () {
-     spv.Return
-   }
-   %8 = spv.EntryPoint "GLCompute" @do_nothing
-   // expected-error @+1 {{use of value '%2' expects different type than prior uses: '!spv.entrypoint' vs '!spv.ptr<f32, Input>'}}
-   spv.ExecutionMode %2 "LocalSizeHint", 3, 4, 5
+   spv.EntryPoint "GLCompute" @do_nothing
+   // CHECK: spv.ExecutionMode {{@.*}} "LocalSizeHint", 3, 4, 5
+   spv.ExecutionMode @do_nothing "LocalSizeHint", 3, 4, 5
 }
 
 // -----
@@ -278,9 +265,9 @@ spv.module "Logical" "VulkanKHR" {
    func @do_nothing() -> () {
      spv.Return
    }
-   %8 = spv.EntryPoint "GLCompute" @do_nothing
+   spv.EntryPoint "GLCompute" @do_nothing
    // expected-error @+1 {{custom op 'spv.ExecutionMode' invalid execution_mode attribute specification: "GLCompute"}}
-   spv.ExecutionMode %8 "GLCompute", 3, 4, 5
+   spv.ExecutionMode @do_nothing "GLCompute", 3, 4, 5
 }
 
 // -----