Spaces:
Running
Running
File size: 4,947 Bytes
22adb3d 205b891 73acd2c 7ba22ea 53f0151 205b891 22adb3d 73acd2c 22adb3d 73acd2c 205b891 7ba22ea 53f0151 77f7edf 22adb3d 53f0151 22adb3d 4f86d5e 7ba22ea 22adb3d 53f0151 77f7edf 53f0151 205b891 73acd2c 205b891 73acd2c 205b891 73acd2c 6b0946c 77f7edf 6b0946c 73acd2c 22adb3d 73acd2c 22adb3d 205b891 73acd2c 22adb3d 4933e80 73acd2c 4f86d5e 73acd2c 205b891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import asyncio
import gradio as gr
import transformers
from transformers import (
TextIteratorStreamer,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
import threading
import ctypes
tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)
pipeline = transformers.pipeline(
"text-generation",
model="pfnet/plamo-2-1b",
trust_remote_code=True,
)
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids, scores):
last_token = input_ids[0][-2:]
for stop in self.stops:
if stop in tokenizer.decode(last_token):
return True
return False
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])])
class CancelableThread(threading.Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
threading.Thread.__init__(self, group=group, target=target, name=name)
self.args = args
self.kwargs = kwargs
return
def run(self):
self.id = threading.get_native_id()
self._target(*self.args, **self.kwargs)
def get_id(self):
return self.id
def raise_exception(self):
thread_id = self.get_id()
resu = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(SystemExit)
)
if resu > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0)
print("Failure in raising exception")
class ThreadManager:
def __init__(self, thread: CancelableThread, **kwargs):
self.thread = thread
def __enter__(self):
# スレッドを開始
self.thread.start()
return self.thread
def __exit__(self, exc_type, exc_value, traceback):
# スレッドの終了を待機
if self.thread.is_alive():
print("trying to terminate thread")
self.thread.raise_exception()
self.thread.join()
print("Thread has been successfully joined.")
def respond(prompt, max_tokens):
# print(prompt)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
thread = CancelableThread(
target=pipeline,
kwargs=dict(
text_inputs=prompt,
max_new_tokens=max_tokens,
return_full_text=False,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
),
)
response = ""
with ThreadManager(thread=thread):
for output in streamer:
if not output:
continue
# print(output)
response += output
yield response, gr.update(interactive=False), gr.update(interactive=False),
yield (
response,
gr.update(interactive=True),
gr.update(interactive=True),
)
def reset_textbox():
return gr.update(value=""), gr.update(value="")
def no_interactive():
return gr.update(interactive=False), gr.update(interactive=False)
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">plamo-2-1b CPU demo</h1>""")
gr.Markdown(
"2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHF中国镜像站の無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。"
)
with gr.Column(elem_id="col_container") as main_block:
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
lines=15, label="input_text", placeholder="これからの人工知能技術は"
)
with gr.Row():
with gr.Column(scale=3):
clear_button = gr.Button("Clear")
with gr.Column(scale=5):
submit_button = gr.Button("Submit")
outputs = gr.Textbox(lines=20, label="Output")
# inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=False):
max_tokens = gr.Slider(
minimum=1, maximum=4096, value=32, step=1, label="Max new tokens"
)
submit_button.click(no_interactive, [], [submit_button, clear_button])
submit_button.click(
respond,
[input_text, max_tokens],
[outputs, submit_button, clear_button],
)
clear_button.click(reset_textbox, [], [input_text, outputs], queue=False)
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
demo.launch()
|