[IE CLDNN] Added Mish operation (#1125)
authorRoman Lyamin <Roman.Lyamin@intel.com>
Thu, 9 Jul 2020 13:57:59 +0000 (16:57 +0300)
committerGitHub <noreply@github.com>
Thu, 9 Jul 2020 13:57:59 +0000 (16:57 +0300)
inference-engine/src/cldnn_engine/cldnn_program.cpp
inference-engine/src/cldnn_engine/cldnn_program.h
inference-engine/thirdparty/clDNN/api/activation.hpp
inference-engine/thirdparty/clDNN/kernel_selector/common/common_types.h
inference-engine/thirdparty/clDNN/kernel_selector/core/common/jitter.cpp
inference-engine/thirdparty/clDNN/kernel_selector/core/kernel_selector_common.cpp
inference-engine/thirdparty/clDNN/src/kernel_selector_helper.cpp
inference-engine/thirdparty/clDNN/tests/test_cases/activation_simple_gpu_test.cpp

index 72146b468c752821fa09f24ccab19ff682c8af21..affd81a26cb73e4b8538c7719d0778b06432b265 100644 (file)
@@ -575,6 +575,7 @@ Program::LayerType Program::LayerTypeFromStr(const std::string &str) {
         { "Sinh" , Sinh },
         { "Cosh" , Cosh },
         { "Swish" , Swish },
+        { "Mish" , Mish },
         { "Gelu" , Gelu },
         { "Atanh" , Atanh },
         { "Floor" , Floor },
@@ -1159,6 +1160,7 @@ void Program::CreateSingleLayerPrimitive(cldnn::topology& topology, InferenceEng
         case SoftPlus:
         case SoftSign:
         case Swish:
+        case Mish:
         case Gelu:
             CreateActivationPrimitive(topology, layer, LayerTypeFromStr(layer->type));
             break;
@@ -2767,6 +2769,8 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
             activationType = ELU;
         } else if (activation_type == "swish")  {
             activationType = Swish;
+        } else if (activation_type == "mish")  {
+            activationType = Mish;
         } else if (activation_type == "gelu")  {
             activationType = Gelu;
         } else if (activation_type == "relu")  {
@@ -2957,6 +2961,11 @@ void Program::CreateActivationPrimitive(cldnn::topology& topology, InferenceEngi
         func = cldnn::activation_func::swish;
         break;
     }
+    case Mish:
+    {
+        func = cldnn::activation_func::mish;
+        break;
+    }
     case Gelu:
     {
         func = cldnn::activation_func::gelu;
index 2d3f8cf0f6d2d2137de367de520df3f350d17b0f..4338852ca42749a962799fe957a11d86a154066a 100644 (file)
@@ -201,6 +201,7 @@ public:
         SoftPlus,
         SoftSign,
         Swish,
+        Mish,
         Gelu,
         Sin,
         Sinh,
index 80a120d3f73bb1dd5cadd539e2e4eb1fefd22aaf..e19deacb3ead202d753b36f065229ff134a64132 100644 (file)
@@ -68,6 +68,7 @@ enum class activation_func {
     softplus,             // ln(exp(val) + 1)
     softsign,             // (val/(1+|val|))
     swish,                // (val*sigmoid(val))
+    mish,                 // val*tanh(ln(1 + exp(val)))
     gelu                  // (0.5*val*(1 + erf(val / sqrt(2)))
 };
 
index c8e39446e3e8c525df5afe34e3959338d38cddaf..ced9e807179df428f37d9a839559fe4dc50cc885 100644 (file)
@@ -150,6 +150,7 @@ enum class ActivationFunction {
     SOFTPLUS,
     SOFTSIGN,
     SWISH,
+    MISH,
     GELU
 };
 
index a164454c948d43ce43f3d377400a8e97ca4e268b..734e02a449a4c23c139506346b96423f562cd7a2 100644 (file)
@@ -717,6 +717,18 @@ JitConstants MakeActivationJitConstants(ActivationFunction activation_function,
                     (input / (one + exp(neg(input)))).str()));
             break;
         }
+        case ActivationFunction::MISH: {
+            std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
+            auto bound = out_dt == Datatype::F32 ? "9.9f"_jit : "4.75h"_jit;
+            const JitTerm two("2." + type_suffix);
+            const JitTerm n((exp(input) + two) * exp(input));
+            const JitTerm common_mish_formula((input * n) / (n + two));
+
+            jitConstants.AddConstant(MakeJitConstant(
+                macro_def,
+                ternary(input.ge(bound), input, common_mish_formula).str()));
+            break;
+        }
         case ActivationFunction::GELU: {
             std::string type_suffix = out_dt == Datatype::F32 ? "f" : "h";
             const JitTerm half{"0.5" + type_suffix};
index c420f51fc59e378a99eb5a0d2f336d2c10294cac..593838e2dfebc80d5854d4829ea76db6a3a64a7c 100644 (file)
@@ -82,6 +82,7 @@ std::string toString(ActivationFunction activation) {
         case ActivationFunction::SOFTPLUS:                 method = "SOFTPLUS"; break;
         case ActivationFunction::SOFTSIGN:                 method = "SOFTSIGN"; break;
         case ActivationFunction::SWISH:                    method = "SWISH"; break;
+        case ActivationFunction::MISH:                     method = "MISH"; break;
         case ActivationFunction::GELU:                     method = "GELU"; break;
         default: break;
     }
index 87776e97629a86987f6b130652da40dcb6b3595b..d45c90327fdf1b68d50c2de5fde11e422c9e1bbf 100644 (file)
@@ -666,6 +666,8 @@ kernel_selector::activation_function get_kernel_selector_activation_param(activa
             return kernel_selector::activation_function::HARD_SIGMOID;
         case cldnn::activation_func::swish:
             return kernel_selector::activation_function::SWISH;
+        case cldnn::activation_func::mish:
+            return kernel_selector::activation_function::MISH;
         case cldnn::activation_func::gelu:
             return kernel_selector::activation_function::GELU;
         default:
index eb178a4f7c21ca20f27c372b689ef1f4fd3979f2..da9f4612d95cf97e8308e6d2cb51c6bce6ef6e7f 100644 (file)
@@ -655,6 +655,46 @@ TEST(activation_f32_fw_gpu, relu_basic_bfzyx) {
     }
 }
 
+TEST(activation_f16_fw_gpu, basic_yxfb_mish) {
+    const auto& engine = get_test_engine();
+
+    auto input = memory::allocate(engine, { data_types::f16, format::yxfb, { 1, 1, 5, 4 } });
+    set_values(input,
+    { FLOAT16(0.0f), FLOAT16(-2.0f), FLOAT16(-3.0f), FLOAT16(4.0f), FLOAT16(5.0f),
+      FLOAT16(2.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(-6.0f),
+      FLOAT16(3.0f), FLOAT16(-3.0f), FLOAT16(3.0f), FLOAT16(5.0f), FLOAT16(1.0f),
+      FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(1.0f), FLOAT16(-1.0f), FLOAT16(1.0f) });
+
+    topology topology(
+        input_layout("input", input.get_layout()),
+        activation("mish", "input", activation_func::mish));
+    network network(engine, topology);
+    network.set_input_data("input", input);
+    auto outputs = network.execute();
+    EXPECT_EQ(outputs.size(), size_t(1));
+    EXPECT_EQ(outputs.begin()->first, "mish");
+
+    auto output_memory = outputs.at("mish").get_memory();
+    auto output_layout = output_memory.get_layout();
+    auto output_ptr = output_memory.pointer<FLOAT16>();
+    auto input_ptr = input.pointer<FLOAT16>();
+
+    int y_size = output_layout.size.spatial[1];
+    int x_size = output_layout.size.spatial[0];
+    int f_size = output_layout.size.feature[0];
+    int b_size = output_layout.size.batch[0];
+    EXPECT_EQ(output_layout.format, format::yxfb);
+    EXPECT_EQ(y_size, 4);
+    EXPECT_EQ(x_size, 5);
+    EXPECT_EQ(f_size, 1);
+    EXPECT_EQ(b_size, 1);
+
+    for (size_t i = 0; i < output_layout.get_linear_size(); ++i) {
+        EXPECT_NEAR((FLOAT16)((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i])))),
+                    output_ptr[i], 1e-2f);
+    }
+}
+
 TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
 {
     //  Input:
@@ -700,6 +740,7 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
         activation_func::negative,
         activation_func::abs,
         activation_func::swish,
+        activation_func::mish,
         activation_func::gelu
     };
 
@@ -817,6 +858,10 @@ TEST(activation_f32_fw_gpu, basic_yxfb_all_functions)
                 case activation_func::swish:
                     EXPECT_FLOAT_EQ((float)input_ptr[i] / (1.f + std::exp((float)(-input_ptr[i]))), output_ptr[i]);
                     break;
+                case activation_func::mish:
+                    EXPECT_NEAR((float)input_ptr[i] * std::tanh(std::log(1.f + std::exp((float)input_ptr[i]))),
+                                output_ptr[i], 1e-5f);
+                    break;
                 case activation_func::gelu:
                     EXPECT_NEAR(0.5f * (float)input_ptr[i] * (1.f + std::erf((float)(input_ptr[i]) / std::sqrt(2.0f))),
                                 output_ptr[i], 1e-5f);