From: Pawel Wasowski Date: Thu, 18 Feb 2021 10:09:04 +0000 (+0100) Subject: [ML][Pipeline] Implement {register, unregister}CustomFilter X-Git-Tag: submit/tizen/20210304.103045~6^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=refs%2Fchanges%2F10%2F253510%2F13;p=platform%2Fcore%2Fapi%2Fwebapi-plugins.git [ML][Pipeline] Implement {register, unregister}CustomFilter ACR: TWDAPI-274 This is the first part of the implementation of nnstreamer's pipeline CustomFilter Web API. It enables registering and unregistering CustomFilters, but they don't process the data. [Verification] Tested with the snippets below, works fine var inputTI = new tizen.ml.TensorsInfo(); inputTI.addTensorInfo('ti1', 'UINT8', [4, 20, 15, 1]); var outputTI = new tizen.ml.TensorsInfo(); outputTI.addTensorInfo('ti1', 'UINT8', [1200]); var flattenPlusOne = function(input) { console.log("Custom filter called with: "); var outputTD = outputTI.getTensorsData(); var rawInputData = input.getTensorRawData(0); for (var i = 0; i < rawInputData.data.size; ++i) { rawInputData.data[i] = rawInputData.data[i] + 1; } outputTD.setTensorRawData(0, rawInputData.data); return new tizen.ml.CustomFilterOutput(0, outputTD); } // register - the happy scenario tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI, outputTI, function errorCallback(error) { console.warn('custom filter error:') ; console.warn(error); }); var pipeline_def = "videotestsrc num-buffers=3 " + "! video/x-raw,width=20,height=15,format=BGRA " + "! tensor_converter " + "! tensor_filter framework=custom-easy model=testfilter2 " + "! appsink name=mysink"; var pipeline = tizen.ml.pipeline.createPipeline(pipeline_def, state => {console.log(state);}) // READY pipeline.start() // unregister - the happy scenario tizen.ml.pipeline.unregisterCustomFilter('testfilter2') // overwrite a previously registered filter - that's ok tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI, outputTI, function errorCallback(error) { console.warn('custom filter error:') ; console.warn(error); }); tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI, outputTI, function errorCallback(error) { console.warn('custom filter error:') ; console.warn(error); }); // unregister nonexistent filter tizen.ml.pipeline.unregisterCustomFilter('nonexistentfilter') // InvalidValuesError: ""nonexistentfilter" CustomFilter not found" Change-Id: Ib01490e669376149eae654937131116f5ec1a4df Signed-off-by: Pawel Wasowski --- diff --git a/src/ml/js/ml_pipeline.js b/src/ml/js/ml_pipeline.js index ef68dcd7..e1b211ad 100755 --- a/src/ml/js/ml_pipeline.js +++ b/src/ml/js/ml_pipeline.js @@ -16,6 +16,7 @@ var kPipelineStateChangeListenerNamePrefix = 'MLPipelineStateChangeListener'; var kSinkListenerNamePrefix = 'MLPipelineSinkListener'; +var kCustomFilterListenerNamePrefix = 'MLPipelineCustomFilterListener'; //PipelineManager::createPipeline() begin var ValidPipelineDisposeExceptions = ['NotFoundError', 'NotSupportedError', 'AbortError']; @@ -485,14 +486,6 @@ Pipeline.prototype.unregisterSinkListener = function() { }; //Pipeline::unregisterSinkListener() end -//Pipeline::registerCustomFilter() begin - -//Pipeline::registerCustomFilter() end - -//Pipeline::unregisterCustomFilter() begin - -//Pipeline::unregisterCustomFilter() end - var PropertyType = { BOOLEAN: 'BOOLEAN', DOUBLE: 'DOUBLE', @@ -726,4 +719,110 @@ var MachineLearningPipeline = function() {}; MachineLearningPipeline.prototype.createPipeline = CreatePipeline; +//Pipeline::registerCustomFilter() begin +var ValidRegisterCustomFilterExceptions = [ + 'InvalidValuesError', + 'NotSupportedError', + 'TypeMismatchError', + 'AbortError' +]; + +MachineLearningPipeline.prototype.registerCustomFilter = function() { + var args = validator_.validateArgs(arguments, [ + { + name: 'name', + type: validator_.Types.STRING + }, + { + name: 'customFilter', + type: types_.FUNCTION + }, + { + name: 'inputInfo', + type: types_.PLATFORM_OBJECT, + values: TensorsInfo + }, + { + name: 'outputInfo', + type: types_.PLATFORM_OBJECT, + values: TensorsInfo + }, + { + name: 'errorCallback', + type: types_.FUNCTION, + optional: true, + nullable: true + } + ]); + + var nativeArgs = { + name: args.name, + listenerName: kCustomFilterListenerNamePrefix + args.name, + inputTensorsInfoId: args.inputInfo._id, + outputTensorsInfoId: args.outputInfo._id + }; + + var customFilterWrapper = function(msg) { + // TODO: In the next commit + }; + + if (native_.listeners_.hasOwnProperty(nativeArgs.listenerName)) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + '"' + nativeArgs.name + '" custom filter is already registered' + ); + } + + native_.addListener(nativeArgs.listenerName, customFilterWrapper); + + var result = native_.callSync('MLPipelineManagerRegisterCustomFilter', nativeArgs); + if (native_.isFailure(result)) { + native_.removeListener(nativeArgs.listenerName); + + throw native_.getErrorObjectAndValidate( + result, + ValidRegisterCustomFilterExceptions, + AbortError + ); + } +}; +//Pipeline::registerCustomFilter() end + +//Pipeline::unregisterCustomFilter() begin +var ValidUnregisterCustomFilterExceptions = [ + 'InvalidValuesError', + 'NotSupportedError', + 'AbortError' +]; +MachineLearningPipeline.prototype.unregisterCustomFilter = function() { + var args = validator_.validateArgs(arguments, [ + { + name: 'name', + type: validator_.Types.STRING + } + ]); + + if (!args.has.name) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: custom filter name is mandatory' + ); + } + + var result = native_.callSync('MLPipelineManagerUnregisterCustomFilter', { + name: args.name + }); + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidUnregisterCustomFilterExceptions, + AbortError + ); + } + + var customFilterListenerName = kCustomFilterListenerNamePrefix + args.name; + native_.removeListener(customFilterListenerName); +}; +//Pipeline::unregisterCustomFilter() end + // ML Pipeline API diff --git a/src/ml/ml.gyp b/src/ml/ml.gyp index 419c44c6..b5a7a9e4 100644 --- a/src/ml/ml.gyp +++ b/src/ml/ml.gyp @@ -17,6 +17,8 @@ 'ml_instance.h', 'ml_pipeline.cc', 'ml_pipeline.h', + 'ml_pipeline_custom_filter.cc', + 'ml_pipeline_custom_filter.h', 'ml_pipeline_manager.cc', 'ml_pipeline_manager.h', 'ml_pipeline_nodeinfo.cc', diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc index aee61d0a..0cdf5f52 100644 --- a/src/ml/ml_instance.cc +++ b/src/ml/ml_instance.cc @@ -53,6 +53,8 @@ const std::string kSize = "size"; const std::string kLocation = "location"; const std::string kShape = "shape"; const std::string kListenerName = "listenerName"; +const std::string kInputTensorsInfoId = "inputTensorsInfoId"; +const std::string kOutputTensorsInfoId = "outputTensorsInfoId"; } // namespace using namespace common; @@ -85,7 +87,7 @@ using namespace common; MlInstance::MlInstance() : tensors_info_manager_{&tensors_data_manager_}, single_manager_{&tensors_info_manager_}, - pipeline_manager_{this, &tensors_info_manager_} { + pipeline_manager_{this, &tensors_info_manager_, &tensors_data_manager_} { ScopeLogger(); using namespace std::placeholders; @@ -145,6 +147,8 @@ MlInstance::MlInstance() REGISTER_METHOD(MLPipelineSourceInputData); REGISTER_METHOD(MLPipelineRegisterSinkListener); REGISTER_METHOD(MLPipelineUnregisterSinkListener); + REGISTER_METHOD(MLPipelineManagerRegisterCustomFilter); + REGISTER_METHOD(MLPipelineManagerUnregisterCustomFilter); // Pipeline API end #undef REGISTER_METHOD @@ -1191,11 +1195,66 @@ void MlInstance::MLPipelineUnregisterSinkListener(const picojson::value& args, // Pipeline::unregisterSinkCallback() end // Pipeline::registerCustomFilter() begin +void MlInstance::MLPipelineManagerRegisterCustomFilter(const picojson::value& args, + picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + + CHECK_ARGS(args, kName, std::string, out); + CHECK_ARGS(args, kListenerName, std::string, out); + CHECK_ARGS(args, kInputTensorsInfoId, double, out); + CHECK_ARGS(args, kOutputTensorsInfoId, double, out); + + const auto& custom_filter_name = args.get(kName).get(); + const auto& listener_name = args.get(kListenerName).get(); + auto input_tensors_info_id = static_cast(args.get(kInputTensorsInfoId).get()); + auto output_tensors_info_id = static_cast(args.get(kOutputTensorsInfoId).get()); + + TensorsInfo* input_tensors_info_ptr = + GetTensorsInfoManager().GetTensorsInfo(input_tensors_info_id); + if (!input_tensors_info_ptr) { + LogAndReportError( + PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsInfo error"), &out, + ("Could not find TensorsInfo handle with given id: %d", input_tensors_info_id)); + return; + } + TensorsInfo* output_tensors_info_ptr = + GetTensorsInfoManager().GetTensorsInfo(output_tensors_info_id); + if (!output_tensors_info_ptr) { + LogAndReportError( + PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsInfo error"), &out, + ("Could not find TensorsInfo handle with given id: %d", output_tensors_info_id)); + return; + } + + auto ret = pipeline_manager_.RegisterCustomFilter( + custom_filter_name, listener_name, input_tensors_info_ptr, output_tensors_info_ptr); + if (!ret) { + LogAndReportError(ret, &out); + return; + } + + ReportSuccess(out); +} // Pipeline::registerCustomFilter() end // Pipeline::unregisterCustomFilter() begin +void MlInstance::MLPipelineManagerUnregisterCustomFilter(const picojson::value& args, + picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + + CHECK_ARGS(args, kName, std::string, out); + + const auto& custom_filter_name = args.get(kName).get(); + auto ret = pipeline_manager_.UnregisterCustomFilter(custom_filter_name); + if (!ret) { + LogAndReportError(ret, &out); + return; + } + + ReportSuccess(out); +} // Pipeline::unregisterCustomFilter() end // NodeInfo::getProperty() begin diff --git a/src/ml/ml_instance.h b/src/ml/ml_instance.h index a44ce690..35c83228 100644 --- a/src/ml/ml_instance.h +++ b/src/ml/ml_instance.h @@ -152,11 +152,11 @@ class MlInstance : public common::ParsedInstance { // Pipeline::unregisterSinkCallback() end // Pipeline::registerCustomFilter() begin - + void MLPipelineManagerRegisterCustomFilter(const picojson::value& args, picojson::object& out); // Pipeline::registerCustomFilter() end // Pipeline::unregisterCustomFilter() begin - + void MLPipelineManagerUnregisterCustomFilter(const picojson::value& args, picojson::object& out); // Pipeline::unregisterCustomFilter() end // NodeInfo::getProperty() begin diff --git a/src/ml/ml_pipeline_custom_filter.cc b/src/ml/ml_pipeline_custom_filter.cc new file mode 100644 index 00000000..43d55765 --- /dev/null +++ b/src/ml/ml_pipeline_custom_filter.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/tools.h" +#include "ml_pipeline_custom_filter.h" +#include "ml_utils.h" + +using common::ErrorCode; +using common::PlatformResult; + +namespace { + +const std::string kListenerId = "listenerId"; + +} // namespace + +namespace extension { +namespace ml { +namespace pipeline { + +PlatformResult CustomFilter::CreateAndRegisterCustomFilter( + const std::string& name, const std::string& listener_name, TensorsInfo* input_tensors_info_ptr, + TensorsInfo* output_tensors_info_ptr, common::Instance* instance_ptr, + TensorsInfoManager* tensors_info_manager_ptr, TensorsDataManager* tensors_data_manager_ptr, + std::unique_ptr* out) { + ScopeLogger( + "name: [%s], listener_name: [%s], input_tensors_info::id: [%d], " + "output_tensors_info::id: [%d]", + name.c_str(), listener_name.c_str(), input_tensors_info_ptr->Id(), + output_tensors_info_ptr->Id()); + + auto* input_tensors_info_clone_ptr = + tensors_info_manager_ptr->CloneTensorsInfo(input_tensors_info_ptr); + if (!input_tensors_info_clone_ptr) { + return LogAndCreateResult( + ErrorCode::ABORT_ERR, "Could not register custom filter", + ("Could not clone TensorsInfo with id: [%d]", input_tensors_info_ptr->Id())); + } + + auto* output_tensors_info_clone_ptr = + tensors_info_manager_ptr->CloneTensorsInfo(output_tensors_info_ptr); + if (!output_tensors_info_clone_ptr) { + return LogAndCreateResult( + ErrorCode::ABORT_ERR, "Could not register custom filter", + ("Could not clone TensorsInfo with id: [%d]", output_tensors_info_ptr->Id())); + } + + auto custom_filter_ptr = std::unique_ptr(new (std::nothrow) CustomFilter{ + name, listener_name, input_tensors_info_clone_ptr, output_tensors_info_clone_ptr, + instance_ptr, tensors_info_manager_ptr, tensors_data_manager_ptr, + std::this_thread::get_id()}); + ; + if (!custom_filter_ptr) { + return LogAndCreateResult(ErrorCode::ABORT_ERR, "Could not register custom filter", + ("Could not allocate memory")); + } + + ml_custom_easy_filter_h custom_filter_handle = nullptr; + auto ret = ml_pipeline_custom_easy_filter_register( + name.c_str(), input_tensors_info_ptr->Handle(), output_tensors_info_ptr->Handle(), + CustomFilterListener, static_cast(custom_filter_ptr.get()), &custom_filter_handle); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_custom_easy_filter_register() failed: [%d] (%s)", ret, + get_error_message(ret)); + return util::ToPlatformResult(ret, "Could not register custom filter"); + } + LoggerD("ml_pipeline_custom_easy_filter_register() succeeded"); + custom_filter_ptr->custom_filter_ = custom_filter_handle; + + *out = std::move(custom_filter_ptr); + + return PlatformResult{}; +} + +CustomFilter::CustomFilter(const std::string& name, const std::string& listener_name, + TensorsInfo* input_tensors_info, TensorsInfo* output_tensors_info, + common::Instance* instance_ptr, + TensorsInfoManager* tensors_info_manager_ptr, + TensorsDataManager* tensors_data_manager_ptr, + const std::thread::id& main_thread_id) + : name_{name}, + listener_name_{listener_name}, + input_tensors_info_ptr_{input_tensors_info}, + output_tensors_info_ptr_{output_tensors_info}, + custom_filter_{nullptr}, + instance_ptr_{instance_ptr}, + tensors_info_manager_ptr_{tensors_info_manager_ptr}, + tensors_data_manager_ptr_{tensors_data_manager_ptr}, + main_thread_id_{main_thread_id} { + ScopeLogger( + "name_: [%s], listener_name_: [%s], input_tensors_info::id: [%d], " + "output_tensors_info::id: [%d]", + name_.c_str(), listener_name_.c_str(), input_tensors_info_ptr_->Id(), + output_tensors_info_ptr_->Id()); +} + +CustomFilter::~CustomFilter() { + ScopeLogger("name: [%s]_, listener_name_: [%s], custom_filter_: [%p]", name_.c_str(), + listener_name_.c_str(), custom_filter_); + + Unregister(); + tensors_info_manager_ptr_->DisposeTensorsInfo(input_tensors_info_ptr_); + tensors_info_manager_ptr_->DisposeTensorsInfo(output_tensors_info_ptr_); +} + +PlatformResult CustomFilter::Unregister() { + ScopeLogger("name_: [%s], listener_name_: [%s], custom_filter_: [%p]", name_.c_str(), + listener_name_.c_str(), custom_filter_); + + if (!custom_filter_) { + LoggerD("CustomFilter was already unregistered"); + return PlatformResult{}; + } + + auto ret = ml_pipeline_custom_easy_filter_unregister(custom_filter_); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_custom_easy_filter_unregister() failed: [%d] (%s)", ret, + get_error_message(ret)); + return util::ToPlatformResult(ret, "Could not unregister custom_filter"); + } + LoggerD("ml_pipeline_custom_easy_filter_unregister() succeeded"); + + custom_filter_ = nullptr; + + return PlatformResult{}; +} + +int CustomFilter::CustomFilterListener(const ml_tensors_data_h input_tensors_data, + ml_tensors_data_h output_tensors_data, void* user_data) { + ScopeLogger("input_tensors_data: [%p], tensors_info_out: [%p], user_data: [%p]", + input_tensors_data, output_tensors_data, user_data); + // TODO: in next commit + return -1; +} + +} // namespace pipeline +} // namespace ml +} // namespace extension \ No newline at end of file diff --git a/src/ml/ml_pipeline_custom_filter.h b/src/ml/ml_pipeline_custom_filter.h new file mode 100644 index 00000000..8f86ea43 --- /dev/null +++ b/src/ml/ml_pipeline_custom_filter.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ML_ML_PIPELINE_CUSTOM_FILTER_H_ +#define ML_ML_PIPELINE_CUSTOM_FILTER_H_ + +#include +#include +#include +#include + +#include + +#include "common/extension.h" +#include "common/platform_result.h" + +#include "ml_tensors_data_manager.h" +#include "ml_tensors_info_manager.h" + +using common::PlatformResult; + +namespace extension { +namespace ml { +namespace pipeline { + +class CustomFilter { + public: + static PlatformResult CreateAndRegisterCustomFilter( + const std::string& name, const std::string& listener_name, + TensorsInfo* input_tensors_info_ptr, TensorsInfo* output_tensors_info_ptr, + common::Instance* instance_ptr, TensorsInfoManager* tensors_info_manager_ptr, + TensorsDataManager* tensors_data_manager_ptr, std::unique_ptr* out); + + ~CustomFilter(); + + PlatformResult Unregister(); + + CustomFilter(const CustomFilter&) = delete; + CustomFilter& operator=(const CustomFilter&) = delete; + + private: + CustomFilter(const std::string& name, const std::string& listener_name, + TensorsInfo* input_tensors_info, TensorsInfo* output_tensors_info, + common::Instance* instance_ptr, TensorsInfoManager* tensors_info_manager_ptr, + TensorsDataManager* tensors_data_manager_ptrr, + const std::thread::id& main_thread_id); + + static int CustomFilterListener(const ml_tensors_data_h tensors_data_in, + ml_tensors_data_h tensors_data_out, void* user_data); + + const std::string name_; + const std::string listener_name_; + TensorsInfo* input_tensors_info_ptr_; + TensorsInfo* output_tensors_info_ptr_; + ml_custom_easy_filter_h custom_filter_; + common::Instance* instance_ptr_; + TensorsInfoManager* tensors_info_manager_ptr_; + TensorsDataManager* tensors_data_manager_ptr_; + + std::thread::id main_thread_id_; +}; + +} // namespace pipeline +} // namespace ml +} // namespace extension + +#endif // ML_ML_PIPELINE_CUSTOM_FILTER_H_ diff --git a/src/ml/ml_pipeline_manager.cc b/src/ml/ml_pipeline_manager.cc index c2c3e1a8..7ed62b57 100644 --- a/src/ml/ml_pipeline_manager.cc +++ b/src/ml/ml_pipeline_manager.cc @@ -14,8 +14,11 @@ * limitations under the License. */ -#include "ml_pipeline_manager.h" +#include +#include + #include "common/tools.h" +#include "ml_pipeline_manager.h" #include "ml_pipeline_switch.h" using common::PlatformResult; @@ -27,8 +30,11 @@ namespace extension { namespace ml { PipelineManager::PipelineManager(common::Instance* instance_ptr, - TensorsInfoManager* tensors_info_manager) - : instance_ptr_{instance_ptr}, tensors_info_manager_{tensors_info_manager} { + TensorsInfoManager* tensors_info_manager, + TensorsDataManager* tensors_data_manager) + : instance_ptr_{instance_ptr}, + tensors_info_manager_{tensors_info_manager}, + tensors_data_manager_{tensors_data_manager} { ScopeLogger(); } @@ -220,11 +226,53 @@ PlatformResult PipelineManager::UnregisterSinkListener(const std::string& sink_n // Pipeline::unregisterSinkCallback() end // Pipeline::registerCustomFilter() begin +PlatformResult PipelineManager::RegisterCustomFilter(const std::string& custom_filter_name, + const std::string& listener_name, + TensorsInfo* input_tensors_info_ptr, + TensorsInfo* output_tensors_info_ptr) { + ScopeLogger( + "custom_filter_name: [%s], listener_name: [%s], input_tensors_info::id: [%d], " + "output_tensors_info::id: [%d]", + custom_filter_name.c_str(), listener_name.c_str(), input_tensors_info_ptr->Id(), + output_tensors_info_ptr->Id()); + + if (custom_filters_.count(custom_filter_name)) { + LoggerE("Listener for [%s] custom_filter is already registered", custom_filter_name.c_str()); + return PlatformResult{ErrorCode::ABORT_ERR, "Internal CustomFilter error"}; + } + + std::unique_ptr custom_filter_ptr; + auto ret = CustomFilter::CreateAndRegisterCustomFilter( + custom_filter_name, listener_name, input_tensors_info_ptr, output_tensors_info_ptr, + instance_ptr_, tensors_info_manager_, tensors_data_manager_, &custom_filter_ptr); + + if (!ret) { + return ret; + } + custom_filters_.insert({custom_filter_name, std::move(custom_filter_ptr)}); + + return PlatformResult{}; +} // Pipeline::registerCustomFilter() end // Pipeline::unregisterCustomFilter() begin +PlatformResult PipelineManager::UnregisterCustomFilter(const std::string& custom_filter_name) { + ScopeLogger("custom_filter_name: [%s]", custom_filter_name.c_str()); + + auto custom_filter_it = custom_filters_.find(custom_filter_name); + if (custom_filters_.end() == custom_filter_it) { + LoggerD("custom_filter [%s] not found", custom_filter_name.c_str()); + return PlatformResult{ErrorCode::INVALID_VALUES_ERR, + "\"" + custom_filter_name + "\" CustomFilter not found"}; + } + auto ret = custom_filter_it->second->Unregister(); + if (ret) { + custom_filters_.erase(custom_filter_it); + } + return ret; +} // Pipeline::unregisterCustomFilter() end // NodeInfo::getProperty() begin diff --git a/src/ml/ml_pipeline_manager.h b/src/ml/ml_pipeline_manager.h index a2e6fa9c..e490f829 100644 --- a/src/ml/ml_pipeline_manager.h +++ b/src/ml/ml_pipeline_manager.h @@ -22,6 +22,7 @@ #include "common/platform_result.h" #include "ml_pipeline.h" +#include "ml_pipeline_custom_filter.h" #include "ml_tensors_data_manager.h" #include "ml_tensors_info_manager.h" @@ -32,7 +33,7 @@ namespace ml { class PipelineManager { public: - PipelineManager(common::Instance* instance_ptr, TensorsInfoManager* tim); + PipelineManager(common::Instance* instance_ptr, TensorsInfoManager* tim, TensorsDataManager* tdm); ~PipelineManager(); @@ -87,11 +88,15 @@ class PipelineManager { // Pipeline::unregisterSinkCallback() end // Pipeline::registerCustomFilter() begin + PlatformResult RegisterCustomFilter(const std::string& custom_filter_name, + const std::string& listener_name, + TensorsInfo* input_tensors_info_ptr, + TensorsInfo* output_tensors_info_ptr); // Pipeline::registerCustomFilter() end // Pipeline::unregisterCustomFilter() begin - + PlatformResult UnregisterCustomFilter(const std::string& custom_filter_name); // Pipeline::unregisterCustomFilter() end // NodeInfo::getProperty() begin @@ -132,7 +137,9 @@ class PipelineManager { private: common::Instance* instance_ptr_; TensorsInfoManager* tensors_info_manager_; + TensorsDataManager* tensors_data_manager_; std::map> pipelines_; + std::unordered_map> custom_filters_; }; } // namespace ml