jit.AddConstants({ MakeJitConstant("BIAS", bias), MakeJitConstant("BIAS_TERM", true) });
}
if (params.hasHidden) {
- jit.AddConstants({ MakeJitConstant("HIDDEN", hidden), MakeJitConstant("HIDDEN_TERM", true) , MakeJitConstant("RECURRENT", recurrent) });
+ jit.AddConstants({ MakeJitConstant("HIDDEN", hidden),
+ MakeJitConstant("HIDDEN_TERM", true),
+ MakeJitConstant("RECURRENT", recurrent),
+ MakeJitConstant("HIDDEN_DIRECTION", params.hidden_direction)
+ });
}
-
jit.AddConstants({ MakeJitConstant("WEIGHTS", weights)});
jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction)});
+ jit.AddConstants({ MakeJitConstant("INPUT_DIRECTION", params.input_direction)});
return jit;
}
KernelData kd = KernelData::Default<lstm_gemm_params>(params, orgParams.inputs.size());
- float effiency = FORCE_PRIORITY_1;
+ float effiency = FORCE_PRIORITY_9;
const auto& input = orgParams.inputs[0];
auto newParams = orgParams;