Add string index/slice operations (#18247)
authorDavid Riazati <davidriazati@fb.com>
Mon, 1 Apr 2019 18:58:28 +0000 (11:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 19:11:35 +0000 (12:11 -0700)
Summary:
Adds support for string indexing (`"a"[0]`) and slicing (`"abc"[1:3]`)
to script.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18247

Differential Revision: D14574486

Pulled By: driazati

fbshipit-source-id: 4b42aa0881e5398ea7f112be46c0335e6e19dced

aten/src/ATen/core/interned_strings.h
test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/schema_matching.cpp

index dfc6905..db61cb0 100644 (file)
@@ -56,6 +56,7 @@ namespace c10 {
   _(prim, ListUnpack)              \
   _(prim, DictConstruct)           \
   _(prim, DictIndex)               \
+  _(prim, StringIndex)             \
   _(prim, NumToTensor)             \
   _(prim, ImplicitTensorToNum)     \
   _(prim, Bool)                    \
@@ -102,6 +103,7 @@ namespace c10 {
   _(aten, len)                     \
   _(aten, list)                    \
   _(aten, wait)                    \
+  _(aten, ord)                     \
   _(prim, unchecked_unwrap_optional)\
   FORALL_ATEN_BASE_SYMBOLS(_)      \
   _(onnx, Add)                     \
index eff70f5..7820a3c 100644 (file)
@@ -8953,7 +8953,7 @@ a")
                 return torch.jit._unwrap_optional(None)
 
     def test_indexing_error(self):
-        with self.assertRaisesRegex(RuntimeError, "only supported on lists, dictionaries, tensors, and tuples"):
+        with self.assertRaisesRegex(RuntimeError, "only supported on List, Dict, Tensor, Tuple, and str"):
             @torch.jit.script
             def test_wrong_type():
                 a = 8
@@ -10184,6 +10184,43 @@ a")
         with self.capture_stdout() as captured:
             print(fn(x, scale, shift))
 
+    def test_string_index(self):
+        def fn(x):
+            # type: (str) -> str
+            return x[2]
+
+        self.checkScript(fn, ("abcde",))
+
+    def test_ord(self):
+        def fn(x):
+            # type: (str) -> int
+            return ord(x)
+
+        self.checkScript(fn, ("h"))
+        self.checkScript(fn, ("y"))
+
+    def test_string_slicing(self):
+        def fn1(x):
+            # type: (str) -> str
+            return x[1:3]
+
+        def fn2(x):
+            # type: (str) -> str
+            return x[-1:3]
+
+        def fn3(x):
+            # type: (str) -> str
+            return x[3:1]
+
+        def fn4(x):
+            # type: (str) -> str
+            return x[3:100]
+
+        self.checkScript(fn1, ("abcdefghi",))
+        self.checkScript(fn2, ("abcdefghi",))
+        self.checkScript(fn3, ("abcdefghi",))
+        self.checkScript(fn4, ("abcdefghi",))
+
     def test_non_final_return(self):
 
         def simple(x):
index d000db1..3337bdb 100644 (file)
@@ -997,6 +997,30 @@ int64_t normalizeIndex(int64_t idx, int64_t list_size) {
   return idx;
 }
 
+int stringSlice(Stack& stack) {
+  auto step = pop(stack).toInt();
+  AT_CHECK(step == 1, "Slicing a string only supports step=1");
+
+  auto end = pop(stack).toInt();
+  auto start = pop(stack).toInt();
+  auto string = pop(stack).toStringRef();
+  const int64_t size = string.size();
+
+  // Clamp start and end to the bounds of the list
+  start = std::max(int64_t(0), normalizeIndex(start, size));
+  end = std::min(size, normalizeIndex(end, size));
+
+  if (end <= start) {
+    // Slice is empty
+    push(stack, std::string(""));
+    return 0;
+  }
+
+  std::string result(string.begin() + start, string.begin() + end);
+  push(stack, result);
+  return 0;
+}
+
 // Equivalent to list.at(idx)
 template <typename TList> // something like Shared<IntList>
 typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
@@ -1743,7 +1767,29 @@ RegisterOperators reg2({
         "aten::ne(Tensor[] a, Tensor[] b) -> bool",
         listNe<Shared<TensorList>>),
     Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
-
+    Operator(
+        "aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
+        stringSlice),
+    Operator(
+        "prim::StringIndex(str string, int index) -> str",
+        [](Stack& stack) {
+          auto index = pop(stack).toInt();
+          auto string = pop(stack).toStringRef();
+          char c = string.at(index);
+          push(stack, std::string(&c, 1));
+          return 0;
+        }),
+    Operator(
+        "aten::ord(str string) -> int",
+        [](Stack& stack) {
+          auto string = pop(stack).toStringRef();
+          AT_CHECK(
+              string.size() == 1,
+              "String for ord() must be 1 character, found",
+              string.size());
+          push(stack, int64_t(string.at(0)));
+          return 0;
+        }),
 #define CREATE_COPY_OP(other_type, c_type)                                 \
   Operator(                                                                \
       "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
index f87e066..997d2df 100644 (file)
@@ -403,6 +403,8 @@ struct Environment {
           {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
           {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
           {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
+          {"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
+          {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
           {"rangelist",
            std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
       };
@@ -2712,11 +2714,21 @@ struct to_ir {
     } else if (auto dict_type = gatherable->type()->cast<DictType>()) {
       auto* idx = emitExpr(subscript_exprs[0]);
       return emitDictIndex(loc, gatherable, idx);
+    } else if (auto string_type = gatherable->type()->cast<StringType>()) {
+      auto* idx = emitExpr(subscript_exprs[0]);
+      return emitBuiltinCall(
+          loc,
+          *graph,
+          prim::StringIndex,
+          c10::nullopt,
+          {gatherable, idx},
+          {},
+          true);
     } else {
       throw ErrorReport(loc)
-          << "Indexing only supported on lists, dictionaries, "
-             "tensors, and tuples, but got type '"
-          << gatherable->type()->str() << "'";
+          << "Indexing only supported on List, Dict, "
+             "Tensor, Tuple, and str but got type '"
+          << gatherable->type()->python_str() << "'";
     }
   }
 };
index 66cc678..30c273d 100644 (file)
@@ -153,6 +153,14 @@ Value* tryMatchArgument(
                           << " for argument '" << arg.name() << "' but found "
                           << value->type()->str() << "\n";
 
+    if (auto v = value->type()->cast<ListType>()) {
+      if (v->getElementType()->isSubtypeOf(TensorType::get())) {
+        ostream << "Empty lists default to List[Tensor]. Use torch.jit."
+                   "annotate(List[my_type], []) to create an empty list of"
+                   " another type\n";
+      }
+    }
+
     if (value->type() == NumberType::get() &&
         value->node()->kind() == aten::item) {
       ostream