This will introduce ShapeInferenceData helper to set as TensorShape with FeatureShape and TFDataLayout
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
dim(0) = shape;
}
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &feature_shape,
+ const TFDataLayout &data_layout)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ if (data_layout == "NHWC")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.height();
+ tensor_shape.dim(2) = feature_shape.width();
+ tensor_shape.dim(3) = feature_shape.depth();
+ }
+ else if (data_layout == "NCHW")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.depth();
+ tensor_shape.dim(2) = feature_shape.height();
+ tensor_shape.dim(3) = feature_shape.width();
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ throw std::runtime_error("as_tensor_shape: only supports NHWC or NCHW");
+ }
+
+ shapedata.tensor_shape(tensor_shape);
+}
+
loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
const TFDataLayout &data_layout)
{
loco::Domain _domain{loco::Domain::Tensor};
};
+void as_tensor_shape(ShapeInferenceData &shapedata, const loco::FeatureShape &shape,
+ const TFDataLayout &data_layout);
+
loco::FeatureShape as_feature_shape(const ShapeInferenceData &shapedata,
const TFDataLayout &data_layout);
ASSERT_EQ(bias_g, 3);
}
+TEST(TensorFlowImport, shapeinferencedata_as_tensor_set)
+{
+ loco::FeatureShape feature_s;
+
+ feature_s.count() = 1;
+ feature_s.height() = 2;
+ feature_s.width() = 3;
+ feature_s.depth() = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ as_tensor_shape(shapedata, feature_s, "NHWC");
+
+ loco::TensorShape tensor_g;
+
+ tensor_g = shapedata.tensor_shape();
+
+ ASSERT_EQ(tensor_g.rank(), 4);
+ ASSERT_EQ(tensor_g.dim(0), 1);
+ ASSERT_EQ(tensor_g.dim(1), 2);
+ ASSERT_EQ(tensor_g.dim(2), 3);
+ ASSERT_EQ(tensor_g.dim(3), 4);
+}
+
TEST(TensorFlowImport, shapeinferencedata_as_feature)
{
loco::TensorShape tensor_s;