From: Jerry Zhang Date: Mon, 16 Aug 2021 21:07:43 +0000 (-0700) Subject: [fx2trt] Factor out add_matrix_multiply_layer X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~989 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a12b371f7c0b0180b4661d25c1a2981863d693d1;p=platform%2Fupstream%2Fpytorch.git [fx2trt] Factor out add_matrix_multiply_layer Summary: Factor out the function so that it can be reused in future diffs Test Plan: buck run mode/opt caffe2/torch/fb/fx2trt:test_matmul Reviewed By: 842974287 Differential Revision: D30322823 fbshipit-source-id: 069b945de2c744cdbcca1618b62827692dfb4174 --- diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index b85ab40..6be0e6f 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -182,6 +182,31 @@ def add_transpose_layer( layer.name = name return layer.get_output(0) +def add_matrix_multiply_layer(network, input_val, other_val, name): + """ Adds a matrix multiply layer to the TensorRT network + Args: + network: TensorRT Network + input_val: input matrix/vector TensorRT ITensor + other_val: another input matrix/vector TensorRT ITensor + name: Name of the matrix multiply layer + Returns: + output TensorRT ITensor from the matrix multiply layer + """ + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input_val.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other_val.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff) + layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op) + layer.name = name + return layer.get_output(0) def process_attr(val, num_elem): if not isinstance(val, tuple): @@ -1023,26 +1048,7 @@ def acc_ops_matmul(network, target, args, kwargs, name): f"matmul received input {i} that is not part " "of the TensorRT region!" ) - input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE - preset_diff = 0 - - if len(input_val.shape) == 1: - preset_diff -= 1 - input_matrix_op = trt.MatrixOperation.VECTOR - - if len(other_val.shape) == 1: - preset_diff += 1 - other_matrix_op = trt.MatrixOperation.VECTOR - - input_val, other_val = broadcast( - network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff - ) - layer = network.add_matrix_multiply( - input_val, input_matrix_op, other_val, other_matrix_op - ) - layer.name = name - return layer.get_output(0) - + return add_matrix_multiply_layer(network, input_val, other_val, name) @tensorrt_converter(acc_ops.sigmoid) def acc_ops_sigmoid(network, target, args, kwargs, name):