TEST_F(FunctionOptimizerTest, SymbolicGradients) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
- // auto T = DT_FLOAT;
FunctionDef func = FunctionDefHelper::Define(
"TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
{
{{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
});
- auto dummy_variable = ops::Variable(scope, {2, 2}, DT_FLOAT);
auto x = ops::Const(scope, 1.0f);
auto y = ops::Const(scope, 2.0f);
auto dl = ops::Const(scope, 3.0f);
test::ExpectTensorEqual<float>(expected[1], optimized[1]);
}
+TEST_F(FunctionOptimizerTest, SymbolicGradientsIdentity) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ FunctionDef func = FunctionDefHelper::Create(
+ // Name
+ "Identity_func",
+ // Args
+ {"in: float"},
+ // Return values
+ {"out: float"},
+ // Attr def
+ {},
+ // Nodes
+ {{{"Identity"}, "Identity", {"in"}, {{"T", DT_FLOAT}}}},
+ // Mapping
+ {{"out", "Identity:output:0"}});
+
+ auto x = ops::Const(scope, 1.0f, {3, 5, 7});
+ auto z = ops::Const(scope, 3.0f, {3, 5, 7});
+
+ NameAttrList fn;
+ fn.set_name("Identity_func");
+ auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, z},
+ {DT_FLOAT}, fn);
+ auto out = ops::Identity(scope.WithOpName("out"), g0.output[0]);
+
+ GrapplerItem item;
+ TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
+ *item.graph.mutable_library()->add_function() = func;
+
+ FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(13, output.node_size());
+ EXPECT_EQ("Const", output.node(0).name());
+ EXPECT_EQ("Const_1", output.node(1).name());
+ EXPECT_EQ("SymbolicGradient/FunctionInputs", output.node(2).name());
+ EXPECT_EQ("SymbolicGradient", output.node(3).name());
+ EXPECT_EQ("SymbolicGradient/SymbolicGradient/Identity",
+ output.node(4).name());
+ EXPECT_EQ("SymbolicGradient/Func/_0", output.node(5).name());
+ EXPECT_EQ("SymbolicGradient/Func/_1", output.node(6).name());
+ EXPECT_EQ("SymbolicGradient/Func/_2", output.node(7).name());
+ EXPECT_EQ("SymbolicGradient/SymbolicGradient/Func/_1/dx",
+ output.node(8).name());
+ EXPECT_EQ("SymbolicGradient/Func/_3", output.node(9).name());
+ EXPECT_EQ("SymbolicGradient/Func/_4", output.node(10).name());
+ EXPECT_EQ("SymbolicGradient/Func/_5", output.node(11).name());
+ EXPECT_EQ("out", output.node(12).name());
+ for (int i = 2; i < 4; ++i) {
+ EXPECT_EQ("IdentityN", output.node(i).op());
+ }
+ for (int i = 4; i < 11; ++i) {
+ EXPECT_EQ("Identity", output.node(i).op());
+ }
+
+ std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"});
+ std::vector<Tensor> optimized = EvaluateNodes(output, {"out"});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
+TEST_F(FunctionOptimizerTest, SymbolicGradientsNoInlineFunc) {
+ FunctionDef func = FunctionDefHelper::Define(
+ "TestFunc", {"x:float", "y:float"}, {"l:float"}, {},
+ {
+ {{"z"}, "Add", {"x", "y"}, {{"T", DT_FLOAT}}},
+ FunctionDefHelper::Const("zero", 0),
+ FunctionDefHelper::Const("one", 1),
+ {{"r"}, "Rank", {"z"}, {{"T", DT_FLOAT}}},
+ {{"indices"}, "Range", {"zero", "r", "one"}},
+ {{"l"}, "Sum", {"z", "indices"}, {{"T", DT_FLOAT}}},
+ });
+ (*func.mutable_attr())["_noinline"].set_b(true);
+
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(scope, 1.0f);
+ auto y = ops::Const(scope, 2.0f);
+ auto dl = ops::Const(scope, 3.0f);
+
+ NameAttrList fn;
+ fn.set_name("TestFunc");
+ (*fn.mutable_attr())["T"].set_type(DT_FLOAT);
+ auto g0 = ops::SymbolicGradient(scope, std::initializer_list<Input>{x, y, dl},
+ {DT_FLOAT, DT_FLOAT}, fn);
+ auto out1 = ops::Identity(scope.WithOpName("out1"), g0.output[0]);
+ auto out2 = ops::Identity(scope.WithOpName("out2"), g0.output[1]);
+
+ GrapplerItem item;
+ TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
+ *item.graph.mutable_library()->add_function() = func;
+
+ FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ // The optimizer should succeed but the graphs should be the same.
+ TF_EXPECT_OK(status);
+ CompareGraphs(item.graph, output);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow