bool fix_shape(moco::tf::TFConcatV2 *node)
{
- auto concat_data = node->annot<ConcatData>();
- if (concat_data == nullptr)
- {
- // shape inference is already done for TFConcatV2
- assert(node->annot<ShapeInferenceData>() != nullptr);
- return false;
- }
- assert(node->annot<ShapeInferenceData>() == nullptr);
+ (void)node;
- auto lhs = node->lhs();
- auto rhs = node->rhs();
- auto lhs_shapedata = lhs->annot<ShapeInferenceData>();
- auto rhs_shapedata = rhs->annot<ShapeInferenceData>();
- if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
- {
- // postpone as previous input node(s) hasn't been processed.
- // this will return false as there was nothing changed, but the input of this
- // node should be changed and from that this method should be called again.
- // if not, the network may have some problem and the final output may not have
- // the right shape value and we can identify with some validation at final stage.
- return false;
- }
-
- uint32_t lhs_rank = lhs_shapedata->rank();
- uint32_t rhs_rank = rhs_shapedata->rank();
- assert(lhs_rank == rhs_rank);
+ throw std::runtime_error("NYI fix_shape TFConcatV2");
- int32_t axis_tf = concat_data->axis();
- if (axis_tf < 0)
- {
- axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
- }
- assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
- // clear annotation ConcatData
- node->annot<ConcatData>(nullptr);
-
- uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
- node->axis(axis_loco);
-
- // Set ShapeInferenceData for TensorConcat
- auto shape_data = stdex::make_unique<ShapeInferenceData>();
- shape_data->rank(lhs_rank);
- for (uint32_t index = 0; index < lhs_rank; ++index)
- {
- uint32_t lhs_dim = lhs_shapedata->dim(index).value();
- uint32_t rhs_dim = rhs_shapedata->dim(index).value();
- // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
- // or doesn't care when "index == axis_loco"
- assert(index == axis_loco || lhs_dim == rhs_dim);
-
- uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
-
- if (lhs_shapedata->dim(index).known())
- shape_data->dim(index) = new_dim;
- else
- shape_data->dim(index).unset();
- }
- node->annot(std::move(shape_data));
-
- return true;
+ return false;
}
bool fix_shape(moco::tf::TFConst *node)