/*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
#include <memory>
#include <cstddef>
#include "common_types.h"
-#include "common_tools.h"
#include "tensor_type.h"
+#include "document.h"
namespace kernel_selector
{
struct val_t
{
uint32_t different_types : 1;
+ uint32_t different_input_weights_types : 1;
uint32_t offset : 1;
uint32_t pitches : 1;
uint32_t batching : 1;
uint32_t fixedKenrelDivider : 1;
uint32_t dynamicKenrelDivider : 1;
uint32_t dynamicKenrelDividerWithPadding : 1;
+ uint32_t position_sensitive : 1;
} pooling;
struct conv_t
{
uint32_t split : 1;
uint32_t dilation : 1;
- uint32_t depthwiseSeparableOpt : 1;
+ uint32_t depthwise_separable_opt : 1;
uint32_t transposed : 1;
uint32_t quantization : 1;
uint32_t calibration : 1;
+ uint32_t local : 1;
+ uint32_t grouped : 1;
} conv;
struct fc_t {} fc;
struct softmax_t
{
uint32_t winograd : 1;
} reorder;
+ struct eltwise_t
+ {
+ uint32_t stride : 1;
+ uint32_t broadcast : 1;
+ } eltwise;
struct lstm_gemm_t {
uint32_t bias : 1;
uint32_t hidden : 1;
struct lstm_elt_t {
uint32_t cell : 1;
} lstm_elt;
+ struct fused_conv_eltw_t {
+ // conv
+ uint32_t split : 1;
+ uint32_t dilation : 1;
+ uint32_t depthwise_separable_opt : 1;
+ uint32_t transposed : 1;
+ uint32_t quantization : 1;
+ uint32_t calibration : 1;
+ uint32_t local : 1;
+ uint32_t grouped : 1;
+ // eltw
+ uint32_t stride : 1;
+ // fused conv eltw
+ uint32_t rw_out_opt : 1;
+ } fused_conv_eltw;
} dedicated;
} val;
uint64_t raw;
void EnableAllOutputWeightsType();
void EnableFP16Emulation() { key.restrict.val.FP16Emulation = 1; }
void EnableDifferentTypes() { key.restrict.val.different_types = 1; }
+ void EnableDifferentInputWeightsTypes() {
+ key.restrict.val.different_input_weights_types = 1; }
void EnableInputLayout(DataLayout l) { key.inputLayout |= (1 << l); }
void EnableAllInputLayout() { key.inputLayout = 0xffffffff; }
void EnableOutputLayout(DataLayout l) { key.outputLayout |= (1 << l); }
void EnablePoolKernelDividerMode(KernelDividerMode m);
void EnablePoolType(PoolType t);
void EnablePoolRemainder(PoolRemainder r);
+ void EnablePositionSensitivePooling() { key.restrict.val.dedicated.pooling.position_sensitive = 1; }
void EnableSplitSupport() { key.restrict.val.dedicated.conv.split = 1; }
void EnableDilation() { key.restrict.val.dedicated.conv.dilation = 1; }
- void EnableDepthwiseSeparableOpt() { key.restrict.val.dedicated.conv.depthwiseSeparableOpt = 1; }
+ void EnableDepthwiseSeparableOpt() { key.restrict.val.dedicated.conv.depthwise_separable_opt = 1; }
+ void EnableLocalConvolution() { key.restrict.val.dedicated.conv.local = 1; }
+ void EnableGroupedConvolution() { key.restrict.val.dedicated.conv.grouped = 1; }
void EnableTranspose() { key.restrict.val.dedicated.conv.transposed = 1; }
void EnableInt8Quantization() { key.restrict.val.dedicated.conv.quantization = 1; }
void EnableOutputCalibration() { key.restrict.val.dedicated.conv.calibration = 1; }
+
+ void EnableFusedConvEltwSplitSupport() { key.restrict.val.dedicated.fused_conv_eltw.split = 1; }
+ void EnableFusedConvEltwDilation() { key.restrict.val.dedicated.fused_conv_eltw.dilation = 1; }
+ void EnableFusedConvEltwDepthwiseSeparableOpt() { key.restrict.val.dedicated.fused_conv_eltw.depthwise_separable_opt = 1; }
+ void EnableFusedConvEltwLocalConvolution() { key.restrict.val.dedicated.fused_conv_eltw.local = 1; }
+ void EnableFusedConvEltwGroupedConvolution() { key.restrict.val.dedicated.fused_conv_eltw.grouped = 1; }
+ void EnableFusedConvEltwTranspose() { key.restrict.val.dedicated.fused_conv_eltw.transposed = 1; }
+ void EnableFusedConvEltwInt8Quantization() { key.restrict.val.dedicated.fused_conv_eltw.quantization = 1; }
+ void EnableFusedConvEltwOutputCalibration() { key.restrict.val.dedicated.fused_conv_eltw.calibration = 1; }
+ void EnableFusedConvEltwEltwiseStride();
+
void EnableWinogradReorder() { key.restrict.val.dedicated.reorder.winograd = 1; }
void EnableSoftmaxDim(SoftmaxDim d);
void EnableConcatAxis(ConcatAxis a);
void EnableUpSamplingSampleType(SampleType a);
+ void EnableEltwiseStride();
+ void EnableEltwiseBroadcast() { key.restrict.val.dedicated.eltwise.broadcast = 1; }
void EnableLSTMGEMMBias() { key.restrict.val.dedicated.lstm_gemm.bias = 1; }
void EnableLSTMGEMMHidden() { key.restrict.val.dedicated.lstm_gemm.hidden = 1; }
void EnableLSTMEltCell() { key.restrict.val.dedicated.lstm_elt.cell = 1; }
void EnableArgMaxMinAxis(ArgMaxMinAxis a);
void EnableLookUpTableIndicesFormat(Datatype a);
void EnableIndexSelectAxis(IndexSelectAxis a);
+ void EnableFusedConvEltwiseRWOutOpt();
bool Support(const ParamsKey& k) const;
bool TuningSupport() const
{
return true;
return false;
}
+ bool isEnabledDifferentInputWeightsTypes() const {
+ return key.restrict.val.different_input_weights_types ? true : false;
+ }
ParamsKey Merge(const ParamsKey& k) const;
private:
bool bImageSupport = false;
bool bIMADSupport = false;
bool bIMMADSupport = false;
+ uint32_t computeUnitsCount = 0;
uint64_t maxWorkGroupSize = 0;
uint64_t maxLocalMemSize = 0;
uint64_t maxImage2dWidth = 0;
std::string deviceId = "";
std::string driverVersion = "";
std::string hostVersion = "";
+ std::shared_ptr<rapidjson::Document> deviceCache;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // base_activation_params
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ struct base_activation_params
+ {
+ ActivationFunction function = ActivationFunction::NONE;
+ float m = 1.f;
+ float n = 0.f;
+
+ base_activation_params() = default;
+ base_activation_params(const float m, const float n) : m(m), n(n) {}
+
+ virtual std::string to_string() const;
+ };
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// base_params
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct base_params : public Params
{
virtual ~base_params() {}
- ActivationFunction activationFunc = ActivationFunction::NONE;
- NonLinearParams activationParams;
- MultiDataTensor inputs;
- DataTensor output;
- bool gradient = false;
+ base_activation_params activation;
+ MultiDataTensor inputs;
+ DataTensor output;
+ bool gradient = false;
virtual std::string to_string() const;
virtual ParamsKey GetParamsKey() const;