}
// We need to record a copy of output nodes before FoldNode() modifies it.
std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
+
Status s = FoldNode(node, output);
processed_nodes.insert(node->name());
if (!s.ok()) {
const GraphProperties& properties,
bool use_shape_info) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
- for (auto& node : *output->mutable_node()) {
- if (IsSimplifiableReduction(node)) {
+ for (int i = 0; i < output->node_size(); ++i) {
+ NodeDef* node = output->mutable_node(i);
+ if (IsSimplifiableReduction(*node)) {
// Replace the reduction node with an identity node, that can be further
// optimized by the model pruner.
DataType output_type;
- if (node.attr().count("T") > 0) {
- output_type = node.attr().at("T").type();
+ if (node->attr().count("T") > 0) {
+ output_type = node->attr().at("T").type();
} else {
// This is an 'any' or 'all' reduction. The output is always boolean.
output_type = DT_BOOL;
}
- node.set_op("Identity");
- node.clear_attr();
- (*node.mutable_attr())["T"].set_type(output_type);
- *node.mutable_input(1) = AsControlDependency(node.input(1));
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ continue;
}
const bool safe_to_use_shapes =
use_shape_info && (feed_nodes_.empty() || is_aggressive);
- if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) {
- DataType output_type = node.attr().at("T").type();
- node.set_op("Identity");
- node.clear_attr();
- (*node.mutable_attr())["T"].set_type(output_type);
- *node.mutable_input(1) = AsControlDependency(node.input(1));
+ if (safe_to_use_shapes && IsSimplifiableReshape(*node, properties)) {
+ DataType output_type = node->attr().at("T").type();
+ node->set_op("Identity");
+ node->clear_attr();
+ (*node->mutable_attr())["T"].set_type(output_type);
+ *node->mutable_input(1) = AsControlDependency(node->input(1));
+ continue;
}
+ const bool is_mul = IsMul(*node);
+ const bool is_matmul = IsMatMul(*node);
+ const bool is_add = IsAdd(*node) || IsBiasAdd(*node);
+ const bool is_sub = IsSub(*node);
+ const bool is_any_div = IsAnyDiv(*node);
// Simplify multiplication by ones or zeros, and addition/subtraction of
// zeros.
- // TODO(rmlarsen): Rewrite x / const -> x * (1/const).
- bool is_mul = IsMul(node);
- bool is_matmul = IsMatMul(node);
- bool is_add = IsAdd(node) || IsBiasAdd(node);
- bool is_sub = IsSub(node);
- bool is_div = IsAnyDiv(node);
- if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_div) &&
- properties.HasInputProperties(node.name()) &&
- properties.HasOutputProperties(node.name())) {
- const NodeDef* x = node_map_->GetNode(node.input(0));
- const NodeDef* y = node_map_->GetNode(node.input(1));
+ if (use_shape_info &&
+ (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
+ properties.HasInputProperties(node->name()) &&
+ properties.HasOutputProperties(node->name())) {
+ const NodeDef* x = node_map_->GetNode(node->input(0));
+ const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
return errors::InvalidArgument("Invalid inputs to node: ",
- node.DebugString());
+ node->DebugString());
}
const TensorShapeProto& output_shape =
- properties.GetOutputProperties(node.name())[0].shape();
+ properties.GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties.GetInputProperties(node.name())[1].shape();
+ properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// TODO(rmlarsen): Handle subtraction 0 - y.
// 1 * y = y or 0 + y = y.
- ReplaceOperationWithIdentity(1, &node);
+ ReplaceOperationWithIdentity(1, node);
continue;
}
// Replace 1 / y with Reciprocal op.
- if (y_matches_output_shape && is_div && x_is_one) {
- ReplaceDivisionOfOnesByReciprocal(&node);
+ if (y_matches_output_shape && is_any_div && x_is_one) {
+ ReplaceDivisionOfOnesByReciprocal(node);
continue;
}
const TensorShapeProto& x_shape =
- properties.GetInputProperties(node.name())[0].shape();
+ properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
if (x_matches_output_shape &&
- (((is_mul || is_div) && y_is_one) ||
+ (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero && is_aggressive))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
- ReplaceOperationWithIdentity(0, &node);
+ ReplaceOperationWithIdentity(0, node);
continue;
}
// Simplify multiplication and matmul by zeros.
// Also optimize zeros divided by a tensor, but only if we are in
// aggressive mode, since we might get rid of divisions by zero.
- bool optimize_zeros_divided_by_y = is_div && x_is_zero && is_aggressive;
+ bool optimize_zeros_divided_by_y =
+ is_any_div && x_is_zero && is_aggressive;
if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
TF_RETURN_IF_ERROR(
- ReplaceOperationWithConstant(0, output_shape, &node));
+ ReplaceOperationWithConstant(0, output_shape, node));
continue;
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
- if ((is_mul || is_div) && x_is_zero && x_matches_output_shape) {
- ReplaceOperationWithIdentity(0, &node);
+ if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
+ ReplaceOperationWithIdentity(0, node);
+ continue;
} else if (is_mul && y_is_zero && y_matches_output_shape) {
- ReplaceOperationWithIdentity(1, &node);
+ ReplaceOperationWithIdentity(1, node);
+ continue;
}
}
}
+
+ // Strength reduce floating point division by a constant Div(x, const) to
+ // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
+ // will be constant folded to Mul(x, 1.0/const).
+ if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
+ const string& const_input = node->input(1);
+ const NodeDef* denom = node_map_->GetNode(const_input);
+ CHECK(denom != nullptr);
+ if (!IsReallyConstant(*denom)) {
+ continue;
+ }
+ if (node->attr().count("T") == 0) {
+ continue;
+ }
+ DataType type = node->attr().at("T").type();
+ if (IsDiv(*node) && !DataTypeIsFloating(type)) {
+ continue;
+ }
+ // Insert new reciprocal op and change node from Div to Mul.
+ NodeDef* reciprocal_node = output->add_node();
+ reciprocal_node->set_name(AddPrefixToNodeName(
+ strings::StrCat(node->name(), "_recip"), kConstantFoldingConst));
+ reciprocal_node->set_op("Reciprocal");
+ reciprocal_node->set_device(node->device());
+ node->set_op("Mul");
+ // Re-wire inputs and outputs.
+ reciprocal_node->add_input(const_input);
+ (*reciprocal_node->mutable_attr())["T"].set_type(type);
+ node->set_input(1, reciprocal_node->name());
+ node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
+ node_map_->UpdateInput(node->name(), const_input,
+ reciprocal_node->name());
+ node_map_->AddOutput(NodeName(const_input), reciprocal_node->name());
+ graph_modified_ = true;
+ }
}
+
return Status::OK();
}
}
TF_RETURN_IF_ERROR(FoldGraph(output));
+ node_map_.reset(new NodeMap(output));
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
return Status::OK();
}
}
+TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
+ Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
+ Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
+ Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
+ Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
+ Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"div_f", "div_i", "realdiv"};
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(8, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "div_i") {
+ // Integer division is unchanged.
+ EXPECT_EQ("Div", node.op());
+ EXPECT_EQ("xi", node.input(0));
+ EXPECT_EQ("ci", node.input(1));
+ } else if (name == "div_f") {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
+ } else if (name == "realdiv") {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ("xf", node.input(0));
+ EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
+ } else if (name == "ConstantFolding/div_f_recip") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(DT_FLOAT, t.dtype());
+ EXPECT_EQ(1, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ } else if (name == "ConstantFolding/realdiv_recip") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(DT_FLOAT, t.dtype());
+ EXPECT_EQ(1, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ }
+ }
+
+ // Check that the reciprocals have the expected value.
+ std::vector<string> fetch = {"cf_half"};
+ auto tensor_expected = EvaluateNodes(item.graph, fetch);
+ EXPECT_EQ(fetch.size(), tensor_expected.size());
+ fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(fetch.size(), tensors.size());
+ for (int i = 0; i < fetch.size(); i++) {
+ test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
+ }
+}
+
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x_known =