/*
 * Copyright (c) 2003, 2004, 2005 Nokia
 * Author: Timo Savola <tsavola@movial.fi>
 *
 * This program is licensed under GPL (see COPYING for details)
 */

#include "fakeroot.h"
#include "types.h"
#include "common.h"
#include "daemon.h"
#include "mount.h"

#define _GNU_SOURCE

#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <signal.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/param.h>
#include <sys/stat.h>

extern char **environ;

typedef enum {
	DAEMON_TO_CLIENT,
	CLIENT_TO_DAEMON
} direction_t;

struct node_s;
typedef struct node_s {
	struct node_s *next;
	int client;
	int daemon;
	fake_dev_t stored_dev;
} node_t;

static volatile bool_t parent_alive = TRUE;
static volatile int listen_sd = -1;
static uint32_t relay_id;

static void append_node(handler_t *data, node_t **head, int client_sd, int daemon_sd)
{
	node_t *node, *n;

	node = calloc(1, sizeof (node_t));
	if (!node) {
		errno = 0;
		send_error(data, oom);
		exit(1);
	}

	node->client = client_sd;
	node->daemon = daemon_sd;

	if (*head) {
		for (n = *head; n->next; n = n->next);
		n->next = node;
	} else {
		*head = node;
	}
}

static void remove_node(node_t **head, node_t *node)
{
	node_t *n;

	if (*head == node) {
		*head = node->next;
	} else {
		for (n = *head; n->next != node; n = n->next);
		n->next = node->next;
	}

	debug("Removing node: client=%d, daemon=%d", node->client, node->daemon);

	close(node->client);
	close(node->daemon);

	free(node);
}

static void get_device_numbers(handler_t *data)
{
	mount_t **ptr;

	if (!data->mounts) {
		debug("No mounts");
		return;
	}

	for (ptr = data->mounts; *ptr; ptr++) {
		mount_t *m = *ptr;

		if (m->info.type == MTYPE_NFS && m->point_dev == 0LL) {
			struct stat buf;

			if (stat(m->info.point, &buf) < 0) {
				send_error(data, "Can't stat %s", m->info.point);
				exit(1);
			}

			m->point_dev = buf.st_dev;

			debug("Device number of %s is %lld", m->info.device, m->info.device_dev);
			debug("Device number of %s is %lld", m->info.point, m->point_dev);
		}
	}
}

static bool_t translate_dev(handler_t *data, node_t *node, struct fakestat *st)
{
	mount_t **ptr;
	const fake_dev_t old = ntohll(st->dev);

	node->stored_dev = st->dev;

	if (!data->mounts) {
		debug("No mounts");
		return FALSE;
	}

	for (ptr = data->mounts; *ptr; ptr++) {
		const mount_t *m = *ptr;

		if (old == m->point_dev) {
			debug("Translating device number: %lld -> %lld", old, m->info.device_dev);

			st->dev = htonll(m->info.device_dev);
			return TRUE;
		}
	}

	debug("No match for device number %lld", old);
	return FALSE;
}

static void restore_dev(node_t *node, struct fakestat *st)
{
	debug("Restoring device number");

	st->dev = node->stored_dev;
}

static int copy_msg(handler_t *data, node_t *node, direction_t direction)
{
	int source, destination;
	struct fake_msg buf;
	ssize_t len;

	if (direction == CLIENT_TO_DAEMON) {
		source = node->client;
		destination = node->daemon;

		debug("Processing message from client=%d", source);
	} else {
		source = node->daemon;
		destination = node->client;

		debug("Processing message from daemon=%d", source);
	}

	while (1) {
		len = read(source, &buf, sizeof (buf));
		if (len >= 0)
			break;

		if (errno != EINTR) {
			send_error(data, "Can't read from socket");
			exit(1);
		}
	}

	if (len == 0) {
		debug("No message");
		return -1;
	}

	if (direction == CLIENT_TO_DAEMON) {
		if (translate_dev(data, node, &buf.st)) {
			buf.remote = htonl(0);
		} else if (ntohl(buf.remote) == 0) {
			debug("Setting remote ID");
			buf.remote = htonl(relay_id);
		}
	} else {
		restore_dev(node, &buf.st);
	}

	if (direction == CLIENT_TO_DAEMON) {
		debug("Forwarding message to daemon=%d", destination);
	} else {
		debug("Forwarding message to client=%d", destination);
	}

	while (1) {
		if (write(destination, &buf, sizeof (buf)) >= 0)
			break;

		if (errno != EINTR) {
			send_error(data, "Can't write to socket");
			exit(1);
		}
	}

	return 0;
}

static int get_daemon(handler_t *data, struct sockaddr_in *addr)
{
	int sd;

	sd = socket(PF_INET, SOCK_STREAM, 0);
	if (sd < 0) {
		send_error(data, "Can't create socket");
		exit(1);
	}

	if (connect(sd, (struct sockaddr *) addr, sizeof (struct sockaddr_in)) < 0) {
		if (errno == EINTR) {
			debug("Connect interrupted");
			exit(0);
		}

		send_error(data, "Connect failed");
		exit(1);
	}

	debug("Connected to daemon=%d", sd);

	return sd;
}

static int get_client(handler_t *data)
{
	struct sockaddr_in addr;
	socklen_t len = sizeof (addr);
	int sd;

	sd = accept(listen_sd, (struct sockaddr *) &addr, &len);
	if (sd < 0) {
		if (errno == EINTR) {
			debug("Accept interrupted");
			exit(0);
		}

		send_error(data, "Accept failed");
		exit(1);
	}

	debug("Connection from client=%d", sd);

	return sd;
}

static void do_relay(handler_t *data, int port)
{
	struct hostent *host;
	struct sockaddr_in addr = { 0 };
	node_t *sd_head = NULL;
	fd_set fds;

	host = gethostbyname(data->host);
	if (!host) {
		send_error(data, "Can't resolve host: %s", data->host);
		exit(1);
	}

	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = *(unsigned int *) host->h_addr;
	addr.sin_port = htons(port);

	debug("Fakeroot daemon assumed to be at %s:%d", data->host, port);

	while (parent_alive) {
		int maxfd, count;
		node_t *n, *next;

		FD_ZERO(&fds);

		FD_SET(listen_sd, &fds);
		maxfd = listen_sd;

		for (n = sd_head, count = 0; n; n = n->next, count++) {
			FD_SET(n->daemon, &fds);
			FD_SET(n->client, &fds);
			maxfd = MAX(maxfd, n->daemon);
			maxfd = MAX(maxfd, n->client);
		}

		debug("Selecting (%d nodes)", count);

		if (select(maxfd + 1, &fds, NULL, NULL, NULL) < 0) {
			if (!parent_alive) {
				debug("SIGEXIT caught during select call");
				break;
			}

			if (errno == EINTR) {
				debug("Select interrupted");
				continue;
			}

			send_error(data, "Select failed");
			exit(1);
		}

		for (n = sd_head; n; n = next) {
			next = n->next;

			if (FD_ISSET(n->daemon, &fds) && copy_msg(data, n, DAEMON_TO_CLIENT) < 0) {
				remove_node(&sd_head, n);
				continue;
			}

			if (FD_ISSET(n->client, &fds) && copy_msg(data, n, CLIENT_TO_DAEMON) < 0) {
				remove_node(&sd_head, n);
			}
		}

		if (FD_ISSET(listen_sd, &fds)) {
			int cli, dae;

			cli = get_client(data);
			dae = get_daemon(data, &addr);

			append_node(data, &sd_head, cli, dae);
		}
	}

	debug("Exiting");
}

static void sig_exit(int sig)
{
	int stored_errno = errno;

	parent_alive = FALSE;

	if (listen_sd >= 0)
		close(listen_sd);

	errno = stored_errno;
}

static pid_t fork_relay(handler_t *data, uint16_t *portp)
{
	int faked_port, sd;
	struct sockaddr_in addr;
	socklen_t len;
	uint32_t id;
	pid_t pid;
	struct sigaction act_exit;

	faked_port = *portp;

	sd = socket(PF_INET, SOCK_STREAM, 0);
	if (sd < 0) {
		send_error(data, "Can't create socket");
		return -1;
	}

	if (setsockopt_bool(sd, SOL_SOCKET, SO_REUSEADDR, TRUE)) {
		send_error(data, "Can't set socket option: SO_REUSEADDR");
		goto _error;
	}

	if (listen(sd, SOMAXCONN) < 0) {
		send_error(data, "Can't listen to socket");
		goto _error;
	}

	len = sizeof (addr);
	if (getsockname(sd, (struct sockaddr *) &addr, &len) < 0) {
		send_error(data, "Can't get name of relay listening socket");
		goto _error;
	}

	*portp = ntohs(addr.sin_port);

	len = sizeof (addr);
	if (getsockname(data->sd, (struct sockaddr *) &addr, &len) < 0) {
		send_error(data, "Can't get name of client connection socket");
		goto _error;
	}

	id = ntohl(addr.sin_addr.s_addr);

	pid = fork();
	if (pid < 0) {
		send_error(data, "Can't fork");
		goto _error;
	}

	if (pid > 0) {
		/* Parent: */
		close(sd);
		return pid;
	}

	/* Child: */

	set_debug_name("RELAY");
	debug("Relaying fakeroot messages at port %d", *portp);
	debug("Remote ID of this relay is 0x%lx", id);

	daemonize(sd);

	act_exit.sa_handler = sig_exit;
	sigemptyset(&act_exit.sa_mask);
	act_exit.sa_flags = 0;

	sigaction(SIGHUP, &act_exit, NULL);
	sigaction(SIGTERM, &act_exit, NULL);

	get_device_numbers(data);

	listen_sd = sd;
	relay_id = id;

	do_relay(data, faked_port);
	exit(0);

_error:
	close(sd);
	return -1;
}

pid_t fakeroot_relay(handler_t *data)
{
	char **old_environ, *str;
	uint16_t port;
	pid_t pid;

	/* TODO: better. */
	old_environ = environ;
	environ = data->param.environ;
	str = getenv(FAKEROOTKEY_ENV);
	environ = old_environ;

	if (!str || !*str)
		return 0;

	debug("Creating relay process");

	port = atoi(str);
	if (port == 0) {
		send_error(data, "Invalid " FAKEROOTKEY_ENV ": %s", str);
		return -1;
	}

	pid = fork_relay(data, &port);
	if (pid < 0)
		return -1;

	data->fakerootkey = malloc(strlen(FAKEROOTKEY_ENV) + 1 + 5 + 1);
	if (!data->fakerootkey) {
		errno = 0;
		send_error(data, oom);
		goto _kill;
	}

	sprintf(data->fakerootkey, FAKEROOTKEY_ENV "=%d", port);

	return pid;

_kill:
	kill(pid, SIGTERM);
	return -1;
}
