#include <loco.h>
#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
#include <moco/Log.h>
#include <stdex/Memory.h>
#include <plier/tf/Convert.h>
return true;
}
+/**
+ * @note While in shape inference, Node maybe Canonical, TF dialect or other dialects
+ * This will provide common loco::NodeShape as shape information
+ */
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
+{
+ if (loco::shape_known(node))
+ {
+ nodeshape = loco::shape_get(node);
+ return true;
+ }
+
+ auto shapedata = node->annot<ShapeInferenceData>();
+ if (shapedata == nullptr)
+ {
+ return false;
+ }
+
+ switch (shapedata->domain())
+ {
+ case loco::Domain::Tensor:
+ nodeshape.set(shapedata->tensor_shape());
+ break;
+
+ case loco::Domain::Feature:
+ nodeshape.set(shapedata->feature_shape());
+ break;
+
+ case loco::Domain::Filter:
+ nodeshape.set(shapedata->filter_shape());
+ break;
+
+ case loco::Domain::DepthwiseFilter:
+ nodeshape.set(shapedata->depthwisefilter_shape());
+ break;
+
+ case loco::Domain::Bias:
+ nodeshape.set(shapedata->bias_shape());
+ break;
+
+ default:
+ throw std::runtime_error("Unsupported Domain in node_shape()");
+ }
+ return true;
+}
+
struct FixPadContext
{
uint32_t input_height;