summary refs log tree commit diff
path: root/state.c
diff options
context:
space:
mode:
authorJune McEnroe <june@causal.agency>2019-10-28 00:39:16 -0400
committerJune McEnroe <june@causal.agency>2019-10-28 00:39:16 -0400
commit0c964f63c5e9362d856778b0abf81fdaff004d57 (patch)
treec356a244640731ba6e39137bbb3e90e6c75fc8c6 /state.c
parentWait for SASL success before sending CAP END (diff)
downloadpounce-0c964f63c5e9362d856778b0abf81fdaff004d57.tar.gz
pounce-0c964f63c5e9362d856778b0abf81fdaff004d57.zip
Move entire login flow to state and reorganize it
Diffstat (limited to 'state.c')
-rw-r--r--state.c297
1 files changed, 178 insertions, 119 deletions
diff --git a/state.c b/state.c
index c63a6b2..85740ed 100644
--- a/state.c
+++ b/state.c
@@ -24,81 +24,112 @@
 
 #include "bounce.h"
 
-static struct {
-	char *origin;
-	char *welcome;
-	char *yourHost;
-	char *created;
-	char *myInfo[4];
-} intro;
-
-static struct {
-	char *nick;
-	char *origin;
-} self;
+typedef void Handler(struct Message *msg);
 
-static void set(char **field, const char *value) {
-	if (*field) free(*field);
-	*field = strdup(value);
-	if (!*field) err(EX_OSERR, "strdup");
+static void require(const struct Message *msg, bool origin, size_t len) {
+	if (origin && !msg->origin) {
+		errx(EX_PROTOCOL, "%s missing origin", msg->cmd);
+	}
+	for (size_t i = 0; i < len; ++i) {
+		if (msg->params[i]) continue;
+		errx(EX_PROTOCOL, "%s missing parameter %zu", msg->cmd, 1 + i);
+	}
 }
 
-struct Channel {
-	char *name;
-	char *topic;
+static const char Base64[64] = {
+	"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
 };
 
-static struct {
-	struct Channel *ptr;
-	size_t cap, len;
-} chans;
-
-static void chanAdd(const char *name) {
-	if (chans.len == chans.cap) {
-		chans.cap = (chans.cap ? chans.cap * 2 : 8);
-		chans.ptr = realloc(chans.ptr, sizeof(*chans.ptr) * chans.cap);
-		if (!chans.ptr) err(EX_OSERR, "realloc");
+static char *base64(const byte *src, size_t len) {
+	char *dst = malloc(1 + (len + 2) / 3 * 4);
+	if (!dst) err(EX_OSERR, "malloc");
+	size_t i = 0;
+	while (len > 2) {
+		dst[i++] = Base64[0x3F & (src[0] >> 2)];
+		dst[i++] = Base64[0x3F & (src[0] << 4 | src[1] >> 4)];
+		dst[i++] = Base64[0x3F & (src[1] << 2 | src[2] >> 6)];
+		dst[i++] = Base64[0x3F & src[2]];
+		src += 3;
+		len -= 3;
 	}
-	struct Channel *chan = &chans.ptr[chans.len++];
-	chan->name = strdup(name);
-	if (!chan->name) err(EX_OSERR, "strdup");
-	chan->topic = NULL;
+	if (len) {
+		dst[i++] = Base64[0x3F & (src[0] >> 2)];
+		if (len > 1) {
+			dst[i++] = Base64[0x3F & (src[0] << 4 | src[1] >> 4)];
+			dst[i++] = Base64[0x3F & (src[1] << 2)];
+		} else {
+			dst[i++] = Base64[0x3F & (src[0] << 4)];
+			dst[i++] = '=';
+		}
+		dst[i++] = '=';
+	}
+	dst[i] = '\0';
+	return dst;
 }
 
-static void chanTopic(const char *name, const char *topic) {
-	for (size_t i = 0; i < chans.len; ++i) {
-		if (strcmp(chans.ptr[i].name, name)) continue;
-		free(chans.ptr[i].topic);
-		chans.ptr[i].topic = strdup(topic);
-		if (!chans.ptr[i].topic) err(EX_OSERR, "strdup");
-		break;
+static char *plainBase64;
+
+void stateLogin(
+	const char *pass, const char *auth,
+	const char *nick, const char *user, const char *real
+) {
+	if (auth) {
+		byte plain[1 + strlen(auth)];
+		plain[0] = 0;
+		for (size_t i = 0; auth[i]; ++i) {
+			plain[1 + i] = (auth[i] == ':' ? 0 : auth[i]);
+		}
+		plainBase64 = base64(plain, sizeof(plain));
+		serverFormat("CAP REQ :sasl\r\n");
 	}
+	if (pass) serverFormat("PASS :%s\r\n", pass);
+	serverFormat("NICK %s\r\n", nick);
+	serverFormat("USER %s 0 * :%s\r\n", user, real);
 }
 
-static void chanRemove(const char *name) {
-	for (size_t i = 0; i < chans.len; ++i) {
-		if (strcmp(chans.ptr[i].name, name)) continue;
-		free(chans.ptr[i].name);
-		free(chans.ptr[i].topic);
-		chans.ptr[i] = chans.ptr[--chans.len];
-		break;
+static void handleCap(struct Message *msg) {
+	require(msg, false, 3);
+	if (strcmp(msg->params[1], "ACK") || strncmp(msg->params[2], "sasl", 4)) {
+		errx(EX_CONFIG, "server does not support SASL");
 	}
+	serverFormat("AUTHENTICATE PLAIN\r\n");
+}
+
+static void handleAuthenticate(struct Message *msg) {
+	(void)msg;
+	if (!plainBase64) errx(EX_PROTOCOL, "unsolicited AUTHENTICATE");
+	serverFormat("AUTHENTICATE %s\r\n", plainBase64);
+	free(plainBase64);
+	plainBase64 = NULL;
+}
+
+static void handleReplyLoggedIn(struct Message *msg) {
+	(void)msg;
+	serverFormat("CAP END\r\n");
+}
+
+static void handleErrorSASLFail(struct Message *msg) {
+	require(msg, false, 2);
+	errx(EX_CONFIG, "%s", msg->params[1]);
 }
 
 static struct {
-	char **tokens;
-	size_t cap, len;
-} support;
+	char *nick;
+	char *origin;
+} self;
 
-static void supportAdd(const char *token) {
-	if (support.len == support.cap) {
-		support.cap = (support.cap ? support.cap * 2 : 8);
-		support.tokens = realloc(support.tokens, sizeof(char *) * support.cap);
-		if (!support.tokens) err(EX_OSERR, "realloc");
-	}
-	support.tokens[support.len] = strdup(token);
-	if (!support.tokens[support.len]) err(EX_OSERR, "strdup");
-	support.len++;
+static struct {
+	char *origin;
+	char *welcome;
+	char *yourHost;
+	char *created;
+	char *myInfo[4];
+} intro;
+
+const char *stateEcho(void) {
+	if (self.origin) return self.origin;
+	if (self.nick) return self.nick;
+	return "*";
 }
 
 bool stateReady(void) {
@@ -110,125 +141,155 @@ bool stateReady(void) {
 		&& intro.myInfo[0];
 }
 
-const char *stateSelf(void) {
-	if (self.origin) return self.origin;
-	if (self.nick) return self.nick;
-	return "*";
-}
-
-typedef void Handler(struct Message *msg);
-
-static void handleCap(struct Message *msg) {
-	bool ack = msg->params[1] && !strcmp(msg->params[1], "ACK");
-	bool sasl = msg->params[2] && !strncmp(msg->params[2], "sasl", 4);
-	if (!ack || !sasl) errx(EX_CONFIG, "server does not support SASL");
-	serverFormat("AUTHENTICATE PLAIN\r\n");
-}
-
-static void handleAuthenticate(struct Message *msg) {
-	(void)msg;
-	serverAuth();
-}
-
-static void handleReplyLoggedIn(struct Message *msg) {
-	(void)msg;
-	serverFormat("CAP END\r\n");
+static void set(char **field, const char *value) {
+	if (*field) free(*field);
+	*field = strdup(value);
+	if (!*field) err(EX_OSERR, "strdup");
 }
 
-static void handleErrorSASLFail(struct Message *msg) {
-	if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_SASLFAIL without message");
-	errx(EX_CONFIG, "%s", msg->params[1]);
+static void handleErrorNicknameInUse(struct Message *msg) {
+	if (self.nick) return;
+	require(msg, false, 2);
+	serverFormat("NICK %s_\r\n", msg->params[1]);
 }
 
 static void handleReplyWelcome(struct Message *msg) {
-	if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_WELCOME without message");
+	require(msg, true, 2);
 	set(&intro.origin, msg->origin);
 	set(&self.nick, msg->params[0]);
 	set(&intro.welcome, msg->params[1]);
 }
 
 static void handleReplyYourHost(struct Message *msg) {
-	if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_YOURHOST without message");
+	require(msg, false, 2);
 	set(&intro.yourHost, msg->params[1]);
 }
 
 static void handleReplyCreated(struct Message *msg) {
-	if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_CREATED without message");
+	require(msg, false, 2);
 	set(&intro.created, msg->params[1]);
 }
 
 static void handleReplyMyInfo(struct Message *msg) {
-	if (!msg->params[4]) errx(EX_PROTOCOL, "RPL_MYINFO without 4 parameters");
+	require(msg, false, 5);
 	set(&intro.myInfo[0], msg->params[1]);
 	set(&intro.myInfo[1], msg->params[2]);
 	set(&intro.myInfo[2], msg->params[3]);
 	set(&intro.myInfo[3], msg->params[4]);
 }
 
+static struct {
+	char **tokens;
+	size_t cap, len;
+} support;
+
+static void supportAdd(const char *token) {
+	if (support.len == support.cap) {
+		support.cap = (support.cap ? support.cap * 2 : 8);
+		support.tokens = realloc(support.tokens, sizeof(char *) * support.cap);
+		if (!support.tokens) err(EX_OSERR, "realloc");
+	}
+	support.tokens[support.len] = strdup(token);
+	if (!support.tokens[support.len]) err(EX_OSERR, "strdup");
+	support.len++;
+}
+
 static void handleReplyISupport(struct Message *msg) {
+	require(msg, false, 1);
 	for (size_t i = 1; i < ParamCap; ++i) {
 		if (!msg->params[i] || strchr(msg->params[i], ' ')) break;
 		supportAdd(msg->params[i]);
 	}
 }
 
-static void handleErrorNicknameInUse(struct Message *msg) {
-	if (self.nick) return;
-	if (!msg->params[1]) errx(EX_PROTOCOL, "ERR_NICKNAMEINUSE without nick");
-	serverFormat("NICK %s_\r\n", msg->params[1]);
+struct Channel {
+	char *name;
+	char *topic;
+};
+
+static struct {
+	struct Channel *ptr;
+	size_t cap, len;
+} chans;
+
+static void chanAdd(const char *name) {
+	if (chans.len == chans.cap) {
+		chans.cap = (chans.cap ? chans.cap * 2 : 8);
+		chans.ptr = realloc(chans.ptr, sizeof(*chans.ptr) * chans.cap);
+		if (!chans.ptr) err(EX_OSERR, "realloc");
+	}
+	struct Channel *chan = &chans.ptr[chans.len++];
+	chan->name = strdup(name);
+	if (!chan->name) err(EX_OSERR, "strdup");
+	chan->topic = NULL;
+}
+
+static void chanTopic(const char *name, const char *topic) {
+	for (size_t i = 0; i < chans.len; ++i) {
+		if (strcmp(chans.ptr[i].name, name)) continue;
+		set(&chans.ptr[i].topic, topic);
+		break;
+	}
+}
+
+static void chanRemove(const char *name) {
+	for (size_t i = 0; i < chans.len; ++i) {
+		if (strcmp(chans.ptr[i].name, name)) continue;
+		free(chans.ptr[i].name);
+		free(chans.ptr[i].topic);
+		chans.ptr[i] = chans.ptr[--chans.len];
+		break;
+	}
 }
 
-static bool fromSelf(const struct Message *msg) {
+static bool originSelf(const char *origin) {
 	if (!self.nick) return false;
+
 	size_t len = strlen(self.nick);
-	if (strlen(msg->origin) < len) return false;
-	if (strncmp(msg->origin, self.nick, len)) return false;
-	if (msg->origin[len] != '!') return false;
-	if (!self.origin || strcmp(self.origin, msg->origin)) {
-		set(&self.origin, msg->origin);
+	if (strlen(origin) < len) return false;
+	if (strncmp(origin, self.nick, len)) return false;
+	if (origin[len] != '!') return false;
+
+	if (!self.origin || strcmp(self.origin, origin)) {
+		set(&self.origin, origin);
 	}
 	return true;
 }
 
 static void handleNick(struct Message *msg) {
-	if (!msg->origin) errx(EX_PROTOCOL, "NICK without origin");
-	if (!msg->params[0]) errx(EX_PROTOCOL, "NICK without nick");
-	if (fromSelf(msg)) set(&self.nick, msg->params[0]);
+	require(msg, true, 1);
+	if (originSelf(msg->origin)) set(&self.nick, msg->params[0]);
 }
 
 static void handleJoin(struct Message *msg) {
-	if (!msg->origin) errx(EX_PROTOCOL, "JOIN without origin");
-	if (!msg->params[0]) errx(EX_PROTOCOL, "JOIN without channel");
-	if (fromSelf(msg)) chanAdd(msg->params[0]);
+	require(msg, true, 1);
+	if (originSelf(msg->origin)) chanAdd(msg->params[0]);
 }
 
 static void handlePart(struct Message *msg) {
-	if (!msg->origin) errx(EX_PROTOCOL, "PART without origin");
-	if (!msg->params[0]) errx(EX_PROTOCOL, "PART without channel");
-	if (fromSelf(msg)) chanRemove(msg->params[0]);
+	require(msg, true, 1);
+	if (originSelf(msg->origin)) chanRemove(msg->params[0]);
 }
 
 static void handleKick(struct Message *msg) {
-	if (!msg->params[0]) errx(EX_PROTOCOL, "KICK without channel");
-	if (!msg->params[1]) errx(EX_PROTOCOL, "KICK without nick");
+	require(msg, false, 2);
 	if (self.nick && !strcmp(msg->params[1], self.nick)) {
 		chanRemove(msg->params[0]);
 	}
 }
 
 static void handleTopic(struct Message *msg) {
-	if (!msg->params[0]) errx(EX_PROTOCOL, "TOPIC without channel");
-	if (!msg->params[1]) errx(EX_PROTOCOL, "TOPIC without topic");
+	require(msg, false, 2);
 	chanTopic(msg->params[0], msg->params[1]);
 }
 
 static void handleReplyTopic(struct Message *msg) {
-	if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_TOPIC without channel");
-	if (!msg->params[2]) errx(EX_PROTOCOL, "RPL_TOPIC without topic");
+	require(msg, false, 3);
 	chanTopic(msg->params[1], msg->params[2]);
 }
 
 static void handleError(struct Message *msg) {
+	require(msg, false, 1);
 	errx(EX_UNAVAILABLE, "%s", msg->params[0]);
 }
 
@@ -281,13 +342,11 @@ void stateSync(struct Client *client) {
 		client,
 		":%s 001 %s :%s\r\n"
 		":%s 002 %s :%s\r\n"
-		":%s 003 %s :%s\r\n",
+		":%s 003 %s :%s\r\n"
+		":%s 004 %s %s %s %s %s\r\n",
 		intro.origin, self.nick, intro.welcome,
 		intro.origin, self.nick, intro.yourHost,
-		intro.origin, self.nick, intro.created
-	);
-	clientFormat(
-		client, ":%s 004 %s %s %s %s %s\r\n",
+		intro.origin, self.nick, intro.created,
 		intro.origin, self.nick,
 		intro.myInfo[0], intro.myInfo[1], intro.myInfo[2], intro.myInfo[3]
 	);