From 57eda692192ed4bcd291f09bf5b2e409c3031db6 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Tue, 14 Sep 2021 12:25:45 -0700 Subject: [PATCH] [fx2trt] fix elementwise op converter with one operand being a literal and has different type (#65004) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65004 If we have some code like `torch.add(x, 1)` and x is a float tensor then in conversion things would falling apart because currently we will add a constant layer of int32 dtype for `1` but we actually need float dtype. This diff adds an arg to `get_trt_tensor` which specify the dtype of the constant layer we would created. Also, start to add doc string for functions. Reviewed By: yinghai Differential Revision: D30852156 fbshipit-source-id: 650ce72d2794093a4616e640ea503dcc1c6b2bc4 --- .../fx2trt/converters/acc_ops_converters.py | 97 ++++++++++++++++------ 1 file changed, 73 insertions(+), 24 deletions(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 46a6a51..f555456 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -48,38 +48,24 @@ def get_axes_for_reduce_op(dim, has_implicit_batch_dimension): return axes -def create_constant(network, tensor, name, squeeze_vector=True): - """ - Args: - squeeze_vector: if set to True, we'll squeeze a vector of shape (1, ..1, n) to (n,) - and rely on broadcasting to expand the dimensions as needed - """ +def create_constant(network, tensor, name, dtype): if isinstance(tensor, int): tensor = torch.IntTensor([tensor]) if isinstance(tensor, float): tensor = torch.Tensor([tensor]) - shape = tuple(tensor.shape) - if squeeze_vector: - # Remove all preceding 1s as they can be re-inserted later during broadcasting. - num_preceding_ones = 0 - for j in range(len(shape)): - if int(shape[j]) == 1: - num_preceding_ones += 1 - else: - break + if dtype: + tensor = tensor.to(dtype) - # If shape is all 1s, we want last digit. - shape = shape[num_preceding_ones:] if num_preceding_ones < len(shape) else (1,) - constant = network.add_constant(shape, to_numpy(tensor)) + constant = network.add_constant(tensor.shape, to_numpy(tensor)) constant.name = name return constant.get_output(0) -def get_trt_tensor(network, input_val, name, squeeze_vector=True): +def get_trt_tensor(network, input_val, name, dtype=None): if isinstance(input_val, (torch.Tensor, int, float)): - return create_constant(network, input_val, name, squeeze_vector) + return create_constant(network, input_val, name, dtype) elif not isinstance(input_val, trt.tensorrt.ITensor): raise RuntimeError( f"Received input {input_val} of name {name} that " @@ -113,6 +99,24 @@ def append_ones(network, input, name, num_prepend_ones): def broadcast(network, a, b, a_name, b_name, preset_diff=0): + """ + Broadcast two TensorRT tensors to the same number of dimensions by + prepending 1s to the tensor with less number of dimensions. + + Args: + network: TensorRT network object. + a: A TensorRT tensor. + b: A TensorRT tensor. + a_name: Name of tensor a. + b_name: Name of tensor b. + preset_diff: The difference of number of dimensions after broadcast. + A positive number means after broadcast, tensor `a` would have + `preset_diff` more dimensions than `b`. This is used in matmul, + since we need to broadcast tensors but not always to the same + number of dimension. The reason is that matmul supports Matrix + x Vector and in this case broadcasted vector should have 1 less + number of dimensions than the matrix tensor. + """ a_shape = tuple(a.shape) b_shape = tuple(b.shape) @@ -125,8 +129,51 @@ def broadcast(network, a, b, a_name, b_name, preset_diff=0): return a, b def add_binary_elementwise_layer(network, lhs_val, rhs_val, op_type, name): - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs") - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs") + """ + This function adds a TensorRT elementwise layer. We only allow at most one + operand to not be a trt tensor, otherwise, we should const fold it first. + If any operand is not a trt tensor, we make it a trt constant layer which + has the same type as the other trt tensor. Then we broadcast these two inputs + to have the same number of dimensions. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network: TensorRT network object. + lhs_val: Left operand of the binary operation. Could be a TensorRT tensor, + a PyTorch tensor or a simple value. + rhs_val: Right operand of the binary operation. Similar to lhs_val. + op_type: Type of the TensorRT elementwise binary operation. + name: The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT elementwise layer. + """ + dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + if isinstance(lhs_val, trt.tensorrt.ITensor): + dtype = torch_dtype_from_trt(lhs_val.dtype) + is_lhs_trt_tensor = True + if isinstance(rhs_val, trt.tensorrt.ITensor): + dtype = torch_dtype_from_trt(rhs_val.dtype) + is_rhs_trt_tensor = True + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + raise RuntimeError(f"Both operands of the binary elementwise op {name}" + "are constant. In this case, please consider constant fold the model first.") + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", dtype) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len(rhs_val.shape) + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len(lhs_val.shape) + lhs_val, rhs_val = broadcast( network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" ) @@ -932,7 +979,7 @@ def acc_ops_topk(network, target, args, kwargs, name): input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension) ) layer.name = name - return (layer.get_output(0), layer.get_output(1)) + return layer.get_output(0), layer.get_output(1) @tensorrt_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name): @@ -1034,7 +1081,9 @@ def acc_ops_reshape(network, target, args, kwargs, name): s = append_ones(network, s, f"{name}_{i}", 1) trt_shape.append(s) else: - trt_shape.append(get_trt_tensor(network, s, f"{name}_{i}")) + trt_shape.append( + get_trt_tensor(network, s, f"{name}_{i}") + ) shape_layer = network.add_concatenation(inputs=trt_shape) shape_layer.axis = 0 -- 2.7.4