clang format world (#15524)
authorMichael Suo <suo@fb.com>
Wed, 26 Dec 2018 14:52:25 +0000 (06:52 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 26 Dec 2018 14:55:01 +0000 (06:55 -0800)
Summary:
The PR clang-formats everything in `torch/csrc/jit/` and adds it to the pre-commit hook.

Here is a list of non-mechanical changes:
- I went over each file and fixed up whenever I could tell that clang-format was clobbering comment formatting.
- Made the macros in register_prim_ops a little more clang-format friendly by omitting trailing commas
- Refactored autodiff.cpp to use a helper class with explicit state rather than a bunch of capturing lambdas
- Small improvements to the precommit hook clang-format
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15524

Differential Revision: D13547989

Pulled By: suo

fbshipit-source-id: 3ff1541bb06433ccfe6de6e33f29227a2b5bb493

169 files changed:
test/cpp/jit/tests.h
tools/clang_format.py
torch/csrc/jit/alias_info.h
torch/csrc/jit/argument_spec.h
torch/csrc/jit/attributes.h
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/autodiff.h
torch/csrc/jit/batched/BatchTensor.cpp
torch/csrc/jit/batched/BatchTensor.h
torch/csrc/jit/catch_utils.hpp
torch/csrc/jit/code_template.h
torch/csrc/jit/constants.cpp
torch/csrc/jit/constants.h
torch/csrc/jit/dynamic_dag.h
torch/csrc/jit/export.cpp
torch/csrc/jit/export.h
torch/csrc/jit/function_schema.h
torch/csrc/jit/fuser/arg_spec.h
torch/csrc/jit/fuser/codegen.cpp
torch/csrc/jit/fuser/codegen.h
torch/csrc/jit/fuser/compiler.cpp
torch/csrc/jit/fuser/compiler.h
torch/csrc/jit/fuser/config.h.in
torch/csrc/jit/fuser/cpu/dynamic_library.h
torch/csrc/jit/fuser/cpu/fused_kernel.cpp
torch/csrc/jit/fuser/cpu/fused_kernel.h
torch/csrc/jit/fuser/cpu/resource_strings.h
torch/csrc/jit/fuser/cpu/temp_file.h
torch/csrc/jit/fuser/cuda/fused_kernel.cpp
torch/csrc/jit/fuser/cuda/fused_kernel.h
torch/csrc/jit/fuser/cuda/resource_strings.h
torch/csrc/jit/fuser/executor.cpp
torch/csrc/jit/fuser/executor.h
torch/csrc/jit/fuser/fallback.cpp
torch/csrc/jit/fuser/fallback.h
torch/csrc/jit/fuser/fused_kernel.h
torch/csrc/jit/fuser/interface.cpp
torch/csrc/jit/fuser/interface.h
torch/csrc/jit/fuser/kernel_cache.cpp
torch/csrc/jit/fuser/kernel_cache.h
torch/csrc/jit/fuser/kernel_spec.h
torch/csrc/jit/fuser/partition_desc.h
torch/csrc/jit/fuser/tensor_desc.h
torch/csrc/jit/fuser/tensor_info.h
torch/csrc/jit/generic_if.h
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/graph_executor.h
torch/csrc/jit/graph_node_list.h
torch/csrc/jit/hooks_for_testing.cpp
torch/csrc/jit/hooks_for_testing.h
torch/csrc/jit/import.cpp
torch/csrc/jit/import.h
torch/csrc/jit/import_method.cpp
torch/csrc/jit/import_method.h
torch/csrc/jit/init.cpp
torch/csrc/jit/init.h
torch/csrc/jit/interpreter.cpp
torch/csrc/jit/interpreter.h
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/ivalue.h
torch/csrc/jit/named_value.h
torch/csrc/jit/node_hashing.cpp
torch/csrc/jit/node_hashing.h
torch/csrc/jit/operator.cpp
torch/csrc/jit/operator.h
torch/csrc/jit/passes/batch_mm.cpp
torch/csrc/jit/passes/batch_mm.h
torch/csrc/jit/passes/canonicalize.cpp
torch/csrc/jit/passes/canonicalize.h
torch/csrc/jit/passes/canonicalize_ops.cpp
torch/csrc/jit/passes/canonicalize_ops.h
torch/csrc/jit/passes/common_subexpression_elimination.cpp
torch/csrc/jit/passes/common_subexpression_elimination.h
torch/csrc/jit/passes/constant_pooling.cpp
torch/csrc/jit/passes/constant_pooling.h
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/constant_propagation.h
torch/csrc/jit/passes/create_autodiff_subgraphs.h
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/dead_code_elimination.h
torch/csrc/jit/passes/erase_number_types.cpp
torch/csrc/jit/passes/erase_number_types.h
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/graph_fuser.h
torch/csrc/jit/passes/inline_autodiff_subgraphs.h
torch/csrc/jit/passes/inplace_check.cpp
torch/csrc/jit/passes/inplace_check.h
torch/csrc/jit/passes/loop_unrolling.cpp
torch/csrc/jit/passes/loop_unrolling.h
torch/csrc/jit/passes/lower_grad_of.cpp
torch/csrc/jit/passes/lower_grad_of.h
torch/csrc/jit/passes/lower_tuples.cpp
torch/csrc/jit/passes/lower_tuples.h
torch/csrc/jit/passes/onnx.cpp
torch/csrc/jit/passes/onnx.h
torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
torch/csrc/jit/passes/onnx/fixup_onnx_loop.h
torch/csrc/jit/passes/onnx/peephole.cpp
torch/csrc/jit/passes/onnx/peephole.h
torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h
torch/csrc/jit/passes/peephole.cpp
torch/csrc/jit/passes/peephole.h
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/python_print.h
torch/csrc/jit/passes/remove_expands.cpp
torch/csrc/jit/passes/remove_expands.h
torch/csrc/jit/passes/remove_inplace_ops.cpp
torch/csrc/jit/passes/requires_grad_analysis.cpp
torch/csrc/jit/passes/requires_grad_analysis.h
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/passes/shape_analysis.h
torch/csrc/jit/passes/specialize_undef.cpp
torch/csrc/jit/passes/specialize_undef.h
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/passes/to_batch.h
torch/csrc/jit/passes/utils/check_alias_annotation.cpp
torch/csrc/jit/pybind.h
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/python_arg_flatten.cpp
torch/csrc/jit/python_arg_flatten.h
torch/csrc/jit/python_interpreter.cpp
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/python_ir.h
torch/csrc/jit/python_tracer.cpp
torch/csrc/jit/python_tracer.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/register_special_ops.cpp
torch/csrc/jit/resource_guard.h
torch/csrc/jit/scope.cpp
torch/csrc/jit/scope.h
torch/csrc/jit/script/builtin_functions.cpp
torch/csrc/jit/script/builtin_functions.h
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h
torch/csrc/jit/script/error_report.h
torch/csrc/jit/script/final_returns.cpp
torch/csrc/jit/script/final_returns.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/jit_exception.h
torch/csrc/jit/script/lexer.cpp
torch/csrc/jit/script/lexer.h
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/parse_string_literal.h
torch/csrc/jit/script/parser.cpp
torch/csrc/jit/script/parser.h
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/python_tree_views.h
torch/csrc/jit/script/schema_matching.cpp
torch/csrc/jit/script/schema_matching.h
torch/csrc/jit/script/sugared_value.cpp
torch/csrc/jit/script/sugared_value.h
torch/csrc/jit/script/tree.h
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/script/type_parser.cpp
torch/csrc/jit/script/type_parser.h
torch/csrc/jit/source_location.h
torch/csrc/jit/source_range.h
torch/csrc/jit/stack.h
torch/csrc/jit/symbolic_script.cpp
torch/csrc/jit/symbolic_script.h
torch/csrc/jit/symbolic_variable.h
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h
torch/csrc/jit/tracing_state.h
torch/csrc/jit/type.h
torch/csrc/jit/variable_tensor_list.h

index 42e33d8..a0c71fd 100644 (file)
   } catch (const std::exception& e) {                                    \
     ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
   }
-#define ASSERT_ANY_THROW(statement)                                      \
-  bool threw = false;                                                    \
-  try {                                                                  \
-    (void)statement;                                                     \
-  } catch (const std::exception& e) {                                    \
-    threw = true;                                                        \
-  }                                                                      \
-  ASSERT_TRUE(threw);                                                    \
+#define ASSERT_ANY_THROW(statement)   \
+  bool threw = false;                 \
+  try {                               \
+    (void)statement;                  \
+  } catch (const std::exception& e) { \
+    threw = true;                     \
+  }                                   \
+  ASSERT_TRUE(threw);
 
 #endif // defined(USE_GTEST)
 
+#include "torch/csrc/autograd/generated/variable_factories.h"
 #include "torch/csrc/autograd/variable.h"
 #include "torch/csrc/jit/argument_spec.h"
 #include "torch/csrc/jit/assertions.h"
 #include "torch/csrc/jit/passes/requires_grad_analysis.h"
 #include "torch/csrc/jit/passes/shape_analysis.h"
 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
-#include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/jit/symbolic_script.h"
+#include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/jit/tracer.h"
 #include "torch/csrc/utils/hash.h"
-#include "torch/csrc/autograd/generated/variable_factories.h"
 
 #include "torch/csrc/autograd/engine.h"
 #include "torch/csrc/autograd/variable.h"
@@ -440,7 +440,9 @@ std::shared_ptr<Graph> build_lstm() {
   return r;
 }
 
-std::vector<at::Tensor> run(InterpreterState & interp, const std::vector<at::Tensor> & inputs) {
+std::vector<at::Tensor> run(
+    InterpreterState& interp,
+    const std::vector<at::Tensor>& inputs) {
   std::vector<IValue> stack(inputs.begin(), inputs.end());
   interp.run(stack);
   return fmap(stack, [](const IValue& i) { return i.toTensor(); });
@@ -469,8 +471,7 @@ std::pair<tensor_list, tensor_list> runGradient(
   df_interpreter.run(df_stack);
 
   // Outputs of f needs to be sliced
-  f_stack.erase(
-      f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
+  f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
   return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
 }
 
@@ -515,13 +516,14 @@ void testTHNNConv() {
 
   // make inputs
   at::Tensor input = torch::randn(input_size);
-  at::Tensor weight = torch::randn({out_channels, input_size[1], kernel_size[0], kernel_size[1]});
+  at::Tensor weight = torch::randn(
+      {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
   at::Tensor bias = torch::randn({out_channels});
 
   // run forward eagerly
   at::Tensor output, finput, fgradinput;
-  std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(input, weight, kernel_size,
-                                                                bias, stride, padding);
+  std::tie(output, finput, fgradinput) = at::thnn_conv2d_forward(
+      input, weight, kernel_size, bias, stride, padding);
 
   // make grad_outputs
   at::Tensor grad_output = torch::randn_like(output);
@@ -530,9 +532,16 @@ void testTHNNConv() {
 
   // run backward eagerly
   at::Tensor grad_input, grad_weight, grad_bias;
-  std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(grad_output, input, weight,
-                                                                         kernel_size, stride, padding,
-                                                                         finput, fgradinput, {true, true, true});
+  std::tie(grad_input, grad_weight, grad_bias) = at::thnn_conv2d_backward(
+      grad_output,
+      input,
+      weight,
+      kernel_size,
+      stride,
+      padding,
+      finput,
+      fgradinput,
+      {true, true, true});
 
   // make JIT graph
   auto graph = std::make_shared<Graph>();
@@ -544,7 +553,9 @@ void testTHNNConv() {
   auto weightg = graph->addInput("weight");
   auto biasg = graph->addInput("bias");
 
-  Value* conv = graph->insert(aten::thnn_conv2d_forward, {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
+  Value* conv = graph->insert(
+      aten::thnn_conv2d_forward,
+      {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
   auto outputs = conv->node()->outputs();
   for (auto output : outputs) {
     graph->registerOutput(output);
@@ -572,7 +583,7 @@ void testTHNNConv() {
   // Get outputs from the interpreter
   tensor_list tensors_out, tensor_grads_out;
   std::tie(tensors_out, tensor_grads_out) =
-    runGradient(grad_spec, tensors_in, tensor_grads_in);
+      runGradient(grad_spec, tensors_in, tensor_grads_in);
 
   // prepare expected structs
   tensor_list expected_tensors_out, expected_tensor_grads_out;
@@ -589,7 +600,9 @@ void testTHNNConv() {
 }
 
 void testATenNativeBatchNorm() {
-  // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
+  // running_mean, Tensor running_var, bool training, float momentum, float eps)
+  // -> (Tensor, Tensor, Tensor)
   std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
   bool training = true;
   float momentum = 0.9;
@@ -610,7 +623,15 @@ void testATenNativeBatchNorm() {
 
   // run forward eagerly
   at::Tensor output, savemean, saveinvstd;
-  std::tie(output, savemean, saveinvstd) = at::native_batch_norm(input, weight, bias, running_mean_eager, running_var_eager, training, momentum, eps);
+  std::tie(output, savemean, saveinvstd) = at::native_batch_norm(
+      input,
+      weight,
+      bias,
+      running_mean_eager,
+      running_var_eager,
+      training,
+      momentum,
+      eps);
 
   // make grad_outputs
   at::Tensor grad_output = torch::randn_like(output);
@@ -619,10 +640,21 @@ void testATenNativeBatchNorm() {
 
   // run backward eagerly
   at::Tensor grad_input, grad_weight, grad_bias;
-  // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
-  std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(grad_output, input, weight,
-                                                                               running_mean_eager, running_var_eager,
-                                                                               savemean, saveinvstd, training, eps, {true, true, true});
+  // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
+  // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
+  // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
+  // Tensor, Tensor)
+  std::tie(grad_input, grad_weight, grad_bias) = at::native_batch_norm_backward(
+      grad_output,
+      input,
+      weight,
+      running_mean_eager,
+      running_var_eager,
+      savemean,
+      saveinvstd,
+      training,
+      eps,
+      {true, true, true});
 
   // make JIT graph
   auto graph = std::make_shared<Graph>();
@@ -636,7 +668,16 @@ void testATenNativeBatchNorm() {
   auto running_meang = graph->addInput("running_mean");
   auto running_varg = graph->addInput("running_var");
 
-  Value* bn = graph->insert(aten::native_batch_norm, {inputg, weightg, biasg, running_meang, running_varg, training_val, momentum_val, eps_val});
+  Value* bn = graph->insert(
+      aten::native_batch_norm,
+      {inputg,
+       weightg,
+       biasg,
+       running_meang,
+       running_varg,
+       training_val,
+       momentum_val,
+       eps_val});
   auto outputs = bn->node()->outputs();
   for (auto output : outputs) {
     graph->registerOutput(output);
@@ -666,7 +707,7 @@ void testATenNativeBatchNorm() {
   // Get outputs from the interpreter
   tensor_list tensors_out, tensor_grads_out;
   std::tie(tensors_out, tensor_grads_out) =
-    runGradient(grad_spec, tensors_in, tensor_grads_in);
+      runGradient(grad_spec, tensors_in, tensor_grads_in);
 
   // prepare expected structs
   tensor_list expected_tensors_out, expected_tensor_grads_out;
@@ -770,15 +811,20 @@ void testADFormulas() {
        [](const VL& v) -> VL { return {v[0].tanh()}; }},
       {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }},
       {"view",
-        unary_pointwise_2d,
-        [](const VL& v) -> VL { return {v[0].view({3, 2})}; }},
+       unary_pointwise_2d,
+       [](const VL& v) -> VL {
+         return {v[0].view({3, 2})};
+       }},
       {"expand",
-        {{2, 1}},
-        [](const VL& v) -> VL { return {v[0].expand({2, 3})}; }},
+       {{2, 1}},
+       [](const VL& v) -> VL {
+         return {v[0].expand({2, 3})};
+       }},
       {"mm",
        {{10, 12}, {12, 15}},
        [](const VL& v) -> VL { return {v[0].mm(v[1])}; }},
-      // TODO: enable once we'll be able to capture lists across forward-backward
+      // TODO: enable once we'll be able to capture lists across
+      // forward-backward
       //{"chunk",   {{10, 12, 15}}, [](const VL& v) -> VL { return
       // fmap<Variable>(v[0].chunk(4, 1)); }},
       //{"chunk",   {{10, 12, 15}}, [](const VL& v) -> VL { return
@@ -860,8 +906,10 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
   graph->registerOutput(d.value());
   graph->registerOutput(e.value());
 
-  auto a_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
-  auto b_var = autograd::make_variable(at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
+  auto a_var = autograd::make_variable(
+      at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
+  auto b_var = autograd::make_variable(
+      at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
   setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2));
   PropagateInputShapes(graph);
   PropagateRequiresGrad(graph);
@@ -899,9 +947,10 @@ void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
 
   auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
     const auto& nodes = graph->nodes();
-    auto maybe_fusion_group = std::find_if(
-        nodes.begin(), nodes.end(),
-        [](const Node* node) { return node->kind() == prim::FusionGroup; });
+    auto maybe_fusion_group =
+        std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
+          return node->kind() == prim::FusionGroup;
+        });
     JIT_ASSERTM(
         maybe_fusion_group != nodes.end(),
         "testRegisterFusionCachesKernel: could not create FusionGroup");
@@ -1215,7 +1264,8 @@ void testCustomOperators() {
     ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
     ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
     ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
-    ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
+    ASSERT_EQ(
+        op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
 
     ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType);
 
@@ -1243,7 +1293,8 @@ void testCustomOperators() {
     ASSERT_EQ(op->schema().arguments()[0].name(), "a");
     ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
     ASSERT_EQ(op->schema().arguments()[1].name(), "b");
-    ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
+    ASSERT_EQ(
+        op->schema().arguments()[1].type()->kind(), TypeKind::DynamicType);
 
     ASSERT_EQ(op->schema().returns().size(), 1);
     ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::DynamicType);
@@ -1405,29 +1456,33 @@ void testSchemaParser() {
   // nested arrays
   auto s = parseSchema("at::what(int[][4] foo) -> ()");
   ASSERT_TRUE(s.arguments().at(0).N() == 4);
-  ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments().at(0)
-                                              .type()->expect<ListType>()
+  ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
+                                              .at(0)
+                                              .type()
+                                              ->expect<ListType>()
                                               ->getElementType()
                                               ->expect<ListType>()
                                               ->getElementType()));
   auto s2 = parseSchema("at::what(int[][] foo) -> ()");
-  ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments().at(0)
-                                            .type()->expect<ListType>()
-                                            ->getElementType()
-                                            ->expect<ListType>()
-                                            ->getElementType()));
+  ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
+                                              .at(0)
+                                              .type()
+                                              ->expect<ListType>()
+                                              ->getElementType()
+                                              ->expect<ListType>()
+                                              ->getElementType()));
 
   // named returns
   parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
-  auto s3 = parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
+  auto s3 =
+      parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
   ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
   ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
 
   // futures
   auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
-  ASSERT_TRUE(IntType::get()->isSubtypeOf(s4.arguments().at(0)
-                                          .type()->expect<FutureType>()
-                                          ->getElementType()));
+  ASSERT_TRUE(IntType::get()->isSubtypeOf(
+      s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
 
   // test tensor with annotated alias sets
   parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
@@ -1530,9 +1585,9 @@ void testTopologicalIndex() {
   }
 }
 
-
 std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() {
-  return std::unique_ptr<detail::DynamicDAG<std::string>>(new detail::DynamicDAG<std::string>());
+  return std::unique_ptr<detail::DynamicDAG<std::string>>(
+      new detail::DynamicDAG<std::string>());
 }
 
 void testNewVertex() {
@@ -1781,20 +1836,20 @@ struct TopoMoveTestFixture {
   bool moveBeforeTopologicallyValid(
       const std::string& toInsert,
       const std::string& insertPoint) {
-    std::function<bool(Node*, Node*)> func = [this](Node* toInsert,
-                                                Node* insertPoint) {
-      return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
-    };
+    std::function<bool(Node*, Node*)> func =
+        [this](Node* toInsert, Node* insertPoint) {
+          return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
+        };
     return moveWithChecks(toInsert, insertPoint, func);
   }
 
   bool moveAfterTopologicallyValid(
       const std::string& toInsert,
       const std::string& insertPoint) {
-    std::function<bool(Node*, Node*)> func = [this](Node* toInsert,
-                                                Node* insertPoint) {
-      return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
-    };
+    std::function<bool(Node*, Node*)> func =
+        [this](Node* toInsert, Node* insertPoint) {
+          return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
+        };
     return moveWithChecks(toInsert, insertPoint, func);
   }
 
index ca44baf..454bd34 100644 (file)
@@ -10,18 +10,19 @@ Running tools/clang_format.py manually with no arguments should replicate the pr
 Only files that are in CLANG_FORMAT_WHITELIST are checked.
 """
 import subprocess
-import glob
-import itertools
 import os
 import argparse
+import fnmatch
 import difflib
 import sys
+import re
 
 
-# Whitelist of files to check. Takes a glob syntax. Does not support
-# recursive globs ("**") because I am lazy and don't want to make that
-# work with Python 2.
-CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/passes/alias_analysis*"]
+# Whitelist of directories to check. All files that in that directory
+# (recursively) will be checked.
+CLANG_FORMAT_WHITELIST = ["torch/csrc/jit/", "test/cpp/jit/"]
+
+CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$")
 
 
 def parse_args():
@@ -43,21 +44,28 @@ def parse_args():
             "Otherwise, just print the changes and exit"
         ),
     )
+    parser.add_argument(
+        "--check-all",
+        action="store_true",
+        default=False,
+        help="If true, check all whitelisted files instead of just working copy changes",
+    )
     parser.add_argument("--verbose", "-v", action="store_true", default=False)
     return parser.parse_args()
 
 
 def get_whitelisted_files():
     """
-    Parse CLANG_FORMAT_WHITELIST and resolve all globs.
-    Returns the set of all whitelisted filenames.
+    Parse CLANG_FORMAT_WHITELIST and resolve all directories.
+    Returns the set of whitelist cpp source files.
     """
-    paths = [glob.glob(entry) for entry in CLANG_FORMAT_WHITELIST]
-    # flatten the files list
-    paths = itertools.chain(*paths)
-    # filter out directories
-    filenames = filter(lambda path: os.path.isfile(path), paths)
-    return set(filenames)
+    matches = []
+    for dir in CLANG_FORMAT_WHITELIST:
+        for root, dirnames, filenames in os.walk(dir):
+            for filename in filenames:
+                if CPP_FILE_REGEX.fullmatch(filename):
+                    matches.append(os.path.join(root, filename))
+    return set(matches)
 
 
 def get_changed_files(rev):
@@ -98,10 +106,13 @@ def get_diffs(files):
 def main():
     args = parse_args()
 
-    changed_files = get_changed_files(args.diff)
     whitelisted_files = get_whitelisted_files()
 
-    files_to_check = changed_files & whitelisted_files
+    if args.check_all:
+        files_to_check = whitelisted_files
+    else:
+        changed_files = get_changed_files(args.diff)
+        files_to_check = changed_files & whitelisted_files
 
     if args.verbose:
         print("Running clang-format on whitelisted files: ")
@@ -118,6 +129,11 @@ def main():
         args = ["clang-format", "-i"]
         args.extend(name_to_diffs.keys())
         subprocess.check_output(args)
+
+        # add the changes so they will be committed
+        args = ["git", "add"]
+        args.extend(name_to_diffs.keys())
+        subprocess.check_output(args)
     else:
         print("ERROR: Running clang-format created changes: ")
         for name, diff in name_to_diffs.items():
index e0d79fb..443a8b5 100644 (file)
@@ -1,6 +1,7 @@
 #include <ATen/core/alias_info.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 using ::c10::AliasInfo;
 
index ec3d988..ba9c0cd 100644 (file)
@@ -1,20 +1,21 @@
 #pragma once
 
-#include <iostream>
-#include <vector>
 #include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/utils/hash.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/stack.h>
 #include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/variable_tensor_list.h>
+#include <torch/csrc/utils/hash.h>
+#include <iostream>
+#include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-// GraphExecutor creates specializations of Graphs for different dimensionalitities
-// and types of inputs.
+// GraphExecutor creates specializations of Graphs for different
+// dimensionalitities and types of inputs.
 
-inline static at::Device ConvertIntToCPUOrCUDA(int device){
+inline static at::Device ConvertIntToCPUOrCUDA(int device) {
   return device < 0 ? at::kCPU : at::Device(at::DeviceType::CUDA, device);
 }
 struct ArgumentInfo {
@@ -30,7 +31,8 @@ struct ArgumentInfo {
   int device() const {
     return device_;
   }
-  // XXX: It is guaranteed that this will return false when called on non-tensor arguments
+  // XXX: It is guaranteed that this will return false when called on non-tensor
+  // arguments
   bool requires_grad() const {
     return requires_grad_;
   }
@@ -46,23 +48,29 @@ struct ArgumentInfo {
     return TensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim());
   }
 
-private:
+ private:
   unsigned is_tensor_ : 1;
   unsigned defined_ : 1;
   unsigned requires_grad_ : 1;
   unsigned : 5;
   unsigned dim_ : 8;
-  int device_ : 8; // NOTE: this needs to be signed because we use -1 to represent CPU
+  int device_ : 8; // NOTE: this needs to be signed because we use -1 to
+                   // represent CPU
   unsigned type_ : 8;
 };
 
-static_assert(std::is_pod<ArgumentInfo>::value,
-  "ArgumentInfo is to be a POD struct");
-static_assert(sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
-  "ArgumentInfo is expected to be a 32-bit struct");
+static_assert(
+    std::is_pod<ArgumentInfo>::value,
+    "ArgumentInfo is to be a POD struct");
+static_assert(
+    sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
+    "ArgumentInfo is expected to be a 32-bit struct");
 
 struct ArgumentSpec {
-  ArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs, size_t num_flat_inputs) {
+  ArgumentSpec(
+      bool with_grad,
+      at::ArrayRef<IValue> inputs,
+      size_t num_flat_inputs) {
     hash_code = num_flat_inputs;
     args.resize(num_flat_inputs);
     size_t offset = 0;
@@ -73,7 +81,7 @@ struct ArgumentSpec {
   }
 
   void addInput(const IValue& input, size_t& offset, bool with_grad) {
-    auto & arg = args.at(offset);
+    auto& arg = args.at(offset);
     // Initialize all fields to 0. This is convenient, because e.g.
     // requires_grad() can be checked even on tensors AND will make
     // padding bits all 0s.
@@ -92,17 +100,18 @@ struct ArgumentSpec {
       combineHash(arg);
       offset++;
     } else if (input.isTuple()) {
-      for (const IValue & elem : input.toTuple()->elements()) {
+      for (const IValue& elem : input.toTuple()->elements()) {
         addInput(elem, offset, with_grad);
       }
     } else {
-      // NB: no need to set is_tensor to false, because we memset the struct to 0 above
+      // NB: no need to set is_tensor to false, because we memset the struct to
+      // 0 above
       combineHash(arg);
       offset++;
     }
   }
 
-  void combineHash(const ArgumentInfo &arg) {
+  void combineHash(const ArgumentInfoarg) {
     ArgumentInfo::plain_data_type arg_data;
     std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo));
     hash_code = hash_combine(hash_code, arg_data);
@@ -110,14 +119,19 @@ struct ArgumentSpec {
 
   // equality is fast: check ninputs, and then check the raw array data,
   // there are no size/stride indirections
-  bool operator==(const ArgumentSpec & spec) const {
-    if (args.size() != spec.args.size()) return false;
-    // NB: we need to break out early when there are no elements, because passing a
-    // nullptr to memcmp is UB.
-    if (args.size() == 0) return true;
-    return std::memcmp(args.data(), spec.args.data(), args.size() * sizeof(ArgumentInfo)) == 0;
-  }
-  bool operator!=(const ArgumentSpec & spec) const {
+  bool operator==(const ArgumentSpec& spec) const {
+    if (args.size() != spec.args.size())
+      return false;
+    // NB: we need to break out early when there are no elements, because
+    // passing a nullptr to memcmp is UB.
+    if (args.size() == 0)
+      return true;
+    return std::memcmp(
+               args.data(),
+               spec.args.data(),
+               args.size() * sizeof(ArgumentInfo)) == 0;
+  }
+  bool operator!=(const ArgumentSpec& spec) const {
     return !(*this == spec);
   }
   size_t size() const {
@@ -133,21 +147,25 @@ struct ArgumentSpec {
   // inferred for it based on this ArgumentSpec.
   std::vector<TypePtr> getTypes(Graph& graph) const {
     size_t offset = 0;
-    return fmap(graph.inputs(),
-                [&](Value *v) { return fillType(v->type(), offset); });
+    return fmap(
+        graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); });
   }
 
-private:
+ private:
   TypePtr fillType(TypePtr original, size_t& offset) const {
     if (original->isSubtypeOf(DynamicType::get())) {
-      auto & arg = args.at(offset++);
+      auto& arg = args.at(offset++);
       if (!arg.defined())
         return UndefinedTensorType::get();
-      return TensorType::create(arg.type(), ConvertIntToCPUOrCUDA(arg.device()), arg.dim(), arg.requires_grad());
+      return TensorType::create(
+          arg.type(),
+          ConvertIntToCPUOrCUDA(arg.device()),
+          arg.dim(),
+          arg.requires_grad());
     } else if (auto tuple_type = original->cast<TupleType>()) {
-      return TupleType::create(fmap(tuple_type->elements(), [&](const TypePtr& subtype) {
-        return fillType(subtype, offset);
-      }));
+      return TupleType::create(fmap(
+          tuple_type->elements(),
+          [&](const TypePtr& subtype) { return fillType(subtype, offset); }));
     } else {
       offset++;
       return original;
@@ -171,38 +189,41 @@ struct CompleteArgumentInfoPOD {
   unsigned defined : 1;
   unsigned requires_grad : 1;
   signed device : 14;
-  uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's tensor_info() array.
-                       // total_dims is the total number of dimensions seen so far
-                       // in all previous members of tensor_info(), including this tensor
-                       // 2*total_dims becomes the offset into the sizes_strides list
-                       // for the _next_ tensor in the tensor_info array
-                       // for tensor 0, the offset is always 0
+  uint32_t total_dims; // all TensorInfoPODs are in CompleteArgumentSpec's
+                       // tensor_info() array. total_dims is the total number of
+                       // dimensions seen so far in all previous members of
+                       // tensor_info(), including this tensor 2*total_dims
+                       // becomes the offset into the sizes_strides list for the
+                       // _next_ tensor in the tensor_info array for tensor 0,
+                       // the offset is always 0
 };
 
-static_assert(sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
-  "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
+static_assert(
+    sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
+    "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
 
 struct CompleteArgumentInfo;
 
 struct CompleteArgumentSpec {
   CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs)
-   hash_code(0), ninputs(inputs.size()) {
+      : hash_code(0), ninputs(inputs.size()) {
     int32_t all_dims = 0;
     const int32_t num_inputs = inputs.size();
     for (int32_t i = 0; i < num_inputs; i++) {
-      if (!inputs[i].isTensor()) continue;
+      if (!inputs[i].isTensor())
+        continue;
       auto tensor = inputs[i].toTensor();
       all_dims += tensor.defined() ? tensor.ndimension() : 0;
     }
     // allocate enough room for all TensorPODs and dimensions
-    data.resize(ninputs + all_dims*2);
+    data.resize(ninputs + all_dims * 2);
 
     // and reinterpret our data array as these structs
     auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data());
-    int64_t * next_dim = sizes_strides();
+    int64_t* next_dim = sizes_strides();
     int32_t total_dims = 0;
-    for(int32_t i = 0; i < num_inputs; i++) {
-      auto & pod = pods[i];
+    for (int32_t i = 0; i < num_inputs; i++) {
+      auto& pod = pods[i];
       pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor());
       if (pod.is_tensor) {
         at::Tensor t = inputs[i].toTensor();
@@ -210,10 +231,11 @@ struct CompleteArgumentSpec {
         if (pod.defined) {
           pod.type = static_cast<int>(t.type().scalarType());
           pod.device = (!t.is_cuda()) ? -1 : t.get_device();
-          pod.requires_grad = with_grad && autograd::as_variable_ref(t).requires_grad();
+          pod.requires_grad =
+              with_grad && autograd::as_variable_ref(t).requires_grad();
           total_dims += t.ndimension();
           auto sizes = t.sizes();
-          std::copy(sizes.begin(),sizes.end(), next_dim);
+          std::copy(sizes.begin(), sizes.end(), next_dim);
           next_dim += sizes.size();
           auto strides = t.strides();
           std::copy(strides.begin(), strides.end(), next_dim);
@@ -226,17 +248,17 @@ struct CompleteArgumentSpec {
     // we precompute the hash_code to minimize the time inside of hash
     // table operations where we may need to hold a compiler cache lock.
     hash_code = hash_combine(0, ninputs);
-    for(auto d : data) {
+    for (auto d : data) {
       hash_code = hash_combine(hash_code, d);
     }
   }
 
   // equality is fast: check ninputs, and then check the raw array data,
   // there are no size/stride indirections
-  bool operator==(const CompleteArgumentSpec & spec) const {
+  bool operator==(const CompleteArgumentSpec& spec) const {
     return ninputs == spec.ninputs && data == spec.data;
   }
-  bool operator!=(const CompleteArgumentSpec & spec) const {
+  bool operator!=(const CompleteArgumentSpec& spec) const {
     return !(*this == spec);
   }
   friend struct CompleteArgumentInfo;
@@ -248,12 +270,13 @@ struct CompleteArgumentSpec {
     return hash_code;
   }
 
-private:
+ private:
   ArrayRef<CompleteArgumentInfoPOD> tensor_info() const {
     return ArrayRef<CompleteArgumentInfoPOD>(
-            reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
+        reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
   }
-  // the start of the sizes_strides information, which comes after the CompleteArgumentInfoPOD list.
+  // the start of the sizes_strides information, which comes after the
+  // CompleteArgumentInfoPOD list.
   const int64_t* sizes_strides() const {
     return data.data() + ninputs;
   }
@@ -262,15 +285,17 @@ private:
   }
   size_t hash_code; // precomputed on construction
   int32_t ninputs;
-  // layout is ninputs of TensorPOD (each 64-bit) followed by their size and stride info
-  // for 3 tensors: [t0POD][t1POD][t2POD][t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
+  // layout is ninputs of TensorPOD (each 64-bit) followed by their size and
+  // stride info for 3 tensors:
+  // [t0POD][t1POD][t2POD]...
+  // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
   std::vector<int64_t> data;
 };
 
 // public view of compressed CompleteArgumentInfo
 struct CompleteArgumentInfo {
-  CompleteArgumentInfo(const CompleteArgumentSpec & spec, const int i)
-  : spec(spec), i(i) {}
+  CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i)
+      : spec(spec), i(i) {}
   bool isTensor() const {
     return pod(i).is_tensor;
   }
@@ -288,49 +313,54 @@ struct CompleteArgumentInfo {
   }
   int ndimension() const {
     // See [valid range], it is always valid to ask for offset for (i + 1)
-    return (sizes_strides_offset(i + 1) - sizes_strides_offset(i))/2;
+    return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2;
   }
   at::IntList sizes() const {
-    return at::IntList(spec.sizes_strides() + sizes_strides_offset(i), ndimension());
+    return at::IntList(
+        spec.sizes_strides() + sizes_strides_offset(i), ndimension());
   }
   at::IntList strides() const {
     int ndim = ndimension();
-    return at::IntList(spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
+    return at::IntList(
+        spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
   }
   operator TypePtr() const {
-    if(!defined())
+    if (!defined())
       return DynamicType::get();
-    return CompleteTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
+    return CompleteTensorType::create(
+        type(), ConvertIntToCPUOrCUDA(device()), sizes(), strides());
   }
-private:
+
+ private:
   // offsetinto sizes_strides() array where the sizes start for tensor j
   // [valid range] valid range is [0, ninputs]
-  // (i.e. you can ask for the offset at ninputs, which would be the offset of the next tensor if it existed)
+  // (i.e. you can ask for the offset at ninputs, which would be the offset of
+  // the next tensor if it existed)
   int sizes_strides_offset(int j) const {
-    if(j == 0) return 0;
-    return 2*pod(j - 1).total_dims;
+    if (j == 0)
+      return 0;
+    return 2 * pod(j - 1).total_dims;
   }
-  const CompleteArgumentInfoPOD & pod(int j) const {
+  const CompleteArgumentInfoPOD& pod(int j) const {
     return spec.tensor_info().at(j);
   }
-  const CompleteArgumentSpec & spec;
+  const CompleteArgumentSpec& spec;
   const int i;
 };
 
-inline std::ostream & operator<<(std::ostream & out, const ArgumentInfo & info) {
-  if(!info.defined()) {
+inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) {
+  if (!info.defined()) {
     return out << "<undefined>";
   }
-  out << "Tensor(device=" << info.device()
-    << ", type=" << toString(info.type())
-    << ", requires_grad=" << info.requires_grad()
-    << ", dims=" << info.dim() << ")";
+  out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
+      << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim()
+      << ")";
   return out;
 }
 
-inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) {
+inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) {
   out << "{";
-  for(size_t i = 0; i < spec.size(); ++i) {
+  for (size_t i = 0; i < spec.size(); ++i) {
     if (i > 0)
       out << ", ";
     out << spec.at(i);
@@ -339,21 +369,23 @@ inline std::ostream& operator<<(std::ostream & out, const ArgumentSpec & spec) {
   return out;
 }
 
-inline std::ostream & operator<<(std::ostream & out, const CompleteArgumentInfo & info) {
-  if(!info.defined()) {
+inline std::ostream& operator<<(
+    std::ostream& out,
+    const CompleteArgumentInfo& info) {
+  if (!info.defined()) {
     return out << "<undefined>";
   }
-  out << "Tensor(device=" << info.device()
-    << ", type=" << toString(info.type())
-    << ", requires_grad=" << info.requires_grad()
-    << ", sizes=" << info.sizes()
-    << ", strides=" << info.strides() << ")";
+  out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
+      << ", requires_grad=" << info.requires_grad()
+      << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")";
   return out;
 }
 
-inline std::ostream& operator<<(std::ostream & out, const CompleteArgumentSpec & spec) {
+inline std::ostream& operator<<(
+    std::ostream& out,
+    const CompleteArgumentSpec& spec) {
   out << "{";
-  for(size_t i = 0; i < spec.size(); ++i) {
+  for (size_t i = 0; i < spec.size(); ++i) {
     if (i > 0)
       out << ", ";
     out << spec.at(i);
@@ -374,19 +406,20 @@ inline void setInputTypes(Graph& g, const ArgumentSpec& spec) {
   }
 }
 
-}}
+} // namespace jit
+} // namespace torch
 
 namespace std {
-  template<>
-  struct hash<torch::jit::ArgumentSpec> {
-    size_t operator()(const torch::jit::ArgumentSpec & spec) const {
-      return spec.hashCode();
-    }
-  };
-  template<>
-  struct hash<torch::jit::CompleteArgumentSpec> {
-    size_t operator()(const torch::jit::CompleteArgumentSpec & spec) const {
-      return spec.hashCode();
-    }
-  };
-}
+template <>
+struct hash<torch::jit::ArgumentSpec> {
+  size_t operator()(const torch::jit::ArgumentSpec& spec) const {
+    return spec.hashCode();
+  }
+};
+template <>
+struct hash<torch::jit::CompleteArgumentSpec> {
+  size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const {
+    return spec.hashCode();
+  }
+};
+} // namespace std
index b48fbf3..b84da83 100644 (file)
@@ -1,31 +1,29 @@
 #pragma once
-#include <vector>
+#include <ATen/ATen.h>
+#include <ATen/Utils.h>
 #include <cstdint>
-#include <string>
 #include <memory>
+#include <string>
 #include <vector>
-#include <ATen/ATen.h>
-#include <ATen/Utils.h>
 
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/interned_strings.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 constexpr int max_tensor_display_size = 10;
 
-enum class AttributeKind {
-  f,fs,i,is,s,ss,t,ts,g,gs
-};
-static inline const char * toString(AttributeKind kind) {
-  static const char* names[] = {"f","fs","i","is","s","ss","t","ts","g","gs"};
-  JIT_ASSERT(size_t(kind) < sizeof(names)/sizeof(AttributeKind));
+enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
+static inline const char* toString(AttributeKind kind) {
+  static const char* names[] = {
+      "f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
+  JIT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
   return names[int(kind)];
 }
 
 struct AttributeValue {
-  AttributeValue(Symbol name)
-  : name(name) {}
+  AttributeValue(Symbol name) : name(name) {}
   using Ptr = std::unique_ptr<AttributeValue>;
   Symbol name;
   virtual AttributeKind kind() const = 0;
@@ -33,67 +31,78 @@ struct AttributeValue {
   virtual ~AttributeValue() = default;
 };
 
-template<typename T, AttributeKind Kind>
+template <typename T, AttributeKind Kind>
 struct ScalarAttributeValue : public AttributeValue {
   using ConstructorType = T;
   using ValueType = T;
   ScalarAttributeValue(Symbol name, ConstructorType value_)
-  : AttributeValue(name), value_(std::move(value_)) {}
-  ValueType & value() {
+      : AttributeValue(name), value_(std::move(value_)) {}
+  ValueType& value() {
     return value_;
   }
   Ptr clone() const override {
     return Ptr(new ScalarAttributeValue(name, value_));
   }
-  AttributeKind kind() const override { return Kind; }
-private:
+  AttributeKind kind() const override {
+    return Kind;
+  }
+
+ private:
   ValueType value_;
 };
 
-template<typename T, AttributeKind Kind>
+template <typename T, AttributeKind Kind>
 struct VectorAttributeValue : public AttributeValue {
   using ConstructorType = std::vector<T>;
   using ValueType = std::vector<T>;
   VectorAttributeValue(Symbol name, ConstructorType value_)
-  : AttributeValue(name), value_(std::move(value_)) {}
-  ValueType & value() {
+      : AttributeValue(name), value_(std::move(value_)) {}
+  ValueType& value() {
     return value_;
   }
-  AttributeKind kind() const override { return Kind; }
+  AttributeKind kind() const override {
+    return Kind;
+  }
   std::unique_ptr<AttributeValue> clone() const override {
     auto copy = value_;
     return Ptr(new VectorAttributeValue(name, std::move(copy)));
   }
-private:
+
+ private:
   ValueType value_;
 };
 
-using FloatAttr = ScalarAttributeValue<double,AttributeKind::f>;
-using FloatsAttr = VectorAttributeValue<double,AttributeKind::fs>;
-using IntAttr = ScalarAttributeValue<int64_t,AttributeKind::i>;
-using IntsAttr = VectorAttributeValue<int64_t,AttributeKind::is>;
-using StringAttr = ScalarAttributeValue<std::string,AttributeKind::s>;
-using StringsAttr = VectorAttributeValue<std::string,AttributeKind::ss>;
-using TensorAttr = ScalarAttributeValue<at::Tensor,AttributeKind::t>;
-using TensorsAttr = VectorAttributeValue<at::Tensor,AttributeKind::ts>;
+using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
+using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
+using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
+using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
+using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
+using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
+using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
+using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
 struct Graph;
-using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>,AttributeKind::g>;
-using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>,AttributeKind::gs>;
+using GraphAttr =
+    ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
+using GraphsAttr =
+    VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
 
 struct AttributeError : public std::exception {
   AttributeError(Symbol name, bool defined) {
     std::stringstream ss;
-    if(!defined) {
-      ss << "required keyword attribute '" << name.toUnqualString() << "' is undefined.";
+    if (!defined) {
+      ss << "required keyword attribute '" << name.toUnqualString()
+         << "' is undefined.";
     } else {
-      ss << "required keyword attribute '" << name.toUnqualString() << "' has the wrong type";
+      ss << "required keyword attribute '" << name.toUnqualString()
+         << "' has the wrong type";
     }
     msg = ss.str();
   }
-  const char* what() const noexcept override  {
+  const char* what() const noexcept override {
     return msg.c_str();
   }
-private:
+
+ private:
   std::string msg;
 };
 
@@ -101,18 +110,18 @@ private:
 // method chaining e.g:
 // Node * n = g->create(kSelect)->i_(kOffset,3)->f_(kValue,3.5);
 // we return Derived* pointers because Nodes are normally held as pointers.
-template<typename Derived>
+template <typename Derived>
 struct Attributes {
   Attributes() = default;
-  void copyAttributes(const Attributes & rhs) {
+  void copyAttributes(const Attributes& rhs) {
     values_.clear();
-    for(auto & i : rhs.values_) {
+    for (auto& i : rhs.values_) {
       values_.push_back(i->clone());
     }
   }
   bool hasAttribute(Symbol name) const {
     JIT_ASSERT(name.is_attr());
-    return find(name,false) != values_.end();
+    return find(name, false) != values_.end();
   }
   // We want direct string accessors, as it is nicer to use than
   // hasAttribute(Symbol::attr("blah"))
@@ -127,14 +136,14 @@ struct Attributes {
   }
   AttributeKind kindOf(Symbol name) const {
     JIT_ASSERT(name.is_attr());
-    return (*find(name,true))->kind();
+    return (*find(name, true))->kind();
   }
   AttributeKind kindOfS(const std::string& name) const {
     return kindOf(Symbol::attr(name));
   }
   Derived* removeAttribute(Symbol name) {
     JIT_ASSERT(name.is_attr());
-    values_.erase(find(name,true));
+    values_.erase(find(name, true));
     return This();
   }
   Derived* removeAttributeS(const std::string& name) {
@@ -149,35 +158,36 @@ struct Attributes {
   // The names are returned in order, since name actually is the index.
   std::vector<Symbol> attributeNames() const {
     std::vector<Symbol> names;
-    for(auto & a : values_)
+    for (auto& a : values_)
       names.push_back(a->name);
     return names;
   }
   std::vector<const char*> attributeNamesS() const {
     std::vector<const char*> names;
-    for(auto & a : values_)
+    for (auto& a : values_)
       names.push_back(a->name.toUnqualString());
     return names;
   }
 
-  #define CREATE_ACCESSOR(Kind, method) \
+#define CREATE_ACCESSOR(Kind, method)                              \
   Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
-    return set<Kind##Attr>(name,std::forward<Kind##Attr::ConstructorType>(v)); \
-  } \
-  const Kind##Attr::ValueType& method(Symbol name) const { \
-    return get<Kind##Attr>(name); \
+    return set<Kind##Attr>(                                        \
+        name, std::forward<Kind##Attr::ConstructorType>(v));       \
+  }                                                                \
+  const Kind##Attr::ValueType& method(Symbol name) const {         \
+    return get<Kind##Attr>(name);                                  \
   }
 
-  CREATE_ACCESSOR(Float,f)
-  CREATE_ACCESSOR(Floats,fs)
-  CREATE_ACCESSOR(String,s)
-  CREATE_ACCESSOR(Strings,ss)
-  CREATE_ACCESSOR(Int,i)
-  CREATE_ACCESSOR(Ints,is)
-  CREATE_ACCESSOR(Graph,g)
-  CREATE_ACCESSOR(Graphs,gs)
+  CREATE_ACCESSOR(Float, f)
+  CREATE_ACCESSOR(Floats, fs)
+  CREATE_ACCESSOR(String, s)
+  CREATE_ACCESSOR(Strings, ss)
+  CREATE_ACCESSOR(Int, i)
+  CREATE_ACCESSOR(Ints, is)
+  CREATE_ACCESSOR(Graph, g)
+  CREATE_ACCESSOR(Graphs, gs)
 
-  #undef CREATE_ACCESSOR
+#undef CREATE_ACCESSOR
 
   // Our Graphs are not very const-correct, so we need to allow returning
   // non-const references too
@@ -188,28 +198,29 @@ struct Attributes {
   // does not use CREATE_ACCESSOR because we need additional asserts
   Derived* t_(Symbol name, TensorAttr::ConstructorType v) {
     JIT_ASSERT(!v.defined() || !v.is_variable());
-    return set<TensorAttr>(name,std::forward<TensorAttr::ConstructorType>(v));
+    return set<TensorAttr>(name, std::forward<TensorAttr::ConstructorType>(v));
   }
   const TensorAttr::ValueType& t(Symbol name) const {
     return get<TensorAttr>(name);
   }
 
   Derived* ts_(Symbol name, TensorsAttr::ConstructorType v) {
-    for(auto & t : v) {
+    for (auto& t : v) {
       JIT_ASSERT(!t.defined() || !t.is_variable());
     }
-    return set<TensorsAttr>(name,std::forward<TensorsAttr::ConstructorType>(v));
+    return set<TensorsAttr>(
+        name, std::forward<TensorsAttr::ConstructorType>(v));
   }
   const TensorsAttr::ValueType& ts(Symbol name) const {
     return get<TensorsAttr>(name);
   }
 
-  template<typename T>
-  static void printPrimList(std::ostream & out, const std::vector<T> & items) {
+  template <typename T>
+  static void printPrimList(std::ostream& out, const std::vector<T>& items) {
     out << "[";
     int i = 0;
-    for(auto & item : items) {
-      if(i++ > 0)
+    for (auto& item : items) {
+      if (i++ > 0)
         out << ", ";
       out << item;
     }
@@ -221,7 +232,7 @@ struct Attributes {
     std::vector<std::string> replace = {"\\n", "\\t", "\\v"};
     for (size_t i = 0; i < search.size(); i++) {
       size_t pos = s.find(search[i]);
-      while(pos != std::string::npos) {
+      while (pos != std::string::npos) {
         s.replace(pos, 1, replace[i]);
         pos = s.find(search[i], pos + 1);
       }
@@ -229,8 +240,8 @@ struct Attributes {
     return s;
   }
 
-  void printValue(std::ostream & out, const Symbol & name) const {
-    switch(kindOf(name)) {
+  void printValue(std::ostream& out, const Symbol& name) const {
+    switch (kindOf(name)) {
       case AttributeKind::f:
         out << f(name);
         break;
@@ -247,34 +258,33 @@ struct Attributes {
         out << "\"" << escapeString(s(name)) << "\"";
         break;
       case AttributeKind::ss:
-        printPrimList(out,ss(name));
+        printPrimList(out, ss(name));
         break;
-      case AttributeKind::t:
-        {
-          at::Tensor tensor = t(name);
-          // 1-elem tensors are usually boxed scalars, so print them like it
-          if (tensor.numel() == 1) {
-            auto scalar_tensor = tensor.view({}).item();
-            out << "{";
-            if (scalar_tensor.isFloatingPoint()) {
-              out << scalar_tensor.toDouble();
-            } else {
-              out << scalar_tensor.toLong();
-            }
-            out << "}";
-          } else if (tensor.numel() <= max_tensor_display_size) {
-            // TODO: This is awful code.  Also it doesn't work on Windows.
-            std::ostringstream tensor_ss;
-            tensor_ss << tensor;
-            std::string tensor_s{tensor_ss.str()};
-            // Remove newlines
-            std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
-            out << tensor_s;
+      case AttributeKind::t: {
+        at::Tensor tensor = t(name);
+        // 1-elem tensors are usually boxed scalars, so print them like it
+        if (tensor.numel() == 1) {
+          auto scalar_tensor = tensor.view({}).item();
+          out << "{";
+          if (scalar_tensor.isFloatingPoint()) {
+            out << scalar_tensor.toDouble();
           } else {
-            out << "<Tensor>";
+            out << scalar_tensor.toLong();
           }
-          break;
+          out << "}";
+        } else if (tensor.numel() <= max_tensor_display_size) {
+          // TODO: This is awful code.  Also it doesn't work on Windows.
+          std::ostringstream tensor_ss;
+          tensor_ss << tensor;
+          std::string tensor_s{tensor_ss.str()};
+          // Remove newlines
+          std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
+          out << tensor_s;
+        } else {
+          out << "<Tensor>";
         }
+        break;
+      }
       case AttributeKind::ts:
         out << "[<Tensors>]";
         break;
@@ -287,29 +297,29 @@ struct Attributes {
     }
   }
 
-private:
+ private:
   // UBSAN error: https://github.com/pytorch/pytorch/issues/9055
   Derived* This() __ubsan_ignore_vptr__ {
     return static_cast<Derived*>(this);
   }
-  template<typename T>
+  template <typename T>
   Derived* set(Symbol name, typename T::ConstructorType v) {
     JIT_ASSERT(name.is_attr());
     auto it = find(name, false);
     auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
-    if(it == values_.end()) {
+    if (it == values_.end()) {
       values_.push_back(std::move(nv));
     } else {
       *it = std::move(nv);
     }
     return This();
   }
-  template<typename T>
-  typename T::ValueType & get(Symbol name) const {
+  template <typename T>
+  typename T::ValueType& get(Symbol name) const {
     JIT_ASSERT(name.is_attr());
     auto it = find(name, true);
     auto* child = dynamic_cast<T*>(it->get());
-    if(child == nullptr) {
+    if (child == nullptr) {
       throw AttributeError(name, true);
     }
     return child->value();
@@ -322,10 +332,10 @@ private:
   using iterator = std::vector<AVPtr>::iterator;
   iterator find(Symbol name, bool required) {
     JIT_ASSERT(name.is_attr());
-    auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
+    auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
       return v->name == name;
     });
-    if(required && it == values_.end()) {
+    if (required && it == values_.end()) {
       throw AttributeError(name, false);
     }
     JIT_ASSERT(!required || it != values_.end());
@@ -334,10 +344,10 @@ private:
   using const_iterator = std::vector<AVPtr>::const_iterator;
   const_iterator find(Symbol name, bool required) const {
     JIT_ASSERT(name.is_attr());
-    auto it = std::find_if(values_.begin(), values_.end(),[&](const AVPtr & v) {
+    auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
       return v->name == name;
     });
-    if(required && it == values_.end()) {
+    if (required && it == values_.end()) {
       throw AttributeError(name, false);
     }
     JIT_ASSERT(!required || it != values_.end());
@@ -345,4 +355,5 @@ private:
   }
 };
 
-}}
+} // namespace jit
+} // namespace torch
index 686a0b9..5bec04b 100644 (file)
 #include <torch/csrc/jit/autodiff.h>
 
-#include "torch/csrc/jit/passes/lower_tuples.h"
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
-#include "torch/csrc/jit/symbolic_script.h"
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/utils/functional.h>
+#include "torch/csrc/jit/passes/lower_tuples.h"
 #include "torch/csrc/jit/script/compiler.h"
+#include "torch/csrc/jit/symbolic_script.h"
 
 #include <torch/csrc/jit/assertions.h>
 
 #include <algorithm>
 #include <memory>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 using value_map = std::unordered_map<Value*, Value*>;
 using value_set = std::unordered_set<Value*>;
 
-void wrapDim(int64_t & dim, const std::vector<int64_t> & sizes) {
+void wrapDim(int64_t& dim, const std::vector<int64_t>& sizes) {
   if (dim < 0) {
     dim += sizes.size();
   }
 }
 
-bool isDifferentiable(Node * n) {
+bool isDifferentiable(Node* n) {
   // TODO: scalar-tensor ops should be canonicalized
   static OperatorSet differentiable_ops = {
-    "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
-    "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
-    "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
-    "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
-    "aten::mul(Tensor self, Tensor other) -> Tensor",
-    "aten::mul(Tensor self, Scalar other) -> Tensor",
-    "aten::div(Tensor self, Tensor other) -> Tensor",
-    "aten::div(Tensor self, Scalar other) -> Tensor",
-    "aten::max(Tensor self, Tensor other) -> Tensor",
-    "aten::min(Tensor self, Tensor other) -> Tensor",
-    "aten::sigmoid(Tensor self) -> Tensor",
-    "aten::tanh(Tensor self) -> Tensor",
-    "aten::relu(Tensor self) -> Tensor",
-    "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
-    "aten::erf(Tensor self) -> Tensor",
-    "aten::erfc(Tensor self) -> Tensor",
-    "aten::exp(Tensor self) -> Tensor",
-    "aten::t(Tensor self) -> Tensor",
-    "aten::neg(Tensor self) -> Tensor",
-    "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
-    "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
-    "aten::type_as(Tensor self, Tensor other) -> Tensor",
-    "aten::unsqueeze(Tensor self, int dim) -> Tensor",
-    "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
-    "aten::mm(Tensor self, Tensor mat2) -> Tensor",
-    "aten::lt(Tensor self, Tensor other) -> Tensor",
-    "aten::le(Tensor self, Tensor other) -> Tensor",
-    "aten::gt(Tensor self, Tensor other) -> Tensor",
-    "aten::ge(Tensor self, Tensor other) -> Tensor",
-    "aten::eq(Tensor self, Tensor other) -> Tensor",
-    "aten::ne(Tensor self, Tensor other) -> Tensor",
-    "aten::lt(Tensor self, Scalar other) -> Tensor",
-    "aten::le(Tensor self, Scalar other) -> Tensor",
-    "aten::gt(Tensor self, Scalar other) -> Tensor",
-    "aten::ge(Tensor self, Scalar other) -> Tensor",
-    "aten::eq(Tensor self, Scalar other) -> Tensor",
-    "aten::ne(Tensor self, Scalar other) -> Tensor",
-    "aten::abs(Tensor self) -> Tensor",
-    "aten::acos(Tensor self) -> Tensor",
-    "aten::asin(Tensor self) -> Tensor",
-    "aten::atan(Tensor self) -> Tensor",
-    "aten::ceil(Tensor self) -> Tensor",
-    "aten::cos(Tensor self) -> Tensor",
-    "aten::cosh(Tensor self) -> Tensor",
-    "aten::exp(Tensor self) -> Tensor",
-    "aten::expm1(Tensor self) -> Tensor",
-    "aten::floor(Tensor self) -> Tensor",
-    "aten::fmod(Tensor self, Scalar other) -> Tensor",
-    "aten::frac(Tensor self) -> Tensor",
-    "aten::log(Tensor self) -> Tensor",
-    "aten::log10(Tensor self) -> Tensor",
-    "aten::log1p(Tensor self) -> Tensor",
-    "aten::log2(Tensor self) -> Tensor",
-    "aten::reciprocal(Tensor self) -> Tensor",
-    "aten::remainder(Tensor self, Scalar other) -> Tensor",
-    "aten::round(Tensor self) -> Tensor",
-    "aten::rsqrt(Tensor self) -> Tensor",
-    "aten::sin(Tensor self) -> Tensor",
-    "aten::sinh(Tensor self) -> Tensor",
-    "aten::tan(Tensor self) -> Tensor",
-    "aten::trunc(Tensor self) -> Tensor",
-    "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
-    "aten::log_softmax(Tensor self, int dim) -> Tensor",
-    "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
-    "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
-    "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
-    "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
+      "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+      "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+      "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+      "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+      "aten::mul(Tensor self, Tensor other) -> Tensor",
+      "aten::mul(Tensor self, Scalar other) -> Tensor",
+      "aten::div(Tensor self, Tensor other) -> Tensor",
+      "aten::div(Tensor self, Scalar other) -> Tensor",
+      "aten::max(Tensor self, Tensor other) -> Tensor",
+      "aten::min(Tensor self, Tensor other) -> Tensor",
+      "aten::sigmoid(Tensor self) -> Tensor",
+      "aten::tanh(Tensor self) -> Tensor",
+      "aten::relu(Tensor self) -> Tensor",
+      "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
+      "aten::erf(Tensor self) -> Tensor",
+      "aten::erfc(Tensor self) -> Tensor",
+      "aten::exp(Tensor self) -> Tensor",
+      "aten::t(Tensor self) -> Tensor",
+      "aten::neg(Tensor self) -> Tensor",
+      "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
+      "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
+      "aten::type_as(Tensor self, Tensor other) -> Tensor",
+      "aten::unsqueeze(Tensor self, int dim) -> Tensor",
+      "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
+      "aten::mm(Tensor self, Tensor mat2) -> Tensor",
+      "aten::lt(Tensor self, Tensor other) -> Tensor",
+      "aten::le(Tensor self, Tensor other) -> Tensor",
+      "aten::gt(Tensor self, Tensor other) -> Tensor",
+      "aten::ge(Tensor self, Tensor other) -> Tensor",
+      "aten::eq(Tensor self, Tensor other) -> Tensor",
+      "aten::ne(Tensor self, Tensor other) -> Tensor",
+      "aten::lt(Tensor self, Scalar other) -> Tensor",
+      "aten::le(Tensor self, Scalar other) -> Tensor",
+      "aten::gt(Tensor self, Scalar other) -> Tensor",
+      "aten::ge(Tensor self, Scalar other) -> Tensor",
+      "aten::eq(Tensor self, Scalar other) -> Tensor",
+      "aten::ne(Tensor self, Scalar other) -> Tensor",
+      "aten::abs(Tensor self) -> Tensor",
+      "aten::acos(Tensor self) -> Tensor",
+      "aten::asin(Tensor self) -> Tensor",
+      "aten::atan(Tensor self) -> Tensor",
+      "aten::ceil(Tensor self) -> Tensor",
+      "aten::cos(Tensor self) -> Tensor",
+      "aten::cosh(Tensor self) -> Tensor",
+      "aten::exp(Tensor self) -> Tensor",
+      "aten::expm1(Tensor self) -> Tensor",
+      "aten::floor(Tensor self) -> Tensor",
+      "aten::fmod(Tensor self, Scalar other) -> Tensor",
+      "aten::frac(Tensor self) -> Tensor",
+      "aten::log(Tensor self) -> Tensor",
+      "aten::log10(Tensor self) -> Tensor",
+      "aten::log1p(Tensor self) -> Tensor",
+      "aten::log2(Tensor self) -> Tensor",
+      "aten::reciprocal(Tensor self) -> Tensor",
+      "aten::remainder(Tensor self, Scalar other) -> Tensor",
+      "aten::round(Tensor self) -> Tensor",
+      "aten::rsqrt(Tensor self) -> Tensor",
+      "aten::sin(Tensor self) -> Tensor",
+      "aten::sinh(Tensor self) -> Tensor",
+      "aten::tan(Tensor self) -> Tensor",
+      "aten::trunc(Tensor self) -> Tensor",
+      "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)",
+      "aten::log_softmax(Tensor self, int dim) -> Tensor",
+      "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
+      "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)",
+      "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)",
+      "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
   };
 
   // TODO: add support for the following fusible operators.
-  // They're a little tricky to implement; max/min require mutability for best perf
-  // "aten::atan2(Tensor self) -> Tensor",
-  // "aten::max(Tensor self) -> Tensor",
-  // "aten::min(Tensor self) -> Tensor"
-
-  if (n->kind() == prim::Constant ||
-      n->kind() == prim::Undefined ||
-      n->kind() == prim::AutogradAdd ||
-      n->kind() == prim::ConstantChunk ||
+  // They're a little tricky to implement; max/min require mutability for best
+  // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
+  // Tensor", "aten::min(Tensor self) -> Tensor"
+
+  if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
+      n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
       n->kind() == prim::None)
     return true;
   if (differentiable_ops.find(n))
@@ -118,15 +116,18 @@ bool isDifferentiable(Node * n) {
     return true;
   }
 
-  if (n->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
-    return n->get<std::vector<int64_t>>(attr::size) && n->is_constant(attr::implicit) &&
-      n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
+  if (n->matches(
+          "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
+    return n->get<std::vector<int64_t>>(attr::size) &&
+        n->is_constant(attr::implicit) &&
+        n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
   }
   if (n->matches("aten::view(Tensor self, int[] size) -> Tensor")) {
     return n->get<std::vector<int64_t>>(attr::size) &&
-      n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
+        n->namedInput(attr::self)->type()->cast<CompleteTensorType>();
   }
-  if (n->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+  if (n->matches(
+          "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
     // TODO(asuhan): support weight
     return n->namedInput(attr::weight)->node()->kind() == prim::Undefined;
   }
@@ -144,10 +145,11 @@ bool isDifferentiable(Node * n) {
   return false;
 }
 
-
-bool isDifferentiable(Graph & g) {
-  return std::all_of(g.nodes().begin(), g.nodes().end(),
-                     static_cast<bool(*)(Node*)>(isDifferentiable));
+bool isDifferentiable(Graph& g) {
+  return std::all_of(
+      g.nodes().begin(),
+      g.nodes().end(),
+      static_cast<bool (*)(Node*)>(isDifferentiable));
 }
 
 // NB: Write gradient using torchscript
@@ -160,22 +162,25 @@ bool isDifferentiable(Graph & g) {
 //
 // Here ctx is a tuple that carries all input/intermediate results needed in
 // backward from forward pass.
-// This python code is compiled into a GradientPair which includes a forward graph
-// and a backward graph. Forward graph will be used to replace the node in grad_desc.f,
-// and backward graph will be used to construct GradOf(node) in reverse_block.
-// Grad_values(a.k.a gradOutputs) propagated through node->owningGraph() in
-// **reversed** order, thus GradientPair.forward ahould be inserted **after**
-// the node being replaced, so that we don't traverse the graph infinite times.
+//
+// This python code is compiled into a GradientPair which includes a forward
+// graph and a backward graph. Forward graph will be used to replace the node in
+// grad_desc.f, and backward graph will be used to construct GradOf(node) in
+// reverse_block. Grad_values(a.k.a gradOutputs) propagated through
+// node->owningGraph() in **reversed** order, thus GradientPair.forward ahould
+// be inserted **after** the node being replaced, so that we don't traverse the
+// graph infinite times.
+//
 // The output of compiled forward graph is [real_outputs, ctx]
 // The input of compiled backward graph is [ctx, grad_values]
-// We run LowerSimpleTuples afterwards to elmininate all tuples generated in this process.
-// The original node and TupleConstruct nodes in forward graph will be cleaned up
-// later using EliminateDeadCode(block).
-// TupleUnPack node in backward graph will be removed in eliminateDeadcode(ReverseDetails)
-// defined in this file.
+// We run LowerSimpleTuples afterwards to elmininate all tuples generated in
+// this process. The original node and TupleConstruct nodes in forward graph
+// will be cleaned up later using EliminateDeadCode(block). TupleUnPack node in
+// backward graph will be removed in eliminateDeadcode(ReverseDetails) defined
+// in this file.
 static c10::optional<std::vector<Value*>> build_script_grad(
-        Node* node,
-        const ArrayRef<Value*>& grads) {
+    Node* node,
+    const ArrayRef<Value*>& grads) {
   auto graph = node->owningGraph();
 
   auto compiled_graphs = gradientInfoForSchema(node->schema());
@@ -187,7 +192,8 @@ static c10::optional<std::vector<Value*>> build_script_grad(
   {
     WithInsertPoint guard(node->next());
     auto fw_graph = compiled_graphs->forward;
-    new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
+    new_outputs = inlineCallTo(
+        *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
     for (size_t i = 0; i < node->outputs().size(); ++i) {
       new_outputs.at(i)->setType(node->outputs()[i]->type());
       new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
@@ -200,83 +206,141 @@ static c10::optional<std::vector<Value*>> build_script_grad(
   auto it = grad_vec.begin();
   grad_vec.insert(it, new_outputs.back());
   ArrayRef<Value*> grad(grad_vec);
-  auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
+  auto grad_inputs =
+      inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
   return grad_inputs;
 };
 
-static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_values) {
-  static const OperatorSet comparison_ops = {
-    "aten::lt(Tensor self, Tensor other) -> Tensor",
-    "aten::le(Tensor self, Tensor other) -> Tensor",
-    "aten::gt(Tensor self, Tensor other) -> Tensor",
-    "aten::ge(Tensor self, Tensor other) -> Tensor",
-    "aten::eq(Tensor self, Tensor other) -> Tensor",
-    "aten::ne(Tensor self, Tensor other) -> Tensor",
-    "aten::lt(Tensor self, Scalar other) -> Tensor",
-    "aten::le(Tensor self, Scalar other) -> Tensor",
-    "aten::gt(Tensor self, Scalar other) -> Tensor",
-    "aten::ge(Tensor self, Scalar other) -> Tensor",
-    "aten::eq(Tensor self, Scalar other) -> Tensor",
-    "aten::ne(Tensor self, Scalar other) -> Tensor",
-  };
-  const auto sumToSizeOf = [node](SymbolicVariable v, Symbol input_name) -> SymbolicVariable {
-    Value * size;
+namespace {
+class GradientHelper {
+ public:
+  GradientHelper(Node* n) : node(n) {}
+
+  std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
+    if (!isDifferentiable(node)) {
+      throw std::runtime_error(
+          std::string("differentiation of ") + node->kind().toDisplayString() +
+          " is not supported, or it is missing necessary type information");
+    }
+    // If AD is defined using torchscript, use it instead of symbolic
+    auto script_grads = build_script_grad(node, grad_values);
+    if (script_grads)
+      return *script_grads;
+    // Definition not found in torchscript, look up in the buildSymbolicGradient
+    // TODO: migrate all to using torchscript
+    auto sym_grads = buildSymbolicGradient(fmap<SymbolicVariable>(grad_values));
+    return fmap(sym_grads, [](const SymbolicVariable& v) { return v.value(); });
+  }
+
+ private:
+  Node* node;
+
+  SymbolicVariable sumToSizeOf(SymbolicVariable v, Symbol input_name) {
+    Value* size;
     {
-      WithInsertPoint insert_guard {node};
+      WithInsertPoint insert_guard{node};
       size = SymbolicVariable(node->namedInput(input_name)).size();
     }
     return v.sumToSize(size);
   };
-  const auto build_sym_grad = [node, &sumToSizeOf](const std::vector<SymbolicVariable>& grads) -> std::vector<SymbolicVariable> {
+
+  std::vector<SymbolicVariable> buildSymbolicGradient(
+      const std::vector<SymbolicVariable>& grads) {
+    static const OperatorSet comparison_ops = {
+        "aten::lt(Tensor self, Tensor other) -> Tensor",
+        "aten::le(Tensor self, Tensor other) -> Tensor",
+        "aten::gt(Tensor self, Tensor other) -> Tensor",
+        "aten::ge(Tensor self, Tensor other) -> Tensor",
+        "aten::eq(Tensor self, Tensor other) -> Tensor",
+        "aten::ne(Tensor self, Tensor other) -> Tensor",
+        "aten::lt(Tensor self, Scalar other) -> Tensor",
+        "aten::le(Tensor self, Scalar other) -> Tensor",
+        "aten::gt(Tensor self, Scalar other) -> Tensor",
+        "aten::ge(Tensor self, Scalar other) -> Tensor",
+        "aten::eq(Tensor self, Scalar other) -> Tensor",
+        "aten::ne(Tensor self, Scalar other) -> Tensor",
+    };
     auto inputs = fmap<SymbolicVariable>(node->inputs());
     auto outputs = fmap<SymbolicVariable>(node->outputs());
 
-    if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {sumToSizeOf(grads.at(0), attr::self),
-              sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
-              nullptr};
+    if (node->matches(
+            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+      return {
+          sumToSizeOf(grads.at(0), attr::self),
+          sumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
+          nullptr};
 
-    } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
       return {grads.at(0), nullptr, nullptr};
 
     } else if (node->kind() == prim::AutogradAdd) {
       // NB: AutogradAdds don't broadcast
       return {grads.at(0), grads.at(0)};
 
-    } else if (node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
       return {sumToSizeOf(grads.at(0), attr::self),
-              sumToSizeOf(-grads.at(0) * node->namedInput(attr::alpha), attr::other),
+              sumToSizeOf(
+                  -grads.at(0) * node->namedInput(attr::alpha), attr::other),
               nullptr};
 
-    } else if (node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor")) {
       return {grads.at(0), nullptr, nullptr};
 
-    } else if (node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::mul(Tensor self, Tensor other) -> Tensor")) {
       return {sumToSizeOf(grads.at(0) * inputs.at(1), attr::self),
               sumToSizeOf(grads.at(0) * inputs.at(0), attr::other)};
 
-    } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::mul(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0) * inputs.at(1), nullptr};
 
-    } else if (node->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::div(Tensor self, Tensor other) -> Tensor")) {
       return {sumToSizeOf(grads.at(0) / inputs.at(1), attr::self),
-              sumToSizeOf(-grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)), attr::other)};
+              sumToSizeOf(
+                  -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1)),
+                  attr::other)};
 
-    } else if (node->matches("aten::div(Tensor self, Scalar other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::div(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0) / inputs.at(1), nullptr};
 
-    } else if (node->matches("aten::max(Tensor self, Tensor other) -> Tensor")) {
-      return {sumToSizeOf(grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)), attr::self),
-              sumToSizeOf(grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)), attr::other)};
-
-    } else if (node->matches("aten::min(Tensor self, Tensor other) -> Tensor")) {
-      return {sumToSizeOf(grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)), attr::self),
-              sumToSizeOf(grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)), attr::other)};
-
-    } else if (node->matches("aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::max(Tensor self, Tensor other) -> Tensor")) {
+      return {
+          sumToSizeOf(
+              grads.at(0) * (inputs.at(0) > inputs.at(1)).type_as(grads.at(0)),
+              attr::self),
+          sumToSizeOf(
+              grads.at(0) * (inputs.at(1) > inputs.at(0)).type_as(grads.at(0)),
+              attr::other)};
+
+    } else if (node->matches(
+                   "aten::min(Tensor self, Tensor other) -> Tensor")) {
+      return {
+          sumToSizeOf(
+              grads.at(0) * (inputs.at(0) < inputs.at(1)).type_as(grads.at(0)),
+              attr::self),
+          sumToSizeOf(
+              grads.at(0) * (inputs.at(1) < inputs.at(0)).type_as(grads.at(0)),
+              attr::other)};
+
+    } else if (
+        node->matches(
+            "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor")) {
       return {nullptr,
-              sumToSizeOf(grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
-              sumToSizeOf(grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)), attr::other)};
+              sumToSizeOf(
+                  grads.at(0) * inputs.at(0).type_as(grads.at(0)), attr::self),
+              sumToSizeOf(
+                  grads.at(0) * (1 - inputs.at(0)).type_as(grads.at(0)),
+                  attr::other)};
 
     } else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
       // TODO: The order of operations matter in this case. This
@@ -288,39 +352,55 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
 
     } else if (node->matches("aten::relu(Tensor self) -> Tensor")) {
-      return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
+      return {grads.at(0) *
+              (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
 
-    } else if (node->matches("aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
       // handle the case that min/max is None
       Value* min = inputs.at(1);
       bool min_must_be_none = min->node()->kind() == prim::None;
       Value* max = inputs.at(2);
       bool max_must_be_none = max->node()->kind() == prim::None;
       // XXX - this formula is wrong when min or max are not stricly prim::None
-      // but may be None dynamically. In this case an internal compiler error will
-      // get thrown when trying to generate expressions involving the values of min/max
+      // but may be None dynamically. In this case an internal compiler error
+      // will get thrown when trying to generate expressions involving the
+      // values of min/max
       if (!min_must_be_none && !max_must_be_none) {
-        return {grads.at(0)
-          * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0)))
-          * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr};
+        return {grads.at(0) *
+                    (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))) *
+                    (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
+                nullptr,
+                nullptr};
       } else if (max_must_be_none) {
-        return {grads.at(0)
-          * (1-(inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))), nullptr, nullptr};
+        return {grads.at(0) *
+                    (1 - (inputs.at(0) <= inputs.at(1)).type_as(inputs.at(0))),
+                nullptr,
+                nullptr};
       } else if (min_must_be_none) {
-        return {grads.at(0)
-          * (1-(inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))), nullptr, nullptr};
+        return {grads.at(0) *
+                    (1 - (inputs.at(0) >= inputs.at(2)).type_as(inputs.at(0))),
+                nullptr,
+                nullptr};
       } else {
         return {grads.at(0), nullptr, nullptr};
       }
-    } else if (node->matches("aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor")) {
       auto threshold = node->get<at::Scalar>(attr::threshold).value();
-      return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)), nullptr, nullptr};
+      return {grads.at(0) * (inputs.at(0) > threshold).type_as(outputs.at(0)),
+              nullptr,
+              nullptr};
 
     } else if (node->matches("aten::erf(Tensor self) -> Tensor")) {
-      return {grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()};
+      return {grads.at(0) * 1.12837916709551 *
+              (-inputs.at(0) * inputs.at(0)).exp()};
 
     } else if (node->matches("aten::erfc(Tensor self) -> Tensor")) {
-      return {-grads.at(0) * 1.12837916709551 * (-inputs.at(0) * inputs.at(0)).exp()};
+      return {-grads.at(0) * 1.12837916709551 *
+              (-inputs.at(0) * inputs.at(0)).exp()};
 
     } else if (node->matches("aten::exp(Tensor self) -> Tensor")) {
       return {grads.at(0) * (outputs.at(0))};
@@ -335,18 +415,22 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       return {grads.at(0) * inputs.at(0).sign()};
 
     } else if (node->matches("aten::acos(Tensor self) -> Tensor")) {
-      return {grads.at(0) * -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())};
+      return {grads.at(0) *
+              -((-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt())};
 
     } else if (node->matches("aten::asin(Tensor self) -> Tensor")) {
-      return {grads.at(0) * (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()};
+      return {grads.at(0) *
+              (-inputs.at(0) * inputs.at(0) + at::Scalar(1)).rsqrt()};
 
     } else if (node->matches("aten::atan(Tensor self) -> Tensor")) {
       return {grads.at(0) / (inputs.at(0) * inputs.at(0) + at::Scalar(1))};
 
-    } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
-      Value * self_size;
+    } else if (
+        node->matches(
+            "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+      Value* self_size;
       {
-        WithInsertPoint insert_guard { node };
+        WithInsertPoint insert_guard{node};
         self_size = inputs.at(0).size();
       }
       return {grads.at(0).expand(self_size), nullptr};
@@ -369,7 +453,8 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->matches("aten::floor(Tensor self) -> Tensor")) {
       return {SymbolicVariable::zeros_like(grads.at(0))};
 
-    } else if (node->matches("aten::fmod(Tensor self, Scalar other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::fmod(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0), nullptr};
 
     } else if (node->matches("aten::frac(Tensor self) -> Tensor")) {
@@ -390,7 +475,8 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->matches("aten::reciprocal(Tensor self) -> Tensor")) {
       return {-grads.at(0) * outputs.at(0) * outputs.at(0)};
 
-    } else if (node->matches("aten::remainder(Tensor self, Scalar other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::remainder(Tensor self, Scalar other) -> Tensor")) {
       return {grads.at(0), nullptr};
 
     } else if (node->matches("aten::round(Tensor self) -> Tensor")) {
@@ -414,28 +500,42 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (node->kind() == prim::ConstantChunk) {
       return {SymbolicVariable::cat(grads, node->i(attr::dim))};
 
-    } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
-               node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
-      // TODO: if sizes are not available statically, add an operator that reutrns them as a tuple
-      auto sizes = node->namedInput(attr::self)->type()->expect<CompleteTensorType>()->sizes();
+    } else if (
+        node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
+        node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
+      // TODO: if sizes are not available statically, add an operator that
+      // reutrns them as a tuple
+      auto sizes = node->namedInput(attr::self)
+                       ->type()
+                       ->expect<CompleteTensorType>()
+                       ->sizes();
       return {grads.at(0).reshape(sizes), nullptr};
 
-    } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
       return {grads.at(0).type_as(inputs.at(0)), nullptr};
 
-    } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
       return {grads.at(0).squeeze(node->namedInput(attr::dim)), nullptr};
 
-    } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
-      return {sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
-              grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
-              inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
-              nullptr, nullptr};
+    } else if (
+        node->matches(
+            "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
+      return {
+          sumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
+          grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
+          inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
+          nullptr,
+          nullptr};
 
     } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
-      return {grads.at(0).mm(inputs.at(1).t()), inputs.at(0).t().mm(grads.at(0))};
+      return {grads.at(0).mm(inputs.at(1).t()),
+              inputs.at(0).t().mm(grads.at(0))};
 
-    } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
       const auto& input_sizes = inputs.at(0).sizes();
       if (input_sizes.size() == 0)
         return {grads.at(0).sum(), nullptr, nullptr};
@@ -456,7 +556,8 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       const auto& sizes = inputs.at(0).sizes();
       std::vector<size_t> squeezed_dims;
       for (size_t i = 0; i < sizes.size(); ++i) {
-        if (sizes[i] != 1) continue;
+        if (sizes[i] != 1)
+          continue;
         squeezed_dims.push_back(i);
       }
       SymbolicVariable returned_grad = grads.at(0);
@@ -465,18 +566,24 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
       }
       return {returned_grad};
 
-    } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
+    } else if (node->matches(
+                   "aten::squeeze(Tensor self, int dim) -> Tensor",
+                   /*const_inputs=*/attr::dim)) {
       int64_t dim = *node->get<int64_t>(attr::dim);
       const auto& sizes = inputs.at(0).sizes();
       wrapDim(dim, sizes);
-      if (sizes.size() == 0)  {
+      if (sizes.size() == 0) {
         return {grads.at(0), nullptr};
       }
-      return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr};
+      return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim),
+              nullptr};
 
-    } else if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
+    } else if (node->matches(
+                   "aten::cat(Tensor[] tensors, int dim) -> Tensor",
+                   /*const_inputs=*/attr::dim)) {
       int dim = *node->get<int64_t>(attr::dim);
-      auto tensor_inputs = inputs; tensor_inputs.pop_back();
+      auto tensor_inputs = inputs;
+      tensor_inputs.pop_back();
       const auto& first_sizes = tensor_inputs.at(0).sizes();
       const auto has_first_sizes = [&first_sizes](SymbolicVariable var) {
         return var.sizes() == first_sizes;
@@ -484,7 +591,8 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
 
       // NB: this is a specialization for the common case where all inputs are
       // of equal sizes. We can use a single split operation to handle that.
-      if (std::all_of(tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
+      if (std::all_of(
+              tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) {
         auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim);
         tensor_grads.emplace_back(nullptr); // for attr::dim
         return tensor_grads;
@@ -502,153 +610,191 @@ static std::vector<Value*> gradientForNode(Node* node, ArrayRef<Value*> grad_val
     } else if (comparison_ops.find(node)) {
       return {nullptr, nullptr};
 
-    } else if (node->matches("aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor")) {
       JIT_ASSERT(grads.size() == 1);
       auto graph = node->owningGraph();
-      auto backward_value = graph->insert(aten::avg_pool2d_backward, {
-        grads.at(0).value(),
-        node->namedInput(attr::self),
-        node->namedInput(attr::kernel_size),
-        node->namedInput(attr::stride),
-        node->namedInput(attr::padding),
-        node->namedInput(attr::ceil_mode),
-        node->namedInput(attr::count_include_pad)});
-      return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};
-
-    } else if (node->matches("aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
+      auto backward_value = graph->insert(
+          aten::avg_pool2d_backward,
+          {grads.at(0).value(),
+           node->namedInput(attr::self),
+           node->namedInput(attr::kernel_size),
+           node->namedInput(attr::stride),
+           node->namedInput(attr::padding),
+           node->namedInput(attr::ceil_mode),
+           node->namedInput(attr::count_include_pad)});
+      return {backward_value->node()->output(0),
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr};
+
+    } else if (
+        node->matches(
+            "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)")) {
       JIT_ASSERT(grads.size() == 2);
       auto graph = node->owningGraph();
-      auto backward_value = graph->insert(aten::max_pool2d_with_indices_backward, {
-        grads.at(0).value(),
-        node->namedInput(attr::self),
-        node->namedInput(attr::kernel_size),
-        node->namedInput(attr::stride),
-        node->namedInput(attr::padding),
-        node->namedInput(attr::dilation),
-        node->namedInput(attr::ceil_mode),
-        outputs.at(1).value()
-      });
-      return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr, nullptr};
-
-    } else if (node->matches("aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
+      auto backward_value = graph->insert(
+          aten::max_pool2d_with_indices_backward,
+          {grads.at(0).value(),
+           node->namedInput(attr::self),
+           node->namedInput(attr::kernel_size),
+           node->namedInput(attr::stride),
+           node->namedInput(attr::padding),
+           node->namedInput(attr::dilation),
+           node->namedInput(attr::ceil_mode),
+           outputs.at(1).value()});
+      return {backward_value->node()->output(0),
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr};
+
+    } else if (
+        node->matches(
+            "aten::thnn_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> (Tensor, Tensor, Tensor)")) {
       auto graph = node->owningGraph();
-      auto backward_value = graph->insert(aten::thnn_conv2d_backward, {
-        grads.at(0).value(),
-        inputs.at(0).value(),
-        inputs.at(1).value(),
-        node->namedInput(attr::kernel_size),
-        node->namedInput(attr::stride),
-        node->namedInput(attr::padding),
-        outputs.at(1).value(),
-        outputs.at(2).value(),
-        graph->insertConstant(std::vector<bool>{true, true, true})
-      });
-      // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
-      Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
+      auto backward_value = graph->insert(
+          aten::thnn_conv2d_backward,
+          {grads.at(0).value(),
+           inputs.at(0).value(),
+           inputs.at(1).value(),
+           node->namedInput(attr::kernel_size),
+           node->namedInput(attr::stride),
+           node->namedInput(attr::padding),
+           outputs.at(1).value(),
+           outputs.at(2).value(),
+           graph->insertConstant(std::vector<bool>{true, true, true})});
+      // graph->insert returns a tuple automatically if multiple outputs are
+      // returned. So unpack them again.
+      Node* tuple_unpack_node =
+          graph->insertNode(graph->createTupleUnpack(backward_value));
       auto tuple_outputs = tuple_unpack_node->outputs();
       JIT_ASSERT(tuple_outputs.size() == size_t(3));
-      return {tuple_outputs[0], tuple_outputs[1], nullptr, tuple_outputs[2], nullptr, nullptr};
+      return {tuple_outputs[0],
+              tuple_outputs[1],
+              nullptr,
+              tuple_outputs[2],
+              nullptr,
+              nullptr};
 
-    } else if (node->matches("aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
+    } else if (
+        node->matches(
+            "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
       auto graph = node->owningGraph();
-      auto backward_value = graph->insert(aten::native_batch_norm_backward, {
-        grads.at(0).value(),
-        inputs.at(0).value(),
-        inputs.at(1).value(),
-        inputs.at(3).value(),
-        inputs.at(4).value(),
-        outputs.at(1).value(),
-        outputs.at(2).value(),
-        inputs.at(5).value(),
-        inputs.at(7).value(),
-        graph->insertConstant(std::vector<bool>{true, true, true})
-      });
-      // graph->insert returns a tuple automatically if multiple outputs are returned. So unpack them again.
-      Node* tuple_unpack_node = graph->insertNode(graph->createTupleUnpack(backward_value));
+      auto backward_value = graph->insert(
+          aten::native_batch_norm_backward,
+          {grads.at(0).value(),
+           inputs.at(0).value(),
+           inputs.at(1).value(),
+           inputs.at(3).value(),
+           inputs.at(4).value(),
+           outputs.at(1).value(),
+           outputs.at(2).value(),
+           inputs.at(5).value(),
+           inputs.at(7).value(),
+           graph->insertConstant(std::vector<bool>{true, true, true})});
+      // graph->insert returns a tuple automatically if multiple outputs are
+      // returned. So unpack them again.
+      Node* tuple_unpack_node =
+          graph->insertNode(graph->createTupleUnpack(backward_value));
       auto tuple_outputs = tuple_unpack_node->outputs();
       JIT_ASSERT(tuple_outputs.size() == size_t(3));
-      return {tuple_outputs[0], tuple_outputs[1], tuple_outputs[2], nullptr, nullptr, nullptr, nullptr, nullptr};
+      return {tuple_outputs[0],
+              tuple_outputs[1],
+              tuple_outputs[2],
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr};
 
-    } else if (node->matches("aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
+    } else if (
+        node->matches(
+            "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
       auto graph = node->owningGraph();
       auto total_weight = graph->insertNode(graph->createUndefined());
       auto weight = graph->insertNode(graph->createUndefined());
-      auto backward_value = graph->insert(aten::nll_loss_backward, {
-        grads.at(0).value(),
-        inputs.at(0).value(),
-        inputs.at(1).value(),
-        weight->output(),
-        inputs.at(3).value(),
-        inputs.at(4).value(),
-        total_weight->output()
-      });
-      return {backward_value->node()->output(0), nullptr, nullptr, nullptr, nullptr};
-
-    } else if (node->matches("aten::log_softmax(Tensor self, int dim) -> Tensor")) {
+      auto backward_value = graph->insert(
+          aten::nll_loss_backward,
+          {grads.at(0).value(),
+           inputs.at(0).value(),
+           inputs.at(1).value(),
+           weight->output(),
+           inputs.at(3).value(),
+           inputs.at(4).value(),
+           total_weight->output()});
+      return {backward_value->node()->output(0),
+              nullptr,
+              nullptr,
+              nullptr,
+              nullptr};
+
+    } else if (node->matches(
+                   "aten::log_softmax(Tensor self, int dim) -> Tensor")) {
       JIT_ASSERT(grads.size() == 1);
       auto graph = node->owningGraph();
-      auto backward_value = graph->insert(aten::_log_softmax_backward_data, {
-        grads.at(0).value(),
-        outputs.at(0).value(),
-        node->namedInput(attr::dim),
-        node->namedInput(attr::self)
-      });
+      auto backward_value = graph->insert(
+          aten::_log_softmax_backward_data,
+          {grads.at(0).value(),
+           outputs.at(0).value(),
+           node->namedInput(attr::dim),
+           node->namedInput(attr::self)});
       return {backward_value->node()->output(0), nullptr};
 
-    } else if (node->kind() == prim::Constant || node->kind() == prim::Undefined || node->kind() == prim::None) {
+    } else if (
+        node->kind() == prim::Constant || node->kind() == prim::Undefined ||
+        node->kind() == prim::None) {
       return {};
     }
-    throw std::runtime_error(std::string("failed to differentiate `") + node->kind().toDisplayString() + "`");
-  };
-  if (!isDifferentiable(node)) {
-    throw std::runtime_error(std::string("differentiation of ") + node->kind().toDisplayString() + " "
-                             "is not supported, or it is missing necessary type information");
+    throw std::runtime_error(
+        std::string("failed to differentiate `") +
+        node->kind().toDisplayString() + "`");
   }
-  // If AD is defined using torchscript, use it instead of symbolic
-  auto script_grads = build_script_grad(node, grad_values);
-  if (script_grads)
-    return *script_grads;
-  // Definition not found in torchscript, look up in the build_sym_grad
-  // TODO: migrate all to using torchscript
-  auto sym_grads = build_sym_grad(fmap<SymbolicVariable>(grad_values));
-  return fmap(sym_grads, [](const SymbolicVariable &v) { return v.value(); });
-}
-
-// If we have a function y = f(x) with jacobian J, the backwards of f is dx = J^t dy.
-// Note that because the backwards always implements this matrix multiply,
-// we know that it maps an input vector of zeros to an output vector of zero
-// regardless of what operations it choses to do inside to actually implement
-// the matrix multiply (most use some optimized form and never generate J^t).
-// More generally, we know that all of the backward computations are linear and
-// can use this property to do more aggressive optimizations later.
-// It is ok to replace any backward function with known-zero inputs with something
-// that produces known-zero outputs. This function encloses each know-linear
-// backward function in a 'GradOf' sub-block so that we can perform optimizations
-// using this information. In particular, specializeUndef will observe if
-// all the inputs to the linear block are Undef, which the autograd uses to represent
-// zeros, and then propagate the undefs to the outputs of the block.
-static std::vector<Value*> linearGradientForNode(Node* node, ArrayRef<Value*> grad_values) {
-  auto & graph = *node->owningGraph();
+};
+} // namespace
+
+// If we have a function y = f(x) with jacobian J, the backwards of f is dx =
+// J^t dy. Note that because the backwards always implements this matrix
+// multiply, we know that it maps an input vector of zeros to an output vector
+// of zero regardless of what operations it choses to do inside to actually
+// implement the matrix multiply (most use some optimized form and never
+// generate J^t). More generally, we know that all of the backward computations
+// are linear and can use this property to do more aggressive optimizations
+// later. It is ok to replace any backward function with known-zero inputs with
+// something that produces known-zero outputs. This function encloses each
+// know-linear backward function in a 'GradOf' sub-block so that we can perform
+// optimizations using this information. In particular, specializeUndef will
+// observe if all the inputs to the linear block are Undef, which the autograd
+// uses to represent zeros, and then propagate the undefs to the outputs of the
+// block.
+static std::vector<Value*> linearGradientForNode(
+    Node* node,
+    ArrayRef<Value*> grad_values) {
+  auto& graph = *node->owningGraph();
   auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
   // to make reading gradient graphs easier, remember the name of the forward op
   linear->s_(attr::name, node->kind().toDisplayString());
   auto block = linear->addBlock();
   WithInsertPoint guard(block);
-  auto results = gradientForNode(node, grad_values);
-  return fmap(results, [block, linear](Value *grad) -> Value* {
-    if (!grad) return nullptr;
+  auto results = GradientHelper(node).gradient(grad_values);
+  return fmap(results, [block, linear](Value* grad) -> Value* {
+    if (!grad)
+      return nullptr;
     block->registerOutput(grad);
     return linear->addOutput()->copyMetadata(grad);
   });
 }
 
 struct ReverseDetails {
-  ReverseDetails(value_map&& grad_map, Block * reverse_block)
-    : grad_map(std::move(grad_map))
-    , reverse_block(reverse_block) {}
+  ReverseDetails(value_map&& grad_map, Block* reverse_block)
+      : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
 
   value_map grad_map;
-  Block * reverse_block;
+  Block* reverse_block;
 };
 
 // AutogradAdd is a special addition function that handles Undef
@@ -670,7 +816,7 @@ static Value* createAutogradAdd(Value* a, Value* b) {
 //   - grad_desc has df_input_vjps and df_output_vjps set
 //     (but df_input_vjps will be modified later as well)
 static ReverseDetails addReverseInline(Gradient& grad_desc) {
-  auto & graph = *grad_desc.f;
+  auto& graph = *grad_desc.f;
   // note: reverse_node is intentionally not inserted to avoid
   // accidentally acting on it (e.g. in elminate dead code),
   // std::cout << *reverse_node << to view its state.
@@ -687,8 +833,8 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
     }
     return it->second;
   };
-  const auto set_grad = [&](Value *x, Value *dx) {
-    if (Value * prev_grad = grad_map[x]) {
+  const auto set_grad = [&](Value* x, Value* dx) {
+    if (Value* prev_grad = grad_map[x]) {
       grad_map[x] = createAutogradAdd(prev_grad, dx);
     } else {
       grad_map[x] = dx;
@@ -697,45 +843,54 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
 
   auto outputs = graph.outputs();
   for (size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
-    Value * output = outputs[i];
+    Value* output = outputs[i];
     if (!output->requires_grad())
       continue;
-    Value * output_grad = reverse_block->addInput()->setType(output->type());
+    Value* output_grad = reverse_block->addInput()->setType(output->type());
     set_grad(output, output_grad);
     grad_desc.df_input_vjps.push_back(i);
   }
 
-  for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end; ++it) {
-    Node *node = *it;
+  for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
+       ++it) {
+    Node* node = *it;
     auto inputs = node->inputs();
     auto outputs = node->outputs();
-    if (std::all_of(outputs.begin(), outputs.end(), [](Value *v) { return !v->requires_grad(); })) {
+    if (std::all_of(outputs.begin(), outputs.end(), [](Value* v) {
+          return !v->requires_grad();
+        })) {
       continue;
     }
 
-    value_list grad_inputs = linearGradientForNode(node, fmap(node->outputs(), get_grad));
+    value_list grad_inputs =
+        linearGradientForNode(node, fmap(node->outputs(), get_grad));
     LowerSimpleTuples(reverse_block);
 
     JIT_ASSERT(grad_inputs.size() == node->inputs().size());
     for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
-      if (!inputs[i]->requires_grad()) continue;
-      // NB: Not returning a gradient w.r.t. a value that requires grad is normal if the
-      // input is non-differentiable. This happens e.g. in the aten::type_as case.
-      if (!grad_inputs[i]) continue;
+      if (!inputs[i]->requires_grad())
+        continue;
+      // NB: Not returning a gradient w.r.t. a value that requires grad is
+      // normal if the input is non-differentiable. This happens e.g. in the
+      // aten::type_as case.
+      if (!grad_inputs[i])
+        continue;
       set_grad(inputs[i], grad_inputs[i]);
     }
   }
 
   auto inputs = graph.inputs();
   for (size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
-    Value * input = inputs[i];
+    Value* input = inputs[i];
     if (!input->requires_grad())
       continue;
-    // NB: Not having a gradient defined w.r.t. an input to the graph which requires grad
-    // can happen and is not an error. It might have been used only in non-differentiable
-    // contexts (e.g. as second input to aten::type_as). In that case we simply ignore it
-    // as an output, because it won't ever produce any meaningful values.
-    if (grad_map.count(input) == 0) continue;
+    // NB: Not having a gradient defined w.r.t. an input to the graph which
+    // requires grad can happen and is not an error. It might have been used
+    // only in non-differentiable contexts (e.g. as second input to
+    // aten::type_as). In that case we simply ignore it as an output, because it
+    // won't ever produce any meaningful values.
+    if (grad_map.count(input) == 0)
+      continue;
     reverse_block->registerOutput(get_grad(input));
     grad_desc.df_output_vjps.push_back(i);
   }
@@ -743,14 +898,15 @@ static ReverseDetails addReverseInline(Gradient& grad_desc) {
   return ReverseDetails(std::move(grad_map), reverse_block);
 }
 
-// Returns a topologically-sorted list of values produced in f, and used in its reverse program.
+// Returns a topologically-sorted list of values produced in f, and used in its
+// reverse program.
 static value_list getReverseCaptures(Gradient& grad_desc) {
-  auto & graph = *grad_desc.f;
+  auto& graph = *grad_desc.f;
   auto primal_block = graph.block();
 
   value_set reverse_captures_set;
   value_list reverse_captures; // Invariant: topo sorted
-  auto check_uses = [&](Value *v) {
+  auto check_uses = [&](Valuev) {
     for (auto use : v->uses()) {
       if (use.user->owningBlock() == primal_block)
         continue;
@@ -759,39 +915,43 @@ static value_list getReverseCaptures(Gradient& grad_desc) {
       }
     }
   };
-  for (Value * input : graph.inputs()) {
+  for (Value* input : graph.inputs()) {
     check_uses(input);
   }
-  for (Node * node : graph.nodes()) {
-    for (Value * output : node->outputs())
+  for (Node* node : graph.nodes()) {
+    for (Value* output : node->outputs())
       check_uses(output);
   }
   return reverse_captures;
 }
 
-// Any temporary value from the primal graphs needs to be captured for later use in the
-// reverse graph, to avoid costly recomputations. However, a lot of the nodes we have
-// in our graphs are simply constants, which are cheap to execute and replicate, and so
-// it's better to just copy them into the reverse graph, without polluting the output
-// lists unnecessarily.
+// Any temporary value from the primal graphs needs to be captured for later use
+// in the reverse graph, to avoid costly recomputations. However, a lot of the
+// nodes we have in our graphs are simply constants, which are cheap to execute
+// and replicate, and so it's better to just copy them into the reverse graph,
+// without polluting the output lists unnecessarily.
 static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) {
   static const auto err = [](Value*) -> Value* {
     throw std::runtime_error("unexpected input");
   };
-  auto & graph = *grad_desc.f;
+  auto& graph = *grad_desc.f;
   Block* reverse_block = rev_info.reverse_block;
 
-  for (Node *top_node : reverse_block->nodes()) {
-    JIT_ASSERT(top_node->kind() == prim::GradOf ||
-               top_node->kind() == prim::AutogradAdd ||
-               top_node->kind() == prim::Undefined);
-    if (top_node->kind() != prim::GradOf) continue;
-    Block * grad_body = top_node->blocks().at(0);
-    for (Node *node : grad_body->nodes()) {
-      for (Value * input : node->inputs()) {
-        if (input->node()->kind() != prim::Constant) continue;
-        if (input->node()->owningBlock() == grad_body) continue;
-        Node *lifted_constant = graph.createClone(input->node(), err);
+  for (Node* top_node : reverse_block->nodes()) {
+    JIT_ASSERT(
+        top_node->kind() == prim::GradOf ||
+        top_node->kind() == prim::AutogradAdd ||
+        top_node->kind() == prim::Undefined);
+    if (top_node->kind() != prim::GradOf)
+      continue;
+    Block* grad_body = top_node->blocks().at(0);
+    for (Node* node : grad_body->nodes()) {
+      for (Value* input : node->inputs()) {
+        if (input->node()->kind() != prim::Constant)
+          continue;
+        if (input->node()->owningBlock() == grad_body)
+          continue;
+        Node* lifted_constant = graph.createClone(input->node(), err);
         reverse_block->prependNode(lifted_constant);
         node->replaceInputWith(input, lifted_constant->output());
       }
@@ -799,22 +959,25 @@ static void liftConstants(Gradient& grad_desc, ReverseDetails& rev_info) {
   }
 }
 
-static void deduplicateSizeCaptures(Gradient& grad_desc, ReverseDetails& rev_info) {
-  Block * primal_block = grad_desc.f->block();
-  const auto usedOnlyInReverse = [primal_block](Value * v) {
-    const auto & uses = v->uses();
-    return std::all_of(uses.begin(), uses.end(),
-                       [primal_block](const Use& u) { return u.user->owningBlock() != primal_block; });
+static void deduplicateSizeCaptures(
+    Gradient& grad_desc,
+    ReverseDetails& rev_info) {
+  Block* primal_block = grad_desc.f->block();
+  const auto usedOnlyInReverse = [primal_block](Value* v) {
+    const auto& uses = v->uses();
+    return std::all_of(uses.begin(), uses.end(), [primal_block](const Use& u) {
+      return u.user->owningBlock() != primal_block;
+    });
   };
   auto captures = getReverseCaptures(grad_desc);
-  value_set capture_set (captures.begin(), captures.end());
-  for (Value * capture : captures) {
-    Node * node = capture->node();
+  value_set capture_set(captures.begin(), captures.end());
+  for (Value* capture : captures) {
+    Node* node = capture->node();
     if (!node->matches("aten::size(Tensor self) -> int[]")) {
       continue;
     }
     if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
-      WithInsertPoint insert_guard { *rev_info.reverse_block->nodes().begin() };
+      WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
       capture->replaceAllUsesWith(SymbolicVariable(node->input()).size());
       node->destroy();
     }
@@ -846,16 +1009,17 @@ static void eliminateDeadCode(ReverseDetails& rev_info) {
 }
 
 static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
-  // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x, s1), s2),
-  // which are equivalent to SumToSize(x, s2), and could save us some captures, but I'm
-  // not 100% sure how to optimize this at this stage, since we don't know which
-  // GradOf blocks will be stitched together to form the derivative. I guess a smart
-  // analysis could implement this, but I didn't have time before the 1.0 release,
-  // so I put this only as a peephole optimization.
+  // TODO: we are sometimes emitting expressions like SumToSize(SumToSize(x,
+  // s1), s2), which are equivalent to SumToSize(x, s2), and could save us some
+  // captures, but I'm not 100% sure how to optimize this at this stage, since
+  // we don't know which GradOf blocks will be stitched together to form the
+  // derivative. I guess a smart analysis could implement this, but I didn't
+  // have time before the 1.0 release, so I put this only as a peephole
+  // optimization.
   liftConstants(grad_desc, rev_info);
   // We generally add a lot of aten::size calls (for derivatives of broadcasting
-  // operators), and they often end up duplicated, and would get captured multiple
-  // times. Make sure we deduplicate them before lifting.
+  // operators), and they often end up duplicated, and would get captured
+  // multiple times. Make sure we deduplicate them before lifting.
   EliminateCommonSubexpression(grad_desc.f);
   deduplicateSizeCaptures(grad_desc, rev_info);
   eliminateDeadCode(rev_info);
@@ -866,10 +1030,11 @@ static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
 // All intermediates needed in the second stage are added to
 // outputs of f, and taken as inputs in df. For a more
 // detailed description see Note [Gradient graphs] in autodiff.h.
-// This function also initializes the fields in grad_desc that were undefined after
-// `addReverseInline` (and extends `df_input_vjps` with vjps for captured temporaries).
+// This function also initializes the fields in grad_desc that were undefined
+// after `addReverseInline` (and extends `df_input_vjps` with vjps for captured
+// temporaries).
 static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
-  auto & graph = *grad_desc.f;
+  auto& graph = *grad_desc.f;
   auto primal_block = graph.block();
   auto reverse_block = rev_info.reverse_block;
 
@@ -902,49 +1067,58 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
 
   std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
   std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
-  // NOTE: we use emplace to avoid replacing an existing index if an output is repeated
+  // NOTE: we use emplace to avoid replacing an existing index if an output is
+  // repeated
   for (size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
     orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
   for (size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
     orig_primal_inputs_idx[graph.inputs()[i]] = i;
 
   // NB: reverse_captures are already deduplicated, and in topo order
-  for (Value * capture_val : reverse_captures) {
+  for (Value* capture_val : reverse_captures) {
     // If it's already an output we don't have to add anything,
     // but register the fact that it needs to be captured.
     if (orig_primal_outputs_idx.count(capture_val) > 0) {
-      grad_desc.df_input_captured_outputs.push_back(orig_primal_outputs_idx[capture_val]);
-    // If it's an input, we could add it as an output but in fact it's
-    // more efficient to use a special kind of capture.
+      grad_desc.df_input_captured_outputs.push_back(
+          orig_primal_outputs_idx[capture_val]);
+      // If it's an input, we could add it as an output but in fact it's
+      // more efficient to use a special kind of capture.
     } else if (orig_primal_inputs_idx.count(capture_val) > 0) {
-      grad_desc.df_input_captured_inputs.push_back(orig_primal_inputs_idx.at(capture_val));
-    // Otherwise it's just a regular intermediate value that we need to add as an output
+      grad_desc.df_input_captured_inputs.push_back(
+          orig_primal_inputs_idx.at(capture_val));
+      // Otherwise it's just a regular intermediate value that we need to add as
+      // an output
     } else {
-      // we need to create a new temporary output for this capture because it wasn't availiable.
+      // we need to create a new temporary output for this capture because it
+      // wasn't availiable.
       graph.registerOutput(capture_val);
-      grad_desc.df_input_captured_outputs.emplace_back(graph.outputs().size() - 1);
+      grad_desc.df_input_captured_outputs.emplace_back(
+          graph.outputs().size() - 1);
     }
   }
 
   // -- Add VJPs for temporaries, adjust df_input_vjps -------------------------
-  // NB [possible optimization]: use the newly added vjp input as soon as the first
-  // vjp for that value is generated, to reduce the lifespan of this input
+  // NB [possible optimization]: use the newly added vjp input as soon as the
+  // first vjp for that value is generated, to reduce the lifespan of this input
   // (currently we add it to the final vjp after all adds).
   for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
-    Value * tmp = graph.outputs().at(i);
+    Value* tmp = graph.outputs().at(i);
     // Add VJP inputs only for intermediates that actually required grad.
-    // Note that we check the contents of the grad_map instead of tmp->requires_grad(),
-    // becuase it's actually a more faithful source. tmp->requires_grad() is really an
-    // overapproximation (i.e. it can have false positives), while the gradients we will
-    // emit for this value can get DCE-d in the optimization pass (because it has no
-    // influence on the real f's outputs that we differentiate).
-    if (rev_info.grad_map.count(tmp) == 0) continue;
-    Value * tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
-    Value * tmp_vjp_prev = rev_info.grad_map.at(tmp);
-    // This is quite weird because we can't first make a sum and then replace all uses
-    // of tmp_vjp_prev (that would replace its use in the sum too!), so we create an
-    // incorrect sum that doesn't use prev vjp, replace uses, and fix the sum.
-    Value * new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
+    // Note that we check the contents of the grad_map instead of
+    // tmp->requires_grad(), becuase it's actually a more faithful source.
+    // tmp->requires_grad() is really an overapproximation (i.e. it can have
+    // false positives), while the gradients we will emit for this value can get
+    // DCE-d in the optimization pass (because it has no influence on the real
+    // f's outputs that we differentiate).
+    if (rev_info.grad_map.count(tmp) == 0)
+      continue;
+    Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
+    Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
+    // This is quite weird because we can't first make a sum and then replace
+    // all uses of tmp_vjp_prev (that would replace its use in the sum too!), so
+    // we create an incorrect sum that doesn't use prev vjp, replace uses, and
+    // fix the sum.
+    Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
     new_vjp->node()->moveAfter(tmp_vjp_prev->node());
     tmp_vjp_prev->replaceAllUsesWith(new_vjp);
     new_vjp->node()->replaceInput(1, tmp_vjp_prev);
@@ -956,13 +1130,13 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
   // construct a map from captured 'value' to the index in the input list
   // used to extract this block into its own function
   std::unordered_map<Value*, size_t> capture_to_formal_index;
-  const auto & add_capture = [&](Value * captured) {
+  const auto& add_capture = [&](Value* captured) {
     capture_to_formal_index[captured] = reverse_block->inputs().size();
     reverse_block->addInput()->copyMetadata(captured);
   };
-  for(auto & offset : grad_desc.df_input_captured_inputs)
+  for (auto& offset : grad_desc.df_input_captured_inputs)
     add_capture(graph.inputs()[offset]);
-  for(auto & offset : grad_desc.df_input_captured_outputs)
+  for (auto& offset : grad_desc.df_input_captured_outputs)
     add_capture(graph.outputs()[offset]);
 
   grad_desc.df = std::make_shared<Graph>();
@@ -974,13 +1148,14 @@ static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
   reverse_block->owningNode()->destroy();
 }
 
-
 Gradient differentiate(std::shared_ptr<Graph>& graph) {
   Gradient grad_desc;
   // Take ownership of the graph
-  JIT_ASSERTM(graph.use_count() == 1,
-              "differentiate will mutate and destroy the graph, so it requires "
-              "graph.use_count() == 1, but found %d", graph.use_count());
+  JIT_ASSERTM(
+      graph.use_count() == 1,
+      "differentiate will mutate and destroy the graph, so it requires "
+      "graph.use_count() == 1, but found %d",
+      graph.use_count());
   std::swap(graph, grad_desc.f);
   // XXX: Take care when handling outputs - they can be duplicated!
 
@@ -1000,4 +1175,5 @@ Gradient differentiate(std::shared_ptr<Graph>& graph) {
   return grad_desc;
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 519a7ae..74df442 100644 (file)
@@ -5,12 +5,14 @@
 
 #include <ATen/ATen.h>
 
-#include <vector>
 #include <memory>
+#include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 using value_list = std::vector<Value*>;
+// clang-format off
 // Example showcasing how Gradient is constructed:
 //
 // Let's assume we have a function f, `m` and `n` do not require grad
@@ -34,6 +36,7 @@ using value_list = std::vector<Value*>;
 //   df_output_vjps = {0}     // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
 //
 // Terminology: vjp = vector-jacobian product
+// clang-format on
 
 struct Gradient {
   explicit operator bool() const {
@@ -45,23 +48,22 @@ struct Gradient {
   // Describes how to construct outputs of f from what its graph will return.
   // This is necessary because some trailing outputs are intermediates produced
   // only to be saved for df (and should be ignored).
-  size_t f_real_outputs = 0;  // initialized for safety.
+  size_t f_real_outputs = 0; // initialized for safety.
 
-  // df inputs are split into two sections: vjps (aka grad_outputs) and captures.
-  // VJPs are "seeds" for the gradient computation given for each input capture
-  // of an Output kind.
-  // Captures are values the need to be saved when f is run. We handle inputs
-  // specially, because this allows us to avoid adding extra vjps as df inputs.
+  // df inputs are split into two sections: vjps (aka grad_outputs) and
+  // captures. VJPs are "seeds" for the gradient computation given for each
+  // input capture of an Output kind. Captures are values the need to be saved
+  // when f is run. We handle inputs specially, because this allows us to avoid
+  // adding extra vjps as df inputs.
 
   std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
   // capture can come from inputs or outputs
   std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
   std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs
 
-
   // df will produce vjps for a subset of inputs of f that required grad.
-  // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a vjp
-  // for inp_idx-th input of f.
+  // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a
+  // vjp for inp_idx-th input of f.
   std::vector<size_t> df_output_vjps; // Offsets into f's inputs.
 
   // How to use gradient to implement a differentiable autograd function:
@@ -76,8 +78,8 @@ struct Gradient {
   //   - Use df_output_vjps to connect next_edges of grad_fn:
   //       for idx in df_output_vjps:
   //         grad_fn.add_next_edge(inputs[idx].gradient_edge())
-  //   - Save captures for df (care needs to be taken to use SavedVariables for inputs and
-  //                           outputs that we will actually return)
+  //   - Save captures for df (care needs to be taken to use SavedVariables for
+  //                           inputs and outputs that we will actually return)
   //   - Return outputs[:f_real_outputs]
   //
   // When running df:
@@ -88,8 +90,9 @@ struct Gradient {
 TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph);
 
 // can we take a derivative of this node symbolically?
-TORCH_API bool isDifferentiable(Node * n);
-TORCH_API bool isDifferentiable(Graph & g);
-TORCH_API bool isZero(Value * v);
+TORCH_API bool isDifferentiable(Node* n);
+TORCH_API bool isDifferentiable(Graph& g);
+TORCH_API bool isZero(Value* v);
 
-}}
+} // namespace jit
+} // namespace torch
index d514cc4..7d709a6 100644 (file)
@@ -1,19 +1,21 @@
 #include <torch/csrc/jit/batched/BatchTensor.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims){
-  if(data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1){
-    throw std::runtime_error("malformed MaskedBatch with data.dim(): "
-      + std::to_string(data.dim()) + ", mask.dim(): " + std::to_string(mask.dim())
-      + ", dims.size(0): " + std::to_string(dims.size(0)));
+BatchTensor::BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims) {
+  if (data.dim() != mask.dim() || mask.dim() != dims.size(0) + 1) {
+    throw std::runtime_error(
+        "malformed MaskedBatch with data.dim(): " + std::to_string(data.dim()) +
+        ", mask.dim(): " + std::to_string(mask.dim()) +
+        ", dims.size(0): " + std::to_string(dims.size(0)));
   }
   this->data = std::move(data);
   this->mask = std::move(mask);
   this->dims = std::move(dims);
 }
 
-BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){
+BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size) {
   dims = at::empty(data.dim(), data.options().dtype(at::kByte));
   dims.fill_(0);
   std::vector<int64_t> sizes(data.dim() + 1, -1);
@@ -25,13 +27,16 @@ BatchTensor::BatchTensor(const at::Tensor& data, int64_t batch_size){
   mask.fill_(1);
 }
 
-BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims) {
+BatchTensor::BatchTensor(
+    const std::vector<at::Tensor>& datalist,
+    at::Tensor dims) {
   auto bs = datalist.size();
-  std::vector<int64_t> sizes(dims.size(0) + 1, 0), mask_sizes(dims.size(0) + 1, 0);
+  std::vector<int64_t> sizes(dims.size(0) + 1, 0),
+      mask_sizes(dims.size(0) + 1, 0);
   sizes[0] = bs;
   mask_sizes[0] = bs;
-  for(int64_t i = 1; i < dims.size(0) + 1; i++){
-    for(const auto& x : datalist){
+  for (int64_t i = 1; i < dims.size(0) + 1; i++) {
+    for (const auto& x : datalist) {
       sizes[i] = std::max(sizes[i], x.size(i));
     }
     mask_sizes[i] = *dims[i - 1].data<uint8_t>() ? sizes[i] : 1;
@@ -40,11 +45,11 @@ BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dim
   data.fill_(0);
   mask = at::empty(mask_sizes, datalist[0].options().dtype(at::kByte));
   mask.fill_(0);
-  for(std::size_t i = 0; i < datalist.size(); i++){
+  for (std::size_t i = 0; i < datalist.size(); i++) {
     auto data_item = data.narrow(0, i, 1);
     auto mask_item = mask.narrow(0, i, 1);
-    for(int64_t j = 0; j < dims.size(0); j++){
-      if(*dims[j].data<uint8_t>()){
+    for (int64_t j = 0; j < dims.size(0); j++) {
+      if (*dims[j].data<uint8_t>()) {
         data_item = data_item.narrow(j + 1, 0, datalist[i].size(j + 1));
         mask_item = mask_item.narrow(j + 1, 0, datalist[i].size(j + 1));
       }
@@ -58,16 +63,16 @@ BatchTensor::BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dim
 std::vector<at::Tensor> BatchTensor::examples() {
   std::vector<at::Tensor> result;
   // calculate number of valid entries in dth dimension of data
-  auto mask_sum = [](at::Tensor data, int d) -> int64_t{
+  auto mask_sum = [](at::Tensor data, int d) -> int64_t {
     data = data.sum(d, /*keepdim=*/true);
-    while(data.dim() >= 1)
+    while (data.dim() >= 1)
       data = data[0];
     return *data.data<int64_t>();
   };
-  for(int64_t i = 0; i < data.size(0); i++){
+  for (int64_t i = 0; i < data.size(0); i++) {
     auto data_tmp = data.narrow(0, i, 1);
-    for(int64_t d = 0; d < dims.size(0); d++){
-      if(*dims[d].data<uint8_t>()){
+    for (int64_t d = 0; d < dims.size(0); d++) {
+      if (*dims[d].data<uint8_t>()) {
         data_tmp = data_tmp.narrow(d + 1, 0, mask_sum(mask[i], d));
       }
     }
@@ -89,4 +94,5 @@ void initBatchTensorBindings(PyObject* module) {
       .def("get_dims", &BatchTensor::get_dims);
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index a7acd27..d74bf12 100644 (file)
@@ -1,18 +1,19 @@
 #pragma once
+#include <ATen/ATen.h>
 #include <ATen/Tensor.h>
 #include <torch/csrc/jit/pybind.h>
-#include <ATen/ATen.h>
 #include <iostream>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 struct BatchTensor {
-public:
+ public:
   BatchTensor(at::Tensor data, at::Tensor mask, at::Tensor dims);
   // expand a tensor to a batchtensor given batch_size
   BatchTensor(const at::Tensor& data, int64_t batch_size);
   BatchTensor(const std::vector<at::Tensor>& datalist, at::Tensor dims);
-  const char * toString() const {
+  const char* toString() const {
     return "BatchTensor";
   }
   at::IntList sizes() const {
@@ -22,26 +23,29 @@ public:
     return data.dim();
   }
   std::vector<at::Tensor> examples();
-  at::Tensor get_data(){
+  at::Tensor get_data() {
     return data;
   }
-  at::Tensor get_mask(){
+  at::Tensor get_mask() {
     return mask;
   }
-  at::Tensor get_dims(){
+  at::Tensor get_dims() {
     return dims;
   }
 
-public:
+ public:
   // data is a Tensor whose size is the batch size in the batch dimension,
   // the size of all examples in static dimensions,
-  // and at least as large as the largest example in the batch in dynamic dimensions.
+  // and at least as large as the largest example in the batch in dynamic
+  // dimensions.
   at::Tensor data;
   // mask is a Tensor whose size is the batch size in the batch dimension,
   // one in static dimensions,
-  // and at least as large as the largest example in the batch in dynamic dimensions.
-  // Each entry in the mask corresponds to one or more entries in the data array (singleton, i.e., static, dimensions are broadcasted),
-  // with a one in the mask denoting that the corresponding data entries represent valid, meaningful data and a zero denoting that they do not.
+  // and at least as large as the largest example in the batch in dynamic
+  // dimensions. Each entry in the mask corresponds to one or more entries in
+  // the data array (singleton, i.e., static, dimensions are broadcasted), with
+  // a one in the mask denoting that the corresponding data entries represent
+  // valid, meaningful data and a zero denoting that they do not.
   at::Tensor mask;
   // dims is a 1-dimensional tensor with a bool for each non-batch dimension,
   // representing whether that dimension is static (False) or dynamic (True).
@@ -49,4 +53,5 @@ public:
 };
 
 void initBatchTensorBindings(PyObject* module);
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index b9b0a87..9e7696b 100644 (file)
@@ -3,6 +3,8 @@
 #define CATCH_CONFIG_PREFIX_ALL
 #include <catch.hpp>
 
-// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
-// define our own version that doesn't warn.
-#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes
+// warning; define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS(...) \
+  INTERNAL_CATCH_THROWS(           \
+      "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__)
index 63082ee..13871c1 100644 (file)
@@ -1,10 +1,11 @@
 #pragma once
+#include <sstream>
 #include <string>
-#include <vector>
 #include <unordered_map>
-#include <sstream>
+#include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // A template environment is a mapping from template variable names, e.g.,
 // identifier (corresponding to $identifier) to their expansions.
@@ -14,31 +15,29 @@ namespace torch { namespace jit {
 // in the top level environment, and then recurses into a parent
 // environment if the key is not found.)
 struct TemplateEnv {
-  TemplateEnv()
-  : parent(nullptr) {}
-  TemplateEnv(TemplateEnv & parent)
-  : parent(&parent) {}
+  TemplateEnv() : parent(nullptr) {}
+  TemplateEnv(TemplateEnv& parent) : parent(&parent) {}
 
   using string_list = std::vector<std::string>;
 
   // Add a string 'v' to the map at key 'k'.
-  void s(const std::string & k, const std::string & v) {
+  void s(const std::string& k, const std::string& v) {
     strings_[k] = v;
     lists_.erase(k);
   }
 
   // Add a number 'v' to the map at key 'k'
-  template<typename T>
-  void d(const std::string & k, const T & v) {
+  template <typename T>
+  void d(const std::string& k, const T& v) {
     strings_[k] = std::to_string(v);
     lists_.erase(k);
   }
 
   // Retrieve the string representation of the value stored at 'k' from the map.
   // Raises an exception if the key is not found.
-  const std::string & s(const std::string & k) const {
-    if(strings_.count(k) == 0) {
-      if(parent) {
+  const std::string& s(const std::string& k) const {
+    if (strings_.count(k) == 0) {
+      if (parent) {
         return parent->s(k);
       }
       notFound(k);
@@ -47,16 +46,16 @@ struct TemplateEnv {
   }
 
   // Store a list of strings 'v' in the map at 'k'.
-  void v(const std::string & k, const string_list & v) {
+  void v(const std::string& k, const string_list& v) {
     lists_[k] = v;
     strings_.erase(k);
   }
 
   // Retrieve a list of strings stored at 'k' from the map.
   // Raises an exception if the key is not found.
-  const string_list & v(const std::string & k) const {
-    if(lists_.count(k) == 0) {
-      if(parent) {
+  const string_list& v(const std::string& k) const {
+    if (lists_.count(k) == 0) {
+      if (parent) {
         return parent->v(k);
       }
       notFound(k);
@@ -65,25 +64,25 @@ struct TemplateEnv {
   }
 
   // Test if a string 'k' is a string (as opposed to a list.)
-  bool keyIsString(const std::string & k) const {
-    if(strings_.count(k) > 0)
+  bool keyIsString(const std::string& k) const {
+    if (strings_.count(k) > 0)
       return true;
-    if(lists_.count(k) > 0)
+    if (lists_.count(k) > 0)
       return false;
-    if(parent)
+    if (parent)
       return parent->keyIsString(k);
     notFound(k);
   }
-private:
-  [[ noreturn ]]
-  void notFound(const std::string & k) const {
+
+ private:
+  [[noreturn]] void notFound(const std::string& k) const {
     std::stringstream ss;
     ss << "key not found: " << k;
     throw std::logic_error(ss.str());
   }
-  std::unordered_map<std::string,std::string> strings_;
-  std::unordered_map<std::string,string_list> lists_;
-  TemplateEnv * parent;
+  std::unordered_map<std::string, std::string> strings_;
+  std::unordered_map<std::string, string_list> lists_;
+  TemplateEnv* parent;
 };
 
 /*
@@ -96,30 +95,29 @@ private:
 # if this list is not empty and ${foo,} will insert one after.
 */
 struct CodeTemplate {
-  /* implicit */ CodeTemplate(std::string t)
-  : template_text(std::move(t)) {}
+  /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {}
 
-  std::string format(const TemplateEnv & env) {
+  std::string format(const TemplateEnv& env) {
     std::stringstream out;
     size_t pos = 0;
     size_t indent = 0;
     bool all_whitespace = true;
-    while(pos < template_text.size()) {
+    while (pos < template_text.size()) {
       char c = template_text[pos];
-      if(c == '$') {
+      if (c == '$') {
         std::stringstream kss;
         bool comma_before;
         bool comma_after;
-        size_t new_pos = parseKey(pos,kss,comma_before,comma_after);
+        size_t new_pos = parseKey(pos, kss, comma_before, comma_after);
         std::string k = kss.str();
         bool is_string = env.keyIsString(k);
-        if(all_whitespace) {
-          if(is_string)
+        if (all_whitespace) {
+          if (is_string)
             emitStringWithIndents(out, indent, env.s(k));
           else
             emitLinesIndented(out, indent, env.v(k));
         } else {
-          if(is_string)
+          if (is_string)
             out << env.s(k);
           else
             emitCommaSeparatedList(out, env.v(k), comma_before, comma_after);
@@ -128,10 +126,10 @@ struct CodeTemplate {
         pos = new_pos;
       } else {
         out << c;
-        if(!isspace(c))
+        if (!isspace(c))
           all_whitespace = false;
         indent++;
-        if(c == '\n') {
+        if (c == '\n') {
           indent = 0;
           all_whitespace = true;
         }
@@ -140,29 +138,34 @@ struct CodeTemplate {
     }
     return out.str();
   }
-private:
+
+ private:
   using string_list = std::vector<std::string>;
   char charAt(size_t p) {
     if (p >= template_text.size())
       throw std::logic_error("EOS found in key");
     return template_text[p];
   }
-  size_t parseKey(size_t pos, std::ostream & k, bool & comma_before, bool & comma_after) {
+  size_t parseKey(
+      size_t pos,
+      std::ostream& k,
+      bool& comma_before,
+      bool& comma_after) {
     comma_before = false;
     comma_after = false;
     pos++;
-    if(charAt(pos) == '{') {
+    if (charAt(pos) == '{') {
       pos++;
-      if(charAt(pos) == ',') {
+      if (charAt(pos) == ',') {
         comma_before = true;
         pos++;
       }
       pos = parseIdent(pos, k);
-      if(charAt(pos) == ',') {
+      if (charAt(pos) == ',') {
         comma_after = true;
         pos++;
       }
-      if(charAt(pos) != '}')
+      if (charAt(pos) != '}')
         throw std::logic_error("missing terminating '}'");
       pos++;
       return pos;
@@ -170,55 +173,66 @@ private:
       return parseIdent(pos, k);
     }
   }
-  size_t parseIdent(size_t pos, std::ostream & k) {
-    while(pos < template_text.size() &&
-      (isalnum(template_text[pos]) || template_text[pos] == '_')) {
+  size_t parseIdent(size_t pos, std::ostream& k) {
+    while (pos < template_text.size() &&
+           (isalnum(template_text[pos]) || template_text[pos] == '_')) {
       k << template_text[pos];
       pos++;
     }
     return pos;
   }
-  void emitCommaSeparatedList(std::ostream & out, const string_list & strings, bool comma_before, bool comma_after) {
-    if(comma_before && strings.size() > 0)
+  void emitCommaSeparatedList(
+      std::ostream& out,
+      const string_list& strings,
+      bool comma_before,
+      bool comma_after) {
+    if (comma_before && strings.size() > 0)
       out << ", ";
-    for(size_t i = 0; i < strings.size(); ++i) {
-      if(i > 0)
+    for (size_t i = 0; i < strings.size(); ++i) {
+      if (i > 0)
         out << ", ";
       out << strings[i];
     }
-    if(comma_after && strings.size() > 0)
+    if (comma_after && strings.size() > 0)
       out << ", ";
   }
   // These indentation functions follow the convention that they never emit
   // leading or trailing newlines when the input string does not have leading
   // or trailing newlines. It's the responsibility of the calling function
   // to indent correctly in the context.
-  void emitIndent(std::ostream & out, size_t indent) {
-    for(size_t i = 0; i < indent; ++i) {
+  void emitIndent(std::ostream& out, size_t indent) {
+    for (size_t i = 0; i < indent; ++i) {
       out << " ";
     }
   }
-  void emitStringWithIndents(std::ostream & out, size_t indent, const std::string & str) {
-    for(auto c : str) {
+  void emitStringWithIndents(
+      std::ostream& out,
+      size_t indent,
+      const std::string& str) {
+    for (auto c : str) {
       out << c;
-      if(c == '\n') {
+      if (c == '\n') {
         emitIndent(out, indent);
       }
     }
   }
-  void emitLinesIndented(std::stringstream & out, size_t indent, const string_list & strings) {
-    for(size_t i = 0; i < strings.size(); ++i) {
-      if(i > 0)
+  void emitLinesIndented(
+      std::stringstream& out,
+      size_t indent,
+      const string_list& strings) {
+    for (size_t i = 0; i < strings.size(); ++i) {
+      if (i > 0)
         emitIndent(out, indent);
-      emitStringWithIndents(out,indent,strings[i]);
-      if(i+1 != strings.size())
+      emitStringWithIndents(out, indent, strings[i]);
+      if (i + 1 != strings.size())
         out << "\n";
     }
   }
   std::string template_text;
 };
-static inline std::string format(const std::string & fmt, TemplateEnv & env) {
+static inline std::string format(const std::string& fmt, TemplateEnv& env) {
   return CodeTemplate(fmt).format(env);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index c1d5884..6a3aded 100644 (file)
@@ -1,10 +1,11 @@
+#include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/utils/functional.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // IValue -> Constant node
 Value* insertConstant(
@@ -12,22 +13,23 @@ Value* insertConstant(
     const IValue& val,
     c10::optional<SourceRange> loc,
     c10::optional<ScopePtr> scope) {
-  Node * n = g.create(prim::Constant);
-  if(val.isTensor()) {
+  Node* n = g.create(prim::Constant);
+  if (val.isTensor()) {
     at::Tensor ref = val.toTensor();
-    if(!ref.defined()) {
+    if (!ref.defined()) {
       n->destroy();
       return g.insertNode(g.createUndefined())->output();
     }
     if (ref.is_variable()) {
       ref = autograd::Variable(ref).data();
     }
-    n->output()->inferTypeFrom(ref); // note: before t_ because of std::move(ref)
+    n->output()->inferTypeFrom(
+        ref); // note: before t_ because of std::move(ref)
     n->t_(attr::value, std::move(ref));
-  } else if(val.isInt()) {
+  } else if (val.isInt()) {
     n->i_(attr::value, val.toInt());
     n->output()->setType(IntType::get());
-  } else if(val.isDouble()) {
+  } else if (val.isDouble()) {
     n->f_(attr::value, val.toDouble());
     n->output()->setType(FloatType::get());
   } else if (val.isBool()) {
@@ -35,110 +37,120 @@ Value* insertConstant(
     n->output()->setType(BoolType::get());
   } else if (val.isBoolList()) {
     auto bool_list = val.toBoolList()->elements();
-    n->is_(attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
+    n->is_(
+        attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
     n->output()->setType(ListType::ofBools());
-  } else if(val.isIntList()) {
+  } else if (val.isIntList()) {
     n->is_(attr::value, val.toIntList()->elements());
     n->output()->setType(ListType::ofInts());
-  } else if(val.isTensorList()) {
-    n->ts_(attr::value, fmap(val.toTensorList()->elements(), [](const at::Tensor & t) {
-      return autograd::Variable(t).data();
-    }));
+  } else if (val.isTensorList()) {
+    n->ts_(
+        attr::value,
+        fmap(val.toTensorList()->elements(), [](const at::Tensor& t) {
+          return autograd::Variable(t).data();
+        }));
     n->output()->setType(ListType::ofTensors());
-  } else if(val.isString()) {
+  } else if (val.isString()) {
     n->s_(attr::value, val.toString()->string());
     n->output()->setType(StringType::get());
-  } else if(val.isDevice()) {
+  } else if (val.isDevice()) {
     std::stringstream ss;
     ss << val.toDevice();
     n->s_(attr::value, ss.str());
     n->output()->setType(DeviceObjType::get());
-  } else if(val.isNone()) {
+  } else if (val.isNone()) {
     n->destroy();
     n = g.create(prim::None);
     n->output()->setType(NoneType::get());
   } else {
-    throw constant_not_supported_error("Unsupported value kind: " + val.tagKind());
+    throw constant_not_supported_error(
+        "Unsupported value kind: " + val.tagKind());
   }
-  if(loc)
+  if (loc)
     n->setSourceLocation(std::make_shared<SourceRange>(*loc));
-  if(scope)
+  if (scope)
     n->setScope(*scope);
   return g.insertNode(n)->output();
 }
 
 RegisterOperators reg({
-  // Implementation of constant node, computes and IValue
-  Operator(
-      FunctionSchema(prim::Constant, {}, {}, /*is_vararg=*/false, /*is_varret=*/true),
-      [](const Node* node) -> Operation {
-        TypePtr type = node->output()->type();
-        if(type->isSubtypeOf(DynamicType::get())) {
-          auto t = autograd::make_variable(node->t(attr::value));
-          return [t](Stack& stack) {
-            push(stack, t);
-            return 0;
-          };
-        } else if (type->isSubtypeOf(BoolType::get())) {
-          bool b = node->i(attr::value);
-          return [b](Stack& stack) {
-            push(stack, b);
-            return 0;
-          };
-        } else if (
-            type->isSubtypeOf(NumberType::get()) &&
-            node->kindOf(attr::value) == AttributeKind::i) {
-          auto i = node->i(attr::value);
-          return [i](Stack& stack) {
-            push(stack, i);
-            return 0;
-          };
-        } else if (
-            type->isSubtypeOf(NumberType::get()) &&
-            node->kindOf(attr::value) == AttributeKind::f) {
-          auto f = node->f(attr::value);
-          return [f](Stack& stack) {
-            push(stack, f);
-            return 0;
-          };
-        } else if(type->isSubtypeOf(ListType::ofInts())) {
-          const auto& is = node->is(attr::value);
-          return [is](Stack& stack) {
-            push(stack, is);
-            return 0;
-          };
-        } else if(type->isSubtypeOf(ListType::ofBools())) {
-          const auto& bs = node->is(attr::value);
-          return [bs](Stack& stack) {
-            push(stack, bs);
-            return 0;
-          };
-        } else if(type->isSubtypeOf(ListType::ofTensors())) {
-          const auto& ts = fmap(node->ts(attr::value), [](const at::Tensor & t) -> at::Tensor {
-            return autograd::make_variable(t);
-          });
-          return [ts](Stack& stack) {
-            push(stack, ts);
-            return 0;
-          };
-        } else if (type == StringType::get()) {
-          const auto& s = node->s(attr::value);
-          return [s](Stack& stack) {
-            push(stack, s);
-            return 0;
-          };
-        } else if (type == DeviceObjType::get()) {
-          auto d = c10::Device(node->s(attr::value));
-          return [d](Stack& stack) {
-            push(stack, d);
-            return 0;
-          };
-        } else {
-          std::stringstream ss;
-          ss << "constant literal not supported for: " << type->str();
-          throw std::runtime_error(ss.str());
-        }
-      }),
+    // Implementation of constant node, computes and IValue
+    Operator(
+        FunctionSchema(
+            prim::Constant,
+            {},
+            {},
+            /*is_vararg=*/false,
+            /*is_varret=*/true),
+        [](const Node* node) -> Operation {
+          TypePtr type = node->output()->type();
+          if (type->isSubtypeOf(DynamicType::get())) {
+            auto t = autograd::make_variable(node->t(attr::value));
+            return [t](Stack& stack) {
+              push(stack, t);
+              return 0;
+            };
+          } else if (type->isSubtypeOf(BoolType::get())) {
+            bool b = node->i(attr::value);
+            return [b](Stack& stack) {
+              push(stack, b);
+              return 0;
+            };
+          } else if (
+              type->isSubtypeOf(NumberType::get()) &&
+              node->kindOf(attr::value) == AttributeKind::i) {
+            auto i = node->i(attr::value);
+            return [i](Stack& stack) {
+              push(stack, i);
+              return 0;
+            };
+          } else if (
+              type->isSubtypeOf(NumberType::get()) &&
+              node->kindOf(attr::value) == AttributeKind::f) {
+            auto f = node->f(attr::value);
+            return [f](Stack& stack) {
+              push(stack, f);
+              return 0;
+            };
+          } else if (type->isSubtypeOf(ListType::ofInts())) {
+            const auto& is = node->is(attr::value);
+            return [is](Stack& stack) {
+              push(stack, is);
+              return 0;
+            };
+          } else if (type->isSubtypeOf(ListType::ofBools())) {
+            const auto& bs = node->is(attr::value);
+            return [bs](Stack& stack) {
+              push(stack, bs);
+              return 0;
+            };
+          } else if (type->isSubtypeOf(ListType::ofTensors())) {
+            const auto& ts = fmap(
+                node->ts(attr::value), [](const at::Tensor& t) -> at::Tensor {
+                  return autograd::make_variable(t);
+                });
+            return [ts](Stack& stack) {
+              push(stack, ts);
+              return 0;
+            };
+          } else if (type == StringType::get()) {
+            const auto& s = node->s(attr::value);
+            return [s](Stack& stack) {
+              push(stack, s);
+              return 0;
+            };
+          } else if (type == DeviceObjType::get()) {
+            auto d = c10::Device(node->s(attr::value));
+            return [d](Stack& stack) {
+              push(stack, d);
+              return 0;
+            };
+          } else {
+            std::stringstream ss;
+            ss << "constant literal not supported for: " << type->str();
+            throw std::runtime_error(ss.str());
+          }
+        }),
 });
 
 c10::optional<IValue> toIValue(const Value* v) {
@@ -151,4 +163,5 @@ c10::optional<IValue> toIValue(const Value* v) {
   op(stack);
   return stack.back();
 }
-}}
+} // namespace jit
+} // namespace torch
index d64bc15..3a787f0 100644 (file)
@@ -1,13 +1,14 @@
 #pragma once
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/scope.h>
 #include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 // helpers for handling constants in the IR
 // - create constant nodes from ints, floats, intlist, Tensors, and other types
 // - implement primitive constant ops.
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct Graph;
 struct Value;
@@ -19,14 +20,14 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error {
 
 // note: prefer g.insertConsant(val, loc) which does exactly the same thing
 // this function is only declared/defined here because its implementation is
-// closely related to the implementation of prim::Constant that is also in constants.cpp
+// closely related to the implementation of prim::Constant that is also in
+// constants.cpp
 TORCH_API Value* insertConstant(
     Graph& g,
     const IValue& val,
     c10::optional<SourceRange> loc = c10::nullopt,
     c10::optional<ScopePtr> scope = c10::nullopt);
 
-
 //////////////////////////////////////////////////////////////////////////////////
 // Helper for retrieving constants
 //////////////////////////////////////////////////////////////////////////////////
@@ -39,9 +40,10 @@ TORCH_API c10::optional<IValue> toIValue(const Value* v);
 // same rules as the interpreter
 template <typename T>
 c10::optional<T> constant_as(const Value* v) {
-  if(auto ivalue = toIValue(v)) {
+  if (auto ivalue = toIValue(v)) {
     return ivalue->to<T>();
   }
   return c10::nullopt;
 }
-}}
+} // namespace jit
+} // namespace torch
index 0e786c7..79dce52 100644 (file)
@@ -9,7 +9,9 @@
 #include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/memory.h>
 
-namespace torch { namespace jit { namespace detail {
+namespace torch {
+namespace jit {
+namespace detail {
 
 // DynamicDAG is a simple directed acyclic graph that dynamically maintains a
 // topological order as edges/vertices are added and removed.
@@ -19,12 +21,12 @@ namespace torch { namespace jit { namespace detail {
 //   merge black nodes that are directly connected by contracting the
 //   edge between them while still maintaining the DAG and a topological order?
 //   Use contractEdge().
-// - Let's say you have a DAG where each vertex is a Node* and the edges represent
-//   data dependencies. We wish to determine if adding a new Node* with certain
-//   data dependencies (or moving an existing one to use new dependencies) is valid.
-//   Use DynamicDAG::addEdge() to add the new data dependencies to the DAG:
-//   it will either find a valid reordering of the DAG's topological order or throw
-//   if the resulting DAG is invalid.
+// - Let's say you have a DAG where each vertex is a Node* and the edges
+//   represent data dependencies. We wish to determine if adding a new Node*
+//   with certain data dependencies (or moving an existing one to use new
+//   dependencies) is valid. Use DynamicDAG::addEdge() to add the new data
+//   dependencies to the DAG: it will either find a valid reordering of the
+//   DAG's topological order or throw if the resulting DAG is invalid.
 //
 // The implementation is based off of the PK algorithm in the following paper:
 // "A Dynamic Topsort Algorithm for Directed Acyclic Graphs"
@@ -32,12 +34,16 @@ namespace torch { namespace jit { namespace detail {
 // https://www.doc.ic.ac.uk/~phjk/Publications/DynamicTopoSortAlg-JEA-07.pdf
 // It is summarized in [Edge addition] (see DynamicDAG<T>::addEdge)
 
-template <typename T> struct Vertex;
-template <typename T> struct DynamicDAG;
-template <typename T> using vertex_list = std::vector<Vertex<T>*>;
-template <typename T> using unique_vertex = std::unique_ptr<Vertex<T>>;
+template <typename T>
+struct Vertex;
+template <typename T>
+struct DynamicDAG;
+template <typename T>
+using vertex_list = std::vector<Vertex<T>*>;
+template <typename T>
+using unique_vertex = std::unique_ptr<Vertex<T>>;
 
-enum class DFSDirection {forward, backward};
+enum class DFSDirection { forward, backward };
 
 // Used to represent adjacency lists in DynamicDAG.
 // Has set semantics: stores distinct elements.
@@ -71,11 +77,21 @@ struct vertex_set {
       return a->ord < b->ord;
     });
   }
-  size_t size() const { return data_.size(); }
-  iterator begin() { return data_.begin(); }
-  iterator end() { return data_.end(); }
-  reverse_iterator rbegin() { return data_.rbegin(); }
-  reverse_iterator rend() { return data_.rend(); }
+  size_t size() const {
+    return data_.size();
+  }
+  iterator begin() {
+    return data_.begin();
+  }
+  iterator end() {
+    return data_.end();
+  }
+  reverse_iterator rbegin() {
+    return data_.rbegin();
+  }
+  reverse_iterator rend() {
+    return data_.rend();
+  }
 
  private:
   std::vector<Vertex<T>*> data_;
@@ -110,7 +126,9 @@ struct visited_list {
     });
   }
 
-  const vertex_list<T>& vector() { return data_; }
+  const vertex_list<T>& vector() {
+    return data_;
+  }
 
  private:
   vertex_list<T> data_;
@@ -118,20 +136,29 @@ struct visited_list {
 
 template <typename T>
 struct Vertex {
-  Vertex(size_t ord, T datum)
-  : ord(ord), visited_(false) { data.push_back(datum); }
+  Vertex(size_t ord, T datum) : ord(ord), visited_(false) {
+    data.push_back(datum);
+  }
 
   std::vector<T> data;
   size_t ord; // unique topological index
 
   std::string toString();
-  vertex_set<T>& in_edges() { return edges_.in_edges; }
-  vertex_set<T>& out_edges() { return edges_.out_edges; }
-  IOEdges<T>&& move_edges() { return std::move(edges_); }
+  vertex_set<T>& in_edges() {
+    return edges_.in_edges;
+  }
+  vertex_set<T>& out_edges() {
+    return edges_.out_edges;
+  }
+  IOEdges<T>&& move_edges() {
+    return std::move(edges_);
+  }
 
-  bool visited() { return visited_; }
+  bool visited() {
+    return visited_;
+  }
 
-private:
+ private:
   IOEdges<T> edges_;
 
   friend visited_list<T>;
@@ -149,7 +176,9 @@ struct DynamicDAG {
 
   // max_size() >= the number of live vertices.
   // for all vertices v, v.ord < max_size()
-  size_t max_size() const { return vertices_.size(); };
+  size_t max_size() const {
+    return vertices_.size();
+  };
   c10::optional<Vertex<T>*> at(size_t ord) const;
 
   std::string toString();
@@ -179,9 +208,10 @@ struct DynamicDAG {
 // O(vertices_.size()). Used for testing, don't call this often.
 template <typename T>
 size_t DynamicDAG<T>::debugNumVertices() const {
-  return std::count_if(vertices_.begin(), vertices_.end(),
-      [](const unique_vertex<T>& v) {
-        if (v) return true;
+  return std::count_if(
+      vertices_.begin(), vertices_.end(), [](const unique_vertex<T>& v) {
+        if (v)
+          return true;
         return false;
       });
 }
@@ -205,7 +235,8 @@ template <typename T>
 void DynamicDAG<T>::debugCheckInvariants() {
   for (size_t ord = 0; ord < vertices_.size(); ++ord) {
     const auto& vertex = vertices_.at(ord);
-    if (!vertex) continue;
+    if (!vertex)
+      continue;
 
     AT_ASSERTM(vertex->ord == ord, toString());
     for (auto* v : vertex->in_edges()) {
@@ -248,14 +279,15 @@ IOEdges<T> DynamicDAG<T>::removeVertex(Vertex<T>* v) {
  *
  * Assume we are adding an edge x -> y and that ord(x) > ord(y).
  * First, if there is a path y ----> x through some other vertices, then this
- * edge addition would create a cycle. Figure this out via DFS and throw if necessary.
+ * edge addition would create a cycle. Figure this out via DFS and throw if
+ * necessary.
  *
  * Now, consider the set of all vertices v such that ord(x) > ord(v) > ord(y).
  * Call this set the affected region (AR) -- these are the only vertices we
  * need to consider for reordering to make the resulting graph valid.
  *
- * Find all children of y (through DFS) in AR (call this set deltaF and add y to it)
- * Find all parents of x in AR (call this set deltaB and add x to it).
+ * Find all children of y (through DFS) in AR (call this set deltaF and add y to
+ * it) Find all parents of x in AR (call this set deltaB and add x to it).
  *
  * Move y and all the children of y to after x and all the parents of x. The
  * result topological ordering is valid.
@@ -291,7 +323,8 @@ IOEdges<T> DynamicDAG<T>::removeVertex(Vertex<T>* v) {
  * deltaB (sorted) = {c(2), d(4), x(6)}. deltaB ords = { 2, 4, 6 }
  * deltaF (sorted) = {y(1), a(3), b(5)}. deltaF ords = { 1, 3, 5 }
  *
- * 2) append the two lists: the result is the order we want these vertices to have.
+ * 2) append the two lists: the result is the order we want these vertices to
+ *    have.
  * L = {c(2), d(4), x(6), y(1), a(3), b(5)}.
  *
  * 3) Merge the sorted ords: R = { 1, 2, 3, 4, 5, 6 }.
@@ -304,9 +337,9 @@ IOEdges<T> DynamicDAG<T>::removeVertex(Vertex<T>* v) {
  *
  * [Analysis]
  * This is O(|AR| log |AR|). |AR| is equal to ord(consumer) - ord(producer).
- * AR is the "affected region": { v s.t. ord(v) in [ord(producer), ord(consumer)] }
- * consisting of the only vertices that can possibly be moved around due to this
- * edge addition.
+ * AR is the "affected region": { v s.t. ord(v) in [ord(producer),
+ * ord(consumer)] } consisting of the only vertices that can possibly be moved
+ * around due to this edge addition.
  *
  * NB: Pearce and Kelly give a complexity bound of <<delta>> where
  * delta = union(deltaF, deltaB) and <<S>> on a set S is
@@ -316,9 +349,11 @@ template <typename T>
 void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
   JIT_ASSERT(producer != consumer);
 
-  // NB: DynamicDAG is a simple graph. If an edge exists already, don't do anything.
+  // NB: DynamicDAG is a simple graph. If an edge exists already, don't do
+  // anything.
   bool is_distinct = producer->out_edges().insert(consumer);
-  if (!is_distinct) return;
+  if (!is_distinct)
+    return;
   is_distinct = consumer->in_edges().insert(producer);
   JIT_ASSERT(is_distinct);
 
@@ -330,18 +365,26 @@ void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
   visited_list<T> deltaF;
   visited_list<T> deltaB;
 
-  // Search for vertices that are reachable from consumer that have a now incorrect
-  // topological ordering.
-  if (dfsSearch(DFSDirection::forward, consumer, producer,
-                /*bound=*/producer->ord, deltaF)) {
+  // Search for vertices that are reachable from consumer that have a now
+  // incorrect topological ordering.
+  if (dfsSearch(
+          DFSDirection::forward,
+          consumer,
+          producer,
+          /*bound=*/producer->ord,
+          deltaF)) {
     // Path found! This means there's a cycle.
     AT_ERROR("Cycle detected while trying to add edge.");
   }
 
   // Search for vertices that can reach producer that have a now incorrect
   // topological ordering
-  JIT_ASSERT(!dfsSearch(DFSDirection::backward, producer, consumer,
-                        /*bound=*/consumer->ord, deltaB));
+  JIT_ASSERT(!dfsSearch(
+      DFSDirection::backward,
+      producer,
+      consumer,
+      /*bound=*/consumer->ord,
+      deltaB));
 
   // Reorder the vertices that are reachable from consumer to occur BEFORE
   // the vertices that can reach producer.
@@ -353,7 +396,8 @@ void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
 // These are the only vertices that can possibly be moved around
 // during edge contraction.
 //
-// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|, |in_edges(consumer)|))
+// contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|,
+//                   |in_edges(consumer)|))
 template <typename T>
 bool DynamicDAG<T>::contractEdge(Vertex<T>* producer, Vertex<T>* consumer) {
   JIT_ASSERT(producer != consumer);
@@ -374,10 +418,13 @@ bool DynamicDAG<T>::contractEdge(Vertex<T>* producer, Vertex<T>* consumer) {
 }
 
 template <typename T>
-void DynamicDAG<T>::mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* consumer) {
+void DynamicDAG<T>::mergeProducerIntoConsumer(
+    Vertex<T>* producer,
+    Vertex<T>* consumer) {
   // Optimization: we want to concat lists [producer.data, consumer.data].
   // Instead of inserting into the beginning of consumer.data, do a swap.
-  producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
+  producer->data.insert(
+      producer->data.end(), consumer->data.begin(), consumer->data.end());
   std::swap(consumer->data, producer->data);
 
   auto edges = removeVertex(producer);
@@ -396,8 +443,11 @@ void DynamicDAG<T>::mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* co
 }
 
 template <typename T>
-void DynamicDAG<T>::mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* consumer) {
-  producer->data.insert(producer->data.end(), consumer->data.begin(), consumer->data.end());
+void DynamicDAG<T>::mergeConsumerIntoProducer(
+    Vertex<T>* producer,
+    Vertex<T>* consumer) {
+  producer->data.insert(
+      producer->data.end(), consumer->data.begin(), consumer->data.end());
 
   auto edges = removeVertex(consumer);
 
@@ -412,11 +462,12 @@ void DynamicDAG<T>::mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* co
   for (auto* parent : edges.in_edges) {
     addEdge(parent, producer);
   }
-
 }
 
 template <typename T>
-bool DynamicDAG<T>::contractionProducesCycle(Vertex<T>* producer, Vertex<T>* consumer) {
+bool DynamicDAG<T>::contractionProducesCycle(
+    Vertex<T>* producer,
+    Vertex<T>* consumer) {
   visited_list<T> visited;
 
   // If there are multiple paths from producer to consumer then contracting
@@ -426,17 +477,22 @@ bool DynamicDAG<T>::contractionProducesCycle(Vertex<T>* producer, Vertex<T>* con
   // producer -> consumer edge.
   size_t upper_bound = consumer->ord;
   for (auto* child : producer->out_edges()) {
-    if (child == consumer) continue;
-    if (child->visited()) continue; // already visited by dfs
-    if (dfsSearch(DFSDirection::forward, child, consumer, upper_bound, visited)) {
+    if (child == consumer)
+      continue;
+    if (child->visited())
+      continue; // already visited by dfs
+    if (dfsSearch(
+            DFSDirection::forward, child, consumer, upper_bound, visited)) {
       return true;
     }
   }
   return false;
 }
 
-
-static bool is_within_bound(DFSDirection direction, size_t value, size_t bound) {
+static bool is_within_bound(
+    DFSDirection direction,
+    size_t value,
+    size_t bound) {
   if (direction == DFSDirection::forward) {
     return value < bound; // upper bound
   } else {
@@ -467,9 +523,9 @@ bool DynamicDAG<T>::dfsSearch(
     auto* vertex = stack.back();
     stack.pop_back();
 
-    auto& next_edges = (direction == DFSDirection::forward) ?
-      vertex->out_edges() :
-      vertex->in_edges();
+    auto& next_edges = (direction == DFSDirection::forward)
+        ? vertex->out_edges()
+        : vertex->in_edges();
 
     for (Vertex<T>* next : next_edges) {
       if (next == end) {
@@ -485,7 +541,6 @@ bool DynamicDAG<T>::dfsSearch(
   return false;
 }
 
-
 // Reorder deltaB vertices to occur before deltaF vertices.
 template <typename T>
 void DynamicDAG<T>::reorder(visited_list<T> deltaF, visited_list<T> deltaB) {
@@ -508,7 +563,8 @@ void DynamicDAG<T>::reorder(visited_list<T> deltaF, visited_list<T> deltaB) {
   }
 
   // Sort the ords by merging two already sorted lists into a large sorted list.
-  // input (example): deltaB = { v(1), v(4), v(7) } , deltaF = { v(0), v(2), v(5) }.
+  // input (example): deltaB = { v(1), v(4), v(7) } ,
+  //                  deltaF = { v(0), v(2), v(5) }.
   // output: { 0, 1, 2, 4, 5, 7 }.
   std::vector<size_t> gathered_ords;
   gathered_ords.reserve(num_affected);
@@ -519,7 +575,10 @@ void DynamicDAG<T>::reorder(visited_list<T> deltaF, visited_list<T> deltaB) {
   for (const auto* v : deltaF_) {
     gathered_ords.push_back(v->ord);
   }
-  std::inplace_merge(gathered_ords.begin(), gathered_ords.begin() + middle, gathered_ords.end());
+  std::inplace_merge(
+      gathered_ords.begin(),
+      gathered_ords.begin() + middle,
+      gathered_ords.end());
 
   // Return the vertices back into the vertices_ storage.
   for (size_t i = 0; i < num_affected; ++i) {
@@ -555,7 +614,7 @@ std::string Vertex<T>::toString() {
       ss << "  " << d;
     }
   }
-  ss << "} ("<< ord << ") -> [";
+  ss << "} (" << ord << ") -> [";
   for (auto* c : out_edges()) {
     ss << c->ord << " ";
   }
@@ -563,4 +622,6 @@ std::string Vertex<T>::toString() {
   return ss.str();
 }
 
-}}}
+} // namespace detail
+} // namespace jit
+} // namespace torch
index f7207f7..d6cf849 100644 (file)
@@ -1,15 +1,14 @@
 #include <google/protobuf/util/json_util.h>
 #include <google/protobuf/util/type_resolver_util.h>
 
-#include <torch/csrc/jit/export.h>
 #include <torch/csrc/autograd/symbolic.h>
+#include <torch/csrc/jit/export.h>
 #include <torch/csrc/onnx/onnx.h>
 
-#include <torch/csrc/utils/functional.h>
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/python_print.h>
-
+#include <torch/csrc/utils/functional.h>
 
 #include <caffe2/core/types.h>
 #include <caffe2/proto/caffe2_pb.h>
@@ -27,7 +26,8 @@
 #include <string>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 namespace onnx_torch = ::torch::onnx;
@@ -45,86 +45,104 @@ std::string getNodeStackTraceString(const Node* n) {
   return ss.str();
 }
 
-void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_type) {
+void validateBlock(
+    Block* b,
+    onnx_torch::OperatorExportTypes operator_export_type) {
   for (auto node : b->nodes()) {
-    for (Block *sub_block : node->blocks()) {
+    for (Blocksub_block : node->blocks()) {
       validateBlock(sub_block, operator_export_type);
     }
     // Macro'ed so we get a marginally better line number on failed export
-#define FAIL_EXPORT(name) \
-      throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
+#define FAIL_EXPORT(name)                          \
+  throw std::runtime_error(                        \
+      std::string("ONNX export failed: ") + name + \
+      "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
     IR_IF(node, PythonOp)
-      auto py_node = static_cast<torch::jit::PythonOp*>(value);
-      FAIL_EXPORT(
-          "Couldn't export Python operator " + py_node->name() +
-          "\n\nDefined at:\n" + getNodeStackTraceString(node))
+    auto py_node = static_cast<torch::jit::PythonOp*>(value);
+    FAIL_EXPORT(
+        "Couldn't export Python operator " + py_node->name() +
+        "\n\nDefined at:\n" + getNodeStackTraceString(node))
     IR_ELSE()
-      // Special error messages for certain types of operators
-      if (node->kind() == aten::expand) {
-        if (operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
-          WithInsertPoint guard(node);
-          auto* new_node = b->owningGraph()->insertNode(
-            b->owningGraph()->create(Symbol(::torch::jit::onnx::ATen), node->inputs(), node->outputs().size()));
-          for (size_t i = 0; i < node->outputs().size(); ++i) {
-            node->output(i)->replaceAllUsesWith(new_node->output(i));
-          }
-          new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
+    // Special error messages for certain types of operators
+    if (node->kind() == aten::expand) {
+      if (operator_export_type ==
+          onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
+        WithInsertPoint guard(node);
+        auto* new_node = b->owningGraph()->insertNode(b->owningGraph()->create(
+            Symbol(::torch::jit::onnx::ATen),
+            node->inputs(),
+            node->outputs().size()));
+        for (size_t i = 0; i < node->outputs().size(); ++i) {
+          node->output(i)->replaceAllUsesWith(new_node->output(i));
         }
+        new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
       }
-      if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
-        FAIL_EXPORT(
-            "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
-            getNodeStackTraceString(node));
-      }
-      bool is_aten_enabled = operator_export_type ==
-              onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
-          operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
-      if (!node->kind().is_onnx() && !is_aten_enabled &&
-          node->kind() != prim::Undefined) {
-        FAIL_EXPORT(
-            "Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" +
-            getNodeStackTraceString(node));
-      }
+    }
+    if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
+      FAIL_EXPORT(
+          "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
+          getNodeStackTraceString(node));
+    }
+    bool is_aten_enabled = operator_export_type ==
+            onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
+        operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
+    if (!node->kind().is_onnx() && !is_aten_enabled &&
+        node->kind() != prim::Undefined) {
+      FAIL_EXPORT(
+          "Couldn't export operator " + node->kind().toDisplayString() +
+          "\n\nDefined at:\n" + getNodeStackTraceString(node));
+    }
     IR_END()
 #undef FAIL_EXPORT
   }
 }
 
-void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
+void validateGraph(
+    const std::shared_ptr<Graph>& graph,
+    onnx_torch::OperatorExportTypes operator_export_type) {
   validateBlock(graph->block(), operator_export_type);
   EliminateDeadCode(graph->block());
 }
 
 class EncoderBase {
  public:
-  EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc);
+  EncoderBase(
+      onnx_torch::OperatorExportTypes operator_export_type,
+      bool strip_doc);
 
   onnx::ModelProto get_model_proto() {
     return model_proto_;
   }
 
  protected:
-  void EncodeGraph(onnx::GraphProto *graph_proto,
-                   const std::shared_ptr<Graph> &graph,
-                   const std::vector<at::Tensor> &initializers = {});
+  void EncodeGraph(
+      onnx::GraphProto* graph_proto,
+      const std::shared_ptr<Graph>& graph,
+      const std::vector<at::Tensor>& initializers = {});
 
-  void EncodeBlock(onnx::GraphProto *graph_proto,
-                   const Block *block,
-                   const std::vector<at::Tensor> &initializers = {});
+  void EncodeBlock(
+      onnx::GraphProto* graph_proto,
+      const Block* block,
+      const std::vector<at::Tensor>& initializers = {});
 
   virtual void EncodeTensor(
       onnx::TensorProto* tensor_proto,
       const at::Tensor& tensor,
       const c10::optional<std::string> external_ref = {}) = 0;
 
-  virtual void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
-                                           const Value* n) {};
+  virtual void EncodeIntermediateValueInfo(
+      onnx::GraphProto* graph_proto,
+      const Value* n){};
 
-  virtual void EncodeValueInfo(onnx::GraphProto *graph_proto,
-                               onnx::ValueInfoProto* v,
-                               const Value* n);
+  virtual void EncodeValueInfo(
+      onnx::GraphProto* graph_proto,
+      onnx::ValueInfoProto* v,
+      const Value* n);
 
-  void AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name);
+  void AddAttribute(
+      onnx::NodeProto* node_proto,
+      const jit::Node* node,
+      const jit::Symbol name);
 
   onnx::ModelProto model_proto_;
   size_t num_blocks_;
@@ -133,7 +151,7 @@ class EncoderBase {
 };
 
 onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
-  switch(at_type) {
+  switch (at_type) {
     case at::kDouble:
       return onnx::TensorProto_DataType_DOUBLE;
     case at::kFloat:
@@ -155,7 +173,9 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
   }
 }
 
-EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, bool strip_doc)
+EncoderBase::EncoderBase(
+    onnx_torch::OperatorExportTypes operator_export_type,
+    bool strip_doc)
     : num_blocks_(0),
       operator_export_type_(operator_export_type),
       strip_doc_(strip_doc) {
@@ -165,7 +185,7 @@ EncoderBase::EncoderBase(onnx_torch::OperatorExportTypes operator_export_type, b
 }
 
 void EncoderBase::EncodeValueInfo(
-    onnx::GraphProto *graph_proto,
+    onnx::GraphProtograph_proto,
     onnx::ValueInfoProto* v,
     const Value* n) {
   v->set_name(n->uniqueName());
@@ -186,15 +206,16 @@ void EncoderBase::EncodeValueInfo(
 }
 
 void EncoderBase::EncodeGraph(
-    onnx::GraphProto *graph_proto,
-    const std::shared_ptr<Graph> &graph,
-    const std::vector<at::Tensor> &initializers) {
+    onnx::GraphProtograph_proto,
+    const std::shared_ptr<Graph>graph,
+    const std::vector<at::Tensor>initializers) {
   EncodeBlock(graph_proto, graph->block(), initializers);
 }
 
 void EncoderBase::EncodeBlock(
-    onnx::GraphProto *graph_proto, const Block *block,
-    const std::vector<at::Tensor> &initializers) {
+    onnx::GraphProto* graph_proto,
+    const Block* block,
+    const std::vector<at::Tensor>& initializers) {
   JIT_ASSERT(graph_proto != nullptr);
   std::string block_name = "torch-jit-export";
   if (num_blocks_) {
@@ -212,7 +233,8 @@ void EncoderBase::EncodeBlock(
     EncodeValueInfo(graph_proto, v, output);
   }
   for (auto node : block->nodes()) {
-    bool is_raw_export = operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
+    bool is_raw_export =
+        operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
     if (node->kind() == prim::Undefined && !is_raw_export) {
       // Undefined nodes are used to implement optional inputs. One
       // way to "not provide" an optional input is to create an
@@ -225,26 +247,25 @@ void EncoderBase::EncodeBlock(
       node->getSourceLocation()->highlight(ss);
       p_n->set_doc_string(ss.str());
     }
-    for(auto input : node->inputs()) {
+    for (auto input : node->inputs()) {
       if (input->node()->kind() == prim::Undefined && !is_raw_export) {
         p_n->add_input("");
       } else {
         p_n->add_input(input->uniqueName());
       }
     }
-    for(auto output : node->outputs()) {
+    for (auto output : node->outputs()) {
       p_n->add_output(output->uniqueName());
       EncodeIntermediateValueInfo(graph_proto, output);
     }
     if (is_raw_export) {
       JIT_ASSERT(!node->kind().is_onnx());
       p_n->set_domain(node->kind().domainString());
-    }
-    else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
+    } else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
       JIT_ASSERT(node->kind().is_onnx());
     }
     p_n->set_op_type(node->kind().toUnqualString());
-    for(auto attr_name : node->attributeNames()) {
+    for (auto attr_name : node->attributeNames()) {
       AddAttribute(p_n, node, attr_name);
     }
     if (is_raw_export && node->blocks().size() > 0) {
@@ -284,7 +305,7 @@ void EncoderBase::EncodeBlock(
   auto num_initializers = initializers.size();
   JIT_ASSERT(block->inputs().size() >= num_initializers);
   size_t inputs_count = block->inputs().size() - num_initializers;
-  for (auto & tensor : initializers) {
+  for (auto& tensor : initializers) {
     // TODO: stop using positions to determine which initializers
     // match to which inputs
     std::string name = graph_proto->input(inputs_count++).name();
@@ -294,18 +315,21 @@ void EncoderBase::EncodeBlock(
   }
 }
 
-void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *node, const jit::Symbol name) {
+void EncoderBase::AddAttribute(
+    onnx::NodeProto* node_proto,
+    const jit::Node* node,
+    const jit::Symbol name) {
   auto attr = node_proto->add_attribute();
   JIT_ASSERT(name.is_attr());
   attr->set_name(name.toUnqualString());
-  switch(node->kindOf(name)) {
+  switch (node->kindOf(name)) {
     case AttributeKind::f:
       attr->set_f(node->f(name));
       attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
       break;
     case AttributeKind::fs:
       attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
-      for(auto & v : node->fs(name))
+      for (auto& v : node->fs(name))
         attr->add_floats(v);
       break;
     case AttributeKind::i:
@@ -314,7 +338,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod
       break;
     case AttributeKind::is:
       attr->set_type(onnx::AttributeProto_AttributeType_INTS);
-      for(auto & v : node->is(name))
+      for (auto& v : node->is(name))
         attr->add_ints(v);
       break;
     case AttributeKind::s:
@@ -323,7 +347,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod
       break;
     case AttributeKind::ss:
       attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
-      for(auto & v : node->ss(name))
+      for (auto& v : node->ss(name))
         attr->add_strings(v);
       break;
     case AttributeKind::t: {
@@ -333,7 +357,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod
     } break;
     case AttributeKind::ts:
       attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
-      for(auto & v : node->ts(name)) {
+      for (auto& v : node->ts(name)) {
         auto t = attr->add_tensors();
         EncodeTensor(t, v);
       }
@@ -345,7 +369,7 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod
     } break;
     case AttributeKind::gs:
       attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
-      for(auto & v : node->gs(name)) {
+      for (auto& v : node->gs(name)) {
         auto g = attr->add_graphs();
         EncodeGraph(g, v);
       }
@@ -355,14 +379,15 @@ void EncoderBase::AddAttribute(onnx::NodeProto *node_proto, const jit::Node *nod
   }
 }
 
-class GraphEncoder: public EncoderBase {
+class GraphEncoder : public EncoderBase {
  public:
-  GraphEncoder(const std::shared_ptr<Graph> &graph,
-               int64_t onnx_opset_version,
-               onnx_torch::OperatorExportTypes operator_export_type,
-               const std::vector<at::Tensor> &initializers,
-               bool defer_weight_export,
-               bool strip_doc);
+  GraphEncoder(
+      const std::shared_ptr<Graph>& graph,
+      int64_t onnx_opset_version,
+      onnx_torch::OperatorExportTypes operator_export_type,
+      const std::vector<at::Tensor>& initializers,
+      bool defer_weight_export,
+      bool strip_doc);
 
   RawDataExportMap get_raw_data_export_map() {
     return raw_data_export_map_;
@@ -379,10 +404,10 @@ class GraphEncoder: public EncoderBase {
 };
 
 GraphEncoder::GraphEncoder(
-    const std::shared_ptr<Graph> &graph,
+    const std::shared_ptr<Graph>graph,
     int64_t onnx_opset_version,
     onnx_torch::OperatorExportTypes operator_export_type,
-    const std::vector<at::Tensor> &initializers,
+    const std::vector<at::Tensor>initializers,
     bool defer_weight_export,
     bool strip_doc)
     : EncoderBase(operator_export_type, strip_doc),
@@ -402,7 +427,7 @@ void GraphEncoder::EncodeTensor(
     onnx::TensorProto* tensor_proto,
     const at::Tensor& tensor,
     const c10::optional<std::string> external_ref) {
-  for(auto d : tensor.sizes()) {
+  for (auto d : tensor.sizes()) {
     tensor_proto->add_dims(d);
   }
   tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.type().scalarType()));
@@ -420,7 +445,9 @@ void GraphEncoder::EncodeTensor(
     tensor_proto->set_raw_data("__EXTERNAL");
   } else {
     JIT_ASSERT(t.is_contiguous());
-    tensor_proto->set_raw_data(std::string(static_cast<char*>(t.data_ptr()),  t.type().elementSizeInBytes() * t.numel()));
+    tensor_proto->set_raw_data(std::string(
+        static_cast<char*>(t.data_ptr()),
+        t.type().elementSizeInBytes() * t.numel()));
   }
 }
 
@@ -517,7 +544,8 @@ void ScriptModuleSerializer::convertModel(
   model_def->set_producer_version("1.0"); // TODO: set the producer version
                                           // using appropriate function call
   model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
-  convertModule(module, "", writer_.archiveName(), model_def->mutable_main_module());
+  convertModule(
+      module, "", writer_.archiveName(), model_def->mutable_main_module());
   writeTensorTable(model_def);
 }
 
@@ -564,7 +592,8 @@ void ScriptModuleSerializer::convertAndWriteTensor(
                                /* stride = */ {1})
                            .cpu();
       AT_ASSERT(
-          storage_tensor.type().elementSizeInBytes() * storage_tensor.storage().size() ==
+          storage_tensor.type().elementSizeInBytes() *
+              storage_tensor.storage().size() ==
           record_size);
     }
     std::string name = "tensors/" + std::to_string(tensor_id);
@@ -616,7 +645,8 @@ void ScriptModuleSerializer::convertModule(
     std::stringstream filename;
     filename << "code/" << module_name.str() << ".py";
     std::string methods_str = methods.str();
-    writer_.writeRecord(filename.str(), methods_str.c_str(), methods_str.size());
+    writer_.writeRecord(
+        filename.str(), methods_str.c_str(), methods_str.size());
     record->set_key(filename.str());
   }
 
@@ -656,7 +686,7 @@ void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
 
 void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
   for (int i = 0; i < shape.dim_size(); ++i) {
-    auto &dim = shape.dim(i);
+    autodim = shape.dim(i);
     if (dim.has_dim_value()) {
       stream << dim.dim_value();
     } else {
@@ -676,15 +706,17 @@ void dump(const onnx::TypeProto& type, std::ostream& stream) {
 }
 
 void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
-  stream << "{name: \"" << value_info.name()
-         << "\", type:";
+  stream << "{name: \"" << value_info.name() << "\", type:";
   dump(value_info.type(), stream);
   stream << "}";
 }
 
 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
 
-void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent) {
+void dump(
+    const onnx::AttributeProto& attr,
+    std::ostream& stream,
+    size_t indent) {
   stream << "{ name: '" << attr.name() << "', type: ";
   if (attr.has_f()) {
     stream << "float, value: " << attr.f();
@@ -694,7 +726,7 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent)
     stream << "string, value: '" << attr.s() << "'";
   } else if (attr.has_g()) {
     stream << "graph, value:\n";
-    dump(attr.g(), stream, indent+1);
+    dump(attr.g(), stream, indent + 1);
     stream << nlidt(indent);
   } else if (attr.has_t()) {
     stream << "tensor, value:";
@@ -712,7 +744,8 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent)
   } else if (attr.strings_size()) {
     stream << "strings, values: [";
     for (int i = 0; i < attr.strings_size(); ++i)
-      stream << "'" << attr.strings(i) << "'" << (i == attr.strings_size() - 1 ? "" : " ");
+      stream << "'" << attr.strings(i) << "'"
+             << (i == attr.strings_size() - 1 ? "" : " ");
     stream << "]";
   } else if (attr.tensors_size()) {
     stream << "tensors, values: [";
@@ -723,7 +756,7 @@ void dump(const onnx::AttributeProto& attr, std::ostream& stream, size_t indent)
   } else if (attr.graphs_size()) {
     stream << "graphs, values: [";
     for (auto& g : attr.graphs()) {
-      dump(g, stream, indent+1);
+      dump(g, stream, indent + 1);
     }
     stream << "]";
   } else {
@@ -743,58 +776,56 @@ void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
   }
   stream << "], attributes: [";
   for (int i = 0; i < node.attribute_size(); ++i) {
-    dump(node.attribute(i), stream, indent+1);
+    dump(node.attribute(i), stream, indent + 1);
     stream << (i == node.attribute_size() - 1 ? "" : ",");
   }
   stream << "]}";
 }
 
 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
-  stream << idt(indent) << "GraphProto {" << nlidt(indent+1)
-         << "name: \"" << graph.name() << "\"" << nlidt(indent+1)
-         << "inputs: [";
+  stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
+         << graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
   for (int i = 0; i < graph.input_size(); ++i) {
     dump(graph.input(i), stream);
     stream << (i == graph.input_size() - 1 ? "" : ",");
   }
-  stream << "]" << nlidt(indent+1)
-         << "outputs: [";
+  stream << "]" << nlidt(indent + 1) << "outputs: [";
   for (int i = 0; i < graph.output_size(); ++i) {
     dump(graph.output(i), stream);
     stream << (i == graph.output_size() - 1 ? "" : ",");
   }
-  stream << "]" << nlidt(indent+1)
-         << "initializers: [";
+  stream << "]" << nlidt(indent + 1) << "initializers: [";
   for (int i = 0; i < graph.initializer_size(); ++i) {
     dump(graph.initializer(i), stream);
     stream << (i == graph.initializer_size() - 1 ? "" : ",");
   }
-  stream << "]" << nlidt(indent+1)
-         << "nodes: [" << nlidt(indent+2);
+  stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
   for (int i = 0; i < graph.node_size(); ++i) {
-    dump(graph.node(i), stream, indent+2);
-    if (i != graph.node_size() - 1) stream << "," << nlidt(indent+2);
+    dump(graph.node(i), stream, indent + 2);
+    if (i != graph.node_size() - 1)
+      stream << "," << nlidt(indent + 2);
   }
-  stream << nlidt(indent+1) << "]\n" << idt(indent) << "}\n";
+  stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
 }
 
-void dump(const onnx::OperatorSetIdProto& operator_set_id, std::ostream& stream) {
+void dump(
+    const onnx::OperatorSetIdProto& operator_set_id,
+    std::ostream& stream) {
   stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
 }
 
 void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
-  stream << idt(indent)
-         << "ModelProto {" << nlidt(indent+1)
-         << "producer_name: \"" << model.producer_name() << "\"" << nlidt(indent+1)
-         << "domain: \"" << model.domain() << "\"" << nlidt(indent+1)
-         << "doc_string: \"" << model.doc_string() << "\"";
+  stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
+         << "producer_name: \"" << model.producer_name() << "\""
+         << nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
+         << nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
   if (model.has_graph()) {
-    stream << nlidt(indent+1) << "graph:\n";
-    dump(model.graph(), stream, indent+2);
+    stream << nlidt(indent + 1) << "graph:\n";
+    dump(model.graph(), stream, indent + 2);
   }
   if (model.opset_import_size()) {
-    stream << idt(indent+1) << "opset_import: [";
-    for (auto &opset_imp : model.opset_import()) {
+    stream << idt(indent + 1) << "opset_import: [";
+    for (autoopset_imp : model.opset_import()) {
       dump(opset_imp, stream);
     }
     stream << "],\n";
@@ -811,14 +842,19 @@ std::string prettyPrint(const onnx::ModelProto& model) {
 } // namespace
 
 std::string pretty_print_onnx(
-                        const std::shared_ptr<Graph> &graph,
-                        const std::vector<at::Tensor> &initializers,
-                        int64_t onnx_opset_version,
-                        bool defer_weight_export,
-                        ::torch::onnx::OperatorExportTypes operator_export_type,
-                        bool google_printer) {
+    const std::shared_ptr<Graph>& graph,
+    const std::vector<at::Tensor>& initializers,
+    int64_t onnx_opset_version,
+    bool defer_weight_export,
+    ::torch::onnx::OperatorExportTypes operator_export_type,
+    bool google_printer) {
   auto graph_encoder = GraphEncoder(
-    graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, true);
+      graph,
+      onnx_opset_version,
+      operator_export_type,
+      initializers,
+      defer_weight_export,
+      true);
   if (google_printer) {
     return graph_encoder.get_model_proto().DebugString();
   }
@@ -831,15 +867,21 @@ std::string pretty_print_onnx(
 // be interpretable by a ONNX-compatible framework. However, PyTorch or
 // libtorch will be able to import the IR and play it back.
 std::tuple<std::string, RawDataExportMap> export_onnx(
-                        const std::shared_ptr<Graph> &graph,
-                        const std::vector<at::Tensor> &initializers,
-                        int64_t onnx_opset_version,
-                        bool defer_weight_export,
-                        ::torch::onnx::OperatorExportTypes operator_export_type) {
+    const std::shared_ptr<Graph>& graph,
+    const std::vector<at::Tensor>& initializers,
+    int64_t onnx_opset_version,
+    bool defer_weight_export,
+    ::torch::onnx::OperatorExportTypes operator_export_type) {
   auto graph_encoder = GraphEncoder(
-    graph, onnx_opset_version, operator_export_type, initializers, defer_weight_export, false);
-  return std::make_tuple(graph_encoder.get_model_proto().SerializeAsString(),
-                         graph_encoder.get_raw_data_export_map());
+      graph,
+      onnx_opset_version,
+      operator_export_type,
+      initializers,
+      defer_weight_export,
+      false);
+  return std::make_tuple(
+      graph_encoder.get_model_proto().SerializeAsString(),
+      graph_encoder.get_raw_data_export_map());
 }
 
 void ExportModule(const script::Module& module, std::ostream& out) {
@@ -847,9 +889,10 @@ void ExportModule(const script::Module& module, std::ostream& out) {
   serializer.serialize(module);
 }
 
-void ExportModule(const script::Module& module, const std::string &filename) {
+void ExportModule(const script::Module& module, const std::stringfilename) {
   ScriptModuleSerializer serializer(filename);
   serializer.serialize(module);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 357f2ae..1e274d8 100644 (file)
@@ -6,7 +6,8 @@
 
 #include <ostream>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // This map is used to keep track of parameters that should be exported
 // externally. When `defer_weight_export` is true, the returned map contains
@@ -23,25 +24,24 @@ TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::vector<at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export = false,
-    ::torch::onnx::OperatorExportTypes operator_export_type
-      = ::torch::onnx::OperatorExportTypes::ONNX);
+    ::torch::onnx::OperatorExportTypes operator_export_type =
+        ::torch::onnx::OperatorExportTypes::ONNX);
 
 // For testing purposes
 TORCH_API std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor> & initializers,
+    const std::vector<at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
-    ::torch::onnx::OperatorExportTypes operator_export_type
-      = ::torch::onnx::OperatorExportTypes::ONNX,
+    ::torch::onnx::OperatorExportTypes operator_export_type =
+        ::torch::onnx::OperatorExportTypes::ONNX,
     bool google_printer = false);
 
-TORCH_API void ExportModule(
-    const script::Module& module,
-    std::ostream& out);
+TORCH_API void ExportModule(const script::Module& module, std::ostream& out);
 
 TORCH_API void ExportModule(
     const script::Module& module,
     const std::string& filename);
 
-}}
+} // namespace jit
+} // namespace torch
index a3a6110..350783c 100644 (file)
@@ -1,8 +1,10 @@
 #include <ATen/core/function_schema.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-using ::c10::FunctionSchema;
 using ::c10::Argument;
+using ::c10::FunctionSchema;
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index e8c7352..d099395 100644 (file)
@@ -4,14 +4,16 @@
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/fuser/tensor_desc.h>
 #include <torch/csrc/utils/functional.h> // fmap
 #include <torch/csrc/utils/hash.h>
-#include <torch/csrc/jit/fuser/tensor_desc.h>
 
-#include <vector>
 #include <cstdint>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Describes the (runtime) arguments to a kernel.
 // ArgSpecs are also used as keys to lookup instantiated kernels, so
@@ -19,22 +21,19 @@ namespace torch { namespace jit { namespace fuser {
 // Note: the device to run on is included in the arg spec because kernels
 //  are compiled per-device.
 struct TORCH_API ArgSpec {
-  ArgSpec(
-    at::TensorList inputs
-  , const int _device)
-  : descs_{fmap<TensorDesc>(inputs)}
-  , hash_code_{torch::get_hash(_device, inputs.size(), descs_)} 
-  , device_{_device}
-  { }
+  ArgSpec(at::TensorList inputs, const int _device)
+      : descs_{fmap<TensorDesc>(inputs)},
+        hash_code_{torch::get_hash(_device, inputs.size(), descs_)},
+        device_{_device} {}
 
   // (Common) hash function
-  static size_t hash(const ArgSpec& spec) { return spec.hash_code_; }
+  static size_t hash(const ArgSpec& spec) {
+    return spec.hash_code_;
+  }
 
   // Comparators
   bool operator==(const ArgSpec& other) const {
-    return (
-       descs_ == other.descs_
-    && device_ == other.device_);
+    return (descs_ == other.descs_ && device_ == other.device_);
   }
 
   bool operator!=(const ArgSpec& spec) const {
@@ -42,18 +41,24 @@ struct TORCH_API ArgSpec {
   }
 
   // Getters
-  size_t hashCode() const { return hash_code_; }
-  const std::vector<TensorDesc>& descs() const { return descs_; }
-  int device() const { return device_; }
+  size_t hashCode() const {
+    return hash_code_;
+  }
+  const std::vector<TensorDesc>& descs() const {
+    return descs_;
+  }
+  int device() const {
+    return device_;
+  }
 
-private:
+ private:
   std::vector<TensorDesc> descs_;
   size_t hash_code_;
   int device_;
 };
 
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CUDA_FUSER || USE_CPU_FUSER
index 47d6445..512865e 100644 (file)
@@ -1,30 +1,32 @@
 #include <torch/csrc/jit/fuser/codegen.h>
 
 #include <ATen/ATen.h>
-#include <torch/csrc/jit/code_template.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/fuser/compiler.h>
 #include <torch/csrc/jit/fuser/config.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/tensor_info.h>
+#include <torch/csrc/jit/ir.h>
 
 #if USE_CUDA_FUSER
-  #include <torch/csrc/jit/fuser/cuda/resource_strings.h>
+#include <torch/csrc/jit/fuser/cuda/resource_strings.h>
 #endif
 
 #if USE_CPU_FUSER
-  #include <torch/csrc/jit/fuser/cpu/resource_strings.h>
+#include <torch/csrc/jit/fuser/cpu/resource_strings.h>
 #endif
 
-#include <tuple>
+#include <cmath>
+#include <cstdint>
 #include <iostream>
 #include <sstream>
-#include <cstdint>
+#include <tuple>
 #include <vector>
-#include <cmath>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Template for computing the offset into the tensor to access a value
 static auto dim_calc = CodeTemplate(R"(
@@ -33,7 +35,6 @@ size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
 ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
 )");
 
-
 static std::string valueName(const Value* n) {
   return "n" + std::to_string(n->unique());
 }
@@ -71,11 +72,12 @@ static const char* scalarTypeName(const at::ScalarType type) {
     return "half";
   }
 
-  switch(type) {
-    #define DEFINE_CASE(ctype,name,_) \
-      case at::ScalarType::name: return #ctype;
+  switch (type) {
+#define DEFINE_CASE(ctype, name, _) \
+  case at::ScalarType::name:        \
+    return #ctype;
     AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE)
-    #undef DEFINE_CASE
+#undef DEFINE_CASE
     default:
       throw std::runtime_error("unknown scalar type");
   }
@@ -88,7 +90,6 @@ static const char* calcScalarTypeName(const at::ScalarType type) {
   return scalarTypeName(type);
 }
 
-
 static std::string variableType(const std::shared_ptr<c10::Type>& t) {
   if (t->kind() == TypeKind::IntType) {
     return "int";
@@ -101,17 +102,21 @@ static std::string variableType(const std::shared_ptr<c10::Type>& t) {
     return calcScalarTypeName(tt->scalarType());
   }
   // something went wrong with the type analysis during shape propagation
-  throw std::runtime_error("unknown scalar type during JIT fusion code generation");
+  throw std::runtime_error(
+      "unknown scalar type during JIT fusion code generation");
 }
 
-static std::string typeCastedValueName(const std::shared_ptr<c10::Type>& t, const at::ScalarType outtype, const std::string& vn) {
+static std::string typeCastedValueName(
+    const std::shared_ptr<c10::Type>& t,
+    const at::ScalarType outtype,
+    const std::string& vn) {
   if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
-    if (! isIntegralType(outtype)) {
+    if (!isIntegralType(outtype)) {
       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
     }
     return vn;
   } else if (t->kind() == TypeKind::FloatType) {
-    if (! isFloatingType(outtype)) {
+    if (!isFloatingType(outtype)) {
       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
     }
     return vn;
@@ -123,89 +128,90 @@ static std::string typeCastedValueName(const std::shared_ptr<c10::Type>& t, cons
     return vn;
   }
   // something went wrong with the type analysis during shape propagation
-  throw std::runtime_error("unknown scalar type during JIT fusion code generation");
+  throw std::runtime_error(
+      "unknown scalar type during JIT fusion code generation");
 }
 
 // Writes "simple mappable" ops
 static std::string encodeRHS(const Node* n) {
   static std::unordered_map<NodeKind, std::string> simple_map_ops = {
-    // unary
-    {aten::_cast_Float, "static_cast<float>(${0})"},
-    {aten::abs, "fabs(${0})"},
-    {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
-    {aten::relu, "${0} < 0 ? 0.f : ${0} "},
-    {aten::log, "logf(${0})"},
-    {aten::log10, "log10f(${0})"},
-    {aten::log1p, "log1pf(${0})"},
-    {aten::log2,  "log2f(${0})"},
-    {aten::lgamma, "lgammaf(${0})"},
-    {aten::exp, "expf(${0})"},
-    {aten::expm1, "expm1f(${0})"},
-    {aten::erf, "erff(${0})"},
-    {aten::erfc, "erfcf(${0})"},
-    {aten::cos, "cosf(${0})"},
-    {aten::acos, "acosf(${0})"},
-    {aten::cosh, "coshf(${0})"},
-    {aten::sin, "sinf(${0})"},
-    {aten::asin, "asinf(${0})"},
-    {aten::sinh, "sinhf(${0})"},
-    {aten::tan, "tanf(${0})"},
-    {aten::atan, "atanf(${0})"},
-    {aten::tanh, "tanhf(${0})"},
-    {aten::sqrt, "sqrtf(${0})"},
-    {aten::rsqrt, "rsqrtf(${0})"},
-    {aten::ceil, "ceilf(${0})"},
-    {aten::floor, "floorf(${0})"},
-    {aten::round, "roundf(${0})"},
-    {aten::trunc, "truncf(${0})"},
-    {aten::frac, "fracf(${0})"},
-    {aten::reciprocal, "1.f/(${0})"},
-    {aten::neg, "-${0}"},
-    //simple binary
-    {aten::atan2, "atan2(${0}, ${1})"},
-    {aten::min, "fminf(${0}, ${1})"},
-    {aten::max, "fmaxf(${0}, ${1})"},
-
-    //binary with other
-    // TODO: some of these ops will not get generated because
-    // we only work on float inputs/outputs, but they are here to record
-    // that they are valid mappable ops once we handle more type
-
-    {aten::__and__, "${0} && ${1}"},
-    {aten::__lshift__, "${0} << ${1}"},
-    {aten::__or__, "${0} || ${1}"},
-    {aten::__rshift__, "${0} >> ${1}"},
-    {aten::__xor__, "${0} ^ ${1}"},
-    {aten::div, "${cast_0} / ${cast_1}"},
-    {aten::eq, "${0} == ${1}"},
-    {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
-    {aten::ge, "(${0} >= ${1})"},
-    {aten::gt, "${0} > ${1}"},
-    {aten::le, "(${0} <= ${1})"},
-    {aten::lt, "${0} < ${1}"},
-    {aten::type_as, "(${cast_0})"},
-    {aten::mul, "${cast_0} * ${cast_1}"},
-    {aten::ne, "${0} != ${1}"},
-    {aten::remainder, "remainderf(${0}, ${1})"},
-    {aten::pow, "powf(${cast_0}, ${cast_1})"},
-
-    //alpha
-    {aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
-    {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
-    {aten::rand_like, "uniform(rnd())"},
-
-    // min, max
-    // It may seem unusual to have the bounds as the first case below,
-    // this is so that if min or max is NaN, they are "ignored"
-    // and when the input is NaN, the output is, too
-    {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"},
-
-    //where
-    {aten::where, "(${0} ? ${1} : ${2})"},
-
-    // simple derivatives
-    {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
-    {aten::_tanh_backward,    "${0} * (1.f - ${1} * ${1})"},
+      // unary
+      {aten::_cast_Float, "static_cast<float>(${0})"},
+      {aten::abs, "fabs(${0})"},
+      {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
+      {aten::relu, "${0} < 0 ? 0.f : ${0} "},
+      {aten::log, "logf(${0})"},
+      {aten::log10, "log10f(${0})"},
+      {aten::log1p, "log1pf(${0})"},
+      {aten::log2, "log2f(${0})"},
+      {aten::lgamma, "lgammaf(${0})"},
+      {aten::exp, "expf(${0})"},
+      {aten::expm1, "expm1f(${0})"},
+      {aten::erf, "erff(${0})"},
+      {aten::erfc, "erfcf(${0})"},
+      {aten::cos, "cosf(${0})"},
+      {aten::acos, "acosf(${0})"},
+      {aten::cosh, "coshf(${0})"},
+      {aten::sin, "sinf(${0})"},
+      {aten::asin, "asinf(${0})"},
+      {aten::sinh, "sinhf(${0})"},
+      {aten::tan, "tanf(${0})"},
+      {aten::atan, "atanf(${0})"},
+      {aten::tanh, "tanhf(${0})"},
+      {aten::sqrt, "sqrtf(${0})"},
+      {aten::rsqrt, "rsqrtf(${0})"},
+      {aten::ceil, "ceilf(${0})"},
+      {aten::floor, "floorf(${0})"},
+      {aten::round, "roundf(${0})"},
+      {aten::trunc, "truncf(${0})"},
+      {aten::frac, "fracf(${0})"},
+      {aten::reciprocal, "1.f/(${0})"},
+      {aten::neg, "-${0}"},
+      // simple binary
+      {aten::atan2, "atan2(${0}, ${1})"},
+      {aten::min, "fminf(${0}, ${1})"},
+      {aten::max, "fmaxf(${0}, ${1})"},
+
+      // binary with other
+      // TODO: some of these ops will not get generated because
+      // we only work on float inputs/outputs, but they are here to record
+      // that they are valid mappable ops once we handle more type
+
+      {aten::__and__, "${0} && ${1}"},
+      {aten::__lshift__, "${0} << ${1}"},
+      {aten::__or__, "${0} || ${1}"},
+      {aten::__rshift__, "${0} >> ${1}"},
+      {aten::__xor__, "${0} ^ ${1}"},
+      {aten::div, "${cast_0} / ${cast_1}"},
+      {aten::eq, "${0} == ${1}"},
+      {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
+      {aten::ge, "(${0} >= ${1})"},
+      {aten::gt, "${0} > ${1}"},
+      {aten::le, "(${0} <= ${1})"},
+      {aten::lt, "${0} < ${1}"},
+      {aten::type_as, "(${cast_0})"},
+      {aten::mul, "${cast_0} * ${cast_1}"},
+      {aten::ne, "${0} != ${1}"},
+      {aten::remainder, "remainderf(${0}, ${1})"},
+      {aten::pow, "powf(${cast_0}, ${cast_1})"},
+
+      // alpha
+      {aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
+      {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
+      {aten::rand_like, "uniform(rnd())"},
+
+      // min, max
+      // It may seem unusual to have the bounds as the first case below,
+      // this is so that if min or max is NaN, they are "ignored"
+      // and when the input is NaN, the output is, too
+      {aten::clamp, "(${0}<${1}?${1}:(${0}>${2}?${2}:${0}))"},
+
+      // where
+      {aten::where, "(${0} ? ${1} : ${2})"},
+
+      // simple derivatives
+      {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
+      {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
   };
 
   if (n->kind() == prim::Constant) {
@@ -222,16 +228,19 @@ static std::string encodeRHS(const Node* n) {
 
   TemplateEnv env;
   size_t i = 0;
-  auto outtype = n->output()->type()->expect<c10::TensorType const>()->scalarType();
-  for(auto in : n->inputs()) {
-    // PyTorch converts (scalar) argument types to result before applying the operator
-    // e.g. 1.4-torch.tensor(3) = -2
+  auto outtype =
+      n->output()->type()->expect<c10::TensorType const>()->scalarType();
+  for (auto in : n->inputs()) {
+    // PyTorch converts (scalar) argument types to result before applying the
+    // operator e.g. 1.4-torch.tensor(3) = -2
     env.s(std::to_string(i), valueName(in));
-    env.s(std::string("cast_")+std::to_string(i), typeCastedValueName(in->type(), outtype, valueName(in)));
+    env.s(
+        std::string("cast_") + std::to_string(i),
+        typeCastedValueName(in->type(), outtype, valueName(in)));
     i++;
   }
 
-  const auto & str = simple_map_ops.at(n->kind());
+  const auto& str = simple_map_ops.at(n->kind());
   return format(str, env);
 }
 
@@ -240,7 +249,7 @@ static std::string encodeRHS(const Node* n) {
 static Node* usedInFusedChunk(const Value* input) {
   auto uses = input->uses();
   if (uses.size() == 1) {
-    Node *user = uses[0].user;
+    Nodeuser = uses[0].user;
     if (user->kind() == prim::ConstantChunk) {
       return user;
     }
@@ -249,41 +258,46 @@ static Node* usedInFusedChunk(const Value* input) {
 }
 
 static void emitIndexingFor(
-  std::ostream& out
-, const std::string& tensor
-, const int ndim
-, const bool last_is_cont) {
+    std::ostream& out,
+    const std::string& tensor,
+    const int ndim,
+    const bool last_is_cont) {
   TemplateEnv env;
-  env.s("tensor",tensor);
-  out << format("IndexType ${tensor}_offset = 0;\n",env);
-  out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env);
+  env.s("tensor", tensor);
+  out << format("IndexType ${tensor}_offset = 0;\n", env);
+  out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
   for (int d = ndim - 1; d >= 0; --d) {
-    env.d("d",d);
-    env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : "");
-    env.s("times_stride",(d < ndim - 1 || !last_is_cont) ?
-      format("* ${tensor}.strides[${d}]",env) : "");
+    env.d("d", d);
+    env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
+    env.s(
+        "times_stride",
+        (d < ndim - 1 || !last_is_cont)
+            ? format("* ${tensor}.strides[${d}]", env)
+            : "");
     out << dim_calc.format(env);
     if (d > 0) {
-      out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env);
+      out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
     }
   }
 }
 
 // TODO: handle cases where we need to generate > 2^32 element tensors
 std::tuple<
-  std::string
-, std::vector<PartitionDesc>
-, std::vector<PartitionDesc>
-, bool>
+    std::string,
+    std::vector<PartitionDesc>,
+    std::vector<PartitionDesc>,
+    bool>
 generateKernel(
-  const std::string& name
-, const Graph& graph
-, const std::vector<TensorDesc>& input_desc
-, const std::vector<TensorDesc>& output_desc
-, const bool use_cuda) {
+    const std::string& name,
+    const Graph& graph,
+    const std::vector<TensorDesc>& input_desc,
+    const std::vector<TensorDesc>& output_desc,
+    const bool use_cuda) {
   TemplateEnv env;
   env.s("kernelName", name);
-  env.s("IndexType","unsigned int"); // Note: not uint32_t to avoid including cstdint
+  env.s(
+      "IndexType",
+      "unsigned int"); // Note: not uint32_t to avoid including cstdint
 
   std::stringstream body;
   std::stringstream tensorOffsets;
@@ -292,15 +306,24 @@ generateKernel(
 
   // Lambda for writing arguments
   auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
-    std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output
+    std::string tensor =
+        "t" +
+        std::to_string(
+            formals.size()); // can't be unique() because Param may be an output
     const auto nDim = desc.nDim();
-    emitIndexingFor(tensorOffsets, tensor, nDim,  desc.lastIsContiguous());
+    emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
     env.s("tensor", tensor);
-    env.d("formal_index", formals.size() + 1); // + 1 because the first argument is the linearIndex
+    env.d(
+        "formal_index",
+        formals.size() +
+            1); // + 1 because the first argument is the linearIndex
     env.d("nDim", nDim);
     env.s("scalar_type", scalarTypeName(desc.scalar_type));
-    formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
-    argument_loads.push_back(format("*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])", env));
+    formals.push_back(
+        format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
+    argument_loads.push_back(format(
+        "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
+        env));
   };
 
   // Writes input parameters and creates flattened inputs
@@ -308,7 +331,7 @@ generateKernel(
   std::vector<std::pair<const Value*, const TensorDesc&>> flat_inputs;
   {
     size_t input_index = 0;
-    for(const auto& p : graph.inputs()) {
+    for (const auto& p : graph.inputs()) {
       if (const Node* chunk = usedInFusedChunk(p)) {
         int64_t dim = chunk->i(attr::dim);
         int64_t chunks = chunk->i(attr::chunks);
@@ -340,7 +363,7 @@ generateKernel(
       } else {
         const auto cat = o->node();
         concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
-        for(const auto& c : cat->inputs()) {
+        for (const auto& c : cat->inputs()) {
           emitFormal(c, *concat_desc.back().subTensorDesc());
           flat_output_nodes.emplace_back(c, desc);
         }
@@ -364,8 +387,8 @@ generateKernel(
     if (is_half) {
       JIT_ASSERT(use_cuda);
       env.s(
-        "access"
-      , format("__half2float(t${formal}.data[t${formal}_offset])", env));
+          "access",
+          format("__half2float(t${formal}.data[t${formal}_offset])", env));
       has_half_tensor = true;
     } else {
       env.s("access", format("t${formal}.data[t${formal}_offset]", env));
@@ -380,9 +403,12 @@ generateKernel(
   // Note: Concat and Chunk are implicitly generated
   // Note: Random number generation is only supported for CUDA kernels.
   for (const auto& n : graph.nodes()) {
-    // Note: FusedConcat nodes work by narrowing the output Tensors before the kernel runs
-    if (n->kind() == prim::FusedConcat) continue;
-    if (n->kind() == prim::ConstantChunk) continue;
+    // Note: FusedConcat nodes work by narrowing the output Tensors before the
+    // kernel runs
+    if (n->kind() == prim::FusedConcat)
+      continue;
+    if (n->kind() == prim::ConstantChunk)
+      continue;
     if (n->kind() == aten::rand_like) {
       JIT_ASSERT(use_cuda);
       has_random = true;
@@ -390,14 +416,14 @@ generateKernel(
     env.s("node", valueName(n->output()));
     env.s("rhs", encodeRHS(n));
     env.s("lhs_type", variableType(n->output()->type()));
-    body << format("${lhs_type} ${node} = ${rhs};\n",env);
+    body << format("${lhs_type} ${node} = ${rhs};\n", env);
   }
 
   // Generates writes to output tensors
   for (const auto& output : flat_output_nodes) {
     const auto& o = output.first;
     env.d("formal", formal_count++);
-    env.s("access", format("t${formal}.data[t${formal}_offset]",env));
+    env.s("access", format("t${formal}.data[t${formal}_offset]", env));
     env.s("node", valueName(o));
 
     // Acquires and converts (if needed) outputs
@@ -405,32 +431,32 @@ generateKernel(
     const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
     if (is_half) {
       JIT_ASSERT(use_cuda);
-      body << format("${access} = __float2half(${node});\n",env);
+      body << format("${access} = __float2half(${node});\n", env);
       has_half_tensor = true;
     } else {
-      body << format("${access} = ${node};\n",env);
+      body << format("${access} = ${node};\n", env);
     }
   }
 
-  // Includes headers
-  // Note: CUDA kernels support halfs and random generation, CPU kernels do not
-  #if USE_CUDA_FUSER
-    if (has_half_tensor) {
-      env.s("HalfHeader", cuda::half_support_literal);
-    } else {
-      env.s("HalfHeader", "");
-    }
+// Includes headers
+// Note: CUDA kernels support halfs and random generation, CPU kernels do not
+#if USE_CUDA_FUSER
+  if (has_half_tensor) {
+    env.s("HalfHeader", cuda::half_support_literal);
+  } else {
+    env.s("HalfHeader", "");
+  }
 
-    if (has_random) {
-      env.s("RandHeader", cuda::rand_support_literal);
-      env.s("RandParam", cuda::rand_param);
-      env.s("RandInit", cuda::rand_init);
-    } else {
-      env.s("RandHeader", "");
-      env.s("RandParam", "");
-      env.s("RandInit", "");
-    }
-  #endif // USE_CUDA_FUSER
+  if (has_random) {
+    env.s("RandHeader", cuda::rand_support_literal);
+    env.s("RandParam", cuda::rand_param);
+    env.s("RandInit", cuda::rand_init);
+  } else {
+    env.s("RandHeader", "");
+    env.s("RandParam", "");
+    env.s("RandInit", "");
+  }
+#endif // USE_CUDA_FUSER
 
   // Insantiates the CUDA or CPU-specific templates
   env.s("tensorOffsets", tensorOffsets.str());
@@ -439,25 +465,26 @@ generateKernel(
   env.v("argument_loads", argument_loads);
   std::string code_string;
   if (use_cuda) {
-    #if USE_CUDA_FUSER
-      env.s("type_declarations", cuda::type_declarations_template.format(env));
-      code_string = cuda::cuda_compilation_unit_template.format(env);
-    #else
-      throw std::runtime_error("CUDA Fusion requested but not supported.");
-    #endif // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+    env.s("type_declarations", cuda::type_declarations_template.format(env));
+    code_string = cuda::cuda_compilation_unit_template.format(env);
+#else
+    throw std::runtime_error("CUDA Fusion requested but not supported.");
+#endif // USE_CUDA_FUSER
   } else {
-    #if USE_CPU_FUSER
-      env.s("type_declarations", cpu::type_declarations_template.format(env));
-      code_string = cpu::cpu_compilation_unit_template.format(env);
-    #else
-      throw std::runtime_error("CPU Fusion requested but not supported");
-    #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+    env.s("type_declarations", cpu::type_declarations_template.format(env));
+    code_string = cpu::cpu_compilation_unit_template.format(env);
+#else
+    throw std::runtime_error("CPU Fusion requested but not supported");
+#endif // USE_CPU_FUSER
   }
 
   if (debugFuser()) {
     std::cerr << "fusion code:" << code_string << std::endl;
   }
-  return std::make_tuple(code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
+  return std::make_tuple(
+      code_string, std::move(chunk_desc), std::move(concat_desc), has_random);
 }
 
 } // namespace fuser
index a86a336..26ce490 100644 (file)
@@ -3,35 +3,37 @@
 #if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/fuser/arg_spec.h>
 #include <torch/csrc/jit/fuser/partition_desc.h>
 #include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/jit/ir.h>
 
-#include <tuple>
-#include <vector>
 #include <iostream>
 #include <string>
+#include <tuple>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Creates a CPU or CUDA kernel for the given graph.
 // Returns a tuple consisting of the generated code (as a string),
-// two vectors of PartitionDescs, the chunk and concat descriptions, 
-// respectively, and a bool indicating whether the generated code 
+// two vectors of PartitionDescs, the chunk and concat descriptions,
+// respectively, and a bool indicating whether the generated code
 // generates random numbers.
 // TODO: the partition descriptions should be generated by the executor.
 TORCH_API std::tuple<
-  std::string
-, std::vector<PartitionDesc>
-, std::vector<PartitionDesc>
-, bool> 
+    std::string,
+    std::vector<PartitionDesc>,
+    std::vector<PartitionDesc>,
+    bool>
 generateKernel(
-  const std::string& name
-, const Graph& graph
-, const std::vector<TensorDesc>& input_desc
-, const std::vector<TensorDesc>& output_desc
-, const bool use_cuda);
+    const std::string& name,
+    const Graph& graph,
+    const std::vector<TensorDesc>& input_desc,
+    const std::vector<TensorDesc>& output_desc,
+    const bool use_cuda);
 
 } // namespace fuser
 } // namespace jit
index 74d6dc3..d94ba85 100644 (file)
@@ -1,44 +1,48 @@
 #include <torch/csrc/jit/fuser/compiler.h>
 
 #include <ATen/ATen.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/code_template.h>
+#include <torch/csrc/jit/fuser/codegen.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
-#include <torch/csrc/jit/fuser/codegen.h>
 #include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/passes/canonicalize.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/type.h>
 #include "torch/csrc/jit/fuser/interface.h"
 
 #if USE_CUDA_FUSER
-  #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
+#include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
 #endif // USE_CUDA_FUSER
 
 #if USE_CPU_FUSER
-  #include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
+#include <torch/csrc/jit/fuser/cpu/fused_kernel.h>
 #endif // USE_CUDA_FUSER
 
+#include <atomic>
 #include <iostream>
 #include <memory>
-#include <unordered_set>
-#include <utility>
-#include <string>
-#include <atomic>
 #include <sstream>
 #include <stdexcept>
+#include <string>
 #include <tuple>
+#include <unordered_set>
+#include <utility>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Counter for number of kernels compiled, used for debugging and
 // creating arbitrary kernel names.
 static std::atomic<size_t> next_kernel_id{0};
 static int debug_fusion{-1};
 
-size_t nCompiledKernels() { return next_kernel_id.load(); }
+size_t nCompiledKernels() {
+  return next_kernel_id.load();
+}
 
 int debugFuser() {
   if (debug_fusion < 0) {
@@ -53,7 +57,7 @@ int debugFuser() {
 static const Node* usedInFusedChunk(const Value* input) {
   const auto& uses = input->uses();
   if (uses.size() == 1) {
-    const Node *user = uses[0].user;
+    const Nodeuser = uses[0].user;
     if (user->kind() == prim::ConstantChunk) {
       return user;
     }
@@ -65,7 +69,8 @@ static void setInputChunkDescriptors(KernelSpec& spec) {
   spec.inputChunks().reserve((spec.graph())->inputs().size());
   for (const Value* input : (spec.graph())->inputs()) {
     if (const Node* chunk = usedInFusedChunk(input)) {
-      spec.inputChunks().emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim));
+      spec.inputChunks().emplace_back(
+          chunk->i(attr::chunks), chunk->i(attr::dim));
     } else {
       spec.inputChunks().emplace_back(1, 0);
     }
@@ -78,14 +83,15 @@ static std::vector<int64_t> getInputDependencies(const Value* output) {
   std::unordered_set<const Value*> inputs;
   std::unordered_set<const Value*> seen;
   while (!queue.empty()) {
-    const Value* val = queue.back(); queue.pop_back();
+    const Value* val = queue.back();
+    queue.pop_back();
     const Node* producer = val->node();
     if (producer->kind() == prim::Param) {
       inputs.insert(val);
       continue;
     }
     for (const Value* input : producer->inputs()) {
-      if (/*bool inserted = */seen.insert(input).second) {
+      if (/*bool inserted = */ seen.insert(input).second) {
         queue.push_back(input);
       }
     }
@@ -103,7 +109,8 @@ static std::vector<int64_t> getInputDependencies(const Value* output) {
 }
 
 static void setInputBroadcastGroups(KernelSpec& spec) {
-  std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>> broadcast_groups;
+  std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>>
+      broadcast_groups;
   for (const Value* output : (spec.graph())->outputs()) {
     if (output->node()->kind() == prim::FusedConcat) {
       for (const Value* concat_input : output->node()->inputs()) {
@@ -114,9 +121,9 @@ static void setInputBroadcastGroups(KernelSpec& spec) {
     }
   }
   std::copy(
-    broadcast_groups.begin()
-  , broadcast_groups.end()
-  , std::back_inserter(spec.inputBroadcastGroups()));
+      broadcast_groups.begin(),
+      broadcast_groups.end(),
+      std::back_inserter(spec.inputBroadcastGroups()));
 }
 
 // Performs "upfront" compilation where storage is known but shapes are not.
@@ -156,10 +163,10 @@ int64_t registerFusion(const Node* fusion_group) {
 }
 
 std::shared_ptr<FusedKernel> compileKernel(
-  const KernelSpec& spec
-, const ArgSpec& arg_spec
-, const std::vector<int64_t>& map_size
-, const at::Device device) {
+    const KernelSpec& spec,
+    const ArgSpec& arg_spec,
+    const std::vector<int64_t>& map_size,
+    const at::Device device) {
   const std::vector<TensorDesc>& input_desc = arg_spec.descs();
 
   auto graph = spec.graph()->copy();
@@ -167,7 +174,10 @@ std::shared_ptr<FusedKernel> compileKernel(
   c10::optional<at::ScalarType> scalar_type;
   for (size_t i = 0; i < input_desc.size(); i++) {
     const auto& desc = input_desc[i];
-    graph->inputs()[i]->setType(TensorType::create(desc.scalar_type, device, desc.nDim())); // TODO: nDim is bad, as it is collapsed
+    graph->inputs()[i]->setType(TensorType::create(
+        desc.scalar_type,
+        device,
+        desc.nDim())); // TODO: nDim is bad, as it is collapsed
   }
 
   PropagateInputShapes(graph);
@@ -179,7 +189,8 @@ std::shared_ptr<FusedKernel> compileKernel(
     if (output->node()->kind() == prim::FusedConcat) {
       sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size();
     }
-    auto scalar_type = output->type()->expect<c10::TensorType const>()->scalarType();
+    auto scalar_type =
+        output->type()->expect<c10::TensorType const>()->scalarType();
     auto type = CompleteTensorType::create(scalar_type, device, sizes);
     output_desc.emplace_back(std::move(type));
   }
@@ -190,42 +201,37 @@ std::shared_ptr<FusedKernel> compileKernel(
   std::vector<PartitionDesc> chunk_desc;
   std::vector<PartitionDesc> concat_desc;
   bool has_random;
-  std::tie(code, chunk_desc, concat_desc, has_random)
-    = generateKernel(
-        name
-      , *graph
-      , input_desc
-      , output_desc
-      , use_cuda);
+  std::tie(code, chunk_desc, concat_desc, has_random) =
+      generateKernel(name, *graph, input_desc, output_desc, use_cuda);
 
   std::shared_ptr<FusedKernel> fused_kernel;
   if (use_cuda) {
-    #if USE_CUDA_FUSER
-      fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
-        device.index()
-      , name
-      , code
-      , input_desc
-      , output_desc
-      , chunk_desc
-      , concat_desc
-      , has_random);
-    #else
-      throw std::runtime_error("CUDA Fusion is not supported on this build.");
-    #endif // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+    fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
+        device.index(),
+        name,
+        code,
+        input_desc,
+        output_desc,
+        chunk_desc,
+        concat_desc,
+        has_random);
+#else
+    throw std::runtime_error("CUDA Fusion is not supported on this build.");
+#endif // USE_CUDA_FUSER
   } else {
-    #if USE_CPU_FUSER
-      fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
-        name
-      , code
-      , input_desc
-      , output_desc
-      , chunk_desc
-      , concat_desc
-      , has_random);
-    #else
-      throw std::runtime_error("CPU Fusion is not supported on this build.");
-    #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+    fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
+        name,
+        code,
+        input_desc,
+        output_desc,
+        chunk_desc,
+        concat_desc,
+        has_random);
+#else
+    throw std::runtime_error("CPU Fusion is not supported on this build.");
+#endif // USE_CPU_FUSER
   }
 
   return fused_kernel;
index 2a7f6f0..38e1ef1 100644 (file)
@@ -3,18 +3,20 @@
 #if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/fuser/arg_spec.h>
 #include <torch/csrc/jit/fuser/config.h>
+#include <torch/csrc/jit/fuser/fused_kernel.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_spec.h>
-#include <torch/csrc/jit/fuser/arg_spec.h>
-#include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/stack.h>
 
 #include <cstdint>
 #include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Performs device-independent "upfront" compilation of the given fusion_group,
 // if it has not been registered already.
@@ -25,10 +27,10 @@ TORCH_API int64_t registerFusion(const Node* fusion_group);
 //  with the runtime arguments specified in ArgSpec.
 //  Outputs are allocated using map_size on the specified device.
 TORCH_API std::shared_ptr<FusedKernel> compileKernel(
-  const KernelSpec& spec
-, const ArgSpec& arg_spec
-, const std::vector<int64_t>& map_size
-, const at::Device device);
+    const KernelSpec& spec,
+    const ArgSpec& arg_spec,
+    const std::vector<int64_t>& map_size,
+    const at::Device device);
 
 TORCH_API size_t nCompiledKernels();
 
index 0809591..02306ed 100644 (file)
@@ -1,4 +1,6 @@
 #pragma once
 
+// clang-format off
 #define USE_CPU_FUSER @USE_CPU_FUSER@
 #define USE_CUDA_FUSER @USE_CUDA_FUSER@
+// clang-format on
index 55adc6b..25f8e39 100644 (file)
@@ -3,10 +3,14 @@
 #if USE_CPU_FUSER
 
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/utils/disallow_copy.h>
 
 #include <dlfcn.h>
 
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
 
 static void* checkDL(void* x) {
   if (!x) {
@@ -30,11 +34,12 @@ struct DynamicLibrary {
   }
 
   ~DynamicLibrary() {
-    if (!handle) return;
+    if (!handle)
+      return;
     dlclose(handle);
   }
 
-private:
+ private:
   void* handle = nullptr;
 };
 
index ad11b14..dbe954a 100644 (file)
@@ -3,17 +3,20 @@
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/code_template.h>
 #include <torch/csrc/jit/fuser/compiler.h>
-#include <torch/csrc/jit/fuser/cpu/temp_file.h>
 #include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
+#include <torch/csrc/jit/fuser/cpu/temp_file.h>
 #include <torch/csrc/utils/memory.h>
 
-#include <sstream>
 #include <cstdlib>
 #include <iostream>
-#include <string>
+#include <sstream>
 #include <stdexcept>
+#include <string>
 
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
 
 static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so";
 static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp";
@@ -63,15 +66,15 @@ static CompilerConfig& getConfig() {
 // optimization can be re-enabled by tracking down the platforms where
 // this error occurs and only selectively disabling it.
 static const std::string compile_string =
-  "\"${cxx}\" -O3 -g "
+    "\"${cxx}\" -O3 -g "
 #ifndef __PPC64__
 //  "-march=native "
 #endif
-  "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
+    "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm";
 
 static void runCompiler(
-  const std::string& cpp_file
-, const std::string& so_file) {
+    const std::string& cpp_file,
+    const std::string& so_file) {
   auto& config = getConfig();
   TemplateEnv env;
   env.s("cxx", config.cxx);
@@ -81,15 +84,15 @@ static void runCompiler(
   std::string result = format(compile_string, env);
   int r = system(result.c_str());
   if (config.openmp && r != 0) {
-    std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
+    std::cerr
+        << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n";
     config.openmp = false; // disable for future compiles
     return runCompiler(cpp_file, so_file);
   }
   JIT_ASSERTM(r == 0, "Failed to compile a fused CPU kernel");
 }
 
-static const std::string disas_string =
-  "objdump -M  intel -d \"${so_file}\"";
+static const std::string disas_string = "objdump -M  intel -d \"${so_file}\"";
 static void disas(const std::string& so_file) {
   TemplateEnv env;
   env.s("so_file", so_file);
@@ -119,11 +122,13 @@ FusedKernelCPU::FusedKernelCPU(
   cpp_file.write(code_);
   cpp_file.sync();
   runCompiler(cpp_file.name(), so_file.name());
-  if (debugFuser() >= 2) disas(so_file.name());
+  if (debugFuser() >= 2)
+    disas(so_file.name());
   so_lib = make_unique<DynamicLibrary>(so_file.name().c_str());
-  #pragma GCC diagnostic ignored "-Wpedantic"
-    kernel = reinterpret_cast<void(*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
-  #pragma GCC diagnostic pop
+#pragma GCC diagnostic ignored "-Wpedantic"
+  kernel =
+      reinterpret_cast<void (*)(uint32_t, void**)>(so_lib->sym(name_.c_str()));
+#pragma GCC diagnostic pop
 }
 
 } // namespace cpu
index 3f32018..272c837 100644 (file)
@@ -4,15 +4,18 @@
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/disallow_copy.h>
 #include <torch/csrc/jit/fuser/cpu/dynamic_library.h>
 #include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/utils/disallow_copy.h>
 
-#include <string>
 #include <cstdint>
 #include <memory>
+#include <string>
 
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
 
 // Represents a compiled CPU kernel and the metadata necessary to run it
 struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel {
@@ -34,7 +37,7 @@ struct TORCH_API FusedKernelCPU : public ::torch::jit::fuser::FusedKernel {
     kernel(numel, arguments.data());
   }
 
-private:
+ private:
   std::unique_ptr<DynamicLibrary> so_lib;
   void (*kernel)(uint32_t, void**) = nullptr;
 };
index 63e9051..8d9e13a 100644 (file)
@@ -4,11 +4,15 @@
 
 #include <torch/csrc/jit/code_template.h>
 
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
 
-/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input.
-Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types,
-so typedefs help it handle those cases*/
+/*with type_as not checking type of its input, a fusion group can have non-fp32
+tensor as input. Correct code for this case is generated, however, nvrtc does
+not know how to handle int*_t integer types, so typedefs help it handle those
+cases*/
 
 static auto type_declarations_template = CodeTemplate(R"(
 
@@ -29,9 +33,9 @@ struct TensorInfo<T, 0> {
 )");
 
 static auto cpu_compilation_unit_template = CodeTemplate(R"(
+#include <math.h>
 #include <cstddef>
 #include <cstdint>
-#include <math.h>
 
 template <typename scalar_t>
 scalar_t rsqrtf(scalar_t x) {
@@ -61,7 +65,7 @@ void ${kernelName}(IndexType totalElements, void ** args) {
 
 } // namespace cpu
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CPU_FUSER
index 2ded2ba..b889974 100644 (file)
@@ -4,15 +4,18 @@
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/disallow_copy.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/utils/disallow_copy.h>
 
 #include <unistd.h>
 
 #include <string>
 #include <vector>
 
-namespace torch { namespace jit { namespace fuser { namespace cpu {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cpu {
 
 struct TempFile {
   TH_DISALLOW_COPY_AND_ASSIGN(TempFile);
@@ -38,12 +41,12 @@ struct TempFile {
     fflush(file_);
   }
 
-  void write(const std::string & str) {
+  void write(const std::string& str) {
     size_t result = fwrite(str.c_str(), 1, str.size(), file_);
     JIT_ASSERT(str.size() == result);
   }
 
-  FILE* file()  {
+  FILE* file() {
     return file_;
   }
 
@@ -55,14 +58,15 @@ struct TempFile {
       fclose(file_);
     }
   }
-private:
+
+ private:
   FILE* file_ = nullptr;
   std::string name_;
 };
 
 } // namespace cpu
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CPU_FUSER
index 42432ba..3f6ba60 100644 (file)
@@ -1,42 +1,48 @@
 #include <torch/csrc/jit/fuser/cuda/fused_kernel.h>
 
 #include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
 #include <THC/THC.h>
+#include <c10/cuda/CUDAGuard.h>
 #include <torch/csrc/cuda/cuda_check.h>
 #include <torch/csrc/jit/resource_guard.h>
 
 // Note: unclear why this forward declaration is necessary
-#include <THC/THCTensorRandom.h>
 #include <THC/THCGenerator.hpp>
+#include <THC/THCTensorRandom.h>
 THCGenerator* THCRandom_getGenerator(THCState* state);
 
-#include <nvrtc.h>
 #include <cuda.h>
 #include <cuda_runtime.h>
+#include <nvrtc.h>
 
-#include <stdexcept>
+#include <algorithm>
+#include <cmath>
 #include <sstream>
+#include <stdexcept>
 #include <tuple>
 #include <vector>
-#include <algorithm>
-#include <cmath>
 
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
 
-void checkCUDAVersion(
-  const cudaDeviceProp& prop) {
+void checkCUDAVersion(const cudaDeviceProp& prop) {
   if ((prop.major >= 6 && CUDA_VERSION < 8000) ||
       (prop.major >= 7 && CUDA_VERSION < 9000)) {
     std::stringstream err_string;
-    err_string << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
-         << CUDA_VERSION << " for the current GPU device " << prop.name
-         << " with device capability " << prop.major << "." << prop.minor;
+    err_string
+        << "In CUDAFusedKernel, PyTorch compiled with insufficient CUDA version: "
+        << CUDA_VERSION << " for the current GPU device " << prop.name
+        << " with device capability " << prop.major << "." << prop.minor;
     throw std::runtime_error(err_string.str());
   }
 }
 
-static void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& minor) {
+static void getMajorMinor(
+    const cudaDeviceProp* const prop,
+    int& major,
+    int& minor) {
   int nvrtc_major, nvrtc_minor;
   TORCH_NVRTC_CHECK(nvrtcVersion(&nvrtc_major, &nvrtc_minor));
 
@@ -56,12 +62,16 @@ static void getMajorMinor(const cudaDeviceProp* const prop, int& major, int& min
     minor = 0;
   } else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2
     major = 7;
-    if (prop->major == 7 && prop->minor <= 2) minor = prop->minor;
-    else minor = 0;
+    if (prop->major == 7 && prop->minor <= 2)
+      minor = prop->minor;
+    else
+      minor = 0;
   } else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5
     major = 7;
-    if (prop->major == 7 && prop->minor <= 5) minor = prop->minor;
-    else minor = 0;
+    if (prop->major == 7 && prop->minor <= 5)
+      minor = prop->minor;
+    else
+      minor = 0;
   }
 }
 
@@ -88,9 +98,9 @@ FusedKernelCUDA::FusedKernelCUDA(
   CUcontext pctx = 0;
   TORCH_CU_CHECK(cuCtxGetCurrent(&pctx));
   if (!pctx) {
-     std::unique_lock<std::mutex> cudaFreeMutexLock(
-     *(THCCachingAllocator_getCudaFreeMutex()));
-     cudaFree(0);
+    std::unique_lock<std::mutex> cudaFreeMutexLock(
+        *(THCCachingAllocator_getCudaFreeMutex()));
+    cudaFree(0);
   }
 
   // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
@@ -98,7 +108,8 @@ FusedKernelCUDA::FusedKernelCUDA(
   const auto prior_device = at::cuda::current_device();
   at::cuda::set_device(device_);
 
-  // Acquires device and NVRTC properties (for compile arch and occupancy calculations)
+  // Acquires device and NVRTC properties (for compile arch and occupancy
+  // calculations)
   prop_ = at::cuda::getCurrentDeviceProperties();
   int major, minor;
   getMajorMinor(prop_, major, minor);
@@ -106,15 +117,12 @@ FusedKernelCUDA::FusedKernelCUDA(
   // Creates the NVRTC program
   nvrtcProgram program;
   TORCH_NVRTC_CHECK(nvrtcCreateProgram(
-    &program
-  , code_.c_str()
-  , nullptr
-  , 0
-  , nullptr
-  , nullptr));
-
-  const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor);
-  const std::vector<const char *> args = {"--std=c++11", compute.c_str(), "-default-device"};
+      &program, code_.c_str(), nullptr, 0, nullptr, nullptr));
+
+  const std::string compute = "--gpu-architecture=compute_" +
+      std::to_string(major) + std::to_string(minor);
+  const std::vector<const char*> args = {
+      "--std=c++11", compute.c_str(), "-default-device"};
   const auto result = nvrtcCompileProgram(program, args.size(), args.data());
   if (result == NVRTC_ERROR_COMPILATION) {
     size_t logsize;
@@ -125,9 +133,8 @@ FusedKernelCUDA::FusedKernelCUDA(
     cu << log.data();
     throw std::runtime_error(cu.str());
   }
-  ResourceGuard holdProgram([&] {
-    TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program));
-  });
+  ResourceGuard holdProgram(
+      [&] { TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); });
   TORCH_NVRTC_CHECK(result);
   size_t ptx_size;
   TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
@@ -139,7 +146,7 @@ FusedKernelCUDA::FusedKernelCUDA(
 
   // Computes max blocks
   TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor(
-    &maxBlocks_, function_, 128, 0));
+      &maxBlocks_, function_, 128, 0));
   maxBlocks_ *= prop_->multiProcessorCount;
 
   // Resets device (end of hacked at::DeviceGuard)
@@ -151,8 +158,8 @@ static int ceilDiv(const int a, const int b) {
 }
 
 void FusedKernelCUDA::launch_raw(
-  const uint32_t numel
-, std::vector<void*>& arguments) const {
+    const uint32_t numel,
+    std::vector<void*>& arguments) const {
   at::cuda::CUDAGuard{device_};
   // Hacked at::DeviceGuard (see note above)
   const auto prior_device = at::cuda::current_device();
@@ -164,7 +171,8 @@ void FusedKernelCUDA::launch_raw(
   // Note: offset defined here so its lifetime extends to the launch
   uint64_t offset;
   if (has_random_) {
-    const auto rand_offset = 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
+    const auto rand_offset =
+        4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
     auto gen = THCRandom_getGenerator(at::globalContext().getTHCState());
     offset = gen->state.philox_seed_offset.fetch_add(rand_offset);
     arguments.push_back(&gen->state.initial_seed);
@@ -174,12 +182,17 @@ void FusedKernelCUDA::launch_raw(
   // Launches kernel on current stream (device was set by executor)
   auto stream = at::cuda::getCurrentCUDAStream();
   TORCH_CU_CHECK(cuLaunchKernel(
-    function_,
-    nBlocks, 1, 1,
-    kBlockSize, 1, 1,
-    0, stream,
-    arguments.data(),
-    nullptr));
+      function_,
+      nBlocks,
+      1,
+      1,
+      kBlockSize,
+      1,
+      1,
+      0,
+      stream,
+      arguments.data(),
+      nullptr));
 
   // Resets device (see at::DeviceGuard notes above)
   at::cuda::set_device(prior_device);
index 31a2909..233c001 100644 (file)
@@ -6,15 +6,18 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/fuser/fused_kernel.h>
 
-#include <nvrtc.h>
 #include <cuda.h>
 #include <cuda_runtime.h>
+#include <nvrtc.h>
 
 #include <cstdint>
-#include <vector>
 #include <string>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
 
 // A class holding metadata for an actual CUDA function.
 // Note: CUDA functions are per device.
@@ -40,7 +43,7 @@ struct TORCH_API FusedKernelCUDA : public ::torch::jit::fuser::FusedKernel {
     return at::Backend::CUDA;
   }
 
-private:
+ private:
   static constexpr auto kBlockSize = 128;
 
   // Note: per device to store device properties and compute launch heuristics
index ab3b8c1..ce56b81 100644 (file)
@@ -5,11 +5,15 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/code_template.h>
 
-namespace torch { namespace jit { namespace fuser { namespace cuda {
+namespace torch {
+namespace jit {
+namespace fuser {
+namespace cuda {
 
-/*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input.
-Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types,
-so typedefs help it handle those cases*/
+/*with type_as not checking type of its input, a fusion group can have non-fp32
+tensor as input. Correct code for this case is generated, however, nvrtc does
+not know how to handle int*_t integer types, so typedefs help it handle those
+cases*/
 
 static auto type_declarations_template = CodeTemplate(R"(
 typedef unsigned char uint8_t;
@@ -132,7 +136,8 @@ constexpr auto rand_support_literal = R"(
   }
 )";
 
-constexpr auto rand_param = ",unsigned long long seed, unsigned long long offset";
+constexpr auto rand_param =
+    ",unsigned long long seed, unsigned long long offset";
 
 constexpr auto rand_init = R"(
   int idx = blockIdx.x*blockDim.x + threadIdx.x;
@@ -156,13 +161,12 @@ void ${kernelName}(IndexType totalElements, ${formals} ${RandParam}) {
 }
 )");
 
-
 // This snippet enables half support in the jit. Following the pattern for
 // reductions, fp16 input data is immediately upconverted to float
 // with __half2float(). All mathematical operations are done on float
 // values, and if needed the intermediate float representation is
 // converted to half with __float2half() when writing to a half tensor.
-constexpr auto half_support_literal  = R"(
+constexpr auto half_support_literal = R"(
 #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
 #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
 #if defined(__cplusplus)
@@ -197,7 +201,7 @@ typedef __half half;
 
 } // namespace cuda
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CUDA_FUSER
index 48de0e1..b8f43bc 100644 (file)
@@ -3,31 +3,32 @@
 #include <ATen/ATen.h>
 #include <ATen/ExpandUtils.h>
 #include <c10/util/Optional.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/fuser/compiler.h>
 #include <torch/csrc/jit/fuser/config.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/fuser/kernel_spec.h>
-#include <torch/csrc/jit/fuser/compiler.h>
 #include <torch/csrc/jit/fuser/tensor_info.h>
+#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/utils/functional.h>
 
-#include <vector>
-#include <tuple>
-#include <stdexcept>
 #include <algorithm>
-#include <map>
 #include <iostream> // TODO: remove, debugging only
+#include <map>
+#include <stdexcept>
+#include <tuple>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Returns the "map size" for this run, which is the common size for all
 // intermediate tensors.
 static c10::optional<std::vector<int64_t>> getMapSize(
-  const KernelSpec& spec
-, at::TensorList args
-, at::IntList arg_subset) {
-
+    const KernelSpec& spec,
+    at::TensorList args,
+    at::IntList arg_subset) {
   // TODO: this keeps reallocating map_size at every iteration, but we know
   // exactly how much storage do we need, so this could be fixed in-place at
   // every step. We're just missing a few functions for ATen, but the fix
@@ -47,7 +48,8 @@ static c10::optional<std::vector<int64_t>> getMapSize(
     } else {
       auto tensor_sizes = arg.sizes().vec();
       const auto num_chunks = chunk_desc.nSubTensors();
-      const auto dim = at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
+      const auto dim =
+          at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
       if (tensor_sizes[dim] % num_chunks != 0) {
         return c10::nullopt;
       }
@@ -65,22 +67,27 @@ static c10::optional<std::vector<int64_t>> getMapSize(
 
 // Tries to determine a map size for the instantiated kernel (see above)
 static c10::optional<std::vector<int64_t>> canRunKernel(
-  const KernelSpec& spec
-, at::TensorList args) {
+    const KernelSpec& spec,
+    at::TensorList args) {
   // Short-circuits on size mismatch
   AT_CHECK(
-    args.size() == spec.inputChunks().size()
-  , "Expected ", spec.inputChunks().size(), " arguments, but got ", args.size());
+      args.size() == spec.inputChunks().size(),
+      "Expected ",
+      spec.inputChunks().size(),
+      " arguments, but got ",
+      args.size());
 
   c10::optional<std::vector<int64_t>> map_size;
   for (const auto& broadcast_group : spec.inputBroadcastGroups()) {
     if (!map_size) {
       map_size = getMapSize(spec, args, broadcast_group);
-      if (!map_size) return c10::nullopt;
+      if (!map_size)
+        return c10::nullopt;
     } else {
       const auto group_map_size = getMapSize(spec, args, broadcast_group);
       // Note: this checks that group_map_size is defined AND equal to map_size
-      if (map_size != group_map_size) return c10::nullopt;
+      if (map_size != group_map_size)
+        return c10::nullopt;
     }
   }
 
@@ -92,14 +99,15 @@ static c10::optional<std::vector<int64_t>> canRunKernel(
 // Note: Arguments are mutated by this call, although map_size is restored
 // to its original value.
 static void expandArgs(
-  const KernelSpec& spec
-, std::vector<at::Tensor>& args
-, std::vector<int64_t>& map_size) {
+    const KernelSpec& spec,
+    std::vector<at::Tensor>& args,
+    std::vector<int64_t>& map_size) {
   for (size_t i = 0; i < args.size(); ++i) {
     auto& arg = args[i];
     const auto& pdesc = spec.inputChunks()[i];
     if (pdesc.nSubTensors() == 1) {
-      if (arg.sizes().equals(map_size)) continue;
+      if (arg.sizes().equals(map_size))
+        continue;
       arg = arg.expand(map_size);
     } else {
       map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
@@ -123,8 +131,8 @@ static uint32_t computeNumel(const at::ArrayRef<int64_t>& sizes) {
 
 // Note: Assumes that after at::chunk, all inputs are the same size
 static std::vector<int64_t> computeMapSize(
-  const at::Tensor& tensor
-, const PartitionDesc& chunkDesc) {
+    const at::Tensor& tensor,
+    const PartitionDesc& chunkDesc) {
   std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
   JIT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0);
   sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors();
@@ -134,37 +142,38 @@ static std::vector<int64_t> computeMapSize(
 // Tries to compress sizes and strides according to cont. Emits the result t
 // c_sizes, c_strides and throws an error on failure (if can't compress)
 static void compressContiguous(
-  const at::IntList& sizes
-, const at::IntList& strides
-, const std::vector<bool>& cont
-, uint32_t* c_sizes
-, uint32_t* c_strides) {
+    const at::IntList& sizes,
+    const at::IntList& strides,
+    const std::vector<bool>& cont,
+    uint32_t* c_sizes,
+    uint32_t* c_strides) {
   size_t compressed_dims = 0;
   size_t cur = 0;
   size_t ndim = sizes.size();
   while (cur < ndim) {
     size_t total_size = sizes[cur];
     cur++;
-    while (cont[cur-1] && cur < ndim) {
-      JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]);
+    while (cont[cur - 1] && cur < ndim) {
+      JIT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]);
       total_size *= sizes[cur];
       cur++;
     }
     c_sizes[compressed_dims] = total_size;
-    c_strides[compressed_dims] = strides[cur-1];
+    c_strides[compressed_dims] = strides[cur - 1];
     compressed_dims++;
   }
 
-  if (ndim > 0) JIT_ASSERT(!cont.back() || strides.back() == 1);
+  if (ndim > 0)
+    JIT_ASSERT(!cont.back() || strides.back() == 1);
 }
 
 // Launches the requested fusion on the given device with the given inputs.
 // Output pointers are stored in outputs (to be put on the stack later).
 void launchFusion(
-  const FusedKernel& fusion
-, const at::Device device
-, const at::ArrayRef<at::Tensor>& inputs
-, std::vector<at::Tensor>& outputs) {
+    const FusedKernel& fusion,
+    const at::Device device,
+    const at::ArrayRef<at::Tensor>& inputs,
+    std::vector<at::Tensor>& outputs) {
   // Fails if fusion and given inputs disagree
   JIT_ASSERT(inputs.size() == fusion.inputDesc().size());
 
@@ -195,10 +204,13 @@ void launchFusion(
     numel = computeNumel(map_size);
   }
 
-  // Computes the storage needed to store TensorInfo structs for inputs and outputs.
+  // Computes the storage needed to store TensorInfo structs for inputs and
+  // outputs.
   size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
-  size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
-  size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
+  size_t maxPossibleTensorInfoSize =
+      sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
+  size_t maxPossibleBufferSize =
+      maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
   std::vector<char> buffer(maxPossibleBufferSize);
   char* buffer_next = buffer.data();
 
@@ -207,21 +219,16 @@ void launchFusion(
   arguments.reserve(3 + flat_inputs_size + flat_outputs_size);
   arguments.push_back(&numel);
 
-  auto addTensorInfoRaw = [&](
-    const TensorDesc& desc
-  , void* data_ptr
-  , at::IntList sizes
-  , at::IntList strides) {
+  auto addTensorInfoRaw = [&](const TensorDesc& desc,
+                              void* data_ptr,
+                              at::IntList sizes,
+                              at::IntList strides) {
     const auto nDim = desc.nDim(); // NOTE: this is the compressed dim
     JIT_ASSERT(nDim <= uncompressedDim); // We'd overflow the space otherwise
     auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
     ti->data = data_ptr;
     compressContiguous(
-      sizes
-    , strides
-    , desc.contiguity
-    , ti->sizes(nDim)
-    , ti->strides(nDim));
+        sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
     buffer_next += maxPossibleTensorInfoSize;
     arguments.push_back(ti);
   };
@@ -239,10 +246,12 @@ void launchFusion(
     if (chunk.isNoop()) {
       addTensorInfo(fusion.inputDesc()[i], tensor);
     } else {
-      size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) * elementSize(tensor.type().scalarType());
+      size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) *
+          elementSize(tensor.type().scalarType());
       char* data_ptr = reinterpret_cast<char*>(tensor.data_ptr());
       for (size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) {
-        addTensorInfoRaw(*chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
+        addTensorInfoRaw(
+            *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
         data_ptr += chunk_offset;
       }
     }
@@ -254,7 +263,8 @@ void launchFusion(
   for (size_t i = 0; i < fusion.outputDesc().size(); ++i) {
     const auto& c = fusion.concatDesc()[i];
     if (c.isNoop()) {
-      outputs.push_back(at::empty(map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
+      outputs.push_back(at::empty(
+          map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
       addTensorInfo(fusion.outputDesc()[i], outputs[i]);
     } else {
       size_t small_size = map_size[c.dim()];
@@ -277,12 +287,10 @@ void launchFusion(
   fusion.launch_raw(numel, arguments);
 }
 
-
-bool runFusion(
-  const int64_t key
-, Stack& stack) {
+bool runFusion(const int64_t key, Stack& stack) {
   // Short-circuits if fusion isn't enabled
-  if (!canFuseOnCPU() && !canFuseOnGPU()) return false;
+  if (!canFuseOnCPU() && !canFuseOnGPU())
+    return false;
 
   // Acquires the FusionSpec
   auto maybe_spec = retrieve(key);
@@ -294,8 +302,8 @@ bool runFusion(
     return i.toTensor();
   });
 
-  // Determines device to dispatch to. If there's a device mismatch in the inputs,
-  // we use the fallback (which should give a nice error message).
+  // Determines device to dispatch to. If there's a device mismatch in the
+  // inputs, we use the fallback (which should give a nice error message).
   at::Device device = inputs.at(0).device();
   at::ScalarType dtype = inputs[0].type().scalarType();
   for (const auto& t : at::TensorList(inputs).slice(1)) {
@@ -305,14 +313,17 @@ bool runFusion(
   }
 
   // Attempts to run fallback if device fusion is disabled
-  if (device.is_cuda() && !canFuseOnGPU()) return false;
-  if (device.is_cpu() && !canFuseOnCPU()) return false;
+  if (device.is_cuda() && !canFuseOnGPU())
+    return false;
+  if (device.is_cpu() && !canFuseOnCPU())
+    return false;
 
   // Validates sizes and expands inputs as needed
   auto maybe_map_size = canRunKernel(spec, inputs);
 
   // Tries to run fallback if map size can't be computed
-  if (!maybe_map_size) return false;
+  if (!maybe_map_size)
+    return false;
   expandArgs(spec, inputs, *maybe_map_size);
 
   // Retrieves the kernel, compiling (and caching) if necessary
@@ -332,9 +343,9 @@ bool runFusion(
   // Updates stack
   drop(stack, spec.nInputs());
   stack.insert(
-    stack.end()
-  , std::make_move_iterator(outputs.begin())
-  , std::make_move_iterator(outputs.end()));
+      stack.end(),
+      std::make_move_iterator(outputs.begin()),
+      std::make_move_iterator(outputs.end()));
 
   return true;
 }
index c83a621..9af2cd9 100644 (file)
@@ -7,13 +7,13 @@
 
 #include <cstdint>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Runs the fusion associated with the key (see registerFusion() in interface.h)
 // on the inputs taken from the given Stack.
-TORCH_API bool runFusion(
-  const int64_t key
-, Stack& stack);
+TORCH_API bool runFusion(const int64_t key, Stack& stack);
 
 } // namespace fuser
 } // namespace jit
index 335f7ea..23868c1 100644 (file)
@@ -1,40 +1,41 @@
 #include <torch/csrc/jit/fuser/fallback.h>
 
-#include <torch/csrc/utils/functional.h> //fmap
+#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/fuser/kernel_cache.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/stack.h>
-#include <torch/csrc/jit/custom_operator.h>
-#include <torch/csrc/jit/fuser/kernel_cache.h>
+#include <torch/csrc/utils/functional.h> //fmap
 
 #include <stdexcept>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
-// Registers fused operators so that fused graphs can properly generate fallback code.
-RegisterOperators reg_fused_operators({
-  Operator(
-    prim::FusedConcat
-  , [](const Node* node) {
-    int64_t dim = node->i(attr::dim);
-    int64_t num_inputs = node->inputs().size();
-    return [dim, num_inputs](Stack& stack) {
-    auto result = at::cat(
-      fmap(last(stack, num_inputs), [](const IValue& i) { return i.toTensor(); })
-    , dim
-    );
-    drop(stack, num_inputs);
-    pack(stack, std::move(result));
-    return 0;
-    };
-  })
-});
+// Registers fused operators so that fused graphs can properly generate fallback
+// code.
+RegisterOperators reg_fused_operators(
+    {Operator(prim::FusedConcat, [](const Node* node) {
+      int64_t dim = node->i(attr::dim);
+      int64_t num_inputs = node->inputs().size();
+      return [dim, num_inputs](Stack& stack) {
+        auto result = at::cat(
+            fmap(
+                last(stack, num_inputs),
+                [](const IValue& i) { return i.toTensor(); }),
+            dim);
+        drop(stack, num_inputs);
+        pack(stack, std::move(result));
+        return 0;
+      };
+    })});
 
 void runFallback(int64_t key, Stack& stack) {
   auto maybe_spec = retrieve(key);
   if (!maybe_spec)
     throw std::runtime_error("Failed to find fusion spec to run fallback.");
-  
+
   InterpreterState{(*maybe_spec)->code()}.run(stack);
 }
 
index b94a325..ab55218 100644 (file)
@@ -6,7 +6,9 @@
 
 #include <cstdlib>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 void runFallback(int64_t key, Stack& stack);
 
index 41a4994..39a590c 100644 (file)
@@ -3,15 +3,17 @@
 #if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <ATen/ATen.h>
-#include <torch/csrc/utils/disallow_copy.h>
-#include <torch/csrc/jit/fuser/tensor_desc.h>
 #include <torch/csrc/jit/fuser/partition_desc.h>
+#include <torch/csrc/jit/fuser/tensor_desc.h>
+#include <torch/csrc/utils/disallow_copy.h>
 
-#include <string>
 #include <cstdint>
+#include <string>
 #include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 struct FusedKernel {
   TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel);
@@ -34,7 +36,6 @@ struct FusedKernel {
 
   virtual ~FusedKernel() = default;
 
-
   // arguments is a list of pointers to the arguments for the compiled CUDA/CPU
   // code.
   // The format of arguments is suitable for directly passing to a call to
@@ -43,23 +44,36 @@ struct FusedKernel {
   // CUDA code), and the remainder are pointers to the TensorInfo<T> structs
   // that compiled code uses to load Tensor data.
   // launch_with_tensors handles packing at::Tensors into this arguments array.
-  // CPU code uses the same convension so that launch_with_tensors can be shared.
-  virtual void launch_raw(
-    const uint32_t numel
-  , std::vector<void*>& arguments) const = 0;
+  // CPU code uses the same convension so that launch_with_tensors can be
+  // shared.
+  virtual void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
+      const = 0;
   virtual at::Backend backend() const = 0;
 
   // Getters
-  const std::string& name() const { return name_; }
-  const std::string& code() const { return code_; }
-  const std::vector<TensorDesc>& inputDesc() const { return input_desc_; }
-  const std::vector<TensorDesc>& outputDesc() const { return output_desc_; }
-  const std::vector<PartitionDesc>& chunkDesc() const { return chunk_desc_; }
-  const std::vector<PartitionDesc>& concatDesc() const { return concat_desc_; }
-  bool hasRandom() const { return has_random_; }
-
-
-protected:
+  const std::string& name() const {
+    return name_;
+  }
+  const std::string& code() const {
+    return code_;
+  }
+  const std::vector<TensorDesc>& inputDesc() const {
+    return input_desc_;
+  }
+  const std::vector<TensorDesc>& outputDesc() const {
+    return output_desc_;
+  }
+  const std::vector<PartitionDesc>& chunkDesc() const {
+    return chunk_desc_;
+  }
+  const std::vector<PartitionDesc>& concatDesc() const {
+    return concat_desc_;
+  }
+  bool hasRandom() const {
+    return has_random_;
+  }
+
+ protected:
   const std::string name_;
   const std::string code_;
   const std::vector<TensorDesc> input_desc_;
index e8d63aa..4e63c6f 100644 (file)
@@ -2,14 +2,15 @@
 
 #include <torch/csrc/jit/fuser/config.h>
 #if USE_CUDA_FUSER || USE_CPU_FUSER
-  #include <torch/csrc/jit/fuser/compiler.h>
-  #include <torch/csrc/jit/fuser/executor.h>
-  #include <torch/csrc/jit/fuser/fallback.h>
+#include <torch/csrc/jit/fuser/compiler.h>
+#include <torch/csrc/jit/fuser/executor.h>
+#include <torch/csrc/jit/fuser/fallback.h>
 #endif // USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <stdexcept>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace detail {
 
@@ -19,34 +20,35 @@ bool cpu_fuser_enabled = false;
 } // namespace detail
 
 int64_t registerFusion(const Node* fusion_group) {
-  #if USE_CUDA_FUSER || USE_CPU_FUSER
-    return fuser::registerFusion(fusion_group);
-  #else
-    throw std::runtime_error("Fusion not supported for this build.");
-  #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+  return fuser::registerFusion(fusion_group);
+#else
+  throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 void runFusion(const int64_t key, Stack& stack) {
-  #if USE_CUDA_FUSER || USE_CPU_FUSER
-    const auto result = fuser::runFusion(key, stack);
-    if (!result) fuser::runFallback(key, stack);
-  #else 
-    throw std::runtime_error("Fusion not supported for this build.");
-  #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+  const auto result = fuser::runFusion(key, stack);
+  if (!result)
+    fuser::runFallback(key, stack);
+#else
+  throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 bool canFuseOnCPU() {
-  #if USE_CPU_FUSER
-    return detail::cpu_fuser_enabled;
-  #endif // USE_CPU_FUSER
+#if USE_CPU_FUSER
+  return detail::cpu_fuser_enabled;
+#endif // USE_CPU_FUSER
 
   return false;
 }
 
 bool canFuseOnGPU() {
-  #if USE_CUDA_FUSER
-    return true;
-  #endif  // USE_CUDA_FUSER
+#if USE_CUDA_FUSER
+  return true;
+#endif // USE_CUDA_FUSER
 
   return false;
 }
@@ -58,36 +60,37 @@ void overrideCanFuseOnCPU(bool value) {
 // Uses the above interface by stuffing the graph into a node and treating that
 // node as a fusion group.
 std::vector<at::Tensor> debugLaunchGraph(
-  Graph& graph
-, at::ArrayRef<at::Tensor> inputs) {
-  #if USE_CUDA_FUSER || USE_CPU_FUSER
-    // Creates a fusion group node
-    auto wrapper_graph = std::make_shared<Graph>();
-    Node* fusion_group = wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
-    fusion_group->g_(attr::Subgraph, graph.copy());
-    for (size_t i = 0; i < graph.inputs().size(); ++i) {
-      fusion_group->addInput(wrapper_graph->addInput());
-    }
-    for (size_t i = 0; i < graph.outputs().size(); ++i) {
-      wrapper_graph->registerOutput(fusion_group->addOutput());
-    }
-
-    // Creates the stack, registers and runs the fusion
-    Stack stack = fmap<IValue>(inputs);
-    const auto key = fuser::registerFusion(fusion_group);
-    fuser::runFusion(key, stack);
-    return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
-  #else 
-    throw std::runtime_error("Fusion not supported for this build.");
-  #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+    Graph& graph,
+    at::ArrayRef<at::Tensor> inputs) {
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+  // Creates a fusion group node
+  auto wrapper_graph = std::make_shared<Graph>();
+  Node* fusion_group =
+      wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
+  fusion_group->g_(attr::Subgraph, graph.copy());
+  for (size_t i = 0; i < graph.inputs().size(); ++i) {
+    fusion_group->addInput(wrapper_graph->addInput());
+  }
+  for (size_t i = 0; i < graph.outputs().size(); ++i) {
+    wrapper_graph->registerOutput(fusion_group->addOutput());
+  }
+
+  // Creates the stack, registers and runs the fusion
+  Stack stack = fmap<IValue>(inputs);
+  const auto key = fuser::registerFusion(fusion_group);
+  fuser::runFusion(key, stack);
+  return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
+#else
+  throw std::runtime_error("Fusion not supported for this build.");
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
-size_t nCompiledKernels() { 
-  #if USE_CUDA_FUSER || USE_CPU_FUSER
-    return fuser::nCompiledKernels(); 
-  #else
-    return 0;
-  #endif // USE_CUDA_FUSER || USE_CPU_FUSER
+size_t nCompiledKernels() {
+#if USE_CUDA_FUSER || USE_CPU_FUSER
+  return fuser::nCompiledKernels();
+#else
+  return 0;
+#endif // USE_CUDA_FUSER || USE_CPU_FUSER
 }
 
 } // namespace jit
index 89d3b88..8136334 100644 (file)
@@ -5,11 +5,12 @@
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/stack.h>
 
+#include <cstdint>
 #include <memory>
 #include <vector>
-#include <cstdint>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 constexpr int kCPUDevice = -1;
 
@@ -27,15 +28,16 @@ TORCH_API void runFusion(const int64_t key, Stack& stack);
 TORCH_API bool canFuseOnCPU();
 TORCH_API bool canFuseOnGPU();
 
-// Sets whether fusion on the CPU is allowed (disabled by default due to flakiness)
+// Sets whether fusion on the CPU is allowed (disabled by default due to
+// flakiness)
 TORCH_API void overrideCanFuseOnCPU(bool value);
 
 // Treats the given graph as a fusion group and launches it on the
 // specified device with the given inputs.
 // Returns the outputs.
 TORCH_API std::vector<at::Tensor> debugLaunchGraph(
-  Graph& graph
-, at::ArrayRef<at::Tensor> inputs);
+    Graph& graph,
+    at::ArrayRef<at::Tensor> inputs);
 
 TORCH_API size_t nCompiledKernels();
 
index 3c52ad8..4496900 100644 (file)
@@ -2,15 +2,17 @@
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 
-#include <unordered_map>
-#include <mutex>
 #include <cstdint>
+#include <mutex>
+#include <unordered_map>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 struct KernelCacheImpl {
   // Note: std::unordered_map does not invalidate references even if rehashing
-  // occurs. This is a critical property for thread-safety. 
+  // occurs. This is a critical property for thread-safety.
   std::mutex mutex_;
   int64_t kernel_counter{0};
 
@@ -33,7 +35,8 @@ int64_t debugNumCachedKernelSpecs() {
   return cache.specMap_.size();
 }
 
-std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph) {
+std::shared_ptr<Graph> normalizeGraphForCache(
+    const std::shared_ptr<Graph>& graph) {
   auto result = Canonicalize(graph, /*keep_unique_names=*/false);
   EraseShapeInformation(result);
   return result;
@@ -49,22 +52,24 @@ int64_t store(std::shared_ptr<Graph> graph) {
   std::lock_guard<std::mutex> guard{cache.mutex_};
   const auto key = cache.kernel_counter++;
   cache.specMap_.emplace(
-    std::piecewise_construct
-  , std::forward_as_tuple(key)
-  , std::forward_as_tuple(key, graph));
+      std::piecewise_construct,
+      std::forward_as_tuple(key),
+      std::forward_as_tuple(key, graph));
   cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
   return key;
 }
 
 // XXX: Does not grab mutex
 static at::optional<KernelSpec*> nolock_retrieve(
-    KernelCacheImpl& cache, const int64_t key) {
+    KernelCacheImpl& cache,
+    const int64_t key) {
   auto it = cache.specMap_.find(key);
-  if (it == cache.specMap_.end()) return at::nullopt;
+  if (it == cache.specMap_.end())
+    return at::nullopt;
   return &(it->second);
 }
 
-at::optional<KernelSpec*> retrieve(const int64_t key) { 
+at::optional<KernelSpec*> retrieve(const int64_t key) {
   auto& cache = getKernelCache();
   std::lock_guard<std::mutex> guard{cache.mutex_};
   return nolock_retrieve(cache, key);
@@ -77,7 +82,8 @@ at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
 
   std::lock_guard<std::mutex> guard{cache.mutex_};
   auto it = cache.graphToKey_.find(repr);
-  if (it == cache.graphToKey_.end()) return at::nullopt;
+  if (it == cache.graphToKey_.end())
+    return at::nullopt;
   return nolock_retrieve(cache, it->second);
 }
 
index 63b4710..792591c 100644 (file)
@@ -4,18 +4,21 @@
 
 #include <c10/util/Optional.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/fuser/kernel_spec.h>
+#include <torch/csrc/jit/ir.h>
 
-#include <cstdint> 
+#include <cstdint>
 #include <functional>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // A thread-safe cache interface.
 
 // Normalizes the graph by canonicalizing and erasing shape information
-TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(const std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(
+    const std::shared_ptr<Graph>& graph);
 
 // Stores the given graph, returning the key used to access it
 TORCH_API int64_t store(std::shared_ptr<Graph> graph);
index 5942bac..50a94a6 100644 (file)
@@ -3,22 +3,24 @@
 #if USE_CUDA_FUSER || USE_CPU_FUSER
 
 #include <ATen/ATen.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <c10/util/Optional.h>
-#include <torch/csrc/jit/stack.h>
-#include <torch/csrc/jit/interpreter.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/fuser/arg_spec.h>
 #include <torch/csrc/jit/fuser/fused_kernel.h>
+#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/jit/interpreter.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/stack.h>
 
-#include <memory>
 #include <cstdint>
-#include <vector>
-#include <unordered_map>
+#include <memory>
 #include <mutex>
+#include <unordered_map>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Helper struct containing partition information: the number of tensors
 // created and the dimension the partitioning is performed on.
@@ -26,32 +28,32 @@ namespace torch { namespace jit { namespace fuser {
 // at runtime the partition info is logically combined with the tensor
 // descriptions to create PartitionDesc objects.
 struct TORCH_API PartitionInfo {
-  PartitionInfo(
-    const int64_t _nSubTensors
-  , const int64_t _dim)
-  : nSubTensors_{_nSubTensors}
-  , dim_{_dim}
-  { };
+  PartitionInfo(const int64_t _nSubTensors, const int64_t _dim)
+      : nSubTensors_{_nSubTensors}, dim_{_dim} {};
 
-  int64_t nSubTensors() const { return nSubTensors_; }
-  int64_t dim() const { return dim_; }
+  int64_t nSubTensors() const {
+    return nSubTensors_;
+  }
+  int64_t dim() const {
+    return dim_;
+  }
 
-private:
+ private:
   int64_t nSubTensors_;
   int64_t dim_;
 };
 
- // "Kernel Specification." - Contains device-independent fusion information.
- // Each kernel specification contains a map of instantiated generated functions
- // that implement some or most of its functionality. Multiple generated
- // functions are needed by each abstract specification because of different
- // devices (cpu vs gpu, different gpus) and different inputs (int vs float,
- // contiguous vs discontiguous).
- // Note: uses a mutex to control access to its kernel store
- // Note: unordered containers do not invalidate references/pointers on
- //   rehashing, which is critical for thread-safety.
- // TODO: allow abstract kernels to use multiple generated kernels
- // TODO: allow abstract kernels to reuse generated kernels from common pool
+// "Kernel Specification." - Contains device-independent fusion information.
+// Each kernel specification contains a map of instantiated generated functions
+// that implement some or most of its functionality. Multiple generated
+// functions are needed by each abstract specification because of different
+// devices (cpu vs gpu, different gpus) and different inputs (int vs float,
+// contiguous vs discontiguous).
+// Note: uses a mutex to control access to its kernel store
+// Note: unordered containers do not invalidate references/pointers on
+//   rehashing, which is critical for thread-safety.
+// TODO: allow abstract kernels to use multiple generated kernels
+// TODO: allow abstract kernels to reuse generated kernels from common pool
 struct TORCH_API KernelSpec {
   KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
       : key_{_key},
@@ -63,10 +65,18 @@ struct TORCH_API KernelSpec {
         kernels_{} {}
 
   // Getters
-  int64_t key() const { return key_; }
-  std::shared_ptr<Graph> graph() const { return graph_; }
-  const Code& code() const { return code_; }
-  int64_t nInputs() const { return nInputs_; }
+  int64_t key() const {
+    return key_;
+  }
+  std::shared_ptr<Graph> graph() const {
+    return graph_;
+  }
+  const Code& code() const {
+    return code_;
+  }
+  int64_t nInputs() const {
+    return nInputs_;
+  }
 
   std::vector<std::vector<int64_t>>& inputBroadcastGroups() {
     return inputBroadcastGroups_;
@@ -75,24 +85,29 @@ struct TORCH_API KernelSpec {
     return inputBroadcastGroups_;
   }
 
-  std::vector<PartitionInfo>& inputChunks() { return inputChunks_; }
-  const std::vector<PartitionInfo>& inputChunks() const { return inputChunks_; }
+  std::vector<PartitionInfo>& inputChunks() {
+    return inputChunks_;
+  }
+  const std::vector<PartitionInfo>& inputChunks() const {
+    return inputChunks_;
+  }
 
   // Cache functions
-  c10::optional<std::shared_ptr<FusedKernel>> findKernel(const ArgSpec& arg_spec) const {
+  c10::optional<std::shared_ptr<FusedKernel>> findKernel(
+      const ArgSpec& arg_spec) const {
     std::lock_guard<std::mutex> guard{mutex_};
     const auto it = kernels_.find(arg_spec);
-    if (it == kernels_.end()) return c10::nullopt;
+    if (it == kernels_.end())
+      return c10::nullopt;
     return it->second;
   }
-  void cacheKernel(
-    const ArgSpec& arg_spec
-  , std::shared_ptr<FusedKernel> kernel) const {
+  void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr<FusedKernel> kernel)
+      const {
     std::lock_guard<std::mutex> guard{mutex_};
     kernels_.emplace(arg_spec, kernel);
   }
 
-private:
+ private:
   int64_t key_;
   std::shared_ptr<Graph> graph_;
   Code code_;
@@ -100,10 +115,9 @@ private:
   std::vector<std::vector<int64_t>> inputBroadcastGroups_;
   std::vector<PartitionInfo> inputChunks_;
   mutable std::mutex mutex_;
-  mutable std::unordered_map<
-    ArgSpec
-  , std::shared_ptr<FusedKernel>
-  , torch::hash<ArgSpec>> kernels_;
+  mutable std::
+      unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, torch::hash<ArgSpec>>
+          kernels_;
 };
 
 } // namespace fuser
index 08b0802..16408d5 100644 (file)
@@ -6,28 +6,23 @@
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/fuser/tensor_desc.h>
 
-#include <memory>
 #include <cstdint>
+#include <memory>
 #include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Descriptor for chunk-ing an input tensor into subtensors
 // OR concat-ing an output tensor from subtensors
 // Note: default constructed used for tensors that do not participate in
 // chunk or cat operations.
 struct TORCH_API PartitionDesc {
-  PartitionDesc()
-  : nSubTensors_{1}
-  , dim_{0} 
-  { }
-
-  PartitionDesc(
-    const TensorDesc& _desc
-  , size_t _nSubTensors
-  , size_t _dim)
-  : nSubTensors_{_nSubTensors}
-  , dim_{_dim} {
+  PartitionDesc() : nSubTensors_{1}, dim_{0} {}
+
+  PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim)
+      : nSubTensors_{_nSubTensors}, dim_{_dim} {
     JIT_ASSERT(nSubTensors_ > 1);
     std::vector<bool> cont = _desc.contiguity;
     if (dim_ > 0) {
@@ -40,20 +35,32 @@ struct TORCH_API PartitionDesc {
     subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont));
   }
 
-  bool isNoop() const { return (nSubTensors_ == 1);}
-  size_t nSubTensors() const { return nSubTensors_; }
-  size_t dim() const { return dim_; }
-  std::shared_ptr<TensorDesc> subTensorDesc() { return subTensorDesc_; }
-  const std::shared_ptr<TensorDesc> subTensorDesc() const { return subTensorDesc_; }
+  bool isNoop() const {
+    return (nSubTensors_ == 1);
+  }
+  size_t nSubTensors() const {
+    return nSubTensors_;
+  }
+  size_t dim() const {
+    return dim_;
+  }
+  std::shared_ptr<TensorDesc> subTensorDesc() {
+    return subTensorDesc_;
+  }
+  const std::shared_ptr<TensorDesc> subTensorDesc() const {
+    return subTensorDesc_;
+  }
 
-private: 
-  size_t nSubTensors_; // == 1 for tensors that should not be operated on via chunk/cat
+ private:
+  size_t nSubTensors_; // == 1 for tensors that should not be operated on via
+                       // chunk/cat
   size_t dim_; // dimension along which the chunk/concat occurs
-  std::shared_ptr<TensorDesc> subTensorDesc_; // descriptor for the subtensor, if it exists
+  std::shared_ptr<TensorDesc>
+      subTensorDesc_; // descriptor for the subtensor, if it exists
 };
 
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CUDA_FUSER || USE_CPU_FUSER
index d1f0b60..fb02867 100644 (file)
@@ -4,15 +4,17 @@
 
 #include <ATen/ATen.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/utils/hash.h>
-#include <torch/csrc/jit/type.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/type.h>
+#include <torch/csrc/utils/hash.h>
 
-#include <vector>
-#include <iostream>
 #include <algorithm>
+#include <iostream>
+#include <vector>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // type information needed by the compiler for input/outputs
 // contiguity[i] is true if the dim i is contiguous with dim i + 1.
@@ -21,33 +23,33 @@ struct TORCH_API TensorDesc {
   at::ScalarType scalar_type;
   std::vector<bool> contiguity;
 
-  TensorDesc(
-    const at::ScalarType& type
-  , const std::vector<bool>& contiguity)
-  : scalar_type{type}
-  , contiguity{contiguity} {
+  TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
+      : scalar_type{type}, contiguity{contiguity} {
     if (contiguity.size() == 0) {
       nDim_ = 0;
     } else {
-      nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0);
+      nDim_ = std::count(contiguity.begin(), contiguity.end(), false) +
+          (lastIsContiguous() ? 1 : 0);
     }
   }
 
   // Delegating constructors
   TensorDesc(
-    const at::ScalarType& type
-  , const at::IntList& sizes
-  , const at::IntList& strides)
-  : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
+      const at::ScalarType& type,
+      const at::IntList& sizes,
+      const at::IntList& strides)
+      : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
 
   TensorDesc(const at::Tensor& t)
-  : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
+      : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
 
   TensorDesc(const CompleteTensorTypePtr& type)
-  : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
+      : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
 
   // number of dimensions after contiguity compression
-  size_t nDim() const { return nDim_; }
+  size_t nDim() const {
+    return nDim_;
+  }
 
   // True iff innermost stride is 1
   bool lastIsContiguous() const {
@@ -55,12 +57,13 @@ struct TORCH_API TensorDesc {
   }
 
   static std::vector<bool> findContiguous(
-    const at::IntList& sizes
-  , const at::IntList& strides) {
+      const at::IntList& sizes,
+      const at::IntList& strides) {
     JIT_ASSERT(sizes.size() == strides.size());
     std::vector<bool> cont(sizes.size());
     for (size_t i = 0; i < sizes.size(); ++i) {
-      const auto expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1;
+      const auto expected_stride =
+          (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1;
       cont[i] = (strides[i] == expected_stride);
     }
     return cont;
@@ -75,10 +78,13 @@ struct TORCH_API TensorDesc {
   }
 
   static size_t hash(const TensorDesc& spec) {
-    return torch::get_hash(spec.scalar_type, spec.nDim_, std::hash<std::vector<bool>>{}(spec.contiguity));
+    return torch::get_hash(
+        spec.scalar_type,
+        spec.nDim_,
+        std::hash<std::vector<bool>>{}(spec.contiguity));
   }
 
-private:
+ private:
   size_t nDim_;
 };
 
index 487359c..161cf3b 100644 (file)
@@ -6,23 +6,28 @@
 
 #include <cstdint>
 
-namespace torch { namespace jit { namespace fuser {
+namespace torch {
+namespace jit {
+namespace fuser {
 
 // Host-side view of TensorInfo
 // Note dims[0] - we need to dynamically allocate the dims.
 struct TORCH_API TensorInfo {
-  
-  uint32_t* sizes(size_t nDim) { return &sizes_strides[0]; }
-  uint32_t* strides(size_t nDim) { return &sizes_strides[nDim]; }
+  uint32_t* sizes(size_t nDim) {
+    return &sizes_strides[0];
+  }
+  uint32_t* strides(size_t nDim) {
+    return &sizes_strides[nDim];
+  }
 
   void* data;
-  #pragma GCC diagnostic ignored "-Wpedantic"
-    uint32_t sizes_strides[0];
-  #pragma GCC diagnostic pop
+#pragma GCC diagnostic ignored "-Wpedantic"
+  uint32_t sizes_strides[0];
+#pragma GCC diagnostic pop
 };
 
 } // namespace fuser
-} // namespace jit 
+} // namespace jit
 } // namespace torch
 
 #endif // USE_CUDA_FUSER || USE_CPU_FUSER
index fa839ba..407b305 100644 (file)
@@ -1,17 +1,23 @@
 // TODO: I'm pretty sure Constness can be done with C++ templates, ala
 // std::is_const, but no time to work it out...
-#define GENERIC_IF(Constness, FullKind, x, Kind) \
-  auto && __match_key = x; \
-  switch(__match_key->kind()) { \
-    case FullKind: { \
-      auto * value = static_cast<Constness ::torch::jit::Kind*>(__match_key); (void) value;
-#define GENERIC_ELSEIF(Constness, FullKind, Kind) \
-    } break; \
-    case FullKind: { \
-      auto * value = static_cast<Constness ::torch::jit::Kind*>(__match_key); (void) value;
+#define GENERIC_IF(Constness, FullKind, x, Kind)                             \
+  auto&& __match_key = x;                                                    \
+  switch (__match_key->kind()) {                                             \
+    case FullKind: {                                                         \
+      auto* value = static_cast<Constness ::torch::jit::Kind*>(__match_key); \
+      (void)value;
+#define GENERIC_ELSEIF(Constness, FullKind, Kind)                          \
+  }                                                                        \
+  break;                                                                   \
+  case FullKind: {                                                         \
+    auto* value = static_cast<Constness ::torch::jit::Kind*>(__match_key); \
+    (void)value;
 #define GENERIC_ELSE() \
-    } break; \
-    default: {
+  }                    \
+  break;               \
+  default: {
 #define GENERIC_END() \
-    } break; \
-  };
+  }                   \
+  break;              \
+  }                   \
+  ;
index ab0ab40..6d4552f 100644 (file)
@@ -1,46 +1,47 @@
 #include <torch/csrc/jit/graph_executor.h>
 
-#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/autograd/grad_mode.h>
 #include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/passes/batch_mm.h>
+#include <torch/csrc/jit/passes/canonicalize_ops.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/graph_fuser.h>
+#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
 #include <torch/csrc/jit/passes/inplace_check.h>
-#include <torch/csrc/jit/passes/peephole.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/remove_expands.h>
-#include <torch/csrc/jit/passes/canonicalize_ops.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
 #include <torch/csrc/jit/passes/loop_unrolling.h>
 #include <torch/csrc/jit/passes/lower_grad_of.h>
-#include <torch/csrc/jit/passes/constant_propagation.h>
-#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
+#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
+#include <torch/csrc/jit/passes/specialize_undef.h>
 #include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/tracer.h>
 
 #include <torch/csrc/autograd/edge.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/jit/script/compiler.h>
 
 #include <cstdint>
+#include <iterator>
 #include <memory>
 #include <mutex>
 #include <unordered_map>
 #include <utility>
 #include <vector>
-#include <iterator>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -51,8 +52,7 @@ using autograd::variable_list;
 struct ExecutionPlan {
   ExecutionPlan() = default;
   ExecutionPlan(std::shared_ptr<Graph> graph)
-    : code(graph)
-    , graph(std::move(graph)) {}
+      : code(graph), graph(std::move(graph)) {}
 
   void run(Stack& stack) const {
     return InterpreterState(code).run(stack);
@@ -75,7 +75,7 @@ struct ExecutionPlan {
 
 struct DifferentiableGraphBackward : public autograd::Function {
   DifferentiableGraphBackward(GraphExecutor executor, size_t capture_size)
-  : executor(std::move(executor)) {
+      : executor(std::move(executor)) {
     is_var_capture.reserve(capture_size);
     var_captures.reserve(capture_size);
     ivalue_captures.reserve(capture_size);
@@ -84,8 +84,10 @@ struct DifferentiableGraphBackward : public autograd::Function {
   variable_list apply(variable_list&& inputs) override {
     Stack stack;
     stack.reserve(is_var_capture.size() + inputs.size());
-    stack.insert(stack.end(), std::make_move_iterator(inputs.begin()),
-                              std::make_move_iterator(inputs.end()));
+    stack.insert(
+        stack.end(),
+        std::make_move_iterator(inputs.begin()),
+        std::make_move_iterator(inputs.end()));
     auto var_capture_it = var_captures.begin();
     auto ivalue_capture_it = ivalue_captures.begin();
     for (bool is_var : is_var_capture) {
@@ -106,11 +108,12 @@ struct DifferentiableGraphBackward : public autograd::Function {
     for (size_t i = 0; i < num_outputs(); ++i) {
       if (should_compute_output(i)) {
         auto output = std::move(stack[i]).toTensor();
-        const auto & edge = next_edge(i);
+        const auto& edge = next_edge(i);
         if (output.defined()) {
           outputs.emplace_back(std::move(output));
         } else if (edge.is_valid()) {
-          outputs.emplace_back(edge.function->input_metadata(edge.input_nr).zeros_like());
+          outputs.emplace_back(
+              edge.function->input_metadata(edge.input_nr).zeros_like());
         } else {
           outputs.emplace_back();
         }
@@ -121,7 +124,7 @@ struct DifferentiableGraphBackward : public autograd::Function {
     return outputs;
   }
 
-  void capture(const IValue & val, bool is_output) {
+  void capture(const IValue& val, bool is_output) {
     const bool is_tensor = val.isTensor();
     is_var_capture.push_back(is_tensor);
     if (is_tensor) {
@@ -130,11 +133,13 @@ struct DifferentiableGraphBackward : public autograd::Function {
       ivalue_captures.push_back(val);
     }
   }
-private:
+
+ private:
   friend struct ExecutionPlan;
   GraphExecutor executor;
 
-  // INVARIANT: is_var_capture.size() == var_captures.size() + ivalue_captures.size()
+  // INVARIANT: is_var_capture.size() == var_captures.size() +
+  // ivalue_captures.size()
   std::vector<bool> is_var_capture;
   std::vector<autograd::SavedVariable> var_captures;
   std::vector<IValue> ivalue_captures;
@@ -154,16 +159,20 @@ struct DifferentiableGraphOp {
         num_outputs(this->grad.f->outputs().size()) {}
 
   // XXX: keep in mind that stack can be larger than the inputs we need!
-  int operator()(Stack & stack) const {
-    auto grad_fn = std::make_shared<DifferentiableGraphBackward>(grad_executor,
-      grad.df_input_captured_inputs.size() + grad.df_input_captured_outputs.size());
+  int operator()(Stack& stack) const {
+    auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
+        grad_executor,
+        grad.df_input_captured_inputs.size() +
+            grad.df_input_captured_outputs.size());
 
     {
       auto inputs = last(stack, num_inputs);
-      // hook up the outputs of df to the gradient functions of the inputs that require gradients
-      for(auto idx : grad.df_output_vjps) {
+      // hook up the outputs of df to the gradient functions of the inputs that
+      // require gradients
+      for (auto idx : grad.df_output_vjps) {
         auto v = Variable(inputs[idx].toTensor());
-        grad_fn->add_next_edge(v.defined() ? v.gradient_edge() : autograd::Edge{});
+        grad_fn->add_next_edge(
+            v.defined() ? v.gradient_edge() : autograd::Edge{});
       }
       captureInputs(*grad_fn, inputs);
     }
@@ -175,19 +184,19 @@ struct DifferentiableGraphOp {
       auto outputs = last(stack, num_outputs);
       // hookup the gradients for the output tensors that require gradients
       // to the inputs to our gradient function df
-      // TODO - XXX - if any output is the same tensor multiple times, views have to be
-      // setup here. We need to refactor autograd until it is safe for
-      // tensors to be constructed without all the viewing infrastructure.
-      // this is currently intentionally not done here so we can get an idea of our
-      // perf before introducing overhead for correctness
-      for(auto idx : grad.df_input_vjps) {
+      // TODO - XXX - if any output is the same tensor multiple times, views
+      // have to be setup here. We need to refactor autograd until it is safe
+      // for tensors to be constructed without all the viewing infrastructure.
+      // this is currently intentionally not done here so we can get an idea of
+      // our perf before introducing overhead for correctness
+      for (auto idx : grad.df_input_vjps) {
         // Note: we have to set this up in place, or we have to throw away and
         // reallocate variables that were already created in wrapTensors. We
         // should add an API for this.
         Variable output = outputs[idx].toTensor();
-        // NB: since our requires_grad setting is only a heuristic we might end up
-        // wanting to differentiate through integral tensors, which is generally a
-        // hard error in autograd.
+        // NB: since our requires_grad setting is only a heuristic we might end
+        // up wanting to differentiate through integral tensors, which is
+        // generally a hard error in autograd.
         if (at::isFloatingType(output.type().scalarType())) {
           autograd::create_gradient_edge(output, grad_fn);
           output.set_requires_grad(true);
@@ -204,30 +213,37 @@ struct DifferentiableGraphOp {
     return 0;
   }
 
-private:
+ private:
   friend GraphExecutor* detail::getGradExecutor(Operation& op);
 
-  void detachVariables(Stack & stack) const {
-    // It would be nice to use an ArrayRef here, but unfortunately those can only
-    // return const references, so we need to do a bunch of indexing ourselves.
+  void detachVariables(Stack& stack) const {
+    // It would be nice to use an ArrayRef here, but unfortunately those can
+    // only return const references, so we need to do a bunch of indexing
+    // ourselves.
     const int64_t stack_size = stack.size();
     const int64_t stack_offset = stack_size - num_inputs;
     for (int64_t i = stack_offset; i < stack_size; ++i) {
-      auto & v = stack[i];
-      if (!v.isTensor()) continue;
+      auto& v = stack[i];
+      if (!v.isTensor())
+        continue;
       auto t = std::move(v).toTensor();
-      v = IValue{t.defined() ? autograd::as_variable_ref(t).detach() : std::move(t)};
+      v = IValue{t.defined() ? autograd::as_variable_ref(t).detach()
+                             : std::move(t)};
     }
   }
   // Capture (save) inputs that would be required to subsequently run backwards
-  void captureInputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> inputs) const {
+  void captureInputs(
+      DifferentiableGraphBackward& grad_fn,
+      at::ArrayRef<IValue> inputs) const {
     for (size_t offset : grad.df_input_captured_inputs) {
-      grad_fn.capture(inputs[offset], /*is_output*/false);
+      grad_fn.capture(inputs[offset], /*is_output*/ false);
     }
   }
-  void captureOutputs(DifferentiableGraphBackward & grad_fn, at::ArrayRef<IValue> outputs) const {
+  void captureOutputs(
+      DifferentiableGraphBackward& grad_fn,
+      at::ArrayRef<IValue> outputs) const {
     for (size_t offset : grad.df_input_captured_outputs) {
-      grad_fn.capture(outputs[offset], /*is_output*/true);
+      grad_fn.capture(outputs[offset], /*is_output*/ true);
     }
   }
 
@@ -239,39 +255,42 @@ private:
   const size_t num_outputs;
 };
 
-void packGradient(Gradient gradient, Node *dnode) {
+void packGradient(Gradient gradient, Nodednode) {
   JIT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
   dnode->g_(attr::Subgraph, gradient.f)
-       ->g_(attr::ReverseSubgraph, gradient.df)
-       ->i_(attr::f_real_outputs, gradient.f_real_outputs)
-       ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
-       ->is_(attr::df_input_captured_inputs, fmap<int64_t>(gradient.df_input_captured_inputs))
-       ->is_(attr::df_input_captured_outputs, fmap<int64_t>(gradient.df_input_captured_outputs))
-       ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
+      ->g_(attr::ReverseSubgraph, gradient.df)
+      ->i_(attr::f_real_outputs, gradient.f_real_outputs)
+      ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
+      ->is_(
+          attr::df_input_captured_inputs,
+          fmap<int64_t>(gradient.df_input_captured_inputs))
+      ->is_(
+          attr::df_input_captured_outputs,
+          fmap<int64_t>(gradient.df_input_captured_outputs))
+      ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
 }
 
-Gradient getGradient(const Node *n) {
+Gradient getGradient(const Noden) {
   JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
   Gradient grad;
   grad.f = n->g(attr::Subgraph);
   grad.df = n->g(attr::ReverseSubgraph);
   grad.f_real_outputs = n->i(attr::f_real_outputs);
   grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
-  grad.df_input_captured_inputs = fmap<size_t>(n->is(attr::df_input_captured_inputs));
-  grad.df_input_captured_outputs = fmap<size_t>(n->is(attr::df_input_captured_outputs));
+  grad.df_input_captured_inputs =
+      fmap<size_t>(n->is(attr::df_input_captured_inputs));
+  grad.df_input_captured_outputs =
+      fmap<size_t>(n->is(attr::df_input_captured_outputs));
   grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
   return grad;
 }
 
 } // anonymous namespace
 
-RegisterOperators reg_graph_executor_ops({
-  Operator(
-    prim::DifferentiableGraph,
-    [](const Node *n) -> Operation {
+RegisterOperators reg_graph_executor_ops(
+    {Operator(prim::DifferentiableGraph, [](const Node* n) -> Operation {
       return DifferentiableGraphOp(getGradient(n));
-    })
-});
+    })});
 
 namespace detail {
 
@@ -286,11 +305,10 @@ GraphExecutor* getGradExecutor(Operation& op) {
 
 // a Graph can be created via tracing, or via a language-based frontend
 // GraphExecutor runs it. It can run the same graph on many different sizes
-// and different requires_grad states, and handles specializations for each situation.
-// GraphExecutor is completely unaware of tracing or module parameters to keep the
-// tracing concerns separated.
+// and different requires_grad states, and handles specializations for each
+// situation. GraphExecutor is completely unaware of tracing or module
+// parameters to keep the tracing concerns separated.
 struct GraphExecutorImpl {
-
   static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) {
     auto copy = graph->copy();
     EraseShapeInformation(copy);
@@ -303,7 +321,7 @@ struct GraphExecutorImpl {
     }
     if (auto tuple_type = ptr->cast<TupleType>()) {
       size_t total = 0;
-      for (auto & elem : tuple_type->elements()) {
+      for (auto& elem : tuple_type->elements()) {
         total += countFlatInputs(elem);
       }
       return total;
@@ -313,18 +331,18 @@ struct GraphExecutorImpl {
 
   static size_t countFlatInputs(const std::shared_ptr<Graph>& graph) {
     size_t total = 0;
-    for (Value * input : graph->inputs()) {
+    for (Value* input : graph->inputs()) {
       total += countFlatInputs(input->type());
     }
     return total;
   }
 
   inline bool hasMutableOperators(Block* block) {
-    for(auto n : block->nodes()) {
-      if(n->kind().is_aten() && n->schema().is_mutable())
+    for (auto n : block->nodes()) {
+      if (n->kind().is_aten() && n->schema().is_mutable())
         return true;
-      for(auto b : n->blocks()) {
-        if(hasMutableOperators(b))
+      for (auto b : n->blocks()) {
+        if (hasMutableOperators(b))
           return true;
       }
     }
@@ -341,21 +359,28 @@ struct GraphExecutorImpl {
         num_outputs(this->graph->outputs().size()) {}
 
   // entry point where execution begins
-  void run(Stack & stack) {
-    AT_CHECK(stack.size() >= num_inputs, "expected ", num_inputs, " inputs, but got only ", stack.size());
-
-    if(tracer::isTracing()) {
+  void run(Stack& stack) {
+    AT_CHECK(
+        stack.size() >= num_inputs,
+        "expected ",
+        num_inputs,
+        " inputs, but got only ",
+        stack.size());
+
+    if (tracer::isTracing()) {
       return runTraced(stack);
     }
 
-    auto & execution_plan = optimize ? getOrCompile(stack) : getOrCompileFallback();
+    auto& execution_plan =
+        optimize ? getOrCompile(stack) : getOrCompileFallback();
     return execution_plan.run(stack);
   }
 
   std::shared_ptr<Graph> graphFor(const Stack& stack) const {
     JIT_ASSERT(stack.size() >= num_inputs);
     auto inputs = last(stack, num_inputs);
-    ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+    ArgumentSpec spec(
+        autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
 
     if (!optimize) {
       AT_CHECK(fallback, "No graph found for given inputs");
@@ -373,7 +398,7 @@ struct GraphExecutorImpl {
     if (fallback) {
       state.fallback = fallback.getDebugState();
     }
-    for (auto & entry : plan_cache) {
+    for (auto& entry : plan_cache) {
       state.execution_plans.emplace(entry.first, entry.second.getDebugState());
     }
     return state;
@@ -387,12 +412,12 @@ struct GraphExecutorImpl {
     autodiffSubgraphInlineThreshold = 1;
   }
 
-private:
+ private:
   friend struct GraphExecutor;
 
-  const ExecutionPlan & getOrCompileFallback() {
+  const ExecutionPlan& getOrCompileFallback() {
     std::lock_guard<std::mutex> lock(compile_mutex);
-    if(!fallback) {
+    if (!fallback) {
       auto graph_ = graph->copy();
       runRequiredPasses(graph_);
       fallback = ExecutionPlan(graph_);
@@ -400,10 +425,13 @@ private:
     return fallback;
   }
 
-  const ExecutionPlan & getOrCompile(const Stack& stack) {
-    // outside lock guard, to minimize the time holding the lock on the fast path
-    // ArgumentSpec even computes its hashCode here.
-    ArgumentSpec spec(autograd::GradMode::is_enabled(), last(stack, num_inputs), num_flat_inputs);
+  const ExecutionPlan& getOrCompile(const Stack& stack) {
+    // outside lock guard, to minimize the time holding the lock on the fast
+    // path ArgumentSpec even computes its hashCode here.
+    ArgumentSpec spec(
+        autograd::GradMode::is_enabled(),
+        last(stack, num_inputs),
+        num_flat_inputs);
     {
       std::lock_guard<std::mutex> lock(compile_mutex);
       auto it = plan_cache.find(spec);
@@ -415,7 +443,7 @@ private:
     }
   }
 
-  ExecutionPlan compileSpec(const ArgumentSpec & spec) {
+  ExecutionPlan compileSpec(const ArgumentSpec& spec) {
     auto opt_graph = graph->copy();
     setInputTypes(*opt_graph, spec);
 
@@ -427,13 +455,14 @@ private:
     // Phase 2. Propagate detailed information about the spec through the
     //          graph (enabled more specializations in later passes).
     //          Shape propagation sometimes depends on certain arguments being
-    //          constants, and constant propagation doesn't need shape information
-    //          anyway, so it's better to run it first.
+    //          constants, and constant propagation doesn't need shape
+    //          information anyway, so it's better to run it first.
     ConstantPropagation(opt_graph);
     PropagateInputShapes(opt_graph);
     PropagateRequiresGrad(opt_graph);
 
-    // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
+    // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
+    // that
     //          we can still execute using autograd).
     runOptimization(opt_graph, spec);
 
@@ -442,8 +471,9 @@ private:
     // Phase 5. Apply non-differentiable optimizations to the graphs we've found
     //          (or the whole grpah if we know we won't need its derivative).
     if (needsGradient(opt_graph)) {
-      auto diff_nodes = CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
-      for (Node * dnode : diff_nodes) {
+      auto diff_nodes =
+          CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
+      for (Node* dnode : diff_nodes) {
         auto diff_graph = std::move(dnode->g(attr::Subgraph));
         Gradient gradient = differentiate(diff_graph);
         runNondiffOptimization(gradient.f);
@@ -458,7 +488,9 @@ private:
     return ExecutionPlan(opt_graph);
   }
 
-  void runOptimization(std::shared_ptr<Graph>& graph, const ArgumentSpec& spec) {
+  void runOptimization(
+      std::shared_ptr<Graph>& graph,
+      const ArgumentSpec& spec) {
     // Basic graph preprocessing to eliminate noise.
     EliminateDeadCode(graph);
     EliminateCommonSubexpression(graph);
@@ -486,7 +518,7 @@ private:
       return false;
     if (mayIntroduceGradient(graph->block()))
       return true;
-    for (const Value * input : graph->inputs()) {
+    for (const Value* input : graph->inputs()) {
       if (input->type()->requires_grad())
         return true;
     }
@@ -505,17 +537,18 @@ private:
     return false;
   }
 
-  void runTraced(Stack & stack) {
+  void runTraced(Stack& stack) {
     const auto& state = tracer::getTracingState();
     auto inputs = last(stack, num_inputs);
-    auto input_values = fmap(inputs, [](const IValue & v) {
-      return tracer::getNestedValueTrace(v);
-    });
-
-    ArgumentSpec spec(autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
-    // NB: we could just run the fallback in here and call it a day, but that would loose all
-    // the control flow information we have in the graph. Thus, we run the fallback to
-    // get the correct output values, but we will override the tracing states later.
+    auto input_values = fmap(
+        inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); });
+
+    ArgumentSpec spec(
+        autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+    // NB: we could just run the fallback in here and call it a day, but that
+    // would loose all the control flow information we have in the graph. Thus,
+    // we run the fallback to get the correct output values, but we will
+    // override the tracing states later.
     {
       // No need to trace a script module.
       ResourceGuard guard(tracer::pauseTracing());
@@ -530,7 +563,8 @@ private:
     auto local_graph = this->graph->copy();
     setInputTypes(*local_graph, spec);
     PropagateInputShapes(local_graph);
-    auto output_values = inlineCallTo(*state->graph, *local_graph, input_values);
+    auto output_values =
+        inlineCallTo(*state->graph, *local_graph, input_values);
 
     auto outputs = last(stack, num_outputs);
     for (size_t i = 0; i < outputs.size(); ++i) {
@@ -538,27 +572,29 @@ private:
     }
   }
 
-  // The unoptimized starting graph. This field is effectively const, but we can't make it so
-  // because Graph::copy() is not const (and making it const is not that easy at this point).
+  // The unoptimized starting graph. This field is effectively const, but we
+  // can't make it so because Graph::copy() is not const (and making it const is
+  // not that easy at this point).
   std::shared_ptr<Graph> graph;
 
-  // If false, we'll run the graph as we get it, without any optimizations. Useful
-  // for debugging.
+  // If false, we'll run the graph as we get it, without any optimizations.
+  // Useful for debugging.
   const bool optimize;
   const size_t num_inputs;
-  const size_t num_flat_inputs; // Number of inputs, assuming all tuples would be flattened.
+  const size_t num_flat_inputs; // Number of inputs, assuming all tuples would
+                                // be flattened.
   const size_t num_outputs;
 
-  // Populated only when optimize is false (and in that case plan_cache will be unused).
-  // The compiled version of graph.
+  // Populated only when optimize is false (and in that case plan_cache will be
+  // unused). The compiled version of graph.
   ExecutionPlan fallback;
 
-  // Mapping from argument configurations to optimized versions of the graph that are
-  // specialized to the spec.
+  // Mapping from argument configurations to optimized versions of the graph
+  // that are specialized to the spec.
   std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
 
-  // GraphExecutors can be accessed from multiple threads, so this thread needs to be
-  // held every time we access the fallback or plan_cache.
+  // GraphExecutors can be accessed from multiple threads, so this thread needs
+  // to be held every time we access the fallback or plan_cache.
   std::mutex compile_mutex;
 
   // Some tunable parameters
@@ -567,9 +603,9 @@ private:
 };
 
 GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
-: pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
+    : pImpl(new GraphExecutorImpl(std::move(graph), optimize)) {}
 
-void GraphExecutor::run(Stack & inputs) {
+void GraphExecutor::run(Stack& inputs) {
   return pImpl->run(inputs);
 }
 
@@ -589,8 +625,7 @@ void GraphExecutor::debugDisableAutodiffSubgraphInlining() {
   return pImpl->debugDisableAutodiffSubgraphInlining();
 }
 
-
-void runRequiredPasses(const std::shared_ptr<Graph>& g)  {
+void runRequiredPasses(const std::shared_ptr<Graph>& g) {
   specializeUndef(*g);
   LowerGradOf(*g);
   // implicit inserted expand nodes are not necessarily always valid
@@ -602,4 +637,5 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g)  {
   EliminateDeadCode(g);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index a53056c..e6564e5 100644 (file)
@@ -1,13 +1,14 @@
 #pragma once
 
-#include <memory>
+#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/variable_tensor_list.h>
-#include <torch/csrc/jit/interpreter.h>
-#include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/argument_spec.h>
+#include <memory>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct GraphExecutorState;
 
@@ -29,7 +30,7 @@ struct GraphExecutorImpl;
 struct TORCH_API GraphExecutor {
   GraphExecutor() = default;
   GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
-  void run(Stack & inputs);
+  void run(Stack& inputs);
   explicit operator bool() const {
     return pImpl != nullptr;
   }
@@ -37,7 +38,8 @@ struct TORCH_API GraphExecutor {
   std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
   GraphExecutorState getDebugState();
   void debugDisableAutodiffSubgraphInlining();
-private:
+
+ private:
   std::shared_ptr<GraphExecutorImpl> pImpl;
 };
 
@@ -51,5 +53,5 @@ GraphExecutor* getGradExecutor(Operation& op);
 
 } // namespace detail
 
-
-}}
+} // namespace jit
+} // namespace torch
index 3fe3dcd..c20cb72 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/assertions.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Intrusive doubly linked lists with sane reverse iterators.
 // The header file is named generic_graph_node_list.h because it is ONLY
@@ -39,21 +40,28 @@ struct Node;
 using graph_node_list = generic_graph_node_list<Node>;
 using const_graph_node_list = generic_graph_node_list<const Node>;
 using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
-using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;
+using const_graph_node_list_iterator =
+    generic_graph_node_list_iterator<const Node>;
 
 template <typename T>
 struct generic_graph_node_list_iterator {
-  generic_graph_node_list_iterator()
-    : cur(nullptr), d(kNextDirection) {}
-  generic_graph_node_list_iterator(T * cur, int d)
-    : cur(cur), d(d) {}
-  generic_graph_node_list_iterator(const generic_graph_node_list_iterator & rhs) = default;
-  generic_graph_node_list_iterator(generic_graph_node_list_iterator && rhs) = default;
-  generic_graph_node_list_iterator& operator=(const generic_graph_node_list_iterator & rhs) = default;
-  generic_graph_node_list_iterator& operator=(generic_graph_node_list_iterator && rhs) = default;
-  T * operator*() const { return cur; }
-  T * operator->() const { return cur; }
-  generic_graph_node_list_iterator & operator++() {
+  generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
+  generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
+  generic_graph_node_list_iterator(
+      const generic_graph_node_list_iterator& rhs) = default;
+  generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) =
+      default;
+  generic_graph_node_list_iterator& operator=(
+      const generic_graph_node_list_iterator& rhs) = default;
+  generic_graph_node_list_iterator& operator=(
+      generic_graph_node_list_iterator&& rhs) = default;
+  T* operator*() const {
+    return cur;
+  }
+  T* operator->() const {
+    return cur;
+  }
+  generic_graph_node_list_iterator& operator++() {
     JIT_ASSERT(cur);
     cur = cur->next_in_graph[d];
     return *this;
@@ -63,7 +71,7 @@ struct generic_graph_node_list_iterator {
     ++(*this);
     return old;
   }
-  generic_graph_node_list_iterator & operator--() {
+  generic_graph_node_list_iterator& operator--() {
     JIT_ASSERT(cur);
     cur = cur->next_in_graph[reverseDir()];
     return *this;
@@ -79,19 +87,20 @@ struct generic_graph_node_list_iterator {
   // silently cause the wrong one to be called.
   // iterator will point to the previous entry after call
   void destroyCurrent() {
-    T * n = cur;
+    T* n = cur;
     cur = cur->next_in_graph[reverseDir()];
     n->destroy();
   }
   generic_graph_node_list_iterator reverse() {
     return generic_graph_node_list_iterator(cur, reverseDir());
   }
-private:
+
+ private:
   int reverseDir() {
     return d == kNextDirection ? kPrevDirection : kNextDirection;
   }
-  T * cur;
-  int d; //direction 0 is forward 1 is reverse, see next_in_graph
+  T* cur;
+  int d; // direction 0 is forward 1 is reverse, see next_in_graph
 };
 
 template <typename T>
@@ -105,10 +114,10 @@ struct generic_graph_node_list {
     return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
   }
   generic_graph_node_list_iterator<T> end() {
-    return generic_graph_node_list_iterator<T>(head,d);
+    return generic_graph_node_list_iterator<T>(head, d);
   }
   generic_graph_node_list_iterator<const T> end() const {
-    return generic_graph_node_list_iterator<const T>(head,d);
+    return generic_graph_node_list_iterator<const T>(head, d);
   }
   generic_graph_node_list_iterator<T> rbegin() {
     return reverse().begin();
@@ -123,37 +132,52 @@ struct generic_graph_node_list {
     return reverse().end();
   }
   generic_graph_node_list reverse() {
-    return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
+    return generic_graph_node_list(
+        head, d == kNextDirection ? kPrevDirection : kNextDirection);
   }
   const generic_graph_node_list reverse() const {
-    return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
-  }
-  T* front() { return head->next_in_graph[d]; }
-  const T* front() const { return head->next_in_graph[d]; }
-  T* back() { return head->next_in_graph[!d]; }
-  const T* back() const { return head->next_in_graph[!d]; }
-  generic_graph_node_list(T * head, int d)
-    : head(head), d(d) {}
-private:
-  T * head;
+    return generic_graph_node_list(
+        head, d == kNextDirection ? kPrevDirection : kNextDirection);
+  }
+  T* front() {
+    return head->next_in_graph[d];
+  }
+  const T* front() const {
+    return head->next_in_graph[d];
+  }
+  T* back() {
+    return head->next_in_graph[!d];
+  }
+  const T* back() const {
+    return head->next_in_graph[!d];
+  }
+  generic_graph_node_list(T* head, int d) : head(head), d(d) {}
+
+ private:
+  T* head;
   int d;
 };
 
 template <typename T>
-static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
+static inline bool operator==(
+    generic_graph_node_list_iterator<T> a,
+    generic_graph_node_list_iterator<T> b) {
   return *a == *b;
 }
 
 template <typename T>
-static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
+static inline bool operator!=(
+    generic_graph_node_list_iterator<T> a,
+    generic_graph_node_list_iterator<T> b) {
   return *a != *b;
 }
 
-}}
+} // namespace jit
+} // namespace torch
 
 namespace std {
 
-template<typename T>
+template <typename T>
 struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
   using difference_type = int64_t;
   using value_type = T*;
@@ -162,4 +186,4 @@ struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
   using iterator_category = bidirectional_iterator_tag;
 };
 
-}
+} // namespace std
index b80f0ab..7ee0195 100644 (file)
@@ -4,13 +4,15 @@
 namespace torch {
 namespace jit {
 
-static std::function<void(std::shared_ptr<script::Module> module)> emit_module_callback;
+static std::function<void(std::shared_ptr<script::Module> module)>
+    emit_module_callback;
 TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module) {
-  if(emit_module_callback) {
+  if (emit_module_callback) {
     emit_module_callback(std::move(module));
   }
 }
-TORCH_API void setEmitModuleHook(std::function<void(std::shared_ptr<script::Module> module)> cb) {
+TORCH_API void setEmitModuleHook(
+    std::function<void(std::shared_ptr<script::Module> module)> cb) {
   emit_module_callback = std::move(cb);
 }
 } // namespace jit
index f160672..46b3398 100644 (file)
@@ -1,6 +1,6 @@
 #pragma once
-#include <functional>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <functional>
 #include <memory>
 
 namespace torch {
@@ -9,6 +9,7 @@ namespace script {
 struct Module;
 }
 TORCH_API void didFinishEmitModule(std::shared_ptr<script::Module> module);
-TORCH_API void setEmitModuleHook(std::function<void(std::shared_ptr<script::Module> module)> cb);
+TORCH_API void setEmitModuleHook(
+    std::function<void(std::shared_ptr<script::Module> module)> cb);
 } // namespace jit
 } // namespace torch
index bba1d1b..56c5859 100644 (file)
@@ -1,13 +1,12 @@
 #include <google/protobuf/util/json_util.h>
 #include <google/protobuf/util/type_resolver_util.h>
 
+#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/import.h>
+#include <torch/csrc/jit/import_method.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/import_method.h>
-
+#include <torch/csrc/utils/functional.h>
 
 #include <caffe2/core/types.h>
 #include <caffe2/proto/caffe2_pb.h>
 
 #include <ATen/ATen.h>
 
+#include <fstream>
+#include <string>
 #include <unordered_map>
 #include <vector>
-#include <string>
-#include <fstream>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -37,26 +37,27 @@ class ScriptModuleDeserializer final {
 
   ScriptModuleDeserializer(std::istream* is);
 
-  void deserialize(ModuleLookup module_lookup,
+  void deserialize(
+      ModuleLookup module_lookup,
       c10::optional<at::Device> device);
 
-private:
- at::Tensor loadTensor(
-     const torch::TensorDef& tensor_proto,
-     std::unordered_map<std::string, at::Storage>& storageMap);
+ private:
 at::Tensor loadTensor(
+      const torch::TensorDef& tensor_proto,
+      std::unordered_map<std::string, at::Storage>& storageMap);
 
- void convertModule(const torch::ModuleDef& module_def);
 void convertModule(const torch::ModuleDef& module_def);
 
- void loadTensorTable(torch::ModelDef* model_def);
 void loadTensorTable(torch::ModelDef* model_def);
 
- PyTorchStreamReader reader_;
- // this is a hack to make sure the script module created in C++ is the
- // same as created in Python
- ModuleLookup moduleLookup_;
- c10::optional<at::Device> device_;
- std::vector<std::string> moduleStack_;
 PyTorchStreamReader reader_;
 // this is a hack to make sure the script module created in C++ is the
 // same as created in Python
 ModuleLookup moduleLookup_;
 c10::optional<at::Device> device_;
 std::vector<std::string> moduleStack_;
 
- std::vector<at::Tensor> tensor_table_;
 std::vector<at::Tensor> tensor_table_;
 };
 
 ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
@@ -67,7 +68,8 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
 ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
     : reader_(is) {}
 
-void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup,
+void ScriptModuleDeserializer::deserialize(
+    ModuleLookup module_lookup,
     c10::optional<at::Device> device) {
   torch::ModelDef model_def;
   at::DataPtr data_ptr;
@@ -108,15 +110,18 @@ void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup,
 
 void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
   std::unordered_map<std::string, at::Storage> storageMap;
-  for(const torch::TensorDef& tensor : model_def->tensors()) {
+  for (const torch::TensorDef& tensor : model_def->tensors()) {
     tensor_table_.emplace_back(loadTensor(tensor, storageMap));
   }
 }
 
-at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_proto,
-                std::unordered_map<std::string, at::Storage>& storageMap) {
-  std::vector<int64_t> dims(tensor_proto.dims().begin(), tensor_proto.dims().end());
-  std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
+at::Tensor ScriptModuleDeserializer::loadTensor(
+    const torch::TensorDef& tensor_proto,
+    std::unordered_map<std::string, at::Storage>& storageMap) {
+  std::vector<int64_t> dims(
+      tensor_proto.dims().begin(), tensor_proto.dims().end());
+  std::vector<int64_t> strides(
+      tensor_proto.strides().begin(), tensor_proto.strides().end());
   auto type = at::typeMetaToScalarType(
       caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
   const std::string& record_key = tensor_proto.data().key();
@@ -138,17 +143,19 @@ at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_p
         record_size / at::CPU(type).typeMeta().itemsize(),
         nullptr); // NB: we didn't set any allocator for the tensor
     if (device.type() == at::DeviceType::CPU) {
-      storage_it = storageMap.insert(std::make_pair(
-            record_key, cpu_storage)).first;
+      storage_it =
+          storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
     } else if (device.type() == at::DeviceType::CUDA) {
-      at::Tensor cpu_tensor = at::empty({0}, at::CPU(type).options()).set_(
-          cpu_storage, tensor_proto.offset(), dims, strides);
-      at::Storage cuda_storage = cpu_tensor.to(device,
-          cpu_tensor.scalar_type()).storage();
-      storage_it = storageMap.insert(std::make_pair(
-            record_key, cuda_storage)).first;
+      at::Tensor cpu_tensor =
+          at::empty({0}, at::CPU(type).options())
+              .set_(cpu_storage, tensor_proto.offset(), dims, strides);
+      at::Storage cuda_storage =
+          cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
+      storage_it =
+          storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
     } else {
-      AT_ERROR("supported devices include CPU and CUDA, however got ",
+      AT_ERROR(
+          "supported devices include CPU and CUDA, however got ",
           at::DeviceTypeName(device.type(), false));
     }
   }
@@ -157,19 +164,20 @@ at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_p
        storage_it->second.device().index() != device.index())) {
     std::stringstream oss;
     oss << "storage previously was specified with device "
-      << storage_it->second.device()
-      << "but now is specified with device "
-      << device << std::endl;
+        << storage_it->second.device() << "but now is specified with device "
+        << device << std::endl;
     AT_ERROR(oss.str());
   }
 
   at::Tensor result;
   if (device.type() == at::DeviceType::CPU) {
-    result = at::empty({0}, at::CPU(type).options()).set_(
-        storage_it->second, tensor_proto.offset(), dims, strides);
+    result =
+        at::empty({0}, at::CPU(type).options())
+            .set_(storage_it->second, tensor_proto.offset(), dims, strides);
   } else if (device.type() == at::DeviceType::CUDA) {
-    result = at::empty({0}, at::CUDA(type).options()).set_(
-        storage_it->second, tensor_proto.offset(), dims, strides);
+    result =
+        at::empty({0}, at::CUDA(type).options())
+            .set_(storage_it->second, tensor_proto.offset(), dims, strides);
   }
   AT_ASSERT(result.defined());
 
@@ -191,19 +199,19 @@ void ScriptModuleDeserializer::convertModule(
   for (int i = 0; i < module_def.parameters_size(); ++i) {
     const torch::ParameterDef& param_def = module_def.parameters(i);
     at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
-    module->register_parameter(
-        param_def.name(), tensor, param_def.is_buffer());
+    module->register_parameter(param_def.name(), tensor, param_def.is_buffer());
   }
   if (module_def.has_torchscript_arena()) {
     at::DataPtr data;
     size_t size;
-    std::tie(data, size) = reader_.getRecord(module_def.torchscript_arena().key());
+    std::tie(data, size) =
+        reader_.getRecord(module_def.torchscript_arena().key());
     std::string data_str(static_cast<const char*>(data.get()), size);
     import_methods(module, data_str, tensor_table_);
   }
 }
 
-}  // namespace
+} // namespace
 
 void import_ir_module(
     ModuleLookup module_lookup,
@@ -221,7 +229,8 @@ void import_ir_module(
   deserializer.deserialize(module_lookup, device);
 }
 
-std::shared_ptr<script::Module> load(std::istream& in,
+std::shared_ptr<script::Module> load(
+    std::istream& in,
     c10::optional<at::Device> device) {
   auto module = std::make_shared<script::Module>();
 
@@ -242,15 +251,17 @@ std::shared_ptr<script::Module> load(std::istream& in,
   return module;
 }
 
-std::shared_ptr<script::Module> load(const std::string& filename,
+std::shared_ptr<script::Module> load(
+    const std::string& filename,
     c10::optional<at::Device> device) {
   std::ifstream in(filename, std::ios_base::binary);
 
-  AT_CHECK(! in.fail(), "load: could not open file ", filename);
+  AT_CHECK(!in.fail(), "load: could not open file ", filename);
 
   auto module = load(in, device);
 
   return module;
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 5560a7d..2252ba4 100644 (file)
@@ -25,7 +25,8 @@ TORCH_API void import_ir_module(
 ///
 /// The istream must contain a serialized `script::Module`, exported via
 /// `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(std::istream& in,
+TORCH_API std::shared_ptr<script::Module> load(
+    std::istream& in,
     c10::optional<c10::Device> device = c10::nullopt);
 
 /// Loads a serialized `script::Module` from the given `filename`.
@@ -33,7 +34,8 @@ TORCH_API std::shared_ptr<script::Module> load(std::istream& in,
 /// The file stored at the location given in `filename` must contain a
 /// serialized `script::Module`, exported either via `ScriptModule.save()` in
 /// Python or `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(const std::string& filename,
+TORCH_API std::shared_ptr<script::Module> load(
+    const std::string& filename,
     c10::optional<c10::Device> device = c10::nullopt);
 
 } // namespace jit
index 6caff06..a51e32b 100644 (file)
@@ -1,51 +1,59 @@
 #include <torch/csrc/jit/import_method.h>
 #include <torch/csrc/jit/script/parser.h>
 
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
 
 // this is a much simpler accessor that only handles modules, parameters, and
 // and methods. It does not depend on python to work.
 struct ModuleAccessorValue : public script::SugaredValue {
   ModuleAccessorValue(std::shared_ptr<script::Module> module)
-  : module(std::move(module)) {}
+      : module(std::move(module)) {}
   std::string kind() const override {
     return "module";
   }
   // select an attribute on it, e.g. `this.field`
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
-    if(script::NamedModule* v = module->find_module(field)) {
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      script::Method& m,
+      const std::string& field) override {
+    if (script::NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleAccessorValue>(v->module);
-    } else if(script::NamedParameter* v = module->find_parameter(field)) {
-      return std::make_shared<script::SimpleValue>(m.get_or_add_parameter(v->slot()));
-    } else if(script::Method* m = module->find_method(field)) {
+    } else if (script::NamedParameter* v = module->find_parameter(field)) {
+      return std::make_shared<script::SimpleValue>(
+          m.get_or_add_parameter(v->slot()));
+    } else if (script::Method* m = module->find_method(field)) {
       return std::make_shared<script::MethodValue>(module, *m);
     } else {
       throw script::ErrorReport(loc) << "unknown attr: " << field;
     }
   }
-private:
+
+ private:
   std::shared_ptr<script::Module> module;
 };
 
 struct OpsValue : public script::SugaredValue {
-  OpsValue(size_t version)
-  : version_(version) {}
+  OpsValue(size_t version) : version_(version) {}
   std::string kind() const override {
     return "ops";
   }
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      script::Method& m,
+      const std::string& field) override {
     return std::make_shared<script::BuiltinModule>(field, version_);
   }
   size_t version_;
 };
 
 struct ConstantValue : public script::SugaredValue {
-  ConstantValue(IValue value)
-  : value_(std::move(value)) {}
+  ConstantValue(IValue value) : value_(std::move(value)) {}
   IValue value_;
-  std::string kind() const override { return "constant"; }
-  Value * asValue(const SourceRange& loc, script::Method & m) override {
+  std::string kind() const override {
+    return "constant";
+  }
+  Value* asValue(const SourceRange& loc, script::Method& m) override {
     return m.graph()->insertConstant(value_);
   }
 };
@@ -54,17 +62,19 @@ struct ConstantValue : public script::SugaredValue {
 // in the 'constants' vector. This table is will be stored in a container format
 // and given to the import_method when restoring the code.
 struct ConstantTableValue : public script::SugaredValue {
-  ConstantTableValue(ArrayRef<at::Tensor> constants)
-  : constants_(constants) {}
+  ConstantTableValue(ArrayRef<at::Tensor> constants) : constants_(constants) {}
   std::string kind() const override {
     return "CONSTANTS";
   }
   // select an attribute on it, e.g. `this.field`
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, script::Method & m, const std::string& field) override {
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      script::Method& m,
+      const std::string& field) override {
     const char* field_s = field.c_str();
     char* end;
     int64_t offset = std::strtoll(field_s + 1, &end, 10);
-    if(field.size() < 2 || *end != 0)
+    if (field.size() < 2 || *end != 0)
       throw script::ErrorReport(loc) << "invalid constant specifier: " << field;
     if (offset < 0 || size_t(offset) >= constants_.size()) {
       throw script::ErrorReport(loc) << "constant index " << offset
@@ -76,7 +86,7 @@ struct ConstantTableValue : public script::SugaredValue {
   }
 
  private:
-   ArrayRef<at::Tensor> constants_;
+  ArrayRef<at::Tensor> constants_;
 };
 
 static size_t parseVersionNumber(script::Lexer& L) {
@@ -87,29 +97,40 @@ static size_t parseVersionNumber(script::Lexer& L) {
   L.expect(script::TK_NEWLINE);
   auto version = script::Const::create(L.cur().range, version_text);
   if (name != "op_version_set")
-    throw script::ErrorReport(range) << "expected an assignment to op_version_set";
+    throw script::ErrorReport(range)
+        << "expected an assignment to op_version_set";
   if (!version.isIntegral())
-    throw script::ErrorReport(range) << "expected an integral version but found " << version.text();
-   return size_t(version.asIntegral());
+    throw script::ErrorReport(range)
+        << "expected an integral version but found " << version.text();
+  return size_t(version.asIntegral());
 }
 
-void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
+void import_methods(
+    const std::shared_ptr<script::Module>& mod,
+    const std::string& src,
+    const std::vector<at::Tensor>& constant_table) {
   script::Parser p(src);
 
   size_t version = parseVersionNumber(p.lexer());
 
   std::unordered_map<std::string, std::shared_ptr<script::SugaredValue>> env = {
-    {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
-    {"ops", std::make_shared<OpsValue>(version)},
-    {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
-    {"fork", std::make_shared<script::ForkValue>()},
-    {"annotate", std::make_shared<script::AnnotateValue>()},
-    {"inf", std::make_shared<ConstantValue>(std::numeric_limits<double>::infinity())},
-    {"nan", std::make_shared<ConstantValue>(std::numeric_limits<double>::quiet_NaN())},
+      {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
+      {"ops", std::make_shared<OpsValue>(version)},
+      {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
+      {"fork", std::make_shared<script::ForkValue>()},
+      {"annotate", std::make_shared<script::AnnotateValue>()},
+      {"inf",
+       std::make_shared<ConstantValue>(
+           std::numeric_limits<double>::infinity())},
+      {"nan",
+       std::make_shared<ConstantValue>(
+           std::numeric_limits<double>::quiet_NaN())},
   };
 
-  auto resolver = [&](const std::string& name, script::Method& m, const SourceRange& loc)
-  -> std::shared_ptr<script::SugaredValue> {
+  auto resolver =
+      [&](const std::string& name,
+          script::Method& m,
+          const SourceRange& loc) -> std::shared_ptr<script::SugaredValue> {
     auto it = env.find(name);
     if (it == env.end())
       return nullptr;
@@ -128,4 +149,5 @@ void import_methods(const std::shared_ptr<script::Module>& mod, const std::strin
   script::defineMethodsInModule(mod, definitions, resolvers, self);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 9ceeb74..c8575d6 100644 (file)
@@ -1,13 +1,16 @@
 #pragma once
 
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/module.h>
 #include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/module.h>
 
 namespace torch {
 namespace jit {
 
-TORCH_API void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
+TORCH_API void import_methods(
+    const std::shared_ptr<script::Module>& mod,
+    const std::string& src,
+    const std::vector<at::Tensor>& constant_table);
 
 } // namespace jit
 } // namespace torch
index f0e35f6..38dbe69 100644 (file)
@@ -1,46 +1,45 @@
-#include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/utils/auto_gil.h>
+#include <torch/csrc/utils/pybind.h>
 
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/tracer.h>
-#include <torch/csrc/jit/python_ir.h>
-#include <torch/csrc/jit/python_arg_flatten.h>
-#include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/import.h>
 #include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/batched/BatchTensor.h>
+#include <torch/csrc/jit/export.h>
+#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/fuser/kernel_cache.h>
-#include <torch/csrc/jit/passes/remove_expands.h>
-#include <torch/csrc/jit/passes/graph_fuser.h>
-#include <torch/csrc/jit/passes/onnx.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
-#include <torch/csrc/jit/passes/erase_number_types.h>
-#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
-#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
-#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
-#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/graph_executor.h>
+#include <torch/csrc/jit/import.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
-#include <torch/csrc/jit/passes/onnx/peephole.h>
-#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/passes/canonicalize_ops.h>
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
+#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
+#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/erase_number_types.h>
+#include <torch/csrc/jit/passes/graph_fuser.h>
 #include <torch/csrc/jit/passes/loop_unrolling.h>
-#include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/passes/onnx.h>
+#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
+#include <torch/csrc/jit/passes/onnx/peephole.h>
+#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
+#include <torch/csrc/jit/passes/peephole.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
+#include <torch/csrc/jit/passes/remove_inplace_ops.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
-#include <torch/csrc/jit/graph_executor.h>
-#include <torch/csrc/jit/script/init.h>
-#include <torch/csrc/jit/script/python_tree_views.h>
-#include <torch/csrc/jit/batched/BatchTensor.h>
 #include <torch/csrc/jit/pybind_utils.h>
-#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/fuser/interface.h>
-#include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/python_arg_flatten.h>
+#include <torch/csrc/jit/python_ir.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/init.h>
 #include <torch/csrc/jit/script/jit_exception.h>
+#include <torch/csrc/jit/script/python_tree_views.h>
+#include <torch/csrc/jit/tracer.h>
 
 #include <caffe2/serialize/inline_container.h>
 
 #include <tuple>
 #include <utility>
 
-namespace torch  { namespace jit {
+namespace torch {
+namespace jit {
 
 // TODO: make a fake future for python
 namespace detail {
-class Future {
-
-};
-}
+class Future {};
+} // namespace detail
 
 namespace {
 
@@ -68,10 +66,10 @@ using autograd::variable_list;
 
 bool loadPythonClasses() {
   // Leaving this code here, because it will likely be useful at some point
-  //PyObject *jit_module = PyImport_ImportModule("torch.jit");
-  //THPUtils_assert(jit_module, "class loader couldn't access "
-          //"torch.jit module");
-  //PyObject *jit_dict = PyModule_GetDict(jit_module);
+  // PyObject *jit_module = PyImport_ImportModule("torch.jit");
+  // THPUtils_assert(jit_module, "class loader couldn't access "
+  //"torch.jit module");
+  // PyObject *jit_dict = PyModule_GetDict(jit_module);
 
   return true;
 }
@@ -86,96 +84,127 @@ std::string runJITCPPTests() {
 std::string runJITCPPTests();
 #endif
 
-void initJITBindings(PyObject *module) {
+void initJITBindings(PyObjectmodule) {
   auto m = py::handle(module).cast<py::module>();
 
   py::register_exception<JITException>(m, "JITException");
 
-  py::class_<python::IODescriptor>(m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
+  py::class_<python::IODescriptor>(
+      m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
 
   m.def("_jit_init", loadPythonClasses)
 #if USE_CUDA_FUSER || USE_CPU_FUSER
-   .def("_jit_debug_fuser_num_cached_kernel_specs",
-       torch::jit::fuser::debugNumCachedKernelSpecs)
+      .def(
+          "_jit_debug_fuser_num_cached_kernel_specs",
+          torch::jit::fuser::debugNumCachedKernelSpecs)
 #endif
-   .def("_jit_pass_onnx", ToONNX)
-   .def("_jit_pass_lower_all_tuples", LowerAllTuples)
-   .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
-   .def("_jit_pass_fuse", FuseGraph)
-   .def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
-     return EliminateDeadCode(g->block()); // overload resolution
-   })
-   .def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
-     return EliminateCommonSubexpression(g); // overload resolution
-   })
-   .def("_jit_pass_remove_inplace_ops", [](std::shared_ptr<Graph> g) {
-      return RemoveInplaceOps(g);
-   })
-   .def("_jit_pass_constant_pooling", ConstantPooling)
-   .def("_jit_pass_peephole", [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
-     return PeepholeOptimize(g, addmm_fusion_enabled);
-   }, py::arg("graph"), py::arg("addmm_fusion_enabled") = false)
-   .def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
-     return Canonicalize(g);
-   })
-   .def("_jit_pass_lint", LintGraph)
-   .def("_jit_pass_shape_analysis", [](std::shared_ptr<Graph> graph, std::vector<at::Tensor> inputs, bool with_grad) {
-     setInputTypes(*graph, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
-     PropagateInputShapes(graph);
-   })
-   .def("_jit_pass_complete_shape_analysis", [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
-     CompleteArgumentSpec spec(with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
-     auto graph_inputs = graph->inputs();
-     JIT_ASSERT(spec.size() == graph_inputs.size());
-     for (size_t i = 0; i < graph_inputs.size(); ++i) {
-       graph_inputs[i]->setType(spec.at(i));
-     }
-     PropagateInputShapes(graph);
-   })
-   .def("_jit_pass_remove_expands", RemoveExpands)
-   .def("_jit_pass_erase_number_types", EraseNumberTypes)
-   .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
-   .def("_jit_pass_loop_unrolling", UnrollLoops)
-   .def("_jit_pass_constant_propagation", [](std::shared_ptr<Graph>& g) {
-     return ConstantPropagation(g);
-   })
-   .def("_jit_pass_erase_shape_information", EraseShapeInformation)
-   .def("_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr<Graph> graph) {
-     CreateAutodiffSubgraphs(graph);
-   })
-   .def("_jit_run_cpp_tests", [] {
-     // We have to release the GIL inside this method, because if we happen to
-     // initialize the autograd engine in these tests, the newly spawned worker threads will
-     // try to initialize their PyThreadState*, and they need the GIL for this.
-     AutoNoGIL _no_gil;
-     return runJITCPPTests();
-   })
-   .def("_jit_flatten", [](py::handle& obj) {
-     auto res =  python::flatten(obj);
-     return std::make_pair(res.vars, res.desc);
-   })
-   .def("_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) {
-     return py::reinterpret_steal<py::object>(python::unflatten(vars, desc));
-   })
-   .def("_jit_pass_onnx_block", BlockToONNX)
-   .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
-   .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
-   .def("_jit_pass_specialize_undef", specializeUndef)
-   .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
-   .def("_jit_differentiate", [](Graph &g) {
-       // the python binding slightly differs in semantics
-       // it makes a copy of the input Graph, and works on that
-       // jit::differentiate mutates the input Graph
-       auto g_clone = g.copy();
-       return differentiate(g_clone);
-   })
-   .def("_jit_check_alias_annotation", [](
-         std::shared_ptr<Graph> g,
-         py::tuple args,
-         const std::string& unqualified_op_name) {
-       auto stack = toStack(args);
-       checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
-   });
+      .def("_jit_pass_onnx", ToONNX)
+      .def("_jit_pass_lower_all_tuples", LowerAllTuples)
+      .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
+      .def("_jit_pass_fuse", FuseGraph)
+      .def(
+          "_jit_pass_dce",
+          [](std::shared_ptr<Graph>& g) {
+            return EliminateDeadCode(g->block()); // overload resolution
+          })
+      .def(
+          "_jit_pass_cse",
+          [](std::shared_ptr<Graph>& g) {
+            return EliminateCommonSubexpression(g); // overload resolution
+          })
+      .def(
+          "_jit_pass_remove_inplace_ops",
+          [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
+      .def("_jit_pass_constant_pooling", ConstantPooling)
+      .def(
+          "_jit_pass_peephole",
+          [](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
+            return PeepholeOptimize(g, addmm_fusion_enabled);
+          },
+          py::arg("graph"),
+          py::arg("addmm_fusion_enabled") = false)
+      .def(
+          "_jit_pass_canonicalize",
+          [](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
+      .def("_jit_pass_lint", LintGraph)
+      .def(
+          "_jit_pass_shape_analysis",
+          [](std::shared_ptr<Graph> graph,
+             std::vector<at::Tensor> inputs,
+             bool with_grad) {
+            setInputTypes(
+                *graph,
+                ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+            PropagateInputShapes(graph);
+          })
+      .def(
+          "_jit_pass_complete_shape_analysis",
+          [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
+            CompleteArgumentSpec spec(
+                with_grad,
+                evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs()));
+            auto graph_inputs = graph->inputs();
+            JIT_ASSERT(spec.size() == graph_inputs.size());
+            for (size_t i = 0; i < graph_inputs.size(); ++i) {
+              graph_inputs[i]->setType(spec.at(i));
+            }
+            PropagateInputShapes(graph);
+          })
+      .def("_jit_pass_remove_expands", RemoveExpands)
+      .def("_jit_pass_erase_number_types", EraseNumberTypes)
+      .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
+      .def("_jit_pass_loop_unrolling", UnrollLoops)
+      .def(
+          "_jit_pass_constant_propagation",
+          [](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
+      .def("_jit_pass_erase_shape_information", EraseShapeInformation)
+      .def(
+          "_jit_pass_create_autodiff_subgraphs",
+          [](std::shared_ptr<Graph> graph) { CreateAutodiffSubgraphs(graph); })
+      .def(
+          "_jit_run_cpp_tests",
+          [] {
+            // We have to release the GIL inside this method, because if we
+            // happen to initialize the autograd engine in these tests, the
+            // newly spawned worker threads will try to initialize their
+            // PyThreadState*, and they need the GIL for this.
+            AutoNoGIL _no_gil;
+            return runJITCPPTests();
+          })
+      .def(
+          "_jit_flatten",
+          [](py::handle& obj) {
+            auto res = python::flatten(obj);
+            return std::make_pair(res.vars, res.desc);
+          })
+      .def(
+          "_jit_unflatten",
+          [](autograd::variable_list vars, python::IODescriptor& desc) {
+            return py::reinterpret_steal<py::object>(
+                python::unflatten(vars, desc));
+          })
+      .def("_jit_pass_onnx_block", BlockToONNX)
+      .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
+      .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
+      .def("_jit_pass_specialize_undef", specializeUndef)
+      .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
+      .def(
+          "_jit_differentiate",
+          [](Graph& g) {
+            // the python binding slightly differs in semantics
+            // it makes a copy of the input Graph, and works on that
+            // jit::differentiate mutates the input Graph
+            auto g_clone = g.copy();
+            return differentiate(g_clone);
+          })
+      .def(
+          "_jit_check_alias_annotation",
+          [](std::shared_ptr<Graph> g,
+             py::tuple args,
+             const std::string& unqualified_op_name) {
+            auto stack = toStack(args);
+            checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
+          });
 
   // NOLINTNEXTLINE(bugprone-unused-raii)
   py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
@@ -186,52 +215,41 @@ void initJITBindings(PyObject *module) {
       });
   // NOLINTNEXTLINE(bugprone-unused-raii)
   py::class_<ArgumentSpec>(m, "ArgumentSpec");
-  py::class_<Code>(m, "Code")
-      .def("grad_executors", [](Code& c) {
-        return py::make_iterator(c.grad_executors().begin(), c.grad_executors().end());
-      });
+  py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
+    return py::make_iterator(
+        c.grad_executors().begin(), c.grad_executors().end());
+  });
 
   py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
-    .def_property_readonly("graph", [](ExecutionPlanState& s) {
-      return s.graph;
-    })
-    .def_property_readonly("code", [](ExecutionPlanState& s) {
-      return s.code;
-    });
+      .def_property_readonly(
+          "graph", [](ExecutionPlanState& s) { return s.graph; })
+      .def_property_readonly(
+          "code", [](ExecutionPlanState& s) { return s.code; });
 
   py::class_<Gradient>(m, "Gradient")
-    .def_property_readonly("f", [](Gradient& m) {
-      return m.f;
-    })
-    .def_property_readonly("df", [](Gradient& m) {
-      return m.df;
-    })
-    .def_property_readonly("f_real_outputs", [](Gradient& m) {
-      return m.f_real_outputs;
-    })
-    .def_property_readonly("df_input_vjps", [](Gradient& m) {
-      return m.df_input_vjps;
-    })
-    .def_property_readonly("df_input_captured_inputs", [](Gradient& m) {
-      return m.df_input_captured_inputs;
-    })
-    .def_property_readonly("df_input_captured_outputs", [](Gradient& m) {
-      return m.df_input_captured_outputs;
-    })
-    .def_property_readonly("df_output_vjps", [](Gradient& m) {
-      return m.df_output_vjps;
-    });
+      .def_property_readonly("f", [](Gradient& m) { return m.f; })
+      .def_property_readonly("df", [](Gradient& m) { return m.df; })
+      .def_property_readonly(
+          "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
+      .def_property_readonly(
+          "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
+      .def_property_readonly(
+          "df_input_captured_inputs",
+          [](Gradient& m) { return m.df_input_captured_inputs; })
+      .def_property_readonly(
+          "df_input_captured_outputs",
+          [](Gradient& m) { return m.df_input_captured_outputs; })
+      .def_property_readonly(
+          "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
 
   py::class_<GraphExecutorState>(m, "GraphExecutorState")
-    .def_property_readonly("graph", [](GraphExecutorState& s) {
-      return s.graph;
-    })
-    .def_property_readonly("execution_plans", [](GraphExecutorState& s) {
-      return s.execution_plans;
-    })
-    .def_property_readonly("fallback", [](GraphExecutorState& s) {
-      return s.fallback;
-    });
+      .def_property_readonly(
+          "graph", [](GraphExecutorState& s) { return s.graph; })
+      .def_property_readonly(
+          "execution_plans",
+          [](GraphExecutorState& s) { return s.execution_plans; })
+      .def_property_readonly(
+          "fallback", [](GraphExecutorState& s) { return s.fallback; });
 
   py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
       .def(
@@ -267,8 +285,9 @@ void initJITBindings(PyObject *module) {
           "get_debug_state",
           [](GraphExecutor& ge) { return ge.getDebugState(); })
       .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
-        const auto & graph = ge.graph();
-        auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
+        const auto& graph = ge.graph();
+        auto stack =
+            evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
         {
           AutoNoGIL no_gil_guard;
           ge.run(stack);
@@ -280,79 +299,91 @@ void initJITBindings(PyObject *module) {
       .def(py::init<std::string>())
       .def(
           "write_record",
-          [](PyTorchStreamWriter& self, const std::string& name, const char* data, size_t size) {
-            return self.writeRecord(name, data, size);
-          })
+          [](PyTorchStreamWriter& self,
+             const std::string& name,
+             const char* data,
+             size_t size) { return self.writeRecord(name, data, size); })
       .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile);
 
   py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")
       .def(py::init<std::string>())
-      .def(
-          "get_record",
-          [](PyTorchStreamReader& self, const std::string& key) {
-            at::DataPtr data;
-            size_t size;
-            std::tie(data, size) = self.getRecord(key);
-            return py::bytes(reinterpret_cast<const char*>(data.get()), size);
-          });
-
+      .def("get_record", [](PyTorchStreamReader& self, const std::string& key) {
+        at::DataPtr data;
+        size_t size;
+        std::tie(data, size) = self.getRecord(key);
+        return py::bytes(reinterpret_cast<const char*>(data.get()), size);
+      });
 
-  m.def("_jit_get_operation", [](const std::string& qualified_name) {
-    try {
-      auto symbol = Symbol::fromQualString(qualified_name);
-      auto operations = getAllOperatorsFor(symbol);
-      AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
-      AT_CHECK(
-          operations.size() == 1,
-          "Found ", operations.size(), " overloads for operator ",
-          qualified_name, "! Overloads are not supported from Python.");
-      std::shared_ptr<Operator> op = operations[0];
-      AT_ASSERT(op != nullptr);
-      std::ostringstream docstring;
-      docstring << "Automatically bound operator '" << qualified_name
-                << "' with schema: " << op->schema();
-      return py::cpp_function([op](py::args args, py::kwargs kwargs) {
-        return invokeOperatorFromPython(
-            *op, std::move(args), std::move(kwargs));
-      }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str()));
-    } catch (const c10::Error& error) {
-      throw std::runtime_error(error.what_without_backtrace());
-    }
-  }, py::arg("qualified_name"));
+  m.def(
+      "_jit_get_operation",
+      [](const std::string& qualified_name) {
+        try {
+          auto symbol = Symbol::fromQualString(qualified_name);
+          auto operations = getAllOperatorsFor(symbol);
+          AT_CHECK(!operations.empty(), "No such operator ", qualified_name);
+          AT_CHECK(
+              operations.size() == 1,
+              "Found ",
+              operations.size(),
+              " overloads for operator ",
+              qualified_name,
+              "! Overloads are not supported from Python.");
+          std::shared_ptr<Operator> op = operations[0];
+          AT_ASSERT(op != nullptr);
+          std::ostringstream docstring;
+          docstring << "Automatically bound operator '" << qualified_name
+                    << "' with schema: " << op->schema();
+          return py::cpp_function(
+              [op](py::args args, py::kwargs kwargs) {
+                return invokeOperatorFromPython(
+                    *op, std::move(args), std::move(kwargs));
+              },
+              py::name(qualified_name.c_str()),
+              py::doc(docstring.str().c_str()));
+        } catch (const c10::Error& error) {
+          throw std::runtime_error(error.what_without_backtrace());
+        }
+      },
+      py::arg("qualified_name"));
 
   py::class_<FunctionSchema>(m, "FunctionSchema")
-  .def_property_readonly("name", [](FunctionSchema& self) { return self.name(); })
-  .def_property_readonly("arguments", [](FunctionSchema& self) { return self.arguments(); })
-  .def_property_readonly("returns", [](FunctionSchema& self) { return self.returns(); });
+      .def_property_readonly(
+          "name", [](FunctionSchema& self) { return self.name(); })
+      .def_property_readonly(
+          "arguments", [](FunctionSchema& self) { return self.arguments(); })
+      .def_property_readonly(
+          "returns", [](FunctionSchema& self) { return self.returns(); });
   py::class_<Argument>(m, "Argument")
-  .def_property_readonly("name", [](Argument& self) { return self.name(); })
-  .def_property_readonly("type", [](Argument& self) { return self.type(); })
-  .def_property_readonly("N", [](Argument& self) -> py::object {
-    return (self.N()) ? py::cast(*self.N()) :  py::none();
-  })
-  .def_property_readonly("default_value", [](Argument& self) -> py::object {
-    if(!self.default_value())
-      return py::none();
-    IValue v = *self.default_value();
-    return toPyObject(std::move(v));
-  });
+      .def_property_readonly("name", [](Argument& self) { return self.name(); })
+      .def_property_readonly("type", [](Argument& self) { return self.type(); })
+      .def_property_readonly(
+          "N",
+          [](Argument& self) -> py::object {
+            return (self.N()) ? py::cast(*self.N()) : py::none();
+          })
+      .def_property_readonly("default_value", [](Argument& self) -> py::object {
+        if (!self.default_value())
+          return py::none();
+        IValue v = *self.default_value();
+        return toPyObject(std::move(v));
+      });
   m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
     auto symbol = Symbol::fromQualString(qualified_name);
     auto operations = getAllOperatorsFor(symbol);
     return fmap(operations, [](const std::shared_ptr<Operator>& op) {
-        return op->schema();
-      });
+      return op->schema();
+    });
   });
 
   // NOLINTNEXTLINE(bugprone-unused-raii)
   py::class_<detail::Future>(m, "Future");
 
-  m.def("fork", [](script::Module &sm, py::args args) {
+  m.def("fork", [](script::Modulesm, py::args args) {
     // TODO: this is a fake stub
     return detail::Future();
   });
 
-  m.def("wait", [](detail::Future &fut) {
+  m.def("wait", [](detail::Futurefut) {
     // TODO: this is a fake stub
   });
 
@@ -368,4 +399,5 @@ void initJITBindings(PyObject *module) {
   initRegisterBatchOpsBindings(module);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index fbc902e..99b21d4 100644 (file)
@@ -1,7 +1,9 @@
 #pragma once
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-void initJITBindings(PyObject *module);
+void initJITBindings(PyObjectmodule);
 
-}}
+}
+} // namespace torch
index 2c45c19..d7e9028 100644 (file)
@@ -6,10 +6,10 @@
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
 
@@ -25,7 +25,8 @@
 #include <utility>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Before we translate to intepreter instructions, we do
 // some preprocessing of the graph to turn it into a form that is closer
@@ -36,11 +37,11 @@ namespace torch { namespace jit {
 // *. computes move_flags (see Outputs), and inserts
 // *  Drop nodes are inserted for any node that is unused to create a dummy use
 //    that will cause the interpreter to free the node.
-//    A drop node is just a node with no outputs that just pops its inputs off the stack,
-//    to ensure the interpreter release references to nodes that are never used.
-//    Drop nodes are also inserted when the last use of a node is in some conditionally
-//    run control flow (e.g. one side of an If) and the interpreter must free
-//    the node only after the control flow has reconverged
+//    A drop node is just a node with no outputs that just pops its inputs off
+//    the stack, to ensure the interpreter release references to nodes that are
+//    never used. Drop nodes are also inserted when the last use of a node is in
+//    some conditionally run control flow (e.g. one side of an If) and the
+//    interpreter must free the node only after the control flow has reconverged
 // Outputs are:
 // * graph - the post processed copy of g
 // * move_flags[n] - a list of booleans, one for each input,
@@ -58,24 +59,25 @@ Value* createTripCountConjunctiveCondition(
   // Emit initial comparison -- initial_trip_count < max_trip_count
   Value* initial_comparison_value =
       g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
-          ->output()->setType(BoolType::get());
+          ->output()
+          ->setType(BoolType::get());
 
   // Replace initial condition with logical `and` of trip count and
   // initial condition
   Value* new_cond =
       g->insertNode(
            g->create(aten::__and__, {initial_comparison_value, cond}, 1))
-          ->output()->setType(BoolType::get());
+          ->output()
+          ->setType(BoolType::get());
   return new_cond;
 }
 
 // this currently just _removes_ the trip count inputs and checks they are
 // unused. In the future they will be desugared into normal arithmetic to
 // provide a loop counter
-void desugarTripCounts(Block * b) {
-  for(auto n : b->nodes()) {
-
-    if(n->kind() == prim::Loop) {
+void desugarTripCounts(Block* b) {
+  for (auto n : b->nodes()) {
+    if (n->kind() == prim::Loop) {
       auto g = n->owningGraph();
       auto body_block = n->blocks()[0];
 
@@ -109,9 +111,10 @@ void desugarTripCounts(Block * b) {
         Value* const_one = g->insertConstant(1);
 
         Value* inc_trip_count =
-            g->insertNode(g->create(
-                    aten::add, {block_trip_count_input, const_one}, 1))
-             ->output()->setType(IntType::get());
+            g->insertNode(
+                 g->create(aten::add, {block_trip_count_input, const_one}, 1))
+                ->output()
+                ->setType(IntType::get());
         body_block->insertOutput(1, inc_trip_count);
 
         Value* body_cond = createTripCountConjunctiveCondition(
@@ -120,16 +123,17 @@ void desugarTripCounts(Block * b) {
         body_block->insertOutput(0, body_cond);
       }
     }
-    for(auto sb : n->blocks()) {
+    for (auto sb : n->blocks()) {
       desugarTripCounts(sb);
     }
   }
 }
 
-// removes all inputs and outputs to a graph, replacing them with Load Store nodes
-static void flattenIO(Graph & graph) {
+// removes all inputs and outputs to a graph, replacing them with Load Store
+// nodes
+static void flattenIO(Graph& graph) {
   auto load = graph.prependNode(graph.create(prim::Load, 0));
-  for(auto old_input : graph.inputs()) {
+  for (auto old_input : graph.inputs()) {
     auto nv = load->addOutput();
     nv->setType(old_input->type());
     old_input->replaceAllUsesWith(nv);
@@ -142,42 +146,40 @@ static void flattenIO(Graph & graph) {
     graph.eraseOutput(graph.outputs().size() - 1);
 }
 
-
 // insert Drop nodes to kill references for anything unused:
 // this can happen in a few places, e.g. when a node returns
 // many values but only one is used
 // a, b = foo()
 // return a
-void dropUnused(Block *b) {
+void dropUnused(Blockb) {
   auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
     std::vector<Value*> to_drop;
-    for(auto v : values) {
-      if(v->uses().size() == 0)
+    for (auto v : values) {
+      if (v->uses().size() == 0)
         to_drop.push_back(v);
     }
-    if(to_drop.size() == 0)
+    if (to_drop.size() == 0)
       return nullptr;
     return b->owningGraph()->create(prim::Drop, to_drop, 0);
   };
 
-  if(auto d = createDropIfUnused(b->inputs())) {
+  if (auto d = createDropIfUnused(b->inputs())) {
     b->prependNode(d);
   }
-  for(auto n : b->nodes()) {
-    if(auto d = createDropIfUnused(n->outputs())) {
+  for (auto n : b->nodes()) {
+    if (auto d = createDropIfUnused(n->outputs())) {
       d->insertAfter(n);
     }
-    for(auto b : n->blocks())
+    for (auto b : n->blocks())
       dropUnused(b);
   }
 }
 
-
 // for each input, should we move rather than copy the inputs
-std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
+std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph& g) {
   // struct to share common data structures
   struct FindLastUses {
-    Graph & graph;
+    Graph& graph;
     // have we seen this value, yet, if not, it is the last use of the value
     std::unordered_set<Value*> seen;
 
@@ -187,40 +189,39 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
     // when the If/Loop exits. These are created and inserted on demand.
     std::unordered_map<Node*, Node*> drop_for_node;
 
-    FindLastUses(Graph & g)
-    : graph(g) {
+    FindLastUses(Graph& g) : graph(g) {
       scanBlock(graph.block());
     }
-    void scanBlock(Block * b) {
+    void scanBlock(Block* b) {
       scanNode(b->return_node());
-      for(auto n : b->nodes().reverse()) {
+      for (auto n : b->nodes().reverse()) {
         scanNode(n);
       }
     }
-    void scanNode(Node * n) {
-      for(auto b : n->blocks()) {
+    void scanNode(Node* n) {
+      for (auto b : n->blocks()) {
         scanBlock(b);
       }
       move_flags[n].resize(n->inputs().size());
-      // scan backwards so if a value is used twice in the list then it is a move
-      for(size_t i = n->inputs().size(); i > 0; --i) {
-        scanUse(n, i-1);
+      // scan backwards so if a value is used twice in the list then it is a
+      // move
+      for (size_t i = n->inputs().size(); i > 0; --i) {
+        scanUse(n, i - 1);
       }
     }
-    void scanUse(Node * n, size_t i) {
-      auto & move_flags_n = move_flags[n];
+    void scanUse(Node* n, size_t i) {
+      auto& move_flags_n = move_flags[n];
       auto v = n->inputs()[i];
       auto inserted = seen.insert(v).second;
-      if(!inserted) {
+      if (!inserted) {
         move_flags_n[i] = false;
         return;
       }
 
       // the last use of v may be in a nested block of an If or Loop statement
-      // find the node 'same_depth_node' at the same depth as the definition of v,
-      // and consider that node to be the last use of v.
-      // This ensures we do not delete nodes in nested scopes
-      // that may be executed multiple times
+      // find the node 'same_depth_node' at the same depth as the definition of
+      // v, and consider that node to be the last use of v. This ensures we do
+      // not delete nodes in nested scopes that may be executed multiple times
       // and that nodes used on one side of an if
       // but not the other get deleted regardless of the branch
       // e.g.
@@ -230,12 +231,13 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
       // drop(a)
       // In other words, we find the first program point for v that
       // _reverse_ dominates the definition of v, and add a drop point there.
-      Node * same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
-      JIT_ASSERT(same_depth_node); // failure means v is not in scope for n, use lint!
+      Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
+      JIT_ASSERT(
+          same_depth_node); // failure means v is not in scope for n, use lint!
 
       // In the case where v and n are in the same block, just mark
       // its move_flags to be true
-      if(same_depth_node == n) {
+      if (same_depth_node == n) {
         move_flags_n[i] = true;
         return;
       }
@@ -243,7 +245,8 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
       // in the case where the use is nested in a block
       // add a Drop node after that block which will drop 'v'.
       move_flags_n[i] = false;
-      addToDropIfNotExists(findOrCreateDropInstructionForNode(same_depth_node), v);
+      addToDropIfNotExists(
+          findOrCreateDropInstructionForNode(same_depth_node), v);
     }
 
     // finds the node in block 'block' that contains in 'n'
@@ -252,16 +255,16 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
     // n1: if <cond>:
     // n2:    b = a + a
     // findOwnerInBlock(n2, n0.block()) == n1
-    Node * findOwnerInBlock(Node * n, Block * block) {
-      while(n != nullptr && block != n->owningBlock()) {
+    Node* findOwnerInBlock(Node* n, Block* block) {
+      while (n != nullptr && block != n->owningBlock()) {
         n = n->owningBlock()->owningNode();
       }
       return n;
     }
 
-    Node * findOrCreateDropInstructionForNode(Node * n) {
+    Node* findOrCreateDropInstructionForNode(Node* n) {
       auto it = drop_for_node.find(n);
-      if(it == drop_for_node.end()) {
+      if (it == drop_for_node.end()) {
         auto drop_node = graph.create(prim::Drop, 0);
         drop_node->insertAfter(n);
         it = drop_for_node.emplace(n, drop_node).first;
@@ -269,10 +272,10 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
       return it->second;
     }
 
-    void addToDropIfNotExists(Node * drop, Value * v) {
-      for(auto i : drop->inputs()) {
+    void addToDropIfNotExists(Node* drop, Value* v) {
+      for (auto i : drop->inputs()) {
         // we already accounted for this use
-        if(i == v)
+        if (i == v)
           return;
       }
       drop->addInput(v);
@@ -282,19 +285,18 @@ std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph & g) {
 
   return FindLastUses(g).move_flags;
 }
-} //namespace
+} // namespace
 
 // pre-processing that happens once per graph
 struct PreprocessGraph {
-  PreprocessGraph(Graph & g)
-  : graph(g.copy()) {
+  PreprocessGraph(Graph& g) : graph(g.copy()) {
     n_outputs = graph->outputs().size();
     desugarTripCounts(graph->block());
     flattenIO(*graph);
     dropUnused(graph->block());
     // fill in move_flags by scanning blocks;
     move_flags = findLastUses(*graph);
-    //TODO: desugar Loop trip counts, for now we drop trip counts
+    // TODO: desugar Loop trip counts, for now we drop trip counts
   }
   // Outputs of the preprocessing:
   std::shared_ptr<Graph> graph;
@@ -311,9 +313,13 @@ struct PreprocessGraph {
 // Note: this is currently unused but will probably be useful in the future,
 // so we keep it around
 struct ContainerTensor : public at::TensorImpl {
-public:
+ public:
   ContainerTensor()
-  : TensorImpl(at::UndefinedTensorId(), caffe2::TypeMeta(), nullptr, /* is_variable */ false) {}
+      : TensorImpl(
+            at::UndefinedTensorId(),
+            caffe2::TypeMeta(),
+            nullptr,
+            /* is_variable */ false) {}
 
   ~ContainerTensor() override = default;
   at::IntList sizes() const override {
@@ -335,7 +341,7 @@ public:
 // which are stored in the ListHandle struct
 // start is an offset into int_data of Code for ListHandle<int>
 // and bool_data of Code for ListHandle<bool>
-template<typename T>
+template <typename T>
 struct ListHandle {
   int start;
   int size;
@@ -358,24 +364,22 @@ struct Instruction {
   std::shared_ptr<SourceLocation> debug_location; // for error reporting
 };
 
-
 int relativeJump(int from_inst, int to_inst) {
   return to_inst - (from_inst + 1);
 }
 
 struct CodeImpl {
-  CodeImpl(const std::shared_ptr<Graph>& graph_)
-      : preprocess(*graph_) {
+  CodeImpl(const std::shared_ptr<Graph>& graph_) : preprocess(*graph_) {
     graph = preprocess.graph;
     insertNodesFromBlock(graph->block());
   }
 
   // jump when input is false
   void createJumpFalse(int from_inst, int to_inst) {
-    auto & inst = instructions[from_inst];
+    auto& inst = instructions[from_inst];
     JIT_ASSERT(inst.debug_name == prim::Placeholder);
     auto offset = relativeJump(from_inst, to_inst);
-    inst.callback = [offset](Stack & stack) {
+    inst.callback = [offset](Stack& stack) {
       auto t = pop(stack).toBool();
       return t ? 0 : offset;
     };
@@ -384,10 +388,10 @@ struct CodeImpl {
 
   // jump when input is true
   void createJumpTrue(int from_inst, int to_inst) {
-    auto & inst = instructions[from_inst];
+    auto& inst = instructions[from_inst];
     JIT_ASSERT(inst.debug_name == prim::Placeholder);
     auto offset = relativeJump(from_inst, to_inst);
-    inst.callback = [offset](Stack & stack) {
+    inst.callback = [offset](Stack& stack) {
       auto t = pop(stack).toBool();
       return t ? offset : 0;
     };
@@ -395,19 +399,17 @@ struct CodeImpl {
   }
 
   void createJump(int from_inst, int to_inst) {
-    auto & inst = instructions[from_inst];
+    auto& inst = instructions[from_inst];
     JIT_ASSERT(inst.debug_name == prim::Placeholder);
     auto offset = relativeJump(from_inst, to_inst);
-    inst.callback = [=](Stack & stack) {
-      return offset;
-    };
+    inst.callback = [=](Stack& stack) { return offset; };
     inst.debug_name = prim::Jump;
   }
 
   void insertNodesFromBlock(Block* block) {
-    for(auto node : block->nodes()) {
-      const auto & source_location = node->getSourceLocation();
-      switch(node->kind()) {
+    for (auto node : block->nodes()) {
+      const auto& source_location = node->getSourceLocation();
+      switch (node->kind()) {
         case prim::If: {
           // x = if c:
           //   <then_block>
@@ -426,17 +428,31 @@ struct CodeImpl {
           //   x = vt
           // end:
 
-          // prim::Placeholder instructions are replaced with branch instructions
-          // when the branch target locations are known
-          auto cond_branch = insertInstruction(prim::Placeholder, source_location, node->inputs(), moveFlags(node), {});
+          // prim::Placeholder instructions are replaced with branch
+          // instructions when the branch target locations are known
+          auto cond_branch = insertInstruction(
+              prim::Placeholder,
+              source_location,
+              node->inputs(),
+              moveFlags(node),
+              {});
           auto then_block = node->blocks()[0];
           auto else_block = node->blocks()[1];
           insertNodesFromBlock(else_block);
-          insertAssign(source_location,else_block->outputs(), moveFlags(else_block), node->outputs());
-          auto jump = insertInstruction(prim::Placeholder, source_location, {}, {}, {});
+          insertAssign(
+              source_location,
+              else_block->outputs(),
+              moveFlags(else_block),
+              node->outputs());
+          auto jump =
+              insertInstruction(prim::Placeholder, source_location, {}, {}, {});
           auto then_block_start = instructions.size();
           insertNodesFromBlock(then_block);
-          insertAssign(source_location, then_block->outputs(), moveFlags(then_block), node->outputs());
+          insertAssign(
+              source_location,
+              then_block->outputs(),
+              moveFlags(then_block),
+              node->outputs());
           createJump(jump, instructions.size());
           createJumpTrue(cond_branch, then_block_start);
         } break;
@@ -458,114 +474,140 @@ struct CodeImpl {
           auto body_block = node->blocks()[0];
 
           // before assign op: stack: ... <cond> <loop-carried-depdencies>
-          insertAssign(source_location, node->inputs(), moveFlags(node), body_block->inputs());
+          insertAssign(
+              source_location,
+              node->inputs(),
+              moveFlags(node),
+              body_block->inputs());
           // after assign op: stack: ... <cond>
           // cond_branch consumes <cond> from top of the stack
-          auto cond_branch = insertInstruction(prim::Placeholder, source_location,{}, {}, {});
+          auto cond_branch =
+              insertInstruction(prim::Placeholder, source_location, {}, {}, {});
           // after branch: stack: ...
 
           auto entry = instructions.size();
           insertNodesFromBlock(body_block);
           // before assign op: stack: ... <cond> <loop-carried-depdencies>
-          insertAssign(source_location, body_block->outputs(), moveFlags(body_block), body_block->inputs());
+          insertAssign(
+              source_location,
+              body_block->outputs(),
+              moveFlags(body_block),
+              body_block->inputs());
           // after assign op: stack: ... <cond>
-          auto cond_branch_end = insertInstruction(prim::Placeholder, source_location, {}, {}, {});
+          auto cond_branch_end =
+              insertInstruction(prim::Placeholder, source_location, {}, {}, {});
           // after branch: stack: ...
 
           aliasRegistersTo(node->outputs(), body_block->inputs());
           createJumpFalse(cond_branch, instructions.size());
           createJumpTrue(cond_branch_end, entry);
         } break;
-        default: {
-          insertInstruction(node);
-        } break;
+        default: { insertInstruction(node); } break;
       }
     }
   }
 
-  size_t insertInstruction(Node * n) {
-    auto inst = insertInstruction(n->kind(), n->getSourceLocation(), n->inputs(), moveFlags(n) , n->outputs());
+  size_t insertInstruction(Node* n) {
+    auto inst = insertInstruction(
+        n->kind(),
+        n->getSourceLocation(),
+        n->inputs(),
+        moveFlags(n),
+        n->outputs());
     instructions[inst].callback = getOperation(n);
     return inst;
   }
-  size_t insertInstruction(Symbol sym,
-                           std::shared_ptr<SourceLocation> debug_location,
-                                 ArrayRef<Value*> inputs,
-                                 ArrayRef<uint8_t> move_flags,
-                                 ArrayRef<Value*> outputs) {
+  size_t insertInstruction(
+      Symbol sym,
+      std::shared_ptr<SourceLocation> debug_location,
+      ArrayRef<Value*> inputs,
+      ArrayRef<uint8_t> move_flags,
+      ArrayRef<Value*> outputs) {
     instructions.emplace_back();
-    auto & inst = instructions.back();
+    auto& inst = instructions.back();
     inst.debug_name = sym;
     inst.debug_location = std::move(debug_location);
     listBegin(inst.inputs.values);
-    for(auto input : inputs) {
+    for (auto input : inputs) {
       listInsert(inst.inputs.values, getOrAllocateRegister(input, true));
     }
     listBegin(inst.inputs.free_flags);
-    for(auto flag : move_flags) {
+    for (auto flag : move_flags) {
       listInsert(inst.inputs.free_flags, flag);
     }
     listBegin(inst.outputs);
-    for(auto output : outputs) {
+    for (auto output : outputs) {
       listInsert(inst.outputs, getOrAllocateRegister(output));
     }
     return instructions.size() - 1;
   }
-  ArrayRef<uint8_t> moveFlags(Node * n) {
+  ArrayRef<uint8_t> moveFlags(Node* n) {
     return preprocess.move_flags.at(n);
   }
-  ArrayRef<uint8_t> moveFlags(Block *b) {
+  ArrayRef<uint8_t> moveFlags(Blockb) {
     return moveFlags(b->return_node());
   }
 
-  size_t insertAssign(std::shared_ptr<SourceLocation> debug_location, ArrayRef<Value*> inputs, ArrayRef<uint8_t> move_flags, ArrayRef<Value*> outputs) {
-    auto inst = insertInstruction(prim::Assign, std::move(debug_location),inputs, move_flags, outputs);
-    // This node effectively forwards its inputs into different places in a register list.
-    // We don't need to manipulate the stack in any way, because all inputs are also outputs,
-    // and the interpreter will take care of putting them in correct places.
+  size_t insertAssign(
+      std::shared_ptr<SourceLocation> debug_location,
+      ArrayRef<Value*> inputs,
+      ArrayRef<uint8_t> move_flags,
+      ArrayRef<Value*> outputs) {
+    auto inst = insertInstruction(
+        prim::Assign, std::move(debug_location), inputs, move_flags, outputs);
+    // This node effectively forwards its inputs into different places in a
+    // register list. We don't need to manipulate the stack in any way, because
+    // all inputs are also outputs, and the interpreter will take care of
+    // putting them in correct places.
     instructions[inst].callback = [](Stack& stack) { return 0; };
     return inst;
   }
 
   // helpers to build/access RegList objects
-  int get(const ListHandle<int> & list, int i)  const {
+  int get(const ListHandle<int>& list, int i) const {
     return int_data[list.start + i];
   }
-  bool get(const ListHandle<bool> & list, int i) const {
+  bool get(const ListHandle<bool>& list, int i) const {
     return bool_data[list.start + i];
   }
-  void listBegin(ListHandle<int> & list) {
+  void listBegin(ListHandle<int>& list) {
     list.start = int_data.size();
     list.size = 0;
   }
-  void listInsert(ListHandle<int> & list, int value) {
-    JIT_ASSERTM(list.start + list.size == (int)int_data.size(), "another list already started");
+  void listInsert(ListHandle<int>& list, int value) {
+    JIT_ASSERTM(
+        list.start + list.size == (int)int_data.size(),
+        "another list already started");
     int_data.push_back(value);
     list.size++;
   }
-  void listBegin(ListHandle<bool> & list) {
+  void listBegin(ListHandle<bool>& list) {
     list.start = bool_data.size();
     list.size = 0;
   }
-  void listInsert(ListHandle<bool> & list, int value) {
-    JIT_ASSERTM(list.start + list.size == (int)bool_data.size(), "another list already started");
+  void listInsert(ListHandle<bool>& list, int value) {
+    JIT_ASSERTM(
+        list.start + list.size == (int)bool_data.size(),
+        "another list already started");
     bool_data.push_back(value);
     list.size++;
   }
   // must be called before any new_allocations are used, otherwise they will
   // already have registers assigned
-  void aliasRegistersTo(ArrayRef<Value*> new_allocations, ArrayRef<Value*> existing_allocations) {
+  void aliasRegistersTo(
+      ArrayRef<Value*> new_allocations,
+      ArrayRef<Value*> existing_allocations) {
     JIT_ASSERT(new_allocations.size() == existing_allocations.size());
-    for(size_t i = 0; i < new_allocations.size(); ++i) {
+    for (size_t i = 0; i < new_allocations.size(); ++i) {
       auto n = new_allocations[i]->unique();
       auto e = existing_allocations[i]->unique();
       JIT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0);
       unique_to_reg[n] = unique_to_reg[e];
     }
   }
-  int getOrAllocateRegister(Value * n, bool required = false) {
+  int getOrAllocateRegister(Value* n, bool required = false) {
     size_t u = n->unique();
-    if(unique_to_reg.count(u) > 0)
+    if (unique_to_reg.count(u) > 0)
       return unique_to_reg[u];
     JIT_ASSERT(!required);
     int r = register_size++;
@@ -576,7 +618,7 @@ struct CodeImpl {
   const std::vector<GraphExecutor*>& grad_executors() {
     if (!grad_executors_) {
       grad_executors_.emplace();
-      for (Instruction & instr : instructions) {
+      for (Instruction& instr : instructions) {
         if (auto executor = detail::getGradExecutor(instr.callback)) {
           grad_executors_->push_back(executor);
         }
@@ -585,33 +627,33 @@ struct CodeImpl {
     return *grad_executors_;
   }
 
-  void dumpInstruction(std::ostream & out, size_t pc) const {
-    auto writeList = [&](const ListHandle<int> & list) {
-      for(int i = 0; i < list.size; i++) {
-        if(i > 0)
+  void dumpInstruction(std::ostream& out, size_t pc) const {
+    auto writeList = [&](const ListHandle<int>& list) {
+      for (int i = 0; i < list.size; i++) {
+        if (i > 0)
           out << ", ";
         out << get(list, i);
       }
     };
-    auto writeUseList = [&](const UseList & list) {
-      for(int i = 0; i < list.values.size; i++) {
-        if(i > 0)
+    auto writeUseList = [&](const UseList& list) {
+      for (int i = 0; i < list.values.size; i++) {
+        if (i > 0)
           out << ", ";
-        if(get(list.free_flags, i))
+        if (get(list.free_flags, i))
           out << "move(" << get(list.values, i) << ")";
         else
           out << get(list.values, i);
       }
     };
-    auto & inst = instructions.at(pc);
+    auto& inst = instructions.at(pc);
     writeList(inst.outputs);
     // NB: debug names are the kind of operator used to select
     // dispatch
     out << " = " << inst.debug_name.toUnqualString() << " ";
     writeUseList(inst.inputs);
   }
-  void dump(std::ostream & out) const {
-    for(size_t i = 0; i < instructions.size(); ++i) {
+  void dump(std::ostream& out) const {
+    for (size_t i = 0; i < instructions.size(); ++i) {
       dumpInstruction(out, i);
       out << "\n";
     }
@@ -626,7 +668,8 @@ struct CodeImpl {
   c10::optional<std::vector<GraphExecutor*>> grad_executors_;
   PreprocessGraph preprocess;
 
-  std::unordered_map<size_t, int> unique_to_reg; // map from unique of nodes to register in register table
+  std::unordered_map<size_t, int>
+      unique_to_reg; // map from unique of nodes to register in register table
 
   friend struct InterpreterState;
   std::vector<Instruction> instructions;
@@ -640,12 +683,11 @@ struct CodeImpl {
 
 // InterpreterState state that and used to compute a Code
 struct InterpreterStateImpl : c10::intrusive_ptr_target {
-  InterpreterStateImpl(const Code & code)
-  : function(code.pImpl),
-    int_data(function->int_data.data()),
-    bool_data(function->bool_data),
-    registers(function->register_size) {
-  }
+  InterpreterStateImpl(const Code& code)
+      : function(code.pImpl),
+        int_data(function->int_data.data()),
+        bool_data(function->bool_data),
+        registers(function->register_size) {}
 
  private:
   c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
@@ -654,60 +696,61 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
   }
 
   bool runImpl(Stack& stack) {
-    auto & instructions = function->instructions;
+    auto& instructions = function->instructions;
     size_t last = instructions.size();
 
     while (pc < last) {
-        // std::cout << "executing " << pc << ": ";
-        // function->dumpInstruction(std::cout, pc);
-        // std::cout << "\n";
-        auto & inst = instructions[pc];
-        try {
-          loadTensorsFromRegisters(inst.inputs, stack);
-          size_t new_pc = pc + 1 + inst.callback(stack);
-          for (int i = inst.outputs.size - 1; i >= 0; --i) {
-            int reg = get(inst.outputs, i);
-            registers[reg] = pop(stack);
-            // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n";
-          }
-          pc = new_pc;
-        } catch (Suspend& e) {
-          // wait() expects a single input
-          JIT_ASSERT(inst.inputs.values.size == 1);
-
-          getOrCreateFuture();
-
-          if (get(inst.inputs.free_flags, 0)) {
-            // make sure the register is not freed once we are waked up
-            registers[get(inst.inputs.values, 0)] = e.future;
-          }
-
-          // Make sure adding callback is the last step.
-          // Otherwise if e.future has completed,
-          // the current thread will continue running before it suspends.
-          InterpreterState state(intrusive_from_this());
-          e.future->addCallback([state]() {
-            c10::global_work_queue().run(
-                InterpreterContinuation(state, Stack()));
-          });
-
-          return true;
-        } catch (Future::FutureError& e) {
-          // Error from the forked thread.
-          auto msg = e.error_msg; // copy the error for each callback
-          handleError(std::move(msg), false);
-          return false;
-        } catch (std::exception& e) {
-          // Error from the current thread
-          bool is_jit_exception = dynamic_cast<JITException*>(&e);
-          if (instructions[pc].debug_location) {
-            handleError(instructions[pc].debug_location->wrapException(
-                e, "operation failed in interpreter"), is_jit_exception);
-          } else {
-            handleError(e.what(), is_jit_exception);
-          }
-          return false;
+      // std::cout << "executing " << pc << ": ";
+      // function->dumpInstruction(std::cout, pc);
+      // std::cout << "\n";
+      auto& inst = instructions[pc];
+      try {
+        loadTensorsFromRegisters(inst.inputs, stack);
+        size_t new_pc = pc + 1 + inst.callback(stack);
+        for (int i = inst.outputs.size - 1; i >= 0; --i) {
+          int reg = get(inst.outputs, i);
+          registers[reg] = pop(stack);
+          // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n";
         }
+        pc = new_pc;
+      } catch (Suspend& e) {
+        // wait() expects a single input
+        JIT_ASSERT(inst.inputs.values.size == 1);
+
+        getOrCreateFuture();
+
+        if (get(inst.inputs.free_flags, 0)) {
+          // make sure the register is not freed once we are waked up
+          registers[get(inst.inputs.values, 0)] = e.future;
+        }
+
+        // Make sure adding callback is the last step.
+        // Otherwise if e.future has completed,
+        // the current thread will continue running before it suspends.
+        InterpreterState state(intrusive_from_this());
+        e.future->addCallback([state]() {
+          c10::global_work_queue().run(InterpreterContinuation(state, Stack()));
+        });
+
+        return true;
+      } catch (Future::FutureError& e) {
+        // Error from the forked thread.
+        auto msg = e.error_msg; // copy the error for each callback
+        handleError(std::move(msg), false);
+        return false;
+      } catch (std::exception& e) {
+        // Error from the current thread
+        bool is_jit_exception = dynamic_cast<JITException*>(&e);
+        if (instructions[pc].debug_location) {
+          handleError(
+              instructions[pc].debug_location->wrapException(
+                  e, "operation failed in interpreter"),
+              is_jit_exception);
+        } else {
+          handleError(e.what(), is_jit_exception);
+        }
+        return false;
+      }
     }
     if (future) {
       auto num_outputs = function->preprocess.n_outputs;
@@ -762,22 +805,21 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
     }
   }
 
-  int get(const ListHandle<int> & list, int i) {
+  int get(const ListHandle<int>& list, int i) {
     return int_data[list.start + i];
   };
-  bool get(const ListHandle<bool> & list, int i) {
+  bool get(const ListHandle<bool>& list, int i) {
     return bool_data[list.start + i];
   }
-  void loadTensorsFromRegisters(const UseList & uses, Stack & stack) {
-    for(int i = 0; i < uses.values.size; i++) {
-      int reg = get(uses.values,i);
+  void loadTensorsFromRegisters(const UseList& uses, Stack& stack) {
+    for (int i = 0; i < uses.values.size; i++) {
+      int reg = get(uses.values, i);
       // std::cout << "push reg[" << reg << "];\n" << registers[reg] << "\n\n";
-      if(get(uses.free_flags,i)) {
+      if (get(uses.free_flags, i)) {
         stack.push_back(std::move(registers[reg]));
       } else {
         stack.push_back(registers[reg]);
       }
-
     }
   }
 
@@ -786,9 +828,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
   c10::intrusive_ptr<Future> future;
   std::shared_ptr<CodeImpl> function; // keep function alive
   // these are just copies of function to prevent indirections in interpreter
-  int * int_data;
-  const std::vector<bool> & bool_data;
-
+  int* int_data;
+  const std::vector<bool>& bool_data;
 
   // this holds all the tensors for this interpreter run
   // we don't bother minimizing the size of this vector, since the extra
@@ -797,31 +838,31 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
   // to make sure memory management happens efficiently.
 
   // We optimize for the case where derivatives are run with retain_graph=False
-  // in the case where it is true, then the interpreter and this array get copied
-  // if this every becomes a bottleneck then we _should_ consider minimizing the
-  // total number or register
+  // in the case where it is true, then the interpreter and this array get
+  // copied if this every becomes a bottleneck then we _should_ consider
+  // minimizing the total number or register
   std::vector<IValue> registers;
 
-  // single buffer for input/output calls to ATen functions, so that we do not reallocate
+  // single buffer for input/output calls to ATen functions, so that we do not
+  // reallocate
   Stack stack;
 };
 
-std::ostream & operator<<(std::ostream & out, const Code & code) {
+std::ostream& operator<<(std::ostream& out, const Code& code) {
   out << *code.pImpl->graph << "\n";
   code.pImpl->dump(out);
   return out;
 }
 
-Code::Code(const std::shared_ptr<Graph>& graph)
-    : pImpl(new CodeImpl(graph)) {}
+Code::Code(const std::shared_ptr<Graph>& graph) : pImpl(new CodeImpl(graph)) {}
 Code::~Code() = default;
 
 const std::vector<GraphExecutor*>& Code::grad_executors() {
   return pImpl->grad_executors();
 }
 
-InterpreterState::InterpreterState(const Code & code)
-  : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
+InterpreterState::InterpreterState(const Code& code)
+    : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
 InterpreterState::~InterpreterState() = default;
 
 void InterpreterState::run(Stack& stack) {
@@ -836,6 +877,8 @@ c10::intrusive_ptr<Future> InterpreterState::getFuture() {
   return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
 }
 
-InterpreterState::InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
+InterpreterState::InterpreterState(
+    c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
     : pImpl(std::move(pImpl_)) {}
-}}
+} // namespace jit
+} // namespace torch
index facbd61..3ef761f 100644 (file)
@@ -1,18 +1,19 @@
 #pragma once
+#include <c10/util/Optional.h>
 #include <memory>
 #include <vector>
-#include <c10/util/Optional.h>
 
-#include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ivalue.h>
 
 namespace at {
-  class Tensor;
+class Tensor;
 }
 namespace c10 {
 struct IValue;
 }
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // The interpreter run Graphs with Tensor inputs and Tensor outputs
 // a separate component in the autograd handles unwrapping and wrapping
@@ -27,8 +28,7 @@ struct Node;
 using Stack = std::vector<c10::IValue>;
 
 struct TORCH_API Code {
-  Code()
-    : pImpl(nullptr) {}
+  Code() : pImpl(nullptr) {}
   explicit Code(const std::shared_ptr<Graph>& graph);
   ~Code();
 
@@ -38,19 +38,20 @@ struct TORCH_API Code {
     return pImpl != nullptr;
   }
 
-private:
+ private:
   std::shared_ptr<CodeImpl> pImpl;
   friend struct InterpreterStateImpl;
-  friend std::ostream & operator<<(std::ostream & out, const Code & code);
+  friend std::ostream& operator<<(std::ostream& out, const Code& code);
 };
 
 struct InterpreterState {
-  InterpreterState(const Code & code);
+  InterpreterState(const Code& code);
   void run(Stack& stack);
   c10::intrusive_ptr<Future> runAsync(Stack& stack);
   c10::intrusive_ptr<Future> getFuture();
   ~InterpreterState();
-private:
+
+ private:
   InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
   // Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl;
   // but intrusive_ptr requires full definition of InterpreterStateImpl,
@@ -83,4 +84,5 @@ struct InterpreterContinuation {
   InterpreterState state;
   Stack stack;
 };
-}}
+} // namespace jit
+} // namespace torch
index f30e92a..17d640d 100644 (file)
@@ -1,13 +1,12 @@
 #include <torch/csrc/jit/ir.h>
 
-
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/autograd/function.h>
-#include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 
 #include <algorithm>
 #include <iostream>
@@ -19,7 +18,8 @@
 #include <unordered_set>
 #include <utility>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 // Constants relating to maintaining the topological index of nodes.
 //
 // Lower and upper bounds of the index. Inclusive range.
@@ -32,26 +32,27 @@ static constexpr topo_position_t kMidPoint = 0;
 //   - 2^(64-n) is the maximum number of appends to the end without reindex
 static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
 
-// Sigh, see https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
+// Sigh, see
+// https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
 constexpr Symbol PythonOp::Kind;
 
-void printValueRef(std::ostream & out, const Value * n) {
+void printValueRef(std::ostream& out, const Value* n) {
   out << "%" << n->uniqueName();
 }
 
 // NB: This overload will become ambiguous with the one Caffe2 provides in its
 // logging, if they ever intersect.
 template <typename T>
-std::ostream& operator<<(std::ostream & out, const std::vector<T> & nodes) {
+std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
   out << at::ArrayRef<T>{nodes};
   return out;
 }
 
 template <typename T>
-std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef<T> & nodes) {
+std::ostream& printValueRefs(std::ostream& out, const at::ArrayRef<T>& nodes) {
   size_t i = 0;
-  for(auto n : nodes) {
-    if(i++ > 0)
+  for (auto n : nodes) {
+    if (i++ > 0)
       out << ", ";
     printValueRef(out, n);
   }
@@ -61,24 +62,28 @@ std::ostream& printValueRefs(std::ostream & out, const at::ArrayRef<T> & nodes)
 // Can't make these two overloads directly a template, it'll be ambiguous with
 // the global printer for operator<<.
 
-std::ostream& operator<<(std::ostream & out, const at::ArrayRef<const Value*> & nodes) {
+std::ostream& operator<<(
+    std::ostream& out,
+    const at::ArrayRef<const Value*>& nodes) {
   return printValueRefs(out, nodes);
 }
 
-std::ostream& operator<<(std::ostream & out, const at::ArrayRef<Value*> & nodes) {
+std::ostream& operator<<(std::ostream& out, const at::ArrayRef<Value*>& nodes) {
   return printValueRefs(out, nodes);
 }
 
 struct const_value_list_with_types {
   const ArrayRef<const Value*> values;
   bool use_newlines;
-  const_value_list_with_types(ArrayRef<const Value*> values, bool use_newlines = false)
-    : values(values), use_newlines(use_newlines) {}
+  const_value_list_with_types(
+      ArrayRef<const Value*> values,
+      bool use_newlines = false)
+      : values(values), use_newlines(use_newlines) {}
 };
-std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) {
+std::ostream& operator<<(std::ostream& out, const_value_list_with_types l) {
   size_t i = 0;
-  for(auto n : l.values) {
-    if(i++ > 0) {
+  for (auto n : l.values) {
+    if (i++ > 0) {
       if (l.use_newlines) {
         // TODO: Indent here is hard-coded for "graph(": un-hard-code it
         out << "\n      ";
@@ -93,14 +98,17 @@ std::ostream& operator<<(std::ostream & out, const_value_list_with_types l) {
   return out;
 }
 
-void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=false) {
+void printAttributes(
+    std::ostream& out,
+    const Node* n,
+    bool ignore_subgraph = false) {
   out << "[";
   auto names = n->attributeNames();
   int i = 0;
-  for(auto name : names) {
+  for (auto name : names) {
     if (ignore_subgraph && name == attr::Subgraph)
       continue;
-    if(i++ > 0)
+    if (i++ > 0)
       out << ", ";
     // TODO: debugging mode to see the qualifier.  We definitely
     // don't want to print the qualifier since it should always
@@ -113,46 +121,51 @@ void printAttributes(std::ostream & out, const Node * n, bool ignore_subgraph=fa
   out << "]";
 }
 
-static std::ostream & indent(std::ostream & out, size_t level) {
-  for(size_t i = 0; i < level; ++i)
+static std::ostream& indent(std::ostream& out, size_t level) {
+  for (size_t i = 0; i < level; ++i)
     out << "  ";
   return out;
 }
 
-std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::vector<const Node*> * groups) {
+std::ostream& printNode(
+    std::ostream& out,
+    size_t level,
+    const Node* n,
+    std::vector<const Node*>* groups) {
   auto outputs = n->outputs();
   indent(out, level) << const_value_list_with_types(outputs);
   out << " = ";
-  IR_IFM_CONST(n,PythonOp)
-    out << "^" << value->name();
-    value->writeScalars(out);
+  IR_IFM_CONST(n, PythonOp)
+  out << "^" << value->name();
+  value->writeScalars(out);
   IR_ELSE()
-    if(n->hasAttribute(attr::Subgraph) && groups) {
-      out << n->kind().toQualString() << "_" << groups->size();
-      if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
-        printAttributes(out, n, /*ignore_subgraph=*/true);
-      }
-      groups->push_back(n);
-    } else {
-      out << n->kind().toQualString();
-      if(n->hasAttributes()) {
-        printAttributes(out,n);
-      }
+  if (n->hasAttribute(attr::Subgraph) && groups) {
+    out << n->kind().toQualString() << "_" << groups->size();
+    if (n->numAttributes() > 1 && n->kind() != prim::DifferentiableGraph) {
+      printAttributes(out, n, /*ignore_subgraph=*/true);
+    }
+    groups->push_back(n);
+  } else {
+    out << n->kind().toQualString();
+    if (n->hasAttributes()) {
+      printAttributes(out, n);
     }
+  }
   IR_END()
   out << "(" << n->inputs() << ")";
   std::string scopeName = n->scopeName();
   if (scopeName.empty()) {
     out << "\n";
-  }
-  else {
+  } else {
     out << ", ";
     out << "scope: " << scopeName << "\n";
   }
-  for(size_t i = 0; i < n->blocks().size(); ++i) {
+  for (size_t i = 0; i < n->blocks().size(); ++i) {
     auto b = n->blocks()[i];
-    indent(out, level + 1) << "block" << i << "(" << const_value_list_with_types(b->inputs(), false) << ") {\n";
-    for(auto n : b->nodes()) {
+    indent(out, level + 1) << "block" << i << "("
+                           << const_value_list_with_types(b->inputs(), false)
+                           << ") {\n";
+    for (auto n : b->nodes()) {
       printNode(out, level + 2, n, groups);
     }
     indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
@@ -161,20 +174,21 @@ std::ostream& printNode(std::ostream & out, size_t level, const Node * n, std::v
   return out;
 }
 
-std::ostream& operator<<(std::ostream & out, const Node & n) {
+std::ostream& operator<<(std::ostream& out, const Node& n) {
   return printNode(out, 0, &n, nullptr);
 }
 
-std::ostream& operator<<(std::ostream & out, const Graph & g) {
+std::ostream& operator<<(std::ostream& out, const Graph& g) {
   out << "graph(" << const_value_list_with_types(g.inputs(), true) << ") {\n";
   std::vector<const Node*> groups;
-  for(auto n : g.nodes()) {
+  for (auto n : g.nodes()) {
     printNode(out, 1, n, &groups);
   }
   out << "  return (" << g.outputs() << ");\n}\n";
   size_t i = 0;
-  for(auto fg : groups) {
-    out << "with " << fg->kind().toQualString() << "_" <<i++ << " = " << *fg->g(attr::Subgraph);
+  for (auto fg : groups) {
+    out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
+        << *fg->g(attr::Subgraph);
   }
   /*
   // Uncomment this to debug all_nodes issues
@@ -189,7 +203,7 @@ std::ostream& operator<<(std::ostream & out, const Graph & g) {
   return out;
 }
 
-std::ostream& Graph::prettyPrint(std::ostream & out) {
+std::ostream& Graph::prettyPrint(std::ostream& out) {
   std::vector<at::Tensor> tensor_table;
   PythonPrint(out, *this, tensor_table);
   return out;
@@ -204,8 +218,8 @@ static void checkSameDevice(const Node* node) {
   bool has_device = false;
   c10::optional<at::Device> device = c10::nullopt;
   auto checkValue = [&](const Value* v) {
-    if(CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
-      if(!has_device) {
+    if (CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
+      if (!has_device) {
         has_device = true;
         device = type->device();
       } else {
@@ -213,10 +227,10 @@ static void checkSameDevice(const Node* node) {
       }
     }
   };
-  for(auto input : node->inputs()) {
+  for (auto input : node->inputs()) {
     checkValue(input);
   }
-  for(auto output : node->outputs()) {
+  for (auto output : node->outputs()) {
     checkValue(output);
   }
 }
@@ -247,13 +261,15 @@ void Node::lint() const {
     for (auto input : inputs_) {
       // WARNING: O(n^2)
       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-      JIT_ASSERT(std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) != input->uses_.end());
+      JIT_ASSERT(
+          std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
+          input->uses_.end());
       JIT_ASSERT(graph_->all_nodes.count(this) == 1);
       i++;
     }
   }
 
-  for(auto o : outputs()) {
+  for (auto o : outputs()) {
     size_t i = 0;
     for (auto use : o->uses()) {
       // Use invariants
@@ -265,38 +281,37 @@ void Node::lint() const {
   }
 
   // Node subclass invariants
-  IR_IF(this,Constant)
-    JIT_ASSERT(inputs_.size() == 0);
+  IR_IF(this, Constant)
+  JIT_ASSERT(inputs_.size() == 0);
   IR_ELSEIF(Return)
-    // Return uses is zero
-    JIT_ASSERT(outputs().size() == 0);
+  // Return uses is zero
+  JIT_ASSERT(outputs().size() == 0);
   IR_ELSEIF(Param)
-    // Param inputs is zero
-    JIT_ASSERT(inputs_.size() == 0);
+  // Param inputs is zero
+  JIT_ASSERT(inputs_.size() == 0);
   IR_ELSEIFM_CONST(PythonOp)
-    // Python operator cconv is correct
-    size_t n_scalars = 0, n_tensors = 0;
-    for (auto c : value->cconv) {
-      if (c == 'c') {
-        n_scalars++;
-      } else if (c == 'd') {
-        n_tensors++;
-      } else {
-        JIT_ASSERT(0);
-      }
-      JIT_ASSERT(static_cast<bool>(value->pyobj));
+  // Python operator cconv is correct
+  size_t n_scalars = 0, n_tensors = 0;
+  for (auto c : value->cconv) {
+    if (c == 'c') {
+      n_scalars++;
+    } else if (c == 'd') {
+      n_tensors++;
+    } else {
+      JIT_ASSERT(0);
     }
-    JIT_ASSERT(n_scalars == value->scalar_args.size());
-    JIT_ASSERT(n_tensors == inputs_.size());
+    JIT_ASSERT(static_cast<bool>(value->pyobj));
+  }
+  JIT_ASSERT(n_scalars == value->scalar_args.size());
+  JIT_ASSERT(n_tensors == inputs_.size());
   IR_ELSEIF(Eval)
-    // TODO: add invariants
+  // TODO: add invariants
   // TODO: It's not good for these ops to be top-level, it makes cases longer.
   IR_ELSEIF(FusionGroup)
-    checkSameDevice(value);
-    // TODO: Typecheck the parameters
-    value->g(attr::Subgraph)->lint();
+  checkSameDevice(value);
+  // TODO: Typecheck the parameters
+  value->g(attr::Subgraph)->lint();
   IR_END()
-
 }
 
 // TODO: When lint fails, give better indication about which
@@ -317,35 +332,35 @@ void Graph::lint() const {
 
   struct LintScope {
     LintScope() = default;
-    LintScope(std::unique_ptr<LintScope> parent)
-    : parent(std::move(parent)) {}
-    bool contains(const Value * v) {
+    LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
+    bool contains(const Value* v) {
       return values.count(v) > 0 || (parent && parent->contains(v));
     }
-    bool contains(const Node * n) {
+    bool contains(const Node* n) {
       return nodes.count(n) > 0 || (parent && parent->contains(n));
     }
-    void insert(const Value * v) {
+    void insert(const Value* v) {
       JIT_ASSERT(!contains(v));
       values.insert(v);
     }
-    void insert(const Node * n) {
+    void insert(const Node* n) {
       JIT_ASSERT(!contains(n));
       nodes.insert(n);
     }
     std::unique_ptr<LintScope> parent;
-  private:
+
+   private:
     std::unordered_set<const Value*> values;
     std::unordered_set<const Node*> nodes;
   };
   // Struct enables mutual recursion in linting methods.
   // Putting it inside Graph::lint enables access to private Graph members
   struct LintImpl {
-    LintImpl(const Graph & g)
-    : g(g)
-    , scope(new LintScope())
-    , all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
-    const Graph & g;
+    LintImpl(const Graph& g)
+        : g(g),
+          scope(new LintScope()),
+          all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
+    const Graph& g;
     std::unique_ptr<LintScope> scope;
     std::unordered_set<size_t> seen_uniques;
     std::unordered_map<const Node*, int64_t> anticipated_uses;
@@ -355,13 +370,13 @@ void Graph::lint() const {
     void check_value(const Value* v) {
       scope->insert(v);
       auto b2 = seen_uniques.insert(v->unique());
-      JIT_ASSERT(b2.second);  // insertion took place
+      JIT_ASSERT(b2.second); // insertion took place
       JIT_ASSERT(v->unique() < g.next_unique_);
 
       for (auto use : v->uses()) {
         JIT_ASSERT(!scope->contains(use.user));
         JIT_ASSERT(g.all_nodes.count(use.user) == 1);
-        anticipated_uses[use.user]++;  // int default constructs to 0
+        anticipated_uses[use.user]++; // int default constructs to 0
       }
     }
     void check_node(const Node* n) {
@@ -370,24 +385,25 @@ void Graph::lint() const {
           JIT_ASSERTM(0, input->unique(), " not in scope");
         }
       }
-      JIT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
-      anticipated_uses[n] = -1;  // we saw the anticipated user!
+      JIT_ASSERT(
+          anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
+      anticipated_uses[n] = -1; // we saw the anticipated user!
       scope->insert(n);
-      for(auto block : n->blocks()) {
+      for (auto block : n->blocks()) {
         std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope)));
         scope = std::move(new_scope);
         check_block(block);
         scope = std::move(scope->parent);
       }
       size_t i = 0;
-      for(auto o : n->outputs()) {
+      for (auto o : n->outputs()) {
         JIT_ASSERT(o->node() == n);
         JIT_ASSERT(i++ == o->offset_);
         check_value(o);
       }
       n->lint();
     }
-    void check_block(const Block *b) {
+    void check_block(const Blockb) {
       // Check topological ordering
       JIT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
       auto curNode = *b->nodes().begin();
@@ -417,10 +433,10 @@ void Graph::lint() const {
       // - only one return node???
 
       node_set nodes_set(ALL_OF(b->nodes()));
-      node_set inputs_set {b->input_};
-      node_set output_set {b->output_};
-      // TODO: Make a more type safe std::includes wrapper which disallows use on
-      // non-ordered containers
+      node_set inputs_set{b->input_};
+      node_set output_set{b->output_};
+      // TODO: Make a more type safe std::includes wrapper which disallows use
+      // on non-ordered containers
       JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
       JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
       JIT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
@@ -430,7 +446,8 @@ void Graph::lint() const {
       sum_set.insert(ALL_OF(output_set));
     }
     void check_graph() {
-      node_set all_nodes_set(ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
+      node_set all_nodes_set(
+          ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
 
       check_block(g.block_);
       for (auto kv : anticipated_uses) {
@@ -471,30 +488,30 @@ void Block::reIndexTopology() {
   }
 }
 
-void Block::cloneFrom(Block * src, std::function<Value*(Value*)> value_map) {
+void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
   std::unordered_map<Value*, Value*> local_map;
-  auto env = [&](Value * v) {
+  auto env = [&](Value* v) {
     auto it = local_map.find(v);
-    if(it != local_map.end())
+    if (it != local_map.end())
       return it->second;
     return value_map(v);
   };
 
   auto graph = owningGraph();
-  for(auto input : src->inputs()) {
+  for (auto input : src->inputs()) {
     local_map[input] = this->addInput()->copyMetadata(input);
   }
 
-  for(auto node : src->nodes()) {
+  for (auto node : src->nodes()) {
     auto new_node = this->appendNode(graph->createClone(node, env));
-    for(size_t i = 0; i < node->outputs().size(); ++i) {
+    for (size_t i = 0; i < node->outputs().size(); ++i) {
       auto oo = node->outputs()[i];
       auto no = new_node->outputs()[i];
       local_map[oo] = no;
       no->copyMetadata(oo);
     }
   }
-  for(auto output : src->outputs()) {
+  for (auto output : src->outputs()) {
     this->registerOutput(env(output));
   }
 }
@@ -503,9 +520,10 @@ void Block::destroy() {
   // we cannot destroy the output because it is used as the sentinel
   // for the nodes() list and has to remain valid for the loop
   output_->removeAllInputs();
-  for(auto it = this->nodes().reverse().begin(),
-      end = this->nodes().reverse().end();
-      it != end; ++it) {
+  for (auto it = this->nodes().reverse().begin(),
+            end = this->nodes().reverse().end();
+       it != end;
+       ++it) {
     it.destroyCurrent();
   }
   output_->destroy();
@@ -532,38 +550,41 @@ std::string Value::uniqueNameBase() const {
   std::string name_base = name;
   auto last_dot_pos = name.find_last_of('.');
   if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
-    if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) {
+    if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
+        std::string::npos) {
       name_base = name.substr(0, last_dot_pos);
     }
   }
   return name_base;
 }
 
-Value* Value::setUniqueName(const std::string & name) {
-  if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) {
+Value* Value::setUniqueName(const std::string& name) {
+  if (name.size() > 0 &&
+      name.find_first_not_of("0123456789") == std::string::npos) {
     throw std::runtime_error("names may not be integers: " + name);
   }
 
-  auto & names = node()->owningGraph()->unique_names_;
+  auto& names = node()->owningGraph()->unique_names_;
 
   // clear any old name from the map
-  if(hasUniqueName()) {
+  if (hasUniqueName()) {
     names.erase(unique_name_);
     unique_name_ = "";
   }
 
   // allow "" to clear the uniquename
-  if(name == "")
+  if (name == "")
     return this;
 
   // if someone else has this name, then rename the other value
   auto old_owner_of_name = names.find(name);
-  if(old_owner_of_name != names.end()) {
+  if (old_owner_of_name != names.end()) {
     size_t suffix = 1;
     std::string name_base = name;
     auto last_dot_pos = name.find_last_of('.');
     if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
-      if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) {
+      if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
+          std::string::npos) {
         suffix = std::stoll(name.substr(last_dot_pos + 1));
         name_base = name.substr(0, last_dot_pos);
       }
@@ -573,7 +594,7 @@ Value* Value::setUniqueName(const std::string & name) {
       std::stringstream ss;
       ss << name_base << "." << suffix++;
       replacement_name = ss.str();
-    } while(names.count(replacement_name) > 0);
+    } while (names.count(replacement_name) > 0);
     old_owner_of_name->second->setUniqueName(replacement_name);
   }
 
@@ -582,14 +603,14 @@ Value* Value::setUniqueName(const std::string & name) {
   return this;
 }
 
-Value* Value::copyMetadata(Value * from) {
+Value* Value::copyMetadata(Value* from) {
   setType(from->type());
   if (from->hasUniqueName())
     setUniqueName(from->uniqueName());
   return this;
 }
 
-void Value::replaceFirstUseWith(Value * newValue) {
+void Value::replaceFirstUseWith(Value* newValue) {
   JIT_ASSERT(owningGraph() == newValue->owningGraph());
   auto u = uses()[0];
   u.user->inputs_[u.offset] = newValue;
@@ -597,7 +618,7 @@ void Value::replaceFirstUseWith(Value * newValue) {
   uses_.erase(uses_.begin());
 }
 
-void Value::replaceAllUsesWith(Value * newValue) {
+void Value::replaceAllUsesWith(Value* newValue) {
   while (!uses().empty()) {
     replaceFirstUseWith(newValue);
   }
@@ -611,7 +632,8 @@ size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
       return i;
     }
   }
-  throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString());
+  throw std::runtime_error(
+      std::string("Couldn't find an argument called ") + name.toQualString());
 }
 
 c10::optional<IValue> Node::get(Symbol name) const {
@@ -622,10 +644,14 @@ Value* Node::namedInput(Symbol name) const {
   return input(findArgument(schema(), name));
 }
 
-bool Node::matches(const char *signature_literal, at::ArrayRef<Symbol> const_inputs) const {
-  if (!sig(signature_literal).matches(this)) return false;
+bool Node::matches(
+    const char* signature_literal,
+    at::ArrayRef<Symbol> const_inputs) const {
+  if (!sig(signature_literal).matches(this))
+    return false;
   for (Symbol s : const_inputs) {
-    if (!is_constant(s)) return false;
+    if (!is_constant(s))
+      return false;
   }
   return true;
 }
@@ -639,8 +665,8 @@ void Node::findSchema() const {
 }
 
 const FunctionSchema* Node::maybeSchema() const {
-  if(!schema_) {
-    if(auto op = findOperatorFor(this)) {
+  if (!schema_) {
+    if (auto op = findOperatorFor(this)) {
       schema_ = &op->schema();
     }
   }
@@ -649,38 +675,38 @@ const FunctionSchema* Node::maybeSchema() const {
 
 bool Node::isNondeterministic() const {
   static const OperatorSet nondeterministic_ops = {
-    "aten::dropout(Tensor input, float p, bool train) -> Tensor",
-    "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
-    "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
-    "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
-    "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
-    "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
-    "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
-    "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
-    "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
-    "aten::poisson(Tensor self, Generator? generator) -> Tensor",
-    "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
-    "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
-    "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::rand_like(Tensor self) -> Tensor",
-    "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randint_like(Tensor self, int high) -> Tensor",
-    "aten::randint_like(Tensor self, int low, int high) -> Tensor",
-    "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randn_like(Tensor self) -> Tensor",
-    "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
-    "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor"
-  };
+      "aten::dropout(Tensor input, float p, bool train) -> Tensor",
+      "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
+      "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
+      "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
+      "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
+      "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
+      "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
+      "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
+      "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
+      "aten::poisson(Tensor self, Generator? generator) -> Tensor",
+      "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
+      "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
+      "aten::rand(int[] size, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::rand_like(Tensor self) -> Tensor",
+      "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randint(int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randint(int low, int high, int[] size, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randint_like(Tensor self, int high) -> Tensor",
+      "aten::randint_like(Tensor self, int low, int high) -> Tensor",
+      "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randn(int[] size, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randn_like(Tensor self) -> Tensor",
+      "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
+      "aten::randperm(int n, *, int dtype, int layout, Device device) -> Tensor"};
 
   if (nondeterministic_ops.find(this) == nullptr) {
     return false;
   }
   // Dropout with train = False is deterministic
-  if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") && is_constant(attr::train) && !get<bool>(attr::train).value()) {
+  if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
+      is_constant(attr::train) && !get<bool>(attr::train).value()) {
     return false;
   }
   return true;
@@ -753,13 +779,13 @@ void Node::assignTopoPosition() {
   }
 }
 
-Node::Node(Graph * graph_, NodeKind kind_) :
-  kind_(kind_),
-  graph_(graph_),
-  owning_block_(nullptr),
-  scope_(graph_->current_scope_),
-  schema_(nullptr),
-  topo_position_(0) {
+Node::Node(Graph* graph_, NodeKind kind_)
+    : kind_(kind_),
+      graph_(graph_),
+      owning_block_(nullptr),
+      scope_(graph_->current_scope_),
+      schema_(nullptr),
+      topo_position_(0) {
   graph_->all_nodes.emplace(this);
 }
 
@@ -767,15 +793,15 @@ void Node::eraseOutput(size_t i) {
   JIT_ASSERT(i < outputs_.size());
   JIT_ASSERT(outputs_[i]->uses().empty());
   schema_ = nullptr;
-  Value * n = outputs_[i];
+  Value* n = outputs_[i];
   outputs_.erase(outputs_.begin() + i);
   owningGraph()->freeValue(n);
-  for(size_t j = i; j < outputs_.size(); j++) {
+  for (size_t j = i; j < outputs_.size(); j++) {
     outputs_[j]->offset_--;
   }
 }
 
-Block * Node::addBlock() {
+Block* Node::addBlock() {
   schema_ = nullptr;
   blocks_.push_back(new Block(owningGraph(), this));
   return blocks_.back();
@@ -784,33 +810,33 @@ Block * Node::addBlock() {
 void Node::eraseBlock(size_t i) {
   JIT_ASSERT(i < blocks_.size());
   schema_ = nullptr;
-  Block * n = blocks_[i];
+  Block* n = blocks_[i];
   blocks_.erase(blocks_.begin() + i);
   n->destroy();
 }
 
 void Node::destroy() {
-  while(!outputs().empty())
+  while (!outputs().empty())
     eraseOutput(outputs().size() - 1);
-  while(!blocks().empty())
+  while (!blocks().empty())
     eraseBlock(blocks().size() - 1);
   removeAllInputs();
-  if(inBlockList())
+  if (inBlockList())
     removeFromList();
   graph_->freeNode(this);
 }
 
-void Node::cloneFrom(Node * s) {
+void Node::cloneFrom(Node* s) {
   setSourceLocation(s->getSourceLocation());
-  if(s->scope_ && !s->scope_->isBlank())
+  if (s->scope_ && !s->scope_->isBlank())
     scope_ = s->scope_;
   copyAttributes(*s);
 }
 
-void Node::replaceAllUsesWith(Node * n) {
+void Node::replaceAllUsesWith(Node* n) {
   JIT_ASSERT(outputs().size() == n->outputs().size());
   size_t nOutputs = outputs().size();
-  for(size_t i = 0; i < nOutputs; i++) {
+  for (size_t i = 0; i < nOutputs; i++) {
     outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
   }
 }
@@ -834,7 +860,7 @@ Value* Node::insertInput(size_t i, Value* value) {
   return value;
 }
 
-Value* Node::addInput(Value * value) {
+Value* Node::addInput(Value* value) {
   JIT_ASSERT(graph_ == value->owningGraph());
   schema_ = nullptr;
   value->uses_.emplace_back(this, inputs_.size());
@@ -842,22 +868,22 @@ Value* Node::addInput(Value * value) {
   return value;
 }
 
-Value* Node::replaceInput(size_t i, Value * newValue) {
+Value* Node::replaceInput(size_t i, Value* newValue) {
   JIT_ASSERT(newValue->owningGraph() == graph_);
   schema_ = nullptr;
-  Value * old = dropInput(i);
+  Value* old = dropInput(i);
   inputs_[i] = newValue;
   newValue->uses_.emplace_back(this, i);
   return old;
 }
 
-void Node::replaceInputWith(Value * from, Value * to) {
+void Node::replaceInputWith(Value* from, Value* to) {
   JIT_ASSERT(from->owningGraph() == graph_);
   JIT_ASSERT(to->owningGraph() == graph_);
   schema_ = nullptr;
   size_t i = 0;
-  for(auto input : inputs()) {
-    if(input == from)
+  for (auto input : inputs()) {
+    if (input == from)
       replaceInput(i, to);
     i++;
   }
@@ -914,28 +940,27 @@ bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
   }
   // should never reach here, since both nodes are ultimately in the same graph
   JIT_ASSERT(false);
-
 }
 
-bool Node::isBefore(const Node * n) const {
+bool Node::isBefore(const Node* n) const {
   return isBeforeOrAfter(n, MoveSide::BEFORE);
 }
 
-bool Node::isAfter(const Node * n) const {
+bool Node::isAfter(const Node* n) const {
   return isBeforeOrAfter(n, MoveSide::AFTER);
 }
 
-Node* Node::insertBefore(Node * n) {
+Node* Node::insertBefore(Node* n) {
   JIT_ASSERT(n->inBlockList());
   insertAfter(n->prev());
   return this;
 }
 
-Node* Node::insertAfter(Node * n) {
+Node* Node::insertAfter(Node* n) {
   JIT_ASSERT(!inBlockList() && n->inBlockList());
   JIT_ASSERT(n->owningBlock());
   this->owning_block_ = n->owningBlock();
-  Node * next = n->next();
+  Node* next = n->next();
   n->next() = this;
   this->prev() = n;
   this->next() = next;
@@ -968,8 +993,7 @@ bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) {
 namespace {
 struct WorkingSet {
  public:
-  explicit WorkingSet(Node* mover, const AliasDb& aliasDb)
-      : aliasDb_(aliasDb) {
+  explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
     add(mover);
   }
 
@@ -1141,7 +1165,11 @@ struct WorkingSet {
 // node at a time. When we can't move past a node (because it depends on the
 // working set), then add it to the working set and keep moving until we hit
 // `moveAfter`.
-bool Node::tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun) {
+bool Node::tryMove(
+    Node* movePoint,
+    MoveSide moveSide,
+    const AliasDb& aliasDb,
+    bool dryRun) {
   JIT_ASSERT(this->inBlockList() && movePoint->inBlockList());
   JIT_ASSERT(this->owningBlock() == movePoint->owningBlock());
   if (this == movePoint) {
@@ -1241,12 +1269,12 @@ void Node::move(Node* movePoint, MoveSide moveSide) {
   }
 }
 
-void Node::moveAfter(Node * n) {
+void Node::moveAfter(Node* n) {
   removeFromList();
   insertAfter(n);
 }
 
-void Node::moveBefore(Node * n) {
+void Node::moveBefore(Node* n) {
   removeFromList();
   insertBefore(n);
 }
@@ -1256,7 +1284,7 @@ void Node::removeInput(size_t i) {
   dropInput(i);
   // everything after this input shifts left,
   // so we need to update their use offsets to match
-  for(size_t j = i+1; j < inputs_.size(); j++) {
+  for (size_t j = i + 1; j < inputs_.size(); j++) {
     auto it = findUseForInput(j);
     it->offset--;
   }
@@ -1265,13 +1293,13 @@ void Node::removeInput(size_t i) {
 
 void Node::removeAllInputs() {
   schema_ = nullptr;
-  for(size_t i = 0; i < inputs().size(); ++i)
+  for (size_t i = 0; i < inputs().size(); ++i)
     dropInput(i);
   inputs_.clear();
 }
 
 use_list::iterator Node::findUseForInput(size_t i) {
-  auto & input_uses = inputs_[i]->uses_;
+  auto& input_uses = inputs_[i]->uses_;
   // O(N) on the use list, but unless we get nodes with +100 uses
   // vector traversal still is probably faster than linked list
   auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
@@ -1291,8 +1319,8 @@ Value* Node::dropInput(size_t i) {
 void Node::removeFromList() {
   JIT_ASSERT(inBlockList());
   this->owning_block_ = nullptr;
-  Node * next = this->next();
-  Node * prev = this->prev();
+  Node* next = this->next();
+  Node* prev = this->prev();
   prev->next() = next;
   next->prev() = prev;
   this->next() = nullptr;
@@ -1300,7 +1328,8 @@ void Node::removeFromList() {
 }
 
 inline const SourceRange& fakeRange() {
-  static SourceRange range(std::make_shared<std::string>("<internally-created-node>"), 0, 1);
+  static SourceRange range(
+      std::make_shared<std::string>("<internally-created-node>"), 0, 1);
   return range;
 }
 
@@ -1322,14 +1351,17 @@ Value* Graph::insert(
 Node* Graph::create(NodeKind kind, size_t num_outputs) {
   // NB: Node constructor adds node to all_nodes
   auto n = new Node(this, kind);
-  for(size_t i = 0; i < num_outputs; i++)
+  for (size_t i = 0; i < num_outputs; i++)
     n->addOutput();
   return n;
 }
 
-Node* Graph::create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs) {
+Node* Graph::create(
+    NodeKind kind,
+    ArrayRef<Value*> inputs,
+    size_t num_outputs) {
   auto n = create(kind, num_outputs);
-  for(auto i : inputs)
+  for (auto i : inputs)
     n->addInput(i);
   return n;
 }
@@ -1339,14 +1371,14 @@ Node* Graph::createUndefined() {
 }
 
 Node* Graph::createNone(TypePtr typ) {
-  Node * n = create(prim::None);
+  Node* n = create(prim::None);
   n->output()->setType(OptionalType::create(std::move(typ)));
   return n;
 }
 
-Node * Graph::createFusionGroup() {
+Node* Graph::createFusionGroup() {
   auto n = create(prim::FusionGroup, 0);
-  n->g_(attr::Subgraph,std::make_shared<Graph>(current_scope()));
+  n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
   return n;
 }
 
@@ -1358,16 +1390,16 @@ Node* Graph::createTuple(at::ArrayRef<Value*> values) {
   return n;
 }
 
-Node* Graph::createTupleUnpack(Value * v) {
+Node* Graph::createTupleUnpack(Value* v) {
   TupleTypePtr tt = v->type()->expect<TupleType>();
   auto n = create(prim::TupleUnpack, {v}, 0);
-  for(auto & element : tt->elements()) {
+  for (auto& element : tt->elements()) {
     n->addOutput()->setType(element);
   }
   return n;
 }
 
-Node* Graph::createTupleIndex(Value * tup, int64_t index) {
+Node* Graph::createTupleIndex(Value* tup, int64_t index) {
   auto n = create(prim::TupleIndex, {tup});
   n->i_(attr::index, index);
   auto tuple_type = tup->type()->expect<TupleType>();
@@ -1375,7 +1407,7 @@ Node* Graph::createTupleIndex(Value * tup, int64_t index) {
   return n;
 }
 
-Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) {
+Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
   auto n = create(prim::TupleSlice, {tup});
   auto tuple_type = tup->type()->expect<TupleType>();
   n->i_(attr::beg, beg);
@@ -1391,13 +1423,13 @@ Node* Graph::createTupleSlice(Value * tup, int64_t beg, int64_t end) {
 
 Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
   auto n = create(prim::ListConstruct, values);
-  for(const auto & v : values) {
+  for (const auto& v : values) {
     JIT_ASSERT(v->type()->isSubtypeOf(elem_type));
   }
   n->output()->setType(ListType::create(elem_type));
   return n;
 }
-Node* Graph::createListUnpack(Value *v, size_t size) {
+Node* Graph::createListUnpack(Valuev, size_t size) {
   ListTypePtr list_type = v->type()->expect<ListType>();
   TypePtr elem_type = list_type->getElementType();
   auto n = create(prim::ListUnpack, {v}, 0);
@@ -1409,7 +1441,7 @@ Node* Graph::createListUnpack(Value *v, size_t size) {
 
 Node* Graph::createNumToTensor(Value* value) {
   auto typ = value->type();
-  Node * result = create(prim::NumToTensor, {value});
+  Node* result = create(prim::NumToTensor, {value});
   result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ)));
   return result;
 }
@@ -1420,18 +1452,21 @@ Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) {
   return result;
 }
 
-Node* Graph::createClone(Node * n, const std::function<Value*(Value*)>& value_map, bool copy_blocks) {
-  //n can be from a different graph
-  Node * r = n->allocNewInstance(this);
-  for(auto o : n->outputs()) {
+Node* Graph::createClone(
+    Node* n,
+    const std::function<Value*(Value*)>& value_map,
+    bool copy_blocks) {
+  // n can be from a different graph
+  Node* r = n->allocNewInstance(this);
+  for (auto o : n->outputs()) {
     r->addOutput()->copyMetadata(o);
   }
   r->cloneFrom(n);
-  for(auto i : n->inputs()) {
+  for (auto i : n->inputs()) {
     r->addInput(value_map(i));
   }
-  if(copy_blocks) {
-    for(auto b : n->blocks()) {
+  if (copy_blocks) {
+    for (auto b : n->blocks()) {
       r->addBlock()->cloneFrom(b, value_map);
     }
   }
@@ -1442,7 +1477,8 @@ Value* Graph::insertConstant(
     IValue val,
     c10::optional<SourceRange> loc,
     c10::optional<ScopePtr> scope) {
-  return jit::insertConstant(*this, std::move(val), std::move(loc), std::move(scope));
+  return jit::insertConstant(
+      *this, std::move(val), std::move(loc), std::move(scope));
 }
 
 std::string Graph::toString() const {
@@ -1452,28 +1488,28 @@ std::string Graph::toString() const {
 }
 
 Graph::~Graph() {
-  for (const Node * n : all_nodes)
+  for (const Node* n : all_nodes)
     delete n;
-  for (const Value * v : all_values)
+  for (const Value* v : all_values)
     delete v;
-  for (const Block * b : all_blocks)
+  for (const Block* b : all_blocks)
     delete b;
 }
 
-void Graph::freeNode(Node * n) {
+void Graph::freeNode(Node* n) {
   auto it = all_nodes.find(n);
   JIT_ASSERT(it != all_nodes.end());
   delete *it;
   all_nodes.erase(it);
 }
-void Graph::freeValue(Value * v) {
+void Graph::freeValue(Value* v) {
   v->setUniqueName("");
   auto it = all_values.find(v);
   JIT_ASSERT(it != all_values.end());
   delete *it;
   all_values.erase(it);
 }
-void Graph::freeBlock(Block * b) {
+void Graph::freeBlock(Block* b) {
   auto it = all_blocks.find(b);
   JIT_ASSERT(it != all_blocks.end());
   delete *it;
@@ -1483,13 +1519,17 @@ void Graph::freeBlock(Block * b) {
 at::ArrayRef<Value*> createTupleUnpack(Value* v) {
   // small peephole optimization to ensure IntList attributes can still turn
   // into constants e.g. in x.expand([3, 4])
-  if(v->node()->kind() == prim::TupleConstruct)
+  if (v->node()->kind() == prim::TupleConstruct)
     return v->node()->inputs();
-  auto & g = *v->owningGraph();
+  auto& g = *v->owningGraph();
   return g.insertNode(g.createTupleUnpack(v))->outputs();
 }
 
-std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs) {
+std::vector<Value*> inlineCallTo(
+    Graph& g,
+    Graph& callee,
+    ArrayRef<Value*> inputs,
+    bool unpack_outputs) {
   std::unordered_map<Value*, Value*> value_map;
   auto value_map_func = [&](Value* v) { return value_map.at(v); };
   JIT_ASSERT(callee.inputs().size() == inputs.size());
@@ -1497,8 +1537,7 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input
     value_map[callee.inputs()[i]] = inputs[i];
   }
   for (auto* node : callee.nodes()) {
-    auto* new_node =
-        g.insertNode(g.createClone(node, value_map_func));
+    auto* new_node = g.insertNode(g.createClone(node, value_map_func));
     for (size_t i = 0; i < node->outputs().size(); ++i) {
       value_map[node->outputs()[i]] = new_node->outputs()[i];
     }
@@ -1511,23 +1550,25 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input
 
   if (unpack_outputs && outputs.size() == 1 &&
       callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
-      auto tup = outputs[0];
-      outputs.clear();
-      for(Value* v : createTupleUnpack(tup)) {
-        outputs.emplace_back(v);
-      }
-      // if this was a peephole tuple unpack we can just get rid of
-      // the tuple construct here and prevent needing DCE
-      if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
-        tup->node()->destroy();
-      }
+    auto tup = outputs[0];
+    outputs.clear();
+    for (Value* v : createTupleUnpack(tup)) {
+      outputs.emplace_back(v);
+    }
+    // if this was a peephole tuple unpack we can just get rid of
+    // the tuple construct here and prevent needing DCE
+    if (tup->node()->kind() == prim::TupleConstruct &&
+        !tup->node()->hasUses()) {
+      tup->node()->destroy();
+    }
   }
 
   return outputs;
 }
 
-PythonOp* defaultAllocPythonOp(Graph*g) {
-  throw std::runtime_error("Trying to allocate a Python object without python bindings loaded");
+PythonOp* defaultAllocPythonOp(Graph* g) {
+  throw std::runtime_error(
+      "Trying to allocate a Python object without python bindings loaded");
 }
 std::atomic<decltype(&defaultAllocPythonOp)> alloc_python_op;
 
@@ -1539,5 +1580,5 @@ void setAllocPythonOp(PythonOp* (*v)(Graph* g)) {
   alloc_python_op.store(v);
 }
 
-
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 5a8f4f0..0d4f777 100644 (file)
@@ -1,25 +1,25 @@
 #pragma once
 
-#include <torch/csrc/jit/attributes.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/attributes.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/function_schema.h>
 #include <torch/csrc/jit/generic_if.h>
 #include <torch/csrc/jit/graph_node_list.h>
 #include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/ivalue.h>
+#include <torch/csrc/jit/named_value.h>
 #include <torch/csrc/jit/resource_guard.h>
 #include <torch/csrc/jit/scope.h>
 #include <torch/csrc/jit/source_location.h>
 #include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/named_value.h>
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/utils/disallow_copy.h>
 #include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/object_ptr.h>
 #include <torch/csrc/utils/python_stub.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <ATen/ATen.h>
 #include <c10/util/ArrayRef.h>
 #include <unordered_set>
 #include <vector>
 
-namespace torch { namespace autograd {
+namespace torch {
+namespace autograd {
 
 struct Function;
 
-}} // namespace torch::autograd
+}
+} // namespace torch
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Graph represents one "function" of computation.
-// It uses a simple ownership model where the graph owns all the nodes inside it.
-// All references inside the graph are raw pointers.
-// Destroying the Graph will invalidate any pointers to nodes in the graph.
+// It uses a simple ownership model where the graph owns all the nodes inside
+// it. All references inside the graph are raw pointers. Destroying the Graph
+// will invalidate any pointers to nodes in the graph.
 struct Graph;
 
 // Node is the base class of the IR graph. It represents one computation
@@ -55,8 +58,8 @@ struct Node;
 // Tensor or an opaque Handle object, as determined by type().
 struct Value;
 
-TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
-TORCH_API std::ostream& operator<<(std::ostream & out, const Node & n);
+TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
+TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
 
 // A list of nodes, with inputs and outputs
 struct Block;
@@ -65,12 +68,11 @@ struct Block;
 // 'user' is the consumer of the value, offset is the index into
 // 'user's input this where the produces will be found.
 struct Use {
-  Use(Node * user, size_t offset)
-  : user(user), offset(offset) {}
-  Node * user;
+  Use(Node* user, size_t offset) : user(user), offset(offset) {}
+  Node* user;
   size_t offset;
 
-  bool operator==(const Use & b) {
+  bool operator==(const Use& b) {
     return user == b.user && offset == b.offset;
   }
 };
@@ -106,29 +108,31 @@ using node_list = std::vector<Node*>;
 using value_list = std::vector<Value*>;
 using use_list = std::vector<Use>;
 using pyobj_list = std::vector<THPObjectPtr>;
-template<typename T>
+template <typename T>
 using ArrayRef = at::ArrayRef<T>;
 using NodeKind = Symbol;
 using topo_position_t = int64_t;
 
 struct Value {
   TH_DISALLOW_COPY_AND_ASSIGN(Value);
-  Value(Node * node_, size_t offset_);
-private:
+  Value(Node* node_, size_t offset_);
+
+ private:
   friend struct Node;
   friend struct Graph;
-  Node * node_;
+  Node* node_;
   size_t offset_;
-  size_t unique_ = 0;          // unique id
+  size_t unique_ = 0; // unique id
   use_list uses_;
   std::string unique_name_;
   TypePtr type_;
-public:
+
+ public:
   Value* setType(TypePtr type);
   void inferTypeFrom(const at::Tensor& output) {
     setType(CompleteTensorType::create(output));
   }
-  const TypePtr & type() const {
+  const TypePtr& type() const {
     JIT_ASSERT(type_ != nullptr);
     return type_;
   }
@@ -145,7 +149,7 @@ public:
   bool hasUniqueName() const {
     return !unique_name_.empty();
   }
-  TORCH_API Value* setUniqueName(const std::string & name);
+  TORCH_API Value* setUniqueName(const std::string& name);
   std::string uniqueName() const {
     if (hasUniqueName())
       return unique_name_;
@@ -161,13 +165,13 @@ public:
   void setOffset(size_t offset) {
     offset_ = offset;
   }
-  const Node * node() const {
+  const Node* node() const {
     return node_;
   }
-  Graph * owningGraph();
-  const Graph * owningGraph() const;
+  Graph* owningGraph();
+  const Graph* owningGraph() const;
   // TODO: make this more const correct
-  const use_list & uses() const {
+  const use_list& uses() const {
     return uses_;
   }
 
@@ -175,7 +179,7 @@ public:
     return !uses().empty();
   }
 
-  TORCH_API void replaceFirstUseWith(Value * newValue);
+  TORCH_API void replaceFirstUseWith(Value* newValue);
 
   // Replaces all uses of this value with 'newValue'.
   //
@@ -186,12 +190,11 @@ public:
   // Result:  %3 = f(%1, %2)
   //          %4 = g(%6)
   //          %5 = h(%6, %6)
-  TORCH_API void replaceAllUsesWith(Value * newValue);
+  TORCH_API void replaceAllUsesWith(Value* newValue);
 
-  TORCH_API Value* copyMetadata(Value * from);
+  TORCH_API Value* copyMetadata(Value* from);
 };
 
-
 struct Node : public Attributes<Node> {
   TH_DISALLOW_COPY_AND_ASSIGN(Node);
   friend struct Graph;
@@ -201,18 +204,23 @@ struct Node : public Attributes<Node> {
   friend const_graph_node_list;
   friend graph_node_list_iterator;
   friend const_graph_node_list_iterator;
-private:
+
+ private:
   // each node but Return/Param
   // is associated with exactly one place in the node list...
   // of the graph_
-  // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the list
-  // such that the list never has null pointers
-  // next_in_graph[0] is next pointer
-  // next_in_graph[1] is prev pointer
-  // using an array to allow the same iterator class for forward and reverse node lists
+  // this circular is a doubly-linked list. The Return node is used as the
+  // sentinel for the beginning and end of the list such that the list never has
+  // null pointers.
+  // - next_in_graph[0] is next pointer
+  // - next_in_graph[1] is prev pointer
+  //
+  // Using an array to allow the same iterator class for forward and
+  // reverse node lists
+  //
   // This list represents a topological sort
 
-  Node* next_in_graph[2] = { nullptr, nullptr };
+  Node* next_in_graph[2] = {nullptr, nullptr};
 
   const NodeKind kind_;
   std::vector<Value*> inputs_;
@@ -225,18 +233,26 @@ private:
   ScopePtr scope_;
   // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
   // This field is effective a cache that's populated on attribute lookups and
-  // invalidated every time we perform an operation that could potentially change
-  // the schema.
-  // note: mutable because schema_ is effectively a cache
+  // invalidated every time we perform an operation that could potentially
+  // change the schema. note: mutable because schema_ is effectively a cache
   mutable const FunctionSchema* schema_;
   topo_position_t topo_position_ = 0;
-protected:
-  TORCH_API Node(Graph * graph_, NodeKind kind_); //defined after graph
-public:
-  Node* & next() { return next_in_graph[kNextDirection]; }
-  Node* & prev() { return next_in_graph[kPrevDirection]; }
-  Node* const & next() const { return next_in_graph[kNextDirection]; }
-  Node* const & prev() const { return next_in_graph[kPrevDirection]; }
+
+ protected:
+  TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
+ public:
+  Node*& next() {
+    return next_in_graph[kNextDirection];
+  }
+  Node*& prev() {
+    return next_in_graph[kPrevDirection];
+  }
+  Node* const& next() const {
+    return next_in_graph[kNextDirection];
+  }
+  Node* const& prev() const {
+    return next_in_graph[kPrevDirection];
+  }
 
   NodeKind kind() const {
     return kind_;
@@ -248,16 +264,16 @@ public:
   std::shared_ptr<SourceLocation> getSourceLocation() const {
     return source_location_;
   }
-  Graph * owningGraph() {
+  Graph* owningGraph() {
     return graph_;
   }
-  const Graph * owningGraph() const {
+  const Graph* owningGraph() const {
     return graph_;
   }
-  Block * owningBlock() {
+  Block* owningBlock() {
     return owning_block_;
   }
-  const Block * owningBlock() const {
+  const Block* owningBlock() const {
     return owning_block_;
   }
   ScopePtr scope() {
@@ -300,26 +316,26 @@ public:
     // raw pointers are.
     return {outputs_.data(), outputs_.size()};
   }
-  Value * output(size_t i) const {
+  Value* output(size_t i) const {
     return outputs_.at(i);
   }
   bool hasUses() const {
-    for(auto o : outputs()) {
-      if(!o->uses().empty())
+    for (auto o : outputs()) {
+      if (!o->uses().empty())
         return true;
     }
     return false;
   }
 
-  TORCH_API void replaceAllUsesWith(Node * n);
+  TORCH_API void replaceAllUsesWith(Node* n);
 
-  // lots of things like chunk have a single input or single output, so we have a
-  // helper to make accessing it easier
-  Value * input() {
+  // lots of things like chunk have a single input or single output, so we have
+  // helper to make accessing it easier
+  Value* input() {
     JIT_ASSERT(inputs_.size() == 1);
     return inputs_.at(0);
   }
-  Value * output() {
+  Value* output() {
     JIT_ASSERT(outputs_.size() == 1);
     return outputs_.at(0);
   }
@@ -327,12 +343,12 @@ public:
     JIT_ASSERT(outputs_.size() == 1);
     return outputs_.at(0);
   }
-  const  Value * input() const {
+  const Value* input() const {
     JIT_ASSERT(inputs_.size() == 1);
     return inputs_.at(0);
   }
   // Access a particular input.  This is a checked index.
-  Value * input(size_t i) const {
+  Value* input(size_t i) const {
     return inputs_.at(i);
   }
 
@@ -342,7 +358,7 @@ public:
 
   template <typename T>
   c10::optional<T> get(Symbol name) const {
-    if(auto v = get(name))
+    if (auto v = get(name))
       return v->template to<T>();
     return c10::nullopt;
   }
@@ -353,7 +369,7 @@ public:
   }
 
   TORCH_API bool isNondeterministic() const;
-  TORCH_API bool hasSideEffects () const;
+  TORCH_API bool hasSideEffects() const;
 
   // Graphs
 
@@ -375,7 +391,7 @@ public:
   // Given:   %3 = f(%1, %2)
   // Execute: %3.addInput(%4)
   // Result:  %3 = f(%1, %2, %4)
-  TORCH_API Value* addInput(Value * value);
+  TORCH_API Value* addInput(Value* value);
 
   // Add 'value' as an input to 'this' at the specified position in the
   // arguments. Returns the added value for ease of chaining.
@@ -387,7 +403,7 @@ public:
   // Given:   %3 = f(%1, %2)
   // Execute: %3.replaceInput(1, %4)
   // Result:  %3 = f(%1, %4)
-  TORCH_API Value * replaceInput(size_t i, Value * newValue);
+  TORCH_API Value* replaceInput(size_t i, Value* newValue);
 
   // Replace all occurrences of 'from' in the inputs of this
   // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
@@ -395,7 +411,7 @@ public:
   // Given:   %3 = f(%1, %2, %1)
   // Execute: %3.replaceInputWith(%1, %4)
   // Result:  %3 = f(%4, %2, %4)
-  TORCH_API void replaceInputWith(Value * from, Value * to);
+  TORCH_API void replaceInputWith(Value* from, Value* to);
 
   TORCH_API Value* addOutput();
 
@@ -403,17 +419,18 @@ public:
 
   TORCH_API void eraseOutput(size_t i);
 
-  TORCH_API Block * addBlock();
+  TORCH_API Block* addBlock();
   TORCH_API void eraseBlock(size_t i);
 
   // Each Node can have a list of subblocks. These are used to define structured
   // nested control flow operators such as If and Loop.
   // The meaning of a block is specific to the kind of node it is in, but
   // all blocks share these semantics:
-  // * Nested lexical scoping: If a node 'Parent' has a subblock which contains a
-  //   node 'Child', Child can use any value that was in scope for the Parent
+  // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
+  //   node 'Child', Child can use any value that was in scope for the Parent
   //   node in addition to any values defined before 'Child' in the subblock.
-  // * The list of inputs to the block are in scope for the duration of the block
+  // * The list of inputs to the block are in scope for the duration of the
+  //   block
   // * the outputs of the Parent node are not in scope for the subblocks
   // Typically the inputs to a block that represents control flow act as
   // as the equivalents phi-nodes in standard SSA form,
@@ -432,10 +449,10 @@ public:
   }
 
   // Is 'this' before 'n' in the topological order?
-  TORCH_API bool isBefore(const Node * n) const;
+  TORCH_API bool isBefore(const Node* n) const;
 
   // Is 'this' after 'n' in the topological order?
-  TORCH_API bool isAfter(const Node * n) const;
+  TORCH_API bool isAfter(const Node* n) const;
 
   // Insert unattached 'this' node before 'n' in the topological order.
   // Returns this (for chaining).
@@ -447,7 +464,7 @@ public:
   // Result:  %3 = f(%1, %2)
   //          %5 = h(%1)
   //          %4 = g(%3)
-  TORCH_API Node* insertBefore(Node * n);
+  TORCH_API Node* insertBefore(Node* n);
 
   // Insert unattached 'this' node after 'n' in the topological order.
   // Returns this (for chaining).
@@ -459,7 +476,7 @@ public:
   // Result:  %3 = f(%1, %2)
   //          %4 = g(%3)
   //          %5 = h(%1)
-  TORCH_API Node* insertAfter(Node * n);
+  TORCH_API Node* insertAfter(Node* n);
 
   // Move 'this' (already in the graph) after 'n' in the topological order.
   //
@@ -472,7 +489,7 @@ public:
   // Result: %3 = g(%1)
   //         %2 = f(%1)
   //
-  TORCH_API void moveAfter(Node * n);
+  TORCH_API void moveAfter(Node* n);
 
   // Move 'this' (already in the graph) after 'n' in the topological order.
   //
@@ -500,7 +517,7 @@ public:
   // Execute: %3.moveBefore(%2)
   // Result: %3 = g(%1)
   //         %2 = f(%1)
-  TORCH_API void moveBefore(Node * n);
+  TORCH_API void moveBefore(Node* n);
 
   // Move 'this' (already in the graph) before 'n' in the topological order.
   //
@@ -565,23 +582,27 @@ public:
   // Example usage: if(auto s = n.cast<Select>()) { ... }
   //
   // TODO: Make this const correct
-  template<typename T>
+  template <typename T>
   T* cast() {
-    if(T::Kind == kind())
+    if (T::Kind == kind())
       return static_cast<T*>(this);
     return nullptr;
   }
-  template<typename T>
+  template <typename T>
   T* expect() {
     JIT_ASSERTM(
         T::Kind == kind(),
-        "expected a ", T::Kind.toDisplayString(),
-        " but found a ", kind().toDisplayString());
+        "expected a ",
+        T::Kind.toDisplayString(),
+        " but found a ",
+        kind().toDisplayString());
     return static_cast<T*>(this);
   }
 
   // XXX: this function is meant to be used with string literals only!
-  TORCH_API bool matches(const char *signature_literal, at::ArrayRef<Symbol> const_inputs={}) const;
+  TORCH_API bool matches(
+      const char* signature_literal,
+      at::ArrayRef<Symbol> const_inputs = {}) const;
 
   const FunctionSchema& schema() const {
     if (!schema_)
@@ -596,13 +617,18 @@ public:
 
  private:
   enum class MoveSide { BEFORE, AFTER };
-  bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
+  bool tryMove(
+      Node* movePoint,
+      MoveSide moveSide,
+      const AliasDb& aliasDb,
+      bool dryRun);
   void move(Node* movePoint, MoveSide moveSide);
   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
 
   std::pair<Value*, const Argument&> findInput(Symbol name);
   void findSchema() const;
-  // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
+  // Lookup iterator in use list of _input i_ that corresponds to its use of
+  // _this_
   TORCH_API use_list::iterator findUseForInput(size_t i);
 
   // remove the use of input i, this sets input i to nullptr, but
@@ -611,7 +637,7 @@ public:
   TORCH_API Value* dropInput(size_t i);
 
   bool inBlockList() const {
-    if(next() == nullptr) {
+    if (next() == nullptr) {
       JIT_ASSERT(prev() == nullptr);
     }
     return next() != nullptr;
@@ -622,13 +648,13 @@ public:
 
   void assignTopoPosition();
 
-protected:
+ protected:
   // subclasses must override
   // this function is used by createClone to initialize a new version
   // of a node in another graph. It should allocate a new instance of the same
   // concrete type as 'this', but in graph 'g' which might be different
   // than graph_
-  virtual Node * allocNewInstance(Graph * g) {
+  virtual Node* allocNewInstance(Graph* g) {
     return new Node(g, kind());
   }
   // create a copy of all properties of Node s into this.
@@ -636,19 +662,19 @@ protected:
   // 'this' will be allocated with s->allocNewInstance(g) so it should have
   // the same concrete type as 's'
   //
-  TORCH_API virtual void cloneFrom(Node * s);
+  TORCH_API virtual void cloneFrom(Node* s);
 };
 
 struct Block {
   friend struct Node;
   friend struct Graph;
   TH_DISALLOW_COPY_AND_ASSIGN(Block);
-  TORCH_API Block(Graph * graph_, Node * node_);
+  TORCH_API Block(Graph* graph_, Node* node_);
   at::ArrayRef<Value*> inputs() {
     return input_->outputs();
   }
   at::ArrayRef<const Value*> inputs() const {
-    const auto & inputs = input_->outputs();
+    const auto& inputs = input_->outputs();
     return {inputs.data(), inputs.size()};
   }
   at::ArrayRef<Value*> outputs() {
@@ -663,20 +689,20 @@ struct Block {
   const_graph_node_list nodes() const {
     return {output_, kNextDirection};
   }
-  Node * return_node() {
+  Node* return_node() {
     return output_;
   }
-  const Node * return_node() const {
+  const Node* return_node() const {
     return output_;
   }
-  Node * param_node() {
+  Node* param_node() {
     return input_;
   }
-  const Node * param_node() const {
+  const Node* param_node() const {
     return input_;
   }
-  Value * addInput(std::string name="") {
-    Value * v = input_->addOutput();
+  Value* addInput(std::string name = "") {
+    Value* v = input_->addOutput();
     v->setUniqueName(std::move(name));
     return v;
   }
@@ -688,7 +714,7 @@ struct Block {
   void eraseInput(size_t i) {
     input_->eraseOutput(i);
   }
-  size_t registerOutput(Value * v) {
+  size_t registerOutput(Value* v) {
     output_->addInput(v);
     return outputs().size() - 1;
   }
@@ -699,35 +725,36 @@ struct Block {
   void eraseOutput(size_t i) {
     output_->removeInput(i);
   }
-  Node * appendNode(Node * n) {
+  Node* appendNode(Node* n) {
     JIT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
     n->insertBefore(output_);
     return n;
   }
 
-  Node * prependNode(Node * n) {
+  Node* prependNode(Node* n) {
     JIT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
     n->insertAfter(output_);
     return n;
   }
-  Graph * owningGraph() {
+  Graph* owningGraph() {
     return graph_;
   }
-  const Graph * owningGraph() const {
+  const Graph* owningGraph() const {
     return graph_;
   }
-  Node * owningNode() {
+  Node* owningNode() {
     return owning_node_;
   }
-  const Node * owningNode() const {
+  const Node* owningNode() const {
     return owning_node_;
   }
   // clone all inputs, nodes, and outputs from src and append them
   // to the inputs, nodes, and outputs of this block
   // value_map is used whenever a node in src references a free variable
   // in src to look up its corresponding value
-  TORCH_API void cloneFrom(Block * src, std::function<Value*(Value*)> value_map);
-private:
+  TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
+
+ private:
   void reIndexTopology();
 
   // should only be called in the constructor
@@ -742,23 +769,24 @@ private:
   // do not have to be removed before you can destroy the block
   void destroy();
 
-  Graph * const graph_;
+  Graph* const graph_;
   // holds outputs in a way that can be reflected
   // as a Use object
   // also used as the beginning/end of the circular node list to avoid
   // having corner cases where the list is empty.
-  Node * const output_;
-  Node * const input_;
-  Node * const owning_node_; // either the node that has this block or nullptr for root
+  Node* const output_;
+  Node* const input_;
+  Node* const
+      owning_node_; // either the node that has this block or nullptr for root
 };
 
 struct Graph {
-TH_DISALLOW_COPY_AND_ASSIGN(Graph);
-friend struct Node;
-friend struct Value;
-friend struct Block;
-private:
+  TH_DISALLOW_COPY_AND_ASSIGN(Graph);
+  friend struct Node;
+  friend struct Value;
+  friend struct Block;
 
+ private:
   // only used to keep track of allocated nodes
   // actual representation of Graph is done with
   // inputs, outputs, nodes
@@ -777,13 +805,12 @@ private:
   // by default this is set to append to the top level block
   Node* insert_before_;
 
-public:
-
+ public:
   Graph(ScopePtr scope_root)
-  : next_unique_(0)
-  , current_scope_(std::move(scope_root))
-  , block_(new Block(this, nullptr))
-  , insert_before_(return_node()) {}
+      : next_unique_(0),
+        current_scope_(std::move(scope_root)),
+        block_(new Block(this, nullptr)),
+        insert_before_(return_node()) {}
 
   Graph() : Graph(c10::make_intrusive<Scope>()) {}
 
@@ -791,33 +818,33 @@ public:
     return block_->inputs();
   }
   at::ArrayRef<const Value*> inputs() const {
-    const auto & block = *block_;
+    const auto& block = *block_;
     return block.inputs();
   }
   at::ArrayRef<Value*> outputs() {
     return block_->outputs();
   }
   at::ArrayRef<const Value*> outputs() const {
-    const auto & block = *block_;
+    const auto& block = *block_;
     return block.outputs();
   }
   graph_node_list nodes() {
     return block_->nodes();
   }
   const_graph_node_list nodes() const {
-    const auto & block = *block_;
+    const auto& block = *block_;
     return block.nodes();
   }
-  Node * param_node() {
+  Node* param_node() {
     return block_->param_node();
   }
-  const Node * param_node() const {
+  const Node* param_node() const {
     return block_->param_node();
   }
-  Node * return_node() {
+  Node* return_node() {
     return block_->return_node();
   }
-  const Node * return_node() const {
+  const Node* return_node() const {
     return block_->return_node();
   }
   void push_scope(const std::string& scope_name) {
@@ -832,7 +859,7 @@ public:
   void set_current_scope(ScopePtr scope) {
     current_scope_ = std::move(scope);
   }
-  Value * addInput(std::string name="") {
+  Value* addInput(std::string name = "") {
     return block_->addInput(std::move(name));
   }
   Value* insertInput(size_t i, std::string name = "") {
@@ -848,24 +875,29 @@ public:
     return unique_names_;
   }
 
-  size_t registerOutput(Value * n) {
+  size_t registerOutput(Value* n) {
     return block_->registerOutput(n);
   }
 
-  TORCH_API Node * create(NodeKind kind, size_t num_outputs=1);
-  TORCH_API Node * create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs=1);
+  TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
+  TORCH_API Node* create(
+      NodeKind kind,
+      ArrayRef<Value*> inputs,
+      size_t num_outputs = 1);
 
-
-  TORCH_API Node* createNone(TypePtr typ); // value of None with type Optional[typ]
+  TORCH_API Node* createNone(
+      TypePtr typ); // value of None with type Optional[typ]
   TORCH_API Node* createUndefined();
   TORCH_API Node* createFusionGroup();
   TORCH_API Node* createDifferentiableSubgraph();
   TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
-  TORCH_API Node* createTupleUnpack(Value * v);
-  TORCH_API Node* createTupleIndex(Value * tup, int64_t index);
-  TORCH_API Node* createTupleSlice(Value * tup, int64_t beg, int64_t end);
-  TORCH_API Node* createList(const TypePtr& elem_type, at::ArrayRef<Value*> values);
-  TORCH_API Node* createListUnpack(Value *v, size_t size);
+  TORCH_API Node* createTupleUnpack(Value* v);
+  TORCH_API Node* createTupleIndex(Value* tup, int64_t index);
+  TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
+  TORCH_API Node* createList(
+      const TypePtr& elem_type,
+      at::ArrayRef<Value*> values);
+  TORCH_API Node* createListUnpack(Value* v, size_t size);
   TORCH_API Node* createNumToTensor(Value* value);
   TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
   Node* createPythonOp(
@@ -876,60 +908,67 @@ public:
   // use node_map to translate inputs of n to inputs of the cloned node
   // if copy_blocks is false, it will not recursively clone the nested blocks
   // this node contains.
-  TORCH_API Node * createClone(Node * n, const std::function<Value*(Value*)>& value_map, bool copy_blocks=true);
+  TORCH_API Node* createClone(
+      Node* n,
+      const std::function<Value*(Value*)>& value_map,
+      bool copy_blocks = true);
 
   TORCH_API Value* insertConstant(
       IValue val,
       c10::optional<SourceRange> loc = c10::nullopt,
       c10::optional<ScopePtr> scope = c10::nullopt);
 
-
-  // schema-driven insert
-  // this inserts a node into the graph with inputs determined from args and kwargs using Python
-  // argument matching rules, and checks that the op matches a known schema
-  // if this node successfully completes, it guarentees the node is a correctly-formed invocation
-  // of opname
+  // Schema-driven insert:
+  // This inserts a node into the graph with inputs determined from args and
+  // kwargs using Python argument matching rules, and checks that the op matches
+  // a known schema.
+  //
+  // If this node successfully completes, it guarentees the node
+  // is a correctly-formed invocation of opname
   TORCH_API Value* insert(
       Symbol opname,
       at::ArrayRef<NamedValue> args,
       at::ArrayRef<NamedValue> kwargs = {},
       const c10::optional<SourceRange>& range = {});
 
-  Node * appendNode(Node * n) {
+  Node* appendNode(Node* n) {
     return block_->appendNode(n);
   }
 
-  Node * prependNode(Node * n) {
+  Node* prependNode(Node* n) {
     return block_->prependNode(n);
   }
 
   // insert before insert_before_ node
   // initialized to insert at the end of the top level block
   // can be changed with setInsertPoint()
-  Node * insertNode(Node * n) {
-    JIT_ASSERT(insert_before_->inBlockList() && "insert point node is no longer in a block list");
+  Node* insertNode(Node* n) {
+    JIT_ASSERT(
+        insert_before_->inBlockList() &&
+        "insert point node is no longer in a block list");
     return n->insertBefore(insert_before_);
   }
   // set where nodes are inserted to append to the end of this block
-  void setInsertPoint(Block * b) {
+  void setInsertPoint(Block* b) {
     JIT_ASSERT(b->owningGraph() == this);
     insert_before_ = b->return_node();
   }
   // set where nodes are inserted to insert _before_ this node
-  // for implementation simplicity we only support inserting before a node for now
-  void setInsertPoint(Node * n) {
+  // for implementation simplicity we only support inserting before a node for
+  // now
+  void setInsertPoint(Node* n) {
     JIT_ASSERT(n->owningGraph() == this && n->inBlockList());
     insert_before_ = n;
   }
-  Node * insertPoint() {
+  Node* insertPoint() {
     return insert_before_;
   }
 
   // the top level block
-  Block * block() {
+  Block* block() {
     return block_;
   }
-  const Block * block() const {
+  const Block* block() const {
     return block_;
   }
 
@@ -942,68 +981,64 @@ public:
 
   TORCH_API std::string toString() const;
 
-  friend TORCH_API std::ostream& operator<<(std::ostream & out, const Graph & g);
+  friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
 
-  TORCH_API std::ostream& prettyPrint(std::ostream & out);
+  TORCH_API std::ostream& prettyPrint(std::ostream& out);
   TORCH_API void dumpPretty();
 
   TORCH_API std::shared_ptr<Graph> copy();
 
-private:
-
-  TORCH_API void freeNode(Node * n);
-  TORCH_API void freeValue(Value * v);
-  TORCH_API void freeBlock(Block * b);
+ private:
+  TORCH_API void freeNode(Node* n);
+  TORCH_API void freeValue(Value* v);
+  TORCH_API void freeBlock(Block* b);
 };
 
 struct WithInsertPoint : public ResourceGuard {
-  WithInsertPoint(Node * n)
-  : ResourceGuard([this] {
-    prev->owningGraph()->setInsertPoint(prev);
-  })
-  , prev(n->owningGraph()->insertPoint()) {
+  WithInsertPoint(Node* n)
+      : ResourceGuard([this] { prev->owningGraph()->setInsertPoint(prev); }),
+        prev(n->owningGraph()->insertPoint()) {
     n->owningGraph()->setInsertPoint(n);
   }
-  WithInsertPoint(Block * b)
-  : WithInsertPoint(b->return_node()) {}
-private:
-  Node * prev;
+  WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
+
+ private:
+  Node* prev;
 };
 
 struct WithCurrentScope : public ResourceGuard {
-  WithCurrentScope(Graph & g, ScopePtr scope)
-  : ResourceGuard([&g, this]() {
-    g.set_current_scope(prev_scope);
-  })
-  , prev_scope(g.current_scope()) {
+  WithCurrentScope(Graph& g, ScopePtr scope)
+      : ResourceGuard([&g, this]() { g.set_current_scope(prev_scope); }),
+        prev_scope(g.current_scope()) {
     g.set_current_scope(std::move(scope));
   }
-private:
+
+ private:
   ScopePtr prev_scope;
 };
 
-inline Value::Value(Node * node_, size_t offset_)
-: node_(node_),
-  offset_(offset_),
-  unique_(node_->graph_->next_unique_++),
-  type_(DynamicType::get()) {
+inline Value::Value(Node* node_, size_t offset_)
+    : node_(node_),
+      offset_(offset_),
+      unique_(node_->graph_->next_unique_++),
+      type_(DynamicType::get()) {
   node_->graph_->all_values.emplace(this);
 }
 
 inline Value* Value::setType(TypePtr type) {
   JIT_ASSERT(type);
   type_ = std::move(type);
-  for (Use & use : uses_) {
+  for (Use& use : uses_) {
     use.user->schema_ = nullptr;
   }
   return this;
 }
 
-inline Graph * Value::owningGraph() {
+inline Graph* Value::owningGraph() {
   return node()->owningGraph();
 }
 
-inline const Graph * Value::owningGraph() const {
+inline const Graph* Value::owningGraph() const {
   return node()->owningGraph();
 }
 
@@ -1015,11 +1050,11 @@ inline const Graph * Value::owningGraph() const {
 // Mutable case
 // The IFM/ELSEIFM indicate that subclass *refinement* occurs.
 // This is only valid for node types for which we have subclasses.
-#define IR_IFM(x,Kind) GENERIC_IF(,prim::Kind,x,Kind)
-#define IR_ELSEIFM(Kind) GENERIC_ELSEIF(,prim::Kind,Kind)
+#define IR_IFM(x, Kind) GENERIC_IF(, prim::Kind, x, Kind)
+#define IR_ELSEIFM(Kind) GENERIC_ELSEIF(, prim::Kind, Kind)
 
-#define IR_IFM_CONST(x,Kind) GENERIC_IF(const,prim::Kind,x,Kind)
-#define IR_ELSEIFM_CONST(Kind) GENERIC_ELSEIF(const,prim::Kind,Kind)
+#define IR_IFM_CONST(x, Kind) GENERIC_IF(const, prim::Kind, x, Kind)
+#define IR_ELSEIFM_CONST(Kind) GENERIC_ELSEIF(const, prim::Kind, Kind)
 
 #define IR_IF(x, Kind)           \
   auto&& __match_key = x;        \
@@ -1052,12 +1087,12 @@ inline const Graph * Value::owningGraph() const {
 
 /************* All nodes not required to be defined before Graph **************/
 
- // execute a Python function, used for Ops we can't optimize but that we want to optimize around
+// execute a Python function, used for Ops we can't optimize but that we want to
+// optimize around
 struct PythonOp : public Node {
   static constexpr Symbol Kind = prim::PythonOp;
 
-  PythonOp(Graph * graph)
-  : Node(graph,prim::PythonOp) {}
+  PythonOp(Graph* graph) : Node(graph, prim::PythonOp) {}
   PythonOp* init(
       THPObjectPtr&& pyobj,
       const std::string& cconv,
@@ -1080,8 +1115,8 @@ struct PythonOp : public Node {
   std::vector<THPObjectPtr> scalar_args;
   virtual std::string name() const = 0;
   virtual void writeScalars(std::ostream& out) const = 0;
-  void cloneFrom(Node * other_) override = 0;
-  Node * allocNewInstance(Graph * g) override = 0;
+  void cloneFrom(Node* other_) override = 0;
+  Node* allocNewInstance(Graph* g) override = 0;
   // recover the autograd.Function instance, if this PythonOp's function
   // was originally SomeFunction.apply
   // used in ONNX for discovering symbolics
@@ -1095,18 +1130,20 @@ inline Node* Graph::createPythonOp(
     const std::string& cconv,
     pyobj_list&& scalar_args) {
   auto op = allocPythonOp(this);
-  return op->init(
-      std::move(pyobj),
-      cconv,
-      std::move(scalar_args));
+  return op->init(std::move(pyobj), cconv, std::move(scalar_args));
 }
 
 TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
 
 TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
-// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node
+// unpack_outputs - if true, and the callee returns a single tuple value, then
+// insert a tuple unpack node
 //                  and return the resulting values
-TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs=false);
-
-
-}} // namespace torch::jit
+TORCH_API std::vector<Value*> inlineCallTo(
+    Graph& g,
+    Graph& callee,
+    ArrayRef<Value*> inputs,
+    bool unpack_outputs = false);
+
+} // namespace jit
+} // namespace torch
index 808171b..87d617c 100644 (file)
@@ -7,8 +7,8 @@ using ::c10::ivalue::List;
 using ::c10::ivalue::Shared;
 
 using ::c10::IValue;
-using ::c10::ivalue::Tuple;
 using ::c10::ivalue::Future;
+using ::c10::ivalue::Tuple;
 
 using ::c10::ivalue::BoolList;
 using ::c10::ivalue::DoubleList;
index 890b7ac..983e08f 100644 (file)
@@ -1,36 +1,36 @@
 #pragma once
 #include <ATen/ATen.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/source_range.h>
 #include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/constants.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct Value;
 
 struct NamedValue {
   NamedValue(const SourceRange& loc, const std::string& name, Value* value)
-  : loc_(loc), name_(name), value_(value) {}
-  NamedValue(const SourceRange& loc, Value* value)
-  : loc_(loc), value_(value) {}
+      : loc_(loc), name_(name), value_(value) {}
+  NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
 
-  /* implicit */ NamedValue(Value* value)
-  : value_(value) {}
+  /* implicit */ NamedValue(Value* value) : value_(value) {}
   NamedValue(const std::string& name, Value* value)
-  : name_(name), value_(value) {}
+      : name_(name), value_(value) {}
 
   /* implicit */ NamedValue(IValue value)
-  : value_(nullptr), ivalue_(std::move(value)) {}
+      : value_(nullptr), ivalue_(std::move(value)) {}
 
   NamedValue(const std::string& name, IValue value)
-   name_(name), ivalue_(std::move(value)) {}
+      : name_(name), ivalue_(std::move(value)) {}
 
   template <
       typename T,
       typename = enable_if_t<
           (!std::is_same<decay_t<T>, NamedValue>::value &&
-           !std::is_same<decay_t<T>, Value*>::value && !std::is_same<decay_t<T>, IValue>::value)>>
+           !std::is_same<decay_t<T>, Value*>::value &&
+           !std::is_same<decay_t<T>, IValue>::value)>>
   NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
 
   template <
@@ -39,10 +39,10 @@ struct NamedValue {
           (!std::is_same<decay_t<T>, Value*>::value &&
            !std::is_same<decay_t<T>, IValue>::value)>>
   NamedValue(const std::string& name, T&& t)
-  : NamedValue(name, IValue(std::forward<T>(t))) {}
+      : NamedValue(name, IValue(std::forward<T>(t))) {}
 
   SourceRange locOr(const SourceRange& backup_location) const {
-    if(!loc_)
+    if (!loc_)
       return backup_location;
     return loc();
   }
@@ -50,8 +50,9 @@ struct NamedValue {
   // note: this will insert a constant node into the graph at the current
   // insert point if this NamedValue is actually a constant
   Value* value(Graph& g) const {
-    if(!value_)
-      return insertConstant(g, ivalue_); // use insertConstant to remove need to include ir.h here
+    if (!value_)
+      return insertConstant(
+          g, ivalue_); // use insertConstant to remove need to include ir.h here
     return value_;
   }
 
@@ -65,12 +66,13 @@ struct NamedValue {
     return *loc_;
   }
 
-private:
- c10::optional<SourceRange> loc_;
- c10::optional<std::string> name_;
- Value* value_{nullptr};
- // only valid if value_ == nullptr;
- IValue ivalue_;
+ private:
 c10::optional<SourceRange> loc_;
 c10::optional<std::string> name_;
 Value* value_{nullptr};
 // only valid if value_ == nullptr;
 IValue ivalue_;
 };
 
-}}
+} // namespace jit
+} // namespace torch
index 3bf23a2..5c32cca 100644 (file)
@@ -5,12 +5,13 @@
 
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/hash.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -18,12 +19,14 @@ bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
   return &lhs.type() == &rhs.type() && lhs.equal(rhs);
 }
 
-bool tensorListEqual(const std::vector<at::Tensor>& lhs, const std::vector<at::Tensor>& rhs) {
-  if (lhs.size() != rhs.size()) return false;
+bool tensorListEqual(
+    const std::vector<at::Tensor>& lhs,
+    const std::vector<at::Tensor>& rhs) {
+  if (lhs.size() != rhs.size())
+    return false;
   return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
 }
 
-
 // Check whether two nodes have the same attributes in CSE.
 // This function may be too conservative for general use.
 // Do NOT support g/gs attributes.
@@ -31,24 +34,30 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
   JIT_ASSERT(lhs != nullptr);
   JIT_ASSERT(rhs != nullptr);
   // One has attributes, the other does not.
-  if (lhs->hasAttributes() != rhs->hasAttributes()) return false;
+  if (lhs->hasAttributes() != rhs->hasAttributes())
+    return false;
   // Neither has attributes.
-  if (!lhs->hasAttributes() && !rhs->hasAttributes()) return true;
+  if (!lhs->hasAttributes() && !rhs->hasAttributes())
+    return true;
 
   auto lnames = lhs->attributeNames();
   auto rnames = rhs->attributeNames();
   std::sort(lnames.begin(), lnames.end());
   std::sort(rnames.begin(), rnames.end());
-  if (lnames != rnames) return false;
+  if (lnames != rnames)
+    return false;
 
   for (auto name : lnames) {
-    if (lhs->kindOf(name) != rhs->kindOf(name)) return false;
+    if (lhs->kindOf(name) != rhs->kindOf(name))
+      return false;
 
-    #define COMPARE_ATTRIBUTEVALUE(type) \
-      case AttributeKind::type: \
-        { if (lhs->type(name) != rhs->type(name)) return false; } break;
+#define COMPARE_ATTRIBUTEVALUE(type)        \
+  case AttributeKind::type: {               \
+    if (lhs->type(name) != rhs->type(name)) \
+      return false;                         \
+  } break;
 
-    switch(lhs->kindOf(name)) {
+    switch (lhs->kindOf(name)) {
       COMPARE_ATTRIBUTEVALUE(f)
       COMPARE_ATTRIBUTEVALUE(fs)
       COMPARE_ATTRIBUTEVALUE(i)
@@ -56,11 +65,13 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
       COMPARE_ATTRIBUTEVALUE(s)
       COMPARE_ATTRIBUTEVALUE(ss)
       case AttributeKind::t: {
-        if (!tensorEqual(lhs->t(name), rhs->t(name))) return false;
+        if (!tensorEqual(lhs->t(name), rhs->t(name)))
+          return false;
         break;
       }
       case AttributeKind::ts: {
-        if (!tensorListEqual(lhs->ts(name), rhs->ts(name))) return false;
+        if (!tensorListEqual(lhs->ts(name), rhs->ts(name)))
+          return false;
         break;
       }
       case AttributeKind::g:
@@ -68,7 +79,7 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
         return false;
     }
 
-    #undef COMPARE_ATTRIBUTEVALUE
+#undef COMPARE_ATTRIBUTEVALUE
   }
 
   return true;
@@ -76,24 +87,28 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
 
 } // anonymous namespace
 
-
 size_t HashNode::operator()(const Node* k) const {
   JIT_ASSERT(k != nullptr);
-  return get_hash(k->kind(),
-                  fmap(k->outputs(), [](const Value *v) { return v->type()->kind(); }),
-                  fmap(k->inputs(), [](const Value *v) { return v->unique(); }));
+  return get_hash(
+      k->kind(),
+      fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }),
+      fmap(k->inputs(), [](const Value* v) { return v->unique(); }));
 };
 
 bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
-  if (lhs == nullptr && rhs == nullptr) return true;
-  if (lhs == nullptr || rhs == nullptr) return false;
+  if (lhs == nullptr && rhs == nullptr)
+    return true;
+  if (lhs == nullptr || rhs == nullptr)
+    return false;
 
-  if (lhs->kind() != rhs->kind()) return false;
+  if (lhs->kind() != rhs->kind())
+    return false;
 
   // Check whether the output types are the same.
   auto lhs_outputs = lhs->outputs();
   auto rhs_outputs = rhs->outputs();
-  if (lhs_outputs.size() != rhs_outputs.size()) return false;
+  if (lhs_outputs.size() != rhs_outputs.size())
+    return false;
   for (size_t i = 0; i < lhs_outputs.size(); ++i) {
     if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
       return false;
@@ -102,12 +117,16 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
   // Check whether the inputs are the same.
   auto lhs_inputs = lhs->inputs();
   auto rhs_inputs = rhs->inputs();
-  if (lhs_inputs.size() != rhs_inputs.size()) return false;
-  if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin())) return false;
+  if (lhs_inputs.size() != rhs_inputs.size())
+    return false;
+  if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
+    return false;
 
-  if (!attributesEqualCSE(lhs, rhs)) return false;
+  if (!attributesEqualCSE(lhs, rhs))
+    return false;
 
   return true;
 };
 
-}}
+} // namespace jit
+} // namespace torch
index d67da4c..112e565 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct HashNode {
   size_t operator()(const Node* k) const;
@@ -12,4 +13,5 @@ struct EqualNode {
   bool operator()(const Node* lhs, const Node* rhs) const;
 };
 
-}}
+} // namespace jit
+} // namespace torch
index 5168702..833fbd0 100644 (file)
@@ -1,27 +1,27 @@
 #include <ATen/ATen.h>
 #include <torch/csrc/jit/alias_info.h>
-#include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/parse_string_literal.h>
-#include <torch/csrc/jit/script/tree.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/lexer.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
+#include <torch/csrc/jit/script/tree.h>
 
 #include <functional>
 #include <memory>
 #include <utility>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace script {
 struct SchemaParser {
-  SchemaParser(const std::string& str)
-  : L(str) {}
+  SchemaParser(const std::string& str) : L(str) {}
 
   FunctionSchema parseDeclaration() {
     auto name = L.expect(TK_IDENT).text();
-    if(L.nextIf(':')) {
+    if (L.nextIf(':')) {
       L.expect(':');
       name = name + "::" + L.expect(TK_IDENT).text();
     }
@@ -31,11 +31,12 @@ struct SchemaParser {
     bool is_vararg = false;
     size_t idx = 0;
     parseList('(', ',', ')', [&] {
-      if(is_vararg)
-        throw ErrorReport(L.cur()) << "... must be the last element of the argument list";
+      if (is_vararg)
+        throw ErrorReport(L.cur())
+            << "... must be the last element of the argument list";
       if (L.nextIf('*')) {
         kwarg_only = true;
-      } else if(L.nextIf(TK_DOTS)) {
+      } else if (L.nextIf(TK_DOTS)) {
         is_vararg = true;
       } else {
         arguments.push_back(parseArgument(
@@ -61,7 +62,7 @@ struct SchemaParser {
     std::vector<FunctionSchema> results;
     do {
       results.push_back(parseDeclaration());
-    } while(L.nextIf(TK_NEWLINE));
+    } while (L.nextIf(TK_NEWLINE));
     L.expect(TK_EOF);
     return results;
   }
@@ -71,21 +72,21 @@ struct SchemaParser {
   }
   TypePtr parseBaseType() {
     static std::unordered_map<std::string, TypePtr> type_map = {
-      {"Generator", GeneratorType::get() },
-      {"ScalarType", IntType::get() },
-      {"Layout", IntType::get() },
-      {"Device", DeviceObjType::get() },
-      {"Scalar", NumberType::get() },
-      {"str", StringType::get() },
-      {"float", FloatType::get() },
-      {"int", IntType::get() },
-      {"bool", BoolType::get() },
+        {"Generator", GeneratorType::get()},
+        {"ScalarType", IntType::get()},
+        {"Layout", IntType::get()},
+        {"Device", DeviceObjType::get()},
+        {"Scalar", NumberType::get()},
+        {"str", StringType::get()},
+        {"float", FloatType::get()},
+        {"int", IntType::get()},
+        {"bool", BoolType::get()},
     };
     auto tok = L.expect(TK_IDENT);
     auto text = tok.text();
     auto it = type_map.find(text);
-    if(it == type_map.end()) {
-      if(text.size() > 0 && islower(text[0])) {
+    if (it == type_map.end()) {
+      if (text.size() > 0 && islower(text[0])) {
         // lower case identifiers that are not otherwise valid types
         // are treated as type variables
         return VarType::create(text);
@@ -123,7 +124,7 @@ struct SchemaParser {
       alias_info.addSet(
           Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
       alias_info.setIsWrite(true);
-    } else{
+    } else {
       return c10::nullopt;
     }
 
@@ -159,8 +160,8 @@ struct SchemaParser {
     } else {
       value = parseBaseType();
     }
-    while(true) {
-      if(L.cur().kind == '[' && L.lookahead().kind == ']') {
+    while (true) {
+      if (L.cur().kind == '[' && L.lookahead().kind == ']') {
         L.next(); // [
         L.next(); // ]
         value = ListType::create(value);
@@ -169,7 +170,7 @@ struct SchemaParser {
           container->addContainedType(std::move(*alias_info));
         }
         alias_info = std::move(container);
-      } else if(L.nextIf('?')) {
+      } else if (L.nextIf('?')) {
         value = OptionalType::create(value);
       } else {
         break;
@@ -187,7 +188,7 @@ struct SchemaParser {
     c10::optional<IValue> default_value;
     c10::optional<std::string> alias_set;
     std::string name;
-    if(L.nextIf('[')) {
+    if (L.nextIf('[')) {
       // note: an array with a size hint can only occur at the Argument level
       type = ListType::create(type);
       N = std::stoll(L.expect(TK_NUMBER).text());
@@ -198,16 +199,16 @@ struct SchemaParser {
       }
       alias_info = std::move(container);
     }
-    if(is_return) {
+    if (is_return) {
       // optionally named return values
-      if(L.cur().kind == TK_IDENT) {
+      if (L.cur().kind == TK_IDENT) {
         name = L.next().text();
       } else {
         name = "ret" + std::to_string(idx);
       }
     } else {
       name = L.expect(TK_IDENT).text();
-      if(L.nextIf('=')) {
+      if (L.nextIf('=')) {
         default_value = parseDefaultValue(type, N);
       }
     }
@@ -220,7 +221,7 @@ struct SchemaParser {
         std::move(alias_info));
   }
   IValue parseSingleConstant(TypeKind kind) {
-    switch(L.cur().kind) {
+    switch (L.cur().kind) {
       case TK_TRUE:
         L.next();
         return true;
@@ -237,11 +238,11 @@ struct SchemaParser {
       case TK_IDENT: {
         auto tok = L.next();
         auto text = tok.text();
-        if("float" == text) {
+        if ("float" == text) {
           return static_cast<int64_t>(at::kFloat);
-        } else if("strided" == text) {
+        } else if ("strided" == text) {
           return static_cast<int64_t>(at::kStrided);
-        } else if("Mean" == text) {
+        } else if ("Mean" == text) {
           return static_cast<int64_t>(Reduction::Mean);
         } else {
           throw ErrorReport(L.cur().range) << "invalid numeric default value";
@@ -249,11 +250,12 @@ struct SchemaParser {
       }
       default:
         std::string n;
-        if(L.nextIf('-'))
+        if (L.nextIf('-'))
           n = "-" + L.expect(TK_NUMBER).text();
         else
           n = L.expect(TK_NUMBER).text();
-        if(kind == TypeKind::FloatType || n.find('.') != std::string::npos || n.find('e') != std::string::npos) {
+        if (kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
+            n.find('e') != std::string::npos) {
           return std::stod(n);
         } else {
           int64_t v = std::stoll(n);
@@ -261,31 +263,29 @@ struct SchemaParser {
         }
     }
   }
-  IValue convertToList(TypeKind kind, const SourceRange& range, std::vector<IValue> vs) {
-    switch(kind) {
-        case TypeKind::FloatType:
-          return fmap(vs, [](IValue v) {
-            return v.toDouble();
-          });
-        case TypeKind::IntType:
-          return fmap(vs, [](IValue v) {
-            return v.toInt();
-          });
-        case TypeKind::BoolType:
-          return fmap(vs, [](IValue v) {
-            return v.toBool();
-          });
-        default:
-          throw ErrorReport(range) << "lists are only supported for float or int types.";
-      }
+  IValue convertToList(
+      TypeKind kind,
+      const SourceRange& range,
+      std::vector<IValue> vs) {
+    switch (kind) {
+      case TypeKind::FloatType:
+        return fmap(vs, [](IValue v) { return v.toDouble(); });
+      case TypeKind::IntType:
+        return fmap(vs, [](IValue v) { return v.toInt(); });
+      case TypeKind::BoolType:
+        return fmap(vs, [](IValue v) { return v.toBool(); });
+      default:
+        throw ErrorReport(range)
+            << "lists are only supported for float or int types.";
+    }
   }
   IValue parseConstantList(TypeKind kind) {
     auto tok = L.expect('[');
     std::vector<IValue> vs;
-    if(L.cur().kind != ']') {
+    if (L.cur().kind != ']') {
       do {
         vs.push_back(parseSingleConstant(kind));
-      } while(L.nextIf(','));
+      } while (L.nextIf(','));
     }
     L.expect(']');
     return convertToList(kind, tok.range, std::move(vs));
@@ -295,13 +295,15 @@ struct SchemaParser {
     L.expect(TK_NONE);
     return IValue();
   }
-  IValue parseDefaultValue(const TypePtr& arg_type, c10::optional<int32_t> arg_N) {
+  IValue parseDefaultValue(
+      const TypePtr& arg_type,
+      c10::optional<int32_t> arg_N) {
     auto range = L.cur().range;
-    switch(arg_type->kind()) {
+    switch (arg_type->kind()) {
       case TypeKind::DynamicType:
       case TypeKind::GeneratorType: {
         return parseTensorDefault(range);
-      }  break;
+      } break;
       case TypeKind::StringType:
       case TypeKind::OptionalType:
       case TypeKind::NumberType:
@@ -311,15 +313,16 @@ struct SchemaParser {
         return parseSingleConstant(arg_type->kind());
         break;
       case TypeKind::DeviceObjType: {
-        auto device_text = parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
+        auto device_text =
+            parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
         return c10::Device(device_text);
         break;
       }
       case TypeKind::ListType: {
         auto elem_kind = arg_type->cast<ListType>()->getElementType();
-        if(L.cur().kind == TK_IDENT) {
+        if (L.cur().kind == TK_IDENT) {
           return parseTensorDefault(range);
-        } else if(arg_N && L.cur().kind != '[') {
+        } else if (arg_N && L.cur().kind != '[') {
           IValue v = parseSingleConstant(elem_kind->kind());
           std::vector<IValue> repeated(*arg_N, v);
           return convertToList(elem_kind->kind(), range, repeated);
@@ -333,7 +336,11 @@ struct SchemaParser {
     return IValue(); // silence warnings
   }
 
-  void parseList(int begin, int sep, int end, const std::function<void()>& callback) {
+  void parseList(
+      int begin,
+      int sep,
+      int end,
+      const std::function<void()>& callback) {
     auto r = L.cur().range;
     if (begin != TK_NOTHING)
       L.expect(begin);
@@ -352,28 +359,32 @@ struct SchemaParser {
 } // namespace script
 
 namespace {
-using OperatorMap = std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
-struct OperatorRegistry  {
-private:
+using OperatorMap =
+    std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
+struct OperatorRegistry {
+ private:
   std::mutex lock;
   OperatorMap operators;
   // list of operators whose schema have not yet been parsed, and must
   // be registered before any call to lookup an opeator
   std::vector<std::shared_ptr<Operator>> to_register;
-  // Those two maps are used to implement lookupByLiteral, which is needed for the n->match(...) calls.
-  // Basically, every function schema is assigned a unique string you can use to match it. However,
-  // parsing those strings or comparing and hashing them character by character would be very slow, so
-  // we use a trick here! Every string literal in your program is guaranteed to have static storage
-  // duration and so its address won't change at runtime. This allows us to memoize answers for every
-  // pointer, which is done by the operators_by_sig_literal map. Still, this map is initially
-  // empty, and so we still need to do the complete string matching at the first time, which is implemented
-  // by performing a lookup in the operators_by_sig map.
+  // Those two maps are used to implement lookupByLiteral, which is needed for
+  // the n->match(...) calls. Basically, every function schema is assigned a
+  // unique string you can use to match it. However, parsing those strings or
+  // comparing and hashing them character by character would be very slow, so we
+  // use a trick here! Every string literal in your program is guaranteed to
+  // have static storage duration and so its address won't change at runtime.
+  // This allows us to memoize answers for every pointer, which is done by the
+  // operators_by_sig_literal map. Still, this map is initially empty, and so we
+  // still need to do the complete string matching at the first time, which is
+  // implemented by performing a lookup in the operators_by_sig map.
   std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
-  std::unordered_map<const char *, std::shared_ptr<Operator>> operators_by_sig_literal;
+  std::unordered_map<const char*, std::shared_ptr<Operator>>
+      operators_by_sig_literal;
 
   // XXX - caller must be holding lock
   void registerPendingOperators() {
-    for(const auto& op : to_register) {
+    for (const auto& op : to_register) {
       Symbol sym = Symbol::fromQualString(op->schema().name());
       operators[sym].push_back(op);
       operators_by_sig[canonicalSchemaString(op->schema())] = op;
@@ -381,18 +392,19 @@ private:
     to_register.clear();
   }
 
-public:
+ public:
   void registerOperator(Operator&& op) {
     std::lock_guard<std::mutex> guard(lock);
     to_register.push_back(std::make_shared<Operator>(std::move(op)));
   }
 
-  const std::shared_ptr<Operator>& lookupByLiteral(const char * name) {
+  const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
     std::lock_guard<std::mutex> guard(lock);
     registerPendingOperators();
     auto it = operators_by_sig_literal.find(name);
     if (it == operators_by_sig_literal.end()) {
-      auto op_ptr_it = operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
+      auto op_ptr_it =
+          operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
       // Handy debugging code that dumps all operators we know about on mismatch
 #if 0
       if (op_ptr_it == operators_by_sig.end()) {
@@ -401,19 +413,21 @@ public:
         }
       }
 #endif
-      JIT_ASSERTM(op_ptr_it != operators_by_sig.end(), "Couldn't find an operator for ", name);
+      JIT_ASSERTM(
+          op_ptr_it != operators_by_sig.end(),
+          "Couldn't find an operator for ",
+          name);
       it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
     }
     return it->second;
   }
 
-
   const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
     std::lock_guard<std::mutex> guard(lock);
     registerPendingOperators();
     static std::vector<std::shared_ptr<Operator>> empty;
     auto it = operators.find(name);
-    if(it != operators.end())
+    if (it != operators.end())
       return it->second;
     return empty;
   }
@@ -427,7 +441,7 @@ OperatorRegistry& getRegistry() {
 } // anonymous namespace
 
 void registerOperator(Operator&& op) {
-  if(op.schema().is_varret()) {
+  if (op.schema().is_varret()) {
     Symbol s = Symbol::fromQualString(op.schema().name());
     if (!printerHasSpecialCaseFor(s)) {
       std::cout << c10::str(
@@ -444,7 +458,7 @@ const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
   return getRegistry().getOperators(name);
 }
 
-Operator& sig(const char *signature) {
+Operator& sig(const charsignature) {
   return *getRegistry().lookupByLiteral(signature);
 }
 
@@ -459,13 +473,14 @@ std::string canonicalSchemaString(const FunctionSchema& schema) {
   out << "(";
 
   bool seen_kwarg_only = false;
-  for(size_t i = 0; i < schema.arguments().size(); ++i) {
-    if (i > 0) out << ", ";
+  for (size_t i = 0; i < schema.arguments().size(); ++i) {
+    if (i > 0)
+      out << ", ";
     if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
       out << "*, ";
       seen_kwarg_only = true;
     }
-    const auto & arg = schema.arguments()[i];
+    const auto& arg = schema.arguments()[i];
     out << arg.type()->str() << " " << arg.name();
   }
 
@@ -475,7 +490,8 @@ std::string canonicalSchemaString(const FunctionSchema& schema) {
   } else if (schema.returns().size() > 1) {
     out << "(";
     for (size_t i = 0; i < schema.returns().size(); ++i) {
-      if (i > 0) out << ", ";
+      if (i > 0)
+        out << ", ";
       out << schema.returns()[i].type()->str();
     }
     out << ")";
@@ -492,12 +508,11 @@ bool Operator::matches(const Node* node) const {
   const auto& formals = schema().arguments();
 
   // not enough inputs
-  if(actuals.size() < formals.size())
+  if (actuals.size() < formals.size())
     return false;
 
-
   TypeEnv type_env;
-  for(size_t i = 0; i < formals.size(); ++i) {
+  for (size_t i = 0; i < formals.size(); ++i) {
     const MatchTypeReturn matched_type =
         matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
     if (!matched_type.type) {
@@ -510,8 +525,9 @@ bool Operator::matches(const Node* node) const {
   }
 
   // too many inputs
-  if(!schema().is_vararg() && actuals.size() != formals.size()) {
-    // std::cout << "not all inputs used\n" << input_i << " " << inputs_size << "\n";
+  if (!schema().is_vararg() && actuals.size() != formals.size()) {
+    // std::cout << "not all inputs used\n" << input_i << " " << inputs_size <<
+    // "\n";
     return false;
   }
 
@@ -520,8 +536,8 @@ bool Operator::matches(const Node* node) const {
 
 std::shared_ptr<Operator> findOperatorFor(const Node* node) {
   const auto& candidates = getAllOperatorsFor(node->kind());
-  for(const auto& candidate : candidates) {
-    if(candidate->matches(node)) {
+  for (const auto& candidate : candidates) {
+    if (candidate->matches(node)) {
       return candidate;
     }
   }
@@ -530,42 +546,41 @@ std::shared_ptr<Operator> findOperatorFor(const Node* node) {
 
 const Operator& getOperatorFor(const Node* node) {
   auto op = findOperatorFor(node);
-  if(op)
+  if (op)
     return *op;
 
   auto er = script::ErrorReport(node->getSourceLocation());
   er << "Schema not found for node. File a bug report.\n";
   er << "Node: " << *node << "\n";
   er << "Input types:";
-  for(size_t i = 0; i < node->inputs().size(); ++i) {
-    if(i > 0)
+  for (size_t i = 0; i < node->inputs().size(); ++i) {
+    if (i > 0)
       er << ", ";
     er << *node->inputs()[i]->type();
   }
   er << "\ncandidates were:\n";
   const auto& candidates = getAllOperatorsFor(node->kind());
-  for(auto & candidate : candidates) {
+  for (auto& candidate : candidates) {
     er << "  " << candidate->schema() << "\n";
   }
   er << *node->owningGraph() << "\n";
   throw er;
 }
 
-
-OperatorSet::OperatorSet(std::initializer_list<const char *> sig_literals) {
-  auto & registry = getRegistry();
-  for (const char * sig : sig_literals) {
+OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
+  auto& registry = getRegistry();
+  for (const char* sig : sig_literals) {
     auto op = registry.lookupByLiteral(sig);
     ops[Symbol::fromQualString(op->schema().name())].push_back(op);
   }
 }
 
-Operator* OperatorSet::find(const Node *n) const  {
+Operator* OperatorSet::find(const Node* n) const {
   auto it = ops.find(n->kind());
   if (it == ops.end()) {
     return nullptr;
   }
-  for (auto & op : it->second) {
+  for (auto& op : it->second) {
     if (op->matches(n)) {
       return op.get();
     }
@@ -573,4 +588,5 @@ Operator* OperatorSet::find(const Node *n) const  {
   return nullptr;
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 5d66f5a..9265fe2 100644 (file)
@@ -4,8 +4,8 @@
 #pragma once
 
 #include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/stack.h>
 
 #include <ATen/ATen.h>
@@ -18,7 +18,8 @@
 #include <utility>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API FunctionSchema parseSchema(const std::string& schema);
 
@@ -38,7 +39,14 @@ struct TORCH_API Operator {
   // arguments. This is used for things like prim::While or prim::If that can
   // take a number of different valid input types and lengths.
   Operator(Symbol name, OperationCreator op_creator)
-      : Operator(FunctionSchema(name, {}, {}, /*is_vararg*/true, /*is_varret*/true), std::move(op_creator)) {}
+      : Operator(
+            FunctionSchema(
+                name,
+                {},
+                {},
+                /*is_vararg*/ true,
+                /*is_varret*/ true),
+            std::move(op_creator)) {}
 
   Operator(FunctionSchema schema, Operation op)
       : schema_(std::make_shared<FunctionSchema>(std::move(schema))),
@@ -58,52 +66,56 @@ struct TORCH_API Operator {
     return op_creator_(node);
   }
 
-  const FunctionSchema & schema() const {
+  const FunctionSchema& schema() const {
     // we lazily parse schema initialized from strings so that
     // we do less work during static operator registration
-    if(!schema_) {
-      schema_ = std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
+    if (!schema_) {
+      schema_ =
+          std::make_shared<FunctionSchema>(parseSchema(schema_string_.value()));
       schema_string_ = c10::nullopt;
     }
     return *schema_;
   }
-private:
- mutable c10::optional<std::string> schema_string_;
- // cannot use c10::optional because windows has issues that require an
- // assignment operator to be generated cannot use std::unique_ptr because
- // initializer lists of Operators end up copying the Operator
- mutable std::shared_ptr<FunctionSchema> schema_;
-
- // Essentially a variant<Operation, OperationCreator>.
- // NB: std::function has a default state (where it == nullptr).
- std::shared_ptr<Operation> op_;
- OperationCreator op_creator_;
+
+ private:
+  mutable c10::optional<std::string> schema_string_;
+  // cannot use c10::optional because windows has issues that require an
+  // assignment operator to be generated cannot use std::unique_ptr because
+  // initializer lists of Operators end up copying the Operator
+  mutable std::shared_ptr<FunctionSchema> schema_;
+
+  // Essentially a variant<Operation, OperationCreator>.
+  // NB: std::function has a default state (where it == nullptr).
+  std::shared_ptr<Operation> op_;
+  OperationCreator op_creator_;
 };
 
 TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema);
 
-TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name);
+TORCH_API const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
+    Symbol name);
 std::shared_ptr<Operator> findOperatorFor(const Node* node);
 const Operator& getOperatorFor(const Node* node);
 
 inline Operation getOperation(const Node* node) {
-  // note: getOperatorFor ensures that getOperatorFor(node).matches(node) == true
-  // so the call to selectVariant is always valid.
+  // note: getOperatorFor ensures that getOperatorFor(node).matches(node) ==
+  // true so the call to selectVariant is always valid.
   return getOperatorFor(node).getOperation(node);
 }
 
 TORCH_API void registerOperator(Operator&& op);
 
 // XXX: this function is meant to be used with string literals only!
-Operator& sig(const char *signature_literal);
+Operator& sig(const charsignature_literal);
 
 struct OperatorSet {
-  OperatorSet(std::initializer_list<const char *> sig_literals);
+  OperatorSet(std::initializer_list<const char*> sig_literals);
   // XXX: Returns a nullptr if no Operator in the set matches n
-  Operator* find(const Node *n) const;
-private:
+  Operator* find(const Node* n) const;
+
+ private:
   std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
 };
 
-
-}}
+} // namespace jit
+} // namespace torch
index 1db977c..2c99d89 100644 (file)
@@ -1,28 +1,29 @@
 #include <torch/csrc/jit/passes/batch_mm.h>
 
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/custom_operator.h>
+#include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/peephole.h>
-#include <torch/csrc/jit/passes/alias_analysis.h>
-#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/utils/functional.h>
 
 #include <ATen/ATen.h>
 #include <algorithm>
 #include <unordered_map>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-// This pass looks for trees in the graph, where leaves are mm ops, and the inner
-// vertices are add nodes. Once we have such a tree they can be reduced to two
-// concats and a single mm (basically into a single multiply of a wide matrix, with
-// a tall matrix).
-// Such patterns show up mostly in backward of RNNs, since the derivative of many
-// uses of matrix multiplies with same weights forms exactly such a tree
-// (note that it's usually also highly imbalanced i.e. has O(n) depth).
+// This pass looks for trees in the graph, where leaves are mm ops, and the
+// inner vertices are add nodes. Once we have such a tree they can be reduced to
+// two concats and a single mm (basically into a single multiply of a wide
+// matrix, with a tall matrix). Such patterns show up mostly in backward of
+// RNNs, since the derivative of many uses of matrix multiplies with same
+// weights forms exactly such a tree (note that it's usually also highly
+// imbalanced i.e. has O(n) depth).
 //
 // This (or any tree of adds of MMs):
 //
@@ -51,34 +52,35 @@ namespace torch { namespace jit {
 // +------+------+ +------+
 
 // Note [Further optimizations]
-// It would be straightforward to extend the TreeToken class to also detect if all
-// MMs had the same lhs/rhs. In such case it's more efficient to expand the lhs
-// and use bmm + sum instead of repeating it in memory via concat.
+// It would be straightforward to extend the TreeToken class to also detect if
+// all MMs had the same lhs/rhs. In such case it's more efficient to expand the
+// lhs and use bmm + sum instead of repeating it in memory via concat.
 
 // Note [Overlapping trees]
 // Additionally it wouldn't be too hard to add support for partially overlapping
 // trees. Right now the it's forbidden in the algorithm (only a single tree will
-// be allowed), so theoretically we might miss some optimization options, especially
-// that the rejected tree could be much larger. I didn't implement that because it's
-// not necessary for the simple RNN cases I saw, so I decided to keep stuff simple.
-// If we ever get around implementing this, the right solution is probably to fuse
-// MMs for the common part, and assume it's an input leaf for the outer two parts
-// (I don't think it's beneficial to recompute, unless the subtree is super small,
-// but let's not get into such details).
+// be allowed), so theoretically we might miss some optimization options,
+// especially that the rejected tree could be much larger. I didn't implement
+// that because it's not necessary for the simple RNN cases I saw, so I decided
+// to keep stuff simple. If we ever get around implementing this, the right
+// solution is probably to fuse MMs for the common part, and assume it's an
+// input leaf for the outer two parts (I don't think it's beneficial to
+// recompute, unless the subtree is super small, but let's not get into such
+// details).
 
 // The algorithm we're using is simple. We're iterating through the graph in the
-// topological order and labeling nodes with TreeTokens. Then, we look for roots of
-// the trees we formed and fuse them.
+// topological order and labeling nodes with TreeTokens. Then, we look for roots
+// of the trees we formed and fuse them.
 
 // Tunable parameter. Set to something larger if it turns out to be better.
 static constexpr size_t min_fusion_size = 4;
 
 bool have_same_shape(at::TensorList inputs) {
   auto expected_sizes = inputs[0].sizes();
-  return std::all_of(inputs.begin(), inputs.end(),
-                     [expected_sizes](const at::Tensor& t) {
-                       return t.sizes() == expected_sizes;
-                     });
+  return std::all_of(
+      inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
+        return t.sizes() == expected_sizes;
+      });
 }
 
 bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
@@ -89,10 +91,8 @@ bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
   return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
 }
 
-RegisterOperators mm_tree_reduction_reg({
-  Operator(
-    prim::MMTreeReduce,
-    [](const Node* node) {
+RegisterOperators mm_tree_reduction_reg(
+    {Operator(prim::MMTreeReduce, [](const Node* node) {
       size_t num_inputs = node->inputs().size();
       return [num_inputs](Stack& stack) {
         std::vector<at::Tensor> inputs;
@@ -107,8 +107,10 @@ RegisterOperators mm_tree_reduction_reg({
         size_t side_num_elems = inputs.size() / 2;
         auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
         auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
-        // TODO: checking this is not free, so we should stop if this keeps failing
-        if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) && shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
+        // TODO: checking this is not free, so we should stop if this keeps
+        // failing
+        if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
+            shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
           auto lhs = at::cat(lhs_inputs, /*dim=*/1);
           auto rhs = at::cat(rhs_inputs, /*dim=*/0);
           push(stack, at::mm(lhs, rhs));
@@ -121,8 +123,7 @@ RegisterOperators mm_tree_reduction_reg({
         }
         return 0;
       };
-    })
-});
+    })});
 
 // TreeTokens will be used to label nodes of the graph, if the nodes will fit
 // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
@@ -131,10 +132,10 @@ RegisterOperators mm_tree_reduction_reg({
 // and build a larger tree.
 struct TreeToken {
   uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
-  Node *node = nullptr;
+  Nodenode = nullptr;
   bool is_root = false;
 
-  static TreeToken mm(Node *mm) {
+  static TreeToken mm(Nodemm) {
     TreeToken token;
     token.tree_size = 1;
     token.node = mm;
@@ -142,10 +143,12 @@ struct TreeToken {
     return token;
   }
 
-  // NB: the returned token might be invalid, so make sure to check its boolean value!
-  static TreeToken transpose(Node *t, TreeToken& inp_token) {
+  // NB: the returned token might be invalid, so make sure to check its boolean
+  // value!
+  static TreeToken transpose(Node* t, TreeToken& inp_token) {
     TreeToken token;
-    if (!inp_token.node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+    if (!inp_token.node->matches(
+            "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
       return token;
     }
     token.tree_size = 1;
@@ -155,8 +158,9 @@ struct TreeToken {
     return token;
   }
 
-  // NB: the returned token might be invalid, so make sure to check its boolean value!
-  static TreeToken add(Node *add, TreeToken& l, TreeToken& r) {
+  // NB: the returned token might be invalid, so make sure to check its boolean
+  // value!
+  static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
     TreeToken token;
     // See Note [Overlapping trees]
     if (&l == &r || !l.is_root || !r.is_root)
@@ -164,7 +168,8 @@ struct TreeToken {
     token.tree_size = l.tree_size + r.tree_size;
     token.node = add;
     token.is_root = true;
-    l.is_root = r.is_root = false; // Reserve the subtrees, so they can't be used again.
+    l.is_root = r.is_root =
+        false; // Reserve the subtrees, so they can't be used again.
     return token;
   }
 
@@ -174,25 +179,29 @@ struct TreeToken {
 
   std::vector<Node*> removeTransposesAndGatherMatmuls() {
     std::vector<Node*> matmuls;
-    std::vector<Node*> queue {node};
+    std::vector<Node*> queue{node};
     Graph* graph = node->owningGraph();
     while (!queue.empty()) {
-      auto n = queue.back(); queue.pop_back();
+      auto n = queue.back();
+      queue.pop_back();
       if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
         matmuls.push_back(n);
       } else if (n->matches("aten::t(Tensor self) -> Tensor")) {
-        Node * input_node = n->input()->node();
-        JIT_ASSERT(input_node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor"));
+        Node* input_node = n->input()->node();
+        JIT_ASSERT(input_node->matches(
+            "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
         // (AB)^T == B^TA^T
-        WithInsertPoint insert_guard { input_node };
-        Value * A = input_node->inputs()[0];
-        Value * B = input_node->inputs()[1];
-        Value * AT = graph->insert(aten::t, {A});
-        Value * BT = graph->insert(aten::t, {B});
-        Value * BTAT = graph->insert(aten::mm, {BT, AT});
+        WithInsertPoint insert_guard{input_node};
+        Value* A = input_node->inputs()[0];
+        Value* B = input_node->inputs()[1];
+        Value* AT = graph->insert(aten::t, {A});
+        Value* BT = graph->insert(aten::t, {B});
+        Value* BTAT = graph->insert(aten::mm, {BT, AT});
         n->output()->replaceAllUsesWith(BTAT);
         matmuls.push_back(BTAT->node());
-      } else if (n->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+      } else if (
+          n->matches(
+              "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
         queue.push_back(n->inputs()[0]->node());
         queue.push_back(n->inputs()[1]->node());
       } else {
@@ -218,9 +227,11 @@ void BatchMMTreeReduce(Block* block) {
       if (input_it != tokens.end()) {
         tokens[node] = TreeToken::transpose(node, input_it->second);
       }
-    } else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      Node *lhs = node->inputs()[0]->node();
-      Node *rhs = node->inputs()[1]->node();
+    } else if (
+        node->matches(
+            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
+      Node* lhs = node->inputs()[0]->node();
+      Node* rhs = node->inputs()[1]->node();
       auto lhs_it = tokens.find(lhs);
       auto rhs_it = tokens.find(rhs);
       // See Note [Overlapping trees] (regarding the uses().size() == 1 check)
@@ -228,9 +239,11 @@ void BatchMMTreeReduce(Block* block) {
       // XXX: uses().size() == 1 is also something that guarantees that this
       // transform is valid, because we know for sure that the none of these
       // operands depend on the result of the other. If we were to remove this,
-      // we need to compute a transitive closure and actually check the dependencies.
+      // we need to compute a transitive closure and actually check the
+      // dependencies.
       if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
-          lhs->output()->uses().size() == 1 && rhs->output()->uses().size() == 1) {
+          lhs->output()->uses().size() == 1 &&
+          rhs->output()->uses().size() == 1) {
         if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
           tokens[node] = token;
         }
@@ -243,17 +256,18 @@ void BatchMMTreeReduce(Block* block) {
   }
 
   // Merge trees we've found
-  for (auto & item : tokens) {
-    auto & root = item.second;
+  for (auto& item : tokens) {
+    auto& root = item.second;
     if (!root || root.tree_size < min_fusion_size)
       continue;
     auto matmuls = root.removeTransposesAndGatherMatmuls();
-    WithInsertPoint insert_guard {root.node};
-    Node * tree_reduce = graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
-    for (Node * matmul : matmuls) {
+    WithInsertPoint insert_guard{root.node};
+    Node* tree_reduce =
+        graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
+    for (Node* matmul : matmuls) {
       tree_reduce->addInput(matmul->inputs().at(0));
     }
-    for (Node * matmul : matmuls) {
+    for (Node* matmul : matmuls) {
       tree_reduce->addInput(matmul->inputs().at(1));
     }
     root.node->output()->replaceAllUsesWith(tree_reduce->output());
@@ -266,29 +280,37 @@ bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
   return other_side_input.numel() <= 1024 * 2048;
 }
 
-RegisterOperators mm_batch_side_reg({
-  Operator(
-    prim::MMBatchSide,
-    [](const Node* node) {
+RegisterOperators mm_batch_side_reg(
+    {Operator(prim::MMBatchSide, [](const Node* node) {
       size_t num_other_side_inputs = node->inputs().size() - 1;
       Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
       return [num_other_side_inputs, single_side](Stack& stack) {
         at::Tensor side_input;
         std::vector<at::Tensor> other_side_inputs;
         other_side_inputs.reserve(num_other_side_inputs);
-        for (auto it = stack.end() - num_other_side_inputs; it != stack.end(); ++it) {
+        for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
+             ++it) {
           other_side_inputs.push_back(std::move(*it).toTensor());
         }
         drop(stack, num_other_side_inputs);
         pop(stack, side_input);
 
         auto any_other_input = other_side_inputs[0];
-        if (have_same_shape(other_side_inputs) && shape_is_fast_for_side(other_side_inputs[0])) {
-          auto other_side_input = at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
-          auto mm_out = single_side == Side::LHS ? side_input.mm(other_side_input) : other_side_input.mm(side_input);
-          auto outputs = at::chunk(mm_out, num_other_side_inputs, /*dim=*/single_side == Side::LHS ? 1 : 0);
-          stack.insert(stack.end(), std::make_move_iterator(outputs.begin()),
-                                    std::make_move_iterator(outputs.end()));
+        if (have_same_shape(other_side_inputs) &&
+            shape_is_fast_for_side(other_side_inputs[0])) {
+          auto other_side_input =
+              at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
+          auto mm_out = single_side == Side::LHS
+              ? side_input.mm(other_side_input)
+              : other_side_input.mm(side_input);
+          auto outputs = at::chunk(
+              mm_out,
+              num_other_side_inputs,
+              /*dim=*/single_side == Side::LHS ? 1 : 0);
+          stack.insert(
+              stack.end(),
+              std::make_move_iterator(outputs.begin()),
+              std::make_move_iterator(outputs.end()));
         } else {
           if (single_side == Side::LHS) {
             for (at::Tensor& other : other_side_inputs) {
@@ -303,32 +325,36 @@ RegisterOperators mm_batch_side_reg({
 
         return 0;
       };
-    })
-});
+    })});
 
-std::pair<std::vector<Node*>, std::vector<Node*>>
-gatherIndependentMMUses(Value *value, const AliasDb& alias_db) {
+std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
+    Value* value,
+    const AliasDb& alias_db) {
   const auto postprocess = [&](std::vector<Node*> mms) {
     if (mms.size() == 0) {
       return mms;
     }
-    std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) { return n->isBefore(m); });
-    // Filter out dependent MMs. This algorithm might do very badly if e.g. you have
-    // a lot of independent MMs, that depend on the first one, but I doubt this will
-    // be a common scenario.
+    std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
+      return n->isBefore(m);
+    });
+    // Filter out dependent MMs. This algorithm might do very badly if e.g. you
+    // have a lot of independent MMs, that depend on the first one, but I doubt
+    // this will be a common scenario.
     for (size_t i = 0; i < mms.size(); ++i) {
-      if (mms[i] == nullptr) continue;
+      if (mms[i] == nullptr)
+        continue;
       for (size_t j = i + 1; j < mms.size(); ++j) {
-        if (mms[j] == nullptr) continue;
+        if (mms[j] == nullptr)
+          continue;
         if (!mms[j]->couldMoveBeforeTopologically(mms[i], alias_db)) {
           mms[j] = nullptr;
         }
       }
     }
-    return filter(mms, [](Node *n) { return n != nullptr; });
+    return filter(mms, [](Noden) { return n != nullptr; });
   };
 
-  Block * block = value->node()->owningBlock();
+  Block* block = value->node()->owningBlock();
   std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
   std::vector<Node*> rhses; // Like above, but rhs
   for (Use u : value->uses()) {
@@ -344,7 +370,7 @@ gatherIndependentMMUses(Value *value, const AliasDb& alias_db) {
   return std::make_pair(postprocess(lhses), postprocess(rhses));
 }
 
-void BatchMMSide(Block * block, const AliasDb& alias_db) {
+void BatchMMSide(Block* block, const AliasDb& alias_db) {
   // NB: 8 is the current loop unrolling factor
   static constexpr size_t how_many_is_many = 8;
   const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
@@ -353,10 +379,12 @@ void BatchMMSide(Block * block, const AliasDb& alias_db) {
       bool move_ok = mms[i]->moveBeforeTopologicallyValid(mms[i + 1], alias_db);
       JIT_ASSERT(move_ok);
     }
-    WithInsertPoint insert_guard { mms[0] };
+    WithInsertPoint insert_guard{mms[0]};
     Graph* graph = mms[0]->owningGraph();
-    Node* batch_mm = graph->create(prim::MMBatchSide,
-                                   /*inputs=*/{}, /*num_outputs=*/mms.size());
+    Node* batch_mm = graph->create(
+        prim::MMBatchSide,
+        /*inputs=*/{},
+        /*num_outputs=*/mms.size());
     graph->insertNode(batch_mm);
     batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
     Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
@@ -368,10 +396,10 @@ void BatchMMSide(Block * block, const AliasDb& alias_db) {
   };
 
   std::unordered_set<Value*> considered_values;
-  for (Node * node : block->nodes()) {
+  for (Node* node : block->nodes()) {
     if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
-      for (Value * input : node->inputs()) {
-        if (/*bool not_inserted = */!considered_values.emplace(input).second) {
+      for (Value* input : node->inputs()) {
+        if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
           continue;
         }
         auto uses_with_many = gatherIndependentMMUses(input, alias_db);
@@ -388,7 +416,6 @@ void BatchMMSide(Block * block, const AliasDb& alias_db) {
       }
     }
   }
-
 }
 
 bool hasMutableOperators(Block* block) {
@@ -412,9 +439,10 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
   BatchMMTreeReduce(graph->block());
   BatchMMSide(graph->block(), alias_db);
   EliminateDeadCode(graph);
-  // It's possible that transpose rearrangements have created sequences of consecutive
-  // transposes that didn't exist before.
+  // It's possible that transpose rearrangements have created sequences of
+  // consecutive transposes that didn't exist before.
   PeepholeOptimize(graph);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index d35461c..a2181bc 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void BatchMM(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 80971e6..7900a52 100644 (file)
@@ -1,6 +1,7 @@
 #include <torch/csrc/jit/passes/canonicalize.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Canonicalize a graph, renumbering it so that all structurally equivalent
 // graphs have same numbers.
@@ -8,14 +9,16 @@ namespace torch { namespace jit {
 //   and replacing them with normal value names.
 //   Otherwise, ignores values with unique names.
 std::shared_ptr<Graph> Canonicalize(
-    const std::shared_ptr<Graph>& graph, bool keep_unique_names) {
+    const std::shared_ptr<Graph>& graph,
+    bool keep_unique_names) {
   auto r = std::make_shared<Graph>(graph->current_scope());
   std::unordered_map<Value*, Value*> rn_env;
   auto rn_fn = [&](Value* v) { return rn_env.at(v); };
   for (auto* input : graph->inputs()) {
     auto* r_input = r->addInput();
     r_input->copyMetadata(input);
-    if (!keep_unique_names) r_input->setUniqueName("");
+    if (!keep_unique_names)
+      r_input->setUniqueName("");
     rn_env[input] = r_input;
   }
   for (auto* node : graph->nodes()) {
@@ -32,7 +35,9 @@ std::shared_ptr<Graph> Canonicalize(
       rn_env[outputs.at(i)] = r_outputs.at(i);
     }
     if (node->hasAttribute(attr::Subgraph)) {
-      r_node->g_(attr::Subgraph, Canonicalize(node->g(attr::Subgraph), keep_unique_names));
+      r_node->g_(
+          attr::Subgraph,
+          Canonicalize(node->g(attr::Subgraph), keep_unique_names));
     }
   }
   for (auto* output : graph->outputs()) {
@@ -40,7 +45,7 @@ std::shared_ptr<Graph> Canonicalize(
   }
 
   return r;
-
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 0d1e1fa..ed3cb0e 100644 (file)
@@ -2,9 +2,12 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API std::shared_ptr<Graph> Canonicalize(
-    const std::shared_ptr<Graph>& graph, bool keep_unique_names=true);
+    const std::shared_ptr<Graph>& graph,
+    bool keep_unique_names = true);
 
-}}
+}
+} // namespace torch
index 2aefaf3..0554b25 100644 (file)
@@ -2,25 +2,28 @@
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 
-
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct ChunkOutput {
-  ChunkOutput(Value * v, size_t o)
-    : val(v), offset(o) {};
-  Value * val;
+  ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
+  Value* val;
   size_t offset;
 };
 
 static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
   std::vector<ChunkOutput> outputs;
   for (auto list_use : chunk->output()->uses()) {
-    if (list_use.user->matches("aten::select(Tensor[] list, int idx) -> Tensor", attr::b)) {
-      outputs.emplace_back(list_use.user->output(),
-                            list_use.user->get<int64_t>(attr::b).value());
+    if (list_use.user->matches(
+            "aten::select(Tensor[] list, int idx) -> Tensor", attr::b)) {
+      outputs.emplace_back(
+          list_use.user->output(),
+          list_use.user->get<int64_t>(attr::b).value());
     } else if (list_use.user->kind() == prim::ListUnpack) {
-      // This sometimes happens if the sizes can't be evenly divided by the number of chunks
-      if (static_cast<int64_t>(list_use.user->outputs().size()) != chunk->get<int64_t>(attr::chunks).value()) {
+      // This sometimes happens if the sizes can't be evenly divided by the
+      // number of chunks
+      if (static_cast<int64_t>(list_use.user->outputs().size()) !=
+          chunk->get<int64_t>(attr::chunks).value()) {
         return c10::nullopt;
       }
       auto unpack_outputs = list_use.user->outputs();
@@ -44,8 +47,9 @@ static void CanonicalizeOps(Block* block) {
     // followed by an add so that it can go through the existing optimization,
     // shape analysis and differentiation passes for those two individual ops.
     // Later, we will fuse together those two ops into a single addmm.
-    if (it->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
-                    /*const_inputs=*/{attr::beta, attr::alpha})) {
+    if (it->matches(
+            "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
+            /*const_inputs=*/{attr::beta, attr::alpha})) {
       if (it->get<at::Scalar>(attr::alpha)->toDouble() != 1.0 ||
           it->get<at::Scalar>(attr::beta)->toDouble() != 1.0) {
         continue;
@@ -58,36 +62,43 @@ static void CanonicalizeOps(Block* block) {
       SymbolicVariable mat2(it->inputs()[2]);
 
       auto mm_result = mat1.mm(mat2);
-      // Set this intermediate aten::mm node to have the same output type as the original aten::addmm
-      // otherwise the canonicalized graph will have DynamicType as the output of this node which is incorrect
+      // Set this intermediate aten::mm node to have the same output type as the
+      // original aten::addmm otherwise the canonicalized graph will have
+      // DynamicType as the output of this node which is incorrect
       (static_cast<Value*>(mm_result))->setType(it->output()->type());
       auto result = mat + mm_result;
       (static_cast<Value*>(result))->setType(it->output()->type());
 
       it->output()->replaceAllUsesWith(result);
       it.destroyCurrent();
-    } else if (it->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
-               it->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
-               it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
-               it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
+    } else if (
+        it->matches(
+            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
+        it->matches(
+            "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
+        it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
+        it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
       if (auto other = it->get<at::Tensor>(attr::other)) {
         if (other->dim() == 0) {
-          WithInsertPoint insert_guard {*it};
+          WithInsertPoint insert_guard{*it};
           auto graph = it->owningGraph();
           auto new_other = graph->insertConstant(other->item());
           std::vector<Value*> inputs = it->inputs().vec();
           inputs.at(1) = new_other;
-          Value * new_output = graph->insertNode(graph->create(it->kind(), inputs))->output();
+          Value* new_output =
+              graph->insertNode(graph->create(it->kind(), inputs))->output();
           it->output()->replaceAllUsesWith(new_output);
         }
       }
-    } else if (it->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
-                           /*const_inputs=*/{attr::chunks, attr::dim})) {
+    } else if (it->matches(
+                   "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
+                   /*const_inputs=*/{attr::chunks, attr::dim})) {
       if (auto orig_outputs = getChunkOutputs(*it)) {
         WithInsertPoint guard(*it);
-        SymbolicVariable self {it->namedInput(attr::self)};
-        auto outputs = self.chunk(it->get<int64_t>(attr::chunks).value(),
-                                  it->get<int64_t>(attr::dim).value());
+        SymbolicVariable self{it->namedInput(attr::self)};
+        auto outputs = self.chunk(
+            it->get<int64_t>(attr::chunks).value(),
+            it->get<int64_t>(attr::dim).value());
         for (ChunkOutput orig_out : *orig_outputs) {
           orig_out.val->replaceAllUsesWith(outputs.at(orig_out.offset));
           outputs[orig_out.offset].value()->setType(orig_out.val->type());
@@ -102,5 +113,5 @@ void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
   EliminateDeadCode(graph);
 }
 
-
-}}
+} // namespace jit
+} // namespace torch
index 21221a7..8dc7cf6 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void CanonicalizeOps(const std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index cac8f6b..7a87c1f 100644 (file)
@@ -5,9 +5,9 @@
 
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/interned_strings.h>
+#include <torch/csrc/jit/node_hashing.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
-#include <torch/csrc/jit/node_hashing.h>
 #include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/hash.h>
 
@@ -21,7 +21,7 @@ void EliminateCommonSubexpression(
     const AliasDb& aliasDb,
     std::function<Node*(Node*)> parent_lookup_fn) {
   std::unordered_set<Node*, HashNode, EqualNode> subexprs;
-  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
+  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
     auto node = *it;
     if (node->hasSideEffects() || node->isNondeterministic() ||
         aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
@@ -64,7 +64,7 @@ void EliminateCommonSubexpression(
     }
   }
 }
-}
+} // namespace
 
 void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
   const auto aliasDb = AliasAnalysis(graph);
index 2f81c30..71efb6d 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 54489e3..af6de16 100644 (file)
@@ -1,16 +1,19 @@
-#include <torch/csrc/jit/ir.h>
-#include <unordered_set>
 #include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/node_hashing.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <unordered_set>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
-//Very similar to the common subexpression elimination pass
-//Move all constants to the beginning of the graph, and deduplicate
-void ConstantPooling(Block * block, std::unordered_set<Node*, HashNode, EqualNode>& constants) {
+// Very similar to the common subexpression elimination pass
+// Move all constants to the beginning of the graph, and deduplicate
+void ConstantPooling(
+    Block* block,
+    std::unordered_set<Node*, HashNode, EqualNode>& constants) {
   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
     auto node = *it;
     // node may be moved to a different block so advance iterator now
@@ -44,10 +47,10 @@ void ConstantPooling(Block * block, std::unordered_set<Node*, HashNode, EqualNod
 
 } // anonymous namespace
 
-
 void ConstantPooling(const std::shared_ptr<Graph>& graph) {
   std::unordered_set<Node*, HashNode, EqualNode> constants;
   ConstantPooling(graph->block(), constants);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index ffa63c9..7af03d6 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void ConstantPooling(const std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 6d66dc2..1fa33f6 100644 (file)
@@ -1,4 +1,3 @@
-#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -6,23 +5,25 @@
 #include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
+#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/utils/functional.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
 std::unordered_set<Symbol> skip_list = {
-  prim::If,
-  prim::Loop, //TODO: handle Loop
-  prim::Constant,
-  prim::Undefined,
-  prim::None, // it is already a constant and propagating it will lose
-              // important type information about which Optional type it is
-  // TODO (zach): we should consider skipping tensor factories in the cases
-  // where the constant tensor would be large but cheap to create.
- };
+    prim::If,
+    prim::Loop, // TODO: handle Loop
+    prim::Constant,
+    prim::Undefined,
+    prim::None, // it is already a constant and propagating it will lose
+                // important type information about which Optional type it is
+    // TODO (zach): we should consider skipping tensor factories in the cases
+    // where the constant tensor would be large but cheap to create.
+};
 
 std::vector<IValue> runNode(Node* n) {
   auto op = getOperation(n);
@@ -34,7 +35,7 @@ std::vector<IValue> runNode(Node* n) {
   auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
     if (v.isTensor()) {
       auto t = std::move(v).toTensor();
-      if(t.defined()) {
+      if (t.defined()) {
         return IValue(autograd::as_variable_ref(t).data());
       } else {
         return t;
@@ -54,7 +55,7 @@ void propagateNode(Node* n) {
     try {
       auto new_output = graph->insertConstant(outputs[i]);
       n->outputs()[i]->replaceAllUsesWith(new_output);
-    } catch(constant_not_supported_error& err) {
+    } catch (constant_not_supported_error& err) {
       // we cannot actually represent the IValue as a constant node,
       // so we give up replacing it
     }
@@ -62,28 +63,29 @@ void propagateNode(Node* n) {
   }
 }
 
-void inlineIf(Block *body, Node * n) {
-  for(auto it = body->nodes().begin(); it != body->nodes().end();) {
-    Node *body_node = *it;
-    //advance iterator because after body_node is moved its next pointer will be
-    //to n
+void inlineIf(Block* body, Node* n) {
+  for (auto it = body->nodes().begin(); it != body->nodes().end();) {
+    Nodebody_node = *it;
+    // advance iterator because after body_node is moved its next pointer will
+    // be to n
     it++;
     body_node->moveBefore(n);
   }
   for (size_t i = 0; i < n->outputs().size(); ++i) {
     n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
   }
-  // NB: destroy the node here, because it might contain side effects, like print
+  // NB: destroy the node here, because it might contain side effects, like
+  // print
   n->destroy();
 }
 
-bool isTrueConstant(Value *val) {
+bool isTrueConstant(Valueval) {
   c10::optional<bool> maybe_value = constant_as<bool>(val);
   JIT_ASSERT(maybe_value);
   return *maybe_value;
 }
 
-void inlineIf(Node *n) {
+void inlineIf(Noden) {
   if (isTrueConstant(n->input())) {
     inlineIf(n->blocks()[0], n);
   } else {
@@ -91,24 +93,24 @@ void inlineIf(Node *n) {
   }
 }
 
-//remove extra outputs from the node
-bool removeExtraNodeOutputs(Node *n) {
+// remove extra outputs from the node
+bool removeExtraNodeOutputs(Noden) {
   JIT_ASSERTM(n->kind() == prim::If, "Only supported for If nodes");
   auto true_block = n->blocks()[0];
   auto false_block = n->blocks()[1];
   auto initial_outputs = true_block->outputs().size();
-  for (size_t i = 0; i < true_block->outputs().size(); ) {
-    //neither block changes the output value
+  for (size_t i = 0; i < true_block->outputs().size();) {
+    // neither block changes the output value
     if (true_block->outputs()[i] == false_block->outputs()[i]) {
       n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
       n->eraseOutput(i);
       true_block->eraseOutput(i);
       false_block->eraseOutput(i);
     } else {
-      i++; //increment bc we didn't remove current index
+      i++; // increment bc we didn't remove current index
     }
   }
-  //an output was removed
+  // an output was removed
   return initial_outputs != true_block->outputs().size();
 }
 
@@ -120,47 +122,47 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
         return v->node()->kind() == prim::Constant;
       });
   bool supported_node = !n->kind().is_onnx() &&
-      skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && !n->hasSideEffects() &&
-      !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
+      skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
+      !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
   auto run_blocks = [&]() {
     if (recurse) {
-      for (Block * block : n->blocks()) {
+      for (Block* block : n->blocks()) {
         ConstantPropagation(block, aliasDb, recurse);
       }
     }
   };
   if (n->kind() == prim::If) {
     run_blocks();
-    //inline node if we can, otherwise check for simplified outputs
+    // inline node if we can, otherwise check for simplified outputs
     if (constant_inputs) {
       inlineIf(n);
     } else {
       removeExtraNodeOutputs(n);
     }
-    //don't rerun run_blocks
+    // don't rerun run_blocks
     return;
   } else if (constant_inputs && supported_node) {
     propagateNode(n);
   }
-  //TODO handle loop nodes. Even if a loop node contains an if that is
-  //inlined its mutated variables currently don't get updated
+  // TODO handle loop nodes. Even if a loop node contains an if that is
+  // inlined its mutated variables currently don't get updated
   run_blocks();
 }
 
 void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse) {
-  for(auto it = block->nodes().begin(); it != block->nodes().end();) {
-    Node *n = *it;
-    it++; //advance iterator bc the current node may be destroyed
+  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+    Noden = *it;
+    it++; // advance iterator bc the current node may be destroyed
     ConstantPropagation(n, aliasDb, recurse);
   }
 }
 } // anonymous namespace
 
-
 void ConstantPropagation(std::shared_ptr<Graph>& graph) {
   const auto aliasDb = AliasAnalysis(graph);
   ConstantPropagation(graph->block(), aliasDb, true);
   EliminateDeadCode(graph);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 2af5369..0141004 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void ConstantPropagation(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 320a0db..dd2faa0 100644 (file)
@@ -1,11 +1,12 @@
 #pragma once
 
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/ir.h>
 
 #include <cstddef>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // insert GraphExecutor nodes that group together
 // subgraphs that are differentiable by the jit's autodiff passes
@@ -14,4 +15,5 @@ namespace torch { namespace jit {
 TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
     const std::shared_ptr<Graph>& graph,
     size_t threshold = 2);
-}}
+} // namespace jit
+} // namespace torch
index b7d606c..3389dce 100644 (file)
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 
-#include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/ir_views.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
index 9b6ef08..c12f638 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // If given a top-level graph, DCE will construct do alias analysis that allows
 // for "smarter" dead code elimination (we will eliminate mutable ops if we can
@@ -11,7 +12,7 @@ namespace torch { namespace jit {
 //
 // So, prefer to use the graph version if you can.
 TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
-TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
+TORCH_API void EliminateDeadCode(Block* block, bool recurse = true);
 
 // Invoke the user-provided callback on all live values before deleting anything
 TORCH_API void EliminateDeadCode(
index 7c5fa1b..12eeaa7 100644 (file)
@@ -1,7 +1,8 @@
-#include <torch/csrc/jit/passes/erase_number_types.h>
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/erase_number_types.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 static void EraseNumberTypesOnBlock(Block* block) {
   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
@@ -36,7 +37,7 @@ static void EraseNumberTypesOnBlock(Block* block) {
         // Let DCE cleanup
       } break;
       default: {
-        for(auto o : it->outputs()) {
+        for (auto o : it->outputs()) {
           if (o->type()->isSubtypeOf(NumberType::get())) {
             o->setType(CompleteTensorType::fromNumberType(o->type()));
           } else if (o->type()->isSubtypeOf(BoolType::get())) {
@@ -52,4 +53,5 @@ void EraseNumberTypes(const std::shared_ptr<Graph>& graph) {
   EraseNumberTypesOnBlock(graph->block());
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 872ad0d..3624356 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Erase NumberType information. This is necessary for and only used in
 // exporting to ONNX. This pass ensures that no remaining Values have
@@ -12,9 +13,10 @@ namespace torch { namespace jit {
 // - prim::Constant nodes which are numbers get changed into 0-dim tensors of
 //   the corresponding type
 // - prim::TensorToNum, prim::ImplicitTensorToNum and prim::NumToTensor nodes
-// are erased.
+//   are erased.
 //
 // The pass assumes that DCE will be called sometime after.
 TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
 
-}}
+} // namespace jit
+} // namespace torch
index ae79808..0157000 100644 (file)
@@ -1,21 +1,22 @@
 #include <torch/csrc/jit/passes/graph_fuser.h>
 
+#include <ATen/ExpandUtils.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/fuser/interface.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/symbolic_variable.h>
-#include <torch/csrc/jit/fuser/interface.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/autodiff.h>
-#include <torch/csrc/jit/assertions.h>
-#include <ATen/ExpandUtils.h>
 #include <unordered_map>
 
 #ifdef USE_CUDA
-  #include <cuda.h> // for CUDA_VERSION
+#include <cuda.h> // for CUDA_VERSION
 #endif
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -29,82 +30,82 @@ namespace {
 //    - Produces contiguous outputs
 // Some of these restrictions may be relaxable, but you should
 // carefully read the code first, as we rely on these assumptions.
-bool isSimpleMap(Node *node) {
-  static OperatorSet simple_mappable {{
-    "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
-
-    "aten::abs(Tensor self) -> Tensor",
-    "aten::acos(Tensor self) -> Tensor",
-    "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
-    "aten::asin(Tensor self) -> Tensor",
-    "aten::atan(Tensor self) -> Tensor",
-    "aten::atan2(Tensor self, Tensor other) -> Tensor",
-    "aten::ceil(Tensor self) -> Tensor",
-    "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
-    "aten::cos(Tensor self) -> Tensor",
-    "aten::cosh(Tensor self) -> Tensor",
-    "aten::div(Tensor self, Tensor other) -> Tensor",
-    "aten::exp(Tensor self) -> Tensor",
-    "aten::expm1(Tensor self) -> Tensor",
-    "aten::erf(Tensor self) -> Tensor",
-    "aten::erfc(Tensor self) -> Tensor",
-    "aten::floor(Tensor self) -> Tensor",
-    "aten::fmod(Tensor self, Tensor other) -> Tensor",
-    "aten::frac(Tensor self) -> Tensor",
-    "aten::lgamma(Tensor self) -> Tensor",
-    "aten::log(Tensor self) -> Tensor",
-    "aten::log10(Tensor self) -> Tensor",
-    "aten::log1p(Tensor self) -> Tensor",
-    "aten::log2(Tensor self) -> Tensor",
-    "aten::max(Tensor self, Tensor other) -> Tensor",
-    "aten::min(Tensor self, Tensor other) -> Tensor",
-    "aten::mul(Tensor self, Tensor other) -> Tensor",
-    "aten::neg(Tensor self) -> Tensor",
-    "aten::pow(Tensor self, Tensor exponent) -> Tensor",
-    "aten::pow(Tensor self, Scalar exponent) -> Tensor",
-    // See https://github.com/pytorch/pytorch/issues/14674 and make sure you
-    // won't make the same mistake before you reenable this.
-    //"aten::rand_like(Tensor self) -> Tensor",
-    "aten::reciprocal(Tensor self) -> Tensor",
-    "aten::relu(Tensor self) -> Tensor",
-    "aten::remainder(Tensor self, Tensor other) -> Tensor",
-    "aten::round(Tensor self) -> Tensor",
-    "aten::rsqrt(Tensor self) -> Tensor",
-    "aten::sigmoid(Tensor self) -> Tensor",
-    "aten::sin(Tensor self) -> Tensor",
-    "aten::sinh(Tensor self) -> Tensor",
-    "aten::sqrt(Tensor self) -> Tensor",
-    "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
-    "aten::tan(Tensor self) -> Tensor",
-    "aten::tanh(Tensor self) -> Tensor",
-    "aten::trunc(Tensor self) -> Tensor",
-    "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
-    "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
-    "aten::mul(Tensor self, Scalar other) -> Tensor",
-    "aten::div(Tensor self, Scalar other) -> Tensor",
-
-    "aten::eq(Tensor self, Tensor other) -> Tensor",
-    "aten::eq(Tensor self, Scalar other) -> Tensor",
-    "aten::ne(Tensor self, Tensor other) -> Tensor",
-    "aten::ne(Tensor self, Scalar other) -> Tensor",
-    "aten::ge(Tensor self, Tensor other) -> Tensor",
-    "aten::ge(Tensor self, Scalar other) -> Tensor",
-    "aten::gt(Tensor self, Tensor other) -> Tensor",
-    "aten::gt(Tensor self, Scalar other) -> Tensor",
-    "aten::le(Tensor self, Tensor other) -> Tensor",
-    "aten::le(Tensor self, Scalar other) -> Tensor",
-    "aten::lt(Tensor self, Tensor other) -> Tensor",
-    "aten::lt(Tensor self, Scalar other) -> Tensor",
-
-    "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
-
-    "aten::type_as(Tensor self, Tensor other) -> Tensor",
+bool isSimpleMap(Nodenode) {
+  static OperatorSet simple_mappable{{
+      "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
+
+      "aten::abs(Tensor self) -> Tensor",
+      "aten::acos(Tensor self) -> Tensor",
+      "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+      "aten::asin(Tensor self) -> Tensor",
+      "aten::atan(Tensor self) -> Tensor",
+      "aten::atan2(Tensor self, Tensor other) -> Tensor",
+      "aten::ceil(Tensor self) -> Tensor",
+      "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
+      "aten::cos(Tensor self) -> Tensor",
+      "aten::cosh(Tensor self) -> Tensor",
+      "aten::div(Tensor self, Tensor other) -> Tensor",
+      "aten::exp(Tensor self) -> Tensor",
+      "aten::expm1(Tensor self) -> Tensor",
+      "aten::erf(Tensor self) -> Tensor",
+      "aten::erfc(Tensor self) -> Tensor",
+      "aten::floor(Tensor self) -> Tensor",
+      "aten::fmod(Tensor self, Tensor other) -> Tensor",
+      "aten::frac(Tensor self) -> Tensor",
+      "aten::lgamma(Tensor self) -> Tensor",
+      "aten::log(Tensor self) -> Tensor",
+      "aten::log10(Tensor self) -> Tensor",
+      "aten::log1p(Tensor self) -> Tensor",
+      "aten::log2(Tensor self) -> Tensor",
+      "aten::max(Tensor self, Tensor other) -> Tensor",
+      "aten::min(Tensor self, Tensor other) -> Tensor",
+      "aten::mul(Tensor self, Tensor other) -> Tensor",
+      "aten::neg(Tensor self) -> Tensor",
+      "aten::pow(Tensor self, Tensor exponent) -> Tensor",
+      "aten::pow(Tensor self, Scalar exponent) -> Tensor",
+      // See https://github.com/pytorch/pytorch/issues/14674 and make sure you
+      // won't make the same mistake before you reenable this.
+      //"aten::rand_like(Tensor self) -> Tensor",
+      "aten::reciprocal(Tensor self) -> Tensor",
+      "aten::relu(Tensor self) -> Tensor",
+      "aten::remainder(Tensor self, Tensor other) -> Tensor",
+      "aten::round(Tensor self) -> Tensor",
+      "aten::rsqrt(Tensor self) -> Tensor",
+      "aten::sigmoid(Tensor self) -> Tensor",
+      "aten::sin(Tensor self) -> Tensor",
+      "aten::sinh(Tensor self) -> Tensor",
+      "aten::sqrt(Tensor self) -> Tensor",
+      "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+      "aten::tan(Tensor self) -> Tensor",
+      "aten::tanh(Tensor self) -> Tensor",
+      "aten::trunc(Tensor self) -> Tensor",
+      "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+      "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+      "aten::mul(Tensor self, Scalar other) -> Tensor",
+      "aten::div(Tensor self, Scalar other) -> Tensor",
+
+      "aten::eq(Tensor self, Tensor other) -> Tensor",
+      "aten::eq(Tensor self, Scalar other) -> Tensor",
+      "aten::ne(Tensor self, Tensor other) -> Tensor",
+      "aten::ne(Tensor self, Scalar other) -> Tensor",
+      "aten::ge(Tensor self, Tensor other) -> Tensor",
+      "aten::ge(Tensor self, Scalar other) -> Tensor",
+      "aten::gt(Tensor self, Tensor other) -> Tensor",
+      "aten::gt(Tensor self, Scalar other) -> Tensor",
+      "aten::le(Tensor self, Tensor other) -> Tensor",
+      "aten::le(Tensor self, Scalar other) -> Tensor",
+      "aten::lt(Tensor self, Tensor other) -> Tensor",
+      "aten::lt(Tensor self, Scalar other) -> Tensor",
+
+      "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
+
+      "aten::type_as(Tensor self, Tensor other) -> Tensor",
   }};
   if (!simple_mappable.find(node)) {
     return false;
   }
   // Check that all non-tensor inputs are constant
-  for (Value * input : node->inputs()) {
+  for (Value* input : node->inputs()) {
     if (input->type()->isSubtypeOf(DynamicType::get())) {
       continue;
     }
@@ -115,93 +116,99 @@ bool isSimpleMap(Node *node) {
   return true;
 }
 
-Value * broadcastSizes(at::ArrayRef<Value*> sizes) {
+Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
   JIT_ASSERT(!sizes.empty());
-  Graph * graph = sizes[0]->owningGraph();
-  Node * broadcast_n = graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
+  Graph* graph = sizes[0]->owningGraph();
+  Node* broadcast_n =
+      graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
   broadcast_n->output()->setType(ListType::ofInts());
   return broadcast_n->output();
 }
 
 struct GraphFuser {
-  Block * block_;
+  Block* block_;
   std::shared_ptr<Graph> graph_;
 
   GraphFuser(Block* block, std::shared_ptr<Graph> graph)
       : block_(block), graph_(std::move(graph)) {}
 
-  value_list tensorInputs(Node * node) {
-    return filter(node->inputs(), [](Value * v) {
+  value_list tensorInputs(Node* node) {
+    return filter(node->inputs(), [](Value* v) {
       return v->type()->isSubtypeOf(DynamicType::get());
     });
   }
 
-  bool isFusable(Node * node) {
+  bool isFusable(Node* node) {
     // We don't want to bother with cross-block node movements, as they
     // are not necessarily correct.
-    if (node->owningBlock() != block_) return false;
+    if (node->owningBlock() != block_)
+      return false;
     return node->kind() == prim::FusionGroup || isSimpleMap(node);
   }
 
-  bool isFusableCatNode(Node * node) {
+  bool isFusableCatNode(Node* node) {
     if (node->kind() != aten::cat)
       return false;
     if (!node->is_constant(attr::dim))
       return false;
     auto tensors_node = node->namedInput(attr::tensors)->node();
-    if (tensors_node->kind() != prim::ListConstruct) return false;
-    // NB: Note that technically other uses of the list aren't a big problem for us.
-    // It would be enough to place the prim::FusedConcat before the prim::ListConstruct, and
-    // allUsersAreThisConsumerOrOccurAfterIt would still be satisfied. However, I don't expect this
-    // to be necessary any time soon, and so we're simply assuming that we don't have to deal with it.
-    if (tensors_node->output()->uses().size() > 1) return false;
+    if (tensors_node->kind() != prim::ListConstruct)
+      return false;
+    // NB: Note that technically other uses of the list aren't a big problem for
+    // us. It would be enough to place the prim::FusedConcat before the
+    // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
+    // still be satisfied. However, I don't expect this to be necessary any time
+    // soon, and so we're simply assuming that we don't have to deal with it.
+    if (tensors_node->output()->uses().size() > 1)
+      return false;
     return true;
   }
 
   // Can this node produce an _output_ of a fusion group?
-  // all Fusable nodes can do this, but additionally Concat, which normally cannot be fused
-  // because it is not a simple map, can be put in a fusion group
-  // as long as no items in the group read the output of concat
-  bool isFusableAsExitNode(Node * node) {
+  // all Fusable nodes can do this, but additionally Concat, which normally
+  // cannot be fused because it is not a simple map, can be put in a fusion
+  // group as long as no items in the group read the output of concat
+  bool isFusableAsExitNode(Node* node) {
     return isFusable(node) || isFusableOnlyAsExitNode(node);
   }
 
-  bool isFusableOnlyAsExitNode(Node * node) {
+  bool isFusableOnlyAsExitNode(Node* node) {
     return isFusableCatNode(node) || node->kind() == prim::FusedConcat;
   }
 
-  bool calculatesSize(Node * node) {
+  bool calculatesSize(Node* node) {
     return node->matches("aten::size(Tensor self) -> int[]");
   }
 
-  bool allUsersAreThisConsumerOrCalcSizes(Node * consumer, Value * producer) {
+  bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
     auto defining_node = producer->node();
-    for(auto o : defining_node->outputs()) {
-      for(auto u : o->uses()) {
-        if(u.user != consumer && !calculatesSize(u.user))
+    for (auto o : defining_node->outputs()) {
+      for (auto u : o->uses()) {
+        if (u.user != consumer && !calculatesSize(u.user))
           return false;
       }
     }
     return true;
   }
 
-  bool mustRemainAsFusionGroupOutput(Value * producer) {
+  bool mustRemainAsFusionGroupOutput(Value* producer) {
     if (producer->node()->kind() != prim::FusionGroup) {
       return false;
     }
     auto subgraph = producer->node()->g(attr::Subgraph);
-    auto * node = subgraph->outputs().at(producer->offset())->node();
+    auto* node = subgraph->outputs().at(producer->offset())->node();
     return isFusableOnlyAsExitNode(node);
   }
 
-  Graph & getSubgraph(Node * n) {
+  Graph& getSubgraph(Node* n) {
     JIT_ASSERT(n->kind() == prim::FusionGroup);
     return *n->g(attr::Subgraph);
   }
 
-  void mergeFusionGroups(Node *consumer_group, Node *producer_group) {
+  void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
     // Now we have two fusion groups!
-    // Revert the fusion - place all inner nodes of producer back in the outer graph.
+    // Revert the fusion - place all inner nodes of producer back in the outer
+    // graph.
     std::vector<Node*> temporary_nodes;
     auto producer_subgraph = &getSubgraph(producer_group);
 
@@ -215,9 +222,8 @@ struct GraphFuser {
 
     // Clone all nodes
     for (auto inner : producer_subgraph->nodes()) {
-      Node * outer = block_->owningGraph()->createClone(inner, [&](Value * k) -> Value* {
-        return inner_to_outer.at(k);
-      });
+      Node* outer = block_->owningGraph()->createClone(
+          inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
       outer->insertBefore(producer_group);
       temporary_nodes.emplace_back(outer);
       auto inner_outputs = inner->outputs();
@@ -233,18 +239,21 @@ struct GraphFuser {
       producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
     }
     producer_group->destroy();
-    producer_group = nullptr; // Just to get a clear error in case someone uses it
+    producer_group =
+        nullptr; // Just to get a clear error in case someone uses it
 
     // Inline the temporary nodes into the first group
     auto consumer_subgraph = &getSubgraph(consumer_group);
-    for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); ++it) {
-      Node *node = *it;
-      Node *merged = mergeNodeIntoGroup(consumer_group, node);
+    for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
+         ++it) {
+      Node* node = *it;
+      Node* merged = mergeNodeIntoGroup(consumer_group, node);
       // If any of the outputs are still used then we need to add them
       auto outputs = node->outputs();
       for (size_t i = 0; i < outputs.size(); ++i) {
         auto output = outputs[i];
-        if (output->uses().size() == 0) continue;
+        if (output->uses().size() == 0)
+          continue;
         consumer_subgraph->registerOutput(merged->outputs()[i]);
         auto new_output = consumer_group->addOutput();
         output->replaceAllUsesWith(new_output);
@@ -257,18 +266,19 @@ struct GraphFuser {
   // insert a producer node into a consuming fusion group.
   // DOES NOT WORK if n is a consumer of an output of the fusion group
   // returns the node _inside_ the group that represents the node
-  Node * mergeNodeIntoGroup(Node* group, Node * n) {
+  Node* mergeNodeIntoGroup(Node* group, Node* n) {
     JIT_ASSERT(n->kind() != prim::FusionGroup);
-    auto & subgraph = getSubgraph(group);
+    auto& subgraph = getSubgraph(group);
     // map from nodes in the surrounding graph to parameters in the fusion
     // group's subgraph that correspond to them
-    std::unordered_map<Value*,Value*> inputs_map;
+    std::unordered_map<Value*, Value*> inputs_map;
     size_t i = 0;
     JIT_ASSERT(group->inputs().size() == subgraph.inputs().size());
-    for(auto input : group->inputs()) {
+    for (auto input : group->inputs()) {
       inputs_map[input] = subgraph.inputs()[i++];
     }
-    // add n's inputs to the fusion group's input list if we don't already have them
+    // add n's inputs to the fusion group's input list if we don't already have
+    // them
     WithInsertPoint guard(*subgraph.nodes().begin());
     for (auto input : n->inputs()) {
       if (inputs_map.count(input) == 0) {
@@ -278,20 +288,23 @@ struct GraphFuser {
           inputs_map[input] = in_group;
           group->addInput(input);
         } else {
-          // We don't support passing in scalars as arguments to fused kernels, so we generally
-          // don't allow fusing tensor-scalar operations unless the scalar is constant. In those
-          // cases we inline the constants directly in the body of the fused group.
+          // We don't support passing in scalars as arguments to fused kernels,
+          // so we generally don't allow fusing tensor-scalar operations unless
+          // the scalar is constant. In those cases we inline the constants
+          // directly in the body of the fused group.
           JIT_ASSERT(input->node()->kind() == prim::Constant);
-          Node * in_const = subgraph.createClone(input->node(), [](Value*) -> Value* { throw std::runtime_error("unexpected input"); });
+          Node* in_const =
+              subgraph.createClone(input->node(), [](Value*) -> Value* {
+                throw std::runtime_error("unexpected input");
+              });
           subgraph.insertNode(in_const);
           inputs_map[input] = in_const->output();
         }
       }
     }
     // copy n into the graph, remapping its inputs to internal nodes
-    Node * in_graph = subgraph.createClone(n,[&](Value * k)-> Value* {
-      return inputs_map[k];
-    });
+    Node* in_graph = subgraph.createClone(
+        n, [&](Value* k) -> Value* { return inputs_map[k]; });
     // if n's outputs are already inputs to the fusion group,
     // we need to remove them because n is now inside the fusion group.
     //
@@ -304,7 +317,7 @@ struct GraphFuser {
     auto inputs = group->inputs();
     for (size_t i = 0; i < n->outputs().size(); ++i) {
       auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
-      if(it != inputs.end()) {
+      if (it != inputs.end()) {
         size_t p = it - inputs.begin();
         group->removeInput(p);
         subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
@@ -316,12 +329,12 @@ struct GraphFuser {
 
   // turn consumer node n into a fusion group with just n inside
   // to prepare for fusion and replace uses of n with the new group
-  Node * createSingletonFusionGroup(Node * n) {
+  Node* createSingletonFusionGroup(Node* n) {
     auto group = block_->owningGraph()->createFusionGroup();
     // propogate position information for the new node so we can always
     // have a valid mapping
     group->insertBefore(n);
-    Node * mergedNode = mergeNodeIntoGroup(group,n);
+    Node* mergedNode = mergeNodeIntoGroup(group, n);
     getSubgraph(group).registerOutput(mergedNode->output());
     auto sel = group->addOutput();
     sel->copyMetadata(n->output());
@@ -331,7 +344,7 @@ struct GraphFuser {
   }
 
   // TODO: remove this and use WithInsertPoint instead
-  void insertAt(Node ** insertion_point, Node * n) {
+  void insertAt(Node** insertion_point, Node* n) {
     n->insertAfter(*insertion_point);
     *insertion_point = n;
   }
@@ -340,11 +353,13 @@ struct GraphFuser {
       Node* consumer,
       Value* producer,
       const AliasDb& aliasDb) {
-    // this handles cases where producer can be moved _into_ the fusion group of consumer.
+    // this handles cases where producer can be moved _into_ the fusion group of
+    // consumer.
     // TODO: extend to fusion of consumer into _producer's_ fusion blob
     // if the consumer allInputsAreThisProducer(consumer,producer)
     // we can move the consumer up into the producer.
-    // but this requires better handling of merging fusion groups so it is not done now
+    // but this requires better handling of merging fusion groups so it is not
+    // done now
     Node* real_consumer = consumer->kind() == aten::cat
         ? consumer->namedInput(attr::tensors)->node()
         : consumer;
@@ -361,11 +376,13 @@ struct GraphFuser {
 
     auto group = consumer;
     if (consumer->kind() == aten::cat) {
-      Graph * graph = consumer->owningGraph();
-      Node * list_construct = consumer->namedInput(attr::tensors)->node();
+      Graph* graph = consumer->owningGraph();
+      Node* list_construct = consumer->namedInput(attr::tensors)->node();
       int64_t dim = consumer->get<int64_t>(attr::dim).value();
 
-      Node * fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())->i_(attr::dim, dim);
+      Node* fused_cat =
+          graph->create(prim::FusedConcat, list_construct->inputs())
+              ->i_(attr::dim, dim);
       fused_cat->insertBefore(list_construct);
       fused_cat->output()->copyMetadata(consumer->output());
       consumer->output()->replaceAllUsesWith(fused_cat->output());
@@ -384,14 +401,14 @@ struct GraphFuser {
       return group;
     }
     JIT_ASSERT(producer->node()->outputs().size() == 1);
-    Node * merged = mergeNodeIntoGroup(group, producer->node());
+    Node* merged = mergeNodeIntoGroup(group, producer->node());
     // remaining uses of this producer can occur because we allow
     // fusion in cases where uses remain after the consumer
     // if these exist, re-route them to the version of producer
     // created in FusionGroup
-    if(producer->uses().size() != 0) {
+    if (producer->uses().size() != 0) {
       getSubgraph(group).registerOutput(merged->output());
-      Value * new_producer = group->addOutput();
+      Value* new_producer = group->addOutput();
       new_producer->copyMetadata(producer);
       producer->replaceAllUsesWith(new_producer);
     }
@@ -404,9 +421,9 @@ struct GraphFuser {
       return false;
     }
     // Does the chunk have constant chunks/dim?
-    auto * chunk = producer->node();
+    auto* chunk = producer->node();
     if (chunk->kind() != prim::ConstantChunk)
-        return false;
+      return false;
     // And all uses of the chunk are in this consumer
     for (auto s : chunk->outputs()) {
       for (auto u : s->uses()) {
@@ -430,10 +447,10 @@ struct GraphFuser {
       return c10::nullopt;
     }
     size_t input_index = it - group->inputs().begin();
-    auto & subgraph = getSubgraph(group);
-    auto * subgraph_input = subgraph.inputs().at(input_index);
+    auto& subgraph = getSubgraph(group);
+    auto* subgraph_input = subgraph.inputs().at(input_index);
     // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
-    auto * node = subgraph_input->uses().at(0).user;
+    auto* node = subgraph_input->uses().at(0).user;
     if (node->kind() == prim::ConstantChunk) {
       JIT_ASSERT(subgraph_input->uses().size() == 1);
       return node;
@@ -442,15 +459,17 @@ struct GraphFuser {
   }
 
   void fuseChunkByReusingExistingFusedChunk(
-      Node * group, Node * chunk, Node * existingFusedChunk) {
+      Node* group,
+      Node* chunk,
+      Node* existingFusedChunk) {
     if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
       return;
     }
-    auto & subgraph = getSubgraph(group);
+    auto& subgraph = getSubgraph(group);
     for (size_t i = 0; i < chunk->outputs().size(); ++i) {
       // Find the input to the FusionGroup (group)
-      auto * replacement_val = existingFusedChunk->outputs().at(i);
-      auto * val = chunk->outputs().at(i);
+      auto* replacement_val = existingFusedChunk->outputs().at(i);
+      auto* val = chunk->outputs().at(i);
       auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
       auto input_index = it - group->inputs().begin();
 
@@ -466,18 +485,20 @@ struct GraphFuser {
   }
 
   // There are two invariants for prim::ConstantChunk:
-  // (1) the tensor input to prim::ConstantChunk must be an input to the fusion group
-  // (2) no two ConstantChunks in the same FusionGroup can share a tensor input.
-  graph_node_list::iterator fuseChunk(Node * consumer, Value * producer) {
-    auto * chunk = producer->node();
+  // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
+  // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
+  // input.
+  graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
+    auto* chunk = producer->node();
     JIT_ASSERT(consumer->kind() == prim::FusionGroup);
     JIT_ASSERT(chunk->kind() == prim::ConstantChunk);
 
     // if producer's input is already an input to a prim::ConstantChunk node,
     // we cannot add a new prim::ConstantChunk node because of invariant (2).
-    auto * chunked_tensor = producer->node()->input();
+    auto* chunked_tensor = producer->node()->input();
     if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
-      fuseChunkByReusingExistingFusedChunk(consumer, chunk, *existingFusedChunk);
+      fuseChunkByReusingExistingFusedChunk(
+          consumer, chunk, *existingFusedChunk);
       return consumer->reverseIterator();
     }
 
@@ -495,16 +516,16 @@ struct GraphFuser {
       }
     }
     // Sort in reverse topological order
-    std::sort(result.begin(), result.end(), [&](Value * a, Value * b) {
+    std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
       return a->node()->isAfter(b->node());
     });
     return result;
   }
 
-  graph_node_list::iterator scanNodeForChunks(Node * consumer) {
+  graph_node_list::iterator scanNodeForChunks(Node* consumer) {
     if (consumer->kind() == prim::FusionGroup) {
       auto inputs = sortReverseTopological(consumer->inputs());
-      for(auto producer : inputs) {
+      for (auto producer : inputs) {
         if (!canFuseChunk(consumer, producer)) {
           continue;
         }
@@ -514,10 +535,11 @@ struct GraphFuser {
     return ++consumer->reverseIterator();
   }
 
-  void insertExplicitBroadcast(Node *node) {
-    WithInsertPoint insert_guard { node };
+  void insertExplicitBroadcast(Nodenode) {
+    WithInsertPoint insert_guard{node};
     auto tensors = tensorInputs(node);
-    auto new_tensors = SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
+    auto new_tensors =
+        SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
 
     // Replace tensors inputs with broadcasted values
     auto new_tensors_it = new_tensors.begin();
@@ -529,15 +551,16 @@ struct GraphFuser {
     }
   }
 
-  Node * promoteChunkToBroadcastingChunk(Node * chunk) {
+  Node* promoteChunkToBroadcastingChunk(Node* chunk) {
     JIT_ASSERT(chunk->kind() == prim::ConstantChunk);
 
     size_t nchunks = chunk->i(attr::chunks);
-    Node * bchunk = chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
+    Node* bchunk =
+        chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
     bchunk->addInput(chunk->input());
     for (size_t i = 0; i < nchunks; ++i) {
-      auto * old_output = chunk->outputs().at(i);
-      auto * new_output = bchunk->outputs().at(i);
+      auto* old_output = chunk->outputs().at(i);
+      auto* new_output = bchunk->outputs().at(i);
       new_output->copyMetadata(old_output);
       old_output->replaceAllUsesWith(new_output);
     }
@@ -582,8 +605,9 @@ struct GraphFuser {
   // we exit the fusion group.
   //
   // NB: The intermediate BroadcastingChunk is important for moving chunks past
-  // more than one operation: the graph fuser is not able to easily move operations
-  // around broadcast_tensors + chunk nodes. Let f, g, h be fusible ops
+  // more than one operation: the graph fuser is not able to easily move
+  // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
+  // ops
   //   x = f(v, w)
   //   z = g(x, y)
   //   a, b = chunk(z)
@@ -612,25 +636,26 @@ struct GraphFuser {
   //   b = g(bx, by)
   //   c = h(a, b)
 
-  bool tryToMoveChunk(Node * consumer, Value * producer) {
+  bool tryToMoveChunk(Node* consumer, Value* producer) {
     // is the output from a chunk/bchunk node?
-    auto * chunk = producer->node();
-    if (chunk->kind() != prim::ConstantChunk && chunk->kind() != prim::BroadcastingChunk)
+    auto* chunk = producer->node();
+    if (chunk->kind() != prim::ConstantChunk &&
+        chunk->kind() != prim::BroadcastingChunk)
       return false;
 
-    // try to find a producer to move after the chunk/bchunk. The producer must be
-    // fusible into the consumer.
+    // try to find a producer to move after the chunk/bchunk. The producer must
+    // be fusible into the consumer.
     auto it = std::find_if(
         chunk->inputs().begin(),
         chunk->inputs().end(),
-        [&](Value * producer_for_chunk) {
+        [&](Value* producer_for_chunk) {
           return isFusable(producer_for_chunk->node()) &&
               allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
         });
     if (it == chunk->inputs().end()) {
       return false;
     }
-    Value * producer_for_chunk = *it;
+    Value* producer_for_chunk = *it;
     size_t producer_index = it - chunk->inputs().begin();
 
     // all uses of the chunk must be in in this consumer
@@ -641,12 +666,12 @@ struct GraphFuser {
       }
     }
     // multiple return operators
-    Node * producer_for_chunk_node = producer_for_chunk->node();
+    Node* producer_for_chunk_node = producer_for_chunk->node();
     JIT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
 
     // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
     // broadcast and one or more chunk operations.
-    auto * bchunk = chunk;
+    auto* bchunk = chunk;
     if (chunk->kind() == prim::ConstantChunk) {
       bchunk = promoteChunkToBroadcastingChunk(chunk);
     }
@@ -655,7 +680,8 @@ struct GraphFuser {
 
     std::vector<Value*> producer_chunk_outputs;
     for (size_t i = 0; i < nchunks; i++) {
-      producer_chunk_outputs.push_back(bchunk->output(nchunks * producer_index + i));
+      producer_chunk_outputs.push_back(
+          bchunk->output(nchunks * producer_index + i));
     }
 
     // Add each of op's operands to the bchunk node.
@@ -664,8 +690,9 @@ struct GraphFuser {
     std::vector<std::vector<Value*>> chunked_inputs;
 
     for (auto input : producer_for_chunk_node->inputs()) {
-      // XXX: we only work with pointwise ops in here, so we know it is valid to push
-      // the concat only through tensor arguments (and all other args can be safely ignored).
+      // XXX: we only work with pointwise ops in here, so we know it is valid to
+      // push the concat only through tensor arguments (and all other args can
+      // be safely ignored).
       if (!input->type()->isSubtypeOf(DynamicType::get()))
         continue;
 
@@ -691,7 +718,7 @@ struct GraphFuser {
       bchunk->addInput(input);
       chunked_inputs.emplace_back(); // alas, to not be C++17
       for (auto chunk_sel : producer_chunk_outputs) {
-        Value * input_chunk_sel = bchunk->addOutput();
+        Value* input_chunk_sel = bchunk->addOutput();
         input_chunk_sel->setType(chunk_sel->type());
         chunked_inputs.back().push_back(input_chunk_sel);
       }
@@ -701,14 +728,16 @@ struct GraphFuser {
     // and then rewrite the graph to use them!
     for (auto chunk_sel : producer_chunk_outputs) {
       auto original_inputs = producer_for_chunk_node->inputs();
-      Node * chunked_op = block_->owningGraph()->create(producer_for_chunk_node->kind());
+      Node* chunked_op =
+          block_->owningGraph()->create(producer_for_chunk_node->kind());
       chunked_op->copyAttributes(*producer_for_chunk_node);
       chunked_op->output()->setType(chunk_sel->type());
       auto chunked_inputs_it = chunked_inputs.begin();
       for (Value* original_input : original_inputs) {
         if (original_input->type()->isSubtypeOf(DynamicType::get())) {
           JIT_ASSERT(chunked_inputs_it != chunked_inputs.end());
-          chunked_op->addInput(chunked_inputs_it->at(chunk_sel->offset() % nchunks));
+          chunked_op->addInput(
+              chunked_inputs_it->at(chunk_sel->offset() % nchunks));
           ++chunked_inputs_it;
         } else {
           chunked_op->addInput(original_input);
@@ -723,16 +752,21 @@ struct GraphFuser {
       bchunk->eraseOutput(nchunks * producer_index);
     }
 
-    // The output of producer_for_chunk_node could have been used in some aten::size
-    // operators, so we need to clean those up as well (we simply broadcast all its tensor inputs).
+    // The output of producer_for_chunk_node could have been used in some
+    // aten::size operators, so we need to clean those up as well (we simply
+    // broadcast all its tensor inputs).
     auto size_calc_uses = producer_for_chunk_node->output()->uses();
     if (!size_calc_uses.empty()) {
-      auto tensor_inputs = filter(producer_for_chunk_node->inputs(),
-                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
-      auto tensor_sizes = fmap(tensor_inputs,
-                               [](Value * v) { return v->owningGraph()->insert(aten::size, {v}); });
+      auto tensor_inputs = filter(
+          producer_for_chunk_node->inputs(),
+          [](Value* v) { return v->type()->isSubtypeOf(DynamicType::get()); });
+      auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
+        return v->owningGraph()->insert(aten::size, {v});
+      });
       JIT_ASSERT(!tensor_sizes.empty());
-      Value * output_size = tensor_sizes.size() == 1 ? tensor_sizes[0] : broadcastSizes(tensor_sizes);
+      Value* output_size = tensor_sizes.size() == 1
+          ? tensor_sizes[0]
+          : broadcastSizes(tensor_sizes);
       for (Use u : size_calc_uses) {
         u.user->output()->replaceAllUsesWith(output_size);
         u.user->destroy();
@@ -746,18 +780,19 @@ struct GraphFuser {
   std::pair<graph_node_list::iterator, bool> scanNode(
       Node* consumer,
       const AliasDb& aliasDb) {
-    if(isFusableAsExitNode(consumer)) {
-      auto consumer_inputs = consumer->kind() == aten::cat ?
-        consumer->namedInput(attr::tensors)->node()->inputs() :
-        consumer->inputs();
+    if (isFusableAsExitNode(consumer)) {
+      auto consumer_inputs = consumer->kind() == aten::cat
+          ? consumer->namedInput(attr::tensors)->node()->inputs()
+          : consumer->inputs();
       // handle inputs in reverse topological order as well...
       // otherwise in f(a,a+b) it will appear a is used twice if we consider
       // the f-a fusion before the f-(a+b) fusion first.
       auto inputs = sortReverseTopological(consumer_inputs);
-      for(auto producer : inputs) {
+      for (auto producer : inputs) {
         // Don't fuse if producer must come from a FusionGroup exit node
-        if (mustRemainAsFusionGroupOutput(producer)) continue;
-        if(tryToMoveChunk(consumer,producer)) {
+        if (mustRemainAsFusionGroupOutput(producer))
+          continue;
+        if (tryToMoveChunk(consumer, producer)) {
           // the chunk before this consumer was re-arranged to allow fusion,
           // we scan this consumer again to perform the fusion
           return std::make_pair(consumer->reverseIterator(), true);
@@ -775,27 +810,31 @@ struct GraphFuser {
 
   void replaceIntermediateBroadcastingChunks() {
     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
-      auto * node = *it;
-      ++it;  // We might delete node, so increment the iterator now.
+      auto* node = *it;
+      ++it; // We might delete node, so increment the iterator now.
       if (node->kind() != prim::BroadcastingChunk) {
         continue;
       }
-      auto * bchunk = node;
+      auto* bchunk = node;
       insertExplicitBroadcast(bchunk);
 
-      auto * graph = block_->owningGraph();
+      auto* graph = block_->owningGraph();
       size_t nchunks = bchunk->i(attr::chunks);
       WithInsertPoint guard(bchunk->next());
 
       // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
-      for (size_t input_offset = 0; input_offset < bchunk->inputs().size(); input_offset++) {
+      for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
+           input_offset++) {
         auto* input = bchunk->inputs().at(input_offset);
 
-        Node * new_chunk = graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
+        Node* new_chunk =
+            graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
         new_chunk->copyAttributes(*bchunk);
-        for (size_t output_offset = 0; output_offset < nchunks; output_offset++) {
+        for (size_t output_offset = 0; output_offset < nchunks;
+             output_offset++) {
           auto new_output = new_chunk->addOutput();
-          auto old_output = bchunk->outputs().at(input_offset * nchunks + output_offset);
+          auto old_output =
+              bchunk->outputs().at(input_offset * nchunks + output_offset);
           new_output->copyMetadata(old_output);
           old_output->replaceAllUsesWith(new_output);
         }
@@ -804,20 +843,21 @@ struct GraphFuser {
     }
   }
 
-  bool usedOnlyInSize(Value * v) {
-    const auto & uses = v->uses();
-    return std::all_of(uses.begin(), uses.end(),
-                       [](const Use& u) { return u.user->matches("aten::size(Tensor self) -> int[]"); });
+  bool usedOnlyInSize(Value* v) {
+    const auto& uses = v->uses();
+    return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
+      return u.user->matches("aten::size(Tensor self) -> int[]");
+    });
   }
 
-  // Builds up expressions that compute shapes of all intermediates (and outputs)
-  // of the fusion group, based on the sizes of inputs. You should run DCE to remove
-  // those that you end up not using.
-  std::unordered_map<Value*, Value*> buildShapeExpressions(Node * fusion_group) {
-    WithInsertPoint insert_guard { fusion_group->next() };
+  // Builds up expressions that compute shapes of all intermediates (and
+  // outputs) of the fusion group, based on the sizes of inputs. You should run
+  // DCE to remove those that you end up not using.
+  std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
+    WithInsertPoint insert_guard{fusion_group->next()};
     std::unordered_map<Value*, Value*> shape_of;
 
-    Graph * graph = fusion_group->owningGraph();
+    Graph* graph = fusion_group->owningGraph();
     auto subgraph = fusion_group->g(attr::Subgraph);
 
     auto inputs = fusion_group->inputs();
@@ -829,18 +869,20 @@ struct GraphFuser {
 
     // When we have a guarantee that an output won't be removed, because it's
     // used in expressions that don't involve size checks, we can use its size
-    // instead of computing a long chain of broadcasts, starting from the beginning
-    // of the kernel.
+    // instead of computing a long chain of broadcasts, starting from the
+    // beginning of the kernel.
     auto outputs = fusion_group->outputs();
     auto soutputs = subgraph->outputs();
     JIT_ASSERT(outputs.size() == soutputs.size());
     for (size_t i = 0; i < outputs.size(); ++i) {
-      if (usedOnlyInSize(outputs[i])) continue;
+      if (usedOnlyInSize(outputs[i]))
+        continue;
       shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
     }
 
-    for (Node * n : subgraph->nodes()) {
-      // XXX: Use of shape_of.emplace is crucial to the output shape optimization!
+    for (Node* n : subgraph->nodes()) {
+      // XXX: Use of shape_of.emplace is crucial to the output shape
+      // optimization!
       if (n->kind() == prim::FusedConcat) {
         // This is a bit more involved, because we have to account for the case
         // when inputs have different shapes, but fortunately those tensors are
@@ -852,31 +894,36 @@ struct GraphFuser {
         continue;
       }
       if (n->kind() == prim::ConstantChunk) {
-        Node * sizes_node = graph->insertNode(graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
+        Node* sizes_node = graph->insertNode(
+            graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
         sizes_node->i_(attr::dim, n->i(attr::dim));
         sizes_node->i_(attr::chunks, n->i(attr::chunks));
-        Value * regular_size = sizes_node->outputs().at(0);
-        Value * last_size = sizes_node->outputs().at(1);
+        Value* regular_size = sizes_node->outputs().at(0);
+        Value* last_size = sizes_node->outputs().at(1);
         regular_size->setType(ListType::ofInts());
         last_size->setType(ListType::ofInts());
         auto outputs = n->outputs();
-        for (Value * o : outputs.slice(0, outputs.size() - 1)) {
+        for (Value* o : outputs.slice(0, outputs.size() - 1)) {
           shape_of.emplace(o, regular_size);
         }
         shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
         continue;
       }
-      auto tensor_inputs = filter(n->inputs(),
-                                  [](Value * v) { return v->type()->isSubtypeOf(DynamicType::get()); });
-      auto shapes = fmap(tensor_inputs, [&](Value * v) { return shape_of.at(v); });
+      auto tensor_inputs = filter(n->inputs(), [](Value* v) {
+        return v->type()->isSubtypeOf(DynamicType::get());
+      });
+      auto shapes =
+          fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
       JIT_ASSERT(!shapes.empty());
-      shape_of.emplace(n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
+      shape_of.emplace(
+          n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
     }
     return shape_of;
   }
 
-  void removeOutputsUsedOnlyInSize(Node * fusion_group) {
-    if (fusion_group->kind() != prim::FusionGroup) return;
+  void removeOutputsUsedOnlyInSize(Node* fusion_group) {
+    if (fusion_group->kind() != prim::FusionGroup)
+      return;
     auto subgraph = fusion_group->g(attr::Subgraph);
 
     auto shape_of = buildShapeExpressions(fusion_group);
@@ -915,9 +962,9 @@ struct GraphFuser {
     // where f, g, h, l are simple map ops.
     // The first iteration will fuse %4 and %3, and see that %1 is an input, but
     // can't be fused, because it has a different use before the fusion group
-    // in our topological ordering. Then, %2 will be considered, and fused with %1.
-    // If we do another iteration, the algorithm will consider the fusion of these
-    // two groups and fix the situation.
+    // in our topological ordering. Then, %2 will be considered, and fused with
+    // %1. If we do another iteration, the algorithm will consider the fusion of
+    // these two groups and fix the situation.
     bool any_changed = true;
     while (any_changed) {
       any_changed = false;
@@ -939,23 +986,23 @@ struct GraphFuser {
     }
 
     // Remove outputs that have been added only because we need their size
-    for (Node * n : block_->nodes()) {
+    for (Node* n : block_->nodes()) {
       removeOutputsUsedOnlyInSize(n);
     }
 
-    for (Node * node : block_->nodes()) {
-      for (Block * sub_block : node->blocks()) {
+    for (Node* node : block_->nodes()) {
+      for (Block* sub_block : node->blocks()) {
         GraphFuser(sub_block, graph_).run();
       }
     }
   }
 };
 
-void PeepholeOptimizeShapeExpressions(Block * block) {
+void PeepholeOptimizeShapeExpressions(Block* block) {
   auto nodes = block->nodes();
   for (auto it = nodes.begin(); it != nodes.end(); ++it) {
-    Node * node = *it;
-    for (Block * subblock : node->blocks()) {
+    Node* node = *it;
+    for (Block* subblock : node->blocks()) {
       PeepholeOptimizeShapeExpressions(subblock);
     }
     if (node->kind() == prim::BroadcastSizes) {
@@ -968,19 +1015,19 @@ void PeepholeOptimizeShapeExpressions(Block * block) {
       // Deduplicate inputs, but use their unique() values to ensure
       // this process only depends on the graph.
       std::map<size_t, Value*> unique_to_value;
-      for (Value * input : node->inputs()) {
+      for (Value* input : node->inputs()) {
         unique_to_value.emplace(input->unique(), input);
       }
       if (unique_to_value.size() != node->inputs().size()) {
         std::vector<Value*> inputs;
         inputs.reserve(unique_to_value.size());
-        for (auto & entry : unique_to_value) {
+        for (auto& entry : unique_to_value) {
           inputs.push_back(entry.second);
         }
         if (inputs.size() == 1) {
           node->output()->replaceAllUsesWith(inputs[0]);
         } else {
-          WithInsertPoint insert_guard { node };
+          WithInsertPoint insert_guard{node};
           node->output()->replaceAllUsesWith(broadcastSizes(inputs));
         }
         it.destroyCurrent();
@@ -988,12 +1035,13 @@ void PeepholeOptimizeShapeExpressions(Block * block) {
         continue;
       }
       // Remove compose simple chains of broadcasts into a single node.
-      const auto & uses = node->output()->uses();
+      const auto& uses = node->output()->uses();
       if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
-        Node * user = uses[0].user;
+        Node* user = uses[0].user;
         user->removeInput(uses[0].offset);
-        // NB: we don't care about deduplication in here, as we will visit user later.
-        for (Value * i : node->inputs()) {
+        // NB: we don't care about deduplication in here, as we will visit user
+        // later.
+        for (Value* i : node->inputs()) {
           user->addInput(i);
         }
         it.destroyCurrent();
@@ -1005,8 +1053,8 @@ void PeepholeOptimizeShapeExpressions(Block * block) {
 } // anonymous namespace
 
 void FuseGraph(std::shared_ptr<Graph>& graph) {
-  // NYI on Windows
-  #ifndef _WIN32
+// NYI on Windows
+#ifndef _WIN32
 
   GraphFuser(graph->block(), graph).run();
   // After FuseGraph some common subexpressions may come back
@@ -1017,7 +1065,8 @@ void FuseGraph(std::shared_ptr<Graph>& graph) {
   // Improve the quality of shape propagation code that was left
   PeepholeOptimizeShapeExpressions(graph->block());
 
-  #endif
+#endif
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 3166b5b..ffc156e 100644 (file)
@@ -2,11 +2,13 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // NB: Be sure to run DCE before fusion, because dead instructions
 // can prevent fusion opportunities from being exploited.
 // On Windows will noop, NYI
 TORCH_API void FuseGraph(std::shared_ptr<Graph>& graph);
 
-}}
+} // namespace jit
+} // namespace torch
index 5dca7f4..17b44bd 100644 (file)
@@ -2,8 +2,12 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-TORCH_API void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold = 5);
+TORCH_API void InlineAutodiffSubgraphs(
+    std::shared_ptr<Graph>& graph,
+    size_t threshold = 5);
 
-}} // namespace torch::jit
+}
+} // namespace torch
index e6e90d6..3684634 100644 (file)
@@ -1,14 +1,15 @@
 #include <torch/csrc/jit/passes/inplace_check.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-void CheckInplace(Block * block) {
+void CheckInplace(Block* block) {
   for (auto node : block->nodes()) {
     if (node->kind() == prim::PythonOp && node->hasAttribute(attr::inplace)) {
       if (node->i(attr::inplace)) {
-        throw std::runtime_error(std::string("inplace ") +
-                                 static_cast<PythonOp*>(node)->name() +
-                                 " not supported in the JIT");
+        throw std::runtime_error(
+            std::string("inplace ") + static_cast<PythonOp*>(node)->name() +
+            " not supported in the JIT");
       }
     }
   }
@@ -18,4 +19,5 @@ void CheckInplace(std::shared_ptr<Graph>& graph) {
   CheckInplace(graph->block());
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index e168a20..1e0cbba 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void CheckInplace(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 18fab73..216d08e 100644 (file)
@@ -1,13 +1,14 @@
 #include <torch/csrc/jit/passes/loop_unrolling.h>
 
-#include <torch/csrc/jit/interned_strings.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/interned_strings.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -15,7 +16,7 @@ static constexpr int64_t kUnrollFactor = 8;
 static constexpr int64_t kMaxBodySize = 32;
 static constexpr int64_t kMaxBodyRepeats = 64;
 
-bool isTrueConstant(Value *val) {
+bool isTrueConstant(Valueval) {
   c10::optional<bool> maybe_value = constant_as<bool>(val);
   return maybe_value && *maybe_value;
 }
@@ -23,17 +24,18 @@ bool isTrueConstant(Value *val) {
 bool isForLoop(Node* node) {
   if (node->kind() != prim::Loop)
     return false;
-  Value *start_cond = node->inputs().at(1);
-  Value *continue_cond = node->blocks().at(0)->outputs().at(0);
+  Valuestart_cond = node->inputs().at(1);
+  Valuecontinue_cond = node->blocks().at(0)->outputs().at(0);
   return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
 }
 
-// Counts the size of this block, stopping and returning once reaches limit instructions.
-int64_t limitedBlockSize(Block *body, int64_t limit) {
+// Counts the size of this block, stopping and returning once reaches limit
+// instructions.
+int64_t limitedBlockSize(Block* body, int64_t limit) {
   auto it = body->nodes().begin();
   auto end = body->nodes().end();
   for (int64_t i = 0; i < limit; ++i, ++it) {
-    for (Block *subblock : it->blocks()) {
+    for (Blocksubblock : it->blocks()) {
       i += limitedBlockSize(subblock, limit - i);
     }
     if (it == end) {
@@ -43,18 +45,19 @@ int64_t limitedBlockSize(Block *body, int64_t limit) {
   return limit;
 }
 
-bool isSmallBlock(Block *body) {
+bool isSmallBlock(Blockbody) {
   return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
 }
 
-// XXX: This function can only be called with a loop that is guaranteed to execute EXACTLY ONCE.
-void inlineBody(Node *loop) {
+// XXX: This function can only be called with a loop that is guaranteed to
+// execute EXACTLY ONCE.
+void inlineBody(Node* loop) {
   auto graph = loop->owningGraph();
   auto body = loop->blocks().at(0);
-  WithInsertPoint insert_point_guard { loop };
+  WithInsertPoint insert_point_guard{loop};
 
   std::unordered_map<Value*, Value*> value_map;
-  auto get_value = [&](Value *v) {
+  auto get_value = [&](Valuev) {
     auto it = value_map.find(v);
     if (it != value_map.end())
       return it->second;
@@ -67,32 +70,34 @@ void inlineBody(Node *loop) {
     value_map[body->inputs()[i - 1]] = loop->inputs()[i];
   }
 
-  for (Node *orig : body->nodes()) {
-    Node *clone = graph->insertNode(graph->createClone(orig, get_value));
+  for (Nodeorig : body->nodes()) {
+    Nodeclone = graph->insertNode(graph->createClone(orig, get_value));
     for (size_t i = 0; i < orig->outputs().size(); ++i) {
       value_map[orig->outputs()[i]] = clone->outputs()[i];
     }
   }
   for (size_t i = 0; i < loop->outputs().size(); ++i) {
-    loop->outputs().at(i)->replaceAllUsesWith(get_value(body->outputs().at(i + 1)));
+    loop->outputs().at(i)->replaceAllUsesWith(
+        get_value(body->outputs().at(i + 1)));
   }
-  // XXX: it is extremely important to destroy the loop in here. DCE might not be able
-  // to conclude that it's safe, because the loop might contain side effects.
+  // XXX: it is extremely important to destroy the loop in here. DCE might not
+  // be able to conclude that it's safe, because the loop might contain side
+  // effects.
   loop->destroy();
 }
 
-void repeatBody(Block *body, int64_t times) {
+void repeatBody(Blockbody, int64_t times) {
   // We will be adding nodes to the body, so cache the initial start and end.
   // XXX: they are both inclusive, because the exclusive body_end would point to
-  //      return_node, which would move further away if we were to add nodes, and we
-  //      would enter an infinite loop.
+  //      return_node, which would move further away if we were to add nodes,
+  //      and we would enter an infinite loop.
   auto body_start = body->nodes().begin();
   auto body_end = std::prev(body->nodes().end());
   auto graph = body->owningGraph();
-  WithInsertPoint insert_point_guard { body };
+  WithInsertPoint insert_point_guard{body};
 
   std::unordered_map<Value*, Value*> value_map;
-  auto get_value = [&](Value *v) {
+  auto get_value = [&](Valuev) {
     auto it = value_map.find(v);
     if (it != value_map.end())
       return it->second;
@@ -101,8 +106,8 @@ void repeatBody(Block *body, int64_t times) {
 
   for (int64_t i = 1; i < times; ++i) {
     // Update loop-carried values
-    // NB: note that we don't need to worry about the loop counter, because we've
-    //     replaced it with a loop-carried variable
+    // NB: note that we don't need to worry about the loop counter, because
+    // we've replaced it with a loop-carried variable
     JIT_ASSERT(body->inputs().size() == body->outputs().size());
     for (size_t i = 1; i < body->inputs().size(); ++i) {
       value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
@@ -110,8 +115,8 @@ void repeatBody(Block *body, int64_t times) {
 
     // Clone the nodes
     for (auto it = body_start; it != std::next(body_end); ++it) {
-      Node *orig = *it;
-      Node *clone = graph->insertNode(graph->createClone(orig, get_value));
+      Nodeorig = *it;
+      Nodeclone = graph->insertNode(graph->createClone(orig, get_value));
       for (size_t i = 0; i < orig->outputs().size(); ++i) {
         value_map[orig->outputs()[i]] = clone->outputs()[i];
       }
@@ -123,50 +128,52 @@ void repeatBody(Block *body, int64_t times) {
   for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
     body->eraseOutput(i);
   }
-  for (Value *output : new_outputs) {
+  for (Valueoutput : new_outputs) {
     body->registerOutput(output);
   }
 
-  // It's likely that we have some dead nodes now - for example the "true" constant
-  // that prevents the loop from breaking. We shouldn't wait too long before removing
-  // them because they might artificially increase the loop size and prevent outer loop
-  // unrolling.
+  // It's likely that we have some dead nodes now - for example the "true"
+  // constant that prevents the loop from breaking. We shouldn't wait too long
+  // before removing them because they might artificially increase the loop size
+  // and prevent outer loop unrolling.
   EliminateDeadCode(body, false);
 }
 
-// Replaces the builtin loop counter with a "mutable" variable outside of the loop.
-void replaceLoopCounter(Node *loop) {
-  Graph *graph = loop->owningGraph();
-  Block *body = loop->blocks().at(0);
+// Replaces the builtin loop counter with a "mutable" variable outside of the
+// loop.
+void replaceLoopCounter(Node* loop) {
+  Graph* graph = loop->owningGraph();
+  Block* body = loop->blocks().at(0);
   WithInsertPoint guard(loop);
   Value* init_counter = graph->insertConstant(0);
 
   loop->insertInput(2, init_counter);
   loop->insertOutput(0)->setType(IntType::get());
 
-  Value * internal_counter = body->insertInput(1)->setType(init_counter->type());
+  Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
   body->inputs()[0]->replaceAllUsesWith(internal_counter);
 
-  WithInsertPoint insertPointGuard{ body->return_node() };
+  WithInsertPoint insertPointGuard{body->return_node()};
   Value* result = graph->insert(aten::add, {internal_counter, 1});
   body->insertOutput(1, result);
 }
 
-void unroll(Node *loop) {
-  Graph *graph = loop->owningGraph();
-  Block *body = loop->blocks().at(0);
+void unroll(Nodeloop) {
+  Graphgraph = loop->owningGraph();
+  Blockbody = loop->blocks().at(0);
   if (!isSmallBlock(body))
     return;
 
-  // We will be using a "mutable" counter outside of the loop instead of the default
-  // one, because this will allow us to share it between the unrolled loop and its epilogue.
-  // This is necessary only if the loop counter is actually used in the body.
+  // We will be using a "mutable" counter outside of the loop instead of the
+  // default one, because this will allow us to share it between the unrolled
+  // loop and its epilogue. This is necessary only if the loop counter is
+  // actually used in the body.
   if (body->inputs()[0]->uses().size() > 0)
     replaceLoopCounter(loop);
 
-  // Some optimization for constant-length loops. If we know they won't run too many
-  // times, then we can unroll them entirely.
-  Value *trip_count = loop->inputs().at(0);
+  // Some optimization for constant-length loops. If we know they won't run too
+  // many times, then we can unroll them entirely.
+  Valuetrip_count = loop->inputs().at(0);
   int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
   if (const_len != -1 && const_len < kMaxBodyRepeats) {
     repeatBody(body, const_len);
@@ -174,11 +181,11 @@ void unroll(Node *loop) {
     return;
   }
 
-  WithInsertPoint insert_point_guard { loop };
+  WithInsertPoint insert_point_guard{loop};
 
   // Clone the loop before we unroll it. The clone will become the epilogue.
-  Node *loop_epilogue = graph->createClone(loop, [](Value *v) { return v; })
-                             ->insertAfter(loop);
+  Node* loop_epilogue =
+      graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
   for (size_t i = 0; i < loop->outputs().size(); ++i) {
     loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
     loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
@@ -188,16 +195,24 @@ void unroll(Node *loop) {
 
   // Change the iteration counts of both loops
   Value* iter_count = loop->inputs().at(0);
-  Value* unrolled_iter_count = graph->insert(aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
+  Value* unrolled_iter_count = graph->insert(
+      aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
   loop->replaceInput(0, unrolled_iter_count);
-  loop_epilogue->replaceInput(0, graph->insert(aten::sub, {iter_count, graph->insert(aten::mul,{unrolled_iter_count , kUnrollFactor})}));
+  loop_epilogue->replaceInput(
+      0,
+      graph->insert(
+          aten::sub,
+          {iter_count,
+           graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
 }
 
-void UnrollLoops(Block *block) {
+void UnrollLoops(Blockblock) {
   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
-    // XXX: unroll might destroy the current node, so we need to pre-increment the iterator
-    Node *node = *it; ++it;
-    for (Block *subblock : node->blocks()) {
+    // XXX: unroll might destroy the current node, so we need to pre-increment
+    // the iterator
+    Node* node = *it;
+    ++it;
+    for (Block* subblock : node->blocks()) {
       UnrollLoops(subblock);
     }
     if (isForLoop(node)) {
@@ -213,4 +228,5 @@ void UnrollLoops(std::shared_ptr<Graph>& graph) {
   EliminateDeadCode(graph);
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 23b1b0f..f2d6ed6 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void UnrollLoops(std::shared_ptr<Graph>& graph);
 
-}} // namespace torch::jit
+}
+} // namespace torch
index 2912307..f64d66d 100644 (file)
@@ -1,10 +1,11 @@
 #include <torch/csrc/jit/passes/lower_grad_of.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 void LowerGradOf(Graph& g) {
-  for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
-    if(it->kind() == prim::GradOf) {
+  for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
+    if (it->kind() == prim::GradOf) {
       // if any_defined(inputs):
       //  outputs = <original_computation>
       // else:
@@ -13,8 +14,8 @@ void LowerGradOf(Graph& g) {
       auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs()))
                       ->output()
                       ->setType(IntType::get());
-      auto if_stat = g.insertNode(
-          g.create(prim::If, {cond}, it->outputs().size()));
+      auto if_stat =
+          g.insertNode(g.create(prim::If, {cond}, it->outputs().size()));
       if_stat->addBlock()->cloneFrom(
           it->blocks().at(0), [](Value* v) { return v; });
       auto else_block = if_stat->addBlock();
@@ -31,4 +32,5 @@ void LowerGradOf(Graph& g) {
   }
 }
 
-}}
+} // namespace jit
+} // namespace torch
index a431e48..63f52d6 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // This pass removes 'grad_of' nodes, replacing them with conditionals of
 // the form:
@@ -12,4 +13,5 @@ namespace torch { namespace jit {
 //  outputs = undefineds
 TORCH_API void LowerGradOf(Graph& g);
 
-}}
+} // namespace jit
+} // namespace torch
index cde02d6..9c1764c 100644 (file)
@@ -1,9 +1,10 @@
-#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/lower_tuples.h>
 #include <torch/csrc/utils/functional.h>
-#include <torch/csrc/jit/assertions.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -11,19 +12,19 @@ namespace {
 // this is to assert we are only doing modifications when we know
 // we can flatten tuples
 std::unordered_set<Symbol> white_list = {
-  prim::If,
-  prim::Loop,
-  prim::TupleUnpack,
-  prim::TupleConstruct,
-  prim::TupleIndex,
-  prim::TupleSlice,
-  prim::Param,
-  prim::Return,
+    prim::If,
+    prim::Loop,
+    prim::TupleUnpack,
+    prim::TupleConstruct,
+    prim::TupleIndex,
+    prim::TupleSlice,
+    prim::Param,
+    prim::Return,
 };
 
-void removeTupleNodes(Node *n, bool must_remove_tuples) {
-  if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex
-      && n->kind() != prim::TupleSlice) {
+void removeTupleNodes(Noden, bool must_remove_tuples) {
+  if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
+      n->kind() != prim::TupleSlice) {
     return;
   }
   auto construct = n->input()->node();
@@ -34,7 +35,7 @@ void removeTupleNodes(Node *n, bool must_remove_tuples) {
     return;
   }
   if (n->kind() == prim::TupleUnpack) {
-    for(size_t i = 0; i < n->outputs().size(); ++i) {
+    for (size_t i = 0; i < n->outputs().size(); ++i) {
       n->outputs()[i]->replaceAllUsesWith(construct->inputs().at(i));
     }
   } else if (n->kind() == prim::TupleIndex) {
@@ -55,35 +56,40 @@ void removeTupleNodes(Node *n, bool must_remove_tuples) {
   }
 }
 
-} //anonymous namespace
+} // anonymous namespace
 
 static void LowerAllTuples(Block* block);
 
 static void VisitNode(Node* n, Node* insert_point) {
-  auto & graph = *n->owningGraph();
+  auto& graph = *n->owningGraph();
 
   // tuple construction operators will become dead when the unpacks are replaced
-  if(n->kind() == prim::TupleConstruct) {
+  if (n->kind() == prim::TupleConstruct) {
     return;
   }
 
-  // note: changing the second argument to false changes this pass from a complete lowering
-  // pass to one that removes tuples when possible. When tuples are first-class
-  // in the interpreter, we should still run this pass to remove extraneous uses
+  // note: changing the second argument to false changes this pass from a
+  // complete lowering pass to one that removes tuples when possible. When
+  // tuples are first-class in the interpreter, we should still run this pass to
+  // remove extraneous uses
 
-  if(n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
+  if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
       n->kind() == prim::TupleSlice) {
-     removeTupleNodes(n, /*must_remove_tuples*/true);
-     return;
+    removeTupleNodes(n, /*must_remove_tuples*/ true);
+    return;
   }
 
   // flatten the input list  op(a, tup, b) --> op(a, t0, t1, b)
-  for(size_t i = 0; i < n->inputs().size();) {
+  for (size_t i = 0; i < n->inputs().size();) {
     auto input = n->inputs()[i];
-    if(TupleTypePtr tt = input->type()->cast<TupleType>()) {
-      JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples");
-      JIT_ASSERTM(input->node()->kind() == prim::TupleConstruct, "tuple use not matched to tuple construct");
-      for(size_t j = 0; j < tt->elements().size(); ++j) {
+    if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
+      JIT_ASSERTM(
+          white_list.count(n->kind()) > 0,
+          "tuple appears in op that does not forward tuples");
+      JIT_ASSERTM(
+          input->node()->kind() == prim::TupleConstruct,
+          "tuple use not matched to tuple construct");
+      for (size_t j = 0; j < tt->elements().size(); ++j) {
         n->insertInput(i + 1 + j, input->node()->inputs().at(j));
       }
       n->removeInput(i);
@@ -94,23 +100,26 @@ static void VisitNode(Node* n, Node* insert_point) {
       ++i;
     }
   }
-  for(auto b : n->blocks()) {
+  for (auto b : n->blocks()) {
     LowerAllTuples(b);
   }
 
   // flatten the outputs list
-  for(size_t i = 0; i < n->outputs().size();) {
-    Value * output = n->outputs()[i];
+  for (size_t i = 0; i < n->outputs().size();) {
+    Value* output = n->outputs()[i];
     // (a, b, tup, c) -> (a, b, t0, t1, c)
     // and:
     //    tup = (t0, t1)
     // is placed at the current insertion point
-    if(TupleTypePtr tt = output->type()->cast<TupleType>()) {
-      JIT_ASSERTM(white_list.count(n->kind()) > 0, "tuple appears in op that does not forward tuples");
-      for(size_t j = 0; j < tt->elements().size(); j++) {
+    if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
+      JIT_ASSERTM(
+          white_list.count(n->kind()) > 0,
+          "tuple appears in op that does not forward tuples");
+      for (size_t j = 0; j < tt->elements().size(); j++) {
         n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
       }
-      auto new_tup = graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
+      auto new_tup =
+          graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
       new_tup->insertBefore(insert_point);
       insert_point = new_tup;
       output->replaceAllUsesWith(new_tup->output());
@@ -127,7 +136,8 @@ static void LowerAllTuples(Block* block) {
   // _outputs_ of normal instructions, since the param_node represents the
   // parameters as outputs, we can handle it by simply visiting the node
   VisitNode(block->param_node(), *block->nodes().begin());
-  for(auto it = block->nodes().begin(), end = block->nodes().end(); it != end;) {
+  for (auto it = block->nodes().begin(), end = block->nodes().end();
+       it != end;) {
     auto n = *it++;
     VisitNode(n, *it);
   }
@@ -138,11 +148,10 @@ static void LowerAllTuples(Block* block) {
   VisitNode(block->return_node(), nullptr);
 }
 
-
 static void EnsureNoTuples(ArrayRef<Value*> values) {
-  for (Value * v : values) {
-    JIT_ASSERTM(v->type()->kind() != TypeKind::TupleType,
-                "Couldn't lower all tuples.");
+  for (Value* v : values) {
+    JIT_ASSERTM(
+        v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
   }
 }
 
@@ -162,9 +171,9 @@ void LowerAllTuples(std::shared_ptr<Graph>& graph) {
 }
 
 void LowerSimpleTuples(Block* block) {
-  for(auto n : block->nodes()) {
-    removeTupleNodes(n, /*must_remove_tuples*/false);
-    for(auto b : n->blocks()) {
+  for (auto n : block->nodes()) {
+    removeTupleNodes(n, /*must_remove_tuples*/ false);
+    for (auto b : n->blocks()) {
       LowerSimpleTuples(b);
     }
   }
@@ -175,4 +184,5 @@ void LowerSimpleTuples(std::shared_ptr<Graph>& graph) {
   EliminateDeadCode(graph);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 4e7c990..2cb783a 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // removes tuples where TupleConstruct and TupleUnpack are matched
 // but leaves tuples in place across if statements, loops, and as inputs/outputs
@@ -15,4 +16,5 @@ TORCH_API void LowerAllTuples(std::shared_ptr<Graph>& graph);
 
 TORCH_API void LowerSimpleTuples(Block* block);
 
-}}
+} // namespace jit
+} // namespace torch
index 98aa05f..e10f6aa 100644 (file)
@@ -1,24 +1,31 @@
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/jit/passes/onnx.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/symbolic.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/onnx.h>
 #include <torch/csrc/utils/functional.h>
-#include <unordered_map>
+#include <torch/csrc/utils/pybind.h>
 #include <sstream>
+#include <unordered_map>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Transform PythonOps into Nodes that match ONNX semantics.
-std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& graph, ::torch::onnx::OperatorExportTypes operator_export_type) {
+std::shared_ptr<Graph> ToONNX(
+    std::shared_ptr<Graph>& graph,
+    ::torch::onnx::OperatorExportTypes operator_export_type) {
   auto new_graph = std::make_shared<Graph>(graph->current_scope());
   std::unordered_map<Value*, Value*> env;
   BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env);
   return new_graph;
 }
 
-void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env) {
+void BlockToONNX(
+    Block* old_block,
+    Block* new_block,
+    ::torch::onnx::OperatorExportTypes operator_export_type,
+    std::unordered_map<Value*, Value*> env) {
   torch::autograd::SymbolicContext ctx{};
   ctx.block = new_block;
 
@@ -26,7 +33,7 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   py::object onnx_symbolic = py::module::import("torch.onnx.symbolic");
 
   // Returns a node that n maps to in the new graph
-  auto envFn = [&env](Value * n) -> Value* {
+  auto envFn = [&env](Value* n) -> Value* {
     auto it = env.find(n);
     JIT_ASSERTM(it != env.end(), "Dangling node reference");
     JIT_ASSERTM(it->second, "Unused node was subsequently used");
@@ -41,13 +48,16 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   // Put the new outputs in our environment map, and copy the type from the
   // input graph if they were not set by the symbolic. This is called only
   // with results of symbolic call (not for nodes that are just cloned).
-  auto setOutputs = [&](const std::string& op_name, Node * node, const value_list & outputs) {
+  auto setOutputs = [&](const std::string& op_name,
+                        Node* node,
+                        const value_list& outputs) {
     auto old_outputs = node->outputs();
     // Count all outputs, excluding Handles
     auto num_old_outputs = old_outputs.size();
     if (outputs.size() != num_old_outputs) {
       std::ostringstream ss;
-      ss << "symbolic for " << op_name << " produced an incorrect number of outputs (expected ";
+      ss << "symbolic for " << op_name
+         << " produced an incorrect number of outputs (expected ";
       ss << num_old_outputs << ", but got " << outputs.size() << ")";
       throw std::runtime_error(ss.str());
     }
@@ -69,7 +79,8 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
         env[old] = nullptr;
         if (!old->uses().empty()) {
           std::ostringstream ss;
-          ss << "symbolic for " << op_name << " returned None for the output " << i;
+          ss << "symbolic for " << op_name << " returned None for the output "
+             << i;
           ss << " (indicating conversion for that particular output is not supported), ";
           ss << "but the network uses this output later";
           // TODO: Say what actually used it
@@ -80,16 +91,19 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   };
 
   // Clone the node and add it to the new graph
-  auto cloneNode = [&](Node * node) {
-    auto n_ = ctx.block->appendNode(ctx.block->owningGraph()->createClone(node, envFn));
-    for(size_t i = 0; i < node->outputs().size(); i++) {
+  auto cloneNode = [&](Node* node) {
+    auto n_ = ctx.block->appendNode(
+        ctx.block->owningGraph()->createClone(node, envFn));
+    for (size_t i = 0; i < node->outputs().size(); i++) {
       // n_->outputs()[i]->setType(node->outputs()[i]->type());
       env[node->outputs()[i]] = n_->outputs()[i];
     }
   };
 
   // Cast output of symbolic() python implementation
-  auto processSymbolicOutput = [&](const std::string& op_name, Node* n, const py::object& raw_output) {
+  auto processSymbolicOutput = [&](const std::string& op_name,
+                                   Node* n,
+                                   const py::object& raw_output) {
     if (raw_output.ptr() == Py_None) {
       cloneNode(n);
       return;
@@ -120,12 +134,13 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
     py::tuple py_inputs(n->inputs().size());
     Py_ssize_t input_nr = 0;
     for (auto* input : n->inputs()) {
-        py_inputs[input_nr++] = py::cast(envFn(input));
+      py_inputs[input_nr++] = py::cast(envFn(input));
     }
 
     WithInsertPoint insert_point_guard(ctx.block);
     WithCurrentScope scope_guard(*ctx.block->owningGraph(), n->scope());
-    py::object raw_output = onnx.attr("_run_symbolic_function")(ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
+    py::object raw_output = onnx.attr("_run_symbolic_function")(
+        ctx.block->owningGraph(), n, py_inputs, env, operator_export_type);
 
     // TODO: Assert it's an ATen identifier???
     // (Sometimes it's not...)
@@ -133,15 +148,14 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   };
 
   auto callPySymbolicMethod = [&](PythonOp* op) {
-
     // Test if there is a symbolic function; bail if there is not
     auto pyobj = py::handle(op->pyobj.get());
     auto func = op->autogradFunction();
-    if(func) {
+    if (func) {
       pyobj = func->get();
     }
 
-    if(!py::hasattr(pyobj, "symbolic")) {
+    if (!py::hasattr(pyobj, "symbolic")) {
       cloneNode(op);
       return;
     }
@@ -157,8 +171,11 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
     for (auto arg_type : op->cconv) {
       py::object obj;
       if (arg_type == 'c') {
-        JIT_ASSERTM(scalar_it != op->scalar_args.end(), "expected too many scalar args");
-        obj = py::reinterpret_borrow<py::object>(py::handle((scalar_it++)->get()));
+        JIT_ASSERTM(
+            scalar_it != op->scalar_args.end(),
+            "expected too many scalar args");
+        obj = py::reinterpret_borrow<py::object>(
+            py::handle((scalar_it++)->get()));
       } else if (arg_type == 'd') {
         JIT_ASSERTM(node_it != inputs.end(), "expected too many inputs");
         obj = py::cast(envFn(*node_it++));
@@ -173,7 +190,8 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
     // Call the symbolic function
     // Use a little trampoline function so we can give good error messages
     // upon argument mismatch
-    py::object raw_output = onnx.attr("_run_symbolic_method")(op->name(), pyobj.attr("symbolic"), py_symbolic_args);
+    py::object raw_output = onnx.attr("_run_symbolic_method")(
+        op->name(), pyobj.attr("symbolic"), py_symbolic_args);
 
     processSymbolicOutput(op->name(), op, raw_output);
   };
@@ -181,9 +199,9 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   // Finally, visit all nodes in the graph
   for (auto node : old_block->nodes()) {
     IR_IFM(node, PythonOp)
-      callPySymbolicMethod(value);
+    callPySymbolicMethod(value);
     IR_ELSE()
-      callPySymbolicFunction(node);
+    callPySymbolicFunction(node);
     IR_END()
   }
   for (auto output : old_block->outputs()) {
@@ -194,4 +212,5 @@ void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExpo
   EliminateDeadCode(ctx.block);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 2cc5045..97a3625 100644 (file)
@@ -3,9 +3,17 @@
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/onnx/onnx.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-TORCH_API std::shared_ptr<Graph> ToONNX(std::shared_ptr<Graph>& state, ::torch::onnx::OperatorExportTypes operator_export_type);
-TORCH_API void BlockToONNX(Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, std::unordered_map<Value*, Value*> env);
+TORCH_API std::shared_ptr<Graph> ToONNX(
+    std::shared_ptr<Graph>& state,
+    ::torch::onnx::OperatorExportTypes operator_export_type);
+TORCH_API void BlockToONNX(
+    Block* old_block,
+    Block* new_block,
+    ::torch::onnx::OperatorExportTypes operator_export_type,
+    std::unordered_map<Value*, Value*> env);
 
-}}
+} // namespace jit
+} // namespace torch
index 22744c4..bc6ba90 100644 (file)
@@ -1,15 +1,16 @@
 #include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-void FixupONNXLoops(Block *block) {
-  for (auto *node : block->nodes()) {
+void FixupONNXLoops(Blockblock) {
+  for (autonode : block->nodes()) {
     if (node->kind() == torch::jit::onnx::Loop) {
       JIT_ASSERT(node->blocks().size() == 1);
-      auto *sub_block = node->blocks()[0];
+      autosub_block = node->blocks()[0];
       sub_block->insertInput(1, "cond");
     }
-    for (Block * block : node->blocks()) {
+    for (Block* block : node->blocks()) {
       FixupONNXLoops(block);
     }
   }
@@ -19,4 +20,5 @@ void FixupONNXLoops(std::shared_ptr<Graph>& graph) {
   FixupONNXLoops(graph->block());
 }
 
-}}  // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 30de755..096d128 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 void FixupONNXLoops(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 5d3f477..6987b93 100644 (file)
@@ -1,5 +1,5 @@
-#include <torch/csrc/jit/passes/onnx/peephole.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/passes/onnx/peephole.h>
 
 #include <c10/util/Optional.h>
 
@@ -8,14 +8,15 @@
 typedef SSIZE_T ssize_t;
 #endif
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-bool isRNN(const Node *node) {
+bool isRNN(const Nodenode) {
   auto k = node->kind();
   return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
 }
 
-bool isNopTranspose(const std::vector<int64_t> & perm) {
+bool isNopTranspose(const std::vector<int64_t>& perm) {
   for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
     if (perm[i] != i)
       return false;
@@ -31,8 +32,9 @@ bool isNopTranspose(const std::vector<int64_t> & perm) {
 // iteration would have folded all the transposes up to that point. Thus,
 // `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of
 // the input tensor index contained in t1 at position `t2[i]``".
-std::vector<int64_t> composeTransposes(const std::vector<int64_t> & t1,
-                                       const std::vector<int64_t> & t2) {
+std::vector<int64_t> composeTransposes(
+    const std::vector<int64_t>& t1,
+    const std::vector<int64_t>& t2) {
   JIT_ASSERT(t1.size() == t2.size());
   std::vector<int64_t> ret;
   ret.reserve(t1.size());
@@ -87,9 +89,9 @@ c10::optional<size_t> fusibleExpandTo(at::IntList from, at::IntList to) {
   return to.size() - from.size();
 }
 
-void fuseBroadcast(Block *b) {
-  for(auto n : b->nodes()) {
-    for (auto *child_block : n->blocks()) {
+void fuseBroadcast(Blockb) {
+  for (auto n : b->nodes()) {
+    for (autochild_block : n->blocks()) {
       fuseBroadcast(child_block);
     }
 
@@ -134,14 +136,18 @@ void fuseBroadcast(Block *b) {
   }
 }
 
-void fuseConsecutiveTransposes(Block *b) {
-  for(auto n : b->nodes()) {
-    for (auto *child_block : n->blocks()) {
+void fuseConsecutiveTransposes(Blockb) {
+  for (auto n : b->nodes()) {
+    for (autochild_block : n->blocks()) {
       fuseConsecutiveTransposes(child_block);
     }
-    if (n->kind() == onnx::Transpose && n->input()->node()->kind() == onnx::Transpose) {
+    if (n->kind() == onnx::Transpose &&
+        n->input()->node()->kind() == onnx::Transpose) {
       auto origInput = n->input();
-      n->is_(attr::perm, composeTransposes(origInput->node()->is(attr::perm), n->is(attr::perm)));
+      n->is_(
+          attr::perm,
+          composeTransposes(
+              origInput->node()->is(attr::perm), n->is(attr::perm)));
       n->replaceInput(0, origInput->node()->input());
       if (origInput->uses().size() == 0) {
         origInput->node()->destroy();
@@ -151,10 +157,10 @@ void fuseConsecutiveTransposes(Block *b) {
   }
 }
 
-void eliminateNopTranspose(Block *b) {
-  for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+void eliminateNopTranspose(Blockb) {
+  for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
     auto n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       eliminateNopTranspose(child_block);
     }
     if (n->kind() == onnx::Transpose) {
@@ -167,18 +173,19 @@ void eliminateNopTranspose(Block *b) {
   }
 }
 
-void fuseTransposeIntoGemm(Block *b) {
-  static const std::vector<int64_t> simpleTransPerm({1,0});
+void fuseTransposeIntoGemm(Blockb) {
+  static const std::vector<int64_t> simpleTransPerm({1, 0});
 
-  for(auto n : b->nodes()) {
-    for (auto *child_block : n->blocks()) {
+  for (auto n : b->nodes()) {
+    for (autochild_block : n->blocks()) {
       fuseTransposeIntoGemm(child_block);
     }
     if (n->kind() == onnx::Gemm) {
-      for (size_t i : {0,1}) {
+      for (size_t i : {0, 1}) {
         auto inp = n->inputs()[i];
         auto trans = i == 0 ? attr::transA : attr::transB;
-        if (inp->node()->kind() == onnx::Transpose && inp->node()->is(attr::perm) == simpleTransPerm) {
+        if (inp->node()->kind() == onnx::Transpose &&
+            inp->node()->is(attr::perm) == simpleTransPerm) {
           n->replaceInput(i, inp->node()->input());
           n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
           if (inp->uses().size() == 0) {
@@ -207,10 +214,10 @@ void fuseTransposeIntoGemm(Block *b) {
 //   entirely by pairing them with their inverse PadPacked. If the
 //   input graph does not pair the operations, export will fail.
 
-void pushPackingPastRnn(Block *b) {
+void pushPackingPastRnn(Blockb) {
   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
     auto* n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       pushPackingPastRnn(child_block);
     }
 
@@ -221,17 +228,18 @@ void pushPackingPastRnn(Block *b) {
       // For now, only handle the case where there is one consumer.
       continue;
     }
-    Node * rnn = n->outputs()[0]->uses()[0].user;
+    Node* rnn = n->outputs()[0]->uses()[0].user;
     if (!isRNN(rnn)) {
       continue;
     }
 
-    if(rnn->owningBlock() != n->owningBlock())
+    if (rnn->owningBlock() != n->owningBlock())
       continue;
 
-    // Packing only has an effect on a network when its outputs are actually used,
-    // so we can remove it here.
-    if (rnn->outputs().at(0)->uses().empty() && n->outputs().at(1)->uses().size() == 1) {
+    // Packing only has an effect on a network when its outputs are actually
+    // used, so we can remove it here.
+    if (rnn->outputs().at(0)->uses().empty() &&
+        n->outputs().at(1)->uses().size() == 1) {
       n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
       n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
       it.destroyCurrent();
@@ -240,7 +248,7 @@ void pushPackingPastRnn(Block *b) {
 
     // The rnn is followed by a transpose and a reshape (if
     // bidirectional), or by a squeeze (if unidirectional).
-    Node * next = rnn->outputs().at(0)->uses().at(0).user;
+    Node* next = rnn->outputs().at(0)->uses().at(0).user;
     if (next->kind() == onnx::Transpose) {
       next = next->outputs().at(0)->uses().at(0).user;
       if (next->kind() != onnx::Reshape) {
@@ -258,7 +266,7 @@ void pushPackingPastRnn(Block *b) {
     n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
 
     // and insert new PackPadded after the RNN
-    Node * newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
+    Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
     newPackPadded->insertAfter(next);
 
     // make things consume from the new PackPadded
@@ -274,7 +282,8 @@ void pushPackingPastRnn(Block *b) {
     // unhygenic way, Pytorch ends up propagating an incorrect type.
     // Until a long-term cleanup comes around, we can fix this by
     // resetting the size to the correct value.
-    CompleteTensorTypePtr oldType = rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
+    CompleteTensorTypePtr oldType =
+        rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
     if (oldType) {
       std::vector<int64_t> new_sizes;
       new_sizes.push_back(oldType->sizes().at(0));
@@ -292,7 +301,7 @@ void pushPackingPastRnn(Block *b) {
 void removeNopPacking(Block* graph) {
   for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
     auto* n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       removeNopPacking(child_block);
     }
 
@@ -323,7 +332,7 @@ void hackFixupPadPackedShapes(Block* graph) {
   // of its input.
   for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
     auto* n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       removeNopPacking(child_block);
     }
 
@@ -335,7 +344,7 @@ void hackFixupPadPackedShapes(Block* graph) {
   }
 }
 
-void fixDefaultRNNState(Graph* graph, Node * n, int input_index) {
+void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
   auto initial_state = n->inputs()[input_index];
 
   // The RNN code in pytorch accepts an optional hidden state. When it
@@ -345,54 +354,64 @@ void fixDefaultRNNState(Graph* graph, Node * n, int input_index) {
   // with something that doesn't fix the batch size.  Note that for
   // multi-layer RNNs there will be a Slice operation between the
   // Constant and the RNN.
-  bool needsFixing =
-    initial_state->node()->kind() == onnx::Constant ||
-    (initial_state->node()->kind() == onnx::Slice &&
-     initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
+  bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
+      (initial_state->node()->kind() == onnx::Slice &&
+       initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
 
   if (!needsFixing) {
     return;
   }
 
-  Node * shape_of_input = graph->create(onnx::Shape, 1);
+  Node* shape_of_input = graph->create(onnx::Shape, 1);
   shape_of_input->insertBefore(n);
   shape_of_input->addInput(n->inputs()[0]);
 
-  Node * gather_indices = graph->create(onnx::Constant, 1);
+  Node* gather_indices = graph->create(onnx::Constant, 1);
   gather_indices->insertBefore(n);
   gather_indices->t_(attr::value, at::scalar_to_tensor(at::Scalar(1)));
 
-  Node * batch_size = graph->create(onnx::Gather, 1);
+  Node* batch_size = graph->create(onnx::Gather, 1);
   batch_size->insertBefore(n);
   batch_size->addInput(shape_of_input->outputs()[0]);
   batch_size->addInput(gather_indices->outputs()[0]);
 
-  Node * unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
+  Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
   unsqueezed_batch_size->insertBefore(n);
   unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
   unsqueezed_batch_size->is_(attr::axes, {0});
 
-  Node * hidden_size = graph->create(onnx::Constant, 1);
+  Node* hidden_size = graph->create(onnx::Constant, 1);
   hidden_size->insertBefore(n);
-  hidden_size->t_(attr::value, at::full({1}, n->i(attr::hidden_size), at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
-
-  Node * num_directions = graph->create(onnx::Constant, 1);
+  hidden_size->t_(
+      attr::value,
+      at::full(
+          {1},
+          n->i(attr::hidden_size),
+          at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
+
+  Node* num_directions = graph->create(onnx::Constant, 1);
   num_directions->insertBefore(n);
-  num_directions->t_(attr::value, scalar_to_tensor(at::Scalar(n->hasAttribute(attr::direction) && n->s(attr::direction) == "bidirectional" ? 2 : 1)));
-
-  Node * unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
+  num_directions->t_(
+      attr::value,
+      scalar_to_tensor(at::Scalar(
+          n->hasAttribute(attr::direction) &&
+                  n->s(attr::direction) == "bidirectional"
+              ? 2
+              : 1)));
+
+  Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
   unsqueezed_num_directions->insertBefore(n);
   unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
   unsqueezed_num_directions->is_(attr::axes, {0});
 
-  Node * concated_dims = graph->create(onnx::Concat, 1);
+  Node* concated_dims = graph->create(onnx::Concat, 1);
   concated_dims->insertBefore(n);
   concated_dims->i_(attr::axis, 0);
   concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
   concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
   concated_dims->addInput(hidden_size->outputs()[0]);
 
-  Node * constant_fill = graph->create(onnx::ConstantFill, 1);
+  Node* constant_fill = graph->create(onnx::ConstantFill, 1);
   constant_fill->insertBefore(n);
   constant_fill->i_(attr::input_as_shape, 1);
   constant_fill->addInput(concated_dims->outputs()[0]);
@@ -406,7 +425,7 @@ void fixDefaultRNNState(Graph* graph, Node * n, int input_index) {
 void fixDefaultRnnHiddenState(Block* b) {
   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
     auto* n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       fixDefaultRnnHiddenState(child_block);
     }
 
@@ -422,10 +441,10 @@ void fixDefaultRnnHiddenState(Block* b) {
   }
 }
 
-void fixDefaultLstmCellState(Block *b) {
+void fixDefaultLstmCellState(Blockb) {
   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
     auto* n = *it;
-    for (auto *child_block : n->blocks()) {
+    for (autochild_block : n->blocks()) {
       fixDefaultLstmCellState(child_block);
     }
 
@@ -446,32 +465,32 @@ static bool isSafeToSpeculate(Node* n) {
 }
 
 static void speculateOps(Block* block) {
-  for(auto it = block->nodes().begin(), end = block->nodes().end();
-      it != end;) {
-    Node * n = *it;
-    ++it; //note: increment first so that it is safe to move the node if needed
+  for (auto it = block->nodes().begin(), end = block->nodes().end();
+       it != end;) {
+    Node* n = *it;
+    ++it; // note: increment first so that it is safe to move the node if needed
 
-    for(auto b : n->blocks()) {
+    for (auto b : n->blocks()) {
       speculateOps(b);
     }
-    if(!isSafeToSpeculate(n))
+    if (!isSafeToSpeculate(n))
       continue;
     // XXX - only works for nodes with a single input
     // move node n outside of the control flow it is nested in
     auto node_input = n->input()->node();
-    if(node_input->owningBlock() == n->owningBlock())
+    if (node_input->owningBlock() == n->owningBlock())
       continue;
     // find the control flow node in the same block as node_input that contains
     // Node n
     auto control_flow_node = n->owningBlock()->owningNode();
-    while(control_flow_node->owningBlock() != node_input->owningBlock())
+    while (control_flow_node->owningBlock() != node_input->owningBlock())
       control_flow_node = control_flow_node->owningBlock()->owningNode();
     // put the node right before this flow node
     n->moveBefore(control_flow_node);
   }
 }
 
-static void replaceInputWithList(Node *node, size_t i, ArrayRef<Value*> to) {
+static void replaceInputWithList(Nodenode, size_t i, ArrayRef<Value*> to) {
   node->removeInput(i);
   for (auto* to_val : to) {
     JIT_ASSERT(to_val->owningGraph() == node->owningGraph());
@@ -494,13 +513,15 @@ static void eraseListConstruct(Block* block) {
     for (auto* input : n->inputs()) {
       if (input->node()->kind() == prim::ListConstruct) {
         auto* lc_node = input->node();
-        TypePtr elem = lc_node->output()->type()->cast<ListType>()->getElementType();
+        TypePtr elem =
+            lc_node->output()->type()->cast<ListType>()->getElementType();
         if (elem->cast<IntType>()) {
-          // ListConstruct Int[] output case, we need to transfrom to ONNX Concat to ensure
-          // the output is a single tensor(dynamic) type in order to be consumed as inputs
+          // ListConstruct Int[] output case, we need to transfrom to ONNX
+          // Concat to ensure the output is a single tensor(dynamic) type in
+          // order to be consumed as inputs
           std::vector<Value*> unsqueezed;
-          Graph *g = block->owningGraph();
-          for (auto* input: lc_node->inputs()) {
+          Graphg = block->owningGraph();
+          for (auto* input : lc_node->inputs()) {
             Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
             unsqueezed_node->insertBefore(lc_node);
             unsqueezed_node->addInput(input);
@@ -509,23 +530,25 @@ static void eraseListConstruct(Block* block) {
           }
           Node* concat_node = g->create(onnx::Concat, 1);
           concat_node->i_(attr::axis, 0);
-          for(auto v: unsqueezed) {
+          for (auto v : unsqueezed) {
             concat_node->addInput(v);
           }
           concat_node->insertBefore(lc_node);
 
-          // make concat node output as new input, then ListConstruct should become dead
-          replacements.emplace_back(i, std::vector<Value*>({concat_node->output()}));
+          // make concat node output as new input, then ListConstruct should
+          // become dead
+          replacements.emplace_back(
+              i, std::vector<Value*>({concat_node->output()}));
 
         } else {
-          // Tensor lists are used mostly for inputs to cat/stack. They are already handled
-          // in those symbolics, and should become dead afterwards.
+          // Tensor lists are used mostly for inputs to cat/stack. They are
+          // already handled in those symbolics, and should become dead
+          // afterwards.
           replacements.emplace_back(
               i,
               std::vector<Value*>(
                   lc_node->inputs().begin(), lc_node->inputs().end()));
         }
-
       }
       i++;
     }
@@ -541,9 +564,9 @@ static void eraseListConstruct(Block* block) {
 //
 // At the moment, here are the optimizations it does:
 //  - This optimization fuses expand calls into ONNX operators, because it is
-//    easier for non-strided backends to more efficiently do broadcasts if this is
-//    local information.  This optimization is not useful for PyTorch as 'expand'
-//    is free.
+//    easier for non-strided backends to more efficiently do broadcasts if this
+//    is local information.  This optimization is not useful for PyTorch as
+//    'expand' is free.
 //  - Fusing of consecutive transposes
 //  - Elimination of NOP transposes
 //  - Fusing of transposes into Gemm
@@ -571,4 +594,5 @@ void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
   eraseListConstruct(graph->block());
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 859f111..63e9132 100644 (file)
@@ -1,7 +1,9 @@
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 32c68a3..10e07c5 100644 (file)
@@ -1,7 +1,8 @@
-#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 static void PrepareDivisionForONNXOnBlock(Block* block) {
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
@@ -13,16 +14,21 @@ static void PrepareDivisionForONNXOnBlock(Block* block) {
 
     if (it->matches("aten::div(int a, int b) -> float")) {
       // Cast to Float before dividing
-      std::vector<Value*> floattensor_inputs = fmap(it->inputs(), [&](Value* input) {
-        auto* longtensor = subgraph->insertNode(subgraph->createNumToTensor(input))->output();
-        auto* nonblocking = subgraph->insertConstant(0);
-        auto* cast = subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
-        return subgraph->insertNode(cast)->output();
-      });
+      std::vector<Value*> floattensor_inputs =
+          fmap(it->inputs(), [&](Value* input) {
+            auto* longtensor =
+                subgraph->insertNode(subgraph->createNumToTensor(input))
+                    ->output();
+            auto* nonblocking = subgraph->insertConstant(0);
+            auto* cast =
+                subgraph->create(aten::_cast_Float, {longtensor, nonblocking});
+            return subgraph->insertNode(cast)->output();
+          });
 
       it->replaceInput(0, floattensor_inputs[0]);
       it->replaceInput(1, floattensor_inputs[1]);
-      it->output()->setType(CompleteTensorType::fromNumberType(FloatType::get()));
+      it->output()->setType(
+          CompleteTensorType::fromNumberType(FloatType::get()));
     }
   }
 }
@@ -31,5 +37,5 @@ void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph) {
   PrepareDivisionForONNXOnBlock(graph->block());
 }
 
-}}
-
+} // namespace jit
+} // namespace torch
index 89b9fa1..5744f9c 100644 (file)
@@ -2,7 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // Prepare division ops for ONNX export. This is necessary for and only used
 // by ONNX export.
@@ -14,4 +15,5 @@ namespace torch { namespace jit {
 //
 TORCH_API void PrepareDivisionForONNX(const std::shared_ptr<Graph>& graph);
 
-}}
+} // namespace jit
+} // namespace torch
index 130edba..955143c 100644 (file)
@@ -4,7 +4,8 @@
 
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // The intent for this optimization pass is to catch all of the small, easy to
 // catch peephole optimizations you might be interested in doing.
@@ -17,24 +18,28 @@ namespace torch { namespace jit {
 //
 // The parameter `addmm_fusion_enabled` exists because, as it is today, fusing
 // add + mm has no benefit within PyTorch running ATen ops. However, we rely on
-// seeing the fused version of addmm for ONNX export, since after ONNX translation
-// we would see redundant Gemm ops with sub-optimal inputs. This flag is exposed
-// so that ONNX export can pass `true` to get the fused behavior, but normal
-// JIT peephole optimization is left alone.
-void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
+// seeing the fused version of addmm for ONNX export, since after ONNX
+// translation we would see redundant Gemm ops with sub-optimal inputs. This
+// flag is exposed so that ONNX export can pass `true` to get the fused
+// behavior, but normal JIT peephole optimization is left alone.
+void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
     auto* node = *it;
 
-    for (Block * sub_block : node->blocks()) {
+    for (Block* sub_block : node->blocks()) {
       PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled);
     }
 
-    // XXX: remember that if you want to simplify an expression by combining multiple nodes
-    // into a different one, then you need to check that they all belong to the given block
-    if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
-        /*const_inputs=*/attr::size)) {
+    // XXX: remember that if you want to simplify an expression by combining
+    // multiple nodes into a different one, then you need to check that they all
+    // belong to the given block
+    if (node->matches(
+            "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
+            /*const_inputs=*/attr::size)) {
       // x.expand(x.size()) == x
-      if (auto input_type = node->namedInput(attr::self)->type()->cast<CompleteTensorType>()) {
+      if (auto input_type = node->namedInput(attr::self)
+                                ->type()
+                                ->cast<CompleteTensorType>()) {
         auto expanded_sizes = node->get<std::vector<int64_t>>(attr::size);
         if (expanded_sizes == input_type->sizes()) {
           node->output()->replaceAllUsesWith(node->namedInput(attr::self));
@@ -42,11 +47,12 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
       }
     } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
       // x.t().t() == x
-      Node *input_node = node->input()->node();
+      Nodeinput_node = node->input()->node();
       if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
         node->output()->replaceAllUsesWith(input_node->input());
       }
-    } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+    } else if (node->matches(
+                   "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
       // x.type_as(y) == x iff x.type() == y.type()
       auto self_type = node->input(0)->type()->cast<TensorType>();
       auto other_type = node->input(1)->type()->cast<TensorType>();
@@ -55,31 +61,39 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
           self_type->device() == other_type->device()) {
         node->output()->replaceAllUsesWith(node->input(0));
       }
-    } else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
-               /*const_inputs=*/attr::alpha)) {
+    } else if (
+        node->matches(
+            "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
+            /*const_inputs=*/attr::alpha)) {
       // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
-      // This optimization has been disabled at the moment, because it's not helpful at all
-      // until we will be able to represent torch.addmm(a, b, c, out=a). That's because addmm
-      // dispatches internally to gemm, which computes:
+      // This optimization has been disabled at the moment, because it's not
+      // helpful at all until we will be able to represent torch.addmm(a, b, c,
+      // out=a). That's because addmm dispatches internally to gemm, which
+      // computes:
       //   C = beta * C + alpha * A @ B
       // but aten::addmm(a, b, c, 1, 1) is really:
       //   D = beta * C + alpha * A @ B
-      // and because it works out of place on C, we're only trading off an explicit add for
-      // a copy inside the addmm function. Note that it doesn't even result in fewer reads,
-      // because mm won't even load C (because beta == 0 for it).
-      if (addmm_fusion_enabled && node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
+      // and because it works out of place on C, we're only trading off an
+      // explicit add for a copy inside the addmm function. Note that it doesn't
+      // even result in fewer reads, because mm won't even load C (because beta
+      // == 0 for it).
+      if (addmm_fusion_enabled &&
+          node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
         // Look for mm from both sides of the add
         for (size_t mm_side = 0; mm_side < 2; mm_side++) {
+          // Add will accept tensors of mismatched scalar types, as long as one
+          // of them is a scalar. Addmm will throw in that case, so we can only
+          // perform this fusion if we're sure that it is correct, and for that
+          // we need the add_mat_type. An alternative would be to insert a
+          // type_as conditional on the tensor shape being a scalar, but that
+          // might add overhead, and make analysis harder.
+          auto add_mat_type =
+              node->input(1 - mm_side)->type()->cast<TensorType>();
+          if (!add_mat_type)
+            continue;
 
-          // Add will accept tensors of mismatched scalar types, as long as one of them is a scalar.
-          // Addmm will throw in that case, so we can only perform this fusion if we're sure
-          // that it is correct, and for that we need the add_mat_type.
-          // An alternative would be to insert a type_as conditional on the tensor shape being a
-          // scalar, but that might add overhead, and make analysis harder.
-          auto add_mat_type = node->input(1 - mm_side)->type()->cast<TensorType>();
-          if (!add_mat_type) continue;
-
-          if (node->input(mm_side)->node()->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
+          if (node->input(mm_side)->node()->matches(
+                  "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
             WithInsertPoint guard(node);
 
             auto mm_node = node->input(mm_side)->node();
@@ -91,10 +105,13 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
             if (!mat_type) {
               mat_type = mat2.value()->type()->cast<TensorType>();
             }
-            // We insert the type_as if we're sure that the added element is a scalar, and we
-            // either don't know what is the type of the multiplied matrices, or know the type,
-            // and know that it's mismatched.
-            if (add_mat_type->dim() == 0 && (!mat_type || add_mat_type->scalarType() != mat_type->scalarType())) {
+            // We insert the type_as if we're sure that the added element is a
+            // scalar, and we either don't know what is the type of the
+            // multiplied matrices, or know the type, and know that it's
+            // mismatched.
+            if (add_mat_type->dim() == 0 &&
+                (!mat_type ||
+                 add_mat_type->scalarType() != mat_type->scalarType())) {
               add_mat = add_mat.type_as(mat1);
             }
 
@@ -106,29 +123,45 @@ void PeepholeOptimizeImpl(Block * block, bool addmm_fusion_enabled) {
           }
         }
       }
-    // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize those
-    } else if (node->matches("aten::mul(Tensor self, Scalar other) -> Tensor", /*const_inputs=*/attr::other) ||
-               node->matches("aten::div(Tensor self, Scalar other) -> Tensor", /*const_inputs=*/attr::other)) {
+      // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize
+      // those
+    } else if (
+        node->matches(
+            "aten::mul(Tensor self, Scalar other) -> Tensor",
+            /*const_inputs=*/attr::other) ||
+        node->matches(
+            "aten::div(Tensor self, Scalar other) -> Tensor",
+            /*const_inputs=*/attr::other)) {
       // x * 1 == x / 1 == x
       if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
         node->output()->replaceAllUsesWith(node->input(0));
       }
-    } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", /*const_inputs=*/{attr::alpha, attr::other}) ||
-               node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", /*const_inputs=*/{attr::alpha, attr::other})) {
+    } else if (
+        node->matches(
+            "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+            /*const_inputs=*/{attr::alpha, attr::other}) ||
+        node->matches(
+            "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
+            /*const_inputs=*/{attr::alpha, attr::other})) {
       // x + 0 == x - 0 == x
       if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
           node->get<at::Scalar>(attr::other)->toDouble() == 0) {
         node->output()->replaceAllUsesWith(node->input(0));
       }
-    } else if (node->kind() == prim::Float || node->kind() == prim::Int || node->kind() == prim::ImplicitTensorToNum) {
+    } else if (
+        node->kind() == prim::Float || node->kind() == prim::Int ||
+        node->kind() == prim::ImplicitTensorToNum) {
       Node* input_node = node->input()->node();
       if (input_node->kind() == prim::NumToTensor) {
         node->output()->replaceAllUsesWith(input_node->input());
       }
-    } else if (node->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+    } else if (
+        node->matches(
+            "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
       auto uses = node->output()->uses();
       for (Use u : uses) {
-        if (u.user->matches("prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
+        if (u.user->matches(
+                "prim::SumToSize(Tensor(a) self, int[] size) -> Tensor(a)")) {
           u.user->replaceInput(0, node->inputs().at(0));
         }
       }
@@ -142,8 +175,11 @@ void PeepholeOptimize(Block* block, bool addmm_fusion_enabled) {
   EliminateDeadCode(block);
 }
 
-void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled) {
+void PeepholeOptimize(
+    const std::shared_ptr<Graph>& graph,
+    bool addmm_fusion_enabled) {
   PeepholeOptimize(graph->block(), addmm_fusion_enabled);
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 2815e69..b63fe7d 100644 (file)
@@ -2,9 +2,15 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-TORCH_API void PeepholeOptimize(const std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled=false);
-TORCH_API void PeepholeOptimize(Block* block, bool addmm_fusion_enabled=false);
+TORCH_API void PeepholeOptimize(
+    const std::shared_ptr<Graph>& graph,
+    bool addmm_fusion_enabled = false);
+TORCH_API void PeepholeOptimize(
+    Block* block,
+    bool addmm_fusion_enabled = false);
 
-}}
+} // namespace jit
+} // namespace torch
index 1f28cc0..d736587 100644 (file)
@@ -1,14 +1,13 @@
-#include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/attributes.h>
+#include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/generic_if.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/ir_views.h>
-#include <torch/csrc/jit/export.h>
+#include <torch/csrc/jit/passes/python_print.h>
 #include <torch/csrc/jit/resource_guard.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/module.h>
 
-
 namespace torch {
 namespace jit {
 
@@ -19,7 +18,7 @@ static bool isPrint(char s) {
 
 void printQuotedString(std::ostream& stmt, const std::string& str) {
   stmt << "\"";
-  for(auto s : str) {
+  for (auto s : str) {
     switch (s) {
       case '\\':
         stmt << "\\\\";
@@ -58,8 +57,10 @@ void printQuotedString(std::ostream& stmt, const std::string& str) {
           // C++ io has stateful formatting settings. Messing with
           // them is probably worse than doing this manually.
           char buf[4] = "000";
-          buf[2] += s % 8; s /= 8;
-          buf[1] += s % 8; s /= 8;
+          buf[2] += s % 8;
+          s /= 8;
+          buf[1] += s % 8;
+          s /= 8;
           buf[0] += s;
           stmt << "\\" << buf;
         }
@@ -70,10 +71,10 @@ void printQuotedString(std::ostream& stmt, const std::string& str) {
 }
 
 static bool isValidIdentifierChar(char c, size_t pos) {
-  return islower(c) || isupper(c) || c == '_' ||  (pos > 0 && isdigit(c));
+  return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
 }
 
-static bool isValidIdentifier(const std::string & name) {
+static bool isValidIdentifier(const std::string& name) {
   if (name.size() == 0)
     return false;
   for (size_t i = 0; i < name.size(); ++i) {
@@ -90,21 +91,24 @@ struct QualifiedName;
 using QualifiedNamePtr = c10::intrusive_ptr<QualifiedName>;
 struct QualifiedName : c10::intrusive_ptr_target {
   QualifiedName(QualifiedNamePtr prefix, std::string name)
-  : prefix_(std::move(prefix)), name_(std::move(name)) {}
+      : prefix_(std::move(prefix)), name_(std::move(name)) {}
   QualifiedNamePtr prefix_;
   std::string name_;
   static QualifiedNamePtr create(QualifiedNamePtr prefix, std::string name) {
-    return c10::make_intrusive<QualifiedName>(std::move(prefix), std::move(name));
+    return c10::make_intrusive<QualifiedName>(
+        std::move(prefix), std::move(name));
   }
   static QualifiedNamePtr create(std::string name) {
-    return c10::make_intrusive<QualifiedName>(QualifiedNamePtr(), std::move(name));
+    return c10::make_intrusive<QualifiedName>(
+        QualifiedNamePtr(), std::move(name));
   }
   std::string str() const {
     std::stringstream ss;
     emit(ss);
     return ss.str();
   }
-private:
+
+ private:
   void emit(std::ostream& out) const {
     if (isValidIdentifier(name_)) {
       if (prefix_) {
@@ -127,7 +131,6 @@ void createTensorToParameterNameMap(
     const script::Module& module,
     const QualifiedNamePtr& prefix,
     std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
-
   for (const auto& elem : module.get_parameters()) {
     const script::NamedParameter& param = elem.value();
     result[param.slot()] = QualifiedName::create(prefix, param.name);
@@ -138,9 +141,9 @@ void createTensorToParameterNameMap(
   }
 }
 
-  // some names are valid identifiers but off limits because
-  // they are keywords or namespaces used in the output
-  const static std::unordered_set<std::string> reserved_names = {
+// some names are valid identifiers but off limits because
+// they are keywords or namespaces used in the output
+const static std::unordered_set<std::string> reserved_names = {
     // identifiers in the environment while parsing
     "_", // avoid the confusing unnamed _
     "aten",
@@ -188,7 +191,7 @@ void createTensorToParameterNameMap(
     "while",
     "with",
     "yield",
-  };
+};
 
 struct PythonPrintPass {
   std::ostream& out;
@@ -221,21 +224,23 @@ struct PythonPrintPass {
   // we only do this if
   // (1) it is a constant, or
   // (2) the temporary is unnamed, is single output, is used once,
-  //     and would appear in the same order when the expression tree is reparsed.
+  //     and would appear in the same order when the expression tree is
+  //     reparsed.
   // The last case can be checked
   // becuase when we emit a expresion tree in the parser,
-  // we do a left-to-right postorder traversal of the expression tree (emit children, then emit op).
-  // The reverse of this is a right-to-left preorder traversal of the tree.
-  // By doing a right-to-left preorder traversal of the inputs of a node,
-  // while also scanning the list of emitted nodes backward, we can see if
-  // they line up with what would happen when parsed the node as an expression. While they line
-  // up we collapse them into an inline expression.
+  // we do a left-to-right postorder traversal of the expression tree (emit
+  // children, then emit op). The reverse of this is a right-to-left preorder
+  // traversal of the tree. By doing a right-to-left preorder traversal of the
+  // inputs of a node, while also scanning the list of emitted nodes backward,
+  // we can see if they line up with what would happen when parsed the node as
+  // an expression. While they line up we collapse them into an inline
+  // expression.
 
-  // The inductive step is that the right-most input should be produced by the node
-  // immediatly before the current node if it is in tree order.
+  // The inductive step is that the right-most input should be produced by the
+  // node immediatly before the current node if it is in tree order.
 
   bool isConstantLike(Node* n) {
-    switch(n->kind()) {
+    switch (n->kind()) {
       case prim::Constant:
       case prim::Undefined:
       case prim::None:
@@ -247,7 +252,8 @@ struct PythonPrintPass {
 
   bool canInline(Value* v) {
     Node* n = v->node();
-    // there must be only 1 values, otherwise we need an assignment to handle the multiple outout values
+    // there must be only 1 values, otherwise we need an assignment to handle
+    // the multiple outout values
     if (n->outputs().size() != 1)
       return false;
     // if it is used more than once, then we need a variable
@@ -263,19 +269,23 @@ struct PythonPrintPass {
     if (n->blocks().size() != 0)
       return false;
     // if it is a loop-carried input, we need a variable
-    // otherwise the condition or trip count may be emitted in the wrong order w.r.t. to it
+    // otherwise the condition or trip count may be emitted in the wrong order
+    // w.r.t. to it
     if (use.user->kind() == prim::Loop && use.offset >= 2)
       return false;
     return true;
   }
 
-  // block_point is the current node in the reverse linear scan of the emitted nodes
-  // v is the current value in the tree traversal that may match with block_point's output.
+  // block_point is the current node in the reverse linear scan of the emitted
+  // nodes v is the current value in the tree traversal that may match with
+  // block_point's output.
   Node* scanValue(Node* block_point, Value* v) {
     Node* n = v->node();
     JIT_ASSERT(isConstantLike(n) || output_inline_.count(n) == 0);
 
-    if (n == block_point && canInline(v)) { // the node must be at the expected point of the typical tree traversal
+    if (n == block_point &&
+        canInline(v)) { // the node must be at the expected point of the typical
+                        // tree traversal
       // recursively see if we can inline the inputs to this input
       block_point = scanNode(block_point);
       output_inline_.insert(n);
@@ -289,21 +299,21 @@ struct PythonPrintPass {
   Node* previousNonConstant(Node* n) {
     do {
       n = n->prev();
-    } while(isConstantLike(n));
+    } while (isConstantLike(n));
     return n;
   }
 
   Node* scanNode(Node* n) {
     // don't bother to scan nodes we have already determined to be inline
-    if(output_inline_.count(n)) {
+    if (output_inline_.count(n)) {
       return n;
     }
-    for(auto b : n->blocks()) {
+    for (auto b : n->blocks()) {
       scanBlock(b);
     }
     Node* block_point = previousNonConstant(n);
-    for(auto it = n->inputs().rbegin(),
-             end = n->inputs().rend(); it != end; ++it) {
+    for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
+         ++it) {
       block_point = scanValue(block_point, *it);
     }
     return block_point;
@@ -311,7 +321,7 @@ struct PythonPrintPass {
 
   void scanBlock(Block* b) {
     scanNode(b->return_node());
-    for(auto node : b->nodes().reverse()) {
+    for (auto node : b->nodes().reverse()) {
       scanNode(node);
     }
   }
@@ -321,7 +331,7 @@ struct PythonPrintPass {
     // ConstantPool, which is also N^2 in the size of the constants,
     // because it doesn't hash any information about the tensors.
     // We will probably need to optimize this at some point using hashing.
-    for(size_t i = 0; i < tensor_table_.size(); ++i) {
+    for (size_t i = 0; i < tensor_table_.size(); ++i) {
       if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
         return i;
       }
@@ -333,18 +343,19 @@ struct PythonPrintPass {
 
   std::unordered_set<Node*> seen_constants;
   void buildConstantList(Node* n, std::vector<Node*>& constants) {
-    for(auto input : n->inputs()) {
-      if (isConstantLike(input->node()) && seen_constants.count(input->node()) == 0) {
+    for (auto input : n->inputs()) {
+      if (isConstantLike(input->node()) &&
+          seen_constants.count(input->node()) == 0) {
         constants.push_back(input->node());
         seen_constants.insert(input->node());
       }
     }
-    for(auto b : n->blocks()) {
+    for (auto b : n->blocks()) {
       buildConstantList(b, constants);
     }
   }
   void buildConstantList(Block* b, std::vector<Node*>& constants) {
-    for(auto n : b->nodes())
+    for (auto n : b->nodes())
       buildConstantList(n, constants);
     buildConstantList(b->return_node(), constants);
   }
@@ -352,9 +363,11 @@ struct PythonPrintPass {
   // anything we have used.
   size_t next_id = 0;
 
-  std::string genNameImpl(const std::string& candidate, std::unordered_set<std::string>& used) {
+  std::string genNameImpl(
+      const std::string& candidate,
+      std::unordered_set<std::string>& used) {
     std::string name = candidate;
-    while(used.count(name) || reserved_names.count(name)) {
+    while (used.count(name) || reserved_names.count(name)) {
       name = candidate + std::to_string(next_id++);
     }
     used.insert(name);
@@ -377,7 +390,7 @@ struct PythonPrintPass {
     std::stringstream ss;
     if (candidate.size() == 0 || isdigit(candidate[0]))
       ss << "_";
-    for(char c : candidate) {
+    for (char c : candidate) {
       if (isupper(c) || islower(c) || isdigit(c) || c == '_')
         ss << c;
       else
@@ -405,7 +418,7 @@ struct PythonPrintPass {
     assignValue(v, useOf(w));
   }
   void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
-    for(auto v : values) {
+    for (auto v : values) {
       assignValue(v, genUniqueNameFor(v));
     }
   }
@@ -421,16 +434,12 @@ struct PythonPrintPass {
 
   ResourceGuard WithIndented() {
     level++;
-    return ResourceGuard([this]{
-      level--;
-    });
+    return ResourceGuard([this] { level--; });
   }
 
   template <class T0, class T1, class F>
-  void zipWith(
-      at::ArrayRef<T0> list_a,
-      at::ArrayRef<T1> list_b,
-      F action) const {
+  void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
+      const {
     auto it_a = list_a.begin();
     auto it_b = list_b.begin();
 
@@ -443,7 +452,11 @@ struct PythonPrintPass {
     }
   }
 
-  void printValueList(std::ostream& stmt, at::ArrayRef<Value*> list, const char* begin = "", const char* end = "") {
+  void printValueList(
+      std::ostream& stmt,
+      at::ArrayRef<Value*> list,
+      const char* begin = "",
+      const char* end = "") {
     stmt << begin;
     auto delimiter = "";
     for (auto* value : list) {
@@ -454,10 +467,8 @@ struct PythonPrintPass {
     stmt << end;
   }
 
-  void printAssignment(
-      at::ArrayRef<Value*> lhs,
-      at::ArrayRef<Value*> rhs) {
-    if(lhs.size() > 0) {
+  void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
+    if (lhs.size() > 0) {
       indent();
       printValueList(out, lhs);
       out << " = ";
@@ -483,41 +494,43 @@ struct PythonPrintPass {
     }
   }
 
-  // our way of encoding loops makes them difficult to turn back into python syntax.
-  // we have to check properties of the condition and trip count inputs to
-  // figure out which one it initially was
+  // our way of encoding loops makes them difficult to turn back into python
+  // syntax. we have to check properties of the condition and trip count inputs
+  // to figure out which one it initially was
   static bool shouldEmitAsForLoop(LoopView stmt) {
-      auto trip_count = toIValue(stmt.maxTripCount());
-      auto cond_input = toIValue(stmt.inputCond());
-      auto cond_next = toIValue(stmt.nextCond());
-
-      bool condition_is_always_true = cond_input && cond_input->toBool() && cond_next &&
-        cond_next->toBool();
-      bool trip_count_is_specified = !trip_count || // trip is not a constant
-          trip_count->toInt() != std::numeric_limits<int64_t>::max() || // it is a constant but not the default one
-          stmt.currentTripCount()->uses().size() > 0; // it is actually being used in the body.
-
-      if (condition_is_always_true) {
-        // if the trip count was not specified this was a user-written while True:
-        return trip_count_is_specified;
-      } else {
-        // this must be a while loop, but check that there isn't _also_ a trip count
-        if (trip_count_is_specified) {
-          throw script::ErrorReport(stmt.node()->getSourceLocation())
-              << "loop cannot be printed as python because it has gone through an optimization "
-              << "that combined while and for loops. File a bug.";
-        }
-        return false;
+    auto trip_count = toIValue(stmt.maxTripCount());
+    auto cond_input = toIValue(stmt.inputCond());
+    auto cond_next = toIValue(stmt.nextCond());
+
+    bool condition_is_always_true =
+        cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
+    bool trip_count_is_specified = !trip_count || // trip is not a constant
+        trip_count->toInt() !=
+            std::numeric_limits<int64_t>::max() || // it is a constant but not
+                                                   // the default one
+        stmt.currentTripCount()->uses().size() >
+            0; // it is actually being used in the body.
+
+    if (condition_is_always_true) {
+      // if the trip count was not specified this was a user-written while True:
+      return trip_count_is_specified;
+    } else {
+      // this must be a while loop, but check that there isn't _also_ a trip
+      // count
+      if (trip_count_is_specified) {
+        throw script::ErrorReport(stmt.node()->getSourceLocation())
+            << "loop cannot be printed as python because it has gone through an optimization "
+            << "that combined while and for loops. File a bug.";
       }
+      return false;
+    }
   }
 
   void printLoop(LoopView stmt) {
-
     // Loop carried dependencies are handled by assigning their initial
     // values to the node->outputs() before the loop,
     // and assign node->outputs() to the new values at the end of each trip.
 
-
     bool emit_as_for_loop = shouldEmitAsForLoop(stmt);
 
     assignValuesToTheirUniqueNames(stmt.carriedOutputs());
@@ -553,9 +566,11 @@ struct PythonPrintPass {
       // the condition is always True
       size_t offset = emit_as_for_loop ? 1 : 0;
       auto body_block = stmt.bodyBlock();
-      ArrayRef<Value*> loop_carried_block_inputs = body_block->inputs().slice(offset);
+      ArrayRef<Value*> loop_carried_block_inputs =
+          body_block->inputs().slice(offset);
       printBlock(body_block, loop_carried_block_inputs.size() > 0);
-      printAssignment(loop_carried_block_inputs, body_block->outputs().slice(offset));
+      printAssignment(
+          loop_carried_block_inputs, body_block->outputs().slice(offset));
     }
   }
 
@@ -602,7 +617,8 @@ struct PythonPrintPass {
         // this node is safe to inline, so assign the output value
         // to that expression directly
         // guard against really long lines
-        if (output_inline_.count(node) > 0 && ss.str().size() + level * 2 < 40) {
+        if (output_inline_.count(node) > 0 &&
+            ss.str().size() + level * 2 < 40) {
           assignValue(node->output(), ss.str());
           return;
         }
@@ -622,7 +638,7 @@ struct PythonPrintPass {
       const char* the_type,
       size_t list_size,
       const IValue& the_list) {
-    if(list_size == 0) {
+    if (list_size == 0) {
       stmt << "annotate(List[" << the_type << "], [])";
     } else {
       stmt << the_list;
@@ -630,30 +646,32 @@ struct PythonPrintPass {
   }
 
   void printConstant(std::ostream& stmt, const IValue& v) {
-    if(v.isTensor()) {
+    if (v.isTensor()) {
       stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
-    } else if(v.isString()) {
+    } else if (v.isString()) {
       printQuotedString(stmt, v.toStringRef());
-    } else if(v.isDevice()) {
+    } else if (v.isDevice()) {
       std::stringstream ss;
       ss << v.toDevice();
       stmt << "torch.device(";
       printQuotedString(stmt, ss.str());
       stmt << ")";
-    } else if(v.isTensorList()) {
+    } else if (v.isTensorList()) {
       stmt << "[";
       const char* delim = "";
-      for(const auto& t : v.toTensorListRef()) {
+      for (const auto& t : v.toTensorListRef()) {
         stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
         delim = ", ";
       }
       stmt << "]";
-    } else if(v.isBoolList()) {
-      printMaybeAnnotatedConstantList(stmt, "bool", v.toBoolListRef().size(), v);
-    } else if(v.isIntList()) {
+    } else if (v.isBoolList()) {
+      printMaybeAnnotatedConstantList(
+          stmt, "bool", v.toBoolListRef().size(), v);
+    } else if (v.isIntList()) {
       printMaybeAnnotatedConstantList(stmt, "int", v.toIntListRef().size(), v);
-    } else if(v.isDoubleList()) {
-      printMaybeAnnotatedConstantList(stmt, "float", v.toDoubleListRef().size(), v);
+    } else if (v.isDoubleList()) {
+      printMaybeAnnotatedConstantList(
+          stmt, "float", v.toDoubleListRef().size(), v);
     } else {
       stmt << v;
     }
@@ -661,7 +679,7 @@ struct PythonPrintPass {
 
   // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
   void printRHS(std::ostream& stmt, Node* node) {
-    switch(node->kind()) {
+    switch (node->kind()) {
       case PythonOp::Kind: {
         auto value = static_cast<const PythonOp*>(node);
         if (enforce_importable_) {
@@ -692,10 +710,10 @@ struct PythonPrintPass {
 
         // XXX - when None has an Optional[T] type, we must ensure that type
         // can be recovered on parsing. It cannot be recovered if it will be
-        // matched to schema with free variables. If it is used only in places where
-        // there is schema and the scheme has no free variables, then we can
-        // recover it without annotation. Otherwise, we annotate None with the right
-        // optional type
+        // matched to schema with free variables. If it is used only in places
+        // where there is schema and the scheme has no free variables, then we
+        // can recover it without annotation. Otherwise, we annotate None with
+        // the right optional type
         const auto& uses = node->output()->uses();
         bool all_usable_schema =
             std::all_of(uses.begin(), uses.end(), [](const Use& u) {
@@ -704,9 +722,9 @@ struct PythonPrintPass {
                   return false;
                 }
                 return !schema->arguments()
-                    .at(u.offset)
-                    .type()
-                    ->hasFreeVariables();
+                            .at(u.offset)
+                            .type()
+                            ->hasFreeVariables();
               }
               return false;
             });
@@ -714,7 +732,8 @@ struct PythonPrintPass {
         if (all_usable_schema) {
           stmt << "None";
         } else {
-          stmt << "annotate(" << node->output()->type()->python_str() << ", None)";
+          stmt << "annotate(" << node->output()->type()->python_str()
+               << ", None)";
         }
       } break;
       case prim::ImplicitTensorToNum: {
@@ -731,14 +750,15 @@ struct PythonPrintPass {
         printValueList(stmt, node->inputs(), "bool(", ")");
       } break;
       case prim::Print: {
-        printValueList(stmt, node->inputs(), "print(",")");
+        printValueList(stmt, node->inputs(), "print(", ")");
       } break;
       case prim::TupleConstruct: {
         printValueList(
             stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
       } break;
       case prim::TupleIndex: {
-        stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index) << "]";
+        stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index)
+             << "]";
       } break;
       case prim::TupleSlice: {
         stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
@@ -750,7 +770,8 @@ struct PythonPrintPass {
         // to infer the type on import
         if (node->inputs().size() == 0 &&
             !node->output()->type()->isSubtypeOf(DynamicType::get())) {
-          stmt << "annotate(" << node->output()->type()->python_str() << ", [])";
+          stmt << "annotate(" << node->output()->type()->python_str()
+               << ", [])";
         } else {
           printValueList(stmt, node->inputs(), "[", "]");
         }
@@ -759,25 +780,24 @@ struct PythonPrintPass {
         // the subgraph gets emitted as another function
         auto name = genMethodName("__forked_function");
         std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
-        worklist.emplace_back([graph, name, this] {
-          printFunctionDefinition(*graph, name);
-        });
+        worklist.emplace_back(
+            [graph, name, this] { printFunctionDefinition(*graph, name); });
         // and we put a call to fork which invokes that function.
         stmt << "fork(self." << name;
-        for(Value* v : node->inputs()) {
+        for (Value* v : node->inputs()) {
           stmt << ", " << useOf(v);
         }
         stmt << ")";
       } break;
       case prim::Function: {
         if (enforce_importable_) {
-          throw script::ErrorReport(node->getSourceLocation()) << "closures are not exportable";
+          throw script::ErrorReport(node->getSourceLocation())
+              << "closures are not exportable";
         }
         auto name = genMethodName("__lambda");
         std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
-        worklist.emplace_back([graph, name, this] {
-          printFunctionDefinition(*graph, name);
-        });
+        worklist.emplace_back(
+            [graph, name, this] { printFunctionDefinition(*graph, name); });
         stmt << "self." << name;
       } break;
       default: {
@@ -788,25 +808,26 @@ struct PythonPrintPass {
           // doing it here ensures we do not have fix up archives later
           stmt << "torch." << kind.toUnqualString() << "(";
         } else {
-          stmt << "ops." << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+          stmt << "ops." << kind.ns().toUnqualString() << "."
+               << kind.toUnqualString() << "(";
         }
         const FunctionSchema& schema = node->schema();
         for (size_t i = 0; i < node->inputs().size(); ++i) {
-            if (i > 0) {
-              stmt << ", ";
+          if (i > 0) {
+            stmt << ", ";
+          }
+          auto v = useOf(node->inputs().at(i));
+          // print the kwarg name if it is a kwarg only argument.
+          if (i < schema.arguments().size()) {
+            auto arg = schema.arguments().at(i);
+            if (arg.kwarg_only()) {
+              stmt << arg.name() << "=";
             }
-            auto v = useOf(node->inputs().at(i));
-            // print the kwarg name if it is a kwarg only argument.
-            if (i < schema.arguments().size()) {
-              auto arg = schema.arguments().at(i);
-              if (arg.kwarg_only()) {
-                stmt << arg.name() << "=";
-              }
-            } else {
-              // vararg functions like format can have extra arguments
-              JIT_ASSERT(schema.is_vararg());
-            }
-            stmt << v;
+          } else {
+            // vararg functions like format can have extra arguments
+            JIT_ASSERT(schema.is_vararg());
+          }
+          stmt << v;
         }
         stmt << ")";
       } break;
@@ -814,9 +835,9 @@ struct PythonPrintPass {
   }
 
   std::ostream& printBlock(Block* root, bool block_has_other_statements) {
-    // pythons weird 'pass' syntax creates a bunch of places where we have to check
-    // if this block would be empty. But not everything in a block is a node.
-    // Sometimes if, loop, and return statements will follow this block
+    // pythons weird 'pass' syntax creates a bunch of places where we have to
+    // check if this block would be empty. But not everything in a block is a
+    // node. Sometimes if, loop, and return statements will follow this block
     // and block_has_other_statements == true.
     if (!block_has_other_statements &&
         root->nodes().begin() == root->nodes().end()) {
@@ -829,22 +850,26 @@ struct PythonPrintPass {
     return out;
   }
 
-  void printDefaultValue(const TypePtr& typ, std::ostream& stmt, const IValue& value) {
-    // xxx - many weak script modules store default values for broadcasting lists
-    // that are not actually the same type as the argument. We can only serialize
-    // default values that will implicitly convert to their declared return type
-    // since we do not need to serialize these built-in modules with their defaults,
-    // we just drop them for now.
+  void printDefaultValue(
+      const TypePtr& typ,
+      std::ostream& stmt,
+      const IValue& value) {
+    // xxx - many weak script modules store default values for broadcasting
+    // lists that are not actually the same type as the argument. We can only
+    // serialize default values that will implicitly convert to their declared
+    // return type since we do not need to serialize these built-in modules with
+    // their defaults, we just drop them for now.
     if (typ->kind() == ListType::Kind &&
         (value.isInt() || value.isDouble() || value.isBool())) {
       return;
     }
     stmt << "=";
     if (value.isTensor() && !value.toTensor().defined()) {
-      // XXX - because undefined tensors are not stored as None, we need special handling.
-      // otherwise they get printed as CONSTANTS.c0 and then cannot be recreated because
-      // constant nodes cannot have an undefined value in them.
-      // The right solution is to make None of type Tensor actually be an IValue None.
+      // XXX - because undefined tensors are not stored as None, we need special
+      // handling. otherwise they get printed as CONSTANTS.c0 and then cannot be
+      // recreated because constant nodes cannot have an undefined value in
+      // them. The right solution is to make None of type Tensor actually be an
+      // IValue None.
       stmt << "None";
       return;
     }
@@ -855,7 +880,6 @@ struct PythonPrintPass {
       const std::string& name,
       const std::vector<c10::optional<IValue>>& defaults = {},
       const std::vector<std::string>& param_names = {}) {
-
     used_names_.clear(); // each graph can reuse local names
 
     // we always print constants at the top of the function, in the order
@@ -869,9 +893,10 @@ struct PythonPrintPass {
     // last param_names.size() arguments to the graph are parameters and not
     // actual inputs, we will print these as, e.g. self.foo.bar
     // while we print the true_inputs out as parameters
-    auto true_inputs = graph.inputs().slice(0, graph.inputs().size() - param_names.size());
+    auto true_inputs =
+        graph.inputs().slice(0, graph.inputs().size() - param_names.size());
     auto param_names_it = param_names.begin();
-    for(auto param : graph.inputs().slice(true_inputs.size())) {
+    for (auto param : graph.inputs().slice(true_inputs.size())) {
       assignValue(param, *param_names_it++);
     }
     assignValuesToTheirUniqueNames(true_inputs);
@@ -910,7 +935,9 @@ struct PythonPrintPass {
       std::ostream& out_,
       std::vector<at::Tensor>& tensor_table,
       bool enforce_importable)
-      : out(out_), tensor_table_(tensor_table), enforce_importable_(enforce_importable) {}
+      : out(out_),
+        tensor_table_(tensor_table),
+        enforce_importable_(enforce_importable) {}
 
   // TODO: we should consider forcing functions to return a single value
   // instead of handling this tuple logic both in the compiler and the printer
@@ -929,7 +956,7 @@ struct PythonPrintPass {
       const std::vector<c10::optional<IValue>>& defaults = {},
       const std::vector<std::string>& param_names = {}) {
     printFunctionDefinition(graph, name, defaults, param_names);
-    while(!worklist.empty()) {
+    while (!worklist.empty()) {
       out << "\n\n";
       auto work = worklist.back();
       worklist.pop_back();
@@ -937,8 +964,10 @@ struct PythonPrintPass {
     }
   }
   void printMethod(script::Method& method) {
-    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
-    createTensorToParameterNameMap(method.owner(), QualifiedName::create("self"),  parameter_names);
+    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+    ;
+    createTensorToParameterNameMap(
+        method.owner(), QualifiedName::create("self"), parameter_names);
     printMethod(method, parameter_names);
   }
   void printMethod(
@@ -950,18 +979,21 @@ struct PythonPrintPass {
         [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
-    auto defaults = fmap(method.getSchema().arguments(), [](const Argument& arg) {
-      return arg.default_value();
-    });
+    auto defaults = fmap(
+        method.getSchema().arguments(),
+        [](const Argument& arg) { return arg.default_value(); });
     printFunction(graph, name, defaults, param_names);
   }
   void printModule(script::Module& module) {
-    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
-    createTensorToParameterNameMap(module, QualifiedName::create("self"),  parameter_names);
-    for(auto& method : module.get_methods()) {
+    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;
+    ;
+    createTensorToParameterNameMap(
+        module, QualifiedName::create("self"), parameter_names);
+    for (auto& method : module.get_methods()) {
       const std::string& name = method.value()->name();
       // we skip __forked_functions because they actually get inlined into their
-      // callers, exporting them again will lead to more code generated on each export
+      // callers, exporting them again will lead to more code generated on each
+      // export
       if (name.find("__forked_function") == 0) {
         continue;
       }
@@ -970,19 +1002,31 @@ struct PythonPrintPass {
   }
 };
 
-TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const Graph& graph,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable) {
   PythonPrintPass pp(out, tensor_table, enforce_importable);
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   pp.printFunction(const_cast<Graph&>(graph), "graph");
 }
 
-TORCH_API void PythonPrint(std::ostream& out, const script::Method& method, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const script::Method& method,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable) {
   PythonPrintPass pp(out, tensor_table, enforce_importable);
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   pp.printMethod(const_cast<script::Method&>(method));
 }
 
-TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const script::Module& module,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable) {
   PythonPrintPass pp(out, tensor_table, enforce_importable);
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   pp.printModule(const_cast<script::Module&>(module));
@@ -997,18 +1041,18 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
   // schema to editing this list here. These cases should only be things
   // that require special handling because they do not fit normal schema
   const static std::unordered_set<Symbol> handled = {
-    prim::Constant,
-    prim::fork,
-    prim::ListConstruct,
-    prim::ListUnpack,
-    prim::None,
-    prim::Print,
-    prim::PythonOp,
-    prim::TupleConstruct,
-    prim::TupleIndex,
-    prim::TupleSlice,
-    prim::TupleUnpack,
-    prim::Undefined,
+      prim::Constant,
+      prim::fork,
+      prim::ListConstruct,
+      prim::ListUnpack,
+      prim::None,
+      prim::Print,
+      prim::PythonOp,
+      prim::TupleConstruct,
+      prim::TupleIndex,
+      prim::TupleSlice,
+      prim::TupleUnpack,
+      prim::Undefined,
   };
 
   // WARNING: by adding a value to this set, you are asserting that your
@@ -1016,21 +1060,21 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
   // to be correctly printed for export (a process that happens before
   // optimization passes run)
   const static std::unordered_set<Symbol> unneeded = {
-    onnx::Reshape, // only used in onnx
-    onnx::Shape, // only used in onnx
-    prim::AnyDefined, // temporarily inserted by autograd
-    prim::AutogradAdd, // temporarily inserted by autograd
-    prim::ConstantChunk, // optimization pass adds it
-    prim::DifferentiableGraph, // optimization pass adds it
-    prim::BroadcastSizes, // optimization pass (fuser) adds it
-    prim::ChunkSizes, // optimization pass (fuser) adds it
-    prim::Drop, // used in interpreter only
-    prim::FusedConcat, // optimization pass adds it
-    prim::FusionGroup, // optimization pass adds it
-    prim::Load, // used in interpreter only
-    prim::MMTreeReduce, // used as an optimization
-    prim::MMBatchSide, // used as an optimization
-    prim::Store, // used in interpreter only
+      onnx::Reshape, // only used in onnx
+      onnx::Shape, // only used in onnx
+      prim::AnyDefined, // temporarily inserted by autograd
+      prim::AutogradAdd, // temporarily inserted by autograd
+      prim::ConstantChunk, // optimization pass adds it
+      prim::DifferentiableGraph, // optimization pass adds it
+      prim::BroadcastSizes, // optimization pass (fuser) adds it
+      prim::ChunkSizes, // optimization pass (fuser) adds it
+      prim::Drop, // used in interpreter only
+      prim::FusedConcat, // optimization pass adds it
+      prim::FusionGroup, // optimization pass adds it
+      prim::Load, // used in interpreter only
+      prim::MMTreeReduce, // used as an optimization
+      prim::MMBatchSide, // used as an optimization
+      prim::Store, // used in interpreter only
 
   };
 
index 5906eeb..e9dc357 100644 (file)
@@ -4,17 +4,30 @@
 #include <iostream>
 #include <vector>
 
-
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace script {
-  struct Method;
-  struct Module;
-}
+struct Method;
+struct Module;
+} // namespace script
 
-TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
-TORCH_API void PythonPrint(std::ostream& out, const script::Method& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
-TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const Graph& graph,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable = false);
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const script::Method& graph,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable = false);
+TORCH_API void PythonPrint(
+    std::ostream& out,
+    const script::Module& module,
+    std::vector<at::Tensor>& tensor_table,
+    bool enforce_importable = false);
 
 TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
-}}
+} // namespace jit
+} // namespace torch
index 892e362..a505574 100644 (file)
@@ -1,7 +1,8 @@
-#include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/remove_expands.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 static void RemoveExpands(Block* block) {
   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
@@ -20,5 +21,5 @@ void RemoveExpands(const std::shared_ptr<Graph>& graph) {
   RemoveExpands(graph->block());
 }
 
-
-}}
+} // namespace jit
+} // namespace torch
index 465d452..cef18ba 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 TORCH_API void RemoveExpands(const std::shared_ptr<Graph>& graph);
 
-}}
+}
+} // namespace torch
index 0a17c63..cd23885 100644 (file)
@@ -1,5 +1,5 @@
-#include <torch/csrc/jit/passes/remove_inplace_ops.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/remove_inplace_ops.h>
 
 namespace torch {
 namespace jit {
@@ -50,7 +50,7 @@ void RemoveInplaceOps(Block* block) {
     }
   }
 }
-}
+} // namespace
 
 void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
   RemoveInplaceOps(graph->block());
index dd57a87..36d00ca 100644 (file)
@@ -1,32 +1,35 @@
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/type.h>
 #include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/type.h>
 
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
-bool getRequiresGrad(Value * value) {
+bool getRequiresGrad(Value* value) {
   return value->requires_grad();
 }
 
-void setRequiresGrad(Value * value, bool req_value) {
+void setRequiresGrad(Value* value, bool req_value) {
   if (auto type = value->type()->cast<TensorType>()) {
     value->setType(type->withRequiresGrad(req_value));
   }
 }
 
-void setRequiresGrad(at::ArrayRef<Value*> outputs, const std::vector<bool>& values) {
+void setRequiresGrad(
+    at::ArrayRef<Value*> outputs,
+    const std::vector<bool>& values) {
   JIT_ASSERT(outputs.size() == values.size());
   for (size_t i = 0; i < values.size(); ++i) {
     setRequiresGrad(outputs[i], values[i]);
   }
 }
 
-void setRequiresGrad(Node * node, const std::vector<bool>& values) {
+void setRequiresGrad(Node* node, const std::vector<bool>& values) {
   setRequiresGrad(node->outputs(), values);
 }
 
@@ -38,26 +41,26 @@ std::vector<bool> bitwiseOr(std::vector<bool> a, const std::vector<bool>& b) {
   return a;
 }
 
-
 void PropagateRequiresGradSimpleNode(Node* node) {
   static const OperatorSet comparison_ops = {
-    "aten::lt(Tensor self, Tensor other) -> Tensor",
-    "aten::le(Tensor self, Tensor other) -> Tensor",
-    "aten::gt(Tensor self, Tensor other) -> Tensor",
-    "aten::ge(Tensor self, Tensor other) -> Tensor",
-    "aten::eq(Tensor self, Tensor other) -> Tensor",
-    "aten::ne(Tensor self, Tensor other) -> Tensor",
-    "aten::lt(Tensor self, Scalar other) -> Tensor",
-    "aten::le(Tensor self, Scalar other) -> Tensor",
-    "aten::gt(Tensor self, Scalar other) -> Tensor",
-    "aten::ge(Tensor self, Scalar other) -> Tensor",
-    "aten::eq(Tensor self, Scalar other) -> Tensor",
-    "aten::ne(Tensor self, Scalar other) -> Tensor",
+      "aten::lt(Tensor self, Tensor other) -> Tensor",
+      "aten::le(Tensor self, Tensor other) -> Tensor",
+      "aten::gt(Tensor self, Tensor other) -> Tensor",
+      "aten::ge(Tensor self, Tensor other) -> Tensor",
+      "aten::eq(Tensor self, Tensor other) -> Tensor",
+      "aten::ne(Tensor self, Tensor other) -> Tensor",
+      "aten::lt(Tensor self, Scalar other) -> Tensor",
+      "aten::le(Tensor self, Scalar other) -> Tensor",
+      "aten::gt(Tensor self, Scalar other) -> Tensor",
+      "aten::ge(Tensor self, Scalar other) -> Tensor",
+      "aten::eq(Tensor self, Scalar other) -> Tensor",
+      "aten::ne(Tensor self, Scalar other) -> Tensor",
   };
 
   if (comparison_ops.find(node)) {
     return setRequiresGrad(node->output(), false);
-  } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
+  } else if (node->matches(
+                 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
     return setRequiresGrad(node->output(), node->input(0)->requires_grad());
   } else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
     return setRequiresGrad(node->output(), false);
@@ -65,17 +68,19 @@ void PropagateRequiresGradSimpleNode(Node* node) {
 
   auto inputs = node->inputs();
   auto outputs = node->outputs();
-  bool should_require = std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
+  bool should_require =
+      std::any_of(inputs.begin(), inputs.end(), getRequiresGrad);
   for (Value* output : outputs) {
     if (auto type = output->type()->cast<TensorType>()) {
-      setRequiresGrad(output, should_require && at::isFloatingType(type->scalarType()));
+      setRequiresGrad(
+          output, should_require && at::isFloatingType(type->scalarType()));
     }
   }
 }
 
-void PropagateRequiresGrad(Block * block);
+void PropagateRequiresGrad(Block* block);
 
-void PropagateRequiresGrad(Node * node) {
+void PropagateRequiresGrad(Node* node) {
   if (node->kind() == prim::If) {
     auto blocks = node->blocks();
     auto true_block = blocks.at(0);
@@ -84,20 +89,24 @@ void PropagateRequiresGrad(Node * node) {
     PropagateRequiresGrad(true_block);
     PropagateRequiresGrad(false_block);
 
-    auto outputs_require =
-      bitwiseOr(fmap(true_block->outputs(), getRequiresGrad),
-                fmap(false_block->outputs(), getRequiresGrad));
+    auto outputs_require = bitwiseOr(
+        fmap(true_block->outputs(), getRequiresGrad),
+        fmap(false_block->outputs(), getRequiresGrad));
     setRequiresGrad(node, outputs_require);
   } else if (node->kind() == prim::Loop) {
     auto body = node->blocks().at(0);
-    std::vector<bool> body_inputs_require = fmap(node->inputs().slice(2), getRequiresGrad);
-    std::vector<bool> body_outputs_require (node->outputs().size(), false);
+    std::vector<bool> body_inputs_require =
+        fmap(node->inputs().slice(2), getRequiresGrad);
+    std::vector<bool> body_outputs_require(node->outputs().size(), false);
 
     while (body_inputs_require != body_outputs_require) {
-      body_inputs_require = bitwiseOr(body_inputs_require, body_outputs_require);
-      setRequiresGrad(body->param_node()->outputs().slice(1), body_inputs_require);
+      body_inputs_require =
+          bitwiseOr(body_inputs_require, body_outputs_require);
+      setRequiresGrad(
+          body->param_node()->outputs().slice(1), body_inputs_require);
       PropagateRequiresGrad(body);
-      body_outputs_require = fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
+      body_outputs_require =
+          fmap(body->return_node()->inputs().slice(1), getRequiresGrad);
     }
 
     setRequiresGrad(node, body_outputs_require);
@@ -106,8 +115,8 @@ void PropagateRequiresGrad(Node * node) {
   }
 }
 
-void PropagateRequiresGrad(Block * block) {
-  for (Node * node : block->nodes()) {
+void PropagateRequiresGrad(Block* block) {
+  for (Node* node : block->nodes()) {
     PropagateRequiresGrad(node);
   }
 }
@@ -118,4 +127,5 @@ void PropagateRequiresGrad(std::shared_ptr<Graph>& graph) {
   PropagateRequiresGrad(graph->block());
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 5300c26..cbe09ff 100644 (file)
@@ -4,12 +4,13 @@
 
 #include <memory>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct Graph;
 struct ArgumentSpec;
 
 TORCH_API void PropagateRequiresGrad(std::shared_ptr<Graph>& graph);
 
-}}
-
+} // namespace jit
+} // namespace torch
index e4f321d..69ed1ef 100644 (file)
@@ -1,10 +1,10 @@
 #include <torch/csrc/jit/passes/shape_analysis.h>
 
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 
 #include <torch/csrc/autograd/variable.h>
 #include <utility>
 #include <vector>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct propagation_error : std::exception {};
 
-#define SHAPE_ASSERT(cond) if (!(cond)) throw propagation_error()
+#define SHAPE_ASSERT(cond) \
+  if (!(cond))             \
+  throw propagation_error()
 
 namespace {
 
@@ -61,6 +64,7 @@ class ShapePropagator {
       }
     }
   }
+
  private:
   const AliasDb aliasDb_;
 
@@ -191,7 +195,6 @@ class ShapePropagator {
       "aten::inverse(Tensor self) -> Tensor",
   };
 
-
   // Check if this node depends on a value that has been mutated previously. If
   // it has, then it's not safe to run this node in isolation, since we don't
   // know whether the dependency has been executed.
@@ -328,11 +331,12 @@ class ShapePropagator {
       }
       return false;
     };
-    auto list_node = ((cat_node->kind() == prim::FusedConcat)
-                     ? cat_node
-                     : cat_node->namedInput(attr::tensors)->node());
-    if (list_node->kind() == prim::ListConstruct
-       || cat_node->kind() == prim::FusedConcat) {
+    auto list_node =
+        ((cat_node->kind() == prim::FusedConcat)
+             ? cat_node
+             : cat_node->namedInput(attr::tensors)->node());
+    if (list_node->kind() == prim::ListConstruct ||
+        cat_node->kind() == prim::FusedConcat) {
       auto tensors = list_node->inputs();
       if (!tensors.empty()) {
         if (propagate_complete(cat_node, tensors)) {
@@ -392,7 +396,8 @@ class ShapePropagator {
         return; // correct num type is already set
       case prim::NumToTensor: {
         TypePtr typ = node->input()->type();
-        if (typ->isSubtypeOf(IntType::get()) || typ->isSubtypeOf(BoolType::get())) {
+        if (typ->isSubtypeOf(IntType::get()) ||
+            typ->isSubtypeOf(BoolType::get())) {
           node->output()->setType(TensorType::create(at::kLong, at::kCPU, 0));
         } else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
           node->output()->setType(TensorType::create(at::kDouble, at::kCPU, 0));
@@ -446,8 +451,8 @@ class ShapePropagator {
       return;
     }
 
-    if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")
-       || node->kind() == prim::FusedConcat) {
+    if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
+        node->kind() == prim::FusedConcat) {
       return PropagateCatShape(node);
     }
 
@@ -495,8 +500,8 @@ class ShapePropagator {
   // primitive/tensor outputs.
 
   bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
-    static const auto broadcast =
-        [](std::vector<TensorTypePtr>& tensor_types, size_t arg_for_type) -> TensorTypePtr {
+    static const auto broadcast = [](std::vector<TensorTypePtr>& tensor_types,
+                                     size_t arg_for_type) -> TensorTypePtr {
       if (tensor_types.size() == 1) {
         return tensor_types[0];
       }
@@ -693,12 +698,12 @@ class ShapePropagator {
           return {};
         }};
 
-    // aten::where is special in that its return type is the second argument's (self)
-    // type rather than the that of condition
+    // aten::where is special in that its return type is the second argument's
+    // (self) type rather than the that of condition
     static const register_formula_for where_op{
         {
             "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
-       },
+        },
         [this](Node* node) -> type_vec_t {
           if (auto maybe_tensor_types = gatherTensorTypes<TensorType>(node)) {
             return {broadcast(*maybe_tensor_types, 1)};
@@ -764,8 +769,8 @@ class ShapePropagator {
 
     // Requirements:
     //   dims           : preserved from the first argument
-    //   scalar type    : preserved from the first argument (doesn't have to match other arguments)
-    //   device         : always matching and preserved
+    //   scalar type    : preserved from the first argument (doesn't have to
+    //   match other arguments) device         : always matching and preserved
     //   tensor inputs  : *
     //   tensor outputs : 1
     // NB: those ops (with slight adjustments) are good candidates for restarts.
@@ -946,24 +951,22 @@ class ShapePropagator {
               node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
         }};
 
-
     // Requirements:
-    //   dims           : preserved if keepdim == false, dim->size() smaller otherwise
-    //   scalar type    : preserved
-    //   device         : preserved
-    //   tensor inputs  : 1
-    //   tensor outputs : 1
+    //   dims           : preserved if keepdim == false, dim->size() smaller
+    //   otherwise scalar type    : preserved device         : preserved tensor
+    //   inputs  : 1 tensor outputs : 1
     // Additionally:
     //   - First input should be the only tensor input
     //   - has a bool keepdim argument
-    static const register_formula_for multidim_reduce_ops {
+    static const register_formula_for multidim_reduce_ops{
         {
             "aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
             "aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
         },
-        [](Node * node) -> type_vec_t {
+        [](Node* node) -> type_vec_t {
           if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
-            return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/false);
+            return multidim_reduce_with_postprocess(
+                node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/false);
           }
           return {};
         }};
@@ -1578,7 +1581,6 @@ class ShapePropagator {
     setUnshapedType(node);
     return false;
   }
-
 };
 } // anonymous namespace
 
@@ -1589,17 +1591,17 @@ void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {
 namespace {
 
 void EraseShapeInformation(at::ArrayRef<Value*> vals) {
-  for (Value * v : vals) {
+  for (Value* v : vals) {
     v->setType(unshapedType(v->type()));
   }
 }
 
-void EraseShapeInformation(Block * b) {
+void EraseShapeInformation(Block* b) {
   EraseShapeInformation(b->inputs());
   EraseShapeInformation(b->outputs());
-  for (Node * n : b->nodes()) {
+  for (Node* n : b->nodes()) {
     EraseShapeInformation(n->outputs());
-    for (Block *sb : n->blocks()) {
+    for (Blocksb : n->blocks()) {
       EraseShapeInformation(sb);
     }
   }
@@ -1611,4 +1613,5 @@ void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
   EraseShapeInformation(graph->block());
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 5886c68..6bc5ee8 100644 (file)
@@ -1,13 +1,15 @@
 #pragma once
 
-#include <memory>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <memory>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct Graph;
 
 TORCH_API void EraseShapeInformation(const std::shared_ptr<Graph>& graph);
 TORCH_API void PropagateInputShapes(const std::shared_ptr<Graph>& graph);
 
-}}
+} // namespace jit
+} // namespace torch
index 3e258cf..799a304 100644 (file)
@@ -1,8 +1,8 @@
 #include <torch/csrc/jit/passes/specialize_undef.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
 
 // propagate undefined information through a gradient graph and
 // remove grad_of blocks if present.
@@ -10,7 +10,7 @@ namespace torch { namespace jit {
 // operations generated by the symbolic autodiff code and cleans up
 // AutogradAdds when possible. Outputs of other nodes are conservatively
 // marked Unknown and not optimized.
-void specializeUndef(Graph & g) {
+void specializeUndef(Graph& g) {
   enum class State { Defined, Undefined, Unknown };
   std::unordered_map<Value*, State> state;
 
@@ -25,29 +25,31 @@ void specializeUndef(Graph & g) {
     }
   }
 
-  for(auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
+  for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) {
     auto n = *it;
-    switch(n->kind()) {
+    switch (n->kind()) {
       case prim::GradOf: {
         auto all_undefined =
             std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
               return state[v] == State::Undefined;
             });
         // Property 1: if all the gradInputs to the GradOf are undefined
-        // then the gradOutputs are also zero and will be represented as undefined nodes
-        if(all_undefined) {
+        // then the gradOutputs are also zero and will be represented as
+        // undefined nodes
+        if (all_undefined) {
           auto undef = g.createUndefined()->insertAfter(n)->output();
-          for(auto o : n->outputs()) {
+          for (auto o : n->outputs()) {
             o->replaceAllUsesWith(undef);
           }
         } else {
-        // Property 2: GradOfs are required to correctly handle combinations
-        // of defined and undefined inputs. They are expected to produce defined
-        // output tensors in this case.
+          // Property 2: GradOfs are required to correctly handle combinations
+          // of defined and undefined inputs. They are expected to produce
+          // defined output tensors in this case.
 
-          // Remove the GradOf, splicing its body back into the surrounding block
+          // Remove the GradOf, splicing its body back into the surrounding
+          // block
           auto body = n->blocks().at(0);
-          for(auto input : n->inputs()){
+          for (auto input : n->inputs()) {
             // we should never get into a situation when specializing a GradOf
             // where we do not know if a value is defined since at the top level
             // a gradient graph is composed of Linear nodes and AutogradAdds
@@ -55,12 +57,12 @@ void specializeUndef(Graph & g) {
             JIT_ASSERT(state[input] != State::Unknown);
           }
           // hoist the nodes in the GradOf body to be before the linear block
-          for(auto it = body->nodes().begin(); it != body->nodes().end();) {
+          for (auto it = body->nodes().begin(); it != body->nodes().end();) {
             auto block_node = *it++;
             block_node->moveBefore(n);
           }
 
-          for(size_t i = 0; i < n->outputs().size(); ++i)
+          for (size_t i = 0; i < n->outputs().size(); ++i)
             n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
         }
         it.destroyCurrent();
@@ -69,16 +71,17 @@ void specializeUndef(Graph & g) {
         auto a = n->input(0);
         auto b = n->input(1);
         // if one is undefined, we can just drop the add
-        if(state[a] == State::Undefined) {
+        if (state[a] == State::Undefined) {
           // Undef + b == b
           n->output()->replaceAllUsesWith(b);
           it.destroyCurrent();
-        } else if(state[b] == State::Undefined) {
+        } else if (state[b] == State::Undefined) {
           // a + Undef == a
           n->output()->replaceAllUsesWith(a);
           it.destroyCurrent();
-        } else if(state[a] == State::Defined && state[b] == State::Defined) {
-          // when both are defined, we can use a normal, optimizable add instruction
+        } else if (state[a] == State::Defined && state[b] == State::Defined) {
+          // when both are defined, we can use a normal, optimizable add
+          // instruction
           WithInsertPoint guard(n);
           Value* new_add = toVar(a) + toVar(b);
           state[new_add] = State::Defined;
@@ -95,7 +98,7 @@ void specializeUndef(Graph & g) {
         state[n->output()] = State::Undefined;
       } break;
       default:
-        for(auto o : n->outputs()) {
+        for (auto o : n->outputs()) {
           state[o] = State::Unknown;
         }
         break;
@@ -103,4 +106,5 @@ void specializeUndef(Graph & g) {
   }
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 09b9fe4..f829570 100644 (file)
@@ -2,8 +2,8 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
-
+namespace torch {
+namespace jit {
 
 // propagate undefined information through a gradient graph and
 // remove grad_of blocks if present.
@@ -11,6 +11,7 @@ namespace torch { namespace jit {
 // operations generated by the symbolic autodiff code and cleans up
 // AutogradAdds when possible. Outputs of other nodes are conservatively
 // marked Unknown and not optimized.
-TORCH_API void specializeUndef(Graph & g);
+TORCH_API void specializeUndef(Graph& g);
 
-}}
+} // namespace jit
+} // namespace torch
index 5a1c940..2f79d17 100644 (file)
@@ -1,47 +1,57 @@
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> ToBatch::batch_operator_table;
+std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
+    ToBatch::batch_operator_table;
 
-std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_t num_inputs){
-  if(batch_operator_table.find(name) == batch_operator_table.end()){
-    throw std::runtime_error("function " + name + " is not supported in batched tensor yet");
+std::shared_ptr<Graph> ToBatch::getBatchOperator(
+    const std::string& name,
+    int64_t num_inputs) {
+  if (batch_operator_table.find(name) == batch_operator_table.end()) {
+    throw std::runtime_error(
+        "function " + name + " is not supported in batched tensor yet");
   }
   auto ops = batch_operator_table.at(name);
-  if(num_inputs == -1)  // default function
+  if (num_inputs == -1) // default function
     return ops[0];
-  for(auto op : ops){
-    if(size_t(num_inputs) == op->inputs().size())
+  for (auto op : ops) {
+    if (size_t(num_inputs) == op->inputs().size())
       return op;
   }
-  throw std::runtime_error("function " + name + " with " + std::to_string(num_inputs) + " inputs is not supported in batched tensor yet");
+  throw std::runtime_error(
+      "function " + name + " with " + std::to_string(num_inputs) +
+      " inputs is not supported in batched tensor yet");
 }
 
-std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+std::vector<Value*> inlineUnpackedCallTo(
+    Graph& g,
+    Graph& callee,
+    ArrayRef<Value*> inputs) {
   return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
 }
 
 // replace aten operator node with BatchTensor operator graph
-void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
+void ToBatch::visitAten(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
   auto func_name = std::string(n->kind().toUnqualString());
   std::vector<Value*> new_inputs;
-  for(Value *input : n->inputs()){
-    if(rn_env.find(input) == rn_env.end()){  // non-tensor input
+  for (Value* input : n->inputs()) {
+    if (rn_env.find(input) == rn_env.end()) { // non-tensor input
       auto new_input = batch_map.at(input);
       new_inputs.insert(new_inputs.end(), new_input.begin(), new_input.end());
-    }
-    else{  // batched tensor input
+    } else { // batched tensor input
       new_inputs.push_back(rn_env.at(input));
     }
   }
 
   // transform scalar to tensor before pass to batch operator script
-    for (auto& input : new_inputs) {
-    if(input->type() == IntType::get() || input->type() == FloatType::get() || input->type() == BoolType::get()){
+  for (auto& input : new_inputs) {
+    if (input->type() == IntType::get() || input->type() == FloatType::get() ||
+        input->type() == BoolType::get()) {
       auto to_tensor_node = res_graph->createNumToTensor(input);
       res_graph->insertNode(to_tensor_node);
       input = to_tensor_node->output();
@@ -49,11 +59,14 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
   }
 
   auto batch_graph = getBatchOperator(func_name, new_inputs.size());
-  auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
+  auto outputs =
+      inlineUnpackedCallTo(*res_block->owningGraph(), *batch_graph, new_inputs);
 
-  // Assume all outputs from inlined operator implementation are in the triple form batched tensor or just a single non-tensor.
+  // Assume all outputs from inlined operator implementation are in the triple
+  // form batched tensor or just a single non-tensor.
   if (outputs.size() == 1) {
-    // if previous output is scalar, transform new output back to scalar from dynamic
+    // if previous output is scalar, transform new output back to scalar from
+    // dynamic
     TypePtr orig_type = n->outputs()[0]->type();
     if (!orig_type->isSubtypeOf(outputs[0]->type())) {
       Symbol op;
@@ -64,24 +77,28 @@ void ToBatch::visitAten(Node* n, Block* block, Block* res_block){
       } else if (orig_type == BoolType::get()) {
         op = prim::Bool;
       } else {
-        throw std::runtime_error("NYI: scalar types other than int, float, and bool are not supported yet");
+        throw std::runtime_error(
+            "NYI: scalar types other than int, float, and bool are not supported yet");
       }
-      rn_env[n->outputs()[0]] = res_graph->insert(op, { outputs[0] });
+      rn_env[n->outputs()[0]] = res_graph->insert(op, {outputs[0]});
     } else {
       rn_env[n->outputs()[0]] = outputs[0];
     }
   } else {
-    for(size_t i = 0; i < n->outputs().size(); i++){
+    for (size_t i = 0; i < n->outputs().size(); i++) {
       auto output = n->outputs()[i];
-      batch_map[output] = std::vector<Value*>(outputs.begin() + i * EXP_BTENSOR_SIZE, outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
+      batch_map[output] = std::vector<Value*>(
+          outputs.begin() + i * EXP_BTENSOR_SIZE,
+          outputs.begin() + i * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE);
     }
   }
 }
 
 // clone prim::Constant to new graph
 // batching transformation is applied to the output of prim::NumToTensor.
-// If there is a prim::NumToTensor following prim::Constant, it will be finally transformed to BatchTensor.
-void ToBatch::visitConstant(Node* n, Block* block, Block* res_block){
+// If there is a prim::NumToTensor following prim::Constant, it will be finally
+// transformed to BatchTensor.
+void ToBatch::visitConstant(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
   auto* r_node = res_graph->createClone(n, rn_fn);
   res_block->appendNode(r_node);
@@ -89,18 +106,21 @@ void ToBatch::visitConstant(Node* n, Block* block, Block* res_block){
 }
 
 // change return tensor to expanded batched tensor, eg: {data, mask, dims}
-void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block){
+void ToBatch::visitNumToTensor(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
   auto* r_node = res_graph->createClone(n, rn_fn);
   res_block->appendNode(r_node);
-  auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("batch_from_scalar_tensor"), r_node->outputs());
+  auto outputs = inlineUnpackedCallTo(
+      *res_block->owningGraph(),
+      *getBatchOperator("batch_from_scalar_tensor"),
+      r_node->outputs());
   batch_map[n->output()] = outputs;
 }
 
 // clone prim::TensorToNum to new graph
-void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block){
+void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
-  if(rn_env.find(n->input()) == rn_env.end()){
+  if (rn_env.find(n->input()) == rn_env.end()) {
     rn_env[n->input()] = batch_map.at(n->input())[0];
   }
   auto* r_node = res_graph->createClone(n, rn_fn);
@@ -110,32 +130,34 @@ void ToBatch::visitTensorToNum(Node* n, Block* block, Block* res_block){
 }
 
 // clone prim::ListConstruct to new graph
-void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block){
+void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
-  if(n->inputs()[0]->type() == DynamicType::get()){  // TensorList: expand directly
+  if (n->inputs()[0]->type() ==
+      DynamicType::get()) { // TensorList: expand directly
     std::vector<Value*> inputs;
-    for(Value* input: n->inputs()) {
+    for (Value* input : n->inputs()) {
       auto res = batch_map.at(input);
       inputs.insert(inputs.end(), res.begin(), res.end());
     }
     batch_map[n->output()] = inputs;
-  }
-  else {  // ScalarList: transform to tensor, then transform back
-    for(Value* input : n->inputs()) {
-      if(rn_env.find(input) == rn_env.end()){
+  } else { // ScalarList: transform to tensor, then transform back
+    for (Value* input : n->inputs()) {
+      if (rn_env.find(input) == rn_env.end()) {
         rn_env[input] = batch_map.at(input)[0];
       }
     }
     auto* r_node = res_graph->createClone(n, rn_fn);
     res_block->appendNode(r_node);
     // transform int[] to tensor
-    auto to_tensor_node = res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
+    auto to_tensor_node =
+        res_graph->create(Symbol::fromQualString("aten::_list_to_tensor"));
     to_tensor_node->addInput(r_node->output());
     res_block->appendNode(to_tensor_node);
     rn_env[n->output()] = to_tensor_node->output();
   }
 }
 
+// clang-format off
 // prim::If transformation:
 // elif is not supported
 //
@@ -216,17 +238,17 @@ void ToBatch::visitListConstruct(Node* n, Block* block, Block* res_block){
 //   %res_dims : Dynamic = aten::__or__(%dims.1, %dims)
 //   return (%res_data, %res_mask, %res_dims);
 // }
-void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
+// clang-format on
+void ToBatch::visitIf(Node* n, Block* block, Block* res_block) {
   toBatch(n->blocks()[0], res_block);
   toBatch(n->blocks()[1], res_block);
 
   // combine results from two if paths
-  for(size_t i = 0; i < n->outputs().size(); i++){
+  for (size_t i = 0; i < n->outputs().size(); i++) {
     std::vector<Value*> inputs;
-    if(batch_map.find(n->input()) == batch_map.end()){  // cond is scalar
+    if (batch_map.find(n->input()) == batch_map.end()) { // cond is scalar
       inputs.push_back(rn_env.at(n->input()));
-    }
-    else{  // cond is tensor
+    } else { // cond is tensor
       auto cond = batch_map.at(n->input());
       inputs.insert(inputs.end(), cond.begin(), cond.end());
     }
@@ -234,11 +256,15 @@ void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
     inputs.insert(inputs.end(), if_output.begin(), if_output.end());
     auto else_output = batch_map.at(n->blocks()[1]->outputs()[i]);
     inputs.insert(inputs.end(), else_output.begin(), else_output.end());
-    auto outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where", inputs.size()), inputs);
+    auto outputs = inlineUnpackedCallTo(
+        *res_block->owningGraph(),
+        *getBatchOperator("where", inputs.size()),
+        inputs);
     batch_map[n->outputs()[i]] = outputs;
   }
 }
 
+// clang-format off
 // prim::Loop transformation:
 //
 // transformation example:
@@ -326,48 +352,56 @@ void ToBatch::visitIf(Node* n, Block* block, Block* res_block){
 //     }
 //   return (%a, %60, %61);
 // }
-void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
+// clang-format on
+void ToBatch::visitLoop(Node* n, Block* block, Block* res_block) {
   auto res_graph = res_block->owningGraph();
   // bool cond_is_tensor indicates whether cond is tensor
   // cond_is_tensor = false, eg: for loop, n->inputs()[1] = byte()
   // cond_is_tensor = true, eg: in some while loop, cond is a batched tensor,
-  //                            we need to add expanded cond to the inputs of loop node and block,
-  //                            and compute cond_any as cond for while loop
+  //                            we need to add expanded cond to the inputs of
+  //                            loop node and block, and compute cond_any as
+  //                            cond for while loop
   bool cond_is_tensor = (batch_map.find(n->inputs()[1]) != batch_map.end());
 
   // create prim::Loop node for res_block
 
   // type of cond in loop should be int type
-  if(rn_env.at(n->inputs()[0])->type() != IntType::get()){
-    rn_env[n->inputs()[0]] = res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
+  if (rn_env.at(n->inputs()[0])->type() != IntType::get()) {
+    rn_env[n->inputs()[0]] =
+        res_graph->insert(prim::Int, {rn_env.at(n->inputs()[0])});
   }
-  if(cond_is_tensor){
+  if (cond_is_tensor) {
     auto cond = batch_map.at(n->inputs()[1]);
-    auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
-    rn_env[n->inputs()[1]] =res_graph->insert(prim::Bool, {cond_any[0]});
+    auto cond_any = inlineUnpackedCallTo(
+        *res_block->owningGraph(), *getBatchOperator("any"), cond);
+    rn_env[n->inputs()[1]] = res_graph->insert(prim::Bool, {cond_any[0]});
   }
-  for(size_t i = 2; i < n->inputs().size(); i++){
+  for (size_t i = 2; i < n->inputs().size(); i++) {
     auto input = n->inputs()[i];
     rn_env[input] = batch_map.at(input)[0];
   }
   auto* r_node = res_graph->createClone(n, rn_fn, /*copy_blocks=*/false);
 
   // change inputs of prim::Loop
-  if(cond_is_tensor){
-    for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+  if (cond_is_tensor) {
+    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
       auto cond = batch_map.at(n->inputs()[1]);
       r_node->insertInput(i + 2, cond[i]);
     }
   }
-  for(size_t i = 2; i < n->inputs().size(); i++){
-    for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
-      r_node->insertInput((i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 + j, batch_map.at(n->inputs()[i])[j]);
+  for (size_t i = 2; i < n->inputs().size(); i++) {
+    for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
+      r_node->insertInput(
+          (i - 2) * EXP_BTENSOR_SIZE + EXP_BTENSOR_SIZE * cond_is_tensor + 2 +
+              j,
+          batch_map.at(n->inputs()[i])[j]);
     }
   }
   res_block->appendNode(r_node);
 
   // create block for Loop node in res_block
-  // if cond is tensor:    first 4 inputs of block: cond_any, cond_data, cond_mask, cond_dims
+  // if cond is tensor:    first 4 inputs of block: cond_any, cond_data,
+  //                       cond_mask, cond_dims
   // if cond is not tensor: first 1 input of block: cond
   auto loop_block = r_node->addBlock();
 
@@ -375,18 +409,24 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
   loop_block->addInput("loop_num");
   loop_block->inputs()[0]->setType(IntType::get());
   rn_env[n->blocks()[0]->inputs()[0]] = loop_block->inputs()[0];
-  if(cond_is_tensor){
-    for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+  if (cond_is_tensor) {
+    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
       loop_block->addInput("cond_" + EXP_BTENSOR_NAME[i]);
     }
   }
-  for(size_t i = 1; i < n->blocks()[0]->inputs().size(); i++){
+  for (size_t i = 1; i < n->blocks()[0]->inputs().size(); i++) {
     auto input = n->blocks()[0]->inputs()[i];
     auto name = input->uniqueName();
-    for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+    for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
       loop_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
     }
-    batch_map[input] = std::vector<Value*>(loop_block->inputs().slice((i - 1) * EXP_BTENSOR_SIZE + 1 + EXP_BTENSOR_SIZE * cond_is_tensor, EXP_BTENSOR_SIZE).vec());
+    batch_map[input] =
+        std::vector<Value*>(loop_block->inputs()
+                                .slice(
+                                    (i - 1) * EXP_BTENSOR_SIZE + 1 +
+                                        EXP_BTENSOR_SIZE * cond_is_tensor,
+                                    EXP_BTENSOR_SIZE)
+                                .vec());
   }
 
   toBatch(n->blocks()[0], loop_block);
@@ -394,59 +434,63 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
   WithInsertPoint guard(loop_block);
 
   // use where operator to update variables and add to outputs
-  for(size_t i = 0; i < n->outputs().size(); i++){
+  for (size_t i = 0; i < n->outputs().size(); i++) {
     std::vector<Value*> inputs, outputs;
-    if(cond_is_tensor){
-      for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+    if (cond_is_tensor) {
+      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
         inputs.push_back(loop_block->inputs()[j + 1]);
       }
       auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
       inputs.insert(inputs.end(), data.begin(), data.end());
-      for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
-        inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
+      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
+        inputs.push_back(
+            loop_block
+                ->inputs()[i * EXP_BTENSOR_SIZE + j + EXP_BTENSOR_SIZE + 1]);
       }
-      outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("where"), inputs);
-    }
-    else{
-      for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+      outputs = inlineUnpackedCallTo(
+          *res_block->owningGraph(), *getBatchOperator("where"), inputs);
+    } else {
+      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
         inputs.push_back(loop_block->inputs()[i * EXP_BTENSOR_SIZE + j + 1]);
       }
       auto data = batch_map.at(n->blocks()[0]->outputs()[i + 1]);
       inputs.insert(inputs.end(), data.begin(), data.end());
-      outputs = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("update"), inputs);
+      outputs = inlineUnpackedCallTo(
+          *res_block->owningGraph(), *getBatchOperator("update"), inputs);
     }
     batch_map[n->outputs()[i]] = outputs;
-    for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+    for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
       loop_block->registerOutput(outputs[j]);
     }
   }
 
   // update loop conditions
-  if(cond_is_tensor){
+  if (cond_is_tensor) {
     auto cond = batch_map.at(n->blocks()[0]->outputs()[0]);
-    auto cond_any = inlineUnpackedCallTo(*res_block->owningGraph(), *getBatchOperator("any"), cond);
+    auto cond_any = inlineUnpackedCallTo(
+        *res_block->owningGraph(), *getBatchOperator("any"), cond);
     auto to_bool_output = res_graph->insert(prim::Bool, {cond_any[0]});
-    loop_block->insertOutput(0,  to_bool_output);
-    for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+    loop_block->insertOutput(0, to_bool_output);
+    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
       loop_block->insertOutput(i + 1, cond[i]);
     }
-  }
-  else{
+  } else {
     auto cond = rn_env.at(n->blocks()[0]->outputs()[0]);
     loop_block->insertOutput(0, cond);
   }
 
   // change outputs of prim::Loop
   auto size = r_node->outputs().size();
-  for(size_t i = 0; i < size; i++){
-    for(size_t j = 1; j < EXP_BTENSOR_SIZE; j++){
+  for (size_t i = 0; i < size; i++) {
+    for (size_t j = 1; j < EXP_BTENSOR_SIZE; j++) {
       r_node->insertOutput(i * EXP_BTENSOR_SIZE + j);
     }
-    batch_map[n->outputs()[i]] = r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
+    batch_map[n->outputs()[i]] =
+        r_node->outputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec();
   }
   // add cond to outputs of loop node
-  if(cond_is_tensor){
-    for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+  if (cond_is_tensor) {
+    for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
       r_node->insertOutput(i);
     }
   }
@@ -455,28 +499,30 @@ void ToBatch::visitLoop(Node* n, Block* block, Block* res_block){
 void ToBatch::toBatch(Block* block, Block* res_block) {
   WithInsertPoint guard(res_block);
 
-  // change inputs of block - expand tensor to batchtensor eg: (data, mask, dims)
-  // eg: a -> a_data, a_mask, a_dims
-  // for block in prim::Loop, register inputs separately to deal with cond
-  if(!block->owningNode() || block->owningNode()->kind() != prim::Loop){
+  // change inputs of block-expand tensor to batchtensor eg: (data, mask, dims)
+  // eg: a -> a_data, a_mask, a_dims for block in prim::Loop, register inputs
+  // separately to deal with cond
+  if (!block->owningNode() || block->owningNode()->kind() != prim::Loop) {
     auto size = block->inputs().size();
-    for(size_t i = 0; i < size; i++){
+    for (size_t i = 0; i < size; i++) {
       auto input = block->inputs()[i];
       auto name = input->uniqueName();
-      for(size_t j = 0; j < EXP_BTENSOR_SIZE; j++){
+      for (size_t j = 0; j < EXP_BTENSOR_SIZE; j++) {
         res_block->addInput(name + "_" + EXP_BTENSOR_NAME[j]);
       }
-      batch_map[input] = std::vector<Value*>(res_block->inputs().slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE).vec());
+      batch_map[input] =
+          std::vector<Value*>(res_block->inputs()
+                                  .slice(i * EXP_BTENSOR_SIZE, EXP_BTENSOR_SIZE)
+                                  .vec());
     }
   }
 
   for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
     auto n = *it;
-    if(n->kind().is_aten()){
+    if (n->kind().is_aten()) {
       visitAten(n, block, res_block);
-    }
-    else if(n->kind().is_prim()){
-      switch(n->kind()){
+    } else if (n->kind().is_prim()) {
+      switch (n->kind()) {
         case prim::Constant:
         case prim::None:
           visitConstant(n, block, res_block);
@@ -499,20 +545,26 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
           visitLoop(n, block, res_block);
           break;
         default:
-          throw std::runtime_error("NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
+          throw std::runtime_error(
+              "NYI: node of prim kind other than [Constant, NumToTensor, TensorToNum, If, Loop] is not supported yet");
       }
-    }
-    else{
-      throw std::runtime_error("NYI: node that is not aten or prim kind is not supported yet");
+    } else {
+      throw std::runtime_error(
+          "NYI: node that is not aten or prim kind is not supported yet");
     }
   }
   // change outputs of block - expand tensor to batchtensor(data, mask, dims)
-  // for block in prim::Loop, register outputs separately to deal with cond and cond_any
-  // for block in prim::If, register outputs separately by combining outputs from two paths and return
-  if(!block->owningNode() || (block->owningNode()->kind() != prim::Loop && block->owningNode()->kind() != prim::If)) {
-    for(Value* output : block->outputs()){
+  // for block in prim::Loop, register outputs separately to deal with cond and
+  // cond_any
+  //
+  // for block in prim::If, register outputs separately by combining
+  // outputs from two paths and return
+  if (!block->owningNode() ||
+      (block->owningNode()->kind() != prim::Loop &&
+       block->owningNode()->kind() != prim::If)) {
+    for (Value* output : block->outputs()) {
       auto r_output = batch_map.at(output);
-      for(size_t i = 0; i < EXP_BTENSOR_SIZE; i++){
+      for (size_t i = 0; i < EXP_BTENSOR_SIZE; i++) {
         res_block->registerOutput(r_output[i]);
       }
     }
@@ -525,7 +577,7 @@ std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
     graph = graph->copy();
     auto outs = createTupleUnpack(graph->outputs().at(0));
     graph->eraseOutput(0);
-    for(auto o : outs)
+    for (auto o : outs)
       graph->registerOutput(o);
     EliminateDeadCode(graph->block());
   }
@@ -533,8 +585,10 @@ std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
   ToBatch to_batch;
   to_batch.toBatch(graph->block(), res_graph->block());
 
-  // methods should only have a single output, so we pack everything into a tuple
-  auto tup = res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
+  // methods should only have a single output, so we pack everything into a
+  // tuple
+  auto tup =
+      res_graph->insertNode(res_graph->createTuple(res_graph->outputs()));
   while (res_graph->outputs().size() > 0)
     res_graph->eraseOutput(res_graph->outputs().size() - 1);
   res_graph->registerOutput(tup->output());
@@ -546,9 +600,12 @@ std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph) {
 void initRegisterBatchOpsBindings(PyObject* module) {
   auto m = py::handle(module).cast<py::module>();
   m.def("to_batch_graph", to_batch_graph);
-  m.def("register_batch_operator", [](std::string name, std::shared_ptr<Graph> graph){
-    ToBatch::batch_operator_table[name].push_back(graph);
-  });
+  m.def(
+      "register_batch_operator",
+      [](std::string name, std::shared_ptr<Graph> graph) {
+        ToBatch::batch_operator_table[name].push_back(graph);
+      });
 }
 
-}} // namespace torch.jit
+} // namespace jit
+} // namespace torch
index 959f265..76bf53d 100644 (file)
@@ -1,25 +1,32 @@
 #pragma once
 
-#include <torch/csrc/jit/pybind.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/pybind.h>
 
 #include <ATen/ATen.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 class ToBatch {
-private:
-  // number of tensors to represent a expanded BatchTensor. {data, mask, dims} for now.
+ private:
+  // number of tensors to represent a expanded BatchTensor. {data, mask, dims}
+  // for now.
   const size_t EXP_BTENSOR_SIZE = 3;
   const std::vector<std::string> EXP_BTENSOR_NAME = {"data", "mask", "dims"};
   // mapping from tensor in original graph to {data, mask, dims} in new graph
   std::unordered_map<Value*, std::vector<Value*>> batch_map;
-  // mapping from input in original graph to new input in new graph - used in createClone
+  // mapping from input in original graph to new input in new graph - used in
+  // createClone
   std::unordered_map<Value*, Value*> rn_env;
-  std::function<Value*(Value*)> rn_fn = [this](Value* v) { return rn_env.at(v); };
+  std::function<Value*(Value*)> rn_fn = [this](Value* v) {
+    return rn_env.at(v);
+  };
 
-private:
-  std::shared_ptr<Graph> getBatchOperator(const std::string& name, int64_t input_num = -1);
+ private:
+  std::shared_ptr<Graph> getBatchOperator(
+      const std::string& name,
+      int64_t input_num = -1);
   void visitAten(Node* n, Block* block, Block* res_block);
   void visitConstant(Node* n, Block* block, Block* res_block);
   void visitNumToTensor(Node* n, Block* block, Block* res_block);
@@ -28,11 +35,13 @@ private:
   void visitIf(Node* n, Block* block, Block* res_block);
   void visitLoop(Node* n, Block* block, Block* res_block);
 
-public:
-  static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>> batch_operator_table;
+ public:
+  static std::unordered_map<std::string, std::vector<std::shared_ptr<Graph>>>
+      batch_operator_table;
   TORCH_API void toBatch(Block* block, Block* res_block);
 };
 
 TORCH_API std::shared_ptr<Graph> to_batch_graph(std::shared_ptr<Graph> graph);
 TORCH_API void initRegisterBatchOpsBindings(PyObject* module);
-}}
+} // namespace jit
+} // namespace torch
index a9432c5..93e18df 100644 (file)
@@ -72,9 +72,7 @@ bool deepEquals(const IValue& lhs, const IValue& rhs) {
 }
 
 struct AliasAndIValue {
-  AliasAndIValue(
-      c10::optional<at::AliasInfo> aliasInfo,
-      IValue iValue)
+  AliasAndIValue(c10::optional<at::AliasInfo> aliasInfo, IValue iValue)
       : aliasInfo(std::move(aliasInfo)), iValue(std::move(iValue)) {}
 
   const c10::optional<at::AliasInfo> aliasInfo;
index 44e29dd..1e03fb8 100644 (file)
@@ -2,14 +2,14 @@
 
 #include <torch/csrc/python_headers.h>
 
-#include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/DynamicTypes.h>
 #include <torch/csrc/THP.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/tracer.h>
 #include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/pybind_utils.h>
+#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/utils/pybind.h>
 
 #include <pybind11/functional.h>
 #include <pybind11/pybind11.h>
 
 namespace py = pybind11;
 
-namespace pybind11 { namespace detail {
+namespace pybind11 {
+namespace detail {
 
-template <> struct type_caster<torch::jit::IValue> {
-public:
+template <>
+struct type_caster<torch::jit::IValue> {
+ public:
   PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue"));
 
   bool load(handle src, bool) {
@@ -32,13 +34,17 @@ public:
     }
   }
 
-  static handle cast(torch::jit::IValue src, return_value_policy /* policy */, handle /* parent */) {
+  static handle cast(
+      torch::jit::IValue src,
+      return_value_policy /* policy */,
+      handle /* parent */) {
     return torch::jit::toPyObject(std::move(src)).release();
   }
 };
 
-template <> struct type_caster<torch::jit::Symbol> {
-public:
+template <>
+struct type_caster<torch::jit::Symbol> {
+ public:
   PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
 
   bool load(handle src, bool) {
@@ -54,45 +60,67 @@ public:
     return true;
   }
 
-  static handle cast(torch::jit::Symbol src, return_value_policy /* policy */, handle /* parent */) {
-    return py::cast(std::string(src.toQualString()), return_value_policy::copy).release();
+  static handle cast(
+      torch::jit::Symbol src,
+      return_value_policy /* policy */,
+      handle /* parent */) {
+    return py::cast(std::string(src.toQualString()), return_value_policy::copy)
+        .release();
   }
 };
 
-template <> struct type_caster<torch::jit::AttributeKind> {
-public:
+template <>
+struct type_caster<torch::jit::AttributeKind> {
+ public:
   PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
 
   bool load(handle src, bool) {
     return false;
   }
 
-  static handle cast(torch::jit::AttributeKind src, return_value_policy /* policy */, handle /* parent */) {
-    return py::cast(std::string(torch::jit::toString(src)), return_value_policy::copy).release();
+  static handle cast(
+      torch::jit::AttributeKind src,
+      return_value_policy /* policy */,
+      handle /* parent */) {
+    return py::cast(
+               std::string(torch::jit::toString(src)),
+               return_value_policy::copy)
+        .release();
   }
 };
 
 // See https://github.com/pybind/pybind11/issues/637
-using ListCasterBase = pybind11::detail::list_caster<std::vector<torch::jit::Node *>, torch::jit::Node *>;
-template<> struct type_caster<std::vector<torch::jit::Node *>> : ListCasterBase {
-    static handle cast(const std::vector<torch::jit::Node *> &src, return_value_policy, handle parent) {
-        return ListCasterBase::cast(src, return_value_policy::reference, parent);
-    }
-    static handle cast(const std::vector<torch::jit::Node *> *src, return_value_policy pol, handle parent) {
-        return cast(*src, pol, parent);
-    }
+using ListCasterBase = pybind11::detail::
+    list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>;
+template <>
+struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
+  static handle cast(
+      const std::vector<torch::jit::Node*>& src,
+      return_value_policy,
+      handle parent) {
+    return ListCasterBase::cast(src, return_value_policy::reference, parent);
+  }
+  static handle cast(
+      const std::vector<torch::jit::Node*>* src,
+      return_value_policy pol,
+      handle parent) {
+    return cast(*src, pol, parent);
+  }
 };
 
-}} // namespace pybind11::detail
+} // namespace detail
+} // namespace pybind11
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-static inline py::tuple tuple_tail(const py::tuple & tup) {
+static inline py::tuple tuple_tail(const py::tuple& tup) {
   py::tuple r(tup.size() - 1);
-  for(size_t i = 1; i < tup.size(); i++) {
-    r[i-1] = tup[i];
+  for (size_t i = 1; i < tup.size(); i++) {
+    r[i - 1] = tup[i];
   }
   return r;
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 5089ada..f18f954 100644 (file)
@@ -1,14 +1,14 @@
 #pragma once
 
+#include <torch/csrc/Device.h>
 #include <torch/csrc/jit/function_schema.h>
 #include <torch/csrc/jit/ivalue.h>
-#include <torch/csrc/jit/stack.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/stack.h>
 #include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/utils/auto_gil.h>
-#include <torch/csrc/Device.h>
+#include <torch/csrc/utils/pybind.h>
 
 #include <c10/util/Exception.h>
 
@@ -26,7 +26,8 @@
 #define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
 #endif
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 namespace detail {
 
 // error reporting: when reporting user-caused errors, these functions should
@@ -34,18 +35,18 @@ namespace detail {
 // that is confusing to display to the end user since it always reports
 // locations in libtorch code rather than user code.
 
-inline void findErrorInKwargs(
-    const FunctionSchema& schema,
-    py::kwargs kwargs) {
+inline void findErrorInKwargs(const FunctionSchema& schema, py::kwargs kwargs) {
   const auto& arguments = schema.arguments();
   // First check if any of the kwargs are unknown, i.e. don't match the name of
   // any argument in the schema.
   for (const auto& kwarg : kwargs) {
     const auto key = py::cast<std::string>(kwarg.first);
-    if(!std::count_if(
+    if (!std::count_if(
             arguments.begin(),
             arguments.end(),
-            [&key](const Argument& argument) { return argument.name() == key; })) {
+            [&key](const Argument& argument) {
+              return argument.name() == key;
+            })) {
       throw std::runtime_error(c10::str(
           "Unknown keyword argument '",
           key,
@@ -87,8 +88,9 @@ inline IValue toIValue(py::handle input) {
     }
     return Tuple::create(s);
   } else {
-    AT_ERROR("Only tensors and (possibly nested) tuples of tensors are supported "
-             "as inputs or outputs of traced functions");
+    AT_ERROR(
+        "Only tensors and (possibly nested) tuples of tensors are supported "
+        "as inputs or outputs of traced functions");
   }
 }
 
@@ -96,111 +98,119 @@ inline Stack toStack(const py::tuple& inputs) {
   return toIValue(inputs).toTuple()->elements();
 }
 
-inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N = c10::nullopt);
+inline IValue toIValue(
+    py::handle obj,
+    const TypePtr& type,
+    c10::optional<int32_t> N = c10::nullopt);
 
 inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
   std::vector<IValue> elems;
-  for(auto elem : obj) {
+  for (auto elem : obj) {
     elems.push_back(toIValue(elem, elem_type));
   }
   return List<IValue>::create(std::move(elems));
 }
 
-inline IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
-    switch (type->kind()) {
-      case TypeKind::DynamicType:
-      case TypeKind::TensorType:
-      case TypeKind::UndefinedTensorType:
-      case TypeKind::CompleteTensorType: {
-        auto var = py::cast<autograd::Variable>(obj);
-        if (var.is_sparse()) {
-          AT_ERROR("sparse tensors not supported");
-        }
-        return var;
-      }
-      case TypeKind::FloatType:
-        return py::cast<double>(obj);
-      case TypeKind::IntType:
-        return py::cast<int64_t>(obj);
-      case TypeKind::NoneType:
-        if(obj != Py_None)
-          throw py::cast_error();
-
-        return {};
-      case TypeKind::BoolType:
-        return py::cast<bool>(obj);
-      case TypeKind::TupleType: {
-        if(!PyTuple_Check(obj.ptr()))
-          throw py::cast_error(); // note: the py::cast does not throw cast_error
-                                  // because it attempts to iterate a non-tuple
-        py::tuple tuple = py::cast<py::tuple>(obj);
-        size_t tuple_size = tuple.size();
-        const auto & elem_types = type->cast<TupleType>()->elements();
-        if (elem_types.size() != tuple_size) {
-          throw py::cast_error();
-        }
-        std::vector<IValue> values;
-        values.reserve(tuple_size);
-        for (size_t i = 0; i < tuple_size; ++i) {
-          values.push_back(toIValue(tuple[i], elem_types[i]));
-        }
-        return Tuple::create(std::move(values));
+inline IValue toIValue(
+    py::handle obj,
+    const TypePtr& type,
+    c10::optional<int32_t> N) {
+  switch (type->kind()) {
+    case TypeKind::DynamicType:
+    case TypeKind::TensorType:
+    case TypeKind::UndefinedTensorType:
+    case TypeKind::CompleteTensorType: {
+      auto var = py::cast<autograd::Variable>(obj);
+      if (var.is_sparse()) {
+        AT_ERROR("sparse tensors not supported");
       }
-      case TypeKind::StringType:
-        return ConstantString::create(py::cast<std::string>(obj));
-      case TypeKind::DeviceObjType: {
-        auto device = reinterpret_cast<THPDevice*>(obj.ptr());
-        return device->device;
+      return var;
+    }
+    case TypeKind::FloatType:
+      return py::cast<double>(obj);
+    case TypeKind::IntType:
+      return py::cast<int64_t>(obj);
+    case TypeKind::NoneType:
+      if (obj != Py_None)
+        throw py::cast_error();
+
+      return {};
+    case TypeKind::BoolType:
+      return py::cast<bool>(obj);
+    case TypeKind::TupleType: {
+      if (!PyTuple_Check(obj.ptr()))
+        throw py::cast_error(); // note: the py::cast does not throw cast_error
+                                // because it attempts to iterate a non-tuple
+      py::tuple tuple = py::cast<py::tuple>(obj);
+      size_t tuple_size = tuple.size();
+      const auto& elem_types = type->cast<TupleType>()->elements();
+      if (elem_types.size() != tuple_size) {
+        throw py::cast_error();
       }
-      case TypeKind::ListType: {
-        const auto& elem_type = type->expect<ListType>()->getElementType();
-        switch(elem_type->kind()) {
-          //allows single int/float to be broadcasted to a fixed size list
-          case TypeKind::IntType:
-            if (!N || !py::isinstance<py::int_>(obj)) {
-              return py::cast<std::vector<int64_t>>(obj);
-            } else {
-              double value = py::cast<int64_t>(obj);
-              std::vector<double> repeated(*N, value);
-              return repeated;
-            }
-          case TypeKind::FloatType:
-            if (!N || !py::isinstance<py::float_>(obj)) {
-              return py::cast<std::vector<double>>(obj);
-            } else {
-              double value = py::cast<double>(obj);
-              std::vector<double> repeated(*N, value);
-              return repeated;
-            }
-          case TypeKind::TensorType:
-          case TypeKind::DynamicType:
-            return py::cast<std::vector<at::Tensor>>(obj);
-          default:
-            return createGenericList(obj, elem_type);
-        }
+      std::vector<IValue> values;
+      values.reserve(tuple_size);
+      for (size_t i = 0; i < tuple_size; ++i) {
+        values.push_back(toIValue(tuple[i], elem_types[i]));
       }
-      case TypeKind::OptionalType: {
-        const auto& elem_type = type->expect<OptionalType>()->getElementType();
-        // check if it's a none obj since optional accepts NoneType
-        if (obj == Py_None)  {
-          if(elem_type->isSubtypeOf(DynamicType::get())) {
-            // return undefined tensor for Optional[Tensor]
-            return at::Tensor();
+      return Tuple::create(std::move(values));
+    }
+    case TypeKind::StringType:
+      return ConstantString::create(py::cast<std::string>(obj));
+    case TypeKind::DeviceObjType: {
+      auto device = reinterpret_cast<THPDevice*>(obj.ptr());
+      return device->device;
+    }
+    case TypeKind::ListType: {
+      const auto& elem_type = type->expect<ListType>()->getElementType();
+      switch (elem_type->kind()) {
+        // allows single int/float to be broadcasted to a fixed size list
+        case TypeKind::IntType:
+          if (!N || !py::isinstance<py::int_>(obj)) {
+            return py::cast<std::vector<int64_t>>(obj);
+          } else {
+            double value = py::cast<int64_t>(obj);
+            std::vector<double> repeated(*N, value);
+            return repeated;
           }
-          else {
-            // for other optional types, return an IValue() to denote a None
-            return {};
+        case TypeKind::FloatType:
+          if (!N || !py::isinstance<py::float_>(obj)) {
+            return py::cast<std::vector<double>>(obj);
+          } else {
+            double value = py::cast<double>(obj);
+            std::vector<double> repeated(*N, value);
+            return repeated;
           }
+        case TypeKind::TensorType:
+        case TypeKind::DynamicType:
+          return py::cast<std::vector<at::Tensor>>(obj);
+        default:
+          return createGenericList(obj, elem_type);
+      }
+    }
+    case TypeKind::OptionalType: {
+      const auto& elem_type = type->expect<OptionalType>()->getElementType();
+      // check if it's a none obj since optional accepts NoneType
+      if (obj == Py_None) {
+        if (elem_type->isSubtypeOf(DynamicType::get())) {
+          // return undefined tensor for Optional[Tensor]
+          return at::Tensor();
+        } else {
+          // for other optional types, return an IValue() to denote a None
+          return {};
         }
-        return toIValue(obj, type->expect<OptionalType>()->getElementType());
       }
-      case TypeKind::NumberType:
-      case TypeKind::GeneratorType:
-      case TypeKind::VarType:
-      case TypeKind::FutureType:
-        break;
+      return toIValue(obj, type->expect<OptionalType>()->getElementType());
     }
-  AT_ERROR("Missing cases in toIValue for type: ", type->str(), "! File a bug report.");
+    case TypeKind::NumberType:
+    case TypeKind::GeneratorType:
+    case TypeKind::VarType:
+    case TypeKind::FutureType:
+      break;
+  }
+  AT_ERROR(
+      "Missing cases in toIValue for type: ",
+      type->str(),
+      "! File a bug report.");
 }
 
 inline IValue argumentToIValue(
@@ -229,9 +239,7 @@ inline IValue argumentToIValue(
   }
 }
 
-inline IValue returnToIValue(
-    const TypePtr& type,
-    py::handle object) {
+inline IValue returnToIValue(const TypePtr& type, py::handle object) {
   try {
     return toIValue(object, type);
   } catch (const py::cast_error& error) {
@@ -273,16 +281,16 @@ inline py::object toPyObject(IValue&& ivalue) {
     return py::cast(ivalue.toTensorListRef());
   } else if (ivalue.isGenericList()) {
     auto list = ivalue.toGenericList();
-    const auto & elements = list->elements();
-    py::list t { elements.size() };
+    const auto& elements = list->elements();
+    py::list t{elements.size()};
     for (size_t i = 0; i < elements.size(); ++i) {
       t[i] = toPyObject(IValue{elements[i]});
     }
     return t;
   } else if (ivalue.isTuple()) {
     auto tuple = ivalue.toTuple();
-    const auto & elements = tuple->elements();
-    py::tuple t { elements.size() };
+    const auto& elements = tuple->elements();
+    py::tuple t{elements.size()};
     for (size_t i = 0; i < elements.size(); ++i) {
       t[i] = toPyObject(IValue{elements[i]});
     }
@@ -296,11 +304,11 @@ inline py::object toPyObject(IValue&& ivalue) {
 
 struct VISIBILITY_HIDDEN tuple_slice {
   /*implicit*/ tuple_slice(py::tuple tup_)
-  : tup(std::move(tup_)), b(0), e(tup.size()) {}
+      : tup(std::move(tup_)), b(0), e(tup.size()) {}
   tuple_slice(py::tuple tup_, int64_t b_)
-  : tup(std::move(tup_)), b(b_), e(tup.size()) {}
+      : tup(std::move(tup_)), b(b_), e(tup.size()) {}
   tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
-  : tup(std::move(tup_)), b(b_), e(e_) {}
+      : tup(std::move(tup_)), b(b_), e(e_) {}
   py::detail::tuple_iterator begin() const {
     return {tup, b};
   }
@@ -313,7 +321,8 @@ struct VISIBILITY_HIDDEN tuple_slice {
   py::detail::tuple_accessor operator[](size_t index) const {
     return {tup, b + index};
   }
-private:
+
+ private:
   py::tuple tup;
   int64_t b;
   int64_t e;
@@ -323,11 +332,15 @@ inline Stack createStackForSchema(
     const FunctionSchema& schema,
     const tuple_slice& args,
     const py::kwargs& kwargs = py::kwargs()) {
-  if(args.size() + kwargs.size() > schema.arguments().size()) {
+  if (args.size() + kwargs.size() > schema.arguments().size()) {
     throw std::runtime_error(c10::str(
-        schema.name(), "() expected at most ", schema.arguments().size(),
+        schema.name(),
+        "() expected at most ",
+        schema.arguments().size(),
         " argument(s) but received ",
-        args.size() + kwargs.size(), " argument(s). Declaration: ", schema));
+        args.size() + kwargs.size(),
+        " argument(s). Declaration: ",
+        schema));
   }
   Stack stack;
   stack.reserve(schema.arguments().size());
@@ -387,10 +400,14 @@ inline py::object createPyObjectForStack(Stack&& stack) {
 }
 
 // TODO: Remove once we clean up the GraphExecutor usage.
-inline Stack evilDeprecatedBadCreateStackDoNotUse(const py::tuple& tuple, at::ArrayRef<Value*> inputs, size_t reserve_extra_space = 0) {
+inline Stack evilDeprecatedBadCreateStackDoNotUse(
+    const py::tuple& tuple,
+    at::ArrayRef<Value*> inputs,
+    size_t reserve_extra_space = 0) {
   if (tuple.size() != inputs.size()) {
-    AT_ERROR("expected " + std::to_string(inputs.size()) +
-                             " inputs, but got " + std::to_string(tuple.size()));
+    AT_ERROR(
+        "expected " + std::to_string(inputs.size()) + " inputs, but got " +
+        std::to_string(tuple.size()));
   }
   Stack result;
   result.reserve(tuple.size() + reserve_extra_space);
@@ -402,8 +419,10 @@ inline Stack evilDeprecatedBadCreateStackDoNotUse(const py::tuple& tuple, at::Ar
 
 inline py::object invokeScriptMethodFromPython(
     script::Method& method,
-    tuple_slice args, py::kwargs kwargs) {
-  auto stack = createStackForSchema(method.getSchema(), std::move(args), std::move(kwargs));
+    tuple_slice args,
+    py::kwargs kwargs) {
+  auto stack = createStackForSchema(
+      method.getSchema(), std::move(args), std::move(kwargs));
   {
     AutoNoGIL no_gil_guard;
     method.run(stack);
@@ -424,4 +443,5 @@ inline py::object invokeOperatorFromPython(
 
   return createPyObjectForStack(std::move(stack));
 }
-}}  // namespace torch::jit
+} // namespace jit
+} // namespace torch
index ba7b13b..95f7770 100644 (file)
@@ -2,33 +2,35 @@
 
 #include <torch/csrc/autograd/grad_mode.h>
 
-namespace torch { namespace jit { namespace python {
+namespace torch {
+namespace jit {
+namespace python {
 
 using namespace torch::autograd;
 using namespace at;
 
 // Alphabet used to describe structure of inputs/outputs (D for desc)
 namespace D {
-static constexpr char ListOpen          = '[';
-static constexpr char ListClose         = ']';
-static constexpr char TupleOpen         = '(';
-static constexpr char TupleClose        = ')';
-static constexpr char Variable          = 'v';
+static constexpr char ListOpen = '[';
+static constexpr char ListClose = ']';
+static constexpr char TupleOpen = '(';
+static constexpr char TupleClose = ')';
+static constexpr char Variable = 'v';
 } // namespace D
 
 namespace {
 
-template<typename T>
+template <typename T>
 py::object cast_handle_sequence(std::vector<py::handle> objs) {
   auto num_objs = objs.size();
-  T sequence { num_objs };
+  T sequence{num_objs};
   for (size_t i = 0; i < num_objs; ++i)
     sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
   return sequence;
 }
 
 void flatten_rec(PyObject* obj, ParsedArgs& args) {
-  auto & structure = args.desc.structure;
+  auto& structure = args.desc.structure;
   if (PyTuple_Check(obj)) {
     structure.push_back(D::TupleOpen);
     for (auto item : py::reinterpret_borrow<py::tuple>(obj))
@@ -45,7 +47,8 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
     args.desc.metadata.emplace_back(var);
     args.desc.structure.push_back(D::Variable);
   } else {
-    std::string msg = "Only tuples, lists and Variables supported as JIT inputs, but got ";
+    std::string msg =
+        "Only tuples, lists and Variables supported as JIT inputs, but got ";
     msg += THPUtils_typename(obj);
     throw std::runtime_error(msg);
   }
@@ -62,18 +65,19 @@ ParsedArgs flatten(py::handle obj) {
 
 namespace {
 
-template<typename T>
+template <typename T>
 py::object cast_sequence(std::vector<py::object> objs) {
   auto num_objs = objs.size();
-  T sequence { num_objs };
+  T sequence{num_objs};
   for (size_t i = 0; i < num_objs; ++i)
     sequence[i] = std::move(objs[i]);
   return sequence;
 }
 
-py::object unflatten_rec(ArrayRef<Variable>::iterator& var_it,
-                         ArrayRef<Variable>::iterator& var_it_end,
-                         std::string::const_iterator& desc_it) {
+py::object unflatten_rec(
+    ArrayRef<Variable>::iterator& var_it,
+    ArrayRef<Variable>::iterator& var_it_end,
+    std::string::const_iterator& desc_it) {
   char type = *desc_it++;
   if (type == D::TupleOpen) {
     std::vector<py::object> objs;
@@ -109,4 +113,6 @@ PyObject* unflatten(ArrayRef<Variable> vars, const IODescriptor& desc) {
   return output.release().ptr();
 }
 
-}}} // namespace torch::jit::python
+} // namespace python
+} // namespace jit
+} // namespace torch
index 00b97b3..93dca38 100644 (file)
@@ -1,27 +1,29 @@
 #pragma once
 
-#include <torch/csrc/jit/pybind.h>
 #include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/pybind.h>
 #include <torch/csrc/utils/hash.h>
 
 #include <ATen/ATen.h>
+#include <functional>
 #include <tuple>
 #include <vector>
-#include <functional>
 
-namespace torch { namespace jit { namespace python {
+namespace torch {
+namespace jit {
+namespace python {
 
 struct IODescriptor {
   struct VariableMetadata {
     VariableMetadata(const autograd::Variable& var)
-      : sizes(var.sizes().vec())
-      , type(var.type().scalarType())
-      , device(var.device())
-      , requires_grad(var.requires_grad()) {}
+        : sizes(var.sizes().vec()),
+          type(var.type().scalarType()),
+          device(var.device()),
+          requires_grad(var.requires_grad()) {}
 
     bool operator==(const VariableMetadata& o) const {
-      return std::tie(  device,   requires_grad,   type,  sizes) ==
-             std::tie(o.device, o.requires_grad, o.type, o.sizes);
+      return std::tie(device, requires_grad, type, sizes) ==
+          std::tie(o.device, o.requires_grad, o.type, o.sizes);
     }
 
     static size_t hash(const VariableMetadata& m) {
@@ -35,8 +37,8 @@ struct IODescriptor {
   };
 
   bool operator==(const IODescriptor& o) const {
-    return std::tie(  structure,   metadata,   grad_enabled) ==
-           std::tie(o.structure, o.metadata, o.grad_enabled);
+    return std::tie(structure, metadata, grad_enabled) ==
+        std::tie(o.structure, o.metadata, o.grad_enabled);
   }
 
   static size_t hash(const IODescriptor& o) {
@@ -45,7 +47,7 @@ struct IODescriptor {
 
   void extend(const autograd::variable_list& list) {
     metadata.reserve(metadata.size() + list.size());
-    for (auto & var : list)
+    for (auto& var : list)
       metadata.emplace_back(var);
   }
 
@@ -61,16 +63,19 @@ struct IODescriptor {
   bool grad_enabled = false;
 };
 
-static inline std::ostream& operator<<(std::ostream& out, const IODescriptor::VariableMetadata& meta) {
+static inline std::ostream& operator<<(
+    std::ostream& out,
+    const IODescriptor::VariableMetadata& meta) {
   at::Device meta_device = meta.device;
-  auto & t = at::getNonVariableType(meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
+  auto& t = at::getNonVariableType(
+      meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
   out << t << "(requires_grad=" << meta.requires_grad;
   if (meta_device.is_cuda()) {
     out << ", device=" << meta_device.index();
   }
   out << ") {";
-  for(size_t i = 0; i < meta.sizes.size(); ++i) {
-    if(i > 0)
+  for (size_t i = 0; i < meta.sizes.size(); ++i) {
+    if (i > 0)
       out << ", ";
     out << meta.sizes[i];
   }
@@ -78,10 +83,12 @@ static inline std::ostream& operator<<(std::ostream& out, const IODescriptor::Va
   return out;
 }
 
-static inline std::ostream& operator<<(std::ostream & out, const IODescriptor & desc) {
+static inline std::ostream& operator<<(
+    std::ostream& out,
+    const IODescriptor& desc) {
   out << desc.structure << "\n";
   out << "  with grad_enabled=" << desc.grad_enabled << "\n";
-  for(size_t i = 0; i < desc.metadata.size(); ++i) {
+  for (size_t i = 0; i < desc.metadata.size(); ++i) {
     out << "  with v" << i << " having type " << desc.metadata[i] << "\n";
   }
   return out;
@@ -95,17 +102,20 @@ struct ParsedArgs {
   IODescriptor desc;
 
   void extend(const autograd::variable_list& list) {
-    if (list.empty()) return;
+    if (list.empty())
+      return;
     vars.reserve(vars.size() + list.size());
-    for (auto & var : list)
+    for (auto& var : list)
       vars.emplace_back(var);
     desc.extend(list);
   }
 };
 
-
 ParsedArgs flatten(py::handle obj);
-PyObject* unflatten(at::ArrayRef<autograd::Variable> vars,
-                    const IODescriptor& structure);
+PyObject* unflatten(
+    at::ArrayRef<autograd::Variable> vars,
+    const IODescriptor& structure);
 
-}}} // namespace torch::jit::python
+} // namespace python
+} // namespace jit
+} // namespace torch
index a418d64..f2ae189 100644 (file)
@@ -1,27 +1,28 @@
-#include <torch/csrc/python_headers.h>
 #include <torch/csrc/jit/interpreter.h>
+#include <torch/csrc/python_headers.h>
 
 #include <torch/csrc/autograd/edge.h>
 #include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/pybind_utils.h>
 
 #include <typeinfo>
 
+#include <torch/csrc/Exceptions.h>
 #include <torch/csrc/autograd/python_engine.h>
 #include <torch/csrc/autograd/python_variable.h>
 #include <torch/csrc/jit/pybind.h>
 #include <torch/csrc/utils/auto_gil.h>
-#include <torch/csrc/Exceptions.h>
 
 namespace py = pybind11;
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
@@ -30,17 +31,18 @@ Operation createPythonOperation(const Node* op_) {
   AutoGIL gil;
   const PythonOp* op = static_cast<const PythonOp*>(op_);
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-  const py::function func = py::reinterpret_borrow<const py::function>(py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
+  const py::function func = py::reinterpret_borrow<const py::function>(
+      py::handle(const_cast<PythonOp*>(op)->pyobj.get()));
 
   size_t num_inputs = 0;
-  for(auto arg_type : op->cconv) {
-    if(arg_type == 'd')
+  for (auto arg_type : op->cconv) {
+    if (arg_type == 'd')
       num_inputs++;
   }
 
   JIT_ASSERT(op->outputs().size() == 1);
 
-  return [=](Stack & stack) {
+  return [=](Stack& stack) {
     AutoGIL gil;
     py::tuple py_inputs(op->cconv.size());
     size_t i = 0;
@@ -49,9 +51,11 @@ Operation createPythonOperation(const Node* op_) {
     for (auto arg_type : op->cconv) {
       if (arg_type == 'c') {
         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-        py_inputs[i] = py::reinterpret_borrow<const py::object>(const_cast<PythonOp*>(op)->scalar_args[next_scalar++].get());
+        py_inputs[i] = py::reinterpret_borrow<const py::object>(
+            const_cast<PythonOp*>(op)->scalar_args[next_scalar++].get());
       } else if (arg_type == 'd') {
-        py_inputs[i] = toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
+        py_inputs[i] =
+            toPyObject(std::move(peek(stack, next_tensor, num_inputs)));
         next_tensor++;
       }
       i++;
@@ -60,16 +64,15 @@ Operation createPythonOperation(const Node* op_) {
     try {
       py::object py_output(func(*py_inputs));
       stack.push_back(returnToIValue(op->output()->type(), py_output));
-    } catch (py::error_already_set & e) {
+    } catch (py::error_already_set& e) {
       throw std::runtime_error(e.what());
     }
     return 0;
   };
 }
 
+RegisterOperators reg({Operator(prim::PythonOp, createPythonOperation)});
 
-RegisterOperators reg({
-  Operator(prim::PythonOp, createPythonOperation)
-});
-
-}}} // torch::jit::anon
+} // namespace
+} // namespace jit
+} // namespace torch
index 5a39fe3..db832f0 100644 (file)
@@ -1,21 +1,21 @@
 #include <torch/csrc/python_headers.h>
 
+#include <torch/csrc/jit/argument_spec.h>
+#include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/pybind.h>
 #include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/utils/pybind.h>
-#include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/python_print.h>
-#include <torch/csrc/jit/argument_spec.h>
 #include <torch/csrc/utils/auto_gil.h>
+#include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/utils/python_strings.h>
 
-
 #include <iostream>
 #include <sstream>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 using c10::Type;
 
@@ -28,7 +28,7 @@ std::string getPythonName(const PyObject* obj_) {
   return py::str(v);
 }
 
-std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
+std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
   AutoGIL gil;
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
@@ -68,416 +68,486 @@ std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
   }
 }
 
-// execute a Python function, used for Ops we can't optimize but that we want to optimize around
+// execute a Python function, used for Ops we can't optimize but that we want to
+// optimize around
 struct ConcretePythonOp : public PythonOp {
- ConcretePythonOp(Graph * graph)
- : PythonOp(graph) {}
- std::string name() const override {
-   AutoGIL gil;
-   if(auto autograd = autogradFunction()) {
-     return getPythonName(autograd->get());
-   } else {
-     return getPythonName(pyobj.get());
-   }
- }
- void cloneFrom(Node * other_) override {
-   Node::cloneFrom(other_);
-   auto other = other_->cast<PythonOp>();
-   this->cconv = other->cconv;
-   Py_INCREF(other->pyobj.get());
-   this->pyobj = THPObjectPtr(other->pyobj.get());
-   for(auto & sa : other->scalar_args) {
-     Py_INCREF(sa.get());
-     this->scalar_args.emplace_back(sa.get());
-   }
- }
- Node * allocNewInstance(Graph * g) override {
-   return new ConcretePythonOp(g);
- }
- // recover the autograd.Function instance, if this PythonOp's function
- // was originally SomeFunction.apply
- // used in ONNX for discovering symbolics
- c10::optional<THPObjectPtr> autogradFunction() const override {
-   AutoGIL gil;
-   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-   py::handle obj = const_cast<PyObject*>(pyobj.get());
-
-   auto r = py::getattr(obj, "__self__", py::none());
-   if(r.is_none())
-     return c10::nullopt;
+  ConcretePythonOp(Graph* graph) : PythonOp(graph) {}
+  std::string name() const override {
+    AutoGIL gil;
+    if (auto autograd = autogradFunction()) {
+      return getPythonName(autograd->get());
+    } else {
+      return getPythonName(pyobj.get());
+    }
+  }
+  void cloneFrom(Node* other_) override {
+    Node::cloneFrom(other_);
+    auto other = other_->cast<PythonOp>();
+    this->cconv = other->cconv;
+    Py_INCREF(other->pyobj.get());
+    this->pyobj = THPObjectPtr(other->pyobj.get());
+    for (auto& sa : other->scalar_args) {
+      Py_INCREF(sa.get());
+      this->scalar_args.emplace_back(sa.get());
+    }
+  }
+  Node* allocNewInstance(Graph* g) override {
+    return new ConcretePythonOp(g);
+  }
+  // recover the autograd.Function instance, if this PythonOp's function
+  // was originally SomeFunction.apply
+  // used in ONNX for discovering symbolics
+  c10::optional<THPObjectPtr> autogradFunction() const override {
+    AutoGIL gil;
+    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+    py::handle obj = const_cast<PyObject*>(pyobj.get());
 
-   auto apply = py::getattr(r, "apply", py::none());
-   if(apply.is_none())
-     return c10::nullopt;
+    auto r = py::getattr(obj, "__self__", py::none());
+    if (r.is_none())
+      return c10::nullopt;
 
-   auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
-   if(PyErr_Occurred())
-     throw py::error_already_set();
-   if(c)
-     return c10::nullopt;
+    auto apply = py::getattr(r, "apply", py::none());
+    if (apply.is_none())
+      return c10::nullopt;
 
-   return THPObjectPtr(r.release().ptr());
- }
+    auto c = PyObject_RichCompareBool(apply.ptr(), obj.ptr(), Py_NE);
+    if (PyErr_Occurred())
+      throw py::error_already_set();
+    if (c)
+      return c10::nullopt;
 
- void writeScalars(std::ostream& out) const override {
-   out << "(";
-   int i = 0;
-   for (auto& scalar : scalar_args) {
-     if (i++ > 0)
-       out << ", ";
-     printPyObject(out, scalar);
-   }
-   out << ")";
- }
+    return THPObjectPtr(r.release().ptr());
+  }
 
+  void writeScalars(std::ostream& out) const override {
+    out << "(";
+    int i = 0;
+    for (auto& scalar : scalar_args) {
+      if (i++ > 0)
+        out << ", ";
+      printPyObject(out, scalar);
+    }
+    out << ")";
+  }
 };
 
 PythonOp* pythonAllocPythonOp(Graph* g) {
   return new ConcretePythonOp(g);
 }
 
-void initPythonIRBindings(PyObject * module_) {
+void initPythonIRBindings(PyObject* module_) {
   setAllocPythonOp(pythonAllocPythonOp);
 
   auto m = py::handle(module_).cast<py::module>();
-  #define GS(name) \
-    def(#name,&Graph :: name)
-  py::class_<Graph,std::shared_ptr<Graph>>(m,"Graph")
-    .def(py::init<>())
-    .def("__repr__",[](Graph & g) {
-      std::stringstream ss;
-      ss << g;
-      return ss.str();
-    })
-    .def("propagate_shapes", [](std::shared_ptr<Graph> g, std::vector<at::Tensor> inputs, bool with_grad) {
-      setInputTypes(*g, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
-      PropagateInputShapes(g);
-    })
-    .def("_export_onnx", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
-                      int64_t onnx_opset_version, bool defer_weight_export,
-                      ::torch::onnx::OperatorExportTypes operator_export_type) {
-      std::string graph;
-      RawDataExportMap export_map;
-      std::tie(graph, export_map) = export_onnx(
-        g, initializers, onnx_opset_version, defer_weight_export, operator_export_type);
-      std::unordered_map<std::string, py::bytes> python_serialized_export_map;
-      for (auto& kv : export_map) {
-        auto t = kv.second;
-        size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
-        // TODO: this is an unecessary copy. In theory we can directly return
-        // the map from identifier to Tensor, but we need some API in Python
-        // to get raw `bytes` containing the raw tensor data.
-        python_serialized_export_map[kv.first] = py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
-      }
-      return std::make_tuple(py::bytes(graph), python_serialized_export_map);
-    }, py::arg("initializers"),
-       py::arg("onnx_opset_version")=0,
-       py::arg("defer_weight_export")=false,
-       py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX)
-    .def("_pretty_print_onnx", [](const std::shared_ptr<Graph> g,
-          const std::vector<at::Tensor>& initializers,
-          int64_t onnx_opset_version, bool defer_weight_export,
-          ::torch::onnx::OperatorExportTypes operator_export_type,
-          bool google_printer) {
-      return pretty_print_onnx(
-        g, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
-        google_printer);
-    }, py::arg("initializers"),
-       py::arg("onnx_opset_version")=0,
-       py::arg("defer_weight_export")=false,
-       py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX,
-       py::arg("google_printer")=false)
-    .def("inputs",[](Graph &g) {
-      return py::make_iterator(g.inputs().begin(), g.inputs().end());
-    })
-    .def("outputs",[](Graph &g) {
-      return py::make_iterator(g.outputs().begin(), g.outputs().end());
-    })
-    // TODO: Iterator invalidation might make this hazardous
-    .def("nodes",[](Graph &g) {
-      return py::make_iterator(g.nodes().begin(), g.nodes().end());
-    })
-    .def("addInput",[](Graph &g) { return g.addInput(); })
-    .def("copy",[](Graph &g) {
-      return g.copy();
-    })
-    .GS(eraseInput)
-    .GS(registerOutput)
-    .def("create",[](Graph & g, const char * str) {
-      return g.create(Symbol::fromQualString(str));
-    })
-    .def("create",[](Graph & g, const char * str, size_t noutputs) {
-      return g.create(Symbol::fromQualString(str), noutputs);
-    })
-    .def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs) {
-      return g.create(Symbol::fromQualString(str),inputs);
-    })
-    .def("create",[](Graph & g, const char * str, const std::vector<Value*> & inputs, size_t noutputs) {
-      return g.create(Symbol::fromQualString(str),inputs, noutputs);
-    })
-    .def("param_node", [](Graph &g) {
-      return g.block()->param_node();
-    })
-    .def("return_node", [](Graph &g) {
-      return g.block()->return_node();
-    })
-    .def("pretty_print", [](Graph &g) {
-      std::ostringstream oss;
-      g.prettyPrint(oss);
-      return oss.str();
-    })
-    .GS(createFusionGroup)
-    .def("createClone",[](Graph & g, Node * n, py::object fn) {
-      return g.createClone(n, [&](Value * e) {
-        return fn(e).cast<Value*>();
-      });
-    })
-    .GS(appendNode)
-    .GS(prependNode)
-    .GS(lint)
-    .GS(insertNode)
-    ;
-    #undef GS
+#define GS(name) def(#name, &Graph ::name)
+  py::class_<Graph, std::shared_ptr<Graph>>(m, "Graph")
+      .def(py::init<>())
+      .def(
+          "__repr__",
+          [](Graph& g) {
+            std::stringstream ss;
+            ss << g;
+            return ss.str();
+          })
+      .def(
+          "propagate_shapes",
+          [](std::shared_ptr<Graph> g,
+             std::vector<at::Tensor> inputs,
+             bool with_grad) {
+            setInputTypes(
+                *g,
+                ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+            PropagateInputShapes(g);
+          })
+      .def(
+          "_export_onnx",
+          [](const std::shared_ptr<Graph> g,
+             const std::vector<at::Tensor>& initializers,
+             int64_t onnx_opset_version,
+             bool defer_weight_export,
+             ::torch::onnx::OperatorExportTypes operator_export_type) {
+            std::string graph;
+            RawDataExportMap export_map;
+            std::tie(graph, export_map) = export_onnx(
+                g,
+                initializers,
+                onnx_opset_version,
+                defer_weight_export,
+                operator_export_type);
+            std::unordered_map<std::string, py::bytes>
+                python_serialized_export_map;
+            for (auto& kv : export_map) {
+              auto t = kv.second;
+              size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
+              // TODO: this is an unecessary copy. In theory we can directly
+              // return the map from identifier to Tensor, but we need some API
+              // in Python to get raw `bytes` containing the raw tensor data.
+              python_serialized_export_map[kv.first] =
+                  py::bytes(static_cast<const char*>(t.data_ptr()), copy_bytes);
+            }
+            return std::make_tuple(
+                py::bytes(graph), python_serialized_export_map);
+          },
+          py::arg("initializers"),
+          py::arg("onnx_opset_version") = 0,
+          py::arg("defer_weight_export") = false,
+          py::arg("operator_export_type") =
+              ::torch::onnx::OperatorExportTypes::ONNX)
+      .def(
+          "_pretty_print_onnx",
+          [](const std::shared_ptr<Graph> g,
+             const std::vector<at::Tensor>& initializers,
+             int64_t onnx_opset_version,
+             bool defer_weight_export,
+             ::torch::onnx::OperatorExportTypes operator_export_type,
+             bool google_printer) {
+            return pretty_print_onnx(
+                g,
+                initializers,
+                onnx_opset_version,
+                defer_weight_export,
+                operator_export_type,
+                google_printer);
+          },
+          py::arg("initializers"),
+          py::arg("onnx_opset_version") = 0,
+          py::arg("defer_weight_export") = false,
+          py::arg("operator_export_type") =
+              ::torch::onnx::OperatorExportTypes::ONNX,
+          py::arg("google_printer") = false)
+      .def(
+          "inputs",
+          [](Graph& g) {
+            return py::make_iterator(g.inputs().begin(), g.inputs().end());
+          })
+      .def(
+          "outputs",
+          [](Graph& g) {
+            return py::make_iterator(g.outputs().begin(), g.outputs().end());
+          })
+      // TODO: Iterator invalidation might make this hazardous
+      .def(
+          "nodes",
+          [](Graph& g) {
+            return py::make_iterator(g.nodes().begin(), g.nodes().end());
+          })
+      .def("addInput", [](Graph& g) { return g.addInput(); })
+      .def("copy", [](Graph& g) { return g.copy(); })
+      .GS(eraseInput)
+      .GS(registerOutput)
+      .def(
+          "create",
+          [](Graph& g, const char* str) {
+            return g.create(Symbol::fromQualString(str));
+          })
+      .def(
+          "create",
+          [](Graph& g, const char* str, size_t noutputs) {
+            return g.create(Symbol::fromQualString(str), noutputs);
+          })
+      .def(
+          "create",
+          [](Graph& g, const char* str, const std::vector<Value*>& inputs) {
+            return g.create(Symbol::fromQualString(str), inputs);
+          })
+      .def(
+          "create",
+          [](Graph& g,
+             const char* str,
+             const std::vector<Value*>& inputs,
+             size_t noutputs) {
+            return g.create(Symbol::fromQualString(str), inputs, noutputs);
+          })
+      .def("param_node", [](Graph& g) { return g.block()->param_node(); })
+      .def("return_node", [](Graph& g) { return g.block()->return_node(); })
+      .def(
+          "pretty_print",
+          [](Graph& g) {
+            std::ostringstream oss;
+            g.prettyPrint(oss);
+            return oss.str();
+          })
+      .GS(createFusionGroup)
+      .def(
+          "createClone",
+          [](Graph& g, Node* n, py::object fn) {
+            return g.createClone(
+                n, [&](Value* e) { return fn(e).cast<Value*>(); });
+          })
+      .GS(appendNode)
+      .GS(prependNode)
+      .GS(lint)
+      .GS(insertNode);
+#undef GS
 
-  #define VS(name) \
-    def(#name,&Value :: name)
-  py::class_<Value,std::unique_ptr<Value, py::nodelete>>(m,"Value")
-    .def("__repr__",[](Value & n) {
-      std::stringstream ss;
-      ss << n.uniqueName() << " defined in (" << *n.node() << ")";
-      return ss.str();
-    })
-    .VS(type)
-    .VS(setType)
-    .VS(inferTypeFrom)
-    // skip owningGraph because it returns a raw pointer to a otherwise
-    // std::shared_ptr stored graph object, and would cause a double free
-    .VS(unique)
-    .VS(uniqueName)
-    .VS(setUniqueName)
-    .VS(offset)
-    .VS(uses)
-    .VS(replaceAllUsesWith)
-    .def("node",[](Value &v) { return v.node(); })
-    .def("setTypeAs", [](Value * node, Value * other) {
-      node->setType(other->type());
-      return node;
-    })
-    .VS(copyMetadata)
-    .VS(isTensor)
-    ;
+#define VS(name) def(#name, &Value ::name)
+  py::class_<Value, std::unique_ptr<Value, py::nodelete>>(m, "Value")
+      .def(
+          "__repr__",
+          [](Value& n) {
+            std::stringstream ss;
+            ss << n.uniqueName() << " defined in (" << *n.node() << ")";
+            return ss.str();
+          })
+      .VS(type)
+      .VS(setType)
+      .VS(inferTypeFrom)
+      // skip owningGraph because it returns a raw pointer to a otherwise
+      // std::shared_ptr stored graph object, and would cause a double free
+      .VS(unique)
+      .VS(uniqueName)
+      .VS(setUniqueName)
+      .VS(offset)
+      .VS(uses)
+      .VS(replaceAllUsesWith)
+      .def("node", [](Value& v) { return v.node(); })
+      .def(
+          "setTypeAs",
+          [](Value* node, Value* other) {
+            node->setType(other->type());
+            return node;
+          })
+      .VS(copyMetadata)
+      .VS(isTensor);
 
-  #undef VS
+#undef VS
 
   py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
-    .def("nodes",[](Block &b) {
-      return py::make_iterator(b.nodes().begin(), b.nodes().end());
-    });
+      .def("nodes", [](Block& b) {
+        return py::make_iterator(b.nodes().begin(), b.nodes().end());
+      });
 
-  #define NS(name) \
-    def(#name,&Node :: name)
-  py::class_<Node,std::unique_ptr<Node, py::nodelete>>(m,"Node")
-    .def("__repr__",[](Node & n) {
-      std::stringstream ss;
-      ss << n;
-      return ss.str();
-    })
-    .def("getSourceLocation", [](Node & n) -> py::object {
-      std::stringstream ss;
-      if (auto sl = n.getSourceLocation()) {
-        sl->highlight(ss);
-        return py::str(ss.str());
-      } else {
-        return py::none();
-      }
-    })
-    .def("hasMultipleOutputs",[](Node&n) {
-      return n.outputs().size() > 1;
-    })
-    .def("outputsSize",[](Node &n) {
-      return n.outputs().size();
-    })
-    .NS(kind)
-    .def("inputs",[](Node &n) {
-      return py::make_iterator(n.inputs().begin(), n.inputs().end());
-    })
-    .def("outputs",[](Node &n) {
-      return py::make_iterator(n.outputs().begin(), n.outputs().end());
-    })
-    .def("output", [](Node &n) {
-      return n.output();
-    })
-    .NS(addInput)
-    .NS(replaceInput)
-    .NS(replaceInputWith)
-    .NS(replaceAllUsesWith)
-    .NS(insertBefore)
-    .NS(insertAfter)
-    .NS(moveAfter)
-    .NS(moveBefore)
-    .NS(removeInput)
-    .NS(removeAllInputs)
-    .NS(destroy)
-    .NS(hasUses)
-    .NS(eraseOutput)
-    .NS(addOutput)
-    .NS(scopeName)
-    .NS(isNondeterministic)
-    .def("blocks", [](Node& n) {
-      return py::make_iterator(n.blocks().begin(), n.blocks().end());
-    })
-    .NS(addBlock)
+#define NS(name) def(#name, &Node ::name)
+  py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
+      .def(
+          "__repr__",
+          [](Node& n) {
+            std::stringstream ss;
+            ss << n;
+            return ss.str();
+          })
+      .def(
+          "getSourceLocation",
+          [](Node& n) -> py::object {
+            std::stringstream ss;
+            if (auto sl = n.getSourceLocation()) {
+              sl->highlight(ss);
+              return py::str(ss.str());
+            } else {
+              return py::none();
+            }
+          })
+      .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
+      .def("outputsSize", [](Node& n) { return n.outputs().size(); })
+      .NS(kind)
+      .def(
+          "inputs",
+          [](Node& n) {
+            return py::make_iterator(n.inputs().begin(), n.inputs().end());
+          })
+      .def(
+          "outputs",
+          [](Node& n) {
+            return py::make_iterator(n.outputs().begin(), n.outputs().end());
+          })
+      .def("output", [](Node& n) { return n.output(); })
+      .NS(addInput)
+      .NS(replaceInput)
+      .NS(replaceInputWith)
+      .NS(replaceAllUsesWith)
+      .NS(insertBefore)
+      .NS(insertAfter)
+      .NS(moveAfter)
+      .NS(moveBefore)
+      .NS(removeInput)
+      .NS(removeAllInputs)
+      .NS(destroy)
+      .NS(hasUses)
+      .NS(eraseOutput)
+      .NS(addOutput)
+      .NS(scopeName)
+      .NS(isNondeterministic)
+      .def(
+          "blocks",
+          [](Node& n) {
+            return py::make_iterator(n.blocks().begin(), n.blocks().end());
+          })
+      .NS(addBlock)
 
-#define AS(name) def(#name,&Attributes<Node> :: name)
-    // methods from Attributes
-    .AS(copyAttributes)
-    .AS(hasAttributes)
+#define AS(name) def(#name, &Attributes<Node>::name)
+      // methods from Attributes
+      .AS(copyAttributes)
+      .AS(hasAttributes)
 #undef AS
-#define AS(name) def(#name,&Attributes<Node> :: name ## S)
-    // The default method names take Symbol, but the string conversion for
-    // Symbol you to qualify with attr::. This is not very user friendly
-    // for attributes, so expose the string variants instead.
-    .AS(hasAttribute)
-    .AS(kindOf)
-    .AS(removeAttribute)
-    .AS(attributeNames)
+#define AS(name) def(#name, &Attributes<Node>::name##S)
+      // The default method names take Symbol, but the string conversion for
+      // Symbol you to qualify with attr::. This is not very user friendly
+      // for attributes, so expose the string variants instead.
+      .AS(hasAttribute)
+      .AS(kindOf)
+      .AS(removeAttribute)
+      .AS(attributeNames)
 #undef AS
-#define CREATE_ACCESSOR(Kind,method) \
-    def(#method "_",[](Node & n, const char * name, Kind##Attr::ValueType v) { \
-      return n . method ## _(Symbol::attr(name), std::move(v)); \
-    }) \
-    .def(#method, [](Node & n, const char * name) { \
-      return n.method(Symbol::attr(name)); \
-    })
-    .CREATE_ACCESSOR(Float,f)
-    .CREATE_ACCESSOR(Floats,fs)
-    .CREATE_ACCESSOR(String,s)
-    .CREATE_ACCESSOR(Strings,ss)
-    .CREATE_ACCESSOR(Int,i)
-    .CREATE_ACCESSOR(Ints,is)
-    .CREATE_ACCESSOR(Graph,g)
-    .CREATE_ACCESSOR(Graphs,gs)
+#define CREATE_ACCESSOR(Kind, method)                          \
+  def(#method "_",                                             \
+      [](Node& n, const char* name, Kind##Attr::ValueType v) { \
+        return n.method##_(Symbol::attr(name), std::move(v));  \
+      })                                                       \
+      .def(#method, [](Node& n, const char* name) {            \
+        return n.method(Symbol::attr(name));                   \
+      })
+      .CREATE_ACCESSOR(Float, f)
+      .CREATE_ACCESSOR(Floats, fs)
+      .CREATE_ACCESSOR(String, s)
+      .CREATE_ACCESSOR(Strings, ss)
+      .CREATE_ACCESSOR(Int, i)
+      .CREATE_ACCESSOR(Ints, is)
+      .CREATE_ACCESSOR(Graph, g)
+      .CREATE_ACCESSOR(Graphs, gs)
 #undef CREATE_ACCESSOR
-    // Tensor (t_) -- manually written to unwrap the variable into a tensor.
-    .def("t_",[](Node & n, const char * name, torch::autograd::Variable v) {
-      return n.t_(Symbol::attr(name), std::move(v.data()));
-    })
-    .def("t", [](Node & n, const char * name) {
-      return torch::autograd::make_variable(n.t(Symbol::attr(name)), /*requires_grad=*/false);
-    })
-    // Tensors (ts_) -- manually written to unwrap variables into tensors.
-    .def("ts_",[](Node & n, const char * name, std::vector<torch::autograd::Variable> vs) {
-      std::vector<at::Tensor> tensors;
-      tensors.reserve(vs.size());
-      for (auto& variable : vs) {
-        tensors.push_back(std::move(variable.data()));
-      }
-      return n.ts_(Symbol::attr(name), std::move(tensors));
-    })
-    .def("ts", [](Node & n, const char * name) {
-      auto tensors = n.ts(Symbol::attr(name));
-      std::vector<torch::autograd::Variable> variables;
-      variables.reserve(tensors.size());
-      for (auto& tensor : tensors) {
-        variables.push_back(torch::autograd::make_variable(
-            std::move(tensor), /*requires_grad=*/false));
-      }
-      return variables;
-    })
-    .def("z_",[](Node & n, const char * name, at::Tensor v) {
-        return n.t_(Symbol::attr(name), autograd::Variable(v.view({})).data());
-    })
-    .def("z",[](Node & n, const char * name) {
-        return n.t(Symbol::attr(name));
-    })
-    .def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
-        for (auto& i : v) {
-          i = autograd::Variable(i.view({})).data();
+      // Tensor (t_) -- manually written to unwrap the variable into a tensor.
+      .def(
+          "t_",
+          [](Node& n, const char* name, torch::autograd::Variable v) {
+            return n.t_(Symbol::attr(name), std::move(v.data()));
+          })
+      .def(
+          "t",
+          [](Node& n, const char* name) {
+            return torch::autograd::make_variable(
+                n.t(Symbol::attr(name)), /*requires_grad=*/false);
+          })
+      // Tensors (ts_) -- manually written to unwrap variables into tensors.
+      .def(
+          "ts_",
+          [](Node& n,
+             const char* name,
+             std::vector<torch::autograd::Variable> vs) {
+            std::vector<at::Tensor> tensors;
+            tensors.reserve(vs.size());
+            for (auto& variable : vs) {
+              tensors.push_back(std::move(variable.data()));
+            }
+            return n.ts_(Symbol::attr(name), std::move(tensors));
+          })
+      .def(
+          "ts",
+          [](Node& n, const char* name) {
+            auto tensors = n.ts(Symbol::attr(name));
+            std::vector<torch::autograd::Variable> variables;
+            variables.reserve(tensors.size());
+            for (auto& tensor : tensors) {
+              variables.push_back(torch::autograd::make_variable(
+                  std::move(tensor), /*requires_grad=*/false));
+            }
+            return variables;
+          })
+      .def(
+          "z_",
+          [](Node& n, const char* name, at::Tensor v) {
+            return n.t_(
+                Symbol::attr(name), autograd::Variable(v.view({})).data());
+          })
+      .def(
+          "z",
+          [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
+      .def(
+          "zs_",
+          [](Node& n, const char* name, TensorsAttr::ValueType v) {
+            for (auto& i : v) {
+              i = autograd::Variable(i.view({})).data();
+            }
+            return n.ts_(Symbol::attr(name), std::move(v));
+          })
+      .def(
+          "zs",
+          [](Node& n, const char* name) { return n.ts(Symbol::attr(name)); })
+      .def(
+          "pyobj",
+          [](Node& n) {
+            return py::handle(n.expect<PythonOp>()->pyobj.get())
+                .cast<py::object>();
+          })
+      .def("cconv", [](Node& n) { return n.expect<PythonOp>()->cconv; })
+      .def("pyname", [](Node& n) { return n.expect<PythonOp>()->name(); })
+      .def("scalar_args", [](Node& n) {
+        auto op = n.expect<PythonOp>();
+        auto scalars = py::list();
+        auto append = scalars.attr("append");
+        for (auto& arg : op->scalar_args) {
+          append(py::handle(arg.get()));
         }
-        return n.ts_(Symbol::attr(name), std::move(v));
-    })
-    .def("zs",[](Node & n, const char * name) {
-        return n.ts(Symbol::attr(name));
-    })
-    .def("pyobj",[](Node & n) {
-      return py::handle(n.expect<PythonOp>()->pyobj.get()).cast<py::object>();
-    })
-    .def("cconv",[](Node & n) {
-      return n.expect<PythonOp>()->cconv;
-    })
-    .def("pyname",[](Node & n) {
-      return n.expect<PythonOp>()->name();
-    })
-    .def("scalar_args",[](Node & n) {
-      auto op = n.expect<PythonOp>();
-      auto scalars = py::list();
-      auto append = scalars.attr("append");
-      for(auto & arg : op->scalar_args) {
-        append(py::handle(arg.get()));
-      }
-      return scalars;
-    })
-    ;
+        return scalars;
+      });
 
   using ::c10::Type;
-  py::class_<Type,std::shared_ptr<Type>>(m,"Type")
-    .def("__repr__",[](Type & t) {
-      return t.python_str();
-    })
-    .def("str",[](Type & t) {
-      std::ostringstream s;
-      s << t;
-      return s.str();
-    })
-    .def("kind",[](const Type& t) {
-      return typeKindToString(t.kind());
-    })
-    .def("sizes",[](Type& t) {
-      return t.expect<CompleteTensorType>()->sizes();
-    })
-    .def("strides",[](Type& t) {
-      return t.expect<CompleteTensorType>()->strides();
-    })
-    .def("contiguous",[](Type& t) {
-      return std::static_pointer_cast<Type>(t.expect<CompleteTensorType>()->contiguous());
-    })
-    .def("scalarType",[](Type& t) {
-      return toString(t.expect<TensorType>()->scalarType());
-    })
-    .def("__eq__", [](std::shared_ptr<Type>& self, std::shared_ptr<Type>& other) {
-                 return *self == *other;
-    })
-    .def("isSubtypeOf", [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
-        return self->isSubtypeOf(other);
-    });
+  py::class_<Type, std::shared_ptr<Type>>(m, "Type")
+      .def("__repr__", [](Type& t) { return t.python_str(); })
+      .def(
+          "str",
+          [](Type& t) {
+            std::ostringstream s;
+            s << t;
+            return s.str();
+          })
+      .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
+      .def(
+          "sizes",
+          [](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
+      .def(
+          "strides",
+          [](Type& t) { return t.expect<CompleteTensorType>()->strides(); })
+      .def(
+          "contiguous",
+          [](Type& t) {
+            return std::static_pointer_cast<Type>(
+                t.expect<CompleteTensorType>()->contiguous());
+          })
+      .def(
+          "scalarType",
+          [](Type& t) {
+            return toString(t.expect<TensorType>()->scalarType());
+          })
+      .def(
+          "__eq__",
+          [](std::shared_ptr<Type>& self, std::shared_ptr<Type>& other) {
+            return *self == *other;
+          })
+      .def(
+          "isSubtypeOf",
+          [](std::shared_ptr<Type>& self, std::shared_ptr<Type> other) {
+            return self->isSubtypeOf(other);
+          });
 
   py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
-    .def_static("get", &NumberType::get);
+      .def_static("get", &NumberType::get);
   py::class_<IntType, Type, std::shared_ptr<IntType>>(m, "IntType")
-    .def_static("get", &IntType::get);
+      .def_static("get", &IntType::get);
   py::class_<FloatType, Type, std::shared_ptr<FloatType>>(m, "FloatType")
-    .def_static("get", &FloatType::get);
+      .def_static("get", &FloatType::get);
   py::class_<DynamicType, Type, std::shared_ptr<DynamicType>>(m, "DynamicType")
-    .def_static("get", &DynamicType::get);
+      .def_static("get", &DynamicType::get);
   py::class_<BoolType, Type, std::shared_ptr<BoolType>>(m, "BoolType")
-    .def_static("get", &BoolType::get);
+      .def_static("get", &BoolType::get);
 
   py::class_<TupleType, Type, std::shared_ptr<TupleType>>(m, "TupleType")
-    .def(py::init([](std::vector<TypePtr> a){ return TupleType::create(a); }))
-    .def("elements", [](TupleType &self){
-      std::vector<TypePtr> types;
-      for (const auto& type : self.elements()) {
-        types.push_back(type);
-      }
-      return types;
-    });
+      .def(
+          py::init([](std::vector<TypePtr> a) { return TupleType::create(a); }))
+      .def("elements", [](TupleType& self) {
+        std::vector<TypePtr> types;
+        for (const auto& type : self.elements()) {
+          types.push_back(type);
+        }
+        return types;
+      });
   py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
-    .def_static("ofInts", &ListType::ofInts)
-    .def_static("ofTensors", &ListType::ofTensors)
-    .def("getElementType", &ListType::getElementType);
+      .def_static("ofInts", &ListType::ofInts)
+      .def_static("ofTensors", &ListType::ofTensors)
+      .def("getElementType", &ListType::getElementType);
 
-  py::class_<Use>(m,"Use")
-  .def_readonly("user",&Use::user)
-  .def_readonly("offset",&Use::offset);
+  py::class_<Use>(m, "Use")
+      .def_readonly("user", &Use::user)
+      .def_readonly("offset", &Use::offset);
 }
-}}
+} // namespace jit
+} // namespace torch
index 7f2098c..18ad1ed 100644 (file)
@@ -2,8 +2,10 @@
 
 #include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 void initPythonIRBindings(PyObject* module);
 
-}}
+}
+} // namespace torch
index 30210a1..aa56f4a 100644 (file)
@@ -1,12 +1,12 @@
 #include <torch/csrc/python_headers.h>
 
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/tracer.h>
 #include <torch/csrc/jit/export.h>
-#include <torch/csrc/jit/pybind.h>
-#include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/pybind.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/utils/python_strings.h>
 
 #include <c10/util/Exception.h>
 
@@ -16,16 +16,16 @@ using namespace torch::autograd;
 using namespace torch::jit;
 using namespace torch::jit::tracer;
 
-
-namespace torch { namespace jit { namespace tracer {
-
+namespace torch {
+namespace jit {
+namespace tracer {
 
 // Python interpreter retrieval routine adapted from
 // https://stackoverflow.com/a/8706144
 std::string getPythonInterpreterStackTrace() {
   std::stringstream stack_trace;
   AutoGIL gil;
-  PyFrameObject *frame = PyEval_GetFrame();
+  PyFrameObjectframe = PyEval_GetFrame();
   while (nullptr != frame) {
     int line = PyCode_Addr2Line(frame->f_code, frame->f_lasti);
     std::string filename = THPUtils_unpackString(frame->f_code->co_filename);
@@ -44,21 +44,22 @@ std::shared_ptr<torch::jit::Graph> createGraphByTracing(
     const c10::optional<size_t>& num_real_inputs) {
   size_t num_func_inputs = num_real_inputs.value_or(trace_inputs.size());
   auto enter_info = tracer::enter(std::move(trace_inputs));
-  getTracingState()->lookup_var_name_fn = [var_name_lookup_fn](const Variable& var) -> std::string {
+  getTracingState()->lookup_var_name_fn =
+      [var_name_lookup_fn](const Variable& var) -> std::string {
     AutoGIL ag;
     return py::cast<std::string>(var_name_lookup_fn(var));
   };
   getTracingState()->force_outplace = force_outplace;
   try {
-
     py::tuple py_inputs(num_func_inputs);
-    for(size_t i = 0; i < num_func_inputs; ++i) {
+    for (size_t i = 0; i < num_func_inputs; ++i) {
       py_inputs[i] = py::cast(enter_info.second[i]);
     }
     auto out = func(*py_inputs);
     if (out.ptr() == Py_None) {
-      AT_ERROR("The traced function didn't return any values! Side-effects are not "
-               "captured in traces, so it would be a no-op.");
+      AT_ERROR(
+          "The traced function didn't return any values! Side-effects are not "
+          "captured in traces, so it would be a no-op.");
     }
     tracer::exit({toIValue(out)});
     auto graph = enter_info.first->graph;
@@ -72,22 +73,23 @@ std::shared_ptr<torch::jit::Graph> createGraphByTracing(
   }
 }
 
-Node* preRecordPythonTrace(THPObjectPtr pyobj,
-                                  const std::string& arg_types,
-                                  at::ArrayRef<Variable> inputs,
-                                  pyobj_list scalar_args) {
+Node* preRecordPythonTrace(
+    THPObjectPtr pyobj,
+    const std::string& arg_types,
+    at::ArrayRef<Variable> inputs,
+    pyobj_list scalar_args) {
   THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
-  if(!apply) {
+  if (!apply) {
     throw python_error();
   }
 
-  auto & graph = getTracingState()->graph;
+  auto& graph = getTracingState()->graph;
 
   Node* n = graph->createPythonOp(
       std::move(apply), arg_types, std::move(scalar_args));
   recordSourceLocation(n);
 
-  for (const Variable & input : inputs) {
+  for (const Variable& input : inputs) {
     n->addInput(getValueTrace(input));
   }
 
@@ -98,7 +100,8 @@ Node* preRecordPythonTrace(THPObjectPtr pyobj,
 }
 
 void pythonRecordSourceLocation(Node* n) {
-  auto sl = std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
+  auto sl =
+      std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
   n->setSourceLocation(sl);
 }
 
@@ -112,46 +115,43 @@ void initPythonTracerBindings(PyObject* module) {
   setRecordSourceLocation(pythonRecordSourceLocation);
 
   auto m = py::handle(module).cast<py::module>();
-  py::class_<TracingState,std::shared_ptr<TracingState>>(m, "TracingState", py::dynamic_attr())
-    // NB: no constructor; you have to get it from C++ code
-    .def("__repr__", [](const TracingState& s) {
-      std::ostringstream ss;
-      ss << "<TracingState " << (const void*)&s << ">";
-      return ss.str();
-    })
-    .def("__str__", [](const TracingState& s) -> std::string {
-      std::ostringstream ss;
-      ss << *s.graph;
-      return ss.str();
-    })
-    .def("push_scope", [](TracingState& s, const std::string& scope_name) {
-      s.graph->push_scope(scope_name);
-    })
-    .def("pop_scope", [](TracingState& s) {
-      s.graph->pop_scope();
-    })
-    .def("set_graph", [](TracingState& s, std::shared_ptr<Graph> g) {
-      s.graph = g;
-    })
-    .def("graph", [](TracingState& s) {
-      return s.graph;
-    });
-
-  m.def("_tracer_warn_use_python", []() {
-    tracer::setWarn(pythonWarn);
-  });
+  py::class_<TracingState, std::shared_ptr<TracingState>>(
+      m, "TracingState", py::dynamic_attr())
+      // NB: no constructor; you have to get it from C++ code
+      .def(
+          "__repr__",
+          [](const TracingState& s) {
+            std::ostringstream ss;
+            ss << "<TracingState " << (const void*)&s << ">";
+            return ss.str();
+          })
+      .def(
+          "__str__",
+          [](const TracingState& s) -> std::string {
+            std::ostringstream ss;
+            ss << *s.graph;
+            return ss.str();
+          })
+      .def(
+          "push_scope",
+          [](TracingState& s, const std::string& scope_name) {
+            s.graph->push_scope(scope_name);
+          })
+      .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
+      .def(
+          "set_graph",
+          [](TracingState& s, std::shared_ptr<Graph> g) { s.graph = g; })
+      .def("graph", [](TracingState& s) { return s.graph; });
+
+  m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
   m.def("_tracer_enter", [](py::args trace_inputs) {
     return tracer::enter(toStack(trace_inputs));
   });
   m.def("_tracer_exit", [](py::tuple var_outputs) {
     tracer::exit(toStack(var_outputs));
   });
-  m.def("_tracer_abandon", []() {
-    tracer::abandon();
-  });
-  m.def("_get_tracing_state", []() {
-    return getTracingState();
-  });
+  m.def("_tracer_abandon", []() { tracer::abandon(); });
+  m.def("_get_tracing_state", []() { return getTracingState(); });
   m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
     return setTracingState(state);
   });
@@ -164,7 +164,8 @@ void initPythonTracerBindings(PyObject* module) {
   m.def("_tracer_set_get_unique_name_fn", [](py::function func) {
     const auto& tracing_state = getTracingState();
     JIT_ASSERT(tracing_state);
-    tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string {
+    tracing_state->lookup_var_name_fn =
+        [func](const Variable& var) -> std::string {
       AutoGIL ag;
       return py::cast<std::string>(func(var));
     };
@@ -176,4 +177,6 @@ void initPythonTracerBindings(PyObject* module) {
   });
 }
 
-}}} // namespace torch::jit::tracing
+} // namespace tracer
+} // namespace jit
+} // namespace torch
index 1bb52bb..ff68499 100644 (file)
@@ -1,15 +1,16 @@
 #pragma once
 
-#include <torch/csrc/python_headers.h>
 #include <torch/csrc/jit/tracer.h>
+#include <torch/csrc/python_headers.h>
 #include <torch/csrc/utils/pybind.h>
 
 #include <memory>
 #include <string>
 
-namespace torch { namespace jit { namespace tracer {
-void initPythonTracerBindings(PyObject *module);
-
+namespace torch {
+namespace jit {
+namespace tracer {
+void initPythonTracerBindings(PyObject* module);
 
 std::string getPythonInterpreterStackTrace();
 Node* preRecordPythonTrace(
@@ -25,4 +26,5 @@ std::shared_ptr<Graph> createGraphByTracing(
     bool force_outplace,
     const c10::optional<size_t>& num_real_inputs = c10::nullopt);
 } // namespace tracer
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index e9ebd1f..b4df638 100644 (file)
@@ -3,11 +3,11 @@
 #include <torch/csrc/autograd/generated/variable_factories.h>
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/fuser/interface.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
 
 #include <ATen/ExpandUtils.h>
@@ -42,14 +42,19 @@ Operation noop(const Node* n) {
 // and if the dest is an int the source must be integral type
 void checkImplicitTensorToNum(at::Tensor t, bool toInt) {
   if (autograd::as_variable_ref(t).requires_grad()) {
-    throw std::runtime_error("Cannot input a tensor that requires grad as a scalar argument");
+    throw std::runtime_error(
+        "Cannot input a tensor that requires grad as a scalar argument");
   }
   if (t.sizes().size() != 0) {
-    throw std::runtime_error("Cannot input a tensor of dimension other than 0 as a scalar argument");
+    throw std::runtime_error(
+        "Cannot input a tensor of dimension other than 0 as a scalar argument");
   }
-  if (toInt && !isIntegralType(autograd::as_variable_ref(t).data().type().scalarType())) {
+  if (toInt &&
+      !isIntegralType(
+          autograd::as_variable_ref(t).data().type().scalarType())) {
     std::stringstream ss;
-    ss << "Cannot input a tensor of type " << t.type().scalarType() << " as an integral argument";
+    ss << "Cannot input a tensor of type " << t.type().scalarType()
+       << " as an integral argument";
     throw std::runtime_error(ss.str());
   }
 }
@@ -88,21 +93,21 @@ RegisterOperators reg({
     Operator(
         "prim::Float(Tensor a) -> float",
         [](const Node* node) -> Operation {
-            return [](Stack& stack) {
-              at::Tensor a;
-              pop(stack, a);
-              push(stack, a.item<double>());
-              return 0;
-            };
+          return [](Stack& stack) {
+            at::Tensor a;
+            pop(stack, a);
+            push(stack, a.item<double>());
+            return 0;
+          };
         }),
     Operator(
         "prim::ImplicitTensorToNum(Tensor a) -> Scalar",
         [](const Node* node) -> Operation {
-          if(node->output()->type() == IntType::get()) {
+          if (node->output()->type() == IntType::get()) {
             return [](Stack& stack) {
               at::Tensor a;
               pop(stack, a);
-              checkImplicitTensorToNum(a, /*to int*/true);
+              checkImplicitTensorToNum(a, /*to int*/ true);
               push(stack, a.item<int64_t>());
               return 0;
             };
@@ -110,7 +115,7 @@ RegisterOperators reg({
             return [](Stack& stack) {
               at::Tensor a;
               pop(stack, a);
-              checkImplicitTensorToNum(a, /*to int*/false);
+              checkImplicitTensorToNum(a, /*to int*/ false);
               push(stack, a.item<double>());
               return 0;
             };
@@ -134,9 +139,7 @@ RegisterOperators reg({
           return [](Stack& stack) {
             bool b;
             pop(stack, b);
-            push(
-                stack,
-                autograd::make_variable(at::scalar_to_tensor(b)));
+            push(stack, autograd::make_variable(at::scalar_to_tensor(b)));
             return 0;
           };
         }),
@@ -252,13 +255,13 @@ RegisterOperators reg({
           };
         }),
     Operator(
-      prim::None,
-      [](const Node* node) {
-        return [](Stack& stack) {
-          stack.emplace_back(IValue());
-          return 0;
-        };
-      }),
+        prim::None,
+        [](const Node* node) {
+          return [](Stack& stack) {
+            stack.emplace_back(IValue());
+            return 0;
+          };
+        }),
     Operator(
         prim::Print,
         [](const Node* node) {
@@ -284,7 +287,8 @@ RegisterOperators reg({
             std::vector<int64_t> size;
             size.reserve(8);
             for (size_t i = 0; i < num_inputs; ++i) {
-              size = at::infer_size(size, peek(stack, i, num_inputs).toIntList()->elements());
+              size = at::infer_size(
+                  size, peek(stack, i, num_inputs).toIntList()->elements());
             }
             drop(stack, num_inputs);
             push(stack, std::move(size));
@@ -299,18 +303,21 @@ RegisterOperators reg({
           return [raw_dim, chunks](Stack& stack) {
             Shared<IntList> sizes_l;
             pop(stack, sizes_l);
-            const auto & shape = sizes_l->elements();
+            const auto& shape = sizes_l->elements();
             std::vector<int64_t> regular_shape = shape;
             std::vector<int64_t> last_shape = shape;
             int64_t dim = at::maybe_wrap_dim(raw_dim, shape.size());
-            AT_CHECK(dim < regular_shape.size(), "Dimension out of range for chunk");
+            AT_CHECK(
+                dim < regular_shape.size(), "Dimension out of range for chunk");
             int64_t split_size = (regular_shape[dim] + chunks - 1) / chunks;
             regular_shape[dim] = split_size;
             if (shape[dim] % chunks == 0) {
               last_shape[dim] = split_size;
             } else {
-              int64_t num_splits = std::max<int64_t>((shape[dim] + split_size - 1) / split_size, 1);
-              last_shape[dim] = split_size - (split_size * num_splits - shape[dim]);
+              int64_t num_splits = std::max<int64_t>(
+                  (shape[dim] + split_size - 1) / split_size, 1);
+              last_shape[dim] =
+                  split_size - (split_size * num_splits - shape[dim]);
               JIT_ASSERT(last_shape[dim] >= 0);
             }
             push(stack, std::move(regular_shape));
@@ -319,7 +326,11 @@ RegisterOperators reg({
           };
         }),
     Operator(
-        FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}),
+        FunctionSchema(
+            "aten::warn",
+            {Argument("message", StringType::get()),
+             Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)},
+            {}),
         [](const Node* node) {
           return [](Stack& stack) {
             drop(stack, 1);
@@ -436,9 +447,13 @@ RegisterOperators reg({
           size_t num_elems = node->outputs().size();
           return [=](Stack& stack) {
             auto t = pop(stack).toTuple();
-            const auto & elems = t->elements();
+            const auto& elems = t->elements();
             if (elems.size() != num_elems) {
-              AT_ERROR("Expected a tuple of ", num_elems, " elements, but got ", elems.size());
+              AT_ERROR(
+                  "Expected a tuple of ",
+                  num_elems,
+                  " elements, but got ",
+                  elems.size());
             }
             stack.insert(stack.end(), elems.begin(), elems.end());
             return 0;
@@ -451,7 +466,7 @@ RegisterOperators reg({
           int64_t end_ind = node->i(attr::end);
           return [=](Stack& stack) {
             auto t = pop(stack).toTuple();
-            const auto & elems = t->elements();
+            const auto& elems = t->elements();
             std::vector<IValue> output_elems;
             for (int64_t i = beg_ind; i < end_ind; ++i) {
               output_elems.emplace_back(elems.at(i));
@@ -461,26 +476,25 @@ RegisterOperators reg({
           };
         }),
     Operator(
-      prim::TupleIndex,
-      [](const Node* node) {
-        auto index = node->i(attr::index);
-        return [=](Stack& stack) {
-          auto tup = pop(stack).toTuple();
-          const auto & elems = tup->elements();
-          // index is normalized to be positive at compile time
-          stack.emplace_back(elems.at(index));
-          return 0;
-        };
-      }),
+        prim::TupleIndex,
+        [](const Node* node) {
+          auto index = node->i(attr::index);
+          return [=](Stack& stack) {
+            auto tup = pop(stack).toTuple();
+            const auto& elems = tup->elements();
+            // index is normalized to be positive at compile time
+            stack.emplace_back(elems.at(index));
+            return 0;
+          };
+        }),
     Operator(
         prim::TupleConstruct,
         [](const Node* node) {
           size_t num_inputs = node->inputs().size();
           return [=](Stack& stack) {
-            std::vector<IValue> elems {
-              std::make_move_iterator(stack.end() - num_inputs),
-              std::make_move_iterator(stack.end())
-            };
+            std::vector<IValue> elems{
+                std::make_move_iterator(stack.end() - num_inputs),
+                std::make_move_iterator(stack.end())};
             drop(stack, num_inputs);
             push(stack, Tuple::create(std::move(elems)));
             return 0;
@@ -491,25 +505,38 @@ RegisterOperators reg({
         [](const Node* node) {
           int64_t chunks = node->i(attr::chunks);
           int64_t dim = node->i(attr::dim);
-          auto outputs_used = fmap(node->outputs(), [](const Value *v) { return v->uses().size() > 0; });
+          auto outputs_used = fmap(node->outputs(), [](const Value* v) {
+            return v->uses().size() > 0;
+          });
           return [=](Stack& stack) {
             autograd::profiler::RecordFunction record("chunk");
             at::Tensor t;
             pop(stack, t);
             auto result = at::chunk(t, chunks, dim);
-            stack.insert(stack.end(), std::make_move_iterator(result.begin()),
-                                      std::make_move_iterator(result.end()));
+            stack.insert(
+                stack.end(),
+                std::make_move_iterator(result.begin()),
+                std::make_move_iterator(result.end()));
             // NB: Chunk can sometimes return a smaller number of outputs.
             int64_t num_results = result.size();
             if (num_results != chunks) {
               if (num_results > chunks) {
-                JIT_ASSERTM(num_results == chunks,
-                            "Expected chunk to return ", chunks, " outputs, but got ", num_results);
+                JIT_ASSERTM(
+                    num_results == chunks,
+                    "Expected chunk to return ",
+                    chunks,
+                    " outputs, but got ",
+                    num_results);
               }
               for (int64_t i = num_results; i < chunks; ++i) {
-                AT_CHECK(!outputs_used[i],
-                         "Expected chunk to return at least ", chunks, " outputs, but got only ", num_results);
-                // We know that the output is unused, so it's ok to push anything on the stack.
+                AT_CHECK(
+                    !outputs_used[i],
+                    "Expected chunk to return at least ",
+                    chunks,
+                    " outputs, but got only ",
+                    num_results);
+                // We know that the output is unused, so it's ok to push
+                // anything on the stack.
                 stack.emplace_back();
               }
             }
@@ -524,27 +551,39 @@ RegisterOperators reg({
           if (lt->getElementType() == IntType::get()) {
             return [=](Stack& stack) {
               auto ilist = pop(stack);
-              const auto & list = ilist.toIntList()->elements();
-              AT_CHECK(list.size() == num_outputs,
-                       "Expected ", num_outputs, " elements in a list but found ", list.size());
+              const auto& list = ilist.toIntList()->elements();
+              AT_CHECK(
+                  list.size() == num_outputs,
+                  "Expected ",
+                  num_outputs,
+                  " elements in a list but found ",
+                  list.size());
               stack.insert(stack.end(), list.begin(), list.end());
               return 0;
             };
           } else if (lt->getElementType() == FloatType::get()) {
             return [=](Stack& stack) {
               auto ilist = pop(stack);
-              const auto & list = ilist.toDoubleList()->elements();
-              AT_CHECK(list.size() == num_outputs,
-                       "Expected ", num_outputs, " elements in a list but found ", list.size());
+              const auto& list = ilist.toDoubleList()->elements();
+              AT_CHECK(
+                  list.size() == num_outputs,
+                  "Expected ",
+                  num_outputs,
+                  " elements in a list but found ",
+                  list.size());
               stack.insert(stack.end(), list.begin(), list.end());
               return 0;
             };
           } else if (lt->getElementType() == DynamicType::get()) {
             return [=](Stack& stack) {
               auto ilist = pop(stack);
-              const auto & list = ilist.toTensorList()->elements();
-              AT_CHECK(list.size() == num_outputs,
-                       "Expected ", num_outputs, " elements in a list but found ", list.size());
+              const auto& list = ilist.toTensorList()->elements();
+              AT_CHECK(
+                  list.size() == num_outputs,
+                  "Expected ",
+                  num_outputs,
+                  " elements in a list but found ",
+                  list.size());
               stack.insert(stack.end(), list.begin(), list.end());
               return 0;
             };
@@ -557,22 +596,20 @@ RegisterOperators reg({
         [](const Node* node) -> Operation {
           const auto num_inputs = node->inputs().size();
           ListTypePtr lt = node->output()->type()->expect<ListType>();
-          if(IntType::get() == lt->getElementType()) {
+          if (IntType::get() == lt->getElementType()) {
             return [=](Stack& stack) {
               auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
-              std::vector<int64_t> vals = fmap(inputs, [](const IValue& v) {
-                return v.toInt();
-              });
+              std::vector<int64_t> vals =
+                  fmap(inputs, [](const IValue& v) { return v.toInt(); });
               drop(stack, num_inputs);
               push(stack, std::move(vals));
               return 0;
             };
-          } else if(FloatType::get() == lt->getElementType()) {
+          } else if (FloatType::get() == lt->getElementType()) {
             return [=](Stack& stack) {
               auto inputs = peekSlice(stack, 0, num_inputs, num_inputs);
-              std::vector<double> vals = fmap(inputs, [](const IValue& v) {
-                return v.toDouble();
-              });
+              std::vector<double> vals =
+                  fmap(inputs, [](const IValue& v) { return v.toDouble(); });
               drop(stack, num_inputs);
               push(stack, std::move(vals));
               return 0;
@@ -603,15 +640,16 @@ RegisterOperators reg({
             };
           }
         }),
-    Operator("aten::_unwrap_optional(t? optional) -> t",
-      [](const Node* node) -> Operation {
-        return [=](Stack& stack) {
-          auto val = pop(stack);
-          JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
-          push(stack, val);
-          return 0;
-        };
-      }),
+    Operator(
+        "aten::_unwrap_optional(t? optional) -> t",
+        [](const Node* node) -> Operation {
+          return [=](Stack& stack) {
+            auto val = pop(stack);
+            JIT_ASSERTM(!val.isNone(), "Unwrapping null optional");
+            push(stack, val);
+            return 0;
+          };
+        }),
     Operator(
         prim::fork,
         [](const Node* node) {
@@ -652,7 +690,7 @@ RegisterOperators reg({
 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
   Operator(                                                                    \
       #aten_op "(int a, int b) -> " #int_result,                               \
-      [](const Node* node) {                                                         \
+      [](const Node* node) {                                                   \
         return [=](Stack& stack) {                                             \
           int64_t a, b;                                                        \
           pop(stack, a, b);                                                    \
@@ -660,19 +698,21 @@ RegisterOperators reg({
           return 0;                                                            \
         };                                                                     \
       }),                                                                      \
-  Operator(                                                                    \
-      #aten_op "(float a, float b) -> " #float_result, [](const Node* node) {        \
-        return [=](Stack& stack) {                                             \
-          double a, b;                                                         \
-          pop(stack, a, b);                                                    \
-          push(stack, float_op);                                               \
-          return 0;                                                            \
-        };                                                                     \
-      }),
+      Operator(                                                                \
+          #aten_op "(float a, float b) -> " #float_result,                     \
+          [](const Node* node) {                                               \
+            return [=](Stack& stack) {                                         \
+              double a, b;                                                     \
+              pop(stack, a, b);                                                \
+              push(stack, float_op);                                           \
+              return 0;                                                        \
+            };                                                                 \
+          })
 
 #define DEFINE_INT_FLOAT_OP(aten_op, op, result)                               \
   Operator(                                                                    \
-      #aten_op "(int a, float b) -> " #result, [](const Node* node) {          \
+      #aten_op "(int a, float b) -> " #result,                                 \
+      [](const Node* node) {                                                   \
         return [=](Stack& stack) {                                             \
           int64_t a;                                                           \
           double b;                                                            \
@@ -681,8 +721,7 @@ RegisterOperators reg({
           return 0;                                                            \
         };                                                                     \
       }),                                                                      \
-  Operator(                                                                    \
-      #aten_op "(float a, int b) -> " #result, [](const Node* node) {          \
+      Operator(#aten_op "(float a, int b) -> " #result, [](const Node* node) { \
         return [=](Stack& stack) {                                             \
           double a;                                                            \
           int64_t b;                                                           \
@@ -690,34 +729,33 @@ RegisterOperators reg({
           push(stack, op);                                                     \
           return 0;                                                            \
         };                                                                     \
-      }),
+      })
 
-
-#define DEFINE_INT_OP(aten_op, op)                            \
+#define DEFINE_INT_OP(aten_op, op)                                  \
   Operator(#aten_op "(int a, int b) -> int", [](const Node* node) { \
-    return [=](Stack& stack) {                                \
-      int64_t a, b;                                           \
-      pop(stack, a, b);                                       \
-      push(stack, op); /* NOLINT(hicpp-signed-bitwise) */     \
-      return 0;                                               \
-    };                                                        \
-  }),
-
-#define DEFINE_BINARY_OP(aten_op, op) \
-  DEFINE_GENERIC_OP(aten_op, op, op, int, float)  \
-  DEFINE_INT_FLOAT_OP(aten_op, op, float)
-#define DEFINE_COMPARISON_OP(aten_op, op) \
-  DEFINE_GENERIC_OP(aten_op, op, op, bool, bool) \
-  DEFINE_INT_FLOAT_OP(aten_op, op, bool)
-#define DEFINE_BOOL_OP(aten_op, op)                              \
+    return [=](Stack& stack) {                                      \
+      int64_t a, b;                                                 \
+      pop(stack, a, b);                                             \
+      push(stack, op); /* NOLINT(hicpp-signed-bitwise) */           \
+      return 0;                                                     \
+    };                                                              \
+  })
+
+#define DEFINE_BINARY_OP(aten_op, op)             \
+  DEFINE_GENERIC_OP(aten_op, op, op, int, float), \
+      DEFINE_INT_FLOAT_OP(aten_op, op, float)
+#define DEFINE_COMPARISON_OP(aten_op, op)         \
+  DEFINE_GENERIC_OP(aten_op, op, op, bool, bool), \
+      DEFINE_INT_FLOAT_OP(aten_op, op, bool)
+#define DEFINE_BOOL_OP(aten_op, op)                                    \
   Operator(#aten_op "(bool a, bool b) -> bool", [](const Node* node) { \
-    return [=](Stack& stack) {                                   \
-      bool a, b;                                                 \
-      pop(stack, a, b);                                          \
-      push(stack, op);                                           \
-      return 0;                                                  \
-    };                                                           \
-  }),
+    return [=](Stack& stack) {                                         \
+      bool a, b;                                                       \
+      pop(stack, a, b);                                                \
+      push(stack, op);                                                 \
+      return 0;                                                        \
+    };                                                                 \
+  })
 
 // Convert an python index (which may be negative) into an index usable for a
 // C++ container
@@ -919,124 +957,150 @@ Operation listSetItem(const Node* node) {
 
 RegisterOperators reg2({
 
-#define DEFINE_STRING_OP(op_name, string_op, result)                           \
-Operator(                                                                      \
-    #op_name "(str a, str b) ->" #result,                                \
-    [](const Node* node) {                                                    \
-      return [=](Stack& stack) {                                               \
-        auto b = pop(stack).toStringRef();                                     \
-        auto a = pop(stack).toStringRef();                                     \
-        push(stack, string_op);                                                \
-        return 0;                                                              \
-    };                                                                         \
-  }),
-
-  DEFINE_STRING_OP(aten::eq, a == b, bool)
-  DEFINE_STRING_OP(aten::ne, a != b, bool)
-  DEFINE_STRING_OP(aten::add, a + b, str)
+#define DEFINE_STRING_OP(op_name, string_op, result)                    \
+  Operator(#op_name "(str a, str b) ->" #result, [](const Node* node) { \
+    return [=](Stack& stack) {                                          \
+      auto b = pop(stack).toStringRef();                                \
+      auto a = pop(stack).toStringRef();                                \
+      push(stack, string_op);                                           \
+      return 0;                                                         \
+    };                                                                  \
+  })
+
+    DEFINE_STRING_OP(aten::eq, a == b, bool),
+    DEFINE_STRING_OP(aten::ne, a != b, bool),
+    DEFINE_STRING_OP(aten::add, a + b, str),
 #undef DEFINE_STRING_OP
 
     // tensor length op (size of 1st dimension)
     Operator(
-      "aten::len(Tensor t) -> int",
-      [](Stack& stack) {
-        at::Tensor t = pop(stack).toTensor();
-        if (t.dim() == 0) {
-          AT_ERROR("len() of a 0-d tensor");
-        }
-        push(stack, t.sizes()[0]);
-        return 0;
-      }
-    ),
+        "aten::len(Tensor t) -> int",
+        [](Stack& stack) {
+          at::Tensor t = pop(stack).toTensor();
+          if (t.dim() == 0) {
+            AT_ERROR("len() of a 0-d tensor");
+          }
+          push(stack, t.sizes()[0]);
+          return 0;
+        }),
     Operator(
         "aten::append(Tensor[](a!) self, Tensor(c) el) -> Tensor[](a!)",
         listAppend<Shared<TensorList>, at::Tensor>),
-    Operator("aten::select(Tensor[](a) list, int idx) -> Tensor(*)", listSelect<Shared<TensorList>>),
-    Operator("aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)", listSetItem<Shared<TensorList>, at::Tensor>),
-
-  // Mutable ops for lists containing immutable types.
-#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type) \
-    Operator("aten::select(" decl_type "[] a, int b) -> " decl_type, listSelect<Shared<c_type>>), \
-    Operator( \
-        "aten::append(" decl_type "[](a!) self, " decl_type " el) -> " decl_type "[](a!)", \
-        listAppend<Shared<c_type>, c_type::ElemType>), \
-    Operator("aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type " el) -> " decl_type"[](a!)", listSetItem<Shared<c_type>, c_type::ElemType>), \
-
-    CREATE_IMMUTABLE_LIST_OPS("int", IntList)
-    CREATE_IMMUTABLE_LIST_OPS("float", DoubleList)
-    CREATE_IMMUTABLE_LIST_OPS("t", GenericList)
-
-#define CREATE_LIST_OPS(decl_type, c_type) \
-    Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>), \
-    Operator("aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type "[]", listAdd<Shared<c_type>, c_type::ElemType>), \
-    Operator( \
-        "aten::slice(" decl_type "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type "[]", \
-        listSlice<Shared<c_type>, c_type::ElemType>), \
-
-
-    CREATE_LIST_OPS("int", IntList)
-    CREATE_LIST_OPS("float", DoubleList)
-    CREATE_LIST_OPS("Tensor", TensorList)
-    CREATE_LIST_OPS("t", GenericList)
+    Operator(
+        "aten::select(Tensor[](a) list, int idx) -> Tensor(*)",
+        listSelect<Shared<TensorList>>),
+    Operator(
+        "aten::_set_item(Tensor[](a!) l, int idx, Tensor el) -> Tensor[](a!)",
+        listSetItem<Shared<TensorList>, at::Tensor>),
+
+// Mutable ops for lists containing immutable types.
+#define CREATE_IMMUTABLE_LIST_OPS(decl_type, c_type)                   \
+  Operator(                                                            \
+      "aten::select(" decl_type "[] a, int b) -> " decl_type,          \
+      listSelect<Shared<c_type>>),                                     \
+      Operator(                                                        \
+          "aten::append(" decl_type "[](a!) self, " decl_type          \
+          " el) -> " decl_type "[](a!)",                               \
+          listAppend<Shared<c_type>, c_type::ElemType>),               \
+      Operator(                                                        \
+          "aten::_set_item(" decl_type "[](a!) l, int idx, " decl_type \
+          " el) -> " decl_type "[](a!)",                               \
+          listSetItem<Shared<c_type>, c_type::ElemType>)
+
+    CREATE_IMMUTABLE_LIST_OPS("int", IntList),
+    CREATE_IMMUTABLE_LIST_OPS("float", DoubleList),
+    CREATE_IMMUTABLE_LIST_OPS("t", GenericList),
+
+#define CREATE_LIST_OPS(decl_type, c_type)                                          \
+  Operator("aten::len(" decl_type "[] a) -> int", listLen<Shared<c_type>>),         \
+      Operator(                                                                     \
+          "aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type           \
+          "[]",                                                                     \
+          listAdd<Shared<c_type>, c_type::ElemType>),                               \
+      Operator(                                                                     \
+          "aten::slice(" decl_type                                                  \
+          "[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \
+          "[]",                                                                     \
+          listSlice<Shared<c_type>, c_type::ElemType>)
+
+    CREATE_LIST_OPS("int", IntList),
+    CREATE_LIST_OPS("float", DoubleList),
+    CREATE_LIST_OPS("Tensor", TensorList),
+    CREATE_LIST_OPS("t", GenericList),
 #undef CREATE_LIST_OPS
 
-
     Operator("aten::eq(int[] a, int[] b) -> bool", listEq<Shared<IntList>>),
-    Operator("aten::eq(float[] a, float[] b) -> bool", listEq<Shared<DoubleList>>),
-    Operator("aten::eq(Tensor[] a, Tensor[] b) -> bool", listEq<Shared<TensorList>>),
+    Operator(
+        "aten::eq(float[] a, float[] b) -> bool",
+        listEq<Shared<DoubleList>>),
+    Operator(
+        "aten::eq(Tensor[] a, Tensor[] b) -> bool",
+        listEq<Shared<TensorList>>),
     Operator("aten::ne(int[] a, int[] b) -> bool", listNe<Shared<IntList>>),
-    Operator("aten::ne(float[] a, float[] b) -> bool", listNe<Shared<DoubleList>>),
-    Operator("aten::ne(Tensor[] a, Tensor[] b) -> bool", listNe<Shared<TensorList>>),
-
-
-#define CREATE_COPY_OP(other_type, c_type)                              \
-  Operator(                                                             \
-      "aten::copy_(Tensor(a!) self, " #other_type                       \
-      " other) -> Tensor(a!)",                                          \
-      [](const Node* node) {                                            \
-        return [=](Stack& stack) {                                      \
-          at::Tensor t;                                                 \
-          c_type other;                                                 \
-          pop(stack, t, other);                                         \
-          std::move(t) = other; /* NOLINT(bugprone-use-after-move) */   \
+    Operator(
+        "aten::ne(float[] a, float[] b) -> bool",
+        listNe<Shared<DoubleList>>),
+    Operator(
+        "aten::ne(Tensor[] a, Tensor[] b) -> bool",
+        listNe<Shared<TensorList>>),
+
+#define CREATE_COPY_OP(other_type, c_type)                                 \
+  Operator(                                                                \
+      "aten::copy_(Tensor(a!) self, " #other_type " other) -> Tensor(a!)", \
+      [](const Node* node) {                                               \
+        return [=](Stack& stack) {                                         \
+          at::Tensor t;                                                    \
+          c_type other;                                                    \
+          pop(stack, t, other);                                            \
+          std::move(t) = other; /* NOLINT(bugprone-use-after-move) */      \
           push(stack, std::move(t)); /* NOLINT(bugprone-use-after-move) */ \
-          return 0;                                                     \
-        };                                                              \
-      }),
+          return 0;                                                        \
+        };                                                                 \
+      })
 
-    CREATE_COPY_OP(Tensor, at::Tensor)
-    CREATE_COPY_OP(int, int64_t)
-    CREATE_COPY_OP(float, double)
+    CREATE_COPY_OP(Tensor, at::Tensor),
+    CREATE_COPY_OP(int, int64_t),
+    CREATE_COPY_OP(float, double),
 #undef CREATE_COPY_OP
 
-    DEFINE_BINARY_OP(aten::add, a + b)
-    DEFINE_BINARY_OP(aten::sub, a - b)
-    DEFINE_BINARY_OP(aten::mul, a * b)
-    DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b)))
-
-    // Pass in two ops for handling int and float separately as % in C++ only works for int
-    // The modulus calculation is different between C++ and Python (on negative), we preserve
-    // the python behavior as it's more common and match python syntax, hence the conversion.
-    DEFINE_GENERIC_OP(aten::remainder, (b + (a % b)) % b, fmod((b + fmod(a, b)), b), int, float)
-    DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float)
-
+    DEFINE_BINARY_OP(aten::add, a + b),
+    DEFINE_BINARY_OP(aten::sub, a - b),
+    DEFINE_BINARY_OP(aten::mul, a* b),
+    DEFINE_BINARY_OP(aten::pow, static_cast<decltype(a)>(pow(a, b))),
+
+    // Pass in two ops for handling int and float separately as % in C++ only
+    // works for int The modulus calculation is different between C++ and Python
+    // (on negative), we preserve the python behavior as it's more common and
+    // match python syntax, hence the conversion.
+    DEFINE_GENERIC_OP(
+        aten::remainder,
+        (b + (a % b)) % b,
+        fmod((b + fmod(a, b)), b),
+        int,
+        float),
+    DEFINE_INT_FLOAT_OP(aten::remainder, fmod((b + fmod(a, b)), b), float),
 
     // in c++ int division rounds to the integer closer to 0, in python floordiv
     // rounds to lower integer
-    DEFINE_GENERIC_OP(aten::floordiv,
-      static_cast<int64_t>(std::floor(static_cast<double>(a) / static_cast<double>(b))),
-      std::floor(a / b), int, float)
-    DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float)
-
-    //only used in loop unrolling, not exposed to end users
-    DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b)
-
-    DEFINE_INT_OP(aten::__and__, a & b)
-    DEFINE_INT_OP(aten::__or__, a | b)
-    DEFINE_INT_OP(aten::__xor__, a ^ b)
+    DEFINE_GENERIC_OP(
+        aten::floordiv,
+        static_cast<int64_t>(
+            std::floor(static_cast<double>(a) / static_cast<double>(b))),
+        std::floor(a / b),
+        int,
+        float),
+    DEFINE_INT_FLOAT_OP(aten::floordiv, std::floor(a / b), float),
+
+    // only used in loop unrolling, not exposed to end users
+    DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
+
+    DEFINE_INT_OP(aten::__and__, a& b),
+    DEFINE_INT_OP(aten::__or__, a | b),
+    DEFINE_INT_OP(aten::__xor__, a ^ b),
 
     // NB: This is the python truediv operation
-    Operator("aten::div(int a, int b) -> float",
+    Operator(
+        "aten::div(int a, int b) -> float",
         [](const Node* node) {
           return [=](Stack& stack) {
             int64_t a, b;
@@ -1045,7 +1109,8 @@ Operator(                                                                      \
             return 0;
           };
         }),
-    Operator("aten::div(float a, float b) -> float",
+    Operator(
+        "aten::div(float a, float b) -> float",
         [](const Node* node) {
           return [=](Stack& stack) {
             double a, b;
@@ -1055,7 +1120,8 @@ Operator(                                                                      \
           };
         }),
 
-    Operator("aten::floor(float a) -> int",
+    Operator(
+        "aten::floor(float a) -> int",
         [](const Node* node) {
           return [=](Stack& stack) {
             double a;
@@ -1065,16 +1131,16 @@ Operator(                                                                      \
           };
         }),
 
-    DEFINE_COMPARISON_OP(aten::ne, a != b)
-    DEFINE_COMPARISON_OP(aten::eq, a == b)
-    DEFINE_COMPARISON_OP(aten::lt, a < b)
-    DEFINE_COMPARISON_OP(aten::gt, a > b)
-    DEFINE_COMPARISON_OP(aten::le, a <= b)
-    DEFINE_COMPARISON_OP(aten::ge, a >= b)
+    DEFINE_COMPARISON_OP(aten::ne, a != b),
+    DEFINE_COMPARISON_OP(aten::eq, a == b),
+    DEFINE_COMPARISON_OP(aten::lt, a < b),
+    DEFINE_COMPARISON_OP(aten::gt, a > b),
+    DEFINE_COMPARISON_OP(aten::le, a <= b),
+    DEFINE_COMPARISON_OP(aten::ge, a >= b),
 
-    DEFINE_BOOL_OP(aten::__and__, a && b)
-    DEFINE_BOOL_OP(aten::__or__, a || b)
-    DEFINE_BOOL_OP(aten::__xor__, a != b)
+    DEFINE_BOOL_OP(aten::__and__, a&& b),
+    DEFINE_BOOL_OP(aten::__or__, a || b),
+    DEFINE_BOOL_OP(aten::__xor__, a != b),
 
     Operator(
         "aten::neg(int self) -> int",
@@ -1128,7 +1194,7 @@ Operator(                                                                      \
             pop(stack, t);
             std::vector<int64_t> elems;
             elems.reserve(t.size(0));
-            for(int i = 0; i < t.size(0); i++){
+            for (int i = 0; i < t.size(0); i++) {
               elems.push_back(*t[i].data<int32_t>());
             }
             push(stack, jit::IntList::create(elems));
@@ -1143,7 +1209,7 @@ Operator(                                                                      \
             pop(stack, l);
             auto t = torch::empty(
                 {static_cast<int64_t>(l.size())}, at::dtype(at::kInt));
-            for(size_t i = 0; i < l.size(); i++){
+            for (size_t i = 0; i < l.size(); i++) {
               t[i] = l[i];
             }
             push(stack, t);
@@ -1152,24 +1218,28 @@ Operator(                                                                      \
         }),
 });
 
-
 // checking one of size & scale_factor is set
 // if scale_factor is a double list check that it's len == dim
 // reference: _check_size_scale_factor in torch/nn/functional.py
-void _check_size_factor(size_t dim, const IValue& size, const IValue& scale_factor) {
+void _check_size_factor(
+    size_t dim,
+    const IValue& size,
+    const IValue& scale_factor) {
   if (size.isNone() && scale_factor.isNone()) {
     throw std::runtime_error("either size or scale_factor should be defined");
   }
   if (!size.isNone() && !scale_factor.isNone()) {
-    throw std::runtime_error("only one of size or scale_factor should be defined");
+    throw std::runtime_error(
+        "only one of size or scale_factor should be defined");
   }
   if (scale_factor.isDoubleList()) {
     auto scale_len = scale_factor.toDoubleListRef().size();
     if (scale_len != dim) {
       std::stringstream str;
       str << "scale_factor shape must match input shape. Input is " << dim
-        << "D, scale_factor size is " << scale_len;
-      throw std::runtime_error("only one of size or scale_factor should be defined");
+          << "D, scale_factor size is " << scale_len;
+      throw std::runtime_error(
+          "only one of size or scale_factor should be defined");
     }
   }
 }
@@ -1177,7 +1247,11 @@ void _check_size_factor(size_t dim, const IValue& size, const IValue& scale_fact
 // reference: _output_size in torch/nn/functional.py
 // size can be none, int or intlist
 // scale_factors can be none, float, or floatlist
-std::vector<int64_t> _output_size(const at::Tensor& input, size_t dim, const IValue& size, const IValue& scale_factors) {
+std::vector<int64_t> _output_size(
+    const at::Tensor& input,
+    size_t dim,
+    const IValue& size,
+    const IValue& scale_factors) {
   if (!size.isNone()) {
     if (size.isInt()) {
       std::vector<int64_t> repeated(dim, size.toInt());
@@ -1210,33 +1284,44 @@ at::Tensor interpolate(
     c10::optional<bool> align_corners) {
   if ((mode == "nearest" || mode == "area")) {
     if (align_corners != c10::nullopt) {
-      throw std::runtime_error("align_corners option can only be set with the "
-                             "interpolating modes: linear | bilinear | bicubic | trilinear");
+      throw std::runtime_error(
+          "align_corners option can only be set with the "
+          "interpolating modes: linear | bilinear | bicubic | trilinear");
     }
   } else {
     if (align_corners == c10::nullopt) {
-      AT_WARN("Default upsampling behavior when mode=", mode, " is changed "
-        "to align_corners=False since 0.4.0. Please specify align_corners=True "
-        "if the old behavior is desired. See the documentation of nn.Upsample for details");
+      AT_WARN(
+          "Default upsampling behavior when mode=",
+          mode,
+          " is changed "
+          "to align_corners=False since 0.4.0. Please specify align_corners=True "
+          "if the old behavior is desired. See the documentation of nn.Upsample for details");
       align_corners = false;
     }
   }
 
   auto input_dim = input.dim();
   if (input_dim == 3 && mode == "nearest")
-    return at::upsample_nearest1d(input, _output_size(input, 1, size, scale_factors));
+    return at::upsample_nearest1d(
+        input, _output_size(input, 1, size, scale_factors));
   if (input_dim == 4 && mode == "nearest")
-    return at::upsample_nearest2d(input, _output_size(input, 2, size, scale_factors));
+    return at::upsample_nearest2d(
+        input, _output_size(input, 2, size, scale_factors));
   if (input_dim == 5 && mode == "nearest")
-    return at::upsample_nearest3d(input, _output_size(input, 3, size, scale_factors));
+    return at::upsample_nearest3d(
+        input, _output_size(input, 3, size, scale_factors));
   if (input_dim == 3 && mode == "area")
-    return at::adaptive_avg_pool1d(input, _output_size(input, 1, size, scale_factors));
+    return at::adaptive_avg_pool1d(
+        input, _output_size(input, 1, size, scale_factors));
   if (input_dim == 4 && mode == "area")
-    return at::adaptive_avg_pool2d(input, _output_size(input, 2, size, scale_factors));
+    return at::adaptive_avg_pool2d(
+        input, _output_size(input, 2, size, scale_factors));
   if (input_dim == 5 && mode == "area")
-    return at::adaptive_avg_pool3d(input, _output_size(input, 3, size, scale_factors));
+    return at::adaptive_avg_pool3d(
+        input, _output_size(input, 3, size, scale_factors));
   if (input_dim == 3 && mode == "linear")
-    return at::upsample_linear1d(input, _output_size(input, 1, size, scale_factors), *align_corners);
+    return at::upsample_linear1d(
+        input, _output_size(input, 1, size, scale_factors), *align_corners);
   if (input_dim == 3 && mode == "bilinear")
     throw std::runtime_error("Got 3D input, but bilinear mode needs 4D input");
   if (input_dim == 3 && mode == "bicubic")
@@ -1246,9 +1331,11 @@ at::Tensor interpolate(
   if (input_dim == 4 && mode == "linear")
     throw std::runtime_error("Got 4D input, but linear mode needs 3D input");
   if (input_dim == 4 && mode == "bilinear")
-    return at::upsample_bilinear2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
+    return at::upsample_bilinear2d(
+        input, _output_size(input, 2, size, scale_factors), *align_corners);
   if (input_dim == 4 && mode == "bicubic")
-    return at::upsample_bicubic2d(input, _output_size(input, 2, size, scale_factors), *align_corners);
+    return at::upsample_bicubic2d(
+        input, _output_size(input, 2, size, scale_factors), *align_corners);
   if (input_dim == 4 && mode == "trilinear")
     throw std::runtime_error("Got 4D input, but trilinear mode needs 5D input");
   if (input_dim == 5 && mode == "linear")
@@ -1258,11 +1345,17 @@ at::Tensor interpolate(
   if (input_dim == 5 && mode == "bicubic")
     throw std::runtime_error("Got 5D input, but bicubic mode needs 4D input");
   if (input_dim == 5 && mode == "trilinear")
-    return at::upsample_trilinear3d(input, _output_size(input, 3, size, scale_factors), *align_corners);
-
-  AT_ERROR("Input Error: Only 3D, 4D and 5D input Tensors supported",
-    " (got ", input_dim, "D) for the modes: nearest | linear | bilinear | trilinear",
-    " (got ", mode, ") ");
+    return at::upsample_trilinear3d(
+        input, _output_size(input, 3, size, scale_factors), *align_corners);
+
+  AT_ERROR(
+      "Input Error: Only 3D, 4D and 5D input Tensors supported",
+      " (got ",
+      input_dim,
+      "D) for the modes: nearest | linear | bilinear | trilinear",
+      " (got ",
+      mode,
+      ") ");
 }
 
 Operation interpolate_op(const Node* n) {
@@ -1273,7 +1366,8 @@ Operation interpolate_op(const Node* n) {
     std::string mode;
     IValue align_corners;
     pop(stack, input, size, scale_factors, mode, align_corners);
-    at::Tensor res = interpolate(input, size, scale_factors, mode, align_corners.toOptional<bool>());
+    at::Tensor res = interpolate(
+        input, size, scale_factors, mode, align_corners.toOptional<bool>());
     push(stack, res);
     return 0;
   };
@@ -1294,7 +1388,8 @@ IValue convert_scale_factor_to_double(const IValue& int_ivalue) {
     return IValue();
   } else {
     std::stringstream ss;
-    ss << "Expecting optional int or int list arg for scale factor, got" << int_ivalue;
+    ss << "Expecting optional int or int list arg for scale factor, got"
+       << int_ivalue;
     throw std::runtime_error(ss.str());
   }
   return scale_factor_double;
@@ -1306,8 +1401,10 @@ Operation upsample_nearest_op(const Node* n) {
     IValue size;
     IValue scale_factor_int;
     pop(stack, input, size, scale_factor_int);
-    IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
-    at::Tensor res = interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
+    IValue scale_factor_double =
+        convert_scale_factor_to_double(scale_factor_int);
+    at::Tensor res =
+        interpolate(input, size, scale_factor_double, "nearest", c10::nullopt);
     push(stack, res);
     return 0;
   };
@@ -1321,8 +1418,14 @@ Operation upsample_op(const Node* n) {
     std::string mode;
     IValue align_corners;
     pop(stack, input, size, scale_factor_int, mode, align_corners);
-    IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
-    at::Tensor res = interpolate(input, size, scale_factor_double, mode, align_corners.toOptional<bool>());
+    IValue scale_factor_double =
+        convert_scale_factor_to_double(scale_factor_int);
+    at::Tensor res = interpolate(
+        input,
+        size,
+        scale_factor_double,
+        mode,
+        align_corners.toOptional<bool>());
     push(stack, res);
     return 0;
   };
@@ -1334,58 +1437,57 @@ Operation upsample_bilinear_op(const Node* n) {
     IValue size;
     IValue scale_factor_int;
     pop(stack, input, size, scale_factor_int);
-    IValue scale_factor_double = convert_scale_factor_to_double(scale_factor_int);
-    at::Tensor res = interpolate(input, size, scale_factor_double, "bilinear", true);
+    IValue scale_factor_double =
+        convert_scale_factor_to_double(scale_factor_int);
+    at::Tensor res =
+        interpolate(input, size, scale_factor_double, "bilinear", true);
     push(stack, res);
     return 0;
   };
 }
 
-
 RegisterOperators reg3({
-  Operator(
-      "aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      interpolate_op),
-  Operator(
-      "aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      interpolate_op),
-  Operator(
-      "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      interpolate_op),
-  Operator(
-      "aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      interpolate_op),
-
-  Operator(
-      "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
-      upsample_nearest_op),
-  Operator(
-      "aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
-      upsample_nearest_op),
-
-  Operator(
-      "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      upsample_op),
-  Operator(
-      "aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
-      upsample_op),
-
-
-  Operator(
-      "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
-      upsample_bilinear_op),
-  Operator(
-      "aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
-      upsample_bilinear_op),
-  Operator(
-      "aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
-      upsample_bilinear_op),
-  Operator(
-      "aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
-      upsample_bilinear_op),
+    Operator(
+        "aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        interpolate_op),
+    Operator(
+        "aten::__interpolate(Tensor input, int[]? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        interpolate_op),
+    Operator(
+        "aten::__interpolate(Tensor input, int? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        interpolate_op),
+    Operator(
+        "aten::__interpolate(Tensor input, int[]? size = None, float? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        interpolate_op),
 
-});
+    Operator(
+        "aten::__upsample_nearest(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
+        upsample_nearest_op),
+    Operator(
+        "aten::__upsample_nearest(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
+        upsample_nearest_op),
+
+    Operator(
+        "aten::__upsample(Tensor input, int? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        upsample_op),
+    Operator(
+        "aten::__upsample(Tensor input, int[]? size = None, int? scale_factor = None, str mode = 'nearest', bool? align_corners = None) -> Tensor",
+        upsample_op),
+
+    Operator(
+        "aten::__upsample_bilinear(Tensor input, int? size = None, int? scale_factor = None) -> Tensor",
+        upsample_bilinear_op),
+    Operator(
+        "aten::__upsample_bilinear(Tensor input, int[]? size = None, int? scale_factor = None) -> Tensor",
+        upsample_bilinear_op),
+    Operator(
+        "aten::__upsample_bilinear(Tensor input, int? size = None, int[]? scale_factor = None) -> Tensor",
+        upsample_bilinear_op),
+    Operator(
+        "aten::__upsample_bilinear(Tensor input, int[]? size = None, int[]? scale_factor = None) -> Tensor",
+        upsample_bilinear_op),
 
+});
 
 at::Tensor leaky_relu(const at::Tensor& tensor, double scalar) {
   return at::leaky_relu(tensor, scalar);
@@ -1396,7 +1498,10 @@ at::Tensor cat(const std::vector<at::Tensor>& tensors) {
 
 static auto reg4 =
     torch::jit::RegisterOperators()
-        .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor", &leaky_relu)
+        .op("_test::leaky_relu(Tensor self, float v=0.01) -> Tensor",
+            &leaky_relu)
         .op("_test::cat(Tensor[] inputs) -> Tensor", &cat);
 
-}}} // torch::jit::anon
+} // namespace
+} // namespace jit
+} // namespace torch
index e467823..ad8265a 100644 (file)
@@ -1,11 +1,12 @@
+#include <ATen/ExpandUtils.h>
 #include <torch/csrc/autograd/profiler.h>
 #include <torch/csrc/jit/custom_operator.h>
 #include <torch/csrc/jit/operator.h>
+
 #include <torch/csrc/api/include/torch/utils.h>
-#include <ATen/ExpandUtils.h>
 
-#include <sstream>
 #include <regex>
+#include <sstream>
 
 namespace torch {
 namespace jit {
@@ -26,9 +27,7 @@ RegisterOperators reg({
         }),
     Operator(
         "aten::Size(int[] sizes) -> int[]",
-        [](Stack& stack) {
-          return 0;
-        }),
+        [](Stack& stack) { return 0; }),
     Operator(
         "aten::size(Tensor self) -> int[]",
         [](Stack& stack) {
@@ -67,14 +66,14 @@ RegisterOperators reg({
 
             auto args = last(stack, num_inputs - 1);
             std::stringstream ss;
-            for(size_t begin = 0, used_args = 0; true; ++used_args) {
+            for (size_t begin = 0, used_args = 0; true; ++used_args) {
               size_t loc = format.find("{}", begin);
-              if(loc == std::string::npos) {
+              if (loc == std::string::npos) {
                 ss << format.substr(begin);
                 break;
               }
               ss << format.substr(begin, loc - begin);
-              if(used_args >= args.size()) {
+              if (used_args >= args.size()) {
                 AT_ERROR("Too few arguments for format string: ", format);
               }
               ss << args[used_args];
@@ -97,33 +96,35 @@ RegisterOperators reg({
           };
         }),
     Operator(
-      "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
-      [](const Node* node) {
-        return [](Stack& stack) {
-          at::Tensor weight;
-          at::Tensor input;
-          double max_norm;
-          double norm_type;
-          pop(stack, weight, input, max_norm, norm_type);
+        "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor",
+        [](const Node* node) {
+          return [](Stack& stack) {
+            at::Tensor weight;
+            at::Tensor input;
+            double max_norm;
+            double norm_type;
+            pop(stack, weight, input, max_norm, norm_type);
 
-          // TODO: remove when script supports setting grad mode
-          torch::NoGradGuard no_grad;
+            // TODO: remove when script supports setting grad mode
+            torch::NoGradGuard no_grad;
 
-          at::Tensor result = at::embedding_renorm_(weight, input, max_norm, norm_type);
-          push(stack, result);
+            at::Tensor result =
+                at::embedding_renorm_(weight, input, max_norm, norm_type);
+            push(stack, result);
 
-          return 0;
-        };
-      }),
+            return 0;
+          };
+        }),
     Operator(
-      "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
-      [](const Node* node) {
-        return [](Stack& stack) {
-          // Everything is a list at the point this is used, so don't do anything
-          drop(stack, 3);
-          return 0;
-        };
-      }),
+        "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
+        [](const Node* node) {
+          return [](Stack& stack) {
+            // Everything is a list at the point this is used, so don't do
+            // anything
+            drop(stack, 3);
+            return 0;
+          };
+        }),
 
 });
 }
index 430d0cb..60b529c 100644 (file)
@@ -1,19 +1,20 @@
 #pragma once
 #include <functional>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 class ResourceGuard {
   std::function<void()> _destructor;
   bool _released;
 
-public:
+ public:
   ResourceGuard(std::function<void()> destructor)
-    : _destructor(std::move(destructor))
-    , _released(false) {}
+      : _destructor(std::move(destructor)), _released(false) {}
 
   ~ResourceGuard() {
-    if (!_released) _destructor();
+    if (!_released)
+      _destructor();
   }
 
   void release() {
@@ -21,4 +22,5 @@ public:
   }
 };
 
-}}
+} // namespace jit
+} // namespace torch
index e9e7a3a..38f98ac 100644 (file)
@@ -1,19 +1,19 @@
 #include <torch/csrc/jit/ir.h>
 
-
-#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/operator.h>
 
+#include <algorithm>
 #include <iostream>
-#include <unordered_map>
-#include <unordered_set>
 #include <set>
-#include <stack>
 #include <sstream>
-#include <algorithm>
+#include <stack>
 #include <string>
+#include <unordered_map>
+#include <unordered_set>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 ScopePtr Scope::push(Symbol name) {
   return c10::make_intrusive<Scope>(intrusive_from_this(), name);
@@ -52,4 +52,5 @@ std::string Scope::namesFromRoot(const std::string& separator) const {
   return out;
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 1ece1f9..263c080 100644 (file)
@@ -1,8 +1,8 @@
 #pragma once
-#include <torch/csrc/jit/interned_strings.h>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <c10/macros/Macros.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/interned_strings.h>
 
 #include <memory>
 
@@ -20,7 +20,7 @@ struct Scope;
 using ScopePtr = c10::intrusive_ptr<Scope>;
 
 struct TORCH_API Scope : public c10::intrusive_ptr_target {
-private:
+ private:
   ScopePtr parent_;
   Symbol name_;
   ScopePtr intrusive_from_this() {
@@ -30,7 +30,8 @@ private:
                                            // to account for this ownership
     return c10::intrusive_ptr<Scope>::reclaim(this);
   }
-public:
+
+ public:
   Scope() {
     name_ = Symbol::scope("");
   }
@@ -62,7 +63,7 @@ public:
     return name_;
   }
 
-  std::string namesFromRoot(const std::string& separator="/") const;
+  std::string namesFromRoot(const std::string& separator = "/") const;
 };
 
 } // namespace jit
index 2b08145..7a5763d 100644 (file)
@@ -1,11 +1,13 @@
-#include <torch/csrc/jit/script/builtin_functions.h>
 #include <torch/csrc/api/include/torch/jit.h>
 #include <torch/csrc/jit/code_template.h>
+#include <torch/csrc/jit/script/builtin_functions.h>
 
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
 
 auto scalar_operators_source = CodeTemplate(
-R"SCRIPT(
+    R"SCRIPT(
 def mul(a : ${Scalar}, b : Tensor) -> Tensor:
   return b * a
 def add(a : ${Scalar}, b : Tensor) -> Tensor:
@@ -29,22 +31,22 @@ def div(a : ${Scalar}, b : Tensor) -> Tensor:
 )SCRIPT");
 
 auto _ntuple_ops = CodeTemplate(
-R"SCRIPT(
+    R"SCRIPT(
 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
   return x
 )SCRIPT");
 
 struct BuiltinFunctionRegistry {
-
   const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
     const static std::vector<Method*> empty;
     // when initializing the builtin function library, we will re-enter
     // getAllBuiltinFunctionsFor since it is called in the compiler to
-    // lookup builtins and initializing the builtin functions calls the compiler.
-    // To avoid deadlocking, we use a recursive mutex (same thread can re-lock,
-    // the mutex without waiting), and report no loaded builtins during init.
+    // lookup builtins and initializing the builtin functions calls the
+    // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
+    // re-lock, the mutex without waiting), and report no loaded builtins during
+    // init.
     std::lock_guard<std::recursive_mutex> guard(mutex);
-    if(state == INTIIALIZING) {
+    if (state == INTIIALIZING) {
       return empty;
     } else if (state == UNINITIALIZED) {
       state = INTIIALIZING;
@@ -53,23 +55,24 @@ struct BuiltinFunctionRegistry {
     }
     JIT_ASSERT(state == INITIALIZED);
     auto it = builtins_by_name.find(name);
-    if(it == builtins_by_name.end())
+    if (it == builtins_by_name.end())
       return empty;
     return it->second;
   }
-private:
+
+ private:
   void loadSource(const std::string& source) {
     auto module = std::make_shared<script::Module>();
     defineMethodsInModule(
         module, source, script::nativeResolver, /*self=*/nullptr);
     modules.push_back(module);
     for (auto& method : module->get_methods()) {
-      builtins_by_name[Symbol::fromQualString("aten::" + method.key())].push_back(
-          method->get());
+      builtins_by_name[Symbol::fromQualString("aten::" + method.key())]
+          .push_back(method->get());
     }
   }
   void loadBuiltinFunctions() {
-    for(auto scalar : {"float", "int"}) {
+    for (auto scalar : {"float", "int"}) {
       TemplateEnv env;
       env.s("Scalar", scalar);
       loadSource(scalar_operators_source.format(env));
@@ -77,13 +80,13 @@ private:
 
     using str_pair = std::pair<std::string, std::string>;
     const std::vector<str_pair> name_len = {
-      str_pair("single", "1"),
-      str_pair("pair", "2"),
-      str_pair("triple", "3"),
-      str_pair("quadruple", "4"),
+        str_pair("single", "1"),
+        str_pair("pair", "2"),
+        str_pair("triple", "3"),
+        str_pair("quadruple", "4"),
     };
-    for(auto scalar: {"float", "int"}) {
-      for (auto pair: name_len) {
+    for (auto scalar : {"float", "int"}) {
+      for (auto pair : name_len) {
         TemplateEnv env;
         env.s("Scalar", scalar);
         env.s("name", pair.first);
@@ -92,7 +95,7 @@ private:
       }
     }
   }
-  enum {UNINITIALIZED, INTIIALIZING, INITIALIZED} state = UNINITIALIZED;
+  enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
   std::recursive_mutex mutex;
   std::vector<std::shared_ptr<Module>> modules;
   std::unordered_map<Symbol, std::vector<Method*>> builtins_by_name;
@@ -103,4 +106,6 @@ TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name) {
   return registry.getAllBuiltinFunctionsFor(name);
 }
 
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
index c8d50bb..42e15e7 100644 (file)
@@ -3,11 +3,12 @@
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/script/module.h>
 
-namespace torch { namespace jit { namespace script {
-
+namespace torch {
+namespace jit {
+namespace script {
 
 TORCH_API const std::vector<Method*>& getAllBuiltinFunctionsFor(Symbol name);
 
-
-
-}}}
+}
+} // namespace jit
+} // namespace torch
index b4100e0..f2636d9 100644 (file)
@@ -1,17 +1,16 @@
-#include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/final_returns.h>
-#include <torch/csrc/jit/passes/lower_tuples.h>
-#include <torch/csrc/jit/script/type_parser.h>
-#include <torch/csrc/jit/passes/constant_pooling.h>
-#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
+#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/script/parser.h>
-#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/script/schema_matching.h>
+#include <torch/csrc/jit/script/type_parser.h>
 #include <torch/csrc/utils/object_ptr.h>
-#include <torch/csrc/jit/operator.h>
-#include <torch/csrc/jit/hooks_for_testing.h>
 
 #include <torch/csrc/jit/constants.h>
 
@@ -31,7 +30,7 @@ using AttributeMap = std::unordered_map<std::string, Const>;
 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
 
 static Value* asSimple(const SugaredValuePtr& value) {
-  if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
+  if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
     return sv->getValue();
   }
   return nullptr;
@@ -54,9 +53,8 @@ static bool meaningfulName(const std::string& name) {
 }
 
 // Auxiliary data structure for desugaring variable binding into our always
-// explicitly scoped language as we descend down
-// nested control structures in the frontend (which themselves don't introduce
-// scopes)
+// explicitly scoped language as we descend down nested control structures in
+// the frontend (which themselves don't introduce scopes)
 //
 // The algorithm is roughly as follows:
 // 1) While emitting a block within a control operator, add inputs and outputs
@@ -81,10 +79,17 @@ static bool meaningfulName(const std::string& name) {
 //      the IR API, but for now we choose to pessimisitically create inputs and
 //      delete unnecessary ones later with replaceAllusesWith().
 struct Environment {
-  Environment(Method & method, Resolver resolver, Block* b, std::shared_ptr<Environment> next = nullptr)
-      : method(method), resolver(std::move(resolver)), b(b), next(std::move(next)) {}
+  Environment(
+      Method& method,
+      Resolver resolver,
+      Block* b,
+      std::shared_ptr<Environment> next = nullptr)
+      : method(method),
+        resolver(std::move(resolver)),
+        b(b),
+        next(std::move(next)) {}
 
-  Method & method;
+  Method& method;
   Resolver resolver;
   std::vector<std::string> captured_inputs;
   std::unordered_map<std::string, std::string> error_messages;
@@ -94,7 +99,7 @@ struct Environment {
 
   // set type error in the lowest environment. if the variable is used after an
   // error has been set, then we will use the more informative error message
-  void setVariableTypeError(const std::string& name, const std::string &msg) {
+  void setVariableTypeError(const std::string& name, const std::stringmsg) {
     auto runner = this;
     while (runner->next) {
       runner = runner->next.get();
@@ -130,7 +135,7 @@ struct Environment {
 
   SugaredValuePtr findInAnyFrame(const std::string& name) {
     for (auto runner = this; runner; runner = runner->next.get()) {
-      if(auto r = runner->findInThisFrame(name)) {
+      if (auto r = runner->findInThisFrame(name)) {
         return r;
       }
     }
@@ -146,15 +151,17 @@ struct Environment {
     // this ensures consistency of the order of loop-carried dependencies
     // even when the use in the loop is in a different order
     size_t insert_pos = 0;
-    while (insert_pos < captured_inputs.size() && name > captured_inputs[insert_pos]) {
+    while (insert_pos < captured_inputs.size() &&
+           name > captured_inputs[insert_pos]) {
       insert_pos++;
     }
     captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
 
     // Create the input
     const size_t loop_carried_block_inputs_offset = 1;
-    Value* new_input = b->insertInput(loop_carried_block_inputs_offset + insert_pos)
-                           ->setType(orig->type());
+    Value* new_input =
+        b->insertInput(loop_carried_block_inputs_offset + insert_pos)
+            ->setType(orig->type());
 
     // Associate this name with this value
     auto sv = std::make_shared<SimpleValue>(new_input);
@@ -163,14 +170,17 @@ struct Environment {
     return sv;
   }
 
-  SugaredValuePtr createCapturedInputIfNeeded(const SourceRange& loc, const std::string& ident) {
+  SugaredValuePtr createCapturedInputIfNeeded(
+      const SourceRange& loc,
+      const std::string& ident) {
     auto in_frame = findInThisFrame(ident);
     if (in_frame) {
       return in_frame;
     }
 
     // recursively handles the case where parent blocks are also loops
-    auto from_parent = next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
+    auto from_parent =
+        next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
 
     // recursively create the captured input if it is the loop block
     if (from_parent && getBlockOwningKind() == prim::Loop) {
@@ -195,13 +205,16 @@ struct Environment {
     setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
   }
 
-  void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
+  void setSugaredVar(
+      const SourceRange& loc,
+      const std::string& name,
+      SugaredValuePtr value) {
     Value* as_simple_value = asSimple(value);
     if (as_simple_value && !as_simple_value->hasUniqueName() &&
         meaningfulName(name) &&
-        // note: if the value wasn't defined in this block, we might be giving a name
-        // only used inside this block to a value outside of this. this is not
-        // normally helpful for debugging and causes import/export jitter.
+        // note: if the value wasn't defined in this block, we might be giving a
+        // name only used inside this block to a value outside of this. this is
+        // not normally helpful for debugging and causes import/export jitter.
         as_simple_value->node()->owningBlock() == block()) {
       as_simple_value->setUniqueName(name);
     }
@@ -212,15 +225,19 @@ struct Environment {
     //   a = ..
     // requires 'a' to be first-class in the graph since its value depends on
     // control flow
-    if(auto parent = findInParentFrame(name)) {
-      if(!as_simple_value) {
-        throw ErrorReport(loc) << "Cannot re-assign '" << name << "' to a value of type " << value->kind() <<
-       " because " << name << " is not a first-class value.  Only reassignments to first-class values are allowed";
+    if (auto parent = findInParentFrame(name)) {
+      if (!as_simple_value) {
+        throw ErrorReport(loc)
+            << "Cannot re-assign '" << name << "' to a value of type "
+            << value->kind() << " because " << name
+            << " is not a first-class value.  Only reassignments to first-class values are allowed";
       }
       Value* simple_parent = asSimple(parent);
-      if(!simple_parent) {
-        throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() <<
-       " and " << name << " is not a first-class value.  Only reassignments to first-class values are allowed";
+      if (!simple_parent) {
+        throw ErrorReport(loc)
+            << "Cannot re-assign '" << name << "' because it has type "
+            << value->kind() << " and " << name
+            << " is not a first-class value.  Only reassignments to first-class values are allowed";
       }
       if (!as_simple_value->type()->isSubtypeOf(
               unshapedType(simple_parent->type()))) {
@@ -245,35 +262,39 @@ struct Environment {
     value_table[name] = std::move(value);
   }
 
-  SugaredValuePtr getSugaredVar(const Ident& ident, bool required=true) {
+  SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
     return getSugaredVar(ident.name(), ident.range());
   }
   Value* getVar(const Ident& ident) {
     return getSugaredVar(ident)->asValue(ident.range(), method);
   }
 
-  SugaredValuePtr getSugaredVar(const std::string& ident, const SourceRange& range, bool required=true) {
+  SugaredValuePtr getSugaredVar(
+      const std::string& ident,
+      const SourceRange& range,
+      bool required = true) {
     auto retval = createCapturedInputIfNeeded(range, ident);
 
-    if(!retval) {
+    if (!retval) {
       static std::unordered_map<std::string, SugaredValuePtr> globals = {
-        {"print", std::make_shared<PrintValue>()},
-        {"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
-        {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
-        {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
-        {"getattr", std::make_shared<GetAttrValue>()},
-        {"isinstance", std::make_shared<IsInstanceValue>()},
-        // todo(zach): remove when we can correctly export torch.full via ONNX
-        // or we have implicit conversion that can convert numbers to tensors
-        {"_to_tensor", std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
-        {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
+          {"print", std::make_shared<PrintValue>()},
+          {"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
+          {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
+          {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
+          {"getattr", std::make_shared<GetAttrValue>()},
+          {"isinstance", std::make_shared<IsInstanceValue>()},
+          // todo(zach): remove when we can correctly export torch.full via ONNX
+          // or we have implicit conversion that can convert numbers to tensors
+          {"_to_tensor",
+           std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
+          {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
       };
       auto it = globals.find(ident);
-      if(it != globals.end())
+      if (it != globals.end())
         retval = it->second;
     }
 
-    if(!retval) {
+    if (!retval) {
       retval = resolver(ident, method, range);
     }
 
@@ -305,9 +326,9 @@ struct Environment {
     // captured_inputs: lcd0, lcd1, ...
     JIT_ASSERT(b->inputs().size() == b->outputs().size());
     JIT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
-    for(size_t i = b->inputs().size() - 1; i > 0; i--) {
+    for (size_t i = b->inputs().size() - 1; i > 0; i--) {
       // nothing changed along this loop
-      if(b->inputs()[i] == b->outputs()[i]) {
+      if (b->inputs()[i] == b->outputs()[i]) {
         auto name = captured_inputs[i - 1];
         Value* orig = findInParentFrame(name)->asValue(loc, method);
         b->inputs()[i]->replaceAllUsesWith(orig);
@@ -319,18 +340,22 @@ struct Environment {
   }
   std::vector<std::string> definedVariables() {
     std::vector<std::string> result;
-    for(auto & kv : value_table) {
+    for (auto& kv : value_table) {
       result.push_back(kv.first);
     }
     return result;
   }
-private:
+
+ private:
   ValueTable value_table;
 };
 
-template<class T>
-static Value* materializeConstant(T val, Graph& graph,
-    const SourceRange& r, std::unordered_map<T, Value*>& map) {
+template <class T>
+static Value* materializeConstant(
+    T val,
+    Graph& graph,
+    const SourceRange& r,
+    std::unordered_map<T, Value*>& map) {
   auto existing_constant = map.find(val);
   if (existing_constant != map.end()) {
     return existing_constant->second;
@@ -344,9 +369,9 @@ static Value* materializeConstant(T val, Graph& graph,
 }
 
 static Value* ensureInt(const SourceRange& range, Value* v) {
-  if(!v->type()->isSubtypeOf(IntType::get())) {
-    throw ErrorReport(range) << "expected a int but found a "
-                             << v->type()->str();
+  if (!v->type()->isSubtypeOf(IntType::get())) {
+    throw ErrorReport(range)
+        << "expected a int but found a " << v->type()->str();
   }
   return v;
 }
@@ -357,8 +382,8 @@ std::shared_ptr<SugaredValue> BuiltinFunction::call(
     at::ArrayRef<NamedValue> inputs,
     at::ArrayRef<NamedValue> attributes,
     size_t n_binders) {
-  return std::make_shared<SimpleValue>(emitBuiltinCall(
-      loc, *m.graph(), symbol, self, inputs, attributes, true));
+  return std::make_shared<SimpleValue>(
+      emitBuiltinCall(loc, *m.graph(), symbol, self, inputs, attributes, true));
 }
 
 inline bool isSupportedListElementType(const TypePtr& type) {
@@ -380,25 +405,26 @@ struct to_ir {
       Resolver resolver_,
       const SugaredValuePtr& self,
       Method& method) // method being constructed
-      : method(method)
-      , graph(method.graph())
-      , resolver(std::move(resolver_))
-      , environment_stack(nullptr) {
+      : method(method),
+        graph(method.graph()),
+        resolver(std::move(resolver_)),
+        environment_stack(nullptr) {
     JIT_ASSERT(resolver);
     pushFrame(graph->block(), /*starts_def=*/true);
 
-    // Type annotations exclude explicitly typing the "self" parameter, so in the
-    // case that this is a method with self we expect one fewer parameter annotation
-    // than the number of parameters this Def takes.
+    // Type annotations exclude explicitly typing the "self" parameter, so in
+    // the case that this is a method with self we expect one fewer parameter
+    // annotation than the number of parameters this Def takes.
     if (self && def.decl().params().size() == 0) {
-      throw ErrorReport(def.decl().params().range()) << "methods must have a self argument";
+      throw ErrorReport(def.decl().params().range())
+          << "methods must have a self argument";
     }
 
     method.setSchema(emitDef(def, self, graph->block()));
     runCleanupPasses(graph);
   }
 
-private:
+ private:
   Method& method;
   std::shared_ptr<Graph> graph;
   Resolver resolver;
@@ -410,16 +436,17 @@ private:
   std::shared_ptr<Environment> environment_stack;
   std::vector<DefContext> def_stack_;
 
-  void pushFrame(Block * b, bool starts_def=false) {
+  void pushFrame(Block* b, bool starts_def = false) {
     if (starts_def) {
       def_stack_.emplace_back();
     }
-    environment_stack = std::make_shared<Environment>(method, resolver, b, environment_stack);
+    environment_stack =
+        std::make_shared<Environment>(method, resolver, b, environment_stack);
   }
-  std::shared_ptr<Environment> popFrame(bool ends_def=false) {
+  std::shared_ptr<Environment> popFrame(bool ends_def = false) {
     auto old_frame = environment_stack;
     environment_stack = environment_stack->next;
-    if(ends_def) {
+    if (ends_def) {
       def_stack_.pop_back();
     }
     return old_frame;
@@ -431,13 +458,16 @@ private:
     ConstantPooling(to_clean);
   }
 
-  FunctionSchema emitDef(const Def& def, const SugaredValuePtr& self, Block* block) {
+  FunctionSchema emitDef(
+      const Def& def,
+      const SugaredValuePtr& self,
+      Block* block) {
     auto schema = extractSchemaFromDef(def, self);
     if (schema.returns().size() == 1) {
       def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
     }
-    std::vector<Argument> arguments = emitFormalArguments(def, self, schema, block);
-
+    std::vector<Argument> arguments =
+        emitFormalArguments(def, self, schema, block);
 
     // body
     auto stmts_list = moveAllReturnsToEnd(def.statements());
@@ -446,24 +476,29 @@ private:
     return {def.name().name(), std::move(arguments), std::move(returns)};
   }
 
-  std::vector<IValue> evaluateDefaults(const SourceRange& r, const std::vector<Expr>& default_types, const std::vector<Expr>& default_exprs) {
+  std::vector<IValue> evaluateDefaults(
+      const SourceRange& r,
+      const std::vector<Expr>& default_types,
+      const std::vector<Expr>& default_exprs) {
     std::vector<IValue> default_values;
     if (default_exprs.empty())
       return default_values;
     // To evaluate the default expressions, we create a graph with no inputs,
     // and whose returns are the default values we need.
-    // We then run constant prop on this graph and check the results are constant.
-    // This approach avoids having to have separate handling of default arguments
-    // from standard expressions by piecing together existing machinery for
-    // graph generation, constant propgation, and constant extraction.
+    // We then run constant prop on this graph and check the results are
+    // constant. This approach avoids having to have separate handling of
+    // default arguments from standard expressions by piecing together existing
+    // machinery for graph generation, constant propgation, and constant
+    // extraction.
     auto tuple_type = Subscript::create(
         r,
         Var::create(r, Ident::create(r, "Tuple")),
         List<Expr>::create(r, default_types));
-    auto blank_decl =
-        Decl::create(r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
+    auto blank_decl = Decl::create(
+        r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
 
-    auto tuple_expr = TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
+    auto tuple_expr =
+        TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
     auto ret = Return::create(r, tuple_expr);
     auto def = Def::create(
         r,
@@ -477,7 +512,9 @@ private:
     return stack.at(0).toTuple()->elements();
   }
 
-  std::vector<Argument> parseArgsFromDecl(const Decl& decl, const SugaredValuePtr& self) {
+  std::vector<Argument> parseArgsFromDecl(
+      const Decl& decl,
+      const SugaredValuePtr& self) {
     auto params_begin = decl.params().begin();
     auto params_end = decl.params().end();
     if (self)
@@ -495,7 +532,8 @@ private:
         default_exprs.emplace_back(def.get());
       }
     }
-    auto default_values = evaluateDefaults(decl.range(), default_types, default_exprs);
+    auto default_values =
+        evaluateDefaults(decl.range(), default_types, default_exprs);
 
     auto defaults_it = default_values.begin();
     for (auto it = params_begin; it != params_end; ++it) {
@@ -504,7 +542,7 @@ private:
       TypePtr type;
       c10::optional<int32_t> N;
 
-      //BroadcastList list can only appear at the argument level
+      // BroadcastList list can only appear at the argument level
       if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
         type = maybe_broad_list->first;
         N = maybe_broad_list->second;
@@ -530,13 +568,14 @@ private:
   std::vector<Argument> parseReturnFromDecl(const Decl& decl) {
     // we represent no annoation on a return type as having no values in the
     // schema's return() list
-    // in emitReturn we take the actual return value to be the value of the return
-    // statement if no one was provided here
-    if(!decl.return_type().present())
+    // in emitReturn we take the actual return value to be the value of the
+    // return statement if no one was provided here
+    if (!decl.return_type().present())
       return {};
 
     if (parseBroadcastList(decl.return_type().get()))
-      throw ErrorReport(decl.return_type().range()) << "Broadcastable lists cannot appear as a return type";
+      throw ErrorReport(decl.return_type().range())
+          << "Broadcastable lists cannot appear as a return type";
     auto parsed_type = parseTypeFromExpr(decl.return_type().get());
     return {Argument(
         "",
@@ -545,35 +584,44 @@ private:
         /*default_value =*/c10::nullopt,
         /*kwarg_only =*/false)};
   }
-  FunctionSchema extractSchemaFromDef(const Def &def, const SugaredValuePtr& self) {
-      auto name = def.name().name();
-      std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
-      std::vector<Argument> returns = parseReturnFromDecl(def.decl());
-      return FunctionSchema(name, std::move(args), std::move(returns), false, false);
+  FunctionSchema extractSchemaFromDef(
+      const Def& def,
+      const SugaredValuePtr& self) {
+    auto name = def.name().name();
+    std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
+    std::vector<Argument> returns = parseReturnFromDecl(def.decl());
+    return FunctionSchema(
+        name, std::move(args), std::move(returns), false, false);
   }
 
-  std::vector<Argument> emitFormalArguments(const Def& def, const SugaredValuePtr& self, const FunctionSchema& schema, Block* block) {
+  std::vector<Argument> emitFormalArguments(
+      const Def& def,
+      const SugaredValuePtr& self,
+      const FunctionSchema& schema,
+      Block* block) {
     std::vector<Argument> arguments; // for schema
     // inputs
     auto it = def.decl().params().begin();
     auto end = def.decl().params().end();
-    auto expected_annotation_size = self ? def.decl().params().size() - 1 : def.decl().params().size();
+    auto expected_annotation_size =
+        self ? def.decl().params().size() - 1 : def.decl().params().size();
     if (schema.arguments().size() != expected_annotation_size) {
-      throw ErrorReport(def.decl().params().range()) << "Number of type annotations for"
-        << " function parameters (" << schema.arguments().size() << ")"
-        << " does not match the number of parameters on the function ("
-        << expected_annotation_size << ")!";
+      throw ErrorReport(def.decl().params().range())
+          << "Number of type annotations for"
+          << " function parameters (" << schema.arguments().size() << ")"
+          << " does not match the number of parameters on the function ("
+          << expected_annotation_size << ")!";
     }
-    if(self) {
+    if (self) {
       JIT_ASSERT(it != end);
       environment_stack->setSugaredVar(def.range(), (*it).ident().name(), self);
       ++it;
     }
     size_t arg_annotation_idx = 0;
-    for(;it != end; ++it) {
+    for (; it != end; ++it) {
       auto& name = (*it).ident().name();
       // Add the input to the graph
-      Value *new_input = block->addInput();
+      Valuenew_input = block->addInput();
       if (meaningfulName(name)) {
         new_input->setUniqueName(name);
       }
@@ -586,7 +634,10 @@ private:
     return arguments;
   }
 
-  Argument emitOutput(const SourceRange& range, const FunctionSchema& schema, Block* block) {
+  Argument emitOutput(
+      const SourceRange& range,
+      const FunctionSchema& schema,
+      Block* block) {
     // rewrites ensure there is always a return statement in program
     JIT_ASSERT(def_stack_.back().merged_return_type_);
     // outputs
@@ -599,48 +650,59 @@ private:
     return emitStatements(statements.begin(), statements.end());
   }
   std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
-      auto subgraph = std::make_shared<Graph>();
-      // note: type is set later on pack_context and context when we know it
-      Node* pack_context = graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
-      Value* context = subgraph->addInput("context");
-      // cannot use createTupleUnpack because the type is not known yet
-      Node* unpack_context = subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
-
-      std::unordered_map<Value*, Value*> captures;
-      auto env = [&](Value* v) -> Value* {
-        auto it = captures.find(v);
-        if (it != captures.end()) {
-            return it->second;
-        }
-        pack_context->addInput(v);
-        Value* r = unpack_context->addOutput()->copyMetadata(v);
-        captures[v] = r;
-        return r;
-      };
-      subgraph->block()->cloneFrom(block, env);
-      auto context_type = TupleType::create(
-          fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
-      pack_context->output()->setType(context_type);
-      context->setType(context_type);
-      return std::make_pair(std::move(subgraph), pack_context->output());
+    auto subgraph = std::make_shared<Graph>();
+    // note: type is set later on pack_context and context when we know it
+    Node* pack_context =
+        graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
+    Value* context = subgraph->addInput("context");
+    // cannot use createTupleUnpack because the type is not known yet
+    Node* unpack_context =
+        subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
+
+    std::unordered_map<Value*, Value*> captures;
+    auto env = [&](Value* v) -> Value* {
+      auto it = captures.find(v);
+      if (it != captures.end()) {
+        return it->second;
+      }
+      pack_context->addInput(v);
+      Value* r = unpack_context->addOutput()->copyMetadata(v);
+      captures[v] = r;
+      return r;
+    };
+    subgraph->block()->cloneFrom(block, env);
+    auto context_type = TupleType::create(
+        fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
+    pack_context->output()->setType(context_type);
+    context->setType(context_type);
+    return std::make_pair(std::move(subgraph), pack_context->output());
   }
   // XXX - right now closures are used _only_ for defining gradients internally
   // There are several unfinished aspects that make them unusable generally
-  // 1. We do not have a type, ivalue, operator to represent prim::Function, so closure_node has type None
+  // 1. We do not have a type, ivalue, operator to represent prim::Function, so
+  // closure_node has type None
   //    and any graphs that contain it cannot be run
-  // 2. There is no export logic for it yet, so it cannot be exported/python_printed
-  // 3. There is nothing preventing the assignment of already existing variables inside the closures
+  // 2. There is no export logic for it yet, so it cannot be
+  // exported/python_printed
+  // 3. There is nothing preventing the assignment of already existing variables
+  // inside the closures
   //    the changes to those variables will just get forgotten.
   // 4. There is no parsing support in frontend.py, this is intentional since it
   //    prevents people from accidentally using this feature.
   void emitClosure(const Def& def) {
     Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
-    closure_node->output()->setType(NoneType::get()); //it is not a real thing yet, so just say the type is none.
+    closure_node->output()->setType(
+        NoneType::get()); // it is not a real thing yet, so just say the type is
+                          // none.
     Block* block = closure_node->addBlock();
     {
       WithInsertPoint guard(block);
       pushFrame(block, /*starts_def=*/true);
-      emitDef(def, nullptr, block); //ignore schema return, we just wont use it for now since we never create a Method for the closure
+      emitDef(
+          def,
+          nullptr,
+          block); // ignore schema return, we just wont use it for now since we
+                  // never create a Method for the closure
       popFrame(/*ends_def=*/true);
     }
     std::shared_ptr<Graph> subgraph;
@@ -649,7 +711,9 @@ private:
     runCleanupPasses(subgraph);
     closure_node->eraseBlock(0);
     closure_node->g_(attr::Subgraph, std::move(subgraph));
-    auto tup = graph->insertNode(graph->createTuple({closure_node->output(), context}))->output();
+    auto tup =
+        graph->insertNode(graph->createTuple({closure_node->output(), context}))
+            ->output();
     environment_stack->setVar(def.name().range(), def.name().name(), tup);
   }
 
@@ -658,24 +722,31 @@ private:
     TypePtr result_type = def_stack_.back().declared_return_type_;
     // result type is annotated, every return must convert to that type
     if (result_type) {
-      // this guard skips implicit conversion from None -> Tensor for the return type.
-      // otherwise forgetting a return a function returning a tensor will cause a None to be
-      // converted to a tensor.
-      if (!(result_type->isSubtypeOf(DynamicType::get()) && result->type()->isSubtypeOf(NoneType::get()))) {
+      // this guard skips implicit conversion from None -> Tensor for the return
+      // type. otherwise forgetting a return a function returning a tensor will
+      // cause a None to be converted to a tensor.
+      if (!(result_type->isSubtypeOf(DynamicType::get()) &&
+            result->type()->isSubtypeOf(NoneType::get()))) {
         result = tryConvertToType(
-            stmt.range(), *graph, result_type, result, /*allow_conversions=*/true);
+            stmt.range(),
+            *graph,
+            result_type,
+            result,
+            /*allow_conversions=*/true);
       }
 
       if (!result->type()->isSubtypeOf(result_type)) {
-        throw ErrorReport(stmt.range()) << "Return value was annotated as having type " << result_type->python_str()
-          << " but is actually of type " << result->type()->python_str();
+        throw ErrorReport(stmt.range())
+            << "Return value was annotated as having type "
+            << result_type->python_str() << " but is actually of type "
+            << result->type()->python_str();
       }
     } else {
       result_type = def_stack_.back().merged_return_type_;
       if (!result_type) {
         result_type = result->type();
       }
-      if(!unifyTypes(result_type, result->type())) {
+      if (!unifyTypes(result_type, result->type())) {
         throw ErrorReport(stmt.range())
             << "Previous return statement returned a value of type "
             << result_type->python_str()
@@ -688,7 +759,9 @@ private:
     environment_stack->setVar(stmt.range(), "$return", result);
   }
 
-  void emitStatements(List<Stmt>::const_iterator begin, List<Stmt>::const_iterator end) {
+  void emitStatements(
+      List<Stmt>::const_iterator begin,
+      List<Stmt>::const_iterator end) {
     for (; begin != end; ++begin) {
       auto stmt = *begin;
       switch (stmt.kind()) {
@@ -710,14 +783,14 @@ private:
         case TK_GLOBAL:
           for (auto ident : Global(stmt).names()) {
             const auto& name = Ident(ident).name();
-            environment_stack->setVar(ident.range(), name, graph->addInput(name));
+            environment_stack->setVar(
+                ident.range(), name, graph->addInput(name));
           }
           break;
         case TK_EXPR_STMT: {
           auto expr = ExprStmt(stmt).expr();
           emitSugaredExpr(expr, 0);
-        }
-        break;
+        } break;
         case TK_RAISE:
           emitRaise(Raise(stmt).range());
           break;
@@ -749,36 +822,27 @@ private:
     return popFrame();
   }
 
-  Node* create(Symbol kind, const SourceRange& loc,  size_t n_outputs) {
-    return graph
-             ->create(kind, n_outputs)
-             ->setSourceLocation(std::make_shared<SourceRange>(loc));
+  Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
+    return graph->create(kind, n_outputs)
+        ->setSourceLocation(std::make_shared<SourceRange>(loc));
   }
 
   Value* emitTernaryIf(const TernaryIf& expr) {
     Value* cond_value = emitCond(expr.cond());
-    auto true_expr = [&] {
-      return emitExpr(expr.true_expr());
-    };
-    auto false_expr  = [&] {
-      return emitExpr(expr.false_expr());
-    };
+    auto true_expr = [&] { return emitExpr(expr.true_expr()); };
+    auto false_expr = [&] { return emitExpr(expr.false_expr()); };
     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
   }
 
   Value* emitShortCircuitIf(
       const SourceRange& loc,
-      const TreeRef & first_expr,
-      const TreeRef & second_expr,
+      const TreeRef& first_expr,
+      const TreeRef& second_expr,
       bool is_or) {
-    Value * first_value = emitCond(Expr(first_expr));
+    Value* first_value = emitCond(Expr(first_expr));
 
-    auto get_first_expr = [first_value] {
-      return first_value;
-    };
-    auto get_second_expr = [&] {
-      return emitCond(Expr(second_expr));
-    };
+    auto get_first_expr = [first_value] { return first_value; };
+    auto get_second_expr = [&] { return emitCond(Expr(second_expr)); };
 
     // if this is an OR, eval second expression if first expr is False.
     // If this is an AND, eval second expression if first expr is True
@@ -789,8 +853,11 @@ private:
     }
   }
 
-  Value* emitIfExpr(const SourceRange& range, Value * cond_value,
-      std::function<Value*()> true_expr,  std::function<Value*()> false_expr) {
+  Value* emitIfExpr(
+      const SourceRange& range,
+      Value* cond_value,
+      std::function<Value*()> true_expr,
+      std::function<Value*()> false_expr) {
     Node* n = graph->insertNode(create(prim::If, range, 0));
 
     n->addInput(cond_value);
@@ -871,17 +938,16 @@ private:
     //   a =
     // ... = a # OK, a is defined along all paths
 
-
-    //ordered set, because we want deterministic graph output
+    // ordered set, because we want deterministic graph output
     std::set<std::string> mutated_variables;
 
-    for(auto & v : save_true->definedVariables()) {
-      if(save_false->findInAnyFrame(v)) {
+    for (auto& v : save_true->definedVariables()) {
+      if (save_false->findInAnyFrame(v)) {
         mutated_variables.insert(v);
       }
     }
-    for(auto & v : save_false->definedVariables()) {
-      if(save_true->findInAnyFrame(v)) {
+    for (auto& v : save_false->definedVariables()) {
+      if (save_true->findInAnyFrame(v)) {
         mutated_variables.insert(v);
       }
     }
@@ -892,8 +958,8 @@ private:
       auto fv = save_false->getVar(x, stmt.range());
       auto unified = unifyTypes(tv->type(), fv->type());
 
-      // attempt to unify the types. we allow variables to be set to different types
-      // in each branch as long as that variable is not already in scope,
+      // attempt to unify the types. we allow variables to be set to different
+      // types in each branch as long as that variable is not already in scope,
       // or if that variable does not get used later. here, we save the error
       // so that the error message will be more informative in the case that is
       // used later. When a is accessed in (a + 1), the error will get printed
@@ -905,27 +971,31 @@ private:
       //
       if (!unified) {
         ErrorReport error(stmt);
-        error << "Type mismatch: " << x << " is set to type " << tv->type()->str() << " in the true branch"
-        << " and type " << fv->type()->str() << " in the false branch";
-        if (save_true->findInParentFrame(x) || save_false->findInParentFrame(x)) {
+        error << "Type mismatch: " << x << " is set to type "
+              << tv->type()->str() << " in the true branch"
+              << " and type " << fv->type()->str() << " in the false branch";
+        if (save_true->findInParentFrame(x) ||
+            save_false->findInParentFrame(x)) {
           throw error;
         } else {
           // error gets saved in the lowest environment because all
-          // variables are scoped to the function. doesn't matter if this accessed
-          // through save_true or save_false
+          // variables are scoped to the function. doesn't matter if this
+          // accessed through save_true or save_false
           save_true->setVariableTypeError(x, error.what());
           continue;
         }
       }
       true_block->registerOutput(tv);
       false_block->registerOutput(fv);
-      environment_stack->setVar(stmt.range(), x, n->addOutput()->setType(*unified));
+      environment_stack->setVar(
+          stmt.range(), x, n->addOutput()->setType(*unified));
     }
   }
 
   void emitIf(const If& stmt) {
-    // NOTE: emitIf checks on If stmt condition to see if the cond AST kind == is/is not,
-    // for such cases we do meta programming and disable emitting the corresponding branches
+    // NOTE: emitIf checks on If stmt condition to see if the cond AST kind ==
+    // is/is not, for such cases we do meta programming and disable emitting the
+    // corresponding branches
     Expr cond = stmt.cond();
 
     if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
@@ -934,16 +1004,19 @@ private:
       emitIfElseBlocks(cond_value, stmt);
       return;
     }
-    // meta programming on AST for is/is not cases and emit branches base on the possible output of cond
+    // meta programming on AST for is/is not cases and emit branches base on the
+    // possible output of cond
     auto cond_op = BinOp(cond);
     SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
     SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
 
-    List<Stmt> always_none_branch = cond.kind() == TK_IS? stmt.trueBranch(): stmt.falseBranch();
-    List<Stmt> never_none_branch = cond.kind() == TK_IS? stmt.falseBranch(): stmt.trueBranch();
+    List<Stmt> always_none_branch =
+        cond.kind() == TK_IS ? stmt.trueBranch() : stmt.falseBranch();
+    List<Stmt> never_none_branch =
+        cond.kind() == TK_IS ? stmt.falseBranch() : stmt.trueBranch();
 
-    auto lhs_none= lhs_val->isNone();
-    auto rhs_none= rhs_val->isNone();
+    auto lhs_none = lhs_val->isNone();
+    auto rhs_none = rhs_val->isNone();
 
     // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
     //
@@ -954,12 +1027,12 @@ private:
     if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
       // None is/is not None: only emit the always_none_branch
       emitStatements(always_none_branch);
-    } else if ((lhs_none == ALWAYS && rhs_none == NEVER) ||
-        (lhs_none == NEVER && rhs_none == ALWAYS)){
+    } else if (
+        (lhs_none == ALWAYS && rhs_none == NEVER) ||
+        (lhs_none == NEVER && rhs_none == ALWAYS)) {
       // lhs_val/rhs_val with A/M: only emit never_none_branch
       emitStatements(never_none_branch);
-    }
-    else {
+    } else {
       // all other cases for lhs_val and rhs_val
       // emit the whole If stmt as usual, finish emitCond first
       auto lhs_range = cond_op.lhs().get()->range();
@@ -970,13 +1043,12 @@ private:
           *method.graph(),
           kind,
           c10::nullopt,
-          {lhs_val->asValue(lhs_range, method), rhs_val->asValue(rhs_range, method)},
+          {lhs_val->asValue(lhs_range, method),
+           rhs_val->asValue(rhs_range, method)},
           {},
           /*required=*/true);
       emitIfElseBlocks(cond_value, stmt);
-
     }
-
   }
 
   // *********************** Loop Operators ************************************
@@ -986,12 +1058,11 @@ private:
 
   // the format of the Loop instruction is:
   // loop_carried_outputs* = Loop(max_trip_count, start_condition,
-  // loop_carried_inputs*)
-  //                          block0(loop_counter, loop_carried_block*) {
-  //                             <body>
-  //                             -> (continue_condition,
-  //                             loop_carried_block_outputs*)
-  //                          }
+  //                              loop_carried_inputs*)
+  //                    block0(loop_counter, loop_carried_block*) {
+  //                       <body>
+  //                       -> (continue_condition, loop_carried_block_outputs*)
+  //                    }
   // all loop_carried_... lists are the same length and represent the value of
   // loop-carried variables whose definitions are updated as the loop executes
   // in a way that ensure single static assignment.
@@ -1010,8 +1081,11 @@ private:
         max_trip_count_val = ensureInt(
             max_trip_count->range(), emitExpr(max_trip_count.value()));
       } else {
-        max_trip_count_val =
-            materializeConstant(std::numeric_limits<int64_t>::max(), *graph, range, integral_constants);
+        max_trip_count_val = materializeConstant(
+            std::numeric_limits<int64_t>::max(),
+            *graph,
+            range,
+            integral_constants);
       }
       if (cond) {
         cond_val = emitCond(cond.value());
@@ -1022,12 +1096,14 @@ private:
     n->addInput(max_trip_count_val);
     n->addInput(cond_val);
     auto* body_block = n->addBlock();
-    Value* trip_count = body_block->addInput()->setType(IntType::get()); // Iteration num
+    Value* trip_count =
+        body_block->addInput()->setType(IntType::get()); // Iteration num
 
     {
       pushFrame(body_block);
       if (itr_ident) {
-        environment_stack->setVar(itr_ident->range(), itr_ident->name(), trip_count);
+        environment_stack->setVar(
+            itr_ident->range(), itr_ident->name(), trip_count);
       }
       WithInsertPoint guard(body_block);
       emitStatements(body);
@@ -1056,7 +1132,7 @@ private:
       body_frame->deleteExtraInputs(range);
 
       // register node inputs/outputs for the true loop carried deps,
-      for(size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
+      for (size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
         auto x = body_frame->captured_inputs[i];
         n->addInput(outer_frame->getVar(x, range));
         // body_block->inputs(): loop_counter, lcd0, lcd1, ...
@@ -1064,11 +1140,14 @@ private:
         auto typ = body_block->inputs()[i + 1]->type();
         outer_frame->setVar(range, x, n->addOutput()->setType(typ));
       }
-
     }
   }
 
-  void emitForRange(const SourceRange& range, const Ident& target, const List<Expr>& args, const List<Stmt>& body) {
+  void emitForRange(
+      const SourceRange& range,
+      const Ident& target,
+      const List<Expr>& args,
+      const List<Stmt>& body) {
     // TODO: start, stop, step loop
     if (args.size() != 1) {
       throw ErrorReport(range)
@@ -1088,11 +1167,13 @@ private:
           << "List of iterables is not supported currently.";
     }
     if (targets.size() != 1) {
-      throw ErrorReport(stmt) << "Iteration variable unpacking is not supported";
+      throw ErrorReport(stmt)
+          << "Iteration variable unpacking is not supported";
     }
 
     if (targets[0].kind() != TK_VAR) {
-      throw ErrorReport(targets[0]) << "unexpected expression in variable initialization of for loop";
+      throw ErrorReport(targets[0])
+          << "unexpected expression in variable initialization of for loop";
     }
     auto target = Var(targets[0]).name();
 
@@ -1103,25 +1184,27 @@ private:
       if (range_iterator.callee().kind() == TK_VAR) {
         Var var = Var(range_iterator.callee());
         if (var.name().name() == "range") {
-          return emitForRange(stmt.range(), target, range_iterator.inputs(), body);
+          return emitForRange(
+              stmt.range(), target, range_iterator.inputs(), body);
         }
       }
     }
 
-    // it isn't a range(<expr>) loop, treat it as a sugared value that maybe can be
-    // unrolled
+    // it isn't a range(<expr>) loop, treat it as a sugared value that maybe can
+    // be unrolled
     auto sv = emitSugaredExpr(itrs[0], 1);
     auto instances = sv->asTuple(stmt.range(), method);
     const std::string& target_name = target.name();
     pushFrame(environment_stack->block());
-    for(const auto& inst : instances) {
+    for (const auto& inst : instances) {
       environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
       emitStatements(body);
     }
 
-    for (const auto & n : environment_stack->definedVariables()) {
+    for (const auto& n : environment_stack->definedVariables()) {
       if (environment_stack->findInParentFrame(n)) {
-        environment_stack->next->setVar(stmt.range(), n, environment_stack->getVar(n, stmt.range()));
+        environment_stack->next->setVar(
+            stmt.range(), n, environment_stack->getVar(n, stmt.range()));
       }
     }
     popFrame();
@@ -1132,7 +1215,6 @@ private:
     emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
   }
 
-
   // Currently we do not support assigning exceptions to variables,
   // a = Exception("hi")
   // raise a
@@ -1159,22 +1241,21 @@ private:
     /* true_block =*/n->addBlock();
     auto* false_block = n->addBlock();
 
-    //if assert test is false throw exception
+    // if assert test is false throw exception
     pushFrame(false_block);
     WithInsertPoint guard(false_block);
     emitRaise(stmt.range());
     popFrame();
   }
 
-
   // Validate that the `lhs` Expr's in an assignment statement are valid. That
   // is:
   //
   // 1) All lhs Expr's are either Var or Starred nodes
   // 2) There is at most one Starred node in the lhs Expr
-  // 3) A Starred node can only appear when there is another non-Starred lhs Expr
-  //    Concretely this means that `*abc = func()` is illegal. Unpacking all
-  //    outputs into a tuple is covered by `abc = func()`.
+  // 3) A Starred node can only appear when there is another non-Starred lhs
+  //    Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
+  //    all outputs into a tuple is covered by `abc = func()`.
   bool calcNumStarredUnpack(const List<Expr>& lhs, const SourceRange& r) {
     size_t num_normal_assign = 0;
     size_t num_starred = 0;
@@ -1196,8 +1277,8 @@ private:
 
     if (num_starred > 0 && num_normal_assign == 0) {
       throw ErrorReport(r) << "A Starred expression may only appear on the "
-                              << "lhs within the presence of another non-starred"
-                              << " expression.";
+                           << "lhs within the presence of another non-starred"
+                           << " expression.";
     }
 
     return num_starred;
@@ -1207,19 +1288,19 @@ private:
   // If the RHS is a tensor, return the corresponding ATen in-place op
   // If it's a list of scalars, then return the corresponding list augment op
   Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
-      switch (stmt.aug_op()) {
-        case '+':
-          return isTensor ? aten::add_ : aten::add;
-        case '-':
-          return isTensor ? aten::sub_ : aten::sub;
-        case '/':
-          return isTensor ? aten::div_ : aten::div;
-        case '*':
-          return isTensor ? aten::mul_ : aten::mul;
-        default:
-          throw ErrorReport(stmt) << "Unknown augmented assignment: "
-                                  << kindToString(stmt.aug_op());
-      }
+    switch (stmt.aug_op()) {
+      case '+':
+        return isTensor ? aten::add_ : aten::add;
+      case '-':
+        return isTensor ? aten::sub_ : aten::sub;
+      case '/':
+        return isTensor ? aten::div_ : aten::div;
+      case '*':
+        return isTensor ? aten::mul_ : aten::mul;
+      default:
+        throw ErrorReport(stmt)
+            << "Unknown augmented assignment: " << kindToString(stmt.aug_op());
+    }
   }
 
   // Emit nodes for augmented assignments like `+=`
@@ -1257,8 +1338,11 @@ private:
   // in place op, and throw error for other unsupported types
   void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
     const auto lhs = Select(stmt.lhs());
-    const auto lhsSugaredVar = environment_stack->getSugaredVar(Var(lhs.value()).name());
-    const auto lhsValue = lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())->asValue(lhs.range(), method);
+    const auto lhsSugaredVar =
+        environment_stack->getSugaredVar(Var(lhs.value()).name());
+    const auto lhsValue =
+        lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
+            ->asValue(lhs.range(), method);
     if (lhsValue->type()->isSubtypeOf(DynamicType::get())) {
       // for module parameter/buffer assignment, only consider tensor types,
       // emit the corresponding in-place op
@@ -1274,9 +1358,9 @@ private:
           /*required=*/true);
 
     } else {
-        throw ErrorReport(stmt.lhs())
-            << "left-hand side of augmented assignment to module "
-            << "parameters/buffers can only be tensor types";
+      throw ErrorReport(stmt.lhs())
+          << "left-hand side of augmented assignment to module "
+          << "parameters/buffers can only be tensor types";
     }
   }
 
@@ -1340,8 +1424,10 @@ private:
       } else {
         // Special case: we tried to do "advanced indexing". Lower this expr
         // into `index` and `index_put_` ops
-        const auto indices = graph->insertNode(
-          graph->createList(DynamicType::get(), tensorIndices))->output();
+        const auto indices = graph
+                                 ->insertNode(graph->createList(
+                                     DynamicType::get(), tensorIndices))
+                                 ->output();
         const auto indexed =
             graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
         const auto augmented = emitBuiltinCall(
@@ -1397,8 +1483,7 @@ private:
       const SourceRange& stmtRange,
       const Subscript& lhs,
       const Expr& rhs) {
-    emitSubscriptAssign(
-        stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
+    emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
   }
 
   void emitSubscriptAssign(
@@ -1428,15 +1513,17 @@ private:
       } else {
         // Special case: we tried to do "advanced indexing" with a tensor.
         // Dispatch to `aten::index_put_`.
-        const auto indices = graph->insertNode(
-          graph->createList(DynamicType::get(), tensorIndices))->output();
+        const auto indices = graph
+                                 ->insertNode(graph->createList(
+                                     DynamicType::get(), tensorIndices))
+                                 ->output();
 
         graph->insert(
             aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
       }
 
-    // Otherwise, this is a list. Dispatch to aten::_set_item to both select and
-    // assign
+      // Otherwise, this is a list. Dispatch to aten::_set_item to both select
+      // and assign
     } else {
       const auto subscript = lhs.subscript_exprs();
       if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
@@ -1459,23 +1546,21 @@ private:
   void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
     size_t n_binders = tl.inputs().size();
     bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
-    if(starred_unpack)
+    if (starred_unpack)
       n_binders--;
     auto output = emitSugaredExpr(rhs, n_binders);
     auto outputs = output->asTuple(
         rhs.range(),
         method,
         starred_unpack ? c10::nullopt : c10::optional<size_t>{n_binders});
-    if(outputs.size() < n_binders) {
+    if (outputs.size() < n_binders) {
       throw ErrorReport(tl)
-        << "need " << (starred_unpack ? "at least " : "")
-        << n_binders << " values to unpack but found only "
-        << outputs.size();
+          << "need " << (starred_unpack ? "at least " : "") << n_binders
+          << " values to unpack but found only " << outputs.size();
     }
-    if(outputs.size() > n_binders && !starred_unpack) {
-      throw ErrorReport(tl)
-      << "too many values to unpack: need " << n_binders << " but found "
-      << outputs.size();
+    if (outputs.size() > n_binders && !starred_unpack) {
+      throw ErrorReport(tl) << "too many values to unpack: need " << n_binders
+                            << " but found " << outputs.size();
     }
     int i = 0;
     for (auto assignee : tl.inputs()) {
@@ -1489,32 +1574,36 @@ private:
           i++;
           break;
         case TK_VAR:
-          environment_stack->setSugaredVar(assignee.range(), Var(assignee).name().name(), outputs.at(i));
+          environment_stack->setSugaredVar(
+              assignee.range(), Var(assignee).name().name(), outputs.at(i));
           i++;
           break;
         case TK_STARRED: {
           auto var = Starred(assignee).expr();
           if (var.kind() != TK_VAR) {
-            throw ErrorReport(var) << "Cannot pack a tuple into a non-variable.";
+            throw ErrorReport(var)
+                << "Cannot pack a tuple into a non-variable.";
           }
           size_t n_matched = outputs.size() - n_binders;
           ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
-          auto values = fmap(outputs_ref.slice(i, n_matched), [&](const std::shared_ptr<SugaredValue>& v) {
-            return v->asValue(assignee.range(), method);
-          });
+          auto values = fmap(
+              outputs_ref.slice(i, n_matched),
+              [&](const std::shared_ptr<SugaredValue>& v) {
+                return v->asValue(assignee.range(), method);
+              });
           auto tup = graph->insertNode(graph->createTuple(values))->output();
-          environment_stack->setVar(
-            var.range(), Var(var).name().name(), tup);
+          environment_stack->setVar(var.range(), Var(var).name().name(), tup);
           i += n_matched;
         } break;
         default:
-        throw ErrorReport(assignee) << "unexpected expression on the left-hand side";
+          throw ErrorReport(assignee)
+              << "unexpected expression on the left-hand side";
       }
     }
   }
 
   void emitAssignment(const Assign& stmt) {
-    switch(stmt.lhs().kind()) {
+    switch (stmt.lhs().kind()) {
       case TK_VAR: {
         auto v = Var(stmt.lhs());
         environment_stack->setSugaredVar(
@@ -1527,7 +1616,8 @@ private:
         emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
         break;
       default:
-        throw ErrorReport(stmt.lhs()) << "unexpected expression on left-hand side of assignment.";
+        throw ErrorReport(stmt.lhs())
+            << "unexpected expression on left-hand side of assignment.";
     }
   }
 
@@ -1586,17 +1676,16 @@ private:
     }
   }
 
-
-
   std::vector<NamedValue> getNamedValues(
       const TreeList& trees,
       bool maybe_unpack) {
     std::vector<NamedValue> values;
     for (const auto& tree : trees) {
-      if(maybe_unpack && tree->kind() == TK_STARRED) {
+      if (maybe_unpack && tree->kind() == TK_STARRED) {
         auto starred = Starred(tree);
-        auto entries = emitSugaredExpr(starred.expr(), 1)->asTuple(starred.range(), method);
-        for(const auto& entry : entries) {
+        auto entries = emitSugaredExpr(starred.expr(), 1)
+                           ->asTuple(starred.range(), method);
+        for (const auto& entry : entries) {
           values.emplace_back(
               tree->range(), entry->asValue(starred.range(), method));
         }
@@ -1612,38 +1701,33 @@ private:
     return getNamedValues(trees.tree()->trees(), maybe_unpack);
   }
 
-  std::vector<Value*> getValues(
-      const TreeList& trees,
-      bool maybe_unpack) {
+  std::vector<Value*> getValues(const TreeList& trees, bool maybe_unpack) {
     return toValues(*graph, getNamedValues(trees, maybe_unpack));
   }
-  std::vector<Value*> getValues(
-      const List<Expr>& trees,
-      bool maybe_unpack) {
+  std::vector<Value*> getValues(const List<Expr>& trees, bool maybe_unpack) {
     return getValues(trees.tree()->trees(), maybe_unpack);
   }
 
   std::vector<NamedValue> emitAttributes(const List<Attribute>& attributes) {
     return fmap(attributes, [&](const Attribute& attr) {
-      return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
+      return NamedValue(
+          attr.range(), attr.name().name(), emitExpr(attr.value()));
     });
   }
 
   void checkApplyExpr(Apply& apply, SourceRange& loc) {
     if (apply.inputs().size() != 2) {
-      throw ErrorReport(loc)
-          << Var(apply.callee()).name().name()
-          << " expected exactly two arguments but found "
-          << apply.inputs().size();
+      throw ErrorReport(loc) << Var(apply.callee()).name().name()
+                             << " expected exactly two arguments but found "
+                             << apply.inputs().size();
     }
     if (apply.attributes().size() > 0) {
       throw ErrorReport(loc)
-          << Var(apply.callee()).name().name()
-          << " takes no keyword arguments";
+          << Var(apply.callee()).name().name() << " takes no keyword arguments";
     }
   }
 
-  std::shared_ptr<SugaredValue> emitApplyExpr(Apply &apply, size_t n_binders) {
+  std::shared_ptr<SugaredValue> emitApplyExpr(Applyapply, size_t n_binders) {
     auto sv = emitSugaredExpr(apply.callee(), 1);
     auto loc = apply.callee().range();
     if (auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
@@ -1672,23 +1756,26 @@ private:
             << " but found " << expr->type()->python_str();
       }
       return std::make_shared<SimpleValue>(expr);
-    } else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
+    } else if (auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
       checkApplyExpr(apply, loc);
       auto obj = emitSugaredExpr(apply.inputs()[0], 1);
       auto selector = apply.inputs()[1];
       if (selector.kind() != TK_STRINGLITERAL) {
-        throw ErrorReport(loc) << "getattr's second argument must be a string literal";
+        throw ErrorReport(loc)
+            << "getattr's second argument must be a string literal";
       }
       const std::string& name = StringLiteral(selector).text();
       return obj->attr(apply.range(), method, name);
     } else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
-      // NOTE: for `isinstance` builtin call in JIT, we only check the static types
-      // on the inputs to evaluate, and insert the corresponding constant node
-      std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj, Expr classinfo) {
+      // NOTE: for `isinstance` builtin call in JIT, we only check the static
+      // types on the inputs to evaluate, and insert the corresponding constant
+      // node
+      std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj,
+                                                            Expr classinfo) {
         if (classinfo.kind() == TK_TUPLE_LITERAL) {
           // handle the case for recursive tuple classinfo
           // return true if obj is an instance of any of the types
-          for (Expr e: TupleLiteral(classinfo).inputs()) {
+          for (Expr e : TupleLiteral(classinfo).inputs()) {
             if (isInstanceCheck(obj, e)) {
               return true;
             }
@@ -1697,18 +1784,20 @@ private:
         }
         auto type_name = parseBaseTypeName(classinfo);
         if (!type_name) {
-          throw ErrorReport(classinfo.range()) << "type must be a type identifier";
+          throw ErrorReport(classinfo.range())
+              << "type must be a type identifier";
         }
         auto val = emitExpr(obj);
-        // Special casing for list and tuple since isintance(x, list) and isinstance(x, tuple)
-        // does not accept List[int] / Tuple[int] like subscript type annotation in python
+        // Special casing for list and tuple since isintance(x, list) and
+        // isinstance(x, tuple) does not accept List[int] / Tuple[int] like
+        // subscript type annotation in python
         if (*type_name == "list" && val->type()->cast<ListType>()) {
           return true;
         } else if (*type_name == "tuple" && val->type()->cast<TupleType>()) {
           return true;
         } else if (val->type()->cast<OptionalType>()) {
           throw ErrorReport(loc)
-                << "Optional isinstance check is not supported, consider use is/isnot None instead";
+              << "Optional isinstance check is not supported, consider use is/isnot None instead";
         } else {
           TypePtr type = parseTypeFromExpr(classinfo);
           if (val->type()->isSubtypeOf(type)) {
@@ -1718,8 +1807,10 @@ private:
         return false;
       };
       checkApplyExpr(apply, loc);
-      bool is_instance_val = isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
-      return std::make_shared<SimpleValue>(graph->insertConstant(is_instance_val, loc));
+      bool is_instance_val =
+          isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
+      return std::make_shared<SimpleValue>(
+          graph->insertConstant(is_instance_val, loc));
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
@@ -1741,7 +1832,8 @@ private:
     } else if (kind == aten::ge) {
       return aten::le;
     }
-    throw std::runtime_error("reverseComparision: unsupported NodeKind. File a bug");
+    throw std::runtime_error(
+        "reverseComparision: unsupported NodeKind. File a bug");
   }
 
   // any expression that can produce a SugaredValue is handled here
@@ -1751,8 +1843,11 @@ private:
   // or a = torch.jit.annotate(List[int], [])
   // the caller is responsible for checking that the result matches type_hint
   // emitSugaredExpr is free to ignore it.
-  std::shared_ptr<SugaredValue> emitSugaredExpr(const Expr& tree, size_t n_binders, const TypePtr& type_hint=nullptr) {
-    switch(tree.kind()) {
+  std::shared_ptr<SugaredValue> emitSugaredExpr(
+      const Expr& tree,
+      size_t n_binders,
+      const TypePtr& type_hint = nullptr) {
+    switch (tree.kind()) {
       case TK_VAR:
         return environment_stack->getSugaredVar(Var(tree).name());
       case '.': {
@@ -1769,18 +1864,18 @@ private:
     }
   }
 
-  Value * emitNegate(const TreeRef& tree) {
+  Value* emitNegate(const TreeRef& tree) {
     const auto& inputs = tree->trees();
     auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
 
     auto neg_val = emitBuiltinCall(
-               tree->range(),
-               *method.graph(),
-               aten::neg,
-               c10::nullopt,
-               named_values,
-               {},
-               /*required=*/true);
+        tree->range(),
+        *method.graph(),
+        aten::neg,
+        c10::nullopt,
+        named_values,
+        {},
+        /*required=*/true);
 
     // constant fold the input if possible
     auto maybe_constant_input = toIValue(neg_val->node()->input());
@@ -1798,22 +1893,25 @@ private:
   // This function extract a new graph from its original subgraph
   std::shared_ptr<SugaredValue> emitForkExpr(
       SourceRange loc,
-      const std::shared_ptr<SugaredValue> &forked,
+      const std::shared_ptr<SugaredValue>forked,
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes) {
     // Build the fork node without inputs
-    auto fork_node = method.graph()->insertNode(method.graph()->create(prim::fork, 1))
-                ->setSourceLocation(std::make_shared<SourceRange>(loc));
+    auto fork_node =
+        method.graph()
+            ->insertNode(method.graph()->create(prim::fork, 1))
+            ->setSourceLocation(std::make_shared<SourceRange>(loc));
     auto body_block = fork_node->addBlock();
 
     // Build a template of the graph to be executed
-    Value *node_output;
+    Valuenode_output;
     {
       WithInsertPoint guard(body_block);
       auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
       auto fn_simple_output = fn_sugared_output->asValue(loc, method);
       body_block->registerOutput(fn_simple_output);
-      node_output = fork_node->output()->setType(FutureType::create(fn_simple_output->type()));
+      node_output = fork_node->output()->setType(
+          FutureType::create(fn_simple_output->type()));
     }
 
     // Fork a new graph from its orignal owning graph
@@ -1867,13 +1965,13 @@ private:
         auto kind = getNodeKind(tree->kind(), inputs.size());
         auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
         return emitBuiltinCall(
-                   tree->range(),
-                   *method.graph(),
-                   kind,
-                   c10::nullopt,
-                   named_values,
-                   {},
-                   /*required=*/true);
+            tree->range(),
+            *method.graph(),
+            kind,
+            c10::nullopt,
+            named_values,
+            {},
+            /*required=*/true);
       }
       case TK_UNARY_MINUS: {
         return emitNegate(tree);
@@ -1882,13 +1980,11 @@ private:
       case TK_OR: {
         const auto& inputs = tree->trees();
         return emitShortCircuitIf(
-          tree->range(),
-          inputs[0],
-          inputs[1],
-          tree->kind() == TK_OR);
+            tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
       }
       case TK_STARRED: {
-        throw ErrorReport(tree) << "Unexpected starred expansion. File a bug report.";
+        throw ErrorReport(tree)
+            << "Unexpected starred expansion. File a bug report.";
       }
       case TK_CONST: {
         return emitConst(Const(tree));
@@ -1932,8 +2028,8 @@ private:
                 << *elem_type << " but found " << *v->type() << " instead";
           }
         }
-        Value* result = graph->insertNode(graph->createList(elem_type, values))
-            ->output();
+        Value* result =
+            graph->insertNode(graph->createList(elem_type, values))->output();
         return result;
       } break;
       case TK_TUPLE_LITERAL: {
@@ -1949,9 +2045,11 @@ private:
 
   Value* emitConst(const Const& c) {
     if (c.isFloatingPoint())
-      return materializeConstant(c.asFloatingPoint(), *graph, c.range(), fp_constants);
+      return materializeConstant(
+          c.asFloatingPoint(), *graph, c.range(), fp_constants);
     else
-     return materializeConstant(c.asIntegral(), *graph, c.range(), integral_constants);
+      return materializeConstant(
+          c.asIntegral(), *graph, c.range(), integral_constants);
   }
 
   Value* emitStringLiteral(const StringLiteral& c) {
@@ -1965,11 +2063,17 @@ private:
       int64_t dim,
       Value* index) {
     return emitBuiltinCall(
-        loc, *graph, aten::select, c10::nullopt,
-        {input, graph->insertConstant(dim, loc), index}, {}, true);
+        loc,
+        *graph,
+        aten::select,
+        c10::nullopt,
+        {input, graph->insertConstant(dim, loc), index},
+        {},
+        true);
   }
 
-  // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end, 1)
+  // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
+  // 1)
   Value* emitSlice(
       const SourceRange& loc,
       Value* input,
@@ -1995,29 +2099,34 @@ private:
     }
     if (input->type()->cast<TupleType>()) {
       if (has_end) {
-        return emitTupleSlice(loc, args[0], args[1], /*end*/args[2]);
+        return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]);
       } else {
         return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
       }
     }
     NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, loc));
-    return emitBuiltinCall(loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
+    return emitBuiltinCall(
+        loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
   }
 
   Value* emitIndex(
       const SourceRange& loc,
       Value* input,
       at::ArrayRef<Value*> indices) {
-    auto* index = graph->insertNode(
-        graph->createList(DynamicType::get(), indices))->output();
-    return emitBuiltinCall(loc, *graph, aten::index, c10::nullopt,  {input, index}, {}, true);
+    auto* index =
+        graph->insertNode(graph->createList(DynamicType::get(), indices))
+            ->output();
+    return emitBuiltinCall(
+        loc, *graph, aten::index, c10::nullopt, {input, index}, {}, true);
   }
 
   // Emits multidimensional slicing with int and slice indices.
   // Returns:
   // - Value*: the input after it has been indexed by int and slice indices.
-  // - vector<Value*>: A list of tensor Value* indices that have not been applied yet.
-  //   Should be NULL at indices where sliceable (post-slicing) isn't indexed by a tensor.
+  // - vector<Value*>: A list of tensor Value* indices that have not been
+  // applied yet.
+  //   Should be NULL at indices where sliceable (post-slicing) isn't indexed by
+  //   a tensor.
   std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
       const SourceRange& loc,
       Value* sliceable,
@@ -2032,7 +2141,7 @@ private:
       dim++;
     };
 
-    for (const auto & subscript_expr : subscript_exprs) {
+    for (const auto& subscript_expr : subscript_exprs) {
       if (subscript_expr.kind() == TK_SLICE_EXPR) {
         sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
         ++dim;
@@ -2047,8 +2156,9 @@ private:
         continue;
       }
       throw ErrorReport(loc)
-        << "Unsupported operation: indexing tensor with unsupported index type "
-        << index->type()->str() << ". Only ints, slices, and tensors are supported.";
+          << "Unsupported operation: indexing tensor with unsupported index type "
+          << index->type()->str()
+          << ". Only ints, slices, and tensors are supported.";
     }
     // at::index takes in a TensorList where some tensors can be undefined.
     // Convert NULL tensorIndices to undefined tensors to pass to at::index.
@@ -2073,9 +2183,9 @@ private:
   // enough dimensions to index".
   //
   // The strategy is to slice and select the tensor for int and slices first
-  // in one pass and then apply at::index on the result of the slicing/selecting.
-  // Call the tensor after we've applied slice / select the `sliced`.
-  // tensor_indices should have the same size as sliced.dim():
+  // in one pass and then apply at::index on the result of the
+  // slicing/selecting. Call the tensor after we've applied slice / select the
+  // `sliced`. tensor_indices should have the same size as sliced.dim():
   // - tensor_indices[i] = NULL if we should not index `sliced` at dim i
   // - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
   Value* emitMultidimSlicing(
@@ -2084,8 +2194,8 @@ private:
       const List<Expr>& subscript_exprs) {
     if (!sliceable->type()->isSubtypeOf(DynamicType::get())) {
       throw ErrorReport(loc)
-        << "Unsupported operation: attempted to use multidimensional "
-        << "indexing on a non-tensor type.";
+          << "Unsupported operation: attempted to use multidimensional "
+          << "indexing on a non-tensor type.";
     }
 
     std::vector<Value*> tensor_indices;
@@ -2117,46 +2227,49 @@ private:
     return emitSlice(loc, sliceable, maybe_dim, slice_exp);
   }
 
-  int64_t getTupleIndexVal(const SourceRange& loc,
-    const TupleTypePtr& tuple_type,
-      Value * idx_val,
+  int64_t getTupleIndexVal(
+      const SourceRange& loc,
+      const TupleTypePtr& tuple_type,
+      Value* idx_val,
       bool allow_out_of_bounds) {
-     int64_t index;
+    int64_t index;
     at::optional<IValue> ivalue = toIValue(idx_val);
     if (ivalue && ivalue->isInt()) {
       index = ivalue->to<int64_t>();
     } else {
-      throw ErrorReport(loc)
-        << "tuple indices must be integer constants";
+      throw ErrorReport(loc) << "tuple indices must be integer constants";
     }
-     // set index to be positive to simplify logic in runtime
+    // set index to be positive to simplify logic in runtime
     int64_t adj_index = index;
     int64_t tuple_len = tuple_type->elements().size();
     if (index < 0) {
       adj_index = tuple_len + index;
     }
     if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
-      throw ErrorReport(loc)
-        << "Tuple index out of range. Tuple is length " << tuple_len
-        << " and index is " << index;
+      throw ErrorReport(loc) << "Tuple index out of range. Tuple is length "
+                             << tuple_len << " and index is " << index;
     }
     return adj_index;
   }
-   Value* emitTupleIndex(const SourceRange& loc,
-      Value * tuple_val,
-      Value * idx_val) {
+  Value* emitTupleIndex(
+      const SourceRange& loc,
+      Value* tuple_val,
+      Value* idx_val) {
     auto tuple_typ = tuple_val->type()->cast<TupleType>();
-    auto adj_index = getTupleIndexVal(loc, tuple_typ, idx_val, /*allow_out_of_bounds*/false);
-    return graph->insertNode(
-        graph->createTupleIndex(tuple_val, adj_index))->output();
+    auto adj_index = getTupleIndexVal(
+        loc, tuple_typ, idx_val, /*allow_out_of_bounds*/ false);
+    return graph->insertNode(graph->createTupleIndex(tuple_val, adj_index))
+        ->output();
   }
 
-  Value* emitTupleSlice(const SourceRange& loc,
+  Value* emitTupleSlice(
+      const SourceRange& loc,
       const NamedValue& tuple_val,
       const NamedValue& beg_val,
       const at::optional<NamedValue>& end_val) {
     auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
-    int64_t beg = getTupleIndexVal(loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/true);
+    int64_t beg = getTupleIndexVal(
+        loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/ true);
     int64_t end;
     int64_t tuple_len = tuple_type->elements().size();
     if (end_val) {
@@ -2168,8 +2281,9 @@ private:
     end = std::min(std::max((int64_t)0, end), tuple_len);
     beg = std::min(std::max((int64_t)0, beg), tuple_len);
 
-    return graph->insertNode(
-        graph->createTupleSlice(tuple_val.value(*graph), beg, end))->output();
+    return graph
+        ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end))
+        ->output();
   }
 
   Value* emitSubscript(const Subscript& subscript) {
@@ -2204,7 +2318,7 @@ private:
       // if it's a list, emit a regular index selection op
       auto* idx = emitExpr(subscript_exprs[0]);
       return emitBuiltinCall(
-                 loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
+          loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
     } else if (gatherable->type()->isSubtypeOf(DynamicType::get())) {
       return emitMultidimSlicing(loc, gatherable, subscript_exprs);
     } else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
@@ -2212,24 +2326,28 @@ private:
       return emitTupleIndex(loc, gatherable, idx);
     } else {
       throw ErrorReport(loc)
-        << "Indexing only supported on lists, tensors, and tuples.";
+          << "Indexing only supported on lists, tensors, and tuples.";
     }
   }
 };
 
-void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<Def>& definitions, const std::vector<Resolver>& resolvers, const SugaredValuePtr& self) {
+void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
+    const std::vector<Def>& definitions,
+    const std::vector<Resolver>& resolvers,
+    const SugaredValuePtr& self) {
   JIT_ASSERT(definitions.size() == resolvers.size());
   auto resolver_it = resolvers.begin();
   std::vector<Method*> methods;
   std::unordered_map<std::string, Method*> function_table;
-  for(const Def& def : definitions) {
+  for (const Def& def : definitions) {
     const std::string& name = def.name().name();
     auto resolver = *resolver_it++;
     JIT_ASSERT(resolver);
-    if(!self) {
-      // if self is defined, then these are methods and do not go into the global namespace
-      // otherwise, they get defined together so we add them to the function table
-      // so the methods can see each other
+    if (!self) {
+      // if self is defined, then these are methods and do not go into the
+      // global namespace otherwise, they get defined together so we add them to
+      // the function table so the methods can see each other
       resolver = [resolver, &function_table](
                      const std::string& name,
                      Method& m,
@@ -2243,19 +2361,23 @@ void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::vector<D
     }
     auto creator = [def, resolver, self](Method& method) {
       JIT_ASSERT(resolver);
-      to_ir(def, resolver, self,  method);
+      to_ir(def, resolver, self, method);
     };
     Method& method = m->create_method(name, creator);
     function_table[name] = &method;
     methods.push_back(&method);
   }
-  for(Method* method : methods) {
+  for (Method* method : methods) {
     method->ensure_defined();
   }
   didFinishEmitModule(m);
 }
 
-void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const SugaredValuePtr& self) {
+void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
+    const std::string& source,
+    const Resolver& resolver,
+    const SugaredValuePtr& self) {
   Parser p(source);
   std::vector<Def> definitions;
   std::vector<Resolver> resolvers;
@@ -2267,7 +2389,6 @@ void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string&
   defineMethodsInModule(m, definitions, resolvers, self);
 }
 
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 4f4c42f..963308e 100644 (file)
@@ -5,17 +5,21 @@
 
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/error_report.h>
-#include <torch/csrc/jit/script/tree_views.h>
 #include <torch/csrc/jit/script/module.h>
 #include <torch/csrc/jit/script/sugared_value.h>
+#include <torch/csrc/jit/script/tree_views.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
-using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
+using Resolver = std::function<std::shared_ptr<
+    SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
 
-inline std::shared_ptr<SugaredValue> nativeResolver(const std::string& name, Method& m, const SourceRange& loc){
+inline std::shared_ptr<SugaredValue> nativeResolver(
+    const std::string& name,
+    Method& m,
+    const SourceRange& loc) {
   if (name == "torch") {
     return std::make_shared<BuiltinModule>("aten");
   }
@@ -23,14 +27,21 @@ inline std::shared_ptr<SugaredValue> nativeResolver(const std::string& name, Met
 }
 
 TORCH_API void defineMethodsInModule(
-  const std::shared_ptr<Module>& m,
-  const std::vector<Def>& definitions,
-  const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/
-  const std::shared_ptr<SugaredValue>& self /* if non-null, the first argument to each def, is bound to this value */
+    const std::shared_ptr<Module>& m,
+    const std::vector<Def>& definitions,
+    const std::vector<Resolver>& resolvers, /* determines how we handle free
+                                               variables in each definition*/
+    const std::shared_ptr<SugaredValue>&
+        self /* if non-null, the first argument to each def, is bound to this
+                value */
 );
 
 // same as above but parse the definitions from source
-TORCH_API void defineMethodsInModule(const std::shared_ptr<Module>& m, const std::string& source, const Resolver& resolver, const std::shared_ptr<SugaredValue>& self);
+TORCH_API void defineMethodsInModule(
+    const std::shared_ptr<Module>& m,
+    const std::string& source,
+    const Resolver& resolver,
+    const std::shared_ptr<SugaredValue>& self);
 
 } // namespace script
 } // namespace jit
index eb4f432..13f028e 100644 (file)
@@ -14,7 +14,7 @@ struct ErrorReport : public std::exception {
   explicit ErrorReport(const SourceRange& r)
       : context(std::make_shared<SourceRange>(r)) {}
   explicit ErrorReport(std::shared_ptr<SourceLocation> loc)
-  : context(std::move(loc)) {}
+      : context(std::move(loc)) {}
   explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
   explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
   const char* what() const noexcept override {
index 5541f27..82de29c 100644 (file)
@@ -1,5 +1,5 @@
-#include <torch/csrc/jit/script/final_returns.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/script/final_returns.h>
 
 namespace torch {
 namespace jit {
@@ -14,7 +14,7 @@ struct ReturnInfo {
 void checkNoReturn(const TreeRef& ref) {
   if (ref->kind() == TK_RETURN)
     throw ErrorReport(ref) << "return is not allowed from a loop.";
-  for(const TreeRef& child : ref->trees()) {
+  for (const TreeRef& child : ref->trees()) {
     checkNoReturn(child);
   }
 }
@@ -22,29 +22,38 @@ void checkNoReturn(const TreeRef& ref) {
 // transform stmts so that its last action is to return or report that it
 // never returns.
 // return_none - if true, add an implicit `return None` to the end of the block
-//   this handles the case where the return is implicit at the end of the function.
-ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none);
+//   this handles the case where the return is implicit at the end of the
+//   function.
+ReturnInfo makeReturnsFinal(
+    const SourceRange& range,
+    at::ArrayRef<TreeRef> stmts,
+    bool return_none);
 ReturnInfo makeReturnsFinal(const List<Stmt>& stmts, bool return_none) {
   return makeReturnsFinal(stmts.range(), stmts.get()->trees(), return_none);
 }
-ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmts, bool return_none) {
+ReturnInfo makeReturnsFinal(
+    const SourceRange& range,
+    at::ArrayRef<TreeRef> stmts,
+    bool return_none) {
   std::vector<TreeRef> changed;
   changed.reserve(stmts.size());
-  for(size_t i = 0; i < stmts.size(); ++i) {
+  for (size_t i = 0; i < stmts.size(); ++i) {
     const TreeRef& stmt = stmts[i];
-    switch(stmt->kind()) {
+    switch (stmt->kind()) {
       case TK_IF: {
         auto if_stmt = If(stmt);
         auto true_final = makeReturnsFinal(if_stmt.trueBranch(), false);
         // (3) early return an if statement without an else block:
         if (true_final.returns_ && if_stmt.falseBranch().size() == 0) {
-          auto rest_final = makeReturnsFinal(range, stmts.slice(i + 1), return_none);
+          auto rest_final =
+              makeReturnsFinal(range, stmts.slice(i + 1), return_none);
           if (!rest_final.returns_) {
             throw ErrorReport(if_stmt)
-                  << "This if statement performs an early return, but the block of code that follows it does not return."
-                  << " Early returns are only allowed when the block following them also returns.";
+                << "This if statement performs an early return, but the block of code that follows it does not return."
+                << " Early returns are only allowed when the block following them also returns.";
           }
-          changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
+          changed.emplace_back(
+              if_stmt.withNewBranches(true_final.stmts_, rest_final.stmts_));
           return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
         }
 
@@ -56,12 +65,13 @@ ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmt
         }
         // (2) all branches return
         if (true_final.returns_ && false_final.returns_) {
-          changed.emplace_back(if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
+          changed.emplace_back(
+              if_stmt.withNewBranches(true_final.stmts_, false_final.stmts_));
           return {true, List<Stmt>::unsafeCreate(range, std::move(changed))};
         }
         throw ErrorReport(if_stmt)
-              << "This if statement contains some paths that return and some paths that do not. "
-              << "If statements must either entirely return or never return.";
+            << "This if statement contains some paths that return and some paths that do not. "
+            << "If statements must either entirely return or never return.";
       } break;
       case TK_WHILE:
       case TK_FOR:
@@ -79,7 +89,8 @@ ReturnInfo makeReturnsFinal(const SourceRange& range, at::ArrayRef<TreeRef> stmt
   }
   if (return_none) {
     // add an implicit return none node
-    changed.emplace_back(Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
+    changed.emplace_back(
+        Return::create(range, Expr(Compound::create(TK_NONE, range, {}))));
   }
   // we reach the end of the block, no returns have happened
   // unless we just inserted a return_none implicit return.
index c2960a4..59f5824 100644 (file)
@@ -3,9 +3,9 @@
 #include <memory>
 #include <string>
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/script/tree_views.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 namespace torch {
 namespace jit {
@@ -33,8 +33,8 @@ namespace script {
 // In particular we allow:
 // 1. If statements where neither <true> nor <false> branch returns.
 // 2. If statements where both <true> and <false> always return.
-// 3. An 'early return' if statement where <true> always returns <false> is empty, and <rest>
-// always returns.
+// 3. An 'early return' if statement where <true> always returns <false> is
+// empty, and <rest> always returns.
 //
 // We do not allow returns from loops in any case.
 //
@@ -44,8 +44,10 @@ namespace script {
 // 2. Both branches return, so we recursively transform the program such that
 // <true> and <false>'s final action is to return. We then delete <rest>
 // because the code is dead. The remaining program preserves the inductive
-// property that its last action is to return since both branches end in a return.
-// 3. In this case we know that <true> and <rest> always returns, and <false> is empty.
+// property that its last action is to return since both branches end in a
+// return.
+// 3. In this case we know that <true> and <rest> always returns, and <false> is
+// empty.
 //    We transform the graph to:
 //    if <cond>:
 //       <true>
@@ -55,6 +57,6 @@ namespace script {
 
 TORCH_API List<Stmt> moveAllReturnsToEnd(const List<Stmt>& stmts);
 
-}
+} // namespace script
 } // namespace jit
 } // namespace torch
index 6ccd194..7b8bd2f 100644 (file)
@@ -7,20 +7,21 @@
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/schema_matching.h>
 
-#include <torch/csrc/jit/python_tracer.h>
-#include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/jit/constants.h>
-#include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/function_schema.h>
-#include <torch/csrc/jit/script/parser.h>
-#include <torch/csrc/jit/import_method.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
+#include <torch/csrc/jit/import_method.h>
 #include <torch/csrc/jit/passes/python_print.h>
+#include <torch/csrc/jit/passes/to_batch.h>
+#include <torch/csrc/jit/pybind_utils.h>
+#include <torch/csrc/jit/python_tracer.h>
+#include <torch/csrc/jit/script/parser.h>
 
 #include <torch/csrc/api/include/torch/ordered_dict.h>
 
 #include <ATen/ATen.h>
 
+#include <pybind11/functional.h>
 #include <cstddef>
 #include <memory>
 #include <sstream>
@@ -28,8 +29,6 @@
 #include <tuple>
 #include <utility>
 #include <vector>
-#include <pybind11/functional.h>
-
 
 namespace torch {
 namespace jit {
@@ -57,8 +56,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
     bool is_submodule = false);
 
 struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
-  PythonValue(py::object self)
-  : self(std::move(self)) {}
+  PythonValue(py::object self) : self(std::move(self)) {}
 
   FunctionSchema getSchema(const size_t n_args, const size_t n_binders) {
     auto annotations = py::module::import("torch.jit.annotations");
@@ -70,11 +68,13 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     if (!signature.is_none()) {
       std::vector<TypePtr> arg_types;
       TypePtr ret_type;
-      std::tie(arg_types, ret_type) = py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
+      std::tie(arg_types, ret_type) =
+          py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
       args.reserve(arg_types.size());
       size_t idx = 0; // Fake argument names by putting in the index
-      for (auto &arg_type : arg_types) {
-        args.push_back(Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
+      for (auto& arg_type : arg_types) {
+        args.push_back(Argument(
+            std::to_string(idx++), std::move(arg_type), {}, {}, false));
       }
       rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
     } else {
@@ -92,11 +92,12 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
       // Construct the default signature: all arguments and returns will be
       // DynamicType
       args.reserve(actual_n_args);
-      for (size_t i=0; i < actual_n_args; ++i) {
-        args.push_back(Argument(std::to_string(i), DynamicType::get(), {}, {}, false));
+      for (size_t i = 0; i < actual_n_args; ++i) {
+        args.push_back(
+            Argument(std::to_string(i), DynamicType::get(), {}, {}, false));
       }
       TypePtr ret_type = DynamicType::get();
-      if(n_binders != 1) {
+      if (n_binders != 1) {
         std::vector<TypePtr> tuple_values(n_binders, ret_type);
         ret_type = TupleType::create(std::move(tuple_values));
       }
@@ -106,13 +107,25 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
   }
 
   // call it like a function, e.g. `outputs = this(inputs)`
-  std::shared_ptr<SugaredValue> call(const SourceRange& loc, Method & m, at::ArrayRef<NamedValue> inputs_, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& m,
+      at::ArrayRef<NamedValue> inputs_,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
     auto inputs = toValues(*m.graph(), inputs_);
     auto schema = getSchema(inputs.size(), n_binders);
 
     std::stringstream failure_messages;
-    c10::optional<MatchedSchema> matched_schema =
-      tryMatchSchema(schema, loc, *m.graph(), c10::nullopt, inputs_, attributes, failure_messages, /*conv_tensor_to_num*/true);
+    c10::optional<MatchedSchema> matched_schema = tryMatchSchema(
+        schema,
+        loc,
+        *m.graph(),
+        c10::nullopt,
+        inputs_,
+        attributes,
+        failure_messages,
+        /*conv_tensor_to_num*/ true);
     if (!matched_schema)
       throw ErrorReport(loc) << failure_messages.str();
 
@@ -120,13 +133,14 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     py::object func = self;
     std::string cconv(inputs.size(), 'd');
     Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
-      THPObjectPtr(func.release().ptr()), cconv, {}));
+        THPObjectPtr(func.release().ptr()), cconv, {}));
     new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
-    for(auto &i : matched_schema->inputs)
+    for (auto& i : matched_schema->inputs)
       new_node->addInput(i);
 
     JIT_ASSERT(matched_schema->return_types.size() == 1);
-    Value* output = new_node->addOutput()->setType(matched_schema->return_types.at(0));
+    Value* output =
+        new_node->addOutput()->setType(matched_schema->return_types.at(0));
     return std::make_shared<SimpleValue>(output);
   }
 
@@ -136,8 +150,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
     return ss.str();
   }
 
-protected:
-
+ protected:
   py::object getattr(const SourceRange& loc, const std::string& name) {
     try {
       return py::getattr(self, name.c_str());
@@ -165,7 +178,8 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
 };
 
 struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
-  explicit ConstantPythonTupleValue(py::object tup) : PythonValue(std::move(tup)) {}
+  explicit ConstantPythonTupleValue(py::object tup)
+      : PythonValue(std::move(tup)) {}
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       const SourceRange& loc,
       Method& m,
@@ -180,9 +194,7 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
     return result;
   }
 
-  Value* asValue(
-      const SourceRange& loc,
-      Method& m) override {
+  Value* asValue(const SourceRange& loc, Method& m) override {
     std::vector<Value*> values;
     for (const auto& sugared_item : asTuple(loc, m)) {
       values.push_back(sugared_item->asValue(loc, m));
@@ -199,17 +211,18 @@ struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
 // anticipating we will eventually need to replace Module with a py::object
 // holding the actual nn.Module class.
 
-
 struct ModuleValue : public SugaredValue {
-  ModuleValue(std::shared_ptr<Module> module)
-  : module(std::move(module)) {}
+  ModuleValue(std::shared_ptr<Module> module) : module(std::move(module)) {}
 
   std::string kind() const override {
     return "module";
   }
 
   // select an attribute on it, e.g. `this.field`
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) override {
     // workaround to make self.training work
     // it adds a buffer 'training' to the model if one doesn't exist
     // and then loads that parameter, casting it to bool
@@ -218,7 +231,8 @@ struct ModuleValue : public SugaredValue {
       if (!v) {
         py::object py_module = py::cast(module);
         bool training = py::cast<bool>(py::getattr(py_module, "training"));
-        auto t = autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
+        auto t =
+            autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
         module->register_parameter("training", std::move(t), true);
         v = module->find_parameter(field);
       }
@@ -227,31 +241,39 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<SimpleValue>(the_bool);
     }
 
-    if(NamedModule* v = module->find_module(field)) {
+    if (NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleValue>(v->module);
-    } else if(Method* v = module->find_method(field)) {
+    } else if (Method* v = module->find_method(field)) {
       return std::make_shared<MethodValue>(module, *v);
-    } else if(NamedParameter* v = module->find_parameter(field)) {
+    } else if (NamedParameter* v = module->find_parameter(field)) {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     }
     // This can also be a call to a non-script module, or a plain
     // python method. If so return this as a python value.
     py::object py_module = py::cast(module);
-    if(py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
+    if (py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
       if (py::isinstance<py::function>(attr) ||
           py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
           py_module.attr("_constants_set").contains(field.c_str())) {
         return toSugaredValue(attr, m, loc, true);
       } else {
-        throw ErrorReport(loc) << "attribute '" << field << "' of type '" << typeString(attr) << "' is not usable in a script method (did you forget to add it __constants__?)";
+        throw ErrorReport(loc)
+            << "attribute '" << field << "' of type '" << typeString(attr)
+            << "' is not usable in a script method (did you forget to add it __constants__?)";
       }
     }
     throw ErrorReport(loc) << "module has no attribute '" << field << "'";
   }
 
   // call module.forward
-  std::shared_ptr<SugaredValue> call(const SourceRange& loc, Method & caller, at::ArrayRef<NamedValue> inputs, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
-    return attr(loc, caller, "forward")->call(loc, caller, inputs, attributes, n_binders);
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& caller,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
+    return attr(loc, caller, "forward")
+        ->call(loc, caller, inputs, attributes, n_binders);
   }
 
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
@@ -259,10 +281,12 @@ struct ModuleValue : public SugaredValue {
       Method& m,
       const c10::optional<size_t>& size_hint = {}) override {
     py::object py_module = py::cast(module);
-    if(!py::isinstance(py_module, py::module::import("torch.jit").attr("_ConstModuleList")))
+    if (!py::isinstance(
+            py_module,
+            py::module::import("torch.jit").attr("_ConstModuleList")))
       return SugaredValue::asTuple(loc, m, size_hint);
     std::vector<std::shared_ptr<SugaredValue>> result;
-    for(py::handle module : py_module) {
+    for (py::handle module : py_module) {
       py::object obj = py::reinterpret_borrow<py::object>(module);
       result.push_back(toSugaredValue(
           obj,
@@ -369,7 +393,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
       const auto v = static_cast<int64_t>(dtype->scalar_type);
       return toSimple(g.insertConstant(v, loc));
     } else if (py::isinstance<py::tuple>(obj)) {
-     return std::make_shared<ConstantPythonTupleValue>(obj);
+      return std::make_shared<ConstantPythonTupleValue>(obj);
     }
   }
 
@@ -394,11 +418,13 @@ std::shared_ptr<SugaredValue> toSugaredValue(
     return std::make_shared<PythonModuleValue>(obj);
   } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
     return std::make_shared<ForkValue>();
-  } else if (obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
+  } else if (
+      obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
     return std::make_shared<AnnotateValue>();
   }
 
-  py::object builtin_name = py::module::import("torch.jit").attr("_find_builtin")(obj);
+  py::object builtin_name =
+      py::module::import("torch.jit").attr("_find_builtin")(obj);
   if (!builtin_name.is_none()) {
     return std::make_shared<BuiltinFunction>(
         Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
@@ -431,18 +457,20 @@ py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
     return py::cast(autograd::as_variable_ref(outputs[0]));
   } else {
     py::tuple tuple(outputs.size());
-    for(size_t i = 0; i < outputs.size(); i++) {
+    for (size_t i = 0; i < outputs.size(); i++) {
       tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
     }
     return tuple;
   }
 }
 
-static void gatherParametersAndBuffers(std::vector<at::Tensor*> & values, const Module & m) {
-  for(auto & param : m.get_parameters()) {
+static void gatherParametersAndBuffers(
+    std::vector<at::Tensor*>& values,
+    const Module& m) {
+  for (auto& param : m.get_parameters()) {
     values.push_back(param->slot());
   }
-  for(const auto & sub : m.get_modules()) {
+  for (const auto& sub : m.get_modules()) {
     gatherParametersAndBuffers(values, *sub->module);
   }
 }
@@ -461,7 +489,7 @@ Resolver pythonResolver(const ResolutionCallback& rcb) {
   };
 }
 
-}
+} // namespace
 
 FunctionSchema getSchemaWithNameAndDefaults(
     const SourceRange& range,
@@ -482,7 +510,8 @@ FunctionSchema getSchemaWithNameAndDefaults(
         } else {
           value = toIValue(it->second, arg.type());
         }
-        new_args.emplace_back(arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
+        new_args.emplace_back(
+            arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
       } catch (py::cast_error& e) {
         throw ErrorReport(range)
             << "Expected a default value of type " << arg.type()->str()
@@ -509,225 +538,280 @@ void initJitScriptBindings(PyObject* module) {
   // public.
   py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
       .def(py::init<>())
-      .def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
-          m->save(filename);
-      })
-      .def("save_to_buffer", [](std::shared_ptr<Module> m) {
-          std::ostringstream buf;
-          m->save(buf);
-          return py::bytes(buf.str());
-      })
+      .def(
+          "save",
+          [](std::shared_ptr<Module> m, const std::string& filename) {
+            m->save(filename);
+          })
+      .def(
+          "save_to_buffer",
+          [](std::shared_ptr<Module> m) {
+            std::ostringstream buf;
+            m->save(buf);
+            return py::bytes(buf.str());
+          })
       .def("_set_optimized", &Module::set_optimized)
       .def(
           "_define",
           [](std::shared_ptr<Module> m,
              const std::string& script,
-             ResolutionCallback rcb, bool has_self) {
+             ResolutionCallback rcb,
+             bool has_self) {
             auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
             defineMethodsInModule(m, script, pythonResolver(rcb), self);
           })
-      .def("_create_methods", [](std::shared_ptr<Module> m,
-          const std::vector<Def>& defs,
-          const std::vector<ResolutionCallback>& rcbs,
-          const std::vector<FunctionDefaults>& defaults) {
-        std::vector<Resolver> resolvers;
-        resolvers.reserve(rcbs.size());
-        for(auto & callback : rcbs) {
-          resolvers.push_back(pythonResolver(callback));
-        }
-        defineMethodsInModule(
-          m,
-          defs,
-          resolvers,
-          std::make_shared<ModuleValue>(m));
-
-        // Stitch in default arguments for each Def if provided
-        auto defaults_it = defaults.begin();
-        auto defs_it = defs.begin();
-        while (defs_it != defs.end()) {
-          auto& method = m->get_method((*defs_it).name().name());
-          method.setSchema(getSchemaWithNameAndDefaults(
-              defs_it->range(), method.getSchema(), at::nullopt, *defaults_it));
-          ++defs_it;
-          ++defaults_it;
-        }
-        didFinishEmitModule(m);
-      })
-      .def("_get_method",
-      [](Module& self, const std::string& name) -> const Method& {
-        return self.get_method(name);
-      }, py::return_value_policy::reference_internal)
+      .def(
+          "_create_methods",
+          [](std::shared_ptr<Module> m,
+             const std::vector<Def>& defs,
+             const std::vector<ResolutionCallback>& rcbs,
+             const std::vector<FunctionDefaults>& defaults) {
+            std::vector<Resolver> resolvers;
+            resolvers.reserve(rcbs.size());
+            for (auto& callback : rcbs) {
+              resolvers.push_back(pythonResolver(callback));
+            }
+            defineMethodsInModule(
+                m, defs, resolvers, std::make_shared<ModuleValue>(m));
+
+            // Stitch in default arguments for each Def if provided
+            auto defaults_it = defaults.begin();
+            auto defs_it = defs.begin();
+            while (defs_it != defs.end()) {
+              auto& method = m->get_method((*defs_it).name().name());
+              method.setSchema(getSchemaWithNameAndDefaults(
+                  defs_it->range(),
+                  method.getSchema(),
+                  at::nullopt,
+                  *defaults_it));
+              ++defs_it;
+              ++defaults_it;
+            }
+            didFinishEmitModule(m);
+          })
+      .def(
+          "_get_method",
+          [](Module& self, const std::string& name) -> const Method& {
+            return self.get_method(name);
+          },
+          py::return_value_policy::reference_internal)
       .def("_register_parameter", &Module::register_parameter)
       .def("_register_module", &Module::register_module)
       .def("_set_parameter", &Module::set_parameter)
       .def("_get_parameter", &Module::get_parameter)
       .def("_get_module", &Module::get_module)
-      .def("_get_modules", [](Module& self) -> py::tuple {
-        auto & modules = self.get_modules();
-        py::tuple result(modules.size());
-        for(size_t i = 0; i < modules.size(); ++i) {
-          auto & item = modules[i];
-          result[i] = std::make_pair(item.key(), item.value().module);
-        }
-        return result;
-      })
-      .def("_get_parameters", [](Module& self) -> py::tuple {
-        auto & parameters = self.get_parameters();
-        py::tuple result(parameters.size());
-        for(size_t i = 0; i < parameters.size(); ++i) {
-          auto & p = parameters[i];
-          py::tuple r(3);
-          result[i] = std::make_tuple(
-            p.key(),
-            autograd::as_variable_ref(*p->slot()),
-            p->is_buffer);
-
-        }
-        return result;
-      })
-      .def("_has_parameter", [](Module& self, const std::string& name) {
-        if(auto r = self.find_parameter(name)) {
-          return !r->is_buffer;
-        }
-        return false;
-      })
-      .def("_has_buffer", [](Module& self, const std::string& name) {
-        if(auto r = self.find_parameter(name)) {
-          return r->is_buffer;
-        }
-        return false;
-      })
-      .def("_has_module", [](Module& self, const std::string& name) {
-        return bool(self.find_module(name));
-      })
-      .def("_has_method", [](Module& self, const std::string& name) {
-        return bool(self.find_method(name));
-      })
-      .def("_method_names", [](Module& self) {
-        using Item = torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
-        return fmap(self.get_methods(), [](const Item & item) {
-          return (*item)->name();
-        });
-      })
-      .def("_create_method_from_graph", [](
-         Module& self,
-         const std::string& name,
-         std::shared_ptr<Graph> graph
-       ){
-         self.create_method(name, std::move(graph), {});
-      })
-      .def("_create_method_from_trace", [](
-        std::shared_ptr<Module> self,
-        const std::string& name,
-        py::function func,
-        py::tuple input_tuple,
-        py::function var_lookup_fn,
-        bool force_outplace) {
-          // prereq: Module's buffers and parameters are unique
-          // this was ensured in python before calling this function
-          std::vector<at::Tensor*> parameters;
-          gatherParametersAndBuffers(parameters, *self);
-          Stack inputs = toStack(input_tuple);
-          for(at::Tensor* param : parameters) {
-            inputs.emplace_back(*param);
-          }
-          auto graph = tracer::createGraphByTracing(
-              func, inputs, var_lookup_fn, force_outplace, input_tuple.size());
-          self->create_method(name, std::move(graph), std::move(parameters));
-          didFinishEmitModule(self);
-      })
-      .def("graph_for", [](py::args args, py::kwargs kwargs) {
-        // [pybind11 varargs] note: old version of pybind11 have a bug that leaks memory
-        // when py::args is mixed with positional arguments
-        // https://github.com/pybind/pybind11/pull/1216
-        // we work around this by not mixing positional arguments with varargs
-        Module& self = py::cast<Module&>(args[0]);
-        if (self.find_method("forward")) {
-          Method & m = self.get_method("forward");
-          return m.graph_for(
-              createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
-        }
-        throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
-      })
-      .def("get_debug_state", [](Module& self) {
-        if (self.find_method("forward")) {
-          Method & m = self.get_method("forward");
-          return m.getDebugState();
-        }
-        throw std::runtime_error("Attempted to call get_debug_state on a Module without a compiled forward()");
-      })
-      .def("debug_disable_autodiff_subgraph_inlining", [](Module& self) {
-        if (self.find_method("forward")) {
-          Method & m = self.get_method("forward");
-          m.debugDisableAutodiffSubgraphInlining();
-        }
-      })
-      .def("forward", [](py::args args, py::kwargs kwargs) {
-        // We implement this in C++ to avoid incurring the pybind11 dispatch
-        // overhead twice: once to call into the method lookup for "forward"
-        // and once to actually invoke the method.
-        //
-        // There is a thin wrapper on top of this method in the C++ version of
-        // ScriptModule.
-
-        // see: [pybind11 varargs]
-        Module& self = py::cast<Module&>(args[0]);
-        return invokeScriptMethodFromPython(self.get_method("forward"), tuple_slice(std::move(args), 1), std::move(kwargs));
-      })
-      .def("_python_print", [](Module& self) {
-        std::ostringstream ss;
-        std::vector<at::Tensor> tensors;
-        PythonPrint(ss, self, tensors, true);
-        return std::make_pair(ss.str(), tensors);
-      })
-      .def_property_readonly("code", [](Module& self) {
-        std::ostringstream ss;
-        std::vector<at::Tensor> tensors;
-        PythonPrint(ss, self, tensors, false);
-        return ss.str();
-      })
+      .def(
+          "_get_modules",
+          [](Module& self) -> py::tuple {
+            auto& modules = self.get_modules();
+            py::tuple result(modules.size());
+            for (size_t i = 0; i < modules.size(); ++i) {
+              auto& item = modules[i];
+              result[i] = std::make_pair(item.key(), item.value().module);
+            }
+            return result;
+          })
+      .def(
+          "_get_parameters",
+          [](Module& self) -> py::tuple {
+            auto& parameters = self.get_parameters();
+            py::tuple result(parameters.size());
+            for (size_t i = 0; i < parameters.size(); ++i) {
+              auto& p = parameters[i];
+              py::tuple r(3);
+              result[i] = std::make_tuple(
+                  p.key(), autograd::as_variable_ref(*p->slot()), p->is_buffer);
+            }
+            return result;
+          })
+      .def(
+          "_has_parameter",
+          [](Module& self, const std::string& name) {
+            if (auto r = self.find_parameter(name)) {
+              return !r->is_buffer;
+            }
+            return false;
+          })
+      .def(
+          "_has_buffer",
+          [](Module& self, const std::string& name) {
+            if (auto r = self.find_parameter(name)) {
+              return r->is_buffer;
+            }
+            return false;
+          })
+      .def(
+          "_has_module",
+          [](Module& self, const std::string& name) {
+            return bool(self.find_module(name));
+          })
+      .def(
+          "_has_method",
+          [](Module& self, const std::string& name) {
+            return bool(self.find_method(name));
+          })
+      .def(
+          "_method_names",
+          [](Module& self) {
+            using Item =
+                torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
+            return fmap(self.get_methods(), [](const Item& item) {
+              return (*item)->name();
+            });
+          })
+      .def(
+          "_create_method_from_graph",
+          [](Module& self,
+             const std::string& name,
+             std::shared_ptr<Graph> graph) {
+            self.create_method(name, std::move(graph), {});
+          })
+      .def(
+          "_create_method_from_trace",
+          [](std::shared_ptr<Module> self,
+             const std::string& name,
+             py::function func,
+             py::tuple input_tuple,
+             py::function var_lookup_fn,
+             bool force_outplace) {
+            // prereq: Module's buffers and parameters are unique
+            // this was ensured in python before calling this function
+            std::vector<at::Tensor*> parameters;
+            gatherParametersAndBuffers(parameters, *self);
+            Stack inputs = toStack(input_tuple);
+            for (at::Tensor* param : parameters) {
+              inputs.emplace_back(*param);
+            }
+            auto graph = tracer::createGraphByTracing(
+                func,
+                inputs,
+                var_lookup_fn,
+                force_outplace,
+                input_tuple.size());
+            self->create_method(name, std::move(graph), std::move(parameters));
+            didFinishEmitModule(self);
+          })
+      .def(
+          "graph_for",
+          [](py::args args, py::kwargs kwargs) {
+            // [pybind11 varargs] note: old version of pybind11 have a bug that
+            // leaks memory when py::args is mixed with positional arguments
+            // https://github.com/pybind/pybind11/pull/1216
+            // we work around this by not mixing positional arguments with
+            // varargs
+            Module& self = py::cast<Module&>(args[0]);
+            if (self.find_method("forward")) {
+              Method& m = self.get_method("forward");
+              return m.graph_for(createStackForSchema(
+                  m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
+            }
+            throw std::runtime_error(
+                "Attempted to call graph_for on a Module without a compiled forward()");
+          })
+      .def(
+          "get_debug_state",
+          [](Module& self) {
+            if (self.find_method("forward")) {
+              Method& m = self.get_method("forward");
+              return m.getDebugState();
+            }
+            throw std::runtime_error(
+                "Attempted to call get_debug_state on a Module without a compiled forward()");
+          })
+      .def(
+          "debug_disable_autodiff_subgraph_inlining",
+          [](Module& self) {
+            if (self.find_method("forward")) {
+              Method& m = self.get_method("forward");
+              m.debugDisableAutodiffSubgraphInlining();
+            }
+          })
+      .def(
+          "forward",
+          [](py::args args, py::kwargs kwargs) {
+            // We implement this in C++ to avoid incurring the pybind11 dispatch
+            // overhead twice: once to call into the method lookup for "forward"
+            // and once to actually invoke the method.
+            //
+            // There is a thin wrapper on top of this method in the C++ version
+            // of ScriptModule.
+
+            // see: [pybind11 varargs]
+            Module& self = py::cast<Module&>(args[0]);
+            return invokeScriptMethodFromPython(
+                self.get_method("forward"),
+                tuple_slice(std::move(args), 1),
+                std::move(kwargs));
+          })
+      .def(
+          "_python_print",
+          [](Module& self) {
+            std::ostringstream ss;
+            std::vector<at::Tensor> tensors;
+            PythonPrint(ss, self, tensors, true);
+            return std::make_pair(ss.str(), tensors);
+          })
+      .def_property_readonly(
+          "code",
+          [](Module& self) {
+            std::ostringstream ss;
+            std::vector<at::Tensor> tensors;
+            PythonPrint(ss, self, tensors, false);
+            return ss.str();
+          })
       .def("apply", &Module::apply)
       .def("_copy_into", &Module::copy_into);
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
-    .def("graph", [&](Method& self) {
-      return self.graph();
-    })
-    .def("__call__", [](py::args args, py::kwargs kwargs) {
-      // see: [pybind11 varargs]
-      Method& method = py::cast<Method&>(args[0]);
-      return invokeScriptMethodFromPython(method, tuple_slice(std::move(args), 1), std::move(kwargs));
-    })
-    .def_property_readonly("graph", [](Method& m) {
-      return m.graph();
-    })
-    .def("propagate_shapes", &Method::propagate_shapes)
-    .def("propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes)
-    .def("params", &Method::params)
-    .def("graph_for", [](py::args args, py::kwargs kwargs) {
-      // see: [pybind11 varargs]
-      Method& self = py::cast<Method&>(args[0]);
-      return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
-    })
-    .def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
-    .def("schema", &Method::getSchema)
-    .def("pretty_print_schema", &Method::pretty_print_schema)
-    .def("python_print", [](Method &m) {
-      std::ostringstream oss;
-      std::vector<at::Tensor> constants;
-      PythonPrint(oss, m, constants, true);
-      return std::make_pair(oss.str(), std::move(constants));
-    });
-
-  m.def("_jit_script_compile", [](std::shared_ptr<Module> mod, const Def &def, ResolutionCallback rcb, FunctionDefaults defaults) {
-    auto def_f = def.withName("forward");
-    defineMethodsInModule(mod, {def_f}, {pythonResolver(rcb)}, nullptr);
-    auto& method = mod->get_method("forward");
-    method.setSchema(getSchemaWithNameAndDefaults(
-        def.range(), method.getSchema(), def.name().name(), defaults));
-    didFinishEmitModule(mod);
-    return mod;
-  });
+      .def("graph", [&](Method& self) { return self.graph(); })
+      .def(
+          "__call__",
+          [](py::args args, py::kwargs kwargs) {
+            // see: [pybind11 varargs]
+            Method& method = py::cast<Method&>(args[0]);
+            return invokeScriptMethodFromPython(
+                method, tuple_slice(std::move(args), 1), std::move(kwargs));
+          })
+      .def_property_readonly("graph", [](Method& m) { return m.graph(); })
+      .def("propagate_shapes", &Method::propagate_shapes)
+      .def(
+          "propagate_and_assign_input_and_output_shapes",
+          &Method::propagate_and_assign_input_and_output_shapes)
+      .def("params", &Method::params)
+      .def(
+          "graph_for",
+          [](py::args args, py::kwargs kwargs) {
+            // see: [pybind11 varargs]
+            Method& self = py::cast<Method&>(args[0]);
+            return self.graph_for(createStackForSchema(
+                self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
+          })
+      .def(
+          "debug_disable_autodiff_subgraph_inlining",
+          &Method::debugDisableAutodiffSubgraphInlining)
+      .def("schema", &Method::getSchema)
+      .def("pretty_print_schema", &Method::pretty_print_schema)
+      .def("python_print", [](Method& m) {
+        std::ostringstream oss;
+        std::vector<at::Tensor> constants;
+        PythonPrint(oss, m, constants, true);
+        return std::make_pair(oss.str(), std::move(constants));
+      });
+
+  m.def(
+      "_jit_script_compile",
+      [](std::shared_ptr<Module> mod,
+         const Def& def,
+         ResolutionCallback rcb,
+         FunctionDefaults defaults) {
+        auto def_f = def.withName("forward");
+        defineMethodsInModule(mod, {def_f}, {pythonResolver(rcb)}, nullptr);
+        auto& method = mod->get_method("forward");
+        method.setSchema(getSchemaWithNameAndDefaults(
+            def.range(), method.getSchema(), def.name().name(), defaults));
+        didFinishEmitModule(mod);
+        return mod;
+      });
 
   m.def("parse_type_comment", [](const std::string& comment) {
     Parser p(comment);
@@ -735,25 +819,33 @@ void initJitScriptBindings(PyObject* module) {
   });
 
   m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
-  m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename,
-        py::object map_location) {
-    c10::optional<at::Device> optional_device;
-    if (!map_location.is(py::none())) {
-      AT_ASSERT(THPDevice_Check(map_location.ptr()));
-      optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
-    }
-    import_ir_module(module_lookup, filename, optional_device);
-  });
-  m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup,
-        const std::string& buffer, py::object map_location) {
-    std::istringstream in(buffer);
-    c10::optional<at::Device> optional_device;
-    if (!map_location.is(py::none())) {
-      AT_ASSERT(THPDevice_Check(map_location.ptr()));
-      optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
-    }
-    import_ir_module(module_lookup, in, optional_device);
-  });
+  m.def(
+      "import_ir_module",
+      [](ModuleLookup module_lookup,
+         const std::string& filename,
+         py::object map_location) {
+        c10::optional<at::Device> optional_device;
+        if (!map_location.is(py::none())) {
+          AT_ASSERT(THPDevice_Check(map_location.ptr()));
+          optional_device =
+              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+        }
+        import_ir_module(module_lookup, filename, optional_device);
+      });
+  m.def(
+      "import_ir_module_from_buffer",
+      [](ModuleLookup module_lookup,
+         const std::string& buffer,
+         py::object map_location) {
+        std::istringstream in(buffer);
+        c10::optional<at::Device> optional_device;
+        if (!map_location.is(py::none())) {
+          AT_ASSERT(THPDevice_Check(map_location.ptr()));
+          optional_device =
+              reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+        }
+        import_ir_module(module_lookup, in, optional_device);
+      });
   m.def("_jit_import_methods", import_methods);
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
 }
index 6de8ad5..31ea21b 100644 (file)
@@ -5,11 +5,9 @@
 namespace torch {
 namespace jit {
 
-struct  JITException
-    : public std::runtime_error {
+struct JITException : public std::runtime_error {
   JITException() = default;
-  explicit JITException(const std::string& msg)
-      : std::runtime_error(msg) {}
+  explicit JITException(const std::string& msg) : std::runtime_error(msg) {}
 };
 
 } // namespace jit
index e320163..fe86e4b 100644 (file)
@@ -2,44 +2,44 @@
 
 #include <c10/util/Exception.h>
 
+#include <mutex>
 #include <string>
 #include <unordered_map>
-#include <mutex>
 
 namespace torch {
 namespace jit {
 namespace script {
 
 static const std::unordered_map<int, int> binary_prec = {
-    {TK_IF,         1},
-    {TK_AND,        2},
-    {TK_OR,         2},
+    {TK_IF, 1},
+    {TK_AND, 2},
+    {TK_OR, 2},
     // reserve a level for unary not
-    {'<',           4},
-    {'>',           4},
-    {TK_IS,         4},
-    {TK_ISNOT,      4},
-    {TK_EQ,         4},
-    {TK_LE,         4},
-    {TK_GE,         4},
-    {TK_NE,         4},
-    {'|',           5},
-    {'^',           6},
-    {'&',           7},
-    {'+',           8},
-    {'-',           8},
-    {'*',           9},
-    {'/',           9},
-    {TK_FLOOR_DIV,  9},
-    {'%',           9},
-    {'@',           9},
-    {TK_POW,       10},
+    {'<', 4},
+    {'>', 4},
+    {TK_IS, 4},
+    {TK_ISNOT, 4},
+    {TK_EQ, 4},
+    {TK_LE, 4},
+    {TK_GE, 4},
+    {TK_NE, 4},
+    {'|', 5},
+    {'^', 6},
+    {'&', 7},
+    {'+', 8},
+    {'-', 8},
+    {'*', 9},
+    {'/', 9},
+    {TK_FLOOR_DIV, 9},
+    {'%', 9},
+    {'@', 9},
+    {TK_POW, 10},
 };
 
 static const std::unordered_map<int, int> unary_prec = {
-    {TK_NOT,        3},
-    {'-',           9},
-    {'*',           9},
+    {TK_NOT, 3},
+    {'-', 9},
+    {'*', 9},
 };
 
 bool SharedParserData::isUnary(int kind, int* prec) {
@@ -66,7 +66,8 @@ int stringToKind(const std::string& str) {
     for (char tok : std::string(valid_single_char_tokens))
       str_to_kind[std::string(1, tok)] = tok;
 #define DEFINE_CASE(tok, _, str) \
-    if (std::string(str) != "") str_to_kind[str] = tok;
+  if (std::string(str) != "")    \
+    str_to_kind[str] = tok;
     TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
 #undef DEFINE_CASE
   });
index 6a0c178..7890d7d 100644 (file)
@@ -1,15 +1,15 @@
 #pragma once
+#include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/utils/memory.h>
 #include <algorithm>
+#include <clocale>
 #include <iostream>
 #include <memory>
 #include <sstream>
 #include <string>
 #include <unordered_map>
 #include <vector>
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/source_range.h>
-#include <torch/csrc/utils/memory.h>
-#include <clocale>
 
 namespace torch {
 namespace jit {
@@ -95,7 +95,6 @@ namespace script {
   _(TK_DOTS, "dots", "...")                      \
   _(TK_PASS, "pass", "pass")
 
-
 static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|";
 
 enum TokenKind {
@@ -150,21 +149,21 @@ struct SharedParserData {
       head->insert(str.c_str(), *c);
     }
 
-#define ADD_CASE(tok, _, tokstring) \
+#define ADD_CASE(tok, _, tokstring)   \
   if (*(tokstring) != '\0') {         \
-    head->insert((tokstring), (tok));   \
+    head->insert((tokstring), (tok)); \
   }
     TC_FORALL_TOKEN_KINDS(ADD_CASE)
 #undef ADD_CASE
   }
 #ifdef _WIN32
-  static double strtod_c(const char * str, char** end) {
+  static double strtod_c(const char* str, char** end) {
     /// NOLINTNEXTLINE(hicpp-signed-bitwise)
     static _locale_t loc = _create_locale(LC_ALL, "C");
     return _strtod_l(str, end, loc);
   }
 #else
-  static double strtod_c(const char * str, char** end) {
+  static double strtod_c(const char* str, char** end) {
     /// NOLINTNEXTLINE(hicpp-signed-bitwise)
     static locale_t loc = newlocale(LC_ALL_MASK, "C", nullptr);
     return strtod_l(str, end, loc);
@@ -189,8 +188,9 @@ struct SharedParserData {
   }
 
   bool isCharCount(char c, const std::string& str, size_t start, int len) {
-    //count checks from [start, start + len)
-    return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len;
+    // count checks from [start, start + len)
+    return start + len <= str.size() &&
+        std::count(str.begin() + start, str.begin() + start + len, c) == len;
   }
 
   // python concatenates all adjacent strings "a" "b" == "ab"
@@ -203,25 +203,25 @@ struct SharedParserData {
       return false;
     int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
 
-    //end is now set past the opening quotation marks
+    // end is now set past the opening quotation marks
     size_t end = start + quote_len;
-    while(end < str.size() && !isCharCount(quote, str, end, quote_len)) {
+    while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
       if (str[end] == '\n' && quote_len != 3) {
         return false;
       }
-      //handle escaped characters. advances past escaped quotation marks,
-      //escaped newlines and escaped backslashes
-      //multi-char escapes like \x1A are handled fine here because the
-      //remainder of the escape are valid string characters anyway
+      // handle escaped characters. advances past escaped quotation marks,
+      // escaped newlines and escaped backslashes
+      // multi-char escapes like \x1A are handled fine here because the
+      // remainder of the escape are valid string characters anyway
       if (str[end] == '\\') {
         end++;
       }
       end++;
     }
-    //set length equal to the complete string including quotations
+    // set length equal to the complete string including quotations
     *len = end - start + quote_len;
-    //if end finished without going past the last character of the string than
-    //there is a match
+    // if end finished without going past the last character of the string than
+    // there is a match
     return end < str.size();
   }
 
@@ -238,8 +238,7 @@ struct SharedParserData {
     return match_string == type_string;
   }
   // find the longest match of str.substring(pos) against a token, return true
-  // if successful
-  // filling in kind, start,and len
+  // if successful filling in kind, start,and len
   bool match(
       const std::string& str,
       size_t pos,
@@ -273,9 +272,8 @@ struct SharedParserData {
             str, pos + 1, continuation, !continuation, kind, start, len);
       }
     }
-    // we handle white space before EOF because in the case we have something like
-    // the following where we need to generate the dedent token
-    // if foo:
+    // we handle white space before EOF because in the case we have something
+    // like the following where we need to generate the dedent token if foo:
     //   ...
     // else:
     //   pass
@@ -321,14 +319,15 @@ struct SharedParserData {
       // identifier 'max'
       if (cur) {
         size_t child_offset = 0;
-        for (size_t e = cur->child_chars.size(); child_offset < e; ++child_offset) {
+        for (size_t e = cur->child_chars.size(); child_offset < e;
+             ++child_offset) {
           if (cur->child_chars[child_offset] == str[pos + i])
-          break;
+            break;
         }
 
         cur = (child_offset == cur->child_chars.size())
-          ? nullptr
-          : cur->child_tries[child_offset].get();
+            ? nullptr
+            : cur->child_tries[child_offset].get();
 
         if (cur && cur->kind != 0) {
           matched = true;
@@ -405,7 +404,8 @@ struct Lexer {
 
   [[noreturn]] void reportError(const std::string& what) {
     reportError(what, cur());
-  }[[noreturn]] void reportError(const std::string& what, const Token& t) {
+  }
+  [[noreturn]] void reportError(const std::string& what, const Token& t) {
     std::stringstream ss;
     ss << what << ":\n";
     t.range.highlight(ss);
@@ -417,7 +417,8 @@ struct Lexer {
        << "' here:\n";
     t.range.highlight(ss);
     throw std::runtime_error(ss.str());
-  }[[noreturn]] void expected(const std::string& what) {
+  }
+  [[noreturn]] void expected(const std::string& what) {
     expected(what, cur());
   }
   // Check that the current token has a given kind, return the current token,
index a4275f5..3aa1418 100644 (file)
@@ -1,13 +1,14 @@
 #include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/jit/script/module.h>
-#include <torch/csrc/jit/script/compiler.h>
-#include <torch/csrc/jit/script/schema_matching.h>
-#include <torch/csrc/jit/script/error_report.h>
 #include <torch/csrc/jit/export.h>
 #include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/script/compiler.h>
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 
-namespace torch { namespace jit { namespace script {
-
+namespace torch {
+namespace jit {
+namespace script {
 
 struct RecursiveMethodCallError : public std::exception {};
 void placeholderCreator(Method&) {
@@ -27,22 +28,30 @@ Value* try_emit_call_to(
   try {
     callee.ensure_defined();
   } catch (RecursiveMethodCallError&) {
-    throw ErrorReport(loc) << " method '" << callee.name()
+    throw ErrorReport(loc)
+        << " method '" << callee.name()
         << "' is called recursively involving this call site. Recursive calls are not supported";
   }
   auto fn = callee.graph();
 
   auto matched_schema = tryMatchSchema(
-    callee.getSchema(),
-    loc, graph, std::move(self), args, kwargs, failure_messages, conv_tensors_to_nums);
-  if(!matched_schema)
+      callee.getSchema(),
+      loc,
+      graph,
+      std::move(self),
+      args,
+      kwargs,
+      failure_messages,
+      conv_tensors_to_nums);
+  if (!matched_schema)
     return nullptr;
 
   // parameters to callee method (which become parameters to _this_ method
   // if they were not already)
-  for(at::Tensor* member : callee.params()) {
-    if(!caller) {
-      throw ErrorReport(loc) << " attempting to call a method with parameters from a raw graph. File a bug report";
+  for (at::Tensor* member : callee.params()) {
+    if (!caller) {
+      throw ErrorReport(loc)
+          << " attempting to call a method with parameters from a raw graph. File a bug report";
     }
     matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
   }
@@ -50,7 +59,11 @@ Value* try_emit_call_to(
   return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
 }
 
-Value* Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
+Value* Method::emit_call_to(
+    const SourceRange& loc,
+    Method& callee,
+    ArrayRef<NamedValue> args,
+    ArrayRef<NamedValue> kwargs) {
   JIT_ASSERT(!executor);
   std::stringstream failure_messages;
   if (auto result = try_emit_call_to(
@@ -69,7 +82,7 @@ Value* Method::emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<Na
 }
 
 void Method::ensure_defined() {
-  if(method_creator) {
+  if (method_creator) {
     auto creator = method_creator;
     method_creator = placeholderCreator;
     creator(*this);
@@ -119,4 +132,6 @@ void Module::to_impl(
   }
 }
 
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
index ae6375c..8712404 100644 (file)
@@ -1,17 +1,17 @@
 #pragma once
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/argument_spec.h>
-#include <torch/csrc/jit/function_schema.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/graph_executor.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/source_range.h>
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/api/include/torch/ordered_dict.h>
 #include <torch/csrc/utils/memory.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <c10/util/ArrayRef.h>
 #include <c10/util/Optional.h>
 #include <functional>
 #include <memory>
 #include <mutex>
+#include <ostream>
 #include <string>
 #include <unordered_map>
 #include <vector>
-#include <ostream>
 
 // This file contains classes which assist in desugaring Python style
 // modules and their methods into flattened graphs which don't have any
 // function calls.
 
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
 
 // A method in a module, e.g. f in:
 //
@@ -42,25 +44,28 @@ namespace torch { namespace jit { namespace script {
 struct Module;
 
 struct Method {
-  Method(Module* owner, std::string name, bool optimize,
-         std::shared_ptr<Graph> graph,
-         std::vector<at::Tensor*> initial_members,
-         std::function<void(Method&)> method_creator)
-  : owner_(owner)
-  , name_(std::move(name))
-  , graph_(std::move(graph))
-  , optimize(optimize)
-  , member_inputs(std::move(initial_members))
-  , method_creator(std::move(method_creator)) {
+  Method(
+      Module* owner,
+      std::string name,
+      bool optimize,
+      std::shared_ptr<Graph> graph,
+      std::vector<at::Tensor*> initial_members,
+      std::function<void(Method&)> method_creator)
+      : owner_(owner),
+        name_(std::move(name)),
+        graph_(std::move(graph)),
+        optimize(optimize),
+        member_inputs(std::move(initial_members)),
+        method_creator(std::move(method_creator)) {
     JIT_ASSERT(graph_->inputs().size() >= member_inputs.size());
     int i = graph_->inputs().size() - member_inputs.size();
-    for(at::Tensor* member : member_inputs) {
+    for (at::Tensor* member : member_inputs) {
       member_input_index[member] = i++;
     }
   }
 
-  void run(Stack & stack) {
-    for(at::Tensor* tp : member_inputs) {
+  void run(Stack& stack) {
+    for (at::Tensor* tp : member_inputs) {
       stack.emplace_back(*tp);
     }
     get_executor().run(stack);
@@ -73,7 +78,7 @@ struct Method {
   }
 
   std::shared_ptr<Graph> graph_for(Stack inputs) {
-    for(at::Tensor* tp : member_inputs) {
+    for (at::Tensor* tp : member_inputs) {
       inputs.emplace_back(*tp);
     }
     return get_executor().graphFor(inputs);
@@ -82,25 +87,29 @@ struct Method {
     return graph_;
   }
 
-  TORCH_API const std::string & name() const {
+  TORCH_API const std::string& name() const {
     return name_;
   }
   // emit a function call by inlining the callees Graph into this one
   // adding any extra parameters necessary to do this call
 
-  // defined here to keep details of member_input handling confined to this class
-  Value* emit_call_to(const SourceRange& loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs);
+  // defined here to keep details of member_input handling confined to this
+  // class
+  Value* emit_call_to(
+      const SourceRange& loc,
+      Method& callee,
+      ArrayRef<NamedValue> args,
+      ArrayRef<NamedValue> kwargs);
 
   // if this isn't yet defined, run its method_creator function
   TORCH_API void ensure_defined();
 
-
   size_t num_inputs() const {
     return graph()->inputs().size() - member_inputs.size();
   }
-  TORCH_API Value * get_or_add_parameter(at::Tensor* slot) {
+  TORCH_API Value* get_or_add_parameter(at::Tensor* slot) {
     auto it = member_input_index.find(slot);
-    if(it != member_input_index.end()) {
+    if (it != member_input_index.end()) {
       return graph()->inputs().at(it->second);
     }
     // add it as a new parameter
@@ -109,11 +118,13 @@ struct Method {
     return graph()->addInput();
   }
 
-  std::shared_ptr<Graph> propagate_shapes(std::vector<at::Tensor> inputs, bool with_grad=false) {
+  std::shared_ptr<Graph> propagate_shapes(
+      std::vector<at::Tensor> inputs,
+      bool with_grad = false) {
     auto retval = graph_->copy();
     Stack stack;
     stack.reserve(inputs.size() + member_inputs.size());
-    for (at::Tensor & i : inputs) {
+    for (at::Tensor& i : inputs) {
       stack.emplace_back(std::move(i));
     }
     for (at::Tensor* inp : member_inputs) {
@@ -125,33 +136,41 @@ struct Method {
     return retval;
   }
 
-  std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, bool with_grad=false, bool propagate=true) {
+  std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
+      std::vector<at::Tensor> inputs,
+      std::vector<at::Tensor> outputs,
+      bool with_grad = false,
+      bool propagate = true) {
     auto retval = graph_->copy();
     for (auto inp : member_inputs) {
       inputs.push_back(*inp);
     }
     if (propagate) {
-      setInputTypes(*retval, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+      setInputTypes(
+          *retval,
+          ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
       PropagateInputShapes(retval);
     }
     JIT_ASSERT(retval->inputs().size() == inputs.size());
-    for (size_t i=0; i < retval->inputs().size(); ++i) {
+    for (size_t i = 0; i < retval->inputs().size(); ++i) {
       auto scalar_type = inputs[i].type().scalarType();
       auto sizes = inputs[i].sizes();
-      auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+      auto type =
+          torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
       retval->inputs()[i]->setType(type);
     }
     at::ArrayRef<Value*> output_values = retval->outputs();
     // patch this to still work if we are returning a tuple of multiple values
     if (output_values.at(0)->type()->kind() == TupleType::Kind) {
-      JIT_ASSERT(output_values.at(0)->node()->kind()== prim::TupleConstruct);
+      JIT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
       output_values = output_values.at(0)->node()->inputs();
     }
     JIT_ASSERT(output_values.size() == outputs.size());
-    for (size_t i=0; i < retval->outputs().size(); ++i) {
+    for (size_t i = 0; i < retval->outputs().size(); ++i) {
       auto scalar_type = outputs[i].type().scalarType();
       auto sizes = outputs[i].sizes();
-      auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+      auto type =
+          torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
       output_values[i]->setType(type);
     }
     return retval;
@@ -167,7 +186,7 @@ struct Method {
   }
 
   TORCH_API const FunctionSchema& getSchema() const {
-    if(schema == nullptr) {
+    if (schema == nullptr) {
       schema = make_unique<FunctionSchema>(defaultSchemaFor(*this));
     }
     return *schema;
@@ -202,22 +221,23 @@ struct Method {
         graph()->outputs().size() == 1,
         "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
   }
-private:
 
+ private:
   static FunctionSchema defaultSchemaFor(const Method& method) {
     std::vector<Argument> args;
     std::vector<Argument> returns;
     Graph& g = *method.graph();
     size_t num_inputs = method.num_inputs();
-    for(size_t i = 0; i < num_inputs; ++i) {
+    for (size_t i = 0; i < num_inputs; ++i) {
       const Value* v = g.inputs().at(i);
-      std::string name = v->hasUniqueName() ? v->uniqueNameBase() : ("argument_"  + std::to_string(i));
+      std::string name = v->hasUniqueName() ? v->uniqueNameBase()
+                                            : ("argument_" + std::to_string(i));
       args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
     }
-    for(size_t i = 0; i < g.outputs().size(); ++i) {
+    for (size_t i = 0; i < g.outputs().size(); ++i) {
       returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
     }
-    return { method.name(), std::move(args), std::move(returns) };
+    return {method.name(), std::move(args), std::move(returns)};
   }
 
   GraphExecutor& get_executor() {
@@ -233,9 +253,14 @@ private:
     // Do we have more inputs than the schema accepts?
     AT_CHECK(
         inputs.size() <= schema.arguments().size(),
-        "Expected at most ", schema.arguments().size(),
-        " argument(s) for operator '", schema.name(), "', but received ",
-        inputs.size(), " argument(s). Declaration: ", schema);
+        "Expected at most ",
+        schema.arguments().size(),
+        " argument(s) for operator '",
+        schema.name(),
+        "', but received ",
+        inputs.size(),
+        " argument(s). Declaration: ",
+        schema);
 
     for (size_t pos = 0; pos < schema.arguments().size(); ++pos) {
       const auto& argument = schema.arguments()[pos];
@@ -244,22 +269,31 @@ private:
         // and should be replaced with a function isSubvalueOf(ivalue, type)
         // That asks if the specific value is a valid instance of type.
         const TypePtr inputType = incompleteInferTypeFrom(inputs[pos]);
-        AT_CHECK(inputType->isSubtypeOf(argument.type()),
-              "Expected value of type ", *argument.type(),
-              " for argument '", argument.name(),
-              "' in position ", pos,
-              ", but instead got value of type ", *inputType,
-              ". Declaration: ", schema);
+        AT_CHECK(
+            inputType->isSubtypeOf(argument.type()),
+            "Expected value of type ",
+            *argument.type(),
+            " for argument '",
+            argument.name(),
+            "' in position ",
+            pos,
+            ", but instead got value of type ",
+            *inputType,
+            ". Declaration: ",
+            schema);
       } else if (argument.default_value()) {
         inputs.push_back(*argument.default_value());
       } else {
-        AT_ERROR(schema.name(), "() is missing value for argument '",
-                argument.name(), "'. Declaration: ", schema);
+        AT_ERROR(
+            schema.name(),
+            "() is missing value for argument '",
+            argument.name(),
+            "'. Declaration: ",
+            schema);
       }
     }
   }
 
-
   // Methods are uniqued onwed by a single module. This raw pointer allows
   // looking up the module.
   Module* owner_;
@@ -290,9 +324,9 @@ private:
 
   std::once_flag executor_init;
 
-  // an optional function that actually creates the method when emit_call_to(this,...)
-  // is first called.
-  // this is used by the compiler so that it can construct methods out of order
+  // an optional function that actually creates the method when
+  // emit_call_to(this,...) is first called. this is used by the compiler so
+  // that it can construct methods out of order
   std::function<void(Method&)> method_creator;
 
   // if absent, then we generate a default schema based on the graph
@@ -310,17 +344,18 @@ struct NamedModule {
 
 struct NamedParameter {
   NamedParameter(std::string name, at::Tensor tensor, bool is_buffer)
-  : name(std::move(name))
-  , is_buffer(is_buffer)
-  , parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
+      : name(std::move(name)),
+        is_buffer(is_buffer),
+        parameter(torch::make_unique<at::Tensor>(std::move(tensor))) {}
 
   const std::string name;
   bool is_buffer; // buffers are part of the module state but
-                        // are not modified by optimizers during SGD
+                  // are not modified by optimizers during SGD
   at::Tensor* slot() const {
     return parameter.get();
   }
-private:
+
+ private:
   // the extra level of indirection allows Methods to safely store pointers
   // to the slots where parameters are kept while also allow parameters
   // to be reassigned
@@ -330,10 +365,10 @@ private:
 struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
   Module()
-  : modules("Module")
-  , parameters("Parameter")
-  , methods("Method")
-  , optimize(true) {}
+      : modules("Module"),
+        parameters("Parameter"),
+        methods("Method"),
+        optimize(true) {}
 
   // note this doesn't change the flags of existing methods just ones
   // added afterward.
@@ -349,34 +384,56 @@ struct Module {
     return get_method("forward")(std::move(inputs));
   }
 
-  void register_parameter(const std::string & name, autograd::Variable v, bool is_buffer) {
-    if(auto p = parameters.find(name)){
+  void register_parameter(
+      const std::string& name,
+      autograd::Variable v,
+      bool is_buffer) {
+    if (auto p = parameters.find(name)) {
       *p->slot() = v;
       p->is_buffer = is_buffer;
       return;
     }
     parameters.insert(name, NamedParameter(name, std::move(v), is_buffer));
   }
-  void register_module(const std::string& name, std::shared_ptr<Module> module) {
+  void register_module(
+      const std::string& name,
+      std::shared_ptr<Module> module) {
     modules.insert(name, {name, std::move(module)});
   }
 
-  Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) {
+  Method& create_method(
+      const std::string& name,
+      std::shared_ptr<Graph> graph,
+      std::vector<at::Tensor*> member_inputs) {
     JIT_ASSERT(graph);
-    std::unique_ptr<Method> method(new Method(this, name, optimize, std::move(graph), std::move(member_inputs), nullptr));
+    std::unique_ptr<Method> method(new Method(
+        this,
+        name,
+        optimize,
+        std::move(graph),
+        std::move(member_inputs),
+        nullptr));
     return *methods.insert(name, std::move(method));
   }
 
-  Method& create_method(const std::string & name, std::function<void(Method&)> creator) {
-    std::unique_ptr<Method> method(new Method(this, name, optimize, std::make_shared<Graph>(), {}, std::move(creator)));
+  Method& create_method(
+      const std::string& name,
+      std::function<void(Method&)> creator) {
+    std::unique_ptr<Method> method(new Method(
+        this,
+        name,
+        optimize,
+        std::make_shared<Graph>(),
+        {},
+        std::move(creator)));
     return *methods.insert(name, std::move(method));
   }
 
-  at::Tensor* parameter_slot(const std::string & name) const {
+  at::Tensor* parameter_slot(const std::string& name) const {
     return parameters[name].slot();
   }
 
-  void set_parameter(const std::string & name, at::Tensor v) {
+  void set_parameter(const std::string& name, at::Tensor v) {
     *parameter_slot(name) = std::move(v);
   }
 
@@ -397,10 +454,12 @@ struct Module {
   const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
     return modules;
   }
-  const torch::OrderedDict<std::string, NamedParameter>& get_parameters() const {
+  const torch::OrderedDict<std::string, NamedParameter>& get_parameters()
+      const {
     return parameters;
   }
-  const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods() const {
+  const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
+      const {
     return methods;
   }
 
@@ -417,7 +476,7 @@ struct Module {
     return nullptr;
   }
   void apply(std::function<void(Module&)> fn) {
-    for (auto &submod : get_modules()) {
+    for (autosubmod : get_modules()) {
       submod.value().module->apply(fn);
     }
     fn(*this);
@@ -472,26 +531,29 @@ struct Module {
 
   void save(const std::string& filename);
 
-  void copy_into(std::function<std::shared_ptr<Module>(
-      std::vector<std::string>)> module_lookup,
-      // parameter_remap is needed when a parent module uses a parameter of a submodule
+  void copy_into(
+      std::function<std::shared_ptr<Module>(std::vector<std::string>)>
+          module_lookup,
+      // parameter_remap is needed when a parent module uses a parameter of a
+      // submodule
       std::unordered_map<at::Tensor*, at::Tensor*>& parameter_remap,
       std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
-    for (auto &kv : parameters) {
-      curr->register_parameter(kv.key(), *kv.value().slot(), kv.value().is_buffer);
+    for (auto& kv : parameters) {
+      curr->register_parameter(
+          kv.key(), *kv.value().slot(), kv.value().is_buffer);
       parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
     }
-    for (auto &kv : modules) {
+    for (autokv : modules) {
       names.push_back(kv.key());
       // Submodules must be translated first, otherwise parameter_remap entries
       // will not be filled in for methods of this module.
       kv.value().module->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
-    for (auto &kv : methods) {
+    for (autokv : methods) {
       std::vector<at::Tensor*> params;
-      for (auto &p : kv.value()->params()) {
+      for (autop : kv.value()->params()) {
         params.push_back(parameter_remap.at(p));
       }
       curr->create_method(kv.key(), kv.value()->graph()->copy(), params);
@@ -528,4 +590,6 @@ Value* try_emit_call_to(
     // unit, and not a method), then nullptr can be passed as caller.
     Method* caller,
     bool conv_tensors_to_nums);
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
index aa34c5d..e8effe3 100644 (file)
@@ -1,15 +1,16 @@
 #pragma once
-#include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/error_report.h>
 #include <c10/util/Optional.h>
+#include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/lexer.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
 inline bool isCharCount(char c, const std::string& str, size_t start, int len) {
-  //count checks from [start, start + len)
-  return start + len <= str.size() && std::count(str.begin() + start, str.begin() + start + len, c) == len;
+  // count checks from [start, start + len)
+  return start + len <= str.size() &&
+      std::count(str.begin() + start, str.begin() + start + len, c) == len;
 }
 
 inline static bool isOctal(char c) {
@@ -21,23 +22,25 @@ inline c10::optional<char> parseOctal(const std::string& str, size_t pos) {
   if (pos + 3 >= str.size())
     return c10::nullopt;
   size_t c = 0;
-  for(size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
+  for (size_t i = 1, b = 64; i < 4; ++i, b /= 8) {
     int d = str[pos + i];
     if (d < '0' || d > '7')
       return c10::nullopt;
     c += b * (d - '0');
   }
-  if(c >= 256)
+  if (c >= 256)
     return c10::nullopt;
   return c;
 }
 
-inline std::string parseStringLiteral(const SourceRange& range, const std::string &str) {
+inline std::string parseStringLiteral(
+    const SourceRange& range,
+    const std::string& str) {
   int quote_len = isCharCount(str[0], str, 0, 3) ? 3 : 1;
   auto ret_str = str.substr(quote_len, str.size() - quote_len * 2);
   size_t pos = ret_str.find('\\');
-  while(pos != std::string::npos) {
-    //invariant: pos has to escape a character because it is a valid string
+  while (pos != std::string::npos) {
+    // invariant: pos has to escape a character because it is a valid string
     char c = ret_str[pos + 1];
     size_t to_erase = 2;
     switch (ret_str[pos + 1]) {
@@ -65,16 +68,14 @@ inline std::string parseStringLiteral(const SourceRange& range, const std::strin
         c = '\t';
         break;
       case 'h':
-        throw ErrorReport(range)
-            << "unsupported hex specifier";
+        throw ErrorReport(range) << "unsupported hex specifier";
       default:
         // \0NN
         if (auto v = parseOctal(str, pos + 1)) {
           to_erase = 4;
           c = *v;
         } else {
-          throw ErrorReport(range)
-              << " ill formed octal specifier";
+          throw ErrorReport(range) << " ill formed octal specifier";
         }
     }
     ret_str.replace(pos, to_erase, /* num copies */ 1, c);
index 95ae7e2..c140779 100644 (file)
@@ -1,24 +1,29 @@
+#include <c10/util/Optional.h>
 #include <torch/csrc/jit/script/lexer.h>
-#include <torch/csrc/jit/script/tree.h>
+#include <torch/csrc/jit/script/parse_string_literal.h>
 #include <torch/csrc/jit/script/parser.h>
+#include <torch/csrc/jit/script/tree.h>
 #include <torch/csrc/jit/script/tree_views.h>
-#include <c10/util/Optional.h>
-#include <torch/csrc/jit/script/parse_string_literal.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
-Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_decl, bool is_method) {
+Decl mergeTypesFromTypeComment(
+    const Decl& decl,
+    const Decl& type_annotation_decl,
+    bool is_method) {
   auto expected_num_annotations = decl.params().size();
   if (is_method) {
     // `self` argument
     expected_num_annotations -= 1;
   }
   if (expected_num_annotations != type_annotation_decl.params().size()) {
-    throw ErrorReport(type_annotation_decl.range()) << "Number of type annotations ("
-      << type_annotation_decl.params().size() << ") did not match the number of "
-      << "function parameters (" << expected_num_annotations << ")";
+    throw ErrorReport(type_annotation_decl.range())
+        << "Number of type annotations ("
+        << type_annotation_decl.params().size()
+        << ") did not match the number of "
+        << "function parameters (" << expected_num_annotations << ")";
   }
   auto old = decl.params();
   auto _new = type_annotation_decl.params();
@@ -33,7 +38,10 @@ Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_dec
   for (; i < decl.params().size(); ++i, ++j) {
     new_params.emplace_back(old[i].withType(_new[j].type()));
   }
-  return Decl::create(decl.range(), List<Param>::create(decl.range(), new_params), type_annotation_decl.return_type());
+  return Decl::create(
+      decl.range(),
+      List<Param>::create(decl.range(), new_params),
+      type_annotation_decl.return_type());
 }
 
 struct ParserImpl {
@@ -60,7 +68,7 @@ struct ParserImpl {
   }
 
   static bool followsTuple(int kind) {
-    switch(kind) {
+    switch (kind) {
       case TK_PLUS_EQ:
       case TK_MINUS_EQ:
       case TK_TIMES_EQ:
@@ -77,9 +85,9 @@ struct ParserImpl {
   // exp | expr, | expr, expr, ...
   Expr parseExpOrExpTuple() {
     auto prefix = parseExp();
-    if(L.cur().kind == ',') {
-      std::vector<Expr> exprs = { prefix };
-      while(L.nextIf(',')) {
+    if (L.cur().kind == ',') {
+      std::vector<Expr> exprs = {prefix};
+      while (L.nextIf(',')) {
         if (followsTuple(L.cur().kind))
           break;
         exprs.push_back(parseExp());
@@ -159,8 +167,10 @@ struct ParserImpl {
       } break;
     }
   }
-  TreeRef
-  parseTrinary(TreeRef true_branch, const SourceRange& range, int binary_prec) {
+  TreeRef parseTrinary(
+      TreeRef true_branch,
+      const SourceRange& range,
+      int binary_prec) {
     auto cond = parseExp();
     L.expect(TK_ELSE);
     auto false_branch = parseExp(binary_prec);
@@ -170,7 +180,9 @@ struct ParserImpl {
   // precedence strictly greater than 'precedence'
   // precedence == 0 will parse _all_ expressions
   // this is the core loop of 'top-down precedence parsing'
-  Expr parseExp() { return parseExp(0); }
+  Expr parseExp() {
+    return parseExp(0);
+  }
   Expr parseExp(int precedence) {
     TreeRef prefix = nullptr;
     int unary_prec;
@@ -178,13 +190,12 @@ struct ParserImpl {
       auto kind = L.cur().kind;
       auto pos = L.cur().range;
       L.next();
-      auto unary_kind = kind == '*' ? TK_STARRED :
-                        kind == '-' ? TK_UNARY_MINUS :
-                                      kind;
+      auto unary_kind =
+          kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind;
       auto subexp = parseExp(unary_prec);
       // fold '-' into constant numbers, so that attributes can accept
       // things like -1
-      if(unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
+      if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
         prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
       } else {
         prefix = c(unary_kind, pos, {subexp});
@@ -214,7 +225,7 @@ struct ParserImpl {
     }
     return Expr(prefix);
   }
-  template<typename T>
+  template <typename T>
   List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
     auto r = L.cur().range;
     if (begin != TK_NOTHING)
@@ -239,7 +250,7 @@ struct ParserImpl {
   StringLiteral parseConcatenatedStringLiterals() {
     auto range = L.cur().range;
     std::stringstream ss;
-    while(L.cur().kind == TK_STRINGLITERAL) {
+    while (L.cur().kind == TK_STRINGLITERAL) {
       auto literal_range = L.cur().range;
       ss << parseStringLiteral(literal_range, L.next().text());
     }
@@ -258,7 +269,8 @@ struct ParserImpl {
           auto ident = parseIdent();
           L.expect('=');
           auto v = parseAttributeValue();
-          attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
+          attributes.push_back(
+              Attribute::create(ident.range(), Ident(ident), v));
         } else {
           inputs.push_back(parseExp());
         }
@@ -278,8 +290,10 @@ struct ParserImpl {
       if (L.cur().kind != ',' && L.cur().kind != ']') {
         second = parseExp();
       }
-      auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first)) : Maybe<Expr>::create(range);
-      auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second)) : Maybe<Expr>::create(range);
+      auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
+                               : Maybe<Expr>::create(range);
+      auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
+                                 : Maybe<Expr>::create(range);
       return SliceExpr::create(range, maybe_first, maybe_second);
     } else {
       return Expr(first);
@@ -289,7 +303,8 @@ struct ParserImpl {
   TreeRef parseSubscript(const TreeRef& value) {
     const auto range = L.cur().range;
 
-    auto subscript_exprs = parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
+    auto subscript_exprs =
+        parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
     return Subscript::create(range, Expr(value), subscript_exprs);
   }
 
@@ -307,18 +322,24 @@ struct ParserImpl {
     } else {
       def = Maybe<Expr>::create(L.cur().range);
     }
-    return Param::create(type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
+    return Param::create(
+        type->range(), Ident(ident), Expr(type), Maybe<Expr>(def));
   }
 
   Param parseBareTypeAnnotation() {
     auto type = parseExp();
-    return Param::create(type.range(), Ident::create(type.range(), ""), type, Maybe<Expr>::create(type.range()));
+    return Param::create(
+        type.range(),
+        Ident::create(type.range(), ""),
+        type,
+        Maybe<Expr>::create(type.range()));
   }
 
   Decl parseTypeComment() {
     auto range = L.cur().range;
     L.expect(TK_TYPE_COMMENT);
-    auto param_types = parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
+    auto param_types =
+        parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
     TreeRef return_type;
     if (L.nextIf(TK_ARROW)) {
       auto return_type_range = L.cur().range;
@@ -344,8 +365,7 @@ struct ParserImpl {
         throw ErrorReport(lhs.range())
             << " augmented assignment can only have one LHS expression";
       }
-      return AugAssign::create(
-          lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
+      return AugAssign::create(lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
     }
   }
 
@@ -359,7 +379,8 @@ struct ParserImpl {
         return parseFor();
       case TK_GLOBAL: {
         auto range = L.next().range;
-        auto idents = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
+        auto idents =
+            parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
         L.expect(TK_NEWLINE);
         return Global::create(range, idents);
       }
@@ -380,7 +401,7 @@ struct ParserImpl {
         auto range = L.next().range;
         auto cond = parseExp();
         Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
-        if (L.nextIf(','))  {
+        if (L.nextIf(',')) {
           auto msg = parseExp();
           maybe_first = Maybe<Expr>::create(range, Expr(msg));
         }
@@ -415,7 +436,7 @@ struct ParserImpl {
     }
     return list;
   }
-  TreeRef parseIf(bool expect_if=true) {
+  TreeRef parseIf(bool expect_if = true) {
     auto r = L.cur().range;
     if (expect_if)
       L.expect(TK_IF);
@@ -433,7 +454,8 @@ struct ParserImpl {
       auto range = L.cur().range;
       false_branch = makeList(range, {parseIf(false)});
     }
-    return If::create(r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
+    return If::create(
+        r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
   }
   TreeRef parseWhile() {
     auto r = L.cur().range;
@@ -446,7 +468,8 @@ struct ParserImpl {
   TreeRef parseFor() {
     auto r = L.cur().range;
     L.expect(TK_FOR);
-    auto targets = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
+    auto targets =
+        parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
     L.expect(TK_IN);
     auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
     L.expect(':');
@@ -454,7 +477,7 @@ struct ParserImpl {
     return For::create(r, targets, itrs, body);
   }
 
-  TreeRef parseStatements(bool expect_indent=true) {
+  TreeRef parseStatements(bool expect_indent = true) {
     auto r = L.cur().range;
     if (expect_indent) {
       L.expect(TK_INDENT);
@@ -462,7 +485,7 @@ struct ParserImpl {
     TreeList stmts;
     do {
       stmts.push_back(parseStmt());
-    } while(!L.nextIf(TK_DEDENT));
+    } while (!L.nextIf(TK_DEDENT));
     return c(TK_LIST, r, std::move(stmts));
   }
 
@@ -482,7 +505,8 @@ struct ParserImpl {
     TreeRef return_type;
     Maybe<Expr> return_annotation = parseReturnAnnotation();
     L.expect(':');
-    return Decl::create(paramlist.range(), List<Param>(paramlist), return_annotation);
+    return Decl::create(
+        paramlist.range(), List<Param>(paramlist), return_annotation);
   }
 
   TreeRef parseFunction(bool is_method) {
@@ -500,8 +524,8 @@ struct ParserImpl {
     }
 
     auto stmts_list = parseStatements(false);
-    return Def::create(name.range(), Ident(name), Decl(decl),
-                       List<Stmt>(stmts_list));
+    return Def::create(
+        name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
   }
   Lexer& lexer() {
     return L;
@@ -519,8 +543,7 @@ struct ParserImpl {
   SharedParserData& shared;
 };
 
-Parser::Parser(const std::string& src)
-: pImpl(new ParserImpl(src)) {}
+Parser::Parser(const std::string& src) : pImpl(new ParserImpl(src)) {}
 
 Parser::~Parser() = default;
 
index 8b9d710..b695e2c 100644 (file)
@@ -1,7 +1,7 @@
 #pragma once
-#include <memory>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/script/tree.h>
+#include <memory>
 
 namespace torch {
 namespace jit {
@@ -11,7 +11,10 @@ struct Decl;
 struct ParserImpl;
 struct Lexer;
 
-TORCH_API Decl mergeTypesFromTypeComment(const Decl& decl, const Decl& type_annotation_decl, bool is_method);
+TORCH_API Decl mergeTypesFromTypeComment(
+    const Decl& decl,
+    const Decl& type_annotation_decl,
+    bool is_method);
 
 struct TORCH_API Parser {
   explicit Parser(const std::string& str);
@@ -19,7 +22,8 @@ struct TORCH_API Parser {
   Decl parseTypeComment();
   Lexer& lexer();
   ~Parser();
-private:
+
+ private:
   std::unique_ptr<ParserImpl> pImpl;
 };
 
index d42fc0e..3d5d399 100644 (file)
 
 namespace py = pybind11;
 
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
 
 struct SourceRangeFactory {
   SourceRangeFactory(std::string source)
-    : source_(std::make_shared<std::string>(std::move(source))) {
+      : source_(std::make_shared<std::string>(std::move(source))) {
     size_t pos = 0;
     do {
       line_len_prefix_sum_.push_back(pos);
@@ -22,8 +24,8 @@ struct SourceRangeFactory {
     } while ((pos = source_->find('\n', pos)) != std::string::npos);
   }
   SourceRange create(int line, int start_col, int end_col) {
-    // Python has a weird convention where col_offset points to the column *before*
-    // the token starts.
+    // Python has a weird convention where col_offset points to the column
+    // *before* the token starts.
     start_col++;
     end_col++;
     // Also, lines are counted from 1.
@@ -36,47 +38,52 @@ struct SourceRangeFactory {
   std::vector<size_t> line_len_prefix_sum_;
 };
 
-template<typename T>
+template <typename T>
 List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
   if (vec.empty())
     return List<T>::create(fallback_pos, std::move(vec));
   return List<T>::create(vec.front().range(), std::move(vec));
 }
 
-template<typename T>
+template <typename T>
 Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
-  return val ? Maybe<T>::create(val->range(), *val) : Maybe<T>::create(fallback_pos);
+  return val ? Maybe<T>::create(val->range(), *val)
+             : Maybe<T>::create(fallback_pos);
 }
 
-void initTreeViewBindings(PyObject *module) {
+void initTreeViewBindings(PyObjectmodule) {
   auto _C = py::handle(module).cast<py::module>();
   auto m = _C.def_submodule("_jit_tree_views");
 
   py::class_<SourceRange>(m, "SourceRange")
-    .def("highlight", [](const SourceRange& self) {
-      std::ostringstream stream;
-      self.highlight(stream);
-      return stream.str();
-    })
-    .def_property_readonly("start", &SourceRange::start)
-    .def_property_readonly("end", &SourceRange::end);
+      .def(
+          "highlight",
+          [](const SourceRange& self) {
+            std::ostringstream stream;
+            self.highlight(stream);
+            return stream.str();
+          })
+      .def_property_readonly("start", &SourceRange::start)
+      .def_property_readonly("end", &SourceRange::end);
   py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
-    .def(py::init<std::string&&>())
-    .def("make_range", &SourceRangeFactory::create)
-    .def("make_raw_range", [](const SourceRangeFactory& self, size_t start, size_t end) {
-      return SourceRange(self.source_, start, end);
-    })
-    .def_property_readonly("source", [](const SourceRangeFactory& self) {
-      return *self.source_;
-    });
+      .def(py::init<std::string&&>())
+      .def("make_range", &SourceRangeFactory::create)
+      .def(
+          "make_raw_range",
+          [](const SourceRangeFactory& self, size_t start, size_t end) {
+            return SourceRange(self.source_, start, end);
+          })
+      .def_property_readonly("source", [](const SourceRangeFactory& self) {
+        return *self.source_;
+      });
 
   py::class_<TreeView>(m, "TreeView")
-    .def("range", &TreeView::range)
-    .def("__str__", [](const TreeView& tree) {
-      std::ostringstream stream;
-      stream << tree.get();
-      return stream.str();
-    });
+      .def("range", &TreeView::range)
+      .def("__str__", [](const TreeView& tree) {
+        std::ostringstream stream;
+        stream << tree.get();
+        return stream.str();
+      });
 
   py::class_<Ident, TreeView>(m, "Ident")
       .def(py::init(&Ident::create))
@@ -84,13 +91,14 @@ void initTreeViewBindings(PyObject *module) {
           "name", [](const Ident& self) { return self.name(); });
 
   py::class_<Param, TreeView>(m, "Param")
-    .def(py::init([](const Expr& type, const Ident& name) {
-      return Param::create(name.range(), name, type, Maybe<Expr>::create(name.range()));
-    }));
+      .def(py::init([](const Expr& type, const Ident& name) {
+        return Param::create(
+            name.range(), name, type, Maybe<Expr>::create(name.range()));
+      }));
   py::class_<Attribute, TreeView>(m, "Attribute")
-    .def(py::init([](const Ident& name, const Expr& value) {
-      return Attribute::create(name.range(), name, value);
-    }));
+      .def(py::init([](const Ident& name, const Expr& value) {
+        return Attribute::create(name.range(), name, value);
+      }));
   m.def("TrueLiteral", [](const SourceRange& range) {
     return Expr(Compound::create(TK_TRUE, range, {}));
   });
@@ -103,60 +111,60 @@ void initTreeViewBindings(PyObject *module) {
 
   py::class_<Stmt, TreeView>(m, "Stmt"); // NOLINT(bugprone-unused-raii)
   py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
-  py::class_<Def, TreeView>(m, "Def")
-    .def(py::init([](const Ident& name,
-                     Decl decl,
-                     std::vector<Stmt> body) {
-      const auto& r = name.range();
-      return Def::create(r,
-                         name,
-                         decl,
-                         wrap_list(r, std::move(body)));
-    }));
-  py::class_<Decl, TreeView>(m, "Decl")
-    .def(py::init([](const SourceRange& r,
-                     std::vector<Param> params,
-                     Expr *return_type) {
-      return Decl::create(r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
-    }));
-
+  py::class_<Def, TreeView>(m, "Def").def(
+      py::init([](const Ident& name, Decl decl, std::vector<Stmt> body) {
+        const auto& r = name.range();
+        return Def::create(r, name, decl, wrap_list(r, std::move(body)));
+      }));
+  py::class_<Decl, TreeView>(m, "Decl").def(py::init(
+      [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
+        return Decl::create(
+            r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
+      }));
 
   py::class_<Assign, Stmt>(m, "Assign")
-    .def(py::init([](const Expr& lhs, const Expr& rhs) {
-      return Assign::create(lhs.range(), lhs, rhs);
-    }));
+      .def(py::init([](const Expr& lhs, const Expr& rhs) {
+        return Assign::create(lhs.range(), lhs, rhs);
+      }));
   py::class_<AugAssign, Stmt>(m, "AugAssign")
-    .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
-      const auto& r = lhs.range();
-      auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
-      return AugAssign::create(r, lhs, kind, rhs);
-    }));
+      .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
+        const auto& r = lhs.range();
+        auto kind =
+            AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
+        return AugAssign::create(r, lhs, kind, rhs);
+      }));
   py::class_<Return, Stmt>(m, "Return")
-    .def(py::init([](const SourceRange& range, Expr* value) {
-      return Return::create(range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
-    }));
+      .def(py::init([](const SourceRange& range, Expr* value) {
+        return Return::create(
+            range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
+      }));
   py::class_<Raise, Stmt>(m, "Raise")
-    .def(py::init([](const SourceRange& range, Expr *expr) {
-      return Raise::create(range, wrap_maybe(range, expr));
-    }));
+      .def(py::init([](const SourceRange& range, Expr* expr) {
+        return Raise::create(range, wrap_maybe(range, expr));
+      }));
   py::class_<Assert, Stmt>(m, "Assert")
-    .def(py::init([](const SourceRange& range, const Expr& test, Expr *msg) {
-      return Assert::create(range, test, wrap_maybe(range, msg));
-    }));
-  py::class_<Pass, Stmt>(m, "Pass")
-    .def(py::init([](const SourceRange& range) {
-      return Pass::create(range);
-    }));
-  py::class_<If, Stmt>(m, "If")
-    .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> true_branch, std::vector<Stmt> false_branch) {
-      return If::create(range, cond,
-                        wrap_list(range, std::move(true_branch)),
-                        wrap_list(range, std::move(false_branch)));
-    }));
+      .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
+        return Assert::create(range, test, wrap_maybe(range, msg));
+      }));
+  py::class_<Pass, Stmt>(m, "Pass").def(
+      py::init([](const SourceRange& range) { return Pass::create(range); }));
+  py::class_<If, Stmt>(m, "If").def(
+      py::init([](const SourceRange& range,
+                  const Expr& cond,
+                  std::vector<Stmt> true_branch,
+                  std::vector<Stmt> false_branch) {
+        return If::create(
+            range,
+            cond,
+            wrap_list(range, std::move(true_branch)),
+            wrap_list(range, std::move(false_branch)));
+      }));
   py::class_<While, Stmt>(m, "While")
-    .def(py::init([](const SourceRange& range, const Expr& cond, std::vector<Stmt> body) {
-      return While::create(range, cond, wrap_list(range, std::move(body)));
-    }));
+      .def(py::init([](const SourceRange& range,
+                       const Expr& cond,
+                       std::vector<Stmt> body) {
+        return While::create(range, cond, wrap_list(range, std::move(body)));
+      }));
   py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange range,
                                                   std::vector<Expr>& targets,
                                                   std::vector<Expr>& itrs,
@@ -167,70 +175,83 @@ void initTreeViewBindings(PyObject *module) {
         wrap_list(range, std::move(itrs)),
         wrap_list(range, std::move(body)));
   }));
-  py::class_<ExprStmt, Stmt>(m, "ExprStmt")
-    .def(py::init([](const Expr& expr) {
-      return ExprStmt::create(expr.range(), expr);
-    }));
+  py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
+    return ExprStmt::create(expr.range(), expr);
+  }));
 
   py::class_<Var, Expr>(m, "Var")
-    .def(py::init([](const Ident& name) {
-      return Var::create(name.range(), name);
-    }))
-    .def_property_readonly("name", [](const Var& var) { return var.name(); });
+      .def(py::init(
+          [](const Ident& name) { return Var::create(name.range(), name); }))
+      .def_property_readonly("name", [](const Var& var) { return var.name(); });
   py::class_<BinOp, Expr>(m, "BinOp")
-    .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
-      return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
-    }));
-  // NB: we take range here, because unary ops precede their exprs, so we need to include them
+      .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
+        return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
+      }));
+  // NB: we take range here, because unary ops precede their exprs, so we need
+  // to include them
   py::class_<UnaryOp, Expr>(m, "UnaryOp")
-    .def(py::init([](const SourceRange& range, std::string kind, const Expr& expr) {
-      auto resolved_kind = stringToKind(kind);
-      resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
-      return UnaryOp::create(range, resolved_kind, expr);
-    }));
+      .def(py::init(
+          [](const SourceRange& range, std::string kind, const Expr& expr) {
+            auto resolved_kind = stringToKind(kind);
+            resolved_kind =
+                resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
+            return UnaryOp::create(range, resolved_kind, expr);
+          }));
   py::class_<Const, Expr>(m, "Const")
-    .def(py::init([](const SourceRange& range, std::string value) {
-      return Const::create(range, value);
-    }));
+      .def(py::init([](const SourceRange& range, std::string value) {
+        return Const::create(range, value);
+      }));
   py::class_<StringLiteral, Expr>(m, "StringLiteral")
-    .def(py::init([](const SourceRange& range, std::string value) {
-      return StringLiteral::create(range, value);
-    }));
+      .def(py::init([](const SourceRange& range, std::string value) {
+        return StringLiteral::create(range, value);
+      }));
   py::class_<Apply, Expr>(m, "Apply")
-    .def(py::init([](const Expr& expr, std::vector<Expr> args, std::vector<Attribute> kwargs) {
-      const auto& r = expr.range();
-      return Apply::create(expr.range(), expr,
-                           wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs)));
-    }));
+      .def(py::init([](const Expr& expr,
+                       std::vector<Expr> args,
+                       std::vector<Attribute> kwargs) {
+        const auto& r = expr.range();
+        return Apply::create(
+            expr.range(),
+            expr,
+            wrap_list(r, std::move(args)),
+            wrap_list(r, std::move(kwargs)));
+      }));
   py::class_<Select, Expr>(m, "Select")
-    .def(py::init([](const Expr& expr, const Ident& field) {
-      const auto& r = expr.range();
-      return Select::create(expr.range(), expr, field);
-    }));
+      .def(py::init([](const Expr& expr, const Ident& field) {
+        const auto& r = expr.range();
+        return Select::create(expr.range(), expr, field);
+      }));
   py::class_<TernaryIf, Expr>(m, "TernaryIf")
-    .def(py::init([](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
-      return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
-    }));
+      .def(py::init(
+          [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
+            return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
+          }));
   py::class_<ListLiteral, Expr>(m, "ListLiteral")
-    .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
-      return ListLiteral::create(range, wrap_list(range, std::move(args)));
-    }));
+      .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
+        return ListLiteral::create(range, wrap_list(range, std::move(args)));
+      }));
   py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
-    .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
-      return TupleLiteral::create(range, wrap_list(range, std::move(args)));
-    }));
+      .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
+        return TupleLiteral::create(range, wrap_list(range, std::move(args)));
+      }));
   py::class_<Subscript, Expr>(m, "Subscript")
-    .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
-      return Subscript::create(base.range(), base, wrap_list(base.range(), std::move(subscript_exprs)));
-    }));
+      .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
+        return Subscript::create(
+            base.range(),
+            base,
+            wrap_list(base.range(), std::move(subscript_exprs)));
+      }));
   py::class_<SliceExpr, Expr>(m, "SliceExpr")
-    .def(py::init([](const SourceRange& range, Expr *lower, Expr *upper) {
-      return SliceExpr::create(range, wrap_maybe(range, lower), wrap_maybe(range, upper));
-    }));
+      .def(py::init([](const SourceRange& range, Expr* lower, Expr* upper) {
+        return SliceExpr::create(
+            range, wrap_maybe(range, lower), wrap_maybe(range, upper));
+      }));
   py::class_<Starred, Expr>(m, "Starred")
-    .def(py::init([](const SourceRange& range, Expr expr){
-      return Starred::create(range, expr);
-    }));
+      .def(py::init([](const SourceRange& range, Expr expr) {
+        return Starred::create(range, expr);
+      }));
 }
 
-}}} // namespace torch::jit::script
+} // namespace script
+} // namespace jit
+} // namespace torch
index 5158729..9f078b9 100644 (file)
@@ -1,8 +1,11 @@
 #include <torch/csrc/python_headers.h>
 
-namespace torch { namespace jit { namespace script {
+namespace torch {
+namespace jit {
+namespace script {
 
-void initTreeViewBindings(PyObject *module);
-
-}}} // namespace torch::jit::script
+void initTreeViewBindings(PyObject* module);
 
+}
+} // namespace jit
+} // namespace torch
index 4ec2651..0595d27 100644 (file)
@@ -1,8 +1,8 @@
-#include <torch/csrc/jit/script/schema_matching.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/builtin_functions.h>
 #include <torch/csrc/jit/script/error_report.h>
+#include <torch/csrc/jit/script/schema_matching.h>
 
 namespace torch {
 namespace jit {
@@ -29,13 +29,13 @@ static inline bool isIntOrFloatUsedAsList(
 
 inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
   auto list_type = list_type_->cast<ListType>();
-  if(!list_type) {
+  if (!list_type) {
     return false;
   }
-  if(type->isSubtypeOf(list_type_)) {
+  if (type->isSubtypeOf(list_type_)) {
     return true;
   }
-  if(auto tuple = type->cast<TupleType>()) {
+  if (auto tuple = type->cast<TupleType>()) {
     return std::all_of(
         tuple->elements().begin(),
         tuple->elements().end(),
@@ -46,21 +46,21 @@ inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
   return false;
 }
 
-// applies implict conversion from value trying to turn it into type concrete_type
-// it succeeds if the return_value->isSubclassOf(concrete_type)
+// applies implict conversion from value trying to turn it into type
+// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
 Value* tryConvertToType(
     const SourceRange& loc,
     Graph& graph,
     const TypePtr& concrete_type,
     Value* value,
     bool allow_conversions) {
-
   if (auto value_tuple = value->type()->cast<TupleType>()) {
     // Allow homogeneous tuples to be casted implicitly to lists of appropriate
     // types
     if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
       auto unpacked = createTupleUnpack(value);
-      auto elem_type = unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
+      auto elem_type =
+          unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
       value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
     }
     // inductively apply implicit conversions to tuples
@@ -82,27 +82,30 @@ Value* tryConvertToType(
     }
   }
 
-  if (value->type()->isSubtypeOf(NoneType::get()) && !concrete_type->isSubtypeOf(NoneType::get())){
+  if (value->type()->isSubtypeOf(NoneType::get()) &&
+      !concrete_type->isSubtypeOf(NoneType::get())) {
     if (concrete_type->isSubtypeOf(OptionalType::ofTensor())) {
       // create undefined tensor when None pass to a optional[tensor] formal arg
       value = graph.insertNode(graph.createUndefined())->output();
     } else if (auto optional_type = concrete_type->cast<OptionalType>()) {
-      value = graph.insertNode(graph.createNone(optional_type->getElementType()))->output();
+      value =
+          graph.insertNode(graph.createNone(optional_type->getElementType()))
+              ->output();
     }
   }
 
-  //implicit conversions
-  if(allow_conversions) {
-     if(concrete_type->isSubtypeOf(NumberType::get())
-      && value->type()->isSubtypeOf(DynamicType::get())) {
+  // implicit conversions
+  if (allow_conversions) {
+    if (concrete_type->isSubtypeOf(NumberType::get()) &&
+        value->type()->isSubtypeOf(DynamicType::get())) {
       auto n = graph.createImplicitTensorToNum(concrete_type, value);
       value = graph.insertNode(n)
-        ->setSourceLocation(std::make_shared<SourceRange>(loc))
-        ->output();
+                  ->setSourceLocation(std::make_shared<SourceRange>(loc))
+                  ->output();
     }
     if (value->type()->isSubtypeOf(StringType::get()) &&
-        DeviceObjType::get()->isSubtypeOf(concrete_type))  {
-      return graph.insert(aten::device, { value }, {}, loc);
+        DeviceObjType::get()->isSubtypeOf(concrete_type)) {
+      return graph.insert(aten::device, {value}, {}, loc);
     }
   }
 
@@ -116,7 +119,7 @@ Value* tryMatchArgument(
     const NamedValue& named_value,
     const std::function<std::ostream&()>& err,
     bool allow_conversions,
-    TypeEnv & type_env) {
+    TypeEnv& type_env) {
   Value* value = named_value.value(graph);
 
   // some functions that take lists of integers or floats for fixed size arrays
@@ -124,7 +127,8 @@ Value* tryMatchArgument(
   // the single int/float is then repeated to the length of the list
   if (isIntOrFloatUsedAsList(value, arg)) {
     std::vector<Value*> repeated(*arg.N(), value);
-    value = graph.insertNode(graph.createList(value->type(), repeated))->output();
+    value =
+        graph.insertNode(graph.createList(value->type(), repeated))->output();
   }
 
   const MatchTypeReturn matched_type =
@@ -140,8 +144,9 @@ Value* tryMatchArgument(
 
   value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
 
-  if(!value->type()->isSubtypeOf(concrete_type)) {
-    err() << "expected a value of type " << concrete_type->str() << " for argument '" << arg.name() << "' but found "
+  if (!value->type()->isSubtypeOf(concrete_type)) {
+    err() << "expected a value of type " << concrete_type->str()
+          << " for argument '" << arg.name() << "' but found "
           << value->type()->str() << "\n"
           << named_value.locOr(loc);
     return nullptr;
@@ -152,8 +157,8 @@ Value* tryMatchArgument(
 c10::optional<size_t> findInputWithName(
     const std::string& name,
     at::ArrayRef<NamedValue> kwargs) {
-  for(size_t i = 0; i < kwargs.size(); ++i) {
-    if(kwargs[i].name() == name)
+  for (size_t i = 0; i < kwargs.size(); ++i) {
+    if (kwargs[i].name() == name)
       return i;
   }
   return c10::nullopt;
@@ -166,12 +171,13 @@ Value* tryCreateList(
     at::ArrayRef<NamedValue> varargs,
     const std::function<std::ostream&()>& err,
     bool convert_tensor_to_num,
-    TypeEnv & type_env) {
+    TypeEnv& type_env) {
   Argument elem_arg("<varargs>", elem_type);
   std::vector<Value*> list_ctor;
-  for(const auto& a : varargs) {
-    Value* av = tryMatchArgument(elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
-    if(!av)
+  for (const auto& a : varargs) {
+    Value* av = tryMatchArgument(
+        elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
+    if (!av)
       return nullptr;
     list_ctor.push_back(av);
   }
@@ -206,9 +212,10 @@ c10::optional<MatchedSchema> tryMatchSchema(
       self = c10::nullopt;
     } else if (!arg.kwarg_only() && used_args < args.size()) {
       // allow zeros(IntList sizes) to work with zeros(1, 2) or zeros(1)
-      if (arg.type()->kind() == TypeKind::ListType && // the formal must be a list
-          !arg.N() && // it must not be a broadcasting list like int[3], otherwise
-                    // a single int is a valid input
+      if (arg.type()->kind() ==
+              TypeKind::ListType && // the formal must be a list
+          !arg.N() && // it must not be a broadcasting list like int[3],
+                      // otherwise a single int is a valid input
           (schema_i + 1 == schema.arguments().size() ||
            schema.arguments()[schema_i + 1]
                .kwarg_only())) { // must be the last position argument
@@ -216,8 +223,10 @@ c10::optional<MatchedSchema> tryMatchSchema(
         if (actual_type->kind() != TypeKind::ListType &&
             !convertibleToList(
                 actual_type,
-                unwrapOptional(arg.type()))) { // and the actual should not be a list already
-          auto elem_type = unwrapOptional(arg.type())->expect<ListType>()->getElementType();
+                unwrapOptional(arg.type()))) { // and the actual should not be a
+                                               // list already
+          auto elem_type =
+              unwrapOptional(arg.type())->expect<ListType>()->getElementType();
           Value* list = tryCreateList(
               elem_type,
               graph,
@@ -254,19 +263,19 @@ c10::optional<MatchedSchema> tryMatchSchema(
             << loc;
       return c10::nullopt;
     }
-    Value* positional = tryMatchArgument(
-        arg, graph, loc, *v, err, allow_conversions, type_env);
+    Value* positional =
+        tryMatchArgument(arg, graph, loc, *v, err, allow_conversions, type_env);
     if (!positional)
       return c10::nullopt;
     positional_inputs.push_back(positional);
   }
   // check for unused self argument
-  if(self != c10::nullopt) {
+  if (self != c10::nullopt) {
     err() << "provided self argument not used in schema\n";
   }
 
   if (schema.is_vararg()) {
-    for(;used_args < args.size(); ++used_args) {
+    for (; used_args < args.size(); ++used_args) {
       positional_inputs.push_back(args[used_args].value(graph));
     }
   }
@@ -296,11 +305,10 @@ c10::optional<MatchedSchema> tryMatchSchema(
   return MatchedSchema{std::move(positional_inputs), std::move(return_types)};
 }
 
-
-// pack outputs of a function following python rules. If there is a single value return
-// a SimpleValue, otherwise pack all the values into a Tuple.
+// pack outputs of a function following python rules. If there is a single value
+// return a SimpleValue, otherwise pack all the values into a Tuple.
 Value* packOutputs(Graph& g, at::ArrayRef<Value*> values) {
-  if(values.size() == 1) {
+  if (values.size() == 1) {
     return values[0];
   }
   return g.insertNode(g.createTuple(values))->output();
@@ -314,9 +322,9 @@ static Value* emitBuiltinNode(
     Graph& graph,
     Symbol name) {
   auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
-                ->setSourceLocation(std::make_shared<SourceRange>(loc));
+               ->setSourceLocation(std::make_shared<SourceRange>(loc));
 
-  for(auto & ret : matched_schema.return_types) {
+  for (auto& ret : matched_schema.return_types) {
     n->addOutput()->setType(ret);
   }
 
@@ -327,11 +335,13 @@ static Value* emitBuiltinNode(
   return packOutputs(graph, n->outputs());
 }
 
-static std::string prefixLine(const std::string& str, const std::string& prefix) {
+static std::string prefixLine(
+    const std::string& str,
+    const std::string& prefix) {
   std::stringstream ss;
   bool was_newline = true;
-  for(auto c : str) {
-    if(was_newline)
+  for (auto c : str) {
+    if (was_newline)
       ss << prefix;
     ss.put(c);
     was_newline = c == '\n';
@@ -342,23 +352,21 @@ static std::string prefixLine(const std::string& str, const std::string& prefix)
 // Search for operators matching the provided symbol name and input types.
 // If one is found, emit a node to the graph for that operator.
 Value* emitBuiltinCall(
-  const SourceRange& loc,
-  Graph& graph,
-  Symbol name,
-  const c10::optional<NamedValue>& self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
-  // otherwise it will return nullptr if the builtin is not found.
-  bool required) {
-
-
+    const SourceRange& loc,
+    Graph& graph,
+    Symbol name,
+    const c10::optional<NamedValue>& self,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    // if true, emitBuiltinCall will throw an exception if this builtin does not
+    // exist, otherwise it will return nullptr if the builtin is not found.
+    bool required) {
   const auto& variants = getAllOperatorsFor(name);
   const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
 
   std::stringstream failure_messages;
-  //first we try to match the schema without any conversion
-  //if no schema matches then insert ImplicitTensorToNum
+  // first we try to match the schema without any conversion
+  // if no schema matches then insert ImplicitTensorToNum
   for (bool allow_conversions : {false, true}) {
     // clear previous error messages
     failure_messages.str("");
@@ -396,7 +404,7 @@ Value* emitBuiltinCall(
   if (!required) {
     return nullptr;
   }
-  if(variants.size() == 0) {
+  if (variants.size() == 0) {
     throw ErrorReport(loc) << "unknown builtin op";
   }
   throw ErrorReport(loc) << "arguments for call are not valid:\n"
@@ -404,7 +412,6 @@ Value* emitBuiltinCall(
                          << "for call at";
 }
 
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 506a474..937bdc2 100644 (file)
@@ -1,18 +1,18 @@
 #pragma once
-#include <torch/csrc/jit/type.h>
-#include <torch/csrc/jit/named_value.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/jit/function_schema.h>
+#include <torch/csrc/jit/named_value.h>
+#include <torch/csrc/jit/type.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
-  // try to match a list if inputs and keyword 'attributes' to this schema,
-  // if it works return the flat list of positional inputs to the call
-  // if it returns nullopt, then failure_messages contains a good error report
-  // set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted to
-  // match the schema
+// try to match a list if inputs and keyword 'attributes' to this schema,
+// if it works return the flat list of positional inputs to the call
+// if it returns nullopt, then failure_messages contains a good error report
+// set convert_tensor_to_num to true if ImplicitTensorToNums should be inserted
+// to match the schema
 
 struct MatchedSchema {
   std::vector<Value*> inputs;
@@ -20,32 +20,32 @@ struct MatchedSchema {
 };
 
 TORCH_API c10::optional<MatchedSchema> tryMatchSchema(
-  const FunctionSchema& schema,
-  const SourceRange& loc,
-  Graph& graph,
-  c10::optional<NamedValue> self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  std::ostream& failure_messages,
-  bool allow_conversions);
+    const FunctionSchema& schema,
+    const SourceRange& loc,
+    Graph& graph,
+    c10::optional<NamedValue> self,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    std::ostream& failure_messages,
+    bool allow_conversions);
 
 TORCH_API Value* emitBuiltinCall(
-  const SourceRange& loc,
-  Graph& graph,
-  Symbol name,
-  const c10::optional<NamedValue>& self,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
-  // otherwise it will return nullptr if the builtin is not found.
-  bool required);
+    const SourceRange& loc,
+    Graph& graph,
+    Symbol name,
+    const c10::optional<NamedValue>& self,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    // if true, emitBuiltinCall will throw an exception if this builtin does not
+    // exist, otherwise it will return nullptr if the builtin is not found.
+    bool required);
 
 TORCH_API c10::optional<size_t> findInputWithName(
-  const std::string& name,
-  at::ArrayRef<NamedValue> kwargs);
+    const std::string& name,
+    at::ArrayRef<NamedValue> kwargs);
 
-// applies implict conversion from value trying to turn it into type concrete_type
-// it succeeds if the return_value->isSubclassOf(concrete_type)
+// applies implict conversion from value trying to turn it into type
+// concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
 TORCH_API Value* tryConvertToType(
     const SourceRange& loc,
     Graph& graph,
@@ -53,6 +53,6 @@ TORCH_API Value* tryConvertToType(
     Value* value,
     bool allow_conversions);
 
-}
+} // namespace script
 } // namespace jit
 } // namespace torch
index 418685b..605344b 100644 (file)
@@ -1,7 +1,7 @@
-#include <torch/csrc/jit/script/type_parser.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/script/tree_views.h>
 #include <torch/csrc/jit/script/sugared_value.h>
+#include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/type_parser.h>
 
 namespace torch {
 namespace jit {
@@ -15,48 +15,53 @@ struct NoneValue : SugaredValue {
 };
 
 std::shared_ptr<SugaredValue> PrintValue::call(
-  const SourceRange& loc,
-  Method & m,
-  at::ArrayRef<NamedValue> inputs,
-  at::ArrayRef<NamedValue> attributes,
-  size_t n_binders) {
-    auto& g = *m.graph();
-    if (!attributes.empty())
-      throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
+    const SourceRange& loc,
+    Method& m,
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) {
+  auto& g = *m.graph();
+  if (!attributes.empty())
+    throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
 
-    //temporary hack to allow print statements to work in python 2, where
-    //print(a, b) is treated as a (a, b) tuple input.
+  // temporary hack to allow print statements to work in python 2, where
+  // print(a, b) is treated as a (a, b) tuple input.
 
-    std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
-    if(lowered_inputs.size() == 1 && lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
-      auto input = lowered_inputs[0];
-      for(size_t j = 0; j < input->node()->inputs().size(); ++j) {
-        lowered_inputs.insert(lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
-      }
-      lowered_inputs.erase(lowered_inputs.begin());
+  std::vector<Value*> lowered_inputs = toValues(*m.graph(), inputs);
+  if (lowered_inputs.size() == 1 &&
+      lowered_inputs.at(0)->node()->kind() == prim::TupleConstruct) {
+    auto input = lowered_inputs[0];
+    for (size_t j = 0; j < input->node()->inputs().size(); ++j) {
+      lowered_inputs.insert(
+          lowered_inputs.begin() + 1 + j, input->node()->inputs().at(j));
     }
-    g.insertNode(g.create(prim::Print, lowered_inputs, 0)
-                     ->setSourceLocation(std::make_shared<SourceRange>(loc)));
-    return std::make_shared<NoneValue>();
+    lowered_inputs.erase(lowered_inputs.begin());
+  }
+  g.insertNode(g.create(prim::Print, lowered_inputs, 0)
+                   ->setSourceLocation(std::make_shared<SourceRange>(loc)));
+  return std::make_shared<NoneValue>();
 }
 
-static const std::unordered_map<std::string, std::string> &builtin_cast_methods() {
+static const std::unordered_map<std::string, std::string>&
+builtin_cast_methods() {
   static std::unordered_map<std::string, std::string> builtin_cast_methods = {
-    {"byte", "_cast_Byte"},
-    {"char", "_cast_Char"},
-    {"double", "_cast_Double"},
-    {"float", "_cast_Float"},
-    {"int", "_cast_Int"},
-    {"long", "_cast_Long"},
-    {"short", "_cast_Short"},
-    {"half", "_cast_Half"}
-  };
+      {"byte", "_cast_Byte"},
+      {"char", "_cast_Char"},
+      {"double", "_cast_Double"},
+      {"float", "_cast_Float"},
+      {"int", "_cast_Int"},
+      {"long", "_cast_Long"},
+      {"short", "_cast_Short"},
+      {"half", "_cast_Half"}};
   return builtin_cast_methods;
 }
 
 // support syntax sugar for x.foo(y, z) by allowing x.foo to return a
 // callable value that will resolve to foo(x, y, z) when called.
-std::shared_ptr<SugaredValue> SimpleValue::attr(const SourceRange& loc, Method & m, const std::string& field) {
+std::shared_ptr<SugaredValue> SimpleValue::attr(
+    const SourceRange& loc,
+    Method& m,
+    const std::string& field) {
   // Allow method-style casts on Tensor types. e.g. x.int()
   if (value->type()->isSubtypeOf(DynamicType::get())) {
     if (builtin_cast_methods().count(field)) {
@@ -67,14 +72,15 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(const SourceRange& loc, Method &
     // functions that are just direct property lookups on tensor
     // must be registered as prim::<name>(Tensor t) -> <return_type>
     static const std::unordered_set<std::string> fields = {
-      "dtype",
-      "device",
-      "shape",
-      "is_cuda",
-      "requires_grad",
+        "dtype",
+        "device",
+        "shape",
+        "is_cuda",
+        "requires_grad",
     };
     if (fields.count(field)) {
-      auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value});
+      auto r =
+          m.graph()->insert(Symbol::fromQualString("prim::" + field), {value});
       return std::make_shared<SimpleValue>(r);
     }
   }
@@ -89,21 +95,25 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
     const SourceRange& loc,
     Method& m,
     const c10::optional<size_t>& size_hint) {
-  static const auto make_simple_value = [](Value* v) -> std::shared_ptr<SugaredValue> {
+  static const auto make_simple_value =
+      [](Value* v) -> std::shared_ptr<SugaredValue> {
     return std::make_shared<SimpleValue>(v);
   };
-  if(value->type()->kind() == TypeKind::TupleType) {
+  if (value->type()->kind() == TypeKind::TupleType) {
     auto outputs = createTupleUnpack(value);
     return fmap(outputs, make_simple_value);
   } else if (value->type()->kind() == TypeKind::ListType) {
     if (!size_hint) {
-      throw ErrorReport(loc) << "cannot statically infer the expected size of a list in this context";
+      throw ErrorReport(loc)
+          << "cannot statically infer the expected size of a list in this context";
     }
     auto graph = value->owningGraph();
-    Node *unpack = graph->insertNode(graph->createListUnpack(value, *size_hint));
+    Node* unpack =
+        graph->insertNode(graph->createListUnpack(value, *size_hint));
     return fmap(unpack->outputs(), make_simple_value);
   }
-  throw ErrorReport(loc) << value->type()->str() << " cannot be used as a tuple";
+  throw ErrorReport(loc) << value->type()->str()
+                         << " cannot be used as a tuple";
 }
 
 } // namespace script
index 916792e..5f998b1 100644 (file)
@@ -5,8 +5,8 @@
 
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/error_report.h>
-#include <torch/csrc/jit/script/tree_views.h>
 #include <torch/csrc/jit/script/module.h>
+#include <torch/csrc/jit/script/tree_views.h>
 
 namespace torch {
 namespace jit {
@@ -20,11 +20,7 @@ namespace script {
 // that separates their behavior from the AST -> IR converter itself.
 // This allows us to keep dependencies on python minimal.
 
-enum NoneStatus {
- ALWAYS,
- MAYBE,
- NEVER
-};
+enum NoneStatus { ALWAYS, MAYBE, NEVER };
 
 struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   // what is this node? for error reporting (e.g. Module, python function)
@@ -32,12 +28,15 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
 
   // what can we do with this thing?
   // use it as a value e.g.  `this + 4`
-  virtual Value * asValue(const SourceRange& loc, Method & m) {
+  virtual Value* asValue(const SourceRange& loc, Method& m) {
     throw ErrorReport(loc) << kind() << " cannot be used as a value";
   }
 
   // select an attribute on it, e.g. `this.field`
-  virtual std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) {
+  virtual std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) {
     throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
   }
   virtual NoneStatus isNone() {
@@ -55,26 +54,25 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
 
   // call it like a function, e.g. `outputs = this(inputs)`
   virtual std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    // note: names for args will be 'argument 0', 'argument 1', etc..
-    at::ArrayRef<NamedValue> inputs_,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) {
-// n_binders is always set to the number of variables an expression is
-// syntactically bound to:
-//     a = foo() # 1 binder (note in this case the single binder might be a tuple)
-//     a, * b = foo() # 1 binder
-//     a, b = foo() # 2 binders
-//     foo() # 0 binders
-//
-// In subexpressions, like bar() in foo(bar()), n_binders is always set to
-// 1. n_binders is used as a hint to subexpressions to determine how many
-// values they should return when that number is ambiguous statically. In
-// particular it is currently used to decide how many tensors a call to a
-// python function will return. It is only a hint, functions do not have to
-// check that n_binders match the number of things they are returning, the
-// assignment logic will do that anyway.
+      const SourceRange& loc,
+      Method& m,
+      // note: names for args will be 'argument 0', 'argument 1', etc..
+      at::ArrayRef<NamedValue> inputs_,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) {
+    // n_binders is always set to the number of variables an expression is
+    // syntactically bound to:
+    //     a = foo() # 1 binder (note in this case the single binder might be a
+    //     tuple) a, * b = foo() # 1 binder a, b = foo() # 2 binders foo() # 0
+    //     binders
+    //
+    // In subexpressions, like bar() in foo(bar()), n_binders is always set to
+    // 1. n_binders is used as a hint to subexpressions to determine how many
+    // values they should return when that number is ambiguous statically. In
+    // particular it is currently used to decide how many tensors a call to a
+    // python function will return. It is only a hint, functions do not have to
+    // check that n_binders match the number of things they are returning, the
+    // assignment logic will do that anyway.
 
     throw ErrorReport(loc) << "cannot call a " << kind();
   }
@@ -85,12 +83,11 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
 // most things in the environment are just simple value types
 // and not special python syntax sugar types
 struct TORCH_API SimpleValue : public SugaredValue {
-  SimpleValue(Value * value)
-  : value(value) {}
+  SimpleValue(Value* value) : value(value) {}
   std::string kind() const override {
     return "value";
   }
-  Value * asValue(const SourceRange& range, Method & m) override {
+  Value* asValue(const SourceRange& range, Method& m) override {
     return value;
   }
   NoneStatus isNone() override {
@@ -105,11 +102,15 @@ struct TORCH_API SimpleValue : public SugaredValue {
       const SourceRange& loc,
       Method& m,
       const c10::optional<size_t>& size_hint = {}) override;
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override;
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) override;
   Value* getValue() const {
     return value;
   }
-private:
+
+ private:
   Value* value;
 };
 
@@ -135,19 +136,21 @@ struct TORCH_API BuiltinFunction : public SugaredValue {
 };
 
 struct TORCH_API BuiltinModule : public SugaredValue {
-  BuiltinModule(std::string name,
-                c10::optional<int64_t> version = at::nullopt)
-    : name(std::move(name))
-    , version(std::move(version)) {}
+  BuiltinModule(std::string name, c10::optional<int64_t> version = at::nullopt)
+      : name(std::move(name)), version(std::move(version)) {}
 
   std::string kind() const override {
     return "builtin module";
   }
-  std::shared_ptr<SugaredValue> attr(const SourceRange& loc, Method & m, const std::string& field) override {
-    return std::make_shared<BuiltinFunction>(Symbol::fromQualString(name+"::"+field), c10::nullopt);
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field) override {
+    return std::make_shared<BuiltinFunction>(
+        Symbol::fromQualString(name + "::" + field), c10::nullopt);
   }
 
-private:
+ private:
   std::string name;
   // when we add operator versioning, emit this op as it exising at 'version'
   // if not set, use the latest version
@@ -157,8 +160,9 @@ private:
 // defines how a method obtained from a module behaves in script
 struct MethodValue : public SugaredValue {
   MethodValue(std::shared_ptr<Module> module, Method& method)
-  : module(std::move(module)) //insurance that method stays alive
-  , method(method) {}
+      : module(std::move(module)) // insurance that method stays alive
+        ,
+        method(method) {}
   std::string kind() const override {
     return "method";
   }
@@ -168,13 +172,13 @@ struct MethodValue : public SugaredValue {
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
-    return std::make_shared<SimpleValue>(caller.emit_call_to(loc, method, inputs, attributes));
+    return std::make_shared<SimpleValue>(
+        caller.emit_call_to(loc, method, inputs, attributes));
   }
 
  private:
   std::shared_ptr<Module> module;
   Method& method;
-
 };
 
 struct TORCH_API PrintValue : public SugaredValue {
@@ -182,11 +186,11 @@ struct TORCH_API PrintValue : public SugaredValue {
     return "print";
   }
   std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    at::ArrayRef<NamedValue> inputs,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) override;
+      const SourceRange& loc,
+      Method& m,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override;
 };
 
 // expressions like int(x)
@@ -194,27 +198,26 @@ struct TORCH_API PrintValue : public SugaredValue {
 // is a noop when the input is a subtype of 'type'
 struct TORCH_API CastValue : public BuiltinFunction {
   CastValue(TypePtr type, c10::Symbol method)
-  : BuiltinFunction(method, c10::nullopt)
-  , type_(std::move(type)) {}
+      : BuiltinFunction(method, c10::nullopt), type_(std::move(type)) {}
   std::shared_ptr<SugaredValue> call(
-    const SourceRange& loc,
-    Method & m,
-    at::ArrayRef<NamedValue> inputs,
-    at::ArrayRef<NamedValue> attributes,
-    size_t n_binders) override {
-      if(inputs.size() == 1 && attributes.size() == 0) {
-        auto v = inputs[0].value(*m.graph());
-        if (v->type()->isSubtypeOf(type_)) {
-          return std::make_shared<SimpleValue>(v);
-        }
+      const SourceRange& loc,
+      Method& m,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
+    if (inputs.size() == 1 && attributes.size() == 0) {
+      auto v = inputs[0].value(*m.graph());
+      if (v->type()->isSubtypeOf(type_)) {
+        return std::make_shared<SimpleValue>(v);
       }
-      return BuiltinFunction::call(loc, m , inputs, attributes, n_binders);
+    }
+    return BuiltinFunction::call(loc, m, inputs, attributes, n_binders);
   }
-private:
+
+ private:
   TypePtr type_;
 };
 
-
 // These SugaredValues have special handling in the compiler because they
 // change the normal evalution order of the expression they participate in.
 // They are exposed here so that the python frontend can inject them
@@ -249,12 +252,12 @@ struct TORCH_API IsInstanceValue : SugaredValue {
   }
 };
 
-static inline std::vector<Value*> toValues(Graph& g, at::ArrayRef<NamedValue> nvs) {
-  return fmap(nvs, [&](const NamedValue& v) {
-    return v.value(g);
-  });
+static inline std::vector<Value*> toValues(
+    Graph& g,
+    at::ArrayRef<NamedValue> nvs) {
+  return fmap(nvs, [&](const NamedValue& v) { return v.value(g); });
 }
 
-}
+} // namespace script
 } // namespace jit
 } // namespace torch
index 3b6c5d1..f12b11c 100644 (file)
@@ -1,8 +1,8 @@
 #pragma once
 
+#include <functional>
 #include <memory>
 #include <vector>
-#include <functional>
 
 #include <torch/csrc/jit/script/lexer.h>
 
@@ -72,8 +72,12 @@ struct Tree : std::enable_shared_from_this<Tree> {
   void matchNumSubtrees(int k, size_t expected_subtrees) {
     return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false);
   }
-  void matchNumSubtreesD(int k, const char* filename, int lineno,
-                         size_t expected_subtrees, bool allow_more) {
+  void matchNumSubtreesD(
+      int k,
+      const char* filename,
+      int lineno,
+      size_t expected_subtrees,
+      bool allow_more) {
     if (kind() != k) {
       std::stringstream ss;
       ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
@@ -84,8 +88,9 @@ struct Tree : std::enable_shared_from_this<Tree> {
     if (trees().size() < expected_subtrees ||
         (!allow_more && trees().size() != expected_subtrees)) {
       std::stringstream ss;
-      ss << filename << ":" << lineno << ": expected at least " << expected_subtrees
-         << " subtrees, but found only " << trees().size() << "\n";
+      ss << filename << ":" << lineno << ": expected at least "
+         << expected_subtrees << " subtrees, but found only " << trees().size()
+         << "\n";
       range().highlight(ss);
       throw std::runtime_error(ss.str());
     }
@@ -122,7 +127,8 @@ static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
 }
 
 struct Compound : public Tree {
-  Compound(int kind, SourceRange range) : Tree(kind), range_(std::move(range)) {}
+  Compound(int kind, SourceRange range)
+      : Tree(kind), range_(std::move(range)) {}
   Compound(int kind, const SourceRange& range_, TreeList&& trees_)
       : Tree(kind),
         range_(mergeRanges(range_, trees_)),
@@ -130,8 +136,10 @@ struct Compound : public Tree {
   const TreeList& trees() const override {
     return trees_;
   }
-  static TreeRef
-  create(int kind, const SourceRange& range_, TreeList&& trees_) {
+  static TreeRef create(
+      int kind,
+      const SourceRange& range_,
+      TreeList&& trees_) {
     return std::make_shared<Compound>(kind, range_, std::move(trees_));
   }
   bool isAtom() const override {
index 1640f48..31a997a 100644 (file)
@@ -9,6 +9,7 @@ namespace torch {
 namespace jit {
 namespace script {
 
+// clang-format off
 // TreeView provides a statically-typed way to traverse the tree, which should
 // be formed according to the grammar below.
 //
@@ -28,7 +29,7 @@ namespace script {
 //       | Global(List<Ident> idents)                                   TK_GLOBAL
 //       -- NB: the only type of Expr's allowed on lhs are Var
 //          Or a tuple containing Var with an optional terminating Starred
-//       | Assign(Expr lhs, Expr rhs)                                  TK_ASSIGN
+//       | Assign(Expr lhs, Expr rhs)                                   TK_ASSIGN
 //       | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs)          TK_AUG_ASSIGN
 //       | Return(List<Expr> values)                                    TK_RETURN
 //       | ExprStmt(List<Expr> expr)                                    TK_EXPR_STMT
@@ -89,6 +90,7 @@ namespace script {
 //    changes to the structure of Ident are always made right here rather
 //    than both in the parser and in this code.
 // XXX: these structs should have no fields to prevent slicing when passing by value
+// clang-format on
 struct TreeView {
   explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
   TreeRef tree() const {
@@ -107,24 +109,39 @@ struct TreeView {
     return tree_->kind();
   }
 
-protected:
+ protected:
   const TreeRef& subtree(size_t i) const {
     return tree_->trees().at(i);
   }
   TreeRef tree_;
 };
 
-template<typename T>
+template <typename T>
 struct ListIterator {
   ListIterator(TreeList::const_iterator it) : it(it) {}
-  bool operator!=(const ListIterator& rhs) const { return it != rhs.it; }
-  bool operator==(const ListIterator& rhs) const { return it == rhs.it; }
-  T operator*() const { return T(*it); }
-  ListIterator& operator+=(std::ptrdiff_t n) { it += n; return *this; }
-  ListIterator& operator++() { ++it; return *this; }
-  ListIterator& operator--() { --it; return *this; }
-
-private:
+  bool operator!=(const ListIterator& rhs) const {
+    return it != rhs.it;
+  }
+  bool operator==(const ListIterator& rhs) const {
+    return it == rhs.it;
+  }
+  T operator*() const {
+    return T(*it);
+  }
+  ListIterator& operator+=(std::ptrdiff_t n) {
+    it += n;
+    return *this;
+  }
+  ListIterator& operator++() {
+    ++it;
+    return *this;
+  }
+  ListIterator& operator--() {
+    --it;
+    return *this;
+  }
+
+ private:
   TreeList::const_iterator it;
 };
 
@@ -137,7 +154,7 @@ struct List : public TreeView {
     tree->match(TK_LIST);
     // Iterate over list to temporarily instantiate Ts that will check the type
     for (const T& elem : *this) {
-      (void) elem; //silence unused warning
+      (void)elem; // silence unused warning
     }
   }
   iterator begin() const {
@@ -156,7 +173,7 @@ struct List : public TreeView {
     return tree_->map([&](TreeRef v) { return fn(T(v)); });
   }
   static List create(const SourceRange& range, const std::vector<T>& subtrees) {
-    TreeList type_erased_sub {subtrees.begin(), subtrees.end()};
+    TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
     return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
   }
   static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
@@ -225,7 +242,8 @@ struct Stmt : public TreeView {
       case TK_DEF:
         return;
       default:
-        throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Stmt";
+        throw ErrorReport(tree)
+            << kindToString(tree->kind()) << " is not a valid Stmt";
     }
   }
 };
@@ -273,7 +291,8 @@ struct Expr : public TreeView {
       case '|':
         return;
       default:
-        throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid Expr";
+        throw ErrorReport(tree)
+            << kindToString(tree->kind()) << " is not a valid Expr";
     }
   }
 };
@@ -292,17 +311,23 @@ struct Attribute : public TreeView {
   Expr value() const {
     return Expr(subtree(1));
   }
-  static Attribute create(const SourceRange& range, const Ident& name, const TreeRef& value) {
+  static Attribute create(
+      const SourceRange& range,
+      const Ident& name,
+      const TreeRef& value) {
     return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
   }
 };
 
-
 struct Param : public TreeView {
   explicit Param(const TreeRef& tree) : TreeView(tree) {
     tree_->match(TK_PARAM);
   }
-  static Param create(const SourceRange& range, const Ident& ident, const Expr& type, const Maybe<Expr>& def) {
+  static Param create(
+      const SourceRange& range,
+      const Ident& ident,
+      const Expr& type,
+      const Maybe<Expr>& def) {
     return Param(Compound::create(TK_PARAM, range, {ident, type, def}));
   }
   Ident ident() const {
@@ -333,7 +358,10 @@ struct Decl : public TreeView {
   Maybe<Expr> return_type() const {
     return Maybe<Expr>(subtree(1));
   }
-  static Decl create(const SourceRange& range, const List<Param>& params, const Maybe<Expr>& return_type) {
+  static Decl create(
+      const SourceRange& range,
+      const List<Param>& params,
+      const Maybe<Expr>& return_type) {
     return Decl(Compound::create(TK_DECL, range, {params, return_type}));
   }
 };
@@ -360,12 +388,10 @@ struct Def : public TreeView {
       const Ident& name,
       const Decl& decl,
       const List<Stmt>& stmts) {
-    return Def(Compound::create(
-        TK_DEF, range, {name, decl, stmts}));
+    return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
   }
 };
 
-
 ////////////////////////////////////////////////////////////////////////////////
 // Statements
 ////////////////////////////////////////////////////////////////////////////////
@@ -383,7 +409,9 @@ struct If : public Stmt {
   List<Stmt> falseBranch() const {
     return List<Stmt>(subtree(2));
   }
-  If withNewBranches(const List<Stmt>& true_branch, const List<Stmt>& false_branch) const {
+  If withNewBranches(
+      const List<Stmt>& true_branch,
+      const List<Stmt>& false_branch) const {
     return create(range(), cond(), true_branch, false_branch);
   }
   static If create(
@@ -391,7 +419,8 @@ struct If : public Stmt {
       const Expr& cond,
       const List<Stmt>& true_branch,
       const List<Stmt>& false_branch) {
-    return If(Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
+    return If(
+        Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
   }
 };
 
@@ -405,7 +434,10 @@ struct While : public Stmt {
   List<Stmt> body() const {
     return List<Stmt>(subtree(1));
   }
-  static While create(const SourceRange& range, const Expr& cond, const List<Stmt>& body) {
+  static While create(
+      const SourceRange& range,
+      const Expr& cond,
+      const List<Stmt>& body) {
     return While(Compound::create(TK_WHILE, range, {cond, body}));
   }
 };
@@ -482,7 +514,6 @@ struct AugAssign : public Stmt {
   }
 };
 
-
 struct Assign : public Stmt {
   explicit Assign(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_ASSIGN);
@@ -547,13 +578,11 @@ struct Pass : public Stmt {
   explicit Pass(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_PASS);
   }
-  static Pass create(
-      const SourceRange& range) {
+  static Pass create(const SourceRange& range) {
     return Pass(Compound::create(TK_PASS, range, {}));
   }
 };
 
-
 struct ExprStmt : public Stmt {
   explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
     tree_->match(TK_EXPR_STMT);
@@ -566,7 +595,6 @@ struct ExprStmt : public Stmt {
   }
 };
 
-
 ////////////////////////////////////////////////////////////////////////////////
 // Expressions
 ////////////////////////////////////////////////////////////////////////////////
@@ -596,10 +624,12 @@ struct BinOp : public Expr {
       case '|':
       case TK_FLOOR_DIV:
         if (tree->trees().size() != 2)
-          throw ErrorReport(tree) << "BinOp expected 2 subtrees, found " << tree->trees().size();
+          throw ErrorReport(tree)
+              << "BinOp expected 2 subtrees, found " << tree->trees().size();
         return;
       default:
-        throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid BinOp";
+        throw ErrorReport(tree)
+            << kindToString(tree->kind()) << " is not a valid BinOp";
     }
   }
   Expr lhs() const {
@@ -608,7 +638,11 @@ struct BinOp : public Expr {
   Expr rhs() const {
     return Expr(subtree(1));
   }
-  static BinOp create(const SourceRange& range, int kind, const Expr& lhs, const Expr& rhs) {
+  static BinOp create(
+      const SourceRange& range,
+      int kind,
+      const Expr& lhs,
+      const Expr& rhs) {
     return BinOp(Compound::create(kind, range, {lhs, rhs}));
   }
 };
@@ -619,10 +653,12 @@ struct UnaryOp : public Expr {
       case TK_UNARY_MINUS:
       case TK_NOT:
         if (tree->trees().size() != 1)
-          throw ErrorReport(tree) << "UnaryOp expected 1 subtree, found " << tree->trees().size();
+          throw ErrorReport(tree)
+              << "UnaryOp expected 1 subtree, found " << tree->trees().size();
         return;
       default:
-        throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid UnaryOp";
+        throw ErrorReport(tree)
+            << kindToString(tree->kind()) << " is not a valid UnaryOp";
     }
   }
   static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
@@ -644,7 +680,8 @@ struct Const : public Expr {
     return std::stoll(subtree(0)->stringValue());
   }
   double asFloatingPoint() const {
-    return SharedParserData::strtod_c(subtree(0)->stringValue().c_str(), nullptr);
+    return SharedParserData::strtod_c(
+        subtree(0)->stringValue().c_str(), nullptr);
   }
   const std::string& text() const {
     return subtree(0)->stringValue();
@@ -661,8 +698,11 @@ struct StringLiteral : public Expr {
   const std::string& text() const {
     return subtree(0)->stringValue();
   }
-  static StringLiteral create(const SourceRange& range, const std::string& value) {
-    return StringLiteral(Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
+  static StringLiteral create(
+      const SourceRange& range,
+      const std::string& value) {
+    return StringLiteral(
+        Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
   }
 };
 
@@ -684,7 +724,8 @@ struct Apply : public Expr {
       const Expr& callee,
       const List<Expr>& inputs,
       const List<Attribute>& attributes) {
-    return Apply(Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
+    return Apply(
+        Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
   }
 };
 
@@ -698,7 +739,10 @@ struct Select : public Expr {
   Ident selector() const {
     return Ident(subtree(1));
   }
-  static Select create(const SourceRange& range, const Expr& value, const Ident& selector) {
+  static Select create(
+      const SourceRange& range,
+      const Expr& value,
+      const Ident& selector) {
     return Select(Compound::create('.', range, {value, selector}));
   }
 };
@@ -727,7 +771,8 @@ struct SliceExpr : public Expr {
       const Maybe<Expr>& end) {
     return SliceExpr(Compound::create(TK_SLICE_EXPR, range, {start, end}));
   }
-private:
+
+ private:
   Expr createInt(int value) const {
     return Expr(Const::create(range(), std::to_string(value)));
   }
@@ -747,7 +792,8 @@ struct Subscript : public Expr {
       const SourceRange& range,
       const Expr& value,
       const List<Expr>& subscript_exprs) {
-    return Subscript(Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
+    return Subscript(
+        Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
   }
 };
 
@@ -776,15 +822,16 @@ struct TernaryIf : public Expr {
   Expr false_expr() const {
     return Expr(subtree(2));
   }
-  static TernaryIf create(const SourceRange& range,
-                          const Expr& cond,
-                          const Expr& true_expr,
-                          const Expr& false_expr) {
-    return TernaryIf(Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
+  static TernaryIf create(
+      const SourceRange& range,
+      const Expr& cond,
+      const Expr& true_expr,
+      const Expr& false_expr) {
+    return TernaryIf(
+        Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
   };
 };
 
-
 struct ListLiteral : public Expr {
   explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
     tree_->match(TK_LIST_LITERAL);
@@ -792,7 +839,9 @@ struct ListLiteral : public Expr {
   List<Expr> inputs() const {
     return subtree(0);
   }
-  static ListLiteral create(const SourceRange& range, const List<Expr>& inputs) {
+  static ListLiteral create(
+      const SourceRange& range,
+      const List<Expr>& inputs) {
     return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
   }
 };
@@ -804,7 +853,9 @@ struct TupleLiteral : public Expr {
   List<Expr> inputs() const {
     return subtree(0);
   }
-  static TupleLiteral create(const SourceRange& range, const List<Expr>& inputs) {
+  static TupleLiteral create(
+      const SourceRange& range,
+      const List<Expr>& inputs) {
     return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
   }
 };
@@ -827,8 +878,8 @@ struct Starred : public Expr {
 
 namespace std {
 
-template<typename T>
+template <typename T>
 struct iterator_traits<torch::jit::script::ListIterator<T>>
-  : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
+    : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
 
 } // namespace std
index e55c6de..61c137b 100644 (file)
@@ -1,58 +1,73 @@
-#include <torch/csrc/jit/script/type_parser.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/script/tree_views.h>
+#include <torch/csrc/jit/script/type_parser.h>
 
 namespace torch {
 namespace jit {
 namespace script {
 
-const std::unordered_map<std::string, TypePtr> &ident_to_type_lut() {
+const std::unordered_map<std::string, TypePtr>ident_to_type_lut() {
   static std::unordered_map<std::string, TypePtr> map = {
-    {"Tensor", DynamicType::get()},
-    {"int", IntType::get()},
-    {"float", FloatType::get()},
-    {"bool", BoolType::get()},
-    {"str", StringType::get()},
-    {"Device", DeviceObjType::get()},
-    // technically this is not a python type but we need it when
-    // parsing serialized methods that use implicit converions to Scalar
-    {"number", NumberType::get()},
-    {"None", NoneType::get()},
+      {"Tensor", DynamicType::get()},
+      {"int", IntType::get()},
+      {"float", FloatType::get()},
+      {"bool", BoolType::get()},
+      {"str", StringType::get()},
+      {"Device", DeviceObjType::get()},
+      // technically this is not a python type but we need it when
+      // parsing serialized methods that use implicit converions to Scalar
+      {"number", NumberType::get()},
+      {"None", NoneType::get()},
   };
   return map;
 }
 
-const std::unordered_map<std::string, std::function<TypePtr(Subscript)>> &subscript_to_type_fns() {
-  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>> map = {
-    {"Tuple", [](Subscript subscript) -> TypePtr {
-      std::vector<TypePtr> subscript_expr_types;
-      for (auto expr : subscript.subscript_exprs()) {
-        subscript_expr_types.push_back(parseTypeFromExpr(expr));
-      }
-      return TupleType::create(subscript_expr_types);
-    }},
-    {"List", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return ListType::create(elem_type);
-    }},
-    {"Optional", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return OptionalType::create(elem_type);
-    }},
-    {"Future", [](Subscript subscript) -> TypePtr {
-      if (subscript.subscript_exprs().size() != 1) {
-        throw ErrorReport(subscript) << " expected exactly one element type but found " << subscript.subscript_exprs().size();
-      }
-      auto elem_type = parseTypeFromExpr(*subscript.subscript_exprs().begin());
-      return FutureType::create(elem_type);
-    }},
-  };
+const std::unordered_map<std::string, std::function<TypePtr(Subscript)>>&
+subscript_to_type_fns() {
+  static std::unordered_map<std::string, std::function<TypePtr(Subscript)>>
+      map = {
+          {"Tuple",
+           [](Subscript subscript) -> TypePtr {
+             std::vector<TypePtr> subscript_expr_types;
+             for (auto expr : subscript.subscript_exprs()) {
+               subscript_expr_types.push_back(parseTypeFromExpr(expr));
+             }
+             return TupleType::create(subscript_expr_types);
+           }},
+          {"List",
+           [](Subscript subscript) -> TypePtr {
+             if (subscript.subscript_exprs().size() != 1) {
+               throw ErrorReport(subscript)
+                   << " expected exactly one element type but found "
+                   << subscript.subscript_exprs().size();
+             }
+             auto elem_type =
+                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
+             return ListType::create(elem_type);
+           }},
+          {"Optional",
+           [](Subscript subscript) -> TypePtr {
+             if (subscript.subscript_exprs().size() != 1) {
+               throw ErrorReport(subscript)
+                   << " expected exactly one element type but found "
+                   << subscript.subscript_exprs().size();
+             }
+             auto elem_type =
+                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
+             return OptionalType::create(elem_type);
+           }},
+          {"Future",
+           [](Subscript subscript) -> TypePtr {
+             if (subscript.subscript_exprs().size() != 1) {
+               throw ErrorReport(subscript)
+                   << " expected exactly one element type but found "
+                   << subscript.subscript_exprs().size();
+             }
+             auto elem_type =
+                 parseTypeFromExpr(*subscript.subscript_exprs().begin());
+             return FutureType::create(elem_type);
+           }},
+      };
   return map;
 }
 
@@ -60,9 +75,8 @@ bool isTorch(const Expr& expr) {
   return expr.kind() == TK_VAR && Var(expr).name().name() == "torch";
 }
 
-
-
-c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(const Expr& expr) {
+c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(
+    const Expr& expr) {
   if (expr.kind() != TK_SUBSCRIPT)
     return c10::nullopt;
   auto subscript = Subscript(expr);
@@ -72,7 +86,7 @@ c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(const Expr& expr)
   auto subscript_exprs = subscript.subscript_exprs();
 
   // handle the case where the BroadcastingList is wrapped in a Optional type
-  if(var.name().name() == "Optional") {
+  if (var.name().name() == "Optional") {
     auto broadcast_list = parseBroadcastList(subscript_exprs[0]);
     if (broadcast_list) {
       TypePtr opt_type = OptionalType::create(broadcast_list->first);
@@ -86,17 +100,19 @@ c10::optional<std::pair<TypePtr, int32_t>> parseBroadcastList(const Expr& expr)
 
   if (subscript_exprs.size() != 1)
     throw ErrorReport(subscript.subscript_exprs().range())
-      << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
+        << "BroadcastingList/Optional[BroadcastingList] must be subscripted with a type";
 
   auto typ = subscript_exprs[0];
   auto len = var.name().name().substr(strlen("BroadcastingList"));
 
   if (typ.kind() != TK_VAR)
-    throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+    throw ErrorReport(subscript.value().range())
+        << "Subscripted type must be a type identifier";
 
   auto value_name = Var(typ).name().name();
   if (value_name != "float" && value_name != "int")
-    throw ErrorReport(subscript.value().range()) << "Broadcastable lists only supported for int or float";
+    throw ErrorReport(subscript.value().range())
+        << "Broadcastable lists only supported for int or float";
 
   auto elem_ptr = ident_to_type_lut().find(value_name);
   JIT_ASSERT(elem_ptr != ident_to_type_lut().end());
@@ -137,10 +153,12 @@ TypePtr parseTypeFromExpr(const Expr& expr) {
     auto subscript = Subscript(expr);
     auto value_name = parseBaseTypeName(subscript.value());
     if (!value_name) {
-      throw ErrorReport(subscript.value().range()) << "Subscripted type must be a type identifier";
+      throw ErrorReport(subscript.value().range())
+          << "Subscripted type must be a type identifier";
     }
     if (!subscript_to_type_fns().count(*value_name)) {
-      throw ErrorReport(subscript.range()) << "Unknown type constructor " << *value_name;
+      throw ErrorReport(subscript.range())
+          << "Unknown type constructor " << *value_name;
     }
     return subscript_to_type_fns().at(*value_name)(subscript);
   } else if (auto name = parseBaseTypeName(expr)) {
@@ -150,8 +168,9 @@ TypePtr parseTypeFromExpr(const Expr& expr) {
     }
     throw ErrorReport(expr) << "Unknown type name " << *name;
   }
-  throw ErrorReport(expr.range()) << "Expression of type " << kindToString(expr.kind())
-                                  << " cannot be used in a type expression";
+  throw ErrorReport(expr.range())
+      << "Expression of type " << kindToString(expr.kind())
+      << " cannot be used in a type expression";
 }
 } // namespace script
 } // namespace jit
index 50e2ea8..b35c55e 100644 (file)
@@ -7,7 +7,8 @@ namespace script {
 struct Expr;
 TORCH_API c10::optional<std::string> parseBaseTypeName(const Expr& expr);
 TORCH_API c10::TypePtr parseTypeFromExpr(const Expr& expr);
-TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(const Expr& expr);
-}
+TORCH_API c10::optional<std::pair<c10::TypePtr, int32_t>> parseBroadcastList(
+    const Expr& expr);
+} // namespace script
 } // namespace jit
 } // namespace torch
index acbd839..cd851d5 100644 (file)
@@ -5,7 +5,8 @@
 #include <stdexcept>
 #include <string>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 // SourceLocation represents source code-level debug information for a node.
 // It contains information about where a node got generated.
 // In the case of tracing this will be a python stack trace.
@@ -13,21 +14,24 @@ namespace torch { namespace jit {
 // by a SourceRange object
 struct SourceLocation {
   virtual ~SourceLocation() = default;
-  virtual void highlight(std::ostream & out) const = 0;
+  virtual void highlight(std::ostream& out) const = 0;
 
-  std::string wrapException(const std::exception & e, const std::string & additional = "") {
+  std::string wrapException(
+      const std::exception& e,
+      const std::string& additional = "") {
     std::stringstream msg;
     msg << "\n" << e.what() << ":\n";
-    if(!additional.empty()) {
+    if (!additional.empty()) {
       msg << additional << ":\n";
     }
     highlight(msg);
     return msg.str();
   }
-  void wrapAndRethrowException(const std::exception & e, const std::string & additional = "") {
+  void wrapAndRethrowException(
+      const std::exception& e,
+      const std::string& additional = "") {
     throw std::runtime_error(wrapException(e, additional));
   }
-
 };
 
 inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
@@ -35,16 +39,16 @@ inline std::ostream& operator<<(std::ostream& out, const SourceLocation& sl) {
   return out;
 }
 
-
 // normally a python stack trace
 struct StringSourceLocation : public SourceLocation {
-  StringSourceLocation(std::string context)
-  : context(std::move(context)) {}
-  void highlight(std::ostream & out) const override {
+  StringSourceLocation(std::string context) : context(std::move(context)) {}
+  void highlight(std::ostream& out) const override {
     out << context;
   }
-private:
+
+ private:
   std::string context;
 };
 
-}}
+} // namespace jit
+} // namespace torch
index 69049af..7229b8d 100644 (file)
@@ -1,7 +1,9 @@
 #pragma once
-#include <torch/csrc/jit/source_location.h>
 #include <torch/csrc/jit/assertions.h>
+#include <torch/csrc/jit/source_location.h>
 
+#include <algorithm>
+#include <memory>
 
 namespace torch {
 namespace jit {
@@ -10,10 +12,7 @@ namespace jit {
 // that
 // range.
 struct SourceRange : public SourceLocation {
-  SourceRange(
-      std::shared_ptr<std::string> file_,
-      size_t start_,
-      size_t end_)
+  SourceRange(std::shared_ptr<std::string> file_, size_t start_, size_t end_)
       : file_(std::move(file_)), start_(start_), end_(end_) {}
   const std::string text() const {
     return file().substr(start(), end() - start());
@@ -34,20 +33,22 @@ struct SourceRange : public SourceLocation {
     JIT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
     JIT_ASSERT(end_line == str.size() || str[end_line] == '\n');
 
-    size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines before the highlight line
-    for(size_t i = 0; begin_highlight > 0; --begin_highlight) {
-      if(str[begin_highlight - 1] == '\n')
+    size_t begin_highlight = begin_line; // beginning of context, CONTEXT lines
+                                         // before the highlight line
+    for (size_t i = 0; begin_highlight > 0; --begin_highlight) {
+      if (str[begin_highlight - 1] == '\n')
         ++i;
-      if(i >= CONTEXT)
+      if (i >= CONTEXT)
         break;
     }
     JIT_ASSERT(begin_highlight == 0 || str[begin_highlight - 1] == '\n');
 
-    size_t end_highlight = end_line; // end of context, CONTEXT lines after the highlight line
-    for(size_t i = 0; end_highlight < str.size(); ++end_highlight) {
-      if(str[end_highlight] == '\n')
+    size_t end_highlight =
+        end_line; // end of context, CONTEXT lines after the highlight line
+    for (size_t i = 0; end_highlight < str.size(); ++end_highlight) {
+      if (str[end_highlight] == '\n')
         ++i;
-      if(i >= CONTEXT)
+      if (i >= CONTEXT)
         break;
     }
     JIT_ASSERT(end_highlight == str.size() || str[end_highlight] == '\n');
index 265cd97..e445691 100644 (file)
@@ -3,7 +3,8 @@
 
 #include <torch/csrc/jit/ivalue.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 using Stack = std::vector<IValue>;
 using Operation = std::function<int(Stack&)>;
@@ -12,31 +13,36 @@ using Operation = std::function<int(Stack&)>;
 // the stack and pushes its M inputs onto the stack
 // before: <other stack items> I0, I1, ... IN <- stack.back()
 // after: <other stack items> O0, O1, ... OM
-// operations are defined this way so that ownership of inputs can be transferred
-// to the operation and it can incrementally drop ownership of tensors
-// when they become unneeded. For large operations, like 'run an entire subgraph',
-// this functionality is very important for minimizing gpu memory usage
-// return value is the relative 'offset' to jump to for the next operation:
-// pc += 1 + offset
+// operations are defined this way so that ownership of inputs can be
+// transferred to the operation and it can incrementally drop ownership of
+// tensors when they become unneeded. For large operations, like 'run an entire
+// subgraph', this functionality is very important for minimizing gpu memory
+// usage return value is the relative 'offset' to jump to for the next
+// operation:
+//   pc += 1 + offset
 // so a return value of 0 goes to the next instruction
 
 // treat the last N elements of the stack as a list, looking up
 // element i
-static inline IValue & peek(Stack & stack, size_t i, size_t N) {
+static inline IValue& peek(Stack& stack, size_t i, size_t N) {
   return *(stack.end() - N + i);
 }
 // treat the last N elements of the stack as a list, looking up the
 // slice starting at index i and having length len
-static inline at::ArrayRef<IValue> peekSlice(const Stack & stack, size_t i, size_t len, size_t N) {
+static inline at::ArrayRef<IValue> peekSlice(
+    const Stack& stack,
+    size_t i,
+    size_t len,
+    size_t N) {
   return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
 }
-static inline at::ArrayRef<IValue> last(const Stack & stack, size_t N) {
+static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
   return peekSlice(stack, 0, N, N);
 }
-static inline void drop(Stack & stack, size_t n) {
+static inline void drop(Stack& stack, size_t n) {
   stack.erase(stack.end() - n, stack.end());
 }
-static inline IValue pop(Stack & stack) {
+static inline IValue pop(Stack& stack) {
   auto r = std::move(stack.back());
   stack.pop_back();
   return r;
@@ -48,40 +54,35 @@ static inline IValue pop(Stack & stack) {
 // equivalent to:
 // b = pop(stack).toTensor();
 // a = pop(stack).toInt();
-template<typename... Types>
+template <typename... Types>
 static inline void pop(Stack& stack, Types&... args) {
   size_t i = 0;
   constexpr size_t N = sizeof...(args);
   int result[N] = {
-    (args = std::move(peek(stack,i++, N)).template to<Types>(),0)...
-  };
-  (void) result;
+      (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
+  (void)result;
   drop(stack, N);
 }
-template<typename... Types>
+template <typename... Types>
 static inline void push(Stack& stack, Types&&... args) {
   constexpr size_t N = sizeof...(args);
-  int result[N] = {
-    (stack.emplace_back(std::forward<Types>(args)), 0)...
-  };
-  (void) result;
+  int result[N] = {(stack.emplace_back(std::forward<Types>(args)), 0)...};
+  (void)result;
 }
 
 // The packer here is carefully written not to make any unnecessary
 // copies.
 
 // pack takes the return values of aten functions pushes them onto the stack
-template<typename T>
-inline void pack(Stack & stack, T&& v) {
+template <typename T>
+inline void pack(Stack& stack, T&& v) {
   stack.emplace_back(std::forward<T>(v));
 }
 
-template<std::size_t remaining, typename... Args>
-struct TuplePacker
-{
+template <std::size_t remaining, typename... Args>
+struct TuplePacker {
   // NB: *Not* a universal reference.
-  static void execute(Stack & stack, std::tuple<Args...> && t)
-  {
+  static void execute(Stack& stack, std::tuple<Args...>&& t) {
     // NB: The move here does not "destroy" the entire tuple, that is
     // not what std::move does; only the particular tuple index
     // processed here gets stolen.
@@ -90,15 +91,15 @@ struct TuplePacker
   }
 };
 
-template<typename... Args>
-struct TuplePacker<0, Args...>
-{
-  static void execute(Stack & stack, std::tuple<Args...> && t) {};
+template <typename... Args>
+struct TuplePacker<0, Args...> {
+  static void execute(Stack& stack, std::tuple<Args...>&& t){};
 };
 
-template<typename... Args>
-inline void pack(Stack & stack, std::tuple<Args...> && t) {
+template <typename... Args>
+inline void pack(Stack& stack, std::tuple<Args...>&& t) {
   TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 12e45b5..3e14546 100644 (file)
@@ -1,12 +1,11 @@
 #include <torch/csrc/jit/symbolic_script.h>
 
-
-
-namespace torch { namespace jit {
-  namespace {
-    std::mutex lock;
-    const std::vector<std::string> functions = {
-      R"(
+namespace torch {
+namespace jit {
+namespace {
+std::mutex lock;
+const std::vector<std::string> functions = {
+    R"(
         def mul(self, other):
             def backward(grad_output):
                 grad_self = (grad_output * other).sum_to_size(self.size())
@@ -21,103 +20,113 @@ namespace torch { namespace jit {
                 return grad_self, None
 
             return torch.adaptive_avg_pool2d(self, output_size), backward
-      )"
-    };
-    std::unordered_map<std::string, GradientPair> schema_to_graphs;
-
-    // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
-    // should be compiled only once and saved in Operator structure.
-    // This should be done along with merging into native_functions.yaml.
-    std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
-  } // anonymous namespace
-
-  std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
-    AT_CHECK(closure->node()->kind() == prim::TupleConstruct, "closure must be a literal tuple construct");
-    Value* fn = closure->node()->inputs().at(0);
-    Value* context = closure->node()->inputs().at(1);
-
-    AT_CHECK(fn->node()->kind() == prim::Function, "closure tuple must contain a prim::Function");
-    return std::make_pair(fn->node()->g(attr::Subgraph), context);
-  }
-
-  Argument originalReturnType(const TupleTypePtr& tup) {
-    AT_CHECK(tup->elements().size() > 1);
-    if(tup->elements().size() == 2)
-      return Argument("", tup->elements().at(0));
-    std::vector<TypePtr> types = tup->elements().vec();
-    types.pop_back();
-    return Argument("", TupleType::create(std::move(types)));
-  }
+      )"};
+std::unordered_map<std::string, GradientPair> schema_to_graphs;
+
+// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
+// should be compiled only once and saved in Operator structure.
+// This should be done along with merging into native_functions.yaml.
+std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
+} // anonymous namespace
+
+std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
+  AT_CHECK(
+      closure->node()->kind() == prim::TupleConstruct,
+      "closure must be a literal tuple construct");
+  Value* fn = closure->node()->inputs().at(0);
+  Value* context = closure->node()->inputs().at(1);
+
+  AT_CHECK(
+      fn->node()->kind() == prim::Function,
+      "closure tuple must contain a prim::Function");
+  return std::make_pair(fn->node()->g(attr::Subgraph), context);
+}
+
+Argument originalReturnType(const TupleTypePtr& tup) {
+  AT_CHECK(tup->elements().size() > 1);
+  if (tup->elements().size() == 2)
+    return Argument("", tup->elements().at(0));
+  std::vector<TypePtr> types = tup->elements().vec();
+  types.pop_back();
+  return Argument("", TupleType::create(std::move(types)));
+}
+
+void loadModule(const std::shared_ptr<script::Module>& module) {
+  for (const auto& method_ : module->get_methods()) {
+    const auto& method = method_.value();
+    GradientPair pair;
+    pair.forward = method->graph();
+
+    // lookup the backward function
+    Node* forward_tuple = pair.forward->outputs().at(0)->node();
+
+    if (forward_tuple->kind() != prim::TupleConstruct) {
+      throw script::ErrorReport(forward_tuple->getSourceLocation())
+          << "gradient must return literal a tuple";
+    }
 
-  void loadModule(const std::shared_ptr<script::Module>& module) {
-    for(const auto& method_ : module->get_methods()) {
-      const auto& method = method_.value();
-      GradientPair pair;
-      pair.forward = method->graph();
-
-      // lookup the backward function
-      Node* forward_tuple = pair.forward->outputs().at(0)->node();
-
-      if (forward_tuple->kind() != prim::TupleConstruct) {
-        throw script::ErrorReport(forward_tuple->getSourceLocation()) << "gradient must return literal a tuple";
-      }
-
-      Value* context;
-      std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back());
-
-      // do surgery on the forward function to remove the closure tuple and replace it with the
-      // context variable:
-      //  backward = (<lambda>, context_tuple)
-      //  return original, backward
-      //  -----
-      //  return original, context_tuple
-      std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
-      new_inputs.back() = context;
-      Value* new_tuple = pair.forward->appendNode(pair.forward->createTuple(new_inputs))->output();
-      pair.forward->eraseOutput(0);
-      pair.forward->registerOutput(new_tuple);
-      forward_tuple->destroy();
-
-      // derive schema from original function's schema:
-      const FunctionSchema& loaded_schema = method->getSchema();
-      FunctionSchema actual_schema(Symbol::aten(loaded_schema.name()),
+    Value* context;
+    std::tie(pair.backward, context) =
+        extractClosure(forward_tuple->inputs().back());
+
+    // do surgery on the forward function to remove the closure tuple and
+    // replace it with the context variable:
+    //  backward = (<lambda>, context_tuple)
+    //  return original, backward
+    //  -----
+    //  return original, context_tuple
+    std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
+    new_inputs.back() = context;
+    Value* new_tuple =
+        pair.forward->appendNode(pair.forward->createTuple(new_inputs))
+            ->output();
+    pair.forward->eraseOutput(0);
+    pair.forward->registerOutput(new_tuple);
+    forward_tuple->destroy();
+
+    // derive schema from original function's schema:
+    const FunctionSchema& loaded_schema = method->getSchema();
+    FunctionSchema actual_schema(
+        Symbol::aten(loaded_schema.name()),
         loaded_schema.arguments(),
-        {originalReturnType(new_tuple->type()->expect<TupleType>())}
-      );
-      std::string key = canonicalSchemaString(actual_schema);
-      schema_to_graphs[key] = std::move(pair);
-    }
+        {originalReturnType(new_tuple->type()->expect<TupleType>())});
+    std::string key = canonicalSchemaString(actual_schema);
+    schema_to_graphs[key] = std::move(pair);
   }
+}
 
-  void loadFunctions() {
-    for(const std::string& str : functions) {
-      auto cu = std::make_shared<script::Module>();
-      script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
-      loadModule(cu);
-    }
+void loadFunctions() {
+  for (const std::string& str : functions) {
+    auto cu = std::make_shared<script::Module>();
+    script::defineMethodsInModule(cu, str, script::nativeResolver, nullptr);
+    loadModule(cu);
   }
+}
 
-  c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema) {
-    std::lock_guard<std::mutex> guard(lock);
-    if (schema_to_graphs.size() == 0) {
-      loadFunctions();
-    }
-    auto cache_it = cached_gradient_pairs.find(&schema);
-    if (cache_it != cached_gradient_pairs.end()) {
-      return cache_it->second;
-    } else {
-      auto schema_str = canonicalSchemaString(schema);
-      auto sym_script_it = schema_to_graphs.find(schema_str);
-      if (sym_script_it != schema_to_graphs.end()) {
-        cached_gradient_pairs.emplace_hint(cache_it, &schema, sym_script_it->second);
-        return sym_script_it->second;
-      }
+c10::optional<GradientPair> gradientInfoForSchema(
+    const FunctionSchema& schema) {
+  std::lock_guard<std::mutex> guard(lock);
+  if (schema_to_graphs.size() == 0) {
+    loadFunctions();
+  }
+  auto cache_it = cached_gradient_pairs.find(&schema);
+  if (cache_it != cached_gradient_pairs.end()) {
+    return cache_it->second;
+  } else {
+    auto schema_str = canonicalSchemaString(schema);
+    auto sym_script_it = schema_to_graphs.find(schema_str);
+    if (sym_script_it != schema_to_graphs.end()) {
+      cached_gradient_pairs.emplace_hint(
+          cache_it, &schema, sym_script_it->second);
+      return sym_script_it->second;
     }
-    return c10::nullopt;
   }
+  return c10::nullopt;
+}
 
-  bool hasGradientInfoForSchema(const FunctionSchema& schema) {
-    return gradientInfoForSchema(schema).has_value();
-  }
+bool hasGradientInfoForSchema(const FunctionSchema& schema) {
+  return gradientInfoForSchema(schema).has_value();
+}
 
-}}
+} // namespace jit
+} // namespace torch
index 45496ab..bc8284c 100644 (file)
@@ -1,18 +1,21 @@
 #pragma once
-// This file is temporary until native_functions.yaml and derivatives.yaml are merged.
-// Ideally this should all go into native_functions.yaml
+// This file is temporary until native_functions.yaml and derivatives.yaml are
+// merged. Ideally this should all go into native_functions.yaml
 
 #include <c10/util/Optional.h>
+#include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/compiler.h>
 #include <torch/csrc/jit/script/module.h>
-#include <torch/csrc/jit/operator.h>
 
-namespace torch { namespace jit {
-  struct GradientPair {
-    std::shared_ptr<Graph> forward;
-    std::shared_ptr<Graph> backward;
-  };
+namespace torch {
+namespace jit {
+struct GradientPair {
+  std::shared_ptr<Graph> forward;
+  std::shared_ptr<Graph> backward;
+};
 
-  TORCH_API c10::optional<GradientPair> gradientInfoForSchema(const FunctionSchema& schema);
-  TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
-}}
+TORCH_API c10::optional<GradientPair> gradientInfoForSchema(
+    const FunctionSchema& schema);
+TORCH_API bool hasGradientInfoForSchema(const FunctionSchema& schema);
+} // namespace jit
+} // namespace torch
index 4060f7c..c7e9ef4 100644 (file)
@@ -1,22 +1,23 @@
 #pragma once
 
-#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 struct SymbolicVariable {
   SymbolicVariable() : v(nullptr) {}
-  /* implicit */ SymbolicVariable(Value * v) : v(v) {}
+  /* implicit */ SymbolicVariable(Value* v) : v(v) {}
   // we allow implicit conversions to/from Value since
   // this type truly just provides more methods for value
   operator Value*() const {
     return v;
   }
-  static SymbolicVariable asNewInput(Graph & g, std::string name = "") {
+  static SymbolicVariable asNewInput(Graph& g, std::string name = "") {
     return g.addInput(std::move(name));
   }
-  static SymbolicVariable asNewInput(Graph & g, TypePtr type) {
+  static SymbolicVariable asNewInput(Graph& g, TypePtr type) {
     return g.addInput()->setType(std::move(type));
   }
   const std::vector<int64_t>& sizes() const {
@@ -25,43 +26,45 @@ struct SymbolicVariable {
   void addAsOutput() const {
     v->owningGraph()->registerOutput(v);
   }
-  static std::vector<SymbolicVariable> create(Symbol kind, ArrayRef<SymbolicVariable> inputs,
-                                 int num_outputs = 1,
-                                 Node** created_node = nullptr,
-                                 Graph * g = nullptr) {
-      if(g == nullptr) {
-        g = inputs.at(0).value()->owningGraph();
-      }
-      Node* n = g->insertNode(g->create(kind, num_outputs));
-      size_t max_depth = 0;
-      ScopePtr s;
-      for(auto n : inputs) {
-        size_t d = n.value()->node()->scope()->getDepth();
-        if(d > max_depth) {
-          max_depth = d;
-          s = n.value()->node()->scope();
-        }
+  static std::vector<SymbolicVariable> create(
+      Symbol kind,
+      ArrayRef<SymbolicVariable> inputs,
+      int num_outputs = 1,
+      Node** created_node = nullptr,
+      Graph* g = nullptr) {
+    if (g == nullptr) {
+      g = inputs.at(0).value()->owningGraph();
+    }
+    Node* n = g->insertNode(g->create(kind, num_outputs));
+    size_t max_depth = 0;
+    ScopePtr s;
+    for (auto n : inputs) {
+      size_t d = n.value()->node()->scope()->getDepth();
+      if (d > max_depth) {
+        max_depth = d;
+        s = n.value()->node()->scope();
       }
-      n->setScope(s);
+    }
+    n->setScope(s);
 
-      for(auto i : inputs) {
-        n->addInput(i.value());
-      }
-      if(created_node) {
-        *created_node = n;
-      }
-      std::vector<SymbolicVariable> out;
-      for(auto v : n->outputs()) {
-        out.emplace_back(v);
-      }
-      return out;
+    for (auto i : inputs) {
+      n->addInput(i.value());
+    }
+    if (created_node) {
+      *created_node = n;
+    }
+    std::vector<SymbolicVariable> out;
+    for (auto v : n->outputs()) {
+      out.emplace_back(v);
+    }
+    return out;
   }
   static bool isConstInt(at::Scalar s, int32_t i) {
     // int32_t is safely convertible to both double and int64_t
-    if(s.isFloatingPoint()) {
-      return (double) i == s.toDouble();
+    if (s.isFloatingPoint()) {
+      return (double)i == s.toDouble();
     } else {
-      return (int64_t) i == s.toLong();
+      return (int64_t)i == s.toLong();
     }
   }
   SymbolicVariable operator*(const SymbolicVariable rhs) const {
@@ -76,37 +79,48 @@ struct SymbolicVariable {
     return (*this) * insertConstant(rhs);
   }
   SymbolicVariable operator>(at::Scalar rhs) const {
-    return create(aten::gt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::gt, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator>(const SymbolicVariable rhs) const {
-    return create(aten::gt, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::gt, {*this, rhs})[0].typeLikeWithScalarType(
+        *this, at::kByte);
   }
   SymbolicVariable operator<(at::Scalar rhs) const {
-    return create(aten::lt, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::lt, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator<(const SymbolicVariable rhs) const {
-    return create(aten::lt, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::lt, {*this, rhs})[0].typeLikeWithScalarType(
+        *this, at::kByte);
   }
   SymbolicVariable operator>=(at::Scalar rhs) const {
-    return create(aten::ge, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::ge, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator>=(const SymbolicVariable rhs) const {
-    return create(aten::ge, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::ge, {*this, rhs})[0].typeLikeWithScalarType(
+        *this, at::kByte);
   }
   SymbolicVariable operator<=(at::Scalar rhs) const {
-    return create(aten::le, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::le, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator<=(const SymbolicVariable rhs) const {
-    return create(aten::le, {*this, rhs})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::le, {*this, rhs})[0].typeLikeWithScalarType(
+        *this, at::kByte);
   }
   SymbolicVariable operator==(at::Scalar rhs) const {
-    return create(aten::eq, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::eq, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator!=(at::Scalar rhs) const {
-    return create(aten::ne, {*this, insertConstant(rhs)})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::ne, {*this, insertConstant(rhs)})[0]
+        .typeLikeWithScalarType(*this, at::kByte);
   }
   SymbolicVariable operator+(const SymbolicVariable rhs) const {
-    return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(*this);
+    return create(aten::add, {*this, rhs, insertConstant(1)})[0].typeLike(
+        *this);
   }
   SymbolicVariable operator+(at::Scalar rhs) const {
     return (*this) + insertConstant(rhs);
@@ -115,25 +129,28 @@ struct SymbolicVariable {
     return create(aten::neg, {*this})[0].typeLike(*this);
   }
   SymbolicVariable operator-(const SymbolicVariable rhs) const {
-    return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(*this);
+    return create(aten::sub, {*this, rhs, insertConstant(1)})[0].typeLike(
+        *this);
   }
   SymbolicVariable operator/(at::Scalar rhs) const {
     return create(aten::div, {*this, insertConstant(rhs)})[0].typeLike(*this);
   }
   SymbolicVariable operator%(at::Scalar rhs) const {
-    return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(*this);
+    return create(aten::remainder, {*this, insertConstant(rhs)})[0].typeLike(
+        *this);
   }
   Value* size() const {
     return v->owningGraph()->insert(aten::size, {v});
   }
-  SymbolicVariable sumToSize(Value * size) const {
+  SymbolicVariable sumToSize(Value* size) const {
     return create(prim::SumToSize, {*this, size})[0];
   }
-  SymbolicVariable expand(Value * size) const {
+  SymbolicVariable expand(Value* size) const {
     return v->owningGraph()->insert(aten::expand, {v, size});
   }
   SymbolicVariable isnan() const {
-    return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(*this, at::kByte);
+    return create(aten::ne, {*this, *this})[0].typeLikeWithScalarType(
+        *this, at::kByte);
   }
   SymbolicVariable mm(const SymbolicVariable rhs) const {
     return create(t("mm"), {*this, rhs})[0];
@@ -148,25 +165,36 @@ struct SymbolicVariable {
     return create(aten::tanh, {*this})[0].typeLike(*this);
   }
   std::vector<SymbolicVariable> chunk(int64_t chunks, int dim) const {
-    Node *chunk;
+    Nodechunk;
     auto outputs = create(prim::ConstantChunk, {value()}, chunks, &chunk);
     chunk->i_(attr::chunks, chunks)->i_(attr::dim, dim);
     return outputs;
   }
   SymbolicVariable type_as(const SymbolicVariable rhs) const {
-    return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(*this, rhs);
+    return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(
+        *this, rhs);
   }
   SymbolicVariable narrow(int dim, int64_t start, int64_t length) const {
-    return create(t("narrow"), { *this, insertConstant(dim), insertConstant(start), insertConstant(length) }, 1)[0];
+    return create(
+        t("narrow"),
+        {*this,
+         insertConstant(dim),
+         insertConstant(start),
+         insertConstant(length)},
+        1)[0];
   }
   static SymbolicVariable cat(ArrayRef<SymbolicVariable> inputs, Value* dim) {
-    Graph *g = dim->owningGraph();
-    Value * input_list;
-    if (inputs.size() == 1 && inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
+    Graph* g = dim->owningGraph();
+    Value* input_list;
+    if (inputs.size() == 1 &&
+        inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
       input_list = inputs[0];
     } else {
-      auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
-      input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
+      auto value_inputs =
+          fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+      input_list =
+          g->insertNode(g->createList(DynamicType::get(), value_inputs))
+              ->output();
     }
     return create(aten::cat, {input_list, dim})[0];
   }
@@ -175,22 +203,30 @@ struct SymbolicVariable {
     return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim));
   }
   static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, Value* dim) {
-    Graph *g = dim->owningGraph();
-    auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
-    Value *input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
+    Graph* g = dim->owningGraph();
+    auto value_inputs =
+        fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+    Value* input_list =
+        g->insertNode(g->createList(DynamicType::get(), value_inputs))
+            ->output();
     return create(aten::stack, {input_list, dim})[0];
   }
   static SymbolicVariable stack(ArrayRef<SymbolicVariable> inputs, int dim) {
     JIT_ASSERT(inputs.size() > 0);
     return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim));
   }
-  static std::vector<SymbolicVariable> broadcast_tensors(ArrayRef<SymbolicVariable> inputs) {
+  static std::vector<SymbolicVariable> broadcast_tensors(
+      ArrayRef<SymbolicVariable> inputs) {
     JIT_ASSERT(inputs.size() > 0);
-    Graph *g = inputs[0].value()->owningGraph();
-    auto value_inputs = fmap(inputs, [](const SymbolicVariable & v) { return v.value(); });
-    Value * input_list = g->insertNode(g->createList(DynamicType::get(), value_inputs))->output();
-    Value * output_list = g->insert(aten::broadcast_tensors, {input_list});
-    Node * unpack = g->insertNode(g->create(prim::ListUnpack, {output_list}, inputs.size()));
+    Graph* g = inputs[0].value()->owningGraph();
+    auto value_inputs =
+        fmap(inputs, [](const SymbolicVariable& v) { return v.value(); });
+    Value* input_list =
+        g->insertNode(g->createList(DynamicType::get(), value_inputs))
+            ->output();
+    Value* output_list = g->insert(aten::broadcast_tensors, {input_list});
+    Node* unpack = g->insertNode(
+        g->create(prim::ListUnpack, {output_list}, inputs.size()));
     return fmap<SymbolicVariable>(unpack->outputs());
   }
   static SymbolicVariable zeros_like(const SymbolicVariable input) {
@@ -224,7 +260,9 @@ struct SymbolicVariable {
     return create(t("sum"), {*this})[0];
   }
   SymbolicVariable sum(int dim, bool keepdim) const {
-    return create(t("sum"), {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0];
+    return create(
+        t("sum"),
+        {*this, insertConstant(at::IntList{dim}), insertConstant(keepdim)})[0];
   }
   SymbolicVariable squeeze(Value* dim) const {
     return create(t("squeeze"), {*this, dim})[0];
@@ -251,13 +289,16 @@ struct SymbolicVariable {
     return reshape(insertConstant(std::move(sizes)));
   }
   SymbolicVariable addmm(SymbolicVariable mat1, SymbolicVariable mat2) const {
-    return create(aten::addmm, {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
+    return create(
+        aten::addmm,
+        {*this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
   }
-  Value * value() const {
+  Value* value() const {
     return v;
   }
-private:
-  Value * insertConstant(IValue value) const {
+
+ private:
+  Value* insertConstant(IValue value) const {
     return v->owningGraph()->insertConstant(std::move(value));
   }
   SymbolicVariable typeLike(SymbolicVariable other) const {
@@ -268,7 +309,7 @@ private:
   SymbolicVariable typeLikeWithScalarType(
       SymbolicVariable other,
       at::ScalarType type) const {
-    if (auto other_type = other.v->type()->cast<CompleteTensorType>()){
+    if (auto other_type = other.v->type()->cast<CompleteTensorType>()) {
       auto new_type = other_type->toScalarType(type)->contiguous();
       v->setType(new_type);
     }
@@ -279,27 +320,30 @@ private:
       SymbolicVariable rhs) const {
     auto other_type = other.v->type()->cast<CompleteTensorType>();
     auto rhs_type = rhs.v->type()->cast<CompleteTensorType>();
-    if (other_type && rhs_type){
-      auto new_type = other_type->toScalarType(rhs_type->scalarType())->contiguous();
+    if (other_type && rhs_type) {
+      auto new_type =
+          other_type->toScalarType(rhs_type->scalarType())->contiguous();
       v->setType(new_type);
     }
     return *this;
   }
-  static Symbol a(const char * s_) {
+  static Symbol a(const char* s_) {
     return Symbol::attr(s_);
   }
-  static Symbol t(const char * s_) {
+  static Symbol t(const char* s_) {
     return Symbol::aten(s_);
   }
-  Value * v;
+  Value* v;
 };
 
 // shorter method so that toVar(v) + toVar(c) is short.
-static inline SymbolicVariable toVar(Value * v) {
+static inline SymbolicVariable toVar(Value* v) {
   return {v};
 }
 
-template<typename T, typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
+template <
+    typename T,
+    typename = typename std::enable_if<std::is_arithmetic<T>::value>::type>
 inline SymbolicVariable operator+(T lhs, SymbolicVariable rhs) {
   return rhs + at::Scalar(lhs);
 }
@@ -312,4 +356,5 @@ inline SymbolicVariable operator-(at::Scalar lhs, SymbolicVariable rhs) {
   return (lhs + (-rhs));
 }
 
-}}
+} // namespace jit
+} // namespace torch
index ae358be..199c732 100644 (file)
@@ -1,31 +1,33 @@
 #include <torch/csrc/jit/tracer.h>
 
-#include <torch/csrc/jit/assertions.h>
-#include <torch/csrc/autograd/variable.h>
-#include <torch/csrc/autograd/function.h>
 #include <torch/csrc/autograd/engine.h>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/variable.h>
+#include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/remove_expands.h>
 
-#include <string>
-#include <sstream>
 #include <memory>
+#include <sstream>
+#include <string>
 
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
 
 ////////////////////////////////////////////////////////////////////////////////
 // Recording the traces
 ////////////////////////////////////////////////////////////////////////////////
 namespace detail {
 
-template<typename T>
-void genericAddInput(Node *n, T value) {
-  Value *v = n->owningGraph()->insertConstant(value);
+template <typename T>
+void genericAddInput(Noden, T value) {
+  Valuev = n->owningGraph()->insertConstant(value);
   recordSourceLocation(v->node());
   n->addInput(v);
 }
 
-template<typename T>
+template <typename T>
 void badArgType(const T& v) {
   AT_ERROR(
       "Found an unsupported argument type in the JIT tracer: ",
@@ -37,7 +39,7 @@ thread_local std::shared_ptr<TracingState> tracing_state;
 
 } // namespace detail
 
-void setValueTrace(const IValue &v, Value *value) {
+void setValueTrace(const IValue& v, Value* value) {
   if (v.isTensor()) {
     auto var = v.toTensor();
     JIT_ASSERT(var.defined());
@@ -45,14 +47,16 @@ void setValueTrace(const IValue &v, Value *value) {
   } else if (v.isTensorList()) {
     auto& outputs = v.toTensorList()->elements();
     auto graph = getTracingState()->graph;
-    Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+    Node* unpack_node = graph->appendNode(
+        graph->create(prim::ListUnpack, {value}, outputs.size()));
     for (size_t i = 0; i < outputs.size(); ++i) {
       setValueTrace(outputs[i], unpack_node->outputs()[i]);
     }
   } else if (v.isTuple()) {
     auto& outputs = v.toTuple()->elements();
     auto graph = getTracingState()->graph;
-    Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size()));
+    Node* unpack_node = graph->appendNode(
+        graph->create(prim::TupleUnpack, {value}, outputs.size()));
     for (size_t i = 0; i < outputs.size(); ++i) {
       setValueTrace(outputs[i], unpack_node->outputs()[i]);
     }
@@ -64,114 +68,131 @@ void setValueTrace(const IValue &v, Value *value) {
   }
 }
 
-void addInputs(Node *n, const char * name, int64_t value) {
+void addInputs(Node* n, const char* name, int64_t value) {
   using ArgumentStash = jit::tracer::ArgumentStash;
   if (ArgumentStash::hasValue(name)) {
-    Value * v = ArgumentStash::popValue(name);
+    Value* v = ArgumentStash::popValue(name);
     n->addInput(v);
   } else {
     detail::genericAddInput(n, value);
   }
 }
 
-void addInputs(Node *n, const char * name, c10::optional<int64_t> value)     {
-  if(value) {
+void addInputs(Node* n, const char* name, c10::optional<int64_t> value) {
+  if (value) {
     detail::genericAddInput(n, *value);
   } else {
-    Graph * g = n->owningGraph();
-    Value* none =
-        g->insertNode(g->createNone(IntType::get()))
-            ->output();
+    Graph* g = n->owningGraph();
+    Value* none = g->insertNode(g->createNone(IntType::get()))->output();
     n->addInput(none);
   }
 }
-void addInputs(Node *n, const char * name, bool value)               { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, double value)             { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const at::Scalar& value)  { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const c10::optional<at::Scalar>& value)  {
-  if(value) {
+void addInputs(Node* n, const char* name, bool value) {
+  detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, double value) {
+  detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, const at::Scalar& value) {
+  detail::genericAddInput(n, value);
+}
+void addInputs(
+    Node* n,
+    const char* name,
+    const c10::optional<at::Scalar>& value) {
+  if (value) {
     detail::genericAddInput(n, *value);
   } else {
-    Graph * g = n->owningGraph();
-    Value* none =
-        g->insertNode(g->createNone(NumberType::get()))
-            ->output();
+    Graph* g = n->owningGraph();
+    Value* none = g->insertNode(g->createNone(NumberType::get()))->output();
     n->addInput(none);
   }
 }
-void addInputs(Node *n, const char * name, const std::string& value) { detail::genericAddInput(n, value); }
-void addInputs(Node *n, const char * name, const at::Tensor& value)  { n->addInput(getValueTrace(value)); }
-void addInputs(Node *n, const char * name, const at::SparseTensorRef& value) { detail::badArgType(value); }
-void addInputs(Node *n, const char * name, at::Generator * value)            {
+void addInputs(Node* n, const char* name, const std::string& value) {
+  detail::genericAddInput(n, value);
+}
+void addInputs(Node* n, const char* name, const at::Tensor& value) {
+  n->addInput(getValueTrace(value));
+}
+void addInputs(Node* n, const char* name, const at::SparseTensorRef& value) {
+  detail::badArgType(value);
+}
+void addInputs(Node* n, const char* name, at::Generator* value) {
   if (value) {
     detail::badArgType(value);
   }
-  Graph * g = n->owningGraph();
-  Value * undef_gen = g->insertNode(g->createNone(GeneratorType::get()))->output();
+  Graph* g = n->owningGraph();
+  Value* undef_gen =
+      g->insertNode(g->createNone(GeneratorType::get()))->output();
   n->addInput(undef_gen);
 }
-void addInputs(Node *n, const char * name, at::Device value) {
+void addInputs(Node* n, const char* name, at::Device value) {
   detail::genericAddInput(n, value);
 }
-void addInputs(Node *n, const char * name, at::Layout value) {
+void addInputs(Node* n, const char* name, at::Layout value) {
   detail::genericAddInput(n, static_cast<int64_t>(value));
 }
-void addInputs(Node *n, const char * name, at::ScalarType value) {
+void addInputs(Node* n, const char* name, at::ScalarType value) {
   detail::genericAddInput(n, static_cast<int64_t>(value));
 }
-void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value)  {
-  if(value) {
+void addInputs(
+    Node* n,
+    const char* name,
+    const c10::optional<at::ScalarType>& value) {
+  if (value) {
     detail::genericAddInput(n, static_cast<int64_t>(*value));
   } else {
-    Graph * g = n->owningGraph();
-    Value* none =
-        g->insertNode(g->createNone(IntType::get()))
-            ->output();
+    Graph* g = n->owningGraph();
+    Value* none = g->insertNode(g->createNone(IntType::get()))->output();
     n->addInput(none);
   }
 }
 
-void addInputs(Node *n, const char * name, at::TensorList value) {
-  Graph *g = n->owningGraph();
-  Node *list_node = g->appendNode(g->createList(DynamicType::get(), fmap(value, getValueTrace)));
+void addInputs(Node* n, const char* name, at::TensorList value) {
+  Graph* g = n->owningGraph();
+  Node* list_node = g->appendNode(
+      g->createList(DynamicType::get(), fmap(value, getValueTrace)));
   n->addInput(list_node->output());
 }
 
-void addInputs(Node* n, const char * name, const at::TensorOptions& options) {
-  // [TensorOptions in script] - update this when you change how we schematize TensorOptions
+void addInputs(Node* n, const char* name, const at::TensorOptions& options) {
+  // [TensorOptions in script] - update this when you change how we schematize
+  // TensorOptions
   addInputs(n, name, at::typeMetaToScalarType(options.dtype()));
   addInputs(n, name, options.layout());
   addInputs(n, name, options.device());
 }
 
-void addInputs(Node *n, const char * name, at::IntList value) {
+void addInputs(Node* n, const char* name, at::IntList value) {
   using ArgumentStash = jit::tracer::ArgumentStash;
-  std::vector<Value*> info = ArgumentStash::hasIntList(name) ?
-    ArgumentStash::popIntList(name) :
-    ArgumentStash::IntListTrace(value.size());
+  std::vector<Value*> info = ArgumentStash::hasIntList(name)
+      ? ArgumentStash::popIntList(name)
+      : ArgumentStash::IntListTrace(value.size());
 
   auto& g = getTracingState()->graph;
   for (size_t i = 0; i < info.size(); ++i) {
-    if (info[i] != nullptr) continue;
+    if (info[i] != nullptr)
+      continue;
     info[i] = g->insertConstant(value[i]);
     recordSourceLocation(info[i]->node());
   }
   for (jit::Value* v : info) {
     if (*v->type() != *jit::IntType::get()) {
       throw std::runtime_error(
-        "Type mismatch in setposattr for IntList. Check that your program "
-        "is valid without tracing, and please file a bug report if it is.");
+          "Type mismatch in setposattr for IntList. Check that your program "
+          "is valid without tracing, and please file a bug report if it is.");
     }
   }
-  n->addInput(g->insertNode(g->createList(jit::IntType::get(), info))->output());
+  n->addInput(
+      g->insertNode(g->createList(jit::IntType::get(), info))->output());
 }
 
-void addInputs(Node *n, const char * name, const ArrayRef<double>& value) {
+void addInputs(Node* n, const char* name, const ArrayRef<double>& value) {
   AT_ERROR("Tracing float lists currently not supported!");
 }
 
 void addOutput(Node* node, const at::Tensor& output) {
-  Value * value = node->addOutput();
+  Value* value = node->addOutput();
   if (output.defined()) {
     value->inferTypeFrom(output);
     setValueTrace(autograd::as_variable_ref(output), value);
@@ -179,11 +200,12 @@ void addOutput(Node* node, const at::Tensor& output) {
 }
 
 void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
-  Value * value = node->addOutput()->setType(ListType::ofTensors());
-  Graph * graph = node->owningGraph();
-  Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+  Value* value = node->addOutput()->setType(ListType::ofTensors());
+  Graph* graph = node->owningGraph();
+  Node* unpack_node = graph->appendNode(
+      graph->create(prim::ListUnpack, {value}, outputs.size()));
   for (size_t i = 0; i < outputs.size(); ++i) {
-    Value * output_val = unpack_node->outputs()[i];
+    Value* output_val = unpack_node->outputs()[i];
     output_val->inferTypeFrom(outputs[i]);
     setValueTrace(outputs[i], output_val);
   }
@@ -197,18 +219,18 @@ void setTracingState(std::shared_ptr<TracingState> state) {
   detail::tracing_state = std::move(state);
 }
 
-TracingState::TracingState()
-    : graph(new Graph()) {}
+TracingState::TracingState() : graph(new Graph()) {}
 
 TracingState::~TracingState() = default;
 
 autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
-  auto & tracing_state = getTracingState();
-  auto & graph = tracing_state->graph;
+  auto& tracing_state = getTracingState();
+  auto& graph = tracing_state->graph;
 
-  auto size_var = autograd::make_variable(scalar_to_tensor(at::Scalar(var.size(dim))));
+  auto size_var =
+      autograd::make_variable(scalar_to_tensor(at::Scalar(var.size(dim))));
   auto* value = getValueTrace(var);
-  WithInsertPoint ipoint { graph->block() };
+  WithInsertPoint ipoint{graph->block()};
   auto dim_val = graph->insertConstant(dim);
   recordSourceLocation(dim_val->node());
   auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
@@ -226,10 +248,15 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
 ////////////////////////////////////////////////////////////////////////////////
 thread_local ArgumentStash ArgumentStash::stash;
 
-void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, size_t idx, const Variable& var) {
+void ArgumentStash::stashIntListElem(
+    const std::string& arg_name,
+    size_t size,
+    size_t idx,
+    const Variable& var) {
   // TODO: check type?
-  if (!isTracing()) return;
-  auto & list_trace = stash.intlists.emplace(arg_name, size).first->second;
+  if (!isTracing())
+    return;
+  auto& list_trace = stash.intlists.emplace(arg_name, size).first->second;
   JIT_ASSERT(size == list_trace.size());
   JIT_ASSERT(idx < list_trace.size());
   JIT_ASSERT(list_trace[idx] == nullptr);
@@ -241,17 +268,22 @@ void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, s
   list_trace[idx] = prim;
 }
 
-void ArgumentStash::stashValue(const std::string& arg_name, size_t idx, const Variable& var, const TypePtr& type) {
-  if (!isTracing()) return;
+void ArgumentStash::stashValue(
+    const std::string& arg_name,
+    size_t idx,
+    const Variable& var,
+    const TypePtr& type) {
+  if (!isTracing())
+    return;
 
   Value* ten = getValueTrace(var);
   WithInsertPoint guard(ten->node()->next());
   auto& g = *ten->owningGraph();
 
   if (type == IntType::get()) {
-    ten = g.insert(prim::Int, { ten });
+    ten = g.insert(prim::Int, {ten});
   } else if (type == FloatType::get()) {
-    ten = g.insert(prim::Float, { ten });
+    ten = g.insert(prim::Float, {ten});
   }
 
   stash.values.emplace(arg_name, ten);
@@ -262,7 +294,8 @@ void ArgumentStash::stashValue(const std::string& arg_name, size_t idx, const Va
 ////////////////////////////////////////////////////////////////////////////////
 // no python present so we just do not record source information
 void defaultRecordSourceLocation(Node* n) {}
-std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(defaultRecordSourceLocation);
+std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
+    defaultRecordSourceLocation);
 void recordSourceLocation(Node* n) {
   return record_source_location.load()(n);
 }
@@ -273,26 +306,26 @@ void setRecordSourceLocation(void (*v)(Node*)) {
 void defaultWarn(const std::string& str) {
   AT_WARN(str);
 }
-std::atomic<warn_fn_type> warn_callback { defaultWarn };
-
-const char * WARN_PYTHON_DATAFLOW =
-  " might cause the trace to be incorrect. We can't record the data flow of "
-  "Python values, so this value will be treated as a constant in the future. "
-  "This means that the trace might not generalize to other inputs!";
-const char * WARN_CONSTRUCTOR =
-  " results are registered as constants in the trace. You can safely ignore this "
-  "warning if you use this function to create tensors out of constant variables "
-  "that would be the same every time you call this function. In any other case, "
-  "this might cause the trace to be incorrect.";
-const char * WARN_RESIZE =
-  " can't be represented in the JIT at the moment, so we won't connect any uses of "
-  "this value with its current trace. If you happen to use it again, it will show "
-  "up as a constant in the graph.";
+std::atomic<warn_fn_type> warn_callback{defaultWarn};
+
+const char* WARN_PYTHON_DATAFLOW =
+    " might cause the trace to be incorrect. We can't record the data flow of "
+    "Python values, so this value will be treated as a constant in the future. "
+    "This means that the trace might not generalize to other inputs!";
+const char* WARN_CONSTRUCTOR =
+    " results are registered as constants in the trace. You can safely ignore this "
+    "warning if you use this function to create tensors out of constant variables "
+    "that would be the same every time you call this function. In any other case, "
+    "this might cause the trace to be incorrect.";
+const char* WARN_RESIZE =
+    " can't be represented in the JIT at the moment, so we won't connect any uses of "
+    "this value with its current trace. If you happen to use it again, it will show "
+    "up as a constant in the graph.";
 
 // XXX: _kind can be a nullptr
-void _do_warn(const char * _reason, const char * _kind) {
-  std::string reason { _reason };
-  std::string kind { _kind ? _kind : "" };
+void _do_warn(const char* _reason, const char* _kind) {
+  std::string reason{_reason};
+  std::string kind{_kind ? _kind : ""};
   std::ostringstream s;
   s << reason << kind;
   warn_callback.load()(s.str());
@@ -302,4 +335,6 @@ void setWarn(warn_fn_type fn) {
   warn_callback.store(fn);
 }
 
-}}}
+} // namespace tracer
+} // namespace jit
+} // namespace torch
index bbbca8d..fc386a9 100644 (file)
@@ -1,27 +1,27 @@
 #pragma once
 
+#include <ATen/Backtrace.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/autograd/function_hook.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/assertions.h>
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/stack.h>
 #include <torch/csrc/jit/tracing_state.h>
-#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/functional.h>
 #include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
-#include <ATen/Backtrace.h>
 
+#include <cstdint>
+#include <iostream>
 #include <memory>
 #include <mutex>
-#include <vector>
-#include <iostream>
-#include <cstdint>
 #include <unordered_map>
+#include <vector>
 
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
 
 using torch::autograd::Variable;
 using variable_list = std::vector<Variable>;
@@ -29,9 +29,9 @@ using variable_list = std::vector<Variable>;
 TORCH_API void recordSourceLocation(Node* n);
 TORCH_API void setRecordSourceLocation(void (*v)(Node*));
 
-// Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
-// this node with an output variable, so that further operations involving this
-// variable know which node in the IR to reference.
+// Having finished adding a new 'node' to the graph IR 'setValueTrace'
+// associates this node with an output variable, so that further operations
+// involving this variable know which node in the IR to reference.
 TORCH_API void setValueTrace(const IValue& v, Value* value);
 
 inline void delValueTrace(const Variable& var) {
@@ -43,9 +43,7 @@ inline std::function<void()> pauseTracing() {
   std::shared_ptr<tracer::TracingState> state = getTracingState();
   tracer::setTracingState(nullptr);
 
-  return [state]() {
-    tracer::setTracingState(state);
-  };
+  return [state]() { tracer::setTracingState(state); };
 }
 
 // Given a variable 'var', return the 'node' which represents the instruction
@@ -59,20 +57,20 @@ inline std::function<void()> pauseTracing() {
 //      return Addmm.apply(output, self, matrix, 0, 1, True)
 //
 // Here, mm fakes up a dummy variable with uninitialized data to do an inplace
-// update on, but subsequently ignores it because the alpha scaling factor is zero.
-// This is one of the cases where a Variable can be created inside of a trace, and
-// if we treat it as a constant, everything will work out.
+// update on, but subsequently ignores it because the alpha scaling factor is
+// zero. This is one of the cases where a Variable can be created inside of a
+// trace, and if we treat it as a constant, everything will work out.
 inline Value* getValueTrace(const Variable& var) {
-  auto &state = getTracingState();
+  autostate = getTracingState();
   if (!var.defined()) {
-    Node *n = state->graph->createUndefined();
+    Noden = state->graph->createUndefined();
     return state->graph->appendNode(n)->output();
   }
 
-  auto & value_map = getTracingState()->value_map;
+  auto& value_map = getTracingState()->value_map;
   auto it = value_map.find(var);
   if (it == value_map.end()) {
-    Value *constant = state->graph->insertConstant(var.data());
+    Valueconstant = state->graph->insertConstant(var.data());
     recordSourceLocation(constant->node());
     constant->inferTypeFrom(var.data());
     it = value_map.emplace_hint(it, var, constant);
@@ -89,31 +87,36 @@ inline Value* getValueTrace(const Variable& var) {
 // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] arguments
 // One might merge getValueTrace and getNestedValueTrace after checking that
 // casting to IValue instead  of Variable is OK
-inline Value* getNestedValueTrace(const IValue &v) {
-  auto &state = getTracingState();
+inline Value* getNestedValueTrace(const IValuev) {
+  autostate = getTracingState();
   if (v.isTensorList()) {
-    return state->graph->insertNode(state->graph->createList(
-        DynamicType::get(),
-        fmap(v.toTensorListRef(), [](const IValue &val) {
-          return getNestedValueTrace(val);
-       })))->output();
+    return state->graph
+        ->insertNode(state->graph->createList(
+            DynamicType::get(),
+            fmap(
+                v.toTensorListRef(),
+                [](const IValue& val) { return getNestedValueTrace(val); })))
+        ->output();
   } else if (v.isTuple()) {
-    return state->graph->insertNode(state->graph->createTuple(
-       fmap(v.toTuple()->elements(), [](const IValue &val) {
-          return getNestedValueTrace(val);
-       })))->output();
+    return state->graph
+        ->insertNode(state->graph->createTuple(fmap(
+            v.toTuple()->elements(),
+            [](const IValue& val) { return getNestedValueTrace(val); })))
+        ->output();
   }
   return getValueTrace(v.toTensor());
 }
 
-
-inline Value* getOutputTrace(const std::shared_ptr<TracingState>& state, const Variable& var, size_t output_no) {
+inline Value* getOutputTrace(
+    const std::shared_ptr<TracingState>& state,
+    const Variable& var,
+    size_t output_no) {
   if (!var.defined()) {
-    Node *n = state->graph->createUndefined();
+    Noden = state->graph->createUndefined();
     return state->graph->appendNode(n)->output();
   }
 
-  auto & value_map = getTracingState()->value_map;
+  auto& value_map = getTracingState()->value_map;
   auto it = value_map.find(var);
   if (it == value_map.end()) {
     std::ostringstream os;
@@ -135,7 +138,8 @@ inline std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
   auto state = std::make_shared<TracingState>();
   setTracingState(state);
   // XXX: this function mutates input
-  const std::function<IValue(IValue, TypePtr, Value*)> add_input = [&](IValue input, TypePtr type, Value* value) -> IValue {
+  const std::function<IValue(IValue, TypePtr, Value*)> add_input =
+      [&](IValue input, TypePtr type, Value* value) -> IValue {
     value->setType(type);
     if (type->isSubtypeOf(DynamicType::get())) {
       auto input_tensor = input.toTensor();
@@ -147,22 +151,26 @@ inline std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
       state->value_map[input_tensor] = value;
       return input_tensor;
     } else if (auto tuple_type = type->cast<TupleType>()) {
-      auto unpack_node = state->graph->insertNode(state->graph->createTupleUnpack(value));
+      auto unpack_node =
+          state->graph->insertNode(state->graph->createTupleUnpack(value));
       auto elem_values = unpack_node->outputs();
       auto elem_types = tuple_type->elements();
       Stack elems = input.toTuple()->elements();
       size_t num_elems = elems.size();
-      AT_ASSERT(elem_values.size() == num_elems && elem_types.size() == num_elems);
+      AT_ASSERT(
+          elem_values.size() == num_elems && elem_types.size() == num_elems);
       for (size_t i = 0; i < num_elems; ++i) {
         elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
       }
       return Tuple::create(std::move(elems));
     } else {
-      AT_ERROR("Only tensors or tuples of tensors can be inputs to traced functions");
+      AT_ERROR(
+          "Only tensors or tuples of tensors can be inputs to traced functions");
     }
   };
   for (IValue& input : inputs) {
-    input = add_input(input, incompleteInferTypeFrom(input), state->graph->addInput());
+    input = add_input(
+        input, incompleteInferTypeFrom(input), state->graph->addInput());
   }
   return std::make_pair(state, inputs);
 }
@@ -171,18 +179,20 @@ inline std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
 // are the variables whose values will be computed upon subsequent
 // invocations of the trace.
 inline void exit(const Stack& outputs) {
-  auto & state = getTracingState();
+  auto& state = getTracingState();
   size_t i = 0;
-  std::function<Value*(const IValue&)> reduce_ivalue = [&](const IValue& iv) -> Value* {
+  std::function<Value*(const IValue&)> reduce_ivalue =
+      [&](const IValue& iv) -> Value* {
     if (iv.isTensor()) {
       return getOutputTrace(state, iv.toTensor(), i);
     } else if (iv.isTuple()) {
-      const auto & elems = iv.toTuple()->elements();
+      const auto& elems = iv.toTuple()->elements();
       auto tuple_node = state->graph->createTuple(fmap(elems, reduce_ivalue));
       state->graph->appendNode(tuple_node);
       return tuple_node->output();
     } else {
-      AT_ERROR("Only tensors or tuples of tensors can be output from traced functions");
+      AT_ERROR(
+          "Only tensors or tuples of tensors can be output from traced functions");
     }
   };
   for (auto& output : outputs) {
@@ -199,31 +209,52 @@ inline void abandon() {
 
 // NB: those serve both as an intermediate steps in addInputs below,
 // as well as the overloads that terminate template recursion
-TORCH_API void addInputs(Node *n, const char * name, int64_t value);
-TORCH_API void addInputs(Node *n, const char * name, c10::optional<int64_t> value);
-TORCH_API void addInputs(Node *n, const char * name, bool value);
-TORCH_API void addInputs(Node *n, const char * name, double value);
-TORCH_API void addInputs(Node *n, const char * name, const at::Scalar& value);
-TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::Scalar>& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::Tensor& value);
-TORCH_API void addInputs(Node *n, const char * name, at::IntList value);
-TORCH_API void addInputs(Node *n, const char * name, at::TensorList value);
-TORCH_API void addInputs(Node *n, const char * name, const ArrayRef<double>& value);
-TORCH_API void addInputs(Node *n, const char * name, const std::string& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::SparseTensorRef& value);
-TORCH_API void addInputs(Node *n, const char * name, const at::TensorOptions& value);
-TORCH_API void addInputs(Node *n, const char * name, at::Device value);
-TORCH_API void addInputs(Node *n, const char * name, at::Layout value);
-TORCH_API void addInputs(Node *n, const char * name, at::ScalarType value);
-TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value);
-TORCH_API void addInputs(Node *n, const char * name, at::Generator * value);
+TORCH_API void addInputs(Node* n, const char* name, int64_t value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    c10::optional<int64_t> value);
+TORCH_API void addInputs(Node* n, const char* name, bool value);
+TORCH_API void addInputs(Node* n, const char* name, double value);
+TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const c10::optional<at::Scalar>& value);
+TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
+TORCH_API void addInputs(Node* n, const char* name, at::IntList value);
+TORCH_API void addInputs(Node* n, const char* name, at::TensorList value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const ArrayRef<double>& value);
+TORCH_API void addInputs(Node* n, const char* name, const std::string& value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const at::SparseTensorRef& value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const at::TensorOptions& value);
+TORCH_API void addInputs(Node* n, const char* name, at::Device value);
+TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
+TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
+TORCH_API void addInputs(
+    Node* n,
+    const char* name,
+    const c10::optional<at::ScalarType>& value);
+TORCH_API void addInputs(Node* n, const char* name, at::Generator* value);
 
-template<size_t N>
-void addInputs(Node *n, const char * name, std::array<bool, N> value) {
-  throw std::runtime_error("Found an unsupported argument type in the JIT tracer. File a bug report.");
+template <size_t N>
+void addInputs(Node* n, const char* name, std::array<bool, N> value) {
+  throw std::runtime_error(
+      "Found an unsupported argument type in the JIT tracer. File a bug report.");
 }
 
-inline void ensureUniqueIfOutOfPlaced(const char * name, const at::Tensor& tensor) {
+inline void ensureUniqueIfOutOfPlaced(
+    const char* name,
+    const at::Tensor& tensor) {
   auto& state = getTracingState();
   if (state && state->force_outplace == false) {
     // If we're not converting in-place ops to out-of-place, this check is
@@ -235,7 +266,8 @@ inline void ensureUniqueIfOutOfPlaced(const char * name, const at::Tensor& tenso
     std::stringstream ss;
     ss << "There are " << aliases
        << " live references to the data region being modified when tracing in-place operator "
-       << name << ". This might cause the trace to be incorrect, because all other views "
+       << name
+       << ". This might cause the trace to be incorrect, because all other views "
        << "that also reference this data will not not reflect this change in the trace! "
        << "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
        << "are outputs of torch.split), this might still be safe.";
@@ -243,7 +275,6 @@ inline void ensureUniqueIfOutOfPlaced(const char * name, const at::Tensor& tenso
   }
 }
 
-
 template <
     typename T,
     typename = torch::enable_if_t<
@@ -258,6 +289,10 @@ void addOutput(Node* node, T&&) {
 TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
 TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
 
-TORCH_API autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim);
+TORCH_API autograd::Variable getSizeOf(
+    const autograd::Variable& var,
+    int64_t dim);
 
-}}} // namespace torch::jit::tracer
+} // namespace tracer
+} // namespace jit
+} // namespace torch
index c906932..ccb35a5 100644 (file)
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/autograd/function_hook.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/assertions.h>
@@ -7,26 +8,26 @@
 #include <torch/csrc/jit/stack.h>
 #include <torch/csrc/jit/type.h>
 #include <torch/csrc/utils/functional.h>
-#include <torch/csrc/utils/functional.h>
-#include <torch/csrc/utils/variadic.h>
 #include <torch/csrc/utils/variadic.h>
-#include <torch/csrc/WindowsTorchApiMacro.h>
 
 #include <ATen/Backtrace.h>
 
+#include <cstdint>
+#include <iostream>
 #include <memory>
 #include <mutex>
-#include <vector>
-#include <iostream>
-#include <cstdint>
 #include <unordered_map>
+#include <vector>
 
-namespace torch { namespace jit { namespace tracer {
+namespace torch {
+namespace jit {
+namespace tracer {
 
 using torch::autograd::Variable;
 using variable_list = std::vector<Variable>;
 
-struct TORCH_API TracingState : public std::enable_shared_from_this<TracingState> {
+struct TORCH_API TracingState
+    : public std::enable_shared_from_this<TracingState> {
   TracingState();
   ~TracingState();
 
@@ -44,15 +45,15 @@ struct TORCH_API TracingState : public std::enable_shared_from_this<TracingState
     }
   };
 
-  std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq> value_map;
+  std::unordered_map<WeakTensor, Value*, WeakTensorHasher, WeakTensorEq>
+      value_map;
   std::shared_ptr<Graph> graph;
   bool warn = true;
   bool force_outplace = false;
   std::function<std::string(const Variable& var)> lookup_var_name_fn =
-    [](const Variable& var) {return "";};
+      [](const Variable& var) { return ""; };
 };
 
-
 // This is meant to be used as a thread local place, where we can store extra
 // info that gets lost when we call into ATen from Python bindings. One example
 // for when this happens is when we get an IntList argument with e.g. sizes for
@@ -62,18 +63,18 @@ struct TORCH_API TracingState : public std::enable_shared_from_this<TracingState
 // information. To prevent this, we temporarily stash it in here.
 struct ArgumentStash {
   struct IntListTrace : std::vector<Value*> {
-    IntListTrace(int size)
-      : std::vector<Value*>(size, nullptr) {}
+    IntListTrace(int size) : std::vector<Value*>(size, nullptr) {}
   };
 
   static bool empty() {
     return stash.intlists.empty();
   }
 
-  TORCH_API static void stashIntListElem(const std::string& arg_name,
-                                         size_t size,
-                                         size_t idx,
-                                         const Variable& var);
+  TORCH_API static void stashIntListElem(
+      const std::string& arg_name,
+      size_t size,
+      size_t idx,
+      const Variable& var);
 
   static bool hasIntList(const std::string& arg_name) {
     return stash.intlists.count(arg_name) > 0;
@@ -88,10 +89,11 @@ struct ArgumentStash {
   // Value stashing: Use these methods to stash arguments which correspond
   // to regular Value*'s in the graph. i.e. they don't require special
   // handling like in the case of IntLists
-  TORCH_API static void stashValue(const std::string& arg_name,
-                                   size_t idx,
-                                   const Variable& var,
-                                   const TypePtr& type=nullptr);
+  TORCH_API static void stashValue(
+      const std::string& arg_name,
+      size_t idx,
+      const Variable& var,
+      const TypePtr& type = nullptr);
 
   static bool hasValue(const std::string& arg_name) {
     return stash.values.count(arg_name) > 0;
@@ -103,13 +105,14 @@ struct ArgumentStash {
     return info;
   }
 
-private:
+ private:
   static thread_local ArgumentStash stash;
   std::unordered_map<std::string, IntListTrace> intlists;
   std::unordered_map<std::string, Value*> values;
 };
 
-// Retrieve or set the current tracing state. Returns a nullptr if tracing is disabled.
+// Retrieve or set the current tracing state. Returns a nullptr if tracing is
+// disabled.
 TORCH_API const std::shared_ptr<TracingState>& getTracingState();
 TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
 
@@ -118,20 +121,21 @@ inline bool isTracing() {
 }
 
 using warn_fn_type = void (*)(const std::string& msg);
-TORCH_API extern const char * WARN_PYTHON_DATAFLOW;
-TORCH_API extern const char * WARN_CONSTRUCTOR;
-TORCH_API extern const char * WARN_RESIZE;
-TORCH_API void _do_warn(const char * _reason, const char * _kind);
-inline void warn(const char * _reason, const char * _kind=nullptr) {
+TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
+TORCH_API extern const char* WARN_CONSTRUCTOR;
+TORCH_API extern const char* WARN_RESIZE;
+TORCH_API void _do_warn(const char* _reason, const char* _kind);
+inline void warn(const char* _reason, const char* _kind = nullptr) {
   if (const auto& state = getTracingState()) {
-    if (!state->warn) return;
+    if (!state->warn)
+      return;
     _do_warn(_reason, _kind);
   }
 }
 TORCH_API void setWarn(warn_fn_type fn);
 
 struct TORCH_API NoWarn {
-  NoWarn(): state(getTracingState()) {
+  NoWarn() : state(getTracingState()) {
     if (state) {
       prev = state->warn;
       state->warn = false;
@@ -146,4 +150,6 @@ struct TORCH_API NoWarn {
   bool prev;
 };
 
-}}} // namespace torch::jit::tracer
+} // namespace tracer
+} // namespace jit
+} // namespace torch
index a4a88cc..ec81a65 100644 (file)
@@ -1,21 +1,23 @@
 #include <ATen/core/jit_type.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 #define C10_USING(T) using ::c10::T;
-  C10_FORALL_TYPES(C10_USING)
+C10_FORALL_TYPES(C10_USING)
 #undef C10_USING
 
 #define C10_USING(T) using ::c10::T##Ptr;
-  C10_FORALL_TYPES(C10_USING)
+C10_FORALL_TYPES(C10_USING)
 #undef C10_USING
 
 using ::c10::Type;
-using ::c10::TypePtr;
 using ::c10::TypeEnv;
+using ::c10::TypePtr;
 
 using ::c10::getTypePtr;
-using ::c10::TypeKind;
 using ::c10::MatchTypeReturn;
+using ::c10::TypeKind;
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 29d0788..aab1cef 100644 (file)
@@ -1,17 +1,19 @@
 #pragma once
 #include <ATen/ATen.h>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 // a wrapper to mark places where we expect all the at::Tensors to be
 // variables
 struct variable_tensor_list : public std::vector<at::Tensor> {
   variable_tensor_list() = default;
-  template<class InputIt>
+  template <class InputIt>
   variable_tensor_list(InputIt first, InputIt last)
-  : std::vector<at::Tensor>(first, last) {}
-  explicit variable_tensor_list(std::vector<at::Tensor> && tensor)
-  : std::vector<at::Tensor>(std::move(tensor)) {}
+      : std::vector<at::Tensor>(first, last) {}
+  explicit variable_tensor_list(std::vector<at::Tensor>&& tensor)
+      : std::vector<at::Tensor>(std::move(tensor)) {}
 };
 
-}}
+} // namespace jit
+} // namespace torch