[Python] 언어모델의 출력을 스트리밍 방식으로 출력하기

2023. 6. 11. 22:15·Python

 언어모델을 GPU에 올려 사용할 경우 출력까지 길게는 30초 이상 걸리는 경우가 있습니다. 만약 서비스에 적용한다고 하면 사용자 입장에서는 이 시간이 길게 느껴질 것입니다. 따라서 이 부분을 스트리밍 방식으로 출력하면 결과까지 걸리는 시간은 같지만 사용자는 기다린다는 느낌이 적어지므로 UX관점에서 해결할 수 있습니다. ChatGPT를 웹에서 사용할 때의 방식이라고 보시면 될 것 같습니다.

스트리밍 방식의 출력

0. 한 번에 출력하는 방식

 Koalpaca 5.8b 모델로 다음과 같은 질문을 했을 때 20초 정도가 걸렸습니다. 이걸 허깅페이스를 통해 스트리밍 방식으로 바꿔보겠습니다.

%%time 
inputs = tokenizer("###질문:피보나치 수열을 파이썬 코드로 만들어줘", return_tensors="pt",return_token_type_ids=False)
generated = model.generate(**inputs,max_new_tokens=256,eos_token_id=2,top_p=0.9,early_stopping=True)[0]
print(tokenizer.decode(generated))

 

###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 1:
        return n
    else:
        return fibonacci(n-1) + fibonacci(n-2)

for i in range(10):
    print(fibonacci(i))
```

출력:
```
fibonacci(10):
    print(i)
```<|endoftext|>
CPU times: user 24.3 s, sys: 387 ms, total: 24.6 s
Wall time: 20.7 s

1. 스트리밍 방식

스트리밍으로 출력을 하려면 허깅페이스의 Thread와 TextIteratorStreamer를 사용해야 합니다.

from threading import Thread
from transformers TextIteratorStreamer

1-1

TextIteratorStreamer를 사용하여 streamer를 만듭니다. 그리고 생성에 사용할 변수들을 딕셔너리형태로 만들어줍니다.

inputs = tokenizer("###질문:피보나치 수열을 파이썬 코드로 만들어줘", return_tensors="pt",return_token_type_ids=False)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=256,eos_token_id=2,top_p=0.9,early_stopping=True)

1-2

생성된 텍스트를 끊기지 않게(논-블록킹 방식) 가져오려면 별도의 스레드가 필요합니다.

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

1-3

생성된 제너레이터를 for문으로 출력하면 스트리밍 방식으로 토큰이 늘어납니다.

generated_text = ""
for new_text in streamer:
    generated_text += new_text
    print(generated_text)

 

###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 
    
###질문:피보나치 수열을 파이썬 코드로 만들어줘.

### 맥락:
```python
def fibonacci(n):
    if n <= 1:

...

2. Gradio 적용

from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer


torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())


model = AutoModelForCausalLM.from_pretrained(
    "beomi/KoAlpaca-Polyglot-5.8B", 
    device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("beomi/KoAlpaca-Polyglot-5.8B")

def run_generation(user_text):
    model_inputs = tokenizer([user_text], return_tensors="pt",return_token_type_ids=False).to(torch_device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=256,
        top_p=0.9,
        early_stopping=True,
        eos_token_id=2
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output


def reset_textbox():
    return gr.update(value='')


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                placeholder="",
                label="User input"
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")


    user_text.submit(run_generation, [user_text], model_output)
    button_submit.click(run_generation, [user_text], model_output)

    demo.queue(max_size=32).launch(enable_queue=True,share=True)

※ KoAlpaca는 [###질문: ###답변:]의 프롬프트로 학습했기 때문에 이런 형식의 프롬프트를 입력해야 제대로 된 결과가 나옵니다.

'Python' 카테고리의 다른 글

[Python] 컴파일로 속도 개선하기(Cython)  (0) 2023.10.31
[Python] 프로파일링으로 병목 찾기 (line_profiler)  (1) 2023.10.09
[Python] PEFT 라이브러리 알아보기  (0) 2023.05.29
[Python] Folium으로 지도에 행정구역 경계 표시하기  (0) 2023.02.26
[Python] selenium 사용 시 chromedriver 자동 업데이트하기  (0) 2023.02.12
'Python' 카테고리의 다른 글
  • [Python] 컴파일로 속도 개선하기(Cython)
  • [Python] 프로파일링으로 병목 찾기 (line_profiler)
  • [Python] PEFT 라이브러리 알아보기
  • [Python] Folium으로 지도에 행정구역 경계 표시하기
gunuuu
gunuuu
주로 AI, ML에 관한 글을 씁니다.
  • gunuuu
    gunuuu
    gunuuu
  • 전체
    오늘
    어제
    • 분류 전체보기 (40)
      • AI/ML (11)
        • NLP (8)
        • RAG (1)
      • Pytorch (5)
      • Python (11)
      • SQL (2)
      • Causal Inference (3)
        • DoWhy (1)
      • 일상 (3)
      • 책 (5)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    itertuples
    크레마S
    KoAlpaca
    line_profiler
    cython
    Hybrid search
    TALLRec
    Chain of Thought
    window function
    Tree of Thought
    SQL
    대규모언어모델
    모델경량화
    허깅페이스
    인과 추론
    nlp
    DoWhy
    sparse vector
    PEFT
    DetectGPT
    TAPT
    DAPT
    FAISS
    DataDistributedParallel
    벡터 db
    미니어처 라이프 서울
    인과추론
    Low-Rank Adaptation
    bm25
    onnx
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
gunuuu
[Python] 언어모델의 출력을 스트리밍 방식으로 출력하기
상단으로

티스토리툴바