Make TFLite Mean op have parity with TF Reduce Mean op by changing the
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 29 Jan 2018 19:47:02 +0000 (11:47 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 29 Jan 2018 19:51:04 +0000 (11:51 -0800)
representation of axis from an attribute to a tensor.

PiperOrigin-RevId: 183701017

tensorflow/contrib/lite/builtin_op_data.h
tensorflow/contrib/lite/kernels/mean.cc
tensorflow/contrib/lite/kernels/mean_test.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator_test.cc

index 6dd9cb3..7a7e20a 100644 (file)
@@ -199,10 +199,6 @@ typedef struct {
 } TfLiteTransposeParams;
 
 typedef struct {
-  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
-  // For now we will fix the maximum possible number of dimensions.
-  int axis[8];
-  int num_axis_dimensions;
   bool keep_dims;
 } TfLiteMeanParams;
 
index 540e5a3..ec1c402 100644 (file)
@@ -35,10 +35,12 @@ struct MeanContext {
   MeanContext(TfLiteContext* context, TfLiteNode* node) {
     params = reinterpret_cast<TfLiteMeanParams*>(node->builtin_data);
     input = GetInput(context, node, 0);
+    axis = GetInput(context, node, 1);
     output = GetOutput(context, node, 0);
   }
   TfLiteMeanParams* params;
   TfLiteTensor* input;
+  TfLiteTensor* axis;
   TfLiteTensor* output;
 };
 
@@ -54,45 +56,26 @@ void Free(TfLiteContext* context, void* buffer) {
   delete reinterpret_cast<int*>(buffer);
 }
 
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
-  TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
-  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
-
-  MeanContext op_context(context, node);
-  int input_num_dims = NumDimensions(op_context.input);
-  int axis_num_dims = op_context.params->num_axis_dimensions;
-
-  // Creates a temp index to iterate through input data.
-  int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
-  TfLiteIntArrayFree(node->temporaries);
-  node->temporaries = TfLiteIntArrayCreate(2);
-  node->temporaries->data[0] = *scratch_tensor_index;
-  TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
-  scratch_tensor->type = kTfLiteInt32;
-  scratch_tensor->allocation_type = kTfLiteArenaRw;
-  TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
-  index_size->data[0] = input_num_dims;
-  TF_LITE_ENSURE_OK(context,
-                    context->ResizeTensor(context, scratch_tensor, index_size));
-
-  // Creates a temp tensor to store resolved axis given input data.
-  node->temporaries->data[1] = *scratch_tensor_index + 1;
-  TfLiteTensor* axis_tensor = &context->tensors[node->temporaries->data[1]];
-  axis_tensor->type = kTfLiteInt32;
-  axis_tensor->allocation_type = kTfLiteArenaRw;
+// Resizes the temp tensor that stores resolved axis.
+TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
+                            TfLiteTensor* resolved_axis) {
   TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
-  axis_size->data[0] = op_context.params->num_axis_dimensions;
-  TF_LITE_ENSURE_OK(context,
-                    context->ResizeTensor(context, axis_tensor, axis_size));
+  axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
+  return context->ResizeTensor(context, resolved_axis, axis_size);
+}
 
-  // Determines size of output tensor.
-  const TfLiteIntArray* input_dims = op_context.input->dims;
-  const int* axis = op_context.params->axis;
-  if (op_context.params->keep_dims) {
+// Resizes output array based on the input size and resolved axis.
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+                                MeanContext* op_context) {
+  size_t num_axis = NumElements(op_context->axis);
+  const TfLiteIntArray* input_dims = op_context->input->dims;
+  int input_num_dims = NumDimensions(op_context->input);
+  const int* axis = GetTensorData<int>(op_context->axis);
+  if (op_context->params->keep_dims) {
     TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
     for (int idx = 0; idx < input_num_dims; ++idx) {
       bool is_axis = false;
-      for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) {
+      for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
         if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
           is_axis = true;
           break;
@@ -104,11 +87,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
         output_dims->data[idx] = input_dims->data[idx];
       }
     }
-    return context->ResizeTensor(context, op_context.output, output_dims);
+    return context->ResizeTensor(context, op_context->output, output_dims);
   } else {
     // Calculates size of reducing axis.
-    int num_reduce_axis = axis_num_dims;
-    for (int i = 0; i < axis_num_dims; ++i) {
+    int num_reduce_axis = num_axis;
+    for (int i = 0; i < num_axis; ++i) {
       int current = axis[i];
       if (current < 0) {
         current += input_num_dims;
@@ -131,7 +114,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
     int num_skip_axis = 0;
     for (int idx = 0; idx < input_num_dims; ++idx) {
       bool is_axis = false;
-      for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) {
+      for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
         if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
           ++num_skip_axis;
           is_axis = true;
@@ -142,24 +125,76 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
         output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
       }
     }
-    return context->ResizeTensor(context, op_context.output, output_dims);
+    return context->ResizeTensor(context, op_context->output, output_dims);
+  }
+}
+
+// Initializes temp tensors to store index and resolved axis.
+TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
+                                   MeanContext* op_context) {
+  // Creates a temp index to iterate through input data.
+  int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+  TfLiteIntArrayFree(node->temporaries);
+  node->temporaries = TfLiteIntArrayCreate(2);
+  node->temporaries->data[0] = *scratch_tensor_index;
+  TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]];
+  scratch_tensor->type = kTfLiteInt32;
+  scratch_tensor->allocation_type = kTfLiteArenaRw;
+  TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
+  index_size->data[0] = NumDimensions(op_context->input);
+  TF_LITE_ENSURE_OK(context,
+                    context->ResizeTensor(context, scratch_tensor, index_size));
+
+  // Creates a temp tensor to store resolved axis given input data.
+  node->temporaries->data[1] = *scratch_tensor_index + 1;
+  TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+  resolved_axis->type = kTfLiteInt32;
+  return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  MeanContext op_context(context, node);
+  TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
+
+  TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+  // Leaves work to Eval if axis is not constant; else resizes output.
+  if (!IsConstantTensor(op_context.axis)) {
+    SetTensorToDynamic(op_context.output);
+    SetTensorToDynamic(resolved_axis);
+    return kTfLiteOk;
   }
+  resolved_axis->allocation_type = kTfLiteArenaRw;
+  TF_LITE_ENSURE_OK(context,
+                    ResizeTempAxis(context, &op_context, resolved_axis));
+  return ResizeOutputTensor(context, &op_context);
 }
 
 template <KernelType kernel_type>
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   MeanContext op_context(context, node);
+  int num_axis = static_cast<int>(NumElements(op_context.axis));
   TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]];
   TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]];
+  // Resize the output tensor if the output tensor is dynamic.
+  if (IsDynamicTensor(op_context.output)) {
+    TF_LITE_ENSURE_OK(context,
+                      ResizeTempAxis(context, &op_context, resolved_axis));
+    TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+    TfLiteTensorRealloc(resolved_axis->bytes, resolved_axis);
+    TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
+  }
 
-#define TF_LITE_MEAN(kernel_type, data_type)                           \
-  kernel_type::Mean<>(                                                 \
-      GetTensorData<data_type>(op_context.input),                      \
-      op_context.input->dims->data, op_context.input->dims->size,      \
-      GetTensorData<data_type>(op_context.output),                     \
-      op_context.output->dims->data, op_context.output->dims->size,    \
-      op_context.params->axis, op_context.params->num_axis_dimensions, \
-      op_context.params->keep_dims, GetTensorData<int>(temp_index),    \
+#define TF_LITE_MEAN(kernel_type, data_type)                        \
+  kernel_type::Mean<>(                                              \
+      GetTensorData<data_type>(op_context.input),                   \
+      op_context.input->dims->data, op_context.input->dims->size,   \
+      GetTensorData<data_type>(op_context.output),                  \
+      op_context.output->dims->data, op_context.output->dims->size, \
+      GetTensorData<int>(op_context.axis), num_axis,                \
+      op_context.params->keep_dims, GetTensorData<int>(temp_index), \
       GetTensorData<int>(resolved_axis))
 
   if (kernel_type == kReference) {
index 4305c06..c4c53c2 100644 (file)
@@ -25,58 +25,108 @@ using ::testing::ElementsAreArray;
 
 class BaseMeanOpModel : public SingleOpModel {
  public:
-  BaseMeanOpModel(const TensorData& input, const TensorData& output,
-                  std::initializer_list<int> axis, bool keep_dims) {
-    input_ = AddInput(input);
-    output_ = AddOutput(output);
-    SetBuiltinOp(
-        BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
-        CreateMeanOptions(builder_, builder_.CreateVector<int>(axis), keep_dims)
-            .Union());
-    BuildInterpreter({GetShape(input_)});
+  void SetAxis(std::initializer_list<int> data) { PopulateTensor(axis_, data); }
+
+  template <class T>
+  void SetInput(std::initializer_list<T> data) {
+    PopulateTensor(input_, data);
   }
 
-  int input() { return input_; }
+  template <class T>
+  std::vector<T> GetOutput() {
+    return ExtractVector<T>(output_);
+  }
+
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
 
  protected:
   int input_;
+  int axis_;
   int output_;
 };
 
-class FloatMeanOpModel : public BaseMeanOpModel {
+// Model for the tests case where axis is a const tensor.
+class MeanOpConstModel : public BaseMeanOpModel {
  public:
-  using BaseMeanOpModel::BaseMeanOpModel;
-
-  void SetInput(std::initializer_list<float> data) {
-    PopulateTensor(input_, data);
+  MeanOpConstModel(const TensorData& input, const TensorData& output,
+                   std::initializer_list<int> axis_shape,
+                   std::initializer_list<int> axis, bool keep_dims) {
+    input_ = AddInput(input);
+    axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
+                 CreateMeanOptions(builder_, keep_dims).Union());
+    BuildInterpreter({GetShape(input_)});
   }
+};
 
-  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+// Model for the tests case where axis is a dynamic tensor.
+class MeanOpDynamicModel : public BaseMeanOpModel {
+ public:
+  MeanOpDynamicModel(const TensorData& input, const TensorData& output,
+                     const TensorData& axis, bool keep_dims) {
+    input_ = AddInput(input);
+    axis_ = AddInput(axis);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
+                 CreateMeanOptions(builder_, keep_dims).Union());
+    BuildInterpreter({GetShape(input_)});
+  }
 };
 
-TEST(FloatMeanOpTest, NotKeepDims) {
+TEST(ConstMeanOpTest, NotKeepDims) {
+  std::initializer_list<float> data = {
+      1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
+      13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+  MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+                     {4}, {1, 0, -3, -3}, false);
+  m.SetInput(data);
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+  EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
+}
+
+TEST(ConstMeanOpTest, KeepDims) {
+  std::initializer_list<float> data = {
+      1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
+      13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+  MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+                     {2}, {0, 2}, true);
+  m.SetInput(data);
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+  EXPECT_THAT(m.GetOutput<float>(),
+              ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
+}
+
+TEST(DynamicMeanOpTest, NotKeepDims) {
   std::initializer_list<float> data = {
       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
-  FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
-                     {1, 0, -3, -3}, false);
+  MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+                       {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+                       false);
+  std::initializer_list<int> axis = {1, 0, -3, -3};
+  m.SetAxis(axis);
   m.SetInput(data);
   m.Invoke();
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
-  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13})));
+  EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
 }
 
-TEST(FloatMeanOpTest, KeepDims) {
+TEST(DynamicMeanOpTest, KeepDims) {
   std::initializer_list<float> data = {
       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
-  FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
-                     {0, 2}, true);
+  MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+                       {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}},
+                       true);
+  std::initializer_list<int> axis = {0, 2};
+  m.SetAxis(axis);
   m.SetInput(data);
   m.Invoke();
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
-  EXPECT_THAT(m.GetOutput(),
+  EXPECT_THAT(m.GetOutput<float>(),
               ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
 }
 
index 64b8f55..c82ae27 100644 (file)
@@ -544,11 +544,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_MEAN: {
       auto* params = MallocPOD<TfLiteMeanParams>();
       if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
-        const auto& axis = schema_params->axis();
-        FlatBufferIntVectorToArray(sizeof(params->axis), axis, params->axis,
-                                   error_reporter);
         params->keep_dims = schema_params->keep_dims();
-        params->num_axis_dimensions = axis->Length();
       }
       builtin_data = reinterpret_cast<void*>(params);
       break;
index f621756..91eac2a 100644 (file)
@@ -333,7 +333,6 @@ table TransposeOptions {
 }
 
 table MeanOptions {
-  axis:[int];
   keep_dims: bool;
 }
 
index ad75852..a8370b3 100755 (executable)
@@ -3388,21 +3388,16 @@ flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
 
 struct MeanOptionsT : public flatbuffers::NativeTable {
   typedef MeanOptions TableType;
-  std::vector<int32_t> axis;
   bool keep_dims;
   MeanOptionsT() : keep_dims(false) {}
 };
 
 struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   typedef MeanOptionsT NativeTableType;
-  enum { VT_AXIS = 4, VT_KEEP_DIMS = 6 };
-  const flatbuffers::Vector<int32_t> *axis() const {
-    return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AXIS);
-  }
+  enum { VT_KEEP_DIMS = 4 };
   bool keep_dims() const { return GetField<uint8_t>(VT_KEEP_DIMS, 0) != 0; }
   bool Verify(flatbuffers::Verifier &verifier) const {
-    return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_AXIS) &&
-           verifier.Verify(axis()) &&
+    return VerifyTableStart(verifier) &&
            VerifyField<uint8_t>(verifier, VT_KEEP_DIMS) && verifier.EndTable();
   }
   MeanOptionsT *UnPack(
@@ -3418,9 +3413,6 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
 struct MeanOptionsBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
-  void add_axis(flatbuffers::Offset<flatbuffers::Vector<int32_t>> axis) {
-    fbb_.AddOffset(MeanOptions::VT_AXIS, axis);
-  }
   void add_keep_dims(bool keep_dims) {
     fbb_.AddElement<uint8_t>(MeanOptions::VT_KEEP_DIMS,
                              static_cast<uint8_t>(keep_dims), 0);
@@ -3438,22 +3430,12 @@ struct MeanOptionsBuilder {
 };
 
 inline flatbuffers::Offset<MeanOptions> CreateMeanOptions(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    flatbuffers::Offset<flatbuffers::Vector<int32_t>> axis = 0,
-    bool keep_dims = false) {
+    flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) {
   MeanOptionsBuilder builder_(_fbb);
-  builder_.add_axis(axis);
   builder_.add_keep_dims(keep_dims);
   return builder_.Finish();
 }
 
-inline flatbuffers::Offset<MeanOptions> CreateMeanOptionsDirect(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    const std::vector<int32_t> *axis = nullptr, bool keep_dims = false) {
-  return tflite::CreateMeanOptions(
-      _fbb, axis ? _fbb.CreateVector<int32_t>(*axis) : 0, keep_dims);
-}
-
 flatbuffers::Offset<MeanOptions> CreateMeanOptions(
     flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o,
     const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -6040,15 +6022,6 @@ inline void MeanOptions::UnPackTo(
   (void)_o;
   (void)_resolver;
   {
-    auto _e = axis();
-    if (_e) {
-      _o->axis.resize(_e->size());
-      for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
-        _o->axis[_i] = _e->Get(_i);
-      }
-    }
-  };
-  {
     auto _e = keep_dims();
     _o->keep_dims = _e;
   };
@@ -6071,9 +6044,8 @@ inline flatbuffers::Offset<MeanOptions> CreateMeanOptions(
     const flatbuffers::rehasher_function_t *__rehasher;
   } _va = {&_fbb, _o, _rehasher};
   (void)_va;
-  auto _axis = _o->axis.size() ? _fbb.CreateVector(_o->axis) : 0;
   auto _keep_dims = _o->keep_dims;
-  return tflite::CreateMeanOptions(_fbb, _axis, _keep_dims);
+  return tflite::CreateMeanOptions(_fbb, _keep_dims);
 }
 
 inline SqueezeOptionsT *SqueezeOptions::UnPack(
index 4ae6ccb..fc8149b 100644 (file)
@@ -695,6 +695,7 @@ def make_mean_tests(zip_path):
           [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
           [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
       ],
+      "const_axis": [True, False],
       "keep_dims": [True, False],
   }, {
       "input_dtype": [tf.float32, tf.int32, tf.int64],
@@ -705,6 +706,7 @@ def make_mean_tests(zip_path):
           -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
           [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
       ],
+      "const_axis": [True, False],
       "keep_dims": [True, False],
   }]
 
@@ -714,17 +716,31 @@ def make_mean_tests(zip_path):
         dtype=parameters["input_dtype"],
         name="input",
         shape=parameters["input_shape"])
+
+    # Get axis as either a placeholder or constants.
+    if parameters["const_axis"]:
+      axis = parameters["axis"]
+      input_tensors = [input_tensor]
+    else:
+      if isinstance(parameters["axis"], list):
+        shape = [len(parameters["axis"])]
+      else:
+        shape = [0]  # shape for None or integers.
+      axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape)
+      input_tensors = [input_tensor, axis]
+
     out = tf.reduce_mean(
-        input_tensor,
-        axis=parameters["axis"],
-        keep_dims=parameters["keep_dims"])
-    return [input_tensor], [out]
+        input_tensor, axis=axis, keep_dims=parameters["keep_dims"])
+    return input_tensors, [out]
 
   def build_inputs(parameters, sess, inputs, outputs):
-    input_values = create_tensor_data(parameters["input_dtype"],
-                                      parameters["input_shape"])
-    return [input_values], sess.run(
-        outputs, feed_dict=dict(zip(inputs, [input_values])))
+    values = [
+        create_tensor_data(parameters["input_dtype"], parameters["input_shape"])
+    ]
+    if not parameters["const_axis"]:
+      if parameters["axis"]:
+        values.append(np.array(parameters["axis"]))
+    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
 
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
index 853a9f4..e33a578 100644 (file)
@@ -548,14 +548,11 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
   flatbuffers::Offset<TfLiteOptions> WriteOptions(
       const TocoOperator& op,
       flatbuffers::FlatBufferBuilder* builder) const override {
-    auto axis = builder->CreateVector(op.axis);
-    return ::tflite::CreateMeanOptions(*builder, axis, op.keep_dims);
+    return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
   }
 
   void ReadOptions(const TfLiteOptions& options,
                    TocoOperator* op) const override {
-    op->axis.insert(op->axis.end(), options.axis()->begin(),
-                    options.axis()->end());
     op->keep_dims = options.keep_dims();
   }
 };
index 4df9071..b4ec7bb 100644 (file)
@@ -134,12 +134,10 @@ TEST_F(OperatorTest, BuiltinSpaceToBatchND) {
 
 TEST_F(OperatorTest, BuiltinMean) {
   MeanOperator op;
-  op.axis = {1, 2};
   op.keep_dims = false;
 
   auto output_toco_op =
       SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
-  EXPECT_EQ(op.axis, output_toco_op->axis);
   EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
 }