audiofxbasefirfilter: FFT convolution implementation
authorSebastian Dröge <sebastian.droege@collabora.co.uk>
Thu, 3 Dec 2009 16:27:13 +0000 (17:27 +0100)
committerSebastian Dröge <sebastian.droege@collabora.co.uk>
Tue, 15 Dec 2009 17:12:46 +0000 (18:12 +0100)
This provides a great speedup, especially the relationship between kernel
length and processing size is now logarithmic instead of linear. Below a
kernel size of 32 it's a bit slower, afterwards it's much faster:

17     0.788000 -> 0.950000
33     1.208000 -> 1.146000
65     2.166000 -> 1.146000
...
4097 107.444000 -> 1.508000

For sizes smaller 32 the normal time-domain convolution is chosen,
for larger sizes the FFT convolution is automatically used.

Fixes bug #594381.

gst/audiofx/Makefile.am
gst/audiofx/audiofxbasefirfilter.c
gst/audiofx/audiofxbasefirfilter.h

index 22f5fd0..bd66963 100644 (file)
@@ -29,6 +29,7 @@ libgstaudiofx_la_LIBADD = $(GST_LIBS) \
        $(GST_CONTROLLER_LIBS) \
        $(GST_PLUGINS_BASE_LIBS) \
        -lgstaudio-$(GST_MAJORMINOR) \
+       -lgstfft-$(GST_MAJORMINOR) \
        $(LIBM)
 libgstaudiofx_la_LDFLAGS = $(GST_PLUGIN_LDFLAGS)
 libgstaudiofx_la_LIBTOOLFLAGS = --tag=disable-static
index c66bd26..c75f98d 100644 (file)
@@ -59,6 +59,9 @@ GST_DEBUG_CATEGORY_STATIC (GST_CAT_DEFAULT);
   GST_DEBUG_CATEGORY_INIT (gst_audio_fx_base_fir_filter_debug, "audiofxbasefirfilter", 0, \
       "FIR filter base class");
 
+/* Switch from time-domain to FFT convolution for kernels >= this */
+#define FFT_THRESHOLD 32
+
 GST_BOILERPLATE_FULL (GstAudioFXBaseFIRFilter, gst_audio_fx_base_fir_filter,
     GstAudioFilter, GST_TYPE_AUDIO_FILTER, DEBUG_INIT);
 
@@ -68,6 +71,9 @@ static gboolean gst_audio_fx_base_fir_filter_start (GstBaseTransform * base);
 static gboolean gst_audio_fx_base_fir_filter_stop (GstBaseTransform * base);
 static gboolean gst_audio_fx_base_fir_filter_event (GstBaseTransform * base,
     GstEvent * event);
+static gboolean gst_audio_fx_base_fir_filter_transform_size (GstBaseTransform *
+    base, GstPadDirection direction, GstCaps * caps, guint size,
+    GstCaps * othercaps, guint * othersize);
 static gboolean gst_audio_fx_base_fir_filter_setup (GstAudioFilter * base,
     GstRingBufferSpec * format);
 
@@ -83,16 +89,23 @@ gst_audio_fx_base_fir_filter_dispose (GObject * object)
 {
   GstAudioFXBaseFIRFilter *self = GST_AUDIO_FX_BASE_FIR_FILTER (object);
 
-  if (self->buffer) {
-    g_free (self->buffer);
-    self->buffer = NULL;
-    self->buffer_length = 0;
-  }
+  g_free (self->buffer);
+  self->buffer = NULL;
+  self->buffer_length = 0;
 
-  if (self->kernel) {
-    g_free (self->kernel);
-    self->kernel = NULL;
-  }
+  g_free (self->kernel);
+  self->kernel = NULL;
+
+  gst_fft_f64_free (self->fft);
+  self->fft = NULL;
+  gst_fft_f64_free (self->ifft);
+  self->ifft = NULL;
+
+  g_free (self->frequency_response);
+  self->frequency_response = NULL;
+
+  g_free (self->fft_buffer);
+  self->fft_buffer = NULL;
 
   G_OBJECT_CLASS (parent_class)->dispose (object);
 }
@@ -122,6 +135,8 @@ gst_audio_fx_base_fir_filter_class_init (GstAudioFXBaseFIRFilterClass * klass)
   trans_class->start = GST_DEBUG_FUNCPTR (gst_audio_fx_base_fir_filter_start);
   trans_class->stop = GST_DEBUG_FUNCPTR (gst_audio_fx_base_fir_filter_stop);
   trans_class->event = GST_DEBUG_FUNCPTR (gst_audio_fx_base_fir_filter_event);
+  trans_class->transform_size =
+      GST_DEBUG_FUNCPTR (gst_audio_fx_base_fir_filter_transform_size);
   filter_class->setup = GST_DEBUG_FUNCPTR (gst_audio_fx_base_fir_filter_setup);
 }
 
@@ -144,6 +159,176 @@ gst_audio_fx_base_fir_filter_init (GstAudioFXBaseFIRFilter * self,
       gst_audio_fx_base_fir_filter_query_type);
 }
 
+/* This implements FFT convolution and uses the overlap-save algorithm.
+ * See http://cnx.org/content/m12022/latest/ or your favorite
+ * digital signal processing book for details.
+ *
+ * In every pass the following is calculated:
+ *
+ * y = IFFT (FFT(x) * FFT(h))
+ *
+ * where y is the output in the time domain, x the
+ * input and h the filter kernel. * is the multiplication
+ * of complex numbers.
+ *
+ * Due to the circular convolution theorem this
+ * gives in the time domain:
+ *
+ * y[t] = \sum_{u=0}^{M-1} x[t - u] * h[u]
+ *
+ * where y is the output, M is the kernel length,
+ * x the periodically extended[0] input and h the
+ * filter kernel.
+ *
+ * ([0] Periodically extended means:    )
+ * (    x[t] = x[t+kN] \forall k \in Z  )
+ * (    where N is the length of x      )
+ *
+ * This means:
+ * - Obviously x and h need to be of the same size for the FFT
+ * - The first M-1 output values are useless because they're
+ *   built from 1 up to M-1 values from the end of the input
+ *   (circular convolusion!).
+ * - The last M-1 input values are only used for 1 up to M-1
+ *   output values, i.e. they need to be used again in the
+ *   next pass for the first M-1 input values.
+ *
+ * => The first pass needs M-1 zeroes at the beginning of the
+ * input and the last M-1 input values of every pass need to
+ * be used as the first M-1 input values of the next pass.
+ *
+ * => x must be larger than h to give a useful number of output
+ * samples and h needs to be padded by zeroes at the end to give
+ * it virtually the same size as x (by M we denote the number of
+ * non-padding samples of h). If len(x)==len(h)==M only 1 output
+ * sample would be calculated per pass, len(x)==2*len(h) would
+ * give M+1 output samples, etc. Usually a factor between 4 and 8
+ * gives a low number of operations per output samples (see website
+ * given above).
+ *
+ * Overall this gives a runtime complexity per sample of
+ *
+ *   (  N log N  )
+ * O ( --------- ) compared to O (M) for the direct calculation.
+ *   ( N - M + 1 )
+ */
+#define DEFINE_FFT_PROCESS_FUNC(width,ctype) \
+static guint \
+process_fft_##width (GstAudioFXBaseFIRFilter * self, const g##ctype * src, \
+    g##ctype * dst, guint input_samples) \
+{ \
+  gint channels = GST_AUDIO_FILTER_CAST (self)->format.channels; \
+  gint i, j; \
+  guint pass; \
+  guint kernel_length = self->kernel_length; \
+  guint block_length = self->block_length; \
+  guint buffer_length = self->buffer_length; \
+  guint real_buffer_length = buffer_length + kernel_length - 1; \
+  guint buffer_fill = self->buffer_fill; \
+  GstFFTF64 *fft = self->fft; \
+  GstFFTF64 *ifft = self->ifft; \
+  GstFFTF64Complex *frequency_response = self->frequency_response; \
+  GstFFTF64Complex *fft_buffer = self->fft_buffer; \
+  guint frequency_response_length = self->frequency_response_length; \
+  gdouble *buffer = self->buffer; \
+  guint generated = 0; \
+  gdouble re, im; \
+  \
+  input_samples /= channels; \
+  \
+  if (!fft_buffer) \
+    self->fft_buffer = fft_buffer = \
+        g_new (GstFFTF64Complex, frequency_response_length); \
+  \
+  /* Buffer contains the time domain samples of input data for one chunk \
+   * plus some more space for the inverse FFT below. \
+   * \
+   * The samples are put at offset kernel_length, the inverse FFT \
+   * overwrites everthing from offset 0 to length-kernel_length+1, keeping \
+   * the last kernel_length-1 samples for copying to the next processing \
+   * step. \
+   */ \
+  if (!buffer) { \
+    self->buffer_length = buffer_length = block_length; \
+    real_buffer_length = buffer_length + kernel_length - 1; \
+    \
+    self->buffer = buffer = g_new0 (gdouble, real_buffer_length * channels); \
+    \
+    /* Beginning has kernel_length-1 zeroes at the beginning */ \
+    self->buffer_fill = buffer_fill = kernel_length - 1; \
+  } \
+  \
+  while (input_samples) { \
+    pass = MIN (buffer_length - buffer_fill, input_samples); \
+    \
+    /* Deinterleave channels */ \
+    for (i = 0; i < pass; i++) { \
+      for (j = 0; j < channels; j++) { \
+        buffer[real_buffer_length * j + buffer_fill + kernel_length - 1 + i] = \
+            src[i * channels + j]; \
+      } \
+    } \
+    buffer_fill += pass; \
+    src += channels * pass; \
+    input_samples -= channels * pass; \
+    \
+    /* If we don't have a complete buffer go out */ \
+    if (buffer_fill < buffer_length) \
+      break; \
+    \
+    for (j = 0; j < channels; j++) { \
+      /* Calculate FFT of input block */ \
+      gst_fft_f64_fft (fft, \
+          buffer + real_buffer_length * j + kernel_length - 1, fft_buffer); \
+      \
+      /* Complex multiplication of input and filter spectrum */ \
+      for (i = 0; i < frequency_response_length; i++) { \
+       re = fft_buffer[i].r; \
+       im = fft_buffer[i].i; \
+        \
+        fft_buffer[i].r = \
+            re * frequency_response[i].r - \
+            im * frequency_response[i].i; \
+        fft_buffer[i].i = \
+            re * frequency_response[i].i + \
+            im * frequency_response[i].r; \
+      } \
+      \
+      /* Calculate inverse FFT of the result */ \
+      gst_fft_f64_inverse_fft (ifft, fft_buffer, \
+          buffer + real_buffer_length * j); \
+      \
+      /* Copy all except the first kernel_length-1 samples to the output */ \
+      for (i = 0; i < buffer_length - kernel_length + 1; i++) { \
+        dst[i * channels + j] = \
+            buffer[real_buffer_length * j + kernel_length - 1 + i]; \
+      } \
+      \
+      /* Copy the last kernel_length-1 samples to the beginning for the next block */ \
+      for (i = 0; i < kernel_length - 1; i++) { \
+        buffer[real_buffer_length * j + kernel_length - 1 + i] = \
+            buffer[real_buffer_length * j + buffer_length + i]; \
+      } \
+    } \
+    \
+    generated += buffer_length - kernel_length + 1; \
+    dst += channels * (buffer_length - kernel_length + 1); \
+    \
+    /* The the first kernel_length-1 samples are there already */ \
+    buffer_fill = kernel_length - 1; \
+  } \
+  \
+  /* Write back cached buffer_fill value */ \
+  self->buffer_fill = buffer_fill; \
+  \
+  return generated; \
+}
+
+DEFINE_FFT_PROCESS_FUNC (32, float);
+DEFINE_FFT_PROCESS_FUNC (64, double);
+
+#undef DEFINE_FFT_PROCESS_FUNC
+
 /* 
  * The code below calculates the linear convolution:
  *
@@ -231,7 +416,6 @@ gst_audio_fx_base_fir_filter_push_residue (GstAudioFXBaseFIRFilter * self)
   gint channels = GST_AUDIO_FILTER_CAST (self)->format.channels;
   gint width = GST_AUDIO_FILTER_CAST (self)->format.width / 8;
   guint outsize, outsamples;
-  gint64 diffsize, diffsamples;
   guint8 *in, *out;
 
   if (channels == 0 || rate == 0 || self->nsamples_in == 0) {
@@ -252,39 +436,66 @@ gst_audio_fx_base_fir_filter_push_residue (GstAudioFXBaseFIRFilter * self)
   }
   outsize = outsamples * channels * width;
 
-  /* Process the difference between latency and residue length samples
-   * to start at the actual data instead of starting at the zeros before
-   * when we only got one buffer smaller than latency */
-
-  /* FIXME: still time domain convolution specific */
-  diffsamples =
-      ((gint64) self->latency) - ((gint64) self->buffer_fill) / channels;
-  if (diffsamples > 0) {
-    diffsize = diffsamples * channels * width;
-    in = g_new0 (guint8, diffsize);
-    out = g_new0 (guint8, diffsize);
-    self->nsamples_out += self->process (self, in, out, diffsamples * channels);
+  if (!self->fft) {
+    gint64 diffsize, diffsamples;
+
+    /* Process the difference between latency and residue length samples
+     * to start at the actual data instead of starting at the zeros before
+     * when we only got one buffer smaller than latency */
+    diffsamples =
+        ((gint64) self->latency) - ((gint64) self->buffer_fill) / channels;
+    if (diffsamples > 0) {
+      diffsize = diffsamples * channels * width;
+      in = g_new0 (guint8, diffsize);
+      out = g_new0 (guint8, diffsize);
+      self->nsamples_out +=
+          self->process (self, in, out, diffsamples * channels);
+      g_free (in);
+      g_free (out);
+    }
+
+    res = gst_pad_alloc_buffer (GST_BASE_TRANSFORM_CAST (self)->srcpad,
+        GST_BUFFER_OFFSET_NONE, outsize,
+        GST_PAD_CAPS (GST_BASE_TRANSFORM_CAST (self)->srcpad), &outbuf);
+
+    if (G_UNLIKELY (res != GST_FLOW_OK)) {
+      GST_WARNING_OBJECT (self, "failed allocating buffer of %d bytes",
+          outsize);
+      self->buffer_fill = 0;
+      return;
+    }
+
+    /* Convolve the residue with zeros to get the actual remaining data */
+    in = g_new0 (guint8, outsize);
+    self->nsamples_out +=
+        self->process (self, in, GST_BUFFER_DATA (outbuf),
+        outsamples * channels);
     g_free (in);
-    g_free (out);
-  }
+  } else {
+    guint gensamples = 0;
+    guint8 *data;
 
-  res = gst_pad_alloc_buffer (GST_BASE_TRANSFORM_CAST (self)->srcpad,
-      GST_BUFFER_OFFSET_NONE, outsize,
-      GST_PAD_CAPS (GST_BASE_TRANSFORM_CAST (self)->srcpad), &outbuf);
+    outbuf = gst_buffer_new_and_alloc (outsize);
+    data = GST_BUFFER_DATA (outbuf);
 
-  if (G_UNLIKELY (res != GST_FLOW_OK)) {
-    GST_WARNING_OBJECT (self, "failed allocating buffer of %d bytes", outsize);
-    self->buffer_fill = 0;
-    return;
-  }
+    while (gensamples < outsamples) {
+      guint step_insamples =
+          (self->block_length - self->buffer_fill) * channels;
+      guint8 *zeroes = g_new0 (guint8, step_insamples * width);
+      guint8 *out = g_new (guint8, self->block_length * channels * width);
+      guint step_gensamples;
 
-  /* Convolve the residue with zeros to get the actual remaining data */
-  in = g_new0 (guint8, outsize);
-  self->nsamples_out +=
-      self->process (self, in, GST_BUFFER_DATA (outbuf), outsamples * channels);
-  g_free (in);
+      step_gensamples = self->process (self, zeroes, out, step_insamples);
+      g_free (zeroes);
 
-  /* FIXME: time domain convolution specific */
+      memcpy (data + gensamples * width, out, MIN (step_gensamples,
+              outsamples - gensamples) * width);
+      gensamples += MIN (step_gensamples, outsamples - gensamples);
+
+      g_free (out);
+    }
+    self->nsamples_out += gensamples;
+  }
 
   /* Set timestamp, offset, etc from the values we
    * saved when processing the regular buffers */
@@ -343,18 +554,53 @@ gst_audio_fx_base_fir_filter_setup (GstAudioFilter * base,
     self->nsamples_in = 0;
   }
 
-  if (format->width == 32)
+  if (format->width == 32 && self->fft)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_fft_32;
+  else if (format->width == 64 && self->fft)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_fft_64;
+  else if (format->width == 32)
     self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_32;
   else if (format->width == 64)
     self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_64;
-  else
-    ret = FALSE;
+  ret = FALSE;
 
   return TRUE;
 }
 
 /* GstBaseTransform vmethod implementations */
 
+static gboolean
+gst_audio_fx_base_fir_filter_transform_size (GstBaseTransform * base,
+    GstPadDirection direction, GstCaps * caps, guint size, GstCaps * othercaps,
+    guint * othersize)
+{
+  GstAudioFXBaseFIRFilter *self = GST_AUDIO_FX_BASE_FIR_FILTER (base);
+  guint blocklen;
+  GstStructure *s;
+  gint width, channels;
+
+  if (!self->fft || direction == GST_PAD_SRC) {
+    *othersize = size;
+    return TRUE;
+  }
+
+  s = gst_caps_get_structure (caps, 0);
+  if (!gst_structure_get_int (s, "width", &width) ||
+      !gst_structure_get_int (s, "channels", &channels))
+    return FALSE;
+
+  width /= 8;
+
+  size /= width * channels;
+
+  blocklen = self->block_length - self->kernel_length + 1;
+  *othersize = ((size + blocklen - 1) / blocklen) * blocklen;
+
+  *othersize *= width * channels;
+
+  return TRUE;
+}
+
 static GstFlowReturn
 gst_audio_fx_base_fir_filter_transform (GstBaseTransform * base,
     GstBuffer * inbuf, GstBuffer * outbuf)
@@ -512,9 +758,13 @@ gst_audio_fx_base_fir_filter_query (GstPad * pad, GstQuery * query)
               GST_TIME_FORMAT " max %" GST_TIME_FORMAT,
               GST_TIME_ARGS (min), GST_TIME_ARGS (max));
 
+          if (self->fft)
+            latency = self->block_length - self->kernel_length + 1;
+          else
+            latency = self->latency;
+
           /* add our own latency */
-          latency =
-              gst_util_uint64_scale_round (self->latency, GST_SECOND, rate);
+          latency = gst_util_uint64_scale_round (latency, GST_SECOND, rate);
 
           GST_DEBUG_OBJECT (self, "Our latency: %"
               GST_TIME_FORMAT, GST_TIME_ARGS (latency));
@@ -576,6 +826,11 @@ void
 gst_audio_fx_base_fir_filter_set_kernel (GstAudioFXBaseFIRFilter * self,
     gdouble * kernel, guint kernel_length, guint64 latency)
 {
+  gdouble *kernel_tmp;
+  guint i;
+  gboolean latency_changed;
+  gint width;
+
   g_return_if_fail (kernel != NULL);
   g_return_if_fail (self != NULL);
 
@@ -589,16 +844,64 @@ gst_audio_fx_base_fir_filter_set_kernel (GstAudioFXBaseFIRFilter * self,
     self->buffer_fill = 0;
   }
 
+  latency_changed = (self->latency != latency
+      || (self->kernel_length < FFT_THRESHOLD && kernel_length >= FFT_THRESHOLD)
+      || (self->kernel_length >= FFT_THRESHOLD
+          && kernel_length < FFT_THRESHOLD));
+
   g_free (self->kernel);
   g_free (self->buffer);
   self->buffer = NULL;
   self->buffer_fill = 0;
   self->buffer_length = 0;
 
+  gst_fft_f64_free (self->fft);
+  self->fft = NULL;
+  gst_fft_f64_free (self->ifft);
+  self->ifft = NULL;
+  g_free (self->frequency_response);
+  self->frequency_response_length = 0;
+  g_free (self->fft_buffer);
+  self->fft_buffer = NULL;
+
   self->kernel = kernel;
   self->kernel_length = kernel_length;
 
-  if (self->latency != latency) {
+  if (kernel_length >= FFT_THRESHOLD) {
+    /* We process 4 * kernel_length samples per pass in FFT mode */
+    kernel_length = 4 * kernel_length;
+    kernel_length = gst_fft_next_fast_length (kernel_length);
+    self->block_length = kernel_length;
+
+    kernel_tmp = g_new0 (gdouble, kernel_length);
+    memcpy (kernel_tmp, kernel, self->kernel_length * sizeof (gdouble));
+
+    self->fft = gst_fft_f64_new (kernel_length, FALSE);
+    self->ifft = gst_fft_f64_new (kernel_length, TRUE);
+    self->frequency_response_length = kernel_length / 2 + 1;
+    self->frequency_response =
+        g_new (GstFFTF64Complex, self->frequency_response_length);
+    gst_fft_f64_fft (self->fft, kernel_tmp, self->frequency_response);
+    g_free (kernel_tmp);
+
+    /* Normalize to make sure IFFT(FFT(x)) == x */
+    for (i = 0; i < self->frequency_response_length; i++) {
+      self->frequency_response[i].r /= kernel_length;
+      self->frequency_response[i].i /= kernel_length;
+    }
+  }
+
+  width = GST_AUDIO_FILTER_CAST (self)->format.width;
+  if (width == 32 && self->fft)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_fft_32;
+  else if (width == 64 && self->fft)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_fft_64;
+  else if (width == 32)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_32;
+  else if (width == 64)
+    self->process = (GstAudioFXBaseFIRFilterProcessFunc) process_64;
+
+  if (latency_changed) {
     self->latency = latency;
     gst_element_post_message (GST_ELEMENT (self),
         gst_message_new_latency (GST_OBJECT (self)));
index aa03b1c..fd3c3bd 100644 (file)
@@ -27,6 +27,7 @@
 
 #include <gst/gst.h>
 #include <gst/audio/gstaudiofilter.h>
+#include <gst/fft/gstfftf64.h>
 
 G_BEGIN_DECLS
 
@@ -54,17 +55,26 @@ typedef guint (*GstAudioFXBaseFIRFilterProcessFunc) (GstAudioFXBaseFIRFilter *,
 struct _GstAudioFXBaseFIRFilter {
   GstAudioFilter element;
 
-  /* < private > */
-  GstAudioFXBaseFIRFilterProcessFunc process;
-
+  /* properties */
   gdouble *kernel;              /* filter kernel -- time domain */
   guint kernel_length;          /* length of the filter kernel -- time domain */
 
+  guint64 latency;              /* pre-latency of the filter kernel */
+
+  /* < private > */
+  GstAudioFXBaseFIRFilterProcessFunc process;
+
   gdouble *buffer;              /* buffer for storing samples of previous buffers */
   guint buffer_fill;            /* fill level of buffer */
-  guint buffer_length;          /* length of the buffer */
-
-  guint64 latency;
+  guint buffer_length;          /* length of the buffer -- meaning depends on processing mode */
+
+  /* FFT convolution specific data */
+  GstFFTF64 *fft;
+  GstFFTF64 *ifft;
+  GstFFTF64Complex *frequency_response;  /* filter kernel -- frequency domain */
+  guint frequency_response_length;       /* length of filter kernel -- frequency domain */
+  GstFFTF64Complex *fft_buffer;          /* FFT buffer, has the length of the frequency response */
+  guint block_length;                    /* Length of the processing blocks -- time domain */
 
   GstClockTime start_ts;        /* start timestamp after a discont */
   guint64 start_off;            /* start offset after a discont */