Add scalar replacement
authorAlan Baker <alanbaker@google.com>
Thu, 30 Nov 2017 22:03:06 +0000 (17:03 -0500)
committerSteven Perron <stevenperron@google.com>
Mon, 11 Dec 2017 15:51:13 +0000 (10:51 -0500)
Adds a scalar replacement pass. The pass considers all function scope
variables of composite type. If there are accesses to individual
elements (and it is legal) the pass replaces the variable with a
variable for each composite element and updates all the uses.

Added the pass to -O
Added NumUses and NumUsers to DefUseManager
Added some helper methods for the inst to block mapping in context
Added some helper methods for specific constant types

No longer generate duplicate pointer types.

* Now searches for an existing pointer of the appropriate type instead
of failing validation
* Fixed spec constant extracts
* Addressed changes for review
* Changed RunSinglePassAndMatch to be able to run validation
 * current users do not enable it

Added handling of acceptable decorations.

* Decorations are also transfered where appropriate

Refactored extension checking into FeatureManager

* Context now owns a feature manager
 * consciously NOT an analysis
 * added some test
* fixed some minor issues related to decorates
* added some decorate related tests for scalar replacement

25 files changed:
Android.mk
include/spirv-tools/optimizer.hpp
source/opt/CMakeLists.txt
source/opt/def_use_manager.cpp
source/opt/def_use_manager.h
source/opt/feature_manager.cpp [new file with mode: 0644]
source/opt/feature_manager.h [new file with mode: 0644]
source/opt/ir_context.h
source/opt/optimizer.cpp
source/opt/passes.h
source/opt/reflect.h
source/opt/scalar_replacement_pass.cpp [new file with mode: 0644]
source/opt/scalar_replacement_pass.h [new file with mode: 0644]
test/opt/CMakeLists.txt
test/opt/common_uniform_elim_test.cpp
test/opt/eliminate_dead_functions_test.cpp
test/opt/feature_manager_test.cpp [new file with mode: 0644]
test/opt/fold_spec_const_op_composite_test.cpp
test/opt/local_redundancy_elimination_test.cpp
test/opt/pass_fixture.h
test/opt/redundancy_elimination_test.cpp
test/opt/scalar_replacement_test.cpp [new file with mode: 0644]
test/opt/strength_reduction_test.cpp
test/opt/unify_const_test.cpp
tools/opt/opt.cpp

index 527bce1..0554e90 100644 (file)
@@ -66,6 +66,7 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/dominator_tree.cpp \
                source/opt/eliminate_dead_constant_pass.cpp \
                source/opt/eliminate_dead_functions_pass.cpp \
+               source/opt/feature_manager.cpp \
                source/opt/flatten_decoration_pass.cpp \
                source/opt/fold.cpp \
                source/opt/fold_spec_constant_op_and_composite_pass.cpp \
@@ -93,6 +94,7 @@ SPVTOOLS_OPT_SRC_FILES := \
                source/opt/propagator.cpp \
                source/opt/redundancy_elimination.cpp \
                source/opt/remove_duplicates_pass.cpp \
+               source/opt/scalar_replacement_pass.cpp \
                source/opt/set_spec_constant_default_value_pass.cpp \
                source/opt/strength_reduction_pass.cpp \
                source/opt/strip_debug_info_pass.cpp \
index eea50a8..754d2b9 100644 (file)
@@ -437,6 +437,11 @@ Optimizer::PassToken CreateLocalRedundancyEliminationPass();
 // This pass will look for instructions where the same value is computed on all
 // paths leading to the instruction.  Those instructions are deleted.
 Optimizer::PassToken CreateRedundancyEliminationPass();
+
+// Create scalar replacement pass.
+// This pass replaces composite function scope variables with variables for each
+// element if those elements are accessed individually.
+Optimizer::PassToken CreateScalarReplacementPass();
 }  // namespace spvtools
 
 #endif  // SPIRV_TOOLS_OPTIMIZER_HPP_
index aa96b2c..1d47ff0 100644 (file)
@@ -29,6 +29,7 @@ add_library(SPIRV-Tools-opt
   dominator_tree.h
   eliminate_dead_constant_pass.h
   eliminate_dead_functions_pass.h
+  feature_manager.h
   flatten_decoration_pass.h
   fold.h
   fold_spec_constant_op_and_composite_pass.h
@@ -58,6 +59,7 @@ add_library(SPIRV-Tools-opt
   redundancy_elimination.h
   reflect.h
   remove_duplicates_pass.h
+  scalar_replacement_pass.h
   set_spec_constant_default_value_pass.h
   strength_reduction_pass.h
   strip_debug_info_pass.h
@@ -83,6 +85,7 @@ add_library(SPIRV-Tools-opt
   dominator_tree.cpp
   eliminate_dead_constant_pass.cpp
   eliminate_dead_functions_pass.cpp
+  feature_manager.cpp
   flatten_decoration_pass.cpp
   fold.cpp
   fold_spec_constant_op_and_composite_pass.cpp
@@ -110,6 +113,7 @@ add_library(SPIRV-Tools-opt
   propagator.cpp
   redundancy_elimination.cpp
   remove_duplicates_pass.cpp
+  scalar_replacement_pass.cpp
   set_spec_constant_default_value_pass.cpp
   strength_reduction_pass.cpp
   strip_debug_info_pass.cpp
index be9dc6c..8d4433b 100644 (file)
@@ -136,6 +136,26 @@ void DefUseManager::ForEachUse(
   ForEachUse(GetDef(id), f);
 }
 
+uint32_t DefUseManager::NumUsers(const ir::Instruction* def) const {
+  uint32_t count = 0;
+  ForEachUser(def, [&count](ir::Instruction*) { ++count; });
+  return count;
+}
+
+uint32_t DefUseManager::NumUsers(uint32_t id) const {
+  return NumUsers(GetDef(id));
+}
+
+uint32_t DefUseManager::NumUses(const ir::Instruction* def) const {
+  uint32_t count = 0;
+  ForEachUse(def, [&count](ir::Instruction*, uint32_t) { ++count; });
+  return count;
+}
+
+uint32_t DefUseManager::NumUses(uint32_t id) const {
+  return NumUses(GetDef(id));
+}
+
 std::vector<ir::Instruction*> DefUseManager::GetAnnotations(uint32_t id) const {
   std::vector<ir::Instruction*> annos;
   const ir::Instruction* def = GetDef(id);
index c660927..1a8d989 100644 (file)
@@ -151,6 +151,14 @@ class DefUseManager {
                   const std::function<void(ir::Instruction*,
                                            uint32_t operand_index)>& f) const;
 
+  // Returns the number of users of |def| (or |id|).
+  uint32_t NumUsers(const ir::Instruction* def) const;
+  uint32_t NumUsers(uint32_t id) const;
+
+  // Returns the number of uses of |def| (or |id|).
+  uint32_t NumUses(const ir::Instruction* def) const;
+  uint32_t NumUses(uint32_t id) const;
+
   // Returns the annotation instrunctions which are a direct use of the given
   // |id|. This means when the decorations are applied through decoration
   // group(s), this function will just return the OpGroupDecorate
diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp
new file mode 100644 (file)
index 0000000..5bd43e9
--- /dev/null
@@ -0,0 +1,34 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "feature_manager.h"
+
+#include "enum_string_mapping.h"
+
+namespace spvtools {
+namespace opt {
+
+void FeatureManager::Analyze(ir::Module* module) {
+  for (auto ext : module->extensions()) {
+    const std::string name =
+        reinterpret_cast<const char*>(ext.GetInOperand(0u).words.data());
+    libspirv::Extension extension;
+    if (libspirv::GetExtensionFromString(name, &extension)) {
+      extensions_.Add(extension);
+    }
+  }
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/feature_manager.h b/source/opt/feature_manager.h
new file mode 100644 (file)
index 0000000..874bf50
--- /dev/null
@@ -0,0 +1,45 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_FEATURE_MANAGER_H_
+#define LIBSPIRV_OPT_FEATURE_MANAGER_H_
+
+#include "extensions.h"
+#include "module.h"
+
+namespace spvtools {
+namespace opt {
+
+// Tracks features enabled by a module. The IRContext has a FeatureManager.
+class FeatureManager {
+ public:
+  FeatureManager() = default;
+
+  // Returns true if |ext| is an enabled extension in the module.
+  bool HasExtension(libspirv::Extension ext) const {
+    return extensions_.Contains(ext);
+  }
+
+  // Analyzes |module| and records enabled extensions.
+  void Analyze(ir::Module* module);
+
+ private:
+  // The enabled extensions.
+  libspirv::ExtensionSet extensions_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_FEATURE_MANAGER_H_
index 646a30f..c0d21f1 100644 (file)
@@ -20,6 +20,7 @@
 #include "decoration_manager.h"
 #include "def_use_manager.h"
 #include "dominator_analysis.h"
+#include "feature_manager.h"
 #include "module.h"
 #include "type_manager.h"
 
@@ -198,6 +199,25 @@ class IRContext {
     return (entry != instr_to_block_.end()) ? entry->second : nullptr;
   }
 
+  // Returns the basic block for |id|. Re-builds the instruction block map, if
+  // needed.
+  //
+  // |id| must be a registered definition.
+  ir::BasicBlock* get_instr_block(uint32_t id) {
+    ir::Instruction* def = get_def_use_mgr()->GetDef(id);
+    return get_instr_block(def);
+  }
+
+  // Sets the basic block for |inst|. Re-builds the mapping if it has become
+  // invalid.
+  void set_instr_block(ir::Instruction* inst, ir::BasicBlock* block) {
+    if (AreAnalysesValid(kAnalysisInstrToBlockMapping)) {
+      instr_to_block_[inst] = block;
+    } else {
+      BuildInstrToBlockMapping();
+    }
+  }
+
   // Returns a pointer the decoration manager.  If the decoration manger is
   // invalid, it is rebuilt first.
   opt::analysis::DecorationManager* get_decoration_mgr() {
@@ -351,6 +371,13 @@ class IRContext {
   // Return the next available SSA id and increment it.
   inline uint32_t TakeNextId() { return module()->TakeNextIdBound(); }
 
+  opt::FeatureManager* get_feature_mgr() {
+    if (!feature_mgr_.get()) {
+      AnalyzeFeatures();
+    }
+    return feature_mgr_.get();
+  }
+
  private:
   // Builds the def-use manager from scratch, even if it was already valid.
   void BuildDefUseManager() {
@@ -381,6 +408,12 @@ class IRContext {
     valid_analyses_ = valid_analyses_ | kAnalysisCFG;
   }
 
+  // Analyzes the features in the owned module. Builds the manager if required.
+  void AnalyzeFeatures() {
+    feature_mgr_.reset(new opt::FeatureManager());
+    feature_mgr_->Analyze(module());
+  }
+
   // Scans a module looking for it capabilities, and initializes combinator_ops_
   // accordingly.
   void InitializeCombinators();
@@ -409,6 +442,7 @@ class IRContext {
 
   // The instruction decoration manager for |module_|.
   std::unique_ptr<opt::analysis::DecorationManager> decoration_mgr_;
+  std::unique_ptr<opt::FeatureManager> feature_mgr_;
 
   // A map from instructions the the basic block they belong to. This mapping is
   // built on-demand when get_instr_block() is called.
index 2c9fb34..82e57f8 100644 (file)
@@ -85,6 +85,7 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
   return RegisterPass(CreateMergeReturnPass())
       .RegisterPass(CreateInlineExhaustivePass())
       .RegisterPass(CreateEliminateDeadFunctionsPass())
+      .RegisterPass(CreateScalarReplacementPass())
       .RegisterPass(CreateLocalAccessChainConvertPass())
       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
       .RegisterPass(CreateLocalSingleStoreElimPass())
@@ -287,4 +288,9 @@ Optimizer::PassToken CreateRedundancyEliminationPass() {
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::RedundancyEliminationPass>());
 }
+
+Optimizer::PassToken CreateScalarReplacementPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::ScalarReplacementPass>());
+}
 }  // namespace spvtools
index 9c1c580..ba3b270 100644 (file)
@@ -40,6 +40,7 @@
 #include "merge_return_pass.h"
 #include "null_pass.h"
 #include "redundancy_elimination.h"
+#include "scalar_replacement_pass.h"
 #include "set_spec_constant_default_value_pass.h"
 #include "strength_reduction_pass.h"
 #include "strip_debug_info_pass.h"
index 46d0049..ce5c331 100644 (file)
@@ -47,6 +47,12 @@ inline bool IsTypeInst(SpvOp opcode) {
 inline bool IsConstantInst(SpvOp opcode) {
   return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp;
 }
+inline bool IsCompileTimeConstantInst(SpvOp opcode) {
+  return opcode >= SpvOpConstantTrue && opcode <= SpvOpConstantNull;
+}
+inline bool IsSpecConstantInst(SpvOp opcode) {
+  return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp;
+}
 inline bool IsTerminatorInst(SpvOp opcode) {
   return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable;
 }
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
new file mode 100644 (file)
index 0000000..fd146b2
--- /dev/null
@@ -0,0 +1,680 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "scalar_replacement_pass.h"
+
+#include "enum_string_mapping.h"
+#include "extensions.h"
+#include "make_unique.h"
+#include "reflect.h"
+#include "types.h"
+
+#include <queue>
+
+namespace spvtools {
+namespace opt {
+
+// Heuristic aggregate element limit.
+const uint32_t MAX_NUM_ELEMENTS = 100u;
+
+Pass::Status ScalarReplacementPass::Process(ir::IRContext* c) {
+  InitializeProcessing(c);
+
+  Status status = Status::SuccessWithoutChange;
+  for (auto& f : *get_module()) {
+    Status functionStatus = ProcessFunction(&f);
+    if (functionStatus == Status::Failure)
+      return functionStatus;
+    else if (functionStatus == Status::SuccessWithChange)
+      status = functionStatus;
+  }
+
+  return status;
+}
+
+Pass::Status ScalarReplacementPass::ProcessFunction(ir::Function* function) {
+  std::queue<ir::Instruction*> worklist;
+  ir::BasicBlock& entry = *function->begin();
+  for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
+    // Function storage class OpVariables must appear as the first instructions
+    // of the entry block.
+    if (iter->opcode() != SpvOpVariable) break;
+
+    ir::Instruction* varInst = &*iter;
+    if (CanReplaceVariable(varInst)) {
+      worklist.push(varInst);
+    }
+  }
+
+  Status status = Status::SuccessWithoutChange;
+  while (!worklist.empty()) {
+    ir::Instruction* varInst = worklist.front();
+    worklist.pop();
+
+    if (!ReplaceVariable(varInst, &worklist))
+      return Status::Failure;
+    else
+      status = Status::SuccessWithChange;
+  }
+
+  return status;
+}
+
+bool ScalarReplacementPass::ReplaceVariable(
+    ir::Instruction* inst, std::queue<ir::Instruction*>* worklist) {
+  std::vector<ir::Instruction*> replacements;
+  CreateReplacementVariables(inst, &replacements);
+
+  bool ok = true;
+  std::vector<ir::Instruction*> dead;
+  dead.push_back(inst);
+  get_def_use_mgr()->ForEachUser(
+      inst, [this, &ok, &replacements, &dead](ir::Instruction* user) {
+        if (!ir::IsAnnotationInst(user->opcode())) {
+          switch (user->opcode()) {
+            case SpvOpLoad:
+              ReplaceWholeLoad(user, replacements);
+              dead.push_back(user);
+              break;
+            case SpvOpStore:
+              ReplaceWholeStore(user, replacements);
+              dead.push_back(user);
+              break;
+            case SpvOpAccessChain:
+            case SpvOpInBoundsAccessChain:
+              ok &= ReplaceAccessChain(user, replacements);
+              dead.push_back(user);
+              break;
+            case SpvOpName:
+            case SpvOpMemberName:
+              break;
+            default:
+              assert(false && "Unexpected opcode");
+              break;
+          }
+        }
+      });
+
+  // There was an illegal access.
+  if (!ok) return false;
+
+  // Clean up some dead code.
+  while (!dead.empty()) {
+    ir::Instruction* toKill = dead.back();
+    dead.pop_back();
+
+    context()->KillInst(toKill);
+  }
+
+  // Attempt to further scalarize.
+  for (auto var : replacements) {
+    if (get_def_use_mgr()->NumUsers(var) == 0) {
+      context()->KillInst(var);
+    } else if (CanReplaceVariable(var)) {
+      worklist->push(var);
+    }
+  }
+
+  return ok;
+}
+
+void ScalarReplacementPass::ReplaceWholeLoad(
+    ir::Instruction* load, const std::vector<ir::Instruction*>& replacements) {
+  // Replaces the load of the entire composite with a load from each replacement
+  // variable followed by a composite construction.
+  ir::BasicBlock* block = context()->get_instr_block(load);
+  std::vector<ir::Instruction*> loads;
+  loads.reserve(replacements.size());
+  ir::BasicBlock::iterator where(load);
+  for (auto var : replacements) {
+    // Create a load of each replacement variable.
+    ir::Instruction* type = GetStorageType(var);
+    uint32_t loadId = TakeNextId();
+    std::unique_ptr<ir::Instruction> newLoad(
+        new ir::Instruction(context(), SpvOpLoad, type->result_id(), loadId,
+                            std::initializer_list<ir::Operand>{
+                                {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
+    // Copy memory access attributes which start at index 1. Index 0 is the
+    // pointer to load.
+    for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
+      ir::Operand copy(load->GetInOperand(i));
+      newLoad->AddOperand(std::move(copy));
+    }
+    where = where.InsertBefore(std::move(newLoad));
+    get_def_use_mgr()->AnalyzeInstDefUse(&*where);
+    context()->set_instr_block(&*where, block);
+    loads.push_back(&*where);
+  }
+
+  // Construct a new composite.
+  uint32_t compositeId = TakeNextId();
+  where = load;
+  std::unique_ptr<ir::Instruction> compositeConstruct(new ir::Instruction(
+      context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
+  for (auto l : loads) {
+    ir::Operand op(SPV_OPERAND_TYPE_ID,
+                   std::initializer_list<uint32_t>{l->result_id()});
+    compositeConstruct->AddOperand(std::move(op));
+  }
+  where = where.InsertBefore(std::move(compositeConstruct));
+  get_def_use_mgr()->AnalyzeInstDefUse(&*where);
+  context()->set_instr_block(&*where, block);
+  context()->ReplaceAllUsesWith(load->result_id(), compositeId);
+}
+
+void ScalarReplacementPass::ReplaceWholeStore(
+    ir::Instruction* store, const std::vector<ir::Instruction*>& replacements) {
+  // Replaces a store to the whole composite with a series of extract and stores
+  // to each element.
+  uint32_t storeInput = store->GetSingleWordInOperand(1u);
+  ir::BasicBlock* block = context()->get_instr_block(store);
+  ir::BasicBlock::iterator where(store);
+  uint32_t elementIndex = 0;
+  for (auto var : replacements) {
+    // Create the extract.
+    ir::Instruction* type = GetStorageType(var);
+    uint32_t extractId = TakeNextId();
+    std::unique_ptr<ir::Instruction> extract(new ir::Instruction(
+        context(), SpvOpCompositeExtract, type->result_id(), extractId,
+        std::initializer_list<ir::Operand>{
+            {SPV_OPERAND_TYPE_ID, {storeInput}},
+            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
+    auto iter = where.InsertBefore(std::move(extract));
+    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
+    context()->set_instr_block(&*iter, block);
+
+    // Create the store.
+    std::unique_ptr<ir::Instruction> newStore(
+        new ir::Instruction(context(), SpvOpStore, 0, 0,
+                            std::initializer_list<ir::Operand>{
+                                {SPV_OPERAND_TYPE_ID, {var->result_id()}},
+                                {SPV_OPERAND_TYPE_ID, {extractId}}}));
+    // Copy memory access attributes which start at index 2. Index 0 is the
+    // pointer and index 1 is the data.
+    for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
+      ir::Operand copy(store->GetInOperand(i));
+      newStore->AddOperand(std::move(copy));
+    }
+    iter = where.InsertBefore(std::move(newStore));
+    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
+    context()->set_instr_block(&*iter, block);
+  }
+}
+
+bool ScalarReplacementPass::ReplaceAccessChain(
+    ir::Instruction* chain, const std::vector<ir::Instruction*>& replacements) {
+  // Replaces the access chain with either another access chain (with one fewer
+  // indexes) or a direct use of the replacement variable.
+  uint32_t indexId = chain->GetSingleWordInOperand(1u);
+  const ir::Instruction* index = get_def_use_mgr()->GetDef(indexId);
+  size_t indexValue = GetConstantInteger(index);
+  if (indexValue > replacements.size()) {
+    // Out of bounds access, this is illegal IR.
+    return false;
+  } else {
+    const ir::Instruction* var = replacements[indexValue];
+    if (chain->NumInOperands() > 2) {
+      // Replace input access chain with another access chain.
+      ir::BasicBlock::iterator chainIter(chain);
+      uint32_t replacementId = TakeNextId();
+      std::unique_ptr<ir::Instruction> replacementChain(new ir::Instruction(
+          context(), chain->opcode(), chain->type_id(), replacementId,
+          std::initializer_list<ir::Operand>{
+              {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
+      // Add the remaining indexes.
+      for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
+        ir::Operand copy(chain->GetInOperand(i));
+        replacementChain->AddOperand(std::move(copy));
+      }
+      auto iter = chainIter.InsertBefore(std::move(replacementChain));
+      get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
+      context()->set_instr_block(&*iter, context()->get_instr_block(chain));
+      context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
+    } else {
+      // Replace with a use of the variable.
+      context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
+    }
+  }
+
+  return true;
+}
+
+void ScalarReplacementPass::CreateReplacementVariables(
+    ir::Instruction* inst, std::vector<ir::Instruction*>* replacements) {
+  ir::Instruction* type = GetStorageType(inst);
+  uint32_t elem = 0;
+  switch (type->opcode()) {
+    case SpvOpTypeStruct:
+      type->ForEachInOperand([this, inst, &elem, replacements](uint32_t* id) {
+        CreateVariable(*id, inst, elem++, replacements);
+      });
+      break;
+    case SpvOpTypeArray:
+      for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
+        CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
+      }
+      break;
+
+    case SpvOpTypeMatrix:
+    case SpvOpTypeVector:
+      for (uint32_t i = 0; i != GetNumElements(type); ++i) {
+        CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
+      }
+      break;
+
+    default:
+      assert(false && "Unexpected type.");
+      break;
+  }
+
+  TransferAnnotations(inst, replacements);
+}
+
+void ScalarReplacementPass::TransferAnnotations(
+    const ir::Instruction* source,
+    std::vector<ir::Instruction*>* replacements) {
+  // Only transfer invariant and restrict decorations on the variable. There are
+  // no type or member decorations that are necessary to transfer.
+  for (auto inst :
+       get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
+    assert(inst->opcode() == SpvOpDecorate);
+    uint32_t decoration = inst->GetSingleWordInOperand(1u);
+    if (decoration == SpvDecorationInvariant ||
+        decoration == SpvDecorationRestrict) {
+      for (auto var : *replacements) {
+        std::unique_ptr<ir::Instruction> annotation(new ir::Instruction(
+            context(), SpvOpDecorate, 0, 0,
+            std::initializer_list<ir::Operand>{
+                {SPV_OPERAND_TYPE_ID, {var->result_id()}},
+                {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
+        for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
+          ir::Operand copy(inst->GetInOperand(i));
+          annotation->AddOperand(std::move(copy));
+        }
+        context()->AddAnnotationInst(std::move(annotation));
+        get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
+      }
+    }
+  }
+}
+
+void ScalarReplacementPass::CreateVariable(
+    uint32_t typeId, ir::Instruction* varInst, uint32_t index,
+    std::vector<ir::Instruction*>* replacements) {
+  uint32_t ptrId = GetOrCreatePointerType(typeId);
+  uint32_t id = TakeNextId();
+  std::unique_ptr<ir::Instruction> variable(new ir::Instruction(
+      context(), SpvOpVariable, ptrId, id,
+      std::initializer_list<ir::Operand>{
+          {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
+
+  ir::BasicBlock* block = context()->get_instr_block(varInst);
+  block->begin().InsertBefore(std::move(variable));
+  ir::Instruction* inst = &*block->begin();
+
+  // If varInst was initialized, make sure to initialize its replacement.
+  GetOrCreateInitialValue(varInst, index, inst);
+  get_def_use_mgr()->AnalyzeInstDefUse(inst);
+  context()->set_instr_block(inst, block);
+
+  replacements->push_back(inst);
+}
+
+uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
+  auto iter = pointee_to_pointer_.find(id);
+  if (iter != pointee_to_pointer_.end()) return iter->second;
+
+  // TODO(alanbaker): Make the type manager useful and then replace this code.
+  uint32_t ptrId = 0;
+  for (auto global : context()->types_values()) {
+    if (global.opcode() == SpvOpTypePointer &&
+        global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
+        global.GetSingleWordInOperand(1u) == id) {
+      if (!context()->get_feature_mgr()->HasExtension(
+              libspirv::Extension::kSPV_KHR_variable_pointers) ||
+          get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
+        // If variable pointers is enabled, only reuse a decoration-less
+        // pointer of the correct type
+        ptrId = global.result_id();
+        break;
+      }
+    }
+  }
+
+  if (ptrId != 0) {
+    pointee_to_pointer_[id] = ptrId;
+    return ptrId;
+  }
+
+  ptrId = TakeNextId();
+  context()->AddType(MakeUnique<ir::Instruction>(
+      context(), SpvOpTypePointer, 0, ptrId,
+      std::initializer_list<ir::Operand>{
+          {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
+          {SPV_OPERAND_TYPE_ID, {id}}}));
+  ir::Instruction* ptr = &*--context()->types_values_end();
+  get_def_use_mgr()->AnalyzeInstDefUse(ptr);
+  pointee_to_pointer_[id] = ptrId;
+
+  return ptrId;
+}
+
+void ScalarReplacementPass::GetOrCreateInitialValue(ir::Instruction* source,
+                                                    uint32_t index,
+                                                    ir::Instruction* newVar) {
+  assert(source->opcode() == SpvOpVariable);
+  if (source->NumInOperands() < 2) return;
+
+  uint32_t initId = source->GetSingleWordInOperand(1u);
+  uint32_t storageId = GetStorageType(newVar)->result_id();
+  ir::Instruction* init = get_def_use_mgr()->GetDef(initId);
+  uint32_t newInitId = 0;
+  // TODO(dnovillo): Refactor this with constant propagation.
+  if (init->opcode() == SpvOpConstantNull) {
+    // Initialize to appropriate NULL.
+    auto iter = type_to_null_.find(storageId);
+    if (iter == type_to_null_.end()) {
+      newInitId = TakeNextId();
+      type_to_null_[storageId] = newInitId;
+      context()->AddGlobalValue(MakeUnique<ir::Instruction>(
+          context(), SpvOpConstantNull, storageId, newInitId,
+          std::initializer_list<ir::Operand>{}));
+      ir::Instruction* newNull = &*--context()->types_values_end();
+      get_def_use_mgr()->AnalyzeInstDefUse(newNull);
+    } else {
+      newInitId = iter->second;
+    }
+  } else if (ir::IsSpecConstantInst(init->opcode())) {
+    // Create a new constant extract.
+    newInitId = TakeNextId();
+    context()->AddGlobalValue(MakeUnique<ir::Instruction>(
+        context(), SpvOpSpecConstantOp, storageId, newInitId,
+        std::initializer_list<ir::Operand>{
+            {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
+            {SPV_OPERAND_TYPE_ID, {init->result_id()}},
+            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
+    ir::Instruction* newSpecConst = &*--context()->types_values_end();
+    get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
+  } else if (init->opcode() == SpvOpConstantComposite) {
+    // Get the appropriate index constant.
+    newInitId = init->GetSingleWordInOperand(index);
+    ir::Instruction* element = get_def_use_mgr()->GetDef(newInitId);
+    if (element->opcode() == SpvOpUndef) {
+      // Undef is not a valid initializer for a variable.
+      newInitId = 0;
+    }
+  } else {
+    assert(false);
+  }
+
+  if (newInitId != 0) {
+    newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
+  }
+}
+
+size_t ScalarReplacementPass::GetIntegerLiteral(const ir::Operand& op) const {
+  assert(op.words.size() <= 2);
+  size_t len = 0;
+  for (uint32_t i = 0; i != op.words.size(); ++i) {
+    len |= (op.words[i] << (32 * i));
+  }
+  return len;
+}
+
+size_t ScalarReplacementPass::GetConstantInteger(
+    const ir::Instruction* constant) const {
+  assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
+         SpvOpTypeInt);
+  assert(constant->opcode() == SpvOpConstant ||
+         constant->opcode() == SpvOpConstantNull);
+  if (constant->opcode() == SpvOpConstantNull) {
+    return 0;
+  }
+
+  const ir::Operand& op = constant->GetInOperand(0u);
+  return GetIntegerLiteral(op);
+}
+
+size_t ScalarReplacementPass::GetArrayLength(
+    const ir::Instruction* arrayType) const {
+  assert(arrayType->opcode() == SpvOpTypeArray);
+  const ir::Instruction* length =
+      get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
+  return GetConstantInteger(length);
+}
+
+size_t ScalarReplacementPass::GetNumElements(
+    const ir::Instruction* type) const {
+  assert(type->opcode() == SpvOpTypeVector ||
+         type->opcode() == SpvOpTypeMatrix);
+  const ir::Operand& op = type->GetInOperand(1u);
+  assert(op.words.size() <= 2);
+  size_t len = 0;
+  for (uint32_t i = 0; i != op.words.size(); ++i) {
+    len |= (op.words[i] << (32 * i));
+  }
+  return len;
+}
+
+ir::Instruction* ScalarReplacementPass::GetStorageType(
+    const ir::Instruction* inst) const {
+  assert(inst->opcode() == SpvOpVariable);
+
+  uint32_t ptrTypeId = inst->type_id();
+  uint32_t typeId =
+      get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
+  return get_def_use_mgr()->GetDef(typeId);
+}
+
+bool ScalarReplacementPass::CanReplaceVariable(
+    const ir::Instruction* varInst) const {
+  assert(varInst->opcode() == SpvOpVariable);
+
+  // Can only replace function scope variables.
+  if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction)
+    return false;
+
+  if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id())))
+    return false;
+
+  const ir::Instruction* typeInst = GetStorageType(varInst);
+  return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst);
+}
+
+bool ScalarReplacementPass::CheckType(const ir::Instruction* typeInst) const {
+  if (!CheckTypeAnnotations(typeInst)) return false;
+
+  switch (typeInst->opcode()) {
+    case SpvOpTypeStruct:
+      // Don't bother with empty structs or very large structs.
+      if (typeInst->NumInOperands() == 0 ||
+          typeInst->NumInOperands() > MAX_NUM_ELEMENTS)
+        return false;
+      return true;
+    case SpvOpTypeArray:
+      if (GetArrayLength(typeInst) > MAX_NUM_ELEMENTS) return false;
+      return true;
+    // TODO(alanbaker): Develop some heuristics for when this should be
+    // re-enabled.
+    //// Specifically including matrix and vector in an attempt to reduce the
+    //// number of vector registers required.
+    // case SpvOpTypeMatrix:
+    // case SpvOpTypeVector:
+    //  if (GetNumElements(typeInst) > MAX_NUM_ELEMENTS) return false;
+    //  return true;
+
+    case SpvOpTypeRuntimeArray:
+    default:
+      return false;
+  }
+}
+
+bool ScalarReplacementPass::CheckTypeAnnotations(
+    const ir::Instruction* typeInst) const {
+  for (auto inst :
+       get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
+    uint32_t decoration;
+    if (inst->opcode() == SpvOpDecorate) {
+      decoration = inst->GetSingleWordInOperand(1u);
+    } else {
+      assert(inst->opcode() == SpvOpMemberDecorate);
+      decoration = inst->GetSingleWordInOperand(2u);
+    }
+
+    switch (decoration) {
+      case SpvDecorationRowMajor:
+      case SpvDecorationColMajor:
+      case SpvDecorationArrayStride:
+      case SpvDecorationMatrixStride:
+      case SpvDecorationCPacked:
+      case SpvDecorationInvariant:
+      case SpvDecorationRestrict:
+      case SpvDecorationOffset:
+      case SpvDecorationAlignment:
+      case SpvDecorationAlignmentId:
+      case SpvDecorationMaxByteOffset:
+        break;
+      default:
+        return false;
+    }
+  }
+
+  return true;
+}
+
+bool ScalarReplacementPass::CheckAnnotations(
+    const ir::Instruction* varInst) const {
+  for (auto inst :
+       get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
+    assert(inst->opcode() == SpvOpDecorate);
+    uint32_t decoration = inst->GetSingleWordInOperand(1u);
+    switch (decoration) {
+      case SpvDecorationInvariant:
+      case SpvDecorationRestrict:
+      case SpvDecorationAlignment:
+      case SpvDecorationAlignmentId:
+      case SpvDecorationMaxByteOffset:
+        break;
+      default:
+        return false;
+    }
+  }
+
+  return true;
+}
+
+bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst) const {
+  VariableStats stats = {0, 0};
+  bool ok = CheckUses(inst, &stats);
+
+  // TODO(alanbaker): Extend this to some meaningful heuristics about when
+  // SRoA is valuable.
+  if (stats.num_partial_accesses == 0) ok = false;
+
+  return ok;
+}
+
+bool ScalarReplacementPass::CheckUses(const ir::Instruction* inst,
+                                      VariableStats* stats) const {
+  bool ok = true;
+  get_def_use_mgr()->ForEachUse(
+      inst, [this, stats, &ok](const ir::Instruction* user, uint32_t index) {
+        // Annotations are check as a group separately.
+        if (!ir::IsAnnotationInst(user->opcode())) {
+          switch (user->opcode()) {
+            case SpvOpAccessChain:
+            case SpvOpInBoundsAccessChain:
+              if (index == 2u) {
+                uint32_t id = user->GetSingleWordOperand(3u);
+                const ir::Instruction* opInst = get_def_use_mgr()->GetDef(id);
+                if (!ir::IsCompileTimeConstantInst(opInst->opcode())) {
+                  ok = false;
+                } else {
+                  if (!CheckUsesRelaxed(user)) ok = false;
+                }
+                stats->num_partial_accesses++;
+              } else {
+                ok = false;
+              }
+              break;
+            case SpvOpLoad:
+              if (!CheckLoad(user, index)) ok = false;
+              stats->num_full_accesses++;
+              break;
+            case SpvOpStore:
+              if (!CheckStore(user, index)) ok = false;
+              stats->num_full_accesses++;
+              break;
+            case SpvOpName:
+            case SpvOpMemberName:
+              break;
+            default:
+              ok = false;
+              break;
+          }
+        }
+      });
+
+  return ok;
+}
+
+bool ScalarReplacementPass::CheckUsesRelaxed(
+    const ir::Instruction* inst) const {
+  bool ok = true;
+  get_def_use_mgr()->ForEachUse(
+      inst, [this, &ok](const ir::Instruction* user, uint32_t index) {
+        switch (user->opcode()) {
+          case SpvOpAccessChain:
+          case SpvOpInBoundsAccessChain:
+            if (index != 2u) {
+              ok = false;
+            } else {
+              if (!CheckUsesRelaxed(user)) ok = false;
+            }
+            break;
+          case SpvOpLoad:
+            if (!CheckLoad(user, index)) ok = false;
+            break;
+          case SpvOpStore:
+            if (!CheckStore(user, index)) ok = false;
+            break;
+          default:
+            ok = false;
+            break;
+        }
+      });
+
+  return ok;
+}
+
+bool ScalarReplacementPass::CheckLoad(const ir::Instruction* inst,
+                                      uint32_t index) const {
+  if (index != 2u) return false;
+  if (inst->NumInOperands() >= 2 &&
+      inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
+    return false;
+  return true;
+}
+
+bool ScalarReplacementPass::CheckStore(const ir::Instruction* inst,
+                                       uint32_t index) const {
+  if (index != 0u) return false;
+  if (inst->NumInOperands() >= 3 &&
+      inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
+    return false;
+  return true;
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
new file mode 100644 (file)
index 0000000..47542f9
--- /dev/null
@@ -0,0 +1,204 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_
+#define LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_
+
+#include "function.h"
+#include "pass.h"
+#include "type_manager.h"
+
+#include <queue>
+
+namespace spvtools {
+namespace opt {
+
+// Documented in optimizer.hpp
+class ScalarReplacementPass : public Pass {
+ public:
+  ScalarReplacementPass() = default;
+
+  const char* name() const override { return "scalar-replacement"; }
+
+  // Attempts to scalarize all appropriate function scope variables. Returns
+  // SuccessWithChange if any change is made.
+  Status Process(ir::IRContext* c) override;
+
+  ir::IRContext::Analysis GetPreservedAnalyses() override {
+    return ir::IRContext::kAnalysisDefUse |
+           ir::IRContext::kAnalysisInstrToBlockMapping |
+           ir::IRContext::kAnalysisDecorations |
+           ir::IRContext::kAnalysisCombinators | ir::IRContext::kAnalysisCFG;
+  }
+
+ private:
+  // Small container for tracking statistics about variables.
+  //
+  // TODO(alanbaker): Develop some useful heuristics to tune this pass.
+  struct VariableStats {
+    uint32_t num_partial_accesses;
+    uint32_t num_full_accesses;
+  };
+
+  // Attempts to scalarize all appropriate function scope variables in
+  // |function|. Returns SuccessWithChange if any changes are mode.
+  Status ProcessFunction(ir::Function* function);
+
+  // Returns true if |varInst| can be scalarized.
+  //
+  // Examines the use chain of |varInst| to verify all uses are valid for
+  // scalarization.
+  bool CanReplaceVariable(const ir::Instruction* varInst) const;
+
+  // Returns true if |typeInst| is an acceptable type to scalarize.
+  //
+  // Allows all aggregate types except runtime arrays. Additionally, checks the
+  // that the number of elements that would be scalarized is within bounds.
+  bool CheckType(const ir::Instruction* typeInst) const;
+
+  // Returns true if all the decorations for |varInst| are acceptable for
+  // scalarization.
+  bool CheckAnnotations(const ir::Instruction* varInst) const;
+
+  // Returns true if all the decorations for |typeInst| are acceptable for
+  // scalarization.
+  bool CheckTypeAnnotations(const ir::Instruction* typeInst) const;
+
+  // Returns true if the uses of |inst| are acceptable for scalarization.
+  //
+  // Recursively checks all the uses of |inst|. For |inst| specifically, only
+  // allows SpvOpAccessChain, SpvOpInBoundsAccessChain, SpvOpLoad and
+  // SpvOpStore. Access chains must have the first index be a compile-time
+  // constant. Subsequent uses of access chains (including other access chains)
+  // are checked in a more relaxed manner.
+  bool CheckUses(const ir::Instruction* inst) const;
+
+  // Helper function for the above |CheckUses|.
+  //
+  // This version tracks some stats about the current OpVariable. These stats
+  // are used to drive heuristics about when to scalarize.
+  bool CheckUses(const ir::Instruction* inst, VariableStats* stats) const;
+
+  // Relaxed helper function for |CheckUses|.
+  bool CheckUsesRelaxed(const ir::Instruction* inst) const;
+
+  // Transfers appropriate decorations from |source| to |replacements|.
+  void TransferAnnotations(const ir::Instruction* source,
+                           std::vector<ir::Instruction*>* replacements);
+
+  // Scalarizes |inst| and updates its uses.
+  //
+  // |inst| must be an OpVariable. It is replaced with an OpVariable for each
+  // for element of the composite type. Uses of |inst| are updated as
+  // appropriate. If the replacement variables are themselves scalarizable, they
+  // get added to |worklist| for further processing. If any replacement
+  // variable ends up with no uses it is erased. Returns false if any
+  // subsequent access chain is out of bounds.
+  bool ReplaceVariable(ir::Instruction* inst,
+                       std::queue<ir::Instruction*>* worklist);
+
+  // Returns the underlying storage type for |inst|.
+  //
+  // |inst| must be an OpVariable. Returns the type that is pointed to by
+  // |inst|.
+  ir::Instruction* GetStorageType(const ir::Instruction* inst) const;
+
+  // Returns true if the load can be scalarized.
+  //
+  // |inst| must be an OpLoad. Returns true if |index| is the pointer operand of
+  // |inst| and the load is not from volatile memory.
+  bool CheckLoad(const ir::Instruction* inst, uint32_t index) const;
+
+  // Returns true if the store can be scalarized.
+  //
+  // |inst| must be an OpStore. Returns true if |index| is the pointer operand
+  // of |inst| and the store is not to volatile memory.
+  bool CheckStore(const ir::Instruction* inst, uint32_t index) const;
+
+  // Creates a variable of type |typeId| from the |index|'th element of
+  // |varInst|. The new variable is added to |replacements|.
+  void CreateVariable(uint32_t typeId, ir::Instruction* varInst, uint32_t index,
+                      std::vector<ir::Instruction*>* replacements);
+
+  // Populates |replacements| with a new OpVariable for each element of |inst|.
+  //
+  // |inst| must be an OpVariable of a composite type. New variables are
+  // initialized the same as the corresponding index in |inst|. |replacements|
+  // will contain a variable for each element of the composite with matching
+  // indexes (i.e. the 0'th element of |inst| is the 0'th entry of
+  // |replacements|).
+  void CreateReplacementVariables(ir::Instruction* inst,
+                                  std::vector<ir::Instruction*>* replacements);
+
+  // Returns the value of an OpConstant of integer type.
+  //
+  // |constant| must use two or fewer words to generate the value.
+  size_t GetConstantInteger(const ir::Instruction* constant) const;
+
+  // Returns the integer literal for |op|.
+  size_t GetIntegerLiteral(const ir::Operand& op) const;
+
+  // Returns the array length for |arrayInst|.
+  size_t GetArrayLength(const ir::Instruction* arrayInst) const;
+
+  // Returns the number of elements in |type|.
+  //
+  // |type| must be a vector or matrix type.
+  size_t GetNumElements(const ir::Instruction* type) const;
+
+  // Returns an id for a pointer to |id|.
+  uint32_t GetOrCreatePointerType(uint32_t id);
+
+  // Creates the initial value for the |index| element of |source| in |newVar|.
+  //
+  // If there is an initial value for |source| for element |index|, it is
+  // appended as an operand on |newVar|. If the initial value is OpUndef, no
+  // initial value is added to |newVar|.
+  void GetOrCreateInitialValue(ir::Instruction* source, uint32_t index,
+                               ir::Instruction* newVar);
+
+  // Replaces the load to the entire composite.
+  //
+  // Generates a load for each replacement variable and then creates a new
+  // composite by combining all of the loads.
+  //
+  // |load| must be a load.
+  void ReplaceWholeLoad(ir::Instruction* load,
+                        const std::vector<ir::Instruction*>& replacements);
+
+  // Replaces the store to the entire composite.
+  //
+  // Generates a composite extract and store for each element in the scalarized
+  // variable from the original store data input.
+  void ReplaceWholeStore(ir::Instruction* store,
+                         const std::vector<ir::Instruction*>& replacements);
+
+  // Replaces an access chain to the composite variable with either a direct use
+  // of the appropriate replacement variable or another access chain with the
+  // replacement variable as the base and one fewer indexes. Returns false if
+  // the chain has an out of bounds access.
+  bool ReplaceAccessChain(ir::Instruction* chain,
+                          const std::vector<ir::Instruction*>& replacements);
+
+  // Maps storage type to a pointer type enclosing that type.
+  std::unordered_map<uint32_t, uint32_t> pointee_to_pointer_;
+
+  // Maps type id to OpConstantNull for that type.
+  std::unordered_map<uint32_t, uint32_t> type_to_null_;
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // LIBSPIRV_OPT_SCALAR_REPLACEMENT_PASS_H_
index 3273e85..3cb9ec0 100644 (file)
@@ -196,6 +196,11 @@ add_spvtools_unittest(TARGET pass_strength_reduction
   LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET pass_scalar_replacement
+  SRCS scalar_replacement_test.cpp pass_utils.cpp
+  LIBS SPIRV-Tools-opt
+)
+
 add_spvtools_unittest(TARGET cfg_cleanup
   SRCS cfg_cleanup_test.cpp pass_utils.cpp
   LIBS SPIRV-Tools-opt
@@ -206,6 +211,11 @@ add_spvtools_unittest(TARGET ir_context
   LIBS SPIRV-Tools-opt
 )
 
+add_spvtools_unittest(TARGET feature_manager
+  SRCS feature_manager_test.cpp
+  LIBS SPIRV-Tools-opt
+)
+
 add_spvtools_unittest(TARGET pass_merge_return
   SRCS pass_merge_return_test.cpp pass_utils.cpp
   LIBS SPIRV-Tools-opt
index b2b9409..60d2643 100644 (file)
@@ -918,8 +918,9 @@ OpReturn
 OpFunctionEnd
 )";
 
-  opt::Pass::Status res = std::get<1>(
-      SinglePassRunAndDisassemble<opt::CommonUniformElimPass>(text, true));
+  opt::Pass::Status res =
+      std::get<1>(SinglePassRunAndDisassemble<opt::CommonUniformElimPass>(
+          text, true, false));
   EXPECT_EQ(res, opt::Pass::Status::SuccessWithoutChange);
 }
 
@@ -1035,8 +1036,9 @@ OpReturn
 OpFunctionEnd
 )";
 
-  opt::Pass::Status res = std::get<1>(
-      SinglePassRunAndDisassemble<opt::CommonUniformElimPass>(text, true));
+  opt::Pass::Status res =
+      std::get<1>(SinglePassRunAndDisassemble<opt::CommonUniformElimPass>(
+          text, true, false));
   EXPECT_EQ(res, opt::Pass::Status::SuccessWithoutChange);
 }
 // TODO(greg-lunarg): Add tests to verify handling of these cases:
index 6f40a91..c780717 100644 (file)
@@ -99,7 +99,7 @@ TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepLiveFunction) {
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   std::string assembly = JoinAllInsts(text);
   auto result = SinglePassRunAndDisassemble<opt::EliminateDeadFunctionsPass>(
-      assembly, /* skip_nop = */ true);
+      assembly, /* skip_nop = */ true, /* do_validation = */ false);
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
   EXPECT_EQ(assembly, std::get<0>(result));
 }
@@ -138,7 +138,7 @@ TEST_F(EliminateDeadFunctionsBasicTest, BasicKeepExportFunctions) {
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   std::string assembly = JoinAllInsts(text);
   auto result = SinglePassRunAndDisassemble<opt::EliminateDeadFunctionsPass>(
-      assembly, /* skip_nop = */ true);
+      assembly, /* skip_nop = */ true, /* do_validation = */ false);
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
   EXPECT_EQ(assembly, std::get<0>(result));
 }
diff --git a/test/opt/feature_manager_test.cpp b/test/opt/feature_manager_test.cpp
new file mode 100644 (file)
index 0000000..264df1f
--- /dev/null
@@ -0,0 +1,86 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <algorithm>
+
+#include "opt/build_module.h"
+#include "opt/ir_context.h"
+
+using namespace spvtools;
+
+using FeatureManagerTest = ::testing::Test;
+
+TEST_F(FeatureManagerTest, MissingExtension) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  ASSERT_NE(context, nullptr);
+
+  ASSERT_FALSE(context->get_feature_mgr()->HasExtension(
+      libspirv::Extension::kSPV_KHR_variable_pointers));
+}
+
+TEST_F(FeatureManagerTest, OneExtension) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpExtension "SPV_KHR_variable_pointers"
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  ASSERT_NE(context, nullptr);
+
+  ASSERT_TRUE(context->get_feature_mgr()->HasExtension(
+      libspirv::Extension::kSPV_KHR_variable_pointers));
+}
+
+TEST_F(FeatureManagerTest, NotADifferentExtension) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpExtension "SPV_KHR_variable_pointers"
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  ASSERT_NE(context, nullptr);
+
+  ASSERT_FALSE(context->get_feature_mgr()->HasExtension(
+      libspirv::Extension::kSPV_KHR_storage_buffer_storage_class));
+}
+
+TEST_F(FeatureManagerTest, TwoExtensions) {
+  const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpExtension "SPV_KHR_variable_pointers"
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+  )";
+
+  std::unique_ptr<ir::IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
+  ASSERT_NE(context, nullptr);
+
+  ASSERT_TRUE(context->get_feature_mgr()->HasExtension(
+      libspirv::Extension::kSPV_KHR_variable_pointers));
+  ASSERT_TRUE(context->get_feature_mgr()->HasExtension(
+      libspirv::Extension::kSPV_KHR_storage_buffer_storage_class));
+}
index f5fa376..4e8064b 100644 (file)
@@ -209,7 +209,7 @@ TEST_P(FoldSpecConstantOpAndCompositePassTest, ParamTestCase) {
   auto status = opt::Pass::Status::SuccessWithoutChange;
   std::tie(optimized, status) =
       SinglePassRunAndDisassemble<opt::FoldSpecConstantOpAndCompositePass>(
-          original, /* skip_nop = */ true);
+          original, /* skip_nop = */ true, /* do_validation = */ false);
 
   // Check the optimized code, but ignore the OpName instructions.
   EXPECT_NE(opt::Pass::Status::Failure, status);
index 07a8c06..70ccf7b 100644 (file)
@@ -54,7 +54,7 @@ TEST_F(LocalRedundancyEliminationTest, RemoveRedundantAdd) {
                OpReturn
                OpFunctionEnd
   )";
-  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text, false);
 }
 
 // Make sure we keep instruction that are different, but look similar.
@@ -85,7 +85,7 @@ TEST_F(LocalRedundancyEliminationTest, KeepDifferentAdd) {
                OpFunctionEnd
   )";
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text, false);
 }
 
 // This test is check that the values are being propagated properly, and that
@@ -123,7 +123,7 @@ TEST_F(LocalRedundancyEliminationTest, RemoveMultipleInstructions) {
                OpFunctionEnd
   )";
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
-  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text, false);
 }
 
 // Redundant instructions in different blocks should be kept.
@@ -152,7 +152,7 @@ TEST_F(LocalRedundancyEliminationTest, KeepInstructionsInDifferentBlocks) {
                OpReturn
                OpFunctionEnd
   )";
-  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::LocalRedundancyEliminationPass>(text, false);
 }
 #endif
 }  // anonymous namespace
index 0e87c7b..33010ec 100644 (file)
@@ -89,11 +89,23 @@ class PassTest : public TestT {
   // disassembly string and the boolean value from the pass Process() function.
   template <typename PassT, typename... Args>
   std::tuple<std::string, opt::Pass::Status> SinglePassRunAndDisassemble(
-      const std::string& assembly, bool skip_nop, Args&&... args) {
+      const std::string& assembly, bool skip_nop, bool do_validation,
+      Args&&... args) {
     std::vector<uint32_t> optimized_bin;
     auto status = opt::Pass::Status::SuccessWithoutChange;
     std::tie(optimized_bin, status) = SinglePassRunToBinary<PassT>(
         assembly, skip_nop, std::forward<Args>(args)...);
+    if (do_validation) {
+      spv_target_env target_env = SPV_ENV_UNIVERSAL_1_1;
+      spv_context spvContext = spvContextCreate(target_env);
+      spv_diagnostic diagnostic = nullptr;
+      spv_const_binary_t binary = {optimized_bin.data(), optimized_bin.size()};
+      spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
+      EXPECT_EQ(error, 0);
+      if (error != 0) spvDiagnosticPrint(diagnostic);
+      spvDiagnosticDestroy(diagnostic);
+      spvContextDestroy(spvContext);
+    }
     std::string optimized_asm;
     EXPECT_TRUE(
         tools_.Disassemble(optimized_bin, &optimized_asm, disassemble_options_))
@@ -157,10 +169,11 @@ class PassTest : public TestT {
   // This does *not* involve pass manager.  Callers are suggested to use
   // SCOPED_TRACE() for better messages.
   template <typename PassT, typename... Args>
-  void SinglePassRunAndMatch(const std::string& original, Args&&... args) {
+  void SinglePassRunAndMatch(const std::string& original, bool do_validation,
+                             Args&&... args) {
     const bool skip_nop = true;
     auto pass_result = SinglePassRunAndDisassemble<PassT>(
-        original, skip_nop, std::forward<Args>(args)...);
+        original, skip_nop, do_validation, std::forward<Args>(args)...);
     auto disassembly = std::get<0>(pass_result);
     auto match_result = effcee::Match(disassembly, original);
     EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
index a257313..c5f3868 100644 (file)
@@ -55,7 +55,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantLocalAdd) {
                OpReturn
                OpFunctionEnd
   )";
-  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text, false);
 }
 
 // Remove a redundant add across basic blocks.
@@ -84,7 +84,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAdd) {
                OpReturn
                OpFunctionEnd
   )";
-  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text, false);
 }
 
 // Remove a redundant add going through a multiple basic blocks.
@@ -120,7 +120,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddDiamond) {
                OpFunctionEnd
 
   )";
-  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text, false);
 }
 
 // Remove a redundant add in a side node.
@@ -156,7 +156,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddInSideNode) {
                OpFunctionEnd
 
   )";
-  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text, false);
 }
 
 // Remove a redundant add whose value is in the result of a phi node.
@@ -196,7 +196,7 @@ TEST_F(RedundancyEliminationTest, RemoveRedundantAddWithPhi) {
                OpFunctionEnd
 
   )";
-  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text);
+  SinglePassRunAndMatch<opt::RedundancyEliminationPass>(text, false);
 }
 
 // Keep the add because it is redundant on some paths, but not all paths.
@@ -231,7 +231,7 @@ TEST_F(RedundancyEliminationTest, KeepPartiallyRedundantAdd) {
 
   )";
   auto result = SinglePassRunAndDisassemble<opt::RedundancyEliminationPass>(
-      text, /* skip_nop = */ true);
+      text, /* skip_nop = */ true, /* do_validation = */ false);
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
 }
 
@@ -269,7 +269,7 @@ TEST_F(RedundancyEliminationTest, KeepRedundantAddWithoutPhi) {
 
   )";
   auto result = SinglePassRunAndDisassemble<opt::RedundancyEliminationPass>(
-      text, /* skip_nop = */ true);
+      text, /* skip_nop = */ true, /* do_validation = */ false);
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
 }
 #endif
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
new file mode 100644 (file)
index 0000000..32b72b0
--- /dev/null
@@ -0,0 +1,1141 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "assembly_builder.h"
+#include "gmock/gmock.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using ScalarReplacementTest = PassTest<::testing::Test>;
+
+// TODO(dneto): Add Effcee as required dependency, and make this unconditional.
+#ifdef SPIRV_EFFCEE
+TEST_F(ScalarReplacementTest, SimpleStruct) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct [[elem:%\w+]]
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]]
+; CHECK: OpConstantNull [[struct]]
+; CHECK: [[null:%\w+]] = OpConstantNull [[elem]]
+; CHECK-NOT: OpVariable [[struct_ptr]]
+; CHECK: [[one:%\w+]] = OpVariable [[elem_ptr]] Function [[null]]
+; CHECK-NEXT: [[two:%\w+]] = OpVariable [[elem_ptr]] Function [[null]]
+; CHECK-NOT: OpVariable [[elem_ptr]] Function [[null]]
+; CHECK-NOT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpInBoundsAccessChain
+; CHECK: [[l1:%\w+]] = OpLoad [[elem]] [[two]]
+; CHECK-NOT: OpAccessChain
+; CHECK: [[l2:%\w+]] = OpLoad [[elem]] [[one]]
+; CHECK: OpIAdd [[elem]] [[l1]] [[l2]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %6 "simple_struct"
+%1 = OpTypeVoid
+%2 = OpTypeInt 32 0
+%3 = OpTypeStruct %2 %2 %2 %2
+%4 = OpTypePointer Function %3
+%5 = OpTypePointer Function %2
+%6 = OpTypeFunction %2
+%7 = OpConstantNull %3
+%8 = OpConstant %2 0
+%9 = OpConstant %2 1
+%10 = OpConstant %2 2
+%11 = OpConstant %2 3
+%12 = OpFunction %2 None %6
+%13 = OpLabel
+%14 = OpVariable %4 Function %7
+%15 = OpInBoundsAccessChain %5 %14 %8
+%16 = OpLoad %2 %15
+%17 = OpAccessChain %5 %14 %10
+%18 = OpLoad %2 %17
+%19 = OpIAdd %2 %16 %18
+OpReturnValue %19
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, StructInitialization) {
+  const std::string text = R"(
+;
+; CHECK: [[elem:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct:%\w+]] = OpTypeStruct [[elem]] [[elem]] [[elem]] [[elem]]
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]]
+; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0
+; CHECK: [[undef:%\w+]] = OpUndef [[elem]]
+; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2
+; CHECK: [[null:%\w+]] = OpConstantNull [[elem]]
+; CHECK-NOT: OpVariable [[struct_ptr]]
+; CHECK: OpVariable [[elem_ptr]] Function [[null]]
+; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]]
+; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]]
+; CHECK-NEXT: OpVariable [[elem_ptr]] Function
+; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[zero]]
+; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %6 "struct_init"
+%1 = OpTypeVoid
+%2 = OpTypeInt 32 0
+%3 = OpTypeStruct %2 %2 %2 %2
+%4 = OpTypePointer Function %3
+%20 = OpTypePointer Function %2
+%6 = OpTypeFunction %1
+%7 = OpConstant %2 0
+%8 = OpUndef %2
+%9 = OpConstant %2 2
+%30 = OpConstant %2 1
+%31 = OpConstant %2 3
+%10 = OpConstantNull %2
+%11 = OpConstantComposite %3 %7 %8 %9 %10
+%12 = OpFunction %1 None %6
+%13 = OpLabel
+%14 = OpVariable %4 Function %11
+%15 = OpAccessChain %20 %14 %7
+OpStore %15 %10
+%16 = OpAccessChain %20 %14 %9
+OpStore %16 %10
+%17 = OpAccessChain %20 %14 %30
+OpStore %17 %10
+%18 = OpAccessChain %20 %14 %31
+OpStore %18 %10
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, SpecConstantInitialization) {
+  const std::string text = R"(
+;
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct:%\w+]] = OpTypeStruct [[int]] [[int]]
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: [[int_ptr:%\w+]] = OpTypePointer Function [[int]]
+; CHECK: [[spec_comp:%\w+]] = OpSpecConstantComposite [[struct]]
+; CHECK: [[ex0:%\w+]] = OpSpecConstantOp [[int]] CompositeExtract [[spec_comp]] 0
+; CHECK: [[ex1:%\w+]] = OpSpecConstantOp [[int]] CompositeExtract [[spec_comp]] 1
+; CHECK-NOT: OpVariable [[struct]]
+; CHECK: OpVariable [[int_ptr]] Function [[ex1]]
+; CHECK-NEXT: OpVariable [[int_ptr]] Function [[ex0]]
+; CHECK-NOT: OpVariable [[struct]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %6 "spec_const"
+%1 = OpTypeVoid
+%2 = OpTypeInt 32 0
+%3 = OpTypeStruct %2 %2
+%4 = OpTypePointer Function %3
+%20 = OpTypePointer Function %2
+%5 = OpTypeFunction %1
+%6 = OpConstant %2 0
+%30 = OpConstant %2 1
+%7 = OpSpecConstant %2 0
+%8 = OpSpecConstantOp %2 IAdd %7 %7
+%9 = OpSpecConstantComposite %3 %7 %8
+%10 = OpFunction %1 None %5
+%11 = OpLabel
+%12 = OpVariable %4 Function %9
+%13 = OpAccessChain %20 %12 %6
+%14 = OpLoad %2 %13
+%15 = OpAccessChain %20 %12 %30
+%16 = OpLoad %2 %15
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+// TODO(alanbaker): Re-enable when vector and matrix scalarization is supported.
+// TEST_F(ScalarReplacementTest, VectorInitialization) {
+//  const std::string text = R"(
+//;
+//; CHECK: [[elem:%\w+]] = OpTypeInt 32 0
+//; CHECK: [[vector:%\w+]] = OpTypeVector [[elem]] 4
+//; CHECK: [[vector_ptr:%\w+]] = OpTypePointer Function [[vector]]
+//; CHECK: [[elem_ptr:%\w+]] = OpTypePointer Function [[elem]]
+//; CHECK: [[zero:%\w+]] = OpConstant [[elem]] 0
+//; CHECK: [[undef:%\w+]] = OpUndef [[elem]]
+//; CHECK: [[two:%\w+]] = OpConstant [[elem]] 2
+//; CHECK: [[null:%\w+]] = OpConstantNull [[elem]]
+//; CHECK-NOT: OpVariable [[vector_ptr]]
+//; CHECK: OpVariable [[elem_ptr]] Function [[zero]]
+//; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]]
+//; CHECK-NEXT: OpVariable [[elem_ptr]] Function
+//; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[two]]
+//; CHECK-NEXT: OpVariable [[elem_ptr]] Function [[null]]
+//; CHECK-NOT: OpVariable [[elem_ptr]] Function [[undef]]
+//;
+// OpCapability Shader
+// OpCapability Linkage
+// OpMemoryModel Logical GLSL450
+// OpName %6 "vector_init"
+//%1 = OpTypeVoid
+//%2 = OpTypeInt 32 0
+//%3 = OpTypeVector %2 4
+//%4 = OpTypePointer Function %3
+//%20 = OpTypePointer Function %2
+//%6 = OpTypeFunction %1
+//%7 = OpConstant %2 0
+//%8 = OpUndef %2
+//%9 = OpConstant %2 2
+//%30 = OpConstant %2 1
+//%31 = OpConstant %2 3
+//%10 = OpConstantNull %2
+//%11 = OpConstantComposite %3 %10 %9 %8 %7
+//%12 = OpFunction %1 None %6
+//%13 = OpLabel
+//%14 = OpVariable %4 Function %11
+//%15 = OpAccessChain %20 %14 %7
+// OpStore %15 %10
+//%16 = OpAccessChain %20 %14 %9
+// OpStore %16 %10
+//%17 = OpAccessChain %20 %14 %30
+// OpStore %17 %10
+//%18 = OpAccessChain %20 %14 %31
+// OpStore %18 %10
+// OpReturn
+// OpFunctionEnd
+//  )";
+//
+//  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+//}
+//
+// TEST_F(ScalarReplacementTest, MatrixInitialization) {
+//  const std::string text = R"(
+//;
+//; CHECK: [[float:%\w+]] = OpTypeFloat 32
+//; CHECK: [[vector:%\w+]] = OpTypeVector [[float]] 2
+//; CHECK: [[matrix:%\w+]] = OpTypeMatrix [[vector]] 2
+//; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]]
+//; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]]
+//; CHECK: [[vec_ptr:%\w+]] = OpTypePointer Function [[vector]]
+//; CHECK: [[zerof:%\w+]] = OpConstant [[float]] 0
+//; CHECK: [[onef:%\w+]] = OpConstant [[float]] 1
+//; CHECK: [[one_zero:%\w+]] = OpConstantComposite [[vector]] [[onef]] [[zerof]]
+//; CHECK: [[zero_one:%\w+]] = OpConstantComposite [[vector]] [[zerof]] [[onef]]
+//; CHECK: [[const_mat:%\w+]] = OpConstantComposite [[matrix]] [[one_zero]]
+//[[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ; CHECK-NOT: OpVariable
+//[[vector]] Function [[one_zero]] ; CHECK: [[f1:%\w+]] = OpVariable
+//[[float_ptr]] Function [[zerof]] ; CHECK-NEXT: [[f2:%\w+]] = OpVariable
+//[[float_ptr]] Function [[onef]] ; CHECK-NEXT: [[vec_var:%\w+]] = OpVariable
+//[[vec_ptr]] Function [[zero_one]] ; CHECK-NOT: OpVariable [[matrix]] ;
+// CHECK-NOT: OpVariable [[vector]] Function [[one_zero]]
+//;
+// OpCapability Shader
+// OpCapability Linkage
+// OpMemoryModel Logical GLSL450
+// OpName %7 "matrix_init"
+//%1 = OpTypeVoid
+//%2 = OpTypeFloat 32
+//%3 = OpTypeVector %2 2
+//%4 = OpTypeMatrix %3 2
+//%5 = OpTypePointer Function %4
+//%6 = OpTypePointer Function %2
+//%30 = OpTypePointer Function %3
+//%10 = OpTypeInt 32 0
+//%7 = OpTypeFunction %1 %10
+//%8 = OpConstant %2 0.0
+//%9 = OpConstant %2 1.0
+//%11 = OpConstant %10 0
+//%12 = OpConstant %10 1
+//%13 = OpConstantComposite %3 %9 %8
+//%14 = OpConstantComposite %3 %8 %9
+//%15 = OpConstantComposite %4 %13 %14
+//%16 = OpFunction %1 None %7
+//%31 = OpFunctionParameter %10
+//%17 = OpLabel
+//%18 = OpVariable %5 Function %15
+//%19 = OpAccessChain %6 %18 %11 %12
+// OpStore %19 %8
+//%20 = OpAccessChain %6 %18 %11 %11
+// OpStore %20 %8
+//%21 = OpAccessChain %30 %18 %12
+// OpStore %21 %14
+// OpReturn
+// OpFunctionEnd
+//  )";
+//
+//  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+//}
+
+TEST_F(ScalarReplacementTest, ElideAccessChain) {
+  const std::string text = R"(
+;
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK-NOT: OpAccessChain
+; CHECK: OpStore [[var]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %6 "elide_access_chain"
+%1 = OpTypeVoid
+%2 = OpTypeInt 32 0
+%3 = OpTypeStruct %2 %2 %2 %2
+%4 = OpTypePointer Function %3
+%20 = OpTypePointer Function %2
+%6 = OpTypeFunction %1
+%7 = OpConstant %2 0
+%8 = OpUndef %2
+%9 = OpConstant %2 2
+%10 = OpConstantNull %2
+%11 = OpConstantComposite %3 %7 %8 %9 %10
+%12 = OpFunction %1 None %6
+%13 = OpLabel
+%14 = OpVariable %4 Function %11
+%15 = OpAccessChain %20 %14 %7
+OpStore %15 %10
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ElideMultipleAccessChains) {
+  const std::string text = R"(
+;
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK-NOT: OpInBoundsAccessChain
+; CHECK OpStore [[var]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %6 "elide_two_access_chains"
+%1 = OpTypeVoid
+%2 = OpTypeFloat 32
+%3 = OpTypeStruct %2 %2
+%4 = OpTypeStruct %3 %3
+%5 = OpTypePointer Function %4
+%6 = OpTypePointer Function %2
+%7 = OpTypeFunction %1
+%8 = OpConstant %2 0.0
+%9 = OpConstant %2 1.0
+%10 = OpTypeInt 32 0
+%11 = OpConstant %10 0
+%12 = OpConstant %10 1
+%13 = OpConstantComposite %3 %9 %8
+%14 = OpConstantComposite %3 %8 %9
+%15 = OpConstantComposite %4 %13 %14
+%16 = OpFunction %1 None %7
+%17 = OpLabel
+%18 = OpVariable %5 Function %15
+%19 = OpInBoundsAccessChain %6 %18 %11 %12
+OpStore %19 %8
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ReplaceAccessChain) {
+  const std::string text = R"(
+;
+; CHECK: [[param:%\w+]] = OpFunctionParameter
+; CHECK: [[var:%\w+]] = OpVariable
+; CHECK: [[access:%\w+]] = OpAccessChain {{%\w+}} [[var]] [[param]]
+; CHECK: OpStore [[access]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %7 "replace_access_chain"
+%1 = OpTypeVoid
+%2 = OpTypeFloat 32
+%10 = OpTypeInt 32 0
+%uint_2 = OpConstant %10 2
+%3 = OpTypeArray %2 %uint_2
+%4 = OpTypeStruct %3 %3
+%5 = OpTypePointer Function %4
+%20 = OpTypePointer Function %3
+%6 = OpTypePointer Function %2
+%7 = OpTypeFunction %1 %10
+%8 = OpConstant %2 0.0
+%9 = OpConstant %2 1.0
+%11 = OpConstant %10 0
+%12 = OpConstant %10 1
+%13 = OpConstantComposite %3 %9 %8
+%14 = OpConstantComposite %3 %8 %9
+%15 = OpConstantComposite %4 %13 %14
+%16 = OpFunction %1 None %7
+%32 = OpFunctionParameter %10
+%17 = OpLabel
+%18 = OpVariable %5 Function %15
+%19 = OpAccessChain %6 %18 %11 %32
+OpStore %19 %8
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ArrayInitialization) {
+  const std::string text = R"(
+;
+; CHECK: [[float:%\w+]] = OpTypeFloat 32
+; CHECK: [[array:%\w+]] = OpTypeArray
+; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]]
+; CHECK: [[float_ptr:%\w+]] = OpTypePointer Function [[float]]
+; CHECK: [[float0:%\w+]] = OpConstant [[float]] 0
+; CHECK: [[float1:%\w+]] = OpConstant [[float]] 1
+; CHECK: [[float2:%\w+]] = OpConstant [[float]] 2
+; CHECK-NOT: OpVariable [[array_ptr]]
+; CHECK: [[var0:%\w+]] = OpVariable [[float_ptr]] Function [[float0]]
+; CHECK-NEXT: [[var1:%\w+]] = OpVariable [[float_ptr]] Function [[float1]]
+; CHECK-NEXT: [[var2:%\w+]] = OpVariable [[float_ptr]] Function [[float2]]
+; CHECK-NOT: OpVariable [[array_ptr]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "array_init"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%float = OpTypeFloat 32
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%uint_2 = OpConstant %uint 2
+%uint_3 = OpConstant %uint 3
+%float_array = OpTypeArray %float %uint_3
+%array_ptr = OpTypePointer Function %float_array
+%float_ptr = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%float_1 = OpConstant %float 1
+%float_2 = OpConstant %float 2
+%const_array = OpConstantComposite %float_array %float_2 %float_1 %float_0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%3 = OpVariable %array_ptr Function %const_array
+%4 = OpInBoundsAccessChain %float_ptr %3 %uint_0
+OpStore %4 %float_0
+%5 = OpInBoundsAccessChain %float_ptr %3 %uint_1
+OpStore %5 %float_0
+%6 = OpInBoundsAccessChain %float_ptr %3 %uint_2
+OpStore %6 %float_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+  ;
+}
+
+TEST_F(ScalarReplacementTest, NonUniformCompositeInitialization) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[array:%\w+]] = OpTypeArray
+; CHECK: [[matrix:%\w+]] = OpTypeMatrix
+; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]]
+; CHECK: [[struct2:%\w+]] = OpTypeStruct [[struct1]] [[matrix]] [[array]] [[uint]]
+; CHECK: [[struct1_ptr:%\w+]] = OpTypePointer Function [[struct1]]
+; CHECK: [[matrix_ptr:%\w+]] = OpTypePointer Function [[matrix]]
+; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[struct2_ptr:%\w+]] = OpTypePointer Function [[struct2]]
+; CHECK: [[const_uint:%\w+]] = OpConstant [[uint]]
+; CHECK: [[const_array:%\w+]] = OpConstantComposite [[array]]
+; CHECK: [[const_matrix:%\w+]] = OpConstantNull [[matrix]]
+; CHECK: [[const_struct1:%\w+]] = OpConstantComposite [[struct1]]
+; CHECK-NOT: OpVariable [[struct2_ptr]] Function
+; CHECK: OpVariable [[uint_ptr]] Function [[const_uint]]
+; CHECK-NEXT: OpVariable [[array_ptr]] Function [[const_array]]
+; CHECK-NEXT: OpVariable [[matrix_ptr]] Function [[const_matrix]]
+; CHECK-NEXT: OpVariable [[struct1_ptr]] Function [[const_struct1]]
+; CHECK-NOT: OpVariable [[struct2_ptr]] Function
+;
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int64
+OpCapability Float64
+OpMemoryModel Logical GLSL450
+OpName %func "non_uniform_composite_init"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%int64 = OpTypeInt 64 1
+%float = OpTypeFloat 32
+%double = OpTypeFloat 64
+%double2 = OpTypeVector %double 2
+%float4 = OpTypeVector %float 4
+%int64_0 = OpConstant %int64 0
+%int64_1 = OpConstant %int64 1
+%int64_2 = OpConstant %int64 2
+%int64_3 = OpConstant %int64 3
+%int64_array3 = OpTypeArray %int64 %int64_3
+%matrix_double2 = OpTypeMatrix %double2 2
+%struct1 = OpTypeStruct %uint %float4
+%struct2 = OpTypeStruct %struct1 %matrix_double2 %int64_array3 %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%matrix_double2_ptr = OpTypePointer Function %matrix_double2
+%int64_array_ptr = OpTypePointer Function %int64_array3
+%uint_ptr = OpTypePointer Function %uint
+%struct2_ptr = OpTypePointer Function %struct2
+%const_uint = OpConstant %uint 0
+%const_int64_array = OpConstantComposite %int64_array3 %int64_0 %int64_1 %int64_2
+%const_double2 = OpConstantNull %double2
+%const_matrix_double2 = OpConstantNull %matrix_double2
+%undef_float4 = OpUndef %float4
+%const_struct1 = OpConstantComposite %struct1 %const_uint %undef_float4
+%const_struct2 = OpConstantComposite %struct2 %const_struct1 %const_matrix_double2 %const_int64_array %const_uint
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct2_ptr Function %const_struct2
+%3 = OpAccessChain %struct1_ptr %var %int64_0
+OpStore %3 %const_struct1
+%4 = OpAccessChain %matrix_double2_ptr %var %int64_1
+OpStore %4 %const_matrix_double2
+%5 = OpAccessChain %int64_array_ptr %var %int64_2
+OpStore %5 %const_int64_array
+%6 = OpAccessChain %uint_ptr %var %int64_3
+OpStore %6 %const_uint
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+  ;
+}
+
+TEST_F(ScalarReplacementTest, ElideUncombinedAccessChains) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[var:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK-NOT: OpAccessChain
+; CHECK: OpStore [[var]] [[const]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "elide_uncombined_access_chains"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%struct2 = OpTypeStruct %struct1
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%struct2_ptr = OpTypePointer Function %struct2
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct2_ptr Function
+%3 = OpAccessChain %struct1_ptr %var %uint_0
+%4 = OpAccessChain %uint_ptr %3 %uint_0
+OpStore %4 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ElideSingleUncombinedAccessChains) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[array:%\w+]] = OpTypeArray [[uint]]
+; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[param:%\w+]] = OpFunctionParameter [[uint]]
+; CHECK: [[var:%\w+]] = OpVariable [[array_ptr]] Function
+; CHECK: [[access:%\w+]] = OpAccessChain {{.*}} [[var]] [[param]]
+; CHECK: OpStore [[access]] [[const]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "elide_single_uncombined_access_chains"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%array = OpTypeArray %uint %uint_1
+%struct2 = OpTypeStruct %array
+%uint_ptr = OpTypePointer Function %uint
+%array_ptr = OpTypePointer Function %array
+%struct2_ptr = OpTypePointer Function %struct2
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void %uint
+%1 = OpFunction %void None %func
+%param = OpFunctionParameter %uint
+%2 = OpLabel
+%var = OpVariable %struct2_ptr Function
+%3 = OpAccessChain %array_ptr %var %uint_0
+%4 = OpAccessChain %uint_ptr %3 %param
+OpStore %4 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ReplaceWholeLoad) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]]
+; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]]
+; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[l1]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "replace_whole_load"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%load = OpLoad %struct1 %var
+%3 = OpAccessChain %uint_ptr %var %uint_0
+OpStore %3 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ReplaceWholeLoadCopyMemoryAccess) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[l1:%\w+]] = OpLoad [[uint]] [[var1]] Nontemporal
+; CHECK: [[l0:%\w+]] = OpLoad [[uint]] [[var0]] Nontemporal
+; CHECK: OpCompositeConstruct [[struct1]] [[l0]] [[l1]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "replace_whole_load_copy_memory_access"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%load = OpLoad %struct1 %var Nontemporal
+%3 = OpAccessChain %uint_ptr %var %uint_0
+OpStore %3 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ReplaceWholeStore) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]]
+; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0
+; CHECK: OpStore [[var0]] [[ex0]]
+; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 1
+; CHECK: OpStore [[var1]] [[ex1]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "replace_whole_store"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%const_struct = OpConstantComposite %struct1 %uint_0 %uint_0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+OpStore %var %const_struct
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, ReplaceWholeStoreCopyMemoryAccess) {
+  const std::string text = R"(
+;
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct1:%\w+]] = OpTypeStruct [[uint]] [[uint]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[const:%\w+]] = OpConstant [[uint]] 0
+; CHECK: [[const_struct:%\w+]] = OpConstantComposite [[struct1]] [[const]] [[const]]
+; CHECK: [[var1:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[var0:%\w+]] = OpVariable [[uint_ptr]] Function
+; CHECK: [[ex0:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 0
+; CHECK: OpStore [[var0]] [[ex0]] Aligned 4
+; CHECK: [[ex1:%\w+]] = OpCompositeExtract [[uint]] [[const_struct]] 1
+; CHECK: OpStore [[var1]] [[ex1]] Aligned 4
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "replace_whole_store_copy_memory_access"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%const_struct = OpConstantComposite %struct1 %uint_0 %uint_0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+OpStore %var %const_struct Aligned 4
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchVolatileLoad) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_volatile_load"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3 Volatile
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchVolatileStore) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_volatile_store"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0
+OpStore %3 %uint_0 Volatile
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchSpecNonFunctionVariable) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Uniform [[struct]]
+; CHECK: OpConstant
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_spec_constant_access_chain"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Uniform %uint
+%struct1_ptr = OpTypePointer Uniform %struct1
+%uint_0 = OpConstant %uint 0
+%var = OpVariable %struct1_ptr Uniform
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%3 = OpAccessChain %uint_ptr %var %uint_0
+OpStore %3 %uint_0 Volatile
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchSpecConstantAccessChain) {
+  const std::string text = R"(
+;
+; CHECK: [[array:%\w+]] = OpTypeArray
+; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[array_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_spec_constant_access_chain"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%array = OpTypeArray %uint %uint_1
+%uint_ptr = OpTypePointer Function %uint
+%array_ptr = OpTypePointer Function %array
+%uint_0 = OpConstant %uint 0
+%spec_const = OpSpecConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %array_ptr Function
+%3 = OpAccessChain %uint_ptr %var %spec_const
+OpStore %3 %uint_0 Volatile
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, NoPartialAccesses) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "no_partial_accesses"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%const = OpConstantNull %struct1
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+OpStore %var %const
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchPtrAccessChain) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_ptr_access_chain"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpPtrAccessChain %uint_ptr %var %uint_0 %uint_0
+OpStore %3 %uint_0
+%4 = OpAccessChain %uint_ptr %var %uint_0
+OpStore %4 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, false);
+}
+
+TEST_F(ScalarReplacementTest, DontTouchInBoundsPtrAccessChain) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "dont_touch_in_bounds_ptr_access_chain"
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpInBoundsPtrAccessChain %uint_ptr %var %uint_0 %uint_0
+OpStore %3 %uint_0
+%4 = OpInBoundsAccessChain %uint_ptr %var %uint_0
+OpStore %4 %uint_0
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, false);
+}
+
+TEST_F(ScalarReplacementTest, DonTouchAliasedDecoration) {
+  const std::string text = R"(
+;
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[struct_ptr]]
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "aliased"
+OpDecorate %var Aliased
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, CopyRestrictDecoration) {
+  const std::string text = R"(
+;
+; CHECK: OpName
+; CHECK-NEXT: OpDecorate [[var0:%\w+]] Restrict
+; CHECK-NEXT: OpDecorate [[var1:%\w+]] Restrict
+; CHECK: [[int:%\w+]] = OpTypeInt
+; CHECK: [[struct:%\w+]] = OpTypeStruct
+; CHECK: [[int_ptr:%\w+]] = OpTypePointer Function [[int]]
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: [[var1]] = OpVariable [[int_ptr]]
+; CHECK-NEXT: [[var0]] = OpVariable [[int_ptr]]
+; CHECK-NOT: OpVariable [[struct_ptr]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "restrict"
+OpDecorate %var Restrict
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%struct1 = OpTypeStruct %uint %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%func = OpTypeFunction %void
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3
+%5 = OpAccessChain %uint_ptr %var %uint_1
+%6 = OpLoad %uint %5
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontClobberDecoratesOnSubtypes) {
+  const std::string text = R"(
+;
+; CHECK: OpDecorate [[array:%\w+]] ArrayStride 1
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[array]] = OpTypeArray [[uint]]
+; CHECK: [[array_ptr:%\w+]] = OpTypePointer Function [[array]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[array_ptr]] Function
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "array_stride"
+OpDecorate %array ArrayStride 1
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%array = OpTypeArray %uint %uint_1
+%struct1 = OpTypeStruct %array
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void %uint
+%1 = OpFunction %void None %func
+%param = OpFunctionParameter %uint
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0 %param
+%4 = OpLoad %uint %3
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+
+TEST_F(ScalarReplacementTest, DontCopyMemberDecorate) {
+  const std::string text = R"(
+;
+; CHECK-NOT: OpDecorate
+; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
+; CHECK: [[struct:%\w+]] = OpTypeStruct [[uint]]
+; CHECK: [[uint_ptr:%\w+]] = OpTypePointer Function [[uint]]
+; CHECK: [[struct_ptr:%\w+]] = OpTypePointer Function [[struct]]
+; CHECK: OpLabel
+; CHECK-NEXT: OpVariable [[uint_ptr]] Function
+; CHECK-NOT: OpVariable
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpName %func "member_decorate"
+OpMemberDecorate %struct1 0 Offset 1
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%struct1 = OpTypeStruct %uint
+%uint_ptr = OpTypePointer Function %uint
+%struct1_ptr = OpTypePointer Function %struct1
+%uint_0 = OpConstant %uint 0
+%func = OpTypeFunction %void %uint
+%1 = OpFunction %void None %func
+%2 = OpLabel
+%var = OpVariable %struct1_ptr Function
+%3 = OpAccessChain %uint_ptr %var %uint_0
+%4 = OpLoad %uint %3
+OpReturn
+OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<opt::ScalarReplacementPass>(text, true);
+}
+#endif  // SPIRV_EFFCEE
+
+}  // namespace
index 32eb4c3..4439aee 100644 (file)
@@ -55,7 +55,7 @@ TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
   };
 
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      JoinAllInsts(text), /* skip_nop = */ true);
+      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
   const std::string& output = std::get<0>(result);
@@ -99,7 +99,7 @@ TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
 ; CHECK: OpFunctionEnd
                OpFunctionEnd)";
 
-  SinglePassRunAndMatch<opt::StrengthReductionPass>(text);
+  SinglePassRunAndMatch<opt::StrengthReductionPass>(text, false);
 }
 #endif
 
@@ -127,7 +127,7 @@ TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
 )";
   // clang-format on
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      text, /* skip_nop = */ true);
+      text, /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
   const std::string& output = std::get<0>(result);
@@ -158,7 +158,7 @@ TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
   };
 
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      JoinAllInsts(text), /* skip_nop = */ true);
+      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
 }
@@ -187,7 +187,7 @@ TEST_F(StrengthReductionBasicTest, BasicNoChange) {
   };
 
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      JoinAllInsts(text), /* skip_nop = */ true);
+      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithoutChange, std::get<1>(result));
 }
@@ -215,7 +215,7 @@ TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
   };
 
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      JoinAllInsts(text), /* skip_nop = */ true);
+      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
   const std::string& output = std::get<0>(result);
@@ -249,7 +249,7 @@ TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
   };
 
   auto result = SinglePassRunAndDisassemble<opt::StrengthReductionPass>(
-      JoinAllInsts(text), /* skip_nop = */ true);
+      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
 
   EXPECT_EQ(opt::Pass::Status::SuccessWithChange, std::get<1>(result));
   const std::string& output = std::get<0>(result);
index 8fc262c..5c29fa7 100644 (file)
@@ -118,7 +118,7 @@ class UnifyConstantTest : public PassTest<T> {
     std::tie(optimized_before_strip, status) =
         this->template SinglePassRunAndDisassemble<opt::UnifyConstantPass>(
             test_builder.GetCode(),
-            /* skip_nop = */ true);
+            /* skip_nop = */ true, /* do_validation = */ false);
     std::string optimized_without_opnames;
     std::unordered_set<std::string> optimized_opnames;
     std::tie(optimized_without_opnames, optimized_opnames) =
index dc912c3..c0cb48a 100644 (file)
@@ -152,6 +152,9 @@ Options (in lexicographical order):
                a new basic block containing an unified return.
                This pass does not currently support structured control flow. It
                makes no changes if the shader capability is detected.
+  --local-redundancy-elimination
+               Looks for instructions in the same basic block that compute the
+               same value, and deletes the redundant ones.
   -O
                Optimize for performance. Apply a sequence of transformations
                in an attempt to improve the performance of the generated
@@ -198,10 +201,14 @@ Options (in lexicographical order):
   --redundancy-elimination
                Looks for instructions in the same function that compute the
                same value, and deletes the redundant ones.
-  --relax-store-struct
+  --relax-struct-store
                Allow store from one struct type to a different type with
                compatible layout and members. This option is forwarded to the
                validator.
+  --scalar-replacement
+               Replace aggregate function scope variables that are only accessed
+               via their elements with new function variables representing each
+               element.
   --set-spec-const-default-value "<spec id>:<default value> ..."
                Set the default values of the specialization constants with
                <spec id>:<default value> pairs specified in a double-quoted
@@ -385,6 +392,8 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         optimizer->RegisterPass(CreateDeadVariableEliminationPass());
       } else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
         optimizer->RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
+      } else if (0 == strcmp(cur_arg, "--scalar-replacement")) {
+        optimizer->RegisterPass(CreateScalarReplacementPass());
       } else if (0 == strcmp(cur_arg, "--strength-reduction")) {
         optimizer->RegisterPass(CreateStrengthReductionPass());
       } else if (0 == strcmp(cur_arg, "--unify-const")) {
@@ -399,7 +408,7 @@ OptStatus ParseFlags(int argc, const char** argv, Optimizer* optimizer,
         optimizer->RegisterPass(CreateLocalRedundancyEliminationPass());
       } else if (0 == strcmp(cur_arg, "--redundancy-elimination")) {
         optimizer->RegisterPass(CreateRedundancyEliminationPass());
-      } else if (0 == strcmp(cur_arg, "--relax-store-struct")) {
+      } else if (0 == strcmp(cur_arg, "--relax-struct-store")) {
         options->relax_struct_store = true;
       } else if (0 == strcmp(cur_arg, "--skip-validation")) {
         *skip_validator = true;