"""
A bunch of miscellanous utilities functions for Python that I seem to need
all the time.

2002-2005 by Erwin S. Andreasen -- http://www.andreasen.org/misc.shtml
This file is in the Public Domain

Version: $Id: util.py,v 1.22 2005/12/16 00:08:21 erwin Exp erwin $

This file needs Python 2.2+.

You can use ``from util import *'' in realy short scripts, but in longer ones
I'd recommend importing each name individually, to allow for better static
error checking.

"""
from __future__ import generators
import sys, traceback, string, time, types, operator, re, os, linecache
import random, base64, md5, copy, cStringIO
from cStringIO import StringIO as IO

__all__ = '''
log fatal limitString cycleList shuffle getTraceback formEncode urlencodeDict generatePassword
run capitalize commaify Enum debug Sort Uniq UniqOrdered Struct If timer MutableInt makedict sum
noun each split indices oneArgument createMatrix stableSort groupStride hasDebug
quote filterURL stableDSU fastDSU combinations setCompare flatten showProgress NullFile memoize
displayLess callableDescr cartesianProduct identity dateTime annotate linefilter splitOn IO
takeWeighted traceit nullfun any all ddict
'''.split()

def identity(x): return x

# there are some vars/attributes that are usually not useful to print in the traceback
UnprintableTypes = [ types.ModuleType, types.ClassType, types.FunctionType,
                     types.MethodType, types.BuiltinMethodType,
                     types.UnboundMethodType, types.TypeType ]

def sum(s, *initial):
    return reduce(operator.add, s, *initial)

def formatLocals(locals, seen):
    "Format a dictionary #locals, recurse into types. Don't shown what's already #seen"
    r = ""
    keys = locals.keys()
    keys.sort()
    for l in keys:
        v = locals[l]

        if l == '__builtins__' or type(v) in UnprintableTypes:
            continue

        # Ignore any errors from the repr()
        try:
            if isinstance(v, (list, tuple)):
                v = v[:100]
                
            r += "%-20s  %s" % (l,repr(v)[:5000]) + "\n"
        except Exception, e:
            r += "%-20s  [error calling repr: %s]\n" % (l,e)
            continue

        if seen.has_key(id(v)):
            continue
        seen[id(v)] = 1

        if type(v) not in UnprintableTypes and not l.startswith('_'):
            klass = getattr(v, '__class__', None)
            try:
                keys = dir(v)
                keys.sort()
                for l2 in keys:
                    if l2[0] == '_':
                        continue
                    v2 = getattr(v,l2)
                    val = repr(v2)
                    if len(val) > 1000:
                        val = val[:1000] + "..."
                    if type(v2) in UnprintableTypes:
                        continue
                    if klass and hasattr(klass, l2): # Don't display class attributes
                        continue
                    r = r + "  %-20s  %s" % (l2,val) + "\n"
            except:
                pass

    return r.encode('utf-8')            # Make sure we only return ASCII or similar
    

def getTraceback(showLocals=1,info=None,levels=10):
    """
    Return the traceback for the current exception. If showLocals is 1, the locals
    in the two frames below are also printed
    """
    etype,evalue,etb = info or sys.exc_info()
    try:
        r =  "Error text: %s (%s)\n" % (evalue, etype)
    except:
        r = "(Error formatting error text)\n"
    r += string.join(traceback.format_exception(etype, evalue, etb), "") +"\n"

    if hasattr(evalue, 'remote_stacktrace'):
        r += '\n--- Remote exception follows ---\n'
        r += '\n'.join(evalue.remote_stacktrace)
        r += '--- End of remote exception ---\n\n'
        
    if not showLocals:
        return r

    

    t = etb
    while t.tb_next <> None: t = t.tb_next
    seen = {}
    
    r += "** Locals (%s):\n" % t.tb_frame.f_code.co_name
    r += formatLocals( t.tb_frame.f_locals, seen)

    # Show the locals one level up too
    cframe = t.tb_frame.f_back
    for level in range(1,levels):
        if cframe:
            r += ("\n** Locals (%s):\n" % cframe.f_code.co_name) + formatLocals(cframe.f_locals, seen)
            cframe = cframe.f_back

    # Display some information to make life nicer for emacs
    r += "*file: %s %d\n" % (os.path.abspath(t.tb_frame.f_code.co_filename),
                             traceback.tb_lineno(t))

    return r

def log(s):
    "Write the string to stderr, preceeded by the current time. Adds a newline."
    sys.stderr.write("[%s] %s\n" % ( time.strftime("%d/%b/%y %H:%M:%S", time.localtime(time.time())), s))

weekdayname = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']

monthname = [None,
            'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
             'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']

def fatal(s):
    "Log this message to stderr, then exit"
    sys.stderr.write("FATAL ERROR: %s\n" % s)
    sys.exit(1)

def dateTime(offset = 0):
    "Return the current date and time formatted for a HTTP header."
    # This was possibly written by Oliver Jowett -- I don't remember
    when = time.time() + offset
    year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when)
    s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
            weekdayname[wd],
            day, monthname[month], year,
            hh, mm, ss)
    return s

def limitString(s,length):
    "Limit string #s# to at most #length# characters (replace extra with ...)"
    if len(s) > length:
        return s[:length-3] + "..."
    else:
        return s

def cycleList(l):
    "Pop the first element of the list, return it and push it back on the list again"
    el = l.pop(0)
    l.append(el)
    return el

def shuffle(l):
    "Shuffle things on this list around randomly"
    # Assign a random value to each thing on the list
    rand = {}
    for x in l:
        rand[x] = random.random()

    l.sort(lambda x,y,rand=rand: cmp(rand[x],rand[y]))

def encodeString(s):
    "base64 encoded md5 hash of a string for one-way ASCII hash"
    return base64.encodestring(md5.new(s).digest())[:-1]

def generatePassword(count):
    "Generate a random password, that many characters long"
    validCharacters = 'abcdefghjkmnpqrstuvwxyz0123456789'
    return string.join(map(lambda x,v=validCharacters: random.choice(v), range(count)), '')

def urlencodeDict(dict):
    "URL encode dictionary. Dictionary items may be strings or lists"
    l = []
    for k,v in dict.items():
        if type(v) in (list, tuple):
            l.extend([ '%s=%s' % (k, x) for x in v])
        else:
            l.append('%s=%s' % (k,v))

    return string.join(l, '&')

def formEncode(dict):
    "Encode a dictionary into <input type=hidden. Dictionary items may be strings or lists. Data is quoted."
    r = ""
    for k,v in dict.items():
        if type(v) in (list, tuple):
            l = v
        else:
            l = [v]
        for x in l:
            r += '<INPUT TYPE="hidden" NAME="%s" VALUE="%s">\n' % (k,quote(x))
    return r


def toPrintable(x):
    "Convert this character to something we know we can print"
    if ord(x) < 32 or ord(x) > 127:
        return '.'
    else:
        return x

def hexdump(data):
    "Dump the data into hex nicely"
    # 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F  0123456789abcef
    r = ''
    while data:
        chunk = data[:16]
        data = data[16:]
        hex = string.join(['%02X' % ord(x) for x in chunk], ' ')
        r += '%-50s %s\n' % (hex,
                            string.join([toPrintable(x) for x in chunk], ''))

    return r


def displayLess(s):
    if os.isatty(sys.stdout.fileno()) and os.getenv('TERM') != 'emacs':
        less = os.popen('less -', 'w')
        try:
            less.write(s)
            less.close()
        except IOError:
            pass
    else:
        print s
    
def run(main,cgi=0,levels=10):
    """
    Run this main function, display traceback if it throws an exception
    Example usage:
    import util
    def main:
      your main here...
      
    if __name__ == '__main'__: util.run(main)
    """

    if os.getenv('HOTSHOT'):
       import hotshot, hotshot.stats
       prof = hotshot.Profile('program.prof')
       try:
           prof.runcall(main)
       except (SystemExit, KeyboardInterrupt):
           pass
       prof.close()
       import psyco; psyco.full()
       stats = hotshot.stats.load("program.prof")
       stats.strip_dirs()
       stats.sort_stats('time', 'calls')
       stats.print_stats(100)
       out = cStringIO.StringIO(); old = sys.stdout; sys.stdout = out
       stats.print_stats(100).print_callees().print_callers()
       sys.stdout = old
       displayLess(out.getvalue())
       return
    try:
        main()
    except (SystemExit, KeyboardInterrupt):
        return
    except:
        if cgi:
            import cgi
            print "Content-Type: text/html\n\n<PRE>"
            print cgi.escape(getTraceback(1))
            print "</PRE>"
        else:
            tb = getTraceback(1,levels=levels)
            if hasattr(sys.stdout, 'fileno') and os.isatty(sys.stdout.fileno()):
                displayLess(tb)
            else:
                print >> sys.stderr, tb
        sys.exit(1)


def capitalize(name):
    "Convert fooBar -> FooBar"
    if not name:
        return ''
    else:
        return name[0].upper() + name[1:]

def commaify(num):
    if num > 1000:  return '%s,%03d' % (commaify(num/1000), num % 1000)
    else:           return str(num)

    

class Enum:
    """
    Map names to strings. Create symbols for the names
    E.g. names = Enum('bob joe alice') creates:
     names.BOB, names.JOE, names.ALICE
    names.lookup('bob') => names.BOB
    names.toString(names.BOB) => 'bob'
    """
    def __init__(self, names):
        self.numbers = {}                 # name -> number
        self.names = {}                   # num -> name
        num = 0
        self.enums = names.split()
        for enum in self.enums:
            self.numbers[enum] = num
            self.names[num] = enum
            setattr(self, enum, num)
            setattr(self, enum.upper(), num)
            num += 1

    def lookup(self, name):
        return self.numbers.get(name.lower())

    def toString(self, symbol):
        return self.names[symbol]

    def post(self, symbol):
        if self.names.has_key(symbol+1):
            return symbol+1
        else:
            return None

    def pre(self, symbol):
        if self.names.has_key(symbol-1):
            return symbol-1
        else:
            return None

debugLevels = {}
def debug(s):
    if debugLevels.has_key(s[0]):
        log(s[2:])

def hasDebug(x):
    return debugLevels.has_key(x)
        
def Sort(l,sf=None):
    "Sort a list, returning the sorted list"
    l2 = l[:]
    if sf:
        l2.sort(sf)
    else:
        l2.sort()
    return l2

def Uniq(l):
    "uniquify a list, returning the changed one"
    d = {}
    for x in l:
        d[x] = 1
    return d.keys()

def UniqOrdered(l):
    seen = {}; r= []
    for x in l:
        if x in seen:
            continue
        seen[x] = True
        r.append(x)
    return r
        

class Struct(dict):
    "A class that can e initialized with a keyword list"
    def __init__(self, **entries):
        self.update(entries)
    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError, name
    def __setattr__(self, name, value):
        self[name] = value

    def __hash__(self):
        return id(self)

def split(s,delim=None):
    if not s:
        return []
    else:
        return s.split(delim)

# Norvig (but I like If better than if_)
def If(test, result, alternative=None):
    "If test is true, 'do' result, else alternative. 'Do' means call if callable."
    if test:
        if callable(result): result = result()
        return result
    else:
        if callable(alternative): alternative = alternative()
        return alternative

def memoize(fn):
    """Memoize fn: make it remember the computed value for any argument list.
    If slot is specified, store result in that slot of first argument.
    If slot is false, store results in a dictionary.
    Ex: def fib(n): return (n<=1 and 1) or (fib(n-1) + fib(n-2)); fib(9) ==> 55
    # Now we make it faster:
    fib = memoize(fib); fib(9) ==> 55"""
    def memoized_fn(*args):
        if not memoized_fn.cache.has_key(args):
            memoized_fn.cache[args] = fn(*args)
        return memoized_fn.cache[args]
    memoized_fn.cache = {}
    return memoized_fn

def timer(n, *fn_and_args):
    """Apply fn to args n times and return the number of seconds elapsed.
    You can leave out the n and it defaults to 1.
    Ex: timer(100, abs, -1); timer(1e4, Dict); timer(pow, 3, 4); timer(Dict)"""
    import time
    try:
        n, fn, args = int(n), fn_and_args[0], fn_and_args[1:]
    except TypeError: # n was not a number; it must be fn, with n=1 as default
        n, fn, args = 1, n, fn_and_args
    print 'ok'
    iterations = range(n)
    start_time = time.clock()
    for i in iterations:
        fn(*args)
    return time.clock() - start_time


def indices(l,offset=0):
    """
    Return a zipped tuple of the original passed list, and a an index, allowing
    iteration over the list and its indices simultanously. Optional offset parameter
    is base for the index.
    """
    return zip(l, range(offset,len(l)+offset))

def gindices(l, offset = 0):
    for x in l:
        yield x, offset
        offset += 1

class MutableInt(object):
    "A poor wrapper around int to make it mutable"
    def __init__(self, val=0):
        self.val = val

    def set(self,val):
        self.val = val

    def __nonzero__(self):
        return self.val

def makedict(l, value=None):
    """
    Create a dictionary with keys equal to items in list
    A value for the items can be specified. It can be a simple value
    or a function, in which case the function is applied to the key to
    get the item value
    """
    if not callable(value):
        fun = lambda x: value
    else:
        fun = value
    return dict([(x, fun(x)) for x in l])

def noun(num, word):
    "Given a number and a noun, return number + plur of noun"
    if num == 1:
        return '1 %s' % word
    else:
        return '%d %ss' % (num,word)

class each:
    """
    Given a sequence variable, each will return an object that will
    apply all actions to this variable to each of the elements in the
    vector
    """
    def __init__(self, l):
        self.l = l

    def __coerce__(self, other):
        return None

    def __getattr__(self, name):
        def wrapper(*args, **kwargs):
            return [getattr(x, name)(*args, **kwargs) for x in self.l]
        if name == '__iter__':
            raise AttributeError
        if not self.l:
            return lambda *args: []
        elif callable(getattr(self.l[0],name)):
            return wrapper
        else:
            return [getattr(x,name) for x in self.l]

    def __repr__(self):
        return '<each %r>' % self.l

class Zero:
    "Return False when called and be False"
    def __call__(self, *args):
        return False
    def __nonzero__(self):
        return False

## Needs more work    
class any:
    def __init__(self, l):
        self.l = l

    def __coerce__(self, other):
        return None

    def __getattr__(self, name):
        def wrapper(*args, **kwargs):
            for x in self.l:
                val = getattr(x, name)(*args, **kwargs)
                print name, repr(val)
                if getattr(x, name)(*args, **kwargs):
                    return True
            return False
        
        if not self.l:
            return Zero()
        
        elif callable(getattr(self.l[0],name)):
            return wrapper
        else:
            for x in self.l:
                if getattr(x,name):
                    return True
            return False

    def __repr__(self):
        return '<any %r>' % self.l

def oneArgument(args, what= ' '):
    "Classical one_argument function -- split args into firstArg, rest."
    if not args:
        return '',''
    elif args.find(what) == -1:
        return args, ''
    else:
        return args.split(what, 1)

def createMatrix(default, *args):
    """
    Create a matrix. default is what gets put into each cell.
    Each arg should contain size of that dimension.
    """
    create = lambda copy=copy.copy: copy(default)
    if len(args) == 1:
        return [create() for x in xrange(args[0])]
    else:
        return [createMatrix(default, *args[1:]) for x in xrange(args[0])]

def stableSort(l, compare=cmp):
    """
    Sort a list.
    Identical elements compare the same
    *Returns* sorted list.
    """
    def _compare(x,y):
        return compare(x[0], y[0]) or cmp(x[1], y[1])
    l2 = indices(l)
    l2.sort(_compare)
    return each(l2)[0]

def stableDSU(l, decorator):
    l = [(decorator(x), i, x) for x,i in indices(l)]
    l.sort()
    return each(l)[2]

def fastDSU(l, decorator):
    l = [(decorator(x), x) for x in l]
    l.sort()
    return [x[1] for x in l]

    d = dict([(x, decorator(x)) for x in l])
    def compare(x,y):
        return cmp(d[x], d[y])
    l.sort(compare)

def splitOn(l, pred):
    return [x for x in l if pred(x)], [x for x in l if not pred(x)]

def groupStride(seq, stride):
    l = []
    while seq:
        l.append(seq[:stride])
        seq = seq[stride:]
    return l
        

def quote(s):
    "Quote CGI chars and \" () # &"
    quoted = (('&', 'amp'), ('#', '#35'),('<', 'lt'),
              ('>', 'gt'),  ('"', 'quot'),
              ('(', '#40'), (')', '#41'))
    if s is None:
        s = ''
    s = str(s)
    for (c, e) in quoted:
        s = s.replace(c, '&%s;' % e)
    return s

def filterURL(s):
    return ''.join([x for x in s if x in string.letters + string.digits + '/.:-'])

def flatten(l):
    r = []
    for x in l:
        if isinstance(x, (list, tuple)):
            r.extend(flatten(x))
        else:
            r.append(x)
    return r

def cartesianProduct(sequence):
    "Generate a combination of each elements the sequence"
    if not sequence:
        yield []
    else:
        first, rest = sequence[0], sequence[1:]
        for x in first:
            for y in cartesianProduct(rest):
                yield [x] + y
    

def combinations(sequence, number):
    "Generate all combinations of NUMBER elements from list SEQUENCE."
    # Adapted from Python 2.2 `test/test_generators.py'.
    # By Francois Pinard
    if number > len(sequence):
	return
    if number == 0:
	yield []
    else:
	first, remainder = sequence[0], sequence[1:]
	# Some combinations retain FIRST.
	for result in combinations(remainder, number-1):
	    result.insert(0, first)
	    yield result
	# Some combinations do not retain FIRST.
	for result in combinations(remainder, number):
	    yield result

def setCompare(s1,s2):
    "Return remove,same,added"
    d1,d2 = makedict(s1), makedict(s2)
    removed = []; added = []; same = []
    for x in d1.keys():
        if x in d2:
            same.append(x)
        else:
            removed.append(x)
    for x in d2.keys():
        if x not in d1:
            added.append(x)
    return removed, same, added

_isatty = None
def showProgress(what, cur, max, _start={}):
    global _isatty
    if _isatty is None:
        _isatty = os.isatty(sys.stdout.fileno())
    if not _isatty:
        return

    if cur == 0:
        stime = _start[what, max] = time.time()
    else:
        try:
            stime = _start[what, max]
        except KeyError:
            stime = _start[what, max] = time.time()
        
    
    dots = (50*cur)/max
    sys.stdout.write('\r%s: %s%s' % (what, '*' * dots, '.' * (50-dots)))

    passed = time.time() - stime
    if passed > 5:
        eta = (1+(int((passed/cur) * (max-cur)))/5)*5
        sys.stdout.write(' ETA %02d:%02d' % (eta / 60, eta %60))
    
    sys.stdout.flush()

class NullFile:
    def write(self, data):
        pass
    
def callableDescr(c):
    # XXX BROKEN ##
  if hasattr(c, 'im_class'):
      return '%s.%s' % (c.im_class.__name__, c.func_name)
  elif isinstance(c, types.FunctionType):
      return '%s.%s' % (c.__module__, c.func_name)
  else:
      try:
          return c.__class__.__name__
      except: return repr(c)

_klasses = {}
def annotate(val, **args):
    try:
        return _klasses[type(val)](val, **args)
    except KeyError:
        class Annotated(type(val)):
            def __new__(self, val, **args):
                v = type(val).__new__(self, val)
                v.__dict__.update(args)
                return v
        _klasses[type(val)] = Annotated
        return Annotated(val, **args)
    

CommentRx = re.compile('#.*')
def linefilter(f):
    "Return non-empty lines, stripping comments"
    for x in f:
        x = x.strip()
        x = CommentRx.sub('', x)
        if x:
            yield x
            
def takeWeighted(count, l):
    l = list(l)
    winners = []
    for x in range(count):
        if not l:
            break
        num = random.randrange(sum(each(l)[0]))
        for y in l:
            if y[0] >= num:
                break
            num -= y[0]
        winners.append(y)
        l.remove(y)
    return winners

        
# Source: Dalke
def traceit(frame, event, arg):
    if event == "line":
        lineno = frame.f_lineno
        filename = frame.f_globals["__file__"]
        if (filename.endswith(".pyc") or
            filename.endswith(".pyo")):
            filename = filename[:-1]
        name = frame.f_globals["__name__"]
        line = linecache.getline(filename, lineno)
        print "%s:%s: %s" % (name, lineno, line.rstrip())
        sys.stdout.flush()
    return traceit

def nullfun(*args,**kw): pass

# GvR
def any(S):
    for x in S:
        if x:
           return x
    return None

def all(S):
    for x in S:
        if not x:
           return False
    return True
    
class ddict(dict):
    "Dictionary with Perlish autovivfication, suitable for e.g. d[x] += 1"
    def __init__(self, default=0, **keys):
        self.default = default
        dict.__init__(self, **keys)

    def __getitem__(self, item):
        try:
            return dict.__getitem__(self, item)
        except KeyError:
            return self.default
        
