#include "TFLExporterImpl.h"
+#include "TypeInference.h"
#include "TensorExporter.h"
#include "OperationExporter.h"
#include "ExporterUtils.h"
void TFLExporter::Impl::exportGraph(loco::Graph *graph)
{
+ // Infer the type of each node
+ TypeInference::run(graph);
+ // TypeInference::get(node) now works
+
_builder.Clear();
SerializedModelData gd;
auto shape_offset = encodeShape(builder, shape_description);
// encode and register output tensor type
- auto tensor_type = getOpResultType(node, gd);
- gd._node_to_type[node] = tensor_type;
+ auto tensor_type = TypeInference::get(node);
+ // gd._node_to_type[node] = tensor_type;
// encode and register output tensor buffer
auto buffer = encodeOpBuffer(builder, node);
#include "schema_generated.h"
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <stdex/Memory.h>
+
#include <type_traits>
namespace
return translateLocoTypeToTFLite(node->dtype());
}
+tflite::TensorType getOpResultType(loco::Push *node, TypeContext &)
+{
+ return TypeInference::get(node->from());
+}
+
tflite::TensorType getOpResultType(loco::ReLU *node, TypeContext &gd)
{
return gd._node_to_type[node->input()];
return value_type;
}
+namespace
+{
+
+struct TypeAnnotation : public loco::NodeAnnotation
+{
+public:
+ TypeAnnotation(const tflite::TensorType &type) : _type{type}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const tflite::TensorType &type(void) const { return _type; }
+
+private:
+ tflite::TensorType _type;
+};
+
+class TypeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+ TypeAnnotator() = default;
+
+public:
+#define NODE(NAME) \
+ void visit(loco::NAME *node) final \
+ { \
+ auto t = getOpResultType(node, _ctx); \
+ node->annot(stdex::make_unique<TypeAnnotation>(t)); \
+ }
+ NODE(ConstGen)
+ NODE(Pull)
+ NODE(Push)
+ NODE(FeatureEncode)
+ NODE(FeatureDecode)
+ NODE(FilterEncode)
+ NODE(MaxPool2D)
+ NODE(AvgPool2D)
+ NODE(Conv2D)
+ NODE(ReLU)
+ NODE(TensorConcat)
+ NODE(BiasEncode)
+ NODE(TensorBiasAdd)
+#undef NODE
+
+private:
+ // TODO Remove this variable
+ TypeContext _ctx;
+};
+
+} // namespace
+
+void TypeInference::run(loco::Graph *g)
+{
+ TypeAnnotator type_annotator;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ {
+ canonical_node->accept(&type_annotator);
+ }
+ }
+}
+
+tflite::TensorType TypeInference::get(loco::Node *node)
+{
+ assert(node->annot<TypeAnnotation>() != nullptr);
+ return node->annot<TypeAnnotation>()->type();
+}
+
int32_t decodeShapeDimension(const loco::Dimension &dim)
{
if (!dim.known())
#include <loco/IR/Nodes.h>
// Tensor type inference functions
-
+#if 0
tflite::TensorType getOpResultType(loco::ConstGen *node, TypeContext &);
tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &);
tflite::TensorType getOpResultType(loco::BiasEncode *node, TypeContext &gd);
tflite::TensorType getOpResultType(loco::BiasAdd<loco::Domain::Tensor> *node, TypeContext &gd);
+#endif
+
+/**
+ * @brief Annotate the type of each node as NodeAnnotation
+ *
+ * HOW TO USE
+ *
+ * TypeInference::run(g);
+ *
+ * TypeInference::get(g->nodes()->at(0));
+ * TypeInference::get(g->nodes()->at(...));
+ */
+struct TypeInference
+{
+ static void run(loco::Graph *g);
+
+ static tflite::TensorType get(loco::Node *node);
+};
// Shape inference functions