“雙月”數據集的生成

                                        “雙月”數據集的生成

問題描述:

最近在看《神經網絡與機器學習》一書,裏面的實驗都是使用的“雙月”數據集,如圖所示:

但是書中並沒有給出這個數據集的來源或者生成方式,所以我自己根據圖中的描述信息,利用python來生成數據集,考慮到問題的靈活性,所以並沒有把數據集存儲下來,而是需要用的時候再生成。

實現思路:

基本假設:

  1. 區域A固定不動
  2. 區域B的圓心可以隨着d的不同以及r的不同而調整
  3. 區域A和區域B的半徑相同
  4. 半徑r表示圓心到圓環中部的位置

半圓的生成:(參數方程表示)

  1. 半徑:\large r\sim U(r-\frac{w}{2}, r+\frac{w}{2})(均勻分佈)
  2. 角度:區域A \large \theta\sim U(0,\pi ),區域B \large \theta\sim U(-\pi, 0) (角度都是均勻分佈)

代碼實現:

import numpy as np
import matplotlib.pyplot as plt

def generate_data(n, r=10, w=6, d=1):
    theta1 = np.random.uniform(0, np.pi, size=n)
    theta2 = np.random.uniform(-np.pi, 0, size=n)
    w1 = np.random.uniform(-w/2, w/2, size=n)
    w2 = np.random.uniform(-w/2, w/2, size=n)
    one = np.ones_like(theta1)

    # data_A_i = [1, coord_x, coord_y, label], label = 1 or -1
    data_A = np.array([one, (r+w1)*np.cos(theta1), (r+w1)*np.sin(theta1), one]).T
    data_B = np.array([one, r + (r+w2)*np.cos(theta2), -d + (r+w2)*np.sin(theta2), -one]).T
    return data_A, data_B

class Data:
    def __init__(self, n, r=10, w=6, d=1):
        self.n = n          # 數據對數,A、B區域各有n個點
        self.r = r          # 半徑
        self.w = w          # 圓環的寬度
        self.d = d          # 區域B的向下偏移量
        self.data_A = []    # 區域A的數據集
        self.data_B = []    # 區域B的數據集
        self.data_AB = []   # 混合區域A和區域B的數據集
    
    def get_data(self):
        self.data_A, self.data_B = generate_data(self.n, self.r, self.w, self.d)
        all_data = np.vstack([self.data_A,self.data_B])
        np.random.shuffle(all_data)
        self.data_AB = all_data

    def plot(self):
        fig = plt.figure()
        plt.scatter(self.data_A[:, 1], self.data_A[:, 2], marker='x')
        plt.scatter(self.data_B[:, 1], self.data_B[:, 2], marker='+')
        plt.show()
        

if __name__ == "__main__":
    # data_A, data_B = generate_data(1000)
    # print(data_A)
    # print(data_B)
    # fig = plt.figure()
    # plt.scatter(data_A[:,1], data_A[:, 2], marker='x')
    # plt.scatter(data_B[:,1], data_B[:, 2], marker='+')
    # plt.show()

    data_set = Data(1000)
    data_set.get_data()
    print(data_set.data_AB)
    data_set.plot()

示例:

n = 1000,r = 10,w = 6,d = 1.

 

注:

  1. 當時生成數據是爲了模式分類實驗去的,所以生成的數據帶有偏置項,同時帶有標籤項(區域A爲+1,區域B爲-1),如果有需要可以自行處理。
  2. 目前編程菜鳥一隻,不會面向對象編程卻硬要寫,大佬勿噴!如有疑問,敬請評論!
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章