[Python] 언어모델의 출력을 스트리밍 방식으로 출력하기

2023. 6. 11. 22:15·Python

 언어모델을 GPU에 올려 사용할 경우 출력까지 길게는 30초 이상 걸리는 경우가 있습니다. 만약 서비스에 적용한다고 하면 사용자 입장에서는 이 시간이 길게 느껴질 것입니다. 따라서 이 부분을 스트리밍 방식으로 출력하면 결과까지 걸리는 시간은 같지만 사용자는 기다린다는 느낌이 적어지므로 UX관점에서 해결할 수 있습니다. ChatGPT를 웹에서 사용할 때의 방식이라고 보시면 될 것 같습니다.

스트리밍 방식의 출력

0. 한 번에 출력하는 방식

 Koalpaca 5.8b 모델로 다음과 같은 질문을 했을 때 20초 정도가 걸렸습니다. 이걸 허깅페이스를 통해 스트리밍 방식으로 바꿔보겠습니다.

%%time 
inputs = tokenizer("###질문:피보나치 수열을 파이썬 코드로 만들어줘", return_tensors="pt",return_token_type_ids=False)
generated = model.generate(**inputs,max_new_tokens=256,eos_token_id=2,top_p=0.9,early_stopping=True)[0]
print(tokenizer.decode(generated))

 

###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 1:
        return n
    else:
        return fibonacci(n-1) + fibonacci(n-2)

for i in range(10):
    print(fibonacci(i))
```

출력:
```
fibonacci(10):
    print(i)
```<|endoftext|>
CPU times: user 24.3 s, sys: 387 ms, total: 24.6 s
Wall time: 20.7 s

1. 스트리밍 방식

스트리밍으로 출력을 하려면 허깅페이스의 Thread와 TextIteratorStreamer를 사용해야 합니다.

from threading import Thread
from transformers TextIteratorStreamer

1-1

TextIteratorStreamer를 사용하여 streamer를 만듭니다. 그리고 생성에 사용할 변수들을 딕셔너리형태로 만들어줍니다.

inputs = tokenizer("###질문:피보나치 수열을 파이썬 코드로 만들어줘", return_tensors="pt",return_token_type_ids=False)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=256,eos_token_id=2,top_p=0.9,early_stopping=True)

1-2

생성된 텍스트를 끊기지 않게(논-블록킹 방식) 가져오려면 별도의 스레드가 필요합니다.

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

1-3

생성된 제너레이터를 for문으로 출력하면 스트리밍 방식으로 토큰이 늘어납니다.

generated_text = ""
for new_text in streamer:
    generated_text += new_text
    print(generated_text)

 

###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 1:

...

2. Gradio 적용

from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer


torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())


model = AutoModelForCausalLM.from_pretrained(
    "beomi/KoAlpaca-Polyglot-5.8B", 
    device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("beomi/KoAlpaca-Polyglot-5.8B")

def run_generation(user_text):
    model_inputs = tokenizer([user_text], return_tensors="pt",return_token_type_ids=False).to(torch_device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=256,
        top_p=0.9,
        early_stopping=True,
        eos_token_id=2
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output


def reset_textbox():
    return gr.update(value='')


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder="",
                label="User input"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")


    user_text.submit(run_generation, [user_text], model_output)
    button_submit.click(run_generation, [user_text], model_output)

    demo.queue(max_size=32).launch(enable_queue=True,share=True)

※ KoAlpaca는 [###질문: ###답변:]의 프롬프트로 학습했기 때문에 이런 형식의 프롬프트를 입력해야 제대로 된 결과가 나옵니다.

'Python' 카테고리의 다른 글

[Python] 컴파일로 속도 개선하기(Cython)  (0) 2023.10.31
[Python] 프로파일링으로 병목 찾기 (line_profiler)  (1) 2023.10.09
[Python] PEFT 라이브러리 알아보기  (0) 2023.05.29
[Python] Folium으로 지도에 행정구역 경계 표시하기  (0) 2023.02.26
[Python] selenium 사용 시 chromedriver 자동 업데이트하기  (0) 2023.02.12
'Python' 카테고리의 다른 글
  • [Python] 컴파일로 속도 개선하기(Cython)
  • [Python] 프로파일링으로 병목 찾기 (line_profiler)
  • [Python] PEFT 라이브러리 알아보기
  • [Python] Folium으로 지도에 행정구역 경계 표시하기
gunuuu
gunuuu
주로 AI, ML에 관한 글을 씁니다.
  • gunuuu
    gunuuu
    gunuuu
  • 전체
    오늘
    어제
    • 분류 전체보기 (40)
      • AI/ML (11)
        • NLP (8)
        • RAG (1)
      • Pytorch (5)
      • Python (11)
      • SQL (2)
      • Causal Inference (3)
        • DoWhy (1)
      • 일상 (3)
      • 책 (5)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    onnx
    Tree of Thought
    window function
    SQL
    DetectGPT
    미니어처 라이프 서울
    인과추론
    인과 추론
    허깅페이스
    bm25
    nlp
    TALLRec
    line_profiler
    DAPT
    FAISS
    TAPT
    cython
    Hybrid search
    Chain of Thought
    대규모언어모델
    PEFT
    벡터 db
    Low-Rank Adaptation
    itertuples
    KoAlpaca
    모델경량화
    sparse vector
    DoWhy
    DataDistributedParallel
    크레마S
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
gunuuu
[Python] 언어모델의 출력을 스트리밍 방식으로 출력하기
상단으로

티스토리툴바