[ML][common] Add a method for cloning native tensors 99/252699/9
authorPawel Wasowski <p.wasowski2@samsung.com>
Mon, 1 Feb 2021 14:01:02 +0000 (15:01 +0100)
committerPawel Wasowski <p.wasowski2@samsung.com>
Tue, 2 Feb 2021 09:14:04 +0000 (10:14 +0100)
ACR: TWDAPI-273/TWDAPI-274

This commit adds a method for cloning data and info from
ml_tensors_data_h and ml_tensors_info_h to TensorsData and related
TensorsInfo objects.

[Verification] Code compiles. Operation is verified in the next commit.

Change-Id: I0268e4d902b7f9a25b5ae531471cdf2028acd77b
Signed-off-by: Pawel Wasowski <p.wasowski2@samsung.com>
src/ml/ml_tensors_info_manager.cc
src/ml/ml_tensors_info_manager.h

index 9a052df58ae403d54367281309dccf7ac4e8c75c..b376e85ac7f08556167a87d20367444e5ffa27ab 100644 (file)
@@ -20,6 +20,7 @@
 
 using common::ErrorCode;
 using common::PlatformResult;
+using extension::ml::TensorRawData;
 
 namespace extension {
 namespace ml {
@@ -402,5 +403,63 @@ TensorsData* TensorsInfoManager::CreateTensorsData(TensorsInfo* tensors_info) {
   return tensors_data_manager_->CreateTensorsData(t_info);
 };
 
+TensorsData* TensorsInfoManager::CloneNativeTensorWithData(ml_tensors_info_h tensors_info_src,
+                                                           ml_tensors_data_h tensors_data_src) {
+  ScopeLogger("tensors_info_src: [%p], tensors_data_src: [%p]", tensors_info_src, tensors_data_src);
+
+  auto* tensors_info_clone = CreateTensorsInfo();
+  if (!tensors_info_clone) {
+    LoggerE("Could not create TensorsInfo");
+    return nullptr;
+  }
+
+  auto ret = ml_tensors_info_clone(tensors_info_clone->Handle(), tensors_info_src);
+  if (ML_ERROR_NONE != ret) {
+    LoggerE("ml_tensors_info_clone() failed: [%d] (%s)", ret, get_error_message(ret));
+    DisposeTensorsInfo(tensors_info_clone);
+    return nullptr;
+  }
+
+  auto* tensors_data_clone = tensors_data_manager_->CreateTensorsData(tensors_info_clone);
+  if (!tensors_data_clone) {
+    LoggerE("Could not create TensorsData");
+    DisposeTensorsInfo(tensors_info_clone);
+    return nullptr;
+  }
+
+  unsigned int tensors_count = 0;
+  auto result = tensors_info_clone->NativeGetCount(&tensors_count);
+  if (!result) {
+    LoggerE("Getting count failed");
+    DisposeTensorsInfo(tensors_info_clone);
+    tensors_data_manager_->DisposeTensorsData(tensors_data_clone);
+    return nullptr;
+  }
+
+  for (unsigned int i = 0; i < tensors_count; ++i) {
+    void* data = nullptr;
+    size_t data_size = 0;
+    ret = ml_tensors_data_get_tensor_data(tensors_data_src, i, &data, &data_size);
+    if (ML_ERROR_NONE != ret) {
+      LoggerE("ml_tensors_data_get_tensor_data() failed: [%d] (%s), i: [%u]", ret,
+              get_error_message(ret), i);
+      DisposeTensorsInfo(tensors_info_clone);
+      tensors_data_manager_->DisposeTensorsData(tensors_data_clone);
+      return nullptr;
+    }
+
+    ret = ml_tensors_data_set_tensor_data(tensors_data_clone->Handle(), i, data, data_size);
+    if (ML_ERROR_NONE != ret) {
+      LoggerE("ml_tensors_data_set_tensor_data() failed: [%d] (%s), i: [%u]", ret,
+              get_error_message(ret), i);
+      DisposeTensorsInfo(tensors_info_clone);
+      tensors_data_manager_->DisposeTensorsData(tensors_data_clone);
+      return nullptr;
+    }
+  }
+
+  LoggerD("Cloning tensor with data successful");
+  return tensors_data_clone;
+}
 }  // ml
 }  // extension
index 56f1c1ddbe04571490e2987c37933de2c3f58685..344ee4cddf907e15a1655306c2109aa7a6868cd4 100644 (file)
@@ -81,6 +81,8 @@ class TensorsInfoManager {
   PlatformResult DisposeTensorsInfo(TensorsInfo* t);
 
   TensorsData* CreateTensorsData(TensorsInfo* tensors_info);
+  TensorsData* CloneNativeTensorWithData(ml_tensors_info_h tensors_info_src,
+                                         ml_tensors_data_h tensors_data_src);
 
  private:
   TensorsInfoManager(TensorsInfoManager const&) = delete;