File size: 4,947 Bytes
22adb3d
205b891
73acd2c
7ba22ea
 
 
 
 
 
53f0151
 
 
205b891
22adb3d
 
73acd2c
22adb3d
 
 
73acd2c
205b891
 
7ba22ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53f0151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f7edf
22adb3d
53f0151
 
 
 
 
22adb3d
 
 
 
 
 
4f86d5e
 
 
7ba22ea
22adb3d
 
 
 
53f0151
 
 
 
77f7edf
53f0151
 
 
 
 
 
 
205b891
 
73acd2c
 
205b891
 
73acd2c
 
205b891
 
73acd2c
 
6b0946c
77f7edf
6b0946c
73acd2c
 
 
 
22adb3d
73acd2c
 
 
 
 
 
22adb3d
205b891
73acd2c
22adb3d
4933e80
 
 
73acd2c
 
 
 
4f86d5e
73acd2c
 
 
 
 
205b891
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import asyncio
import gradio as gr
import transformers
from transformers import (
    TextIteratorStreamer,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
)
import threading

import ctypes

tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True)

pipeline = transformers.pipeline(
    "text-generation",
    model="pfnet/plamo-2-1b",
    trust_remote_code=True,
)


class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids, scores):
        last_token = input_ids[0][-2:]
        for stop in self.stops:
            if stop in tokenizer.decode(last_token):
                return True
        return False


stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])])


class CancelableThread(threading.Thread):
    def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
        threading.Thread.__init__(self, group=group, target=target, name=name)
        self.args = args
        self.kwargs = kwargs
        return

    def run(self):
        self.id = threading.get_native_id()
        self._target(*self.args, **self.kwargs)

    def get_id(self):
        return self.id

    def raise_exception(self):
        thread_id = self.get_id()
        resu = ctypes.pythonapi.PyThreadState_SetAsyncExc(
            ctypes.c_long(thread_id), ctypes.py_object(SystemExit)
        )
        if resu > 1:
            ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0)
            print("Failure in raising exception")


class ThreadManager:
    def __init__(self, thread: CancelableThread, **kwargs):
        self.thread = thread

    def __enter__(self):
        # スレッドを開始
        self.thread.start()
        return self.thread

    def __exit__(self, exc_type, exc_value, traceback):
        # スレッドの終了を待機
        if self.thread.is_alive():
            print("trying to terminate thread")
            self.thread.raise_exception()
        self.thread.join()
        print("Thread has been successfully joined.")


def respond(prompt, max_tokens):
    # print(prompt)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    thread = CancelableThread(
        target=pipeline,
        kwargs=dict(
            text_inputs=prompt,
            max_new_tokens=max_tokens,
            return_full_text=False,
            streamer=streamer,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria,
        ),
    )
    response = ""

    with ThreadManager(thread=thread):
        for output in streamer:
            if not output:
                continue
            # print(output)
            response += output
            yield response, gr.update(interactive=False), gr.update(interactive=False),
        yield (
            response,
            gr.update(interactive=True),
            gr.update(interactive=True),
        )


def reset_textbox():
    return gr.update(value=""), gr.update(value="")


def no_interactive():
    return gr.update(interactive=False), gr.update(interactive=False)


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">plamo-2-1b CPU demo</h1>""")
    gr.Markdown(
        "2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHF中国镜像站の無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。"
    )
    with gr.Column(elem_id="col_container") as main_block:
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(
                    lines=15, label="input_text", placeholder="これからの人工知能技術は"
                )
                with gr.Row():
                    with gr.Column(scale=3):
                        clear_button = gr.Button("Clear")
                    with gr.Column(scale=5):
                        submit_button = gr.Button("Submit")
            outputs = gr.Textbox(lines=20, label="Output")

        # inputs, top_p, temperature, top_k, repetition_penalty
        with gr.Accordion("Parameters", open=False):
            max_tokens = gr.Slider(
                minimum=1, maximum=4096, value=32, step=1, label="Max new tokens"
            )

    submit_button.click(no_interactive, [], [submit_button, clear_button])
    submit_button.click(
        respond,
        [input_text, max_tokens],
        [outputs, submit_button, clear_button],
    )
    clear_button.click(reset_textbox, [], [input_text, outputs], queue=False)

    demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    demo.launch()