return me;
}
-private:
- GraphBuilderRegistry()
+public:
+ void add(const std::string op, std::unique_ptr<GraphBuilder> &&builder)
{
- // TODO add operations
+ _builder_map[op] = std::move(builder);
}
private:
- std::map<std::string, std::unique_ptr<GraphBuilder>> _builder_map;
+ std::map<const std::string, std::unique_ptr<GraphBuilder>> _builder_map;
};
} // namespace onnx
} // namespace moco
+#include <stdex/Memory.h>
+
+#define REGISTER_OP_BUILDER(NAME, BUILDER) \
+ namespace \
+ { \
+ __attribute__((constructor)) void reg_op(void) \
+ { \
+ std::unique_ptr<moco::onnx::BUILDER> builder = stdex::make_unique<moco::onnx::BUILDER>(); \
+ moco::onnx::GraphBuilderRegistry::get().add(#NAME, std::move(builder)); \
+ } \
+ }
+
#endif // __MOCO_FRONTEND_ONNX_GRAPH_BUILDER_REGISTRY_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+
+#include <cassert>
+
+namespace moco
+{
+namespace onnx
+{
+
+/**
+ * @brief GraphBuilder for Identity node
+ */
+class IdentityGraphBuilder : public GraphBuilder
+{
+public:
+ bool validate(const ::onnx::NodeProto &) const override;
+ void build(const ::onnx::NodeProto &, GraphBuilderContext *) const override;
+};
+
+bool IdentityGraphBuilder::validate(const ::onnx::NodeProto &node) const { return true; }
+
+void IdentityGraphBuilder::build(const ::onnx::NodeProto &node, GraphBuilderContext *context) const
+{
+ assert(context != nullptr);
+
+ loco::Graph *graph = context->graph();
+ SymbolTable *nodes = context->nodes();
+ SymbolTable *input_names = context->input_names();
+
+ // Create a "Forward" node for Identity
+ auto forward_node = graph->nodes()->create<loco::Forward>();
+
+ nodes->enroll(node.name(), forward_node);
+
+ // Record all inputs to forward_node
+ for (int i = 0; i < node.input_size(); ++i)
+ {
+ const auto &input_name = node.input(i);
+ input_names->list(forward_node, input_name);
+ }
+}
+
+} // namespace onnx
+} // namespace moco
+
+#include "GraphBuilderRegistry.h"
+
+REGISTER_OP_BUILDER(Identity, IdentityGraphBuilder)