
/*
 * Copyright (C) 1996, Jonathan Layes <layes@loran.com>
 *
 * See the file "COPYING" for copyright information
 *
 *	Rewrote by A.Kuznetsov
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <syslog.h>
#include <fcntl.h>
#include <linux/netdevice.h>
#include <net/if_arp.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

static void arpd_update (struct arpd_request *);
static void arpd_flush (struct arpd_request *);
static struct arpd_request * arpd_lookup (struct arpd_request *);
static struct arpd_request * arpd_find (unsigned long, __u32, struct arpd_request *);
#ifdef DEBUG
static void arpd_print(char*, struct arpd_request*);
#else
#define arpd_print(x,y) /* NOTHING */
#endif

#define HASH_LOG	10
#define HASH_PRIME	(1<<HASH_LOG)

struct arpd_entry
{
	struct arpd_entry  *next;
	struct arpd_request body;
};

struct dev_list
{
	struct dev_list   *next;
	unsigned long      dev;
	struct arpd_entry *hash[HASH_PRIME];
};

struct dev_list *tree;

int main(int argc, char **argv)
{
	int status;
	int fd;
	struct arpd_request req;
	struct arpd_request *rep;

#ifndef DEBUG
	if (fork())
		exit(0);

	for (fd=0; fd < OPEN_MAX; fd++)
		close(fd);

	fd = open("/dev/null", O_RDWR);
	if (fd) {
		if (fd != 0)
			dup2(fd, 0);
		if (fd != 1)
			dup2(fd, 1);
		if (fd != 2)
			dup2(fd, 2);
		if (fd > 2)
			close(fd);
	}
	setsid();
#endif

	openlog ("arpd", LOG_PID | LOG_CONS, LOG_DAEMON);

	fd = open("/dev/arpd", O_RDWR);
	if (fd < 0) {
		syslog(LOG_CRIT, "cannot open /dev/arpd: %m\n");
		exit(-1);
	}

	while (1) {
		status = read(fd, &req, sizeof(req));
		if (status < 0) {
			if (errno == EINTR)
				continue;
			syslog(LOG_CRIT, "cannot read /dev/arpd: %m\n");
			exit(-1);
		}
		if (status != sizeof(req))
		{
			syslog(LOG_CRIT, "bad message length %d\n", status);
			exit(-1);
		}
		switch (req.req)
		{
		case ARPD_UPDATE:
			arpd_print("UPDATE", &req);
			arpd_update (&req);
			break;
		case ARPD_LOOKUP:
			arpd_print("LOOKUP", &req);
			if ((rep = arpd_lookup (&req)) == NULL) {
				req.updated = 0;
				rep = &req;
			}
			arpd_print("REPLY", rep);
			status = write(fd, rep, sizeof(*rep));
			if (status < 0) {
				syslog(LOG_CRIT, "cannot write /dev/arpd: %m");
				exit(-1);
			}
			if (status != sizeof(*rep)) {
				syslog(LOG_CRIT, "write /dev/arpd returns %d\n", status);
				exit(-1);
			}
			break;
		case ARPD_FLUSH:
			arpd_print("FLUSH", &req);
			arpd_flush (&req);
			break;
		}
	}

	return 0;
}

static void arpd_update (struct arpd_request * entry)
{
	arpd_find (entry->dev, entry->ip, entry);
}

static struct arpd_request* arpd_lookup (struct arpd_request * entry)
{
	return arpd_find (entry->dev, entry->ip, NULL);
}

static struct arpd_request * arpd_find (unsigned long dev, __u32 ip, struct arpd_request * newent)
{
	struct dev_list *dl;
	struct arpd_entry *ent, **entp;
	unsigned long key;

	for (dl = tree; dl; dl = dl->next) {
		if (dl->dev == dev)
			goto found;
	}
	if (!newent)
		return NULL;
	dl = (struct dev_list*)malloc(sizeof(struct dev_list));
	if (!dl)
		return NULL;
	memset(dl, 0, sizeof(struct dev_list));
	dl->dev  = dev;
	dl->next = tree;
	tree = dl;

found:
	key = ip;
	key ^= ip>>HASH_LOG;
	key ^= ip>>(HASH_LOG+HASH_LOG);
	key ^= ip>>(HASH_LOG+HASH_LOG+HASH_LOG);
	key &= HASH_PRIME-1;

	entp = &dl->hash[key];
	
	for (ent = *entp; ent; ent = ent->next) {
		if (ent->body.ip == ip) {
			if (newent)
				ent->body = *newent;
			return &ent->body;
		}
	}
	if (!newent)
		return NULL;

	ent = malloc(sizeof(struct arpd_entry));
	if (!ent)
		return NULL;
	ent->body = *newent;
	ent->next = *entp;
	*entp = ent;
	return &ent->body;
}

static void arpd_flush (struct arpd_request * entry)
{
	struct dev_list *dl, **dlp;
	int key;

	for (dlp=&tree; (dl=*dlp); dlp=&dl->next) {
		if (dl->dev == entry->dev) {
			*dlp = dl->next;
			goto found;
		}
	}
	return;

found:

	for (key=0; key < HASH_PRIME; key++) {
		struct arpd_entry *ent, *next;
		for (ent = dl->hash[key]; ent; ent = next) {
			next = ent->next;
			free(ent);
		}
	}
	free(dl);
}

#ifdef DEBUG
static void arpd_print(char * type, struct arpd_request* req)
{
	unsigned char *ha=req->ha;

	fprintf(stderr, "Type: %s\n", type);
	fprintf(stderr, "Dst: %08x\n", ntohl(req->ip));
	fprintf(stderr, "Dev: %08x Stamp: %08x Updated: %d\n", req->dev, req->stamp, req->updated);
	fprintf(stderr, "HA: %02x:%02x:%02x:%02x:%02x:%02x\n\n",
		ha[0], ha[1], ha[2], ha[3], ha[4], ha[5]);
}
#endif
