From ed47b85d3bd9c33ed723d2ae5309bebb4d619ac6 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Fri, 22 Mar 2019 13:48:59 -0700 Subject: [PATCH] Allow fusion of float function arguments (#18087) Summary: so that functions like `def fn(x, p:float)` can be fused. Fixes #9940 and #11186. Fuses only float (not integer) arguments to simplify assembling arguments for fusion launch. CPU fusion is disabled in CI and this won't be tested, but I tested it locally. cc t-vi, apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/18087 Differential Revision: D14581206 Pulled By: wanchaol fbshipit-source-id: ccb0cf79b1751706f9b2cdf1715115eae5a39fb6 --- test/test_jit.py | 17 ++++++++ torch/csrc/jit/fuser/codegen.cpp | 75 ++++++++++++++++++++++++----------- torch/csrc/jit/fuser/codegen.h | 2 +- torch/csrc/jit/fuser/compiler.cpp | 5 ++- torch/csrc/jit/fuser/executor.cpp | 18 +++++++-- torch/csrc/jit/passes/graph_fuser.cpp | 8 ++-- 6 files changed, 93 insertions(+), 32 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index fad5307..20673f6 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12214,6 +12214,23 @@ class TestFuser(JitTestCase): self.assertEqual(f(x), scripted(x)) self.assertAllFused(scripted.graph_for(x)) + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") + @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") + @skipIfRocm + def test_scalar_arg_cuda(self): + def fn_test_scalar_arg(x, p): + # type: (Tensor, float) -> Tensor + return p * (x * x + x) + + x = torch.randn(4, 4, dtype=torch.float, device='cuda') + p = 3 + scripted = torch.jit.script(fn_test_scalar_arg, (x, p)) + self.assertEqual(fn_test_scalar_arg(x, p), scripted(x, p)) + self.assertAllFused(scripted.graph_for(x, p)) + x.requires_grad_(True) + out = scripted(x, p) + self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes")) + @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") @enable_cpu_fuser def test_fuser_deduplication(self): diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index 274d41a..fd5beef 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -300,7 +300,7 @@ static void emitIndexingFor( std::string generateKernel( const std::string& name, const Graph& graph, - const std::vector>& inputs, + const std::vector>>& inputs, const std::vector>& outputs, const bool use_cuda) { TemplateEnv env; @@ -316,29 +316,54 @@ std::string generateKernel( // Lambda for writing arguments auto emitFormal = [&](const Value* n, const TensorDesc& desc) { - std::string tensor = - "t" + + env.d( + "formal_index", + formals.size() + + 1); // + 1 because the first argument is the linearIndex + std::string tensor = + "t" + + std::to_string( + formals.size()); // can't be unique() because Param may be an output + const auto nDim = desc.nDim(); + emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous()); + env.s("tensor", tensor); + env.d("nDim", nDim); + env.s("scalar_type", scalarTypeName(desc.scalar_type)); + formals.push_back( + format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env)); + argument_loads.push_back(format( + "*static_cast*>(args[${formal_index}])", + env)); + }; + + auto emitScalarFormal = [&](const Value* n){ + env.d( + "formal_index", + formals.size() + + 1); // + 1 because the first argument is the linearIndex + std::string scalar = + "s" + std::to_string( formals.size()); // can't be unique() because Param may be an output - const auto nDim = desc.nDim(); - emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous()); - env.s("tensor", tensor); env.d( "formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex - env.d("nDim", nDim); - env.s("scalar_type", scalarTypeName(desc.scalar_type)); - formals.push_back( - format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env)); + env.s("scalar", scalar); + env.s("scalar_type", variableType(n->type())); + formals.push_back(format("${scalar_type} ${scalar}", env)); argument_loads.push_back(format( - "*static_cast*>(args[${formal_index}])", - env)); + "*static_cast<${scalar_type}*>(args[${formal_index}])", env)); }; + // Writes input parameters for (const auto& input : inputs) { - emitFormal(input.first, input.second); + if (input.second.has_value()){ + emitFormal(input.first, *input.second); + } else { + emitScalarFormal(input.first); + } } // Writes output parameters @@ -358,18 +383,22 @@ std::string generateKernel( // Note: conversion from half is only supported for CUDA kernels. // The conversion immediately converts fp16 inputs to float. // Access for other types is common to CUDA and CPU kernels. - const auto is_half = (input.second.scalar_type == at::ScalarType::Half); - if (is_half) { - AT_ASSERT(use_cuda); - env.s( - "access", - format("__half2float(t${formal}.data[t${formal}_offset])", env)); - has_half_tensor = true; + if (input.second.has_value()) { + const auto is_half = input.second.has_value() && ((*input.second).scalar_type == at::ScalarType::Half); + if (is_half) { + AT_ASSERT(use_cuda); + env.s( + "access", + format("__half2float(t${formal}.data[t${formal}_offset])", env)); + has_half_tensor = true; + } else { + env.s("access", format("t${formal}.data[t${formal}_offset]", env)); + } + env.s("lhs_type", calcScalarTypeName(input.second.value().scalar_type)); } else { - env.s("access", format("t${formal}.data[t${formal}_offset]", env)); + env.s("access", format("s${formal}", env)); + env.s("lhs_type", variableType(input.first->type())); } - env.s("lhs_type", calcScalarTypeName(input.second.scalar_type)); - body << format("${lhs_type} ${node} = ${access};\n", env); } diff --git a/torch/csrc/jit/fuser/codegen.h b/torch/csrc/jit/fuser/codegen.h index 1135cfc..52db7bb 100644 --- a/torch/csrc/jit/fuser/codegen.h +++ b/torch/csrc/jit/fuser/codegen.h @@ -20,7 +20,7 @@ namespace fuser { TORCH_API std::string generateKernel( const std::string& name, const Graph& graph, - const std::vector>& inputs, + const std::vector>>& inputs, const std::vector>& outputs, const bool use_cuda); diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index 0e1fdf3..ebfa45e 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -280,10 +280,13 @@ std::shared_ptr compileKernel( // Creates chunk and flattened input descriptions std::vector chunk_desc; - std::vector> flat_inputs; + std::vector>> flat_inputs; { size_t input_index = 0; for (const auto& p : graph->inputs()) { + if (p->type()->isSubtypeOf(FloatType::get())) { + flat_inputs.emplace_back(p, c10::nullopt); + } if (!p->type()->isSubtypeOf(TensorType::get())) { continue; } diff --git a/torch/csrc/jit/fuser/executor.cpp b/torch/csrc/jit/fuser/executor.cpp index c3c29af..2003366 100644 --- a/torch/csrc/jit/fuser/executor.cpp +++ b/torch/csrc/jit/fuser/executor.cpp @@ -133,7 +133,7 @@ static bool expandArgs( static bool shouldExpandArgs( const KernelSpec& spec, std::vector& args, - std::vector& map_size) { + std::vector& map_size) { return expandArgs(spec, args, map_size, /*dry_run=*/true); } @@ -191,6 +191,7 @@ void launchFusion( const FusedKernel& fusion, const at::Device device, const at::ArrayRef& inputs, + const at::ArrayRef& all_inputs, std::vector& outputs) { // Fails if fusion and given inputs disagree AT_ASSERT(inputs.size() == fusion.inputDesc().size()); @@ -222,6 +223,13 @@ void launchFusion( numel = computeNumel(map_size); } + // compute number of scalar inputs and convert them to float + std::vector scalar_inputs; + scalar_inputs.reserve(all_inputs.size()); + for (auto const &input: all_inputs){ + if (input.isDouble()) scalar_inputs.push_back(input.to()); + } + // Computes the storage needed to store TensorInfo structs for inputs and // outputs. size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size(); @@ -234,7 +242,7 @@ void launchFusion( // A vector of arguments to the kernel (numel, *input_desc_s, *output_desc_s) std::vector arguments; - arguments.reserve(3 + flat_inputs_size + flat_outputs_size); + arguments.reserve(3 + scalar_inputs.size() + flat_inputs_size + flat_outputs_size); arguments.push_back(&numel); auto addTensorInfoRaw = [&](const TensorDesc& desc, @@ -274,6 +282,10 @@ void launchFusion( } } } + // Adds scalar arguments + for (float &s: scalar_inputs){ + arguments.push_back(&s); + } // Adds (flattened) output arguments outputs.reserve(fusion.outputDesc().size()); @@ -363,7 +375,7 @@ bool runFusion(const int64_t key, Stack& stack) { // Launches fusion std::vector raw_outputs; - launchFusion(*(*maybe_kernel), device, inputs, raw_outputs); + launchFusion(*(*maybe_kernel), device, inputs, all_inputs, raw_outputs); auto outputs = fmap(spec.outputMapAndSizes(), [&](const OutputMapAndSize& omap) { if (omap.needsSumToSize()) { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index f3945ed..5923582 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -105,9 +105,8 @@ bool isSimpleMap(Node* node) { if (!simple_mappable.find(node)) { return false; } - // Check that all non-tensor inputs are constant for (Value* input : node->inputs()) { - if (input->type()->isSubtypeOf(TensorType::get())) { + if (input->type()->isSubtypeOf(TensorType::get()) || input->type()->isSubtypeOf(FloatType::get())) { continue; } if (input->node()->kind() != prim::Constant) { @@ -384,8 +383,9 @@ struct GraphFuser { group->insertInput(tensor_insert_idx, input); tensor_insert_idx++; } else if ( - n->kind() == aten::_grad_sum_to_size && - input->type()->isSubtypeOf(ListType::ofInts())) { + (input->type()->isSubtypeOf(FloatType::get()) && input->node()->kind() != prim::Constant) || + (n->kind() == aten::_grad_sum_to_size && + input->type()->isSubtypeOf(ListType::ofInts()))) { auto in_group = subgraph.addInput(); in_group->setType(input->type()); inputs_map[input] = in_group; -- 2.7.4