#define EIGEN_USE_THREADS
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
Status ConstantFolding::ReplaceOperationWithConstant(
double value, const GraphProperties& properties,
- const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
+ const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
+ bool* success) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
- if (dtype == DT_INVALID) return Status::OK();
+ if (dtype == DT_INVALID) {
+ *success = false;
+ return Status::OK();
+ }
AttrValue tensor_attr;
TF_RETURN_IF_ERROR(
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
- graph_modified_ = true;
+ *success = true;
return Status::OK();
}
Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
GraphProperties* properties,
bool use_shape_info) {
- const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
return Status::OK();
return Status::OK();
}
+ bool arithmetic_simplification_succeed = false;
+ Status simplify_arithmetic_status = SimplifyArithmeticOperations(
+ optimized_graph, properties, node, use_shape_info,
+ &arithmetic_simplification_succeed);
+ if (!simplify_arithmetic_status.ok()) {
+ return simplify_arithmetic_status;
+ } else if (arithmetic_simplification_succeed) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (ReduceDivToReciprocalMul(optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (ConstantPushDown(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConstPropThroughIdentityN(node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status ConstantFolding::SimplifyArithmeticOperations(
+ GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node,
+ bool use_shape_info, bool* success) {
const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
if (y_matches_output_shape && (is_sub && x_is_zero)) {
// Replace 0 - y with Neg(y).
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
+ *success = true;
return Status::OK();
}
DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
// x OR true = true OR y = true.
+ bool updated_graph = false;
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
- 1, *properties, output_shape, node, optimized_graph));
+ bool replace_succeed = false;
+ Status replace_op_status =
+ ReplaceOperationWithConstant(1, *properties, output_shape, node,
+ optimized_graph, &replace_succeed);
+ if (!replace_op_status.ok()) {
+ return replace_op_status;
+ } else if (replace_succeed) {
+ updated_graph = true;
+ }
}
// Simplify multiplication and matmul by zeros.
// Also optimize zeros divided by a tensor, but only if we are in
// aggressive mode, since we might get rid of divisions by zero.
+ const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
if (shp.IsFullyDefined()) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
- 0, *properties, output_shape, node, optimized_graph));
- return Status::OK();
+ bool replace_succeed = false;
+ Status replace_op_status =
+ ReplaceOperationWithConstant(0, *properties, output_shape, node,
+ optimized_graph, &replace_succeed);
+ if (!replace_op_status.ok()) {
+ return replace_op_status;
+ } else if (replace_succeed) {
+ *success = true;
+ return Status::OK();
+ }
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ *success = true;
return Status::OK();
} else if (is_mul && y_is_zero && y_matches_output_shape) {
ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+ *success = true;
return Status::OK();
}
}
+ if (updated_graph) {
+ *success = true;
+ return Status::OK();
+ }
}
+ *success = false;
+ return Status::OK();
+}
+bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
+ NodeDef* node) {
// Strength reduce floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
// will be constant folded to Mul(x, 1.0/const).
const NodeDef* denom = node_map_->GetNode(const_input);
CHECK(denom != nullptr);
if (!IsReallyConstant(*denom)) {
- return Status::OK();
+ return false;
}
if (node->attr().count("T") == 0) {
- return Status::OK();
+ return false;
}
DataType type = node->attr().at("T").type();
if (IsDiv(*node) &&
!(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
- return Status::OK();
+ return false;
}
// Insert new reciprocal op and change node from Div to Mul.
NodeDef* reciprocal_node = optimized_graph->add_node();
node->set_input(1, reciprocal_node->name());
node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (ConstantPushDown(node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialConstPropThroughIdentityN(node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
- }
-
- if (PartialConcatConstFolding(optimized_graph, properties, node)) {
- graph_modified_ = true;
- return Status::OK();
+ return true;
}
-
- return Status::OK();
+ return false;
}
bool ConstantFolding::ConstantPushDown(NodeDef* node) {