diff options
Diffstat (limited to 'src/aho-corasick.c')
| -rw-r--r-- | src/aho-corasick.c | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/src/aho-corasick.c b/src/aho-corasick.c new file mode 100644 index 0000000..c69fafd --- /dev/null +++ b/src/aho-corasick.c @@ -0,0 +1,317 @@ +// TODO ac_match_range() -- match within a range of a buffer, file, ... +// TODO streaming support -- match over a socket +// TODO ac_strerror() -- report errors better + +// Aho-Corasick searching algorithm implementation +// https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm + +#include <stdio.h> +#include <stdlib.h> +#include <stdbool.h> +#include <string.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> + +#include "aho-corasick.h" + +struct pattern_node { + ac_node_t *children[AC_ALPHABET_SIZE]; + ac_node_t *fail; + ac_pattern_t *matches; +}; + +struct pattern_entry { + const char *id; + uint8_t *pattern; + size_t len; + ac_pattern_t *next; +}; + +struct ac_context { + ac_node_t *root; +}; + +static ac_node_t *node_new(void) { + return calloc(1, sizeof(ac_node_t)); +} + +ac_context_t *ac_new(void) { + ac_context_t *ctx = calloc(1, sizeof(ac_context_t)); + + ctx->root = node_new(); + + return ctx; +} + +static void free_node(ac_node_t *node) { + if (!node) { + return; + } + + for (int i = 0; i < AC_ALPHABET_SIZE; i++) { + free_node(node->children[i]); + } + + ac_pattern_t *p = node->matches; + while(p) { + ac_pattern_t *next = p->next; + free(p->pattern); + free(p); + p = next; + } + + free(node); +} + +void ac_free(ac_context_t *ctx) { + if (!ctx) { + return; + } + + free_node(ctx->root); + free(ctx); +} + +int ac_add_pattern(ac_context_t *ctx, const char *id, const uint8_t *pattern, size_t len) { + if (!ctx || !id || !pattern || len == 0) { + return -1; + } + + ac_node_t *node = ctx->root; + + for (size_t i = 0; i < len; i++) { + uint8_t c = pattern[i]; + + if (!node->children[c]) { + node->children[c] = node_new(); + } + + node = node->children[c]; + } + + for (ac_pattern_t *p = node->matches; p; p = p->next) { + if (p->len == len && + memcmp(p->pattern, pattern, len) == 0 && + strcmp(p->id, id) == 0) { + return 0; // duplicate pattern. don't add. + } + } + + ac_pattern_t *entry = malloc(sizeof(*entry)); + entry->id = id; + entry->len = len; + entry->pattern = malloc(len); + if (!entry->pattern) { + return -1; + } else { + memcpy(entry->pattern, pattern, len); + } + entry->next = node->matches; + node->matches = entry; + + return 0; +} + +struct node_queue { + ac_node_t **data; + size_t head; + size_t tail; + size_t cap; +}; + +static struct node_queue *queue_new(size_t cap) { + struct node_queue *q = calloc(1, sizeof(*q)); + + if (cap == 0) { + cap = 64; + } + + q->data = malloc(sizeof(void *) * cap); + q->cap = cap; + + return q; +} + +static void queue_free(struct node_queue *q) { + free(q->data); + free(q); +} + +static void queue_push(struct node_queue *q, ac_node_t *node) { + if (q->tail >= q->cap) { + size_t new_cap = q->cap ? q->cap * 2 : 64; + ac_node_t **new_data = realloc(q->data, sizeof(void *) * new_cap); + + if (!new_data) { + fprintf(stderr, "queue_push: realloc failure\n"); + exit(EXIT_FAILURE); // TODO error flag instead of exit + } + + q->data = new_data; + q->cap = new_cap; + } + + q->data[q->tail++] = node; +} + +static ac_node_t *queue_pop(struct node_queue *q) { + return q->head < q->tail ? q->data[q->head++] : NULL; +} + +static bool queue_empty(struct node_queue *q) { + return q->head == q->tail; +} + +int ac_build(ac_context_t *ctx) { + if (!ctx || !ctx->root) { + return -1; + } + + struct node_queue *q = queue_new(0); + if (!q) { + return -1; + } + + ac_node_t *root = ctx->root; + + for (int i = 0; i < AC_ALPHABET_SIZE; i++) { + if (root->children[i]) { + root->children[i]->fail = root; + queue_push(q, root->children[i]); + } + } + + while (!queue_empty(q)) { + ac_node_t *node = queue_pop(q); + + for (int i = 0; i < AC_ALPHABET_SIZE; i++) { + ac_node_t *child = node->children[i]; + if (!child) { + continue; + } + + ac_node_t *fail = node->fail; + while (fail && !fail->children[i]) { + fail = fail->fail; + } + + if (fail) { + child->fail = fail->children[i]; + } else { + child->fail = root; + } + + queue_push(q, child); + } + } + + queue_free(q); + + return 0; +} + +int ac_match(ac_context_t *ctx, const uint8_t *data, size_t len, ac_callback callback, void *user_data) { + if (!ctx || !data || !callback) { + // TODO error context + return -1; + } + + ac_node_t *node = ctx->root; + + for (size_t i = 0; i < len; i++) { + uint8_t c = data[i]; + + while (node && !node->children[c]) { + node = node->fail; + } + + if (!node) { + node = ctx->root; + } else { + node = node->children[c]; + } + + ac_node_t *temp = node; + while (temp) { + for (ac_pattern_t *p = temp->matches; p; p = p->next) { + ac_match_t match = { + .id = p->id, + .offset = i + 1 - p->len, + .len = p->len + }; + + callback(&match, user_data); + } + + temp = temp->fail; + } + } + + return 0; +} + +int ac_match_fd(ac_context_t *ctx, int fd, ac_callback callback, void *user_data) { + if (!ctx || fd < 0 || !callback) { + // TODO error context + return -1; + } + + char buffer[AC_BUF_SIZE]; + size_t offset = 0; + ssize_t n; + + ac_node_t *node = ctx->root; + + while ((n = read(fd, buffer, sizeof(buffer))) > 0) { + for (ssize_t i = 0; i < n; i++) { + uint8_t c = buffer[i]; + + while (node && !node->children[c]) { + node = node->fail; + } + + if (!node) { + node = ctx->root; + } else { + node = node->children[c]; + } + + ac_node_t *temp = node; + while (temp) { + for (ac_pattern_t *p = temp->matches; p; p = p->next) { + ac_match_t match = { + .id = p->id, + .offset = (offset + i + 1) - p->len, + .len = p->len + }; + + callback(&match, user_data); + } + + temp = temp->fail; + } + } + + offset += n; + } + + return (n < 0) ? -1 : 0; +} + +int ac_match_path(ac_context_t *ctx, const char *path, ac_callback callback, void *user_data) { + if (!ctx || !path || !callback) { + return -1; + } + + int fd = open(path, O_RDONLY); + if (fd < 0) { + return -1; + } + + int ret = ac_match_fd(ctx, fd, callback, user_data); + + close(fd); + + return ret; +} |
