From: zrphercule Date: Fri, 11 Jan 2019 18:45:47 +0000 (-0800) Subject: Add scalar_type_to_pytorch_type dict in ONNX symbolic X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1897 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c9d7ead0c462f4a0b21e8717fc564e43b2827738;p=platform%2Fupstream%2Fpytorch.git Add scalar_type_to_pytorch_type dict in ONNX symbolic Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15965 Differential Revision: D13637521 Pulled By: zrphercule fbshipit-source-id: 922cadc56f6380f67c14444cff4aa354a87150af --- diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index bbdb761..82e469e 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1051,6 +1051,19 @@ scalar_name_to_pytorch = { 'int16_t': 'Short', } +# This indicates each scalar type's corresponding +# torch type. Related source: +# https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h +scalar_type_to_pytorch_type = [ + torch.uint8, # 0 + torch.int8, # 1 + torch.short, # 2 + torch.int, # 3 + torch.int64, # 4 + torch.half, # 5 + torch.float, # 6 + torch.double, # 7 +] def _cast_func_template(to_i, g, input, non_blocking): return g.op("Cast", input, to_i=to_i)