summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--bin/pngo.c38
1 files changed, 23 insertions, 15 deletions
diff --git a/bin/pngo.c b/bin/pngo.c
index 6dd88c87..425d4a9a 100644
--- a/bin/pngo.c
+++ b/bin/pngo.c
@@ -97,13 +97,13 @@ struct PACKED Header {
 };
 
 static uint8_t bytesPerPixel(struct Header header) {
-    assert(header.depth == 8);
+    assert(header.depth >= 8);
     switch (header.color) {
-        case GRAYSCALE:       return 1;
-        case TRUECOLOR:       return 3;
-        case INDEXED:         return 1;
-        case GRAYSCALE_ALPHA: return 2;
-        case TRUECOLOR_ALPHA: return 4;
+        case GRAYSCALE:       return 1 * header.depth / 8;
+        case TRUECOLOR:       return 3 * header.depth / 8;
+        case INDEXED:         return 1 * header.depth / 8;
+        case GRAYSCALE_ALPHA: return 2 * header.depth / 8;
+        case TRUECOLOR_ALPHA: return 4 * header.depth / 8;
     }
 }
 
@@ -220,6 +220,7 @@ enum PACKED FilterType {
     FILT_AVERAGE,
     FILT_PAETH,
 };
+#define FILT__COUNT (FILT_PAETH + 1)
 
 struct FilterBytes {
     uint8_t x;
@@ -230,9 +231,9 @@ struct FilterBytes {
 
 static uint8_t paethPredictor(struct FilterBytes f) {
     int32_t p = (int32_t)f.a + (int32_t)f.b - (int32_t)f.c;
-    int32_t pa = labs(p - (int32_t)f.a);
-    int32_t pb = labs(p - (int32_t)f.b);
-    int32_t pc = labs(p - (int32_t)f.c);
+    int32_t pa = abs(p - (int32_t)f.a);
+    int32_t pb = abs(p - (int32_t)f.b);
+    int32_t pc = abs(p - (int32_t)f.c);
     if (pa <= pb && pa <= pc) return f.a;
     if (pb <= pc) return f.b;
     return f.c;
@@ -273,7 +274,7 @@ static struct Scanline *scanlines(
     for (uint32_t y = 0; y < header.height; ++y) {
         lines[y].type = &data[y * stride];
         lines[y].data = &data[y * stride + 1];
-        if (*lines[y].type > FILT_PAETH) {
+        if (*lines[y].type >= FILT__COUNT) {
             errx(EX_DATAERR, "%s: invalid filter type %hhu", path, *lines[y].type);
         }
     }
@@ -307,11 +308,18 @@ static void reconData(struct Header header, const struct Scanline *lines) {
 static void filterData(struct Header header, const struct Scanline *lines) {
     uint8_t bpp = bytesPerPixel(header);
     for (uint32_t y = header.height - 1; y < header.height; --y) {
-        // TODO: Filter type heuristic.
-        *lines[y].type = FILT_PAETH;
-        for (uint32_t i = (bpp * header.width) - 1; i < bpp * header.width; --i) {
-            lines[y].data[i] = filt(*lines[y].type, filterBytes(lines, bpp, y, i));
+        uint8_t filter[FILT__COUNT][bpp * header.width];
+        uint32_t heuristic[FILT__COUNT] = { 0 };
+        enum FilterType minType = FILT_NONE;
+        for (enum FilterType type = FILT_NONE; type < FILT__COUNT; ++type) {
+            for (uint32_t i = 0; i < bpp * header.width; ++i) {
+                filter[type][i] = filt(type, filterBytes(lines, bpp, y, i));
+                heuristic[type] += abs((int8_t)filter[type][i]);
+            }
+            if (heuristic[type] < heuristic[minType]) minType = type;
         }
+        *lines[y].type = minType;
+        memcpy(lines[y].data, filter[minType], bpp * header.width);
     }
 }
 
@@ -380,7 +388,7 @@ int main(int argc, char *argv[]) {
             path, header.interlace
         );
     }
-    if (header.depth != 8) {
+    if (header.depth < 8) {
         errx(EX_CONFIG, "%s: unsupported bit depth %hhu", path, header.depth);
     }