“双月”数据集的生成

                                        “双月”数据集的生成

问题描述:

最近在看《神经网络与机器学习》一书,里面的实验都是使用的“双月”数据集,如图所示:

但是书中并没有给出这个数据集的来源或者生成方式,所以我自己根据图中的描述信息,利用python来生成数据集,考虑到问题的灵活性,所以并没有把数据集存储下来,而是需要用的时候再生成。

实现思路:

基本假设:

  1. 区域A固定不动
  2. 区域B的圆心可以随着d的不同以及r的不同而调整
  3. 区域A和区域B的半径相同
  4. 半径r表示圆心到圆环中部的位置

半圆的生成:(参数方程表示)

  1. 半径:\large r\sim U(r-\frac{w}{2}, r+\frac{w}{2})(均匀分布)
  2. 角度:区域A \large \theta\sim U(0,\pi ),区域B \large \theta\sim U(-\pi, 0) (角度都是均匀分布)

代码实现:

import numpy as np
import matplotlib.pyplot as plt

def generate_data(n, r=10, w=6, d=1):
    theta1 = np.random.uniform(0, np.pi, size=n)
    theta2 = np.random.uniform(-np.pi, 0, size=n)
    w1 = np.random.uniform(-w/2, w/2, size=n)
    w2 = np.random.uniform(-w/2, w/2, size=n)
    one = np.ones_like(theta1)

    # data_A_i = [1, coord_x, coord_y, label], label = 1 or -1
    data_A = np.array([one, (r+w1)*np.cos(theta1), (r+w1)*np.sin(theta1), one]).T
    data_B = np.array([one, r + (r+w2)*np.cos(theta2), -d + (r+w2)*np.sin(theta2), -one]).T
    return data_A, data_B

class Data:
    def __init__(self, n, r=10, w=6, d=1):
        self.n = n          # 数据对数,A、B区域各有n个点
        self.r = r          # 半径
        self.w = w          # 圆环的宽度
        self.d = d          # 区域B的向下偏移量
        self.data_A = []    # 区域A的数据集
        self.data_B = []    # 区域B的数据集
        self.data_AB = []   # 混合区域A和区域B的数据集
    
    def get_data(self):
        self.data_A, self.data_B = generate_data(self.n, self.r, self.w, self.d)
        all_data = np.vstack([self.data_A,self.data_B])
        np.random.shuffle(all_data)
        self.data_AB = all_data

    def plot(self):
        fig = plt.figure()
        plt.scatter(self.data_A[:, 1], self.data_A[:, 2], marker='x')
        plt.scatter(self.data_B[:, 1], self.data_B[:, 2], marker='+')
        plt.show()
        

if __name__ == "__main__":
    # data_A, data_B = generate_data(1000)
    # print(data_A)
    # print(data_B)
    # fig = plt.figure()
    # plt.scatter(data_A[:,1], data_A[:, 2], marker='x')
    # plt.scatter(data_B[:,1], data_B[:, 2], marker='+')
    # plt.show()

    data_set = Data(1000)
    data_set.get_data()
    print(data_set.data_AB)
    data_set.plot()

示例:

n = 1000,r = 10,w = 6,d = 1.

 

注:

  1. 当时生成数据是为了模式分类实验去的,所以生成的数据带有偏置项,同时带有标签项(区域A为+1,区域B为-1),如果有需要可以自行处理。
  2. 目前编程菜鸟一只,不会面向对象编程却硬要写,大佬勿喷!如有疑问,敬请评论!
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章