namespace
{
+
+class ShapeInferenceAlgorithm final : public moco::tf::TFNodeVisitor<loco::NodeShape>
+{
+public:
+ ShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
+ {
+ // DO NOTHING
+ }
+
+private:
+ const loco::ShapeInferenceRule::Context *_ctx;
+
+private:
+ bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
+ loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
+
+public:
+ loco::NodeShape visit(const moco::tf::TFNode *node) final
+ {
+ loco::NodeShape unknown;
+ return unknown;
+ }
+};
+
+} // namespace
+
+namespace
+{
namespace compat
{
assert(node->dialect() == TFDialect::get());
assert(dynamic_cast<const TFNode *>(node) != nullptr);
- sink->fail();
+ ShapeInferenceAlgorithm alg{ctx};
+ auto shape = dynamic_cast<const TFNode *>(node)->accept(&alg);
+
+ if (shape.domain() == loco::Domain::Unknown)
+ sink->fail();
+ else
+ sink->okay(shape);
}
} // namespace tf