使用MovieLens數據集
下載數據集
下面這段代碼寫得很完備,很高大上
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import seaborn as sns
from collections import Counter
import tensorflow as tf
import os
import pickle
import re
from tensorflow.python.ops import math_ops
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import zipfile
import hashlib
def _unzip(save_path, _, database_name, data_path):
"""
Unzip wrapper with the same interface as _ungzip
:param save_path: The path of the gzip files
:param database_name: Name of database
:param data_path: Path to extract to
:param _: HACK - Used to have to same interface as _ungzip
"""
print('Extracting {}...'.format(database_name))
with zipfile.ZipFile(save_path) as zf:
zf.extractall(data_path)
def download_extract(database_name, data_path):
"""
Download and extract database
:param database_name: Database name
"""
DATASET_ML1M = 'ml-1m'
if database_name == DATASET_ML1M:
url = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
hash_code = 'c4d9eecfca2ab87c1945afe126590906'
extract_path = os.path.join(data_path, 'ml-1m')
save_path = os.path.join(data_path, 'ml-1m.zip')
extract_fn = _unzip
if os.path.exists(extract_path):
print('Found {} Data'.format(database_name))
return
if not os.path.exists(data_path):
os.makedirs(data_path)
if not os.path.exists(save_path):
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
urlretrieve(
url,
save_path,
pbar.hook)
assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
'{} file is corrupted. Remove the file and try again.'.format(save_path)
os.makedirs(extract_path)
try:
extract_fn(save_path, extract_path, database_name, data_path)
except Exception as err:
shutil.rmtree(extract_path) # Remove extraction folder if there is an error
raise err
print('Done.')
# Remove compressed data
# os.remove(save_path)
class DLProgress(tqdm):
"""
Handle Progress Bar while Downloading
"""
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
"""
A hook function that will be called once on establishment of the network connection and
once after each block read thereafter.
:param block_num: A count of blocks transferred so far
:param block_size: Block size in bytes
:param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
a file size in response to a retrieval request.
"""
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
data_dir = './'
download_extract('ml-1m', data_dir)
Found ml-1m Data
數據探索
import pandas as pd
import matplotlib.pyplot as plt
rating
def getRatings(file_path):
rates = pd.read_table(
file_path,
header=None,
sep="::",
names=["userID", "movieID", "rate", "timestamp"],
)
print("userID的範圍爲: <{},{}>"
.format(min(rates["userID"]), max(rates["userID"])))
print("movieID的範圍爲: <{},{}>"
.format(min(rates["movieID"]), max(rates["movieID"])))
print("評分值的範圍爲: <{},{}>"
.format(min(rates["rate"]), max(rates["rate"])))
print("數據總條數爲:\n{}".format(rates.count()))
print("數據前5條記錄爲:\n{}".format(rates.head(5)))
df = rates["userID"].groupby(rates["userID"])
print("用戶評分記錄最少條數爲:{}".format(df.count().min()))
scores = rates["rate"].groupby(rates["rate"]).count()
# 圖上添加數字
for x, y in zip(scores.keys(), scores.values):
plt.text(x, y + 2, "%.0f" % y, ha="center", va="bottom", fontsize=12)
plt.bar(scores.keys(), scores.values, fc="r", tick_label=scores.keys())
plt.xlabel("Rating Score")
plt.ylabel("People Number")
plt.title("Rating Scores And related People Number")
plt.show()
getRatings("./ml-1m/ratings.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
userID的範圍爲: <1,6040>
movieID的範圍爲: <1,3952>
評分值的範圍爲: <1,5>
數據總條數爲:
userID 1000209
movieID 1000209
rate 1000209
timestamp 1000209
dtype: int64
數據前5條記錄爲:
userID movieID rate timestamp
0 1 1193 5 978300760
1 1 661 3 978302109
2 1 914 3 978301968
3 1 3408 4 978300275
4 1 2355 5 978824291
用戶評分記錄最少條數爲:20
Movies
def getMovies(file_path):
movies = pd.read_table(
file_path,
header=None,
sep="::",
names=["movieID", "title", "genres"]
)
print("movieID的範圍爲: <{},{}>"
.format(min(movies["movieID"]), max(movies["movieID"])))
print("數據總條數爲:\n{}".format(movies.count()))
# 電影分類統計
moviesDict = dict()
for line in movies["genres"].values:
for one in line.split("|"):
moviesDict.setdefault(one, 0)
moviesDict[one] += 1
print("電影類型總數爲:{}".format(len(moviesDict)))
print("電影類型分別爲:{}".format(moviesDict.keys()))
print(moviesDict)
newMD = sorted(moviesDict.items(), key=lambda x: x[1], reverse=True)
# 設置標籤
labels = [newMD[i][0] for i in range(len(newMD))]
values = [newMD[i][1] for i in range(len(newMD))]
# 與labels對應,數值越大離中心區越遠
explode = [x * 0.01 for x in range(len(newMD))]
# 設置X軸 Y軸比例
plt.axes(aspect=1)
# labeldistance表示標籤離中心距離,pctdistance表示百分百數據離中心區距離
# autopct表示百分比的格式,shadow表示陰影
plt.pie(
x=values,
labels=labels,
explode=explode,
autopct="%3.1f %%",
shadow=False,
labeldistance=1.1,
startangle=0,
pctdistance=0.8,
center=(-1, 0),
)
# 控制位置:在bbox_to_anchor數組中,前者控制左右移動,後者控制上下
# ncol控制圖例所列的列數,默認爲1
plt.legend(loc=7, bbox_to_anchor=(1.3, 1.0), ncol=3, fancybox=True, shadow=True, fontsize=6)
plt.show()
getMovies("./ml-1m/movies.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
movieID的範圍爲: <1,3952>
數據總條數爲:
movieID 3883
title 3883
genres 3883
dtype: int64
電影類型總數爲:18
電影類型分別爲:dict_keys(['Animation', "Children's", 'Comedy', 'Adventure', 'Fantasy', 'Romance', 'Drama', 'Action', 'Crime', 'Thriller', 'Horror', 'Sci-Fi', 'Documentary', 'War', 'Musical', 'Mystery', 'Film-Noir', 'Western'])
{'Animation': 105, "Children's": 251, 'Comedy': 1200, 'Adventure': 283, 'Fantasy': 68, 'Romance': 471, 'Drama': 1603, 'Action': 503, 'Crime': 211, 'Thriller': 492, 'Horror': 343, 'Sci-Fi': 276, 'Documentary': 127, 'War': 143, 'Musical': 114, 'Mystery': 106, 'Film-Noir': 44, 'Western': 68}
User
def getUsers(file_path):
users = pd.read_table(
file_path,
header=None,
sep="::",
names=["userID", "gender", "age", "Occupation", "zip-code"],
)
print("userID的範圍爲: <{},{}>".format(min(users["userID"]), max(users["userID"])))
print("數據總條數爲:\n{}".format(users.count()))
usersGender = users["gender"].groupby(users["gender"]).count()
print(usersGender)
plt.axes(aspect=1)
plt.pie(x=usersGender.values, labels=usersGender.keys(), autopct="%3.1f %%")
plt.legend(bbox_to_anchor=(1.0, 1.0))
plt.show()
usersAge = users["age"].groupby(users["age"]).count()
print(usersAge)
plt.plot(
usersAge.keys(),
usersAge.values,
label="用戶年齡信息展示",
linewidth=3,
color="r",
marker="o",
markerfacecolor="blue",
markersize=12,
)
# 圖上添加數字
for x, y in zip(usersAge.keys(), usersAge.values):
plt.text(x, y+10, "%.0f" % y, ha="center", va="bottom", fontsize=12)
plt.xlabel("Age")
plt.ylabel("Users Number")
plt.title("Age of users")
plt.show()
getUsers("./ml-1m/users.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
userID的範圍爲: <1,6040>
數據總條數爲:
userID 6040
gender 6040
age 6040
Occupation 6040
zip-code 6040
dtype: int64
gender
F 1709
M 4331
Name: gender, dtype: int64
age
1 222
18 1103
25 2096
35 1193
45 550
50 496
56 380
Name: age, dtype: int64