From 78bf1a906586e137132325a0f39799fdecf219b5 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 13 Dec 2018 22:10:56 -0800 Subject: [PATCH] Revert D13407930: [pytorch][PR] Support torch.tensor in script Differential Revision: D13407930 Original commit changeset: d17f1195a221 fbshipit-source-id: f4458872c48ec4a2c9983b21ed90bcdc0ae665b7 --- aten/src/ATen/core/ivalue.h | 1 - test/test_jit.py | 88 ------------ torch/csrc/jit/constants.cpp | 3 +- torch/csrc/jit/pybind_utils.h | 2 - torch/csrc/jit/register_prim_ops.cpp | 90 +++--------- torch/csrc/jit/register_special_ops.cpp | 239 +++----------------------------- torch/csrc/jit/script/compiler.cpp | 5 +- 7 files changed, 45 insertions(+), 383 deletions(-) diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index ed46d11..8b99ecc 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -639,7 +639,6 @@ DEFINE_TO(int64_t, toInt) DEFINE_TO(bool, toBool) DEFINE_TO(c10::intrusive_ptr, toDoubleList) DEFINE_TO(c10::intrusive_ptr, toIntList) -DEFINE_TO(c10::intrusive_ptr, toBoolList) DEFINE_TO(c10::intrusive_ptr, toTensorList) DEFINE_TO(c10::intrusive_ptr, toGenericList) DEFINE_TO(c10::intrusive_ptr, toString) diff --git a/test/test_jit.py b/test/test_jit.py index ae6ba8d..732229a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2782,17 +2782,6 @@ class TestScript(JitTestCase): return torch.ones(x), x self.checkScript(stuff3, ([3, 2],)) - def test_bool_list_io(self): - @torch.jit.script - def stuff4(x): - # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]] - return x, [True, False], [[True]] - - li_1, li_2, li_3 = stuff4([True]) - li_3 = li_3[0] - for li in [li_1, li_2, li_3]: - self.assertTrue(type(li[0]) == type(True)) - def test_nested_list(self): def foo(z): # type: (Tuple[int, List[List[int]]]) -> int @@ -4442,83 +4431,6 @@ a") def test_tensor_number_math(self): self._test_tensor_number_math() - def test_torch_tensor_bad_input(self): - with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, " - "or bools, got None"): - @torch.jit.script - def test(): - return torch.tensor([None]) - - with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"): - @torch.jit.script - def tmp(): - return torch.tensor([]) - - @torch.jit.script - def foo(): - return torch.tensor([[2, 2], [1]]) - with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"): - foo() - - def test_torch_tensor_empty_list(self): - def func(): - return torch.tensor(torch.jit.annotate(List[int], [])) - cu = torch.jit.script(func) - t1 = cu() - t2 = func() - - # torchscript returns int tensor, python returns float tensor - self.assertNotEqual(t1.dtype, t2.dtype) - - def func(): - li = torch.jit.annotate(List[int], []) - return torch.tensor([li, li]) - - self.checkScript(func, ()) - - def func(): - li = torch.jit.annotate(List[int], []) - return torch.tensor([[[li]]]) - - self.checkScript(func, ()) - - def test_torch_tensor(self): - template = dedent(''' - def func(): - li = {list_create} - return torch.tensor(li {options}) - ''') - - lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", - "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] - - dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", - ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short", - ", dtype=torch.int", ", dtype=torch.long"] - - devices = ['', ", device='cpu'"] - if RUN_CUDA: - devices.append(", device='cuda'") - - option_pairs = [dtype + device for dtype in dtypes for device in devices] - for li in lists: - for option in option_pairs: - # tensor from empty list is type float in python and annotated type in torchscript - if "annotate" in li and "dtype" not in option: - continue - code = template.format(list_create=li, options=option) - scope = {} - exec(code, globals(), scope) - cu = torch.jit.CompilationUnit(code) - t1 = cu.func() - t2 = scope['func']() - if t1.dtype == torch.float16: # equality NYI for half tensor - self.assertTrue(str(t1) == str(t2)) - else: - self.assertEqual(t1, t2) - self.assertEqual(t1.dtype, t2.dtype) - self.assertEqual(t1.device, t2.device) - @unittest.skipIf(not RUN_CUDA, "No CUDA") @skipIfRocm def test_tensor_number_math_cuda(self): diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 42ffc7f..4912d58 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -108,8 +108,7 @@ RegisterOperators reg({ return 0; }; } else if(type->isSubtypeOf(ListType::ofBools())) { - auto int_list = node->is(attr::value); - std::vector bs(int_list.begin(), int_list.end()); + auto bs = node->is(attr::value); return [bs](Stack& stack) { push(stack, bs); return 0; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 40acf42..2b6c336 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -172,8 +172,6 @@ inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional repeated(*N, value); return repeated; } - case TypeKind::BoolType: - return py::cast>(obj); case TypeKind::TensorType: case TypeKind::DynamicType: return py::cast>(obj); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 2fd4938..10e91a5 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -54,19 +54,6 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) { } } -template // int64_t, bool, double -Operation listConstruct(int64_t num_inputs) { - return [=](Stack& stack) { - auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); - std::vector vals = fmap(inputs, [](const IValue& v) { - return v.to(); - }); - drop(stack, num_inputs); - push(stack, std::move(vals)); - return 0; - }; -} - RegisterOperators reg({ Operator( prim::FusionGroup, @@ -579,11 +566,25 @@ RegisterOperators reg({ const auto num_inputs = node->inputs().size(); ListTypePtr lt = node->output()->type()->expect(); if(IntType::get() == lt->getElementType()) { - return listConstruct(num_inputs); + return [=](Stack& stack) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + std::vector vals = fmap(inputs, [](const IValue& v) { + return v.toInt(); + }); + drop(stack, num_inputs); + push(stack, std::move(vals)); + return 0; + }; } else if(FloatType::get() == lt->getElementType()) { - return listConstruct(num_inputs); - } else if (lt->getElementType() == BoolType::get()) { - return listConstruct(num_inputs); + return [=](Stack& stack) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + std::vector vals = fmap(inputs, [](const IValue& v) { + return v.toDouble(); + }); + drop(stack, num_inputs); + push(stack, std::move(vals)); + return 0; + }; } else if (lt->getElementType()->isSubtypeOf(DynamicType::get())) { return [=](Stack& stack) { const size_t stack_size = stack.size(); @@ -747,16 +748,6 @@ typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) { return list->elements()[normalized_idx]; } -// cannot return a reference to an element in a bool vector -bool getBoolItem(const std::vector& list, int64_t idx) { - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - if (normalized_idx < 0 || normalized_idx >= list_size) { - throw std::out_of_range("list index out of range"); - } - return list[normalized_idx]; -} - template Operation listAppend(const Node* node) { return [](Stack& stack) { @@ -784,21 +775,6 @@ Operation listSelect(const Node* node) { }; } -// needs specialization because cannot return a pointer to a bool in an array -template<> -Operation listSelect>(const Node* node) { - return [=](Stack& stack) { - Shared list; - int64_t idx; - pop(stack, list, idx); - - auto element = getBoolItem(list->elements(), idx); - push(stack, std::move(element)); - return 0; - }; -} - - template Operation listLen(const Node* node) { return [=](Stack& stack) { @@ -810,7 +786,6 @@ Operation listLen(const Node* node) { }; } - template Operation listEq(const Node* node) { return [=](Stack& stack) { @@ -950,29 +925,6 @@ Operation listSetItem(const Node* node) { }; } - -template<> -Operation listSetItem, bool>(const Node* node) { - return [](Stack& stack) { - Shared list; - int64_t idx; - bool value; - - pop(stack, list, idx, value); - - int64_t list_size = list->elements().size(); - auto normalized_idx = normalizeIndex(idx, list_size); - if (normalized_idx < 0 || normalized_idx >= list_size) { - throw std::out_of_range("list index out of range"); - } - list->elements()[normalized_idx] = value; - - push(stack, list); - return 0; - }; -} - - RegisterOperators reg2({ #define DEFINE_STRING_OP(op_name, string_op, result) \ @@ -1021,8 +973,6 @@ Operator( \ CREATE_IMMUTABLE_LIST_OPS("int", IntList) CREATE_IMMUTABLE_LIST_OPS("float", DoubleList) CREATE_IMMUTABLE_LIST_OPS("t", GenericList) - CREATE_IMMUTABLE_LIST_OPS("t", BoolList) - #define CREATE_LIST_OPS(decl_type, c_type) \ Operator("aten::len(" decl_type "[] a) -> int", listLen>), \ @@ -1035,7 +985,6 @@ Operator( \ CREATE_LIST_OPS("int", IntList) CREATE_LIST_OPS("float", DoubleList) CREATE_LIST_OPS("Tensor", TensorList) - CREATE_LIST_OPS("bool", BoolList) CREATE_LIST_OPS("t", GenericList) #undef CREATE_LIST_OPS @@ -1043,11 +992,10 @@ Operator( \ Operator("aten::eq(int[] a, int[] b) -> bool", listEq>), Operator("aten::eq(float[] a, float[] b) -> bool", listEq>), Operator("aten::eq(Tensor[] a, Tensor[] b) -> bool", listEq>), - Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq>), Operator("aten::ne(int[] a, int[] b) -> bool", listNe>), Operator("aten::ne(float[] a, float[] b) -> bool", listNe>), Operator("aten::ne(Tensor[] a, Tensor[] b) -> bool", listNe>), - Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe>), + #define CREATE_COPY_OP(other_type, c_type) \ Operator( \ diff --git a/torch/csrc/jit/register_special_ops.cpp b/torch/csrc/jit/register_special_ops.cpp index 2de49f8..9ae08a5 100644 --- a/torch/csrc/jit/register_special_ops.cpp +++ b/torch/csrc/jit/register_special_ops.cpp @@ -2,10 +2,7 @@ #include #include #include -#include -#include -#include -#include +#include #include #include @@ -14,126 +11,6 @@ namespace torch { namespace jit { namespace { - - -void checkListInputType(const c10::TypePtr& elem_type, const Node* node) { - if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) { - auto error = script::ErrorReport(node->getSourceLocation()); - error << "Input list to torch.tensor must be of ints, floats, or bools, " << - "got " << elem_type->str(); - // special case empty list torch.tensor([]) - if (elem_type->isSubtypeOf(DynamicType::get())) { - auto input = node->inputs().at(0); - if (input->node()->kind() == prim::ListConstruct && input->node()->inputs().size() == 0) { - error << "\n(Note: empty lists are constructed as Tensor[]; \n" - << "if you want an empty list of a different type, \n" - << "use `torch.jit.annotate(List[T], [])`, \n" - << "where `T` is the type of elements in the list)"; - } - } - throw error; - } -} - -at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { - if (type == FloatType::get()) { - return at::ScalarType::Double; - } else if (type == IntType::get()) { - return at::ScalarType::Long; - } else if (type == BoolType::get()) { - return at::ScalarType::Byte; - } - AT_ASSERTM(0, "Add new condition, expected Float, Int, or Bool but got", - type->str()); -} - - -int64_t list_size(IValue list) { - if (list.isGenericList()) { - return list.toGenericListRef().size(); - } else if (list.isIntList()) { - return list.toIntListRef().size(); - } else if (list.isDoubleList()){ - return list.toDoubleListRef().size(); - } else if (list.isBoolList()) { - return list.toBoolListRef().size(); - } - AT_ASSERTM(0, "Unexpected list type", list); -} - -std::vector compute_sizes(IValue seq) { - std::vector sizes; - // because bool, int, and float lists are specialized, inner array will - // will not be generic list - while (seq.isGenericList()) { - auto seq_list = seq.toGenericListRef(); - auto length = seq_list.size(); - AT_ASSERT(length != 0); - sizes.push_back(length); - seq = seq_list[0]; - } - sizes.push_back(list_size(seq)); - return sizes; -} - -void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { - if (seq_size != n) { - AT_ERROR("Expected sequence of length ", n, " at dim ", dim, " (got ", seq_size, ")"); - } -} - -template -void storeLastDimension(char* data, const std::vector& sizes, const c10::ArrayRef& strides, int64_t dim, - int elementSize, std::vector obj) { - auto n = sizes[dim]; - auto seq_size = obj.size(); - checkSequenceSize(n, dim, seq_size); - for (int64_t i = 0; i < n; i++) { - *(DTYPE*)data = obj[i]; - data += strides[dim] * elementSize; - } -} - -// bool vector needs to be cast to uint8_t -template<> -void storeLastDimension(char* data, const std::vector& sizes, const c10::ArrayRef& strides, int64_t dim, - int elementSize, std::vector obj) { - auto n = sizes[dim]; - auto seq_size = obj.size(); - checkSequenceSize(n, dim, seq_size); - for (int64_t i = 0; i < n; i++) { - *(uint8_t*)data = static_cast(obj[i]); - data += strides[dim] * elementSize; - } -} - -// refernce python implementation recursive_store in tensor_new.cpp - -void recursiveStore(char* data, const std::vector& sizes, const c10::ArrayRef& strides, int64_t dim, - int elementSize, IValue obj) { - - auto ndim = sizes.size(); - auto n = sizes[dim]; - auto seq_size = list_size(obj); - checkSequenceSize(n, dim, seq_size); - if (dim + 1 < static_cast(ndim)) { - auto items = obj.toGenericListRef(); - for (int64_t i = 0; i < n; i++) { - recursiveStore(data, sizes, strides, dim + 1, elementSize, items[i]); - data += strides[dim] * elementSize; - } - } else { - JIT_ASSERT(obj.isIntList() || obj.isDoubleList() || obj.isBoolList()); - if (obj.isIntList()) { - storeLastDimension(data, sizes, strides, dim, elementSize, obj.toIntListRef()); - } else if (obj.isDoubleList()){ - storeLastDimension(data, sizes, strides, dim, elementSize, obj.toDoubleListRef()); - } else { - storeLastDimension(data, sizes, strides, dim, elementSize, obj.toBoolListRef()); - } - } -} - RegisterOperators reg({ Operator( "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]", @@ -177,35 +54,6 @@ RegisterOperators reg({ return 0; }), Operator( - "aten::_infer_size(int[] a, int[] b) -> int[]", - [](const Node* node) { - return [](Stack& stack) { - auto a = pop(stack).toIntList()->elements(); - auto b = pop(stack).toIntList()->elements(); - push(stack, at::infer_size(a, b)); - return 0; - }; - }), - Operator( - "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", - [](const Node* node) { - return [](Stack& stack) { - at::Tensor weight; - at::Tensor input; - double max_norm; - double norm_type; - pop(stack, weight, input, max_norm, norm_type); - - // TODO: remove when script supports setting grad mode - torch::NoGradGuard no_grad; - - at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type); - push(stack, result); - - return 0; - }; - }), - Operator( "aten::format(str self, ...) -> str", [](const Node* node) { size_t num_inputs = node->inputs().size(); @@ -238,75 +86,32 @@ RegisterOperators reg({ return 0; }; }), - -#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \ -Operator( \ - "aten::tensor(" #operator_type " t, *, ScalarType? dtype=None, Device? device=None"\ - ") -> Tensor", \ - [](const Node* node) { \ - auto initial_scalar_type = scalarTypeFromJitType(node->inputs().at(0)->type()); \ - return [initial_scalar_type](Stack& stack) { \ - c_type scalar_val; \ - IValue dtype; \ - IValue device; \ - pop(stack, scalar_val, dtype, device); \ - auto tensor = autograd::make_variable(tensor_creation_op); \ - at::ScalarType scalar_type = dtype.isNone() ? \ - tensor.scalar_type() : dtype.toScalarType(); \ - c10::Device dev = device.isNone() ? tensor.device() : device.toDevice(); \ - if (scalar_type != initial_scalar_type || dev != tensor.device()) { \ - tensor = tensor.to(dev, scalar_type); \ - } \ - push(stack, tensor); \ - return 0; \ - }; \ - }), - -DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val)) -DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val)) -DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val)) - - - // reference python implementation: internal_new_from_data in tensor_new.cpp Operator( - "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor", + "aten::_infer_size(int[] a, int[] b) -> int[]", + [](const Node* node) { + return [](Stack& stack) { + auto a = pop(stack).toIntList()->elements(); + auto b = pop(stack).toIntList()->elements(); + push(stack, at::infer_size(a, b)); + return 0; + }; + }), + Operator( + "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", [](const Node* node) { - auto input = node->inputs().at(0); - auto elem_type = input->type(); - while (auto list_type = elem_type->cast()) { - elem_type = list_type->getElementType(); - } - checkListInputType(elem_type, node); - at::ScalarType initial_scalar_type = scalarTypeFromJitType(elem_type); - return [initial_scalar_type, elem_type](Stack& stack) { - IValue data; - IValue dtype; - IValue device; - pop(stack, data, dtype, device); - auto sizes = compute_sizes(data); - auto tensor = autograd::make_variable( - at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type))); - - recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0, - tensor.type().elementSizeInBytes(), data); - - at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType(); - c10::Device dev = device.isNone() ? tensor.device() : device.toDevice(); - if (scalar_type != initial_scalar_type || dev != tensor.device()) { - tensor = tensor.to(dev, scalar_type); - } + return [](Stack& stack) { + at::Tensor weight; + at::Tensor input; + double max_norm; + double norm_type; + pop(stack, weight, input, max_norm, norm_type); - auto default_type = at::typeMetaToScalarType(at::get_default_dtype()); + // TODO: remove when script supports setting grad mode + torch::NoGradGuard no_grad; - if (dtype.isNone() && tensor.scalar_type() != default_type && - tensor.numel() == 0) { - AT_WARN("Creating a tensor from an empty ", elem_type->str(), - "list will create a tensor of default floating point type (currently ", default_type, - ") in python but a tensor of type ", elem_type->str(), " in torchscript.\n", - "Pass in a dtype argument to ensure consistent behavior"); - } + at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type); + push(stack, result); - push(stack, tensor); return 0; }; }), diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 549d2e0..5dd09fa 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -589,6 +589,7 @@ c10::optional tryMatchSchema( failure_messages << "\nfor operator " << schema << ":\n"; return failure_messages; }; + TypeEnv type_env; std::vector positional_inputs; std::vector used_kwarg(kwargs.size(), false); @@ -603,7 +604,7 @@ c10::optional tryMatchSchema( self = c10::nullopt; } else if (!arg.kwarg_only() && used_args < args.size()) { // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1) - if (allow_conversions && arg.type()->kind() == TypeKind::ListType && // the formal must be a list + if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list !arg.N() && // it must not be a broadcasting list like int[3], otherwise // a single int is a valid input (schema_i + 1 == schema.arguments().size() || @@ -2275,7 +2276,7 @@ private: elem_type = values.at(0)->type(); } for (auto v : values) { - if (*v->type() != *elem_type) { + if (v->type() != elem_type) { throw ErrorReport(tree) << "Lists must contain only a single type, expected: " << *elem_type << " but found " << *v->type() << " instead"; -- 2.7.4