#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
// TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
// optimizations will be migrated to stages
- void AddFrameControlDeps(const NodeDef* old_node,
- const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = ctx_.frame_map->find(old_node);
- if (frame_it != ctx_.frame_map->end()) {
- for (auto node : new_nodes) {
- ctx_.frame_map->emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
- ctx_.node_map);
+ void ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ ctx_.node_map->AddOutput(NodeName(src->input(i)),
+ target_node->name());
+ } else {
+ break;
}
}
}
CHECK(IsSupported(node));
std::set<string> common_factors;
- TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+ std::vector<string> ctrl_deps;
+ TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
new_add_node->set_input(i, unique_factors[i]);
}
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
+ // Add control deps on add node
+ for (const string& ctrl_dep : ctrl_deps) {
+ *new_add_node->add_input() = ctrl_dep;
+ ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
+ }
// optimize new inner aggregation node
AddToOptimizationQueue(new_add_node);
}
// Determine the set of common factors if the input nodes are all Mul nodes.
- Status GetCommonFactors(const NodeDef* node,
- std::set<string>* common_factors) const {
+ Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+ std::vector<string>* ctrl_deps) const {
CHECK(common_factors->empty());
for (int i = 0; i < node->input_size(); ++i) {
if (i > 0 && common_factors->empty()) break;
- if (IsControlInput(node->input(i))) break;
-
+ if (IsControlInput(node->input(i))) {
+ ctrl_deps->push_back(node->input(i));
+ continue;
+ }
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
std::inserter(intersection, intersection.begin()));
std::swap(*common_factors, intersection);
}
+ for (int i = 2; i < input->input_size(); ++i) {
+ ctrl_deps->push_back(input->input(i));
+ }
}
return Status::OK();
}
}
}
-void ArithmeticOptimizer::AddFrameControlDeps(
- const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes,
- const string& source_for_ctrl_dep,
- const std::vector<NodeDef*>& sinks_for_control_dep) {
- const auto frame_it = frame_map_.find(old_node);
- if (frame_it != frame_map_.end()) {
- for (auto node : new_nodes) {
- frame_map_.emplace(node, frame_it->second);
- }
- if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- source_for_ctrl_dep, optimized_graph_, node_map_.get());
- for (auto node : sinks_for_control_dep) {
- MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get());
+void ArithmeticOptimizer::ForwardControlDependencies(
+ NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+ for (const auto& src : src_nodes) {
+ for (int i = src->input_size() - 1; i >= 0; --i) {
+ if (IsControlInput(src->input(i))) {
+ *target_node->add_input() = src->input(i);
+ node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
+ } else {
+ break;
}
}
}
node_map_->AddOutput(new_transpose->name(), new_cast->name());
nodes_to_simplify->PushBack(new_transpose);
- // Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_transpose, new_cast},
- new_transpose->input(0), {new_transpose});
-
+ ForwardControlDependencies(new_transpose, {cast, node});
return new_cast->name();
}
}
node_map_->AddOutput(weights->name(), scaled_weights->name());
scaled_weights->add_input(mul->input(1));
node_map_->AddOutput(scale->name(), scaled_weights->name());
- AddFrameControlDeps(node, {scaled_weights}, "", {});
+ ForwardControlDependencies(scaled_weights, {source});
// Update `conv`'s weights to `scaled_weights`.
conv->set_input(1, scaled_weights->name());
}
if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
- // Discard aggregate nodes with a single input.
+ // Discard aggregate nodes with a single input and no control dependencies.
if (node->input_size() == 1) {
return node->input(0);
}
return "";
}
new_const_node->set_device(node->device());
+ MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+ optimized_graph_, node_map_.get());
nodes_to_simplify->PushBack(new_const_node);
// 2. Replace the aggregate node with Mul(Const(N), x).
new_mul_node->add_input(node->input(0));
node_map_->AddOutput(node->input(0), new_mul_node->name());
- CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get());
- AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0),
- {new_const_node});
+ ForwardControlDependencies(new_mul_node, {node});
return new_mul_node->name();
}
}
FlipBooleanAttr(attr_a, new_op);
new_op->set_input(0, a->input(0));
node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
- AddFrameControlDeps(node, {new_op}, a->input(0), {new_op});
}
if (b_is_foldable) {
const string attr_b =
FlipBooleanAttr(attr_b, new_op);
new_op->set_input(1, b->input(0));
node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
- if (!a_is_foldable) {
- AddFrameControlDeps(node, {new_op}, b->input(0), {new_op});
- }
}
+ std::vector<const NodeDef*> deps_to_forward({node});
+ if (a_is_foldable) {
+ deps_to_forward.push_back(a);
+ }
+ if (b_is_foldable) {
+ deps_to_forward.push_back(b);
+ }
+ ForwardControlDependencies(new_op, deps_to_forward);
}
}
: "Transpose");
new_op->set_input(0, input->input(0));
node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
- AddFrameControlDeps(node, {new_op}, "", {});
+ ForwardControlDependencies(new_op, {node, input});
return new_op->name();
}
}
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get(),
- &frame_map_);
+ graph_properties_.get(), node_map_.get());
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
graph_properties_.reset(new GraphProperties(item));
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
- // Identify loop frames
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
-
// Perform the optimizations.
TF_RETURN_IF_ERROR(SimplifyArithmeticOps());