The benchmark binary support multiple batches in one run (#15443)
authorFei Sun <feisun@fb.com>
Fri, 21 Dec 2018 16:39:05 +0000 (08:39 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 21 Dec 2018 16:45:41 +0000 (08:45 -0800)
Summary:
It is sometimes beneficial to run multiple batches in one benchmark and check the aggregated results.

This PR enables this functionality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15443

Reviewed By: llyfacebook

Differential Revision: D13531129

Pulled By: sf-wind

fbshipit-source-id: 553a762a5cbadf5a3d9fd6af767ae34899bc1aa2

binaries/benchmark_helper.cc
binaries/benchmark_helper.h
binaries/convert_image_to_tensor.cc

index b4e8714..10b48f0 100644 (file)
@@ -93,7 +93,7 @@ void setOperatorEngine(caffe2::NetDef* net_def, const string& backend) {
   }
 }
 
-void loadInput(
+int loadInput(
     shared_ptr<caffe2::Workspace> workspace,
     const bool run_on_gpu,
     map<string, caffe2::TensorProtos>& tensor_protos_map,
@@ -101,6 +101,8 @@ void loadInput(
     const string& input_file,
     const string& input_dims,
     const string& input_type) {
+  // How many input blobs are in the inputs
+  int blob_num = 1;
   // Load input.
   if (input.size()) {
     vector<string> input_names = caffe2::split(',', input);
@@ -117,6 +119,15 @@ void loadInput(
         workspace->CreateBlob(input_names[i]);
         tensor_protos_map.insert(std::make_pair(input_names[i], tensor_protos));
       }
+      // Check that all blobs have the same number of entries
+      blob_num = tensor_protos_map[input_names[0]].protos_size();
+      for (int i = 1; i < input_names.size(); ++i) {
+        int bnum = tensor_protos_map[input_names[i]].protos_size();
+        CAFFE_ENFORCE_EQ(
+            blob_num,
+            bnum,
+            "Number of blobs are not the same for all inputs");
+      }
     } else if (input_dims.size() || input_type.size()) {
       CAFFE_ENFORCE_GE(
           input_dims.size(),
@@ -186,6 +197,7 @@ void loadInput(
           "input_dims is set.");
     }
   }
+  return blob_num;
 }
 
 void fillInputBlob(
@@ -222,11 +234,17 @@ void runNetwork(
     map<string, caffe2::TensorProtos>& tensor_protos_map,
     const bool wipe_cache,
     const bool run_individual,
+    const bool run_on_gpu,
+    const bool text_output,
     const int warmup,
     const int iter,
+    const int num_blobs,
     const int sleep_before_run,
     const int sleep_between_iteration,
-    const int sleep_between_net_and_operator) {
+    const int sleep_between_net_and_operator,
+    const std::string& output,
+    const std::string& output_folder) {
+
   if (!net_def.has_name()) {
     net_def.set_name("benchmark");
   }
@@ -262,6 +280,15 @@ void runNetwork(
       caffe2::wipe_cache();
     }
     CAFFE_ENFORCE(net->Run(), "Main run ", i, " has failed.");
+    // Write the output for the first num_blobs times
+    writeOutput(
+        workspace,
+        run_on_gpu,
+        output,
+        output_folder,
+        text_output,
+        i,
+        num_blobs);
     if (wipe_cache) {
       caffe2::wipe_cache();
     }
@@ -296,39 +323,50 @@ void writeOutput(
     const bool run_on_gpu,
     const string& output,
     const string& output_folder,
-    const bool text_output) {
+    const bool text_output,
+    const int index,
+    const int num_blobs) {
+  if (output.size() == 0) {
+    return;
+  }
   string output_prefix = output_folder.size() ? output_folder + "/" : "";
-  if (output.size()) {
-    vector<string> output_names = caffe2::split(',', output);
-    if (output == "*") {
-      output_names = workspace->Blobs();
-    }
-    for (const string& name : output_names) {
-      CAFFE_ENFORCE(
-          workspace->HasBlob(name),
-          "You requested a non-existing blob: ",
-          name);
-      if (text_output) {
-        if (run_on_gpu) {
+  vector<string> output_names = caffe2::split(',', output);
+  if (output == "*") {
+    output_names = workspace->Blobs();
+  }
+  for (const string& name : output_names) {
+    CAFFE_ENFORCE(
+        workspace->HasBlob(name),
+        "You requested a non-existing blob: ",
+        name);
+    if (text_output) {
+      if (run_on_gpu) {
 #ifdef __CUDA_ARCH__
-          writeTextOutput<caffe2::CUDAContext, caffe2::TensorCUDA>(
-              workspace->GetBlob(name)->GetMutable<caffe2::TensorCUDA>(),
-              output_prefix,
-              name);
+        writeTextOutput<caffe2::CUDAContext, caffe2::TensorCUDA>(
+            workspace->GetBlob(name)->GetMutable<caffe2::TensorCUDA>(),
+            output_prefix,
+            name,
+            index,
+            num_blobs);
 #else
-          CAFFE_THROW("Not support GPU.");
+        CAFFE_THROW("Not support GPU.");
 #endif
-        } else {
-          writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
-              BlobGetMutableTensor(workspace->GetBlob(name), caffe2::CPU),
-              output_prefix,
-              name);
-        }
       } else {
-        string serialized = SerializeBlob(*workspace->GetBlob(name), name);
-        string output_filename = output_prefix + name;
-        caffe2::WriteStringToFile(serialized, output_filename.c_str());
+        writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
+            BlobGetMutableTensor(workspace->GetBlob(name), caffe2::CPU),
+            output_prefix,
+            name,
+            index,
+            num_blobs);
       }
+    } else {
+      // Do not support multiple entries per blob.
+      CAFFE_ENFORCE(
+          index == 0,
+          "Binary file only support one output.");
+      string serialized = SerializeBlob(*workspace->GetBlob(name), name);
+      string output_filename = output_prefix + name;
+      caffe2::WriteStringToFile(serialized, output_filename.c_str());
     }
   }
 }
@@ -393,7 +431,7 @@ int benchmark(
 
   map<string, caffe2::TensorProtos> tensor_protos_map;
 
-  loadInput(
+  int num_blobs = loadInput(
       workspace,
       run_on_gpu,
       tensor_protos_map,
@@ -408,18 +446,16 @@ int benchmark(
       tensor_protos_map,
       FLAGS_wipe_cache,
       FLAGS_run_individual,
+      run_on_gpu,
+      FLAGS_text_output,
       FLAGS_warmup,
       FLAGS_iter,
+      num_blobs,
       FLAGS_sleep_before_run,
       FLAGS_sleep_between_iteration,
-      FLAGS_sleep_between_net_and_operator);
-
-  writeOutput(
-      workspace,
-      run_on_gpu,
+      FLAGS_sleep_between_net_and_operator,
       FLAGS_output,
-      FLAGS_output_folder,
-      FLAGS_text_output);
+      FLAGS_output_folder);
 
   return 0;
 }
index 6f12878..0301440 100644 (file)
@@ -34,7 +34,12 @@ template <typename ContextType, typename TensorType>
 void writeTextOutput(
     TensorType* tensor,
     const string& output_prefix,
-    const string& name) {
+    const string& name,
+    int index,
+    int num_blobs) {
+  if (index >= num_blobs) {
+    return;
+  }
   string filename = name;
   std::replace(filename.begin(), filename.end(), '/', '_');
   string output_name = output_prefix + "/" + filename + ".txt";
@@ -80,7 +85,13 @@ void writeTextOutput(
   str.pop_back();
   lines.push_back(str);
 
-  std::ofstream output_file(output_name);
+  auto flags = std::ios::out;
+  if (index != 0) {
+    flags |= std::ios::app;
+  } else {
+    flags |= std::ios::trunc;
+  }
+  std::ofstream output_file(output_name, flags);
   std::ostream_iterator<std::string> output_iterator(output_file, "\n");
   std::copy(lines.begin(), lines.end(), output_iterator);
 }
@@ -89,35 +100,42 @@ void observerConfig();
 bool backendCudaSet(const string&);
 void setDeviceType(caffe2::NetDef*, caffe2::DeviceType&);
 void setOperatorEngine(caffe2::NetDef*, const string&);
-void loadInput(
-    shared_ptr<caffe2::Workspace>,
-    const bool,
-    map<string, caffe2::TensorProtos>&,
-    const string&,
-    const string&,
-    const string&,
-    const string&);
+int loadInput(
+    shared_ptr<caffe2::Workspace> workspace,
+    const bool run_on_gpu,
+    map<string, caffe2::TensorProtos>& tensor_protos_map,
+    const string& input,
+    const string& input_file,
+    const string& input_dims,
+    const string& input_type);
 void fillInputBlob(
-    shared_ptr<caffe2::Workspace>,
-    map<string, caffe2::TensorProtos>&,
+    shared_ptr<caffe2::Workspace> workspace,
+    map<string, caffe2::TensorProtos>& tensor_protos_map,
     int iteration);
 void writeOutput(
-    shared_ptr<caffe2::Workspace>,
-    const bool,
-    const string&,
-    const string&,
-    const bool);
+    shared_ptr<caffe2::Workspace> workspace,
+    const bool run_on_gpu,
+    const string& output,
+    const string& output_folder,
+    const bool text_output,
+    const int index,
+    const int num_blobs);
 void runNetwork(
-    shared_ptr<caffe2::Workspace>,
-    caffe2::NetDef&,
-    map<string, caffe2::TensorProtos>&,
-    const bool,
-    const bool,
-    const int,
-    const int,
-    const int,
-    const int,
-    const int);
+    shared_ptr<caffe2::Workspace> workspace,
+    caffe2::NetDef& net_def,
+    map<string, caffe2::TensorProtos>& tensor_protos_map,
+    const bool wipe_cache,
+    const bool run_individual,
+    const bool run_on_gpu,
+    const bool text_output,
+    const int warmup,
+    const int iter,
+    const int num_blobs,
+    const int sleep_before_run,
+    const int sleep_between_iteration,
+    const int sleep_between_net_and_operator,
+    const std::string& output,
+    const std::string& output_folder);
 int benchmark(
     int argc,
     char* argv[],
index 785102f..26397a1 100755 (executable)
@@ -285,36 +285,36 @@ int getBatchSize(int num_items) {
 }
 
 void writeValues(
-    std::vector<std::vector<float>>& values,
-    std::vector<int>& dims,
+    std::vector<std::vector<std::vector<float>>>& values,
+    std::vector<std::vector<int>>& dims,
     std::string output_file) {
 
   caffe2::Timer timer;
   timer.Start();
 
-  int batch_size = getBatchSize(values.size());
-  int num_batches = values.size() / batch_size;
-  assert(dims[0] == batch_size);
+  assert(dims.size() == values.size());
+  int num_batches = dims.size();
 
   TensorProtos protos;
   for (int k = 0; k < num_batches; k++) {
     TensorProto* data;
     data = protos.add_protos();
     data->set_data_type(TensorProto::FLOAT);
-    for (int dim : dims) {
+    auto one_dim = dims[k];
+    for (int dim : one_dim) {
       data->add_dims(dim);
     }
+    int batch_size = one_dim[0];
     long long int entry_size = 1;
-    for (int i = 1; i < dims.size(); i++) {
-      entry_size *= dims[i];
+    for (int i = 1; i < one_dim.size(); i++) {
+      entry_size *= one_dim[i];
     }
 
     // Not optimized
     for (int i = 0; i < batch_size; i++) {
-      int idx = k * batch_size + i;
-      assert(values[idx].size() == entry_size);
-      for (int j = 0; j < values[idx].size(); j++) {
-        data->add_float_data(values[idx][j]);
+      assert(values[k][i].size() == entry_size);
+      for (int j = 0; j < values[k][i].size(); j++) {
+        data->add_float_data(values[k][i][j]);
       }
     }
   }
@@ -348,26 +348,34 @@ void convertImages() {
   } else {
     return;
   }
-  std::vector<std::vector<float>> values;
+  int batch_size = getBatchSize(file_names.size());
+  int num_batches = file_names.size() / batch_size;
+  assert(file_names.size() == batch_size * num_batches);
+  std::vector<std::vector<std::vector<float>>> values;
+  std::vector<std::vector<int>> dims;
   int C = FLAGS_color ? 3 : 1;
-  int height = -1;
-  int width = -1;
-  for (int i = 0; i < file_names.size(); i++) {
-    int one_height, one_width;
-    std::vector<float> one_image_values =
-        convertOneImage(file_names[i], &one_height, &one_width);
-    if (height < 0 && width < 0) {
-      height = one_height;
-      width = one_width;
-    } else {
-      assert(height == one_height);
-      assert(width == one_width);
+  for (int k = 0; k < num_batches; k++) {
+    std::vector<std::vector<float>> one_value;
+    int height = -1;
+    int width = -1;
+    for (int i = 0; i < batch_size; i++) {
+      int idx = k * batch_size + i;
+      int one_height, one_width;
+      std::vector<float> one_image_values =
+          convertOneImage(file_names[idx], &one_height, &one_width);
+      if (height < 0 && width < 0) {
+        height = one_height;
+        width = one_width;
+      } else {
+        assert(height == one_height);
+        assert(width == one_width);
+      }
+      one_value.push_back(one_image_values);
     }
-    values.push_back(one_image_values);
+    vector<int> one_dim = {batch_size, C, height, width};
+    dims.push_back(one_dim);
+    values.push_back(one_value);
   }
-
-  int batch_size = getBatchSize(values.size());
-  vector<int> dims = {batch_size, C, height, width};
   writeValues(values, dims, FLAGS_output_tensor);
 }
 
@@ -395,29 +403,39 @@ void convertValues() {
   std::ifstream infile(FLAGS_input_text_file);
   std::string line;
   std::getline(infile, line);
-  vector<int> dims = splitString <int>(line);
-  assert(dims.size() >= 2);
+  vector<int> file_dims = splitString <int>(line);
+  assert(file_dims.size() >= 2);
 
-  int num_items = dims[0];
+  int num_items = file_dims[0];
   int batch_size = getBatchSize(num_items);
+  int num_batches = num_items / batch_size;
+  assert(num_items == batch_size * num_batches);
   vector<string> lines;
   while (std::getline(infile, line)) {
     lines.push_back(line);
   }
   assert(lines.size() == num_items);
-  std::vector<std::vector<float>> values;
-  int num = -1;
-  for (std::string line : lines) {
-    vector<float> item = splitString<float>(line);
-    if (num < 0) {
-      num = item.size();
-    } else {
-      assert(num == item.size());
+  std::vector<std::vector<std::vector<float>>> values;
+  std::vector<std::vector<int>> dims;
+  for (int i = 0; i < num_batches; i++) {
+    std::vector<std::vector<float>> one_value;
+    int num = -1;
+    for (int j = 0; j < batch_size; j++) {
+      int idx = i * batch_size + j;
+      std::string line = lines[idx];
+      vector<float> item = splitString<float>(line);
+      if (num < 0) {
+        num = item.size();
+      } else {
+        assert(num == item.size());
+      }
+      one_value.push_back(item);
     }
-    values.push_back(item);
+    vector<int> batch_dims = file_dims;
+    batch_dims[0] = batch_size;
+    dims.push_back(batch_dims);
+    values.push_back(one_value);
   }
-  vector<int> batch_dims = dims;
-  batch_dims[0] = batch_size;
 
   writeValues(values, dims, FLAGS_output_text_tensor);
 }