2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include <gtest/gtest.h>
19 #include "api/memory.hpp"
20 #include <api/input_layout.hpp>
21 #include "api/embed.hpp"
22 #include <api/topology.hpp>
23 #include <api/tensor.hpp>
24 #include <api/network.hpp>
25 #include <api/engine.hpp>
26 #include <api/data.hpp>
27 #include "test_utils/test_utils.h"
31 using namespace cldnn;
32 using namespace tests;
34 TEST(embed_gpu, seq3num4) {
54 const auto& engine = get_test_engine();
56 auto sequence_length = 3;
57 auto num_output_size = 4;
59 auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, sequence_length, 1 } });
60 auto weights_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ num_output_size, 1, vocab_size, 1 } });
61 auto bias_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, 1, num_output_size } });
62 auto output_ref = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, sequence_length, num_output_size, 1 } });
64 set_values(input_prim, { 1.0f, 2.0f, 0.0f });
65 set_values(weights_prim, { 1.0f, 1.0f, 1.0f, 1.0f,
66 2.0f, 2.0f, 2.0f, 2.0f,
67 3.0f, 3.0f, 3.0f, 3.0f });
68 set_values(bias_prim, { 1.0f, 2.0f, 3.0f, 4.0f });
69 set_values(output_ref, { 3.0f, 4.0f, 5.0f, 6.0f,
70 4.0f, 5.0f, 6.0f, 7.0f,
71 2.0f, 3.0f, 4.0f, 5.0f });
73 auto input = input_layout("input", input_prim.get_layout());
74 auto w_data = data("weights", weights_prim);
75 auto b_data = data("bias", bias_prim);
77 auto embed_test = embed("embed_prim", "input", "weights", "bias");
82 topology.add(embed_test);
84 network network(engine, topology);
85 network.set_input_data("input", input_prim);
87 auto outputs = network.execute();
88 EXPECT_EQ(outputs.size(), size_t(1));
89 EXPECT_EQ(outputs.begin()->first, "embed_prim");
91 auto output_prim = outputs.begin()->second.get_memory();
92 auto ref = output_ref.pointer<float>();
93 auto output_ptr = output_prim.pointer<float>();
94 for (auto i = 0; i < batch * sequence_length * num_output_size; i++) {
95 EXPECT_EQ(ref[i], output_ptr[i]);
100 TEST(embed_gpu, b2seq2num3) {
117 // -1.0 0.0 1.0 -1.0 4.0 4.0
118 // 10.0 18.0 19.0 -1.0 0.0 1.0
120 const auto& engine = get_test_engine();
122 auto sequence_length = 2;
123 auto num_output_size = 3;
125 auto input_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, 1, sequence_length, 1 } });
126 auto weights_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ num_output_size, 1, vocab_size, 1 } });
127 auto bias_prim = memory::allocate(engine, { data_types::f32,format::bfyx,{ 1, 1, 1, num_output_size } });
128 auto output_ref = memory::allocate(engine, { data_types::f32,format::bfyx,{ batch, sequence_length, num_output_size, 1 } });
130 set_values(input_prim, { 0.0f, 1.0f, 2.0f, 0.0f });
131 set_values(weights_prim, { -1.0f, -2.0f, -3.0f,
133 10.0f, 16.0f, 15.0f });
134 set_values(bias_prim, { 0.0f, 2.0f, 4.0f });
135 set_values(output_ref, { -1.0f, 0.0f, 1.0f, -1.0f, 4.0f, 4.0f,
136 10.0f, 18.0f, 19.0f, -1.0f, 0.0f, 1.0f });
138 auto input = input_layout("input", input_prim.get_layout());
139 auto w_data = data("weights", weights_prim);
140 auto b_data = data("bias", bias_prim);
142 auto embed_test = embed("embed_prim", "input", "weights", "bias");
145 topology.add(w_data);
146 topology.add(b_data);
147 topology.add(embed_test);
149 network network(engine, topology);
150 network.set_input_data("input", input_prim);
152 auto outputs = network.execute();
153 EXPECT_EQ(outputs.size(), size_t(1));
154 EXPECT_EQ(outputs.begin()->first, "embed_prim");
156 auto output_prim = outputs.begin()->second.get_memory();
157 auto ref = output_ref.pointer<float>();
158 auto output_ptr = output_prim.pointer<float>();
159 for (auto i = 0; i < batch * sequence_length * num_output_size; i++) {
160 EXPECT_EQ(ref[i], output_ptr[i]);