/*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
*/
#include "jitter.h"
+#include "tensor_type.h"
namespace kernel_selector {
switch (wType)
{
case WeightsType::INT8: return GetTypeName<int8_t>();
+ case WeightsType::UINT8: return GetTypeName<uint8_t>();
case WeightsType::F16: return "half";
case WeightsType::F32: return GetTypeName<float>();
default: return "";
}
}
+ std::string toCodeString(float val) {
+ if (std::isinf(val))
+ return std::signbit(val) ? "-INFINITY" : "INFINITY";
+ std::stringstream ss;
+ // Workaround GCC compiler/STL bug
+ ss << "as_float(0x" << std::hex << *reinterpret_cast<uint32_t*>(&val) << ")";
+
+ ss << " /*" << std::scientific << val << "*/";
+ return ss.str();
+ }
+
+ std::string toCodeString(double val) {
+ if (std::isinf(val))
+ return std::signbit(val) ? "-INFINITY" : "INFINITY";
+ std::stringstream ss;
+ // Workaround GCC compiler/STL bug
+ ss << "as_double(0x" << std::hex << *reinterpret_cast<uint64_t*>(&val) << ")";
+
+ ss << " /*" << std::scientific << val << "*/";
+ return ss.str();
+ }
+
JitDefinitions JitConstants::GetDefinitions() const
{
JitDefinitions definitons;
return definitons;
}
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // TensorBaseTJitConstant
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ template<typename DType, typename Layout>
+ class TensorBaseTJitConstant : public JitConstant
+ {
+ protected:
+ TensorBaseTJitConstant(const std::string& name) : JitConstant(name) {}
+
+ public:
+
+ JitDefinitions GetDefinitions(const Tensor::TensorBaseT<DType, Layout>& t) const
+ {
+ JitDefinitions definitions{
+ { _name + "_TYPE", toCLType(t.GetDType()) },
+ { _name + "_OFFSET", toCodeString(t.GetFirstElementOffset()) },
+ { _name + "_VIEW_OFFSET", toCodeString(t.GetViewOffset()) },
+ { _name + "_LENGTH", toCodeString(t.LogicalSize()) },
+ { _name + "_DIMS", toCodeString(t.GetDims().size()) },
+ { _name + "_SIMPLE", toCodeString(t.SimpleLayout()) },
+ { "TO_" + _name + "_TYPE", "convert_" + toCLType(t.GetDType()) },
+ { _name + "_LAYOUT_" + toString(t.GetLayout()), "1" },
+ };
+
+ definitions.push_back({ _name + "_SIZE", toCodeString(t.GetDims().size()) });
+ definitions.push_back({ _name + "_SIZES", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.v; }) });
+ definitions.push_back({ _name + "_PITCHES", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.pitch; }) });
+ definitions.push_back({ _name + "_PAD_BEFORE", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.before; }) });
+ definitions.push_back({ _name + "_PAD_AFTER", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.after; }) });
+
+ return definitions;
+ }
+ };
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // DataTensorJitConstant
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ class DataTensorJitConstant : public TensorBaseTJitConstant<Datatype, DataLayout>
+ {
+ const DataTensor _tensor;
+
+ public:
+ DataTensorJitConstant(const std::string& name, const DataTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
+
+ JitDefinitions GetDefinitions() const override;
+ };
+
JitDefinitions DataTensorJitConstant::GetDefinitions() const
{
JitDefinitions baseDefinitions = TensorBaseTJitConstant::GetDefinitions(_tensor);
return definitions;
}
+ std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const DataTensor& value)
+ {
+ return std::static_pointer_cast<JitConstant>(std::make_shared<DataTensorJitConstant>(name, value));
+ }
+
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ // WeightTensorJitConstant
+ ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+ class WeightTensorJitConstant : public TensorBaseTJitConstant<WeightsType, WeightsLayout>
+ {
+ const WeightsTensor _tensor;
+
+ public:
+ WeightTensorJitConstant(const std::string& name, const WeightsTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
+
+ JitDefinitions GetDefinitions() const override;
+ };
+
JitDefinitions WeightTensorJitConstant::GetDefinitions() const
{
JitDefinitions baseDefinitions = TensorBaseTJitConstant::GetDefinitions(_tensor);
JitDefinitions definitions{
- { _name + "_SIZE_X", toCodeString(_tensor.X().v) },
- { _name + "_SIZE_Y", toCodeString(_tensor.Y().v) },
- { _name + "_IFM_NUM", toCodeString(_tensor.IFM().v) },
- { _name + "_OFM_NUM", toCodeString(_tensor.OFM().v) },
- { _name + "_X_PITCH", toCodeString(_tensor.X().pitch) },
- { _name + "_Y_PITCH", toCodeString(_tensor.Y().pitch) },
- { _name + "_IFM_PITCH", toCodeString(_tensor.IFM().pitch) },
- { _name + "_OFM_PITCH", toCodeString(_tensor.OFM().pitch) },
+ { _name + "_SIZE_X", toCodeString(_tensor.X().v) },
+ { _name + "_SIZE_Y", toCodeString(_tensor.Y().v) },
+ { _name + "_IFM_NUM", toCodeString(_tensor.IFM().v) },
+ { _name + "_OFM_NUM", toCodeString(_tensor.OFM().v) },
+ { _name + "_X_PITCH", toCodeString(_tensor.X().pitch) },
+ { _name + "_Y_PITCH", toCodeString(_tensor.Y().pitch) },
+ { _name + "_IFM_PITCH", toCodeString(_tensor.IFM().pitch) },
+ { _name + "_OFM_PITCH", toCodeString(_tensor.OFM().pitch) },
};
definitions.insert(definitions.end(), baseDefinitions.begin(), baseDefinitions.end());
return definitions;
}
- std::shared_ptr<JitConstant> MakeActivationJitConstants(ActivationFunction activation_function)
+ std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const WeightsTensor& value)
{
+ return std::static_pointer_cast<JitConstant>(std::make_shared<WeightTensorJitConstant>(name, value));
+ }
+
+ std::shared_ptr<JitConstant> MakeActivationJitConstants(ActivationFunction activation_function, const std::string& suffix)
+ {
+ std::string name = "ACTIVATION" + suffix;
// TODO: use native_exp and use cast for APL
switch (activation_function)
{
case ActivationFunction::LOGISTIC:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_VAL_ONE/(UNIT_VAL_ONE + exp(-input)))");
+ return MakeJitConstant(name + "(input, m, n)", "(UNIT_VAL_ONE/(UNIT_VAL_ONE + exp(-input)))");
case ActivationFunction::HYPERBOLIC_TAN:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(tanh(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(tanh(input))");
case ActivationFunction::RELU:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_MAX_FUNC(UNIT_VAL_ZERO, input))");
+ return MakeJitConstant(name + "(input, m, n)", "(UNIT_MAX_FUNC(UNIT_VAL_ZERO, input))");
case ActivationFunction::RELU_NEGATIVE_SLOPE:
- return MakeJitConstant("ACTIVATION(input, slope, n)", "isinf(TO_UNIT_TYPE(slope)) ? ((input >= UNIT_VAL_ZERO) ? \
+ return MakeJitConstant(name + "(input, slope, n)", "isinf(TO_UNIT_TYPE(slope)) ? ((input >= UNIT_VAL_ZERO) ? \
input : -TO_UNIT_TYPE(slope)) : \
(UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) + TO_UNIT_TYPE(slope) * UNIT_MIN_FUNC(input, UNIT_VAL_ZERO))");
case ActivationFunction::ELU:
- return MakeJitConstant("ACTIVATION(input, alpha, n)", "(UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) + \
+ return MakeJitConstant(name + "(input, alpha, n)", "(UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) + \
TO_UNIT_TYPE(alpha) * (exp(UNIT_MIN_FUNC(input, UNIT_VAL_ZERO)) - UNIT_VAL_ONE));");
case ActivationFunction::CLAMP:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_MAX_FUNC(TO_UNIT_TYPE(m), UNIT_MIN_FUNC(TO_UNIT_TYPE(n), input)))");
+ return MakeJitConstant(name + "(input, m, n)", "(UNIT_MAX_FUNC(TO_UNIT_TYPE(m), UNIT_MIN_FUNC(TO_UNIT_TYPE(n), input)))");
case ActivationFunction::SOFTRELU:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(log(UNIT_VAL_ONE + exp(input)))");
+ return MakeJitConstant(name + "(input, m, n)", "(log(UNIT_VAL_ONE + exp(input)))");
case ActivationFunction::ABS:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(fabs(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(fabs(input))");
case ActivationFunction::LINEAR:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(m*input + n)");
+ return MakeJitConstant(name + "(input, m, n)", "(m*input + n)");
case ActivationFunction::SQUARE:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(input*input)");
+ return MakeJitConstant(name + "(input, m, n)", "(input*input)");
case ActivationFunction::SQRT:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(sqrt(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(sqrt(input))");
case ActivationFunction::SIN:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(sin(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(sin(input))");
case ActivationFunction::ASIN:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(asin(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(asin(input))");
case ActivationFunction::SINH:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(sinh(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(sinh(input))");
case ActivationFunction::COS:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(cos(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(cos(input))");
case ActivationFunction::ACOS:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(acos(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(acos(input))");
case ActivationFunction::COSH:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(cosh(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(cosh(input))");
case ActivationFunction::LOG:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(log(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(log(input))");
case ActivationFunction::LOG2:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(log2(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(log2(input))");
case ActivationFunction::EXP:
- return MakeJitConstant("ACTIVATION(input, m, n)", "(exp(input))");
+ return MakeJitConstant(name + "(input, m, n)", "(exp(input))");
+ case ActivationFunction::NOT:
+ return MakeJitConstant(name + "(input, m, n)", "((input != 0) ? UNIT_VAL_ZERO : UNIT_VAL_ONE)");
case ActivationFunction::RELU_GRAD:
- return MakeJitConstant("ACTIVATION(input_grad, input, m, n)", "(input_grad * (input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)))");
+ return MakeJitConstant(name + "(input_grad, input, m, n)", "(input_grad * (input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)))");
case ActivationFunction::RELU_NEGATIVE_SLOPE_GRAD:
- return MakeJitConstant("ACTIVATION(input_grad, input, slope, n)", "(input_grad * ((input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)) + TO_UNIT_TYPE(slope) * (input <= 0 ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0))))");
+ return MakeJitConstant(name + "(input_grad, input, slope, n)", "(input_grad * ((input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)) + TO_UNIT_TYPE(slope) * (input <= 0 ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0))))");
case ActivationFunction::NONE_GRAD:
- return MakeJitConstant("ACTIVATION(input_grad, input, m, n)", "input_grad");
+ return MakeJitConstant(name + "(input_grad, input, m, n)", "input_grad");
case ActivationFunction::NONE:
default:
- return MakeJitConstant("ACTIVATION(input, m, n)", "input");
+ return MakeJitConstant(name + "(input, m, n)", "input");
}
}
case Datatype::INT8:
unit_type = "char";
unit_max_val = "CHAR_MAX";
- unit_min_val = "-UNIT_VAL_MAX";
+ unit_min_val = "CHAR_MIN";
unit_val_one = "(char) 1";
unit_val_zero = "(char) 0";
to_unit_type = "convert_char(v)";
unit_max_func = "max";
unit_min_func = "min";
break;
+ case Datatype::UINT8:
+ unit_type = "uchar";
+ unit_max_val = "UCHAR_MAX";
+ unit_min_val = "0";
+ unit_val_one = "(uchar) 1";
+ unit_val_zero = "(uchar) 0";
+ to_unit_type = "convert_uchar(v)";
+ unit_max_func = "max";
+ unit_min_func = "min";
+ break;
case Datatype::INT32:
unit_type = "int";
unit_max_val = "INT_MAX";
- unit_min_val = "-UNIT_VAL_MAX";
+ unit_min_val = "INT_MIN";
unit_val_one = "(int) 1";
unit_val_zero = "(int) 0";
to_unit_type = "convert_int(v)";
unit_max_func = "max";
unit_min_func = "min";
break;
+ case Datatype::UINT32:
+ unit_type = "uint";
+ unit_max_val = "UINT_MAX";
+ unit_min_val = "0";
+ unit_val_one = "(uint) 1";
+ unit_val_zero = "(uint) 0";
+ to_unit_type = "convert_uint(v)";
+ unit_max_func = "max";
+ unit_min_func = "min";
+ break;
case Datatype::INT64:
unit_type = "long";
unit_max_val = "LONG_MAX";
- unit_min_val = "-UNIT_VAL_MAX";
+ unit_min_val = "LONG_MIN";
unit_val_one = "(long) 1";
unit_val_zero = "(long) 0";
to_unit_type = "convert_long(v)";
MakeJitConstant("UNIT_MIN_FUNC", unit_min_func),
};
}
+
+ JitConstants MakeActivationJitConstants(const base_activation_params& params, const std::string& suffix)
+ {
+ return JitConstants{
+ MakeJitConstant("NL_M" + suffix, params.m),
+ MakeJitConstant("NL_N" + suffix, params.n),
+ MakeActivationJitConstants(params.function, suffix)
+ };
+ }
+
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// MakeBaseParamsJitConstants
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool bInt8Used = params.output.GetDType() == Datatype::INT8;
bool bInt32Used = params.output.GetDType() == Datatype::INT32;
bool bInt64Used = params.output.GetDType() == Datatype::INT64;
+ bool bUInt8Used = params.output.GetDType() == Datatype::UINT8;
+ bool bUInt32Used = params.output.GetDType() == Datatype::INT32;
for (const auto& i : params.inputs)
{
bFP16Used |= i.GetDType() == Datatype::F16;
bInt8Used |= i.GetDType() == Datatype::INT8;
bInt32Used |= i.GetDType() == Datatype::INT32;
bInt64Used |= i.GetDType() == Datatype::INT64;
+ bUInt8Used |= i.GetDType() == Datatype::UINT8;
+ bUInt32Used |= i.GetDType() == Datatype::UINT32;
}
JitConstants jit{
MakeJitConstant("INT8_UNIT_USED", bInt8Used),
MakeJitConstant("INT32_UNIT_USED", bInt32Used),
MakeJitConstant("INT64_UNIT_USED", bInt64Used),
+ MakeJitConstant("UINT8_UNIT_USED", bUInt8Used),
+ MakeJitConstant("UINT32_UNIT_USED", bUInt32Used),
MakeJitConstant("GRADIENT", params.gradient),
};
- // for activation function
- jit.AddConstants({
- MakeJitConstant("NL_M", params.activationParams.m),
- MakeJitConstant("NL_N", params.activationParams.n),
- MakeActivationJitConstants(params.activationFunc),
- });
-
if (bInt8Used)
{
jit.Merge(MakeUnitTypeJitConstants(Datatype::INT8));
{
jit.Merge(MakeUnitTypeJitConstants(Datatype::INT64));
}
+ else if (bUInt8Used)
+ {
+ jit.Merge(MakeUnitTypeJitConstants(Datatype::UINT8));
+ }
+ else if (bUInt32Used)
+ {
+ jit.Merge(MakeUnitTypeJitConstants(Datatype::UINT32));
+ }
else
{
jit.Merge(MakeUnitTypeJitConstants(Datatype::F32));
}
+ // for activation function
+ jit.Merge(MakeActivationJitConstants(params.activation));
+
for (size_t i = 0; i < params.inputs.size(); i++)
{
jit.AddConstant(MakeJitConstant("INPUT" + toCodeString(i), params.inputs[i]));
return jit;
}
-}
\ No newline at end of file
+}