Imported Upstream version 6.1
[platform/upstream/ffmpeg.git] / libavfilter / vf_libvmaf.c
1 /*
2  * Copyright (c) 2017 Ronald S. Bultje <rsbultje@gmail.com>
3  * Copyright (c) 2017 Ashish Pratap Singh <ashk43712@gmail.com>
4  *
5  * This file is part of FFmpeg.
6  *
7  * FFmpeg is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * FFmpeg is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with FFmpeg; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21
22 /**
23  * @file
24  * Calculate the VMAF between two input videos.
25  */
26
27 #include "config_components.h"
28
29 #include <libvmaf.h>
30
31 #include "libavutil/avstring.h"
32 #include "libavutil/opt.h"
33 #include "libavutil/pixdesc.h"
34 #include "avfilter.h"
35 #include "drawutils.h"
36 #include "formats.h"
37 #include "framesync.h"
38 #include "internal.h"
39 #include "video.h"
40
41 #if CONFIG_LIBVMAF_CUDA_FILTER
42 #include <libvmaf_cuda.h>
43
44 #include "libavutil/hwcontext.h"
45 #include "libavutil/hwcontext_cuda_internal.h"
46 #endif
47
48 typedef struct LIBVMAFContext {
49     const AVClass *class;
50     FFFrameSync fs;
51     char *model_path;
52     char *log_path;
53     char *log_fmt;
54     int enable_transform;
55     int phone_model;
56     int psnr;
57     int ssim;
58     int ms_ssim;
59     char *pool;
60     int n_threads;
61     int n_subsample;
62     int enable_conf_interval;
63     char *model_cfg;
64     char *feature_cfg;
65     VmafContext *vmaf;
66     VmafModel **model;
67     unsigned model_cnt;
68     unsigned frame_cnt;
69     unsigned bpc;
70 #if CONFIG_LIBVMAF_CUDA_FILTER
71     VmafCudaState *cu_state;
72 #endif
73 } LIBVMAFContext;
74
75 #define OFFSET(x) offsetof(LIBVMAFContext, x)
76 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM
77
78 static const AVOption libvmaf_options[] = {
79     {"log_path",  "Set the file path to be used to write log.",                         OFFSET(log_path), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
80     {"log_fmt",  "Set the format of the log (csv, json, xml, or sub).",                 OFFSET(log_fmt), AV_OPT_TYPE_STRING, {.str="xml"}, 0, 1, FLAGS},
81     {"pool",  "Set the pool method to be used for computing vmaf.",                     OFFSET(pool), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
82     {"n_threads", "Set number of threads to be used when computing vmaf.",              OFFSET(n_threads), AV_OPT_TYPE_INT, {.i64=0}, 0, UINT_MAX, FLAGS},
83     {"n_subsample", "Set interval for frame subsampling used when computing vmaf.",     OFFSET(n_subsample), AV_OPT_TYPE_INT, {.i64=1}, 1, UINT_MAX, FLAGS},
84     {"model",  "Set the model to be used for computing vmaf.",                          OFFSET(model_cfg), AV_OPT_TYPE_STRING, {.str="version=vmaf_v0.6.1"}, 0, 1, FLAGS},
85     {"feature",  "Set the feature to be used for computing vmaf.",                      OFFSET(feature_cfg), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 1, FLAGS},
86     { NULL }
87 };
88
89 FRAMESYNC_DEFINE_CLASS(libvmaf, LIBVMAFContext, fs);
90
91 static enum VmafPixelFormat pix_fmt_map(enum AVPixelFormat av_pix_fmt)
92 {
93     switch (av_pix_fmt) {
94     case AV_PIX_FMT_YUV420P:
95     case AV_PIX_FMT_YUV420P10LE:
96     case AV_PIX_FMT_YUV420P12LE:
97     case AV_PIX_FMT_YUV420P16LE:
98         return VMAF_PIX_FMT_YUV420P;
99     case AV_PIX_FMT_YUV422P:
100     case AV_PIX_FMT_YUV422P10LE:
101     case AV_PIX_FMT_YUV422P12LE:
102     case AV_PIX_FMT_YUV422P16LE:
103         return VMAF_PIX_FMT_YUV422P;
104     case AV_PIX_FMT_YUV444P:
105     case AV_PIX_FMT_YUV444P10LE:
106     case AV_PIX_FMT_YUV444P12LE:
107     case AV_PIX_FMT_YUV444P16LE:
108         return VMAF_PIX_FMT_YUV444P;
109     default:
110         return VMAF_PIX_FMT_UNKNOWN;
111     }
112 }
113
114 static int copy_picture_data(AVFrame *src, VmafPicture *dst, unsigned bpc)
115 {
116     const int bytes_per_value = bpc > 8 ? 2 : 1;
117     int err = vmaf_picture_alloc(dst, pix_fmt_map(src->format), bpc,
118                                  src->width, src->height);
119     if (err)
120         return AVERROR(ENOMEM);
121
122     for (unsigned i = 0; i < 3; i++) {
123         uint8_t *src_data = src->data[i];
124         uint8_t *dst_data = dst->data[i];
125         for (unsigned j = 0; j < dst->h[i]; j++) {
126             memcpy(dst_data, src_data, bytes_per_value * dst->w[i]);
127             src_data += src->linesize[i];
128             dst_data += dst->stride[i];
129         }
130     }
131
132     return 0;
133 }
134
135 static int do_vmaf(FFFrameSync *fs)
136 {
137     AVFilterContext *ctx = fs->parent;
138     LIBVMAFContext *s = ctx->priv;
139     VmafPicture pic_ref, pic_dist;
140     AVFrame *ref, *dist;
141     int err = 0;
142
143     int ret = ff_framesync_dualinput_get(fs, &dist, &ref);
144     if (ret < 0)
145         return ret;
146     if (ctx->is_disabled || !ref)
147         return ff_filter_frame(ctx->outputs[0], dist);
148
149     if (dist->color_range != ref->color_range) {
150         av_log(ctx, AV_LOG_WARNING, "distorted and reference "
151                "frames use different color ranges (%s != %s)\n",
152                av_color_range_name(dist->color_range),
153                av_color_range_name(ref->color_range));
154     }
155
156     err = copy_picture_data(ref, &pic_ref, s->bpc);
157     if (err) {
158         av_log(s, AV_LOG_ERROR, "problem during vmaf_picture_alloc.\n");
159         return AVERROR(ENOMEM);
160     }
161
162     err = copy_picture_data(dist, &pic_dist, s->bpc);
163     if (err) {
164         av_log(s, AV_LOG_ERROR, "problem during vmaf_picture_alloc.\n");
165         vmaf_picture_unref(&pic_ref);
166         return AVERROR(ENOMEM);
167     }
168
169     err = vmaf_read_pictures(s->vmaf, &pic_ref, &pic_dist, s->frame_cnt++);
170     if (err) {
171         av_log(s, AV_LOG_ERROR, "problem during vmaf_read_pictures.\n");
172         return AVERROR(EINVAL);
173     }
174
175     return ff_filter_frame(ctx->outputs[0], dist);
176 }
177
178
179 static AVDictionary **delimited_dict_parse(char *str, unsigned *cnt)
180 {
181     AVDictionary **dict = NULL;
182     char *str_copy = NULL;
183     char *saveptr = NULL;
184     unsigned cnt2;
185     int err = 0;
186
187     if (!str)
188         return NULL;
189
190     cnt2 = 1;
191     for (char *p = str; *p; p++) {
192         if (*p == '|')
193             cnt2++;
194     }
195
196     dict = av_calloc(cnt2, sizeof(*dict));
197     if (!dict)
198         goto fail;
199
200     str_copy = av_strdup(str);
201     if (!str_copy)
202         goto fail;
203
204     *cnt = 0;
205     for (unsigned i = 0; i < cnt2; i++) {
206         char *s = av_strtok(i == 0 ? str_copy : NULL, "|", &saveptr);
207         if (!s)
208             continue;
209         err = av_dict_parse_string(&dict[(*cnt)++], s, "=", ":", 0);
210         if (err)
211             goto fail;
212     }
213
214     av_free(str_copy);
215     return dict;
216
217 fail:
218     if (dict) {
219         for (unsigned i = 0; i < *cnt; i++) {
220             if (dict[i])
221                 av_dict_free(&dict[i]);
222         }
223         av_free(dict);
224     }
225
226     av_free(str_copy);
227     *cnt = 0;
228     return NULL;
229 }
230
231 static int parse_features(AVFilterContext *ctx)
232 {
233     LIBVMAFContext *s = ctx->priv;
234     AVDictionary **dict = NULL;
235     unsigned dict_cnt;
236     int err = 0;
237
238     if (!s->feature_cfg)
239         return 0;
240
241     dict = delimited_dict_parse(s->feature_cfg, &dict_cnt);
242     if (!dict) {
243         av_log(ctx, AV_LOG_ERROR,
244                "could not parse feature config: %s\n", s->feature_cfg);
245         return AVERROR(EINVAL);
246     }
247
248     for (unsigned i = 0; i < dict_cnt; i++) {
249         char *feature_name = NULL;
250         VmafFeatureDictionary *feature_opts_dict = NULL;
251         const AVDictionaryEntry *e = NULL;
252
253         while (e = av_dict_iterate(dict[i], e)) {
254             if (av_stristr(e->key, "name")) {
255                 feature_name = e->value;
256                 continue;
257             }
258
259             err = vmaf_feature_dictionary_set(&feature_opts_dict, e->key,
260                                               e->value);
261             if (err) {
262                 av_log(ctx, AV_LOG_ERROR,
263                        "could not set feature option: %s.%s=%s\n",
264                        feature_name, e->key, e->value);
265                 goto exit;
266             }
267         }
268
269         err = vmaf_use_feature(s->vmaf, feature_name, feature_opts_dict);
270         if (err) {
271             av_log(ctx, AV_LOG_ERROR,
272                    "problem during vmaf_use_feature: %s\n", feature_name);
273             goto exit;
274         }
275     }
276
277 exit:
278     for (unsigned i = 0; i < dict_cnt; i++) {
279         if (dict[i])
280             av_dict_free(&dict[i]);
281     }
282     av_free(dict);
283     return err;
284 }
285
286 static int parse_models(AVFilterContext *ctx)
287 {
288     LIBVMAFContext *s = ctx->priv;
289     AVDictionary **dict;
290     unsigned dict_cnt;
291     int err = 0;
292
293     if (!s->model_cfg) return 0;
294
295     dict_cnt = 0;
296     dict = delimited_dict_parse(s->model_cfg, &dict_cnt);
297     if (!dict) {
298         av_log(ctx, AV_LOG_ERROR,
299                "could not parse model config: %s\n", s->model_cfg);
300         return AVERROR(EINVAL);
301     }
302
303     s->model_cnt = dict_cnt;
304     s->model = av_calloc(s->model_cnt, sizeof(*s->model));
305     if (!s->model)
306         return AVERROR(ENOMEM);
307
308     for (unsigned i = 0; i < dict_cnt; i++) {
309         VmafModelConfig model_cfg = { 0 };
310         const AVDictionaryEntry *e = NULL;
311         char *version = NULL;
312         char  *path = NULL;
313
314         while (e = av_dict_iterate(dict[i], e)) {
315             if (av_stristr(e->key, "disable_clip")) {
316                 model_cfg.flags |= av_stristr(e->value, "true") ?
317                     VMAF_MODEL_FLAG_DISABLE_CLIP : 0;
318                 continue;
319             }
320
321             if (av_stristr(e->key, "enable_transform")) {
322                 model_cfg.flags |= av_stristr(e->value, "true") ?
323                     VMAF_MODEL_FLAG_ENABLE_TRANSFORM : 0;
324                 continue;
325             }
326
327             if (av_stristr(e->key, "name")) {
328                 model_cfg.name = e->value;
329                 continue;
330             }
331
332             if (av_stristr(e->key, "version")) {
333                 version = e->value;
334                 continue;
335             }
336
337             if (av_stristr(e->key, "path")) {
338                 path = e->value;
339                 continue;
340             }
341         }
342
343         if (version) {
344             err = vmaf_model_load(&s->model[i], &model_cfg, version);
345             if (err) {
346                 av_log(ctx, AV_LOG_ERROR,
347                        "could not load libvmaf model with version: %s\n",
348                        version);
349                 goto exit;
350             }
351         }
352
353         if (path && !s->model[i]) {
354             err = vmaf_model_load_from_path(&s->model[i], &model_cfg, path);
355             if (err) {
356                 av_log(ctx, AV_LOG_ERROR,
357                        "could not load libvmaf model with path: %s\n",
358                        path);
359                 goto exit;
360             }
361         }
362
363         if (!s->model[i]) {
364             av_log(ctx, AV_LOG_ERROR,
365                    "could not load libvmaf model with config: %s\n",
366                    s->model_cfg);
367             goto exit;
368         }
369
370         while (e = av_dict_iterate(dict[i], e)) {
371             VmafFeatureDictionary *feature_opts_dict = NULL;
372             char *feature_opt = NULL;
373
374             char *feature_name = av_strtok(e->key, ".", &feature_opt);
375             if (!feature_opt)
376                 continue;
377
378             err = vmaf_feature_dictionary_set(&feature_opts_dict,
379                                               feature_opt, e->value);
380             if (err) {
381                 av_log(ctx, AV_LOG_ERROR,
382                        "could not set feature option: %s.%s=%s\n",
383                        feature_name, feature_opt, e->value);
384                 err = AVERROR(EINVAL);
385                 goto exit;
386             }
387
388             err = vmaf_model_feature_overload(s->model[i], feature_name,
389                                               feature_opts_dict);
390             if (err) {
391                 av_log(ctx, AV_LOG_ERROR,
392                        "could not overload feature: %s\n", feature_name);
393                 err = AVERROR(EINVAL);
394                 goto exit;
395             }
396         }
397     }
398
399     for (unsigned i = 0; i < s->model_cnt; i++) {
400         err = vmaf_use_features_from_model(s->vmaf, s->model[i]);
401         if (err) {
402             av_log(ctx, AV_LOG_ERROR,
403                    "problem during vmaf_use_features_from_model\n");
404             err = AVERROR(EINVAL);
405             goto exit;
406         }
407     }
408
409 exit:
410     for (unsigned i = 0; i < dict_cnt; i++) {
411         if (dict[i])
412             av_dict_free(&dict[i]);
413     }
414     av_free(dict);
415     return err;
416 }
417
418 static enum VmafLogLevel log_level_map(int log_level)
419 {
420     switch (log_level) {
421     case AV_LOG_QUIET:
422         return VMAF_LOG_LEVEL_NONE;
423     case AV_LOG_ERROR:
424         return VMAF_LOG_LEVEL_ERROR;
425     case AV_LOG_WARNING:
426         return VMAF_LOG_LEVEL_WARNING;
427     case AV_LOG_INFO:
428         return VMAF_LOG_LEVEL_INFO;
429     case AV_LOG_DEBUG:
430         return VMAF_LOG_LEVEL_DEBUG;
431     default:
432         return VMAF_LOG_LEVEL_INFO;
433     }
434 }
435
436 static av_cold int init(AVFilterContext *ctx)
437 {
438     LIBVMAFContext *s = ctx->priv;
439     int err = 0;
440
441     VmafConfiguration cfg = {
442         .log_level = log_level_map(av_log_get_level()),
443         .n_subsample = s->n_subsample,
444         .n_threads = s->n_threads,
445     };
446
447     err = vmaf_init(&s->vmaf, cfg);
448     if (err)
449         return AVERROR(EINVAL);
450
451     err = parse_models(ctx);
452     if (err)
453         return err;
454
455     err = parse_features(ctx);
456     if (err)
457         return err;
458
459     s->fs.on_event = do_vmaf;
460     return 0;
461 }
462
463 static const enum AVPixelFormat pix_fmts[] = {
464     AV_PIX_FMT_YUV444P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV420P,
465     AV_PIX_FMT_YUV444P10LE, AV_PIX_FMT_YUV422P10LE, AV_PIX_FMT_YUV420P10LE,
466     AV_PIX_FMT_YUV444P12LE, AV_PIX_FMT_YUV422P12LE, AV_PIX_FMT_YUV420P12LE,
467     AV_PIX_FMT_YUV444P16LE, AV_PIX_FMT_YUV422P16LE, AV_PIX_FMT_YUV420P16LE,
468     AV_PIX_FMT_NONE
469 };
470
471 static int config_input_ref(AVFilterLink *inlink)
472 {
473     AVFilterContext *ctx = inlink->dst;
474     LIBVMAFContext *s = ctx->priv;
475     const AVPixFmtDescriptor *desc;
476     int err = 0;
477
478     if (ctx->inputs[0]->w != ctx->inputs[1]->w) {
479         av_log(ctx, AV_LOG_ERROR, "input width must match.\n");
480         err |= AVERROR(EINVAL);
481     }
482
483     if (ctx->inputs[0]->h != ctx->inputs[1]->h) {
484         av_log(ctx, AV_LOG_ERROR, "input height must match.\n");
485         err |= AVERROR(EINVAL);
486     }
487
488     if (ctx->inputs[0]->format != ctx->inputs[1]->format) {
489         av_log(ctx, AV_LOG_ERROR, "input pix_fmt must match.\n");
490         err |= AVERROR(EINVAL);
491     }
492
493     if (err)
494         return err;
495
496     desc = av_pix_fmt_desc_get(inlink->format);
497     s->bpc = desc->comp[0].depth;
498
499     return 0;
500 }
501
502 static int config_output(AVFilterLink *outlink)
503 {
504     AVFilterContext *ctx = outlink->src;
505     LIBVMAFContext *s = ctx->priv;
506     AVFilterLink *mainlink = ctx->inputs[0];
507     int ret;
508
509     ret = ff_framesync_init_dualinput(&s->fs, ctx);
510     if (ret < 0)
511         return ret;
512     outlink->w = mainlink->w;
513     outlink->h = mainlink->h;
514     outlink->time_base = mainlink->time_base;
515     outlink->sample_aspect_ratio = mainlink->sample_aspect_ratio;
516     outlink->frame_rate = mainlink->frame_rate;
517     if ((ret = ff_framesync_configure(&s->fs)) < 0)
518         return ret;
519
520     return 0;
521 }
522
523 static int activate(AVFilterContext *ctx)
524 {
525     LIBVMAFContext *s = ctx->priv;
526     return ff_framesync_activate(&s->fs);
527 }
528
529 static enum VmafOutputFormat log_fmt_map(const char *log_fmt)
530 {
531     if (log_fmt) {
532         if (av_stristr(log_fmt, "xml"))
533             return VMAF_OUTPUT_FORMAT_XML;
534         if (av_stristr(log_fmt, "json"))
535             return VMAF_OUTPUT_FORMAT_JSON;
536         if (av_stristr(log_fmt, "csv"))
537             return VMAF_OUTPUT_FORMAT_CSV;
538         if (av_stristr(log_fmt, "sub"))
539             return VMAF_OUTPUT_FORMAT_SUB;
540     }
541
542     return VMAF_OUTPUT_FORMAT_XML;
543 }
544
545 static enum VmafPoolingMethod pool_method_map(const char *pool_method)
546 {
547     if (pool_method) {
548         if (av_stristr(pool_method, "min"))
549             return VMAF_POOL_METHOD_MIN;
550         if (av_stristr(pool_method, "mean"))
551             return VMAF_POOL_METHOD_MEAN;
552         if (av_stristr(pool_method, "harmonic_mean"))
553             return VMAF_POOL_METHOD_HARMONIC_MEAN;
554     }
555
556     return VMAF_POOL_METHOD_MEAN;
557 }
558
559 static av_cold void uninit(AVFilterContext *ctx)
560 {
561     LIBVMAFContext *s = ctx->priv;
562     int err = 0;
563
564     ff_framesync_uninit(&s->fs);
565
566     if (!s->frame_cnt)
567         goto clean_up;
568
569     err = vmaf_read_pictures(s->vmaf, NULL, NULL, 0);
570     if (err) {
571         av_log(ctx, AV_LOG_ERROR,
572                "problem flushing libvmaf context.\n");
573     }
574
575     for (unsigned i = 0; i < s->model_cnt; i++) {
576         double vmaf_score;
577         err = vmaf_score_pooled(s->vmaf, s->model[i], pool_method_map(s->pool),
578                                 &vmaf_score, 0, s->frame_cnt - 1);
579         if (err) {
580             av_log(ctx, AV_LOG_ERROR,
581                    "problem getting pooled vmaf score.\n");
582         }
583
584         av_log(ctx, AV_LOG_INFO, "VMAF score: %f\n", vmaf_score);
585     }
586
587     if (s->vmaf) {
588         if (s->log_path && !err)
589             vmaf_write_output(s->vmaf, s->log_path, log_fmt_map(s->log_fmt));
590     }
591
592 clean_up:
593     if (s->model) {
594         for (unsigned i = 0; i < s->model_cnt; i++) {
595             if (s->model[i])
596                 vmaf_model_destroy(s->model[i]);
597         }
598         av_free(s->model);
599     }
600
601     if (s->vmaf)
602         vmaf_close(s->vmaf);
603 }
604
605 static const AVFilterPad libvmaf_inputs[] = {
606     {
607         .name         = "main",
608         .type         = AVMEDIA_TYPE_VIDEO,
609     },{
610         .name         = "reference",
611         .type         = AVMEDIA_TYPE_VIDEO,
612         .config_props = config_input_ref,
613     },
614 };
615
616 static const AVFilterPad libvmaf_outputs[] = {
617     {
618         .name          = "default",
619         .type          = AVMEDIA_TYPE_VIDEO,
620         .config_props  = config_output,
621     },
622 };
623
624 const AVFilter ff_vf_libvmaf = {
625     .name          = "libvmaf",
626     .description   = NULL_IF_CONFIG_SMALL("Calculate the VMAF between two video streams."),
627     .preinit       = libvmaf_framesync_preinit,
628     .init          = init,
629     .uninit        = uninit,
630     .activate      = activate,
631     .priv_size     = sizeof(LIBVMAFContext),
632     .priv_class    = &libvmaf_class,
633     FILTER_INPUTS(libvmaf_inputs),
634     FILTER_OUTPUTS(libvmaf_outputs),
635     FILTER_PIXFMTS_ARRAY(pix_fmts),
636 };
637
638 #if CONFIG_LIBVMAF_CUDA_FILTER
639 static const enum AVPixelFormat supported_formats[] = {
640     AV_PIX_FMT_YUV420P,
641     AV_PIX_FMT_YUV444P16,
642 };
643
644 static int format_is_supported(enum AVPixelFormat fmt)
645 {
646     int i;
647
648     for (i = 0; i < FF_ARRAY_ELEMS(supported_formats); i++)
649         if (supported_formats[i] == fmt)
650             return 1;
651     return 0;
652 }
653
654 static int config_props_cuda(AVFilterLink *outlink)
655 {
656     int err;
657     AVFilterContext *ctx = outlink->src;
658     LIBVMAFContext *s = ctx->priv;
659     AVFilterLink *inlink = ctx->inputs[0];
660     AVHWFramesContext *frames_ctx = (AVHWFramesContext*) inlink->hw_frames_ctx->data;
661     AVCUDADeviceContext *device_hwctx = frames_ctx->device_ctx->hwctx;
662     CUcontext cu_ctx = device_hwctx->cuda_ctx;
663     const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(frames_ctx->sw_format);
664
665     VmafConfiguration cfg = {
666         .log_level = log_level_map(av_log_get_level()),
667         .n_subsample = s->n_subsample,
668         .n_threads = s->n_threads,
669     };
670
671     VmafCudaPictureConfiguration cuda_pic_cfg = {
672         .pic_params = {
673             .bpc = desc->comp[0].depth,
674             .w = inlink->w,
675             .h = inlink->h,
676             .pix_fmt = pix_fmt_map(frames_ctx->sw_format),
677         },
678         .pic_prealloc_method = VMAF_CUDA_PICTURE_PREALLOCATION_METHOD_DEVICE,
679     };
680
681     VmafCudaConfiguration cuda_cfg = {
682         .cu_ctx = cu_ctx,
683     };
684
685     if (!format_is_supported(frames_ctx->sw_format)) {
686         av_log(s, AV_LOG_ERROR,
687                "Unsupported input format: %s\n", desc->name);
688         return AVERROR(EINVAL);
689     }
690
691     err = vmaf_init(&s->vmaf, cfg);
692     if (err)
693         return AVERROR(EINVAL);
694
695     err = vmaf_cuda_state_init(&s->cu_state, cuda_cfg);
696     if (err)
697         return AVERROR(EINVAL);
698
699     err = vmaf_cuda_import_state(s->vmaf, s->cu_state);
700     if (err)
701         return AVERROR(EINVAL);
702
703     err = vmaf_cuda_preallocate_pictures(s->vmaf, cuda_pic_cfg);
704     if (err < 0)
705         return err;
706
707     err = parse_models(ctx);
708     if (err)
709         return err;
710
711     err = parse_features(ctx);
712     if (err)
713         return err;
714
715     return config_output(outlink);
716 }
717
718 static int copy_picture_data_cuda(VmafContext* vmaf,
719                                   AVCUDADeviceContext* device_hwctx,
720                                   AVFrame* src, VmafPicture* dst,
721                                   enum AVPixelFormat pix_fmt)
722 {
723     const AVPixFmtDescriptor *pix_desc = av_pix_fmt_desc_get(pix_fmt);
724     CudaFunctions *cu = device_hwctx->internal->cuda_dl;
725
726     CUDA_MEMCPY2D m = {
727         .srcMemoryType = CU_MEMORYTYPE_DEVICE,
728         .dstMemoryType = CU_MEMORYTYPE_DEVICE,
729     };
730
731     int err = vmaf_cuda_fetch_preallocated_picture(vmaf, dst);
732     if (err)
733         return AVERROR(ENOMEM);
734
735     err = cu->cuCtxPushCurrent(device_hwctx->cuda_ctx);
736     if (err)
737         return AVERROR_EXTERNAL;
738
739     for (unsigned i = 0; i < pix_desc->nb_components; i++) {
740         m.srcDevice = (CUdeviceptr) src->data[i];
741         m.srcPitch = src->linesize[i];
742         m.dstDevice = (CUdeviceptr) dst->data[i];
743         m.dstPitch = dst->stride[i];
744         m.WidthInBytes = dst->w[i] * ((dst->bpc + 7) / 8);
745         m.Height = dst->h[i];
746
747         err = cu->cuMemcpy2D(&m);
748         if (err)
749             return AVERROR_EXTERNAL;
750         break;
751     }
752
753     err = cu->cuCtxPopCurrent(NULL);
754     if (err)
755         return AVERROR_EXTERNAL;
756
757     return 0;
758 }
759
760 static int do_vmaf_cuda(FFFrameSync* fs)
761 {
762     AVFilterContext* ctx = fs->parent;
763     LIBVMAFContext* s = ctx->priv;
764     AVFilterLink *inlink = ctx->inputs[0];
765     AVHWFramesContext *frames_ctx = (AVHWFramesContext*) inlink->hw_frames_ctx->data;
766     AVCUDADeviceContext *device_hwctx = frames_ctx->device_ctx->hwctx;
767     VmafPicture pic_ref, pic_dist;
768     AVFrame *ref, *dist;
769
770     int err = 0;
771
772     err = ff_framesync_dualinput_get(fs, &dist, &ref);
773     if (err < 0)
774         return err;
775     if (ctx->is_disabled || !ref)
776         return ff_filter_frame(ctx->outputs[0], dist);
777
778     err = copy_picture_data_cuda(s->vmaf, device_hwctx, ref, &pic_ref,
779                                  frames_ctx->sw_format);
780     if (err) {
781         av_log(s, AV_LOG_ERROR, "problem during copy_picture_data_cuda.\n");
782         return AVERROR(ENOMEM);
783     }
784
785     err = copy_picture_data_cuda(s->vmaf, device_hwctx, dist, &pic_dist,
786                                  frames_ctx->sw_format);
787     if (err) {
788         av_log(s, AV_LOG_ERROR, "problem during copy_picture_data_cuda.\n");
789         return AVERROR(ENOMEM);
790     }
791
792     err = vmaf_read_pictures(s->vmaf, &pic_ref, &pic_dist, s->frame_cnt++);
793     if (err) {
794         av_log(s, AV_LOG_ERROR, "problem during vmaf_read_pictures.\n");
795         return AVERROR(EINVAL);
796     }
797
798     return ff_filter_frame(ctx->outputs[0], dist);
799 }
800
801 static av_cold int init_cuda(AVFilterContext *ctx)
802 {
803     LIBVMAFContext *s = ctx->priv;
804     s->fs.on_event = do_vmaf_cuda;
805     return 0;
806 }
807
808 static const AVFilterPad libvmaf_outputs_cuda[] = {
809     {
810         .name         = "default",
811         .type         = AVMEDIA_TYPE_VIDEO,
812         .config_props = config_props_cuda,
813     },
814 };
815
816 const AVFilter ff_vf_libvmaf_cuda = {
817     .name           = "libvmaf_cuda",
818     .description    = NULL_IF_CONFIG_SMALL("Calculate the VMAF between two video streams."),
819     .preinit        = libvmaf_framesync_preinit,
820     .init           = init_cuda,
821     .uninit         = uninit,
822     .activate       = activate,
823     .priv_size      = sizeof(LIBVMAFContext),
824     .priv_class     = &libvmaf_class,
825     FILTER_INPUTS(libvmaf_inputs),
826     FILTER_OUTPUTS(libvmaf_outputs_cuda),
827     FILTER_SINGLE_PIXFMT(AV_PIX_FMT_CUDA),
828     .flags_internal = FF_FILTER_FLAG_HWFRAME_AWARE,
829 };
830 #endif