From a710df9548200d180e2b575cb1d8e1821a48fede Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 31 May 2019 14:52:36 +0900 Subject: [PATCH] [loco] add int32 type (#3636) * [loco] add int32 type This will add int32 type for data types Signed-off-by: SaeHie Park * change order * another change order * move after todo --- contrib/loco/include/loco/IR/DataTypeTraits.h | 8 ++++++ contrib/loco/src/IR/Nodes.cpp | 1 + contrib/loco/src/IR/Nodes.test.cpp | 41 +++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/contrib/loco/include/loco/IR/DataTypeTraits.h b/contrib/loco/include/loco/IR/DataTypeTraits.h index 660d5d5..ef73187 100644 --- a/contrib/loco/include/loco/IR/DataTypeTraits.h +++ b/contrib/loco/include/loco/IR/DataTypeTraits.h @@ -19,6 +19,8 @@ #include "loco/IR/DataType.h" +#include + namespace loco { @@ -31,6 +33,12 @@ template struct DataTypeImpl }; // TODO Support other enum values +template <> struct DataTypeImpl +{ + // Use C++ int32_t type for 32bit integer + using Type = int32_t; +}; + template <> struct DataTypeImpl { // Use C++ float type for IEEE 32-bit floating-point numbers diff --git a/contrib/loco/src/IR/Nodes.cpp b/contrib/loco/src/IR/Nodes.cpp index 1c5c600..5f197f2 100644 --- a/contrib/loco/src/IR/Nodes.cpp +++ b/contrib/loco/src/IR/Nodes.cpp @@ -59,6 +59,7 @@ template typename DataTypeImpl
::Type &ConstGen::at(uint32_t n) template const typename DataTypeImpl
::Type &ConstGen::at
(uint32_t) const; \ template typename DataTypeImpl
::Type &ConstGen::at
(uint32_t); +INSTANTIATE(DataType::S32); INSTANTIATE(DataType::FLOAT32); #undef INSTANTIATE diff --git a/contrib/loco/src/IR/Nodes.test.cpp b/contrib/loco/src/IR/Nodes.test.cpp index 4337bce..45a4111 100644 --- a/contrib/loco/src/IR/Nodes.test.cpp +++ b/contrib/loco/src/IR/Nodes.test.cpp @@ -118,6 +118,47 @@ TEST(ConstGenTest, constructor) ASSERT_EQ(constgen_node.at(5), 5.0f); } +TEST(ConstGenTest, constructor_s32) +{ + loco::ConstGen constgen_node; + + ASSERT_EQ(constgen_node.dtype(), loco::DataType::Unknown); + ASSERT_EQ(constgen_node.rank(), 0); + + constgen_node.dtype(loco::DataType::S32); + ASSERT_EQ(constgen_node.dtype(), loco::DataType::S32); + + constgen_node.rank(2); + ASSERT_EQ(constgen_node.rank(), 2); + + constgen_node.dim(0) = loco::make_dimension(2); + constgen_node.dim(1) = loco::make_dimension(3); + + ASSERT_TRUE(constgen_node.dim(0).known()); + ASSERT_TRUE(constgen_node.dim(1).known()); + + ASSERT_EQ(constgen_node.dim(0), 2); + ASSERT_EQ(constgen_node.dim(1), 3); + + constgen_node.size(6); + + ASSERT_EQ(constgen_node.size(), 6); + + constgen_node.at(0) = 0; // Set 0,0 + constgen_node.at(1) = 1; // Set 0,1 + constgen_node.at(2) = 2; // Set 0,2 + constgen_node.at(3) = -3; // Set 1,0 + constgen_node.at(4) = -4; // Set 1,1 + constgen_node.at(5) = -5; // Set 1,2 + + ASSERT_EQ(constgen_node.at(0), 0); + ASSERT_EQ(constgen_node.at(1), 1); + ASSERT_EQ(constgen_node.at(2), 2); + ASSERT_EQ(constgen_node.at(3), -3); + ASSERT_EQ(constgen_node.at(4), -4); + ASSERT_EQ(constgen_node.at(5), -5); +} + TEST(MaxPool2DTest, constructor) { loco::MaxPool2D maxpool_node; -- 2.7.4