import unittest
import os
import types
import re

class FunctionTest(unittest.TestCase):
    """Runs a single function as a test. Produces output that looks
    a little different than that of unittest's own FunctionTestCase."""
    
    def __init__(self, func, *args):
        """@param func:  function to run as a test."""
        self.func = func
        self.args = args
        unittest.TestCase.__init__(self)

    def shortDescription(self):
        """Uses the function's doc string as the ideal form. Otherwise,
        the name is used."""
        if self.func.__doc__:
            name = self.func.__doc__
        else:
            name = self.func.func_name
        return "%s (%s)" % (name, self.func.__module__)

    def runTest(self):
        self.func(*self.args)
        
class SetupFunction(FunctionTest):
    """Runs a setup or teardown function. Does not count in the total
    test count, but masquerades as a test for convenient execution."""

    def countTestCases(self):
        return 0

def cmpFunctionLocation(a, b):
    """Compares two functions' line numbers to ensure tests run 
    in file order."""
    
    return cmp(a.func_code.co_firstlineno, b.func_code.co_firstlineno)

class Collector(unittest.TestSuite):
    """A TestSuite that is automatically built from tests in the supplied
    package.
    """
    regex = None

    def __init__(self, packagename, regex=None):
        """@packagename  a string name for the package to look up and search
        
        @regex  a regular expression to match against filenames and test names"""
        super(Collector, self).__init__()
        
        if regex:
            self.regex = re.compile(regex)
            
        mod = __import__(packagename, dict(), dict())
        packagedir = os.path.abspath(os.path.dirname(mod.__file__))
        
        for root, dirs, files in os.walk(packagedir):
            if ".svn" in root or ".cvs" in root:
                continue
                
            currentdir = root[len(packagedir)+1:]
            if currentdir:
                currentpackage = "%s.%s" % (packagename, 
                    currentdir.replace(os.path.sep, "."))
            else:
                currentpackage = packagename
                
            self._processFiles(files, currentpackage)
            
    def _processFiles(self, files, currentpackage):
            regex = self.regex
            
            for file in files:
                ismatch = False
                if regex:
                    if regex.search(file):
                        ismatch = True
                if file.startswith("test_"):
                    filename, ext = os.path.splitext(file)
                    if ext != ".py":
                        continue
                    self._processModule(filename, currentpackage, 
                        ismatch)
                        
    def _processModule(self, filename, currentpackage, ismatch):
        regex = self.regex
        
        mod = __import__("%s.%s" % (currentpackage, filename),
            dict(), dict(), currentpackage)
        candidates = dir(mod)
        
        if hasattr(mod, "setup_module"):
            self.addTest(SetupFunction(mod.setup_module, mod))
        addInOrder = []
        for candidatename in candidates:
            if regex and not ismatch and not regex.search(candidatename):
                continue
            candidate = getattr(mod, candidatename)
            
            if type(candidate) == types.TypeType and \
                issubclass(candidate, unittest.TestCase):
                loader = unittest.defaultTestLoader
                self.addTest(loader.loadTestsFromTestCase(
                    candidate))
            elif type(candidate) == types.FunctionType:
                if not candidatename.lower().startswith("test"):
                    continue
                addInOrder.append(candidate)
                
        addInOrder.sort(cmpFunctionLocation)
        for func in addInOrder:
            self.addTest(FunctionTest(func))
        
        if hasattr(mod, "teardown_module"):
            self.addTest(SetupFunction(mod.teardown_module, mod))


syntax highlighted by Code2HTML, v. 0.9.1