Caffe2 - Add flag to fails if float point exceptions is detected in operator runs...
authorDuc Ngo <duc@fb.com>
Sat, 16 Mar 2019 19:21:55 +0000 (12:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 16 Mar 2019 19:28:05 +0000 (12:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18040

Add flag to fails if float point exceptions is detected in operator runs

Sample exception

Exception [enforce fail at operator.h:837] !std::fetestexcept(FE_DIVBYZERO). Division by zero floating point exception (FE_DIVBYZERO) reported.
Error from operator:
input: "1" input: "0" output: "out" name: "" type: "Div"

Reviewed By: jspark1105

Differential Revision: D14467731

fbshipit-source-id: fad030b1d619a5a661ff2114edb947e4562cecdd

caffe2/core/operator.cc
caffe2/core/operator.h
caffe2/python/operator_fp_exceptions_test.py [new file with mode: 0644]

index 170a67e..d7b66bc 100644 (file)
@@ -25,6 +25,11 @@ C10_DEFINE_bool(
     false,
     "If set, disable implicit engine preferences. This is useful for unit "
     "testing and debugging cases.");
+C10_DEFINE_bool(
+    caffe2_operator_throw_if_fp_exceptions,
+    false,
+    "If set, throws if floating point exceptions (FE_DIVBYZERO, FE_INVALID, "
+    "FE_OVERFLOW) are detected when running any operator.");
 
 namespace caffe2 {
 
index 40fdb51..17a0392 100644 (file)
@@ -2,6 +2,7 @@
 #define CAFFE2_CORE_OPERATOR_H_
 
 #include <array>
+#include <cfenv>
 #include <climits>
 #include <cstddef>
 #include <exception>
@@ -27,6 +28,8 @@
 #include <ATen/core/function_schema.h>
 #include <ATen/core/ivalue.h>
 
+C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
+
 namespace caffe2 {
 
 class CAFFE2_API OperatorBase;
@@ -823,7 +826,22 @@ class Operator : public OperatorBase {
       StartAllObservers();
 
       context_.SwitchToDevice(stream_id);
+
+      if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
+        std::feclearexcept(FE_ALL_EXCEPT);
+      }
       bool result = RunOnDevice();
+      if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
+        CAFFE_ENFORCE(
+            !std::fetestexcept(FE_DIVBYZERO),
+            "Division by zero floating point exception (FE_DIVBYZERO) reported.");
+        CAFFE_ENFORCE(
+            !std::fetestexcept(FE_INVALID),
+            "Invalid floating point exception (FE_INVALID) reported.");
+        CAFFE_ENFORCE(
+            !std::fetestexcept(FE_OVERFLOW),
+            "Overflow floating point exception (FE_OVERFLOW) reported.");
+      }
       if (!result) {
         this->RecordLastFailedOpNetPosition();
       }
diff --git a/caffe2/python/operator_fp_exceptions_test.py b/caffe2/python/operator_fp_exceptions_test.py
new file mode 100644 (file)
index 0000000..6e08f92
--- /dev/null
@@ -0,0 +1,41 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from caffe2.python import core, workspace
+from caffe2.proto import caffe2_pb2
+from caffe2.python.test_util import TestCase
+
+import numpy as np
+import unittest
+
+
+def setThrowIfFpExceptions(enabled):
+    core.GlobalInit(["caffe2", "--caffe2_operator_throw_if_fp_exceptions=%d" % (1 if enabled else 0)])
+
+
+class OperatorFPExceptionsTest(TestCase):
+    def test_fp_exception_divbyzero(self):
+        # This test asserts the followings
+        # - If flag caffe2_operator_throw_if_fp_exceptions is set,
+        # floating point exceptions will be thrown
+        # - If flag caffe2_operator_throw_if_fp_exceptions is not set,
+        # floating point exceptions will not be thrown
+        workspace.blobs["0"] = np.array([0.0], dtype=np.float32)
+        workspace.blobs["1"] = np.array([1.0], dtype=np.float32)
+
+        net = core.Net("test_fp")
+        net.Div(["1", "0"], "out")
+
+        for throw_if_fp_exceptions in (True, False):
+            setThrowIfFpExceptions(throw_if_fp_exceptions)
+            exception_raised = False
+            try:
+                workspace.RunNetOnce(net)
+            except Exception as e:
+                exception_raised = True
+            self.assertEquals(exception_raised, throw_if_fp_exceptions)
+
+
+if __name__ == '__main__':
+    unittest.main()