2 * Copyright (c) 2021 Paul Buxton
4 * This file is part of FFmpeg.
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23 * Color correction filter based on
24 * https://www.researchgate.net/publication/275213614_A_New_Color_Correction_Method_for_Underwater_Imaging
28 #include "libavutil/imgutils.h"
29 #include "libavutil/opt.h"
30 #include "libavutil/pixdesc.h"
37 typedef struct ThreadData {
44 typedef struct GrayWorldContext {
51 #define OFFSET(x) offsetof(GrayWorldContext, x)
52 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM | AV_OPT_FLAG_RUNTIME_PARAM
53 static const AVOption grayworld_options[] = {
57 AVFILTER_DEFINE_CLASS(grayworld);
59 static void apply_matrix(const float matrix[3][3], const float input[3], float output[3])
61 output[0] = matrix[0][0] * input[0] + matrix[0][1] * input[1] + matrix[0][2] * input[2];
62 output[1] = matrix[1][0] * input[0] + matrix[1][1] * input[1] + matrix[1][2] * input[2];
63 output[2] = matrix[2][0] * input[0] + matrix[2][1] * input[1] + matrix[2][2] * input[2];
66 static const float lms2lab[3][3] = {
67 {0.5774, 0.5774, 0.5774},
68 {0.40825, 0.40825, -0.816458},
72 static const float lab2lms[3][3] = {
73 {0.57735, 0.40825, 0.707},
74 {0.57735, 0.40825, -0.707},
78 static const float rgb2lms[3][3] = {
79 {0.3811, 0.5783, 0.0402},
80 {0.1967, 0.7244, 0.0782},
81 {0.0241, 0.1288, 0.8444}
84 static const float lms2rgb[3][3] = {
85 {4.4679, -3.5873, 0.1193},
86 {-1.2186, 2.3809, -0.1624},
87 {0.0497, -0.2439, 1.2045}
91 * Convert from Linear RGB to logspace LAB
93 * @param rgb Input array of rgb components
94 * @param lab output array of lab components
96 static void rgb2lab(const float rgb[3], float lab[3])
100 apply_matrix(rgb2lms, rgb, lms);
101 lms[0] = lms[0] > 0.f ? logf(lms[0]) : -1024.f;
102 lms[1] = lms[1] > 0.f ? logf(lms[1]) : -1024.f;
103 lms[2] = lms[2] > 0.f ? logf(lms[2]) : -1024.f;
104 apply_matrix(lms2lab, lms, lab);
108 * Convert from Logspace LAB to Linear RGB
110 * @param lab input array of lab components
111 * @param rgb output array of rgb components
113 static void lab2rgb(const float lab[3], float rgb[3])
117 apply_matrix(lab2lms, lab, lms);
118 lms[0] = expf(lms[0]);
119 lms[1] = expf(lms[1]);
120 lms[2] = expf(lms[2]);
121 apply_matrix(lms2rgb, lms, rgb);
125 * Convert a frame from linear RGB to logspace LAB, and accumulate channel totals for each row
126 * Convert from RGB -> lms using equation 4 in color transfer paper.
128 * @param ctx Filter context
129 * @param arg Thread data pointer
130 * @param jobnr job number
131 * @param nb_jobs number of jobs
133 static int convert_frame(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
135 GrayWorldContext *s = ctx->priv;
136 ThreadData *td = arg;
137 AVFrame *in = td->in;
138 AVFrame *out = td->out;
139 AVFilterLink *outlink = ctx->outputs[0];
140 const int slice_start = (out->height * jobnr) / nb_jobs;
141 const int slice_end = (out->height * (jobnr + 1)) / nb_jobs;
142 float rgb[3], lab[3];
144 for (int i = slice_start; i < slice_end; i++) {
145 float *b_in_row = (float *)(in->data[1] + i * in->linesize[1]);
146 float *g_in_row = (float *)(in->data[0] + i * in->linesize[0]);
147 float *r_in_row = (float *)(in->data[2] + i * in->linesize[2]);
148 float *acur = s->tmpplab + i * outlink->w + outlink->w * outlink->h;
149 float *bcur = s->tmpplab + i * outlink->w + 2 * outlink->w * outlink->h;
150 float *lcur = s->tmpplab + i * outlink->w;
152 s->line_sum[i] = 0.f;
153 s->line_sum[i + outlink->h] = 0.f;
154 s->line_count_pels[i] = 0;
156 for (int j = 0; j < outlink->w; j++) {
157 rgb[0] = r_in_row[j];
158 rgb[1] = g_in_row[j];
159 rgb[2] = b_in_row[j];
164 s->line_sum[i] += lab[1];
165 s->line_sum[i + outlink->h] += lab[2];
166 s->line_count_pels[i]++;
173 * Sum the channel totals and compute the mean for each channel
175 * @param s Frame context
176 * @param td thread data
178 static void compute_correction(GrayWorldContext *s, ThreadData *td)
180 float asum = 0.f, bsum = 0.f;
183 for (int y = 0; y < td->out->height; y++) {
184 asum += s->line_sum[y];
185 bsum += s->line_sum[y + td->out->height];
186 pixels += s->line_count_pels[y];
189 td->a_avg = asum / pixels;
190 td->b_avg = bsum / pixels;
194 * Subtract the mean logspace AB values from each pixel.
196 * @param ctx Filter context
197 * @param arg Thread data pointer
198 * @param jobnr job number
199 * @param nb_jobs number of jobs
201 static int correct_frame(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
203 GrayWorldContext *s = ctx->priv;
204 ThreadData *td = arg;
205 AVFrame *out = td->out;
206 AVFilterLink *outlink = ctx->outputs[0];
207 const int slice_start = (out->height * jobnr) / nb_jobs;
208 const int slice_end = (out->height * (jobnr + 1)) / nb_jobs;
209 float rgb[3], lab[3];
211 for (int i = slice_start; i < slice_end; i++) {
212 float *g_out_row = (float *)(out->data[0] + i * out->linesize[0]);
213 float *b_out_row = (float *)(out->data[1] + i * out->linesize[1]);
214 float *r_out_row = (float *)(out->data[2] + i * out->linesize[2]);
215 float *lcur = s->tmpplab + i * outlink->w;
216 float *acur = s->tmpplab + i * outlink->w + outlink->w * outlink->h;
217 float *bcur = s->tmpplab + i * outlink->w + 2 * outlink->w * outlink->h;
219 for (int j = 0; j < outlink->w; j++) {
224 // subtract the average for the color channels
228 //convert back to linear rgb
230 r_out_row[j] = rgb[0];
231 g_out_row[j] = rgb[1];
232 b_out_row[j] = rgb[2];
238 static int config_input(AVFilterLink *inlink)
240 GrayWorldContext *s = inlink->dst->priv;
242 FF_ALLOC_TYPED_ARRAY(s->tmpplab, inlink->h * inlink->w * 3);
243 FF_ALLOC_TYPED_ARRAY(s->line_count_pels, inlink->h);
244 FF_ALLOC_TYPED_ARRAY(s->line_sum, inlink->h * 2);
245 if (!s->tmpplab || !s->line_count_pels || !s->line_sum)
246 return AVERROR(ENOMEM);
251 static av_cold void uninit(AVFilterContext *ctx)
253 GrayWorldContext *s = ctx->priv;
255 av_freep(&s->tmpplab);
256 av_freep(&s->line_count_pels);
257 av_freep(&s->line_sum);
260 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
262 AVFilterContext *ctx = inlink->dst;
263 GrayWorldContext *s = ctx->priv;
264 AVFilterLink *outlink = ctx->outputs[0];
268 if (av_frame_is_writable(in)) {
271 out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
274 return AVERROR(ENOMEM);
276 av_frame_copy_props(out, in);
278 /* input and output transfer will be linear */
279 if (in->color_trc == AVCOL_TRC_UNSPECIFIED) {
280 av_log(s, AV_LOG_WARNING, "Untagged transfer, assuming linear light.\n");
281 out->color_trc = AVCOL_TRC_LINEAR;
282 } else if (in->color_trc != AVCOL_TRC_LINEAR) {
283 av_log(s, AV_LOG_WARNING, "Gray world color correction works on linear light only.\n");
289 ff_filter_execute(ctx, convert_frame, &td, NULL, FFMIN(outlink->h, ff_filter_get_nb_threads(ctx)));
290 compute_correction(s, &td);
291 ff_filter_execute(ctx, correct_frame, &td, NULL, FFMIN(outlink->h, ff_filter_get_nb_threads(ctx)));
294 av_image_copy_plane(out->data[3], out->linesize[3],
295 in->data[3], in->linesize[3], outlink->w * 4, outlink->h);
299 return ff_filter_frame(outlink, out);
302 static const AVFilterPad grayworld_inputs[] = {
305 .type = AVMEDIA_TYPE_VIDEO,
306 .filter_frame = filter_frame,
307 .config_props = config_input,
311 static const AVFilterPad grayworld_outputs[] = {
314 .type = AVMEDIA_TYPE_VIDEO,
318 const AVFilter ff_vf_grayworld = {
320 .description = NULL_IF_CONFIG_SMALL("Adjust white balance using LAB gray world algorithm"),
321 .priv_size = sizeof(GrayWorldContext),
322 .priv_class = &grayworld_class,
323 FILTER_INPUTS(grayworld_inputs),
324 FILTER_OUTPUTS(grayworld_outputs),
325 FILTER_PIXFMTS(AV_PIX_FMT_GBRPF32, AV_PIX_FMT_GBRAPF32),
326 .flags = AVFILTER_FLAG_SUPPORT_TIMELINE_GENERIC | AVFILTER_FLAG_SLICE_THREADS,