Op version: Populate version in Toco TFLite exporter
authorYu-Cheng Ling <ycling@google.com>
Fri, 18 May 2018 17:39:45 +0000 (10:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 17:42:12 +0000 (10:42 -0700)
PiperOrigin-RevId: 197166962

tensorflow/contrib/lite/toco/tflite/export.cc
tensorflow/contrib/lite/toco/tflite/export.h
tensorflow/contrib/lite/toco/tflite/export_test.cc
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator.h
tensorflow/contrib/lite/toco/tflite/simple_operator.h

index a4c0b2d..5daa703 100644 (file)
@@ -45,14 +45,20 @@ using ::tflite::Tensor;
 
 namespace {
 
-details::OperatorKey GetOperatorKey(const ::toco::Operator& op) {
+details::OperatorKey GetOperatorKey(
+    const ::toco::Operator& op,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
   string custom_code;
   if (op.type == OperatorType::kTensorFlowUnsupported) {
     const TensorFlowUnsupportedOperator& unsupported_op =
         static_cast<const TensorFlowUnsupportedOperator&>(op);
     custom_code = unsupported_op.tensorflow_op;
   }
-  return details::OperatorKey(op.type, custom_code);
+  int version = 1;
+  if (ops_by_type.count(op.type) != 0) {
+    version = ops_by_type.at(op.type)->GetVersion(op);
+  }
+  return details::OperatorKey(op.type, custom_code, version);
 }
 
 }  // Anonymous namespace.
@@ -74,11 +80,13 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
   }
 }
 
-void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) {
+void LoadOperatorsMap(
+    const Model& model, OperatorsMap* operators_map,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
   // First find a list of unique operator types.
   std::set<OperatorKey> keys;
   for (const auto& op : model.operators) {
-    keys.insert(GetOperatorKey(*op));
+    keys.insert(GetOperatorKey(*op, ops_by_type));
   }
   // Now assign indices to them and fill in the map.
   int index = 0;
@@ -185,8 +193,9 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
   std::map<int, Offset<OperatorCode>> ordered_opcodes;
 
   for (const auto& op : model.operators) {
-    const details::OperatorKey operator_key = GetOperatorKey(*op);
+    const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
     int op_index = operators_map.at(operator_key);
+    int op_version = operator_key.version;
 
     string name = HelpfulOperatorTypeName(*op);
     bool is_builtin = false;
@@ -197,7 +206,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
 
     if (is_builtin) {
       ordered_opcodes[op_index] =
-          CreateOperatorCode(*builder, builtin_ops[name], 0);
+          CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
     } else {
       // This could be a kTensorFlowUnsupported, in which case we should be
       // able to retrieve the original Tensorflow name from the OperatorKey, or
@@ -211,8 +220,9 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
       if (error_summary) {
         error_summary->insert(name);
       }
-      ordered_opcodes[op_index] = CreateOperatorCode(
-          *builder, BuiltinOperator_CUSTOM, builder->CreateString(name));
+      ordered_opcodes[op_index] =
+          CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
+                             builder->CreateString(name), op_version);
     }
   }
 
@@ -244,7 +254,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
       outputs.push_back(tensors_map.at(output));
     }
 
-    int op_index = operators_map.at(GetOperatorKey(*op));
+    int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
 
     // This is a custom op unless we can find it in ops_by_type, and even then
     // it could be a custom op (such as kTensorFlowUnsupported).
@@ -279,15 +289,20 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
 
 void Export(const Model& model, bool allow_custom_ops,
             string* output_file_contents) {
-  flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
-
   const auto ops_by_type = BuildOperatorByTypeMap();
+  Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+}
+
+void Export(
+    const Model& model, bool allow_custom_ops, string* output_file_contents,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+  flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
 
   details::TensorsMap tensors_map;
   details::LoadTensorsMap(model, &tensors_map);
 
   details::OperatorsMap operators_map;
-  details::LoadOperatorsMap(model, &operators_map);
+  details::LoadOperatorsMap(model, &operators_map, ops_by_type);
 
   std::vector<const Array*> buffers_to_write;
   Array empty_array;
index 8c79cb8..90abfb9 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
 
 #include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
 
 namespace toco {
 
@@ -25,11 +26,18 @@ namespace tflite {
 // result in the given string.
 void Export(const Model& model, bool allow_custom_ops,
             string* output_file_contents);
+
 // This if backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
 inline void Export(const Model& model, string* output_file_contents) {
   Export(model, true, output_file_contents);
 }
 
+// Export API with custom TFLite operator mapping.
+void Export(
+    const Model& model, bool allow_custom_ops, string* output_file_contents,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+
 namespace details {
 
 // A maps from tensor name to its final position in the TF Lite buffer.
@@ -39,25 +47,47 @@ using TensorsMap = std::unordered_map<string, int>;
 // Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
 // identify which operation is used.
 struct OperatorKey {
-  OperatorKey(OperatorType type, const std::string& custom_code)
-      : type(type), custom_code(custom_code) {}
+  OperatorKey(OperatorType type, const std::string& custom_code, int version)
+      : type(type), custom_code(custom_code), version(version) {}
   const OperatorType type;
   const std::string custom_code;
+  const int version;
 
   bool operator<(const OperatorKey& other) const {
     if (type < other.type) return true;
-    if (type > other.type) return false;
-    return custom_code < other.custom_code;
+    else if (type > other.type)
+      return false;
+    else if (custom_code < other.custom_code)
+      return true;
+    else if (custom_code > other.custom_code)
+      return false;
+    else
+      return version < other.version;
   }
 
   bool operator==(const OperatorKey& other) const {
-    return type == other.type && custom_code == other.custom_code;
+    return type == other.type && custom_code == other.custom_code &&
+           version == other.version;
   }
 
   struct Hash {
-    std::size_t operator()(const OperatorKey& key) const {
-      return std::hash<size_t>()(static_cast<size_t>(key.type)) ^
-             std::hash<std::string>()(key.custom_code);
+    size_t operator()(const OperatorKey& key) const {
+      return CombineHashes({std::hash<size_t>()(static_cast<size_t>(key.type)),
+                            std::hash<std::string>()(key.custom_code),
+                            std::hash<int>()(key.version)});
+    }
+
+   private:
+    // TODO(ycling): Refactoring and extract this function into a common
+    // utility module.
+    static size_t CombineHashes(std::initializer_list<size_t> hashes) {
+      size_t result = 0;
+      // Hash combiner used by TensorFlow core.
+      for (size_t hash : hashes) {
+        result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) +
+                           (result >> 4));
+      }
+      return result;
     }
   };
 };
@@ -66,11 +96,12 @@ struct OperatorKey {
 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
 
 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
-void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map);
+void LoadOperatorsMap(
+    const Model& model, OperatorsMap* operators_map,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
 
 }  // namespace details
 }  // namespace tflite
-
 }  // namespace toco
 
 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
index 6754372..409e7d7 100644 (file)
@@ -17,6 +17,9 @@ limitations under the License.
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
 
 namespace toco {
 namespace tflite {
@@ -65,12 +68,13 @@ TEST_F(ExportTest, LoadOperatorsMap) {
   BuildTestModel();
 
   details::OperatorsMap operators;
-  details::LoadOperatorsMap(input_model_, &operators);
-  EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]);
-  EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]);
-  EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "")]);
+  const auto ops_by_type = BuildOperatorByTypeMap();
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+  EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
+  EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
+  EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
   EXPECT_EQ(3, operators[details::OperatorKey(
-                   OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]);
+                   OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]);
 }
 
 TEST_F(ExportTest, Export) {
@@ -104,6 +108,160 @@ TEST_F(ExportTest, Export) {
   EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
 }
 
+// This test is based on a hypothetical scenario that dilation is supported
+// only in Conv version 2. So Toco populates version=1 when dialation
+// parameters are all 1, and version=2 otehrwise.
+class FakeConvolutionOperator
+    : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
+                             ::tflite::BuiltinOptions_Conv2DOptions> {
+ public:
+  FakeConvolutionOperator()
+      : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
+                        OperatorType::kConv) {}
+
+  // Returning the op version according to the op parameters.
+  int GetVersion(const Operator& op) const override {
+    const TocoOperator& conv_op = static_cast<const TocoOperator&>(op);
+    if (conv_op.dilation_width_factor != 1 ||
+        conv_op.dilation_height_factor != 1) {
+      // Version 2 if dilation is used.
+      return 2;
+    }
+    return 1;
+  }
+
+  // Note: The read / write code doesn't need to be changed if we stick with
+  // the restrictions:
+  // * Only adding parameters at the bottom of the Flatbuffer tables.
+  // * When the default value of parameters are used, the op works consistently
+  //   with the previous version.
+  flatbuffers::Offset<TfLiteOptions> WriteOptions(
+      const TocoOperator& op,
+      flatbuffers::FlatBufferBuilder* builder) const override {
+    auto padding = Padding::Serialize(op.padding.type);
+    auto activation_function =
+        ActivationFunction::Serialize(op.fused_activation_function);
+    return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
+                                         op.stride_height, activation_function,
+                                         op.dilation_width_factor,
+                                         op.dilation_height_factor);
+  }
+
+  void ReadOptions(const TfLiteOptions& options,
+                   TocoOperator* op) const override {
+    op->padding.type = Padding::Deserialize(options.padding());
+    op->stride_width = options.stride_w();
+    op->stride_height = options.stride_h();
+    op->dilation_width_factor = options.dilation_w_factor();
+    op->dilation_height_factor = options.dilation_h_factor();
+    op->fused_activation_function =
+        ActivationFunction::Deserialize(options.fused_activation_function());
+  }
+};
+
+class VersionedOpExportTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    input_model_.GetOrCreateArray("input");
+    input_model_.GetOrCreateArray("filter");
+    input_model_.GetOrCreateArray("output");
+  }
+  void AddConvOp(bool use_dialation) {
+    {
+      auto* op = new ConvOperator;
+      op->inputs.push_back("input");
+      op->inputs.push_back("filter");
+      op->inputs.push_back("output");
+
+      op->padding.type = PaddingType::kSame;
+      op->stride_width = 1;
+      op->stride_height = 1;
+      if (use_dialation) {
+        op->dilation_width_factor = 2;
+        op->dilation_height_factor = 2;
+      } else {
+        op->dilation_width_factor = 1;
+        op->dilation_height_factor = 1;
+      }
+      input_model_.operators.emplace_back(op);
+    }
+  }
+
+  std::map<OperatorType, std::unique_ptr<BaseOperator>>
+  BuildFakeOperatorByTypeMap() {
+    std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
+    result[OperatorType::kConv] =
+        std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
+    return result;
+  }
+
+  Model input_model_;
+};
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
+  AddConvOp(false);
+
+  details::OperatorsMap operators;
+  const auto ops_by_type = BuildFakeOperatorByTypeMap();
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+  EXPECT_EQ(1, operators.size());
+  EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+}
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
+  AddConvOp(true);
+
+  details::OperatorsMap operators;
+  const auto ops_by_type = BuildFakeOperatorByTypeMap();
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+  EXPECT_EQ(1, operators.size());
+  EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+}
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
+  AddConvOp(false);
+  AddConvOp(true);
+
+  details::OperatorsMap operators;
+  const auto ops_by_type = BuildFakeOperatorByTypeMap();
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+  EXPECT_EQ(2, operators.size());
+  EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+  EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+}
+
+TEST_F(VersionedOpExportTest, Export) {
+  AddConvOp(false);
+  AddConvOp(true);
+
+  string result;
+  const auto ops_by_type = BuildFakeOperatorByTypeMap();
+  Export(input_model_, true, &result, ops_by_type);
+
+  auto* model = ::tflite::GetModel(result.data());
+  auto operator_codes = model->operator_codes();
+
+  // Verify that 2 operator codes are populdated. Both are CONV_2D but with
+  // different versions.
+  EXPECT_EQ(2, operator_codes->size());
+  EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
+            (*operator_codes)[0]->builtin_code());
+  EXPECT_EQ(1, (*operator_codes)[0]->version());
+  EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
+            (*operator_codes)[1]->builtin_code());
+  EXPECT_EQ(2, (*operator_codes)[1]->version());
+
+  // Verify that the 2 operators points to the correct indices of the operation
+  // codes.
+  auto operators = (*model->subgraphs())[0]->operators();
+  EXPECT_EQ(2, operators->size());
+  EXPECT_EQ(0, (*operators)[0]->opcode_index());
+  EXPECT_EQ(1, (*operators)[1]->opcode_index());
+}
+
 // TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators.
 
 }  // namespace
index 2cd9700..6922e50 100644 (file)
@@ -53,6 +53,8 @@ class AveragePool
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Convolution
@@ -83,6 +85,8 @@ class Convolution
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class DepthwiseConvolution
@@ -112,6 +116,8 @@ class DepthwiseConvolution
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
@@ -132,6 +138,8 @@ class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class SpaceToBatchND
@@ -149,6 +157,8 @@ class SpaceToBatchND
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
@@ -169,6 +179,8 @@ class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
@@ -189,6 +201,8 @@ class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class BatchToSpaceND
@@ -206,6 +220,8 @@ class BatchToSpaceND
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
@@ -225,6 +241,8 @@ class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
     op->src_data_type = DataType::Deserialize(options.in_data_type());
     op->dst_data_type = DataType::Deserialize(options.out_data_type());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Concatenation
@@ -243,6 +261,8 @@ class Concatenation
                    TocoOperator* op) const override {
     op->axis = options.axis();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
@@ -255,6 +275,8 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
     op->block_size = m["block_size"].AsInt64();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class FakeQuant : public CustomOperator<FakeQuantOperator> {
@@ -274,6 +296,8 @@ class FakeQuant : public CustomOperator<FakeQuantOperator> {
     const auto& num_bits = m["num_bits"];
     op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8;
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class FullyConnected
@@ -295,6 +319,8 @@ class FullyConnected
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
@@ -311,6 +337,8 @@ class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
                    TocoOperator* op) const override {
     op->axis = options.axis();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
@@ -331,6 +359,8 @@ class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
         ActivationFunction::Deserialize(options.fused_activation_function());
     op->rank = options.rank();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class L2Normalization
@@ -351,6 +381,8 @@ class L2Normalization
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
@@ -378,6 +410,8 @@ class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class LocalResponseNormalization
@@ -401,6 +435,8 @@ class LocalResponseNormalization
     op->alpha = options.alpha();
     op->beta = options.beta();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
@@ -428,6 +464,8 @@ class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
@@ -448,6 +486,8 @@ class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
@@ -463,6 +503,8 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
@@ -478,6 +520,8 @@ class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Reshape
@@ -499,6 +543,8 @@ class Reshape
     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
                      options.new_shape()->end());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Softmax
@@ -516,6 +562,8 @@ class Softmax
                    TocoOperator* op) const override {
     op->beta = options.beta();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class SpaceToDepth
@@ -534,6 +582,8 @@ class SpaceToDepth
                    TocoOperator* op) const override {
     op->block_size = options.block_size();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Transpose
@@ -549,6 +599,8 @@ class Transpose
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
@@ -571,6 +623,8 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
     CHECK(options.fused_activation_function() ==
           ::tflite::ActivationFunctionType_TANH);
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
@@ -587,6 +641,8 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
                    TocoOperator* op) const override {
     op->keep_dims = options.keep_dims();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class ResizeBilinear
@@ -605,6 +661,8 @@ class ResizeBilinear
                    TocoOperator* op) const override {
     op->align_corners = options.align_corners();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Squeeze
@@ -626,6 +684,8 @@ class Squeeze
                             options.squeeze_dims()->begin(),
                             options.squeeze_dims()->end());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class Split
@@ -644,6 +704,8 @@ class Split
                    TocoOperator* op) const override {
     op->num_split = options.num_splits();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class StridedSlice
@@ -668,6 +730,8 @@ class StridedSlice
     op->new_axis_mask = options.new_axis_mask();
     op->shrink_axis_mask = options.shrink_axis_mask();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
@@ -682,6 +746,8 @@ class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {}
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
@@ -699,6 +765,8 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
                    TocoOperator* op) const override {
     op->output_data_type = DataType::Deserialize(options.output_type());
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class TransposeConv
@@ -722,6 +790,8 @@ class TransposeConv
     op->stride_width = options.stride_w();
     op->stride_height = options.stride_h();
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 class TensorFlowUnsupported : public BaseOperator {
@@ -828,6 +898,12 @@ class TensorFlowUnsupported : public BaseOperator {
     }
     node_def.SerializeToString(&op->tensorflow_node_def);
   }
+
+  int GetVersion(const Operator& op) const override {
+    // TODO(ycling): Deisng and implement a way to plumb the version of
+    // custom ops.
+    return 1;
+  }
 };
 
 namespace {
index 88af3d6..50f0620 100644 (file)
@@ -77,6 +77,16 @@ class BaseOperator {
       const BuiltinOptions* builtin_options,
       const CustomOptions* custom_options) const = 0;
 
+  // Get the op version by op parameters.
+  // The function need to be overridden to return the op version based on the
+  // parameters. Note:
+  // * The first version for each op should be 1 (to be consistent with the
+  //   default value in Flatbuffer. `return 1;` is okay for newly implemented
+  //   ops.
+  // * When multiple versions are defined for an op, this function need to be
+  //   overridden. (See example in `operator_test.cc`)
+  virtual int GetVersion(const Operator& op) const = 0;
+
  private:
   string name_;
   OperatorType type_;
index 72678c8..a7f7e88 100644 (file)
@@ -41,6 +41,8 @@ class SimpleOperator : public BaseOperator {
       const CustomOptions* custom_options) const override {
     return std::unique_ptr<Operator>(new T);
   }
+
+  int GetVersion(const Operator& op) const override { return 1; }
 };
 
 }  // namespace tflite