Add input information in RecordFunction calls (#18717)
authorIlia Cherniavskii <iliacher@fb.com>
Tue, 16 Apr 2019 03:24:10 +0000 (20:24 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 03:28:08 +0000 (20:28 -0700)
Summary:
Add input information into generated RecordFunction calls in
VariableType wrappers, JIT operators and a few more locations
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18717

Differential Revision: D14729156

Pulled By: ilia-cher

fbshipit-source-id: 811ac4cbfd85af5c389ef030a7e82ef454afadec

test/cpp/jit/test_misc.h
tools/autograd/gen_variable_type.py
tools/jit/gen_jit_dispatch.py
tools/jit/templates/register_aten_ops.cpp
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/autograd/function.h
torch/csrc/autograd/python_function.cpp
torch/csrc/autograd/record_function.cpp
torch/csrc/autograd/record_function.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/register_special_ops.cpp

index f5a637c..fe3fe18 100644 (file)
@@ -576,14 +576,54 @@ void testTopologicalIndex() {
   }
 }
 
-void invokeTestRecordFunction(at::Tensor& t) {
-  autograd::profiler::GetPackedInputsCallback inputs_cb = [t]() {
-    Stack st;
-    pack(st, t);
-    return st;
-  };
-  autograd::profiler::RecordFunction guard("test", inputs_cb);
-  t.add_(torch::ones_like(t));
+at::Tensor invokeTestRecordFunction(at::Tensor& t) {
+  RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
+
+  auto t2 = t.pow(2);
+  return t2;
+}
+
+static const auto invokeTestRecordFunction_JIT = R"JIT(
+  def forward(t):
+    t2 = t.pow(2)
+    return t2
+)JIT";
+
+at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
+  RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
+
+  auto cu = compile(invokeTestRecordFunction_JIT);
+  return cu->get_function("forward")({t}).toTensor();
+}
+
+using TracedTestInputs =
+    std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
+
+void checkTracedInputs(const TracedTestInputs& inputs) {
+  bool found_test = false;
+  bool found_pow = false;
+  bool found_mul = false;
+  for (const auto& input : inputs) {
+    const auto& fn = std::get<0>(input);
+    const auto& sizes = std::get<1>(input);
+    if (fn == "test") {
+      found_test = true;
+      AT_CHECK(sizes.size() == 1);
+      AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
+    } else if (fn == "test::pow") {
+      found_pow = true;
+      AT_CHECK(sizes.size() == 2);
+      AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
+      AT_CHECK(sizes[1].empty());
+    } else if (fn.find("::mul") != std::string::npos) {
+      found_mul = true;
+      AT_CHECK(sizes.size() > 1);
+      AT_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
+    }
+  }
+  AT_CHECK(found_test);
+  AT_CHECK(found_pow);
+  AT_CHECK(found_mul);
 }
 
 std::string getFullName(const autograd::profiler::RecordFunction* fn_ptr) {
@@ -599,47 +639,42 @@ std::string getFullName(const autograd::profiler::RecordFunction* fn_ptr) {
   return full_name;
 }
 
-void invokeTestRecordFunctionNested() {
-  autograd::profiler::RecordFunction guard("inner");
-}
-
 void testRecordFunction() {
-  std::vector<std::vector<int64_t>> input_sizes;
+  // [(fn, [[sizes], [sizes], ...]), ...]
+  TracedTestInputs traced_inputs;
   autograd::profiler::pushCallback(
-      [&input_sizes](const autograd::profiler::RecordFunction& fn) {
-        for (const auto& input : fn.inputs()) {
+      [&traced_inputs](const autograd::profiler::RecordFunction& fn) {
+        auto inputs = fn.inputs();
+        std::vector<std::vector<int64_t>> sizes;
+        for (const auto& input : inputs) {
           if (input.isTensor()) {
-            std::vector<int64_t> t = input.toTensor().sizes().vec();
-            input_sizes.push_back(t);
+            sizes.push_back(input.toTensor().sizes().vec());
+          } else if (input.isScalar()){
+            sizes.push_back(std::vector<int64_t>());
           }
         }
-      });
+        traced_inputs.push_back(
+            std::make_tuple(std::string(getFullName(&fn)), sizes));
+      }, [](const autograd::profiler::RecordFunction&) {}, true);
 
   auto t = torch::randn({1, 2, 3}, at::kCPU);
-  invokeTestRecordFunction(t);
+  t.set_requires_grad(true);
+  auto t2 = invokeTestRecordFunction(t);
+  t2.backward();
+  auto eager_inputs = traced_inputs;
+  traced_inputs.clear();
+
+  t = torch::randn({1, 2, 3}, at::kCPU);
+  t.set_requires_grad(true);
+  t2 = invokeTestRecordFunctionJIT(t);
+  t2.backward();
+  auto jit_inputs = traced_inputs;
+  traced_inputs.clear();
 
   autograd::profiler::popCallback();
 
-  AT_CHECK(input_sizes.size() == 1);
-  AT_CHECK(input_sizes[0] == at::IntArrayRef({1, 2, 3}));
-
-  // test nested RecordFunctions
-  std::vector<std::string> nested_names;
-  autograd::profiler::pushCallback(
-      [&nested_names](const autograd::profiler::RecordFunction& fn) {
-        nested_names.push_back(getFullName(&fn));
-      });
-
-  {
-    autograd::profiler::RecordFunction guard("outer");
-    invokeTestRecordFunctionNested();
-    ;
-  }
-
-  autograd::profiler::popCallback();
-  AT_CHECK(nested_names.size() == 2);
-  AT_CHECK(nested_names[0] == "outer");
-  AT_CHECK(nested_names[1] == "outer::inner");
+  checkTracedInputs(eager_inputs);
+  checkTracedInputs(jit_inputs);
 }
 
 void testAutogradProfiler() {
index f0753e0..dd076b9 100644 (file)
@@ -207,7 +207,8 @@ if (${cond}) {
 """)
 
 RECORD_FUNCTION = CodeTemplate("""\
-profiler::RecordFunction profiler("${name}", Function::peek_at_next_sequence_nr());""")
+RECORD_FUNCTION("${name}", std::vector<c10::IValue>({${input_names}}), Function::peek_at_next_sequence_nr());
+""")
 
 SELECT = CodeTemplate("""\
 if (${cond}) {
@@ -847,12 +848,22 @@ def emit_body(declaration):
             return []
         return ['increment_version({});'.format(arg['name']) for arg in differentiable_outputs]
 
+    def check_record_function_input_type(simple_type):
+        return simple_type in ['Tensor', 'Scalar']
+
+    def record_function_input_names():
+        return ', '.join([
+            arg['name'] for arg in declaration['arguments']
+            if check_record_function_input_type(arg['simple_type'])])
+
     env = {}
     combined = nested_dict(env, declaration)
 
     body = []
     if base_name not in DONT_PROFILE:
-        body.append(RECORD_FUNCTION.substitute(combined))
+        input_names = record_function_input_names()
+        body.append(
+            RECORD_FUNCTION.substitute(combined, input_names=input_names))
     if strategy != 'use_type':
         body.extend(unpack_args(env, declaration))
     if requires_derivative:
index 149832b..14b0843 100644 (file)
@@ -147,7 +147,6 @@ auto result_ = (${first}).${name}(${args_with_tensor_options});
 
 CONSTRUCTOR = CodeTemplate("""\
 [](Stack & stack) {
-    autograd::profiler::RecordFunction record("${name}");
     ${lvalues}
     ${call}
     drop(stack, ${num_inputs});
index 0fb2975..2c46480 100644 (file)
@@ -83,7 +83,7 @@ RegisterOperators reg({
   Operator(
       "aten::get_device(Tensor self) -> int",
       [](Stack & stack) {
-          autograd::profiler::RecordFunction record("get_device");
+          RECORD_FUNCTION("get_device", std::vector<c10::IValue>());
           auto result = at::get_device(
               (std::move(peek(stack, 0, 1))).toTensor()
           );
@@ -95,7 +95,7 @@ RegisterOperators reg({
   Operator(
       "aten::storage_offset(Tensor self) -> int",
       [](Stack & stack) {
-          autograd::profiler::RecordFunction record("storage_offset");
+          RECORD_FUNCTION("storage_offset", std::vector<c10::IValue>());
           auto result = ((std::move(peek(stack, 0, 1))).toTensor()).storage_offset();
           drop(stack, 1);
           pack(stack, std::move(result));
@@ -105,7 +105,7 @@ RegisterOperators reg({
   Operator(
       "aten::is_contiguous(Tensor self) -> bool",
       [](Stack & stack) {
-          autograd::profiler::RecordFunction record("is_contiguous");
+          RECORD_FUNCTION("is_contiguous", std::vector<c10::IValue>());
           auto result = ((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous();
           drop(stack, 1);
           pack(stack, std::move(result));
index be5f576..2a7414c 100644 (file)
@@ -316,7 +316,8 @@ Tensor & VariableType::resize_as_(Tensor & self, const Tensor & the_template) co
 }
 
 Tensor VariableType::detach(const Tensor & self) const {
-  profiler::RecordFunction profiler("detach");
+  RECORD_FUNCTION("detach", std::vector<c10::IValue>({self}));
+
   torch::jit::Node* node = nullptr;
   if (jit::tracer::isTracing()) {
     auto& graph = jit::tracer::getTracingState()->graph;
@@ -336,7 +337,8 @@ Tensor VariableType::detach(const Tensor & self) const {
 }
 
 Tensor & VariableType::detach_(Tensor & self) const {
-  profiler::RecordFunction profiler("detach_");
+  RECORD_FUNCTION("detach_", std::vector<c10::IValue>({self}));
+
   torch::jit::Node* node = nullptr;
   if (jit::tracer::isTracing()) {
     auto& graph = jit::tracer::getTracingState()->graph;
index 7c73893..380dfd9 100644 (file)
@@ -112,7 +112,9 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
   /// Evaluates the function on the given inputs and returns the result of the
   /// function call.
   variable_list operator()(variable_list&& inputs) {
-    profiler::RecordFunction rec(this);
+    RECORD_FUNCTION(
+        this, std::vector<c10::IValue>(inputs.begin(), inputs.end()));
+
     return apply(std::move(inputs));
   }
 
index 81d7871..50fc160 100644 (file)
@@ -671,8 +671,10 @@ PyObject* process_outputs(PyObject *op_obj, THPFunction* grad_fn, const Unpacked
 PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
 {
   HANDLE_TH_ERRORS
-  torch::autograd::profiler::RecordFunction record(Py_TYPE(self)->tp_name,
-                                                   Function::peek_at_next_sequence_nr());
+  RECORD_FUNCTION(
+    Py_TYPE(self)->tp_name,
+    std::vector<c10::IValue>(),
+    Function::peek_at_next_sequence_nr());
 
   auto info_pair = unpack_input<true>(_inputs);
   auto& unpacked_input = info_pair.first;
@@ -702,8 +704,10 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
 PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
 {
   HANDLE_TH_ERRORS
-  torch::autograd::profiler::RecordFunction record(((PyTypeObject*)cls)->tp_name,
-                                                   Function::peek_at_next_sequence_nr());
+  RECORD_FUNCTION(
+    ((PyTypeObject*)cls)->tp_name,
+    std::vector<c10::IValue>(),
+    Function::peek_at_next_sequence_nr());
 
   THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
   if (!backward_cls) return nullptr;
index 57f83e9..75394b3 100644 (file)
@@ -4,20 +4,21 @@
 namespace torch { namespace autograd { namespace profiler {
 
 namespace {
-bool has_callbacks = false;
 std::vector<RecordFunctionCallback> start_callbacks;
 std::vector<RecordFunctionCallback> end_callbacks;
+size_t callback_needs_inputs = 0;
 thread_local RecordFunction* thread_local_func_ = nullptr;
 }
 
-void pushCallback(RecordFunctionCallback start, RecordFunctionCallback end) {
+void pushCallback(
+    RecordFunctionCallback start,
+    RecordFunctionCallback end,
+    bool needs_inputs) {
   start_callbacks.push_back(start);
   end_callbacks.push_back(end);
-  has_callbacks = true;
-}
-
-void pushCallback(RecordFunctionCallback start) {
-  pushCallback(start, [](const RecordFunction&){});
+  if (callback_needs_inputs > 0 || needs_inputs) {
+    ++callback_needs_inputs;
+  }
 }
 
 void popCallback() {
@@ -26,39 +27,53 @@ void popCallback() {
   }
   start_callbacks.pop_back();
   end_callbacks.pop_back();
-  has_callbacks = !start_callbacks.empty();
+  if (callback_needs_inputs > 0) {
+    --callback_needs_inputs;
+  }
+}
+
+bool hasCallbacks() {
+  return !start_callbacks.empty();
+}
+
+bool needsInputs() {
+  return callback_needs_inputs > 0;
 }
 
-RecordFunction::RecordFunction(Function* fn, GetPackedInputsCallback cb) {
-  if (!has_callbacks) {
+void RecordFunction::before(const char* name, int64_t sequence_nr) {
+  if (!hasCallbacks()) {
     return;
   }
-  fn_ = fn;
-  name_ = StringView(fn->name());
-  sequence_nr_ = fn->sequence_nr();
-  inputs_cb_ = cb;
+  AT_ASSERT(!initialized_);
+  name_ = StringView(name);
+  sequence_nr_ = sequence_nr;
+
+  initialized_ = true;
   processCallbacks();
 }
 
-RecordFunction::RecordFunction(
-    std::string name, int64_t sequence_nr, GetPackedInputsCallback cb) {
-  if (!has_callbacks) {
+void RecordFunction::before(std::string name, int64_t sequence_nr) {
+  if (!hasCallbacks()) {
     return;
   }
+  AT_ASSERT(!initialized_);
   name_ = StringView(std::move(name));
   sequence_nr_ = sequence_nr;
-  inputs_cb_ = cb;
+
+  initialized_ = true;
   processCallbacks();
 }
 
-RecordFunction::RecordFunction(
-    const char* name, int64_t sequence_nr, GetPackedInputsCallback cb) {
-  if (!has_callbacks) {
+void RecordFunction::before(Function* fn, int64_t sequence_nr) {
+  if (!hasCallbacks()) {
     return;
   }
-  name_ = StringView(name);
-  sequence_nr_ = sequence_nr;
-  inputs_cb_ = cb;
+  AT_ASSERT(!initialized_);
+  fn_ = fn;
+  name_ = StringView(fn->name());
+  sequence_nr_ = (sequence_nr >= 0) ? sequence_nr : fn->sequence_nr();
+
+  initialized_ = true;
   processCallbacks();
 }
 
@@ -72,7 +87,7 @@ void RecordFunction::processCallbacks() {
 }
 
 RecordFunction::~RecordFunction() {
-  if (has_callbacks) {
+  if (initialized_) {
     for (const auto& cb : end_callbacks) {
       cb(*this);
     }
index eef1a67..7d25f55 100644 (file)
@@ -26,32 +26,37 @@ struct TORCH_API StringView {
   const char* str_ptr_;
 };
 
-using GetPackedInputsCallback = std::function<std::vector<c10::IValue>()>;
-
 struct TORCH_API RecordFunction {
-  explicit RecordFunction(Function* fn, GetPackedInputsCallback cb = nullptr);
-
-  explicit RecordFunction(
-      std::string name,
-      int64_t current_sequence_nr = -1,
-      GetPackedInputsCallback cb = nullptr);
-
-  explicit RecordFunction(
-      const char* name,
-      int64_t current_sequence_nr = -1,
-      GetPackedInputsCallback cb = nullptr);
-
-  explicit RecordFunction(
-      std::string name,
-      GetPackedInputsCallback cb) : RecordFunction(name, -1, cb) {}
+  // Default constructor is used with before function called afterwards
+  RecordFunction() {}
+
+  // before function initializes RecordFunction members and calls
+  // start callbacks
+  void before(const char* name, int64_t sequence_nr = -1);
+  void before(std::string name, int64_t sequence_nr = -1);
+  void before(Function* fn, int64_t sequence_nr = -1);
+
+  template<typename F>
+  void before(
+      F fn,
+      c10::ArrayRef<c10::IValue> args,
+      int64_t current_sequence_nr = -1) {
+    inputs_ = args.vec();
+    before(fn, current_sequence_nr);
+  }
 
-  explicit RecordFunction(
-      const char* name,
-      GetPackedInputsCallback cb) : RecordFunction(name, -1, cb) {}
+  template<typename F>
+  void before(
+      F fn,
+      std::vector<c10::IValue>&& args,
+      int64_t current_sequence_nr = -1) {
+    inputs_ = std::move(args);
+    before(fn, current_sequence_nr);
+  }
 
+  // Destructor calls end callbacks
   virtual ~RecordFunction();
 
-
   inline Function* func() const {
     return fn_;
   }
@@ -65,10 +70,6 @@ struct TORCH_API RecordFunction {
   }
 
   const std::vector<c10::IValue>& inputs() const {
-    if (inputs_cb_ && !inputs_initialized_) {
-      inputs_ = inputs_cb_();
-      inputs_initialized_ = true;
-    }
     return inputs_;
   }
 
@@ -82,20 +83,33 @@ struct TORCH_API RecordFunction {
   Function* fn_ = nullptr;
   StringView name_;
   int64_t sequence_nr_ = -1;
-
+  std::vector<c10::IValue> inputs_;
   RecordFunction* parent_ = nullptr;
 
-  GetPackedInputsCallback inputs_cb_ = nullptr;
-  mutable bool inputs_initialized_ = false;
-  // initialized lazily by inputs_cb_
-  mutable std::vector<c10::IValue> inputs_;
+  bool initialized_ = false;
 };
 
+TORCH_API bool hasCallbacks();
+TORCH_API bool needsInputs();
+
+// optional argument - function's seq_no
+#define RECORD_FUNCTION(fn, inputs, ...) \
+  torch::autograd::profiler::RecordFunction guard; \
+  if (torch::autograd::profiler::hasCallbacks()) { \
+    if (torch::autograd::profiler::needsInputs()) { \
+      guard.before(fn, inputs, ##__VA_ARGS__); \
+    } else { \
+      guard.before(fn, ##__VA_ARGS__); \
+    } \
+  }
+
 // WARNING: all calls to pushCallback/popCallback are not thread safe and
 // must not overlap with other code execution
 using RecordFunctionCallback = std::function<void(const RecordFunction&)>;
-TORCH_API void pushCallback(RecordFunctionCallback, RecordFunctionCallback);
-TORCH_API void pushCallback(RecordFunctionCallback);
+TORCH_API void pushCallback(
+    RecordFunctionCallback start,
+    RecordFunctionCallback end = [](const RecordFunction&){},
+    bool needs_inputs = false);
 TORCH_API void popCallback();
 
 } // namespace profiler
index 27393d2..31142ea 100644 (file)
@@ -117,7 +117,7 @@ RegisterOperators reg(
          [](const Node* node) {
            const auto key = registerFusion(node);
            return [key](Stack& stack) {
-             autograd::profiler::RecordFunction record("FusionGroup");
+             RECORD_FUNCTION("FusionGroup", std::vector<c10::IValue>());
              runFusion(key, stack);
              return 0;
            };
@@ -660,7 +660,8 @@ RegisterOperators reg(
              return v->uses().size() > 0;
            });
            return [=](Stack& stack) {
-             autograd::profiler::RecordFunction record("chunk");
+             RECORD_FUNCTION("chunk", last(stack, 1));
+
              at::Tensor t;
              pop(stack, t);
              auto result = at::chunk(t, chunks, dim);
index 1979a03..05a4d6d 100644 (file)
@@ -140,7 +140,8 @@ RegisterOperators reg({
     Operator(
         "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]",
         [](Stack& stack) {
-          autograd::profiler::RecordFunction record("split_with_sizes");
+          RECORD_FUNCTION("split_with_sizes", last(stack, 3));
+
           auto result = at::split_with_sizes(
               (std::move(peek(stack, 0, 3))).toTensor(),
               (std::move(peek(stack, 1, 3))).toIntList()->elements(),
@@ -155,7 +156,8 @@ RegisterOperators reg({
     Operator(
         "aten::size(Tensor self) -> int[]",
         [](Stack& stack) {
-          autograd::profiler::RecordFunction record("sizes");
+          RECORD_FUNCTION("size", last(stack, 1));
+
           auto t = std::move(pop(stack)).toTensor();
           pack(stack, t.sizes().vec());
           return 0;
@@ -163,7 +165,8 @@ RegisterOperators reg({
     Operator(
         "aten::list_with_default(int[] list, int[] defaults) -> int[]",
         [](Stack& stack) {
-          autograd::profiler::RecordFunction record("sizes");
+          RECORD_FUNCTION("sizes", last(stack, 2));
+
           auto list = peek(stack, 0, 2).toIntListRef();
           auto defaults = peek(stack, 1, 2).toIntListRef();
           drop(stack, 2);