public:
explicit DispatchTable(const FunctionSchema& schema)
: kernels_()
- , reverse_index_of_first_tensor_arg_(
- schema.arguments().size() - get_index_of_first_tensor_arg_(schema))
+ , dispatch_strategy_(get_dispatch_strategy_(schema))
, operator_name_(schema.name()) {}
DispatchTable(DispatchTable&&) = default;
* @return Kernel function pointing to the right kernel for the given arguments
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
- TensorTypeId dispatch_key = torch::jit::peek(
- *stack,
- 0,
- reverse_index_of_first_tensor_arg_
- ).toTensor().type_id();
+ TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack);
return *kernels_.lookup(dispatch_key, operator_name_);
}
return kernels_.isEmpty();
}
- private:
- static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) {
+private:
+ struct DispatchStrategy final {
+ // this is caching the index so we don't have to parse the schema inputs
+ // again and again for each dispatcher lookup.
+ // reverse_index means this is the distance from the first tensor argument
+ // to argument_list.end(), i.e. from the top of the stack.
+ // Since it is distance to end(), this means it's 1-indexed,
+ // i.e. '1' is the last argument.
+ size_t reverse_index_of_first_tensor_arg_;
+ bool first_tensor_arg_is_tensor_list_;
+
+ TensorTypeId get_dispatch_key(const Stack* stack) const {
+ auto first_tensor_arg = torch::jit::peek(
+ *stack,
+ 0,
+ reverse_index_of_first_tensor_arg_
+ );
+ if (first_tensor_arg_is_tensor_list_) {
+ auto tensor_list = first_tensor_arg.toTensorList();
+ if (tensor_list->elements().size() == 0) {
+ throw std::runtime_error("Tried to dispatch based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
+ }
+ return tensor_list->elements()[0].type_id();
+ } else {
+ return first_tensor_arg.toTensor().type_id();
+ }
+ }
+ };
+
+ static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) {
for (size_t i = 0; i < schema.arguments().size(); ++i) {
- if (schema.arguments()[i].type()->isSubtypeOf(TensorType::get())) {
- return i;
+ const auto& type = schema.arguments()[i].type();
+ if (type->isSubtypeOf(TensorType::get())) {
+ return {schema.arguments().size() - i, false};
+ }
+ if (type->isSubtypeOf(ListType::ofTensors())) {
+ return {schema.arguments().size() - i, true};
}
}
}
detail::ThreadsafeOperatorTable_ kernels_;
-
- // this is caching the index so we don't have to parse the schema inputs
- // again and again for each dispatcher lookup.
- // reverse_index means this is the distance from the first tensor argument
- // to argument_list.end(), i.e. from the top of the stack.
- // Since it is distance to end(), this means it's 1-indexed,
- // i.e. '1' is the last argument.
- size_t reverse_index_of_first_tensor_arg_;
+ DispatchStrategy dispatch_strategy_;
std::string operator_name_;
};