#pragma once
#include "caffe2/core/logging.h"
+#include "caffe2/opt/shape_info.h"
#include "caffe2/proto/caffe2_pb.h"
#include <sstream>
#include <unordered_set>
namespace caffe2 {
-
-struct CAFFE2_API ShapeInfo {
- enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
- ShapeInfo() {}
- ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
- ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {}
-
- // type of the shape according its first dim
- DimType dim_type{DimType::UNKNOWN};
- TensorShape shape;
-};
-
-using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
-
// This struct stores the max bound size for batch in the general sense. We have
// the conventioal batch size and the look-up sequence, which is also batch in a
// sense.
const NetDef& net,
const std::unordered_map<std::string, ShapeInfo>& info);
- const std::unordered_map<std::string, ShapeInfo>& shape_info() const {
+ const ShapeInfoMap& shape_info() const {
return shape_info_;
}
// Populate shapes from workplace
const std::vector<std::string> ws_blobs = ws->Blobs();
for (const auto& s : ws_blobs) {
- auto shape = GetTensorShapeOfBlob(ws->GetBlob(s));
- if (!shape.unknown_shape()) {
- shape_map.emplace(
- std::piecewise_construct,
- std::forward_as_tuple(s),
- std::forward_as_tuple(ShapeInfo::DimType::CONSTANT, shape));
+ auto shape_info = getShapeInfoFromBlob(ws->GetBlob(s));
+ if (shape_info.dim_type != ShapeInfo::DimType::UNKNOWN) {
+ shape_map[s] = shape_info;
}
}
for (const auto& kv : *shape_hints_ordered) {
--- /dev/null
+#include "caffe2/opt/shape_info.h"
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+ShapeInfo getShapeInfoFromBlob(const Blob* blob) {
+ ShapeInfo shape_info;
+ shape_info.shape = GetTensorShapeOfBlob(blob);
+ shape_info.dim_type = shape_info.shape.unknown_shape()
+ ? ShapeInfo::DimType::UNKNOWN
+ : ShapeInfo::DimType::CONSTANT;
+ return shape_info;
+}
+
+bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs) {
+ return lhs.dim_type == rhs.dim_type &&
+ lhs.shape.SerializeAsString() == rhs.shape.SerializeAsString();
+}
+
+} // namespace caffe2
--- /dev/null
+#pragma once
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+struct CAFFE2_API ShapeInfo {
+ enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
+ ShapeInfo() {}
+ ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
+ ShapeInfo(DimType t, const TensorShape& s) : dim_type(t), shape(s) {}
+
+ // type of the shape according its first dim
+ DimType dim_type{DimType::UNKNOWN};
+ TensorShape shape;
+};
+
+using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
+
+// Generates ShapeInfo from Blob.
+ShapeInfo getShapeInfoFromBlob(const Blob* blob);
+
+bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs);
+
+} // namespace caffe2