#pragma once
#include "fully_connected_kernel_base.h"
-#include <algorithm>
+
namespace kernel_selector {
class FullyConnectedBlockKernelBase : public FullyConnectedKernelBase {
JitConstants GetJitConstants(const fully_connected_params& params, const DispatchData& kd) const override;
// how many batches will a single work item compute
- static size_t GetBatchesPerWorkItem(const fully_connected_params& params) {
- auto batchSize = params.output.Batch().v;
- return std::min(batchSize, static_cast<size_t>(32U));
- }
+ virtual size_t GetBatchesPerWorkItem(const fully_connected_params& params) const;
- static size_t GetLocalGroupsSize(const fully_connected_params& params) {
- auto batchSize = params.output.Batch().v;
- return std::max(static_cast<size_t>(1U), batchSize / GetBatchesPerWorkItem(params));
- }
+ size_t GetLocalGroupsSize(const fully_connected_params& params) const;
// how many neurons for a single batch will a single work item produce
static size_t GetNeuronsPerWorkItem(const fully_connected_params& params) {
return 1;
}
};
-} // namespace kernel_selector
\ No newline at end of file
+} // namespace kernel_selector