From: Piotr Kosko Date: Mon, 1 Feb 2021 07:33:51 +0000 (+0000) Subject: Merge "[systeminfo] Prevent possible crash when failure initialization" into tizen X-Git-Tag: submit/tizen/20210202.064821^0 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=fabc6da2918c6e757a1c703ce40984025a73e3ee;hp=dde9626c01553c1516a4ea63d0ccaaf1ee7a9b43;p=platform%2Fcore%2Fapi%2Fwebapi-plugins.git Merge "[systeminfo] Prevent possible crash when failure initialization" into tizen --- diff --git a/src/ml/js/ml_common.js b/src/ml/js/ml_common.js index 1c11756..699cddf 100755 --- a/src/ml/js/ml_common.js +++ b/src/ml/js/ml_common.js @@ -26,25 +26,22 @@ var MAX_TENSORS_INFO_COUNT = 16; // TensorRawData -var TensorRawData = function() { +var TensorRawData = function(data, size, shape) { Object.defineProperties(this, { data: { enumerable: true, - get: function() { - throw new WebAPIException(WebAPIException.ABORT_ERR, 'Not implemented'); - } + writable: false, + value: data }, size: { enumerable: true, - get: function() { - throw new WebAPIException(WebAPIException.ABORT_ERR, 'Not implemented'); - } + writable: false, + value: size }, shape: { enumerable: true, - get: function() { - throw new WebAPIException(WebAPIException.ABORT_ERR, 'Not implemented'); - } + writable: false, + value: shape } }); }; @@ -63,6 +60,47 @@ var TensorType = { UNKNOWN: 'UNKNOWN' }; +function _GetBufferTypeFromTensorType(tensorType) { + switch (tensorType) { + case 'INT8': + return Int8Array; + case 'UINT8': + return Uint8Array; + case 'INT16': + return Int16Array; + case 'UINT16': + return Uint16Array; + case 'FLOAT32': + return Float32Array; + case 'INT32': + return Int32Array; + case 'UINT32': + return Uint32Array; + case 'FLOAT64': + return Float64Array; + case 'INT64': + return BigInt64Array; + case 'UINT64': + return BigUint64Array; + } + return Uint8Array; +} + +function _CheckIfArrayHasOnlyNumbersAndThrow(array, arrayName) { + if (xwalk.utils.type.isNullOrUndefined(array)) { + return; + } + + array.forEach(function(d) { + if (Number.isInteger(d) == false) { + throw new WebAPIException( + WebAPIException.TYPE_MISMATCH_ERR, + arrayName + ' array has to contain only integers' + ); + } + }); +} + // TensorsData var _ValidTensorsDataIds = new Set(); @@ -99,13 +137,150 @@ var TensorsData = function(id, tensorsInfoId) { }; TensorsData.prototype.getTensorRawData = function() { - _CheckIfTensorsDataNotDisposed(); - throw new WebAPIException(WebAPIException.ABORT_ERR, 'Not implemented'); + _CheckIfTensorsDataNotDisposed(this._id); + var args = validator_.validateArgs(arguments, [ + { + name: 'index', + type: types_.LONG + }, + { + name: 'location', + type: types_.ARRAY, + optional: true + }, + { + name: 'size', + type: types_.ARRAY, + optional: true + } + ]); + + if (!args.has.index) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: index is undefined' + ); + } + + _CheckIfArrayHasOnlyNumbersAndThrow(args.location, 'location'); + _CheckIfArrayHasOnlyNumbersAndThrow(args.size, 'size'); + + var callArgs = { + tensorsDataId: this._id, + index: args.index, + location: args.location ? args.location : [], + size: args.size ? args.size : [] + }; + + var result = native_.callSync('MLTensorsDataGetTensorRawData', callArgs); + + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + TensorsInfoGettersSettersValidExceptions, + AbortError + ); + } + + // TODO: modify StringToArray to accept also float types, not only int + var data = privUtils_.StringToArray(result.buffer, Uint8Array); + var ArrayType = _GetBufferTypeFromTensorType(result.type); + var shape = result.shape; + return new TensorRawData(new ArrayType(data.buffer), data.byteLength, shape); }; -TensorsData.prototype.setTensorData = function() { - _CheckIfTensorsDataNotDisposed(); - throw new WebAPIException(WebAPIException.ABORT_ERR, 'Not implemented'); +var TensorsDataSetTensorRawDataExceptions = [ + 'InvalidValuesError', + 'TypeMismatchError', + 'NotSupportedError', + 'AbortError' +]; + +function ValidateBufferForTensorsData(tensorsData, index, buffer) { + var result = native_.callSync('MLTensorsDataGetTensorType', { + tensorsDataId: tensorsData._id, + index: index + }); + + if (native_.isFailure(result)) { + throw AbortError; + } + var tensorType = native_.getResultObject(result); + var ret = buffer; + + var ArrayType = _GetBufferTypeFromTensorType(tensorType); + if (Array.isArray(buffer)) { + // in case of standard Array - create TypedArray from it + ret = new ArrayType(buffer); + } else if (false == buffer instanceof ArrayType) { + throw new WebAPIException( + WebAPIException.TYPE_MISMATCH_ERR, + 'buffer array has incompatible type, expected: ' + + ArrayType.name + + ', got: ' + + x.constructor.name + ); + } + return ret; +} + +TensorsData.prototype.setTensorRawData = function() { + _CheckIfTensorsDataNotDisposed(this._id); + var argsIndex = validator_.validateArgs(arguments, [ + { + name: 'index', + type: types_.LONG + } + ]); + var argsLocSize = validator_.validateArgs(Array.prototype.slice.call(arguments, 2), [ + { + name: 'location', + type: types_.ARRAY, + optional: true + }, + { + name: 'size', + type: types_.ARRAY, + optional: true + } + ]); + + if (!argsIndex.has.index) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: index is undefined' + ); + } + + if (arguments.length < 2) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: buffer is undefined' + ); + } + var buffer = ValidateBufferForTensorsData(this, argsIndex.index, arguments[1]); + + _CheckIfArrayHasOnlyNumbersAndThrow(argsLocSize.location, 'location'); + _CheckIfArrayHasOnlyNumbersAndThrow(argsLocSize.size, 'size'); + + // TODO: modify ArrayToString to accept also float types, not only int + var encodedData = privUtils_.ArrayToString(new Uint8Array(buffer.buffer)); + var callArgs = { + index: argsIndex.index, + tensorsDataId: this._id, + buffer: encodedData, + location: argsLocSize.location ? argsLocSize.location : [], + size: argsLocSize.size ? argsLocSize.size : [] + }; + var result = native_.callSync('MLTensorsDataSetTensorRawData', callArgs); + + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + TensorsDataSetTensorRawDataExceptions, + AbortError + ); + } }; TensorsData.prototype.dispose = function() { @@ -209,14 +384,7 @@ TensorsInfo.prototype.addTensorInfo = function() { } ]); - args.dimensions.forEach(function(d) { - if (Number.isInteger(d) == false) { - throw new WebAPIException( - WebAPIException.TYPE_MISMATCH_ERR, - 'dimensions array has to contain only integers' - ); - } - }); + _CheckIfArrayHasOnlyNumbersAndThrow(args.dimensions, 'dimensions'); var callArgs = { name: args.name, @@ -337,14 +505,7 @@ TensorsInfo.prototype.setDimensions = function() { } ]); - args.dimensions.forEach(function(d) { - if (Number.isInteger(d) == false) { - throw new WebAPIException( - WebAPIException.TYPE_MISMATCH_ERR, - 'dimensions array has to contain only integers' - ); - } - }); + _CheckIfArrayHasOnlyNumbersAndThrow(args.dimensions, 'dimensions'); var callArgs = { index: args.index, diff --git a/src/ml/js/ml_pipeline.js b/src/ml/js/ml_pipeline.js index 425e34c..cf92f3f 100755 --- a/src/ml/js/ml_pipeline.js +++ b/src/ml/js/ml_pipeline.js @@ -208,7 +208,75 @@ Pipeline.prototype.getNodeInfo = function() { //Pipeline::getNodeInfo() end //Pipeline::getSource() begin +var ValidInputTensorsInfoExceptions = ['NotFoundError', 'AbortError']; +function Source(name, pipeline_id) { + Object.defineProperties(this, { + name: { + enumerable: true, + value: name + }, + inputTensorsInfo: { + get: function() { + var result = native_.callSync('MLPipelineGetInputTensorsInfo', { + id: this._pipeline_id, + name: this.name + }); + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidInputTensorsInfoExceptions, + AbortError + ); + } + + return new TensorsInfo(result.id); + } + }, + _pipeline_id: { + value: pipeline_id + } + }); +} + +var ValidPipelineGetSourceExceptions = [ + 'InvalidStateError', + 'InvalidValuesError', + 'NotFoundError', + 'NotSupportedError', + 'AbortError' +]; + +Pipeline.prototype.getSource = function() { + var args = validator_.validateArgs(arguments, [ + { + name: 'name', + type: validator_.Types.STRING + } + ]); + + if (!args.has.name) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: name is mandatory' + ); + } + + var nativeArgs = { + id: this._id, + name: args.name + }; + var result = native_.callSync('MLPipelineGetSource', nativeArgs); + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidPipelineGetSourceExceptions, + AbortError + ); + } + + return new Source(args.name, this._id); +}; //Pipeline::getSource() end //Pipeline::getSwitch() begin @@ -261,6 +329,11 @@ Pipeline.prototype.getSwitch = function() { //Pipeline::getSwitch() end //Pipeline::getValve() begin +var ValidValveIsOpenAndSetOpenExceptions = [ + 'NotFoundError', + 'NotSupportedError', + 'AbortError' +]; function Valve(name, pipeline_id) { Object.defineProperties(this, { name: { @@ -269,6 +342,26 @@ function Valve(name, pipeline_id) { }, _pipeline_id: { value: pipeline_id + }, + isOpen: { + get: function() { + var result = native_.callSync('MLPipelineValveIsOpen', { + id: this._pipeline_id, + name: this.name + }); + + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidValveIsOpenAndSetOpenExceptions, + AbortError + ); + } + + return result.result; + }, + set: function() {}, + enumerable: true } }); } @@ -491,7 +584,36 @@ Switch.prototype.select = function() { //Switch::select() end //Valve::setOpen() begin +Valve.prototype.setOpen = function() { + var args = validator_.validateArgs(arguments, [ + { + name: 'open', + type: validator_.Types.BOOLEAN + } + ]); + + if (!args.has.open) { + throw new WebAPIException( + WebAPIException.INVALID_VALUES_ERR, + 'Invalid parameter: open is mandatory' + ); + } + + var nativeArgs = { + id: this._pipeline_id, + name: this.name, + open: args.open + }; + var result = native_.callSync('MLPipelineValveSetOpen', nativeArgs); + if (native_.isFailure(result)) { + throw native_.getErrorObjectAndValidate( + result, + ValidValveIsOpenAndSetOpenExceptions, + AbortError + ); + } +}; //Valve::setOpen() end var MachineLearningPipeline = function() {}; diff --git a/src/ml/ml.gyp b/src/ml/ml.gyp index 40aebd3..0b08553 100644 --- a/src/ml/ml.gyp +++ b/src/ml/ml.gyp @@ -23,8 +23,8 @@ 'ml_pipeline_nodeinfo.h', 'ml_pipeline_switch.cc', 'ml_pipeline_switch.h', - #TODO pipeline Source - #TODO pipeline Valve + 'ml_pipeline_source.h', + 'ml_pipeline_source.cc', 'ml_pipeline_valve.h', 'ml_pipeline_valve.cc', 'ml_tensors_data_manager.cc', diff --git a/src/ml/ml_instance.cc b/src/ml/ml_instance.cc index e4c7833..078b7b5 100644 --- a/src/ml/ml_instance.cc +++ b/src/ml/ml_instance.cc @@ -40,10 +40,15 @@ const std::string kDefinition = "definition"; const std::string kPipelineStateChangeListenerName = "listenerName"; const std::string kOtherId = "otherId"; const std::string kPadName = "padName"; +const std::string kOpen = "open"; const std::string kNodeName = "nodeName"; const std::string kProperty = "property"; const std::string kBOOLEAN = "BOOLEAN"; const std::string kSTRING = "STRING"; +const std::string kBuffer = "buffer"; +const std::string kSize = "size"; +const std::string kLocation = "location"; +const std::string kShape = "shape"; } // namespace using namespace common; @@ -76,7 +81,7 @@ using namespace common; MlInstance::MlInstance() : tensors_info_manager_{&tensors_data_manager_}, single_manager_{&tensors_info_manager_}, - pipeline_manager_{this} { + pipeline_manager_{this, &tensors_info_manager_} { ScopeLogger(); using namespace std::placeholders; @@ -98,8 +103,13 @@ MlInstance::MlInstance() REGISTER_METHOD(MLTensorsInfoClone); REGISTER_METHOD(MLTensorsInfoEquals); REGISTER_METHOD(MLTensorsInfoDispose); + REGISTER_METHOD(MLPipelineValveSetOpen); + REGISTER_METHOD(MLPipelineValveIsOpen); REGISTER_METHOD(MLTensorsDataDispose); + REGISTER_METHOD(MLTensorsDataGetTensorRawData); + REGISTER_METHOD(MLTensorsDataGetTensorType); + REGISTER_METHOD(MLTensorsDataSetTensorRawData); // Single API begin REGISTER_METHOD(MLSingleManagerOpenModel); @@ -128,6 +138,8 @@ MlInstance::MlInstance() REGISTER_METHOD(MLPipelineGetValve); REGISTER_METHOD(MLPipelineNodeInfoGetProperty); REGISTER_METHOD(MLPipelineNodeInfoSetProperty); + REGISTER_METHOD(MLPipelineGetSource); + REGISTER_METHOD(MLPipelineGetInputTensorsInfo); // Pipeline API end #undef REGISTER_METHOD @@ -602,6 +614,110 @@ void MlInstance::MLTensorsDataDispose(const picojson::value& args, picojson::obj } ReportSuccess(out); } +void MlInstance::MLTensorsDataGetTensorRawData(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + CHECK_ARGS(args, kTensorsDataId, double, out); + CHECK_ARGS(args, kIndex, double, out); + CHECK_ARGS(args, kLocation, picojson::array, out); + CHECK_ARGS(args, kSize, picojson::array, out); + + int tensor_data_id = static_cast(args.get(kTensorsDataId).get()); + int index = static_cast(args.get(kIndex).get()); + + TensorsData* tensors_data = GetTensorsDataManager().GetTensorsData(tensor_data_id); + if (nullptr == tensors_data) { + LogAndReportError(PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsData error"), &out, + ("Could not find TensorsData handle with given id: %d", tensor_data_id)); + return; + } + // TODO: validate location and size - will be done in future commit + int location[ML_TENSOR_RANK_LIMIT]; + int size[ML_TENSOR_RANK_LIMIT]; + TensorRawData raw_data; + PlatformResult result = tensors_data->GetTensorRawData(index, location, size, &raw_data); + if (!result) { + LogAndReportError(result, &out); + return; + } + + std::vector out_data{raw_data.data, raw_data.data + raw_data.size}; + out[kBuffer] = picojson::value(picojson::string_type, true); + common::encode_binary_in_string(out_data, out[kBuffer].get()); + + out[kType] = picojson::value(raw_data.type_str); + picojson::array shape = picojson::array{}; + for (int i = 0; i < ML_TENSOR_RANK_LIMIT; i++) { + shape.push_back(picojson::value{static_cast(raw_data.shape[i])}); + } + out[kShape] = picojson::value{shape}; + + ReportSuccess(out); +} + +void MlInstance::MLTensorsDataGetTensorType(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + CHECK_ARGS(args, kTensorsDataId, double, out); + CHECK_ARGS(args, kIndex, double, out); + + int tensors_data_id = static_cast(args.get(kTensorsDataId).get()); + int index = static_cast(args.get(kIndex).get()); + + TensorsData* tensors_data = GetTensorsDataManager().GetTensorsData(tensors_data_id); + if (nullptr == tensors_data) { + LogAndReportError(PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsData error"), &out, + ("Could not find TensorsData handle with given id: %d", tensors_data_id)); + return; + } + + std::string tensor_type_string; + PlatformResult result = + types::TensorTypeEnum.getName(tensors_data->GetTensorType(index), &tensor_type_string); + if (!result) { + LogAndReportError(PlatformResult(ErrorCode::ABORT_ERR, "Error getting name of TensorType"), + &out, + ("TensorTypeEnum.getName() failed, error: %s", result.message().c_str())); + return; + } + + picojson::value val = picojson::value{tensor_type_string}; + ReportSuccess(val, out); +} + +void MlInstance::MLTensorsDataSetTensorRawData(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + CHECK_ARGS(args, kTensorsDataId, double, out); + CHECK_ARGS(args, kIndex, double, out); + CHECK_ARGS(args, kBuffer, std::string, out); + CHECK_ARGS(args, kLocation, picojson::array, out); + CHECK_ARGS(args, kSize, picojson::array, out); + + int tensors_data_id = static_cast(args.get(kTensorsDataId).get()); + int index = static_cast(args.get(kIndex).get()); + + TensorsData* tensors_data = GetTensorsDataManager().GetTensorsData(tensors_data_id); + if (nullptr == tensors_data) { + LogAndReportError(PlatformResult(ErrorCode::ABORT_ERR, "Internal TensorsData error"), &out, + ("Could not find TensorsData handle with given id: %d", tensors_data_id)); + return; + } + + int location[ML_TENSOR_RANK_LIMIT] = {0, 0, 0, 0}; + int size[ML_TENSOR_RANK_LIMIT] = {-1, -1, -1, -1}; + // TODO: validate location and size - will be done in future commit + + const std::string& str_buffer = args.get(kBuffer).get(); + std::vector buffer; + common::decode_binary_from_string(str_buffer, buffer); + + TensorRawData rawData{.data = buffer.data(), .size = buffer.size()}; + PlatformResult result = tensors_data->SetTensorRawData(index, location, size, rawData); + if (!result) { + LogAndReportError(result, &out); + return; + } + + ReportSuccess(out); +} // Common ML API end // Single API begin @@ -901,7 +1017,24 @@ void MlInstance::MLPipelineGetNodeInfo(const picojson::value& args, picojson::ob // Pipeline::getNodeInfo() end // Pipeline::getSource() begin +void MlInstance::MLPipelineGetSource(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kName, std::string, out); + + auto name = args.get(kName).get(); + auto id = static_cast(args.get(kId).get()); + + PlatformResult result = pipeline_manager_.GetSource(id, name); + + if (!result) { + LogAndReportError(result, &out); + return; + } + ReportSuccess(out); +} // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -1030,7 +1163,24 @@ void MlInstance::MLPipelineNodeInfoSetProperty(const picojson::value& args, pico // NodeInfo::setProperty() end // Source::inputTensorsInfo begin +void MlInstance::MLPipelineGetInputTensorsInfo(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: [%s]", args.serialize().c_str()); + + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kName, std::string, out); + + auto id = static_cast(args.get(kId).get()); + const auto& name = args.get(kName).get(); + + int res_id = -1; + PlatformResult result = pipeline_manager_.getInputTensorsInfo(id, name, &res_id); + if (!result) { + LogAndReportError(result, &out); + return; + } + ReportSuccess(out); +} // Source::inputTensorsInfo end // Source::inputData() begin @@ -1081,9 +1231,48 @@ void MlInstance::MLPipelineSwitchSelect(const picojson::value& args, picojson::o // Switch::select() end // Valve::setOpen() begin +void MlInstance::MLPipelineValveSetOpen(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kName, std::string, out); + CHECK_ARGS(args, kOpen, bool, out); + + auto name = args.get(kName).get(); + auto pipeline_id = args.get(kId).get(); + auto open = args.get(kOpen).get(); + + auto ret = pipeline_manager_.ValveSetOpen(pipeline_id, name, open); + if (!ret) { + LogAndReportError(ret, &out); + return; + } + ReportSuccess(out); +} // Valve::setOpen() end +// Valve::isOpen() begin +void MlInstance::MLPipelineValveIsOpen(const picojson::value& args, picojson::object& out) { + ScopeLogger("args: %s", args.serialize().c_str()); + + CHECK_ARGS(args, kId, double, out); + CHECK_ARGS(args, kName, std::string, out); + + auto name = args.get(kName).get(); + auto pipeline_id = args.get(kId).get(); + auto open = true; + + auto ret = pipeline_manager_.ValveIsOpen(pipeline_id, name, &open); + if (!ret) { + LogAndReportError(ret, &out); + return; + } + + ReportSuccess(picojson::value{open}, out); +} +// Valve::isOpen() end + // Pipeline API end #undef CHECK_EXIST diff --git a/src/ml/ml_instance.h b/src/ml/ml_instance.h index 95125fb..6bbc16e 100644 --- a/src/ml/ml_instance.h +++ b/src/ml/ml_instance.h @@ -56,6 +56,9 @@ class MlInstance : public common::ParsedInstance { void MLTensorsInfoDispose(const picojson::value& args, picojson::object& out); void MLTensorsDataDispose(const picojson::value& args, picojson::object& out); + void MLTensorsDataGetTensorRawData(const picojson::value& args, picojson::object& out); + void MLTensorsDataGetTensorType(const picojson::value& args, picojson::object& out); + void MLTensorsDataSetTensorRawData(const picojson::value& args, picojson::object& out); TensorsInfoManager tensors_info_manager_; TensorsDataManager tensors_data_manager_; @@ -103,7 +106,7 @@ class MlInstance : public common::ParsedInstance { // Pipeline::getNodeInfo() end // Pipeline::getSource() begin - + void MLPipelineGetSource(const picojson::value& args, picojson::object& out); // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -139,7 +142,7 @@ class MlInstance : public common::ParsedInstance { // NodeInfo::setProperty() end // Source::inputTensorsInfo begin - + void MLPipelineGetInputTensorsInfo(const picojson::value& args, picojson::object& out); // Source::inputTensorsInfo end // Source::inputData() begin @@ -155,8 +158,12 @@ class MlInstance : public common::ParsedInstance { // Switch::select() end // Valve::setOpen() begin - + void MLPipelineValveSetOpen(const picojson::value& args, picojson::object& out); // Valve::setOpen() end + + // Valve::isOpen() begin + void MLPipelineValveIsOpen(const picojson::value& args, picojson::object& out); + // Valve::isOpen() end // Pipeline API end }; diff --git a/src/ml/ml_pipeline.cc b/src/ml/ml_pipeline.cc index fe35948..1d1cce5 100644 --- a/src/ml/ml_pipeline.cc +++ b/src/ml/ml_pipeline.cc @@ -196,6 +196,8 @@ PlatformResult Pipeline::Dispose() { valves_.clear(); + sources_.clear(); + auto ret = ml_pipeline_destroy(pipeline_); if (ML_ERROR_NONE != ret) { LoggerE("ml_pipeline_destroy() failed: [%d] (%s)", ret, get_error_message(ret)); @@ -210,7 +212,7 @@ PlatformResult Pipeline::Dispose() { // Pipeline::dispose() end // Pipeline::getNodeInfo() begin -PlatformResult Pipeline::GetNodeInfo(std::string& name) { +PlatformResult Pipeline::GetNodeInfo(const std::string& name) { ScopeLogger("id_: [%d], name: [%s]", id_, name.c_str()); auto nodeinfo_it = node_info_.find(name); @@ -233,7 +235,22 @@ PlatformResult Pipeline::GetNodeInfo(std::string& name) { // Pipeline::getNodeInfo() end // Pipeline::getSource() begin +PlatformResult Pipeline::GetSource(const std::string& name) { + ScopeLogger("id: [%d], name: [%s]", id_, name.c_str()); + auto source_it = sources_.find(name); + if (sources_.end() != source_it) { + LoggerD("Source [%s] found", name.c_str()); + return PlatformResult{}; + } + + std::unique_ptr source_ptr; + auto ret = Source::CreateSource(name, pipeline_, &source_ptr); + if (ret) { + sources_.insert({name, std::move(source_ptr)}); + } + return ret; +} // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -270,7 +287,7 @@ PlatformResult Pipeline::GetValve(const std::string& name) { LoggerD("Creating [%s] Valve", name.c_str()); std::unique_ptr valve_ptr; - auto ret = Valve::CreateValve(name, pipeline_, &valve_ptr); + auto ret = Valve::CreateValve(name, pipeline_, *this, &valve_ptr); if (ret) { valves_.insert({name, std::move(valve_ptr)}); } @@ -305,9 +322,7 @@ PlatformResult Pipeline::getProperty(const std::string& node_name, const std::st return PlatformResult{ErrorCode::NOT_FOUND_ERR, "NodeInfo not found"}; } - auto ret = nodeinfo_it->second->getProperty(name, type, property); - - return ret; + return nodeinfo_it->second->getProperty(name, type, property); } // NodeInfo::getProperty() end @@ -322,13 +337,22 @@ PlatformResult Pipeline::setProperty(const std::string& node_name, const std::st return PlatformResult{ErrorCode::NOT_FOUND_ERR, "NodeInfo not found"}; } - auto ret = nodeinfo_it->second->setProperty(name, type, property); - return ret; + return nodeinfo_it->second->setProperty(name, type, property); } // NodeInfo::setProperty() end // Source::inputTensorsInfo begin +PlatformResult Pipeline::getInputTensorsInfo(const std::string& name, ml_tensors_info_h* result) { + ScopeLogger(); + + auto source_it = sources_.find(name); + if (sources_.end() == source_it) { + LoggerD("Source [%s] not found", name.c_str()); + return PlatformResult{ErrorCode::NOT_FOUND_ERR, "Source not found"}; + } + return source_it->second->getInputTensorsInfo(result); +} // Source::inputTensorsInfo end // Source::inputData() begin @@ -352,7 +376,26 @@ PlatformResult Pipeline::GetSwitch(const std::string& name, Switch** out) { // Switch::getPadList() end // Valve::setOpen() begin +PlatformResult Pipeline::GetNodeInfo(const std::string& name, NodeInfo** out) { + ScopeLogger("id_: [%d], name: [%s]", id_, name.c_str()); + + auto ret = GetNodeInfo(name); + if (ret) { + *out = node_info_[name].get(); + } + + return ret; +} + +PlatformResult Pipeline::GetValve(const std::string& name, Valve** out) { + ScopeLogger("id: [%d], name: [%s]", id_, name.c_str()); + auto ret = GetValve(name); + if (ret) { + *out = valves_[name].get(); + } + return ret; +} // Valve::setOpen() end } // namespace extension diff --git a/src/ml/ml_pipeline.h b/src/ml/ml_pipeline.h index afe3fee..d6606c8 100644 --- a/src/ml/ml_pipeline.h +++ b/src/ml/ml_pipeline.h @@ -27,6 +27,7 @@ #include "common/picojson.h" #include "common/platform_result.h" #include "ml_pipeline_nodeinfo.h" +#include "ml_pipeline_source.h" #include "ml_pipeline_switch.h" #include "ml_pipeline_valve.h" @@ -72,11 +73,11 @@ class Pipeline { // Pipeline::dispose() end // Pipeline::getNodeInfo() begin - PlatformResult GetNodeInfo(std::string& name); + PlatformResult GetNodeInfo(const std::string& name); // Pipeline::getNodeInfo() end // Pipeline::getSource() begin - + PlatformResult GetSource(const std::string& name); // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -114,7 +115,7 @@ class Pipeline { // NodeInfo::setProperty() end // Source::inputTensorsInfo begin - + PlatformResult getInputTensorsInfo(const std::string& name, ml_tensors_info_h* result); // Source::inputTensorsInfo end // Source::inputData() begin @@ -127,7 +128,8 @@ class Pipeline { // Switch::getPadList() end // Valve::setOpen() begin - + PlatformResult GetNodeInfo(const std::string& name, NodeInfo** out); + PlatformResult GetValve(const std::string& name, Valve** out); // Valve::setOpen() end private: Pipeline(int id, const std::string& state_change_listener_name, common::Instance* instance_ptr); @@ -152,6 +154,7 @@ class Pipeline { std::unordered_map> switches_; std::map> node_info_; std::unordered_map> valves_; + std::map> sources_; static void PipelineStateChangeListener(ml_pipeline_state_e state, void* user_data); }; diff --git a/src/ml/ml_pipeline_manager.cc b/src/ml/ml_pipeline_manager.cc index 3e1cc41..3251229 100644 --- a/src/ml/ml_pipeline_manager.cc +++ b/src/ml/ml_pipeline_manager.cc @@ -26,7 +26,9 @@ using common::tools::ReportSuccess; namespace extension { namespace ml { -PipelineManager::PipelineManager(common::Instance* instance_ptr) : instance_ptr_{instance_ptr} { +PipelineManager::PipelineManager(common::Instance* instance_ptr, + TensorsInfoManager* tensors_info_manager) + : instance_ptr_{instance_ptr}, tensors_info_manager_{tensors_info_manager} { ScopeLogger(); } @@ -144,7 +146,17 @@ PlatformResult PipelineManager::GetNodeInfo(int id, std::string& name) { // Pipeline::getNodeInfo() end // Pipeline::getSource() begin +PlatformResult PipelineManager::GetSource(int pipeline_id, const std::string& name) { + ScopeLogger("name: [%s], pipeline_id: [%d]", name.c_str(), pipeline_id); + auto pipeline_it = pipelines_.find(pipeline_id); + if (pipelines_.end() == pipeline_it) { + LoggerD("Pipeline not found: [%d]", pipeline_id); + return PlatformResult{ErrorCode::NOT_FOUND_ERR, "Pipeline not found"}; + } + + return pipeline_it->second->GetSource(name); +} // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -227,7 +239,24 @@ PlatformResult PipelineManager::setProperty(int id, const std::string& node_name // NodeInfo::setProperty() end // Source::inputTensorsInfo begin +PlatformResult PipelineManager::getInputTensorsInfo(int id, const std::string& name, int* res_id) { + ScopeLogger(); + + auto pipeline_it = pipelines_.find(id); + if (pipelines_.end() == pipeline_it) { + LoggerD("Pipeline not found: [%d]", id); + return PlatformResult{ErrorCode::NOT_FOUND_ERR, "Pipeline not found"}; + } + ml_tensors_info_h in_info = nullptr; + PlatformResult ret = pipeline_it->second->getInputTensorsInfo(name, &in_info); + if (!ret) { + return ret; + } + auto tensor_info = tensors_info_manager_->CreateTensorsInfo(in_info); + *res_id = tensor_info->Id(); + return PlatformResult{}; +} // Source::inputTensorsInfo end // Source::inputData() begin @@ -278,8 +307,47 @@ PlatformResult PipelineManager::SwitchSelect(int pipeline_id, const std::string& // Switch::select() end // Valve::setOpen() begin +PlatformResult PipelineManager::ValveSetOpen(int pipeline_id, const std::string& valve_name, + bool open) { + ScopeLogger("pipeline_id: [%d], valve_name: [%s], open: [%s]", pipeline_id, valve_name.c_str(), + open ? "true" : "false"); + + auto pipeline_it = pipelines_.find(pipeline_id); + if (pipelines_.end() == pipeline_it) { + LoggerD("Pipeline not found: [%d]", pipeline_id); + return PlatformResult{ErrorCode::NOT_FOUND_ERR, "Pipeline not found"}; + } + Valve* valve_ptr = nullptr; + auto ret = pipeline_it->second->GetValve(valve_name, &valve_ptr); + if (!ret) { + return ret; + } + + return valve_ptr->SetOpen(open); +} // Valve::setOpen() end +// Valve::isOpen() begin +PlatformResult PipelineManager::ValveIsOpen(int pipeline_id, const std::string& valve_name, + bool* open) { + ScopeLogger("pipeline_id: [%d], valve_name: [%s]", pipeline_id, valve_name.c_str()); + + auto pipeline_it = pipelines_.find(pipeline_id); + if (pipelines_.end() == pipeline_it) { + LoggerD("Pipeline not found: [%d]", pipeline_id); + return PlatformResult{ErrorCode::NOT_FOUND_ERR, "Pipeline not found"}; + } + + Valve* valve_ptr = nullptr; + auto ret = pipeline_it->second->GetValve(valve_name, &valve_ptr); + if (!ret) { + return ret; + } + + return valve_ptr->IsOpen(open); +} +// Valve::isOpen() end + } // namespace ml } // namespace extension diff --git a/src/ml/ml_pipeline_manager.h b/src/ml/ml_pipeline_manager.h index f986624..18e8c4f 100644 --- a/src/ml/ml_pipeline_manager.h +++ b/src/ml/ml_pipeline_manager.h @@ -22,6 +22,7 @@ #include "common/platform_result.h" #include "ml_pipeline.h" +#include "ml_tensors_info_manager.h" using common::PlatformResult; @@ -30,7 +31,7 @@ namespace ml { class PipelineManager { public: - PipelineManager(common::Instance* instance_ptr); + PipelineManager(common::Instance* instance_ptr, TensorsInfoManager* tim); ~PipelineManager(); @@ -64,7 +65,7 @@ class PipelineManager { // Pipeline::getNodeInfo() end // Pipeline::getSource() begin - + PlatformResult GetSource(int pipeline_id, const std::string& name); // Pipeline::getSource() end // Pipeline::getSwitch() begin @@ -102,7 +103,7 @@ class PipelineManager { // NodeInfo::setProperty() end // Source::inputTensorsInfo begin - + PlatformResult getInputTensorsInfo(int id, const std::string& name, int* res_id); // Source::inputTensorsInfo end // Source::inputData() begin @@ -120,10 +121,15 @@ class PipelineManager { // Switch::select() end // Valve::setOpen() begin - + PlatformResult ValveSetOpen(int pipeline_id, const std::string& valve_name, bool open); // Valve::setOpen() end + + // Valve::isOpen() begin + PlatformResult ValveIsOpen(int pipeline_id, const std::string& valve_name, bool* open); + // Valve::isOpen() end private: common::Instance* instance_ptr_; + TensorsInfoManager* tensors_info_manager_; std::map> pipelines_; }; diff --git a/src/ml/ml_pipeline_source.cc b/src/ml/ml_pipeline_source.cc new file mode 100644 index 0000000..7ac0104 --- /dev/null +++ b/src/ml/ml_pipeline_source.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ml_pipeline_source.h" +#include "ml_utils.h" + +using common::PlatformResult; +using common::ErrorCode; + +namespace extension { +namespace ml { +namespace pipeline { + +PlatformResult Source::CreateSource(const std::string& name, ml_pipeline_h pipeline, + std::unique_ptr* out) { + ScopeLogger("name: [%s], pipeline: [%p]", name.c_str(), pipeline); + ml_pipeline_src_h source_handle = nullptr; + auto ret = ml_pipeline_src_get_handle(pipeline, name.c_str(), &source_handle); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_src_get_handle() failed: [%d] (%s)", ret, get_error_message(ret)); + return util::ToPlatformResult(ret, "Could not get source"); + } + + out->reset(new (std::nothrow) Source{name, source_handle}); + if (!out) { + ret = ml_pipeline_src_release_handle(source_handle); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_src_release_handle() failed: [%d] (%s)", ret, get_error_message(ret)); + } else { + LoggerD("ml_pipeline_src_release_handle() succeeded"); + } + return LogAndCreateResult(ErrorCode::ABORT_ERR, "Could not get the source", + ("Could not allocate memory")); + } + + return PlatformResult{}; +} + +Source::Source(const std::string& name, ml_pipeline_src_h source_handle) + : name_{name}, source_{source_handle} { + ScopeLogger("name: [%s], handle: [%p]", name.c_str(), source_handle); +} + +Source::~Source() { + ScopeLogger("name: [%s], handle: [%p]", name_.c_str(), source_); + + auto ret = ml_pipeline_src_release_handle(source_); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_src_release_handle() failed: [%d] (%s)", ret, get_error_message(ret)); + } else { + LoggerD("ml_pipeline_src_release_handle() succeeded"); + } +} + +PlatformResult Source::getInputTensorsInfo(ml_tensors_info_h* result) { + ScopeLogger(); + + ml_tensors_info_h info = nullptr; + auto ret = ml_pipeline_src_get_tensors_info(source_, &info); + + if (ML_ERROR_NONE != ret) { + LoggerE(" ml_pipeline_src_get_tensors_info failed: %d (%s)", ret, get_error_message(ret)); + return util::ToPlatformResult(ret, "Failed to get tensor info"); + } + + *result = info; + + return PlatformResult{}; +} + +} // namespace pipeline +} // namespace ml +} // namespace extension \ No newline at end of file diff --git a/src/ml/ml_pipeline_source.h b/src/ml/ml_pipeline_source.h new file mode 100644 index 0000000..cdd0154 --- /dev/null +++ b/src/ml/ml_pipeline_source.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ML_ML_PIPELINE_SOURCE_H_ +#define ML_ML_PIPELINE_SOURCE_H_ + +#include +#include + +#include + +#include "common/platform_result.h" + +using common::PlatformResult; + +namespace extension { +namespace ml { +namespace pipeline { + +class Source { + public: + static PlatformResult CreateSource(const std::string& name, ml_pipeline_h pipeline, + std::unique_ptr* out); + + ~Source(); + + PlatformResult getInputTensorsInfo(ml_tensors_info_h* result); + + Source(const Source&) = delete; + Source& operator=(const Source&) = delete; + + private: + Source(const std::string& name, ml_pipeline_src_h source_handle); + const std::string name_; + const ml_pipeline_src_h source_; +}; + +} // namespace pipeline +} // namespace ml +} // namespace extension + +#endif // ML_ML_PIPELINE_SOURCE_H_ \ No newline at end of file diff --git a/src/ml/ml_pipeline_valve.cc b/src/ml/ml_pipeline_valve.cc index cbd6fda..06d5a5e 100644 --- a/src/ml/ml_pipeline_valve.cc +++ b/src/ml/ml_pipeline_valve.cc @@ -14,29 +14,33 @@ * limitations under the License. */ +#include "common/picojson.h" + +#include "ml_pipeline.h" #include "ml_pipeline_valve.h" #include "ml_utils.h" using common::PlatformResult; using common::ErrorCode; +using extension::ml::Pipeline; namespace extension { namespace ml { namespace pipeline { -PlatformResult Valve::CreateValve(const std::string& name, ml_pipeline_h pipeline, - std::unique_ptr* out) { - ScopeLogger("name: [%s], pipeline: [%p]", name.c_str(), pipeline); +PlatformResult Valve::CreateValve(const std::string& name, ml_pipeline_h native_pipeline_handle, + Pipeline& pipeline, std::unique_ptr* out) { + ScopeLogger("name: [%s], native_pipeline_handle: [%p]", name.c_str(), native_pipeline_handle); ml_pipeline_valve_h valve_handle = nullptr; - auto ret = ml_pipeline_valve_get_handle(pipeline, name.c_str(), &valve_handle); + auto ret = ml_pipeline_valve_get_handle(native_pipeline_handle, name.c_str(), &valve_handle); if (ML_ERROR_NONE != ret) { LoggerE("ml_pipeline_valve_get_handle() failed: [%d] (%s)", ret, get_error_message(ret)); return util::ToPlatformResult(ret, "Could not get valve"); } LoggerD("ml_pipeline_valve_get_handle() succeeded"); - out->reset(new (std::nothrow) Valve{name, valve_handle}); + out->reset(new (std::nothrow) Valve{name, valve_handle, pipeline}); if (!out) { ret = ml_pipeline_valve_release_handle(valve_handle); if (ML_ERROR_NONE != ret) { @@ -51,8 +55,8 @@ PlatformResult Valve::CreateValve(const std::string& name, ml_pipeline_h pipelin return PlatformResult{}; } -Valve::Valve(const std::string& name, ml_pipeline_valve_h valve_handle) - : name_{name}, valve_{valve_handle} { +Valve::Valve(const std::string& name, ml_pipeline_valve_h valve_handle, Pipeline& pipeline) + : name_{name}, valve_{valve_handle}, pipeline_{pipeline} { ScopeLogger("name: [%s], handle: [%p]", name.c_str(), valve_handle); } @@ -67,6 +71,42 @@ Valve::~Valve() { } } +PlatformResult Valve::SetOpen(bool open) { + ScopeLogger("name: [%s], open: [%s]", name_.c_str(), open ? "true" : "false"); + + auto ret = ml_pipeline_valve_set_open(valve_, open); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_pipeline_valve_set_open() failed: [%d] (%s)", ret, get_error_message(ret)); + return util::ToPlatformResult(ret, "Could not set valve open state"); + } + LoggerD("ml_pipeline_valve_set_open() succeeded"); + + return PlatformResult{}; +} + +PlatformResult Valve::IsOpen(bool* out) { + ScopeLogger("name: [%s]", name_.c_str()); + + NodeInfo* node_info_ptr = nullptr; + auto ret = pipeline_.GetNodeInfo(name_, &node_info_ptr); + if (!ret) { + return ret; + } + + // Valve's drop property doc: + // https://gstreamer.freedesktop.org/documentation/coreelements/valve.html?gi-language=c#valve:drop + const std::string kPropertyToNegate = "drop"; + const std::string kProperty = "property"; + const std::string kType = "BOOLEAN"; + + picojson::object propertyToNegate; + ret = node_info_ptr->getProperty(kPropertyToNegate, kType, &propertyToNegate); + if (ret) { + *out = !propertyToNegate[kProperty].get(); + } + return ret; +} + } // namespace pipeline } // namespace ml } // namespace extension \ No newline at end of file diff --git a/src/ml/ml_pipeline_valve.h b/src/ml/ml_pipeline_valve.h index a03e0e7..75452b0 100644 --- a/src/ml/ml_pipeline_valve.h +++ b/src/ml/ml_pipeline_valve.h @@ -28,22 +28,28 @@ using common::PlatformResult; namespace extension { namespace ml { + +class Pipeline; namespace pipeline { class Valve { public: - static PlatformResult CreateValve(const std::string& name, ml_pipeline_h pipeline, - std::unique_ptr* out); + static PlatformResult CreateValve(const std::string& name, ml_pipeline_h native_pipeline_handle, + Pipeline& pipeline, std::unique_ptr* out); ~Valve(); + PlatformResult SetOpen(bool open); + PlatformResult IsOpen(bool* out); + Valve(const Valve&) = delete; Valve& operator=(const Valve&) = delete; private: - Valve(const std::string& name, ml_pipeline_valve_h valve_handle); + Valve(const std::string& name, ml_pipeline_valve_h valve_handle, Pipeline& pipeline); const std::string name_; const ml_pipeline_valve_h valve_; + Pipeline& pipeline_; }; } // namespace pipeline diff --git a/src/ml/ml_tensors_data_manager.cc b/src/ml/ml_tensors_data_manager.cc index 225e289..410cd68 100644 --- a/src/ml/ml_tensors_data_manager.cc +++ b/src/ml/ml_tensors_data_manager.cc @@ -52,6 +52,71 @@ int TensorsData::Count() { return tensors_info_->Count(); } +ml_tensor_type_e TensorsData::GetTensorType(int index) { + ScopeLogger("id_: %d, index: %d", id_, index); + ml_tensor_type_e tensor_type_enum = ML_TENSOR_TYPE_UNKNOWN; + PlatformResult result = tensors_info_->NativeGetTensorType(index, &tensor_type_enum); + if (!result) { + LoggerE("Failed to get tensor type"); + } + return tensor_type_enum; +} + +PlatformResult TensorsData::GetTensorRawData(int index, int location[ML_TENSOR_RANK_LIMIT], + int size[ML_TENSOR_RANK_LIMIT], + TensorRawData* tensor_raw_data) { + ScopeLogger("id_: %d, index: %d", id_, index); + if (nullptr == tensor_raw_data) { + LoggerE("Invalid tensor_raw_data"); + return PlatformResult(ErrorCode::ABORT_ERR); + } + void* data; + size_t data_size; + int ret = ml_tensors_data_get_tensor_data(handle_, index, &data, &data_size); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_tensors_data_get_tensor_data failed: %d (%s)", ret, get_error_message(ret)); + return util::ToPlatformResult(ret, "Internal TensorsData error"); + } + // TODO: add support for location and size - will be done in future commit + + // Dimensions of whole tensor + unsigned int dim[ML_TENSOR_RANK_LIMIT]; + PlatformResult result = tensors_info_->NativeGetTensorDimensions(index, dim); + if (!result) { + return result; + } + + for (int i = 0; i < ML_TENSOR_RANK_LIMIT; i++) { + tensor_raw_data->shape[i] = dim[i]; + } + + result = types::TensorTypeEnum.getName(this->GetTensorType(index), &tensor_raw_data->type_str); + if (!result) { + return result; + } + + tensor_raw_data->data = static_cast(data); + tensor_raw_data->size = data_size; + + return PlatformResult(ErrorCode::NO_ERROR); +} + +PlatformResult TensorsData::SetTensorRawData(int index, int location[ML_TENSOR_RANK_LIMIT], + int size[ML_TENSOR_RANK_LIMIT], + TensorRawData& tensor_raw_data) { + ScopeLogger("id_: %d, index: %d", id_, index); + + // TODO: add support for location and size - will be done in future commit + int ret = + ml_tensors_data_set_tensor_data(handle_, index, tensor_raw_data.data, tensor_raw_data.size); + if (ML_ERROR_NONE != ret) { + LoggerE("ml_tensors_data_set_tensor_data failed: %d (%s)", ret, get_error_message(ret)); + return util::ToPlatformResult(ret, "Internal TensorsData error"); + } + + return PlatformResult(ErrorCode::NO_ERROR); +} + PlatformResult TensorsData::NativeDestroy() { ScopeLogger("id_: %d", id_); int ret = ml_tensors_data_destroy(handle_); diff --git a/src/ml/ml_tensors_data_manager.h b/src/ml/ml_tensors_data_manager.h index 39811df..41db5e7 100644 --- a/src/ml/ml_tensors_data_manager.h +++ b/src/ml/ml_tensors_data_manager.h @@ -31,6 +31,14 @@ namespace ml { class TensorsInfo; +struct TensorRawData { + // TensorRawData does not take ownership of data, remember to handle it outside + uint8_t* data; + size_t size; + std::string type_str; + unsigned int shape[ML_TENSOR_RANK_LIMIT]; +}; + class TensorsData { public: TensorsData(ml_tensors_data_h handle, int id, TensorsInfo* tensors_info); @@ -40,6 +48,11 @@ class TensorsData { int Id(); int TensorsInfoId(); int Count(); + ml_tensor_type_e GetTensorType(int index); + PlatformResult GetTensorRawData(int index, int location[ML_TENSOR_RANK_LIMIT], + int size[ML_TENSOR_RANK_LIMIT], TensorRawData* tensor_raw_data); + PlatformResult SetTensorRawData(int index, int location[ML_TENSOR_RANK_LIMIT], + int size[ML_TENSOR_RANK_LIMIT], TensorRawData& tensor_raw_data); PlatformResult NativeDestroy();