Python小练习:裁减函数(Clip Function)

Python小练习:裁减函数(Clip Function)

作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/

本文介绍两种数据裁剪方法,将原始数据裁剪到某一指定范围内。

1. clip_function_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凯鲁嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小练习:裁减函数(Clip Function)
 5 import torch
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 plt.rc('font',family='Times New Roman')
 9 # 裁剪范围
10 LOG_STD_MAX = 2
11 LOG_STD_MIN = -10
12 def clip_function(
13     x: torch.Tensor,
14     bound_mode: str
15 ) -> torch.Tensor:
16     if bound_mode == "clamp": # 将x裁剪到[-10, 2]
17         # 大于2的统一设为2,小于-10的统一设为-10
18         x = torch.clamp(x, LOG_STD_MIN, LOG_STD_MAX)
19     elif bound_mode == "tanh": # 将x裁剪到[-10, 2]
20         scale = (LOG_STD_MAX-LOG_STD_MIN) / 2 # 6
21         x = (torch.tanh(x)+1) * scale + LOG_STD_MIN
22         # tanh:[-1, 1], torch.tanh()+1:[0, 2]
23         # (torch.tanh(x)+1) * scale:[0, 12]
24         # (torch.tanh(x)+1) * scale + LOG_STD_MIN:[-10, 2]
25     elif bound_mode == "no":
26         x = x
27     else:
28         raise NotImplementedError
29     return x
30 
31 torch.manual_seed(0)
32 x = torch.randn(2, 3)*10
33 print('原始数据:\n', x)
34 
35 str1 = 'clamp'
36 print('裁剪算子:', str1)
37 y = clip_function(x, str1)
38 print('裁剪后:\n', y)
39 
40 str2 = 'tanh'
41 print('裁剪算子:', str2)
42 y = clip_function(x, str2)
43 print('裁剪后:\n', y)
44 
45 # --------------------画图------------------------
46 num = 1000
47 a = torch.randn(num)*10.0
48 a, _ = torch.sort(a)
49 b1 = clip_function(a, str1)
50 b2 = clip_function(a, str2)
51 # 手动设置横纵座标范围
52 plt.xlim([0, num])
53 plt.ylim([a.min(), a.max()])
54 aa = np.arange(0, num)
55 plt.plot(aa, a, color = 'green', ls = '-', label = 'data')
56 plt.plot(aa, b1, color = 'red', ls = '-', label = str1)
57 plt.plot(aa, b2, color = 'blue', ls = '-', label = str2)
58 # 画2条不起眼的虚线
59 plt.plot([0, num], [LOG_STD_MIN, LOG_STD_MIN], color = 'gray', ls = '--', alpha = 0.3)
60 plt.plot([0, num], [LOG_STD_MAX, LOG_STD_MAX], color = 'gray', ls = '--', alpha = 0.3)
61 # 横纵座标轴
62 plt.xlabel('x')
63 plt.ylabel('clip(x)')
64 plt.legend(loc = 2)
65 plt.tight_layout()
66 plt.savefig('Clip Function.png', bbox_inches='tight', dpi=500)
67 plt.show()

2. 结果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Other/clip_function_test.py"
原始数据:
 tensor([[ 15.4100,  -2.9343, -21.7879],
        [  5.6843, -10.8452, -13.9860]])
裁剪算子: clamp
裁剪后:
 tensor([[  2.0000,  -2.9343, -10.0000],
        [  2.0000, -10.0000, -10.0000]])
裁剪算子: tanh
裁剪后:
 tensor([[  2.0000,  -9.9662, -10.0000],
        [  1.9999, -10.0000, -10.0000]])

Process finished with exit code 0

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章