[mlir][spirv] Add parsing and printing support for SpecConstantOperation
authorergawy <kareem.ergawy@gmail.com>
Wed, 16 Dec 2020 13:20:24 +0000 (08:20 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 16 Dec 2020 13:26:48 +0000 (08:26 -0500)
Adds more support for `SpecConstantOperation` by defining a custom
syntax for the op and implementing its parsing and printing.

Reviewed By: mravishankar, antiagainst

Differential Revision: https://reviews.llvm.org/D92919

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/structure-ops.mlir

index b8e76c3..1ae7d28 100644 (file)
@@ -608,9 +608,12 @@ def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope
   let autogenSerialization = 0;
 }
 
-def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
-  let summary = "Yields the result computed in `spv.SpecConstantOperation`'s"
-                "region back to the parent op.";
+def SPV_YieldOp : SPV_Op<"mlir.yield", [
+    HasParent<"SpecConstantOperationOp">, NoSideEffect, Terminator]> {
+  let summary = [{
+    Yields the result computed in `spv.SpecConstantOperation`'s
+    region back to the parent op.
+  }];
 
   let description = [{
     This op is a special terminator whose only purpose is to terminate
@@ -639,12 +642,16 @@ def SPV_YieldOp : SPV_Op<"mlir.yield", [NoSideEffect, Terminator]> {
   let autogenSerialization = 0;
 
   let assemblyFormat = "attr-dict $operand `:` type($operand)";
+
+  let verifier = [{ return success(); }];
 }
 
 def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
-                                         InFunctionScope, NoSideEffect,
-                                         IsolatedFromAbove]> {
-  let summary = "Declare a new specialization constant that results from doing an operation.";
+       NoSideEffect, InFunctionScope,
+       SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = [{
+    Declare a new specialization constant that results from doing an operation.
+  }];
 
   let description = [{
     This op declares a SPIR-V specialization constant that results from
@@ -653,12 +660,8 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
     In the `spv` dialect, this op is modelled as follows:
 
     ```
-    spv-spec-constant-operation-op ::= `"spv.SpecConstantOperation"`
-                                         `(`ssa-id (`, ` ssa-id)`)`
-                                       `({`
-                                         ssa-id = spirv-op
-                                         `spv.mlir.yield` ssa-id
-                                       `})` `:` function-type
+    spv-spec-constant-operation-op ::= `spv.SpecConstantOperation` `wraps`
+                                         generic-spirv-op `:` function-type
     ```
 
     In particular, an `spv.SpecConstantOperation` contains exactly one
@@ -712,17 +715,15 @@ def SPV_SpecConstantOperationOp : SPV_Op<"SpecConstantOperation", [
     #### Example:
     ```mlir
     %0 = spv.constant 1: i32
+    %1 = spv.constant 1: i32
 
-    %1 = "spv.SpecConstantOperation"(%0) ({
-      %ret = spv.IAdd %0, %0 : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32) -> i32
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
     ```
   }];
 
-  let arguments = (ins Variadic<AnyType>:$operands);
+  let arguments = (ins);
 
-  let results = (outs AnyType:$results);
+  let results = (outs AnyType:$result);
 
   let regions = (region SizedRegion<1>:$body);
 
index 03e416e..43b3c51 100644 (file)
@@ -3396,35 +3396,39 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
 }
 
 //===----------------------------------------------------------------------===//
-// spv.mlir.yield
+// spv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(spirv::YieldOp yieldOp) {
-  Operation *parentOp = yieldOp->getParentOp();
+static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
+                                                OperationState &state) {
+  Region *body = state.addRegion();
 
-  if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp))
-    return yieldOp.emitOpError(
-        "expected parent op to be 'spv.SpecConstantOperation'");
+  if (parser.parseKeyword("wraps"))
+    return failure();
 
-  Block &block = parentOp->getRegion(0).getBlocks().front();
-  Operation &enclosedOp = block.getOperations().front();
+  body->push_back(new Block);
+  Block &block = body->back();
+  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
 
-  if (yieldOp.getOperand().getDefiningOp() != &enclosedOp)
-    return yieldOp.emitOpError(
-        "expected operand to be defined by preceeding op");
+  if (!wrappedOp)
+    return failure();
 
-  return success();
-}
+  OpBuilder builder(parser.getBuilder().getContext());
+  builder.setInsertionPointToEnd(&block);
+  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
+  state.location = wrappedOp->getLoc();
 
-static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
-                                                OperationState &state) {
-  // TODO: For now, only generic form is supported.
-  return failure();
+  state.addTypes(wrappedOp->getResult(0).getType());
+
+  if (parser.parseOptionalAttrDict(state.attributes))
+    return failure();
+
+  return success();
 }
 
 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
-  // TODO
-  printer.printGenericOp(op);
+  printer << op.getOperationName() << " wraps ";
+  printer.printGenericOp(&op.body().front().front());
 }
 
 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
@@ -3433,11 +3437,6 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
   if (block.getOperations().size() != 2)
     return constOp.emitOpError("expected exactly 2 nested ops");
 
-  Operation &yieldOp = block.getOperations().back();
-
-  if (!isa<spirv::YieldOp>(yieldOp))
-    return constOp.emitOpError("expected terminator to be a yield op");
-
   Operation &enclosedOp = block.getOperations().front();
 
   // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
@@ -3457,21 +3456,12 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
            spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
     return constOp.emitOpError("invalid enclosed op");
 
-  if (enclosedOp.getNumOperands() != constOp.getOperands().size())
-    return constOp.emitOpError("invalid number of operands; expected ")
-           << enclosedOp.getNumOperands() << ", actual "
-           << constOp.getOperands().size();
-
-  if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments())
-    return constOp.emitOpError("invalid number of region arguments; expected ")
-           << enclosedOp.getNumOperands() << ", actual "
-           << constOp.getRegion().getNumArguments();
-
-  for (auto operand : constOp.getOperands())
+  for (auto operand : enclosedOp.getOperands())
     if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
              spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
             operand.getDefiningOp()))
-      return constOp.emitOpError("invalid operand");
+      return constOp.emitOpError(
+          "invalid operand, must be defined by a constant operation");
 
   return success();
 }
index 89a30e2..c0b4951 100644 (file)
@@ -757,6 +757,7 @@ spv.module Logical GLSL450 {
   // expected-error @+1 {{unsupported composite type}}
   spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device>
 }
+
 //===----------------------------------------------------------------------===//
 // spv.SpecConstantOperation
 //===----------------------------------------------------------------------===//
@@ -765,34 +766,15 @@ spv.module Logical GLSL450 {
 
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
+    // CHECK: [[LHS:%.*]] = spv.constant
     %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
-
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
+    // CHECK: [[RHS:%.*]] = spv.constant
+    %1 = spv.constant 1: i32
 
-    // expected-error @+1 {{invalid number of operands; expected 2, actual 1}}
-    %1 = "spv.SpecConstantOperation"(%0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32) -> i32
+    // CHECK: spv.SpecConstantOperation wraps "spv.IAdd"([[LHS]], [[RHS]]) : (i32, i32) -> i32
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%0, %1) : (i32, i32) -> i32
 
-    spv.ReturnValue %1 : i32
+    spv.ReturnValue %2 : i32
   }
 }
 
@@ -801,25 +783,7 @@ spv.module Logical GLSL450 {
 spv.module Logical GLSL450 {
   spv.func @foo() -> i32 "None" {
     %0 = spv.constant 1: i32
-    %2 = spv.constant 1: i32
-
-    // expected-error @+1 {{invalid number of region arguments; expected 2, actual 1}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32):
-      %ret = spv.IAdd %lhs, %lhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-    // expected-error @+1 {{expected parent op to be 'spv.SpecConstantOperation'}}
+    // expected-error @+1 {{op expects parent op 'spv.SpecConstantOperation'}}
     spv.mlir.yield %0 : i32
   }
 }
@@ -827,67 +791,12 @@ spv.module Logical GLSL450 {
 // -----
 
 spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.ISub %lhs, %rhs : i32
-      // expected-error @+1 {{expected operand to be defined by preceeding op}}
-      spv.mlir.yield %lhs : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    // expected-error @+1 {{expected exactly 2 nested ops}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      %ret2 = spv.IAdd %lhs, %rhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
-  spv.func @foo() -> i32 "None" {
-    %0 = spv.constant 1: i32
-
-    // expected-error @+1 {{expected terminator to be a yield op}}
-    %1 = "spv.SpecConstantOperation"(%0, %0) ({
-    ^bb(%lhs : i32, %rhs : i32):
-      %ret = spv.IAdd %lhs, %rhs : i32
-      spv.ReturnValue %ret : i32
-    }) : (i32, i32) -> i32
-
-    spv.ReturnValue %1 : i32
-  }
-}
-
-// -----
-
-spv.module Logical GLSL450 {
   spv.func @foo() -> () "None" {
     %0 = spv.Variable : !spv.ptr<i32, Function>
 
     // expected-error @+1 {{invalid enclosed op}}
-    %2 = "spv.SpecConstantOperation"(%0) ({
-    ^bb(%arg0 : !spv.ptr<i32, Function>):
-      %ret = spv.Load "Function" %arg0 : i32
-      spv.mlir.yield %ret : i32
-    }) : (!spv.ptr<i32, Function>) -> i32
+    %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
+    spv.Return
   }
 }
 
@@ -898,11 +807,9 @@ spv.module Logical GLSL450 {
     %0 = spv.Variable : !spv.ptr<i32, Function>
     %1 = spv.Load "Function" %0 : i32
 
-    // expected-error @+1 {{invalid operand}}
-    %2 = "spv.SpecConstantOperation"(%1, %1) ({
-    ^bb(%lhs: i32, %rhs: i32):
-      %ret = spv.IAdd %lhs, %lhs : i32
-      spv.mlir.yield %ret : i32
-    }) : (i32, i32) -> i32
+    // expected-error @+1 {{invalid operand, must be defined by a constant operation}}
+    %2 = spv.SpecConstantOperation wraps "spv.IAdd"(%1, %1) : (i32, i32) -> i32
+
+    spv.Return
   }
 }