[structured] Preserve computed elements from meta func to impl (#61746)
authorMeghan Lele <meghanl@fb.com>
Wed, 1 Sep 2021 21:24:54 +0000 (14:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 21:34:25 +0000 (14:34 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61746

**Summary**
This commit introduces a new feature for structured kernels that allows
kernels to declare quantities as "precomputed" in
`native_functions.yaml`, compute them once in the `meta` function and
reuse them again in the `impl`. The names and types of these quantities
are used to generate code for a struct containing them that the `meta`
function must return. In the case of a handful of surveyed kernels
(`all,`, `any`, `avg_pool2d`), these quantities that are used both in
the `meta` and `impl` have the same meaning as certain kernel arguments
and in fact supersede them. Accordingly, the correspondence between a
kernel argument and the precomputed elements that supersede it is also
captured in `native_functions.yaml`. This information is used to unpack
the struct returned by `meta` and pass its contents correctly to the
`impl` function.

The primary goal is to avoid recompute and enhance developer experience
(e.g. sometimes people can forget to compute these elements while
porting a kernel).

Test Plan: Imported from OSS

Reviewed By: tugsbayasgalan

Differential Revision: D30407831

Pulled By: SplitInfinity

fbshipit-source-id: 00975525ea373721fe52d06f75cd4ac91f3dc556

aten/src/ATen/TensorMeta.h
aten/src/ATen/native/AveragePool2d.cpp
aten/src/ATen/native/Pool.h
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/cpu/AvgPoolKernel.cpp
aten/src/ATen/native/cuda/AveragePool2d.cu
aten/src/ATen/native/native_functions.yaml
tools/codegen/api/structured.py
tools/codegen/dest/register_dispatch_key.py
tools/codegen/gen.py
tools/codegen/model.py

index ac295ec..6a5491a 100644 (file)
@@ -26,6 +26,16 @@ namespace impl {
 #define TORCH_META_FUNC(name) void structured_##name::meta
 #define TORCH_META_FUNC2(name, overload) void structured_##name##_##overload::meta
 
+// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct as a return value.
+// They should be used when the kernel in question has precomputed values declared in native_functions.yaml and
+// the corresponding implementation should return an instance of the aforementioned struct.
+#define TORCH_PRECOMPUTE_META_FUNC(name) structured_##name::meta_return_ty structured_##name::meta
+#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) structured_##name##_##overload::meta_return_ty structured_##name##_##overload::meta
+
+// Use this to create a precompute struct in a meta function.
+#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
+#define TORCH_PRECOMPUTE_STRUCT2(name, overload) structured_##name##_##overload::precompute_out<>
+
 // Use this to define the prototype for an implementation.  This takes only
 // one argument, which is the name of the dispatch key entry you're
 // implementing.
index 2693cc6..8f264c0 100644 (file)
@@ -8,59 +8,81 @@ namespace at {
 namespace meta{
 using namespace native;
 
-TORCH_META_FUNC(avg_pool2d) (
-  const Tensor& input,
-  IntArrayRef kernel_size,
-  IntArrayRef stride,
-  IntArrayRef padding,
-  bool ceil_mode,
-  bool count_include_pad,
-  c10::optional<int64_t> divisor_override
-) {
+TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
+(const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override) {
   // #20866, #22032: Guarantee this for the official C++ API?
   TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
     "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
-  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
-  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
+  const int64_t kH = kernel_size[0];
+  const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
 
   TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
     "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
-  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
-  const int dW = stride.empty() ? kW :
-                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
+  const int64_t dH = stride.empty() ? kH : stride[0];
+  const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
 
   TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
     "avg_pool2d: padding must either be a single int, or a tuple of two ints");
-  const int padH = safe_downcast<int, int64_t>(padding[0]);
-  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
+  const int64_t padH = padding[0];
+  const int64_t padW = padding.size() == 1 ? padH : padding[1];
 
   TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
     "divisor must be not zero");
 
-  /* sizes */
   const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
   const int64_t nInputPlane = input.size(-3);
   const int64_t inputHeight = input.size(-2);
   const int64_t inputWidth = input.size(-1);
 
-  const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
-  const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
+  const int64_t outputHeight = pooling_output_shape<int64_t>(
+      inputHeight, kH, padH, dH, 1, ceil_mode);
+  const int64_t outputWidth =
+      pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
 
   auto memory_format = input.suggest_memory_format();
   pool2d_shape_check(
-    input,
-    kH, kW, dH, dW, padH, padW, 1, 1,
-    nInputPlane,
-    inputHeight, inputWidth,
-    outputHeight, outputWidth, memory_format);
+      input,
+      kH,
+      kW,
+      dH,
+      dW,
+      padH,
+      padW,
+      1,
+      1,
+      nInputPlane,
+      inputHeight,
+      inputWidth,
+      outputHeight,
+      outputWidth,
+      memory_format);
 
   /* resize output */
   if (input.ndimension() == 3) {
-    set_output(0, {nInputPlane, outputHeight, outputWidth}, input.options());
+    set_output(
+        0,
+        {nInputPlane,
+         outputHeight,
+         outputWidth},
+        input.options());
   }
   else {
-    set_output(0, {nbatch, nInputPlane, outputHeight, outputWidth}, input.options().memory_format(memory_format));
+    set_output(
+        0,
+        {nbatch,
+         nInputPlane,
+         outputHeight,
+         outputWidth},
+        input.options().memory_format(memory_format));
   }
+
+  return TORCH_PRECOMPUTE_STRUCT(avg_pool2d)().set_kH(kH).set_kW(kW).set_dH(dH).set_dW(dW).set_padH(padH).set_padW(padW);
 }
 
 TORCH_META_FUNC(avg_pool2d_backward) (
@@ -119,30 +141,30 @@ TORCH_META_FUNC(avg_pool2d_backward) (
 
 namespace native {
 
-TORCH_IMPL_FUNC(avg_pool2d_out_cpu) (
-  const Tensor &input,
-  IntArrayRef kernel_size,
-  IntArrayRef stride,
-  IntArrayRef padding,
-  bool ceil_mode,
-  bool count_include_pad,
-  c10::optional<int64_t> divisor_override,
-  const Tensor &output
-) {
-  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
-  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
-
-  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
-  const int dW = stride.empty() ? kW :
-                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
-
-  const int padH = safe_downcast<int, int64_t>(padding[0]);
-  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
-
+TORCH_IMPL_FUNC(avg_pool2d_out_cpu)
+(const Tensor& input,
+ int64_t kH,
+ int64_t kW,
+ int64_t dH,
+ int64_t dW,
+ int64_t padH,
+ int64_t padW,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override,
+ const Tensor& output) {
   avg_pool2d_kernel(
-      kCPU, output, input,
-      kW, kH, dW, dH, padW, padH,
-      count_include_pad, divisor_override);
+      kCPU,
+      output,
+      input,
+      kW,
+      kH,
+      dW,
+      dH,
+      padW,
+      padH,
+      count_include_pad,
+      divisor_override);
 }
 
 TORCH_IMPL_FUNC(avg_pool2d_backward_out_cpu) (
index 5fe979d..da77491 100644 (file)
@@ -16,10 +16,13 @@ DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
 DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
 
 // averge pooling has same signature for forward and backward
-using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
+using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
+using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
     int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
+
 DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
-DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_backward_kernel);
+DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
 
 namespace {
 
index 6e5a153..620908b 100644 (file)
@@ -109,16 +109,18 @@ void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
   }
 }
 
-TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
+TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
   check_all_any("all", self, maybe_get_output());
   auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
   resize_reduction(*this, self, dim, keepdim, out_dtype);
+  return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
 }
 
-TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
+TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
   check_all_any("any", self, maybe_get_output());
   auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
   resize_reduction(*this, self, dim, keepdim, out_dtype);
+  return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
 }
 
 void check_argmax_argmin(
@@ -1338,7 +1340,6 @@ Tensor all(const Tensor& self) {
 
 TORCH_IMPL_FUNC(all_out)
 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
-  dim = maybe_wrap_dim(dim, self.dim());
   auto iter = get_allany_iter(self, result, dim, keepdim);
   auto mut_result = const_cast<Tensor&>(result);
   if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
@@ -1370,8 +1371,10 @@ Tensor any(const Tensor& self) {
 }
 
 TORCH_IMPL_FUNC(any_out)
-(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
-  dim = maybe_wrap_dim(dim, self.dim());
+(const Tensor& self,
+ int64_t dim,
+ bool keepdim,
+ const Tensor& result) {
   auto iter = get_allany_iter(self, result, dim, keepdim);
   auto mut_result = const_cast<Tensor&>(result);
   if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
index 2aa075f..2bee020 100644 (file)
@@ -14,9 +14,9 @@ template <typename scalar_t>
 void cpu_avg_pool(
     const Tensor& output_,
     const Tensor& input_,
-    int kW, int kH,
-    int dW, int dH,
-    int padW, int padH,
+    int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH,
+    int64_t padW, int64_t padH,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
   auto input = input_.contiguous();
@@ -98,9 +98,9 @@ template <typename scalar_t>
 void cpu_avg_pool_channels_last(
     const Tensor& output_,
     const Tensor& input_,
-    int kW, int kH,
-    int dW, int dH,
-    int padW, int padH,
+    int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH,
+    int64_t padW, int64_t padH,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
   TORCH_CHECK(input_.ndimension() == 4,
@@ -359,9 +359,9 @@ void cpu_avg_pool_backward_channels_last(
 void avg_pool2d_kernel_impl(
     const Tensor& output,
     const Tensor& input,
-    int kW, int kH,
-    int dW, int dH,
-    int padW, int padH,
+    int64_t kW, int64_t kH,
+    int64_t dW, int64_t dH,
+    int64_t padW, int64_t padH,
     bool count_include_pad,
     c10::optional<int64_t> divisor_override) {
   switch (input.suggest_memory_format()) {
index 5de3adc..df9fcfe 100644 (file)
@@ -232,30 +232,31 @@ __global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads,
 
 } // anonymous namespace
 
-TORCH_IMPL_FUNC(avg_pool2d_out_cuda) (
-  const Tensor& input_,
-  IntArrayRef kernel_size,
-  IntArrayRef stride,
-  IntArrayRef padding,
-  bool ceil_mode,
-  bool count_include_pad,
-  c10::optional<int64_t> divisor_override,
-  const Tensor& output
-) {
+TORCH_IMPL_FUNC(avg_pool2d_out_cuda)
+(const Tensor& input_,
+ int64_t kH_,
+ int64_t kW_,
+ int64_t dH_,
+ int64_t dW_,
+ int64_t padH_,
+ int64_t padW_,
+ bool ceil_mode,
+ bool count_include_pad,
+ c10::optional<int64_t> divisor_override,
+ const Tensor& output) {
   TensorArg output_arg{ output, "output", 1 };
   TensorArg input_arg{ input_, "input_", 2 };
 
   checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg});
 
-  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
-  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
+  const int kH = safe_downcast<int, int64_t>(kH_);
+  const int kW = safe_downcast<int, int64_t>(kW_);
 
-  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
-  const int dW = stride.empty() ? kW :
-                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
+  const int dH = safe_downcast<int, int64_t>(dH_);
+  const int dW = safe_downcast<int, int64_t>(dW_);
 
-  const int padH = safe_downcast<int, int64_t>(padding[0]);
-  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
+  const int padH = safe_downcast<int, int64_t>(padH_);
+  const int padW = safe_downcast<int, int64_t>(padW_);
 
   /* sizes */
   const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
@@ -263,8 +264,8 @@ TORCH_IMPL_FUNC(avg_pool2d_out_cuda) (
   const int64_t inputHeight = input_.size(-2);
   const int64_t inputWidth = input_.size(-1);
 
-  const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
-  const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
+  int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
+  int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
   const auto memory_format = input_.suggest_memory_format();
 
   Tensor input = input_.contiguous(memory_format);
@@ -289,37 +290,55 @@ TORCH_IMPL_FUNC(avg_pool2d_out_cuda) (
           case MemoryFormat::ChannelsLast: {
             output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
             avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
-                <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
-                  count,
-                  input_data,
-                  nbatch,
-                  nInputPlane,
-                  inputHeight, inputWidth,
-                  outputHeight, outputWidth,
-                  kH, kW,
-                  dH, dW,
-                  padH, padW,
-                  output_data,
-                  divisor_override_value,
-                  count_include_pad, use_divisor);
+                <<<num_blocks,
+                   num_threads,
+                   0,
+                   at::cuda::getCurrentCUDAStream()>>>(
+                    count,
+                    input_data,
+                    nbatch,
+                    nInputPlane,
+                    inputHeight,
+                    inputWidth,
+                    outputHeight,
+                    outputWidth,
+                    kH,
+                    kW,
+                    dH,
+                    dW,
+                    padH,
+                    padW,
+                    output_data,
+                    divisor_override_value,
+                    count_include_pad,
+                    use_divisor);
             C10_CUDA_KERNEL_LAUNCH_CHECK();
             break;
           }
           case MemoryFormat::Contiguous: {
             avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
-              <<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
-                count,
-                input_data,
-                nbatch,
-                nInputPlane,
-                inputHeight, inputWidth,
-                outputHeight, outputWidth,
-                kH, kW,
-                dH, dW,
-                padH, padW,
-                output_data,
-                divisor_override_value,
-                count_include_pad, use_divisor);
+                <<<num_blocks,
+                   num_threads,
+                   0,
+                   at::cuda::getCurrentCUDAStream()>>>(
+                    count,
+                    input_data,
+                    nbatch,
+                    nInputPlane,
+                    inputHeight,
+                    inputWidth,
+                    outputHeight,
+                    outputWidth,
+                    kH,
+                    kW,
+                    dH,
+                    dW,
+                    padH,
+                    padW,
+                    output_data,
+                    divisor_override_value,
+                    count_include_pad,
+                    use_divisor);
             C10_CUDA_KERNEL_LAUNCH_CHECK();
             break;
           }
index 688763e..fae433c 100644 (file)
 - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
   structured: True
+  precomputed:
+  - dim -> int dim
   dispatch:
     CPU, CUDA: all_out
 
 - func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
   structured: True
+  precomputed:
+  - dim -> int dim
   dispatch:
     CPU, CUDA: any_out
 
 - func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
   structured: True
+  precomputed:
+  - kernel_size -> int kH, int kW
+  - stride -> int dH, int dW
+  - padding -> int padH, int padW
   dispatch:
     CPU: avg_pool2d_out_cpu
     CUDA: avg_pool2d_out_cuda
index 4f1437f..6aab794 100644 (file)
@@ -84,7 +84,27 @@ def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[B
 
 def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
     args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
-    args.extend(g.out.func.arguments.non_out)
+
+    if g.out.precomputed:
+        # A list of parameters for the impl function with
+        # certain parameters replaced with precomputed counterparts
+        # as specified in native_functions.yaml.
+        non_out_args_replaced: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
+
+        for a in g.out.func.arguments.non_out:
+            if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
+                # If a is in precompute.replace, append the parameters
+                # that should replace it onto non_out_args_replaced.
+                for replacement in g.out.precomputed.replace[a.name]:
+                    non_out_args_replaced.append(replacement)
+            else:
+                # If not, push a as it is.
+                non_out_args_replaced.append(a)
+
+        args.extend(non_out_args_replaced)
+    else:
+        args.extend(g.out.func.arguments.non_out)
+
     args.extend(g.out.func.arguments.out)
     return [r for arg in args for r in argument(arg)]
 
index 784ee56..ec3a2e6 100644 (file)
@@ -584,7 +584,29 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
                     method=False
                 )
             )
-            sig_body.append(f"op.meta({meta_exprs});")
+
+            if self.g.out.precomputed:
+                # If this function group has precomputed elements, the meta function
+                # returns a struct containing them which must be saved so that it
+                # can be unpacked when generating code to call the impl.
+                sig_body.append(f"auto precompute = op.meta({meta_exprs});")
+
+                # Put all of the contents of the precompute struct into the context
+                # so that translate will be able to return the correct args for the
+                # call to the impl.
+                for precomputed_elems in self.g.out.precomputed.replace.values():
+                    for arg in precomputed_elems:
+                        context.append(Expr(
+                            expr=f"precompute.{arg.name}",
+                            type=structured.argument_type(arg, binds=arg.name),
+                        ))
+
+                # Add a use of the precompute struct so FB internal compilers don't
+                # complain that there is an unused variable.
+                sig_body.append("(void)precompute;")
+            else:
+                sig_body.append(f"op.meta({meta_exprs});")
+
 
             # After running meta, op.outputs_ is guaranteed to be valid;
             # add it to the context
index 203b5a9..c986f83 100644 (file)
@@ -456,9 +456,98 @@ def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
         parent_class = g.out.structured_inherits
         if parent_class is None:
             parent_class = "at::impl::MetaBase"
+        meta_return = "void"
+        precomputed = g.out.precomputed if g.structured else None
+
+        if precomputed:
+            # Generate the template declaration with one bool parameter for each
+            # precomputed element. Each parameter is true if the corresponding (in
+            # terms of position) precomputed element has been set.
+            precomputed_elements = [elem for replace_list in precomputed.replace.values() for elem in replace_list]
+            precomputed_template_parameters = [elem.name.upper() for elem in precomputed_elements]
+            precomputed_template_params_str = ", ".join(f"bool {param} = false" for param in precomputed_template_parameters)
+            precompute_template_decl = f"template <{precomputed_template_params_str}>"
+
+            # Generate a string containing declarations of all precomputed elements.
+            precomputed_elements_with_cpp_types = [
+                structured.argument_type(elem, binds=elem.name)
+                for elem in precomputed_elements
+            ]
+
+            precomputed_elements_decl = ";\n".join(
+                f"{elem.cpp_type(strip_ref=True)} {elem.name}" for elem in precomputed_elements_with_cpp_types
+            )
+
+            # Generate "setter" methods for each precomputed element. Each method will return
+            # a new instance of precompute_out with the template parameter that corresponds to
+            # the member set by the method to true (to indicate that it has been set).
+            setter_methods = []
+            for i, elem in enumerate(precomputed_elements):
+                # Generate the signature. The return type will be the same
+                # as the type of `this` but with the template parameter
+                # corresponding to the element set by this method set to true.
+                # The assert generated below will ensure that this template
+                # parameter is false on the type of `this`.
+                return_ty_templates = ", ".join(
+                    precomputed_template_parameters[:i] + ["true"] + precomputed_template_parameters[i + 1:]
+                )
+                return_ty = f"precompute_out<{return_ty_templates}>"
+                elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(strip_ref=True)
+                signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
+
+                # Generate an assert which checks that the
+                # template parameter corresponding to the precomputed
+                # element that is set by this method is false on the
+                # class corresponding to the object that `this` points to.
+                # This ensures that each element can be set only once.
+                assert_msg = f"\"{precomputed_elements[i].name} already set\""
+                assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
+
+                # Generate the new object construction block. All state
+                # except the element that this method sets is copied from the
+                # object that `this` points to. The value for the element that
+                # the method sets is taken from a method parameter.
+                construction_stmts = []
+                construction_stmts.append(f"{return_ty} ret;")
+
+                for j, elem in enumerate(precomputed_elements):
+                    if i == j:
+                        construction_stmts.append(f"ret.{elem.name} = value;")
+                    else:
+                        construction_stmts.append(f"ret.{elem.name} = this->{elem.name};")
+
+                construction_stmts.append("return ret;")
+                construction_block = "\n".join(construction_stmts)
+
+                setter_methods.append(f"""
+                    {signature} {{
+                        {assert_stmt}
+                        {construction_block}
+                    }}
+                """)
+            setter_methods_decl = "\n".join(setter_methods)
+
+            # Meta should return an instance of the struct containing the precomputed elements.
+            meta_return_template_params = ", ".join(["true"] * len(precomputed_template_parameters))
+            # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
+            # type (which has a variable number of template parameters).
+            meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
+            meta_return = "meta_return_ty"
+            precomputed_decl = f"""
+                {precompute_template_decl}
+                struct TORCH_API precompute_out {{
+                    {setter_methods_decl}
+                    {precomputed_elements_decl};
+            }};"""
+        else:
+            meta_return_typedef = ""
+            precomputed_decl = ""
+
         return f"""\
 struct TORCH_API structured_{name} : public {parent_class} {{
-    void meta({args_str});
+    {precomputed_decl}
+    {meta_return_typedef}
+    {meta_return} meta({args_str});
 }};
 """
 
index 4f82b70..e604e72 100644 (file)
@@ -229,6 +229,14 @@ class NativeFunction:
     # changes the semantics of set_output to call the parent class.
     structured_inherits: Optional[str]
 
+    # Structured kernels can declare elements as "precomputed". These elements
+    # are returned by the meta function in one struct and passed to the impl
+    # function in lieu of certain kernel arguments that these precomputed
+    # elements supersede. Information about the names and types of these
+    # precomputed elements and how they correspond to kernel arguments is stored
+    # in this member, if applicable.
+    precomputed: Optional['Precompute']
+
     # Argument names whose default  should be excluded from the C++ interface.
     # Intended for resolving overload ambiguities between signatures.
     cpp_no_default_args: Set[str]
@@ -320,6 +328,10 @@ class NativeFunction:
         category_override = e.pop('category_override', None)
         assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}'
 
+        precomputed_dict = e.pop('precomputed', None)
+        assert precomputed_dict is None or structured is True
+        precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
+
         from tools.codegen.api import cpp
 
         raw_dispatch = e.pop('dispatch', None)
@@ -389,6 +401,7 @@ class NativeFunction:
             structured=structured,
             structured_delegate=structured_delegate,
             structured_inherits=structured_inherits,
+            precomputed=precomputed,
             manual_kernel_registration=manual_kernel_registration,
             manual_cpp_binding=manual_cpp_binding,
             python_module=python_module,
@@ -1496,3 +1509,42 @@ def parse_returns(return_decl: str) -> Tuple[Return, ...]:
     if return_decl[0] == '(' and return_decl[-1] == ')':
         return_decl = return_decl[1:-1]
     return tuple(Return.parse(arg) for arg in return_decl.split(', '))
+
+
+# A Precompute instance consists of a map from kernel argument name
+# to the list of Argument instances that should replace that
+# kernel argument in the impl function.
+@dataclass(frozen=True)
+class Precompute:
+    # A map from kernel argument name -> a list of precomputed
+    # elements that replaces/supersedes it.
+    replace: Dict[str, List[Argument]]
+
+    @staticmethod
+    def parse(src: object) -> 'Precompute':
+        assert isinstance(src, list)
+
+        # src is a list of strings of the format:
+        #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
+        # Parse this list to get the names of which precomputed elements
+        # should replace which kernel arguments.
+        replace = {}
+        for raw_replace_item in src:
+            assert isinstance(raw_replace_item, str)
+
+            arg, with_list_raw = raw_replace_item.split(' -> ')
+            with_list = with_list_raw.split(',')
+            with_list_args = [Argument.parse(name.strip()) for name in with_list]
+            replace[arg] = with_list_args
+
+        r = Precompute(replace=replace)
+        assert r.to_list() == src, 'r.to_list() != src'
+        return r
+
+    def to_list(self) -> List[str]:
+        replace_list = []
+        for kernel_param, replacement_params in self.replace.items():
+            replacements = ', '.join(str(param) for param in replacement_params)
+            replace_list.append(f'{kernel_param} -> {replacements}')
+
+        return replace_list