From 2e94054e347d25d74506c9095b56b83d923f2a60 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 28 Feb 2019 16:25:37 -0800 Subject: [PATCH] Allow dispatch based on tensor list args (#17522) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17522 Dispatch is still based on the first tensor arg, but that first "tensor arg" is now allowed to be a tensor list. That is, the first argument that is either Tensor or TensorList will be the deciding factor for dispatch. If it is a TensorList, then that TensorList must not be empty or dispatch will fail. Reviewed By: ezyang Differential Revision: D14235840 fbshipit-source-id: 266c18912d56ce77aa84306c5605c4191f3d882b --- aten/src/ATen/core/dispatch/DispatchTable.h | 58 +++++++++++++++++++---------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 1cec3ec..665a8d5 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -119,8 +119,7 @@ class DispatchTable final { public: explicit DispatchTable(const FunctionSchema& schema) : kernels_() - , reverse_index_of_first_tensor_arg_( - schema.arguments().size() - get_index_of_first_tensor_arg_(schema)) + , dispatch_strategy_(get_dispatch_strategy_(schema)) , operator_name_(schema.name()) {} DispatchTable(DispatchTable&&) = default; @@ -159,11 +158,7 @@ class DispatchTable final { * @return Kernel function pointing to the right kernel for the given arguments */ const DispatchTableEntry& lookup(const Stack* stack) const { - TensorTypeId dispatch_key = torch::jit::peek( - *stack, - 0, - reverse_index_of_first_tensor_arg_ - ).toTensor().type_id(); + TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack); return *kernels_.lookup(dispatch_key, operator_name_); } @@ -171,11 +166,43 @@ class DispatchTable final { return kernels_.isEmpty(); } - private: - static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) { +private: + struct DispatchStrategy final { + // this is caching the index so we don't have to parse the schema inputs + // again and again for each dispatcher lookup. + // reverse_index means this is the distance from the first tensor argument + // to argument_list.end(), i.e. from the top of the stack. + // Since it is distance to end(), this means it's 1-indexed, + // i.e. '1' is the last argument. + size_t reverse_index_of_first_tensor_arg_; + bool first_tensor_arg_is_tensor_list_; + + TensorTypeId get_dispatch_key(const Stack* stack) const { + auto first_tensor_arg = torch::jit::peek( + *stack, + 0, + reverse_index_of_first_tensor_arg_ + ); + if (first_tensor_arg_is_tensor_list_) { + auto tensor_list = first_tensor_arg.toTensorList(); + if (tensor_list->elements().size() == 0) { + throw std::runtime_error("Tried to dispatch based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty."); + } + return tensor_list->elements()[0].type_id(); + } else { + return first_tensor_arg.toTensor().type_id(); + } + } + }; + + static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) { for (size_t i = 0; i < schema.arguments().size(); ++i) { - if (schema.arguments()[i].type()->isSubtypeOf(TensorType::get())) { - return i; + const auto& type = schema.arguments()[i].type(); + if (type->isSubtypeOf(TensorType::get())) { + return {schema.arguments().size() - i, false}; + } + if (type->isSubtypeOf(ListType::ofTensors())) { + return {schema.arguments().size() - i, true}; } } @@ -183,14 +210,7 @@ class DispatchTable final { } detail::ThreadsafeOperatorTable_ kernels_; - - // this is caching the index so we don't have to parse the schema inputs - // again and again for each dispatcher lookup. - // reverse_index means this is the distance from the first tensor argument - // to argument_list.end(), i.e. from the top of the stack. - // Since it is distance to end(), this means it's 1-indexed, - // i.e. '1' is the last argument. - size_t reverse_index_of_first_tensor_arg_; + DispatchStrategy dispatch_strategy_; std::string operator_name_; }; -- 2.7.4