Internal cleanup.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:46:05 +0000 (11:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 18:55:27 +0000 (11:55 -0700)
PiperOrigin-RevId: 188905507

tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java

index d63c299..fc16488 100644 (file)
@@ -71,6 +71,23 @@ enum DataType {
     throw new IllegalArgumentException("DataType " + this + " is not supported yet");
   }
 
+  /** Gets string names of the data type. */
+  String toStringName() {
+    switch (this) {
+      case FLOAT32:
+        return "float";
+      case INT32:
+        return "int";
+      case UINT8:
+        return "byte";
+      case INT64:
+        return "long";
+      case BYTEBUFFER:
+        return "ByteBuffer";
+    }
+    throw new IllegalArgumentException("DataType " + this + " is not supported yet");
+  }
+
   // Cached to avoid copying it
   private static final DataType[] values = values();
 }
index bca4a3c..014636f 100644 (file)
@@ -261,6 +261,27 @@ final class NativeInterpreterWrapper implements AutoCloseable {
     return (inferenceDurationNanoseconds < 0) ? null : inferenceDurationNanoseconds;
   }
 
+  /**
+   * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+   */
+  int[] getInputDims(int index) {
+    return getInputDims(interpreterHandle, index, -1);
+  }
+
+  /**
+   * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
+   * input.
+   */
+  private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
+
+  /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
+  String getOutputDataType(int index) {
+    int type = getOutputDataType(interpreterHandle, index);
+    return DataType.fromNumber(type).toStringName();
+  }
+
+  private static native int getOutputDataType(long interpreterHandle, int outputIdx);
+
   private static final int ERROR_BUFFER_SIZE = 512;
 
   private long errorHandle;
@@ -297,8 +318,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
 
   private static native void delete(long errorHandle, long modelHandle, long interpreterHandle);
 
-  private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
-
   static {
     TensorFlowLite.init();
   }
index 475b467..2870ffe 100644 (file)
@@ -79,6 +79,21 @@ TfLiteType resolveDataType(jint data_type) {
   }
 }
 
+int getDataType(TfLiteType data_type) {
+  switch (data_type) {
+    case kTfLiteFloat32:
+      return 1;
+    case kTfLiteInt32:
+      return 2;
+    case kTfLiteUInt8:
+      return 3;
+    case kTfLiteInt64:
+      return 4;
+    default:
+      return -1;
+  }
+}
+
 void printDims(char* buffer, int max_size, int* dims, int num_dims) {
   if (max_size <= 0) return;
   buffer[0] = '?';
@@ -477,7 +492,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
   tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   const int idx = static_cast<int>(input_idx);
-  if (input_idx >= interpreter->inputs().size()) {
+  if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
     throwException(env, kIllegalArgumentException,
                    "Out of range: Failed to get %d-th input out of %d inputs",
                    input_idx, interpreter->inputs().size());
@@ -485,22 +500,41 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
   }
   TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
   int size = target->dims->size;
-  int expected_num_bytes = elementByteSize(target->type);
-  for (int i = 0; i < size; ++i) {
-    expected_num_bytes *= target->dims->data[i];
-  }
-  if (num_bytes != expected_num_bytes) {
-    throwException(env, kIllegalArgumentException,
-                   "Failed to get input dimensions. %d-th input should have"
-                   " %d bytes, but found %d bytes.",
-                   idx, expected_num_bytes, num_bytes);
-    return nullptr;
+  if (num_bytes >= 0) {  // verifies num of bytes matches if num_bytes if valid.
+    int expected_num_bytes = elementByteSize(target->type);
+    for (int i = 0; i < size; ++i) {
+      expected_num_bytes *= target->dims->data[i];
+    }
+    if (num_bytes != expected_num_bytes) {
+      throwException(env, kIllegalArgumentException,
+                     "Failed to get input dimensions. %d-th input should have"
+                     " %d bytes, but found %d bytes.",
+                     idx, expected_num_bytes, num_bytes);
+      return nullptr;
+    }
   }
   jintArray outputs = env->NewIntArray(size);
   env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
   return outputs;
 }
 
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+    JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
+  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  if (interpreter == nullptr) return -1;
+  const int idx = static_cast<int>(output_idx);
+  if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
+    throwException(env, kIllegalArgumentException,
+                   "Out of range: Failed to get %d-th output out of %d outputs",
+                   output_idx, interpreter->outputs().size());
+    return -1;
+  }
+  TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
+  int type = getDataType(target->type);
+  return static_cast<jint>(type);
+}
+
 JNIEXPORT jboolean JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
index f7c2d9b..d611ec7 100644 (file)
@@ -122,8 +122,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
  *  Method:
  *  Signature: (JII)[I
  *
- * It gets input dimensions if num_bytes matches number of bytes required by
- * the input, else returns null and throws IllegalArgumentException.
+ * Gets input dimensions. If num_bytes is non-negative, it will check whether
+ * num_bytes matches num of bytes required by the input, and return null and
+ * throw IllegalArgumentException if not.
  */
 JNIEXPORT jintArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
@@ -132,6 +133,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
+ *  Signature: (JI)I
+ *
+ * Gets output dimensions.
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+    JNIEnv* env, jclass clazz, jlong handle, jint output_idx);
+
+/*
+ *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
+ *  Method:
  *  Signature: (JJI[I)Z
  *
  * It returns true if resizing input tensor to different dimensions, else return
index 6371fb5..d6b4e9f 100644 (file)
@@ -482,4 +482,46 @@ public final class NativeInterpreterWrapperTest {
     assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
     wrapper.close();
   }
+
+  @Test
+  public void testGetInputDims() {
+    NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+    int[] expectedDims = {1, 8, 8, 3};
+    assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
+    wrapper.close();
+  }
+
+  @Test
+  public void testGetInputDimsOutOfRange() {
+    NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+    try {
+      wrapper.getInputDims(-1);
+      fail();
+    } catch (IllegalArgumentException e) {
+      assertThat(e).hasMessageThat().contains("Out of range");
+    }
+    try {
+      wrapper.getInputDims(1);
+      fail();
+    } catch (IllegalArgumentException e) {
+      assertThat(e).hasMessageThat().contains("Out of range");
+    }
+    wrapper.close();
+  }
+
+  @Test
+  public void testGetOutputDataType() {
+    NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
+    assertThat(wrapper.getOutputDataType(0)).contains("float");
+    wrapper.close();
+    wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
+    assertThat(wrapper.getOutputDataType(0)).contains("long");
+    wrapper.close();
+    wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
+    assertThat(wrapper.getOutputDataType(0)).contains("int");
+    wrapper.close();
+    wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
+    assertThat(wrapper.getOutputDataType(0)).contains("byte");
+    wrapper.close();
+  }
 }
index a5c1305..3aef0c3 100644 (file)
@@ -47,4 +47,40 @@ public class TestHelper {
       throw new IllegalArgumentException("Interpreter has not initialized; Failed to get latency.");
     }
   }
+
+  /**
+   * Gets the dimensions of an input.
+   *
+   * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+   *     IllegalArgumentException} will be thrown.
+   * @param index an integer index of the input. If it is invalid, an {@code
+   *     IllegalArgumentException} will be thrown.
+   */
+  public static int[] getInputDims(Interpreter interpreter, int index) {
+    if (interpreter != null && interpreter.wrapper != null) {
+      return interpreter.wrapper.getInputDims(index);
+    } else {
+      throw new IllegalArgumentException(
+          "Interpreter has not initialized;" + " Failed to get input dimensions.");
+    }
+  }
+
+  /**
+   * Gets the string name of the data type of an output.
+   *
+   * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
+   *     IllegalArgumentException} will be thrown.
+   * @param index an integer index of the output. If it is invalid, an {@code
+   *     IllegalArgumentException} will be thrown.
+   * @return string name of the data type. Possible values include "float", "int", "byte", and
+   *     "long".
+   */
+  public static String getOutputDataType(Interpreter interpreter, int index) {
+    if (interpreter != null && interpreter.wrapper != null) {
+      return interpreter.wrapper.getOutputDataType(index);
+    } else {
+      throw new IllegalArgumentException(
+          "Interpreter has not initialized;" + " Failed to get output data type.");
+    }
+  }
 }