#!/usr/bin/env python3
#
# Copyright 2021-2024, Julian Catchen <jcatchen@illinois.edu>
#
# This file is part of Stacks.
#
# Stacks is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Stacks is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Stacks.  If not, see <http://www.gnu.org/licenses/>.
#

import argparse
import sys
import os
import math
from datetime import datetime
from enum import Enum
from operator import itemgetter, attrgetter

#
# Constant values
#
SUMSTATS_COLS = 21
VCF_SAMP_STRT = 9

#
# Global configuration variables.
#
sumstats_path = ""
vcf_path      = ""
popmap_path   = ""
popmap_used   = False
out_path      = ""
ref_aligned   = True
ctg_limit     = 500000
polarize      = False
plots         = False
smooth        = False
sigma         = 150000.0
verbose       = False

class Strand(Enum):
    plus  = 0
    minus = 1

class SmoothStat:
    def __init__(self, loc_id, pop, col, chr, bp, allele_cnt, stat):
        self.id   = loc_id
        self.pop  = pop
        self.col  = col
        self.chr  = chr
        self.bp   = bp
        self.cnt  = allele_cnt
        self.stat = stat
        self.smoothed = 0.0
        
class Pop:
    def __init__(self, pop, p, q, cnt, priv):
        self.pop       = pop
        self.p_freq    = p
        self.q_freq    = q
        self.cnt       = cnt
        self.private   = priv
        self.priv_all  = None # If there is a private allele, which one is it?

    def __str__(self):
        s = "{} - p freq: {:.3f}; q freq: {:.3f}; sample cnt: {}; private? {}".format(
            self.pop, self.p_freq, self.q_freq, self.cnt, self.private)
        if self.private == True:
            s += " [allele: '{}']".format(self.priv_all)

        return s            

class SNP:
    def __init__(self):
        self.bp       = 0
        self.col      = 0
        self.p_allele = ""
        self.q_allele = ""
        self.pops     = {}    # Includes parental populations and 'vcf' for all VCF samples.
        self.indvpops = {}    # Indexed by population name; includes individual populations from VCF.
        self.priv_pop = False # True if there is a private allele in one or more of the populations at this site.
        self.priv_hyb = False # Is the private allele present in the hybrids?

    def __str__(self):
        s = "Column {}; P allele: '{}', Q allele '{}', private? {}".format(self.col, self.p_allele, self.q_allele, self.priv_pop)
        return s

class Locus:
    def __init__(self, id):
        self.id       = id
        self.chr      = None
        self.bp       = 0
        self.snps     = {}    # Indexed by column.
        self.private  = False # At least one SNP at this locus contains a private allele.
    def __str__(self):
        s  = "Locus {}".format(self.id)
        if self.chr != None:
            s += " [{}:{}:{}]".format(self.chr, self.bp, self.strand)
        s +="; {} snp(s)\n".format(len(self.snps))
        for snp in self.snps:
            s += "  "
            s += self.snps[snp].__str__() + "\n"
            for pop in self.snps[snp].pops:
                s += "    "
                s += self.snps[snp].pops[pop].__str__() + "\n"
        return s


def parse_command_line():
    global sumstats_path
    global vcf_path
    global popmap_path, popmap_used
    global out_path
    global polarize
    global plots, smooth, sigma
    global verbose

    desc = '''Supply the populations.sumstats.tsv file containing the 'parental' populations \
    which were used to define private alleles. Supply a VCF file from the 'hybrid' \
    populations which may contain those private alleles. Optionally export a set of VCF \
    files, one per parental population, that list SNPs with the non-private allele as \
    'REF' and the private allele as 'ALT' for plotting.'''

    p = argparse.ArgumentParser(description=desc)

    #
    # Add options.
    #
    p.add_argument("--sumstats", type=str, metavar="path", required=True,
                   help="path to parental populations sumstats file.")
    p.add_argument("--vcf", type=str, metavar="path", required=True,
                   help="path to hybrids VCF file.")
    p.add_argument("--popmap", type=str, metavar="path",
                   help="path to a population map used to generate the VCF file.")
    p.add_argument("--polarize", action="store_true",
                   help="Export a set of VCF files, one per parent population, with private alleles polarized.")
    p.add_argument("--plots", action="store_true",
                   help="Write a file for plotting private allele frequencies along the genome.")
    p.add_argument("--smooth", action="store_true",
                   help="Smooth private allele values in plotting files, if data are referenced aligned.")
    p.add_argument("--sigma", type=float, metavar="size",
                   help="Control the smoothing. Default 75000bp, increase to smooth more, decrease to smooth less.")
    p.add_argument("-o", "--out-path", type=str, metavar="prefix", required=True,
                   help="prefix path for writing output files.")
    p.add_argument("--verbose", action="store_true", dest="verbose",
                   help="Ouput a summary of each locus.")

    #
    # Parse the command line
    #
    args = p.parse_args()

    if args.vcf != None:
        vcf_path = args.vcf
    if args.popmap != None:
        popmap_path = args.popmap
        popmap_used = True
    if args.sumstats != None:
        sumstats_path = args.sumstats
    if args.out_path != None:
        out_path = args.out_path
    if args.verbose != None:
        verbose = args.verbose
    if args.polarize != None:
        polarize = args.polarize
    if args.plots != None:
        plots = args.plots
    if args.smooth != None:
        smooth = args.smooth
    if args.sigma != None:
        sigma = args.sigma

    if len(sumstats_path) == 0:
        print >> sys.stderr, "You must specify the path to the sumstats file."
        p.print_help()
        sys.exit()
    if len(vcf_path) == 0:
        print >> sys.stderr, "You must specify the path to the hybrids VCF file."
        p.print_help()
        sys.exit()
    if len(out_path) == 0:
        print >> sys.stderr, "You must specify a path prefix to output files."
        p.print_help()
        sys.exit()


def complement(nuc):
    if nuc == 'A':
        return 'T'
    elif nuc == 'C':
        return 'G'
    elif nuc == 'G':
        return 'C'
    elif nuc == 'T':
        return 'A'
    elif nuc == 'a':
        return 't'
    elif nuc == 'c':
        return 'g'
    elif nuc == 'g':
        return 'c'
    elif nuc == 't':
        return 'a'


def parse_popmap_file(path, vcfpops, vcfpops_rev):
    fh = open(path, "r")

    lineno = 0
    for line in fh:
        lineno += 1
        line    = line.strip("\n")

        if line[0] == "#":
            continue

        parts = line.split("\t")
        if len(parts) < 2:
            print("Error parsing popmap '{}' at line {}".format(path, lineno), file=sys.stderr)
            exit()
            
        if parts[1] not in vcfpops:
            vcfpops[parts[1]] = []
            vcfpops[parts[1]].append(parts[0])
        vcfpops_rev[parts[0]] = parts[1]

    print("  Found {} samples in {} populations.".format(len(vcfpops_rev), len(vcfpops)), file=sys.stderr)
    return


def parse_sumstats_file(path, loci, populations_key):
    global ref_aligned, ctg_limit, smooth, plots
    
    fh = open(path, "r")

    #
    # Parse the sumstats header and extract the population names/sample names
    #
    lineno = 0
    for line in fh:
        lineno += 1
        line    = line.strip("\n")

        if line[0] != "#" or line[0:10] == "# Locus ID":
            break

        parts    = line.split("\t")
        pop_name = parts[0][2:]
        populations_key[pop_name] = parts[1].split(",")

    #
    # Parse SNP records.
    #
    for line in fh:
        lineno += 1
        line    = line.strip("\n")

        if line[0] == "#":
            continue
        
        parts = line.split("\t")

        if len(parts) != SUMSTATS_COLS:
            print("Error parsing '{}' on line {}, incorrect number of columns ('{}'), file not in correct format.".format(
                path, lineno, len(parts)), file=sys.stderr)
            exit()
    
        locus_id = int(parts[0])
        column   = int(parts[3])
        popul    = parts[4]
        p_allele = parts[5]
        q_allele = parts[6]
        all_cnt  = int(parts[7]) * 2
        p_freq   = float(parts[8])
        q_freq   = 1 - p_freq
        private  = True if parts[20] == "1" else False

        pop = Pop(popul, p_freq, q_freq, all_cnt, private)

        #
        # Have we already seen this locus and SNP?
        #
        if locus_id in loci:
            loc = loci[locus_id]
        else:
            loc = Locus(locus_id)
            loci[locus_id] = loc

        #if parts[1] != "un":
        loc.chr = parts[1]
        loc.bp  = int(parts[2])

        if private == True:
            loc.private = True

        if column in loc.snps:
            snp = loc.snps[column]
        else:
            snp = SNP()
            snp.bp  = int(parts[2])
            snp.col = column
            loc.snps[column] = snp

        if p_allele != "-":
            snp.p_allele = p_allele
        if q_allele != "-":
            snp.q_allele = q_allele
        if private:
            snp.priv_pop = True
            
        snp.pops[popul] = pop
        
    fh.close()

    snp_cnt = 0
    for l in loci:
        loc = loci[l]
        snp_cnt += len(loc.snps)
        #
        # Find which allele, or alleles, was/were the private one.
        # In most cases, a single allele will be private between the parent populations.
        # However, there can be alternative, private alleles where each parent population
        # has a unique allele.
        #
        if loc.private == False:
            continue
        
        for col in loc.snps:
            snp     = loc.snps[col]
            alleles = {snp.p_allele: 0, snp.q_allele: 0}
            
            for p in snp.pops:
                pop = snp.pops[p]
                if pop.p_freq > 0:
                    alleles[snp.p_allele] += 1
                if pop.q_freq > 0:
                    alleles[snp.q_allele] += 1

            for p in snp.pops:
                pop = snp.pops[p]
                if pop.private == True:
                    if alleles[snp.p_allele] == 1 and pop.p_freq > 0:
                        pop.priv_all = snp.p_allele
                    elif alleles[snp.q_allele] == 1 and pop.q_freq > 0:
                        pop.priv_all = snp.q_allele

    print("  Read {} loci containing {} SNPs.".format(len(loci), snp_cnt), file=sys.stderr)

    l = next(iter(loci))
    if loci[l].chr == "un" or loci[l].chr == None:
        ref_aligned = False
        ctg_limit   = 1
        if smooth == True or plots == True:
            smooth = False
            plots  = False
            print("Warning: unable to generate plots or smooth them without reference-aligned data.", file=sys.stderr)


def parse_hybrid_vcf(path, loci, populations_key, vcfpops, vcfpops_rev, chrs_list):
    global ref_aligned

    fh = open(path, "r")

    not_found = 0
    snp_cnt   = 0
    locus_cnt = set()

    #
    # Parse the chromosome names and lengths from the VCF header, if available.
    #
    lineno = 0
    parts  = []
    for line in fh:
        lineno += 1
        line    = line.strip("\n")

        #
        # Skip the file definition lines, but record the contig lengths.
        #
        if line[0:8] == "##contig":
            line = line[13:-1]
            chr, chrlen = line.split(",")
            chrlen = chrlen[7:]
            chrs_list.append( (chr, int(chrlen)) )

        if line[0:6] == "#CHROM":
            parts = line.split("\t")
            break

    if ref_aligned == False:
        chrs_list.append( ("un", 1) )

    #
    # List of sample names, to be taken from the VCF header.
    #
    populations_key['vcf'] = []
    sample_cols = {}
    pop_cols    = {}

    if popmap_used == True:
        for p in vcfpops:
            pop_cols[p] = []
    
    #
    # Parse the sample names from the VCF header and record which columns the samples from
    # each population occur in.
    #
    popmap_samples_found = 0
    for i in range(9, len(parts)):
        populations_key['vcf'].append(parts[i])
        
        if popmap_used and parts[i] in vcfpops_rev:
            popmap_samples_found += 1
            sample_cols[i] = vcfpops_rev[parts[i]]
            pop_cols[vcfpops_rev[parts[i]]].append(i)

    if popmap_used == True:
        print("  Found {} samples from population map in VCF file.".format(popmap_samples_found), file=sys.stderr)

    for line in fh:
        lineno += 1
        line    = line.strip("\n")

        parts = line.split("\t")

        #
        # If referenced aligned, column 3 contains the locus ID, column of the SNP within the
        # locus and the strand, e.g. '152:79:-'
        # Otherwise, if contains just the locus ID and column of the SNP.
        # 
        if ref_aligned == True:
            locus_id, column, strand = parts[2].split(":")
        else:
            locus_id, column = parts[2].split(":")
            strand = "+"
        locus_id = int(locus_id)
        column   = int(column)
        strand   = Strand.minus if strand == "-" else Strand.plus

        if locus_id not in loci:
            not_found += 1
            continue
        if column not in loci[locus_id].snps:
            not_found += 1
            continue

        locus_cnt.add(locus_id)
        snp_cnt += 1
        
        snp = loci[locus_id].snps[column]

        vcfpop = Pop("vcf", 0.0, 0.0, 0, False)

        if popmap_used:
            indv_pops = {}
            for p in pop_cols:
                indv_pops[p] = Pop(p, 0.0, 0.0, 0, False)
                
        ref_allele = parts[3]
        alt_allele = parts[4]

        if strand == Strand.minus:
            ref_allele = complement(ref_allele)
            alt_allele = complement(alt_allele)
                    
        if ((ref_allele != snp.p_allele and ref_allele != snp.q_allele) or
            (alt_allele != snp.p_allele and alt_allele != snp.q_allele)):
            print("Error parsing VCF file on line {}; ref/alt alleles {}/{} do not match sumstats file ({}/{}).\n".format(
                lineno, ref_allele, alt_allele, snp.p_allele, snp.q_allele), file=sys.stderr)
            exit(1)

        genotypes = []
        
        for i in range(VCF_SAMP_STRT, len(parts)):
            fields = parts[i].split(",")
            gtype  = fields[0].split(":")[0]

            if gtype == "./.":
                sample_gtype = ( None, None )
            elif gtype == "0/0":
                sample_gtype = (ref_allele, ref_allele) 
            elif gtype == "0/1":
                sample_gtype = (ref_allele, alt_allele)
            elif gtype == "1/0":
                sample_gtype = (ref_allele, alt_allele)
            elif gtype == "1/1":
                sample_gtype = (alt_allele, alt_allele)

            genotypes.append(sample_gtype)
            
        for i in range(0, len(genotypes)):
            gtype = genotypes[i]
            
            if gtype[0] != None:
                vcfpop.cnt += 1
                if gtype[0] == snp.p_allele:
                    vcfpop.p_freq += 1.0
                else:
                    vcfpop.q_freq += 1.0
            if gtype[1] != None:
                vcfpop.cnt += 1
                if gtype[1] == snp.p_allele:
                    vcfpop.p_freq += 1.0
                else:
                    vcfpop.q_freq += 1.0

            if popmap_used:
                indvpop = indv_pops[sample_cols[i + VCF_SAMP_STRT]]
                if gtype[0] != None:
                    indvpop.cnt += 1
                    if gtype[0] == snp.p_allele:
                        indvpop.p_freq += 1.0
                    else:
                        indvpop.q_freq += 1.0
                if gtype[1] != None:
                    indvpop.cnt += 1
                    if gtype[1] == snp.p_allele:
                        indvpop.p_freq += 1.0
                    else:
                        indvpop.q_freq += 1.0

        if vcfpop.cnt > 0:
            vcfpop.p_freq = vcfpop.p_freq / vcfpop.cnt
            vcfpop.q_freq = vcfpop.q_freq / vcfpop.cnt

        if popmap_used:
            for p in indv_pops:
                indvpop = indv_pops[p]
                if indvpop.cnt > 0:
                    indvpop.p_freq = indvpop.p_freq / indvpop.cnt
                    indvpop.q_freq = indvpop.q_freq / indvpop.cnt
            
        if snp.priv_pop == True:
            for p in snp.pops:
                pop = snp.pops[p]
                if pop.private == True:
                    if pop.priv_all == snp.p_allele and vcfpop.p_freq > 0:
                        snp.priv_hyb = True
                    elif pop.priv_all == snp.q_allele and vcfpop.q_freq > 0:
                        snp.priv_hyb = True

        snp.pops[vcfpop.pop] = vcfpop

        if popmap_used:
            for p in indv_pops:
                snp.indvpops[p] = indv_pops[p]

    fh.close()

    print("  Read {} loci, {} snps, from {} samples.".format(
        len(locus_cnt), snp_cnt, len(populations_key['vcf'])), file=sys.stderr)

                    
def order_loci(chrs_list, loci, ordered_loci):
    global ref_aligned

    #
    # Sort loci containing private alleles onto their respective chromosomes.
    #
    for loc_id in loci:
        loc = loci[loc_id]

        if loc.chr == None:
            continue

        if loc.private == False:
            continue

        if loc.chr not in ordered_loci:
            ordered_loci[loc.chr] = []
        ordered_loci[loc.chr].append(loc)

    for chr in ordered_loci:
        ordered_loci[chr].sort(key=attrgetter('bp'))

    #
    # Sort the chromosomes by length
    #
    chrs_list.sort(key=lambda x:x[1], reverse=True)
    

def write_private_allele_summary(chrs_list, ordered_loci, populations_key, vcfpops, out_path):
    datestamp = datetime.now().strftime("%Y%m%d")

    #
    # Open the output file.
    #
    out_fh = open(out_path + "_private_alleles.tsv", 'w')
    out_fh.write("# Generated by stacks-private-alleles, date {}\n# {}\n".format(datestamp, " ".join(sys.argv)))
    out_fh.write("# Chr\tBP\tLocus\tColumn\tPrivAllele\tParentPop\tParentFreq\tTotCnt\tPrivCnt\tVCF-Freq\tTotCnt\tPrivCnt")

    if popmap_used:
        for p in vcfpops:
            out_fh.write("\t{}-Freq\tTotCnt\tPrivCnt".format(p))
    out_fh.write("\n")
    print("\nWriting a summary of private alleles: \n    {}\n".format(out_path + "_private_alleles.tsv"), file=sys.stderr)

    tot_snps_per_pop = {}
    priv_pop = {}
    priv_hyb = {}
    chrno    = 0

    #
    # Iterate over chromosomes/scaffolds, longest to shortest.
    #
    for chr, chrlen in chrs_list:
        if chr not in ordered_loci:
            continue
        if chrlen < ctg_limit:
            continue

        chrno += 1
        if chrno <= 3 or verbose == True:
            print("{}: {} ({:.2f}Mbp): {} private loci".format(
                chrno, chr, chrlen/1000000.0, len(ordered_loci[chr])), file=sys.stderr)

        #
        # Initialize counters.
        #
        for p in populations_key:
            if p == "vcf":
                continue
            priv_pop[p] = 0
            priv_hyb[p] = 0
            tot_snps_per_pop[p] = 0

        loc_cnt  = 0
        snp_cnt  = 0
        missing_hybrid_cnt = 0

        priv_allele_hybrid_freq = []

        for loc in ordered_loci[chr]:
            loc_cnt += 1
            snp_cnt += len(loc.snps)
            for snpcol in loc.snps:
                snp = loc.snps[snpcol]
                
                if "vcf" not in snp.pops:
                    missing_hybrid_cnt += 1

                #
                # Record total number of SNPs found in each parent population
                # as well as the number of populations that have private alleles from those SNPs.
                #
                for p in snp.pops:
                    if p == "vcf":
                        continue

                    tot_snps_per_pop[p] += 1
                    
                    if snp.pops[p].private == True:
                        priv_pop[p] += 1

                #
                # Was this private allele found in the hybrid/vcf population? If so, print it.
                #
                if snp.priv_hyb == True:

                    for p in snp.pops:
                        pop = snp.pops[p]
                        
                        if pop.private == True:
                            priv_hyb[p] += 1
                            if pop.priv_all == snp.p_allele:
                                priv_allele_hybrid_freq.append(snp.pops['vcf'].p_freq)
                                s = "{}\t{}\t{}\t{}\t{}\t{}\t{:.3f}\t{}\t{}\t{:.3f}\t{}\t{}".format(
                                    loc.chr, snp.bp, loc.id, snpcol,
                                    pop.priv_all, pop.pop, pop.p_freq, pop.cnt,
                                    round(pop.p_freq * pop.cnt),
                                    snp.pops['vcf'].p_freq, snp.pops['vcf'].cnt,
                                    round(snp.pops['vcf'].p_freq * snp.pops['vcf'].cnt))
                                if popmap_used:
                                    for p in vcfpops:
                                        if p in snp.indvpops:
                                            s += "\t{:.3f}\t{}\t{}".format(
                                                snp.indvpops[p].p_freq, snp.indvpops[p].cnt,
                                                round(snp.indvpops[p].p_freq * snp.indvpops[p].cnt))
                                        else:
                                            s += "\t{}\t{}\t{}".format('-','-','-','-')

                            elif pop.priv_all == snp.q_allele:
                                priv_allele_hybrid_freq.append(snp.pops['vcf'].q_freq)
                                s = "{}\t{}\t{}\t{}\t{}\t{}\t{:.3f}\t{}\t{}\t{:.3f}\t{}\t{}".format(
                                    loc.chr, snp.bp, loc.id, snpcol,
                                    pop.priv_all, pop.pop, pop.q_freq, pop.cnt,
                                    round(pop.q_freq * pop.cnt),
                                    snp.pops['vcf'].q_freq, snp.pops['vcf'].cnt,
                                    round(snp.pops['vcf'].q_freq * snp.pops['vcf'].cnt))
                                if popmap_used:
                                    for p in vcfpops:
                                        if p in snp.indvpops:
                                            s += "\t{:.3f}\t{}\t{}".format(
                                                snp.indvpops[p].q_freq, snp.indvpops[p].cnt,
                                                round(snp.indvpops[p].q_freq * snp.indvpops[p].cnt))
                                        else:
                                            s += "\t{}\t{}\t{}".format('-','-','-','-')

                            s += "\n"
                            out_fh.write(s)

        if chrno <= 3 or verbose == True:
            print("{} loci, including {} SNPs, analyzed; in {} cases SNP present in hybrid/VCF population.".format(
                loc_cnt, snp_cnt, snp_cnt - missing_hybrid_cnt), file=sys.stderr)

            for p in priv_pop:
                print(" '{}' had {} SNPs present; private alleles occured within {} of those SNPs ({:.2f}%)".format(
                    p, tot_snps_per_pop[p], priv_pop[p], (priv_pop[p] / float(tot_snps_per_pop[p]) * 100)), file=sys.stderr)
                print("    Private allele found in {} SNPs of hybrid/VCF population ({:.2f}%)".format(
                    priv_hyb[p], (priv_hyb[p] / float(tot_snps_per_pop[p]) * 100)), file=sys.stderr)
                
            print(file=sys.stderr)

    print("Specify --verbose for details on more chromosomes.", file=sys.stderr)    
    out_fh.close()


def invert_refalt_fields(fields):
    #
    # Invert the REF and ALT alleles for this VCF line; alter the downstream genotypes to reflect this change.
    #
    # VCF is defined as: CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO,FORMAT,SAMPLEGENOTYPES...
    invf = []
    invf += fields[0:3]    # CHROM,POS,ID
    invf.append(fields[4]) # ALT -> REF
    invf.append(fields[3]) # REF -> ALT
    invf += fields[5:9]    # QUAL,FILTER,INFO,FORMAT
            
    for sampleno in range(9, len(fields)):
        gt = fields[sampleno]
        if gt[0:3] == "0/0":
            gt = "1/1" + gt[3:]
        elif gt[0:3] == "1/1":
            gt = "0/0" + gt[3:]
        invf.append(gt)

    return invf


def write_polarized_vcfs(chrs_list, loci, populations_key, vcf_path, out_path):
    datestamp = datetime.now().strftime("%Y%m%d")

    #
    # Write out a VCF file for each 'parental' population
    #
    vcf_fhs = {}

    for pop in populations_key:
        if pop == 'vcf':
            continue
        path = out_path + "_" + pop + ".polarized.snps.vcf"
        print("   {}".format(path), file=sys.stderr)
        vcf_fhs[pop] = open(path, "w")

    #
    # Open the VCF input file again.
    #
    in_fh = open(vcf_path, "r")

    lineno = 0
    parts  = []
    #
    # Write the VCF header
    #
    for line in in_fh:
        lineno += 1

        if line[0:2] == "##":
            if line[0:10] == "##fileDate":
                line = "##fileDate={}\n".format(datestamp)
            if line[0:8] == "##source":
                line = "##source={}\n".format("\"stacks-private-alleles\"")
                
            for pop in vcf_fhs:
                vcf_fhs[pop].write(line)

        if line[0:6] == "#CHROM":
            # Write the final line containint the column headings.
            for pop in vcf_fhs:
                vcf_fhs[pop].write(line)
            break

    for line in in_fh:
        lineno += 1
        line    = line.strip("\n")
        parts   = line.split("\t")

        if ref_aligned == True:
            locus_id, column, strand = parts[2].split(":")
        else:
            locus_id, column = parts[2].split(":")
            strand = "+"
        locus_id = int(locus_id)
        column   = int(column)
        strand   = Strand.minus if strand == "-" else Strand.plus
        
        ref_all = parts[3]
        alt_all = parts[4]

        # print("Looking at locus {}, snp {}, ref: {}, alt: {}".format(locus_id, column, ref_all, alt_all), file=sys.stderr)

        if locus_id not in loci:
            continue
        loc = loci[locus_id]

        if column not in loc.snps:
            continue
        snp = loc.snps[column]

        if snp.priv_hyb == False:
            continue
        
        for popname in vcf_fhs:
            #
            # Does this population possess a private allele at this SNP?
            #
            if popname not in snp.pops:
                continue
            
            pop = snp.pops[popname]
            if pop.private == True:
                #
                # If so, check if the private allele is the REF allele; if
                # so, swap the REF/ALT alleles so that private alleles for this
                # focal population are always ALT alleles.
                #
                if pop.priv_all == ref_all:
                    invparts = invert_refalt_fields(parts)
                    vcf_fhs[popname].write("\t".join(invparts) + "\n")
                else:
                    vcf_fhs[popname].write("\t".join(parts) + "\n")

    for pop in vcf_fhs:
        vcf_fhs[pop].close()

    in_fh.close()

    return


def calc_weights(sigma):
    #
    # Calculate weights for window smoothing operations.
    #
    limit = 3 * int(sigma)

    weights = [0.0] * (limit + 1)
    for i in range(0, limit):
        weights[i] = math.exp((-1 * math.pow(i, 2)) / (2 * math.pow(sigma, 2)))

    return weights


def determine_window_limits(sites, limit, center_bp, pos_l, pos_u):
    limit_l = center_bp - limit if center_bp - limit > 0 else 0
    limit_u = center_bp + limit

    while pos_l < len(sites):
        if sites[pos_l].bp < limit_l:
            pos_l += 1
        else:
            break

    while pos_u < len(sites):
        if sites[pos_u].bp < limit_u:
            pos_u += 1
        else:
            break

    return pos_l, pos_u


def smooth_chromosome(sites):
    #
    # To generate smooth genome-wide distributions of Fst, we calculate a kernel-smoothing
    # moving average of Fst values along each ordered chromosome.
    #
    # For each genomic region centered on a nucleotide position c, the contribution of the population
    # genetic statistic at position p to the region average was weighted by the Gaussian function:
    #   exp( (-1 * (p - c)^2) / (2 * sigma^2))
    #
    # In addition, we weight each position according to (n_k - 1), where n_k is the number of alleles
    # sampled at that location.
    #
    # By default, sigma = 150Kb, for computational efficiency, only calculate average out to 3sigma.
    #
    dist  = 0
    pos_l = 0
    pos_u = 0
    sum          = 0.0
    final_weight = 0.0
    allele_cnt   = 0.0

    weights = calc_weights(sigma)

    for pos_c in range(0, len(sites)):
        c = sites[pos_c]

        sum = 0.0

        pos_l, pos_u = determine_window_limits(sites, int(3*sigma), c.bp, pos_l, pos_u)

        # print("c.bp: {}, pos_c: {}, pos_l: {}, pos_u: {}; sites len: {}".format(c.bp, pos_c, pos_l, pos_u, len(sites)), file=sys.stderr)
        for pos_p in range(pos_l, pos_u):
            p = sites[pos_p]

            if p.bp > c.bp:
                dist = p.bp - c.bp
            else:
                dist = c.bp - p.bp

            final_weight = (c.cnt - 1.0) * weights[dist]
            c.smoothed   = c.smoothed + (p.stat * final_weight)
            sum         += final_weight

        c.smoothed = c.smoothed / sum;

    return


def write_plot_values(fh, pop_1, pop_2):
    i = 0
    j = 0
    p1_len = len(pop_1)
    p2_len = len(pop_2)

    while i < p1_len or j < p2_len:

        if j < p2_len:
            while i < p1_len and pop_1[i].bp <= pop_2[j].bp:
                if smooth:
                    s = "{}\t{}\t{}\t{:.3f}\t{:.5f}".format(pop_1[i].chr, pop_1[i].bp, pop_1[i].pop, pop_1[i].stat, pop_1[i].smoothed)
                else:
                    s = "{}\t{}\t{}\t{:.3f}".format(pop_1[i].chr, pop_1[i].bp, pop_1[i].pop, pop_1[i].stat)
                fh.write(s + "\n")
                i += 1
        else:
            while i < p1_len:
                if smooth:
                    s = "{}\t{}\t{}\t{:.3f}\t{:.5f}".format(pop_1[i].chr, pop_1[i].bp, pop_1[i].pop, pop_1[i].stat, pop_1[i].smoothed)
                else:
                    s = "{}\t{}\t{}\t{:.3f}".format(pop_1[i].chr, pop_1[i].bp, pop_1[i].pop, pop_1[i].stat)
                fh.write(s + "\n")
                i += 1
                    
        if i < p1_len:
            while j < p2_len and pop_2[j].bp < pop_1[i].bp:
                if smooth:
                    s = "{}\t{}\t{}\t{:.3f}\t{:.5f}".format(pop_2[j].chr, pop_2[j].bp, pop_2[j].pop, (-1 * pop_2[j].stat), (-1 * pop_2[j].smoothed))
                else:
                    s = "{}\t{}\t{}\t{:.3f}".format(pop_2[j].chr, pop_2[j].bp, pop_2[j].pop, (-1 * pop_2[j].stat))
                fh.write(s + "\n")
                j += 1
        else:
            while j < p2_len:
                if smooth:
                    s = "{}\t{}\t{}\t{:.3f}\t{:.5f}".format(pop_2[j].chr, pop_2[j].bp, pop_2[j].pop, (-1 * pop_2[j].stat), (-1 * pop_2[j].smoothed))
                else:
                    s = "{}\t{}\t{}\t{:.3f}".format(pop_2[j].chr, pop_2[j].bp, pop_2[j].pop, (-1 * pop_2[j].stat))
                fh.write(s + "\n")
                j += 1
                
    
def write_plots(chrs_list, ordered_loci, populations_key, vcfpops, out_path):
    datestamp = datetime.now().strftime("%Y%m%d")

    #
    # If requested, write a file to plot private allele frequencies in the VCF population.
    # The y-axis scales from 1 to -1, the x-axis is the genome coordinates
    #   The first population is positive on the y-axis, the second negative;
    #   The frequency of the private allele in the VCF population is the y value.
    #
    plot_pops = list(populations_key)
    plot_fhs  = {}
    plot_fhs['vcf'] = open(out_path + "_plot_by_chromosome-all.tsv", 'w')
    print("   {}".format(out_path + "_plot_by_chromosome-all.tsv"), file=sys.stderr)
    plot_fhs['vcf'].write("# Generated by stacks-private-alleles, {}\n# {}\n".format(datestamp, " ".join(sys.argv)))
    plot_fhs['vcf'].write("#   Population '{}': positive values\n#   Population '{}': negative values\n".format(plot_pops[0], plot_pops[1]))
    if smooth:
        plot_fhs['vcf'].write("# Chr\tBP\tPopulation\tVCF-PrivateAlleleFreq\tVCF-Smoothed\n")
    else:
        plot_fhs['vcf'].write("# Chr\tBP\tPopulation\tVCF-PrivateAlleleFreq\n")

    if popmap_used == True:
        for p in vcfpops:
            plot_fhs[p] = open(out_path + "_plot_by_chromosome-" + p + ".tsv", 'w')
            print("   {}".format(out_path + "_plot_by_chromosome-" + p +".tsv"), file=sys.stderr)
            plot_fhs[p].write("# Generated by stacks-private-alleles, date {}\n# {}\n".format(datestamp, " ".join(sys.argv)))
            plot_fhs[p].write("#   Population '{}': positive values\n#   Population '{}': negative values\n".format(plot_pops[0], plot_pops[1]))
            if smooth:
                plot_fhs[p].write("# Chr\tBP\tPopulation\t{}-PrivateAlleleFreq\t{}-Smoothed\n".format(p,p))
            else:
                plot_fhs[p].write("# Chr\tBP\tPopulation\t{}-PrivateAlleleFreq\n".format(p))

    if smooth:
        print("   Smoothing values", file=sys.stderr, end='', flush=True)

    #
    # Iterate over chromosomes/scaffolds, longest to shortest.
    #
    for chr, chrlen in chrs_list:
        if chr not in ordered_loci:
            continue
        if chrlen < ctg_limit:
            continue

        #
        # Create arrays to hold SmoothStat objects to do the plot smoothing.
        #
        pop_1 = []
        pop_2 = []
        pop_1_vcfpops = {}
        pop_2_vcfpops = {}
        if popmap_used == True:
            for p in vcfpops:
                pop_1_vcfpops[p] = []
                pop_2_vcfpops[p] = []

        for loc in ordered_loci[chr]:

            for snpcol in loc.snps:
                snp = loc.snps[snpcol]

                #
                # Was this private allele found in the hybrid/vcf population? If so, print it.
                #
                if snp.priv_hyb == True:

                    for p in snp.pops:
                        pop = snp.pops[p]
                        
                        if pop.private == True:

                            if pop.priv_all == snp.p_allele:
                                if pop.pop == plot_pops[0]:
                                    pop_1.append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.pops['vcf'].cnt, snp.pops['vcf'].p_freq))
                                    if popmap_used == True:
                                        for p in vcfpops:
                                            if snp.indvpops[p].cnt > 0:
                                                pop_1_vcfpops[p].append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.indvpops[p].cnt, snp.indvpops[p].p_freq))
                                            
                                elif pop.pop == plot_pops[1]:
                                    pop_2.append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.pops['vcf'].cnt, snp.pops['vcf'].p_freq))
                                    if popmap_used == True:
                                        for p in vcfpops:
                                            if snp.indvpops[p].cnt > 0:
                                                pop_2_vcfpops[p].append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.indvpops[p].cnt, snp.indvpops[p].p_freq))
                                            
                            elif pop.priv_all == snp.q_allele:
                                if pop.pop == plot_pops[0]:
                                    pop_1.append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.pops['vcf'].cnt, snp.pops['vcf'].q_freq))
                                    if popmap_used == True:
                                        for p in vcfpops:
                                            if snp.indvpops[p].cnt > 0:
                                                pop_1_vcfpops[p].append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.indvpops[p].cnt, snp.indvpops[p].q_freq))
                                            
                                elif pop.pop == plot_pops[1]:
                                    pop_2.append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.pops['vcf'].cnt, snp.pops['vcf'].q_freq))
                                    if popmap_used == True:
                                        for p in vcfpops:
                                            if snp.indvpops[p].cnt > 0:
                                                pop_2_vcfpops[p].append(SmoothStat(loc.id, pop.pop, snp.col, loc.chr, snp.bp, snp.indvpops[p].cnt, snp.indvpops[p].q_freq))

        pop_1.sort(key=attrgetter("bp"))
        pop_2.sort(key=attrgetter("bp"))
        if popmap_used == True:
            for p in vcfpops:
                pop_1_vcfpops[p].sort(key=attrgetter("bp"))
                pop_2_vcfpops[p].sort(key=attrgetter("bp"))
                
        #
        # Smooth the values.
        #
        if smooth == True:
            smooth_chromosome(pop_1)
            smooth_chromosome(pop_2)
            if popmap_used == True:
                for p in vcfpops:
                    smooth_chromosome(pop_1_vcfpops[p])
                    smooth_chromosome(pop_2_vcfpops[p])
            print(".", file=sys.stderr, end='', flush=True)

        #
        # Write the values.
        #
        write_plot_values(plot_fhs['vcf'], pop_1, pop_2)
        if popmap_used == True:
            for p in vcfpops:
                write_plot_values(plot_fhs[p], pop_1_vcfpops[p], pop_2_vcfpops[p])

    for f in plot_fhs:
        plot_fhs[f].close()
    if smooth == True:
        print(file=sys.stderr)

    return


def main():
    loci         = {}
    ordered_loci = {}
    chrs_list    = []
    #
    # List of populations, and ordered sample names for each.
    #
    populations_key = {}

    parse_command_line()

    vcfpops     = {}
    vcfpops_rev = {}
    if popmap_used:
        print("Parsing population map: '{}'...".format(popmap_path), file=sys.stderr)
        parse_popmap_file(popmap_path, vcfpops, vcfpops_rev)
    
    print("Parsing sumstats file: '{}'...".format(sumstats_path), file=sys.stderr)
    parse_sumstats_file(sumstats_path, loci, populations_key)

    print("Parsing VCF file: '{}'...".format(vcf_path), file=sys.stderr)
    parse_hybrid_vcf(vcf_path, loci, populations_key, vcfpops, vcfpops_rev, chrs_list)

    order_loci(chrs_list, loci, ordered_loci)

    write_private_allele_summary(chrs_list, ordered_loci, populations_key, vcfpops, out_path)

    if polarize == True:
        print("\nWriting polarized VCF files:", file=sys.stderr)
        write_polarized_vcfs(chrs_list, loci, populations_key, vcf_path, out_path)
        print("done.", file=sys.stderr)

    if plots == True:
        print("\nGenerating plotting files:", file=sys.stderr)
        write_plots(chrs_list, ordered_loci, populations_key, vcfpops, out_path)
        print("done.", file=sys.stderr)

#                                                                              #
#------------------------------------------------------------------------------#
#                                                                              #
if __name__ == "__main__":
    main()
