return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
}
+bool is_valid(const circle::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+ circle::BuiltinOperator code = opcode->builtin_code();
+ return (circle::BuiltinOperator_MIN <= code && code <= circle::BuiltinOperator_MAX);
+}
+
bool is_custom(const circle::OperatorCodeT &opcode)
{
circle::BuiltinOperator code = opcode.builtin_code;
return (code == circle::BuiltinOperator_CUSTOM);
}
+bool is_custom(const circle::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+ circle::BuiltinOperator code = opcode->builtin_code();
+ return (code == circle::BuiltinOperator_CUSTOM);
+}
+
std::string opcode_name(const circle::OperatorCodeT &opcode)
{
if (!is_valid(opcode))
return circle::EnumNameBuiltinOperator(code);
}
+std::string opcode_name(const circle::OperatorCode *opcode)
+{
+ assert(opcode != nullptr);
+
+ if (!is_valid(opcode))
+ {
+ std::ostringstream oss;
+ oss << "(invalid)";
+ return oss.str();
+ }
+
+ if (is_custom(opcode))
+ {
+ auto custom_code = opcode->custom_code()->str();
+ if (custom_code.empty())
+ return "(invalid custom)";
+
+ return custom_code;
+ }
+
+ circle::BuiltinOperator code = opcode->builtin_code();
+ return circle::EnumNameBuiltinOperator(code);
+}
+
const char *tensor_name(const circle::TensorT &tensor)
{
static const char *kEmptyTensorName = "(noname)";
return kEmptyTensorName;
}
+const char *tensor_name(const circle::Tensor *tensor)
+{
+ assert(tensor != nullptr);
+
+ static const char *kEmptyTensorName = "(noname)";
+ const auto tensor_name = tensor->name()->c_str();
+
+ if (!std::string(tensor_name).empty())
+ return tensor_name;
+
+ return kEmptyTensorName;
+}
+
const circle::QuantizationParametersT *tensor_quantization(const circle::TensorT &tensor)
{
return tensor.quantization.get();
}
+const circle::QuantizationParameters *tensor_quantization(const circle::Tensor *tensor)
+{
+ assert(tensor != nullptr);
+ return tensor->quantization();
+}
+
loco::DataType luci_datatype(const circle::TensorType type)
{
switch (type)
return nullptr;
}
+std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParameters *qparams)
+{
+ // create temporary unpacked API object
+ assert(qparams != nullptr);
+ circle::QuantizationParametersT quantization;
+ qparams->UnPackTo(&quantization);
+
+ return luci_quantparam(&quantization);
+}
+
std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParametersT *sparsity)
{
assert(sparsity);
return sparsityparam;
}
+std::unique_ptr<SparsityParam> luci_sparsityparam(const circle::SparsityParameters *sparparam)
+{
+ // create temporary unpacked API object
+ assert(sparparam != nullptr);
+ circle::SparsityParametersT sparsity;
+ sparparam->UnPackTo(&sparsity);
+
+ return luci_sparsityparam(&sparsity);
+}
+
void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
{
node->name(tensor_name(tensor));
}
}
+void copy_tensor_attributes(const circle::Tensor *tensor, CircleNode *node)
+{
+ assert(tensor != nullptr);
+
+ node->name(tensor_name(tensor));
+ node->dtype(luci_datatype(tensor->type()));
+
+ const auto tensor_shape_signature = wrap(tensor->shape_signature());
+ const auto tensor_shape = wrap(tensor->shape());
+ assert(tensor_shape_signature.size() == 0 ||
+ tensor_shape_signature.size() == tensor_shape.size());
+
+ const auto dims = tensor_shape; // in NHWC
+ node->rank(dims.size());
+ for (uint32_t r = 0; r < dims.size(); ++r)
+ {
+ if (tensor_shape_signature.size() > 0 && tensor_shape_signature.at(r) == -1)
+ node->dim(r).unset();
+ else
+ node->dim(r).set(dims[r]);
+ }
+
+ const auto quantization = tensor->quantization();
+ if (quantization != nullptr)
+ {
+ auto quantparam = luci_quantparam(quantization);
+ if (quantparam)
+ node->quantparam(std::move(quantparam));
+ }
+
+ const auto sparsity = tensor->sparsity();
+ if (sparsity != nullptr)
+ {
+ auto sparsityparam = luci_sparsityparam(sparsity);
+ if (sparsityparam)
+ node->sparsityparam(std::move(sparsityparam));
+ }
+}
+
circle::BuiltinOperator CircleReader::builtin_code(const circle::OperatorT &op) const
{
const auto &op_codes = opcodes();
_model.reset(model->UnPack());
// for direct pointer access
- _model_ptr = model;
+ _native_model = model;
return true;
}
_current_subgraph = _model->subgraphs[sgindex].get();
// for direct pointer access
- auto subgraphs = _model_ptr->subgraphs();
- const circle::SubGraph *subgraph = (*subgraphs)[sgindex];
+ auto subgraphs = _native_model->subgraphs();
+ assert(subgraphs != nullptr);
+
+ _native_subgraph = subgraphs->Get(sgindex);
+ assert(_native_subgraph != nullptr);
- _tensors_ptr = subgraph->tensors();
+ _tensors_ptr = _native_subgraph->tensors();
return true;
}
+template <typename T>
+VectorWrapper<T>::VectorWrapper(const flatbuffers::Vector<T> *ptr) : _vector(ptr)
+{
+ // Do nothing
+}
+
+template <typename T> uint32_t VectorWrapper<T>::size() const
+{
+ return null() ? 0 : _vector->size();
+}
+
+template <typename T> const T *VectorWrapper<T>::data() const
+{
+ return null() ? nullptr : _vector->data();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::begin() const
+{
+ return null() ? iterator(nullptr, 0) : _vector->begin();
+}
+
+template <typename T> typename VectorWrapper<T>::iterator VectorWrapper<T>::end() const
+{
+ return null() ? begin() : _vector->end();
+}
+
+template <typename T> typename VectorWrapper<T>::value_type VectorWrapper<T>::at(uint32_t i) const
+{
+ if (i >= size())
+ {
+ // TODO find better error message
+ throw std::range_error("Access to prohibited vector element");
+ }
+
+ return _vector->Get(i);
+}
+
+template <typename T>
+typename VectorWrapper<T>::value_type VectorWrapper<T>::operator[](uint32_t i) const
+{
+ return at(i);
+}
+
+template <typename T> bool VectorWrapper<T>::null() const { return _vector == nullptr; }
+template <typename T> bool VectorWrapper<T>::empty() const { return size() == 0; }
+
+#define REGISTER_WRAPPER(T) template class VectorWrapper<T>
+REGISTER_WRAPPER(flatbuffers::Offset<circle::SubGraph>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Buffer>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Tensor>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Operator>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::OperatorCode>);
+REGISTER_WRAPPER(flatbuffers::Offset<circle::Metadata>);
+REGISTER_WRAPPER(int32_t);
+REGISTER_WRAPPER(uint8_t);
+#undef REGISTER_WRAPPER
+
} // namespace luci