#!usr/bin/python3

#===============================================================================
# Description
#===============================================================================

"""
Script to map reads to a reference genome in parallel using BWA mem, 
quality filter the bam files (q20) and add read groups.
"""

#===============================================================================
# Import modules
#===============================================================================

import argparse
import os
import sys
import multiprocessing
from datetime import datetime

#===============================================================================
# Parse arguments
#===============================================================================

# Create an ArgumentParser object
parser = argparse.ArgumentParser(description = 'Map paired end reads of all fastq files in parallel, quality filter bam files (q20) and add read groups.')

# Add positional arguments (mandatory)
parser.add_argument('reference',
                    help = 'Full path and name of the reference genome')
# Add optional arguments
parser.add_argument('-d', '--dir',
                    default = './',
                    type = str,
                    help = 'Directory containing fastq files (default current directory)')
parser.add_argument('-s', '--suffix',
                    default = '.fq',
                    type = str,
                    help = 'Suffix of fastq files to take into account for mapping. Zipped fastq files are allowed. (default .fq)')
parser.add_argument('-p', '--processes',
                    default = 4,
                    type = int,
                    help = 'Define the number of parallel processes (default 4)')
parser.add_argument('-l', '--log',
                    default = "map_BWA_parallel.log",
                    type = str,
                    help = 'Indicate the name of the log file (default map_BWA_parallel.log)')
parser.add_argument('-o', '--out',
                    default = './',
                    type = str,
                    help = 'Directory for output bam files(default current directory)')
# Parse arguments to a dictionary
args = vars(parser.parse_args())

#===============================================================================
# Functions
#===============================================================================

def print_date ():
    """
    Print the current date and time to stderr.
    """
    sys.stderr.write('----------------\n')
    sys.stderr.write('{}\n'.format(datetime.now().strftime('%Y-%m-%d %H:%M')))
    sys.stderr.write('----------------\n\n')
    return

def commands_mapping (basename, dir = args['dir'], suffix = args['suffix'], reference = args['reference'], log = args['log'], out = args['out']):
    """
    This function contains all commands needed to map reads to the reference genome.
    """
    sys.stderr.write('*')
    
    # map reads to reference genome
    cmd1 = "bwa mem -M -t 16 " + reference + " " + dir + basename + suffix + " > " + out + "/" + basename + ".BWA.sam 2>> " + out + "/" + basename + ".BWA.log"
    os.system(cmd1)
    
    # convert sam to sorted bam
    cmd2 = "samtools view -bS " + out + "/" + basename + ".BWA.sam | samtools sort -o " + out + "/" + basename + ".BWA.bam"
    os.system(cmd2)
    
    # index bam file
    cmd3 = "samtools index " + out + "/" + basename + ".BWA.bam"
    os.system(cmd3)
     
    # concat sample log to log file
    os.system("cat " + out + "/" + basename + ".BWA.log >> " + log)
    os.system("echo \"" + basename + ".BWA.log finished\">> " + log)
    os.system("rm " + out + "/" + basename + ".BWA.log")
    
    #delete sam file (to save storage space)
    os.system("rm " + out + "/" + basename + ".BWA.sam")
    
    return

def filter_bam_q20 (basename, out = args['out']):
    """
    This function filters bam-files that have only aligned reads with
    a mapping quality score of 20 or higher = also filter
    for not uniquely mapped reads (q is 0 for not uniquely mapped reads)
    and save the files to a new folder
    """
    sys.stderr.write('*')
    
    # filter bam file
    cmd1 = "samtools view -q 20 -b " + out + "/" + basename + " -o " + out + "_q20/" + basename.replace('.BWA.bam','') + "_q20.BWA.bam"
    os.system(cmd1)
    # index bam file
    cmd2 = "samtools index " + out + "_q20/" + basename.replace('.BWA.bam','') + "_q20.BWA.bam"
    os.system(cmd2)
    
    return

def add_RG (basename, out = args['out']):
    """
    This function adds readgroups to the _q20.BWA.bam files
    and save the files to a new folder
    """
    sys.stderr.write('*')
    
    # add readgroup
#    cmd1 = "picard-tools AddOrReplaceReadGroups INPUT=" + out + "_q20/" + basename.replace('.BWA.bam','') + "_q20.BWA.bam OUTPUT=" + out + "_q20_RG/" + basename.replace('.BWA.bam','') + "_q20_RG.BWA.bam RGID=2017_lane RGSM=" + basename.replace('.BWA.bam','') + "_q20.BWA.bam RGPL=Hi-Seq RGLB=" + basename.replace('.BWA.bam','') + "_q20.BWA.bam_lib RGPU=lane"
    cmd1 = "PicardCommandLine AddOrReplaceReadGroups INPUT=" + out + "_q20/" + basename.replace('.BWA.bam','') + "_q20.BWA.bam OUTPUT=" + out + "_q20_RG/" + basename.replace('.BWA.bam','') + "_q20_RG.BWA.bam RGID=2017_lane RGSM=" + basename.replace('.BWA.bam','') + "_q20.BWA.bam RGPL=Hi-Seq RGLB=" + basename.replace('.BWA.bam','') + "_q20.BWA.bam_lib RGPU=lane"
    os.system(cmd1)
    
    # index bam file
    cmd2 = "samtools index " + out + "_q20_RG/" + basename.replace('.BWA.bam','') + "_q20_RG.BWA.bam"
    os.system(cmd2)

#===============================================================================
# Script
#===============================================================================

if __name__ == '__main__':
    print_date()
    
    # create mapping dir
    if not os.path.isdir(args['out']):
        os.system('mkdir ' + args['out'])

    # create mapping dir q20 filtered 
    if not os.path.isdir(args['out'] + '_q20'):
        os.system('mkdir ' + args['out'] + '_q20')

    # create mapping dir q20 filtered RG
    if not os.path.isdir(args['out'] + '_q20_RG'):
        os.system('mkdir ' + args['out'] + '_q20_RG')

    # map reads
    # creating a pool of 4 workers
    with multiprocessing.Pool(args['processes']) as p:
        sys.stderr.write("* Started mapping of read files ...\n")
        # build a list of tasks (= all basenames of fastqfiles)
        tasks = [f[:(len(args['suffix']) * -1)] for f in os.listdir(args['dir']) if f.endswith(args['suffix'])]
        
        # run tasks in parallel
        p.map(commands_mapping, tasks)
    
    # filter bam files
    with multiprocessing.Pool(args['processes']) as p:
        sys.stderr.write("* Started filtering bam files with mapping quality score of 20 or higher...\n")
        # build a list of tasks (= all basenames of bamfiles)
        tasks = [f for f in os.listdir(args['out']) if f.endswith('.BWA.bam')]
        
        # run tasks in parallel
        p.map(filter_bam_q20, tasks)
    
    # add readgroup
    with multiprocessing.Pool(args['processes']) as p:
        sys.stderr.write("* Started adding readgroups to bam files with mapping quality score of 20 or higher...\n")
        # build a list of tasks (= all basenames of bamfiles)
        tasks = [f for f in os.listdir(args['out']) if f.endswith('.BWA.bam')]
        
        # run tasks in parallel
        p.map(add_RG, tasks)
    
    sys.stderr.write('* Finished\n\n')
    print_date()