This commit enables tflite2circle to support SparsityParameters.
ONE-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com
// is_variable
bool is_variable = it->is_variable();
+ flatbuffers::Offset<circle::SparsityParameters> sparsity;
+ // sparsity
+ if (it->sparsity())
+ {
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
+ dim_metadata;
+
+ // traversal_order
+ if (it->sparsity()->traversal_order())
+ {
+ auto traversal_order_vec = std::vector<int32_t>{
+ it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
+ traversal_order = fb->CreateVector(traversal_order_vec);
+ }
+
+ // block_map
+ if (it->sparsity()->block_map())
+ {
+ auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
+ it->sparsity()->block_map()->end()};
+ block_map = fb->CreateVector(block_map_vec);
+ }
+
+ // dim_metadata
+ std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
+ auto tflite_dim_metadata = it->sparsity()->dim_metadata();
+ for (auto it : *tflite_dim_metadata)
+ {
+ // array_segments
+ auto tflite_array_segments_type = it->array_segments_type();
+ auto circle_array_segments =
+ get_circle_sparse_index_vector(*fb, it, tflite_array_segments_type);
+ auto circle_array_segments_type =
+ get_circle_sparse_index_vector_type(tflite_array_segments_type);
+
+ // array_indices
+ auto tflite_array_indices_type = it->array_indices_type();
+ auto circle_array_indices =
+ get_circle_sparse_index_vector(*fb, it, tflite_array_indices_type);
+ auto circle_array_indices_type =
+ get_circle_sparse_index_vector_type(tflite_array_indices_type);
+
+ auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*fb};
+
+ circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
+ circle_dim_metadata_builder.add_dense_size(it->dense_size());
+ circle_dim_metadata_builder.add_array_segments(circle_array_segments);
+ circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
+ circle_dim_metadata_builder.add_array_indices(circle_array_indices);
+ circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
+ auto dim_metadata = circle_dim_metadata_builder.Finish();
+ dim_metadata_vec.emplace_back(dim_metadata);
+ }
+ dim_metadata = fb->CreateVector(dim_metadata_vec);
+
+ sparsity = circle::CreateSparsityParameters(*fb, traversal_order, block_map, dim_metadata);
+ }
+
// shape signature
flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
if (it->shape_signature())
tensor_builder.add_name(name);
tensor_builder.add_quantization(quantization);
tensor_builder.add_is_variable(is_variable);
+ tensor_builder.add_sparsity(sparsity);
tensor_builder.add_shape_signature(shape_signature);
auto tensor = tensor_builder.Finish();
tensor_vec.emplace_back(tensor);
}
}
+circle::DimensionType get_circle_dimension_type(tflite::DimensionType tfl_dim_type)
+{
+ switch (tfl_dim_type)
+ {
+ case tflite::DimensionType_DENSE:
+ return circle::DimensionType_DENSE;
+ case tflite::DimensionType_SPARSE_CSR:
+ return circle::DimensionType_SPARSE_CSR;
+ default:
+ throw std::runtime_error("tflite2circle: wrong dimension type.");
+ }
+}
+
+flatbuffers::Offset<void>
+get_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
+ const tflite::DimensionMetadata *dm,
+ const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
+{
+ switch (tfl_sparse_index_vector_type)
+ {
+ case tflite::SparseIndexVector_Int32Vector:
+ {
+ auto values_vec_int32 =
+ std::vector<int32_t>{dm->array_segments_as_Int32Vector()->values()->begin(),
+ dm->array_segments_as_Int32Vector()->values()->end()};
+ auto values_int32 = fb.CreateVector(values_vec_int32);
+ circle::Int32VectorBuilder int32_vector_builder{fb};
+ int32_vector_builder.add_values(values_int32);
+ return int32_vector_builder.Finish().Union();
+ }
+ case tflite::SparseIndexVector_Uint16Vector:
+ {
+ auto values_vec_uint16 =
+ std::vector<uint16_t>{dm->array_segments_as_Uint16Vector()->values()->begin(),
+ dm->array_segments_as_Uint16Vector()->values()->end()};
+ auto values_uint16 = fb.CreateVector(values_vec_uint16);
+ circle::Uint16VectorBuilder uint16_vector_builder{fb};
+ uint16_vector_builder.add_values(values_uint16);
+ return uint16_vector_builder.Finish().Union();
+ }
+ case tflite::SparseIndexVector_Uint8Vector:
+ {
+ auto values_vec_uint8 =
+ std::vector<uint8_t>{dm->array_segments_as_Uint8Vector()->values()->begin(),
+ dm->array_segments_as_Uint8Vector()->values()->end()};
+ auto values_uint8 = fb.CreateVector(values_vec_uint8);
+ circle::Uint8VectorBuilder uint8_vector_builder{fb};
+ uint8_vector_builder.add_values(values_uint8);
+ return uint8_vector_builder.Finish().Union();
+ }
+ default:
+ throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
+ }
+}
+
+circle::SparseIndexVector
+get_circle_sparse_index_vector_type(const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
+{
+ switch (tfl_sparse_index_vector_type)
+ {
+ case tflite::SparseIndexVector_Int32Vector:
+ return circle::SparseIndexVector_Int32Vector;
+ case tflite::SparseIndexVector_Uint16Vector:
+ return circle::SparseIndexVector_Uint16Vector;
+ case tflite::SparseIndexVector_Uint8Vector:
+ return circle::SparseIndexVector_Uint8Vector;
+ default:
+ throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
+ }
+}
+
} // namespace tflite2circle
*/
circle::MirrorPadMode get_circle_mirrorpad_mode(tflite::MirrorPadMode tfl_mode);
+/**
+ * @brief Returns circle DimensionType according to tflite.
+*/
+circle::DimensionType get_circle_dimension_type(tflite::DimensionType tfl_dim_type);
+
+/**
+ * @brief Returns circle SparseIndexVector according to tflite.
+*/
+flatbuffers::Offset<void>
+get_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
+ const tflite::DimensionMetadata *dm,
+ const tflite::SparseIndexVector &tfl_sparse_index_vector_type);
+
+/**
+ * @brief Returns circle SparseIndexVector type according to tflite.
+*/
+circle::SparseIndexVector
+get_circle_sparse_index_vector_type(const tflite::SparseIndexVector &tfl_sparse_index_vector_type);
+
} // namespace tflite2circle
#endif // __DATA_LOOKUP_H__