[Relay] use unordered_map instead of map in ANF (#3024)
author雾雨魔理沙 <lolisa@marisa.moe>
Mon, 15 Apr 2019 19:56:31 +0000 (12:56 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Mon, 15 Apr 2019 19:56:31 +0000 (12:56 -0700)
src/relay/pass/to_a_normal_form.cc

index 1f0ed9e..5e4253d 100644 (file)
@@ -34,7 +34,9 @@
 namespace tvm {
 namespace relay {
 
-Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv);
+Expr ToANormalForm(const Expr& e,
+                   const Module& m,
+                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv);
 
 struct ScopeNode;
 using Scope = std::shared_ptr<ScopeNode>;
@@ -104,7 +106,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
                             const Module& m,
                             const DependencyGraph& dg,
                             std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
-                            std::set<GlobalVar>* gv) {
+                            std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
     Fill fi(m, dg, node_scope, gv);
     return fi.GetScope(e)->ll->Get(fi.VisitExpr(e));
   }
@@ -113,13 +115,13 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
   Module mod_;
   const DependencyGraph& dg_;
   std::unordered_map<DependencyGraph::Node*, Scope>* node_scope_;
-  std::set<GlobalVar>* visited_;
+  std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited_;
   std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo;
 
   Fill(Module mod,
        const DependencyGraph& dg,
        std::unordered_map<DependencyGraph::Node*, Scope>* node_scope,
-       std::set<GlobalVar>* visited) :
+       std::unordered_set<GlobalVar, NodeHash, NodeEqual>* visited) :
     mod_(mod),
     dg_(dg),
     node_scope_(node_scope),
@@ -273,7 +275,9 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
   }
 };
 
-Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
+Expr ToANormalFormAux(const Expr& e,
+                      const Module& m,
+                      std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
   /* When you lift a lambda, what is inside is also being lift.
    *
    * So we must determine the scope of the lambda before determining the scope of it's body.
@@ -299,12 +303,14 @@ Expr ToANormalFormAux(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
   return Fill::ToANormalForm(e, m, dg, &node_scope, gv);
 }
 
-Expr ToANormalForm(const Expr& e, const Module& m, std::set<GlobalVar>* gv) {
+Expr ToANormalForm(const Expr& e,
+                   const Module& m,
+                   std::unordered_set<GlobalVar, NodeHash, NodeEqual>* gv) {
   return TransformF([&](const Expr& e) { return ToANormalFormAux(e, m, gv); }, e);
 }
 
 Expr ToANormalForm(const Expr& e, const Module& m) {
-  std::set<GlobalVar> gv;
+  std::unordered_set<GlobalVar, NodeHash, NodeEqual> gv;
   return ToANormalForm(e, m, &gv);
 }