#include "Dialect/TFNodeVisitor.h"
#include "Dialect/TFNodeImpl.h"
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
#include <stdex/Memory.h>
namespace
* A and B are drawn multiple times to simplify the diagram
*/
- auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
- auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
-
auto node_A = node->x();
auto node_B = node->y();
+ if (!loco::shape_known(node_A) || !loco::shape_known(node_B))
+ {
+ // Wait for shape inference
+ return false;
+ }
+
+ const auto &x_shape = loco::shape_get(node_A);
+ const auto &y_shape = loco::shape_get(node_B);
+
+ if (!(x_shape == y_shape))
+ {
+ // TODO support broadcast
+ return false;
+ }
+
+ auto sub_node = graph->nodes()->create<loco::EltwiseSub>();
+ auto mul_node = graph->nodes()->create<loco::EltwiseMul>();
+
// update connections
sub_node->lhs(node_A);
sub_node->rhs(node_B);