summary refs log tree commit diff
path: root/dispatch.c
diff options
context:
space:
mode:
Diffstat (limited to 'dispatch.c')
-rw-r--r--dispatch.c273
1 files changed, 273 insertions, 0 deletions
diff --git a/dispatch.c b/dispatch.c
new file mode 100644
index 0000000..3f28e7f
--- /dev/null
+++ b/dispatch.c
@@ -0,0 +1,273 @@
+/* Copyright (C) 2019  C. McEnroe <june@causal.agency>
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include <err.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <sysexits.h>
+#include <unistd.h>
+
+static struct {
+	struct pollfd *ptr;
+	size_t len, cap;
+} event;
+
+static void eventAdd(int fd) {
+	if (event.len == event.cap) {
+		event.cap = (event.cap ? event.cap * 2 : 8);
+		event.ptr = realloc(event.ptr, sizeof(*event.ptr) * event.cap);
+		if (!event.ptr) err(EX_OSERR, "malloc");
+	}
+	event.ptr[event.len++] = (struct pollfd) {
+		.fd = fd,
+		.events = POLLIN,
+	};
+}
+
+static void eventRemove(size_t i) {
+	close(event.ptr[i].fd);
+	event.ptr[i] = event.ptr[--event.len];
+}
+
+static ssize_t sendfd(int sock, int fd) {
+	size_t len = CMSG_SPACE(sizeof(int));
+	char buf[len];
+
+	char x = 0;
+	struct iovec iov = { .iov_base = &x, .iov_len = 1 };
+	struct msghdr msg = {
+		.msg_iov = &iov,
+		.msg_iovlen = 1,
+		.msg_control = buf,
+		.msg_controllen = len,
+	};
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	*(int *)CMSG_DATA(cmsg) = fd;
+
+	return sendmsg(sock, &msg, 0);
+}
+
+static struct {
+	uint8_t buf[4096];
+	uint8_t *ptr;
+	size_t len;
+} peek;
+
+static void skip(size_t skip) {
+	if (peek.len < skip) skip = peek.len;
+	peek.ptr += skip;
+	peek.len -= skip;
+}
+static uint8_t uint8(void) {
+	if (peek.len < 1) return 0;
+	peek.len--;
+	return *peek.ptr++;
+}
+static uint16_t uint16(void) {
+	uint16_t val = uint8();
+	return val << 8 | uint8();
+}
+
+static char *serverName(void) {
+	peek.ptr = peek.buf;
+	// TLSPlaintext
+	if (uint8() != 22) return NULL;
+	skip(4);
+	// Handshake
+	if (uint8() != 1) return NULL;
+	skip(3);
+	// ClientHello
+	skip(34);
+	skip(uint8());
+	skip(uint16());
+	skip(uint8());
+	peek.len = uint16();
+	while (peek.len) {
+		// Extension
+		uint16_t type = uint16();
+		uint16_t len = uint16();
+		if (type != 0) {
+			skip(len);
+			continue;
+		}
+		// ServerNameList
+		skip(2);
+		// ServerName
+		if (uint8() != 0) return NULL;
+		// HostName
+		len = uint16();
+		char *name = (char *)peek.ptr;
+		skip(len);
+		*peek.ptr = '\0';
+		return name;
+	}
+	return NULL;
+}
+
+int main(int argc, char *argv[]) {
+	const char *host = "localhost";
+	const char *port = "6697";
+	const char *path = NULL;
+	int timeout = 1000;
+
+	int opt;
+	while (0 < (opt = getopt(argc, argv, "H:P:t:"))) {
+		switch (opt) {
+			break; case 'H': host = optarg;
+			break; case 'P': port = optarg;
+			break; case 't': {
+				char *rest;
+				timeout = strtol(optarg, &rest, 0);
+				if (*rest) errx(EX_USAGE, "invalid timeout: %s", optarg);
+			}
+			break; default:  return EX_USAGE;
+		}
+	}
+	if (optind < argc) {
+		path = argv[optind];
+	} else {
+		errx(EX_USAGE, "directory required");
+	}
+
+	int dir = open(path, O_DIRECTORY);
+	if (dir < 0) err(EX_NOINPUT, "%s", path);
+
+	int error = fchdir(dir);
+	if (error) err(EX_NOINPUT, "%s", path);
+
+	struct addrinfo *head;
+	struct addrinfo hints = {
+		.ai_family = AF_UNSPEC,
+		.ai_socktype = SOCK_STREAM,
+		.ai_protocol = IPPROTO_TCP,
+	};
+	error = getaddrinfo(host, port, &hints, &head);
+	if (error) errx(EX_NOHOST, "%s:%s: %s", host, port, gai_strerror(error));
+
+	size_t binds = 0;
+	for (struct addrinfo *ai = head; ai; ai = ai->ai_next) {
+		int sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
+		if (sock < 0) err(EX_OSERR, "socket");
+
+		int yes = 1;
+		error = setsockopt(
+			sock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)
+		);
+		if (error) err(EX_OSERR, "setsockopt");
+
+		error = bind(sock, ai->ai_addr, ai->ai_addrlen);
+		if (error) {
+			warn("%s:%s", host, port);
+			close(sock);
+			continue;
+		}
+
+		eventAdd(sock);
+		binds++;
+	}
+	if (!binds) errx(EX_UNAVAILABLE, "could not bind any sockets");
+	freeaddrinfo(head);
+
+	for (size_t i = 0; i < binds; ++i) {
+		error = listen(event.ptr[i].fd, 1);
+		if (error) err(EX_IOERR, "listen");
+	}
+
+	for (;;) {
+		int nfds = poll(
+			event.ptr, event.len, (event.len > binds ? timeout : -1)
+		);
+		if (nfds < 0) err(EX_IOERR, "poll");
+
+		if (!nfds) {
+			for (size_t i = event.len - 1; i >= binds; --i) {
+				eventRemove(i);
+			}
+			continue;
+		}
+
+		for (size_t i = event.len - 1; i < event.len; --i) {
+			if (!event.ptr[i].revents) continue;
+
+			if (i < binds) {
+				int sock = accept(event.ptr[i].fd, NULL, NULL);
+				if (sock < 0) err(EX_IOERR, "accept");
+
+				int yes = 1;
+				error = setsockopt(
+					sock, SOL_SOCKET, SO_NOSIGPIPE, &yes, sizeof(yes)
+				);
+				if (error) err(EX_OSERR, "setsockopt");
+
+				eventAdd(sock);
+				continue;
+			}
+
+			if (event.ptr[i].revents & (POLLHUP | POLLERR)) {
+				eventRemove(i);
+				continue;
+			}
+
+			ssize_t len = recv(
+				event.ptr[i].fd, peek.buf, sizeof(peek.buf) - 1, MSG_PEEK
+			);
+			if (len < 0) {
+				warn("recv");
+				eventRemove(i);
+				continue;
+			}
+			peek.len = len;
+
+			char *name = serverName();
+			if (!name || name[0] == '.' || name[0] == '/') {
+				eventRemove(i);
+				continue;
+			}
+
+			int sock = socket(PF_UNIX, SOCK_STREAM, 0);
+			if (sock < 0) err(EX_OSERR, "socket");
+
+			struct sockaddr_un addr = { .sun_family = AF_UNIX };
+			strncpy(addr.sun_path, name, sizeof(addr.sun_path));
+#ifdef __FreeBSD__
+			error = connectat(
+				dir, sock, (struct sockaddr *)&addr, SUN_LEN(&addr)
+			);
+#else
+			error = connect(sock, (struct sockaddr *)&addr, SUN_LEN(&addr));
+#endif
+			if (error) warn("%s", name);
+
+			len = sendfd(sock, event.ptr[i].fd);
+			if (len < 0) warn("%s", name);
+
+			close(sock);
+			eventRemove(i);
+		}
+	}
+}