namespace {
-details::OperatorKey GetOperatorKey(const ::toco::Operator& op) {
+details::OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
string custom_code;
if (op.type == OperatorType::kTensorFlowUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
custom_code = unsupported_op.tensorflow_op;
}
- return details::OperatorKey(op.type, custom_code);
+ int version = 1;
+ if (ops_by_type.count(op.type) != 0) {
+ version = ops_by_type.at(op.type)->GetVersion(op);
+ }
+ return details::OperatorKey(op.type, custom_code, version);
}
} // Anonymous namespace.
}
}
-void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) {
+void LoadOperatorsMap(
+ const Model& model, OperatorsMap* operators_map,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op));
+ keys.insert(GetOperatorKey(*op, ops_by_type));
}
// Now assign indices to them and fill in the map.
int index = 0;
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
- const details::OperatorKey operator_key = GetOperatorKey(*op);
+ const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
int op_index = operators_map.at(operator_key);
+ int op_version = operator_key.version;
string name = HelpfulOperatorTypeName(*op);
bool is_builtin = false;
if (is_builtin) {
ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, builtin_ops[name], 0);
+ CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
} else {
// This could be a kTensorFlowUnsupported, in which case we should be
// able to retrieve the original Tensorflow name from the OperatorKey, or
if (error_summary) {
error_summary->insert(name);
}
- ordered_opcodes[op_index] = CreateOperatorCode(
- *builder, BuiltinOperator_CUSTOM, builder->CreateString(name));
+ ordered_opcodes[op_index] =
+ CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
+ builder->CreateString(name), op_version);
}
}
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(GetOperatorKey(*op));
+ int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
// This is a custom op unless we can find it in ops_by_type, and even then
// it could be a custom op (such as kTensorFlowUnsupported).
void Export(const Model& model, bool allow_custom_ops,
string* output_file_contents) {
- flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
-
const auto ops_by_type = BuildOperatorByTypeMap();
+ Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+}
+
+void Export(
+ const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
details::TensorsMap tensors_map;
details::LoadTensorsMap(model, &tensors_map);
details::OperatorsMap operators_map;
- details::LoadOperatorsMap(model, &operators_map);
+ details::LoadOperatorsMap(model, &operators_map, ops_by_type);
std::vector<const Array*> buffers_to_write;
Array empty_array;
#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
namespace toco {
// result in the given string.
void Export(const Model& model, bool allow_custom_ops,
string* output_file_contents);
+
// This if backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
inline void Export(const Model& model, string* output_file_contents) {
Export(model, true, output_file_contents);
}
+// Export API with custom TFLite operator mapping.
+void Export(
+ const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+
namespace details {
// A maps from tensor name to its final position in the TF Lite buffer.
// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
- OperatorKey(OperatorType type, const std::string& custom_code)
- : type(type), custom_code(custom_code) {}
+ OperatorKey(OperatorType type, const std::string& custom_code, int version)
+ : type(type), custom_code(custom_code), version(version) {}
const OperatorType type;
const std::string custom_code;
+ const int version;
bool operator<(const OperatorKey& other) const {
if (type < other.type) return true;
- if (type > other.type) return false;
- return custom_code < other.custom_code;
+ else if (type > other.type)
+ return false;
+ else if (custom_code < other.custom_code)
+ return true;
+ else if (custom_code > other.custom_code)
+ return false;
+ else
+ return version < other.version;
}
bool operator==(const OperatorKey& other) const {
- return type == other.type && custom_code == other.custom_code;
+ return type == other.type && custom_code == other.custom_code &&
+ version == other.version;
}
struct Hash {
- std::size_t operator()(const OperatorKey& key) const {
- return std::hash<size_t>()(static_cast<size_t>(key.type)) ^
- std::hash<std::string>()(key.custom_code);
+ size_t operator()(const OperatorKey& key) const {
+ return CombineHashes({std::hash<size_t>()(static_cast<size_t>(key.type)),
+ std::hash<std::string>()(key.custom_code),
+ std::hash<int>()(key.version)});
+ }
+
+ private:
+ // TODO(ycling): Refactoring and extract this function into a common
+ // utility module.
+ static size_t CombineHashes(std::initializer_list<size_t> hashes) {
+ size_t result = 0;
+ // Hash combiner used by TensorFlow core.
+ for (size_t hash : hashes) {
+ result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) +
+ (result >> 4));
+ }
+ return result;
}
};
};
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
-void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map);
+void LoadOperatorsMap(
+ const Model& model, OperatorsMap* operators_map,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
} // namespace details
} // namespace tflite
-
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/toco/tflite/types.h"
namespace toco {
namespace tflite {
BuildTestModel();
details::OperatorsMap operators;
- details::LoadOperatorsMap(input_model_, &operators);
- EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]);
- EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]);
- EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "")]);
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
+ EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
+ EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
EXPECT_EQ(3, operators[details::OperatorKey(
- OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]);
+ OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]);
}
TEST_F(ExportTest, Export) {
EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
}
+// This test is based on a hypothetical scenario that dilation is supported
+// only in Conv version 2. So Toco populates version=1 when dialation
+// parameters are all 1, and version=2 otehrwise.
+class FakeConvolutionOperator
+ : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
+ ::tflite::BuiltinOptions_Conv2DOptions> {
+ public:
+ FakeConvolutionOperator()
+ : BuiltinOperator(::tflite::BuiltinOperator_CONV_2D,
+ OperatorType::kConv) {}
+
+ // Returning the op version according to the op parameters.
+ int GetVersion(const Operator& op) const override {
+ const TocoOperator& conv_op = static_cast<const TocoOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ // Version 2 if dilation is used.
+ return 2;
+ }
+ return 1;
+ }
+
+ // Note: The read / write code doesn't need to be changed if we stick with
+ // the restrictions:
+ // * Only adding parameters at the bottom of the Flatbuffer tables.
+ // * When the default value of parameters are used, the op works consistently
+ // with the previous version.
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ auto padding = Padding::Serialize(op.padding.type);
+ auto activation_function =
+ ActivationFunction::Serialize(op.fused_activation_function);
+ return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
+ op.stride_height, activation_function,
+ op.dilation_width_factor,
+ op.dilation_height_factor);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->padding.type = Padding::Deserialize(options.padding());
+ op->stride_width = options.stride_w();
+ op->stride_height = options.stride_h();
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
+ op->fused_activation_function =
+ ActivationFunction::Deserialize(options.fused_activation_function());
+ }
+};
+
+class VersionedOpExportTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ input_model_.GetOrCreateArray("input");
+ input_model_.GetOrCreateArray("filter");
+ input_model_.GetOrCreateArray("output");
+ }
+ void AddConvOp(bool use_dialation) {
+ {
+ auto* op = new ConvOperator;
+ op->inputs.push_back("input");
+ op->inputs.push_back("filter");
+ op->inputs.push_back("output");
+
+ op->padding.type = PaddingType::kSame;
+ op->stride_width = 1;
+ op->stride_height = 1;
+ if (use_dialation) {
+ op->dilation_width_factor = 2;
+ op->dilation_height_factor = 2;
+ } else {
+ op->dilation_width_factor = 1;
+ op->dilation_height_factor = 1;
+ }
+ input_model_.operators.emplace_back(op);
+ }
+ }
+
+ std::map<OperatorType, std::unique_ptr<BaseOperator>>
+ BuildFakeOperatorByTypeMap() {
+ std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
+ result[OperatorType::kConv] =
+ std::unique_ptr<BaseOperator>(new FakeConvolutionOperator);
+ return result;
+ }
+
+ Model input_model_;
+};
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
+ AddConvOp(false);
+
+ details::OperatorsMap operators;
+ const auto ops_by_type = BuildFakeOperatorByTypeMap();
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+ EXPECT_EQ(1, operators.size());
+ EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+}
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
+ AddConvOp(true);
+
+ details::OperatorsMap operators;
+ const auto ops_by_type = BuildFakeOperatorByTypeMap();
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+ EXPECT_EQ(1, operators.size());
+ EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+}
+
+TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
+ AddConvOp(false);
+ AddConvOp(true);
+
+ details::OperatorsMap operators;
+ const auto ops_by_type = BuildFakeOperatorByTypeMap();
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+
+ EXPECT_EQ(2, operators.size());
+ EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+ EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+}
+
+TEST_F(VersionedOpExportTest, Export) {
+ AddConvOp(false);
+ AddConvOp(true);
+
+ string result;
+ const auto ops_by_type = BuildFakeOperatorByTypeMap();
+ Export(input_model_, true, &result, ops_by_type);
+
+ auto* model = ::tflite::GetModel(result.data());
+ auto operator_codes = model->operator_codes();
+
+ // Verify that 2 operator codes are populdated. Both are CONV_2D but with
+ // different versions.
+ EXPECT_EQ(2, operator_codes->size());
+ EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
+ (*operator_codes)[0]->builtin_code());
+ EXPECT_EQ(1, (*operator_codes)[0]->version());
+ EXPECT_EQ(::tflite::BuiltinOperator_CONV_2D,
+ (*operator_codes)[1]->builtin_code());
+ EXPECT_EQ(2, (*operator_codes)[1]->version());
+
+ // Verify that the 2 operators points to the correct indices of the operation
+ // codes.
+ auto operators = (*model->subgraphs())[0]->operators();
+ EXPECT_EQ(2, operators->size());
+ EXPECT_EQ(0, (*operators)[0]->opcode_index());
+ EXPECT_EQ(1, (*operators)[1]->opcode_index());
+}
+
// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators.
} // namespace
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Convolution
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class DepthwiseConvolution
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class SpaceToBatchND
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class BatchToSpaceND
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
op->src_data_type = DataType::Deserialize(options.in_data_type());
op->dst_data_type = DataType::Deserialize(options.out_data_type());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Concatenation
TocoOperator* op) const override {
op->axis = options.axis();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
op->block_size = m["block_size"].AsInt64();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class FakeQuant : public CustomOperator<FakeQuantOperator> {
const auto& num_bits = m["num_bits"];
op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8;
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class FullyConnected
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
TocoOperator* op) const override {
op->axis = options.axis();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
ActivationFunction::Deserialize(options.fused_activation_function());
op->rank = options.rank();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class L2Normalization
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class LocalResponseNormalization
op->alpha = options.alpha();
op->beta = options.beta();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Reshape
op->shape.insert(op->shape.end(), options.new_shape()->begin(),
options.new_shape()->end());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Softmax
TocoOperator* op) const override {
op->beta = options.beta();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class SpaceToDepth
TocoOperator* op) const override {
op->block_size = options.block_size();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Transpose
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
TocoOperator* op) const override {
op->keep_dims = options.keep_dims();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class ResizeBilinear
TocoOperator* op) const override {
op->align_corners = options.align_corners();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Squeeze
options.squeeze_dims()->begin(),
options.squeeze_dims()->end());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class Split
TocoOperator* op) const override {
op->num_split = options.num_splits();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class StridedSlice
op->new_axis_mask = options.new_axis_mask();
op->shrink_axis_mask = options.shrink_axis_mask();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
TocoOperator* op) const override {
op->output_data_type = DataType::Deserialize(options.output_type());
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class TransposeConv
op->stride_width = options.stride_w();
op->stride_height = options.stride_h();
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
class TensorFlowUnsupported : public BaseOperator {
}
node_def.SerializeToString(&op->tensorflow_node_def);
}
+
+ int GetVersion(const Operator& op) const override {
+ // TODO(ycling): Deisng and implement a way to plumb the version of
+ // custom ops.
+ return 1;
+ }
};
namespace {
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const = 0;
+ // Get the op version by op parameters.
+ // The function need to be overridden to return the op version based on the
+ // parameters. Note:
+ // * The first version for each op should be 1 (to be consistent with the
+ // default value in Flatbuffer. `return 1;` is okay for newly implemented
+ // ops.
+ // * When multiple versions are defined for an op, this function need to be
+ // overridden. (See example in `operator_test.cc`)
+ virtual int GetVersion(const Operator& op) const = 0;
+
private:
string name_;
OperatorType type_;
const CustomOptions* custom_options) const override {
return std::unique_ptr<Operator>(new T);
}
+
+ int GetVersion(const Operator& op) const override { return 1; }
};
} // namespace tflite