"torch/csrc/jit/runtime/slice_indices_adjust.cpp",
"torch/csrc/jit/runtime/register_ops_utils.cpp",
"torch/csrc/jit/runtime/vararg_functions.cpp",
+ "torch/csrc/jit/mobile/promoted_prim_ops.cpp",
+ "torch/csrc/jit/mobile/prim_ops_registery.cpp",
"torch/csrc/jit/serialization/import_read.cpp",
"torch/csrc/jit/serialization/unpickler.cpp",
]
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/prim_ops_registery.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
are removed */
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
- auto opname = code_->op_names_.back();
+ const auto& opname = code_->op_names_.back();
- const auto& opname_c10 = opname;
std::function<void(Stack&)> fn;
- auto jit_op = findOperatorFor(opname);
const std::vector<c10::Argument>* pArgs = nullptr;
- if (jit_op) {
- fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
- pArgs = &jit_op->schema().arguments();
+ bool promoted_op = mobile::hasPrimOpsFn(name);
+ if (promoted_op) {
+ fn = mobile::getPrimOpsFn(name);
} else {
- auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
- if (op.has_value()) {
- fn = [op](Stack& stack) { op->callBoxed(&stack); };
- if (op->hasSchema()) {
- pArgs = &op->schema().arguments();
+ std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
+ if (jit_op) {
+ fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
+ pArgs = &jit_op->schema().arguments();
+ } else {
+ auto op = c10::Dispatcher::singleton().findSchema(opname);
+ if (op.has_value()) {
+ fn = [op](Stack& stack) { op->callBoxed(&stack); };
+ if (op->hasSchema()) {
+ pArgs = &op->schema().arguments();
+ } else {
+ TORCH_CHECK(false, "arguments are missing for operator ", opname);
+ }
} else {
- TORCH_CHECK(false, "arguments are missing for operator ", opname);
+ return false;
}
- } else {
- return false;
}
}
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
- const auto& args = *pArgs;
- if (model_version == 0x3LL &&
- opname == c10::OperatorName("aten::_convolution", "")) {
- // Since byte-code versions 0x4L, convolution has an additional
- // default-value argument (allow_tf32=True, see
- // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
- // backward compatibility with models of byte-code version <= 0x3L, where
- // this bool argument does not yet exist.
- fn = [fn](Stack& stack) {
- stack.push_back(true);
- fn(stack);
- };
- } else {
- // num_specified_args >= 0 indicates number of arguments are available
- // from model. We can use it to handle backward compatibility.
- if (num_specified_args &&
- num_specified_args.value() < static_cast<int64_t>(args.size())) {
- fn = [fn, num_specified_args, args](Stack& stack) {
- std::vector<IValue> out_args;
- // The following logic pops and temporarily stores all out arguments
- // from the stack (which can be 0 or more, and always appended to the
- // schema), in order to push the necessary default values. Finally, the
- // out arguments are pushed back into the stack.
- for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
- out_args.push_back(stack.back());
- stack.pop_back();
- }
- size_t start_index = num_specified_args.value() - out_args.size();
- TORCH_CHECK(
- start_index >= 0,
- "The number of output arguments is: ",
- out_args.size(),
- ", which is more then the number of specified arguments: ",
- num_specified_args.value());
- for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
- TORCH_CHECK(
- args[i].default_value().has_value(),
- "Error happened at preparing for default values for the argument. The ",
- i,
- "th argument ",
- args[i].name(),
- " does not have a specified value or default value. ");
-
- stack.push_back(args[i].default_value());
- }
- stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
+ if (!promoted_op) {
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
+ const auto& args = *pArgs;
+ if (model_version == 0x3LL &&
+ opname == c10::OperatorName("aten::_convolution", "")) {
+ // Since byte-code versions 0x4L, convolution has an additional
+ // default-value argument (allow_tf32=True, see
+ // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
+ // backward compatibility with models of byte-code version <= 0x3L, where
+ // this bool argument does not yet exist.
+ fn = [fn](Stack& stack) {
+ stack.push_back(true);
fn(stack);
};
+ } else {
+ // num_specified_args >= 0 indicates number of arguments are available
+ // from model. We can use it to handle backward compatibility.
+ if (num_specified_args &&
+ num_specified_args.value() < static_cast<int64_t>(args.size())) {
+ fn = [fn, num_specified_args, args](Stack& stack) {
+ std::vector<IValue> out_args;
+ // The following logic pops and temporarily stores all out arguments
+ // from the stack (which can be 0 or more, and always appended to the
+ // schema), in order to push the necessary default values. Finally,
+ // the out arguments are pushed back into the stack.
+ for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
+ out_args.push_back(stack.back());
+ stack.pop_back();
+ }
+ size_t start_index = num_specified_args.value() - out_args.size();
+ TORCH_CHECK(
+ start_index >= 0,
+ "The number of output arguments is: ",
+ out_args.size(),
+ ", which is more then the number of specified arguments: ",
+ num_specified_args.value());
+ for (size_t i = start_index; i < (args.size() - out_args.size());
+ ++i) {
+ TORCH_CHECK(
+ args[i].default_value().has_value(),
+ "Error happened at preparing for default values for the argument. The ",
+ i,
+ "th argument ",
+ args[i].name(),
+ " does not have a specified value or default value. ");
+
+ stack.push_back(args[i].default_value());
+ }
+ stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
+ fn(stack);
+ };
+ }
}
}
code_->operators_.emplace_back(fn);
--- /dev/null
+#include <torch/csrc/jit/mobile/prim_ops_registery.h>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+
+std::unordered_map<std::string, std::function<void(Stack&)>>& primOpsFnTable() {
+ static std::unordered_map<std::string, std::function<void(Stack&)>>
+ prim_ops_fn;
+ return prim_ops_fn;
+}
+
+void registerPrimOpsFunction(
+ const std::string& name,
+ const std::function<void(Stack&)>& fn) {
+ primOpsFnTable()[name] = fn;
+}
+
+bool hasPrimOpsFn(const std::string& name) {
+ return primOpsFnTable().count(name);
+}
+
+std::function<void(Stack&)>& getPrimOpsFn(const std::string& name) {
+ TORCH_CHECK(
+ hasPrimOpsFn(name),
+ "Prim Ops Function for ",
+ name,
+ " is not promoted yet.");
+ return primOpsFnTable()[name];
+}
+
+} // namespace mobile
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <ATen/core/ivalue.h>
+#include <functional>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+
+using Stack = std::vector<c10::IValue>;
+
+void registerPrimOpsFunction(
+ const std::string& name,
+ const std::function<void(Stack&)>& fn);
+
+bool hasPrimOpsFn(const std::string& name);
+
+std::function<void(Stack&)>& getPrimOpsFn(const std::string& name);
+
+class prim_op_fn_register {
+ public:
+ prim_op_fn_register(
+ const std::string& name,
+ const std::function<void(Stack&)>& fn) {
+ registerPrimOpsFunction(name, fn);
+ }
+};
+
+} // namespace mobile
+} // namespace jit
+} // namespace torch
--- /dev/null
+#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
+
+namespace torch {
+namespace jit {
+void tupleIndex(Stack& stack) {
+ int64_t index = pop(stack).toInt();
+ auto tuple = pop(stack).toTuple();
+ auto norm_index = normalizeIndex(index, tuple->elements().size());
+ if (norm_index < 0 ||
+ norm_index > static_cast<int64_t>(tuple->elements().size())) {
+ throw std::out_of_range("Tuple list index out of range");
+ }
+ stack.emplace_back(tuple->elements()[norm_index]);
+}
+
+void raiseException(Stack& stack) {
+ throw JITException(pop(stack).toStringRef());
+}
+
+void is(Stack& stack) {
+ IValue self, obj;
+ pop(stack, self, obj);
+ push(stack, self.is(obj));
+}
+
+void unInitialized(Stack& stack) {
+ push(stack, IValue::uninitialized());
+}
+
+void isNot(Stack& stack) {
+ IValue self, obj;
+ pop(stack, self, obj);
+ push(stack, !self.is(obj));
+}
+
+void aten_format(Stack& stack) {
+ size_t num_inputs = pop(stack).toInt();
+ format(stack, num_inputs);
+}
+
+void size(Stack& stack) {
+ auto t = std::move(pop(stack)).toTensor();
+ pack(stack, t.sizes().vec());
+}
+
+void device(Stack& stack) {
+ push(stack, pop(stack).toTensor().device());
+}
+
+void dtype(Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, static_cast<int64_t>(a.scalar_type()));
+}
+
+void toPrimDType(Stack& stack) {
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ bool non_blocking;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ bool copy;
+ pop(stack, non_blocking, copy);
+ c10::optional<at::ScalarType> scalarType =
+ pop(stack).toOptional<at::ScalarType>();
+ c10::optional<c10::Device> device = c10::nullopt;
+ at::Tensor self = pop(stack).toTensor();
+ push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+}
+
+void dim(Stack& stack) {
+ at::Tensor arg = pop(stack).toTensor();
+ push(stack, arg.dim());
+}
+
+void _not(Stack& stack) {
+ push(stack, !pop(stack).toBool());
+}
+
+void boolTensor(Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.is_nonzero());
+}
+
+void toList(Stack& stack) {
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int elem_ty_val;
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ int dim_val;
+ at::Tensor t;
+
+ pop(stack, elem_ty_val);
+ pop(stack, dim_val);
+ pop(stack, t);
+
+ // If the Tensor is not on the CPU, transfer it.
+ if (!t.device().is_cpu()) {
+ t = t.cpu();
+ }
+
+ // Rebuild the output type using elem_ty_val and dim_val. Start
+ // with the element type corresponding to elem_ty_val.
+ TypePtr out_ty;
+ if (elem_ty_val == 0) {
+ out_ty = IntType::get();
+ } else if (elem_ty_val == 1) {
+ out_ty = FloatType::get();
+ } else if (elem_ty_val == 2) {
+ out_ty = BoolType::get();
+ } else if (elem_ty_val == 3) {
+ out_ty = ComplexType::get();
+ } else {
+ TORCH_CHECK(
+ false,
+ "Unsupported element type for tolist; only int, float, complex and bool are supported");
+ }
+
+ // Check that type of the Tensor matches that of the annotation.
+ // Make an exception for the case in which the annotated type is
+ // float/complex and the Tensor data type is also float/complex;
+ // the elements will be casted to double/c10::complex<double>
+ // later.
+ TORCH_CHECK(
+ (out_ty == FloatType::get() && t.is_floating_point()) ||
+ (out_ty == ComplexType::get() && t.is_complex()) ||
+ tryScalarTypeFromJitType(out_ty) == t.scalar_type(),
+ "Output annotation element type and runtime tensor element type must match for tolist()");
+
+ // Check that the dimension of the Tensor matches that of the
+ // annotation.
+ TORCH_CHECK(
+ dim_val == t.dim(),
+ "Output annotation list dimension and runtime tensor dimension must match for tolist()");
+
+ // Wrap out_ty in a ListType dim times.
+ for (const auto i : c10::irange(dim_val)) {
+ (void)i; // Suppress unused variable warning
+ out_ty = ListType::create(out_ty);
+ }
+
+ int64_t dim = t.dim();
+ auto sizes = t.sizes();
+ auto strides = t.strides();
+ size_t element_size = t.element_size();
+ char* data = static_cast<char*>(t.data_ptr());
+ auto result = tensorToListRecursive(
+ data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
+ push(stack, std::move(result));
+}
+
+void numToTensorScalar(Stack& stack) {
+ at::Scalar s;
+ pop(stack, s);
+ push(stack, at::scalar_to_tensor(s));
+}
+
+void isCuda(Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.is_cuda());
+}
+
+void numToTensorBool(Stack& stack) {
+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ bool b;
+ pop(stack, b);
+ push(stack, at::scalar_to_tensor(b));
+}
+
+static const std::array<mobile::prim_op_fn_register, 14> op_reg = {
+ mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
+ mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
+ mobile::prim_op_fn_register("aten::format", aten_format),
+ mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
+ mobile::prim_op_fn_register("prim::RaiseException", raiseException),
+ mobile::prim_op_fn_register("prim::device", device),
+ mobile::prim_op_fn_register("prim::dtype", dtype),
+ mobile::prim_op_fn_register("aten::__not__", _not),
+ mobile::prim_op_fn_register("aten::__is__", is),
+ mobile::prim_op_fn_register("aten::__isnot__", isNot),
+ mobile::prim_op_fn_register("aten::dim", dim),
+ mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
+ mobile::prim_op_fn_register("aten::to.prim_dtype", toPrimDType),
+ mobile::prim_op_fn_register("prim::is_cuda", isCuda)
+ // TODO: (@pavithran) size is overloaded with int[] and Tensor
+ // so this throws error expecting int not Tensor
+ // mobile::prim_op_fn_register("aten::size", size)
+};
+
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+#include <torch/csrc/jit/mobile/prim_ops_registery.h>
+#include <torch/csrc/jit/runtime/register_ops_utils.h>
+
+namespace torch {
+namespace jit {
+
+void tupleIndex(Stack& stack);
+
+void raiseException(Stack& stack);
+
+void is(Stack& stack);
+
+void unInitialized(Stack& stack);
+
+void isNot(Stack& stack);
+
+void aten_format(Stack& stack);
+
+void size(Stack& stack);
+
+void device(Stack& stack);
+
+void dtype(Stack& stack);
+
+void toPrimDType(Stack& stack);
+
+void dim(Stack& stack);
+
+void _not(Stack& stack);
+
+void boolTensor(Stack& stack);
+
+void toList(Stack& stack);
+
+void numToTensorScalar(Stack& stack);
+
+void isCuda(Stack& stack);
+
+void numToTensorBool(Stack& stack);
+
+} // namespace jit
+} // namespace torch
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
+#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::Bool.Tensor(Tensor a) -> bool"),
- [](Stack& stack) {
- at::Tensor a;
- pop(stack, a);
- push(stack, a.is_nonzero());
- },
+ boolTensor,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::format(str self, ...) -> str"),
- [](Stack& stack) {
- size_t num_inputs = pop(stack).toInt();
- format(stack, num_inputs);
- },
+ [](Stack& stack) { aten_format(stack); },
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::einsum.sublist(Tensor a, ...) -> Tensor"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"),
- [](Stack& stack) {
- at::Scalar s;
- pop(stack, s);
- push(stack, at::scalar_to_tensor(s));
- },
+ numToTensorScalar,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"),
- [](Stack& stack) { throw JITException(pop(stack).toStringRef()); },
+ raiseException,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::Size(int[] sizes) -> int[]"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"),
- [](Stack& stack) {
- auto t = std::move(pop(stack)).toTensor();
- pack(stack, t.sizes().vec());
- },
+ size,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"),
// note the compiler knows to type TupleIndex more accurately than it
// is listed here.
TORCH_SELECTIVE_SCHEMA("prim::TupleIndex(Any tup, int i) -> Any"),
- [](Stack& stack) {
- int64_t index = pop(stack).toInt();
- auto tuple = pop(stack).toTuple();
- auto norm_index = normalizeIndex(index, tuple->elements().size());
- if (norm_index < 0 ||
- norm_index > static_cast<int64_t>(tuple->elements().size())) {
- throw std::out_of_range("Tuple list index out of range");
- }
- stack.emplace_back(tuple->elements()[norm_index]);
- },
+ tupleIndex,
aliasAnalysisSpecialCase()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::ne.int_list(int[] a, int[] b) -> bool"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"),
- [](Stack& stack) { push(stack, pop(stack).toTensor().device()); },
+ device,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"),
- [](Stack& stack) {
- at::Tensor a;
- pop(stack, a);
- push(stack, static_cast<int64_t>(a.scalar_type()));
- },
+ dtype,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"),
- [](Stack& stack) { push(stack, !pop(stack).toBool()); },
+ _not,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::__is__(t1 self, t2 obj) -> bool"),
- [](Stack& stack) {
- IValue self, obj;
- pop(stack, self, obj);
- push(stack, self.is(obj));
- },
+ is,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::__isnot__(t1 self, t2 obj) -> bool"),
- [](Stack& stack) {
- IValue self, obj;
- pop(stack, self, obj);
- push(stack, !self.is(obj));
- },
+ isNot,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::element_size(Tensor self) -> int"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"),
- [](Stack& stack) {
- at::Tensor arg = pop(stack).toTensor();
- push(stack, arg.dim());
- },
+ dim,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"),
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
- [](Stack& stack) { push(stack, IValue::uninitialized()); },
+ unInitialized,
aliasAnalysisSpecialCase()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Print(...) -> ()"),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"),
- [](Stack& stack) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- bool non_blocking;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- bool copy;
- pop(stack, non_blocking, copy);
- c10::optional<at::ScalarType> scalarType =
- pop(stack).toOptional<at::ScalarType>();
- c10::optional<c10::Device> device = c10::nullopt;
- at::Tensor self = pop(stack).toTensor();
- push(
- stack, to_dispatch(self, device, scalarType, non_blocking, copy));
- },
+ toPrimDType,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_cuda(Tensor a) -> bool"),
- [](Stack& stack) {
- at::Tensor a;
- pop(stack, a);
- push(stack, a.is_cuda());
- },
+ isCuda,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"),
// depends on the type hint and input. The implementation of this
// operator below is intended to be as close to the Python
// implementation in torch/csrc/utils/tensor_list.cpp as possible.
- [](const Node* /*node*/) -> Operation {
- return [](Stack& stack) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int elem_ty_val;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int dim_val;
- at::Tensor t;
-
- pop(stack, elem_ty_val);
- pop(stack, dim_val);
- pop(stack, t);
-
- // If the Tensor is not on the CPU, transfer it.
- if (!t.device().is_cpu()) {
- t = t.cpu();
- }
-
- // Rebuild the output type using elem_ty_val and dim_val. Start
- // with the element type corresponding to elem_ty_val.
- TypePtr out_ty;
- if (elem_ty_val == 0) {
- out_ty = IntType::get();
- } else if (elem_ty_val == 1) {
- out_ty = FloatType::get();
- } else if (elem_ty_val == 2) {
- out_ty = BoolType::get();
- } else if (elem_ty_val == 3) {
- out_ty = ComplexType::get();
- } else {
- TORCH_CHECK(
- false,
- "Unsupported element type for tolist; only int, float, complex and bool are supported");
- }
-
- // Check that type of the Tensor matches that of the annotation.
- // Make an exception for the case in which the annotated type is
- // float/complex and the Tensor data type is also float/complex;
- // the elements will be casted to double/c10::complex<double>
- // later.
- TORCH_CHECK(
- (out_ty == FloatType::get() && t.is_floating_point()) ||
- (out_ty == ComplexType::get() && t.is_complex()) ||
- tryScalarTypeFromJitType(out_ty) == t.scalar_type(),
- "Output annotation element type and runtime tensor element type must match for tolist()");
-
- // Check that the dimension of the Tensor matches that of the
- // annotation.
- TORCH_CHECK(
- dim_val == t.dim(),
- "Output annotation list dimension and runtime tensor dimension must match for tolist()");
-
- // Wrap out_ty in a ListType dim times.
- for (const auto i : c10::irange(dim_val)) {
- (void)i; // Suppress unused variable warning
- out_ty = ListType::create(out_ty);
- }
-
- int64_t dim = t.dim();
- auto sizes = t.sizes();
- auto strides = t.strides();
- size_t element_size = t.element_size();
- char* data = static_cast<char*>(t.data_ptr());
- auto result = tensorToListRecursive(
- data,
- 0,
- dim,
- out_ty,
- t.scalar_type(),
- sizes,
- strides,
- element_size);
- push(stack, std::move(result));
- };
- },
+ [](const Node* /*node*/) -> Operation { return toList; },
aliasAnalysisSpecialCase()));
return v;
})());
// because all _to_tensor conversion have to have the same operator namet
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.bool(bool a) -> Tensor"),
- [](Stack& stack) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- bool b;
- pop(stack, b);
- push(stack, at::scalar_to_tensor(b));
- },
+ numToTensorBool,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::device(str a) -> Device"),