} // namespace
-TEST(TensorFlowImport, load_model_withio)
-{
- moco::tf::ModelSignature signature;
-
- signature.add_input(moco::tf::TensorName("Placeholder", 0));
- signature.add_output(moco::tf::TensorName("output/identity", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
-
- using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::Canonical>;
-
- moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
- r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
- moco::tf::Importer importer{&r};
-
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
- // what to test:
- // - import reads Pull
- // - import reads Forward
- // - attribute values should match
-
- auto pull = find_first_node_bytype<loco::Pull>(graph.get());
- ASSERT_NE(pull, nullptr);
- auto forward = find_first_node_bytype<loco::Forward>(graph.get());
- ASSERT_NE(forward, nullptr);
-
- ASSERT_EQ(pull->dtype(), loco::DataType::FLOAT32);
- ASSERT_EQ(pull->rank(), 4);
- loco::Dimension dim1 = 1;
- loco::Dimension dim2 = 2;
- ASSERT_EQ(pull->dim(0).value(), dim1.value());
- ASSERT_EQ(pull->dim(1).value(), dim2.value());
- ASSERT_EQ(pull->dim(2).value(), dim1.value());
- ASSERT_EQ(pull->dim(3).value(), dim2.value());
-}
-
TEST(TensorFlowImport, load_model_withio_tf)
{
moco::tf::ModelSignature signature;
tensorflow::GraphDef graph_def;
EXPECT_TRUE(plier::tf::parse_graphdef(basic_pbtxtdata, graph_def));
- using IdentityGraphBuilder = moco::tf::IdentityGraphBuilderImpl<ImportTarget::TensorFlow>;
-
- moco::tf::GraphBuilderRegistry r{&moco::tf::GraphBuilderRegistry::get()};
- // TODO add Placeholder
- r.add("Identity", stdex::make_unique<IdentityGraphBuilder>());
- moco::tf::Importer importer{&r};
+ moco::tf::Importer importer;
std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
using namespace moco::tf;
-class IdentityGraphUpdate final : public GraphUpdate
-{
-public:
- IdentityGraphUpdate(loco::Forward *node, const std::vector<TensorName> &names)
- : _node(node), _names(names)
- {
- }
-
- void input(const SymbolTable *) const override;
-
-private:
- loco::Forward *_node;
- const std::vector<TensorName> _names;
-};
-
class TFIdentityGraphUpdate final : public GraphUpdate
{
public:
const std::vector<TensorName> _names;
};
-void IdentityGraphUpdate::input(const SymbolTable *tensor_names) const
-{
- for (auto &name : _names)
- {
- loco::Node *target = tensor_names->node(name);
- _node->input(target);
- }
-}
-
void TFIdentityGraphUpdate::input(const SymbolTable *tensor_names) const
{
for (auto &name : _names)
namespace tf
{
-/**
- * @brief GraphBuilder for Identity node
- */
-class IdentityGraphBuilder final : public IdentityGraphBuilderBase
-{
-public:
- void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
-};
-
-bool IdentityGraphBuilderBase::validate(const tensorflow::NodeDef &node) const
+bool IdentityGraphBuilder::validate(const tensorflow::NodeDef &node) const
{
if (node.input_size() < 1) // from TensorFlow lite toco
return false;
{
assert(context != nullptr);
- if (moco::tf::get<moco::tf::Knob::ImportAsTFIdentity>())
- {
- IdentityGraphBuilderImpl<ImportTarget::TensorFlow> builder;
- return builder.build(node, context);
- }
- else
- {
- IdentityGraphBuilderImpl<ImportTarget::Canonical> builder;
- return builder.build(node, context);
- }
-}
-
-void IdentityGraphBuilderImpl<ImportTarget::Canonical>::build(const tensorflow::NodeDef &node,
- GraphBuilderContext *context) const
-{
- loco::Graph *graph = context->graph();
- SymbolTable *tensor_names = context->tensor_names();
- UpdateQueue *updates = context->updates();
-
- // Create a "Forward" node for Identity
- auto forward_node = graph->nodes()->create<loco::Forward>();
-
- // register string-name to node
- TensorName output_name(node.name(), 0);
- tensor_names->enroll(output_name, forward_node);
-
- // Queue node input update
- // TODO: Check if we really need multiple input handlings
- std::vector<TensorName> names;
- for (int i = 0; i < node.input_size(); ++i)
- {
- names.emplace_back(TensorName(node.input(i)));
- }
- auto update = stdex::make_unique<IdentityGraphUpdate>(forward_node, names);
- updates->enroll(std::move(update));
-}
-
-void IdentityGraphBuilderImpl<ImportTarget::TensorFlow>::build(const tensorflow::NodeDef &node,
- GraphBuilderContext *context) const
-{
loco::Graph *graph = context->graph();
SymbolTable *tensor_names = context->tensor_names();
UpdateQueue *updates = context->updates();
#define __OP_IDENTITY_H__
#include "GraphBuilder.h"
-#include "ImportTarget.h"
namespace moco
{
namespace tf
{
-struct IdentityGraphBuilderBase : public GraphBuilder
+class IdentityGraphBuilder : public GraphBuilder
{
- virtual ~IdentityGraphBuilderBase() = default;
-
+public:
bool validate(const tensorflow::NodeDef &) const final;
-};
-
-template <ImportTarget T> class IdentityGraphBuilderImpl;
-
-template <>
-struct IdentityGraphBuilderImpl<ImportTarget::Canonical> final : public IdentityGraphBuilderBase
-{
- void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
-};
-
-template <>
-struct IdentityGraphBuilderImpl<ImportTarget::TensorFlow> final : public IdentityGraphBuilderBase
-{
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const final;
};