};
/**
- * @brief Record the (tensor) shape of each loco node
- */
-struct ShapeContext
-{
- std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
-};
-
-/**
* @breif Record the information of T/F Lite SubGraph and its mapping to loco
*/
struct SubGraphContext
};
// Prerequisites for tflite::Model object creation
-struct SerializedModelData final : public ShapeContext, public SubGraphContext
+struct SerializedModelData final : public SubGraphContext
{
SerializedModelData() = default;
SerializedModelData(const SerializedModelData &) = delete;
#include "OperationExporter.h"
#include "ExporterUtils.h"
+#include "TypeInference.h"
using namespace flatbuffers;
using namespace tflite;
// zero bias.
auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
assert(ker);
- int32_t bias_vec_size = gd._node_to_shape[ker]._dims[0]; // output kernel count
+ int32_t bias_vec_size = ShapeInference::get(ker)._dims[0]; // output kernel count
auto bias_vec_shape_offset = builder.CreateVector(std::vector<int32_t>{bias_vec_size});
size_t raw_bias_vec_size = bias_vec_size * sizeof(int32_t);
{
// Create and register output tensor shape
ShapeDescription shape_description = ShapeInference::get(node);
- gd._node_to_shape[node] = shape_description;
auto shape_offset = encodeShape(builder, shape_description);
// encode and register output tensor type
return node->annot<TypeAnnotation>()->type();
}
+namespace
+{
+
+/**
+ * @brief Record the (tensor) shape of each loco node
+ */
+struct ShapeContext
+{
+ std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
+};
+
+} // namespace
+
int32_t decodeShapeDimension(const loco::Dimension &dim)
{
if (!dim.known())