[fx2trt] Factor out add_matrix_multiply_layer
authorJerry Zhang <jerryzh@fb.com>
Mon, 16 Aug 2021 21:07:43 +0000 (14:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 21:13:37 +0000 (14:13 -0700)
Summary: Factor out the function so that it can be reused in future diffs

Test Plan: buck run mode/opt caffe2/torch/fb/fx2trt:test_matmul

Reviewed By: 842974287

Differential Revision: D30322823

fbshipit-source-id: 069b945de2c744cdbcca1618b62827692dfb4174

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py

index b85ab40..6be0e6f 100644 (file)
@@ -182,6 +182,31 @@ def add_transpose_layer(
     layer.name = name
     return layer.get_output(0)
 
+def add_matrix_multiply_layer(network, input_val, other_val, name):
+    """ Adds a matrix multiply layer to the TensorRT network
+    Args:
+        network: TensorRT Network
+        input_val: input matrix/vector TensorRT ITensor
+        other_val: another input matrix/vector TensorRT ITensor
+        name: Name of the matrix multiply layer
+    Returns:
+        output TensorRT ITensor from the matrix multiply layer
+    """
+    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
+    preset_diff = 0
+
+    if len(input_val.shape) == 1:
+        preset_diff -= 1
+        input_matrix_op = trt.MatrixOperation.VECTOR
+
+    if len(other_val.shape) == 1:
+        preset_diff += 1
+        other_matrix_op = trt.MatrixOperation.VECTOR
+
+    input_val, other_val = broadcast(network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff)
+    layer = network.add_matrix_multiply(input_val, input_matrix_op, other_val, other_matrix_op)
+    layer.name = name
+    return layer.get_output(0)
 
 def process_attr(val, num_elem):
     if not isinstance(val, tuple):
@@ -1023,26 +1048,7 @@ def acc_ops_matmul(network, target, args, kwargs, name):
                 f"matmul received input {i} that is not part " "of the TensorRT region!"
             )
 
-    input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
-    preset_diff = 0
-
-    if len(input_val.shape) == 1:
-        preset_diff -= 1
-        input_matrix_op = trt.MatrixOperation.VECTOR
-
-    if len(other_val.shape) == 1:
-        preset_diff += 1
-        other_matrix_op = trt.MatrixOperation.VECTOR
-
-    input_val, other_val = broadcast(
-        network, input_val, other_val, f"{name}_input", f"{name}_other", preset_diff
-    )
-    layer = network.add_matrix_multiply(
-        input_val, input_matrix_op, other_val, other_matrix_op
-    )
-    layer.name = name
-    return layer.get_output(0)
-
+    return add_matrix_multiply_layer(network, input_val, other_val, name)
 
 @tensorrt_converter(acc_ops.sigmoid)
 def acc_ops_sigmoid(network, target, args, kwargs, name):