if (params.has_cell) {
const auto& cell = params.cell;
- jit.AddConstants({ MakeJitConstant("CELL_TERM", true), MakeJitConstant("CELL", cell) });
+ jit.AddConstants({
+ MakeJitConstant("CELL_TERM", true),
+ MakeJitConstant("CELL", cell),
+ MakeJitConstant("CELL_DIRECTION", params.cell_direction)
+ });
}
if (params.clip > 0) {
std::string psclip = toCodeString(params.clip);
if (params.input_forget) {
jit.AddConstants({ MakeJitConstant("INPUT_FORGET", true) });
}
+ jit.AddConstants({ MakeJitConstant("DIRECTION", params.direction) });
const auto& GEMMInput = params.inputs[0];
size_t size = GEMMInput.X().v / 4;