Spaces:
Running
Running
import asyncio | |
import gradio as gr | |
import transformers | |
from transformers import ( | |
TextIteratorStreamer, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
import threading | |
import ctypes | |
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True) | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model="pfnet/plamo-2-1b", | |
trust_remote_code=True, | |
) | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids, scores): | |
last_token = input_ids[0][-2:] | |
for stop in self.stops: | |
if stop in tokenizer.decode(last_token): | |
return True | |
return False | |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])]) | |
class CancelableThread(threading.Thread): | |
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): | |
threading.Thread.__init__(self, group=group, target=target, name=name) | |
self.args = args | |
self.kwargs = kwargs | |
return | |
def run(self): | |
self.id = threading.get_native_id() | |
self._target(*self.args, **self.kwargs) | |
def get_id(self): | |
return self.id | |
def raise_exception(self): | |
thread_id = self.get_id() | |
resu = ctypes.pythonapi.PyThreadState_SetAsyncExc( | |
ctypes.c_long(thread_id), ctypes.py_object(SystemExit) | |
) | |
if resu > 1: | |
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0) | |
print("Failure in raising exception") | |
class ThreadManager: | |
def __init__(self, thread: CancelableThread, **kwargs): | |
self.thread = thread | |
def __enter__(self): | |
# スレッドを開始 | |
self.thread.start() | |
return self.thread | |
def __exit__(self, exc_type, exc_value, traceback): | |
# スレッドの終了を待機 | |
if self.thread.is_alive(): | |
print("trying to terminate thread") | |
self.thread.raise_exception() | |
self.thread.join() | |
print("Thread has been successfully joined.") | |
def respond(prompt, max_tokens): | |
# print(prompt) | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
thread = CancelableThread( | |
target=pipeline, | |
kwargs=dict( | |
text_inputs=prompt, | |
max_new_tokens=max_tokens, | |
return_full_text=False, | |
streamer=streamer, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
stopping_criteria=stopping_criteria, | |
), | |
) | |
response = "" | |
with ThreadManager(thread=thread): | |
for output in streamer: | |
if not output: | |
continue | |
# print(output) | |
response += output | |
yield response, gr.update(interactive=False), gr.update(interactive=False), | |
yield ( | |
response, | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
) | |
def reset_textbox(): | |
return gr.update(value=""), gr.update(value="") | |
def no_interactive(): | |
return gr.update(interactive=False), gr.update(interactive=False) | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1 align="center">plamo-2-1b CPU demo</h1>""") | |
gr.Markdown( | |
"2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHF中国镜像站の無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。" | |
) | |
with gr.Column(elem_id="col_container") as main_block: | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
lines=15, label="input_text", placeholder="これからの人工知能技術は" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
clear_button = gr.Button("Clear") | |
with gr.Column(scale=5): | |
submit_button = gr.Button("Submit") | |
outputs = gr.Textbox(lines=20, label="Output") | |
# inputs, top_p, temperature, top_k, repetition_penalty | |
with gr.Accordion("Parameters", open=False): | |
max_tokens = gr.Slider( | |
minimum=1, maximum=4096, value=32, step=1, label="Max new tokens" | |
) | |
submit_button.click(no_interactive, [], [submit_button, clear_button]) | |
submit_button.click( | |
respond, | |
[input_text, max_tokens], | |
[outputs, submit_button, clear_button], | |
) | |
clear_button.click(reset_textbox, [], [input_text, outputs], queue=False) | |
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
demo.launch() | |