Support torch.tensor in script (#14913)
authorElias Ellison <eellison@fb.com>
Fri, 14 Dec 2018 01:36:21 +0000 (17:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 01:38:38 +0000 (17:38 -0800)
Summary:
Adding support for torch.tensor in script.

The input list is typed as t[], because it can be arbitrarily nested. I added a check a compile time check  that the inner type of the list is a bool, float, or int.

Also adds specialization for Boolean Lists, which already existed at the ivalue level but had not been added to the compiler yet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14913

Differential Revision: D13407930

Pulled By: eellison

fbshipit-source-id: d17f1195a22149d5b0d08d76c89a7fab8444f7c5

aten/src/ATen/core/ivalue.h
test/test_jit.py
torch/csrc/jit/constants.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/register_special_ops.cpp
torch/csrc/jit/script/compiler.cpp

index 8b99ecc..ed46d11 100644 (file)
@@ -639,6 +639,7 @@ DEFINE_TO(int64_t, toInt)
 DEFINE_TO(bool, toBool)
 DEFINE_TO(c10::intrusive_ptr<ivalue::DoubleList>, toDoubleList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::IntList>, toIntList)
+DEFINE_TO(c10::intrusive_ptr<ivalue::BoolList>, toBoolList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::TensorList>, toTensorList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericList>, toGenericList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
index 80801c3..5101a13 100644 (file)
@@ -2782,6 +2782,17 @@ class TestScript(JitTestCase):
             return torch.ones(x), x
         self.checkScript(stuff3, ([3, 2],))
 
+    def test_bool_list_io(self):
+        @torch.jit.script
+        def stuff4(x):
+            # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
+            return x, [True, False], [[True]]
+
+        li_1, li_2, li_3 = stuff4([True])
+        li_3 = li_3[0]
+        for li in [li_1, li_2, li_3]:
+            self.assertTrue(type(li[0]) == type(True))
+
     def test_nested_list(self):
         def foo(z):
             # type: (Tuple[int, List[List[int]]]) -> int
@@ -4431,6 +4442,83 @@ a")
     def test_tensor_number_math(self):
         self._test_tensor_number_math()
 
+    def test_torch_tensor_bad_input(self):
+        with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, "
+                                    "or bools, got None"):
+            @torch.jit.script
+            def test():
+                return torch.tensor([None])
+
+        with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"):
+            @torch.jit.script
+            def tmp():
+                return torch.tensor([])
+
+        @torch.jit.script
+        def foo():
+            return torch.tensor([[2, 2], [1]])
+        with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
+            foo()
+
+    def test_torch_tensor_empty_list(self):
+        def func():
+            return torch.tensor(torch.jit.annotate(List[int], []))
+        cu = torch.jit.script(func)
+        t1 = cu()
+        t2 = func()
+
+        # torchscript returns int tensor, python returns float tensor
+        self.assertNotEqual(t1.dtype, t2.dtype)
+
+        def func():
+            li = torch.jit.annotate(List[int], [])
+            return torch.tensor([li, li])
+
+        self.checkScript(func, ())
+
+        def func():
+            li = torch.jit.annotate(List[int], [])
+            return torch.tensor([[[li]]])
+
+        self.checkScript(func, ())
+
+    def test_torch_tensor(self):
+        template = dedent('''
+        def func():
+            li = {list_create}
+            return torch.tensor(li {options})
+        ''')
+
+        lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
+                 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
+
+        dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
+                  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
+                  ", dtype=torch.int", ", dtype=torch.long"]
+
+        devices = ['', ", device='cpu'"]
+        if RUN_CUDA:
+            devices.append(", device='cuda'")
+
+        option_pairs = [dtype + device for dtype in dtypes for device in devices]
+        for li in lists:
+            for option in option_pairs:
+                # tensor from empty list is type float in python and annotated type in torchscript
+                if "annotate" in li and "dtype" not in option:
+                    continue
+                code = template.format(list_create=li, options=option)
+                scope = {}
+                exec(code, globals(), scope)
+                cu = torch.jit.CompilationUnit(code)
+                t1 = cu.func()
+                t2 = scope['func']()
+                if t1.dtype == torch.float16:  # equality NYI for half tensor
+                    self.assertTrue(str(t1) == str(t2))
+                else:
+                    self.assertEqual(t1, t2)
+                self.assertEqual(t1.dtype, t2.dtype)
+                self.assertEqual(t1.device, t2.device)
+
     @unittest.skipIf(not RUN_CUDA, "No CUDA")
     @skipIfRocm
     def test_tensor_number_math_cuda(self):
index 4912d58..42ffc7f 100644 (file)
@@ -108,7 +108,8 @@ RegisterOperators reg({
             return 0;
           };
         } else if(type->isSubtypeOf(ListType::ofBools())) {
-          auto bs = node->is(attr::value);
+          auto int_list = node->is(attr::value);
+          std::vector<bool> bs(int_list.begin(), int_list.end());
           return [bs](Stack& stack) {
             push(stack, bs);
             return 0;
index 2b6c336..40acf42 100644 (file)
@@ -172,6 +172,8 @@ inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_
               std::vector<double> repeated(*N, value);
               return repeated;
             }
+          case TypeKind::BoolType:
+            return py::cast<std::vector<bool>>(obj);
           case TypeKind::TensorType:
           case TypeKind::DynamicType:
             return py::cast<std::vector<at::Tensor>>(obj);
index 10e91a5..2fd4938 100644 (file)
@@ -54,6 +54,19 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
   }
 }
 
+template <typename dtype> // int64_t, bool, double
+Operation listConstruct(int64_t num_inputs) {
+  return [=](Stack& stack) {
+    auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
+    std::vector<dtype> vals = fmap(inputs, [](const IValue& v) {
+      return v.to<dtype>();
+    });
+    drop(stack, num_inputs);
+    push(stack, std::move(vals));
+    return 0;
+  };
+}
+
 RegisterOperators reg({
     Operator(
         prim::FusionGroup,
@@ -566,25 +579,11 @@ RegisterOperators reg({
           const auto num_inputs = node->inputs().size();
           ListTypePtr lt = node->output()->type()->expect<ListType>();
           if(IntType::get() == lt->getElementType()) {
-            return [=](Stack& stack) {
-              auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
-              std::vector<int64_t> vals = fmap(inputs, [](const IValue& v) {
-                return v.toInt();
-              });
-              drop(stack, num_inputs);
-              push(stack, std::move(vals));
-              return 0;
-            };
+            return listConstruct<int64_t>(num_inputs);
           } else if(FloatType::get() == lt->getElementType()) {
-            return [=](Stack& stack) {
-              auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
-              std::vector<double> vals = fmap(inputs, [](const IValue& v) {
-                return v.toDouble();
-              });
-              drop(stack, num_inputs);
-              push(stack, std::move(vals));
-              return 0;
-            };
+            return listConstruct<double>(num_inputs);
+          } else if (lt->getElementType() == BoolType::get()) {
+            return listConstruct<bool>(num_inputs);
           } else if (lt->getElementType()->isSubtypeOf(DynamicType::get())) {
             return [=](Stack& stack) {
               const size_t stack_size = stack.size();
@@ -748,6 +747,16 @@ typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
   return list->elements()[normalized_idx];
 }
 
+// cannot return a reference to an element in a bool vector
+bool getBoolItem(const std::vector<bool>& list, int64_t idx) {
+  const int64_t list_size = list.size();
+  const int64_t normalized_idx = normalizeIndex(idx, list_size);
+  if (normalized_idx < 0 || normalized_idx >= list_size) {
+    throw std::out_of_range("list index out of range");
+  }
+  return list[normalized_idx];
+}
+
 template <typename TList, typename TElement>
 Operation listAppend(const Node* node) {
   return [](Stack& stack) {
@@ -775,6 +784,21 @@ Operation listSelect(const Node* node) {
   };
 }
 
+// needs specialization because cannot return a pointer to a bool in an array
+template<>
+Operation listSelect<Shared<BoolList>>(const Node* node) {
+  return [=](Stack& stack) {
+    Shared<BoolList> list;
+    int64_t idx;
+    pop(stack, list, idx);
+
+    auto element = getBoolItem(list->elements(), idx);
+    push(stack, std::move(element));
+    return 0;
+  };
+}
+
+
 template <typename T>
 Operation listLen(const Node* node) {
   return [=](Stack& stack) {
@@ -786,6 +810,7 @@ Operation listLen(const Node* node) {
   };
 }
 
+
 template <typename T>
 Operation listEq(const Node* node) {
   return [=](Stack& stack) {
@@ -925,6 +950,29 @@ Operation listSetItem(const Node* node) {
   };
 }
 
+
+template<>
+Operation listSetItem<Shared<BoolList>, bool>(const Node* node) {
+  return [](Stack& stack) {
+    Shared<BoolList> list;
+    int64_t idx;
+    bool value;
+
+    pop(stack, list, idx, value);
+
+    int64_t list_size = list->elements().size();
+    auto normalized_idx = normalizeIndex(idx, list_size);
+    if (normalized_idx < 0 || normalized_idx >= list_size) {
+      throw std::out_of_range("list index out of range");
+    }
+    list->elements()[normalized_idx] = value;
+
+    push(stack, list);
+    return 0;
+  };
+}
+
+
 RegisterOperators reg2({
 
 #define DEFINE_STRING_OP(op_name, string_op, result)                           \
@@ -973,6 +1021,8 @@ Operator(                                                                      \
     CREATE_IMMUTABLE_LIST_OPS("int", IntList)
     CREATE_IMMUTABLE_LIST_OPS("float", DoubleList)
     CREATE_IMMUTABLE_LIST_OPS("t", GenericList)
+    CREATE_IMMUTABLE_LIST_OPS("t", BoolList)
+
 
 #define CREATE_LIST_OPS(decl_type, c_type) \
     Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
@@ -985,6 +1035,7 @@ Operator(                                                                      \
     CREATE_LIST_OPS("int", IntList)
     CREATE_LIST_OPS("float", DoubleList)
     CREATE_LIST_OPS("Tensor", TensorList)
+    CREATE_LIST_OPS("bool", BoolList)
     CREATE_LIST_OPS("t", GenericList)
 #undef CREATE_LIST_OPS
 
@@ -992,10 +1043,11 @@ Operator(                                                                      \
     Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
     Operator("aten::eq(float[] a, float[] b) -> bool", listEq<Shared<DoubleList>>),
     Operator("aten::eq(Tensor[] a, Tensor[] b) -> bool", listEq<Shared<TensorList>>),
+    Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>),
     Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
     Operator("aten::ne(float[] a, float[] b) -> bool", listNe<Shared<DoubleList>>),
     Operator("aten::ne(Tensor[] a, Tensor[] b) -> bool", listNe<Shared<TensorList>>),
-
+    Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
 
 #define CREATE_COPY_OP(other_type, c_type)                              \
   Operator(                                                             \
index 9ae08a5..2de49f8 100644 (file)
@@ -2,7 +2,10 @@
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/api/include/torch/utils.h>
-#include <ATen/ExpandUtils.h>
+#include <aten/src/ATen/ExpandUtils.h>
+#include <c10/core/ScalarType.h>
+#include <aten/src/ATen/InitialTensorOptions.h>
+#include <torch/csrc/jit/script/error_report.h>
 
 #include <sstream>
 #include <regex>
@@ -11,6 +14,126 @@ namespace torch {
 namespace jit {
 
 namespace {
+
+
+void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
+  if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) {
+    auto error = script::ErrorReport(node->getSourceLocation());
+    error << "Input list to torch.tensor must be of ints, floats, or bools, " <<
+      "got " << elem_type->str();
+    // special case empty list torch.tensor([])
+    if (elem_type->isSubtypeOf(DynamicType::get())) {
+      auto input = node->inputs().at(0);
+      if (input->node()->kind() == prim::ListConstruct && input->node()->inputs().size() == 0) {
+        error << "\n(Note: empty lists are constructed as Tensor[]; \n"
+               << "if you want an empty list of a different type, \n"
+               << "use `torch.jit.annotate(List[T], [])`, \n"
+               << "where `T` is the type of elements in the list)";
+      }
+    }
+    throw error;
+  }
+}
+
+at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
+  if (type == FloatType::get()) {
+    return at::ScalarType::Double;
+  } else if (type == IntType::get()) {
+    return at::ScalarType::Long;
+  } else if (type == BoolType::get()) {
+    return at::ScalarType::Byte;
+  }
+  AT_ASSERTM(0, "Add new condition, expected Float, Int, or Bool but got",
+      type->str());
+}
+
+
+int64_t list_size(IValue list) {
+  if (list.isGenericList()) {
+    return list.toGenericListRef().size();
+  } else if (list.isIntList()) {
+    return list.toIntListRef().size();
+  } else if (list.isDoubleList()){
+    return list.toDoubleListRef().size();
+  } else if (list.isBoolList()) {
+    return list.toBoolListRef().size();
+  }
+  AT_ASSERTM(0, "Unexpected list type", list);
+}
+
+std::vector<int64_t> compute_sizes(IValue seq) {
+  std::vector<int64_t> sizes;
+  // because bool, int, and float lists are specialized, inner array will
+  // will not be generic list
+  while (seq.isGenericList()) {
+    auto seq_list = seq.toGenericListRef();
+    auto length = seq_list.size();
+    AT_ASSERT(length != 0);
+    sizes.push_back(length);
+    seq = seq_list[0];
+  }
+  sizes.push_back(list_size(seq));
+  return sizes;
+}
+
+void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
+  if (seq_size != n) {
+    AT_ERROR("Expected sequence of length ", n, " at dim ", dim, " (got ", seq_size, ")");
+  }
+}
+
+template <typename DTYPE>
+void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
+    int elementSize, std::vector<DTYPE> obj) {
+  auto n = sizes[dim];
+  auto seq_size = obj.size();
+  checkSequenceSize(n, dim, seq_size);
+  for (int64_t i = 0; i < n; i++) {
+    *(DTYPE*)data = obj[i];
+    data += strides[dim] * elementSize;
+  }
+}
+
+// bool vector needs to be cast to uint8_t
+template<>
+void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
+    int elementSize, std::vector<bool> obj) {
+  auto n = sizes[dim];
+  auto seq_size = obj.size();
+  checkSequenceSize(n, dim, seq_size);
+  for (int64_t i = 0; i < n; i++) {
+    *(uint8_t*)data = static_cast<uint8_t>(obj[i]);
+    data += strides[dim] * elementSize;
+  }
+}
+
+// refernce python implementation recursive_store in tensor_new.cpp
+
+void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
+   int elementSize, IValue obj) {
+
+  auto ndim = sizes.size();
+  auto n = sizes[dim];
+  auto seq_size = list_size(obj);
+  checkSequenceSize(n, dim, seq_size);
+  if (dim + 1 < static_cast<long>(ndim)) {
+    auto items = obj.toGenericListRef();
+    for (int64_t i = 0; i < n; i++) {
+      recursiveStore(data, sizes, strides, dim + 1, elementSize, items[i]);
+      data += strides[dim] * elementSize;
+    }
+  } else {
+    JIT_ASSERT(obj.isIntList() || obj.isDoubleList() || obj.isBoolList());
+    if (obj.isIntList()) {
+      storeLastDimension<int64_t>(data, sizes, strides, dim, elementSize, obj.toIntListRef());
+    } else if (obj.isDoubleList()){
+      storeLastDimension<double>(data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
+    } else {
+      storeLastDimension<bool>(data, sizes, strides, dim, elementSize, obj.toBoolListRef());
+    }
+  }
+}
+
 RegisterOperators reg({
     Operator(
         "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
@@ -54,6 +177,35 @@ RegisterOperators reg({
           return 0;
         }),
     Operator(
+        "aten::_infer_size(int[] a, int[] b) -> int[]",
+        [](const Node* node) {
+          return [](Stack& stack) {
+            auto a = pop(stack).toIntList()->elements();
+            auto b = pop(stack).toIntList()->elements();
+            push(stack, at::infer_size(a, b));
+            return 0;
+          };
+        }),
+    Operator(
+      "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
+      [](const Node* node) {
+        return [](Stack& stack) {
+          at::Tensor weight;
+          at::Tensor input;
+          double max_norm;
+          double norm_type;
+          pop(stack, weight, input, max_norm, norm_type);
+
+          // TODO: remove when script supports setting grad mode
+          torch::NoGradGuard no_grad;
+
+          at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
+          push(stack, result);
+
+          return 0;
+        };
+      }),
+    Operator(
         "aten::format(str self, ...) -> str",
         [](const Node* node) {
           size_t num_inputs = node->inputs().size();
@@ -86,32 +238,75 @@ RegisterOperators reg({
             return 0;
           };
         }),
+
+#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op)             \
+Operator(                                                                             \
+  "aten::tensor(" #operator_type " t, *, ScalarType? dtype=None, Device? device=None"\
+      ") -> Tensor",                                                                  \
+  [](const Node* node) {                                                              \
+    auto initial_scalar_type = scalarTypeFromJitType(node->inputs().at(0)->type());   \
+    return [initial_scalar_type](Stack& stack) {                                      \
+      c_type scalar_val;                                                              \
+      IValue dtype;                                                                   \
+      IValue device;                                                                  \
+      pop(stack, scalar_val, dtype, device);                                          \
+      auto tensor = autograd::make_variable(tensor_creation_op);                      \
+      at::ScalarType scalar_type = dtype.isNone() ?                                   \
+        tensor.scalar_type() : dtype.toScalarType();                                  \
+      c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();        \
+      if (scalar_type != initial_scalar_type || dev != tensor.device()) {             \
+        tensor = tensor.to(dev, scalar_type);                                         \
+      }                                                                               \
+      push(stack, tensor);                                                            \
+      return 0;                                                                       \
+    };                                                                                \
+  }),
+
+DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
+DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
+DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
+
+
+    // reference python implementation: internal_new_from_data in tensor_new.cpp
     Operator(
-        "aten::_infer_size(int[] a, int[] b) -> int[]",
-        [](const Node* node) {
-          return [](Stack& stack) {
-            auto a = pop(stack).toIntList()->elements();
-            auto b = pop(stack).toIntList()->elements();
-            push(stack, at::infer_size(a, b));
-            return 0;
-          };
-        }),
-    Operator(
-      "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
+      "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
       [](const Node* node) {
-        return [](Stack& stack) {
-          at::Tensor weight;
-          at::Tensor input;
-          double max_norm;
-          double norm_type;
-          pop(stack, weight, input, max_norm, norm_type);
+        auto input = node->inputs().at(0);
+        auto elem_type = input->type();
+        while (auto list_type = elem_type->cast<ListType>()) {
+          elem_type = list_type->getElementType();
+        }
+        checkListInputType(elem_type, node);
+        at::ScalarType initial_scalar_type = scalarTypeFromJitType(elem_type);
+        return [initial_scalar_type, elem_type](Stack& stack) {
+          IValue data;
+          IValue dtype;
+          IValue device;
+          pop(stack, data, dtype, device);
+          auto sizes = compute_sizes(data);
+          auto tensor = autograd::make_variable(
+            at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
 
-          // TODO: remove when script supports setting grad mode
-          torch::NoGradGuard no_grad;
+          recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
+              tensor.type().elementSizeInBytes(), data);
 
-          at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
-          push(stack, result);
+          at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
+          c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();
+          if (scalar_type != initial_scalar_type || dev != tensor.device()) {
+            tensor = tensor.to(dev, scalar_type);
+          }
+
+          auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
+
+          if (dtype.isNone() && tensor.scalar_type() != default_type &&
+              tensor.numel() == 0) {
+            AT_WARN("Creating a tensor from an empty ", elem_type->str(),
+              "list will create a tensor of default floating point type  (currently ", default_type,
+              ") in python but a tensor of type ", elem_type->str(), " in torchscript.\n",
+              "Pass in a dtype argument to ensure consistent behavior");
+          }
 
+          push(stack, tensor);
           return 0;
         };
       }),
index 5dd09fa..549d2e0 100644 (file)
@@ -589,7 +589,6 @@ c10::optional<MatchedSchema> tryMatchSchema(
     failure_messages << "\nfor operator " << schema << ":\n";
     return failure_messages;
   };
-
   TypeEnv type_env;
   std::vector<Value*> positional_inputs;
   std::vector<bool> used_kwarg(kwargs.size(), false);
@@ -604,7 +603,7 @@ c10::optional<MatchedSchema> tryMatchSchema(
       self = c10::nullopt;
     } else if (!arg.kwarg_only() && used_args < args.size()) {
       // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
-      if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
+      if (allow_conversions && arg.type()->kind() == TypeKind::ListType && // the formal must be a list
           !arg.N() && // it must not be a broadcasting list like int[3], otherwise
                     // a single int is a valid input
           (schema_i + 1 == schema.arguments().size() ||
@@ -2276,7 +2275,7 @@ private:
           elem_type = values.at(0)->type();
         }
         for (auto v : values) {
-          if (v->type() != elem_type) {
+          if (*v->type() != *elem_type)  {
             throw ErrorReport(tree)
                 << "Lists must contain only a single type, expected: "
                 << *elem_type << " but found " << *v->type() << " instead";