From 7a61306031648fc4e62fe9e5868893ebf9ae6fb6 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Fri, 14 Dec 2018 13:30:35 -0800 Subject: [PATCH] Enable all clang-tidy performance checks (#15198) Summary: This PR adds the final set of clang-tidy checks we should add for our codebase: a last set of performance-related checks. Most fixes here are around changing `auto` to `const auto&` in a few places where unnecessary copies were made, and adding `reserve()` calls before loops doing repeated `push_back()`. Also a few cases of calling `std::string::find` with a single-character string literal instead of a single char, which uses a less efficient string search algorithm meant for searching larger substrings. ![image](https://user-images.githubusercontent.com/6429851/49978940-adc1a780-ff01-11e8-99da-a4e431361f07.png) ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/15198 Differential Revision: D13468797 Pulled By: goldsborough fbshipit-source-id: 2bed1ea1c7c162b7f3e0e1026f17125e88c4d5b2 --- .clang-tidy | 3 ++- torch/csrc/Exceptions.h | 2 +- torch/csrc/api/include/torch/nn/module.h | 8 ++++---- torch/csrc/api/src/data/datasets/mnist.cpp | 4 ++-- torch/csrc/api/src/nn/module.cpp | 20 ++++++++++---------- torch/csrc/api/src/nn/modules/batchnorm.cpp | 3 +-- torch/csrc/api/src/nn/modules/embedding.cpp | 2 +- torch/csrc/api/src/nn/modules/linear.cpp | 2 +- torch/csrc/api/src/nn/modules/rnn.cpp | 12 +++++------- torch/csrc/api/src/optim/serialize.cpp | 1 + torch/csrc/autograd/VariableTypeUtils.h | 2 +- torch/csrc/autograd/engine.cpp | 2 +- torch/csrc/autograd/python_function.cpp | 5 +---- torch/csrc/jit/batched/BatchTensor.cpp | 2 +- torch/csrc/jit/constants.cpp | 8 ++++---- torch/csrc/jit/custom_operator.h | 2 +- torch/csrc/jit/fuser/compiler.cpp | 2 +- torch/csrc/jit/graph_executor.cpp | 2 +- torch/csrc/jit/init.cpp | 4 ++-- torch/csrc/jit/interpreter.h | 2 +- torch/csrc/jit/operator.cpp | 4 ++-- torch/csrc/jit/passes/alias_analysis.cpp | 2 +- torch/csrc/jit/passes/graph_fuser.cpp | 1 + torch/csrc/jit/passes/python_print.cpp | 2 +- torch/csrc/jit/passes/specialize_undef.cpp | 2 +- torch/csrc/jit/python_ir.cpp | 2 +- torch/csrc/jit/python_tracer.cpp | 4 ++-- torch/csrc/jit/register_prim_ops.cpp | 3 ++- torch/csrc/jit/script/compiler.cpp | 10 +++++----- torch/csrc/jit/script/compiler.h | 2 +- torch/csrc/jit/script/init.cpp | 3 ++- torch/csrc/jit/script/module.h | 2 +- torch/csrc/jit/script/python_tree_views.cpp | 10 +++++----- torch/csrc/jit/script/tree.h | 6 +++--- torch/csrc/tensor/python_tensor.cpp | 2 +- torch/csrc/utils/python_arg_parser.cpp | 2 +- torch/csrc/utils/tensor_new.cpp | 6 +++--- 37 files changed, 75 insertions(+), 76 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index fe9b9ac..8f22f83 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -24,7 +24,8 @@ Checks: ' ,-modernize-use-auto ,-modernize-use-default-member-init ,-modernize-use-using - ,performance-unnecessary-value-param + ,performance-* + ,-performance-noexcept-move-constructor ' WarningsAsErrors: '*' HeaderFilterRegex: 'torch/csrc/.*' diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index c0f3b6b..5471999 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -50,7 +50,7 @@ struct python_error : public std::exception { } python_error(python_error&& other) { - type = std::move(other.type); + type = other.type; value = other.value; traceback = other.traceback; other.type = nullptr; diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index 43c6231..bb3fa9e 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -150,7 +150,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// \endrst void apply( const NamedModuleApplyFunction& function, - std::string name_prefix = std::string()); + const std::string& name_prefix = std::string()); /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `const std::string&` for the key of the module, @@ -167,7 +167,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// \endrst void apply( const ConstNamedModuleApplyFunction& function, - std::string name_prefix = std::string()) const; + const std::string& name_prefix = std::string()) const; /// Applies the `function` to the `Module` and recursively to every submodule. /// The function must accept a `const std::shared_ptr&`. @@ -198,7 +198,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// \endrst void apply( const NamedModulePointerApplyFunction& function, - std::string name_prefix = std::string()) const; + const std::string& name_prefix = std::string()) const; /// Returns the parameters of this `Module` and if `recurse` is true, also /// recursively of every submodule. @@ -243,7 +243,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// stored in a `shared_ptr`. /// \endrst OrderedDict> named_modules( - std::string name_prefix = std::string(), + const std::string& name_prefix = std::string(), bool include_self = true) const; /// Returns the direct submodules of this `Module`. diff --git a/torch/csrc/api/src/data/datasets/mnist.cpp b/torch/csrc/api/src/data/datasets/mnist.cpp index aa92cda..d77b457 100644 --- a/torch/csrc/api/src/data/datasets/mnist.cpp +++ b/torch/csrc/api/src/data/datasets/mnist.cpp @@ -51,11 +51,11 @@ uint32_t expect_int32(std::ifstream& stream, uint32_t expected) { return value; } -std::string join_paths(std::string head, std::string tail) { +std::string join_paths(std::string head, const std::string& tail) { if (head.back() != '/') { head.push_back('/'); } - head += std::move(tail); + head += tail; return head; } diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index 5a38195..de526b2 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -101,26 +101,26 @@ void Module::apply(const ConstModuleApplyFunction& function) const { void Module::apply( const NamedModuleApplyFunction& function, - std::string name_prefix) { + const std::string& name_prefix) { function(/*name=*/name_prefix, *this); apply_to_submodules( [&function]( const std::string& name, const std::shared_ptr& module) { function(name, *module); }, - std::move(name_prefix)); + name_prefix); } void Module::apply( const ConstNamedModuleApplyFunction& function, - std::string name_prefix) const { + const std::string& name_prefix) const { function(/*name=*/name_prefix, *this); apply_to_submodules( [&function]( const std::string& name, const std::shared_ptr& module) { function(name, *module); }, - std::move(name_prefix)); + name_prefix); } void Module::apply(const ModulePointerApplyFunction& function) const { @@ -133,10 +133,10 @@ void Module::apply(const ModulePointerApplyFunction& function) const { void Module::apply( const NamedModulePointerApplyFunction& function, - std::string name_prefix) const { + const std::string& name_prefix) const { function( /*name=*/name_prefix, shared_from_this_checked()); - apply_to_submodules(function, std::move(name_prefix)); + apply_to_submodules(function, name_prefix); } std::vector Module::parameters(bool recurse) const { @@ -199,7 +199,7 @@ std::vector> Module::modules(bool include_self) const { } OrderedDict> Module::named_modules( - std::string name_prefix, + const std::string& name_prefix, bool include_self) const { OrderedDict> result; if (include_self) { @@ -208,14 +208,14 @@ OrderedDict> Module::named_modules( const std::string& key, const std::shared_ptr& module) { result.insert(key, module); }, - std::move(name_prefix)); + name_prefix); } else { apply_to_submodules( [&result]( const std::string& key, const std::shared_ptr& module) { result.insert(key, module); }, - std::move(name_prefix)); + name_prefix); } return result; } @@ -329,7 +329,7 @@ void Module::apply_to_submodules( for (const auto& child : children_) { auto qualified_name = join_name(name_prefix, child.key()); function(qualified_name, child.value()); - child.value()->apply_to_submodules(function, std::move(qualified_name)); + child.value()->apply_to_submodules(function, qualified_name); } } diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index 43f1be1..cb1e91d 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -13,8 +13,7 @@ namespace torch { namespace nn { BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {} -BatchNormImpl::BatchNormImpl(BatchNormOptions options) - : options(std::move(options)) { +BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) { reset(); } diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 7797561..be972e5 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -14,7 +14,7 @@ EmbeddingOptions::EmbeddingOptions(int64_t count, int64_t dimension) : count_(count), dimension_(dimension) {} EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) - : options(std::move(options)) { + : options(options) { reset(); } diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index e2766ab..f9ec0d1 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -10,7 +10,7 @@ namespace torch { namespace nn { LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {} -LinearImpl::LinearImpl(LinearOptions options) : options(std::move(options)) { +LinearImpl::LinearImpl(LinearOptions options) : options(options) { reset(); } diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 8058cc5..632e4d4 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -131,7 +131,7 @@ RNNOutput RNNImplBase::generic_forward( } Tensor output, new_state; std::tie(output, new_state) = function( - std::move(input), + input, std::move(state), flat_weights_, options.with_bias_, @@ -208,12 +208,12 @@ RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) { case RNNActivation::ReLU: return generic_forward( static_cast(&torch::rnn_relu), - std::move(input), + input, std::move(state)); case RNNActivation::Tanh: return generic_forward( static_cast(&torch::rnn_tanh), - std::move(input), + input, std::move(state)); default: AT_ERROR("Unhandled RNN activation function!"); @@ -244,7 +244,7 @@ RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) { } Tensor output, hidden_state, cell_state; std::tie(output, hidden_state, cell_state) = torch::lstm( - std::move(input), + input, {state[0], state[1]}, flat_weights_, options.with_bias_, @@ -266,9 +266,7 @@ GRUImpl::GRUImpl(const GRUOptions& options) RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) { return generic_forward( - static_cast(&torch::gru), - std::move(input), - std::move(state)); + static_cast(&torch::gru), input, std::move(state)); } } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/optim/serialize.cpp b/torch/csrc/api/src/optim/serialize.cpp index 808b604..02d1000 100644 --- a/torch/csrc/api/src/optim/serialize.cpp +++ b/torch/csrc/api/src/optim/serialize.cpp @@ -17,6 +17,7 @@ void serialize( const std::string& key, const std::vector& steps) { std::vector tensors; + tensors.reserve(steps.size()); for (const auto& step : steps) { tensors.push_back(torch::tensor(static_cast(step))); } diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 5025370..b37c36d 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -153,7 +153,7 @@ inline Tensor as_variable(Tensor tensor) { inline std::vector as_variable(TensorList tl) { return fmap(tl, [](const Tensor& t) -> Tensor { - return make_variable(std::move(t), /*requires_grad=*/false); + return make_variable(t, /*requires_grad=*/false); }); } diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index b2f5da5..29149e4 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -408,7 +408,7 @@ static variable_list call_function(FunctionTask& task) { if(has_post_hooks){ // NOLINTNEXTLINE(bugprone-use-after-move) - return call_post_hooks(fn, std::move(outputs), std::move(inputs)); + return call_post_hooks(fn, std::move(outputs), inputs); } return outputs; } diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 9a9d8c6..061f85f 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -585,10 +585,7 @@ static Node* _trace_pre_record( Py_INCREF(op_obj); auto pyobj = THPObjectPtr(op_obj); return jit::tracer::preRecordPythonTrace( - std::move(pyobj), - std::move(arg_types), - input_vars, - std::move(scalar_args)); + std::move(pyobj), arg_types, input_vars, std::move(scalar_args)); } static void _trace_post_record( diff --git a/torch/csrc/jit/batched/BatchTensor.cpp b/torch/csrc/jit/batched/BatchTensor.cpp index 625166d..d514cc4 100644 --- a/torch/csrc/jit/batched/BatchTensor.cpp +++ b/torch/csrc/jit/batched/BatchTensor.cpp @@ -31,7 +31,7 @@ BatchTensor::BatchTensor(const std::vector& datalist, at::Tensor dim sizes[0] = bs; mask_sizes[0] = bs; for(int64_t i = 1; i < dims.size(0) + 1; i++){ - for(auto x : datalist){ + for(const auto& x : datalist){ sizes[i] = std::max(sizes[i], x.size(i)); } mask_sizes[i] = *dims[i - 1].data() ? sizes[i] : 1; diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp index 4912d58..c1d5884 100644 --- a/torch/csrc/jit/constants.cpp +++ b/torch/csrc/jit/constants.cpp @@ -102,19 +102,19 @@ RegisterOperators reg({ return 0; }; } else if(type->isSubtypeOf(ListType::ofInts())) { - auto is = node->is(attr::value); + const auto& is = node->is(attr::value); return [is](Stack& stack) { push(stack, is); return 0; }; } else if(type->isSubtypeOf(ListType::ofBools())) { - auto bs = node->is(attr::value); + const auto& bs = node->is(attr::value); return [bs](Stack& stack) { push(stack, bs); return 0; }; } else if(type->isSubtypeOf(ListType::ofTensors())) { - auto ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor { + const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor { return autograd::make_variable(t); }); return [ts](Stack& stack) { @@ -122,7 +122,7 @@ RegisterOperators reg({ return 0; }; } else if (type == StringType::get()) { - auto s = node->s(attr::value); + const auto& s = node->s(attr::value); return [s](Stack& stack) { push(stack, s); return 0; diff --git a/torch/csrc/jit/custom_operator.h b/torch/csrc/jit/custom_operator.h index 29d518b..051a7bf 100644 --- a/torch/csrc/jit/custom_operator.h +++ b/torch/csrc/jit/custom_operator.h @@ -63,7 +63,7 @@ Node* getTracedNode( const std::tuple& tuple) { auto symbol = Symbol::fromQualString(schema.name()); const auto& graph = tracer::getTracingState()->graph; - Node* node = graph->create(std::move(symbol), /*num_outputs=*/0); + Node* node = graph->create(symbol, /*num_outputs=*/0); tracer::recordSourceLocation(node); // Hack to call addInputs for the parameter pack in a sequenced fashion. diff --git a/torch/csrc/jit/fuser/compiler.cpp b/torch/csrc/jit/fuser/compiler.cpp index dadf9cc..74d6dc3 100644 --- a/torch/csrc/jit/fuser/compiler.cpp +++ b/torch/csrc/jit/fuser/compiler.cpp @@ -51,7 +51,7 @@ int debugFuser() { // If the given node is used once by a chunk node, returns that node. // Returns nullptr otherwise. static const Node* usedInFusedChunk(const Value* input) { - const auto uses = input->uses(); + const auto& uses = input->uses(); if (uses.size() == 1) { const Node *user = uses[0].user; if (user->kind() == prim::ConstantChunk) { diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index a59bc51..07b3019 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -503,7 +503,7 @@ private: } void runTraced(Stack & stack) { - auto state = tracer::getTracingState(); + const auto& state = tracer::getTracingState(); auto inputs = last(stack, num_inputs); auto input_values = fmap(inputs, [](const IValue & v) { return tracer::getNestedValueTrace(v); diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index e5b05d0..77c9d54 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -300,7 +300,7 @@ void initJITBindings(PyObject *module) { m.def("_jit_get_operation", [](const std::string& qualified_name) { try { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(std::move(symbol)); + auto operations = getAllOperatorsFor(symbol); AT_CHECK(!operations.empty(), "No such operator ", qualified_name); AT_CHECK( operations.size() == 1, @@ -338,7 +338,7 @@ void initJITBindings(PyObject *module) { }); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(std::move(symbol)); + auto operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index d969f0c..facbd61 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -73,7 +73,7 @@ struct Suspend : public std::exception { struct InterpreterContinuation { InterpreterContinuation(InterpreterState state_, Stack stack_) - : state(std::move(state_)), stack(std::move(stack_)) {} + : state(state_), stack(std::move(stack_)) {} void operator()() { state.runAsync(stack); diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 00a5027..dbe7e64 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -253,7 +253,7 @@ struct SchemaParser { n = "-" + L.expect(TK_NUMBER).text(); else n = L.expect(TK_NUMBER).text(); - if(kind == TypeKind::FloatType || n.find(".") != std::string::npos || n.find("e") != std::string::npos) { + if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) { return std::stod(n); } else { int64_t v = std::stoll(n); @@ -405,7 +405,7 @@ private: // XXX - caller must be holding lock void registerPendingOperators() { - for(auto op : to_register) { + for(const auto& op : to_register) { Symbol sym = Symbol::fromQualString(op->schema().name()); operators[sym].push_back(op); operators_by_sig[canonicalSchemaString(op->schema())] = op; diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 0f931a2..2cd25ef 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -531,7 +531,7 @@ void AliasDb::addAlias(const Value* value, Symbol alias) { valueToAlias_[value].addSet(alias); } else { AliasInfo aliasInfo; - aliasInfo.addSet(std::move(alias)); + aliasInfo.addSet(alias); valueToAlias_.insert({value, std::move(aliasInfo)}); } } diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 8f43048..ae79808 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -973,6 +973,7 @@ void PeepholeOptimizeShapeExpressions(Block * block) { } if (unique_to_value.size() != node->inputs().size()) { std::vector inputs; + inputs.reserve(unique_to_value.size()); for (auto & entry : unique_to_value) { inputs.push_back(entry.second); } diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 8f23952..3930b2d 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -637,7 +637,7 @@ struct PythonPrintPass { } else if(v.isTensorList()) { stmt << "["; const char* delim = ""; - for(auto t : v.toTensorListRef()) { + for(const auto& t : v.toTensorListRef()) { stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t); delim = ", "; } diff --git a/torch/csrc/jit/passes/specialize_undef.cpp b/torch/csrc/jit/passes/specialize_undef.cpp index 4c3f94d..3e258cf 100644 --- a/torch/csrc/jit/passes/specialize_undef.cpp +++ b/torch/csrc/jit/passes/specialize_undef.cpp @@ -15,7 +15,7 @@ void specializeUndef(Graph & g) { std::unordered_map state; for (Value* input : g.inputs()) { - auto tp = input->type(); + const auto& tp = input->type(); if (tp->isSubtypeOf(UndefinedTensorType::get())) { state[input] = State::Undefined; } else if (tp->isSubtypeOf(DynamicType::get())) { diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 671bad8..c34335a 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -467,7 +467,7 @@ void initPythonIRBindings(PyObject * module_) { .def(py::init([](std::vector a){ return TupleType::create(a); })) .def("elements", [](TupleType &self){ std::vector types; - for (auto type : self.elements()) { + for (const auto& type : self.elements()) { types.push_back(type); } return types; diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index e5619a9..a83b645 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -165,7 +165,7 @@ void initPythonTracerBindings(PyObject* module) { return setValueTrace(var, value); }); m.def("_tracer_set_get_unique_name_fn", [](py::function func) { - auto tracing_state = getTracingState(); + const auto& tracing_state = getTracingState(); JIT_ASSERT(tracing_state); tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string { AutoGIL ag; @@ -173,7 +173,7 @@ void initPythonTracerBindings(PyObject* module) { }; }); m.def("_tracer_set_force_outplace", [](bool force_outplace) { - auto tracing_state = getTracingState(); + const auto& tracing_state = getTracingState(); JIT_ASSERT(tracing_state); tracing_state->force_outplace = force_outplace; }); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 10e91a5..128145c 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -401,7 +401,7 @@ RegisterOperators reg({ return [=](Stack& stack) { bool result = false; for (const IValue& t : last(stack, num_inputs)) { - if (std::move(t).toTensor().defined()) { + if (t.toTensor().defined()) { result = true; break; } @@ -1135,6 +1135,7 @@ Operator( \ at::Tensor t; pop(stack, t); std::vector elems; + elems.reserve(t.size(0)); for(int i = 0; i < t.size(0); i++){ elems.push_back(*t[i].data()); } diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 5dd09fa..c46f4e1 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -419,7 +419,7 @@ static inline bool isIntOrFloatUsedAsList( const Value* value, const Argument& arg) { // Look for int[N] or float[N] - auto v_type = value->type(); + const auto& v_type = value->type(); if (v_type != FloatType::get() && v_type != IntType::get()) return false; auto arg_type = unwrapOptional(arg.type()); @@ -1054,7 +1054,7 @@ private: << " return (" << schema.returns().size() << ") does not match" << " the number of returns from the function (" << results.size() << ")!"; } - auto range = return_stmt.range(); + const auto& range = return_stmt.range(); size_t return_type_idx = 0; for (auto r : results) { TypePtr type = DynamicType::get(); @@ -1506,7 +1506,7 @@ private: auto instances = sv->asTuple(stmt.range(), method); const std::string& target_name = target.name(); pushFrame(environment_stack->block()); - for(auto inst : instances) { + for(const auto& inst : instances) { environment_stack->setSugaredVar(itrs[0].range(), target_name, inst); emitStatements(body); } @@ -1988,7 +1988,7 @@ private: if(maybe_unpack && tree->kind() == TK_STARRED) { auto starred = Starred(tree); auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method); - for(auto entry : entries) { + for(const auto& entry : entries) { values.emplace_back( tree->range(), entry->asValue(starred.range(), method)); } @@ -2639,7 +2639,7 @@ void defineMethodsInModule(const std::shared_ptr& m, const std::vector methods; std::unordered_map function_table; - for(Def def : definitions) { + for(const Def& def : definitions) { const std::string& name = def.name().name(); auto resolver = *resolver_it++; JIT_ASSERT(resolver); diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index bb98996..e3e28aa 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -121,7 +121,7 @@ private: struct TORCH_API BuiltinFunction : public SugaredValue { BuiltinFunction(Symbol symbol, c10::optional self) - : symbol(std::move(symbol)), self(std::move(self)) {} + : symbol(symbol), self(std::move(self)) {} // The symbol of the function (e.g. `aten::relu`). Symbol symbol; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 7e6e975..eafe330 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -185,7 +185,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue { const SourceRange& loc, Method& m) override { std::vector values; - for (auto sugared_item : asTuple(loc, m)) { + for (const auto& sugared_item : asTuple(loc, m)) { values.push_back(sugared_item->asValue(loc, m)); } auto node = m.graph()->createTuple(values); @@ -532,6 +532,7 @@ void initJitScriptBindings(PyObject* module) { const std::vector& rcbs, const std::vector& defaults) { std::vector resolvers; + resolvers.reserve(rcbs.size()); for(auto & callback : rcbs) { resolvers.push_back(pythonResolver(callback)); } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 4920b2f..ed3fade 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -123,7 +123,7 @@ struct Method { stack.push_back(*inp); } const auto size = stack.size(); - setInputTypes(*retval, ArgumentSpec(with_grad, std::move(stack), size)); + setInputTypes(*retval, ArgumentSpec(with_grad, stack, size)); PropagateInputShapes(retval); return retval; } diff --git a/torch/csrc/jit/script/python_tree_views.cpp b/torch/csrc/jit/script/python_tree_views.cpp index 4833c34..8d2c39d 100644 --- a/torch/csrc/jit/script/python_tree_views.cpp +++ b/torch/csrc/jit/script/python_tree_views.cpp @@ -107,10 +107,10 @@ void initTreeViewBindings(PyObject *module) { .def(py::init([](const Ident& name, Decl decl, std::vector body) { - auto r = name.range(); + const auto& r = name.range(); return Def::create(r, name, - std::move(decl), + decl, wrap_list(r, std::move(body))); })); py::class_(m, "Decl") @@ -127,7 +127,7 @@ void initTreeViewBindings(PyObject *module) { })); py::class_(m, "AugAssign") .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) { - auto r = lhs.range(); + const auto& r = lhs.range(); auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); return AugAssign::create(r, lhs, kind, rhs); })); @@ -198,13 +198,13 @@ void initTreeViewBindings(PyObject *module) { })); py::class_(m, "Apply") .def(py::init([](const Expr& expr, std::vector args, std::vector kwargs) { - auto r = expr.range(); + const auto& r = expr.range(); return Apply::create(expr.range(), expr, wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs))); })); py::class_(m, "Select") .def(py::init([](const Expr& expr, const Ident& field) { - auto r = expr.range(); + const auto& r = expr.range(); return Select::create(expr.range(), expr, field); })); py::class_(m, "TernaryIf") diff --git a/torch/csrc/jit/script/tree.h b/torch/csrc/jit/script/tree.h index aab95f5..787c0df 100644 --- a/torch/csrc/jit/script/tree.h +++ b/torch/csrc/jit/script/tree.h @@ -111,7 +111,7 @@ struct String : public Tree { }; static SourceRange mergeRanges(SourceRange c, const TreeList& others) { - for (auto t : others) { + for (const auto& t : others) { if (t->isAtom()) continue; size_t s = std::min(c.start(), t->range().start()); @@ -171,7 +171,7 @@ struct pretty_tree { break; default: out << "(" << kindToString(t->kind()); - for (auto e : t->trees()) { + for (const auto& e : t->trees()) { out << " " << get_flat(e); } out << ")"; @@ -188,7 +188,7 @@ struct pretty_tree { } std::string k = kindToString(t->kind()); out << "(" << k; - for (auto e : t->trees()) { + for (const auto& e : t->trees()) { out << "\n" << std::string(indent + 2, ' '); print(out, e, indent + 2); } diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 04f05b5..54bda65 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -290,7 +290,7 @@ static void py_bind_tensor_types(const std::vector& tensor_types) for (auto& tensor_type : tensor_types) { auto name = std::string(tensor_type.name); - auto idx = name.rfind("."); + auto idx = name.rfind('.'); auto type_name = name.substr(idx + 1); auto module_name = name.substr(0, idx); diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 187b9ec..8a13094 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -286,7 +286,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt) while (!done) { auto offset = fmt.find(", ", last_offset); if (offset == std::string::npos) { - offset = fmt.find(")", last_offset); + offset = fmt.find(')', last_offset); done = true; next_offset = offset + 1; } else { diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 7a1fcd0..6721ab7 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -209,7 +209,7 @@ Tensor internal_new_from_data( auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type))); AutoNoGIL no_gil; maybe_initialize_cuda(device); - return var.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_variables); + return var.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables); } #ifdef USE_NUMPY @@ -219,7 +219,7 @@ Tensor internal_new_from_data( auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type()); AutoNoGIL no_gil; maybe_initialize_cuda(device); - return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_numpy); + return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy); } #endif @@ -232,7 +232,7 @@ Tensor internal_new_from_data( auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type)); AutoNoGIL no_gil; maybe_initialize_cuda(device); - return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/false); + return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/false); } Tensor new_from_data_copy( -- 2.7.4