summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--bounce.c23
-rw-r--r--bounce.h8
-rw-r--r--client.c140
3 files changed, 153 insertions, 18 deletions
diff --git a/bounce.c b/bounce.c
index ff99b91..07ed4ae 100644
--- a/bounce.c
+++ b/bounce.c
@@ -43,6 +43,7 @@ static void loopAdd(int fd, struct Client *client) {
 
 	loop.fds[loop.len].fd = fd;
 	loop.fds[loop.len].events = POLLIN;
+	loop.fds[loop.len].revents = 0;
 	loop.clients[loop.len] = client;
 	loop.len++;
 }
@@ -64,7 +65,6 @@ static char *censor(char *arg) {
 int main(int argc, char *argv[]) {
 	const char *localHost = "localhost";
 	const char *localPort = "6697";
-	const char *localPass = NULL;
 	char certPath[PATH_MAX] = "";
 	char privPath[PATH_MAX] = "";
 
@@ -84,7 +84,7 @@ int main(int argc, char *argv[]) {
 			break; case 'H': localHost = optarg;
 			break; case 'K': strlcpy(privPath, optarg, sizeof(privPath));
 			break; case 'P': localPort = optarg;
-			break; case 'W': localPass = censor(optarg);
+			break; case 'W': clientPass = censor(optarg);
 			break; case 'a': auth = censor(optarg);
 			break; case 'h': host = optarg;
 			break; case 'j': join = optarg;
@@ -138,17 +138,22 @@ int main(int argc, char *argv[]) {
 		for (size_t i = 0; i < loop.len; ++i) {
 			if (!loop.fds[i].revents) continue;
 			if (i < bindLen) {
-				struct Client *client = clientAlloc();
-				loopAdd(listenAccept(&client->tls, loop.fds[i].fd), client);
+				struct tls *tls;
+				int fd = listenAccept(&tls, loop.fds[i].fd);
+				loopAdd(fd, clientAlloc(tls));
 			} else if (!loop.clients[i]) {
 				serverRecv();
-			} else if (loop.fds[i].revents & POLLERR) {
-				close(loop.fds[i].fd);
-				clientFree(loop.clients[i]);
-				loopRemove(i);
 			} else {
-				clientRecv(loop.clients[i]);
+				struct Client *client = loop.clients[i];
+				if (loop.fds[i].revents & POLLIN) clientRecv(client);
+				if (loop.fds[i].revents & ~POLLIN || clientClose(client)) {
+					clientFree(client);
+					close(loop.fds[i].fd);
+					loopRemove(i);
+					break;
+				}
 			}
 		}
 	}
+	err(EX_IOERR, "poll");
 }
diff --git a/bounce.h b/bounce.h
index 491616b..0929da7 100644
--- a/bounce.h
+++ b/bounce.h
@@ -27,10 +27,6 @@
 #define DEFAULT_PRIV_PATH "/usr/local/etc/letsencrypt/live/%s/privkey.pem"
 #endif
 
-struct Client {
-	struct tls *tls;
-};
-
 #define ARRAY_LEN(a) (sizeof(a) / sizeof(a[0]))
 
 enum { ParamCap = 15 };
@@ -73,6 +69,8 @@ void serverJoin(const char *join);
 void serverSend(const char *ptr, size_t len);
 void serverRecv(void);
 
-struct Client *clientAlloc(void);
+char *clientPass;
+struct Client *clientAlloc(struct tls *tls);
 void clientFree(struct Client *client);
+bool clientClose(const struct Client *client);
 void clientRecv(struct Client *client);
diff --git a/client.c b/client.c
index d012b28..0bba8d3 100644
--- a/client.c
+++ b/client.c
@@ -14,23 +14,155 @@
  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+#include <assert.h>
 #include <err.h>
+#include <stdarg.h>
+#include <stdbool.h>
+#include <stdio.h>
 #include <stdlib.h>
+#include <string.h>
 #include <sysexits.h>
+#include <tls.h>
 
 #include "bounce.h"
 
-struct Client *clientAlloc(void) {
-	struct Client *client = calloc(1, sizeof(*client));
-	if (!client) err(EX_OSERR, "calloc");
+enum Need {
+	NeedNick = 1 << 0,
+	NeedUser = 1 << 1,
+	NeedPass = 1 << 2,
+	NeedCapEnd = 1 << 3,
+};
+
+struct Client {
+	bool close;
+	struct tls *tls;
+	enum Need need;
+	char buf[4096];
+	size_t len;
+};
+
+struct Client *clientAlloc(struct tls *tls) {
+	struct Client *client = malloc(sizeof(*client));
+	if (!client) err(EX_OSERR, "malloc");
+
+	client->close = false;
+	client->tls = tls;
+	client->need = NeedNick | NeedUser | (clientPass ? NeedPass : 0);
+	client->len = 0;
+
 	return client;
 }
 
 void clientFree(struct Client *client) {
+	tls_close(client->tls);
 	tls_free(client->tls);
 	free(client);
 }
 
-void clientRecv(struct Client *client) {
+bool clientClose(const struct Client *client) {
+	return client->close;
+}
+
+static void clientSend(struct Client *client, const char *ptr, size_t len) {
+	if (verbose) fprintf(stderr, "\x1B[34m%.*s\x1B[m", (int)len, ptr);
+	while (len) {
+		ssize_t ret = tls_write(client->tls, ptr, len);
+		// FIXME: Handle non-blocking?
+		if (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) continue;
+		if (ret < 0) {
+			warnx("tls_write: %s", tls_error(client->tls));
+			client->close = true;
+			return;
+		}
+		ptr += ret;
+		len -= ret;
+	}
+}
+
+static void format(struct Client *client, const char *format, ...) {
+	char buf[513];
+	va_list ap;
+	va_start(ap, format);
+	int len = vsnprintf(buf, sizeof(buf), format, ap);
+	va_end(ap);
+	assert(len > 0 && (size_t)len < sizeof(buf));
+	clientSend(client, buf, len);
+}
+
+typedef void Handler(struct Client *client, struct Command cmd);
+
+static void handleNick(struct Client *client, struct Command cmd) {
+	(void)cmd;
+	client->need &= ~NeedNick;
+}
+
+static void handleUser(struct Client *client, struct Command cmd) {
+	(void)cmd;
+	// TODO: Identify client by username.
+	client->need &= ~NeedUser;
+}
+
+static void handlePass(struct Client *client, struct Command cmd) {
+	if (!cmd.params[0] || strcmp(clientPass, cmd.params[0])) {
+		format(client, ":invalid 464 * :Password incorrect\r\n");
+		client->close = true;
+	} else {
+		client->need &= ~NeedPass;
+	}
+}
+
+static void handleCap(struct Client *client, struct Command cmd) {
 	// TODO...
 }
+
+static const struct {
+	const char *cmd;
+	Handler *fn;
+} Handlers[] = {
+	{ "CAP", handleCap },
+	{ "NICK", handleNick },
+	{ "PASS", handlePass },
+	{ "USER", handleUser },
+};
+
+static void clientParse(struct Client *client, char *line) {
+	struct Command cmd = parse(line);
+	if (!cmd.name) {
+		// FIXME: Identify client in message.
+		warnx("no command");
+		client->close = true;
+		return;
+	}
+	for (size_t i = 0; i < ARRAY_LEN(Handlers); ++i) {
+		if (strcmp(cmd.name, Handlers[i].cmd)) continue;
+		Handlers[i].fn(client, cmd);
+		break;
+	}
+}
+
+void clientRecv(struct Client *client) {
+	ssize_t read = tls_read(
+		client->tls,
+		&client->buf[client->len], sizeof(client->buf) - client->len
+	);
+	if (read == TLS_WANT_POLLIN || read == TLS_WANT_POLLOUT) return;
+	if (read < 0) warnx("tls_read: %s", tls_error(client->tls));
+	if (read < 1) {
+		client->close = true;
+		return;
+	}
+	client->len += read;
+
+	char *crlf;
+	char *line = client->buf;
+	for (;;) {
+		crlf = memmem(line, &client->buf[client->len] - line, "\r\n", 2);
+		if (!crlf) break;
+		crlf[0] = '\0';
+		if (verbose) fprintf(stderr, "\x1B[33m%s\x1B[m\n", line);
+		clientParse(client, line);
+		line = crlf + 2;
+	}
+	client->len -= line - client->buf;
+	memmove(client->buf, line, client->len);
+}