Remove templates for GenericDict
authorDavid Riazati <davidriazati@fb.com>
Sat, 16 Feb 2019 05:32:34 +0000 (21:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 16 Feb 2019 05:35:19 +0000 (21:35 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17175

Differential Revision: D14113022

Pulled By: driazati

fbshipit-source-id: 5183e131cc8ccb58525875f76fa03133570a59ea

aten/src/ATen/core/ivalue.h
test/cpp/api/jit.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/register_prim_ops.cpp

index 05109a8..cb633ec 100644 (file)
@@ -73,37 +73,10 @@ struct DictEqualTo {
   bool operator()(const IValue& lhs, const IValue& rhs) const;
 };
 
-template <typename Key, typename Value>
-using DictUnorderedMap = std::unordered_map<Key, Value, DictHash, DictEqualTo>;
-
-template <typename Key, typename Value>
-struct CAFFE2_API Dict : c10::intrusive_ptr_target {
- private:
-  DictUnorderedMap<Key, Value> elements_;
-
- public:
-  Dict(DictUnorderedMap<Key, Value> elements_)
-      : elements_(std::move(elements_)) {}
-  static c10::intrusive_ptr<Dict> create(
-      DictUnorderedMap<Key, Value> elements_) {
-    return c10::make_intrusive<Dict>(std::move(elements_));
-  }
-  const DictUnorderedMap<Key, Value>& elements() const {
-    return elements_;
-  }
-  operator const DictUnorderedMap<Key, Value>&() const {
-    return elements();
-  }
-
-  DictUnorderedMap<Key, Value>& elements() {
-    return elements_;
-  }
-  operator DictUnorderedMap<Key, Value>&() {
-    return elements();
-  }
-};
+using UnorderedMap = std::unordered_map<IValue, IValue, DictHash, DictEqualTo>;
 
 struct Future;
+struct GenericDict;
 
 struct CAFFE2_API Tuple : public List<IValue> {
   using List<IValue>::List;
@@ -116,7 +89,6 @@ using TensorList = List<at::Tensor>;
 using DoubleList = List<double>;
 using BoolList = List<bool>;
 using GenericList = List<IValue>;
-using GenericDict = Dict<IValue, IValue>;
 
 
 }
@@ -339,7 +311,7 @@ struct CAFFE2_API IValue final {
   const std::vector<bool>& toBoolListRef() const;
   const std::vector<at::Tensor>& toTensorListRef() const;
   const std::vector<IValue>& toGenericListRef() const;
-  const ivalue::DictUnorderedMap<IValue, IValue>& toGenericDictRef() const;
+  const ivalue::UnorderedMap& toGenericDictRef() const;
   const std::string& toStringRef() const;
 
   // ConstantString
@@ -409,7 +381,7 @@ struct CAFFE2_API IValue final {
 
   // GenericDict
   IValue(c10::intrusive_ptr<ivalue::GenericDict> v);
-  IValue(ivalue::DictUnorderedMap<IValue, IValue> v);
+  IValue(ivalue::UnorderedMap v);
   bool isGenericDict() const { return Tag::GenericDict == tag; }
   c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() && {
     AT_ASSERT(isGenericDict());
@@ -692,6 +664,32 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
   FutureError error;
 };
 
+struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
+ private:
+  UnorderedMap elements_;
+
+ public:
+  GenericDict(UnorderedMap elements_)
+      : elements_(std::move(elements_)) {}
+  static c10::intrusive_ptr<GenericDict> create(
+      UnorderedMap elements_) {
+    return c10::make_intrusive<GenericDict>(std::move(elements_));
+  }
+  const UnorderedMap& elements() const {
+    return elements_;
+  }
+  operator const UnorderedMap&() const {
+    return elements();
+  }
+
+  UnorderedMap& elements() {
+    return elements_;
+  }
+  operator UnorderedMap&() {
+    return elements();
+  }
+};
+
 #undef TORCH_FORALL_TAGS
 
 namespace detail {
@@ -806,7 +804,7 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericDict> v)
 : tag(Tag::GenericDict), is_intrusive_ptr(true) {
   payload.as_intrusive_ptr = v.release();
 }
-inline IValue::IValue(ivalue::DictUnorderedMap<IValue, IValue> v)
+inline IValue::IValue(ivalue::UnorderedMap v)
 : IValue(ivalue::GenericDict::create(std::move(v))) {}
 
 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
@@ -834,7 +832,7 @@ inline const std::vector<IValue>& IValue::toGenericListRef() const {
   return toGenericList()->elements();
 }
 
-inline const c10::ivalue::DictUnorderedMap<IValue, IValue>& IValue::
+inline const c10::ivalue::UnorderedMap& IValue::
     toGenericDictRef() const {
   return toGenericDict()->elements();
 }
index 43a5958..a433578 100644 (file)
@@ -93,7 +93,7 @@ TEST(TorchScriptTest, TestDictArgMatching) {
       def dict_op(a: Dict[str, Tensor], b: str):
         return a[b]
     )JIT");
-  c10::ivalue::DictUnorderedMap<torch::jit::IValue, torch::jit::IValue> dict;
+  c10::ivalue::UnorderedMap dict;
   dict[std::string("hello")] = torch::ones({2});
   auto output = module->run_method("dict_op", dict, std::string("hello"));
   ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
index 9f43e72..e16e07a 100644 (file)
@@ -119,13 +119,13 @@ inline IValue createGenericDict(
     py::handle obj,
     const TypePtr& key_type,
     const TypePtr& value_type) {
-  at::ivalue::DictUnorderedMap<IValue, IValue> elems;
+  at::ivalue::UnorderedMap elems;
   elems.reserve(py::len(obj));
   for (auto key : obj) {
     elems.insert(std::make_pair(
         toIValue(key, key_type), toIValue(obj[key], value_type)));
   }
-  return at::ivalue::Dict<IValue, IValue>::create(std::move(elems));
+  return at::ivalue::GenericDict::create(std::move(elems));
 }
 
 inline IValue toIValue(
index b5cd2f2..969c2ab 100644 (file)
@@ -825,7 +825,7 @@ RegisterOperators reg({
                 "DictConstruct must have an even number of inputs");
           }
           return [=](Stack& stack) {
-            c10::ivalue::DictUnorderedMap<IValue, IValue> vals;
+            c10::ivalue::UnorderedMap vals;
             for (size_t i = 0; i < num_inputs; i += 2) {
               auto val = pop(stack);
               auto key = pop(stack);