/*
-// Copyright (c) 2016 Intel Corporation
+// Copyright (c) 2016-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
*/
#include "eltwise_kernel_base.h"
-#include "kernel_selector_utils.h"
+#include "kernel_selector_utils.h"
namespace kernel_selector
{
case EltwiseMode::MAX:
case EltwiseMode::POW:
case EltwiseMode::MODULU:
+ case EltwiseMode::EQ:
+ case EltwiseMode::NE:
+ case EltwiseMode::LT:
+ case EltwiseMode::LE:
+ case EltwiseMode::GT:
+ case EltwiseMode::GE:
+ case EltwiseMode::LOGIC_AND:
+ case EltwiseMode::LOGIC_OR:
+ case EltwiseMode::LOGIC_XOR:
+ case EltwiseMode::SQUARED_DIFF:
return 2;
case EltwiseMode::SQRT:
case EltwiseMode::RSQRT:
k.EnableOutputCalibration();
}
+ if (!stride.empty())
+ {
+ k.EnableEltwiseStride();
+ }
+
+ if (broadcast)
+ {
+ k.EnableEltwiseBroadcast();
+ }
+
return k;
}
jit.AddConstants({
MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
+ MakeJitConstant("ELTWISE_BROADCAST", params.broadcast),
});
if (params.int8_quantization)
}
inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
+ if (!params.stride.empty())
+ {
+ jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
+ jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
+ }
if (useVload8)
{
vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
}
}
- std::string input0_str, input1_str, cast_type, op;
+ std::string input0_str, input1_str, cast_type, output_cast, op;
if (useVload8)
{
op = "const UNIT_TYPE tmp" + op_num_str + " = ";
}
+ if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
+ output_cast = "(char)";
+ cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
+ }
+
input0_str = cast_type + "INPUT_" + op_num_str + "_0";
input1_str = cast_type + "INPUT_" + op_num_str + "_1";
switch (ew.mode)
{
- case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
- case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
- case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
- case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
- case EltwiseMode::MODULU: op += cast_type + "fmod(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::MIN: op += cast_type + "fmin(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::MAX: op += cast_type + "fmax(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
- case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
- case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
- case EltwiseMode::ASSIGN: op += input0_str; break;
+ case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
+ case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
+ case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
+ case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
+ case EltwiseMode::MODULU:
+ case EltwiseMode::MIN:
+ case EltwiseMode::MAX:
+ {
+ auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max" ));
+ auto input_0_type = params.inputs[0].GetDType();
+ auto input_1_type = params.inputs[1].GetDType();
+
+ // input_0 == int
+ if (input_0_type == kernel_selector::Datatype::INT8 ||
+ input_0_type == kernel_selector::Datatype::INT32 ||
+ input_0_type == kernel_selector::Datatype::INT64)
+ {
+ // input_0 == int && input_1 == int
+ if (input_1_type == kernel_selector::Datatype::INT8 ||
+ input_1_type == kernel_selector::Datatype::INT32 ||
+ input_1_type == kernel_selector::Datatype::INT64)
+ {
+ if (ew.mode == EltwiseMode::MODULU)
+ op += input0_str + " % " + input1_str;
+ else
+ op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
+ }
+ // input_0 == int && input_1 != int
+ else
+ {
+ op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
+ }
+ }
+ // input_0 != int && input_1 == int
+ else if ( input_1_type == kernel_selector::Datatype::INT8 ||
+ input_1_type == kernel_selector::Datatype::INT32 ||
+ input_1_type == kernel_selector::Datatype::INT64)
+ {
+ op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
+ }
+ // input_0 != int && input_1 != int
+ else
+ {
+ op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
+ }
+ } break;
+ case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
+ case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
+ case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
+ case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
+ " * (" + input0_str + " - " + input1_str + "))"; break;
+ case EltwiseMode::EQ: op += output_cast + "(" + input0_str + " == " + input1_str + ")"; break;
+ case EltwiseMode::NE: op += output_cast + "(" + input0_str + " != " + input1_str + ")"; break;
+ case EltwiseMode::LT: op += output_cast + "(" + input0_str + " < " + input1_str + ")"; break;
+ case EltwiseMode::LE: op += output_cast + "(" + input0_str + " <= " + input1_str + ")"; break;
+ case EltwiseMode::GT: op += output_cast + "(" + input0_str + " > " + input1_str + ")"; break;
+ case EltwiseMode::GE: op += output_cast + "(" + input0_str + " >= " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_AND: op += output_cast + "(" + input0_str + " && " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_OR: op += output_cast + "(" + input0_str + " || " + input1_str + ")"; break;
+ case EltwiseMode::LOGIC_XOR: op += output_cast + "(!" + input0_str + " != !" + input1_str + ")"; break;
+ case EltwiseMode::ASSIGN: op += input0_str; break;
default:
break;
}
}
for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
- do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
+ do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
"[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
- if (params.layoutBased || params.int8_quantization)
+ if (params.layoutBased || params.int8_quantization || params.broadcast)
+ {
+ jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
+ }
+
+ if (!params.stride.empty())
{
- jit.Merge(GetTensorFriendlyWorkGroupsJit(params.inputs[0]));
+ jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
}
return jit;
{
DispatchData kd;
- if (params.layoutBased || params.int8_quantization)
+ if (params.layoutBased || params.int8_quantization || params.broadcast)
{
- auto global = GetTensorFriendlyWorkGroups(params.inputs[0]);
+ auto global = GetTensorFriendlyWorkGroups(params.output);
kd.gws0 = global[0];
kd.gws1 = global[1];
kd.gws2 = global[2];
+ if (!params.stride.empty())
+ {
+ kd.gws0 /= params.stride[0].x;
+ kd.gws0 /= params.stride[0].y;
+ }
}
else if (CheckInputsOutputNoPitchSameDims(params))
{
kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
- kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, ROUND_ROBIN);
+ kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false, newParams.int8_quantization, newParams.output_calibration);
kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;