《PROMPT2MODEL: Generating Deployable Models from Natural Language Instructions》論文學習

一、Introduction

傳統上,從零開始構建一個自然語言處理(NLP)模型是一項重大任務。一個尋求解決新問題的NLP從業者需要定義他們的任務範圍,找到或創建目標任務領域的行爲數據,選擇合適的模型架構,訓練模型,通過評估評估其性能,然後將其部署到實際應用中。

Prompt2Model is a framework for generating a small yet accurate model from a prompt.

類似於GPT-3的語言生成模型(LLM)提供了一種通過“prompt提示”實現的輕量級的自然語言處理系統構建範式。從業者現在可以編寫一個prompt提示,指定預期的系統行爲(可以選擇性地提供一些演示),然後要求LLM通過text completion生成所需的輸出。這使得開發者可以快速地爲各種應用開發自然語言處理系統的原型,而無需編寫任何代碼。

然而,需要指出的是,目前仍存在着概念驗證原型和實際部署之間的差距。

  • 一方面,使用LLM進行提示可能很昂貴,因爲它們要麼需要大量的計算資源,要麼需要訪問商業API,並且它們對輸入提示質量的依賴使其與經過訓練的微調模型相比不穩定。因爲從業者通常沒有足夠的驗證數據來衡量系統的性能,所以在部署之前調試系統也更具挑戰性。
  • 此外,通過LLM進行提示的系統還存在使用上的挑戰。從業者對使用LLM的高成本和較慢的預測時間表示關注,而在高風險領域工作的人員由於隱私問題不能依賴商業LLM API。例如,在美國,將用戶數據與LLM服務提供商共享對於許多應用是非法的。

在這項工作中,我們提出了Prompt2Model,這是一個系統,它通過提示保留了以輕量級方式指定系統行爲的能力,同時產生一個可部署的特定目標模型,保持了微調模型的所有優勢。Prompt2Model被設計爲一個自動化流水線,通過三個渠道提取用戶提示中的關鍵任務信息,然後自動收集和綜合任務特定的知識:

  • 數據集檢索:在可能的情況下,我們通過檢索與任務相關的標註數據來收集訓練數據。
  • 數據集生成:我們利用一個LLM(“教師模型”)提取知識,通過使用它來生成一個僞標記數據集。之前的工作已經證明,這樣的數據集可以用來訓練一個較小的“學生”模型,以模擬教師模型的行爲。
  • 模型檢索:根據提示,我們確定一個預訓練的語言模型,其參數化知識適用於用戶的意圖。這個選擇的模型作爲學生模型,並利用生成和檢索的數據進行進一步的微調和評估。

Prompt2Model被設計爲支持每個組件的可插拔替換。

我們提供一個參考實現,其中展示了它與基於gpt-3.5-turbo的數據集生成器、基於DataFinder的數據集檢索器以及使用BM25的模型檢索器的實用性。我們對三個任務進行評估,涵蓋傳統的自然語言處理基準測試和新穎的應用,發現Prompt2Model在某些情況下生成的小模型在使用相同的提示作爲輸入時優於gpt-3.5-turbo。在這3個任務中的2箇中,我們觀察到與gpt-3.5-turbo基準相比的改進幅度超過20個點,儘管Prompt2Model生成的最終模型體積最多隻有gpt-3.5-turbo的700倍小。我們還發現我們可以生成有效的評估數據集。

我們相信Prompt2Model可以爲社區提供以下用途:

  1. 快速構建小型高效的自然語言處理系統的工具:Prompt2Model可以直接用於在幾小時內生成優於LLMs的特定任務模型,而無需進行任何手動數據標註或架構設計。該方法填補了概念驗證LLM原型和模型的實際部署之間的差距。
  2. 基於提示的端到端模型訓練的測試平臺:鑑於Prompt2Model的可擴展設計,它可以提供一個平臺,用於探索模型蒸餾、數據集生成、合成評估、數據集檢索和模型檢索等新技術。我們的平臺允許使用外在的下游度量標準研究這些組件,從而在這些研究領域取得經驗上的進展。

參考鏈接:

https://arxiv.org/pdf/2308.12261.pdf 

 

二、Prompt2Model Framework

我們的系統Prompt2Model提供了一個自動化的機器學習流程平臺:數據收集、模型訓練、評估和部署。

我們在下圖中展示了我們的自動化流程。 

The Prompt2Model architecture seeks to automate the core machine learning development pipeline, allowing us to train a small yet accurate model from just a prompt. 

核心是我們的自動化數據收集系統,它利用數據集檢索和基於LLM的數據集生成來獲取與用戶需求相關的標記數據。

然後,我們檢索預訓練模型,並對收集到的數據集的訓練集進行微調。

最後,我們在相同的數據集的測試集上評估我們訓練過的模型,並可選擇創建一個可以與模型交互的Web用戶界面。

我們通用的方法設計成模塊化和可擴展的,每個組件可以由從業者以不同方式實現或禁用。

0x1:Prompt Parser

作爲我們系統的主要輸入,用戶提供LLMs的prompt提示。這些提示包括一條指令和可選的幾個預期行爲的演示。

雖然這種開放式的界面對用戶來說很方便,但端到端的機器學習流水線強依賴於一個處理這種輸入的提示解析器,例如將提示分割成指令和單個演示,或將指令翻譯成英文。

我們將提示解析爲指令和示範字段,其中,

  • 指令代表主要任務或目標
  • 示範展示所需行爲

爲了實現這一點,我們利用具有上下文學習的LLM來分割用戶提示,在實驗中使用OpenAI gpt-3.5-turbo-0613。如果提供的指令被確定爲非英語語言,則使用DeepL API將其翻譯成英語。

0x2:Dataset Retriever 

給定一個prompt提示,我們首先嚐試發現現有的人工標註數據,以支持用戶的任務描述。

數據集檢索器有幾個設計決策:

  1. 搜索哪些數據集?
  2. 如何爲搜索建立數據集索引?
  3. 用戶任務需要哪些數據集列,哪些列應被忽略?

Färber和Leisinger以及Viswanathan等人的先前工作介紹了用於數據集搜索的系統。我們在我們的實現中使用了後者,稱爲DataFinder。

通過提取Hugging Face數據集中每個數據集的用戶生成數據集描述,我們利用DataFinder訓練的雙編碼檢索器對數據集進行排序,以找出最相關的數據集。一旦確定了一個相關的數據集,下一步是確定數據集的哪些列對應於用戶指定的輸入和期望的輸出。

由於自動識別任何數據集的正確模式可能具有挑戰性,我們採用了人機協作的方法。我們向用戶呈現默認情況下爲k=25個的前k個數據集,並允許用戶選擇最相關的數據集,或者聲明沒有一個數據集適合其任務。然後,我們要求用戶從數據集的模式中識別適當的輸入和輸出列。

0x3:Dataset Generator

並非所有任務都有現有的標註數據,許多任務僅與現有數據集存在弱相關。

爲了支持各種任務,我們引入了一個數據集生成器,根據Prompt解析器解析的用戶特定要求生成合成訓練數據。這個組件面臨成本效益、生成速度、示例多樣性和質量控制方面的挑戰。

我們精心設計了我們的數據集生成器,以實現速度優化的低成本生成,同時創建多樣且高質量的示例。我們的策略包括以下組成部分。

1、高質量多樣性的few-shot prompt(High-Diversity Few-Shot Prompting)

我們使用自動提示工程來生成多樣化的數據集。我們將用戶提供的演示示例與之前生成的示例的隨機樣本相結合,以促進多樣性並避免生成重複的示例。如果沒有這個策略,200個生成的問答示例中有120個是重複的;有了這個策略,只有25個是重複的。

2、溫度退火策略(Temperature Annealing)

我們根據已生成示例的數量,按比例調整採樣溫度,從低(偏向確定性輸出)調整到高(鼓勵多樣性探索)。

這種調節有助於保持輸出質量,同時逐漸鼓勵多樣性。

3、自一致解碼(Self-Consistency Decoding)

鑑於語言模型可能對相同的輸入生成非唯一或不正確的輸出,我們使用自一致性過濾來選擇僞標籤。

具體而言,我們通過選擇最常見的答案爲每個唯一輸入創建一個共識輸出,當出現常見答案之間的平局情況,我們啓發式地選擇最短的答案。這在確保唯一示例的同時提高了生成數據集的準確性。

4、異步批處理(Asynchronous Batching)

使用zeno-build,我們並行化API請求。我們使用額外的機制,如動態批處理大小和節流控制,來優化API的使用。

0x4:Model Retriever 

除了訓練數據之外,我們還必須確定一個適當的模型進行微調。

爲了支持多個任務使用統一的模型接口,我們目前限制在Hugging Face的編碼器-解碼器架構上,這是根據最近的研究表明編碼器-解碼器模型在模型蒸餾中具有更高的數據效率。這個限制仍然有很多預訓練模型可供選擇,例如:

  • 用於編碼相關任務的Salesforce/codet5-base
  • 用於阿拉伯語到英語翻譯的MaryaAI/opus-mt-ar-en-finetuned-ar-to-en

我們把選擇預訓練模型的問題看作一個搜索問題。根據用戶的指令作爲查詢,我們在Hugging Face的所有模型的文本描述中進行搜索。這個搜索任務具有挑戰性,因爲Hugging Face模型的描述往往很稀疏,包含很多模板化的文本,通常只有幾個詞表明模型的內容。

爲了解決這個問題,我們採用HyDE框架,首先使用gpt-3.5-turbo根據用戶的指令創建一個假設的模型描述。我們在下圖中展示了一個針對問答指令生成的假設文檔的示例。然後,我們將這個描述作爲擴展查詢,並應用BM25算法計算查詢-模型的相似度得分。 

For our model retriever, we first construct a hypothetical model description for a query, then compute similarity scores between that hypothetical model description and the descriptions of real models. 

爲了確保部署的便利性,我們過濾掉大小(以字節爲單位)超過用戶指定閾值的模型(默認設置爲3GB)。根據高下載量的模型往往更具質量的直覺,我們通過以下排名選擇頂級模型: 

0x5:Training

基於已獲取和生成的數據集以及預訓練模型,我們使用一個模型訓練器來在數據的子集上對模型進行微調。

目前,我們通過將所有任務視爲文本到文本生成的方式來訓練模型,但這個組件可以在未來擴展以支持新的方法。

1、Dataset Processing

我們通過利用兩個數據集來訓練模型,

  • 一個是生成的數據集
  • 一個是檢索的數據集

爲了避免領域特定建模的挑戰(例如爲分類或生成任務構建專門的架構),我們將所有數據集都視爲“文本到文本”問題。我們將每個數據集的輸入列文本化,並在輸入之前添加用戶的指令來指導模型。

2、Finetuning

我們將檢索到的數據集和生成的數據集連接起來,並在訓練學生模型之前對它們進行洗牌。我們爲所有任務使用相同的默認超參數。我們使用AdamW優化器進行訓練,lr = 5e-5,訓練3個時期,大約需要一個小時完成所有任務。

0x6:Evaluation

在對檢索到的和生成的數據集的部分進行模型訓練後,我們將剩餘的數據交給一個模型評估器模塊。

我們的目標是支持各種任務,但是爲任意任務選擇正確的任務特定度量標準是一個困難的問題。 

我們的模型評估器使用三個通用度量自動評估所有任務的模型:

  • 精確匹配(Exact Match):精確匹配度量模型輸出與參考答案完全匹配的頻率。
  • ChrF++:ChrF++平衡了精確度和召回率,用於評估文本生成質量。
  • BERTScore:BERTScore通過比較模型輸出和嵌入空間中的參考答案來捕捉語義相似性,儘管用詞或短語不同。

我們使用XLM-R作爲BERTScore的編碼器,以支持多語言評估。 

0x7:Web App Creation

爲了使開發者能夠向合作伙伴或用戶展示模型,我們包含了一個可選的組件,稱爲Demo Creator,以創建一個可視化界面來與模型進行交互,這個基於Gradio構建的Web應用可以輕鬆地在服務器上公開部署。

 

三、Discussion and Conclusion

我們提出了Prompt2Model框架,該框架僅使用自然語言提示自動生成任務特定模型。我們的概念驗證實驗證明,儘管使用了與LLMs相似的易於使用的界面,Prompt2Model仍然能夠生成小型但準確的模型,並且其生成的數據集可以用於估計實際性能。除了我們提供的可直接使用的參考實現工具外,Prompt2Model的可擴展設計和模塊化實現使其成爲推進模型蒸餾、數據集生成、合成評估、數據集檢索和模型檢索的平臺。

我們相信我們的Prompt2Model框架可以激發各種新穎的研究問題。我們希望我們的平臺能夠促使未來的工作更深入地研究生成數據和模型的質量保證。有趣的問題包括:

  • 我們應該爲下游模型訓練生成多少數據以及它應該具有多大的多樣性?
  • 我們如何有效地混合檢索和生成的數據集,以實現互補的優勢(例如,使用數據集生成來專注於檢索數據集未涵蓋的模型預期輸入)?
  • 由於用戶通常很難事先準確表達他們的需求,未來的擴展應該解決人在環路糾正的挑戰 - 要麼通過提供潛在策略來幫助人們迭代地完善提示,要麼允許人們在任務元數據提取和生成數據與其意圖不符合時進行事後修復。

我們希望提出明確的挑戰,並邀請社區爲我們的框架的各個組件提供新穎的實現。

 

四、Limitations

我們系統的一個主要限制是,

  • 我們目前的實驗都是使用gpt-3.5-turbo API進行的(用於提示解析、數據集生成和模型檢索)。這個LLM是付費閉源的,這使得它作爲科學文物存在問題(Rogers等,2023年)
  • 此外,這個LLM的服務提供商OpenAI禁止使用他們的API創建可能與OpenAI競爭的模型,這在商業應用中可能引起法律問題。我們正在探索集成開源LLM以避免對專有API的依賴。

我們工作的另一個侷限性是

  • Prompt2Model對處理英語以外的其他語言的任務的能力有限。雖然我們已經展示了我們的系統在從日語自然語言查詢生成代碼的支持方面的侷限性,但我們的系統很可能在處理資源較少的語言時遇到更大的困難。在我們的參考實現中,我們使用了未公開的gpt-3.5-turbo模型作爲我們的數據集生成器。該模型被認爲與GPT-3相似,後者是在93%的英語文檔、1%的德語文檔、1%的法語文檔和<5%的其他語言文檔上進行訓練的。
  • 我們對這個模型的使用可能加劇高資源語言和低資源語言之間現有語言技術差距的存在。

還有一個潛在的侷限性是

  • 我們只在3個任務上測試了我們的方法,每個任務只有一個數據集和一個評估指標。我們之所以做出這個決定,是因爲我們的重點是提供一個可擴展的軟件系統,而不是在許多數據集上建立最先進的結果,但我們認爲我們的結果表明具有更廣泛的適用性。

 

五、代碼示例

"""An commend line demo to run the whole system."""

import json
import logging
import os
import time
from pathlib import Path

import datasets
import pyfiglet
import torch
import transformers
import yaml
from datasets import concatenate_datasets, load_from_disk
from termcolor import colored

from prompt2model.dataset_generator.base import DatasetSplit
from prompt2model.dataset_generator.prompt_based import PromptBasedDatasetGenerator
from prompt2model.dataset_processor.textualize import TextualizeProcessor
from prompt2model.dataset_retriever import DescriptionDatasetRetriever
from prompt2model.demo_creator import create_gradio
from prompt2model.model_evaluator import Seq2SeqEvaluator
from prompt2model.model_executor import GenerationModelExecutor
from prompt2model.model_retriever import DescriptionModelRetriever
from prompt2model.model_trainer.generate import GenerationModelTrainer
from prompt2model.prompt_parser import (
    MockPromptSpec,
    PromptBasedInstructionParser,
    TaskType,
)
from prompt2model.utils.logging_utils import get_formatted_logger


def line_print(input_str: str) -> None:
    """Print the given input string surrounded by horizontal lines.

    Args:
        input_str: The string to be printed.
    """
    print(f"{input_str}")


def print_logo():
    """Print the logo of Prompt2Model."""
    figlet = pyfiglet.Figlet(width=200)
    # Create ASCII art for each word and split into lines
    words = ["Prompt", "2", "Model"]
    colors = ["red", "green", "blue"]
    ascii_art_parts = [figlet.renderText(word).split("\n") for word in words]

    # Calculate the maximum height among the words
    max_height = max(len(part) for part in ascii_art_parts)

    # Equalize the height by adding empty lines at the bottom
    for part in ascii_art_parts:
        while len(part) < max_height:
            part.append("")

    # Zip the lines together, color them, and join them with a space
    ascii_art_lines = []
    for lines in zip(*ascii_art_parts):
        colored_line = " ".join(
            colored(line, color) for line, color in zip(lines, colors)
        )
        ascii_art_lines.append(colored_line)

    # Join the lines together to get the ASCII art
    ascii_art = "\n".join(ascii_art_lines)

    # Get the width of the terminal
    term_width = os.get_terminal_size().columns

    # Center the ASCII art
    centered_ascii_art = "\n".join(
        line.center(term_width) for line in ascii_art.split("\n")
    )

    line_print(centered_ascii_art)


def main():
    """The main function running the whole system."""
    print_logo()
    # Save the status of Prompt2Model for this session,
    # in case the user wishes to stop and continue later.
    if os.path.isfile("status.yaml"):
        with open("status.yaml", "r") as f:
            status = yaml.safe_load(f)
    else:
        status = {}

    while True:
        line_print("Do you want to start from scratch? (y/n)")
        answer = input()
        if answer.lower() == "n":
            if os.path.isfile("status.yaml"):
                with open("status.yaml", "r") as f:
                    status = yaml.safe_load(f)
                    print(f"Current status:\n{json.dumps(status, indent=4)}")
                    break
            else:
                status = {}
                break
        elif answer.lower() == "y":
            status = {}
            break
        else:
            continue

    propmt_has_been_parsed = status.get("prompt_has_been_parsed", False)
    dataset_has_been_retrieved = status.get("dataset_has_been_retrieved", False)
    model_has_been_retrieved = status.get("model_has_been_retrieved", False)
    dataset_has_been_generated = status.get("dataset_has_been_generated", False)
    model_has_been_trained = status.get("model_has_been_trained", False)
    if not propmt_has_been_parsed:
        prompt = ""
        line_print(
            "Enter your task description and few-shot examples (or 'done' to finish):"
        )
        time.sleep(2)
        while True:
            line = input()
            if line == "done":
                break
            prompt += line + "\n"
        line_print("Parsing prompt...")
        prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
        prompt_spec.parse_from_prompt(prompt)

        propmt_has_been_parsed = True
        status["instruction"] = prompt_spec.instruction
        status["examples"] = prompt_spec.examples
        status["prompt_has_been_parsed"] = True
        with open("status.yaml", "w") as f:
            yaml.safe_dump(status, f)
        line_print("Prompt parsed.")

    if propmt_has_been_parsed and not dataset_has_been_retrieved:
        prompt_spec = MockPromptSpec(
            TaskType.TEXT_GENERATION, status["instruction"], status["examples"]
        )
        line_print("Retrieving dataset...")
        retriever = DescriptionDatasetRetriever()
        retrieved_dataset_dict = retriever.retrieve_dataset_dict(prompt_spec)
        dataset_has_been_retrieved = True
        if retrieved_dataset_dict is not None:
            retrieved_dataset_dict.save_to_disk("retrieved_dataset_dict")
            status["retrieved_dataset_dict_root"] = "retrieved_dataset_dict"
        else:
            status["retrieved_dataset_dict_root"] = None
        status["dataset_has_been_retrieved"] = True
        with open("status.yaml", "w") as f:
            yaml.safe_dump(status, f)

    if (
        propmt_has_been_parsed
        and dataset_has_been_retrieved
        and not model_has_been_retrieved
    ):
        line_print("Retrieving model...")
        prompt_spec = MockPromptSpec(
            TaskType.TEXT_GENERATION, status["instruction"], status["examples"]
        )
        retriever = DescriptionModelRetriever(
            model_descriptions_index_path="huggingface_data/huggingface_models/model_info/",  # noqa E501
            use_bm25=True,
            use_HyDE=True,
        )
        top_model_name = retriever.retrieve(prompt_spec)
        line_print("Here are the models we retrieved.")
        for idx, each in enumerate(top_model_name):
            line_print(f"# {idx + 1}: {each}")
        while True:
            line_print(
                "Enter the number of the model you want to use. Range from 1 to 5."
            )
            line = input()
            try:
                rank = int(line)
                assert 1 <= rank <= 5
                break
            except Exception:
                line_print("Invalid input. Please enter a number.")
        model_has_been_retrieved = True
        status["model_has_been_retrieved"] = True
        status["model_name"] = top_model_name[rank - 1]
        with open("status.yaml", "w") as f:
            yaml.safe_dump(status, f)

    if (
        propmt_has_been_parsed
        and dataset_has_been_retrieved
        and model_has_been_retrieved
        and not dataset_has_been_generated
    ):
        prompt_spec = MockPromptSpec(
            TaskType.TEXT_GENERATION, status["instruction"], status["examples"]
        )
        generator_logger = get_formatted_logger("DatasetGenerator")
        generator_logger.setLevel(logging.INFO)
        line_print("The dataset generation has not finished.")
        time.sleep(2)
        line_print(f"Your input instruction:\n\n{prompt_spec.instruction}")
        time.sleep(2)
        line_print(f"Your input few-shot examples:\n\n{prompt_spec.examples}")
        time.sleep(2)
        while True:
            line_print("Enter the number of examples you wish to generate:")
            line = input()
            try:
                num_expected = int(line)
                break
            except ValueError:
                line_print("Invalid input. Please enter a number.")
        while True:
            line_print("Enter the initial temperature:")
            line = input()
            try:
                initial_temperature = float(line)
                assert 0 <= initial_temperature <= 2.0
                break
            except Exception:
                line_print(
                    "Invalid initial temperature. Please enter a number (float) between 0 and 2."  # noqa E501
                )
        while True:
            line_print("Enter the max temperature (we suggest 1.4):")
            line = input()
            try:
                max_temperature = float(line)
                assert 0 <= max_temperature <= 2.0
                break
            except Exception:
                line_print(
                    "Invalid max temperature. Please enter a float between 0 and 2."
                )
        line_print("Starting to generate dataset. This may take a while...")
        time.sleep(2)
        unlimited_dataset_generator = PromptBasedDatasetGenerator(
            initial_temperature=initial_temperature,
            max_temperature=max_temperature,
            responses_per_request=3,
        )
        generated_dataset = unlimited_dataset_generator.generate_dataset_split(
            prompt_spec, num_expected, split=DatasetSplit.TRAIN
        )
        generated_dataset.save_to_disk("generated_dataset")
        dataset_has_been_generated = True
        status["dataset_has_been_generated"] = True
        with open("status.yaml", "w") as f:
            yaml.safe_dump(status, f)
        line_print("The generated dataset is ready.")
        time.sleep(2)

    if (
        propmt_has_been_parsed
        and dataset_has_been_retrieved
        and model_has_been_retrieved
        and dataset_has_been_generated
        and not model_has_been_trained
    ):
        line_print("The model has not been trained.")
        time.sleep(2)
        dataset_root = Path("generated_dataset")
        if not dataset_root.exists():
            raise ValueError("Dataset has not been generated yet.")
        trained_model_root = Path("result/trained_model")
        trained_tokenizer_root = Path("result/trained_tokenizer")
        RESULT_PATH = Path("result/result")
        trained_model_root.mkdir(parents=True, exist_ok=True)
        trained_tokenizer_root.mkdir(parents=True, exist_ok=True)
        RESULT_PATH.mkdir(parents=True, exist_ok=True)
        dataset = load_from_disk(dataset_root)
        if status["retrieved_dataset_dict_root"] is not None:
            cached_retrieved_dataset_dict = datasets.load_from_disk(
                status["retrieved_dataset_dict_root"]
            )
            dataset_list = [dataset, cached_retrieved_dataset_dict["train"]]
        else:
            dataset_list = [dataset]

        line_print("Processing datasets.")
        instruction = status["instruction"]
        t5_processor = TextualizeProcessor(has_encoder=True)
        t5_modified_dataset_dicts = t5_processor.process_dataset_lists(
            instruction,
            dataset_list,
            train_proportion=0.6,
            val_proportion=0.2,
            maximum_example_num=3000,
        )
        processor_logger = get_formatted_logger("DatasetProcessor")
        processor_logger.setLevel(logging.INFO)
        training_datasets = []
        validation_datasets = []
        test_datasets = []
        for idx, modified_dataset_dict in enumerate(t5_modified_dataset_dicts):
            training_datasets.append(modified_dataset_dict["train"])
            validation_datasets.append(modified_dataset_dict["val"])
            test_datasets.append(modified_dataset_dict["test"])
        trainer_logger = get_formatted_logger("ModelTrainer")
        trainer_logger.setLevel(logging.INFO)
        evaluator_logger = get_formatted_logger("ModelEvaluator")
        evaluator_logger.setLevel(logging.INFO)

        while True:
            line = input("Enter the training batch size:")
            try:
                train_batch_size = int(line)
                assert 0 < train_batch_size
                break
            except Exception:
                line_print("The training batch size must be greater than 0.")
        time.sleep(1)

        while True:
            line = input("Enter the number of epochs to train for:")
            try:
                num_epochs = int(line)
                break
            except ValueError:
                line_print("Invalid input. Please enter a number.")
        time.sleep(1)

        trainer = GenerationModelTrainer(
            status["model_name"],
            has_encoder=True,
            executor_batch_size=train_batch_size,
            tokenizer_max_length=1024,
            sequence_max_length=1280,
        )
        args_output_root = Path("result/training_output")
        args_output_root.mkdir(parents=True, exist_ok=True)
        line_print("Starting training.")
        trained_model, trained_tokenizer = trainer.train_model(
            hyperparameter_choices={
                "output_dir": str(args_output_root),
                "save_strategy": "epoch",
                "num_train_epochs": num_epochs,
                "per_device_train_batch_size": train_batch_size,
                "evaluation_strategy": "epoch",
            },
            training_datasets=training_datasets,
            validation_datasets=validation_datasets,
        )
        trained_model.save_pretrained(trained_model_root)
        trained_tokenizer.save_pretrained(trained_tokenizer_root)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        trained_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
            trained_model_root
        ).to(device)
        trained_tokenizer = transformers.AutoTokenizer.from_pretrained(
            trained_tokenizer_root
        )
        line_print("Finished training. Now evaluating on the test set.")
        test_dataset = concatenate_datasets(test_datasets)

        model_executor = GenerationModelExecutor(
            trained_model,
            trained_tokenizer,
            train_batch_size,
            tokenizer_max_length=1024,
            sequence_max_length=1280,
        )
        t5_outputs = model_executor.make_prediction(
            test_set=test_dataset, input_column="model_input"
        )
        evaluator = Seq2SeqEvaluator()
        metric_values = evaluator.evaluate_model(
            test_dataset,
            "model_output",
            t5_outputs,
            encoder_model_name="xlm-roberta-base",
        )
        line_print(metric_values)
        with open(RESULT_PATH / "metric.txt", "w") as result_file:
            for metric_name, metric_value in metric_values.items():
                result_file.write(f"{metric_name}: {metric_value}\n")
        status["model_has_been_trained"] = model_has_been_trained = True
        status["trained_model_root"] = str(trained_model_root)
        status["trained_tokenizer_root"] = str(trained_tokenizer_root)
        with open("status.yaml", "w") as f:
            yaml.safe_dump(status, f)
        line_print("Model has been trained and evaluated.")

    t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
        status["trained_model_root"]
    ).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    t5_tokenizer = transformers.AutoTokenizer.from_pretrained(
        status["trained_tokenizer_root"]
    )
    model_executor = GenerationModelExecutor(
        t5_model, t5_tokenizer, 1, tokenizer_max_length=1024, sequence_max_length=1280
    )
    prompt_spec = MockPromptSpec(
        TaskType.TEXT_GENERATION, status["instruction"], status["examples"]
    )
    interface_t5 = create_gradio(model_executor, prompt_spec)
    interface_t5.launch(share=True)


if __name__ == "__main__":
    main()
View Code

參考鏈接:

https://colab.research.google.com/github/neulab/prompt2model/blob/main/prompt2model_demo.ipynb
https://github.com/neulab/prompt2model 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章