/*
-// Copyright (c) 2018 Intel Corporation
+// Copyright (c) 2018-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
*/
#include "eltwise_kernel_fs_bs_yx_bsv4_fsv32.h"
-#include "kernel_selector_utils.h"
+#include "kernel_selector_utils.h"
namespace kernel_selector {
k.EnableBatching();
k.EnableInt8Quantization();
k.EnableOutputCalibration();
+ k.EnableEltwiseStride();
return k;
}
kd.lws1 = 1;
kd.lws2 = 8;
+ kd.effiency = FORCE_PRIORITY_3;
return kd;
}
}
inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
+
+ if (!params.stride.empty())
+ {
+ jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
+ jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
+ }
}
jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
switch (ew.mode)
{
- case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
- case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
- case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
- case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
- case EltwiseMode::MODULU: op += cast_type + "fmod(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::MIN: op += cast_type + "fmin(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::MAX: op += cast_type + "fmax(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
- case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
- case EltwiseMode::ASSIGN: op += input0_str; break;
+ case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
+ case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
+ case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
+ case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
+ case EltwiseMode::MODULU:
+ case EltwiseMode::MIN:
+ case EltwiseMode::MAX:
+ {
+ auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max"));
+ auto input_0_type = params.inputs[0].GetDType();
+ auto input_1_type = params.inputs[1].GetDType();
+
+ // input_0 == int
+ if (input_0_type == kernel_selector::Datatype::INT8 ||
+ input_0_type == kernel_selector::Datatype::INT32 ||
+ input_0_type == kernel_selector::Datatype::INT64)
+ {
+ // input_0 == int && input_1 == int
+ if (input_1_type == kernel_selector::Datatype::INT8 ||
+ input_1_type == kernel_selector::Datatype::INT32 ||
+ input_1_type == kernel_selector::Datatype::INT64)
+ {
+ if (ew.mode == EltwiseMode::MODULU)
+ op += input0_str + " % " + input1_str;
+ else
+ op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
+ }
+ // input_0 == int && input_1 != int
+ else
+ {
+ op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
+ }
+ }
+ // input_0 != int && input_1 == int
+ else if (input_1_type == kernel_selector::Datatype::INT8 ||
+ input_1_type == kernel_selector::Datatype::INT32 ||
+ input_1_type == kernel_selector::Datatype::INT64)
+ {
+ op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
+ }
+ // input_0 != int && input_1 != int
+ else
+ {
+ op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
+ }
+ } break;
+ case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
+ case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
+ case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
+ case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
+ " * (" + input0_str + " - " + input1_str + "))"; break;
+ case EltwiseMode::EQ: op += cast_type + "(" + input0_str + " == " + input1_str + ")"; break;
+ case EltwiseMode::NE: op += cast_type + "(" + input0_str + " != " + input1_str + ")"; break;
+ case EltwiseMode::LT: op += cast_type + "(" + input0_str + " < " + input1_str + ")"; break;
+ case EltwiseMode::LE: op += cast_type + "(" + input0_str + " <= " + input1_str + ")"; break;
+ case EltwiseMode::GT: op += cast_type + "(" + input0_str + " > " + input1_str + ")"; break;
+ case EltwiseMode::GE: op += cast_type + "(" + input0_str + " >= " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_AND: op += cast_type + "(" + input0_str + " && " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_OR: op += cast_type + "(" + input0_str + " || " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_XOR: op += cast_type + "(!" + input0_str + " != !" + input1_str + ")"; break;
+ case EltwiseMode::ASSIGN: op += input0_str; break;
default:
break;
}
jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
}
+ if (!params.stride.empty())
+ {
+ jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
+ }
+
///////////////
return jit;
}
{
return GetCommonKernelsData(params, options);
}
-}
\ No newline at end of file
+}