#define TORCH_META_FUNC(name) void structured_##name::meta
#define TORCH_META_FUNC2(name, overload) void structured_##name##_##overload::meta
+// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct as a return value.
+// They should be used when the kernel in question has precomputed values declared in native_functions.yaml and
+// the corresponding implementation should return an instance of the aforementioned struct.
+#define TORCH_PRECOMPUTE_META_FUNC(name) structured_##name::meta_return_ty structured_##name::meta
+#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) structured_##name##_##overload::meta_return_ty structured_##name##_##overload::meta
+
+// Use this to create a precompute struct in a meta function.
+#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
+#define TORCH_PRECOMPUTE_STRUCT2(name, overload) structured_##name##_##overload::precompute_out<>
+
// Use this to define the prototype for an implementation. This takes only
// one argument, which is the name of the dispatch key entry you're
// implementing.
namespace meta{
using namespace native;
-TORCH_META_FUNC(avg_pool2d) (
- const Tensor& input,
- IntArrayRef kernel_size,
- IntArrayRef stride,
- IntArrayRef padding,
- bool ceil_mode,
- bool count_include_pad,
- c10::optional<int64_t> divisor_override
-) {
+TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
+(const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override) {
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
- const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
- const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
+ const int64_t kH = kernel_size[0];
+ const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
- const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
- const int dW = stride.empty() ? kW :
- stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
+ const int64_t dH = stride.empty() ? kH : stride[0];
+ const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
- const int padH = safe_downcast<int, int64_t>(padding[0]);
- const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
+ const int64_t padH = padding[0];
+ const int64_t padW = padding.size() == 1 ? padH : padding[1];
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
"divisor must be not zero");
- /* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
- const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
- const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
+ const int64_t outputHeight = pooling_output_shape<int64_t>(
+ inputHeight, kH, padH, dH, 1, ceil_mode);
+ const int64_t outputWidth =
+ pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
auto memory_format = input.suggest_memory_format();
pool2d_shape_check(
- input,
- kH, kW, dH, dW, padH, padW, 1, 1,
- nInputPlane,
- inputHeight, inputWidth,
- outputHeight, outputWidth, memory_format);
+ input,
+ kH,
+ kW,
+ dH,
+ dW,
+ padH,
+ padW,
+ 1,
+ 1,
+ nInputPlane,
+ inputHeight,
+ inputWidth,
+ outputHeight,
+ outputWidth,
+ memory_format);
/* resize output */
if (input.ndimension() == 3) {
- set_output(0, {nInputPlane, outputHeight, outputWidth}, input.options());
+ set_output(
+ 0,
+ {nInputPlane,
+ outputHeight,
+ outputWidth},
+ input.options());
}
else {
- set_output(0, {nbatch, nInputPlane, outputHeight, outputWidth}, input.options().memory_format(memory_format));
+ set_output(
+ 0,
+ {nbatch,
+ nInputPlane,
+ outputHeight,
+ outputWidth},
+ input.options().memory_format(memory_format));
}
+
+ return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW);
}
TORCH_META_FUNC(avg_pool2d_backward) (
namespace native {
-TORCH_IMPL_FUNC(avg_pool2d_out_cpu) (
- const Tensor &input,
- IntArrayRef kernel_size,
- IntArrayRef stride,
- IntArrayRef padding,
- bool ceil_mode,
- bool count_include_pad,
- c10::optional<int64_t> divisor_override,
- const Tensor &output
-) {
- const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
- const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
-
- const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
- const int dW = stride.empty() ? kW :
- stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
-
- const int padH = safe_downcast<int, int64_t>(padding[0]);
- const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
-
+TORCH_IMPL_FUNC(avg_pool2d_out_cpu)
+(const Tensor& input,
+ int64_t kH,
+ int64_t kW,
+ int64_t dH,
+ int64_t dW,
+ int64_t padH,
+ int64_t padW,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override,
+ const Tensor& output) {
avg_pool2d_kernel(
- kCPU, output, input,
- kW, kH, dW, dH, padW, padH,
- count_include_pad, divisor_override);
+ kCPU,
+ output,
+ input,
+ kW,
+ kH,
+ dW,
+ dH,
+ padW,
+ padH,
+ count_include_pad,
+ divisor_override);
}
TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) (
DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
// averge pooling has same signature for forward and backward
-using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
+using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
+ int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
+using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
+
DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
-DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_backward_kernel);
+DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
namespace {
}
}
-TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
+TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("all", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
+ return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}
-TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
+TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_all_any("any", self, maybe_get_output());
auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
+ return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}
void check_argmax_argmin(
TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
- dim = maybe_wrap_dim(dim, self.dim());
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
}
TORCH_IMPL_FUNC(any_out)
-(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
- dim = maybe_wrap_dim(dim, self.dim());
+(const Tensor& self,
+ int64_t dim,
+ bool keepdim,
+ const Tensor& result) {
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
void cpu_avg_pool(
const Tensor& output_,
const Tensor& input_,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH,
+ int64_t kW, int64_t kH,
+ int64_t dW, int64_t dH,
+ int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto input = input_.contiguous();
void cpu_avg_pool_channels_last(
const Tensor& output_,
const Tensor& input_,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH,
+ int64_t kW, int64_t kH,
+ int64_t dW, int64_t dH,
+ int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 4,
void avg_pool2d_kernel_impl(
const Tensor& output,
const Tensor& input,
- int kW, int kH,
- int dW, int dH,
- int padW, int padH,
+ int64_t kW, int64_t kH,
+ int64_t dW, int64_t dH,
+ int64_t padW, int64_t padH,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
switch (input.suggest_memory_format()) {
} // anonymous namespace
-TORCH_IMPL_FUNC(avg_pool2d_out_cuda) (
- const Tensor& input_,
- IntArrayRef kernel_size,
- IntArrayRef stride,
- IntArrayRef padding,
- bool ceil_mode,
- bool count_include_pad,
- c10::optional<int64_t> divisor_override,
- const Tensor& output
-) {
+TORCH_IMPL_FUNC(avg_pool2d_out_cuda)
+(const Tensor& input_,
+ int64_t kH_,
+ int64_t kW_,
+ int64_t dH_,
+ int64_t dW_,
+ int64_t padH_,
+ int64_t padW_,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override,
+ const Tensor& output) {
TensorArg output_arg{ output, "output", 1 };
TensorArg input_arg{ input_, "input_", 2 };
checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg});
- const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
- const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
+ const int kH = safe_downcast<int, int64_t>(kH_);
+ const int kW = safe_downcast<int, int64_t>(kW_);
- const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
- const int dW = stride.empty() ? kW :
- stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
+ const int dH = safe_downcast<int, int64_t>(dH_);
+ const int dW = safe_downcast<int, int64_t>(dW_);
- const int padH = safe_downcast<int, int64_t>(padding[0]);
- const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
+ const int padH = safe_downcast<int, int64_t>(padH_);
+ const int padW = safe_downcast<int, int64_t>(padW_);
/* sizes */
const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
const int64_t inputHeight = input_.size(-2);
const int64_t inputWidth = input_.size(-1);
- const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
- const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
+ int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
+ int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
const auto memory_format = input_.suggest_memory_format();
Tensor input = input_.contiguous(memory_format);
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
- <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
- count,
- input_data,
- nbatch,
- nInputPlane,
- inputHeight, inputWidth,
- outputHeight, outputWidth,
- kH, kW,
- dH, dW,
- padH, padW,
- output_data,
- divisor_override_value,
- count_include_pad, use_divisor);
+ <<<num_blocks,
+ num_threads,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ count,
+ input_data,
+ nbatch,
+ nInputPlane,
+ inputHeight,
+ inputWidth,
+ outputHeight,
+ outputWidth,
+ kH,
+ kW,
+ dH,
+ dW,
+ padH,
+ padW,
+ output_data,
+ divisor_override_value,
+ count_include_pad,
+ use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
case MemoryFormat::Contiguous: {
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
- <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
- count,
- input_data,
- nbatch,
- nInputPlane,
- inputHeight, inputWidth,
- outputHeight, outputWidth,
- kH, kW,
- dH, dW,
- padH, padW,
- output_data,
- divisor_override_value,
- count_include_pad, use_divisor);
+ <<<num_blocks,
+ num_threads,
+ 0,
+ at::cuda::getCurrentCUDAStream()>>>(
+ count,
+ input_data,
+ nbatch,
+ nInputPlane,
+ inputHeight,
+ inputWidth,
+ outputHeight,
+ outputWidth,
+ kH,
+ kW,
+ dH,
+ dW,
+ padH,
+ padW,
+ output_data,
+ divisor_override_value,
+ count_include_pad,
+ use_divisor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
+ precomputed:
+ - dim -> int dim
dispatch:
CPU, CUDA: all_out
- func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
+ precomputed:
+ - dim -> int dim
dispatch:
CPU, CUDA: any_out
- func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
structured: True
+ precomputed:
+ - kernel_size -> int kH, int kW
+ - stride -> int dH, int dW
+ - padding -> int padH, int padW
dispatch:
CPU: avg_pool2d_out_cpu
CUDA: avg_pool2d_out_cuda
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- args.extend(g.out.func.arguments.non_out)
+
+ if g.out.precomputed:
+ # A list of parameters for the impl function with
+ # certain parameters replaced with precomputed counterparts
+ # as specified in native_functions.yaml.
+ non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+
+ for a in g.out.func.arguments.non_out:
+ if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
+ # If a is in precompute.replace, append the parameters
+ # that should replace it onto non_out_args_replaced.
+ for replacement in g.out.precomputed.replace[a.name]:
+ non_out_args_replaced.append(replacement)
+ else:
+ # If not, push a as it is.
+ non_out_args_replaced.append(a)
+
+ args.extend(non_out_args_replaced)
+ else:
+ args.extend(g.out.func.arguments.non_out)
+
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]
method=False
)
)
- sig_body.append(f"op.meta({meta_exprs});")
+
+ if self.g.out.precomputed:
+ # If this function group has precomputed elements, the meta function
+ # returns a struct containing them which must be saved so that it
+ # can be unpacked when generating code to call the impl.
+ sig_body.append(f"auto precompute = op.meta({meta_exprs});")
+
+ # Put all of the contents of the precompute struct into the context
+ # so that translate will be able to return the correct args for the
+ # call to the impl.
+ for precomputed_elems in self.g.out.precomputed.replace.values():
+ for arg in precomputed_elems:
+ context.append(Expr(
+ expr=f"precompute.{arg.name}",
+ type=structured.argument_type(arg, binds=arg.name),
+ ))
+
+ # Add a use of the precompute struct so FB internal compilers don't
+ # complain that there is an unused variable.
+ sig_body.append("(void)precompute;")
+ else:
+ sig_body.append(f"op.meta({meta_exprs});")
+
# After running meta, op.outputs_ is guaranteed to be valid;
# add it to the context
parent_class = g.out.structured_inherits
if parent_class is None:
parent_class = "at::impl::MetaBase"
+ meta_return = "void"
+ precomputed = g.out.precomputed if g.structured else None
+
+ if precomputed:
+ # Generate the template declaration with one bool parameter for each
+ # precomputed element. Each parameter is true if the corresponding (in
+ # terms of position) precomputed element has been set.
+ precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list]
+ precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
+ precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
+ precompute_template_decl = f"template <{precomputed_template_params_str}>"
+
+ # Generate a string containing declarations of all precomputed elements.
+ precomputed_elements_with_cpp_types = [
+ structured.argument_type(elem, binds=elem.name)
+ for elem in precomputed_elements
+ ]
+
+ precomputed_elements_decl = ";\n".join(
+ f"{elem.cpp_type(strip_ref=True)} {elem.name}" for elem in precomputed_elements_with_cpp_types
+ )
+
+ # Generate "setter" methods for each precomputed element. Each method will return
+ # a new instance of precompute_out with the template parameter that corresponds to
+ # the member set by the method to true (to indicate that it has been set).
+ setter_methods = []
+ for i, elem in enumerate(precomputed_elements):
+ # Generate the signature. The return type will be the same
+ # as the type of `this` but with the template parameter
+ # corresponding to the element set by this method set to true.
+ # The assert generated below will ensure that this template
+ # parameter is false on the type of `this`.
+ return_ty_templates = ", ".join(
+ precomputed_template_parameters[:i] + ["true"] + precomputed_template_parameters[i + 1:]
+ )
+ return_ty = f"precompute_out<{return_ty_templates}>"
+ elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(strip_ref=True)
+ signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
+
+ # Generate an assert which checks that the
+ # template parameter corresponding to the precomputed
+ # element that is set by this method is false on the
+ # class corresponding to the object that `this` points to.
+ # This ensures that each element can be set only once.
+ assert_msg = f"\"{precomputed_elements[i].name} already set\""
+ assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
+
+ # Generate the new object construction block. All state
+ # except the element that this method sets is copied from the
+ # object that `this` points to. The value for the element that
+ # the method sets is taken from a method parameter.
+ construction_stmts = []
+ construction_stmts.append(f"{return_ty} ret;")
+
+ for j, elem in enumerate(precomputed_elements):
+ if i == j:
+ construction_stmts.append(f"ret.{elem.name} = value;")
+ else:
+ construction_stmts.append(f"ret.{elem.name} = this->{elem.name};")
+
+ construction_stmts.append("return ret;")
+ construction_block = "\n".join(construction_stmts)
+
+ setter_methods.append(f"""
+ {signature} {{
+ {assert_stmt}
+ {construction_block}
+ }}
+ """)
+ setter_methods_decl = "\n".join(setter_methods)
+
+ # Meta should return an instance of the struct containing the precomputed elements.
+ meta_return_template_params = ", ".join(["true"] * len(precomputed_template_parameters))
+ # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
+ # type (which has a variable number of template parameters).
+ meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
+ meta_return = "meta_return_ty"
+ precomputed_decl = f"""
+ {precompute_template_decl}
+ struct TORCH_API precompute_out {{
+ {setter_methods_decl}
+ {precomputed_elements_decl};
+ }};"""
+ else:
+ meta_return_typedef = ""
+ precomputed_decl = ""
+
return f"""\
struct TORCH_API structured_{name} : public {parent_class} {{
- void meta({args_str});
+ {precomputed_decl}
+ {meta_return_typedef}
+ {meta_return} meta({args_str});
}};
"""
# changes the semantics of set_output to call the parent class.
structured_inherits: Optional[str]
+ # Structured kernels can declare elements as "precomputed". These elements
+ # are returned by the meta function in one struct and passed to the impl
+ # function in lieu of certain kernel arguments that these precomputed
+ # elements supersede. Information about the names and types of these
+ # precomputed elements and how they correspond to kernel arguments is stored
+ # in this member, if applicable.
+ precomputed: Optional['Precompute']
+
# Argument names whose default should be excluded from the C++ interface.
# Intended for resolving overload ambiguities between signatures.
cpp_no_default_args: Set[str]
category_override = e.pop('category_override', None)
assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}'
+ precomputed_dict = e.pop('precomputed', None)
+ assert precomputed_dict is None or structured is True
+ precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
+
from tools.codegen.api import cpp
raw_dispatch = e.pop('dispatch', None)
structured=structured,
structured_delegate=structured_delegate,
structured_inherits=structured_inherits,
+ precomputed=precomputed,
manual_kernel_registration=manual_kernel_registration,
manual_cpp_binding=manual_cpp_binding,
python_module=python_module,
if return_decl[0] == '(' and return_decl[-1] == ')':
return_decl = return_decl[1:-1]
return tuple(Return.parse(arg) for arg in return_decl.split(', '))
+
+
+# A Precompute instance consists of a map from kernel argument name
+# to the list of Argument instances that should replace that
+# kernel argument in the impl function.
+@dataclass(frozen=True)
+class Precompute:
+ # A map from kernel argument name -> a list of precomputed
+ # elements that replaces/supersedes it.
+ replace: Dict[str, List[Argument]]
+
+ @staticmethod
+ def parse(src: object) -> 'Precompute':
+ assert isinstance(src, list)
+
+ # src is a list of strings of the format:
+ # {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
+ # Parse this list to get the names of which precomputed elements
+ # should replace which kernel arguments.
+ replace = {}
+ for raw_replace_item in src:
+ assert isinstance(raw_replace_item, str)
+
+ arg, with_list_raw = raw_replace_item.split(' -> ')
+ with_list = with_list_raw.split(',')
+ with_list_args = [Argument.parse(name.strip()) for name in with_list]
+ replace[arg] = with_list_args
+
+ r = Precompute(replace=replace)
+ assert r.to_list() == src, 'r.to_list() != src'
+ return r
+
+ def to_list(self) -> List[str]:
+ replace_list = []
+ for kernel_param, replacement_params in self.replace.items():
+ replacements = ', '.join(str(param) for param in replacement_params)
+ replace_list.append(f'{kernel_param} -> {replacements}')
+
+ return replace_list