f20412cb104d4131c583aa2091ca893bbbd764d6
[platform/upstream/ffmpeg.git] / libavfilter / vf_grayworld.c
1 /*
2  * Copyright (c) 2021 Paul Buxton
3  *
4  * This file is part of FFmpeg.
5  *
6  * FFmpeg is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * FFmpeg is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with FFmpeg; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19  */
20
21  /**
22   * @file
23   * Color correction filter based on
24   * https://www.researchgate.net/publication/275213614_A_New_Color_Correction_Method_for_Underwater_Imaging
25   *
26   */
27
28 #include "libavutil/imgutils.h"
29 #include "libavutil/opt.h"
30 #include "libavutil/pixdesc.h"
31
32 #include "avfilter.h"
33 #include "formats.h"
34 #include "internal.h"
35 #include "video.h"
36
37 typedef struct ThreadData {
38     AVFrame *in, *out;
39     float l_avg;
40     float a_avg;
41     float b_avg;
42 } ThreadData;
43
44 typedef struct GrayWorldContext {
45     const AVClass *class;
46     float *tmpplab;
47     int *line_count_pels;
48     float *line_sum;
49 } GrayWorldContext;
50
51 #define OFFSET(x) offsetof(GrayWorldContext, x)
52 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM | AV_OPT_FLAG_RUNTIME_PARAM
53 static const AVOption grayworld_options[] = {
54     { NULL }
55 };
56
57 AVFILTER_DEFINE_CLASS(grayworld);
58
59 static void apply_matrix(const float matrix[3][3], const float input[3], float output[3])
60 {
61     output[0] = matrix[0][0] * input[0] + matrix[0][1] * input[1] + matrix[0][2] * input[2];
62     output[1] = matrix[1][0] * input[0] + matrix[1][1] * input[1] + matrix[1][2] * input[2];
63     output[2] = matrix[2][0] * input[0] + matrix[2][1] * input[1] + matrix[2][2] * input[2];
64 }
65
66 static const float lms2lab[3][3] = {
67     {0.5774, 0.5774, 0.5774},
68     {0.40825, 0.40825, -0.816458},
69     {0.707, -0.707, 0}
70 };
71
72 static const float lab2lms[3][3] = {
73     {0.57735, 0.40825, 0.707},
74     {0.57735, 0.40825, -0.707},
75     {0.57735, -0.8165, 0}
76 };
77
78 static const float rgb2lms[3][3] = {
79     {0.3811, 0.5783, 0.0402},
80     {0.1967, 0.7244, 0.0782},
81     {0.0241, 0.1288, 0.8444}
82 };
83
84 static const float lms2rgb[3][3] = {
85     {4.4679, -3.5873, 0.1193},
86     {-1.2186, 2.3809, -0.1624},
87     {0.0497, -0.2439, 1.2045}
88 };
89
90 /**
91  * Convert from Linear RGB to logspace LAB
92  *
93  * @param rgb Input array of rgb components
94  * @param lab output array of lab components
95  */
96 static void rgb2lab(const float rgb[3], float lab[3])
97 {
98     float lms[3];
99
100     apply_matrix(rgb2lms, rgb, lms);
101     lms[0] = lms[0] > 0.f ? logf(lms[0]) : -1024.f;
102     lms[1] = lms[1] > 0.f ? logf(lms[1]) : -1024.f;
103     lms[2] = lms[2] > 0.f ? logf(lms[2]) : -1024.f;
104     apply_matrix(lms2lab, lms, lab);
105 }
106
107 /**
108  * Convert from Logspace LAB to Linear RGB
109  *
110  * @param lab input array of lab components
111  * @param rgb output array of rgb components
112  */
113 static void lab2rgb(const float lab[3], float rgb[3])
114 {
115     float lms[3];
116
117     apply_matrix(lab2lms, lab, lms);
118     lms[0] = expf(lms[0]);
119     lms[1] = expf(lms[1]);
120     lms[2] = expf(lms[2]);
121     apply_matrix(lms2rgb, lms, rgb);
122 }
123
124 /**
125  * Convert a frame from linear RGB to logspace LAB, and accumulate channel totals for each row
126  * Convert from RGB -> lms using equation 4 in color transfer paper.
127  *
128  * @param ctx Filter context
129  * @param arg Thread data pointer
130  * @param jobnr job number
131  * @param nb_jobs number of jobs
132  */
133 static int convert_frame(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
134 {
135     GrayWorldContext *s = ctx->priv;
136     ThreadData *td = arg;
137     AVFrame *in = td->in;
138     AVFrame *out = td->out;
139     AVFilterLink *outlink = ctx->outputs[0];
140     const int slice_start = (out->height * jobnr) / nb_jobs;
141     const int slice_end = (out->height * (jobnr + 1)) / nb_jobs;
142     float rgb[3], lab[3];
143
144     for (int i = slice_start; i < slice_end; i++) {
145         float *b_in_row = (float *)(in->data[1] + i * in->linesize[1]);
146         float *g_in_row = (float *)(in->data[0] + i * in->linesize[0]);
147         float *r_in_row = (float *)(in->data[2] + i * in->linesize[2]);
148         float *acur = s->tmpplab + i * outlink->w + outlink->w * outlink->h;
149         float *bcur = s->tmpplab + i * outlink->w + 2 * outlink->w * outlink->h;
150         float *lcur = s->tmpplab + i * outlink->w;
151
152         s->line_sum[i] = 0.f;
153         s->line_sum[i + outlink->h] = 0.f;
154         s->line_count_pels[i] = 0;
155
156         for (int j = 0; j < outlink->w; j++) {
157             rgb[0] = r_in_row[j];
158             rgb[1] = g_in_row[j];
159             rgb[2] = b_in_row[j];
160             rgb2lab(rgb, lab);
161             *(lcur++) = lab[0];
162             *(acur++) = lab[1];
163             *(bcur++) = lab[2];
164             s->line_sum[i] += lab[1];
165             s->line_sum[i + outlink->h] += lab[2];
166             s->line_count_pels[i]++;
167         }
168     }
169     return 0;
170 }
171
172 /**
173  * Sum the channel totals and compute the mean for each channel
174  *
175  * @param s Frame context
176  * @param td thread data
177  */
178 static void compute_correction(GrayWorldContext *s, ThreadData *td)
179 {
180     float asum = 0.f, bsum = 0.f;
181     int pixels = 0;
182
183     for (int y = 0; y < td->out->height; y++) {
184         asum += s->line_sum[y];
185         bsum += s->line_sum[y + td->out->height];
186         pixels += s->line_count_pels[y];
187     }
188
189     td->a_avg = asum / pixels;
190     td->b_avg = bsum / pixels;
191 }
192
193 /**
194  * Subtract the mean logspace AB values from each pixel.
195  *
196  * @param ctx Filter context
197  * @param arg Thread data pointer
198  * @param jobnr job number
199  * @param nb_jobs number of jobs
200  */
201 static int correct_frame(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
202 {
203     GrayWorldContext *s = ctx->priv;
204     ThreadData *td = arg;
205     AVFrame *out = td->out;
206     AVFilterLink *outlink = ctx->outputs[0];
207     const int slice_start = (out->height * jobnr) / nb_jobs;
208     const int slice_end = (out->height * (jobnr + 1)) / nb_jobs;
209     float rgb[3], lab[3];
210
211     for (int i = slice_start; i < slice_end; i++) {
212         float *g_out_row = (float *)(out->data[0] + i * out->linesize[0]);
213         float *b_out_row = (float *)(out->data[1] + i * out->linesize[1]);
214         float *r_out_row = (float *)(out->data[2] + i * out->linesize[2]);
215         float *lcur = s->tmpplab + i * outlink->w;
216         float *acur = s->tmpplab + i * outlink->w + outlink->w * outlink->h;
217         float *bcur = s->tmpplab + i * outlink->w + 2 * outlink->w * outlink->h;
218
219         for (int j = 0; j < outlink->w; j++) {
220             lab[0] = *lcur++;
221             lab[1] = *acur++;
222             lab[2] = *bcur++;
223
224             // subtract the average for the color channels
225             lab[1] -= td->a_avg;
226             lab[2] -= td->b_avg;
227
228             //convert back to linear rgb
229             lab2rgb(lab, rgb);
230             r_out_row[j] = rgb[0];
231             g_out_row[j] = rgb[1];
232             b_out_row[j] = rgb[2];
233         }
234     }
235     return 0;
236 }
237
238 static int config_input(AVFilterLink *inlink)
239 {
240     GrayWorldContext *s = inlink->dst->priv;
241
242     FF_ALLOC_TYPED_ARRAY(s->tmpplab, inlink->h * inlink->w * 3);
243     FF_ALLOC_TYPED_ARRAY(s->line_count_pels, inlink->h);
244     FF_ALLOC_TYPED_ARRAY(s->line_sum, inlink->h * 2);
245     if (!s->tmpplab || !s->line_count_pels || !s->line_sum)
246         return AVERROR(ENOMEM);
247
248     return 0;
249 }
250
251 static av_cold void uninit(AVFilterContext *ctx)
252 {
253     GrayWorldContext *s = ctx->priv;
254
255     av_freep(&s->tmpplab);
256     av_freep(&s->line_count_pels);
257     av_freep(&s->line_sum);
258 }
259
260 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
261 {
262     AVFilterContext *ctx = inlink->dst;
263     GrayWorldContext *s = ctx->priv;
264     AVFilterLink *outlink = ctx->outputs[0];
265     ThreadData td;
266     AVFrame *out;
267
268     if (av_frame_is_writable(in)) {
269         out = in;
270     } else {
271         out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
272         if (!out) {
273             av_frame_free(&in);
274             return AVERROR(ENOMEM);
275         }
276         av_frame_copy_props(out, in);
277     }
278     /* input and output transfer will be linear */
279     if (in->color_trc == AVCOL_TRC_UNSPECIFIED) {
280         av_log(s, AV_LOG_WARNING, "Untagged transfer, assuming linear light.\n");
281         out->color_trc = AVCOL_TRC_LINEAR;
282     } else if (in->color_trc != AVCOL_TRC_LINEAR) {
283         av_log(s, AV_LOG_WARNING, "Gray world color correction works on linear light only.\n");
284     }
285
286     td.in = in;
287     td.out = out;
288
289     ff_filter_execute(ctx, convert_frame, &td, NULL, FFMIN(outlink->h, ff_filter_get_nb_threads(ctx)));
290     compute_correction(s, &td);
291     ff_filter_execute(ctx, correct_frame, &td, NULL, FFMIN(outlink->h, ff_filter_get_nb_threads(ctx)));
292
293     if (in != out) {
294         av_image_copy_plane(out->data[3], out->linesize[3],
295             in->data[3], in->linesize[3], outlink->w * 4, outlink->h);
296         av_frame_free(&in);
297     }
298
299     return ff_filter_frame(outlink, out);
300 }
301
302 static const AVFilterPad grayworld_inputs[] = {
303     {
304         .name         = "default",
305         .type         = AVMEDIA_TYPE_VIDEO,
306         .filter_frame = filter_frame,
307         .config_props = config_input,
308     }
309 };
310
311 static const AVFilterPad grayworld_outputs[] = {
312     {
313         .name = "default",
314         .type = AVMEDIA_TYPE_VIDEO,
315     }
316 };
317
318 const AVFilter ff_vf_grayworld = {
319     .name          = "grayworld",
320     .description   = NULL_IF_CONFIG_SMALL("Adjust white balance using LAB gray world algorithm"),
321     .priv_size     = sizeof(GrayWorldContext),
322     .priv_class    = &grayworld_class,
323     FILTER_INPUTS(grayworld_inputs),
324     FILTER_OUTPUTS(grayworld_outputs),
325     FILTER_PIXFMTS(AV_PIX_FMT_GBRPF32, AV_PIX_FMT_GBRAPF32),
326     .flags         = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,
327     .uninit        = uninit,
328 };