#include <sys/types.h>
#include <ctype.h>
#include <errno.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>
#include "patterns.h"
#define uchar(c) ((unsigned char)(c))
#define CAP_UNFINISHED (-1)
#define CAP_POSITION (-2)
#define L_ESC '%'
#define SPECIALS "^$*+?.([%-"
struct match_state {
int matchdepth;
int repetitioncounter;
int maxcaptures;
const char *src_init;
const char *src_end;
const char *p_end;
const char *error;
int level;
struct {
const char *init;
ptrdiff_t len;
} capture[MAXCAPTURES];
};
static const char *match(struct match_state *, const char *, const char *);
static int
match_error(struct match_state *ms, const char *error)
{
ms->error = ms->error == NULL ? error : ms->error;
return (-1);
}
static int
check_capture(struct match_state *ms, int l)
{
l -= '1';
if (l < 0 || l >= ms->level || ms->capture[l].len == CAP_UNFINISHED)
return match_error(ms, "invalid capture index");
return (l);
}
static int
capture_to_close(struct match_state *ms)
{
int level = ms->level;
for (level--; level >= 0; level--)
if (ms->capture[level].len == CAP_UNFINISHED)
return (level);
return match_error(ms, "invalid pattern capture");
}
static const char *
classend(struct match_state *ms, const char *p)
{
switch (*p++) {
case L_ESC:
if (p == ms->p_end)
match_error(ms,
"malformed pattern (ends with '%')");
return p + 1;
case '[':
if (*p == '^')
p++;
do {
if (p == ms->p_end) {
match_error(ms,
"malformed pattern (missing ']')");
break;
}
if (*(p++) == L_ESC && p < ms->p_end) {
p++;
}
} while (*p != ']');
return p + 1;
default:
return p;
}
}
static int
match_class(int c, int cl)
{
int res;
switch (tolower(cl)) {
case 'a':
res = isalpha(c);
break;
case 'c':
res = iscntrl(c);
break;
case 'd':
res = isdigit(c);
break;
case 'g':
res = isgraph(c);
break;
case 'l':
res = islower(c);
break;
case 'p':
res = ispunct(c);
break;
case 's':
res = isspace(c);
break;
case 'u':
res = isupper(c);
break;
case 'w':
res = isalnum(c);
break;
case 'x':
res = isxdigit(c);
break;
default:
return (cl == c);
}
return (islower(cl) ? res : !res);
}
static int
matchbracketclass(int c, const char *p, const char *ec)
{
int sig = 1;
if (*(p + 1) == '^') {
sig = 0;
p++;
}
while (++p < ec) {
if (*p == L_ESC) {
p++;
if (match_class(c, uchar(*p)))
return sig;
} else if ((*(p + 1) == '-') && (p + 2 < ec)) {
p += 2;
if (uchar(*(p - 2)) <= c && c <= uchar(*p))
return sig;
} else if (uchar(*p) == c)
return sig;
}
return !sig;
}
static int
singlematch(struct match_state *ms, const char *s, const char *p,
const char *ep)
{
if (s >= ms->src_end)
return 0;
else {
int c = uchar(*s);
switch (*p) {
case '.':
return (1);
case L_ESC:
return match_class(c, uchar(*(p + 1)));
case '[':
return matchbracketclass(c, p, ep - 1);
default:
return (uchar(*p) == c);
}
}
}
static const char *
matchbalance(struct match_state *ms, const char *s, const char *p)
{
if (p >= ms->p_end - 1) {
match_error(ms,
"malformed pattern (missing arguments to '%b')");
return (NULL);
}
if (*s != *p)
return (NULL);
else {
int b = *p;
int e = *(p + 1);
int cont = 1;
while (++s < ms->src_end) {
if (*s == e) {
if (--cont == 0)
return s + 1;
} else if (*s == b)
cont++;
}
}
return (NULL);
}
static const char *
max_expand(struct match_state *ms, const char *s, const char *p, const char *ep)
{
ptrdiff_t i = 0;
while (singlematch(ms, s + i, p, ep))
i++;
while (i >= 0) {
const char *res = match(ms, (s + i), ep + 1);
if (res)
return res;
i--;
}
return NULL;
}
static const char *
min_expand(struct match_state *ms, const char *s, const char *p, const char *ep)
{
for (;;) {
const char *res = match(ms, s, ep + 1);
if (res != NULL)
return res;
else if (singlematch(ms, s, p, ep))
s++;
else
return NULL;
}
}
static const char *
start_capture(struct match_state *ms, const char *s, const char *p, int what)
{
const char *res;
int level = ms->level;
if (level >= ms->maxcaptures) {
match_error(ms, "too many captures");
return (NULL);
}
ms->capture[level].init = s;
ms->capture[level].len = what;
ms->level = level + 1;
if ((res = match(ms, s, p)) == NULL)
ms->level--;
return res;
}
static const char *
end_capture(struct match_state *ms, const char *s, const char *p)
{
int l = capture_to_close(ms);
const char *res;
if (l == -1)
return NULL;
ms->capture[l].len = s - ms->capture[l].init;
if ((res = match(ms, s, p)) == NULL)
ms->capture[l].len = CAP_UNFINISHED;
return res;
}
static const char *
match_capture(struct match_state *ms, const char *s, int l)
{
size_t len;
l = check_capture(ms, l);
if (l == -1)
return NULL;
len = ms->capture[l].len;
if ((size_t) (ms->src_end - s) >= len &&
memcmp(ms->capture[l].init, s, len) == 0)
return s + len;
else
return NULL;
}
static const char *
match(struct match_state *ms, const char *s, const char *p)
{
const char *ep, *res;
char previous;
if (ms->matchdepth-- == 0) {
match_error(ms, "pattern too complex");
return (NULL);
}
init:
if (p != ms->p_end) {
switch (*p) {
case '(':
if (*(p + 1) == ')')
s = start_capture(ms, s, p + 2, CAP_POSITION);
else
s = start_capture(ms, s, p + 1, CAP_UNFINISHED);
break;
case ')':
s = end_capture(ms, s, p + 1);
break;
case '$':
if ((p + 1) != ms->p_end) {
goto dflt;
}
s = (s == ms->src_end) ? s : NULL;
break;
case L_ESC:
switch (*(p + 1)) {
case 'b':
s = matchbalance(ms, s, p + 2);
if (s != NULL) {
p += 4;
goto init;
}
break;
case 'f':
p += 2;
if (*p != '[') {
match_error(ms, "missing '['"
" after '%f' in pattern");
break;
}
ep = classend(ms, p);
if (ms->error != NULL)
break;
previous =
(s == ms->src_init) ? '\0' : *(s - 1);
if (!matchbracketclass(uchar(previous),
p, ep - 1) &&
matchbracketclass(uchar(*s),
p, ep - 1)) {
p = ep;
goto init;
}
s = NULL;
break;
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
s = match_capture(ms, s, uchar(*(p + 1)));
if (s != NULL) {
p += 2;
goto init;
}
break;
default:
goto dflt;
}
break;
default:
dflt:
ep = classend(ms, p);
if (ms->error != NULL)
break;
if (!singlematch(ms, s, p, ep)) {
if (ms->repetitioncounter-- == 0) {
match_error(ms, "max repetition items");
s = NULL;
} else if
(*ep == '*' || *ep == '?' || *ep == '-') {
p = ep + 1;
goto init;
} else {
s = NULL;
}
} else {
switch (*ep) {
case '?':
if ((res =
match(ms, s + 1, ep + 1)) != NULL)
s = res;
else {
p = ep + 1;
goto init;
}
break;
case '+':
s++;
case '*':
s = max_expand(ms, s, p, ep);
break;
case '-':
s = min_expand(ms, s, p, ep);
break;
default:
s++;
p = ep;
goto init;
}
}
break;
}
}
ms->matchdepth++;
return s;
}
static const char *
lmemfind(const char *s1, size_t l1,
const char *s2, size_t l2)
{
const char *init;
if (l2 == 0) {
return (s1);
} else if (l2 > l1) {
return (NULL);
} else {
l2--;
l1 = l1 - l2;
while (l1 > 0 &&
(init = (const char *)memchr(s1, *s2, l1)) != NULL) {
init++;
if (memcmp(init, s2 + 1, l2) == 0)
return init - 1;
else {
l1 -= init - s1;
s1 = init;
}
}
return (NULL);
}
}
static int
push_onecapture(struct match_state *ms, int i, const char *s,
const char *e, struct str_find *sm)
{
if (i >= ms->level) {
if (i == 0 || ms->level == 0) {
sm->sm_so = (off_t)(s - ms->src_init);
sm->sm_eo = (off_t)(e - s) + sm->sm_so;
} else
return match_error(ms, "invalid capture index");
} else {
ptrdiff_t l = ms->capture[i].len;
if (l == CAP_UNFINISHED)
return match_error(ms, "unfinished capture");
sm->sm_so = ms->capture[i].init - ms->src_init;
sm->sm_eo = sm->sm_so + l;
}
sm->sm_eo = sm->sm_eo < sm->sm_so ? sm->sm_so : sm->sm_eo;
return (0);
}
static int
push_captures(struct match_state *ms, const char *s, const char *e,
struct str_find *sm, size_t nsm)
{
unsigned int i;
unsigned int nlevels = (ms->level <= 0 && s) ? 1 : ms->level;
if (nlevels > nsm)
nlevels = nsm;
for (i = 0; i < nlevels; i++)
if (push_onecapture(ms, i, s, e, sm + i) == -1)
break;
return (nlevels);
}
static int
nospecials(const char *p, size_t l)
{
size_t upto = 0;
do {
if (strpbrk(p + upto, SPECIALS)) {
return 0;
}
upto += strlen(p + upto) + 1;
} while (upto <= l);
return (1);
}
static int
str_find_aux(struct match_state *ms, const char *pattern, const char *string,
struct str_find *sm, size_t nsm, off_t init)
{
size_t ls = strlen(string);
size_t lp = strlen(pattern);
const char *s = string;
const char *p = pattern;
const char *s1, *s2;
int anchor, i;
if (init < 0)
init = 0;
else if (init > (off_t)ls)
return match_error(ms, "starting after string's end");
s1 = s + init;
if (nospecials(p, lp)) {
s2 = lmemfind(s1, ls - (size_t)init, p, lp);
if (s2 != NULL) {
i = 0;
sm[i].sm_so = 0;
sm[i].sm_eo = ls;
if (nsm > 1) {
i++;
sm[i].sm_so = s2 - s;
sm[i].sm_eo = (s2 - s) + lp;
}
return (i + 1);
}
return (0);
}
anchor = (*p == '^');
if (anchor) {
p++;
lp--;
}
ms->maxcaptures = (nsm > MAXCAPTURES ? MAXCAPTURES : nsm) - 1;
ms->matchdepth = MAXCCALLS;
ms->repetitioncounter = MAXREPETITION;
ms->src_init = s;
ms->src_end = s + ls;
ms->p_end = p + lp;
do {
const char *res;
ms->level = 0;
if ((res = match(ms, s1, p)) != NULL) {
sm->sm_so = 0;
sm->sm_eo = ls;
return push_captures(ms, s1, res, sm + 1, nsm - 1) + 1;
} else if (ms->error != NULL) {
return 0;
}
} while (s1++ < ms->src_end && !anchor);
return 0;
}
int
str_find(const char *string, const char *pattern, struct str_find *sm,
size_t nsm, const char **errstr)
{
struct match_state ms;
int ret;
memset(&ms, 0, sizeof(ms));
memset(sm, 0, nsm * sizeof(*sm));
ret = str_find_aux(&ms, pattern, string, sm, nsm, 0);
if (ms.error != NULL) {
*errstr = ms.error;
ret = 0;
} else
*errstr = NULL;
return (ret);
}
int
str_match(const char *string, const char *pattern, struct str_match *m,
const char **errstr)
{
struct str_find sm[MAXCAPTURES];
struct match_state ms;
int ret, i;
size_t len, nsm;
nsm = MAXCAPTURES;
memset(&ms, 0, sizeof(ms));
memset(sm, 0, sizeof(sm));
memset(m, 0, sizeof(*m));
ret = str_find_aux(&ms, pattern, string, sm, nsm, 0);
if (ret <= 0 || ms.error != NULL) {
*errstr = ms.error;
return (-1);
}
if ((m->sm_match = calloc(ret, sizeof(char *))) == NULL) {
*errstr = strerror(errno);
return (-1);
}
m->sm_nmatch = ret;
for (i = 0; i < ret; i++) {
if (sm[i].sm_so > sm[i].sm_eo)
continue;
len = sm[i].sm_eo - sm[i].sm_so;
if ((m->sm_match[i] = strndup(string +
sm[i].sm_so, len)) == NULL) {
*errstr = strerror(errno);
str_match_free(m);
return (-1);
}
}
*errstr = NULL;
return (0);
}
void
str_match_free(struct str_match *m)
{
unsigned int i = 0;
for (i = 0; i < m->sm_nmatch; i++)
free(m->sm_match[i]);
free(m->sm_match);
m->sm_match = NULL;
m->sm_nmatch = 0;
}