lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
lhs_transpose_values),
act_values))
{
- // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ // Validate output only if validate() is successful
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
boundary_handling_cases))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingPartialInXFullInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
boundary_handling_cases))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingFullInXFullInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
boundary_handling_cases))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
FIXTURE_DATA_TEST_CASE(RunPrecommitBoundaryHandlingFullInXPartialInY, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
boundary_handling_cases))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
-}
-
-FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
- k_values),
- b_values),
- m0_values_precommit),
- n0_values_precommit),
- k0_values_precommit),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
- broadcast_bias_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
-}
-
-FIXTURE_DATA_TEST_CASE(RunNightly, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
- k_values),
- b_values),
- m0_values_nightly),
- n0_values_nightly),
- k0_values_nightly),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
- broadcast_bias_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
-}
-
-FIXTURE_DATA_TEST_CASE(RunPrecommit3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_w_values,
- m_h_values),
- n_values),
- k_values),
- b_values),
- m0_values_precommit),
- n0_values_precommit),
- k0_values_precommit),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("has_pad_y", {false, true})),
- framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
-}
-
-FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_w_values,
- m_h_values),
- n_values),
- k_values),
- b_values),
- m0_values_nightly),
- n0_values_nightly),
- k0_values_nightly),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("has_pad_y", {false, true})),
- framework::dataset::make("DataType", DataType::F32)),
- a_values),
- beta_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
-TEST_SUITE(ExportToCLImage)
FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<float>, framework::DatasetMode::PRECOMMIT,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
h0_values),
i_values_rhs),
t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("export_to_cl_image_rhs", false, true)),
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
act_values))
{
// Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
h0_values),
i_values_rhs),
t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("export_to_cl_image_rhs", false, true)),
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
act_values))
{
// Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
h0_values),
i_values_rhs),
t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("export_to_cl_image_rhs", false, true)),
framework::dataset::make("has_pad_y", {false, true})),
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
act_values))
{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<float>, framework::DatasetMode::NIGHTLY,
h0_values),
i_values_rhs),
t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", true)),
+ framework::dataset::make("export_to_cl_image_rhs", false, true)),
framework::dataset::make("has_pad_y", {false, true})),
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
act_values))
{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
-TEST_SUITE_END() // ExportToCLImage
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
- k_values),
- b_values),
- m0_values_precommit),
- n0_values_precommit),
- k0_values_precommit),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
- broadcast_bias_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
-}
-
-FIXTURE_DATA_TEST_CASE(RunNightly, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
- n_values),
- k_values),
- b_values),
- m0_values_nightly),
- n0_values_nightly),
- k0_values_nightly),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
- broadcast_bias_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
-}
-
-FIXTURE_DATA_TEST_CASE(RunPrecommit3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<half>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_w_values,
- m_h_values),
- n_values),
- k_values),
- b_values),
- m0_values_precommit),
- n0_values_precommit),
- k0_values_precommit),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("has_pad_y", {false, true})),
- framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
-}
-
-FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<half>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_w_values,
- m_h_values),
- n_values),
- k_values),
- b_values),
- m0_values_nightly),
- n0_values_nightly),
- k0_values_nightly),
- h0_values),
- i_values_rhs),
- t_values_rhs),
- framework::dataset::make("export_to_cl_image_rhs", false)),
- framework::dataset::make("has_pad_y", {false, true})),
- framework::dataset::make("DataType", DataType::F16)),
- a_values),
- beta_values),
- act_values))
-{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
-}
-
-TEST_SUITE(ExportToCLImage)
FIXTURE_DATA_TEST_CASE(RunPrecommit, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half>, framework::DatasetMode::PRECOMMIT,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
act_values))
{
// Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
act_values))
{
// Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
- if(image2d_from_buffer_supported(CLKernelLibrary::get().get_device()))
+ if(validate_result)
{
validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
beta_values),
act_values))
{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
FIXTURE_DATA_TEST_CASE(RunNightly3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<half>, framework::DatasetMode::NIGHTLY,
beta_values),
act_values))
{
- // Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ // Validate output only if the target platform supports the OpenCL cl_khr_image2d_from_buffer extension
+ if(validate_result)
+ {
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+ }
+ else
+ {
+ ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+ framework::ARM_COMPUTE_PRINT_INFO();
+ }
}
-TEST_SUITE_END() // ExportToCLImage
-
TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float
broadcast_bias ? 1 : m,
broadcast_bias ? 1 : batch_size);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
+ if(validate_result)
+ {
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+ }
}
protected:
ReshapeLHSFunctionType reshape_lhs;
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
+
+ validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+ validate_result = validate_result || !rhs_info.export_to_cl_image;
+ if(!validate_result)
+ {
+ return nullptr;
+ }
+
reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
}
}
+ bool validate_result = true;
TensorType _target{};
SimpleTensor<T> _reference{};
};
const TensorShape rhs_shape(n, k, batch_size);
const TensorShape bias_shape(n, 1, 1);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h, act_info);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, m_h, act_info);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h, act_info);
+ if(validate_result)
+ {
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, m_h, act_info);
+ }
}
protected:
ReshapeLHSFunctionType reshape_lhs;
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
+
+ validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+ validate_result = validate_result || !rhs_info.export_to_cl_image;
+ if(!validate_result)
+ {
+ return nullptr;
+ }
+
reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
}
}
+ bool validate_result = true;
TensorType _target{};
SimpleTensor<T> _reference{};
};
broadcast_bias ? 1 : m,
broadcast_bias ? 1 : batch_size);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
+ if(validate_result)
+ {
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+ }
}
protected:
// Create and configure function
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
+
+ validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+ validate_result = validate_result || !rhs_info.export_to_cl_image;
+ if(!validate_result)
+ {
+ return nullptr;
+ }
+
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
}
+ bool validate_result = true;
TensorType _target{};
SimpleTensor<T> _reference{};
};
const TensorShape rhs_shape(n, k, batch_size);
const TensorShape bias_shape(n, 1, 1);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h, act_info, has_pad_y);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, m_h, act_info);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h, act_info, has_pad_y);
+ if(validate_result)
+ {
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, m_h, act_info);
+ }
}
protected:
// Create and configure function
ReshapeRHSFunctionType reshape_rhs;
GEMMFunctionType gemm;
+
+ validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+ validate_result = validate_result || !rhs_info.export_to_cl_image;
+ if(!validate_result)
+ {
+ return nullptr;
+ }
+
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, kernel_info);
return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
}
+ bool validate_result = true;
TensorType _target{};
SimpleTensor<T> _reference{};
};