Allow dispatch based on tensor list args (#17522)
authorSebastian Messmer <messmer@fb.com>
Fri, 1 Mar 2019 00:25:37 +0000 (16:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 1 Mar 2019 00:32:00 +0000 (16:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17522

Dispatch is still based on the first tensor arg, but that first "tensor arg" is now allowed to be a tensor list.
That is, the first argument that is either Tensor or TensorList will be the deciding factor for dispatch.
If it is a TensorList, then that TensorList must not be empty or dispatch will fail.

Reviewed By: ezyang

Differential Revision: D14235840

fbshipit-source-id: 266c18912d56ce77aa84306c5605c4191f3d882b

aten/src/ATen/core/dispatch/DispatchTable.h

index 1cec3ec..665a8d5 100644 (file)
@@ -119,8 +119,7 @@ class DispatchTable final {
  public:
   explicit DispatchTable(const FunctionSchema& schema)
   : kernels_()
-  , reverse_index_of_first_tensor_arg_(
-        schema.arguments().size() - get_index_of_first_tensor_arg_(schema))
+  , dispatch_strategy_(get_dispatch_strategy_(schema))
   , operator_name_(schema.name()) {}
 
   DispatchTable(DispatchTable&&) = default;
@@ -159,11 +158,7 @@ class DispatchTable final {
    * @return Kernel function pointing to the right kernel for the given arguments
    */
    const DispatchTableEntry& lookup(const Stack* stack) const {
-     TensorTypeId dispatch_key = torch::jit::peek(
-       *stack,
-       0,
-       reverse_index_of_first_tensor_arg_
-     ).toTensor().type_id();
+     TensorTypeId dispatch_key = dispatch_strategy_.get_dispatch_key(stack);
      return *kernels_.lookup(dispatch_key, operator_name_);
    }
 
@@ -171,11 +166,43 @@ class DispatchTable final {
      return kernels_.isEmpty();
    }
 
- private:
-  static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) {
+private:
+  struct DispatchStrategy final {
+    // this is caching the index so we don't have to parse the schema inputs
+    // again and again for each dispatcher lookup.
+    // reverse_index means this is the distance from the first tensor argument
+    // to argument_list.end(), i.e. from the top of the stack.
+    // Since it is distance to end(), this means it's 1-indexed,
+    // i.e. '1' is the last argument.
+    size_t reverse_index_of_first_tensor_arg_;
+    bool first_tensor_arg_is_tensor_list_;
+
+    TensorTypeId get_dispatch_key(const Stack* stack) const {
+      auto first_tensor_arg = torch::jit::peek(
+        *stack,
+        0,
+        reverse_index_of_first_tensor_arg_
+      );
+      if (first_tensor_arg_is_tensor_list_) {
+        auto tensor_list = first_tensor_arg.toTensorList();
+        if (tensor_list->elements().size() == 0) {
+          throw std::runtime_error("Tried to dispatch based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
+        }
+        return tensor_list->elements()[0].type_id();
+      } else {
+        return first_tensor_arg.toTensor().type_id();
+      }
+    }
+  };
+
+  static DispatchStrategy get_dispatch_strategy_(const FunctionSchema& schema) {
     for (size_t i = 0; i < schema.arguments().size(); ++i) {
-      if (schema.arguments()[i].type()->isSubtypeOf(TensorType::get())) {
-        return i;
+      const auto& type = schema.arguments()[i].type();
+      if (type->isSubtypeOf(TensorType::get())) {
+        return {schema.arguments().size() - i, false};
+      }
+      if (type->isSubtypeOf(ListType::ofTensors())) {
+        return {schema.arguments().size() - i, true};
       }
     }
 
@@ -183,14 +210,7 @@ class DispatchTable final {
   }
 
   detail::ThreadsafeOperatorTable_ kernels_;
-
-  // this is caching the index so we don't have to parse the schema inputs
-  // again and again for each dispatcher lookup.
-  // reverse_index means this is the distance from the first tensor argument
-  // to argument_list.end(), i.e. from the top of the stack.
-  // Since it is distance to end(), this means it's 1-indexed,
-  // i.e. '1' is the last argument.
-  size_t reverse_index_of_first_tensor_arg_;
+  DispatchStrategy dispatch_strategy_;
   std::string operator_name_;
 };