summary refs log tree commit diff
path: root/bounce.c
diff options
context:
space:
mode:
Diffstat (limited to 'bounce.c')
-rw-r--r--bounce.c175
1 files changed, 87 insertions, 88 deletions
diff --git a/bounce.c b/bounce.c
index 3f06603..9b18050 100644
--- a/bounce.c
+++ b/bounce.c
@@ -23,7 +23,9 @@
 #include <getopt.h>
 #include <limits.h>
 #include <poll.h>
+#include <pwd.h>
 #include <signal.h>
+#include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
@@ -39,10 +41,6 @@
 #include <sys/capsicum.h>
 #endif
 
-#ifndef SIGINFO
-#define SIGINFO SIGUSR2
-#endif
-
 static void hashPass(void) {
 	char *pass = getpass("Password: ");
 	byte rand[12];
@@ -52,9 +50,16 @@ static void hashPass(void) {
 	printf("%s\n", crypt(pass, salt));
 }
 
+static size_t parseSize(const char *str) {
+	char *rest;
+	size_t size = strtoull(str, &rest, 0);
+	if (*rest) errx(EX_USAGE, "invalid size: %s", str);
+	return size;
+}
+
 static FILE *saveFile;
 
-static void saveExit(void) {
+static void saveSave(void) {
 	int error = ringSave(saveFile);
 	if (error) warn("fwrite");
 	error = fclose(saveFile);
@@ -68,74 +73,68 @@ static void saveLoad(const char *path) {
 
 	int error = flock(fileno(saveFile), LOCK_EX | LOCK_NB);
 	if (error && errno != EWOULDBLOCK) err(EX_OSERR, "flock");
-	if (error) errx(EX_CANTCREAT, "%s: lock held by other process", path);
+	if (error) errx(EX_CANTCREAT, "lock held by other process: %s", path);
 
 	rewind(saveFile);
 	ringLoad(saveFile);
-
 	error = ftruncate(fileno(saveFile), 0);
 	if (error) err(EX_IOERR, "ftruncate");
 
-	atexit(saveExit);
+	atexit(saveSave);
 }
 
 struct SplitPath {
 	int dir;
-	const char *file;
+	char *file;
 	int targetDir;
 };
 
-#ifdef __FreeBSD__
-static void splitLimit(struct SplitPath split, const cap_rights_t *rights) {
-	int error = cap_rights_limit(split.dir, rights);
-	if (error) err(EX_OSERR, "cap_rights_limit");
-	if (split.targetDir < 0) return;
-	error = cap_rights_limit(split.targetDir, rights);
-	if (error) err(EX_OSERR, "cap_rights_limit");
+static bool linkTarget(char *target, size_t cap, int dir, const char *file) {
+	ssize_t len = readlinkat(dir, file, target, cap - 1);
+	if (len < 0 && errno == EINVAL) return false;
+	if (len < 0) err(EX_IOERR, "readlinkat");
+	target[len] = '\0';
+	return true;
 }
-#endif
 
 static struct SplitPath splitPath(char *path) {
 	struct SplitPath split = { .targetDir = -1 };
-	char *base = strrchr(path, '/');
-	if (base) {
-		*base++ = '\0';
+	split.file = strrchr(path, '/');
+	if (split.file) {
+		*split.file++ = '\0';
 		split.dir = open(path, O_DIRECTORY);
 	} else {
-		base = path;
+		split.file = path;
 		split.dir = open(".", O_DIRECTORY);
 	}
 	if (split.dir < 0) err(EX_NOINPUT, "%s", path);
-	split.file = base;
 
 	// Capsicum workaround for certbot "live" symlinks to "../../archive".
 	char target[PATH_MAX];
-	ssize_t len = readlinkat(split.dir, split.file, target, sizeof(target) - 1);
-	if (len < 0 && errno == EINVAL) return split;
-	if (len < 0) err(EX_IOERR, "readlinkat");
-	target[len] = '\0';
-
-	base = strrchr(target, '/');
-	if (base) {
-		*base = '\0';
+	if (!linkTarget(target, sizeof(target), split.dir, split.file)) {
+		return split;
+	}
+	char *file = strrchr(target, '/');
+	if (file) {
+		*file = '\0';
 		split.targetDir = openat(split.dir, target, O_DIRECTORY);
 		if (split.targetDir < 0) err(EX_NOINPUT, "%s", target);
 	}
+
 	return split;
 }
 
 static FILE *splitOpen(struct SplitPath split) {
 	if (split.targetDir >= 0) {
 		char target[PATH_MAX];
-		ssize_t len = readlinkat(
-			split.dir, split.file, target, sizeof(target) - 1
-		);
-		if (len < 0) err(EX_IOERR, "readlinkat");
-		target[len] = '\0';
-
+		if (!linkTarget(target, sizeof(target), split.dir, split.file)) {
+			errx(EX_CONFIG, "file is no longer a symlink");
+		}
 		split.dir = split.targetDir;
 		split.file = strrchr(target, '/');
-		if (!split.file) errx(EX_CONFIG, "symlink no longer targets directory");
+		if (!split.file) {
+			errx(EX_CONFIG, "symlink no longer targets directory");
+		}
 		split.file++;
 	}
 
@@ -146,6 +145,22 @@ static FILE *splitOpen(struct SplitPath split) {
 	return file;
 }
 
+#ifdef __FreeBSD__
+static void capLimit(int fd, const cap_rights_t *rights) {
+	int error = cap_rights_limit(fd, rights);
+	if (error) err(EX_OSERR, "cap_rights_limit");
+}
+static void capLimitSplit(struct SplitPath split, const cap_rights_t *rights) {
+	capLimit(split.dir, rights);
+	if (split.targetDir >= 0) capLimit(split.targetDir, rights);
+}
+#endif
+
+static volatile sig_atomic_t signals[NSIG];
+static void signalHandler(int signal) {
+	signals[signal] = 1;
+}
+
 static struct {
 	struct pollfd *fds;
 	struct Client **clients;
@@ -168,28 +183,26 @@ static void eventAdd(int fd, struct Client *client) {
 }
 
 static void eventRemove(size_t i) {
+	close(event.fds[i].fd);
 	event.len--;
 	event.fds[i] = event.fds[event.len];
 	event.clients[i] = event.clients[event.len];
 }
 
-static volatile sig_atomic_t signals[NSIG];
-static void signalHandler(int signal) {
-	signals[signal] = 1;
-}
-
 int main(int argc, char *argv[]) {
+	size_t ringSize = 4096;
+	const char *savePath = NULL;
+
 	const char *bindHost = "localhost";
 	const char *bindPort = "6697";
 	char bindPath[PATH_MAX] = "";
 	char certPath[PATH_MAX] = "";
 	char privPath[PATH_MAX] = "";
-	const char *save = NULL;
-	size_t ring = 4096;
 
 	bool insecure = false;
 	const char *clientCert = NULL;
 	const char *clientPriv = NULL;
+
 	const char *host = NULL;
 	const char *port = "6697";
 	char *pass = NULL;
@@ -198,6 +211,7 @@ int main(int argc, char *argv[]) {
 	const char *nick = NULL;
 	const char *user = NULL;
 	const char *real = NULL;
+
 	const char *join = NULL;
 	const char *away = "pounced :3";
 	const char *quit = "connection reset by purr";
@@ -247,18 +261,14 @@ int main(int argc, char *argv[]) {
 			break; case 'a': sasl = true; plain = optarg;
 			break; case 'c': clientCert = optarg;
 			break; case 'e': sasl = true;
-			break; case 'f': save = optarg;
+			break; case 'f': savePath = optarg;
 			break; case 'h': host = optarg;
 			break; case 'j': join = optarg;
 			break; case 'k': clientPriv = optarg;
 			break; case 'n': nick = optarg;
 			break; case 'p': port = optarg;
 			break; case 'r': real = optarg;
-			break; case 's': {
-				char *rest;
-				ring = strtoull(optarg, &rest, 0);
-				if (*rest) errx(EX_USAGE, "invalid size: %s", optarg);
-			}
+			break; case 's': ringSize = parseSize(optarg);
 			break; case 'u': user = optarg;
 			break; case 'v': verbose = true;
 			break; case 'w': pass = optarg;
@@ -282,7 +292,8 @@ int main(int argc, char *argv[]) {
 	if (!privPath[0]) {
 		snprintf(privPath, sizeof(privPath), DEFAULT_PRIV_PATH, bindHost);
 	}
-	if (!host) errx(EX_USAGE, "no host");
+
+	if (!host) errx(EX_USAGE, "host required");
 	if (!nick) {
 		nick = getenv("USER");
 		if (!nick) errx(EX_CONFIG, "USER unset");
@@ -290,8 +301,8 @@ int main(int argc, char *argv[]) {
 	if (!user) user = nick;
 	if (!real) real = nick;
 
-	ringAlloc(ring);
-	if (save) saveLoad(save);
+	ringAlloc(ringSize);
+	if (savePath) saveLoad(savePath);
 
 	struct SplitPath certSplit = splitPath(certPath);
 	struct SplitPath privSplit = splitPath(privPath);
@@ -320,18 +331,13 @@ int main(int argc, char *argv[]) {
 	cap_rights_init(&bindRights, CAP_LISTEN, CAP_ACCEPT);
 	cap_rights_merge(&bindRights, &sockRights);
 
-	if (saveFile) {
-		error = cap_rights_limit(fileno(saveFile), &saveRights);
-		if (error) err(EX_OSERR, "cap_rights_limit");
-	}
-	splitLimit(certSplit, &fileRights);
-	splitLimit(privSplit, &fileRights);
+	if (saveFile) capLimit(fileno(saveFile), &saveRights);
+	capLimitSplit(certSplit, &fileRights);
+	capLimitSplit(privSplit, &fileRights);
 	for (size_t i = 0; i < binds; ++i) {
-		error = cap_rights_limit(bind[i], &bindRights);
-		if (error) err(EX_OSERR, "cap_rights_limit");
+		capLimit(bind[i], &bindRights);
 	}
-	error = cap_rights_limit(server, &sockRights);
-	if (error) err(EX_OSERR, "cap_rights_limit");
+	capLimit(server, &sockRights);
 #endif
 
 	stateLogin(pass, sasl, plain, nick, user, real);
@@ -342,6 +348,11 @@ int main(int argc, char *argv[]) {
 	serverFormat("AWAY :%s\r\n", away);
 	if (join) serverFormat("JOIN :%s\r\n", join);
 
+	signal(SIGINT, signalHandler);
+	signal(SIGTERM, signalHandler);
+	signal(SIGINFO, signalHandler);
+	signal(SIGUSR1, signalHandler);
+
 	for (size_t i = 0; i < binds; ++i) {
 		int error = listen(bind[i], 1);
 		if (error) err(EX_IOERR, "listen");
@@ -349,21 +360,17 @@ int main(int argc, char *argv[]) {
 	}
 	eventAdd(server, NULL);
 
-	signal(SIGINT, signalHandler);
-	signal(SIGTERM, signalHandler);
-	signal(SIGINFO, signalHandler);
-	signal(SIGUSR1, signalHandler);
-
 	size_t clients = 0;
 	for (;;) {
 		int nfds = poll(event.fds, event.len, -1);
 		if (nfds < 0 && errno != EINTR) err(EX_IOERR, "poll");
-
 		if (signals[SIGINT] || signals[SIGTERM]) break;
+
 		if (signals[SIGINFO]) {
 			ringInfo();
 			signals[SIGINFO] = 0;
 		}
+
 		if (signals[SIGUSR1]) {
 			cert = splitOpen(certSplit);
 			priv = splitOpen(privSplit);
@@ -378,7 +385,12 @@ int main(int argc, char *argv[]) {
 			short revents = event.fds[i].revents;
 			if (!revents) continue;
 
-			if (i < binds) {
+			if (event.fds[i].fd == server) {
+				serverRecv();
+				continue;
+			}
+
+			if (!event.clients[i]) {
 				int fd;
 				struct tls *tls = listenAccept(&fd, event.fds[i].fd);
 				int error = tls_handshake(tls);
@@ -393,21 +405,11 @@ int main(int argc, char *argv[]) {
 				continue;
 			}
 
-			if (!event.clients[i]) {
-				if (revents & POLLIN) {
-					serverRecv();
-				} else {
-					errx(EX_UNAVAILABLE, "server hung up");
-				}
-				continue;
-			}
-
 			struct Client *client = event.clients[i];
 			if (revents & POLLIN) clientRecv(client);
 			if (revents & POLLOUT) clientConsume(client);
 			if (clientError(client) || revents & (POLLHUP | POLLERR)) {
 				clientFree(client);
-				close(event.fds[i].fd);
 				eventRemove(i);
 				if (!--clients) serverFormat("AWAY :%s\r\n", away);
 			}
@@ -424,14 +426,11 @@ int main(int argc, char *argv[]) {
 	}
 
 	serverFormat("QUIT :%s\r\n", quit);
-	for (size_t i = 0; i < event.len; ++i) {
-		if (event.clients[i]) {
-			clientFormat(
-				event.clients[i], ":%s QUIT :%s\r\nERROR :Disconnecting\r\n",
-				stateEcho(), quit
-			);
-			clientFree(event.clients[i]);
-		}
+	for (size_t i = binds + 1; i < event.len; ++i) {
+		assert(event.clients[i]);
+		clientFormat(event.clients[i], ":%s QUIT :%s\r\n", stateEcho(), quit);
+		clientFormat(event.clients[i], "ERROR :Disconnecting\r\n");
+		clientFree(event.clients[i]);
 		close(event.fds[i].fd);
 	}
 }