Add hash() global (#18258)
authorDavid Riazati <davidriazati@fb.com>
Sat, 30 Mar 2019 01:23:28 +0000 (18:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 01:29:34 +0000 (18:29 -0700)
Summary:
This adds `hash()` which supports `int`, `str`, and `float`. It relies on `std::hash` which is implementation defined, so the result of `hash()` in TorchScript is not the same as in Python, but should satisfy the same properties.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18258

Differential Revision: D14692317

Pulled By: driazati

fbshipit-source-id: 909df5d024bb3feea157d5a203b7de53c72261c9

aten/src/ATen/core/interned_strings.h
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp

index 25783d1..dfc6905 100644 (file)
@@ -98,6 +98,7 @@ namespace c10 {
   _(aten, _set_item)               \
   _(aten, index_put_)              \
   _(aten, device)                  \
+  _(aten, hash)                    \
   _(aten, len)                     \
   _(aten, list)                    \
   _(aten, wait)                    \
index 7641070..4599c96 100644 (file)
@@ -9989,6 +9989,34 @@ a")
 
         self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
 
+    def test_hash(self):
+        def tester(fn, inputs):
+            for x in inputs:
+                for y in inputs:
+                    if x == y:
+                        self.assertEqual(fn(x), fn(y))
+                    else:
+                        self.assertNotEqual(fn(x), fn(y))
+
+        @torch.jit.script
+        def int_hash(x):
+            # type: (int) -> int
+            return hash(x)
+
+        @torch.jit.script
+        def float_hash(x):
+            # type: (float) -> int
+            return hash(x)
+
+        @torch.jit.script
+        def str_hash(x):
+            # type: (str) -> int
+            return hash(x)
+
+        tester(int_hash, (20, 21, 22))
+        tester(float_hash, (20.0, 21.00001, 22.443))
+        tester(str_hash, ("", "hello", "a"))
+
     def test_mutable_dce(self):
         @torch.jit.script
         def foo():
index 9e22d24..d000db1 100644 (file)
@@ -1552,6 +1552,14 @@ int dictGetDefault(Stack& stack) {
   return 0;
 }
 
+template<typename T>
+int hashValue(Stack& stack) {
+  auto value = pop(stack);
+  auto hash = std::hash<T>()(value.to<T>());
+  push(stack, int64_t(hash));
+  return 0;
+}
+
 RegisterOperators reg2({
 
 #define DEFINE_STRING_OP(op_name, string_op, result)                \
@@ -1914,6 +1922,11 @@ RegisterOperators reg2({
     CREATE_DICT_OPS("int"),
     CREATE_DICT_OPS("float"),
 #undef CREATE_DICT_OPS
+
+
+    Operator("aten::hash(str t) -> int", hashValue<std::string>),
+    Operator("aten::hash(int t) -> int", hashValue<int>),
+    Operator("aten::hash(float t) -> int", hashValue<double>),
 });
 
 // reference: _output_size in torch/nn/functional.py
index 1a06161..f87e066 100644 (file)
@@ -399,6 +399,7 @@ struct Environment {
           {"_to_tensor",
            std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
           {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
+          {"hash", std::make_shared<BuiltinFunction>(aten::hash, at::nullopt)},
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},