From 0c964f63c5e9362d856778b0abf81fdaff004d57 Mon Sep 17 00:00:00 2001 From: Curtis McEnroe Date: Mon, 28 Oct 2019 00:39:16 -0400 Subject: Move entire login flow to state and reorganize it --- state.c | 297 ++++++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 178 insertions(+), 119 deletions(-) (limited to 'state.c') diff --git a/state.c b/state.c index c63a6b2..85740ed 100644 --- a/state.c +++ b/state.c @@ -24,81 +24,112 @@ #include "bounce.h" -static struct { - char *origin; - char *welcome; - char *yourHost; - char *created; - char *myInfo[4]; -} intro; - -static struct { - char *nick; - char *origin; -} self; +typedef void Handler(struct Message *msg); -static void set(char **field, const char *value) { - if (*field) free(*field); - *field = strdup(value); - if (!*field) err(EX_OSERR, "strdup"); +static void require(const struct Message *msg, bool origin, size_t len) { + if (origin && !msg->origin) { + errx(EX_PROTOCOL, "%s missing origin", msg->cmd); + } + for (size_t i = 0; i < len; ++i) { + if (msg->params[i]) continue; + errx(EX_PROTOCOL, "%s missing parameter %zu", msg->cmd, 1 + i); + } } -struct Channel { - char *name; - char *topic; +static const char Base64[64] = { + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" }; -static struct { - struct Channel *ptr; - size_t cap, len; -} chans; - -static void chanAdd(const char *name) { - if (chans.len == chans.cap) { - chans.cap = (chans.cap ? chans.cap * 2 : 8); - chans.ptr = realloc(chans.ptr, sizeof(*chans.ptr) * chans.cap); - if (!chans.ptr) err(EX_OSERR, "realloc"); +static char *base64(const byte *src, size_t len) { + char *dst = malloc(1 + (len + 2) / 3 * 4); + if (!dst) err(EX_OSERR, "malloc"); + size_t i = 0; + while (len > 2) { + dst[i++] = Base64[0x3F & (src[0] >> 2)]; + dst[i++] = Base64[0x3F & (src[0] << 4 | src[1] >> 4)]; + dst[i++] = Base64[0x3F & (src[1] << 2 | src[2] >> 6)]; + dst[i++] = Base64[0x3F & src[2]]; + src += 3; + len -= 3; } - struct Channel *chan = &chans.ptr[chans.len++]; - chan->name = strdup(name); - if (!chan->name) err(EX_OSERR, "strdup"); - chan->topic = NULL; + if (len) { + dst[i++] = Base64[0x3F & (src[0] >> 2)]; + if (len > 1) { + dst[i++] = Base64[0x3F & (src[0] << 4 | src[1] >> 4)]; + dst[i++] = Base64[0x3F & (src[1] << 2)]; + } else { + dst[i++] = Base64[0x3F & (src[0] << 4)]; + dst[i++] = '='; + } + dst[i++] = '='; + } + dst[i] = '\0'; + return dst; } -static void chanTopic(const char *name, const char *topic) { - for (size_t i = 0; i < chans.len; ++i) { - if (strcmp(chans.ptr[i].name, name)) continue; - free(chans.ptr[i].topic); - chans.ptr[i].topic = strdup(topic); - if (!chans.ptr[i].topic) err(EX_OSERR, "strdup"); - break; +static char *plainBase64; + +void stateLogin( + const char *pass, const char *auth, + const char *nick, const char *user, const char *real +) { + if (auth) { + byte plain[1 + strlen(auth)]; + plain[0] = 0; + for (size_t i = 0; auth[i]; ++i) { + plain[1 + i] = (auth[i] == ':' ? 0 : auth[i]); + } + plainBase64 = base64(plain, sizeof(plain)); + serverFormat("CAP REQ :sasl\r\n"); } + if (pass) serverFormat("PASS :%s\r\n", pass); + serverFormat("NICK %s\r\n", nick); + serverFormat("USER %s 0 * :%s\r\n", user, real); } -static void chanRemove(const char *name) { - for (size_t i = 0; i < chans.len; ++i) { - if (strcmp(chans.ptr[i].name, name)) continue; - free(chans.ptr[i].name); - free(chans.ptr[i].topic); - chans.ptr[i] = chans.ptr[--chans.len]; - break; +static void handleCap(struct Message *msg) { + require(msg, false, 3); + if (strcmp(msg->params[1], "ACK") || strncmp(msg->params[2], "sasl", 4)) { + errx(EX_CONFIG, "server does not support SASL"); } + serverFormat("AUTHENTICATE PLAIN\r\n"); +} + +static void handleAuthenticate(struct Message *msg) { + (void)msg; + if (!plainBase64) errx(EX_PROTOCOL, "unsolicited AUTHENTICATE"); + serverFormat("AUTHENTICATE %s\r\n", plainBase64); + free(plainBase64); + plainBase64 = NULL; +} + +static void handleReplyLoggedIn(struct Message *msg) { + (void)msg; + serverFormat("CAP END\r\n"); +} + +static void handleErrorSASLFail(struct Message *msg) { + require(msg, false, 2); + errx(EX_CONFIG, "%s", msg->params[1]); } static struct { - char **tokens; - size_t cap, len; -} support; + char *nick; + char *origin; +} self; -static void supportAdd(const char *token) { - if (support.len == support.cap) { - support.cap = (support.cap ? support.cap * 2 : 8); - support.tokens = realloc(support.tokens, sizeof(char *) * support.cap); - if (!support.tokens) err(EX_OSERR, "realloc"); - } - support.tokens[support.len] = strdup(token); - if (!support.tokens[support.len]) err(EX_OSERR, "strdup"); - support.len++; +static struct { + char *origin; + char *welcome; + char *yourHost; + char *created; + char *myInfo[4]; +} intro; + +const char *stateEcho(void) { + if (self.origin) return self.origin; + if (self.nick) return self.nick; + return "*"; } bool stateReady(void) { @@ -110,125 +141,155 @@ bool stateReady(void) { && intro.myInfo[0]; } -const char *stateSelf(void) { - if (self.origin) return self.origin; - if (self.nick) return self.nick; - return "*"; -} - -typedef void Handler(struct Message *msg); - -static void handleCap(struct Message *msg) { - bool ack = msg->params[1] && !strcmp(msg->params[1], "ACK"); - bool sasl = msg->params[2] && !strncmp(msg->params[2], "sasl", 4); - if (!ack || !sasl) errx(EX_CONFIG, "server does not support SASL"); - serverFormat("AUTHENTICATE PLAIN\r\n"); -} - -static void handleAuthenticate(struct Message *msg) { - (void)msg; - serverAuth(); -} - -static void handleReplyLoggedIn(struct Message *msg) { - (void)msg; - serverFormat("CAP END\r\n"); +static void set(char **field, const char *value) { + if (*field) free(*field); + *field = strdup(value); + if (!*field) err(EX_OSERR, "strdup"); } -static void handleErrorSASLFail(struct Message *msg) { - if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_SASLFAIL without message"); - errx(EX_CONFIG, "%s", msg->params[1]); +static void handleErrorNicknameInUse(struct Message *msg) { + if (self.nick) return; + require(msg, false, 2); + serverFormat("NICK %s_\r\n", msg->params[1]); } static void handleReplyWelcome(struct Message *msg) { - if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_WELCOME without message"); + require(msg, true, 2); set(&intro.origin, msg->origin); set(&self.nick, msg->params[0]); set(&intro.welcome, msg->params[1]); } static void handleReplyYourHost(struct Message *msg) { - if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_YOURHOST without message"); + require(msg, false, 2); set(&intro.yourHost, msg->params[1]); } static void handleReplyCreated(struct Message *msg) { - if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_CREATED without message"); + require(msg, false, 2); set(&intro.created, msg->params[1]); } static void handleReplyMyInfo(struct Message *msg) { - if (!msg->params[4]) errx(EX_PROTOCOL, "RPL_MYINFO without 4 parameters"); + require(msg, false, 5); set(&intro.myInfo[0], msg->params[1]); set(&intro.myInfo[1], msg->params[2]); set(&intro.myInfo[2], msg->params[3]); set(&intro.myInfo[3], msg->params[4]); } +static struct { + char **tokens; + size_t cap, len; +} support; + +static void supportAdd(const char *token) { + if (support.len == support.cap) { + support.cap = (support.cap ? support.cap * 2 : 8); + support.tokens = realloc(support.tokens, sizeof(char *) * support.cap); + if (!support.tokens) err(EX_OSERR, "realloc"); + } + support.tokens[support.len] = strdup(token); + if (!support.tokens[support.len]) err(EX_OSERR, "strdup"); + support.len++; +} + static void handleReplyISupport(struct Message *msg) { + require(msg, false, 1); for (size_t i = 1; i < ParamCap; ++i) { if (!msg->params[i] || strchr(msg->params[i], ' ')) break; supportAdd(msg->params[i]); } } -static void handleErrorNicknameInUse(struct Message *msg) { - if (self.nick) return; - if (!msg->params[1]) errx(EX_PROTOCOL, "ERR_NICKNAMEINUSE without nick"); - serverFormat("NICK %s_\r\n", msg->params[1]); +struct Channel { + char *name; + char *topic; +}; + +static struct { + struct Channel *ptr; + size_t cap, len; +} chans; + +static void chanAdd(const char *name) { + if (chans.len == chans.cap) { + chans.cap = (chans.cap ? chans.cap * 2 : 8); + chans.ptr = realloc(chans.ptr, sizeof(*chans.ptr) * chans.cap); + if (!chans.ptr) err(EX_OSERR, "realloc"); + } + struct Channel *chan = &chans.ptr[chans.len++]; + chan->name = strdup(name); + if (!chan->name) err(EX_OSERR, "strdup"); + chan->topic = NULL; +} + +static void chanTopic(const char *name, const char *topic) { + for (size_t i = 0; i < chans.len; ++i) { + if (strcmp(chans.ptr[i].name, name)) continue; + set(&chans.ptr[i].topic, topic); + break; + } +} + +static void chanRemove(const char *name) { + for (size_t i = 0; i < chans.len; ++i) { + if (strcmp(chans.ptr[i].name, name)) continue; + free(chans.ptr[i].name); + free(chans.ptr[i].topic); + chans.ptr[i] = chans.ptr[--chans.len]; + break; + } } -static bool fromSelf(const struct Message *msg) { +static bool originSelf(const char *origin) { if (!self.nick) return false; + size_t len = strlen(self.nick); - if (strlen(msg->origin) < len) return false; - if (strncmp(msg->origin, self.nick, len)) return false; - if (msg->origin[len] != '!') return false; - if (!self.origin || strcmp(self.origin, msg->origin)) { - set(&self.origin, msg->origin); + if (strlen(origin) < len) return false; + if (strncmp(origin, self.nick, len)) return false; + if (origin[len] != '!') return false; + + if (!self.origin || strcmp(self.origin, origin)) { + set(&self.origin, origin); } return true; } static void handleNick(struct Message *msg) { - if (!msg->origin) errx(EX_PROTOCOL, "NICK without origin"); - if (!msg->params[0]) errx(EX_PROTOCOL, "NICK without nick"); - if (fromSelf(msg)) set(&self.nick, msg->params[0]); + require(msg, true, 1); + if (originSelf(msg->origin)) set(&self.nick, msg->params[0]); } static void handleJoin(struct Message *msg) { - if (!msg->origin) errx(EX_PROTOCOL, "JOIN without origin"); - if (!msg->params[0]) errx(EX_PROTOCOL, "JOIN without channel"); - if (fromSelf(msg)) chanAdd(msg->params[0]); + require(msg, true, 1); + if (originSelf(msg->origin)) chanAdd(msg->params[0]); } static void handlePart(struct Message *msg) { - if (!msg->origin) errx(EX_PROTOCOL, "PART without origin"); - if (!msg->params[0]) errx(EX_PROTOCOL, "PART without channel"); - if (fromSelf(msg)) chanRemove(msg->params[0]); + require(msg, true, 1); + if (originSelf(msg->origin)) chanRemove(msg->params[0]); } static void handleKick(struct Message *msg) { - if (!msg->params[0]) errx(EX_PROTOCOL, "KICK without channel"); - if (!msg->params[1]) errx(EX_PROTOCOL, "KICK without nick"); + require(msg, false, 2); if (self.nick && !strcmp(msg->params[1], self.nick)) { chanRemove(msg->params[0]); } } static void handleTopic(struct Message *msg) { - if (!msg->params[0]) errx(EX_PROTOCOL, "TOPIC without channel"); - if (!msg->params[1]) errx(EX_PROTOCOL, "TOPIC without topic"); + require(msg, false, 2); chanTopic(msg->params[0], msg->params[1]); } static void handleReplyTopic(struct Message *msg) { - if (!msg->params[1]) errx(EX_PROTOCOL, "RPL_TOPIC without channel"); - if (!msg->params[2]) errx(EX_PROTOCOL, "RPL_TOPIC without topic"); + require(msg, false, 3); chanTopic(msg->params[1], msg->params[2]); } static void handleError(struct Message *msg) { + require(msg, false, 1); errx(EX_UNAVAILABLE, "%s", msg->params[0]); } @@ -281,13 +342,11 @@ void stateSync(struct Client *client) { client, ":%s 001 %s :%s\r\n" ":%s 002 %s :%s\r\n" - ":%s 003 %s :%s\r\n", + ":%s 003 %s :%s\r\n" + ":%s 004 %s %s %s %s %s\r\n", intro.origin, self.nick, intro.welcome, intro.origin, self.nick, intro.yourHost, - intro.origin, self.nick, intro.created - ); - clientFormat( - client, ":%s 004 %s %s %s %s %s\r\n", + intro.origin, self.nick, intro.created, intro.origin, self.nick, intro.myInfo[0], intro.myInfo[1], intro.myInfo[2], intro.myInfo[3] ); -- cgit 1.4.1