summary refs log tree commit diff
path: root/litterbox.c
diff options
context:
space:
mode:
authorJune McEnroe <june@causal.agency>2019-12-24 00:58:26 -0500
committerJune McEnroe <june@causal.agency>2019-12-24 00:58:26 -0500
commit2d7108f29b9f7d80709dfbe0c2e2434b07c5a45c (patch)
tree4e00cf884cd6b0291694f830140d79b240d19c28 /litterbox.c
parentActually only use a transaction for handlers that need it (diff)
downloadlitterbox-2d7108f29b9f7d80709dfbe0c2e2434b07c5a45c.tar.gz
litterbox-2d7108f29b9f7d80709dfbe0c2e2434b07c5a45c.zip
It's The Big Refactor
Diffstat (limited to 'litterbox.c')
-rw-r--r--litterbox.c275
1 files changed, 151 insertions, 124 deletions
diff --git a/litterbox.c b/litterbox.c
index 39fce8f..658822c 100644
--- a/litterbox.c
+++ b/litterbox.c
@@ -16,9 +16,7 @@
 
 #include <assert.h>
 #include <err.h>
-#include <sqlite3.h>
 #include <stdarg.h>
-#include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
@@ -28,37 +26,29 @@
 
 #include "database.h"
 
-static sqlite3 *db;
+static struct {
+	sqlite3_stmt *context;
+	sqlite3_stmt *event;
+	sqlite3_stmt *events;
+} insert;
 
-static const char *CreateJoins = SQL(
-	CREATE TEMPORARY TABLE joins (
-		nick TEXT NOT NULL,
-		channel TEXT NOT NULL,
-		UNIQUE (nick, channel)
+static void prepare(void) {
+	const char *CreateJoins = SQL(
+		CREATE TEMPORARY TABLE joins (
+			nick TEXT NOT NULL,
+			channel TEXT NOT NULL,
+			UNIQUE (nick, channel)
+		);
 	);
-);
-
-enum {
-	InsertContext,
-	InsertName,
-	InsertEvent,
-	InsertJoin,
-	DeleteJoin,
-	UpdateJoin,
-	InsertEvents,
-};
-static const char *Statements[] = {
-	[InsertContext] = SQL(
+	dbExec(CreateJoins);
+
+	const char *InsertContext = SQL(
 		INSERT OR IGNORE INTO contexts (network, name, query)
 		VALUES (:network, :context, :query);
-	),
-
-	[InsertName] = SQL(
-		INSERT OR IGNORE INTO names (nick, user, host)
-		VALUES (:nick, :user, :host);
-	),
+	);
+	dbPersist(&insert.context, InsertContext);
 
-	[InsertEvent] = SQL(
+	const char *InsertEvent = SQL(
 		INSERT INTO events (time, type, context, name, target, message)
 		SELECT
 			coalesce(datetime(:time), datetime('now')),
@@ -69,9 +59,10 @@ static const char *Statements[] = {
 			AND names.nick = :nick
 			AND names.user = :user
 			AND names.host = :host;
-	),
+	);
+	dbPersist(&insert.event, InsertEvent);
 
-	[InsertEvents] = SQL(
+	const char *InsertEvents = SQL(
 		INSERT INTO events (time, type, context, name, target, message)
 		SELECT
 			coalesce(datetime(:time), datetime('now')),
@@ -83,90 +74,107 @@ static const char *Statements[] = {
 			AND names.nick = :nick
 			AND names.user = :user
 			AND names.host = :host;
-	),
-
-	[InsertJoin] = SQL(
-		INSERT INTO joins (nick, channel) VALUES (:nick, :channel);
-	),
-	[DeleteJoin] = SQL(
-		DELETE FROM joins WHERE nick = :nick AND channel = :channel;
-	),
-	[UpdateJoin] = SQL(
-		UPDATE joins SET nick = :new WHERE nick = :old;
-	),
-};
-
-static sqlite3_stmt *stmts[ARRAY_LEN(Statements)];
-
-static void prepare(void) {
-	dbExec(db, CreateJoins);
-	for (size_t i = 0; i < ARRAY_LEN(stmts); ++i) {
-		stmts[i] = dbPrepare(db, true, Statements[i]);
-	}
+	);
+	dbPersist(&insert.events, InsertEvents);
 }
 
 static void bindNetwork(const char *network) {
-	dbBindTextCopy(stmts[InsertContext], ":network", network);
-	dbBindTextCopy(stmts[InsertEvent], ":network", network);
-	dbBindTextCopy(stmts[InsertEvents], ":network", network);
+	dbBindTextCopy(insert.context, ":network", network);
+	dbBindTextCopy(insert.event, ":network", network);
+	dbBindTextCopy(insert.events, ":network", network);
 }
 
 static void insertContext(const char *context, bool query) {
-	dbBindText(stmts[InsertContext], ":context", context);
-	dbBindInt(stmts[InsertContext], ":query", query);
-	dbRun(stmts[InsertContext]);
-	dbBindText(stmts[InsertEvent], ":context", context);
-}
-
-static void insertName(const char *nick, const char *user, const char *host) {
-	dbBindText(stmts[InsertName], ":nick", nick);
-	dbBindText(stmts[InsertName], ":user", user);
-	dbBindText(stmts[InsertName], ":host", host);
-	dbRun(stmts[InsertName]);
-	dbBindText(stmts[InsertEvent], ":nick", nick);
-	dbBindText(stmts[InsertEvent], ":user", user);
-	dbBindText(stmts[InsertEvent], ":host", host);
-	dbBindText(stmts[InsertEvents], ":nick", nick);
-	dbBindText(stmts[InsertEvents], ":user", user);
-	dbBindText(stmts[InsertEvents], ":host", host);
+	dbBindText(insert.context, ":context", context);
+	dbBindInt(insert.context, ":query", query);
+	dbRun(insert.context);
 }
 
 static void insertEvent(
-	const char *time, enum Type type, const char *target, const char *message
+	const char *time, enum Type type, const char *context,
+	const char *nick, const char *user, const char *host,
+	const char *target, const char *message
 ) {
-	dbBindText(stmts[InsertEvent], ":time", time);
-	dbBindInt(stmts[InsertEvent], ":type", type);
-	dbBindText(stmts[InsertEvent], ":target", target);
-	dbBindText(stmts[InsertEvent], ":message", message);
-	dbRun(stmts[InsertEvent]);
+	dbBindText(insert.event, ":time", time);
+	dbBindInt(insert.event, ":type", type);
+	dbBindText(insert.event, ":context", context);
+	dbBindText(insert.event, ":nick", nick);
+	dbBindText(insert.event, ":user", user);
+	dbBindText(insert.event, ":host", host);
+	dbBindText(insert.event, ":target", target);
+	dbBindText(insert.event, ":message", message);
+	dbRun(insert.event);
 }
 
 static void insertEvents(
-	const char *time, enum Type type, const char *target, const char *message
+	const char *time, enum Type type,
+	const char *nick, const char *user, const char *host,
+	const char *target, const char *message
 ) {
-	dbBindText(stmts[InsertEvents], ":time", time);
-	dbBindInt(stmts[InsertEvents], ":type", type);
-	dbBindText(stmts[InsertEvents], ":target", target);
-	dbBindText(stmts[InsertEvents], ":message", message);
-	dbRun(stmts[InsertEvents]);
+	dbBindText(insert.events, ":time", time);
+	dbBindInt(insert.events, ":type", type);
+	dbBindText(insert.events, ":nick", nick);
+	dbBindText(insert.events, ":user", user);
+	dbBindText(insert.events, ":host", host);
+	dbBindText(insert.events, ":target", target);
+	dbBindText(insert.events, ":message", message);
+	dbRun(insert.events);
+}
+
+static void insertName(const char *nick, const char *user, const char *host) {
+	static sqlite3_stmt *stmt;
+	const char *sql = SQL(
+		INSERT OR IGNORE INTO names (nick, user, host)
+		VALUES (:nick, :user, :host);
+	);
+	dbPersist(&stmt, sql);
+	dbBindText(stmt, ":nick", nick);
+	dbBindText(stmt, ":user", user);
+	dbBindText(stmt, ":host", host);
+	dbRun(stmt);
 }
 
 static void insertJoin(const char *nick, const char *channel) {
-	dbBindText(stmts[InsertJoin], ":nick", nick);
-	dbBindText(stmts[InsertJoin], ":channel", channel);
-	dbRun(stmts[InsertJoin]);
+	static sqlite3_stmt *stmt;
+	const char *sql = SQL(
+		INSERT OR IGNORE INTO joins (nick, channel) VALUES (:nick, :channel);
+	);
+	dbPersist(&stmt, sql);
+	dbBindText(stmt, ":nick", nick);
+	dbBindText(stmt, ":channel", channel);
+	dbRun(stmt);
 }
 
 static void deleteJoin(const char *nick, const char *channel) {
-	dbBindText(stmts[DeleteJoin], ":nick", nick);
-	dbBindText(stmts[DeleteJoin], ":channel", channel);
-	dbRun(stmts[DeleteJoin]);
+	static sqlite3_stmt *stmt;
+	const char *sql = SQL(
+		DELETE FROM joins WHERE nick = :nick AND channel = :channel;
+	);
+	dbPersist(&stmt, sql);
+	dbBindText(stmt, ":nick", nick);
+	dbBindText(stmt, ":channel", channel);
+	dbRun(stmt);
 }
 
 static void updateJoin(const char *old, const char *new) {
-	dbBindText(stmts[UpdateJoin], ":old", old);
-	dbBindText(stmts[UpdateJoin], ":new", new);
-	dbRun(stmts[UpdateJoin]);
+	static sqlite3_stmt *stmt;
+	const char *sql = SQL(
+		UPDATE joins SET nick = :new WHERE nick = :old;
+	);
+	dbPersist(&stmt, sql);
+	dbBindText(stmt, ":old", old);
+	dbBindText(stmt, ":new", new);
+	dbRun(stmt);
+}
+
+static void clearJoins(const char *channel) {
+	static sqlite3_stmt *stmt;
+	const char *sql = SQL(
+		DELETE FROM joins WHERE channel = :channel;
+	);
+	dbPersist(&stmt, sql);
+	dbBindText(stmt, ":channel", channel);
+	dbRun(stmt);
 }
 
 static struct tls *client;
@@ -288,31 +296,35 @@ static void handlePrivmsg(struct Message *msg) {
 	require(msg, 2);
 	if (!msg->nick) return;
 
-	insertName(msg->nick, msg->user, msg->host);
-	if (strchr(chanTypes, msg->params[0][0])) {
-		insertContext(msg->params[0], false);
-	} else if (strcmp(msg->params[0], self)) {
-		insertContext(msg->params[0], true);
-	} else {
-		insertContext(msg->nick, true);
+	bool query = true;
+	const char *context = msg->params[0];
+	if (strchr(chanTypes, context[0])) query = false;
+	if (!strcmp(context, self)) context = msg->nick;
+
+	enum Type type = (!strcmp(msg->cmd, "NOTICE") ? Notice : Privmsg);
+	char *message = msg->params[1];
+	if (!strncmp(message, "\1ACTION ", 8)) {
+		message += 8;
+		message[strcspn(message, "\1")] = '\0';
+		type = Action;
 	}
 
-	if (!strncmp(msg->params[1], "\1ACTION ", 8)) {
-		char *action = &msg->params[1][8];
-		action[strcspn(action, "\1")] = '\0';
-		insertEvent(msg->time, Action, NULL, action);
-	} else if (!strcmp(msg->cmd, "NOTICE")) {
-		insertEvent(msg->time, Notice, NULL, msg->params[1]);
-	} else {
-		insertEvent(msg->time, Privmsg, NULL, msg->params[1]);
-	}
+	insertContext(context, query);
+	insertName(msg->nick, msg->user, msg->host);
+	insertEvent(
+		msg->time, type, context,
+		msg->nick, msg->user, msg->host, NULL, message
+	);
 }
 
 static void handleJoin(struct Message *msg) {
 	require(msg, 1);
 	insertContext(msg->params[0], false);
 	insertName(msg->nick, msg->user, msg->host);
-	insertEvent(msg->time, Join, NULL, NULL);
+	insertEvent(
+		msg->time, Join, msg->params[0],
+		msg->nick, msg->user, msg->host, NULL, NULL
+	);
 	insertJoin(msg->nick, msg->params[0]);
 }
 
@@ -320,25 +332,41 @@ static void handlePart(struct Message *msg) {
 	require(msg, 1);
 	insertContext(msg->params[0], false);
 	insertName(msg->nick, msg->user, msg->host);
-	insertEvent(msg->time, Part, NULL, msg->params[1]);
-	deleteJoin(msg->nick, msg->params[0]);
-	// TODO: Clear joins if self.
+	insertEvent(
+		msg->time, Part, msg->params[0],
+		msg->nick, msg->user, msg->host, NULL, msg->params[1]
+	);
+	if (!strcmp(msg->nick, self)) {
+		clearJoins(msg->params[0]);
+	} else {
+		deleteJoin(msg->nick, msg->params[0]);
+	}
 }
 
 static void handleKick(struct Message *msg) {
 	require(msg, 2);
 	insertContext(msg->params[0], false);
 	insertName(msg->nick, msg->user, msg->host);
-	insertEvent(msg->time, Kick, msg->params[1], msg->params[2]);
-	deleteJoin(msg->params[1], msg->params[0]);
-	// TODO: Clear joins if self.
+	insertEvent(
+		msg->time, Kick, msg->params[0],
+		msg->nick, msg->user, msg->host,
+		msg->params[1], msg->params[2]
+	);
+	if (!strcmp(msg->params[1], self)) {
+		clearJoins(msg->params[0]);
+	} else {
+		deleteJoin(msg->params[1], msg->params[0]);
+	}
 }
 
 static void handleNick(struct Message *msg) {
 	require(msg, 1);
 	if (!strcmp(msg->nick, self)) set(&self, msg->params[0]);
 	insertName(msg->nick, msg->user, msg->host);
-	insertEvents(msg->time, Nick, msg->params[0], NULL);
+	insertEvents(
+		msg->time, Nick,
+		msg->nick, msg->user, msg->host, msg->params[0], NULL
+	);
 	updateJoin(msg->nick, msg->params[0]);
 }
 
@@ -369,9 +397,9 @@ static void handle(struct Message msg) {
 	for (size_t i = 0; i < ARRAY_LEN(Handlers); ++i) {
 		if (strcmp(msg.cmd, Handlers[i].cmd)) continue;
 		if (Handlers[i].transaction) {
-			dbExec(db, SQL(BEGIN TRANSACTION;));
+			dbExec(SQL(BEGIN TRANSACTION;));
 			Handlers[i].fn(&msg);
-			dbExec(db, SQL(COMMIT TRANSACTION;));
+			dbExec(SQL(COMMIT TRANSACTION;));
 		} else {
 			Handlers[i].fn(&msg);
 		}
@@ -413,19 +441,17 @@ int main(int argc, char *argv[]) {
 
 	int flags = SQLITE_OPEN_READWRITE;
 	if (init) flags |= SQLITE_OPEN_CREATE;
-
-	db = dbFind(path, flags);
-	if (!db) errx(EX_NOINPUT, "database not found");
+	dbFind(path, flags);
 
 	if (init) {
-		dbInit(db);
+		dbInit();
 		return EX_OK;
 	}
 	if (migrate) {
-		dbMigrate(db);
+		dbMigrate();
 		return EX_OK;
 	}
-	if (dbVersion(db) != DatabaseVersion) {
+	if (dbVersion() != DatabaseVersion) {
 		errx(EX_CONFIG, "database out of date; migrate with -m");
 	}
 
@@ -468,6 +494,7 @@ int main(int argc, char *argv[]) {
 		ssize_t ret = tls_read(client, &buf[len], sizeof(buf) - len);
 		if (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) continue;
 		if (ret < 0) errx(EX_IOERR, "tls_read: %s", tls_error(client));
+		if (!ret) break;
 		len += ret;
 
 		char *line = buf;
@@ -482,5 +509,5 @@ int main(int argc, char *argv[]) {
 		memmove(buf, line, len);
 	}
 
-	// TODO: Clean up statements and db on exit.
+	dbClose();
 }