import re
import math

def getwords(doc):
    splitter=re.compile('\\W*')
    words=[s.lower() for s in splitter.split(doc)
           if len(s)>2 and len(s)<20]
    return dict([(w,1) for w in words])

class classifier:
    def __init__(self,getfeatures,filename=None):
        # nro combinazioni feature/categoria
        self.fc={}
        # nro documenti per categoria
        self.cc={}
        self.getfeatures=getfeatures

    #inc the count of a feature/category
    def incf(self,f,cat):
        self.fc.setdefault(f,{})
        self.fc[f].setdefault(cat,0)
        self.fc[f][cat]+=1

    #inc the count of a category
    def incc(self,cat):
        self.cc.setdefault(cat,0)
        self.cc[cat]+=1

    #nro di volte una feature appare per una categoria
    def fcount(self,f,cat):
        if f in self.fc and cat in self.fc[f]:
            return float(self.fc[f][cat])
        return 0.0

    #nro di item in una categoria
    def catcount(self,cat):
        if cat in self.cc:
            return float(self.cc[cat])
        return 0
    
    #nro totale di item
    def totalcount(self):
        return sum(self.cc.values())

    #lista di tutte le categorie
    def categories(self):
        return self.cc.keys()

    #train
    def train(self,item,cat):
        features=self.getfeatures(item)
        for f in features:
            self.incf(f,cat)
        self.incc(cat)

    # calcola prob
    def fprob(self,f,cat):
        if self.catcount(cat)==0: return 0
        return self.fcount(f,cat)/self.catcount(cat)

    # weighted-prob
    def weightedprob(self,f,cat, prf, weight=1.0,ap=0.5):
        basicprob=prf(f,cat)
        totals=sum([self.fcount(f,c) for c in self.categories()])
        bp=((weight*ap)+(totals*basicprob))/(weight+totals)
        return bp

class naivebayes(classifier):
    def docprob(self,item,cat):
        features=self.getfeatures(item)
        p=1
        for f in features: p*=self.weightedprob(f,cat,self.fprob)
        return p
    def prob(self,item,cat):
        catprob=self.catcount(cat)/self.totalcount()
        docprob=self.docprob(item,cat)
        return docprob*catprob
    def classify(self,item):
        max=0;
        probs={}
        for cat in self.categories():
            probs[cat]=self.prob(item,cat)
            if probs[cat]>max:
                max=probs[cat]
                best=cat
        return best
            
            

import os
nb=naivebayes(getwords)
for f in os.listdir('NEWS/mini_newsgroups/sci.crypt'):
    print f
    testo = file('NEWS/mini_newsgroups/sci.crypt/'+str(f)).read()
    nb.train(testo,'sci.crypt')
    print 'ok'

for f in os.listdir('NEWS/mini_newsgroups/sci.crypt'):
    print f
    testo = file('NEWS/mini_newsgroups/sci.med/'+str(f)).read()
    nb.train(testo,'sci.med')
    print 'ok'




def sampletrain(cl):
    cl.train('Nobody owns the water.','good')
    cl.train('the quick rabbit jumps fences','good')
    cl.train('buy pharmaceuticals now','bad')
    cl.train('make quick money at the online casino','bad')
    cl.train('the quick brown fox jumps','good')


