Merge pull request #17604 from LupusSanctus:am/pytorch_tf_cls_tutorial
authorAnastasia M <anastasia.murzova@xperience.ai>
Tue, 26 Jan 2021 11:06:15 +0000 (14:06 +0300)
committerGitHub <noreply@github.com>
Tue, 26 Jan 2021 11:06:15 +0000 (11:06 +0000)
[GSoC] Added TF and PyTorch classification conversion cases

* Added TF and PyTorch classification conversion cases

* Modified structure, some processing scripts. Added evaluation pipeline

* Minor structure change

* Removed extra functions, minor structure change

* Modified structure, code corrections

* Updated classification code block, added classification tutorials

* Added minor modifications of paths

* Classification block corrections in accordance with comments

29 files changed:
doc/tutorials/dnn/dnn_pytorch_tf_classification/images/opencv_resnet50_test_res_c.jpg [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/images/pytorch_resnet50_opencv_test_res.jpg [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/images/squirrel_cls.jpg [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/images/tf_mobilenet_opencv_test_res.jpg [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_c_tutorial.md [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_tutorial.md [new file with mode: 0644]
doc/tutorials/dnn/dnn_pytorch_tf_classification/tf_cls_model_conversion_tutorial.md [new file with mode: 0644]
doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown
doc/tutorials/dnn/table_of_content_dnn.markdown
samples/data/squirrel_cls.jpg [new file with mode: 0644]
samples/dnn/classification.cpp
samples/dnn/classification.py
samples/dnn/dnn_model_runner/dnn_conversion/common/abstract_model.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_accuracy_evaluator.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_data_fetcher.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/img_utils.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/test/cls_model_test_pipeline.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/default_preprocess_config.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/test/model_test_pipeline.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/common/utils.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_cls.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50_onnx.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/pytorch/pytorch_model.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_cls.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_mobilenet.py [new file with mode: 0644]
samples/dnn/dnn_model_runner/dnn_conversion/tf/tf_model.py [new file with mode: 0644]

diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/opencv_resnet50_test_res_c.jpg b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/opencv_resnet50_test_res_c.jpg
new file mode 100644 (file)
index 0000000..4d1ba30
Binary files /dev/null and b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/opencv_resnet50_test_res_c.jpg differ
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/pytorch_resnet50_opencv_test_res.jpg b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/pytorch_resnet50_opencv_test_res.jpg
new file mode 100644 (file)
index 0000000..7bee270
Binary files /dev/null and b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/pytorch_resnet50_opencv_test_res.jpg differ
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/squirrel_cls.jpg b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/squirrel_cls.jpg
new file mode 100644 (file)
index 0000000..289b13b
Binary files /dev/null and b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/squirrel_cls.jpg differ
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/tf_mobilenet_opencv_test_res.jpg b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/tf_mobilenet_opencv_test_res.jpg
new file mode 100644 (file)
index 0000000..cc18156
Binary files /dev/null and b/doc/tutorials/dnn/dnn_pytorch_tf_classification/images/tf_mobilenet_opencv_test_res.jpg differ
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_c_tutorial.md b/doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_c_tutorial.md
new file mode 100644 (file)
index 0000000..1807caf
--- /dev/null
@@ -0,0 +1,220 @@
+# Conversion of PyTorch Classification Models and Launch with OpenCV C++ {#pytorch_cls_c_tutorial_dnn_conversion}
+
+@prev_tutorial{pytorch_cls_tutorial_dnn_conversion}
+
+|    |    |
+| -: | :- |
+| Original author | Anastasia Murzova |
+| Compatibility | OpenCV >= 4.5 |
+
+## Goals
+In this tutorial you will learn how to:
+* convert PyTorch classification models into ONNX format
+* run converted PyTorch model with OpenCV C/C++ API
+* provide model inference
+
+We will explore the above-listed points by the example of ResNet-50 architecture.
+
+## Introduction
+Let's briefly view the key concepts involved in the pipeline of PyTorch models transition with OpenCV API. The initial step in conversion of PyTorch models into cv::dnn::Net
+is model transferring into [ONNX](https://onnx.ai/about.html) format. ONNX aims at the interchangeability of the neural networks between various frameworks. There is a built-in function in PyTorch for ONNX conversion: [``torch.onnx.export``](https://pytorch.org/docs/stable/onnx.html#torch.onnx.export).
+Further the obtained ``.onnx`` model is passed into cv::dnn::readNetFromONNX or cv::dnn::readNet.
+
+## Requirements
+To be able to experiment with the below code you will need to install a set of libraries. We will use a virtual environment with python3.7+ for this:
+
+```console
+virtualenv -p /usr/bin/python3.7 <env_dir_path>
+source <env_dir_path>/bin/activate
+```
+
+For OpenCV-Python building from source, follow the corresponding instructions from the @ref tutorial_py_table_of_contents_setup.
+
+Before you start the installation of the libraries, you can customize the [requirements.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt), excluding or including (for example, ``opencv-python``) some dependencies.
+The below line initiates requirements installation into the previously activated virtual environment:
+
+```console
+pip install -r requirements.txt
+```
+
+## Practice
+In this part we are going to cover the following points:
+1. create a classification model conversion pipeline
+2. provide the inference, process prediction results
+
+### Model Conversion Pipeline
+The code in this subchapter is located in the ``samples/dnn/dnn_model_runner`` module and can be executed with the line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_resnet50_onnx
+```
+
+The following code contains the description of the below-listed steps:
+1. instantiate PyTorch model
+2. convert PyTorch model into ``.onnx``
+
+```python
+# initialize PyTorch ResNet-50 model
+original_model = models.resnet50(pretrained=True)
+
+# get the path to the converted into ONNX PyTorch model
+full_model_path = get_pytorch_onnx_model(original_model)
+print("PyTorch ResNet-50 model was successfully converted: ", full_model_path)
+```
+
+``get_pytorch_onnx_model(original_model)`` function is based on ``torch.onnx.export(...)`` call:
+
+```python
+# define the directory for further converted model save
+onnx_model_path = "models"
+# define the name of further converted model
+onnx_model_name = "resnet50.onnx"
+
+# create directory for further converted model
+os.makedirs(onnx_model_path, exist_ok=True)
+
+# get full path to the converted model
+full_model_path = os.path.join(onnx_model_path, onnx_model_name)
+
+# generate model input
+generated_input = Variable(
+    torch.randn(1, 3, 224, 224)
+)
+
+# model export into ONNX format
+torch.onnx.export(
+    original_model,
+    generated_input,
+    full_model_path,
+    verbose=True,
+    input_names=["input"],
+    output_names=["output"],
+    opset_version=11
+)
+```
+
+After the successful execution of the above code we will get the following output:
+
+```console
+PyTorch ResNet-50 model was successfully converted: models/resnet50.onnx
+```
+
+The proposed in ``dnn/samples`` module ``dnn_model_runner`` allows us to reproduce the above conversion steps for the following PyTorch classification models:
+* alexnet
+* vgg11
+* vgg13
+* vgg16
+* vgg19
+* resnet18
+* resnet34
+* resnet50
+* resnet101
+* resnet152
+* squeezenet1_0
+* squeezenet1_1
+* resnext50_32x4d
+* resnext101_32x8d
+* wide_resnet50_2
+* wide_resnet101_2
+
+To obtain the converted model, the following line should be executed:
+
+```
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name <pytorch_cls_model_name> --evaluate False
+```
+
+For the ResNet-50 case the below line should be run:
+
+```
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name resnet50 --evaluate False
+```
+
+The default root directory for the converted model storage is defined in module ``CommonConfig``:
+
+```python
+@dataclass
+class CommonConfig:
+    output_data_root_dir: str = "dnn_model_runner/dnn_conversion"
+```
+
+Thus, the converted ResNet-50 will be saved in ``dnn_model_runner/dnn_conversion/models``.
+
+### Inference Pipeline
+Now we can use ```models/resnet50.onnx``` for the inference pipeline using OpenCV C/C++ API. The implemented pipeline can be found in [samples/dnn/classification.cpp](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.cpp).
+After the build of samples (``BUILD_EXAMPLES`` flag value should be ``ON``), the appropriate ``example_dnn_classification`` executable file will be provided.
+
+To provide model inference we will use the below [squirrel photo](https://www.pexels.com/photo/brown-squirrel-eating-1564292) (under [CC0](https://www.pexels.com/terms-of-service/) license) corresponding to ImageNet class ID 335:
+```console
+fox squirrel, eastern fox squirrel, Sciurus niger
+```
+
+![Classification model input image](images/squirrel_cls.jpg)
+
+For the label decoding of the obtained prediction, we also need ``imagenet_classes.txt`` file, which contains the full list of the ImageNet classes.
+
+In this tutorial we will run the inference process for the converted PyTorch ResNet-50 model from the build (``samples/build``) directory:
+
+```
+./dnn/example_dnn_classification --model=../dnn/models/resnet50.onnx --input=../data/squirrel_cls.jpg --width=224 --height=224 --rgb=true --scale="0.003921569" --mean="123.675 116.28 103.53" --std="0.229 0.224 0.225" --crop=true --initial_width=256 --initial_height=256 --classes=../data/dnn/classification_classes_ILSVRC2012.txt
+```
+
+Let's explore ``classification.cpp`` key points step by step:
+
+1. read the model with cv::dnn::readNet, initialize the network:
+
+```cpp
+Net net = readNet(model, config, framework);
+```
+
+The ``model`` parameter value is taken from ``--model`` key. In our case, it is ``resnet50.onnx``.
+
+* preprocess input image:
+
+```cpp
+if (rszWidth != 0 && rszHeight != 0)
+{
+    resize(frame, frame, Size(rszWidth, rszHeight));
+}
+
+// Create a 4D blob from a frame
+blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, crop);
+
+// Check std values.
+if (std.val[0] != 0.0 && std.val[1] != 0.0 && std.val[2] != 0.0)
+{
+    // Divide blob by std.
+    divide(blob, std, blob);
+}
+```
+
+In this step we use cv::dnn::blobFromImage function to prepare model input.
+We set ``Size(rszWidth, rszHeight)`` with  ``--initial_width=256 --initial_height=256`` for the initial image resize as it's described in [PyTorch ResNet inference pipeline](https://pytorch.org/hub/pytorch_vision_resnet/).
+
+It should be noted that firstly in cv::dnn::blobFromImage mean value is subtracted and only then pixel values are multiplied by scale.
+Thus, we use ``--mean="123.675 116.28 103.53"``, which is equivalent to ``[0.485, 0.456, 0.406]`` multiplied by ``255.0`` to reproduce the original image preprocessing order for PyTorch classification models:
+
+```python
+img /= 255.0
+img -= [0.485, 0.456, 0.406]
+img /= [0.229, 0.224, 0.225]
+```
+
+* make forward pass:
+
+```cpp
+net.setInput(blob);
+Mat prob = net.forward();
+```
+
+* process the prediction:
+
+```cpp
+Point classIdPoint;
+double confidence;
+minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
+int classId = classIdPoint.x;
+```
+
+Here we choose the most likely object class. The ``classId`` result for our case is 335 - fox squirrel, eastern fox squirrel, Sciurus niger:
+
+![ResNet50 OpenCV C++ inference output](images/opencv_resnet50_test_res_c.jpg)
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_tutorial.md b/doc/tutorials/dnn/dnn_pytorch_tf_classification/pytorch_cls_model_conversion_tutorial.md
new file mode 100644 (file)
index 0000000..409d2f5
--- /dev/null
@@ -0,0 +1,362 @@
+# Conversion of PyTorch Classification Models and Launch with OpenCV Python {#pytorch_cls_tutorial_dnn_conversion}
+
+@prev_tutorial{tutorial_dnn_OCR}
+@next_tutorial{pytorch_cls_c_tutorial_dnn_conversion}
+
+|    |    |
+| -: | :- |
+| Original author | Anastasia Murzova |
+| Compatibility | OpenCV >= 4.5 |
+
+## Goals
+In this tutorial you will learn how to:
+* convert PyTorch classification models into ONNX format
+* run converted PyTorch model with OpenCV Python API
+* obtain an evaluation of the PyTorch and OpenCV DNN models.
+
+We will explore the above-listed points by the example of the ResNet-50 architecture.
+
+## Introduction
+Let's briefly view the key concepts involved in the pipeline of PyTorch models transition with OpenCV API. The initial step in conversion of PyTorch models into cv.dnn.Net
+is model transferring into [ONNX](https://onnx.ai/about.html) format. ONNX aims at the interchangeability of the neural networks between various frameworks. There is a built-in function in PyTorch for ONNX conversion: [``torch.onnx.export``](https://pytorch.org/docs/stable/onnx.html#torch.onnx.export).
+Further the obtained ``.onnx`` model is passed into cv.dnn.readNetFromONNX.
+
+## Requirements
+To be able to experiment with the below code you will need to install a set of libraries. We will use a virtual environment with python3.7+ for this:
+
+```console
+virtualenv -p /usr/bin/python3.7 <env_dir_path>
+source <env_dir_path>/bin/activate
+```
+
+For OpenCV-Python building from source, follow the corresponding instructions from the @ref tutorial_py_table_of_contents_setup.
+
+Before you start the installation of the libraries, you can customize the [requirements.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt), excluding or including (for example, ``opencv-python``) some dependencies.
+The below line initiates requirements installation into the previously activated virtual environment:
+
+```console
+pip install -r requirements.txt
+```
+
+## Practice
+In this part we are going to cover the following points:
+1. create a classification model conversion pipeline and provide the inference
+2. evaluate and test classification models
+
+If you'd like merely to run evaluation or test model pipelines, the "Model Conversion Pipeline" part can be skipped.
+
+### Model Conversion Pipeline
+The code in this subchapter is located in the ``dnn_model_runner`` module and can be executed with the line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_resnet50
+```
+
+The following code contains the description of the below-listed steps:
+1. instantiate PyTorch model
+2. convert PyTorch model into ``.onnx``
+3. read the transferred network with OpenCV API
+4. prepare input data
+5. provide inference
+
+```python
+# initialize PyTorch ResNet-50 model
+original_model = models.resnet50(pretrained=True)
+
+# get the path to the converted into ONNX PyTorch model
+full_model_path = get_pytorch_onnx_model(original_model)
+
+# read converted .onnx model with OpenCV API
+opencv_net = cv2.dnn.readNetFromONNX(full_model_path)
+print("OpenCV model was successfully read. Layer IDs: \n", opencv_net.getLayerNames())
+
+# get preprocessed image
+input_img = get_preprocessed_img("../data/squirrel_cls.jpg")
+
+# get ImageNet labels
+imagenet_labels = get_imagenet_labels("../data/dnn/classification_classes_ILSVRC2012.txt")
+
+# obtain OpenCV DNN predictions
+get_opencv_dnn_prediction(opencv_net, input_img, imagenet_labels)
+
+# obtain original PyTorch ResNet50 predictions
+get_pytorch_dnn_prediction(original_model, input_img, imagenet_labels)
+```
+
+To provide model inference we will use the below [squirrel photo](https://www.pexels.com/photo/brown-squirrel-eating-1564292) (under [CC0](https://www.pexels.com/terms-of-service/) license) corresponding to ImageNet class ID 335:
+```console
+fox squirrel, eastern fox squirrel, Sciurus niger
+```
+
+![Classification model input image](images/squirrel_cls.jpg)
+
+For the label decoding of the obtained prediction, we also need ``imagenet_classes.txt`` file, which contains the full list of the ImageNet classes.
+
+Let's go deeper into each step by the example of pretrained PyTorch ResNet-50:
+*  instantiate PyTorch ResNet-50 model:
+
+```python
+# initialize PyTorch ResNet-50 model
+original_model = models.resnet50(pretrained=True)
+```
+
+*  convert PyTorch model into ONNX:
+
+```python
+# define the directory for further converted model save
+onnx_model_path = "models"
+# define the name of further converted model
+onnx_model_name = "resnet50.onnx"
+
+# create directory for further converted model
+os.makedirs(onnx_model_path, exist_ok=True)
+
+# get full path to the converted model
+full_model_path = os.path.join(onnx_model_path, onnx_model_name)
+
+# generate model input
+generated_input = Variable(
+    torch.randn(1, 3, 224, 224)
+)
+
+# model export into ONNX format
+torch.onnx.export(
+    original_model,
+    generated_input,
+    full_model_path,
+    verbose=True,
+    input_names=["input"],
+    output_names=["output"],
+    opset_version=11
+)
+```
+
+After the successful execution of the above code, we will get ``models/resnet50.onnx``.
+
+* read the transferred network with cv.dnn.readNetFromONNX passing the obtained in the previous step ONNX model into it:
+
+```python
+# read converted .onnx model with OpenCV API
+opencv_net = cv2.dnn.readNetFromONNX(full_model_path)
+```
+
+* prepare input data:
+
+```python
+# read the image
+input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+input_img = input_img.astype(np.float32)
+
+input_img = cv2.resize(input_img, (256, 256))
+
+# define preprocess parameters
+mean = np.array([0.485, 0.456, 0.406]) * 255.0
+scale = 1 / 255.0
+std = [0.229, 0.224, 0.225]
+
+# prepare input blob to fit the model input:
+# 1. subtract mean
+# 2. scale to set pixel values from 0 to 1
+input_blob = cv2.dnn.blobFromImage(
+    image=input_img,
+    scalefactor=scale,
+    size=(224, 224),  # img target size
+    mean=mean,
+    swapRB=True,  # BGR -> RGB
+    crop=True  # center crop
+)
+# 3. divide by std
+input_blob[0] /= np.asarray(std, dtype=np.float32).reshape(3, 1, 1)
+```
+
+In this step we read the image and prepare model input with cv.dnn.blobFromImage function, which returns 4-dimensional blob.
+It should be noted that firstly in cv.dnn.blobFromImage mean value is subtracted and only then pixel values are multiplied by scale. Thus, ``mean`` is multiplied by ``255.0`` to reproduce the original image preprocessing order:
+
+```python
+img /= 255.0
+img -= [0.485, 0.456, 0.406]
+img /= [0.229, 0.224, 0.225]
+```
+
+* OpenCV cv.dnn.Net inference:
+
+```python
+# set OpenCV DNN input
+opencv_net.setInput(preproc_img)
+
+# OpenCV DNN inference
+out = opencv_net.forward()
+print("OpenCV DNN prediction: \n")
+print("* shape: ", out.shape)
+
+# get the predicted class ID
+imagenet_class_id = np.argmax(out)
+
+# get confidence
+confidence = out[0][imagenet_class_id]
+print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+print("* confidence: {:.4f}".format(confidence))
+```
+
+After the above code execution we will get the following output:
+
+```console
+OpenCV DNN prediction:
+* shape:  (1, 1000)
+* class ID: 335, label: fox squirrel, eastern fox squirrel, Sciurus niger
+* confidence: 14.8308
+```
+
+* PyTorch ResNet-50 model inference:
+
+```python
+original_net.eval()
+preproc_img = torch.FloatTensor(preproc_img)
+
+# inference
+out = original_net(preproc_img)
+print("\nPyTorch model prediction: \n")
+print("* shape: ", out.shape)
+
+# get the predicted class ID
+imagenet_class_id = torch.argmax(out, axis=1).item()
+print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+
+# get confidence
+confidence = out[0][imagenet_class_id]
+print("* confidence: {:.4f}".format(confidence.item()))
+```
+
+After the above code launching we will get the following output:
+
+```console
+PyTorch model prediction:
+* shape:  torch.Size([1, 1000])
+* class ID: 335, label: fox squirrel, eastern fox squirrel, Sciurus niger
+* confidence: 14.8308
+```
+
+The inference results of the original ResNet-50 model and cv.dnn.Net are equal. For the extended evaluation of the models we can use ``py_to_py_cls`` of the ``dnn_model_runner`` module. This module part will be described in the next subchapter.
+
+### Evaluation of the Models
+
+The proposed in ``samples/dnn`` ``dnn_model_runner`` module allows to run the full evaluation pipeline on the ImageNet dataset and test execution for the following PyTorch classification models:
+* alexnet
+* vgg11
+* vgg13
+* vgg16
+* vgg19
+* resnet18
+* resnet34
+* resnet50
+* resnet101
+* resnet152
+* squeezenet1_0
+* squeezenet1_1
+* resnext50_32x4d
+* resnext101_32x8d
+* wide_resnet50_2
+* wide_resnet101_2
+
+This list can be also extended with further appropriate evaluation pipeline configuration.
+
+#### Evaluation Mode
+
+The below line represents running of the module in the evaluation mode:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name <pytorch_cls_model_name>
+```
+
+Chosen from the list classification model will be read into OpenCV cv.dnn.Net object. Evaluation results of PyTorch and OpenCV models (accuracy, inference time, L1) will be written into the log file. Inference time values will be also depicted in a chart to generalize the obtained model information.
+
+Necessary evaluation configurations are defined in the [test_config.py](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py) and can be modified in accordance with actual paths of data location:
+
+```python
+@dataclass
+class TestClsConfig:
+    batch_size: int = 50
+    frame_size: int = 224
+    img_root_dir: str = "./ILSVRC2012_img_val"
+    # location of image-class matching
+    img_cls_file: str = "./val.txt"
+    bgr_to_rgb: bool = True
+```
+
+To initiate the evaluation of the PyTorch ResNet-50, run the following line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name resnet50
+```
+
+After script launch, the log file with evaluation data will be generated in ``dnn_model_runner/dnn_conversion/logs``:
+
+```console
+The model PyTorch resnet50 was successfully obtained and converted to OpenCV DNN resnet50
+===== Running evaluation of the model with the following params:
+    * val data location: ./ILSVRC2012_img_val
+    * log file location: dnn_model_runner/dnn_conversion/logs/PyTorch_resnet50_log.txt
+```
+
+#### Test Mode
+
+The below line represents running of the module in the test mode, namely it provides the steps for the model inference:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name <pytorch_cls_model_name> --test True --default_img_preprocess <True/False> --evaluate False
+```
+
+Here ``default_img_preprocess`` key defines whether you'd like to parametrize the model test process with some particular values or use the default values, for example, ``scale``, ``mean`` or ``std``.
+
+Test configuration is represented in [test_config.py](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py) ``TestClsModuleConfig`` class:
+
+```python
+@dataclass
+class TestClsModuleConfig:
+    cls_test_data_dir: str = "../data"
+    test_module_name: str = "classification"
+    test_module_path: str = "classification.py"
+    input_img: str = os.path.join(cls_test_data_dir, "squirrel_cls.jpg")
+    model: str = ""
+
+    frame_height: str = str(TestClsConfig.frame_size)
+    frame_width: str = str(TestClsConfig.frame_size)
+    scale: str = "1.0"
+    mean: List[str] = field(default_factory=lambda: ["0.0", "0.0", "0.0"])
+    std: List[str] = field(default_factory=list)
+    crop: str = "False"
+    rgb: str = "True"
+    rsz_height: str = ""
+    rsz_width: str = ""
+    classes: str = os.path.join(cls_test_data_dir, "dnn", "classification_classes_ILSVRC2012.txt")
+```
+
+The default image preprocessing options are defined in [default_preprocess_config.py](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/default_preprocess_config.py). For instance:
+
+```python
+BASE_IMG_SCALE_FACTOR = 1 / 255.0
+PYTORCH_RSZ_HEIGHT = 256
+PYTORCH_RSZ_WIDTH = 256
+
+pytorch_resize_input_blob = {
+    "mean": ["123.675", "116.28", "103.53"],
+    "scale": str(BASE_IMG_SCALE_FACTOR),
+    "std": ["0.229", "0.224", "0.225"],
+    "crop": "True",
+    "rgb": "True",
+    "rsz_height": str(PYTORCH_RSZ_HEIGHT),
+    "rsz_width": str(PYTORCH_RSZ_WIDTH)
+}
+```
+
+The basis of the model testing is represented in [samples/dnn/classification.py](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.py).  ``classification.py`` can be executed autonomously with provided converted model in ``--input`` and populated parameters for cv.dnn.blobFromImage.
+
+To reproduce from scratch the described in "Model Conversion Pipeline" OpenCV steps with ``dnn_model_runner`` execute the below line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.pytorch.classification.py_to_py_cls --model_name resnet50 --test True --default_img_preprocess True --evaluate False
+```
+
+The network prediction is depicted in the top left corner of the output window:
+
+![ResNet50 OpenCV inference output](images/pytorch_resnet50_opencv_test_res.jpg)
diff --git a/doc/tutorials/dnn/dnn_pytorch_tf_classification/tf_cls_model_conversion_tutorial.md b/doc/tutorials/dnn/dnn_pytorch_tf_classification/tf_cls_model_conversion_tutorial.md
new file mode 100644 (file)
index 0000000..c2da541
--- /dev/null
@@ -0,0 +1,360 @@
+# Conversion of TensorFlow Classification Models and Launch with OpenCV Python {#tf_cls_tutorial_dnn_conversion}
+
+|    |    |
+| -: | :- |
+| Original author | Anastasia Murzova |
+| Compatibility | OpenCV >= 4.5 |
+
+## Goals
+In this tutorial you will learn how to:
+* obtain frozen graphs of TensorFlow (TF) classification models
+* run converted TensorFlow model with OpenCV Python API
+* obtain an evaluation of the TensorFlow and OpenCV DNN models
+
+We will explore the above-listed points by the example of MobileNet architecture.
+
+## Introduction
+Let's briefly view the key concepts involved in the pipeline of TensorFlow models transition with OpenCV API. The initial step in conversion of TensorFlow models into cv.dnn.Net
+is obtaining the frozen TF model graph. Frozen graph defines the combination of the model graph structure with kept values of the required variables, for example, weights. Usually the frozen graph is saved in [protobuf](https://en.wikipedia.org/wiki/Protocol_Buffers) (```.pb```) files.
+After the model ``.pb`` file was generated it can be read with cv.dnn.readNetFromTensorflow function.
+
+## Requirements
+To be able to experiment with the below code you will need to install a set of libraries. We will use a virtual environment with python3.7+ for this:
+
+```console
+virtualenv -p /usr/bin/python3.7 <env_dir_path>
+source <env_dir_path>/bin/activate
+```
+
+For OpenCV-Python building from source, follow the corresponding instructions from the @ref tutorial_py_table_of_contents_setup.
+
+Before you start the installation of the libraries, you can customize the [requirements.txt](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt), excluding or including (for example, ``opencv-python``) some dependencies.
+The below line initiates requirements installation into the previously activated virtual environment:
+
+```console
+pip install -r requirements.txt
+```
+
+## Practice
+In this part we are going to cover the following points:
+1. create a TF classification model conversion pipeline and provide the inference
+2. evaluate and test TF classification models
+
+If you'd like merely to run evaluation or test model pipelines, the "Model Conversion Pipeline" tutorial part can be skipped.
+
+### Model Conversion Pipeline
+The code in this subchapter is located in the ``dnn_model_runner`` module and can be executed with the line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.tf.classification.py_to_py_mobilenet
+```
+
+The following code contains the description of the below-listed steps:
+1. instantiate TF model
+2. create TF frozen graph
+3. read TF frozen graph with OpenCV API
+4. prepare input data
+5. provide inference
+
+```python
+# initialize TF MobileNet model
+original_tf_model = MobileNet(
+    include_top=True,
+    weights="imagenet"
+)
+
+# get TF frozen graph path
+full_pb_path = get_tf_model_proto(original_tf_model)
+
+# read frozen graph with OpenCV API
+opencv_net = cv2.dnn.readNetFromTensorflow(full_pb_path)
+print("OpenCV model was successfully read. Model layers: \n", opencv_net.getLayerNames())
+
+# get preprocessed image
+input_img = get_preprocessed_img("../data/squirrel_cls.jpg")
+
+# get ImageNet labels
+imagenet_labels = get_imagenet_labels("../data/dnn/classification_classes_ILSVRC2012.txt")
+
+# obtain OpenCV DNN predictions
+get_opencv_dnn_prediction(opencv_net, input_img, imagenet_labels)
+
+# obtain TF model predictions
+get_tf_dnn_prediction(original_tf_model, input_img, imagenet_labels)
+```
+
+To provide model inference we will use the below [squirrel photo](https://www.pexels.com/photo/brown-squirrel-eating-1564292) (under [CC0](https://www.pexels.com/terms-of-service/) license) corresponding to ImageNet class ID 335:
+```console
+fox squirrel, eastern fox squirrel, Sciurus niger
+```
+
+![Classification model input image](images/squirrel_cls.jpg)
+
+For the label decoding of the obtained prediction, we also need ``imagenet_classes.txt`` file, which contains the full list of the ImageNet classes.
+
+Let's go deeper into each step by the example of pretrained TF MobileNet:
+* instantiate TF model:
+
+```python
+# initialize TF MobileNet model
+original_tf_model = MobileNet(
+    include_top=True,
+    weights="imagenet"
+)
+```
+
+* create TF frozen graph
+
+```python
+# define the directory for .pb model
+pb_model_path = "models"
+
+# define the name of .pb model
+pb_model_name = "mobilenet.pb"
+
+# create directory for further converted model
+os.makedirs(pb_model_path, exist_ok=True)
+
+# get model TF graph
+tf_model_graph = tf.function(lambda x: tf_model(x))
+
+# get concrete function
+tf_model_graph = tf_model_graph.get_concrete_function(
+    tf.TensorSpec(tf_model.inputs[0].shape, tf_model.inputs[0].dtype))
+
+# obtain frozen concrete function
+frozen_tf_func = convert_variables_to_constants_v2(tf_model_graph)
+# get frozen graph
+frozen_tf_func.graph.as_graph_def()
+
+# save full tf model
+tf.io.write_graph(graph_or_graph_def=frozen_tf_func.graph,
+                  logdir=pb_model_path,
+                  name=pb_model_name,
+                  as_text=False)
+```
+
+After the successful execution of the above code, we will get a frozen graph in ``models/mobilenet.pb``.
+
+* read TF frozen graph with with cv.dnn.readNetFromTensorflow passing the obtained in the previous step ``mobilenet.pb`` into it:
+
+```python
+# get TF frozen graph path
+full_pb_path = get_tf_model_proto(original_tf_model)
+```
+
+* prepare input data with cv2.dnn.blobFromImage function:
+
+```python
+# read the image
+input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+input_img = input_img.astype(np.float32)
+
+# define preprocess parameters
+mean = np.array([1.0, 1.0, 1.0]) * 127.5
+scale = 1 / 127.5
+
+# prepare input blob to fit the model input:
+# 1. subtract mean
+# 2. scale to set pixel values from 0 to 1
+input_blob = cv2.dnn.blobFromImage(
+    image=input_img,
+    scalefactor=scale,
+    size=(224, 224),  # img target size
+    mean=mean,
+    swapRB=True,  # BGR -> RGB
+    crop=True  # center crop
+)
+print("Input blob shape: {}\n".format(input_blob.shape))
+```
+
+Please, pay attention at the preprocessing order in the cv2.dnn.blobFromImage function. Firstly, the mean value is subtracted and only then pixel values are multiplied by the defined scale.
+Therefore, to reproduce the image preprocessing pipeline from the TF [``mobilenet.preprocess_input``](https://github.com/tensorflow/tensorflow/blob/02032fb477e9417197132648ec81e75beee9063a/tensorflow/python/keras/applications/mobilenet.py#L443-L445) function, we multiply ``mean`` by ``127.5``.
+
+As a result, 4-dimensional ``input_blob`` was obtained:
+
+ ``Input blob shape: (1, 3, 224, 224)``
+
+* provide OpenCV cv.dnn.Net inference:
+
+```python
+# set OpenCV DNN input
+opencv_net.setInput(preproc_img)
+
+# OpenCV DNN inference
+out = opencv_net.forward()
+print("OpenCV DNN prediction: \n")
+print("* shape: ", out.shape)
+
+# get the predicted class ID
+imagenet_class_id = np.argmax(out)
+
+# get confidence
+confidence = out[0][imagenet_class_id]
+print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+print("* confidence: {:.4f}\n".format(confidence))
+```
+
+After the above code execution we will get the following output:
+
+```console
+OpenCV DNN prediction:
+* shape:  (1, 1000)
+* class ID: 335, label: fox squirrel, eastern fox squirrel, Sciurus niger
+* confidence: 0.9525
+```
+
+* provide TF MobileNet inference:
+
+```python
+# inference
+preproc_img = preproc_img.transpose(0, 2, 3, 1)
+print("TF input blob shape: {}\n".format(preproc_img.shape))
+
+out = original_net(preproc_img)
+
+print("\nTensorFlow model prediction: \n")
+print("* shape: ", out.shape)
+
+# get the predicted class ID
+imagenet_class_id = np.argmax(out)
+print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+
+# get confidence
+confidence = out[0][imagenet_class_id]
+print("* confidence: {:.4f}".format(confidence))
+```
+
+To fit TF model input, ``input_blob`` was transposed:
+
+```console
+TF input blob shape: (1, 224, 224, 3)
+```
+
+TF inference results are the following:
+
+```console
+TensorFlow model prediction:
+* shape:  (1, 1000)
+* class ID: 335, label: fox squirrel, eastern fox squirrel, Sciurus niger
+* confidence: 0.9525
+```
+
+As it can be seen from the experiments OpenCV and TF inference results are equal.
+
+### Evaluation of the Models
+
+The proposed in ``dnn/samples`` ``dnn_model_runner`` module allows to run the full evaluation pipeline on the ImageNet dataset and test execution for the following TensorFlow classification models:
+* vgg16
+* vgg19
+* resnet50
+* resnet101
+* resnet152
+* densenet121
+* densenet169
+* densenet201
+* inceptionresnetv2
+* inceptionv3
+* mobilenet
+* mobilenetv2
+* nasnetlarge
+* nasnetmobile
+* xception
+
+This list can be also extended with further appropriate evaluation pipeline configuration.
+
+#### Evaluation Mode
+
+To below line represents running of the module in the evaluation mode:
+
+```console
+python -m dnn_model_runner.dnn_conversion.tf.classification.py_to_py_cls --model_name <tf_cls_model_name>
+```
+
+Chosen from the list classification model will be read into OpenCV ``cv.dnn_Net`` object. Evaluation results of TF and OpenCV models (accuracy, inference time, L1) will be written into the log file. Inference time values will be also depicted in a chart to generalize the obtained model information.
+
+Necessary evaluation configurations are defined in the [test_config.py](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py) and can be modified in accordance with actual paths of data location::
+
+```python
+@dataclass
+class TestClsConfig:
+    batch_size: int = 50
+    frame_size: int = 224
+    img_root_dir: str = "./ILSVRC2012_img_val"
+    # location of image-class matching
+    img_cls_file: str = "./val.txt"
+    bgr_to_rgb: bool = True
+```
+
+The values from ``TestClsConfig`` can be customized in accordance with chosen model.
+
+To initiate the evaluation of the TensorFlow MobileNet, run the following line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.tf.classification.py_to_py_cls --model_name mobilenet
+```
+
+After script launch, the log file with evaluation data will be generated in ``dnn_model_runner/dnn_conversion/logs``:
+
+```console
+===== Running evaluation of the model with the following params:
+    * val data location: ./ILSVRC2012_img_val
+    * log file location: dnn_model_runner/dnn_conversion/logs/TF_mobilenet_log.txt
+```
+
+#### Test Mode
+
+The below line represents running of the module in the test mode, namely it provides the steps for the model inference:
+
+```console
+python -m dnn_model_runner.dnn_conversion.tf.classification.py_to_py_cls --model_name <tf_cls_model_name> --test True --default_img_preprocess <True/False> --evaluate False
+```
+
+Here ``default_img_preprocess`` key defines whether you'd like to parametrize the model test process with some particular values or use the default values, for example, ``scale``, ``mean`` or ``std``.
+
+Test configuration is represented in [test_config.py](https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py) ``TestClsModuleConfig`` class:
+
+```python
+@dataclass
+class TestClsModuleConfig:
+    cls_test_data_dir: str = "../data"
+    test_module_name: str = "classification"
+    test_module_path: str = "classification.py"
+    input_img: str = os.path.join(cls_test_data_dir, "squirrel_cls.jpg")
+    model: str = ""
+
+    frame_height: str = str(TestClsConfig.frame_size)
+    frame_width: str = str(TestClsConfig.frame_size)
+    scale: str = "1.0"
+    mean: List[str] = field(default_factory=lambda: ["0.0", "0.0", "0.0"])
+    std: List[str] = field(default_factory=list)
+    crop: str = "False"
+    rgb: str = "True"
+    rsz_height: str = ""
+    rsz_width: str = ""
+    classes: str = os.path.join(cls_test_data_dir, "dnn", "classification_classes_ILSVRC2012.txt")
+```
+
+The default image preprocessing options are defined in ``default_preprocess_config.py``. For instance, for MobileNet:
+
+```python
+tf_input_blob = {
+    "mean": ["127.5", "127.5", "127.5"],
+    "scale": str(1 / 127.5),
+    "std": [],
+    "crop": "True",
+    "rgb": "True"
+}
+```
+
+The basis of the model testing is represented in [samples/dnn/classification.py](https://github.com/opencv/opencv/blob/master/samples/dnn/classification.py). ``classification.py`` can be executed autonomously with provided converted model in ``--input`` and populated parameters for cv.dnn.blobFromImage.
+
+To reproduce from scratch the described in "Model Conversion Pipeline" OpenCV steps with ``dnn_model_runner`` execute the below line:
+
+```console
+python -m dnn_model_runner.dnn_conversion.tf.classification.py_to_py_cls --model_name mobilenet --test True --default_img_preprocess True --evaluate False
+```
+
+The network prediction is depicted in the top left corner of the output window:
+
+![TF MobileNet OpenCV inference output](images/tf_mobilenet_opencv_test_res.jpg)
index 603ae9a..5f28b6c 100644 (file)
@@ -3,6 +3,7 @@
 @tableofcontents
 
 @prev_tutorial{tutorial_dnn_OCR}
+@next_tutorial{pytorch_cls_tutorial_dnn_conversion}
 
 |    |    |
 | -: | :- |
index dd3e596..a74554f 100644 (file)
@@ -10,3 +10,12 @@ Deep Neural Networks (dnn module) {#tutorial_table_of_content_dnn}
 -   @subpage tutorial_dnn_custom_layers
 -   @subpage tutorial_dnn_OCR
 -   @subpage tutorial_dnn_text_spotting
+
+#### PyTorch models with OpenCV
+In this section you will find the guides, which describe how to run classification, segmentation and detection PyTorch DNN models with OpenCV.
+-   @subpage pytorch_cls_tutorial_dnn_conversion
+-   @subpage pytorch_cls_c_tutorial_dnn_conversion
+
+#### TensorFlow models with OpenCV
+In this section you will find the guides, which describe how to run classification, segmentation and detection TensorFlow DNN models with OpenCV.
+-   @subpage tf_cls_tutorial_dnn_conversion
diff --git a/samples/data/squirrel_cls.jpg b/samples/data/squirrel_cls.jpg
new file mode 100644 (file)
index 0000000..289b13b
Binary files /dev/null and b/samples/data/squirrel_cls.jpg differ
index 0ae9e6e..8440371 100644 (file)
@@ -8,22 +8,26 @@
 #include "common.hpp"
 
 std::string keys =
-    "{ help  h     | | Print help message. }"
-    "{ @alias      | | An alias name of model to extract preprocessing parameters from models.yml file. }"
-    "{ zoo         | models.yml | An optional path to file with preprocessing parameters }"
-    "{ input i     | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
-    "{ framework f | | Optional name of an origin framework of the model. Detect it automatically if it does not set. }"
-    "{ classes     | | Optional path to a text file with names of classes. }"
-    "{ backend     | 0 | Choose one of computation backends: "
-                        "0: automatically (by default), "
-                        "1: Halide language (http://halide-lang.org/), "
-                        "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
-                        "3: OpenCV implementation }"
-    "{ target      | 0 | Choose one of target computation devices: "
-                        "0: CPU target (by default), "
-                        "1: OpenCL, "
-                        "2: OpenCL fp16 (half-float precision), "
-                        "3: VPU }";
+    "{ help  h          | | Print help message. }"
+    "{ @alias           | | An alias name of model to extract preprocessing parameters from models.yml file. }"
+    "{ zoo              | models.yml | An optional path to file with preprocessing parameters }"
+    "{ input i          | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
+    "{ initial_width    | 0 | Preprocess input image by initial resizing to a specific width.}"
+    "{ initial_height   | 0 | Preprocess input image by initial resizing to a specific height.}"
+    "{ std              | 0.0 0.0 0.0 | Preprocess input image by dividing on a standard deviation.}"
+    "{ crop             | false | Preprocess input image by center cropping.}"
+    "{ framework f      | | Optional name of an origin framework of the model. Detect it automatically if it does not set. }"
+    "{ classes          | | Optional path to a text file with names of classes. }"
+    "{ backend          | 0 | Choose one of computation backends: "
+                            "0: automatically (by default), "
+                            "1: Halide language (http://halide-lang.org/), "
+                            "2: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
+                            "3: OpenCV implementation }"
+    "{ target           | 0 | Choose one of target computation devices: "
+                            "0: CPU target (by default), "
+                            "1: OpenCL, "
+                            "2: OpenCL fp16 (half-float precision), "
+                            "3: VPU }";
 
 using namespace cv;
 using namespace dnn;
@@ -47,9 +51,13 @@ int main(int argc, char** argv)
         return 0;
     }
 
+    int rszWidth = parser.get<int>("initial_width");
+    int rszHeight = parser.get<int>("initial_height");
     float scale = parser.get<float>("scale");
     Scalar mean = parser.get<Scalar>("mean");
+    Scalar std = parser.get<Scalar>("std");
     bool swapRB = parser.get<bool>("rgb");
+    bool crop = parser.get<bool>("crop");
     int inpWidth = parser.get<int>("width");
     int inpHeight = parser.get<int>("height");
     String model = findFile(parser.get<String>("model"));
@@ -108,8 +116,20 @@ int main(int argc, char** argv)
             break;
         }
 
+        if (rszWidth != 0 && rszHeight != 0)
+        {
+            resize(frame, frame, Size(rszWidth, rszHeight));
+        }
+
         //! [Create a 4D blob from a frame]
-        blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
+        blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, crop);
+
+        // Check std values.
+        if (std.val[0] != 0.0 && std.val[1] != 0.0 && std.val[2] != 0.0)
+        {
+            // Divide blob by std.
+            divide(blob, std, blob);
+        }
         //! [Create a 4D blob from a frame]
 
         //! [Set input blob]
index 1c6908a..558c8b0 100644 (file)
-import cv2 as cv
 import argparse
-import numpy as np
 
+import cv2 as cv
+import numpy as np
 from common import *
 
-backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
-targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD, cv.dnn.DNN_TARGET_HDDL)
-
-parser = argparse.ArgumentParser(add_help=False)
-parser.add_argument('--zoo', default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
-                    help='An optional path to file with preprocessing parameters.')
-parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
-parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
-                    help='Optional name of an origin framework of the model. '
-                         'Detect it automatically if it does not set.')
-parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
-                    help="Choose one of computation backends: "
-                         "%d: automatically (by default), "
-                         "%d: Halide language (http://halide-lang.org/), "
-                         "%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
-                         "%d: OpenCV implementation" % backends)
-parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
-                    help='Choose one of target computation devices: '
-                         '%d: CPU target (by default), '
-                         '%d: OpenCL, '
-                         '%d: OpenCL fp16 (half-float precision), '
-                         '%d: NCS2 VPU, '
-                         '%d: HDDL VPU' % targets)
-args, _ = parser.parse_known_args()
-add_preproc_args(args.zoo, parser, 'classification')
-parser = argparse.ArgumentParser(parents=[parser],
-                                 description='Use this script to run classification deep learning networks using OpenCV.',
-                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-args = parser.parse_args()
-
-args.model = findFile(args.model)
-args.config = findFile(args.config)
-args.classes = findFile(args.classes)
-
-# Load names of classes
-classes = None
-if args.classes:
-    with open(args.classes, 'rt') as f:
-        classes = f.read().rstrip('\n').split('\n')
-
-# Load a network
-net = cv.dnn.readNet(args.model, args.config, args.framework)
-net.setPreferableBackend(args.backend)
-net.setPreferableTarget(args.target)
-
-winName = 'Deep learning image classification in OpenCV'
-cv.namedWindow(winName, cv.WINDOW_NORMAL)
-
-cap = cv.VideoCapture(args.input if args.input else 0)
-while cv.waitKey(1) < 0:
-    hasFrame, frame = cap.read()
-    if not hasFrame:
-        cv.waitKey()
-        break
-
-    # Create a 4D blob from a frame.
-    inpWidth = args.width if args.width else frame.shape[1]
-    inpHeight = args.height if args.height else frame.shape[0]
-    blob = cv.dnn.blobFromImage(frame, args.scale, (inpWidth, inpHeight), args.mean, args.rgb, crop=False)
-
-    # Run a model
-    net.setInput(blob)
-    out = net.forward()
-
-    # Get a class with a highest score.
-    out = out.flatten()
-    classId = np.argmax(out)
-    confidence = out[classId]
-
-    # Put efficiency information.
-    t, _ = net.getPerfProfile()
-    label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
-    cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
-
-    # Print predicted class.
-    label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence)
-    cv.putText(frame, label, (0, 40), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
-
-    cv.imshow(winName, frame)
+
+def get_args_parser(func_args):
+    backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE,
+                cv.dnn.DNN_BACKEND_OPENCV)
+    targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD,
+               cv.dnn.DNN_TARGET_HDDL)
+
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument('--zoo', default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
+                        help='An optional path to file with preprocessing parameters.')
+    parser.add_argument('--input',
+                        help='Path to input image or video file. Skip this argument to capture frames from a camera.')
+    parser.add_argument('--framework', choices=['caffe', 'tensorflow', 'torch', 'darknet'],
+                        help='Optional name of an origin framework of the model. '
+                             'Detect it automatically if it does not set.')
+    parser.add_argument('--std', nargs='*', type=float,
+                        help='Preprocess input image by dividing on a standard deviation.')
+    parser.add_argument('--crop', type=bool, default=False,
+                        help='Preprocess input image by dividing on a standard deviation.')
+    parser.add_argument('--initial_width', type=int,
+                        help='Preprocess input image by initial resizing to a specific width.')
+    parser.add_argument('--initial_height', type=int,
+                        help='Preprocess input image by initial resizing to a specific height.')
+    parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
+                        help="Choose one of computation backends: "
+                             "%d: automatically (by default), "
+                             "%d: Halide language (http://halide-lang.org/), "
+                             "%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
+                             "%d: OpenCV implementation" % backends)
+    parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
+                        help='Choose one of target computation devices: '
+                             '%d: CPU target (by default), '
+                             '%d: OpenCL, '
+                             '%d: OpenCL fp16 (half-float precision), '
+                             '%d: NCS2 VPU, '
+                             '%d: HDDL VPU' % targets)
+
+    args, _ = parser.parse_known_args()
+    add_preproc_args(args.zoo, parser, 'classification')
+    parser = argparse.ArgumentParser(parents=[parser],
+                                     description='Use this script to run classification deep learning networks using OpenCV.',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    return parser.parse_args(func_args)
+
+
+def main(func_args=None):
+    args = get_args_parser(func_args)
+    args.model = findFile(args.model)
+    args.config = findFile(args.config)
+    args.classes = findFile(args.classes)
+
+    # Load names of classes
+    classes = None
+    if args.classes:
+        with open(args.classes, 'rt') as f:
+            classes = f.read().rstrip('\n').split('\n')
+
+    # Load a network
+    net = cv.dnn.readNet(args.model, args.config, args.framework)
+    net.setPreferableBackend(args.backend)
+    net.setPreferableTarget(args.target)
+
+    winName = 'Deep learning image classification in OpenCV'
+    cv.namedWindow(winName, cv.WINDOW_NORMAL)
+
+    cap = cv.VideoCapture(args.input if args.input else 0)
+    while cv.waitKey(1) < 0:
+        hasFrame, frame = cap.read()
+        if not hasFrame:
+            cv.waitKey()
+            break
+
+        # Create a 4D blob from a frame.
+        inpWidth = args.width if args.width else frame.shape[1]
+        inpHeight = args.height if args.height else frame.shape[0]
+
+        if args.initial_width and args.initial_height:
+            frame = cv.resize(frame, (args.initial_width, args.initial_height))
+
+        blob = cv.dnn.blobFromImage(frame, args.scale, (inpWidth, inpHeight), args.mean, args.rgb, crop=args.crop)
+        if args.std:
+            blob[0] /= np.asarray(args.std, dtype=np.float32).reshape(3, 1, 1)
+
+        # Run a model
+        net.setInput(blob)
+        out = net.forward()
+
+        # Get a class with a highest score.
+        out = out.flatten()
+        classId = np.argmax(out)
+        confidence = out[classId]
+
+        # Put efficiency information.
+        t, _ = net.getPerfProfile()
+        label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
+        cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
+
+        # Print predicted class.
+        label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence)
+        cv.putText(frame, label, (0, 40), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
+
+        cv.imshow(winName, frame)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/abstract_model.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/abstract_model.py
new file mode 100644 (file)
index 0000000..c968c53
--- /dev/null
@@ -0,0 +1,23 @@
+from abc import ABC, ABCMeta, abstractmethod
+
+
+class AbstractModel(ABC):
+
+    @abstractmethod
+    def get_prepared_models(self):
+        pass
+
+
+class Framework(object):
+    in_blob_name = ''
+    out_blob_name = ''
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def get_name(self):
+        pass
+
+    @abstractmethod
+    def get_output(self, input_blob):
+        pass
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_accuracy_evaluator.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_accuracy_evaluator.py
new file mode 100644 (file)
index 0000000..5028f92
--- /dev/null
@@ -0,0 +1,96 @@
+import sys
+import time
+
+import numpy as np
+
+from ...utils import get_final_summary_info
+
+
+class ClsAccEvaluation:
+    log = sys.stdout
+    img_classes = {}
+    batch_size = 0
+
+    def __init__(self, log_path, img_classes_file, batch_size):
+        self.log = open(log_path, 'w')
+        self.img_classes = self.read_classes(img_classes_file)
+        self.batch_size = batch_size
+
+        # collect the accuracies for both models
+        self.general_quality_metric = []
+        self.general_inference_time = []
+
+    @staticmethod
+    def read_classes(img_classes_file):
+        result = {}
+        with open(img_classes_file) as file:
+            for l in file.readlines():
+                result[l.split()[0]] = int(l.split()[1])
+        return result
+
+    def get_correct_answers(self, img_list, net_output_blob):
+        correct_answers = 0
+        for i in range(len(img_list)):
+            indexes = np.argsort(net_output_blob[i])[-5:]
+            correct_index = self.img_classes[img_list[i]]
+            if correct_index in indexes:
+                correct_answers += 1
+        return correct_answers
+
+    def process(self, frameworks, data_fetcher):
+        sorted_imgs_names = sorted(self.img_classes.keys())
+        correct_answers = [0] * len(frameworks)
+        samples_handled = 0
+        blobs_l1_diff = [0] * len(frameworks)
+        blobs_l1_diff_count = [0] * len(frameworks)
+        blobs_l_inf_diff = [sys.float_info.min] * len(frameworks)
+        inference_time = [0.0] * len(frameworks)
+
+        for x in range(0, len(sorted_imgs_names), self.batch_size):
+            sublist = sorted_imgs_names[x:x + self.batch_size]
+            batch = data_fetcher.get_batch(sublist)
+
+            samples_handled += len(sublist)
+            fw_accuracy = []
+            fw_time = []
+            frameworks_out = []
+            for i in range(len(frameworks)):
+                start = time.time()
+                out = frameworks[i].get_output(batch)
+                end = time.time()
+                correct_answers[i] += self.get_correct_answers(sublist, out)
+                fw_accuracy.append(100 * correct_answers[i] / float(samples_handled))
+                frameworks_out.append(out)
+                inference_time[i] += end - start
+                fw_time.append(inference_time[i] / samples_handled * 1000)
+                print(samples_handled, 'Accuracy for', frameworks[i].get_name() + ':', fw_accuracy[i], file=self.log)
+                print("Inference time, ms ", frameworks[i].get_name(), fw_time[i], file=self.log)
+
+                self.general_quality_metric.append(fw_accuracy)
+                self.general_inference_time.append(fw_time)
+
+            for i in range(1, len(frameworks)):
+                log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
+                diff = np.abs(frameworks_out[0] - frameworks_out[i])
+                l1_diff = np.sum(diff) / diff.size
+                print(samples_handled, "L1 difference", log_str, l1_diff, file=self.log)
+                blobs_l1_diff[i] += l1_diff
+                blobs_l1_diff_count[i] += 1
+                if np.max(diff) > blobs_l_inf_diff[i]:
+                    blobs_l_inf_diff[i] = np.max(diff)
+                print(samples_handled, "L_INF difference", log_str, blobs_l_inf_diff[i], file=self.log)
+
+            self.log.flush()
+
+        for i in range(1, len(blobs_l1_diff)):
+            log_str = frameworks[0].get_name() + " vs " + frameworks[i].get_name() + ':'
+            print('Final l1 diff', log_str, blobs_l1_diff[i] / blobs_l1_diff_count[i], file=self.log)
+
+        print(
+            get_final_summary_info(
+                self.general_quality_metric,
+                self.general_inference_time,
+                "accuracy"
+            ),
+            file=self.log
+        )
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_data_fetcher.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/evaluation/classification/cls_data_fetcher.py
new file mode 100644 (file)
index 0000000..805f540
--- /dev/null
@@ -0,0 +1,87 @@
+import os
+from abc import ABCMeta, abstractmethod
+
+import cv2
+import numpy as np
+
+from ...img_utils import read_rgb_img, get_pytorch_preprocess
+from ...test.configs.default_preprocess_config import PYTORCH_RSZ_HEIGHT, PYTORCH_RSZ_WIDTH
+
+
+class DataFetch(object):
+    imgs_dir = ''
+    frame_size = 0
+    bgr_to_rgb = False
+
+    __metaclass__ = ABCMeta
+
+    @abstractmethod
+    def preprocess(self, img):
+        pass
+
+    @staticmethod
+    def reshape_img(img):
+        img = img[:, :, 0:3].transpose(2, 0, 1)
+        return np.expand_dims(img, 0)
+
+    def center_crop(self, img):
+        cols = img.shape[1]
+        rows = img.shape[0]
+
+        y1 = round((rows - self.frame_size) / 2)
+        y2 = round(y1 + self.frame_size)
+        x1 = round((cols - self.frame_size) / 2)
+        x2 = round(x1 + self.frame_size)
+        return img[y1:y2, x1:x2]
+
+    def initial_preprocess(self, img):
+        min_dim = min(img.shape[-3], img.shape[-2])
+        resize_ratio = self.frame_size / float(min_dim)
+
+        img = cv2.resize(img, (0, 0), fx=resize_ratio, fy=resize_ratio)
+        img = self.center_crop(img)
+        return img
+
+    def get_preprocessed_img(self, img_path):
+        image_data = read_rgb_img(img_path, self.bgr_to_rgb)
+        image_data = self.preprocess(image_data)
+        return self.reshape_img(image_data)
+
+    def get_batch(self, img_names):
+        assert type(img_names) is list
+        batch = np.zeros((len(img_names), 3, self.frame_size, self.frame_size)).astype(np.float32)
+
+        for i in range(len(img_names)):
+            img_name = img_names[i]
+            img_file = os.path.join(self.imgs_dir, img_name)
+            assert os.path.exists(img_file)
+
+            batch[i] = self.get_preprocessed_img(img_file)
+        return batch
+
+
+class PyTorchPreprocessedFetch(DataFetch):
+    def __init__(self, pytorch_cls_config, preprocess_input=None):
+        self.imgs_dir = pytorch_cls_config.img_root_dir
+        self.frame_size = pytorch_cls_config.frame_size
+        self.bgr_to_rgb = pytorch_cls_config.bgr_to_rgb
+        self.preprocess_input = preprocess_input
+
+    def preprocess(self, img):
+        img = cv2.resize(img, (PYTORCH_RSZ_WIDTH, PYTORCH_RSZ_HEIGHT))
+        img = self.center_crop(img)
+        if self.preprocess_input:
+            return self.presprocess_input(img)
+        return get_pytorch_preprocess(img)
+
+
+class TFPreprocessedFetch(DataFetch):
+    def __init__(self, tf_cls_config, preprocess_input):
+        self.imgs_dir = tf_cls_config.img_root_dir
+        self.frame_size = tf_cls_config.frame_size
+        self.bgr_to_rgb = tf_cls_config.bgr_to_rgb
+        self.preprocess_input = preprocess_input
+
+    def preprocess(self, img):
+        img = self.initial_preprocess(img)
+        return self.preprocess_input(img)
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/img_utils.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/img_utils.py
new file mode 100644 (file)
index 0000000..3e17ec8
--- /dev/null
@@ -0,0 +1,19 @@
+import cv2
+import numpy as np
+
+from .test.configs.default_preprocess_config import BASE_IMG_SCALE_FACTOR
+
+
+def read_rgb_img(img_file, is_bgr_to_rgb=True):
+    img = cv2.imread(img_file, cv2.IMREAD_COLOR)
+    if is_bgr_to_rgb:
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    return img
+
+
+def get_pytorch_preprocess(img):
+    img = img.astype(np.float32)
+    img *= BASE_IMG_SCALE_FACTOR
+    img -= [0.485, 0.456, 0.406]
+    img /= [0.229, 0.224, 0.225]
+    return img
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/test/cls_model_test_pipeline.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/test/cls_model_test_pipeline.py
new file mode 100644 (file)
index 0000000..757d513
--- /dev/null
@@ -0,0 +1,60 @@
+from .configs.test_config import TestClsConfig, TestClsModuleConfig
+from .model_test_pipeline import ModelTestPipeline
+from ..evaluation.classification.cls_accuracy_evaluator import ClsAccEvaluation
+from ..utils import get_test_module
+
+
+class ClsModelTestPipeline(ModelTestPipeline):
+    def __init__(
+            self,
+            network_model,
+            model_processor,
+            dnn_model_processor,
+            data_fetcher,
+            img_processor=None,
+            cls_args_parser=None,
+            default_input_blob_preproc=None
+    ):
+        super(ClsModelTestPipeline, self).__init__(
+            network_model,
+            model_processor,
+            dnn_model_processor
+        )
+
+        if cls_args_parser:
+            self._parser = cls_args_parser
+
+        self.test_config = TestClsConfig()
+
+        parser_args = self._parser.parse_args()
+
+        if parser_args.test:
+            self._test_module_config = TestClsModuleConfig()
+            self._test_module = get_test_module(
+                self._test_module_config.test_module_name,
+                self._test_module_config.test_module_path
+            )
+
+            if parser_args.default_img_preprocess:
+                self._default_input_blob_preproc = default_input_blob_preproc
+        if parser_args.evaluate:
+            self._data_fetcher = data_fetcher(self.test_config, img_processor)
+
+    def _configure_test_module_params(self):
+        self._test_module_param_list.extend((
+            '--crop', self._test_module_config.crop,
+            '--std', *self._test_module_config.std
+        ))
+
+        if self._test_module_config.rsz_height and self._test_module_config.rsz_width:
+            self._test_module_param_list.extend((
+                '--initial_height', self._test_module_config.rsz_height,
+                '--initial_width', self._test_module_config.rsz_width,
+            ))
+
+    def _configure_acc_eval(self, log_path):
+        self._accuracy_evaluator = ClsAccEvaluation(
+            log_path,
+            self.test_config.img_cls_file,
+            self.test_config.batch_size
+        )
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/default_preprocess_config.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/default_preprocess_config.py
new file mode 100644 (file)
index 0000000..e7ade91
--- /dev/null
@@ -0,0 +1,37 @@
+BASE_IMG_SCALE_FACTOR = 1 / 255.0
+PYTORCH_RSZ_HEIGHT = 256
+PYTORCH_RSZ_WIDTH = 256
+
+pytorch_resize_input_blob = {
+    "mean": ["123.675", "116.28", "103.53"],
+    "scale": str(BASE_IMG_SCALE_FACTOR),
+    "std": ["0.229", "0.224", "0.225"],
+    "crop": "True",
+    "rgb": True,
+    "rsz_height": str(PYTORCH_RSZ_HEIGHT),
+    "rsz_width": str(PYTORCH_RSZ_WIDTH)
+}
+
+pytorch_input_blob = {
+    "mean": ["123.675", "116.28", "103.53"],
+    "scale": str(BASE_IMG_SCALE_FACTOR),
+    "std": ["0.229", "0.224", "0.225"],
+    "crop": "True",
+    "rgb": True
+}
+
+tf_input_blob = {
+    "scale": str(1 / 127.5),
+    "mean": ["127.5", "127.5", "127.5"],
+    "std": [],
+    "crop": "True",
+    "rgb": True
+}
+
+tf_model_blob_caffe_mode = {
+    "mean": ["103.939", "116.779", "123.68"],
+    "scale": "1.0",
+    "std": [],
+    "crop": "True",
+    "rgb": False
+}
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/test/configs/test_config.py
new file mode 100644 (file)
index 0000000..5568091
--- /dev/null
@@ -0,0 +1,40 @@
+import os
+from dataclasses import dataclass, field
+from typing import List
+
+
+@dataclass
+class CommonConfig:
+    output_data_root_dir: str = "dnn_model_runner/dnn_conversion"
+    logs_dir: str = os.path.join(output_data_root_dir, "logs")
+    log_file_path: str = os.path.join(logs_dir, "{}_log.txt")
+
+
+@dataclass
+class TestClsConfig:
+    batch_size: int = 1
+    frame_size: int = 224
+    img_root_dir: str = "./ILSVRC2012_img_val"
+    # location of image-class matching
+    img_cls_file: str = "./val.txt"
+    bgr_to_rgb: bool = True
+
+
+@dataclass
+class TestClsModuleConfig:
+    cls_test_data_dir: str = "../data"
+    test_module_name: str = "classification"
+    test_module_path: str = "classification.py"
+    input_img: str = os.path.join(cls_test_data_dir, "squirrel_cls.jpg")
+    model: str = ""
+
+    frame_height: str = str(TestClsConfig.frame_size)
+    frame_width: str = str(TestClsConfig.frame_size)
+    scale: str = "1.0"
+    mean: List[str] = field(default_factory=lambda: ["0.0", "0.0", "0.0"])
+    std: List[str] = field(default_factory=list)
+    crop: str = "False"
+    rgb: str = "True"
+    rsz_height: str = ""
+    rsz_width: str = ""
+    classes: str = os.path.join(cls_test_data_dir, "dnn", "classification_classes_ILSVRC2012.txt")
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/test/model_test_pipeline.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/test/model_test_pipeline.py
new file mode 100644 (file)
index 0000000..38b9c38
--- /dev/null
@@ -0,0 +1,126 @@
+import os
+
+import numpy as np
+
+from .configs.test_config import CommonConfig
+from ..utils import create_parser, plot_acc
+
+
+class ModelTestPipeline:
+    def __init__(
+            self,
+            network_model,
+            model_processor,
+            dnn_model_processor
+    ):
+        self._net_model = network_model
+        self._model_processor = model_processor
+        self._dnn_model_processor = dnn_model_processor
+
+        self._parser = create_parser()
+
+        self._test_module = None
+        self._test_module_config = None
+        self._test_module_param_list = None
+
+        self.test_config = None
+        self._data_fetcher = None
+
+        self._default_input_blob_preproc = None
+        self._accuracy_evaluator = None
+
+    def init_test_pipeline(self):
+        cmd_args = self._parser.parse_args()
+        model_dict = self._net_model.get_prepared_models()
+
+        model_names = list(model_dict.keys())
+        print(
+            "The model {} was successfully obtained and converted to OpenCV {}".format(model_names[0], model_names[1])
+        )
+
+        if cmd_args.test:
+            if not self._test_module_config.model:
+                self._test_module_config.model = self._net_model.model_path["full_path"]
+
+            if cmd_args.default_img_preprocess:
+                self._test_module_config.scale = self._default_input_blob_preproc["scale"]
+                self._test_module_config.mean = self._default_input_blob_preproc["mean"]
+                self._test_module_config.std = self._default_input_blob_preproc["std"]
+                self._test_module_config.crop = self._default_input_blob_preproc["crop"]
+
+                if "rsz_height" in self._default_input_blob_preproc and "rsz_width" in self._default_input_blob_preproc:
+                    self._test_module_config.rsz_height = self._default_input_blob_preproc["rsz_height"]
+                    self._test_module_config.rsz_width = self._default_input_blob_preproc["rsz_width"]
+
+                self._test_module_param_list = [
+                    '--model', self._test_module_config.model,
+                    '--input', self._test_module_config.input_img,
+                    '--width', self._test_module_config.frame_width,
+                    '--height', self._test_module_config.frame_height,
+                    '--scale', self._test_module_config.scale,
+                    '--mean', *self._test_module_config.mean,
+                    '--std', *self._test_module_config.std,
+                    '--classes', self._test_module_config.classes,
+                ]
+
+                if self._default_input_blob_preproc["rgb"]:
+                    self._test_module_param_list.append('--rgb')
+
+                self._configure_test_module_params()
+
+            self._test_module.main(
+                self._test_module_param_list
+            )
+
+        if cmd_args.evaluate:
+            original_model_name = model_names[0]
+            dnn_model_name = model_names[1]
+
+            self.run_test_pipeline(
+                [
+                    self._model_processor(model_dict[original_model_name], original_model_name),
+                    self._dnn_model_processor(model_dict[dnn_model_name], dnn_model_name)
+                ],
+                original_model_name.replace(" ", "_")
+            )
+
+    def run_test_pipeline(
+            self,
+            models_list,
+            formatted_exp_name,
+            is_plot_acc=True
+    ):
+        log_path, logs_dir = self._configure_eval_log(formatted_exp_name)
+
+        print(
+            "===== Running evaluation of the model with the following params:\n"
+            "\t* val data location: {}\n"
+            "\t* log file location: {}\n".format(
+                self.test_config.img_root_dir,
+                log_path
+            )
+        )
+
+        os.makedirs(logs_dir, exist_ok=True)
+
+        self._configure_acc_eval(log_path)
+        self._accuracy_evaluator.process(models_list, self._data_fetcher)
+
+        if is_plot_acc:
+            plot_acc(
+                np.array(self._accuracy_evaluator.general_inference_time),
+                formatted_exp_name
+            )
+
+        print("===== End of the evaluation pipeline =====")
+
+    def _configure_acc_eval(self, log_path):
+        pass
+
+    def _configure_test_module_params(self):
+        pass
+
+    @staticmethod
+    def _configure_eval_log(formatted_exp_name):
+        common_test_config = CommonConfig()
+        return common_test_config.log_file_path.format(formatted_exp_name), common_test_config.logs_dir
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/common/utils.py b/samples/dnn/dnn_model_runner/dnn_conversion/common/utils.py
new file mode 100644 (file)
index 0000000..cf24dd3
--- /dev/null
@@ -0,0 +1,153 @@
+import argparse
+import importlib.util
+import os
+import random
+
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+import torch
+
+from .test.configs.test_config import CommonConfig
+
+SEED_VAL = 42
+DNN_LIB = "DNN"
+# common path for model savings
+MODEL_PATH_ROOT = os.path.join(CommonConfig().output_data_root_dir, "{}/models")
+
+
+def get_full_model_path(lib_name, model_full_name):
+    model_path = MODEL_PATH_ROOT.format(lib_name)
+    return {
+        "path": model_path,
+        "full_path": os.path.join(model_path, model_full_name)
+    }
+
+
+def plot_acc(data_list, experiment_name):
+    plt.figure(figsize=[8, 6])
+    plt.plot(data_list[:, 0], "r", linewidth=2.5, label="Original Model")
+    plt.plot(data_list[:, 1], "b", linewidth=2.5, label="Converted DNN Model")
+    plt.xlabel("Iterations ", fontsize=15)
+    plt.ylabel("Time (ms)", fontsize=15)
+    plt.title(experiment_name, fontsize=15)
+    plt.legend()
+    full_path_to_fig = os.path.join(CommonConfig().output_data_root_dir, experiment_name + ".png")
+    plt.savefig(full_path_to_fig, bbox_inches="tight")
+
+
+def get_final_summary_info(general_quality_metric, general_inference_time, metric_name):
+    general_quality_metric = np.array(general_quality_metric)
+    general_inference_time = np.array(general_inference_time)
+    summary_line = "===== End of processing. General results:\n"
+    "\t* mean {} for the original model: {}\t"
+    "\t* mean time (min) for the original model inferences: {}\n"
+    "\t* mean {} for the DNN model: {}\t"
+    "\t* mean time (min) for the DNN model inferences: {}\n".format(
+        metric_name, np.mean(general_quality_metric[:, 0]),
+        np.mean(general_inference_time[:, 0]) / 60000,
+        metric_name, np.mean(general_quality_metric[:, 1]),
+        np.mean(general_inference_time[:, 1]) / 60000,
+    )
+    return summary_line
+
+
+def set_common_reproducibility():
+    random.seed(SEED_VAL)
+    np.random.seed(SEED_VAL)
+
+
+def set_pytorch_env():
+    set_common_reproducibility()
+    torch.manual_seed(SEED_VAL)
+    torch.set_printoptions(precision=10)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(SEED_VAL)
+        torch.backends.cudnn_benchmark_enabled = False
+        torch.backends.cudnn.deterministic = True
+
+
+def set_tf_env(is_use_gpu=True):
+    set_common_reproducibility()
+    tf.random.set_seed(SEED_VAL)
+    os.environ["TF_DETERMINISTIC_OPS"] = "1"
+
+    if tf.config.list_physical_devices("GPU") and is_use_gpu:
+        gpu_devices = tf.config.list_physical_devices("GPU")
+        tf.config.experimental.set_visible_devices(gpu_devices[0], "GPU")
+        tf.config.experimental.set_memory_growth(gpu_devices[0], True)
+        os.environ["TF_USE_CUDNN"] = "1"
+    else:
+        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+
+
+def str_bool(input_val):
+    if input_val.lower() in ('yes', 'true', 't', 'y', '1'):
+        return True
+    elif input_val.lower() in ('no', 'false', 'f', 'n', '0'):
+        return False
+    else:
+        raise argparse.ArgumentTypeError('Boolean value was expected')
+
+
+def get_formatted_model_list(model_list):
+    note_line = 'Please, choose the model from the below list:\n'
+    spaces_to_set = ' ' * (len(note_line) - 2)
+    return note_line + ''.join([spaces_to_set, '{} \n'] * len(model_list)).format(*model_list)
+
+
+def model_str(model_list):
+    def type_model_list(input_val):
+        if input_val.lower() in model_list:
+            return input_val.lower()
+        else:
+            raise argparse.ArgumentTypeError(
+                'The model is currently unavailable for test.\n' +
+                get_formatted_model_list(model_list)
+            )
+
+    return type_model_list
+
+
+def get_test_module(test_module_name, test_module_path):
+    module_spec = importlib.util.spec_from_file_location(test_module_name, test_module_path)
+    test_module = importlib.util.module_from_spec(module_spec)
+    module_spec.loader.exec_module(test_module)
+    module_spec.loader.exec_module(test_module)
+    return test_module
+
+
+def create_parser():
+    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
+    parser.add_argument(
+        "--test",
+        type=str_bool,
+        help="Define whether you'd like to run the model with OpenCV for testing.",
+        default=False
+    ),
+    parser.add_argument(
+        "--default_img_preprocess",
+        type=str_bool,
+        help="Define whether you'd like to preprocess the input image with defined"
+             " PyTorch or TF functions for model test with OpenCV.",
+        default=False
+    ),
+    parser.add_argument(
+        "--evaluate",
+        type=str_bool,
+        help="Define whether you'd like to run evaluation of the models (ex.: TF vs OpenCV networks).",
+        default=True
+    )
+    return parser
+
+
+def create_extended_parser(model_list):
+    parser = create_parser()
+    parser.add_argument(
+        "--model_name",
+        type=model_str(model_list=model_list),
+        help="\nDefine the model name to test.\n" +
+             get_formatted_model_list(model_list),
+        required=True
+    )
+    return parser
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_cls.py b/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_cls.py
new file mode 100644 (file)
index 0000000..abbd738
--- /dev/null
@@ -0,0 +1,71 @@
+from torchvision import models
+
+from ..pytorch_model import (
+    PyTorchModelPreparer,
+    PyTorchModelProcessor,
+    PyTorchDnnModelProcessor
+)
+from ...common.evaluation.classification.cls_data_fetcher import PyTorchPreprocessedFetch
+from ...common.test.cls_model_test_pipeline import ClsModelTestPipeline
+from ...common.test.configs.default_preprocess_config import pytorch_resize_input_blob
+from ...common.test.configs.test_config import TestClsConfig
+from ...common.utils import set_pytorch_env, create_extended_parser
+
+model_dict = {
+    "alexnet": models.alexnet,
+
+    "vgg11": models.vgg11,
+    "vgg13": models.vgg13,
+    "vgg16": models.vgg16,
+    "vgg19": models.vgg19,
+
+    "resnet18": models.resnet18,
+    "resnet34": models.resnet34,
+    "resnet50": models.resnet50,
+    "resnet101": models.resnet101,
+    "resnet152": models.resnet152,
+
+    "squeezenet1_0": models.squeezenet1_0,
+    "squeezenet1_1": models.squeezenet1_1,
+
+    "resnext50_32x4d": models.resnext50_32x4d,
+    "resnext101_32x8d": models.resnext101_32x8d,
+
+    "wide_resnet50_2": models.wide_resnet50_2,
+    "wide_resnet101_2": models.wide_resnet101_2
+}
+
+
+class PyTorchClsModel(PyTorchModelPreparer):
+    def __init__(self, height, width, model_name, original_model):
+        super(PyTorchClsModel, self).__init__(height, width, model_name, original_model)
+
+
+def main():
+    set_pytorch_env()
+
+    parser = create_extended_parser(list(model_dict.keys()))
+    cmd_args = parser.parse_args()
+    model_name = cmd_args.model_name
+
+    cls_model = PyTorchClsModel(
+        height=TestClsConfig().frame_size,
+        width=TestClsConfig().frame_size,
+        model_name=model_name,
+        original_model=model_dict[model_name](pretrained=True)
+    )
+
+    pytorch_cls_pipeline = ClsModelTestPipeline(
+        network_model=cls_model,
+        model_processor=PyTorchModelProcessor,
+        dnn_model_processor=PyTorchDnnModelProcessor,
+        data_fetcher=PyTorchPreprocessedFetch,
+        cls_args_parser=parser,
+        default_input_blob_preproc=pytorch_resize_input_blob
+    )
+
+    pytorch_cls_pipeline.init_test_pipeline()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50.py b/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50.py
new file mode 100644 (file)
index 0000000..3a228bf
--- /dev/null
@@ -0,0 +1,139 @@
+import os
+
+import cv2
+import numpy as np
+import torch
+import torch.onnx
+from torch.autograd import Variable
+from torchvision import models
+
+
+def get_pytorch_onnx_model(original_model):
+    # define the directory for further converted model save
+    onnx_model_path = "models"
+    # define the name of further converted model
+    onnx_model_name = "resnet50.onnx"
+
+    # create directory for further converted model
+    os.makedirs(onnx_model_path, exist_ok=True)
+
+    # get full path to the converted model
+    full_model_path = os.path.join(onnx_model_path, onnx_model_name)
+
+    # generate model input
+    generated_input = Variable(
+        torch.randn(1, 3, 224, 224)
+    )
+
+    # model export into ONNX format
+    torch.onnx.export(
+        original_model,
+        generated_input,
+        full_model_path,
+        verbose=True,
+        input_names=["input"],
+        output_names=["output"],
+        opset_version=11
+    )
+
+    return full_model_path
+
+
+def get_preprocessed_img(img_path):
+    # read the image
+    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+    input_img = input_img.astype(np.float32)
+
+    input_img = cv2.resize(input_img, (256, 256))
+
+    # define preprocess parameters
+    mean = np.array([0.485, 0.456, 0.406]) * 255.0
+    scale = 1 / 255.0
+    std = [0.229, 0.224, 0.225]
+
+    # prepare input blob to fit the model input:
+    # 1. subtract mean
+    # 2. scale to set pixel values from 0 to 1
+    input_blob = cv2.dnn.blobFromImage(
+        image=input_img,
+        scalefactor=scale,
+        size=(224, 224),  # img target size
+        mean=mean,
+        swapRB=True,  # BGR -> RGB
+        crop=True  # center crop
+    )
+    # 3. divide by std
+    input_blob[0] /= np.asarray(std, dtype=np.float32).reshape(3, 1, 1)
+    return input_blob
+
+
+def get_imagenet_labels(labels_path):
+    with open(labels_path) as f:
+        imagenet_labels = [line.strip() for line in f.readlines()]
+    return imagenet_labels
+
+
+def get_opencv_dnn_prediction(opencv_net, preproc_img, imagenet_labels):
+    # set OpenCV DNN input
+    opencv_net.setInput(preproc_img)
+
+    # OpenCV DNN inference
+    out = opencv_net.forward()
+    print("OpenCV DNN prediction: \n")
+    print("* shape: ", out.shape)
+
+    # get the predicted class ID
+    imagenet_class_id = np.argmax(out)
+
+    # get confidence
+    confidence = out[0][imagenet_class_id]
+    print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+    print("* confidence: {:.4f}".format(confidence))
+
+
+def get_pytorch_dnn_prediction(original_net, preproc_img, imagenet_labels):
+    original_net.eval()
+    preproc_img = torch.FloatTensor(preproc_img)
+
+    # inference
+    with torch.no_grad():
+        out = original_net(preproc_img)
+
+    print("\nPyTorch model prediction: \n")
+    print("* shape: ", out.shape)
+
+    # get the predicted class ID
+    imagenet_class_id = torch.argmax(out, axis=1).item()
+    print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+
+    # get confidence
+    confidence = out[0][imagenet_class_id]
+    print("* confidence: {:.4f}".format(confidence.item()))
+
+
+def main():
+    # initialize PyTorch ResNet-50 model
+    original_model = models.resnet50(pretrained=True)
+
+    # get the path to the converted into ONNX PyTorch model
+    full_model_path = get_pytorch_onnx_model(original_model)
+
+    # read converted .onnx model with OpenCV API
+    opencv_net = cv2.dnn.readNetFromONNX(full_model_path)
+    print("OpenCV model was successfully read. Layer IDs: \n", opencv_net.getLayerNames())
+
+    # get preprocessed image
+    input_img = get_preprocessed_img("../data/squirrel_cls.jpg")
+
+    # get ImageNet labels
+    imagenet_labels = get_imagenet_labels("../data/dnn/classification_classes_ILSVRC2012.txt")
+
+    # obtain OpenCV DNN predictions
+    get_opencv_dnn_prediction(opencv_net, input_img, imagenet_labels)
+
+    # obtain original PyTorch ResNet50 predictions
+    get_pytorch_dnn_prediction(original_model, input_img, imagenet_labels)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50_onnx.py b/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/classification/py_to_py_resnet50_onnx.py
new file mode 100644 (file)
index 0000000..969c359
--- /dev/null
@@ -0,0 +1,50 @@
+import os
+
+import torch
+import torch.onnx
+from torch.autograd import Variable
+from torchvision import models
+
+
+def get_pytorch_onnx_model(original_model):
+    # define the directory for further converted model save
+    onnx_model_path = "models"
+    # define the name of further converted model
+    onnx_model_name = "resnet50.onnx"
+
+    # create directory for further converted model
+    os.makedirs(onnx_model_path, exist_ok=True)
+
+    # get full path to the converted model
+    full_model_path = os.path.join(onnx_model_path, onnx_model_name)
+
+    # generate model input
+    generated_input = Variable(
+        torch.randn(1, 3, 224, 224)
+    )
+
+    # model export into ONNX format
+    torch.onnx.export(
+        original_model,
+        generated_input,
+        full_model_path,
+        verbose=True,
+        input_names=["input"],
+        output_names=["output"],
+        opset_version=11
+    )
+
+    return full_model_path
+
+
+def main():
+    # initialize PyTorch ResNet-50 model
+    original_model = models.resnet50(pretrained=True)
+
+    # get the path to the converted into ONNX PyTorch model
+    full_model_path = get_pytorch_onnx_model(original_model)
+    print("PyTorch ResNet-50 model was successfully converted: ", full_model_path)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/pytorch_model.py b/samples/dnn/dnn_model_runner/dnn_conversion/pytorch/pytorch_model.py
new file mode 100644 (file)
index 0000000..cb32004
--- /dev/null
@@ -0,0 +1,98 @@
+import os
+
+import cv2
+import torch.onnx
+from torch.autograd import Variable
+
+from ..common.abstract_model import AbstractModel, Framework
+from ..common.utils import DNN_LIB, get_full_model_path
+
+CURRENT_LIB = "PyTorch"
+MODEL_FORMAT = ".onnx"
+
+
+class PyTorchModelPreparer(AbstractModel):
+
+    def __init__(
+            self,
+            height,
+            width,
+            model_name="default",
+            original_model=object,
+            batch_size=1,
+            default_input_name="input",
+            default_output_name="output"
+    ):
+        self._height = height
+        self._width = width
+        self._model_name = model_name
+        self._original_model = original_model
+        self._batch_size = batch_size
+        self._default_input_name = default_input_name
+        self._default_output_name = default_output_name
+
+        self.model_path = self._set_model_path()
+        self._dnn_model = self._set_dnn_model()
+
+    def _set_dnn_model(self):
+        generated_input = Variable(torch.randn(
+            self._batch_size, 3, self._height, self._width)
+        )
+        os.makedirs(self.model_path["path"], exist_ok=True)
+        torch.onnx.export(
+            self._original_model,
+            generated_input,
+            self.model_path["full_path"],
+            verbose=True,
+            input_names=[self._default_input_name],
+            output_names=[self._default_output_name],
+            opset_version=11
+        )
+
+        return cv2.dnn.readNetFromONNX(self.model_path["full_path"])
+
+    def _set_model_path(self):
+        model_to_save = self._model_name + MODEL_FORMAT
+        return get_full_model_path(CURRENT_LIB.lower(), model_to_save)
+
+    def get_prepared_models(self):
+        return {
+            CURRENT_LIB + " " + self._model_name: self._original_model,
+            DNN_LIB + " " + self._model_name: self._dnn_model
+        }
+
+
+class PyTorchModelProcessor(Framework):
+    def __init__(self, prepared_model, model_name):
+        self._prepared_model = prepared_model
+        self._name = model_name
+
+    def get_output(self, input_blob):
+        tensor = torch.FloatTensor(input_blob)
+        self._prepared_model.eval()
+
+        with torch.no_grad():
+            model_out = self._prepared_model(tensor)
+
+        # segmentation case
+        if len(model_out) == 2:
+            model_out = model_out['out']
+
+        out = model_out.detach().numpy()
+        return out
+
+    def get_name(self):
+        return self._name
+
+
+class PyTorchDnnModelProcessor(Framework):
+    def __init__(self, prepared_dnn_model, model_name):
+        self._prepared_dnn_model = prepared_dnn_model
+        self._name = model_name
+
+    def get_output(self, input_blob):
+        self._prepared_dnn_model.setInput(input_blob, '')
+        return self._prepared_dnn_model.forward()
+
+    def get_name(self):
+        return self._name
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt b/samples/dnn/dnn_model_runner/dnn_conversion/requirements.txt
new file mode 100644 (file)
index 0000000..65ab56a
--- /dev/null
@@ -0,0 +1,9 @@
+# Python 3.7.5
+onnx>=1.7.0
+numpy>=1.19.1
+
+torch>=1.5.1
+torchvision>=0.6.1
+
+tensorflow>=2.1.0
+tensorflow-gpu>=2.1.0
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_cls.py b/samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_cls.py
new file mode 100644 (file)
index 0000000..0eabe87
--- /dev/null
@@ -0,0 +1,104 @@
+from tensorflow.keras.applications import (
+    VGG16, vgg16,
+    VGG19, vgg19,
+
+    ResNet50, resnet,
+    ResNet101,
+    ResNet152,
+
+    DenseNet121, densenet,
+    DenseNet169,
+    DenseNet201,
+
+    InceptionResNetV2, inception_resnet_v2,
+    InceptionV3, inception_v3,
+
+    MobileNet, mobilenet,
+    MobileNetV2, mobilenet_v2,
+
+    NASNetLarge, nasnet,
+    NASNetMobile,
+
+    Xception, xception
+)
+
+from ..tf_model import TFModelPreparer
+from ..tf_model import (
+    TFModelProcessor,
+    TFDnnModelProcessor
+)
+from ...common.evaluation.classification.cls_data_fetcher import TFPreprocessedFetch
+from ...common.test.cls_model_test_pipeline import ClsModelTestPipeline
+from ...common.test.configs.default_preprocess_config import (
+    tf_input_blob,
+    pytorch_input_blob,
+    tf_model_blob_caffe_mode
+)
+from ...common.utils import set_tf_env, create_extended_parser
+
+model_dict = {
+    "vgg16": [VGG16, vgg16, tf_model_blob_caffe_mode],
+    "vgg19": [VGG19, vgg19, tf_model_blob_caffe_mode],
+
+    "resnet50": [ResNet50, resnet, tf_model_blob_caffe_mode],
+    "resnet101": [ResNet101, resnet, tf_model_blob_caffe_mode],
+    "resnet152": [ResNet152, resnet, tf_model_blob_caffe_mode],
+
+    "densenet121": [DenseNet121, densenet, pytorch_input_blob],
+    "densenet169": [DenseNet169, densenet, pytorch_input_blob],
+    "densenet201": [DenseNet201, densenet, pytorch_input_blob],
+
+    "inceptionresnetv2": [InceptionResNetV2, inception_resnet_v2, tf_input_blob],
+    "inceptionv3": [InceptionV3, inception_v3, tf_input_blob],
+
+    "mobilenet": [MobileNet, mobilenet, tf_input_blob],
+    "mobilenetv2": [MobileNetV2, mobilenet_v2, tf_input_blob],
+
+    "nasnetlarge": [NASNetLarge, nasnet, tf_input_blob],
+    "nasnetmobile": [NASNetMobile, nasnet, tf_input_blob],
+
+    "xception": [Xception, xception, tf_input_blob]
+}
+
+CNN_CLASS_ID = 0
+CNN_UTILS_ID = 1
+DEFAULT_BLOB_PARAMS_ID = 2
+
+
+class TFClsModel(TFModelPreparer):
+    def __init__(self, model_name, original_model):
+        super(TFClsModel, self).__init__(model_name, original_model)
+
+
+def main():
+    set_tf_env()
+
+    parser = create_extended_parser(list(model_dict.keys()))
+    cmd_args = parser.parse_args()
+
+    model_name = cmd_args.model_name
+    model_name_val = model_dict[model_name]
+
+    cls_model = TFClsModel(
+        model_name=model_name,
+        original_model=model_name_val[CNN_CLASS_ID](
+            include_top=True,
+            weights="imagenet"
+        )
+    )
+
+    tf_cls_pipeline = ClsModelTestPipeline(
+        network_model=cls_model,
+        model_processor=TFModelProcessor,
+        dnn_model_processor=TFDnnModelProcessor,
+        data_fetcher=TFPreprocessedFetch,
+        img_processor=model_name_val[CNN_UTILS_ID].preprocess_input,
+        cls_args_parser=parser,
+        default_input_blob_preproc=model_name_val[DEFAULT_BLOB_PARAMS_ID]
+    )
+
+    tf_cls_pipeline.init_test_pipeline()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_mobilenet.py b/samples/dnn/dnn_model_runner/dnn_conversion/tf/classification/py_to_py_mobilenet.py
new file mode 100644 (file)
index 0000000..ebc6cfe
--- /dev/null
@@ -0,0 +1,142 @@
+import os
+
+import cv2
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.applications import MobileNet
+from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+from ...common.utils import set_tf_env
+
+
+def get_tf_model_proto(tf_model):
+    # define the directory for .pb model
+    pb_model_path = "models"
+
+    # define the name of .pb model
+    pb_model_name = "mobilenet.pb"
+
+    # create directory for further converted model
+    os.makedirs(pb_model_path, exist_ok=True)
+
+    # get model TF graph
+    tf_model_graph = tf.function(lambda x: tf_model(x))
+
+    # get concrete function
+    tf_model_graph = tf_model_graph.get_concrete_function(
+        tf.TensorSpec(tf_model.inputs[0].shape, tf_model.inputs[0].dtype))
+
+    # obtain frozen concrete function
+    frozen_tf_func = convert_variables_to_constants_v2(tf_model_graph)
+    # get frozen graph
+    frozen_tf_func.graph.as_graph_def()
+
+    # save full tf model
+    tf.io.write_graph(graph_or_graph_def=frozen_tf_func.graph,
+                      logdir=pb_model_path,
+                      name=pb_model_name,
+                      as_text=False)
+
+    return os.path.join(pb_model_path, pb_model_name)
+
+
+def get_preprocessed_img(img_path):
+    # read the image
+    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+    input_img = input_img.astype(np.float32)
+
+    # define preprocess parameters
+    mean = np.array([1.0, 1.0, 1.0]) * 127.5
+    scale = 1 / 127.5
+
+    # prepare input blob to fit the model input:
+    # 1. subtract mean
+    # 2. scale to set pixel values from 0 to 1
+    input_blob = cv2.dnn.blobFromImage(
+        image=input_img,
+        scalefactor=scale,
+        size=(224, 224),  # img target size
+        mean=mean,
+        swapRB=True,  # BGR -> RGB
+        crop=True  # center crop
+    )
+    print("Input blob shape: {}\n".format(input_blob.shape))
+
+    return input_blob
+
+
+def get_imagenet_labels(labels_path):
+    with open(labels_path) as f:
+        imagenet_labels = [line.strip() for line in f.readlines()]
+    return imagenet_labels
+
+
+def get_opencv_dnn_prediction(opencv_net, preproc_img, imagenet_labels):
+    # set OpenCV DNN input
+    opencv_net.setInput(preproc_img)
+
+    # OpenCV DNN inference
+    out = opencv_net.forward()
+    print("OpenCV DNN prediction: \n")
+    print("* shape: ", out.shape)
+
+    # get the predicted class ID
+    imagenet_class_id = np.argmax(out)
+
+    # get confidence
+    confidence = out[0][imagenet_class_id]
+    print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+    print("* confidence: {:.4f}\n".format(confidence))
+
+
+def get_tf_dnn_prediction(original_net, preproc_img, imagenet_labels):
+    # inference
+    preproc_img = preproc_img.transpose(0, 2, 3, 1)
+    print("TF input blob shape: {}\n".format(preproc_img.shape))
+
+    out = original_net(preproc_img)
+
+    print("\nTensorFlow model prediction: \n")
+    print("* shape: ", out.shape)
+
+    # get the predicted class ID
+    imagenet_class_id = np.argmax(out)
+    print("* class ID: {}, label: {}".format(imagenet_class_id, imagenet_labels[imagenet_class_id]))
+
+    # get confidence
+    confidence = out[0][imagenet_class_id]
+    print("* confidence: {:.4f}".format(confidence))
+
+
+def main():
+    # configure TF launching
+    set_tf_env()
+
+    # initialize TF MobileNet model
+    original_tf_model = MobileNet(
+        include_top=True,
+        weights="imagenet"
+    )
+
+    # get TF frozen graph path
+    full_pb_path = get_tf_model_proto(original_tf_model)
+
+    # read frozen graph with OpenCV API
+    opencv_net = cv2.dnn.readNetFromTensorflow(full_pb_path)
+    print("OpenCV model was successfully read. Model layers: \n", opencv_net.getLayerNames())
+
+    # get preprocessed image
+    input_img = get_preprocessed_img("../data/squirrel_cls.jpg")
+
+    # get ImageNet labels
+    imagenet_labels = get_imagenet_labels("../data/dnn/classification_classes_ILSVRC2012.txt")
+
+    # obtain OpenCV DNN predictions
+    get_opencv_dnn_prediction(opencv_net, input_img, imagenet_labels)
+
+    # obtain TF model predictions
+    get_tf_dnn_prediction(original_tf_model, input_img, imagenet_labels)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/samples/dnn/dnn_model_runner/dnn_conversion/tf/tf_model.py b/samples/dnn/dnn_model_runner/dnn_conversion/tf/tf_model.py
new file mode 100644 (file)
index 0000000..2411821
--- /dev/null
@@ -0,0 +1,112 @@
+import cv2
+import tensorflow as tf
+from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+from ..common.abstract_model import AbstractModel, Framework
+from ..common.utils import DNN_LIB, get_full_model_path
+
+CURRENT_LIB = "TF"
+MODEL_FORMAT = ".pb"
+
+
+class TFModelPreparer(AbstractModel):
+    """ Class for the preparation of the TF models: original and converted OpenCV Net.
+
+    Args:
+        model_name: TF model name
+        original_model: TF configured model object or session
+        is_ready_graph: indicates whether ready .pb file already exists
+        tf_model_graph_path: path to the existing frozen TF graph
+    """
+
+    def __init__(
+            self,
+            model_name="default",
+            original_model=None,
+            is_ready_graph=False,
+            tf_model_graph_path=""
+    ):
+        self._model_name = model_name
+        self._original_model = original_model
+        self._model_to_save = ""
+
+        self._is_ready_to_transfer_graph = is_ready_graph
+        self.model_path = self._set_model_path(tf_model_graph_path)
+        self._dnn_model = self._set_dnn_model()
+
+    def _set_dnn_model(self):
+        if not self._is_ready_to_transfer_graph:
+            # get model TF graph
+            tf_model_graph = tf.function(lambda x: self._original_model(x))
+
+            tf_model_graph = tf_model_graph.get_concrete_function(
+                tf.TensorSpec(self._original_model.inputs[0].shape, self._original_model.inputs[0].dtype))
+
+            # obtain frozen concrete function
+            frozen_tf_func = convert_variables_to_constants_v2(tf_model_graph)
+            frozen_tf_func.graph.as_graph_def()
+
+            # save full TF model
+            tf.io.write_graph(graph_or_graph_def=frozen_tf_func.graph,
+                              logdir=self.model_path["path"],
+                              name=self._model_to_save,
+                              as_text=False)
+
+        return cv2.dnn.readNetFromTensorflow(self.model_path["full_path"])
+
+    def _set_model_path(self, tf_pb_file_path):
+        """ Method for setting model paths.
+
+        Args:
+            tf_pb_file_path: path to the existing TF .pb
+
+        Returns:
+            dictionary, where full_path key means saved model path and its full name.
+        """
+        model_paths_dict = {
+            "path": "",
+            "full_path": tf_pb_file_path
+        }
+
+        if not self._is_ready_to_transfer_graph:
+            self._model_to_save = self._model_name + MODEL_FORMAT
+            model_paths_dict = get_full_model_path(CURRENT_LIB.lower(), self._model_to_save)
+
+        return model_paths_dict
+
+    def get_prepared_models(self):
+        original_lib_name = CURRENT_LIB + " " + self._model_name
+        configured_model_dict = {
+            original_lib_name: self._original_model,
+            DNN_LIB + " " + self._model_name: self._dnn_model
+        }
+        return configured_model_dict
+
+
+class TFModelProcessor(Framework):
+    def __init__(self, prepared_model, model_name):
+        self._prepared_model = prepared_model
+        self._name = model_name
+
+    def get_output(self, input_blob):
+        assert len(input_blob.shape) == 4
+        batch_tf = input_blob.transpose(0, 2, 3, 1)
+        out = self._prepared_model(batch_tf)
+        return out
+
+    def get_name(self):
+        return CURRENT_LIB
+
+
+class TFDnnModelProcessor(Framework):
+    def __init__(self, prepared_dnn_model, model_name):
+        self._prepared_dnn_model = prepared_dnn_model
+        self._name = model_name
+
+    def get_output(self, input_blob):
+        self._prepared_dnn_model.setInput(input_blob)
+        ret_val = self._prepared_dnn_model.forward()
+        return ret_val
+
+    def get_name(self):
+        return DNN_LIB