From dabb748d5bab7f26a2439496b6f96b189fc98a6f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 5 Jul 2019 17:25:26 +0900 Subject: [PATCH] [moco/tf] Import as TFBiasAdd by Knob (#4116) This will add ImportAsTFBiasAdd knob and import as TFBiasAdd when turned on. Signed-off-by: SaeHie Park --- contrib/moco-tf/src/Knob.lst | 1 + contrib/moco-tf/src/Op/BiasAdd.cpp | 70 +++++++++++++++++++++++++++++++++ contrib/moco-tf/src/Op/BiasAdd.test.cpp | 45 +++++++++++++++++++++ 3 files changed, 116 insertions(+) diff --git a/contrib/moco-tf/src/Knob.lst b/contrib/moco-tf/src/Knob.lst index f40052c..a0538e3 100644 --- a/contrib/moco-tf/src/Knob.lst +++ b/contrib/moco-tf/src/Knob.lst @@ -3,6 +3,7 @@ #endif // KNOB_BOOL // KNOB_BOOL(NAME, DEFAULT_VALUE, DESCRIPTION) +KNOB_BOOL(ImportAsTFBiasAdd, false, Import BiasAdd node as TFBiasAdd node) KNOB_BOOL(ImportAsTFConv2D, false, Import Conv2D node as TFConv2D node) KNOB_BOOL(RemoveDeadNode, false, Enable RemoveDeadNode optimization) KNOB_BOOL(RemoveForwardNode, false, Enable RemoveForwardNode optimization) diff --git a/contrib/moco-tf/src/Op/BiasAdd.cpp b/contrib/moco-tf/src/Op/BiasAdd.cpp index b50f5b2..78d6438 100644 --- a/contrib/moco-tf/src/Op/BiasAdd.cpp +++ b/contrib/moco-tf/src/Op/BiasAdd.cpp @@ -17,6 +17,9 @@ #include "Convert.h" #include "GraphBuilder.h" #include "GraphBuilderContext.h" +#include "Knob.h" + +#include "IR/TFBiasAdd.h" #include @@ -28,6 +31,7 @@ #include #include +#include namespace { @@ -75,6 +79,34 @@ void BiasInputUpdate::input(const SymbolTable *node_table) const _bias_enc->input(input_node); } +class TFBiasAddGraphUpdate final : public GraphUpdate +{ +public: + TFBiasAddGraphUpdate(moco::tf::TFBiasAdd *biasadd, std::vector &names) + : _biasadd(biasadd), _names(names) + { + } + + void input(const SymbolTable *) const override; + +private: + moco::tf::TFBiasAdd *_biasadd; + std::vector _names; +}; + +void TFBiasAddGraphUpdate::input(const SymbolTable *node_table) const +{ + assert(_names.size() == 2); + + auto value_node = node_table->node(_names[0]); + auto bias_node = node_table->node(_names[1]); + assert(value_node != nullptr); + assert(bias_node != nullptr); + + _biasadd->value(value_node); + _biasadd->bias(bias_node); +} + } // namespace namespace moco @@ -90,6 +122,10 @@ class BiasAddGraphBuilder final : public GraphBuilder public: bool validate(const tensorflow::NodeDef &) const override; void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; + +private: + void buildCanonical(const tensorflow::NodeDef &node, GraphBuilderContext *context) const; + void buildTF(const tensorflow::NodeDef &node, GraphBuilderContext *context) const; }; bool BiasAddGraphBuilder::validate(const tensorflow::NodeDef &node) const @@ -116,6 +152,15 @@ void BiasAddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon { assert(context != nullptr); + if (moco::tf::get()) + buildTF(node, context); + else + buildCanonical(node, context); +} + +void BiasAddGraphBuilder::buildCanonical(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); @@ -157,6 +202,31 @@ void BiasAddGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderCon updates->enroll(std::move(bias_update)); } +void BiasAddGraphBuilder::buildTF(const tensorflow::NodeDef &node, + GraphBuilderContext *context) const +{ + loco::Graph *graph = context->graph(); + SymbolTable *tensor_names = context->tensor_names(); + UpdateQueue *updates = context->updates(); + + // tensorflow data_format: one of NHWC or NCHW. + auto data_layout = get_string_attr(node, "data_format"); + auto tf_bias_add = graph->nodes()->create(); + + tf_bias_add->data_layout(data_layout); + + // To set the input node of encode_node with biasAdd_name + TensorName output_name(node.name(), 0); + tensor_names->enroll(output_name, tf_bias_add); + + std::vector input_names; + input_names.push_back(TensorName(node.input(0))); + input_names.push_back(TensorName(node.input(1))); + + auto update = stdex::make_unique(tf_bias_add, input_names); + updates->enroll(std::move(update)); +} + } // namespace tf } // namespace moco diff --git a/contrib/moco-tf/src/Op/BiasAdd.test.cpp b/contrib/moco-tf/src/Op/BiasAdd.test.cpp index b679950..93c230a 100644 --- a/contrib/moco-tf/src/Op/BiasAdd.test.cpp +++ b/contrib/moco-tf/src/Op/BiasAdd.test.cpp @@ -17,6 +17,8 @@ #include "TestHelper.h" #include "Importer.h" +#include "Knob.h" +#include "IR/TFBiasAdd.h" #include @@ -103,6 +105,34 @@ TEST(TensorFlowImport, bias_add_01) EXPECT_TRUE(parse_graphdef(bias_add_01_pbtxtdata, graph_def)); std::unique_ptr graph = importer.import(signature, graph_def); + if (moco::tf::get()) + { + loco::Graph::NodeContext *loco_nodes = graph->nodes(); + loco::Graph::InputContext *loco_inputs = graph->inputs(); + ASSERT_EQ(loco_inputs->size(), 0); + ASSERT_EQ(loco_nodes->size(), 4); + + int idx = 0; + + loco::ConstGen *value = dynamic_cast(loco_nodes->at(idx++)); + loco::ConstGen *bias = dynamic_cast(loco_nodes->at(idx++)); + moco::tf::TFBiasAdd *bias_add = dynamic_cast(loco_nodes->at(idx++)); + loco::Push *push = dynamic_cast(loco_nodes->at(idx++)); + + ASSERT_NE(value, nullptr); + ASSERT_NE(bias, nullptr); + ASSERT_NE(bias_add, nullptr); + ASSERT_NE(push, nullptr); + + ASSERT_TRUE(bias_add->value() == value); + ASSERT_TRUE(bias_add->bias() == bias); + ASSERT_TRUE(push->from() == bias_add); + ASSERT_TRUE(bias_add->data_layout() == "NHWC"); + } + else + { + // TODO fix indentation and remove clang switch + // clang-format off // test 1. // loco node : 1. ConstGen ------------------+-- 4. BiasAdd -- 5. Push // 2. ConstGen - 3. BiasEncode -/ @@ -138,6 +168,8 @@ TEST(TensorFlowImport, bias_add_01) // axis ASSERT_EQ(bias_add->axis(), 3); // NHWC + // clang-format on + } } namespace @@ -219,6 +251,17 @@ TEST(TensorFlowImport, bias_add_NCHW_axis) EXPECT_TRUE(parse_graphdef(bias_add_NCHW_pbtxtdata, graph_def)); std::unique_ptr graph = importer.import(signature, graph_def); + if (moco::tf::get()) + { + loco::Graph::NodeContext *loco_nodes = graph->nodes(); + moco::tf::TFBiasAdd *bias_add = dynamic_cast(loco_nodes->at(2)); + ASSERT_NE(bias_add, nullptr); + } + else + { + // TODO fix indentation and remove clang switch + // clang-format off + // testing axis value of biasAdd loco::Graph::NodeContext *loco_nodes = graph->nodes(); loco::BiasAdd *bias_add = @@ -226,4 +269,6 @@ TEST(TensorFlowImport, bias_add_NCHW_axis) ASSERT_NE(bias_add, nullptr); ASSERT_EQ(bias_add->axis(), 1); // NCHW + // clang-format on + } } -- 2.7.4