Update the output format for benchmark_helper. It outputs the dimensi… (#15108)
authorFei Sun <feisun@fb.com>
Wed, 12 Dec 2018 06:22:42 +0000 (22:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 06:24:56 +0000 (22:24 -0800)
Summary:
…on first and all the values in the next line. This way, it can output arbitrary blob
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15108

Reviewed By: llyfacebook

Differential Revision: D13429346

Pulled By: sf-wind

fbshipit-source-id: 5e0bba2a46fbe8d997dfc3d55a698484552e3af8

binaries/benchmark_helper.h

index e9ed70f..55d7d07 100644 (file)
@@ -47,35 +47,37 @@ void writeTextOutput(
   CAFFE_ENFORCE(blob_proto.has_tensor());
   caffe2::TensorProto tensor_proto = blob_proto.tensor();
   int dims_size = tensor_proto.dims_size();
-  // For NCHW or NHWC, print one line per CHW/HWC.
-  // If the output is one dimension, it means N==1,
-  // print everything to one line.
-  int loop_count = dims_size > 1 ? tensor_proto.dims(0) : 1;
   long long elem_dim_size =
       dims_size > 1 ? tensor_proto.dims(1) : tensor_proto.dims(0);
   for (int i = 2; i < dims_size; i++) {
     elem_dim_size *= tensor_proto.dims(i);
   }
   std::vector<std::string> lines;
-  for (int i = 0; i < loop_count; i++) {
-    int start_idx = i * elem_dim_size;
-    std::stringstream line;
-    if (tensor_proto.data_type() == caffe2::TensorProto::FLOAT) {
-      auto start = tensor_proto.float_data().begin() + start_idx;
-      auto end = start + elem_dim_size;
-      copy(start, end, std::ostream_iterator<float>(line, ","));
-    } else if (tensor_proto.data_type() == caffe2::TensorProto::INT32) {
-      auto start = tensor_proto.int32_data().begin() + start_idx;
-      auto end = start + elem_dim_size;
-      copy(start, end, std::ostream_iterator<int>(line, ","));
-    } else {
-      CAFFE_THROW("Unimplemented Blob type.");
+  std::string dims;
+  for (int i = 0; i < dims_size; i++) {
+    int dim = tensor_proto.dims(i);
+    if (i > 0) {
+      dims += ", ";
     }
-    // remove the last ,
-    string str = line.str();
-    str.pop_back();
-    lines.push_back(str);
+    dims += std::to_string(dim);
   }
+  lines.push_back(dims);
+  std::stringstream line;
+  if (tensor_proto.data_type() == caffe2::TensorProto::FLOAT) {
+    auto start = tensor_proto.float_data().begin();
+    auto end = tensor_proto.float_data().end();
+    copy(start, end, std::ostream_iterator<float>(line, ","));
+  } else if (tensor_proto.data_type() == caffe2::TensorProto::INT32) {
+    auto start = tensor_proto.int32_data().begin();
+    auto end = tensor_proto.int32_data().end();
+    copy(start, end, std::ostream_iterator<int>(line, ","));
+  } else {
+    CAFFE_THROW("Unimplemented Blob type.");
+  }
+  // remove the last ,
+  string str = line.str();
+  str.pop_back();
+  lines.push_back(str);
 
   std::ofstream output_file(output_name);
   std::ostream_iterator<std::string> output_iterator(output_file, "\n");