// TODO ac_match_range() -- match within a range of a buffer, file, ... // TODO streaming support -- match over a socket // TODO ac_strerror() -- report errors better // Aho-Corasick searching algorithm implementation // https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm #include #include #include #include #include #include #include #include "aho-corasick.h" struct pattern_node { ac_node_t *children[AC_ALPHABET_SIZE]; ac_node_t *fail; ac_pattern_t *matches; }; struct pattern_entry { const char *id; uint8_t *pattern; size_t len; ac_pattern_t *next; }; struct ac_context { ac_node_t *root; }; static ac_node_t *node_new(void) { return calloc(1, sizeof(ac_node_t)); } ac_context_t *ac_new(void) { ac_context_t *ctx = calloc(1, sizeof(ac_context_t)); ctx->root = node_new(); return ctx; } static void free_node(ac_node_t *node) { if (!node) { return; } for (int i = 0; i < AC_ALPHABET_SIZE; i++) { free_node(node->children[i]); } ac_pattern_t *p = node->matches; while(p) { ac_pattern_t *next = p->next; free(p->pattern); free(p); p = next; } free(node); } void ac_free(ac_context_t *ctx) { if (!ctx) { return; } free_node(ctx->root); free(ctx); } int ac_add_pattern(ac_context_t *ctx, const char *id, const uint8_t *pattern, size_t len) { if (!ctx || !id || !pattern || len == 0) { return -1; } ac_node_t *node = ctx->root; for (size_t i = 0; i < len; i++) { uint8_t c = pattern[i]; if (!node->children[c]) { node->children[c] = node_new(); } node = node->children[c]; } for (ac_pattern_t *p = node->matches; p; p = p->next) { if (p->len == len && memcmp(p->pattern, pattern, len) == 0 && strcmp(p->id, id) == 0) { return 0; // duplicate pattern. don't add. } } ac_pattern_t *entry = malloc(sizeof(*entry)); entry->id = id; entry->len = len; entry->pattern = malloc(len); if (!entry->pattern) { return -1; } else { memcpy(entry->pattern, pattern, len); } entry->next = node->matches; node->matches = entry; return 0; } struct node_queue { ac_node_t **data; size_t head; size_t tail; size_t cap; }; static struct node_queue *queue_new(size_t cap) { struct node_queue *q = calloc(1, sizeof(*q)); if (cap == 0) { cap = 64; } q->data = malloc(sizeof(void *) * cap); q->cap = cap; return q; } static void queue_free(struct node_queue *q) { free(q->data); free(q); } static void queue_push(struct node_queue *q, ac_node_t *node) { if (q->tail >= q->cap) { size_t new_cap = q->cap ? q->cap * 2 : 64; ac_node_t **new_data = realloc(q->data, sizeof(void *) * new_cap); if (!new_data) { fprintf(stderr, "queue_push: realloc failure\n"); exit(EXIT_FAILURE); // TODO error flag instead of exit } q->data = new_data; q->cap = new_cap; } q->data[q->tail++] = node; } static ac_node_t *queue_pop(struct node_queue *q) { return q->head < q->tail ? q->data[q->head++] : NULL; } static bool queue_empty(struct node_queue *q) { return q->head == q->tail; } int ac_build(ac_context_t *ctx) { if (!ctx || !ctx->root) { return -1; } struct node_queue *q = queue_new(0); if (!q) { return -1; } ac_node_t *root = ctx->root; for (int i = 0; i < AC_ALPHABET_SIZE; i++) { if (root->children[i]) { root->children[i]->fail = root; queue_push(q, root->children[i]); } } while (!queue_empty(q)) { ac_node_t *node = queue_pop(q); for (int i = 0; i < AC_ALPHABET_SIZE; i++) { ac_node_t *child = node->children[i]; if (!child) { continue; } ac_node_t *fail = node->fail; while (fail && !fail->children[i]) { fail = fail->fail; } if (fail) { child->fail = fail->children[i]; } else { child->fail = root; } queue_push(q, child); } } queue_free(q); return 0; } int ac_match(ac_context_t *ctx, const uint8_t *data, size_t len, ac_callback callback, void *user_data) { if (!ctx || !data || !callback) { // TODO error context return -1; } ac_node_t *node = ctx->root; for (size_t i = 0; i < len; i++) { uint8_t c = data[i]; while (node && !node->children[c]) { node = node->fail; } if (!node) { node = ctx->root; } else { node = node->children[c]; } ac_node_t *temp = node; while (temp) { for (ac_pattern_t *p = temp->matches; p; p = p->next) { ac_match_t match = { .id = p->id, .offset = i + 1 - p->len, .len = p->len }; callback(&match, user_data); } temp = temp->fail; } } return 0; } int ac_match_fd(ac_context_t *ctx, int fd, ac_callback callback, void *user_data) { if (!ctx || fd < 0 || !callback) { // TODO error context return -1; } char buffer[AC_BUF_SIZE]; size_t offset = 0; ssize_t n; ac_node_t *node = ctx->root; while ((n = read(fd, buffer, sizeof(buffer))) > 0) { for (ssize_t i = 0; i < n; i++) { uint8_t c = buffer[i]; while (node && !node->children[c]) { node = node->fail; } if (!node) { node = ctx->root; } else { node = node->children[c]; } ac_node_t *temp = node; while (temp) { for (ac_pattern_t *p = temp->matches; p; p = p->next) { ac_match_t match = { .id = p->id, .offset = (offset + i + 1) - p->len, .len = p->len }; callback(&match, user_data); } temp = temp->fail; } } offset += n; } return (n < 0) ? -1 : 0; } int ac_match_path(ac_context_t *ctx, const char *path, ac_callback callback, void *user_data) { if (!ctx || !path || !callback) { return -1; } int fd = open(path, O_RDONLY); if (fd < 0) { return -1; } int ret = ac_match_fd(ctx, fd, callback, user_data); close(fd); return ret; }