Detect overflow in !<integer>.
authorDejan Mircevski <deki@google.com>
Tue, 29 Sep 2015 21:07:21 +0000 (17:07 -0400)
committerDavid Neto <dneto@google.com>
Mon, 26 Oct 2015 16:55:33 +0000 (12:55 -0400)
source/text.cpp
test/ImmediateInt.cpp

index 5bc0b5d..ccec524 100644 (file)
 
 using spvutils::BitwiseCast;
 
-
 bool spvIsValidIDCharacter(const char value) {
   return value == '_' || 0 != ::isalnum(value);
 }
 
 // Returns true if the given string represents a valid ID name.
-bool spvIsValidID(const char *textValue) {
-  const char *c = textValue;
+bool spvIsValidID(const chartextValue) {
+  const charc = textValue;
   for (; *c != '\0'; ++c) {
     if (!spvIsValidIDCharacter(*c)) {
       return false;
@@ -152,6 +151,34 @@ spv_result_t spvTextToLiteral(const char* textValue, spv_literal_t* pLiteral) {
   return SPV_SUCCESS;
 }
 
+namespace {
+
+/// Parses an immediate integer from text, guarding against overflow.  If
+/// successful, adds the parsed value to pInst, advances the context past it,
+/// and returns SPV_SUCCESS.  Otherwise, leaves pInst alone, emits diagnostics,
+/// and returns SPV_ERROR_INVALID_TEXT.
+spv_result_t encodeImmediate(libspirv::AssemblyContext* context,
+                             const char* text, spv_instruction_t* pInst) {
+  assert(*text == '!');
+  const char* begin = text + 1;
+  char* end = nullptr;
+  const uint64_t parseResult = std::strtoull(begin, &end, 0);
+  size_t length = end - begin;
+  if (length != strlen(begin)) {
+    context->diagnostic() << "Invalid immediate integer '" << text << "'.";
+    return SPV_ERROR_INVALID_TEXT;
+  } else if (length > 10 || parseResult > UINT32_MAX) {
+    context->diagnostic() << "Immediate integer '" << text
+                          << "' is over 32 bits.";
+    return SPV_ERROR_INVALID_TEXT;
+  }
+  context->binaryEncodeU32(parseResult, pInst);
+  context->seekForward(strlen(text));
+  return SPV_SUCCESS;
+}
+
+}  // anonymous namespace
+
 /// @brief Translate an Opcode operand to binary form
 ///
 /// @param[in] grammar the grammar to use for compilation
@@ -162,7 +189,7 @@ spv_result_t spvTextToLiteral(const char* textValue, spv_literal_t* pLiteral) {
 /// @param[in,out] pExpectedOperands the operand types expected
 ///
 /// @return result code
-spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
+spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammargrammar,
                                   libspirv::AssemblyContext* context,
                                   const spv_operand_type_t type,
                                   const char* textValue,
@@ -170,18 +197,9 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
                                   spv_operand_pattern_t* pExpectedOperands) {
   // NOTE: Handle immediate int in the stream
   if ('!' == textValue[0]) {
-    const char* begin = textValue + 1;
-    char* end = nullptr;
-    uint32_t immediateInt = strtoul(begin, &end, 0);
-    size_t size = strlen(textValue);
-    size_t length = (end - begin);
-    if (size - 1 != length) {
-      context->diagnostic() << "Invalid immediate integer '" << textValue << "'.";
-      return SPV_ERROR_INVALID_TEXT;
+    if (auto error = encodeImmediate(context, textValue, pInst)) {
+      return error;
     }
-    context->seekForward(size);
-    pInst->words[pInst->wordCount] = immediateInt;
-    pInst->wordCount += 1;
     *pExpectedOperands =
         spvAlternatePatternFollowingImmediate(*pExpectedOperands);
     return SPV_SUCCESS;
@@ -211,10 +229,9 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
       // NOTE: Special case for extension instruction lookup
       if (OpExtInst == pInst->opcode) {
         spv_ext_inst_desc extInst;
-        if (grammar.lookupExtInst(pInst->extInstType, textValue,
-                                      &extInst)) {
+        if (grammar.lookupExtInst(pInst->extInstType, textValue, &extInst)) {
           context->diagnostic() << "Invalid extended instruction name '"
-                            << textValue << "'.";
+                                << textValue << "'.";
           return SPV_ERROR_INVALID_TEXT;
         }
         pInst->words[pInst->wordCount++] = extInst->ext_inst;
@@ -233,7 +250,8 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
       if (error != SPV_SUCCESS) {
         if (error == SPV_ERROR_OUT_OF_MEMORY) return error;
         if (spvOperandIsOptional(type)) return SPV_FAILED_MATCH;
-        context->diagnostic() << "Invalid literal number '" << textValue << "'.";
+        context->diagnostic() << "Invalid literal number '" << textValue
+                              << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
       switch (literal.type) {
@@ -269,12 +287,14 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
             return SPV_ERROR_INVALID_TEXT;
         } break;
         case SPV_LITERAL_TYPE_STRING: {
-          context->diagnostic() << "Expected literal number, found literal string '"
-                            << textValue << "'.";
+          context->diagnostic()
+              << "Expected literal number, found literal string '" << textValue
+              << "'.";
           return SPV_FAILED_MATCH;
         } break;
         default:
-          context->diagnostic() << "Invalid literal number '" << textValue << "'";
+          context->diagnostic() << "Invalid literal number '" << textValue
+                                << "'";
           return SPV_ERROR_INVALID_TEXT;
       }
     } break;
@@ -285,12 +305,14 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
       if (error != SPV_SUCCESS) {
         if (error == SPV_ERROR_OUT_OF_MEMORY) return error;
         if (spvOperandIsOptional(type)) return SPV_FAILED_MATCH;
-        context->diagnostic() << "Invalid literal string '" << textValue << "'.";
+        context->diagnostic() << "Invalid literal string '" << textValue
+                              << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
       if (literal.type != SPV_LITERAL_TYPE_STRING) {
-        context->diagnostic() << "Expected literal string, found literal number '"
-                          << textValue << "'.";
+        context->diagnostic()
+            << "Expected literal string, found literal number '" << textValue
+            << "'.";
         return SPV_FAILED_MATCH;
       }
 
@@ -311,7 +333,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
       uint32_t value;
       if (grammar.parseMaskOperand(type, textValue, &value)) {
         context->diagnostic() << "Invalid " << spvOperandTypeStr(type) << " '"
-                          << textValue << "'.";
+                              << textValue << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
       if (auto error = context->binaryEncodeU32(value, pInst)) return error;
@@ -335,7 +357,8 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
                                  textValue, pInst, pExpectedOperands);
       }
       if (error) {
-        context->diagnostic() << "Invalid word following !<integer>: " << textValue;
+        context->diagnostic() << "Invalid word following !<integer>: "
+                              << textValue;
         return error;
       }
       if (pExpectedOperands->empty()) {
@@ -346,15 +369,14 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar &grammar,
       // NOTE: All non literal operands are handled here using the operand
       // table.
       spv_operand_desc entry;
-      if (grammar.lookupOperand(type, textValue, strlen(textValue),
-                                    &entry)) {
+      if (grammar.lookupOperand(type, textValue, strlen(textValue), &entry)) {
         context->diagnostic() << "Invalid " << spvOperandTypeStr(type) << " '"
-                          << textValue << "'.";
+                              << textValue << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
       if (context->binaryEncodeU32(entry->value, pInst)) {
         context->diagnostic() << "Invalid " << spvOperandTypeStr(type) << " '"
-                          << textValue << "'.";
+                              << textValue << "'.";
         return SPV_ERROR_INVALID_TEXT;
       }
 
@@ -373,7 +395,7 @@ namespace {
 /// instruction and returns SPV_SUCCESS.  Otherwise, returns an error code and
 /// leaves position pointing to the error in text.
 spv_result_t encodeInstructionStartingWithImmediate(
-    const libspirv::AssemblyGrammar &grammar,
+    const libspirv::AssemblyGrammargrammar,
     libspirv::AssemblyContext* context, spv_instruction_t* pInst) {
   std::string firstWord;
   spv_position_t nextPosition = {};
@@ -383,17 +405,9 @@ spv_result_t encodeInstructionStartingWithImmediate(
     return error;
   }
 
-  assert(firstWord[0] == '!');
-  const char* begin = firstWord.data() + 1;
-  char* end = nullptr;
-  uint32_t immediateInt = strtoul(begin, &end, 0);
-  if ((begin + firstWord.size() - 1) != end) {
-    context->diagnostic() << "Invalid immediate integer '" << firstWord << "'.";
-    return SPV_ERROR_INVALID_TEXT;
+  if ((error = encodeImmediate(context, firstWord.c_str(), pInst))) {
+    return error;
   }
-  context->seekForward(firstWord.size());
-  pInst->words[0] = immediateInt;
-  pInst->wordCount = 1;
   while (context->advance() != SPV_END_OF_STREAM) {
     // A beginning of a new instruction means we're done.
     if (context->isStartOfNewInst()) return SPV_SUCCESS;
@@ -435,7 +449,7 @@ spv_result_t encodeInstructionStartingWithImmediate(
 /// @param[in,out] pPosition in the text stream
 ///
 /// @return result code
-spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar &grammar,
+spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammargrammar,
                                  libspirv::AssemblyContext* context,
                                  spv_assembly_syntax_format_t format,
                                  spv_instruction_t* pInst) {
@@ -473,9 +487,10 @@ spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar &grammar,
 
     result_id = firstWord;
     if ('%' != result_id.front()) {
-      context->diagnostic() << "Expected <opcode> or <result-id> at the beginning "
-                           "of an instruction, found '"
-                        << result_id << "'.";
+      context->diagnostic()
+          << "Expected <opcode> or <result-id> at the beginning "
+             "of an instruction, found '"
+          << result_id << "'.";
       return SPV_ERROR_INVALID_TEXT;
     }
     result_id_position = context->position();
@@ -516,15 +531,16 @@ spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar &grammar,
   spv_opcode_desc opcodeEntry;
   error = grammar.lookupOpcode(pInstName, &opcodeEntry);
   if (error) {
-    context->diagnostic() << "Invalid Opcode name '" << context->getWord() << "'";
+    context->diagnostic() << "Invalid Opcode name '" << context->getWord()
+                          << "'";
     return error;
   }
   if (SPV_ASSEMBLY_SYNTAX_FORMAT_ASSIGNMENT == format) {
     // If this instruction has <result-id>, check it follows AAF.
     if (opcodeEntry->hasResult && result_id.empty()) {
       context->diagnostic() << "Expected <result-id> at the beginning of an "
-                           "instruction, found '"
-                        << firstWord << "'.";
+                               "instruction, found '"
+                            << firstWord << "'.";
       return SPV_ERROR_INVALID_TEXT;
     }
   }
@@ -616,10 +632,11 @@ namespace {
 // Translates a given assembly language module into binary form.
 // If a diagnostic is generated, it is not yet marked as being
 // for a text-based input.
-spv_result_t spvTextToBinaryInternal(
-    const libspirv::AssemblyGrammar &grammar, const spv_text text,
-    spv_assembly_syntax_format_t format, spv_binary* pBinary,
-    spv_diagnostic* pDiagnostic) {
+spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar,
+                                     const spv_text text,
+                                     spv_assembly_syntax_format_t format,
+                                     spv_binary* pBinary,
+                                     spv_diagnostic* pDiagnostic) {
   if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC;
   libspirv::AssemblyContext context(text, pDiagnostic);
   if (!text->str || !text->length) {
@@ -708,8 +725,7 @@ spv_result_t spvTextWithFormatToBinary(
     spv_binary* pBinary, spv_diagnostic* pDiagnostic) {
   spv_text_t text = {input_text, input_text_size};
 
-  libspirv::AssemblyGrammar grammar(operandTable, opcodeTable,
-                                             extInstTable);
+  libspirv::AssemblyGrammar grammar(operandTable, opcodeTable, extInstTable);
   spv_result_t result =
       spvTextToBinaryInternal(grammar, &text, format, pBinary, pDiagnostic);
   if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true;
index c0afd06..6e6d333 100644 (file)
@@ -68,7 +68,6 @@ TEST_F(TextToBinaryTest, ImmediateIntOperand) {
 using ImmediateIntTest = TextToBinaryTest;
 
 TEST_F(ImmediateIntTest, AnyWordInSimpleStatement) {
-  // TODO(deki): uncomment assertions below and make them pass.
   EXPECT_THAT(CompiledInstructions("!0x0004002B %a %b 123", kCAF),
               Eq(MakeInstruction(spv::OpConstant, {1, 2, 123})));
   EXPECT_THAT(CompiledInstructions("OpConstant !1 %b 123", kCAF),
@@ -115,7 +114,6 @@ TEST_F(ImmediateIntTest, OpCodeInAssignment) {
 // Literal integers after !<integer> are handled correctly.
 TEST_F(ImmediateIntTest, IntegerFollowingImmediate) {
   const SpirvVector original = CompiledInstructions("OpTypeInt %1 8 1", kCAF);
-  // TODO(deki): uncomment assertions below and make them pass.
   EXPECT_EQ(original, CompiledInstructions("!0x00040015 1 8 1", kCAF));
   EXPECT_EQ(original, CompiledInstructions("OpTypeInt !1 8 1", kCAF));
 
@@ -127,6 +125,7 @@ TEST_F(ImmediateIntTest, IntegerFollowingImmediate) {
   EXPECT_EQ(CompiledInstructions("OpConstant %10 %2 -123", kCAF),
             CompiledInstructions("OpConstant %10 !2 -123", kCAF));
 
+  // TODO(deki): uncomment assertions below and make them pass.
   // Hex value(s).
   // EXPECT_EQ(CompileSuccessfully("OpConstant %10 %1 0x12345678", kCAF),
   //           CompileSuccessfully("OpConstant %10 !1 0x12345678", kCAF));
@@ -261,14 +260,14 @@ TEST_F(ImmediateIntTest, ConsecutiveImmediateOpcodes) {
 
 // !<integer> followed by, eg, an enum or '=' or a random bareword.
 TEST_F(ImmediateIntTest, ForbiddenOperands) {
-// TODO(deki): uncomment assertions below and make them pass.
   EXPECT_THAT(CompileFailure("OpMemoryModel !0 OpenCL"), HasSubstr("OpenCL"));
   EXPECT_THAT(CompileFailure("!1 %0 = !2"), HasSubstr("="));
-#if 0
   // Immediate integers longer than one 32-bit word.
   EXPECT_THAT(CompileFailure("!5000000000"), HasSubstr("5000000000"));
-  EXPECT_THAT(CompileFailure("!0x00020049 !5000000000"), HasSubstr("5000000000"));
-#endif
+  EXPECT_THAT(CompileFailure("!999999999999999999"),
+              HasSubstr("999999999999999999"));
+  EXPECT_THAT(CompileFailure("!0x00020049 !5000000000"),
+              HasSubstr("5000000000"));
   EXPECT_THAT(CompileFailure("OpMemoryModel !0 random_bareword"),
               HasSubstr("random_bareword"));
 }