target_link_libraries(exo_tflite PUBLIC exo_tflite_fbs)
target_link_libraries(exo_tflite PUBLIC loco)
target_link_libraries(exo_tflite PRIVATE stdex)
+target_link_libraries(exo_tflite PRIVATE pepper_strcast)
# Let's apply nncc common compile options
#
# NOTE This will enable strict compilation (warnings as error).
#include "ExporterUtils.h"
+ShapeDescription to_shape_description(const loco::TensorShape &shape)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(shape.rank());
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ // All the dimensions SHOULD be known
+ assert(shape.dim(axis).known());
+ res._dims.at(axis) = shape.dim(axis).value();
+ }
+
+ return res;
+}
+
+ShapeDescription to_shape_description(const loco::NodeShape &shape)
+{
+ switch (shape.domain())
+ {
+ case loco::Domain::Tensor:
+ return to_shape_description(shape.as<loco::TensorShape>());
+ default:
+ break;
+ }
+
+ throw std::runtime_error{"Not implemented yet"};
+}
+
uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator builtin_code)
{
auto it = _operator_codes.find(OpCode{builtin_code});
#include "loco.h"
#include "loco/IR/PermutingCodec.h"
+#include "loco/IR/NodeShape.h"
#include <unordered_map>
bool _rank_known;
};
+ShapeDescription to_shape_description(const loco::TensorShape &shape);
+ShapeDescription to_shape_description(const loco::NodeShape &shape);
+
/**
* @breif Record the information of T/F Lite SubGraph and its mapping to loco
*/
#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
+#include <pepper/strcast.h>
#include <stdex/Memory.h>
#include <type_traits>
+#include <cstdlib>
+
namespace
{
+// This Knob is a temporary workaround for incermental migration
+//
+// TODO Remove this workaround!
+struct Knob
+{
+ Knob()
+ {
+ // Off by default
+ auto s = std::getenv("EXOTFLITE_USE_LOCO_SHAPE_INFERENCE");
+ enable_loco_shape_inferene_framework = pepper::safe_strcast<int>(s, 0 /* DEFAULT */) != 0;
+ }
+
+ bool enable_loco_shape_inferene_framework = false;
+};
+
+Knob knob;
+
template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
T ceil_div(T dividend, T divisor)
{
void ShapeInference::run(loco::Graph *g)
{
+ if (knob.enable_loco_shape_inferene_framework)
+ {
+ loco::CanonicalShapeInferenceRule rule;
+ loco::apply(&rule).to(g);
+ return;
+ }
+
ShapeContext shape_ctx;
ShapeGetter shape_getter{shape_ctx};
ShapeDescription ShapeInference::get(loco::Node *node)
{
+ if (knob.enable_loco_shape_inferene_framework)
+ {
+ assert(loco::shape_known(node));
+ return to_shape_description(loco::shape_get(node));
+ }
+
assert(node->annot<ShapeAnnotation>() != nullptr);
return node->annot<ShapeAnnotation>()->shape();
}