Fix https://github.com/KhronosGroup/SPIRV-Tools/issues/1130
authorDiego Novillo <dnovillo@google.com>
Fri, 22 Dec 2017 17:38:02 +0000 (12:38 -0500)
committerDavid Neto <dneto@google.com>
Fri, 22 Dec 2017 18:33:17 +0000 (13:33 -0500)
This addresses review feedback for the CCP implementation (which fixes
https://github.com/KhronosGroup/SPIRV-Tools/issues/889).

This adds more protection around the folding of instructions that would
not be supported by the folder.

source/opt/ccp_pass.cpp
source/opt/constants.h
source/opt/fold.cpp
source/opt/fold.h
source/opt/propagator.h

index aea3e4b..13be49e 100644 (file)
 //      Constant propagation with conditional branches,
 //      Wegman and Zadeck, ACM TOPLAS 13(2):181-210.
 #include "ccp_pass.h"
-
 #include "fold.h"
 #include "function.h"
 #include "module.h"
 #include "propagator.h"
 
+#include <algorithm>
+
 namespace spvtools {
 namespace opt {
 
@@ -97,24 +98,36 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) {
 
   // Otherwise, see if the RHS of the assignment folds into a constant value.
   std::vector<uint32_t> cst_val_ids;
-  for (uint32_t i = 0; i < instr->NumInOperands(); i++) {
-    uint32_t op_id = instr->GetSingleWordInOperand(i);
-    auto it = values_.find(op_id);
-    if (it != values_.end()) {
-      cst_val_ids.push_back(it->second);
-    } else {
-      break;
+  bool missing_constants = false;
+  instr->ForEachInId([this, &cst_val_ids, &missing_constants](uint32_t* op_id) {
+    auto it = values_.find(*op_id);
+    if (it == values_.end()) {
+      missing_constants = true;
+      return;
     }
-  }
+    cst_val_ids.push_back(it->second);
+  });
 
   // If we did not find a constant value for every operand in the instruction,
   // do not bother folding it.  Indicate that this instruction does not produce
   // an interesting value for now.
-  auto constants = const_mgr_->GetConstantsFromIds(cst_val_ids);
-  if (constants.size() == 0) {
+  if (missing_constants) {
     return SSAPropagator::kNotInteresting;
   }
 
+  auto constants = const_mgr_->GetConstantsFromIds(cst_val_ids);
+  assert(constants.size() != 0 && "Found undeclared constants");
+
+  // If any of the constants are not supported by the folder, we will not be
+  // able to produce a constant out of this instruction.  Consider it varying
+  // in that case.
+  if (!std::all_of(constants.begin(), constants.end(),
+                   [](const analysis::Constant* cst) {
+                     return IsFoldableConstant(cst);
+                   })) {
+    return SSAPropagator::kVarying;
+  }
+
   // Otherwise, fold the instruction with all the operands to produce a new
   // constant.
   uint32_t result_val = FoldScalars(instr->opcode(), constants);
@@ -129,8 +142,9 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) {
 SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr,
                                                ir::BasicBlock** dest_bb) const {
   assert(instr->IsBranch() && "Expected a branch instruction.");
-  uint32_t dest_label = 0;
 
+  *dest_bb = nullptr;
+  uint32_t dest_label = 0;
   if (instr->opcode() == SpvOpBranch) {
     // An unconditional jump always goes to its unique destination.
     dest_label = instr->GetSingleWordInOperand(0);
@@ -142,7 +156,6 @@ SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr,
     auto it = values_.find(pred_id);
     if (it == values_.end()) {
       // The predicate has an unknown value, either branch could be taken.
-      *dest_bb = nullptr;
       return SSAPropagator::kVarying;
     }
 
@@ -159,11 +172,15 @@ SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr,
     // which of the target literals it matches.  The branch associated with that
     // literal is the taken branch.
     assert(instr->opcode() == SpvOpSwitch);
+    if (instr->GetOperand(0).words.size() != 1) {
+      // If the selector is wider than 32-bits, return varying. TODO(dnovillo):
+      // Add support for wider constants.
+      return SSAPropagator::kVarying;
+    }
     uint32_t select_id = instr->GetSingleWordOperand(0);
     auto it = values_.find(select_id);
     if (it == values_.end()) {
       // The selector has an unknown value, any of the branches could be taken.
-      *dest_bb = nullptr;
       return SSAPropagator::kVarying;
     }
 
index ac5708f..70d4c5d 100644 (file)
@@ -356,8 +356,8 @@ class ConstantManager {
       const Type* type, const std::vector<uint32_t>& literal_words_or_ids);
 
   // Gets or creates a Constant instance to hold the constant value of the given
-  // instruction. It returns a pointer to the Constant's defining instruction or
-  // nullptr if it could not create the constant.
+  // instruction. It returns a pointer to a Constant instance or nullptr if it
+  // could not create the constant.
   const Constant* GetConstantFromInst(ir::Instruction* inst);
 
   // Gets or creates a constant defining instruction for the given Constant |c|.
index 77e0517..96f6640 100644 (file)
@@ -173,6 +173,8 @@ uint32_t OperateWords(SpvOp opcode,
 // not accepted in this function.
 uint32_t FoldScalars(SpvOp opcode,
                      const std::vector<const analysis::Constant*>& operands) {
+  assert(IsFoldableOpcode(opcode) &&
+         "Unhandled instruction opcode in FoldScalars");
   std::vector<uint32_t> operand_values_in_raw_words;
   for (const auto& operand : operands) {
     if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
@@ -192,14 +194,11 @@ uint32_t FoldScalars(SpvOp opcode,
   return OperateWords(opcode, operand_values_in_raw_words);
 }
 
-// Returns the result of performing an operation over constant vectors. This
-// function iterates through the given vector type constant operands and
-// calculates the result for each element of the result vector to return.
-// Vectors with longer than 32-bit scalar components are not accepted in this
-// function.
 std::vector<uint32_t> FoldVectors(
     SpvOp opcode, uint32_t num_dims,
     const std::vector<const analysis::Constant*>& operands) {
+  assert(IsFoldableOpcode(opcode) &&
+         "Unhandled instruction opcode in FoldVectors");
   std::vector<uint32_t> result;
   for (uint32_t d = 0; d < num_dims; d++) {
     std::vector<uint32_t> operand_values_for_one_dimension;
@@ -282,5 +281,13 @@ bool IsFoldableOpcode(SpvOp opcode) {
   }
 }
 
+bool IsFoldableConstant(const analysis::Constant* cst) {
+  // Currently supported constants are 32-bit values or null constants.
+  if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
+    return scalar->words().size() == 1;
+  else
+    return cst->AsNullConstant() != nullptr;
+}
+
 }  // namespace opt
 }  // namespace spvtools
index eb94cf2..52954a5 100644 (file)
 namespace spvtools {
 namespace opt {
 
+// Returns the result of folding a scalar instruction with the given |opcode|
+// and |operands|. Each entry in |operands| is a pointer to an
+// analysis::Constant instance, which should've been created with the constant
+// manager (See IRContext::get_constant_mgr).
+//
+// It is an error to call this function with an opcode that does not pass the
+// IsFoldableOpcode test. If any error occurs during folding, the folder will
+// faill with a call to assert.
 uint32_t FoldScalars(SpvOp opcode,
                      const std::vector<const analysis::Constant*>& operands);
 
+// Returns the result of performing an operation with the given |opcode| over
+// constant vectors with |num_dims| dimensions.  Each entry in |operands| is a
+// pointer to an analysis::Constant instance, which should've been created with
+// the constant manager (See IRContext::get_constant_mgr).
+//
+// This function iterates through the given vector type constant operands and
+// calculates the result for each element of the result vector to return.
+// Vectors with longer than 32-bit scalar components are not accepted in this
+// function.
+//
+// It is an error to call this function with an opcode that does not pass the
+// IsFoldableOpcode test. If any error occurs during folding, the folder will
+// faill with a call to assert.
 std::vector<uint32_t> FoldVectors(
     SpvOp opcode, uint32_t num_dims,
     const std::vector<const analysis::Constant*>& operands);
 
+// Returns true if |opcode| represents an operation handled by FoldScalars or
+// FoldVectors.
 bool IsFoldableOpcode(SpvOp opcode);
 
+// Returns true if |cst| is supported by FoldScalars and FoldVectors.
+bool IsFoldableConstant(const analysis::Constant* cst);
+
 }  // namespace opt
 }  // namespace spvtools
 
index be06ff0..0c9d18a 100644 (file)
@@ -192,7 +192,8 @@ class SSAPropagator {
   bool Run(ir::Function* fn);
 
   // Returns true if the |i|th argument for |phi| comes through a CFG edge that
-  // has been marked executable.
+  // has been marked executable. |i| should be an index value accepted by
+  // Instruction::GetSingleWordOperand.
   bool IsPhiArgExecutable(ir::Instruction* phi, uint32_t i) const;
 
  private: