// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
- 2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> C_split(
- 2, std::vector<char>(ngraph::shape_size(C_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(C_shape) / 2));
std::vector<std::vector<char>> W_split(
- 2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
- 2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
- 2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
- reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
- reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
- reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
+ reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+ reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+ reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+ std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
std::vector<std::vector<char>> forward_res(
- 3, std::vector<char>(H_shape[0] * H_shape[2]));
+ 2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
std::vector<std::vector<char>> reverse_res(
- 3, std::vector<char>(H_shape[0] * H_shape[2]));
+ 2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
CellArgs args;
args.activation_f = activation_f;
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
+ // update H,C,W,R,B shapes after split
+ shapes[2][1] = 1;
+ shapes[3][1] = 1;
+ for (int i = 4; i < shapes.size(); ++i)
+ {
+ shapes[i][0] = 1;
+ }
// forward pass
cell_pass<T>(
CellType::LSTM,
r_pointers[0],
b_pointers[0]},
shapes,
- {forward_res[0].data(), forward_res[1].data(), forward_res[2].data()},
+ {forward_res_y.data(), forward_res[0].data(), forward_res[1].data()},
args,
false);
// reverse pass
r_pointers[1],
b_pointers[1]},
shapes,
- {reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()},
+ {reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
- std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
- {H_shape[0], 1, H_shape[2]},
- {H_shape[0], 1, H_shape[2]}};
- Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+ std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+ {H_shape[0], 1, X_shape[1], H_shape[2]}};
+ std::vector<Shape> in_shapes_h_c = {{H_shape[0], 1, H_shape[2]},
+ {H_shape[0], 1, H_shape[2]}};
+ Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+ Shape output_shape_h_c{H_shape[0], 2, H_shape[2]};
- runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+ runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
- in_shapes,
- output_shape,
+ in_shapes_y,
+ output_shape_y,
1,
sizeof(T));
- runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+ runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
Ho,
- in_shapes,
- output_shape,
+ in_shapes_h_c,
+ output_shape_h_c,
1,
sizeof(T));
- runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()},
+ runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
Co,
- in_shapes,
- output_shape,
+ in_shapes_h_c,
+ output_shape_h_c,
1,
sizeof(T));
}
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
- 2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> W_split(
- 2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
- 2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
- 2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
- reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
- reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
- reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
- std::vector<std::vector<char>> forward_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
- std::vector<std::vector<char>> reverse_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
+ reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+ reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+ reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+ std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
+ std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
CellArgs args;
args.activation_f = activation_f;
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
+ // update H,W,R,B shapes after split
+ shapes[2][1] = 1;
+ for (int i = 3; i < shapes.size(); ++i)
+ {
+ shapes[i][0] = 1;
+ }
// forward pass
cell_pass<T>(CellType::GRU,
{X,
r_pointers[0],
b_pointers[0]},
shapes,
- {forward_res[0].data(), forward_res[1].data()},
+ {forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
r_pointers[1],
b_pointers[1]},
shapes,
- {reverse_res[0].data(), reverse_res[1].data()},
+ {reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
- std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
- {H_shape[0], 1, H_shape[2]}};
- Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+ std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+ {H_shape[0], 1, X_shape[1], H_shape[2]}};
+ std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
+ {H_shape[0], 1, H_shape[2]}};
+ Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+ Shape output_shape_h{H_shape[0], 2, H_shape[2]};
- runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+ runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
- in_shapes,
- output_shape,
+ in_shapes_y,
+ output_shape_y,
1,
sizeof(T));
- runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+ runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
Ho,
- in_shapes,
- output_shape,
+ in_shapes_h,
+ output_shape_h,
1,
sizeof(T));
}
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
- 2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> W_split(
- 2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
- 2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
- 2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
+ 2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
- reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
- reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
- reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
- std::vector<std::vector<char>> forward_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
- std::vector<std::vector<char>> reverse_res(
- 2, std::vector<char>(H_shape[0] * H_shape[2]));
+ reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
+ reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
+ reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
+ std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
+ std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
+ X_shape[1]);
+ std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
CellArgs args;
args.activation_f = activation_f;
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
+ // update H,W,R,B shapes after split
+ shapes[2][1] = 1;
+ for (int i = 3; i < shapes.size(); ++i)
+ {
+ shapes[i][0] = 1;
+ }
// forward pass
cell_pass<T>(CellType::RNN,
{X,
r_pointers[0],
b_pointers[0]},
shapes,
- {forward_res[0].data(), forward_res[1].data()},
+ {forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
r_pointers[1],
b_pointers[1]},
shapes,
- {reverse_res[0].data(), reverse_res[1].data()},
+ {reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
- std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
- {H_shape[0], 1, H_shape[2]}};
- Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
+ std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
+ {H_shape[0], 1, X_shape[1], H_shape[2]}};
+ std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
+ {H_shape[0], 1, H_shape[2]}};
+ Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
+ Shape output_shape_h{H_shape[0], 2, H_shape[2]};
- runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
+ runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
- in_shapes,
- output_shape,
+ in_shapes_y,
+ output_shape_y,
1,
sizeof(T));
- runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
+ runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
Ho,
- in_shapes,
- output_shape,
+ in_shapes_h,
+ output_shape_h,
1,
sizeof(T));
}