// See the License for the specific language governing permissions and
// limitations under the License.
+#include "include/include_all.cl"
-#include "include/common.cl"
-#include "include/data_types.cl"
-
-KERNEL(embed_ref)(const __global UNIT_TYPE* input0, __global UNIT_TYPE* output, __global UNIT_TYPE* weights, __global UNIT_TYPE* biases)
+KERNEL(embed_ref)(const __global UNIT_TYPE* input0,
+ __global UNIT_TYPE* output,
+ const __global UNIT_TYPE* weights
+#if BIAS_TERM
+ ,const __global UNIT_TYPE* biases
+#endif
+)
{
const uint x = (uint)get_global_id(0);
const uint y = (uint)get_global_id(1);
const uint b = (uint)get_global_id(2);
+
uint output_idx = (b*INPUT0_ELEMENTS_COUNT*NUM_OUTPUT_SIZE)+(uint)(x*NUM_OUTPUT_SIZE+y);
- output[output_idx] = weights[(uint)(input0[(b*INPUT0_ELEMENTS_COUNT)+x]*NUM_OUTPUT_SIZE+y)] + biases[y];
+ output[output_idx] = weights[(uint)(input0[(b*INPUT0_ELEMENTS_COUNT)+x]*NUM_OUTPUT_SIZE+y)];
+#if BIAS_TERM
+ output[output_idx] += biases[y];
+#endif
}
-
\ No newline at end of file