# dim, which is a very rare case. For now we just claim not supporting dim=None.
assert dim is not None, "We don't support dim=None right now."
+ dim = dim % (len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0))
if network.has_implicit_batch_dimension:
assert dim != 0, "We don't support squeeze batch dim when it's implicit."
dim -= 1
layer.name = name
return layer.get_output(0)
+@tensorrt_converter(acc_ops.topk)
+def acc_ops_topk(network, target, args, kwargs, name):
+ input_val = kwargs["input"]
+
+ if not isinstance(input_val, trt.tensorrt.ITensor):
+ raise RuntimeError(f"topk received input {input_val} that is not part "
+ "of the TensorRT region!")
+
+ if kwargs["sorted"] and kwargs["k"] != 1:
+ raise RuntimeError("Currently we don't support sorted=True in topk.")
+
+ if not network.has_implicit_batch_dimension and len(input_val.shape) <= 1:
+ raise RuntimeError("At least 2 dimensions are required for input to topk.")
+
+ num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0)
+ k = kwargs["k"]
+ dim = (kwargs["dim"] if kwargs["dim"] else -1) % num_dims
+ operation = trt.TopKOperation.MAX if kwargs["largest"] else trt.TopKOperation.MIN
+ layer = network.add_topk(
+ input_val, operation, k, get_axes_for_reduce_op(dim, network.has_implicit_batch_dimension)
+ )
+ layer.name = name
+ return (layer.get_output(0), layer.get_output(1))
@tensorrt_converter(acc_ops.adaptive_avg_pool2d)
def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name):
def layer_norm(*, input, normalized_shape, weight, bias, eps):
return nn.functional.layer_norm(**locals())
+def argmin_max_mapper_impl(node: torch.fx.Node, largest: bool) -> torch.fx.Node:
+ """
+ Map torch.argmin or torch.argmax to acc_ops.flatten (depend on dim) + acc_ops.topk
+ + acc_ops.getitem + acc_ops.squeeze (depends on keepdim).
+ """
+ input_node = node.kwargs["input"]
+ dim = node.kwargs["dim"]
+ keepdim = node.kwargs["keepdim"]
+
+ if dim is None and keepdim:
+ raise RuntimeError("We currently don't support argmin/argmax with dim=None and keepdim=True")
+
+ with node.graph.inserting_before(node):
+ if dim is None:
+ flatten_kwargs = {"input": node.kwargs["input"], "start_dim": 0, "end_dim": -1}
+ flatten_node = node.graph.call_function(flatten, kwargs=flatten_kwargs)
+ flatten_node.meta["type"] = torch.Tensor
+ input_node = flatten_node
+ dim = -1
+
+ topk_kwargs = {"input": input_node, "k": 1, "dim": dim, "largest": largest, "sorted": False}
+ topk_node = node.graph.call_function(topk, kwargs=topk_kwargs)
+ # It's actually more like NamedTuple but tuple here should be fine.
+ topk_node.meta["type"] = tuple
+
+ getitem_kwargs = {"input": topk_node, "idx": 1}
+ getitem_node = node.graph.call_function(getitem, kwargs=getitem_kwargs)
+ getitem_node.meta["type"] = torch.Tensor
+ output_node = getitem_node
+
+ if not keepdim:
+ squeeze_kwargs = {"input": getitem_node, "dim": dim}
+ output_node = node.graph.call_function(squeeze, kwargs=squeeze_kwargs)
+
+ output_node.meta = node.meta.copy()
+ return output_node
+
+@register_custom_acc_mapper_fn(
+ op_and_target=("call_function", torch.argmin),
+ arg_replacement_tuples=[
+ ("input", "input"),
+ ("dim", "dim"),
+ ("keepdim", "keepdim"),
+ ],
+)
+def torch_argmin_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
+ """
+ Map torch.argmin to acc_ops.flatten (depend on dim) + acc_ops.topk + acc_ops.getitem
+ + acc_ops.squeeze (depends on keepdim).
+ """
+ return argmin_max_mapper_impl(node, largest=False)
@register_custom_acc_mapper_fn(
op_and_target=("call_method", "split"),