[moco-tf] Introduce Broadcast Helper (#7334)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 11 Sep 2019 05:02:47 +0000 (14:02 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 11 Sep 2019 05:02:47 +0000 (14:02 +0900)
* [moco-tf] Introduce Broadcast Helper

This commit introduces Broadcast Helper module which facilitates
numpy-style broadcasting support.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Fix a typo

compiler/moco-tf/CMakeLists.txt
compiler/moco-tf/src/BroadcastHelper.cpp [new file with mode: 0644]
compiler/moco-tf/src/BroadcastHelper.h [new file with mode: 0644]
compiler/moco-tf/src/BroadcastHelper.test.cpp [new file with mode: 0644]

index 6f94702..9ca5777 100644 (file)
@@ -21,6 +21,8 @@ target_include_directories(moco_tf_frontend PRIVATE src)
 target_include_directories(moco_tf_frontend PUBLIC include)
 target_link_libraries(moco_tf_frontend PUBLIC moco_tf_proto)
 target_link_libraries(moco_tf_frontend PUBLIC loco)
+target_link_libraries(moco_tf_frontend PRIVATE bino)
+target_link_libraries(moco_tf_frontend PRIVATE fipe)
 target_link_libraries(moco_tf_frontend PRIVATE locop)
 target_link_libraries(moco_tf_frontend PRIVATE stdex)
 target_link_libraries(moco_tf_frontend PRIVATE cwrap)
@@ -41,6 +43,8 @@ nncc_find_package(GTest REQUIRED)
 add_executable(moco_tf_frontend_test ${TESTS})
 target_include_directories(moco_tf_frontend_test PRIVATE src)
 target_link_libraries(moco_tf_frontend_test gtest_main)
+target_link_libraries(moco_tf_frontend_test bino)
+target_link_libraries(moco_tf_frontend_test fipe)
 target_link_libraries(moco_tf_frontend_test locop)
 target_link_libraries(moco_tf_frontend_test moco_log)
 target_link_libraries(moco_tf_frontend_test moco_tf_frontend)
diff --git a/compiler/moco-tf/src/BroadcastHelper.cpp b/compiler/moco-tf/src/BroadcastHelper.cpp
new file mode 100644 (file)
index 0000000..fc058c1
--- /dev/null
@@ -0,0 +1,226 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BroadcastHelper.h"
+
+#include <loco/IR/Nodes.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace
+{
+
+class NodeWithTensorShape
+{
+public:
+  NodeWithTensorShape() = default;
+
+public:
+  NodeWithTensorShape(loco::Node *node, const loco::TensorShape &shape) : _node{node}, _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  loco::Node *node(void) const { return _node; }
+  const loco::TensorShape &shape(void) const { return _shape; }
+
+private:
+  loco::Node *_node = nullptr;
+  loco::TensorShape _shape;
+};
+
+NodeWithTensorShape glue(loco::Node *node, const loco::TensorShape &shape)
+{
+  return NodeWithTensorShape(node, shape);
+}
+
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+  TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  loco::TensorShape to(uint32_t output_rank)
+  {
+    auto const &input_shape = _shape;
+    uint32_t const input_rank = input_shape.rank();
+
+    assert(input_rank <= output_rank && "Cannot shrink rank");
+    uint32_t const axis_shift = output_rank - input_rank;
+
+    loco::TensorShape output_shape;
+
+    output_shape.rank(output_rank);
+    for (uint32_t axis = 0; axis < output_rank; ++axis)
+    {
+      output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+    }
+
+    return output_shape;
+  }
+
+private:
+  const loco::TensorShape _shape;
+};
+
+TensorShapeExpander expand(const loco::TensorShape &shape) { return TensorShapeExpander{shape}; }
+
+/**
+ * @brief Create a rank-expanded node (if required)
+ */
+class ExpandRankFunctor final
+{
+public:
+  ExpandRankFunctor(uint32_t rank) : _rank{rank}
+  {
+    // DO NOTHING
+  }
+
+public:
+  NodeWithTensorShape operator()(const NodeWithTensorShape &in) const
+  {
+    auto const input_node = in.node();
+    auto const input_shape = in.shape();
+    auto const input_rank = input_shape.rank();
+
+    uint32_t const expected_rank = _rank;
+
+    assert(input_rank <= expected_rank);
+    if (input_rank == expected_rank)
+    {
+      // Nothing to expand
+      return in;
+    }
+
+    auto g = input_node->graph();
+    assert(g != nullptr);
+
+    auto output_shape = expand(input_shape).to(expected_rank);
+    auto output_node = g->nodes()->create<loco::FixedReshape>();
+
+    output_node->input(input_node);
+    output_node->rank(expected_rank);
+    for (uint32_t axis = 0; axis < expected_rank; ++axis)
+    {
+      output_node->dim(axis) = output_shape.dim(axis);
+    }
+
+    return glue(output_node, output_shape);
+  }
+
+private:
+  uint32_t _rank;
+};
+
+ExpandRankFunctor expand_rank_to(uint32_t rank) { return ExpandRankFunctor{rank}; }
+
+/**
+ * @brief Create a dimension-expanded node (if required)
+ */
+class ExpandDimsFunctor final
+{
+public:
+  ExpandDimsFunctor(const loco::TensorShape &shape) : _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  NodeWithTensorShape operator()(const NodeWithTensorShape &in) const
+  {
+    auto const input_node = in.node();
+    auto const input_shape = in.shape();
+    const auto &output_shape = _shape;
+
+    assert(input_shape.rank() == output_shape.rank());
+
+    if (input_shape == output_shape)
+    {
+      // Nothing to expand
+      return in;
+    }
+
+    uint32_t const rank = output_shape.rank();
+
+    auto g = input_node->graph();
+    assert(g != nullptr);
+
+    auto output_node = g->nodes()->create<loco::TensorBroadcast>();
+
+    for (uint32_t axis = 0; axis < rank; ++axis)
+    {
+      auto input_dim = input_shape.dim(axis);
+      auto output_dim = output_shape.dim(axis);
+
+      assert(input_dim.known() and output_dim.known());
+
+      if (!(input_dim == output_dim))
+      {
+        assert(input_dim == 1);
+        output_node->mapping()->dim(axis) = output_dim;
+      }
+    }
+
+    output_node->input(input_node);
+
+    return glue(output_node, output_shape);
+  }
+
+private:
+  loco::TensorShape _shape;
+};
+
+ExpandDimsFunctor expand_dims_as(const loco::TensorShape &shape)
+{
+  return ExpandDimsFunctor{shape};
+}
+
+} // namespace
+
+namespace moco
+{
+namespace tf
+{
+
+loco::Node *BroadcastFunctor::build(loco::Node *node, const loco::TensorShape &shape) const
+{
+  // clang-format off
+  return glue(node, shape)
+       | expand_rank_to(_shape.rank())
+       | expand_dims_as(_shape)
+       | [] (const NodeWithTensorShape &in) { return in.node(); };
+  // clang-format on
+}
+
+loco::Node *BroadcastFunctor::build(loco::Node *node) const
+{
+  return build(node, loco::shape_get(node).as<loco::TensorShape>());
+}
+
+} // namespace tf
+} // namespace moco
diff --git a/compiler/moco-tf/src/BroadcastHelper.h b/compiler/moco-tf/src/BroadcastHelper.h
new file mode 100644 (file)
index 0000000..6238ad2
--- /dev/null
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __BROADCAST_HELPER_H__
+#define __BROADCAST_HELPER_H__
+
+#include <loco/IR/Node.h>
+#include <loco/IR/Dimension.h>
+#include <loco/IR/TensorShape.h>
+
+#include <bino.h>
+#include <fipe.h> // include "fipe.h" for clients
+
+namespace moco
+{
+namespace tf
+{
+
+class BroadcastFunctor final
+{
+public:
+  BroadcastFunctor(const loco::TensorShape &shape) : _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  loco::Node *build(loco::Node *in_node, const loco::TensorShape &in_shape) const;
+
+  loco::Node *operator()(loco::Node *in_node, const loco::TensorShape &in_shape) const
+  {
+    return build(in_node, in_shape);
+  }
+
+  // This method assumes the followings:
+  // - loco::shape_known(node) returns true, and
+  // - loco::shape_get(node).domain() is loco::Domain::Tensor
+  loco::Node *build(loco::Node *node) const;
+
+  loco::Node *operator()(loco::Node *node) const { return build(node); }
+
+private:
+  loco::TensorShape _shape;
+};
+
+/**
+ * @brief Create a broadcasted node
+ *
+ * First, append canonical.FixedReshape if rank expansion is required.
+ * Then, append canonical.TensorBroadcast if dimension expansion is required
+ *
+ * This mimics "tf.broadcast_to" API in TensorFlow.
+ */
+static inline auto broadcast_to(const loco::TensorShape &shape)
+    -> decltype(bino::transform_both(std::declval<BroadcastFunctor>()))
+{
+  return bino::transform_both(BroadcastFunctor{shape});
+}
+
+} // namespace tf
+} // namespace moco
+
+#endif // __BROADCAST_HELPER_H__
diff --git a/compiler/moco-tf/src/BroadcastHelper.test.cpp b/compiler/moco-tf/src/BroadcastHelper.test.cpp
new file mode 100644 (file)
index 0000000..a6cbd71
--- /dev/null
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BroadcastHelper.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(BroadcastFunctorTest, expand_rank)
+{
+  // Broadcast Tensor<3> as Tensor<1 x 3>
+  auto g = loco::make_graph();
+
+  auto input = g->inputs()->create();
+
+  auto pull = g->nodes()->create<loco::Pull>();
+  pull->index(0);
+
+  loco::TensorShape current_shape;
+  {
+    current_shape.rank(1);
+    current_shape.dim(0) = 3;
+  }
+
+  loco::TensorShape expected_shape;
+  {
+    expected_shape.rank(2);
+    expected_shape.dim(0) = 1;
+    expected_shape.dim(1) = 3;
+  }
+
+  moco::tf::BroadcastFunctor functor{expected_shape};
+
+  auto node = functor.build(pull, current_shape);
+
+  ASSERT_EQ(node->opnum(), static_cast<uint32_t>(loco::CanonicalOpcode::FixedReshape));
+  ASSERT_EQ(node->arg(0), pull);
+}
+
+TEST(BroadcastFunctorTest, expand_dims)
+{
+  // Broadcast Tensor<1> as Tensor<3>
+  auto g = loco::make_graph();
+
+  auto input = g->inputs()->create();
+
+  auto pull = g->nodes()->create<loco::Pull>();
+  pull->index(0);
+
+  loco::TensorShape current_shape;
+  {
+    current_shape.rank(1);
+    current_shape.dim(0) = 1;
+  }
+
+  loco::TensorShape expected_shape;
+  {
+    expected_shape.rank(1);
+    expected_shape.dim(0) = 3;
+  }
+
+  moco::tf::BroadcastFunctor functor{expected_shape};
+
+  auto node = functor.build(pull, current_shape);
+
+  ASSERT_EQ(node->opnum(), static_cast<uint32_t>(loco::CanonicalOpcode::TensorBroadcast));
+  ASSERT_EQ(node->arg(0), pull);
+
+  auto tensor_broadcast = dynamic_cast<loco::TensorBroadcast *>(node);
+
+  ASSERT_NE(tensor_broadcast, nullptr);
+  ASSERT_TRUE(tensor_broadcast->mapping()->defined(0));
+  ASSERT_EQ(tensor_broadcast->mapping()->dim(0), 3);
+}