Use AT_WARN for warnings in the JIT (#14770)
authorAdam Paszke <adam.paszke@gmail.com>
Wed, 5 Dec 2018 08:07:51 +0000 (00:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 08:16:09 +0000 (00:16 -0800)
Summary:
Previously their implementation dispatched to prim::Print, which kept
printing the warnings.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14770

Differential Revision: D13327629

Pulled By: suo

fbshipit-source-id: b9913f533d4530eb7c29146c39981ba7f72b6b68

aten/src/ATen/core/interned_strings.h
test/expect/TestJit.test_warnings.expect
torch/csrc/jit/passes/common_subexpression_elimination.cpp
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/builtin_functions.cpp

index ec0e045..78ad29c 100644 (file)
@@ -73,6 +73,7 @@ namespace c10 {
   _(prim, NoneGenerator)           \
   _(prim, MMTreeReduce)            \
   _(prim, MMBatchSide)             \
+  _(aten, warn)                    \
   _(aten, floordiv)                \
   _(aten, __round_to_zero_floordiv)\
   _(prim, fork)                    \
index 60fd27b..c4ab59d 100644 (file)
@@ -5,7 +5,7 @@ graph(%x : Tensor) {
   %4 : bool = prim::TensorToBool(%3)
    = prim::If(%4)
     block0() {
-       = prim::Print(%1)
+       = aten::warn(%1, %2)
       -> ()
     }
     block1() {
index 72b1501..fe02712 100644 (file)
@@ -23,9 +23,9 @@ void EliminateCommonSubexpression(
   std::unordered_set<Node*, HashNode, EqualNode> subexprs;
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
     auto node = *it;
-    if (node->isNondeterministic() || node->kind() == prim::PythonOp ||
-        node->kind() == prim::Print || aliasDb.hasWriters(node) ||
-        aliasDb.hasWildcard(node)) {
+    if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
+        node->kind() == aten::warn || node->isNondeterministic() ||
+        aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
       // Do NOT have enough information to do CSE on these nodes.
       continue;
     }
index bd5d5a3..af6dfef 100644 (file)
@@ -181,6 +181,7 @@ class DeadCodeEliminator {
     if (it != memo_.end())
       return it->second;
     bool has_side_effects = node->kind() == prim::Print ||
+        node->kind() == aten::warn ||
         node->kind() == prim::RaiseException ||
         node->kind() == prim::PythonOp ||
         std::any_of(node->blocks().begin(),
index 7ed54c0..301ef80 100644 (file)
@@ -330,6 +330,16 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        FunctionSchema("aten::warn", {Argument("message", StringType::get()), Argument("stacklevel", IntType::get(), c10::nullopt, 2, true)}, {}),
+        [](const Node* node) {
+          return [](Stack& stack) {
+            drop(stack, 1);
+            AT_WARN(pop(stack).toStringRef());
+            return 0;
+          };
+        }),
+
+    Operator(
         "prim::RaiseException(str msg) -> ()",
         [](const Node* node) -> Operation {
           return [](Stack& stack) {
index ca42e68..a6b35ee 100644 (file)
@@ -28,16 +28,6 @@ def div(a : ${Scalar}, b : Tensor) -> Tensor:
   return torch.reciprocal(b) * a
 )SCRIPT");
 
-auto python_builtins_source = R"SCRIPT(
-def warn(string: str):
-  print(string)
-)SCRIPT";
-
-auto python_builtins_source_overloads = R"SCRIPT(
-def warn(string: str, stacklevel: int):
-  print(string)
-)SCRIPT";
-
 auto _ntuple_ops = CodeTemplate(
 R"SCRIPT(
 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
@@ -84,8 +74,6 @@ private:
       env.s("Scalar", scalar);
       loadSource(scalar_operators_source.format(env));
     }
-    loadSource(python_builtins_source);
-    loadSource(python_builtins_source_overloads);
 
     using str_pair = std::pair<std::string, std::string>;
     const std::vector<str_pair> name_len = {