diff options
-rw-r--r-- | database.h | 106 | ||||
-rw-r--r-- | litterbox.c | 275 | ||||
-rw-r--r-- | unscoop.c | 27 |
3 files changed, 222 insertions, 186 deletions
diff --git a/database.h b/database.h index 1d5ef1c..8d8c985 100644 --- a/database.h +++ b/database.h @@ -45,13 +45,14 @@ enum Type { }; static bool verbose; +static sqlite3 *db; -static inline void dbExec(sqlite3 *db, const char *sql) { +static inline void dbExec(const char *sql) { int error = sqlite3_exec(db, sql, NULL, NULL, NULL); if (error) errx(EX_SOFTWARE, "%s: %s", sqlite3_errmsg(db), sql); } -static inline sqlite3 *dbOpen(char *path, int flags) { +static inline void dbOpen(char *path, int flags) { char *base = strrchr(path, '/'); if (flags & SQLITE_OPEN_CREATE && base) { *base = '\0'; @@ -60,22 +61,24 @@ static inline sqlite3 *dbOpen(char *path, int flags) { *base = '/'; } - sqlite3 *db; int error = sqlite3_open_v2(path, &db, flags, NULL); if (error == SQLITE_CANTOPEN) { sqlite3_close(db); - return NULL; + db = NULL; + return; } if (error) errx(EX_NOINPUT, "%s: %s", path, sqlite3_errmsg(db)); sqlite3_busy_timeout(db, 1000); - dbExec(db, SQL(PRAGMA foreign_keys = true;)); - - return db; + dbExec(SQL(PRAGMA foreign_keys = true;)); } -static inline sqlite3 *dbFind(char *path, int flags) { - if (path) return dbOpen(path, flags); +static inline void dbFind(char *path, int flags) { + if (path) { + dbOpen(path, flags); + if (db) return; + errx(EX_NOINPUT, "%s: database not found", path); + } const char *home = getenv("HOME"); const char *dataHome = getenv("XDG_DATA_HOME"); @@ -88,55 +91,71 @@ static inline sqlite3 *dbFind(char *path, int flags) { if (!home) errx(EX_CONFIG, "HOME unset"); snprintf(buf, sizeof(buf), "%s/.local/share/" DATABASE_PATH, home); } - sqlite3 *db = dbOpen(buf, flags); - if (db) return db; + dbOpen(buf, flags); + if (db) return; if (!dataDirs) dataDirs = "/usr/local/share:/usr/share"; while (*dataDirs) { size_t len = strcspn(dataDirs, ":"); snprintf(buf, sizeof(buf), "%.*s/" DATABASE_PATH, (int)len, dataDirs); - db = dbOpen(buf, flags); - if (db) return db; + dbOpen(buf, flags); + if (db) return; dataDirs += len; if (*dataDirs) dataDirs++; } - return NULL; + errx(EX_NOINPUT, "database not found"); } -static inline sqlite3_stmt * -dbPrepare(sqlite3 *db, bool persistent, const char *sql) { +static struct Persist { sqlite3_stmt *stmt; + struct Persist *prev; +} *persistHead; + +static inline void dbPersist(sqlite3_stmt **stmt, const char *sql) { + if (*stmt) return; + int error = sqlite3_prepare_v3( - db, sql, -1, (persistent ? SQLITE_PREPARE_PERSISTENT : 0), &stmt, NULL + db, sql, -1, SQLITE_PREPARE_PERSISTENT, stmt, NULL ); if (error) errx(EX_SOFTWARE, "%s: %s", sqlite3_errmsg(db), sql); + + struct Persist *persist = malloc(sizeof(*persist)); + persist->stmt = *stmt; + persist->prev = persistHead; + persistHead = persist; +} + +static inline void dbClose(void) { + for (struct Persist *persist = persistHead; persist;) { + sqlite3_finalize(persist->stmt); + struct Persist *prev = persist->prev; + free(persist); + persist = prev; + } + sqlite3_close(db); +} + +static inline sqlite3_stmt *dbPrepare(const char *sql) { + sqlite3_stmt *stmt; + int error = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); + if (error) err(EX_SOFTWARE, "%s: %s", sqlite3_errmsg(db), sql); return stmt; } static inline int dbParam(sqlite3_stmt *stmt, const char *param) { int index = sqlite3_bind_parameter_index(stmt, param); if (index) return index; - errx( - EX_SOFTWARE, "no such parameter %s: %s", - param, sqlite3_sql(stmt) - ); + errx(EX_SOFTWARE, "no such parameter %s: %s", param, sqlite3_sql(stmt)); } static inline void dbBindNull(sqlite3_stmt *stmt, const char *param) { if (!sqlite3_bind_null(stmt, dbParam(stmt, param))) return; - errx( - EX_SOFTWARE, "sqlite3_bind_null: %s", - sqlite3_errmsg(sqlite3_db_handle(stmt)) - ); + errx(EX_SOFTWARE, "sqlite3_bind_null: %s", sqlite3_errmsg(db)); } -static inline void -dbBindInt(sqlite3_stmt *stmt, const char *param, int value) { +static inline void dbBindInt(sqlite3_stmt *stmt, const char *param, int value) { if (!sqlite3_bind_int(stmt, dbParam(stmt, param), value)) return; - errx( - EX_SOFTWARE, "sqlite3_bind_int: %s", - sqlite3_errmsg(sqlite3_db_handle(stmt)) - ); + errx(EX_SOFTWARE, "sqlite3_bind_int: %s", sqlite3_errmsg(db)); } static inline void dbBindText5( @@ -146,11 +165,7 @@ static inline void dbBindText5( int error = sqlite3_bind_text( stmt, dbParam(stmt, param), text, len, (copy ? SQLITE_TRANSIENT : NULL) ); - if (!error) return; - errx( - EX_SOFTWARE, "sqlite3_bind_text: %s", - sqlite3_errmsg(sqlite3_db_handle(stmt)) - ); + if (error) err(EX_SOFTWARE, "sqlite3_bind_text: %s", sqlite3_errmsg(db)); } static inline void @@ -171,10 +186,7 @@ dbBindTextCopy(sqlite3_stmt *stmt, const char *param, const char *text) { static inline int dbStep(sqlite3_stmt *stmt) { int error = sqlite3_step(stmt); if (error == SQLITE_ROW || error == SQLITE_DONE) return error; - errx( - EX_SOFTWARE, "%s: %s", - sqlite3_errmsg(sqlite3_db_handle(stmt)), sqlite3_expanded_sql(stmt) - ); + errx(EX_SOFTWARE, "%s: %s", sqlite3_errmsg(db), sqlite3_expanded_sql(stmt)); } static inline void dbRun(sqlite3_stmt *stmt) { @@ -187,8 +199,8 @@ static inline void dbRun(sqlite3_stmt *stmt) { sqlite3_reset(stmt); } -static inline int dbVersion(sqlite3 *db) { - sqlite3_stmt *stmt = dbPrepare(db, false, SQL(PRAGMA user_version;)); +static inline int dbVersion(void) { + sqlite3_stmt *stmt = dbPrepare(SQL(PRAGMA user_version;)); dbStep(stmt); int version = sqlite3_column_int(stmt, 0); sqlite3_finalize(stmt); @@ -258,16 +270,16 @@ static const char *InitSQL = SQL( COMMIT TRANSACTION; ); -static inline void dbInit(sqlite3 *db) { - dbExec(db, InitSQL); +static inline void dbInit(void) { + dbExec(InitSQL); } static const char *MigrationSQL[] = { NULL, }; -static inline void dbMigrate(sqlite3 *db) { - for (int version = dbVersion(db); version < DatabaseVersion; ++version) { - dbExec(db, MigrationSQL[version]); +static inline void dbMigrate(void) { + for (int version = dbVersion(); version < DatabaseVersion; ++version) { + dbExec(MigrationSQL[version]); } } diff --git a/litterbox.c b/litterbox.c index 39fce8f..658822c 100644 --- a/litterbox.c +++ b/litterbox.c @@ -16,9 +16,7 @@ #include <assert.h> #include <err.h> -#include <sqlite3.h> #include <stdarg.h> -#include <stdbool.h> #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -28,37 +26,29 @@ #include "database.h" -static sqlite3 *db; +static struct { + sqlite3_stmt *context; + sqlite3_stmt *event; + sqlite3_stmt *events; +} insert; -static const char *CreateJoins = SQL( - CREATE TEMPORARY TABLE joins ( - nick TEXT NOT NULL, - channel TEXT NOT NULL, - UNIQUE (nick, channel) +static void prepare(void) { + const char *CreateJoins = SQL( + CREATE TEMPORARY TABLE joins ( + nick TEXT NOT NULL, + channel TEXT NOT NULL, + UNIQUE (nick, channel) + ); ); -); - -enum { - InsertContext, - InsertName, - InsertEvent, - InsertJoin, - DeleteJoin, - UpdateJoin, - InsertEvents, -}; -static const char *Statements[] = { - [InsertContext] = SQL( + dbExec(CreateJoins); + + const char *InsertContext = SQL( INSERT OR IGNORE INTO contexts (network, name, query) VALUES (:network, :context, :query); - ), - - [InsertName] = SQL( - INSERT OR IGNORE INTO names (nick, user, host) - VALUES (:nick, :user, :host); - ), + ); + dbPersist(&insert.context, InsertContext); - [InsertEvent] = SQL( + const char *InsertEvent = SQL( INSERT INTO events (time, type, context, name, target, message) SELECT coalesce(datetime(:time), datetime('now')), @@ -69,9 +59,10 @@ static const char *Statements[] = { AND names.nick = :nick AND names.user = :user AND names.host = :host; - ), + ); + dbPersist(&insert.event, InsertEvent); - [InsertEvents] = SQL( + const char *InsertEvents = SQL( INSERT INTO events (time, type, context, name, target, message) SELECT coalesce(datetime(:time), datetime('now')), @@ -83,90 +74,107 @@ static const char *Statements[] = { AND names.nick = :nick AND names.user = :user AND names.host = :host; - ), - - [InsertJoin] = SQL( - INSERT INTO joins (nick, channel) VALUES (:nick, :channel); - ), - [DeleteJoin] = SQL( - DELETE FROM joins WHERE nick = :nick AND channel = :channel; - ), - [UpdateJoin] = SQL( - UPDATE joins SET nick = :new WHERE nick = :old; - ), -}; - -static sqlite3_stmt *stmts[ARRAY_LEN(Statements)]; - -static void prepare(void) { - dbExec(db, CreateJoins); - for (size_t i = 0; i < ARRAY_LEN(stmts); ++i) { - stmts[i] = dbPrepare(db, true, Statements[i]); - } + ); + dbPersist(&insert.events, InsertEvents); } static void bindNetwork(const char *network) { - dbBindTextCopy(stmts[InsertContext], ":network", network); - dbBindTextCopy(stmts[InsertEvent], ":network", network); - dbBindTextCopy(stmts[InsertEvents], ":network", network); + dbBindTextCopy(insert.context, ":network", network); + dbBindTextCopy(insert.event, ":network", network); + dbBindTextCopy(insert.events, ":network", network); } static void insertContext(const char *context, bool query) { - dbBindText(stmts[InsertContext], ":context", context); - dbBindInt(stmts[InsertContext], ":query", query); - dbRun(stmts[InsertContext]); - dbBindText(stmts[InsertEvent], ":context", context); -} - -static void insertName(const char *nick, const char *user, const char *host) { - dbBindText(stmts[InsertName], ":nick", nick); - dbBindText(stmts[InsertName], ":user", user); - dbBindText(stmts[InsertName], ":host", host); - dbRun(stmts[InsertName]); - dbBindText(stmts[InsertEvent], ":nick", nick); - dbBindText(stmts[InsertEvent], ":user", user); - dbBindText(stmts[InsertEvent], ":host", host); - dbBindText(stmts[InsertEvents], ":nick", nick); - dbBindText(stmts[InsertEvents], ":user", user); - dbBindText(stmts[InsertEvents], ":host", host); + dbBindText(insert.context, ":context", context); + dbBindInt(insert.context, ":query", query); + dbRun(insert.context); } static void insertEvent( - const char *time, enum Type type, const char *target, const char *message + const char *time, enum Type type, const char *context, + const char *nick, const char *user, const char *host, + const char *target, const char *message ) { - dbBindText(stmts[InsertEvent], ":time", time); - dbBindInt(stmts[InsertEvent], ":type", type); - dbBindText(stmts[InsertEvent], ":target", target); - dbBindText(stmts[InsertEvent], ":message", message); - dbRun(stmts[InsertEvent]); + dbBindText(insert.event, ":time", time); + dbBindInt(insert.event, ":type", type); + dbBindText(insert.event, ":context", context); + dbBindText(insert.event, ":nick", nick); + dbBindText(insert.event, ":user", user); + dbBindText(insert.event, ":host", host); + dbBindText(insert.event, ":target", target); + dbBindText(insert.event, ":message", message); + dbRun(insert.event); } static void insertEvents( - const char *time, enum Type type, const char *target, const char *message + const char *time, enum Type type, + const char *nick, const char *user, const char *host, + const char *target, const char *message ) { - dbBindText(stmts[InsertEvents], ":time", time); - dbBindInt(stmts[InsertEvents], ":type", type); - dbBindText(stmts[InsertEvents], ":target", target); - dbBindText(stmts[InsertEvents], ":message", message); - dbRun(stmts[InsertEvents]); + dbBindText(insert.events, ":time", time); + dbBindInt(insert.events, ":type", type); + dbBindText(insert.events, ":nick", nick); + dbBindText(insert.events, ":user", user); + dbBindText(insert.events, ":host", host); + dbBindText(insert.events, ":target", target); + dbBindText(insert.events, ":message", message); + dbRun(insert.events); +} + +static void insertName(const char *nick, const char *user, const char *host) { + static sqlite3_stmt *stmt; + const char *sql = SQL( + INSERT OR IGNORE INTO names (nick, user, host) + VALUES (:nick, :user, :host); + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":user", user); + dbBindText(stmt, ":host", host); + dbRun(stmt); } static void insertJoin(const char *nick, const char *channel) { - dbBindText(stmts[InsertJoin], ":nick", nick); - dbBindText(stmts[InsertJoin], ":channel", channel); - dbRun(stmts[InsertJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + INSERT OR IGNORE INTO joins (nick, channel) VALUES (:nick, :channel); + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static void deleteJoin(const char *nick, const char *channel) { - dbBindText(stmts[DeleteJoin], ":nick", nick); - dbBindText(stmts[DeleteJoin], ":channel", channel); - dbRun(stmts[DeleteJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + DELETE FROM joins WHERE nick = :nick AND channel = :channel; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":nick", nick); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static void updateJoin(const char *old, const char *new) { - dbBindText(stmts[UpdateJoin], ":old", old); - dbBindText(stmts[UpdateJoin], ":new", new); - dbRun(stmts[UpdateJoin]); + static sqlite3_stmt *stmt; + const char *sql = SQL( + UPDATE joins SET nick = :new WHERE nick = :old; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":old", old); + dbBindText(stmt, ":new", new); + dbRun(stmt); +} + +static void clearJoins(const char *channel) { + static sqlite3_stmt *stmt; + const char *sql = SQL( + DELETE FROM joins WHERE channel = :channel; + ); + dbPersist(&stmt, sql); + dbBindText(stmt, ":channel", channel); + dbRun(stmt); } static struct tls *client; @@ -288,31 +296,35 @@ static void handlePrivmsg(struct Message *msg) { require(msg, 2); if (!msg->nick) return; - insertName(msg->nick, msg->user, msg->host); - if (strchr(chanTypes, msg->params[0][0])) { - insertContext(msg->params[0], false); - } else if (strcmp(msg->params[0], self)) { - insertContext(msg->params[0], true); - } else { - insertContext(msg->nick, true); + bool query = true; + const char *context = msg->params[0]; + if (strchr(chanTypes, context[0])) query = false; + if (!strcmp(context, self)) context = msg->nick; + + enum Type type = (!strcmp(msg->cmd, "NOTICE") ? Notice : Privmsg); + char *message = msg->params[1]; + if (!strncmp(message, "\1ACTION ", 8)) { + message += 8; + message[strcspn(message, "\1")] = '\0'; + type = Action; } - if (!strncmp(msg->params[1], "\1ACTION ", 8)) { - char *action = &msg->params[1][8]; - action[strcspn(action, "\1")] = '\0'; - insertEvent(msg->time, Action, NULL, action); - } else if (!strcmp(msg->cmd, "NOTICE")) { - insertEvent(msg->time, Notice, NULL, msg->params[1]); - } else { - insertEvent(msg->time, Privmsg, NULL, msg->params[1]); - } + insertContext(context, query); + insertName(msg->nick, msg->user, msg->host); + insertEvent( + msg->time, type, context, + msg->nick, msg->user, msg->host, NULL, message + ); } static void handleJoin(struct Message *msg) { require(msg, 1); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Join, NULL, NULL); + insertEvent( + msg->time, Join, msg->params[0], + msg->nick, msg->user, msg->host, NULL, NULL + ); insertJoin(msg->nick, msg->params[0]); } @@ -320,25 +332,41 @@ static void handlePart(struct Message *msg) { require(msg, 1); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Part, NULL, msg->params[1]); - deleteJoin(msg->nick, msg->params[0]); - // TODO: Clear joins if self. + insertEvent( + msg->time, Part, msg->params[0], + msg->nick, msg->user, msg->host, NULL, msg->params[1] + ); + if (!strcmp(msg->nick, self)) { + clearJoins(msg->params[0]); + } else { + deleteJoin(msg->nick, msg->params[0]); + } } static void handleKick(struct Message *msg) { require(msg, 2); insertContext(msg->params[0], false); insertName(msg->nick, msg->user, msg->host); - insertEvent(msg->time, Kick, msg->params[1], msg->params[2]); - deleteJoin(msg->params[1], msg->params[0]); - // TODO: Clear joins if self. + insertEvent( + msg->time, Kick, msg->params[0], + msg->nick, msg->user, msg->host, + msg->params[1], msg->params[2] + ); + if (!strcmp(msg->params[1], self)) { + clearJoins(msg->params[0]); + } else { + deleteJoin(msg->params[1], msg->params[0]); + } } static void handleNick(struct Message *msg) { require(msg, 1); if (!strcmp(msg->nick, self)) set(&self, msg->params[0]); insertName(msg->nick, msg->user, msg->host); - insertEvents(msg->time, Nick, msg->params[0], NULL); + insertEvents( + msg->time, Nick, + msg->nick, msg->user, msg->host, msg->params[0], NULL + ); updateJoin(msg->nick, msg->params[0]); } @@ -369,9 +397,9 @@ static void handle(struct Message msg) { for (size_t i = 0; i < ARRAY_LEN(Handlers); ++i) { if (strcmp(msg.cmd, Handlers[i].cmd)) continue; if (Handlers[i].transaction) { - dbExec(db, SQL(BEGIN TRANSACTION;)); + dbExec(SQL(BEGIN TRANSACTION;)); Handlers[i].fn(&msg); - dbExec(db, SQL(COMMIT TRANSACTION;)); + dbExec(SQL(COMMIT TRANSACTION;)); } else { Handlers[i].fn(&msg); } @@ -413,19 +441,17 @@ int main(int argc, char *argv[]) { int flags = SQLITE_OPEN_READWRITE; if (init) flags |= SQLITE_OPEN_CREATE; - - db = dbFind(path, flags); - if (!db) errx(EX_NOINPUT, "database not found"); + dbFind(path, flags); if (init) { - dbInit(db); + dbInit(); return EX_OK; } if (migrate) { - dbMigrate(db); + dbMigrate(); return EX_OK; } - if (dbVersion(db) != DatabaseVersion) { + if (dbVersion() != DatabaseVersion) { errx(EX_CONFIG, "database out of date; migrate with -m"); } @@ -468,6 +494,7 @@ int main(int argc, char *argv[]) { ssize_t ret = tls_read(client, &buf[len], sizeof(buf) - len); if (ret == TLS_WANT_POLLIN || ret == TLS_WANT_POLLOUT) continue; if (ret < 0) errx(EX_IOERR, "tls_read: %s", tls_error(client)); + if (!ret) break; len += ret; char *line = buf; @@ -482,5 +509,5 @@ int main(int argc, char *argv[]) { memmove(buf, line, len); } - // TODO: Clean up statements and db on exit. + dbClose(); } diff --git a/unscoop.c b/unscoop.c index 0b15430..085c134 100644 --- a/unscoop.c +++ b/unscoop.c @@ -221,12 +221,12 @@ static sqlite3_stmt *insertEvent; static int paramNetwork; static int paramContext; -static void prepareInsert(sqlite3 *db) { +static void prepareInsert(void) { const char *InsertName = SQL( INSERT OR IGNORE INTO names (nick, user, host) VALUES (:nick, coalesce(:user, '*'), coalesce(:host, '*')); ); - insertName = dbPrepare(db, true, InsertName); + dbPersist(&insertName, InsertName); const char *InsertEvent = SQL( INSERT INTO events (time, type, context, name, target, message) @@ -244,7 +244,7 @@ static void prepareInsert(sqlite3 *db) { AND names.user = coalesce(:user, '*') AND names.host = coalesce(:host, '*'); ); - insertEvent = dbPrepare(db, true, InsertEvent); + dbPersist(&insertEvent, InsertEvent); paramNetwork = dbParam(insertEvent, ":network"); paramContext = dbParam(insertEvent, ":context"); } @@ -293,7 +293,7 @@ static void dedupEvents(sqlite3 *db) { ), duplicates AS (SELECT event FROM potentials WHERE diff > 50) DELETE FROM events WHERE event IN duplicates; ); - dbExec(db, Delete); + dbExec(Delete); printf("deleted %d events\n", sqlite3_changes(db)); } @@ -317,9 +317,8 @@ int main(int argc, char *argv[]) { } } - sqlite3 *db = dbFind(path, SQLITE_OPEN_READWRITE); - if (!db) errx(EX_NOINPUT, "database not found"); - if (dbVersion(db) != DatabaseVersion) { + dbFind(path, SQLITE_OPEN_READWRITE); + if (dbVersion() != DatabaseVersion) { errx(EX_CONFIG, "database out of date; migrate with litterbox -m"); } @@ -335,6 +334,7 @@ int main(int argc, char *argv[]) { regex[i] = compile(format->matchers[i].pattern); } + sqlite3_stmt *insertContext = NULL; const char *InsertContext = SQL( INSERT OR IGNORE INTO contexts (network, name, query) VALUES ( @@ -342,11 +342,11 @@ int main(int argc, char *argv[]) { NOT (:context LIKE '#%' OR :context LIKE '&%') ); ); - sqlite3_stmt *insertContext = dbPrepare(db, true, InsertContext); + dbPersist(&insertContext, InsertContext); dbBindText(insertContext, ":network", network); dbBindText(insertContext, ":context", context); - prepareInsert(db); + prepareInsert(); dbBindText(insertEvent, ":network", network); dbBindText(insertEvent, ":context", context); @@ -376,7 +376,7 @@ int main(int argc, char *argv[]) { FILE *file = fopen(argv[i], "r"); if (!file) err(EX_NOINPUT, "%s", argv[i]); - dbExec(db, SQL(BEGIN TRANSACTION;)); + dbExec(SQL(BEGIN TRANSACTION;)); regmatch_t pathNetwork = match[i][format->network]; regmatch_t pathContext = match[i][format->context]; @@ -403,12 +403,9 @@ int main(int argc, char *argv[]) { if (ferror(file)) err(EX_IOERR, "%s", argv[i]); fclose(file); - dbExec(db, SQL(COMMIT TRANSACTION;)); + dbExec(SQL(COMMIT TRANSACTION;)); } printf("\n"); - sqlite3_finalize(insertContext); - sqlite3_finalize(insertName); - sqlite3_finalize(insertEvent); - sqlite3_close(db); + dbClose(); } |