1 /* Karatsuba convolution
3 * Copyright (C) 1999 Ralph Loader <suckfish@ihug.co.nz>
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License
16 * along with this program; if not, write to the Free Software
17 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24 /* The algorithm is based on the following. For the convolution of a pair
25 * of pairs, (a,b) * (c,d) = (0, a.c, a.d+b.c, b.d), we can reduce the four
26 * multiplications to three, by the formulae a.d+b.c = (a+b).(c+d) - a.c -
27 * b.d. A similar relation enables us to compute a 2n by 2n convolution
28 * using 3 n by n convolutions, and thus a 2^n by 2^n convolution using 3^n
29 * multiplications (as opposed to the 4^n that the quadratic algorithm
32 /* For large n, this is slower than the O(n log n) that the FFT method
33 * takes, but we avoid using complex numbers, and we only have to compute
34 * one convolution, as opposed to 3 FFTs. We have good locality-of-
35 * reference as well, which will help on CPUs with tiny caches. */
37 /* E.g., for a 512 x 512 convolution, the FFT method takes 55 * 512 = 28160
38 * (real) multiplications, as opposed to 3^9 = 19683 for the Karatsuba
39 * algorithm. We actually want 257 outputs of a 256 x 512 convolution;
40 * that doesn't appear to give an easy advantage for the FFT algorithm, but
41 * for the Karatsuba algorithm, it's easy to use two 256 x 256
42 * convolutions, taking 2 x 3^8 = 12312 multiplications. [This difference
43 * is that the FFT method "wraps" the arrays, doing a 2^n x 2^n -> 2^n,
44 * while the Karatsuba algorithm pads with zeros, doing 2^n x 2^n -> 2.2^n
47 /* There's a big lie above, actually... for a 4x4 convolution, it's quicker
48 * to do it using 16 multiplications than the more complex Karatsuba
49 * algorithm... So the recursion bottoms out at 4x4s. This increases the
50 * number of multiplications by a factor of 16/9, but reduces the overheads
53 /* The convolution algorithm is implemented as a stack machine. We have a
54 * stack of commands, each in one of the forms "do a 2^n x 2^n
55 * convolution", or "combine these three length 2^n outputs into one
65 typedef union stack_entry_s
69 const double *left, *right;
82 #define STACK_SIZE (CONVOLVE_DEPTH * 3)
84 struct _struct_convolve_state
86 double left[CONVOLVE_BIG];
87 double right[CONVOLVE_SMALL * 3];
88 double scratch[CONVOLVE_SMALL * 3];
89 stack_entry stack[STACK_SIZE];
93 * Initialisation routine - sets up tables and space to work in.
94 * Returns a pointer to internal state, to be used when performing calls.
95 * On error, returns NULL.
96 * The pointer should be freed when it is finished with, by convolve_close().
101 return (convolve_state *) malloc (sizeof (convolve_state));
105 * Free the state allocated with convolve_init().
108 convolve_close (convolve_state * state)
115 convolve_4 (double *out, const double *left, const double *right)
116 /* This does a 4x4 -> 7 convolution. For what it's worth, the slightly odd
117 * ordering gives about a 1% speed up on my Pentium II. */
119 double l0, l1, l2, l3, r0, r1, r2, r3;
128 a = (l0 * r1) + (l1 * r0);
132 a = (l0 * r2) + (l1 * r1) + (l2 * r0);
137 out[3] = (l0 * r3) + (l1 * r2) + (l2 * r1) + (l3 * r0);
138 out[4] = (l1 * r3) + (l2 * r2) + (l3 * r1);
139 out[5] = (l2 * r3) + (l3 * r2);
144 convolve_run (stack_entry * top, unsigned size, double *scratch)
145 /* Interpret a stack of commands. The stack starts with two entries; the
146 * convolution to do, and an illegal entry used to mark the stack top. The
147 * size is the number of entries in each input, and must be a power of 2,
148 * and at least 8. It is OK to have out equal to left and/or right.
149 * scratch must have length 3*size. The number of stack entries needed is
150 * 3n-4 where size=2^n. */
157 /* When we get here, the stack top is always a convolve,
158 * with size > 4. So we will split it. We repeatedly split
159 * the top entry until we get to size = 4. */
162 right = top->v.right;
167 double *s_left, *s_right;
170 /* Halve the size. */
173 /* Allocate the scratch areas. */
174 s_left = scratch + size * 3;
175 /* s_right is a length 2*size buffer also used for
176 * intermediate output. */
177 s_right = scratch + size * 4;
179 /* Create the intermediate factors. */
180 for (i = 0; i < size; i++) {
181 double l = left[i] + left[i + size];
182 double r = right[i] + right[i + size];
184 s_left[i + size] = r;
188 /* Push the combine entry onto the stack. */
191 top[2].b.null = NULL;
193 /* Push the low entry onto the stack. This must be
194 * the last of the three sub-convolutions, because
195 * it may overwrite the arguments. */
196 top[1].v.left = left;
197 top[1].v.right = right;
200 /* Push the mid entry onto the stack. */
201 top[0].v.left = s_left;
202 top[0].v.right = s_right;
203 top[0].v.out = s_right;
205 /* Leave the high entry in variables. */
212 /* When we get here, the stack top is a group of 3
213 * convolves, with size = 4, followed by some combines. */
214 convolve_4 (out, left, right);
215 convolve_4 (top[0].v.out, top[0].v.left, top[0].v.right);
216 convolve_4 (top[1].v.out, top[1].v.left, top[1].v.right);
219 /* Now process combines. */
221 /* b.main is the output buffer, mid is the middle
222 * part which needs to be adjusted in place, and
223 * then folded back into the output. We do this in
224 * a slightly strange way, so as to avoid having
226 double *out = top->b.main;
227 double *mid = scratch + size * 4;
231 out[size * 2 - 1] = 0;
232 for (i = 0; i < size - 1; i++) {
236 lo = mid[0] - (out[0] + out[2 * size]) + out[size];
237 hi = mid[size] - (out[size] + out[3 * size]) + out[2 * size];
244 } while (top->b.null == NULL);
245 } while (top->b.main != NULL);
249 convolve_match (const int *lastchoice,
250 const short *input, convolve_state * state)
251 /* lastchoice is a 256 sized array. input is a 512 array. We find the
252 * contiguous length 256 sub-array of input that best matches lastchoice.
253 * A measure of how good a sub-array is compared with the lastchoice is
254 * given by the sum of the products of each pair of entries. We maximise
255 * that, by taking an appropriate convolution, and then finding the maximum
256 * entry in the convolutions. state is a (non-NULL) pointer returned by
263 double *left = state->left;
264 double *right = state->right;
265 double *scratch = state->scratch;
266 stack_entry *top = state->stack + STACK_SIZE - 1;
269 for (i = 0; i < 512; i++)
273 for (i = 0; i < 256; i++) {
274 double a = lastchoice[255 - i];
280 /* We adjust the smaller of the two input arrays to have average
281 * value 0. This makes the eventual result insensitive to both
282 * constant offsets and positive multipliers of the inputs. */
284 for (i = 0; i < 256; i++)
286 /* End-of-stack marker. */
287 #if 0 /* The following line produces a CRASH, need to figure out why?!! */
288 top[1].b.null = scratch;
290 top[1].b.main = NULL;
291 /* The low 256x256, of which we want the high 256 outputs. */
293 top->v.right = right;
294 top->v.out = right + 256;
295 convolve_run (top, 256, scratch);
297 /* The high 256x256, of which we want the low 256 outputs. */
298 top->v.left = left + 256;
299 top->v.right = right;
301 convolve_run (top, 256, scratch);
303 /* Now find the best position amoungs this. Apart from the first
304 * and last, the required convolution outputs are formed by adding
305 * outputs from the two convolutions above. */
309 for (i = 0; i < 256; i++) {
310 double a = right[i] + right[i + 512];
321 /* This is some debugging code... */
325 for (i = 0; i < 256; i++)
326 best += ((double) input[i + p]) * ((double) lastchoice[i] - avg);
328 for (i = 0; i < 257; i++) {
332 for (j = 0; j < 256; j++)
333 tot += ((double) input[i + j]) * ((double) lastchoice[j] - avg);
336 if (tot != left[i + 255])