Update to provide resources to PyArmNN examples manually
authorPavel Macenauer <pavel.macenauer@linaro.org>
Tue, 2 Jun 2020 11:54:59 +0000 (11:54 +0000)
committerJim Flynn <jim.flynn@arm.com>
Wed, 19 Aug 2020 20:44:03 +0000 (20:44 +0000)
Change-Id: I9ee751512abd5d4ec9faca499b5cea7c19028d22
Signed-off-by: Pavel Macenauer <pavel.macenauer@nxp.com>
python/pyarmnn/examples/example_utils.py
python/pyarmnn/examples/onnx_mobilenetv2.py
python/pyarmnn/examples/tflite_mobilenetv1_quantized.py

index 5ef30f2..e5425dd 100644 (file)
@@ -2,27 +2,77 @@
 # SPDX-License-Identifier: MIT
 
 from urllib.parse import urlparse
-import os
 from PIL import Image
+from zipfile import ZipFile
+import os
 import pyarmnn as ann
 import numpy as np
 import requests
 import argparse
 import warnings
 
+DEFAULT_IMAGE_URL = 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg'
+
+
+def run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info):
+    """Runs inference on a set of images.
+
+    Args:
+        runtime: Arm NN runtime
+        net_id: Network ID
+        images: Loaded images to run inference on
+        labels: Loaded labels per class
+        input_binding_info: Network input information
+        output_binding_info: Network output information
+
+    Returns:
+        None
+    """
+    output_tensors = ann.make_output_tensors([output_binding_info])
+    for idx, im in enumerate(images):
+        # Create input tensors
+        input_tensors = ann.make_input_tensors([input_binding_info], [im])
+
+        # Run inference
+        print("Running inference({0}) ...".format(idx))
+        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
+
+        # Process output
+        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0]
+        results = np.argsort(out_tensor)[::-1]
+        print_top_n(5, results, labels, out_tensor)
+
+
+def unzip_file(filename: str):
+    """Unzips a file.
+
+    Args:
+        filename(str): Name of the file
+
+    Returns:
+        None
+    """
+    with ZipFile(filename, 'r') as zip_obj:
+        zip_obj.extractall()
+
 
 def parse_command_line(desc: str = ""):
     """Adds arguments to the script.
 
     Args:
-        desc(str): Script description.
+        desc (str): Script description
 
     Returns:
-        Namespace: Arguments to the script command.
+        Namespace: Arguments to the script command
     """
     parser = argparse.ArgumentParser(description=desc)
     parser.add_argument("-v", "--verbose", help="Increase output verbosity",
                         action="store_true")
+    parser.add_argument("-d", "--data-dir", help="Data directory which contains all the images.",
+                        action="store", default="")
+    parser.add_argument("-m", "--model-dir",
+                        help="Model directory which contains the model file (TF, TFLite, ONNX, Caffe).", action="store",
+                        default="")
     return parser.parse_args()
 
 
@@ -30,15 +80,14 @@ def __create_network(model_file: str, backends: list, parser=None):
     """Creates a network based on a file and parser type.
 
     Args:
-        model_file (str): Path of the model file.
+        model_file (str): Path of the model file
         backends (list): List of backends to use when running inference.
         parser_type: Parser instance. (pyarmnn.ITFliteParser/pyarmnn.IOnnxParser...)
 
     Returns:
-        int: Network ID.
-        int: Graph ID.
-        IParser: TF Lite parser instance.
-        IRuntime: Runtime object instance.
+        int: Network ID
+        IParser: TF Lite parser instance
+        IRuntime: Runtime object instance
     """
     args = parse_command_line()
     options = ann.CreationOptions()
@@ -189,7 +238,7 @@ def print_top_n(N: int, results: list, labels: list, prob: list):
         print("class={0} ; value={1}".format(labels[results[i]], prob[results[i]]))
 
 
-def download_file(url: str, force: bool = False, filename: str = None, dest: str = "tmp"):
+def download_file(url: str, force: bool = False, filename: str = None):
     """Downloads a file.
 
     Args:
@@ -197,25 +246,113 @@ def download_file(url: str, force: bool = False, filename: str = None, dest: str
         force (bool): Forces to download the file even if it exists.
         filename (str): Renames the file when set.
 
+    Raises:
+        RuntimeError: If for some reason download fails.
+
     Returns:
         str: Path to the downloaded file.
     """
-    if filename is None:  # extract filename from url when None
-        filename = urlparse(url)
-        filename = os.path.basename(filename.path)
-
-    if str is not None:
-        if not os.path.exists(dest):
-            os.makedirs(dest)
-        filename = os.path.join(dest, filename)
-
-    print("Downloading '{0}' from '{1}' ...".format(filename, url))
-    if not os.path.exists(filename) or force is True:
-        r = requests.get(url)
-        with open(filename, 'wb') as f:
-            f.write(r.content)
-        print("Finished.")
-    else:
-        print("File already exists.")
+    try:
+        if filename is None:  # extract filename from url when None
+            filename = urlparse(url)
+            filename = os.path.basename(filename.path)
+
+        print("Downloading '{0}' from '{1}' ...".format(filename, url))
+        if not os.path.exists(filename) or force is True:
+            r = requests.get(url)
+            with open(filename, 'wb') as f:
+                f.write(r.content)
+            print("Finished.")
+        else:
+            print("File already exists.")
+    except:
+        raise RuntimeError("Unable to download file.")
 
     return filename
+
+
+def get_model_and_labels(model_dir: str, model: str, labels: str, archive: str = None, download_url: str = None):
+    """Gets model and labels.
+
+    Args:
+        model_dir(str): Folder in which model and label files can be found
+        model (str): Name of the model file
+        labels (str): Name of the labels file
+        archive (str): Name of the archive file (optional - need to provide only labels and model)
+        download_url(str or list): Archive url or urls if multiple files (optional - to to provide only to download it)
+
+    Returns:
+        tuple (str, str): Output label and model filenames
+    """
+    labels = os.path.join(model_dir, labels)
+    model = os.path.join(model_dir, model)
+
+    if os.path.exists(labels) and os.path.exists(model):
+        print("Found model ({0}) and labels ({1}).".format(model, labels))
+    elif archive is not None and os.path.exists(os.path.join(model_dir, archive)):
+        print("Found archive ({0}). Unzipping ...".format(archive))
+        unzip_file(archive)
+    elif download_url is not None:
+        print("Model, labels or archive not found. Downloading ...".format(archive))
+        try:
+            if isinstance(download_url, str):
+                download_url = [download_url]
+            for dl in download_url:
+                archive = download_file(dl)
+            if dl.lower().endswith(".zip"):
+                unzip_file(archive)
+        except RuntimeError:
+            print("Unable to download file ({}).".format(archive_url))
+
+    if not os.path.exists(labels) or not os.path.exists(model):
+        raise RuntimeError("Unable to provide model and labels.")
+
+    return model, labels
+
+
+def list_images(folder: str = None, formats: list = ['.jpg', '.jpeg']):
+    """Lists files of a certain format in a folder.
+
+    Args:
+        folder (str): Path to the folder to search
+        formats (list): List of supported files
+
+    Returns:
+        list: A list of found files
+    """
+    files = []
+    if folder and not os.path.exists(folder):
+        print("Folder '{}' does not exist.".format(folder))
+        return files
+
+    for file in os.listdir(folder if folder else os.getcwd()):
+        for frmt in formats:
+            if file.lower().endswith(frmt):
+                files.append(os.path.join(folder, file) if folder else file)
+                break  # only the format loop
+
+    return files
+
+
+def get_images(image_dir: str, image_url: str = DEFAULT_IMAGE_URL):
+    """Gets image.
+
+    Args:
+        image (str): Image filename
+        image_url (str): Image url
+
+    Returns:
+        str: Output image filename
+    """
+    images = list_images(image_dir)
+    if not images and image_url is not None:
+        print("No images found. Downloading ...")
+        try:
+            images = [download_file(image_url)]
+        except RuntimeError:
+            print("Unable to download file ({0}).".format(image_url))
+
+    if not images:
+        raise RuntimeError("Unable to provide images.")
+
+    return images
index 5ba0849..05bfd7b 100755 (executable)
@@ -4,6 +4,7 @@
 
 import pyarmnn as ann
 import numpy as np
+import os
 from PIL import Image
 import example_utils as eu
 
@@ -43,45 +44,48 @@ def preprocess_onnx(img: Image, width: int, height: int, data_type, scale: float
     return img
 
 
-if __name__ == "__main__":
-    # Download resources
-    kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
-    labels_filename = eu.download_file('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
-    model_filename = eu.download_file(
-        'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx')
-
-    # Create a network from a model file
-    net_id, parser, runtime = eu.create_onnx_network(model_filename)
-
-    # Load input information from the model and create input tensors
-    input_binding_info = parser.GetNetworkInputBindingInfo("data")
-
-    # Load output information from the model and create output tensors
-    output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
-    output_tensors = ann.make_output_tensors([output_binding_info])
-
-    # Load labels
-    labels = eu.load_labels(labels_filename)
-
-    # Load images and resize to expected size
-    image_names = [kitten_filename]
-    images = eu.load_images(image_names,
-                            224, 224,
-                            np.float32,
-                            255.0,
-                            [0.485, 0.456, 0.406],
-                            [0.229, 0.224, 0.225],
-                            preprocess_onnx)
-
-    for idx, im in enumerate(images):
-        # Create input tensors
-        input_tensors = ann.make_input_tensors([input_binding_info], [im])
-
-        # Run inference
-        print("Running inference on '{0}' ...".format(image_names[idx]))
-        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
-
-        # Process output
-        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
-        results = np.argsort(out_tensor)[::-1]
-        eu.print_top_n(5, results, labels, out_tensor)
+args = eu.parse_command_line()
+
+model_filename = 'mobilenetv2-1.0.onnx'
+labels_filename = 'synset.txt'
+archive_filename = 'mobilenetv2-1.0.zip'
+labels_url = 'https://s3.amazonaws.com/onnx-model-zoo/' + labels_filename
+model_url = 'https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/' + model_filename
+
+# Download resources
+image_filenames = eu.get_images(args.data_dir)
+
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+                                                          archive_filename,
+                                                          [model_url, labels_url])
+
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+    assert (os.path.exists(im))
+
+# Create a network from a model file
+net_id, parser, runtime = eu.create_onnx_network(model_filename)
+
+# Load input information from the model and create input tensors
+input_binding_info = parser.GetNetworkInputBindingInfo("data")
+
+# Load output information from the model and create output tensors
+output_binding_info = parser.GetNetworkOutputBindingInfo("mobilenetv20_output_flatten0_reshape0")
+output_tensors = ann.make_output_tensors([output_binding_info])
+
+# Load labels
+labels = eu.load_labels(labels_filename)
+
+# Load images and resize to expected size
+images = eu.load_images(image_filenames,
+                        224, 224,
+                        np.float32,
+                        255.0,
+                        [0.485, 0.456, 0.406],
+                        [0.229, 0.224, 0.225],
+                        preprocess_onnx)
+
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)
index aa18a52..cb2c91c 100755 (executable)
@@ -2,71 +2,54 @@
 # Copyright 2020 NXP
 # SPDX-License-Identifier: MIT
 
-from zipfile import ZipFile
 import numpy as np
 import pyarmnn as ann
 import example_utils as eu
 import os
 
+args = eu.parse_command_line()
 
-def unzip_file(filename):
-    """Unzips a file to its current location.
+# names of the files in the archive
+labels_filename = 'labels_mobilenet_quant_v1_224.txt'
+model_filename = 'mobilenet_v1_1.0_224_quant.tflite'
+archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip'
 
-    Args:
-        filename (str): Name of the archive.
+archive_url = \
+    'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip'
 
-    Returns:
-        str: Directory path of the extracted files.
-    """
-    with ZipFile(filename, 'r') as zip_obj:
-        zip_obj.extractall(os.path.dirname(filename))
-    return os.path.dirname(filename)
+model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename,
+                                                          archive_filename, archive_url)
 
+image_filenames = eu.get_images(args.data_dir)
 
-if __name__ == "__main__":
-    # Download resources
-    archive_filename = eu.download_file(
-        'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip')
-    dir_path = unzip_file(archive_filename)
-    # names of the files in the archive
-    labels_filename = os.path.join(dir_path, 'labels_mobilenet_quant_v1_224.txt')
-    model_filename = os.path.join(dir_path, 'mobilenet_v1_1.0_224_quant.tflite')
-    kitten_filename = eu.download_file('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
+# all 3 resources must exist to proceed further
+assert os.path.exists(labels_filename)
+assert os.path.exists(model_filename)
+assert image_filenames
+for im in image_filenames:
+    assert(os.path.exists(im))
 
-    # Create a network from the model file
-    net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
+# Create a network from the model file
+net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename)
 
-    # Load input information from the model
-    # tflite has all the need information in the model unlike other formats
-    input_names = parser.GetSubgraphInputTensorNames(graph_id)
-    assert len(input_names) == 1  # there should be 1 input tensor in mobilenet
+# Load input information from the model
+# tflite has all the need information in the model unlike other formats
+input_names = parser.GetSubgraphInputTensorNames(graph_id)
+assert len(input_names) == 1  # there should be 1 input tensor in mobilenet
 
-    input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
-    input_width = input_binding_info[1].GetShape()[1]
-    input_height = input_binding_info[1].GetShape()[2]
+input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+input_width = input_binding_info[1].GetShape()[1]
+input_height = input_binding_info[1].GetShape()[2]
 
-    # Load output information from the model and create output tensors
-    output_names = parser.GetSubgraphOutputTensorNames(graph_id)
-    assert len(output_names) == 1  # and only one output tensor
-    output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
-    output_tensors = ann.make_output_tensors([output_binding_info])
+# Load output information from the model and create output tensors
+output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+assert len(output_names) == 1  # and only one output tensor
+output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0])
 
-    # Load labels file
-    labels = eu.load_labels(labels_filename)
+# Load labels file
+labels = eu.load_labels(labels_filename)
 
-    # Load images and resize to expected size
-    image_names = [kitten_filename]
-    images = eu.load_images(image_names, input_width, input_height)
+# Load images and resize to expected size
+images = eu.load_images(image_filenames, input_width, input_height)
 
-    for idx, im in enumerate(images):
-        # Create input tensors
-        input_tensors = ann.make_input_tensors([input_binding_info], [im])
-
-        # Run inference
-        print("Running inference on '{0}' ...".format(image_names[idx]))
-        runtime.EnqueueWorkload(net_id, input_tensors, output_tensors)
-
-        # Process output
-        out_tensor = ann.workload_tensors_to_ndarray(output_tensors)[0][0]
-        results = np.argsort(out_tensor)[::-1]
-        eu.print_top_n(5, results, labels, out_tensor)
+eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info)