#include <err.h>
#include <errno.h>
#include <poll.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#include <tls.h>

#define BUFSIZE 16384

#define TLS_DEBUG 0

#if TLS_DEBUG
# define dbg(...) fprintf(stderr, __VA_ARGS__)
#else
# define dbg(...) ((void)0)
#endif

static void copy_from_stdin_to_tls(struct tls *ctx, int *fd)
{
	static size_t buf[BUFSIZE];
	ssize_t n;
	int i = 0;
	dbg("DEBUG: data from STDIN\n");
	do {
		n = read(STDIN_FILENO, buf, sizeof(buf));
		dbg("read %zu\n", n);
	} while (n < 0 && errno == EINTR);

	if (n < 1) {
		*fd = -1;
		return;
	}

	while (n > 0) {
		ssize_t r = tls_write(ctx, &buf[i], n);
		if (r == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT)
			continue;
		if (r < 0)
			err(1, "tls_write: %s", tls_error(ctx));
		i += r;
		n -= r;
	}
}

static int copy_from_tls_to_stdout(struct tls *ctx)
{
	static size_t buf[BUFSIZE];
	ssize_t n,r;
	int i = 0;

	dbg("DEBUG: data from TLS\n");
	do {
		n = tls_read(ctx, buf, sizeof(buf));
	} while (n == TLS_WANT_POLLIN || r == TLS_WANT_POLLOUT);
	if (n < 0)
		err(1, "tls read: %s", tls_error(ctx));

	if (n == 0)
		return 1;

	while (n) {
		r = write(STDOUT_FILENO, &buf[i], n);
		if (r < 0)
			err(1, "write");
		i += r;
		n -= r;
	}
	return 0;
}

int do_poll(struct pollfd *fds, int nfds)
{
	int r;
	while ((r = poll(fds, nfds, -1)) < 0) {
		if (errno != EINTR && errno != ENOMEM)
			err(1, "poll");
	}
	return r;
}

static void copy_loop(struct tls *ctx, int sfd)
{
	struct pollfd fds[2] = {
		{ .fd = STDIN_FILENO,	.events = POLLIN },
		{ .fd = sfd,		.events = POLLIN },
	};

	while (1) {
		int r = do_poll(fds, 2);
		if (fds[0].revents)
			copy_from_stdin_to_tls(ctx, &fds[0].fd);

		if (fds[1].revents && copy_from_tls_to_stdout(ctx))
			break;
	}
}

void usage(const char *prog, int ret) {
	printf("usage: %s [-s FD] [-I] -n SNI\n", prog);
	exit(ret);
}

int main(int argc, char *argv[])
{
	int c, sfd = 1;;
	const char *sni = NULL;
	struct tls_config *tc;
	struct tls *ctx;
	int insecure = 0;

	while ((c = getopt(argc, argv, "hs:n:I")) != -1) {
		switch (c) {
		case 'h':
			usage(argv[0], 0);
			break;
		case 's':
			sfd = atoi(optarg);
			break;
		case 'n':
			sni = optarg;
			break;
		case 'I':
			insecure = 1;
			break;
		case '?':
			usage(argv[0], 1);
		}
	}

	if (tls_init() == -1)
		errx(1, "tls_init() failed");

	if ((ctx = tls_client()) == NULL)
		errx(1, "tls_client() failed");

	if (insecure) {
		if ((tc = tls_config_new()) == NULL)
			errx(1, "tls_config_new() failed");
		tls_config_insecure_noverifycert(tc);
		tls_config_insecure_noverifyname(tc);
		tls_config_insecure_noverifytime(tc);
		if (tls_configure(ctx, tc) == -1)
			err(1, "tls_configure: %s", tls_error(ctx));
		tls_config_free(tc);
	}

	if (tls_connect_fds(ctx, sfd, sfd, sni) == -1)
		errx(1, "%s: TLS connect failed", sni);

	if (tls_handshake(ctx) == -1)
		errx(1, "%s: %s", sni, tls_error(ctx));

	copy_loop(ctx, sfd);
	tls_close(ctx);
	return 0;
}