diff options
-rw-r--r-- | bin/dtch.c | 141 |
1 files changed, 69 insertions, 72 deletions
diff --git a/bin/dtch.c b/bin/dtch.c index a5982a5e..41659243 100644 --- a/bin/dtch.c +++ b/bin/dtch.c @@ -52,30 +52,21 @@ static struct sockaddr_un sockAddr(const char *home, const char *name) { return addr; } -static ssize_t writeAll(int fd, const char *buf, size_t len) { - ssize_t writeLen; - while (0 < (writeLen = write(fd, buf, len))) { - buf += writeLen; - len -= writeLen; - } - return writeLen; -} - static char z; static struct iovec iov = { .iov_base = &z, .iov_len = 1 }; static ssize_t sendFd(int sock, int fd) { - size_t len = CMSG_LEN(sizeof(int)); - char buf[len]; + size_t size = CMSG_LEN(sizeof(int)); + char buf[size]; struct msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1, .msg_control = buf, - .msg_controllen = len, + .msg_controllen = size, }; struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_len = len; + cmsg->cmsg_len = size; cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; *(int *)CMSG_DATA(cmsg) = fd; @@ -84,13 +75,13 @@ static ssize_t sendFd(int sock, int fd) { } static int recvFd(int sock) { - size_t len = CMSG_LEN(sizeof(int)); - char buf[len]; + size_t size = CMSG_LEN(sizeof(int)); + char buf[size]; struct msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1, .msg_control = buf, - .msg_controllen = len, + .msg_controllen = size, }; ssize_t n = recvmsg(sock, &msg, 0); @@ -106,7 +97,6 @@ static int recvFd(int sock) { } static struct sockaddr_un addr; - static void unlinkAddr(void) { unlink(addr.sun_path); } @@ -114,23 +104,26 @@ static void unlinkAddr(void) { static int dtch(int argc, char *argv[]) { int error; - struct passwd *user = getUser(); + const struct passwd *user = getUser(); - char *name = user->pw_name; - char *cmd = user->pw_shell; - if (argc > 2) { - name = argv[1]; - cmd = argv[2]; - argv += 2; - } else if (argc > 1) { + const char *name = user->pw_name; + if (argc > 1) { name = argv[1]; argv++; + argc--; + } + if (argc > 1) { + argv++; + } else { + argv[0] = user->pw_shell; } int home = open(user->pw_dir, 0); - if (home < 0) err(EX_IOERR, "%s", user->pw_dir); - error = mkdirat(home, ".dtch", S_IRWXU); - if (error && errno != EEXIST) err(EX_IOERR, "%s/.dtch", user->pw_dir); + if (home < 0) err(EX_CANTCREAT, "%s", user->pw_dir); + + error = mkdirat(home, ".dtch", 0700); + if (error && errno != EEXIST) err(EX_CANTCREAT, "%s/.dtch", user->pw_dir); + error = close(home); if (error) err(EX_IOERR, "%s", user->pw_dir); @@ -139,21 +132,22 @@ static int dtch(int argc, char *argv[]) { addr = sockAddr(user->pw_dir, name); error = bind(server, (struct sockaddr *)&addr, sizeof(addr)); - if (error) err(EX_IOERR, "%s", addr.sun_path); + if (error) err(EX_CANTCREAT, "%s", addr.sun_path); atexit(unlinkAddr); error = chmod(addr.sun_path, 0600); if (error) err(EX_IOERR, "%s", addr.sun_path); - int master; - pid_t pid = forkpty(&master, NULL, NULL, NULL); + error = fcntl(server, F_SETFD, FD_CLOEXEC); + if (error) err(EX_IOERR, "fcntl(%d)", server); + + int pty; + pid_t pid = forkpty(&pty, NULL, NULL, NULL); if (pid < 0) err(EX_OSERR, "forkpty"); if (!pid) { - error = close(server); - if (error) warn("close(%d)", server); - execvp(cmd, argv); - err(EX_OSERR, "%s", cmd); + execvp(*argv, argv); + err(EX_NOINPUT, "%s", *argv); } error = listen(server, 0); @@ -163,26 +157,27 @@ static int dtch(int argc, char *argv[]) { int client = accept(server, NULL, NULL); if (client < 0) err(EX_IOERR, "accept(%d)", server); - ssize_t len = sendFd(client, master); - if (len < 0) warn("sendmsg(%d)", client); + ssize_t size = sendFd(client, pty); + if (size < 0) warn("sendmsg(%d)", client); - len = recv(client, &z, sizeof(z), 0); - if (len < 0) warn("recv(%d)", client); + size = recv(client, &z, sizeof(z), 0); + if (size < 0) warn("recv(%d)", client); error = close(client); if (error) warn("close(%d)", client); - pid_t dead = waitpid(pid, NULL, WNOHANG); - if (dead < 0) warn("waitpid(%d)", pid); - if (dead) exit(EX_OK); + int status; + pid_t dead = waitpid(pid, &status, WNOHANG); + if (dead < 0) err(EX_OSERR, "waitpid(%d)", pid); + if (dead) return WIFEXITED(status) ? WEXITSTATUS(status) : EX_SOFTWARE; } } static struct termios saveTerm; - static void restoreTerm(void) { tcsetattr(STDERR_FILENO, TCSADRAIN, &saveTerm); - printf( + fprintf( + stderr, "\x1b[?1049l" // rmcup "\x1b\x63\x1b[!p\x1b[?3;4l\x1b[4l\x1b>" // reset ); @@ -191,25 +186,25 @@ static void restoreTerm(void) { static int atch(int argc, char *argv[]) { int error; - struct passwd *user = getUser(); - char *name = (argc > 1) ? argv[1] : user->pw_name; + const struct passwd *user = getUser(); + const char *name = (argc > 1) ? argv[1] : user->pw_name; int client = socket(PF_LOCAL, SOCK_STREAM, 0); if (client < 0) err(EX_OSERR, "socket"); struct sockaddr_un addr = sockAddr(user->pw_dir, name); error = connect(client, (struct sockaddr *)&addr, sizeof(addr)); - if (error) err(EX_IOERR, "%s", addr.sun_path); + if (error) err(EX_NOINPUT, "%s", addr.sun_path); - int master = recvFd(client); - if (master < 0) err(EX_IOERR, "recvmsg(%d)", client); + int pty = recvFd(client); + if (pty < 0) err(EX_IOERR, "recvmsg(%d)", client); struct winsize window; error = ioctl(STDERR_FILENO, TIOCGWINSZ, &window); if (error) err(EX_IOERR, "ioctl(%d, TIOCGWINSZ)", STDERR_FILENO); - error = ioctl(master, TIOCSWINSZ, &window); - if (error) err(EX_IOERR, "ioctl(%d, TIOCSWINSZ)", master); + error = ioctl(pty, TIOCSWINSZ, &window); + if (error) err(EX_IOERR, "ioctl(%d, TIOCSWINSZ)", pty); error = tcgetattr(STDERR_FILENO, &saveTerm); if (error) err(EX_IOERR, "tcgetattr(%d)", STDERR_FILENO); @@ -220,43 +215,45 @@ static int atch(int argc, char *argv[]) { error = tcsetattr(STDERR_FILENO, TCSADRAIN, &raw); if (error) err(EX_IOERR, "tcsetattr(%d)", STDERR_FILENO); - char ctrlL = CTRL('L'); - ssize_t len = write(master, &ctrlL, 1); - if (len < 0) err(EX_IOERR, "write(%d)", master); + char c = CTRL('L'); + ssize_t size = write(pty, &c, 1); + if (size < 0) err(EX_IOERR, "write(%d)", pty); + char buf[4096]; struct pollfd fds[2] = { { .fd = STDIN_FILENO, .events = POLLIN }, - { .fd = master, .events = POLLIN }, + { .fd = pty, .events = POLLIN }, }; - while (0 < poll(fds, 2, -1)) { - char buf[4096]; - ssize_t len; + if (fds[0].revents == POLLIN) { + ssize_t readSize = read(STDIN_FILENO, buf, sizeof(buf)); + if (readSize < 0) err(EX_IOERR, "read(%d)", STDIN_FILENO); - if (fds[0].revents) { - len = read(STDIN_FILENO, buf, sizeof(buf)); - if (len < 0) err(EX_IOERR, "read(%d)", STDIN_FILENO); + if (readSize == 1 && buf[0] == CTRL('Q')) return EX_OK; - if (len && buf[0] == CTRL('Q')) exit(EX_OK); - - len = writeAll(master, buf, len); - if (len < 0) err(EX_IOERR, "write(%d)", master); + ssize_t writeSize = write(pty, buf, readSize); + if (writeSize < 0) err(EX_IOERR, "write(%d)", pty); + if (writeSize < readSize) errx(EX_IOERR, "short write(%d)", pty); } - if (fds[1].revents) { - len = read(master, buf, sizeof(buf)); - if (len < 0) err(EX_IOERR, "read(%d)", master); - len = writeAll(STDOUT_FILENO, buf, len); - if (len < 0) err(EX_IOERR, "write(%d)", STDOUT_FILENO); + if (fds[1].revents == POLLIN) { + ssize_t readSize = read(pty, buf, sizeof(buf)); + if (readSize < 0) err(EX_IOERR, "read(%d)", pty); + + ssize_t writeSize = write(STDOUT_FILENO, buf, readSize); + if (writeSize < 0) err(EX_IOERR, "write(%d)", STDOUT_FILENO); + if (writeSize < readSize) { + errx(EX_IOERR, "short write(%d)", STDOUT_FILENO); + } } } - err(EX_IOERR, "poll([%d,%d])", STDIN_FILENO, master); + err(EX_IOERR, "poll"); } int main(int argc, char *argv[]) { switch (argv[0][0]) { case 'd': return dtch(argc, argv); case 'a': return atch(argc, argv); + default: return EX_USAGE; } - return EX_USAGE; } |