#include "error_handler.h"
#include "kernel_runner.h"
-#include "api/CPP/reorder.hpp"
-#include "api/CPP/input_layout.hpp"
+#include "api/reorder.hpp"
+#include "api/input_layout.hpp"
#include <memory>
namespace cldnn {
protected:
kernel::kernel_arguments_data get_arguments(typed_primitive_inst<fully_connected>& instance,
- int32_t) const override {
- kernel::kernel_arguments_data args;
+ int32_t split) const override {
+ kernel::kernel_arguments_data args = parent::get_arguments(instance, split);
args.inputs = {(memory_impl::cptr) &instance.input_memory()};
args.output = (memory_impl::cptr) &instance.output_memory();
arg.get_program());
fc_optional_params.allowInputReordering = true;
- if (arg.get_primitive()->with_activation)
- convert_activation_func_params(arg.get_primitive(), fc_params.activation);
-
fc_params.output = fc_params.output.FlattenFeatureAndSpatials();
const auto primitive = arg.get_primitive();
}
};
-namespace {
-struct attach {
- attach() {
- auto val_fw = fully_connected_gpu::create;
-
- implementation_map<fully_connected>::add({
- {std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw},
- // MMAD
- {std::make_tuple(engine_types::ocl, data_types::i8, format::byxf_af32), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::i8, format::fs_bs_yx_bsv4_fsv32), val_fw},
- // IMAD
- {std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw},
- {std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw},
- // fs_b_yx_fsv32
- {std::make_tuple(engine_types::ocl, data_types::f16, format::fs_b_yx_fsv32), val_fw},
- });
- }
- ~attach() {}
-};
-attach attach_impl;
-} // namespace
+namespace detail {
+
+attach_fully_connected_gpu::attach_fully_connected_gpu() {
+ auto val_fw = fully_connected_gpu::create;
+
+ implementation_map<fully_connected>::add({
+ {std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw},
+ // MMAD
+ {std::make_tuple(engine_types::ocl, data_types::i8, format::byxf_af32), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::i8, format::fs_bs_yx_bsv4_fsv32), val_fw},
+ // IMAD
+ {std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw},
+ {std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw},
+ // fs_b_yx_fsv32
+ {std::make_tuple(engine_types::ocl, data_types::f16, format::fs_b_yx_fsv32), val_fw},
+ });
+}
+
+} // namespace detail
} // namespace gpu
} // namespace cldnn