#include "IR/TFAdd.h"
#include "IR/TFAvgPool.h"
#include "IR/TFBiasAdd.h"
+#include "IR/TFConcatV2.h"
#include "IR/TFConst.h"
#include "IR/TFConv2D.h"
#include "IR/TFDepthwiseConv2dNative.h"
TENSORFLOW_NODE(Add, TFAdd)
TENSORFLOW_NODE(AvgPool, TFAvgPool)
TENSORFLOW_NODE(BiasAdd, TFBiasAdd)
+TENSORFLOW_NODE(ConcatV2, TFConcatV2)
TENSORFLOW_NODE(Const, TFConst)
TENSORFLOW_NODE(Conv2D, TFConv2D)
TENSORFLOW_NODE(DepthwiseConv2dNative, TFDepthwiseConv2dNative)
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFCONCATV2_H__
+#define __MOCO_TF_IR_TFCONCATV2_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+
+/// @note TFConcatV2 corresponds to the following GraphDef
+/*
+node {
+ name: "Concat"
+ op: "ConcatV2"
+ input: "Input01"
+ input: "Input02"
+ input: "Axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+*/
+
+/**
+ * @note As there is no VariableArityNode for now, we will import ConcatV2
+ * as cascading of multiple TFConcatV2 nodes like loco::TensorConcat
+ */
+
+class TFConcatV2 final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::ConcatV2>>
+{
+public:
+ TFConcatV2() = default;
+
+public:
+ Node *lhs(void) const { return at(0)->node(); }
+ void lhs(Node *node) { at(0)->node(node); }
+
+ Node *rhs(void) const { return at(1)->node(); }
+ void rhs(Node *node) { at(1)->node(node); }
+
+public:
+ uint32_t axis(void) const { return _axis; }
+ void axis(uint32_t val) { _axis = val; }
+
+private:
+ // Axis
+ uint32_t _axis{0};
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFCONCATV2_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFConcatV2.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFConcatV2Test, constructor)
+{
+ moco::tf::TFConcatV2 concatv2_node;
+
+ ASSERT_EQ(concatv2_node.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(concatv2_node.opcode(), moco::tf::TFOpcode::ConcatV2);
+
+ ASSERT_EQ(concatv2_node.lhs(), nullptr);
+ ASSERT_EQ(concatv2_node.rhs(), nullptr);
+ ASSERT_EQ(concatv2_node.axis(), 0);
+}
return false;
}
+bool fix_padding(moco::tf::TFConcatV2 *node)
+{
+ // Nothing to do with padding
+ return false;
+}
+
bool fix_padding(moco::tf::TFConst *node)
{
// Nothing to do with padding
return copy_shapedata(value, node);
}
+bool fix_shape(moco::tf::TFConcatV2 *node)
+{
+ auto concat_data = node->annot<ConcatData>();
+ if (concat_data == nullptr)
+ {
+ // shape inference is already done for TFConcatV2
+ assert(node->annot<ShapeInferenceData>() != nullptr);
+ return false;
+ }
+ assert(node->annot<ShapeInferenceData>() == nullptr);
+
+ auto lhs = node->lhs();
+ auto rhs = node->rhs();
+ auto lhs_shapedata = lhs->annot<ShapeInferenceData>();
+ auto rhs_shapedata = rhs->annot<ShapeInferenceData>();
+ if (lhs_shapedata == nullptr || rhs_shapedata == nullptr)
+ {
+ // postpone as previous input node(s) hasn't been processed.
+ // this will return false as there was nothing changed, but the input of this
+ // node should be changed and from that this method should be called again.
+ // if not, the network may have some problem and the final output may not have
+ // the right shape value and we can identify with some validation at final stage.
+ return false;
+ }
+
+ uint32_t lhs_rank = lhs_shapedata->rank();
+ uint32_t rhs_rank = rhs_shapedata->rank();
+ assert(lhs_rank == rhs_rank);
+
+ int32_t axis_tf = concat_data->axis();
+ if (axis_tf < 0)
+ {
+ axis_tf = static_cast<int32_t>(lhs_rank) + axis_tf;
+ }
+ assert(0 <= axis_tf && axis_tf < static_cast<int32_t>(lhs_rank));
+ // clear annotation ConcatData
+ node->annot<ConcatData>(nullptr);
+
+ uint32_t axis_loco = static_cast<uint32_t>(axis_tf);
+ node->axis(axis_loco);
+
+ // Set ShapeInferenceData for TensorConcat
+ auto shape_data = stdex::make_unique<ShapeInferenceData>();
+ shape_data->rank(lhs_rank);
+ for (uint32_t index = 0; index < lhs_rank; ++index)
+ {
+ uint32_t lhs_dim = lhs_shapedata->dim(index).value();
+ uint32_t rhs_dim = rhs_shapedata->dim(index).value();
+ // "lhs_dim == rhs_dim" should hold when "index != axis_loco"
+ // or doesn't care when "index == axis_loco"
+ assert(index == axis_loco || lhs_dim == rhs_dim);
+
+ uint32_t new_dim = (index == axis_loco) ? lhs_dim + rhs_dim : lhs_dim;
+
+ if (lhs_shapedata->dim(index).known())
+ shape_data->dim(index) = new_dim;
+ else
+ shape_data->dim(index).unset();
+ }
+ node->annot(std::move(shape_data));
+
+ return true;
+}
+
bool fix_shape(moco::tf::TFConst *node)
{
auto shapedata = node->annot<ShapeInferenceData>();