File size: 4,611 Bytes
197279e
28d60d1
 
 
 
 
b90b971
 
 
28d60d1
 
 
3673d00
28d60d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b90b971
 
 
 
28d60d1
b90b971
 
28d60d1
 
 
 
 
b90b971
 
 
 
 
28d60d1
b90b971
 
 
 
28d60d1
 
197279e
 
28d60d1
 
 
 
 
 
 
 
 
 
 
 
 
d25cf4a
28d60d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b90b971
 
 
 
 
 
 
 
c833d30
b90b971
28d60d1
f0371e7
b90b971
 
c833d30
28d60d1
7849211
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread

# model_path = 'sail/Sailor-7B-Chat'
model_path = 'sail/Sailor2-20B-Chat'


# Loading the tokenizer and model from HF中国镜像站's model hub.
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)

# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [151645]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False


# system_role= 'system'
# user_role = 'question'
# assistant_role = "answer"

system_role= 'system'
user_role = 'user'
assistant_role = 'assistant'

sft_start_token =  "<|im_start|>"
sft_end_token = "<|im_end|>"
ct_end_token = "<|endoftext|>"

# system_prompt= \
# 'You are an AI assistant named Sailor created by Sea AI Lab. \
# Your answer should be friendly, unbiased, faithful, informative and detailed.'
# system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>"

system_prompt= \
'You are an AI assistant named Sailor2, created by Sea AI Lab. \
As an AI assistant, you can answer questions in English, Chinese, and Southeast Asian languages \
such as Burmese, Cebuano, Ilocano, Indonesian, Javanese, Khmer, Lao, Malay, Sundanese, Tagalog, Thai, Vietnamese, and Waray. \
Your responses should be friendly, unbiased, informative, detailed, and faithful.'

# Function to generate model predictions.

@spaces.GPU()
def predict(message, history):
    # history = []
    history_transformer_format = history + [[message, ""]]
    stop = StopOnTokens()

    # Formatting the input for the model.
    messages =  system_prompt + sft_end_token.join([sft_end_token.join([f"\n{sft_start_token}{user_role}\n" + item[0], f"\n{sft_start_token}{assistant_role}\n" + item[1]])
                        for item in history_transformer_format])
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        top_p= 0.75,
        top_k= 60,
        temperature=0.2,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop]),
        repetition_penalty=1.1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()  # Starting the generation in a separate thread.
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        if sft_end_token in partial_message:  # Breaking the loop if the stop token is generated.
            break
        yield partial_message


css = """
full-height {
    height: 100%;
}
"""

prompt_examples = [
    'How to cook a fish?',
    'Cara memanggang ikan',
    'วิธีย่างปลา',
    'Cách nướng cá'
]

# placeholder = """
# <div style="opacity: 0.5;">
#     <img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;">
#     <br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions:
#     <br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲🇾Malay, and 🇱🇦Lao.
# </div>
# """
placeholder=''

chatbot = gr.Chatbot(label='Sailor2', placeholder=placeholder) 
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
    # gr.Markdown("""<center><font size=8>Sailor-Chat Bot⚓</center>""")
    # gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor-llm/raw/main/misc/wide_sailor_banner.jpg" style="height: 110px"/><p>""")
    gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor2/raw/main/misc/sailor2_wide_banner.jpg" style="height: 110px"/><p>""")
    gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css)

    demo.launch()  # Launching the web interface.