languagemodels.inference
1from typing import List 2import requests 3import re 4import os 5import sys 6from time import perf_counter 7 8from languagemodels.models import get_model, get_model_info 9from languagemodels.config import config 10 11 12class InferenceException(Exception): 13 pass 14 15 16def truncate_prompt(prompt): 17 """Truncates a prompt to the maximum length allowed by the config""" 18 max_prompt_length = config["max_prompt_length"] 19 if len(prompt) > max_prompt_length: 20 print( 21 f"Warning: Prompt truncated from {len(prompt)} to " 22 f"{max_prompt_length} characters to avoid OOM." 23 ) 24 return prompt[:max_prompt_length] 25 return prompt 26 27 28def list_tokens(prompt): 29 """Generates a list of tokens for a supplied prompt 30 31 >>> list_tokens("Hello, world!") # doctest: +SKIP 32 [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)] 33 34 >>> list_tokens("Hello, world!") 35 [('...Hello', ...), ... ('...world', ...), ...] 36 """ 37 prompt = truncate_prompt(prompt) 38 tokenizer, _ = get_model("instruct") 39 40 output = tokenizer.encode(prompt, add_special_tokens=False) 41 tokens = output.tokens 42 ids = output.ids 43 44 return list(zip(tokens, ids)) 45 46 47def generate_ts(engine, prompt, max_tokens=200): 48 """Generates a single text response for a prompt from a textsynth server 49 50 The server and API key are provided as environment variables: 51 52 LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 53 LANGUAGEMODELS_TS_KEY is the API key 54 """ 55 apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or "" 56 server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com" 57 58 response = requests.post( 59 f"{server}/v1/engines/{engine}/completions", 60 headers={"Authorization": f"Bearer {apikey}"}, 61 json={"prompt": prompt, "max_tokens": max_tokens}, 62 ) 63 resp = response.json() 64 if "text" in resp: 65 return resp["text"] 66 else: 67 raise InferenceException(f"TextSynth error: {resp}") 68 69 70def generate_oa(engine, prompt, max_tokens=200, temperature=0): 71 """Generates a single text response for a prompt using OpenAI 72 73 The server and API key are provided as environment variables: 74 75 LANGUAGEMODELS_OA_KEY is the API key 76 """ 77 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 78 79 response = requests.post( 80 "https://api.openai.com/v1/completions", 81 headers={ 82 "Authorization": f"Bearer {apikey}", 83 "Content-Type": "application/json", 84 }, 85 json={ 86 "model": engine, 87 "prompt": prompt, 88 "max_tokens": max_tokens, 89 "temperature": temperature, 90 }, 91 ) 92 resp = response.json() 93 94 try: 95 return resp["choices"][0]["text"] 96 except KeyError: 97 raise InferenceException(f"OpenAI error: {resp}") 98 99 100def chat_oa(engine, prompt, max_tokens=200, temperature=0): 101 """Generates a single text response for a prompt using OpenAI 102 103 The server and API key are provided as environment variables: 104 105 LANGUAGEMODELS_OA_KEY is the API key 106 """ 107 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 108 109 response = requests.post( 110 "https://api.openai.com/v1/chat/completions", 111 headers={ 112 "Authorization": f"Bearer {apikey}", 113 "Content-Type": "application/json", 114 }, 115 json={ 116 "model": engine, 117 "messages": [{"role": "user", "content": prompt}], 118 "max_tokens": max_tokens, 119 "temperature": temperature, 120 }, 121 ) 122 resp = response.json() 123 124 try: 125 return resp["choices"][0]["message"]["content"] 126 except KeyError: 127 raise InferenceException(f"OpenAI error: {resp}") 128 129 130def stream_results(results, tokenizer): 131 """Map a token iterator to a substring iterator""" 132 tokens = [] 133 last_len = 0 134 135 for result in results: 136 tokens.append(result.token_id) 137 text = tokenizer.decode(tokens) 138 yield text[last_len:] 139 last_len = len(text) 140 141 142def echo_results(results, tokenizer): 143 """Output results to stderr as they are collected""" 144 tokens = [] 145 last_len = 0 146 147 for result in results: 148 tokens.append(result.token_id) 149 text = tokenizer.decode(tokens) 150 sys.stderr.write(text[last_len:]) 151 sys.stderr.flush() 152 last_len = len(text) 153 154 sys.stderr.write("\n\n") 155 sys.stderr.flush() 156 return tokens 157 158 159def generate( 160 instructions: List[str], 161 max_tokens: int = 200, 162 temperature: float = 0.1, 163 topk: int = 1, 164 repetition_penalty: float = 0.0, 165 prefix: str = "", 166 suppress: List[str] = [], 167 model: str = "instruct", 168 stream: bool = False, 169): 170 """Generates completions for a prompt 171 172 This may use a local model, or it may make an API call to an external 173 model if API keys are available. 174 175 >>> generate(["What is the capital of France?"]) 176 ['...Paris...'] 177 178 >>> list(generate(["What is the capital of France?"], stream=True)) 179 ['...Paris...'] 180 """ 181 if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get( 182 "LANGUAGEMODELS_TS_SERVER" 183 ): 184 return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip() 185 186 if os.environ.get("LANGUAGEMODELS_OA_KEY"): 187 return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip() 188 189 tokenizer, model = get_model(model) 190 191 start_time = perf_counter() 192 193 suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress] 194 195 model_info = get_model_info("instruct") 196 197 fmt = model_info.get("prompt_fmt", "{instruction}") 198 199 if repetition_penalty == 0.0: 200 repetition_penalty = model_info.get("repetition_penalty", 1.3) 201 202 prompts = [fmt.replace("{instruction}", inst) for inst in instructions] 203 truncated_prompts = [truncate_prompt(p) for p in prompts] 204 205 prompts_tok = [tokenizer.encode(p).tokens for p in truncated_prompts] 206 207 outputs_ids = [] 208 if hasattr(model, "translate_batch"): 209 prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens 210 if stream or (config["echo"] and len(prompts_tok) == 1): 211 results = model.generate_tokens( 212 prompts_tok[0], 213 target_prefix=prefix, 214 repetition_penalty=repetition_penalty, 215 max_decoding_length=max_tokens, 216 sampling_temperature=temperature, 217 sampling_topk=topk, 218 suppress_sequences=suppress, 219 ) 220 221 if stream: 222 return stream_results(results, tokenizer) 223 else: 224 outputs_ids = [echo_results(results, tokenizer)] 225 else: 226 results = model.translate_batch( 227 prompts_tok, 228 target_prefix=[prefix] * len(prompts), 229 repetition_penalty=repetition_penalty, 230 max_decoding_length=max_tokens, 231 sampling_temperature=temperature, 232 sampling_topk=topk, 233 suppress_sequences=suppress, 234 beam_size=1, 235 ) 236 outputs_tokens = [r.hypotheses[0] for r in results] 237 for output in outputs_tokens: 238 outputs_ids.append([tokenizer.token_to_id(t) for t in output]) 239 else: 240 if stream or (config["echo"] and len(prompts_tok) == 1): 241 results = model.generate_tokens( 242 prompts_tok, 243 repetition_penalty=repetition_penalty, 244 max_length=max_tokens, 245 sampling_temperature=temperature, 246 sampling_topk=topk, 247 suppress_sequences=suppress, 248 ) 249 250 if stream: 251 return stream_results(results, tokenizer) 252 else: 253 outputs_ids = [echo_results(results, tokenizer)] 254 else: 255 results = model.generate_batch( 256 prompts_tok, 257 repetition_penalty=repetition_penalty, 258 max_length=max_tokens, 259 sampling_temperature=temperature, 260 sampling_topk=topk, 261 suppress_sequences=suppress, 262 beam_size=1, 263 include_prompt_in_result=False, 264 ) 265 outputs_ids = [r.sequences_ids[0] for r in results] 266 267 model_info["requests"] = model_info.get("requests", 0) + len(prompts) 268 269 in_toks = sum(len(p) for p in prompts_tok) 270 model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks 271 272 out_toks = sum(len(o) for o in outputs_ids) 273 model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks 274 275 elapsed_time = perf_counter() - start_time 276 model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time 277 278 return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids] 279 280 281def rank_instruct(inputs, targets): 282 """Sorts a list of targets by their probabilities 283 284 >>> rank_instruct(["Classify positive or negative: I love python. Classification:"], 285 ... ['positive', 'negative']) 286 [['positive', 'negative']] 287 288 >>> rank_instruct(["Classify fantasy or documentary: " 289 ... "The wizard raised their wand. Classification:"], 290 ... ['fantasy', 'documentary']) 291 [['fantasy', 'documentary']] 292 293 >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"]) 294 [['six', 'seven'], ['seven', 'six']] 295 """ 296 tokenizer, model = get_model("instruct") 297 298 model_info = get_model_info("instruct") 299 fmt = model_info.get("prompt_fmt", "{instruction}") 300 inputs = [fmt.replace("{instruction}", inst) for inst in inputs] 301 inputs = [truncate_prompt(i) for i in inputs] 302 303 targ_tok = [tokenizer.encode(t, add_special_tokens=False).tokens for t in targets] 304 targ_tok *= len(inputs) 305 306 in_tok = [] 307 for input in inputs: 308 toks = [tokenizer.encode(input).tokens] 309 in_tok += toks * len(targets) 310 311 if "Generator" in str(type(model)): 312 scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)]) 313 else: 314 scores = model.score_batch(in_tok, target=targ_tok) 315 316 ret = [] 317 for i in range(0, len(inputs) * len(targets), len(targets)): 318 logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]] 319 results = sorted(zip(targets, logprobs), key=lambda r: -r[1]) 320 ret.append([r[0] for r in results]) 321 322 return ret 323 324 325def parse_chat(prompt): 326 """Converts a chat prompt using special tokens to a plain-text prompt 327 328 This is useful for prompting generic models that have not been fine-tuned 329 for chat using specialized tokens. 330 331 >>> parse_chat('User: What time is it?') 332 Traceback (most recent call last): 333 .... 334 inference.InferenceException: Chat prompt must end with 'Assistant:' 335 336 >>> parse_chat('''User: What time is it? 337 ... 338 ... Assistant:''') 339 [{'role': 'user', 'content': 'What time is it?'}] 340 341 >>> parse_chat(''' 342 ... A helpful assistant 343 ... 344 ... User: What time is it? 345 ... 346 ... Assistant: 347 ... ''') 348 [{'role': 'system', 'content': 'A helpful assistant'}, 349 {'role': 'user', 'content': 'What time is it?'}] 350 351 >>> parse_chat(''' 352 ... A helpful assistant 353 ... 354 ... User: What time is it? 355 ... 356 ... Assistant: The time is 357 ... ''') 358 Traceback (most recent call last): 359 .... 360 inference.InferenceException: Final assistant message must be blank 361 362 >>> parse_chat(''' 363 ... A helpful assistant 364 ... 365 ... User: First para 366 ... 367 ... Second para 368 ... 369 ... Assistant: 370 ... ''') 371 [{'role': 'system', 'content': 'A helpful assistant'}, 372 {'role': 'user', 'content': 'First para\\n\\nSecond para'}] 373 374 >>> parse_chat(''' 375 ... A helpful assistant 376 ... 377 ... User: What time is it? 378 ... 379 ... InvalidRole: Nothing 380 ... 381 ... Assistant: 382 ... ''') 383 Traceback (most recent call last): 384 .... 385 inference.InferenceException: Invalid chat role: invalidrole 386 """ 387 388 if not re.match(r"^\s*\w+:", prompt): 389 prompt = "System: " + prompt 390 391 prompt = "\n\n" + prompt 392 393 chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M) 394 chunks = [m.strip() for m in chunks if m.strip()] 395 396 messages = [] 397 398 for i in range(0, len(chunks), 2): 399 role = chunks[i].lower() 400 401 try: 402 content = chunks[i + 1] 403 content = re.sub(r"\s*\n\n\s*", "\n\n", content) 404 except IndexError: 405 content = "" 406 messages.append({"role": role, "content": content}) 407 408 for message in messages: 409 if message["role"] not in ["system", "user", "assistant"]: 410 raise InferenceException(f"Invalid chat role: {message['role']}") 411 412 if messages[-1]["role"] != "assistant": 413 raise InferenceException("Chat prompt must end with 'Assistant:'") 414 415 if messages[-1]["content"] != "": 416 raise InferenceException("Final assistant message must be blank") 417 418 return messages[:-1]
Common base class for all non-exit exceptions.
17def truncate_prompt(prompt): 18 """Truncates a prompt to the maximum length allowed by the config""" 19 max_prompt_length = config["max_prompt_length"] 20 if len(prompt) > max_prompt_length: 21 print( 22 f"Warning: Prompt truncated from {len(prompt)} to " 23 f"{max_prompt_length} characters to avoid OOM." 24 ) 25 return prompt[:max_prompt_length] 26 return prompt
Truncates a prompt to the maximum length allowed by the config
29def list_tokens(prompt): 30 """Generates a list of tokens for a supplied prompt 31 32 >>> list_tokens("Hello, world!") # doctest: +SKIP 33 [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)] 34 35 >>> list_tokens("Hello, world!") 36 [('...Hello', ...), ... ('...world', ...), ...] 37 """ 38 prompt = truncate_prompt(prompt) 39 tokenizer, _ = get_model("instruct") 40 41 output = tokenizer.encode(prompt, add_special_tokens=False) 42 tokens = output.tokens 43 ids = output.ids 44 45 return list(zip(tokens, ids))
Generates a list of tokens for a supplied prompt
>>> list_tokens("Hello, world!") # doctest: +SKIP
[('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
>>> list_tokens("Hello, world!")
[('...Hello', ...), ... ('...world', ...), ...]
48def generate_ts(engine, prompt, max_tokens=200): 49 """Generates a single text response for a prompt from a textsynth server 50 51 The server and API key are provided as environment variables: 52 53 LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 54 LANGUAGEMODELS_TS_KEY is the API key 55 """ 56 apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or "" 57 server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com" 58 59 response = requests.post( 60 f"{server}/v1/engines/{engine}/completions", 61 headers={"Authorization": f"Bearer {apikey}"}, 62 json={"prompt": prompt, "max_tokens": max_tokens}, 63 ) 64 resp = response.json() 65 if "text" in resp: 66 return resp["text"] 67 else: 68 raise InferenceException(f"TextSynth error: {resp}")
Generates a single text response for a prompt from a textsynth server
The server and API key are provided as environment variables:
LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 LANGUAGEMODELS_TS_KEY is the API key
71def generate_oa(engine, prompt, max_tokens=200, temperature=0): 72 """Generates a single text response for a prompt using OpenAI 73 74 The server and API key are provided as environment variables: 75 76 LANGUAGEMODELS_OA_KEY is the API key 77 """ 78 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 79 80 response = requests.post( 81 "https://api.openai.com/v1/completions", 82 headers={ 83 "Authorization": f"Bearer {apikey}", 84 "Content-Type": "application/json", 85 }, 86 json={ 87 "model": engine, 88 "prompt": prompt, 89 "max_tokens": max_tokens, 90 "temperature": temperature, 91 }, 92 ) 93 resp = response.json() 94 95 try: 96 return resp["choices"][0]["text"] 97 except KeyError: 98 raise InferenceException(f"OpenAI error: {resp}")
Generates a single text response for a prompt using OpenAI
The server and API key are provided as environment variables:
LANGUAGEMODELS_OA_KEY is the API key
101def chat_oa(engine, prompt, max_tokens=200, temperature=0): 102 """Generates a single text response for a prompt using OpenAI 103 104 The server and API key are provided as environment variables: 105 106 LANGUAGEMODELS_OA_KEY is the API key 107 """ 108 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 109 110 response = requests.post( 111 "https://api.openai.com/v1/chat/completions", 112 headers={ 113 "Authorization": f"Bearer {apikey}", 114 "Content-Type": "application/json", 115 }, 116 json={ 117 "model": engine, 118 "messages": [{"role": "user", "content": prompt}], 119 "max_tokens": max_tokens, 120 "temperature": temperature, 121 }, 122 ) 123 resp = response.json() 124 125 try: 126 return resp["choices"][0]["message"]["content"] 127 except KeyError: 128 raise InferenceException(f"OpenAI error: {resp}")
Generates a single text response for a prompt using OpenAI
The server and API key are provided as environment variables:
LANGUAGEMODELS_OA_KEY is the API key
131def stream_results(results, tokenizer): 132 """Map a token iterator to a substring iterator""" 133 tokens = [] 134 last_len = 0 135 136 for result in results: 137 tokens.append(result.token_id) 138 text = tokenizer.decode(tokens) 139 yield text[last_len:] 140 last_len = len(text)
Map a token iterator to a substring iterator
143def echo_results(results, tokenizer): 144 """Output results to stderr as they are collected""" 145 tokens = [] 146 last_len = 0 147 148 for result in results: 149 tokens.append(result.token_id) 150 text = tokenizer.decode(tokens) 151 sys.stderr.write(text[last_len:]) 152 sys.stderr.flush() 153 last_len = len(text) 154 155 sys.stderr.write("\n\n") 156 sys.stderr.flush() 157 return tokens
Output results to stderr as they are collected
160def generate( 161 instructions: List[str], 162 max_tokens: int = 200, 163 temperature: float = 0.1, 164 topk: int = 1, 165 repetition_penalty: float = 0.0, 166 prefix: str = "", 167 suppress: List[str] = [], 168 model: str = "instruct", 169 stream: bool = False, 170): 171 """Generates completions for a prompt 172 173 This may use a local model, or it may make an API call to an external 174 model if API keys are available. 175 176 >>> generate(["What is the capital of France?"]) 177 ['...Paris...'] 178 179 >>> list(generate(["What is the capital of France?"], stream=True)) 180 ['...Paris...'] 181 """ 182 if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get( 183 "LANGUAGEMODELS_TS_SERVER" 184 ): 185 return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip() 186 187 if os.environ.get("LANGUAGEMODELS_OA_KEY"): 188 return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip() 189 190 tokenizer, model = get_model(model) 191 192 start_time = perf_counter() 193 194 suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress] 195 196 model_info = get_model_info("instruct") 197 198 fmt = model_info.get("prompt_fmt", "{instruction}") 199 200 if repetition_penalty == 0.0: 201 repetition_penalty = model_info.get("repetition_penalty", 1.3) 202 203 prompts = [fmt.replace("{instruction}", inst) for inst in instructions] 204 truncated_prompts = [truncate_prompt(p) for p in prompts] 205 206 prompts_tok = [tokenizer.encode(p).tokens for p in truncated_prompts] 207 208 outputs_ids = [] 209 if hasattr(model, "translate_batch"): 210 prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens 211 if stream or (config["echo"] and len(prompts_tok) == 1): 212 results = model.generate_tokens( 213 prompts_tok[0], 214 target_prefix=prefix, 215 repetition_penalty=repetition_penalty, 216 max_decoding_length=max_tokens, 217 sampling_temperature=temperature, 218 sampling_topk=topk, 219 suppress_sequences=suppress, 220 ) 221 222 if stream: 223 return stream_results(results, tokenizer) 224 else: 225 outputs_ids = [echo_results(results, tokenizer)] 226 else: 227 results = model.translate_batch( 228 prompts_tok, 229 target_prefix=[prefix] * len(prompts), 230 repetition_penalty=repetition_penalty, 231 max_decoding_length=max_tokens, 232 sampling_temperature=temperature, 233 sampling_topk=topk, 234 suppress_sequences=suppress, 235 beam_size=1, 236 ) 237 outputs_tokens = [r.hypotheses[0] for r in results] 238 for output in outputs_tokens: 239 outputs_ids.append([tokenizer.token_to_id(t) for t in output]) 240 else: 241 if stream or (config["echo"] and len(prompts_tok) == 1): 242 results = model.generate_tokens( 243 prompts_tok, 244 repetition_penalty=repetition_penalty, 245 max_length=max_tokens, 246 sampling_temperature=temperature, 247 sampling_topk=topk, 248 suppress_sequences=suppress, 249 ) 250 251 if stream: 252 return stream_results(results, tokenizer) 253 else: 254 outputs_ids = [echo_results(results, tokenizer)] 255 else: 256 results = model.generate_batch( 257 prompts_tok, 258 repetition_penalty=repetition_penalty, 259 max_length=max_tokens, 260 sampling_temperature=temperature, 261 sampling_topk=topk, 262 suppress_sequences=suppress, 263 beam_size=1, 264 include_prompt_in_result=False, 265 ) 266 outputs_ids = [r.sequences_ids[0] for r in results] 267 268 model_info["requests"] = model_info.get("requests", 0) + len(prompts) 269 270 in_toks = sum(len(p) for p in prompts_tok) 271 model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks 272 273 out_toks = sum(len(o) for o in outputs_ids) 274 model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks 275 276 elapsed_time = perf_counter() - start_time 277 model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time 278 279 return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]
Generates completions for a prompt
This may use a local model, or it may make an API call to an external model if API keys are available.
>>> generate(["What is the capital of France?"])
['...Paris...']
>>> list(generate(["What is the capital of France?"], stream=True))
['...Paris...']
282def rank_instruct(inputs, targets): 283 """Sorts a list of targets by their probabilities 284 285 >>> rank_instruct(["Classify positive or negative: I love python. Classification:"], 286 ... ['positive', 'negative']) 287 [['positive', 'negative']] 288 289 >>> rank_instruct(["Classify fantasy or documentary: " 290 ... "The wizard raised their wand. Classification:"], 291 ... ['fantasy', 'documentary']) 292 [['fantasy', 'documentary']] 293 294 >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"]) 295 [['six', 'seven'], ['seven', 'six']] 296 """ 297 tokenizer, model = get_model("instruct") 298 299 model_info = get_model_info("instruct") 300 fmt = model_info.get("prompt_fmt", "{instruction}") 301 inputs = [fmt.replace("{instruction}", inst) for inst in inputs] 302 inputs = [truncate_prompt(i) for i in inputs] 303 304 targ_tok = [tokenizer.encode(t, add_special_tokens=False).tokens for t in targets] 305 targ_tok *= len(inputs) 306 307 in_tok = [] 308 for input in inputs: 309 toks = [tokenizer.encode(input).tokens] 310 in_tok += toks * len(targets) 311 312 if "Generator" in str(type(model)): 313 scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)]) 314 else: 315 scores = model.score_batch(in_tok, target=targ_tok) 316 317 ret = [] 318 for i in range(0, len(inputs) * len(targets), len(targets)): 319 logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]] 320 results = sorted(zip(targets, logprobs), key=lambda r: -r[1]) 321 ret.append([r[0] for r in results]) 322 323 return ret
Sorts a list of targets by their probabilities
>>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
... ['positive', 'negative'])
[['positive', 'negative']]
>>> rank_instruct(["Classify fantasy or documentary: "
... "The wizard raised their wand. Classification:"],
... ['fantasy', 'documentary'])
[['fantasy', 'documentary']]
>>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
[['six', 'seven'], ['seven', 'six']]
326def parse_chat(prompt): 327 """Converts a chat prompt using special tokens to a plain-text prompt 328 329 This is useful for prompting generic models that have not been fine-tuned 330 for chat using specialized tokens. 331 332 >>> parse_chat('User: What time is it?') 333 Traceback (most recent call last): 334 .... 335 inference.InferenceException: Chat prompt must end with 'Assistant:' 336 337 >>> parse_chat('''User: What time is it? 338 ... 339 ... Assistant:''') 340 [{'role': 'user', 'content': 'What time is it?'}] 341 342 >>> parse_chat(''' 343 ... A helpful assistant 344 ... 345 ... User: What time is it? 346 ... 347 ... Assistant: 348 ... ''') 349 [{'role': 'system', 'content': 'A helpful assistant'}, 350 {'role': 'user', 'content': 'What time is it?'}] 351 352 >>> parse_chat(''' 353 ... A helpful assistant 354 ... 355 ... User: What time is it? 356 ... 357 ... Assistant: The time is 358 ... ''') 359 Traceback (most recent call last): 360 .... 361 inference.InferenceException: Final assistant message must be blank 362 363 >>> parse_chat(''' 364 ... A helpful assistant 365 ... 366 ... User: First para 367 ... 368 ... Second para 369 ... 370 ... Assistant: 371 ... ''') 372 [{'role': 'system', 'content': 'A helpful assistant'}, 373 {'role': 'user', 'content': 'First para\\n\\nSecond para'}] 374 375 >>> parse_chat(''' 376 ... A helpful assistant 377 ... 378 ... User: What time is it? 379 ... 380 ... InvalidRole: Nothing 381 ... 382 ... Assistant: 383 ... ''') 384 Traceback (most recent call last): 385 .... 386 inference.InferenceException: Invalid chat role: invalidrole 387 """ 388 389 if not re.match(r"^\s*\w+:", prompt): 390 prompt = "System: " + prompt 391 392 prompt = "\n\n" + prompt 393 394 chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M) 395 chunks = [m.strip() for m in chunks if m.strip()] 396 397 messages = [] 398 399 for i in range(0, len(chunks), 2): 400 role = chunks[i].lower() 401 402 try: 403 content = chunks[i + 1] 404 content = re.sub(r"\s*\n\n\s*", "\n\n", content) 405 except IndexError: 406 content = "" 407 messages.append({"role": role, "content": content}) 408 409 for message in messages: 410 if message["role"] not in ["system", "user", "assistant"]: 411 raise InferenceException(f"Invalid chat role: {message['role']}") 412 413 if messages[-1]["role"] != "assistant": 414 raise InferenceException("Chat prompt must end with 'Assistant:'") 415 416 if messages[-1]["content"] != "": 417 raise InferenceException("Final assistant message must be blank") 418 419 return messages[:-1]
Converts a chat prompt using special tokens to a plain-text prompt
This is useful for prompting generic models that have not been fine-tuned for chat using specialized tokens.
>>> parse_chat('User: What time is it?')
Traceback (most recent call last):
....
inference.InferenceException: Chat prompt must end with 'Assistant:'
>>> parse_chat('''User: What time is it?
...
... Assistant:''')
[{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... Assistant:
... ''')
[{'role': 'system', 'content': 'A helpful assistant'},
{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... Assistant: The time is
... ''')
Traceback (most recent call last):
....
inference.InferenceException: Final assistant message must be blank
>>> parse_chat('''
... A helpful assistant
...
... User: First para
...
... Second para
...
... Assistant:
... ''')
[{'role': 'system', 'content': 'A helpful assistant'},
{'role': 'user', 'content': 'First para\n\nSecond para'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... InvalidRole: Nothing
...
... Assistant:
... ''')
Traceback (most recent call last):
....
inference.InferenceException: Invalid chat role: invalidrole