2e828d5e5025e387517f1df84ede11b655e106e7
[platform/core/ml/beyond.git] /
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package com.samsung.android.beyond.inference;
18
19 import java.util.*;
20 import android.util.Log;
21
22 import com.samsung.android.beyond.inference.tensor.TensorSet;
23 import com.samsung.android.beyond.NativeInstance;
24
25 import static com.samsung.android.beyond.inference.Option.TAG;
26
27 import androidx.annotation.NonNull;
28
29 public class InferenceHandler extends NativeInstance {
30     private List<Peer> peerList = new ArrayList<Peer>();
31
32     private TensorOutputCallback tensorOutputCallback;
33
34     InferenceHandler(@NonNull InferenceMode inferenceMode) {
35         long nativeInstance = create(inferenceMode.toString());
36         if (nativeInstance == 0L) {
37             throw new RuntimeException("The native instance of InferenceHandler is not created successfully.");
38         }
39
40         registerNativeInstance(nativeInstance, (Long instance) -> destroy(instance));
41     }
42
43     // TODO:
44     // removePeer() should be provided
45     public boolean addInferencePeer(@NonNull Peer inferencePeer) {
46         if (instance == 0L) {
47             Log.e(TAG, "Instance is invalid.");
48             return false;
49         }
50
51         if (inferencePeer.getNativeInstance() == 0L) {
52             Log.e(TAG, "The given peer instance is invalid.");
53             return false;
54         }
55
56         if (addPeer(instance, inferencePeer.getNativeInstance()) == true) {
57             // NOTE:
58             // Hold the peer instance until release them from the this.
59             peerList.add(inferencePeer);
60             return true;
61         }
62
63         return false;
64     }
65
66     public boolean loadModel(@NonNull String modelPath) {
67         if (instance == 0L) {
68             Log.e(TAG, "Instance is invalid.");
69             return false;
70         }
71
72         Log.i(TAG, "Model path in a client : " + modelPath);
73         return loadModel(instance, modelPath);
74     }
75
76     public boolean prepare() {
77         if (instance == 0L) {
78             Log.e(TAG, "Instance is invalid.");
79             return false;
80         }
81
82         return prepare(instance);
83     }
84
85     public boolean run(@NonNull TensorSet inputTensors) {
86         if (instance == 0L) {
87             Log.e(TAG, "Instance is invalid.");
88             return false;
89         }
90
91         if (inputTensors.getTensorsInstance() == 0L) {
92             Log.e(TAG, "The given tensor set is invalid.");
93             return false;
94         }
95
96         return run(instance, inputTensors.getTensorsInstance(), inputTensors.getTensors().length);
97     }
98
99     public boolean setOutputCallback(@NonNull TensorOutputCallback tensorOutputCallback) {
100         if (instance == 0L) {
101             Log.e(TAG, "Instance is invalid.");
102             return false;
103         }
104
105         this.tensorOutputCallback = tensorOutputCallback;
106
107         if (setCallback(instance, tensorOutputCallback) < 0) {
108             Log.e(TAG, "Fail to set the given output callback.");
109             return false;
110         }
111
112         return true;
113     }
114
115     private native long create(String inferenceMode);
116
117     private native boolean addPeer(long inferenceHandle, long peerHandle);
118
119     private native boolean loadModel(long handle, String modelPath);
120
121     private native boolean prepare(long inferenceHandle);
122
123     private native boolean run(long inferenceHandle, long tensorsInstance, int numberOfTensors);
124
125     private static native void destroy(long inferenceHandle);
126
127     private native int setCallback(long instance, TensorOutputCallback tensorOutputCallback);
128 }