* limitations under the License.
*/
+#include "Conv2D.h"
+
#include "Convert.h"
#include "GraphBuilder.h"
#include "GraphBuilderContext.h"
/**
* @brief GraphBuilder for Conv2D node
*/
-class Conv2DGraphBuilder final : public GraphBuilder
+class Conv2DGraphBuilder final : public Conv2DGraphBuilderBase
{
public:
- bool validate(const tensorflow::NodeDef &) const override;
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-
-private:
- void buildCanonical(const tensorflow::NodeDef &node, GraphBuilderContext *context) const;
- void buildTF(const tensorflow::NodeDef &node, GraphBuilderContext *context) const;
};
-bool Conv2DGraphBuilder::validate(const tensorflow::NodeDef &node) const
+bool Conv2DGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
{
assert(node.input_size() == 2);
assert(context != nullptr);
if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
- buildTF(node, context);
+ {
+ Conv2DGraphBuilderImpl<ImportTarget::TensorFlow> builder;
+ builder.build(node, context);
+ }
else
- buildCanonical(node, context);
+ {
+ Conv2DGraphBuilderImpl<ImportTarget::Canonical> builder;
+ builder.build(node, context);
+ }
}
-void Conv2DGraphBuilder::buildCanonical(const tensorflow::NodeDef &node,
- GraphBuilderContext *context) const
+void Conv2DGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
{
loco::Graph *graph = context->graph();
SymbolTable *tensor_names = context->tensor_names();
updates->enroll(std::move(ker_update));
}
-void Conv2DGraphBuilder::buildTF(const tensorflow::NodeDef &node,
- GraphBuilderContext *context) const
+void Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
{
loco::Graph *graph = context->graph();
SymbolTable *tensor_names = context->tensor_names();
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __OP_CONV_2D_H__
+#define __OP_CONV_2D_H__
+
+#include "GraphBuilder.h"
+#include "ImportTarget.h"
+
+namespace moco
+{
+namespace tf
+{
+
+struct Conv2DGraphBuilderBase : public GraphBuilder
+{
+ virtual ~Conv2DGraphBuilderBase() = default;
+
+ bool validate(const tensorflow::NodeDef &) const final;
+};
+
+template <ImportTarget T> class Conv2DGraphBuilderImpl;
+
+template <>
+struct Conv2DGraphBuilderImpl<ImportTarget::Canonical> final : public Conv2DGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+template <>
+struct Conv2DGraphBuilderImpl<ImportTarget::TensorFlow> final : public Conv2DGraphBuilderBase
+{
+ void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __OP_CONV_2D_H__
* limitations under the License.
*/
+#include "Conv2D.h"
+
#include "TestHelper.h"
#include "Importer.h"
-#include "Knob.h"
#include "IR/TFConv2D.h"
#include <loco.h>
#include <memory>
+using namespace moco::tf;
using namespace moco::tf::test;
namespace
tensorflow::GraphDef graph_def;
EXPECT_TRUE(parse_graphdef(conv2d_01_pbtxtdata, graph_def));
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
- // TODO remove using knob
- if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
+ // Test loco.TF Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
verify_TFConv2D_01(graph.get());
- else
+ }
+
+ // Test loco.Canonical Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
verify_Conv2D_01(graph.get());
+ }
}
namespace
EXPECT_TRUE(parse_graphdef(conv2d_inception_pbtxtdata, graph_def));
std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
- // TODO remove using knob
- if (moco::tf::get<moco::tf::Knob::ImportAsTFConv2D>())
+ // Test loco.TF Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::TensorFlow>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
verify_TFConv2D_inception_indexed_tensor_name(graph.get());
- else
+ }
+
+ // Test loco.Canonical Importer
+ {
+ using Conv2DGraphBuilder = Conv2DGraphBuilderImpl<ImportTarget::Canonical>;
+
+ moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
+ r.add("Conv2D", stdex::make_unique<Conv2DGraphBuilder>());
+ moco::tf::Importer importer{&r};
+
+ std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
+
verify_Conv2D_inception_indexed_tensor_name(graph.get());
+ }
}