return false;
}
}
+ return !ModifiesInputsInPlace(node);
+}
+
+bool ModifiesInputsInPlace(const NodeDef& node) {
// Some nodes do in-place updates on regular tensor inputs.
- if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace") ||
- StringPiece(op_name).starts_with("Inplace")) {
- return false;
+ string op_name = node.op();
+ std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
+ if (StringPiece(op_name).contains("inplace")) {
+ return true;
}
- return true;
+ return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
}
bool ModifiesFrameInfo(const NodeDef& node) {
bool IsPersistent(const NodeDef& node);
bool IsFreeOfSideEffect(const NodeDef& node);
+
bool ModifiesFrameInfo(const NodeDef& node);
+// Returns true if the op is known to write to one or more of its inputs.
+bool ModifiesInputsInPlace(const NodeDef& node);
+
// Returns true if the op is an element-wise involution, i.e. if it is its
// own inverse such that f(f(x)) == x.
bool IsInvolution(const NodeDef& node);
return node_map_->NodeExists(OptimizedNodeName(node, suffix));
}
+namespace {
+
+bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
+ const std::unordered_set<string> op_types_to_traverse = {
+ node.op(), "Identity", "IdentityN", "Reshape"};
+ int node_idx = graph_view.index(node.name());
+ std::set<int> node_fanout;
+ graph_view.DepthFirstSearch(op_types_to_traverse, node_idx, &node_fanout);
+ for (int fanout : node_fanout) {
+ if (ModifiesInputsInPlace(graph_view.graph()->node(fanout))) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
void ArithmeticOptimizer::DedupComputations() {
bool stop = true;
+ SimpleGraphView graph_view;
+ if (!graph_view.Initialize(*optimized_graph_).ok()) {
+ LOG(WARNING) << "Failed to build SimpleGraphView.";
+ return;
+ }
std::set<int> duplicates;
do {
stop = true;
if (rep == node) {
continue;
}
+ // If either node feeds an inplace op, deduping them may cause data races.
+ // For example: If we dedup nodes initializing two independent inplace
+ // accumulations, they will write to the same buffer, clobbering each
+ // other's results.
+ if (FeedsInPlaceOp(graph_view, *rep) ||
+ FeedsInPlaceOp(graph_view, *node)) {
+ continue;
+ }
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
- for (string& name : *fanout->mutable_input()) {
+ for (int i = 0; i < fanout->input_size(); ++i) {
+ string* name = fanout->mutable_input(i);
int position;
- const string nodename = ParseNodeName(name, &position);
+ const string nodename = ParseNodeName(*name, &position);
if (nodename == node->name()) {
// Update name in-place.
if (position > 0) {
- name = StrCat(rep->name(), ":", position);
+ *name = StrCat(rep->name(), ":", position);
} else if (position == 0) {
- name = rep->name();
+ *name = rep->name();
} else {
- name = StrCat("^", rep->name());
+ *name = StrCat("^", rep->name());
}
node_map_->AddOutput(rep->name(), fanout->name());
}