From 7c4cdb8bae0e8760ebe4793d49ea5aee68768655 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Thu, 22 Mar 2018 11:25:49 -0700 Subject: [PATCH] Supports PReLU in TFLite & Toco. PiperOrigin-RevId: 190097557 --- tensorflow/contrib/lite/builtin_ops.h | 1 + tensorflow/contrib/lite/kernels/activations.cc | 64 +++++++++++ .../contrib/lite/kernels/activations_test.cc | 43 ++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/model.cc | 1 + tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 1 + tensorflow/contrib/lite/schema/schema_generated.h | 9 +- tensorflow/contrib/lite/testing/BUILD | 1 + .../contrib/lite/testing/generate_examples.py | 49 +++++++++ .../lite/testing/generated_examples_zip_test.cc | 4 + tensorflow/contrib/lite/toco/BUILD | 1 + .../graph_transformations/graph_transformations.h | 1 + .../toco/graph_transformations/identify_prelu.cc | 119 +++++++++++++++++++++ .../graph_transformations/propagate_fixed_sizes.cc | 1 + tensorflow/contrib/lite/toco/model.h | 13 +++ tensorflow/contrib/lite/toco/tflite/operator.cc | 2 + tensorflow/contrib/lite/toco/toco_tooling.cc | 1 + tensorflow/contrib/lite/toco/tooling_util.cc | 1 + 19 files changed, 312 insertions(+), 3 deletions(-) create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index e4652a3..d7993e6 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -78,6 +78,7 @@ typedef enum { kTfLiteBuiltinDelegate = 51, kTfLiteBuiltinBidirectionalSequenceLstm = 52, kTfLiteBuiltinCast = 53, + kTfLiteBuiltinPrelu = 54, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 093761c..39a54c9 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -150,6 +150,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* alpha = GetInput(context, node, 1); + + output->type = input->type; + + // Currently only Float32 is supported + // TODO(ycling): Support other data types. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32); + + // Currently, only support 4D `input` and 3D `alpha` with shape + // (1, 1, channels). + // TODO(impjdi): Support other cases where `alpha` is broadcastable + // to `input`. + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); @@ -388,6 +416,35 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* alpha = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + + if (input->type != kTfLiteFloat32) { + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + const int batches = input->dims->data[0]; + const int height = input->dims->data[1]; + const int width = input->dims->data[2]; + const int channels = input->dims->data[3]; + + TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); + TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels); + + const int n = batches * height * width * channels; + for (int i = 0; i < n; ++i) { + const float x = input->data.f[i]; + output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x; + } + + return kTfLiteOk; +} + } // namespace activations TfLiteRegistration* Register_RELU() { @@ -439,6 +496,13 @@ TfLiteRegistration* Register_LOG_SOFTMAX() { return &r; } +TfLiteRegistration* Register_PRELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::PreluPrepare, + activations::PreluEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index b9a96e3..50a84ed 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -383,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } +class PReluOpModel : public SingleOpModel { + public: + PReluOpModel(const TensorData& input, const TensorData& alpha) { + input_ = AddInput(input); + alpha_ = AddInput(alpha); + output_ = AddOutput(input); + SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_), GetShape(alpha_)}); + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAlpha(std::initializer_list data) { + PopulateTensor(alpha_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int alpha_; + int output_; +}; + +TEST(FloatActivationsOpTest, PRelu) { + PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}}, + {TensorType_FLOAT32, {1, 1, 3}}); + + m.SetInput({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -2.0f, -2.0f, -2.0f, // Row 1, Column 2 + }); + m.SetAlpha({0.0f, 1.0f, 2.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + 0.0f, -1.0f, -2.0f, // Row 2, Column 1 + 0.0f, -2.0f, -4.0f, // Row 1, Column 2 + })); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 369d3b9..62045f0 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -75,6 +75,7 @@ TfLiteRegistration* Register_TOPK_V2(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); +TfLiteRegistration* Register_PRELU(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -131,6 +132,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); + AddBuiltin(BuiltinOperator_PRELU, Register_PRELU()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 9c619f8..b7ccdf0 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -309,6 +309,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_CAST: case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_PRELU: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 9d00d96..e31b7c0 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -349,6 +349,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DEQUANTIZE: case tflite::BuiltinOperator_DELEGATE: case tflite::BuiltinOperator_CAST: + case tflite::BuiltinOperator_PRELU: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 04387fe..e107597 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -130,6 +130,7 @@ enum BuiltinOperator : byte { DELEGATE = 51, BIDIRECTIONAL_SEQUENCE_LSTM = 52, CAST = 53, + PRELU = 54, } // Options for the builtin operators. diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index b922de2..86daeaf 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -254,11 +254,12 @@ enum BuiltinOperator { BuiltinOperator_DELEGATE = 51, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, BuiltinOperator_CAST = 53, + BuiltinOperator_PRELU = 54, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_CAST + BuiltinOperator_MAX = BuiltinOperator_PRELU }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[53] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -311,7 +312,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] { BuiltinOperator_LOG_SOFTMAX, BuiltinOperator_DELEGATE, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOperator_CAST + BuiltinOperator_CAST, + BuiltinOperator_PRELU }; return values; } @@ -372,6 +374,7 @@ inline const char **EnumNamesBuiltinOperator() { "DELEGATE", "BIDIRECTIONAL_SEQUENCE_LSTM", "CAST", + "PRELU", nullptr }; return names; diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index f1b18ad..555ea90 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -39,6 +39,7 @@ gen_zipped_test_files( "mean.zip", "mul.zip", "pad.zip", + "prelu.zip", "relu.zip", "relu1.zip", "relu6.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 420bdb4..38de9dc 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -617,6 +617,54 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [{ + # The canonical case for image processing is having a 4D `input` (NHWC) + # and `shared_axes`=[1, 2], so the alpha parameter is per channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + with tf.variable_scope("", reuse=True): + alpha = tf.get_variable("p_re_lu/alpha") + sess.run(alpha.assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -1911,6 +1959,7 @@ def main(unused_args): "relu.zip": make_relu_tests, "relu1.zip": make_relu1_tests, "relu6.zip": make_relu6_tests, + "prelu.zip": make_prelu_tests, "l2_pool.zip": make_pool_tests(make_l2_pool), "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), "max_pool.zip": make_pool_tests(tf.nn.max_pool), diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 5e76e7c..ba2d259 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -88,6 +88,9 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, + + // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. + {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -253,6 +256,7 @@ INSTANTIATE_TESTS(mul) INSTANTIATE_TESTS(pad) INSTANTIATE_TESTS(relu) INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(prelu) INSTANTIATE_TESTS(relu6) INSTANTIATE_TESTS(reshape) INSTANTIATE_TESTS(resize_bilinear) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 395abc5..486ff1e 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -193,6 +193,7 @@ cc_library( "graph_transformations/identify_lstm.cc", "graph_transformations/identify_lstm_merge_inputs.cc", "graph_transformations/identify_lstm_split_inputs.cc", + "graph_transformations/identify_prelu.cc", "graph_transformations/identify_relu1.cc", "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 11e5e19..640afc7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -129,6 +129,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc new file mode 100644 index 0000000..30be4ac --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +// This transformation rule tries to identify the PRelu structure generated by +// Keras, and convert it to a single op. +// +// The formula of PReLU is: +// f(x) = alpha * x for x < 0, f(x) = x for x >= 0. +// +// `x` is the input, and `alpha` is a trainable tensor which can be broadcasted +// to the shape of `x`. +// +// There's no native PRelu op in TensorFlow, so Keras generates the following +// structure which does the equivalent calculation: +// f(x) = Relu(x) + (-alpha * Relu(-x)) +// +// Practically, alpha is always a constant in the inference graph, and Toco have +// other graph transformations which fold the activation functions to other ops. +// Therefore, we're looking for the structure: +// +// f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu)) + +namespace toco { + +bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { + const auto add_op_it = model->operators.begin() + op_index; + const auto* add_op = add_op_it->get(); + if (add_op == nullptr || add_op->type != OperatorType::kAdd || + add_op->inputs.size() != 2 || + add_op->fused_activation_function != FusedActivationFunctionType::kNone) { + return false; + } + + const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]); + if (relu_input_op == nullptr || relu_input_op->type != OperatorType::kRelu || + relu_input_op->inputs.size() != 1 || + relu_input_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + return false; + } + + // TODO(ycling): Both Add and Mul are commutative. Support the case where + // the position of operands are exchanged. + const auto* mul_op = GetOpWithOutput(*model, add_op->inputs[1]); + if (mul_op == nullptr || mul_op->type != OperatorType::kMul || + mul_op->inputs.size() != 2 || + mul_op->fused_activation_function != FusedActivationFunctionType::kNone) { + return false; + } + + const auto neg_alpha_tensor_name = mul_op->inputs[0]; + + const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]); + + if (relu_neg_input_op == nullptr || + relu_neg_input_op->type != OperatorType::kNeg || + relu_neg_input_op->fused_activation_function != + FusedActivationFunctionType::kRelu || + relu_neg_input_op->inputs.size() != 1) { + return false; + } + + if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) { + return false; + } + + const auto input_tensor_name = relu_input_op->inputs[0]; + const auto output_tensor_name = add_op->outputs[0]; + + // Construct a tensor for positive alpha (double negative). + const auto alpha_tensor_name = + AvailableArrayName(*model, neg_alpha_tensor_name + "_neg"); + model->GetOrCreateArray(alpha_tensor_name); + + auto* neg_neg_alpha_op = new NegOperator; + neg_neg_alpha_op->inputs = {neg_alpha_tensor_name}; + neg_neg_alpha_op->outputs = {alpha_tensor_name}; + model->operators.emplace(add_op_it, neg_neg_alpha_op); + + auto* prelu_op = new PReluOperator; + prelu_op->inputs = {input_tensor_name, alpha_tensor_name}; + prelu_op->outputs = {output_tensor_name}; + model->operators.emplace(add_op_it, prelu_op); + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op)); + + DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model); + DeleteArrayIfUsedOnce(add_op->inputs[0], model); + DeleteArrayIfUsedOnce(add_op->inputs[1], model); + DeleteArrayIfUsedOnce(mul_op->inputs[1], model); + // Remove the existing Add op that outputs the final result. If the other + // intermediate tensors aren't used by other ops, those will be removed by + // other graph transformation rules. + model->operators.erase(FindOp(*model, add_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 375848a..676736c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1467,6 +1467,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu: case OperatorType::kRelu1: case OperatorType::kRelu6: + case OperatorType::kPRelu: case OperatorType::kSoftmax: case OperatorType::kLogSoftmax: case OperatorType::kLogistic: diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 3fa0089..5199e29 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -65,6 +65,7 @@ enum class OperatorType { kRelu, kRelu1, kRelu6, + kPRelu, kSoftmax, kLogSoftmax, kSub, @@ -566,6 +567,18 @@ struct Relu6Operator : Operator { Relu6Operator() : Operator(OperatorType::kRelu6) {} }; +// PRelu +// f(x) = alpha * x for x < 0, f(x) = x for x >= 0. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the alpha array +// +// Equivalent to keras.layers.PReLU. +struct PReluOperator : Operator { + PReluOperator() : Operator(OperatorType::kPRelu) {} +}; + // Element-wise Logistic operator: // x -> Logistic(x) = 1 / (1 + exp(-x)) // diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index f2cc4ef..f23249c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -854,6 +854,8 @@ std::vector> BuildOperatorList() { new SimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1)); ops.emplace_back( new SimpleOperator("RELU6", OperatorType::kRelu6)); + ops.emplace_back( + new SimpleOperator("PRELU", OperatorType::kPRelu)); ops.emplace_back(new SimpleOperator( "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index ca66110..30dd6fa 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -94,6 +94,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyL2Normalization); transformations->Add(new IdentifyL2Pool); transformations->Add(new IdentifyRelu1); + transformations->Add(new IdentifyPRelu); transformations->Add(new RemoveTrivialBinaryOperator); transformations->Add(new ReadFakeQuantMinMax); transformations->Add(new ResolveSpaceToBatchNDAttributes); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 2362206..ec1770c 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -300,6 +300,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Relu) HANDLE_OPERATORTYPENAME_CASE(Relu1) HANDLE_OPERATORTYPENAME_CASE(Relu6) + HANDLE_OPERATORTYPENAME_CASE(PRelu) HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) HANDLE_OPERATORTYPENAME_CASE(Softmax) HANDLE_OPERATORTYPENAME_CASE(LogSoftmax) -- 2.7.4