| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 
 | import sys
 
 import jieba
 import pymysql
 import kashgari
 from collections import Counter
 from textrank4zh import TextRank4Keyword, TextRank4Sentence
 from kashgari.tasks.classification import CNN_Model
 from kashgari.embeddings import BERTEmbedding
 
 
 def getDirtyWord():
 ans = []
 file = '/Users/roarboil/Desktop/news-python/dataprocess/word.txt'
 f = open(file, 'r')
 ff = f.readlines()
 for line in ff:
 line = line.rstrip("\n")
 ans.append(line)
 return ans
 
 
 def process(content):
 res = list(jieba.cut(content))
 dirty = []
 for c in res:
 if c in wordList:
 dirty.append(c)
 for i in dirty:
 if i in res:
 res.remove(i)
 return res
 
 
 def getData():
 x_list = []
 y_list = []
 val_x = []
 val_y = []
 test_x = []
 test_y = []
 database = pymysql.connect(host="47.97.90.30",
 user="news",
 port=3306,
 passwd="news",
 db="news",
 charset='utf8')
 categoryList = ["politics", "fortune", "local", "health", "tech"]
 cursor = database.cursor()
 for category in categoryList:
 sql = "select * from news where category = %s"
 cursor.execute(sql, category)
 results = cursor.fetchall()
 tmpx = []
 tmpy = []
 for row in results:
 category = row[4]
 content = str(row[6])
 if content == "":
 continue
 tmpy.append(category)
 tmpx.append(process(content[:300]))
 percent = len(tmpx) // 10
 y_list += tmpy
 x_list += tmpx
 val_y += tmpy[:5 * percent]
 val_x += tmpx[:5 * percent]
 test_y += tmpy[5 * percent:]
 test_x += tmpx[5 * percent:]
 return x_list, y_list, val_x, val_y, test_x, test_y
 
 
 if __name__ == '__main__':
 wordList = getDirtyWord()
 BERT_PATH = '/Users/roarboil/Desktop/news-python/dataprocess/chinese_L-12_H-768_A-12'
 bert_embed = BERTEmbedding(BERT_PATH, task=kashgari.CLASSIFICATION, sequence_length=100)
 model = CNN_Model(bert_embed)
 
 train_x, train_y, val_x, val_y, test_x, test_y = getData()
 
 
 
 
 batch_size = 16
 epochs = 3
 
 model.fit(train_x, train_y, val_x, val_y, batch_size=int(batch_size), epochs=int(epochs))
 model.evaluate(test_x, test_y)
 model.save('/Users/roarboil/Desktop/news-python/dataprocess/model')
 
 |