Python小練習:裁減函數(Clip Function)

Python小練習:裁減函數(Clip Function)

作者:凱魯嘎吉 - 博客園 http://www.cnblogs.com/kailugaji/

本文介紹兩種數據裁剪方法,將原始數據裁剪到某一指定範圍內。

1. clip_function_test.py

 1 # -*- coding: utf-8 -*-
 2 # Author:凱魯嘎吉 Coral Gajic
 3 # https://www.cnblogs.com/kailugaji/
 4 # Python小練習:裁減函數(Clip Function)
 5 import torch
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 plt.rc('font',family='Times New Roman')
 9 # 裁剪範圍
10 LOG_STD_MAX = 2
11 LOG_STD_MIN = -10
12 def clip_function(
13     x: torch.Tensor,
14     bound_mode: str
15 ) -> torch.Tensor:
16     if bound_mode == "clamp": # 將x裁剪到[-10, 2]
17         # 大於2的統一設爲2,小於-10的統一設爲-10
18         x = torch.clamp(x, LOG_STD_MIN, LOG_STD_MAX)
19     elif bound_mode == "tanh": # 將x裁剪到[-10, 2]
20         scale = (LOG_STD_MAX-LOG_STD_MIN) / 2 # 6
21         x = (torch.tanh(x)+1) * scale + LOG_STD_MIN
22         # tanh:[-1, 1], torch.tanh()+1:[0, 2]
23         # (torch.tanh(x)+1) * scale:[0, 12]
24         # (torch.tanh(x)+1) * scale + LOG_STD_MIN:[-10, 2]
25     elif bound_mode == "no":
26         x = x
27     else:
28         raise NotImplementedError
29     return x
30 
31 torch.manual_seed(0)
32 x = torch.randn(2, 3)*10
33 print('原始數據:\n', x)
34 
35 str1 = 'clamp'
36 print('裁剪算子:', str1)
37 y = clip_function(x, str1)
38 print('裁剪後:\n', y)
39 
40 str2 = 'tanh'
41 print('裁剪算子:', str2)
42 y = clip_function(x, str2)
43 print('裁剪後:\n', y)
44 
45 # --------------------畫圖------------------------
46 num = 1000
47 a = torch.randn(num)*10.0
48 a, _ = torch.sort(a)
49 b1 = clip_function(a, str1)
50 b2 = clip_function(a, str2)
51 # 手動設置橫縱座標範圍
52 plt.xlim([0, num])
53 plt.ylim([a.min(), a.max()])
54 aa = np.arange(0, num)
55 plt.plot(aa, a, color = 'green', ls = '-', label = 'data')
56 plt.plot(aa, b1, color = 'red', ls = '-', label = str1)
57 plt.plot(aa, b2, color = 'blue', ls = '-', label = str2)
58 # 畫2條不起眼的虛線
59 plt.plot([0, num], [LOG_STD_MIN, LOG_STD_MIN], color = 'gray', ls = '--', alpha = 0.3)
60 plt.plot([0, num], [LOG_STD_MAX, LOG_STD_MAX], color = 'gray', ls = '--', alpha = 0.3)
61 # 橫縱座標軸
62 plt.xlabel('x')
63 plt.ylabel('clip(x)')
64 plt.legend(loc = 2)
65 plt.tight_layout()
66 plt.savefig('Clip Function.png', bbox_inches='tight', dpi=500)
67 plt.show()

2. 結果

D:\ProgramData\Anaconda3\python.exe "D:/Python code/2023.3 exercise/Other/clip_function_test.py"
原始數據:
 tensor([[ 15.4100,  -2.9343, -21.7879],
        [  5.6843, -10.8452, -13.9860]])
裁剪算子: clamp
裁剪後:
 tensor([[  2.0000,  -2.9343, -10.0000],
        [  2.0000, -10.0000, -10.0000]])
裁剪算子: tanh
裁剪後:
 tensor([[  2.0000,  -9.9662, -10.0000],
        [  1.9999, -10.0000, -10.0000]])

Process finished with exit code 0

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章