#!usr/bin/python3

#===============================================================================
# Description
#===============================================================================

# Yves BAWIN July 2018

#===============================================================================
# Import modules
#===============================================================================

import os, sys, argparse, multiprocessing, datetime

import pysam        # make sure this is version 1.11
version = pysam.__version__
version = version.split('.')
version = pysam.__version__
if not int(version.split('.')[1]) > 11 or (int(version.split('.')[1]) == 11 and int(version.split('.')[2]) >= 1):
    sys.exit('Please make sure Pysam version 0.11.1 (or later) is installed (current version = {})\nRun sudo pip3 install pysam==0.11.1 to fix this\n'.format(version))
import numpy as np
import pandas as pd
#import plotnine as p9

#===============================================================================
# Parse arguments
#===============================================================================

# Create an ArgumentParser object.
parser = argparse.ArgumentParser(description = 'Create two separate BAM files (one for each subgenome genome) with the samtools lines that were uniquely assigned to the corresponding subgenome. ')

#Mandatory arguments
'''
Arguments for the directory paths leading to the directories containing the BAM files of all reads that are mapped on the reference genomes A and B
'''
parser.add_argument('-d_A', '--dir_A',
                    type = str,
                    help = 'Directory containing unfiltered bam files of mappings using subgenome A')

parser.add_argument('-d_B', '--dir_B',
                    type = str,
                    help = 'Directory containing unfiltered bam files of mappings using subgenome B')

# Optional arguments
'''
Arguments for the suffix of the output bam files, the processor resources to use, and the output and plot format (pdf/png)
'''
parser.add_argument('-d', '--dir',
                    default = '.',
                    type = str,
                    help = 'Output directory (default current directory)')
parser.add_argument('--pair_aware',
                    dest = 'pairs',
                    action = 'store_true',
                    help = '(default False).')
parser.add_argument('-p', '--processes',
                    default = 4,
                    type = int,
                    help = 'Define the number of parallel processes (default 4)')
parser.add_argument('-a', '--suffix_a',
                    default = 'RefA',
                    type = str,
                    help = 'Suffix that is used to name the bam file with reads assigned to subgenome A (default RefA)')
parser.add_argument('-b', '--suffix_b',
                    default = 'RefB',
                    type = str,
                    help = 'Suffix that is used to name the bam file with reads assigned to subgenome B (default RefB)')
parser.add_argument('--merge_bam',
                    dest = 'merge_bam_files',
                    action = 'store_true',
                    help = 'Use this option if you want to merge the categorised bam files.')
parser.add_argument('--plot',
                    dest = 'plot',
                    action = 'store_true',
                    help = 'Use this option to generate plots of the number of reads (unmapped reads, uniquely mapped reads, reads designated assigned on MAPQ score, reads assigned on alignment length, reads assigned on number of mismatches, unassigned reads) (default False).')
parser.add_argument('--pdf', 
                    dest = 'plot_type',
                    action = 'store_true',
                    help = 'Use this option if you want to plot pdf graphs (default).')
parser.add_argument('--png',
                    dest = 'plot_type', 
                    action = 'store_false',
                    help = 'Use this option if you want to plot png graphs.')
parser.set_defaults(plot = False, plot_type=True)

'''
Arguments for the penalty scores of insertions and deletions 
'''

parser.add_argument('-mmq', '--min_mapping_quality',
                    default = 20,
                    type = int,
                    help = 'Penalty for insertions in the read (default = 20).')
parser.add_argument('--clipping_weight',
                    default = 1,
                    type = float,
                    help = 'Weight of soft- and hard-clipped nucleotides (default = 1).')
parser.add_argument('--indel_weight',
                    default = 1,
                    type = float,
                    help = 'Weight of indel mutations (default = 1).')
parser.add_argument('-m', '--maximum_mismatches',
                    default = 100,
                    type = int,
                    help = 'Maximum number of allowed mismatches (default = 100).')
                    
'''
Arguments for creating one overview table with the read categorisation results for all processed bam files, bam files with read groups for GATK SNP calling, and deleting intermediate files
'''

parser.add_argument('--DeleteIntermediateFiles',
                    dest = 'DeleteIntermediateFiles',
                    action = 'store_true',
                    help = 'Boolean for deleting the bam files without secondary alignment and the sorted bam files that are created by the script(default False)')
parser.set_defaults(DeleteIntermediateFiles=False)

parser.add_argument('--Read_groups',
                    dest = 'Read_groups',
                    action = 'store_true',
                    help = 'Add read groups to the bam files after the read categorisation and index the bam files')
parser.set_defaults(Read_groups=False)

# Parse arguments to a dictionary.
args = vars(parser.parse_args())

#===============================================================================
# Functions
#===============================================================================

def print_date ():
    """
    Print the current date and time to stderr.
    """
    sys.stderr.write('-------------------\n')
    sys.stderr.write('{}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
    sys.stderr.write('-------------------\n\n')
    return


def mismatches(read, dir_A = args['dir_A'], dir_B = args['dir_B'], clipping_weight = args['clipping_weight'], indel_weight = args['indel_weight']):
    mismatch = 0
    #Count the number of inserted nucleotides.
    CIGAR = read.cigartuples
    if CIGAR:
        for char in CIGAR:
            if char[0] == 1: #Insertions
                mismatch += int(char[1]) * indel_weight
            elif char[0] == 2: #Deletions
                mismatch += int(char[1]) * indel_weight
            elif char[0] == 4: #Soft-clipping
                mismatch += int(char[1]) * clipping_weight
            elif char[0] == 5: #Hard-clipping
                mismatch += int(char[1]) * clipping_weight
                
        #Count all substitions in the MD tag.
        MD_tag = read.get_tag('MD')
        for i in range(len(MD_tag) - 1):
            if MD_tag[i].isnumeric() and MD_tag[i + 1].isalpha():
                mismatch += 1
    else:
        mismatch = read.query_length
    return mismatch


def count_mapped_reads(reads_A, reads_B):
    mapped_A, mapped_B = 0, 0
    for A, B in zip(reads_A, reads_B):
        if not A.is_unmapped:
            mapped_A += 1
        if not B.is_unmapped:
            mapped_B += 1
    return mapped_A, mapped_B


def categorise(bam, dir = args['dir'], dir_A = args['dir_A'], dir_B = args['dir_B'], RefA = args['suffix_a'], RefB = args['suffix_b'], pair_aware = args['pairs'], min_mapq = args['min_mapping_quality'], plot = args['plot'], plot_type = args['plot_type'], merge_bam = args['merge_bam_files'], max_mismatches = args['maximum_mismatches']):
    """
    The function contains all commands to categorise reads according to one of the four predefined criteria (uniquely mapped, highest MAPQ score, largest alignment length, lowest number of mismatches).
    
    Three files are created:
    *.RefA.bam: bam file containing all reads that are assigned to subgenome A
    *.RefB.bam: bam file containing all reads that are assigned to subgenome B
    *.Ref0.bam: bam file containing all reads that could not be assigned to one of the subgenomes
    """
    
    # Extract the basename from samfile A and print that name to terminal to get an on-screen overview of the progress.
    name = os.path.splitext(bam)[0]
    sys.stderr.write('* Read categorisation of sample {} \n'.format(name))
    
    #Create variables for read counts.
    unique_A, unique_B = 0, 0
    low_mapq_A, low_mapq_B = 0, 0
    least_mismatches_A, least_mismatches_B = 0, 0
    max_mismatches_A, max_mismatches_B = 0, 0
    unmapped, unassigned, low_mapq = 0, 0, 0
    count = 0
    
    #Remove secondary mappings.
    if not os.path.isfile(dir_A + '/woSecAl/' + name + '_woSecAl.bam'):
        pysam.view('-h', '-F 0x900', '-o', dir_A + '/woSecAl/' + name + '_woSecAl.bam', dir_A + '/' + name + '.bam', catch_stdout=False)
    if not os.path.isfile(dir_B + '/woSecAl/' + name + '_woSecAl.bam'):
        pysam.view('-h', '-F 0x900', '-o', dir_B + '/woSecAl/' + name + '_woSecAl.bam', dir_B + '/' + name + '.bam', catch_stdout=False)
        
    #Sort each BAM file according to the read name.
    if not os.path.isfile(dir_A + '/sorted/' + name + '_sorted.bam'):
        pysam.sort('-n', '-o', dir_A + '/sorted/' + name + '_sorted.bam', dir_A + '/' + 'woSecAl' + '/' + name + '_woSecAl' + '.bam')
    if not os.path.isfile(dir_B + '/sorted/' + name + '_sorted.bam'):
        pysam.sort('-n', '-o', dir_B + '/sorted/' + name + '_sorted.bam', dir_B + '/' + 'woSecAl' + '/' + name + '_woSecAl' + '.bam')
        
    #Open the sorted bam files.
    save = pysam.set_verbosity(0)
    bamfile_A = pysam.AlignmentFile(dir_A + '/' + 'sorted' + '/' + name + '_sorted' + '.bam','rb')
    bamfile_B = pysam.AlignmentFile(dir_B + '/' + 'sorted' + '/' + name + '_sorted' + '.bam','rb')
    pysam.set_verbosity(save)
    
    #Create three new bam files for reads that are assigned to subgenome A, subgenome B, or to neither one of them.
    Ref_A = pysam.AlignmentFile(dir + '/' + '{}.{}.bam'.format(name, RefA), 'wb', template = bamfile_A)
    Ref_B = pysam.AlignmentFile(dir + '/' + '{}.{}.bam'.format(name, RefB), 'wb', template = bamfile_B)
    Ref_A0 = pysam.AlignmentFile(dir + '/' + '{}.Ref0.{}.bam'.format(name, RefA), 'wb', template = bamfile_A)
    Ref_B0 = pysam.AlignmentFile(dir + '/' + '{}.Ref0.{}.bam'.format(name, RefB), 'wb', template = bamfile_B)
    
    #Create an iter containing all reads for samfile_A and samfile_B.
    read_iter_A = bamfile_A.fetch(until_eof=True)
    read_iter_B = bamfile_B.fetch(until_eof=True)
    read_name = None
    reads_A, reads_B = list(), list()
    
    #Evaluate each read
    for read_A, read_B in zip(read_iter_A, read_iter_B):
        count += 1
        #Check whether read_A and read_B have the same ID.
        assert read_A.query_name == read_B.query_name, 'The reads from sample {} in BAM file from reference A are not ordered in the same way than the reads in the BAM file from reference B'.format(name)
        
        #Reset the number of unmapped reads and the number of mismatches if the mapping should not be treated as pair_aware or if a new read pair is considered (i.e. read name differs from previous).
        reads_A.append(read_A)
        reads_B.append(read_B)
        if pair_aware and read_name != read_A.query_name:
            read_name = read_A.query_name
        else:
            #Evaluate the read based on each predefined criterion.
             #Read (pair) is uniquely mapped onto one of the subgenomes.
            mapped_A, mapped_B = count_mapped_reads(reads_A, reads_B)
            if mapped_A > mapped_B:
                for A in reads_A:
                    if int(A.mapping_quality) >= min_mapq:
                        Ref_A.write(A)
                        unique_A += 1
                    else:
                        Ref_A0.write(A)
                        low_mapq_A += 1
                        
            elif mapped_B > mapped_A:
                for B in reads_B:
                    if int(B.mapping_quality) >= min_mapq:
                        Ref_B.write(B)
                        unique_B += 1
                    else:
                        Ref_B0.write(B)
                        low_mapq_B += 1
                        
            elif mapped_A == 0:
                for A, B in zip(reads_A, reads_B):
                    Ref_A0.write(A)
                    Ref_B0.write(B)
                    unmapped += 1
            else:
                #Calculate number of mismatches in each read (pair).
                mismatches_A = sum([mismatches(x) for x in reads_A])
                mismatches_B = sum([mismatches(x) for x in reads_B])
                
                #Read (pair) has lowest number of mismatches with one of the subgenomes.
                if mismatches_A < mismatches_B:
                    for A in reads_A:
                        if mismatches_A < max_mismatches:
                            if int(A.mapping_quality) >= min_mapq:
                                Ref_A.write(A)
                                least_mismatches_A += 1
                            else:
                                Ref_A0.write(A)
                                low_mapq_A += 1
                        else:
                            Ref_A0.write(A)
                            max_mismatches_A += 1
                elif mismatches_B < mismatches_A:
                    for B in reads_B:
                        if mismatches_B < max_mismatches:
                            if int(B.mapping_quality) >= min_mapq:
                                Ref_B.write(B)
                                least_mismatches_B += 1
                            else:
                                Ref_B0.write(B)
                                low_mapq_B += 1
                        else:
                            Ref_B0.write(B)
                            max_mismatches_B += 1
                            
                 #If the read cannot be assigned to one of the subgenomes based on the previous criteria, the read is assigned to the BAM file with unmapped and unassigned reads (reference 0).
                else:
                    for A, B in zip(reads_A, reads_B):
                        Ref_A0.write(A)
                        Ref_B0.write(B)
                        unassigned += 1
            reads_A, reads_B = list(), list()
    Ref_A.close()
    Ref_B.close()
    Ref_A0.close()
    Ref_B0.close()
    
    #Merge BAM files if preferred.
    if merge_bam:
        pysam.merge('-f', '-n', '{}/{}'.format(dir, '{}_categorised.bam'.format(name)), '{}/{}'.format(dir, '{}.{}.bam'.format(name, RefA)), '{}/{}'.format(dir, '{}.{}.bam'.format(name, RefB)))
        pysam.merge('-f', '-n', '{}/{}'.format(dir, '{}.Ref0.bam'.format(name)), '{}/{}'.format(dir, '{}.Ref0.{}.bam'.format(name, RefA)), '{}/{}'.format(dir, '{}.Ref0.{}.bam'.format(name, RefB)))
        
    #Check whether the sum of all processed reads equals the total number of reads in the BAM file.
    save = pysam.set_verbosity(0)
    bamfile_A = pysam.AlignmentFile(dir_A + '/' + 'sorted' + '/' + name + '_sorted' + '.bam','rb')
    pysam.set_verbosity(save)
    Count_A = bamfile_A.count(until_eof=True)
    assert Count_A == count, 'Number of processed reads from sample {} does not equal the number of reads in the original BAM files: {} vs {}'.format(name, count, Count_A)
    
    #Combine all counts of the considered sample into one dataframe, which is appended to the results dataframe.
    colnames = ['Sample_ID', 'Reference', 'Criterion', 'Count']
    result = pd.DataFrame([[name, RefA, 'Uniquely_mapped', unique_A], [name, RefB, 'Uniquely_mapped', unique_B], 
                           [name, RefA, 'Least_mismatches', least_mismatches_A], [name, RefB, 'Least_mismatches', least_mismatches_B],
                           [name, RefA, 'Max_mismatches', max_mismatches_A], [name, RefB, 'Max_mismatches', max_mismatches_B],
                           [name, RefA, 'Low_mapping_quality', low_mapq_A], [name, RefB, 'Low_mapping_quality', low_mapq_B],
                           [name, 'Both', 'Unmapped', unmapped], [name, 'Both', 'Unassigned', unassigned]], columns = colnames)
    return result


def DeleteIntermediate(dir_A = args['dir_A'], dir_B = args['dir_B'], dir = args['dir']):
    '''
    The function deletes all intermediate bam files that were created during the read categorisation (_woSecAl.bam, _sorted.bam).
    '''
    for file_A in os.listdir(dir_A):
        if file_A.endswith('.bam'):
            name = os.path.splitext(file_A)[0]
            os.remove('{}/woSecAl/{}_woSecAl.bam'.format(dir_A, name))
            os.remove('{}/sorted/{}_sorted.bam'.format(dir_A, name))
            os.remove('{}/woSecAl/{}_woSecAl.bam'.format(dir_B, name))
            os.remove('{}/sorted/{}_sorted.bam'.format(dir_B, name))
    os.rmdir('{}/woSecAl/'.format(dir_A))
    os.rmdir('{}/sorted/'.format(dir_A))
    os.rmdir('{}/woSecAl/'.format(dir_B)) 
    os.rmdir('{}/sorted/'.format(dir_B))
    return


def add_RG(sample_name, dir = args['dir']):
    """
    This function adds readgroups to the .bam files, creates .bai indexed bam files, and saves the files to a new directory.
    """
    pysam.sort('-o', dir + '/' + sample_name.replace('.bam','_sorted.bam'), dir + '/' + sample_name)
    cmd1 = "PicardCommandLine AddOrReplaceReadGroups  INPUT=" + dir + '/' + sample_name.replace('.bam','_sorted.bam') + " OUTPUT=" + dir + "/RG/" + sample_name.replace('.bam','_RG.bam') + " RGID= 2017_lane RGSM=" + sample_name.replace('.bam','_RG.bam') + " RGPL=Hi-Seq RGLB=" + sample_name.replace('.bam','_RG.bam_lib') + " RGPU=lane"
    os.system(cmd1)
    pysam.index(dir + '/RG/' + sample_name.replace('.bam', '_RG.bam'))
    return

#===============================================================================
# Script
#===============================================================================

if __name__ == '__main__':

    print_date()
    
    #Create new directories for files without secondary mappings that are sorted according read ID.
    if not os.path.isdir(args['dir_A'] + '/' + 'woSecAl'):
        os.mkdir(args['dir_A'] + '/' + 'woSecAl')
    
    if not os.path.isdir(args['dir_B'] + '/' + 'woSecAl'):
        os.mkdir(args['dir_B'] + '/' + 'woSecAl')
    
    if not os.path.isdir(args['dir_A'] + '/' + 'sorted'):
        os.mkdir(args['dir_A'] + '/' + 'sorted')
    
    if not os.path.isdir(args['dir_B'] + '/' + 'sorted'):
        os.mkdir(args['dir_B'] + '/' + 'sorted')
    if not os.path.isdir(args['dir']):
        os.mkdir(args['dir'])
    
    #Categorise the reads of all bam files in the give directory.
    with multiprocessing.Pool(args['processes']) as p:
        bam_files_A = [f for f in os.listdir(args['dir_A']) if f.endswith('.bam')]
        results = p.map(categorise, bam_files_A)
    
    #Concatenate the results of all files and convert the results into a pivot table
    Output = pd.concat(results, ignore_index = True)
    Log = Output.pivot_table(index=['Criterion', 'Reference'], columns='Sample_ID')
    Log.columns.name = None
    Log.index.name = None
    Log.to_csv('{}/Results_pivot_ReadCategorisation.txt'.format(args['dir']), sep = '\t')
    
    #Optional: delete intermediate files
    if args['DeleteIntermediateFiles'] == True:
        sys.stderr.write('* Deleting intermediate bam files ...\n')
        DeleteIntermediate()
    
    #Optional: make plot
    if args['plot'] == True:
        sys.stderr.write('* Creating a stacked bar graph with the results of the read categorisation ...\n')
        plot()
    
    #Optional: Add read groups to the newly created bam files
    if args['Read_groups'] == True:
        if not os.path.isdir(args['dir'] + '/RG'):
            os.mkdir(args['dir'] + '/RG')
        
        samples = [f for f in os.listdir(args['dir']) if f.endswith('_categorised.bam')]
        with multiprocessing.Pool(args['processes']) as p:
            sys.stderr.write("* Adding readgroups to the categorised bam files ...\n")
            p.map(add_RG, samples)
    
    sys.stderr.write('* Finished!\n')
    
    print_date()
