Fix gelu in PyTorch frontend, tighten numerical checks (#5763)
authorThomas Viehmann <tv.code@beamnet.de>
Thu, 11 Jun 2020 12:10:27 +0000 (14:10 +0200)
committerGitHub <noreply@github.com>
Thu, 11 Jun 2020 12:10:27 +0000 (21:10 +0900)
Previously, the PyTorch frontend approximated gelu with fastgelu.
To provide a more faithful conversion, we implement gelu instead.

We also tighten the numerical comparisons between PyTorch and
TVM-from-PyTorch to 1e-5. The object detection models need an
increased tolerance of 1e-4 to pass.

I had to throw in a few fixes for missing conversions
(probably due to working with very new PyTorch).

I must admit the GoogLeNet/NasNet test didn't run on my machine,
probably due to problems at my end.

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 7b96530..380388a 100644 (file)
@@ -481,7 +481,10 @@ def _full():
             msg = "Data type %s could not be parsed in zeros op" % (type(data))
             raise AssertionError(msg)
 
-        dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+        if inputs[2] is not None: # dtype given
+            dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+        else:
+            dtype = data.type_annotation.dtype
 
         return _op.full(_expr.const(fill_value), shape, dtype=dtype)
     return _impl
@@ -567,14 +570,13 @@ def _celu():
 
 def _gelu():
     def _impl(inputs, input_types):
-        import math
         data = inputs[0]
-
-        def _pow3(x):
-            return x * x * x
-        return _expr.const(0.5) * data * (_expr.const(1.0) +
-                                          _op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) *
-                                                   (data + _expr.const(0.044715) * _pow3(data))))
+        # gelu is data  * normcdf(data)
+        # normcdf expressed as erf because we don't currently have that intrinsic
+        # note that there is also a fastgelu variant approximating normcdf
+        # with tanh and third order polynomials, but this is "true" gelu
+        return data * (_expr.const(0.5) +
+                       _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5))
     return _impl
 
 def _selu():
@@ -1839,6 +1841,7 @@ def _get_convert_map(prelude):
         "aten::Int"                             : _int(),
         "prim::NumToTensor"                     : _numtotensor(),
         "prim::ImplicitTensorToNum"             : _tensortonum(),
+        "aten::ScalarImplicit"                  : _tensortonum(),
         "aten::constant_pad_nd"                 : _pad("constant"),
         "aten::reflection_pad1d"                : _pad("reflect"),
         "aten::reflection_pad2d"                : _pad("reflect"),
@@ -1877,6 +1880,7 @@ def _get_convert_map(prelude):
         "aten::floor"                           : _unary("floor"),
         "aten::round"                           : _unary("round"),
         "aten::isfinite"                        : _unary("isfinite"),
+        "aten::isinf"                           : _unary("isinf"),
         "aten::isnan"                           : _unary("isnan"),
         "aten::clamp"                           : _clamp(),
         "aten::detach"                          : _identity(),
index 3c7ff4f..c9c76be 100644 (file)
@@ -135,7 +135,8 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
 
 def verify_model(model_name, input_data=[],
                  custom_convert_map={},
-                 ctx_list=ctx_list()):
+                 ctx_list=ctx_list(),
+                 rtol=1e-5, atol=1e-5):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
     if isinstance(model_name, str):
@@ -190,7 +191,7 @@ def verify_model(model_name, input_data=[],
 
                 assert_shapes_match(baseline_output, compiled_output)
                 tvm.testing.assert_allclose(baseline_output, compiled_output,
-                                            rtol=1e-3, atol=1e-3)
+                                            rtol=rtol, atol=atol)
 
     del model_name
     del baseline_model
@@ -1216,35 +1217,35 @@ def test_conv3d_transpose():
 # Model tests
 def test_resnet18():
     torch.set_grad_enabled(False)
-    verify_model("resnet18")
+    verify_model("resnet18", atol=1e-4, rtol=1e-4)
 
 def test_squeezenet1_0():
     torch.set_grad_enabled(False)
-    verify_model("squeezenet1_0")
+    verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4)
 
 def test_squeezenet1_1():
     torch.set_grad_enabled(False)
-    verify_model("squeezenet1_1")
+    verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4)
 
 def test_densenet121():
     torch.set_grad_enabled(False)
-    verify_model("densenet121")
+    verify_model("densenet121", atol=1e-4, rtol=1e-4)
 
 def test_inception_v3():
     torch.set_grad_enabled(False)
-    verify_model("inception_v3")
+    verify_model("inception_v3", atol=1e-4, rtol=1e-4)
 
 def test_googlenet():
     torch.set_grad_enabled(False)
-    verify_model("googlenet")
+    verify_model("googlenet", atol=1e-4, rtol=1e-4)
 
 def test_mnasnet0_5():
     torch.set_grad_enabled(False)
-    verify_model("mnasnet0_5")
+    verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4)
 
 def test_mobilenet_v2():
     torch.set_grad_enabled(False)
-    verify_model("mobilenet_v2")
+    verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4)
 
 """
 #TODO: Fix VGG and AlexNet issues (probably due to pooling)
@@ -1305,19 +1306,19 @@ def test_segmentaton_models():
 
     inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
 
-    verify_model(SegmentationModelWrapper(fcn.eval()), inp)
+    verify_model(SegmentationModelWrapper(fcn.eval()), inp, atol=1e-4, rtol=1e-4)
 
     # depthwise + dilated covolution not supported on x86
     # see https://github.com/apache/incubator-tvm/issues/4962
     cuda_ctx = ("cuda", tvm.gpu(0))
     if cuda_ctx[1].exist:
-        verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
+        verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx], atol=1e-4, rtol=1e-4)
 
 
 def test_3d_models():
     input_shape = (1, 3, 4, 56, 56)
     resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval()
-    verify_model(resnet3d, [torch.rand(input_shape)])
+    verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4)
 
 
 def verify_script_model(pt_model, ishapes):