Supports PReLU in TFLite & Toco.
authorYu-Cheng Ling <ycling@google.com>
Thu, 22 Mar 2018 18:25:49 +0000 (11:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 18:29:59 +0000 (11:29 -0700)
PiperOrigin-RevId: 190097557

19 files changed:
tensorflow/contrib/lite/builtin_ops.h
tensorflow/contrib/lite/kernels/activations.cc
tensorflow/contrib/lite/kernels/activations_test.cc
tensorflow/contrib/lite/kernels/register.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc

index e4652a3..d7993e6 100644 (file)
@@ -78,6 +78,7 @@ typedef enum {
   kTfLiteBuiltinDelegate = 51,
   kTfLiteBuiltinBidirectionalSequenceLstm = 52,
   kTfLiteBuiltinCast = 53,
+  kTfLiteBuiltinPrelu = 54,
 } TfLiteBuiltinOperator;
 
 #ifdef __cplusplus
index 093761c..39a54c9 100644 (file)
@@ -150,6 +150,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
                                TfLiteIntArrayCopy(input->dims));
 }
 
+TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TfLiteTensor* input = GetInput(context, node, 0);
+  TfLiteTensor* output = GetOutput(context, node, 0);
+  TfLiteTensor* alpha = GetInput(context, node, 1);
+
+  output->type = input->type;
+
+  // Currently only Float32 is supported
+  // TODO(ycling): Support other data types.
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32);
+
+  // Currently, only support 4D `input` and 3D `alpha` with shape
+  // (1, 1, channels).
+  // TODO(impjdi): Support other cases where `alpha` is broadcastable
+  // to `input`.
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]);
+
+  return context->ResizeTensor(context, output,
+                               TfLiteIntArrayCopy(input->dims));
+}
+
 TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* input = GetInput(context, node, 0);
   TfLiteTensor* output = GetOutput(context, node, 0);
@@ -388,6 +416,35 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
   }
 }
 
+TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteTensor* input = GetInput(context, node, 0);
+  TfLiteTensor* alpha = GetInput(context, node, 1);
+  TfLiteTensor* output = GetOutput(context, node, 0);
+
+  if (input->type != kTfLiteFloat32) {
+    context->ReportError(context, "Only float32 supported currently.");
+    return kTfLiteError;
+  }
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+  const int batches = input->dims->data[0];
+  const int height = input->dims->data[1];
+  const int width = input->dims->data[2];
+  const int channels = input->dims->data[3];
+
+  TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
+  TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels);
+
+  const int n = batches * height * width * channels;
+  for (int i = 0; i < n; ++i) {
+    const float x = input->data.f[i];
+    output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x;
+  }
+
+  return kTfLiteOk;
+}
+
 }  // namespace activations
 
 TfLiteRegistration* Register_RELU() {
@@ -439,6 +496,13 @@ TfLiteRegistration* Register_LOG_SOFTMAX() {
   return &r;
 }
 
+TfLiteRegistration* Register_PRELU() {
+  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+                                 activations::PreluPrepare,
+                                 activations::PreluEval};
+  return &r;
+}
+
 }  // namespace builtin
 }  // namespace ops
 }  // namespace tflite
index b9a96e3..50a84ed 100644 (file)
@@ -383,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
                               })));
 }
 
+class PReluOpModel : public SingleOpModel {
+ public:
+  PReluOpModel(const TensorData& input, const TensorData& alpha) {
+    input_ = AddInput(input);
+    alpha_ = AddInput(alpha);
+    output_ = AddOutput(input);
+    SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
+    BuildInterpreter({GetShape(input_), GetShape(alpha_)});
+  }
+  void SetInput(std::initializer_list<float> data) {
+    PopulateTensor(input_, data);
+  }
+  void SetAlpha(std::initializer_list<float> data) {
+    PopulateTensor(alpha_, data);
+  }
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+  int input_;
+  int alpha_;
+  int output_;
+};
+
+TEST(FloatActivationsOpTest, PRelu) {
+  PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
+                 {TensorType_FLOAT32, {1, 1, 3}});
+
+  m.SetInput({
+      0.0f, 0.0f, 0.0f,     // Row 1, Column 1
+      1.0f, 1.0f, 1.0f,     // Row 1, Column 2
+      -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
+      -2.0f, -2.0f, -2.0f,  // Row 1, Column 2
+  });
+  m.SetAlpha({0.0f, 1.0f, 2.0f});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+                                 0.0f, 0.0f, 0.0f,    // Row 1, Column 1
+                                 1.0f, 1.0f, 1.0f,    // Row 1, Column 2
+                                 0.0f, -1.0f, -2.0f,  // Row 2, Column 1
+                                 0.0f, -2.0f, -4.0f,  // Row 1, Column 2
+                             }));
+}
+
 }  // namespace
 }  // namespace tflite
 
index 369d3b9..62045f0 100644 (file)
@@ -75,6 +75,7 @@ TfLiteRegistration* Register_TOPK_V2();
 TfLiteRegistration* Register_LOG_SOFTMAX();
 TfLiteRegistration* Register_CAST();
 TfLiteRegistration* Register_DEQUANTIZE();
+TfLiteRegistration* Register_PRELU();
 
 BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -131,6 +132,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
   AddBuiltin(BuiltinOperator_CAST, Register_CAST());
   AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
+  AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
 
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
   // custom ops aren't always included by default.
index 9c619f8..b7ccdf0 100644 (file)
@@ -309,6 +309,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_LOG_SOFTMAX:
     case BuiltinOperator_CAST:
     case BuiltinOperator_DEQUANTIZE:
+    case BuiltinOperator_PRELU:
       break;
     case BuiltinOperator_LSH_PROJECTION: {
       TfLiteLSHProjectionParams* params =
index 9d00d96..e31b7c0 100644 (file)
@@ -349,6 +349,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       case tflite::BuiltinOperator_DEQUANTIZE:
       case tflite::BuiltinOperator_DELEGATE:
       case tflite::BuiltinOperator_CAST:
+      case tflite::BuiltinOperator_PRELU:
         FATAL("Op code %d is currently not delegated to NNAPI", builtin);
         nn_op_type = -1;  // set to invalid
         break;
index 04387fe..e107597 100644 (file)
@@ -130,6 +130,7 @@ enum BuiltinOperator : byte {
   DELEGATE = 51,
   BIDIRECTIONAL_SEQUENCE_LSTM = 52,
   CAST = 53,
+  PRELU = 54,
 }
 
 // Options for the builtin operators.
index b922de2..86daeaf 100755 (executable)
@@ -254,11 +254,12 @@ enum BuiltinOperator {
   BuiltinOperator_DELEGATE = 51,
   BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
   BuiltinOperator_CAST = 53,
+  BuiltinOperator_PRELU = 54,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_CAST
+  BuiltinOperator_MAX = BuiltinOperator_PRELU
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[53] {
   static BuiltinOperator values[] = {
     BuiltinOperator_ADD,
     BuiltinOperator_AVERAGE_POOL_2D,
@@ -311,7 +312,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] {
     BuiltinOperator_LOG_SOFTMAX,
     BuiltinOperator_DELEGATE,
     BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
-    BuiltinOperator_CAST
+    BuiltinOperator_CAST,
+    BuiltinOperator_PRELU
   };
   return values;
 }
@@ -372,6 +374,7 @@ inline const char **EnumNamesBuiltinOperator() {
     "DELEGATE",
     "BIDIRECTIONAL_SEQUENCE_LSTM",
     "CAST",
+    "PRELU",
     nullptr
   };
   return names;
index f1b18ad..555ea90 100644 (file)
@@ -39,6 +39,7 @@ gen_zipped_test_files(
         "mean.zip",
         "mul.zip",
         "pad.zip",
+        "prelu.zip",
         "relu.zip",
         "relu1.zip",
         "relu6.zip",
index 420bdb4..38de9dc 100644 (file)
@@ -617,6 +617,54 @@ def make_relu6_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_prelu_tests(zip_path):
+  """Make a set of tests to do PReLU."""
+
+  test_parameters = [{
+      # The canonical case for image processing is having a 4D `input` (NHWC)
+      # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+      "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+      "shared_axes": [[1, 2], [1]],
+  }]
+
+  def build_graph(parameters):
+    """Build the graph for the test case."""
+
+    input_tensor = tf.placeholder(
+        dtype=tf.float32, name="input", shape=parameters["input_shape"])
+    prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+    out = prelu(input_tensor)
+    return [input_tensor], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    """Build the inputs for the test case."""
+
+    input_shape = parameters["input_shape"]
+    input_values = create_tensor_data(
+        np.float32, input_shape, min_value=-10, max_value=10)
+    shared_axes = parameters["shared_axes"]
+
+    alpha_shape = []
+    for dim in range(1, len(input_shape)):
+      alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+    alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+    with tf.variable_scope("", reuse=True):
+      alpha = tf.get_variable("p_re_lu/alpha")
+      sess.run(alpha.assign(alpha_values))
+
+    return [input_values], sess.run(
+        outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+  make_zip_of_tests(
+      zip_path,
+      test_parameters,
+      build_graph,
+      build_inputs,
+      use_frozen_graph=True)
+
+
 # This function tests various TensorFLow functions that generates Const op,
 # including `tf.ones`, `tf.zeros` and random functions.
 def make_constant_tests(zip_path):
@@ -1911,6 +1959,7 @@ def main(unused_args):
         "relu.zip": make_relu_tests,
         "relu1.zip": make_relu1_tests,
         "relu6.zip": make_relu6_tests,
+        "prelu.zip": make_prelu_tests,
         "l2_pool.zip": make_pool_tests(make_l2_pool),
         "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
         "max_pool.zip": make_pool_tests(tf.nn.max_pool),
index 5e76e7c..ba2d259 100644 (file)
@@ -88,6 +88,9 @@ std::map<string, string> kBrokenTests = {
 
     // Transpose only supports 1D-4D input tensors.
     {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
+
+    // PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
+    {R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
 };
 
 // Allows test data to be unzipped into a temporary directory and makes
@@ -253,6 +256,7 @@ INSTANTIATE_TESTS(mul)
 INSTANTIATE_TESTS(pad)
 INSTANTIATE_TESTS(relu)
 INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(prelu)
 INSTANTIATE_TESTS(relu6)
 INSTANTIATE_TESTS(reshape)
 INSTANTIATE_TESTS(resize_bilinear)
index 395abc5..486ff1e 100644 (file)
@@ -193,6 +193,7 @@ cc_library(
         "graph_transformations/identify_lstm.cc",
         "graph_transformations/identify_lstm_merge_inputs.cc",
         "graph_transformations/identify_lstm_split_inputs.cc",
+        "graph_transformations/identify_prelu.cc",
         "graph_transformations/identify_relu1.cc",
         "graph_transformations/lstm_utils.cc",
         "graph_transformations/make_initial_dequantize_operator.cc",
index 11e5e19..640afc7 100644 (file)
@@ -129,6 +129,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
 DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
 DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
 DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
 DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
new file mode 100644 (file)
index 0000000..30be4ac
--- /dev/null
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+// This transformation rule tries to identify the PRelu structure generated by
+// Keras, and convert it to a single op.
+//
+// The formula of PReLU is:
+// f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
+//
+// `x` is the input, and `alpha` is a trainable tensor which can be broadcasted
+// to the shape of `x`.
+//
+// There's no native PRelu op in TensorFlow, so Keras generates the following
+// structure which does the equivalent calculation:
+// f(x) = Relu(x) + (-alpha * Relu(-x))
+//
+// Practically, alpha is always a constant in the inference graph, and Toco have
+// other graph transformations which fold the activation functions to other ops.
+// Therefore, we're looking for the structure:
+//
+// f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu))
+
+namespace toco {
+
+bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
+  const auto add_op_it = model->operators.begin() + op_index;
+  const auto* add_op = add_op_it->get();
+  if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
+      add_op->inputs.size() != 2 ||
+      add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
+    return false;
+  }
+
+  const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
+  if (relu_input_op == nullptr || relu_input_op->type != OperatorType::kRelu ||
+      relu_input_op->inputs.size() != 1 ||
+      relu_input_op->fused_activation_function !=
+          FusedActivationFunctionType::kNone) {
+    return false;
+  }
+
+  // TODO(ycling): Both Add and Mul are commutative. Support the case where
+  // the position of operands are exchanged.
+  const auto* mul_op = GetOpWithOutput(*model, add_op->inputs[1]);
+  if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
+      mul_op->inputs.size() != 2 ||
+      mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
+    return false;
+  }
+
+  const auto neg_alpha_tensor_name = mul_op->inputs[0];
+
+  const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]);
+
+  if (relu_neg_input_op == nullptr ||
+      relu_neg_input_op->type != OperatorType::kNeg ||
+      relu_neg_input_op->fused_activation_function !=
+          FusedActivationFunctionType::kRelu ||
+      relu_neg_input_op->inputs.size() != 1) {
+    return false;
+  }
+
+  if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) {
+    return false;
+  }
+
+  const auto input_tensor_name = relu_input_op->inputs[0];
+  const auto output_tensor_name = add_op->outputs[0];
+
+  // Construct a tensor for positive alpha (double negative).
+  const auto alpha_tensor_name =
+      AvailableArrayName(*model, neg_alpha_tensor_name + "_neg");
+  model->GetOrCreateArray(alpha_tensor_name);
+
+  auto* neg_neg_alpha_op = new NegOperator;
+  neg_neg_alpha_op->inputs = {neg_alpha_tensor_name};
+  neg_neg_alpha_op->outputs = {alpha_tensor_name};
+  model->operators.emplace(add_op_it, neg_neg_alpha_op);
+
+  auto* prelu_op = new PReluOperator;
+  prelu_op->inputs = {input_tensor_name, alpha_tensor_name};
+  prelu_op->outputs = {output_tensor_name};
+  model->operators.emplace(add_op_it, prelu_op);
+  AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op));
+
+  DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model);
+  DeleteArrayIfUsedOnce(add_op->inputs[0], model);
+  DeleteArrayIfUsedOnce(add_op->inputs[1], model);
+  DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
+  // Remove the existing Add op that outputs the final result. If the other
+  // intermediate tensors aren't used by other ops, those will be removed by
+  // other graph transformation rules.
+  model->operators.erase(FindOp(*model, add_op));
+
+  return true;
+}
+
+}  // namespace toco
index 375848a..676736c 100644 (file)
@@ -1467,6 +1467,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
     case OperatorType::kRelu:
     case OperatorType::kRelu1:
     case OperatorType::kRelu6:
+    case OperatorType::kPRelu:
     case OperatorType::kSoftmax:
     case OperatorType::kLogSoftmax:
     case OperatorType::kLogistic:
index 3fa0089..5199e29 100644 (file)
@@ -65,6 +65,7 @@ enum class OperatorType {
   kRelu,
   kRelu1,
   kRelu6,
+  kPRelu,
   kSoftmax,
   kLogSoftmax,
   kSub,
@@ -566,6 +567,18 @@ struct Relu6Operator : Operator {
   Relu6Operator() : Operator(OperatorType::kRelu6) {}
 };
 
+// PRelu
+//   f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
+//
+// Inputs:
+//   inputs[0]: required: the input array
+//   inputs[1]: required: the alpha array
+//
+// Equivalent to keras.layers.PReLU.
+struct PReluOperator : Operator {
+  PReluOperator() : Operator(OperatorType::kPRelu) {}
+};
+
 // Element-wise Logistic operator:
 //   x -> Logistic(x) = 1 / (1 + exp(-x))
 //
index f2cc4ef..f23249c 100644 (file)
@@ -854,6 +854,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
       new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
   ops.emplace_back(
       new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
+  ops.emplace_back(
+      new SimpleOperator<Relu1Operator>("PRELU", OperatorType::kPRelu));
   ops.emplace_back(new SimpleOperator<LogisticOperator>(
       "LOGISTIC", OperatorType::kLogistic));
   ops.emplace_back(
index ca66110..30dd6fa 100644 (file)
@@ -94,6 +94,7 @@ void MakeGeneralGraphTransformationsSet(
   transformations->Add(new IdentifyL2Normalization);
   transformations->Add(new IdentifyL2Pool);
   transformations->Add(new IdentifyRelu1);
+  transformations->Add(new IdentifyPRelu);
   transformations->Add(new RemoveTrivialBinaryOperator);
   transformations->Add(new ReadFakeQuantMinMax);
   transformations->Add(new ResolveSpaceToBatchNDAttributes);
index 2362206..ec1770c 100644 (file)
@@ -300,6 +300,7 @@ const char* OperatorTypeName(OperatorType type) {
     HANDLE_OPERATORTYPENAME_CASE(Relu)
     HANDLE_OPERATORTYPENAME_CASE(Relu1)
     HANDLE_OPERATORTYPENAME_CASE(Relu6)
+    HANDLE_OPERATORTYPENAME_CASE(PRelu)
     HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
     HANDLE_OPERATORTYPENAME_CASE(Softmax)
     HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)