from spamEmail import spamEmailBayes
import re
from matplotlib import pyplot as plt 
#spam类对象
spam=spamEmailBayes()
#保存词频的词典
spamDict={}
normDict={}
testDict={}
#保存每封邮件中出现的词
wordsList=[]
wordsDict={}
#保存预测结果,key为文件名，值为预测类别
testResult={}
#分别获得正常邮件、垃圾邮件及测试文件名称列表
normFileList=spam.get_File_List(
        r".\data\normal")
spamFileList=spam.get_File_List(
        r".\data\spam")
testFileList=spam.get_File_List(
        r".\data\test")

#获取训练集中正常邮件与垃圾邮件的数量
normFilelen=len(normFileList)
spamFilelen=len(spamFileList)
#获得停用词表，用于对停用词过滤
stopList=spam.getStopWords()
#获得正常邮件中的词频
for fileName in normFileList:
    #清空列表wordsList
    wordsList.clear()
    for line in open("../data/normal/"+fileName):
        #过滤掉非中文字符
        rule=re.compile(r"[^\u4e00-\u9fa5]")#正则
        line=rule.sub("-",line)
        #将每封邮件出现的词保存在wordsList中
        spam.get_word_list(line,wordsList,stopList)
    #统计每个词在所有邮件中出现的次数
    spam.addToDict(wordsList, wordsDict)
normDict=wordsDict.copy()  

#获得垃圾邮件中的词频
wordsDict.clear()
for fileName in spamFileList:
    wordsList.clear()
    for line in open("../data/spam/"+fileName):
        rule=re.compile(r"[^\u4e00-\u9fa5]")
        line=rule.sub("-",line)
        spam.get_word_list(line,wordsList,stopList)
    #get_word_list统计一个邮件的词种
    #addToDict统计某个词在所有邮件中出现的次数
    spam.addToDict(wordsList, wordsDict)
spamDict=wordsDict.copy()
#spamDict词典包含垃圾邮件中的词语，及出现它的邮件份数

# 测试邮件2.0
fpr=[]
tpr=[]
#i控制阈值变化
for i in range(10):
    testResult.clear()
    n=i/10
    print("第{}次执行，阈值设置为{}".format(i,n))
    for fileName in testFileList:
        #清除数据
        testDict.clear()
        wordsDict.clear()
        wordsList.clear()
        #逐个处理文件
        for line in open("../data/test/"+fileName):
            #正则变换
            rule=re.compile(r"[^\u4e00-\u9fa5]")
            line=rule.sub("-",line)
            #这个line文件的词种
            spam.get_word_list(line,wordsList,stopList)
        #test下所有文件的词及频率
        spam.addToDict(wordsList, wordsDict)
        #得到测试集数据，包括单词，单词出现的邮件份数
        testDict=wordsDict.copy()
        #通过计算每个文件中p(s|w)来得到对分类影响最大的15个词
        wordProbList=spam.getTestWords(testDict, spamDict,normDict,normFilelen,spamFilelen)
        #print(wordProbList)
        p=spam.calBayes(wordProbList, spamDict, normDict,normFilelen,spamFilelen)
        #print("第{}次测试的贝叶斯P值：{},比较的值为：{}".format(i+1,p,n))
        if(p>n):
            #print("p比较判断的Boolean值：")
            #print(p>(i+1)/10)
            testResult.setdefault(fileName,1)
            #print("p大于n" % len(testResult))
        else:
            testResult.setdefault(fileName,0)
            #print("p不大于n" % len(testResult))
            #print("证明进入else中")
        #print("testResylt")
        #print(testResult)
    #计算分类正确率（测试集中文件名低于1000的为正常邮件）
    testAccuracy=spam.calAccuracy(testResult)
    #for i,ic in testResult.items():
    #    print(i+"/"+str(ic))
    print("正确率:")
    print(testAccuracy)  
    #纵TPR横FPR
    #TPR即是recall
    precision,recall,F1,FPR=spam.getPrecisionRecallF1(testResult)
    print("准确率：")
    print(precision)
    print("召回率：")
    print(recall)
    print("f1分数：")
    print(F1)
    fpr.append(FPR)
    tpr.append(recall)

fpr = [0] + fpr + [1]
tpr = [0] + tpr + [1]
#print("横序列fpr:")
#print(fpr)
#print("纵序列tpr:")
#print(tpr)
plt.plot(fpr, tpr)
plt.show()

AUC=0
for i in range(len(fpr)-1):
    AUC+=(tpr[i]+tpr[i+1])*(fpr[i+1]-fpr[i])/2
print("AUC值为：")
print(AUC)