diff options
Diffstat (limited to 'litterbox.c')
-rw-r--r-- | litterbox.c | 275 |
1 files changed, 151 insertions, 124 deletions
diff --git a/litterbox.c b/litterbox.c index 39fce8f..658822c 100644 --- a/litterbox.c +++ b/litterbox.c @@ -16,9 +16,7 @@ #include <assert.h> #include <err.h> -#include <sqlite3.h> #include <stdarg.h> -#include <stdbool.h> #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -28,37 +26,29 @@ #include "database.h" -static sqlite3 *db; +static struct { + sqlite3_stmt *context; + sqlite3_stmt *event; + sqlite3_stmt *events; +} insert; -static const char *CreateJoins = SQL( - CREATE TEMPORARY TABLE joins ( - nick TEXT NOT NULL, - channel TEXT NOT NULL, - UNIQUE (nick, channel) +static void prepare(void) { + const char *CreateJoins = SQL( + CREATE TEMPORARY TABLE joins ( + nick TEXT NOT NULL, + channel TEXT NOT NULL, + UNIQUE (nick, channel) + ); ); -); - -enum { - InsertContext, - InsertName, - InsertEvent, - InsertJoin, - DeleteJoin, - UpdateJoin, - InsertEvents, -}; -static const char *Statements[] = { - [InsertContext] = SQL( + dbExec(CreateJoins); + + const char *InsertContext = SQL( INSERT OR IGNORE INTO contexts (network, name, query) VALUES (:network, :context, :query); - ), - - [InsertName] = SQL( - INSERT OR IGNORE INTO names (nick, user, host) - VALUES (:nick, :user, :host); - ), + ); + dbPersist(&insert.context, InsertContext); - [InsertEvent] = SQL( + const char *InsertEvent = SQL( INSERT INTO events (time, type, context, name, target, message) SELECT coalesce(datetime(:time), datetime('now')), @@ -69,9 +59,10 @@ static const char *Statements[] = { AND names.nick = :nick AND names.user = :user AND names.host = :host; - ), + ); + dbPersist(&insert.event, InsertEvent); - [InsertEvents] = SQL( + const char *InsertEvents = SQL( INSERT INTO events (time, type, context, name, target, message) SELECT coalesce(datetime(:time), datetime('now')), @@ -83,90 +74,107 @@ static const char *Statements[] = { AND names.nick = :nick AND names.user = :user AND names.host = :host; - ), - - [InsertJoin] = SQL( - INSERT INTO joins (nick, channel) VALUES (:nick, :channel); - ), - [DeleteJoin] = SQL( - DELETE FROM joins WHERE nick = :nick AND channel = :channel; - ), - [UpdateJoin] = SQL( - UPDATE joins SET nick = :new WHERE nick = :old; - ), -}; - -static sqlite3_stmt *stmts[ARRAY_LEN(Statements)]; - -static void prepare(void) { - dbExec(db, CreateJoins); - for (size_t i = 0; i < ARRAY_LEN(stmts); ++i) { - stmts[i] = dbPrepare(db, true, Statements[i]); - } + ); + dbPersist(&insert.events, InsertEvents); } static void bindNetwork(const char *network) { - dbBindTextCopy(stmts[InsertContext], ":network", network); - dbBindTextCopy(stmts[InsertEvent], ":network", network); - dbBindTextCopy(stmts[InsertEvents], ":network", network); + dbBindTextCopy(insert.context, ":network", network); + dbBindTextCopy(insert.event, ":network", network); + dbBindTextCopy(insert.events, ":network", network); } static void insertContext(const char *context, bool query) { - dbBindText(stmts[InsertContext], ":context", context); - dbBindInt(stmts[InsertContext], ":query", query); - dbRun(stmts[InsertContext]); - dbBindText(stmts[InsertEvent], ":context", context); -} - -static void insertName(const char *nick, const char *user, const char *host) { - dbBindText(stmts[InsertName], ":nick", nick); - dbBindText(stmts[InsertName], ":user", user); - dbBindText(stmts[InsertName], ":host", host); - dbRun(stmts[InsertName]); - dbBindText(stmts[InsertEvent], ":nick", nick); - dbBindText(stmts[InsertEvent], ":user", user); - dbBindText(stmts[InsertEvent], ":host", host); - dbBindText(stmts[InsertEvents], ":nick", nick); - dbBindText(stmts[InsertEvents], ":user", user); - dbBindText(stmts[InsertEvents], ":host", host); + dbBindText(insert.context, ":context", context); + dbBindInt(insert.context, ":query", query); + dbRun(insert.context); } static void insertEvent( - const char *time, enum Type type, const char *target, const char *message + const char *time, enum Type type, const char *context, + const char *nick, const char *user, const char *host, + const char *target, const char *message ) { - dbBindText(stmts[InsertEvent], ":time", time); - dbBindInt(stmts[InsertEvent], ":type", type); - dbBindText(stmts[InsertEvent], ":target", target); - dbBindText(stmts[InsertEvent], ":message", message); - dbRun(stmts[InsertEvent]); + dbBindText(insert.event, ":time", time); + dbBindInt(insert.event, ":type", type); + dbBindText(insert.event, ":context", context); + dbBindText(insert.event, ":nick", nick); + dbBindText(insert.event, ":user", user); + dbBindText(insert.event, ":host", host); + dbBindText(insert.event, ":target", target); + dbBindText(insert.event, ":message", message); + dbRun(insert.event); } static void insertEvents( - const char *time, enum Type type, const char *target, const char *message + const char *time, enum Type type, + const char *nick, const char *user, const char *host, + const char *target, const char *message ) { - dbBindText(stmts[InsertEvents], ":time", time); - dbBindInt(stmts[InsertEvents], ":type", type); - dbBindText(stmts[InsertEvents], ":target", target); - dbBindText(stmts[InsertEvents], ":message", message); - dbRun(stmts[InsertEvents]); + dbBindText(insert.events, ":time", time); + dbBindInt(insert.events, ":type", type); + dbBindText(insert.events, ":nick", nick); + dbBindText(insert.events, ":user", user); + dbBindText(insert.events, ":host", host); + dbBindText(insert.events, ":target", target); + dbBindText(insert.events, ":message", message); + dbRun(insert.events); +} + +static void insertName(const char *nick, const char *user, const char *host) { + static sqlite3_stmt *stmt; + const char *sql = SQL( + INSERT OR IGNORE INTO names (nick, user, host) + VALUES (:nick, :user, :host); + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":user", user); + dbBindText(stmt, ":host", host); + dbRun(stmt); } static void insertJoin(const char *nick, const char *channel) { - dbBindText(stmts[InsertJoin], ":nick", nick); - dbBindText(stmts[InsertJoin], ":channel", channel); - dbRun(stmts[InsertJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + INSERT OR IGNORE INTO joins (nick, channel) VALUES (:nick, :channel); + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static void deleteJoin(const char *nick, const char *channel) { - dbBindText(stmts[DeleteJoin], ":nick", nick); - dbBindText(stmts[DeleteJoin], ":channel", channel); - dbRun(stmts[DeleteJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + DELETE FROM joins WHERE nick = :nick AND channel = :channel; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static void updateJoin(const char *old, const char *new) { - dbBindText(stmts[UpdateJoin], ":old", old); - dbBindText(stmts[UpdateJoin], ":new", new); - dbRun(stmts[UpdateJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + UPDATE joins SET nick = :new WHERE nick = :old; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":old", old); + dbBindText(stmt, ":new", new); + dbRun(stmt); +} + +static void clearJoins(const char *channel) { + static sqlite3_stmt *stmt; + const char *sql = SQL( + DELETE FROM joins WHERE channel = :channel; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static struct tls *client; @@ -288,31 +296,35 @@ static void handlePrivmsg(struct Message *msg) { require(msg, 2); if (!msg->nick) return; - insertName(msg->nick, msg->user, msg->host); - if (strchr(chanTypes, msg->params[0][0])) { - insertContext(msg->params[0], false); - } else if (strcmp(msg->params[0], self)) { - insertContext(msg->params[0], true); - } else { - insertContext(msg->nick, true); + bool query = true; + const char *context = msg->params[0]; + if (strchr(chanTypes, context[0])) query = false; + if (!strcmp(context, self)) context = msg->nick; + + enum Type type = (!strcmp(msg->cmd, "NOTICE") ? Notice : Privmsg); + char *message = msg->params[1]; + if (!strncmp(message, "\1ACTION ", 8)) { + message += 8; + message[strcspn(message, "\1")] = '\0'; + type = Action; } - if (!strncmp(msg->params[1], "\1ACTION ", 8)) { - char *action = &msg->params[1][8]; - action[strcspn(action, "\1")] = '\0'; - insertEvent(msg->time, Action, NULL, action); - } else if (!strcmp(msg->cmd, "NOTICE")) { - insertEvent(msg->time, Notice, NULL, msg->params[1]); - } else { - insertEvent(msg->time, Privmsg, NULL, msg->params[1]); - } + insertContext(context, query); + insertName(msg->nick, msg->user, msg->host); + insertEvent( + msg->time, type, context, + msg->nick, msg->user, msg->host, NULL, message + ); } static void handleJoin(struct Message *msg) { require(msg, 1); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Join, NULL, NULL); + insertEvent( + msg->time, Join, msg->params[0], + msg->nick, msg->user, msg->host, NULL, NULL + ); insertJoin(msg->nick, msg->params[0]); } @@ -320,25 +332,41 @@ static void handlePart(struct Message *msg) { require(msg, 1); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Part, NULL, msg->params[1]); - deleteJoin(msg->nick, msg->params[0]); - // TODO: Clear joins if self. + insertEvent( + msg->time, Part, msg->params[0], + msg->nick, msg->user, msg->host, NULL, msg->params[1] + ); + if (!strcmp(msg->nick, self)) { + clearJoins(msg->params[0]); + } else { + deleteJoin(msg->nick, msg->params[0]); + } } static void handleKick(struct Message *msg) { require(msg, 2); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Kick, msg->params[1], msg->params[2]); - deleteJoin(msg->params[1], msg->params[0]); - // TODO: Clear joins if self. + insertEvent( + msg->time, Kick, msg->params[0], + msg->nick, msg->user, msg->host, + msg->params[1], msg->params[2] + ); + if (!strcmp(msg->params[1], self)) { + clearJoins(msg->params[0]); + } else { + deleteJoin(msg->params[1], msg->params[0]); + } } static void handleNick(struct Message *msg) { require(msg, 1); if (!strcmp(msg->nick, self)) set(&self, msg->params[0]); insertName(msg->nick, msg->user, msg->host); - insertEvents(msg->time, Nick, msg->params[0], NULL); + insertEvents( + msg->time, Nick, + msg->nick, msg->user, msg->host, msg->params[0], NULL + ); updateJoin(msg->nick, msg->params[0]); } @@ -369,9 +397,9 @@ static void handle(struct Message msg) { for (size_t i = 0; i < ARRAY_LEN(Handlers); ++i) { if (strcmp(msg.cmd, Handlers[i].cmd)) continue; if (Handlers[i].transaction) { - dbExec(db, SQL(BEGIN TRANSACTION;)); + dbExec(SQL(BEGIN TRANSACTION;)); Handlers[i].fn(&msg); - dbExec(db, SQL(COMMIT TRANSACTION;)); + dbExec(SQL(COMMIT TRANSACTION;)); } else { Handlers[i].fn(&msg); } @@ -413,19 +441,17 @@ int main(int argc, char *argv[]) { int flags = SQLITE_OPEN_READWRITE; if (init) flags |= SQLITE_OPEN_CREATE; - - db = dbFind(path, flags); - if (!db) errx(EX_NOINPUT, "database not found"); + dbFind(path, flags); if (init) { - dbInit(db); + dbInit(); return EX_OK; } if (migrate) { - dbMigrate(db); + dbMigrate(); return EX_OK; } - if (dbVersion(db) != DatabaseVersion) { + if (dbVersion() != DatabaseVersion) { errx(EX_CONFIG, "database out of date; migrate with -m"); } @@ -468,6 +494,7 @@ int main(int argc, char *argv[]) { ssize_t ret = tls_read(client, &buf[len], sizeof(buf) - len); if (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) continue; if (ret < 0) errx(EX_IOERR, "tls_read: %s", tls_error(client)); + if (!ret) break; len += ret; char *line = buf; @@ -482,5 +509,5 @@ int main(int argc, char *argv[]) { memmove(buf, line, len); } - // TODO: Clean up statements and db on exit. + dbClose(); } |