return involution_ops->count(node.op()) > 0;
}
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
+ if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
+ return true;
+ }
+ static const std::unordered_set<string>*
+ value_and_order_and_shape_preserving_ops =
+ CHECK_NOTNULL((new const std::unordered_set<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "Identity",
+ "IdentityN",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return value_and_order_and_shape_preserving_ops->count(node.op()) > 0;
+}
+
bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
static const std::unordered_set<string>* value_and_order_preserving_ops =
CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
"ExpandDims",
- "Identity",
- "IdentityN",
- "PreventGradient",
- "Print",
- "Reshape",
"Snapshot",
"Squeeze",
- "StopGradient",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0;
+ return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
"Tanh",
}));
return element_wise_ops->count(node.op()) > 0 ||
- (!IsIdentityN(node) && IsValueAndOrderPreserving(node));
+ (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node));
}
bool HasOpDef(const NodeDef& node) {
// own inverse such that f(f(x)) == x.
bool IsInvolution(const NodeDef& node);
+// Returns true if the op preserves the order and value of elements
+// and shape of its first input tensor.
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node);
+
// Returns true if the op preserves the order and value of elements in its
// first input tensor and possible changes its shape.
bool IsValueAndOrderPreserving(const NodeDef& node);
return n > 1;
} else if (IsSplit(*node) || IsSplitV(*node)) {
const int num_split = node->attr().at("num_split").i();
+ if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
+ // TODO(rmlarsen): Remove this constraint when we have optimizations
+ // in place for merging slices into splits.
+ return false;
+ }
return num_split > 1 && !IsAlreadyOptimized(*node);
}
return false;
if (tails.empty()) {
return Status::OK();
}
- AddControlInputs(ctrl_inputs, root_node);
AddToOptimizationQueue(root_node);
optimized_nodes_.insert(root_node->name());
if (node_is_concat_) {
+ AddControlInputs(ctrl_inputs, root_node);
return HoistChainForConcat(prefix_length, tails, root_node);
} else {
- return HoistChainForSplit(prefix_length, tails, root_node);
+ return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
}
}
IsInPreserveSet(*op)) {
return false;
}
- if (node_is_concat_ &&
- ctx().node_map->GetOutputs(op->name()).size() > 1) {
- // TODO(rmlarsen): Allow and hoist outgoing control edges.
+ if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
+ // TODO(rmlarsen): Allow outgoing control edges.
return false;
}
}
}
Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
+ std::set<string>* ctrl_inputs,
NodeDef* split_node) {
// Create a new chain before the split node to process the input tensor.
const string& split_name = split_node->name();
cur_copy->add_input(orig_input);
ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
cur_copy->name());
+ // Make sure all the control inputs are satisfied before running the first
+ // node in the new chain.
+ AddControlInputs(ctrl_inputs, cur_copy);
// Connect all consumers of the tail nodes directly to the
// output port of Split from which the chain started.
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
bool remove_negation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool convert_sqrt_div_to_rsqrt_mul = false;
// Choose which arithmetic optimizer stages will be enabled for a given
EXPECT_NE(node.name(), "cos_exp_b2");
if (node.name() == "split1") {
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("axis", node.input(0));
EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
- EXPECT_EQ("^ctrl1", node.input(2));
found++;
}
if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
EXPECT_EQ("Sin", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
found++;
}
if (node.name() == "id_a") {
}
if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
EXPECT_EQ("Exp", node.op());
- EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ(4, node.input_size());
EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ctrl1", node.input(1));
+ EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("^ctrl3", node.input(3));
found++;
}
if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
found++;
}
if (node.name() == "split2") {
- EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
EXPECT_EQ("size_splits2", node.input(1));
EXPECT_EQ("axis", node.input(2));
- EXPECT_EQ("^ctrl1", node.input(3));
- EXPECT_EQ("^ctrl2", node.input(4));
- EXPECT_EQ("^ctrl3", node.input(5));
found++;
}
if (node.name() == "id_a2") {