"""
Enumerate lattices with a Python implementation of the algorithm from
    J. Heitzig and J. Reinhold, Counting finite lattices,
    Algebra Universalis 48 (2002) 43-53.

Author:  Peter Jipsen (2009-08-20): initial version

Note:    This file is Python code with no Sage dependencies but usable in Sage

Example: Find all nonisomorphic lattices of size 5:

         all_lattices(5)

         The output is a list of adjacency lists giving the upper cover relation
         on the set {0,...,n-3} for L without bottom and top element
         E.g. the 5 element nonmodular lattice is [[], [], [0]]

         The algorithm also enumerates finite (meet or join) semilattices
         (add only the top or bottom element)
"""

#*****************************************************************************
#           Copyright (C) 2009 Peter Jipsen <jipsen@chapman.edu>
#
# Distributed  under  the  terms  of  the  GNU  General  Public  License (GPL)
#                         http://www.gnu.org/licenses/
#*****************************************************************************

import psyco
psyco.full()

def permutations(m,n):
    #return list of all permutations of {m,...,n-1}
    p = [m+i for i in range(n-m)]
    ps = [p]
    n = len(p)
    j = 1
    while j>=0:
        q = range(n)
        j = n-2
        while j>=0 and p[j]>p[j+1]: j = j-1
        if j>=0:
            for k in range(j): q[k] = p[k]
            k = n-1
            while p[j]>p[k]: k = k-1
            q[j] = p[k]
            i = n-1
            while i>j:
                q[i] = p[j+n-i]
                i = i-1
            q[j+n-k]=p[j]
            p = q
            ps.append(q)
    return ps

def inverse_permutation(p): # assumes permutation is on {0,...,len(p)-1}
    q = range(len(p))
    for i in range(len(p)):
        q[p[i]]=i
    return q

def achains0(A,x,B):
    # find disjoint subsets of A U [x] U B (if it intersects blevs)
    # A is a set of pairwise disjoint elements, each a in A is disjoint
    # from x and from all elements of B
    A1 = A+[x]
    u = [y for y in range(0,m) if any([le[c][y] for c in A1])]
    if sum([2**j for j in A1])>=wm1 and \
       all([all([any([le[c][u[i]] and le[c][u[j]] for c in A1]) or\
               not any([le[c][u[i]] and le[c][u[j]] for c in Zc])\
               for j in range(i)]) for i in range(len(u))]): As.append(A1[:])
    if B!=[]:
        if blevs.intersection(A+B)!=[]: achains0(A,B[0],B[1:])
        B1 = []
        C = [c for c in Zc if le[c][x]]
        for b in B:
            #(for antichains) if not(le[x][b] or le[b][x]): B1.append(b)
            if not any([le[c][b] for c in C]): B1.append(b)
        if B1!=[]: achains0(A1,B1[0],B1[1:])

def lattice_antichains(L,lev,dep):
    # find subsets A of L-{0} such that a,b in up(A) implies a^b in {0} U up(A)
    # and A intersects lev(k-1) U lev(k) where k = dep(n-1)
    # and sum(2^j for j in A) >= w(n-1)
    global wm1, m, blevs, As, Zc
    wm1 = sum([2**j for j in L[-1]])
    m = len(L)
    k = dep[m-1]
    blevs = set(lev[k-1]+lev[k]) # bottom two levels
    As = []
    Zc = set(range(m)).difference(reduce(lambda x,y:set(x)|set(y),L))  # minimal elements
    achains0([],0,range(1,m))
    if len(lev)==1: return [[]]+As
    return As

def is_canonical_lattice(L,lev):
    # let k=dep(n), m=max(lev[k-1])+1 and Lm = L|{0..m-1}.
    # generate all level-preserving permutations
    n = len(L)
    perms = [permutations(l[0],l[-1]+1) for l in lev]
    ps = perms[0]
    for i in range(1,len(lev)):
        newps = []
        for p in ps:
            for q in perms[i]: newps.append(p+q)
        ps = newps
    w = [sum([2**j for j in L[i]]) for i in range(n)] # weight of L
    for p in ps[1:]:
        pw = [sum([2**p[j] for j in L[i]]) for i in range(n)]
        q = inverse_permutation(p)
        pw = [pw[i] for i in q]
        if w > pw: return False
    return True

def next_lattice(L,lev,dep,n,count):
    global lat_count, lat_list
    m = len(L) # new element to be added
    if m<n:
        for A in lattice_antichains(L,lev,dep):
            L_A = L+[A] # add covers of new element
            if A!=[] and A[-1] in lev[-1]: lev_A = lev+[[m]] # update level
            else: lev_A = lev[:-1]+[lev[-1]+[m]]
            if is_canonical_lattice(L_A,lev_A): 
                for j in range(m): # update less_or_equal relation
                    le[m][j] = any([le[i][j] for i in A])
                if A!=[] and A[-1] in lev[-1]: dep_A = dep+[len(lev)] # update depth
                else: dep_A = dep+[len(lev)-1]
                next_lattice(L_A,lev_A,dep_A,n,count)
    elif count: lat_count = lat_count+1
    else: lat_list.append([c[:] for c in L])

def all_lattices(n,count=False): # construct (or count) all lattices of size n
    global le, lat_count, lat_list
    lat_list = []
    lat_count = 0
    # initialize less_or_equal relation
    le = [[True if i==j else False for j in range(n-2)] for i in range(n-2)]
    next_lattice([[]],[[0]],[0],n-2,count)
    return lat_count if count else lat_list

from time import *
def tlat(n):
    t=clock()
    return all_lattices(n,True),clock()-t
