target_link_libraries(moco_tf_frontend PRIVATE cwrap)
target_link_libraries(moco_tf_frontend PRIVATE moco_log)
target_link_libraries(moco_tf_frontend PRIVATE pepper_strcast)
+target_link_libraries(moco_tf_frontend PRIVATE locomotiv)
if(NOT ENABLE_TEST)
return()
require("stdex")
require("moco-log")
require("pepper-strcast")
+require("locomotiv")
#define __MOCO_TF_TRANSFORMS_H__
#include "Transforms/ClearAnnotTransform.h"
+#include "Transforms/ConstantFoldingTransform.h"
#include "Transforms/FixPaddingTransform.h"
#include "Transforms/FixShapeTransform.h"
#include "Transforms/FuseBinaryIntoPreceding.h"
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantFoldingTransform.h"
+
+#include <loco.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <moco/Log.h>
+#include <stdex/Memory.h>
+
+#include <locomotiv/Session.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+using namespace moco::tf;
+
+uint64_t num_elements(const loco::NodeMixin<loco::NodeTrait::TensorShape> &shape)
+{
+ if (shape.rank() == 0)
+ {
+ return 0;
+ }
+
+ uint64_t res = 1;
+
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ assert(shape.dim(axis).known());
+ res *= shape.dim(axis).value();
+ }
+
+ return res;
+}
+
+/// @brief For some op, constant folding should not be performed. This returns true if node is such
+/// op.
+bool skip(const loco::Node *node)
+{
+ static std::set<uint32_t> skip_op = {
+ // TODO Current implementation works for 'Tensor' domain only. Support other domains such as
+ // `Feature`, `Filter`, `Bias`, etc.
+ static_cast<uint32_t>(loco::CanonicalOpcode::FilterEncode),
+ static_cast<uint32_t>(loco::CanonicalOpcode::FeatureEncode),
+ static_cast<uint32_t>(loco::CanonicalOpcode::BiasEncode),
+ static_cast<uint32_t>(loco::CanonicalOpcode::DepthwiseFilterEncode),
+
+ // We don't perform constant folding for Push
+ static_cast<uint32_t>(loco::CanonicalOpcode::Push),
+ };
+
+ if (node->dialect() == loco::CanonicalDialect::get())
+ {
+ if (skip_op.find(node->opnum()) != skip_op.end())
+ return true;
+ }
+
+ return false;
+}
+
+/// @brief Checks if a node is a target of constant folding transform
+bool foldable(const loco::Node *node)
+{
+ if (node->dialect() == loco::CanonicalDialect::get())
+ {
+ if (skip(node))
+ return false;
+
+ if (node->arity() == 0) // e.g., when a node is e.g, ConstGen or Pull
+ return false;
+
+ // When all args are ConstGen, let's do Constant Folding Transforms
+ for (int i = 0; i < node->arity(); i++)
+ {
+ if (node->arg(i)->opnum() != static_cast<uint32_t>(loco::CanonicalOpcode::ConstGen))
+ return false;
+ }
+
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+}
+
+void fold(loco::Graph *graph, loco::Node *node)
+{
+ assert(foldable(node)); // sanity check to find a mistake when this function is reused later
+
+ // calcluate foldable node
+ locomotiv::Session sess(graph, std::vector<loco::Node *>{node});
+ sess.infer();
+ auto data = sess.get_output(0);
+
+ assert(data != nullptr);
+
+ auto shape = data->shape();
+ auto dtype = data->dtype();
+
+ // build ConstGen
+ auto new_const = graph->nodes()->create<loco::ConstGen>();
+ {
+ new_const->dtype(dtype);
+
+ new_const->rank(shape->rank());
+ for (int d = 0; d < shape->rank(); d++)
+ new_const->dim(d) = shape->dim(d);
+
+ auto count = num_elements(*new_const);
+
+ if (dtype == loco::DataType::FLOAT32)
+ {
+ new_const->size<loco::DataType::FLOAT32>(count);
+
+ auto const_buf = data->as_f32_bufptr()->base();
+ for (int x = 0; x < count; x++)
+ new_const->at<loco::DataType::FLOAT32>(x) = const_buf[x];
+ }
+ else if (dtype == loco::DataType::S32)
+ {
+ new_const->size<loco::DataType::S32>(count);
+
+ auto const_buf = data->as_s32_bufptr()->base();
+ for (int x = 0; x < count; x++)
+ new_const->at<loco::DataType::S32>(x) = const_buf[x];
+ }
+ }
+
+ // replace node with new_const
+ loco::replace(node).with(new_const);
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+bool ConstantFoldingTransform::run(loco::Graph *graph)
+{
+ auto outputs = loco::output_nodes(graph);
+
+ bool changed = false;
+ for (auto node : loco::postorder_traversal(outputs))
+ {
+ if (foldable(node))
+ {
+ fold(graph, node);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace tf
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
+#define __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
+
+#include "Transform.h"
+
+#include <loco.h>
+
+namespace moco
+{
+namespace tf
+{
+
+/**
+ * @brief Performs constant folding optimization
+ */
+class ConstantFoldingTransform : public Transform
+{
+public:
+ const char *name(void) const final { return "ConstantFoldingTransform"; }
+
+public:
+ bool run(loco::Graph *graph) override;
+};
+
+} // namespace tf
+} // namespace moco
+
+#endif // __MOCO_TF_CONSTANT_FOLDING_TRANSFORM_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantFoldingTransform.h"
+
+#include "TestHelper.h"
+#include "IR/TFFusedBatchNorm.h"
+#include "Importer.h"
+#include "Canonicalizer.h"
+
+#include <loco.h>
+#include <locop/FormattedGraph.h>
+#include <moco/Log.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// clang-format off
+
+/*
+ test case:
+ ConstGen -- Relu -- Push
+
+ after constant folding
+ ConstGen -- Push
+*/
+const char *case01 = STRING_CONTENT(
+node {
+ name: "input"
+ op: "Const"
+ attr {
+ key: "dtype" value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape { dim { size: 2 } }
+ float_val: -3.14
+ float_val: 3.14
+ }
+ }
+ }
+}
+node {
+ name: "relu"
+ op: "Relu"
+ input: "input"
+ attr { key: "T" value { type: DT_FLOAT } }
+}
+);
+// clang-format on
+
+} // namespace
+
+namespace
+{
+
+char to_char(bool b) { return b ? 'Y' : 'N'; }
+
+} // namespace
+
+TEST(ConstantFolding, case01)
+{
+ LOGGER(l);
+
+ // load graph
+ moco::tf::Importer importer;
+ moco::tf::ModelSignature signature;
+ signature.add_output(moco::tf::TensorName("relu", 0));
+
+ tensorflow::GraphDef graph_def;
+ EXPECT_TRUE(parse_graphdef(case01, graph_def));
+ auto graph = importer.import(signature, graph_def);
+
+ // Convert graph to hold only Canonical dialect
+ moco::tf::Canonicalizer canonicalizer;
+ canonicalizer.canonicalize(graph.get());
+
+ INFO(l) << "Before ConstantFolding";
+ INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+ moco::tf::ConstantFoldingTransform transform;
+ while (transform.run(graph.get()) == true)
+ {
+ INFO(l) << "running ConstantFolding...";
+ }
+
+ INFO(l) << "After ConstantFolding ";
+ INFO(l) << locop::fmt<locop::LinearV1>(graph);
+
+ auto push = moco::tf::test::find_first_node_bytype<loco::Push>(graph.get());
+ auto const_gen = dynamic_cast<loco::ConstGen *>(push->from());
+
+ ASSERT_NE(const_gen, nullptr);
+ ASSERT_EQ(const_gen->size<loco::DataType::FLOAT32>(), 2);
+ ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(0), 0); // result of relu (-3.14)
+ ASSERT_EQ(const_gen->at<loco::DataType::FLOAT32>(1), 3.14f);
+}
+
+// TODO Add more complex cases