}
// Move constants past Enter.
- if (IsEnter(*node) && node->input_size() > 0) {
+ // TODO(rmlarsen): Reenable when we fix the root cause of b/76008022
+ if (opt_level_ == RewriterConfig::AGGRESSIVE && IsEnter(*node) &&
+ node->input_size() > 0) {
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input != nullptr && IsReallyConstant(*input) &&
NodeDef* new_node = optimized_graph->add_node();
*new_node = *input;
new_node->set_name(OptimizedNodeName(*input, "_enter"));
+ new_node->set_device(node->device());
new_node->clear_input();
new_node->add_input(AsControlDependency(node_name));
node_map_->AddNode(new_node->name(), new_node);
static string AddControlDependency(const string& input_name, GraphDef* graph,
NodeMap* node_map);
- ConstantFolding(DeviceBase* cpu_device);
+ explicit ConstantFolding(DeviceBase* cpu_device);
ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);
~ConstantFolding() override {}
item.fetch.push_back("id2");
item.fetch.push_back("id3");
- ConstantFolding optimizer(nullptr /* cpu_device */);
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);