From: Michael Suo Date: Wed, 26 Dec 2018 14:52:25 +0000 (-0800) Subject: clang format world (#15524) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2089 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f636dc927687cc50a527c9185f9d95ed65e32996;p=platform%2Fupstream%2Fpytorch.git clang format world (#15524) Summary: The PR clang-formats everything in `torch/csrc/jit/` and adds it to the pre-commit hook. Here is a list of non-mechanical changes: - I went over each file and fixed up whenever I could tell that clang-format was clobbering comment formatting. - Made the macros in register_prim_ops a little more clang-format friendly by omitting trailing commas - Refactored autodiff.cpp to use a helper class with explicit state rather than a bunch of capturing lambdas - Small improvements to the precommit hook clang-format Pull Request resolved: https://github.com/pytorch/pytorch/pull/15524 Differential Revision: D13547989 Pulled By: suo fbshipit-source-id: 3ff1541bb06433ccfe6de6e33f29227a2b5bb493 --- diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 42e33d8..a0c71fd 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -16,17 +16,18 @@ } catch (const std::exception& e) { \ ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \ } -#define ASSERT_ANY_THROW(statement) \ - bool threw = false; \ - try { \ - (void)statement; \ - } catch (const std::exception& e) { \ - threw = true; \ - } \ - ASSERT_TRUE(threw); \ +#define ASSERT_ANY_THROW(statement) \ + bool threw = false; \ + try { \ + (void)statement; \ + } catch (const std::exception& e) { \ + threw = true; \ + } \ + ASSERT_TRUE(threw); #endif // defined(USE_GTEST) +#include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/argument_spec.h" #include "torch/csrc/jit/assertions.h" @@ -51,11 +52,10 @@ #include "torch/csrc/jit/passes/requires_grad_analysis.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/passes/utils/subgraph_utils.h" -#include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/symbolic_script.h" +#include "torch/csrc/jit/symbolic_variable.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/utils/hash.h" -#include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/engine.h" #include "torch/csrc/autograd/variable.h" @@ -440,7 +440,9 @@ std::shared_ptr build_lstm() { return r; } -std::vector run(InterpreterState & interp, const std::vector & inputs) { +std::vector run( + InterpreterState& interp, + const std::vector& inputs) { std::vector stack(inputs.begin(), inputs.end()); interp.run(stack); return fmap(stack, [](const IValue& i) { return i.toTensor(); }); @@ -469,8 +471,7 @@ std::pair runGradient( df_interpreter.run(df_stack); // Outputs of f needs to be sliced - f_stack.erase( - f_stack.begin() + grad_spec.f_real_outputs, f_stack.end()); + f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end()); return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack)); } @@ -515,13 +516,14 @@ void testTHNNConv() { // make inputs at::Tensor input = torch::randn(input_size); - at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]}); + at::Tensor weight = torch::randn( + {out_channels, input_size[1], kernel_size[0], kernel_size[1]}); at::Tensor bias = torch::randn({out_channels}); // run forward eagerly at::Tensor output, finput, fgradinput; - std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size, - bias, stride, padding); + std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward( + input, weight, kernel_size, bias, stride, padding); // make grad_outputs at::Tensor grad_output = torch::randn_like(output); @@ -530,9 +532,16 @@ void testTHNNConv() { // run backward eagerly at::Tensor grad_input, grad_weight, grad_bias; - std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight, - kernel_size, stride, padding, - finput, fgradinput, {true, true, true}); + std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward( + grad_output, + input, + weight, + kernel_size, + stride, + padding, + finput, + fgradinput, + {true, true, true}); // make JIT graph auto graph = std::make_shared(); @@ -544,7 +553,9 @@ void testTHNNConv() { auto weightg = graph->addInput("weight"); auto biasg = graph->addInput("bias"); - Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); + Value* conv = graph->insert( + aten::thnn_conv2d_forward, + {inputg, weightg, ksz_val, biasg, kst_val, pad_val}); auto outputs = conv->node()->outputs(); for (auto output : outputs) { graph->registerOutput(output); @@ -572,7 +583,7 @@ void testTHNNConv() { // Get outputs from the interpreter tensor_list tensors_out, tensor_grads_out; std::tie(tensors_out, tensor_grads_out) = - runGradient(grad_spec, tensors_in, tensor_grads_in); + runGradient(grad_spec, tensors_in, tensor_grads_in); // prepare expected structs tensor_list expected_tensors_out, expected_tensor_grads_out; @@ -589,7 +600,9 @@ void testTHNNConv() { } void testATenNativeBatchNorm() { - // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor + // running_mean, Tensor running_var, bool training, float momentum, float eps) + // -> (Tensor, Tensor, Tensor) std::vector input_size = {4, 3, 15, 17}; // B x C x H x W bool training = true; float momentum = 0.9; @@ -610,7 +623,15 @@ void testATenNativeBatchNorm() { // run forward eagerly at::Tensor output, savemean, saveinvstd; - std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps); + std::tie(output, savemean, saveinvstd) = at::native_batch_norm( + input, + weight, + bias, + running_mean_eager, + running_var_eager, + training, + momentum, + eps); // make grad_outputs at::Tensor grad_output = torch::randn_like(output); @@ -619,10 +640,21 @@ void testATenNativeBatchNorm() { // run backward eagerly at::Tensor grad_input, grad_weight, grad_bias; - // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight, - running_mean_eager, running_var_eager, - savemean, saveinvstd, training, eps, {true, true, true}); + // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor + // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor + // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, + // Tensor, Tensor) + std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward( + grad_output, + input, + weight, + running_mean_eager, + running_var_eager, + savemean, + saveinvstd, + training, + eps, + {true, true, true}); // make JIT graph auto graph = std::make_shared(); @@ -636,7 +668,16 @@ void testATenNativeBatchNorm() { auto running_meang = graph->addInput("running_mean"); auto running_varg = graph->addInput("running_var"); - Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val}); + Value* bn = graph->insert( + aten::native_batch_norm, + {inputg, + weightg, + biasg, + running_meang, + running_varg, + training_val, + momentum_val, + eps_val}); auto outputs = bn->node()->outputs(); for (auto output : outputs) { graph->registerOutput(output); @@ -666,7 +707,7 @@ void testATenNativeBatchNorm() { // Get outputs from the interpreter tensor_list tensors_out, tensor_grads_out; std::tie(tensors_out, tensor_grads_out) = - runGradient(grad_spec, tensors_in, tensor_grads_in); + runGradient(grad_spec, tensors_in, tensor_grads_in); // prepare expected structs tensor_list expected_tensors_out, expected_tensor_grads_out; @@ -770,15 +811,20 @@ void testADFormulas() { [](const VL& v) -> VL { return {v[0].tanh()}; }}, {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }}, {"view", - unary_pointwise_2d, - [](const VL& v) -> VL { return {v[0].view({3, 2})}; }}, + unary_pointwise_2d, + [](const VL& v) -> VL { + return {v[0].view({3, 2})}; + }}, {"expand", - {{2, 1}}, - [](const VL& v) -> VL { return {v[0].expand({2, 3})}; }}, + {{2, 1}}, + [](const VL& v) -> VL { + return {v[0].expand({2, 3})}; + }}, {"mm", {{10, 12}, {12, 15}}, [](const VL& v) -> VL { return {v[0].mm(v[1])}; }}, - // TODO: enable once we'll be able to capture lists across forward-backward + // TODO: enable once we'll be able to capture lists across + // forward-backward //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return // fmap(v[0].chunk(4, 1)); }}, //{"chunk", {{10, 12, 15}}, [](const VL& v) -> VL { return @@ -860,8 +906,10 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { graph->registerOutput(d.value()); graph->registerOutput(e.value()); - auto a_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); - auto b_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); + auto a_var = autograd::make_variable( + at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); + auto b_var = autograd::make_variable( + at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2)); PropagateInputShapes(graph); PropagateRequiresGrad(graph); @@ -899,9 +947,10 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) { auto getFusionGroup = [](const std::shared_ptr& graph) { const auto& nodes = graph->nodes(); - auto maybe_fusion_group = std::find_if( - nodes.begin(), nodes.end(), - [](const Node* node) { return node->kind() == prim::FusionGroup; }); + auto maybe_fusion_group = + std::find_if(nodes.begin(), nodes.end(), [](const Node* node) { + return node->kind() == prim::FusionGroup; + }); JIT_ASSERTM( maybe_fusion_group != nodes.end(), "testRegisterFusionCachesKernel: could not create FusionGroup"); @@ -1215,7 +1264,8 @@ void testCustomOperators() { ASSERT_EQ(op->schema().arguments()[0].name(), "_0"); ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); ASSERT_EQ(op->schema().arguments()[1].name(), "_1"); - ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); + ASSERT_EQ( + op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType); @@ -1243,7 +1293,8 @@ void testCustomOperators() { ASSERT_EQ(op->schema().arguments()[0].name(), "a"); ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); ASSERT_EQ(op->schema().arguments()[1].name(), "b"); - ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); + ASSERT_EQ( + op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType); ASSERT_EQ(op->schema().returns().size(), 1); ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType); @@ -1405,29 +1456,33 @@ void testSchemaParser() { // nested arrays auto s = parseSchema("at::what(int[][4] foo) -> ()"); ASSERT_TRUE(s.arguments().at(0).N() == 4); - ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments().at(0) - .type()->expect() + ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments() + .at(0) + .type() + ->expect() ->getElementType() ->expect() ->getElementType())); auto s2 = parseSchema("at::what(int[][] foo) -> ()"); - ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments().at(0) - .type()->expect() - ->getElementType() - ->expect() - ->getElementType())); + ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments() + .at(0) + .type() + ->expect() + ->getElementType() + ->expect() + ->getElementType())); // named returns parseSchema("at::what(Tensor! i_will_be_written_to) -> ()"); - auto s3 = parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)"); + auto s3 = + parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)"); ASSERT_TRUE(s3.returns().at(0).name() == "the_return"); ASSERT_TRUE(s3.returns().at(1).name() == "the_return2"); // futures auto s4 = parseSchema("at::what(Future(int) foo) -> ()"); - ASSERT_TRUE(IntType::get()->isSubtypeOf(s4.arguments().at(0) - .type()->expect() - ->getElementType())); + ASSERT_TRUE(IntType::get()->isSubtypeOf( + s4.arguments().at(0).type()->expect()->getElementType())); // test tensor with annotated alias sets parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))"); @@ -1530,9 +1585,9 @@ void testTopologicalIndex() { } } - std::unique_ptr> newDynamicDAG() { - return std::unique_ptr>(new detail::DynamicDAG()); + return std::unique_ptr>( + new detail::DynamicDAG()); } void testNewVertex() { @@ -1781,20 +1836,20 @@ struct TopoMoveTestFixture { bool moveBeforeTopologicallyValid( const std::string& toInsert, const std::string& insertPoint) { - std::function func = [this](Node* toInsert, - Node* insertPoint) { - return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb); - }; + std::function func = + [this](Node* toInsert, Node* insertPoint) { + return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb); + }; return moveWithChecks(toInsert, insertPoint, func); } bool moveAfterTopologicallyValid( const std::string& toInsert, const std::string& insertPoint) { - std::function func = [this](Node* toInsert, - Node* insertPoint) { - return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb); - }; + std::function func = + [this](Node* toInsert, Node* insertPoint) { + return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb); + }; return moveWithChecks(toInsert, insertPoint, func); } diff --git a/tools/clang_format.py b/tools/clang_format.py index ca44baf..454bd34 100644 --- a/tools/clang_format.py +++ b/tools/clang_format.py @@ -10,18 +10,19 @@ Running tools/clang_format.py manually with no arguments should replicate the pr Only files that are in CLANG_FORMAT_WHITELIST are checked. """ import subprocess -import glob -import itertools import os import argparse +import fnmatch import difflib import sys +import re -# Whitelist of files to check. Takes a glob syntax. Does not support -# recursive globs ("**") because I am lazy and don't want to make that -# work with Python 2. -CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/passes/alias_analysis*"] +# Whitelist of directories to check. All files that in that directory +# (recursively) will be checked. +CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/", "test/cpp/jit/"] + +CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$") def parse_args(): @@ -43,21 +44,28 @@ def parse_args(): "Otherwise, just print the changes and exit" ), ) + parser.add_argument( + "--check-all", + action="store_true", + default=False, + help="If true, check all whitelisted files instead of just working copy changes", + ) parser.add_argument("--verbose", "-v", action="store_true", default=False) return parser.parse_args() def get_whitelisted_files(): """ - Parse CLANG_FORMAT_WHITELIST and resolve all globs. - Returns the set of all whitelisted filenames. + Parse CLANG_FORMAT_WHITELIST and resolve all directories. + Returns the set of whitelist cpp source files. """ - paths = [glob.glob(entry) for entry in CLANG_FORMAT_WHITELIST] - # flatten the files list - paths = itertools.chain(*paths) - # filter out directories - filenames = filter(lambda path: os.path.isfile(path), paths) - return set(filenames) + matches = [] + for dir in CLANG_FORMAT_WHITELIST: + for root, dirnames, filenames in os.walk(dir): + for filename in filenames: + if CPP_FILE_REGEX.fullmatch(filename): + matches.append(os.path.join(root, filename)) + return set(matches) def get_changed_files(rev): @@ -98,10 +106,13 @@ def get_diffs(files): def main(): args = parse_args() - changed_files = get_changed_files(args.diff) whitelisted_files = get_whitelisted_files() - files_to_check = changed_files & whitelisted_files + if args.check_all: + files_to_check = whitelisted_files + else: + changed_files = get_changed_files(args.diff) + files_to_check = changed_files & whitelisted_files if args.verbose: print("Running clang-format on whitelisted files: ") @@ -118,6 +129,11 @@ def main(): args = ["clang-format", "-i"] args.extend(name_to_diffs.keys()) subprocess.check_output(args) + + # add the changes so they will be committed + args = ["git", "add"] + args.extend(name_to_diffs.keys()) + subprocess.check_output(args) else: print("ERROR: Running clang-format created changes: ") for name, diff in name_to_diffs.items(): diff --git a/torch/csrc/jit/alias_info.h b/torch/csrc/jit/alias_info.h index e0d79fb..443a8b5 100644 --- a/torch/csrc/jit/alias_info.h +++ b/torch/csrc/jit/alias_info.h @@ -1,6 +1,7 @@ #include -namespace torch { namespace jit { +namespace torch { +namespace jit { using ::c10::AliasInfo; diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index ec3d988..ba9c0cd 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -1,20 +1,21 @@ #pragma once -#include -#include #include -#include +#include #include #include -#include #include +#include +#include +#include -namespace torch { namespace jit { +namespace torch { +namespace jit { -// GraphExecutor creates specializations of Graphs for different dimensionalitities -// and types of inputs. +// GraphExecutor creates specializations of Graphs for different +// dimensionalitities and types of inputs. -inline static at::Device ConvertIntToCPUOrCUDA(int device){ +inline static at::Device ConvertIntToCPUOrCUDA(int device) { return device < 0 ? at::kCPU : at::Device(at::DeviceType::CUDA, device); } struct ArgumentInfo { @@ -30,7 +31,8 @@ struct ArgumentInfo { int device() const { return device_; } - // XXX: It is guaranteed that this will return false when called on non-tensor arguments + // XXX: It is guaranteed that this will return false when called on non-tensor + // arguments bool requires_grad() const { return requires_grad_; } @@ -46,23 +48,29 @@ struct ArgumentInfo { return TensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim()); } -private: + private: unsigned is_tensor_ : 1; unsigned defined_ : 1; unsigned requires_grad_ : 1; unsigned : 5; unsigned dim_ : 8; - int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU + int device_ : 8; // NOTE: this needs to be signed because we use -1 to + // represent CPU unsigned type_ : 8; }; -static_assert(std::is_pod::value, - "ArgumentInfo is to be a POD struct"); -static_assert(sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), - "ArgumentInfo is expected to be a 32-bit struct"); +static_assert( + std::is_pod::value, + "ArgumentInfo is to be a POD struct"); +static_assert( + sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), + "ArgumentInfo is expected to be a 32-bit struct"); struct ArgumentSpec { - ArgumentSpec(bool with_grad, at::ArrayRef inputs, size_t num_flat_inputs) { + ArgumentSpec( + bool with_grad, + at::ArrayRef inputs, + size_t num_flat_inputs) { hash_code = num_flat_inputs; args.resize(num_flat_inputs); size_t offset = 0; @@ -73,7 +81,7 @@ struct ArgumentSpec { } void addInput(const IValue& input, size_t& offset, bool with_grad) { - auto & arg = args.at(offset); + auto& arg = args.at(offset); // Initialize all fields to 0. This is convenient, because e.g. // requires_grad() can be checked even on tensors AND will make // padding bits all 0s. @@ -92,17 +100,18 @@ struct ArgumentSpec { combineHash(arg); offset++; } else if (input.isTuple()) { - for (const IValue & elem : input.toTuple()->elements()) { + for (const IValue& elem : input.toTuple()->elements()) { addInput(elem, offset, with_grad); } } else { - // NB: no need to set is_tensor to false, because we memset the struct to 0 above + // NB: no need to set is_tensor to false, because we memset the struct to + // 0 above combineHash(arg); offset++; } } - void combineHash(const ArgumentInfo &arg) { + void combineHash(const ArgumentInfo& arg) { ArgumentInfo::plain_data_type arg_data; std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo)); hash_code = hash_combine(hash_code, arg_data); @@ -110,14 +119,19 @@ struct ArgumentSpec { // equality is fast: check ninputs, and then check the raw array data, // there are no size/stride indirections - bool operator==(const ArgumentSpec & spec) const { - if (args.size() != spec.args.size()) return false; - // NB: we need to break out early when there are no elements, because passing a - // nullptr to memcmp is UB. - if (args.size() == 0) return true; - return std::memcmp(args.data(), spec.args.data(), args.size() * sizeof(ArgumentInfo)) == 0; - } - bool operator!=(const ArgumentSpec & spec) const { + bool operator==(const ArgumentSpec& spec) const { + if (args.size() != spec.args.size()) + return false; + // NB: we need to break out early when there are no elements, because + // passing a nullptr to memcmp is UB. + if (args.size() == 0) + return true; + return std::memcmp( + args.data(), + spec.args.data(), + args.size() * sizeof(ArgumentInfo)) == 0; + } + bool operator!=(const ArgumentSpec& spec) const { return !(*this == spec); } size_t size() const { @@ -133,21 +147,25 @@ struct ArgumentSpec { // inferred for it based on this ArgumentSpec. std::vector getTypes(Graph& graph) const { size_t offset = 0; - return fmap(graph.inputs(), - [&](Value *v) { return fillType(v->type(), offset); }); + return fmap( + graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); }); } -private: + private: TypePtr fillType(TypePtr original, size_t& offset) const { if (original->isSubtypeOf(DynamicType::get())) { - auto & arg = args.at(offset++); + auto& arg = args.at(offset++); if (!arg.defined()) return UndefinedTensorType::get(); - return TensorType::create(arg.type(), ConvertIntToCPUOrCUDA(arg.device()), arg.dim(), arg.requires_grad()); + return TensorType::create( + arg.type(), + ConvertIntToCPUOrCUDA(arg.device()), + arg.dim(), + arg.requires_grad()); } else if (auto tuple_type = original->cast()) { - return TupleType::create(fmap(tuple_type->elements(), [&](const TypePtr& subtype) { - return fillType(subtype, offset); - })); + return TupleType::create(fmap( + tuple_type->elements(), + [&](const TypePtr& subtype) { return fillType(subtype, offset); })); } else { offset++; return original; @@ -171,38 +189,41 @@ struct CompleteArgumentInfoPOD { unsigned defined : 1; unsigned requires_grad : 1; signed device : 14; - uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's tensor_info() array. - // total_dims is the total number of dimensions seen so far - // in all previous members of tensor_info(), including this tensor - // 2*total_dims becomes the offset into the sizes_strides list - // for the _next_ tensor in the tensor_info array - // for tensor 0, the offset is always 0 + uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's + // tensor_info() array. total_dims is the total number of + // dimensions seen so far in all previous members of + // tensor_info(), including this tensor 2*total_dims + // becomes the offset into the sizes_strides list for the + // _next_ tensor in the tensor_info array for tensor 0, + // the offset is always 0 }; -static_assert(sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t), - "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work"); +static_assert( + sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t), + "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work"); struct CompleteArgumentInfo; struct CompleteArgumentSpec { CompleteArgumentSpec(bool with_grad, at::ArrayRef inputs) - : hash_code(0), ninputs(inputs.size()) { + : hash_code(0), ninputs(inputs.size()) { int32_t all_dims = 0; const int32_t num_inputs = inputs.size(); for (int32_t i = 0; i < num_inputs; i++) { - if (!inputs[i].isTensor()) continue; + if (!inputs[i].isTensor()) + continue; auto tensor = inputs[i].toTensor(); all_dims += tensor.defined() ? tensor.ndimension() : 0; } // allocate enough room for all TensorPODs and dimensions - data.resize(ninputs + all_dims*2); + data.resize(ninputs + all_dims * 2); // and reinterpret our data array as these structs auto* pods = reinterpret_cast(data.data()); - int64_t * next_dim = sizes_strides(); + int64_t* next_dim = sizes_strides(); int32_t total_dims = 0; - for(int32_t i = 0; i < num_inputs; i++) { - auto & pod = pods[i]; + for (int32_t i = 0; i < num_inputs; i++) { + auto& pod = pods[i]; pod.is_tensor = static_cast(inputs[i].isTensor()); if (pod.is_tensor) { at::Tensor t = inputs[i].toTensor(); @@ -210,10 +231,11 @@ struct CompleteArgumentSpec { if (pod.defined) { pod.type = static_cast(t.type().scalarType()); pod.device = (!t.is_cuda()) ? -1 : t.get_device(); - pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad(); + pod.requires_grad = + with_grad && autograd::as_variable_ref(t).requires_grad(); total_dims += t.ndimension(); auto sizes = t.sizes(); - std::copy(sizes.begin(),sizes.end(), next_dim); + std::copy(sizes.begin(), sizes.end(), next_dim); next_dim += sizes.size(); auto strides = t.strides(); std::copy(strides.begin(), strides.end(), next_dim); @@ -226,17 +248,17 @@ struct CompleteArgumentSpec { // we precompute the hash_code to minimize the time inside of hash // table operations where we may need to hold a compiler cache lock. hash_code = hash_combine(0, ninputs); - for(auto d : data) { + for (auto d : data) { hash_code = hash_combine(hash_code, d); } } // equality is fast: check ninputs, and then check the raw array data, // there are no size/stride indirections - bool operator==(const CompleteArgumentSpec & spec) const { + bool operator==(const CompleteArgumentSpec& spec) const { return ninputs == spec.ninputs && data == spec.data; } - bool operator!=(const CompleteArgumentSpec & spec) const { + bool operator!=(const CompleteArgumentSpec& spec) const { return !(*this == spec); } friend struct CompleteArgumentInfo; @@ -248,12 +270,13 @@ struct CompleteArgumentSpec { return hash_code; } -private: + private: ArrayRef tensor_info() const { return ArrayRef( - reinterpret_cast(data.data()), ninputs); + reinterpret_cast(data.data()), ninputs); } - // the start of the sizes_strides information, which comes after the CompleteArgumentInfoPOD list. + // the start of the sizes_strides information, which comes after the + // CompleteArgumentInfoPOD list. const int64_t* sizes_strides() const { return data.data() + ninputs; } @@ -262,15 +285,17 @@ private: } size_t hash_code; // precomputed on construction int32_t ninputs; - // layout is ninputs of TensorPOD (each 64-bit) followed by their size and stride info - // for 3 tensors: [t0POD][t1POD][t2POD][t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides] + // layout is ninputs of TensorPOD (each 64-bit) followed by their size and + // stride info for 3 tensors: + // [t0POD][t1POD][t2POD]... + // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides] std::vector data; }; // public view of compressed CompleteArgumentInfo struct CompleteArgumentInfo { - CompleteArgumentInfo(const CompleteArgumentSpec & spec, const int i) - : spec(spec), i(i) {} + CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i) + : spec(spec), i(i) {} bool isTensor() const { return pod(i).is_tensor; } @@ -288,49 +313,54 @@ struct CompleteArgumentInfo { } int ndimension() const { // See [valid range], it is always valid to ask for offset for (i + 1) - return (sizes_strides_offset(i + 1) - sizes_strides_offset(i))/2; + return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2; } at::IntList sizes() const { - return at::IntList(spec.sizes_strides() + sizes_strides_offset(i), ndimension()); + return at::IntList( + spec.sizes_strides() + sizes_strides_offset(i), ndimension()); } at::IntList strides() const { int ndim = ndimension(); - return at::IntList(spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim); + return at::IntList( + spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim); } operator TypePtr() const { - if(!defined()) + if (!defined()) return DynamicType::get(); - return CompleteTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides()); + return CompleteTensorType::create( + type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides()); } -private: + + private: // offsetinto sizes_strides() array where the sizes start for tensor j // [valid range] valid range is [0, ninputs] - // (i.e. you can ask for the offset at ninputs, which would be the offset of the next tensor if it existed) + // (i.e. you can ask for the offset at ninputs, which would be the offset of + // the next tensor if it existed) int sizes_strides_offset(int j) const { - if(j == 0) return 0; - return 2*pod(j - 1).total_dims; + if (j == 0) + return 0; + return 2 * pod(j - 1).total_dims; } - const CompleteArgumentInfoPOD & pod(int j) const { + const CompleteArgumentInfoPOD& pod(int j) const { return spec.tensor_info().at(j); } - const CompleteArgumentSpec & spec; + const CompleteArgumentSpec& spec; const int i; }; -inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) { - if(!info.defined()) { +inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) { + if (!info.defined()) { return out << ""; } - out << "Tensor(device=" << info.device() - << ", type=" << toString(info.type()) - << ", requires_grad=" << info.requires_grad() - << ", dims=" << info.dim() << ")"; + out << "Tensor(device=" << info.device() << ", type=" << toString(info.type()) + << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim() + << ")"; return out; } -inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) { +inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) { out << "{"; - for(size_t i = 0; i < spec.size(); ++i) { + for (size_t i = 0; i < spec.size(); ++i) { if (i > 0) out << ", "; out << spec.at(i); @@ -339,21 +369,23 @@ inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) { return out; } -inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) { - if(!info.defined()) { +inline std::ostream& operator<<( + std::ostream& out, + const CompleteArgumentInfo& info) { + if (!info.defined()) { return out << ""; } - out << "Tensor(device=" << info.device() - << ", type=" << toString(info.type()) - << ", requires_grad=" << info.requires_grad() - << ", sizes=" << info.sizes() - << ", strides=" << info.strides() << ")"; + out << "Tensor(device=" << info.device() << ", type=" << toString(info.type()) + << ", requires_grad=" << info.requires_grad() + << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")"; return out; } -inline std::ostream& operator<<(std::ostream & out, const CompleteArgumentSpec & spec) { +inline std::ostream& operator<<( + std::ostream& out, + const CompleteArgumentSpec& spec) { out << "{"; - for(size_t i = 0; i < spec.size(); ++i) { + for (size_t i = 0; i < spec.size(); ++i) { if (i > 0) out << ", "; out << spec.at(i); @@ -374,19 +406,20 @@ inline void setInputTypes(Graph& g, const ArgumentSpec& spec) { } } -}} +} // namespace jit +} // namespace torch namespace std { - template<> - struct hash { - size_t operator()(const torch::jit::ArgumentSpec & spec) const { - return spec.hashCode(); - } - }; - template<> - struct hash { - size_t operator()(const torch::jit::CompleteArgumentSpec & spec) const { - return spec.hashCode(); - } - }; -} +template <> +struct hash { + size_t operator()(const torch::jit::ArgumentSpec& spec) const { + return spec.hashCode(); + } +}; +template <> +struct hash { + size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const { + return spec.hashCode(); + } +}; +} // namespace std diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index b48fbf3..b84da83 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -1,31 +1,29 @@ #pragma once -#include +#include +#include #include -#include #include +#include #include -#include -#include #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { constexpr int max_tensor_display_size = 10; -enum class AttributeKind { - f,fs,i,is,s,ss,t,ts,g,gs -}; -static inline const char * toString(AttributeKind kind) { - static const char* names[] = {"f","fs","i","is","s","ss","t","ts","g","gs"}; - JIT_ASSERT(size_t(kind) < sizeof(names)/sizeof(AttributeKind)); +enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs }; +static inline const char* toString(AttributeKind kind) { + static const char* names[] = { + "f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"}; + JIT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind)); return names[int(kind)]; } struct AttributeValue { - AttributeValue(Symbol name) - : name(name) {} + AttributeValue(Symbol name) : name(name) {} using Ptr = std::unique_ptr; Symbol name; virtual AttributeKind kind() const = 0; @@ -33,67 +31,78 @@ struct AttributeValue { virtual ~AttributeValue() = default; }; -template +template struct ScalarAttributeValue : public AttributeValue { using ConstructorType = T; using ValueType = T; ScalarAttributeValue(Symbol name, ConstructorType value_) - : AttributeValue(name), value_(std::move(value_)) {} - ValueType & value() { + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { return value_; } Ptr clone() const override { return Ptr(new ScalarAttributeValue(name, value_)); } - AttributeKind kind() const override { return Kind; } -private: + AttributeKind kind() const override { + return Kind; + } + + private: ValueType value_; }; -template +template struct VectorAttributeValue : public AttributeValue { using ConstructorType = std::vector; using ValueType = std::vector; VectorAttributeValue(Symbol name, ConstructorType value_) - : AttributeValue(name), value_(std::move(value_)) {} - ValueType & value() { + : AttributeValue(name), value_(std::move(value_)) {} + ValueType& value() { return value_; } - AttributeKind kind() const override { return Kind; } + AttributeKind kind() const override { + return Kind; + } std::unique_ptr clone() const override { auto copy = value_; return Ptr(new VectorAttributeValue(name, std::move(copy))); } -private: + + private: ValueType value_; }; -using FloatAttr = ScalarAttributeValue; -using FloatsAttr = VectorAttributeValue; -using IntAttr = ScalarAttributeValue; -using IntsAttr = VectorAttributeValue; -using StringAttr = ScalarAttributeValue; -using StringsAttr = VectorAttributeValue; -using TensorAttr = ScalarAttributeValue; -using TensorsAttr = VectorAttributeValue; +using FloatAttr = ScalarAttributeValue; +using FloatsAttr = VectorAttributeValue; +using IntAttr = ScalarAttributeValue; +using IntsAttr = VectorAttributeValue; +using StringAttr = ScalarAttributeValue; +using StringsAttr = VectorAttributeValue; +using TensorAttr = ScalarAttributeValue; +using TensorsAttr = VectorAttributeValue; struct Graph; -using GraphAttr = ScalarAttributeValue,AttributeKind::g>; -using GraphsAttr = VectorAttributeValue,AttributeKind::gs>; +using GraphAttr = + ScalarAttributeValue, AttributeKind::g>; +using GraphsAttr = + VectorAttributeValue, AttributeKind::gs>; struct AttributeError : public std::exception { AttributeError(Symbol name, bool defined) { std::stringstream ss; - if(!defined) { - ss << "required keyword attribute '" << name.toUnqualString() << "' is undefined."; + if (!defined) { + ss << "required keyword attribute '" << name.toUnqualString() + << "' is undefined."; } else { - ss << "required keyword attribute '" << name.toUnqualString() << "' has the wrong type"; + ss << "required keyword attribute '" << name.toUnqualString() + << "' has the wrong type"; } msg = ss.str(); } - const char* what() const noexcept override { + const char* what() const noexcept override { return msg.c_str(); } -private: + + private: std::string msg; }; @@ -101,18 +110,18 @@ private: // method chaining e.g: // Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5); // we return Derived* pointers because Nodes are normally held as pointers. -template +template struct Attributes { Attributes() = default; - void copyAttributes(const Attributes & rhs) { + void copyAttributes(const Attributes& rhs) { values_.clear(); - for(auto & i : rhs.values_) { + for (auto& i : rhs.values_) { values_.push_back(i->clone()); } } bool hasAttribute(Symbol name) const { JIT_ASSERT(name.is_attr()); - return find(name,false) != values_.end(); + return find(name, false) != values_.end(); } // We want direct string accessors, as it is nicer to use than // hasAttribute(Symbol::attr("blah")) @@ -127,14 +136,14 @@ struct Attributes { } AttributeKind kindOf(Symbol name) const { JIT_ASSERT(name.is_attr()); - return (*find(name,true))->kind(); + return (*find(name, true))->kind(); } AttributeKind kindOfS(const std::string& name) const { return kindOf(Symbol::attr(name)); } Derived* removeAttribute(Symbol name) { JIT_ASSERT(name.is_attr()); - values_.erase(find(name,true)); + values_.erase(find(name, true)); return This(); } Derived* removeAttributeS(const std::string& name) { @@ -149,35 +158,36 @@ struct Attributes { // The names are returned in order, since name actually is the index. std::vector attributeNames() const { std::vector names; - for(auto & a : values_) + for (auto& a : values_) names.push_back(a->name); return names; } std::vector attributeNamesS() const { std::vector names; - for(auto & a : values_) + for (auto& a : values_) names.push_back(a->name.toUnqualString()); return names; } - #define CREATE_ACCESSOR(Kind, method) \ +#define CREATE_ACCESSOR(Kind, method) \ Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ - return set(name,std::forward(v)); \ - } \ - const Kind##Attr::ValueType& method(Symbol name) const { \ - return get(name); \ + return set( \ + name, std::forward(v)); \ + } \ + const Kind##Attr::ValueType& method(Symbol name) const { \ + return get(name); \ } - CREATE_ACCESSOR(Float,f) - CREATE_ACCESSOR(Floats,fs) - CREATE_ACCESSOR(String,s) - CREATE_ACCESSOR(Strings,ss) - CREATE_ACCESSOR(Int,i) - CREATE_ACCESSOR(Ints,is) - CREATE_ACCESSOR(Graph,g) - CREATE_ACCESSOR(Graphs,gs) + CREATE_ACCESSOR(Float, f) + CREATE_ACCESSOR(Floats, fs) + CREATE_ACCESSOR(String, s) + CREATE_ACCESSOR(Strings, ss) + CREATE_ACCESSOR(Int, i) + CREATE_ACCESSOR(Ints, is) + CREATE_ACCESSOR(Graph, g) + CREATE_ACCESSOR(Graphs, gs) - #undef CREATE_ACCESSOR +#undef CREATE_ACCESSOR // Our Graphs are not very const-correct, so we need to allow returning // non-const references too @@ -188,28 +198,29 @@ struct Attributes { // does not use CREATE_ACCESSOR because we need additional asserts Derived* t_(Symbol name, TensorAttr::ConstructorType v) { JIT_ASSERT(!v.defined() || !v.is_variable()); - return set(name,std::forward(v)); + return set(name, std::forward(v)); } const TensorAttr::ValueType& t(Symbol name) const { return get(name); } Derived* ts_(Symbol name, TensorsAttr::ConstructorType v) { - for(auto & t : v) { + for (auto& t : v) { JIT_ASSERT(!t.defined() || !t.is_variable()); } - return set(name,std::forward(v)); + return set( + name, std::forward(v)); } const TensorsAttr::ValueType& ts(Symbol name) const { return get(name); } - template - static void printPrimList(std::ostream & out, const std::vector & items) { + template + static void printPrimList(std::ostream& out, const std::vector& items) { out << "["; int i = 0; - for(auto & item : items) { - if(i++ > 0) + for (auto& item : items) { + if (i++ > 0) out << ", "; out << item; } @@ -221,7 +232,7 @@ struct Attributes { std::vector replace = {"\\n", "\\t", "\\v"}; for (size_t i = 0; i < search.size(); i++) { size_t pos = s.find(search[i]); - while(pos != std::string::npos) { + while (pos != std::string::npos) { s.replace(pos, 1, replace[i]); pos = s.find(search[i], pos + 1); } @@ -229,8 +240,8 @@ struct Attributes { return s; } - void printValue(std::ostream & out, const Symbol & name) const { - switch(kindOf(name)) { + void printValue(std::ostream& out, const Symbol& name) const { + switch (kindOf(name)) { case AttributeKind::f: out << f(name); break; @@ -247,34 +258,33 @@ struct Attributes { out << "\"" << escapeString(s(name)) << "\""; break; case AttributeKind::ss: - printPrimList(out,ss(name)); + printPrimList(out, ss(name)); break; - case AttributeKind::t: - { - at::Tensor tensor = t(name); - // 1-elem tensors are usually boxed scalars, so print them like it - if (tensor.numel() == 1) { - auto scalar_tensor = tensor.view({}).item(); - out << "{"; - if (scalar_tensor.isFloatingPoint()) { - out << scalar_tensor.toDouble(); - } else { - out << scalar_tensor.toLong(); - } - out << "}"; - } else if (tensor.numel() <= max_tensor_display_size) { - // TODO: This is awful code. Also it doesn't work on Windows. - std::ostringstream tensor_ss; - tensor_ss << tensor; - std::string tensor_s{tensor_ss.str()}; - // Remove newlines - std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' '); - out << tensor_s; + case AttributeKind::t: { + at::Tensor tensor = t(name); + // 1-elem tensors are usually boxed scalars, so print them like it + if (tensor.numel() == 1) { + auto scalar_tensor = tensor.view({}).item(); + out << "{"; + if (scalar_tensor.isFloatingPoint()) { + out << scalar_tensor.toDouble(); } else { - out << ""; + out << scalar_tensor.toLong(); } - break; + out << "}"; + } else if (tensor.numel() <= max_tensor_display_size) { + // TODO: This is awful code. Also it doesn't work on Windows. + std::ostringstream tensor_ss; + tensor_ss << tensor; + std::string tensor_s{tensor_ss.str()}; + // Remove newlines + std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' '); + out << tensor_s; + } else { + out << ""; } + break; + } case AttributeKind::ts: out << "[]"; break; @@ -287,29 +297,29 @@ struct Attributes { } } -private: + private: // UBSAN error: https://github.com/pytorch/pytorch/issues/9055 Derived* This() __ubsan_ignore_vptr__ { return static_cast(this); } - template + template Derived* set(Symbol name, typename T::ConstructorType v) { JIT_ASSERT(name.is_attr()); auto it = find(name, false); auto nv = AVPtr(new T(name, std::forward(v))); - if(it == values_.end()) { + if (it == values_.end()) { values_.push_back(std::move(nv)); } else { *it = std::move(nv); } return This(); } - template - typename T::ValueType & get(Symbol name) const { + template + typename T::ValueType& get(Symbol name) const { JIT_ASSERT(name.is_attr()); auto it = find(name, true); auto* child = dynamic_cast(it->get()); - if(child == nullptr) { + if (child == nullptr) { throw AttributeError(name, true); } return child->value(); @@ -322,10 +332,10 @@ private: using iterator = std::vector::iterator; iterator find(Symbol name, bool required) { JIT_ASSERT(name.is_attr()); - auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) { + auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); - if(required && it == values_.end()) { + if (required && it == values_.end()) { throw AttributeError(name, false); } JIT_ASSERT(!required || it != values_.end()); @@ -334,10 +344,10 @@ private: using const_iterator = std::vector::const_iterator; const_iterator find(Symbol name, bool required) const { JIT_ASSERT(name.is_attr()); - auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) { + auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); - if(required && it == values_.end()) { + if (required && it == values_.end()) { throw AttributeError(name, false); } JIT_ASSERT(!required || it != values_.end()); @@ -345,4 +355,5 @@ private: } }; -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 686a0b9..5bec04b 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -1,113 +1,111 @@ #include -#include "torch/csrc/jit/passes/lower_tuples.h" +#include #include -#include #include -#include "torch/csrc/jit/symbolic_script.h" +#include #include -#include #include +#include "torch/csrc/jit/passes/lower_tuples.h" #include "torch/csrc/jit/script/compiler.h" +#include "torch/csrc/jit/symbolic_script.h" #include #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { using value_map = std::unordered_map; using value_set = std::unordered_set; -void wrapDim(int64_t & dim, const std::vector & sizes) { +void wrapDim(int64_t& dim, const std::vector& sizes) { if (dim < 0) { dim += sizes.size(); } } -bool isDifferentiable(Node * n) { +bool isDifferentiable(Node* n) { // TODO: scalar-tensor ops should be canonicalized static OperatorSet differentiable_ops = { - "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", - "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", - "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", - "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", - "aten::mul(Tensor self, Tensor other) -> Tensor", - "aten::mul(Tensor self, Scalar other) -> Tensor", - "aten::div(Tensor self, Tensor other) -> Tensor", - "aten::div(Tensor self, Scalar other) -> Tensor", - "aten::max(Tensor self, Tensor other) -> Tensor", - "aten::min(Tensor self, Tensor other) -> Tensor", - "aten::sigmoid(Tensor self) -> Tensor", - "aten::tanh(Tensor self) -> Tensor", - "aten::relu(Tensor self) -> Tensor", - "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", - "aten::erf(Tensor self) -> Tensor", - "aten::erfc(Tensor self) -> Tensor", - "aten::exp(Tensor self) -> Tensor", - "aten::t(Tensor self) -> Tensor", - "aten::neg(Tensor self) -> Tensor", - "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor", - "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", - "aten::type_as(Tensor self, Tensor other) -> Tensor", - "aten::unsqueeze(Tensor self, int dim) -> Tensor", - "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor", - "aten::mm(Tensor self, Tensor mat2) -> Tensor", - "aten::lt(Tensor self, Tensor other) -> Tensor", - "aten::le(Tensor self, Tensor other) -> Tensor", - "aten::gt(Tensor self, Tensor other) -> Tensor", - "aten::ge(Tensor self, Tensor other) -> Tensor", - "aten::eq(Tensor self, Tensor other) -> Tensor", - "aten::ne(Tensor self, Tensor other) -> Tensor", - "aten::lt(Tensor self, Scalar other) -> Tensor", - "aten::le(Tensor self, Scalar other) -> Tensor", - "aten::gt(Tensor self, Scalar other) -> Tensor", - "aten::ge(Tensor self, Scalar other) -> Tensor", - "aten::eq(Tensor self, Scalar other) -> Tensor", - "aten::ne(Tensor self, Scalar other) -> Tensor", - "aten::abs(Tensor self) -> Tensor", - "aten::acos(Tensor self) -> Tensor", - "aten::asin(Tensor self) -> Tensor", - "aten::atan(Tensor self) -> Tensor", - "aten::ceil(Tensor self) -> Tensor", - "aten::cos(Tensor self) -> Tensor", - "aten::cosh(Tensor self) -> Tensor", - "aten::exp(Tensor self) -> Tensor", - "aten::expm1(Tensor self) -> Tensor", - "aten::floor(Tensor self) -> Tensor", - "aten::fmod(Tensor self, Scalar other) -> Tensor", - "aten::frac(Tensor self) -> Tensor", - "aten::log(Tensor self) -> Tensor", - "aten::log10(Tensor self) -> Tensor", - "aten::log1p(Tensor self) -> Tensor", - "aten::log2(Tensor self) -> Tensor", - "aten::reciprocal(Tensor self) -> Tensor", - "aten::remainder(Tensor self, Scalar other) -> Tensor", - "aten::round(Tensor self) -> Tensor", - "aten::rsqrt(Tensor self) -> Tensor", - "aten::sin(Tensor self) -> Tensor", - "aten::sinh(Tensor self) -> Tensor", - "aten::tan(Tensor self) -> Tensor", - "aten::trunc(Tensor self) -> Tensor", - "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)", - "aten::log_softmax(Tensor self, int dim) -> Tensor", - "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", - "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)", - "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", - "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::mul(Tensor self, Tensor other) -> Tensor", + "aten::mul(Tensor self, Scalar other) -> Tensor", + "aten::div(Tensor self, Tensor other) -> Tensor", + "aten::div(Tensor self, Scalar other) -> Tensor", + "aten::max(Tensor self, Tensor other) -> Tensor", + "aten::min(Tensor self, Tensor other) -> Tensor", + "aten::sigmoid(Tensor self) -> Tensor", + "aten::tanh(Tensor self) -> Tensor", + "aten::relu(Tensor self) -> Tensor", + "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", + "aten::erf(Tensor self) -> Tensor", + "aten::erfc(Tensor self) -> Tensor", + "aten::exp(Tensor self) -> Tensor", + "aten::t(Tensor self) -> Tensor", + "aten::neg(Tensor self) -> Tensor", + "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor", + "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", + "aten::type_as(Tensor self, Tensor other) -> Tensor", + "aten::unsqueeze(Tensor self, int dim) -> Tensor", + "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor", + "aten::mm(Tensor self, Tensor mat2) -> Tensor", + "aten::lt(Tensor self, Tensor other) -> Tensor", + "aten::le(Tensor self, Tensor other) -> Tensor", + "aten::gt(Tensor self, Tensor other) -> Tensor", + "aten::ge(Tensor self, Tensor other) -> Tensor", + "aten::eq(Tensor self, Tensor other) -> Tensor", + "aten::ne(Tensor self, Tensor other) -> Tensor", + "aten::lt(Tensor self, Scalar other) -> Tensor", + "aten::le(Tensor self, Scalar other) -> Tensor", + "aten::gt(Tensor self, Scalar other) -> Tensor", + "aten::ge(Tensor self, Scalar other) -> Tensor", + "aten::eq(Tensor self, Scalar other) -> Tensor", + "aten::ne(Tensor self, Scalar other) -> Tensor", + "aten::abs(Tensor self) -> Tensor", + "aten::acos(Tensor self) -> Tensor", + "aten::asin(Tensor self) -> Tensor", + "aten::atan(Tensor self) -> Tensor", + "aten::ceil(Tensor self) -> Tensor", + "aten::cos(Tensor self) -> Tensor", + "aten::cosh(Tensor self) -> Tensor", + "aten::exp(Tensor self) -> Tensor", + "aten::expm1(Tensor self) -> Tensor", + "aten::floor(Tensor self) -> Tensor", + "aten::fmod(Tensor self, Scalar other) -> Tensor", + "aten::frac(Tensor self) -> Tensor", + "aten::log(Tensor self) -> Tensor", + "aten::log10(Tensor self) -> Tensor", + "aten::log1p(Tensor self) -> Tensor", + "aten::log2(Tensor self) -> Tensor", + "aten::reciprocal(Tensor self) -> Tensor", + "aten::remainder(Tensor self, Scalar other) -> Tensor", + "aten::round(Tensor self) -> Tensor", + "aten::rsqrt(Tensor self) -> Tensor", + "aten::sin(Tensor self) -> Tensor", + "aten::sinh(Tensor self) -> Tensor", + "aten::tan(Tensor self) -> Tensor", + "aten::trunc(Tensor self) -> Tensor", + "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)", + "aten::log_softmax(Tensor self, int dim) -> Tensor", + "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", + "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)", + "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)", + "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", }; // TODO: add support for the following fusible operators. - // They're a little tricky to implement; max/min require mutability for best perf - // "aten::atan2(Tensor self) -> Tensor", - // "aten::max(Tensor self) -> Tensor", - // "aten::min(Tensor self) -> Tensor" - - if (n->kind() == prim::Constant || - n->kind() == prim::Undefined || - n->kind() == prim::AutogradAdd || - n->kind() == prim::ConstantChunk || + // They're a little tricky to implement; max/min require mutability for best + // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) -> + // Tensor", "aten::min(Tensor self) -> Tensor" + + if (n->kind() == prim::Constant || n->kind() == prim::Undefined || + n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk || n->kind() == prim::None) return true; if (differentiable_ops.find(n)) @@ -118,15 +116,18 @@ bool isDifferentiable(Node * n) { return true; } - if (n->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { - return n->get>(attr::size) && n->is_constant(attr::implicit) && - n->namedInput(attr::self)->type()->cast(); + if (n->matches( + "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { + return n->get>(attr::size) && + n->is_constant(attr::implicit) && + n->namedInput(attr::self)->type()->cast(); } if (n->matches("aten::view(Tensor self, int[] size) -> Tensor")) { return n->get>(attr::size) && - n->namedInput(attr::self)->type()->cast(); + n->namedInput(attr::self)->type()->cast(); } - if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { + if (n->matches( + "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { // TODO(asuhan): support weight return n->namedInput(attr::weight)->node()->kind() == prim::Undefined; } @@ -144,10 +145,11 @@ bool isDifferentiable(Node * n) { return false; } - -bool isDifferentiable(Graph & g) { - return std::all_of(g.nodes().begin(), g.nodes().end(), - static_cast(isDifferentiable)); +bool isDifferentiable(Graph& g) { + return std::all_of( + g.nodes().begin(), + g.nodes().end(), + static_cast(isDifferentiable)); } // NB: Write gradient using torchscript @@ -160,22 +162,25 @@ bool isDifferentiable(Graph & g) { // // Here ctx is a tuple that carries all input/intermediate results needed in // backward from forward pass. -// This python code is compiled into a GradientPair which includes a forward graph -// and a backward graph. Forward graph will be used to replace the node in grad_desc.f, -// and backward graph will be used to construct GradOf(node) in reverse_block. -// Grad_values(a.k.a gradOutputs) propagated through node->owningGraph() in -// **reversed** order, thus GradientPair.forward ahould be inserted **after** -// the node being replaced, so that we don't traverse the graph infinite times. +// +// This python code is compiled into a GradientPair which includes a forward +// graph and a backward graph. Forward graph will be used to replace the node in +// grad_desc.f, and backward graph will be used to construct GradOf(node) in +// reverse_block. Grad_values(a.k.a gradOutputs) propagated through +// node->owningGraph() in **reversed** order, thus GradientPair.forward ahould +// be inserted **after** the node being replaced, so that we don't traverse the +// graph infinite times. +// // The output of compiled forward graph is [real_outputs, ctx] // The input of compiled backward graph is [ctx, grad_values] -// We run LowerSimpleTuples afterwards to elmininate all tuples generated in this process. -// The original node and TupleConstruct nodes in forward graph will be cleaned up -// later using EliminateDeadCode(block). -// TupleUnPack node in backward graph will be removed in eliminateDeadcode(ReverseDetails) -// defined in this file. +// We run LowerSimpleTuples afterwards to elmininate all tuples generated in +// this process. The original node and TupleConstruct nodes in forward graph +// will be cleaned up later using EliminateDeadCode(block). TupleUnPack node in +// backward graph will be removed in eliminateDeadcode(ReverseDetails) defined +// in this file. static c10::optional> build_script_grad( - Node* node, - const ArrayRef& grads) { + Node* node, + const ArrayRef& grads) { auto graph = node->owningGraph(); auto compiled_graphs = gradientInfoForSchema(node->schema()); @@ -187,7 +192,8 @@ static c10::optional> build_script_grad( { WithInsertPoint guard(node->next()); auto fw_graph = compiled_graphs->forward; - new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true); + new_outputs = inlineCallTo( + *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true); for (size_t i = 0; i < node->outputs().size(); ++i) { new_outputs.at(i)->setType(node->outputs()[i]->type()); new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]); @@ -200,83 +206,141 @@ static c10::optional> build_script_grad( auto it = grad_vec.begin(); grad_vec.insert(it, new_outputs.back()); ArrayRef grad(grad_vec); - auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true); + auto grad_inputs = + inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true); return grad_inputs; }; -static std::vector gradientForNode(Node* node, ArrayRef grad_values) { - static const OperatorSet comparison_ops = { - "aten::lt(Tensor self, Tensor other) -> Tensor", - "aten::le(Tensor self, Tensor other) -> Tensor", - "aten::gt(Tensor self, Tensor other) -> Tensor", - "aten::ge(Tensor self, Tensor other) -> Tensor", - "aten::eq(Tensor self, Tensor other) -> Tensor", - "aten::ne(Tensor self, Tensor other) -> Tensor", - "aten::lt(Tensor self, Scalar other) -> Tensor", - "aten::le(Tensor self, Scalar other) -> Tensor", - "aten::gt(Tensor self, Scalar other) -> Tensor", - "aten::ge(Tensor self, Scalar other) -> Tensor", - "aten::eq(Tensor self, Scalar other) -> Tensor", - "aten::ne(Tensor self, Scalar other) -> Tensor", - }; - const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable { - Value * size; +namespace { +class GradientHelper { + public: + GradientHelper(Node* n) : node(n) {} + + std::vector gradient(ArrayRef grad_values) { + if (!isDifferentiable(node)) { + throw std::runtime_error( + std::string("differentiation of ") + node->kind().toDisplayString() + + " is not supported, or it is missing necessary type information"); + } + // If AD is defined using torchscript, use it instead of symbolic + auto script_grads = build_script_grad(node, grad_values); + if (script_grads) + return *script_grads; + // Definition not found in torchscript, look up in the buildSymbolicGradient + // TODO: migrate all to using torchscript + auto sym_grads = buildSymbolicGradient(fmap(grad_values)); + return fmap(sym_grads, [](const SymbolicVariable& v) { return v.value(); }); + } + + private: + Node* node; + + SymbolicVariable sumToSizeOf(SymbolicVariable v, Symbol input_name) { + Value* size; { - WithInsertPoint insert_guard {node}; + WithInsertPoint insert_guard{node}; size = SymbolicVariable(node->namedInput(input_name)).size(); } return v.sumToSize(size); }; - const auto build_sym_grad = [node, &sumToSizeOf](const std::vector& grads) -> std::vector { + + std::vector buildSymbolicGradient( + const std::vector& grads) { + static const OperatorSet comparison_ops = { + "aten::lt(Tensor self, Tensor other) -> Tensor", + "aten::le(Tensor self, Tensor other) -> Tensor", + "aten::gt(Tensor self, Tensor other) -> Tensor", + "aten::ge(Tensor self, Tensor other) -> Tensor", + "aten::eq(Tensor self, Tensor other) -> Tensor", + "aten::ne(Tensor self, Tensor other) -> Tensor", + "aten::lt(Tensor self, Scalar other) -> Tensor", + "aten::le(Tensor self, Scalar other) -> Tensor", + "aten::gt(Tensor self, Scalar other) -> Tensor", + "aten::ge(Tensor self, Scalar other) -> Tensor", + "aten::eq(Tensor self, Scalar other) -> Tensor", + "aten::ne(Tensor self, Scalar other) -> Tensor", + }; auto inputs = fmap(node->inputs()); auto outputs = fmap(node->outputs()); - if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { - return {sumToSizeOf(grads.at(0), attr::self), - sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other), - nullptr}; + if (node->matches( + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { + return { + sumToSizeOf(grads.at(0), attr::self), + sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other), + nullptr}; - } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { + } else if ( + node->matches( + "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { return {grads.at(0), nullptr, nullptr}; } else if (node->kind() == prim::AutogradAdd) { // NB: AutogradAdds don't broadcast return {grads.at(0), grads.at(0)}; - } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { + } else if ( + node->matches( + "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) { return {sumToSizeOf(grads.at(0), attr::self), - sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other), + sumToSizeOf( + -grads.at(0) * node->namedInput(attr::alpha), attr::other), nullptr}; - } else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { + } else if ( + node->matches( + "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) { return {grads.at(0), nullptr, nullptr}; - } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) { + } else if (node->matches( + "aten::mul(Tensor self, Tensor other) -> Tensor")) { return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self), sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)}; - } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { + } else if (node->matches( + "aten::mul(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0) * inputs.at(1), nullptr}; - } else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) { + } else if (node->matches( + "aten::div(Tensor self, Tensor other) -> Tensor")) { return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self), - sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)}; + sumToSizeOf( + -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), + attr::other)}; - } else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) { + } else if (node->matches( + "aten::div(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0) / inputs.at(1), nullptr}; - } else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)}; - - } else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)}; - - } else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) { + } else if (node->matches( + "aten::max(Tensor self, Tensor other) -> Tensor")) { + return { + sumToSizeOf( + grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), + attr::self), + sumToSizeOf( + grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), + attr::other)}; + + } else if (node->matches( + "aten::min(Tensor self, Tensor other) -> Tensor")) { + return { + sumToSizeOf( + grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), + attr::self), + sumToSizeOf( + grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), + attr::other)}; + + } else if ( + node->matches( + "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) { return {nullptr, - sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self), - sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)}; + sumToSizeOf( + grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self), + sumToSizeOf( + grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), + attr::other)}; } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) { // TODO: The order of operations matter in this case. This @@ -288,39 +352,55 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))}; } else if (node->matches("aten::relu(Tensor self) -> Tensor")) { - return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; + return {grads.at(0) * + (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))}; - } else if (node->matches("aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) { + } else if ( + node->matches( + "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) { // handle the case that min/max is None Value* min = inputs.at(1); bool min_must_be_none = min->node()->kind() == prim::None; Value* max = inputs.at(2); bool max_must_be_none = max->node()->kind() == prim::None; // XXX - this formula is wrong when min or max are not stricly prim::None - // but may be None dynamically. In this case an internal compiler error will - // get thrown when trying to generate expressions involving the values of min/max + // but may be None dynamically. In this case an internal compiler error + // will get thrown when trying to generate expressions involving the + // values of min/max if (!min_must_be_none && !max_must_be_none) { - return {grads.at(0) - * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) - * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr}; + return {grads.at(0) * + (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) * + (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), + nullptr, + nullptr}; } else if (max_must_be_none) { - return {grads.at(0) - * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))), nullptr, nullptr}; + return {grads.at(0) * + (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))), + nullptr, + nullptr}; } else if (min_must_be_none) { - return {grads.at(0) - * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr}; + return {grads.at(0) * + (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), + nullptr, + nullptr}; } else { return {grads.at(0), nullptr, nullptr}; } - } else if (node->matches("aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) { + } else if ( + node->matches( + "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) { auto threshold = node->get(attr::threshold).value(); - return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)), nullptr, nullptr}; + return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)), + nullptr, + nullptr}; } else if (node->matches("aten::erf(Tensor self) -> Tensor")) { - return {grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()}; + return {grads.at(0) * 1.12837916709551 * + (-inputs.at(0) * inputs.at(0)).exp()}; } else if (node->matches("aten::erfc(Tensor self) -> Tensor")) { - return {-grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()}; + return {-grads.at(0) * 1.12837916709551 * + (-inputs.at(0) * inputs.at(0)).exp()}; } else if (node->matches("aten::exp(Tensor self) -> Tensor")) { return {grads.at(0) * (outputs.at(0))}; @@ -335,18 +415,22 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val return {grads.at(0) * inputs.at(0).sign()}; } else if (node->matches("aten::acos(Tensor self) -> Tensor")) { - return {grads.at(0) * -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())}; + return {grads.at(0) * + -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())}; } else if (node->matches("aten::asin(Tensor self) -> Tensor")) { - return {grads.at(0) * (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()}; + return {grads.at(0) * + (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()}; } else if (node->matches("aten::atan(Tensor self) -> Tensor")) { return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))}; - } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) { - Value * self_size; + } else if ( + node->matches( + "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) { + Value* self_size; { - WithInsertPoint insert_guard { node }; + WithInsertPoint insert_guard{node}; self_size = inputs.at(0).size(); } return {grads.at(0).expand(self_size), nullptr}; @@ -369,7 +453,8 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::floor(Tensor self) -> Tensor")) { return {SymbolicVariable::zeros_like(grads.at(0))}; - } else if (node->matches("aten::fmod(Tensor self, Scalar other) -> Tensor")) { + } else if (node->matches( + "aten::fmod(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0), nullptr}; } else if (node->matches("aten::frac(Tensor self) -> Tensor")) { @@ -390,7 +475,8 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->matches("aten::reciprocal(Tensor self) -> Tensor")) { return {-grads.at(0) * outputs.at(0) * outputs.at(0)}; - } else if (node->matches("aten::remainder(Tensor self, Scalar other) -> Tensor")) { + } else if (node->matches( + "aten::remainder(Tensor self, Scalar other) -> Tensor")) { return {grads.at(0), nullptr}; } else if (node->matches("aten::round(Tensor self) -> Tensor")) { @@ -414,28 +500,42 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->kind() == prim::ConstantChunk) { return {SymbolicVariable::cat(grads, node->i(attr::dim))}; - } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || - node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { - // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple - auto sizes = node->namedInput(attr::self)->type()->expect()->sizes(); + } else if ( + node->matches("aten::view(Tensor self, int[] size) -> Tensor") || + node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { + // TODO: if sizes are not available statically, add an operator that + // reutrns them as a tuple + auto sizes = node->namedInput(attr::self) + ->type() + ->expect() + ->sizes(); return {grads.at(0).reshape(sizes), nullptr}; - } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { + } else if (node->matches( + "aten::type_as(Tensor self, Tensor other) -> Tensor")) { return {grads.at(0).type_as(inputs.at(0)), nullptr}; - } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { + } else if (node->matches( + "aten::unsqueeze(Tensor self, int dim) -> Tensor")) { return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr}; - } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) { - return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self), - grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha), - inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha), - nullptr, nullptr}; + } else if ( + node->matches( + "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) { + return { + sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self), + grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha), + inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha), + nullptr, + nullptr}; } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { - return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))}; + return {grads.at(0).mm(inputs.at(1).t()), + inputs.at(0).t().mm(grads.at(0))}; - } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { + } else if ( + node->matches( + "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) { const auto& input_sizes = inputs.at(0).sizes(); if (input_sizes.size() == 0) return {grads.at(0).sum(), nullptr, nullptr}; @@ -456,7 +556,8 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val const auto& sizes = inputs.at(0).sizes(); std::vector squeezed_dims; for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] != 1) continue; + if (sizes[i] != 1) + continue; squeezed_dims.push_back(i); } SymbolicVariable returned_grad = grads.at(0); @@ -465,18 +566,24 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } return {returned_grad}; - } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { + } else if (node->matches( + "aten::squeeze(Tensor self, int dim) -> Tensor", + /*const_inputs=*/attr::dim)) { int64_t dim = *node->get(attr::dim); const auto& sizes = inputs.at(0).sizes(); wrapDim(dim, sizes); - if (sizes.size() == 0) { + if (sizes.size() == 0) { return {grads.at(0), nullptr}; } - return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr}; + return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), + nullptr}; - } else if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { + } else if (node->matches( + "aten::cat(Tensor[] tensors, int dim) -> Tensor", + /*const_inputs=*/attr::dim)) { int dim = *node->get(attr::dim); - auto tensor_inputs = inputs; tensor_inputs.pop_back(); + auto tensor_inputs = inputs; + tensor_inputs.pop_back(); const auto& first_sizes = tensor_inputs.at(0).sizes(); const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { return var.sizes() == first_sizes; @@ -484,7 +591,8 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val // NB: this is a specialization for the common case where all inputs are // of equal sizes. We can use a single split operation to handle that. - if (std::all_of(tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) { + if (std::all_of( + tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) { auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim); tensor_grads.emplace_back(nullptr); // for attr::dim return tensor_grads; @@ -502,153 +610,191 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (comparison_ops.find(node)) { return {nullptr, nullptr}; - } else if (node->matches("aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) { + } else if ( + node->matches( + "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) { JIT_ASSERT(grads.size() == 1); auto graph = node->owningGraph(); - auto backward_value = graph->insert(aten::avg_pool2d_backward, { - grads.at(0).value(), - node->namedInput(attr::self), - node->namedInput(attr::kernel_size), - node->namedInput(attr::stride), - node->namedInput(attr::padding), - node->namedInput(attr::ceil_mode), - node->namedInput(attr::count_include_pad)}); - return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr}; - - } else if (node->matches("aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) { + auto backward_value = graph->insert( + aten::avg_pool2d_backward, + {grads.at(0).value(), + node->namedInput(attr::self), + node->namedInput(attr::kernel_size), + node->namedInput(attr::stride), + node->namedInput(attr::padding), + node->namedInput(attr::ceil_mode), + node->namedInput(attr::count_include_pad)}); + return {backward_value->node()->output(0), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; + + } else if ( + node->matches( + "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) { JIT_ASSERT(grads.size() == 2); auto graph = node->owningGraph(); - auto backward_value = graph->insert(aten::max_pool2d_with_indices_backward, { - grads.at(0).value(), - node->namedInput(attr::self), - node->namedInput(attr::kernel_size), - node->namedInput(attr::stride), - node->namedInput(attr::padding), - node->namedInput(attr::dilation), - node->namedInput(attr::ceil_mode), - outputs.at(1).value() - }); - return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr}; - - } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { + auto backward_value = graph->insert( + aten::max_pool2d_with_indices_backward, + {grads.at(0).value(), + node->namedInput(attr::self), + node->namedInput(attr::kernel_size), + node->namedInput(attr::stride), + node->namedInput(attr::padding), + node->namedInput(attr::dilation), + node->namedInput(attr::ceil_mode), + outputs.at(1).value()}); + return {backward_value->node()->output(0), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; + + } else if ( + node->matches( + "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) { auto graph = node->owningGraph(); - auto backward_value = graph->insert(aten::thnn_conv2d_backward, { - grads.at(0).value(), - inputs.at(0).value(), - inputs.at(1).value(), - node->namedInput(attr::kernel_size), - node->namedInput(attr::stride), - node->namedInput(attr::padding), - outputs.at(1).value(), - outputs.at(2).value(), - graph->insertConstant(std::vector{true, true, true}) - }); - // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. - Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); + auto backward_value = graph->insert( + aten::thnn_conv2d_backward, + {grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + node->namedInput(attr::kernel_size), + node->namedInput(attr::stride), + node->namedInput(attr::padding), + outputs.at(1).value(), + outputs.at(2).value(), + graph->insertConstant(std::vector{true, true, true})}); + // graph->insert returns a tuple automatically if multiple outputs are + // returned. So unpack them again. + Node* tuple_unpack_node = + graph->insertNode(graph->createTupleUnpack(backward_value)); auto tuple_outputs = tuple_unpack_node->outputs(); JIT_ASSERT(tuple_outputs.size() == size_t(3)); - return {tuple_outputs[0], tuple_outputs[1], nullptr, tuple_outputs[2], nullptr, nullptr}; + return {tuple_outputs[0], + tuple_outputs[1], + nullptr, + tuple_outputs[2], + nullptr, + nullptr}; - } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) { + } else if ( + node->matches( + "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) { auto graph = node->owningGraph(); - auto backward_value = graph->insert(aten::native_batch_norm_backward, { - grads.at(0).value(), - inputs.at(0).value(), - inputs.at(1).value(), - inputs.at(3).value(), - inputs.at(4).value(), - outputs.at(1).value(), - outputs.at(2).value(), - inputs.at(5).value(), - inputs.at(7).value(), - graph->insertConstant(std::vector{true, true, true}) - }); - // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again. - Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value)); + auto backward_value = graph->insert( + aten::native_batch_norm_backward, + {grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + inputs.at(3).value(), + inputs.at(4).value(), + outputs.at(1).value(), + outputs.at(2).value(), + inputs.at(5).value(), + inputs.at(7).value(), + graph->insertConstant(std::vector{true, true, true})}); + // graph->insert returns a tuple automatically if multiple outputs are + // returned. So unpack them again. + Node* tuple_unpack_node = + graph->insertNode(graph->createTupleUnpack(backward_value)); auto tuple_outputs = tuple_unpack_node->outputs(); JIT_ASSERT(tuple_outputs.size() == size_t(3)); - return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr}; + return {tuple_outputs[0], + tuple_outputs[1], + tuple_outputs[2], + nullptr, + nullptr, + nullptr, + nullptr, + nullptr}; - } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { + } else if ( + node->matches( + "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) { auto graph = node->owningGraph(); auto total_weight = graph->insertNode(graph->createUndefined()); auto weight = graph->insertNode(graph->createUndefined()); - auto backward_value = graph->insert(aten::nll_loss_backward, { - grads.at(0).value(), - inputs.at(0).value(), - inputs.at(1).value(), - weight->output(), - inputs.at(3).value(), - inputs.at(4).value(), - total_weight->output() - }); - return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr}; - - } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) { + auto backward_value = graph->insert( + aten::nll_loss_backward, + {grads.at(0).value(), + inputs.at(0).value(), + inputs.at(1).value(), + weight->output(), + inputs.at(3).value(), + inputs.at(4).value(), + total_weight->output()}); + return {backward_value->node()->output(0), + nullptr, + nullptr, + nullptr, + nullptr}; + + } else if (node->matches( + "aten::log_softmax(Tensor self, int dim) -> Tensor")) { JIT_ASSERT(grads.size() == 1); auto graph = node->owningGraph(); - auto backward_value = graph->insert(aten::_log_softmax_backward_data, { - grads.at(0).value(), - outputs.at(0).value(), - node->namedInput(attr::dim), - node->namedInput(attr::self) - }); + auto backward_value = graph->insert( + aten::_log_softmax_backward_data, + {grads.at(0).value(), + outputs.at(0).value(), + node->namedInput(attr::dim), + node->namedInput(attr::self)}); return {backward_value->node()->output(0), nullptr}; - } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) { + } else if ( + node->kind() == prim::Constant || node->kind() == prim::Undefined || + node->kind() == prim::None) { return {}; } - throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`"); - }; - if (!isDifferentiable(node)) { - throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " " - "is not supported, or it is missing necessary type information"); + throw std::runtime_error( + std::string("failed to differentiate `") + + node->kind().toDisplayString() + "`"); } - // If AD is defined using torchscript, use it instead of symbolic - auto script_grads = build_script_grad(node, grad_values); - if (script_grads) - return *script_grads; - // Definition not found in torchscript, look up in the build_sym_grad - // TODO: migrate all to using torchscript - auto sym_grads = build_sym_grad(fmap(grad_values)); - return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); }); -} - -// If we have a function y = f(x) with jacobian J, the backwards of f is dx = J^t dy. -// Note that because the backwards always implements this matrix multiply, -// we know that it maps an input vector of zeros to an output vector of zero -// regardless of what operations it choses to do inside to actually implement -// the matrix multiply (most use some optimized form and never generate J^t). -// More generally, we know that all of the backward computations are linear and -// can use this property to do more aggressive optimizations later. -// It is ok to replace any backward function with known-zero inputs with something -// that produces known-zero outputs. This function encloses each know-linear -// backward function in a 'GradOf' sub-block so that we can perform optimizations -// using this information. In particular, specializeUndef will observe if -// all the inputs to the linear block are Undef, which the autograd uses to represent -// zeros, and then propagate the undefs to the outputs of the block. -static std::vector linearGradientForNode(Node* node, ArrayRef grad_values) { - auto & graph = *node->owningGraph(); +}; +} // namespace + +// If we have a function y = f(x) with jacobian J, the backwards of f is dx = +// J^t dy. Note that because the backwards always implements this matrix +// multiply, we know that it maps an input vector of zeros to an output vector +// of zero regardless of what operations it choses to do inside to actually +// implement the matrix multiply (most use some optimized form and never +// generate J^t). More generally, we know that all of the backward computations +// are linear and can use this property to do more aggressive optimizations +// later. It is ok to replace any backward function with known-zero inputs with +// something that produces known-zero outputs. This function encloses each +// know-linear backward function in a 'GradOf' sub-block so that we can perform +// optimizations using this information. In particular, specializeUndef will +// observe if all the inputs to the linear block are Undef, which the autograd +// uses to represent zeros, and then propagate the undefs to the outputs of the +// block. +static std::vector linearGradientForNode( + Node* node, + ArrayRef grad_values) { + auto& graph = *node->owningGraph(); auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0)); // to make reading gradient graphs easier, remember the name of the forward op linear->s_(attr::name, node->kind().toDisplayString()); auto block = linear->addBlock(); WithInsertPoint guard(block); - auto results = gradientForNode(node, grad_values); - return fmap(results, [block, linear](Value *grad) -> Value* { - if (!grad) return nullptr; + auto results = GradientHelper(node).gradient(grad_values); + return fmap(results, [block, linear](Value* grad) -> Value* { + if (!grad) + return nullptr; block->registerOutput(grad); return linear->addOutput()->copyMetadata(grad); }); } struct ReverseDetails { - ReverseDetails(value_map&& grad_map, Block * reverse_block) - : grad_map(std::move(grad_map)) - , reverse_block(reverse_block) {} + ReverseDetails(value_map&& grad_map, Block* reverse_block) + : grad_map(std::move(grad_map)), reverse_block(reverse_block) {} value_map grad_map; - Block * reverse_block; + Block* reverse_block; }; // AutogradAdd is a special addition function that handles Undef @@ -670,7 +816,7 @@ static Value* createAutogradAdd(Value* a, Value* b) { // - grad_desc has df_input_vjps and df_output_vjps set // (but df_input_vjps will be modified later as well) static ReverseDetails addReverseInline(Gradient& grad_desc) { - auto & graph = *grad_desc.f; + auto& graph = *grad_desc.f; // note: reverse_node is intentionally not inserted to avoid // accidentally acting on it (e.g. in elminate dead code), // std::cout << *reverse_node << to view its state. @@ -687,8 +833,8 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) { } return it->second; }; - const auto set_grad = [&](Value *x, Value *dx) { - if (Value * prev_grad = grad_map[x]) { + const auto set_grad = [&](Value* x, Value* dx) { + if (Value* prev_grad = grad_map[x]) { grad_map[x] = createAutogradAdd(prev_grad, dx); } else { grad_map[x] = dx; @@ -697,45 +843,54 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) { auto outputs = graph.outputs(); for (size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) { - Value * output = outputs[i]; + Value* output = outputs[i]; if (!output->requires_grad()) continue; - Value * output_grad = reverse_block->addInput()->setType(output->type()); + Value* output_grad = reverse_block->addInput()->setType(output->type()); set_grad(output, output_grad); grad_desc.df_input_vjps.push_back(i); } - for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end; ++it) { - Node *node = *it; + for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end; + ++it) { + Node* node = *it; auto inputs = node->inputs(); auto outputs = node->outputs(); - if (std::all_of(outputs.begin(), outputs.end(), [](Value *v) { return !v->requires_grad(); })) { + if (std::all_of(outputs.begin(), outputs.end(), [](Value* v) { + return !v->requires_grad(); + })) { continue; } - value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad)); + value_list grad_inputs = + linearGradientForNode(node, fmap(node->outputs(), get_grad)); LowerSimpleTuples(reverse_block); JIT_ASSERT(grad_inputs.size() == node->inputs().size()); for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) { - if (!inputs[i]->requires_grad()) continue; - // NB: Not returning a gradient w.r.t. a value that requires grad is normal if the - // input is non-differentiable. This happens e.g. in the aten::type_as case. - if (!grad_inputs[i]) continue; + if (!inputs[i]->requires_grad()) + continue; + // NB: Not returning a gradient w.r.t. a value that requires grad is + // normal if the input is non-differentiable. This happens e.g. in the + // aten::type_as case. + if (!grad_inputs[i]) + continue; set_grad(inputs[i], grad_inputs[i]); } } auto inputs = graph.inputs(); for (size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) { - Value * input = inputs[i]; + Value* input = inputs[i]; if (!input->requires_grad()) continue; - // NB: Not having a gradient defined w.r.t. an input to the graph which requires grad - // can happen and is not an error. It might have been used only in non-differentiable - // contexts (e.g. as second input to aten::type_as). In that case we simply ignore it - // as an output, because it won't ever produce any meaningful values. - if (grad_map.count(input) == 0) continue; + // NB: Not having a gradient defined w.r.t. an input to the graph which + // requires grad can happen and is not an error. It might have been used + // only in non-differentiable contexts (e.g. as second input to + // aten::type_as). In that case we simply ignore it as an output, because it + // won't ever produce any meaningful values. + if (grad_map.count(input) == 0) + continue; reverse_block->registerOutput(get_grad(input)); grad_desc.df_output_vjps.push_back(i); } @@ -743,14 +898,15 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) { return ReverseDetails(std::move(grad_map), reverse_block); } -// Returns a topologically-sorted list of values produced in f, and used in its reverse program. +// Returns a topologically-sorted list of values produced in f, and used in its +// reverse program. static value_list getReverseCaptures(Gradient& grad_desc) { - auto & graph = *grad_desc.f; + auto& graph = *grad_desc.f; auto primal_block = graph.block(); value_set reverse_captures_set; value_list reverse_captures; // Invariant: topo sorted - auto check_uses = [&](Value *v) { + auto check_uses = [&](Value* v) { for (auto use : v->uses()) { if (use.user->owningBlock() == primal_block) continue; @@ -759,39 +915,43 @@ static value_list getReverseCaptures(Gradient& grad_desc) { } } }; - for (Value * input : graph.inputs()) { + for (Value* input : graph.inputs()) { check_uses(input); } - for (Node * node : graph.nodes()) { - for (Value * output : node->outputs()) + for (Node* node : graph.nodes()) { + for (Value* output : node->outputs()) check_uses(output); } return reverse_captures; } -// Any temporary value from the primal graphs needs to be captured for later use in the -// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have -// in our graphs are simply constants, which are cheap to execute and replicate, and so -// it's better to just copy them into the reverse graph, without polluting the output -// lists unnecessarily. +// Any temporary value from the primal graphs needs to be captured for later use +// in the reverse graph, to avoid costly recomputations. However, a lot of the +// nodes we have in our graphs are simply constants, which are cheap to execute +// and replicate, and so it's better to just copy them into the reverse graph, +// without polluting the output lists unnecessarily. static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) { static const auto err = [](Value*) -> Value* { throw std::runtime_error("unexpected input"); }; - auto & graph = *grad_desc.f; + auto& graph = *grad_desc.f; Block* reverse_block = rev_info.reverse_block; - for (Node *top_node : reverse_block->nodes()) { - JIT_ASSERT(top_node->kind() == prim::GradOf || - top_node->kind() == prim::AutogradAdd || - top_node->kind() == prim::Undefined); - if (top_node->kind() != prim::GradOf) continue; - Block * grad_body = top_node->blocks().at(0); - for (Node *node : grad_body->nodes()) { - for (Value * input : node->inputs()) { - if (input->node()->kind() != prim::Constant) continue; - if (input->node()->owningBlock() == grad_body) continue; - Node *lifted_constant = graph.createClone(input->node(), err); + for (Node* top_node : reverse_block->nodes()) { + JIT_ASSERT( + top_node->kind() == prim::GradOf || + top_node->kind() == prim::AutogradAdd || + top_node->kind() == prim::Undefined); + if (top_node->kind() != prim::GradOf) + continue; + Block* grad_body = top_node->blocks().at(0); + for (Node* node : grad_body->nodes()) { + for (Value* input : node->inputs()) { + if (input->node()->kind() != prim::Constant) + continue; + if (input->node()->owningBlock() == grad_body) + continue; + Node* lifted_constant = graph.createClone(input->node(), err); reverse_block->prependNode(lifted_constant); node->replaceInputWith(input, lifted_constant->output()); } @@ -799,22 +959,25 @@ static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) { } } -static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) { - Block * primal_block = grad_desc.f->block(); - const auto usedOnlyInReverse = [primal_block](Value * v) { - const auto & uses = v->uses(); - return std::all_of(uses.begin(), uses.end(), - [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; }); +static void deduplicateSizeCaptures( + Gradient& grad_desc, + ReverseDetails& rev_info) { + Block* primal_block = grad_desc.f->block(); + const auto usedOnlyInReverse = [primal_block](Value* v) { + const auto& uses = v->uses(); + return std::all_of(uses.begin(), uses.end(), [primal_block](const Use& u) { + return u.user->owningBlock() != primal_block; + }); }; auto captures = getReverseCaptures(grad_desc); - value_set capture_set (captures.begin(), captures.end()); - for (Value * capture : captures) { - Node * node = capture->node(); + value_set capture_set(captures.begin(), captures.end()); + for (Value* capture : captures) { + Node* node = capture->node(); if (!node->matches("aten::size(Tensor self) -> int[]")) { continue; } if (usedOnlyInReverse(capture) && capture_set.count(node->input())) { - WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() }; + WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()}; capture->replaceAllUsesWith(SymbolicVariable(node->input()).size()); node->destroy(); } @@ -846,16 +1009,17 @@ static void eliminateDeadCode(ReverseDetails& rev_info) { } static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) { - // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2), - // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm - // not 100% sure how to optimize this at this stage, since we don't know which - // GradOf blocks will be stitched together to form the derivative. I guess a smart - // analysis could implement this, but I didn't have time before the 1.0 release, - // so I put this only as a peephole optimization. + // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, + // s1), s2), which are equivalent to SumToSize(x, s2), and could save us some + // captures, but I'm not 100% sure how to optimize this at this stage, since + // we don't know which GradOf blocks will be stitched together to form the + // derivative. I guess a smart analysis could implement this, but I didn't + // have time before the 1.0 release, so I put this only as a peephole + // optimization. liftConstants(grad_desc, rev_info); // We generally add a lot of aten::size calls (for derivatives of broadcasting - // operators), and they often end up duplicated, and would get captured multiple - // times. Make sure we deduplicate them before lifting. + // operators), and they often end up duplicated, and would get captured + // multiple times. Make sure we deduplicate them before lifting. EliminateCommonSubexpression(grad_desc.f); deduplicateSizeCaptures(grad_desc, rev_info); eliminateDeadCode(rev_info); @@ -866,10 +1030,11 @@ static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) { // All intermediates needed in the second stage are added to // outputs of f, and taken as inputs in df. For a more // detailed description see Note [Gradient graphs] in autodiff.h. -// This function also initializes the fields in grad_desc that were undefined after -// `addReverseInline` (and extends `df_input_vjps` with vjps for captured temporaries). +// This function also initializes the fields in grad_desc that were undefined +// after `addReverseInline` (and extends `df_input_vjps` with vjps for captured +// temporaries). static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { - auto & graph = *grad_desc.f; + auto& graph = *grad_desc.f; auto primal_block = graph.block(); auto reverse_block = rev_info.reverse_block; @@ -902,49 +1067,58 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { std::unordered_map orig_primal_outputs_idx; std::unordered_map orig_primal_inputs_idx; - // NOTE: we use emplace to avoid replacing an existing index if an output is repeated + // NOTE: we use emplace to avoid replacing an existing index if an output is + // repeated for (size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i) orig_primal_outputs_idx.emplace(graph.outputs()[i], i); for (size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i) orig_primal_inputs_idx[graph.inputs()[i]] = i; // NB: reverse_captures are already deduplicated, and in topo order - for (Value * capture_val : reverse_captures) { + for (Value* capture_val : reverse_captures) { // If it's already an output we don't have to add anything, // but register the fact that it needs to be captured. if (orig_primal_outputs_idx.count(capture_val) > 0) { - grad_desc.df_input_captured_outputs.push_back(orig_primal_outputs_idx[capture_val]); - // If it's an input, we could add it as an output but in fact it's - // more efficient to use a special kind of capture. + grad_desc.df_input_captured_outputs.push_back( + orig_primal_outputs_idx[capture_val]); + // If it's an input, we could add it as an output but in fact it's + // more efficient to use a special kind of capture. } else if (orig_primal_inputs_idx.count(capture_val) > 0) { - grad_desc.df_input_captured_inputs.push_back(orig_primal_inputs_idx.at(capture_val)); - // Otherwise it's just a regular intermediate value that we need to add as an output + grad_desc.df_input_captured_inputs.push_back( + orig_primal_inputs_idx.at(capture_val)); + // Otherwise it's just a regular intermediate value that we need to add as + // an output } else { - // we need to create a new temporary output for this capture because it wasn't availiable. + // we need to create a new temporary output for this capture because it + // wasn't availiable. graph.registerOutput(capture_val); - grad_desc.df_input_captured_outputs.emplace_back(graph.outputs().size() - 1); + grad_desc.df_input_captured_outputs.emplace_back( + graph.outputs().size() - 1); } } // -- Add VJPs for temporaries, adjust df_input_vjps ------------------------- - // NB [possible optimization]: use the newly added vjp input as soon as the first - // vjp for that value is generated, to reduce the lifespan of this input + // NB [possible optimization]: use the newly added vjp input as soon as the + // first vjp for that value is generated, to reduce the lifespan of this input // (currently we add it to the final vjp after all adds). for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) { - Value * tmp = graph.outputs().at(i); + Value* tmp = graph.outputs().at(i); // Add VJP inputs only for intermediates that actually required grad. - // Note that we check the contents of the grad_map instead of tmp->requires_grad(), - // becuase it's actually a more faithful source. tmp->requires_grad() is really an - // overapproximation (i.e. it can have false positives), while the gradients we will - // emit for this value can get DCE-d in the optimization pass (because it has no - // influence on the real f's outputs that we differentiate). - if (rev_info.grad_map.count(tmp) == 0) continue; - Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type()); - Value * tmp_vjp_prev = rev_info.grad_map.at(tmp); - // This is quite weird because we can't first make a sum and then replace all uses - // of tmp_vjp_prev (that would replace its use in the sum too!), so we create an - // incorrect sum that doesn't use prev vjp, replace uses, and fix the sum. - Value * new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in); + // Note that we check the contents of the grad_map instead of + // tmp->requires_grad(), becuase it's actually a more faithful source. + // tmp->requires_grad() is really an overapproximation (i.e. it can have + // false positives), while the gradients we will emit for this value can get + // DCE-d in the optimization pass (because it has no influence on the real + // f's outputs that we differentiate). + if (rev_info.grad_map.count(tmp) == 0) + continue; + Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type()); + Value* tmp_vjp_prev = rev_info.grad_map.at(tmp); + // This is quite weird because we can't first make a sum and then replace + // all uses of tmp_vjp_prev (that would replace its use in the sum too!), so + // we create an incorrect sum that doesn't use prev vjp, replace uses, and + // fix the sum. + Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in); new_vjp->node()->moveAfter(tmp_vjp_prev->node()); tmp_vjp_prev->replaceAllUsesWith(new_vjp); new_vjp->node()->replaceInput(1, tmp_vjp_prev); @@ -956,13 +1130,13 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { // construct a map from captured 'value' to the index in the input list // used to extract this block into its own function std::unordered_map capture_to_formal_index; - const auto & add_capture = [&](Value * captured) { + const auto& add_capture = [&](Value* captured) { capture_to_formal_index[captured] = reverse_block->inputs().size(); reverse_block->addInput()->copyMetadata(captured); }; - for(auto & offset : grad_desc.df_input_captured_inputs) + for (auto& offset : grad_desc.df_input_captured_inputs) add_capture(graph.inputs()[offset]); - for(auto & offset : grad_desc.df_input_captured_outputs) + for (auto& offset : grad_desc.df_input_captured_outputs) add_capture(graph.outputs()[offset]); grad_desc.df = std::make_shared(); @@ -974,13 +1148,14 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) { reverse_block->owningNode()->destroy(); } - Gradient differentiate(std::shared_ptr& graph) { Gradient grad_desc; // Take ownership of the graph - JIT_ASSERTM(graph.use_count() == 1, - "differentiate will mutate and destroy the graph, so it requires " - "graph.use_count() == 1, but found %d", graph.use_count()); + JIT_ASSERTM( + graph.use_count() == 1, + "differentiate will mutate and destroy the graph, so it requires " + "graph.use_count() == 1, but found %d", + graph.use_count()); std::swap(graph, grad_desc.f); // XXX: Take care when handling outputs - they can be duplicated! @@ -1000,4 +1175,5 @@ Gradient differentiate(std::shared_ptr& graph) { return grad_desc; } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/autodiff.h b/torch/csrc/jit/autodiff.h index 519a7ae..74df442 100644 --- a/torch/csrc/jit/autodiff.h +++ b/torch/csrc/jit/autodiff.h @@ -5,12 +5,14 @@ #include -#include #include +#include -namespace torch { namespace jit { +namespace torch { +namespace jit { using value_list = std::vector; +// clang-format off // Example showcasing how Gradient is constructed: // // Let's assume we have a function f, `m` and `n` do not require grad @@ -34,6 +36,7 @@ using value_list = std::vector; // df_output_vjps = {0} // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr). // // Terminology: vjp = vector-jacobian product +// clang-format on struct Gradient { explicit operator bool() const { @@ -45,23 +48,22 @@ struct Gradient { // Describes how to construct outputs of f from what its graph will return. // This is necessary because some trailing outputs are intermediates produced // only to be saved for df (and should be ignored). - size_t f_real_outputs = 0; // initialized for safety. + size_t f_real_outputs = 0; // initialized for safety. - // df inputs are split into two sections: vjps (aka grad_outputs) and captures. - // VJPs are "seeds" for the gradient computation given for each input capture - // of an Output kind. - // Captures are values the need to be saved when f is run. We handle inputs - // specially, because this allows us to avoid adding extra vjps as df inputs. + // df inputs are split into two sections: vjps (aka grad_outputs) and + // captures. VJPs are "seeds" for the gradient computation given for each + // input capture of an Output kind. Captures are values the need to be saved + // when f is run. We handle inputs specially, because this allows us to avoid + // adding extra vjps as df inputs. std::vector df_input_vjps; // Offsets into f's outputs. // capture can come from inputs or outputs std::vector df_input_captured_inputs; // Offsets into f's inputs std::vector df_input_captured_outputs; // Offsets into f's outputs - // df will produce vjps for a subset of inputs of f that required grad. - // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a vjp - // for inp_idx-th input of f. + // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a + // vjp for inp_idx-th input of f. std::vector df_output_vjps; // Offsets into f's inputs. // How to use gradient to implement a differentiable autograd function: @@ -76,8 +78,8 @@ struct Gradient { // - Use df_output_vjps to connect next_edges of grad_fn: // for idx in df_output_vjps: // grad_fn.add_next_edge(inputs[idx].gradient_edge()) - // - Save captures for df (care needs to be taken to use SavedVariables for inputs and - // outputs that we will actually return) + // - Save captures for df (care needs to be taken to use SavedVariables for + // inputs and outputs that we will actually return) // - Return outputs[:f_real_outputs] // // When running df: @@ -88,8 +90,9 @@ struct Gradient { TORCH_API Gradient differentiate(std::shared_ptr& graph); // can we take a derivative of this node symbolically? -TORCH_API bool isDifferentiable(Node * n); -TORCH_API bool isDifferentiable(Graph & g); -TORCH_API bool isZero(Value * v); +TORCH_API bool isDifferentiable(Node* n); +TORCH_API bool isDifferentiable(Graph& g); +TORCH_API bool isZero(Value* v); -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/batched/BatchTensor.cpp b/torch/csrc/jit/batched/BatchTensor.cpp index d514cc4..7d709a6 100644 --- a/torch/csrc/jit/batched/BatchTensor.cpp +++ b/torch/csrc/jit/batched/BatchTensor.cpp @@ -1,19 +1,21 @@ #include -namespace torch { namespace jit { +namespace torch { +namespace jit { -BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims){ - if(data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1){ - throw std::runtime_error("malformed MaskedBatch with data.dim(): " - + std::to_string(data.dim()) + ", mask.dim(): " + std::to_string(mask.dim()) - + ", dims.size(0): " + std::to_string(dims.size(0))); +BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) { + if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) { + throw std::runtime_error( + "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) + + ", mask.dim(): " + std::to_string(mask.dim()) + + ", dims.size(0): " + std::to_string(dims.size(0))); } this->data = std::move(data); this->mask = std::move(mask); this->dims = std::move(dims); } -BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){ +BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) { dims = at::empty(data.dim(), data.options().dtype(at::kByte)); dims.fill_(0); std::vector sizes(data.dim() + 1, -1); @@ -25,13 +27,16 @@ BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){ mask.fill_(1); } -BatchTensor::BatchTensor(const std::vector& datalist, at::Tensor dims) { +BatchTensor::BatchTensor( + const std::vector& datalist, + at::Tensor dims) { auto bs = datalist.size(); - std::vector sizes(dims.size(0) + 1, 0), mask_sizes(dims.size(0) + 1, 0); + std::vector sizes(dims.size(0) + 1, 0), + mask_sizes(dims.size(0) + 1, 0); sizes[0] = bs; mask_sizes[0] = bs; - for(int64_t i = 1; i < dims.size(0) + 1; i++){ - for(const auto& x : datalist){ + for (int64_t i = 1; i < dims.size(0) + 1; i++) { + for (const auto& x : datalist) { sizes[i] = std::max(sizes[i], x.size(i)); } mask_sizes[i] = *dims[i - 1].data() ? sizes[i] : 1; @@ -40,11 +45,11 @@ BatchTensor::BatchTensor(const std::vector& datalist, at::Tensor dim data.fill_(0); mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte)); mask.fill_(0); - for(std::size_t i = 0; i < datalist.size(); i++){ + for (std::size_t i = 0; i < datalist.size(); i++) { auto data_item = data.narrow(0, i, 1); auto mask_item = mask.narrow(0, i, 1); - for(int64_t j = 0; j < dims.size(0); j++){ - if(*dims[j].data()){ + for (int64_t j = 0; j < dims.size(0); j++) { + if (*dims[j].data()) { data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1)); mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1)); } @@ -58,16 +63,16 @@ BatchTensor::BatchTensor(const std::vector& datalist, at::Tensor dim std::vector BatchTensor::examples() { std::vector result; // calculate number of valid entries in dth dimension of data - auto mask_sum = [](at::Tensor data, int d) -> int64_t{ + auto mask_sum = [](at::Tensor data, int d) -> int64_t { data = data.sum(d, /*keepdim=*/true); - while(data.dim() >= 1) + while (data.dim() >= 1) data = data[0]; return *data.data(); }; - for(int64_t i = 0; i < data.size(0); i++){ + for (int64_t i = 0; i < data.size(0); i++) { auto data_tmp = data.narrow(0, i, 1); - for(int64_t d = 0; d < dims.size(0); d++){ - if(*dims[d].data()){ + for (int64_t d = 0; d < dims.size(0); d++) { + if (*dims[d].data()) { data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d)); } } @@ -89,4 +94,5 @@ void initBatchTensorBindings(PyObject* module) { .def("get_dims", &BatchTensor::get_dims); } -}} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/batched/BatchTensor.h b/torch/csrc/jit/batched/BatchTensor.h index a7acd27..d74bf12 100644 --- a/torch/csrc/jit/batched/BatchTensor.h +++ b/torch/csrc/jit/batched/BatchTensor.h @@ -1,18 +1,19 @@ #pragma once +#include #include #include -#include #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { struct BatchTensor { -public: + public: BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims); // expand a tensor to a batchtensor given batch_size BatchTensor(const at::Tensor& data, int64_t batch_size); BatchTensor(const std::vector& datalist, at::Tensor dims); - const char * toString() const { + const char* toString() const { return "BatchTensor"; } at::IntList sizes() const { @@ -22,26 +23,29 @@ public: return data.dim(); } std::vector examples(); - at::Tensor get_data(){ + at::Tensor get_data() { return data; } - at::Tensor get_mask(){ + at::Tensor get_mask() { return mask; } - at::Tensor get_dims(){ + at::Tensor get_dims() { return dims; } -public: + public: // data is a Tensor whose size is the batch size in the batch dimension, // the size of all examples in static dimensions, - // and at least as large as the largest example in the batch in dynamic dimensions. + // and at least as large as the largest example in the batch in dynamic + // dimensions. at::Tensor data; // mask is a Tensor whose size is the batch size in the batch dimension, // one in static dimensions, - // and at least as large as the largest example in the batch in dynamic dimensions. - // Each entry in the mask corresponds to one or more entries in the data array (singleton, i.e., static, dimensions are broadcasted), - // with a one in the mask denoting that the corresponding data entries represent valid, meaningful data and a zero denoting that they do not. + // and at least as large as the largest example in the batch in dynamic + // dimensions. Each entry in the mask corresponds to one or more entries in + // the data array (singleton, i.e., static, dimensions are broadcasted), with + // a one in the mask denoting that the corresponding data entries represent + // valid, meaningful data and a zero denoting that they do not. at::Tensor mask; // dims is a 1-dimensional tensor with a bool for each non-batch dimension, // representing whether that dimension is static (False) or dynamic (True). @@ -49,4 +53,5 @@ public: }; void initBatchTensorBindings(PyObject* module); -}} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/catch_utils.hpp b/torch/csrc/jit/catch_utils.hpp index b9b0a87..9e7696b 100644 --- a/torch/csrc/jit/catch_utils.hpp +++ b/torch/csrc/jit/catch_utils.hpp @@ -3,6 +3,8 @@ #define CATCH_CONFIG_PREFIX_ALL #include -// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning; -// define our own version that doesn't warn. -#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes +// warning; define our own version that doesn't warn. +#define _CATCH_REQUIRE_THROWS(...) \ + INTERNAL_CATCH_THROWS( \ + "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__) diff --git a/torch/csrc/jit/code_template.h b/torch/csrc/jit/code_template.h index 63082ee..13871c1 100644 --- a/torch/csrc/jit/code_template.h +++ b/torch/csrc/jit/code_template.h @@ -1,10 +1,11 @@ #pragma once +#include #include -#include #include -#include +#include -namespace torch { namespace jit { +namespace torch { +namespace jit { // A template environment is a mapping from template variable names, e.g., // identifier (corresponding to $identifier) to their expansions. @@ -14,31 +15,29 @@ namespace torch { namespace jit { // in the top level environment, and then recurses into a parent // environment if the key is not found.) struct TemplateEnv { - TemplateEnv() - : parent(nullptr) {} - TemplateEnv(TemplateEnv & parent) - : parent(&parent) {} + TemplateEnv() : parent(nullptr) {} + TemplateEnv(TemplateEnv& parent) : parent(&parent) {} using string_list = std::vector; // Add a string 'v' to the map at key 'k'. - void s(const std::string & k, const std::string & v) { + void s(const std::string& k, const std::string& v) { strings_[k] = v; lists_.erase(k); } // Add a number 'v' to the map at key 'k' - template - void d(const std::string & k, const T & v) { + template + void d(const std::string& k, const T& v) { strings_[k] = std::to_string(v); lists_.erase(k); } // Retrieve the string representation of the value stored at 'k' from the map. // Raises an exception if the key is not found. - const std::string & s(const std::string & k) const { - if(strings_.count(k) == 0) { - if(parent) { + const std::string& s(const std::string& k) const { + if (strings_.count(k) == 0) { + if (parent) { return parent->s(k); } notFound(k); @@ -47,16 +46,16 @@ struct TemplateEnv { } // Store a list of strings 'v' in the map at 'k'. - void v(const std::string & k, const string_list & v) { + void v(const std::string& k, const string_list& v) { lists_[k] = v; strings_.erase(k); } // Retrieve a list of strings stored at 'k' from the map. // Raises an exception if the key is not found. - const string_list & v(const std::string & k) const { - if(lists_.count(k) == 0) { - if(parent) { + const string_list& v(const std::string& k) const { + if (lists_.count(k) == 0) { + if (parent) { return parent->v(k); } notFound(k); @@ -65,25 +64,25 @@ struct TemplateEnv { } // Test if a string 'k' is a string (as opposed to a list.) - bool keyIsString(const std::string & k) const { - if(strings_.count(k) > 0) + bool keyIsString(const std::string& k) const { + if (strings_.count(k) > 0) return true; - if(lists_.count(k) > 0) + if (lists_.count(k) > 0) return false; - if(parent) + if (parent) return parent->keyIsString(k); notFound(k); } -private: - [[ noreturn ]] - void notFound(const std::string & k) const { + + private: + [[noreturn]] void notFound(const std::string& k) const { std::stringstream ss; ss << "key not found: " << k; throw std::logic_error(ss.str()); } - std::unordered_map strings_; - std::unordered_map lists_; - TemplateEnv * parent; + std::unordered_map strings_; + std::unordered_map lists_; + TemplateEnv* parent; }; /* @@ -96,30 +95,29 @@ private: # if this list is not empty and ${foo,} will insert one after. */ struct CodeTemplate { - /* implicit */ CodeTemplate(std::string t) - : template_text(std::move(t)) {} + /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {} - std::string format(const TemplateEnv & env) { + std::string format(const TemplateEnv& env) { std::stringstream out; size_t pos = 0; size_t indent = 0; bool all_whitespace = true; - while(pos < template_text.size()) { + while (pos < template_text.size()) { char c = template_text[pos]; - if(c == '$') { + if (c == '$') { std::stringstream kss; bool comma_before; bool comma_after; - size_t new_pos = parseKey(pos,kss,comma_before,comma_after); + size_t new_pos = parseKey(pos, kss, comma_before, comma_after); std::string k = kss.str(); bool is_string = env.keyIsString(k); - if(all_whitespace) { - if(is_string) + if (all_whitespace) { + if (is_string) emitStringWithIndents(out, indent, env.s(k)); else emitLinesIndented(out, indent, env.v(k)); } else { - if(is_string) + if (is_string) out << env.s(k); else emitCommaSeparatedList(out, env.v(k), comma_before, comma_after); @@ -128,10 +126,10 @@ struct CodeTemplate { pos = new_pos; } else { out << c; - if(!isspace(c)) + if (!isspace(c)) all_whitespace = false; indent++; - if(c == '\n') { + if (c == '\n') { indent = 0; all_whitespace = true; } @@ -140,29 +138,34 @@ struct CodeTemplate { } return out.str(); } -private: + + private: using string_list = std::vector; char charAt(size_t p) { if (p >= template_text.size()) throw std::logic_error("EOS found in key"); return template_text[p]; } - size_t parseKey(size_t pos, std::ostream & k, bool & comma_before, bool & comma_after) { + size_t parseKey( + size_t pos, + std::ostream& k, + bool& comma_before, + bool& comma_after) { comma_before = false; comma_after = false; pos++; - if(charAt(pos) == '{') { + if (charAt(pos) == '{') { pos++; - if(charAt(pos) == ',') { + if (charAt(pos) == ',') { comma_before = true; pos++; } pos = parseIdent(pos, k); - if(charAt(pos) == ',') { + if (charAt(pos) == ',') { comma_after = true; pos++; } - if(charAt(pos) != '}') + if (charAt(pos) != '}') throw std::logic_error("missing terminating '}'"); pos++; return pos; @@ -170,55 +173,66 @@ private: return parseIdent(pos, k); } } - size_t parseIdent(size_t pos, std::ostream & k) { - while(pos < template_text.size() && - (isalnum(template_text[pos]) || template_text[pos] == '_')) { + size_t parseIdent(size_t pos, std::ostream& k) { + while (pos < template_text.size() && + (isalnum(template_text[pos]) || template_text[pos] == '_')) { k << template_text[pos]; pos++; } return pos; } - void emitCommaSeparatedList(std::ostream & out, const string_list & strings, bool comma_before, bool comma_after) { - if(comma_before && strings.size() > 0) + void emitCommaSeparatedList( + std::ostream& out, + const string_list& strings, + bool comma_before, + bool comma_after) { + if (comma_before && strings.size() > 0) out << ", "; - for(size_t i = 0; i < strings.size(); ++i) { - if(i > 0) + for (size_t i = 0; i < strings.size(); ++i) { + if (i > 0) out << ", "; out << strings[i]; } - if(comma_after && strings.size() > 0) + if (comma_after && strings.size() > 0) out << ", "; } // These indentation functions follow the convention that they never emit // leading or trailing newlines when the input string does not have leading // or trailing newlines. It's the responsibility of the calling function // to indent correctly in the context. - void emitIndent(std::ostream & out, size_t indent) { - for(size_t i = 0; i < indent; ++i) { + void emitIndent(std::ostream& out, size_t indent) { + for (size_t i = 0; i < indent; ++i) { out << " "; } } - void emitStringWithIndents(std::ostream & out, size_t indent, const std::string & str) { - for(auto c : str) { + void emitStringWithIndents( + std::ostream& out, + size_t indent, + const std::string& str) { + for (auto c : str) { out << c; - if(c == '\n') { + if (c == '\n') { emitIndent(out, indent); } } } - void emitLinesIndented(std::stringstream & out, size_t indent, const string_list & strings) { - for(size_t i = 0; i < strings.size(); ++i) { - if(i > 0) + void emitLinesIndented( + std::stringstream& out, + size_t indent, + const string_list& strings) { + for (size_t i = 0; i < strings.size(); ++i) { + if (i > 0) emitIndent(out, indent); - emitStringWithIndents(out,indent,strings[i]); - if(i+1 != strings.size()) + emitStringWithIndents(out, indent, strings[i]); + if (i + 1 != strings.size()) out << "\n"; } } std::string template_text; }; -static inline std::string format(const std::string & fmt, TemplateEnv & env) { +static inline std::string format(const std::string& fmt, TemplateEnv& env) { return CodeTemplate(fmt).format(env); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index c1d5884..6a3aded 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -1,10 +1,11 @@ +#include #include -#include #include -#include +#include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // IValue -> Constant node Value* insertConstant( @@ -12,22 +13,23 @@ Value* insertConstant( const IValue& val, c10::optional loc, c10::optional scope) { - Node * n = g.create(prim::Constant); - if(val.isTensor()) { + Node* n = g.create(prim::Constant); + if (val.isTensor()) { at::Tensor ref = val.toTensor(); - if(!ref.defined()) { + if (!ref.defined()) { n->destroy(); return g.insertNode(g.createUndefined())->output(); } if (ref.is_variable()) { ref = autograd::Variable(ref).data(); } - n->output()->inferTypeFrom(ref); // note: before t_ because of std::move(ref) + n->output()->inferTypeFrom( + ref); // note: before t_ because of std::move(ref) n->t_(attr::value, std::move(ref)); - } else if(val.isInt()) { + } else if (val.isInt()) { n->i_(attr::value, val.toInt()); n->output()->setType(IntType::get()); - } else if(val.isDouble()) { + } else if (val.isDouble()) { n->f_(attr::value, val.toDouble()); n->output()->setType(FloatType::get()); } else if (val.isBool()) { @@ -35,110 +37,120 @@ Value* insertConstant( n->output()->setType(BoolType::get()); } else if (val.isBoolList()) { auto bool_list = val.toBoolList()->elements(); - n->is_(attr::value, std::vector(bool_list.begin(), bool_list.end())); + n->is_( + attr::value, std::vector(bool_list.begin(), bool_list.end())); n->output()->setType(ListType::ofBools()); - } else if(val.isIntList()) { + } else if (val.isIntList()) { n->is_(attr::value, val.toIntList()->elements()); n->output()->setType(ListType::ofInts()); - } else if(val.isTensorList()) { - n->ts_(attr::value, fmap(val.toTensorList()->elements(), [](const at::Tensor & t) { - return autograd::Variable(t).data(); - })); + } else if (val.isTensorList()) { + n->ts_( + attr::value, + fmap(val.toTensorList()->elements(), [](const at::Tensor& t) { + return autograd::Variable(t).data(); + })); n->output()->setType(ListType::ofTensors()); - } else if(val.isString()) { + } else if (val.isString()) { n->s_(attr::value, val.toString()->string()); n->output()->setType(StringType::get()); - } else if(val.isDevice()) { + } else if (val.isDevice()) { std::stringstream ss; ss << val.toDevice(); n->s_(attr::value, ss.str()); n->output()->setType(DeviceObjType::get()); - } else if(val.isNone()) { + } else if (val.isNone()) { n->destroy(); n = g.create(prim::None); n->output()->setType(NoneType::get()); } else { - throw constant_not_supported_error("Unsupported value kind: " + val.tagKind()); + throw constant_not_supported_error( + "Unsupported value kind: " + val.tagKind()); } - if(loc) + if (loc) n->setSourceLocation(std::make_shared(*loc)); - if(scope) + if (scope) n->setScope(*scope); return g.insertNode(n)->output(); } RegisterOperators reg({ - // Implementation of constant node, computes and IValue - Operator( - FunctionSchema(prim::Constant, {}, {}, /*is_vararg=*/false, /*is_varret=*/true), - [](const Node* node) -> Operation { - TypePtr type = node->output()->type(); - if(type->isSubtypeOf(DynamicType::get())) { - auto t = autograd::make_variable(node->t(attr::value)); - return [t](Stack& stack) { - push(stack, t); - return 0; - }; - } else if (type->isSubtypeOf(BoolType::get())) { - bool b = node->i(attr::value); - return [b](Stack& stack) { - push(stack, b); - return 0; - }; - } else if ( - type->isSubtypeOf(NumberType::get()) && - node->kindOf(attr::value) == AttributeKind::i) { - auto i = node->i(attr::value); - return [i](Stack& stack) { - push(stack, i); - return 0; - }; - } else if ( - type->isSubtypeOf(NumberType::get()) && - node->kindOf(attr::value) == AttributeKind::f) { - auto f = node->f(attr::value); - return [f](Stack& stack) { - push(stack, f); - return 0; - }; - } else if(type->isSubtypeOf(ListType::ofInts())) { - const auto& is = node->is(attr::value); - return [is](Stack& stack) { - push(stack, is); - return 0; - }; - } else if(type->isSubtypeOf(ListType::ofBools())) { - const auto& bs = node->is(attr::value); - return [bs](Stack& stack) { - push(stack, bs); - return 0; - }; - } else if(type->isSubtypeOf(ListType::ofTensors())) { - const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor { - return autograd::make_variable(t); - }); - return [ts](Stack& stack) { - push(stack, ts); - return 0; - }; - } else if (type == StringType::get()) { - const auto& s = node->s(attr::value); - return [s](Stack& stack) { - push(stack, s); - return 0; - }; - } else if (type == DeviceObjType::get()) { - auto d = c10::Device(node->s(attr::value)); - return [d](Stack& stack) { - push(stack, d); - return 0; - }; - } else { - std::stringstream ss; - ss << "constant literal not supported for: " << type->str(); - throw std::runtime_error(ss.str()); - } - }), + // Implementation of constant node, computes and IValue + Operator( + FunctionSchema( + prim::Constant, + {}, + {}, + /*is_vararg=*/false, + /*is_varret=*/true), + [](const Node* node) -> Operation { + TypePtr type = node->output()->type(); + if (type->isSubtypeOf(DynamicType::get())) { + auto t = autograd::make_variable(node->t(attr::value)); + return [t](Stack& stack) { + push(stack, t); + return 0; + }; + } else if (type->isSubtypeOf(BoolType::get())) { + bool b = node->i(attr::value); + return [b](Stack& stack) { + push(stack, b); + return 0; + }; + } else if ( + type->isSubtypeOf(NumberType::get()) && + node->kindOf(attr::value) == AttributeKind::i) { + auto i = node->i(attr::value); + return [i](Stack& stack) { + push(stack, i); + return 0; + }; + } else if ( + type->isSubtypeOf(NumberType::get()) && + node->kindOf(attr::value) == AttributeKind::f) { + auto f = node->f(attr::value); + return [f](Stack& stack) { + push(stack, f); + return 0; + }; + } else if (type->isSubtypeOf(ListType::ofInts())) { + const auto& is = node->is(attr::value); + return [is](Stack& stack) { + push(stack, is); + return 0; + }; + } else if (type->isSubtypeOf(ListType::ofBools())) { + const auto& bs = node->is(attr::value); + return [bs](Stack& stack) { + push(stack, bs); + return 0; + }; + } else if (type->isSubtypeOf(ListType::ofTensors())) { + const auto& ts = fmap( + node->ts(attr::value), [](const at::Tensor& t) -> at::Tensor { + return autograd::make_variable(t); + }); + return [ts](Stack& stack) { + push(stack, ts); + return 0; + }; + } else if (type == StringType::get()) { + const auto& s = node->s(attr::value); + return [s](Stack& stack) { + push(stack, s); + return 0; + }; + } else if (type == DeviceObjType::get()) { + auto d = c10::Device(node->s(attr::value)); + return [d](Stack& stack) { + push(stack, d); + return 0; + }; + } else { + std::stringstream ss; + ss << "constant literal not supported for: " << type->str(); + throw std::runtime_error(ss.str()); + } + }), }); c10::optional toIValue(const Value* v) { @@ -151,4 +163,5 @@ c10::optional toIValue(const Value* v) { op(stack); return stack.back(); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/constants.h b/torch/csrc/jit/constants.h index d64bc15..3a787f0 100644 --- a/torch/csrc/jit/constants.h +++ b/torch/csrc/jit/constants.h @@ -1,13 +1,14 @@ #pragma once +#include #include #include #include -#include // helpers for handling constants in the IR // - create constant nodes from ints, floats, intlist, Tensors, and other types // - implement primitive constant ops. -namespace torch { namespace jit { +namespace torch { +namespace jit { struct Graph; struct Value; @@ -19,14 +20,14 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error { // note: prefer g.insertConsant(val, loc) which does exactly the same thing // this function is only declared/defined here because its implementation is -// closely related to the implementation of prim::Constant that is also in constants.cpp +// closely related to the implementation of prim::Constant that is also in +// constants.cpp TORCH_API Value* insertConstant( Graph& g, const IValue& val, c10::optional loc = c10::nullopt, c10::optional scope = c10::nullopt); - ////////////////////////////////////////////////////////////////////////////////// // Helper for retrieving constants ////////////////////////////////////////////////////////////////////////////////// @@ -39,9 +40,10 @@ TORCH_API c10::optional toIValue(const Value* v); // same rules as the interpreter template c10::optional constant_as(const Value* v) { - if(auto ivalue = toIValue(v)) { + if (auto ivalue = toIValue(v)) { return ivalue->to(); } return c10::nullopt; } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/dynamic_dag.h b/torch/csrc/jit/dynamic_dag.h index 0e786c7..79dce52 100644 --- a/torch/csrc/jit/dynamic_dag.h +++ b/torch/csrc/jit/dynamic_dag.h @@ -9,7 +9,9 @@ #include #include -namespace torch { namespace jit { namespace detail { +namespace torch { +namespace jit { +namespace detail { // DynamicDAG is a simple directed acyclic graph that dynamically maintains a // topological order as edges/vertices are added and removed. @@ -19,12 +21,12 @@ namespace torch { namespace jit { namespace detail { // merge black nodes that are directly connected by contracting the // edge between them while still maintaining the DAG and a topological order? // Use contractEdge(). -// - Let's say you have a DAG where each vertex is a Node* and the edges represent -// data dependencies. We wish to determine if adding a new Node* with certain -// data dependencies (or moving an existing one to use new dependencies) is valid. -// Use DynamicDAG::addEdge() to add the new data dependencies to the DAG: -// it will either find a valid reordering of the DAG's topological order or throw -// if the resulting DAG is invalid. +// - Let's say you have a DAG where each vertex is a Node* and the edges +// represent data dependencies. We wish to determine if adding a new Node* +// with certain data dependencies (or moving an existing one to use new +// dependencies) is valid. Use DynamicDAG::addEdge() to add the new data +// dependencies to the DAG: it will either find a valid reordering of the +// DAG's topological order or throw if the resulting DAG is invalid. // // The implementation is based off of the PK algorithm in the following paper: // "A Dynamic Topsort Algorithm for Directed Acyclic Graphs" @@ -32,12 +34,16 @@ namespace torch { namespace jit { namespace detail { // https://www.doc.ic.ac.uk/~phjk/Publications/DynamicTopoSortAlg-JEA-07.pdf // It is summarized in [Edge addition] (see DynamicDAG::addEdge) -template struct Vertex; -template struct DynamicDAG; -template using vertex_list = std::vector*>; -template using unique_vertex = std::unique_ptr>; +template +struct Vertex; +template +struct DynamicDAG; +template +using vertex_list = std::vector*>; +template +using unique_vertex = std::unique_ptr>; -enum class DFSDirection {forward, backward}; +enum class DFSDirection { forward, backward }; // Used to represent adjacency lists in DynamicDAG. // Has set semantics: stores distinct elements. @@ -71,11 +77,21 @@ struct vertex_set { return a->ord < b->ord; }); } - size_t size() const { return data_.size(); } - iterator begin() { return data_.begin(); } - iterator end() { return data_.end(); } - reverse_iterator rbegin() { return data_.rbegin(); } - reverse_iterator rend() { return data_.rend(); } + size_t size() const { + return data_.size(); + } + iterator begin() { + return data_.begin(); + } + iterator end() { + return data_.end(); + } + reverse_iterator rbegin() { + return data_.rbegin(); + } + reverse_iterator rend() { + return data_.rend(); + } private: std::vector*> data_; @@ -110,7 +126,9 @@ struct visited_list { }); } - const vertex_list& vector() { return data_; } + const vertex_list& vector() { + return data_; + } private: vertex_list data_; @@ -118,20 +136,29 @@ struct visited_list { template struct Vertex { - Vertex(size_t ord, T datum) - : ord(ord), visited_(false) { data.push_back(datum); } + Vertex(size_t ord, T datum) : ord(ord), visited_(false) { + data.push_back(datum); + } std::vector data; size_t ord; // unique topological index std::string toString(); - vertex_set& in_edges() { return edges_.in_edges; } - vertex_set& out_edges() { return edges_.out_edges; } - IOEdges&& move_edges() { return std::move(edges_); } + vertex_set& in_edges() { + return edges_.in_edges; + } + vertex_set& out_edges() { + return edges_.out_edges; + } + IOEdges&& move_edges() { + return std::move(edges_); + } - bool visited() { return visited_; } + bool visited() { + return visited_; + } -private: + private: IOEdges edges_; friend visited_list; @@ -149,7 +176,9 @@ struct DynamicDAG { // max_size() >= the number of live vertices. // for all vertices v, v.ord < max_size() - size_t max_size() const { return vertices_.size(); }; + size_t max_size() const { + return vertices_.size(); + }; c10::optional*> at(size_t ord) const; std::string toString(); @@ -179,9 +208,10 @@ struct DynamicDAG { // O(vertices_.size()). Used for testing, don't call this often. template size_t DynamicDAG::debugNumVertices() const { - return std::count_if(vertices_.begin(), vertices_.end(), - [](const unique_vertex& v) { - if (v) return true; + return std::count_if( + vertices_.begin(), vertices_.end(), [](const unique_vertex& v) { + if (v) + return true; return false; }); } @@ -205,7 +235,8 @@ template void DynamicDAG::debugCheckInvariants() { for (size_t ord = 0; ord < vertices_.size(); ++ord) { const auto& vertex = vertices_.at(ord); - if (!vertex) continue; + if (!vertex) + continue; AT_ASSERTM(vertex->ord == ord, toString()); for (auto* v : vertex->in_edges()) { @@ -248,14 +279,15 @@ IOEdges DynamicDAG::removeVertex(Vertex* v) { * * Assume we are adding an edge x -> y and that ord(x) > ord(y). * First, if there is a path y ----> x through some other vertices, then this - * edge addition would create a cycle. Figure this out via DFS and throw if necessary. + * edge addition would create a cycle. Figure this out via DFS and throw if + * necessary. * * Now, consider the set of all vertices v such that ord(x) > ord(v) > ord(y). * Call this set the affected region (AR) -- these are the only vertices we * need to consider for reordering to make the resulting graph valid. * - * Find all children of y (through DFS) in AR (call this set deltaF and add y to it) - * Find all parents of x in AR (call this set deltaB and add x to it). + * Find all children of y (through DFS) in AR (call this set deltaF and add y to + * it) Find all parents of x in AR (call this set deltaB and add x to it). * * Move y and all the children of y to after x and all the parents of x. The * result topological ordering is valid. @@ -291,7 +323,8 @@ IOEdges DynamicDAG::removeVertex(Vertex* v) { * deltaB (sorted) = {c(2), d(4), x(6)}. deltaB ords = { 2, 4, 6 } * deltaF (sorted) = {y(1), a(3), b(5)}. deltaF ords = { 1, 3, 5 } * - * 2) append the two lists: the result is the order we want these vertices to have. + * 2) append the two lists: the result is the order we want these vertices to + * have. * L = {c(2), d(4), x(6), y(1), a(3), b(5)}. * * 3) Merge the sorted ords: R = { 1, 2, 3, 4, 5, 6 }. @@ -304,9 +337,9 @@ IOEdges DynamicDAG::removeVertex(Vertex* v) { * * [Analysis] * This is O(|AR| log |AR|). |AR| is equal to ord(consumer) - ord(producer). - * AR is the "affected region": { v s.t. ord(v) in [ord(producer), ord(consumer)] } - * consisting of the only vertices that can possibly be moved around due to this - * edge addition. + * AR is the "affected region": { v s.t. ord(v) in [ord(producer), + * ord(consumer)] } consisting of the only vertices that can possibly be moved + * around due to this edge addition. * * NB: Pearce and Kelly give a complexity bound of <> where * delta = union(deltaF, deltaB) and <> on a set S is @@ -316,9 +349,11 @@ template void DynamicDAG::addEdge(Vertex* producer, Vertex* consumer) { JIT_ASSERT(producer != consumer); - // NB: DynamicDAG is a simple graph. If an edge exists already, don't do anything. + // NB: DynamicDAG is a simple graph. If an edge exists already, don't do + // anything. bool is_distinct = producer->out_edges().insert(consumer); - if (!is_distinct) return; + if (!is_distinct) + return; is_distinct = consumer->in_edges().insert(producer); JIT_ASSERT(is_distinct); @@ -330,18 +365,26 @@ void DynamicDAG::addEdge(Vertex* producer, Vertex* consumer) { visited_list deltaF; visited_list deltaB; - // Search for vertices that are reachable from consumer that have a now incorrect - // topological ordering. - if (dfsSearch(DFSDirection::forward, consumer, producer, - /*bound=*/producer->ord, deltaF)) { + // Search for vertices that are reachable from consumer that have a now + // incorrect topological ordering. + if (dfsSearch( + DFSDirection::forward, + consumer, + producer, + /*bound=*/producer->ord, + deltaF)) { // Path found! This means there's a cycle. AT_ERROR("Cycle detected while trying to add edge."); } // Search for vertices that can reach producer that have a now incorrect // topological ordering - JIT_ASSERT(!dfsSearch(DFSDirection::backward, producer, consumer, - /*bound=*/consumer->ord, deltaB)); + JIT_ASSERT(!dfsSearch( + DFSDirection::backward, + producer, + consumer, + /*bound=*/consumer->ord, + deltaB)); // Reorder the vertices that are reachable from consumer to occur BEFORE // the vertices that can reach producer. @@ -353,7 +396,8 @@ void DynamicDAG::addEdge(Vertex* producer, Vertex* consumer) { // These are the only vertices that can possibly be moved around // during edge contraction. // -// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|, |in_edges(consumer)|)) +// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|, +// |in_edges(consumer)|)) template bool DynamicDAG::contractEdge(Vertex* producer, Vertex* consumer) { JIT_ASSERT(producer != consumer); @@ -374,10 +418,13 @@ bool DynamicDAG::contractEdge(Vertex* producer, Vertex* consumer) { } template -void DynamicDAG::mergeProducerIntoConsumer(Vertex* producer, Vertex* consumer) { +void DynamicDAG::mergeProducerIntoConsumer( + Vertex* producer, + Vertex* consumer) { // Optimization: we want to concat lists [producer.data, consumer.data]. // Instead of inserting into the beginning of consumer.data, do a swap. - producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end()); + producer->data.insert( + producer->data.end(), consumer->data.begin(), consumer->data.end()); std::swap(consumer->data, producer->data); auto edges = removeVertex(producer); @@ -396,8 +443,11 @@ void DynamicDAG::mergeProducerIntoConsumer(Vertex* producer, Vertex* co } template -void DynamicDAG::mergeConsumerIntoProducer(Vertex* producer, Vertex* consumer) { - producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end()); +void DynamicDAG::mergeConsumerIntoProducer( + Vertex* producer, + Vertex* consumer) { + producer->data.insert( + producer->data.end(), consumer->data.begin(), consumer->data.end()); auto edges = removeVertex(consumer); @@ -412,11 +462,12 @@ void DynamicDAG::mergeConsumerIntoProducer(Vertex* producer, Vertex* co for (auto* parent : edges.in_edges) { addEdge(parent, producer); } - } template -bool DynamicDAG::contractionProducesCycle(Vertex* producer, Vertex* consumer) { +bool DynamicDAG::contractionProducesCycle( + Vertex* producer, + Vertex* consumer) { visited_list visited; // If there are multiple paths from producer to consumer then contracting @@ -426,17 +477,22 @@ bool DynamicDAG::contractionProducesCycle(Vertex* producer, Vertex* con // producer -> consumer edge. size_t upper_bound = consumer->ord; for (auto* child : producer->out_edges()) { - if (child == consumer) continue; - if (child->visited()) continue; // already visited by dfs - if (dfsSearch(DFSDirection::forward, child, consumer, upper_bound, visited)) { + if (child == consumer) + continue; + if (child->visited()) + continue; // already visited by dfs + if (dfsSearch( + DFSDirection::forward, child, consumer, upper_bound, visited)) { return true; } } return false; } - -static bool is_within_bound(DFSDirection direction, size_t value, size_t bound) { +static bool is_within_bound( + DFSDirection direction, + size_t value, + size_t bound) { if (direction == DFSDirection::forward) { return value < bound; // upper bound } else { @@ -467,9 +523,9 @@ bool DynamicDAG::dfsSearch( auto* vertex = stack.back(); stack.pop_back(); - auto& next_edges = (direction == DFSDirection::forward) ? - vertex->out_edges() : - vertex->in_edges(); + auto& next_edges = (direction == DFSDirection::forward) + ? vertex->out_edges() + : vertex->in_edges(); for (Vertex* next : next_edges) { if (next == end) { @@ -485,7 +541,6 @@ bool DynamicDAG::dfsSearch( return false; } - // Reorder deltaB vertices to occur before deltaF vertices. template void DynamicDAG::reorder(visited_list deltaF, visited_list deltaB) { @@ -508,7 +563,8 @@ void DynamicDAG::reorder(visited_list deltaF, visited_list deltaB) { } // Sort the ords by merging two already sorted lists into a large sorted list. - // input (example): deltaB = { v(1), v(4), v(7) } , deltaF = { v(0), v(2), v(5) }. + // input (example): deltaB = { v(1), v(4), v(7) } , + // deltaF = { v(0), v(2), v(5) }. // output: { 0, 1, 2, 4, 5, 7 }. std::vector gathered_ords; gathered_ords.reserve(num_affected); @@ -519,7 +575,10 @@ void DynamicDAG::reorder(visited_list deltaF, visited_list deltaB) { for (const auto* v : deltaF_) { gathered_ords.push_back(v->ord); } - std::inplace_merge(gathered_ords.begin(), gathered_ords.begin() + middle, gathered_ords.end()); + std::inplace_merge( + gathered_ords.begin(), + gathered_ords.begin() + middle, + gathered_ords.end()); // Return the vertices back into the vertices_ storage. for (size_t i = 0; i < num_affected; ++i) { @@ -555,7 +614,7 @@ std::string Vertex::toString() { ss << " " << d; } } - ss << "} ("<< ord << ") -> ["; + ss << "} (" << ord << ") -> ["; for (auto* c : out_edges()) { ss << c->ord << " "; } @@ -563,4 +622,6 @@ std::string Vertex::toString() { return ss.str(); } -}}} +} // namespace detail +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index f7207f7..d6cf849 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -1,15 +1,14 @@ #include #include -#include #include +#include #include -#include #include #include #include - +#include #include #include @@ -27,7 +26,8 @@ #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { namespace { namespace onnx_torch = ::torch::onnx; @@ -45,86 +45,104 @@ std::string getNodeStackTraceString(const Node* n) { return ss.str(); } -void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_type) { +void validateBlock( + Block* b, + onnx_torch::OperatorExportTypes operator_export_type) { for (auto node : b->nodes()) { - for (Block *sub_block : node->blocks()) { + for (Block* sub_block : node->blocks()) { validateBlock(sub_block, operator_export_type); } // Macro'ed so we get a marginally better line number on failed export -#define FAIL_EXPORT(name) \ - throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString()); +#define FAIL_EXPORT(name) \ + throw std::runtime_error( \ + std::string("ONNX export failed: ") + name + \ + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString()); IR_IF(node, PythonOp) - auto py_node = static_cast(value); - FAIL_EXPORT( - "Couldn't export Python operator " + py_node->name() + - "\n\nDefined at:\n" + getNodeStackTraceString(node)) + auto py_node = static_cast(value); + FAIL_EXPORT( + "Couldn't export Python operator " + py_node->name() + + "\n\nDefined at:\n" + getNodeStackTraceString(node)) IR_ELSE() - // Special error messages for certain types of operators - if (node->kind() == aten::expand) { - if (operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) { - WithInsertPoint guard(node); - auto* new_node = b->owningGraph()->insertNode( - b->owningGraph()->create(Symbol(::torch::jit::onnx::ATen), node->inputs(), node->outputs().size())); - for (size_t i = 0; i < node->outputs().size(); ++i) { - node->output(i)->replaceAllUsesWith(new_node->output(i)); - } - new_node->s_(Symbol::fromQualString("attr::operator"), "expand"); + // Special error messages for certain types of operators + if (node->kind() == aten::expand) { + if (operator_export_type == + onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) { + WithInsertPoint guard(node); + auto* new_node = b->owningGraph()->insertNode(b->owningGraph()->create( + Symbol(::torch::jit::onnx::ATen), + node->inputs(), + node->outputs().size())); + for (size_t i = 0; i < node->outputs().size(); ++i) { + node->output(i)->replaceAllUsesWith(new_node->output(i)); } + new_node->s_(Symbol::fromQualString("attr::operator"), "expand"); } - if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { - FAIL_EXPORT( - "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" + - getNodeStackTraceString(node)); - } - bool is_aten_enabled = operator_export_type == - onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK || - operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN; - if (!node->kind().is_onnx() && !is_aten_enabled && - node->kind() != prim::Undefined) { - FAIL_EXPORT( - "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" + - getNodeStackTraceString(node)); - } + } + if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) { + FAIL_EXPORT( + "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" + + getNodeStackTraceString(node)); + } + bool is_aten_enabled = operator_export_type == + onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK || + operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN; + if (!node->kind().is_onnx() && !is_aten_enabled && + node->kind() != prim::Undefined) { + FAIL_EXPORT( + "Couldn't export operator " + node->kind().toDisplayString() + + "\n\nDefined at:\n" + getNodeStackTraceString(node)); + } IR_END() #undef FAIL_EXPORT } } -void validateGraph(const std::shared_ptr& graph, onnx_torch::OperatorExportTypes operator_export_type) { +void validateGraph( + const std::shared_ptr& graph, + onnx_torch::OperatorExportTypes operator_export_type) { validateBlock(graph->block(), operator_export_type); EliminateDeadCode(graph->block()); } class EncoderBase { public: - EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc); + EncoderBase( + onnx_torch::OperatorExportTypes operator_export_type, + bool strip_doc); onnx::ModelProto get_model_proto() { return model_proto_; } protected: - void EncodeGraph(onnx::GraphProto *graph_proto, - const std::shared_ptr &graph, - const std::vector &initializers = {}); + void EncodeGraph( + onnx::GraphProto* graph_proto, + const std::shared_ptr& graph, + const std::vector& initializers = {}); - void EncodeBlock(onnx::GraphProto *graph_proto, - const Block *block, - const std::vector &initializers = {}); + void EncodeBlock( + onnx::GraphProto* graph_proto, + const Block* block, + const std::vector& initializers = {}); virtual void EncodeTensor( onnx::TensorProto* tensor_proto, const at::Tensor& tensor, const c10::optional external_ref = {}) = 0; - virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto, - const Value* n) {}; + virtual void EncodeIntermediateValueInfo( + onnx::GraphProto* graph_proto, + const Value* n){}; - virtual void EncodeValueInfo(onnx::GraphProto *graph_proto, - onnx::ValueInfoProto* v, - const Value* n); + virtual void EncodeValueInfo( + onnx::GraphProto* graph_proto, + onnx::ValueInfoProto* v, + const Value* n); - void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name); + void AddAttribute( + onnx::NodeProto* node_proto, + const jit::Node* node, + const jit::Symbol name); onnx::ModelProto model_proto_; size_t num_blocks_; @@ -133,7 +151,7 @@ class EncoderBase { }; onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { - switch(at_type) { + switch (at_type) { case at::kDouble: return onnx::TensorProto_DataType_DOUBLE; case at::kFloat: @@ -155,7 +173,9 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { } } -EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc) +EncoderBase::EncoderBase( + onnx_torch::OperatorExportTypes operator_export_type, + bool strip_doc) : num_blocks_(0), operator_export_type_(operator_export_type), strip_doc_(strip_doc) { @@ -165,7 +185,7 @@ EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, b } void EncoderBase::EncodeValueInfo( - onnx::GraphProto *graph_proto, + onnx::GraphProto* graph_proto, onnx::ValueInfoProto* v, const Value* n) { v->set_name(n->uniqueName()); @@ -186,15 +206,16 @@ void EncoderBase::EncodeValueInfo( } void EncoderBase::EncodeGraph( - onnx::GraphProto *graph_proto, - const std::shared_ptr &graph, - const std::vector &initializers) { + onnx::GraphProto* graph_proto, + const std::shared_ptr& graph, + const std::vector& initializers) { EncodeBlock(graph_proto, graph->block(), initializers); } void EncoderBase::EncodeBlock( - onnx::GraphProto *graph_proto, const Block *block, - const std::vector &initializers) { + onnx::GraphProto* graph_proto, + const Block* block, + const std::vector& initializers) { JIT_ASSERT(graph_proto != nullptr); std::string block_name = "torch-jit-export"; if (num_blocks_) { @@ -212,7 +233,8 @@ void EncoderBase::EncodeBlock( EncodeValueInfo(graph_proto, v, output); } for (auto node : block->nodes()) { - bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW; + bool is_raw_export = + operator_export_type_ == onnx_torch::OperatorExportTypes::RAW; if (node->kind() == prim::Undefined && !is_raw_export) { // Undefined nodes are used to implement optional inputs. One // way to "not provide" an optional input is to create an @@ -225,26 +247,25 @@ void EncoderBase::EncodeBlock( node->getSourceLocation()->highlight(ss); p_n->set_doc_string(ss.str()); } - for(auto input : node->inputs()) { + for (auto input : node->inputs()) { if (input->node()->kind() == prim::Undefined && !is_raw_export) { p_n->add_input(""); } else { p_n->add_input(input->uniqueName()); } } - for(auto output : node->outputs()) { + for (auto output : node->outputs()) { p_n->add_output(output->uniqueName()); EncodeIntermediateValueInfo(graph_proto, output); } if (is_raw_export) { JIT_ASSERT(!node->kind().is_onnx()); p_n->set_domain(node->kind().domainString()); - } - else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) { + } else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) { JIT_ASSERT(node->kind().is_onnx()); } p_n->set_op_type(node->kind().toUnqualString()); - for(auto attr_name : node->attributeNames()) { + for (auto attr_name : node->attributeNames()) { AddAttribute(p_n, node, attr_name); } if (is_raw_export && node->blocks().size() > 0) { @@ -284,7 +305,7 @@ void EncoderBase::EncodeBlock( auto num_initializers = initializers.size(); JIT_ASSERT(block->inputs().size() >= num_initializers); size_t inputs_count = block->inputs().size() - num_initializers; - for (auto & tensor : initializers) { + for (auto& tensor : initializers) { // TODO: stop using positions to determine which initializers // match to which inputs std::string name = graph_proto->input(inputs_count++).name(); @@ -294,18 +315,21 @@ void EncoderBase::EncodeBlock( } } -void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) { +void EncoderBase::AddAttribute( + onnx::NodeProto* node_proto, + const jit::Node* node, + const jit::Symbol name) { auto attr = node_proto->add_attribute(); JIT_ASSERT(name.is_attr()); attr->set_name(name.toUnqualString()); - switch(node->kindOf(name)) { + switch (node->kindOf(name)) { case AttributeKind::f: attr->set_f(node->f(name)); attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); break; case AttributeKind::fs: attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for(auto & v : node->fs(name)) + for (auto& v : node->fs(name)) attr->add_floats(v); break; case AttributeKind::i: @@ -314,7 +338,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod break; case AttributeKind::is: attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for(auto & v : node->is(name)) + for (auto& v : node->is(name)) attr->add_ints(v); break; case AttributeKind::s: @@ -323,7 +347,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod break; case AttributeKind::ss: attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for(auto & v : node->ss(name)) + for (auto& v : node->ss(name)) attr->add_strings(v); break; case AttributeKind::t: { @@ -333,7 +357,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod } break; case AttributeKind::ts: attr->set_type(onnx::AttributeProto_AttributeType_TENSORS); - for(auto & v : node->ts(name)) { + for (auto& v : node->ts(name)) { auto t = attr->add_tensors(); EncodeTensor(t, v); } @@ -345,7 +369,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod } break; case AttributeKind::gs: attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS); - for(auto & v : node->gs(name)) { + for (auto& v : node->gs(name)) { auto g = attr->add_graphs(); EncodeGraph(g, v); } @@ -355,14 +379,15 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod } } -class GraphEncoder: public EncoderBase { +class GraphEncoder : public EncoderBase { public: - GraphEncoder(const std::shared_ptr &graph, - int64_t onnx_opset_version, - onnx_torch::OperatorExportTypes operator_export_type, - const std::vector &initializers, - bool defer_weight_export, - bool strip_doc); + GraphEncoder( + const std::shared_ptr& graph, + int64_t onnx_opset_version, + onnx_torch::OperatorExportTypes operator_export_type, + const std::vector& initializers, + bool defer_weight_export, + bool strip_doc); RawDataExportMap get_raw_data_export_map() { return raw_data_export_map_; @@ -379,10 +404,10 @@ class GraphEncoder: public EncoderBase { }; GraphEncoder::GraphEncoder( - const std::shared_ptr &graph, + const std::shared_ptr& graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, - const std::vector &initializers, + const std::vector& initializers, bool defer_weight_export, bool strip_doc) : EncoderBase(operator_export_type, strip_doc), @@ -402,7 +427,7 @@ void GraphEncoder::EncodeTensor( onnx::TensorProto* tensor_proto, const at::Tensor& tensor, const c10::optional external_ref) { - for(auto d : tensor.sizes()) { + for (auto d : tensor.sizes()) { tensor_proto->add_dims(d); } tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType())); @@ -420,7 +445,9 @@ void GraphEncoder::EncodeTensor( tensor_proto->set_raw_data("__EXTERNAL"); } else { JIT_ASSERT(t.is_contiguous()); - tensor_proto->set_raw_data(std::string(static_cast(t.data_ptr()), t.type().elementSizeInBytes() * t.numel())); + tensor_proto->set_raw_data(std::string( + static_cast(t.data_ptr()), + t.type().elementSizeInBytes() * t.numel())); } } @@ -517,7 +544,8 @@ void ScriptModuleSerializer::convertModel( model_def->set_producer_version("1.0"); // TODO: set the producer version // using appropriate function call model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST); - convertModule(module, "", writer_.archiveName(), model_def->mutable_main_module()); + convertModule( + module, "", writer_.archiveName(), model_def->mutable_main_module()); writeTensorTable(model_def); } @@ -564,7 +592,8 @@ void ScriptModuleSerializer::convertAndWriteTensor( /* stride = */ {1}) .cpu(); AT_ASSERT( - storage_tensor.type().elementSizeInBytes() * storage_tensor.storage().size() == + storage_tensor.type().elementSizeInBytes() * + storage_tensor.storage().size() == record_size); } std::string name = "tensors/" + std::to_string(tensor_id); @@ -616,7 +645,8 @@ void ScriptModuleSerializer::convertModule( std::stringstream filename; filename << "code/" << module_name.str() << ".py"; std::string methods_str = methods.str(); - writer_.writeRecord(filename.str(), methods_str.c_str(), methods_str.size()); + writer_.writeRecord( + filename.str(), methods_str.c_str(), methods_str.size()); record->set_key(filename.str()); } @@ -656,7 +686,7 @@ void dump(const onnx::TensorProto& tensor, std::ostream& stream) { void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) { for (int i = 0; i < shape.dim_size(); ++i) { - auto &dim = shape.dim(i); + auto& dim = shape.dim(i); if (dim.has_dim_value()) { stream << dim.dim_value(); } else { @@ -676,15 +706,17 @@ void dump(const onnx::TypeProto& type, std::ostream& stream) { } void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) { - stream << "{name: \"" << value_info.name() - << "\", type:"; + stream << "{name: \"" << value_info.name() << "\", type:"; dump(value_info.type(), stream); stream << "}"; } void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent); -void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) { +void dump( + const onnx::AttributeProto& attr, + std::ostream& stream, + size_t indent) { stream << "{ name: '" << attr.name() << "', type: "; if (attr.has_f()) { stream << "float, value: " << attr.f(); @@ -694,7 +726,7 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) stream << "string, value: '" << attr.s() << "'"; } else if (attr.has_g()) { stream << "graph, value:\n"; - dump(attr.g(), stream, indent+1); + dump(attr.g(), stream, indent + 1); stream << nlidt(indent); } else if (attr.has_t()) { stream << "tensor, value:"; @@ -712,7 +744,8 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) } else if (attr.strings_size()) { stream << "strings, values: ["; for (int i = 0; i < attr.strings_size(); ++i) - stream << "'" << attr.strings(i) << "'" << (i == attr.strings_size() - 1 ? "" : " "); + stream << "'" << attr.strings(i) << "'" + << (i == attr.strings_size() - 1 ? "" : " "); stream << "]"; } else if (attr.tensors_size()) { stream << "tensors, values: ["; @@ -723,7 +756,7 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) } else if (attr.graphs_size()) { stream << "graphs, values: ["; for (auto& g : attr.graphs()) { - dump(g, stream, indent+1); + dump(g, stream, indent + 1); } stream << "]"; } else { @@ -743,58 +776,56 @@ void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) { } stream << "], attributes: ["; for (int i = 0; i < node.attribute_size(); ++i) { - dump(node.attribute(i), stream, indent+1); + dump(node.attribute(i), stream, indent + 1); stream << (i == node.attribute_size() - 1 ? "" : ","); } stream << "]}"; } void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) { - stream << idt(indent) << "GraphProto {" << nlidt(indent+1) - << "name: \"" << graph.name() << "\"" << nlidt(indent+1) - << "inputs: ["; + stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \"" + << graph.name() << "\"" << nlidt(indent + 1) << "inputs: ["; for (int i = 0; i < graph.input_size(); ++i) { dump(graph.input(i), stream); stream << (i == graph.input_size() - 1 ? "" : ","); } - stream << "]" << nlidt(indent+1) - << "outputs: ["; + stream << "]" << nlidt(indent + 1) << "outputs: ["; for (int i = 0; i < graph.output_size(); ++i) { dump(graph.output(i), stream); stream << (i == graph.output_size() - 1 ? "" : ","); } - stream << "]" << nlidt(indent+1) - << "initializers: ["; + stream << "]" << nlidt(indent + 1) << "initializers: ["; for (int i = 0; i < graph.initializer_size(); ++i) { dump(graph.initializer(i), stream); stream << (i == graph.initializer_size() - 1 ? "" : ","); } - stream << "]" << nlidt(indent+1) - << "nodes: [" << nlidt(indent+2); + stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2); for (int i = 0; i < graph.node_size(); ++i) { - dump(graph.node(i), stream, indent+2); - if (i != graph.node_size() - 1) stream << "," << nlidt(indent+2); + dump(graph.node(i), stream, indent + 2); + if (i != graph.node_size() - 1) + stream << "," << nlidt(indent + 2); } - stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n"; + stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n"; } -void dump(const onnx::OperatorSetIdProto& operator_set_id, std::ostream& stream) { +void dump( + const onnx::OperatorSetIdProto& operator_set_id, + std::ostream& stream) { stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}"; } void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) { - stream << idt(indent) - << "ModelProto {" << nlidt(indent+1) - << "producer_name: \"" << model.producer_name() << "\"" << nlidt(indent+1) - << "domain: \"" << model.domain() << "\"" << nlidt(indent+1) - << "doc_string: \"" << model.doc_string() << "\""; + stream << idt(indent) << "ModelProto {" << nlidt(indent + 1) + << "producer_name: \"" << model.producer_name() << "\"" + << nlidt(indent + 1) << "domain: \"" << model.domain() << "\"" + << nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\""; if (model.has_graph()) { - stream << nlidt(indent+1) << "graph:\n"; - dump(model.graph(), stream, indent+2); + stream << nlidt(indent + 1) << "graph:\n"; + dump(model.graph(), stream, indent + 2); } if (model.opset_import_size()) { - stream << idt(indent+1) << "opset_import: ["; - for (auto &opset_imp : model.opset_import()) { + stream << idt(indent + 1) << "opset_import: ["; + for (auto& opset_imp : model.opset_import()) { dump(opset_imp, stream); } stream << "],\n"; @@ -811,14 +842,19 @@ std::string prettyPrint(const onnx::ModelProto& model) { } // namespace std::string pretty_print_onnx( - const std::shared_ptr &graph, - const std::vector &initializers, - int64_t onnx_opset_version, - bool defer_weight_export, - ::torch::onnx::OperatorExportTypes operator_export_type, - bool google_printer) { + const std::shared_ptr& graph, + const std::vector& initializers, + int64_t onnx_opset_version, + bool defer_weight_export, + ::torch::onnx::OperatorExportTypes operator_export_type, + bool google_printer) { auto graph_encoder = GraphEncoder( - graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, true); + graph, + onnx_opset_version, + operator_export_type, + initializers, + defer_weight_export, + true); if (google_printer) { return graph_encoder.get_model_proto().DebugString(); } @@ -831,15 +867,21 @@ std::string pretty_print_onnx( // be interpretable by a ONNX-compatible framework. However, PyTorch or // libtorch will be able to import the IR and play it back. std::tuple export_onnx( - const std::shared_ptr &graph, - const std::vector &initializers, - int64_t onnx_opset_version, - bool defer_weight_export, - ::torch::onnx::OperatorExportTypes operator_export_type) { + const std::shared_ptr& graph, + const std::vector& initializers, + int64_t onnx_opset_version, + bool defer_weight_export, + ::torch::onnx::OperatorExportTypes operator_export_type) { auto graph_encoder = GraphEncoder( - graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, false); - return std::make_tuple(graph_encoder.get_model_proto().SerializeAsString(), - graph_encoder.get_raw_data_export_map()); + graph, + onnx_opset_version, + operator_export_type, + initializers, + defer_weight_export, + false); + return std::make_tuple( + graph_encoder.get_model_proto().SerializeAsString(), + graph_encoder.get_raw_data_export_map()); } void ExportModule(const script::Module& module, std::ostream& out) { @@ -847,9 +889,10 @@ void ExportModule(const script::Module& module, std::ostream& out) { serializer.serialize(module); } -void ExportModule(const script::Module& module, const std::string &filename) { +void ExportModule(const script::Module& module, const std::string& filename) { ScriptModuleSerializer serializer(filename); serializer.serialize(module); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h index 357f2ae..1e274d8 100644 --- a/torch/csrc/jit/export.h +++ b/torch/csrc/jit/export.h @@ -6,7 +6,8 @@ #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // This map is used to keep track of parameters that should be exported // externally. When `defer_weight_export` is true, the returned map contains @@ -23,25 +24,24 @@ TORCH_API std::tuple export_onnx( const std::vector& initializers, int64_t onnx_opset_version, bool defer_weight_export = false, - ::torch::onnx::OperatorExportTypes operator_export_type - = ::torch::onnx::OperatorExportTypes::ONNX); + ::torch::onnx::OperatorExportTypes operator_export_type = + ::torch::onnx::OperatorExportTypes::ONNX); // For testing purposes TORCH_API std::string pretty_print_onnx( const std::shared_ptr& graph, - const std::vector & initializers, + const std::vector& initializers, int64_t onnx_opset_version, bool defer_weight_export, - ::torch::onnx::OperatorExportTypes operator_export_type - = ::torch::onnx::OperatorExportTypes::ONNX, + ::torch::onnx::OperatorExportTypes operator_export_type = + ::torch::onnx::OperatorExportTypes::ONNX, bool google_printer = false); -TORCH_API void ExportModule( - const script::Module& module, - std::ostream& out); +TORCH_API void ExportModule(const script::Module& module, std::ostream& out); TORCH_API void ExportModule( const script::Module& module, const std::string& filename); -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h index a3a6110..350783c 100644 --- a/torch/csrc/jit/function_schema.h +++ b/torch/csrc/jit/function_schema.h @@ -1,8 +1,10 @@ #include -namespace torch { namespace jit { +namespace torch { +namespace jit { -using ::c10::FunctionSchema; using ::c10::Argument; +using ::c10::FunctionSchema; -}} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/fuser/arg_spec.h b/torch/csrc/jit/fuser/arg_spec.h index e8c7352..d099395 100644 --- a/torch/csrc/jit/fuser/arg_spec.h +++ b/torch/csrc/jit/fuser/arg_spec.h @@ -4,14 +4,16 @@ #include #include +#include #include // fmap #include -#include -#include #include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Describes the (runtime) arguments to a kernel. // ArgSpecs are also used as keys to lookup instantiated kernels, so @@ -19,22 +21,19 @@ namespace torch { namespace jit { namespace fuser { // Note: the device to run on is included in the arg spec because kernels // are compiled per-device. struct TORCH_API ArgSpec { - ArgSpec( - at::TensorList inputs - , const int _device) - : descs_{fmap(inputs)} - , hash_code_{torch::get_hash(_device, inputs.size(), descs_)} - , device_{_device} - { } + ArgSpec(at::TensorList inputs, const int _device) + : descs_{fmap(inputs)}, + hash_code_{torch::get_hash(_device, inputs.size(), descs_)}, + device_{_device} {} // (Common) hash function - static size_t hash(const ArgSpec& spec) { return spec.hash_code_; } + static size_t hash(const ArgSpec& spec) { + return spec.hash_code_; + } // Comparators bool operator==(const ArgSpec& other) const { - return ( - descs_ == other.descs_ - && device_ == other.device_); + return (descs_ == other.descs_ && device_ == other.device_); } bool operator!=(const ArgSpec& spec) const { @@ -42,18 +41,24 @@ struct TORCH_API ArgSpec { } // Getters - size_t hashCode() const { return hash_code_; } - const std::vector& descs() const { return descs_; } - int device() const { return device_; } + size_t hashCode() const { + return hash_code_; + } + const std::vector& descs() const { + return descs_; + } + int device() const { + return device_; + } -private: + private: std::vector descs_; size_t hash_code_; int device_; }; } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CUDA_FUSER || USE_CPU_FUSER diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index 47d6445..512865e 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -1,30 +1,32 @@ #include #include -#include -#include #include +#include #include #include #include #include +#include #if USE_CUDA_FUSER - #include +#include #endif #if USE_CPU_FUSER - #include +#include #endif -#include +#include +#include #include #include -#include +#include #include -#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Template for computing the offset into the tensor to access a value static auto dim_calc = CodeTemplate(R"( @@ -33,7 +35,6 @@ size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes}; ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride}; )"); - static std::string valueName(const Value* n) { return "n" + std::to_string(n->unique()); } @@ -71,11 +72,12 @@ static const char* scalarTypeName(const at::ScalarType type) { return "half"; } - switch(type) { - #define DEFINE_CASE(ctype,name,_) \ - case at::ScalarType::name: return #ctype; + switch (type) { +#define DEFINE_CASE(ctype, name, _) \ + case at::ScalarType::name: \ + return #ctype; AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE) - #undef DEFINE_CASE +#undef DEFINE_CASE default: throw std::runtime_error("unknown scalar type"); } @@ -88,7 +90,6 @@ static const char* calcScalarTypeName(const at::ScalarType type) { return scalarTypeName(type); } - static std::string variableType(const std::shared_ptr& t) { if (t->kind() == TypeKind::IntType) { return "int"; @@ -101,17 +102,21 @@ static std::string variableType(const std::shared_ptr& t) { return calcScalarTypeName(tt->scalarType()); } // something went wrong with the type analysis during shape propagation - throw std::runtime_error("unknown scalar type during JIT fusion code generation"); + throw std::runtime_error( + "unknown scalar type during JIT fusion code generation"); } -static std::string typeCastedValueName(const std::shared_ptr& t, const at::ScalarType outtype, const std::string& vn) { +static std::string typeCastedValueName( + const std::shared_ptr& t, + const at::ScalarType outtype, + const std::string& vn) { if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) { - if (! isIntegralType(outtype)) { + if (!isIntegralType(outtype)) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } return vn; } else if (t->kind() == TypeKind::FloatType) { - if (! isFloatingType(outtype)) { + if (!isFloatingType(outtype)) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } return vn; @@ -123,89 +128,90 @@ static std::string typeCastedValueName(const std::shared_ptr& t, cons return vn; } // something went wrong with the type analysis during shape propagation - throw std::runtime_error("unknown scalar type during JIT fusion code generation"); + throw std::runtime_error( + "unknown scalar type during JIT fusion code generation"); } // Writes "simple mappable" ops static std::string encodeRHS(const Node* n) { static std::unordered_map simple_map_ops = { - // unary - {aten::_cast_Float, "static_cast(${0})"}, - {aten::abs, "fabs(${0})"}, - {aten::sigmoid, "1.f / (1.f + expf(-${0}))"}, - {aten::relu, "${0} < 0 ? 0.f : ${0} "}, - {aten::log, "logf(${0})"}, - {aten::log10, "log10f(${0})"}, - {aten::log1p, "log1pf(${0})"}, - {aten::log2, "log2f(${0})"}, - {aten::lgamma, "lgammaf(${0})"}, - {aten::exp, "expf(${0})"}, - {aten::expm1, "expm1f(${0})"}, - {aten::erf, "erff(${0})"}, - {aten::erfc, "erfcf(${0})"}, - {aten::cos, "cosf(${0})"}, - {aten::acos, "acosf(${0})"}, - {aten::cosh, "coshf(${0})"}, - {aten::sin, "sinf(${0})"}, - {aten::asin, "asinf(${0})"}, - {aten::sinh, "sinhf(${0})"}, - {aten::tan, "tanf(${0})"}, - {aten::atan, "atanf(${0})"}, - {aten::tanh, "tanhf(${0})"}, - {aten::sqrt, "sqrtf(${0})"}, - {aten::rsqrt, "rsqrtf(${0})"}, - {aten::ceil, "ceilf(${0})"}, - {aten::floor, "floorf(${0})"}, - {aten::round, "roundf(${0})"}, - {aten::trunc, "truncf(${0})"}, - {aten::frac, "fracf(${0})"}, - {aten::reciprocal, "1.f/(${0})"}, - {aten::neg, "-${0}"}, - //simple binary - {aten::atan2, "atan2(${0}, ${1})"}, - {aten::min, "fminf(${0}, ${1})"}, - {aten::max, "fmaxf(${0}, ${1})"}, - - //binary with other - // TODO: some of these ops will not get generated because - // we only work on float inputs/outputs, but they are here to record - // that they are valid mappable ops once we handle more type - - {aten::__and__, "${0} && ${1}"}, - {aten::__lshift__, "${0} << ${1}"}, - {aten::__or__, "${0} || ${1}"}, - {aten::__rshift__, "${0} >> ${1}"}, - {aten::__xor__, "${0} ^ ${1}"}, - {aten::div, "${cast_0} / ${cast_1}"}, - {aten::eq, "${0} == ${1}"}, - {aten::fmod, "fmodf(${cast_0}, ${cast_1})"}, - {aten::ge, "(${0} >= ${1})"}, - {aten::gt, "${0} > ${1}"}, - {aten::le, "(${0} <= ${1})"}, - {aten::lt, "${0} < ${1}"}, - {aten::type_as, "(${cast_0})"}, - {aten::mul, "${cast_0} * ${cast_1}"}, - {aten::ne, "${0} != ${1}"}, - {aten::remainder, "remainderf(${0}, ${1})"}, - {aten::pow, "powf(${cast_0}, ${cast_1})"}, - - //alpha - {aten::add, "${cast_0} + ${cast_2}*${cast_1}"}, - {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"}, - {aten::rand_like, "uniform(rnd())"}, - - // min, max - // It may seem unusual to have the bounds as the first case below, - // this is so that if min or max is NaN, they are "ignored" - // and when the input is NaN, the output is, too - {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"}, - - //where - {aten::where, "(${0} ? ${1} : ${2})"}, - - // simple derivatives - {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"}, - {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"}, + // unary + {aten::_cast_Float, "static_cast(${0})"}, + {aten::abs, "fabs(${0})"}, + {aten::sigmoid, "1.f / (1.f + expf(-${0}))"}, + {aten::relu, "${0} < 0 ? 0.f : ${0} "}, + {aten::log, "logf(${0})"}, + {aten::log10, "log10f(${0})"}, + {aten::log1p, "log1pf(${0})"}, + {aten::log2, "log2f(${0})"}, + {aten::lgamma, "lgammaf(${0})"}, + {aten::exp, "expf(${0})"}, + {aten::expm1, "expm1f(${0})"}, + {aten::erf, "erff(${0})"}, + {aten::erfc, "erfcf(${0})"}, + {aten::cos, "cosf(${0})"}, + {aten::acos, "acosf(${0})"}, + {aten::cosh, "coshf(${0})"}, + {aten::sin, "sinf(${0})"}, + {aten::asin, "asinf(${0})"}, + {aten::sinh, "sinhf(${0})"}, + {aten::tan, "tanf(${0})"}, + {aten::atan, "atanf(${0})"}, + {aten::tanh, "tanhf(${0})"}, + {aten::sqrt, "sqrtf(${0})"}, + {aten::rsqrt, "rsqrtf(${0})"}, + {aten::ceil, "ceilf(${0})"}, + {aten::floor, "floorf(${0})"}, + {aten::round, "roundf(${0})"}, + {aten::trunc, "truncf(${0})"}, + {aten::frac, "fracf(${0})"}, + {aten::reciprocal, "1.f/(${0})"}, + {aten::neg, "-${0}"}, + // simple binary + {aten::atan2, "atan2(${0}, ${1})"}, + {aten::min, "fminf(${0}, ${1})"}, + {aten::max, "fmaxf(${0}, ${1})"}, + + // binary with other + // TODO: some of these ops will not get generated because + // we only work on float inputs/outputs, but they are here to record + // that they are valid mappable ops once we handle more type + + {aten::__and__, "${0} && ${1}"}, + {aten::__lshift__, "${0} << ${1}"}, + {aten::__or__, "${0} || ${1}"}, + {aten::__rshift__, "${0} >> ${1}"}, + {aten::__xor__, "${0} ^ ${1}"}, + {aten::div, "${cast_0} / ${cast_1}"}, + {aten::eq, "${0} == ${1}"}, + {aten::fmod, "fmodf(${cast_0}, ${cast_1})"}, + {aten::ge, "(${0} >= ${1})"}, + {aten::gt, "${0} > ${1}"}, + {aten::le, "(${0} <= ${1})"}, + {aten::lt, "${0} < ${1}"}, + {aten::type_as, "(${cast_0})"}, + {aten::mul, "${cast_0} * ${cast_1}"}, + {aten::ne, "${0} != ${1}"}, + {aten::remainder, "remainderf(${0}, ${1})"}, + {aten::pow, "powf(${cast_0}, ${cast_1})"}, + + // alpha + {aten::add, "${cast_0} + ${cast_2}*${cast_1}"}, + {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"}, + {aten::rand_like, "uniform(rnd())"}, + + // min, max + // It may seem unusual to have the bounds as the first case below, + // this is so that if min or max is NaN, they are "ignored" + // and when the input is NaN, the output is, too + {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"}, + + // where + {aten::where, "(${0} ? ${1} : ${2})"}, + + // simple derivatives + {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"}, + {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"}, }; if (n->kind() == prim::Constant) { @@ -222,16 +228,19 @@ static std::string encodeRHS(const Node* n) { TemplateEnv env; size_t i = 0; - auto outtype = n->output()->type()->expect()->scalarType(); - for(auto in : n->inputs()) { - // PyTorch converts (scalar) argument types to result before applying the operator - // e.g. 1.4-torch.tensor(3) = -2 + auto outtype = + n->output()->type()->expect()->scalarType(); + for (auto in : n->inputs()) { + // PyTorch converts (scalar) argument types to result before applying the + // operator e.g. 1.4-torch.tensor(3) = -2 env.s(std::to_string(i), valueName(in)); - env.s(std::string("cast_")+std::to_string(i), typeCastedValueName(in->type(), outtype, valueName(in))); + env.s( + std::string("cast_") + std::to_string(i), + typeCastedValueName(in->type(), outtype, valueName(in))); i++; } - const auto & str = simple_map_ops.at(n->kind()); + const auto& str = simple_map_ops.at(n->kind()); return format(str, env); } @@ -240,7 +249,7 @@ static std::string encodeRHS(const Node* n) { static Node* usedInFusedChunk(const Value* input) { auto uses = input->uses(); if (uses.size() == 1) { - Node *user = uses[0].user; + Node* user = uses[0].user; if (user->kind() == prim::ConstantChunk) { return user; } @@ -249,41 +258,46 @@ static Node* usedInFusedChunk(const Value* input) { } static void emitIndexingFor( - std::ostream& out -, const std::string& tensor -, const int ndim -, const bool last_is_cont) { + std::ostream& out, + const std::string& tensor, + const int ndim, + const bool last_is_cont) { TemplateEnv env; - env.s("tensor",tensor); - out << format("IndexType ${tensor}_offset = 0;\n",env); - out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env); + env.s("tensor", tensor); + out << format("IndexType ${tensor}_offset = 0;\n", env); + out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env); for (int d = ndim - 1; d >= 0; --d) { - env.d("d",d); - env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : ""); - env.s("times_stride",(d < ndim - 1 || !last_is_cont) ? - format("* ${tensor}.strides[${d}]",env) : ""); + env.d("d", d); + env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : ""); + env.s( + "times_stride", + (d < ndim - 1 || !last_is_cont) + ? format("* ${tensor}.strides[${d}]", env) + : ""); out << dim_calc.format(env); if (d > 0) { - out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env); + out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env); } } } // TODO: handle cases where we need to generate > 2^32 element tensors std::tuple< - std::string -, std::vector -, std::vector -, bool> + std::string, + std::vector, + std::vector, + bool> generateKernel( - const std::string& name -, const Graph& graph -, const std::vector& input_desc -, const std::vector& output_desc -, const bool use_cuda) { + const std::string& name, + const Graph& graph, + const std::vector& input_desc, + const std::vector& output_desc, + const bool use_cuda) { TemplateEnv env; env.s("kernelName", name); - env.s("IndexType","unsigned int"); // Note: not uint32_t to avoid including cstdint + env.s( + "IndexType", + "unsigned int"); // Note: not uint32_t to avoid including cstdint std::stringstream body; std::stringstream tensorOffsets; @@ -292,15 +306,24 @@ generateKernel( // Lambda for writing arguments auto emitFormal = [&](const Value* n, const TensorDesc& desc) { - std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output + std::string tensor = + "t" + + std::to_string( + formals.size()); // can't be unique() because Param may be an output const auto nDim = desc.nDim(); - emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous()); + emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous()); env.s("tensor", tensor); - env.d("formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex + env.d( + "formal_index", + formals.size() + + 1); // + 1 because the first argument is the linearIndex env.d("nDim", nDim); env.s("scalar_type", scalarTypeName(desc.scalar_type)); - formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env)); - argument_loads.push_back(format("*static_cast*>(args[${formal_index}])", env)); + formals.push_back( + format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env)); + argument_loads.push_back(format( + "*static_cast*>(args[${formal_index}])", + env)); }; // Writes input parameters and creates flattened inputs @@ -308,7 +331,7 @@ generateKernel( std::vector> flat_inputs; { size_t input_index = 0; - for(const auto& p : graph.inputs()) { + for (const auto& p : graph.inputs()) { if (const Node* chunk = usedInFusedChunk(p)) { int64_t dim = chunk->i(attr::dim); int64_t chunks = chunk->i(attr::chunks); @@ -340,7 +363,7 @@ generateKernel( } else { const auto cat = o->node(); concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim)); - for(const auto& c : cat->inputs()) { + for (const auto& c : cat->inputs()) { emitFormal(c, *concat_desc.back().subTensorDesc()); flat_output_nodes.emplace_back(c, desc); } @@ -364,8 +387,8 @@ generateKernel( if (is_half) { JIT_ASSERT(use_cuda); env.s( - "access" - , format("__half2float(t${formal}.data[t${formal}_offset])", env)); + "access", + format("__half2float(t${formal}.data[t${formal}_offset])", env)); has_half_tensor = true; } else { env.s("access", format("t${formal}.data[t${formal}_offset]", env)); @@ -380,9 +403,12 @@ generateKernel( // Note: Concat and Chunk are implicitly generated // Note: Random number generation is only supported for CUDA kernels. for (const auto& n : graph.nodes()) { - // Note: FusedConcat nodes work by narrowing the output Tensors before the kernel runs - if (n->kind() == prim::FusedConcat) continue; - if (n->kind() == prim::ConstantChunk) continue; + // Note: FusedConcat nodes work by narrowing the output Tensors before the + // kernel runs + if (n->kind() == prim::FusedConcat) + continue; + if (n->kind() == prim::ConstantChunk) + continue; if (n->kind() == aten::rand_like) { JIT_ASSERT(use_cuda); has_random = true; @@ -390,14 +416,14 @@ generateKernel( env.s("node", valueName(n->output())); env.s("rhs", encodeRHS(n)); env.s("lhs_type", variableType(n->output()->type())); - body << format("${lhs_type} ${node} = ${rhs};\n",env); + body << format("${lhs_type} ${node} = ${rhs};\n", env); } // Generates writes to output tensors for (const auto& output : flat_output_nodes) { const auto& o = output.first; env.d("formal", formal_count++); - env.s("access", format("t${formal}.data[t${formal}_offset]",env)); + env.s("access", format("t${formal}.data[t${formal}_offset]", env)); env.s("node", valueName(o)); // Acquires and converts (if needed) outputs @@ -405,32 +431,32 @@ generateKernel( const auto is_half = (output.second.scalar_type == at::ScalarType::Half); if (is_half) { JIT_ASSERT(use_cuda); - body << format("${access} = __float2half(${node});\n",env); + body << format("${access} = __float2half(${node});\n", env); has_half_tensor = true; } else { - body << format("${access} = ${node};\n",env); + body << format("${access} = ${node};\n", env); } } - // Includes headers - // Note: CUDA kernels support halfs and random generation, CPU kernels do not - #if USE_CUDA_FUSER - if (has_half_tensor) { - env.s("HalfHeader", cuda::half_support_literal); - } else { - env.s("HalfHeader", ""); - } +// Includes headers +// Note: CUDA kernels support halfs and random generation, CPU kernels do not +#if USE_CUDA_FUSER + if (has_half_tensor) { + env.s("HalfHeader", cuda::half_support_literal); + } else { + env.s("HalfHeader", ""); + } - if (has_random) { - env.s("RandHeader", cuda::rand_support_literal); - env.s("RandParam", cuda::rand_param); - env.s("RandInit", cuda::rand_init); - } else { - env.s("RandHeader", ""); - env.s("RandParam", ""); - env.s("RandInit", ""); - } - #endif // USE_CUDA_FUSER + if (has_random) { + env.s("RandHeader", cuda::rand_support_literal); + env.s("RandParam", cuda::rand_param); + env.s("RandInit", cuda::rand_init); + } else { + env.s("RandHeader", ""); + env.s("RandParam", ""); + env.s("RandInit", ""); + } +#endif // USE_CUDA_FUSER // Insantiates the CUDA or CPU-specific templates env.s("tensorOffsets", tensorOffsets.str()); @@ -439,25 +465,26 @@ generateKernel( env.v("argument_loads", argument_loads); std::string code_string; if (use_cuda) { - #if USE_CUDA_FUSER - env.s("type_declarations", cuda::type_declarations_template.format(env)); - code_string = cuda::cuda_compilation_unit_template.format(env); - #else - throw std::runtime_error("CUDA Fusion requested but not supported."); - #endif // USE_CUDA_FUSER +#if USE_CUDA_FUSER + env.s("type_declarations", cuda::type_declarations_template.format(env)); + code_string = cuda::cuda_compilation_unit_template.format(env); +#else + throw std::runtime_error("CUDA Fusion requested but not supported."); +#endif // USE_CUDA_FUSER } else { - #if USE_CPU_FUSER - env.s("type_declarations", cpu::type_declarations_template.format(env)); - code_string = cpu::cpu_compilation_unit_template.format(env); - #else - throw std::runtime_error("CPU Fusion requested but not supported"); - #endif // USE_CPU_FUSER +#if USE_CPU_FUSER + env.s("type_declarations", cpu::type_declarations_template.format(env)); + code_string = cpu::cpu_compilation_unit_template.format(env); +#else + throw std::runtime_error("CPU Fusion requested but not supported"); +#endif // USE_CPU_FUSER } if (debugFuser()) { std::cerr << "fusion code:" << code_string << std::endl; } - return std::make_tuple(code_string, std::move(chunk_desc), std::move(concat_desc), has_random); + return std::make_tuple( + code_string, std::move(chunk_desc), std::move(concat_desc), has_random); } } // namespace fuser diff --git a/torch/csrc/jit/fuser/codegen.h b/torch/csrc/jit/fuser/codegen.h index a86a336..26ce490 100644 --- a/torch/csrc/jit/fuser/codegen.h +++ b/torch/csrc/jit/fuser/codegen.h @@ -3,35 +3,37 @@ #if USE_CUDA_FUSER || USE_CPU_FUSER #include -#include #include #include #include +#include -#include -#include #include #include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Creates a CPU or CUDA kernel for the given graph. // Returns a tuple consisting of the generated code (as a string), -// two vectors of PartitionDescs, the chunk and concat descriptions, -// respectively, and a bool indicating whether the generated code +// two vectors of PartitionDescs, the chunk and concat descriptions, +// respectively, and a bool indicating whether the generated code // generates random numbers. // TODO: the partition descriptions should be generated by the executor. TORCH_API std::tuple< - std::string -, std::vector -, std::vector -, bool> + std::string, + std::vector, + std::vector, + bool> generateKernel( - const std::string& name -, const Graph& graph -, const std::vector& input_desc -, const std::vector& output_desc -, const bool use_cuda); + const std::string& name, + const Graph& graph, + const std::vector& input_desc, + const std::vector& output_desc, + const bool use_cuda); } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index 74d6dc3..d94ba85 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -1,44 +1,48 @@ #include #include -#include -#include -#include #include -#include -#include +#include +#include #include #include -#include #include +#include +#include +#include +#include #include "torch/csrc/jit/fuser/interface.h" #if USE_CUDA_FUSER - #include +#include #endif // USE_CUDA_FUSER #if USE_CPU_FUSER - #include +#include #endif // USE_CUDA_FUSER +#include #include #include -#include -#include -#include -#include #include #include +#include #include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Counter for number of kernels compiled, used for debugging and // creating arbitrary kernel names. static std::atomic next_kernel_id{0}; static int debug_fusion{-1}; -size_t nCompiledKernels() { return next_kernel_id.load(); } +size_t nCompiledKernels() { + return next_kernel_id.load(); +} int debugFuser() { if (debug_fusion < 0) { @@ -53,7 +57,7 @@ int debugFuser() { static const Node* usedInFusedChunk(const Value* input) { const auto& uses = input->uses(); if (uses.size() == 1) { - const Node *user = uses[0].user; + const Node* user = uses[0].user; if (user->kind() == prim::ConstantChunk) { return user; } @@ -65,7 +69,8 @@ static void setInputChunkDescriptors(KernelSpec& spec) { spec.inputChunks().reserve((spec.graph())->inputs().size()); for (const Value* input : (spec.graph())->inputs()) { if (const Node* chunk = usedInFusedChunk(input)) { - spec.inputChunks().emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim)); + spec.inputChunks().emplace_back( + chunk->i(attr::chunks), chunk->i(attr::dim)); } else { spec.inputChunks().emplace_back(1, 0); } @@ -78,14 +83,15 @@ static std::vector getInputDependencies(const Value* output) { std::unordered_set inputs; std::unordered_set seen; while (!queue.empty()) { - const Value* val = queue.back(); queue.pop_back(); + const Value* val = queue.back(); + queue.pop_back(); const Node* producer = val->node(); if (producer->kind() == prim::Param) { inputs.insert(val); continue; } for (const Value* input : producer->inputs()) { - if (/*bool inserted = */seen.insert(input).second) { + if (/*bool inserted = */ seen.insert(input).second) { queue.push_back(input); } } @@ -103,7 +109,8 @@ static std::vector getInputDependencies(const Value* output) { } static void setInputBroadcastGroups(KernelSpec& spec) { - std::unordered_set, torch::hash>> broadcast_groups; + std::unordered_set, torch::hash>> + broadcast_groups; for (const Value* output : (spec.graph())->outputs()) { if (output->node()->kind() == prim::FusedConcat) { for (const Value* concat_input : output->node()->inputs()) { @@ -114,9 +121,9 @@ static void setInputBroadcastGroups(KernelSpec& spec) { } } std::copy( - broadcast_groups.begin() - , broadcast_groups.end() - , std::back_inserter(spec.inputBroadcastGroups())); + broadcast_groups.begin(), + broadcast_groups.end(), + std::back_inserter(spec.inputBroadcastGroups())); } // Performs "upfront" compilation where storage is known but shapes are not. @@ -156,10 +163,10 @@ int64_t registerFusion(const Node* fusion_group) { } std::shared_ptr compileKernel( - const KernelSpec& spec -, const ArgSpec& arg_spec -, const std::vector& map_size -, const at::Device device) { + const KernelSpec& spec, + const ArgSpec& arg_spec, + const std::vector& map_size, + const at::Device device) { const std::vector& input_desc = arg_spec.descs(); auto graph = spec.graph()->copy(); @@ -167,7 +174,10 @@ std::shared_ptr compileKernel( c10::optional scalar_type; for (size_t i = 0; i < input_desc.size(); i++) { const auto& desc = input_desc[i]; - graph->inputs()[i]->setType(TensorType::create(desc.scalar_type, device, desc.nDim())); // TODO: nDim is bad, as it is collapsed + graph->inputs()[i]->setType(TensorType::create( + desc.scalar_type, + device, + desc.nDim())); // TODO: nDim is bad, as it is collapsed } PropagateInputShapes(graph); @@ -179,7 +189,8 @@ std::shared_ptr compileKernel( if (output->node()->kind() == prim::FusedConcat) { sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size(); } - auto scalar_type = output->type()->expect()->scalarType(); + auto scalar_type = + output->type()->expect()->scalarType(); auto type = CompleteTensorType::create(scalar_type, device, sizes); output_desc.emplace_back(std::move(type)); } @@ -190,42 +201,37 @@ std::shared_ptr compileKernel( std::vector chunk_desc; std::vector concat_desc; bool has_random; - std::tie(code, chunk_desc, concat_desc, has_random) - = generateKernel( - name - , *graph - , input_desc - , output_desc - , use_cuda); + std::tie(code, chunk_desc, concat_desc, has_random) = + generateKernel(name, *graph, input_desc, output_desc, use_cuda); std::shared_ptr fused_kernel; if (use_cuda) { - #if USE_CUDA_FUSER - fused_kernel = std::make_shared( - device.index() - , name - , code - , input_desc - , output_desc - , chunk_desc - , concat_desc - , has_random); - #else - throw std::runtime_error("CUDA Fusion is not supported on this build."); - #endif // USE_CUDA_FUSER +#if USE_CUDA_FUSER + fused_kernel = std::make_shared( + device.index(), + name, + code, + input_desc, + output_desc, + chunk_desc, + concat_desc, + has_random); +#else + throw std::runtime_error("CUDA Fusion is not supported on this build."); +#endif // USE_CUDA_FUSER } else { - #if USE_CPU_FUSER - fused_kernel = std::make_shared( - name - , code - , input_desc - , output_desc - , chunk_desc - , concat_desc - , has_random); - #else - throw std::runtime_error("CPU Fusion is not supported on this build."); - #endif // USE_CPU_FUSER +#if USE_CPU_FUSER + fused_kernel = std::make_shared( + name, + code, + input_desc, + output_desc, + chunk_desc, + concat_desc, + has_random); +#else + throw std::runtime_error("CPU Fusion is not supported on this build."); +#endif // USE_CPU_FUSER } return fused_kernel; diff --git a/torch/csrc/jit/fuser/compiler.h b/torch/csrc/jit/fuser/compiler.h index 2a7f6f0..38e1ef1 100644 --- a/torch/csrc/jit/fuser/compiler.h +++ b/torch/csrc/jit/fuser/compiler.h @@ -3,18 +3,20 @@ #if USE_CUDA_FUSER || USE_CPU_FUSER #include -#include -#include +#include #include +#include #include #include -#include -#include +#include +#include #include #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Performs device-independent "upfront" compilation of the given fusion_group, // if it has not been registered already. @@ -25,10 +27,10 @@ TORCH_API int64_t registerFusion(const Node* fusion_group); // with the runtime arguments specified in ArgSpec. // Outputs are allocated using map_size on the specified device. TORCH_API std::shared_ptr compileKernel( - const KernelSpec& spec -, const ArgSpec& arg_spec -, const std::vector& map_size -, const at::Device device); + const KernelSpec& spec, + const ArgSpec& arg_spec, + const std::vector& map_size, + const at::Device device); TORCH_API size_t nCompiledKernels(); diff --git a/torch/csrc/jit/fuser/config.h.in b/torch/csrc/jit/fuser/config.h.in index 0809591..02306ed 100644 --- a/torch/csrc/jit/fuser/config.h.in +++ b/torch/csrc/jit/fuser/config.h.in @@ -1,4 +1,6 @@ #pragma once +// clang-format off #define USE_CPU_FUSER @USE_CPU_FUSER@ #define USE_CUDA_FUSER @USE_CUDA_FUSER@ +// clang-format on diff --git a/torch/csrc/jit/fuser/cpu/dynamic_library.h b/torch/csrc/jit/fuser/cpu/dynamic_library.h index 55adc6b..25f8e39 100644 --- a/torch/csrc/jit/fuser/cpu/dynamic_library.h +++ b/torch/csrc/jit/fuser/cpu/dynamic_library.h @@ -3,10 +3,14 @@ #if USE_CPU_FUSER #include +#include #include -namespace torch { namespace jit { namespace fuser { namespace cpu { +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { static void* checkDL(void* x) { if (!x) { @@ -30,11 +34,12 @@ struct DynamicLibrary { } ~DynamicLibrary() { - if (!handle) return; + if (!handle) + return; dlclose(handle); } -private: + private: void* handle = nullptr; }; diff --git a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp index ad11b14..dbe954a 100644 --- a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp @@ -3,17 +3,20 @@ #include #include #include -#include #include +#include #include -#include #include #include -#include +#include #include +#include -namespace torch { namespace jit { namespace fuser { namespace cpu { +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so"; static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp"; @@ -63,15 +66,15 @@ static CompilerConfig& getConfig() { // optimization can be re-enabled by tracking down the platforms where // this error occurs and only selectively disabling it. static const std::string compile_string = - "\"${cxx}\" -O3 -g " + "\"${cxx}\" -O3 -g " #ifndef __PPC64__ // "-march=native " #endif - "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; + "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; static void runCompiler( - const std::string& cpp_file -, const std::string& so_file) { + const std::string& cpp_file, + const std::string& so_file) { auto& config = getConfig(); TemplateEnv env; env.s("cxx", config.cxx); @@ -81,15 +84,15 @@ static void runCompiler( std::string result = format(compile_string, env); int r = system(result.c_str()); if (config.openmp && r != 0) { - std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n"; + std::cerr + << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n"; config.openmp = false; // disable for future compiles return runCompiler(cpp_file, so_file); } JIT_ASSERTM(r == 0, "Failed to compile a fused CPU kernel"); } -static const std::string disas_string = - "objdump -M intel -d \"${so_file}\""; +static const std::string disas_string = "objdump -M intel -d \"${so_file}\""; static void disas(const std::string& so_file) { TemplateEnv env; env.s("so_file", so_file); @@ -119,11 +122,13 @@ FusedKernelCPU::FusedKernelCPU( cpp_file.write(code_); cpp_file.sync(); runCompiler(cpp_file.name(), so_file.name()); - if (debugFuser() >= 2) disas(so_file.name()); + if (debugFuser() >= 2) + disas(so_file.name()); so_lib = make_unique(so_file.name().c_str()); - #pragma GCC diagnostic ignored "-Wpedantic" - kernel = reinterpret_cast(so_lib->sym(name_.c_str())); - #pragma GCC diagnostic pop +#pragma GCC diagnostic ignored "-Wpedantic" + kernel = + reinterpret_cast(so_lib->sym(name_.c_str())); +#pragma GCC diagnostic pop } } // namespace cpu diff --git a/torch/csrc/jit/fuser/cpu/fused_kernel.h b/torch/csrc/jit/fuser/cpu/fused_kernel.h index 3f32018..272c837 100644 --- a/torch/csrc/jit/fuser/cpu/fused_kernel.h +++ b/torch/csrc/jit/fuser/cpu/fused_kernel.h @@ -4,15 +4,18 @@ #include #include -#include #include #include +#include -#include #include #include +#include -namespace torch { namespace jit { namespace fuser { namespace cpu { +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { // Represents a compiled CPU kernel and the metadata necessary to run it struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel { @@ -34,7 +37,7 @@ struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel { kernel(numel, arguments.data()); } -private: + private: std::unique_ptr so_lib; void (*kernel)(uint32_t, void**) = nullptr; }; diff --git a/torch/csrc/jit/fuser/cpu/resource_strings.h b/torch/csrc/jit/fuser/cpu/resource_strings.h index 63e9051..8d9e13a 100644 --- a/torch/csrc/jit/fuser/cpu/resource_strings.h +++ b/torch/csrc/jit/fuser/cpu/resource_strings.h @@ -4,11 +4,15 @@ #include -namespace torch { namespace jit { namespace fuser { namespace cpu { +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { -/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input. -Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types, -so typedefs help it handle those cases*/ +/*with type_as not checking type of its input, a fusion group can have non-fp32 +tensor as input. Correct code for this case is generated, however, nvrtc does +not know how to handle int*_t integer types, so typedefs help it handle those +cases*/ static auto type_declarations_template = CodeTemplate(R"( @@ -29,9 +33,9 @@ struct TensorInfo { )"); static auto cpu_compilation_unit_template = CodeTemplate(R"( +#include #include #include -#include template scalar_t rsqrtf(scalar_t x) { @@ -61,7 +65,7 @@ void ${kernelName}(IndexType totalElements, void ** args) { } // namespace cpu } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CPU_FUSER diff --git a/torch/csrc/jit/fuser/cpu/temp_file.h b/torch/csrc/jit/fuser/cpu/temp_file.h index 2ded2ba..b889974 100644 --- a/torch/csrc/jit/fuser/cpu/temp_file.h +++ b/torch/csrc/jit/fuser/cpu/temp_file.h @@ -4,15 +4,18 @@ #include #include -#include #include +#include #include #include #include -namespace torch { namespace jit { namespace fuser { namespace cpu { +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { struct TempFile { TH_DISALLOW_COPY_AND_ASSIGN(TempFile); @@ -38,12 +41,12 @@ struct TempFile { fflush(file_); } - void write(const std::string & str) { + void write(const std::string& str) { size_t result = fwrite(str.c_str(), 1, str.size(), file_); JIT_ASSERT(str.size() == result); } - FILE* file() { + FILE* file() { return file_; } @@ -55,14 +58,15 @@ struct TempFile { fclose(file_); } } -private: + + private: FILE* file_ = nullptr; std::string name_; }; } // namespace cpu } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CPU_FUSER diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp index 42432ba..3f6ba60 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp @@ -1,42 +1,48 @@ #include #include -#include #include +#include #include #include // Note: unclear why this forward declaration is necessary -#include #include +#include THCGenerator* THCRandom_getGenerator(THCState* state); -#include #include #include +#include -#include +#include +#include #include +#include #include #include -#include -#include -namespace torch { namespace jit { namespace fuser { namespace cuda { +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { -void checkCUDAVersion( - const cudaDeviceProp& prop) { +void checkCUDAVersion(const cudaDeviceProp& prop) { if ((prop.major >= 6 && CUDA_VERSION < 8000) || (prop.major >= 7 && CUDA_VERSION < 9000)) { std::stringstream err_string; - err_string << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: " - << CUDA_VERSION << " for the current GPU device " << prop.name - << " with device capability " << prop.major << "." << prop.minor; + err_string + << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: " + << CUDA_VERSION << " for the current GPU device " << prop.name + << " with device capability " << prop.major << "." << prop.minor; throw std::runtime_error(err_string.str()); } } -static void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& minor) { +static void getMajorMinor( + const cudaDeviceProp* const prop, + int& major, + int& minor) { int nvrtc_major, nvrtc_minor; TORCH_NVRTC_CHECK(nvrtcVersion(&nvrtc_major, &nvrtc_minor)); @@ -56,12 +62,16 @@ static void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& min minor = 0; } else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2 major = 7; - if (prop->major == 7 && prop->minor <= 2) minor = prop->minor; - else minor = 0; + if (prop->major == 7 && prop->minor <= 2) + minor = prop->minor; + else + minor = 0; } else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5 major = 7; - if (prop->major == 7 && prop->minor <= 5) minor = prop->minor; - else minor = 0; + if (prop->major == 7 && prop->minor <= 5) + minor = prop->minor; + else + minor = 0; } } @@ -88,9 +98,9 @@ FusedKernelCUDA::FusedKernelCUDA( CUcontext pctx = 0; TORCH_CU_CHECK(cuCtxGetCurrent(&pctx)); if (!pctx) { - std::unique_lock cudaFreeMutexLock( - *(THCCachingAllocator_getCudaFreeMutex())); - cudaFree(0); + std::unique_lock cudaFreeMutexLock( + *(THCCachingAllocator_getCudaFreeMutex())); + cudaFree(0); } // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work @@ -98,7 +108,8 @@ FusedKernelCUDA::FusedKernelCUDA( const auto prior_device = at::cuda::current_device(); at::cuda::set_device(device_); - // Acquires device and NVRTC properties (for compile arch and occupancy calculations) + // Acquires device and NVRTC properties (for compile arch and occupancy + // calculations) prop_ = at::cuda::getCurrentDeviceProperties(); int major, minor; getMajorMinor(prop_, major, minor); @@ -106,15 +117,12 @@ FusedKernelCUDA::FusedKernelCUDA( // Creates the NVRTC program nvrtcProgram program; TORCH_NVRTC_CHECK(nvrtcCreateProgram( - &program - , code_.c_str() - , nullptr - , 0 - , nullptr - , nullptr)); - - const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor); - const std::vector args = {"--std=c++11", compute.c_str(), "-default-device"}; + &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); + + const std::string compute = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + const std::vector args = { + "--std=c++11", compute.c_str(), "-default-device"}; const auto result = nvrtcCompileProgram(program, args.size(), args.data()); if (result == NVRTC_ERROR_COMPILATION) { size_t logsize; @@ -125,9 +133,8 @@ FusedKernelCUDA::FusedKernelCUDA( cu << log.data(); throw std::runtime_error(cu.str()); } - ResourceGuard holdProgram([&] { - TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); - }); + ResourceGuard holdProgram( + [&] { TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); }); TORCH_NVRTC_CHECK(result); size_t ptx_size; TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); @@ -139,7 +146,7 @@ FusedKernelCUDA::FusedKernelCUDA( // Computes max blocks TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor( - &maxBlocks_, function_, 128, 0)); + &maxBlocks_, function_, 128, 0)); maxBlocks_ *= prop_->multiProcessorCount; // Resets device (end of hacked at::DeviceGuard) @@ -151,8 +158,8 @@ static int ceilDiv(const int a, const int b) { } void FusedKernelCUDA::launch_raw( - const uint32_t numel -, std::vector& arguments) const { + const uint32_t numel, + std::vector& arguments) const { at::cuda::CUDAGuard{device_}; // Hacked at::DeviceGuard (see note above) const auto prior_device = at::cuda::current_device(); @@ -164,7 +171,8 @@ void FusedKernelCUDA::launch_raw( // Note: offset defined here so its lifetime extends to the launch uint64_t offset; if (has_random_) { - const auto rand_offset = 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1); + const auto rand_offset = + 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1); auto gen = THCRandom_getGenerator(at::globalContext().getTHCState()); offset = gen->state.philox_seed_offset.fetch_add(rand_offset); arguments.push_back(&gen->state.initial_seed); @@ -174,12 +182,17 @@ void FusedKernelCUDA::launch_raw( // Launches kernel on current stream (device was set by executor) auto stream = at::cuda::getCurrentCUDAStream(); TORCH_CU_CHECK(cuLaunchKernel( - function_, - nBlocks, 1, 1, - kBlockSize, 1, 1, - 0, stream, - arguments.data(), - nullptr)); + function_, + nBlocks, + 1, + 1, + kBlockSize, + 1, + 1, + 0, + stream, + arguments.data(), + nullptr)); // Resets device (see at::DeviceGuard notes above) at::cuda::set_device(prior_device); diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.h b/torch/csrc/jit/fuser/cuda/fused_kernel.h index 31a2909..233c001 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.h +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.h @@ -6,15 +6,18 @@ #include #include -#include #include #include +#include #include -#include #include +#include -namespace torch { namespace jit { namespace fuser { namespace cuda { +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { // A class holding metadata for an actual CUDA function. // Note: CUDA functions are per device. @@ -40,7 +43,7 @@ struct TORCH_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel { return at::Backend::CUDA; } -private: + private: static constexpr auto kBlockSize = 128; // Note: per device to store device properties and compute launch heuristics diff --git a/torch/csrc/jit/fuser/cuda/resource_strings.h b/torch/csrc/jit/fuser/cuda/resource_strings.h index ab3b8c1..ce56b81 100644 --- a/torch/csrc/jit/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/fuser/cuda/resource_strings.h @@ -5,11 +5,15 @@ #include #include -namespace torch { namespace jit { namespace fuser { namespace cuda { +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { -/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input. -Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types, -so typedefs help it handle those cases*/ +/*with type_as not checking type of its input, a fusion group can have non-fp32 +tensor as input. Correct code for this case is generated, however, nvrtc does +not know how to handle int*_t integer types, so typedefs help it handle those +cases*/ static auto type_declarations_template = CodeTemplate(R"( typedef unsigned char uint8_t; @@ -132,7 +136,8 @@ constexpr auto rand_support_literal = R"( } )"; -constexpr auto rand_param = ",unsigned long long seed, unsigned long long offset"; +constexpr auto rand_param = + ",unsigned long long seed, unsigned long long offset"; constexpr auto rand_init = R"( int idx = blockIdx.x*blockDim.x + threadIdx.x; @@ -156,13 +161,12 @@ void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) { } )"); - // This snippet enables half support in the jit. Following the pattern for // reductions, fp16 input data is immediately upconverted to float // with __half2float(). All mathematical operations are done on float // values, and if needed the intermediate float representation is // converted to half with __float2half() when writing to a half tensor. -constexpr auto half_support_literal = R"( +constexpr auto half_support_literal = R"( #define __HALF_TO_US(var) *(reinterpret_cast(&(var))) #define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) #if defined(__cplusplus) @@ -197,7 +201,7 @@ typedef __half half; } // namespace cuda } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CUDA_FUSER diff --git a/torch/csrc/jit/fuser/executor.cpp b/torch/csrc/jit/fuser/executor.cpp index 48de0e1..b8f43bc 100644 --- a/torch/csrc/jit/fuser/executor.cpp +++ b/torch/csrc/jit/fuser/executor.cpp @@ -3,31 +3,32 @@ #include #include #include -#include -#include +#include #include #include #include #include -#include #include +#include +#include -#include -#include -#include #include -#include #include // TODO: remove, debugging only +#include +#include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Returns the "map size" for this run, which is the common size for all // intermediate tensors. static c10::optional> getMapSize( - const KernelSpec& spec -, at::TensorList args -, at::IntList arg_subset) { - + const KernelSpec& spec, + at::TensorList args, + at::IntList arg_subset) { // TODO: this keeps reallocating map_size at every iteration, but we know // exactly how much storage do we need, so this could be fixed in-place at // every step. We're just missing a few functions for ATen, but the fix @@ -47,7 +48,8 @@ static c10::optional> getMapSize( } else { auto tensor_sizes = arg.sizes().vec(); const auto num_chunks = chunk_desc.nSubTensors(); - const auto dim = at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size()); + const auto dim = + at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size()); if (tensor_sizes[dim] % num_chunks != 0) { return c10::nullopt; } @@ -65,22 +67,27 @@ static c10::optional> getMapSize( // Tries to determine a map size for the instantiated kernel (see above) static c10::optional> canRunKernel( - const KernelSpec& spec -, at::TensorList args) { + const KernelSpec& spec, + at::TensorList args) { // Short-circuits on size mismatch AT_CHECK( - args.size() == spec.inputChunks().size() - , "Expected ", spec.inputChunks().size(), " arguments, but got ", args.size()); + args.size() == spec.inputChunks().size(), + "Expected ", + spec.inputChunks().size(), + " arguments, but got ", + args.size()); c10::optional> map_size; for (const auto& broadcast_group : spec.inputBroadcastGroups()) { if (!map_size) { map_size = getMapSize(spec, args, broadcast_group); - if (!map_size) return c10::nullopt; + if (!map_size) + return c10::nullopt; } else { const auto group_map_size = getMapSize(spec, args, broadcast_group); // Note: this checks that group_map_size is defined AND equal to map_size - if (map_size != group_map_size) return c10::nullopt; + if (map_size != group_map_size) + return c10::nullopt; } } @@ -92,14 +99,15 @@ static c10::optional> canRunKernel( // Note: Arguments are mutated by this call, although map_size is restored // to its original value. static void expandArgs( - const KernelSpec& spec -, std::vector& args -, std::vector& map_size) { + const KernelSpec& spec, + std::vector& args, + std::vector& map_size) { for (size_t i = 0; i < args.size(); ++i) { auto& arg = args[i]; const auto& pdesc = spec.inputChunks()[i]; if (pdesc.nSubTensors() == 1) { - if (arg.sizes().equals(map_size)) continue; + if (arg.sizes().equals(map_size)) + continue; arg = arg.expand(map_size); } else { map_size.at(pdesc.dim()) *= pdesc.nSubTensors(); @@ -123,8 +131,8 @@ static uint32_t computeNumel(const at::ArrayRef& sizes) { // Note: Assumes that after at::chunk, all inputs are the same size static std::vector computeMapSize( - const at::Tensor& tensor -, const PartitionDesc& chunkDesc) { + const at::Tensor& tensor, + const PartitionDesc& chunkDesc) { std::vector sizes(tensor.sizes().begin(), tensor.sizes().end()); JIT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0); sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors(); @@ -134,37 +142,38 @@ static std::vector computeMapSize( // Tries to compress sizes and strides according to cont. Emits the result t // c_sizes, c_strides and throws an error on failure (if can't compress) static void compressContiguous( - const at::IntList& sizes -, const at::IntList& strides -, const std::vector& cont -, uint32_t* c_sizes -, uint32_t* c_strides) { + const at::IntList& sizes, + const at::IntList& strides, + const std::vector& cont, + uint32_t* c_sizes, + uint32_t* c_strides) { size_t compressed_dims = 0; size_t cur = 0; size_t ndim = sizes.size(); while (cur < ndim) { size_t total_size = sizes[cur]; cur++; - while (cont[cur-1] && cur < ndim) { - JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]); + while (cont[cur - 1] && cur < ndim) { + JIT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]); total_size *= sizes[cur]; cur++; } c_sizes[compressed_dims] = total_size; - c_strides[compressed_dims] = strides[cur-1]; + c_strides[compressed_dims] = strides[cur - 1]; compressed_dims++; } - if (ndim > 0) JIT_ASSERT(!cont.back() || strides.back() == 1); + if (ndim > 0) + JIT_ASSERT(!cont.back() || strides.back() == 1); } // Launches the requested fusion on the given device with the given inputs. // Output pointers are stored in outputs (to be put on the stack later). void launchFusion( - const FusedKernel& fusion -, const at::Device device -, const at::ArrayRef& inputs -, std::vector& outputs) { + const FusedKernel& fusion, + const at::Device device, + const at::ArrayRef& inputs, + std::vector& outputs) { // Fails if fusion and given inputs disagree JIT_ASSERT(inputs.size() == fusion.inputDesc().size()); @@ -195,10 +204,13 @@ void launchFusion( numel = computeNumel(map_size); } - // Computes the storage needed to store TensorInfo structs for inputs and outputs. + // Computes the storage needed to store TensorInfo structs for inputs and + // outputs. size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size(); - size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim; - size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size); + size_t maxPossibleTensorInfoSize = + sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim; + size_t maxPossibleBufferSize = + maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size); std::vector buffer(maxPossibleBufferSize); char* buffer_next = buffer.data(); @@ -207,21 +219,16 @@ void launchFusion( arguments.reserve(3 + flat_inputs_size + flat_outputs_size); arguments.push_back(&numel); - auto addTensorInfoRaw = [&]( - const TensorDesc& desc - , void* data_ptr - , at::IntList sizes - , at::IntList strides) { + auto addTensorInfoRaw = [&](const TensorDesc& desc, + void* data_ptr, + at::IntList sizes, + at::IntList strides) { const auto nDim = desc.nDim(); // NOTE: this is the compressed dim JIT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise auto ti = reinterpret_cast(buffer_next); ti->data = data_ptr; compressContiguous( - sizes - , strides - , desc.contiguity - , ti->sizes(nDim) - , ti->strides(nDim)); + sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim)); buffer_next += maxPossibleTensorInfoSize; arguments.push_back(ti); }; @@ -239,10 +246,12 @@ void launchFusion( if (chunk.isNoop()) { addTensorInfo(fusion.inputDesc()[i], tensor); } else { - size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) * elementSize(tensor.type().scalarType()); + size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) * + elementSize(tensor.type().scalarType()); char* data_ptr = reinterpret_cast(tensor.data_ptr()); for (size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) { - addTensorInfoRaw(*chunk.subTensorDesc(), data_ptr, map_size, tensor.strides()); + addTensorInfoRaw( + *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides()); data_ptr += chunk_offset; } } @@ -254,7 +263,8 @@ void launchFusion( for (size_t i = 0; i < fusion.outputDesc().size(); ++i) { const auto& c = fusion.concatDesc()[i]; if (c.isNoop()) { - outputs.push_back(at::empty(map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type))); + outputs.push_back(at::empty( + map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type))); addTensorInfo(fusion.outputDesc()[i], outputs[i]); } else { size_t small_size = map_size[c.dim()]; @@ -277,12 +287,10 @@ void launchFusion( fusion.launch_raw(numel, arguments); } - -bool runFusion( - const int64_t key -, Stack& stack) { +bool runFusion(const int64_t key, Stack& stack) { // Short-circuits if fusion isn't enabled - if (!canFuseOnCPU() && !canFuseOnGPU()) return false; + if (!canFuseOnCPU() && !canFuseOnGPU()) + return false; // Acquires the FusionSpec auto maybe_spec = retrieve(key); @@ -294,8 +302,8 @@ bool runFusion( return i.toTensor(); }); - // Determines device to dispatch to. If there's a device mismatch in the inputs, - // we use the fallback (which should give a nice error message). + // Determines device to dispatch to. If there's a device mismatch in the + // inputs, we use the fallback (which should give a nice error message). at::Device device = inputs.at(0).device(); at::ScalarType dtype = inputs[0].type().scalarType(); for (const auto& t : at::TensorList(inputs).slice(1)) { @@ -305,14 +313,17 @@ bool runFusion( } // Attempts to run fallback if device fusion is disabled - if (device.is_cuda() && !canFuseOnGPU()) return false; - if (device.is_cpu() && !canFuseOnCPU()) return false; + if (device.is_cuda() && !canFuseOnGPU()) + return false; + if (device.is_cpu() && !canFuseOnCPU()) + return false; // Validates sizes and expands inputs as needed auto maybe_map_size = canRunKernel(spec, inputs); // Tries to run fallback if map size can't be computed - if (!maybe_map_size) return false; + if (!maybe_map_size) + return false; expandArgs(spec, inputs, *maybe_map_size); // Retrieves the kernel, compiling (and caching) if necessary @@ -332,9 +343,9 @@ bool runFusion( // Updates stack drop(stack, spec.nInputs()); stack.insert( - stack.end() - , std::make_move_iterator(outputs.begin()) - , std::make_move_iterator(outputs.end())); + stack.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); return true; } diff --git a/torch/csrc/jit/fuser/executor.h b/torch/csrc/jit/fuser/executor.h index c83a621..9af2cd9 100644 --- a/torch/csrc/jit/fuser/executor.h +++ b/torch/csrc/jit/fuser/executor.h @@ -7,13 +7,13 @@ #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Runs the fusion associated with the key (see registerFusion() in interface.h) // on the inputs taken from the given Stack. -TORCH_API bool runFusion( - const int64_t key -, Stack& stack); +TORCH_API bool runFusion(const int64_t key, Stack& stack); } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/fuser/fallback.cpp b/torch/csrc/jit/fuser/fallback.cpp index 335f7ea..23868c1 100644 --- a/torch/csrc/jit/fuser/fallback.cpp +++ b/torch/csrc/jit/fuser/fallback.cpp @@ -1,40 +1,41 @@ #include -#include //fmap +#include +#include #include #include #include -#include -#include +#include //fmap #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { -// Registers fused operators so that fused graphs can properly generate fallback code. -RegisterOperators reg_fused_operators({ - Operator( - prim::FusedConcat - , [](const Node* node) { - int64_t dim = node->i(attr::dim); - int64_t num_inputs = node->inputs().size(); - return [dim, num_inputs](Stack& stack) { - auto result = at::cat( - fmap(last(stack, num_inputs), [](const IValue& i) { return i.toTensor(); }) - , dim - ); - drop(stack, num_inputs); - pack(stack, std::move(result)); - return 0; - }; - }) -}); +// Registers fused operators so that fused graphs can properly generate fallback +// code. +RegisterOperators reg_fused_operators( + {Operator(prim::FusedConcat, [](const Node* node) { + int64_t dim = node->i(attr::dim); + int64_t num_inputs = node->inputs().size(); + return [dim, num_inputs](Stack& stack) { + auto result = at::cat( + fmap( + last(stack, num_inputs), + [](const IValue& i) { return i.toTensor(); }), + dim); + drop(stack, num_inputs); + pack(stack, std::move(result)); + return 0; + }; + })}); void runFallback(int64_t key, Stack& stack) { auto maybe_spec = retrieve(key); if (!maybe_spec) throw std::runtime_error("Failed to find fusion spec to run fallback."); - + InterpreterState{(*maybe_spec)->code()}.run(stack); } diff --git a/torch/csrc/jit/fuser/fallback.h b/torch/csrc/jit/fuser/fallback.h index b94a325..ab55218 100644 --- a/torch/csrc/jit/fuser/fallback.h +++ b/torch/csrc/jit/fuser/fallback.h @@ -6,7 +6,9 @@ #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { void runFallback(int64_t key, Stack& stack); diff --git a/torch/csrc/jit/fuser/fused_kernel.h b/torch/csrc/jit/fuser/fused_kernel.h index 41a4994..39a590c 100644 --- a/torch/csrc/jit/fuser/fused_kernel.h +++ b/torch/csrc/jit/fuser/fused_kernel.h @@ -3,15 +3,17 @@ #if USE_CUDA_FUSER || USE_CPU_FUSER #include -#include -#include #include +#include +#include -#include #include +#include #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { struct FusedKernel { TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel); @@ -34,7 +36,6 @@ struct FusedKernel { virtual ~FusedKernel() = default; - // arguments is a list of pointers to the arguments for the compiled CUDA/CPU // code. // The format of arguments is suitable for directly passing to a call to @@ -43,23 +44,36 @@ struct FusedKernel { // CUDA code), and the remainder are pointers to the TensorInfo structs // that compiled code uses to load Tensor data. // launch_with_tensors handles packing at::Tensors into this arguments array. - // CPU code uses the same convension so that launch_with_tensors can be shared. - virtual void launch_raw( - const uint32_t numel - , std::vector& arguments) const = 0; + // CPU code uses the same convension so that launch_with_tensors can be + // shared. + virtual void launch_raw(const uint32_t numel, std::vector& arguments) + const = 0; virtual at::Backend backend() const = 0; // Getters - const std::string& name() const { return name_; } - const std::string& code() const { return code_; } - const std::vector& inputDesc() const { return input_desc_; } - const std::vector& outputDesc() const { return output_desc_; } - const std::vector& chunkDesc() const { return chunk_desc_; } - const std::vector& concatDesc() const { return concat_desc_; } - bool hasRandom() const { return has_random_; } - - -protected: + const std::string& name() const { + return name_; + } + const std::string& code() const { + return code_; + } + const std::vector& inputDesc() const { + return input_desc_; + } + const std::vector& outputDesc() const { + return output_desc_; + } + const std::vector& chunkDesc() const { + return chunk_desc_; + } + const std::vector& concatDesc() const { + return concat_desc_; + } + bool hasRandom() const { + return has_random_; + } + + protected: const std::string name_; const std::string code_; const std::vector input_desc_; diff --git a/torch/csrc/jit/fuser/interface.cpp b/torch/csrc/jit/fuser/interface.cpp index e8d63aa..4e63c6f 100644 --- a/torch/csrc/jit/fuser/interface.cpp +++ b/torch/csrc/jit/fuser/interface.cpp @@ -2,14 +2,15 @@ #include #if USE_CUDA_FUSER || USE_CPU_FUSER - #include - #include - #include +#include +#include +#include #endif // USE_CUDA_FUSER || USE_CPU_FUSER #include -namespace torch { namespace jit { +namespace torch { +namespace jit { namespace detail { @@ -19,34 +20,35 @@ bool cpu_fuser_enabled = false; } // namespace detail int64_t registerFusion(const Node* fusion_group) { - #if USE_CUDA_FUSER || USE_CPU_FUSER - return fuser::registerFusion(fusion_group); - #else - throw std::runtime_error("Fusion not supported for this build."); - #endif // USE_CUDA_FUSER || USE_CPU_FUSER +#if USE_CUDA_FUSER || USE_CPU_FUSER + return fuser::registerFusion(fusion_group); +#else + throw std::runtime_error("Fusion not supported for this build."); +#endif // USE_CUDA_FUSER || USE_CPU_FUSER } void runFusion(const int64_t key, Stack& stack) { - #if USE_CUDA_FUSER || USE_CPU_FUSER - const auto result = fuser::runFusion(key, stack); - if (!result) fuser::runFallback(key, stack); - #else - throw std::runtime_error("Fusion not supported for this build."); - #endif // USE_CUDA_FUSER || USE_CPU_FUSER +#if USE_CUDA_FUSER || USE_CPU_FUSER + const auto result = fuser::runFusion(key, stack); + if (!result) + fuser::runFallback(key, stack); +#else + throw std::runtime_error("Fusion not supported for this build."); +#endif // USE_CUDA_FUSER || USE_CPU_FUSER } bool canFuseOnCPU() { - #if USE_CPU_FUSER - return detail::cpu_fuser_enabled; - #endif // USE_CPU_FUSER +#if USE_CPU_FUSER + return detail::cpu_fuser_enabled; +#endif // USE_CPU_FUSER return false; } bool canFuseOnGPU() { - #if USE_CUDA_FUSER - return true; - #endif // USE_CUDA_FUSER +#if USE_CUDA_FUSER + return true; +#endif // USE_CUDA_FUSER return false; } @@ -58,36 +60,37 @@ void overrideCanFuseOnCPU(bool value) { // Uses the above interface by stuffing the graph into a node and treating that // node as a fusion group. std::vector debugLaunchGraph( - Graph& graph -, at::ArrayRef inputs) { - #if USE_CUDA_FUSER || USE_CPU_FUSER - // Creates a fusion group node - auto wrapper_graph = std::make_shared(); - Node* fusion_group = wrapper_graph->insertNode(wrapper_graph->createFusionGroup()); - fusion_group->g_(attr::Subgraph, graph.copy()); - for (size_t i = 0; i < graph.inputs().size(); ++i) { - fusion_group->addInput(wrapper_graph->addInput()); - } - for (size_t i = 0; i < graph.outputs().size(); ++i) { - wrapper_graph->registerOutput(fusion_group->addOutput()); - } - - // Creates the stack, registers and runs the fusion - Stack stack = fmap(inputs); - const auto key = fuser::registerFusion(fusion_group); - fuser::runFusion(key, stack); - return fmap(stack, [](const IValue& iv) { return iv.toTensor(); }); - #else - throw std::runtime_error("Fusion not supported for this build."); - #endif // USE_CUDA_FUSER || USE_CPU_FUSER + Graph& graph, + at::ArrayRef inputs) { +#if USE_CUDA_FUSER || USE_CPU_FUSER + // Creates a fusion group node + auto wrapper_graph = std::make_shared(); + Node* fusion_group = + wrapper_graph->insertNode(wrapper_graph->createFusionGroup()); + fusion_group->g_(attr::Subgraph, graph.copy()); + for (size_t i = 0; i < graph.inputs().size(); ++i) { + fusion_group->addInput(wrapper_graph->addInput()); + } + for (size_t i = 0; i < graph.outputs().size(); ++i) { + wrapper_graph->registerOutput(fusion_group->addOutput()); + } + + // Creates the stack, registers and runs the fusion + Stack stack = fmap(inputs); + const auto key = fuser::registerFusion(fusion_group); + fuser::runFusion(key, stack); + return fmap(stack, [](const IValue& iv) { return iv.toTensor(); }); +#else + throw std::runtime_error("Fusion not supported for this build."); +#endif // USE_CUDA_FUSER || USE_CPU_FUSER } -size_t nCompiledKernels() { - #if USE_CUDA_FUSER || USE_CPU_FUSER - return fuser::nCompiledKernels(); - #else - return 0; - #endif // USE_CUDA_FUSER || USE_CPU_FUSER +size_t nCompiledKernels() { +#if USE_CUDA_FUSER || USE_CPU_FUSER + return fuser::nCompiledKernels(); +#else + return 0; +#endif // USE_CUDA_FUSER || USE_CPU_FUSER } } // namespace jit diff --git a/torch/csrc/jit/fuser/interface.h b/torch/csrc/jit/fuser/interface.h index 89d3b88..8136334 100644 --- a/torch/csrc/jit/fuser/interface.h +++ b/torch/csrc/jit/fuser/interface.h @@ -5,11 +5,12 @@ #include #include +#include #include #include -#include -namespace torch { namespace jit { +namespace torch { +namespace jit { constexpr int kCPUDevice = -1; @@ -27,15 +28,16 @@ TORCH_API void runFusion(const int64_t key, Stack& stack); TORCH_API bool canFuseOnCPU(); TORCH_API bool canFuseOnGPU(); -// Sets whether fusion on the CPU is allowed (disabled by default due to flakiness) +// Sets whether fusion on the CPU is allowed (disabled by default due to +// flakiness) TORCH_API void overrideCanFuseOnCPU(bool value); // Treats the given graph as a fusion group and launches it on the // specified device with the given inputs. // Returns the outputs. TORCH_API std::vector debugLaunchGraph( - Graph& graph -, at::ArrayRef inputs); + Graph& graph, + at::ArrayRef inputs); TORCH_API size_t nCompiledKernels(); diff --git a/torch/csrc/jit/fuser/kernel_cache.cpp b/torch/csrc/jit/fuser/kernel_cache.cpp index 3c52ad8..4496900 100644 --- a/torch/csrc/jit/fuser/kernel_cache.cpp +++ b/torch/csrc/jit/fuser/kernel_cache.cpp @@ -2,15 +2,17 @@ #include #include -#include -#include #include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { struct KernelCacheImpl { // Note: std::unordered_map does not invalidate references even if rehashing - // occurs. This is a critical property for thread-safety. + // occurs. This is a critical property for thread-safety. std::mutex mutex_; int64_t kernel_counter{0}; @@ -33,7 +35,8 @@ int64_t debugNumCachedKernelSpecs() { return cache.specMap_.size(); } -std::shared_ptr normalizeGraphForCache(const std::shared_ptr& graph) { +std::shared_ptr normalizeGraphForCache( + const std::shared_ptr& graph) { auto result = Canonicalize(graph, /*keep_unique_names=*/false); EraseShapeInformation(result); return result; @@ -49,22 +52,24 @@ int64_t store(std::shared_ptr graph) { std::lock_guard guard{cache.mutex_}; const auto key = cache.kernel_counter++; cache.specMap_.emplace( - std::piecewise_construct - , std::forward_as_tuple(key) - , std::forward_as_tuple(key, graph)); + std::piecewise_construct, + std::forward_as_tuple(key), + std::forward_as_tuple(key, graph)); cache.graphToKey_.emplace(std::make_pair(std::move(repr), key)); return key; } // XXX: Does not grab mutex static at::optional nolock_retrieve( - KernelCacheImpl& cache, const int64_t key) { + KernelCacheImpl& cache, + const int64_t key) { auto it = cache.specMap_.find(key); - if (it == cache.specMap_.end()) return at::nullopt; + if (it == cache.specMap_.end()) + return at::nullopt; return &(it->second); } -at::optional retrieve(const int64_t key) { +at::optional retrieve(const int64_t key) { auto& cache = getKernelCache(); std::lock_guard guard{cache.mutex_}; return nolock_retrieve(cache, key); @@ -77,7 +82,8 @@ at::optional lookupGraph(std::shared_ptr graph) { std::lock_guard guard{cache.mutex_}; auto it = cache.graphToKey_.find(repr); - if (it == cache.graphToKey_.end()) return at::nullopt; + if (it == cache.graphToKey_.end()) + return at::nullopt; return nolock_retrieve(cache, it->second); } diff --git a/torch/csrc/jit/fuser/kernel_cache.h b/torch/csrc/jit/fuser/kernel_cache.h index 63b4710..792591c 100644 --- a/torch/csrc/jit/fuser/kernel_cache.h +++ b/torch/csrc/jit/fuser/kernel_cache.h @@ -4,18 +4,21 @@ #include #include -#include #include +#include -#include +#include #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // A thread-safe cache interface. // Normalizes the graph by canonicalizing and erasing shape information -TORCH_API std::shared_ptr normalizeGraphForCache(const std::shared_ptr& graph); +TORCH_API std::shared_ptr normalizeGraphForCache( + const std::shared_ptr& graph); // Stores the given graph, returning the key used to access it TORCH_API int64_t store(std::shared_ptr graph); diff --git a/torch/csrc/jit/fuser/kernel_spec.h b/torch/csrc/jit/fuser/kernel_spec.h index 5942bac..50a94a6 100644 --- a/torch/csrc/jit/fuser/kernel_spec.h +++ b/torch/csrc/jit/fuser/kernel_spec.h @@ -3,22 +3,24 @@ #if USE_CUDA_FUSER || USE_CPU_FUSER #include -#include #include -#include -#include -#include -#include +#include #include #include +#include +#include +#include +#include -#include #include -#include -#include +#include #include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Helper struct containing partition information: the number of tensors // created and the dimension the partitioning is performed on. @@ -26,32 +28,32 @@ namespace torch { namespace jit { namespace fuser { // at runtime the partition info is logically combined with the tensor // descriptions to create PartitionDesc objects. struct TORCH_API PartitionInfo { - PartitionInfo( - const int64_t _nSubTensors - , const int64_t _dim) - : nSubTensors_{_nSubTensors} - , dim_{_dim} - { }; + PartitionInfo(const int64_t _nSubTensors, const int64_t _dim) + : nSubTensors_{_nSubTensors}, dim_{_dim} {}; - int64_t nSubTensors() const { return nSubTensors_; } - int64_t dim() const { return dim_; } + int64_t nSubTensors() const { + return nSubTensors_; + } + int64_t dim() const { + return dim_; + } -private: + private: int64_t nSubTensors_; int64_t dim_; }; - // "Kernel Specification." - Contains device-independent fusion information. - // Each kernel specification contains a map of instantiated generated functions - // that implement some or most of its functionality. Multiple generated - // functions are needed by each abstract specification because of different - // devices (cpu vs gpu, different gpus) and different inputs (int vs float, - // contiguous vs discontiguous). - // Note: uses a mutex to control access to its kernel store - // Note: unordered containers do not invalidate references/pointers on - // rehashing, which is critical for thread-safety. - // TODO: allow abstract kernels to use multiple generated kernels - // TODO: allow abstract kernels to reuse generated kernels from common pool +// "Kernel Specification." - Contains device-independent fusion information. +// Each kernel specification contains a map of instantiated generated functions +// that implement some or most of its functionality. Multiple generated +// functions are needed by each abstract specification because of different +// devices (cpu vs gpu, different gpus) and different inputs (int vs float, +// contiguous vs discontiguous). +// Note: uses a mutex to control access to its kernel store +// Note: unordered containers do not invalidate references/pointers on +// rehashing, which is critical for thread-safety. +// TODO: allow abstract kernels to use multiple generated kernels +// TODO: allow abstract kernels to reuse generated kernels from common pool struct TORCH_API KernelSpec { KernelSpec(const int64_t _key, const std::shared_ptr& _graph) : key_{_key}, @@ -63,10 +65,18 @@ struct TORCH_API KernelSpec { kernels_{} {} // Getters - int64_t key() const { return key_; } - std::shared_ptr graph() const { return graph_; } - const Code& code() const { return code_; } - int64_t nInputs() const { return nInputs_; } + int64_t key() const { + return key_; + } + std::shared_ptr graph() const { + return graph_; + } + const Code& code() const { + return code_; + } + int64_t nInputs() const { + return nInputs_; + } std::vector>& inputBroadcastGroups() { return inputBroadcastGroups_; @@ -75,24 +85,29 @@ struct TORCH_API KernelSpec { return inputBroadcastGroups_; } - std::vector& inputChunks() { return inputChunks_; } - const std::vector& inputChunks() const { return inputChunks_; } + std::vector& inputChunks() { + return inputChunks_; + } + const std::vector& inputChunks() const { + return inputChunks_; + } // Cache functions - c10::optional> findKernel(const ArgSpec& arg_spec) const { + c10::optional> findKernel( + const ArgSpec& arg_spec) const { std::lock_guard guard{mutex_}; const auto it = kernels_.find(arg_spec); - if (it == kernels_.end()) return c10::nullopt; + if (it == kernels_.end()) + return c10::nullopt; return it->second; } - void cacheKernel( - const ArgSpec& arg_spec - , std::shared_ptr kernel) const { + void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr kernel) + const { std::lock_guard guard{mutex_}; kernels_.emplace(arg_spec, kernel); } -private: + private: int64_t key_; std::shared_ptr graph_; Code code_; @@ -100,10 +115,9 @@ private: std::vector> inputBroadcastGroups_; std::vector inputChunks_; mutable std::mutex mutex_; - mutable std::unordered_map< - ArgSpec - , std::shared_ptr - , torch::hash> kernels_; + mutable std:: + unordered_map, torch::hash> + kernels_; }; } // namespace fuser diff --git a/torch/csrc/jit/fuser/partition_desc.h b/torch/csrc/jit/fuser/partition_desc.h index 08b0802..16408d5 100644 --- a/torch/csrc/jit/fuser/partition_desc.h +++ b/torch/csrc/jit/fuser/partition_desc.h @@ -6,28 +6,23 @@ #include #include -#include #include +#include #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Descriptor for chunk-ing an input tensor into subtensors // OR concat-ing an output tensor from subtensors // Note: default constructed used for tensors that do not participate in // chunk or cat operations. struct TORCH_API PartitionDesc { - PartitionDesc() - : nSubTensors_{1} - , dim_{0} - { } - - PartitionDesc( - const TensorDesc& _desc - , size_t _nSubTensors - , size_t _dim) - : nSubTensors_{_nSubTensors} - , dim_{_dim} { + PartitionDesc() : nSubTensors_{1}, dim_{0} {} + + PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim) + : nSubTensors_{_nSubTensors}, dim_{_dim} { JIT_ASSERT(nSubTensors_ > 1); std::vector cont = _desc.contiguity; if (dim_ > 0) { @@ -40,20 +35,32 @@ struct TORCH_API PartitionDesc { subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont)); } - bool isNoop() const { return (nSubTensors_ == 1);} - size_t nSubTensors() const { return nSubTensors_; } - size_t dim() const { return dim_; } - std::shared_ptr subTensorDesc() { return subTensorDesc_; } - const std::shared_ptr subTensorDesc() const { return subTensorDesc_; } + bool isNoop() const { + return (nSubTensors_ == 1); + } + size_t nSubTensors() const { + return nSubTensors_; + } + size_t dim() const { + return dim_; + } + std::shared_ptr subTensorDesc() { + return subTensorDesc_; + } + const std::shared_ptr subTensorDesc() const { + return subTensorDesc_; + } -private: - size_t nSubTensors_; // == 1 for tensors that should not be operated on via chunk/cat + private: + size_t nSubTensors_; // == 1 for tensors that should not be operated on via + // chunk/cat size_t dim_; // dimension along which the chunk/concat occurs - std::shared_ptr subTensorDesc_; // descriptor for the subtensor, if it exists + std::shared_ptr + subTensorDesc_; // descriptor for the subtensor, if it exists }; } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CUDA_FUSER || USE_CPU_FUSER diff --git a/torch/csrc/jit/fuser/tensor_desc.h b/torch/csrc/jit/fuser/tensor_desc.h index d1f0b60..fb02867 100644 --- a/torch/csrc/jit/fuser/tensor_desc.h +++ b/torch/csrc/jit/fuser/tensor_desc.h @@ -4,15 +4,17 @@ #include #include -#include -#include #include +#include +#include -#include -#include #include +#include +#include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // type information needed by the compiler for input/outputs // contiguity[i] is true if the dim i is contiguous with dim i + 1. @@ -21,33 +23,33 @@ struct TORCH_API TensorDesc { at::ScalarType scalar_type; std::vector contiguity; - TensorDesc( - const at::ScalarType& type - , const std::vector& contiguity) - : scalar_type{type} - , contiguity{contiguity} { + TensorDesc(const at::ScalarType& type, const std::vector& contiguity) + : scalar_type{type}, contiguity{contiguity} { if (contiguity.size() == 0) { nDim_ = 0; } else { - nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0); + nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + + (lastIsContiguous() ? 1 : 0); } } // Delegating constructors TensorDesc( - const at::ScalarType& type - , const at::IntList& sizes - , const at::IntList& strides) - : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} + const at::ScalarType& type, + const at::IntList& sizes, + const at::IntList& strides) + : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} TensorDesc(const at::Tensor& t) - : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} + : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} TensorDesc(const CompleteTensorTypePtr& type) - : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} + : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} // number of dimensions after contiguity compression - size_t nDim() const { return nDim_; } + size_t nDim() const { + return nDim_; + } // True iff innermost stride is 1 bool lastIsContiguous() const { @@ -55,12 +57,13 @@ struct TORCH_API TensorDesc { } static std::vector findContiguous( - const at::IntList& sizes - , const at::IntList& strides) { + const at::IntList& sizes, + const at::IntList& strides) { JIT_ASSERT(sizes.size() == strides.size()); std::vector cont(sizes.size()); for (size_t i = 0; i < sizes.size(); ++i) { - const auto expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1; + const auto expected_stride = + (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1; cont[i] = (strides[i] == expected_stride); } return cont; @@ -75,10 +78,13 @@ struct TORCH_API TensorDesc { } static size_t hash(const TensorDesc& spec) { - return torch::get_hash(spec.scalar_type, spec.nDim_, std::hash>{}(spec.contiguity)); + return torch::get_hash( + spec.scalar_type, + spec.nDim_, + std::hash>{}(spec.contiguity)); } -private: + private: size_t nDim_; }; diff --git a/torch/csrc/jit/fuser/tensor_info.h b/torch/csrc/jit/fuser/tensor_info.h index 487359c..161cf3b 100644 --- a/torch/csrc/jit/fuser/tensor_info.h +++ b/torch/csrc/jit/fuser/tensor_info.h @@ -6,23 +6,28 @@ #include -namespace torch { namespace jit { namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // Host-side view of TensorInfo // Note dims[0] - we need to dynamically allocate the dims. struct TORCH_API TensorInfo { - - uint32_t* sizes(size_t nDim) { return &sizes_strides[0]; } - uint32_t* strides(size_t nDim) { return &sizes_strides[nDim]; } + uint32_t* sizes(size_t nDim) { + return &sizes_strides[0]; + } + uint32_t* strides(size_t nDim) { + return &sizes_strides[nDim]; + } void* data; - #pragma GCC diagnostic ignored "-Wpedantic" - uint32_t sizes_strides[0]; - #pragma GCC diagnostic pop +#pragma GCC diagnostic ignored "-Wpedantic" + uint32_t sizes_strides[0]; +#pragma GCC diagnostic pop }; } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CUDA_FUSER || USE_CPU_FUSER diff --git a/torch/csrc/jit/generic_if.h b/torch/csrc/jit/generic_if.h index fa839ba..407b305 100644 --- a/torch/csrc/jit/generic_if.h +++ b/torch/csrc/jit/generic_if.h @@ -1,17 +1,23 @@ // TODO: I'm pretty sure Constness can be done with C++ templates, ala // std::is_const, but no time to work it out... -#define GENERIC_IF(Constness, FullKind, x, Kind) \ - auto && __match_key = x; \ - switch(__match_key->kind()) { \ - case FullKind: { \ - auto * value = static_cast(__match_key); (void) value; -#define GENERIC_ELSEIF(Constness, FullKind, Kind) \ - } break; \ - case FullKind: { \ - auto * value = static_cast(__match_key); (void) value; +#define GENERIC_IF(Constness, FullKind, x, Kind) \ + auto&& __match_key = x; \ + switch (__match_key->kind()) { \ + case FullKind: { \ + auto* value = static_cast(__match_key); \ + (void)value; +#define GENERIC_ELSEIF(Constness, FullKind, Kind) \ + } \ + break; \ + case FullKind: { \ + auto* value = static_cast(__match_key); \ + (void)value; #define GENERIC_ELSE() \ - } break; \ - default: { + } \ + break; \ + default: { #define GENERIC_END() \ - } break; \ - }; + } \ + break; \ + } \ + ; diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index ab0ab40..6d4552f 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -1,46 +1,47 @@ #include -#include #include #include +#include #include +#include #include #include -#include +#include #include +#include #include #include +#include #include #include #include +#include #include -#include -#include -#include -#include -#include #include #include -#include -#include +#include +#include #include +#include +#include #include -#include -#include +#include #include #include #include #include +#include #include #include #include #include #include -#include -namespace torch { namespace jit { +namespace torch { +namespace jit { namespace { @@ -51,8 +52,7 @@ using autograd::variable_list; struct ExecutionPlan { ExecutionPlan() = default; ExecutionPlan(std::shared_ptr graph) - : code(graph) - , graph(std::move(graph)) {} + : code(graph), graph(std::move(graph)) {} void run(Stack& stack) const { return InterpreterState(code).run(stack); @@ -75,7 +75,7 @@ struct ExecutionPlan { struct DifferentiableGraphBackward : public autograd::Function { DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size) - : executor(std::move(executor)) { + : executor(std::move(executor)) { is_var_capture.reserve(capture_size); var_captures.reserve(capture_size); ivalue_captures.reserve(capture_size); @@ -84,8 +84,10 @@ struct DifferentiableGraphBackward : public autograd::Function { variable_list apply(variable_list&& inputs) override { Stack stack; stack.reserve(is_var_capture.size() + inputs.size()); - stack.insert(stack.end(), std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); + stack.insert( + stack.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); auto var_capture_it = var_captures.begin(); auto ivalue_capture_it = ivalue_captures.begin(); for (bool is_var : is_var_capture) { @@ -106,11 +108,12 @@ struct DifferentiableGraphBackward : public autograd::Function { for (size_t i = 0; i < num_outputs(); ++i) { if (should_compute_output(i)) { auto output = std::move(stack[i]).toTensor(); - const auto & edge = next_edge(i); + const auto& edge = next_edge(i); if (output.defined()) { outputs.emplace_back(std::move(output)); } else if (edge.is_valid()) { - outputs.emplace_back(edge.function->input_metadata(edge.input_nr).zeros_like()); + outputs.emplace_back( + edge.function->input_metadata(edge.input_nr).zeros_like()); } else { outputs.emplace_back(); } @@ -121,7 +124,7 @@ struct DifferentiableGraphBackward : public autograd::Function { return outputs; } - void capture(const IValue & val, bool is_output) { + void capture(const IValue& val, bool is_output) { const bool is_tensor = val.isTensor(); is_var_capture.push_back(is_tensor); if (is_tensor) { @@ -130,11 +133,13 @@ struct DifferentiableGraphBackward : public autograd::Function { ivalue_captures.push_back(val); } } -private: + + private: friend struct ExecutionPlan; GraphExecutor executor; - // INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size() + // INVARIANT: is_var_capture.size() == var_captures.size() + + // ivalue_captures.size() std::vector is_var_capture; std::vector var_captures; std::vector ivalue_captures; @@ -154,16 +159,20 @@ struct DifferentiableGraphOp { num_outputs(this->grad.f->outputs().size()) {} // XXX: keep in mind that stack can be larger than the inputs we need! - int operator()(Stack & stack) const { - auto grad_fn = std::make_shared(grad_executor, - grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size()); + int operator()(Stack& stack) const { + auto grad_fn = std::make_shared( + grad_executor, + grad.df_input_captured_inputs.size() + + grad.df_input_captured_outputs.size()); { auto inputs = last(stack, num_inputs); - // hook up the outputs of df to the gradient functions of the inputs that require gradients - for(auto idx : grad.df_output_vjps) { + // hook up the outputs of df to the gradient functions of the inputs that + // require gradients + for (auto idx : grad.df_output_vjps) { auto v = Variable(inputs[idx].toTensor()); - grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{}); + grad_fn->add_next_edge( + v.defined() ? v.gradient_edge() : autograd::Edge{}); } captureInputs(*grad_fn, inputs); } @@ -175,19 +184,19 @@ struct DifferentiableGraphOp { auto outputs = last(stack, num_outputs); // hookup the gradients for the output tensors that require gradients // to the inputs to our gradient function df - // TODO - XXX - if any output is the same tensor multiple times, views have to be - // setup here. We need to refactor autograd until it is safe for - // tensors to be constructed without all the viewing infrastructure. - // this is currently intentionally not done here so we can get an idea of our - // perf before introducing overhead for correctness - for(auto idx : grad.df_input_vjps) { + // TODO - XXX - if any output is the same tensor multiple times, views + // have to be setup here. We need to refactor autograd until it is safe + // for tensors to be constructed without all the viewing infrastructure. + // this is currently intentionally not done here so we can get an idea of + // our perf before introducing overhead for correctness + for (auto idx : grad.df_input_vjps) { // Note: we have to set this up in place, or we have to throw away and // reallocate variables that were already created in wrapTensors. We // should add an API for this. Variable output = outputs[idx].toTensor(); - // NB: since our requires_grad setting is only a heuristic we might end up - // wanting to differentiate through integral tensors, which is generally a - // hard error in autograd. + // NB: since our requires_grad setting is only a heuristic we might end + // up wanting to differentiate through integral tensors, which is + // generally a hard error in autograd. if (at::isFloatingType(output.type().scalarType())) { autograd::create_gradient_edge(output, grad_fn); output.set_requires_grad(true); @@ -204,30 +213,37 @@ struct DifferentiableGraphOp { return 0; } -private: + private: friend GraphExecutor* detail::getGradExecutor(Operation& op); - void detachVariables(Stack & stack) const { - // It would be nice to use an ArrayRef here, but unfortunately those can only - // return const references, so we need to do a bunch of indexing ourselves. + void detachVariables(Stack& stack) const { + // It would be nice to use an ArrayRef here, but unfortunately those can + // only return const references, so we need to do a bunch of indexing + // ourselves. const int64_t stack_size = stack.size(); const int64_t stack_offset = stack_size - num_inputs; for (int64_t i = stack_offset; i < stack_size; ++i) { - auto & v = stack[i]; - if (!v.isTensor()) continue; + auto& v = stack[i]; + if (!v.isTensor()) + continue; auto t = std::move(v).toTensor(); - v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)}; + v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() + : std::move(t)}; } } // Capture (save) inputs that would be required to subsequently run backwards - void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef inputs) const { + void captureInputs( + DifferentiableGraphBackward& grad_fn, + at::ArrayRef inputs) const { for (size_t offset : grad.df_input_captured_inputs) { - grad_fn.capture(inputs[offset], /*is_output*/false); + grad_fn.capture(inputs[offset], /*is_output*/ false); } } - void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef outputs) const { + void captureOutputs( + DifferentiableGraphBackward& grad_fn, + at::ArrayRef outputs) const { for (size_t offset : grad.df_input_captured_outputs) { - grad_fn.capture(outputs[offset], /*is_output*/true); + grad_fn.capture(outputs[offset], /*is_output*/ true); } } @@ -239,39 +255,42 @@ private: const size_t num_outputs; }; -void packGradient(Gradient gradient, Node *dnode) { +void packGradient(Gradient gradient, Node* dnode) { JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph); dnode->g_(attr::Subgraph, gradient.f) - ->g_(attr::ReverseSubgraph, gradient.df) - ->i_(attr::f_real_outputs, gradient.f_real_outputs) - ->is_(attr::df_input_vjps, fmap(gradient.df_input_vjps)) - ->is_(attr::df_input_captured_inputs, fmap(gradient.df_input_captured_inputs)) - ->is_(attr::df_input_captured_outputs, fmap(gradient.df_input_captured_outputs)) - ->is_(attr::df_output_vjps, fmap(gradient.df_output_vjps)); + ->g_(attr::ReverseSubgraph, gradient.df) + ->i_(attr::f_real_outputs, gradient.f_real_outputs) + ->is_(attr::df_input_vjps, fmap(gradient.df_input_vjps)) + ->is_( + attr::df_input_captured_inputs, + fmap(gradient.df_input_captured_inputs)) + ->is_( + attr::df_input_captured_outputs, + fmap(gradient.df_input_captured_outputs)) + ->is_(attr::df_output_vjps, fmap(gradient.df_output_vjps)); } -Gradient getGradient(const Node *n) { +Gradient getGradient(const Node* n) { JIT_ASSERT(n->kind() == prim::DifferentiableGraph); Gradient grad; grad.f = n->g(attr::Subgraph); grad.df = n->g(attr::ReverseSubgraph); grad.f_real_outputs = n->i(attr::f_real_outputs); grad.df_input_vjps = fmap(n->is(attr::df_input_vjps)); - grad.df_input_captured_inputs = fmap(n->is(attr::df_input_captured_inputs)); - grad.df_input_captured_outputs = fmap(n->is(attr::df_input_captured_outputs)); + grad.df_input_captured_inputs = + fmap(n->is(attr::df_input_captured_inputs)); + grad.df_input_captured_outputs = + fmap(n->is(attr::df_input_captured_outputs)); grad.df_output_vjps = fmap(n->is(attr::df_output_vjps)); return grad; } } // anonymous namespace -RegisterOperators reg_graph_executor_ops({ - Operator( - prim::DifferentiableGraph, - [](const Node *n) -> Operation { +RegisterOperators reg_graph_executor_ops( + {Operator(prim::DifferentiableGraph, [](const Node* n) -> Operation { return DifferentiableGraphOp(getGradient(n)); - }) -}); + })}); namespace detail { @@ -286,11 +305,10 @@ GraphExecutor* getGradExecutor(Operation& op) { // a Graph can be created via tracing, or via a language-based frontend // GraphExecutor runs it. It can run the same graph on many different sizes -// and different requires_grad states, and handles specializations for each situation. -// GraphExecutor is completely unaware of tracing or module parameters to keep the -// tracing concerns separated. +// and different requires_grad states, and handles specializations for each +// situation. GraphExecutor is completely unaware of tracing or module +// parameters to keep the tracing concerns separated. struct GraphExecutorImpl { - static std::shared_ptr prepareGraph(std::shared_ptr& graph) { auto copy = graph->copy(); EraseShapeInformation(copy); @@ -303,7 +321,7 @@ struct GraphExecutorImpl { } if (auto tuple_type = ptr->cast()) { size_t total = 0; - for (auto & elem : tuple_type->elements()) { + for (auto& elem : tuple_type->elements()) { total += countFlatInputs(elem); } return total; @@ -313,18 +331,18 @@ struct GraphExecutorImpl { static size_t countFlatInputs(const std::shared_ptr& graph) { size_t total = 0; - for (Value * input : graph->inputs()) { + for (Value* input : graph->inputs()) { total += countFlatInputs(input->type()); } return total; } inline bool hasMutableOperators(Block* block) { - for(auto n : block->nodes()) { - if(n->kind().is_aten() && n->schema().is_mutable()) + for (auto n : block->nodes()) { + if (n->kind().is_aten() && n->schema().is_mutable()) return true; - for(auto b : n->blocks()) { - if(hasMutableOperators(b)) + for (auto b : n->blocks()) { + if (hasMutableOperators(b)) return true; } } @@ -341,21 +359,28 @@ struct GraphExecutorImpl { num_outputs(this->graph->outputs().size()) {} // entry point where execution begins - void run(Stack & stack) { - AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size()); - - if(tracer::isTracing()) { + void run(Stack& stack) { + AT_CHECK( + stack.size() >= num_inputs, + "expected ", + num_inputs, + " inputs, but got only ", + stack.size()); + + if (tracer::isTracing()) { return runTraced(stack); } - auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback(); + auto& execution_plan = + optimize ? getOrCompile(stack) : getOrCompileFallback(); return execution_plan.run(stack); } std::shared_ptr graphFor(const Stack& stack) const { JIT_ASSERT(stack.size() >= num_inputs); auto inputs = last(stack, num_inputs); - ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs); + ArgumentSpec spec( + autograd::GradMode::is_enabled(), inputs, num_flat_inputs); if (!optimize) { AT_CHECK(fallback, "No graph found for given inputs"); @@ -373,7 +398,7 @@ struct GraphExecutorImpl { if (fallback) { state.fallback = fallback.getDebugState(); } - for (auto & entry : plan_cache) { + for (auto& entry : plan_cache) { state.execution_plans.emplace(entry.first, entry.second.getDebugState()); } return state; @@ -387,12 +412,12 @@ struct GraphExecutorImpl { autodiffSubgraphInlineThreshold = 1; } -private: + private: friend struct GraphExecutor; - const ExecutionPlan & getOrCompileFallback() { + const ExecutionPlan& getOrCompileFallback() { std::lock_guard lock(compile_mutex); - if(!fallback) { + if (!fallback) { auto graph_ = graph->copy(); runRequiredPasses(graph_); fallback = ExecutionPlan(graph_); @@ -400,10 +425,13 @@ private: return fallback; } - const ExecutionPlan & getOrCompile(const Stack& stack) { - // outside lock guard, to minimize the time holding the lock on the fast path - // ArgumentSpec even computes its hashCode here. - ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs), num_flat_inputs); + const ExecutionPlan& getOrCompile(const Stack& stack) { + // outside lock guard, to minimize the time holding the lock on the fast + // path ArgumentSpec even computes its hashCode here. + ArgumentSpec spec( + autograd::GradMode::is_enabled(), + last(stack, num_inputs), + num_flat_inputs); { std::lock_guard lock(compile_mutex); auto it = plan_cache.find(spec); @@ -415,7 +443,7 @@ private: } } - ExecutionPlan compileSpec(const ArgumentSpec & spec) { + ExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); setInputTypes(*opt_graph, spec); @@ -427,13 +455,14 @@ private: // Phase 2. Propagate detailed information about the spec through the // graph (enabled more specializations in later passes). // Shape propagation sometimes depends on certain arguments being - // constants, and constant propagation doesn't need shape information - // anyway, so it's better to run it first. + // constants, and constant propagation doesn't need shape + // information anyway, so it's better to run it first. ConstantPropagation(opt_graph); PropagateInputShapes(opt_graph); PropagateRequiresGrad(opt_graph); - // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that + // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites + // that // we can still execute using autograd). runOptimization(opt_graph, spec); @@ -442,8 +471,9 @@ private: // Phase 5. Apply non-differentiable optimizations to the graphs we've found // (or the whole grpah if we know we won't need its derivative). if (needsGradient(opt_graph)) { - auto diff_nodes = CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold); - for (Node * dnode : diff_nodes) { + auto diff_nodes = + CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold); + for (Node* dnode : diff_nodes) { auto diff_graph = std::move(dnode->g(attr::Subgraph)); Gradient gradient = differentiate(diff_graph); runNondiffOptimization(gradient.f); @@ -458,7 +488,9 @@ private: return ExecutionPlan(opt_graph); } - void runOptimization(std::shared_ptr& graph, const ArgumentSpec& spec) { + void runOptimization( + std::shared_ptr& graph, + const ArgumentSpec& spec) { // Basic graph preprocessing to eliminate noise. EliminateDeadCode(graph); EliminateCommonSubexpression(graph); @@ -486,7 +518,7 @@ private: return false; if (mayIntroduceGradient(graph->block())) return true; - for (const Value * input : graph->inputs()) { + for (const Value* input : graph->inputs()) { if (input->type()->requires_grad()) return true; } @@ -505,17 +537,18 @@ private: return false; } - void runTraced(Stack & stack) { + void runTraced(Stack& stack) { const auto& state = tracer::getTracingState(); auto inputs = last(stack, num_inputs); - auto input_values = fmap(inputs, [](const IValue & v) { - return tracer::getNestedValueTrace(v); - }); - - ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs); - // NB: we could just run the fallback in here and call it a day, but that would loose all - // the control flow information we have in the graph. Thus, we run the fallback to - // get the correct output values, but we will override the tracing states later. + auto input_values = fmap( + inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); }); + + ArgumentSpec spec( + autograd::GradMode::is_enabled(), inputs, num_flat_inputs); + // NB: we could just run the fallback in here and call it a day, but that + // would loose all the control flow information we have in the graph. Thus, + // we run the fallback to get the correct output values, but we will + // override the tracing states later. { // No need to trace a script module. ResourceGuard guard(tracer::pauseTracing()); @@ -530,7 +563,8 @@ private: auto local_graph = this->graph->copy(); setInputTypes(*local_graph, spec); PropagateInputShapes(local_graph); - auto output_values = inlineCallTo(*state->graph, *local_graph, input_values); + auto output_values = + inlineCallTo(*state->graph, *local_graph, input_values); auto outputs = last(stack, num_outputs); for (size_t i = 0; i < outputs.size(); ++i) { @@ -538,27 +572,29 @@ private: } } - // The unoptimized starting graph. This field is effectively const, but we can't make it so - // because Graph::copy() is not const (and making it const is not that easy at this point). + // The unoptimized starting graph. This field is effectively const, but we + // can't make it so because Graph::copy() is not const (and making it const is + // not that easy at this point). std::shared_ptr graph; - // If false, we'll run the graph as we get it, without any optimizations. Useful - // for debugging. + // If false, we'll run the graph as we get it, without any optimizations. + // Useful for debugging. const bool optimize; const size_t num_inputs; - const size_t num_flat_inputs; // Number of inputs, assuming all tuples would be flattened. + const size_t num_flat_inputs; // Number of inputs, assuming all tuples would + // be flattened. const size_t num_outputs; - // Populated only when optimize is false (and in that case plan_cache will be unused). - // The compiled version of graph. + // Populated only when optimize is false (and in that case plan_cache will be + // unused). The compiled version of graph. ExecutionPlan fallback; - // Mapping from argument configurations to optimized versions of the graph that are - // specialized to the spec. + // Mapping from argument configurations to optimized versions of the graph + // that are specialized to the spec. std::unordered_map plan_cache; - // GraphExecutors can be accessed from multiple threads, so this thread needs to be - // held every time we access the fallback or plan_cache. + // GraphExecutors can be accessed from multiple threads, so this thread needs + // to be held every time we access the fallback or plan_cache. std::mutex compile_mutex; // Some tunable parameters @@ -567,9 +603,9 @@ private: }; GraphExecutor::GraphExecutor(std::shared_ptr graph, bool optimize) -: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {} + : pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {} -void GraphExecutor::run(Stack & inputs) { +void GraphExecutor::run(Stack& inputs) { return pImpl->run(inputs); } @@ -589,8 +625,7 @@ void GraphExecutor::debugDisableAutodiffSubgraphInlining() { return pImpl->debugDisableAutodiffSubgraphInlining(); } - -void runRequiredPasses(const std::shared_ptr& g) { +void runRequiredPasses(const std::shared_ptr& g) { specializeUndef(*g); LowerGradOf(*g); // implicit inserted expand nodes are not necessarily always valid @@ -602,4 +637,5 @@ void runRequiredPasses(const std::shared_ptr& g) { EliminateDeadCode(g); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index a53056c..e6564e5 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -1,13 +1,14 @@ #pragma once -#include +#include +#include +#include #include #include -#include -#include -#include +#include -namespace torch { namespace jit { +namespace torch { +namespace jit { struct GraphExecutorState; @@ -29,7 +30,7 @@ struct GraphExecutorImpl; struct TORCH_API GraphExecutor { GraphExecutor() = default; GraphExecutor(std::shared_ptr graph, bool optimize = true); - void run(Stack & inputs); + void run(Stack& inputs); explicit operator bool() const { return pImpl != nullptr; } @@ -37,7 +38,8 @@ struct TORCH_API GraphExecutor { std::shared_ptr graphFor(const Stack& inputs) const; GraphExecutorState getDebugState(); void debugDisableAutodiffSubgraphInlining(); -private: + + private: std::shared_ptr pImpl; }; @@ -51,5 +53,5 @@ GraphExecutor* getGradExecutor(Operation& op); } // namespace detail - -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/graph_node_list.h b/torch/csrc/jit/graph_node_list.h index 3fe3dcd..c20cb72 100644 --- a/torch/csrc/jit/graph_node_list.h +++ b/torch/csrc/jit/graph_node_list.h @@ -2,7 +2,8 @@ #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // Intrusive doubly linked lists with sane reverse iterators. // The header file is named generic_graph_node_list.h because it is ONLY @@ -39,21 +40,28 @@ struct Node; using graph_node_list = generic_graph_node_list; using const_graph_node_list = generic_graph_node_list; using graph_node_list_iterator = generic_graph_node_list_iterator; -using const_graph_node_list_iterator = generic_graph_node_list_iterator; +using const_graph_node_list_iterator = + generic_graph_node_list_iterator; template struct generic_graph_node_list_iterator { - generic_graph_node_list_iterator() - : cur(nullptr), d(kNextDirection) {} - generic_graph_node_list_iterator(T * cur, int d) - : cur(cur), d(d) {} - generic_graph_node_list_iterator(const generic_graph_node_list_iterator & rhs) = default; - generic_graph_node_list_iterator(generic_graph_node_list_iterator && rhs) = default; - generic_graph_node_list_iterator& operator=(const generic_graph_node_list_iterator & rhs) = default; - generic_graph_node_list_iterator& operator=(generic_graph_node_list_iterator && rhs) = default; - T * operator*() const { return cur; } - T * operator->() const { return cur; } - generic_graph_node_list_iterator & operator++() { + generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} + generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {} + generic_graph_node_list_iterator( + const generic_graph_node_list_iterator& rhs) = default; + generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) = + default; + generic_graph_node_list_iterator& operator=( + const generic_graph_node_list_iterator& rhs) = default; + generic_graph_node_list_iterator& operator=( + generic_graph_node_list_iterator&& rhs) = default; + T* operator*() const { + return cur; + } + T* operator->() const { + return cur; + } + generic_graph_node_list_iterator& operator++() { JIT_ASSERT(cur); cur = cur->next_in_graph[d]; return *this; @@ -63,7 +71,7 @@ struct generic_graph_node_list_iterator { ++(*this); return old; } - generic_graph_node_list_iterator & operator--() { + generic_graph_node_list_iterator& operator--() { JIT_ASSERT(cur); cur = cur->next_in_graph[reverseDir()]; return *this; @@ -79,19 +87,20 @@ struct generic_graph_node_list_iterator { // silently cause the wrong one to be called. // iterator will point to the previous entry after call void destroyCurrent() { - T * n = cur; + T* n = cur; cur = cur->next_in_graph[reverseDir()]; n->destroy(); } generic_graph_node_list_iterator reverse() { return generic_graph_node_list_iterator(cur, reverseDir()); } -private: + + private: int reverseDir() { return d == kNextDirection ? kPrevDirection : kNextDirection; } - T * cur; - int d; //direction 0 is forward 1 is reverse, see next_in_graph + T* cur; + int d; // direction 0 is forward 1 is reverse, see next_in_graph }; template @@ -105,10 +114,10 @@ struct generic_graph_node_list { return generic_graph_node_list_iterator(head->next_in_graph[d], d); } generic_graph_node_list_iterator end() { - return generic_graph_node_list_iterator(head,d); + return generic_graph_node_list_iterator(head, d); } generic_graph_node_list_iterator end() const { - return generic_graph_node_list_iterator(head,d); + return generic_graph_node_list_iterator(head, d); } generic_graph_node_list_iterator rbegin() { return reverse().begin(); @@ -123,37 +132,52 @@ struct generic_graph_node_list { return reverse().end(); } generic_graph_node_list reverse() { - return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection); + return generic_graph_node_list( + head, d == kNextDirection ? kPrevDirection : kNextDirection); } const generic_graph_node_list reverse() const { - return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection); - } - T* front() { return head->next_in_graph[d]; } - const T* front() const { return head->next_in_graph[d]; } - T* back() { return head->next_in_graph[!d]; } - const T* back() const { return head->next_in_graph[!d]; } - generic_graph_node_list(T * head, int d) - : head(head), d(d) {} -private: - T * head; + return generic_graph_node_list( + head, d == kNextDirection ? kPrevDirection : kNextDirection); + } + T* front() { + return head->next_in_graph[d]; + } + const T* front() const { + return head->next_in_graph[d]; + } + T* back() { + return head->next_in_graph[!d]; + } + const T* back() const { + return head->next_in_graph[!d]; + } + generic_graph_node_list(T* head, int d) : head(head), d(d) {} + + private: + T* head; int d; }; template -static inline bool operator==(generic_graph_node_list_iterator a, generic_graph_node_list_iterator b) { +static inline bool operator==( + generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { return *a == *b; } template -static inline bool operator!=(generic_graph_node_list_iterator a, generic_graph_node_list_iterator b) { +static inline bool operator!=( + generic_graph_node_list_iterator a, + generic_graph_node_list_iterator b) { return *a != *b; } -}} +} // namespace jit +} // namespace torch namespace std { -template +template struct iterator_traits> { using difference_type = int64_t; using value_type = T*; @@ -162,4 +186,4 @@ struct iterator_traits> { using iterator_category = bidirectional_iterator_tag; }; -} +} // namespace std diff --git a/torch/csrc/jit/hooks_for_testing.cpp b/torch/csrc/jit/hooks_for_testing.cpp index b80f0ab..7ee0195 100644 --- a/torch/csrc/jit/hooks_for_testing.cpp +++ b/torch/csrc/jit/hooks_for_testing.cpp @@ -4,13 +4,15 @@ namespace torch { namespace jit { -static std::function module)> emit_module_callback; +static std::function module)> + emit_module_callback; TORCH_API void didFinishEmitModule(std::shared_ptr module) { - if(emit_module_callback) { + if (emit_module_callback) { emit_module_callback(std::move(module)); } } -TORCH_API void setEmitModuleHook(std::function module)> cb) { +TORCH_API void setEmitModuleHook( + std::function module)> cb) { emit_module_callback = std::move(cb); } } // namespace jit diff --git a/torch/csrc/jit/hooks_for_testing.h b/torch/csrc/jit/hooks_for_testing.h index f160672..46b3398 100644 --- a/torch/csrc/jit/hooks_for_testing.h +++ b/torch/csrc/jit/hooks_for_testing.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include #include namespace torch { @@ -9,6 +9,7 @@ namespace script { struct Module; } TORCH_API void didFinishEmitModule(std::shared_ptr module); -TORCH_API void setEmitModuleHook(std::function module)> cb); +TORCH_API void setEmitModuleHook( + std::function module)> cb); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index bba1d1b..56c5859 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -1,13 +1,12 @@ #include #include +#include #include +#include #include -#include -#include #include -#include - +#include #include #include @@ -16,12 +15,13 @@ #include +#include +#include #include #include -#include -#include -namespace torch { namespace jit { +namespace torch { +namespace jit { namespace { @@ -37,26 +37,27 @@ class ScriptModuleDeserializer final { ScriptModuleDeserializer(std::istream* is); - void deserialize(ModuleLookup module_lookup, + void deserialize( + ModuleLookup module_lookup, c10::optional device); -private: - at::Tensor loadTensor( - const torch::TensorDef& tensor_proto, - std::unordered_map& storageMap); + private: + at::Tensor loadTensor( + const torch::TensorDef& tensor_proto, + std::unordered_map& storageMap); - void convertModule(const torch::ModuleDef& module_def); + void convertModule(const torch::ModuleDef& module_def); - void loadTensorTable(torch::ModelDef* model_def); + void loadTensorTable(torch::ModelDef* model_def); - PyTorchStreamReader reader_; - // this is a hack to make sure the script module created in C++ is the - // same as created in Python - ModuleLookup moduleLookup_; - c10::optional device_; - std::vector moduleStack_; + PyTorchStreamReader reader_; + // this is a hack to make sure the script module created in C++ is the + // same as created in Python + ModuleLookup moduleLookup_; + c10::optional device_; + std::vector moduleStack_; - std::vector tensor_table_; + std::vector tensor_table_; }; ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename) @@ -67,7 +68,8 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename) ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is) : reader_(is) {} -void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup, +void ScriptModuleDeserializer::deserialize( + ModuleLookup module_lookup, c10::optional device) { torch::ModelDef model_def; at::DataPtr data_ptr; @@ -108,15 +110,18 @@ void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup, void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) { std::unordered_map storageMap; - for(const torch::TensorDef& tensor : model_def->tensors()) { + for (const torch::TensorDef& tensor : model_def->tensors()) { tensor_table_.emplace_back(loadTensor(tensor, storageMap)); } } -at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_proto, - std::unordered_map& storageMap) { - std::vector dims(tensor_proto.dims().begin(), tensor_proto.dims().end()); - std::vector strides(tensor_proto.strides().begin(), tensor_proto.strides().end()); +at::Tensor ScriptModuleDeserializer::loadTensor( + const torch::TensorDef& tensor_proto, + std::unordered_map& storageMap) { + std::vector dims( + tensor_proto.dims().begin(), tensor_proto.dims().end()); + std::vector strides( + tensor_proto.strides().begin(), tensor_proto.strides().end()); auto type = at::typeMetaToScalarType( caffe2::DataTypeToTypeMeta(tensor_proto.data_type())); const std::string& record_key = tensor_proto.data().key(); @@ -138,17 +143,19 @@ at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_p record_size / at::CPU(type).typeMeta().itemsize(), nullptr); // NB: we didn't set any allocator for the tensor if (device.type() == at::DeviceType::CPU) { - storage_it = storageMap.insert(std::make_pair( - record_key, cpu_storage)).first; + storage_it = + storageMap.insert(std::make_pair(record_key, cpu_storage)).first; } else if (device.type() == at::DeviceType::CUDA) { - at::Tensor cpu_tensor = at::empty({0}, at::CPU(type).options()).set_( - cpu_storage, tensor_proto.offset(), dims, strides); - at::Storage cuda_storage = cpu_tensor.to(device, - cpu_tensor.scalar_type()).storage(); - storage_it = storageMap.insert(std::make_pair( - record_key, cuda_storage)).first; + at::Tensor cpu_tensor = + at::empty({0}, at::CPU(type).options()) + .set_(cpu_storage, tensor_proto.offset(), dims, strides); + at::Storage cuda_storage = + cpu_tensor.to(device, cpu_tensor.scalar_type()).storage(); + storage_it = + storageMap.insert(std::make_pair(record_key, cuda_storage)).first; } else { - AT_ERROR("supported devices include CPU and CUDA, however got ", + AT_ERROR( + "supported devices include CPU and CUDA, however got ", at::DeviceTypeName(device.type(), false)); } } @@ -157,19 +164,20 @@ at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_p storage_it->second.device().index() != device.index())) { std::stringstream oss; oss << "storage previously was specified with device " - << storage_it->second.device() - << "but now is specified with device " - << device << std::endl; + << storage_it->second.device() << "but now is specified with device " + << device << std::endl; AT_ERROR(oss.str()); } at::Tensor result; if (device.type() == at::DeviceType::CPU) { - result = at::empty({0}, at::CPU(type).options()).set_( - storage_it->second, tensor_proto.offset(), dims, strides); + result = + at::empty({0}, at::CPU(type).options()) + .set_(storage_it->second, tensor_proto.offset(), dims, strides); } else if (device.type() == at::DeviceType::CUDA) { - result = at::empty({0}, at::CUDA(type).options()).set_( - storage_it->second, tensor_proto.offset(), dims, strides); + result = + at::empty({0}, at::CUDA(type).options()) + .set_(storage_it->second, tensor_proto.offset(), dims, strides); } AT_ASSERT(result.defined()); @@ -191,19 +199,19 @@ void ScriptModuleDeserializer::convertModule( for (int i = 0; i < module_def.parameters_size(); ++i) { const torch::ParameterDef& param_def = module_def.parameters(i); at::Tensor tensor = tensor_table_.at(param_def.tensor_id()); - module->register_parameter( - param_def.name(), tensor, param_def.is_buffer()); + module->register_parameter(param_def.name(), tensor, param_def.is_buffer()); } if (module_def.has_torchscript_arena()) { at::DataPtr data; size_t size; - std::tie(data, size) = reader_.getRecord(module_def.torchscript_arena().key()); + std::tie(data, size) = + reader_.getRecord(module_def.torchscript_arena().key()); std::string data_str(static_cast(data.get()), size); import_methods(module, data_str, tensor_table_); } } -} // namespace +} // namespace void import_ir_module( ModuleLookup module_lookup, @@ -221,7 +229,8 @@ void import_ir_module( deserializer.deserialize(module_lookup, device); } -std::shared_ptr load(std::istream& in, +std::shared_ptr load( + std::istream& in, c10::optional device) { auto module = std::make_shared(); @@ -242,15 +251,17 @@ std::shared_ptr load(std::istream& in, return module; } -std::shared_ptr load(const std::string& filename, +std::shared_ptr load( + const std::string& filename, c10::optional device) { std::ifstream in(filename, std::ios_base::binary); - AT_CHECK(! in.fail(), "load: could not open file ", filename); + AT_CHECK(!in.fail(), "load: could not open file ", filename); auto module = load(in, device); return module; } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h index 5560a7d..2252ba4 100644 --- a/torch/csrc/jit/import.h +++ b/torch/csrc/jit/import.h @@ -25,7 +25,8 @@ TORCH_API void import_ir_module( /// /// The istream must contain a serialized `script::Module`, exported via /// `torch::jit::ExportModule` in C++. -TORCH_API std::shared_ptr load(std::istream& in, +TORCH_API std::shared_ptr load( + std::istream& in, c10::optional device = c10::nullopt); /// Loads a serialized `script::Module` from the given `filename`. @@ -33,7 +34,8 @@ TORCH_API std::shared_ptr load(std::istream& in, /// The file stored at the location given in `filename` must contain a /// serialized `script::Module`, exported either via `ScriptModule.save()` in /// Python or `torch::jit::ExportModule` in C++. -TORCH_API std::shared_ptr load(const std::string& filename, +TORCH_API std::shared_ptr load( + const std::string& filename, c10::optional device = c10::nullopt); } // namespace jit diff --git a/torch/csrc/jit/import_method.cpp b/torch/csrc/jit/import_method.cpp index 6caff06..a51e32b 100644 --- a/torch/csrc/jit/import_method.cpp +++ b/torch/csrc/jit/import_method.cpp @@ -1,51 +1,59 @@ #include #include -namespace torch { namespace jit { - +namespace torch { +namespace jit { // this is a much simpler accessor that only handles modules, parameters, and // and methods. It does not depend on python to work. struct ModuleAccessorValue : public script::SugaredValue { ModuleAccessorValue(std::shared_ptr module) - : module(std::move(module)) {} + : module(std::move(module)) {} std::string kind() const override { return "module"; } // select an attribute on it, e.g. `this.field` - std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { - if(script::NamedModule* v = module->find_module(field)) { + std::shared_ptr attr( + const SourceRange& loc, + script::Method& m, + const std::string& field) override { + if (script::NamedModule* v = module->find_module(field)) { return std::make_shared(v->module); - } else if(script::NamedParameter* v = module->find_parameter(field)) { - return std::make_shared(m.get_or_add_parameter(v->slot())); - } else if(script::Method* m = module->find_method(field)) { + } else if (script::NamedParameter* v = module->find_parameter(field)) { + return std::make_shared( + m.get_or_add_parameter(v->slot())); + } else if (script::Method* m = module->find_method(field)) { return std::make_shared(module, *m); } else { throw script::ErrorReport(loc) << "unknown attr: " << field; } } -private: + + private: std::shared_ptr module; }; struct OpsValue : public script::SugaredValue { - OpsValue(size_t version) - : version_(version) {} + OpsValue(size_t version) : version_(version) {} std::string kind() const override { return "ops"; } - std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { + std::shared_ptr attr( + const SourceRange& loc, + script::Method& m, + const std::string& field) override { return std::make_shared(field, version_); } size_t version_; }; struct ConstantValue : public script::SugaredValue { - ConstantValue(IValue value) - : value_(std::move(value)) {} + ConstantValue(IValue value) : value_(std::move(value)) {} IValue value_; - std::string kind() const override { return "constant"; } - Value * asValue(const SourceRange& loc, script::Method & m) override { + std::string kind() const override { + return "constant"; + } + Value* asValue(const SourceRange& loc, script::Method& m) override { return m.graph()->insertConstant(value_); } }; @@ -54,17 +62,19 @@ struct ConstantValue : public script::SugaredValue { // in the 'constants' vector. This table is will be stored in a container format // and given to the import_method when restoring the code. struct ConstantTableValue : public script::SugaredValue { - ConstantTableValue(ArrayRef constants) - : constants_(constants) {} + ConstantTableValue(ArrayRef constants) : constants_(constants) {} std::string kind() const override { return "CONSTANTS"; } // select an attribute on it, e.g. `this.field` - std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { + std::shared_ptr attr( + const SourceRange& loc, + script::Method& m, + const std::string& field) override { const char* field_s = field.c_str(); char* end; int64_t offset = std::strtoll(field_s + 1, &end, 10); - if(field.size() < 2 || *end != 0) + if (field.size() < 2 || *end != 0) throw script::ErrorReport(loc) << "invalid constant specifier: " << field; if (offset < 0 || size_t(offset) >= constants_.size()) { throw script::ErrorReport(loc) << "constant index " << offset @@ -76,7 +86,7 @@ struct ConstantTableValue : public script::SugaredValue { } private: - ArrayRef constants_; + ArrayRef constants_; }; static size_t parseVersionNumber(script::Lexer& L) { @@ -87,29 +97,40 @@ static size_t parseVersionNumber(script::Lexer& L) { L.expect(script::TK_NEWLINE); auto version = script::Const::create(L.cur().range, version_text); if (name != "op_version_set") - throw script::ErrorReport(range) << "expected an assignment to op_version_set"; + throw script::ErrorReport(range) + << "expected an assignment to op_version_set"; if (!version.isIntegral()) - throw script::ErrorReport(range) << "expected an integral version but found " << version.text(); - return size_t(version.asIntegral()); + throw script::ErrorReport(range) + << "expected an integral version but found " << version.text(); + return size_t(version.asIntegral()); } -void import_methods(const std::shared_ptr& mod, const std::string& src, const std::vector& constant_table) { +void import_methods( + const std::shared_ptr& mod, + const std::string& src, + const std::vector& constant_table) { script::Parser p(src); size_t version = parseVersionNumber(p.lexer()); std::unordered_map> env = { - {"torch", std::make_shared("aten", version)}, - {"ops", std::make_shared(version)}, - {"CONSTANTS", std::make_shared(constant_table)}, - {"fork", std::make_shared()}, - {"annotate", std::make_shared()}, - {"inf", std::make_shared(std::numeric_limits::infinity())}, - {"nan", std::make_shared(std::numeric_limits::quiet_NaN())}, + {"torch", std::make_shared("aten", version)}, + {"ops", std::make_shared(version)}, + {"CONSTANTS", std::make_shared(constant_table)}, + {"fork", std::make_shared()}, + {"annotate", std::make_shared()}, + {"inf", + std::make_shared( + std::numeric_limits::infinity())}, + {"nan", + std::make_shared( + std::numeric_limits::quiet_NaN())}, }; - auto resolver = [&](const std::string& name, script::Method& m, const SourceRange& loc) - -> std::shared_ptr { + auto resolver = + [&](const std::string& name, + script::Method& m, + const SourceRange& loc) -> std::shared_ptr { auto it = env.find(name); if (it == env.end()) return nullptr; @@ -128,4 +149,5 @@ void import_methods(const std::shared_ptr& mod, const std::strin script::defineMethodsInModule(mod, definitions, resolvers, self); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/import_method.h b/torch/csrc/jit/import_method.h index 9ceeb74..c8575d6 100644 --- a/torch/csrc/jit/import_method.h +++ b/torch/csrc/jit/import_method.h @@ -1,13 +1,16 @@ #pragma once #include -#include #include +#include namespace torch { namespace jit { -TORCH_API void import_methods(const std::shared_ptr& mod, const std::string& src, const std::vector& constant_table); +TORCH_API void import_methods( + const std::shared_ptr& mod, + const std::string& src, + const std::vector& constant_table); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index f0e35f6..38dbe69 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -1,46 +1,45 @@ -#include #include +#include -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include #include -#include -#include -#include #include -#include +#include +#include #include +#include +#include +#include +#include #include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include #include -#include -#include -#include -#include #include -#include -#include -#include -#include +#include +#include +#include +#include #include +#include +#include #include @@ -53,14 +52,13 @@ #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // TODO: make a fake future for python namespace detail { -class Future { - -}; -} +class Future {}; +} // namespace detail namespace { @@ -68,10 +66,10 @@ using autograd::variable_list; bool loadPythonClasses() { // Leaving this code here, because it will likely be useful at some point - //PyObject *jit_module = PyImport_ImportModule("torch.jit"); - //THPUtils_assert(jit_module, "class loader couldn't access " - //"torch.jit module"); - //PyObject *jit_dict = PyModule_GetDict(jit_module); + // PyObject *jit_module = PyImport_ImportModule("torch.jit"); + // THPUtils_assert(jit_module, "class loader couldn't access " + //"torch.jit module"); + // PyObject *jit_dict = PyModule_GetDict(jit_module); return true; } @@ -86,96 +84,127 @@ std::string runJITCPPTests() { std::string runJITCPPTests(); #endif -void initJITBindings(PyObject *module) { +void initJITBindings(PyObject* module) { auto m = py::handle(module).cast(); py::register_exception(m, "JITException"); - py::class_(m, "IODescriptor"); // NOLINT(bugprone-unused-raii) + py::class_( + m, "IODescriptor"); // NOLINT(bugprone-unused-raii) m.def("_jit_init", loadPythonClasses) #if USE_CUDA_FUSER || USE_CPU_FUSER - .def("_jit_debug_fuser_num_cached_kernel_specs", - torch::jit::fuser::debugNumCachedKernelSpecs) + .def( + "_jit_debug_fuser_num_cached_kernel_specs", + torch::jit::fuser::debugNumCachedKernelSpecs) #endif - .def("_jit_pass_onnx", ToONNX) - .def("_jit_pass_lower_all_tuples", LowerAllTuples) - .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX) - .def("_jit_pass_fuse", FuseGraph) - .def("_jit_pass_dce", [](std::shared_ptr& g) { - return EliminateDeadCode(g->block()); // overload resolution - }) - .def("_jit_pass_cse", [](std::shared_ptr& g) { - return EliminateCommonSubexpression(g); // overload resolution - }) - .def("_jit_pass_remove_inplace_ops", [](std::shared_ptr g) { - return RemoveInplaceOps(g); - }) - .def("_jit_pass_constant_pooling", ConstantPooling) - .def("_jit_pass_peephole", [](const std::shared_ptr& g, bool addmm_fusion_enabled) { - return PeepholeOptimize(g, addmm_fusion_enabled); - }, py::arg("graph"), py::arg("addmm_fusion_enabled") = false) - .def("_jit_pass_canonicalize", [](const std::shared_ptr& g) { - return Canonicalize(g); - }) - .def("_jit_pass_lint", LintGraph) - .def("_jit_pass_shape_analysis", [](std::shared_ptr graph, std::vector inputs, bool with_grad) { - setInputTypes(*graph, ArgumentSpec(with_grad, fmap(inputs), inputs.size())); - PropagateInputShapes(graph); - }) - .def("_jit_pass_complete_shape_analysis", [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { - CompleteArgumentSpec spec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs())); - auto graph_inputs = graph->inputs(); - JIT_ASSERT(spec.size() == graph_inputs.size()); - for (size_t i = 0; i < graph_inputs.size(); ++i) { - graph_inputs[i]->setType(spec.at(i)); - } - PropagateInputShapes(graph); - }) - .def("_jit_pass_remove_expands", RemoveExpands) - .def("_jit_pass_erase_number_types", EraseNumberTypes) - .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX) - .def("_jit_pass_loop_unrolling", UnrollLoops) - .def("_jit_pass_constant_propagation", [](std::shared_ptr& g) { - return ConstantPropagation(g); - }) - .def("_jit_pass_erase_shape_information", EraseShapeInformation) - .def("_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr graph) { - CreateAutodiffSubgraphs(graph); - }) - .def("_jit_run_cpp_tests", [] { - // We have to release the GIL inside this method, because if we happen to - // initialize the autograd engine in these tests, the newly spawned worker threads will - // try to initialize their PyThreadState*, and they need the GIL for this. - AutoNoGIL _no_gil; - return runJITCPPTests(); - }) - .def("_jit_flatten", [](py::handle& obj) { - auto res = python::flatten(obj); - return std::make_pair(res.vars, res.desc); - }) - .def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) { - return py::reinterpret_steal(python::unflatten(vars, desc)); - }) - .def("_jit_pass_onnx_block", BlockToONNX) - .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) - .def("_jit_pass_canonicalize_ops", CanonicalizeOps) - .def("_jit_pass_specialize_undef", specializeUndef) - .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) - .def("_jit_differentiate", [](Graph &g) { - // the python binding slightly differs in semantics - // it makes a copy of the input Graph, and works on that - // jit::differentiate mutates the input Graph - auto g_clone = g.copy(); - return differentiate(g_clone); - }) - .def("_jit_check_alias_annotation", []( - std::shared_ptr g, - py::tuple args, - const std::string& unqualified_op_name) { - auto stack = toStack(args); - checkAliasAnnotation(g, std::move(stack), unqualified_op_name); - }); + .def("_jit_pass_onnx", ToONNX) + .def("_jit_pass_lower_all_tuples", LowerAllTuples) + .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX) + .def("_jit_pass_fuse", FuseGraph) + .def( + "_jit_pass_dce", + [](std::shared_ptr& g) { + return EliminateDeadCode(g->block()); // overload resolution + }) + .def( + "_jit_pass_cse", + [](std::shared_ptr& g) { + return EliminateCommonSubexpression(g); // overload resolution + }) + .def( + "_jit_pass_remove_inplace_ops", + [](std::shared_ptr g) { return RemoveInplaceOps(g); }) + .def("_jit_pass_constant_pooling", ConstantPooling) + .def( + "_jit_pass_peephole", + [](const std::shared_ptr& g, bool addmm_fusion_enabled) { + return PeepholeOptimize(g, addmm_fusion_enabled); + }, + py::arg("graph"), + py::arg("addmm_fusion_enabled") = false) + .def( + "_jit_pass_canonicalize", + [](const std::shared_ptr& g) { return Canonicalize(g); }) + .def("_jit_pass_lint", LintGraph) + .def( + "_jit_pass_shape_analysis", + [](std::shared_ptr graph, + std::vector inputs, + bool with_grad) { + setInputTypes( + *graph, + ArgumentSpec(with_grad, fmap(inputs), inputs.size())); + PropagateInputShapes(graph); + }) + .def( + "_jit_pass_complete_shape_analysis", + [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { + CompleteArgumentSpec spec( + with_grad, + evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs())); + auto graph_inputs = graph->inputs(); + JIT_ASSERT(spec.size() == graph_inputs.size()); + for (size_t i = 0; i < graph_inputs.size(); ++i) { + graph_inputs[i]->setType(spec.at(i)); + } + PropagateInputShapes(graph); + }) + .def("_jit_pass_remove_expands", RemoveExpands) + .def("_jit_pass_erase_number_types", EraseNumberTypes) + .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX) + .def("_jit_pass_loop_unrolling", UnrollLoops) + .def( + "_jit_pass_constant_propagation", + [](std::shared_ptr& g) { return ConstantPropagation(g); }) + .def("_jit_pass_erase_shape_information", EraseShapeInformation) + .def( + "_jit_pass_create_autodiff_subgraphs", + [](std::shared_ptr graph) { CreateAutodiffSubgraphs(graph); }) + .def( + "_jit_run_cpp_tests", + [] { + // We have to release the GIL inside this method, because if we + // happen to initialize the autograd engine in these tests, the + // newly spawned worker threads will try to initialize their + // PyThreadState*, and they need the GIL for this. + AutoNoGIL _no_gil; + return runJITCPPTests(); + }) + .def( + "_jit_flatten", + [](py::handle& obj) { + auto res = python::flatten(obj); + return std::make_pair(res.vars, res.desc); + }) + .def( + "_jit_unflatten", + [](autograd::variable_list vars, python::IODescriptor& desc) { + return py::reinterpret_steal( + python::unflatten(vars, desc)); + }) + .def("_jit_pass_onnx_block", BlockToONNX) + .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) + .def("_jit_pass_canonicalize_ops", CanonicalizeOps) + .def("_jit_pass_specialize_undef", specializeUndef) + .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) + .def( + "_jit_differentiate", + [](Graph& g) { + // the python binding slightly differs in semantics + // it makes a copy of the input Graph, and works on that + // jit::differentiate mutates the input Graph + auto g_clone = g.copy(); + return differentiate(g_clone); + }) + .def( + "_jit_check_alias_annotation", + [](std::shared_ptr g, + py::tuple args, + const std::string& unqualified_op_name) { + auto stack = toStack(args); + checkAliasAnnotation(g, std::move(stack), unqualified_op_name); + }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") @@ -186,52 +215,41 @@ void initJITBindings(PyObject *module) { }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "ArgumentSpec"); - py::class_(m, "Code") - .def("grad_executors", [](Code& c) { - return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end()); - }); + py::class_(m, "Code").def("grad_executors", [](Code& c) { + return py::make_iterator( + c.grad_executors().begin(), c.grad_executors().end()); + }); py::class_(m, "ExecutionPlanState") - .def_property_readonly("graph", [](ExecutionPlanState& s) { - return s.graph; - }) - .def_property_readonly("code", [](ExecutionPlanState& s) { - return s.code; - }); + .def_property_readonly( + "graph", [](ExecutionPlanState& s) { return s.graph; }) + .def_property_readonly( + "code", [](ExecutionPlanState& s) { return s.code; }); py::class_(m, "Gradient") - .def_property_readonly("f", [](Gradient& m) { - return m.f; - }) - .def_property_readonly("df", [](Gradient& m) { - return m.df; - }) - .def_property_readonly("f_real_outputs", [](Gradient& m) { - return m.f_real_outputs; - }) - .def_property_readonly("df_input_vjps", [](Gradient& m) { - return m.df_input_vjps; - }) - .def_property_readonly("df_input_captured_inputs", [](Gradient& m) { - return m.df_input_captured_inputs; - }) - .def_property_readonly("df_input_captured_outputs", [](Gradient& m) { - return m.df_input_captured_outputs; - }) - .def_property_readonly("df_output_vjps", [](Gradient& m) { - return m.df_output_vjps; - }); + .def_property_readonly("f", [](Gradient& m) { return m.f; }) + .def_property_readonly("df", [](Gradient& m) { return m.df; }) + .def_property_readonly( + "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; }) + .def_property_readonly( + "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; }) + .def_property_readonly( + "df_input_captured_inputs", + [](Gradient& m) { return m.df_input_captured_inputs; }) + .def_property_readonly( + "df_input_captured_outputs", + [](Gradient& m) { return m.df_input_captured_outputs; }) + .def_property_readonly( + "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; }); py::class_(m, "GraphExecutorState") - .def_property_readonly("graph", [](GraphExecutorState& s) { - return s.graph; - }) - .def_property_readonly("execution_plans", [](GraphExecutorState& s) { - return s.execution_plans; - }) - .def_property_readonly("fallback", [](GraphExecutorState& s) { - return s.fallback; - }); + .def_property_readonly( + "graph", [](GraphExecutorState& s) { return s.graph; }) + .def_property_readonly( + "execution_plans", + [](GraphExecutorState& s) { return s.execution_plans; }) + .def_property_readonly( + "fallback", [](GraphExecutorState& s) { return s.fallback; }); py::class_(m, "GraphExecutor", py::dynamic_attr()) .def( @@ -267,8 +285,9 @@ void initJITBindings(PyObject *module) { "get_debug_state", [](GraphExecutor& ge) { return ge.getDebugState(); }) .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object { - const auto & graph = ge.graph(); - auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs()); + const auto& graph = ge.graph(); + auto stack = + evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs()); { AutoNoGIL no_gil_guard; ge.run(stack); @@ -280,79 +299,91 @@ void initJITBindings(PyObject *module) { .def(py::init()) .def( "write_record", - [](PyTorchStreamWriter& self, const std::string& name, const char* data, size_t size) { - return self.writeRecord(name, data, size); - }) + [](PyTorchStreamWriter& self, + const std::string& name, + const char* data, + size_t size) { return self.writeRecord(name, data, size); }) .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile); py::class_(m, "PyTorchFileReader") .def(py::init()) - .def( - "get_record", - [](PyTorchStreamReader& self, const std::string& key) { - at::DataPtr data; - size_t size; - std::tie(data, size) = self.getRecord(key); - return py::bytes(reinterpret_cast(data.get()), size); - }); - + .def("get_record", [](PyTorchStreamReader& self, const std::string& key) { + at::DataPtr data; + size_t size; + std::tie(data, size) = self.getRecord(key); + return py::bytes(reinterpret_cast(data.get()), size); + }); - m.def("_jit_get_operation", [](const std::string& qualified_name) { - try { - auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(symbol); - AT_CHECK(!operations.empty(), "No such operator ", qualified_name); - AT_CHECK( - operations.size() == 1, - "Found ", operations.size(), " overloads for operator ", - qualified_name, "! Overloads are not supported from Python."); - std::shared_ptr op = operations[0]; - AT_ASSERT(op != nullptr); - std::ostringstream docstring; - docstring << "Automatically bound operator '" << qualified_name - << "' with schema: " << op->schema(); - return py::cpp_function([op](py::args args, py::kwargs kwargs) { - return invokeOperatorFromPython( - *op, std::move(args), std::move(kwargs)); - }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str())); - } catch (const c10::Error& error) { - throw std::runtime_error(error.what_without_backtrace()); - } - }, py::arg("qualified_name")); + m.def( + "_jit_get_operation", + [](const std::string& qualified_name) { + try { + auto symbol = Symbol::fromQualString(qualified_name); + auto operations = getAllOperatorsFor(symbol); + AT_CHECK(!operations.empty(), "No such operator ", qualified_name); + AT_CHECK( + operations.size() == 1, + "Found ", + operations.size(), + " overloads for operator ", + qualified_name, + "! Overloads are not supported from Python."); + std::shared_ptr op = operations[0]; + AT_ASSERT(op != nullptr); + std::ostringstream docstring; + docstring << "Automatically bound operator '" << qualified_name + << "' with schema: " << op->schema(); + return py::cpp_function( + [op](py::args args, py::kwargs kwargs) { + return invokeOperatorFromPython( + *op, std::move(args), std::move(kwargs)); + }, + py::name(qualified_name.c_str()), + py::doc(docstring.str().c_str())); + } catch (const c10::Error& error) { + throw std::runtime_error(error.what_without_backtrace()); + } + }, + py::arg("qualified_name")); py::class_(m, "FunctionSchema") - .def_property_readonly("name", [](FunctionSchema& self) { return self.name(); }) - .def_property_readonly("arguments", [](FunctionSchema& self) { return self.arguments(); }) - .def_property_readonly("returns", [](FunctionSchema& self) { return self.returns(); }); + .def_property_readonly( + "name", [](FunctionSchema& self) { return self.name(); }) + .def_property_readonly( + "arguments", [](FunctionSchema& self) { return self.arguments(); }) + .def_property_readonly( + "returns", [](FunctionSchema& self) { return self.returns(); }); py::class_(m, "Argument") - .def_property_readonly("name", [](Argument& self) { return self.name(); }) - .def_property_readonly("type", [](Argument& self) { return self.type(); }) - .def_property_readonly("N", [](Argument& self) -> py::object { - return (self.N()) ? py::cast(*self.N()) : py::none(); - }) - .def_property_readonly("default_value", [](Argument& self) -> py::object { - if(!self.default_value()) - return py::none(); - IValue v = *self.default_value(); - return toPyObject(std::move(v)); - }); + .def_property_readonly("name", [](Argument& self) { return self.name(); }) + .def_property_readonly("type", [](Argument& self) { return self.type(); }) + .def_property_readonly( + "N", + [](Argument& self) -> py::object { + return (self.N()) ? py::cast(*self.N()) : py::none(); + }) + .def_property_readonly("default_value", [](Argument& self) -> py::object { + if (!self.default_value()) + return py::none(); + IValue v = *self.default_value(); + return toPyObject(std::move(v)); + }); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); auto operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { - return op->schema(); - }); + return op->schema(); + }); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "Future"); - m.def("fork", [](script::Module &sm, py::args args) { + m.def("fork", [](script::Module& sm, py::args args) { // TODO: this is a fake stub return detail::Future(); }); - m.def("wait", [](detail::Future &fut) { + m.def("wait", [](detail::Future& fut) { // TODO: this is a fake stub }); @@ -368,4 +399,5 @@ void initJITBindings(PyObject *module) { initRegisterBatchOpsBindings(module); } -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/init.h b/torch/csrc/jit/init.h index fbc902e..99b21d4 100644 --- a/torch/csrc/jit/init.h +++ b/torch/csrc/jit/init.h @@ -1,7 +1,9 @@ #pragma once -namespace torch { namespace jit { +namespace torch { +namespace jit { -void initJITBindings(PyObject *module); +void initJITBindings(PyObject* module); -}} +} +} // namespace torch diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 2c45c19..d7e9028 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -6,10 +6,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -25,7 +25,8 @@ #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // Before we translate to intepreter instructions, we do // some preprocessing of the graph to turn it into a form that is closer @@ -36,11 +37,11 @@ namespace torch { namespace jit { // *. computes move_flags (see Outputs), and inserts // * Drop nodes are inserted for any node that is unused to create a dummy use // that will cause the interpreter to free the node. -// A drop node is just a node with no outputs that just pops its inputs off the stack, -// to ensure the interpreter release references to nodes that are never used. -// Drop nodes are also inserted when the last use of a node is in some conditionally -// run control flow (e.g. one side of an If) and the interpreter must free -// the node only after the control flow has reconverged +// A drop node is just a node with no outputs that just pops its inputs off +// the stack, to ensure the interpreter release references to nodes that are +// never used. Drop nodes are also inserted when the last use of a node is in +// some conditionally run control flow (e.g. one side of an If) and the +// interpreter must free the node only after the control flow has reconverged // Outputs are: // * graph - the post processed copy of g // * move_flags[n] - a list of booleans, one for each input, @@ -58,24 +59,25 @@ Value* createTripCountConjunctiveCondition( // Emit initial comparison -- initial_trip_count < max_trip_count Value* initial_comparison_value = g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1)) - ->output()->setType(BoolType::get()); + ->output() + ->setType(BoolType::get()); // Replace initial condition with logical `and` of trip count and // initial condition Value* new_cond = g->insertNode( g->create(aten::__and__, {initial_comparison_value, cond}, 1)) - ->output()->setType(BoolType::get()); + ->output() + ->setType(BoolType::get()); return new_cond; } // this currently just _removes_ the trip count inputs and checks they are // unused. In the future they will be desugared into normal arithmetic to // provide a loop counter -void desugarTripCounts(Block * b) { - for(auto n : b->nodes()) { - - if(n->kind() == prim::Loop) { +void desugarTripCounts(Block* b) { + for (auto n : b->nodes()) { + if (n->kind() == prim::Loop) { auto g = n->owningGraph(); auto body_block = n->blocks()[0]; @@ -109,9 +111,10 @@ void desugarTripCounts(Block * b) { Value* const_one = g->insertConstant(1); Value* inc_trip_count = - g->insertNode(g->create( - aten::add, {block_trip_count_input, const_one}, 1)) - ->output()->setType(IntType::get()); + g->insertNode( + g->create(aten::add, {block_trip_count_input, const_one}, 1)) + ->output() + ->setType(IntType::get()); body_block->insertOutput(1, inc_trip_count); Value* body_cond = createTripCountConjunctiveCondition( @@ -120,16 +123,17 @@ void desugarTripCounts(Block * b) { body_block->insertOutput(0, body_cond); } } - for(auto sb : n->blocks()) { + for (auto sb : n->blocks()) { desugarTripCounts(sb); } } } -// removes all inputs and outputs to a graph, replacing them with Load Store nodes -static void flattenIO(Graph & graph) { +// removes all inputs and outputs to a graph, replacing them with Load Store +// nodes +static void flattenIO(Graph& graph) { auto load = graph.prependNode(graph.create(prim::Load, 0)); - for(auto old_input : graph.inputs()) { + for (auto old_input : graph.inputs()) { auto nv = load->addOutput(); nv->setType(old_input->type()); old_input->replaceAllUsesWith(nv); @@ -142,42 +146,40 @@ static void flattenIO(Graph & graph) { graph.eraseOutput(graph.outputs().size() - 1); } - // insert Drop nodes to kill references for anything unused: // this can happen in a few places, e.g. when a node returns // many values but only one is used // a, b = foo() // return a -void dropUnused(Block *b) { +void dropUnused(Block* b) { auto createDropIfUnused = [&](ArrayRef values) -> Node* { std::vector to_drop; - for(auto v : values) { - if(v->uses().size() == 0) + for (auto v : values) { + if (v->uses().size() == 0) to_drop.push_back(v); } - if(to_drop.size() == 0) + if (to_drop.size() == 0) return nullptr; return b->owningGraph()->create(prim::Drop, to_drop, 0); }; - if(auto d = createDropIfUnused(b->inputs())) { + if (auto d = createDropIfUnused(b->inputs())) { b->prependNode(d); } - for(auto n : b->nodes()) { - if(auto d = createDropIfUnused(n->outputs())) { + for (auto n : b->nodes()) { + if (auto d = createDropIfUnused(n->outputs())) { d->insertAfter(n); } - for(auto b : n->blocks()) + for (auto b : n->blocks()) dropUnused(b); } } - // for each input, should we move rather than copy the inputs -std::unordered_map> findLastUses(Graph & g) { +std::unordered_map> findLastUses(Graph& g) { // struct to share common data structures struct FindLastUses { - Graph & graph; + Graph& graph; // have we seen this value, yet, if not, it is the last use of the value std::unordered_set seen; @@ -187,40 +189,39 @@ std::unordered_map> findLastUses(Graph & g) { // when the If/Loop exits. These are created and inserted on demand. std::unordered_map drop_for_node; - FindLastUses(Graph & g) - : graph(g) { + FindLastUses(Graph& g) : graph(g) { scanBlock(graph.block()); } - void scanBlock(Block * b) { + void scanBlock(Block* b) { scanNode(b->return_node()); - for(auto n : b->nodes().reverse()) { + for (auto n : b->nodes().reverse()) { scanNode(n); } } - void scanNode(Node * n) { - for(auto b : n->blocks()) { + void scanNode(Node* n) { + for (auto b : n->blocks()) { scanBlock(b); } move_flags[n].resize(n->inputs().size()); - // scan backwards so if a value is used twice in the list then it is a move - for(size_t i = n->inputs().size(); i > 0; --i) { - scanUse(n, i-1); + // scan backwards so if a value is used twice in the list then it is a + // move + for (size_t i = n->inputs().size(); i > 0; --i) { + scanUse(n, i - 1); } } - void scanUse(Node * n, size_t i) { - auto & move_flags_n = move_flags[n]; + void scanUse(Node* n, size_t i) { + auto& move_flags_n = move_flags[n]; auto v = n->inputs()[i]; auto inserted = seen.insert(v).second; - if(!inserted) { + if (!inserted) { move_flags_n[i] = false; return; } // the last use of v may be in a nested block of an If or Loop statement - // find the node 'same_depth_node' at the same depth as the definition of v, - // and consider that node to be the last use of v. - // This ensures we do not delete nodes in nested scopes - // that may be executed multiple times + // find the node 'same_depth_node' at the same depth as the definition of + // v, and consider that node to be the last use of v. This ensures we do + // not delete nodes in nested scopes that may be executed multiple times // and that nodes used on one side of an if // but not the other get deleted regardless of the branch // e.g. @@ -230,12 +231,13 @@ std::unordered_map> findLastUses(Graph & g) { // drop(a) // In other words, we find the first program point for v that // _reverse_ dominates the definition of v, and add a drop point there. - Node * same_depth_node = findOwnerInBlock(n, v->node()->owningBlock()); - JIT_ASSERT(same_depth_node); // failure means v is not in scope for n, use lint! + Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock()); + JIT_ASSERT( + same_depth_node); // failure means v is not in scope for n, use lint! // In the case where v and n are in the same block, just mark // its move_flags to be true - if(same_depth_node == n) { + if (same_depth_node == n) { move_flags_n[i] = true; return; } @@ -243,7 +245,8 @@ std::unordered_map> findLastUses(Graph & g) { // in the case where the use is nested in a block // add a Drop node after that block which will drop 'v'. move_flags_n[i] = false; - addToDropIfNotExists(findOrCreateDropInstructionForNode(same_depth_node), v); + addToDropIfNotExists( + findOrCreateDropInstructionForNode(same_depth_node), v); } // finds the node in block 'block' that contains in 'n' @@ -252,16 +255,16 @@ std::unordered_map> findLastUses(Graph & g) { // n1: if : // n2: b = a + a // findOwnerInBlock(n2, n0.block()) == n1 - Node * findOwnerInBlock(Node * n, Block * block) { - while(n != nullptr && block != n->owningBlock()) { + Node* findOwnerInBlock(Node* n, Block* block) { + while (n != nullptr && block != n->owningBlock()) { n = n->owningBlock()->owningNode(); } return n; } - Node * findOrCreateDropInstructionForNode(Node * n) { + Node* findOrCreateDropInstructionForNode(Node* n) { auto it = drop_for_node.find(n); - if(it == drop_for_node.end()) { + if (it == drop_for_node.end()) { auto drop_node = graph.create(prim::Drop, 0); drop_node->insertAfter(n); it = drop_for_node.emplace(n, drop_node).first; @@ -269,10 +272,10 @@ std::unordered_map> findLastUses(Graph & g) { return it->second; } - void addToDropIfNotExists(Node * drop, Value * v) { - for(auto i : drop->inputs()) { + void addToDropIfNotExists(Node* drop, Value* v) { + for (auto i : drop->inputs()) { // we already accounted for this use - if(i == v) + if (i == v) return; } drop->addInput(v); @@ -282,19 +285,18 @@ std::unordered_map> findLastUses(Graph & g) { return FindLastUses(g).move_flags; } -} //namespace +} // namespace // pre-processing that happens once per graph struct PreprocessGraph { - PreprocessGraph(Graph & g) - : graph(g.copy()) { + PreprocessGraph(Graph& g) : graph(g.copy()) { n_outputs = graph->outputs().size(); desugarTripCounts(graph->block()); flattenIO(*graph); dropUnused(graph->block()); // fill in move_flags by scanning blocks; move_flags = findLastUses(*graph); - //TODO: desugar Loop trip counts, for now we drop trip counts + // TODO: desugar Loop trip counts, for now we drop trip counts } // Outputs of the preprocessing: std::shared_ptr graph; @@ -311,9 +313,13 @@ struct PreprocessGraph { // Note: this is currently unused but will probably be useful in the future, // so we keep it around struct ContainerTensor : public at::TensorImpl { -public: + public: ContainerTensor() - : TensorImpl(at::UndefinedTensorId(), caffe2::TypeMeta(), nullptr, /* is_variable */ false) {} + : TensorImpl( + at::UndefinedTensorId(), + caffe2::TypeMeta(), + nullptr, + /* is_variable */ false) {} ~ContainerTensor() override = default; at::IntList sizes() const override { @@ -335,7 +341,7 @@ public: // which are stored in the ListHandle struct // start is an offset into int_data of Code for ListHandle // and bool_data of Code for ListHandle -template +template struct ListHandle { int start; int size; @@ -358,24 +364,22 @@ struct Instruction { std::shared_ptr debug_location; // for error reporting }; - int relativeJump(int from_inst, int to_inst) { return to_inst - (from_inst + 1); } struct CodeImpl { - CodeImpl(const std::shared_ptr& graph_) - : preprocess(*graph_) { + CodeImpl(const std::shared_ptr& graph_) : preprocess(*graph_) { graph = preprocess.graph; insertNodesFromBlock(graph->block()); } // jump when input is false void createJumpFalse(int from_inst, int to_inst) { - auto & inst = instructions[from_inst]; + auto& inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); - inst.callback = [offset](Stack & stack) { + inst.callback = [offset](Stack& stack) { auto t = pop(stack).toBool(); return t ? 0 : offset; }; @@ -384,10 +388,10 @@ struct CodeImpl { // jump when input is true void createJumpTrue(int from_inst, int to_inst) { - auto & inst = instructions[from_inst]; + auto& inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); - inst.callback = [offset](Stack & stack) { + inst.callback = [offset](Stack& stack) { auto t = pop(stack).toBool(); return t ? offset : 0; }; @@ -395,19 +399,17 @@ struct CodeImpl { } void createJump(int from_inst, int to_inst) { - auto & inst = instructions[from_inst]; + auto& inst = instructions[from_inst]; JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); - inst.callback = [=](Stack & stack) { - return offset; - }; + inst.callback = [=](Stack& stack) { return offset; }; inst.debug_name = prim::Jump; } void insertNodesFromBlock(Block* block) { - for(auto node : block->nodes()) { - const auto & source_location = node->getSourceLocation(); - switch(node->kind()) { + for (auto node : block->nodes()) { + const auto& source_location = node->getSourceLocation(); + switch (node->kind()) { case prim::If: { // x = if c: // @@ -426,17 +428,31 @@ struct CodeImpl { // x = vt // end: - // prim::Placeholder instructions are replaced with branch instructions - // when the branch target locations are known - auto cond_branch = insertInstruction(prim::Placeholder, source_location, node->inputs(), moveFlags(node), {}); + // prim::Placeholder instructions are replaced with branch + // instructions when the branch target locations are known + auto cond_branch = insertInstruction( + prim::Placeholder, + source_location, + node->inputs(), + moveFlags(node), + {}); auto then_block = node->blocks()[0]; auto else_block = node->blocks()[1]; insertNodesFromBlock(else_block); - insertAssign(source_location,else_block->outputs(), moveFlags(else_block), node->outputs()); - auto jump = insertInstruction(prim::Placeholder, source_location, {}, {}, {}); + insertAssign( + source_location, + else_block->outputs(), + moveFlags(else_block), + node->outputs()); + auto jump = + insertInstruction(prim::Placeholder, source_location, {}, {}, {}); auto then_block_start = instructions.size(); insertNodesFromBlock(then_block); - insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs()); + insertAssign( + source_location, + then_block->outputs(), + moveFlags(then_block), + node->outputs()); createJump(jump, instructions.size()); createJumpTrue(cond_branch, then_block_start); } break; @@ -458,114 +474,140 @@ struct CodeImpl { auto body_block = node->blocks()[0]; // before assign op: stack: ... - insertAssign(source_location, node->inputs(), moveFlags(node), body_block->inputs()); + insertAssign( + source_location, + node->inputs(), + moveFlags(node), + body_block->inputs()); // after assign op: stack: ... // cond_branch consumes from top of the stack - auto cond_branch = insertInstruction(prim::Placeholder, source_location,{}, {}, {}); + auto cond_branch = + insertInstruction(prim::Placeholder, source_location, {}, {}, {}); // after branch: stack: ... auto entry = instructions.size(); insertNodesFromBlock(body_block); // before assign op: stack: ... - insertAssign(source_location, body_block->outputs(), moveFlags(body_block), body_block->inputs()); + insertAssign( + source_location, + body_block->outputs(), + moveFlags(body_block), + body_block->inputs()); // after assign op: stack: ... - auto cond_branch_end = insertInstruction(prim::Placeholder, source_location, {}, {}, {}); + auto cond_branch_end = + insertInstruction(prim::Placeholder, source_location, {}, {}, {}); // after branch: stack: ... aliasRegistersTo(node->outputs(), body_block->inputs()); createJumpFalse(cond_branch, instructions.size()); createJumpTrue(cond_branch_end, entry); } break; - default: { - insertInstruction(node); - } break; + default: { insertInstruction(node); } break; } } } - size_t insertInstruction(Node * n) { - auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs()); + size_t insertInstruction(Node* n) { + auto inst = insertInstruction( + n->kind(), + n->getSourceLocation(), + n->inputs(), + moveFlags(n), + n->outputs()); instructions[inst].callback = getOperation(n); return inst; } - size_t insertInstruction(Symbol sym, - std::shared_ptr debug_location, - ArrayRef inputs, - ArrayRef move_flags, - ArrayRef outputs) { + size_t insertInstruction( + Symbol sym, + std::shared_ptr debug_location, + ArrayRef inputs, + ArrayRef move_flags, + ArrayRef outputs) { instructions.emplace_back(); - auto & inst = instructions.back(); + auto& inst = instructions.back(); inst.debug_name = sym; inst.debug_location = std::move(debug_location); listBegin(inst.inputs.values); - for(auto input : inputs) { + for (auto input : inputs) { listInsert(inst.inputs.values, getOrAllocateRegister(input, true)); } listBegin(inst.inputs.free_flags); - for(auto flag : move_flags) { + for (auto flag : move_flags) { listInsert(inst.inputs.free_flags, flag); } listBegin(inst.outputs); - for(auto output : outputs) { + for (auto output : outputs) { listInsert(inst.outputs, getOrAllocateRegister(output)); } return instructions.size() - 1; } - ArrayRef moveFlags(Node * n) { + ArrayRef moveFlags(Node* n) { return preprocess.move_flags.at(n); } - ArrayRef moveFlags(Block *b) { + ArrayRef moveFlags(Block* b) { return moveFlags(b->return_node()); } - size_t insertAssign(std::shared_ptr debug_location, ArrayRef inputs, ArrayRef move_flags, ArrayRef outputs) { - auto inst = insertInstruction(prim::Assign, std::move(debug_location),inputs, move_flags, outputs); - // This node effectively forwards its inputs into different places in a register list. - // We don't need to manipulate the stack in any way, because all inputs are also outputs, - // and the interpreter will take care of putting them in correct places. + size_t insertAssign( + std::shared_ptr debug_location, + ArrayRef inputs, + ArrayRef move_flags, + ArrayRef outputs) { + auto inst = insertInstruction( + prim::Assign, std::move(debug_location), inputs, move_flags, outputs); + // This node effectively forwards its inputs into different places in a + // register list. We don't need to manipulate the stack in any way, because + // all inputs are also outputs, and the interpreter will take care of + // putting them in correct places. instructions[inst].callback = [](Stack& stack) { return 0; }; return inst; } // helpers to build/access RegList objects - int get(const ListHandle & list, int i) const { + int get(const ListHandle& list, int i) const { return int_data[list.start + i]; } - bool get(const ListHandle & list, int i) const { + bool get(const ListHandle& list, int i) const { return bool_data[list.start + i]; } - void listBegin(ListHandle & list) { + void listBegin(ListHandle& list) { list.start = int_data.size(); list.size = 0; } - void listInsert(ListHandle & list, int value) { - JIT_ASSERTM(list.start + list.size == (int)int_data.size(), "another list already started"); + void listInsert(ListHandle& list, int value) { + JIT_ASSERTM( + list.start + list.size == (int)int_data.size(), + "another list already started"); int_data.push_back(value); list.size++; } - void listBegin(ListHandle & list) { + void listBegin(ListHandle& list) { list.start = bool_data.size(); list.size = 0; } - void listInsert(ListHandle & list, int value) { - JIT_ASSERTM(list.start + list.size == (int)bool_data.size(), "another list already started"); + void listInsert(ListHandle& list, int value) { + JIT_ASSERTM( + list.start + list.size == (int)bool_data.size(), + "another list already started"); bool_data.push_back(value); list.size++; } // must be called before any new_allocations are used, otherwise they will // already have registers assigned - void aliasRegistersTo(ArrayRef new_allocations, ArrayRef existing_allocations) { + void aliasRegistersTo( + ArrayRef new_allocations, + ArrayRef existing_allocations) { JIT_ASSERT(new_allocations.size() == existing_allocations.size()); - for(size_t i = 0; i < new_allocations.size(); ++i) { + for (size_t i = 0; i < new_allocations.size(); ++i) { auto n = new_allocations[i]->unique(); auto e = existing_allocations[i]->unique(); JIT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0); unique_to_reg[n] = unique_to_reg[e]; } } - int getOrAllocateRegister(Value * n, bool required = false) { + int getOrAllocateRegister(Value* n, bool required = false) { size_t u = n->unique(); - if(unique_to_reg.count(u) > 0) + if (unique_to_reg.count(u) > 0) return unique_to_reg[u]; JIT_ASSERT(!required); int r = register_size++; @@ -576,7 +618,7 @@ struct CodeImpl { const std::vector& grad_executors() { if (!grad_executors_) { grad_executors_.emplace(); - for (Instruction & instr : instructions) { + for (Instruction& instr : instructions) { if (auto executor = detail::getGradExecutor(instr.callback)) { grad_executors_->push_back(executor); } @@ -585,33 +627,33 @@ struct CodeImpl { return *grad_executors_; } - void dumpInstruction(std::ostream & out, size_t pc) const { - auto writeList = [&](const ListHandle & list) { - for(int i = 0; i < list.size; i++) { - if(i > 0) + void dumpInstruction(std::ostream& out, size_t pc) const { + auto writeList = [&](const ListHandle& list) { + for (int i = 0; i < list.size; i++) { + if (i > 0) out << ", "; out << get(list, i); } }; - auto writeUseList = [&](const UseList & list) { - for(int i = 0; i < list.values.size; i++) { - if(i > 0) + auto writeUseList = [&](const UseList& list) { + for (int i = 0; i < list.values.size; i++) { + if (i > 0) out << ", "; - if(get(list.free_flags, i)) + if (get(list.free_flags, i)) out << "move(" << get(list.values, i) << ")"; else out << get(list.values, i); } }; - auto & inst = instructions.at(pc); + auto& inst = instructions.at(pc); writeList(inst.outputs); // NB: debug names are the kind of operator used to select // dispatch out << " = " << inst.debug_name.toUnqualString() << " "; writeUseList(inst.inputs); } - void dump(std::ostream & out) const { - for(size_t i = 0; i < instructions.size(); ++i) { + void dump(std::ostream& out) const { + for (size_t i = 0; i < instructions.size(); ++i) { dumpInstruction(out, i); out << "\n"; } @@ -626,7 +668,8 @@ struct CodeImpl { c10::optional> grad_executors_; PreprocessGraph preprocess; - std::unordered_map unique_to_reg; // map from unique of nodes to register in register table + std::unordered_map + unique_to_reg; // map from unique of nodes to register in register table friend struct InterpreterState; std::vector instructions; @@ -640,12 +683,11 @@ struct CodeImpl { // InterpreterState state that and used to compute a Code struct InterpreterStateImpl : c10::intrusive_ptr_target { - InterpreterStateImpl(const Code & code) - : function(code.pImpl), - int_data(function->int_data.data()), - bool_data(function->bool_data), - registers(function->register_size) { - } + InterpreterStateImpl(const Code& code) + : function(code.pImpl), + int_data(function->int_data.data()), + bool_data(function->bool_data), + registers(function->register_size) {} private: c10::intrusive_ptr intrusive_from_this() { @@ -654,60 +696,61 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } bool runImpl(Stack& stack) { - auto & instructions = function->instructions; + auto& instructions = function->instructions; size_t last = instructions.size(); while (pc < last) { - // std::cout << "executing " << pc << ": "; - // function->dumpInstruction(std::cout, pc); - // std::cout << "\n"; - auto & inst = instructions[pc]; - try { - loadTensorsFromRegisters(inst.inputs, stack); - size_t new_pc = pc + 1 + inst.callback(stack); - for (int i = inst.outputs.size - 1; i >= 0; --i) { - int reg = get(inst.outputs, i); - registers[reg] = pop(stack); - // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n"; - } - pc = new_pc; - } catch (Suspend& e) { - // wait() expects a single input - JIT_ASSERT(inst.inputs.values.size == 1); - - getOrCreateFuture(); - - if (get(inst.inputs.free_flags, 0)) { - // make sure the register is not freed once we are waked up - registers[get(inst.inputs.values, 0)] = e.future; - } - - // Make sure adding callback is the last step. - // Otherwise if e.future has completed, - // the current thread will continue running before it suspends. - InterpreterState state(intrusive_from_this()); - e.future->addCallback([state]() { - c10::global_work_queue().run( - InterpreterContinuation(state, Stack())); - }); - - return true; - } catch (Future::FutureError& e) { - // Error from the forked thread. - auto msg = e.error_msg; // copy the error for each callback - handleError(std::move(msg), false); - return false; - } catch (std::exception& e) { - // Error from the current thread - bool is_jit_exception = dynamic_cast(&e); - if (instructions[pc].debug_location) { - handleError(instructions[pc].debug_location->wrapException( - e, "operation failed in interpreter"), is_jit_exception); - } else { - handleError(e.what(), is_jit_exception); - } - return false; + // std::cout << "executing " << pc << ": "; + // function->dumpInstruction(std::cout, pc); + // std::cout << "\n"; + auto& inst = instructions[pc]; + try { + loadTensorsFromRegisters(inst.inputs, stack); + size_t new_pc = pc + 1 + inst.callback(stack); + for (int i = inst.outputs.size - 1; i >= 0; --i) { + int reg = get(inst.outputs, i); + registers[reg] = pop(stack); + // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n"; } + pc = new_pc; + } catch (Suspend& e) { + // wait() expects a single input + JIT_ASSERT(inst.inputs.values.size == 1); + + getOrCreateFuture(); + + if (get(inst.inputs.free_flags, 0)) { + // make sure the register is not freed once we are waked up + registers[get(inst.inputs.values, 0)] = e.future; + } + + // Make sure adding callback is the last step. + // Otherwise if e.future has completed, + // the current thread will continue running before it suspends. + InterpreterState state(intrusive_from_this()); + e.future->addCallback([state]() { + c10::global_work_queue().run(InterpreterContinuation(state, Stack())); + }); + + return true; + } catch (Future::FutureError& e) { + // Error from the forked thread. + auto msg = e.error_msg; // copy the error for each callback + handleError(std::move(msg), false); + return false; + } catch (std::exception& e) { + // Error from the current thread + bool is_jit_exception = dynamic_cast(&e); + if (instructions[pc].debug_location) { + handleError( + instructions[pc].debug_location->wrapException( + e, "operation failed in interpreter"), + is_jit_exception); + } else { + handleError(e.what(), is_jit_exception); + } + return false; + } } if (future) { auto num_outputs = function->preprocess.n_outputs; @@ -762,22 +805,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } } - int get(const ListHandle & list, int i) { + int get(const ListHandle& list, int i) { return int_data[list.start + i]; }; - bool get(const ListHandle & list, int i) { + bool get(const ListHandle& list, int i) { return bool_data[list.start + i]; } - void loadTensorsFromRegisters(const UseList & uses, Stack & stack) { - for(int i = 0; i < uses.values.size; i++) { - int reg = get(uses.values,i); + void loadTensorsFromRegisters(const UseList& uses, Stack& stack) { + for (int i = 0; i < uses.values.size; i++) { + int reg = get(uses.values, i); // std::cout << "push reg[" << reg << "];\n" << registers[reg] << "\n\n"; - if(get(uses.free_flags,i)) { + if (get(uses.free_flags, i)) { stack.push_back(std::move(registers[reg])); } else { stack.push_back(registers[reg]); } - } } @@ -786,9 +828,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { c10::intrusive_ptr future; std::shared_ptr function; // keep function alive // these are just copies of function to prevent indirections in interpreter - int * int_data; - const std::vector & bool_data; - + int* int_data; + const std::vector& bool_data; // this holds all the tensors for this interpreter run // we don't bother minimizing the size of this vector, since the extra @@ -797,31 +838,31 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // to make sure memory management happens efficiently. // We optimize for the case where derivatives are run with retain_graph=False - // in the case where it is true, then the interpreter and this array get copied - // if this every becomes a bottleneck then we _should_ consider minimizing the - // total number or register + // in the case where it is true, then the interpreter and this array get + // copied if this every becomes a bottleneck then we _should_ consider + // minimizing the total number or register std::vector registers; - // single buffer for input/output calls to ATen functions, so that we do not reallocate + // single buffer for input/output calls to ATen functions, so that we do not + // reallocate Stack stack; }; -std::ostream & operator<<(std::ostream & out, const Code & code) { +std::ostream& operator<<(std::ostream& out, const Code& code) { out << *code.pImpl->graph << "\n"; code.pImpl->dump(out); return out; } -Code::Code(const std::shared_ptr& graph) - : pImpl(new CodeImpl(graph)) {} +Code::Code(const std::shared_ptr& graph) : pImpl(new CodeImpl(graph)) {} Code::~Code() = default; const std::vector& Code::grad_executors() { return pImpl->grad_executors(); } -InterpreterState::InterpreterState(const Code & code) - : pImpl(c10::make_intrusive(code)) {} +InterpreterState::InterpreterState(const Code& code) + : pImpl(c10::make_intrusive(code)) {} InterpreterState::~InterpreterState() = default; void InterpreterState::run(Stack& stack) { @@ -836,6 +877,8 @@ c10::intrusive_ptr InterpreterState::getFuture() { return static_cast(pImpl.get())->getOrCreateFuture(); } -InterpreterState::InterpreterState(c10::intrusive_ptr pImpl_) +InterpreterState::InterpreterState( + c10::intrusive_ptr pImpl_) : pImpl(std::move(pImpl_)) {} -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index facbd61..3ef761f 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -1,18 +1,19 @@ #pragma once +#include #include #include -#include -#include #include +#include namespace at { - class Tensor; +class Tensor; } namespace c10 { struct IValue; } -namespace torch { namespace jit { +namespace torch { +namespace jit { // The interpreter run Graphs with Tensor inputs and Tensor outputs // a separate component in the autograd handles unwrapping and wrapping @@ -27,8 +28,7 @@ struct Node; using Stack = std::vector; struct TORCH_API Code { - Code() - : pImpl(nullptr) {} + Code() : pImpl(nullptr) {} explicit Code(const std::shared_ptr& graph); ~Code(); @@ -38,19 +38,20 @@ struct TORCH_API Code { return pImpl != nullptr; } -private: + private: std::shared_ptr pImpl; friend struct InterpreterStateImpl; - friend std::ostream & operator<<(std::ostream & out, const Code & code); + friend std::ostream& operator<<(std::ostream& out, const Code& code); }; struct InterpreterState { - InterpreterState(const Code & code); + InterpreterState(const Code& code); void run(Stack& stack); c10::intrusive_ptr runAsync(Stack& stack); c10::intrusive_ptr getFuture(); ~InterpreterState(); -private: + + private: InterpreterState(c10::intrusive_ptr pImpl); // Ideally we should use c10::intrusive_ptr for pImpl; // but intrusive_ptr requires full definition of InterpreterStateImpl, @@ -83,4 +84,5 @@ struct InterpreterContinuation { InterpreterState state; Stack stack; }; -}} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index f30e92a..17d640d 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -1,13 +1,12 @@ #include - -#include #include -#include #include -#include -#include +#include +#include #include +#include +#include #include #include @@ -19,7 +18,8 @@ #include #include -namespace torch { namespace jit { +namespace torch { +namespace jit { // Constants relating to maintaining the topological index of nodes. // // Lower and upper bounds of the index. Inclusive range. @@ -32,26 +32,27 @@ static constexpr topo_position_t kMidPoint = 0; // - 2^(64-n) is the maximum number of appends to the end without reindex static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */; -// Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char +// Sigh, see +// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char constexpr Symbol PythonOp::Kind; -void printValueRef(std::ostream & out, const Value * n) { +void printValueRef(std::ostream& out, const Value* n) { out << "%" << n->uniqueName(); } // NB: This overload will become ambiguous with the one Caffe2 provides in its // logging, if they ever intersect. template -std::ostream& operator<<(std::ostream & out, const std::vector & nodes) { +std::ostream& operator<<(std::ostream& out, const std::vector& nodes) { out << at::ArrayRef{nodes}; return out; } template -std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef & nodes) { +std::ostream& printValueRefs(std::ostream& out, const at::ArrayRef& nodes) { size_t i = 0; - for(auto n : nodes) { - if(i++ > 0) + for (auto n : nodes) { + if (i++ > 0) out << ", "; printValueRef(out, n); } @@ -61,24 +62,28 @@ std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef & nodes) // Can't make these two overloads directly a template, it'll be ambiguous with // the global printer for operator<<. -std::ostream& operator<<(std::ostream & out, const at::ArrayRef & nodes) { +std::ostream& operator<<( + std::ostream& out, + const at::ArrayRef& nodes) { return printValueRefs(out, nodes); } -std::ostream& operator<<(std::ostream & out, const at::ArrayRef & nodes) { +std::ostream& operator<<(std::ostream& out, const at::ArrayRef& nodes) { return printValueRefs(out, nodes); } struct const_value_list_with_types { const ArrayRef values; bool use_newlines; - const_value_list_with_types(ArrayRef values, bool use_newlines = false) - : values(values), use_newlines(use_newlines) {} + const_value_list_with_types( + ArrayRef values, + bool use_newlines = false) + : values(values), use_newlines(use_newlines) {} }; -std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) { +std::ostream& operator<<(std::ostream& out, const_value_list_with_types l) { size_t i = 0; - for(auto n : l.values) { - if(i++ > 0) { + for (auto n : l.values) { + if (i++ > 0) { if (l.use_newlines) { // TODO: Indent here is hard-coded for "graph(": un-hard-code it out << "\n "; @@ -93,14 +98,17 @@ std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) { return out; } -void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=false) { +void printAttributes( + std::ostream& out, + const Node* n, + bool ignore_subgraph = false) { out << "["; auto names = n->attributeNames(); int i = 0; - for(auto name : names) { + for (auto name : names) { if (ignore_subgraph && name == attr::Subgraph) continue; - if(i++ > 0) + if (i++ > 0) out << ", "; // TODO: debugging mode to see the qualifier. We definitely // don't want to print the qualifier since it should always @@ -113,46 +121,51 @@ void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=fa out << "]"; } -static std::ostream & indent(std::ostream & out, size_t level) { - for(size_t i = 0; i < level; ++i) +static std::ostream& indent(std::ostream& out, size_t level) { + for (size_t i = 0; i < level; ++i) out << " "; return out; } -std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::vector * groups) { +std::ostream& printNode( + std::ostream& out, + size_t level, + const Node* n, + std::vector* groups) { auto outputs = n->outputs(); indent(out, level) << const_value_list_with_types(outputs); out << " = "; - IR_IFM_CONST(n,PythonOp) - out << "^" << value->name(); - value->writeScalars(out); + IR_IFM_CONST(n, PythonOp) + out << "^" << value->name(); + value->writeScalars(out); IR_ELSE() - if(n->hasAttribute(attr::Subgraph) && groups) { - out << n->kind().toQualString() << "_" << groups->size(); - if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) { - printAttributes(out, n, /*ignore_subgraph=*/true); - } - groups->push_back(n); - } else { - out << n->kind().toQualString(); - if(n->hasAttributes()) { - printAttributes(out,n); - } + if (n->hasAttribute(attr::Subgraph) && groups) { + out << n->kind().toQualString() << "_" << groups->size(); + if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) { + printAttributes(out, n, /*ignore_subgraph=*/true); + } + groups->push_back(n); + } else { + out << n->kind().toQualString(); + if (n->hasAttributes()) { + printAttributes(out, n); } + } IR_END() out << "(" << n->inputs() << ")"; std::string scopeName = n->scopeName(); if (scopeName.empty()) { out << "\n"; - } - else { + } else { out << ", "; out << "scope: " << scopeName << "\n"; } - for(size_t i = 0; i < n->blocks().size(); ++i) { + for (size_t i = 0; i < n->blocks().size(); ++i) { auto b = n->blocks()[i]; - indent(out, level + 1) << "block" << i << "(" << const_value_list_with_types(b->inputs(), false) << ") {\n"; - for(auto n : b->nodes()) { + indent(out, level + 1) << "block" << i << "(" + << const_value_list_with_types(b->inputs(), false) + << ") {\n"; + for (auto n : b->nodes()) { printNode(out, level + 2, n, groups); } indent(out, level + 2) << "-> (" << b->outputs() << ")\n"; @@ -161,20 +174,21 @@ std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::v return out; } -std::ostream& operator<<(std::ostream & out, const Node & n) { +std::ostream& operator<<(std::ostream& out, const Node& n) { return printNode(out, 0, &n, nullptr); } -std::ostream& operator<<(std::ostream & out, const Graph & g) { +std::ostream& operator<<(std::ostream& out, const Graph& g) { out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n"; std::vector groups; - for(auto n : g.nodes()) { + for (auto n : g.nodes()) { printNode(out, 1, n, &groups); } out << " return (" << g.outputs() << ");\n}\n"; size_t i = 0; - for(auto fg : groups) { - out << "with " << fg->kind().toQualString() << "_" <kind().toQualString() << "_" << i++ << " = " + << *fg->g(attr::Subgraph); } /* // Uncomment this to debug all_nodes issues @@ -189,7 +203,7 @@ std::ostream& operator<<(std::ostream & out, const Graph & g) { return out; } -std::ostream& Graph::prettyPrint(std::ostream & out) { +std::ostream& Graph::prettyPrint(std::ostream& out) { std::vector tensor_table; PythonPrint(out, *this, tensor_table); return out; @@ -204,8 +218,8 @@ static void checkSameDevice(const Node* node) { bool has_device = false; c10::optional device = c10::nullopt; auto checkValue = [&](const Value* v) { - if(CompleteTensorTypePtr type = v->type()->cast()) { - if(!has_device) { + if (CompleteTensorTypePtr type = v->type()->cast()) { + if (!has_device) { has_device = true; device = type->device(); } else { @@ -213,10 +227,10 @@ static void checkSameDevice(const Node* node) { } } }; - for(auto input : node->inputs()) { + for (auto input : node->inputs()) { checkValue(input); } - for(auto output : node->outputs()) { + for (auto output : node->outputs()) { checkValue(output); } } @@ -247,13 +261,15 @@ void Node::lint() const { for (auto input : inputs_) { // WARNING: O(n^2) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(const_cast(this), i)) != input->uses_.end()); + JIT_ASSERT( + std::find(ALL_OF(input->uses_), Use(const_cast(this), i)) != + input->uses_.end()); JIT_ASSERT(graph_->all_nodes.count(this) == 1); i++; } } - for(auto o : outputs()) { + for (auto o : outputs()) { size_t i = 0; for (auto use : o->uses()) { // Use invariants @@ -265,38 +281,37 @@ void Node::lint() const { } // Node subclass invariants - IR_IF(this,Constant) - JIT_ASSERT(inputs_.size() == 0); + IR_IF(this, Constant) + JIT_ASSERT(inputs_.size() == 0); IR_ELSEIF(Return) - // Return uses is zero - JIT_ASSERT(outputs().size() == 0); + // Return uses is zero + JIT_ASSERT(outputs().size() == 0); IR_ELSEIF(Param) - // Param inputs is zero - JIT_ASSERT(inputs_.size() == 0); + // Param inputs is zero + JIT_ASSERT(inputs_.size() == 0); IR_ELSEIFM_CONST(PythonOp) - // Python operator cconv is correct - size_t n_scalars = 0, n_tensors = 0; - for (auto c : value->cconv) { - if (c == 'c') { - n_scalars++; - } else if (c == 'd') { - n_tensors++; - } else { - JIT_ASSERT(0); - } - JIT_ASSERT(static_cast(value->pyobj)); + // Python operator cconv is correct + size_t n_scalars = 0, n_tensors = 0; + for (auto c : value->cconv) { + if (c == 'c') { + n_scalars++; + } else if (c == 'd') { + n_tensors++; + } else { + JIT_ASSERT(0); } - JIT_ASSERT(n_scalars == value->scalar_args.size()); - JIT_ASSERT(n_tensors == inputs_.size()); + JIT_ASSERT(static_cast(value->pyobj)); + } + JIT_ASSERT(n_scalars == value->scalar_args.size()); + JIT_ASSERT(n_tensors == inputs_.size()); IR_ELSEIF(Eval) - // TODO: add invariants + // TODO: add invariants // TODO: It's not good for these ops to be top-level, it makes cases longer. IR_ELSEIF(FusionGroup) - checkSameDevice(value); - // TODO: Typecheck the parameters - value->g(attr::Subgraph)->lint(); + checkSameDevice(value); + // TODO: Typecheck the parameters + value->g(attr::Subgraph)->lint(); IR_END() - } // TODO: When lint fails, give better indication about which @@ -317,35 +332,35 @@ void Graph::lint() const { struct LintScope { LintScope() = default; - LintScope(std::unique_ptr parent) - : parent(std::move(parent)) {} - bool contains(const Value * v) { + LintScope(std::unique_ptr parent) : parent(std::move(parent)) {} + bool contains(const Value* v) { return values.count(v) > 0 || (parent && parent->contains(v)); } - bool contains(const Node * n) { + bool contains(const Node* n) { return nodes.count(n) > 0 || (parent && parent->contains(n)); } - void insert(const Value * v) { + void insert(const Value* v) { JIT_ASSERT(!contains(v)); values.insert(v); } - void insert(const Node * n) { + void insert(const Node* n) { JIT_ASSERT(!contains(n)); nodes.insert(n); } std::unique_ptr parent; - private: + + private: std::unordered_set values; std::unordered_set nodes; }; // Struct enables mutual recursion in linting methods. // Putting it inside Graph::lint enables access to private Graph members struct LintImpl { - LintImpl(const Graph & g) - : g(g) - , scope(new LintScope()) - , all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered* - const Graph & g; + LintImpl(const Graph& g) + : g(g), + scope(new LintScope()), + all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered* + const Graph& g; std::unique_ptr scope; std::unordered_set seen_uniques; std::unordered_map anticipated_uses; @@ -355,13 +370,13 @@ void Graph::lint() const { void check_value(const Value* v) { scope->insert(v); auto b2 = seen_uniques.insert(v->unique()); - JIT_ASSERT(b2.second); // insertion took place + JIT_ASSERT(b2.second); // insertion took place JIT_ASSERT(v->unique() < g.next_unique_); for (auto use : v->uses()) { JIT_ASSERT(!scope->contains(use.user)); JIT_ASSERT(g.all_nodes.count(use.user) == 1); - anticipated_uses[use.user]++; // int default constructs to 0 + anticipated_uses[use.user]++; // int default constructs to 0 } } void check_node(const Node* n) { @@ -370,24 +385,25 @@ void Graph::lint() const { JIT_ASSERTM(0, input->unique(), " not in scope"); } } - JIT_ASSERT(anticipated_uses[n] == static_cast(n->inputs_.size())); - anticipated_uses[n] = -1; // we saw the anticipated user! + JIT_ASSERT( + anticipated_uses[n] == static_cast(n->inputs_.size())); + anticipated_uses[n] = -1; // we saw the anticipated user! scope->insert(n); - for(auto block : n->blocks()) { + for (auto block : n->blocks()) { std::unique_ptr new_scope(new LintScope(std::move(scope))); scope = std::move(new_scope); check_block(block); scope = std::move(scope->parent); } size_t i = 0; - for(auto o : n->outputs()) { + for (auto o : n->outputs()) { JIT_ASSERT(o->node() == n); JIT_ASSERT(i++ == o->offset_); check_value(o); } n->lint(); } - void check_block(const Block *b) { + void check_block(const Block* b) { // Check topological ordering JIT_ASSERT(b->param_node()->isBefore(*b->nodes().begin())); auto curNode = *b->nodes().begin(); @@ -417,10 +433,10 @@ void Graph::lint() const { // - only one return node??? node_set nodes_set(ALL_OF(b->nodes())); - node_set inputs_set {b->input_}; - node_set output_set {b->output_}; - // TODO: Make a more type safe std::includes wrapper which disallows use on - // non-ordered containers + node_set inputs_set{b->input_}; + node_set output_set{b->output_}; + // TODO: Make a more type safe std::includes wrapper which disallows use + // on non-ordered containers JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set))); JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set))); JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set))); @@ -430,7 +446,8 @@ void Graph::lint() const { sum_set.insert(ALL_OF(output_set)); } void check_graph() { - node_set all_nodes_set(ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered* + node_set all_nodes_set( + ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered* check_block(g.block_); for (auto kv : anticipated_uses) { @@ -471,30 +488,30 @@ void Block::reIndexTopology() { } } -void Block::cloneFrom(Block * src, std::function value_map) { +void Block::cloneFrom(Block* src, std::function value_map) { std::unordered_map local_map; - auto env = [&](Value * v) { + auto env = [&](Value* v) { auto it = local_map.find(v); - if(it != local_map.end()) + if (it != local_map.end()) return it->second; return value_map(v); }; auto graph = owningGraph(); - for(auto input : src->inputs()) { + for (auto input : src->inputs()) { local_map[input] = this->addInput()->copyMetadata(input); } - for(auto node : src->nodes()) { + for (auto node : src->nodes()) { auto new_node = this->appendNode(graph->createClone(node, env)); - for(size_t i = 0; i < node->outputs().size(); ++i) { + for (size_t i = 0; i < node->outputs().size(); ++i) { auto oo = node->outputs()[i]; auto no = new_node->outputs()[i]; local_map[oo] = no; no->copyMetadata(oo); } } - for(auto output : src->outputs()) { + for (auto output : src->outputs()) { this->registerOutput(env(output)); } } @@ -503,9 +520,10 @@ void Block::destroy() { // we cannot destroy the output because it is used as the sentinel // for the nodes() list and has to remain valid for the loop output_->removeAllInputs(); - for(auto it = this->nodes().reverse().begin(), - end = this->nodes().reverse().end(); - it != end; ++it) { + for (auto it = this->nodes().reverse().begin(), + end = this->nodes().reverse().end(); + it != end; + ++it) { it.destroyCurrent(); } output_->destroy(); @@ -532,38 +550,41 @@ std::string Value::uniqueNameBase() const { std::string name_base = name; auto last_dot_pos = name.find_last_of('.'); if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { - if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) { + if (name.find_first_not_of("0123456789", last_dot_pos + 1) == + std::string::npos) { name_base = name.substr(0, last_dot_pos); } } return name_base; } -Value* Value::setUniqueName(const std::string & name) { - if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) { +Value* Value::setUniqueName(const std::string& name) { + if (name.size() > 0 && + name.find_first_not_of("0123456789") == std::string::npos) { throw std::runtime_error("names may not be integers: " + name); } - auto & names = node()->owningGraph()->unique_names_; + auto& names = node()->owningGraph()->unique_names_; // clear any old name from the map - if(hasUniqueName()) { + if (hasUniqueName()) { names.erase(unique_name_); unique_name_ = ""; } // allow "" to clear the uniquename - if(name == "") + if (name == "") return this; // if someone else has this name, then rename the other value auto old_owner_of_name = names.find(name); - if(old_owner_of_name != names.end()) { + if (old_owner_of_name != names.end()) { size_t suffix = 1; std::string name_base = name; auto last_dot_pos = name.find_last_of('.'); if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) { - if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) { + if (name.find_first_not_of("0123456789", last_dot_pos + 1) == + std::string::npos) { suffix = std::stoll(name.substr(last_dot_pos + 1)); name_base = name.substr(0, last_dot_pos); } @@ -573,7 +594,7 @@ Value* Value::setUniqueName(const std::string & name) { std::stringstream ss; ss << name_base << "." << suffix++; replacement_name = ss.str(); - } while(names.count(replacement_name) > 0); + } while (names.count(replacement_name) > 0); old_owner_of_name->second->setUniqueName(replacement_name); } @@ -582,14 +603,14 @@ Value* Value::setUniqueName(const std::string & name) { return this; } -Value* Value::copyMetadata(Value * from) { +Value* Value::copyMetadata(Value* from) { setType(from->type()); if (from->hasUniqueName()) setUniqueName(from->uniqueName()); return this; } -void Value::replaceFirstUseWith(Value * newValue) { +void Value::replaceFirstUseWith(Value* newValue) { JIT_ASSERT(owningGraph() == newValue->owningGraph()); auto u = uses()[0]; u.user->inputs_[u.offset] = newValue; @@ -597,7 +618,7 @@ void Value::replaceFirstUseWith(Value * newValue) { uses_.erase(uses_.begin()); } -void Value::replaceAllUsesWith(Value * newValue) { +void Value::replaceAllUsesWith(Value* newValue) { while (!uses().empty()) { replaceFirstUseWith(newValue); } @@ -611,7 +632,8 @@ size_t findArgument(const FunctionSchema& the_schema, Symbol name) { return i; } } - throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); + throw std::runtime_error( + std::string("Couldn't find an argument called ") + name.toQualString()); } c10::optional Node::get(Symbol name) const { @@ -622,10 +644,14 @@ Value* Node::namedInput(Symbol name) const { return input(findArgument(schema(), name)); } -bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) const { - if (!sig(signature_literal).matches(this)) return false; +bool Node::matches( + const char* signature_literal, + at::ArrayRef const_inputs) const { + if (!sig(signature_literal).matches(this)) + return false; for (Symbol s : const_inputs) { - if (!is_constant(s)) return false; + if (!is_constant(s)) + return false; } return true; } @@ -639,8 +665,8 @@ void Node::findSchema() const { } const FunctionSchema* Node::maybeSchema() const { - if(!schema_) { - if(auto op = findOperatorFor(this)) { + if (!schema_) { + if (auto op = findOperatorFor(this)) { schema_ = &op->schema(); } } @@ -649,38 +675,38 @@ const FunctionSchema* Node::maybeSchema() const { bool Node::isNondeterministic() const { static const OperatorSet nondeterministic_ops = { - "aten::dropout(Tensor input, float p, bool train) -> Tensor", - "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)", - "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor", - "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", - "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor", - "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor", - "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor", - "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor", - "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor", - "aten::poisson(Tensor self, Generator? generator) -> Tensor", - "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", - "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", - "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor", - "aten::rand_like(Tensor self) -> Tensor", - "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint_like(Tensor self, int high) -> Tensor", - "aten::randint_like(Tensor self, int low, int high) -> Tensor", - "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor", - "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor", - "aten::randn_like(Tensor self) -> Tensor", - "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", - "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor" - }; + "aten::dropout(Tensor input, float p, bool train) -> Tensor", + "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)", + "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", + "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor", + "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor", + "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor", + "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor", + "aten::poisson(Tensor self, Generator? generator) -> Tensor", + "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", + "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor", + "aten::rand_like(Tensor self) -> Tensor", + "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", + "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", + "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor", + "aten::randint_like(Tensor self, int high) -> Tensor", + "aten::randint_like(Tensor self, int low, int high) -> Tensor", + "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor", + "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor", + "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor", + "aten::randn_like(Tensor self) -> Tensor", + "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor", + "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor"}; if (nondeterministic_ops.find(this) == nullptr) { return false; } // Dropout with train = False is deterministic - if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get(attr::train).value()) { + if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && + is_constant(attr::train) && !get(attr::train).value()) { return false; } return true; @@ -753,13 +779,13 @@ void Node::assignTopoPosition() { } } -Node::Node(Graph * graph_, NodeKind kind_) : - kind_(kind_), - graph_(graph_), - owning_block_(nullptr), - scope_(graph_->current_scope_), - schema_(nullptr), - topo_position_(0) { +Node::Node(Graph* graph_, NodeKind kind_) + : kind_(kind_), + graph_(graph_), + owning_block_(nullptr), + scope_(graph_->current_scope_), + schema_(nullptr), + topo_position_(0) { graph_->all_nodes.emplace(this); } @@ -767,15 +793,15 @@ void Node::eraseOutput(size_t i) { JIT_ASSERT(i < outputs_.size()); JIT_ASSERT(outputs_[i]->uses().empty()); schema_ = nullptr; - Value * n = outputs_[i]; + Value* n = outputs_[i]; outputs_.erase(outputs_.begin() + i); owningGraph()->freeValue(n); - for(size_t j = i; j < outputs_.size(); j++) { + for (size_t j = i; j < outputs_.size(); j++) { outputs_[j]->offset_--; } } -Block * Node::addBlock() { +Block* Node::addBlock() { schema_ = nullptr; blocks_.push_back(new Block(owningGraph(), this)); return blocks_.back(); @@ -784,33 +810,33 @@ Block * Node::addBlock() { void Node::eraseBlock(size_t i) { JIT_ASSERT(i < blocks_.size()); schema_ = nullptr; - Block * n = blocks_[i]; + Block* n = blocks_[i]; blocks_.erase(blocks_.begin() + i); n->destroy(); } void Node::destroy() { - while(!outputs().empty()) + while (!outputs().empty()) eraseOutput(outputs().size() - 1); - while(!blocks().empty()) + while (!blocks().empty()) eraseBlock(blocks().size() - 1); removeAllInputs(); - if(inBlockList()) + if (inBlockList()) removeFromList(); graph_->freeNode(this); } -void Node::cloneFrom(Node * s) { +void Node::cloneFrom(Node* s) { setSourceLocation(s->getSourceLocation()); - if(s->scope_ && !s->scope_->isBlank()) + if (s->scope_ && !s->scope_->isBlank()) scope_ = s->scope_; copyAttributes(*s); } -void Node::replaceAllUsesWith(Node * n) { +void Node::replaceAllUsesWith(Node* n) { JIT_ASSERT(outputs().size() == n->outputs().size()); size_t nOutputs = outputs().size(); - for(size_t i = 0; i < nOutputs; i++) { + for (size_t i = 0; i < nOutputs; i++) { outputs()[i]->replaceAllUsesWith(n->outputs()[i]); } } @@ -834,7 +860,7 @@ Value* Node::insertInput(size_t i, Value* value) { return value; } -Value* Node::addInput(Value * value) { +Value* Node::addInput(Value* value) { JIT_ASSERT(graph_ == value->owningGraph()); schema_ = nullptr; value->uses_.emplace_back(this, inputs_.size()); @@ -842,22 +868,22 @@ Value* Node::addInput(Value * value) { return value; } -Value* Node::replaceInput(size_t i, Value * newValue) { +Value* Node::replaceInput(size_t i, Value* newValue) { JIT_ASSERT(newValue->owningGraph() == graph_); schema_ = nullptr; - Value * old = dropInput(i); + Value* old = dropInput(i); inputs_[i] = newValue; newValue->uses_.emplace_back(this, i); return old; } -void Node::replaceInputWith(Value * from, Value * to) { +void Node::replaceInputWith(Value* from, Value* to) { JIT_ASSERT(from->owningGraph() == graph_); JIT_ASSERT(to->owningGraph() == graph_); schema_ = nullptr; size_t i = 0; - for(auto input : inputs()) { - if(input == from) + for (auto input : inputs()) { + if (input == from) replaceInput(i, to); i++; } @@ -914,28 +940,27 @@ bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const { } // should never reach here, since both nodes are ultimately in the same graph JIT_ASSERT(false); - } -bool Node::isBefore(const Node * n) const { +bool Node::isBefore(const Node* n) const { return isBeforeOrAfter(n, MoveSide::BEFORE); } -bool Node::isAfter(const Node * n) const { +bool Node::isAfter(const Node* n) const { return isBeforeOrAfter(n, MoveSide::AFTER); } -Node* Node::insertBefore(Node * n) { +Node* Node::insertBefore(Node* n) { JIT_ASSERT(n->inBlockList()); insertAfter(n->prev()); return this; } -Node* Node::insertAfter(Node * n) { +Node* Node::insertAfter(Node* n) { JIT_ASSERT(!inBlockList() && n->inBlockList()); JIT_ASSERT(n->owningBlock()); this->owning_block_ = n->owningBlock(); - Node * next = n->next(); + Node* next = n->next(); n->next() = this; this->prev() = n; this->next() = next; @@ -968,8 +993,7 @@ bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) { namespace { struct WorkingSet { public: - explicit WorkingSet(Node* mover, const AliasDb& aliasDb) - : aliasDb_(aliasDb) { + explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) { add(mover); } @@ -1141,7 +1165,11 @@ struct WorkingSet { // node at a time. When we can't move past a node (because it depends on the // working set), then add it to the working set and keep moving until we hit // `moveAfter`. -bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun) { +bool Node::tryMove( + Node* movePoint, + MoveSide moveSide, + const AliasDb& aliasDb, + bool dryRun) { JIT_ASSERT(this->inBlockList() && movePoint->inBlockList()); JIT_ASSERT(this->owningBlock() == movePoint->owningBlock()); if (this == movePoint) { @@ -1241,12 +1269,12 @@ void Node::move(Node* movePoint, MoveSide moveSide) { } } -void Node::moveAfter(Node * n) { +void Node::moveAfter(Node* n) { removeFromList(); insertAfter(n); } -void Node::moveBefore(Node * n) { +void Node::moveBefore(Node* n) { removeFromList(); insertBefore(n); } @@ -1256,7 +1284,7 @@ void Node::removeInput(size_t i) { dropInput(i); // everything after this input shifts left, // so we need to update their use offsets to match - for(size_t j = i+1; j < inputs_.size(); j++) { + for (size_t j = i + 1; j < inputs_.size(); j++) { auto it = findUseForInput(j); it->offset--; } @@ -1265,13 +1293,13 @@ void Node::removeInput(size_t i) { void Node::removeAllInputs() { schema_ = nullptr; - for(size_t i = 0; i < inputs().size(); ++i) + for (size_t i = 0; i < inputs().size(); ++i) dropInput(i); inputs_.clear(); } use_list::iterator Node::findUseForInput(size_t i) { - auto & input_uses = inputs_[i]->uses_; + auto& input_uses = inputs_[i]->uses_; // O(N) on the use list, but unless we get nodes with +100 uses // vector traversal still is probably faster than linked list auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); @@ -1291,8 +1319,8 @@ Value* Node::dropInput(size_t i) { void Node::removeFromList() { JIT_ASSERT(inBlockList()); this->owning_block_ = nullptr; - Node * next = this->next(); - Node * prev = this->prev(); + Node* next = this->next(); + Node* prev = this->prev(); prev->next() = next; next->prev() = prev; this->next() = nullptr; @@ -1300,7 +1328,8 @@ void Node::removeFromList() { } inline const SourceRange& fakeRange() { - static SourceRange range(std::make_shared(""), 0, 1); + static SourceRange range( + std::make_shared(""), 0, 1); return range; } @@ -1322,14 +1351,17 @@ Value* Graph::insert( Node* Graph::create(NodeKind kind, size_t num_outputs) { // NB: Node constructor adds node to all_nodes auto n = new Node(this, kind); - for(size_t i = 0; i < num_outputs; i++) + for (size_t i = 0; i < num_outputs; i++) n->addOutput(); return n; } -Node* Graph::create(NodeKind kind, ArrayRef inputs, size_t num_outputs) { +Node* Graph::create( + NodeKind kind, + ArrayRef inputs, + size_t num_outputs) { auto n = create(kind, num_outputs); - for(auto i : inputs) + for (auto i : inputs) n->addInput(i); return n; } @@ -1339,14 +1371,14 @@ Node* Graph::createUndefined() { } Node* Graph::createNone(TypePtr typ) { - Node * n = create(prim::None); + Node* n = create(prim::None); n->output()->setType(OptionalType::create(std::move(typ))); return n; } -Node * Graph::createFusionGroup() { +Node* Graph::createFusionGroup() { auto n = create(prim::FusionGroup, 0); - n->g_(attr::Subgraph,std::make_shared(current_scope())); + n->g_(attr::Subgraph, std::make_shared(current_scope())); return n; } @@ -1358,16 +1390,16 @@ Node* Graph::createTuple(at::ArrayRef values) { return n; } -Node* Graph::createTupleUnpack(Value * v) { +Node* Graph::createTupleUnpack(Value* v) { TupleTypePtr tt = v->type()->expect(); auto n = create(prim::TupleUnpack, {v}, 0); - for(auto & element : tt->elements()) { + for (auto& element : tt->elements()) { n->addOutput()->setType(element); } return n; } -Node* Graph::createTupleIndex(Value * tup, int64_t index) { +Node* Graph::createTupleIndex(Value* tup, int64_t index) { auto n = create(prim::TupleIndex, {tup}); n->i_(attr::index, index); auto tuple_type = tup->type()->expect(); @@ -1375,7 +1407,7 @@ Node* Graph::createTupleIndex(Value * tup, int64_t index) { return n; } -Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) { +Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) { auto n = create(prim::TupleSlice, {tup}); auto tuple_type = tup->type()->expect(); n->i_(attr::beg, beg); @@ -1391,13 +1423,13 @@ Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) { Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef values) { auto n = create(prim::ListConstruct, values); - for(const auto & v : values) { + for (const auto& v : values) { JIT_ASSERT(v->type()->isSubtypeOf(elem_type)); } n->output()->setType(ListType::create(elem_type)); return n; } -Node* Graph::createListUnpack(Value *v, size_t size) { +Node* Graph::createListUnpack(Value* v, size_t size) { ListTypePtr list_type = v->type()->expect(); TypePtr elem_type = list_type->getElementType(); auto n = create(prim::ListUnpack, {v}, 0); @@ -1409,7 +1441,7 @@ Node* Graph::createListUnpack(Value *v, size_t size) { Node* Graph::createNumToTensor(Value* value) { auto typ = value->type(); - Node * result = create(prim::NumToTensor, {value}); + Node* result = create(prim::NumToTensor, {value}); result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ))); return result; } @@ -1420,18 +1452,21 @@ Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) { return result; } -Node* Graph::createClone(Node * n, const std::function& value_map, bool copy_blocks) { - //n can be from a different graph - Node * r = n->allocNewInstance(this); - for(auto o : n->outputs()) { +Node* Graph::createClone( + Node* n, + const std::function& value_map, + bool copy_blocks) { + // n can be from a different graph + Node* r = n->allocNewInstance(this); + for (auto o : n->outputs()) { r->addOutput()->copyMetadata(o); } r->cloneFrom(n); - for(auto i : n->inputs()) { + for (auto i : n->inputs()) { r->addInput(value_map(i)); } - if(copy_blocks) { - for(auto b : n->blocks()) { + if (copy_blocks) { + for (auto b : n->blocks()) { r->addBlock()->cloneFrom(b, value_map); } } @@ -1442,7 +1477,8 @@ Value* Graph::insertConstant( IValue val, c10::optional loc, c10::optional scope) { - return jit::insertConstant(*this, std::move(val), std::move(loc), std::move(scope)); + return jit::insertConstant( + *this, std::move(val), std::move(loc), std::move(scope)); } std::string Graph::toString() const { @@ -1452,28 +1488,28 @@ std::string Graph::toString() const { } Graph::~Graph() { - for (const Node * n : all_nodes) + for (const Node* n : all_nodes) delete n; - for (const Value * v : all_values) + for (const Value* v : all_values) delete v; - for (const Block * b : all_blocks) + for (const Block* b : all_blocks) delete b; } -void Graph::freeNode(Node * n) { +void Graph::freeNode(Node* n) { auto it = all_nodes.find(n); JIT_ASSERT(it != all_nodes.end()); delete *it; all_nodes.erase(it); } -void Graph::freeValue(Value * v) { +void Graph::freeValue(Value* v) { v->setUniqueName(""); auto it = all_values.find(v); JIT_ASSERT(it != all_values.end()); delete *it; all_values.erase(it); } -void Graph::freeBlock(Block * b) { +void Graph::freeBlock(Block* b) { auto it = all_blocks.find(b); JIT_ASSERT(it != all_blocks.end()); delete *it; @@ -1483,13 +1519,17 @@ void Graph::freeBlock(Block * b) { at::ArrayRef createTupleUnpack(Value* v) { // small peephole optimization to ensure IntList attributes can still turn // into constants e.g. in x.expand([3, 4]) - if(v->node()->kind() == prim::TupleConstruct) + if (v->node()->kind() == prim::TupleConstruct) return v->node()->inputs(); - auto & g = *v->owningGraph(); + auto& g = *v->owningGraph(); return g.insertNode(g.createTupleUnpack(v))->outputs(); } -std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs, bool unpack_outputs) { +std::vector inlineCallTo( + Graph& g, + Graph& callee, + ArrayRef inputs, + bool unpack_outputs) { std::unordered_map value_map; auto value_map_func = [&](Value* v) { return value_map.at(v); }; JIT_ASSERT(callee.inputs().size() == inputs.size()); @@ -1497,8 +1537,7 @@ std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef input value_map[callee.inputs()[i]] = inputs[i]; } for (auto* node : callee.nodes()) { - auto* new_node = - g.insertNode(g.createClone(node, value_map_func)); + auto* new_node = g.insertNode(g.createClone(node, value_map_func)); for (size_t i = 0; i < node->outputs().size(); ++i) { value_map[node->outputs()[i]] = new_node->outputs()[i]; } @@ -1511,23 +1550,25 @@ std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef input if (unpack_outputs && outputs.size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) { - auto tup = outputs[0]; - outputs.clear(); - for(Value* v : createTupleUnpack(tup)) { - outputs.emplace_back(v); - } - // if this was a peephole tuple unpack we can just get rid of - // the tuple construct here and prevent needing DCE - if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) { - tup->node()->destroy(); - } + auto tup = outputs[0]; + outputs.clear(); + for (Value* v : createTupleUnpack(tup)) { + outputs.emplace_back(v); + } + // if this was a peephole tuple unpack we can just get rid of + // the tuple construct here and prevent needing DCE + if (tup->node()->kind() == prim::TupleConstruct && + !tup->node()->hasUses()) { + tup->node()->destroy(); + } } return outputs; } -PythonOp* defaultAllocPythonOp(Graph*g) { - throw std::runtime_error("Trying to allocate a Python object without python bindings loaded"); +PythonOp* defaultAllocPythonOp(Graph* g) { + throw std::runtime_error( + "Trying to allocate a Python object without python bindings loaded"); } std::atomic alloc_python_op; @@ -1539,5 +1580,5 @@ void setAllocPythonOp(PythonOp* (*v)(Graph* g)) { alloc_python_op.store(v); } - -}} // namespace torch::jit +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 5a8f4f0..0d4f777 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1,25 +1,25 @@ #pragma once -#include #include +#include +#include +#include #include #include #include +#include +#include #include #include #include #include -#include -#include -#include #include -#include +#include #include #include #include #include -#include #include #include @@ -33,18 +33,21 @@ #include #include -namespace torch { namespace autograd { +namespace torch { +namespace autograd { struct Function; -}} // namespace torch::autograd +} +} // namespace torch -namespace torch { namespace jit { +namespace torch { +namespace jit { // Graph represents one "function" of computation. -// It uses a simple ownership model where the graph owns all the nodes inside it. -// All references inside the graph are raw pointers. -// Destroying the Graph will invalidate any pointers to nodes in the graph. +// It uses a simple ownership model where the graph owns all the nodes inside +// it. All references inside the graph are raw pointers. Destroying the Graph +// will invalidate any pointers to nodes in the graph. struct Graph; // Node is the base class of the IR graph. It represents one computation @@ -55,8 +58,8 @@ struct Node; // Tensor or an opaque Handle object, as determined by type(). struct Value; -TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g); -TORCH_API std::ostream& operator<<(std::ostream & out, const Node & n); +TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g); +TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n); // A list of nodes, with inputs and outputs struct Block; @@ -65,12 +68,11 @@ struct Block; // 'user' is the consumer of the value, offset is the index into // 'user's input this where the produces will be found. struct Use { - Use(Node * user, size_t offset) - : user(user), offset(offset) {} - Node * user; + Use(Node* user, size_t offset) : user(user), offset(offset) {} + Node* user; size_t offset; - bool operator==(const Use & b) { + bool operator==(const Use& b) { return user == b.user && offset == b.offset; } }; @@ -106,29 +108,31 @@ using node_list = std::vector; using value_list = std::vector; using use_list = std::vector; using pyobj_list = std::vector; -template +template using ArrayRef = at::ArrayRef; using NodeKind = Symbol; using topo_position_t = int64_t; struct Value { TH_DISALLOW_COPY_AND_ASSIGN(Value); - Value(Node * node_, size_t offset_); -private: + Value(Node* node_, size_t offset_); + + private: friend struct Node; friend struct Graph; - Node * node_; + Node* node_; size_t offset_; - size_t unique_ = 0; // unique id + size_t unique_ = 0; // unique id use_list uses_; std::string unique_name_; TypePtr type_; -public: + + public: Value* setType(TypePtr type); void inferTypeFrom(const at::Tensor& output) { setType(CompleteTensorType::create(output)); } - const TypePtr & type() const { + const TypePtr& type() const { JIT_ASSERT(type_ != nullptr); return type_; } @@ -145,7 +149,7 @@ public: bool hasUniqueName() const { return !unique_name_.empty(); } - TORCH_API Value* setUniqueName(const std::string & name); + TORCH_API Value* setUniqueName(const std::string& name); std::string uniqueName() const { if (hasUniqueName()) return unique_name_; @@ -161,13 +165,13 @@ public: void setOffset(size_t offset) { offset_ = offset; } - const Node * node() const { + const Node* node() const { return node_; } - Graph * owningGraph(); - const Graph * owningGraph() const; + Graph* owningGraph(); + const Graph* owningGraph() const; // TODO: make this more const correct - const use_list & uses() const { + const use_list& uses() const { return uses_; } @@ -175,7 +179,7 @@ public: return !uses().empty(); } - TORCH_API void replaceFirstUseWith(Value * newValue); + TORCH_API void replaceFirstUseWith(Value* newValue); // Replaces all uses of this value with 'newValue'. // @@ -186,12 +190,11 @@ public: // Result: %3 = f(%1, %2) // %4 = g(%6) // %5 = h(%6, %6) - TORCH_API void replaceAllUsesWith(Value * newValue); + TORCH_API void replaceAllUsesWith(Value* newValue); - TORCH_API Value* copyMetadata(Value * from); + TORCH_API Value* copyMetadata(Value* from); }; - struct Node : public Attributes { TH_DISALLOW_COPY_AND_ASSIGN(Node); friend struct Graph; @@ -201,18 +204,23 @@ struct Node : public Attributes { friend const_graph_node_list; friend graph_node_list_iterator; friend const_graph_node_list_iterator; -private: + + private: // each node but Return/Param // is associated with exactly one place in the node list... // of the graph_ - // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list - // such that the list never has null pointers - // next_in_graph[0] is next pointer - // next_in_graph[1] is prev pointer - // using an array to allow the same iterator class for forward and reverse node lists + // this circular is a doubly-linked list. The Return node is used as the + // sentinel for the beginning and end of the list such that the list never has + // null pointers. + // - next_in_graph[0] is next pointer + // - next_in_graph[1] is prev pointer + // + // Using an array to allow the same iterator class for forward and + // reverse node lists + // // This list represents a topological sort - Node* next_in_graph[2] = { nullptr, nullptr }; + Node* next_in_graph[2] = {nullptr, nullptr}; const NodeKind kind_; std::vector inputs_; @@ -225,18 +233,26 @@ private: ScopePtr scope_; // Assumes FunctionSchemas are persistent, so we don't manage their lifetime. // This field is effective a cache that's populated on attribute lookups and - // invalidated every time we perform an operation that could potentially change - // the schema. - // note: mutable because schema_ is effectively a cache + // invalidated every time we perform an operation that could potentially + // change the schema. note: mutable because schema_ is effectively a cache mutable const FunctionSchema* schema_; topo_position_t topo_position_ = 0; -protected: - TORCH_API Node(Graph * graph_, NodeKind kind_); //defined after graph -public: - Node* & next() { return next_in_graph[kNextDirection]; } - Node* & prev() { return next_in_graph[kPrevDirection]; } - Node* const & next() const { return next_in_graph[kNextDirection]; } - Node* const & prev() const { return next_in_graph[kPrevDirection]; } + + protected: + TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph + public: + Node*& next() { + return next_in_graph[kNextDirection]; + } + Node*& prev() { + return next_in_graph[kPrevDirection]; + } + Node* const& next() const { + return next_in_graph[kNextDirection]; + } + Node* const& prev() const { + return next_in_graph[kPrevDirection]; + } NodeKind kind() const { return kind_; @@ -248,16 +264,16 @@ public: std::shared_ptr getSourceLocation() const { return source_location_; } - Graph * owningGraph() { + Graph* owningGraph() { return graph_; } - const Graph * owningGraph() const { + const Graph* owningGraph() const { return graph_; } - Block * owningBlock() { + Block* owningBlock() { return owning_block_; } - const Block * owningBlock() const { + const Block* owningBlock() const { return owning_block_; } ScopePtr scope() { @@ -300,26 +316,26 @@ public: // raw pointers are. return {outputs_.data(), outputs_.size()}; } - Value * output(size_t i) const { + Value* output(size_t i) const { return outputs_.at(i); } bool hasUses() const { - for(auto o : outputs()) { - if(!o->uses().empty()) + for (auto o : outputs()) { + if (!o->uses().empty()) return true; } return false; } - TORCH_API void replaceAllUsesWith(Node * n); + TORCH_API void replaceAllUsesWith(Node* n); - // lots of things like chunk have a single input or single output, so we have a - // helper to make accessing it easier - Value * input() { + // lots of things like chunk have a single input or single output, so we have + // a helper to make accessing it easier + Value* input() { JIT_ASSERT(inputs_.size() == 1); return inputs_.at(0); } - Value * output() { + Value* output() { JIT_ASSERT(outputs_.size() == 1); return outputs_.at(0); } @@ -327,12 +343,12 @@ public: JIT_ASSERT(outputs_.size() == 1); return outputs_.at(0); } - const Value * input() const { + const Value* input() const { JIT_ASSERT(inputs_.size() == 1); return inputs_.at(0); } // Access a particular input. This is a checked index. - Value * input(size_t i) const { + Value* input(size_t i) const { return inputs_.at(i); } @@ -342,7 +358,7 @@ public: template c10::optional get(Symbol name) const { - if(auto v = get(name)) + if (auto v = get(name)) return v->template to(); return c10::nullopt; } @@ -353,7 +369,7 @@ public: } TORCH_API bool isNondeterministic() const; - TORCH_API bool hasSideEffects () const; + TORCH_API bool hasSideEffects() const; // Graphs @@ -375,7 +391,7 @@ public: // Given: %3 = f(%1, %2) // Execute: %3.addInput(%4) // Result: %3 = f(%1, %2, %4) - TORCH_API Value* addInput(Value * value); + TORCH_API Value* addInput(Value* value); // Add 'value' as an input to 'this' at the specified position in the // arguments. Returns the added value for ease of chaining. @@ -387,7 +403,7 @@ public: // Given: %3 = f(%1, %2) // Execute: %3.replaceInput(1, %4) // Result: %3 = f(%1, %4) - TORCH_API Value * replaceInput(size_t i, Value * newValue); + TORCH_API Value* replaceInput(size_t i, Value* newValue); // Replace all occurrences of 'from' in the inputs of this // node with 'to'. Corresponds to llvm's replaceUsesOfWith. @@ -395,7 +411,7 @@ public: // Given: %3 = f(%1, %2, %1) // Execute: %3.replaceInputWith(%1, %4) // Result: %3 = f(%4, %2, %4) - TORCH_API void replaceInputWith(Value * from, Value * to); + TORCH_API void replaceInputWith(Value* from, Value* to); TORCH_API Value* addOutput(); @@ -403,17 +419,18 @@ public: TORCH_API void eraseOutput(size_t i); - TORCH_API Block * addBlock(); + TORCH_API Block* addBlock(); TORCH_API void eraseBlock(size_t i); // Each Node can have a list of subblocks. These are used to define structured // nested control flow operators such as If and Loop. // The meaning of a block is specific to the kind of node it is in, but // all blocks share these semantics: - // * Nested lexical scoping: If a node 'Parent' has a subblock which contains a - // node 'Child', Child can use any value that was in scope for the Parent + // * Nested lexical scoping: If a node 'Parent' has a subblock which contains + // a node 'Child', Child can use any value that was in scope for the Parent // node in addition to any values defined before 'Child' in the subblock. - // * The list of inputs to the block are in scope for the duration of the block + // * The list of inputs to the block are in scope for the duration of the + // block // * the outputs of the Parent node are not in scope for the subblocks // Typically the inputs to a block that represents control flow act as // as the equivalents phi-nodes in standard SSA form, @@ -432,10 +449,10 @@ public: } // Is 'this' before 'n' in the topological order? - TORCH_API bool isBefore(const Node * n) const; + TORCH_API bool isBefore(const Node* n) const; // Is 'this' after 'n' in the topological order? - TORCH_API bool isAfter(const Node * n) const; + TORCH_API bool isAfter(const Node* n) const; // Insert unattached 'this' node before 'n' in the topological order. // Returns this (for chaining). @@ -447,7 +464,7 @@ public: // Result: %3 = f(%1, %2) // %5 = h(%1) // %4 = g(%3) - TORCH_API Node* insertBefore(Node * n); + TORCH_API Node* insertBefore(Node* n); // Insert unattached 'this' node after 'n' in the topological order. // Returns this (for chaining). @@ -459,7 +476,7 @@ public: // Result: %3 = f(%1, %2) // %4 = g(%3) // %5 = h(%1) - TORCH_API Node* insertAfter(Node * n); + TORCH_API Node* insertAfter(Node* n); // Move 'this' (already in the graph) after 'n' in the topological order. // @@ -472,7 +489,7 @@ public: // Result: %3 = g(%1) // %2 = f(%1) // - TORCH_API void moveAfter(Node * n); + TORCH_API void moveAfter(Node* n); // Move 'this' (already in the graph) after 'n' in the topological order. // @@ -500,7 +517,7 @@ public: // Execute: %3.moveBefore(%2) // Result: %3 = g(%1) // %2 = f(%1) - TORCH_API void moveBefore(Node * n); + TORCH_API void moveBefore(Node* n); // Move 'this' (already in the graph) before 'n' in the topological order. // @@ -565,23 +582,27 @@ public: // Example usage: if(auto s = n.cast