[ML][Pipeline] Implement {register, unregister}CustomFilter 10/253510/13
authorPawel Wasowski <p.wasowski2@samsung.com>
Thu, 18 Feb 2021 10:09:04 +0000 (11:09 +0100)
committerPawel Wasowski <p.wasowski2@samsung.com>
Tue, 23 Feb 2021 17:45:46 +0000 (18:45 +0100)
ACR: TWDAPI-274

This is the first part of the implementation of nnstreamer's pipeline
CustomFilter Web API. It enables registering and unregistering
CustomFilters, but they don't process the data.

[Verification] Tested with the snippets below, works fine

var inputTI = new tizen.ml.TensorsInfo();
inputTI.addTensorInfo('ti1', 'UINT8', [4, 20, 15, 1]);
var outputTI = new tizen.ml.TensorsInfo();
outputTI.addTensorInfo('ti1', 'UINT8', [1200]);
var flattenPlusOne = function(input) {
    console.log("Custom filter called with: ");
    var outputTD = outputTI.getTensorsData();
    var rawInputData = input.getTensorRawData(0);
    for (var i = 0; i < rawInputData.data.size; ++i) {
        rawInputData.data[i] = rawInputData.data[i] + 1;
    }
        outputTD.setTensorRawData(0, rawInputData.data);
        return new tizen.ml.CustomFilterOutput(0, outputTD);
}

// register - the happy scenario
tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI,
    outputTI, function errorCallback(error) {
        console.warn('custom filter error:') ; console.warn(error);
    });

var pipeline_def = "videotestsrc num-buffers=3 "
                   + "! video/x-raw,width=20,height=15,format=BGRA "
                   + "! tensor_converter "
                   + "! tensor_filter framework=custom-easy model=testfilter2 "
                   + "! appsink name=mysink";

var pipeline = tizen.ml.pipeline.createPipeline(pipeline_def,
                                                state => {console.log(state);})
// READY
pipeline.start()

// unregister - the happy scenario
tizen.ml.pipeline.unregisterCustomFilter('testfilter2')

// overwrite a previously registered filter - that's ok

tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI,
    outputTI, function errorCallback(error) {
        console.warn('custom filter error:') ; console.warn(error);
    });

tizen.ml.pipeline.registerCustomFilter('testfilter2', flattenPlusOne, inputTI,
    outputTI, function errorCallback(error) {
        console.warn('custom filter error:') ; console.warn(error);
    });

// unregister nonexistent filter
tizen.ml.pipeline.unregisterCustomFilter('nonexistentfilter')
// InvalidValuesError: ""nonexistentfilter" CustomFilter not found"

Change-Id: Ib01490e669376149eae654937131116f5ec1a4df
Signed-off-by: Pawel Wasowski <p.wasowski2@samsung.com>
src/ml/js/ml_pipeline.js
src/ml/ml.gyp
src/ml/ml_instance.cc
src/ml/ml_instance.h
src/ml/ml_pipeline_custom_filter.cc [new file with mode: 0644]
src/ml/ml_pipeline_custom_filter.h [new file with mode: 0644]
src/ml/ml_pipeline_manager.cc
src/ml/ml_pipeline_manager.h

index ef68dcd..e1b211a 100755 (executable)
@@ -16,6 +16,7 @@
 
 var kPipelineStateChangeListenerNamePrefix = 'MLPipelineStateChangeListener';
 var kSinkListenerNamePrefix = 'MLPipelineSinkListener';
+var kCustomFilterListenerNamePrefix = 'MLPipelineCustomFilterListener';
 
 //PipelineManager::createPipeline() begin
 var ValidPipelineDisposeExceptions = ['NotFoundError', 'NotSupportedError', 'AbortError'];
@@ -485,14 +486,6 @@ Pipeline.prototype.unregisterSinkListener = function() {
 };
 //Pipeline::unregisterSinkListener() end
 
-//Pipeline::registerCustomFilter() begin
-
-//Pipeline::registerCustomFilter() end
-
-//Pipeline::unregisterCustomFilter() begin
-
-//Pipeline::unregisterCustomFilter() end
-
 var PropertyType = {
     BOOLEAN: 'BOOLEAN',
     DOUBLE: 'DOUBLE',
@@ -726,4 +719,110 @@ var MachineLearningPipeline = function() {};
 
 MachineLearningPipeline.prototype.createPipeline = CreatePipeline;
 
+//Pipeline::registerCustomFilter() begin
+var ValidRegisterCustomFilterExceptions = [
+    'InvalidValuesError',
+    'NotSupportedError',
+    'TypeMismatchError',
+    'AbortError'
+];
+
+MachineLearningPipeline.prototype.registerCustomFilter = function() {
+    var args = validator_.validateArgs(arguments, [
+        {
+            name: 'name',
+            type: validator_.Types.STRING
+        },
+        {
+            name: 'customFilter',
+            type: types_.FUNCTION
+        },
+        {
+            name: 'inputInfo',
+            type: types_.PLATFORM_OBJECT,
+            values: TensorsInfo
+        },
+        {
+            name: 'outputInfo',
+            type: types_.PLATFORM_OBJECT,
+            values: TensorsInfo
+        },
+        {
+            name: 'errorCallback',
+            type: types_.FUNCTION,
+            optional: true,
+            nullable: true
+        }
+    ]);
+
+    var nativeArgs = {
+        name: args.name,
+        listenerName: kCustomFilterListenerNamePrefix + args.name,
+        inputTensorsInfoId: args.inputInfo._id,
+        outputTensorsInfoId: args.outputInfo._id
+    };
+
+    var customFilterWrapper = function(msg) {
+        // TODO: In the next commit
+    };
+
+    if (native_.listeners_.hasOwnProperty(nativeArgs.listenerName)) {
+        throw new WebAPIException(
+            WebAPIException.INVALID_VALUES_ERR,
+            '"' + nativeArgs.name + '" custom filter is already registered'
+        );
+    }
+
+    native_.addListener(nativeArgs.listenerName, customFilterWrapper);
+
+    var result = native_.callSync('MLPipelineManagerRegisterCustomFilter', nativeArgs);
+    if (native_.isFailure(result)) {
+        native_.removeListener(nativeArgs.listenerName);
+
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidRegisterCustomFilterExceptions,
+            AbortError
+        );
+    }
+};
+//Pipeline::registerCustomFilter() end
+
+//Pipeline::unregisterCustomFilter() begin
+var ValidUnregisterCustomFilterExceptions = [
+    'InvalidValuesError',
+    'NotSupportedError',
+    'AbortError'
+];
+MachineLearningPipeline.prototype.unregisterCustomFilter = function() {
+    var args = validator_.validateArgs(arguments, [
+        {
+            name: 'name',
+            type: validator_.Types.STRING
+        }
+    ]);
+
+    if (!args.has.name) {
+        throw new WebAPIException(
+            WebAPIException.INVALID_VALUES_ERR,
+            'Invalid parameter: custom filter name is mandatory'
+        );
+    }
+
+    var result = native_.callSync('MLPipelineManagerUnregisterCustomFilter', {
+        name: args.name
+    });
+    if (native_.isFailure(result)) {
+        throw native_.getErrorObjectAndValidate(
+            result,
+            ValidUnregisterCustomFilterExceptions,
+            AbortError
+        );
+    }
+
+    var customFilterListenerName = kCustomFilterListenerNamePrefix + args.name;
+    native_.removeListener(customFilterListenerName);
+};
+//Pipeline::unregisterCustomFilter() end
+
 // ML Pipeline API
index 419c44c..b5a7a9e 100644 (file)
@@ -17,6 +17,8 @@
         'ml_instance.h',
         'ml_pipeline.cc',
         'ml_pipeline.h',
+        'ml_pipeline_custom_filter.cc',
+        'ml_pipeline_custom_filter.h',
         'ml_pipeline_manager.cc',
         'ml_pipeline_manager.h',
         'ml_pipeline_nodeinfo.cc',
index aee61d0..0cdf5f5 100644 (file)
@@ -53,6 +53,8 @@ const std::string kSize = "size";
 const std::string kLocation = "location";
 const std::string kShape = "shape";
 const std::string kListenerName = "listenerName";
+const std::string kInputTensorsInfoId = "inputTensorsInfoId";
+const std::string kOutputTensorsInfoId = "outputTensorsInfoId";
 }  //  namespace
 
 using namespace common;
@@ -85,7 +87,7 @@ using namespace common;
 MlInstance::MlInstance()
     : tensors_info_manager_{&tensors_data_manager_},
       single_manager_{&tensors_info_manager_},
-      pipeline_manager_{this, &tensors_info_manager_} {
+      pipeline_manager_{this, &tensors_info_manager_, &tensors_data_manager_} {
   ScopeLogger();
   using namespace std::placeholders;
 
@@ -145,6 +147,8 @@ MlInstance::MlInstance()
   REGISTER_METHOD(MLPipelineSourceInputData);
   REGISTER_METHOD(MLPipelineRegisterSinkListener);
   REGISTER_METHOD(MLPipelineUnregisterSinkListener);
+  REGISTER_METHOD(MLPipelineManagerRegisterCustomFilter);
+  REGISTER_METHOD(MLPipelineManagerUnregisterCustomFilter);
 // Pipeline API end
 
 #undef REGISTER_METHOD
@@ -1191,11 +1195,66 @@ void MlInstance::MLPipelineUnregisterSinkListener(const picojson::value& args,
 // Pipeline::unregisterSinkCallback() end
 
 // Pipeline::registerCustomFilter() begin
+void MlInstance::MLPipelineManagerRegisterCustomFilter(const picojson::value& args,
+                                                       picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+
+  CHECK_ARGS(args, kName, std::string, out);
+  CHECK_ARGS(args, kListenerName, std::string, out);
+  CHECK_ARGS(args, kInputTensorsInfoId, double, out);
+  CHECK_ARGS(args, kOutputTensorsInfoId, double, out);
+
+  const auto& custom_filter_name = args.get(kName).get<std::string>();
+  const auto& listener_name = args.get(kListenerName).get<std::string>();
+  auto input_tensors_info_id = static_cast<int>(args.get(kInputTensorsInfoId).get<double>());
+  auto output_tensors_info_id = static_cast<int>(args.get(kOutputTensorsInfoId).get<double>());
+
+  TensorsInfo* input_tensors_info_ptr =
+      GetTensorsInfoManager().GetTensorsInfo(input_tensors_info_id);
+  if (!input_tensors_info_ptr) {
+    LogAndReportError(
+        PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsInfo error"), &out,
+        ("Could not find TensorsInfo handle with given id: %d", input_tensors_info_id));
+    return;
+  }
 
+  TensorsInfo* output_tensors_info_ptr =
+      GetTensorsInfoManager().GetTensorsInfo(output_tensors_info_id);
+  if (!output_tensors_info_ptr) {
+    LogAndReportError(
+        PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsInfo error"), &out,
+        ("Could not find TensorsInfo handle with given id: %d", output_tensors_info_id));
+    return;
+  }
+
+  auto ret = pipeline_manager_.RegisterCustomFilter(
+      custom_filter_name, listener_name, input_tensors_info_ptr, output_tensors_info_ptr);
+  if (!ret) {
+    LogAndReportError(ret, &out);
+    return;
+  }
+
+  ReportSuccess(out);
+}
 // Pipeline::registerCustomFilter() end
 
 // Pipeline::unregisterCustomFilter() begin
+void MlInstance::MLPipelineManagerUnregisterCustomFilter(const picojson::value& args,
+                                                         picojson::object& out) {
+  ScopeLogger("args: %s", args.serialize().c_str());
+
+  CHECK_ARGS(args, kName, std::string, out);
+
+  const auto& custom_filter_name = args.get(kName).get<std::string>();
 
+  auto ret = pipeline_manager_.UnregisterCustomFilter(custom_filter_name);
+  if (!ret) {
+    LogAndReportError(ret, &out);
+    return;
+  }
+
+  ReportSuccess(out);
+}
 // Pipeline::unregisterCustomFilter() end
 
 // NodeInfo::getProperty() begin
index a44ce69..35c8322 100644 (file)
@@ -152,11 +152,11 @@ class MlInstance : public common::ParsedInstance {
   // Pipeline::unregisterSinkCallback() end
 
   // Pipeline::registerCustomFilter() begin
-
+  void MLPipelineManagerRegisterCustomFilter(const picojson::value& args, picojson::object& out);
   // Pipeline::registerCustomFilter() end
 
   // Pipeline::unregisterCustomFilter() begin
-
+  void MLPipelineManagerUnregisterCustomFilter(const picojson::value& args, picojson::object& out);
   // Pipeline::unregisterCustomFilter() end
 
   // NodeInfo::getProperty() begin
diff --git a/src/ml/ml_pipeline_custom_filter.cc b/src/ml/ml_pipeline_custom_filter.cc
new file mode 100644 (file)
index 0000000..43d5576
--- /dev/null
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *    Licensed under the Apache License, Version 2.0 (the "License");
+ *    you may not use this file except in compliance with the License.
+ *    You may obtain a copy of the License at
+ *
+ *        http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *    Unless required by applicable law or agreed to in writing, software
+ *    distributed under the License is distributed on an "AS IS" BASIS,
+ *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *    See the License for the specific language governing permissions and
+ *    limitations under the License.
+ */
+
+#include <nnstreamer/nnstreamer.h>
+#include <utility>
+
+#include "common/tools.h"
+#include "ml_pipeline_custom_filter.h"
+#include "ml_utils.h"
+
+using common::ErrorCode;
+using common::PlatformResult;
+
+namespace {
+
+const std::string kListenerId = "listenerId";
+
+}  //  namespace
+
+namespace extension {
+namespace ml {
+namespace pipeline {
+
+PlatformResult CustomFilter::CreateAndRegisterCustomFilter(
+    const std::string& name, const std::string& listener_name, TensorsInfo* input_tensors_info_ptr,
+    TensorsInfo* output_tensors_info_ptr, common::Instance* instance_ptr,
+    TensorsInfoManager* tensors_info_manager_ptr, TensorsDataManager* tensors_data_manager_ptr,
+    std::unique_ptr<CustomFilter>* out) {
+  ScopeLogger(
+      "name: [%s], listener_name: [%s], input_tensors_info::id: [%d], "
+      "output_tensors_info::id: [%d]",
+      name.c_str(), listener_name.c_str(), input_tensors_info_ptr->Id(),
+      output_tensors_info_ptr->Id());
+
+  auto* input_tensors_info_clone_ptr =
+      tensors_info_manager_ptr->CloneTensorsInfo(input_tensors_info_ptr);
+  if (!input_tensors_info_clone_ptr) {
+    return LogAndCreateResult(
+        ErrorCode::ABORT_ERR, "Could not register custom filter",
+        ("Could not clone TensorsInfo with id: [%d]", input_tensors_info_ptr->Id()));
+  }
+
+  auto* output_tensors_info_clone_ptr =
+      tensors_info_manager_ptr->CloneTensorsInfo(output_tensors_info_ptr);
+  if (!output_tensors_info_clone_ptr) {
+    return LogAndCreateResult(
+        ErrorCode::ABORT_ERR, "Could not register custom filter",
+        ("Could not clone TensorsInfo with id: [%d]", output_tensors_info_ptr->Id()));
+  }
+
+  auto custom_filter_ptr = std::unique_ptr<CustomFilter>(new (std::nothrow) CustomFilter{
+      name, listener_name, input_tensors_info_clone_ptr, output_tensors_info_clone_ptr,
+      instance_ptr, tensors_info_manager_ptr, tensors_data_manager_ptr,
+      std::this_thread::get_id()});
+  ;
+  if (!custom_filter_ptr) {
+    return LogAndCreateResult(ErrorCode::ABORT_ERR, "Could not register custom filter",
+                              ("Could not allocate memory"));
+  }
+
+  ml_custom_easy_filter_h custom_filter_handle = nullptr;
+  auto ret = ml_pipeline_custom_easy_filter_register(
+      name.c_str(), input_tensors_info_ptr->Handle(), output_tensors_info_ptr->Handle(),
+      CustomFilterListener, static_cast<void*>(custom_filter_ptr.get()), &custom_filter_handle);
+  if (ML_ERROR_NONE != ret) {
+    LoggerE("ml_pipeline_custom_easy_filter_register() failed: [%d] (%s)", ret,
+            get_error_message(ret));
+    return util::ToPlatformResult(ret, "Could not register custom filter");
+  }
+  LoggerD("ml_pipeline_custom_easy_filter_register() succeeded");
+  custom_filter_ptr->custom_filter_ = custom_filter_handle;
+
+  *out = std::move(custom_filter_ptr);
+
+  return PlatformResult{};
+}
+
+CustomFilter::CustomFilter(const std::string& name, const std::string& listener_name,
+                           TensorsInfo* input_tensors_info, TensorsInfo* output_tensors_info,
+                           common::Instance* instance_ptr,
+                           TensorsInfoManager* tensors_info_manager_ptr,
+                           TensorsDataManager* tensors_data_manager_ptr,
+                           const std::thread::id& main_thread_id)
+    : name_{name},
+      listener_name_{listener_name},
+      input_tensors_info_ptr_{input_tensors_info},
+      output_tensors_info_ptr_{output_tensors_info},
+      custom_filter_{nullptr},
+      instance_ptr_{instance_ptr},
+      tensors_info_manager_ptr_{tensors_info_manager_ptr},
+      tensors_data_manager_ptr_{tensors_data_manager_ptr},
+      main_thread_id_{main_thread_id} {
+  ScopeLogger(
+      "name_: [%s], listener_name_: [%s], input_tensors_info::id: [%d], "
+      "output_tensors_info::id: [%d]",
+      name_.c_str(), listener_name_.c_str(), input_tensors_info_ptr_->Id(),
+      output_tensors_info_ptr_->Id());
+}
+
+CustomFilter::~CustomFilter() {
+  ScopeLogger("name: [%s]_, listener_name_: [%s], custom_filter_: [%p]", name_.c_str(),
+              listener_name_.c_str(), custom_filter_);
+
+  Unregister();
+  tensors_info_manager_ptr_->DisposeTensorsInfo(input_tensors_info_ptr_);
+  tensors_info_manager_ptr_->DisposeTensorsInfo(output_tensors_info_ptr_);
+}
+
+PlatformResult CustomFilter::Unregister() {
+  ScopeLogger("name_: [%s], listener_name_: [%s], custom_filter_: [%p]", name_.c_str(),
+              listener_name_.c_str(), custom_filter_);
+
+  if (!custom_filter_) {
+    LoggerD("CustomFilter was already unregistered");
+    return PlatformResult{};
+  }
+
+  auto ret = ml_pipeline_custom_easy_filter_unregister(custom_filter_);
+  if (ML_ERROR_NONE != ret) {
+    LoggerE("ml_pipeline_custom_easy_filter_unregister() failed: [%d] (%s)", ret,
+            get_error_message(ret));
+    return util::ToPlatformResult(ret, "Could not unregister custom_filter");
+  }
+  LoggerD("ml_pipeline_custom_easy_filter_unregister() succeeded");
+
+  custom_filter_ = nullptr;
+
+  return PlatformResult{};
+}
+
+int CustomFilter::CustomFilterListener(const ml_tensors_data_h input_tensors_data,
+                                       ml_tensors_data_h output_tensors_data, void* user_data) {
+  ScopeLogger("input_tensors_data: [%p], tensors_info_out: [%p], user_data: [%p]",
+              input_tensors_data, output_tensors_data, user_data);
+  // TODO: in next commit
+  return -1;
+}
+
+}  // namespace pipeline
+}  // namespace ml
+}  // namespace extension
\ No newline at end of file
diff --git a/src/ml/ml_pipeline_custom_filter.h b/src/ml/ml_pipeline_custom_filter.h
new file mode 100644 (file)
index 0000000..8f86ea4
--- /dev/null
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
+ *
+ *    Licensed under the Apache License, Version 2.0 (the "License");
+ *    you may not use this file except in compliance with the License.
+ *    You may obtain a copy of the License at
+ *
+ *        http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *    Unless required by applicable law or agreed to in writing, software
+ *    distributed under the License is distributed on an "AS IS" BASIS,
+ *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *    See the License for the specific language governing permissions and
+ *    limitations under the License.
+ */
+
+#ifndef ML_ML_PIPELINE_CUSTOM_FILTER_H_
+#define ML_ML_PIPELINE_CUSTOM_FILTER_H_
+
+#include <memory>
+#include <string>
+#include <thread>
+#include <unordered_map>
+
+#include <nnstreamer/nnstreamer.h>
+
+#include "common/extension.h"
+#include "common/platform_result.h"
+
+#include "ml_tensors_data_manager.h"
+#include "ml_tensors_info_manager.h"
+
+using common::PlatformResult;
+
+namespace extension {
+namespace ml {
+namespace pipeline {
+
+class CustomFilter {
+ public:
+  static PlatformResult CreateAndRegisterCustomFilter(
+      const std::string& name, const std::string& listener_name,
+      TensorsInfo* input_tensors_info_ptr, TensorsInfo* output_tensors_info_ptr,
+      common::Instance* instance_ptr, TensorsInfoManager* tensors_info_manager_ptr,
+      TensorsDataManager* tensors_data_manager_ptr, std::unique_ptr<CustomFilter>* out);
+
+  ~CustomFilter();
+
+  PlatformResult Unregister();
+
+  CustomFilter(const CustomFilter&) = delete;
+  CustomFilter& operator=(const CustomFilter&) = delete;
+
+ private:
+  CustomFilter(const std::string& name, const std::string& listener_name,
+               TensorsInfo* input_tensors_info, TensorsInfo* output_tensors_info,
+               common::Instance* instance_ptr, TensorsInfoManager* tensors_info_manager_ptr,
+               TensorsDataManager* tensors_data_manager_ptrr,
+               const std::thread::id& main_thread_id);
+
+  static int CustomFilterListener(const ml_tensors_data_h tensors_data_in,
+                                  ml_tensors_data_h tensors_data_out, void* user_data);
+
+  const std::string name_;
+  const std::string listener_name_;
+  TensorsInfo* input_tensors_info_ptr_;
+  TensorsInfo* output_tensors_info_ptr_;
+  ml_custom_easy_filter_h custom_filter_;
+  common::Instance* instance_ptr_;
+  TensorsInfoManager* tensors_info_manager_ptr_;
+  TensorsDataManager* tensors_data_manager_ptr_;
+
+  std::thread::id main_thread_id_;
+};
+
+}  // namespace pipeline
+}  // namespace ml
+}  // namespace extension
+
+#endif  // ML_ML_PIPELINE_CUSTOM_FILTER_H_
index c2c3e1a..7ed62b5 100644 (file)
  *    limitations under the License.
  */
 
-#include "ml_pipeline_manager.h"
+#include <algorithm>
+#include <regex>
+
 #include "common/tools.h"
+#include "ml_pipeline_manager.h"
 #include "ml_pipeline_switch.h"
 
 using common::PlatformResult;
@@ -27,8 +30,11 @@ namespace extension {
 namespace ml {
 
 PipelineManager::PipelineManager(common::Instance* instance_ptr,
-                                 TensorsInfoManager* tensors_info_manager)
-    : instance_ptr_{instance_ptr}, tensors_info_manager_{tensors_info_manager} {
+                                 TensorsInfoManager* tensors_info_manager,
+                                 TensorsDataManager* tensors_data_manager)
+    : instance_ptr_{instance_ptr},
+      tensors_info_manager_{tensors_info_manager},
+      tensors_data_manager_{tensors_data_manager} {
   ScopeLogger();
 }
 
@@ -220,11 +226,53 @@ PlatformResult PipelineManager::UnregisterSinkListener(const std::string& sink_n
 // Pipeline::unregisterSinkCallback() end
 
 // Pipeline::registerCustomFilter() begin
+PlatformResult PipelineManager::RegisterCustomFilter(const std::string& custom_filter_name,
+                                                     const std::string& listener_name,
+                                                     TensorsInfo* input_tensors_info_ptr,
+                                                     TensorsInfo* output_tensors_info_ptr) {
+  ScopeLogger(
+      "custom_filter_name: [%s], listener_name: [%s], input_tensors_info::id: [%d], "
+      "output_tensors_info::id: [%d]",
+      custom_filter_name.c_str(), listener_name.c_str(), input_tensors_info_ptr->Id(),
+      output_tensors_info_ptr->Id());
+
+  if (custom_filters_.count(custom_filter_name)) {
+    LoggerE("Listener for [%s] custom_filter is already registered", custom_filter_name.c_str());
+    return PlatformResult{ErrorCode::ABORT_ERR, "Internal CustomFilter error"};
+  }
+
+  std::unique_ptr<CustomFilter> custom_filter_ptr;
+  auto ret = CustomFilter::CreateAndRegisterCustomFilter(
+      custom_filter_name, listener_name, input_tensors_info_ptr, output_tensors_info_ptr,
+      instance_ptr_, tensors_info_manager_, tensors_data_manager_, &custom_filter_ptr);
+
+  if (!ret) {
+    return ret;
+  }
 
+  custom_filters_.insert({custom_filter_name, std::move(custom_filter_ptr)});
+
+  return PlatformResult{};
+}
 // Pipeline::registerCustomFilter() end
 
 // Pipeline::unregisterCustomFilter() begin
+PlatformResult PipelineManager::UnregisterCustomFilter(const std::string& custom_filter_name) {
+  ScopeLogger("custom_filter_name: [%s]", custom_filter_name.c_str());
+
+  auto custom_filter_it = custom_filters_.find(custom_filter_name);
+  if (custom_filters_.end() == custom_filter_it) {
+    LoggerD("custom_filter [%s] not found", custom_filter_name.c_str());
+    return PlatformResult{ErrorCode::INVALID_VALUES_ERR,
+                          "\"" + custom_filter_name + "\" CustomFilter not found"};
+  }
 
+  auto ret = custom_filter_it->second->Unregister();
+  if (ret) {
+    custom_filters_.erase(custom_filter_it);
+  }
+  return ret;
+}
 // Pipeline::unregisterCustomFilter() end
 
 // NodeInfo::getProperty() begin
index a2e6fa9..e490f82 100644 (file)
@@ -22,6 +22,7 @@
 #include "common/platform_result.h"
 
 #include "ml_pipeline.h"
+#include "ml_pipeline_custom_filter.h"
 #include "ml_tensors_data_manager.h"
 #include "ml_tensors_info_manager.h"
 
@@ -32,7 +33,7 @@ namespace ml {
 
 class PipelineManager {
  public:
-  PipelineManager(common::Instance* instance_ptr, TensorsInfoManager* tim);
+  PipelineManager(common::Instance* instance_ptr, TensorsInfoManager* tim, TensorsDataManager* tdm);
 
   ~PipelineManager();
 
@@ -87,11 +88,15 @@ class PipelineManager {
   // Pipeline::unregisterSinkCallback() end
 
   // Pipeline::registerCustomFilter() begin
+  PlatformResult RegisterCustomFilter(const std::string& custom_filter_name,
+                                      const std::string& listener_name,
+                                      TensorsInfo* input_tensors_info_ptr,
+                                      TensorsInfo* output_tensors_info_ptr);
 
   // Pipeline::registerCustomFilter() end
 
   // Pipeline::unregisterCustomFilter() begin
-
+  PlatformResult UnregisterCustomFilter(const std::string& custom_filter_name);
   // Pipeline::unregisterCustomFilter() end
 
   // NodeInfo::getProperty() begin
@@ -132,7 +137,9 @@ class PipelineManager {
  private:
   common::Instance* instance_ptr_;
   TensorsInfoManager* tensors_info_manager_;
+  TensorsDataManager* tensors_data_manager_;
   std::map<int, std::unique_ptr<Pipeline>> pipelines_;
+  std::unordered_map<std::string, std::unique_ptr<CustomFilter>> custom_filters_;
 };
 
 }  // namespace ml