#!/usr/bin/python3
#  Author: Jamie Strandboge <jamie@ubuntu.com>
#  Copyright (C) 2015 Canonical Ltd.
#
#  This script is distributed under the terms and conditions of the GNU General
#  Public License, Version 3 or later. See http://www.gnu.org/copyleft/gpl.html
#  for details.

import codecs
import optparse
import os
import re
import shutil
import sys
import tempfile
import yaml
from yaml.loader import SafeLoader


# Force 15.1 to be 15.10. This is not ideal, but we lose data with yaml that
# has 'policy_version: 15.10'. Note, this applies to all floats, but we only
# have one thing (policy_version) that is allowed to be a float, so it is not
# an issue)
float_pat = re.compile(r'^[0-9][0-9]\.[0-9]$')


def float_constructor(loader, node):
    value = loader.construct_scalar(node)
    if float_pat.search(value):
        value += '0'
    return value


yaml.add_constructor(u'tag:yaml.org,2002:float', float_constructor,
                     Loader=SafeLoader)


DEBUGGING = False


#
# Configuration
#
sc_system_policy_dir = "/usr/share/seccomp"
sc_include_policy_dir = "/var/lib/snappy/seccomp"

#
# Helpers
#


def debug(out):
    '''Print debug message'''
    if DEBUGGING:
        try:
            sys.stderr.write("DEBUG: %s\n" % (out))
        except IOError:
            pass


def valid_path(path, relative_ok=False):
    '''Valid path'''
    m = "Invalid path: %s" % (path)
    if not relative_ok and not path.startswith('/'):
        debug("%s (relative)" % (m))
        return False

    if '"' in path:  # We double quote elsewhere
        debug("%s (quote)" % (m))
        return False

    if '../' in path:
        debug("%s (../ path escape)" % (m))
        return False

    try:
        p = os.path.normpath(path)
    except Exception:
        debug("%s (could not normalize)" % (m))
        return False

    if p != path:
        debug("%s (normalized path != path (%s != %s))" % (m, p, path))
        return False

    # If we made it here, we are safe
    return True


def _is_safe(s):
    '''Known safe regex'''
    if re.search(r'^[a-zA-Z_0-9\-\.]+$', s):
        return True
    return False


def valid_policy_vendor(s):
    '''Verify the policy vendor'''
    return _is_safe(s)


def valid_policy_version(v):
    '''Verify the policy version'''
    try:
        float(v)
    except ValueError:
        return False
    if float(v) < 0:
        return False
    return True


def valid_template_name(s):
    '''Verify the template name'''
    return _is_safe(s)


def valid_policy_group_name(s):
    '''Verify policy group name'''
    return _is_safe(s)


def open_file_read(path):
    '''Open specified file read-only'''
    try:
        orig = codecs.open(path, 'r', "UTF-8")
    except Exception:
        raise

    return orig


def autodetect_vendor_version():
    '''Try to autodetect the vendor and version'''
    vendor = None
    version = None

    channel = "/etc/system-image/channel.ini"
    lsb = "/etc/lsb-release"
    if os.path.isfile(channel):
        for line in open_file_read(channel).readlines():
            if re.search(r'^channel: [a-zA-Z]+', line):
                vendor = line.strip().split()[1].split('/')[0].lower()
    elif os.path.isfile(lsb):
        for line in open_file_read(lsb).readlines():
            if re.search(r'DISTRIB_ID=[a-zA-Z]+', line):
                vendor = line.strip().split('=')[1].lower()

    if os.path.isfile(lsb):
        for line in open_file_read(lsb).readlines():
            if re.search(r'DISTRIB_RELEASE=[0-9]+(\.[0-9]+)?$', line):
                version = line.strip().split('=')[1].lower()
                try:
                    float(version)
                except Exception:
                    version = None

    if vendor is None or version is None:
        return (None, None)
    elif not os.path.isdir(os.path.join("%s/templates" % sc_system_policy_dir,
                                        vendor, version)):
        debug("Could not find policy in %s" % os.path.join(
              sc_system_policy_dir, vendor, version))
        return (None, None)

    debug("Autodetected vendor=%s, version=%s" % (vendor, version))
    return (vendor, version)


def valid_syscall(s):
    '''Verify the syscall'''
    if s == "@unrestricted" or re.search(r'^[a-z_0-9]+$', s):
        return True
    return False


def valid_deny(s):
    '''Verify denied syscall'''
    tmp = s.split()
    if tmp[0] == 'deny' and len(tmp) > 1 and valid_syscall(tmp[1]):
        return tmp[1]
    return None


def read_manifest(m):
    '''Verify manifest'''
    opt = dict()
    known = ['caps',
             'policy_groups',
             'policy_vendor',
             'policy_version',
             'security-template',
             'syscalls',
             'template'
             ]
    try:
        y = yaml.safe_load(open_file_read(m).read())
        # y = yaml.load(open_file_read(m).read())
    except Exception as e:
        raise SeccompFilterGenException("Could not parse yaml: %s" % str(e))

    if not isinstance(y, dict):
        raise SeccompFilterGenException("Improper yaml format")

    if 'security-template' in y and 'template' in y:
        raise SeccompFilterGenException("Must use either 'security-template' "
                                        "or 'template', not both")
    if 'caps' in y and 'policy_groups' in y:
        raise SeccompFilterGenException("Must use either 'caps' or "
                                        "'policy_groups', not both")

    for key in y:
        if key not in known:
            raise SeccompFilterGenException("Found unknown key '%s'" % key)
        elif key in ['caps', 'policy_groups', 'syscalls']:
            if not isinstance(y[key], list):
                raise SeccompFilterGenException("'%s' is not a list" % key)
        elif key in ['policy_vendor', 'security-template', 'template']:
            if not isinstance(y[key], str):
                raise SeccompFilterGenException("'%s' is not a string" % key)
        elif key == 'policy_version':
            try:
                float(y[key])
            except Exception as e:
                raise SeccompFilterGenException("'%s' is not a number" % key)
        if key in ['security-template', 'template']:
            opt['template'] = y[key]
        elif key in ['caps', 'policy_groups']:
            opt['policy_groups'] = ",".join(y[key])
        elif key == 'syscalls':
            opt[key] = ",".join(y[key])
        else:
            opt[key] = y[key]

    if 'template' not in opt:
        opt['template'] = 'default'

    return opt


class SeccompFilterGen:
    def __init__(self, opt):
        if opt.manifest:
            if not os.path.exists(opt.manifest):
                raise SeccompFilterGenException("Could not find '%s'" %
                                                opt.manifest)
            args = read_manifest(opt.manifest)
        else:
            args = vars(opt)

        self.parse_args(args)

    def parse_args(self, args):
        '''Verify arguments'''
        if 'template' in args and not valid_template_name(args['template']):
            raise SeccompFilterGenException("Invalid template '%s'" %
                                            args['template'])
        self.template = args['template']

        self.policy_groups = []
        if 'policy_groups' in args and args['policy_groups'] is not None:
            for i in args['policy_groups'].split(','):
                if not valid_policy_group_name(i):
                    raise SeccompFilterGenException("Invalid policy group '%s'"
                                                    % i)
                if i not in self.policy_groups:
                    self.policy_groups.append(i)
            self.policy_groups.sort()

        self.policy_vendor = None
        if 'policy_vendor' in args and args['policy_vendor'] is not None:
            if not valid_policy_vendor(args['policy_vendor']):
                raise SeccompFilterGenException("Invalid policy vendor '%s'" %
                                                args['policy_vendor'])
            self.policy_vendor = args['policy_vendor']

        self.policy_version = None
        if 'policy_version' in args and args['policy_version'] is not None:
            if not valid_policy_version(args['policy_version']):
                raise SeccompFilterGenException("Invalid policy version '%s'" %
                                                args['policy_version'])
            self.policy_version = args['policy_version']

        if (self.policy_vendor is not None and self.policy_version is None) \
           or (self.policy_vendor is None and self.policy_version is not None):
            raise SeccompFilterGenException("Must specify both policy version "
                                            "and vendor")

        if self.policy_vendor is None and self.policy_version is None:
            (self.policy_vendor, self.policy_version) = \
                autodetect_vendor_version()

        self.syscalls = []
        if 'syscalls' in args and args['syscalls'] is not None:
            for i in args['syscalls'].split(','):
                if not valid_syscall(i):
                    raise SeccompFilterGenException("Invalid syscall '%s'" % i)
                if i not in self.syscalls:
                    self.syscalls.append(i)
            self.syscalls.sort()

    def parse_policy(self, policy):
        '''Parse policy'''
        lines = []
        denied = set([])
        for line in policy.splitlines():
            if line.startswith('#') or re.search('^\s*$', line) or \
               valid_syscall(line):
                lines.append(line)
            else:
                d = valid_deny(line)
                if d is None:
                    raise SeccompFilterGenException("Could not parse: '%s'" %
                                                    line)
                denied.add(d)
                lines.append(d)

        contents = ""
        for line in lines:
            if line in denied:
                contents += "# EXPLICITLY DENIED: %s\n" % line
            else:
                contents += "%s\n" % line

        return contents

    def get_template(self):
        '''Get security template contents'''
        if self.policy_vendor and self.policy_version:
            fn = os.path.join(sc_system_policy_dir,
                              "templates/%s/%s" % (self.policy_vendor,
                                                   self.policy_version),
                              self.template)
        else:
            fn = os.path.join(sc_system_policy_dir, "templates", self.template)

        if not os.path.exists(fn):
            inc_t = os.path.join(sc_include_policy_dir,
                                 "templates",
                                 self.template)
            if not os.path.exists(inc_t):
                raise SeccompFilterGenException("Invalid template '%s'" %
                                                self.template)

            fn = inc_t

        return open_file_read(fn).read()

    def get_policy_groups(self):
        '''Get security policy group contents'''
        pg_dir = os.path.join(sc_system_policy_dir, "policygroups")
        if self.policy_vendor and self.policy_version:
            pg_dir = os.path.join(sc_system_policy_dir, "policygroups/%s/%s" %
                                  (self.policy_vendor, self.policy_version))

        contents = ""
        for p in self.policy_groups:
            fn = os.path.join(pg_dir, p)
            if not os.path.exists(fn):
                inc_p = os.path.join(sc_include_policy_dir, "policygroups", p)
                if not os.path.exists(inc_p):
                    raise SeccompFilterGenException("Invalid policy group "
                                                    "'%s'" % p)
                fn = inc_p

            contents += "\n%s" % open_file_read(fn).read()

        return contents

    def get_syscalls(self):
        '''Get user-specified syscalls'''
        contents = ""
        if len(self.syscalls) > 0:
            contents += "\n# Specified syscalls"
            for s in self.syscalls:
                contents += "\n%s" % s
            contents += "\n"

        return contents

    def get_additional(self, orig):
        '''Get any additional policy'''
        contents = ""
        fn = "%s.additional" % orig
        if os.path.isfile(fn):
            contents += "\n# Additional\n"
            contents += "%s" % open_file_read(fn).read()

        return contents

    def output_policy(self, out_fn=None):
        '''Output policy'''
        tmp = self.get_template()
        tmp += self.get_policy_groups()
        tmp += self.get_syscalls()
        if out_fn is not None:
            tmp += self.get_additional(out_fn)

        policy = self.parse_policy(tmp)

        if not out_fn:
            sys.stdout.write('%s\n' % policy)
        else:
            if os.path.exists(out_fn):
                raise SeccompFilterGenException("'%s' already exists" % out_fn)

            dir = os.path.dirname(out_fn)
            if not os.path.exists(dir):
                os.mkdir(dir)

            if not os.path.isdir(dir):
                raise SeccompFilterGenException("'%s' is not a directory" %
                                                dir)

            f, fn = tempfile.mkstemp(prefix='sc-filtergen')
            if not isinstance(policy, bytes):
                policy = policy.encode('utf-8')
            os.write(f, policy)
            os.close(f)

            os.chmod(fn, 0o0644)
            shutil.move(fn, out_fn)


#
# End helpers
#


class SeccompFilterGenException(Exception):
    '''This class represents SeccompFilterGen exceptions'''
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


def main():
    global DEBUGGING

    parser = optparse.OptionParser()
    parser.add_option("-d", "--debug",
                      help="Show debugging output",
                      action='store_true',
                      default=False)
    parser.add_option("-t", "--template", "--security-template",
                      dest="template",
                      help="Use non-default policy template",
                      metavar="TEMPLATE",
                      default='default')
    parser.add_option("-p", "--policy-groups", "--caps",
                      type=str,
                      dest="policy_groups",
                      help="Comma-separated list of policy groups",
                      metavar="POLICYGROUPS")
    parser.add_option("--syscalls",
                      type=str,
                      dest="syscalls",
                      help="Comma-separated list of syscalls",
                      metavar="SYSCALLS")
    parser.add_option("--policy-version",
                      type=str,
                      dest="policy_version",
                      help="Specify version for templates and policy groups",
                      metavar="VERSION")
    parser.add_option("--policy-vendor",
                      type=str,
                      dest="policy_vendor",
                      help="Specify vendor for templates and policy groups",
                      metavar="VENDOR")
    parser.add_option("-o", "--output-file",
                      dest="output_file",
                      help="Output to file (default to stdout)",
                      metavar="FILE",
                      default=None)
    parser.add_option("--policy-dir",
                      dest="policy_dir",
                      help="Use non-default policy directory",
                      metavar="DIR",
                      default=None)
    parser.add_option("--include-policy-dir",
                      dest="include_policy_dir",
                      help="Use non-default include policy directory",
                      metavar="DIR",
                      default=None)
    parser.add_option("-m", "--manifest",
                      type=str,
                      dest="manifest",
                      help="Yaml manifest file",
                      metavar="FILE")

    (opt, args) = parser.parse_args()

    if opt.debug:
        DEBUGGING = True

    if opt.policy_dir is not None and os.path.isdir(opt.policy_dir):
        global sc_system_policy_dir
        sc_system_policy_dir = opt.policy_dir
        debug("sc_system_policy_dir: %s" % opt.policy_dir)

    if opt.include_policy_dir is not None and \
       os.path.isdir(opt.include_policy_dir):
        global sc_include_policy_dir
        sc_include_policy_dir = opt.include_policy_dir
        debug("sc_include_policy_dir: %s" % opt.include_policy_dir)

    scg = SeccompFilterGen(opt)
    scg.output_policy(opt.output_file)

if __name__ == "__main__":
    sys.exit(main())
