#include "IR/TFFusedBatchNorm.h"
#include "IR/TFIdentity.h"
#include "IR/TFMaxPool.h"
+#include "IR/TFMean.h"
#include "IR/TFMul.h"
#include "IR/TFRealDiv.h"
#include "IR/TFRelu.h"
TENSORFLOW_NODE(FusedBatchNorm, TFFusedBatchNorm)
TENSORFLOW_NODE(Identity, TFIdentity)
TENSORFLOW_NODE(MaxPool, TFMaxPool)
+TENSORFLOW_NODE(Mean, TFMean)
TENSORFLOW_NODE(Mul, TFMul)
TENSORFLOW_NODE(RealDiv, TFRealDiv)
TENSORFLOW_NODE(Relu, TFRelu)
loco::DataType visit(const TFFusedBatchNorm *node) { return dtype_get(node->input()); }
loco::DataType visit(const TFIdentity *node) { return dtype_get(node->input()); }
loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFMean *node) { return dtype_get(node->input()); }
loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFMEAN_H__
+#define __MOCO_TF_IR_TFMEAN_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+#include <vector>
+
+namespace moco
+{
+namespace tf
+{
+
+class TFMean final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Mean>>
+{
+public:
+ TFMean() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ Node *reduction_indices(void) const { return at(1)->node(); }
+ void reduction_indices(Node *node) { at(1)->node(node); }
+
+public:
+ bool keep_dims(void) const { return _keep_dims; }
+ void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
+
+private:
+ bool _keep_dims = false;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFMEAN_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFMean.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFMeanTest, constructor)
+{
+ moco::tf::TFMean mean_node;
+
+ ASSERT_EQ(mean_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(mean_node.opcode(), moco::tf::TFOpcode::Mean);
+
+ ASSERT_EQ(mean_node.input(), nullptr);
+ ASSERT_EQ(mean_node.reduction_indices(), nullptr);
+ ASSERT_EQ(mean_node.keep_dims(), false);
+}
return copy_shapedata(x, node);
}
+bool fix_shape(moco::tf::TFMean *node)
+{
+ if (shape_inference_done(node))
+ return false;
+
+ LOGGER(l);
+
+ auto input = node->input();
+ auto reduction_indices = node->reduction_indices();
+ loco::NodeShape input_shape;
+ loco::NodeShape reduction_indices_shape;
+
+ if (!node_shape(input, input_shape) || !node_shape(reduction_indices, reduction_indices_shape))
+ {
+ // Input and reduction_indices shape are required for TFMean shape inference
+ return false;
+ }
+
+ // Get constant values if reduction_indeces is const
+ std::vector<int32_t> reduction_values;
+ if (auto tfconst = dynamic_cast<moco::tf::TFConst *>(reduction_indices))
+ {
+ assert(tfconst->dtype() == loco::DataType::S32);
+ auto const_size = tfconst->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < const_size; ++i)
+ {
+ int32_t axis = tfconst->at<loco::DataType::S32>(i);
+ if (axis < 0)
+ axis += input_shape.as<loco::TensorShape>().rank();
+ reduction_values.push_back(axis);
+ }
+ }
+ else
+ {
+ // we cannot find a valid reduction indices value
+ INFO(l) << "Fix shape TFMean fail : reduction indeces are not constant or not valid";
+ return false;
+ }
+
+ loco::TensorShape shape_data;
+ loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ if (node->keep_dims())
+ {
+ shape_data.rank(input_tensor_shape.rank());
+ for (uint32_t i = 0; i < input_tensor_shape.rank(); ++i)
+ shape_data.dim(i) = input_tensor_shape.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ shape_data.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_tensor_shape.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ shape_data.rank(input_tensor_shape.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ shape_data.dim(j++) = i;
+ }
+
+ auto shape_annot = stdex::make_unique<ShapeInferenceData>();
+ shape_annot->tensor_shape(shape_data);
+ node->annot(std::move(shape_annot));
+
+ return true;
+}
+
bool fix_shape(moco::tf::TFRealDiv *node)
{
auto x = node->x();