[RELAY][FRONTEND][TF] Fix FuseBatchNorm output cast error if need_cast is True (...
authorhcyang <yhcvb@126.com>
Wed, 19 Feb 2020 06:33:16 +0000 (14:33 +0800)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2020 06:33:15 +0000 (22:33 -0800)
python/tvm/relay/frontend/tensorflow.py

index f920682..587b076 100644 (file)
@@ -897,6 +897,7 @@ def _fused_batch_norm():
                       disables=['momentum'])(inputs, attr)
 
         if need_cast:
+            out = _expr.TupleGetItem(out.astuple(), 0)
             out = _op.cast(out, dtype=attr['T'].name)
         return out
     return _impl