最近看到一篇推文是在不量化、不損失精度的情況下使用一張16G的顯卡推理70B的大模型。方案來自於kaggle的一個方案,具體流程爲:
1.創建一個空的(例如,沒有權重的)模型
2.決定每一層將要去哪裏(當有多個設備可用時)
3.在內存中加載其權重的一部分
4.在空模型中加載這些權重
5.將權重移動到設備上進行推理
6.從第3步重複,直到所有的權重都被加載
PyTorch 1.9引入了一種新的設備,稱爲元設備(meta device)。
這使我們能夠創建沒有任何數據附加的張量,元設備上的張量只需要一個shape,只要你在元設備上,你就可以創建任意大的張量,而不必擔心CPU(或GPU)的RAM夠不夠。
比如下面的代碼,內存不夠的話就會崩掉
1 import torch 2 large_tensor = torch.randn(100000, 100000)
這個大張量需要4 * 10**10字節(默認精度是FP32,所以張量的每個元素佔用4字節),因此需要40GB的RAM。然而,在元設備上執行相同的操作就可以正常運行:
1 import torch 2 large_tensor = torch.randn(100000, 100000, device="meta")
這個張量沒有關聯的數據,只有一個形狀。你可以直接在元設備上實例化一個模型:
1 large_model = torch.nn.Linear(100000, 100000, device="meta")
但是對於現成的模型來說,這種語法需要你重寫所有的建模代碼,以便每個模型的子部分都接受並傳遞一個設備關鍵字參數。由於這對Transformers庫的預訓練模型來說不切實際,accelerate庫有一個context manager,整合了meta device可以實例化一個空模型。
1 # Load meta model (no memory used) 2 with init_empty_weights(): 3 self.model = AutoModelForCausalLM.from_config(self.config, trust_remote_code=True) 4 self.model.tie_weights()
這一步很關鍵,我們知道每個權重的形狀,因此我們可以知道一旦我們完全加載預訓練的張量,它們將消耗多少內存。因此,我們可以決定如何在CPU和GPU之間分割我們的模型。
除此之外,定義了兩個關鍵的方法,分別是load_layer_to_cpu,負責把 權重從disk挪到CPU,另外一個是move_layer_to_device,負責把權重從cpu挪到顯卡。還有一個釋放顯存的方法clean_memory,負責清空顯存。
1 def load_layer_to_cpu(self, layer_name): 2 self.weights_loader.set_state_dict(layer_name, self.device) 3 state_dict = self.weights_loader.get_state_dict(self.device) 4 if "value_head.weight" in state_dict: 5 state_dict = {"lm_head.weight" : state_dict["value_head.weight"]} 6 return state_dict 7 8 def move_layer_to_device(self, state_dict): 9 for param_name, param in state_dict.items(): 10 assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)" 11 set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype) 12 13 def clean_memory(): 14 gc.collect() 15 ctypes.CDLL("libc.so.6").malloc_trim(0) 16 torch.cuda.empty_cache()
下面展示完整的代碼
1 from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel 2 from accelerate import init_empty_weights 3 from accelerate.utils.modeling import set_module_tensor_to_device 4 from safetensors.torch import load_file 5 from optimum.bettertransformer import BetterTransformer 6 7 N_BATCHES = 3 8 MAX_LENGTH = 4096 9 10 def clean_memory(): 11 gc.collect() 12 ctypes.CDLL("libc.so.6").malloc_trim(0) 13 torch.cuda.empty_cache() 14 15 16 # Class for sharded llama 17 class ShardedLlama: 18 def __init__(self, checkpoint_path, weights_loader, device="cuda:0", dtype=torch.float16): 19 20 # Save parameters 21 self.checkpoint_path = Path(checkpoint_path) 22 self.weights_loader = weights_loader 23 self.device = device 24 self.dtype = dtype 25 26 # Create model 27 self.config = AutoConfig.from_pretrained(self.checkpoint_path) 28 self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 29 self.tokenizer.pad_token = self.tokenizer.eos_token 30 self.tokenizer.padding_side = "right" 31 self.init_model() 32 self.layer_names = ["model.embed_tokens"] + [f"model.layers.{i}" for i in range(len(self.model.model.layers))] + ["model.norm", "value_head"] 33 34 def init_model(self): 35 36 # Load meta model (no memory used) 37 with init_empty_weights(): 38 self.model = AutoModelForCausalLM.from_config(self.config) 39 self.model.lm_head = torch.nn.Linear(8192, 8, bias=False) # originally 32k 40 self.model.eval() 41 self.model = BetterTransformer.transform(self.model) # enable flash attention 42 self.model.tie_weights() 43 44 self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) + [self.model.model.norm, self.model.lm_head] 45 46 # Move buffers to device (note that much GPU memory used) 47 for buffer_name, buffer in self.model.named_buffers(): 48 set_module_tensor_to_device(self.model, buffer_name, self.device, value=buffer, dtype=self.dtype) 49 50 def load_layer_to_cpu(self, layer_name): 51 self.weights_loader.set_state_dict(layer_name, self.device) 52 state_dict = self.weights_loader.get_state_dict(self.device) 53 if "value_head.weight" in state_dict: 54 state_dict = {"lm_head.weight" : state_dict["value_head.weight"]} 55 return state_dict 56 57 def move_layer_to_device(self, state_dict): 58 for param_name, param in state_dict.items(): 59 assert param.dtype != torch.int8, "int8 not supported (need to add fp16_statistics)" 60 set_module_tensor_to_device(self.model, param_name, self.device, value=param, dtype=self.dtype) 61 62 def __call__(self, inputs): 63 # inputs = [(prefix, suffix), ...] with prefix.shape[0] = 1 and suffix.shape[0] = 5 64 65 # Reboot the model to make sure buffers are loaded and memory is clean 66 del self.model 67 clean_memory() 68 self.init_model() 69 70 # Send batch to device 71 batch = [(prefix.to(self.device), suffix.to(self.device)) for prefix, suffix in inputs] 72 n_suffixes = len(batch[0][1]) 73 suffix_eos = [(suffix != self.tokenizer.pad_token_id).sum(1) - 1 for _, suffix in inputs] 74 75 # Create attention mask for the largest input, and position ids to use KV cache 76 attention_mask = torch.ones(MAX_LENGTH, MAX_LENGTH) 77 attention_mask = attention_mask.triu(diagonal=1)[None, None, ...] == 0 78 attention_mask = attention_mask.to(self.device) 79 position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :] 80 81 with ThreadPoolExecutor() as executor, torch.inference_mode(): 82 83 # Load first layer 84 future = executor.submit(self.load_layer_to_cpu, "model.embed_tokens") 85 86 for i, (layer_name, layer) in tqdm(enumerate(zip(self.layer_names, self.layers)), desc=self.device, total=len(self.layers)): 87 88 # Load current layer and prepare next layer 89 state_dict = future.result() 90 if (i + 1) < len(self.layer_names): 91 future = executor.submit(self.load_layer_to_cpu, self.layer_names[i + 1]) 92 self.move_layer_to_device(state_dict) 93 94 # Run layer 95 for j, (prefix, suffix) in enumerate(batch): 96 if layer_name == "model.embed_tokens": 97 batch[j] = (layer(prefix), layer(suffix)) 98 elif layer_name == "model.norm": 99 # Only keep the last token at this point 100 batch[j] = (None, layer(suffix[torch.arange(n_suffixes), suffix_eos[j]][:, None])) 101 elif layer_name == "value_head": 102 batch[j] = layer(suffix)[:, 0].mean(1).detach().cpu().numpy() 103 else: 104 # Run prefix 105 len_p, len_s = prefix.shape[1], suffix.shape[1] 106 new_prefix, (k_cache, v_cache) = layer(prefix, use_cache=True, attention_mask=attention_mask[:, :, -len_p:, -len_p:]) 107 108 # Run suffix 109 pos = position_ids[:, len_p:len_p + len_s].expand(n_suffixes, -1) 110 attn = attention_mask[:, :, -len_s:, -len_p - len_s:].expand(n_suffixes, -1, -1, -1) 111 kv_cache = (k_cache.expand(n_suffixes, -1, -1, -1), v_cache.expand(n_suffixes, -1, -1, -1)) 112 new_suffix = layer(suffix, past_key_value=kv_cache, position_ids=pos, attention_mask=attn)[0] 113 batch[j] = (new_prefix, new_suffix) 114 115 # Remove previous layer from memory (including buffers) 116 layer.to("meta") 117 clean_memory() # proposed by CPMP 118 119 # Get scores 120 return batch 121 122 123 124 125 def run_model(device, df, weights_loader): 126 model = ShardedLlama(checkpoint_path, weights_loader, device=device) 127 f = partial(get_tokens, tokenizer=model.tokenizer) 128 inputs = df.apply(f, axis=1).values 129 batches = np.array_split(inputs, N_BATCHES) 130 outputs = [] 131 for i, batch in enumerate(batches): 132 outputs += model(batch) 133 return outputs
完整代碼參考:https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag
文章來源: