return node_shape_with_check(node->x());
}
+ loco::NodeShape visit(const moco::tf::TFShape *node) final
+ {
+ auto input_shape = node_shape(node->input());
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ loco::TensorShape output_shape;
+
+ // Note that input shape becomes node(TFShape)'s value
+ output_shape.rank(1);
+ output_shape.dim(0) = input_tensor_shape.rank();
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::tf::TFSoftmax *node) final
+ {
+ return node_shape_with_check(node->logits());
+ }
+
+ loco::NodeShape visit(const moco::tf::TFSqrt *node) final
+ {
+ return node_shape_with_check(node->x());
+ }
+
+ loco::NodeShape visit(const moco::tf::TFSquaredDifference *node) final
+ {
+ return binary_node_shape(node);
+ }
+
+ loco::NodeShape visit(const moco::tf::TFSqueeze *node) final
+ {
+ auto input_shape = node_shape(node->input());
+
+ // TODO Not sure Squeeze only get input as Tensor
+ // Note that tensor_shape() has assertion in it
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ auto squeeze_dims_vec = node->squeeze_dims();
+ std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+ loco::TensorShape output_shape;
+ uint32_t output_rank = 0;
+
+ if (squeeze_dims.empty())
+ {
+ // Remove all dimensions whose value is 1
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (dim != 1)
+ {
+ assert(dim > 1);
+ output_shape.rank(++output_rank);
+ output_shape.dim(output_rank - 1) = dim;
+ }
+ }
+ }
+ else
+ {
+ uint32_t input_rank = input_tensor_shape.rank();
+
+ // Sanity check for 'squeeze_dims'
+ auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+ if (!(squeeze_dims.size() < input_rank))
+ return false;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (!(squeeze_dim >= -(int64_t)input_rank))
+ return false;
+ if (!(squeeze_dim < (int64_t)input_rank))
+ return false;
+ }
+ return true;
+ };
+
+ if (!is_valid_squeeze_dims())
+ {
+ throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
+ }
+
+ // Resolve negative squeeze dimension
+ std::set<int64_t> resolved_squeeze_dims;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (squeeze_dim < 0)
+ resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
+ else
+ resolved_squeeze_dims.insert(squeeze_dim);
+ }
+
+ // Remove squeeze dimensions only
+ for (uint32_t axis = 0; axis < input_rank; ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
+ {
+ // Not squeeze dim
+ output_shape.rank(++output_rank);
+ output_shape.dim(output_rank - 1) = dim;
+ }
+ else
+ {
+ // Is squeeze dim
+ assert(dim == 1);
+ // DO NOTHING
+ }
+ }
+ }
+
+ assert(output_shape.rank() > 0);
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::tf::TFStopGradient *node) final
+ {
+ return node_shape_with_check(node->input());
+ }
+
+ loco::NodeShape visit(const moco::tf::TFSub *node) final { return binary_node_shape(node); }
+
+ loco::NodeShape visit(const moco::tf::TFTanh *node) final
+ {
+ return node_shape_with_check(node->x());
+ }
+
public:
loco::NodeShape visit(const moco::tf::TFNode *node) final
{
bool fix_shape(moco::tf::TFRsqrt *node) { return false; }
-bool fix_shape(moco::tf::TFShape *node)
-{
- if (shape_inference_done(node))
- return false;
-
- auto input = node->input();
- loco::NodeShape input_shape;
- if (!node_shape(input, input_shape))
- {
- // Input shape is required for TFShape shape inference
- return false;
- }
- loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
-
- loco::TensorShape node_shape;
-
- // Note that input shape becomes node(TFShape)'s value
- node_shape.rank(1);
- node_shape.dim(0) = input_tensor_shape.rank();
-
- auto shape_annot = stdex::make_unique<ShapeInferenceData>();
- shape_annot->tensor_shape(node_shape);
- node->annot(std::move(shape_annot));
-
- LOGGER(l);
- INFO(l) << "Fix TFShape shape = " << node_shape;
-
- return true;
-}
-
-bool fix_shape(moco::tf::TFSqrt *node)
-{
- // Output shape is same as the input x
- auto x = node->x();
- return copy_shapedata(x, node);
-}
-
-bool fix_shape(moco::tf::TFSoftmax *node)
-{
- // Output shape is same as the input x
- auto logits = node->logits();
- return copy_shapedata(logits, node);
-}
-
-bool fix_shape(moco::tf::TFSquaredDifference *node)
-{
- auto x = node->x();
- auto y = node->y();
- return copy_shapedata(x, y, node);
-}
-
-bool fix_shape(moco::tf::TFSqueeze *node)
-{
- if (shape_inference_done(node))
- return false;
-
- auto input = node->input();
- loco::NodeShape input_shape;
- if (!node_shape(input, input_shape))
- {
- // Input shape is required for TFSqueeze shape inference
- return false;
- }
-
- // TODO Not sure Squeeze only get input as Tensor
- // Note that tensor_shape() has assertion in it
- auto input_tensor_shape = input_shape.as<loco::TensorShape>();
-
- auto squeeze_dims_vec = node->squeeze_dims();
- std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
-
- loco::TensorShape node_shape;
- uint32_t node_rank = 0;
-
- if (squeeze_dims.empty())
- {
- // Remove all dimensions whose value is 1
- for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
- {
- assert(input_tensor_shape.dim(axis).known());
- auto dim = input_tensor_shape.dim(axis).value();
- if (dim != 1)
- {
- assert(dim > 1);
- node_shape.rank(++node_rank);
- node_shape.dim(node_rank - 1) = dim;
- }
- }
- }
- else
- {
- uint32_t input_rank = input_tensor_shape.rank();
-
- // Sanity check for 'squeeze_dims'
- auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
- if (!(squeeze_dims.size() < input_rank))
- return false;
- for (auto squeeze_dim : squeeze_dims)
- {
- if (!(squeeze_dim >= -(int64_t)input_rank))
- return false;
- if (!(squeeze_dim < (int64_t)input_rank))
- return false;
- }
- return true;
- };
-
- if (!is_valid_squeeze_dims())
- {
- throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
- }
-
- // Resolve negative squeeze dimension
- std::set<int64_t> resolved_squeeze_dims;
- for (auto squeeze_dim : squeeze_dims)
- {
- if (squeeze_dim < 0)
- resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
- else
- resolved_squeeze_dims.insert(squeeze_dim);
- }
-
- // Remove squeeze dimensions only
- for (uint32_t axis = 0; axis < input_rank; ++axis)
- {
- assert(input_tensor_shape.dim(axis).known());
- auto dim = input_tensor_shape.dim(axis).value();
- if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
- {
- // Not squeeze dim
- node_shape.rank(++node_rank);
- node_shape.dim(node_rank - 1) = dim;
- }
- else
- {
- // Is squeeze dim
- assert(dim == 1);
- // DO NOTHING
- }
- }
- }
-
- assert(node_shape.rank() > 0);
+bool fix_shape(moco::tf::TFShape *node) { return false; }
- auto shape_annot = stdex::make_unique<ShapeInferenceData>();
- shape_annot->tensor_shape(node_shape);
- node->annot(std::move(shape_annot));
+bool fix_shape(moco::tf::TFSqrt *node) { return false; }
- LOGGER(l);
- INFO(l) << "Fix TFSqueeze shape = " << node_shape;
-
- return true;
-}
+bool fix_shape(moco::tf::TFSoftmax *node) { return false; }
-bool fix_shape(moco::tf::TFStopGradient *node)
-{
- // Output shape is same as the input
- auto input = node->input();
- return copy_shapedata(input, node);
-}
+bool fix_shape(moco::tf::TFSquaredDifference *node) { return false; }
-bool fix_shape(moco::tf::TFSub *node)
-{
- auto x = node->x();
- auto y = node->y();
- loco::NodeShape x_shape;
- loco::NodeShape y_shape;
+bool fix_shape(moco::tf::TFSqueeze *node) { return false; }
- if (!node_shape(x, x_shape))
- return false;
- if (!node_shape(y, y_shape))
- return false;
+bool fix_shape(moco::tf::TFStopGradient *node) { return false; }
- // Output shape is same as the input
- return copy_shapedata(x, y, node);
-}
+bool fix_shape(moco::tf::TFSub *node) { return false; }
-bool fix_shape(moco::tf::TFTanh *node)
-{
- // Output shape is same as the input
- auto x = node->x();
- return copy_shapedata(x, node);
-}
+bool fix_shape(moco::tf::TFTanh *node) { return false; }
bool fix_shape(locoex::COpCall *node)
{