個性化推薦系統——1. 數據集探索

使用MovieLens數據集

下載數據集

下面這段代碼寫得很完備,很高大上

import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
import seaborn as sns
from collections import Counter
import tensorflow as tf

import os
import pickle
import re
from tensorflow.python.ops import math_ops
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import zipfile
import hashlib

def _unzip(save_path, _, database_name, data_path):
    """
    Unzip wrapper with the same interface as _ungzip
    :param save_path: The path of the gzip files
    :param database_name: Name of database
    :param data_path: Path to extract to
    :param _: HACK - Used to have to same interface as _ungzip
    """
    print('Extracting {}...'.format(database_name))
    with zipfile.ZipFile(save_path) as zf:
        zf.extractall(data_path)

def download_extract(database_name, data_path):
    """
    Download and extract database
    :param database_name: Database name
    """
    DATASET_ML1M = 'ml-1m'

    if database_name == DATASET_ML1M:
        url = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip'
        hash_code = 'c4d9eecfca2ab87c1945afe126590906'
        extract_path = os.path.join(data_path, 'ml-1m')
        save_path = os.path.join(data_path, 'ml-1m.zip')
        extract_fn = _unzip

    if os.path.exists(extract_path):
        print('Found {} Data'.format(database_name))
        return

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    if not os.path.exists(save_path):
        with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
            urlretrieve(
                url,
                save_path,
                pbar.hook)

    assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
        '{} file is corrupted.  Remove the file and try again.'.format(save_path)

    os.makedirs(extract_path)
    try:
        extract_fn(save_path, extract_path, database_name, data_path)
    except Exception as err:
        shutil.rmtree(extract_path)  # Remove extraction folder if there is an error
        raise err

    print('Done.')
    # Remove compressed data
#     os.remove(save_path)

class DLProgress(tqdm):
    """
    Handle Progress Bar while Downloading
    """
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        """
        A hook function that will be called once on establishment of the network connection and
        once after each block read thereafter.
        :param block_num: A count of blocks transferred so far
        :param block_size: Block size in bytes
        :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
                            a file size in response to a retrieval request.
        """
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num
data_dir = './'
download_extract('ml-1m', data_dir)
Found ml-1m Data

數據探索

import pandas as pd
import matplotlib.pyplot as plt

rating

def getRatings(file_path):
    rates = pd.read_table(
        file_path,
        header=None,
        sep="::",
        names=["userID", "movieID", "rate", "timestamp"],
    )
    print("userID的範圍爲: <{},{}>"
          .format(min(rates["userID"]), max(rates["userID"])))
    print("movieID的範圍爲: <{},{}>"
          .format(min(rates["movieID"]), max(rates["movieID"])))
    print("評分值的範圍爲: <{},{}>"
          .format(min(rates["rate"]), max(rates["rate"])))
    print("數據總條數爲:\n{}".format(rates.count()))
    print("數據前5條記錄爲:\n{}".format(rates.head(5)))
    df = rates["userID"].groupby(rates["userID"])
    print("用戶評分記錄最少條數爲:{}".format(df.count().min()))

    scores = rates["rate"].groupby(rates["rate"]).count()
    # 圖上添加數字
    for x, y in zip(scores.keys(), scores.values):
        plt.text(x, y + 2, "%.0f" % y, ha="center", va="bottom", fontsize=12)
    plt.bar(scores.keys(), scores.values, fc="r", tick_label=scores.keys())
    plt.xlabel("Rating Score")
    plt.ylabel("People Number")
    plt.title("Rating Scores And related People Number")
    plt.show()
getRatings("./ml-1m/ratings.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.



userID的範圍爲: <1,6040>
movieID的範圍爲: <1,3952>
評分值的範圍爲: <1,5>
數據總條數爲:
userID       1000209
movieID      1000209
rate         1000209
timestamp    1000209
dtype: int64
數據前5條記錄爲:
   userID  movieID  rate  timestamp
0       1     1193     5  978300760
1       1      661     3  978302109
2       1      914     3  978301968
3       1     3408     4  978300275
4       1     2355     5  978824291
用戶評分記錄最少條數爲:20

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-YxD37vHb-1580564733297)(output_9_2.png)]

Movies

def getMovies(file_path):
    movies = pd.read_table(
        file_path,
        header=None,
        sep="::",
        names=["movieID", "title", "genres"]
    )

    print("movieID的範圍爲: <{},{}>"
          .format(min(movies["movieID"]), max(movies["movieID"])))
    print("數據總條數爲:\n{}".format(movies.count()))
    # 電影分類統計
    moviesDict = dict()
    for line in movies["genres"].values:
        for one in line.split("|"):
            moviesDict.setdefault(one, 0)
            moviesDict[one] += 1

    print("電影類型總數爲:{}".format(len(moviesDict)))
    print("電影類型分別爲:{}".format(moviesDict.keys()))
    print(moviesDict)

    newMD = sorted(moviesDict.items(), key=lambda x: x[1], reverse=True)
    # 設置標籤
    labels = [newMD[i][0] for i in range(len(newMD))]
    values = [newMD[i][1] for i in range(len(newMD))]
    # 與labels對應,數值越大離中心區越遠
    explode = [x * 0.01 for x in range(len(newMD))]
    # 設置X軸 Y軸比例
    plt.axes(aspect=1)
    # labeldistance表示標籤離中心距離,pctdistance表示百分百數據離中心區距離
    # autopct表示百分比的格式,shadow表示陰影
    plt.pie(
        x=values,
        labels=labels,
        explode=explode,
        autopct="%3.1f %%",
        shadow=False,
        labeldistance=1.1,
        startangle=0,
        pctdistance=0.8,
        center=(-1, 0),
    )
    # 控制位置:在bbox_to_anchor數組中,前者控制左右移動,後者控制上下
    # ncol控制圖例所列的列數,默認爲1
    plt.legend(loc=7, bbox_to_anchor=(1.3, 1.0), ncol=3, fancybox=True, shadow=True, fontsize=6)
    plt.show()
getMovies("./ml-1m/movies.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.



movieID的範圍爲: <1,3952>
數據總條數爲:
movieID    3883
title      3883
genres     3883
dtype: int64
電影類型總數爲:18
電影類型分別爲:dict_keys(['Animation', "Children's", 'Comedy', 'Adventure', 'Fantasy', 'Romance', 'Drama', 'Action', 'Crime', 'Thriller', 'Horror', 'Sci-Fi', 'Documentary', 'War', 'Musical', 'Mystery', 'Film-Noir', 'Western'])
{'Animation': 105, "Children's": 251, 'Comedy': 1200, 'Adventure': 283, 'Fantasy': 68, 'Romance': 471, 'Drama': 1603, 'Action': 503, 'Crime': 211, 'Thriller': 492, 'Horror': 343, 'Sci-Fi': 276, 'Documentary': 127, 'War': 143, 'Musical': 114, 'Mystery': 106, 'Film-Noir': 44, 'Western': 68}

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-XlU4xku0-1580564733298)(output_12_2.png)]

User

def getUsers(file_path):
    users = pd.read_table(
        file_path,
        header=None,
        sep="::",
        names=["userID", "gender", "age", "Occupation", "zip-code"],
    )
    print("userID的範圍爲: <{},{}>".format(min(users["userID"]), max(users["userID"])))
    print("數據總條數爲:\n{}".format(users.count()))

    usersGender = users["gender"].groupby(users["gender"]).count()
    print(usersGender)

    plt.axes(aspect=1)
    plt.pie(x=usersGender.values, labels=usersGender.keys(), autopct="%3.1f %%")
    plt.legend(bbox_to_anchor=(1.0, 1.0))
    plt.show()

    usersAge = users["age"].groupby(users["age"]).count()
    print(usersAge)

    plt.plot(
        usersAge.keys(),
        usersAge.values,
        label="用戶年齡信息展示",
        linewidth=3,
        color="r",
        marker="o",
        markerfacecolor="blue",
        markersize=12,
    )
    # 圖上添加數字
    for x, y in zip(usersAge.keys(), usersAge.values):
        plt.text(x, y+10, "%.0f" % y, ha="center", va="bottom", fontsize=12)
    plt.xlabel("Age")
    plt.ylabel("Users Number")
    plt.title("Age of users")
    plt.show()
getUsers("./ml-1m/users.dat")
/home/gz/anaconda3/envs/tf2/lib/python3.7/site-packages/ipykernel_launcher.py:6: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.



userID的範圍爲: <1,6040>
數據總條數爲:
userID        6040
gender        6040
age           6040
Occupation    6040
zip-code      6040
dtype: int64
gender
F    1709
M    4331
Name: gender, dtype: int64

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-SscEGWdg-1580564733299)(output_15_2.png)]

age
1      222
18    1103
25    2096
35    1193
45     550
50     496
56     380
Name: age, dtype: int64

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章