namespace tf
{
+loco::TensorShape ShapeInferenceData::tensor_shape(void) const
+{
+ loco::TensorShape shape;
+
+ shape.rank(rank());
+ for (uint32_t r = 0; r < rank(); ++r)
+ {
+ if (dim(r).known())
+ shape.dim(r) = loco::make_dimension(dim(r).value());
+ else
+ shape.dim(r).unset();
+ }
+
+ return shape;
+}
+
loco::FeatureShape ShapeInferenceData::feature_shape(void) const
{
loco::FeatureShape shape;
dim(3) = shape.depth();
}
+void ShapeInferenceData::filter_shape(const loco::FilterShape &shape)
+{
+ rank(4);
+ dim(0) = shape.count();
+ dim(1) = shape.height();
+ dim(2) = shape.width();
+ dim(3) = shape.depth();
+}
+
} // namespace tf
} // namespace moco
/**
* @brief ShapeInferenceData provides shape inference data tracking from the start(input)
*
- * @note For Feature, NHWC is used for shape layout
+ * @note For Feature and Filter, NHWC is used for shape layout
*/
class ShapeInferenceData : public loco::NodeAnnotation,
public loco::NodeMixin<loco::NodeTrait::TensorShape>
~ShapeInferenceData(){};
public:
+ loco::TensorShape tensor_shape(void) const;
loco::FeatureShape feature_shape(void) const;
void feature_shape(const loco::FeatureShape &shape);
+ void filter_shape(const loco::FilterShape &shape);
};
} // namespace tf
#include <gtest/gtest.h>
+TEST(TensorFlowFrontend, shapeinferencedata_tensor_get)
+{
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.rank(4);
+ shapedata.dim(0) = 1;
+ shapedata.dim(1) = 2;
+ shapedata.dim(2) = 3;
+ shapedata.dim(3) = 4;
+
+ loco::TensorShape tensor = shapedata.tensor_shape();
+
+ ASSERT_EQ(tensor.rank(), 4);
+ ASSERT_EQ(tensor.dim(0), 1);
+ ASSERT_EQ(tensor.dim(1), 2);
+ ASSERT_EQ(tensor.dim(2), 3);
+ ASSERT_EQ(tensor.dim(3), 4);
+}
+
TEST(TensorFlowFrontend, shapeinferencedata_feature_set)
{
loco::FeatureShape feature;
ASSERT_EQ(feature.width(), 3);
ASSERT_EQ(feature.depth(), 4);
}
+
+TEST(TensorFlowFrontend, shapeinferencedata_filter_set)
+{
+ loco::FilterShape filter;
+
+ filter.count() = 1;
+ filter.height() = 2;
+ filter.width() = 3;
+ filter.depth() = 4;
+
+ moco::tf::ShapeInferenceData shapedata;
+
+ shapedata.filter_shape(filter);
+
+ ASSERT_EQ(shapedata.rank(), 4);
+ ASSERT_EQ(shapedata.dim(0), 1);
+ ASSERT_EQ(shapedata.dim(1), 2);
+ ASSERT_EQ(shapedata.dim(2), 3);
+ ASSERT_EQ(shapedata.dim(3), 4);
+}