from Bio.Align import AlignInfo 
from Bio import Clustalw
import math

def aln_to_distancematrix(input_filename):
    '''
    Given the path to a dna multiple-aligment file in .aln format, this returns a dictionary
        of Jukes-Cantor distances in which the keys (indices) are integer 2-tuples.
        The integers are given by the input sequence order.
    
    '''
    alignment = Clustalw.parse_file(input_filename)
    dist_dict = {}
    for i in range(len(alignment.get_column(0))):
        for j in range(len(alignment.get_column(0))):
            dist_dict[(i,j)] = 0
    total = 0
    for i in range(alignment.get_alignment_length()): 
        column = alignment.get_column(i)
        if column.find('-') == -1:
            total = total + 1
            for ColumnIndex1 in range(len(column)):
                for ColumnIndex2 in range(ColumnIndex1+1, len(column)):
                    if column[ColumnIndex1] != column[ColumnIndex2]:
                        dist_dict[(ColumnIndex1,ColumnIndex2)] += 1
                        dist_dict[(ColumnIndex2,ColumnIndex1)] += 1
    for i in range(len(alignment.get_column(0))):
        for j in range(len(alignment.get_column(0))):
            dist_dict[(i,j)] = -.75*math.log(1-4/3*dist_dict[(i,j)]/float(total))
    return dist_dict

# This may be useful in computing the arithmetic means between clusters.
def flatten(inlist, ltypes=(list, tuple)):
    '''
    Flattens a nested list.
    '''
    index = 0
    flatlist = [x for x in inlist]
    while index < len(flatlist):
        if not flatlist[index]:
            flatlist.pop(index)
            continue
        while isinstance(flatlist[index], ltypes):
            flatlist[index:index+1] = list(flatlist[index])
        index += 1
    return flatlist

# I believe this was a working version, from which I deleted only one line (!).  
# There is a comment at that deletion
def upgma2(dist_dict):
    '''
    Returns a string in Newick format that gives the UPGMA tree from a distance dictionary.
    '''
    cluster_length = int(math.sqrt(len(dist_dict)))
    clustered = []
    unclustered = [x for x in range(cluster_length)]
    maxd = max(dist_dict.values())
    while unclustered != []:
        clustered.sort()
        #check unclustered vs unclustereds:
        ucmeanmin = maxd + 1 
        for i in unclustered:
            for j in unclustered:
                if ucmeanmin > dist_dict[i,j] and i != j:
                    ucmeanmin = dist_dict[i,j]
                    ucmeanarg = (i,j)
        #check unclustered vs clustered:
        cmeanmin = maxd + 1
        for i in range(len(clustered)):
            flat = flatten(clustered[i])
            clustersize = len(flat)
            for j in unclustered:
                total = 0
                for k in flat:
                    total += dist_dict[k,j]
                mean = total/float(clustersize)
                if cmeanmin > mean:
                    cmeanmin = mean
                    cmeanarg = (i,j)
        #check clustered vs clustered:
        ccmeanmin = maxd + 1
        for i in range(len(clustered)):
            flat1 = flatten(clustered[i])
            clustersize1 = len(flat1)
            for j in range(i+1,len(clustered)):
                flat2 = flatten(clustered[j])
                clustersize2 = len(flat2)
                total = 0
                for k1 in flat1:
                    for k2 in flat2:
                        total += dist_dict[k1,k2]
               #deleted line; you need to define the cluster-cluster mean here from the above data.
                if ccmeanmin > mean:
                    ccmeanmin = mean
                    ccmeanarg = (i,j)
        # agglomerate:
        if ucmeanmin < cmeanmin and ucmeanmin < ccmeanmin:
            #an unclustered vs unclustered had the minimum:
            clustered.append([ucmeanarg[0],ucmeanarg[1]])
            unclustered.remove(ucmeanarg[0])
            unclustered.remove(ucmeanarg[1])
        elif ccmeanmin < cmeanmin:
            #a clustered vs clustered had the minimum:
            clustered[ccmeanarg[0]] = [clustered[ccmeanarg[0]],clustered[ccmeanarg[1]]]
            clustered.remove(clustered[ccmeanarg[1]])
        else:
            #an unclustered vs clustered had the minimum:
            clustered[cmeanarg[0]] = [clustered[cmeanarg[0]],cmeanarg[1]]
            unclustered.remove(cmeanarg[1])
    Newick_string = str(clustered).replace(']',')')
    Newick_string = Newick_string.replace('[','(')
    return Newick_string
