// 0.0 0.0 0.0 0.0
// 6.0 8.0 -2.0 -2.0
- engine engine;
+ const auto& engine = get_test_engine();
auto batch = 1;
auto sequence_length = 3;
auto num_output_size = 4;
auto vocab_size = 3;
- auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, 1, sequence_length } });
+ auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, sequence_length, 1 } });
auto weights_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ num_output_size, 1, vocab_size, 1 } });
auto bias_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, 1, num_output_size } });
auto output_ref = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, sequence_length, num_output_size, 1 } });
// -1.0 0.0 1.0 -1.0 4.0 4.0
// 10.0 18.0 19.0 -1.0 0.0 1.0
- engine engine;
+ const auto& engine = get_test_engine();
auto batch = 2;
auto sequence_length = 2;
auto num_output_size = 3;
auto vocab_size = 3;
- auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, 1, sequence_length } });
+ auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, sequence_length, 1 } });
auto weights_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ num_output_size, 1, vocab_size, 1 } });
auto bias_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ 1, 1, 1, num_output_size } });
auto output_ref = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, sequence_length, num_output_size, 1 } });