#include "TFDialect.h"
#include "TFNode.h"
-#include "TFNodeVisitor.h"
+
+#include "Annotations/ShapeInferenceData.h"
#include <loco/IR/NodeShape.h>
#include <loco/Service/ShapeInference.h>
#include <cassert>
-// temporary headers that will be removed
-// TODO remove using ShapeInferenceData
-#include "Annotations/ShapeInferenceData.h"
-
-namespace
-{
-
-using namespace moco::tf;
-
-/**
- * @note "Forward" means that this algorithm computes the ouput shape from inputs shapes
- */
-class ForwardShapeInferenceAlgorithm final : public TFNodeVisitor<loco::NodeShape>
-{
-public:
- // TODO TFAdd
- // TODO TFAvgPool
- // TODO TFBiasAdd
- // TODO TFConcatV2
- // TODO TFConst
- // TODO TFConv2D
- // TODO TFDepthwiseConv2dNative
- // TODO TFFusedBatchNorm
- // TODO TFIdentity
- // TODO TFMaxPool
- // TODO TFMul
- // TODO TFRealDiv
- // TODO TFRelu
- // TODO TFRelu6
- // TODO TFReshape
- // TODO TFRsqrt
- // TODO TFShape
- // TODO TFSoftmax
- // TODO TFSqrt
- // TODO TFSquaredDifference
- // TODO TFSqueeze
- // TODO TFStopGradient
- // TODO TFSub
- // TODO TFTanh
-
- // temporary default fallback to use legacy ShapeInferenceData
- // TODO remove using ShapeInferenceData
- loco::NodeShape visit(const TFNode *node) final;
-};
-
-// TODO TFAdd
-
-// TODO TFAvgPool
-
-// TODO TFBiasAdd
-
-// TODO TFConcatV2
-
-// TODO TFConst
-
-// TODO TFConv2D
-
-// TODO TFDepthwiseConv2dNative
-
-// TODO TFFusedBatchNorm
-
-// TODO TFIdentity
-
-// TODO TFMaxPool
-
-// TODO TFMul
-
-// TODO TFRealDiv
-
-// TODO TFRelu
-
-// TODO TFRelu6
-
-// TODO TFReshape
-
-// TODO TFRsqrt
-
-// TODO TFShape
-
-// TODO TFSoftmax
-
-// TODO TFSquaredDifference
-
-// TODO TFSqueeze
-
-// TODO TFStopGradient
-
-// TODO TFSub
-
-// TODO TFTanh
-
-// temporary default fallback to use ShapeInferenceData
-// TODO remove using ShapeInferenceData
-loco::NodeShape ForwardShapeInferenceAlgorithm::visit(const TFNode *node)
-{
- auto shapedata = node->annot<ShapeInferenceData>();
- assert(shapedata != nullptr);
-
- loco::NodeShape nodeshape;
- assert(shapedata->domain() == loco::Domain::Tensor);
-
- nodeshape.set(shapedata->tensor_shape());
-
- return nodeshape;
-}
-
-} // namespace
-
namespace moco
{
namespace tf