From 487d2ab835286f4eea891d93bc32adfd5543aef8 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Fri, 18 May 2018 10:39:45 -0700 Subject: [PATCH] Op version: Populate version in Toco TFLite exporter PiperOrigin-RevId: 197166962 --- tensorflow/contrib/lite/toco/tflite/export.cc | 39 +++-- tensorflow/contrib/lite/toco/tflite/export.h | 51 +++++-- tensorflow/contrib/lite/toco/tflite/export_test.cc | 168 ++++++++++++++++++++- tensorflow/contrib/lite/toco/tflite/operator.cc | 76 ++++++++++ tensorflow/contrib/lite/toco/tflite/operator.h | 10 ++ .../contrib/lite/toco/tflite/simple_operator.h | 2 + 6 files changed, 319 insertions(+), 27 deletions(-) diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index a4c0b2d..5daa703 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -45,14 +45,20 @@ using ::tflite::Tensor; namespace { -details::OperatorKey GetOperatorKey(const ::toco::Operator& op) { +details::OperatorKey GetOperatorKey( + const ::toco::Operator& op, + const std::map>& ops_by_type) { string custom_code; if (op.type == OperatorType::kTensorFlowUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; } - return details::OperatorKey(op.type, custom_code); + int version = 1; + if (ops_by_type.count(op.type) != 0) { + version = ops_by_type.at(op.type)->GetVersion(op); + } + return details::OperatorKey(op.type, custom_code, version); } } // Anonymous namespace. @@ -74,11 +80,13 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { } } -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type) { // First find a list of unique operator types. std::set keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op)); + keys.insert(GetOperatorKey(*op, ops_by_type)); } // Now assign indices to them and fill in the map. int index = 0; @@ -185,8 +193,9 @@ Offset>> ExportOperatorCodes( std::map> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op); + const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); int op_index = operators_map.at(operator_key); + int op_version = operator_key.version; string name = HelpfulOperatorTypeName(*op); bool is_builtin = false; @@ -197,7 +206,7 @@ Offset>> ExportOperatorCodes( if (is_builtin) { ordered_opcodes[op_index] = - CreateOperatorCode(*builder, builtin_ops[name], 0); + CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { // This could be a kTensorFlowUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or @@ -211,8 +220,9 @@ Offset>> ExportOperatorCodes( if (error_summary) { error_summary->insert(name); } - ordered_opcodes[op_index] = CreateOperatorCode( - *builder, BuiltinOperator_CUSTOM, builder->CreateString(name)); + ordered_opcodes[op_index] = + CreateOperatorCode(*builder, BuiltinOperator_CUSTOM, + builder->CreateString(name), op_version); } } @@ -244,7 +254,7 @@ Offset>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op)); + int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); // This is a custom op unless we can find it in ops_by_type, and even then // it could be a custom op (such as kTensorFlowUnsupported). @@ -279,15 +289,20 @@ Offset>> ExportBuffers( void Export(const Model& model, bool allow_custom_ops, string* output_file_contents) { - flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); - const auto ops_by_type = BuildOperatorByTypeMap(); + Export(model, allow_custom_ops, output_file_contents, ops_by_type); +} + +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type) { + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); details::TensorsMap tensors_map; details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map); + details::LoadOperatorsMap(model, &operators_map, ops_by_type); std::vector buffers_to_write; Array empty_array; diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 8c79cb8..90abfb9 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ #include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" namespace toco { @@ -25,11 +26,18 @@ namespace tflite { // result in the given string. void Export(const Model& model, bool allow_custom_ops, string* output_file_contents); + // This if backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. inline void Export(const Model& model, string* output_file_contents) { Export(model, true, output_file_contents); } +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, bool allow_custom_ops, string* output_file_contents, + const std::map>& ops_by_type); + namespace details { // A maps from tensor name to its final position in the TF Lite buffer. @@ -39,25 +47,47 @@ using TensorsMap = std::unordered_map; // Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { - OperatorKey(OperatorType type, const std::string& custom_code) - : type(type), custom_code(custom_code) {} + OperatorKey(OperatorType type, const std::string& custom_code, int version) + : type(type), custom_code(custom_code), version(version) {} const OperatorType type; const std::string custom_code; + const int version; bool operator<(const OperatorKey& other) const { if (type < other.type) return true; - if (type > other.type) return false; - return custom_code < other.custom_code; + else if (type > other.type) + return false; + else if (custom_code < other.custom_code) + return true; + else if (custom_code > other.custom_code) + return false; + else + return version < other.version; } bool operator==(const OperatorKey& other) const { - return type == other.type && custom_code == other.custom_code; + return type == other.type && custom_code == other.custom_code && + version == other.version; } struct Hash { - std::size_t operator()(const OperatorKey& key) const { - return std::hash()(static_cast(key.type)) ^ - std::hash()(key.custom_code); + size_t operator()(const OperatorKey& key) const { + return CombineHashes({std::hash()(static_cast(key.type)), + std::hash()(key.custom_code), + std::hash()(key.version)}); + } + + private: + // TODO(ycling): Refactoring and extract this function into a common + // utility module. + static size_t CombineHashes(std::initializer_list hashes) { + size_t result = 0; + // Hash combiner used by TensorFlow core. + for (size_t hash : hashes) { + result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) + + (result >> 4)); + } + return result; } }; }; @@ -66,11 +96,12 @@ struct OperatorKey { using OperatorsMap = std::unordered_map; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); -void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); +void LoadOperatorsMap( + const Model& model, OperatorsMap* operators_map, + const std::map>& ops_by_type); } // namespace details } // namespace tflite - } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 6754372..409e7d7 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" namespace toco { namespace tflite { @@ -65,12 +68,13 @@ TEST_F(ExportTest, LoadOperatorsMap) { BuildTestModel(); details::OperatorsMap operators; - details::LoadOperatorsMap(input_model_, &operators); - EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]); - EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]); - EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "")]); + const auto ops_by_type = BuildOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); + EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); + EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]); + OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { @@ -104,6 +108,160 @@ TEST_F(ExportTest, Export) { EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2)); } +// This test is based on a hypothetical scenario that dilation is supported +// only in Conv version 2. So Toco populates version=1 when dialation +// parameters are all 1, and version=2 otehrwise. +class FakeConvolutionOperator + : public BuiltinOperator { + public: + FakeConvolutionOperator() + : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D, + OperatorType::kConv) {} + + // Returning the op version according to the op parameters. + int GetVersion(const Operator& op) const override { + const TocoOperator& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + // Version 2 if dilation is used. + return 2; + } + return 1; + } + + // Note: The read / write code doesn't need to be changed if we stick with + // the restrictions: + // * Only adding parameters at the bottom of the Flatbuffer tables. + // * When the default value of parameters are used, the op works consistently + // with the previous version. + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, + op.stride_height, activation_function, + op.dilation_width_factor, + op.dilation_height_factor); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->dilation_width_factor = options.dilation_w_factor(); + op->dilation_height_factor = options.dilation_h_factor(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class VersionedOpExportTest : public ::testing::Test { + protected: + void SetUp() override { + input_model_.GetOrCreateArray("input"); + input_model_.GetOrCreateArray("filter"); + input_model_.GetOrCreateArray("output"); + } + void AddConvOp(bool use_dialation) { + { + auto* op = new ConvOperator; + op->inputs.push_back("input"); + op->inputs.push_back("filter"); + op->inputs.push_back("output"); + + op->padding.type = PaddingType::kSame; + op->stride_width = 1; + op->stride_height = 1; + if (use_dialation) { + op->dilation_width_factor = 2; + op->dilation_height_factor = 2; + } else { + op->dilation_width_factor = 1; + op->dilation_height_factor = 1; + } + input_model_.operators.emplace_back(op); + } + } + + std::map> + BuildFakeOperatorByTypeMap() { + std::map> result; + result[OperatorType::kConv] = + std::unique_ptr(new FakeConvolutionOperator); + return result; + } + + Model input_model_; +}; + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { + AddConvOp(false); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(1, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { + AddConvOp(false); + AddConvOp(true); + + details::OperatorsMap operators; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + + EXPECT_EQ(2, operators.size()); + EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); + EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); +} + +TEST_F(VersionedOpExportTest, Export) { + AddConvOp(false); + AddConvOp(true); + + string result; + const auto ops_by_type = BuildFakeOperatorByTypeMap(); + Export(input_model_, true, &result, ops_by_type); + + auto* model = ::tflite::GetModel(result.data()); + auto operator_codes = model->operator_codes(); + + // Verify that 2 operator codes are populdated. Both are CONV_2D but with + // different versions. + EXPECT_EQ(2, operator_codes->size()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[0]->builtin_code()); + EXPECT_EQ(1, (*operator_codes)[0]->version()); + EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D, + (*operator_codes)[1]->builtin_code()); + EXPECT_EQ(2, (*operator_codes)[1]->version()); + + // Verify that the 2 operators points to the correct indices of the operation + // codes. + auto operators = (*model->subgraphs())[0]->operators(); + EXPECT_EQ(2, operators->size()); + EXPECT_EQ(0, (*operators)[0]->opcode_index()); + EXPECT_EQ(1, (*operators)[1]->opcode_index()); +} + // TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 2cd9700..6922e50 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -53,6 +53,8 @@ class AveragePool op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Convolution @@ -83,6 +85,8 @@ class Convolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthwiseConvolution @@ -112,6 +116,8 @@ class DepthwiseConvolution op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Add : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToBatchND @@ -149,6 +157,8 @@ class SpaceToBatchND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Sub : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Div : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class BatchToSpaceND @@ -206,6 +220,8 @@ class BatchToSpaceND void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Cast : public BuiltinOperatorsrc_data_type = DataType::Deserialize(options.in_data_type()); op->dst_data_type = DataType::Deserialize(options.out_data_type()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Concatenation @@ -243,6 +261,8 @@ class Concatenation TocoOperator* op) const override { op->axis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class DepthToSpace : public CustomOperator { @@ -255,6 +275,8 @@ class DepthToSpace : public CustomOperator { void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { op->block_size = m["block_size"].AsInt64(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class FakeQuant : public CustomOperator { @@ -274,6 +296,8 @@ class FakeQuant : public CustomOperator { const auto& num_bits = m["num_bits"]; op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; } + + int GetVersion(const Operator& op) const override { return 1; } }; class FullyConnected @@ -295,6 +319,8 @@ class FullyConnected op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Gather : public BuiltinOperatoraxis = options.axis(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Svdf : public BuiltinOperatorrank = options.rank(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Normalization @@ -351,6 +381,8 @@ class L2Normalization op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class L2Pool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class LocalResponseNormalization @@ -401,6 +435,8 @@ class LocalResponseNormalization op->alpha = options.alpha(); op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class MaxPool : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Mul : public BuiltinOperatorfused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Pad : public BuiltinOperatorshape.insert(op->shape.end(), options.new_shape()->begin(), options.new_shape()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Softmax @@ -516,6 +562,8 @@ class Softmax TocoOperator* op) const override { op->beta = options.beta(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class SpaceToDepth @@ -534,6 +582,8 @@ class SpaceToDepth TocoOperator* op) const override { op->block_size = options.block_size(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Transpose @@ -549,6 +599,8 @@ class Transpose void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } }; class Lstm : public BuiltinOperatorkeep_dims = options.keep_dims(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class ResizeBilinear @@ -605,6 +661,8 @@ class ResizeBilinear TocoOperator* op) const override { op->align_corners = options.align_corners(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Squeeze @@ -626,6 +684,8 @@ class Squeeze options.squeeze_dims()->begin(), options.squeeze_dims()->end()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class Split @@ -644,6 +704,8 @@ class Split TocoOperator* op) const override { op->num_split = options.num_splits(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class StridedSlice @@ -668,6 +730,8 @@ class StridedSlice op->new_axis_mask = options.new_axis_mask(); op->shrink_axis_mask = options.shrink_axis_mask(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TopK_V2 : public BuiltinOperatoroutput_data_type = DataType::Deserialize(options.output_type()); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TransposeConv @@ -722,6 +790,8 @@ class TransposeConv op->stride_width = options.stride_w(); op->stride_height = options.stride_h(); } + + int GetVersion(const Operator& op) const override { return 1; } }; class TensorFlowUnsupported : public BaseOperator { @@ -828,6 +898,12 @@ class TensorFlowUnsupported : public BaseOperator { } node_def.SerializeToString(&op->tensorflow_node_def); } + + int GetVersion(const Operator& op) const override { + // TODO(ycling): Deisng and implement a way to plumb the version of + // custom ops. + return 1; + } }; namespace { diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 88af3d6..50f0620 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -77,6 +77,16 @@ class BaseOperator { const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const = 0; + // Get the op version by op parameters. + // The function need to be overridden to return the op version based on the + // parameters. Note: + // * The first version for each op should be 1 (to be consistent with the + // default value in Flatbuffer. `return 1;` is okay for newly implemented + // ops. + // * When multiple versions are defined for an op, this function need to be + // overridden. (See example in `operator_test.cc`) + virtual int GetVersion(const Operator& op) const = 0; + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h index 72678c8..a7f7e88 100644 --- a/tensorflow/contrib/lite/toco/tflite/simple_operator.h +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -41,6 +41,8 @@ class SimpleOperator : public BaseOperator { const CustomOptions* custom_options) const override { return std::unique_ptr(new T); } + + int GetVersion(const Operator& op) const override { return 1; } }; } // namespace tflite -- 2.7.4