"""
Enumerate lattices with a Python implementation of the algorithm from
    J. Heitzig and J. Reinhold, Counting finite lattices,
    Algebra Universalis 48 (2002) 43-53.

Author:  Peter Jipsen (2009-08-20): initial version

Note:    This file is Python code with no Sage dependencies but usable in Sage

Example: Find all nonisomorphic lattices of size 5:

         all_lattices(5)

         The output is a list of adjacency l
         ists giving the upper cover relation
         on the set {0,...,n-1} with 0 as bottom and 1 as top element
         E.g. the 5 element nonmodular lattice is [[3, 4], [], [1], [1], [2]]

         The algorithm also enumerates finite (meet or join) semilattices
         (just delete the top or bottom element)
"""

#*****************************************************************************
#           Copyright (C) 2009 Peter Jipsen <jipsen@chapman.edu>
#
# Distributed  under  the  terms  of  the  GNU  General  Public  License (GPL)
#                         http://www.gnu.org/licenses/
#*****************************************************************************

def permutations(m,n):
    #return list of all permutations of {m,...,n-1}
    p = [m+i for i in range(n-m)]
    ps = [p]
    n = len(p)
    j = 1
    while j>=0:
        q = range(n)
        j = n-2
        while j>=0 and p[j]>p[j+1]: j = j-1
        if j>=0:
            for k in range(j): q[k] = p[k]
            k = n-1
            while p[j]>p[k]: k = k-1
            q[j] = p[k]
            i = n-1
            while i>j:
                q[i] = p[j+n-i]
                i = i-1
            q[j+n-k]=p[j]
            p = q
            ps.append(q)
    return ps

def inverse_permutation(p): # assumes permutation is on {0,...,len(p)-1}
    q = range(len(p))
    for i in range(len(p)):
        q[p[i]]=i
    return q

def lattice_antichains(L,lev,dep):
    # find subsets A of L-{0} such that a,b in up(A) implies a^b in {0} U up(A)
    # and A intersects lev(k-1) U lev(k) where k = dep(n-1)
    # and sum(2^j for j in A) >= w(n-1)
    def achains0(A,x,B):
        # find disjoint subsets of A U [x] U B (if it intersects blevs)
        # A is a set of pairwise disjoint elements, each a in A is disjoint
        # from x and from all elements of B
        A1 = A+[x]
        u = [y for y in range(1,len(L)) if any(le[c][y] for c in A1)]
        #if sum(2**j for j in A1)>=wm1: As.append(A1[:])
        if sum(2**j for j in A1)>=wm1 and \
           all(all(any(le[c][u[i]] and le[c][u[j]] for c in A1) or\
                   not any(le[c][u[i]] and le[c][u[j]] for c in Zc)\
                   for j in range(i)) for i in range(len(u))): As.append(A1[:])
        if B!=[]:
            if blevs.intersection(A+B)!=[]: achains0(A,B[0],B[1:])
            B1 = []
            C = [c for c in Zc if le[c][x]]
            for b in B:
                #(for antichains) if not(le[x][b] or le[b][x]): B1.append(b)
                if not any(le[c][b] for c in C): B1.append(b)
            if B1!=[]: achains0(A1,B1[0],B1[1:])
    wm1 = sum(2**j for j in L[-1])
    k = dep[len(L)-1]
    blevs = set(lev[k-1]+lev[k]) # bottom two levels
    As = []
    Zc = L[0]  # covers of 0
    if 1 in blevs: achains0([],1,range(2,len(L)))
    else: achains0([],2,range(3,len(L)))
    return As

def is_canonical_lattice(L,AutLm):
    # let k=dep(n), m=max(lev(k-1))+1 and Lm = L|{0..m-1}.
    # generate all level-preserving permutations with restriction in Aut(Lm)
    m = len(L)
    ps = permutations(len(AutLm[0]),m)
    newps = []
    for q in ps: newps = newps+[p+q for p in AutLm]
    AutL = [newps[0]]
    w = [sum(2**j for j in L[i]) for i in range(2,m)] # weight of L
    for p in newps[1:]:
        pw = [sum(2**p[j] for j in L[i]) for i in range(2,m)]
        q = inverse_permutation(p)
        pw = [pw[i-2] for i in q[2:]]
        if w > pw: return False,[]
        if w == pw: AutL.append(p)
    return True,AutL

def all_lattices(n,count=False): # construct (or count) all lattices of size n
    global le, lat_count
    def next_lattice(L,lev,dep,AutL,AutLm,n):
        global lat_count
        m = len(L) # new element to be added
        if m<n:
            for A in lattice_antichains(L,lev,dep):
                L_A = L+[A] # add covers of new element
                L_A[0] = [x for x in L_A[0] if not x in A]+[m] # add new element as
                                                               # atom and remove A
                if max(A) in lev[-1]:
                    lev_A = lev+[[m]] # update level
                    AutLAm = AutL
                else:
                    lev_A = lev[:-1]+[lev[-1]+[m]]
                    AutLAm = AutLm
                is_canon,AutLA = is_canonical_lattice(L_A,AutLAm)
                if is_canon:
                    for j in range(m): # update less_or_equal relation
                        le[m][j] = any(le[i][j] for i in A)
                    if max(A) in lev[-1]: dep_A = dep+[len(lev)] # update depth
                    else: dep_A = dep+[len(lev)-1]
                    next_lattice(L_A,lev_A,dep_A,AutLA,AutLAm,n)
        elif count: lat_count = lat_count+1
        else: lat_list.append([c[:] for c in L])
    lat_list = []
    lat_count = 0
    le = [[True if i==j or i==0 else False for j in range(n)] for i in range(n)]
    le[2][0] = False; le[2][1] = True # initialize less_or_equal relation
    next_lattice([[2],[],[1]],[[1],[2]],[2,0,1],[[0,1,2]],[[0,1]],n)
    return lat_count if count else lat_list

from time import *

def tlat(n):
    t = time()
    return all_lattices(n,True), time()-t
