plamo-2-1b / app.py
masanorihirano's picture
Update app.py
77f7edf verified
import asyncio
import gradio as gr
import transformers
from transformers import (
TextIteratorStreamer,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
import threading
import ctypes
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
pipeline = transformers.pipeline(
"text-generation",
model="pfnet/plamo-2-1b",
trust_remote_code=True,
)
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids, scores):
last_token = input_ids[0][-2:]
for stop in self.stops:
if stop in tokenizer.decode(last_token):
return True
return False
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])])
class CancelableThread(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
threading.Thread.__init__(self, group=group, target=target, name=name)
self.args = args
self.kwargs = kwargs
return
def run(self):
self.id = threading.get_native_id()
self._target(*self.args, **self.kwargs)
def get_id(self):
return self.id
def raise_exception(self):
thread_id = self.get_id()
resu = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(SystemExit)
)
if resu > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0)
print("Failure in raising exception")
class ThreadManager:
def __init__(self, thread: CancelableThread, **kwargs):
self.thread = thread
def __enter__(self):
# スレッドを開始
self.thread.start()
return self.thread
def __exit__(self, exc_type, exc_value, traceback):
# スレッドの終了を待機
if self.thread.is_alive():
print("trying to terminate thread")
self.thread.raise_exception()
self.thread.join()
print("Thread has been successfully joined.")
def respond(prompt, max_tokens):
# print(prompt)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
thread = CancelableThread(
target=pipeline,
kwargs=dict(
text_inputs=prompt,
max_new_tokens=max_tokens,
return_full_text=False,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
),
)
response = ""
with ThreadManager(thread=thread):
for output in streamer:
if not output:
continue
# print(output)
response += output
yield response, gr.update(interactive=False), gr.update(interactive=False),
yield (
response,
gr.update(interactive=True),
gr.update(interactive=True),
)
def reset_textbox():
return gr.update(value=""), gr.update(value="")
def no_interactive():
return gr.update(interactive=False), gr.update(interactive=False)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">plamo-2-1b CPU demo</h1>""")
gr.Markdown(
"2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHF中国镜像站の無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。"
)
with gr.Column(elem_id="col_container") as main_block:
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=15, label="input_text", placeholder="これからの人工知能技術は"
)
with gr.Row():
with gr.Column(scale=3):
clear_button = gr.Button("Clear")
with gr.Column(scale=5):
submit_button = gr.Button("Submit")
outputs = gr.Textbox(lines=20, label="Output")
# inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=False):
max_tokens = gr.Slider(
minimum=1, maximum=4096, value=32, step=1, label="Max new tokens"
)
submit_button.click(no_interactive, [], [submit_button, clear_button])
submit_button.click(
respond,
[input_text, max_tokens],
[outputs, submit_button, clear_button],
)
clear_button.click(reset_textbox, [], [input_text, outputs], queue=False)
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
demo.launch()