+++ /dev/null
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "Canonicalization.h"
-
-#include "Knob.h"
-
-#include "Dialect/TFDialect.h"
-#include "Dialect/TFNodes.h"
-#include "Dialect/TFNodeVisitor.h"
-#include "Dialect/TFNodeImpl.h"
-
-#include <moco/tf/Names.h>
-#include <moco/Log.h>
-
-namespace moco
-{
-namespace tf
-{
-
-class TFNodeCanonicalize final : public TFNodeMutableVisitor<bool>
-{
-public:
- TFNodeCanonicalize(loco::Graph *graph) : _graph(graph){};
-
-public:
- bool visit(TFBiasAdd *node);
-
-private:
- loco::Graph *_graph;
-};
-
-bool TFNodeCanonicalize::visit(TFBiasAdd *node)
-{
- if (!moco::tf::get<moco::tf::Knob::CanonicalizeBiasAdd>())
- return false;
-
- LOGGER(l);
-
- /**
- * @note This will replace TFBiasAdd node with Canonical BiasEncode + TensorBiasAdd
- *
- * Before
- * A -- TFBiasAdd - C
- * B -/
- *
- * After
- * - TFBiasAdd -
- * A --------------- TensorBiasAdd - C
- * B - BiasEncode -/
- *
- * Where
- * A : value of TFBiasAdd
- * B : bias of TFBiasAdd
- * C : a node that uses TFBiasAdd as an input
- * TFBiasAdd is disconnected from other nodes
- */
-
- INFO(l) << "TFNodeCanonicalize TFBiasAdd begin";
-
- // tensorflow data_format: one of NHWC or NCHW.
- auto data_layout = as_DataLayout(node->data_layout());
-
- // creating loco nodes
- auto bias_enc = _graph->nodes()->create<loco::BiasEncode>();
-
- auto bias_add = _graph->nodes()->create<loco::TensorBiasAdd>();
- {
- if (data_layout == DataLayout::NHWC)
- {
- INFO(l) << "TFNodeCanonicalize TFBiasAdd axis 3";
- bias_add->axis(3);
- }
- else if (data_layout == DataLayout::NCHW)
- {
- INFO(l) << "TFNodeCanonicalize TFBiasAdd axis 1";
- bias_add->axis(1); // Channel
- // Note: the following descrition of TF 1.13 at
- // https://www.tensorflow.org/api_docs/python/tf/nn/bias_add seems wrong:
- // "bias: A 1-D Tensor with size matching the last dimension of value."
- // because providing the size of W (last dimension) to bias throws an error with TensorFlow
- }
- }
-
- auto node_A = node->value();
- auto node_B = node->bias();
-
- // update connections
- bias_add->value(node_A);
- bias_add->bias(bias_enc);
- bias_enc->input(node_B);
-
- // replace old with new : about C in above note
- replace(node).with(bias_add);
- node->value(nullptr);
- node->bias(nullptr);
-
- INFO(l) << "TFNodeCanonicalize TFBiasAdd done";
-
- return true;
-}
-
-bool Canonicalization::run(loco::Graph *graph)
-{
- auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
- bool changed = false;
-
- for (auto node : active_nodes)
- {
- if (node->dialect() == TFDialect::get())
- {
- auto tf_node = dynamic_cast<moco::tf::TFNode *>(node);
- assert(tf_node != nullptr);
- TFNodeCanonicalize canonicalize(graph);
- if (tf_node->accept(&canonicalize))
- changed = true;
- }
- }
-
- return changed;
-}
-
-} // namespace tf
-} // namespace moco
+++ /dev/null
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __MOCO_TF_CANONICALIZATION_H__
-#define __MOCO_TF_CANONICALIZATION_H__
-
-#include "Transform.h"
-
-#include <loco.h>
-
-namespace moco
-{
-namespace tf
-{
-
-/**
- * @brief Convert to Canonical dialect
- */
-class Canonicalization : public Transform
-{
-public:
- const char *name(void) const final { return "Canonicalization"; }
-
-public:
- bool run(loco::Graph *graph) override;
-};
-
-} // namespace tf
-} // namespace moco
-
-#endif // __MOCO_TF_CANONICALIZATION_H__