Add scalar_type_to_pytorch_type dict in ONNX symbolic
authorzrphercule <zrphercule@gmail.com>
Fri, 11 Jan 2019 18:45:47 +0000 (10:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 18:55:43 +0000 (10:55 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15965

Differential Revision: D13637521

Pulled By: zrphercule

fbshipit-source-id: 922cadc56f6380f67c14444cff4aa354a87150af

torch/onnx/symbolic.py

index bbdb761..82e469e 100644 (file)
@@ -1051,6 +1051,19 @@ scalar_name_to_pytorch = {
     'int16_t': 'Short',
 }
 
+# This indicates each scalar type's corresponding
+# torch type. Related source:
+# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
+scalar_type_to_pytorch_type = [
+        torch.uint8,    # 0
+        torch.int8,     # 1
+        torch.short,    # 2
+        torch.int,      # 3
+        torch.int64,    # 4
+        torch.half,     # 5
+        torch.float,    # 6
+        torch.double,   # 7
+]
 
 def _cast_func_template(to_i, g, input, non_blocking):
     return g.op("Cast", input, to_i=to_i)