From f89de64796104ca3a48a465a62017edea1cd7506 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 5 Dec 2018 00:07:51 -0800 Subject: [PATCH] Use AT_WARN for warnings in the JIT (#14770) Summary: Previously their implementation dispatched to prim::Print, which kept printing the warnings. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14770 Differential Revision: D13327629 Pulled By: suo fbshipit-source-id: b9913f533d4530eb7c29146c39981ba7f72b6b68 --- aten/src/ATen/core/interned_strings.h | 1 + test/expect/TestJit.test_warnings.expect | 2 +- torch/csrc/jit/passes/common_subexpression_elimination.cpp | 6 +++--- torch/csrc/jit/passes/dead_code_elimination.cpp | 1 + torch/csrc/jit/register_prim_ops.cpp | 10 ++++++++++ torch/csrc/jit/script/builtin_functions.cpp | 12 ------------ 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index ec0e045..78ad29c 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -73,6 +73,7 @@ namespace c10 { _(prim, NoneGenerator) \ _(prim, MMTreeReduce) \ _(prim, MMBatchSide) \ + _(aten, warn) \ _(aten, floordiv) \ _(aten, __round_to_zero_floordiv)\ _(prim, fork) \ diff --git a/test/expect/TestJit.test_warnings.expect b/test/expect/TestJit.test_warnings.expect index 60fd27b..c4ab59d 100644 --- a/test/expect/TestJit.test_warnings.expect +++ b/test/expect/TestJit.test_warnings.expect @@ -5,7 +5,7 @@ graph(%x : Tensor) { %4 : bool = prim::TensorToBool(%3) = prim::If(%4) block0() { - = prim::Print(%1) + = aten::warn(%1, %2) -> () } block1() { diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index 72b1501..fe02712 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -23,9 +23,9 @@ void EliminateCommonSubexpression( std::unordered_set subexprs; for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) { auto node = *it; - if (node->isNondeterministic() || node->kind() == prim::PythonOp || - node->kind() == prim::Print || aliasDb.hasWriters(node) || - aliasDb.hasWildcard(node)) { + if (node->kind() == prim::PythonOp || node->kind() == prim::Print || + node->kind() == aten::warn || node->isNondeterministic() || + aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) { // Do NOT have enough information to do CSE on these nodes. continue; } diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index bd5d5a3..af6dfef 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -181,6 +181,7 @@ class DeadCodeEliminator { if (it != memo_.end()) return it->second; bool has_side_effects = node->kind() == prim::Print || + node->kind() == aten::warn || node->kind() == prim::RaiseException || node->kind() == prim::PythonOp || std::any_of(node->blocks().begin(), diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 7ed54c0..301ef80 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -330,6 +330,16 @@ RegisterOperators reg({ }; }), Operator( + FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}), + [](const Node* node) { + return [](Stack& stack) { + drop(stack, 1); + AT_WARN(pop(stack).toStringRef()); + return 0; + }; + }), + + Operator( "prim::RaiseException(str msg) -> ()", [](const Node* node) -> Operation { return [](Stack& stack) { diff --git a/torch/csrc/jit/script/builtin_functions.cpp b/torch/csrc/jit/script/builtin_functions.cpp index ca42e68..a6b35ee 100644 --- a/torch/csrc/jit/script/builtin_functions.cpp +++ b/torch/csrc/jit/script/builtin_functions.cpp @@ -28,16 +28,6 @@ def div(a : ${Scalar}, b : Tensor) -> Tensor: return torch.reciprocal(b) * a )SCRIPT"); -auto python_builtins_source = R"SCRIPT( -def warn(string: str): - print(string) -)SCRIPT"; - -auto python_builtins_source_overloads = R"SCRIPT( -def warn(string: str, stacklevel: int): - print(string) -)SCRIPT"; - auto _ntuple_ops = CodeTemplate( R"SCRIPT( def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]: @@ -84,8 +74,6 @@ private: env.s("Scalar", scalar); loadSource(scalar_operators_source.format(env)); } - loadSource(python_builtins_source); - loadSource(python_builtins_source_overloads); using str_pair = std::pair; const std::vector name_len = { -- 2.7.4