#!/usr/bin/env python

import argparse
import configparser
import sys
import time
import XenAPI


PBD_PLUG_RETRY_DELAY = 30  # In seconds.


# ------------------------------------------------------------------------------

def strtobool(str):
    if not str:
        return False
    str = str.lower()
    if str in ('y', 'yes', 't', 'true', 'on', '1'):
        return True
    if str in ('n', 'no', 'f', 'false', 'off', '0'):
        return False
    raise ValueError("invalid truth value '{}'".format(str))


# ------------------------------------------------------------------------------

def get_xapi_session():
    session = XenAPI.xapi_local()
    try:
        session.xenapi.login_with_password('root', '', '', 'plug-late-sr')
    except Exception as e:
        raise Exception('Failed to get local Xen API session: {}'.format(e))
    return session


def get_local_host_uuid():
    with open('/etc/xensource-inventory', 'r') as f:
        for line in f.readlines():
            if line.startswith('INSTALLATION_UUID'):
                return line.split("'")[1]


def get_local_host_ref(session):
    return session.xenapi.host.get_by_uuid(get_local_host_uuid())


def get_master_ref(session):
    pools = session.xenapi.pool.get_all()
    return session.xenapi.pool.get_master(pools[0])


# ------------------------------------------------------------------------------

def get_pbd_plug_retry_delay(config):
    try:
        return max(int(config.get('pbd-plug-retry-delay')), 5)
    except Exception:
        return PBD_PLUG_RETRY_DELAY


def has_auto_poweron_vms(config):
    return strtobool(config.get('auto-poweron-vms', '1'))


def plug_pbds(session, master_ref, sr_ref, sr_uuid, retry_delay):
    pbd_refs = session.xenapi.SR.get_PBDs(sr_ref)

    # Find master PBD and move it before the slaves.
    master_pbd_index = None
    for index, pbd_ref in enumerate(pbd_refs):
        pbd = session.xenapi.PBD.get_record(pbd_ref)
        if pbd['host'] == master_ref:
            master_pbd_index = index
            break

    should_start_vms = True
    if master_pbd_index is not None:
        if master_pbd_index != 0:
            pbd_refs.insert(0, pbd_refs.pop(master_pbd_index))
        if session.xenapi.PBD.get_currently_attached(pbd_refs[0]):
            should_start_vms = False
    else:
        print('Unable to find master PBD of {}'.format(sr_uuid))

    # Plug all PBDs.
    for pbd_ref in pbd_refs:
        while True:
            try:
                session.xenapi.PBD.plug(pbd_ref)
                break
            except Exception:
                time.sleep(retry_delay)
    return should_start_vms


def get_vm_refs(session, sr_ref):
    vm_refs = []

    vdis = session.xenapi.VDI.get_all_records_where(
        'field "SR" = "{}"'.format(sr_ref)
    )
    for vdi in vdis.values():
        for vbd_ref in vdi['VBDs']:
            vbd = session.xenapi.VBD.get_record(vbd_ref)
            vm_refs.append(vbd['VM'])

    return vm_refs


def start_vms(session, vm_refs):
    for vm_ref in vm_refs:
        vm = session.xenapi.VM.get_record(vm_ref)
        other_config = vm.get('other_config')
        if not other_config:
            continue

        auto_poweron = other_config.get('auto_poweron')
        if not auto_poweron:
            continue

        try:
            session.xenapi.VM.start(vm_ref, False, False)
        except XenAPI.Failure as e:
            if e.details[0] != 'VM_BAD_POWER_STATE':
                print('Failed to start VM: {}'.format(e))


# ------------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', required=True, metavar='PATH')
    args = parser.parse_args()

    session = get_xapi_session()
    master_ref = get_master_ref(session)
    if get_local_host_ref(session) != master_ref:
        sys.exit(0)

    config = configparser.ConfigParser()
    config.read(args.config)

    all_srs_config = None
    for sr_uuid, sr_config in config._sections.iteritems():
        if sr_uuid == '*':
            all_srs_config = sr_config
            break

    vm_refs = []

    if all_srs_config is not None:
        # Handle wildcard.
        retry_delay = get_pbd_plug_retry_delay(all_srs_config)
        srs = session.xenapi.SR.get_all_records()
        auto_poweron = has_auto_poweron_vms(all_srs_config)

        for sr_ref, sr in srs.iteritems():
            should_start_vms = plug_pbds(
                session, master_ref, sr_ref, sr['uuid'], retry_delay
            )
            if auto_poweron and should_start_vms:
                vm_refs.extend(get_vm_refs(session, sr_ref))
    else:
        # Handle specific SRs.
        for sr_uuid, sr_config in config._sections.iteritems():
            try:
                sr_ref = session.xenapi.SR.get_by_uuid(sr_uuid)
            except Exception as e:
                print('Failed to get SR: {}'.format(e))
                continue

            retry_delay = get_pbd_plug_retry_delay(sr_config)
            should_start_vms = plug_pbds(
                session, master_ref, sr_ref, sr_uuid, retry_delay
            )
            if should_start_vms and has_auto_poweron_vms(sr_config):
                vm_refs.extend(get_vm_refs(session, sr_ref))

    start_vms(session, set(vm_refs))


if __name__ == '__main__':
    main()
