_(prim, ListUnpack) \
_(prim, DictConstruct) \
_(prim, DictIndex) \
+ _(prim, StringIndex) \
_(prim, NumToTensor) \
_(prim, ImplicitTensorToNum) \
_(prim, Bool) \
_(aten, len) \
_(aten, list) \
_(aten, wait) \
+ _(aten, ord) \
_(prim, unchecked_unwrap_optional)\
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
return torch.jit._unwrap_optional(None)
def test_indexing_error(self):
- with self.assertRaisesRegex(RuntimeError, "only supported on lists, dictionaries, tensors, and tuples"):
+ with self.assertRaisesRegex(RuntimeError, "only supported on List, Dict, Tensor, Tuple, and str"):
@torch.jit.script
def test_wrong_type():
a = 8
with self.capture_stdout() as captured:
print(fn(x, scale, shift))
+ def test_string_index(self):
+ def fn(x):
+ # type: (str) -> str
+ return x[2]
+
+ self.checkScript(fn, ("abcde",))
+
+ def test_ord(self):
+ def fn(x):
+ # type: (str) -> int
+ return ord(x)
+
+ self.checkScript(fn, ("h"))
+ self.checkScript(fn, ("y"))
+
+ def test_string_slicing(self):
+ def fn1(x):
+ # type: (str) -> str
+ return x[1:3]
+
+ def fn2(x):
+ # type: (str) -> str
+ return x[-1:3]
+
+ def fn3(x):
+ # type: (str) -> str
+ return x[3:1]
+
+ def fn4(x):
+ # type: (str) -> str
+ return x[3:100]
+
+ self.checkScript(fn1, ("abcdefghi",))
+ self.checkScript(fn2, ("abcdefghi",))
+ self.checkScript(fn3, ("abcdefghi",))
+ self.checkScript(fn4, ("abcdefghi",))
+
def test_non_final_return(self):
def simple(x):
return idx;
}
+int stringSlice(Stack& stack) {
+ auto step = pop(stack).toInt();
+ AT_CHECK(step == 1, "Slicing a string only supports step=1");
+
+ auto end = pop(stack).toInt();
+ auto start = pop(stack).toInt();
+ auto string = pop(stack).toStringRef();
+ const int64_t size = string.size();
+
+ // Clamp start and end to the bounds of the list
+ start = std::max(int64_t(0), normalizeIndex(start, size));
+ end = std::min(size, normalizeIndex(end, size));
+
+ if (end <= start) {
+ // Slice is empty
+ push(stack, std::string(""));
+ return 0;
+ }
+
+ std::string result(string.begin() + start, string.begin() + end);
+ push(stack, result);
+ return 0;
+}
+
// Equivalent to list.at(idx)
template <typename TList> // something like Shared<IntList>
typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
"aten::ne(Tensor[] a, Tensor[] b) -> bool",
listNe<Shared<TensorList>>),
Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
-
+ Operator(
+ "aten::slice(str string, int start, int end=9223372036854775807, int step=1) -> str",
+ stringSlice),
+ Operator(
+ "prim::StringIndex(str string, int index) -> str",
+ [](Stack& stack) {
+ auto index = pop(stack).toInt();
+ auto string = pop(stack).toStringRef();
+ char c = string.at(index);
+ push(stack, std::string(&c, 1));
+ return 0;
+ }),
+ Operator(
+ "aten::ord(str string) -> int",
+ [](Stack& stack) {
+ auto string = pop(stack).toStringRef();
+ AT_CHECK(
+ string.size() == 1,
+ "String for ord() must be 1 character, found",
+ string.size());
+ push(stack, int64_t(string.at(0)));
+ return 0;
+ }),
#define CREATE_COPY_OP(other_type, c_type) \
Operator( \
"aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
{"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
{"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
{"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
+ {"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
+ {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
{"rangelist",
std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
};
} else if (auto dict_type = gatherable->type()->cast<DictType>()) {
auto* idx = emitExpr(subscript_exprs[0]);
return emitDictIndex(loc, gatherable, idx);
+ } else if (auto string_type = gatherable->type()->cast<StringType>()) {
+ auto* idx = emitExpr(subscript_exprs[0]);
+ return emitBuiltinCall(
+ loc,
+ *graph,
+ prim::StringIndex,
+ c10::nullopt,
+ {gatherable, idx},
+ {},
+ true);
} else {
throw ErrorReport(loc)
- << "Indexing only supported on lists, dictionaries, "
- "tensors, and tuples, but got type '"
- << gatherable->type()->str() << "'";
+ << "Indexing only supported on List, Dict, "
+ "Tensor, Tuple, and str but got type '"
+ << gatherable->type()->python_str() << "'";
}
}
};
<< " for argument '" << arg.name() << "' but found "
<< value->type()->str() << "\n";
+ if (auto v = value->type()->cast<ListType>()) {
+ if (v->getElementType()->isSubtypeOf(TensorType::get())) {
+ ostream << "Empty lists default to List[Tensor]. Use torch.jit."
+ "annotate(List[my_type], []) to create an empty list of"
+ " another type\n";
+ }
+ }
+
if (value->type() == NumberType::get() &&
value->node()->kind() == aten::item) {
ostream