#!/usr/bin/python3
#  OpenVPN 3 Linux client -- Next generation OpenVPN client
#
#  SPDX-License-Identifier: AGPL-3.0-only
#
#  Copyright (C) 2019 - 2023  OpenVPN Inc <sales@openvpn.net>
#  Copyright (C) 2019 - 2023  David Sommerseth <davids@openvpn.net>
#

##
# @file  openvpn3-as
#
# @brief  Utility to download and import a configuration profile directly
#         from an OpenVPN Access Server

import sys
import os
import argparse
from getpass import getpass
import urllib.parse
import urllib.request
import ssl
import openvpn3
import pwd
import dbus
import xml.etree.ElementTree as xmlET
import socket


##
#  Challenge/Response based authentication
#
#  If the Access Server demands more steps for the authentication,
#  it will send and error response with a challenge string (CRV1 string).
#  This needs to be processed locally before the client tries to reconnect
#  with the proper response.
#
#  The CRV1 protocol technically also supports an information passing with
#  no response back to the server.  This is not properly implemented nor
#  tested.
#
#  This is implemented as an Exception class.  This is due to the server
#  sends and error message and the server connection is lost.  This made
#  it simpler to keep a fairly sane programmatic logic when this situation
#  occurs.
#
class CRV1Authentication(Exception):
    ##
    #  Constructor
    #
    #  @param username  Username of the getting the challenge/response auth
    #  @param crv1str   The complete CRV1 string which will be parsed and
    #                   processed by this class
    def __init__(self, username, crv1str):
        self.__username = username

        # The CRV1 string contains several fields separated by colon,
        # but only the last field can contain colon as part of data.
        # To ensure this is correctly decoded, decode and re-encode
        # the various fields
        crv1 = crv1str.split(':')
        self.__flags = crv1[1].split(',')
        self.__stateid = crv1[2]
        self.__challenge = ':'.join(crv1[4:])
        self.__response = None


    ##
    #  This processes the challenge/response request and will
    #  both print and request the user for input, according to
    #  the request.
    #
    #  @return Returns True if a response is expected to be
    #          sent back to the server.
    #
    def Process(self):
        print(self.__challenge, end='', flush=True)
        if 'R' in self.__flags:
            if 'E' in self.__flags:
                self.__response = input(': ')
            else:
                self.__response = getpass(': ')
            return True
        else:
            return False


    ##
    #  This generates the response needed according to the
    #  CRV1 protocol.  This response is expected to be passed
    #  in the password field in the next connection to the server.
    #
    #  @return Retuns a properly formatted CRV1 response string
    #
    def GetResponse(self):
        if 'R' not in self.__flags or self.__response is None:
            return 'CRV1::%s::' % (self.__stateid,)
        return 'CRV1::%s::%s' % (self.__stateid, self.__response)


##
#  Helper class to integrate with a few systemd operations
#
class SystemdServiceUnit(object):
    """Simple implementation for managing systemd services"""

    def __init__(self, dbuscon, unit_name):
        self._dbuscon = dbuscon
        self._unit_name = unit_name

        # Retrieve access to the main systemd manager object
        self._srvmngr_obj = self._dbuscon.get_object('org.freedesktop.systemd1',
                                                     '/org/freedesktop/systemd1')
        # Establish a link to the manager interface in the manager object
        self._srvmgr = dbus.Interface(self._srvmngr_obj,
                                      dbus_interface='org.freedesktop.systemd1.Manager')

    def Enable(self):
        """Enable a systemd service unit to be started at boot"""

        self._srvmgr.EnableUnitFiles([self._unit_name,], False, True)

    def Start(self):
        """Start a systemd service unit"""

        self._srvmgr.StartUnit(self._unit_name, 'replace')


##
#  Fetches a configuration profile from an OpenVPN Access Server
#  for a specific user.
#
#  @param uri        URL to the OpenVPN Access Server to connect to
#  @param username   Username to connect as
#  @param password   Password to use for the authentication
#  @param autologin  Boolean, if True it will attempt to retrieve the
#                    auto-login profile.  If this is available for, the
#                    the user will not be prompted for a credentials when
#                    connecting to the VPN server
#
#  @return Returns a string containing the configuration profile for the user
#          on success.  Otherwise an exception with the error is thrown.
#
def fetch_profile(uri, username, password, autologin, insecure_certs):
    resturi = uri.geturl() + '/rest/' + (autologin and 'GetAutologin' or 'GetUserlogin') + "?tls-cryptv2=1&action=import"

    # Prepare credentials
    authinfo = urllib.request.HTTPBasicAuthHandler()
    authinfo.add_password('OpenVPN Access Server',
                          resturi, username, password)

    # Prepare appropriate handler
    handler = None
    if 'https' == uri.scheme:
        # Setup default SSL context
        sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
        sslctx.options &= ~ssl.OP_NO_SSLv3
        sslctx.options &= ~ssl.OP_NO_SSLv2
        sslctx.load_default_certs()
        if not insecure_certs:
            sslctx.verify_mode = ssl.CERT_REQUIRED
            sslctx.check_hostname = True
        else:
            sslctx.check_hostname = False
            sslctx.verify_mode = ssl.CERT_NONE

        handler = urllib.request.HTTPSHandler(context=sslctx)
        if 'OPENVPN3_DEBUG' in os.environ:
            handler.set_http_debuglevel(1)
    else:
        raise RuntimeError('Only the https protocol is supported')

    # Prepare urllib, providing credentials
    opener = urllib.request.build_opener(handler,
                                         authinfo,
                                         urllib.request.CacheFTPHandler)
    opener.addheaders = [('User-agent',
                          'openvpn3-linux::openvpn3-as/' + openvpn3.constants.VERSION)]
    urllib.request.install_opener(opener)

    # Fetch the URI/configuration profile
    if 'OPENVPN3_DEBUG' in os.environ:
        print('REST API URL: ' + resturi)
    try:
        with  urllib.request.urlopen(resturi) as response:
            contents = response.read()
        return contents.decode('utf-8')
    except urllib.error.URLError as e:
        try:
            errxml = xmlET.fromstring(e.read().decode('utf-8'))
        except AttributeError:
            err = "(unknown)"
            if urllib.error.URLError == type(e):
                if ssl.SSLError == type(e.reason):
                    if 'CERTIFICATE_VERIFY_FAILED' == e.reason.reason:
                        err = "Server HTTPS certificate validation failed"
                    else:
                        err = e.reason.reason
                elif socket.gaierror == type(e.reason):
                    err = e.reason.strerror
                else:
                    err = '[%s] %s' %(type(e.reason), str(e.reason))
            else:
                err = str(e)
            raise RuntimeError(err)

        if 'Error' != errxml.tag:
            raise e

        typetag = errxml.find('Type')
        msgtag = errxml.find('Message')
        if msgtag is None or typetag is None:
            raise e
        if 'Authorization Required' == typetag.text:
            if msgtag.text[:5] == 'CRV1:':
                raise CRV1Authentication(username, msgtag.text);

        raise RuntimeError(msgtag.text)


def main():
    # Parse the command line
    cmdparser = argparse.ArgumentParser(description='OpenVPN 3 Access Server profile downloader',
                                        usage='%(prog)s [options] <https://...>')
    cmdparser.add_argument('--autologin', action='store_true',
                           help='Download the auto-login profile (default: user-login profile)')
    cmdparser.add_argument('--systemd-start', action='store_true',
                           help='Enable and start this configuration during boot automatically')

    cmdparser.add_argument('--impersistent', action='store_true',
                           help='Do not import the configuration profile as a persistent profile (default: persistent)')
    cmdparser.add_argument('--name', metavar='CONFIG-NAME',
                           help='Override the automatically generated profile name (default: AS:$servername)')
    cmdparser.add_argument('--owner', metavar='USERNAME',
                           help='Import this configuration with USERNAME as the profile owner (requires root)')
    cmdparser.add_argument('--username', metavar='USERNAME',
                           help='Username to use for OpenVPN Access Server login')
    cmdparser.add_argument('--insecure-certs', action='store_true',
                           help='Ignore invalid https server certificates used by the server')

    cmdparser.add_argument('server', metavar='<https://...>', type=str, nargs=1,
                           help='URL to OpenVPN Access Server')

    opts = cmdparser.parse_args()
    if len(opts.server) != 1:
        print('ERROR: Missing OpenVPN Access Server URL')
        sys.exit(1)

    if opts.systemd_start and os.geteuid() != 0:
        print('ERROR: This must be run as root when using --systemd-start')
        sys.exit(1)

    owner_uid = None
    if opts.owner:
        if os.geteuid() != 0:
            print('ERROR: This must be run as root when using --owner')
            sys.exit(1)
        try:
            pwe = pwd.getpwnam(opts.owner)
            owner_uid = pwe.pw_uid
        except KeyError:
            raise RuntimeError('ERROR: User "%s" was not found' % opts.owner)

    # Connect to the D-Bus system bus and establish a link
    # to the OpenVPN 3 Configuration Manager
    sysbus = dbus.SystemBus()
    cfgmgr = openvpn3.ConfigurationManager(sysbus)

    # OpenVPN Access URI and credentials
    uri = urllib.parse.urlparse(opts.server[0].rstrip('/'))
    username = opts.username and opts.username or input('OpenVPN Access Server Username: ')
    password = getpass('OpenVPN Access Server Password: ')

    # Retrieve the configuration profile and generate a profile name
    profile_name = opts.name and opts.name or 'AS:' + uri.netloc
    done = False
    while not done:
        try:
            profile = fetch_profile(uri, username, password, opts.autologin, opts.insecure_certs)
            done = True
        except CRV1Authentication as crv1auth:
            if crv1auth.Process():
                password = crv1auth.GetResponse()
        except Exception as e:
            raise RuntimeError('Failed to download configuration profile: %s' % str(e))

    if profile is None or len(profile) < 1:
        print('ERROR: No configuration profile retrieved')
        sys.exit(3)

    # Import the configuration profile to the configuration manager
    cfgobj = cfgmgr.Import(profile_name, profile, False, not opts.impersistent,
                           'openvpn3-as')
    cfgobj.SetProperty('locked_down', True)

    print('-'*60)
    print('Profile imported successfully')
    print('Configuration name: %s' % profile_name)
    print('Configuration path: %s' % cfgobj.GetPath())

    if opts.systemd_start and owner_uid is not None:
        print('Granting root user access to profile ... ', end='', flush=True)
        cfgobj.AccessGrant(0)
        cfgobj.SetOwnershipTransfer(True)
        print('Done')

    if owner_uid is not None:
        print('Setting the profile ownership to "%s" (uid %u) ... ' % (opts.owner, owner_uid), end='', flush=True)
        cfgmgr.TransferOwnership(cfgobj.GetPath(), owner_uid)
        print('Done')

    if opts.systemd_start:
        unit_name = 'openvpn3-session@%s.service' % cfgobj.GetConfigName()
        service = SystemdServiceUnit(sysbus, unit_name)

        print('Enabling %s during boot ... ' % unit_name, end='', flush=True)
        service.Enable()
        print('Done')

        print('Starting %s ... ' % unit_name, end='', flush=True)
        service.Start()
        print('Done')

#
#  MAIN
#
if __name__ == '__main__':
    try:
        main()
    except RuntimeError as e:
        print(e)
    except KeyboardInterrupt:
        print()
