}
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties,
+ GraphProperties* properties,
bool use_shape_info) {
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < output->node_size(); ++i) {
if (use_shape_info &&
(IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size 1)
bool replaceable = !shape.unknown_rank();
}
if (use_shape_info && IsSlice(*node) &&
- properties.GetInputProperties(node->name()).size() == 3) {
- const auto& input = properties.GetInputProperties(node->name())[0];
- const auto& b = properties.GetInputProperties(node->name())[1];
- const auto& s = properties.GetInputProperties(node->name())[2];
+ properties->GetInputProperties(node->name()).size() == 3) {
+ const auto& input = properties->GetInputProperties(node->name())[0];
+ const auto& b = properties->GetInputProperties(node->name())[1];
+ const auto& s = properties->GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
}
if (IsTile(*node) &&
- properties.GetInputProperties(node->name()).size() == 2) {
- const auto& m = properties.GetInputProperties(node->name())[1];
+ properties->GetInputProperties(node->name()).size() == 2) {
+ const auto& m = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
}
if (IsPad(*node) &&
- properties.GetInputProperties(node->name()).size() >= 2) {
- const auto& p = properties.GetInputProperties(node->name())[1];
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& p = properties->GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
}
if (use_shape_info && IsSqueeze(*node) &&
- !properties.GetInputProperties(node->name()).empty()) {
+ !properties->GetInputProperties(node->name()).empty()) {
// https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
// error to squeeze a dimension that is not 1, so we only need to check
// whether the input has > 1 size for each dimension.
const auto& shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size > 1)
bool replaceable = !shape.unknown_rank();
}
}
+ if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+ !OptimizedNodeExists(*node, "_const_axis")) {
+ // Create constant axis node.
+ Tensor axis_t(DT_INT32, TensorShape({}));
+ NodeDef* axis_node = output->add_node();
+ axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+ const int axis = node->attr().at("axis").i();
+ if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+ !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
+ .ok()) {
+ continue;
+ }
+ VLOG(1) << "*** Rewriting trivial Pack node: " << node->DebugString();
+ // Add a control dependency to make sure axis_node is in the right frame.
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ node->input(0), graph_, node_map_.get());
+ axis_node->add_input(ctrl_dep);
+ axis_node->set_device(node->device());
+ node->set_op("ExpandDims");
+ if (node->attr().count("axis") != 0) {
+ node->mutable_attr()->erase("axis");
+ }
+ if (node->attr().count("N") != 0) {
+ node->mutable_attr()->erase("N");
+ }
+ (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+ node->add_input(axis_node->name());
+ if (node->input_size() > 2) {
+ node->mutable_input()->SwapElements(1, node->input_size() - 1);
+ }
+ }
+
// Switch(x, x) will always feed false to its false branch and true to
// its true branch. By rewriting the graph a bit, we can propagate these
// constants down the two output branches, and just use control dependencies
graph_modified_ = true;
continue;
}
- if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+ if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
node->clear_attr();
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
- properties.HasInputProperties(node->name()) &&
- properties.HasOutputProperties(node->name())) {
+ properties->HasInputProperties(node->name()) &&
+ properties->HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
node->DebugString());
}
const TensorShapeProto& output_shape =
- properties.GetOutputProperties(node->name())[0].shape();
+ properties->GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
- properties.GetInputProperties(node->name())[1].shape();
+ properties->GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = IsOnes(*x);
const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
}
const TensorShapeProto& x_shape =
- properties.GetInputProperties(node->name())[0].shape();
+ properties->GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = IsOnes(*y);
const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
}
TF_RETURN_IF_ERROR(FoldGraph(output));
node_map_.reset(new NodeMap(output));
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
return Status::OK();
}
EXPECT_EQ("^id_n", output.node(7).input(2));
}
+TEST_F(ConstantFoldingTest, TrivialPack) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output x =
+ ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
+ Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
+ auto stack =
+ ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
+ ops::Stack::Axis(1));
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch.push_back("stack");
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ for (const auto& node : output.node()) {
+ if (node.name() == "stack") {
+ EXPECT_EQ("stack", node.name());
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
+ EXPECT_EQ("^y", node.input(2));
+ } else if (node.name() == "ConstantFolding/stack_const_axis") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("^x", node.input(0));
+ }
+ }
+
+ std::vector<string> fetch = {"stack"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow