Imported Upstream version 0.9.0
[platform/upstream/libjxl.git] / lib / jxl / enc_patch_dictionary.cc
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_patch_dictionary.h"
7
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <sys/types.h>
11
12 #include <algorithm>
13 #include <atomic>
14 #include <string>
15 #include <tuple>
16 #include <utility>
17 #include <vector>
18
19 #include "lib/jxl/ans_params.h"
20 #include "lib/jxl/base/common.h"
21 #include "lib/jxl/base/compiler_specific.h"
22 #include "lib/jxl/base/override.h"
23 #include "lib/jxl/base/random.h"
24 #include "lib/jxl/base/status.h"
25 #include "lib/jxl/chroma_from_luma.h"
26 #include "lib/jxl/dec_cache.h"
27 #include "lib/jxl/dec_frame.h"
28 #include "lib/jxl/enc_ans.h"
29 #include "lib/jxl/enc_aux_out.h"
30 #include "lib/jxl/enc_cache.h"
31 #include "lib/jxl/enc_debug_image.h"
32 #include "lib/jxl/enc_dot_dictionary.h"
33 #include "lib/jxl/enc_frame.h"
34 #include "lib/jxl/entropy_coder.h"
35 #include "lib/jxl/frame_header.h"
36 #include "lib/jxl/image.h"
37 #include "lib/jxl/image_bundle.h"
38 #include "lib/jxl/image_ops.h"
39 #include "lib/jxl/pack_signed.h"
40 #include "lib/jxl/patch_dictionary_internal.h"
41
42 namespace jxl {
43
44 static constexpr size_t kPatchFrameReferenceId = 3;
45
46 // static
47 void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic,
48                                     BitWriter* writer, size_t layer,
49                                     AuxOut* aux_out) {
50   JXL_ASSERT(pdic.HasAny());
51   std::vector<std::vector<Token>> tokens(1);
52   size_t num_ec = pdic.shared_->metadata->m.num_extra_channels;
53
54   auto add_num = [&](int context, size_t num) {
55     tokens[0].emplace_back(context, num);
56   };
57   size_t num_ref_patch = 0;
58   for (size_t i = 0; i < pdic.positions_.size();) {
59     size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
60     while (i < pdic.positions_.size() &&
61            pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
62       i++;
63     }
64     num_ref_patch++;
65   }
66   add_num(kNumRefPatchContext, num_ref_patch);
67   size_t blend_pos = 0;
68   for (size_t i = 0; i < pdic.positions_.size();) {
69     size_t i_start = i;
70     size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
71     const auto& ref_pos = pdic.ref_positions_[ref_pos_idx];
72     while (i < pdic.positions_.size() &&
73            pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
74       i++;
75     }
76     size_t num = i - i_start;
77     JXL_ASSERT(num > 0);
78     add_num(kReferenceFrameContext, ref_pos.ref);
79     add_num(kPatchReferencePositionContext, ref_pos.x0);
80     add_num(kPatchReferencePositionContext, ref_pos.y0);
81     add_num(kPatchSizeContext, ref_pos.xsize - 1);
82     add_num(kPatchSizeContext, ref_pos.ysize - 1);
83     add_num(kPatchCountContext, num - 1);
84     for (size_t j = i_start; j < i; j++) {
85       const PatchPosition& pos = pdic.positions_[j];
86       if (j == i_start) {
87         add_num(kPatchPositionContext, pos.x);
88         add_num(kPatchPositionContext, pos.y);
89       } else {
90         add_num(kPatchOffsetContext,
91                 PackSigned(pos.x - pdic.positions_[j - 1].x));
92         add_num(kPatchOffsetContext,
93                 PackSigned(pos.y - pdic.positions_[j - 1].y));
94       }
95       for (size_t j = 0; j < num_ec + 1; ++j, ++blend_pos) {
96         const PatchBlending& info = pdic.blendings_[blend_pos];
97         add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode));
98         if (UsesAlpha(info.mode) &&
99             pdic.shared_->metadata->m.extra_channel_info.size() > 1) {
100           add_num(kPatchAlphaChannelContext, info.alpha_channel);
101         }
102         if (UsesClamp(info.mode)) {
103           add_num(kPatchClampContext, info.clamp);
104         }
105       }
106     }
107   }
108
109   EntropyEncodingData codes;
110   std::vector<uint8_t> context_map;
111   BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts,
112                            tokens, &codes, &context_map, writer, layer,
113                            aux_out);
114   WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out);
115 }
116
117 // static
118 void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic,
119                                           Image3F* opsin) {
120   size_t num_ec = pdic.shared_->metadata->m.num_extra_channels;
121   // TODO(veluca): this can likely be optimized knowing it runs on full images.
122   for (size_t y = 0; y < opsin->ysize(); y++) {
123     float* JXL_RESTRICT rows[3] = {
124         opsin->PlaneRow(0, y),
125         opsin->PlaneRow(1, y),
126         opsin->PlaneRow(2, y),
127     };
128     for (size_t pos_idx : pdic.GetPatchesForRow(y)) {
129       const size_t blending_idx = pos_idx * (num_ec + 1);
130       const PatchPosition& pos = pdic.positions_[pos_idx];
131       const PatchReferencePosition& ref_pos =
132           pdic.ref_positions_[pos.ref_pos_idx];
133       const PatchBlendMode mode = pdic.blendings_[blending_idx].mode;
134       size_t by = pos.y;
135       size_t bx = pos.x;
136       size_t xsize = ref_pos.xsize;
137       JXL_DASSERT(y >= by);
138       JXL_DASSERT(y < by + ref_pos.ysize);
139       size_t iy = y - by;
140       size_t ref = ref_pos.ref;
141       const float* JXL_RESTRICT ref_rows[3] = {
142           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
143               0, ref_pos.y0 + iy) +
144               ref_pos.x0,
145           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
146               1, ref_pos.y0 + iy) +
147               ref_pos.x0,
148           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
149               2, ref_pos.y0 + iy) +
150               ref_pos.x0,
151       };
152       for (size_t ix = 0; ix < xsize; ix++) {
153         for (size_t c = 0; c < 3; c++) {
154           if (mode == PatchBlendMode::kAdd) {
155             rows[c][bx + ix] -= ref_rows[c][ix];
156           } else if (mode == PatchBlendMode::kReplace) {
157             rows[c][bx + ix] = 0;
158           } else if (mode == PatchBlendMode::kNone) {
159             // Nothing to do.
160           } else {
161             JXL_UNREACHABLE("Blending mode %u not yet implemented",
162                             (uint32_t)mode);
163           }
164         }
165       }
166     }
167   }
168 }
169
170 namespace {
171
172 struct PatchColorspaceInfo {
173   float kChannelDequant[3];
174   float kChannelWeights[3];
175
176   explicit PatchColorspaceInfo(bool is_xyb) {
177     if (is_xyb) {
178       kChannelDequant[0] = 0.01615;
179       kChannelDequant[1] = 0.08875;
180       kChannelDequant[2] = 0.1922;
181       kChannelWeights[0] = 30.0;
182       kChannelWeights[1] = 3.0;
183       kChannelWeights[2] = 1.0;
184     } else {
185       kChannelDequant[0] = 20.0f / 255;
186       kChannelDequant[1] = 22.0f / 255;
187       kChannelDequant[2] = 20.0f / 255;
188       kChannelWeights[0] = 0.017 * 255;
189       kChannelWeights[1] = 0.02 * 255;
190       kChannelWeights[2] = 0.017 * 255;
191     }
192   }
193
194   float ScaleForQuantization(float val, size_t c) {
195     return val / kChannelDequant[c];
196   }
197
198   int Quantize(float val, size_t c) {
199     return truncf(ScaleForQuantization(val, c));
200   }
201
202   bool is_similar_v(const float v1[3], const float v2[3], float threshold) {
203     float distance = 0;
204     for (size_t c = 0; c < 3; c++) {
205       distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c];
206     }
207     return distance <= threshold;
208   }
209 };
210
211 std::vector<PatchInfo> FindTextLikePatches(
212     const CompressParams& cparams, const Image3F& opsin,
213     const PassesEncoderState* JXL_RESTRICT state, ThreadPool* pool,
214     AuxOut* aux_out, bool is_xyb) {
215   if (state->cparams.patches == Override::kOff) return {};
216   const auto& frame_dim = state->shared.frame_dim;
217
218   PatchColorspaceInfo pci(is_xyb);
219   float kSimilarThreshold = 0.8f;
220
221   auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1,
222                                 std::pair<uint32_t, uint32_t> p2,
223                                 const float* JXL_RESTRICT rows[3],
224                                 size_t stride, float threshold) {
225     float v1[3], v2[3];
226     for (size_t c = 0; c < 3; c++) {
227       v1[c] = rows[c][p1.second * stride + p1.first];
228       v2[c] = rows[c][p2.second * stride + p2.first];
229     }
230     return pci.is_similar_v(v1, v2, threshold);
231   };
232
233   std::atomic<bool> has_screenshot_areas{false};
234   const size_t opsin_stride = opsin.PixelsPerRow();
235   const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0),
236                                              opsin.ConstPlaneRow(1, 0),
237                                              opsin.ConstPlaneRow(2, 0)};
238
239   auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1,
240                                              std::pair<uint32_t, uint32_t> p2) {
241     for (size_t c = 0; c < 3; c++) {
242       float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first];
243       float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first];
244       if (std::fabs(v1 - v2) > 1e-4) {
245         return false;
246       }
247     }
248     return true;
249   };
250
251   auto is_similar = [&](std::pair<uint32_t, uint32_t> p1,
252                         std::pair<uint32_t, uint32_t> p2) {
253     return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold);
254   };
255
256   constexpr int64_t kPatchSide = 4;
257   constexpr int64_t kExtraSide = 4;
258
259   // Look for kPatchSide size squares, naturally aligned, that all have the same
260   // pixel values.
261   ImageB is_screenshot_like(DivCeil(frame_dim.xsize, kPatchSide),
262                             DivCeil(frame_dim.ysize, kPatchSide));
263   ZeroFillImage(&is_screenshot_like);
264   uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0);
265   const size_t screenshot_stride = is_screenshot_like.PixelsPerRow();
266   const auto process_row = [&](const uint32_t y, size_t /* thread */) {
267     for (uint64_t x = 0; x < frame_dim.xsize / kPatchSide; x++) {
268       bool all_same = true;
269       for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) {
270         for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) {
271           size_t cx = x * kPatchSide + ix;
272           size_t cy = y * kPatchSide + iy;
273           if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) {
274             all_same = false;
275             break;
276           }
277         }
278       }
279       if (!all_same) continue;
280       size_t num = 0;
281       size_t num_same = 0;
282       for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) {
283         for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) {
284           int64_t cx = x * kPatchSide + ix;
285           int64_t cy = y * kPatchSide + iy;
286           if (cx < 0 || static_cast<uint64_t>(cx) >= frame_dim.xsize ||  //
287               cy < 0 || static_cast<uint64_t>(cy) >= frame_dim.ysize) {
288             continue;
289           }
290           num++;
291           if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++;
292         }
293       }
294       // Too few equal pixels nearby.
295       if (num_same * 8 < num * 7) continue;
296       screenshot_row[y * screenshot_stride + x] = 1;
297       has_screenshot_areas = true;
298     }
299   };
300   JXL_CHECK(RunOnPool(pool, 0, frame_dim.ysize / kPatchSide, ThreadPool::NoInit,
301                       process_row, "IsScreenshotLike"));
302
303   // TODO(veluca): also parallelize the rest of this function.
304   if (WantDebugOutput(cparams)) {
305     DumpPlaneNormalized(cparams, "screenshot_like", is_screenshot_like);
306   }
307
308   constexpr int kSearchRadius = 1;
309
310   if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) {
311     return {};
312   }
313
314   // Search for "similar enough" pixels near the screenshot-like areas.
315   ImageB is_background(frame_dim.xsize, frame_dim.ysize);
316   ZeroFillImage(&is_background);
317   Image3F background(frame_dim.xsize, frame_dim.ysize);
318   ZeroFillImage(&background);
319   constexpr size_t kDistanceLimit = 50;
320   float* JXL_RESTRICT background_rows[3] = {
321       background.PlaneRow(0, 0),
322       background.PlaneRow(1, 0),
323       background.PlaneRow(2, 0),
324   };
325   const size_t background_stride = background.PixelsPerRow();
326   uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0);
327   const size_t is_background_stride = is_background.PixelsPerRow();
328   std::vector<
329       std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>>
330       queue;
331   size_t queue_front = 0;
332   for (size_t y = 0; y < frame_dim.ysize; y++) {
333     for (size_t x = 0; x < frame_dim.xsize; x++) {
334       if (!screenshot_row[screenshot_stride * (y / kPatchSide) +
335                           (x / kPatchSide)])
336         continue;
337       queue.push_back({{x, y}, {x, y}});
338     }
339   }
340   while (queue.size() != queue_front) {
341     std::pair<uint32_t, uint32_t> cur = queue[queue_front].first;
342     std::pair<uint32_t, uint32_t> src = queue[queue_front].second;
343     queue_front++;
344     if (is_background_row[cur.second * is_background_stride + cur.first])
345       continue;
346     is_background_row[cur.second * is_background_stride + cur.first] = 1;
347     for (size_t c = 0; c < 3; c++) {
348       background_rows[c][cur.second * background_stride + cur.first] =
349           opsin_rows[c][src.second * opsin_stride + src.first];
350     }
351     for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
352       for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
353         if (dx == 0 && dy == 0) continue;
354         int next_first = cur.first + dx;
355         int next_second = cur.second + dy;
356         if (next_first < 0 || next_second < 0 ||
357             static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
358             static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
359           continue;
360         }
361         if (static_cast<uint32_t>(
362                 std::abs(next_first - static_cast<int>(src.first)) +
363                 std::abs(next_second - static_cast<int>(src.second))) >
364             kDistanceLimit) {
365           continue;
366         }
367         std::pair<uint32_t, uint32_t> next{next_first, next_second};
368         if (is_similar(src, next)) {
369           if (!screenshot_row[next.second / kPatchSide * screenshot_stride +
370                               next.first / kPatchSide] ||
371               is_same(src, next)) {
372             if (!is_background_row[next.second * is_background_stride +
373                                    next.first])
374               queue.emplace_back(next, src);
375           }
376         }
377       }
378     }
379   }
380   queue.clear();
381
382   ImageF ccs;
383   Rng rng(0);
384   bool paint_ccs = false;
385   if (WantDebugOutput(cparams)) {
386     DumpPlaneNormalized(cparams, "is_background", is_background);
387     if (is_xyb) {
388       DumpXybImage(cparams, "background", background);
389     } else {
390       DumpImage(cparams, "background", background);
391     }
392     ccs = ImageF(frame_dim.xsize, frame_dim.ysize);
393     ZeroFillImage(&ccs);
394     paint_ccs = true;
395   }
396
397   constexpr float kVerySimilarThreshold = 0.03f;
398   constexpr float kHasSimilarThreshold = 0.03f;
399
400   const float* JXL_RESTRICT const_background_rows[3] = {
401       background_rows[0], background_rows[1], background_rows[2]};
402   auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) {
403     return is_similar_impl(p1, p2, const_background_rows, background_stride,
404                            kVerySimilarThreshold);
405   };
406
407   constexpr int kMinPeak = 2;
408   constexpr int kHasSimilarRadius = 2;
409
410   std::vector<PatchInfo> info;
411
412   // Find small CC outside the "similar enough" areas, compute bounding boxes,
413   // and run heuristics to exclude some patches.
414   ImageB visited(frame_dim.xsize, frame_dim.ysize);
415   ZeroFillImage(&visited);
416   uint8_t* JXL_RESTRICT visited_row = visited.Row(0);
417   const size_t visited_stride = visited.PixelsPerRow();
418   std::vector<std::pair<uint32_t, uint32_t>> cc;
419   std::vector<std::pair<uint32_t, uint32_t>> stack;
420   for (size_t y = 0; y < frame_dim.ysize; y++) {
421     for (size_t x = 0; x < frame_dim.xsize; x++) {
422       if (is_background_row[y * is_background_stride + x]) continue;
423       cc.clear();
424       stack.clear();
425       stack.emplace_back(x, y);
426       size_t min_x = x;
427       size_t max_x = x;
428       size_t min_y = y;
429       size_t max_y = y;
430       std::pair<uint32_t, uint32_t> reference;
431       bool found_border = false;
432       bool all_similar = true;
433       while (!stack.empty()) {
434         std::pair<uint32_t, uint32_t> cur = stack.back();
435         stack.pop_back();
436         if (visited_row[cur.second * visited_stride + cur.first]) continue;
437         visited_row[cur.second * visited_stride + cur.first] = 1;
438         if (cur.first < min_x) min_x = cur.first;
439         if (cur.first > max_x) max_x = cur.first;
440         if (cur.second < min_y) min_y = cur.second;
441         if (cur.second > max_y) max_y = cur.second;
442         if (paint_ccs) {
443           cc.push_back(cur);
444         }
445         for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
446           for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
447             if (dx == 0 && dy == 0) continue;
448             int next_first = static_cast<int32_t>(cur.first) + dx;
449             int next_second = static_cast<int32_t>(cur.second) + dy;
450             if (next_first < 0 || next_second < 0 ||
451                 static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
452                 static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
453               continue;
454             }
455             std::pair<uint32_t, uint32_t> next{next_first, next_second};
456             if (!is_background_row[next.second * is_background_stride +
457                                    next.first]) {
458               stack.push_back(next);
459             } else {
460               if (!found_border) {
461                 reference = next;
462                 found_border = true;
463               } else {
464                 if (!is_similar_b(next, reference)) all_similar = false;
465               }
466             }
467           }
468         }
469       }
470       if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize ||
471           max_y - min_y >= kMaxPatchSize) {
472         continue;
473       }
474       size_t bpos = background_stride * reference.second + reference.first;
475       float ref[3] = {background_rows[0][bpos], background_rows[1][bpos],
476                       background_rows[2][bpos]};
477       bool has_similar = false;
478       for (size_t iy = std::max<int>(
479                static_cast<int32_t>(min_y) - kHasSimilarRadius, 0);
480            iy < std::min(max_y + kHasSimilarRadius + 1, frame_dim.ysize);
481            iy++) {
482         for (size_t ix = std::max<int>(
483                  static_cast<int32_t>(min_x) - kHasSimilarRadius, 0);
484              ix < std::min(max_x + kHasSimilarRadius + 1, frame_dim.xsize);
485              ix++) {
486           size_t opos = opsin_stride * iy + ix;
487           float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos],
488                          opsin_rows[2][opos]};
489           if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) {
490             has_similar = true;
491           }
492         }
493       }
494       if (!has_similar) continue;
495       info.emplace_back();
496       info.back().second.emplace_back(min_x, min_y);
497       QuantizedPatch& patch = info.back().first;
498       patch.xsize = max_x - min_x + 1;
499       patch.ysize = max_y - min_y + 1;
500       int max_value = 0;
501       for (size_t c : {1, 0, 2}) {
502         for (size_t iy = min_y; iy <= max_y; iy++) {
503           for (size_t ix = min_x; ix <= max_x; ix++) {
504             size_t offset = (iy - min_y) * patch.xsize + ix - min_x;
505             patch.fpixels[c][offset] =
506                 opsin_rows[c][iy * opsin_stride + ix] - ref[c];
507             int val = pci.Quantize(patch.fpixels[c][offset], c);
508             patch.pixels[c][offset] = val;
509             if (std::abs(val) > max_value) max_value = std::abs(val);
510           }
511         }
512       }
513       if (max_value < kMinPeak) {
514         info.pop_back();
515         continue;
516       }
517       if (paint_ccs) {
518         float cc_color = rng.UniformF(0.5, 1.0);
519         for (std::pair<uint32_t, uint32_t> p : cc) {
520           ccs.Row(p.second)[p.first] = cc_color;
521         }
522       }
523     }
524   }
525
526   if (paint_ccs) {
527     JXL_ASSERT(WantDebugOutput(cparams));
528     DumpPlaneNormalized(cparams, "ccs", ccs);
529   }
530   if (info.empty()) {
531     return {};
532   }
533
534   // Remove duplicates.
535   constexpr size_t kMinPatchOccurrences = 2;
536   std::sort(info.begin(), info.end());
537   size_t unique = 0;
538   for (size_t i = 1; i < info.size(); i++) {
539     if (info[i].first == info[unique].first) {
540       info[unique].second.insert(info[unique].second.end(),
541                                  info[i].second.begin(), info[i].second.end());
542     } else {
543       if (info[unique].second.size() >= kMinPatchOccurrences) {
544         unique++;
545       }
546       info[unique] = info[i];
547     }
548   }
549   if (info[unique].second.size() >= kMinPatchOccurrences) {
550     unique++;
551   }
552   info.resize(unique);
553
554   size_t max_patch_size = 0;
555
556   for (size_t i = 0; i < info.size(); i++) {
557     size_t pixels = info[i].first.xsize * info[i].first.ysize;
558     if (pixels > max_patch_size) max_patch_size = pixels;
559   }
560
561   // don't use patches if all patches are smaller than this
562   constexpr size_t kMinMaxPatchSize = 20;
563   if (max_patch_size < kMinMaxPatchSize) return {};
564
565   return info;
566 }
567
568 }  // namespace
569
570 void FindBestPatchDictionary(const Image3F& opsin,
571                              PassesEncoderState* JXL_RESTRICT state,
572                              const JxlCmsInterface& cms, ThreadPool* pool,
573                              AuxOut* aux_out, bool is_xyb) {
574   std::vector<PatchInfo> info =
575       FindTextLikePatches(state->cparams, opsin, state, pool, aux_out, is_xyb);
576
577   // TODO(veluca): this doesn't work if both dots and patches are enabled.
578   // For now, since dots and patches are not likely to occur in the same kind of
579   // images, disable dots if some patches were found.
580   if (info.empty() &&
581       ApplyOverride(
582           state->cparams.dots,
583           state->cparams.speed_tier <= SpeedTier::kSquirrel &&
584               state->cparams.butteraugli_distance >= kMinButteraugliForDots)) {
585     info = FindDotDictionary(state->cparams, opsin, state->shared.cmap, pool);
586   }
587
588   if (info.empty()) return;
589
590   std::sort(
591       info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) {
592         return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize;
593       });
594
595   size_t max_x_size = 0;
596   size_t max_y_size = 0;
597   size_t total_pixels = 0;
598
599   for (size_t i = 0; i < info.size(); i++) {
600     size_t pixels = info[i].first.xsize * info[i].first.ysize;
601     if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize;
602     if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize;
603     total_pixels += pixels;
604   }
605
606   // Bin-packing & conversion of patches.
607   constexpr float kBinPackingSlackness = 1.05f;
608   size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels));
609   size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels));
610   std::vector<std::pair<size_t, size_t>> ref_positions(info.size());
611   // TODO(veluca): allow partial overlaps of patches that have the same pixels.
612   size_t max_y = 0;
613   do {
614     max_y = 0;
615     // Increase packed image size.
616     ref_xsize = ref_xsize * kBinPackingSlackness + 1;
617     ref_ysize = ref_ysize * kBinPackingSlackness + 1;
618
619     ImageB occupied(ref_xsize, ref_ysize);
620     ZeroFillImage(&occupied);
621     uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0);
622     size_t occupied_stride = occupied.PixelsPerRow();
623
624     bool success = true;
625     // For every patch...
626     for (size_t patch = 0; patch < info.size(); patch++) {
627       size_t x0 = 0;
628       size_t y0 = 0;
629       size_t xsize = info[patch].first.xsize;
630       size_t ysize = info[patch].first.ysize;
631       bool found = false;
632       // For every possible start position ...
633       for (; y0 + ysize <= ref_ysize; y0++) {
634         x0 = 0;
635         for (; x0 + xsize <= ref_xsize; x0++) {
636           bool has_occupied_pixel = false;
637           size_t x = x0;
638           // Check if it is possible to place the patch in this position in the
639           // reference frame.
640           for (size_t y = y0; y < y0 + ysize; y++) {
641             x = x0;
642             for (; x < x0 + xsize; x++) {
643               if (occupied_rows[y * occupied_stride + x]) {
644                 has_occupied_pixel = true;
645                 break;
646               }
647             }
648           }  // end of positioning check
649           if (!has_occupied_pixel) {
650             found = true;
651             break;
652           }
653           x0 = x;  // Jump to next pixel after the occupied one.
654         }
655         if (found) break;
656       }  // end of start position checking
657
658       // We didn't find a possible position: repeat from the beginning with a
659       // larger reference frame size.
660       if (!found) {
661         success = false;
662         break;
663       }
664
665       // We found a position: mark the corresponding positions in the reference
666       // image as used.
667       ref_positions[patch] = {x0, y0};
668       for (size_t y = y0; y < y0 + ysize; y++) {
669         for (size_t x = x0; x < x0 + xsize; x++) {
670           occupied_rows[y * occupied_stride + x] = true;
671         }
672       }
673       max_y = std::max(max_y, y0 + ysize);
674     }
675
676     if (success) break;
677   } while (true);
678
679   JXL_ASSERT(ref_ysize >= max_y);
680
681   ref_ysize = max_y;
682
683   Image3F reference_frame(ref_xsize, ref_ysize);
684   // TODO(veluca): figure out a better way to fill the image.
685   ZeroFillImage(&reference_frame);
686   std::vector<PatchPosition> positions;
687   std::vector<PatchReferencePosition> pref_positions;
688   std::vector<PatchBlending> blendings;
689   float* JXL_RESTRICT ref_rows[3] = {
690       reference_frame.PlaneRow(0, 0),
691       reference_frame.PlaneRow(1, 0),
692       reference_frame.PlaneRow(2, 0),
693   };
694   size_t ref_stride = reference_frame.PixelsPerRow();
695   size_t num_ec = state->shared.metadata->m.num_extra_channels;
696
697   for (size_t i = 0; i < info.size(); i++) {
698     PatchReferencePosition ref_pos;
699     ref_pos.xsize = info[i].first.xsize;
700     ref_pos.ysize = info[i].first.ysize;
701     ref_pos.x0 = ref_positions[i].first;
702     ref_pos.y0 = ref_positions[i].second;
703     ref_pos.ref = kPatchFrameReferenceId;
704     for (size_t y = 0; y < ref_pos.ysize; y++) {
705       for (size_t x = 0; x < ref_pos.xsize; x++) {
706         for (size_t c = 0; c < 3; c++) {
707           ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] =
708               info[i].first.fpixels[c][y * ref_pos.xsize + x];
709         }
710       }
711     }
712     for (const auto& pos : info[i].second) {
713       positions.emplace_back(
714           PatchPosition{pos.first, pos.second, pref_positions.size()});
715       // Add blending for color channels, ignore other channels.
716       blendings.push_back({PatchBlendMode::kAdd, 0, false});
717       for (size_t j = 0; j < num_ec; ++j) {
718         blendings.push_back({PatchBlendMode::kNone, 0, false});
719       }
720     }
721     pref_positions.emplace_back(std::move(ref_pos));
722   }
723
724   CompressParams cparams = state->cparams;
725   // Recursive application of patches could create very weird issues.
726   cparams.patches = Override::kOff;
727
728   RoundtripPatchFrame(&reference_frame, state, kPatchFrameReferenceId, cparams,
729                       cms, pool, aux_out, /*subtract=*/true);
730
731   // TODO(veluca): this assumes that applying patches is commutative, which is
732   // not true for all blending modes. This code only produces kAdd patches, so
733   // this works out.
734   PatchDictionaryEncoder::SetPositions(
735       &state->shared.image_features.patches, std::move(positions),
736       std::move(pref_positions), std::move(blendings));
737 }
738
739 void RoundtripPatchFrame(Image3F* reference_frame,
740                          PassesEncoderState* JXL_RESTRICT state, int idx,
741                          CompressParams& cparams, const JxlCmsInterface& cms,
742                          ThreadPool* pool, AuxOut* aux_out, bool subtract) {
743   FrameInfo patch_frame_info;
744   cparams.resampling = 1;
745   cparams.ec_resampling = 1;
746   cparams.dots = Override::kOff;
747   cparams.noise = Override::kOff;
748   cparams.modular_mode = true;
749   cparams.responsive = 0;
750   cparams.progressive_dc = 0;
751   cparams.progressive_mode = false;
752   cparams.qprogressive_mode = false;
753   // Use gradient predictor and not Predictor::Best.
754   cparams.options.predictor = Predictor::Gradient;
755   patch_frame_info.save_as_reference = idx;  // always saved.
756   patch_frame_info.frame_type = FrameType::kReferenceOnly;
757   patch_frame_info.save_before_color_transform = true;
758   ImageBundle ib(&state->shared.metadata->m);
759   // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is
760   // no simple way to express that yet.
761   patch_frame_info.ib_needs_color_transform = false;
762   ib.SetFromImage(std::move(*reference_frame),
763                   state->shared.metadata->m.color_encoding);
764   if (!ib.metadata()->extra_channel_info.empty()) {
765     // Add placeholder extra channels to the patch image: patch encoding does
766     // not yet support extra channels, but the codec expects that the amount of
767     // extra channels in frames matches that in the metadata of the codestream.
768     std::vector<ImageF> extra_channels;
769     extra_channels.reserve(ib.metadata()->extra_channel_info.size());
770     for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) {
771       extra_channels.emplace_back(ib.xsize(), ib.ysize());
772       // Must initialize the image with data to not affect blending with
773       // uninitialized memory.
774       // TODO(lode): patches must copy and use the real extra channels instead.
775       ZeroFillImage(&extra_channels.back());
776     }
777     ib.SetExtraChannels(std::move(extra_channels));
778   }
779   PassesEncoderState roundtrip_state;
780   auto special_frame = std::unique_ptr<BitWriter>(new BitWriter());
781   AuxOut patch_aux_out;
782   JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib,
783                         &roundtrip_state, cms, pool, special_frame.get(),
784                         aux_out ? &patch_aux_out : nullptr));
785   if (aux_out) {
786     for (const auto& l : patch_aux_out.layers) {
787       aux_out->layers[kLayerDictionary].Assimilate(l);
788     }
789   }
790   const Span<const uint8_t> encoded = special_frame->GetSpan();
791   state->special_frames.emplace_back(std::move(special_frame));
792   if (subtract) {
793     ImageBundle decoded(&state->shared.metadata->m);
794     PassesDecoderState dec_state;
795     JXL_CHECK(dec_state.output_encoding_info.SetFromMetadata(
796         *state->shared.metadata));
797     const uint8_t* frame_start = encoded.data();
798     size_t encoded_size = encoded.size();
799     JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size, &decoded,
800                           *state->shared.metadata));
801     frame_start += decoded.decoded_bytes();
802     encoded_size -= decoded.decoded_bytes();
803     size_t ref_xsize =
804         dec_state.shared_storage.reference_frames[idx].frame.color()->xsize();
805     // if the frame itself uses patches, we need to decode another frame
806     if (!ref_xsize) {
807       JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size,
808                             &decoded, *state->shared.metadata));
809     }
810     JXL_CHECK(encoded_size == 0);
811     state->shared.reference_frames[idx] =
812         std::move(dec_state.shared_storage.reference_frames[idx]);
813   } else {
814     state->shared.reference_frames[idx].frame = std::move(ib);
815   }
816 }
817
818 }  // namespace jxl