- reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
- reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
- reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
- std::vector<std::vector<char>> forward_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
- std::vector<std::vector<char>> reverse_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
+ reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+ reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+ reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+ std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
+ std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);