# Loop:
#     Optional:
#         # Set up a new base classifier for testing.
#         new_classifier(), or set_classifier()
#     # Run tests against (possibly variants of) this classifier.
#     Loop:
#         Loop:
#             Optional:
#                 # train on more ham and spam
#                 train(ham, spam)
#             Optional:
#                 # Forget training for some subset of ham and spam.
#                 untrain(ham, spam)
#         # Predict against other data.
#         Loop:
#             test(ham, spam)
#         # Display stats against all runs on this classifier variant.
#         # This also saves the trained classifer, if desired (option
#         # save_trained_pickles).
#         finishtest()
# # Display stats against all runs.
# alldone()

try:
    from sets import Set
except ImportError:
    from spambayes.compatsets import Set

import cPickle as pickle

try:
    from heapq import heapreplace
except ImportError:
    from spambayes.compatheapq import heapreplace

from spambayes.Options import options
from spambayes import Tester
from spambayes import classifier
from spambayes.Histogram import Hist

try:
    True, False
except NameError:
    # Maintain compatibility with Python 2.2
    True, False = 1, 0


def printhist(tag, ham, spam, nbuckets=options["TestDriver", "nbuckets"]):
    print
    print "-> <stat> Ham scores for", tag,
    ham.display(nbuckets)

    print
    print "-> <stat> Spam scores for", tag,
    spam.display(nbuckets)

    if not options["TestDriver", "compute_best_cutoffs_from_histograms"]:
        return
    if ham.n == 0 or spam.n == 0:
        return

    # Figure out "the best" ham & spam cutoff points, meaning the ones that
    # minimize
    #    num_fp * fp_weight + num_fn + fn_weight + num_unsure * unsure_weight
    # the total number of misclassified msgs (other definitions are
    # certainly possible!).

    # At cutoff 0, everything is called spam, so there are no false negatives,
    # and every ham is a false positive.
    assert ham.nbuckets == spam.nbuckets
    n = ham.nbuckets
    FPW = options["TestDriver", "best_cutoff_fp_weight"]
    FNW = options["TestDriver", "best_cutoff_fn_weight"]
    UNW = options["TestDriver", "best_cutoff_unsure_weight"]

    # Get running totals:  {h,s}total[i] is # of ham/spam below bucket i
    htotal = [0] * (n+1)
    stotal = [0] * (n+1)
    for i in range(1, n+1):
        htotal[i] = htotal[i-1] + ham.buckets[i-1]
        stotal[i] = stotal[i-1] + spam.buckets[i-1]
    assert htotal[-1] == ham.n
    assert stotal[-1] == spam.n

    best_cost = 1e200   # infinity
    bests = []          # best h and s cutoffs

    for h in range(n+1):
        num_fn = stotal[h]
        fn_cost = num_fn * FNW
        for s in xrange(h, n+1):
            # ham  0:h  correct
            #      h:s  unsure
            #      s:   FP
            # spam 0:h  FN
            #      h:s  unsure
            #      s:   correct
            num_fp = htotal[-1] - htotal[s]
            num_un = htotal[s] - htotal[h] + stotal[s] - stotal[h]
            cost = num_fp * FPW + fn_cost + num_un * UNW
            if cost <= best_cost:
                if cost < best_cost:
                    best_cost = cost
                    bests = []
                bests.append((h, s))

    print '-> best cost for %s $%.2f' % (tag, best_cost)
    print '-> per-fp cost $%.2f; per-fn cost $%.2f; per-unsure cost $%.2f' % (
          FPW, FNW, UNW)

    if len(bests) > 1:
        print '-> achieved at', len(bests), 'cutoff pairs'
        info = [('smallest ham & spam cutoffs', bests[0]),
                ('largest ham & spam cutoffs', bests[-1])]
    else:
        info = [('achieved at ham & spam cutoffs', bests[0])]

    for tag, (h, s) in info:
        print '-> %s %g & %g' % (tag, float(h)/n, float(s)/n)
        num_fn = stotal[h]
        num_fp = htotal[-1] - htotal[s]
        num_unh = htotal[s] - htotal[h]
        num_uns = stotal[s] - stotal[h]
        print '->     fp %d; fn %d; unsure ham %d; unsure spam %d' % (
              num_fp, num_fn, num_unh, num_uns)
        print '->     fp rate %.3g%%; fn rate %.3g%%; unsure rate %.3g%%' % (
              num_fp*1e2 / ham.n, num_fn*1e2 / spam.n,
              (num_unh + num_uns)*1e2 / (ham.n + spam.n))

    return float(bests[0][0])/n,float(bests[0][1])/n

def printmsg(msg, prob, clues):
    print msg.tag
    print "prob =", prob
    for clue in clues:
        print "prob(%r) = %g" % clue
    print
    guts = str(msg)
    if options["TestDriver", "show_charlimit"] > 0:
        guts = guts[:options["TestDriver", "show_charlimit"]]
    print guts


class Driver:

    def __init__(self):
        self.falsepos = Set()
        self.falseneg = Set()
        self.unsure = Set()
        self.global_ham_hist = Hist()
        self.global_spam_hist = Hist()
        self.ntimes_finishtest_called = 0
        self.new_classifier()
        from spambayes import CostCounter
        self.cc=CostCounter.default()

    def new_classifier(self):
        """Create and use a new, virgin classifier."""
        self.set_classifier(classifier.Bayes())

    def set_classifier(self, classifier):
        """Specify a classifier to be used for further testing."""
        self.classifier = classifier
        self.tester = Tester.Test(classifier)
        self.trained_ham_hist = Hist()
        self.trained_spam_hist = Hist()

    def train(self, ham, spam):
        print "-> Training on", ham, "&", spam, "...",
        c = self.classifier
        nham, nspam = c.nham, c.nspam
        self.tester.train(ham, spam)
        print c.nham - nham, "hams &", c.nspam- nspam, "spams"

    def untrain(self, ham, spam):
        print "-> Forgetting", ham, "&", spam, "...",
        c = self.classifier
        nham, nspam = c.nham, c.nspam
        self.tester.untrain(ham, spam)
        print nham - c.nham, "hams &", nspam - c.nspam, "spams"

    def finishtest(self):
        if options["TestDriver", "show_histograms"]:
            printhist("all in this training set:",
                      self.trained_ham_hist, self.trained_spam_hist)
        self.global_ham_hist += self.trained_ham_hist
        self.global_spam_hist += self.trained_spam_hist
        self.trained_ham_hist = Hist()
        self.trained_spam_hist = Hist()

        self.ntimes_finishtest_called += 1
        if options["TestDriver", "save_trained_pickles"]:
            fname = "%s%d.pik" % (options["TestDriver", "pickle_basename"],
                                  self.ntimes_finishtest_called)
            print "    saving pickle to", fname
            fp = file(fname, 'wb')
            pickle.dump(self.classifier, fp, 1)
            fp.close()

    def alldone(self):
        if options["TestDriver", "show_histograms"]:
            besthamcut,bestspamcut = printhist("all runs:",
                                               self.global_ham_hist,
                                               self.global_spam_hist)
        else:
            besthamcut = options["Categorization", "ham_cutoff"]
            bestspamcut = options["Categorization", "spam_cutoff"]
            self.global_ham_hist.compute_stats()
            self.global_spam_hist.compute_stats()
        nham = self.global_ham_hist.n
        nspam = self.global_spam_hist.n
        nfp = len(self.falsepos)
        nfn = len(self.falseneg)
        nun = len(self.unsure)
        print "-> <stat> all runs false positives:", nfp
        print "-> <stat> all runs false negatives:", nfn
        print "-> <stat> all runs unsure:", nun
        print "-> <stat> all runs false positive %:", (nfp * 1e2 / nham)
        print "-> <stat> all runs false negative %:", (nfn * 1e2 / nspam)
        print "-> <stat> all runs unsure %:", (nun * 1e2 / (nham + nspam))
        print "-> <stat> all runs cost: $%.2f" % (
              nfp * options["TestDriver", "best_cutoff_fp_weight"] +
              nfn * options["TestDriver", "best_cutoff_fn_weight"] +
              nun * options["TestDriver", "best_cutoff_unsure_weight"])
        # Set back the options for the delayed calculations in self.cc
        options["Categorization", "ham_cutoff"] = besthamcut
        options["Categorization", "spam_cutoff"] = bestspamcut
        print self.cc

        if options["TestDriver", "save_histogram_pickles"]:
            for f, h in (('ham', self.global_ham_hist),
                         ('spam', self.global_spam_hist)):
                fname = "%s_%shist.pik" % (options["TestDriver",
                                                   "pickle_basename"], f)
                print "    saving %s histogram pickle to %s" %(f, fname)
                fp = file(fname, 'wb')
                pickle.dump(h, fp, 1)
                fp.close()

    def test(self, ham, spam):
        c = self.classifier
        t = self.tester
        local_ham_hist = Hist()
        local_spam_hist = Hist()

        def new_ham(msg, prob, lo=options["TestDriver", "show_ham_lo"],
                               hi=options["TestDriver", "show_ham_hi"]):
            local_ham_hist.add(prob * 100.0)
            self.cc.ham(prob)
            if lo <= prob <= hi:
                print
                print "Ham with prob =", prob
                prob, clues = c.spamprob(msg, True)
                printmsg(msg, prob, clues)

        def new_spam(msg, prob, lo=options["TestDriver", "show_spam_lo"],
                                hi=options["TestDriver", "show_spam_hi"]):
            local_spam_hist.add(prob * 100.0)
            self.cc.spam(prob)
            if lo <= prob <= hi:
                print
                print "Spam with prob =", prob
                prob, clues = c.spamprob(msg, True)
                printmsg(msg, prob, clues)

        t.reset_test_results()
        print "-> Predicting", ham, "&", spam, "..."
        t.predict(spam, True, new_spam)
        t.predict(ham, False, new_ham)
        print "-> <stat> tested", t.nham_tested, "hams &", t.nspam_tested, \
              "spams against", c.nham, "hams &", c.nspam, "spams"

        print "-> <stat> false positive %:", t.false_positive_rate()
        print "-> <stat> false negative %:", t.false_negative_rate()
        print "-> <stat> unsure %:", t.unsure_rate()
        print "-> <stat> cost: $%.2f" % (
               t.nham_wrong * options["TestDriver", "best_cutoff_fp_weight"] +
               t.nspam_wrong * options["TestDriver", "best_cutoff_fn_weight"] +
               (t.nham_unsure + t.nspam_unsure) *
               options["TestDriver", "best_cutoff_unsure_weight"])

        newfpos = Set(t.false_positives()) - self.falsepos
        self.falsepos |= newfpos
        print "-> <stat> %d new false positives" % len(newfpos)
        if newfpos:
            print "    new fp:", [e.tag for e in newfpos]
        if not options["TestDriver", "show_false_positives"]:
            newfpos = ()
        for e in newfpos:
            print '*' * 78
            prob, clues = c.spamprob(e, True)
            printmsg(e, prob, clues)

        newfneg = Set(t.false_negatives()) - self.falseneg
        self.falseneg |= newfneg
        print "-> <stat> %d new false negatives" % len(newfneg)
        if newfneg:
            print "    new fn:", [e.tag for e in newfneg]
        if not options["TestDriver", "show_false_negatives"]:
            newfneg = ()
        for e in newfneg:
            print '*' * 78
            prob, clues = c.spamprob(e, True)
            printmsg(e, prob, clues)

        newunsure = Set(t.unsures()) - self.unsure
        self.unsure |= newunsure
        print "-> <stat> %d new unsure" % len(newunsure)
        if newunsure:
            print "    new unsure:", [e.tag for e in newunsure]
        if not options["TestDriver", "show_unsure"]:
            newunsure = ()
        for e in newunsure:
            print '*' * 78
            prob, clues = c.spamprob(e, True)
            printmsg(e, prob, clues)

        if options["TestDriver", "show_histograms"]:
            printhist("this pair:", local_ham_hist, local_spam_hist)
        self.trained_ham_hist += local_ham_hist
        self.trained_spam_hist += local_spam_hist


syntax highlighted by Code2HTML, v. 0.9.1