: graph_def_version_(graph_def_version),
node_def_(CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
+ input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
for (const TensorShapeProto& p : input_tensors_as_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
}
PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
if (!construction_status_.ok()) return;
+ inputs_.reserve(input_shapes.size());
for (const TensorShapeProto& p : input_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
: graph_def_version_(graph_def_version),
node_def_(CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
+ input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
for (const PartialTensorShape& p : input_tensors_as_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
}
PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
if (!construction_status_.ok()) return;
+ inputs_.reserve(input_shapes.size());
for (const PartialTensorShape& p : input_shapes) {
ShapeHandle shape;
construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
for (const auto& e : output_name_map_) {
num_outputs = std::max(num_outputs, e.second.second);
}
- for (int i = 0; i < num_outputs; ++i) {
- outputs_.push_back(nullptr);
- }
+ outputs_.assign(num_outputs, nullptr);
output_handle_shapes_and_types_.resize(num_outputs);
}
TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
// Merge the prefix dims and create the new output shapes.
+ const int32 rank_s = Rank(s);
std::vector<DimensionHandle> dims;
+ dims.reserve(std::max(rank, rank_s));
dims.resize(rank);
for (int i = 0; i < rank; ++i) {
TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
}
*prefix_out = MakeShape(dims);
- for (int i = rank; i < Rank(s); ++i) dims.push_back(Dim(s, i));
+ for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
*s_out = MakeShape(dims);
return Status::OK();
}
Status InferenceContext::AttachContext(const Status& status) {
std::vector<string> input_shapes;
+ input_shapes.reserve(inputs_.size());
for (const ShapeHandle& input_shape : inputs_) {
input_shapes.emplace_back(DebugString(input_shape));
}
// Add information about the input tensors and partial tensor shapes used.
std::vector<string> input_from_tensors_str;
std::vector<string> input_from_tensors_as_shape_str;
+ input_from_tensors_as_shape_str.reserve(inputs_.size());
for (int i = 0; i < inputs_.size(); ++i) {
if (requested_input_tensor_as_partial_shape_[i] &&
i < input_tensors_as_shapes_.size() &&
if (!refined) {
return false;
}
- for (int i = 0; i < new_values.size(); ++i) {
- (*to_update)[i] = new_values[i];
- }
+ to_update->swap(new_values);
return true;
}
void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
std::vector<bool>* input_already_exists) {
// Remove 'inputs_to_remove' from 'node_def'
- // TODO(skyewm): is there a better way to do this?
- std::vector<string> inputs;
- inputs.reserve(node_def->input_size());
- for (int i = 0; i < node_def->input_size(); ++i) {
- inputs.push_back(node_def->input(i));
- }
- node_def->clear_input();
- for (int i = 0, j = 0; i < inputs.size(); ++i) {
+ NodeDef copy;
+ copy.mutable_input()->Reserve(node_def->input_size() -
+ inputs_to_remove.size());
+ for (int i = 0, j = 0; i < node_def->input_size(); ++i) {
if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
++j;
} else {
- node_def->add_input(inputs[i]);
+ copy.add_input()->swap(*node_def->mutable_input(i));
}
}
+ node_def->mutable_input()->Swap(copy.mutable_input());
// Remove 'inputs_to_remove' from 'input_already_exists'
for (int idx : inputs_to_remove) {
input_already_exists->erase(input_already_exists->begin() + idx);
// dependencies
for (const string& control_dep : opts_.control_dependencies) {
string input = TensorId(control_dep, Graph::kControlSlot).ToString();
- const protobuf::RepeatedPtrField<string>& inputs = node_def->input();
- if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) {
- // Control dependency already exists
+ bool found = false;
+ for (int i = node_def->input_size() - 1; i >= 0; --i) {
+ const string& node_input = node_def->input(i);
+ if (node_input[0] != '^') {
+ // Control inputs are at the end. Break when we reach the non-control
+ // inputs.
+ break;
+ }
+ if (node_input == input) {
+ // Control dependency already exists
+ found = true;
+ break;
+ }
+ }
+ if (found) {
continue;
}
node_def->add_input(input);
node_def->set_name(strings::StrCat(prefix_, node_def->name()));
// Update names of input nodes
for (int i = 0; i < node_def->input_size(); ++i) {
- StringPiece input(node_def->input(i));
// Skip remapped inputs (which already exist in g_ and are not being
// imported).
if (input_already_exists[i]) continue;
+ StringPiece input(node_def->input(i));
if (str_util::ConsumePrefix(&input, "^")) {
node_def->set_input(i, strings::StrCat("^", prefix_, input));
} else {
}
}
- // TODO(ashankar): The line below means an additional copy of the NodeDef,
- // which can be expensive if the NodeDef contains large tensors in it.
- // Might make sense to change the API for ImportGraphDef to take a mutable
- // GraphDef* and avoid the copying.
+ // TODO(ashankar): The line below means an additional copy of the
+ // NodeDef, which can be expensive if the NodeDef contains large tensors
+ // in it. Might make sense to change the API for ImportGraphDef to take
+ // a mutable GraphDef* and avoid the copying.
imported_node_def = original_node_def;
if (!opts_.input_map.empty()) {
// Note that input_already_exists can shrink here
src_node->num_outputs(), " outputs");
}
- inputs.push_back(InputInfo(id.first.ToString(), src_node, src_index));
+ inputs.emplace_back(id.first.ToString(), src_node, src_index);
}
if (has_data_back_edge && !IsMerge(*node_def)) {
if (inputs[i].node == nullptr) {
// Record this back edge, which will be added after all nodes
// are created.
- back_edges_.push_back(
- EdgeInfo(inputs[i].name, inputs[i].index, node, i));
+ back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i);
} else if (inputs[i].index == Graph::kControlSlot) {
g_->AddControlEdge(inputs[i].node, node);
} else {
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_memory.h"
-#include <list>
+
+#include <deque>
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
const string& node_name, int output_id,
std::unordered_map<string, GraphMemory::LiveTensor*>* live_tensors,
- std::list<GraphMemory::LiveTensor>* device_tensors) {
+ std::deque<GraphMemory::LiveTensor>* device_tensors) {
string name = strings::StrCat(node_name, ":", output_id);
GraphMemory::LiveTensor* live;
auto it = live_tensors->find(name);
namespace {
struct Event {
+ Event(int64 _timestamp, bool _allocated,
+ const GraphMemory::LiveTensor* _tensor)
+ : timestamp(_timestamp), allocated(_allocated), tensor(_tensor) {}
+
int64 timestamp;
bool allocated;
const GraphMemory::LiveTensor* tensor;
}
std::unordered_map<string, LiveTensor*> live_tensors;
- std::unordered_map<string, std::list<LiveTensor>> live_tensors_per_device;
-
- NodeMap node_map(&item_.graph);
+ std::unordered_map<string, std::deque<LiveTensor>> live_tensors_per_device;
+ std::unordered_map<string, const NodeDef*> node_map;
+ for (const NodeDef& node : item_.graph.node()) {
+ node_map[node.name()] = &node;
+ }
for (const auto& dev_stats : timeline.dev_stats()) {
const string& device_name = dev_stats.device();
const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
- std::list<LiveTensor>& device_tensors =
+ std::deque<LiveTensor>& device_tensors =
live_tensors_per_device[dev_stats.device()];
for (const auto& node_stats : dev_stats.node_stats()) {
for (int i = 0; i < node_stats.output_size(); ++i) {
node_stats.op_end_rel_micros()));
}
- const NodeDef* node = node_map.GetNode(node_stats.node_name());
- if (!node) {
+ auto it = node_map.find(node_stats.node_name());
+ if (it == node_map.end()) {
// Skip nodes inserted by TF since they don't exist in the original
// graph (e.g _Send/_Recv nodes).
continue;
}
+ const NodeDef* node = it->second;
std::unordered_set<int> swapped_inputs;
if (is_gpu) {
auto it = node->attr().find("_swap_to_host");
std::vector<Event> events;
events.reserve(2 * live_per_device.second.size());
for (const auto& live : live_per_device.second) {
- events.push_back(Event{live.allocation_time.count(), true, &live});
- events.push_back(Event{live.deallocation_time.count(), false, &live});
+ events.emplace_back(static_cast<int64>(live.allocation_time.count()),
+ true, &live);
+ events.emplace_back(static_cast<int64>(live.deallocation_time.count()),
+ false, &live);
}
std::stable_sort(events.begin(), events.end());
size_t peak = 0;
- std::set<const LiveTensor*> live_at_peak;
+ std::unordered_set<const LiveTensor*> live_at_peak;
size_t current = 0;
- std::set<const LiveTensor*> currently_live;
+ std::unordered_set<const LiveTensor*> currently_live;
for (int i = 0; i < events.size(); ++i) {
const auto& event = events[i];
namespace tensorflow {
namespace grappler {
-GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) {
+GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
id = other.id;
feed = other.feed;
fetch = other.fetch;
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
- graph.Swap(&graphDef);
+ graph.Swap(graph_def);
}
std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
// A TensorFlow model to optimize.
// Models are represented by the combination of a graph, one of more fetch
// nodes, and potentially a set of nodes to feed.
-// TODO(volunteer_needed): turn this struct into a class.
struct GrapplerItem {
GrapplerItem() = default;
- GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef);
+ GrapplerItem(const GrapplerItem& other, GraphDef&& graph_def)
+ : GrapplerItem(other, &graph_def) {}
+ // Swaps *graph_def with an empty GraphDef.
+ GrapplerItem(const GrapplerItem& other, GraphDef* graph_def);
virtual ~GrapplerItem() = default;
string id; // A unique id for this item
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
auto is_value_preserving_non_branching = [&](const NodeDef& node) {
- return IsValuePreserving(node) &&
- NumNonControlOutputs(node, node_map) == 1 &&
- nodes_to_preserve.count(node.name()) == 0;
+ return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
+ IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
};
return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
is_value_preserving_non_branching);
Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
const GrapplerItem& item,
GraphDef* optimized_graph) {
- GrapplerItem optimized_item(item);
- optimized_graph_ = &optimized_item.graph;
-
// Set up helper data structures.
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
+ *optimized_graph = item.graph;
+ optimized_graph_ = optimized_graph;
node_map_.reset(new NodeMap(optimized_graph_));
DedupComputations();
// optimize larger subgraphs starting from the roots with more inputs.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
- // Shapes are only needed in aggressive mode.
- graph_properties_.reset(new GraphProperties(item));
+ GrapplerItem optimized_item(item, optimized_graph);
+ optimized_graph_ = &optimized_item.graph;
+ graph_properties_.reset(new GraphProperties(optimized_item));
const Status status = graph_properties_->InferStatically(false);
const bool can_use_shapes = status.ok();
if (!can_use_shapes) {
for (const auto& input : node.input()) {
int port = 0;
- ParseNodeName(input, &port);
+ ParseNodeNameAsStringPiece(input, &port);
if (port < 0) {
// Control dependency
break;
left_child_is_constant ? left_child : right_child;
// Make sure that it is safe to change the value of the child node->
if (op_child_node->input_size() < 2 ||
- NumNonControlOutputs(*op_child_node, *node_map_) > 1 ||
nodes_to_preserve_.find(op_child_node->name()) !=
- nodes_to_preserve_.end()) {
+ nodes_to_preserve_.end() ||
+ NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
continue;
}
}
bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
- if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
- return false;
- }
- if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
- // The output values of this node may be needed.
- return false;
- }
- if (IsMerge(node) || IsSwitch(node)) {
- return false;
- }
- if (ModifiesFrameInfo(node)) {
- return false;
- }
- if (!IsFreeOfSideEffect(node)) {
+ if (!fetch_nodes_known_ ||
+ nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
}
- if (node.op() == "ControlTrigger") {
+ if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
+ !IsFreeOfSideEffect(node)) {
return false;
}
if (node.op().rfind("Submodel", 0) == 0) {
if (!status.ok() || op_def->output_arg_size() == 0) {
return false;
}
-
+ const std::unordered_set<string> do_not_rewrite_ops{
+ "Assert", "CheckNumerics", "_Retval",
+ "_Arg", "_ParallelConcatUpdate", "_TPUExecute",
+ "_TPUCompile", "ControlTrigger"};
+ if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
+ return false;
+ }
if (!SafeToRemoveIdentity(node)) {
return false;
}
-
- const std::unordered_set<string> do_not_rewrite_ops{
- "Assert", "CheckNumerics", "_Retval",
- "_Arg", "_ParallelConcatUpdate", "_TPUExecute",
- "_TPUCompile"};
- return do_not_rewrite_ops.find(node.op()) == do_not_rewrite_ops.end();
+ if (NumNonControlOutputs(node, *node_map_) > 0) {
+ // The output values of this node may be needed.
+ return false;
+ }
+ return true;
}
void DependencyOptimizer::OptimizeNode(int node_idx,
bool data_connection = false;
for (int i = fanout->input_size() - 1; i >= 0; --i) {
int pos;
- string input_name = ParseNodeName(fanout->input(i), &pos);
+ StringPiece input_name =
+ ParseNodeNameAsStringPiece(fanout->input(i), &pos);
if (input_name == node_name) {
if (pos < 0) {
fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
for (int j = 0; j < consumer->input_size(); ++j) {
const string& old_input = consumer->input(j);
int old_input_pos;
- string old_input_node_name =
- ParseNodeName(old_input, &old_input_pos);
+ StringPiece old_input_node_name =
+ ParseNodeNameAsStringPiece(old_input, &old_input_pos);
if (old_input_node_name == node_name) {
if (old_input_pos >= 0) {
// Regular input
recomputation_targets_name_scope_, optimized_graph,
item);
- GrapplerItem optimized_item(item, std::move(*optimized_graph));
+ GrapplerItem optimized_item(item, optimized_graph);
std::unordered_set<string> skip_list;
// Bound the number of rewrite passes to avoid long processing times on graphs
// that simply won't fit in memory.
cfg_.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
? 1
: cfg_.meta_optimizer_iterations();
+ GrapplerItem optimized_item = item;
+ optimized_graph->Swap(&optimized_item.graph);
for (int iteration = 0; iteration < num_iterations; ++iteration) {
VLOG(1) << "Starting optimization iteration " << iteration + 1;
for (const auto& optimizer : optimizers) {
+ // Invariant: optimized_graph contains the most recently optimized
+ // version of the graph.
if (iteration > 0 && run_once_optimizers.count(optimizer->name())) {
continue;
}
- if (!already_optimized) {
- Status status = optimizer->Optimize(cluster, item, optimized_graph);
- string result;
- if (!status.ok()) {
- VLOG(1) << "Not able to apply optimizer " << optimizer->name()
- << ". Return status: " << status.ToString();
- result = status.ToString();
- } else {
- already_optimized = true;
- result = strings::StrCat(
- "OK. ", PrintSizesBeforeAfter(item.graph, *optimized_graph));
- }
- result_.push_back(std::make_pair(optimizer->name(), result));
- VLOG(1) << "Optimizer " << optimizer->name()
- << " return status: " << result;
+ uint64 start_us = Env::Default()->NowMicros();
+ // This swaps the current optimized_graph into optimized item and
+ // resets optimized_graph to an empty graph.
+ optimized_graph->Swap(&optimized_item.graph);
+ *optimized_graph = GraphDef();
+ Status status =
+ optimizer->Optimize(cluster, optimized_item, optimized_graph);
+
+ uint64 end_us = Env::Default()->NowMicros();
+ float duration_ms = (end_us - start_us) / 1000.0f;
+ string result;
+ if (!status.ok()) {
+ VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
+ << status.ToString();
+ optimized_graph->Swap(&optimized_item.graph);
+ result = status.ToString();
} else {
- GrapplerItem optimized_item(item, std::move(*optimized_graph));
- Status status =
- optimizer->Optimize(cluster, optimized_item, optimized_graph);
- string result;
- if (!status.ok()) {
- VLOG(1) << "Not able to apply optimizer " << optimizer->name() << ": "
- << status.ToString();
- optimized_graph->Swap(&optimized_item.graph);
- result = status.ToString();
- } else {
- result = strings::StrCat(
- optimizer->name(), ": ",
- PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph));
- }
- result_.push_back(std::make_pair(optimizer->name(), result));
- VLOG(1) << result;
+ already_optimized = true;
+ result = strings::StrCat(
+ optimizer->name(), ": ",
+ PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
+ ", time = ", duration_ms, "ms.");
}
+ result_.emplace_back(optimizer->name(), result);
+ VLOG(1) << result;
}
}
item.graph.library().gradient_size());
DCHECK_EQ(optimized_graph->versions().producer(),
item.graph.versions().producer());
- } else {
- *optimized_graph = item.graph;
}
-
return Status::OK();
}
return true;
}
int position1;
- string node1 = ParseNodeName(name1, &position1);
+ StringPiece node1 = ParseNodeNameAsStringPiece(name1, &position1);
int position2;
- string node2 = ParseNodeName(name2, &position2);
+ StringPiece node2 = ParseNodeNameAsStringPiece(name2, &position2);
return (position1 == position2) && (node1 == node2);
}
-string ParseNodeName(const string& name, int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
- *position = 0;
- return "";
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
- }
- return capture.ToString();
- }
-}
-
bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
int NodePosition(const string& name) {
int position;
- ParseNodeName(name, &position);
+ ParseNodeNameAsStringPiece(name, &position);
return position;
}
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
+ int pos;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (IsControlInput(node_as_input)) {
break;
}
- if (NodeName(node_as_input) == node.name()) {
+ if (node_as_input == node.name()) {
++num_outputs;
+ } else {
+ const StringPiece name =
+ ParseNodeNameAsStringPiece(node_as_input, &pos);
+ if (name == node.name()) {
+ ++num_outputs;
+ }
}
}
}
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
// Get the trailing position number ":{digits}" (if any) of a node name.
int NodePosition(const string& name);
+inline StringPiece ParseNodeNameAsStringPiece(const string& name,
+ int* position) {
+ // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
+ // to get a node name.
+ strings::Scanner scan(name);
+ scan.ZeroOrOneLiteral("^")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
+ StringPiece capture;
+ StringPiece remaining;
+ if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ *position = 0;
+ static const string empty;
+ return StringPiece(empty);
+ } else {
+ if (name[0] == '^') {
+ *position = -1;
+ } else if (remaining.empty()) {
+ *position = 0;
+ } else {
+ // Skip the first ':' character.
+ CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ return capture;
+ }
+}
+
// Returns the node name and position in a single call.
-string ParseNodeName(const string& name, int* position);
+inline string ParseNodeName(const string& name, int* position) {
+ return ParseNodeNameAsStringPiece(name, position).ToString();
+}
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,