#include "machine_learning_exception.h"
#include "object_detection_adapter.h"
#include "object_detection_external.h"
+#include "mv_object_detection_config.h"
using namespace std;
using namespace MediaVision::Common;
{
template<typename T, typename V> ObjectDetectionAdapter<T, V>::ObjectDetectionAdapter() : _source()
{
- // In default, Mobilenet v1 ssd model will be used.
- // If other model is set by user then strategy pattern will be used
- // to create its corresponding concrete class by calling create().
- _object_detection = make_unique<MobilenetV1Ssd>(ObjectDetectionTaskType::MOBILENET_V1_SSD);
+ auto config = make_unique<EngineConfig>(string(MV_CONFIG_PATH) + _meta_file_name);
+
+ string defaultModelName;
+
+ int ret = config->getStringAttribute(MV_OBJECT_DETECTION_DEFAULT_MODEL_NAME, &defaultModelName);
+ if (ret != MEDIA_VISION_ERROR_NONE)
+ throw InvalidOperation("Fail to get default model name.");
+
+ create(convertToTaskType(defaultModelName.c_str()));
}
template<typename T, typename V> ObjectDetectionAdapter<T, V>::~ObjectDetectionAdapter()
template<typename T, typename V> void ObjectDetectionAdapter<T, V>::create(ObjectDetectionTaskType task_type)
{
- // If default task type is same as a given one then skip.
- if (_object_detection->getTaskType() == task_type)
- return;
+ // If a concrete class object created already exists, reset the object
+ // so that other concrete class object can be created again according to a given task_type.
+ if (_object_detection) {
+ // If default task type is same as a given one then skip.
+ if (_object_detection->getTaskType() == task_type)
+ return;
- _object_detection.reset();
+ _object_detection.reset();
+ }
if (task_type == ObjectDetectionTaskType::MOBILENET_V1_SSD)
_object_detection = make_unique<MobilenetV1Ssd>(task_type);
template<typename T, typename V> void ObjectDetectionAdapter<T, V>::configure()
{
- _object_detection->configure("object_detection.json");
+ _object_detection->configure(_meta_file_name);
}
template<typename T, typename V> void ObjectDetectionAdapter<T, V>::getNumberOfEngines(unsigned int *number_of_engines)