Imported Upstream version 6.1
[platform/upstream/ffmpeg.git] / libavfilter / vf_dnn_processing.c
1 /*
2  * Copyright (c) 2019 Guo Yejun
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21 /**
22  * @file
23  * implementing a generic image processing filter using deep learning networks.
24  */
25
26 #include "libavutil/opt.h"
27 #include "libavutil/pixdesc.h"
28 #include "libavutil/avassert.h"
29 #include "libavutil/imgutils.h"
30 #include "filters.h"
31 #include "dnn_filter_common.h"
32 #include "internal.h"
33 #include "video.h"
34 #include "libswscale/swscale.h"
35 #include "libavutil/time.h"
36
37 typedef struct DnnProcessingContext {
38     const AVClass *class;
39     DnnContext dnnctx;
40     struct SwsContext *sws_uv_scale;
41     int sws_uv_height;
42 } DnnProcessingContext;
43
44 #define OFFSET(x) offsetof(DnnProcessingContext, dnnctx.x)
45 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
46 static const AVOption dnn_processing_options[] = {
47     { "dnn_backend", "DNN backend",                OFFSET(backend_type),     AV_OPT_TYPE_INT,       { .i64 = DNN_TF },    INT_MIN, INT_MAX, FLAGS, "backend" },
48 #if (CONFIG_LIBTENSORFLOW == 1)
49     { "tensorflow",  "tensorflow backend flag",    0,                        AV_OPT_TYPE_CONST,     { .i64 = DNN_TF },    0, 0, FLAGS, "backend" },
50 #endif
51 #if (CONFIG_LIBOPENVINO == 1)
52     { "openvino",    "openvino backend flag",      0,                        AV_OPT_TYPE_CONST,     { .i64 = DNN_OV },    0, 0, FLAGS, "backend" },
53 #endif
54     DNN_COMMON_OPTIONS
55     { NULL }
56 };
57
58 AVFILTER_DEFINE_CLASS(dnn_processing);
59
60 static av_cold int init(AVFilterContext *context)
61 {
62     DnnProcessingContext *ctx = context->priv;
63     return ff_dnn_init(&ctx->dnnctx, DFT_PROCESS_FRAME, context);
64 }
65
66 static const enum AVPixelFormat pix_fmts[] = {
67     AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24,
68     AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32,
69     AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P,
70     AV_PIX_FMT_YUV444P, AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P,
71     AV_PIX_FMT_NV12,
72     AV_PIX_FMT_NONE
73 };
74
75 #define LOG_FORMAT_CHANNEL_MISMATCH()                       \
76     av_log(ctx, AV_LOG_ERROR,                               \
77            "the frame's format %s does not match "          \
78            "the model input channel %d\n",                  \
79            av_get_pix_fmt_name(fmt),                        \
80            model_input->channels);
81
82 static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
83 {
84     AVFilterContext *ctx   = inlink->dst;
85     enum AVPixelFormat fmt = inlink->format;
86
87     // the design is to add explicit scale filter before this filter
88     if (model_input->height != -1 && model_input->height != inlink->h) {
89         av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n",
90                                    model_input->height, inlink->h);
91         return AVERROR(EIO);
92     }
93     if (model_input->width != -1 && model_input->width != inlink->w) {
94         av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n",
95                                    model_input->width, inlink->w);
96         return AVERROR(EIO);
97     }
98     if (model_input->dt != DNN_FLOAT) {
99         avpriv_report_missing_feature(ctx, "data type rather than DNN_FLOAT");
100         return AVERROR(EIO);
101     }
102
103     switch (fmt) {
104     case AV_PIX_FMT_RGB24:
105     case AV_PIX_FMT_BGR24:
106         if (model_input->channels != 3) {
107             LOG_FORMAT_CHANNEL_MISMATCH();
108             return AVERROR(EIO);
109         }
110         return 0;
111     case AV_PIX_FMT_GRAY8:
112     case AV_PIX_FMT_GRAYF32:
113     case AV_PIX_FMT_YUV420P:
114     case AV_PIX_FMT_YUV422P:
115     case AV_PIX_FMT_YUV444P:
116     case AV_PIX_FMT_YUV410P:
117     case AV_PIX_FMT_YUV411P:
118     case AV_PIX_FMT_NV12:
119         if (model_input->channels != 1) {
120             LOG_FORMAT_CHANNEL_MISMATCH();
121             return AVERROR(EIO);
122         }
123         return 0;
124     default:
125         avpriv_report_missing_feature(ctx, "%s", av_get_pix_fmt_name(fmt));
126         return AVERROR(EIO);
127     }
128
129     return 0;
130 }
131
132 static int config_input(AVFilterLink *inlink)
133 {
134     AVFilterContext *context     = inlink->dst;
135     DnnProcessingContext *ctx = context->priv;
136     int result;
137     DNNData model_input;
138     int check;
139
140     result = ff_dnn_get_input(&ctx->dnnctx, &model_input);
141     if (result != 0) {
142         av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n");
143         return result;
144     }
145
146     check = check_modelinput_inlink(&model_input, inlink);
147     if (check != 0) {
148         return check;
149     }
150
151     return 0;
152 }
153
154 static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
155 {
156     const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
157     av_assert0(desc);
158     return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
159 }
160
161 static int prepare_uv_scale(AVFilterLink *outlink)
162 {
163     AVFilterContext *context = outlink->src;
164     DnnProcessingContext *ctx = context->priv;
165     AVFilterLink *inlink = context->inputs[0];
166     enum AVPixelFormat fmt = inlink->format;
167
168     if (isPlanarYUV(fmt)) {
169         if (inlink->w != outlink->w || inlink->h != outlink->h) {
170             if (fmt == AV_PIX_FMT_NV12) {
171                 ctx->sws_uv_scale = sws_getContext(inlink->w >> 1, inlink->h >> 1, AV_PIX_FMT_YA8,
172                                                    outlink->w >> 1, outlink->h >> 1, AV_PIX_FMT_YA8,
173                                                    SWS_BICUBIC, NULL, NULL, NULL);
174                 ctx->sws_uv_height = inlink->h >> 1;
175             } else {
176                 const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt);
177                 int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
178                 int sws_src_w = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w);
179                 int sws_dst_h = AV_CEIL_RSHIFT(outlink->h, desc->log2_chroma_h);
180                 int sws_dst_w = AV_CEIL_RSHIFT(outlink->w, desc->log2_chroma_w);
181                 ctx->sws_uv_scale = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8,
182                                                    sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8,
183                                                    SWS_BICUBIC, NULL, NULL, NULL);
184                 ctx->sws_uv_height = sws_src_h;
185             }
186         }
187     }
188
189     return 0;
190 }
191
192 static int config_output(AVFilterLink *outlink)
193 {
194     AVFilterContext *context = outlink->src;
195     DnnProcessingContext *ctx = context->priv;
196     int result;
197     AVFilterLink *inlink = context->inputs[0];
198
199     // have a try run in case that the dnn model resize the frame
200     result = ff_dnn_get_output(&ctx->dnnctx, inlink->w, inlink->h, &outlink->w, &outlink->h);
201     if (result != 0) {
202         av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n");
203         return result;
204     }
205
206     prepare_uv_scale(outlink);
207
208     return 0;
209 }
210
211 static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in)
212 {
213     const AVPixFmtDescriptor *desc;
214     int uv_height;
215
216     if (!ctx->sws_uv_scale) {
217         av_assert0(in->height == out->height && in->width == out->width);
218         desc = av_pix_fmt_desc_get(in->format);
219         uv_height = AV_CEIL_RSHIFT(in->height, desc->log2_chroma_h);
220         for (int i = 1; i < 3; ++i) {
221             int bytewidth = av_image_get_linesize(in->format, in->width, i);
222             if (bytewidth < 0) {
223                 return AVERROR(EINVAL);
224             }
225             av_image_copy_plane(out->data[i], out->linesize[i],
226                                 in->data[i], in->linesize[i],
227                                 bytewidth, uv_height);
228         }
229     } else if (in->format == AV_PIX_FMT_NV12) {
230         sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1,
231                   0, ctx->sws_uv_height, out->data + 1, out->linesize + 1);
232     } else {
233         sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1,
234                   0, ctx->sws_uv_height, out->data + 1, out->linesize + 1);
235         sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 2), in->linesize + 2,
236                   0, ctx->sws_uv_height, out->data + 2, out->linesize + 2);
237     }
238
239     return 0;
240 }
241
242 static int flush_frame(AVFilterLink *outlink, int64_t pts, int64_t *out_pts)
243 {
244     DnnProcessingContext *ctx = outlink->src->priv;
245     int ret;
246     DNNAsyncStatusType async_state;
247
248     ret = ff_dnn_flush(&ctx->dnnctx);
249     if (ret != 0) {
250         return -1;
251     }
252
253     do {
254         AVFrame *in_frame = NULL;
255         AVFrame *out_frame = NULL;
256         async_state = ff_dnn_get_result(&ctx->dnnctx, &in_frame, &out_frame);
257         if (out_frame) {
258             if (isPlanarYUV(in_frame->format))
259                 copy_uv_planes(ctx, out_frame, in_frame);
260             av_frame_free(&in_frame);
261             ret = ff_filter_frame(outlink, out_frame);
262             if (ret < 0)
263                 return ret;
264             if (out_pts)
265                 *out_pts = out_frame->pts + pts;
266         }
267         av_usleep(5000);
268     } while (async_state >= DAST_NOT_READY);
269
270     return 0;
271 }
272
273 static int activate(AVFilterContext *filter_ctx)
274 {
275     AVFilterLink *inlink = filter_ctx->inputs[0];
276     AVFilterLink *outlink = filter_ctx->outputs[0];
277     DnnProcessingContext *ctx = filter_ctx->priv;
278     AVFrame *in = NULL, *out = NULL;
279     int64_t pts;
280     int ret, status;
281     int got_frame = 0;
282     int async_state;
283
284     FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink);
285
286     do {
287         // drain all input frames
288         ret = ff_inlink_consume_frame(inlink, &in);
289         if (ret < 0)
290             return ret;
291         if (ret > 0) {
292             out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
293             if (!out) {
294                 av_frame_free(&in);
295                 return AVERROR(ENOMEM);
296             }
297             av_frame_copy_props(out, in);
298             if (ff_dnn_execute_model(&ctx->dnnctx, in, out) != 0) {
299                 return AVERROR(EIO);
300             }
301         }
302     } while (ret > 0);
303
304     // drain all processed frames
305     do {
306         AVFrame *in_frame = NULL;
307         AVFrame *out_frame = NULL;
308         async_state = ff_dnn_get_result(&ctx->dnnctx, &in_frame, &out_frame);
309         if (out_frame) {
310             if (isPlanarYUV(in_frame->format))
311                 copy_uv_planes(ctx, out_frame, in_frame);
312             av_frame_free(&in_frame);
313             ret = ff_filter_frame(outlink, out_frame);
314             if (ret < 0)
315                 return ret;
316             got_frame = 1;
317         }
318     } while (async_state == DAST_SUCCESS);
319
320     // if frame got, schedule to next filter
321     if (got_frame)
322         return 0;
323
324     if (ff_inlink_acknowledge_status(inlink, &status, &pts)) {
325         if (status == AVERROR_EOF) {
326             int64_t out_pts = pts;
327             ret = flush_frame(outlink, pts, &out_pts);
328             ff_outlink_set_status(outlink, status, out_pts);
329             return ret;
330         }
331     }
332
333     FF_FILTER_FORWARD_WANTED(outlink, inlink);
334
335     return 0;
336 }
337
338 static av_cold void uninit(AVFilterContext *ctx)
339 {
340     DnnProcessingContext *context = ctx->priv;
341
342     sws_freeContext(context->sws_uv_scale);
343     ff_dnn_uninit(&context->dnnctx);
344 }
345
346 static const AVFilterPad dnn_processing_inputs[] = {
347     {
348         .name         = "default",
349         .type         = AVMEDIA_TYPE_VIDEO,
350         .config_props = config_input,
351     },
352 };
353
354 static const AVFilterPad dnn_processing_outputs[] = {
355     {
356         .name = "default",
357         .type = AVMEDIA_TYPE_VIDEO,
358         .config_props  = config_output,
359     },
360 };
361
362 const AVFilter ff_vf_dnn_processing = {
363     .name          = "dnn_processing",
364     .description   = NULL_IF_CONFIG_SMALL("Apply DNN processing filter to the input."),
365     .priv_size     = sizeof(DnnProcessingContext),
366     .init          = init,
367     .uninit        = uninit,
368     FILTER_INPUTS(dnn_processing_inputs),
369     FILTER_OUTPUTS(dnn_processing_outputs),
370     FILTER_PIXFMTS_ARRAY(pix_fmts),
371     .priv_class    = &dnn_processing_class,
372     .activate      = activate,
373 };