summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bounce.c10
-rw-r--r--client.c24
2 files changed, 22 insertions, 12 deletions
diff --git a/bounce.c b/bounce.c
index ec1925c..5b5e524 100644
--- a/bounce.c
+++ b/bounce.c
@@ -486,15 +486,7 @@ int main(int argc, char *argv[]) {
 					warn("accept");
 					continue;
 				}
-
-				error = tls_handshake(tls);
-				if (error) {
-					warnx("tls_handshake: %s", tls_error(tls));
-					tls_free(tls);
-					close(fd);
-				} else {
-					eventAdd(fd, clientAlloc(tls));
-				}
+				eventAdd(fd, clientAlloc(tls));
 				continue;
 			}
 
diff --git a/client.c b/client.c
index 36f8008..6f36539 100644
--- a/client.c
+++ b/client.c
@@ -48,6 +48,7 @@ char *clientAway;
 static size_t active;
 
 enum Need {
+	BIT(NeedHandshake),
 	BIT(NeedNick),
 	BIT(NeedUser),
 	BIT(NeedPass),
@@ -69,11 +70,23 @@ struct Client *clientAlloc(struct tls *tls) {
 	struct Client *client = calloc(1, sizeof(*client));
 	if (!client) err(EX_OSERR, "calloc");
 	client->tls = tls;
-	client->need = NeedNick | NeedUser | (clientPass ? NeedPass : 0);
-	if ((clientCaps & CapSASL) && tls_peer_cert_provided(tls)) {
+	client->need = NeedHandshake | NeedNick | NeedUser;
+	if (clientPass) client->need |= NeedPass;
+	return client;
+}
+
+static void clientHandshake(struct Client *client) {
+	int error = tls_handshake(client->tls);
+	if (error == TLS_WANT_POLLIN || error == TLS_WANT_POLLOUT) return;
+	if (error) {
+		warnx("client tls_handshake: %s", tls_error(client->tls));
+		client->error = true;
+		return;
+	}
+	client->need &= ~NeedHandshake;
+	if ((clientCaps & CapSASL) && tls_peer_cert_provided(client->tls)) {
 		client->need &= ~NeedPass;
 	}
-	return client;
 }
 
 void clientFree(struct Client *client) {
@@ -369,6 +382,11 @@ static bool intercept(const char *line, size_t len) {
 }
 
 void clientRecv(struct Client *client) {
+	if (client->need & NeedHandshake) {
+		clientHandshake(client);
+		return;
+	}
+
 	ssize_t read = tls_read(
 		client->tls,
 		&client->buf[client->len], sizeof(client->buf) - client->len