diff options
-rw-r--r-- | bounce.c | 175 |
1 files changed, 87 insertions, 88 deletions
diff --git a/bounce.c b/bounce.c index 3f06603..9b18050 100644 --- a/bounce.c +++ b/bounce.c @@ -23,7 +23,9 @@ #include <getopt.h> #include <limits.h> #include <poll.h> +#include <pwd.h> #include <signal.h> +#include <stdbool.h> #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -39,10 +41,6 @@ #include <sys/capsicum.h> #endif -#ifndef SIGINFO -#define SIGINFO SIGUSR2 -#endif - static void hashPass(void) { char *pass = getpass("Password: "); byte rand[12]; @@ -52,9 +50,16 @@ static void hashPass(void) { printf("%s\n", crypt(pass, salt)); } +static size_t parseSize(const char *str) { + char *rest; + size_t size = strtoull(str, &rest, 0); + if (*rest) errx(EX_USAGE, "invalid size: %s", str); + return size; +} + static FILE *saveFile; -static void saveExit(void) { +static void saveSave(void) { int error = ringSave(saveFile); if (error) warn("fwrite"); error = fclose(saveFile); @@ -68,74 +73,68 @@ static void saveLoad(const char *path) { int error = flock(fileno(saveFile), LOCK_EX | LOCK_NB); if (error && errno != EWOULDBLOCK) err(EX_OSERR, "flock"); - if (error) errx(EX_CANTCREAT, "%s: lock held by other process", path); + if (error) errx(EX_CANTCREAT, "lock held by other process: %s", path); rewind(saveFile); ringLoad(saveFile); - error = ftruncate(fileno(saveFile), 0); if (error) err(EX_IOERR, "ftruncate"); - atexit(saveExit); + atexit(saveSave); } struct SplitPath { int dir; - const char *file; + char *file; int targetDir; }; -#ifdef __FreeBSD__ -static void splitLimit(struct SplitPath split, const cap_rights_t *rights) { - int error = cap_rights_limit(split.dir, rights); - if (error) err(EX_OSERR, "cap_rights_limit"); - if (split.targetDir < 0) return; - error = cap_rights_limit(split.targetDir, rights); - if (error) err(EX_OSERR, "cap_rights_limit"); +static bool linkTarget(char *target, size_t cap, int dir, const char *file) { + ssize_t len = readlinkat(dir, file, target, cap - 1); + if (len < 0 && errno == EINVAL) return false; + if (len < 0) err(EX_IOERR, "readlinkat"); + target[len] = '\0'; + return true; } -#endif static struct SplitPath splitPath(char *path) { struct SplitPath split = { .targetDir = -1 }; - char *base = strrchr(path, '/'); - if (base) { - *base++ = '\0'; + split.file = strrchr(path, '/'); + if (split.file) { + *split.file++ = '\0'; split.dir = open(path, O_DIRECTORY); } else { - base = path; + split.file = path; split.dir = open(".", O_DIRECTORY); } if (split.dir < 0) err(EX_NOINPUT, "%s", path); - split.file = base; // Capsicum workaround for certbot "live" symlinks to "../../archive". char target[PATH_MAX]; - ssize_t len = readlinkat(split.dir, split.file, target, sizeof(target) - 1); - if (len < 0 && errno == EINVAL) return split; - if (len < 0) err(EX_IOERR, "readlinkat"); - target[len] = '\0'; - - base = strrchr(target, '/'); - if (base) { - *base = '\0'; + if (!linkTarget(target, sizeof(target), split.dir, split.file)) { + return split; + } + char *file = strrchr(target, '/'); + if (file) { + *file = '\0'; split.targetDir = openat(split.dir, target, O_DIRECTORY); if (split.targetDir < 0) err(EX_NOINPUT, "%s", target); } + return split; } static FILE *splitOpen(struct SplitPath split) { if (split.targetDir >= 0) { char target[PATH_MAX]; - ssize_t len = readlinkat( - split.dir, split.file, target, sizeof(target) - 1 - ); - if (len < 0) err(EX_IOERR, "readlinkat"); - target[len] = '\0'; - + if (!linkTarget(target, sizeof(target), split.dir, split.file)) { + errx(EX_CONFIG, "file is no longer a symlink"); + } split.dir = split.targetDir; split.file = strrchr(target, '/'); - if (!split.file) errx(EX_CONFIG, "symlink no longer targets directory"); + if (!split.file) { + errx(EX_CONFIG, "symlink no longer targets directory"); + } split.file++; } @@ -146,6 +145,22 @@ static FILE *splitOpen(struct SplitPath split) { return file; } +#ifdef __FreeBSD__ +static void capLimit(int fd, const cap_rights_t *rights) { + int error = cap_rights_limit(fd, rights); + if (error) err(EX_OSERR, "cap_rights_limit"); +} +static void capLimitSplit(struct SplitPath split, const cap_rights_t *rights) { + capLimit(split.dir, rights); + if (split.targetDir >= 0) capLimit(split.targetDir, rights); +} +#endif + +static volatile sig_atomic_t signals[NSIG]; +static void signalHandler(int signal) { + signals[signal] = 1; +} + static struct { struct pollfd *fds; struct Client **clients; @@ -168,28 +183,26 @@ static void eventAdd(int fd, struct Client *client) { } static void eventRemove(size_t i) { + close(event.fds[i].fd); event.len--; event.fds[i] = event.fds[event.len]; event.clients[i] = event.clients[event.len]; } -static volatile sig_atomic_t signals[NSIG]; -static void signalHandler(int signal) { - signals[signal] = 1; -} - int main(int argc, char *argv[]) { + size_t ringSize = 4096; + const char *savePath = NULL; + const char *bindHost = "localhost"; const char *bindPort = "6697"; char bindPath[PATH_MAX] = ""; char certPath[PATH_MAX] = ""; char privPath[PATH_MAX] = ""; - const char *save = NULL; - size_t ring = 4096; bool insecure = false; const char *clientCert = NULL; const char *clientPriv = NULL; + const char *host = NULL; const char *port = "6697"; char *pass = NULL; @@ -198,6 +211,7 @@ int main(int argc, char *argv[]) { const char *nick = NULL; const char *user = NULL; const char *real = NULL; + const char *join = NULL; const char *away = "pounced :3"; const char *quit = "connection reset by purr"; @@ -247,18 +261,14 @@ int main(int argc, char *argv[]) { break; case 'a': sasl = true; plain = optarg; break; case 'c': clientCert = optarg; break; case 'e': sasl = true; - break; case 'f': save = optarg; + break; case 'f': savePath = optarg; break; case 'h': host = optarg; break; case 'j': join = optarg; break; case 'k': clientPriv = optarg; break; case 'n': nick = optarg; break; case 'p': port = optarg; break; case 'r': real = optarg; - break; case 's': { - char *rest; - ring = strtoull(optarg, &rest, 0); - if (*rest) errx(EX_USAGE, "invalid size: %s", optarg); - } + break; case 's': ringSize = parseSize(optarg); break; case 'u': user = optarg; break; case 'v': verbose = true; break; case 'w': pass = optarg; @@ -282,7 +292,8 @@ int main(int argc, char *argv[]) { if (!privPath[0]) { snprintf(privPath, sizeof(privPath), DEFAULT_PRIV_PATH, bindHost); } - if (!host) errx(EX_USAGE, "no host"); + + if (!host) errx(EX_USAGE, "host required"); if (!nick) { nick = getenv("USER"); if (!nick) errx(EX_CONFIG, "USER unset"); @@ -290,8 +301,8 @@ int main(int argc, char *argv[]) { if (!user) user = nick; if (!real) real = nick; - ringAlloc(ring); - if (save) saveLoad(save); + ringAlloc(ringSize); + if (savePath) saveLoad(savePath); struct SplitPath certSplit = splitPath(certPath); struct SplitPath privSplit = splitPath(privPath); @@ -320,18 +331,13 @@ int main(int argc, char *argv[]) { cap_rights_init(&bindRights, CAP_LISTEN, CAP_ACCEPT); cap_rights_merge(&bindRights, &sockRights); - if (saveFile) { - error = cap_rights_limit(fileno(saveFile), &saveRights); - if (error) err(EX_OSERR, "cap_rights_limit"); - } - splitLimit(certSplit, &fileRights); - splitLimit(privSplit, &fileRights); + if (saveFile) capLimit(fileno(saveFile), &saveRights); + capLimitSplit(certSplit, &fileRights); + capLimitSplit(privSplit, &fileRights); for (size_t i = 0; i < binds; ++i) { - error = cap_rights_limit(bind[i], &bindRights); - if (error) err(EX_OSERR, "cap_rights_limit"); + capLimit(bind[i], &bindRights); } - error = cap_rights_limit(server, &sockRights); - if (error) err(EX_OSERR, "cap_rights_limit"); + capLimit(server, &sockRights); #endif stateLogin(pass, sasl, plain, nick, user, real); @@ -342,6 +348,11 @@ int main(int argc, char *argv[]) { serverFormat("AWAY :%s\r\n", away); if (join) serverFormat("JOIN :%s\r\n", join); + signal(SIGINT, signalHandler); + signal(SIGTERM, signalHandler); + signal(SIGINFO, signalHandler); + signal(SIGUSR1, signalHandler); + for (size_t i = 0; i < binds; ++i) { int error = listen(bind[i], 1); if (error) err(EX_IOERR, "listen"); @@ -349,21 +360,17 @@ int main(int argc, char *argv[]) { } eventAdd(server, NULL); - signal(SIGINT, signalHandler); - signal(SIGTERM, signalHandler); - signal(SIGINFO, signalHandler); - signal(SIGUSR1, signalHandler); - size_t clients = 0; for (;;) { int nfds = poll(event.fds, event.len, -1); if (nfds < 0 && errno != EINTR) err(EX_IOERR, "poll"); - if (signals[SIGINT] || signals[SIGTERM]) break; + if (signals[SIGINFO]) { ringInfo(); signals[SIGINFO] = 0; } + if (signals[SIGUSR1]) { cert = splitOpen(certSplit); priv = splitOpen(privSplit); @@ -378,7 +385,12 @@ int main(int argc, char *argv[]) { short revents = event.fds[i].revents; if (!revents) continue; - if (i < binds) { + if (event.fds[i].fd == server) { + serverRecv(); + continue; + } + + if (!event.clients[i]) { int fd; struct tls *tls = listenAccept(&fd, event.fds[i].fd); int error = tls_handshake(tls); @@ -393,21 +405,11 @@ int main(int argc, char *argv[]) { continue; } - if (!event.clients[i]) { - if (revents & POLLIN) { - serverRecv(); - } else { - errx(EX_UNAVAILABLE, "server hung up"); - } - continue; - } - struct Client *client = event.clients[i]; if (revents & POLLIN) clientRecv(client); if (revents & POLLOUT) clientConsume(client); if (clientError(client) || revents & (POLLHUP | POLLERR)) { clientFree(client); - close(event.fds[i].fd); eventRemove(i); if (!--clients) serverFormat("AWAY :%s\r\n", away); } @@ -424,14 +426,11 @@ int main(int argc, char *argv[]) { } serverFormat("QUIT :%s\r\n", quit); - for (size_t i = 0; i < event.len; ++i) { - if (event.clients[i]) { - clientFormat( - event.clients[i], ":%s QUIT :%s\r\nERROR :Disconnecting\r\n", - stateEcho(), quit - ); - clientFree(event.clients[i]); - } + for (size_t i = binds + 1; i < event.len; ++i) { + assert(event.clients[i]); + clientFormat(event.clients[i], ":%s QUIT :%s\r\n", stateEcho(), quit); + clientFormat(event.clients[i], "ERROR :Disconnecting\r\n"); + clientFree(event.clients[i]); close(event.fds[i].fd); } } |