本文利用了transformers中的BertModel,对部分cnews数据集进行了文本分类,用来对BERT模型练手还是不错的。
数据描述
数据集是从清华大学的THUCNews中提取出来的部分数据。
训练集中有5万条数据,分成了10类,每类5000条数据。
1 | {'体育': 5000, '娱乐': 5000, '家居': 5000, '房产': 5000, '教育': 5000, '时尚': 5000, '时政': 5000, '游戏': 5000, '科技': 5000, '财经': 5000} |
验证集中有5000条数据,每类500条数据。
1 | {'体育': 500, '娱乐': 500, '家居': 500, '房产': 500, '教育': 500, '时尚': 500, '时政': 500, '游戏': 500, '科技': 500, '财经': 500} |
模型描述
整个分类模型首先把句子输入到Bert预训练模型,然后将句子的embedding输入给一个全连接层,最后把全连接层的输出输入到softmax中。模型代码如下:
1 | # model.py |
数据处理
生成BERT的输入,并封装成TensorDataset
1 | def get_bert_input(text, tokenizer, max_len=512): |
训练
训练模型,并对验证集进行验证,然后把最优模型保存下来
1 | def main(): |
预测
用训练好的最优模型对新闻进行预测
1 | import torch |
结果
1 | # output.txt |
完整项目放在了GitHub上: https://github.com/illiterate/BertClassifier