From d4ff0457718b573d2c9d20000c63014666bf5791 Mon Sep 17 00:00:00 2001 From: "C. McEnroe" Date: Sat, 9 Nov 2019 20:17:43 -0500 Subject: Maintain stateCaps and offer them to clients --- bounce.h | 2 ++ client.c | 8 +++----- state.c | 16 ++++++++++++---- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/bounce.h b/bounce.h index 5dd2536..026199f 100644 --- a/bounce.h +++ b/bounce.h @@ -62,6 +62,7 @@ static inline struct Message parse(char *line) { X("chghost", CapChghost) \ X("extended-join", CapExtendedJoin) \ X("invite-notify", CapInviteNotify) \ + X("sasl", CapSASL) \ X("server-time", CapServerTime) \ X("", CapUnsupported) @@ -144,6 +145,7 @@ size_t clientDiff(const struct Client *client); void clientConsume(struct Client *client); bool stateJoinNames; +enum Cap stateCaps; void stateLogin( const char *pass, bool sasl, const char *plain, const char *nick, const char *user, const char *real diff --git a/client.c b/client.c index 9bbc5d9..40bba1c 100644 --- a/client.c +++ b/client.c @@ -150,6 +150,7 @@ static void handlePass(struct Client *client, struct Message *msg) { static void handleCap(struct Client *client, struct Message *msg) { if (!msg->params[0]) msg->params[0] = ""; + enum Cap avail = CapServerTime | (stateCaps & ~CapSASL); if (!strcmp(msg->params[0], "END")) { if (!client->need) return; @@ -158,15 +159,12 @@ static void handleCap(struct Client *client, struct Message *msg) { } else if (!strcmp(msg->params[0], "LS")) { if (client->need) client->need |= NeedCapEnd; - clientFormat( - client, ":%s CAP * LS :%s\r\n", - ORIGIN, capList(CapServerTime) - ); + clientFormat(client, ":%s CAP * LS :%s\r\n", ORIGIN, capList(avail)); } else if (!strcmp(msg->params[0], "REQ") && msg->params[1]) { if (client->need) client->need |= NeedCapEnd; enum Cap caps = capParse(msg->params[1]); - if (caps == CapServerTime) { + if (caps == (avail & caps)) { client->caps |= caps; clientFormat(client, ":%s CAP * ACK :%s\r\n", ORIGIN, msg->params[1]); } else { diff --git a/state.c b/state.c index da5c96f..e17d9e6 100644 --- a/state.c +++ b/state.c @@ -45,7 +45,7 @@ void stateLogin( ) { if (pass) serverFormat("PASS :%s\r\n", pass); if (sasl) { - serverFormat("CAP REQ :sasl\r\n"); + serverFormat("CAP REQ :%s\r\n", capList(CapSASL)); if (plain) { byte buf[1 + strlen(plain)]; buf[0] = 0; @@ -59,14 +59,22 @@ void stateLogin( } serverFormat("NICK %s\r\n", nick); serverFormat("USER %s 0 * :%s\r\n", user, real); + serverFormat("CAP LS\r\n"); } static void handleCap(struct Message *msg) { require(msg, false, 3); - if (strcmp(msg->params[1], "ACK") || strncmp(msg->params[2], "sasl", 4)) { - errx(EX_CONFIG, "server does not support SASL"); + enum Cap caps = capParse(msg->params[2]); + if (!strcmp(msg->params[1], "ACK")) { + stateCaps |= caps; + if (caps & CapSASL) { + serverFormat( + "AUTHENTICATE %s\r\n", (plainBase64 ? "PLAIN" : "EXTERNAL") + ); + } + } else if (!strcmp(msg->params[1], "NAK")) { + errx(EX_CONFIG, "server does not support %s", msg->params[2]); } - serverFormat("AUTHENTICATE %s\r\n", (plainBase64 ? "PLAIN" : "EXTERNAL")); } static void handleAuthenticate(struct Message *msg) { -- cgit 1.4.1