用朴素贝叶斯算法对新闻进行分类

20 newsgroups数据集介绍

20 newsgroups数据集18000篇新闻文章,一共涉及到20种话题,所以称作20 newsgroups text dataset,分文两部分:训练集和测试集,通常用来做文本分类.

fetch_20newsgroups()方法介绍

'''
fetch_20newsgroups(data_home=None,subset='train',categories=None,shuffle=True,random_state=42,remove=(),download_if_missing=True)
'''
'''
data_home指的是数据集的地址,如果默认的话,所有的数据都会在'~/scikit_learn_data'文件夹下.

subset就是train,test,all三种可选,分别对应训练集、测试集和所有样本。

categories:是指类别,如果指定类别,就会只提取出目标类,如果是默认,则是提取所有类别出来。

shuffle:是否打乱样本顺序,如果是相互独立的话。

random_state:打乱顺序的随机种子

remove:是一个元组,用来去除一些停用词的,例如标题引用之类的。

download_if_missing: 如果数据缺失,是否去下载。
'''

用朴素贝叶斯算法对新闻进行分类

#coding=utf-8

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer

from sklearn.naive_bayes import MultinomialNB

from sklearn.datasets import get_data_home
print("数据集默认存放目录",get_data_home())

#首次运行需要下载数据集 大约14MB,可能会下载很慢,后面会介绍离线下载的方式
#subset就是train,test,all三种可选,分别对应训练集、测试集和所有样本。
data=fetch_20newsgroups(subset="all")

#数据集划分
x_train,x_test,y_train,y_test=train_test_split(data.data,data.target)

#转换器
transfer=TfidfVectorizer()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)

#预估器
estimator=MultinomialNB()
estimator.fit(x_train,y_train)

y_predict=estimator.predict(x_test)

print("目标值与预测值:",y_test==y_predict)

print("准确率",estimator.score(x_test,y_test))
返回笔记列表
入门小站