import argparse
import glob
from pathlib import Path
from typing import List, Optional

import runner
from fq_utils import process_fastqs
from runner import ProcessCmd, create_dragen_cmd, create_setup_ref_cmds


def create_batch_cmds(
        static_dragen_cmd: str, fastqs: List[str] = (), bam_batch: List[str] = (), cnv_normals_dir: str = None
) -> List[ProcessCmd]:
    process_cmds = []
    # Process and add FASTQ commands
    fq_sample_names, fastq_list = process_fastqs(fastqs, Path('.'))
    for sample_name in fq_sample_names:
        process_cmds.append(ProcessCmd(
            ' '.join([
                static_dragen_cmd,
                f'--output-file-prefix {sample_name}',
                f'--vc-sample-name {sample_name}',
                f'--fastq-list {fastq_list}',
                f'--fastq-list-sample-id {sample_name}',
                *([f'--cnv-normals-file {file}' for file in glob.glob(f'{cnv_normals_dir}*')]
                  if cnv_normals_dir else []),
            ])
        ))

    # Process and add BAM batch commands
    for bam in bam_batch:
        output_file_prefix = '.'.join(bam.split('/')[-1].split('.')[0:-1])
        process_cmds.append(ProcessCmd(
            ' '.join([
                static_dragen_cmd,
                f'--output-file-prefix {output_file_prefix}',
                f'--bam-input {bam}',
                *([f'--cnv-normals-file {file}' for file in glob.glob(f'{cnv_normals_dir}*')]
                  if cnv_normals_dir else []),
            ])
        ))
    return process_cmds


def execute_batch(
        ref_tar: str, static_args: List[str], fastqs: List[str], bam_batch: List[str], cnv_normals_tar: Optional[str]
) -> None:
    setup_cmds: List[ProcessCmd] = create_setup_ref_cmds(ref_tar)

    cnv_normals_dir = None
    if cnv_normals_tar:
        cnv_normals_dir = '/ephemeral/cnv-normals/'
        setup_cmds.append(ProcessCmd(f'tar -C {cnv_normals_dir} -xf {cnv_normals_tar}', cnv_normals_dir))
    runner.run_dragen_cmds(setup_cmds)  # Need to run first to extract panel of normals

    static_dragen_args = create_dragen_cmd(static_args)
    dragen_cmds = create_batch_cmds(static_dragen_args, fastqs, bam_batch, cnv_normals_dir)
    runner.run_dragen_cmds(dragen_cmds)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--fastqs', nargs='*', default=[], type=str)
    parser.add_argument('--bam-batch', nargs='*', default=[], type=str)
    parser.add_argument('--cnv-normals-tar', type=str)
    parser.add_argument('--ref-tar', type=str)
    parser.add_argument('--debug', action='store_true')
    args, extras = parser.parse_known_args()

    runner.execute_with_no_traceback(
        lambda: execute_batch(args.ref_tar, extras, args.fastqs, args.bam_batch, args.cnv_normals_tar),
        args.debug
    )


if __name__ == '__main__':
    main()
