Added tracking to types and validated certain instructions.
authorAndrew Woloszyn <awoloszyn@google.com>
Tue, 29 Sep 2015 15:28:34 +0000 (11:28 -0400)
committerDavid Neto <dneto@google.com>
Mon, 26 Oct 2015 16:55:33 +0000 (12:55 -0400)
We need to know how to generate correct SPIRV for cases like
OpConstant %int64 42 since the current parser will encode the 42 as a
32-bit value incorrectly.

This change is the first of a pair. This one tracks types, and makes
sure that OpConstant and OpSpecConstant are only ever called with
Integer or Float types, and OpSwitch is only called with integer
generating values.

15 files changed:
include/libspirv/libspirv.h
source/binary.cpp
source/instruction.h
source/opcode.cpp
source/opcode.h
source/text.cpp
source/text_handler.cpp
source/text_handler.h
test/AssemblyFormat.cpp
test/BinaryToText.cpp
test/ImmediateInt.cpp
test/TextToBinary.Constant.cpp
test/TextToBinary.ControlFlow.cpp
test/UnitSPIRV.h
test/ValidateID.cpp

index f4d821a..bbb2911 100644 (file)
@@ -152,6 +152,7 @@ typedef enum spv_endianness_t {
 typedef enum spv_operand_type_t {
   SPV_OPERAND_TYPE_NONE = 0,
   SPV_OPERAND_TYPE_ID,
+  SPV_OPERAND_TYPE_TYPE_ID,
   SPV_OPERAND_TYPE_RESULT_ID,
   SPV_OPERAND_TYPE_LITERAL_NUMBER,
   // A literal number that can (but is not required to) expand multiple words.
index b277acf..4ac1493 100644 (file)
@@ -177,6 +177,7 @@ spv_result_t spvBinaryDecodeOperand(
   switch (type) {
     case SPV_OPERAND_TYPE_EXECUTION_SCOPE:
     case SPV_OPERAND_TYPE_ID:
+    case SPV_OPERAND_TYPE_TYPE_ID:
     case SPV_OPERAND_TYPE_ID_IN_OPTIONAL_TUPLE:
     case SPV_OPERAND_TYPE_OPTIONAL_ID:
     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS:
index b2b10b4..960fce3 100644 (file)
@@ -40,6 +40,9 @@ struct spv_instruction_t {
   Op opcode;
   spv_ext_inst_type_t extInstType;
 
+  // The Id of the result type, if this instruction has one.  Zero otherwise.
+  uint32_t resultTypeId;
+
   // The instruction, as a sequence of 32-bit words.
   // For a regular instruction the opcode and word count are combined
   // in words[0], as described in the SPIR-V spec.
index 957ddef..70f1ad7 100644 (file)
@@ -180,7 +180,7 @@ void spvOpcodeTableInitialize() {
     opcode.numTypes = 0;
     // Type ID always comes first, if present.
     if (opcode.hasType)
-      opcode.operandTypes[opcode.numTypes++] = SPV_OPERAND_TYPE_ID;
+      opcode.operandTypes[opcode.numTypes++] = SPV_OPERAND_TYPE_TYPE_ID;
     // Result ID always comes next, if present
     if (opcode.hasResult)
       opcode.operandTypes[opcode.numTypes++] = SPV_OPERAND_TYPE_RESULT_ID;
@@ -805,3 +805,32 @@ int32_t spvOpcodeIsValue(Op opcode) {
       return false;
   }
 }
+
+int32_t spvOpcodeGeneratesType(Op op) {
+  switch(op) {
+    case OpTypeVoid:
+    case OpTypeBool:
+    case OpTypeInt:
+    case OpTypeFloat:
+    case OpTypeVector:
+    case OpTypeMatrix:
+    case OpTypeImage:
+    case OpTypeSampler:
+    case OpTypeSampledImage:
+    case OpTypeArray:
+    case OpTypeRuntimeArray:
+    case OpTypeStruct:
+    case OpTypeOpaque:
+    case OpTypePointer:
+    case OpTypeFunction:
+    case OpTypeEvent:
+    case OpTypeDeviceEvent:
+    case OpTypeReserveId:
+    case OpTypeQueue:
+    case OpTypePipe:
+    case OpTypeForwardPointer:
+      return true;
+    default:;
+  }
+  return 0;
+}
index 8d17d10..a6d5e14 100644 (file)
@@ -189,4 +189,11 @@ int32_t spvInstructionIsInBasicBlock(const spv_instruction_t *pFirstInst,
 /// @return zero if false, non-zero otherwise
 int32_t spvOpcodeIsValue(Op opcode);
 
+/// @brief Determine if the Opcode generates a type
+///
+/// @param[in] opcode the opcode
+///
+/// @return zero if false, non-zero otherwise
+int32_t spvOpcodeGeneratesType(Op op);
+
 #endif
index 32bbf85..1a2ddca 100644 (file)
@@ -209,6 +209,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
   switch (type) {
     case SPV_OPERAND_TYPE_EXECUTION_SCOPE:
     case SPV_OPERAND_TYPE_ID:
+    case SPV_OPERAND_TYPE_TYPE_ID:
     case SPV_OPERAND_TYPE_ID_IN_OPTIONAL_TUPLE:
     case SPV_OPERAND_TYPE_OPTIONAL_ID:
     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS:
@@ -224,6 +225,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
         return SPV_ERROR_INVALID_TEXT;
       }
       const uint32_t id = context->spvNamedIdAssignOrGet(textValue);
+      if (type == SPV_OPERAND_TYPE_TYPE_ID) pInst->resultTypeId = id;
       spvInstructionAddWord(pInst, id);
     } break;
     case SPV_OPERAND_TYPE_LITERAL_NUMBER: {
@@ -255,6 +257,40 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
                               << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
+
+      // The encoding for OpConstant, OpSpecConstant and OpSwitch all
+      // depend on either their own result-id or the result-id of
+      // one of their parameters.
+      if (OpConstant == pInst->opcode || OpSpecConstant == pInst->opcode) {
+        // Special cases for encoding possibly non-32-bit literals here.
+        libspirv::IdType type =
+            context->getTypeOfTypeGeneratingValue(pInst->resultTypeId);
+        if (!libspirv::isScalarFloating(type) &&
+            !libspirv::isScalarIntegral(type)) {
+          spv_opcode_desc d;
+          const char* opcode_name = "opcode";
+          if (SPV_SUCCESS == grammar.lookupOpcode(pInst->opcode, &d)) {
+            opcode_name = d->name;
+          }
+          context->diagnostic()
+              << "Type for " << opcode_name
+              << " must be a scalar floating point or integer type";
+          return SPV_ERROR_INVALID_TEXT;
+        }
+      } else if (pInst->opcode == OpSwitch) {
+        // We need to know the value of the selector.
+        libspirv::IdType type =
+            context->getTypeOfValueInstruction(pInst->words[1]);
+        if (type.type_class != libspirv::IdTypeClass::kScalarIntegerType) {
+          context->diagnostic()
+              << "The selector operand for OpSwitch must be the result"
+                 " of an instruction that generates an integer scalar";
+          return SPV_ERROR_INVALID_TEXT;
+        }
+      }
+      // TODO(awoloszyn): Generate the correct assembly for arbitrary
+      // bitwidths here instead of falling though.
+
       switch (literal.type) {
         // We do not have to print diagnostics here because spvBinaryEncode*
         // prints diagnostic messages on failure.
@@ -437,7 +473,6 @@ spv_result_t encodeInstructionStartingWithImmediate(
   }
   return SPV_SUCCESS;
 }
-
 }  // anonymous namespace
 
 /// @brief Translate single Opcode and operands to binary form
@@ -624,6 +659,18 @@ spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar& grammar,
     }
   }
 
+  if (spvOpcodeGeneratesType(pInst->opcode)) {
+    if (context->recordTypeDefinition(pInst) != SPV_SUCCESS) {
+      return SPV_ERROR_INVALID_TEXT;
+    }
+  } else if (opcodeEntry->hasType) {
+    // SPIR-V dictates that if an instruction has both a return value and a
+    // type ID then the type id is first, and the return value is second.
+    assert(opcodeEntry->hasResult &&
+           "Unknown opcode: has a type but no result.");
+    context->recordTypeIdForValue(pInst->words[2], pInst->words[1]);
+  }
+
   if (pInst->words.size() > SPV_LIMIT_INSTRUCTION_WORD_COUNT_MAX) {
     context->diagnostic() << "Instruction too long: " << pInst->words.size()
                           << " words, but the limit is "
index a0e2512..0f6f135 100644 (file)
@@ -29,6 +29,7 @@
 #include <algorithm>
 #include <cassert>
 #include <cstring>
+#include <tuple>
 
 #include "binary.h"
 #include "ext_inst.h"
@@ -219,6 +220,12 @@ spv_result_t AssemblyGrammar::lookupOpcode(const char *name,
                                            spv_opcode_desc *desc) const {
   return spvOpcodeTableNameLookup(opcodeTable_, name, desc);
 }
+
+spv_result_t AssemblyGrammar::lookupOpcode(Op opcode,
+                                           spv_opcode_desc *desc) const {
+  return spvOpcodeTableValueLookup(opcodeTable_, opcode, desc);
+}
+
 spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type,
                                             const char *name, size_t name_len,
                                             spv_operand_desc *desc) const {
@@ -358,5 +365,61 @@ spv_result_t AssemblyContext::binaryEncodeString(
 
   return SPV_SUCCESS;
 }
+
+spv_result_t AssemblyContext::recordTypeDefinition(
+    const spv_instruction_t *pInst) {
+  uint32_t value = pInst->words[1];
+  if (types_.find(value) != types_.end()) {
+    diagnostic() << "Value " << value
+                 << " has already been used to generate a type";
+    return SPV_ERROR_INVALID_VALUE;
+  }
+
+  if (pInst->opcode == OpTypeInt) {
+    if (pInst->words.size() != 4) {
+      diagnostic() << "Invalid OpTypeInt instruction";
+      return SPV_ERROR_INVALID_VALUE;
+    }
+    types_[value] = { pInst->words[2], IdTypeClass::kScalarIntegerType };
+  } else if (pInst->opcode == OpTypeFloat) {
+    if (pInst->words.size() != 3) {
+      diagnostic() << "Invalid OpTypeFloat instruction";
+      return SPV_ERROR_INVALID_VALUE;
+    }
+    types_[value] = { pInst->words[2], IdTypeClass::kScalarFloatType };
+  } else {
+    types_[value] = { 0, IdTypeClass::kOtherType };
+  }
+  return SPV_SUCCESS;
+}
+
+IdType AssemblyContext::getTypeOfTypeGeneratingValue(uint32_t value) const {
+  auto type = types_.find(value);
+  if (type == types_.end()) {
+    return {0, IdTypeClass::kBottom};
+  }
+  return std::get<1>(*type);
+}
+
+IdType AssemblyContext::getTypeOfValueInstruction(uint32_t value) const {
+  auto type_value = value_types_.find(value);
+  if (type_value == value_types_.end()) {
+    return { 0, IdTypeClass::kBottom};
+  }
+  return getTypeOfTypeGeneratingValue(std::get<1>(*type_value));
+}
+
+spv_result_t AssemblyContext::recordTypeIdForValue(uint32_t value,
+                                                   uint32_t type) {
+  bool successfully_inserted = false;
+  std::tie(std::ignore, successfully_inserted) =
+      value_types_.insert(std::make_pair(value, type));
+  if (!successfully_inserted) {
+    diagnostic() << "Value is being defined a second time";
+    return SPV_ERROR_INVALID_VALUE;
+  }
+  return SPV_SUCCESS;
+}
+
 }
 
index d5ba7eb..e67bb82 100644 (file)
@@ -54,6 +54,16 @@ struct IdType {
   IdTypeClass type_class;
 };
 
+// Returns true if the type is a scalar integer type.
+inline bool isScalarIntegral(const IdType& type) {
+  return type.type_class == IdTypeClass::kScalarIntegerType;
+}
+
+// Returns true if the type is a scalar floating point type.
+inline bool isScalarFloating(const IdType& type) {
+  return type.type_class == IdTypeClass::kScalarFloatType;
+}
+
 // Encapsulates the grammar to use for SPIR-V assembly.
 // Contains methods to query for valid instructions and operands.
 class AssemblyGrammar {
@@ -73,6 +83,11 @@ class AssemblyGrammar {
   // SPV_ERROR_INVALID_LOOKUP if the opcode does not exist.
   spv_result_t lookupOpcode(const char *name, spv_opcode_desc *desc) const;
 
+  // Fills in the desc parameter with the information about the opcode
+  // of the valid. Returns SPV_SUCCESS if the opcode was found, and
+  // SPV_ERROR_INVALID_LOOKUP if the opcode does not exist.
+  spv_result_t lookupOpcode(Op opcode, spv_opcode_desc *desc) const;
+
   // Fills in the desc parameter with the information about the given
   // operand. Returns SPV_SUCCESS if the operand was found, and
   // SPV_ERROR_INVALID_LOOKUP otherwise.
@@ -189,6 +204,27 @@ class AssemblyContext {
   // instruction.
   spv_result_t binaryEncodeString(const char *value, spv_instruction_t *pInst);
 
+  // Returns the IdType associated with this type-generating value.
+  // If the type has not been previously recorded with recordTypeDefinition,
+  // { 0, IdTypeClass::kBottom } will be returned.
+  IdType getTypeOfTypeGeneratingValue(uint32_t value) const;
+
+  // Returns the IdType that represents the return value of this Value
+  // generating instruction.
+  // If the value has not been recorded with recordTypeIdForValue, or the type
+  // could not be determined { 0, IdTypeClass::kBottom } will be returned.
+  IdType getTypeOfValueInstruction(uint32_t value) const;
+
+  // Tracks the type-defining instruction. The result of the tracking can
+  // later be queried using getValueType.
+  // pInst is expected to be completely filled in by the time this instruction
+  // is called.
+  // Returns SPV_SUCCESS on success, or SPV_ERROR_INVALID_VALUE on error.
+  spv_result_t recordTypeDefinition(const spv_instruction_t* pInst);
+
+  // Tracks the relationship between the value and its type.
+  spv_result_t recordTypeIdForValue(uint32_t value, uint32_t type);
+
  private:
   // Maps ID names to their corresponding numerical ids.
   using spv_named_id_table = std::unordered_map<std::string, uint32_t>;
index 8cd51d7..3e4a890 100644 (file)
@@ -31,7 +31,7 @@ namespace {
 using spvtest::TextToBinaryTest;
 
 TEST_F(TextToBinaryTest, EncodeAAFTextAsAAF) {
-  SetText("%2 = OpConstant %1 1000");
+  SetText("%2 = OpTypeMatrix %1 1000");
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_ASSIGNMENT,
@@ -42,7 +42,7 @@ TEST_F(TextToBinaryTest, EncodeAAFTextAsAAF) {
 }
 
 TEST_F(TextToBinaryTest, EncodeAAFTextAsCAF) {
-  SetText("%2 = OpConstant %1 1000");
+  SetText("%2 = OpTypeMatrix %1 1000");
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_CANONICAL,
@@ -55,7 +55,7 @@ TEST_F(TextToBinaryTest, EncodeAAFTextAsCAF) {
 }
 
 TEST_F(TextToBinaryTest, EncodeCAFTextAsCAF) {
-  SetText("OpConstant %1 %2 1000");
+  SetText("OpTypeMatrix %1 %2 1000");
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_CANONICAL,
@@ -66,7 +66,7 @@ TEST_F(TextToBinaryTest, EncodeCAFTextAsCAF) {
 }
 
 TEST_F(TextToBinaryTest, EncodeCAFTextAsAAF) {
-  SetText("OpConstant %1 %2 1000");
+  SetText("OpTypeMatrix %1 %2 1000");
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_ASSIGNMENT,
@@ -74,13 +74,13 @@ TEST_F(TextToBinaryTest, EncodeCAFTextAsAAF) {
   ASSERT_NE(nullptr, diagnostic);
   EXPECT_STREQ(
       "Expected <result-id> at the beginning of an instruction, found "
-      "'OpConstant'.",
+      "'OpTypeMatrix'.",
       diagnostic->error);
   EXPECT_EQ(0, diagnostic->position.line);
 }
 
 TEST_F(TextToBinaryTest, EncodeMixedTextAsAAF) {
-  SetText("OpConstant %1 %2 1000\n%3 = OpConstant %1 2000");
+  SetText("OpTypeMatrix %1 %2 1000\n%3 = OpTypeMatrix %1 2000");
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_ASSIGNMENT,
@@ -88,13 +88,13 @@ TEST_F(TextToBinaryTest, EncodeMixedTextAsAAF) {
   ASSERT_NE(nullptr, diagnostic);
   EXPECT_STREQ(
       "Expected <result-id> at the beginning of an instruction, found "
-      "'OpConstant'.",
+      "'OpTypeMatrix'.",
       diagnostic->error);
   EXPECT_EQ(0, diagnostic->position.line);
 }
 
 TEST_F(TextToBinaryTest, EncodeMixedTextAsCAF) {
-  SetText("OpConstant %1 %2 1000\n%3 = OpConstant %1 2000");
+  SetText("OpTypeMatrix %1 %2 1000\n%3 = OpTypeMatrix %1 2000");
   EXPECT_EQ(SPV_ERROR_INVALID_TEXT,
             spvTextWithFormatToBinary(
                 text.str, text.length, SPV_ASSEMBLY_SYNTAX_FORMAT_CANONICAL,
@@ -111,7 +111,7 @@ const char* kBound4Preamble =
 
 TEST_F(TextToBinaryTest, DecodeAAFAsAAF) {
   const std::string assembly =
-      "%2 = OpConstant %1 1000\n%3 = OpConstant %1 2000\n";
+      "%1 = OpTypeMatrix %2 1000\n%3 = OpTypeMatrix %2 2000\n";
   SetText(assembly);
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
@@ -130,7 +130,7 @@ TEST_F(TextToBinaryTest, DecodeAAFAsAAF) {
 
 TEST_F(TextToBinaryTest, DecodeAAFAsCAF) {
   const std::string assembly =
-      "%2 = OpConstant %1 1000\n%3 = OpConstant %1 2000\n";
+      "%1 = OpTypeMatrix %2 1000\n%3 = OpTypeMatrix %2 2000\n";
   SetText(assembly);
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
@@ -144,13 +144,14 @@ TEST_F(TextToBinaryTest, DecodeAAFAsCAF) {
                                       SPV_ASSEMBLY_SYNTAX_FORMAT_CANONICAL,
                                       &decoded_text, &diagnostic));
   EXPECT_EQ(std::string(kBound4Preamble) +
-                "OpConstant %1 %2 1000\nOpConstant %1 %3 2000\n",
+                "OpTypeMatrix %1 %2 1000\nOpTypeMatrix %3 %2 2000\n",
             decoded_text->str);
   spvTextDestroy(decoded_text);
 }
 
 TEST_F(TextToBinaryTest, DecodeCAFAsAAF) {
-  const std::string assembly = "OpConstant %1 %2 1000\nOpConstant %1 %3 2000\n";
+  const std::string assembly =
+      "OpTypeMatrix %1 %2 1000\nOpTypeMatrix %3 %2 2000\n";
   SetText(assembly);
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
@@ -164,13 +165,14 @@ TEST_F(TextToBinaryTest, DecodeCAFAsAAF) {
                                       SPV_ASSEMBLY_SYNTAX_FORMAT_ASSIGNMENT,
                                       &decoded_text, &diagnostic));
   EXPECT_EQ(std::string(kBound4Preamble) +
-                "%2 = OpConstant %1 1000\n%3 = OpConstant %1 2000\n",
+                "%1 = OpTypeMatrix %2 1000\n%3 = OpTypeMatrix %2 2000\n",
             decoded_text->str);
   spvTextDestroy(decoded_text);
 }
 
 TEST_F(TextToBinaryTest, DecodeCAFAsCAF) {
-  const std::string assembly = "OpConstant %1 %2 1000\nOpConstant %1 %3 2000\n";
+  const std::string assembly =
+      "OpTypeMatrix %1 %2 1000\nOpTypeMatrix %3 %2 2000\n";
   SetText(assembly);
   EXPECT_EQ(SPV_SUCCESS,
             spvTextWithFormatToBinary(
@@ -184,7 +186,7 @@ TEST_F(TextToBinaryTest, DecodeCAFAsCAF) {
                                       SPV_ASSEMBLY_SYNTAX_FORMAT_CANONICAL,
                                       &decoded_text, &diagnostic));
   EXPECT_EQ(std::string(kBound4Preamble) +
-                "OpConstant %1 %2 1000\nOpConstant %1 %3 2000\n",
+                "OpTypeMatrix %1 %2 1000\nOpTypeMatrix %3 %2 2000\n",
             decoded_text->str);
   spvTextDestroy(decoded_text);
 }
index ad63ea9..8ca9faf 100644 (file)
@@ -219,11 +219,11 @@ TEST(BinaryToTextSmall, LiteralInt64) {
   error = spvBinaryToText(binary->code, binary->wordCount,
                           SPV_BINARY_TO_TEXT_OPTION_NONE, opcodeTable,
                           operandTable, extInstTable, &text, &diagnostic);
-  EXPECT_EQ(SPV_SUCCESS, error);
   if (error) {
     spvDiagnosticPrint(diagnostic);
     spvDiagnosticDestroy(diagnostic);
   }
+  ASSERT_EQ(SPV_SUCCESS, error);
   const std::string header =
       "; SPIR-V\n; Version: 99\n; Generator: Khronos\n; "
       "Bound: 3\n; Schema: 0\n";
@@ -252,11 +252,11 @@ TEST(BinaryToTextSmall, LiteralDouble) {
   error = spvBinaryToText(binary->code, binary->wordCount,
                           SPV_BINARY_TO_TEXT_OPTION_NONE, opcodeTable,
                           operandTable, extInstTable, &text, &diagnostic);
-  EXPECT_EQ(SPV_SUCCESS, error);
   if (error) {
     spvDiagnosticPrint(diagnostic);
     spvDiagnosticDestroy(diagnostic);
   }
+  ASSERT_EQ(SPV_SUCCESS, error);
   const std::string output =
       R"(; SPIR-V
 ; Version: 99
index c4649a9..200572c 100644 (file)
@@ -68,20 +68,20 @@ TEST_F(TextToBinaryTest, ImmediateIntOperand) {
 using ImmediateIntTest = TextToBinaryTest;
 
 TEST_F(ImmediateIntTest, AnyWordInSimpleStatement) {
-  EXPECT_THAT(CompiledInstructions("!0x0004002B %a %b 123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
-  EXPECT_THAT(CompiledInstructions("OpConstant !1 %b 123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 1, 123})));
-  EXPECT_THAT(CompiledInstructions("OpConstant %1 !2 123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
-  EXPECT_THAT(CompiledInstructions("OpConstant  %a %b !123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
-  EXPECT_THAT(CompiledInstructions("!0x0004002B %1 !2 123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
-  EXPECT_THAT(CompiledInstructions("OpConstant !1 %b !123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 1, 123})));
-  EXPECT_THAT(CompiledInstructions("!0x0004002B !1 !2 !123", kCAF),
-              Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
+  EXPECT_THAT(CompiledInstructions("!0x00040018 %a %b %123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 2,3 })));
+  EXPECT_THAT(CompiledInstructions("OpTypeMatrix !1 %b %123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 1, 2})));
+  EXPECT_THAT(CompiledInstructions("OpTypeMatrix %1 !2 %123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 2, 2})));
+  EXPECT_THAT(CompiledInstructions("OpTypeMatrix  %a %b !123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 2, 123})));
+  EXPECT_THAT(CompiledInstructions("!0x00040018 %1 !2 %123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 2, 2})));
+  EXPECT_THAT(CompiledInstructions("OpTypeMatrix !1 %b !123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 1, 123})));
+  EXPECT_THAT(CompiledInstructions("!0x00040018 !1 !2 !123", kCAF),
+              Eq(MakeInstruction(spv::OpTypeMatrix, {1, 2, 123})));
 }
 
 TEST_F(ImmediateIntTest, AnyWordAfterEqualsAndOpCode) {
@@ -118,12 +118,16 @@ TEST_F(ImmediateIntTest, IntegerFollowingImmediate) {
   EXPECT_EQ(original, CompiledInstructions("OpTypeInt !1 8 1", kCAF));
 
   // 64-bit integer literal.
-  EXPECT_EQ(CompiledInstructions("OpConstant %10 %2 5000000000", kCAF),
-            CompiledInstructions("OpConstant %10 !2 5000000000", kCAF));
+  EXPECT_EQ(CompiledInstructions("OpTypeInt %i64 64 0\n"
+                                 "OpConstant %i64 %2 5000000000", kCAF),
+            CompiledInstructions("OpTypeInt %i64 64 0\n"
+                                 "OpConstant %i64 !2 5000000000", kCAF));
 
   // Negative integer.
-  EXPECT_EQ(CompiledInstructions("OpConstant %10 %2 -123", kCAF),
-            CompiledInstructions("OpConstant %10 !2 -123", kCAF));
+  EXPECT_EQ(CompiledInstructions("OpTypeInt %i64 32 1\n"
+                                 "OpConstant %i64 %2 -123", kCAF),
+            CompiledInstructions("OpTypeInt %i64 32 1\n"
+                                 "OpConstant %i64 !2 -123", kCAF));
 
   // TODO(deki): uncomment assertions below and make them pass.
   // Hex value(s).
@@ -136,16 +140,16 @@ TEST_F(ImmediateIntTest, IntegerFollowingImmediate) {
 
 // Literal floats after !<integer> are handled correctly.
 TEST_F(ImmediateIntTest, FloatFollowingImmediate) {
-  EXPECT_EQ(CompiledInstructions("OpConstant %10 %2 0.123", kCAF),
-            CompiledInstructions("OpConstant %10 !2 0.123", kCAF));
-  EXPECT_EQ(CompiledInstructions("OpConstant %10 %2 -0.5", kCAF),
-            CompiledInstructions("OpConstant %10 !2 -0.5", kCAF));
+  EXPECT_EQ(CompiledInstructions("OpTypeMatrix %10 %2 0.123", kCAF),
+            CompiledInstructions("OpTypeMatrix %10 !2 0.123", kCAF));
+  EXPECT_EQ(CompiledInstructions("OpTypeMatrix %10 %2 -0.5", kCAF),
+            CompiledInstructions("OpTypeMatrix %10 !2 -0.5", kCAF));
   // 64-bit float.
   EXPECT_EQ(
       CompiledInstructions(
-          "OpConstant %10 %2 9999999999999999999999999999999999999999.9", kCAF),
+          "OpTypeMatrix %10 %2 9999999999999999999999999999999999999999.9", kCAF),
       CompiledInstructions(
-          "OpConstant %10 !2 9999999999999999999999999999999999999999.9",
+          "OpTypeMatrix %10 !2 9999999999999999999999999999999999999999.9",
           kCAF));
 }
 
index 7233565..2650a0f 100644 (file)
@@ -36,6 +36,7 @@ namespace {
 
 using spvtest::EnumCase;
 using spvtest::MakeInstruction;
+using spvtest::Concatenate;
 using ::testing::Eq;
 
 // Test Sampler Addressing Mode enum values
@@ -89,9 +90,89 @@ INSTANTIATE_TEST_CASE_P(
 #undef CASE
 // clang-format on
 
+
+struct ConstantTestCase {
+  std::string constant_type;
+  std::string constant_value;
+  std::vector<uint32_t> expected_instructions;
+};
+
+using OpConstantValidTest = spvtest::TextToBinaryTestBase<
+  ::testing::TestWithParam<ConstantTestCase>>;
+
+TEST_P(OpConstantValidTest, ValidTypes)
+{
+  std::string input =
+      "%1 = " + GetParam().constant_type + "\n"
+      "%2 = OpConstant %1 " + GetParam().constant_value + "\n";
+  std::vector<uint32_t> instructions;
+  EXPECT_THAT(CompiledInstructions(input),
+              Eq(GetParam().expected_instructions));
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+    TextToBinaryOpConstantValid, OpConstantValidTest,
+    ::testing::ValuesIn(std::vector<ConstantTestCase>{
+      {"OpTypeInt 32 0", "42",
+        Concatenate({MakeInstruction(spv::OpTypeInt, {1, 32, 0}),
+         MakeInstruction(spv::OpConstant, {1, 2, 42})})},
+      {"OpTypeInt 32 1", "-32",
+        Concatenate({MakeInstruction(spv::OpTypeInt, {1, 32, 1}),
+         MakeInstruction(spv::OpConstant, {1, 2, static_cast<uint32_t>(-32)})})},
+      {"OpTypeFloat 32", "1.0",
+        Concatenate({MakeInstruction(spv::OpTypeFloat, {1, 32}),
+         MakeInstruction(spv::OpConstant, {1, 2, 0x3f800000})})},
+      {"OpTypeFloat 32", "10.0",
+        Concatenate({MakeInstruction(spv::OpTypeFloat, {1, 32}),
+         MakeInstruction(spv::OpConstant, {1, 2, 0x41200000})})},
+    }));
+// clang-format on
+
+using OpConstantInvalidTypeTest = spvtest::TextToBinaryTestBase<
+  ::testing::TestWithParam<std::string>>;
+
+TEST_P(OpConstantInvalidTypeTest, InvalidTypes)
+{
+  std::string input =
+      "%1 = " + GetParam() + "\n"
+      "%2 = OpConstant %1 0\n";
+  EXPECT_THAT(
+      CompileFailure(input),
+      Eq("Type for Constant must be a scalar floating point or integer type"));
+}
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+    TextToBinaryOpConstantInvalidValidType, OpConstantInvalidTypeTest,
+    ::testing::ValuesIn(std::vector<std::string>{
+      {"OpTypeVoid",
+       "OpTypeBool",
+       "OpTypeVector %a 32",
+       "OpTypeMatrix %a 32",
+       "OpTypeImage %a 1D 0 0 0 0 Unknown",
+       "OpTypeSampler",
+       "OpTypeSampledImage %a",
+       "OpTypeArray %a %b",
+       "OpTypeRuntimeArray %a",
+       "OpTypeStruct %a",
+       "OpTypeOpaque \"Foo\"",
+       "OpTypePointer UniformConstant %a",
+       "OpTypeFunction %a %b",
+       "OpTypeEvent",
+       "OpTypeDeviceEvent",
+       "OpTypeReserveId",
+       "OpTypeQueue",
+       "OpTypePipe ReadOnly",
+       "OpTypeForwardPointer %a UniformConstant",
+        // At least one thing that isn't a type at all
+       "OpNot %a %b"
+      },
+    }));
+// clang-format on
+
+
 // TODO(dneto): OpConstantTrue
 // TODO(dneto): OpConstantFalse
-// TODO(dneto): OpConstant
 // TODO(dneto): OpConstantComposite
 // TODO(dneto): OpConstantSampler: other variations Param is 0 or 1
 // TODO(dneto): OpConstantNull
index e47c3b9..181f4b3 100644 (file)
@@ -36,6 +36,7 @@ namespace {
 
 using spvtest::EnumCase;
 using spvtest::MakeInstruction;
+using spvtest::Concatenate;
 using spvtest::TextToBinaryTest;
 using ::testing::Eq;
 
@@ -100,7 +101,6 @@ TEST_F(OpLoopMergeTest, CombinedLoopControlMask) {
   EXPECT_THAT(CompiledInstructions(input),
               Eq(MakeInstruction(spv::OpLoopMerge, {1, 2, expected_mask})));
 }
-
 // Test OpSwitch
 
 TEST_F(TextToBinaryTest, SwitchGoodZeroTargets) {
@@ -109,14 +109,25 @@ TEST_F(TextToBinaryTest, SwitchGoodZeroTargets) {
 }
 
 TEST_F(TextToBinaryTest, SwitchGoodOneTarget) {
-  EXPECT_THAT(CompiledInstructions("OpSwitch %selector %default 12 %target0"),
-              Eq(MakeInstruction(spv::OpSwitch, {1, 2, 12, 3})));
+  EXPECT_THAT(CompiledInstructions("%1 = OpTypeInt 32 0\n"
+                                   "%2 = OpConstant %1 52\n"
+                                   "OpSwitch %2 %default 12 %target0"),
+              Eq(Concatenate({
+                  MakeInstruction(spv::OpTypeInt, {1, 32, 0}),
+                  MakeInstruction(spv::OpConstant, {1, 2, 52}),
+                  MakeInstruction(spv::OpSwitch, {2, 3, 12, 4})})));
 }
 
 TEST_F(TextToBinaryTest, SwitchGoodTwoTargets) {
-  EXPECT_THAT(CompiledInstructions(
-                  "OpSwitch %selector %default 12 %target0 42 %target1"),
-              Eq(MakeInstruction(spv::OpSwitch, {1, 2, 12, 3, 42, 4})));
+  EXPECT_THAT(
+      CompiledInstructions("%1 = OpTypeInt 32 0\n"
+                           "%2 = OpConstant %1 52\n"
+                           "OpSwitch %2 %default 12 %target0 42 %target1"),
+      Eq(Concatenate({
+          MakeInstruction(spv::OpTypeInt, {1, 32, 0}),
+          MakeInstruction(spv::OpConstant, {1, 2, 52}),
+          MakeInstruction(spv::OpSwitch, {2, 3, 12, 4, 42, 5}),
+      })));
 }
 
 TEST_F(TextToBinaryTest, SwitchBadMissingSelector) {
@@ -156,11 +167,97 @@ TEST_F(TextToBinaryTest, SwitchBadInvalidLiteralCanonicalFormat) {
 }
 
 TEST_F(TextToBinaryTest, SwitchBadMissingTarget) {
-  EXPECT_THAT(CompileFailure("OpSwitch %selector %default 12"),
+  EXPECT_THAT(CompileFailure("%1 = OpTypeInt 32 0\n"
+                             "%2 = OpConstant %1 52\n"
+                             "OpSwitch %2 %default 12"),
               Eq("Expected operand, found end of stream."));
 }
 
+struct SwitchTestCase{
+  std::string constant_type_args;
+  std::vector<uint32_t> expected_instructions;
+};
+
+using OpSwitchValidTest = spvtest::TextToBinaryTestBase<
+  ::testing::TestWithParam<SwitchTestCase>>;
+
+TEST_P(OpSwitchValidTest, ValidTypes)
+{
+  std::string input =
+      "%1 = OpTypeInt " + GetParam().constant_type_args + "\n"
+      "%2 = OpConstant %1 0\n"
+      "OpSwitch %2 %default 32 %4\n";
+  std::vector<uint32_t> instructions;
+  EXPECT_THAT(CompiledInstructions(input),
+              Eq(GetParam().expected_instructions));
+}
+// clang-format off
+#define CASE(integer_width, integer_signedness) \
+  { #integer_width " " #integer_signedness, \
+    { Concatenate({ \
+       MakeInstruction(spv::OpTypeInt, {1, integer_width, integer_signedness}),\
+       MakeInstruction(spv::OpConstant, {1, 2, 0} ), \
+       MakeInstruction(spv::OpSwitch, {2, 3, 32, 4} )})}}
+INSTANTIATE_TEST_CASE_P(
+    TextToBinaryOpSwitchValid, OpSwitchValidTest,
+    ::testing::ValuesIn(std::vector<SwitchTestCase>{
+      CASE(32, 0),
+      CASE(16, 0),
+      // TODO(dneto): For a 64-bit selector, the literals should take up two
+      // words, in little-endian sequence.  In that case the OpSwitch operands
+      // would be {2, 3, 32, 0, 4}.
+      CASE(64, 0),
+      // TODO(dneto): Try signed cases also.
+    }));
+#undef CASE
+// clang-format on
+
+using OpSwitchInvalidTypeTestCase = spvtest::TextToBinaryTestBase<
+  ::testing::TestWithParam<std::string>>;
+
+TEST_P(OpSwitchInvalidTypeTestCase, InvalidTypes)
+{
+  std::string input =
+      "%1 = " + GetParam() + "\n"
+      "%3 = OpCopyObject %1 %2\n" // We only care the type of the expression
+      "%4 = OpSwitch %3 %default 32 %c\n";
+  EXPECT_THAT(CompileFailure(input),
+              Eq(std::string(
+                  "The selector operand for OpSwitch must be the result of an "
+                  "instruction that generates an integer scalar")));
+}
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+    TextToBinaryOpSwitchInvalidTests, OpSwitchInvalidTypeTestCase,
+    ::testing::ValuesIn(std::vector<std::string>{
+      {"OpTypeVoid",
+       "OpTypeBool",
+       "OpTypeFloat 32",
+       "OpTypeVector %a 32",
+       "OpTypeMatrix %a 32",
+       "OpTypeImage %a 1D 0 0 0 0 Unknown",
+       "OpTypeSampler",
+       "OpTypeSampledImage %a",
+       "OpTypeArray %a %b",
+       "OpTypeRuntimeArray %a",
+       "OpTypeStruct %a",
+       "OpTypeOpaque \"Foo\"",
+       "OpTypePointer UniformConstant %a",
+       "OpTypeFunction %a %b",
+       "OpTypeEvent",
+       "OpTypeDeviceEvent",
+       "OpTypeReserveId",
+       "OpTypeQueue",
+       "OpTypePipe ReadOnly",
+       "OpTypeForwardPointer %a UniformConstant",
+           // At least one thing that isn't a type at all
+       "OpNot %a %b"
+      },
+    }));
+// clang-format on
 
+//TODO(awoloszyn): Add tests for switch with different operand widths
+//                 once non-32-bit support is in.
 // TODO(dneto): OpPhi
 // TODO(dneto): OpLoopMerge
 // TODO(dneto): OpLabel
index acd4fd2..7180a3d 100644 (file)
@@ -119,6 +119,7 @@ inline std::vector<uint32_t> MakeInstruction(
   return result;
 }
 
+
 // Returns a vector of words representing a single instruction with the
 // given opcode and whose operands are the concatenation of the two given
 // argument lists.
index 7378205..cebdef1 100644 (file)
@@ -421,7 +421,7 @@ TEST_F(ValidateID, OpConstantGood) {
 TEST_F(ValidateID, OpConstantBad) {
   const char *spirv = R"(
 %1 = OpTypeVoid
-%2 = OpConstant %1 0)";
+%2 = OpConstant !1 !0)";
   CHECK(spirv, SPV_ERROR_INVALID_ID);
 }
 
@@ -643,7 +643,7 @@ TEST_F(ValidateID, OpSpecConstantGood) {
 TEST_F(ValidateID, OpSpecConstantBad) {
   const char *spirv = R"(
 %1 = OpTypeVoid
-%2 = OpSpecConstant %1 3.14)";
+%2 = OpSpecConstant !1 !4)";
   CHECK(spirv, SPV_ERROR_INVALID_ID);
 }