Don't add cast for TF batch norm when type isn't changing (#5731)
authorTrevor Morris <trevmorr@amazon.com>
Mon, 8 Jun 2020 23:43:28 +0000 (16:43 -0700)
committerGitHub <noreply@github.com>
Mon, 8 Jun 2020 23:43:28 +0000 (05:13 +0530)
python/tvm/relay/frontend/tensorflow.py

index 201c6ba..50987f9 100644 (file)
@@ -1227,7 +1227,7 @@ def _fused_batch_norm():
             attr['data_format'] = attr['data_format'].decode("utf-8")
             if attr['data_format'] == 'NCHW':
                 axis = 1
-        if 'U' in attr:
+        if 'U' in attr and attr['U'].name != attr['T'].name:
             need_cast = True
             inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name)
         # Check if mean and variance are empty