Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / common / jitter.cpp
index 1a426a0..0f1cf4d 100644 (file)
@@ -1,5 +1,5 @@
 /*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2019 Intel Corporation
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
 */
 
 #include "jitter.h"
+#include "tensor_type.h"
 
 namespace kernel_selector {
 
@@ -23,6 +24,7 @@ namespace kernel_selector {
         switch (wType)
         {
         case WeightsType::INT8: return GetTypeName<int8_t>();
+        case WeightsType::UINT8: return GetTypeName<uint8_t>();
         case WeightsType::F16:  return "half";
         case WeightsType::F32:  return GetTypeName<float>();
         default: return "";
@@ -58,6 +60,28 @@ namespace kernel_selector {
         }
     }
 
+    std::string toCodeString(float val) {
+        if (std::isinf(val))
+            return std::signbit(val) ? "-INFINITY" : "INFINITY";
+        std::stringstream ss;
+        // Workaround GCC compiler/STL bug
+        ss << "as_float(0x" << std::hex << *reinterpret_cast<uint32_t*>(&val) << ")";
+
+        ss << " /*" << std::scientific << val << "*/";
+        return ss.str();
+    }
+
+    std::string toCodeString(double val) {
+        if (std::isinf(val))
+            return std::signbit(val) ? "-INFINITY" : "INFINITY";
+        std::stringstream ss;
+        // Workaround GCC compiler/STL bug
+        ss << "as_double(0x" << std::hex << *reinterpret_cast<uint64_t*>(&val) << ")";
+
+        ss << " /*" << std::scientific << val << "*/";
+        return ss.str();
+    }
+
     JitDefinitions JitConstants::GetDefinitions() const
     {
         JitDefinitions definitons;
@@ -70,6 +94,53 @@ namespace kernel_selector {
         return definitons;
     }
 
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    // TensorBaseTJitConstant
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    template<typename DType, typename Layout>
+    class TensorBaseTJitConstant : public JitConstant
+    {
+    protected:
+        TensorBaseTJitConstant(const std::string& name) : JitConstant(name) {}
+
+    public:
+
+        JitDefinitions GetDefinitions(const Tensor::TensorBaseT<DType, Layout>& t) const
+        {
+            JitDefinitions definitions{
+                { _name + "_TYPE",          toCLType(t.GetDType()) },
+            { _name + "_OFFSET",        toCodeString(t.GetFirstElementOffset()) },
+            { _name + "_VIEW_OFFSET",   toCodeString(t.GetViewOffset()) },
+            { _name + "_LENGTH",        toCodeString(t.LogicalSize()) },
+            { _name + "_DIMS",          toCodeString(t.GetDims().size()) },
+            { _name + "_SIMPLE",        toCodeString(t.SimpleLayout()) },
+            { "TO_" + _name + "_TYPE",  "convert_" + toCLType(t.GetDType()) },
+            { _name + "_LAYOUT_" + toString(t.GetLayout()), "1" },
+            };
+
+            definitions.push_back({ _name + "_SIZE",        toCodeString(t.GetDims().size()) });
+            definitions.push_back({ _name + "_SIZES",       toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.v; }) });
+            definitions.push_back({ _name + "_PITCHES",     toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.pitch; }) });
+            definitions.push_back({ _name + "_PAD_BEFORE",  toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.before; }) });
+            definitions.push_back({ _name + "_PAD_AFTER",   toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.after; }) });
+
+            return definitions;
+        }
+    };
+
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    // DataTensorJitConstant
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    class DataTensorJitConstant : public TensorBaseTJitConstant<Datatype, DataLayout>
+    {
+        const DataTensor _tensor;
+
+    public:
+        DataTensorJitConstant(const std::string& name, const DataTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
+
+        JitDefinitions GetDefinitions() const override;
+    };
+
     JitDefinitions DataTensorJitConstant::GetDefinitions() const
     {
         JitDefinitions baseDefinitions = TensorBaseTJitConstant::GetDefinitions(_tensor);
@@ -100,19 +171,37 @@ namespace kernel_selector {
         return definitions;
     }
 
+    std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const DataTensor& value)
+    {
+        return std::static_pointer_cast<JitConstant>(std::make_shared<DataTensorJitConstant>(name, value));
+    }
+
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    // WeightTensorJitConstant
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    class WeightTensorJitConstant : public TensorBaseTJitConstant<WeightsType, WeightsLayout>
+    {
+        const WeightsTensor _tensor;
+
+    public:
+        WeightTensorJitConstant(const std::string& name, const WeightsTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
+
+        JitDefinitions GetDefinitions() const override;
+    };
+
     JitDefinitions WeightTensorJitConstant::GetDefinitions() const
     {
         JitDefinitions baseDefinitions = TensorBaseTJitConstant::GetDefinitions(_tensor);
 
         JitDefinitions definitions{
-        { _name + "_SIZE_X",        toCodeString(_tensor.X().v) },
-        { _name + "_SIZE_Y",        toCodeString(_tensor.Y().v) },
-        { _name + "_IFM_NUM",       toCodeString(_tensor.IFM().v) },
-        { _name + "_OFM_NUM",       toCodeString(_tensor.OFM().v) },
-        { _name + "_X_PITCH",       toCodeString(_tensor.X().pitch) },
-        { _name + "_Y_PITCH",       toCodeString(_tensor.Y().pitch) },
-        { _name + "_IFM_PITCH",     toCodeString(_tensor.IFM().pitch) },
-        { _name + "_OFM_PITCH",     toCodeString(_tensor.OFM().pitch) },
+            { _name + "_SIZE_X",        toCodeString(_tensor.X().v) },
+            { _name + "_SIZE_Y",        toCodeString(_tensor.Y().v) },
+            { _name + "_IFM_NUM",       toCodeString(_tensor.IFM().v) },
+            { _name + "_OFM_NUM",       toCodeString(_tensor.OFM().v) },
+            { _name + "_X_PITCH",       toCodeString(_tensor.X().pitch) },
+            { _name + "_Y_PITCH",       toCodeString(_tensor.Y().pitch) },
+            { _name + "_IFM_PITCH",     toCodeString(_tensor.IFM().pitch) },
+            { _name + "_OFM_PITCH",     toCodeString(_tensor.OFM().pitch) },
         };
 
         definitions.insert(definitions.end(), baseDefinitions.begin(), baseDefinitions.end());
@@ -120,63 +209,71 @@ namespace kernel_selector {
         return definitions;
     }
 
-    std::shared_ptr<JitConstant> MakeActivationJitConstants(ActivationFunction activation_function)
+    std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const WeightsTensor& value)
     {
+        return std::static_pointer_cast<JitConstant>(std::make_shared<WeightTensorJitConstant>(name, value));
+    }
+
+    std::shared_ptr<JitConstant> MakeActivationJitConstants(ActivationFunction activation_function, const std::string& suffix)
+    {
+        std::string name = "ACTIVATION" + suffix;
         // TODO: use native_exp and use cast for APL
         switch (activation_function)
         {
         case ActivationFunction::LOGISTIC:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_VAL_ONE/(UNIT_VAL_ONE + exp(-input)))");
+            return MakeJitConstant(name + "(input, m, n)", "(UNIT_VAL_ONE/(UNIT_VAL_ONE + exp(-input)))");
         case ActivationFunction::HYPERBOLIC_TAN:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(tanh(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(tanh(input))");
         case ActivationFunction::RELU:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_MAX_FUNC(UNIT_VAL_ZERO, input))");
+            return MakeJitConstant(name + "(input, m, n)", "(UNIT_MAX_FUNC(UNIT_VAL_ZERO, input))");
         case ActivationFunction::RELU_NEGATIVE_SLOPE:
-            return MakeJitConstant("ACTIVATION(input, slope, n)", "isinf(TO_UNIT_TYPE(slope)) ? ((input >= UNIT_VAL_ZERO) ? \
+            return MakeJitConstant(name + "(input, slope, n)", "isinf(TO_UNIT_TYPE(slope)) ? ((input >= UNIT_VAL_ZERO) ? \
                                                         input : -TO_UNIT_TYPE(slope)) : \
                                                         (UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) + TO_UNIT_TYPE(slope) * UNIT_MIN_FUNC(input, UNIT_VAL_ZERO))");
         case ActivationFunction::ELU:
-            return MakeJitConstant("ACTIVATION(input, alpha, n)", "(UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) +  \
+            return MakeJitConstant(name + "(input, alpha, n)", "(UNIT_MAX_FUNC(input, UNIT_VAL_ZERO) +  \
                                                         TO_UNIT_TYPE(alpha) * (exp(UNIT_MIN_FUNC(input, UNIT_VAL_ZERO)) - UNIT_VAL_ONE));");
         case ActivationFunction::CLAMP:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(UNIT_MAX_FUNC(TO_UNIT_TYPE(m), UNIT_MIN_FUNC(TO_UNIT_TYPE(n), input)))");
+            return MakeJitConstant(name + "(input, m, n)", "(UNIT_MAX_FUNC(TO_UNIT_TYPE(m), UNIT_MIN_FUNC(TO_UNIT_TYPE(n), input)))");
         case ActivationFunction::SOFTRELU:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(log(UNIT_VAL_ONE + exp(input)))");
+            return MakeJitConstant(name + "(input, m, n)", "(log(UNIT_VAL_ONE + exp(input)))");
         case ActivationFunction::ABS:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(fabs(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(fabs(input))");
         case ActivationFunction::LINEAR:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(m*input + n)");
+            return MakeJitConstant(name + "(input, m, n)", "(m*input + n)");
         case ActivationFunction::SQUARE:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(input*input)");
+            return MakeJitConstant(name + "(input, m, n)", "(input*input)");
         case ActivationFunction::SQRT:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(sqrt(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(sqrt(input))");
         case ActivationFunction::SIN:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(sin(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(sin(input))");
         case ActivationFunction::ASIN:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(asin(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(asin(input))");
         case ActivationFunction::SINH:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(sinh(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(sinh(input))");
         case ActivationFunction::COS:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(cos(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(cos(input))");
         case ActivationFunction::ACOS:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(acos(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(acos(input))");
         case ActivationFunction::COSH:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(cosh(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(cosh(input))");
         case ActivationFunction::LOG:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(log(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(log(input))");
         case ActivationFunction::LOG2:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(log2(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(log2(input))");
         case ActivationFunction::EXP:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "(exp(input))");
+            return MakeJitConstant(name + "(input, m, n)", "(exp(input))");
+        case ActivationFunction::NOT:
+            return MakeJitConstant(name + "(input, m, n)", "((input != 0) ? UNIT_VAL_ZERO : UNIT_VAL_ONE)");
         case ActivationFunction::RELU_GRAD:
-            return MakeJitConstant("ACTIVATION(input_grad, input, m, n)", "(input_grad * (input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)))");
+            return MakeJitConstant(name + "(input_grad, input, m, n)", "(input_grad * (input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)))");
         case ActivationFunction::RELU_NEGATIVE_SLOPE_GRAD:
-            return MakeJitConstant("ACTIVATION(input_grad, input, slope, n)", "(input_grad * ((input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)) + TO_UNIT_TYPE(slope) * (input <= 0 ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0))))");
+            return MakeJitConstant(name + "(input_grad, input, slope, n)", "(input_grad * ((input > UNIT_VAL_ZERO ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0)) + TO_UNIT_TYPE(slope) * (input <= 0 ? TO_UNIT_TYPE(1) : TO_UNIT_TYPE(0))))");
         case ActivationFunction::NONE_GRAD:
-            return MakeJitConstant("ACTIVATION(input_grad, input, m, n)", "input_grad");
+            return MakeJitConstant(name + "(input_grad, input, m, n)", "input_grad");
         case ActivationFunction::NONE:
         default:
-            return MakeJitConstant("ACTIVATION(input, m, n)", "input");
+            return MakeJitConstant(name + "(input, m, n)", "input");
         }
     }
 
@@ -195,27 +292,47 @@ namespace kernel_selector {
         case Datatype::INT8:
             unit_type = "char";
             unit_max_val = "CHAR_MAX";
-            unit_min_val = "-UNIT_VAL_MAX";
+            unit_min_val = "CHAR_MIN";
             unit_val_one = "(char) 1";
             unit_val_zero = "(char) 0";
             to_unit_type = "convert_char(v)";
             unit_max_func = "max";
             unit_min_func = "min";
             break;
+        case Datatype::UINT8:
+            unit_type = "uchar";
+            unit_max_val = "UCHAR_MAX";
+            unit_min_val = "0";
+            unit_val_one = "(uchar) 1";
+            unit_val_zero = "(uchar) 0";
+            to_unit_type = "convert_uchar(v)";
+            unit_max_func = "max";
+            unit_min_func = "min";
+            break;
         case Datatype::INT32:
             unit_type = "int";
             unit_max_val = "INT_MAX";
-            unit_min_val = "-UNIT_VAL_MAX";
+            unit_min_val = "INT_MIN";
             unit_val_one = "(int) 1";
             unit_val_zero = "(int) 0";
             to_unit_type = "convert_int(v)";
             unit_max_func = "max";
             unit_min_func = "min";
             break;
+        case Datatype::UINT32:
+            unit_type = "uint";
+            unit_max_val = "UINT_MAX";
+            unit_min_val = "0";
+            unit_val_one = "(uint) 1";
+            unit_val_zero = "(uint) 0";
+            to_unit_type = "convert_uint(v)";
+            unit_max_func = "max";
+            unit_min_func = "min";
+            break;
         case Datatype::INT64:
             unit_type = "long";
             unit_max_val = "LONG_MAX";
-            unit_min_val = "-UNIT_VAL_MAX";
+            unit_min_val = "LONG_MIN";
             unit_val_one = "(long) 1";
             unit_val_zero = "(long) 0";
             to_unit_type = "convert_long(v)";
@@ -256,6 +373,16 @@ namespace kernel_selector {
             MakeJitConstant("UNIT_MIN_FUNC",        unit_min_func),
         };
     }
+
+    JitConstants MakeActivationJitConstants(const base_activation_params& params, const std::string& suffix)
+    {
+        return JitConstants{
+            MakeJitConstant("NL_M" + suffix, params.m),
+            MakeJitConstant("NL_N" + suffix, params.n),
+            MakeActivationJitConstants(params.function, suffix)
+        };
+    }
+
     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
     // MakeBaseParamsJitConstants
     ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -265,12 +392,16 @@ namespace kernel_selector {
         bool bInt8Used = params.output.GetDType() == Datatype::INT8;
         bool bInt32Used = params.output.GetDType() == Datatype::INT32;
         bool bInt64Used = params.output.GetDType() == Datatype::INT64;
+        bool bUInt8Used = params.output.GetDType() == Datatype::UINT8;
+        bool bUInt32Used = params.output.GetDType() == Datatype::INT32;
         for (const auto& i : params.inputs)
         {
             bFP16Used |= i.GetDType() == Datatype::F16;
             bInt8Used |= i.GetDType() == Datatype::INT8;
             bInt32Used |= i.GetDType() == Datatype::INT32;
             bInt64Used |= i.GetDType() == Datatype::INT64;
+            bUInt8Used |= i.GetDType() == Datatype::UINT8;
+            bUInt32Used |= i.GetDType() == Datatype::UINT32;
         }
 
         JitConstants jit{
@@ -281,16 +412,11 @@ namespace kernel_selector {
             MakeJitConstant("INT8_UNIT_USED",       bInt8Used),
             MakeJitConstant("INT32_UNIT_USED",      bInt32Used),
             MakeJitConstant("INT64_UNIT_USED",      bInt64Used),
+            MakeJitConstant("UINT8_UNIT_USED",      bUInt8Used),
+            MakeJitConstant("UINT32_UNIT_USED",     bUInt32Used),
             MakeJitConstant("GRADIENT",             params.gradient),
         };
 
-        // for activation function
-        jit.AddConstants({
-            MakeJitConstant("NL_M",                 params.activationParams.m),
-            MakeJitConstant("NL_N",                 params.activationParams.n),
-            MakeActivationJitConstants(params.activationFunc),
-            });
-
         if (bInt8Used)
         {
             jit.Merge(MakeUnitTypeJitConstants(Datatype::INT8));
@@ -307,11 +433,22 @@ namespace kernel_selector {
         {
             jit.Merge(MakeUnitTypeJitConstants(Datatype::INT64));
         }
+        else if (bUInt8Used)
+        {
+            jit.Merge(MakeUnitTypeJitConstants(Datatype::UINT8));
+        }
+        else if (bUInt32Used)
+        {
+            jit.Merge(MakeUnitTypeJitConstants(Datatype::UINT32));
+        }
         else
         {
             jit.Merge(MakeUnitTypeJitConstants(Datatype::F32));
         }
 
+        // for activation function
+        jit.Merge(MakeActivationJitConstants(params.activation));
+
         for (size_t i = 0; i < params.inputs.size(); i++)
         {
             jit.AddConstant(MakeJitConstant("INPUT" + toCodeString(i), params.inputs[i]));
@@ -344,4 +481,4 @@ namespace kernel_selector {
         return jit;
     }
 
-}
\ No newline at end of file
+}