dim(0) = shape;
}
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+ const TFDataLayout &data_layout)
+{
+ if (shapedata.domain() == loco::Domain::Feature)
+ return shapedata.feature_shape();
+
+ loco::FeatureShape feature_shape;
+
+ // only convert from tensor to feature
+ if (shapedata.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("as_feature_shape: domain is not tensor");
+ }
+ if (shapedata.rank() != 4)
+ {
+ throw std::runtime_error("as_feature_shape: rank is not 4");
+ }
+
+ // TODO support for other data_layout if needed
+ if (data_layout != "NHWC" && data_layout != "NCHW")
+ {
+ throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+ }
+
+ if (data_layout == "NHWC")
+ {
+ feature_shape.count() = shapedata.dim(0);
+ feature_shape.height() = shapedata.dim(1);
+ feature_shape.width() = shapedata.dim(2);
+ feature_shape.depth() = shapedata.dim(3);
+ }
+ else
+ {
+ feature_shape.count() = shapedata.dim(0);
+ feature_shape.depth() = shapedata.dim(1);
+ feature_shape.height() = shapedata.dim(2);
+ feature_shape.width() = shapedata.dim(3);
+ }
+
+ return feature_shape;
+}
+
} // namespace tf
} // namespace moco
namespace tf
{
+/// @note Below alias may be introduced as separate class
+using TFDataLayout = std::string;
+
/**
* @brief ShapeInferenceData provides shape inference data tracking from the start(input)
*
loco::Domain _domain{loco::Domain::Tensor};
};
+loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
+ const TFDataLayout &data_layout);
+
} // namespace tf
} // namespace moco
ASSERT_EQ(bias_g, 3);
}
+
+TEST(TensorFlowImport, shapeinferencedata_as_feature)
+{
+ loco::TensorShape tensor_s;
+
+ tensor_s.rank(4);
+ tensor_s.dim(0) = 1;
+ tensor_s.dim(1) = 2;
+ tensor_s.dim(2) = 3;
+ tensor_s.dim(3) = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.tensor_shape(tensor_s);
+
+ loco::FeatureShape feature_g = as_feature_shape(shapedata, "NHWC");
+
+ ASSERT_EQ(feature_g.count(), 1);
+ ASSERT_EQ(feature_g.height(), 2);
+ ASSERT_EQ(feature_g.width(), 3);
+ ASSERT_EQ(feature_g.depth(), 4);
+}