graph_modified_ = true;
}
+void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
+ GraphDef* graph) {
+ node->set_op("Neg");
+ node->mutable_input()->SwapElements(0, 1);
+ const string ctrl_dep =
+ AddControlDependency(node->input(1), graph, node_map_.get());
+ node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
+ node->set_input(1, ctrl_dep);
+ graph_modified_ = true;
+}
+
Status ConstantFolding::ReplaceOperationWithConstant(
double value, const TensorShapeProto& shape, NodeDef* node,
GraphDef* graph) {
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
- // TODO(rmlarsen): Handle subtraction 0 - y.
// 1 * y = y or 0 + y = y.
ReplaceOperationWithSnapshot(1, node, output);
continue;
}
+ if (y_matches_output_shape && (is_sub && x_is_zero)) {
+ // Replace 0 - y with Neg(y).
+ ReplaceSubtractionFromZeroByNegation(node, output);
+ continue;
+ }
+
// Replace 1 / y with Reciprocal op.
if (y_matches_output_shape && is_any_div && x_is_one) {
DataType type = node->attr().at("T").type();
GraphDef* graph);
void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
GraphDef* graph);
+ void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
Status ReplaceOperationWithConstant(double value,
const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph);
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros", node.input(1));
} else if (name == "sub2") {
- // We don't handle this case yet.
- EXPECT_EQ("Sub", node.op());
- EXPECT_EQ("zeros", node.input(0));
- EXPECT_EQ("y", node.input(1));
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
}
const std::set<string> square_zero_const{"mul1", "mul2", "mul5",
"mul6", "matmul1", "matmul2"};