From 19cafa40a1ad37bf95d2b2464d203f2792449d48 Mon Sep 17 00:00:00 2001 From: Curtis McEnroe Date: Wed, 23 Oct 2019 17:49:24 -0400 Subject: Implement some amount of client connection --- bounce.c | 23 +++++++---- bounce.h | 8 ++-- client.c | 140 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 153 insertions(+), 18 deletions(-) diff --git a/bounce.c b/bounce.c index ff99b91..07ed4ae 100644 --- a/bounce.c +++ b/bounce.c @@ -43,6 +43,7 @@ static void loopAdd(int fd, struct Client *client) { loop.fds[loop.len].fd = fd; loop.fds[loop.len].events = POLLIN; + loop.fds[loop.len].revents = 0; loop.clients[loop.len] = client; loop.len++; } @@ -64,7 +65,6 @@ static char *censor(char *arg) { int main(int argc, char *argv[]) { const char *localHost = "localhost"; const char *localPort = "6697"; - const char *localPass = NULL; char certPath[PATH_MAX] = ""; char privPath[PATH_MAX] = ""; @@ -84,7 +84,7 @@ int main(int argc, char *argv[]) { break; case 'H': localHost = optarg; break; case 'K': strlcpy(privPath, optarg, sizeof(privPath)); break; case 'P': localPort = optarg; - break; case 'W': localPass = censor(optarg); + break; case 'W': clientPass = censor(optarg); break; case 'a': auth = censor(optarg); break; case 'h': host = optarg; break; case 'j': join = optarg; @@ -138,17 +138,22 @@ int main(int argc, char *argv[]) { for (size_t i = 0; i < loop.len; ++i) { if (!loop.fds[i].revents) continue; if (i < bindLen) { - struct Client *client = clientAlloc(); - loopAdd(listenAccept(&client->tls, loop.fds[i].fd), client); + struct tls *tls; + int fd = listenAccept(&tls, loop.fds[i].fd); + loopAdd(fd, clientAlloc(tls)); } else if (!loop.clients[i]) { serverRecv(); - } else if (loop.fds[i].revents & POLLERR) { - close(loop.fds[i].fd); - clientFree(loop.clients[i]); - loopRemove(i); } else { - clientRecv(loop.clients[i]); + struct Client *client = loop.clients[i]; + if (loop.fds[i].revents & POLLIN) clientRecv(client); + if (loop.fds[i].revents & ~POLLIN || clientClose(client)) { + clientFree(client); + close(loop.fds[i].fd); + loopRemove(i); + break; + } } } } + err(EX_IOERR, "poll"); } diff --git a/bounce.h b/bounce.h index 491616b..0929da7 100644 --- a/bounce.h +++ b/bounce.h @@ -27,10 +27,6 @@ #define DEFAULT_PRIV_PATH "/usr/local/etc/letsencrypt/live/%s/privkey.pem" #endif -struct Client { - struct tls *tls; -}; - #define ARRAY_LEN(a) (sizeof(a) / sizeof(a[0])) enum { ParamCap = 15 }; @@ -73,6 +69,8 @@ void serverJoin(const char *join); void serverSend(const char *ptr, size_t len); void serverRecv(void); -struct Client *clientAlloc(void); +char *clientPass; +struct Client *clientAlloc(struct tls *tls); void clientFree(struct Client *client); +bool clientClose(const struct Client *client); void clientRecv(struct Client *client); diff --git a/client.c b/client.c index d012b28..0bba8d3 100644 --- a/client.c +++ b/client.c @@ -14,23 +14,155 @@ * along with this program. If not, see . */ +#include #include +#include +#include +#include #include +#include #include +#include #include "bounce.h" -struct Client *clientAlloc(void) { - struct Client *client = calloc(1, sizeof(*client)); - if (!client) err(EX_OSERR, "calloc"); +enum Need { + NeedNick = 1 << 0, + NeedUser = 1 << 1, + NeedPass = 1 << 2, + NeedCapEnd = 1 << 3, +}; + +struct Client { + bool close; + struct tls *tls; + enum Need need; + char buf[4096]; + size_t len; +}; + +struct Client *clientAlloc(struct tls *tls) { + struct Client *client = malloc(sizeof(*client)); + if (!client) err(EX_OSERR, "malloc"); + + client->close = false; + client->tls = tls; + client->need = NeedNick | NeedUser | (clientPass ? NeedPass : 0); + client->len = 0; + return client; } void clientFree(struct Client *client) { + tls_close(client->tls); tls_free(client->tls); free(client); } -void clientRecv(struct Client *client) { +bool clientClose(const struct Client *client) { + return client->close; +} + +static void clientSend(struct Client *client, const char *ptr, size_t len) { + if (verbose) fprintf(stderr, "\x1B[34m%.*s\x1B[m", (int)len, ptr); + while (len) { + ssize_t ret = tls_write(client->tls, ptr, len); + // FIXME: Handle non-blocking? + if (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) continue; + if (ret < 0) { + warnx("tls_write: %s", tls_error(client->tls)); + client->close = true; + return; + } + ptr += ret; + len -= ret; + } +} + +static void format(struct Client *client, const char *format, ...) { + char buf[513]; + va_list ap; + va_start(ap, format); + int len = vsnprintf(buf, sizeof(buf), format, ap); + va_end(ap); + assert(len > 0 && (size_t)len < sizeof(buf)); + clientSend(client, buf, len); +} + +typedef void Handler(struct Client *client, struct Command cmd); + +static void handleNick(struct Client *client, struct Command cmd) { + (void)cmd; + client->need &= ~NeedNick; +} + +static void handleUser(struct Client *client, struct Command cmd) { + (void)cmd; + // TODO: Identify client by username. + client->need &= ~NeedUser; +} + +static void handlePass(struct Client *client, struct Command cmd) { + if (!cmd.params[0] || strcmp(clientPass, cmd.params[0])) { + format(client, ":invalid 464 * :Password incorrect\r\n"); + client->close = true; + } else { + client->need &= ~NeedPass; + } +} + +static void handleCap(struct Client *client, struct Command cmd) { // TODO... } + +static const struct { + const char *cmd; + Handler *fn; +} Handlers[] = { + { "CAP", handleCap }, + { "NICK", handleNick }, + { "PASS", handlePass }, + { "USER", handleUser }, +}; + +static void clientParse(struct Client *client, char *line) { + struct Command cmd = parse(line); + if (!cmd.name) { + // FIXME: Identify client in message. + warnx("no command"); + client->close = true; + return; + } + for (size_t i = 0; i < ARRAY_LEN(Handlers); ++i) { + if (strcmp(cmd.name, Handlers[i].cmd)) continue; + Handlers[i].fn(client, cmd); + break; + } +} + +void clientRecv(struct Client *client) { + ssize_t read = tls_read( + client->tls, + &client->buf[client->len], sizeof(client->buf) - client->len + ); + if (read == TLS_WANT_POLLIN || read == TLS_WANT_POLLOUT) return; + if (read < 0) warnx("tls_read: %s", tls_error(client->tls)); + if (read < 1) { + client->close = true; + return; + } + client->len += read; + + char *crlf; + char *line = client->buf; + for (;;) { + crlf = memmem(line, &client->buf[client->len] - line, "\r\n", 2); + if (!crlf) break; + crlf[0] = '\0'; + if (verbose) fprintf(stderr, "\x1B[33m%s\x1B[m\n", line); + clientParse(client, line); + line = crlf + 2; + } + client->len -= line - client->buf; + memmove(client->buf, line, client->len); +} -- cgit 1.4.1