/*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
#pragma once
#include "jitter.h"
-#include "tensor_type.h"
namespace kernel_selector {
struct weight_bias_params;
- struct convolution_params;
+ struct optional_params;
+ struct WeightsReorderParams;
- bool CheckConvolutionPaddedInputDesc(const convolution_params& params, const DataTensor& reqDesc);
- DataTensor GetConvolutionBFYXPaddedTensor(const convolution_params& cp);
- bool CovolutionCheckInput(const Params& p, const optional_params& o);
- bool CovolutionUpdateInputParams(convolution_params& params);
- WeightsType DataTypeToWeightsType(Datatype t);
- bool CheckWeights(const WeightsTensor& tensor, WeightsType reqType, std::vector<WeightsLayout> reqLayouts);
std::vector<size_t> GetImageSizes(const kernel_selector::WeightsTensor& dimensions, const WeightsLayout layout);
bool CheckImageSize(const weight_bias_params& newParams, const WeightsLayout layout);
- bool UpdateWeightsParams(weight_bias_params& newParams, const optional_params& options, std::vector<WeightsLayout> layouts, WeightsReorderParams& weightsReorderParams);
+ bool UpdateWeightsParams(weight_bias_params& newParams, const optional_params& options, std::vector<WeightsLayout> layouts, WeightsReorderParams& weightsReorderParams, const ParamsKey& paramsKey = ParamsKey());
JitConstants GetTensorFriendlyWorkGroupsJit(const DataTensor& t);
std::vector<size_t> GetTensorFriendlyWorkGroups(const DataTensor& t);
std::vector<size_t> GetOptimalLocalWorkGroupSizes(std::vector<size_t> gws);