[tflite2circle] Support SparsityParameters (#3947)
authorseongwoo chae <mhs4670go@naver.com>
Tue, 25 Aug 2020 00:28:47 +0000 (09:28 +0900)
committerGitHub <noreply@github.com>
Tue, 25 Aug 2020 00:28:47 +0000 (09:28 +0900)
This commit enables tflite2circle to support SparsityParameters.

ONE-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com

compiler/tflite2circle/src/CircleModel.cpp
compiler/tflite2circle/src/DataLookup.cpp
compiler/tflite2circle/src/DataLookup.h

index a950f15..14c44cb 100644 (file)
@@ -119,6 +119,66 @@ Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_fla
       // is_variable
       bool is_variable = it->is_variable();
 
+      flatbuffers::Offset<circle::SparsityParameters> sparsity;
+      // sparsity
+      if (it->sparsity())
+      {
+        flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
+        flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
+        flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
+            dim_metadata;
+
+        // traversal_order
+        if (it->sparsity()->traversal_order())
+        {
+          auto traversal_order_vec = std::vector<int32_t>{
+              it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
+          traversal_order = fb->CreateVector(traversal_order_vec);
+        }
+
+        // block_map
+        if (it->sparsity()->block_map())
+        {
+          auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
+                                                    it->sparsity()->block_map()->end()};
+          block_map = fb->CreateVector(block_map_vec);
+        }
+
+        // dim_metadata
+        std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
+        auto tflite_dim_metadata = it->sparsity()->dim_metadata();
+        for (auto it : *tflite_dim_metadata)
+        {
+          // array_segments
+          auto tflite_array_segments_type = it->array_segments_type();
+          auto circle_array_segments =
+              get_circle_sparse_index_vector(*fb, it, tflite_array_segments_type);
+          auto circle_array_segments_type =
+              get_circle_sparse_index_vector_type(tflite_array_segments_type);
+
+          // array_indices
+          auto tflite_array_indices_type = it->array_indices_type();
+          auto circle_array_indices =
+              get_circle_sparse_index_vector(*fb, it, tflite_array_indices_type);
+          auto circle_array_indices_type =
+              get_circle_sparse_index_vector_type(tflite_array_indices_type);
+
+          auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*fb};
+
+          circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
+          circle_dim_metadata_builder.add_dense_size(it->dense_size());
+          circle_dim_metadata_builder.add_array_segments(circle_array_segments);
+          circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
+          circle_dim_metadata_builder.add_array_indices(circle_array_indices);
+          circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
+          auto dim_metadata = circle_dim_metadata_builder.Finish();
+          dim_metadata_vec.emplace_back(dim_metadata);
+        }
+        dim_metadata = fb->CreateVector(dim_metadata_vec);
+
+        sparsity = circle::CreateSparsityParameters(*fb, traversal_order, block_map, dim_metadata);
+      }
+
       // shape signature
       flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
       if (it->shape_signature())
@@ -135,6 +195,7 @@ Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_fla
       tensor_builder.add_name(name);
       tensor_builder.add_quantization(quantization);
       tensor_builder.add_is_variable(is_variable);
+      tensor_builder.add_sparsity(sparsity);
       tensor_builder.add_shape_signature(shape_signature);
       auto tensor = tensor_builder.Finish();
       tensor_vec.emplace_back(tensor);
index b0d35d1..7122cdf 100644 (file)
@@ -123,4 +123,75 @@ circle::MirrorPadMode get_circle_mirrorpad_mode(tflite::MirrorPadMode tfl_mode)
   }
 }
 
+circle::DimensionType get_circle_dimension_type(tflite::DimensionType tfl_dim_type)
+{
+  switch (tfl_dim_type)
+  {
+    case tflite::DimensionType_DENSE:
+      return circle::DimensionType_DENSE;
+    case tflite::DimensionType_SPARSE_CSR:
+      return circle::DimensionType_SPARSE_CSR;
+    default:
+      throw std::runtime_error("tflite2circle: wrong dimension type.");
+  }
+}
+
+flatbuffers::Offset<void>
+get_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
+                               const tflite::DimensionMetadata *dm,
+                               const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
+{
+  switch (tfl_sparse_index_vector_type)
+  {
+    case tflite::SparseIndexVector_Int32Vector:
+    {
+      auto values_vec_int32 =
+          std::vector<int32_t>{dm->array_segments_as_Int32Vector()->values()->begin(),
+                               dm->array_segments_as_Int32Vector()->values()->end()};
+      auto values_int32 = fb.CreateVector(values_vec_int32);
+      circle::Int32VectorBuilder int32_vector_builder{fb};
+      int32_vector_builder.add_values(values_int32);
+      return int32_vector_builder.Finish().Union();
+    }
+    case tflite::SparseIndexVector_Uint16Vector:
+    {
+      auto values_vec_uint16 =
+          std::vector<uint16_t>{dm->array_segments_as_Uint16Vector()->values()->begin(),
+                                dm->array_segments_as_Uint16Vector()->values()->end()};
+      auto values_uint16 = fb.CreateVector(values_vec_uint16);
+      circle::Uint16VectorBuilder uint16_vector_builder{fb};
+      uint16_vector_builder.add_values(values_uint16);
+      return uint16_vector_builder.Finish().Union();
+    }
+    case tflite::SparseIndexVector_Uint8Vector:
+    {
+      auto values_vec_uint8 =
+          std::vector<uint8_t>{dm->array_segments_as_Uint8Vector()->values()->begin(),
+                               dm->array_segments_as_Uint8Vector()->values()->end()};
+      auto values_uint8 = fb.CreateVector(values_vec_uint8);
+      circle::Uint8VectorBuilder uint8_vector_builder{fb};
+      uint8_vector_builder.add_values(values_uint8);
+      return uint8_vector_builder.Finish().Union();
+    }
+    default:
+      throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
+  }
+}
+
+circle::SparseIndexVector
+get_circle_sparse_index_vector_type(const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
+{
+  switch (tfl_sparse_index_vector_type)
+  {
+    case tflite::SparseIndexVector_Int32Vector:
+      return circle::SparseIndexVector_Int32Vector;
+    case tflite::SparseIndexVector_Uint16Vector:
+      return circle::SparseIndexVector_Uint16Vector;
+    case tflite::SparseIndexVector_Uint8Vector:
+      return circle::SparseIndexVector_Uint8Vector;
+    default:
+      throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
+  }
+}
+
 } // namespace tflite2circle
index 7ea01b9..26ad746 100644 (file)
@@ -76,6 +76,25 @@ circle::BuiltinOptions get_circle_builtin_options_type(const tflite::Operator *o
 */
 circle::MirrorPadMode get_circle_mirrorpad_mode(tflite::MirrorPadMode tfl_mode);
 
+/**
+ * @brief Returns circle DimensionType according to tflite.
+*/
+circle::DimensionType get_circle_dimension_type(tflite::DimensionType tfl_dim_type);
+
+/**
+ * @brief Returns circle SparseIndexVector according to tflite.
+*/
+flatbuffers::Offset<void>
+get_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
+                               const tflite::DimensionMetadata *dm,
+                               const tflite::SparseIndexVector &tfl_sparse_index_vector_type);
+
+/**
+ * @brief Returns circle SparseIndexVector type according to tflite.
+*/
+circle::SparseIndexVector
+get_circle_sparse_index_vector_type(const tflite::SparseIndexVector &tfl_sparse_index_vector_type);
+
 } // namespace tflite2circle
 
 #endif // __DATA_LOOKUP_H__