#include "FixShapeTransform.h"
+#include "Annotations/ConcatData.h"
#include "Annotations/ShapeInferenceData.h"
#include <loco.h>
bool fix_shape(loco::TensorConcat *node)
{
- // TODO fill this
- return false;
+ const ConcatData *concat_data = node->annot<ConcatData>();
+ if (concat_data == nullptr)
+ {
+ // shape inference is already done for TensorConcat
+ assert(node->annot<ShapeInferenceData>() != nullptr);
+ return false;
+ }
+ assert(node->annot<ShapeInferenceData>() == nullptr);
+
+ loco::Node *lhs = node->lhs();
+ loco::Node *rhs = node->rhs();
+ const ShapeInferenceData *lhs_shapedata = lhs->annot<ShapeInferenceData>();
+ const ShapeInferenceData *rhs_shapedata = rhs->annot<ShapeInferenceData>();
+ if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
+ {
+ // postpone as previous input node(s) hasn't been processed.
+ // this will return false as there was nothing changed, but the input of this
+ // node should be changed and from that this method should be called again.
+ // if not, the network may have some problem and the final output may not have
+ // the right shape value and we can identify with some validation at final stage.
+ return false;
+ }
+
+ uint32_t lhs_rank = lhs_shapedata->rank();
+ uint32_t rhs_rank = rhs_shapedata->rank();
+ assert(lhs_rank == rhs_rank);
+
+ int32_t axis_tf = concat_data->axis();
+ if (axis_tf < 0)
+ {
+ axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
+ }
+ assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
+ // clear annotation ConcatData
+ node->annot<ConcatData>(nullptr);
+
+ uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
+ node->axis(axis_loco);
+
+ // Set ShapeInferenceData for TensorConcat
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->rank(lhs_rank);
+ for (uint32_t index = 0; index < lhs_rank; ++index)
+ {
+ uint32_t lhs_dim = lhs_shapedata->dim(index).value();
+ uint32_t rhs_dim = rhs_shapedata->dim(index).value();
+ // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
+ // or doesn't care when "index == axis_loco"
+ assert(index == axis_loco || lhs_dim == rhs_dim);
+
+ uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
+
+ if (lhs_shapedata->dim(index).known())
+ shape_data->dim(index) = loco::make_dimension(new_dim);
+ else
+ shape_data->dim(index).unset();
+ }
+ node->annot(std::move(shape_data));
+
+ return true;
}
} // namespace