item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
EnableOnlyHoistCommonFactor(&optimizer);
EXPECT_EQ("id", id_node->name());
EXPECT_EQ(HoistDivName("add"), id_node->input(0));
}
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ if (use_ints) {
+ test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
+ } else {
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ }
}
}
}