return index2data[index];
}
+ /**
+ * @brief Get the Data object
+ *
+ * @return const std::vector<T>& underlying data
+ */
+ const std::vector<T> &getData() const { return index2data; }
+
private:
std::unordered_map<T, unsigned int> data2index; /**< data -> index map */
std::vector<T> index2data; /**< index -> data map */
return std::get<BidirectionalIndexMap<T>>(maps);
}
+ template <typename T> const BidirectionalIndexMap<T> &getIndexMap() const {
+ return std::get<BidirectionalIndexMap<T>>(maps);
+ }
+
private:
float empty_buffer[0]; /**< unintialized tensor points to this buffer */
flatbuffers::Offset<
flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
- /** NYI! */
- return flatbuffers::Offset<
- flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>();
-};
+ const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
+
+ std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
+ fb_op_codes.reserve(op_codes.size());
+
+ auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
+ int32_t version = 1) {
+ tflite::OperatorCodeBuilder builder(fbb);
+ builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
+ /// @todo find reason why version field is not shown
+ /// on json when version is 1 (other versions are fine)
+ builder.add_version(version);
+ builder.add_builtin_code(op);
+ return builder.Finish();
+ };
+
+ std::transform(op_codes.begin(), op_codes.end(),
+ std::back_inserter(fb_op_codes), create_op_offset);
+
+ return fbb.CreateVector(fb_op_codes);
+}
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
buildSubGraph(const TfOpNodes &nodes, const TfOpIdxMap &map,
auto opNodes = buildOpNodes(representation);
TfOpIdxMap map(opNodes); /// build TfOpIdxMap
- auto UNUSED(opcodes) = buildOperatorCodes(map, fbb);
+ auto opcodes = buildOperatorCodes(map, fbb);
auto UNUSED(buffers) = buildBuffers(map, fbb);
auto UNUSED(subgraph) = buildSubGraph(opNodes, map, fbb);
auto desc = fbb.CreateString("This file is generated from NNTrainer");
tflite::ModelBuilder model_builder(fbb);
- WILL(model_builder.add_operator_codes(opcode_offset));
+ model_builder.add_operator_codes(opcodes);
WILL(model_builder.add_buffers(buffers));
WILL(model_builder.add_subgraphs(subgraph));
model_builder.add_version(3);