#!/usr/bin/python

# osh
# Copyright (C) 2005 Jack Orenstein <jao@geophile.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

import traceback

from osh.oshapi import *

o = None
e = None

def save(L, x):
    L.append(x)
    return None

def output_handlers():
    global o
    global e
    o = []
    e = []
    return {'o': f(lambda *x: save(o, x)),
            'e': f(lambda command, input, message: save(e, (command, input, message)))}

def smoketest(label, pipeline, expected_output):
    print label
    pipeline.append(output_handlers())
    try:
        osh(*pipeline)
        if e:
            print e
        if o != expected_output:
            print 'expected: %s' % str(expected_output)
            print 'actual:   %s' % str(o)
    except:
        traceback.print_exc(file = sys.stderr)        
    
smoketest('gen and out',
          [gen(3)],
          [(0,), (1,), (2,)])

smoketest('f (lambda)',
          [gen(3), f(lambda x: x * 10)],
          [(0,), (10,), (20,)])
smoketest('f (string)',
          [gen(3), f('x: x * 10')],
          [(0,), (10,), (20,)])

smoketest('select (lambda)',
          [gen(4), select(lambda x: x % 2 == 1)],
          [(1,), (3,)])
smoketest('select (string)',
          [gen(4), select('x: x % 2 == 1')],
          [(1,), (3,)])

smoketest('agg (lambda)',
          [gen(5), agg(0, lambda sum, x: sum + x)],
          [(10,)])
smoketest('agg (string)',
          [gen(5), agg(0, 'sum, x: sum + x')],
          [(10,)])
smoketest('agg group (lambda)',
          [gen(5),
           f(lambda x: (x / 2, x)),
           agg(group(lambda halfx, x: halfx), 0, lambda sum, halfx, x: sum + x)],
          [(0, 1), (1, 5), (2, 4)])
smoketest('agg group (string)',
          [gen(5),
           f('x: (x / 2, x)'),
           agg(group('halfx, x: halfx'), 0, 'sum, halfx, x: sum + x')],
          [(0, 1), (1, 5), (2, 4)])
smoketest('agg consecutive (lambda)',
          [gen(5),
           f(lambda x: (x / 2, x)),
           agg(consecutive(lambda halfx, x: halfx), 0, lambda sum, halfx, x: sum + x)],
          [(0, 1), (1, 5), (2, 4)])
smoketest('agg consecutive (string)',
          [gen(5),
           f('x: (x / 2, x)'),
           agg(consecutive('halfx, x: halfx'), 0, 'sum, halfx, x: sum + x')],
          [(0, 1), (1, 5), (2, 4)])

smoketest('sort (lambda)',
          [gen(5), sort(lambda x: -x)],
          [(4,), (3,), (2,), (1,), (0,)])
smoketest('sort (string)',
          [gen(5), sort('x: -x')],
          [(4,), (3,), (2,), (1,), (0,)])

smoketest('expand',
          [gen(4), f(lambda x: [x] * x), expand()],
          [(1,), (2,), (2,), (3,), (3,), (3,)])
smoketest('expand (position)',
          [gen(3), f(lambda x: (x, (5, 6))), expand(1)],
          [(0, 5), (0, 6), (1, 5), (1, 6), (2, 5), (2, 6)])

smoketest('squish (0-1)',
          [gen(9), window(disjoint(3)), squish(), squish('+')],
          [(3,), (12,), (21,)])
smoketest('squish (>1)',
          [gen(27), window(disjoint(3)), squish(), window(disjoint(3)), squish('+ max min')],
          [(9, 7, 2), (36, 16, 11), (63, 25, 20)])

smoketest('window (default)',
          [gen(9), window(disjoint(3)), squish()],
          [(0, 1, 2), (3, 4, 5), (6, 7, 8)])
smoketest('window (disjoint)',
          [gen(9), window(disjoint(3)), squish()],
          [(0, 1, 2), (3, 4, 5), (6, 7, 8)])
smoketest('window (overlap)',
          [gen(9), window(overlap(3)), squish()],
          [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6), (5, 6, 7), (6, 7, 8), (7, 8, None), (8, None, None)])
smoketest('window (predicate function)',
          [gen(9), window(lambda x: x % 3 == 0), squish()],
          [(0, 1, 2), (3, 4, 5), (6, 7, 8)])
smoketest('window (predicate string)',
          [gen(9), window('x: x % 3 == 0'), squish()],
          [(0, 1, 2), (3, 4, 5), (6, 7, 8)])

# smoketest captures output from last out commands, which is the input object,
# not the formatted object. So this test doesn't work.
## smoketest('relabel',
##           [gen(2),
##            f(lambda x: x / (x - 1)),
##            {'o': f(lambda x: 0),
##             'e': f(lambda *x: 1)},
##            {'o': out('OUT: %s'),
##             'e': out('ERR: %s')}],
##           ['OUT: 0', 'OUT: 1'])

smoketest('unique',
          [gen(3, 1),
           f(lambda x: [x for i in range(x)]),
           expand(),
           unique(),
           sort()],
          [(1,), (2,), (3,)])
smoketest('unique (consecutive)',
          [gen(3, 1),
           f(lambda x: [x for i in range(x)]),
           expand(),
           unique(consecutive()),
           sort()],
          [(1,), (2,), (3,)])

# Hmm. Testing stdin is tricky.
# ./smoketest_api "echo 'abc' | osh ^ f 's: [c for c in s]' $" "[['a', 'b', 'c']]"

smoketest('sh',
          [sh('echo abc'), f(lambda s: [c for c in s])],
          [('a', 'b', 'c')])

# print 'n()'
# ./smoketest_api "osh gen 3 ^ f 'x: (x, n(), n())' $" "[(0, 0, 1), (1, 2, 3), (2, 4, 5)]"

# print 'pipeline'
# ./smoketest_api "osh [ gen 3 ] ^ f 'x: (x, n(), n())' $" "[(0, 0, 1), (1, 2, 3), (2, 4, 5)]"
# ./smoketest_api "osh gen 3 ^ [ f 'x: (x, n(), n())' ] $" "[(0, 0, 1), (1, 2, 3), (2, 4, 5)]"
# ./smoketest_api "osh [ gen 3  ^ f 'x: (x, n(), n())' ] $" "[(0, 0, 1), (1, 2, 3), (2, 4, 5)]"
# ./smoketest_api "osh [ gen 3  ^ [ f 'x: (x, n(), n())' ] ] $" "[(0, 0, 1), (1, 2, 3), (2, 4, 5)]"

# Error handling

def collect(out):
    return lambda x: out.append(x)

print 'error handling (default)'
o = []
try:
    osh(gen(3),
        f(lambda x: x / (x - 1)),
        f(collect(o)))
    print 'ZeroDivisionError should have been thrown'
except ZeroDivisionError:
    import traceback
    # traceback.print_exc()
    # expected
    pass
if o != [0]:
    print 'o is wrong: %s' % o

print 'error handling (overridden)'
o = []
e = []
try:
    osh(gen(3),
        f(lambda x: x / (x - 1)),
        {'o': f(collect(o)),
         'e': f(collect(e))})
except ZeroDivisionError:
    print "ZeroDivisionError shouldn't have happened."
if o != [0, 2]:
    print 'o is wrong: %s' % o
if not (len(e) == 1 and isinstance(e[0], OshError)):
    print 'e is wrong: %s' % e
