Add generic list/dict custom op bindings (#17587)
authorDavid Riazati <davidriazati@fb.com>
Thu, 28 Feb 2019 22:43:05 +0000 (14:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 23:00:26 +0000 (15:00 -0800)
Summary:
Fixes #17017

Sandcastle refuses to land #17037, so trying fresh here
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17587

Differential Revision: D14265402

Pulled By: driazati

fbshipit-source-id: b942721aa9360ac6b3862f552ac95529eb0cf52c

aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h

index 533edfa..5347e48 100644 (file)
@@ -794,6 +794,39 @@ DEFINE_TO(c10::Device, toDevice)
 DEFINE_TO(at::ScalarType, toScalarType)
 DEFINE_TO(at::Layout, toLayout)
 
+template <typename T>
+struct _fake_type {};
+
+template <typename Elem>
+std::vector<Elem> generic_to(
+    const IValue* ivalue,
+    _fake_type<std::vector<Elem>>) {
+  return fmap(ivalue->toGenericListRef(), [](IValue item_ivalue) { return item_ivalue.to<Elem>(); });
+}
+
+template <typename K, typename V>
+std::unordered_map<K, V> generic_to(
+    const IValue* ivalue,
+    _fake_type<std::unordered_map<K, V>>) {
+  std::unordered_map<K, V> specialized_dict;
+
+  for (auto item : ivalue->toGenericDictRef()) {
+    specialized_dict[item.first.to<K>()] = item.second.to<V>();
+  }
+
+  return specialized_dict;
+}
+
+template <typename T>
+inline T IValue::to() && {
+  return generic_to(this, _fake_type<T>{});
+}
+
+template <typename T>
+inline T IValue::to() const& {
+  return generic_to(this, _fake_type<T>{});
+}
+
 // note: when adding a DEFINE_TO case here you should also add a
 // toX method to IValue. These named methods are much more discoverable
 // than the to templated function.
index 3e2cd3a..83c6a16 100644 (file)
@@ -1052,7 +1052,16 @@ template<class T> struct getTypePtr_<ArrayRef<T>> final {
     return type;
   }
 };
-template<class T> struct getTypePtr_<at::optional<T>> final {
+template <class K, class V>
+struct getTypePtr_<std::unordered_map<K, V>> final {
+  static TypePtr call() {
+    static auto type =
+        DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());
+    return type;
+  }
+};
+template <class T>
+struct getTypePtr_<at::optional<T>> final {
   static TypePtr call() {
     static auto type = OptionalType::create(getTypePtr_<T>::call());
     return type;
@@ -1060,6 +1069,8 @@ template<class T> struct getTypePtr_<at::optional<T>> final {
 };
 }
 template<class T> inline TypePtr getTypePtr() {
+  // TODO: static_assert that a templated function exists, and throw a friendy
+  // error message if not
   return detail::getTypePtr_<T>::call();
 }
 
index d3820a6..ceb4fe1 100644 (file)
@@ -12478,6 +12478,9 @@ graph(%x : Tensor):
   return (%1)
 ''')
 
+    def test_generic_list(self):
+        self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
+
 
 class TestJitGeneratedAutograd(JitTestCase):
     pass
index c1f425d..353357d 100644 (file)
@@ -1958,11 +1958,16 @@ at::Tensor cat(const std::vector<at::Tensor>& tensors) {
   return at::cat(tensors);
 }
 
+std::string get_first(const std::vector<std::vector<std::string>>& strings) {
+  return strings[0][0];
+}
+
 static auto reg4 =
     torch::jit::RegisterOperators()
         .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
             &leaky_relu)
-        .op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
+        .op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
+        .op("_test::get_first", &get_first);
 
 } // namespace
 } // namespace jit
index 86c4733..7855b75 100644 (file)
@@ -429,6 +429,12 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
 void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }
+void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<double>& value) {
+  AT_ERROR("Tracing float lists currently not supported!");
+}
 
 void addOutput(Node* node, const at::Tensor& output) {
   setOutput(node->addOutput(), output);
index 16c6ce7..09ea106 100644 (file)
@@ -96,6 +96,10 @@ TORCH_API void addInputs(
     Node* n,
     const char* name,
     const ArrayRef<double>& value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<double>& value);
 TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
 TORCH_API void addInputs(
     Node* n,
@@ -114,6 +118,33 @@ TORCH_API void addInputs(
     const c10::optional<at::ScalarType>& value);
 TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
 
+template<typename T>
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<T>& value);
+
+template<typename K, typename V>
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const std::unordered_map<K, V>& value);
+
+template<typename T>
+void addInputs(
+    Node* n,
+    const char* name,
+    const std::vector<T>& value) {
+  AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
+}
+template<typename K, typename V>
+void addInputs(
+    Node* n,
+    const char* name,
+    const std::unordered_map<K, V>& value) {
+  AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
+}
+
 template <size_t N>
 void addInputs(Node* n, const char* name, std::array<bool, N> value) {
   throw std::runtime_error(