From 227a0ff5d13a9dbda2d02d3e8698db0c6e0261be Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 4 Jul 2019 19:27:03 +0900 Subject: [PATCH] [loco] Add DataType attribute to Graph Input/Output (#4097) This commit allows loco frontends to record the data type of graph-level input/output in GraphInput/GraphOutput. Signed-off-by: Jonghyun Park --- contrib/loco/include/loco/IR/Graph.h | 30 ++++++++++++++++++++++++++++-- contrib/loco/src/IR/Graph.test.cpp | 15 +++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/contrib/loco/include/loco/IR/Graph.h b/contrib/loco/include/loco/IR/Graph.h index 2f06321..7e77e0b 100644 --- a/contrib/loco/include/loco/IR/Graph.h +++ b/contrib/loco/include/loco/IR/Graph.h @@ -17,6 +17,7 @@ #ifndef __LOCO_IR_GRAPH_H__ #define __LOCO_IR_GRAPH_H__ +#include "loco/IR/DataType.h" #include "loco/IR/Nodes.h" #include "loco/ADT/ObjectPool.h" @@ -29,6 +30,31 @@ namespace loco { +// TODO Introduce Named trait +enum class Trait +{ + // Any "DataTyped" class has the following methods + // - DataType dtype(void) const; + // - void dtype(const DataType &value); + DataTyped, +}; + +template class Mixin; + +// TODO Re-implement NodeMixin using this mixin +template <> class Mixin +{ +public: + Mixin() = default; + +public: + const DataType &dtype(void) const { return _dtype; } + void dtype(const DataType &value) { _dtype = value; } + +private: + DataType _dtype = DataType::Unknown; +}; + /** * @brief Trait for elements with name */ @@ -49,7 +75,7 @@ private: /** * @brief Graph-level Input Metadata */ -class GraphInput final : private NamedEntity +class GraphInput final : private NamedEntity, public Mixin { public: LOCO_NAMED_ENTITY_EXPOSE; @@ -72,7 +98,7 @@ private: /** * @brief Graph-level Output Metadata */ -class GraphOutput final : private NamedEntity +class GraphOutput final : private NamedEntity, public Mixin { public: LOCO_NAMED_ENTITY_EXPOSE; diff --git a/contrib/loco/src/IR/Graph.test.cpp b/contrib/loco/src/IR/Graph.test.cpp index ea683ee..0a68c04 100644 --- a/contrib/loco/src/IR/Graph.test.cpp +++ b/contrib/loco/src/IR/Graph.test.cpp @@ -44,6 +44,21 @@ TEST(NamedTest, setter_and_getter) ASSERT_EQ(elem.name(), "name"); } +TEST(DataTypedMixinTest, constructor) +{ + loco::Mixin mixin; + + ASSERT_EQ(mixin.dtype(), loco::DataType::Unknown); +} + +TEST(DataTypedMixinTest, setter_and_getter) +{ + loco::Mixin mixin; + + mixin.dtype(loco::DataType::FLOAT32); + ASSERT_EQ(mixin.dtype(), loco::DataType::FLOAT32); +} + TEST(GraphTest, create_and_destroy_node) { auto g = loco::make_graph(); -- 2.7.4