This will update shape inference to use loco service in Canonicalizers of Add,Mul,Realdiv and Sub nodes
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
#include "AddCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
namespace
auto node_A = node->x();
auto node_B = node->y();
- auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
- assert(x_shapedata != nullptr);
- auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
- assert(y_shapedata != nullptr);
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
- if (!(*x_shapedata == *y_shapedata))
+ if (!(x_shape == y_shape))
{
// TODO support broadcast Add
return false;
#include "MulCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
namespace
auto node_A = node->x();
auto node_B = node->y();
- auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
- assert(x_shapedata != nullptr);
- auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
- assert(y_shapedata != nullptr);
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
- if (!(*x_shapedata == *y_shapedata))
+ if (!(x_shape == y_shape))
{
// TODO support broadcast Mul
return false;
#include "RealDivCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
namespace
auto node_A = node->x();
auto node_B = node->y();
- auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
- assert(x_shapedata != nullptr);
- auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
- assert(y_shapedata != nullptr);
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
- if (!(*x_shapedata == *y_shapedata))
+ if (!(x_shape == y_shape))
{
// TODO support broadcast Div
return false;
#include "SubCanonicalizer.h"
-#include "Annotations/ShapeInferenceData.h"
-
#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
namespace
auto node_A = node->x();
auto node_B = node->y();
- auto x_shapedata = node_A->annot<moco::tf::ShapeInferenceData>();
- assert(x_shapedata != nullptr);
- auto y_shapedata = node_B->annot<moco::tf::ShapeInferenceData>();
- assert(y_shapedata != nullptr);
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
- if (!(*x_shapedata == *y_shapedata))
+ if (!(x_shape == y_shape))
{
// TODO support broadcast Sub
return false;