Implement SSA CCP (SSA Conditional Constant Propagation).
authorDiego Novillo <dnovillo@google.com>
Tue, 5 Dec 2017 16:39:25 +0000 (11:39 -0500)
committerDiego Novillo <dnovillo@google.com>
Thu, 21 Dec 2017 19:29:45 +0000 (14:29 -0500)
This implements the conditional constant propagation pass proposed in

Constant propagation with conditional branches,
Wegman and Zadeck, ACM TOPLAS 13(2):181-210.

The main logic resides in CCPPass::VisitInstruction.  Instruction that
may produce a constant value are evaluated with the constant folder. If
they produce a new constant, the instruction is considered interesting.
Otherwise, it's considered varying (for unfoldable instructions) or
just not interesting (when not enough operands have a constant value).

The other main piece of logic is in CCPPass::VisitBranch.  This
evaluates the selector of the branch.  When it's found to be a known
value, it computes the destination basic block and sets it.  This tells
the propagator which branches to follow.

The patch required extensions to the constant manager as well. Instead
of hashing the Constant pointers, this patch changes the constant pool
to hash the contents of the Constant.  This allows the lookups to be
done using the actual values of the Constant, preventing duplicate
definitions.

20 files changed:
Android.mk
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/ccp_pass.cpp [new file with mode: 0644]
source/opt/ccp_pass.h [new file with mode: 0644]
source/opt/constants.cpp
source/opt/constants.h
source/opt/fold.cpp
source/opt/fold.h
source/opt/fold_spec_constant_op_and_composite_pass.cpp
source/opt/instruction.cpp
source/opt/instruction.h
source/opt/optimizer.cpp
source/opt/passes.h
source/opt/propagator.cpp
source/opt/propagator.h
test/opt/CMakeLists.txt
test/opt/ccp_test.cpp [new file with mode: 0644]
test/opt/fold_spec_const_op_composite_test.cpp
tools/opt/opt.cpp

index a0da131..08b9080 100644 (file)
@@ -59,6 +59,7 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/build_module.cpp \
                source/opt/cfg.cpp \
                source/opt/cfg_cleanup_pass.cpp \
+               source/opt/ccp_pass.cpp \
                source/opt/compact_ids_pass.cpp \
                source/opt/common_uniform_elim_pass.cpp \
                source/opt/constants.cpp \
index 5d2e306..afc544a 100644 (file)
@@ -452,6 +452,18 @@ Optimizer::PassToken CreateScalarReplacementPass();
 // used in only one function.  Those variables are moved to the function storage
 // class in the function that they are used.
 Optimizer::PassToken CreatePrivateToLocalPass();
+
+// Creates a conditional constant propagation (CCP) pass.
+// This pass implements the SSA-CCP algorithm in
+//
+//      Constant propagation with conditional branches,
+//      Wegman and Zadeck, ACM TOPLAS 13(2):181-210.
+//
+// Constant values in expressions and conditional jumps are folded and
+// simplified. This may reduce code size by removing never executed jump targets
+// and computations with constant operands.
+Optimizer::PassToken CreateCCPPass();
+
 }  // namespace spvtools
 
 #endif  // SPIRV_TOOLS_OPTIMIZER_HPP_
index 0da8961..0b49842 100644 (file)
@@ -16,6 +16,7 @@ add_library(SPIRV-Tools-opt
   basic_block.h
   block_merge_pass.h
   build_module.h
+  ccp_pass.h
   cfg_cleanup_pass.h
   cfg.h
   common_uniform_elim_pass.h
@@ -74,6 +75,7 @@ add_library(SPIRV-Tools-opt
   basic_block.cpp
   block_merge_pass.cpp
   build_module.cpp
+  ccp_pass.cpp
   cfg_cleanup_pass.cpp
   cfg.cpp
   common_uniform_elim_pass.cpp
diff --git a/source/opt/ccp_pass.cpp b/source/opt/ccp_pass.cpp
new file mode 100644 (file)
index 0000000..aea3e4b
--- /dev/null
@@ -0,0 +1,264 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file implements conditional constant propagation as described in
+//
+//      Constant propagation with conditional branches,
+//      Wegman and Zadeck, ACM TOPLAS 13(2):181-210.
+#include "ccp_pass.h"
+
+#include "fold.h"
+#include "function.h"
+#include "module.h"
+#include "propagator.h"
+
+namespace spvtools {
+namespace opt {
+
+SSAPropagator::PropStatus CCPPass::VisitPhi(ir::Instruction* phi) {
+  uint32_t meet_val_id = 0;
+
+  // Implement the lattice meet operation. The result of this Phi instruction is
+  // interesting only if the meet operation over arguments coming through
+  // executable edges yields the same constant value.
+  for (uint32_t i = 2; i < phi->NumOperands(); i += 2) {
+    if (!propagator_->IsPhiArgExecutable(phi, i)) {
+      // Ignore arguments coming through non-executable edges.
+      continue;
+    }
+    uint32_t phi_arg_id = phi->GetSingleWordOperand(i);
+    auto it = values_.find(phi_arg_id);
+    if (it != values_.end()) {
+      // We found an argument with a constant value.  Apply the meet operation
+      // with the previous arguments.
+      if (meet_val_id == 0) {
+        // This is the first argument we find.  Initialize the result to its
+        // constant value id.
+        meet_val_id = it->second;
+      } else if (it->second == meet_val_id) {
+        // The argument is the same constant value already computed. Continue
+        // looking.
+        continue;
+      } else {
+        // We found another constant value, but it is different from the
+        // previous computed meet value.  This Phi will never be constant.
+        return SSAPropagator::kVarying;
+      }
+    } else {
+      // If any argument is not a constant, the Phi produces nothing
+      // interesting for now. The propagator will callback again, if needed.
+      return SSAPropagator::kNotInteresting;
+    }
+  }
+
+  // If there are no incoming executable edges, the meet ID will still be 0. In
+  // that case, return not interesting to evaluate the Phi node again.
+  if (meet_val_id == 0) {
+    return SSAPropagator::kNotInteresting;
+  }
+
+  // All the operands have the same constant value represented by |meet_val_id|.
+  // Set the Phi's result to that value and declare it interesting.
+  values_[phi->result_id()] = meet_val_id;
+  return SSAPropagator::kInteresting;
+}
+
+SSAPropagator::PropStatus CCPPass::VisitAssignment(ir::Instruction* instr) {
+  assert(instr->result_id() != 0 &&
+         "Expecting an instruction that produces a result");
+
+  // If this is a copy operation, and the RHS is a known constant, assign its
+  // value to the LHS.
+  if (instr->opcode() == SpvOpCopyObject) {
+    uint32_t rhs_id = instr->GetSingleWordInOperand(0);
+    auto it = values_.find(rhs_id);
+    if (it != values_.end()) {
+      values_[instr->result_id()] = it->second;
+      return SSAPropagator::kInteresting;
+    }
+    return SSAPropagator::kNotInteresting;
+  }
+
+  // Instructions with a RHS that cannot produce a constant are always varying.
+  if (!instr->IsFoldable()) {
+    return SSAPropagator::kVarying;
+  }
+
+  // Otherwise, see if the RHS of the assignment folds into a constant value.
+  std::vector<uint32_t> cst_val_ids;
+  for (uint32_t i = 0; i < instr->NumInOperands(); i++) {
+    uint32_t op_id = instr->GetSingleWordInOperand(i);
+    auto it = values_.find(op_id);
+    if (it != values_.end()) {
+      cst_val_ids.push_back(it->second);
+    } else {
+      break;
+    }
+  }
+
+  // If we did not find a constant value for every operand in the instruction,
+  // do not bother folding it.  Indicate that this instruction does not produce
+  // an interesting value for now.
+  auto constants = const_mgr_->GetConstantsFromIds(cst_val_ids);
+  if (constants.size() == 0) {
+    return SSAPropagator::kNotInteresting;
+  }
+
+  // Otherwise, fold the instruction with all the operands to produce a new
+  // constant.
+  uint32_t result_val = FoldScalars(instr->opcode(), constants);
+  const analysis::Constant* result_const =
+      const_mgr_->GetConstant(const_mgr_->GetType(instr), {result_val});
+  ir::Instruction* const_decl =
+      const_mgr_->GetDefiningInstruction(result_const);
+  values_[instr->result_id()] = const_decl->result_id();
+  return SSAPropagator::kInteresting;
+}
+
+SSAPropagator::PropStatus CCPPass::VisitBranch(ir::Instruction* instr,
+                                               ir::BasicBlock** dest_bb) const {
+  assert(instr->IsBranch() && "Expected a branch instruction.");
+  uint32_t dest_label = 0;
+
+  if (instr->opcode() == SpvOpBranch) {
+    // An unconditional jump always goes to its unique destination.
+    dest_label = instr->GetSingleWordInOperand(0);
+  } else if (instr->opcode() == SpvOpBranchConditional) {
+    // For a conditional branch, determine whether the predicate selector has a
+    // known value in |values_|.  If it does, set the destination block
+    // according to the selector's boolean value.
+    uint32_t pred_id = instr->GetSingleWordOperand(0);
+    auto it = values_.find(pred_id);
+    if (it == values_.end()) {
+      // The predicate has an unknown value, either branch could be taken.
+      *dest_bb = nullptr;
+      return SSAPropagator::kVarying;
+    }
+
+    // Get the constant value for the predicate selector from the value table.
+    // Use it to decide which branch will be taken.
+    uint32_t pred_val_id = it->second;
+    const analysis::Constant* c = const_mgr_->FindDeclaredConstant(pred_val_id);
+    assert(c && "Expected to find a constant declaration for a known value.");
+    const analysis::BoolConstant* val = c->AsBoolConstant();
+    dest_label = val->value() ? instr->GetSingleWordOperand(1)
+                              : instr->GetSingleWordOperand(2);
+  } else {
+    // For an OpSwitch, extract the value taken by the switch selector and check
+    // which of the target literals it matches.  The branch associated with that
+    // literal is the taken branch.
+    assert(instr->opcode() == SpvOpSwitch);
+    uint32_t select_id = instr->GetSingleWordOperand(0);
+    auto it = values_.find(select_id);
+    if (it == values_.end()) {
+      // The selector has an unknown value, any of the branches could be taken.
+      *dest_bb = nullptr;
+      return SSAPropagator::kVarying;
+    }
+
+    // Get the constant value for the selector from the value table. Use it to
+    // decide which branch will be taken.
+    uint32_t select_val_id = it->second;
+    const analysis::Constant* c =
+        const_mgr_->FindDeclaredConstant(select_val_id);
+    assert(c && "Expected to find a constant declaration for a known value.");
+    const analysis::IntConstant* val = c->AsIntConstant();
+
+    // Start assuming that the selector will take the default value;
+    dest_label = instr->GetSingleWordOperand(1);
+    for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
+      if (val->words()[0] == instr->GetSingleWordOperand(i)) {
+        dest_label = instr->GetSingleWordOperand(i + 1);
+        break;
+      }
+    }
+  }
+
+  assert(dest_label && "Destination label should be set at this point.");
+  *dest_bb = context()->cfg()->block(dest_label);
+  return SSAPropagator::kInteresting;
+}
+
+SSAPropagator::PropStatus CCPPass::VisitInstruction(ir::Instruction* instr,
+                                                    ir::BasicBlock** dest_bb) {
+  *dest_bb = nullptr;
+  if (instr->opcode() == SpvOpPhi) {
+    return VisitPhi(instr);
+  } else if (instr->IsBranch()) {
+    return VisitBranch(instr, dest_bb);
+  } else if (instr->result_id()) {
+    return VisitAssignment(instr);
+  }
+  return SSAPropagator::kVarying;
+}
+
+bool CCPPass::ReplaceValues() {
+  bool retval = false;
+  for (const auto& it : values_) {
+    uint32_t id = it.first;
+    uint32_t cst_id = it.second;
+    if (id != cst_id) {
+      retval |= context()->ReplaceAllUsesWith(id, cst_id);
+    }
+  }
+  return retval;
+}
+
+bool CCPPass::PropagateConstants(ir::Function* fp) {
+  const auto visit_fn = [this](ir::Instruction* instr,
+                               ir::BasicBlock** dest_bb) {
+    return VisitInstruction(instr, dest_bb);
+  };
+
+  InsertPhiInstructions(fp);
+  propagator_ =
+      std::unique_ptr<SSAPropagator>(new SSAPropagator(context(), visit_fn));
+  if (propagator_->Run(fp)) {
+    return ReplaceValues();
+  }
+
+  return false;
+}
+
+void CCPPass::Initialize(ir::IRContext* c) {
+  InitializeProcessing(c);
+
+  const_mgr_ = context()->get_constant_mgr();
+
+  // Populate the constant table with values from constant declarations in the
+  // module.  The values of each OpConstant declaration is the identity
+  // assignment (i.e., each constant is its own value).
+  for (const auto& inst : c->module()->GetConstants()) {
+    values_[inst->result_id()] = inst->result_id();
+    if (!const_mgr_->MapInst(inst)) {
+      assert(false &&
+             "Could not map a new constant value to its defining instruction");
+    }
+  }
+}
+
+Pass::Status CCPPass::Process(ir::IRContext* c) {
+  Initialize(c);
+
+  // Process all entry point functions.
+  ProcessFunction pfn = [this](ir::Function* fp) {
+    return PropagateConstants(fp);
+  };
+  bool modified = ProcessReachableCallTree(pfn, context());
+  return modified ? Pass::Status::SuccessWithChange
+                  : Pass::Status::SuccessWithoutChange;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/ccp_pass.h b/source/opt/ccp_pass.h
new file mode 100644 (file)
index 0000000..ba54dc7
--- /dev/null
@@ -0,0 +1,87 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_CCP_PASS_H_
+#define LIBSPIRV_OPT_CCP_PASS_H_
+
+#include "constants.h"
+#include "function.h"
+#include "ir_context.h"
+#include "mem_pass.h"
+#include "module.h"
+#include "propagator.h"
+
+namespace spvtools {
+namespace opt {
+
+class CCPPass : public MemPass {
+ public:
+  CCPPass() = default;
+  const char* name() const override { return "ccp"; }
+  Status Process(ir::IRContext* c) override;
+
+ private:
+  // Initializes the pass.
+  void Initialize(ir::IRContext* c);
+
+  // Runs constant propagation on the given function |fp|. Returns true if any
+  // constants were propagated and the IR modified.
+  bool PropagateConstants(ir::Function* fp);
+
+  // Visits a single instruction |instr|.  If the instruction is a conditional
+  // branch that always jumps to the same basic block, it sets the destination
+  // block in |dest_bb|.
+  SSAPropagator::PropStatus VisitInstruction(ir::Instruction* instr,
+                                             ir::BasicBlock** dest_bb);
+
+  // Visits an OpPhi instruction |phi|. This applies the meet operator for the
+  // CCP lattice. Essentially, if all the operands in |phi| have the same
+  // constant value C, the result for |phi| gets assigned the value C.
+  SSAPropagator::PropStatus VisitPhi(ir::Instruction* phi);
+
+  // Visits an SSA assignment instruction |instr|.  If the RHS of |instr| folds
+  // into a constant value C, then the LHS of |instr| is assigned the value C in
+  // |values_|.
+  SSAPropagator::PropStatus VisitAssignment(ir::Instruction* instr);
+
+  // Visits a branch instruction |instr|. If the branch is conditional
+  // (OpBranchConditional or OpSwitch), and the value of its selector is known,
+  // |dest_bb| will be set to the corresponding destination block. Unconditional
+  // branches always set |dest_bb| to the single destination block.
+  SSAPropagator::PropStatus VisitBranch(ir::Instruction* instr,
+                                        ir::BasicBlock** dest_bb) const;
+
+  // Replaces all operands used in |fp| with the corresponding constant values
+  // in |values_|.  Returns true if any operands were replaced, and false
+  // otherwise.
+  bool ReplaceValues();
+
+  // Constant manager for the parent IR context.  Used to record new constants
+  // generated during propagation.
+  analysis::ConstantManager* const_mgr_;
+
+  // Constant value table.  Each entry <id, const_decl_id> in this map
+  // represents the compile-time constant value for |id| as declared by
+  // |const_decl_id|. Each |const_decl_id| in this table is an OpConstant
+  // declaration for the current module.
+  std::unordered_map<uint32_t, uint32_t> values_;
+
+  // Propagator engine used.
+  std::unique_ptr<SSAPropagator> propagator_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif
index c2b1a61..d4224d2 100644 (file)
@@ -22,25 +22,15 @@ namespace spvtools {
 namespace opt {
 namespace analysis {
 
-analysis::Type* ConstantManager::GetType(const ir::Instruction* inst) const {
+Type* ConstantManager::GetType(const ir::Instruction* inst) const {
   return context()->get_type_mgr()->GetType(inst->type_id());
 }
 
-uint32_t ConstantManager::FindRecordedConstant(
-    const analysis::Constant* c) const {
-  auto iter = const_val_to_id_.find(c);
-  if (iter == const_val_to_id_.end()) {
-    return 0;
-  } else {
-    return iter->second;
-  }
-}
-
-std::vector<const analysis::Constant*> ConstantManager::GetConstantsFromIds(
+std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
     const std::vector<uint32_t>& ids) const {
-  std::vector<const analysis::Constant*> constants;
+  std::vector<const Constant*> constants;
   for (uint32_t id : ids) {
-    if (analysis::Constant* c = FindRecordedConstant(id)) {
+    if (const Constant* c = FindDeclaredConstant(id)) {
       constants.push_back(c);
     } else {
       return {};
@@ -50,51 +40,52 @@ std::vector<const analysis::Constant*> ConstantManager::GetConstantsFromIds(
 }
 
 ir::Instruction* ConstantManager::BuildInstructionAndAddToModule(
-    std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
+    const Constant* new_const, ir::Module::inst_iterator* pos,
     uint32_t type_id) {
-  analysis::Constant* new_const = c.get();
   uint32_t new_id = context()->TakeNextId();
-  const_val_to_id_[new_const] = new_id;
-  id_to_const_val_[new_id] = std::move(c);
   auto new_inst = CreateInstruction(new_id, new_const, type_id);
-  if (!new_inst) return nullptr;
+  if (!new_inst) {
+    return nullptr;
+  }
   auto* new_inst_ptr = new_inst.get();
   *pos = pos->InsertBefore(std::move(new_inst));
   ++(*pos);
   context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
+  MapConstantToInst(new_const, new_inst_ptr);
   return new_inst_ptr;
 }
 
-analysis::Constant* ConstantManager::FindRecordedConstant(uint32_t id) const {
-  auto iter = id_to_const_val_.find(id);
-  if (iter == id_to_const_val_.end()) {
-    return nullptr;
+ir::Instruction* ConstantManager::GetDefiningInstruction(
+    const Constant* c, ir::Module::inst_iterator* pos) {
+  uint32_t decl_id = FindDeclaredConstant(c);
+  if (decl_id == 0) {
+    auto iter = context()->types_values_end();
+    if (pos == nullptr) pos = &iter;
+    return BuildInstructionAndAddToModule(c, pos);
   } else {
-    return iter->second.get();
+    return context()->get_def_use_mgr()->GetDef(decl_id);
   }
 }
 
-std::unique_ptr<analysis::Constant> ConstantManager::CreateConstant(
-    const analysis::Type* type,
-    const std::vector<uint32_t>& literal_words_or_ids) const {
-  std::unique_ptr<analysis::Constant> new_const;
+const Constant* ConstantManager::CreateConstant(
+    const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
   if (literal_words_or_ids.size() == 0) {
     // Constant declared with OpConstantNull
-    return MakeUnique<analysis::NullConstant>(type);
+    return new NullConstant(type);
   } else if (auto* bt = type->AsBool()) {
     assert(literal_words_or_ids.size() == 1 &&
            "Bool constant should be declared with one operand");
-    return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front());
+    return new BoolConstant(bt, literal_words_or_ids.front());
   } else if (auto* it = type->AsInteger()) {
-    return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids);
+    return new IntConstant(it, literal_words_or_ids);
   } else if (auto* ft = type->AsFloat()) {
-    return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids);
+    return new FloatConstant(ft, literal_words_or_ids);
   } else if (auto* vt = type->AsVector()) {
     auto components = GetConstantsFromIds(literal_words_or_ids);
     if (components.empty()) return nullptr;
     // All components of VectorConstant must be of type Bool, Integer or Float.
     if (!std::all_of(components.begin(), components.end(),
-                     [](const analysis::Constant* c) {
+                     [](const Constant* c) {
                        if (c->type()->AsBool() || c->type()->AsInteger() ||
                            c->type()->AsFloat()) {
                          return true;
@@ -106,29 +97,27 @@ std::unique_ptr<analysis::Constant> ConstantManager::CreateConstant(
     // All components of VectorConstant must be in the same type.
     const auto* component_type = components.front()->type();
     if (!std::all_of(components.begin(), components.end(),
-                     [&component_type](const analysis::Constant* c) {
+                     [&component_type](const Constant* c) {
                        if (c->type() == component_type) return true;
                        return false;
                      }))
       return nullptr;
-    return MakeUnique<analysis::VectorConstant>(vt, components);
+    return new VectorConstant(vt, components);
   } else if (auto* st = type->AsStruct()) {
     auto components = GetConstantsFromIds(literal_words_or_ids);
     if (components.empty()) return nullptr;
-    return MakeUnique<analysis::StructConstant>(st, components);
+    return new StructConstant(st, components);
   } else if (auto* at = type->AsArray()) {
     auto components = GetConstantsFromIds(literal_words_or_ids);
     if (components.empty()) return nullptr;
-    return MakeUnique<analysis::ArrayConstant>(at, components);
+    return new ArrayConstant(at, components);
   } else {
     return nullptr;
   }
 }
 
-std::unique_ptr<analysis::Constant> ConstantManager::CreateConstantFromInst(
-    ir::Instruction* inst) const {
+const Constant* ConstantManager::GetConstantFromInst(ir::Instruction* inst) {
   std::vector<uint32_t> literal_words_or_ids;
-  std::unique_ptr<analysis::Constant> new_const;
 
   // Collect the constant defining literals or component ids.
   for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
@@ -138,7 +127,7 @@ std::unique_ptr<analysis::Constant> ConstantManager::CreateConstantFromInst(
   }
 
   switch (inst->opcode()) {
-    // OpConstant{True|Flase} have the value embedded in the opcode. So they
+    // OpConstant{True|False} have the value embedded in the opcode. So they
     // are not handled by the for-loop above. Here we add the value explicitly.
     case SpvOp::SpvOpConstantTrue:
       literal_words_or_ids.push_back(true);
@@ -154,35 +143,36 @@ std::unique_ptr<analysis::Constant> ConstantManager::CreateConstantFromInst(
     default:
       return nullptr;
   }
-  return CreateConstant(GetType(inst), literal_words_or_ids);
+
+  return GetConstant(GetType(inst), literal_words_or_ids);
 }
 
 std::unique_ptr<ir::Instruction> ConstantManager::CreateInstruction(
-    uint32_t id, analysis::Constant* c, uint32_t type_id) const {
+    uint32_t id, const Constant* c, uint32_t type_id) const {
   uint32_t type =
       (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
   if (c->AsNullConstant()) {
     return MakeUnique<ir::Instruction>(context(), SpvOp::SpvOpConstantNull,
                                        type, id,
                                        std::initializer_list<ir::Operand>{});
-  } else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
+  } else if (const BoolConstant* bc = c->AsBoolConstant()) {
     return MakeUnique<ir::Instruction>(
         context(),
         bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
         type, id, std::initializer_list<ir::Operand>{});
-  } else if (analysis::IntConstant* ic = c->AsIntConstant()) {
+  } else if (const IntConstant* ic = c->AsIntConstant()) {
     return MakeUnique<ir::Instruction>(
         context(), SpvOp::SpvOpConstant, type, id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             ic->words())});
-  } else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
+  } else if (const FloatConstant* fc = c->AsFloatConstant()) {
     return MakeUnique<ir::Instruction>(
         context(), SpvOp::SpvOpConstant, type, id,
         std::initializer_list<ir::Operand>{ir::Operand(
             spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
             fc->words())});
-  } else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
+  } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
     return CreateCompositeInstruction(id, cc, type_id);
   } else {
     return nullptr;
@@ -190,11 +180,10 @@ std::unique_ptr<ir::Instruction> ConstantManager::CreateInstruction(
 }
 
 std::unique_ptr<ir::Instruction> ConstantManager::CreateCompositeInstruction(
-    uint32_t result_id, analysis::CompositeConstant* cc,
-    uint32_t type_id) const {
+    uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
   std::vector<ir::Operand> operands;
-  for (const analysis::Constant* component_const : cc->GetComponents()) {
-    uint32_t id = FindRecordedConstant(component_const);
+  for (const Constant* component_const : cc->GetComponents()) {
+    uint32_t id = FindDeclaredConstant(component_const);
     if (id == 0) {
       // Cannot get the id of the component constant, while all components
       // should have been added to the module prior to the composite constant.
@@ -210,6 +199,12 @@ std::unique_ptr<ir::Instruction> ConstantManager::CreateCompositeInstruction(
                                      type, result_id, std::move(operands));
 }
 
+const Constant* ConstantManager::GetConstant(
+    const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
+  auto cst = CreateConstant(type, literal_words_or_ids);
+  return cst ? RegisterConstant(cst) : nullptr;
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools
index eda26b4..ac5708f 100644 (file)
 #ifndef LIBSPIRV_OPT_CONSTANTS_H_
 #define LIBSPIRV_OPT_CONSTANTS_H_
 
+#include <cinttypes>
 #include <memory>
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -76,13 +79,13 @@ class Constant {
   virtual const ArrayConstant* AsArrayConstant() const { return nullptr; }
   virtual const NullConstant* AsNullConstant() const { return nullptr; }
 
-  const analysis::Type* type() const { return type_; }
+  const Type* type() const { return type_; }
 
  protected:
-  Constant(const analysis::Type* ty) : type_(ty) {}
+  Constant(const Type* ty) : type_(ty) {}
 
   // The type of this constant.
-  const analysis::Type* type_;
+  const Type* type_;
 };
 
 // Abstract class for scalar type constants.
@@ -96,9 +99,9 @@ class ScalarConstant : public Constant {
   virtual const std::vector<uint32_t>& words() const { return words_; }
 
  protected:
-  ScalarConstant(const analysis::Type* ty, const std::vector<uint32_t>& w)
+  ScalarConstant(const Type* ty, const std::vector<uint32_t>& w)
       : Constant(ty), words_(w) {}
-  ScalarConstant(const analysis::Type* ty, std::vector<uint32_t>&& w)
+  ScalarConstant(const Type* ty, std::vector<uint32_t>&& w)
       : Constant(ty), words_(std::move(w)) {}
   std::vector<uint32_t> words_;
 };
@@ -106,9 +109,9 @@ class ScalarConstant : public Constant {
 // Integer type constant.
 class IntConstant : public ScalarConstant {
  public:
-  IntConstant(const analysis::Integer* ty, const std::vector<uint32_t>& w)
+  IntConstant(const Integer* ty, const std::vector<uint32_t>& w)
       : ScalarConstant(ty, w) {}
-  IntConstant(const analysis::Integer* ty, std::vector<uint32_t>&& w)
+  IntConstant(const Integer* ty, std::vector<uint32_t>&& w)
       : ScalarConstant(ty, std::move(w)) {}
 
   IntConstant* AsIntConstant() override { return this; }
@@ -126,9 +129,9 @@ class IntConstant : public ScalarConstant {
 // Float type constant.
 class FloatConstant : public ScalarConstant {
  public:
-  FloatConstant(const analysis::Float* ty, const std::vector<uint32_t>& w)
+  FloatConstant(const Float* ty, const std::vector<uint32_t>& w)
       : ScalarConstant(ty, w) {}
-  FloatConstant(const analysis::Float* ty, std::vector<uint32_t>&& w)
+  FloatConstant(const Float* ty, std::vector<uint32_t>&& w)
       : ScalarConstant(ty, std::move(w)) {}
 
   FloatConstant* AsFloatConstant() override { return this; }
@@ -146,7 +149,7 @@ class FloatConstant : public ScalarConstant {
 // Bool type constant.
 class BoolConstant : public ScalarConstant {
  public:
-  BoolConstant(const analysis::Bool* ty, bool v)
+  BoolConstant(const Bool* ty, bool v)
       : ScalarConstant(ty, {static_cast<uint32_t>(v)}), value_(v) {}
 
   BoolConstant* AsBoolConstant() override { return this; }
@@ -180,12 +183,11 @@ class CompositeConstant : public Constant {
   }
 
  protected:
-  CompositeConstant(const analysis::Type* ty) : Constant(ty), components_() {}
-  CompositeConstant(const analysis::Type* ty,
+  CompositeConstant(const Type* ty) : Constant(ty), components_() {}
+  CompositeConstant(const Type* ty,
                     const std::vector<const Constant*>& components)
       : Constant(ty), components_(components) {}
-  CompositeConstant(const analysis::Type* ty,
-                    std::vector<const Constant*>&& components)
+  CompositeConstant(const Type* ty, std::vector<const Constant*>&& components)
       : Constant(ty), components_(std::move(components)) {}
   std::vector<const Constant*> components_;
 };
@@ -193,12 +195,11 @@ class CompositeConstant : public Constant {
 // Struct type constant.
 class StructConstant : public CompositeConstant {
  public:
-  StructConstant(const analysis::Struct* ty) : CompositeConstant(ty) {}
-  StructConstant(const analysis::Struct* ty,
+  StructConstant(const Struct* ty) : CompositeConstant(ty) {}
+  StructConstant(const Struct* ty,
                  const std::vector<const Constant*>& components)
       : CompositeConstant(ty, components) {}
-  StructConstant(const analysis::Struct* ty,
-                 std::vector<const Constant*>&& components)
+  StructConstant(const Struct* ty, std::vector<const Constant*>&& components)
       : CompositeConstant(ty, std::move(components)) {}
 
   StructConstant* AsStructConstant() override { return this; }
@@ -216,14 +217,13 @@ class StructConstant : public CompositeConstant {
 // Vector type constant.
 class VectorConstant : public CompositeConstant {
  public:
-  VectorConstant(const analysis::Vector* ty)
+  VectorConstant(const Vector* ty)
       : CompositeConstant(ty), component_type_(ty->element_type()) {}
-  VectorConstant(const analysis::Vector* ty,
+  VectorConstant(const Vector* ty,
                  const std::vector<const Constant*>& components)
       : CompositeConstant(ty, components),
         component_type_(ty->element_type()) {}
-  VectorConstant(const analysis::Vector* ty,
-                 std::vector<const Constant*>&& components)
+  VectorConstant(const Vector* ty, std::vector<const Constant*>&& components)
       : CompositeConstant(ty, std::move(components)),
         component_type_(ty->element_type()) {}
 
@@ -241,21 +241,19 @@ class VectorConstant : public CompositeConstant {
     return std::unique_ptr<Constant>(CopyVectorConstant().release());
   }
 
-  const analysis::Type* component_type() { return component_type_; }
+  const Type* component_type() { return component_type_; }
 
  private:
-  const analysis::Type* component_type_;
+  const Type* component_type_;
 };
 
 // Array type constant.
 class ArrayConstant : public CompositeConstant {
  public:
-  ArrayConstant(const analysis::Array* ty) : CompositeConstant(ty) {}
-  ArrayConstant(const analysis::Array* ty,
-                const std::vector<const Constant*>& components)
+  ArrayConstant(const Array* ty) : CompositeConstant(ty) {}
+  ArrayConstant(const Array* ty, const std::vector<const Constant*>& components)
       : CompositeConstant(ty, components) {}
-  ArrayConstant(const analysis::Array* ty,
-                std::vector<const Constant*>&& components)
+  ArrayConstant(const Array* ty, std::vector<const Constant*>&& components)
       : CompositeConstant(ty, std::move(components)) {}
 
   ArrayConstant* AsArrayConstant() override { return this; }
@@ -273,7 +271,7 @@ class ArrayConstant : public CompositeConstant {
 // Null type constant.
 class NullConstant : public Constant {
  public:
-  NullConstant(const analysis::Type* ty) : Constant(ty) {}
+  NullConstant(const Type* ty) : Constant(ty) {}
   NullConstant* AsNullConstant() override { return this; }
   const NullConstant* AsNullConstant() const override { return this; }
 
@@ -288,6 +286,60 @@ class NullConstant : public Constant {
 
 class IRContext;
 
+// Hash function for Constant instances. Use the structure of the constant as
+// the key.
+struct ConstantHash {
+  void add_pointer(std::u32string* h, const void* p) const {
+    uint64_t ptr_val = reinterpret_cast<uint64_t>(p);
+    h->push_back(static_cast<uint32_t>(ptr_val >> 32));
+    h->push_back(static_cast<uint32_t>(ptr_val));
+  }
+
+  size_t operator()(const Constant* const_val) const {
+    std::u32string h;
+    add_pointer(&h, const_val->type());
+    if (const auto scalar = const_val->AsScalarConstant()) {
+      for (const auto& w : scalar->words()) {
+        h.push_back(w);
+      }
+    } else if (const auto composite = const_val->AsCompositeConstant()) {
+      for (const auto& c : composite->GetComponents()) {
+        add_pointer(&h, c);
+      }
+    } else if (const_val->AsNullConstant()) {
+      h.push_back(0);
+    } else {
+      assert(
+          false &&
+          "Tried to compute the hash value of an invalid Constant instance.");
+    }
+
+    return std::hash<std::u32string>()(h);
+  }
+};
+
+// Equality comparison structure for two constants.
+struct ConstantEqual {
+  bool operator()(const Constant* c1, const Constant* c2) const {
+    if (c1->type() != c2->type()) {
+      return false;
+    }
+
+    if (const auto& s1 = c1->AsScalarConstant()) {
+      const auto& s2 = c2->AsScalarConstant();
+      return s2 && s1->words() == s2->words();
+    } else if (const auto& composite1 = c1->AsCompositeConstant()) {
+      const auto& composite2 = c2->AsCompositeConstant();
+      return composite2 &&
+             composite1->GetComponents() == composite2->GetComponents();
+    } else if (c1->AsNullConstant())
+      return c2->AsNullConstant() != nullptr;
+    else
+      assert(false && "Tried to compare two invalid Constant instances.");
+    return false;
+  }
+};
+
 // This class represents a pool of constants.
 class ConstantManager {
  public:
@@ -295,28 +347,27 @@ class ConstantManager {
 
   ir::IRContext* context() const { return ctx_; }
 
-  // Creates a Constant instance with the given type and a vector of constant
-  // defining words. Returns an unique pointer to the created Constant instance
-  // if the Constant instance can be created successfully. To create scalar
-  // type constants, the vector should contain the constant value in 32 bit
-  // words and the given type must be of type Bool, Integer or Float. To create
-  // composite type constants, the vector should contain the component ids, and
-  // those component ids should have been recorded before as Normal Constants.
-  // And the given type must be of type Struct, Vector or Array. When creating
-  // VectorType Constant instance, the components must be scalars of the same
-  // type, either Bool, Integer or Float. If any of the rules above failed, the
-  // creation will fail and nullptr will be returned. If the vector is empty,
-  // a NullConstant instance will be created with the given type.
-  std::unique_ptr<Constant> CreateConstant(
-      const Type* type,
-      const std::vector<uint32_t>& literal_words_or_ids) const;
-
-  // Creates a Constant instance to hold the constant value of the given
-  // instruction. If the given instruction defines a normal constants whose
-  // value is already known in the module, returns the unique pointer to the
-  // created Constant instance. Otherwise does not create anything and returns a
-  // nullptr.
-  std::unique_ptr<Constant> CreateConstantFromInst(ir::Instruction* inst) const;
+  // Gets or creates a unique Constant instance of type |type| and a vector of
+  // constant defining words |words|. If a Constant instance existed already in
+  // the constant pool, it returns a pointer to it.  Otherwise, it creates one
+  // using CreateConstant. If a new Constant instance cannot be created, it
+  // returns nullptr.
+  const Constant* GetConstant(
+      const Type* type, const std::vector<uint32_t>& literal_words_or_ids);
+
+  // Gets or creates a Constant instance to hold the constant value of the given
+  // instruction. It returns a pointer to the Constant's defining instruction or
+  // nullptr if it could not create the constant.
+  const Constant* GetConstantFromInst(ir::Instruction* inst);
+
+  // Gets or creates a constant defining instruction for the given Constant |c|.
+  // If |c| had already been defined, it returns a pointer to the existing
+  // declaration. Otherwise, it calls BuildInstructionAndAddToModule. If the
+  // optional |pos| is given, it will insert any newly created instructions at
+  // the given instruction iterator position. Otherwise, it inserts the new
+  // instruction at the end of the current module's types section.
+  ir::Instruction* GetDefiningInstruction(
+      const Constant* c, ir::Module::inst_iterator* pos = nullptr);
 
   // Creates a constant defining instruction for the given Constant instance
   // and inserts the instruction at the position specified by the given
@@ -331,8 +382,82 @@ class ConstantManager {
   // the type of the constant is derived by getting an id from the type manager
   // for |c|.
   ir::Instruction* BuildInstructionAndAddToModule(
-      std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos,
-      uint32_t type_id = 0);
+      const Constant* c, ir::Module::inst_iterator* pos, uint32_t type_id = 0);
+
+  // A helper function to get the result type of the given instruction. Returns
+  // nullptr if the instruction does not have a type id (type id is 0).
+  Type* GetType(const ir::Instruction* inst) const;
+
+  // A helper function to get the collected normal constant with the given id.
+  // Returns the pointer to the Constant instance in case it is found.
+  // Otherwise, it returns a null pointer.
+  const Constant* FindDeclaredConstant(uint32_t id) const {
+    auto iter = id_to_const_val_.find(id);
+    return (iter != id_to_const_val_.end()) ? iter->second : nullptr;
+  }
+
+  // A helper function to get the id of a collected constant with the pointer
+  // to the Constant instance. Returns 0 in case the constant is not found.
+  uint32_t FindDeclaredConstant(const Constant* c) const {
+    auto iter = const_val_to_id_.find(c);
+    return (iter != const_val_to_id_.end()) ? iter->second : 0;
+  }
+
+  // Returns the canonical constant that has the same structure and value as the
+  // given Constant |cst|. If none is found, it returns nullptr.
+  const Constant* FindConstant(const Constant* c) const {
+    auto it = const_pool_.find(c);
+    return (it != const_pool_.end()) ? *it : nullptr;
+  }
+
+  // Registers a new constant |cst| in the constant pool. If the constant
+  // existed already, it returns a pointer to the previously existing Constant
+  // in the pool. Otherwise, it returns |cst|.
+  const Constant* RegisterConstant(const Constant* cst) {
+    auto ret = const_pool_.insert(cst);
+    return *ret.first;
+  }
+
+  // A helper function to get a vector of Constant instances with the specified
+  // ids. If it can not find the Constant instance for any one of the ids,
+  // it returns an empty vector.
+  std::vector<const Constant*> GetConstantsFromIds(
+      const std::vector<uint32_t>& ids) const;
+
+  // Records a mapping between |inst| and the constant value generated by it.
+  // It returns true if a new Constant was successfully mapped, false if |inst|
+  // generates no constant values.
+  bool MapInst(ir::Instruction* inst) {
+    if (auto cst = GetConstantFromInst(inst)) {
+      MapConstantToInst(cst, inst);
+      return true;
+    }
+    return false;
+  }
+
+  // Records a new mapping between |inst| and |const_value|. This updates the
+  // two mappings |id_to_const_val_| and |const_val_to_id_|.
+  void MapConstantToInst(const Constant* const_value, ir::Instruction* inst) {
+    const_val_to_id_[const_value] = inst->result_id();
+    id_to_const_val_[inst->result_id()] = const_value;
+  }
+
+ private:
+  // Creates a Constant instance with the given type and a vector of constant
+  // defining words. Returns a unique pointer to the created Constant instance
+  // if the Constant instance can be created successfully. To create scalar
+  // type constants, the vector should contain the constant value in 32 bit
+  // words and the given type must be of type Bool, Integer or Float. To create
+  // composite type constants, the vector should contain the component ids, and
+  // those component ids should have been recorded before as Normal Constants.
+  // And the given type must be of type Struct, Vector or Array. When creating
+  // VectorType Constant instance, the components must be scalars of the same
+  // type, either Bool, Integer or Float. If any of the rules above failed, the
+  // creation will fail and nullptr will be returned. If the vector is empty,
+  // a NullConstant instance will be created with the given type.
+  const Constant* CreateConstant(
+      const Type* type,
+      const std::vector<uint32_t>& literal_words_or_ids) const;
 
   // Creates an instruction with the given result id to declare a constant
   // represented by the given Constant instance. Returns an unique pointer to
@@ -344,7 +469,7 @@ class ConstantManager {
   // the type of the constant is derived by getting an id from the type manager
   // for |c|.
   std::unique_ptr<ir::Instruction> CreateInstruction(
-      uint32_t result_id, analysis::Constant* c, uint32_t type_id = 0) const;
+      uint32_t result_id, const Constant* c, uint32_t type_id = 0) const;
 
   // Creates an OpConstantComposite instruction with the given result id and
   // the CompositeConst instance which represents a composite constant. Returns
@@ -356,52 +481,26 @@ class ConstantManager {
   // the type of the constant is derived by getting an id from the type manager
   // for |c|.
   std::unique_ptr<ir::Instruction> CreateCompositeInstruction(
-      uint32_t result_id, analysis::CompositeConstant* cc,
+      uint32_t result_id, const CompositeConstant* cc,
       uint32_t type_id = 0) const;
 
-  // A helper function to get the result type of the given instruction. Returns
-  // nullptr if the instruction does not have a type id (type id is 0).
-  analysis::Type* GetType(const ir::Instruction* inst) const;
-
-  // A helper function to get the collected normal constant with the given id.
-  // Returns the pointer to the Constant instance in case it is found.
-  // Otherwise, returns null pointer.
-  analysis::Constant* FindRecordedConstant(uint32_t id) const;
-
-  // A helper function to get the id of a collected constant with the pointer
-  // to the Constant instance. Returns 0 in case the constant is not found.
-  uint32_t FindRecordedConstant(const analysis::Constant* c) const;
-
-  // A helper function to get a vector of Constant instances with the specified
-  // ids. If can not find the Constant instance for any one of the ids, returns
-  // an empty vector.
-  std::vector<const analysis::Constant*> GetConstantsFromIds(
-      const std::vector<uint32_t>& ids) const;
-
-  // Records a new mapping between |inst| and |const_value|.
-  // This updates the two mappings |id_to_const_val_| and |const_val_to_id_|.
-  void MapConstantToInst(std::unique_ptr<analysis::Constant> const_value,
-                         ir::Instruction* inst) {
-    const_val_to_id_[const_value.get()] = inst->result_id();
-    id_to_const_val_[inst->result_id()] = std::move(const_value);
-  }
-
- private:
   // IR context that owns this constant manager.
   ir::IRContext* ctx_;
 
   // A mapping from the result ids of Normal Constants to their
-  // analysis::Constant instances. All Normal Constants in the module, either
+  // Constant instances. All Normal Constants in the module, either
   // existing ones before optimization or the newly generated ones, should have
   // their Constant instance stored and their result id registered in this map.
-  std::unordered_map<uint32_t, std::unique_ptr<analysis::Constant>>
-      id_to_const_val_;
+  std::unordered_map<uint32_t, const Constant*> id_to_const_val_;
 
-  // A mapping from the analsis::Constant instance of Normal Contants to their
-  // result id in the module. This is a mirror map of id_to_const_val_. All
+  // A mapping from the Constant instance of Normal Constants to their
+  // result id in the module. This is a mirror map of |id_to_const_val_|. All
   // Normal Constants that defining instructions in the module should have
-  // their analysis::Constant and their result id registered here.
-  std::unordered_map<const analysis::Constant*, uint32_t> const_val_to_id_;
+  // their Constant and their result id registered here.
+  std::unordered_map<const Constant*, uint32_t> const_val_to_id_;
+
+  // The constant pool.  All created constants are registered here.
+  std::unordered_set<const Constant*, ConstantHash, ConstantEqual> const_pool_;
 };
 
 }  // namespace analysis
index 005cb76..77e0517 100644 (file)
@@ -172,10 +172,10 @@ uint32_t OperateWords(SpvOp opcode,
 // result in 32 bit word. Scalar constants with longer than 32-bit width are
 // not accepted in this function.
 uint32_t FoldScalars(SpvOp opcode,
-                     const std::vector<analysis::Constant*>& operands) {
+                     const std::vector<const analysis::Constant*>& operands) {
   std::vector<uint32_t> operand_values_in_raw_words;
-  for (analysis::Constant* operand : operands) {
-    if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
+  for (const auto& operand : operands) {
+    if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
       const auto& scalar_words = scalar->words();
       assert(scalar_words.size() == 1 &&
              "Scalar constants with longer than 32-bit width are not allowed "
@@ -199,12 +199,12 @@ uint32_t FoldScalars(SpvOp opcode,
 // function.
 std::vector<uint32_t> FoldVectors(
     SpvOp opcode, uint32_t num_dims,
-    const std::vector<analysis::Constant*>& operands) {
+    const std::vector<const analysis::Constant*>& operands) {
   std::vector<uint32_t> result;
   for (uint32_t d = 0; d < num_dims; d++) {
     std::vector<uint32_t> operand_values_for_one_dimension;
-    for (analysis::Constant* operand : operands) {
-      if (analysis::VectorConstant* vector_operand =
+    for (const auto& operand : operands) {
+      if (const analysis::VectorConstant* vector_operand =
               operand->AsVectorConstant()) {
         // Extract the raw value of the scalar component constants
         // in 32-bit words here. The reason of not using FoldScalars() here
@@ -240,5 +240,47 @@ std::vector<uint32_t> FoldVectors(
   return result;
 }
 
+bool IsFoldableOpcode(SpvOp opcode) {
+  // NOTE: Extend to more opcodes as new cases are handled in the folder
+  // functions.
+  switch (opcode) {
+    case SpvOp::SpvOpBitwiseAnd:
+    case SpvOp::SpvOpBitwiseOr:
+    case SpvOp::SpvOpBitwiseXor:
+    case SpvOp::SpvOpIAdd:
+    case SpvOp::SpvOpIEqual:
+    case SpvOp::SpvOpIMul:
+    case SpvOp::SpvOpINotEqual:
+    case SpvOp::SpvOpISub:
+    case SpvOp::SpvOpLogicalAnd:
+    case SpvOp::SpvOpLogicalEqual:
+    case SpvOp::SpvOpLogicalNot:
+    case SpvOp::SpvOpLogicalNotEqual:
+    case SpvOp::SpvOpLogicalOr:
+    case SpvOp::SpvOpNot:
+    case SpvOp::SpvOpSDiv:
+    case SpvOp::SpvOpSelect:
+    case SpvOp::SpvOpSGreaterThan:
+    case SpvOp::SpvOpSGreaterThanEqual:
+    case SpvOp::SpvOpShiftLeftLogical:
+    case SpvOp::SpvOpShiftRightArithmetic:
+    case SpvOp::SpvOpShiftRightLogical:
+    case SpvOp::SpvOpSLessThan:
+    case SpvOp::SpvOpSLessThanEqual:
+    case SpvOp::SpvOpSMod:
+    case SpvOp::SpvOpSNegate:
+    case SpvOp::SpvOpSRem:
+    case SpvOp::SpvOpUDiv:
+    case SpvOp::SpvOpUGreaterThan:
+    case SpvOp::SpvOpUGreaterThanEqual:
+    case SpvOp::SpvOpULessThan:
+    case SpvOp::SpvOpULessThanEqual:
+    case SpvOp::SpvOpUMod:
+      return true;
+    default:
+      return false;
+  }
+}
+
 }  // namespace opt
 }  // namespace spvtools
index f1e5ea1..eb94cf2 100644 (file)
@@ -25,11 +25,13 @@ namespace spvtools {
 namespace opt {
 
 uint32_t FoldScalars(SpvOp opcode,
-                     const std::vector<analysis::Constant*>& operands);
+                     const std::vector<const analysis::Constant*>& operands);
 
 std::vector<uint32_t> FoldVectors(
     SpvOp opcode, uint32_t num_dims,
-    const std::vector<analysis::Constant*>& operands);
+    const std::vector<const analysis::Constant*>& operands);
+
+bool IsFoldableOpcode(SpvOp opcode);
 
 }  // namespace opt
 }  // namespace spvtools
index 831906f..79eaead 100644 (file)
@@ -92,15 +92,14 @@ Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl(
         // Constant instance should also be created successfully and recorded
         // in the id_to_const_val_ and const_val_to_id_ mapps.
         if (auto const_value =
-                context()->get_constant_mgr()->CreateConstantFromInst(inst)) {
+                context()->get_constant_mgr()->GetConstantFromInst(inst)) {
           // Need to replace the OpSpecConstantComposite instruction with a
           // corresponding OpConstantComposite instruction.
           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
             modified = true;
           }
-          context()->get_constant_mgr()->MapConstantToInst(
-              std::move(const_value), inst);
+          context()->get_constant_mgr()->MapConstantToInst(const_value, inst);
         }
         break;
       }
@@ -185,8 +184,8 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
   // operand. The first in-operand is the spec opcode.
   uint32_t source = inst->GetSingleWordInOperand(1);
   uint32_t type = context()->get_def_use_mgr()->GetDef(source)->type_id();
-  analysis::Constant* first_operand_const =
-      context()->get_constant_mgr()->FindRecordedConstant(source);
+  const analysis::Constant* first_operand_const =
+      context()->get_constant_mgr()->FindDeclaredConstant(source);
   if (!first_operand_const) return nullptr;
 
   const analysis::Constant* current_const = first_operand_const;
@@ -207,7 +206,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
       // Because components of a NullConstant are always NullConstants, we can
       // return early with a NullConstant in the result type.
       return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-          context()->get_constant_mgr()->CreateConstant(
+          context()->get_constant_mgr()->GetConstant(
               context()->get_constant_mgr()->GetType(inst), {}),
           pos, type);
     } else {
@@ -216,7 +215,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
     }
   }
   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-      current_const->Copy(), pos, type);
+      current_const, pos);
 }
 
 ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
@@ -230,10 +229,10 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
   assert(result_vec_type &&
          "The result of VectorShuffle must be of type vector");
 
-  // A temporary null constants that can be used as the components fo the
-  // result vector. This is needed when any one of the vector operands are null
+  // A temporary null constants that can be used as the components of the result
+  // vector. This is needed when any one of the vector operands are null
   // constant.
-  std::unique_ptr<analysis::Constant> null_component_constants;
+  const analysis::Constant* null_component_constants = nullptr;
 
   // Get a concatenated vector of scalar constants. The vector should be built
   // with the components from the first and the second operand of VectorShuffle.
@@ -244,14 +243,13 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
     assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
            "The vector operand must have a SPV_OPERAND_TYPE_ID type");
     uint32_t operand_id = inst->GetSingleWordInOperand(i);
-    analysis::Constant* operand_const =
-        context()->get_constant_mgr()->FindRecordedConstant(operand_id);
+    auto operand_const =
+        context()->get_constant_mgr()->FindDeclaredConstant(operand_id);
     if (!operand_const) return nullptr;
     const analysis::Type* operand_type = operand_const->type();
     assert(operand_type->AsVector() &&
            "The first two operand of VectorShuffle must be of vector type");
-    if (analysis::VectorConstant* vec_const =
-            operand_const->AsVectorConstant()) {
+    if (auto vec_const = operand_const->AsVectorConstant()) {
       // case 1: current operand is a non-null vector constant.
       concatenated_components.insert(concatenated_components.end(),
                                      vec_const->GetComponents().begin(),
@@ -263,13 +261,13 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
         const analysis::Type* component_type =
             operand_type->AsVector()->element_type();
         null_component_constants =
-            context()->get_constant_mgr()->CreateConstant(component_type, {});
+            context()->get_constant_mgr()->GetConstant(component_type, {});
       }
       // Append the null scalar consts to the concatenated components
       // vector.
       concatenated_components.insert(concatenated_components.end(),
                                      operand_type->AsVector()->element_count(),
-                                     null_component_constants.get());
+                                     null_component_constants);
     } else {
       // no other valid cases
       return nullptr;
@@ -280,7 +278,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
   // satisfy SSA def-use dominance.
   if (null_component_constants) {
     context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-        std::move(null_component_constants), pos);
+        null_component_constants, pos);
   }
   // Create the new vector constant with the selected components.
   std::vector<const analysis::Constant*> selected_components;
@@ -292,10 +290,13 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
            "Literal index out of bound of the concatenated vector");
     selected_components.push_back(concatenated_components[literal]);
   }
-  auto new_vec_const = MakeUnique<analysis::VectorConstant>(
-      result_vec_type, selected_components);
+  auto new_vec_const =
+      new analysis::VectorConstant(result_vec_type, selected_components);
+  auto reg_vec_const =
+      context()->get_constant_mgr()->RegisterConstant(new_vec_const);
+  if (reg_vec_const != new_vec_const) delete new_vec_const;
   return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-      std::move(new_vec_const), pos);
+      reg_vec_const, pos);
 }
 
 namespace {
@@ -329,7 +330,7 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
       context()->get_constant_mgr()->GetType(inst);
   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
   // Check and collect operands.
-  std::vector<analysis::Constant*> operands;
+  std::vector<const analysis::Constant*> operands;
 
   if (!std::all_of(
           inst->cbegin(), inst->cend(),
@@ -337,8 +338,8 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
             // skip the operands that is not an id.
             if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID) return true;
             uint32_t id = o.words.front();
-            if (analysis::Constant* c =
-                    context()->get_constant_mgr()->FindRecordedConstant(id)) {
+            if (auto c =
+                    context()->get_constant_mgr()->FindDeclaredConstant(id)) {
               if (IsValidTypeForComponentWiseOperation(c->type())) {
                 operands.push_back(c);
                 return true;
@@ -351,10 +352,10 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
   if (result_type->AsInteger() || result_type->AsBool()) {
     // Scalar operation
     uint32_t result_val = FoldScalars(spec_opcode, operands);
-    auto result_const = context()->get_constant_mgr()->CreateConstant(
-        result_type, {result_val});
+    auto result_const =
+        context()->get_constant_mgr()->GetConstant(result_type, {result_val});
     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-        std::move(result_const), pos);
+        result_const, pos);
   } else if (result_type->AsVector()) {
     // Vector operation
     const analysis::Type* element_type =
@@ -364,11 +365,11 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
         FoldVectors(spec_opcode, num_dims, operands);
     std::vector<const analysis::Constant*> result_vector_components;
     for (uint32_t r : result_vec) {
-      if (auto rc = context()->get_constant_mgr()->CreateConstant(element_type,
-                                                                  {r})) {
-        result_vector_components.push_back(rc.get());
+      if (auto rc =
+              context()->get_constant_mgr()->GetConstant(element_type, {r})) {
+        result_vector_components.push_back(rc);
         if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-                std::move(rc), pos)) {
+                rc, pos)) {
           assert(false &&
                  "Failed to build and insert constant declaring instruction "
                  "for the given vector component constant");
@@ -377,10 +378,13 @@ ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
         assert(false && "Failed to create constants with 32-bit word");
       }
     }
-    auto new_vec_const = MakeUnique<analysis::VectorConstant>(
-        result_type->AsVector(), result_vector_components);
+    auto new_vec_const = new analysis::VectorConstant(result_type->AsVector(),
+                                                      result_vector_components);
+    auto reg_vec_const =
+        context()->get_constant_mgr()->RegisterConstant(new_vec_const);
+    if (reg_vec_const != new_vec_const) delete new_vec_const;
     return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-        std::move(new_vec_const), pos);
+        reg_vec_const, pos);
   } else {
     // Cannot process invalid component wise operation. The result of component
     // wise operation must be of integer or bool scalar or vector of
index 684514c..182e8b3 100644 (file)
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-#include "instruction.h"
-#include "ir_context.h"
-
 #include <initializer_list>
 
+#include "fold.h"
+#include "instruction.h"
 #include "ir_context.h"
 #include "reflect.h"
 
@@ -444,5 +443,8 @@ bool Instruction::IsOpaqueType() const {
            spvOpcodeIsBaseOpaqueType(opcode());
   }
 }
+
+bool Instruction::IsFoldable() const { return opt::IsFoldableOpcode(opcode()); }
+
 }  // namespace ir
 }  // namespace spvtools
index 66fc82c..3c5d027 100644 (file)
@@ -326,11 +326,15 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
     return spvOpcodeIsBlockTerminator(opcode());
   }
 
-  // Return true if |this| is an instruction that define an opaque type.  Since
+  // Returns true if |this| is an instruction that define an opaque type.  Since
   // runtime array have similar characteristics they are included as opaque
   // types.
   bool IsOpaqueType() const;
 
+  // Returns true if |this| is an instruction which could be folded into a
+  // constant value.
+  bool IsFoldable() const;
+
   inline bool operator==(const Instruction&) const;
   inline bool operator!=(const Instruction&) const;
   inline bool operator<(const Instruction&) const;
index 57d5c72..2edd0de 100644 (file)
@@ -122,12 +122,14 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
       .RegisterPass(CreateInsertExtractElimPass())
+      .RegisterPass(CreateCCPPass())
       .RegisterPass(CreateAggressiveDCEPass())
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateBlockMergePass())
       .RegisterPass(CreateLocalMultiStoreElimPass())
       .RegisterPass(CreateInsertExtractElimPass())
       .RegisterPass(CreateRedundancyEliminationPass())
+      .RegisterPass(CreateCFGCleanupPass())
       // Currently exposing driver bugs resulting in crashes (#946)
       // .RegisterPass(CreateCommonUniformElimPass())
       .RegisterPass(CreateDeadVariableEliminationPass());
@@ -142,12 +144,14 @@ Optimizer& Optimizer::RegisterSizePasses() {
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
       .RegisterPass(CreateInsertExtractElimPass())
+      .RegisterPass(CreateCCPPass())
       .RegisterPass(CreateAggressiveDCEPass())
       .RegisterPass(CreateDeadBranchElimPass())
       .RegisterPass(CreateBlockMergePass())
       .RegisterPass(CreateLocalMultiStoreElimPass())
       .RegisterPass(CreateInsertExtractElimPass())
       .RegisterPass(CreateRedundancyEliminationPass())
+      .RegisterPass(CreateCFGCleanupPass())
       // Currently exposing driver bugs resulting in crashes (#946)
       // .RegisterPass(CreateCommonUniformElimPass())
       .RegisterPass(CreateDeadVariableEliminationPass());
@@ -336,4 +340,9 @@ Optimizer::PassToken CreatePrivateToLocalPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::PrivateToLocalPass>());
 }
+
+Optimizer::PassToken CreateCCPPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::CCPPass>());
+}
+
 }  // namespace spvtools
index 67cfe2e..1b1acdf 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "aggressive_dead_code_elim_pass.h"
 #include "block_merge_pass.h"
+#include "ccp_pass.h"
 #include "cfg_cleanup_pass.h"
 #include "common_uniform_elim_pass.h"
 #include "compact_ids_pass.h"
index 18c89cf..0ac1310 100644 (file)
@@ -49,6 +49,16 @@ void SSAPropagator::AddSSAEdges(uint32_t id) {
   });
 }
 
+bool SSAPropagator::IsPhiArgExecutable(ir::Instruction* phi, uint32_t i) const {
+  ir::BasicBlock* phi_bb = ctx_->get_instr_block(phi);
+
+  uint32_t in_label_id = phi->GetSingleWordOperand(i + 1);
+  ir::Instruction* in_label_instr = get_def_use_mgr()->GetDef(in_label_id);
+  ir::BasicBlock* in_bb = ctx_->get_instr_block(in_label_instr);
+
+  return IsEdgeExecutable(Edge(in_bb, phi_bb));
+}
+
 bool SSAPropagator::Simulate(ir::Instruction* instr) {
   bool changed = false;
 
@@ -98,7 +108,6 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) {
   // defined at an instruction D that should be simulated again, then the output
   // of D might affect |instr|, so we should simulate |instr| again.
   bool has_operands_to_simulate = false;
-  ir::BasicBlock* instr_bb = ctx_->get_instr_block(instr);
   if (instr->opcode() == SpvOpPhi) {
     // For Phi instructions, an operand causes the Phi to be simulated again if
     // the operand comes from an edge that has not yet been traversed or if its
@@ -111,12 +120,7 @@ bool SSAPropagator::Simulate(ir::Instruction* instr) {
 
       uint32_t arg_id = instr->GetSingleWordOperand(i);
       ir::Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id);
-      uint32_t in_label_id = instr->GetSingleWordOperand(i + 1);
-      ir::Instruction* in_label_instr = get_def_use_mgr()->GetDef(in_label_id);
-      ir::BasicBlock* in_bb = ctx_->get_instr_block(in_label_instr);
-      Edge edge(in_bb, instr_bb);
-
-      if (!IsEdgeExecutable(edge) || ShouldSimulateAgain(arg_def_instr)) {
+      if (!IsPhiArgExecutable(instr, i) || ShouldSimulateAgain(arg_def_instr)) {
         has_operands_to_simulate = true;
         break;
       }
index 2e7eb59..be06ff0 100644 (file)
@@ -187,10 +187,14 @@ class SSAPropagator {
   SSAPropagator(ir::IRContext* context, const VisitFunction& visit_fn)
       : ctx_(context), visit_fn_(visit_fn) {}
 
-  // Run the propagator on function |fn|. Returns true if changes were made to
+  // Runs the propagator on function |fn|. Returns true if changes were made to
   // the function. Otherwise, it returns false.
   bool Run(ir::Function* fn);
 
+  // Returns true if the |i|th argument for |phi| comes through a CFG edge that
+  // has been marked executable.
+  bool IsPhiArgExecutable(ir::Instruction* phi, uint32_t i) const;
+
  private:
   // Initialize processing.
   void Initialize(ir::Function* fn);
@@ -216,7 +220,7 @@ class SSAPropagator {
   }
 
   // Returns true if |block| has been simulated already.
-  bool BlockHasBeenSimulated(ir::BasicBlock* block) {
+  bool BlockHasBeenSimulated(ir::BasicBlock* block) const {
     return simulated_blocks_.find(block) != simulated_blocks_.end();
   }
 
@@ -232,7 +236,7 @@ class SSAPropagator {
   }
 
   // Returns true if |edge| has been marked as executable.
-  bool IsEdgeExecutable(const Edge& edge) {
+  bool IsEdgeExecutable(const Edge& edge) const {
     return executable_edges_.find(edge) != executable_edges_.end();
   }
 
index bffb0bb..d8a5b53 100644 (file)
@@ -255,3 +255,8 @@ add_spvtools_unittest(TARGET pass_remove_duplicates
   SRCS pass_remove_duplicates_test.cpp
   LIBS SPIRV-Tools-opt
 )
+
+add_spvtools_unittest(TARGET ccp
+  SRCS ccp_test.cpp
+  LIBS SPIRV-Tools-opt
+)
diff --git a/test/opt/ccp_test.cpp b/test/opt/ccp_test.cpp
new file mode 100644 (file)
index 0000000..a274526
--- /dev/null
@@ -0,0 +1,360 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+#include "opt/ccp_pass.h"
+
+namespace {
+
+using namespace spvtools;
+
+using CCPTest = PassTest<::testing::Test>;
+
+// TODO(dneto): Add Effcee as required dependency, and make this unconditional.
+#ifdef SPIRV_EFFCEE
+TEST_F(CCPTest, PropagateThroughPhis) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %x %outparm
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %x "x"
+               OpName %outparm "outparm"
+               OpDecorate %x Flat
+               OpDecorate %x Location 0
+               OpDecorate %outparm Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+       %bool = OpTypeBool
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_4 = OpConstant %int 4
+      %int_3 = OpConstant %int 3
+      %int_1 = OpConstant %int 1
+%_ptr_Input_int = OpTypePointer Input %int
+          %x = OpVariable %_ptr_Input_int Input
+%_ptr_Output_int = OpTypePointer Output %int
+    %outparm = OpVariable %_ptr_Output_int Output
+       %main = OpFunction %void None %3
+          %4 = OpLabel
+          %5 = OpLoad %int %x
+          %9 = OpIAdd %int %int_1 %int_3
+          %6 = OpSGreaterThan %bool %5 %int_3
+               OpSelectionMerge %25 None
+               OpBranchConditional %6 %22 %23
+         %22 = OpLabel
+
+; CHECK: OpCopyObject %int %int_4
+          %7 = OpCopyObject %int %9
+
+               OpBranch %25
+         %23 = OpLabel
+          %8 = OpCopyObject %int %int_4
+               OpBranch %25
+         %25 = OpLabel
+
+; %int_4 should have propagated to both OpPhi operands.
+; CHECK: OpPhi %int %int_4 {{%\d+}} %int_4 {{%\d+}}
+         %35 = OpPhi %int %7 %22 %8 %23
+
+; This function always returns 4. DCE should get rid of everything else.
+; CHECK OpStore %outparm %int_4
+               OpStore %outparm %35
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+
+TEST_F(CCPTest, SimplifyConditionals) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %outparm
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %outparm "outparm"
+               OpDecorate %outparm Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+       %bool = OpTypeBool
+%_ptr_Function_int = OpTypePointer Function %int
+      %int_4 = OpConstant %int 4
+      %int_3 = OpConstant %int 3
+      %int_1 = OpConstant %int 1
+%_ptr_Output_int = OpTypePointer Output %int
+    %outparm = OpVariable %_ptr_Output_int Output
+       %main = OpFunction %void None %3
+          %4 = OpLabel
+          %9 = OpIAdd %int %int_4 %int_3
+          %6 = OpSGreaterThan %bool %9 %int_3
+               OpSelectionMerge %25 None
+; CHECK: OpBranchConditional %true [[bb_taken:%\d+]] [[bb_not_taken:%\d+]]
+               OpBranchConditional %6 %22 %23
+; CHECK: [[bb_taken]] = OpLabel
+         %22 = OpLabel
+; CHECK: OpCopyObject %int %int_7
+          %7 = OpCopyObject %int %9
+               OpBranch %25
+; CHECK: [[bb_not_taken]] = OpLabel
+         %23 = OpLabel
+; CHECK: [[id_not_evaluated:%\d+]] = OpCopyObject %int %int_4
+          %8 = OpCopyObject %int %int_4
+               OpBranch %25
+         %25 = OpLabel
+
+; %int_7 should have propagated to the first OpPhi operand. But the else branch
+; is not executable (conditional is always true), so no values should be
+; propagated there and the value of the OpPhi should always be %int_7.
+; CHECK: OpPhi %int %int_7 [[bb_taken]] [[id_not_evaluated]] [[bb_not_taken]]
+         %35 = OpPhi %int %7 %22 %8 %23
+
+; Only the true path of the conditional is ever executed. The output of this
+; function is always %int_7.
+; CHECK: OpStore %outparm %int_7
+               OpStore %outparm %35
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+
+TEST_F(CCPTest, SimplifySwitches) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %outparm
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %x "x"
+               OpName %outparm "outparm"
+               OpDecorate %outparm Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+     %int_23 = OpConstant %int 23
+     %int_42 = OpConstant %int 42
+     %int_14 = OpConstant %int 14
+     %int_15 = OpConstant %int 15
+      %int_4 = OpConstant %int 4
+%_ptr_Output_int = OpTypePointer Output %int
+    %outparm = OpVariable %_ptr_Output_int Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+          %x = OpVariable %_ptr_Function_int Function
+               OpStore %x %int_23
+         %10 = OpLoad %int %x
+               OpSelectionMerge %14 None
+               OpSwitch %10 %14 10 %11 13 %12 23 %13
+         %11 = OpLabel
+               OpStore %x %int_42
+               OpBranch %14
+         %12 = OpLabel
+               OpStore %x %int_14
+               OpBranch %14
+         %13 = OpLabel
+               OpStore %x %int_15
+               OpBranch %14
+         %14 = OpLabel
+; CHECK: OpPhi %int %int_23 {{%\d+}} %int_42 {{%\d+}} %int_14 {{%\d+}} %int_15 {{%\d+}}
+; CHECK-NOT: OpLoad %int
+         %23 = OpLoad %int %x
+; CHECK: OpIAdd %int %int_15 %int_4
+         %24 = OpIAdd %int %23 %int_4
+; CHECK: OpStore %x %int_19
+               OpStore %x %24
+         %27 = OpLoad %int %x
+; CHECK: OpStore %outparm %int_19
+               OpStore %outparm %27
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+
+TEST_F(CCPTest, SimplifySwitchesDefaultBranch) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %outparm
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %x "x"
+               OpName %outparm "outparm"
+               OpDecorate %outparm Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+     %int_42 = OpConstant %int 42
+      %int_4 = OpConstant %int 4
+      %int_1 = OpConstant %int 1
+%_ptr_Output_int = OpTypePointer Output %int
+    %outparm = OpVariable %_ptr_Output_int Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+          %x = OpVariable %_ptr_Function_int Function
+               OpStore %x %int_42
+         %10 = OpLoad %int %x
+         %15 = OpIAdd %int %10 %int_4
+               OpSelectionMerge %14 None
+               OpSwitch %15 %13 10 %11
+         %11 = OpLabel
+               OpStore %x %int_42
+               OpBranch %14
+         %13 = OpLabel
+               OpStore %x %int_1
+               OpBranch %14
+         %14 = OpLabel
+; CHECK: OpPhi %int %int_42 {{%\d+}} %int_1 {{%\d+}}
+; CHECK-NOT: OpLoad %int
+         %23 = OpLoad %int %x
+; CHECK: OpIAdd %int %int_1 %int_4
+         %24 = OpIAdd %int %23 %int_4
+; CHECK: OpStore %x %int_5
+               OpStore %x %24
+         %27 = OpLoad %int %x
+; CHECK: OpStore %outparm %int_5
+               OpStore %outparm %27
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+
+TEST_F(CCPTest, SimplifyIntVector) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %OutColor
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %v "v"
+               OpName %OutColor "OutColor"
+               OpDecorate %OutColor Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+      %v4int = OpTypeVector %int 4
+%_ptr_Function_v4int = OpTypePointer Function %v4int
+      %int_1 = OpConstant %int 1
+      %int_2 = OpConstant %int 2
+      %int_3 = OpConstant %int 3
+      %int_4 = OpConstant %int 4
+         %14 = OpConstantComposite %v4int %int_1 %int_2 %int_3 %int_4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Output_v4int = OpTypePointer Output %v4int
+   %OutColor = OpVariable %_ptr_Output_v4int Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+          %v = OpVariable %_ptr_Function_v4int Function
+               OpStore %v %14
+         %18 = OpAccessChain %_ptr_Function_int %v %uint_0
+         %19 = OpLoad %int %18
+
+; The constant folder does not see through access chains. To get this, the
+; vector would have to be scalarized.
+; CHECK: [[result_id:%\d+]] = OpIAdd %int {{%\d+}} %int_1
+         %20 = OpIAdd %int %19 %int_1
+         %21 = OpAccessChain %_ptr_Function_int %v %uint_0
+
+; CHECK: OpStore {{%\d+}} [[result_id]]
+               OpStore %21 %20
+         %24 = OpLoad %v4int %v
+               OpStore %OutColor %24
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+
+TEST_F(CCPTest, BadSimplifyFloatVector) {
+  const std::string spv_asm = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %OutColor
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %v "v"
+               OpName %OutColor "OutColor"
+               OpDecorate %OutColor Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+    %float_3 = OpConstant %float 3
+    %float_4 = OpConstant %float 4
+         %14 = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+   %OutColor = OpVariable %_ptr_Output_v4float Output
+       %main = OpFunction %void None %3
+          %5 = OpLabel
+          %v = OpVariable %_ptr_Function_v4float Function
+               OpStore %v %14
+         %18 = OpAccessChain %_ptr_Function_float %v %uint_0
+         %19 = OpLoad %float %18
+
+; NOTE: This test should start failing once floating point folding is
+;       implemented (https://github.com/KhronosGroup/SPIRV-Tools/issues/943).
+;       This should be checking that we are adding %float_1 + %float_1.
+; CHECK: [[result_id:%\d+]] = OpFAdd %float {{%\d+}} %float_1
+         %20 = OpFAdd %float %19 %float_1
+         %21 = OpAccessChain %_ptr_Function_float %v %uint_0
+
+; This should be checkint that we are storing %float_2 instead of result_it.
+; CHECK: OpStore {{%\d+}} [[result_id]]
+               OpStore %21 %20
+         %24 = OpLoad %v4float %v
+               OpStore %OutColor %24
+               OpReturn
+               OpFunctionEnd
+               )";
+
+  SinglePassRunAndMatch<opt::CCPPass>(spv_asm, true);
+}
+#endif
+
+}  // namespace
index 4e8064b..a8debdf 100644 (file)
@@ -330,9 +330,9 @@ INSTANTIATE_TEST_CASE_P(
 // Tests for operations that resulting in different types.
 INSTANTIATE_TEST_CASE_P(
     Cast, FoldSpecConstantOpAndCompositePassTest,
-    ::testing::ValuesIn(std::vector<
-                        FoldSpecConstantOpAndCompositePassTestCase>({
-        // clang-format off
+    ::testing::ValuesIn(
+        std::vector<FoldSpecConstantOpAndCompositePassTestCase>({
+            // clang-format off
             // int -> bool scalar
             {
               // original
@@ -441,13 +441,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%true = OpConstantTrue %bool",
                 "%true_0 = OpConstantTrue %bool",
-                "%spec_bool_t_vec = OpConstantComposite %v2bool %true %true_0",
+                "%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0",
                 "%false = OpConstantFalse %bool",
                 "%false_0 = OpConstantFalse %bool",
-                "%spec_bool_f_vec = OpConstantComposite %v2bool %false %false_0",
+                "%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0",
                 "%false_1 = OpConstantFalse %bool",
                 "%false_2 = OpConstantFalse %bool",
-                "%spec_bool_from_null = OpConstantComposite %v2bool %false_1 %false_2",
+                "%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2",
               },
             },
 
@@ -463,13 +463,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%true = OpConstantTrue %bool",
                 "%true_0 = OpConstantTrue %bool",
-                "%spec_bool_t_vec = OpConstantComposite %v2bool %true %true_0",
+                "%spec_bool_t_vec = OpConstantComposite %v2bool %true_0 %true_0",
                 "%false = OpConstantFalse %bool",
                 "%false_0 = OpConstantFalse %bool",
-                "%spec_bool_f_vec = OpConstantComposite %v2bool %false %false_0",
+                "%spec_bool_f_vec = OpConstantComposite %v2bool %false_0 %false_0",
                 "%false_1 = OpConstantFalse %bool",
                 "%false_2 = OpConstantFalse %bool",
-                "%spec_bool_from_null = OpConstantComposite %v2bool %false_1 %false_2",
+                "%spec_bool_from_null = OpConstantComposite %v2bool %false_2 %false_2",
               },
             },
 
@@ -485,13 +485,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%int_1 = OpConstant %int 1",
                 "%int_1_0 = OpConstant %int 1",
-                "%spec_int_one_vec = OpConstantComposite %v2int %int_1 %int_1_0",
+                "%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0",
                 "%int_0 = OpConstant %int 0",
                 "%int_0_0 = OpConstant %int 0",
-                "%spec_int_zero_vec = OpConstantComposite %v2int %int_0 %int_0_0",
+                "%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0",
                 "%int_0_1 = OpConstant %int 0",
                 "%int_0_2 = OpConstant %int 0",
-                "%spec_int_from_null = OpConstantComposite %v2int %int_0_1 %int_0_2",
+                "%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2",
               },
             },
 
@@ -507,13 +507,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%int_1 = OpConstant %int 1",
                 "%int_1_0 = OpConstant %int 1",
-                "%spec_int_one_vec = OpConstantComposite %v2int %int_1 %int_1_0",
+                "%spec_int_one_vec = OpConstantComposite %v2int %int_1_0 %int_1_0",
                 "%int_0 = OpConstant %int 0",
                 "%int_0_0 = OpConstant %int 0",
-                "%spec_int_zero_vec = OpConstantComposite %v2int %int_0 %int_0_0",
+                "%spec_int_zero_vec = OpConstantComposite %v2int %int_0_0 %int_0_0",
                 "%int_0_1 = OpConstant %int 0",
                 "%int_0_2 = OpConstant %int 0",
-                "%spec_int_from_null = OpConstantComposite %v2int %int_0_1 %int_0_2",
+                "%spec_int_from_null = OpConstantComposite %v2int %int_0_2 %int_0_2",
               },
             },
 
@@ -529,13 +529,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%uint_1 = OpConstant %uint 1",
                 "%uint_1_0 = OpConstant %uint 1",
-                "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1 %uint_1_0",
+                "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
                 "%uint_0 = OpConstant %uint 0",
                 "%uint_0_0 = OpConstant %uint 0",
-                "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0 %uint_0_0",
+                "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0",
                 "%uint_0_1 = OpConstant %uint 0",
                 "%uint_0_2 = OpConstant %uint 0",
-                "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_1 %uint_0_2",
+                "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2",
               },
             },
 
@@ -551,17 +551,17 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%uint_1 = OpConstant %uint 1",
                 "%uint_1_0 = OpConstant %uint 1",
-                "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1 %uint_1_0",
+                "%spec_uint_one_vec = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
                 "%uint_0 = OpConstant %uint 0",
                 "%uint_0_0 = OpConstant %uint 0",
-                "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0 %uint_0_0",
+                "%spec_uint_zero_vec = OpConstantComposite %v2uint %uint_0_0 %uint_0_0",
                 "%uint_0_1 = OpConstant %uint 0",
                 "%uint_0_2 = OpConstant %uint 0",
-                "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_1 %uint_0_2",
+                "%spec_uint_from_null = OpConstantComposite %v2uint %uint_0_2 %uint_0_2",
               },
             },
-        // clang-format on
-    })));
+            // clang-format on
+        })));
 
 // Tests about boolean scalar logical operations and comparison operations with
 // scalar int/uint type.
@@ -836,13 +836,13 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%int_n1 = OpConstant %int -1",
                 "%int_n1_0 = OpConstant %int -1",
-                "%v2int_minus_1 = OpConstantComposite %v2int %int_n1 %int_n1_0",
+                "%v2int_minus_1 = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
                 "%int_n2 = OpConstant %int -2",
                 "%int_n2_0 = OpConstant %int -2",
-                "%v2int_minus_2 = OpConstantComposite %v2int %int_n2 %int_n2_0",
+                "%v2int_minus_2 = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
                 "%int_0 = OpConstant %int 0",
                 "%int_0_0 = OpConstant %int 0",
-                "%v2int_neg_null = OpConstantComposite %v2int %int_0 %int_0_0",
+                "%v2int_neg_null = OpConstantComposite %v2int %int_0_0 %int_0_0",
               },
             },
             // vector integer (including null vetors) add, sub, div, mul
@@ -865,35 +865,35 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%int_5 = OpConstant %int 5",
                 "%int_5_0 = OpConstant %int 5",
-                "%spec_v2int_iadd = OpConstantComposite %v2int %int_5 %int_5_0",
+                "%spec_v2int_iadd = OpConstantComposite %v2int %int_5_0 %int_5_0",
                 "%int_n4 = OpConstant %int -4",
                 "%int_n4_0 = OpConstant %int -4",
-                "%spec_v2int_isub = OpConstantComposite %v2int %int_n4 %int_n4_0",
+                "%spec_v2int_isub = OpConstantComposite %v2int %int_n4_0 %int_n4_0",
                 "%int_n2 = OpConstant %int -2",
                 "%int_n2_0 = OpConstant %int -2",
-                "%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2 %int_n2_0",
+                "%spec_v2int_sdiv = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
                 "%int_n6 = OpConstant %int -6",
                 "%int_n6_0 = OpConstant %int -6",
-                "%spec_v2int_imul = OpConstantComposite %v2int %int_n6 %int_n6_0",
+                "%spec_v2int_imul = OpConstantComposite %v2int %int_n6_0 %int_n6_0",
                 "%int_n6_1 = OpConstant %int -6",
                 "%int_n6_2 = OpConstant %int -6",
-                "%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6_1 %int_n6_2",
+                "%spec_v2int_iadd_null = OpConstantComposite %v2int %int_n6_2 %int_n6_2",
 
                 "%uint_5 = OpConstant %uint 5",
                 "%uint_5_0 = OpConstant %uint 5",
-                "%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5 %uint_5_0",
+                "%spec_v2uint_iadd = OpConstantComposite %v2uint %uint_5_0 %uint_5_0",
                 "%uint_4294967292 = OpConstant %uint 4294967292",
                 "%uint_4294967292_0 = OpConstant %uint 4294967292",
-                "%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292 %uint_4294967292_0",
+                "%spec_v2uint_isub = OpConstantComposite %v2uint %uint_4294967292_0 %uint_4294967292_0",
                 "%uint_1431655764 = OpConstant %uint 1431655764",
                 "%uint_1431655764_0 = OpConstant %uint 1431655764",
-                "%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764 %uint_1431655764_0",
+                "%spec_v2uint_udiv = OpConstantComposite %v2uint %uint_1431655764_0 %uint_1431655764_0",
                 "%uint_2863311528 = OpConstant %uint 2863311528",
                 "%uint_2863311528_0 = OpConstant %uint 2863311528",
-                "%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528 %uint_2863311528_0",
+                "%spec_v2uint_imul = OpConstantComposite %v2uint %uint_2863311528_0 %uint_2863311528_0",
                 "%uint_2863311528_1 = OpConstant %uint 2863311528",
                 "%uint_2863311528_2 = OpConstant %uint 2863311528",
-                "%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528_1 %uint_2863311528_2",
+                "%spec_v2uint_isub_null = OpConstantComposite %v2uint %uint_2863311528_2 %uint_2863311528_2",
               },
             },
             // vector integer rem, mod
@@ -938,33 +938,33 @@ INSTANTIATE_TEST_CASE_P(
                 // srem
                 "%int_1 = OpConstant %int 1",
                 "%int_1_0 = OpConstant %int 1",
-                "%7_srem_3 = OpConstantComposite %v2int %int_1 %int_1_0",
+                "%7_srem_3 = OpConstantComposite %v2int %int_1_0 %int_1_0",
                 "%int_n1 = OpConstant %int -1",
                 "%int_n1_0 = OpConstant %int -1",
-                "%minus_7_srem_3 = OpConstantComposite %v2int %int_n1 %int_n1_0",
+                "%minus_7_srem_3 = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
                 "%int_1_1 = OpConstant %int 1",
                 "%int_1_2 = OpConstant %int 1",
-                "%7_srem_minus_3 = OpConstantComposite %v2int %int_1_1 %int_1_2",
+                "%7_srem_minus_3 = OpConstantComposite %v2int %int_1_2 %int_1_2",
                 "%int_n1_1 = OpConstant %int -1",
                 "%int_n1_2 = OpConstant %int -1",
-                "%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1_1 %int_n1_2",
+                "%minus_7_srem_minus_3 = OpConstantComposite %v2int %int_n1_2 %int_n1_2",
                 // smod
                 "%int_1_3 = OpConstant %int 1",
                 "%int_1_4 = OpConstant %int 1",
-                "%7_smod_3 = OpConstantComposite %v2int %int_1_3 %int_1_4",
+                "%7_smod_3 = OpConstantComposite %v2int %int_1_4 %int_1_4",
                 "%int_2 = OpConstant %int 2",
                 "%int_2_0 = OpConstant %int 2",
-                "%minus_7_smod_3 = OpConstantComposite %v2int %int_2 %int_2_0",
+                "%minus_7_smod_3 = OpConstantComposite %v2int %int_2_0 %int_2_0",
                 "%int_n2 = OpConstant %int -2",
                 "%int_n2_0 = OpConstant %int -2",
-                "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2 %int_n2_0",
+                "%7_smod_minus_3 = OpConstantComposite %v2int %int_n2_0 %int_n2_0",
                 "%int_n1_3 = OpConstant %int -1",
                 "%int_n1_4 = OpConstant %int -1",
-                "%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1_3 %int_n1_4",
+                "%minus_7_smod_minus_3 = OpConstantComposite %v2int %int_n1_4 %int_n1_4",
                 // umod
                 "%uint_1 = OpConstant %uint 1",
                 "%uint_1_0 = OpConstant %uint 1",
-                "%7_umod_3 = OpConstantComposite %v2uint %uint_1 %uint_1_0",
+                "%7_umod_3 = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
               },
             },
             // vector integer bitwise, shift
@@ -985,25 +985,25 @@ INSTANTIATE_TEST_CASE_P(
               {
                 "%int_2 = OpConstant %int 2",
                 "%int_2_0 = OpConstant %int 2",
-                "%xor_1_3 = OpConstantComposite %v2int %int_2 %int_2_0",
+                "%xor_1_3 = OpConstantComposite %v2int %int_2_0 %int_2_0",
                 "%int_0 = OpConstant %int 0",
                 "%int_0_0 = OpConstant %int 0",
-                "%and_1_2 = OpConstantComposite %v2int %int_0 %int_0_0",
+                "%and_1_2 = OpConstantComposite %v2int %int_0_0 %int_0_0",
                 "%int_3 = OpConstant %int 3",
                 "%int_3_0 = OpConstant %int 3",
-                "%or_1_2 = OpConstantComposite %v2int %int_3 %int_3_0",
+                "%or_1_2 = OpConstantComposite %v2int %int_3_0 %int_3_0",
 
                 "%unsigned_31 = OpConstant %uint 31",
                 "%v2unsigned_31 = OpConstantComposite %v2uint %unsigned_31 %unsigned_31",
                 "%uint_2147483648 = OpConstant %uint 2147483648",
                 "%uint_2147483648_0 = OpConstant %uint 2147483648",
-                "%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648 %uint_2147483648_0",
+                "%unsigned_left_shift_max = OpConstantComposite %v2uint %uint_2147483648_0 %uint_2147483648_0",
                 "%uint_1 = OpConstant %uint 1",
                 "%uint_1_0 = OpConstant %uint 1",
-                "%unsigned_right_shift_logical = OpConstantComposite %v2uint %uint_1 %uint_1_0",
+                "%unsigned_right_shift_logical = OpConstantComposite %v2uint %uint_1_0 %uint_1_0",
                 "%int_n1 = OpConstant %int -1",
                 "%int_n1_0 = OpConstant %int -1",
-                "%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1 %int_n1_0",
+                "%signed_right_shift_arithmetic = OpConstantComposite %v2int %int_n1_0 %int_n1_0",
               },
             },
             // Skip folding if any vector operands or components of the operands
@@ -1140,7 +1140,7 @@ INSTANTIATE_TEST_CASE_P(
                 "%float_1 = OpConstant %float 1",
                 "%inner = OpConstantComposite %inner_struct %bool_true %signed_null %float_1",
                 "%outer = OpConstantComposite %outer_struct %inner %signed_one",
-                "%extract_inner = OpConstantComposite %inner_struct %bool_true %signed_null %float_1",
+                "%extract_inner = OpConstantComposite %flat_struct %bool_true %signed_null %float_1",
                 "%extract_int = OpConstant %int 1",
                 "%extract_inner_float = OpConstant %float 1",
               },
index 6b94a3c..ead001b 100644 (file)
@@ -81,6 +81,12 @@ standard output.
 NOTE: The optimizer is a work in progress.
 
 Options (in lexicographical order):
+  --ccp
+               Apply the conditional constant propagation transform.  This will
+               propagate constant values throughout the program, and simplify
+               expressions and conditional jumps with known predicate
+               values.  Performed on entry point call tree functions and
+               exported functions.
   --cfg-cleanup
                Cleanup the control flow graph. This will remove any unnecessary
                code from the CFG like unreachable code. Performed on entry
@@ -449,6 +455,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         if (status.action != OPT_CONTINUE) {
           return status;
         }
+      } else if (0 == strcmp(cur_arg, "--ccp")) {
+        optimizer->RegisterPass(CreateCCPPass());
       } else if ('\0' == cur_arg[1]) {
         // Setting a filename of "-" to indicate stdin.
         if (!*in_file) {