[locoex/custom op] enhancing locoex custom op (#5810)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 24 Jul 2019 22:43:02 +0000 (07:43 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 24 Jul 2019 22:43:02 +0000 (07:43 +0900)
Now customop node can store dtype info. Two methods are defined as const.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/locoex-customop/include/locoex/COpCall.h
compiler/locoex-customop/src/COpCall.cpp

index 716377f..197fd8d 100644 (file)
@@ -33,7 +33,8 @@ namespace locoex
  * @brief Class to calls custom operation
  */
 class COpCall final : public VariadicArityNode<COpNode>,
-                      public loco::NodeMixin<loco::NodeTrait::TensorShape>
+                      public loco::NodeMixin<loco::NodeTrait::TensorShape>,
+                      public loco::NodeMixin<loco::NodeTrait::DataType>
 {
 public:
   COpCall(unsigned arity) : VariadicArityNode<COpNode>(arity) {}
@@ -53,10 +54,10 @@ public:
 
   /// @brief  Retrieve attr_data stored with attr_name
   template <COpAttrType AT>
-  const typename AttrTypeTrait<AT>::Type *attr(const std::string &attr_name);
+  const typename AttrTypeTrait<AT>::Type *attr(const std::string &attr_name) const;
 
   /// @brief get all the names of attr
-  std::vector<std::string> attr_names();
+  std::vector<std::string> attr_names() const;
 
 private:
   std::string _op;
index 527e0d4..0299147 100644 (file)
@@ -22,12 +22,13 @@ namespace locoex
 {
 
 template <COpAttrType AT>
-const typename AttrTypeTrait<AT>::Type *COpCall::attr(const std::string &attr_name)
+const typename AttrTypeTrait<AT>::Type *COpCall::attr(const std::string &attr_name) const
 {
   COpAttrData *attr_data;
-  if (_attrs.find(attr_name) != _attrs.end())
+  auto found = _attrs.find(attr_name);
+  if (found != _attrs.end())
   {
-    attr_data = _attrs[attr_name].get();
+    attr_data = found->second.get();
     return dynamic_cast<const typename AttrTypeTrait<AT>::Type *>(attr_data);
   }
   else
@@ -42,12 +43,11 @@ void COpCall::attr(const std::string &attr_name, std::unique_ptr<COpAttrData> &&
     throw std::runtime_error("Attr already inserted");
 }
 
-std::vector<std::string> COpCall::attr_names()
+std::vector<std::string> COpCall::attr_names() const
 {
   std::vector<std::string> attr_names;
 
-  for (std::map<std::string, std::unique_ptr<COpAttrData>>::iterator it = _attrs.begin();
-       it != _attrs.end(); ++it)
+  for (auto it = _attrs.cbegin(); it != _attrs.cend(); ++it)
   {
     attr_names.emplace_back(it->first);
   }
@@ -55,8 +55,9 @@ std::vector<std::string> COpCall::attr_names()
   return attr_names;
 }
 
-#define INSTANTIATE(AT) \
-  template const typename AttrTypeTrait<AT>::Type *COpCall::attr<AT>(const std::string &attr_name);
+#define INSTANTIATE(AT)                                                                            \
+  template const typename AttrTypeTrait<AT>::Type *COpCall::attr<AT>(const std::string &attr_name) \
+      const;
 
 INSTANTIATE(COpAttrType::Float)
 INSTANTIATE(COpAttrType::Int)