Revert D13407930: [pytorch][PR] Support torch.tensor in script
authorMichael Suo <suo@fb.com>
Fri, 14 Dec 2018 06:10:56 +0000 (22:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 06:13:07 +0000 (22:13 -0800)
Differential Revision:
D13407930

Original commit changeset: d17f1195a221

fbshipit-source-id: f4458872c48ec4a2c9983b21ed90bcdc0ae665b7

aten/src/ATen/core/ivalue.h
test/test_jit.py
torch/csrc/jit/constants.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/register_special_ops.cpp
torch/csrc/jit/script/compiler.cpp

index ed46d11..8b99ecc 100644 (file)
@@ -639,7 +639,6 @@ DEFINE_TO(int64_t, toInt)
 DEFINE_TO(bool, toBool)
 DEFINE_TO(c10::intrusive_ptr<ivalue::DoubleList>, toDoubleList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::IntList>, toIntList)
-DEFINE_TO(c10::intrusive_ptr<ivalue::BoolList>, toBoolList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::TensorList>, toTensorList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericList>, toGenericList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
index ae6ba8d..732229a 100644 (file)
@@ -2782,17 +2782,6 @@ class TestScript(JitTestCase):
             return torch.ones(x), x
         self.checkScript(stuff3, ([3, 2],))
 
-    def test_bool_list_io(self):
-        @torch.jit.script
-        def stuff4(x):
-            # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
-            return x, [True, False], [[True]]
-
-        li_1, li_2, li_3 = stuff4([True])
-        li_3 = li_3[0]
-        for li in [li_1, li_2, li_3]:
-            self.assertTrue(type(li[0]) == type(True))
-
     def test_nested_list(self):
         def foo(z):
             # type: (Tuple[int, List[List[int]]]) -> int
@@ -4442,83 +4431,6 @@ a")
     def test_tensor_number_math(self):
         self._test_tensor_number_math()
 
-    def test_torch_tensor_bad_input(self):
-        with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, "
-                                    "or bools, got None"):
-            @torch.jit.script
-            def test():
-                return torch.tensor([None])
-
-        with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"):
-            @torch.jit.script
-            def tmp():
-                return torch.tensor([])
-
-        @torch.jit.script
-        def foo():
-            return torch.tensor([[2, 2], [1]])
-        with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
-            foo()
-
-    def test_torch_tensor_empty_list(self):
-        def func():
-            return torch.tensor(torch.jit.annotate(List[int], []))
-        cu = torch.jit.script(func)
-        t1 = cu()
-        t2 = func()
-
-        # torchscript returns int tensor, python returns float tensor
-        self.assertNotEqual(t1.dtype, t2.dtype)
-
-        def func():
-            li = torch.jit.annotate(List[int], [])
-            return torch.tensor([li, li])
-
-        self.checkScript(func, ())
-
-        def func():
-            li = torch.jit.annotate(List[int], [])
-            return torch.tensor([[[li]]])
-
-        self.checkScript(func, ())
-
-    def test_torch_tensor(self):
-        template = dedent('''
-        def func():
-            li = {list_create}
-            return torch.tensor(li {options})
-        ''')
-
-        lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
-                 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
-
-        dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
-                  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
-                  ", dtype=torch.int", ", dtype=torch.long"]
-
-        devices = ['', ", device='cpu'"]
-        if RUN_CUDA:
-            devices.append(", device='cuda'")
-
-        option_pairs = [dtype + device for dtype in dtypes for device in devices]
-        for li in lists:
-            for option in option_pairs:
-                # tensor from empty list is type float in python and annotated type in torchscript
-                if "annotate" in li and "dtype" not in option:
-                    continue
-                code = template.format(list_create=li, options=option)
-                scope = {}
-                exec(code, globals(), scope)
-                cu = torch.jit.CompilationUnit(code)
-                t1 = cu.func()
-                t2 = scope['func']()
-                if t1.dtype == torch.float16:  # equality NYI for half tensor
-                    self.assertTrue(str(t1) == str(t2))
-                else:
-                    self.assertEqual(t1, t2)
-                self.assertEqual(t1.dtype, t2.dtype)
-                self.assertEqual(t1.device, t2.device)
-
     @unittest.skipIf(not RUN_CUDA, "No CUDA")
     @skipIfRocm
     def test_tensor_number_math_cuda(self):
index 42ffc7f..4912d58 100644 (file)
@@ -108,8 +108,7 @@ RegisterOperators reg({
             return 0;
           };
         } else if(type->isSubtypeOf(ListType::ofBools())) {
-          auto int_list = node->is(attr::value);
-          std::vector<bool> bs(int_list.begin(), int_list.end());
+          auto bs = node->is(attr::value);
           return [bs](Stack& stack) {
             push(stack, bs);
             return 0;
index 40acf42..2b6c336 100644 (file)
@@ -172,8 +172,6 @@ inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_
               std::vector<double> repeated(*N, value);
               return repeated;
             }
-          case TypeKind::BoolType:
-            return py::cast<std::vector<bool>>(obj);
           case TypeKind::TensorType:
           case TypeKind::DynamicType:
             return py::cast<std::vector<at::Tensor>>(obj);
index 2fd4938..10e91a5 100644 (file)
@@ -54,19 +54,6 @@ void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
   }
 }
 
-template <typename dtype> // int64_t, bool, double
-Operation listConstruct(int64_t num_inputs) {
-  return [=](Stack& stack) {
-    auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
-    std::vector<dtype> vals = fmap(inputs, [](const IValue& v) {
-      return v.to<dtype>();
-    });
-    drop(stack, num_inputs);
-    push(stack, std::move(vals));
-    return 0;
-  };
-}
-
 RegisterOperators reg({
     Operator(
         prim::FusionGroup,
@@ -579,11 +566,25 @@ RegisterOperators reg({
           const auto num_inputs = node->inputs().size();
           ListTypePtr lt = node->output()->type()->expect<ListType>();
           if(IntType::get() == lt->getElementType()) {
-            return listConstruct<int64_t>(num_inputs);
+            return [=](Stack& stack) {
+              auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
+              std::vector<int64_t> vals = fmap(inputs, [](const IValue& v) {
+                return v.toInt();
+              });
+              drop(stack, num_inputs);
+              push(stack, std::move(vals));
+              return 0;
+            };
           } else if(FloatType::get() == lt->getElementType()) {
-            return listConstruct<double>(num_inputs);
-          } else if (lt->getElementType() == BoolType::get()) {
-            return listConstruct<bool>(num_inputs);
+            return [=](Stack& stack) {
+              auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
+              std::vector<double> vals = fmap(inputs, [](const IValue& v) {
+                return v.toDouble();
+              });
+              drop(stack, num_inputs);
+              push(stack, std::move(vals));
+              return 0;
+            };
           } else if (lt->getElementType()->isSubtypeOf(DynamicType::get())) {
             return [=](Stack& stack) {
               const size_t stack_size = stack.size();
@@ -747,16 +748,6 @@ typename TList::element_type::ElemType& getItem(TList& list, int64_t idx) {
   return list->elements()[normalized_idx];
 }
 
-// cannot return a reference to an element in a bool vector
-bool getBoolItem(const std::vector<bool>& list, int64_t idx) {
-  const int64_t list_size = list.size();
-  const int64_t normalized_idx = normalizeIndex(idx, list_size);
-  if (normalized_idx < 0 || normalized_idx >= list_size) {
-    throw std::out_of_range("list index out of range");
-  }
-  return list[normalized_idx];
-}
-
 template <typename TList, typename TElement>
 Operation listAppend(const Node* node) {
   return [](Stack& stack) {
@@ -784,21 +775,6 @@ Operation listSelect(const Node* node) {
   };
 }
 
-// needs specialization because cannot return a pointer to a bool in an array
-template<>
-Operation listSelect<Shared<BoolList>>(const Node* node) {
-  return [=](Stack& stack) {
-    Shared<BoolList> list;
-    int64_t idx;
-    pop(stack, list, idx);
-
-    auto element = getBoolItem(list->elements(), idx);
-    push(stack, std::move(element));
-    return 0;
-  };
-}
-
-
 template <typename T>
 Operation listLen(const Node* node) {
   return [=](Stack& stack) {
@@ -810,7 +786,6 @@ Operation listLen(const Node* node) {
   };
 }
 
-
 template <typename T>
 Operation listEq(const Node* node) {
   return [=](Stack& stack) {
@@ -950,29 +925,6 @@ Operation listSetItem(const Node* node) {
   };
 }
 
-
-template<>
-Operation listSetItem<Shared<BoolList>, bool>(const Node* node) {
-  return [](Stack& stack) {
-    Shared<BoolList> list;
-    int64_t idx;
-    bool value;
-
-    pop(stack, list, idx, value);
-
-    int64_t list_size = list->elements().size();
-    auto normalized_idx = normalizeIndex(idx, list_size);
-    if (normalized_idx < 0 || normalized_idx >= list_size) {
-      throw std::out_of_range("list index out of range");
-    }
-    list->elements()[normalized_idx] = value;
-
-    push(stack, list);
-    return 0;
-  };
-}
-
-
 RegisterOperators reg2({
 
 #define DEFINE_STRING_OP(op_name, string_op, result)                           \
@@ -1021,8 +973,6 @@ Operator(                                                                      \
     CREATE_IMMUTABLE_LIST_OPS("int", IntList)
     CREATE_IMMUTABLE_LIST_OPS("float", DoubleList)
     CREATE_IMMUTABLE_LIST_OPS("t", GenericList)
-    CREATE_IMMUTABLE_LIST_OPS("t", BoolList)
-
 
 #define CREATE_LIST_OPS(decl_type, c_type) \
     Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
@@ -1035,7 +985,6 @@ Operator(                                                                      \
     CREATE_LIST_OPS("int", IntList)
     CREATE_LIST_OPS("float", DoubleList)
     CREATE_LIST_OPS("Tensor", TensorList)
-    CREATE_LIST_OPS("bool", BoolList)
     CREATE_LIST_OPS("t", GenericList)
 #undef CREATE_LIST_OPS
 
@@ -1043,11 +992,10 @@ Operator(                                                                      \
     Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
     Operator("aten::eq(float[] a, float[] b) -> bool", listEq<Shared<DoubleList>>),
     Operator("aten::eq(Tensor[] a, Tensor[] b) -> bool", listEq<Shared<TensorList>>),
-    Operator("aten::eq(bool[] a, bool[] b) -> bool", listEq<Shared<BoolList>>),
     Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
     Operator("aten::ne(float[] a, float[] b) -> bool", listNe<Shared<DoubleList>>),
     Operator("aten::ne(Tensor[] a, Tensor[] b) -> bool", listNe<Shared<TensorList>>),
-    Operator("aten::ne(bool[] a, bool[] b) -> bool", listNe<Shared<BoolList>>),
+
 
 #define CREATE_COPY_OP(other_type, c_type)                              \
   Operator(                                                             \
index 2de49f8..9ae08a5 100644 (file)
@@ -2,10 +2,7 @@
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/api/include/torch/utils.h>
-#include <aten/src/ATen/ExpandUtils.h>
-#include <c10/core/ScalarType.h>
-#include <aten/src/ATen/InitialTensorOptions.h>
-#include <torch/csrc/jit/script/error_report.h>
+#include <ATen/ExpandUtils.h>
 
 #include <sstream>
 #include <regex>
@@ -14,126 +11,6 @@ namespace torch {
 namespace jit {
 
 namespace {
-
-
-void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
-  if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) {
-    auto error = script::ErrorReport(node->getSourceLocation());
-    error << "Input list to torch.tensor must be of ints, floats, or bools, " <<
-      "got " << elem_type->str();
-    // special case empty list torch.tensor([])
-    if (elem_type->isSubtypeOf(DynamicType::get())) {
-      auto input = node->inputs().at(0);
-      if (input->node()->kind() == prim::ListConstruct && input->node()->inputs().size() == 0) {
-        error << "\n(Note: empty lists are constructed as Tensor[]; \n"
-               << "if you want an empty list of a different type, \n"
-               << "use `torch.jit.annotate(List[T], [])`, \n"
-               << "where `T` is the type of elements in the list)";
-      }
-    }
-    throw error;
-  }
-}
-
-at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
-  if (type == FloatType::get()) {
-    return at::ScalarType::Double;
-  } else if (type == IntType::get()) {
-    return at::ScalarType::Long;
-  } else if (type == BoolType::get()) {
-    return at::ScalarType::Byte;
-  }
-  AT_ASSERTM(0, "Add new condition, expected Float, Int, or Bool but got",
-      type->str());
-}
-
-
-int64_t list_size(IValue list) {
-  if (list.isGenericList()) {
-    return list.toGenericListRef().size();
-  } else if (list.isIntList()) {
-    return list.toIntListRef().size();
-  } else if (list.isDoubleList()){
-    return list.toDoubleListRef().size();
-  } else if (list.isBoolList()) {
-    return list.toBoolListRef().size();
-  }
-  AT_ASSERTM(0, "Unexpected list type", list);
-}
-
-std::vector<int64_t> compute_sizes(IValue seq) {
-  std::vector<int64_t> sizes;
-  // because bool, int, and float lists are specialized, inner array will
-  // will not be generic list
-  while (seq.isGenericList()) {
-    auto seq_list = seq.toGenericListRef();
-    auto length = seq_list.size();
-    AT_ASSERT(length != 0);
-    sizes.push_back(length);
-    seq = seq_list[0];
-  }
-  sizes.push_back(list_size(seq));
-  return sizes;
-}
-
-void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
-  if (seq_size != n) {
-    AT_ERROR("Expected sequence of length ", n, " at dim ", dim, " (got ", seq_size, ")");
-  }
-}
-
-template <typename DTYPE>
-void storeLastDimension(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
-    int elementSize, std::vector<DTYPE> obj) {
-  auto n = sizes[dim];
-  auto seq_size = obj.size();
-  checkSequenceSize(n, dim, seq_size);
-  for (int64_t i = 0; i < n; i++) {
-    *(DTYPE*)data = obj[i];
-    data += strides[dim] * elementSize;
-  }
-}
-
-// bool vector needs to be cast to uint8_t
-template<>
-void storeLastDimension<bool>(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
-    int elementSize, std::vector<bool> obj) {
-  auto n = sizes[dim];
-  auto seq_size = obj.size();
-  checkSequenceSize(n, dim, seq_size);
-  for (int64_t i = 0; i < n; i++) {
-    *(uint8_t*)data = static_cast<uint8_t>(obj[i]);
-    data += strides[dim] * elementSize;
-  }
-}
-
-// refernce python implementation recursive_store in tensor_new.cpp
-
-void recursiveStore(char* data, const std::vector<int64_t>& sizes, const c10::ArrayRef<int64_t>& strides, int64_t dim,
-   int elementSize, IValue obj) {
-
-  auto ndim = sizes.size();
-  auto n = sizes[dim];
-  auto seq_size = list_size(obj);
-  checkSequenceSize(n, dim, seq_size);
-  if (dim + 1 < static_cast<long>(ndim)) {
-    auto items = obj.toGenericListRef();
-    for (int64_t i = 0; i < n; i++) {
-      recursiveStore(data, sizes, strides, dim + 1, elementSize, items[i]);
-      data += strides[dim] * elementSize;
-    }
-  } else {
-    JIT_ASSERT(obj.isIntList() || obj.isDoubleList() || obj.isBoolList());
-    if (obj.isIntList()) {
-      storeLastDimension<int64_t>(data, sizes, strides, dim, elementSize, obj.toIntListRef());
-    } else if (obj.isDoubleList()){
-      storeLastDimension<double>(data, sizes, strides, dim, elementSize, obj.toDoubleListRef());
-    } else {
-      storeLastDimension<bool>(data, sizes, strides, dim, elementSize, obj.toBoolListRef());
-    }
-  }
-}
-
 RegisterOperators reg({
     Operator(
         "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
@@ -177,35 +54,6 @@ RegisterOperators reg({
           return 0;
         }),
     Operator(
-        "aten::_infer_size(int[] a, int[] b) -> int[]",
-        [](const Node* node) {
-          return [](Stack& stack) {
-            auto a = pop(stack).toIntList()->elements();
-            auto b = pop(stack).toIntList()->elements();
-            push(stack, at::infer_size(a, b));
-            return 0;
-          };
-        }),
-    Operator(
-      "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
-      [](const Node* node) {
-        return [](Stack& stack) {
-          at::Tensor weight;
-          at::Tensor input;
-          double max_norm;
-          double norm_type;
-          pop(stack, weight, input, max_norm, norm_type);
-
-          // TODO: remove when script supports setting grad mode
-          torch::NoGradGuard no_grad;
-
-          at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
-          push(stack, result);
-
-          return 0;
-        };
-      }),
-    Operator(
         "aten::format(str self, ...) -> str",
         [](const Node* node) {
           size_t num_inputs = node->inputs().size();
@@ -238,75 +86,32 @@ RegisterOperators reg({
             return 0;
           };
         }),
-
-#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op)             \
-Operator(                                                                             \
-  "aten::tensor(" #operator_type " t, *, ScalarType? dtype=None, Device? device=None"\
-      ") -> Tensor",                                                                  \
-  [](const Node* node) {                                                              \
-    auto initial_scalar_type = scalarTypeFromJitType(node->inputs().at(0)->type());   \
-    return [initial_scalar_type](Stack& stack) {                                      \
-      c_type scalar_val;                                                              \
-      IValue dtype;                                                                   \
-      IValue device;                                                                  \
-      pop(stack, scalar_val, dtype, device);                                          \
-      auto tensor = autograd::make_variable(tensor_creation_op);                      \
-      at::ScalarType scalar_type = dtype.isNone() ?                                   \
-        tensor.scalar_type() : dtype.toScalarType();                                  \
-      c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();        \
-      if (scalar_type != initial_scalar_type || dev != tensor.device()) {             \
-        tensor = tensor.to(dev, scalar_type);                                         \
-      }                                                                               \
-      push(stack, tensor);                                                            \
-      return 0;                                                                       \
-    };                                                                                \
-  }),
-
-DEFINE_TORCH_TENSOR_OP(float, double, at::scalar_to_tensor(scalar_val))
-DEFINE_TORCH_TENSOR_OP(int, int64_t, at::scalar_to_tensor(scalar_val))
-DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).fill_(scalar_val))
-
-
-    // reference python implementation: internal_new_from_data in tensor_new.cpp
     Operator(
-      "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
+        "aten::_infer_size(int[] a, int[] b) -> int[]",
+        [](const Node* node) {
+          return [](Stack& stack) {
+            auto a = pop(stack).toIntList()->elements();
+            auto b = pop(stack).toIntList()->elements();
+            push(stack, at::infer_size(a, b));
+            return 0;
+          };
+        }),
+    Operator(
+      "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
       [](const Node* node) {
-        auto input = node->inputs().at(0);
-        auto elem_type = input->type();
-        while (auto list_type = elem_type->cast<ListType>()) {
-          elem_type = list_type->getElementType();
-        }
-        checkListInputType(elem_type, node);
-        at::ScalarType initial_scalar_type = scalarTypeFromJitType(elem_type);
-        return [initial_scalar_type, elem_type](Stack& stack) {
-          IValue data;
-          IValue dtype;
-          IValue device;
-          pop(stack, data, dtype, device);
-          auto sizes = compute_sizes(data);
-          auto tensor = autograd::make_variable(
-            at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
-
-          recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
-              tensor.type().elementSizeInBytes(), data);
-
-          at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
-          c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();
-          if (scalar_type != initial_scalar_type || dev != tensor.device()) {
-            tensor = tensor.to(dev, scalar_type);
-          }
+        return [](Stack& stack) {
+          at::Tensor weight;
+          at::Tensor input;
+          double max_norm;
+          double norm_type;
+          pop(stack, weight, input, max_norm, norm_type);
 
-          auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
+          // TODO: remove when script supports setting grad mode
+          torch::NoGradGuard no_grad;
 
-          if (dtype.isNone() && tensor.scalar_type() != default_type &&
-              tensor.numel() == 0) {
-            AT_WARN("Creating a tensor from an empty ", elem_type->str(),
-              "list will create a tensor of default floating point type  (currently ", default_type,
-              ") in python but a tensor of type ", elem_type->str(), " in torchscript.\n",
-              "Pass in a dtype argument to ensure consistent behavior");
-          }
+          at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
+          push(stack, result);
 
-          push(stack, tensor);
           return 0;
         };
       }),
index 549d2e0..5dd09fa 100644 (file)
@@ -589,6 +589,7 @@ c10::optional<MatchedSchema> tryMatchSchema(
     failure_messages << "\nfor operator " << schema << ":\n";
     return failure_messages;
   };
+
   TypeEnv type_env;
   std::vector<Value*> positional_inputs;
   std::vector<bool> used_kwarg(kwargs.size(), false);
@@ -603,7 +604,7 @@ c10::optional<MatchedSchema> tryMatchSchema(
       self = c10::nullopt;
     } else if (!arg.kwarg_only() && used_args < args.size()) {
       // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
-      if (allow_conversions && arg.type()->kind() == TypeKind::ListType && // the formal must be a list
+      if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
           !arg.N() && // it must not be a broadcasting list like int[3], otherwise
                     // a single int is a valid input
           (schema_i + 1 == schema.arguments().size() ||
@@ -2275,7 +2276,7 @@ private:
           elem_type = values.at(0)->type();
         }
         for (auto v : values) {
-          if (*v->type() != *elem_type)  {
+          if (v->type() != elem_type) {
             throw ErrorReport(tree)
                 << "Lists must contain only a single type, expected: "
                 << *elem_type << " but found " << *v->type() << " instead";