#!/usr/bin/python3
"""
Copyright (C) 2023  Michael Ablassmeier <abi@grinser.de>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import signal
import logging
import argparse
from typing import List
from datetime import datetime
from functools import partial
from concurrent.futures import ThreadPoolExecutor, as_completed
from nbd import __version__ as __nbdversion__
from libvirtnbdbackup import sighandle
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import virt
from libvirtnbdbackup.objects import DomainDisk
from libvirtnbdbackup.virt import checkpoint
from libvirtnbdbackup import output
from libvirtnbdbackup.output import stream
from libvirtnbdbackup import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup import exceptions
from libvirtnbdbackup.backup import partialfile
from libvirtnbdbackup.backup import job
from libvirtnbdbackup.backup import disk
from libvirtnbdbackup.backup import metadata
from libvirtnbdbackup.backup import check
from libvirtnbdbackup.ssh.exceptions import sshError
from libvirtnbdbackup.virt.exceptions import (
    domainNotFound,
    connectionFailed,
)
from libvirtnbdbackup.output.exceptions import OutputException


def main() -> None:
    """Handle backup operation and settings."""
    parser = argparse.ArgumentParser(
        description="Backup libvirt/qemu virtual machines",
        epilog=(
            "Examples:\n"
            "   # full backup of domain 'webvm' with all attached disks:\n"
            "\t%(prog)s -d webvm -l full -o /backup/\n"
            "   # incremental backup:\n"
            "\t%(prog)s -d webvm -l inc -o /backup/\n"
            "   # differential backup:\n"
            "\t%(prog)s -d webvm -l diff -o /backup/\n"
            "   # full backup, exclude disk 'vda':\n"
            "\t%(prog)s -d webvm -l full -x vda -o /backup/\n"
            "   # full backup, backup only disk 'vdb':\n"
            "\t%(prog)s -d webvm -l full -i vdb -o /backup/\n"
            "   # full backup, compression enabled:\n"
            "\t%(prog)s -d webvm -l full -z -o /backup/\n"
            "   # full backup, create archive:\n"
            "\t%(prog)s -d webvm -l full -o - > backup.zip\n"
            "   # full backup of vm operating on remote libvirtd:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system "
            "--ssh-user root -d webvm -l full -o /backup/\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    opt = parser.add_argument_group("General options")
    opt.add_argument("-d", "--domain", required=True, type=str, help="Domain to backup")
    opt.add_argument(
        "-l",
        "--level",
        default="copy",
        choices=["copy", "full", "inc", "diff", "auto"],
        type=str,
        help="Backup level. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--type",
        default="stream",
        type=str,
        choices=["stream", "raw"],
        help="Output type: stream or raw. (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Include full provisioned disk images in backup. (default: %(default)s)",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Output target directory"
    )
    opt.add_argument(
        "-C",
        "--checkpointdir",
        required=False,
        default=None,
        type=str,
        help="Persistent libvirt checkpoint storage directory",
    )
    opt.add_argument(
        "--scratchdir",
        default="/var/tmp",
        required=False,
        type=str,
        help="Target dir for temporary scratch file. (default: %(default)s)",
    )
    opt.add_argument(
        "-S",
        "--start-domain",
        default=False,
        required=False,
        action="store_true",
        help="Start virtual machine if it is offline. (default: %(default)s)",
    )
    opt.add_argument(
        "--pause",
        default=False,
        required=False,
        action="store_true",
        help="Pause virtual machine while starting backup job. (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--include",
        default=None,
        type=str,
        help="Backup only disk with target dev name (-i vda)",
    )
    opt.add_argument(
        "-x",
        "--exclude",
        default=None,
        type=str,
        help="Exclude disk(s) with target dev name (-x vda,vdb)",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        default=False,
        help="Disable progress bar",
        action="store_true",
    )
    opt.add_argument(
        "-z",
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        help="Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    opt.add_argument(
        "-w",
        "--worker",
        type=int,
        default=None,
        help=(
            "Amount of concurrent workers used "
            "to backup multiple disks. (default: amount of disks)"
        ),
    )
    opt.add_argument(
        "-F",
        "--freeze-mountpoint",
        type=str,
        default=None,
        help=(
            "If qemu agent available, freeze only filesystems on specified mountpoints within"
            " virtual machine (default: all)"
        ),
    )
    opt.add_argument(
        "-e",
        "--strict",
        default=False,
        help=(
            "Change exit code if warnings occur during backup operation. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "--no-sparse-detection",
        default=False,
        help=(
            "Skip detection of sparse ranges during incremental or differential backup. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "-T",
        "--threshold",
        type=int,
        default=None,
        help=("Execute backup only if threshold is reached."),
    )
    remopt = parser.add_argument_group("Remote Backup options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    logopt.add_argument(
        "-L",
        "--syslog",
        default=False,
        action="store_true",
        help="Additionally send log messages to syslog (default: %(default)s)",
    )
    logopt.add_argument(
        "--quiet",
        default=False,
        action="store_true",
        help="Disable logging to stderr (default: %(default)s)",
    )
    argopt.addLogColorArgs(logopt)
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-q",
        "--qemu",
        default=False,
        action="store_true",
        help="Use Qemu tools to query extents.",
    )
    debopt.add_argument(
        "-s",
        "--startonly",
        default=False,
        help="Only initialize backup job via libvirt, do not backup any data",
        action="store_true",
    )
    debopt.add_argument(
        "-k",
        "--killonly",
        default=False,
        help="Kill any running block job",
        action="store_true",
    )
    debopt.add_argument(
        "-p",
        "--printonly",
        default=False,
        help="Quit after printing estimated checkpoint size.",
        action="store_true",
    )
    argopt.addDebugArgs(debopt)

    repository = output.target()
    args = lib.argparse(parser)

    lib.setThreadName()
    args.stdout = args.output == "-"
    args.sshClient = None
    args.diskInfo = []
    args.offline = False

    if args.quiet is True:
        args.noprogress = True

    fileStream = stream.get(args, repository)

    try:
        if not args.stdout:
            fileStream.create(args.output)
    except OutputException as e:
        logging.error("Can't open output file: [%s]", e)
        sys.exit(1)

    if args.worker is not None and args.worker < 1:
        args.worker = 1

    now = datetime.now().strftime("%m%d%Y%H%M%S")
    logFile = f"{args.output}/backup.{args.level}.{now}.log"
    fileLog = lib.getLogFile(logFile) or sys.exit(1)

    counter = logCount()  # pylint: disable=unreachable
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    logging.info("Backup level: [%s]", args.level)
    if args.compress is not False:
        logging.info("Compression enabled, level [%s]", args.compress)

    try:
        check.arguments(args)
    except exceptions.BackupException as e:
        logging.error(e)
        sys.exit(1)

    if partialfile.exists(args):
        sys.exit(1)

    if not args.checkpointdir:
        args.checkpointdir = f"{args.output}/checkpoints"
    else:
        logging.info("Store checkpoints in: [%s]", args.checkpointdir)

    fileStream.create(args.checkpointdir)

    def connectionError(_, reason, args):
        """Callback if the libvirt connection drops mid
        data transfer, used to potentially cleanup the
        leftover backup job.
        """
        virConnectCloseReason = (
            "Misc I/O error",
            "End-of-file from server",
            "Keepalive timer triggered",
            "Client side connection close",
            "Unknown",
        )
        logging.error(
            "Libvirt connection error [%s], trying to reconnect",
            virConnectCloseReason[reason],
        )
        try:
            virtClient = virt.client(args)
        except connectionFailed as e:
            logging.error("Unrecoverable connection error: %s", e)
            sys.exit(1)
        domObj = virtClient.getDomain(args.domain)
        if not args.offline:
            logging.error("Attempting to stop backup task")
            virtClient.stopBackup(domObj)
        sys.exit(1)

    try:
        virtClient = virt.client(args)
        domObj = virtClient.getDomain(args.domain)
    except domainNotFound as e:
        logging.error("%s", e)
        sys.exit(1)
    except connectionFailed as e:
        logging.error("Can't connect libvirt daemon: [%s]", e)
        sys.exit(1)

    virtClient._conn.registerCloseCallback(  # pylint: disable=W0212
        connectionError, args
    )

    logging.info("Libvirt library version: [%s]", virtClient.libvirtVersion)
    logging.info("NBD library version: [%s]", __nbdversion__)

    try:
        check.vmfeature(virtClient, domObj)
        checkpoint.checkForeign(args, domObj)
        check.vmstate(args, virtClient, domObj)
        check.targetDir(args)
    except exceptions.BackupException as e:
        logging.error(e)
        sys.exit(1)
    except exceptions.CheckpointException:
        sys.exit(1)

    if args.raw is True and args.level in ("inc", "diff"):
        logging.warning(
            "Raw disks can't be included during incremental or differential backup."
        )
        logging.warning("Excluding raw disks.")
        args.raw = False

    signal.signal(
        signal.SIGINT,
        partial(sighandle.Backup.catch, args, domObj, virtClient, logging),
    )

    if args.level not in ("inc", "diff") and args.no_sparse_detection is True:
        args.no_sparse_detection = False

    vmConfig = virtClient.getDomainConfig(domObj)
    disks: List[DomainDisk] = virtClient.getDomainDisks(args, vmConfig)
    args.info = virtClient.getDomainInfo(vmConfig)
    if virtClient.getTPMDevice(vmConfig):
        logging.warning("Emulated TPM device attached: User action required.")
        logging.warning(
            "Please manually backup contents of: [/var/lib/libvirt/swtpm/%s/]",
            domObj.UUIDString(),
        )

    try:
        check.diskformat(args, disks)
    except exceptions.BackupException as e:
        logging.info(e)

    if not disks:
        logging.error("Unable to detect disks suitable for backup.")
        metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)
        sys.exit(1)

    try:
        check.blockjobs(args, virtClient, domObj, disks)
    except exceptions.BackupException as e:
        logging.error(e)
        sys.exit(1)

    logging.info(
        "Backup will save [%s] attached disks.",
        len(disks),
    )
    if args.worker is None or args.worker > int(len(disks)):
        args.worker = int(len(disks))
    logging.info("Concurrent backup processes: [%s]", args.worker)

    if args.killonly is True:
        logging.info("Stopping backup job")
        if not virtClient.stopBackup(domObj):
            sys.exit(1)
        sys.exit(0)

    try:
        checkpoint.create(args, domObj)
    except exceptions.CheckpointException as errmsg:
        logging.error(errmsg)
        sys.exit(1)

    if args.printonly and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        logging.info("Estimated checkpoint backup size: [%s] Bytes", size)
        sys.exit(0)

    if args.threshold and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        if size < args.threshold:
            logging.info(
                "Backup size [%s] does not meet required threshold [%s], skipping backup.",
                size,
                args.threshold,
            )
            sys.exit(0)

    if virtClient.remoteHost != "":
        args.sshClient = lib.sshSession(args, virtClient.remoteHost)
        if not args.sshClient:
            logging.error("Remote backup detected but ssh session setup failed")
            sys.exit(1)
        logging.info(
            "Remote NBD Endpoint host: [%s]",
            virtClient.remoteHost,
        )
        if args.offline is True:
            logging.info(
                "Remote ports used for backup: [%s-%s]",
                args.nbd_port,
                args.nbd_port + args.worker,
            )
    else:
        logging.info("Local NBD Endpoint sockets:")
        for sdisk in disks:
            logging.info(
                "%s: [nbd+unix:///%s?socket=%s]",
                sdisk.target,
                sdisk.target,
                args.socketfile,
            )

    if args.offline is not True:
        logging.info("Temporary scratch file target directory: [%s]", args.scratchdir)
        fileStream.create(args.scratchdir)
        if not job.start(args, virtClient, domObj, disks):
            sys.exit(1)

    if args.level not in ("copy", "diff") and args.offline is False:
        logging.info("Started backup job with checkpoint, saving information.")
        try:
            checkpoint.save(args)
        except exceptions.CheckpointException as e:
            logging.error("Extending checkpoint file failed: [%s]", e)
            sys.exit(1)
        if not checkpoint.backup(args, domObj):
            virtClient.stopBackup(domObj)
            sys.exit(1)

    if args.startonly is True:
        logging.info("Started backup job for debugging, exiting.")
        sys.exit(0)

    backupSize: int = 0
    try:
        with ThreadPoolExecutor(max_workers=args.worker) as executor:
            futures = {
                executor.submit(
                    disk.backup, args, Disk, count, fileStream, virtClient
                ): Disk
                for count, Disk in enumerate(disks)
            }
            for future in as_completed(futures):
                size, state = future.result()
                backupSize += size
                if state is not True:
                    raise exceptions.DiskBackupFailed("Backup of one disk failed")
    except exceptions.BackupException as e:
        logging.error("Disk backup failed: [%s]", e)
    except sshError as e:
        logging.error("Remote Disk backup failed: [%s]", e)
    except Exception as e:  # pylint: disable=broad-except
        logging.critical("Unknown Exception during backup: %s", e)
        logging.exception(e)

    if args.offline is False:
        logging.info("Backup jobs finished, stopping backup task.")
        virtClient.stopBackup(domObj)

    virtClient.close()

    metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)

    if domObj.autostart() == 1:
        metadata.backupAutoStart(args)

    if counter.count.errors > 0:
        logging.error("Error during backup")
        sys.exit(1)

    if args.sshClient:
        args.sshClient.disconnect()

    if counter.count.warnings > 0 and args.strict is True:
        logging.info(
            "[%s] Warnings detected during backup operation, forcing exit code 2",
            counter.count.warnings,
        )
        sys.exit(2)

    logging.info("Total saved disk data: [%s]", lib.humanize(backupSize))
    logging.info("Finished successfully")


if __name__ == "__main__":
    main()
