#include "IR/TFMaxPool.h"
#include "IR/TFMean.h"
#include "IR/TFMul.h"
+#include "IR/TFPad.h"
#include "IR/TFRealDiv.h"
#include "IR/TFRelu.h"
#include "IR/TFRelu6.h"
TENSORFLOW_NODE(MaxPool, TFMaxPool)
TENSORFLOW_NODE(Mean, TFMean)
TENSORFLOW_NODE(Mul, TFMul)
+TENSORFLOW_NODE(Pad, TFPad)
TENSORFLOW_NODE(RealDiv, TFRealDiv)
TENSORFLOW_NODE(Relu, TFRelu)
TENSORFLOW_NODE(Relu6, TFRelu6)
return moco::tf::as_tensor_shape(output_feature_shape, node->data_layout());
}
+ loco::NodeShape visit(const moco::tf::TFPad *node) final
+ {
+ auto input_shape = node_shape(node->input());
+ assert(input_shape.domain() == loco::Domain::Tensor);
+
+ auto const_paddings = dynamic_cast<moco::tf::TFConst *>(node->paddings());
+ assert(const_paddings);
+ assert(const_paddings->dtype() == loco::DataType::S32);
+ assert(const_paddings->rank() == 2);
+
+ loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
+ loco::TensorShape output_tensor_shape;
+
+ output_tensor_shape.rank(input_tensor_shape.rank());
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ output_tensor_shape.dim(axis) = input_tensor_shape.dim(axis).value() +
+ const_paddings->at<loco::DataType::S32>(axis * 2) +
+ const_paddings->at<loco::DataType::S32>(axis * 2 + 1);
+ }
+
+ return loco::NodeShape{output_tensor_shape};
+ }
+
public:
loco::NodeShape visit(const moco::tf::TFNode *node) final
{
loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
loco::DataType visit(const TFMean *node) { return dtype_get(node->input()); }
loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFPad *node) { return dtype_get(node->input()); }
loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
loco::DataType visit(const TFRelu6 *node) { return dtype_get(node->features()); }
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_IR_TFPAD_H__
+#define __MOCO_TF_IR_TFPAD_H__
+
+#include "Dialect/TFNodeDecl.h"
+
+namespace moco
+{
+namespace tf
+{
+/// @note TFPad corresponds to the following GraphDef
+/*
+node {
+ name: "Pad"
+ op: "Pad"
+ input: "Const_tensor"
+ input: "Const_paddings"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tpaddings"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+*/
+
+class TFPad final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Pad>>
+{
+public:
+ TFPad() = default;
+
+public:
+ Node *input(void) const { return at(0)->node(); }
+ void input(Node *node) { at(0)->node(node); }
+
+ Node *paddings(void) const { return at(1)->node(); }
+ void paddings(Node *node) { at(1)->node(node); }
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_IR_TFPAD_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IR/TFPad.h"
+
+#include "Dialect/TFDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(TFPadTest, constructor)
+{
+ moco::tf::TFPad pad;
+
+ ASSERT_EQ(pad.dialect(), moco::tf::TFDialect::get());
+ ASSERT_EQ(pad.opcode(), moco::tf::TFOpcode::Pad);
+
+ ASSERT_EQ(pad.input(), nullptr);
+ ASSERT_EQ(pad.paddings(), nullptr);
+}
return true;
}
+bool fix_shape(moco::tf::TFPad *node) { return false; }
+
bool fix_shape(moco::tf::TFRealDiv *node)
{
auto x = node->x();