speech-recognition: started working on dictionary switching support.
authorKrisztian Litkey <krisztian.litkey@intel.com>
Thu, 6 Jun 2013 13:40:46 +0000 (16:40 +0300)
committerKrisztian Litkey <krisztian.litkey@intel.com>
Thu, 6 Jun 2013 13:53:21 +0000 (16:53 +0300)
src/daemon/recognizer.c
src/daemon/recognizer.h
src/plugins/simple-disambiguator/disambiguator.c

index af0d0f8..a959af2 100644 (file)
  */
 
 typedef struct {
-    srs_context_t   *srs;                /* main context */
-    char            *name;               /* recognizer name */
-    mrp_list_hook_t  hook;               /* to list of recognizers */
-    srs_srec_api_t   api;                /* backend API */
-    void            *api_data;           /* opaque backend data */
+    srs_context_t     *srs;              /* main context */
+    char              *name;             /* recognizer name */
+    mrp_list_hook_t    hook;             /* to list of recognizers */
+    srs_srec_api_t     api;              /* backend API */
+    void              *api_data;         /* opaque backend data */
+    srs_srec_result_t *result;           /* result being processed, if any */
 } srs_srec_t;
 
 
@@ -210,19 +211,71 @@ static srs_srec_t *find_srec(srs_context_t *srs, const char *name)
 }
 
 
+static void free_match(srs_srec_match_t *m)
+{
+    mrp_free(m);
+}
+
+
 static void free_srec_result(srs_srec_result_t *res)
 {
+    srs_srec_match_t *m;
+    mrp_list_hook_t  *p, *n;
+
+    switch (res->type) {
+    case SRS_SREC_RESULT_MATCH:
+        mrp_list_foreach(&res->result.matches, p, n) {
+            m = mrp_list_entry(p, typeof(*m), hook);
+            mrp_list_delete(&m->hook);
+            free_match(m);
+        }
+        break;
+
+    case SRS_SREC_RESULT_DICT:
+        break;
+
+    case SRS_SREC_RESULT_AMBIGUOUS:
+        break;
+
+    default:
+        break;
+    }
+
     mrp_list_delete(&res->hook);
     mrp_free(res);
 }
 
 
+static void process_match_result(srs_srec_t *srec, srs_srec_result_t *res)
+{
+    mrp_list_hook_t  *p, *n;
+    srs_srec_match_t *match;
+
+    mrp_list_foreach(&res->result.matches, p, n) {
+        match = mrp_list_entry(p, typeof(*match), hook);
+        client_notify_command(match->client, match->index);
+    }
+}
+
+
+static void process_dict_result(srs_srec_t *srec, srs_srec_result_t *res)
+{
+    printf("*** should process dictionary operation ***\n");
+    return;
+}
+
+
+static void process_ambiguity(srs_srec_t *srec, srs_srec_result_t *res)
+{
+    return;
+}
+
+
 static int srec_notify_cb(srs_srec_utterance_t *utt, void *notify_data)
 {
     srs_srec_t           *srec = (srs_srec_t *)notify_data;
     srs_disamb_t         *dis;
     srs_srec_candidate_t *c;
-    mrp_list_hook_t       results, *p, *n;
     srs_srec_result_t    *res;
     srs_srec_token_t     *t;
     int                   i, j;
@@ -241,22 +294,33 @@ static int srec_notify_cb(srs_srec_utterance_t *utt, void *notify_data)
     dis = find_disamb(srec->srs, SRS_DEFAULT_DISAMBIGUATOR);
 
     if (dis != NULL) {
-        mrp_list_init(&results);
+        res = srec->result;
 
-        if (dis->api.disambiguate(utt, &results, dis->api_data) == 0) {
+        if (dis->api.disambiguate(utt, &res, dis->api_data) == 0 && res) {
             mrp_log_info("Disambiguation succeeded.");
 
-            mrp_list_foreach(&results, p, n) {
-                res = mrp_list_entry(p, typeof(*res), hook);
+            switch (res->type) {
+            case SRS_SREC_RESULT_MATCH:
+                process_match_result(srec, res);
+                break;
 
-                client_notify_command(res->client, res->index);
+            case SRS_SREC_RESULT_DICT:
+                process_dict_result(srec, res);
+                break;
 
-                free_srec_result(res);
+            case SRS_SREC_RESULT_AMBIGUOUS:
+                process_ambiguity(srec, res);
+                break;
+
+            default:
+                break;
             }
+
+            free_srec_result(res);
         }
     }
 
-    return -1;
+    return SRS_SREC_FLUSH_ALL;
 }
 
 
index ef61977..090ca3d 100644 (file)
@@ -58,11 +58,14 @@ typedef struct {
     /** Schedule a rescan of the given portion of the audio buffer. */
     int (*rescan)(uint32_t start, uint32_t end, void *user_data);
     /** Get a copy of the audio samples in the buffer. */
-    void *(*sampledup)(uint32_t start, uint32_t end, void *user_data);
+    void *(*sampledup)(uint32_t start, uint32_t end, size_t *size,
+                       void *user_data);
     /** Check if the given language model exists/is usable. */
     int (*check_decoder)(const char *decoder, void *user_data);
     /** Set language model to be used. */
     int (*select_decoder)(const char *decoder, void *user_data);
+    /** Get the used language model. */
+    const char *(*get_decoder)(void *user_data);
 } srs_srec_api_t;
 
 /*
@@ -138,32 +141,39 @@ typedef enum {
     SRS_DISAMB_AMBIGUOUS,                /* failed to (fully) disambiguate */
 } srs_disamb_type_t;
 
+typedef enum {
+    SRS_SREC_RESULT_UNKNOWN = 0,         /* unknown result */
+    SRS_SREC_RESULT_MATCH,               /* full command match */
+    SRS_SREC_RESULT_DICT,                /* dictionary switch required */
+    SRS_SREC_RESULT_AMBIGUOUS,           /* further disambiguation needed */
+} srs_srec_result_type_t;
+
 typedef struct {
-    srs_disamb_type_t type;
-    union {
-        struct {
-            srs_client_t *clients;
-            int           indices;
-            int           nclient;
-        } match;
-        struct {
-            uint32_t      flush_start;
-            uint32_t      flush_end;
-        } rescan;
-        struct {
-            srs_client_t *clients;
-            int           indices;
-            int           nclient;
-        } ambiguity;
-    } result;
-} srs_disamb_result_t;
+    mrp_list_hook_t   hook;              /* to more commands */
+    srs_client_t     *client;            /* actual client */
+    int               index;             /* client command index */
+    double            score;             /* backend score */
+    int               fuzz;              /* disambiguation fuzz */
+    char            **tokens;            /* command tokens */
+} srs_srec_match_t;
 
 struct srs_srec_result_s {
-    srs_client_t    *client;             /* client */
-    int              index;              /* command index */
-    double           score;              /* recognition backend score */
-    int              fuzz;               /* disambiguation fuzz */
-    mrp_list_hook_t  hook;               /* to more results */
+    srs_srec_result_type_t   type;       /* result type */
+    mrp_list_hook_t          hook;       /* to list of results */
+    char                   **tokens;     /* matched tokens */
+    int                      ntoken;     /* number of tokens */
+    char                   **dicts;      /* dictionary stack */
+    int                      ndict;      /* stack depth */
+
+    union {                              /* type specific data */
+        mrp_list_hook_t    matches;      /* full match(es) */
+        struct {
+            srs_dict_op_t  op;           /* push/pop/switch */
+            char          *dict;         /* dictionary for switch/push */
+            int            rescan;       /* rescan starting at this token */
+            void          *state;        /* disambiguator continuation */
+        } dict;
+    } result;
 };
 
 
@@ -177,7 +187,7 @@ typedef struct {
     /** Unregister the commands of a client. */
     void (*del_client)(srs_client_t *client, void *api_data);
     /** Disambiguate an utterance with candidates. */
-    int (*disambiguate)(srs_srec_utterance_t *utt, mrp_list_hook_t *results,
+    int (*disambiguate)(srs_srec_utterance_t *utt, srs_srec_result_t **result,
                         void *api_data);
 } srs_disamb_api_t;
 
index aa7ae50..f92334e 100644 (file)
@@ -116,8 +116,11 @@ static srs_dict_op_t parse_dictionary(const char *tkn, char *dict, size_t size)
 static node_t *get_token_node(node_t *prnt, const char *token, int insert)
 {
     mrp_list_hook_t *p, *n;
-    node_t          *node;
+    node_t          *node, *any;
+    int              cnt;
 
+    cnt = 0;
+    any = NULL;
     mrp_list_foreach(&prnt->children, p, n) {
         node = mrp_list_entry(p, typeof(*node), hook);
 
@@ -130,10 +133,33 @@ static node_t *get_token_node(node_t *prnt, const char *token, int insert)
             mrp_debug("found token node %s", token);
             return node;
         }
+
+        if (!strcmp(node->data.token, SRS_TOKEN_WILDCARD))
+            any = node;
+
+        cnt++;
     }
 
+    /*
+     * wildcard node matches all tokens but only for pure lookups
+     */
+
     if (!insert) {
-        errno = ENOENT;
+        if (any != NULL)
+            return any;
+        else {
+            errno = ENOENT;
+            return NULL;
+        }
+    }
+
+    /*
+     * a wildcard node must be the only child of its parent
+     */
+
+    if (any != NULL || (cnt > 0 && !strcmp(token, SRS_TOKEN_WILDCARD))) {
+        mrp_log_error("Wildcard/non-wildcard token conflict.");
+        errno = EILSEQ;
         return NULL;
     }
 
@@ -168,11 +194,19 @@ static node_t *get_dictionary_node(node_t *prnt, const char *token, int insert)
     mrp_list_hook_t *p, *n;
     node_t          *node;
 
-    op = parse_dictionary(token, dict, sizeof(dict));
+    if (token != NULL) {
+        op = parse_dictionary(token, dict, sizeof(dict));
 
-    if (op == SRS_DICT_OP_UNKNOWN) {
-        errno = EINVAL;
-        return NULL;
+        if (op == SRS_DICT_OP_UNKNOWN) {
+            errno = EINVAL;
+            return NULL;
+        }
+    }
+    else {
+        if (insert) {
+            errno = EINVAL;
+            return NULL;
+        }
     }
 
     if (prnt->type != NODE_TYPE_TOKEN) {
@@ -183,6 +217,10 @@ static node_t *get_dictionary_node(node_t *prnt, const char *token, int insert)
     mrp_list_foreach(&prnt->children, p, n) {
         node = mrp_list_entry(p, typeof(*node), hook);
 
+        if (!insert && token == NULL)
+            if (node->type == NODE_TYPE_DICTIONARY)
+                return node;
+
         if (node->type != NODE_TYPE_DICTIONARY || node->data.dict.op != op ||
             strcmp(dict, node->data.dict.dict) != 0) {
             errno = EILSEQ;
@@ -457,63 +495,126 @@ static void disamb_del_client(srs_client_t *client, void *api_data)
 }
 
 
-static int disambiguate(srs_srec_utterance_t *utt, mrp_list_hook_t *results,
+static int disambiguate(srs_srec_utterance_t *utt, srs_srec_result_t **result,
                         void *api_data)
 {
     disamb_t             *dis = (disamb_t *)api_data;
     srs_srec_candidate_t *src;
     srs_srec_result_t    *res;
+    srs_srec_match_t     *m;
     const char           *tkn;
     mrp_list_hook_t      *p, *n;
-    node_t               *node, *child;
-    int                   i, j, match;
-
-    mrp_list_init(results);
+    node_t               *node, *child, *prnt;
+    int                   i, j, end, match;
 
     mrp_debug("should disambiguate utterance %p", utt);
 
-    for (i = 0; i < (int)utt->ncand; i++) {
-        src  = utt->cands[i];
+    /* XXX handling multiple candidates currently not implemented */
+    if (utt->ncand > 1) {
+        mrp_log_error("handling multiple candidates not implemented");
+        return -1;
+    }
+
+    src = utt->cands[0];
+    res = *result;
+
+    if (res != NULL) {
+        if (res->type == SRS_SREC_RESULT_DICT)
+            node = res->result.dict.state;
+    }
+    else {
         node = dis->root;
+        res  = mrp_allocz(sizeof(*res));
 
-        for (j = 0, match = TRUE; j < (int)src->ntoken && match; j++) {
-            tkn = src->tokens[j].token;
+        if (res == NULL)
+            return -1;
 
-            node = get_token_node(node, tkn, FALSE);
+        mrp_list_init(&res->hook);
+        mrp_list_init(&res->result.matches);
+    }
 
-            if (node == NULL)
-                match = FALSE;
+    for (i = 0, match = TRUE; i < (int)src->ntoken && match; i++) {
+        tkn = src->tokens[i].token;
+
+        prnt = node;
+        node = get_token_node(prnt, tkn, FALSE);
+
+        if (node == NULL)
+            node = get_token_node(prnt, SRS_TOKEN_WILDCARD, FALSE);
+
+        if (node == NULL) {
+            node = get_dictionary_node(prnt, NULL, FALSE);
+
+            if (node != NULL) {
+                printf("*** found dictionary node %s ***\n",
+                       node->data.dict.dict);
+
+                res->type = SRS_SREC_RESULT_DICT;
+                res->result.dict.op    = node->data.dict.op;
+                res->result.dict.dict  = node->data.dict.dict;
+                res->result.dict.state = node;
+
+                *result = res;
+
+                return 0;
+            }
             else
-                mrp_debug("found matching node for %s", tkn);
+                match = FALSE;
         }
+        else {
+            mrp_debug("found matching node for %s", tkn);
 
-        if (match) {
-            mrp_list_foreach(&node->children, p, n) {
-                child = mrp_list_entry(p, typeof(*child), hook);
+            if (strcmp(node->data.token, SRS_TOKEN_WILDCARD))
+                end = i;
+            else
+                end = (int)src->ntoken - 1;
 
-                if (child->type != NODE_TYPE_CLIENT) {
-                    mrp_log_error("Unexpected non-client node type 0x%x.",
-                                  node->type);
-                    continue;
-                }
+            for (j = i; j <= end; j++) {
+                if (mrp_reallocz(res->tokens, res->ntoken, res->ntoken + 1)) {
+                    res->tokens[res->ntoken] = mrp_strdup(tkn);
 
-                res = mrp_allocz(sizeof(*res));
+                    if (res->tokens[res->ntoken] == NULL)
+                        return -1;
+                    else
+                        res->ntoken++;
+                }
+            }
 
-                if (res == NULL)
-                    return -1;
+            i = end;
+        }
+    }
 
-                mrp_list_init(&res->hook);
-                res->client = child->data.client.client;
-                res->index  = child->data.client.index;
-                res->fuzz   = 0;
-                res->score  = src->score;
+    if (match && i == (int)src->ntoken) {
+        res->type = SRS_SREC_RESULT_MATCH;
 
-                mrp_list_append(results, &res->hook);
+        mrp_list_foreach(&node->children, p, n) {
+            child = mrp_list_entry(p, typeof(*child), hook);
 
-                mrp_log_info("Found matching command %s/#%d.",
-                             res->client->id, res->index);
+            if (child->type != NODE_TYPE_CLIENT) {
+                mrp_log_error("Unexpected non-client node type 0x%x.",
+                              node->type);
+                continue;
             }
+
+            m = mrp_allocz(sizeof(*m));
+
+            if (m == NULL)
+                return -1;
+
+            mrp_list_init(&m->hook);
+            m->client = child->data.client.client;
+            m->index  = child->data.client.index;
+            m->score  = src->score;
+            m->fuzz   = 0;
+            m->tokens = NULL;
+
+            mrp_list_append(&res->result.matches, &m->hook);
+
+            mrp_log_info("Found matching command %s/#%d.",
+                         m->client->id, m->index);
         }
+
+        *result = res;
     }
 
     return 0;