DEFINE_TO(at::ScalarType, toScalarType)
DEFINE_TO(at::Layout, toLayout)
+template <typename T>
+struct _fake_type {};
+
+template <typename Elem>
+std::vector<Elem> generic_to(
+ const IValue* ivalue,
+ _fake_type<std::vector<Elem>>) {
+ return fmap(ivalue->toGenericListRef(), [](IValue item_ivalue) { return item_ivalue.to<Elem>(); });
+}
+
+template <typename K, typename V>
+std::unordered_map<K, V> generic_to(
+ const IValue* ivalue,
+ _fake_type<std::unordered_map<K, V>>) {
+ std::unordered_map<K, V> specialized_dict;
+
+ for (auto item : ivalue->toGenericDictRef()) {
+ specialized_dict[item.first.to<K>()] = item.second.to<V>();
+ }
+
+ return specialized_dict;
+}
+
+template <typename T>
+inline T IValue::to() && {
+ return generic_to(this, _fake_type<T>{});
+}
+
+template <typename T>
+inline T IValue::to() const& {
+ return generic_to(this, _fake_type<T>{});
+}
+
// note: when adding a DEFINE_TO case here you should also add a
// toX method to IValue. These named methods are much more discoverable
// than the to templated function.
return type;
}
};
-template<class T> struct getTypePtr_<at::optional<T>> final {
+template <class K, class V>
+struct getTypePtr_<std::unordered_map<K, V>> final {
+ static TypePtr call() {
+ static auto type =
+ DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());
+ return type;
+ }
+};
+template <class T>
+struct getTypePtr_<at::optional<T>> final {
static TypePtr call() {
static auto type = OptionalType::create(getTypePtr_<T>::call());
return type;
};
}
template<class T> inline TypePtr getTypePtr() {
+ // TODO: static_assert that a templated function exists, and throw a friendy
+ // error message if not
return detail::getTypePtr_<T>::call();
}
return (%1)
''')
+ def test_generic_list(self):
+ self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
+
class TestJitGeneratedAutograd(JitTestCase):
pass
return at::cat(tensors);
}
+std::string get_first(const std::vector<std::vector<std::string>>& strings) {
+ return strings[0][0];
+}
+
static auto reg4 =
torch::jit::RegisterOperators()
.op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
&leaky_relu)
- .op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
+ .op("_test::cat(Tensor[] inputs) -> Tensor", &cat)
+ .op("_test::get_first", &get_first);
} // namespace
} // namespace jit
void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
AT_ERROR("Tracing float lists currently not supported!");
}
+void addInputs(
+ Node* n,
+ const char* name,
+ const std::vector<double>& value) {
+ AT_ERROR("Tracing float lists currently not supported!");
+}
void addOutput(Node* node, const at::Tensor& output) {
setOutput(node->addOutput(), output);
Node* n,
const char* name,
const ArrayRef<double>& value);
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const std::vector<double>& value);
TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
TORCH_API void addInputs(
Node* n,
const c10::optional<at::ScalarType>& value);
TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
+template<typename T>
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const std::vector<T>& value) {
+ AT_ERROR("Tracing generic lists currently not supported!");
+}
+template<typename K, typename V>
+TORCH_API void addInputs(
+ Node* n,
+ const char* name,
+ const std::unordered_map<K, V>& value) {
+ AT_ERROR("Tracing generic dicts currently not supported!");
+}
+
template <size_t N>
void addInputs(Node* n, const char* name, std::array<bool, N> value) {
throw std::runtime_error(