/*
 *  OpenVPN -- An application to securely tunnel IP networks
 *             over a single TCP/UDP port, with support for SSL/TLS-based
 *             session authentication and key exchange,
 *             packet encryption, packet authentication, and
 *             packet compression.
 *
 *  Copyright (C) 2002-2008 Telethra, Inc. <sales@openvpn.net>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License version 2
 *  as published by the Free Software Foundation.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program (see the file COPYING included with this
 *  distribution); if not, write to the Free Software Foundation, Inc.,
 *  59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

/* packet filter functions */

#include "syshead.h"

#if defined(ENABLE_PF)

#include "init.h"

#include "memdbg.h"

#include "pf-inline.h"

static void
pf_destroy (struct pf_set *pfs)
{
  if (pfs)
    {
      if (pfs->cns.hash_table)
	hash_free (pfs->cns.hash_table);

      {
	struct pf_cn_elem *l = pfs->cns.list;
	while (l)
	  {
	    struct pf_cn_elem *next = l->next;
	    free (l->rule.cn);
	    free (l);
	    l = next;
	  }
      }
      {
	struct pf_subnet *l = pfs->sns.list;
	while (l)
	  {
	    struct pf_subnet *next = l->next;
	    free (l);
	    l = next;
	  }
      }
      free (pfs);
    }
}

static bool
add_client (const char *line, const char *prefix, const int line_num, struct pf_cn_elem ***next, const bool exclude)
{
  struct pf_cn_elem *e;
  ALLOC_OBJ_CLEAR (e, struct pf_cn_elem);
  e->rule.exclude = exclude;
  e->rule.cn = string_alloc (line, NULL);
  **next = e;
  *next = &e->next;
  return true;
}

static bool
add_subnet (const char *line, const char *prefix, const int line_num, struct pf_subnet ***next, const bool exclude)
{
  struct in_addr network;
  in_addr_t netmask = 0;

  if (strcmp (line, "unknown"))
    {
      int netbits = 32;
      char *div = strchr (line, '/');

      if (div)
	{
	  *div++ = '\0';
	  if (sscanf (div, "%d", &netbits) != 1)
	    {
	      msg (D_PF_INFO, "PF: %s/%d: bad '/n' subnet specifier: '%s'", prefix, line_num, div);
	      return false;
	    }
	  if (netbits < 0 || netbits > 32)
	    {
	      msg (D_PF_INFO, "PF: %s/%d: bad '/n' subnet specifier: must be between 0 and 32: '%s'", prefix, line_num, div);
	      return false;
	    }
	}

      if (openvpn_inet_aton (line, &network) != OIA_IP)
	{
	  msg (D_PF_INFO, "PF: %s/%d: bad network address: '%s'", prefix, line_num, line);
	  return false;
	}
      netmask = netbits_to_netmask (netbits);
    }
  else
    {
      /* match special "unknown" tag for addresses unrecognized by mroute */
      network.s_addr = htonl(0);
      netmask = ~0;
    }

  {
    struct pf_subnet *e;
    ALLOC_OBJ_CLEAR (e, struct pf_subnet);
    e->rule.exclude = exclude;
    e->rule.network = ntohl (network.s_addr);
    e->rule.netmask = netmask;
    **next = e;
    *next = &e->next;
    return true;
  }
}

static uint32_t
cn_hash_function (const void *key, uint32_t iv)
{
  return hash_func ((uint8_t *)key, strlen ((char *)key) + 1, iv);
}

static bool
cn_compare_function (const void *key1, const void *key2)
{
  return !strcmp((const char *)key1, (const char *)key2);
}

static bool
genhash (struct pf_cn_set *cns, const char *prefix, const int n_clients)
{
  struct pf_cn_elem *e;
  bool status = true;
  int n_buckets = n_clients;

  if (n_buckets < 16)
    n_buckets = 16;
  cns->hash_table = hash_init (n_buckets, 0, cn_hash_function, cn_compare_function);
  for (e = cns->list; e != NULL; e = e->next)
    {
      if (!hash_add (cns->hash_table, e->rule.cn, &e->rule, false))
	{
	  msg (D_PF_INFO, "PF: %s: duplicate common name in [clients] section: '%s'", prefix, e->rule.cn);
	  status = false;
	}
    }
  
  return status;
}

static struct pf_set *
pf_init (const struct buffer_list *bl, const char *prefix, const bool allow_kill)
{
# define MODE_UNDEF   0
# define MODE_CLIENTS 1
# define MODE_SUBNETS 2
  int mode = MODE_UNDEF;
  int line_num = 0;
  int n_clients = 0;
  int n_subnets = 0;
  int n_errors = 0;
  struct pf_set *pfs = NULL;
  char line[PF_MAX_LINE_LEN];

  ALLOC_OBJ_CLEAR (pfs, struct pf_set);
  if (bl)
    {
      struct pf_cn_elem **cl = &pfs->cns.list;
      struct pf_subnet **sl = &pfs->sns.list;
      struct buffer_entry *be;

      for (be = bl->head; be != NULL; be = be->next)
	{
	  ++line_num;
	  strncpynt (line, BSTR(&be->buf), sizeof(line));
	  rm_trailing_chars (line, "\r\n\t ");
	  if (line[0] == '\0' || line[0] == '#')
	    ;
	  else if (line[0] == '+' || line[0] == '-')
	    {
	      bool exclude = (line[0] == '-');

	      if (line[1] =='\0')
		{
		  msg (D_PF_INFO, "PF: %s/%d: no data after +/-: '%s'", prefix, line_num, line);
		  ++n_errors;
		}
	      else if (mode == MODE_CLIENTS)
		{
		  if (add_client (&line[1], prefix, line_num, &cl, exclude))
		    ++n_clients;
		  else
		    ++n_errors;
		}
	      else if (mode == MODE_SUBNETS)
		{
		  if (add_subnet (&line[1], prefix, line_num, &sl, exclude))
		    ++n_subnets;
		  else
		    ++n_errors;
		}
	      else if (mode == MODE_UNDEF)
		;
	      else
		{
		  ASSERT (0);
		}
	    }
	  else if (line[0] == '[')
	    {
	      if (!strcasecmp (line, "[clients accept]"))
		{
		  mode = MODE_CLIENTS;
		  pfs->cns.default_allow = true;
		}
	      else if (!strcasecmp (line, "[clients drop]"))
		{
		  mode = MODE_CLIENTS;
		  pfs->cns.default_allow = false;
		}
	      else if (!strcasecmp (line, "[subnets accept]"))
		{
		  mode = MODE_SUBNETS;
		  pfs->sns.default_allow = true;
		}
	      else if (!strcasecmp (line, "[subnets drop]"))
		{
		  mode = MODE_SUBNETS;
		  pfs->sns.default_allow = false;
		}
	      else if (!strcasecmp (line, "[end]"))
		goto done;
	      else if (allow_kill && !strcasecmp (line, "[kill]"))
		goto kill;
	      else
		{
		  mode = MODE_UNDEF;
		  msg (D_PF_INFO, "PF: %s/%d unknown tag: '%s'", prefix, line_num, line);
		  ++n_errors;
		}
	    }
	  else
	    {
	      msg (D_PF_INFO, "PF: %s/%d line must begin with '+', '-', or '[' : '%s'", prefix, line_num, line);
	      ++n_errors;
	    }
	}
      ++n_errors;
      msg (D_PF_INFO, "PF: %s: missing [end]", prefix);
    }
  else
    {
      msg (D_PF_INFO, "PF: %s: cannot open", prefix);
      ++n_errors;
    }

 done:
  if (bl)
    {
      if (!n_errors)
	{
	  if (!genhash (&pfs->cns, prefix, n_clients))
	    ++n_errors;
	}
      if (n_errors)
	msg (D_PF_INFO, "PF: %s rejected due to %d error(s)", prefix, n_errors);
    }
  if (n_errors)
    {
      pf_destroy (pfs);
      pfs = NULL;
    }
  return pfs;
  
 kill:
  pf_destroy (pfs);
  ALLOC_OBJ_CLEAR (pfs, struct pf_set);
  pfs->kill = true;
  return pfs;
}

#ifdef PLUGIN_PF
static struct pf_set *
pf_init_from_file (const char *fn)
{
  struct buffer_list *bl = buffer_list_file (fn, PF_MAX_LINE_LEN);
  if (bl)
    {
      struct pf_set *pfs = pf_init (bl, fn, true);
      buffer_list_free (bl);
      return pfs;
    }
  else
    {
      msg (D_PF_INFO|M_ERRNO, "PF: %s: cannot open", fn);
      return NULL;
    }
}
#endif

#ifdef ENABLE_DEBUG

static const char *
drop_accept (const bool accept)
{
  return accept ? "ACCEPT" : "DROP"; 
}

static const char *
pct_name (const int type)
{
  switch (type)
    {
    case PCT_SRC:
      return "SRC";
    case PCT_DEST:
      return "DEST";
    default:
      return "???";
    }
}

static void
pf_cn_test_print (const char *prefix,
		  const int type,
		  const char *prefix2,
		  const char *cn,
		  const bool allow,
		  const struct pf_cn *rule)
{
  if (rule)
    {
      dmsg (D_PF_DEBUG, "PF: %s/%s/%s %s %s rule=[%s %s]",
	   prefix, prefix2, pct_name (type),
	   cn, drop_accept (allow),
	   rule->cn, drop_accept (!rule->exclude));
    }
  else
    {
      dmsg (D_PF_DEBUG, "PF: %s/%s/%s %s %s",
	   prefix, prefix2, pct_name (type),
	   cn, drop_accept (allow));
    }
}

static void
pf_addr_test_print (const char *prefix,
		    const char *prefix2,
		    const struct context *src,
		    const struct mroute_addr *dest,
		    const bool allow,
		    const struct ipv4_subnet *rule)
{
  struct gc_arena gc = gc_new ();
  if (rule)
    {
      dmsg (D_PF_DEBUG, "PF: %s/%s %s %s %s rule=[%s/%s %s]",
	   prefix,
	   prefix2,
	   tls_common_name (src->c2.tls_multi, false),
	   mroute_addr_print_ex (dest, MAPF_SHOW_ARP, &gc),
	   drop_accept (allow),
	   print_in_addr_t (rule->network, 0, &gc),
	   print_in_addr_t (rule->netmask, 0, &gc),
	   drop_accept (!rule->exclude));
    }
  else
    {
      dmsg (D_PF_DEBUG, "PF: %s/%s %s %s %s",
	   prefix,
	   prefix2,
	   tls_common_name (src->c2.tls_multi, false),
	   mroute_addr_print_ex (dest, MAPF_SHOW_ARP, &gc),
	   drop_accept (allow));
    }
  gc_free (&gc);
}

#endif

static inline struct pf_cn *
lookup_cn_rule (struct hash *h, const char *cn, const uint32_t cn_hash)
{
  struct hash_element *he = hash_lookup_fast (h, hash_bucket (h, cn_hash), cn, cn_hash);
  if (he)
    return (struct pf_cn *) he->value;
  else
    return NULL;
}

bool
pf_cn_test (struct pf_set *pfs, const struct tls_multi *tm, const int type, const char *prefix)
{
  if (!pfs->kill)
    {
      const char *cn;
      uint32_t cn_hash;
      if (tls_common_name_hash (tm, &cn, &cn_hash))
	{
	  const struct pf_cn *rule = lookup_cn_rule (pfs->cns.hash_table, cn, cn_hash);
	  if (rule)
	    {
#ifdef ENABLE_DEBUG
	      if (check_debug_level (D_PF_DEBUG))
		pf_cn_test_print ("PF_CN_MATCH", type, prefix, cn, !rule->exclude, rule);
#endif
	      if (!rule->exclude)
		return true;
	      else
		return false;
	    }
	  else
	    {
#ifdef ENABLE_DEBUG
	      if (check_debug_level (D_PF_DEBUG))
		pf_cn_test_print ("PF_CN_DEFAULT", type, prefix, cn, pfs->cns.default_allow, NULL);
#endif
	      if (pfs->cns.default_allow)
		return true;
	      else
		return false;
	    }
	}
    }
#ifdef ENABLE_DEBUG
  if (check_debug_level (D_PF_DEBUG))
    pf_cn_test_print ("PF_CN_FAULT", type, prefix, tls_common_name (tm, false), false, NULL);
#endif
  return false;
}

bool
pf_addr_test_dowork (const struct context *src, const struct mroute_addr *dest, const char *prefix)
{
  struct pf_set *pfs = src->c2.pf.pfs;
  if (pfs && !pfs->kill)
    {
      const in_addr_t addr = in_addr_t_from_mroute_addr (dest);
      const struct pf_subnet *se = pfs->sns.list;
      while (se)
	{
	  if ((addr & se->rule.netmask) == se->rule.network)
	    {
#ifdef ENABLE_DEBUG
	      if (check_debug_level (D_PF_DEBUG))
		pf_addr_test_print ("PF_ADDR_MATCH", prefix, src, dest, !se->rule.exclude, &se->rule);
#endif
	      return !se->rule.exclude;
	    }
	  se = se->next;
	}
#ifdef ENABLE_DEBUG
      if (check_debug_level (D_PF_DEBUG))
	pf_addr_test_print ("PF_ADDR_DEFAULT", prefix, src, dest, pfs->sns.default_allow, NULL);
#endif
      return pfs->sns.default_allow;
    }
  else
    {
#ifdef ENABLE_DEBUG
      if (check_debug_level (D_PF_DEBUG))
	pf_addr_test_print ("PF_ADDR_FAULT", prefix, src, dest, false, NULL);
#endif
      return false;
    }
}

#ifdef PLUGIN_PF
void
pf_check_reload (struct context *c)
{
  const int slow_wakeup = 15;
  const int fast_wakeup = 1;
  const int wakeup_transition = 60;
  bool reloaded = false;

  if (c->c2.pf.enabled
      && c->c2.pf.filename
      && event_timeout_trigger (&c->c2.pf.reload, &c->c2.timeval, ETT_DEFAULT))
    {
      struct stat s;
      if (!stat (c->c2.pf.filename, &s))
	{
	  if (s.st_mtime > c->c2.pf.file_last_mod)
	    {
	      struct pf_set *pfs = pf_init_from_file (c->c2.pf.filename);
	      if (pfs)
		{
		  if (c->c2.pf.pfs)
		    pf_destroy (c->c2.pf.pfs);
		  c->c2.pf.pfs = pfs;
		  reloaded = true;
		  if (pf_kill_test (pfs))
		    {
		      c->sig->signal_received = SIGTERM;
		      c->sig->signal_text = "pf-kill";
		    }
		}
	      c->c2.pf.file_last_mod = s.st_mtime;
	    }
	}
      {
	int wakeup = slow_wakeup;
	if (!c->c2.pf.pfs && c->c2.pf.n_check_reload < wakeup_transition)
	  wakeup = fast_wakeup;
	event_timeout_init (&c->c2.pf.reload, wakeup, now);
	reset_coarse_timers (c);
	c->c2.pf.n_check_reload++;
      }
    }
#ifdef ENABLE_DEBUG
  if (reloaded && check_debug_level (D_PF_DEBUG))
    pf_context_print (&c->c2.pf, "pf_check_reload", D_PF_DEBUG);
#endif
}
#endif

#ifdef MANAGEMENT_PF
bool
pf_load_from_buffer_list (struct context *c, const struct buffer_list *config)
{
  struct pf_set *pfs = pf_init (config, "[SERVER-PF]", false);
  if (pfs)
    {
      if (c->c2.pf.pfs)
	pf_destroy (c->c2.pf.pfs);
      c->c2.pf.pfs = pfs;
      return true;
    }
  else
    return false;
}
#endif

void
pf_init_context (struct context *c)
{
  struct gc_arena gc = gc_new ();
#ifdef PLUGIN_PF
  if (plugin_defined (c->plugins, OPENVPN_PLUGIN_ENABLE_PF))
    {
      const char *pf_file = create_temp_filename (c->options.tmp_dir, "pf", &gc);
      delete_file (pf_file);
      setenv_str (c->c2.es, "pf_file", pf_file);

      if (plugin_call (c->plugins, OPENVPN_PLUGIN_ENABLE_PF, NULL, NULL, c->c2.es) == OPENVPN_PLUGIN_FUNC_SUCCESS)
	{
	  event_timeout_init (&c->c2.pf.reload, 1, now);
	  c->c2.pf.filename = string_alloc (pf_file, NULL);
	  c->c2.pf.enabled = true;
#ifdef ENABLE_DEBUG
	  if (check_debug_level (D_PF_DEBUG))
	    pf_context_print (&c->c2.pf, "pf_init_context#1", D_PF_DEBUG);
#endif
	}
      else
	{
	  msg (M_WARN, "WARNING: OPENVPN_PLUGIN_ENABLE_PF disabled");
	}
    }
#endif
#ifdef MANAGEMENT_PF
  if (!c->c2.pf.enabled && management_enable_pf (management))
    {
      c->c2.pf.enabled = true;
#ifdef ENABLE_DEBUG
      if (check_debug_level (D_PF_DEBUG))
	pf_context_print (&c->c2.pf, "pf_init_context#2", D_PF_DEBUG);
#endif
    }
#endif
  gc_free (&gc);
}

void
pf_destroy_context (struct pf_context *pfc)
{
#ifdef PLUGIN_PF
  if (pfc->filename)
    {
      delete_file (pfc->filename);
      free (pfc->filename);
    }
#endif
  if (pfc->pfs)
    pf_destroy (pfc->pfs);
}

#ifdef ENABLE_DEBUG

static void
pf_subnet_set_print (const struct pf_subnet_set *s, const int lev)
{
  struct gc_arena gc = gc_new ();
  if (s)
    {
      struct pf_subnet *e;

      msg (lev, "  ----- struct pf_subnet_set -----");
      msg (lev, "  default_allow=%s", drop_accept (s->default_allow));

      for (e = s->list; e != NULL; e = e->next)
	{
	  msg (lev, "   %s/%s %s",
	       print_in_addr_t (e->rule.network, 0, &gc),
	       print_in_addr_t (e->rule.netmask, 0, &gc),
	       drop_accept (!e->rule.exclude));
	}
    }
  gc_free (&gc);
}

static void
pf_cn_set_print (const struct pf_cn_set *s, const int lev)
{
  if (s)
    {
      struct hash_iterator hi;
      struct hash_element *he;

      msg (lev, "  ----- struct pf_cn_set -----");
      msg (lev, "  default_allow=%s", drop_accept (s->default_allow));

      if (s->hash_table)
	{
	  hash_iterator_init (s->hash_table, &hi, false);
	  while ((he = hash_iterator_next (&hi)))
	    {
	      struct pf_cn *e = (struct pf_cn *)he->value;
	      msg (lev, "   %s %s",
		   e->cn,
		   drop_accept (!e->exclude));
	    }

	  msg (lev, "  ----------");

	  {
	    struct pf_cn_elem *ce;
	    for (ce = s->list; ce != NULL; ce = ce->next)
	      {
		struct pf_cn *e = lookup_cn_rule (s->hash_table, ce->rule.cn, cn_hash_function (ce->rule.cn, 0));
		if (e)
		  {
		    msg (lev, "   %s %s",
			 e->cn,
			 drop_accept (!e->exclude));
		  }
		else
		  {
		    msg (lev, "   %s LOOKUP FAILED", ce->rule.cn);
		  }
	      }
	  }
	}
    }
}

static void
pf_set_print (const struct pf_set *pfs, const int lev)
{
  if (pfs)
    {
      msg (lev, " ----- struct pf_set -----");
      msg (lev, " kill=%d", pfs->kill);
      pf_subnet_set_print (&pfs->sns, lev);
      pf_cn_set_print (&pfs->cns, lev);
    }
}

void
pf_context_print (const struct pf_context *pfc, const char *prefix, const int lev)
{
  msg (lev, "----- %s : struct pf_context -----", prefix);
  if (pfc)
    {
      msg (lev, "enabled=%d", pfc->enabled);
#ifdef PLUGIN_PF
      msg (lev, "filename='%s'", np(pfc->filename));
      msg (lev, "file_last_mod=%u", (unsigned int)pfc->file_last_mod);
      msg (lev, "n_check_reload=%u", pfc->n_check_reload);
      msg (lev, "reload=[%d,%u,%u]", pfc->reload.defined, pfc->reload.n, (unsigned int)pfc->reload.last);
#endif
      pf_set_print (pfc->pfs, lev);
    }
  msg (lev, "--------------------");
}

#endif

#endif