summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--bounce.h2
-rw-r--r--client.c8
-rw-r--r--state.c16
3 files changed, 17 insertions, 9 deletions
diff --git a/bounce.h b/bounce.h
index 5dd2536..026199f 100644
--- a/bounce.h
+++ b/bounce.h
@@ -62,6 +62,7 @@ static inline struct Message parse(char *line) {
 	X("chghost", CapChghost) \
 	X("extended-join", CapExtendedJoin) \
 	X("invite-notify", CapInviteNotify) \
+	X("sasl", CapSASL) \
 	X("server-time", CapServerTime) \
 	X("", CapUnsupported)
 
@@ -144,6 +145,7 @@ size_t clientDiff(const struct Client *client);
 void clientConsume(struct Client *client);
 
 bool stateJoinNames;
+enum Cap stateCaps;
 void stateLogin(
 	const char *pass, bool sasl, const char *plain,
 	const char *nick, const char *user, const char *real
diff --git a/client.c b/client.c
index 9bbc5d9..40bba1c 100644
--- a/client.c
+++ b/client.c
@@ -150,6 +150,7 @@ static void handlePass(struct Client *client, struct Message *msg) {
 
 static void handleCap(struct Client *client, struct Message *msg) {
 	if (!msg->params[0]) msg->params[0] = "";
+	enum Cap avail = CapServerTime | (stateCaps & ~CapSASL);
 
 	if (!strcmp(msg->params[0], "END")) {
 		if (!client->need) return;
@@ -158,15 +159,12 @@ static void handleCap(struct Client *client, struct Message *msg) {
 
 	} else if (!strcmp(msg->params[0], "LS")) {
 		if (client->need) client->need |= NeedCapEnd;
-		clientFormat(
-			client, ":%s CAP * LS :%s\r\n",
-			ORIGIN, capList(CapServerTime)
-		);
+		clientFormat(client, ":%s CAP * LS :%s\r\n", ORIGIN, capList(avail));
 
 	} else if (!strcmp(msg->params[0], "REQ") && msg->params[1]) {
 		if (client->need) client->need |= NeedCapEnd;
 		enum Cap caps = capParse(msg->params[1]);
-		if (caps == CapServerTime) {
+		if (caps == (avail & caps)) {
 			client->caps |= caps;
 			clientFormat(client, ":%s CAP * ACK :%s\r\n", ORIGIN, msg->params[1]);
 		} else {
diff --git a/state.c b/state.c
index da5c96f..e17d9e6 100644
--- a/state.c
+++ b/state.c
@@ -45,7 +45,7 @@ void stateLogin(
 ) {
 	if (pass) serverFormat("PASS :%s\r\n", pass);
 	if (sasl) {
-		serverFormat("CAP REQ :sasl\r\n");
+		serverFormat("CAP REQ :%s\r\n", capList(CapSASL));
 		if (plain) {
 			byte buf[1 + strlen(plain)];
 			buf[0] = 0;
@@ -59,14 +59,22 @@ void stateLogin(
 	}
 	serverFormat("NICK %s\r\n", nick);
 	serverFormat("USER %s 0 * :%s\r\n", user, real);
+	serverFormat("CAP LS\r\n");
 }
 
 static void handleCap(struct Message *msg) {
 	require(msg, false, 3);
-	if (strcmp(msg->params[1], "ACK") || strncmp(msg->params[2], "sasl", 4)) {
-		errx(EX_CONFIG, "server does not support SASL");
+	enum Cap caps = capParse(msg->params[2]);
+	if (!strcmp(msg->params[1], "ACK")) {
+		stateCaps |= caps;
+		if (caps & CapSASL) {
+			serverFormat(
+				"AUTHENTICATE %s\r\n", (plainBase64 ? "PLAIN" : "EXTERNAL")
+			);
+		}
+	} else if (!strcmp(msg->params[1], "NAK")) {
+		errx(EX_CONFIG, "server does not support %s", msg->params[2]);
 	}
-	serverFormat("AUTHENTICATE %s\r\n", (plainBase64 ? "PLAIN" : "EXTERNAL"));
 }
 
 static void handleAuthenticate(struct Message *msg) {