Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
203509f
1
Parent(s):
a04dbc6
✨ add constrained gen script
Browse filesSigned-off-by: Peter <[email protected]>
- constrained_generation.py +255 -0
constrained_generation.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
constrained_generation.py - use constrained beam search to generate text from a model with entered constraints
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import logging
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import yake
|
11 |
+
from transformers import AutoTokenizer, PhrasalConstraint
|
12 |
+
|
13 |
+
def get_tokenizer(model_name="gpt2", verbose=False):
|
14 |
+
"""
|
15 |
+
get_tokenizer - returns a tokenizer object
|
16 |
+
|
17 |
+
:param model_name: name of the model to use, default gpt2
|
18 |
+
:param verbose: verbosity
|
19 |
+
"""
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
21 |
+
model_name, add_special_tokens=False, padding=True, truncation=True
|
22 |
+
)
|
23 |
+
tokenizer.pad_token = tokenizer.eos_token
|
24 |
+
if verbose:
|
25 |
+
print(f"loaded tokenizer {model_name}")
|
26 |
+
return tokenizer
|
27 |
+
|
28 |
+
|
29 |
+
def unique_words(list_of_strings):
|
30 |
+
"""
|
31 |
+
unique_words - return a list of unique words from a list of strings. Uses set to remove duplicates.
|
32 |
+
"""
|
33 |
+
unique_words = []
|
34 |
+
output_list = []
|
35 |
+
for string in list_of_strings:
|
36 |
+
# split string into words
|
37 |
+
words = string.split()
|
38 |
+
# check if word is unique
|
39 |
+
unique_status = True
|
40 |
+
for word in words:
|
41 |
+
if word not in unique_words:
|
42 |
+
unique_words.append(word)
|
43 |
+
else:
|
44 |
+
unique_status = False
|
45 |
+
break
|
46 |
+
if unique_status:
|
47 |
+
output_list.append(string)
|
48 |
+
|
49 |
+
return output_list
|
50 |
+
|
51 |
+
|
52 |
+
def create_kw_extractor(
|
53 |
+
language="en",
|
54 |
+
max_ngram_size=3,
|
55 |
+
deduplication_algo="seqm",
|
56 |
+
windowSize=10,
|
57 |
+
numOfKeywords=10,
|
58 |
+
ddpt=0.7,
|
59 |
+
):
|
60 |
+
"""
|
61 |
+
creates a keyword extractor object
|
62 |
+
|
63 |
+
:param language: language of the text
|
64 |
+
:param max_ngram_size: max ngram size
|
65 |
+
:param deduplication_algo: deduplication algorithm
|
66 |
+
:param windowSize: window size
|
67 |
+
:param numOfKeywords: number of keywords
|
68 |
+
:param ddpt: Deduplication Percentage Threshold
|
69 |
+
|
70 |
+
:return: keyword extractor object
|
71 |
+
"""
|
72 |
+
assert ddpt >= 0 and ddpt <= 1, f"need 0<thresh<1, got {ddpt}"
|
73 |
+
return yake.KeywordExtractor(
|
74 |
+
lan=language,
|
75 |
+
n=max_ngram_size,
|
76 |
+
dedupLim=ddpt,
|
77 |
+
dedupFunc=deduplication_algo,
|
78 |
+
windowsSize=windowSize,
|
79 |
+
top=numOfKeywords,
|
80 |
+
features=None,
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def simple_kw(body_text: str, yake_ex=None, max_kw=10, verbose=False):
|
85 |
+
"""
|
86 |
+
simple_kw - extract keywords from a text using yake
|
87 |
+
|
88 |
+
Args:
|
89 |
+
body_text (str): text to extract keywords from
|
90 |
+
yake_ex (yake.KeywordExtractor, optional): yake keyword extractor. Defaults to None.
|
91 |
+
max_kw (int, optional): maximum number of keywords to extract. Defaults to 10.
|
92 |
+
verbose (bool, optional): Defaults to False.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
list: list of keywords
|
96 |
+
"""
|
97 |
+
yake_ex = yake_ex or create_kw_extractor(
|
98 |
+
max_ngram_size=2,
|
99 |
+
ddpt=0.8,
|
100 |
+
windowSize=10,
|
101 |
+
deduplication_algo="seqm",
|
102 |
+
numOfKeywords=max_kw,
|
103 |
+
) # per optuna study
|
104 |
+
|
105 |
+
keywords = yake_ex.extract_keywords(body_text)
|
106 |
+
keywords_list = [str(kw[0]).lower() for kw in keywords]
|
107 |
+
logging.info(
|
108 |
+
f"YAKE: found {len(keywords_list)} keywords, the top {max_kw} are: {keywords_list[:max_kw]}"
|
109 |
+
)
|
110 |
+
|
111 |
+
if verbose:
|
112 |
+
|
113 |
+
print(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
|
114 |
+
print(keywords_list[:max_kw])
|
115 |
+
logging.info(f"found {len(keywords_list)} keywords, the top {max_kw} are:")
|
116 |
+
|
117 |
+
return keywords_list[:max_kw]
|
118 |
+
|
119 |
+
|
120 |
+
def constrained_generation(
|
121 |
+
prompt: str,
|
122 |
+
pipeline,
|
123 |
+
tokenizer=None,
|
124 |
+
no_repeat_ngram_size=2,
|
125 |
+
length_penalty=0.7,
|
126 |
+
repetition_penalty=3.5,
|
127 |
+
num_beams=4,
|
128 |
+
max_generated_tokens=48,
|
129 |
+
min_generated_tokens=2,
|
130 |
+
timeout=300,
|
131 |
+
num_return_sequences=1,
|
132 |
+
verbose=False,
|
133 |
+
full_text=False,
|
134 |
+
force_word: str = None,
|
135 |
+
speaker_name: str = "Person Alpha",
|
136 |
+
responder_name: str = "Person Beta",
|
137 |
+
**kwargs,
|
138 |
+
):
|
139 |
+
"""
|
140 |
+
constrained_generation - generate text based on prompt and constraints
|
141 |
+
|
142 |
+
USAGE
|
143 |
+
-----
|
144 |
+
response = constrained_generation("hey man - how have you been lately?",
|
145 |
+
tokenizer, my_chatbot, verbose=True,
|
146 |
+
force_word=" meme", num_beams=32)
|
147 |
+
|
148 |
+
Parameters
|
149 |
+
----------
|
150 |
+
prompt : str, prompt to use for generation,
|
151 |
+
tokenizer : transformers.PreTrainedTokenizer, tokenizer to use, must be compatible with model
|
152 |
+
pipeline : transformers.pipeline, pipeline to use, must be compatible with tokenizer & text2text model
|
153 |
+
no_repeat_ngram_size : int, optional, default=2,
|
154 |
+
num_beams : int, optional, default=8,
|
155 |
+
max_generated_tokens : int, optional, default=64,
|
156 |
+
min_generated_tokens : int, optional, default=16,
|
157 |
+
verbose : bool, optional, default=False, print output
|
158 |
+
force_word : _type_, optional, default=None, force word to be used in generation
|
159 |
+
speaker_name : str, optional, default="Person Alpha", name of speaker
|
160 |
+
responder_name : str, optional, default="Person Beta", name of responder
|
161 |
+
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
response : str, generated text
|
165 |
+
"""
|
166 |
+
st = time.perf_counter()
|
167 |
+
tokenizer = tokenizer or copy.deepcopy(pipeline.tokenizer)
|
168 |
+
tokenizer.add_prefix_space = True
|
169 |
+
tokenizer.add_special_tokens = False
|
170 |
+
|
171 |
+
prompt_length = len(tokenizer(prompt, truncation=True).input_ids)
|
172 |
+
if responder_name.lower() not in prompt.lower():
|
173 |
+
prompt = f"{prompt}\n\n{responder_name}:\n"
|
174 |
+
# key_prompt_phrases = get_keyberts(prompt)
|
175 |
+
key_prompt_phrases = simple_kw(prompt)
|
176 |
+
|
177 |
+
try:
|
178 |
+
responder_name_words = responder_name.lower().split()
|
179 |
+
speaker_name_words = speaker_name.lower().split()
|
180 |
+
except Exception as e:
|
181 |
+
responder_name_words = []
|
182 |
+
speaker_name_words = []
|
183 |
+
logging.info(f"could not split names: {e}")
|
184 |
+
|
185 |
+
key_prompt_phrases = [
|
186 |
+
p
|
187 |
+
for p in key_prompt_phrases
|
188 |
+
if not any([name in p for name in responder_name_words])
|
189 |
+
and not any([name in p for name in speaker_name_words])
|
190 |
+
]
|
191 |
+
force_flexible = unique_words(key_prompt_phrases)
|
192 |
+
print(f"found keywords: {force_flexible}")
|
193 |
+
|
194 |
+
if verbose:
|
195 |
+
logging.info(f"found the following keywords: {force_flexible}")
|
196 |
+
logging.info(
|
197 |
+
f"forcing the word: {force_word}"
|
198 |
+
) if force_word is not None else logging.info("\n")
|
199 |
+
else:
|
200 |
+
logging.info(f"found the following keywords: {force_flexible}")
|
201 |
+
|
202 |
+
if len(force_flexible) == 0:
|
203 |
+
force_flexible = None
|
204 |
+
constraints = (
|
205 |
+
[
|
206 |
+
PhrasalConstraint(
|
207 |
+
tokenizer(force_word, add_special_tokens=False).input_ids,
|
208 |
+
),
|
209 |
+
]
|
210 |
+
if force_word is not None
|
211 |
+
else None
|
212 |
+
)
|
213 |
+
force_words_ids = (
|
214 |
+
[
|
215 |
+
tokenizer(
|
216 |
+
force_flexible,
|
217 |
+
).input_ids,
|
218 |
+
]
|
219 |
+
if force_flexible is not None
|
220 |
+
else None
|
221 |
+
)
|
222 |
+
|
223 |
+
try:
|
224 |
+
logging.info("generating text..")
|
225 |
+
result = pipeline(
|
226 |
+
prompt,
|
227 |
+
constraints=constraints if force_word is not None else None,
|
228 |
+
force_words_ids=force_words_ids if force_flexible is not None else None,
|
229 |
+
max_length=None,
|
230 |
+
max_new_tokens=max_generated_tokens,
|
231 |
+
min_length=min_generated_tokens + prompt_length if full_text else min_generated_tokens,
|
232 |
+
num_beams=num_beams,
|
233 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
234 |
+
num_return_sequences=num_return_sequences,
|
235 |
+
max_time=timeout,
|
236 |
+
length_penalty=length_penalty,
|
237 |
+
repetition_penalty=repetition_penalty,
|
238 |
+
return_full_text=full_text,
|
239 |
+
remove_invalid_values=True,
|
240 |
+
skip_special_tokens=True,
|
241 |
+
clean_up_tokenization_spaces=True,
|
242 |
+
early_stopping=True,
|
243 |
+
do_sample=False,
|
244 |
+
**kwargs,
|
245 |
+
)
|
246 |
+
response = result[0]["generated_text"]
|
247 |
+
rt = round((time.perf_counter() - st) / 60, 3)
|
248 |
+
logging.info(f"generated response in {rt} minutes")
|
249 |
+
if verbose:
|
250 |
+
print(f"input prompt:\n\t{prompt}")
|
251 |
+
print(f"response:\n\t{response}")
|
252 |
+
except Exception as e:
|
253 |
+
logging.info(f"could not generate response: {e}")
|
254 |
+
response = "Sorry, I don't know how to respond to that."
|
255 |
+
return response
|