Instead of writing another `OpKernel` with redundant code as above, often you
will be able to use a C++ template instead. You will still have one kernel
registration (`REGISTER_KERNEL_BUILDER` call) per overload.
-<pre class="prettyprint"><code class="lang-cpp">
-<b>template <typename T></b>
+```c++
+template <typename T>
class ZeroOutOp : public OpKernel {
public:
- explicit ZeroOutOp(OpKernelConstruction\* context) : OpKernel(context) {}<br/>
- void Compute(OpKernelContext\* context) override {
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
// Grab the input tensor
- const Tensor& input\_tensor = context->input(0);
- auto input = input\_tensor.flat<b><T></b>();<br/>
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<T>();
+
// Create an output tensor
Tensor* output = NULL;
- OP\_REQUIRES\_OK(context,
- context->allocate\_output(0, input_tensor.shape(), &output));
- auto output\_flat = output->template flat<b><T></b>();<br/>
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_tensor.shape(), &output));
+ auto output_flat = output->template flat<T>();
+
// Set all the elements of the output tensor to 0
const int N = input.size();
- for (int i = 0; i < N; i++) {
- output\_flat(i) = 0;
- }<br/>
+ for (int i = 0; i < N; i++) {
+ output_flat(i) = 0;
+ }
+
// Preserve the first input value
- if (N > 0) output\_flat(0) = input(0);
+ if (N > 0) output_flat(0) = input(0);
}
-};<br/>
-// Note that TypeConstraint<int32>("T") means that attr "T" (defined
+};
+
+// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
-// instantiation.</b>
-REGISTER\_KERNEL\_BUILDER(
+// instantiation.
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint<int32>("T"),
- <b>ZeroOutOp<int32></b>);
-REGISTER\_KERNEL\_BUILDER(
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int32>("T"),
+ ZeroOutOp<int32>);
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint<float>("T"),
- <b>ZeroOutOp<float></b>);
-<b>REGISTER\_KERNEL\_BUILDER(
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ ZeroOutOp<float>);
+REGISTER_KERNEL_BUILDER(
Name("ZeroOut")
- .Device(DEVICE\_CPU)
- .TypeConstraint<double>("T"),
- ZeroOutOp<double>);
-</b></code></pre>
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ ZeroOutOp<double>);
+```
If you have more than a couple overloads, you can put the registration in a
macro.