1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
| import sys
import jieba import pymysql import kashgari from collections import Counter from textrank4zh import TextRank4Keyword, TextRank4Sentence from kashgari.tasks.classification import CNN_Model from kashgari.embeddings import BERTEmbedding
def getDirtyWord(): ans = [] file = '/Users/roarboil/Desktop/news-python/dataprocess/word.txt' f = open(file, 'r') ff = f.readlines() for line in ff: line = line.rstrip("\n") ans.append(line) return ans
def process(content): res = list(jieba.cut(content)) dirty = [] for c in res: if c in wordList: dirty.append(c) for i in dirty: if i in res: res.remove(i) return res
def getData(): x_list = [] y_list = [] val_x = [] val_y = [] test_x = [] test_y = [] database = pymysql.connect(host="47.97.90.30", user="news", port=3306, passwd="news", db="news", charset='utf8') categoryList = ["politics", "fortune", "local", "health", "tech"] cursor = database.cursor() for category in categoryList: sql = "select * from news where category = %s" cursor.execute(sql, category) results = cursor.fetchall() tmpx = [] tmpy = [] for row in results: category = row[4] content = str(row[6]) if content == "": continue tmpy.append(category) tmpx.append(process(content[:300])) percent = len(tmpx) // 10 y_list += tmpy x_list += tmpx val_y += tmpy[:5 * percent] val_x += tmpx[:5 * percent] test_y += tmpy[5 * percent:] test_x += tmpx[5 * percent:] return x_list, y_list, val_x, val_y, test_x, test_y
if __name__ == '__main__': wordList = getDirtyWord() BERT_PATH = '/Users/roarboil/Desktop/news-python/dataprocess/chinese_L-12_H-768_A-12' bert_embed = BERTEmbedding(BERT_PATH, task=kashgari.CLASSIFICATION, sequence_length=100) model = CNN_Model(bert_embed)
train_x, train_y, val_x, val_y, test_x, test_y = getData()
batch_size = 16 epochs = 3
model.fit(train_x, train_y, val_x, val_y, batch_size=int(batch_size), epochs=int(epochs)) model.evaluate(test_x, test_y) model.save('/Users/roarboil/Desktop/news-python/dataprocess/model')
|