Improve type handling in PyTorch frontend (#5834)
authorThomas Viehmann <tv.code@beamnet.de>
Mon, 22 Jun 2020 13:33:04 +0000 (15:33 +0200)
committerGitHub <noreply@github.com>
Mon, 22 Jun 2020 13:33:04 +0000 (19:03 +0530)
commit4eb49f04f1a1eaf93fa2aa67a533e752c45d96f8
tree6c2651f3a746a001862be05e7026a58df986c3fa
parent8942b78387e333f2bf5679e2acd6e5916a159dd5
Improve type handling in PyTorch frontend (#5834)

* Improve type handling in PyTorch frontend

- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.

* Fix scalar detection using numpy.isscalar

and address other review comments. Thank you @siju-samuel

* refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const
python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/qnn_test.py
tests/python/frontend/pytorch/test_forward.py