about summary refs log tree commit diff
path: root/state.c
diff options
context:
space:
mode:
Diffstat (limited to 'state.c')
-rw-r--r--state.c39
1 files changed, 22 insertions, 17 deletions
diff --git a/state.c b/state.c
index a0336dc..ba6f8d6 100644
--- a/state.c
+++ b/state.c
@@ -40,21 +40,23 @@ static void require(const struct Message *msg, bool origin, size_t len) {
 static char *plainBase64;
 
 void stateLogin(
-	const char *pass, const char *auth,
+	const char *pass, bool sasl, const char *plain,
 	const char *nick, const char *user, const char *real
 ) {
-	if (auth) {
-		byte plain[1 + strlen(auth)];
-		plain[0] = 0;
-		for (size_t i = 0; auth[i]; ++i) {
-			plain[1 + i] = (auth[i] == ':' ? 0 : auth[i]);
-		}
-		plainBase64 = malloc(BASE64_SIZE(sizeof(plain)));
-		if (!plainBase64) err(EX_OSERR, "malloc");
-		base64(plainBase64, plain, sizeof(plain));
+	if (pass) serverFormat("PASS :%s\r\n", pass);
+	if (sasl) {
 		serverFormat("CAP REQ :sasl\r\n");
+		if (plain) {
+			byte buf[1 + strlen(plain)];
+			buf[0] = 0;
+			for (size_t i = 0; plain[i]; ++i) {
+				buf[1 + i] = (plain[i] == ':' ? 0 : plain[i]);
+			}
+			plainBase64 = malloc(BASE64_SIZE(sizeof(buf)));
+			if (!plainBase64) err(EX_OSERR, "malloc");
+			base64(plainBase64, buf, sizeof(buf));
+		}
 	}
-	if (pass) serverFormat("PASS :%s\r\n", pass);
 	serverFormat("NICK %s\r\n", nick);
 	serverFormat("USER %s 0 * :%s\r\n", user, real);
 }
@@ -64,16 +66,19 @@ static void handleCap(struct Message *msg) {
 	if (strcmp(msg->params[1], "ACK") || strncmp(msg->params[2], "sasl", 4)) {
 		errx(EX_CONFIG, "server does not support SASL");
 	}
-	serverFormat("AUTHENTICATE PLAIN\r\n");
+	serverFormat("AUTHENTICATE %s\r\n", (plainBase64 ? "PLAIN" : "EXTERNAL"));
 }
 
 static void handleAuthenticate(struct Message *msg) {
 	(void)msg;
-	if (!plainBase64) errx(EX_PROTOCOL, "unsolicited AUTHENTICATE");
-	serverFormat("AUTHENTICATE %s\r\n", plainBase64);
-	explicit_bzero(plainBase64, strlen(plainBase64));
-	free(plainBase64);
-	plainBase64 = NULL;
+	if (plainBase64) {
+		serverFormat("AUTHENTICATE %s\r\n", plainBase64);
+		explicit_bzero(plainBase64, strlen(plainBase64));
+		free(plainBase64);
+		plainBase64 = NULL;
+	} else {
+		serverFormat("AUTHENTICATE +\r\n");
+	}
 }
 
 static void handleReplyLoggedIn(struct Message *msg) {