class Mean : public KernelWithParams<ReducerParams>
{
public:
- Mean(const Tensor *input, const Tensor *axes, Tensor *output, const ReducerParams ¶ms);
+ Mean(const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index,
+ Tensor *resolved_axes, Tensor *temp_sum, const ReducerParams ¶ms);
const Tensor *input() const { return _inputs[0]; }
const Tensor *axes() const { return _inputs[1]; }
void evalQuantizedS16() const;
private:
- std::unique_ptr<Tensor> _temp_index;
- std::unique_ptr<Tensor> _resolved_axes;
- std::unique_ptr<Tensor> _temp_sum;
+ bool _need_temporaries = false;
};
} // namespace kernels