collectUnsupportedOps();
}
+static const std::set<tflite::BuiltinOperator> supportedOperators = {
+ BuiltinOperator_ADD,
+ BuiltinOperator_AVERAGE_POOL_2D,
+ BuiltinOperator_CONCATENATION,
+ BuiltinOperator_CONV_2D,
+ BuiltinOperator_DEPTHWISE_CONV_2D,
+ BuiltinOperator_DIV,
+ BuiltinOperator_FULLY_CONNECTED,
+ BuiltinOperator_LEAKY_RELU,
+ BuiltinOperator_LOGISTIC,
+ BuiltinOperator_MAX_POOL_2D,
+ BuiltinOperator_MAXIMUM,
+ BuiltinOperator_MEAN,
+ BuiltinOperator_MUL,
+ BuiltinOperator_PAD,
+ BuiltinOperator_RELU,
+ BuiltinOperator_RELU6,
+ BuiltinOperator_RESHAPE,
+ BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+ BuiltinOperator_SHAPE,
+ BuiltinOperator_SLICE,
+ BuiltinOperator_SOFTMAX,
+ BuiltinOperator_SQRT,
+ BuiltinOperator_SQUARED_DIFFERENCE,
+ BuiltinOperator_SQUEEZE,
+ BuiltinOperator_STRIDED_SLICE,
+ BuiltinOperator_SUB,
+ BuiltinOperator_TANH,
+ BuiltinOperator_TRANSPOSE,
+ BuiltinOperator_TRANSPOSE_CONV,
+};
+
void TfliteImporter::collectUnsupportedOps()
{
+ std::set<std::string> errors;
for (auto sub_graph : *(_modelPacked->subgraphs()))
for (auto op : *(sub_graph->operators()))
- processUnsupportedOp(op);
+ {
+ BuiltinOperator opcode = (*_opcodes)[op->opcode_index()]->builtin_code();
+ if (supportedOperators.find(opcode) == supportedOperators.end())
+ {
+ if (opcode <= BuiltinOperator_MAX)
+ errors.insert(std::string(EnumNameBuiltinOperator(opcode)) + ": unsupported operator");
+ else
+ errors.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
+ }
+ }
- if (!_problemsOpSet.empty())
+ if (!errors.empty())
{
std::string msg("NNC can't load model. Detected problems:");
- for (const auto &problemStr : _problemsOpSet)
- msg.append("\n * " + problemStr);
+ for (const auto &e : errors)
+ msg.append("\n * " + e);
throw std::runtime_error(msg);
}
}
-void TfliteImporter::processUnsupportedOp(const Operator *op)
-{
- BuiltinOperator opcode = (*_opcodes)[op->opcode_index()]->builtin_code();
- switch (opcode)
- {
- case BuiltinOperator_MAX_POOL_2D:
- case BuiltinOperator_AVERAGE_POOL_2D:
- _opCreator->checkPool2D(op->builtin_options_as<Pool2DOptions>(), _problemsOpSet);
- break;
- case BuiltinOperator_CONCATENATION:
- _opCreator->checkConcatenation(op->builtin_options_as<ConcatenationOptions>(),
- _problemsOpSet);
- break;
- case BuiltinOperator_CONV_2D:
- _opCreator->checkConv2D(op->builtin_options_as<Conv2DOptions>(), _problemsOpSet);
- break;
- case BuiltinOperator_DEPTHWISE_CONV_2D:
- _opCreator->checkDepthwiseConv2D(op->builtin_options_as<DepthwiseConv2DOptions>(),
- _problemsOpSet);
- break;
- case BuiltinOperator_FULLY_CONNECTED:
- _opCreator->checkFullyConnected(op->builtin_options_as<FullyConnectedOptions>(),
- _problemsOpSet);
- break;
- case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
- _opCreator->checkResizeNearestNeighbor(op->builtin_options_as<ResizeNearestNeighborOptions>(),
- _problemsOpSet);
- break;
- case BuiltinOperator_STRIDED_SLICE:
- _opCreator->checkStridedSlice(op->builtin_options_as<StridedSliceOptions>(), _problemsOpSet);
- break;
- case BuiltinOperator_SHAPE:
- _opCreator->checkShape(op->builtin_options_as<ShapeOptions>(), _problemsOpSet);
- break;
- case BuiltinOperator_SOFTMAX:
- case BuiltinOperator_SLICE:
- case BuiltinOperator_RESHAPE:
- case BuiltinOperator_SQUEEZE:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_SQRT:
- case BuiltinOperator_PAD:
- case BuiltinOperator_ADD:
- case BuiltinOperator_SUB:
- case BuiltinOperator_SQUARED_DIFFERENCE:
- case BuiltinOperator_MUL:
- case BuiltinOperator_MEAN:
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_DIV:
- case BuiltinOperator_TRANSPOSE_CONV:
- case BuiltinOperator_TANH:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_TRANSPOSE:
- case BuiltinOperator_LEAKY_RELU:
- // No checks
- break;
- default:
- if (opcode <= BuiltinOperator_MAX)
- {
- _problemsOpSet.insert(std::string(EnumNameBuiltinOperator(opcode)) +
- ": unsupported operator");
- }
- else
- {
- _problemsOpSet.insert(std::to_string(opcode) + ": unsuppored in tflite custom opcode");
- }
- }
-}
-
std::unique_ptr<Graph> TfliteImporter::createIR()
{
walkGraphAndCreateMIR();
// Maps TFLite tensors indices to corresponding MIR operation outputs.
std::map<int, mir::Operation::Output *> _tensorMap;
- // set of strings describing incorrect parts of network and parts of network unsupported by NNC
- std::set<std::string> _problemsOpSet;
-
void import();
std::unique_ptr<mir::Graph> createIR();
*/
void collectUnsupportedOps();
- void processUnsupportedOp(const ::tflite::Operator *op);
-
/**
* @brief Mark output MIR nodes
*/
return constant_op->getValue();
}
-void TFLiteOpCreator::checkConv2D(const Conv2DOptions *opts,
- std::set<std::string> &problems_ops_set)
-{
- checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
std::vector<mir::Operation::Output *>
TFLiteOpCreator::convertConv2D(const Conv2DOptions *opts,
const std::vector<mir::Operation::Output *> &inputs)
return {addFusedActivation(result, opts->fused_activation_function())};
}
-void TFLiteOpCreator::checkDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
- std::set<std::string> &problems_ops_set)
-{
- checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
std::vector<mir::Operation::Output *>
TFLiteOpCreator::convertDepthwiseConv2D(const DepthwiseConv2DOptions *opts,
const std::vector<mir::Operation::Output *> &inputs)
return {addFusedActivation(result, opts->fused_activation_function())};
}
-void TFLiteOpCreator::checkConcatenation(const ConcatenationOptions *opts,
- std::set<std::string> &problems_ops_set)
-{
- checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
std::vector<mir::Operation::Output *>
TFLiteOpCreator::convertConcatenation(const ::tflite::ConcatenationOptions *opts,
const std::vector<mir::Operation::Output *> &inputs)
return {addFusedActivation(result->getOutput(0), opts->fused_activation_function())};
}
-void TFLiteOpCreator::checkPool2D(const Pool2DOptions *opts,
- std::set<std::string> &problems_ops_set)
-{
- checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
std::vector<mir::Operation::Output *>
TFLiteOpCreator::convertMaxPool2D(const ::tflite::Pool2DOptions *opts,
const std::vector<mir::Operation::Output *> &inputs)
return {result->getOutput(0)};
}
-void TFLiteOpCreator::checkResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
- std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
+ const std::vector<mir::Operation::Output *> &inputs)
{
if (opts->align_corners())
- problems_ops_set.insert("'align_corners' is not currently supported");
-}
+ throw std::runtime_error("'align_corners' is not currently supported");
-std::vector<mir::Operation::Output *> TFLiteOpCreator::convertResizeNearestNeighbor(
- const ::tflite::ResizeNearestNeighborOptions * /*opts*/,
- const std::vector<mir::Operation::Output *> &inputs)
-{
auto input = inputs.at(0);
mir::Tensor<int32_t> size_tensor(extractTensor(inputs.at(1)));
return {result->getOutput(0)};
}
-void TFLiteOpCreator::checkFullyConnected(const FullyConnectedOptions *opts,
- std::set<std::string> &problems_ops_set)
-{
- checkActivationType(opts->fused_activation_function(), problems_ops_set);
-}
-
std::vector<mir::Operation::Output *>
TFLiteOpCreator::convertFullyConnected(const ::tflite::FullyConnectedOptions *opts,
const std::vector<mir::Operation::Output *> &inputs)
return {addFusedActivation(result, opts->fused_activation_function())};
}
-void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type,
- std::set<std::string> &problems_ops_set)
-{
- if (activation_type != ActivationFunctionType_NONE &&
- activation_type != ActivationFunctionType_RELU &&
- activation_type != ActivationFunctionType_RELU6 &&
- activation_type != ActivationFunctionType_TANH)
- problems_ops_set.insert(std::string("Unsupported activation type: ") +
- EnumNameActivationFunctionType(activation_type));
-}
-
mir::Operation::Output *TFLiteOpCreator::addFusedActivation(mir::Operation::Output *input,
ActivationFunctionType activation_type)
{
- // TODO Support other activation function types.
switch (activation_type)
{
case ActivationFunctionType_NONE:
case ActivationFunctionType_TANH:
return createOp<ops::TanhOp>(input)->getOutput(0);
default:
- assert(false && "Unsupported activation types must be detected before this pass");
+ throw std::runtime_error(std::string("Unsupported activation type: ") +
+ tflite::EnumNameActivationFunctionType(activation_type));
}
}
return {result->getOutput(0)};
}
-void TFLiteOpCreator::checkStridedSlice(const ::tflite::StridedSliceOptions *opts,
- std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions *opts,
+ const std::vector<mir::Operation::Output *> &inputs)
{
if (opts->ellipsis_mask() != 0)
- problems_ops_set.insert("StridedSlice: parameter 'ellipsis_mask' is not supported.");
+ throw std::runtime_error("StridedSlice: parameter 'ellipsis_mask' is not supported.");
if (opts->new_axis_mask() != 0)
- problems_ops_set.insert("StridedSlice: parameter 'new_axis_mask' is not supported.");
-}
+ throw std::runtime_error("StridedSlice: parameter 'new_axis_mask' is not supported.");
-std::vector<mir::Operation::Output *>
-TFLiteOpCreator::convertStridedSlice(const ::tflite::StridedSliceOptions *opts,
- const std::vector<mir::Operation::Output *> &inputs)
-{
auto input = inputs.at(0);
mir::Tensor<int32_t> begin_tensor(extractTensor(inputs.at(1)));
mir::Tensor<int32_t> end_tensor(extractTensor(inputs.at(2)));
return {result->getOutput(0)};
}
-void TFLiteOpCreator::checkShape(const ::tflite::ShapeOptions *opts,
- std::set<std::string> &problems_ops_set)
+std::vector<mir::Operation::Output *>
+TFLiteOpCreator::convertShape(const ::tflite::ShapeOptions *opts,
+ const std::vector<mir::Operation::Output *> &inputs)
{
if (opts->out_type() != TensorType_INT32)
{
- problems_ops_set.insert(std::string("SHAPE: Unsupported tensor type: ") +
- EnumNameTensorType(opts->out_type()));
+ throw std::runtime_error(std::string("SHAPE: Unsupported tensor type: ") +
+ EnumNameTensorType(opts->out_type()));
}
-}
-std::vector<mir::Operation::Output *>
-TFLiteOpCreator::convertShape(const ::tflite::ShapeOptions * /*opts*/,
- const std::vector<mir::Operation::Output *> &inputs)
-{
const auto &input_shape = inputs[0]->getShape();
int32_t rank = input_shape.rank();
Shape output_shape{rank};
convertShape(const ::tflite::ShapeOptions *opts,
const std::vector<mir::Operation::Output *> &inputs);
- void checkPool2D(const ::tflite::Pool2DOptions *opts, std::set<std::string> &problems_ops_set);
-
- void checkConcatenation(const ::tflite::ConcatenationOptions *opts,
- std::set<std::string> &problems_ops_set);
-
- void checkConv2D(const ::tflite::Conv2DOptions *opts, std::set<std::string> &problems_ops_set);
-
- void checkDepthwiseConv2D(const ::tflite::DepthwiseConv2DOptions *opts,
- std::set<std::string> &problems_ops_set);
-
- void checkFullyConnected(const ::tflite::FullyConnectedOptions *opts,
- std::set<std::string> &problems_ops_set);
-
- void checkResizeNearestNeighbor(const ::tflite::ResizeNearestNeighborOptions *opts,
- std::set<std::string> &problems_ops_set);
-
- void checkStridedSlice(const ::tflite::StridedSliceOptions *opts,
- std::set<std::string> &problems_ops_set);
-
- void checkShape(const ::tflite::ShapeOptions *opts, std::set<std::string> &problems_ops_set);
-
private:
Graph *_graph;
{::tflite::Padding_SAME, ops::PaddingType::Same},
{::tflite::Padding_VALID, ops::PaddingType::Valid}};
- void checkActivationType(::tflite::ActivationFunctionType activation_type,
- std::set<std::string> &problems_ops_set);
-
mir::Operation::Output *addFusedActivation(mir::Operation::Output *input,
::tflite::ActivationFunctionType activation_type);