Summary:
Fix https://github.com/pytorch/pytorch/issues/14104
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15832
Reviewed By: bddppq
Differential Revision:
D13598332
Pulled By: yinghai
fbshipit-source-id:
3302ac47928974f49353c5da8af440e5c1716c22
if value_info.name in initialized:
continue
shape = list(d.dim_value for d in value_info.type.tensor_type.shape.dim)
- ws.FeedBlob(value_info.name, np.ones(shape), device_option)
+ ws.FeedBlob(
+ value_info.name,
+ np.ones(shape, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[value_info.type.tensor_type.elem_type]),
+ device_option)
@staticmethod
def optimize_onnx(input, init=False, predict=False):