Add generic list/dict custom op bindings (#17037)
authorDavid Riazati <davidriazati@fb.com>
Fri, 22 Feb 2019 22:38:33 +0000 (14:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 22:49:43 +0000 (14:49 -0800)
Summary:
Fixes #17017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17037

Differential Revision: D14095703

Pulled By: driazati

fbshipit-source-id: 2b5ae20d42ad21c98c86a8f1cd7f1de175510507

aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h

index 08b2298..1ae317f 100644 (file)
@@ -750,6 +750,39 @@ DEFINE_TO(c10::Device, toDevice)
 DEFINE_TO(at::ScalarType, toScalarType)
 DEFINE_TO(at::Layout, toLayout)
 
+template <typename T>
+struct _fake_type {};
+
+template <typename Elem>
+std::vector<Elem> generic_to(
+    const IValue* ivalue,
+    _fake_type<std::vector<Elem>>) {
+  return fmap(ivalue->toGenericListRef(), [](IValue item_ivalue) { return item_ivalue.to<Elem>(); });
+}
+
+template <typename K, typename V>
+std::unordered_map<K, V> generic_to(
+    const IValue* ivalue,
+    _fake_type<std::unordered_map<K, V>>) {
+  std::unordered_map<K, V> specialized_dict;
+
+  for (auto item : ivalue->toGenericDictRef()) {
+    specialized_dict[item.first.to<K>()] = item.second.to<V>();
+  }
+
+  return specialized_dict;
+}
+
+template <typename T>
+inline T IValue::to() && {
+  return generic_to(this, _fake_type<T>{});
+}
+
+template <typename T>
+inline T IValue::to() const& {
+  return generic_to(this, _fake_type<T>{});
+}
+
 // note: when adding a DEFINE_TO case here you should also add a
 // toX method to IValue. These named methods are much more discoverable
 // than the to templated function.
index 4db6f06..7091208 100644 (file)
@@ -1042,7 +1042,16 @@ template<class T> struct getTypePtr_<ArrayRef<T>> final {
     return type;
   }
 };
-template<class T> struct getTypePtr_<at::optional<T>> final {
+template <class K, class V>
+struct getTypePtr_<std::unordered_map<K, V>> final {
+  static TypePtr call() {
+    static auto type =
+        DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());
+    return type;
+  }
+};
+template <class T>
+struct getTypePtr_<at::optional<T>> final {
   static TypePtr call() {
     static auto type = OptionalType::create(getTypePtr_<T>::call());
     return type;
@@ -1050,6 +1059,8 @@ template<class T> struct getTypePtr_<at::optional<T>> final {
 };
 }
 template<class T> inline TypePtr getTypePtr() {
+  // TODO: static_assert that a templated function exists, and throw a friendy
+  // error message if not
   return detail::getTypePtr_<T>::call();
 }
 
index 1fb8276..9a425af 100644 (file)
@@ -12367,6 +12367,9 @@ graph(%x : Tensor):
   return (%1)
 ''')
 
+    def test_generic_list(self):
+        self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
+
 
 class TestJitGeneratedAutograd(JitTestCase):
     pass
index 4b091b2..ce85ec6 100644 (file)
@@ -1910,11 +1910,16 @@ at::Tensor cat(const std::vector<at::Tensor>& tensors) {
   return at::cat(tensors);
 }
 
+std::string get_first(const std::vector<std::vector<std::string>>& strings) {
+  return strings[0][0];
+}
+
 static auto reg4 =
     torch::jit::RegisterOperators()
         .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
             &leaky_relu)
-        .op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
+        .op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
+        .op("_test::get_first", &get_first);
 
 } // namespace
 } // namespace jit
index 86c4733..7855b75 100644 (file)
@@ -429,6 +429,12 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
 void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }
+void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<double>& value) {
+  AT_ERROR("Tracing float lists currently not supported!");
+}
 
 void addOutput(Node* node, const at::Tensor& output) {
   setOutput(node->addOutput(), output);
index 16c6ce7..80fd4a0 100644 (file)
@@ -96,6 +96,10 @@ TORCH_API void addInputs(
     Node* n,
     const char* name,
     const ArrayRef<double>& value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<double>& value);
 TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
 TORCH_API void addInputs(
     Node* n,
@@ -114,6 +118,21 @@ TORCH_API void addInputs(
     const c10::optional<at::ScalarType>& value);
 TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
 
+template<typename T>
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<T>& value) {
+  AT_ERROR("Tracing generic lists currently not supported!");
+}
+template<typename K, typename V>
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::unordered_map<K, V>& value) {
+  AT_ERROR("Tracing generic dicts currently not supported!");
+}
+
 template <size_t N>
 void addInputs(Node* n, const char* name, std::array<bool, N> value) {
   throw std::runtime_error(