throw new IllegalArgumentException("DataType " + this + " is not supported yet");
}
+ /** Gets string names of the data type. */
+ String toStringName() {
+ switch (this) {
+ case FLOAT32:
+ return "float";
+ case INT32:
+ return "int";
+ case UINT8:
+ return "byte";
+ case INT64:
+ return "long";
+ case BYTEBUFFER:
+ return "ByteBuffer";
+ }
+ throw new IllegalArgumentException("DataType " + this + " is not supported yet");
+ }
+
// Cached to avoid copying it
private static final DataType[] values = values();
}
return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
}
+ /**
+ * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+ */
+ int[] getInputDims(int index) {
+ return getInputDims(interpreterHandle, index, -1);
+ }
+
+ /**
+ * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
+ * input.
+ */
+ private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
+
+ /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
+ String getOutputDataType(int index) {
+ int type = getOutputDataType(interpreterHandle, index);
+ return DataType.fromNumber(type).toStringName();
+ }
+
+ private static native int getOutputDataType(long interpreterHandle, int outputIdx);
+
private static final int ERROR_BUFFER_SIZE = 512;
private long errorHandle;
private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
- private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
-
static {
TensorFlowLite.init();
}
}
}
+int getDataType(TfLiteType data_type) {
+ switch (data_type) {
+ case kTfLiteFloat32:
+ return 1;
+ case kTfLiteInt32:
+ return 2;
+ case kTfLiteUInt8:
+ return 3;
+ case kTfLiteInt64:
+ return 4;
+ default:
+ return -1;
+ }
+}
+
void printDims(char* buffer, int max_size, int* dims, int num_dims) {
if (max_size <= 0) return;
buffer[0] = '?';
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const int idx = static_cast<int>(input_idx);
- if (input_idx >= interpreter->inputs().size()) {
+ if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
throwException(env, kIllegalArgumentException,
"Out of range: Failed to get %d-th input out of %d inputs",
input_idx, interpreter->inputs().size());
}
TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
int size = target->dims->size;
- int expected_num_bytes = elementByteSize(target->type);
- for (int i = 0; i < size; ++i) {
- expected_num_bytes *= target->dims->data[i];
- }
- if (num_bytes != expected_num_bytes) {
- throwException(env, kIllegalArgumentException,
- "Failed to get input dimensions. %d-th input should have"
- " %d bytes, but found %d bytes.",
- idx, expected_num_bytes, num_bytes);
- return nullptr;
+ if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid.
+ int expected_num_bytes = elementByteSize(target->type);
+ for (int i = 0; i < size; ++i) {
+ expected_num_bytes *= target->dims->data[i];
+ }
+ if (num_bytes != expected_num_bytes) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to get input dimensions. %d-th input should have"
+ " %d bytes, but found %d bytes.",
+ idx, expected_num_bytes, num_bytes);
+ return nullptr;
+ }
}
jintArray outputs = env->NewIntArray(size);
env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
return outputs;
}
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return -1;
+ const int idx = static_cast<int>(output_idx);
+ if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Out of range: Failed to get %d-th output out of %d outputs",
+ output_idx, interpreter->outputs().size());
+ return -1;
+ }
+ TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
+ int type = getDataType(target->type);
+ return static_cast<jint>(type);
+}
+
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
* Method:
* Signature: (JII)[I
*
- * It gets input dimensions if num_bytes matches number of bytes required by
- * the input, else returns null and throws IllegalArgumentException.
+ * Gets input dimensions. If num_bytes is non-negative, it will check whether
+ * num_bytes matches num of bytes required by the input, and return null and
+ * throw IllegalArgumentException if not.
*/
JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
+ * Signature: (JI)I
+ *
+ * Gets output dimensions.
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
* Signature: (JJI[I)Z
*
* It returns true if resizing input tensor to different dimensions, else return
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
}
+
+ @Test
+ public void testGetInputDims() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ int[] expectedDims = {1, 8, 8, 3};
+ assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
+ wrapper.close();
+ }
+
+ @Test
+ public void testGetInputDimsOutOfRange() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ try {
+ wrapper.getInputDims(-1);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Out of range");
+ }
+ try {
+ wrapper.getInputDims(1);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Out of range");
+ }
+ wrapper.close();
+ }
+
+ @Test
+ public void testGetOutputDataType() {
+ NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+ assertThat(wrapper.getOutputDataType(0)).contains("float");
+ wrapper.close();
+ wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
+ assertThat(wrapper.getOutputDataType(0)).contains("long");
+ wrapper.close();
+ wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
+ assertThat(wrapper.getOutputDataType(0)).contains("int");
+ wrapper.close();
+ wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+ assertThat(wrapper.getOutputDataType(0)).contains("byte");
+ wrapper.close();
+ }
}
throw new IllegalArgumentException("Interpreter has not initialized; Failed to get latency.");
}
}
+
+ /**
+ * Gets the dimensions of an input.
+ *
+ * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+ * IllegalArgumentException} will be thrown.
+ * @param index an integer index of the input. If it is invalid, an {@code
+ * IllegalArgumentException} will be thrown.
+ */
+ public static int[] getInputDims(Interpreter interpreter, int index) {
+ if (interpreter != null && interpreter.wrapper != null) {
+ return interpreter.wrapper.getInputDims(index);
+ } else {
+ throw new IllegalArgumentException(
+ "Interpreter has not initialized;" + " Failed to get input dimensions.");
+ }
+ }
+
+ /**
+ * Gets the string name of the data type of an output.
+ *
+ * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+ * IllegalArgumentException} will be thrown.
+ * @param index an integer index of the output. If it is invalid, an {@code
+ * IllegalArgumentException} will be thrown.
+ * @return string name of the data type. Possible values include "float", "int", "byte", and
+ * "long".
+ */
+ public static String getOutputDataType(Interpreter interpreter, int index) {
+ if (interpreter != null && interpreter.wrapper != null) {
+ return interpreter.wrapper.getOutputDataType(index);
+ } else {
+ throw new IllegalArgumentException(
+ "Interpreter has not initialized;" + " Failed to get output data type.");
+ }
+ }
}