#include "Dialect/TFNodes.h"
#include <loco.h>
+#include <loco/IR/NodeShape.h>
#include <moco/Log.h>
#include <stdex/Memory.h>
#include <plier/tf/Convert.h>
}
}
+std::unique_ptr<ShapeInferenceData> make_shape_inference_data(const loco::NodeShape &src)
+{
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+
+ switch (src.domain())
+ {
+ case loco::Domain::Tensor:
+ shape_data->tensor_shape(src.as<loco::TensorShape>());
+ break;
+
+ case loco::Domain::Feature:
+ shape_data->feature_shape(src.as<loco::FeatureShape>());
+ break;
+
+ case loco::Domain::Filter:
+ shape_data->filter_shape(src.as<loco::FilterShape>());
+ break;
+
+ case loco::Domain::DepthwiseFilter:
+ shape_data->depthwisefilter_shape(src.as<loco::DepthwiseFilterShape>());
+ break;
+
+ case loco::Domain::Bias:
+ shape_data->bias_shape(src.as<loco::BiasShape>());
+ break;
+
+ default:
+ throw std::runtime_error("Unsupported Domain in make_shape_inference_data");
+ }
+
+ return std::move(shape_data);
+}
+
/**
* @brief Copy ShapeInferenceData from loco::Node pointer src to dst
*/