20 newsgroups数据集18000篇新闻文章,一共涉及到20种话题,所以称作20 newsgroups text dataset,分文两部分:训练集和测试集,通常用来做文本分类.
'''
fetch_20newsgroups(data_home=None,subset='train',categories=None,shuffle=True,random_state=42,remove=(),download_if_missing=True)
'''
'''
data_home指的是数据集的地址,如果默认的话,所有的数据都会在'~/scikit_learn_data'文件夹下.
subset就是train,test,all三种可选,分别对应训练集、测试集和所有样本。
categories:是指类别,如果指定类别,就会只提取出目标类,如果是默认,则是提取所有类别出来。
shuffle:是否打乱样本顺序,如果是相互独立的话。
random_state:打乱顺序的随机种子
remove:是一个元组,用来去除一些停用词的,例如标题引用之类的。
download_if_missing: 如果数据缺失,是否去下载。
'''
#coding=utf-8
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.datasets import get_data_home
print("数据集默认存放目录",get_data_home())
#首次运行需要下载数据集 大约14MB,可能会下载很慢,后面会介绍离线下载的方式
#subset就是train,test,all三种可选,分别对应训练集、测试集和所有样本。
data=fetch_20newsgroups(subset="all")
#数据集划分
x_train,x_test,y_train,y_test=train_test_split(data.data,data.target)
#转换器
transfer=TfidfVectorizer()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
#预估器
estimator=MultinomialNB()
estimator.fit(x_train,y_train)
y_predict=estimator.predict(x_test)
print("目标值与预测值:",y_test==y_predict)
print("准确率",estimator.score(x_test,y_test))