Allow benchmark model graph to be specified in text proto format.
authorShashi Shekhar <shashishekhar@google.com>
Sat, 5 May 2018 18:55:53 +0000 (11:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 7 May 2018 22:45:26 +0000 (15:45 -0700)
PiperOrigin-RevId: 195547670

tensorflow/tools/benchmark/benchmark_model.cc
tensorflow/tools/benchmark/benchmark_model_test.cc

index 1552302..eeb1fab 100644 (file)
@@ -262,6 +262,10 @@ Status InitializeSession(int num_threads, const string& graph,
   tensorflow::GraphDef tensorflow_graph;
   Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get());
   if (!s.ok()) {
+    s = ReadTextProto(Env::Default(), graph, graph_def->get());
+  }
+
+  if (!s.ok()) {
     LOG(ERROR) << "Could not create TensorFlow Graph: " << s;
     return s;
   }
index 16ab2ff..6813045 100644 (file)
@@ -26,30 +26,36 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
-TEST(BenchmarkModelTest, InitializeAndRun) {
-  const string dir = testing::TmpDir();
-  const string filename_pb = io::JoinPath(dir, "graphdef.pb");
-
+void CreateTestGraph(const ::tensorflow::Scope& root,
+                     benchmark_model::InputLayerInfo* input,
+                     string* output_name, GraphDef* graph_def) {
   // Create a simple graph and write it to filename_pb.
   const int input_width = 400;
   const int input_height = 10;
-  benchmark_model::InputLayerInfo input;
-  input.shape = TensorShape({input_width, input_height});
-  input.data_type = DT_FLOAT;
+  input->shape = TensorShape({input_width, input_height});
+  input->data_type = DT_FLOAT;
   const TensorShape constant_shape({input_height, input_width});
 
   Tensor constant_tensor(DT_FLOAT, constant_shape);
   test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
 
-  auto root = Scope::NewRootScope().ExitOnError();
   auto placeholder =
-      ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape));
-  input.name = placeholder.node()->name();
+      ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape));
+  input->name = placeholder.node()->name();
   auto m = ops::MatMul(root, placeholder, constant_tensor);
-  const string output_name = m.node()->name();
+  *output_name = m.node()->name();
+  TF_ASSERT_OK(root.ToGraphDef(graph_def));
+}
+
+TEST(BenchmarkModelTest, InitializeAndRun) {
+  const string dir = testing::TmpDir();
+  const string filename_pb = io::JoinPath(dir, "graphdef.pb");
+  auto root = Scope::NewRootScope().ExitOnError();
 
+  benchmark_model::InputLayerInfo input;
+  string output_name;
   GraphDef graph_def;
-  TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+  CreateTestGraph(root, &input, &output_name, &graph_def);
   string graph_def_serialized;
   graph_def.SerializeToString(&graph_def_serialized);
   TF_ASSERT_OK(
@@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
   ASSERT_EQ(num_runs, 10);
 }
 
+TEST(BenchmarkModeTest, TextProto) {
+  const string dir = testing::TmpDir();
+  const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt");
+  auto root = Scope::NewRootScope().ExitOnError();
+
+  benchmark_model::InputLayerInfo input;
+  string output_name;
+  GraphDef graph_def;
+  CreateTestGraph(root, &input, &output_name, &graph_def);
+  TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def));
+
+  std::unique_ptr<Session> session;
+  std::unique_ptr<GraphDef> loaded_graph_def;
+  TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session,
+                                                  &loaded_graph_def));
+  std::unique_ptr<StatSummarizer> stats;
+  stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get())));
+  int64 time;
+  int64 num_runs = 0;
+  TF_ASSERT_OK(benchmark_model::TimeMultipleRuns(
+      0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(),
+      &time, &num_runs));
+  ASSERT_EQ(num_runs, 10);
+}
+
 }  // namespace
 }  // namespace tensorflow