*/
#include "ResolveConstantShape.h"
+#include "TFShapeInferenceHelper.h"
#include "IR/TFShape.h"
#include "IR/TFConst.h"
-#include "Annotations/ShapeInferenceData.h"
#include <loco.h>
*/
bool resolve_constant_shape(loco::Graph *graph, moco::tf::TFShape *shape_node)
{
- using moco::tf::ShapeInferenceData;
-
- auto input_shape = shape_node->input()->annot<ShapeInferenceData>();
+ auto input_shape = moco::tf::node_shape(shape_node->input());
// Check condition
- if (!input_shape)
+ if (input_shape.domain() == loco::Domain::Unknown)
{
// Cannot resolve without known input_shape
return false;
}
- auto shape_rank = input_shape->rank();
+
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ auto shape_rank = input_tensor_shape.rank();
for (uint32_t axis = 0; axis < shape_rank; ++axis)
{
- if (!input_shape->dim(axis).known())
+ if (!input_tensor_shape.dim(axis).known())
{
// Cannot resolve with unknown dimension
return false;
}
}
- auto input_tensor_shape = input_shape->tensor_shape();
-
// Make TFConst to replace TFShape
auto const_node = graph->nodes()->create<moco::tf::TFConst>();