From 3c003aa6aee742caa48d5d0ac9499a58f38fd07b Mon Sep 17 00:00:00 2001 From: Pavithran Ramachandran Date: Fri, 17 Sep 2021 10:22:41 -0700 Subject: [PATCH] [PyTorchEdge] promote prim ops by using ops table for mobile runtime (#64816) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64816 ## Context: Promoting prim ops: Certain prim ops are frequent than others (like tupleIndex, raiseException, ...). These ops are frequent that they are chosen to be promoted as first class instructions. To promote it requires multiple steps and support from TS team as it changes how the bytecode is serialized and deserialized. So to prevent multiple bytecode version bumps and provided stability while these changes happen, an iterim iterative process is proposed which uses a table to lookup for "promoted" op's function. This allows us to rapidly update the ops list and test on production model without having to change the bytecode. In case of failure, we can quickly revert this change. ## Observation The ops are chosen based on the notebook N1135657 which examines the top frequent ops. ## Fix An iterim solution of having a static table, which when given a prim op name returns a function to be applied on the stack. This helps us check in `function.cpp` to get the "promoted" op. As a fall back, the "promoted" op still resides in `register_prim_ops.cpp` so that the function of prim op is never missed. ghstack-source-id: 138261338 Test Plan: ``` [pavithran@67109.od ~/fbsource/fbcode (eddab7da6)]$ buck test caffe2/test/cpp/jit:jit -- BackendTest.TestComposite Building: finished in 5.4 sec (100%) 7284/7284 jobs, 0/7284 updated Total time: 5.8 sec More details at https://www.internalfb.com/intern/buck/build/480191aa-a1ba-42ca-99e9-ee4bf2b06d65 Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 867382eb-327f-43d7-a45c-875b7f484b15 Trace available for this run at /tmp/tpx-20210914-100224.283682/trace.log Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/844425134506115 ✓ ListingSuccess: caffe2/test/cpp/jit:jit - main (12.159) ✓ Pass: caffe2/test/cpp/jit:jit - BackendTest.TestCompositeWithSetStates (0.797) ✓ Pass: caffe2/test/cpp/jit:jit - BackendTest.TestComposite (0.779) Summary Pass: 2 ListingSuccess: 1 If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users Finished test run: https://www.internalfb.com/intern/testinfra/testrun/844425134506115 ``` {F663491347} Reviewed By: iseeyuan Differential Revision: D30819926 fbshipit-source-id: 4cbe05d5761bdc9d62ef08e18172dcf64cb49526 --- tools/build_variables.bzl | 2 + torch/csrc/jit/mobile/function.cpp | 130 +++++++++--------- torch/csrc/jit/mobile/prim_ops_registery.cpp | 34 +++++ torch/csrc/jit/mobile/prim_ops_registery.h | 32 +++++ torch/csrc/jit/mobile/promoted_prim_ops.cpp | 190 +++++++++++++++++++++++++++ torch/csrc/jit/mobile/promoted_prim_ops.h | 43 ++++++ torch/csrc/jit/runtime/register_prim_ops.cpp | 167 +++-------------------- 7 files changed, 388 insertions(+), 210 deletions(-) create mode 100644 torch/csrc/jit/mobile/prim_ops_registery.cpp create mode 100644 torch/csrc/jit/mobile/prim_ops_registery.h create mode 100644 torch/csrc/jit/mobile/promoted_prim_ops.cpp create mode 100644 torch/csrc/jit/mobile/promoted_prim_ops.h diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 6b11812..8f0282e 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -109,6 +109,8 @@ core_sources_common = [ "torch/csrc/jit/runtime/slice_indices_adjust.cpp", "torch/csrc/jit/runtime/register_ops_utils.cpp", "torch/csrc/jit/runtime/vararg_functions.cpp", + "torch/csrc/jit/mobile/promoted_prim_ops.cpp", + "torch/csrc/jit/mobile/prim_ops_registery.cpp", "torch/csrc/jit/serialization/import_read.cpp", "torch/csrc/jit/serialization/unpickler.cpp", ] diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 67e298f..e53d5c1 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -38,79 +39,86 @@ bool Function::append_operator( are removed */ // Keep the original opname in code_ code_->op_names_.emplace_back(name, overload_name); - auto opname = code_->op_names_.back(); + const auto& opname = code_->op_names_.back(); - const auto& opname_c10 = opname; std::function fn; - auto jit_op = findOperatorFor(opname); const std::vector* pArgs = nullptr; - if (jit_op) { - fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); }; - pArgs = &jit_op->schema().arguments(); + bool promoted_op = mobile::hasPrimOpsFn(name); + if (promoted_op) { + fn = mobile::getPrimOpsFn(name); } else { - auto op = c10::Dispatcher::singleton().findSchema(opname_c10); - if (op.has_value()) { - fn = [op](Stack& stack) { op->callBoxed(&stack); }; - if (op->hasSchema()) { - pArgs = &op->schema().arguments(); + std::shared_ptr jit_op = findOperatorFor(opname); + if (jit_op) { + fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); }; + pArgs = &jit_op->schema().arguments(); + } else { + auto op = c10::Dispatcher::singleton().findSchema(opname); + if (op.has_value()) { + fn = [op](Stack& stack) { op->callBoxed(&stack); }; + if (op->hasSchema()) { + pArgs = &op->schema().arguments(); + } else { + TORCH_CHECK(false, "arguments are missing for operator ", opname); + } } else { - TORCH_CHECK(false, "arguments are missing for operator ", opname); + return false; } - } else { - return false; } } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs); - const auto& args = *pArgs; - if (model_version == 0x3LL && - opname == c10::OperatorName("aten::_convolution", "")) { - // Since byte-code versions 0x4L, convolution has an additional - // default-value argument (allow_tf32=True, see - // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles - // backward compatibility with models of byte-code version <= 0x3L, where - // this bool argument does not yet exist. - fn = [fn](Stack& stack) { - stack.push_back(true); - fn(stack); - }; - } else { - // num_specified_args >= 0 indicates number of arguments are available - // from model. We can use it to handle backward compatibility. - if (num_specified_args && - num_specified_args.value() < static_cast(args.size())) { - fn = [fn, num_specified_args, args](Stack& stack) { - std::vector out_args; - // The following logic pops and temporarily stores all out arguments - // from the stack (which can be 0 or more, and always appended to the - // schema), in order to push the necessary default values. Finally, the - // out arguments are pushed back into the stack. - for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { - out_args.push_back(stack.back()); - stack.pop_back(); - } - size_t start_index = num_specified_args.value() - out_args.size(); - TORCH_CHECK( - start_index >= 0, - "The number of output arguments is: ", - out_args.size(), - ", which is more then the number of specified arguments: ", - num_specified_args.value()); - for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) { - TORCH_CHECK( - args[i].default_value().has_value(), - "Error happened at preparing for default values for the argument. The ", - i, - "th argument ", - args[i].name(), - " does not have a specified value or default value. "); - - stack.push_back(args[i].default_value()); - } - stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); + if (!promoted_op) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs); + const auto& args = *pArgs; + if (model_version == 0x3LL && + opname == c10::OperatorName("aten::_convolution", "")) { + // Since byte-code versions 0x4L, convolution has an additional + // default-value argument (allow_tf32=True, see + // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles + // backward compatibility with models of byte-code version <= 0x3L, where + // this bool argument does not yet exist. + fn = [fn](Stack& stack) { + stack.push_back(true); fn(stack); }; + } else { + // num_specified_args >= 0 indicates number of arguments are available + // from model. We can use it to handle backward compatibility. + if (num_specified_args && + num_specified_args.value() < static_cast(args.size())) { + fn = [fn, num_specified_args, args](Stack& stack) { + std::vector out_args; + // The following logic pops and temporarily stores all out arguments + // from the stack (which can be 0 or more, and always appended to the + // schema), in order to push the necessary default values. Finally, + // the out arguments are pushed back into the stack. + for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { + out_args.push_back(stack.back()); + stack.pop_back(); + } + size_t start_index = num_specified_args.value() - out_args.size(); + TORCH_CHECK( + start_index >= 0, + "The number of output arguments is: ", + out_args.size(), + ", which is more then the number of specified arguments: ", + num_specified_args.value()); + for (size_t i = start_index; i < (args.size() - out_args.size()); + ++i) { + TORCH_CHECK( + args[i].default_value().has_value(), + "Error happened at preparing for default values for the argument. The ", + i, + "th argument ", + args[i].name(), + " does not have a specified value or default value. "); + + stack.push_back(args[i].default_value()); + } + stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); + fn(stack); + }; + } } } code_->operators_.emplace_back(fn); diff --git a/torch/csrc/jit/mobile/prim_ops_registery.cpp b/torch/csrc/jit/mobile/prim_ops_registery.cpp new file mode 100644 index 0000000..8349431 --- /dev/null +++ b/torch/csrc/jit/mobile/prim_ops_registery.cpp @@ -0,0 +1,34 @@ +#include + +namespace torch { +namespace jit { +namespace mobile { + +std::unordered_map>& primOpsFnTable() { + static std::unordered_map> + prim_ops_fn; + return prim_ops_fn; +} + +void registerPrimOpsFunction( + const std::string& name, + const std::function& fn) { + primOpsFnTable()[name] = fn; +} + +bool hasPrimOpsFn(const std::string& name) { + return primOpsFnTable().count(name); +} + +std::function& getPrimOpsFn(const std::string& name) { + TORCH_CHECK( + hasPrimOpsFn(name), + "Prim Ops Function for ", + name, + " is not promoted yet."); + return primOpsFnTable()[name]; +} + +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/prim_ops_registery.h b/torch/csrc/jit/mobile/prim_ops_registery.h new file mode 100644 index 0000000..f14848c --- /dev/null +++ b/torch/csrc/jit/mobile/prim_ops_registery.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace mobile { + +using Stack = std::vector; + +void registerPrimOpsFunction( + const std::string& name, + const std::function& fn); + +bool hasPrimOpsFn(const std::string& name); + +std::function& getPrimOpsFn(const std::string& name); + +class prim_op_fn_register { + public: + prim_op_fn_register( + const std::string& name, + const std::function& fn) { + registerPrimOpsFunction(name, fn); + } +}; + +} // namespace mobile +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp new file mode 100644 index 0000000..c96f090 --- /dev/null +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -0,0 +1,190 @@ +#include + +namespace torch { +namespace jit { +void tupleIndex(Stack& stack) { + int64_t index = pop(stack).toInt(); + auto tuple = pop(stack).toTuple(); + auto norm_index = normalizeIndex(index, tuple->elements().size()); + if (norm_index < 0 || + norm_index > static_cast(tuple->elements().size())) { + throw std::out_of_range("Tuple list index out of range"); + } + stack.emplace_back(tuple->elements()[norm_index]); +} + +void raiseException(Stack& stack) { + throw JITException(pop(stack).toStringRef()); +} + +void is(Stack& stack) { + IValue self, obj; + pop(stack, self, obj); + push(stack, self.is(obj)); +} + +void unInitialized(Stack& stack) { + push(stack, IValue::uninitialized()); +} + +void isNot(Stack& stack) { + IValue self, obj; + pop(stack, self, obj); + push(stack, !self.is(obj)); +} + +void aten_format(Stack& stack) { + size_t num_inputs = pop(stack).toInt(); + format(stack, num_inputs); +} + +void size(Stack& stack) { + auto t = std::move(pop(stack)).toTensor(); + pack(stack, t.sizes().vec()); +} + +void device(Stack& stack) { + push(stack, pop(stack).toTensor().device()); +} + +void dtype(Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, static_cast(a.scalar_type())); +} + +void toPrimDType(Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool non_blocking; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool copy; + pop(stack, non_blocking, copy); + c10::optional scalarType = + pop(stack).toOptional(); + c10::optional device = c10::nullopt; + at::Tensor self = pop(stack).toTensor(); + push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); +} + +void dim(Stack& stack) { + at::Tensor arg = pop(stack).toTensor(); + push(stack, arg.dim()); +} + +void _not(Stack& stack) { + push(stack, !pop(stack).toBool()); +} + +void boolTensor(Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_nonzero()); +} + +void toList(Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int elem_ty_val; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int dim_val; + at::Tensor t; + + pop(stack, elem_ty_val); + pop(stack, dim_val); + pop(stack, t); + + // If the Tensor is not on the CPU, transfer it. + if (!t.device().is_cpu()) { + t = t.cpu(); + } + + // Rebuild the output type using elem_ty_val and dim_val. Start + // with the element type corresponding to elem_ty_val. + TypePtr out_ty; + if (elem_ty_val == 0) { + out_ty = IntType::get(); + } else if (elem_ty_val == 1) { + out_ty = FloatType::get(); + } else if (elem_ty_val == 2) { + out_ty = BoolType::get(); + } else if (elem_ty_val == 3) { + out_ty = ComplexType::get(); + } else { + TORCH_CHECK( + false, + "Unsupported element type for tolist; only int, float, complex and bool are supported"); + } + + // Check that type of the Tensor matches that of the annotation. + // Make an exception for the case in which the annotated type is + // float/complex and the Tensor data type is also float/complex; + // the elements will be casted to double/c10::complex + // later. + TORCH_CHECK( + (out_ty == FloatType::get() && t.is_floating_point()) || + (out_ty == ComplexType::get() && t.is_complex()) || + tryScalarTypeFromJitType(out_ty) == t.scalar_type(), + "Output annotation element type and runtime tensor element type must match for tolist()"); + + // Check that the dimension of the Tensor matches that of the + // annotation. + TORCH_CHECK( + dim_val == t.dim(), + "Output annotation list dimension and runtime tensor dimension must match for tolist()"); + + // Wrap out_ty in a ListType dim times. + for (const auto i : c10::irange(dim_val)) { + (void)i; // Suppress unused variable warning + out_ty = ListType::create(out_ty); + } + + int64_t dim = t.dim(); + auto sizes = t.sizes(); + auto strides = t.strides(); + size_t element_size = t.element_size(); + char* data = static_cast(t.data_ptr()); + auto result = tensorToListRecursive( + data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size); + push(stack, std::move(result)); +} + +void numToTensorScalar(Stack& stack) { + at::Scalar s; + pop(stack, s); + push(stack, at::scalar_to_tensor(s)); +} + +void isCuda(Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, a.is_cuda()); +} + +void numToTensorBool(Stack& stack) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool b; + pop(stack, b); + push(stack, at::scalar_to_tensor(b)); +} + +static const std::array op_reg = { + mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex), + mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor), + mobile::prim_op_fn_register("aten::format", aten_format), + mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar), + mobile::prim_op_fn_register("prim::RaiseException", raiseException), + mobile::prim_op_fn_register("prim::device", device), + mobile::prim_op_fn_register("prim::dtype", dtype), + mobile::prim_op_fn_register("aten::__not__", _not), + mobile::prim_op_fn_register("aten::__is__", is), + mobile::prim_op_fn_register("aten::__isnot__", isNot), + mobile::prim_op_fn_register("aten::dim", dim), + mobile::prim_op_fn_register("prim::Uninitialized", unInitialized), + mobile::prim_op_fn_register("aten::to.prim_dtype", toPrimDType), + mobile::prim_op_fn_register("prim::is_cuda", isCuda) + // TODO: (@pavithran) size is overloaded with int[] and Tensor + // so this throws error expecting int not Tensor + // mobile::prim_op_fn_register("aten::size", size) +}; + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.h b/torch/csrc/jit/mobile/promoted_prim_ops.h new file mode 100644 index 0000000..0e46b62 --- /dev/null +++ b/torch/csrc/jit/mobile/promoted_prim_ops.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include + +namespace torch { +namespace jit { + +void tupleIndex(Stack& stack); + +void raiseException(Stack& stack); + +void is(Stack& stack); + +void unInitialized(Stack& stack); + +void isNot(Stack& stack); + +void aten_format(Stack& stack); + +void size(Stack& stack); + +void device(Stack& stack); + +void dtype(Stack& stack); + +void toPrimDType(Stack& stack); + +void dim(Stack& stack); + +void _not(Stack& stack); + +void boolTensor(Stack& stack); + +void toList(Stack& stack); + +void numToTensorScalar(Stack& stack); + +void isCuda(Stack& stack); + +void numToTensorBool(Stack& stack); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 9164471..05bd5b7 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -201,11 +202,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.Tensor(Tensor a) -> bool"), - [](Stack& stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_nonzero()); - }, + boolTensor, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), @@ -363,10 +360,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::format(str self, ...) -> str"), - [](Stack& stack) { - size_t num_inputs = pop(stack).toInt(); - format(stack, num_inputs); - }, + [](Stack& stack) { aten_format(stack); }, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::einsum.sublist(Tensor a, ...) -> Tensor"), @@ -377,15 +371,11 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"), - [](Stack& stack) { - at::Scalar s; - pop(stack, s); - push(stack, at::scalar_to_tensor(s)); - }, + numToTensorScalar, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"), - [](Stack& stack) { throw JITException(pop(stack).toStringRef()); }, + raiseException, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Size(int[] sizes) -> int[]"), @@ -393,10 +383,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"), - [](Stack& stack) { - auto t = std::move(pop(stack)).toTensor(); - pack(stack, t.sizes().vec()); - }, + size, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"), @@ -431,16 +418,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { // note the compiler knows to type TupleIndex more accurately than it // is listed here. TORCH_SELECTIVE_SCHEMA("prim::TupleIndex(Any tup, int i) -> Any"), - [](Stack& stack) { - int64_t index = pop(stack).toInt(); - auto tuple = pop(stack).toTuple(); - auto norm_index = normalizeIndex(index, tuple->elements().size()); - if (norm_index < 0 || - norm_index > static_cast(tuple->elements().size())) { - throw std::out_of_range("Tuple list index out of range"); - } - stack.emplace_back(tuple->elements()[norm_index]); - }, + tupleIndex, aliasAnalysisSpecialCase()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::ne.int_list(int[] a, int[] b) -> bool"), @@ -453,35 +431,23 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"), - [](Stack& stack) { push(stack, pop(stack).toTensor().device()); }, + device, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"), - [](Stack& stack) { - at::Tensor a; - pop(stack, a); - push(stack, static_cast(a.scalar_type())); - }, + dtype, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"), - [](Stack& stack) { push(stack, !pop(stack).toBool()); }, + _not, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::__is__(t1 self, t2 obj) -> bool"), - [](Stack& stack) { - IValue self, obj; - pop(stack, self, obj); - push(stack, self.is(obj)); - }, + is, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::__isnot__(t1 self, t2 obj) -> bool"), - [](Stack& stack) { - IValue self, obj; - pop(stack, self, obj); - push(stack, !self.is(obj)); - }, + isNot, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::element_size(Tensor self) -> int"), @@ -499,10 +465,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"), - [](Stack& stack) { - at::Tensor arg = pop(stack).toTensor(); - push(stack, arg.dim()); - }, + dim, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), @@ -655,7 +618,7 @@ static const OperatorGeneratorArgs opGenArgs[] = { aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), - [](Stack& stack) { push(stack, IValue::uninitialized()); }, + unInitialized, aliasAnalysisSpecialCase()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::Print(...) -> ()"), @@ -1044,27 +1007,11 @@ static const OperatorGeneratorArgs opGenArgs[] = { OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA( "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), - [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; - pop(stack, non_blocking, copy); - c10::optional scalarType = - pop(stack).toOptional(); - c10::optional device = c10::nullopt; - at::Tensor self = pop(stack).toTensor(); - push( - stack, to_dispatch(self, device, scalarType, non_blocking, copy)); - }, + toPrimDType, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::is_cuda(Tensor a) -> bool"), - [](Stack& stack) { - at::Tensor a; - pop(stack, a); - push(stack, a.is_cuda()); - }, + isCuda, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::is_xpu(Tensor a) -> bool"), @@ -1182,80 +1129,7 @@ RegisterOperators reg(([]() { // depends on the type hint and input. The implementation of this // operator below is intended to be as close to the Python // implementation in torch/csrc/utils/tensor_list.cpp as possible. - [](const Node* /*node*/) -> Operation { - return [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int elem_ty_val; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int dim_val; - at::Tensor t; - - pop(stack, elem_ty_val); - pop(stack, dim_val); - pop(stack, t); - - // If the Tensor is not on the CPU, transfer it. - if (!t.device().is_cpu()) { - t = t.cpu(); - } - - // Rebuild the output type using elem_ty_val and dim_val. Start - // with the element type corresponding to elem_ty_val. - TypePtr out_ty; - if (elem_ty_val == 0) { - out_ty = IntType::get(); - } else if (elem_ty_val == 1) { - out_ty = FloatType::get(); - } else if (elem_ty_val == 2) { - out_ty = BoolType::get(); - } else if (elem_ty_val == 3) { - out_ty = ComplexType::get(); - } else { - TORCH_CHECK( - false, - "Unsupported element type for tolist; only int, float, complex and bool are supported"); - } - - // Check that type of the Tensor matches that of the annotation. - // Make an exception for the case in which the annotated type is - // float/complex and the Tensor data type is also float/complex; - // the elements will be casted to double/c10::complex - // later. - TORCH_CHECK( - (out_ty == FloatType::get() && t.is_floating_point()) || - (out_ty == ComplexType::get() && t.is_complex()) || - tryScalarTypeFromJitType(out_ty) == t.scalar_type(), - "Output annotation element type and runtime tensor element type must match for tolist()"); - - // Check that the dimension of the Tensor matches that of the - // annotation. - TORCH_CHECK( - dim_val == t.dim(), - "Output annotation list dimension and runtime tensor dimension must match for tolist()"); - - // Wrap out_ty in a ListType dim times. - for (const auto i : c10::irange(dim_val)) { - (void)i; // Suppress unused variable warning - out_ty = ListType::create(out_ty); - } - - int64_t dim = t.dim(); - auto sizes = t.sizes(); - auto strides = t.strides(); - size_t element_size = t.element_size(); - char* data = static_cast(t.data_ptr()); - auto result = tensorToListRecursive( - data, - 0, - dim, - out_ty, - t.scalar_type(), - sizes, - strides, - element_size); - push(stack, std::move(result)); - }; - }, + [](const Node* /*node*/) -> Operation { return toList; }, aliasAnalysisSpecialCase())); return v; })()); @@ -2136,12 +2010,7 @@ static const OperatorGeneratorArgs opGenArgs1[] = { // because all _to_tensor conversion have to have the same operator namet OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.bool(bool a) -> Tensor"), - [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; - pop(stack, b); - push(stack, at::scalar_to_tensor(b)); - }, + numToTensorBool, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::device(str a) -> Device"), -- 2.7.4