我們寫好一個python模塊供他人在命令行下調用時,有時需要在命令行指定參數傳入到模塊,通過argparse包可以便捷地接收參數。
例如我用pytorch定義了一個神經網絡模型,需要命令行運行main.py時傳入參數給初始化函數Net.init()
python main.py --task train --batch_size 10 --result_file ./result/deeptte.res --pooling_method attention --kernel_size 3 --alpha 0.1 --log_file run_log
import torch.nn as nn
class Net(nn.Module):
def __init__(self, kernel_size=3, num_filter=32, pooling_method='attention', num_final_fcs=3, final_fc_size=128,
alpha=0.3):
super(Net, self).__init__()
self.kernel_size = kernel_size
self.num_filter = num_filter
self.pooling_method = pooling_method
self.num_final_fcs = num_final_fcs
self.final_fc_size = final_fc_size
self.alpha = alpha
......
使用argparse首先通過ArgumentParser()
創建一個解析器parser,然後通過add_argument()
爲解析器添加要解析的參數,最後通過parse_args()
就可以獲得命令行傳來的參數
import argparse
# 創建解析器
parser = argparse.ArgumentParser()
# 添加參數
parser.add_argument('--task', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--weight_file', type=str)
parser.add_argument('--result_file', type=str)
parser.add_argument('--kernel_size', type=int)
parser.add_argument('--pooling_method', type=str)
parser.add_argument('--alpha', type=float)
parser.add_argument('--log_file', type=str)
# 解析參數
args = parser.parse_args()
接下來需要對參數進行過濾並利用參數來初始化神經網絡。我們獲取的參數args有的可能並不是Net.init()所需要的,因此通過inspect.getargspec()
方法來獲取Net.init()方法的參數列表model_args,然後將命令行獲得的參數轉化爲dict並進行遍歷,如果參數不在model_args中,則pop掉。最後利用剩下的參數去初始化Net
import inspect
model_args = inspect.getargspec(model_class.__init__).args # 初始化函數所需參數
shell_args = args._get_kwargs() # 命令行輸入的參數
# 對參數進行過濾
kwargs = dict(shell_args)
for arg, val in shell_args:
if arg not in model_args:
kwargs.pop(arg)
# 使用參數初始化神經網絡對象
model = Net(**kwargs)