From 1e9c384afb094afde5b4dd8800e3521cd12bc0fd Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Thu, 13 Dec 2018 16:09:08 -0800 Subject: [PATCH] Enable performance-unnecessary-value-param in .clang-tidy (#15026) Summary: This PR fixes around 250 places in the codebase where we were making unnecessary copies of objects (some large, some small). ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/15026 Differential Revision: D13458784 Pulled By: goldsborough fbshipit-source-id: be5148b2ce09493588d70952e6f6d6ff5ec5199b --- .clang-tidy | 1 + test/cpp/api/module.cpp | 3 +- tools/run-clang-tidy-in-ci.sh | 2 +- torch/csrc/Exceptions.h | 4 +- torch/csrc/api/include/torch/nn/cloneable.h | 4 +- torch/csrc/api/include/torch/nn/module.h | 18 ++-- .../csrc/api/include/torch/nn/modules/batchnorm.h | 7 +- torch/csrc/api/include/torch/nn/modules/conv.h | 8 +- torch/csrc/api/include/torch/nn/modules/dropout.h | 4 +- .../csrc/api/include/torch/nn/modules/embedding.h | 2 +- torch/csrc/api/include/torch/nn/modules/linear.h | 2 +- torch/csrc/api/include/torch/nn/modules/rnn.h | 16 ++-- .../csrc/api/include/torch/nn/modules/sequential.h | 2 +- torch/csrc/api/src/nn/module.cpp | 22 ++--- torch/csrc/api/src/nn/modules/batchnorm.cpp | 7 +- torch/csrc/api/src/nn/modules/conv.cpp | 6 +- torch/csrc/api/src/nn/modules/dropout.cpp | 4 +- torch/csrc/api/src/nn/modules/embedding.cpp | 2 +- torch/csrc/api/src/nn/modules/linear.cpp | 2 +- torch/csrc/api/src/nn/modules/rnn.cpp | 30 ++++--- torch/csrc/api/src/serialize/input-archive.cpp | 4 +- torch/csrc/autograd/VariableTypeUtils.h | 2 +- torch/csrc/autograd/engine.cpp | 2 +- torch/csrc/autograd/functions/basic_ops.h | 4 +- torch/csrc/autograd/functions/pybind.h | 2 +- torch/csrc/autograd/functions/utils.cpp | 2 +- torch/csrc/autograd/functions/utils.h | 2 +- torch/csrc/autograd/python_cpp_function.cpp | 2 +- torch/csrc/autograd/python_cpp_function.h | 2 +- torch/csrc/autograd/python_function.cpp | 2 +- torch/csrc/autograd/variable.h | 4 +- torch/csrc/jit/batched/BatchTensor.cpp | 12 +-- torch/csrc/jit/batched/BatchTensor.h | 4 +- torch/csrc/jit/constants.cpp | 2 +- torch/csrc/jit/constants.h | 2 +- torch/csrc/jit/fuser/codegen.cpp | 14 ++-- torch/csrc/jit/fuser/kernel_spec.h | 21 ++--- torch/csrc/jit/fuser/tensor_desc.h | 4 +- torch/csrc/jit/hooks_for_testing.cpp | 8 +- torch/csrc/jit/import_method.cpp | 8 +- torch/csrc/jit/ir.cpp | 19 +++-- torch/csrc/jit/ir.h | 18 ++-- torch/csrc/jit/operator.cpp | 9 +- torch/csrc/jit/passes/alias_analysis.cpp | 4 +- torch/csrc/jit/passes/alias_analysis.h | 2 +- .../csrc/jit/passes/create_autodiff_subgraphs.cpp | 2 +- torch/csrc/jit/passes/create_autodiff_subgraphs.h | 2 +- torch/csrc/jit/passes/dead_code_elimination.cpp | 2 +- torch/csrc/jit/passes/python_print.cpp | 10 +-- torch/csrc/jit/passes/remove_inplace_ops.cpp | 2 +- torch/csrc/jit/passes/remove_inplace_ops.h | 4 +- torch/csrc/jit/passes/shape_analysis.cpp | 6 +- torch/csrc/jit/passes/shape_analysis.h | 4 +- torch/csrc/jit/passes/to_batch.cpp | 2 +- torch/csrc/jit/passes/to_batch.h | 2 +- .../jit/passes/utils/check_alias_annotation.cpp | 8 +- .../csrc/jit/passes/utils/check_alias_annotation.h | 2 +- torch/csrc/jit/pybind_utils.h | 4 +- torch/csrc/jit/python_tracer.cpp | 8 +- torch/csrc/jit/python_tracer.h | 15 ++-- torch/csrc/jit/register_prim_ops.cpp | 18 ++-- torch/csrc/jit/scope.h | 2 +- torch/csrc/jit/script/compiler.cpp | 98 +++++++++++----------- torch/csrc/jit/script/compiler.h | 43 ++++++---- torch/csrc/jit/script/init.cpp | 32 +++---- torch/csrc/jit/script/lexer.cpp | 2 +- torch/csrc/jit/script/lexer.h | 2 +- torch/csrc/jit/script/module.cpp | 10 +-- torch/csrc/jit/script/module.h | 12 +-- torch/csrc/jit/script/parser.h | 10 +-- torch/csrc/jit/script/tree.h | 7 +- torch/csrc/jit/script/tree_views.h | 14 ++-- torch/csrc/jit/symbolic_variable.h | 8 +- torch/csrc/jit/tracer.cpp | 2 +- torch/csrc/jit/tracing_state.h | 4 +- torch/csrc/utils/invalid_arguments.cpp | 2 +- torch/csrc/utils/pybind.h | 5 +- torch/csrc/utils/tensor_new.cpp | 18 ++-- 78 files changed, 350 insertions(+), 313 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 6d02359..fe9b9ac 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -24,6 +24,7 @@ Checks: ' ,-modernize-use-auto ,-modernize-use-default-member-init ,-modernize-use-using + ,performance-unnecessary-value-param ' WarningsAsErrors: '*' HeaderFilterRegex: 'torch/csrc/.*' diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp index 1dac422..2e77a8d 100644 --- a/test/cpp/api/module.cpp +++ b/test/cpp/api/module.cpp @@ -240,7 +240,8 @@ TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) { TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) { struct Cloneable : Module { std::shared_ptr clone( - torch::optional device = torch::nullopt) const override { + const torch::optional& device = + torch::nullopt) const override { return nullptr; } }; diff --git a/tools/run-clang-tidy-in-ci.sh b/tools/run-clang-tidy-in-ci.sh index 4d384da..8220c4d 100755 --- a/tools/run-clang-tidy-in-ci.sh +++ b/tools/run-clang-tidy-in-ci.sh @@ -40,7 +40,7 @@ fi # otherwise we'd have to build ONNX protos as part of this CI job. time python tools/clang_tidy.py \ --verbose \ - --paths torch/csrc \ + --paths torch/csrc/ \ --diff "$BASE_BRANCH" \ -g"-torch/csrc/distributed/Module.cpp" \ -g"-torch/csrc/jit/export.cpp" \ diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 51a3d33..c0f3b6b 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -51,8 +51,8 @@ struct python_error : public std::exception { python_error(python_error&& other) { type = std::move(other.type); - value = std::move(other.value); - traceback = std::move(other.traceback); + value = other.value; + traceback = other.traceback; other.type = nullptr; other.value = nullptr; other.traceback = nullptr; diff --git a/torch/csrc/api/include/torch/nn/cloneable.h b/torch/csrc/api/include/torch/nn/cloneable.h index 1f4f1ba..9903725 100644 --- a/torch/csrc/api/include/torch/nn/cloneable.h +++ b/torch/csrc/api/include/torch/nn/cloneable.h @@ -32,7 +32,7 @@ class Cloneable : public virtual Module { /// and submodules in the cloned module are different from those in the /// original module. std::shared_ptr clone( - optional device = nullopt) const override { + const optional& device = nullopt) const override { NoGradGuard no_grad; const auto& self = static_cast(*this); @@ -75,7 +75,7 @@ class Cloneable : public virtual Module { } private: - void clone_(Module& other, optional device) final { + void clone_(Module& other, const optional& device) final { // Here we are *pretty* certain that `other's` type is `Derived` (because it // was registered under the same name as `this`), but you never know what // crazy things `reset()` does, so `dynamic_cast` just to be safe. diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index d1fe89c..43c6231 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -109,7 +109,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// easier-to-use polymorphic interface. /// \endrst virtual std::shared_ptr clone( - optional device = nullopt) const; + const optional& device = nullopt) const; /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `Module&`. @@ -121,7 +121,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// std::cout << module.name() << std::endl; /// }); /// \endrst - void apply(ModuleApplyFunction function); + void apply(const ModuleApplyFunction& function); /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `const Module&`. @@ -133,7 +133,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// std::cout << module.name() << std::endl; /// }); /// \endrst - void apply(ConstModuleApplyFunction function) const; + void apply(const ConstModuleApplyFunction& function) const; /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `const std::string&` for the key of the module, @@ -149,7 +149,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// }); /// \endrst void apply( - NamedModuleApplyFunction function, + const NamedModuleApplyFunction& function, std::string name_prefix = std::string()); /// Applies the `function` to the `Module` and recursively to every submodule. @@ -166,7 +166,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// }); /// \endrst void apply( - ConstNamedModuleApplyFunction function, + const ConstNamedModuleApplyFunction& function, std::string name_prefix = std::string()) const; /// Applies the `function` to the `Module` and recursively to every submodule. @@ -179,7 +179,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// std::cout << module->name() << std::endl; /// }); /// \endrst - void apply(ModulePointerApplyFunction function) const; + void apply(const ModulePointerApplyFunction& function) const; /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `const std::string&` for the key of the module, @@ -197,7 +197,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// }); /// \endrst void apply( - NamedModulePointerApplyFunction function, + const NamedModulePointerApplyFunction& function, std::string name_prefix = std::string()) const; /// Returns the parameters of this `Module` and if `recurse` is true, also @@ -465,7 +465,7 @@ class TORCH_API Module : public std::enable_shared_from_this { // Private methods. /// Used in the implementation of `Cloneable`. - virtual void clone_(Module& other, optional device); + virtual void clone_(Module& other, const optional& device); /// The implementation of the various `to()` methods. template @@ -475,7 +475,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// `Module`'s children (thus not including the module itself). void apply_to_submodules( const NamedModulePointerApplyFunction& function, - std::string name_name_prefix = std::string()) const; + const std::string& name_prefix = std::string()) const; /// Returns a shared_ptr to `this` in a safe (checked) way. std::shared_ptr shared_from_this_checked() const; diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 2d5de89..33aabfd 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -59,11 +59,14 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable { /// The module must be constructed with `stateful = true` when calling this /// method, as the module will otherwise not store running statistics. If you /// want to supply the mean and variance yourself, use `pure_forward`. - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); /// Applies batch normalization on the `input` using the given `mean` and /// `variance` statistics. - Tensor pure_forward(Tensor input, Tensor mean, Tensor variance); + Tensor pure_forward( + const Tensor& input, + const Tensor& mean, + const Tensor& variance); /// The options with which this module was constructed. BatchNormOptions options; diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index 7b146aa..a242af1 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -17,7 +17,7 @@ struct ConvOptions { ConvOptions( int64_t input_channels, int64_t output_channels, - ExpandingArray kernel_size) : + ExpandingArray kernel_size) : input_channels_(input_channels), output_channels_(output_channels), kernel_size_(std::move(kernel_size)) {} @@ -106,7 +106,7 @@ class ConvImpl : public torch::nn::Cloneable { class TORCH_API Conv1dImpl : public ConvImpl<1, Conv1dImpl> { public: using ConvImpl<1, Conv1dImpl>::ConvImpl; - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); }; /// `ConvOptions` specialized for 1-D convolution. @@ -126,7 +126,7 @@ TORCH_MODULE(Conv1d); class TORCH_API Conv2dImpl : public ConvImpl<2, Conv2dImpl> { public: using ConvImpl<2, Conv2dImpl>::ConvImpl; - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); }; /// `ConvOptions` specialized for 2-D convolution. @@ -146,7 +146,7 @@ TORCH_MODULE(Conv2d); class TORCH_API Conv3dImpl : public ConvImpl<3, Conv3dImpl> { public: using ConvImpl<3, Conv3dImpl>::ConvImpl; - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); }; /// `ConvOptions` specialized for 3-D convolution. diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h index 9d2d8b3..46fb0e3 100644 --- a/torch/csrc/api/include/torch/nn/modules/dropout.h +++ b/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -41,7 +41,7 @@ class TORCH_API DropoutImpl : public detail::DropoutImplBase { using detail::DropoutImplBase::DropoutImplBase; /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); }; /// Applies spatial [Dropout](https://arxiv.org/abs/1207.0580) to inputs with @@ -58,7 +58,7 @@ class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase::DropoutImplBase; /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); }; /// A `ModuleHolder` subclass for `DropoutImpl`. diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index 7b7a4b0..1c94884 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -30,7 +30,7 @@ class TORCH_API EmbeddingImpl : public torch::nn::Cloneable { /// Performs a lookup on the embedding table stored in `weight` using the /// `indices` supplied and returns the result. - Tensor forward(Tensor indices); + Tensor forward(const Tensor& indices); /// The `Options` used to configure this `Embedding` module. /// Changes to `EmbeddingOptions` *after construction* have no effect. diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h index 379f802..3a11543 100644 --- a/torch/csrc/api/include/torch/nn/modules/linear.h +++ b/torch/csrc/api/include/torch/nn/modules/linear.h @@ -31,7 +31,7 @@ class TORCH_API LinearImpl : public Cloneable { /// Transforms the `input` tensor by multiplying with the `weight` and /// optionally adding the `bias`, if `with_bias` is true in the options. - Tensor forward(Tensor input); + Tensor forward(const Tensor& input); /// The options used to configure this module. LinearOptions options; diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index 0c48777..c418ed8 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -60,7 +60,7 @@ class RNNImplBase : public torch::nn::Cloneable { enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 }; explicit RNNImplBase( - RNNOptionsBase options_, + const RNNOptionsBase& options_, optional cudnn_mode = nullopt, int64_t number_of_gates = 1); @@ -113,7 +113,7 @@ class RNNImplBase : public torch::nn::Cloneable { /// RNN function as first argument. RNNOutput generic_forward( std::function function, - Tensor input, + const Tensor& input, Tensor state); /// Returns a flat vector of all weights, with layer weights following each @@ -175,13 +175,13 @@ class TORCH_API RNNImpl : public detail::RNNImplBase { public: RNNImpl(int64_t input_size, int64_t hidden_size) : RNNImpl(RNNOptions(input_size, hidden_size)) {} - explicit RNNImpl(RNNOptions options); + explicit RNNImpl(const RNNOptions& options); /// Applies the `RNN` module to an input sequence and input state. /// The `input` should follow a `(sequence, batch, features)` layout unless /// `batch_first` is true, in which case the layout should be `(batch, /// sequence, features)`. - RNNOutput forward(Tensor input, Tensor state = {}); + RNNOutput forward(const Tensor& input, Tensor state = {}); RNNOptions options; }; @@ -203,13 +203,13 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase { public: LSTMImpl(int64_t input_size, int64_t hidden_size) : LSTMImpl(LSTMOptions(input_size, hidden_size)) {} - explicit LSTMImpl(LSTMOptions options); + explicit LSTMImpl(const LSTMOptions& options); /// Applies the `LSTM` module to an input sequence and input state. /// The `input` should follow a `(sequence, batch, features)` layout unless /// `batch_first` is true, in which case the layout should be `(batch, /// sequence, features)`. - RNNOutput forward(Tensor input, Tensor state = {}); + RNNOutput forward(const Tensor& input, Tensor state = {}); }; /// A `ModuleHolder` subclass for `LSTMImpl`. @@ -229,13 +229,13 @@ class TORCH_API GRUImpl : public detail::RNNImplBase { public: GRUImpl(int64_t input_size, int64_t hidden_size) : GRUImpl(GRUOptions(input_size, hidden_size)) {} - explicit GRUImpl(GRUOptions options); + explicit GRUImpl(const GRUOptions& options); /// Applies the `GRU` module to an input sequence and input state. /// The `input` should follow a `(sequence, batch, features)` layout unless /// `batch_first` is true, in which case the layout should be `(batch, /// sequence, features)`. - RNNOutput forward(Tensor input, Tensor state = {}); + RNNOutput forward(const Tensor& input, Tensor state = {}); }; /// A `ModuleHolder` subclass for `GRUImpl`. diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h index 779e32a..32cd913 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/sequential.h @@ -104,7 +104,7 @@ class SequentialImpl : public Cloneable { /// Special cloning function for `Sequential` because it does not use /// `reset()`. std::shared_ptr clone( - optional device = nullopt) const override { + const optional& device = nullopt) const override { auto clone = std::make_shared(); for (const auto& module : modules_) { clone->push_back(module.clone(device)); diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index f608aca..5a38195 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -74,7 +74,7 @@ const std::string& Module::name() const noexcept { return *name_; } -std::shared_ptr Module::clone(optional device) const { +std::shared_ptr Module::clone(const optional& device) const { AT_ERROR( "clone() has not been implemented for ", name(), @@ -83,7 +83,7 @@ std::shared_ptr Module::clone(optional device) const { "> instead of torch::nn::Module to inherit the ability to clone."); } -void Module::apply(ModuleApplyFunction function) { +void Module::apply(const ModuleApplyFunction& function) { function(*this); apply_to_submodules( [&function](const std::string&, const std::shared_ptr& module) { @@ -91,7 +91,7 @@ void Module::apply(ModuleApplyFunction function) { }); } -void Module::apply(ConstModuleApplyFunction function) const { +void Module::apply(const ConstModuleApplyFunction& function) const { function(*this); apply_to_submodules( [&function](const std::string&, const std::shared_ptr& module) { @@ -99,7 +99,9 @@ void Module::apply(ConstModuleApplyFunction function) const { }); } -void Module::apply(NamedModuleApplyFunction function, std::string name_prefix) { +void Module::apply( + const NamedModuleApplyFunction& function, + std::string name_prefix) { function(/*name=*/name_prefix, *this); apply_to_submodules( [&function]( @@ -110,7 +112,7 @@ void Module::apply(NamedModuleApplyFunction function, std::string name_prefix) { } void Module::apply( - ConstNamedModuleApplyFunction function, + const ConstNamedModuleApplyFunction& function, std::string name_prefix) const { function(/*name=*/name_prefix, *this); apply_to_submodules( @@ -121,7 +123,7 @@ void Module::apply( std::move(name_prefix)); } -void Module::apply(ModulePointerApplyFunction function) const { +void Module::apply(const ModulePointerApplyFunction& function) const { function(shared_from_this_checked()); apply_to_submodules( [&function](const std::string&, const std::shared_ptr& module) { @@ -130,7 +132,7 @@ void Module::apply(ModulePointerApplyFunction function) const { } void Module::apply( - NamedModulePointerApplyFunction function, + const NamedModulePointerApplyFunction& function, std::string name_prefix) const { function( /*name=*/name_prefix, shared_from_this_checked()); @@ -319,13 +321,13 @@ Tensor& Module::register_buffer(std::string name, Tensor tensor) { return buffers_.insert(std::move(name), std::move(tensor)); } -void Module::clone_(Module& other, optional device) {} +void Module::clone_(Module& other, const optional& device) {} void Module::apply_to_submodules( const NamedModulePointerApplyFunction& function, - std::string name_name_prefix) const { + const std::string& name_prefix) const { for (const auto& child : children_) { - auto qualified_name = join_name(name_name_prefix, child.key()); + auto qualified_name = join_name(name_prefix, child.key()); function(qualified_name, child.value()); child.value()->apply_to_submodules(function, std::move(qualified_name)); } diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index a8aa886..43f1be1 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -33,7 +33,7 @@ void BatchNormImpl::reset() { } } -Tensor BatchNormImpl::forward(Tensor input) { +Tensor BatchNormImpl::forward(const Tensor& input) { AT_CHECK( options.stateful_, "Calling BatchNorm::forward is only permitted when " @@ -42,7 +42,10 @@ Tensor BatchNormImpl::forward(Tensor input) { return pure_forward(input, running_mean, running_variance); } -Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) { +Tensor BatchNormImpl::pure_forward( + const Tensor& input, + const Tensor& mean, + const Tensor& variance) { if (is_training()) { const auto num_channels = input.dim() > 1 ? input.size(1) : 1; AT_CHECK( diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 1eae8af..581d38d 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -59,7 +59,7 @@ void ConvImpl::reset() { } } -Tensor Conv1dImpl::forward(Tensor input) { +Tensor Conv1dImpl::forward(const Tensor& input) { if (options.transposed_) { return torch::conv_transpose1d( input, @@ -81,7 +81,7 @@ Tensor Conv1dImpl::forward(Tensor input) { options.groups_); } -Tensor Conv2dImpl::forward(Tensor input) { +Tensor Conv2dImpl::forward(const Tensor& input) { if (options.transposed_) { return torch::conv_transpose2d( input, @@ -103,7 +103,7 @@ Tensor Conv2dImpl::forward(Tensor input) { options.groups_); } -Tensor Conv3dImpl::forward(Tensor input) { +Tensor Conv3dImpl::forward(const Tensor& input) { if (options.transposed_) { return torch::conv_transpose3d( input, diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp index b7887f4..2ec5400 100644 --- a/torch/csrc/api/src/nn/modules/dropout.cpp +++ b/torch/csrc/api/src/nn/modules/dropout.cpp @@ -26,11 +26,11 @@ template class DropoutImplBase; DropoutOptions::DropoutOptions(double rate) : rate_(rate) {} -Tensor DropoutImpl::forward(Tensor input) { +Tensor DropoutImpl::forward(const Tensor& input) { return torch::dropout(input, options.rate_, this->is_training()); } -Tensor FeatureDropoutImpl::forward(Tensor input) { +Tensor FeatureDropoutImpl::forward(const Tensor& input) { return torch::feature_dropout(input, options.rate_, this->is_training()); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index d09afe0..7797561 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -25,7 +25,7 @@ void EmbeddingImpl::reset() { weight.normal_(0, 1); } -Tensor EmbeddingImpl::forward(Tensor input) { +Tensor EmbeddingImpl::forward(const Tensor& input) { return torch::embedding(weight, /*indices=*/input); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index c7c17bf..e2766ab 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -28,7 +28,7 @@ void LinearImpl::reset() { } } -Tensor LinearImpl::forward(Tensor input) { +Tensor LinearImpl::forward(const Tensor& input) { AT_ASSERT(!options.with_bias_ || bias.defined()); return torch::linear(input, weight, bias); } diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 206be24..8058cc5 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -29,7 +29,7 @@ RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size) template RNNImplBase::RNNImplBase( - RNNOptionsBase options_, + const RNNOptionsBase& options_, optional cudnn_mode, int64_t number_of_gates) : options(options_), @@ -121,7 +121,7 @@ void RNNImplBase::flatten_parameters() { template RNNOutput RNNImplBase::generic_forward( std::function function, - Tensor input, + const Tensor& input, Tensor state) { if (!state.defined()) { // #layers, batch size, state size @@ -192,7 +192,7 @@ RNNOptions& RNNOptions::relu() { return activation(RNNActivation::ReLU); } -RNNImpl::RNNImpl(RNNOptions options) +RNNImpl::RNNImpl(const RNNOptions& options) : detail::RNNImplBase( detail::RNNOptionsBase(options.input_size_, options.hidden_size_) .layers(options.layers_) @@ -203,14 +203,18 @@ RNNImpl::RNNImpl(RNNOptions options) static_cast(options.activation_)), options(options) {} -RNNOutput RNNImpl::forward(Tensor input, Tensor state) { +RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) { switch (options.activation_) { case RNNActivation::ReLU: return generic_forward( - static_cast(&torch::rnn_relu), input, state); + static_cast(&torch::rnn_relu), + std::move(input), + std::move(state)); case RNNActivation::Tanh: return generic_forward( - static_cast(&torch::rnn_tanh), input, state); + static_cast(&torch::rnn_tanh), + std::move(input), + std::move(state)); default: AT_ERROR("Unhandled RNN activation function!"); } @@ -218,13 +222,13 @@ RNNOutput RNNImpl::forward(Tensor input, Tensor state) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -LSTMImpl::LSTMImpl(LSTMOptions options) +LSTMImpl::LSTMImpl(const LSTMOptions& options) : detail::RNNImplBase( options, CuDNNMode::LSTM, /*number_of_gates=*/4) {} -RNNOutput LSTMImpl::forward(Tensor input, Tensor state) { +RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) { // It would be trickier to adapt the `generic_forward` for the LSTM because // its output has a different dimensionality (3-tuple vs. 2-tuple), while we // always return one state variable (stacking the hidden/cell state into one), @@ -240,7 +244,7 @@ RNNOutput LSTMImpl::forward(Tensor input, Tensor state) { } Tensor output, hidden_state, cell_state; std::tie(output, hidden_state, cell_state) = torch::lstm( - input, + std::move(input), {state[0], state[1]}, flat_weights_, options.with_bias_, @@ -254,15 +258,17 @@ RNNOutput LSTMImpl::forward(Tensor input, Tensor state) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -GRUImpl::GRUImpl(GRUOptions options) +GRUImpl::GRUImpl(const GRUOptions& options) : detail::RNNImplBase( options, CuDNNMode::GRU, /*number_of_gates=*/3) {} -RNNOutput GRUImpl::forward(Tensor input, Tensor state) { +RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) { return generic_forward( - static_cast(&torch::gru), input, state); + static_cast(&torch::gru), + std::move(input), + std::move(state)); } } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 444fe79..4cd934a 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -54,12 +54,12 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { void InputArchive::load_from(const std::string& filename, c10::optional device /*= c10::nullopt*/) { - module_ = torch::jit::load(filename, device); + module_ = torch::jit::load(filename, std::move(device)); } void InputArchive::load_from(std::istream& stream, c10::optional device /*= c10::nullopt*/) { - module_ = torch::jit::load(stream, device); + module_ = torch::jit::load(stream, std::move(device)); } } // namespace serialize } // namespace torch diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 1550783..5025370 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -71,7 +71,7 @@ inline void rebase_history(std::vector&& vars, std::shared_ptradd_input_metadata(var); - var.rebase_history({grad_fn, output_nr}); + var.rebase_history({std::move(grad_fn), output_nr}); } else { grad_fn->add_input_metadata(Function::undefined_input()); } diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f12a927..b2f5da5 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -306,7 +306,7 @@ static variable_list call_pre_hooks(Function& fn, variable_list inputs) { return inputs; } -static variable_list call_post_hooks(Function& fn, variable_list outputs, variable_list inputs) { +static variable_list call_post_hooks(Function& fn, variable_list outputs, const variable_list& inputs) { for (const auto& hook : fn.post_hooks()) { outputs = (*hook)(outputs, inputs); } diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 18c3ddb..6492835 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -28,11 +28,11 @@ struct TORCH_API Error : public Function { // NYI, grad_fn= will be printed if we use Error, which is confusing. So // special case with a new NotImplemented function here. struct TORCH_API NotImplemented : public Error { - NotImplemented(std::string forward_fn, edge_list&& next_edges) + NotImplemented(const std::string& forward_fn, edge_list&& next_edges) : Error("derivative for " + forward_fn + " is not implemented", std::move(next_edges)) {} - NotImplemented(std::string forward_fn) + NotImplemented(const std::string& forward_fn) : Error("derivative for " + forward_fn + " is not implemented") {} }; diff --git a/torch/csrc/autograd/functions/pybind.h b/torch/csrc/autograd/functions/pybind.h index 89a1547..e04053f 100644 --- a/torch/csrc/autograd/functions/pybind.h +++ b/torch/csrc/autograd/functions/pybind.h @@ -22,7 +22,7 @@ public: return true; } static handle cast(std::shared_ptr src, return_value_policy /* policy */, handle /* parent */) { - auto fn = functionToPyObject(src); + auto fn = functionToPyObject(std::move(src)); return handle(fn); } }; diff --git a/torch/csrc/autograd/functions/utils.cpp b/torch/csrc/autograd/functions/utils.cpp index 7b41e89..70b4aa7 100644 --- a/torch/csrc/autograd/functions/utils.cpp +++ b/torch/csrc/autograd/functions/utils.cpp @@ -10,7 +10,7 @@ namespace torch { namespace autograd { variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, - function_constructor ctr) { + const function_constructor& ctr) { variable_list result; result.reserve(outputs.size()); if (!any_variable_requires_grad(inputs)) { diff --git a/torch/csrc/autograd/functions/utils.h b/torch/csrc/autograd/functions/utils.h index 65d45c3..c632526 100644 --- a/torch/csrc/autograd/functions/utils.h +++ b/torch/csrc/autograd/functions/utils.h @@ -20,7 +20,7 @@ using function_constructor = std::function(edge_list&& * grad_fn if necessary. */ TORCH_API variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs, - function_constructor ctr); + const function_constructor& ctr); /// Checks that inputs contains exactly `args` items and that the first `required_args` /// items are not nullptr. If not specified, `required_args` defaults to `args`. diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 1edf794..08e10a5 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -198,7 +198,7 @@ struct DefaultFunctionType { PyTypeObject type; }; -PyObject* functionToPyObject(std::shared_ptr cdata) +PyObject* functionToPyObject(const std::shared_ptr& cdata) { static DefaultFunctionType default_type; diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index 4de2d7b..90352e3 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -61,6 +61,6 @@ PyTypeObject* createForwardFunctionPyTypeObject(PyTypeObject& type, const char* } void registerCppFunction(const std::type_info& type, PyTypeObject* pytype); -PyObject* functionToPyObject(std::shared_ptr cdata); +PyObject* functionToPyObject(const std::shared_ptr& cdata); }} // namespace torch::autograd diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 9ec8899..9a9d8c6 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -882,7 +882,7 @@ PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook) static PyObject *unpack_saved_variables( THPFunction *self, - std::function unpack_fn) + const std::function& unpack_fn) { THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); auto& saved_variables = self->saved_variables; diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index e28ff26..b67a810 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -596,11 +596,11 @@ inline void Variable::backward( c10::optional gradient, bool keep_graph, bool create_graph) const { - get()->backward(gradient, keep_graph, create_graph); + get()->backward(std::move(gradient), keep_graph, create_graph); } inline void Variable::set_data(Tensor new_data) const { - get()->set_data(new_data); + get()->set_data(std::move(new_data)); } inline void Variable::set_gradient_edge(Edge edge) noexcept { diff --git a/torch/csrc/jit/batched/BatchTensor.cpp b/torch/csrc/jit/batched/BatchTensor.cpp index f8a3cc7..625166d 100644 --- a/torch/csrc/jit/batched/BatchTensor.cpp +++ b/torch/csrc/jit/batched/BatchTensor.cpp @@ -8,12 +8,12 @@ BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims){ + std::to_string(data.dim()) + ", mask.dim(): " + std::to_string(mask.dim()) + ", dims.size(0): " + std::to_string(dims.size(0))); } - this->data = data; - this->mask = mask; - this->dims = dims; + this->data = std::move(data); + this->mask = std::move(mask); + this->dims = std::move(dims); } -BatchTensor::BatchTensor(at::Tensor data, int64_t batch_size){ +BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){ dims = at::empty(data.dim(), data.options().dtype(at::kByte)); dims.fill_(0); std::vector sizes(data.dim() + 1, -1); @@ -25,7 +25,7 @@ BatchTensor::BatchTensor(at::Tensor data, int64_t batch_size){ mask.fill_(1); } -BatchTensor::BatchTensor(const std::vector datalist, at::Tensor dims) { +BatchTensor::BatchTensor(const std::vector& datalist, at::Tensor dims) { auto bs = datalist.size(); std::vector sizes(dims.size(0) + 1, 0), mask_sizes(dims.size(0) + 1, 0); sizes[0] = bs; @@ -52,7 +52,7 @@ BatchTensor::BatchTensor(const std::vector datalist, at::Tensor dims data_item += datalist[i]; mask_item.fill_(1); } - this->dims = dims; + this->dims = std::move(dims); } std::vector BatchTensor::examples() { diff --git a/torch/csrc/jit/batched/BatchTensor.h b/torch/csrc/jit/batched/BatchTensor.h index fc35e5b..a7acd27 100644 --- a/torch/csrc/jit/batched/BatchTensor.h +++ b/torch/csrc/jit/batched/BatchTensor.h @@ -10,8 +10,8 @@ struct BatchTensor { public: BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims); // expand a tensor to a batchtensor given batch_size - BatchTensor(at::Tensor data, int64_t batch_size); - BatchTensor(const std::vector datalist, at::Tensor dims); + BatchTensor(const at::Tensor& data, int64_t batch_size); + BatchTensor(const std::vector& datalist, at::Tensor dims); const char * toString() const { return "BatchTensor"; } diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 86d3f2c..4912d58 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -9,7 +9,7 @@ namespace torch { namespace jit { // IValue -> Constant node Value* insertConstant( Graph& g, - IValue val, + const IValue& val, c10::optional loc, c10::optional scope) { Node * n = g.create(prim::Constant); diff --git a/torch/csrc/jit/constants.h b/torch/csrc/jit/constants.h index b86d746..d64bc15 100644 --- a/torch/csrc/jit/constants.h +++ b/torch/csrc/jit/constants.h @@ -22,7 +22,7 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error { // closely related to the implementation of prim::Constant that is also in constants.cpp TORCH_API Value* insertConstant( Graph& g, - IValue val, + const IValue& val, c10::optional loc = c10::nullopt, c10::optional scope = c10::nullopt); diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index ebfee5c..371226e 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -11,11 +11,11 @@ #if USE_CUDA_FUSER #include -#endif +#endif #if USE_CPU_FUSER #include -#endif +#endif #include #include @@ -47,7 +47,7 @@ static std::string scalarValue(const bool v) { } // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific -// implementations of these special values. These macros are found in the +// implementations of these special values. These macros are found in the // resource strings for each device. static std::string scalarValue(const double v) { std::ostringstream out; @@ -89,7 +89,7 @@ static const char* calcScalarTypeName(const at::ScalarType type) { } -static std::string variableType(const std::shared_ptr t) { +static std::string variableType(const std::shared_ptr& t) { if (t->kind() == TypeKind::IntType) { return "int"; } else if (t->kind() == TypeKind::FloatType) { @@ -104,7 +104,7 @@ static std::string variableType(const std::shared_ptr t) { throw std::runtime_error("unknown scalar type during JIT fusion code generation"); } -static std::string typeCastedValueName(const std::shared_ptr t, const at::ScalarType outtype, const std::string& vn) { +static std::string typeCastedValueName(const std::shared_ptr& t, const at::ScalarType outtype, const std::string& vn) { if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) { if (! isIntegralType(outtype)) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; @@ -272,7 +272,7 @@ std::tuple< std::string , std::vector , std::vector -, bool> +, bool> generateKernel( const std::string& name , const Graph& graph @@ -376,7 +376,7 @@ generateKernel( bool has_random = false; // Generates code for intermediate nodes // Note: Concat and Chunk are implicitly generated - // Note: Random number generation is only supported for CUDA kernels. + // Note: Random number generation is only supported for CUDA kernels. for (const auto& n : graph.nodes()) { // Note: FusedConcat nodes work by narrowing the output Tensors before the kernel runs if (n->kind() == prim::FusedConcat) continue; diff --git a/torch/csrc/jit/fuser/kernel_spec.h b/torch/csrc/jit/fuser/kernel_spec.h index e768377..5942bac 100644 --- a/torch/csrc/jit/fuser/kernel_spec.h +++ b/torch/csrc/jit/fuser/kernel_spec.h @@ -53,19 +53,16 @@ private: // TODO: allow abstract kernels to use multiple generated kernels // TODO: allow abstract kernels to reuse generated kernels from common pool struct TORCH_API KernelSpec { - KernelSpec( - const int64_t _key - , std::shared_ptr _graph) - : key_{_key} - , graph_{_graph} - , code_{_graph} - , nInputs_{_graph->inputs().size()} - , inputBroadcastGroups_{} - , inputChunks_{} - , kernels_{} - { } + KernelSpec(const int64_t _key, const std::shared_ptr& _graph) + : key_{_key}, + graph_{_graph}, + code_{_graph}, + nInputs_{_graph->inputs().size()}, + inputBroadcastGroups_{}, + inputChunks_{}, + kernels_{} {} - // Getters + // Getters int64_t key() const { return key_; } std::shared_ptr graph() const { return graph_; } const Code& code() const { return code_; } diff --git a/torch/csrc/jit/fuser/tensor_desc.h b/torch/csrc/jit/fuser/tensor_desc.h index 4c43169..d1f0b60 100644 --- a/torch/csrc/jit/fuser/tensor_desc.h +++ b/torch/csrc/jit/fuser/tensor_desc.h @@ -43,7 +43,7 @@ struct TORCH_API TensorDesc { TensorDesc(const at::Tensor& t) : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} - TensorDesc(CompleteTensorTypePtr type) + TensorDesc(const CompleteTensorTypePtr& type) : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {} // number of dimensions after contiguity compression @@ -91,7 +91,7 @@ inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { } } // namespace fuser -} // namespace jit +} // namespace jit } // namespace torch #endif // USE_CUDA_FUSER || USE_CPU_FUSER diff --git a/torch/csrc/jit/hooks_for_testing.cpp b/torch/csrc/jit/hooks_for_testing.cpp index 9e51d908..b80f0ab 100644 --- a/torch/csrc/jit/hooks_for_testing.cpp +++ b/torch/csrc/jit/hooks_for_testing.cpp @@ -6,12 +6,12 @@ namespace jit { static std::function module)> emit_module_callback; TORCH_API void didFinishEmitModule(std::shared_ptr module) { - if(emit_module_callback) - emit_module_callback(module); - + if(emit_module_callback) { + emit_module_callback(std::move(module)); + } } TORCH_API void setEmitModuleHook(std::function module)> cb) { - emit_module_callback = cb; + emit_module_callback = std::move(cb); } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/import_method.cpp b/torch/csrc/jit/import_method.cpp index 3bcedb2..6caff06 100644 --- a/torch/csrc/jit/import_method.cpp +++ b/torch/csrc/jit/import_method.cpp @@ -13,7 +13,7 @@ struct ModuleAccessorValue : public script::SugaredValue { return "module"; } // select an attribute on it, e.g. `this.field` - std::shared_ptr attr(SourceRange loc, script::Method & m, const std::string& field) override { + std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { if(script::NamedModule* v = module->find_module(field)) { return std::make_shared(v->module); } else if(script::NamedParameter* v = module->find_parameter(field)) { @@ -34,7 +34,7 @@ struct OpsValue : public script::SugaredValue { std::string kind() const override { return "ops"; } - std::shared_ptr attr(SourceRange loc, script::Method & m, const std::string& field) override { + std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { return std::make_shared(field, version_); } size_t version_; @@ -45,7 +45,7 @@ struct ConstantValue : public script::SugaredValue { : value_(std::move(value)) {} IValue value_; std::string kind() const override { return "constant"; } - Value * asValue(SourceRange loc, script::Method & m) override { + Value * asValue(const SourceRange& loc, script::Method & m) override { return m.graph()->insertConstant(value_); } }; @@ -60,7 +60,7 @@ struct ConstantTableValue : public script::SugaredValue { return "CONSTANTS"; } // select an attribute on it, e.g. `this.field` - std::shared_ptr attr(SourceRange loc, script::Method & m, const std::string& field) override { + std::shared_ptr attr(const SourceRange& loc, script::Method & m, const std::string& field) override { const char* field_s = field.c_str(); char* end; int64_t offset = std::strtoll(field_s + 1, &end, 10); diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 80f6079..5c085a7 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -9,14 +9,15 @@ #include #include +#include #include -#include -#include #include -#include #include -#include +#include #include +#include +#include +#include namespace torch { namespace jit { // Constants relating to maintaining the topological index of nodes. @@ -1294,7 +1295,7 @@ Value* Graph::insert( Symbol opname, at::ArrayRef args, at::ArrayRef kwargs, - c10::optional range) { + const c10::optional& range) { return script::emitBuiltinCall( range.value_or(fakeRange()), *this, @@ -1326,7 +1327,7 @@ Node* Graph::createUndefined() { Node* Graph::createNone(TypePtr typ) { Node * n = create(prim::None); - n->output()->setType(OptionalType::create(typ)); + n->output()->setType(OptionalType::create(std::move(typ))); return n; } @@ -1402,7 +1403,7 @@ Node* Graph::createListUnpack(Value *v, size_t size) { Node* Graph::createNumToTensor(Value* value) { auto typ = value->type(); Node * result = create(prim::NumToTensor, {value}); - result->output()->setType(CompleteTensorType::fromNumberType(typ)); + result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ))); return result; } @@ -1412,7 +1413,7 @@ Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) { return result; } -Node* Graph::createClone(Node * n, std::function value_map, bool copy_blocks) { +Node* Graph::createClone(Node * n, const std::function& value_map, bool copy_blocks) { //n can be from a different graph Node * r = n->allocNewInstance(this); for(auto o : n->outputs()) { @@ -1434,7 +1435,7 @@ Value* Graph::insertConstant( IValue val, c10::optional loc, c10::optional scope) { - return jit::insertConstant(*this, std::move(val), loc, scope); + return jit::insertConstant(*this, std::move(val), std::move(loc), std::move(scope)); } std::string Graph::toString() const { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 2d023cc..4823067 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -264,7 +264,7 @@ public: return scope_; } void setScope(ScopePtr scope) { - scope_ = scope; + scope_ = std::move(scope); } std::string scopeName() const { if (!scope_) { @@ -676,12 +676,12 @@ struct Block { } Value * addInput(std::string name="") { Value * v = input_->addOutput(); - v->setUniqueName(name); + v->setUniqueName(std::move(name)); return v; } Value* insertInput(size_t i, std::string name = "") { Value* v = input_->insertOutput(i); - v->setUniqueName(name); + v->setUniqueName(std::move(name)); return v; } void eraseInput(size_t i) { @@ -829,7 +829,7 @@ public: return current_scope_; } void set_current_scope(ScopePtr scope) { - current_scope_ = scope; + current_scope_ = std::move(scope); } Value * addInput(std::string name="") { return block_->addInput(std::move(name)); @@ -876,7 +876,7 @@ public: // use node_map to translate inputs of n to inputs of the cloned node // if copy_blocks is false, it will not recursively clone the nested blocks // this node contains. - TORCH_API Node * createClone(Node * n, std::function value_map, bool copy_blocks=true); + TORCH_API Node * createClone(Node * n, const std::function& value_map, bool copy_blocks=true); TORCH_API Value* insertConstant( IValue val, @@ -893,7 +893,7 @@ public: Symbol opname, at::ArrayRef args, at::ArrayRef kwargs = {}, - c10::optional range = {}); + const c10::optional& range = {}); Node * appendNode(Node * n) { return block_->appendNode(n); @@ -976,7 +976,7 @@ struct WithCurrentScope : public ResourceGuard { g.set_current_scope(prev_scope); }) , prev_scope(g.current_scope()) { - g.set_current_scope(scope); + g.set_current_scope(std::move(scope)); } private: ScopePtr prev_scope; @@ -990,9 +990,9 @@ inline Value::Value(Node * node_, size_t offset_) node_->graph_->all_values.emplace(this); } -inline Value* Value::setType(const TypePtr type) { +inline Value* Value::setType(TypePtr type) { JIT_ASSERT(type); - type_ = type; + type_ = std::move(type); for (Use & use : uses_) { use.user->schema_ = nullptr; } diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 9fa6548..00a5027 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -7,6 +7,11 @@ #include #include +#include +#include +#include +#include + namespace torch { namespace jit { namespace script { @@ -290,7 +295,7 @@ struct SchemaParser { L.expect(TK_NONE); return IValue(); } - IValue parseDefaultValue(TypePtr arg_type, c10::optional arg_N) { + IValue parseDefaultValue(const TypePtr& arg_type, c10::optional arg_N) { auto range = L.cur().range; switch(arg_type->kind()) { case TypeKind::DynamicType: @@ -328,7 +333,7 @@ struct SchemaParser { return IValue(); // silence warnings } - void parseList(int begin, int sep, int end, std::function callback) { + void parseList(int begin, int sep, int end, const std::function& callback) { auto r = L.cur().range; if (begin != TK_NOTHING) L.expect(begin); diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 9330030..0f931a2 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { namespace { -bool shouldAnnotate(TypePtr type) { +bool shouldAnnotate(const TypePtr& type) { return type->isSubtypeOf(DynamicType::get()) || type->kind() == TypeKind::ListType || type->kind() == TypeKind::TupleType || @@ -202,7 +202,7 @@ void AliasDb::dump() const { } } -void AliasDb::analyze(std::shared_ptr graph) { +void AliasDb::analyze(const std::shared_ptr& graph) { // Assign aliases to the graph's inputs, assuming that all inputs of a given // type may alias to each other. const auto tensorAlias = getFreshAlias(/*isGraphInput=*/true); diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index 5c543a7..687b171 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -53,7 +53,7 @@ class AliasDb { void dump() const; private: - void analyze(std::shared_ptr graph); + void analyze(const std::shared_ptr& graph); void analyze(Block* block); void analyze(Node* node); diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 17ad5a8..198323a 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -168,7 +168,7 @@ class SubgraphSlicer { } // anonymous namespace std::vector CreateAutodiffSubgraphs( - std::shared_ptr graph, + const std::shared_ptr& graph, size_t threshold) { std::vector diff_nodes; SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes); diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.h b/torch/csrc/jit/passes/create_autodiff_subgraphs.h index cf63239..320a0db 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.h +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.h @@ -12,6 +12,6 @@ namespace torch { namespace jit { // threshold - minimum number of nodes that will appear in a block // returns all differentiable blocks that have been found TORCH_API std::vector CreateAutodiffSubgraphs( - std::shared_ptr graph, + const std::shared_ptr& graph, size_t threshold = 2); }} diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 5503a68..bea2d5b 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -10,7 +10,7 @@ namespace jit { class DeadCodeEliminator { public: explicit DeadCodeEliminator(std::shared_ptr graph) - : aliasDb_(AliasAnalysis(graph)) {} + : aliasDb_(AliasAnalysis(std::move(graph))) {} DeadCodeEliminator(bool collect_only = false) : collect_only_(collect_only) {} diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index e160b30..8f23952 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -125,7 +125,7 @@ private: void createTensorToParameterNameMap( const script::Module& module, - QualifiedNamePtr prefix, + const QualifiedNamePtr& prefix, std::unordered_map& result) { for (const auto& elem : module.get_parameters()) { @@ -615,7 +615,7 @@ struct PythonPrintPass { std::ostream& stmt, const char* the_type, size_t list_size, - IValue the_list) { + const IValue& the_list) { if(list_size == 0) { stmt << "annotate(List[" << the_type << "], [])"; } else { @@ -623,7 +623,7 @@ struct PythonPrintPass { } } - void printConstant(std::ostream& stmt, IValue v) { + void printConstant(std::ostream& stmt, const IValue& v) { if(v.isTensor()) { stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor()); } else if(v.isString()) { @@ -813,7 +813,7 @@ struct PythonPrintPass { return out; } - void printDefaultValue(std::ostream& stmt, IValue value) { + void printDefaultValue(std::ostream& stmt, const IValue& value) { if (value.isTensor() && !value.toTensor().defined()) { // XXX - because undefined tensors are not stored as None, we need special handling. // otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because @@ -827,7 +827,7 @@ struct PythonPrintPass { void printFunctionDefinition( Graph& graph, const std::string& name, - const std::vector> defaults = {}, + const std::vector>& defaults = {}, const std::vector& param_names = {}) { used_names_.clear(); // each graph can reuse local names diff --git a/torch/csrc/jit/passes/remove_inplace_ops.cpp b/torch/csrc/jit/passes/remove_inplace_ops.cpp index 5dc85f1..0a17c63 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.cpp +++ b/torch/csrc/jit/passes/remove_inplace_ops.cpp @@ -52,7 +52,7 @@ void RemoveInplaceOps(Block* block) { } } -void RemoveInplaceOps(std::shared_ptr graph) { +void RemoveInplaceOps(const std::shared_ptr& graph) { RemoveInplaceOps(graph->block()); } } // namespace jit diff --git a/torch/csrc/jit/passes/remove_inplace_ops.h b/torch/csrc/jit/passes/remove_inplace_ops.h index 9919f9a..a26c53d 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.h +++ b/torch/csrc/jit/passes/remove_inplace_ops.h @@ -2,9 +2,11 @@ #include +#include + namespace torch { namespace jit { // see .cpp for docs -TORCH_API void RemoveInplaceOps(std::shared_ptr graph); +TORCH_API void RemoveInplaceOps(const std::shared_ptr& graph); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index cedc221..129f478 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -44,7 +44,7 @@ bool isValidReturnForRunning(Value* v) { class ShapePropagator { public: explicit ShapePropagator(std::shared_ptr graph) - : aliasDb_(AliasAnalysis(graph)) {} + : aliasDb_(AliasAnalysis(std::move(graph))) {} void PropagateShapeOnBlock(Block* block, bool insert_expands = true) { for (Node* node : block->nodes()) { @@ -1583,7 +1583,7 @@ class ShapePropagator { }; } // anonymous namespace -void PropagateInputShapes(std::shared_ptr graph) { +void PropagateInputShapes(const std::shared_ptr& graph) { ShapePropagator(graph).PropagateShapeOnBlock(graph->block()); } @@ -1608,7 +1608,7 @@ void EraseShapeInformation(Block * b) { } // anonymous namespace -void EraseShapeInformation(std::shared_ptr graph) { +void EraseShapeInformation(const std::shared_ptr& graph) { EraseShapeInformation(graph->block()); } diff --git a/torch/csrc/jit/passes/shape_analysis.h b/torch/csrc/jit/passes/shape_analysis.h index b61c7a8..5886c68 100644 --- a/torch/csrc/jit/passes/shape_analysis.h +++ b/torch/csrc/jit/passes/shape_analysis.h @@ -7,7 +7,7 @@ namespace torch { namespace jit { struct Graph; -TORCH_API void EraseShapeInformation(std::shared_ptr graph); -TORCH_API void PropagateInputShapes(std::shared_ptr graph); +TORCH_API void EraseShapeInformation(const std::shared_ptr& graph); +TORCH_API void PropagateInputShapes(const std::shared_ptr& graph); }} diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp index be6bd31..8877806 100644 --- a/torch/csrc/jit/passes/to_batch.cpp +++ b/torch/csrc/jit/passes/to_batch.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { std::unordered_map>> ToBatch::batch_operator_table; -std::shared_ptr ToBatch::getBatchOperator(std::string name, int64_t num_inputs){ +std::shared_ptr ToBatch::getBatchOperator(const std::string& name, int64_t num_inputs){ if(batch_operator_table.find(name) == batch_operator_table.end()){ throw std::runtime_error("function " + name + " is not supported in batched tensor yet"); } diff --git a/torch/csrc/jit/passes/to_batch.h b/torch/csrc/jit/passes/to_batch.h index 8aeb6e7..4c2a098 100644 --- a/torch/csrc/jit/passes/to_batch.h +++ b/torch/csrc/jit/passes/to_batch.h @@ -19,7 +19,7 @@ private: std::function rn_fn = [this](Value* v) { return rn_env.at(v); }; private: - std::shared_ptr getBatchOperator(std::string name, int64_t input_num = -1); + std::shared_ptr getBatchOperator(const std::string& name, int64_t input_num = -1); void visitAten(Node* n, Block* block, Block* res_block); void visitConstant(Node* n, Block* block, Block* res_block); void visitNumToTensor(Node* n, Block* block, Block* res_block); diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index f81837d..a9432c5 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -73,9 +73,9 @@ bool deepEquals(const IValue& lhs, const IValue& rhs) { struct AliasAndIValue { AliasAndIValue( - const c10::optional& aliasInfo, - const IValue& iValue) - : aliasInfo(aliasInfo), iValue(iValue) {} + c10::optional aliasInfo, + IValue iValue) + : aliasInfo(std::move(aliasInfo)), iValue(std::move(iValue)) {} const c10::optional aliasInfo; const IValue iValue; @@ -185,7 +185,7 @@ c10::optional toIValueProp(const Value* v) { } // namespace void checkAliasAnnotation( - std::shared_ptr graph, + const std::shared_ptr& graph, std::vector pythonInputs, const std::string& unqualifiedOpName) { // Find the node that corresponds to our op name diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.h b/torch/csrc/jit/passes/utils/check_alias_annotation.h index 841963c..b2d7e8e 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.h +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.h @@ -13,7 +13,7 @@ namespace jit { // This function expects a graph with a single op with `unqualifiedOpName`, plus // the inputs that you would otherwise have passed to the graph executor. TORCH_API void checkAliasAnnotation( - std::shared_ptr graph, + const std::shared_ptr& graph, std::vector pythonInputs, const std::string& unqualifiedOpName); } // namespace jit diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 560d944..2b6c336 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -321,8 +321,8 @@ private: inline Stack createStackForSchema( const FunctionSchema& schema, - tuple_slice args, - py::kwargs kwargs = py::kwargs()) { + const tuple_slice& args, + const py::kwargs& kwargs = py::kwargs()) { if(args.size() + kwargs.size() > schema.arguments().size()) { throw std::runtime_error(c10::str( schema.name(), "() expected at most ", schema.arguments().size(), diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 3546f93..e5619a9 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -37,11 +37,11 @@ std::string getPythonInterpreterStackTrace() { } std::shared_ptr createGraphByTracing( - py::function func, + const py::function& func, Stack trace_inputs, - py::function var_name_lookup_fn, + const py::function& var_name_lookup_fn, bool force_outplace, - c10::optional num_real_inputs) { + const c10::optional& num_real_inputs) { size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size()); auto enter_info = tracer::enter(std::move(trace_inputs)); getTracingState()->lookup_var_name_fn = [var_name_lookup_fn](const Variable& var) -> std::string { @@ -76,7 +76,7 @@ std::shared_ptr createGraphByTracing( } Node* preRecordPythonTrace(THPObjectPtr pyobj, - std::string arg_types, + const std::string& arg_types, at::ArrayRef inputs, pyobj_list scalar_args) { THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply")); diff --git a/torch/csrc/jit/python_tracer.h b/torch/csrc/jit/python_tracer.h index cf357ad..1bb52bb 100644 --- a/torch/csrc/jit/python_tracer.h +++ b/torch/csrc/jit/python_tracer.h @@ -1,25 +1,28 @@ #pragma once #include -#include #include #include +#include +#include + namespace torch { namespace jit { namespace tracer { void initPythonTracerBindings(PyObject *module); std::string getPythonInterpreterStackTrace(); Node* preRecordPythonTrace( - THPObjectPtr pyobj, std::string arg_types, at::ArrayRef inputs, + THPObjectPtr pyobj, + const std::string& arg_types, + at::ArrayRef inputs, pyobj_list scalar_args); std::shared_ptr createGraphByTracing( - py::function func, + const py::function& func, Stack inputs, - py::function var_name_lookup_fn, + const py::function& var_name_lookup_fn, bool force_outplace, - c10::optional num_real_inputs = {}); + const c10::optional& num_real_inputs = c10::nullopt); } // namespace tracer - }} // namespace torch::jit diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index a478fd6..10e91a5 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1163,7 +1163,7 @@ Operator( \ // checking one of size & scale_factor is set // if scale_factor is a double list check that it's len == dim // reference: _check_size_scale_factor in torch/nn/functional.py -void _check_size_factor(size_t dim, const IValue& size, IValue scale_factor) { +void _check_size_factor(size_t dim, const IValue& size, const IValue& scale_factor) { if (size.isNone() && scale_factor.isNone()) { throw std::runtime_error("either size or scale_factor should be defined"); } @@ -1184,7 +1184,7 @@ void _check_size_factor(size_t dim, const IValue& size, IValue scale_factor) { // reference: _output_size in torch/nn/functional.py // size can be none, int or intlist // scale_factors can be none, float, or floatlist -std::vector _output_size(at::Tensor input, size_t dim, IValue size, IValue scale_factors) { +std::vector _output_size(const at::Tensor& input, size_t dim, const IValue& size, const IValue& scale_factors) { if (!size.isNone()) { if (size.isInt()) { std::vector repeated(dim, size.toInt()); @@ -1209,8 +1209,12 @@ std::vector _output_size(at::Tensor input, size_t dim, IValue size, IVa // reference: interpolate in torch/nn/functional.py // size can be none, int or intlist // scale_factors can be none, float, or floatlist -at::Tensor interpolate(at::Tensor input, IValue size, IValue scale_factors, - std::string mode, c10::optional align_corners) { +at::Tensor interpolate( + const at::Tensor& input, + const IValue& size, + const IValue& scale_factors, + const std::string& mode, + c10::optional align_corners) { if ((mode == "nearest" || mode == "area")) { if (align_corners != c10::nullopt) { throw std::runtime_error("align_corners option can only be set with the " @@ -1279,7 +1283,7 @@ Operation interpolate_op(const Node* n) { // interpolate takes in float & float[] for scale factor // upsample takes in int & int[], so convert the ints to floats before // passing on to the interpolate op -IValue convert_scale_factor_to_double(IValue int_ivalue) { +IValue convert_scale_factor_to_double(const IValue& int_ivalue) { IValue scale_factor_double; if (int_ivalue.isInt()) { scale_factor_double = static_cast(int_ivalue.toInt()); @@ -1384,10 +1388,10 @@ RegisterOperators reg3({ }); -at::Tensor leaky_relu(at::Tensor tensor, double scalar) { +at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) { return at::leaky_relu(tensor, scalar); } -at::Tensor cat(std::vector tensors) { +at::Tensor cat(const std::vector& tensors) { return at::cat(tensors); } diff --git a/torch/csrc/jit/scope.h b/torch/csrc/jit/scope.h index fdb0fbe..1ece1f9 100644 --- a/torch/csrc/jit/scope.h +++ b/torch/csrc/jit/scope.h @@ -36,7 +36,7 @@ public: } Scope(ScopePtr parent, Symbol name) { name_ = name; - parent_ = parent; + parent_ = std::move(parent); } ScopePtr push(Symbol name); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 1f83cff..5dd09fa 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -47,7 +47,7 @@ struct PrintValue : public SugaredValue { return "print"; } std::shared_ptr call( - SourceRange loc, + const SourceRange& loc, Method & m, at::ArrayRef inputs, at::ArrayRef attributes, @@ -81,7 +81,7 @@ struct CastValue : public BuiltinFunction { : BuiltinFunction(method, c10::nullopt) , type_(std::move(type)) {} std::shared_ptr call( - SourceRange loc, + const SourceRange& loc, Method & m, at::ArrayRef inputs, at::ArrayRef attributes, @@ -98,7 +98,7 @@ private: TypePtr type_; }; -static Value* asSimple(SugaredValuePtr value) { +static Value* asSimple(const SugaredValuePtr& value) { if(SimpleValue* sv = dynamic_cast(value.get())) { return sv->getValue(); } @@ -229,7 +229,7 @@ struct Environment { return sv; } - SugaredValuePtr createCapturedInputIfNeeded(const SourceRange& loc, std::string ident) { + SugaredValuePtr createCapturedInputIfNeeded(const SourceRange& loc, const std::string& ident) { auto in_frame = findInThisFrame(ident); if (in_frame) { return in_frame; @@ -318,7 +318,7 @@ struct Environment { return getSugaredVar(ident)->asValue(ident.range(), method); } - SugaredValuePtr getSugaredVar(const std::string& ident, SourceRange range, bool required=true) { + SugaredValuePtr getSugaredVar(const std::string& ident, const SourceRange& range, bool required=true) { auto retval = createCapturedInputIfNeeded(range, ident); if(!retval) { @@ -352,7 +352,7 @@ struct Environment { return retval; } - Value* getVar(const std::string& ident, SourceRange range) { + Value* getVar(const std::string& ident, const SourceRange& range) { return getSugaredVar(ident, range)->asValue(range, method); } @@ -427,7 +427,7 @@ static inline bool isIntOrFloatUsedAsList( return list_type && list_type->getElementType() == v_type && arg.N(); } -inline bool convertibleToList(TypePtr type, TypePtr list_type_) { +inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) { auto list_type = list_type_->cast(); if(!list_type) { return false; @@ -451,7 +451,7 @@ inline bool convertibleToList(TypePtr type, TypePtr list_type_) { Value* tryConvertToType( const SourceRange& loc, Graph& graph, - TypePtr concrete_type, + const TypePtr& concrete_type, Value* value, bool allow_conversions) { // Allow homogeneous tuples to be casted implicitly to lists of appropriate @@ -497,7 +497,7 @@ Value* tryMatchArgument( Graph& graph, const SourceRange& loc, const NamedValue& named_value, - std::function err, + const std::function& err, bool allow_conversions, TypeEnv & type_env) { Value* value = named_value.value(graph); @@ -543,11 +543,11 @@ c10::optional findInputWithName( } Value* tryCreateList( - TypePtr elem_type, + const TypePtr& elem_type, Graph& graph, const SourceRange& loc, at::ArrayRef varargs, - std::function err, + const std::function& err, bool convert_tensor_to_num, TypeEnv & type_env) { Argument elem_arg("", elem_type); @@ -694,7 +694,7 @@ c10::optional tryMatchSchema( return MatchedSchema{std::move(positional_inputs), std::move(return_types)}; } -static std::string prefixLine(const std::string& str, std::string prefix) { +static std::string prefixLine(const std::string& str, const std::string& prefix) { std::stringstream ss; bool was_newline = true; for(auto c : str) { @@ -733,7 +733,7 @@ Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - c10::optional self, + const c10::optional& self, at::ArrayRef inputs, at::ArrayRef attributes, // if true, emitBuiltinCall will throw an exception if this builtin does not exist, @@ -801,7 +801,7 @@ static Value* ensureInt(const SourceRange& range, Value* v) { } std::shared_ptr BuiltinFunction::call( - SourceRange loc, + const SourceRange& loc, Method& m, at::ArrayRef inputs, at::ArrayRef attributes, @@ -810,25 +810,25 @@ std::shared_ptr BuiltinFunction::call( loc, *m.graph(), symbol, self, inputs, attributes, true)); } -inline bool isSupportedListElementType(TypePtr type) { +inline bool isSupportedListElementType(const TypePtr& type) { return type->isSubtypeOf(DynamicType::get()) || type->isSubtypeOf(NumberType::get()); } -TypePtr parseTypeFromExpr(Expr expr); -c10::optional> handleBroadcastList(Expr expr); +TypePtr parseTypeFromExpr(const Expr& expr); +c10::optional> handleBroadcastList(const Expr& expr); struct to_ir { to_ir( - Def def, + Def def_, Resolver resolver_, - SugaredValuePtr self, + SugaredValuePtr self_, Method& method) // method being constructed : method(method) , graph(method.graph()) - , def(def) + , def(std::move(def_)) , resolver(std::move(resolver_)) - , self(self) + , self(std::move(self_)) , environment_stack(nullptr) { JIT_ASSERT(resolver); pushFrame(graph->block()); @@ -912,7 +912,7 @@ private: return default_values; } - std::vector parseArgsFromDecl(Decl decl) { + std::vector parseArgsFromDecl(const Decl& decl) { auto params_begin = decl.params().begin(); auto params_end = decl.params().end(); if (self) @@ -962,7 +962,7 @@ private: return retval; } - std::vector parseReturnsFromDecl(Decl decl) { + std::vector parseReturnsFromDecl(const Decl& decl) { JIT_ASSERT(decl.return_type().present()); if (handleBroadcastList(decl.return_type().get())) throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type"; @@ -1002,7 +1002,7 @@ private: return FunctionSchema(name, args, returns, false, is_varret); } - std::vector emitFormalArguments(SugaredValuePtr self, const FunctionSchema& schema) { + std::vector emitFormalArguments(const SugaredValuePtr& self, const FunctionSchema& schema) { std::vector arguments; // for schema // inputs auto it = def.decl().params().begin(); @@ -1135,7 +1135,7 @@ private: std::shared_ptr emitSingleIfBranch( Block* b, - const List branch) { + const List& branch) { pushFrame(b); WithInsertPoint guard(b); emitStatements(branch); @@ -1198,8 +1198,8 @@ private: popFrame(); }; - emit_if_expr(true_block, true_expr); - emit_if_expr(false_block, false_expr); + emit_if_expr(true_block, std::move(true_expr)); + emit_if_expr(false_block, std::move(false_expr)); auto true_type = unshapedType(true_block->outputs().at(0)->type()); auto false_type = unshapedType(false_block->outputs().at(0)->type()); @@ -1215,7 +1215,7 @@ private: return expr_value; } - Value* emitCond(Expr cond) { + Value* emitCond(const Expr& cond) { Value* v = emitExpr(cond); if (!v->type()->isSubtypeOf(BoolType::get())) { ErrorReport error(cond); @@ -1460,7 +1460,7 @@ private: } } - void emitForRange(SourceRange range, const Ident& target, const List& args, const List& body) { + void emitForRange(const SourceRange& range, const Ident& target, const List& args, const List& body) { // TODO: start, stop, step loop if (args.size() != 1) { throw ErrorReport(range) @@ -1981,7 +1981,7 @@ private: std::vector getNamedValues( - TreeList trees, + const TreeList& trees, bool maybe_unpack) { std::vector values; for (const auto& tree : trees) { @@ -1999,23 +1999,23 @@ private: return values; } std::vector getNamedValues( - List trees, + const List& trees, bool maybe_unpack) { return getNamedValues(trees.tree()->trees(), maybe_unpack); } std::vector getValues( - TreeList trees, + const TreeList& trees, bool maybe_unpack) { return toValues(*graph, getNamedValues(trees, maybe_unpack)); } std::vector getValues( - List trees, + const List& trees, bool maybe_unpack) { return getValues(trees.tree()->trees(), maybe_unpack); } - std::vector emitAttributes(const List attributes) { + std::vector emitAttributes(const List& attributes) { return fmap(attributes, [&](const Attribute& attr) { return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value())); }); @@ -2077,8 +2077,8 @@ private: } } - Value* emitExpr(Expr tree, TypePtr type_hint = nullptr) { - return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method); + Value* emitExpr(const Expr& tree, TypePtr type_hint = nullptr) { + return emitSugaredExpr(tree, 1, std::move(type_hint))->asValue(tree.range(), method); } NodeKind reverseComparision(NodeKind kind) { @@ -2101,7 +2101,7 @@ private: // or a = torch.jit.annotate(List[int], []) // the caller is responsible for checking that the result matches type_hint // emitSugaredExpr is free to ignore it. - std::shared_ptr emitSugaredExpr(Expr tree, size_t n_binders, TypePtr type_hint=nullptr) { + std::shared_ptr emitSugaredExpr(const Expr& tree, size_t n_binders, TypePtr type_hint=nullptr) { switch(tree.kind()) { case TK_VAR: return environment_stack->getSugaredVar(Var(tree).name()); @@ -2115,7 +2115,7 @@ private: return emitApplyExpr(apply, n_binders); } break; default: - return std::make_shared(emitSimpleExpr(tree, type_hint)); + return std::make_shared(emitSimpleExpr(tree, std::move(type_hint))); } } @@ -2191,7 +2191,7 @@ private: Value* emitSimpleExpr( const TreeRef& tree, - TypePtr type_hint = nullptr) { + const TypePtr& type_hint = nullptr) { switch (tree->kind()) { case '@': case TK_POW: @@ -2583,7 +2583,7 @@ static const std::unordered_map &builtin_cast_methods( // support syntax sugar for x.foo(y, z) by allowing x.foo to return a // callable value that will resolve to foo(x, y, z) when called. -std::shared_ptr SimpleValue::attr(SourceRange loc, Method & m, const std::string& field) { +std::shared_ptr SimpleValue::attr(const SourceRange& loc, Method & m, const std::string& field) { // Allow method-style casts on Tensor types. e.g. x.int() if (value->type()->isSubtypeOf(DynamicType::get())) { if (builtin_cast_methods().count(field)) { @@ -2634,7 +2634,7 @@ std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef input return outputs; } -void defineMethodsInModule(std::shared_ptr m, const std::vector& definitions, const std::vector& resolvers, SugaredValuePtr self) { +void defineMethodsInModule(const std::shared_ptr& m, const std::vector& definitions, const std::vector& resolvers, const SugaredValuePtr& self) { JIT_ASSERT(definitions.size() == resolvers.size()); auto resolver_it = resolvers.begin(); std::vector methods; @@ -2714,13 +2714,13 @@ const std::unordered_map> &subscr return map; } -bool isTorch(Expr expr) { +bool isTorch(const Expr& expr) { return expr.kind() == TK_VAR && Var(expr).name().name() == "torch"; } // gets the base type name given namespaces where the types live // turns torch.Tensor -> Tensor, X -> X -c10::optional parseBaseTypeName(Expr expr) { +c10::optional parseBaseTypeName(const Expr& expr) { switch (expr.kind()) { case TK_VAR: { return Var(expr).name().name(); @@ -2735,7 +2735,7 @@ c10::optional parseBaseTypeName(Expr expr) { return at::nullopt; } -TypePtr parseTypeFromExpr(Expr expr) { +TypePtr parseTypeFromExpr(const Expr& expr) { if (expr.kind() == TK_SUBSCRIPT) { auto subscript = Subscript(expr); auto value_name = parseBaseTypeName(subscript.value()); @@ -2757,7 +2757,7 @@ TypePtr parseTypeFromExpr(Expr expr) { << " cannot be used in a type expression"; } -c10::optional> handleBroadcastList(Expr expr) { +c10::optional> handleBroadcastList(const Expr& expr) { if (expr.kind() != TK_SUBSCRIPT) return c10::nullopt; auto subscript = Subscript(expr); @@ -2808,7 +2808,7 @@ c10::optional> handleBroadcastList(Expr expr) { return std::pair(list_ptr, len_v); } -void defineMethodsInModule(std::shared_ptr m, const std::string& source, Resolver resolver, SugaredValuePtr self) { +void defineMethodsInModule(std::shared_ptr m, const std::string& source, const Resolver& resolver, const SugaredValuePtr& self) { Parser p(source); std::vector definitions; std::vector resolvers; @@ -2817,13 +2817,13 @@ void defineMethodsInModule(std::shared_ptr m, const std::string& source, definitions.push_back(def); resolvers.push_back(resolver); } - defineMethodsInModule(m, definitions, resolvers, self); + defineMethodsInModule(std::move(m), definitions, resolvers, self); } std::vector> SimpleValue::asTuple( - SourceRange loc, + const SourceRange& loc, Method& m, - c10::optional size_hint) { + const c10::optional& size_hint) { static const auto make_simple_value = [](Value* v) -> std::shared_ptr { return std::make_shared(v); }; diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 3b7c62f..bb98996 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -38,12 +38,12 @@ struct SugaredValue : public std::enable_shared_from_this { // what can we do with this thing? // use it as a value e.g. `this + 4` - virtual Value * asValue(SourceRange loc, Method & m) { + virtual Value * asValue(const SourceRange& loc, Method & m) { throw ErrorReport(loc) << kind() << " cannot be used as a value"; } // select an attribute on it, e.g. `this.field` - virtual std::shared_ptr attr(SourceRange loc, Method & m, const std::string& field) { + virtual std::shared_ptr attr(const SourceRange& loc, Method & m, const std::string& field) { throw ErrorReport(loc) << "attribute lookup is not defined on " << kind(); } virtual NoneStatus isNone() { @@ -53,15 +53,15 @@ struct SugaredValue : public std::enable_shared_from_this { // use it as a vector of values, e.g. a tuple of values as return value from // a method invocation virtual std::vector> asTuple( - SourceRange loc, + const SourceRange& loc, Method& m, - c10::optional size_hint = {}) { + const c10::optional& size_hint = {}) { throw ErrorReport(loc) << kind() << " cannot be used as a tuple"; } // call it like a function, e.g. `outputs = this(inputs)` virtual std::shared_ptr call( - SourceRange loc, + const SourceRange& loc, Method & m, // note: names for args will be 'argument 0', 'argument 1', etc.. at::ArrayRef inputs_, @@ -96,7 +96,7 @@ struct TORCH_API SimpleValue : public SugaredValue { std::string kind() const override { return "value"; } - Value * asValue(SourceRange range, Method & m) override { + Value * asValue(const SourceRange& range, Method & m) override { return value; } NoneStatus isNone() override { @@ -108,10 +108,10 @@ struct TORCH_API SimpleValue : public SugaredValue { return NEVER; } std::vector> asTuple( - SourceRange loc, + const SourceRange& loc, Method& m, - c10::optional size_hint = {}) override; - std::shared_ptr attr(SourceRange loc, Method & m, const std::string& field) override; + const c10::optional& size_hint = {}) override; + std::shared_ptr attr(const SourceRange& loc, Method & m, const std::string& field) override; Value* getValue() const { return value; } @@ -133,7 +133,7 @@ struct TORCH_API BuiltinFunction : public SugaredValue { return "builtin"; } std::shared_ptr call( - SourceRange loc, + const SourceRange& loc, Method& m, at::ArrayRef attributes, at::ArrayRef inputs, @@ -149,7 +149,7 @@ struct TORCH_API BuiltinModule : public SugaredValue { std::string kind() const override { return "builtin module"; } - std::shared_ptr attr(SourceRange loc, Method & m, const std::string& field) override { + std::shared_ptr attr(const SourceRange& loc, Method & m, const std::string& field) override { return std::make_shared(Symbol::fromQualString(name+"::"+field), c10::nullopt); } @@ -187,14 +187,14 @@ inline std::shared_ptr nativeResolver(const std::string& name, Met } TORCH_API void defineMethodsInModule( - std::shared_ptr m, + const std::shared_ptr& m, const std::vector& definitions, const std::vector& resolvers, /* determines how we handle free variables in each definition*/ - std::shared_ptr self /* if non-null, the first argument to each def, is bound to this value */ + const std::shared_ptr& self /* if non-null, the first argument to each def, is bound to this value */ ); // same as above but parse the definitions from source -TORCH_API void defineMethodsInModule(std::shared_ptr m, const std::string& source, Resolver resolver, std::shared_ptr self); +TORCH_API void defineMethodsInModule(std::shared_ptr m, const std::string& source, const Resolver& resolver, const std::shared_ptr& self); // pack outputs of a function following python rules. If there is a single value return // a SimpleValue, otherwise pack all the values into a Tuple. @@ -209,10 +209,17 @@ struct MethodValue : public SugaredValue { std::string kind() const override { return "method"; } - std::shared_ptr call(SourceRange loc, Method & caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { - return std::make_shared(packOutputs(*caller.graph(), caller.emit_call_to(loc, method, inputs, attributes))); + std::shared_ptr call( + const SourceRange& loc, + Method& caller, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + return std::make_shared(packOutputs( + *caller.graph(), caller.emit_call_to(loc, method, inputs, attributes))); } -private: + + private: std::shared_ptr module; Method& method; @@ -243,7 +250,7 @@ TORCH_API Value* emitBuiltinCall( const SourceRange& loc, Graph& graph, Symbol name, - c10::optional self, + const c10::optional& self, at::ArrayRef inputs, at::ArrayRef attributes, // if true, emitBuiltinCall will throw an exception if this builtin does not exist, diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1780c1f..7e6e975 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -105,7 +105,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { } // call it like a function, e.g. `outputs = this(inputs)` - std::shared_ptr call(SourceRange loc, Method & m, at::ArrayRef inputs_, at::ArrayRef attributes, size_t n_binders) override { + std::shared_ptr call(const SourceRange& loc, Method & m, at::ArrayRef inputs_, at::ArrayRef attributes, size_t n_binders) override { auto inputs = toValues(*m.graph(), inputs_); auto schema = getSchema(inputs.size(), n_binders); @@ -139,7 +139,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { protected: - py::object getattr(SourceRange loc, const std::string& name) { + py::object getattr(const SourceRange& loc, const std::string& name) { try { return py::getattr(self, name.c_str()); } catch (py::error_already_set& e) { @@ -151,10 +151,10 @@ protected: }; struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { - explicit PythonModuleValue(py::object mod) : PythonValue(mod) {} + explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {} std::shared_ptr attr( - SourceRange loc, + const SourceRange& loc, Method& m, const std::string& field) override { py::object member = getattr(loc, field); @@ -166,11 +166,11 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { }; struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { - explicit ConstantPythonTupleValue(py::object tup) : PythonValue(tup) {} + explicit ConstantPythonTupleValue(py::object tup) : PythonValue(std::move(tup)) {} std::vector> asTuple( - SourceRange loc, + const SourceRange& loc, Method& m, - c10::optional size_hint = {}) override { + const c10::optional& size_hint = {}) override { py::tuple tup = self; std::vector> result; result.reserve(tup.size()); @@ -182,7 +182,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { } Value* asValue( - SourceRange loc, + const SourceRange& loc, Method& m) override { std::vector values; for (auto sugared_item : asTuple(loc, m)) { @@ -210,7 +210,7 @@ struct ModuleValue : public SugaredValue { } // select an attribute on it, e.g. `this.field` - std::shared_ptr attr(SourceRange loc, Method & m, const std::string& field) override { + std::shared_ptr attr(const SourceRange& loc, Method & m, const std::string& field) override { // workaround to make self.training work // it adds a buffer 'training' to the model if one doesn't exist // and then loads that parameter, casting it to bool @@ -251,14 +251,14 @@ struct ModuleValue : public SugaredValue { } // call module.forward - std::shared_ptr call(SourceRange loc, Method & caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { + std::shared_ptr call(const SourceRange& loc, Method & caller, at::ArrayRef inputs, at::ArrayRef attributes, size_t n_binders) override { return attr(loc, caller, "forward")->call(loc, caller, inputs, attributes, n_binders); } std::vector> asTuple( - SourceRange loc, + const SourceRange& loc, Method& m, - c10::optional size_hint = {}) override { + const c10::optional& size_hint = {}) override { py::object py_module = py::cast(module); if(!py::isinstance(py_module, py::module::import("torch.jit").attr("_ConstModuleList"))) return SugaredValue::asTuple(loc, m, size_hint); @@ -296,7 +296,7 @@ struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { } std::shared_ptr call( - SourceRange loc, + const SourceRange& loc, Method& caller, at::ArrayRef inputs, at::ArrayRef attributes, @@ -450,7 +450,7 @@ static void gatherParametersAndBuffers(std::vector & values, const namespace { -Resolver pythonResolver(ResolutionCallback rcb) { +Resolver pythonResolver(const ResolutionCallback& rcb) { return [rcb](const std::string& name, Method& m, const SourceRange& loc) -> std::shared_ptr { AutoGIL ag; @@ -466,8 +466,8 @@ Resolver pythonResolver(ResolutionCallback rcb) { FunctionSchema getSchemaWithNameAndDefaults( const SourceRange& range, - const FunctionSchema schema, - at::optional new_name, + const FunctionSchema& schema, + const at::optional& new_name, const FunctionDefaults& default_args) { std::vector new_args; for (auto& arg : schema.arguments()) { diff --git a/torch/csrc/jit/script/lexer.cpp b/torch/csrc/jit/script/lexer.cpp index 9b67a1e..e320163 100644 --- a/torch/csrc/jit/script/lexer.cpp +++ b/torch/csrc/jit/script/lexer.cpp @@ -59,7 +59,7 @@ bool SharedParserData::isBinary(int kind, int* prec) { return false; } -int stringToKind(std::string str) { +int stringToKind(const std::string& str) { static std::once_flag init_flag; static std::unordered_map str_to_kind; std::call_once(init_flag, []() { diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index 24b7a31..6a0c178 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -109,7 +109,7 @@ enum TokenKind { }; std::string kindToString(int kind); -int stringToKind(std::string str); +int stringToKind(const std::string& str); // nested hash tables that indicate char-by-char what is a valid token. struct TokenTrie; diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index bd735eb..5adf873 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -15,7 +15,7 @@ void placeholderCreator(Method&) { c10::optional> try_emit_call_to( Graph& graph, - SourceRange loc, + const SourceRange& loc, Method& callee, c10::optional self, ArrayRef args, @@ -33,7 +33,7 @@ c10::optional> try_emit_call_to( auto matched_schema = tryMatchSchema( callee.getSchema(), - loc, graph, self, args, kwargs, failure_messages, conv_tensors_to_nums); + loc, graph, std::move(self), args, kwargs, failure_messages, conv_tensors_to_nums); if(!matched_schema) return c10::nullopt; @@ -48,7 +48,7 @@ c10::optional> try_emit_call_to( return inlineCallTo(graph, *callee.graph(), matched_schema->inputs); } -std::vector Method::emit_call_to(SourceRange loc, Method & callee, ArrayRef args, ArrayRef kwargs) { +std::vector Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef args, ArrayRef kwargs) { JIT_ASSERT(!executor); std::stringstream failure_messages; if (auto result = try_emit_call_to( @@ -96,8 +96,8 @@ void Module::save(const std::string& filename) { } void Module::to_impl( - c10::optional device, - c10::optional dtype, + const c10::optional& device, + const c10::optional& dtype, bool non_blocking) { // First call `to()` on every child module. for (auto& child : modules) { diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 83afd29..4920b2f 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -92,7 +92,7 @@ struct Method { // adding any extra parameters necessary to do this call // defined here to keep details of member_input handling confined to this class - std::vector emit_call_to(SourceRange loc, Method & callee, ArrayRef args, ArrayRef kwargs); + std::vector emit_call_to(const SourceRange& loc, Method & callee, ArrayRef args, ArrayRef kwargs); // if this isn't yet defined, run its method_creator function TORCH_API void ensure_defined(); @@ -334,7 +334,7 @@ struct Module { } IValue forward(std::vector inputs) { - return get_method("forward")(inputs); + return get_method("forward")(std::move(inputs)); } void register_parameter(const std::string & name, autograd::Variable v, bool is_buffer) { @@ -356,7 +356,7 @@ struct Module { } Method& create_method(const std::string & name, std::function creator) { - std::unique_ptr method(new Method(this, name, optimize, std::make_shared(), {}, creator)); + std::unique_ptr method(new Method(this, name, optimize, std::make_shared(), {}, std::move(creator))); return *methods.insert(name, std::move(method)); } @@ -462,8 +462,8 @@ struct Module { private: void to_impl( - c10::optional device, - c10::optional dtype, + const c10::optional& device, + const c10::optional& dtype, bool non_blocking); // invariant: to ensure member_inputs of Methods stay valid, @@ -480,7 +480,7 @@ struct Module { // match the functions schema c10::optional> try_emit_call_to( Graph& graph, - SourceRange loc, + const SourceRange& loc, Method& callee, c10::optional self, ArrayRef args, diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index fc6d9ce..12cff19 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -10,7 +10,7 @@ namespace jit { namespace script { -inline Decl mergeTypesFromTypeComment(Decl decl, Decl type_annotation_decl, bool is_method) { +inline Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_decl, bool is_method) { auto expected_num_annotations = decl.params().size(); if (is_method) { // `self` argument @@ -48,7 +48,7 @@ struct Parser { // of the Compound tree are in the same place. return Ident::create(t.range, t.text()); } - TreeRef createApply(Expr expr) { + TreeRef createApply(const Expr& expr) { TreeList attributes; auto range = L.cur().range; TreeList inputs; @@ -165,7 +165,7 @@ struct Parser { auto cond = parseExp(); L.expect(TK_ELSE); auto false_branch = parseExp(binary_prec); - return c(TK_IF_EXPR, range, {cond, true_branch, false_branch}); + return c(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch}); } // parse the longest expression whose binary operators have // precedence strictly greater than 'precedence' @@ -287,7 +287,7 @@ struct Parser { } } - TreeRef parseSubscript(TreeRef value) { + TreeRef parseSubscript(const TreeRef& value) { const auto range = L.cur().range; auto subscript_exprs = parseList('[', ',', ']', &Parser::parseSubscriptExp); @@ -337,7 +337,7 @@ struct Parser { // 'first' has already been parsed since expressions can exist // alone on a line: // first[,other,lhs] = rhs - TreeRef parseAssign(Expr lhs) { + TreeRef parseAssign(const Expr& lhs) { auto op = parseAssignmentOp(); auto rhs = parseExpOrExpTuple(); L.expect(TK_NEWLINE); diff --git a/torch/csrc/jit/script/tree.h b/torch/csrc/jit/script/tree.h index 36c0c1d..aab95f5 100644 --- a/torch/csrc/jit/script/tree.h +++ b/torch/csrc/jit/script/tree.h @@ -52,7 +52,8 @@ struct Tree : std::enable_shared_from_this { const TreeRef& tree(size_t i) const { return trees().at(i); } - virtual TreeRef map(std::function fn) { + virtual TreeRef map(const std::function& fn) { + (void)fn; return shared_from_this(); } template @@ -136,7 +137,7 @@ struct Compound : public Tree { bool isAtom() const override { return false; } - TreeRef map(std::function fn) override { + TreeRef map(const std::function& fn) override { TreeList trees_; for (auto& t : trees()) { trees_.push_back(fn(t)); @@ -200,7 +201,7 @@ static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) { return out << std::endl; } -static inline std::ostream& operator<<(std::ostream& out, TreeRef t) { +static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) { return out << pretty_tree(t); } diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 724a6e3..270e1aa 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -151,7 +151,7 @@ struct List : public TreeView { T operator[](size_t i) const { return T(subtree(i)); } - TreeRef map(std::function fn) { + TreeRef map(const std::function& fn) { return tree_->map([&](TreeRef v) { return fn(T(v)); }); } static List create(const SourceRange& range, const std::vector& subtrees) { @@ -177,7 +177,7 @@ struct Maybe : public TreeView { T get() const { return T(tree_->trees().at(0)); } - TreeRef map(std::function fn) { + TreeRef map(const std::function& fn) { return tree_->map([&](TreeRef v) { return fn(T(v)); }); } static Maybe create(const SourceRange& range) { @@ -297,7 +297,7 @@ struct Param : public TreeView { explicit Param(const TreeRef& tree) : TreeView(tree) { tree_->match(TK_PARAM); } - static Param create(const SourceRange& range, const Ident& ident, const Expr& type, Maybe def) { + static Param create(const SourceRange& range, const Ident& ident, const Expr& type, const Maybe& def) { return Param(Compound::create(TK_PARAM, range, {ident, type, def})); } Ident ident() const { @@ -309,7 +309,7 @@ struct Param : public TreeView { Maybe defaultValue() const { return Maybe(subtree(2)); } - Param withType(Expr typ) const { + Param withType(const Expr& typ) const { return Param::create(range(), ident(), typ, defaultValue()); } }; @@ -328,7 +328,7 @@ struct Decl : public TreeView { Maybe return_type() const { return Maybe(subtree(1)); } - static Decl create(const SourceRange& range, const List& params, Maybe return_type) { + static Decl create(const SourceRange& range, const List& params, const Maybe& return_type) { return Decl(Compound::create(TK_DECL, range, {params, return_type})); } }; @@ -338,7 +338,7 @@ struct Def : public TreeView { tree->match(TK_DEF); } Def withName(std::string new_name) const { - auto new_ident = Ident::create(name().range(), new_name); + auto new_ident = Ident::create(name().range(), std::move(new_name)); return create(range(), new_ident, decl(), statements()); } Ident name() const { @@ -553,7 +553,7 @@ struct ExprStmt : public Stmt { Expr expr() { return Expr(subtree(0)); } - static ExprStmt create(const SourceRange& range, const Expr list) { + static ExprStmt create(const SourceRange& range, const Expr& list) { return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list})); } }; diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index 7cf47a4..f216b5d 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -14,7 +14,7 @@ struct SymbolicVariable { return v; } static SymbolicVariable asNewInput(Graph & g, std::string name = "") { - return g.addInput(name); + return g.addInput(std::move(name)); } static SymbolicVariable asNewInput(Graph & g, TypePtr type) { return g.addInput()->setType(std::move(type)); @@ -239,13 +239,13 @@ struct SymbolicVariable { return create(aten::view, {*this, sizes})[0]; } SymbolicVariable view(std::vector sizes) const { - return view(insertConstant(sizes)); + return view(insertConstant(std::move(sizes))); } SymbolicVariable reshape(Value* sizes) const { return create(aten::reshape, {*this, sizes})[0]; } SymbolicVariable reshape(std::vector sizes) const { - return reshape(insertConstant(sizes)); + return reshape(insertConstant(std::move(sizes))); } SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const { return create(aten::addmm, {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0]; @@ -255,7 +255,7 @@ struct SymbolicVariable { } private: Value * insertConstant(IValue value) const { - return v->owningGraph()->insertConstant(value); + return v->owningGraph()->insertConstant(std::move(value)); } SymbolicVariable typeLike(SymbolicVariable other) const { if (auto other_type = other.v->type()->cast()) diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index ef15721..91b333c 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -191,7 +191,7 @@ void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, s list_trace[idx] = prim; } -void ArgumentStash::stashValue(const std::string& arg_name, size_t idx, const Variable& var, TypePtr type) { +void ArgumentStash::stashValue(const std::string& arg_name, size_t idx, const Variable& var, const TypePtr& type) { if (!isTracing()) return; Value* ten = getValueTrace(var); diff --git a/torch/csrc/jit/tracing_state.h b/torch/csrc/jit/tracing_state.h index ef9dce5..c906932 100644 --- a/torch/csrc/jit/tracing_state.h +++ b/torch/csrc/jit/tracing_state.h @@ -91,7 +91,7 @@ struct ArgumentStash { TORCH_API static void stashValue(const std::string& arg_name, size_t idx, const Variable& var, - TypePtr type=nullptr); + const TypePtr& type=nullptr); static bool hasValue(const std::string& arg_name) { return stash.values.count(arg_name) > 0; @@ -123,7 +123,7 @@ TORCH_API extern const char * WARN_CONSTRUCTOR; TORCH_API extern const char * WARN_RESIZE; TORCH_API void _do_warn(const char * _reason, const char * _kind); inline void warn(const char * _reason, const char * _kind=nullptr) { - if (auto state = getTracingState()) { + if (const auto& state = getTracingState()) { if (!state->warn) return; _do_warn(_reason, _kind); } diff --git a/torch/csrc/utils/invalid_arguments.cpp b/torch/csrc/utils/invalid_arguments.cpp index 1c7f604..520e91b 100644 --- a/torch/csrc/utils/invalid_arguments.cpp +++ b/torch/csrc/utils/invalid_arguments.cpp @@ -149,7 +149,7 @@ std::unique_ptr _buildType(std::string type_name, bool is_nullable) { } std::pair _parseOption(const std::string& _option_str, - const std::unordered_map kwargs) + const std::unordered_map& kwargs) { if (_option_str == "no arguments") return std::pair(Option(false, false), _option_str); diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index f4338eb..74c2288 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -12,6 +12,7 @@ #include #include +#include namespace py = pybind11; @@ -33,7 +34,7 @@ struct type_caster { } static handle - cast(at::Tensor src, return_value_policy /* policy */, handle /* parent */) { + cast(const at::Tensor& src, return_value_policy /* policy */, handle /* parent */) { if (!src.is_variable()) { throw std::runtime_error( "Expected tensor's dynamic type to be Variable, not Tensor"); @@ -55,7 +56,7 @@ public: } } static handle cast(torch::autograd::Variable src, return_value_policy /* policy */, handle /* parent */) { - return handle(THPVariable_Wrap(src)); + return handle(THPVariable_Wrap(std::move(src))); } }; diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index bdb9b8c..7a1fcd0 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -56,34 +56,34 @@ void maybe_initialize_cuda(const Device device) { Tensor dispatch_zeros(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::zeros(sizes, type.options(device)); + return torch::zeros(sizes, type.options(std::move(device))); } Tensor dispatch_ones(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::ones(sizes, type.options(device)); + return torch::ones(sizes, type.options(std::move(device))); } Tensor dispatch_full(const Type& type, Scalar fill_value, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::full(sizes, fill_value, type.options(device)); + return torch::full(sizes, fill_value, type.options(std::move(device))); } Tensor new_with_sizes(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::empty(sizes, type.options(device)); + return torch::empty(sizes, type.options(std::move(device))); } Tensor new_with_storage(const Type& type, Storage storage) { auto tensor = at::empty({}, type.options()); - tensor.set_(storage); + tensor.set_(std::move(storage)); return tensor; } -Tensor new_with_tensor(const Type& type, Tensor other) { +Tensor new_with_tensor(const Type& type, const Tensor& other) { if (other.type() != type) { throw TypeError("expected %s (got %s)", type.toString(), other.type().toString()); } @@ -239,7 +239,7 @@ Tensor new_from_data_copy( const Type& type, c10::optional device, PyObject* data) { - return internal_new_from_data(type, device, data, true, true, false); + return internal_new_from_data(type, std::move(device), data, true, true, false); } Tensor legacy_new_from_sequence( @@ -249,7 +249,7 @@ Tensor legacy_new_from_sequence( if (!PySequence_Check(data)) { throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name); } - return legacy_new_from_data(type, device, data); + return legacy_new_from_data(type, std::move(device), data); } void check_legacy_ctor_device(const Type& type, c10::optional device) { @@ -453,7 +453,7 @@ Tensor legacy_new_from_data( const Type& type, c10::optional device, PyObject* data) { - return internal_new_from_data(type, device, data, false, false, false); + return internal_new_from_data(type, std::move(device), data, false, false, false); } Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject* kwargs) { -- 2.7.4