Add support for Logical Ops in SPIR-V dialect
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 30 Sep 2019 17:40:07 +0000 (10:40 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 30 Sep 2019 17:40:36 +0000 (10:40 -0700)
Add operations corresponding to OpLogicalAnd, OpLogicalNot,
OpLogicalEqual, OpLogicalNotEqual and OpLogicalOr instructions in
SPIR-V dialect. This needs changes to class hierarchy in SPIR-V
TableGen files to split SPIRVLogicalOp into SPIRVLogicalUnaryOp and
SPIRVLogicalBinaryOp. All derived classes of SPIRVLogicalOp are
updated accordingly.

Update the spirv dialect generation script to
1) Allow specifying base class to use for instruction spec generation
and file name to generate the specification in separately.
2) Use the existing descriptions for operations.
3) Update define_inst.sh to also invoke define_opcode.sh to also
define the corresponding SPIR-V instruction opcode enum.

PiperOrigin-RevId: 272014876

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/ops.mlir
mlir/utils/spirv/define_inst.sh
mlir/utils/spirv/gen_spirv_dialect.py

index 1440f75..6b1d20d 100644 (file)
@@ -132,6 +132,11 @@ def SPV_OC_OpSRem                   : I32EnumAttrCase<"OpSRem", 138>;
 def SPV_OC_OpSMod                   : I32EnumAttrCase<"OpSMod", 139>;
 def SPV_OC_OpFRem                   : I32EnumAttrCase<"OpFRem", 140>;
 def SPV_OC_OpFMod                   : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpLogicalEqual           : I32EnumAttrCase<"OpLogicalEqual", 164>;
+def SPV_OC_OpLogicalNotEqual        : I32EnumAttrCase<"OpLogicalNotEqual", 165>;
+def SPV_OC_OpLogicalOr              : I32EnumAttrCase<"OpLogicalOr", 166>;
+def SPV_OC_OpLogicalAnd             : I32EnumAttrCase<"OpLogicalAnd", 167>;
+def SPV_OC_OpLogicalNot             : I32EnumAttrCase<"OpLogicalNot", 168>;
 def SPV_OC_OpSelect                 : I32EnumAttrCase<"OpSelect", 169>;
 def SPV_OC_OpIEqual                 : I32EnumAttrCase<"OpIEqual", 170>;
 def SPV_OC_OpINotEqual              : I32EnumAttrCase<"OpINotEqual", 171>;
@@ -184,12 +189,14 @@ def SPV_OpcodeAttr :
       SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
       SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
       SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
-      SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
-      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
-      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
-      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
-      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
-      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
+      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
+      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
+      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
+      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
+      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
index 1e9a547..b3b7df6 100644 (file)
 include "mlir/SPIRV/SPIRVBase.td"
 #endif // SPIRV_BASE
 
-class SPV_LogicalOp<string mnemonic, Type operandsType,
+class SPV_LogicalBinaryOp<string mnemonic, Type operandsType,
                     list<OpTrait> traits = []> :
       // Result type is SPV_Bool.
       SPV_BinaryOp<mnemonic, SPV_Bool, operandsType,
                    !listconcat(traits,
                                [NoSideEffect, SameTypeOperands,
                                 SameOperandsAndResultShape])> {
-  let parser = [{ return ::parseBinaryLogicalOp(parser, result); }];
-  let printer = [{ return ::printBinaryLogicalOp(getOperation(), p); }];
+  let parser = [{ return ::parseLogicalBinaryOp(parser, result); }];
+  let printer = [{ return ::printLogicalOp(getOperation(), p); }];
+}
+
+class SPV_LogicalUnaryOp<string mnemonic, Type operandType,
+                         list<OpTrait> traits = []> :
+      // Result type is SPV_Bool.
+      SPV_UnaryOp<mnemonic, SPV_Bool, operandType,
+                  !listconcat(traits, [NoSideEffect, SameTypeOperands,
+                                       SameOperandsAndResultShape])> {
+  let parser = [{ return ::parseLogicalUnaryOp(parser, result); }];
+  let printer = [{ return ::printLogicalOp(getOperation(), p); }];
 }
 
 // -----
 
-def SPV_FOrdEqualOp : SPV_LogicalOp<"FOrdEqual", SPV_Float, [Commutative]> {
+def SPV_FOrdEqualOp : SPV_LogicalBinaryOp<"FOrdEqual", SPV_Float, [Commutative]> {
   let summary = "Floating-point comparison for being ordered and equal.";
 
   let description = [{
@@ -73,7 +83,7 @@ def SPV_FOrdEqualOp : SPV_LogicalOp<"FOrdEqual", SPV_Float, [Commutative]> {
 
 // -----
 
-def SPV_FOrdGreaterThanOp : SPV_LogicalOp<"FOrdGreaterThan", SPV_Float, []> {
+def SPV_FOrdGreaterThanOp : SPV_LogicalBinaryOp<"FOrdGreaterThan", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are ordered and Operand 1 is
     greater than  Operand 2.
@@ -107,7 +117,7 @@ def SPV_FOrdGreaterThanOp : SPV_LogicalOp<"FOrdGreaterThan", SPV_Float, []> {
 
 // -----
 
-def SPV_FOrdGreaterThanEqualOp : SPV_LogicalOp<"FOrdGreaterThanEqual", SPV_Float, []> {
+def SPV_FOrdGreaterThanEqualOp : SPV_LogicalBinaryOp<"FOrdGreaterThanEqual", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are ordered and Operand 1 is
     greater than or equal to Operand 2.
@@ -141,7 +151,7 @@ def SPV_FOrdGreaterThanEqualOp : SPV_LogicalOp<"FOrdGreaterThanEqual", SPV_Float
 
 // -----
 
-def SPV_FOrdLessThanOp : SPV_LogicalOp<"FOrdLessThan", SPV_Float, []> {
+def SPV_FOrdLessThanOp : SPV_LogicalBinaryOp<"FOrdLessThan", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are ordered and Operand 1 is less
     than Operand 2.
@@ -175,7 +185,7 @@ def SPV_FOrdLessThanOp : SPV_LogicalOp<"FOrdLessThan", SPV_Float, []> {
 
 // -----
 
-def SPV_FOrdLessThanEqualOp : SPV_LogicalOp<"FOrdLessThanEqual", SPV_Float, []> {
+def SPV_FOrdLessThanEqualOp : SPV_LogicalBinaryOp<"FOrdLessThanEqual", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are ordered and Operand 1 is less
     than or equal to Operand 2.
@@ -209,7 +219,7 @@ def SPV_FOrdLessThanEqualOp : SPV_LogicalOp<"FOrdLessThanEqual", SPV_Float, []>
 
 // -----
 
-def SPV_FOrdNotEqualOp : SPV_LogicalOp<"FOrdNotEqual", SPV_Float, [Commutative]> {
+def SPV_FOrdNotEqualOp : SPV_LogicalBinaryOp<"FOrdNotEqual", SPV_Float, [Commutative]> {
   let summary = "Floating-point comparison for being ordered and not equal.";
 
   let description = [{
@@ -240,7 +250,7 @@ def SPV_FOrdNotEqualOp : SPV_LogicalOp<"FOrdNotEqual", SPV_Float, [Commutative]>
 
 // -----
 
-def SPV_FUnordEqualOp : SPV_LogicalOp<"FUnordEqual", SPV_Float, [Commutative]> {
+def SPV_FUnordEqualOp : SPV_LogicalBinaryOp<"FUnordEqual", SPV_Float, [Commutative]> {
   let summary = "Floating-point comparison for being unordered or equal.";
 
   let description = [{
@@ -271,7 +281,7 @@ def SPV_FUnordEqualOp : SPV_LogicalOp<"FUnordEqual", SPV_Float, [Commutative]> {
 
 // -----
 
-def SPV_FUnordGreaterThanOp : SPV_LogicalOp<"FUnordGreaterThan", SPV_Float, []> {
+def SPV_FUnordGreaterThanOp : SPV_LogicalBinaryOp<"FUnordGreaterThan", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are unordered or Operand 1 is
     greater than  Operand 2.
@@ -305,7 +315,7 @@ def SPV_FUnordGreaterThanOp : SPV_LogicalOp<"FUnordGreaterThan", SPV_Float, []>
 
 // -----
 
-def SPV_FUnordGreaterThanEqualOp : SPV_LogicalOp<"FUnordGreaterThanEqual", SPV_Float, []> {
+def SPV_FUnordGreaterThanEqualOp : SPV_LogicalBinaryOp<"FUnordGreaterThanEqual", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are unordered or Operand 1 is
     greater than or equal to Operand 2.
@@ -339,7 +349,7 @@ def SPV_FUnordGreaterThanEqualOp : SPV_LogicalOp<"FUnordGreaterThanEqual", SPV_F
 
 // -----
 
-def SPV_FUnordLessThanOp : SPV_LogicalOp<"FUnordLessThan", SPV_Float, []> {
+def SPV_FUnordLessThanOp : SPV_LogicalBinaryOp<"FUnordLessThan", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are unordered or Operand 1 is less
     than Operand 2.
@@ -373,7 +383,7 @@ def SPV_FUnordLessThanOp : SPV_LogicalOp<"FUnordLessThan", SPV_Float, []> {
 
 // -----
 
-def SPV_FUnordLessThanEqualOp : SPV_LogicalOp<"FUnordLessThanEqual", SPV_Float, []> {
+def SPV_FUnordLessThanEqualOp : SPV_LogicalBinaryOp<"FUnordLessThanEqual", SPV_Float, []> {
   let summary = [{
     Floating-point comparison if operands are unordered or Operand 1 is less
     than or equal to Operand 2.
@@ -407,7 +417,7 @@ def SPV_FUnordLessThanEqualOp : SPV_LogicalOp<"FUnordLessThanEqual", SPV_Float,
 
 // -----
 
-def SPV_FUnordNotEqualOp : SPV_LogicalOp<"FUnordNotEqual", SPV_Float, [Commutative]> {
+def SPV_FUnordNotEqualOp : SPV_LogicalBinaryOp<"FUnordNotEqual", SPV_Float, [Commutative]> {
   let summary = "Floating-point comparison for being unordered or not equal.";
 
   let description = [{
@@ -438,7 +448,7 @@ def SPV_FUnordNotEqualOp : SPV_LogicalOp<"FUnordNotEqual", SPV_Float, [Commutati
 
 // -----
 
-def SPV_IEqualOp : SPV_LogicalOp<"IEqual", SPV_Integer, [Commutative]> {
+def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual", SPV_Integer, [Commutative]> {
   let summary = "Integer comparison for equality.";
 
   let description = [{
@@ -469,7 +479,7 @@ def SPV_IEqualOp : SPV_LogicalOp<"IEqual", SPV_Integer, [Commutative]> {
 
 // -----
 
-def SPV_INotEqualOp : SPV_LogicalOp<"INotEqual", SPV_Integer, [Commutative]> {
+def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", SPV_Integer, [Commutative]> {
   let summary = "Integer comparison for inequality.";
 
   let description = [{
@@ -500,7 +510,168 @@ def SPV_INotEqualOp : SPV_LogicalOp<"INotEqual", SPV_Integer, [Commutative]> {
 
 // -----
 
-def SPV_SGreaterThanOp : SPV_LogicalOp<"SGreaterThan", SPV_Integer, []> {
+def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]> {
+  let summary = [{
+    Result is true if both Operand 1 and Operand 2 are true. Result is false
+    if either Operand 1 or Operand 2 are false.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 must be the same as Result Type.
+
+     The type of Operand 2 must be the same as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    logical-and ::= `spv.LogicalAnd` ssa-use `,` ssa-use
+                    `:` operand-type
+    ```
+
+    For example:
+
+    ```
+    %2 = spv.LogicalAnd %0, %1 : i1
+    %2 = spv.LogicalAnd %0, %1 : vector<4xi1>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual", SPV_Bool, [Commutative]> {
+  let summary = [{
+    Result is true if Operand 1 and Operand 2 have the same value. Result is
+    false if Operand 1 and Operand 2 have different values.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 must be the same as Result Type.
+
+     The type of Operand 2 must be the same as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    logical-equal ::= `spv.LogicalEqual` ssa-use `,` ssa-use
+                      `:` operand-type
+    ```
+
+    For example:
+
+    ```
+    %2 = spv.LogicalEqual %0, %1 : i1
+    %2 = spv.LogicalEqual %0, %1 : vector<4xi1>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> {
+  let summary = [{
+    Result is true if Operand is false.  Result is false if Operand is true.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand must be the same as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    logical-not ::= `spv.LogicalNot` ssa-use `:` operand-type
+    ```
+
+    For example:
+
+    ```
+    %2 = spv.LogicalNot %0 : i1
+    %2 = spv.LogicalNot %0 : vector<4xi1>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual", SPV_Bool, [Commutative]> {
+  let summary = [{
+    Result is true if Operand 1 and Operand 2 have different values. Result
+    is false if Operand 1 and Operand 2 have the same value.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 must be the same as Result Type.
+
+     The type of Operand 2 must be the same as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    logical-not-equal ::= `spv.LogicalNotEqual` ssa-use `,` ssa-use
+                          `:` operand-type
+    ```
+
+    For example:
+
+    ```
+    %2 = spv.LogicalNotEqual %0, %1 : i1
+    %2 = spv.LogicalNotEqual %0, %1 : vector<4xi1>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]> {
+  let summary = [{
+    Result is true if either Operand 1 or Operand 2 is true. Result is false
+    if both Operand 1 and Operand 2 are false.
+  }];
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+     The type of Operand 1 must be the same as Result Type.
+
+     The type of Operand 2 must be the same as Result Type.
+
+     Results are computed per component.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    logical-or ::= `spv.LogicalOr` ssa-use `,` ssa-use
+                    `:` operand-type
+    ```
+
+    For example:
+
+    ```
+    %2 = spv.LogicalOr %0, %1 : i1
+    %2 = spv.LogicalOr %0, %1 : vector<4xi1>
+    ```
+  }];
+}
+
+// -----
+
+def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, []> {
   let summary = [{
     Signed-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -533,7 +704,7 @@ def SPV_SGreaterThanOp : SPV_LogicalOp<"SGreaterThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_SGreaterThanEqualOp : SPV_LogicalOp<"SGreaterThanEqual", SPV_Integer, []> {
+def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integer, []> {
   let summary = [{
     Signed-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -567,7 +738,7 @@ def SPV_SGreaterThanEqualOp : SPV_LogicalOp<"SGreaterThanEqual", SPV_Integer, []
 
 // -----
 
-def SPV_SLessThanOp : SPV_LogicalOp<"SLessThan", SPV_Integer, []> {
+def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, []> {
   let summary = [{
     Signed-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -600,7 +771,7 @@ def SPV_SLessThanOp : SPV_LogicalOp<"SLessThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> {
+def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []> {
   let summary = [{
     Signed-integer comparison if Operand 1 is less than or equal to Operand
     2.
@@ -634,7 +805,7 @@ def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> {
 
 // -----
 
-def SPV_SelectOp : SPV_Op<"Select", []> {
+def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> {
   let summary = [{
     Select between two objects. Before version 1.4, results are only
     computed per component.
@@ -691,7 +862,7 @@ def SPV_SelectOp : SPV_Op<"Select", []> {
 
 // -----
 
-def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> {
+def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -724,7 +895,7 @@ def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_UGreaterThanEqualOp : SPV_LogicalOp<"UGreaterThanEqual", SPV_Integer, []> {
+def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -758,7 +929,7 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalOp<"UGreaterThanEqual", SPV_Integer, []
 
 // -----
 
-def SPV_ULessThanOp : SPV_LogicalOp<"ULessThan", SPV_Integer, []> {
+def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -791,7 +962,8 @@ def SPV_ULessThanOp : SPV_LogicalOp<"ULessThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_ULessThanEqualOp : SPV_LogicalOp<"ULessThanEqual", SPV_Integer, []> {
+def SPV_ULessThanEqualOp :
+  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than or equal to
     Operand 2.
index 408d365..f6ae3e4 100644 (file)
@@ -367,7 +367,28 @@ static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
           << unaryOp->getOperand(0)->getType();
 }
 
-static ParseResult parseBinaryLogicalOp(OpAsmParser &parser,
+/// Result of a logical op must be a scalar or vector of boolean type.
+static Type getUnaryOpResultType(Builder &builder, Type operandType) {
+  Type resultType = builder.getIntegerType(1);
+  if (auto vecType = operandType.dyn_cast<VectorType>()) {
+    return VectorType::get(vecType.getNumElements(), resultType);
+  }
+  return resultType;
+}
+
+static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
+                                       OperationState &state) {
+  OpAsmParser::OperandType operandInfo;
+  Type type;
+  if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
+      parser.resolveOperand(operandInfo, type, state.operands)) {
+    return failure();
+  }
+  state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
+  return success();
+}
+
+static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
                                         OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
   Type type;
@@ -375,18 +396,13 @@ static ParseResult parseBinaryLogicalOp(OpAsmParser &parser,
       parser.resolveOperands(ops, type, result.operands)) {
     return failure();
   }
-  // Result must be a scalar or vector of boolean type.
-  Type resultType = parser.getBuilder().getIntegerType(1);
-  if (auto opsType = type.dyn_cast<VectorType>()) {
-    resultType = VectorType::get(opsType.getNumElements(), resultType);
-  }
-  result.addTypes(resultType);
+  result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
   return success();
 }
 
-static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
-  printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
-          << *logicalOp->getOperand(1);
+static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
+  printer << logicalOp->getName() << ' ';
+  printer.printOperands(logicalOp->getOperands());
   printer << " : " << logicalOp->getOperand(0)->getType();
 }
 
index 8c4b0fa..d24015a 100644 (file)
@@ -607,6 +607,113 @@ spv.module "Logical" "GLSL450" {
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.LogicalAnd
+//===----------------------------------------------------------------------===//
+
+func @logicalBinary(%arg0 : i1, %arg1 : i1, %arg2 : i1)
+{
+  // CHECK: [[TMP:%.*]] = spv.LogicalAnd {{%.*}}, {{%.*}} : i1
+  %0 = spv.LogicalAnd %arg0, %arg1 : i1
+  // CHECK: {{%.*}} = spv.LogicalAnd [[TMP]], {{%.*}} : i1
+  %1 = spv.LogicalAnd %0, %arg2 : i1
+  return
+}
+
+func @logicalBinary2(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>)
+{
+  // CHECK: {{%.*}} = spv.LogicalAnd {{%.*}}, {{%.*}} : vector<4xi1>
+  %0 = spv.LogicalAnd %arg0, %arg1 : vector<4xi1>
+  return
+}
+
+// -----
+
+func @logicalBinary(%arg0 : i1, %arg1 : i1)
+{
+  // expected-error @+2 {{expected ':'}}
+  %0 = spv.LogicalAnd %arg0, %arg1
+  return
+}
+
+// -----
+
+func @logicalBinary(%arg0 : i1, %arg1 : i1)
+{
+  // expected-error @+2 {{expected non-function type}}
+  %0 = spv.LogicalAnd %arg0, %arg1 :
+  return
+}
+
+// -----
+
+func @logicalBinary(%arg0 : i1, %arg1 : i1)
+{
+  // expected-error @+1 {{custom op 'spv.LogicalAnd' expected 2 operands}}
+  %0 = spv.LogicalAnd %arg0 : i1
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalNot
+//===----------------------------------------------------------------------===//
+
+func @logicalUnary(%arg0 : i1, %arg1 : i1)
+{
+  // CHECK: [[TMP:%.*]] = spv.LogicalNot {{%.*}} : i1
+  %0 = spv.LogicalNot %arg0 : i1
+  // CHECK: {{%.*}} = spv.LogicalNot [[TMP]] : i1
+  %1 = spv.LogicalNot %0 : i1
+  return
+}
+
+func @logicalUnary2(%arg0 : vector<4xi1>)
+{
+  // CHECK: {{%.*}} = spv.LogicalNot {{%.*}} : vector<4xi1>
+  %0 = spv.LogicalNot %arg0 : vector<4xi1>
+  return
+}
+
+// -----
+
+func @logicalUnary(%arg0 : i1)
+{
+  // expected-error @+2 {{expected ':'}}
+  %0 = spv.LogicalNot %arg0
+  return
+}
+
+// -----
+
+func @logicalUnary(%arg0 : i1)
+{
+  // expected-error @+2 {{expected non-function type}}
+  %0 = spv.LogicalNot %arg0 :
+  return
+}
+
+// -----
+
+func @logicalUnary(%arg0 : i1)
+{
+  // expected-error @+1 {{expected SSA operand}}
+  %0 = spv.LogicalNot : i1
+  return
+}
+
+// -----
+
+func @logicalUnary(%arg0 : i32)
+{
+  // expected-error @+1 {{'spv.LogicalNot' op operand #0 must be 1-bit integer or vector of 1-bit integer values of length 2/3/4, but got 'i32'}}
+  %0 = spv.LogicalNot %arg0 : i32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.MemoryBarrier
 //===----------------------------------------------------------------------===//
 
index 55e2fa0..328862b 100755 (executable)
@@ -1,5 +1,4 @@
 #!/bin/bash
-
 # Copyright 2019 The MLIR Authors.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # ./define_inst.sh LogicalOp OpFOrdEqual
 set -e
 
-inst_category=$1
+file_name=$1
+inst_category=$2
 
 case $inst_category in
   Op | ArithmeticOp | LogicalOp | ControlFlowOp | StructureOp)
   ;;
   *)
-    echo "Usage : " $0 " <inst_category> (<opname>)*"
+    echo "Usage : " $0 "<filename> <inst_category> (<opname>)*"
     echo "<inst_category> must be one of " \
       "(Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp)"
     exit 1;
@@ -44,11 +44,15 @@ case $inst_category in
 esac
 
 shift
+shift
 
 current_file="$(readlink -f "$0")"
 current_dir="$(dirname "$current_file")"
 
 python3 ${current_dir}/gen_spirv_dialect.py \
   --op-td-path \
-  ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRV${inst_category}s.td \
+  ${current_dir}/../../include/mlir/Dialect/SPIRV/${file_name} \
   --inst-category $inst_category --new-inst "$@"
+
+${current_dir}/define_opcodes.sh "$@"
+
index e505b40..ff096df 100755 (executable)
@@ -364,7 +364,22 @@ def map_spec_operand_to_ods_argument(operand):
   return '{}:${}'.format(arg_type, name)
 
 
-def get_op_definition(instruction, doc, existing_info, inst_category):
+def get_description(text, assembly):
+  """Generates the description for the given SPIR-V instruction.
+
+  Arguments:
+    - text: Textual description of the operation as string.
+    - assembly: Custom Assembly format with example as string.
+
+  Returns:
+    - A string that corresponds to the description of the Tablegen op.
+  """
+  fmt_str = ('{text}\n\n    ### Custom assembly ' 'form\n{assembly}}}];\n')
+  return fmt_str.format(
+      text=text, assembly=assembly)
+
+
+def get_op_definition(instruction, doc, existing_info):
   """Generates the TableGen op definition for the given SPIR-V instruction.
 
   Arguments:
@@ -379,8 +394,8 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
   fmt_str = ('def SPV_{opname}Op : '
              'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> '
              '{{\n  let summary = {summary};\n\n  let description = '
-             '[{{\n{description}\n\n    ### Custom assembly '
-             'form\n{assembly}}}];\n')
+             '[{{\n{description}}}];\n')
+  inst_category = existing_info.get('inst_category', 'Op')
   if inst_category == 'Op':
     fmt_str +='\n  let arguments = (ins{args});\n\n'\
               '  let results = (outs{results});\n'
@@ -393,7 +408,7 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
   # Make sure we have ', ' to separate the category arguments from traits
   category_args = category_args.rstrip(', ') + ', '
 
-  summary, description = doc.split('\n', 1)
+  summary, text = doc.split('\n', 1)
   wrapper = textwrap.TextWrapper(
       width=76, initial_indent='    ', subsequent_indent='    ')
 
@@ -405,10 +420,10 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
   else:
     summary = '[{{\n{}\n  }}]'.format(wrapper.fill(summary))
 
-  # Wrap description
-  description = description.split('\n')
-  description = [wrapper.fill(line) for line in description if line]
-  description = '\n\n'.join(description)
+  # Wrap text
+  text = text.split('\n')
+  text = [wrapper.fill(line) for line in text if line]
+  text = '\n\n'.join(text)
 
   operands = instruction.get('operands', [])
 
@@ -433,8 +448,8 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
       # Prepend and append whitespace for formatting
       arguments = '\n    {}\n  '.format(arguments)
 
-  assembly = existing_info.get('assembly', None)
-  if assembly is None:
+  description = existing_info.get('description', None)
+  if description is None:
     assembly = '\n    ``` {.ebnf}\n'\
                '    [TODO]\n'\
                '    ```\n\n'\
@@ -442,6 +457,7 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
                '    ```\n'\
                '    [TODO]\n'\
                '    ```\n  '
+    description = get_description(text, assembly)
 
   return fmt_str.format(
       opname=opname,
@@ -450,7 +466,6 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
       traits=existing_info.get('traits', ''),
       summary=summary,
       description=description,
-      assembly=assembly,
       args=arguments,
       results=results,
       extras=existing_info.get('extras', ''))
@@ -493,6 +508,14 @@ def extract_td_op_info(op_def):
   assert len(opname) == 1, 'more than one ops in the same section!'
   opname = opname[0]
 
+  # Get instruction category
+  inst_category = [
+      o[4:] for o in re.findall('SPV_\w+Op',
+                                op_def.split(':', 1)[1])
+  ]
+  assert len(inst_category) <= 1, 'more than one ops in the same section!'
+  inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
+
   # Get category_args
   op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0]
   opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
@@ -501,9 +524,9 @@ def extract_td_op_info(op_def):
   # Get traits
   traits, _ = get_string_between(rest, '[', ']')
 
-  # Get custom assembly form
-  assembly, rest = get_string_between(op_def, '### Custom assembly form\n',
-                                      '}];\n')
+  # Get description
+  description, rest = get_string_between(op_def, 'let description = [{\n',
+                                         '}];\n')
 
   # Get arguments
   args, rest = get_string_between(rest, '  let arguments = (ins', ');\n')
@@ -518,9 +541,10 @@ def extract_td_op_info(op_def):
   return {
       # Prefix with 'Op' to make it consistent with SPIR-V spec
       'opname': 'Op{}'.format(opname),
+      'inst_category': inst_category,
       'category_args': category_args,
       'traits': traits,
-      'assembly': assembly,
+      'description': description,
       'arguments': args,
       'results': results,
       'extras': extras
@@ -567,7 +591,7 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
         inst for inst in instructions if inst['opname'] == opname)
     op_defs.append(
         get_op_definition(instruction, docs[opname],
-                          op_info_dict.get(opname, {}), inst_category))
+                          op_info_dict.get(opname, {})))
 
   # Substitute the old op definitions
   op_defs = [header] + op_defs + [footer]
@@ -622,8 +646,7 @@ if __name__ == '__main__':
       type=str,
       default='Op',
       help='SPIR-V instruction category used for choosing '\
-           'a suitable .td file and TableGen common base '\
-           'class to define this op')
+           'the TableGen base class to define this op')
 
   args = cli_parser.parse_args()