Extract ShapeInfo and some util functions into a separate file. (#17025)
authorYing Zhang <yingz@fb.com>
Wed, 13 Feb 2019 00:37:50 +0000 (16:37 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Feb 2019 01:06:29 +0000 (17:06 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17025

Extract ShapeInfo and some util functions into a separate file.

Reviewed By: yinghai

Differential Revision: D14017432

fbshipit-source-id: 201db46bce6d52d9355a1a86925aa6206d0336bf

caffe2/opt/bound_shape_inferencer.h
caffe2/opt/onnxifi_transformer.cc
caffe2/opt/shape_info.cc [new file with mode: 0644]
caffe2/opt/shape_info.h [new file with mode: 0644]

index 9699a3b..e42472b 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "caffe2/core/logging.h"
+#include "caffe2/opt/shape_info.h"
 #include "caffe2/proto/caffe2_pb.h"
 
 #include <sstream>
@@ -9,20 +10,6 @@
 #include <unordered_set>
 
 namespace caffe2 {
-
-struct CAFFE2_API ShapeInfo {
-  enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
-  ShapeInfo() {}
-  ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
-  ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {}
-
-  // type of the shape according its first dim
-  DimType dim_type{DimType::UNKNOWN};
-  TensorShape shape;
-};
-
-using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
-
 // This struct stores the max bound size for batch in the general sense. We have
 // the conventioal batch size and the look-up sequence, which is also batch in a
 // sense.
@@ -51,7 +38,7 @@ class CAFFE2_API BoundShapeInferencer {
       const NetDef& net,
       const std::unordered_map<std::string, ShapeInfo>& info);
 
-  const std::unordered_map<std::string, ShapeInfo>& shape_info() const {
+  const ShapeInfoMap& shape_info() const {
     return shape_info_;
   }
 
index a1e0896..344f58a 100644 (file)
@@ -97,12 +97,9 @@ ShapeInfoMap InferShapes(
     // Populate shapes from workplace
     const std::vector<std::string> ws_blobs = ws->Blobs();
     for (const auto& s : ws_blobs) {
-      auto shape = GetTensorShapeOfBlob(ws->GetBlob(s));
-      if (!shape.unknown_shape()) {
-        shape_map.emplace(
-            std::piecewise_construct,
-            std::forward_as_tuple(s),
-            std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, shape));
+      auto shape_info = getShapeInfoFromBlob(ws->GetBlob(s));
+      if (shape_info.dim_type != ShapeInfo::DimType::UNKNOWN) {
+        shape_map[s] = shape_info;
       }
     }
     for (const auto& kv : *shape_hints_ordered) {
diff --git a/caffe2/opt/shape_info.cc b/caffe2/opt/shape_info.cc
new file mode 100644 (file)
index 0000000..2f09bf7
--- /dev/null
@@ -0,0 +1,21 @@
+#include "caffe2/opt/shape_info.h"
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+ShapeInfo getShapeInfoFromBlob(const Blob* blob) {
+  ShapeInfo shape_info;
+  shape_info.shape = GetTensorShapeOfBlob(blob);
+  shape_info.dim_type = shape_info.shape.unknown_shape()
+      ? ShapeInfo::DimType::UNKNOWN
+      : ShapeInfo::DimType::CONSTANT;
+  return shape_info;
+}
+
+bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs) {
+  return lhs.dim_type == rhs.dim_type &&
+      lhs.shape.SerializeAsString() == rhs.shape.SerializeAsString();
+}
+
+} // namespace caffe2
diff --git a/caffe2/opt/shape_info.h b/caffe2/opt/shape_info.h
new file mode 100644 (file)
index 0000000..24672a2
--- /dev/null
@@ -0,0 +1,25 @@
+#pragma once
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+struct CAFFE2_API ShapeInfo {
+  enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
+  ShapeInfo() {}
+  ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
+  ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {}
+
+  // type of the shape according its first dim
+  DimType dim_type{DimType::UNKNOWN};
+  TensorShape shape;
+};
+
+using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
+
+// Generates ShapeInfo from Blob.
+ShapeInfo getShapeInfoFromBlob(const Blob* blob);
+
+bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs);
+
+} // namespace caffe2