Allow fusion of float function arguments (#18087)
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Fri, 22 Mar 2019 20:48:59 +0000 (13:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 20:52:33 +0000 (13:52 -0700)
Summary:
so that functions like `def fn(x, p:float)` can be fused. Fixes #9940 and #11186. Fuses only float (not integer) arguments to simplify assembling arguments for fusion launch.
CPU fusion is disabled in CI and this won't be tested, but I tested it locally.
cc t-vi, apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18087

Differential Revision: D14581206

Pulled By: wanchaol

fbshipit-source-id: ccb0cf79b1751706f9b2cdf1715115eae5a39fb6

test/test_jit.py
torch/csrc/jit/fuser/codegen.cpp
torch/csrc/jit/fuser/codegen.h
torch/csrc/jit/fuser/compiler.cpp
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/passes/graph_fuser.cpp

index fad5307..20673f6 100644 (file)
@@ -12214,6 +12214,23 @@ class TestFuser(JitTestCase):
         self.assertEqual(f(x), scripted(x))
         self.assertAllFused(scripted.graph_for(x))
 
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_scalar_arg_cuda(self):
+        def fn_test_scalar_arg(x, p):
+            # type: (Tensor, float) -> Tensor
+            return p * (x * x + x)
+
+        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+        p = 3
+        scripted = torch.jit.script(fn_test_scalar_arg, (x, p))
+        self.assertEqual(fn_test_scalar_arg(x, p), scripted(x, p))
+        self.assertAllFused(scripted.graph_for(x, p))
+        x.requires_grad_(True)
+        out = scripted(x, p)
+        self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes"))
+
     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
     @enable_cpu_fuser
     def test_fuser_deduplication(self):
index 274d41a..fd5beef 100644 (file)
@@ -300,7 +300,7 @@ static void emitIndexingFor(
 std::string generateKernel(
     const std::string& name,
     const Graph& graph,
-    const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
+    const std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>& inputs,
     const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
     const bool use_cuda) {
   TemplateEnv env;
@@ -316,29 +316,54 @@ std::string generateKernel(
 
   // Lambda for writing arguments
   auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
-    std::string tensor =
-        "t" +
+    env.d(
+        "formal_index",
+        formals.size() +
+            1); // + 1 because the first argument is the linearIndex
+      std::string tensor =
+          "t" +
+          std::to_string(
+              formals.size()); // can't be unique() because Param may be an output
+      const auto nDim = desc.nDim();
+      emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
+      env.s("tensor", tensor);
+      env.d("nDim", nDim);
+      env.s("scalar_type", scalarTypeName(desc.scalar_type));
+      formals.push_back(
+          format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
+      argument_loads.push_back(format(
+          "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
+          env));
+  };
+
+  auto emitScalarFormal = [&](const Value* n){
+    env.d(
+        "formal_index",
+        formals.size() +
+            1); // + 1 because the first argument is the linearIndex
+    std::string scalar =
+        "s" +
         std::to_string(
             formals.size()); // can't be unique() because Param may be an output
-    const auto nDim = desc.nDim();
-    emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
-    env.s("tensor", tensor);
     env.d(
         "formal_index",
         formals.size() +
             1); // + 1 because the first argument is the linearIndex
-    env.d("nDim", nDim);
-    env.s("scalar_type", scalarTypeName(desc.scalar_type));
-    formals.push_back(
-        format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
+    env.s("scalar", scalar);
+    env.s("scalar_type", variableType(n->type()));
+    formals.push_back(format("${scalar_type} ${scalar}", env));
     argument_loads.push_back(format(
-        "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
-        env));
+    "*static_cast<${scalar_type}*>(args[${formal_index}])", env));
   };
 
+
   // Writes input parameters
   for (const auto& input : inputs) {
-    emitFormal(input.first, input.second);
+    if (input.second.has_value()){
+      emitFormal(input.first, *input.second);
+    } else {
+      emitScalarFormal(input.first);
+    }
   }
 
   // Writes output parameters
@@ -358,18 +383,22 @@ std::string generateKernel(
     // Note: conversion from half is only supported for CUDA kernels.
     //  The conversion immediately converts fp16 inputs to float.
     //  Access for other types is common to CUDA and CPU kernels.
-    const auto is_half = (input.second.scalar_type == at::ScalarType::Half);
-    if (is_half) {
-      AT_ASSERT(use_cuda);
-      env.s(
-          "access",
-          format("__half2float(t${formal}.data[t${formal}_offset])", env));
-      has_half_tensor = true;
+    if (input.second.has_value()) {
+      const auto is_half = input.second.has_value() && ((*input.second).scalar_type == at::ScalarType::Half);
+      if (is_half) {
+        AT_ASSERT(use_cuda);
+        env.s(
+            "access",
+            format("__half2float(t${formal}.data[t${formal}_offset])", env));
+        has_half_tensor = true;
+      } else {
+        env.s("access", format("t${formal}.data[t${formal}_offset]", env));
+      }
+      env.s("lhs_type", calcScalarTypeName(input.second.value().scalar_type));
     } else {
-      env.s("access", format("t${formal}.data[t${formal}_offset]", env));
+      env.s("access", format("s${formal}", env));
+      env.s("lhs_type", variableType(input.first->type()));
     }
-    env.s("lhs_type", calcScalarTypeName(input.second.scalar_type));
-
     body << format("${lhs_type} ${node} = ${access};\n", env);
   }
 
index 1135cfc..52db7bb 100644 (file)
@@ -20,7 +20,7 @@ namespace fuser {
 TORCH_API std::string generateKernel(
     const std::string& name,
     const Graph& graph,
-    const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
+    const std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>& inputs,
     const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
     const bool use_cuda);
 
index 0e1fdf3..ebfa45e 100644 (file)
@@ -280,10 +280,13 @@ std::shared_ptr<FusedKernel> compileKernel(
 
   // Creates chunk and flattened input descriptions
   std::vector<PartitionDesc> chunk_desc;
-  std::vector<std::pair<const Value*, const TensorDesc>> flat_inputs;
+  std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>> flat_inputs;
   {
     size_t input_index = 0;
     for (const auto& p : graph->inputs()) {
+      if (p->type()->isSubtypeOf(FloatType::get())) {
+        flat_inputs.emplace_back(p, c10::nullopt);
+      }
       if (!p->type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
index c3c29af..2003366 100644 (file)
@@ -133,7 +133,7 @@ static bool expandArgs(
 static bool shouldExpandArgs(
     const KernelSpec& spec,
     std::vector<at::Tensor>& args,
-    std::vector<int64_t>& map_size) {  
+    std::vector<int64_t>& map_size) {
   return expandArgs(spec, args, map_size, /*dry_run=*/true);
 }
 
@@ -191,6 +191,7 @@ void launchFusion(
     const FusedKernel& fusion,
     const at::Device device,
     const at::ArrayRef<at::Tensor>& inputs,
+    const at::ArrayRef<IValue>& all_inputs,
     std::vector<at::Tensor>& outputs) {
   // Fails if fusion and given inputs disagree
   AT_ASSERT(inputs.size() == fusion.inputDesc().size());
@@ -222,6 +223,13 @@ void launchFusion(
     numel = computeNumel(map_size);
   }
 
+  // compute number of scalar inputs and convert them to float
+  std::vector<float> scalar_inputs;
+  scalar_inputs.reserve(all_inputs.size());
+  for (auto const &input: all_inputs){
+    if (input.isDouble()) scalar_inputs.push_back(input.to<float>());
+  }
+
   // Computes the storage needed to store TensorInfo structs for inputs and
   // outputs.
   size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
@@ -234,7 +242,7 @@ void launchFusion(
 
   // A vector of arguments to the kernel (numel, *input_desc_s, *output_desc_s)
   std::vector<void*> arguments;
-  arguments.reserve(3 + flat_inputs_size + flat_outputs_size);
+  arguments.reserve(3 + scalar_inputs.size() + flat_inputs_size + flat_outputs_size);
   arguments.push_back(&numel);
 
   auto addTensorInfoRaw = [&](const TensorDesc& desc,
@@ -274,6 +282,10 @@ void launchFusion(
       }
     }
   }
+  // Adds scalar arguments
+  for (float &s: scalar_inputs){
+    arguments.push_back(&s);
+  }
 
   // Adds (flattened) output arguments
   outputs.reserve(fusion.outputDesc().size());
@@ -363,7 +375,7 @@ bool runFusion(const int64_t key, Stack& stack) {
 
   // Launches fusion
   std::vector<at::Tensor> raw_outputs;
-  launchFusion(*(*maybe_kernel), device, inputs, raw_outputs);
+  launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs);
 
   auto outputs = fmap(spec.outputMapAndSizes(), [&](const OutputMapAndSize& omap) {
     if (omap.needsSumToSize()) {
index f3945ed..5923582 100644 (file)
@@ -105,9 +105,8 @@ bool isSimpleMap(Node* node) {
   if (!simple_mappable.find(node)) {
     return false;
   }
-  // Check that all non-tensor inputs are constant
   for (Value* input : node->inputs()) {
-    if (input->type()->isSubtypeOf(TensorType::get())) {
+    if (input->type()->isSubtypeOf(TensorType::get()) || input->type()->isSubtypeOf(FloatType::get())) {
       continue;
     }
     if (input->node()->kind() != prim::Constant) {
@@ -384,8 +383,9 @@ struct GraphFuser {
           group->insertInput(tensor_insert_idx, input);
           tensor_insert_idx++;
         } else if (
-            n->kind() == aten::_grad_sum_to_size &&
-            input->type()->isSubtypeOf(ListType::ofInts())) {
+          (input->type()->isSubtypeOf(FloatType::get()) && input->node()->kind() != prim::Constant) ||
+          (n->kind() == aten::_grad_sum_to_size &&
+            input->type()->isSubtypeOf(ListType::ofInts()))) {
           auto in_group = subgraph.addInput();
           in_group->setType(input->type());
           inputs_map[input] = in_group;