# -*- coding: utf-8 -*-
import itertools as it
import numpy as np
from scipy.sparse import csc_matrix
from numpy.linalg import matrix_rank


#return the list of all forsts of 'd' sticks on S with |S|=2d
def Forests(S,d):
    if d==0:
        return [[]]
    else:
        m=min(S)
        S=[i for i in S if i!=m]
        return [[(m,i)]+OT for i in S for OT in Forests([j for j in S if j!=i],d-1)]

#return the list of all forsts of 'd' sticks with 'a' single vertices labelled with x and 'b' with y
#the vertices are [0,..,2d+a+b-1]
def LabForests(d,a,b):
    n=2*d+a+b
    for S in it.combinations(range(n),2*d):
        for OS in Forests(S,d):
            for T in it.combinations([i for i in range(n) if i not in S],a):
                yield [OS,list(T),[i for i in range(n) if i not in S and i not in T]]

# return a sparse matrix that represent the differential (a+b,d) -> (a+b+2,d-1) of weight a-b on 2d+a+b vertices (full forests)
def diff(d,a,b):
    if d==0:
        return []
    else:
        LabRow=[x for x in LabForests(d-1,a+1,b+1)]
        LabCol=[y for y in LabForests(d,a,b)]
        I=[]
        J=[]
        data=[]
        for F in LabCol:
                for (a,b) in F[0]:
                    G1=[[x for x in F[0] if x != (a,b)],sorted(F[1]+[a]),sorted(F[2]+[b])]
                    G2=[[x for x in F[0] if x != (a,b)],sorted(F[1]+[b]),sorted(F[2]+[a])]
                    #the coefficient consider the sign for the Leibniz rule and the sign coming from reordering the d(G_a,b)
                    data=data+[(-1)**(F[0].index((a,b))+len(F[1])-G1[1].index(a)+G1[2].index(b)),(-1)**(F[0].index((a,b))+len(F[1])-G2[1].index(b)+G2[2].index(a))]
                    I=I+[LabCol.index(F),LabCol.index(F)]
                    J=J+[LabRow.index(G1),LabRow.index(G2)]
        return csc_matrix((np.array(data),(np.array(J),np.array(I))),dtype='f')

# return the dimension of the subgroup of gr^(2d+k) H^(k+2q,d-q)(Conf(C,2d+k)/C) where the maximal torus T of SL2 acts with weight k for q=0,..,d
def h(d,k):
    rkd=[0 for x in range(d+1)]
    dim=[0 for x in range(d+1)]
    for i in range(d):
        D=diff(d-i,k+i,i)
        dim[i]=D.get_shape()[1]
        rkd[i]=matrix_rank(D.toarray())
    if d>0:
        dim[d]=D.get_shape()[0]
    else:
        dim[0]=1
    return [dim[i]-rkd[i]-rkd[i-1] for i in range(d+1)]


print(h(3,2))
print(h(3,0))
print(h(2,2))
#print(h(4,0))
#print(matrix_rank(diff(4,2,0).toarray()))