#include "shape_helper.h"
#include "tflite_ir_visitor.h"
+#include "tflite_op_creator.h"
namespace nncc
{
// TODO: make this a smart pointer. Note that it requires changing the NNImporter interface,
// because currently it returns a void*.
graph = new Graph();
+ opCreator.reset(new OpCreator(graph));
}
void IrVisitor::visit(const Model *m)
void IrVisitor::visit(const Operator *op)
{
- throw std::runtime_error{"Not yet implemented"};
+ auto inputs = createOpInputs(op);
+ auto params = createOpParams(op);
+
+ std::vector<INode::Ref> outputs;
+
+ unsigned int opcode = (*opcodes)[op->opcode_index()]->builtin_code();
+ // TODO: support other NN operator types
+ switch (opcode)
+ {
+ case BuiltinOperator_CONV_2D:
+ outputs = opCreator->createConv2D(inputs, params, op->builtin_options_as<Conv2DOptions>());
+ break;
+ case BuiltinOperator_DEPTHWISE_CONV_2D:
+ outputs = opCreator->createDepthConv2D(inputs, params,
+ op->builtin_options_as<DepthwiseConv2DOptions>());
+ break;
+ case BuiltinOperator_MAX_POOL_2D:
+ outputs = opCreator->createMaxPool(inputs, params, op->builtin_options_as<Pool2DOptions>());
+ break;
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ outputs = opCreator->createAvgPool(inputs, params, op->builtin_options_as<Pool2DOptions>());
+ break;
+ case BuiltinOperator_CONCATENATION:
+ outputs = opCreator->createConcat(inputs, params, op->builtin_options_as<ConcatenationOptions>());
+ break;
+ case BuiltinOperator_RESHAPE:
+ outputs = opCreator->createReshape(inputs, params, op->builtin_options_as<ReshapeOptions>());
+ break;
+ case BuiltinOperator_SOFTMAX:
+ outputs = opCreator->createSoftmax(inputs, params, op->builtin_options_as<SoftmaxOptions>());
+ break;
+ default:
+ throw PluginException(
+ std::string("Encountered unsupported TFLite operator: ") +
+ EnumNamesBuiltinOperator()[opcode]);
+ }
+
+ for (int i = 0; i < op->outputs()->size(); ++i)
+ {
+ opsForTensorsTheyOutput[(*(op->outputs()))[i]] = outputs[i];
+ }
}
void IrVisitor::visit(const Tensor *) {}