import unittest
import sys
import time

import parens as parens

if sys.hexversion < 0x02050000:
    print "6.006 code was designed for Python 2.5, and you are running " + \
          "an older version. http://python.org/download"
    sys.exit()

def flatten(collection):
    result = []
    for i in collection:
        if hasattr(i, "__iter__") and not isinstance(i, basestring):
            result.extend(flatten(i))
        else:
            result.append(i)
    return result

class TestParens(unittest.TestCase):
    def test1_tiny(self):
        """tiny tests"""
        self.paren_test([100, 1000, 10], 1000000)
        self.paren_test([5, 9, 3], 135)
        
    def test2_small(self):
        """small tests"""
        self.paren_test([1, 2, 3, 4, 5], 38)
        self.paren_test([5, 4, 3, 2, 1], 38)

    def test3_medium(self):
        """medium tests"""
        self.paren_test([10, 2, 11, 3, 12, 4, 13, 5, 14, 6], 896)
        self.paren_test([61, 8, 25, 5, 39, 4, 21, 4, 6], 5476)
        self.paren_test([100, 5, 3, 5, 20, 5, 3, 25, 95], 37790)
        
    def test4_large(self):
        """large tests"""
        self.file_test('parens_20', 18469330)
        self.file_test('parens_40', 12146250)
        self.file_test('parens_60', 3214978)
        self.file_test('parens_80', 4333337)
        self.file_test('parens_100', 48534520)
        
    def test5_huge(self):
        """huge test"""
        self.file_test('parens_300', 72331164)
        
    def file_test(self, file_name, best_cost):
        """tests a file-based instance of the matrix chain multiplication problem"""
        s_file = open('data/' + file_name)
        length = int(s_file.readline())
        array = []
        for i in range(length):
            array.append(int(s_file.readline()))
        self.paren_test(array, best_cost)
    
    def paren_test(self, dimensions, best_cost):
        """tests an instance of the matrix chain multiplication problem"""
        paren_sizes = [[dimensions[i - 1], dimensions[i]] for i in range(1, len(dimensions))]
        answer = parens.solve(paren_sizes)
        
        # answer should contain all terms in order
        flat_answer = flatten(answer)
        self.assertEqual(len(paren_sizes), len(flat_answer))
        for i in range(len(flat_answer)):
            self.assertEqual(i, flat_answer[i])
            
        # compute cost 
        def multiplication_cost(start1, end1, start2, end2):
            return paren_sizes[start1][0] * paren_sizes[start2][0] * paren_sizes[end2][1]
        
        def xcost(subanswer):
            if not isinstance(subanswer, list):
                return [0, subanswer, subanswer]
            # make sure multiplication order is well determined
            self.assertEqual(2, len(subanswer))
            xc1 = xcost(subanswer[0])
            xc2 = xcost(subanswer[1])
            return [xc1[0] + xc2[0] + multiplication_cost(xc1[1], xc1[2], xc2[1], xc2[2]),
                    xc1[1], xc2[2]]
        
        answer_cost = xcost(answer)[0]
        self.assertEqual(best_cost, answer_cost)
        
if __name__ == '__main__':
    unittest.main(argv = unittest.sys.argv + ['--verbose'])
