about summary refs log tree commit diff
path: root/local.c
diff options
context:
space:
mode:
Diffstat (limited to 'local.c')
-rw-r--r--local.c160
1 files changed, 57 insertions, 103 deletions
diff --git a/local.c b/local.c
index a697e15..fcd670a 100644
--- a/local.c
+++ b/local.c
@@ -1,4 +1,4 @@
-/* Copyright (C) 2019  C. McEnroe <june@causal.agency>
+/* Copyright (C) 2019  June McEnroe <june@causal.agency>
  *
  * This program is free software: you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -27,82 +27,51 @@
 
 #include <err.h>
 #include <errno.h>
-#include <fcntl.h>
 #include <limits.h>
 #include <netdb.h>
 #include <netinet/in.h>
-#include <netinet/tcp.h>
 #include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <string.h>
 #include <sys/socket.h>
-#include <sys/stat.h>
 #include <sys/un.h>
 #include <sysexits.h>
 #include <tls.h>
 #include <unistd.h>
 
-#ifdef __FreeBSD__
-#include <sys/capsicum.h>
-#endif
-
 #include "bounce.h"
 
-#ifdef __APPLE__
-#define TCP_KEEPIDLE TCP_KEEPALIVE
-#endif
-
 static struct tls *server;
 
-static byte *readFile(size_t *len, FILE *file) {
-	struct stat stat;
-	int error = fstat(fileno(file), &stat);
-	if (error) err(EX_IOERR, "fstat");
-
-	byte *buf = malloc(stat.st_size);
-	if (!buf) err(EX_OSERR, "malloc");
-
-	rewind(file);
-	*len = fread(buf, 1, stat.st_size, file);
-	if (ferror(file)) err(EX_IOERR, "fread");
-
-	return buf;
-}
-
-void localConfig(FILE *cert, FILE *priv, FILE *ca, bool require) {
-	tls_free(server);
-	server = tls_server();
+int localConfig(
+	const char *cert, const char *priv, const char *ca, bool require
+) {
+	if (!server) server = tls_server();
 	if (!server) errx(EX_SOFTWARE, "tls_server");
 
 	struct tls_config *config = tls_config_new();
 	if (!config) errx(EX_SOFTWARE, "tls_config_new");
 
-	size_t len;
-	byte *buf = readFile(&len, cert);
-	int error = tls_config_set_cert_mem(config, buf, len);
-	if (error) {
-		errx(EX_CONFIG, "tls_config_set_cert_mem: %s", tls_config_error(config));
+	int error;
+	char buf[PATH_MAX];
+	for (int i = 0; configPath(buf, sizeof(buf), cert, i); ++i) {
+		error = tls_config_set_cert_file(config, buf);
+		if (!error) break;
 	}
-	free(buf);
+	if (error) goto fail;
 
-	buf = readFile(&len, priv);
-	error = tls_config_set_key_mem(config, buf, len);
-	if (error) {
-		errx(EX_CONFIG, "tls_config_set_key_mem: %s", tls_config_error(config));
+	for (int i = 0; configPath(buf, sizeof(buf), priv, i); ++i) {
+		error = tls_config_set_key_file(config, buf);
+		if (!error) break;
 	}
-	free(buf);
+	if (error) goto fail;
 
 	if (ca) {
-		buf = readFile(&len, ca);
-		error = tls_config_set_ca_mem(config, buf, len);
-		if (error) {
-			errx(
-				EX_CONFIG, "tls_config_set_ca_mem: %s",
-				tls_config_error(config)
-			);
+		for (int i = 0; configPath(buf, sizeof(buf), ca, i); ++i) {
+			error = tls_config_set_ca_file(config, buf);
+			if (!error) break;
 		}
-		free(buf);
+		if (error) goto fail;
 		if (require) {
 			tls_config_verify_client(config);
 		} else {
@@ -113,6 +82,12 @@ void localConfig(FILE *cert, FILE *priv, FILE *ca, bool require) {
 	error = tls_configure(server, config);
 	if (error) errx(EX_SOFTWARE, "tls_configure: %s", tls_error(server));
 	tls_config_free(config);
+	return 0;
+
+fail:
+	warnx("%s", tls_config_error(config));
+	tls_config_free(config);
+	return -1;
 }
 
 size_t localBind(int fds[], size_t cap, const char *host, const char *port) {
@@ -150,50 +125,39 @@ size_t localBind(int fds[], size_t cap, const char *host, const char *port) {
 }
 
 static bool unix;
-static int unixDir = -1;
-static char unixFile[PATH_MAX];
-
-static void unixUnlink(void) {
-	int error = unlinkat(unixDir, unixFile, 0);
-	if (error) warn("unlinkat");
-}
-
-size_t localUnix(int fds[], size_t cap, const char *path) {
-	if (!cap) return 0;
-
-	int sock = socket(PF_UNIX, SOCK_STREAM, 0);
-	if (sock < 0) err(EX_OSERR, "socket");
 
+static int unixBind(int sock, const char *path) {
 	struct sockaddr_un addr = { .sun_family = AF_UNIX };
-	int len = snprintf(
-		addr.sun_path, sizeof(addr.sun_path), "%s", path
-	);
+	int len = snprintf(addr.sun_path, sizeof(addr.sun_path), "%s", path);
 	if ((size_t)len >= sizeof(addr.sun_path)) {
 		errx(EX_CONFIG, "path too long: %s", path);
 	}
 
 	int error = bind(sock, (struct sockaddr *)&addr, SUN_LEN(&addr));
-	if (error) err(EX_UNAVAILABLE, "%s", path);
+	if (!error || errno != EADDRINUSE) return error;
+
+	int check = socket(PF_UNIX, SOCK_STREAM, 0);
+	if (check < 0) err(EX_OSERR, "socket");
 
-	char dir[PATH_MAX] = ".";
-	const char *base = strrchr(path, '/');
-	if (base) {
-		snprintf(dir, sizeof(dir), "%.*s", (int)(base - path), path);
-		base++;
-	} else {
-		base = path;
+	error = connect(check, (struct sockaddr *)&addr, SUN_LEN(&addr));
+	close(check);
+	if (!error) {
+		errno = EADDRINUSE;
+		return -1;
 	}
-	snprintf(unixFile, sizeof(unixFile), "%s", base);
 
-	unixDir = open(dir, O_DIRECTORY);
-	if (unixDir < 0) err(EX_UNAVAILABLE, "%s", dir);
-	atexit(unixUnlink);
+	unlink(path);
+	return bind(sock, (struct sockaddr *)&addr, SUN_LEN(&addr));
+}
+
+size_t localUnix(int fds[], size_t cap, const char *path) {
+	if (!cap) return 0;
+
+	int sock = socket(PF_UNIX, SOCK_STREAM, 0);
+	if (sock < 0) err(EX_OSERR, "socket");
 
-#ifdef __FreeBSD__
-	cap_rights_t rights;
-	error = cap_rights_limit(unixDir, cap_rights_init(&rights, CAP_UNLINKAT));
-	if (error) err(EX_OSERR, "cap_rights_limit");
-#endif
+	int error = unixBind(sock, path);
+	if (error) err(EX_UNAVAILABLE, "%s", path);
 
 	unix = true;
 	fds[0] = sock;
@@ -221,29 +185,19 @@ static int recvfd(int sock) {
 	return *(int *)CMSG_DATA(cmsg);
 }
 
-struct tls *localAccept(int *fd, int bind) {
-	*fd = accept(bind, NULL, NULL);
-	if (*fd < 0) return NULL;
+int localAccept(struct tls **client, int bind) {
+	int fd = accept(bind, NULL, NULL);
+	if (fd < 0) return fd;
 
 	if (unix) {
-		int sent = recvfd(*fd);
-		if (sent < 0) err(EX_IOERR, "recvfd");
-		close(*fd);
-		*fd = sent;
+		int sent = recvfd(fd);
+		close(fd);
+		if (sent < 0) return sent;
+		fd = sent;
 	}
 
-	int on = 1;
-	int error = setsockopt(*fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on));
-	if (error) err(EX_OSERR, "setsockopt");
-
-#ifdef TCP_KEEPIDLE
-	int idle = 15 * 60;
-	error = setsockopt(*fd, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle));
-	if (error) err(EX_OSERR, "setsockopt");
-#endif
-
-	struct tls *client;
-	error = tls_accept_socket(server, &client, *fd);
+	int error = tls_accept_socket(server, client, fd);
 	if (error) errx(EX_SOFTWARE, "tls_accept_socket: %s", tls_error(server));
-	return client;
+
+	return fd;
 }