languagemodels.inference
1from typing import List 2import requests 3import re 4import os 5import sys 6from time import perf_counter 7 8from languagemodels.models import get_model, get_model_info 9from languagemodels.config import config 10 11 12class InferenceException(Exception): 13 pass 14 15 16def list_tokens(prompt): 17 """Generates a list of tokens for a supplied prompt 18 19 >>> list_tokens("Hello, world!") # doctest: +SKIP 20 [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)] 21 22 >>> list_tokens("Hello, world!") 23 [('...Hello', ...), ... ('...world', ...), ...] 24 """ 25 tokenizer, _ = get_model("instruct") 26 27 output = tokenizer.encode(prompt, add_special_tokens=False) 28 tokens = output.tokens 29 ids = output.ids 30 31 return list(zip(tokens, ids)) 32 33 34def generate_ts(engine, prompt, max_tokens=200): 35 """Generates a single text response for a prompt from a textsynth server 36 37 The server and API key are provided as environment variables: 38 39 LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 40 LANGUAGEMODELS_TS_KEY is the API key 41 """ 42 apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or "" 43 server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com" 44 45 response = requests.post( 46 f"{server}/v1/engines/{engine}/completions", 47 headers={"Authorization": f"Bearer {apikey}"}, 48 json={"prompt": prompt, "max_tokens": max_tokens}, 49 ) 50 resp = response.json() 51 if "text" in resp: 52 return resp["text"] 53 else: 54 raise InferenceException(f"TextSynth error: {resp}") 55 56 57def generate_oa(engine, prompt, max_tokens=200, temperature=0): 58 """Generates a single text response for a prompt using OpenAI 59 60 The server and API key are provided as environment variables: 61 62 LANGUAGEMODELS_OA_KEY is the API key 63 """ 64 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 65 66 response = requests.post( 67 "https://api.openai.com/v1/completions", 68 headers={ 69 "Authorization": f"Bearer {apikey}", 70 "Content-Type": "application/json", 71 }, 72 json={ 73 "model": engine, 74 "prompt": prompt, 75 "max_tokens": max_tokens, 76 "temperature": temperature, 77 }, 78 ) 79 resp = response.json() 80 81 try: 82 return resp["choices"][0]["text"] 83 except KeyError: 84 raise InferenceException(f"OpenAI error: {resp}") 85 86 87def chat_oa(engine, prompt, max_tokens=200, temperature=0): 88 """Generates a single text response for a prompt using OpenAI 89 90 The server and API key are provided as environment variables: 91 92 LANGUAGEMODELS_OA_KEY is the API key 93 """ 94 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 95 96 response = requests.post( 97 "https://api.openai.com/v1/chat/completions", 98 headers={ 99 "Authorization": f"Bearer {apikey}", 100 "Content-Type": "application/json", 101 }, 102 json={ 103 "model": engine, 104 "messages": [{"role": "user", "content": prompt}], 105 "max_tokens": max_tokens, 106 "temperature": temperature, 107 }, 108 ) 109 resp = response.json() 110 111 try: 112 return resp["choices"][0]["message"]["content"] 113 except KeyError: 114 raise InferenceException(f"OpenAI error: {resp}") 115 116 117def stream_results(results, tokenizer): 118 """Map a token iterator to a substring iterator""" 119 tokens = [] 120 last_len = 0 121 122 for result in results: 123 tokens.append(result.token_id) 124 text = tokenizer.decode(tokens) 125 yield text[last_len:] 126 last_len = len(text) 127 128 129def echo_results(results, tokenizer): 130 """Output results to stderr as they are collected""" 131 tokens = [] 132 last_len = 0 133 134 for result in results: 135 tokens.append(result.token_id) 136 text = tokenizer.decode(tokens) 137 sys.stderr.write(text[last_len:]) 138 sys.stderr.flush() 139 last_len = len(text) 140 141 sys.stderr.write("\n\n") 142 sys.stderr.flush() 143 return tokens 144 145 146def generate( 147 instructions: List[str], 148 max_tokens: int = 200, 149 temperature: float = 0.1, 150 topk: int = 1, 151 repetition_penalty: float = 0.0, 152 prefix: str = "", 153 suppress: List[str] = [], 154 model: str = "instruct", 155 stream: bool = False, 156): 157 """Generates completions for a prompt 158 159 This may use a local model, or it may make an API call to an external 160 model if API keys are available. 161 162 >>> generate(["What is the capital of France?"]) 163 ['...Paris...'] 164 165 >>> list(generate(["What is the capital of France?"], stream=True)) 166 ['...Paris...'] 167 """ 168 if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get( 169 "LANGUAGEMODELS_TS_SERVER" 170 ): 171 return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip() 172 173 if os.environ.get("LANGUAGEMODELS_OA_KEY"): 174 return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip() 175 176 tokenizer, model = get_model(model) 177 178 start_time = perf_counter() 179 180 suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress] 181 182 model_info = get_model_info("instruct") 183 184 fmt = model_info.get("prompt_fmt", "{instruction}") 185 186 if repetition_penalty == 0.0: 187 repetition_penalty = model_info.get("repetition_penalty", 1.3) 188 189 prompts = [fmt.replace("{instruction}", inst) for inst in instructions] 190 prompts_tok = [tokenizer.encode(p).tokens for p in prompts] 191 192 outputs_ids = [] 193 if hasattr(model, "translate_batch"): 194 prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens 195 if stream or (config["echo"] and len(prompts_tok) == 1): 196 results = model.generate_tokens( 197 prompts_tok[0], 198 target_prefix=prefix, 199 repetition_penalty=repetition_penalty, 200 max_decoding_length=max_tokens, 201 sampling_temperature=temperature, 202 sampling_topk=topk, 203 suppress_sequences=suppress, 204 ) 205 206 if stream: 207 return stream_results(results, tokenizer) 208 else: 209 outputs_ids = [echo_results(results, tokenizer)] 210 else: 211 results = model.translate_batch( 212 prompts_tok, 213 target_prefix=[prefix] * len(prompts), 214 repetition_penalty=repetition_penalty, 215 max_decoding_length=max_tokens, 216 sampling_temperature=temperature, 217 sampling_topk=topk, 218 suppress_sequences=suppress, 219 beam_size=1, 220 ) 221 outputs_tokens = [r.hypotheses[0] for r in results] 222 for output in outputs_tokens: 223 outputs_ids.append([tokenizer.token_to_id(t) for t in output]) 224 else: 225 if stream or (config["echo"] and len(prompts_tok) == 1): 226 results = model.generate_tokens( 227 prompts_tok, 228 repetition_penalty=repetition_penalty, 229 max_length=max_tokens, 230 sampling_temperature=temperature, 231 sampling_topk=topk, 232 suppress_sequences=suppress, 233 ) 234 235 if stream: 236 return stream_results(results, tokenizer) 237 else: 238 outputs_ids = [echo_results(results, tokenizer)] 239 else: 240 results = model.generate_batch( 241 prompts_tok, 242 repetition_penalty=repetition_penalty, 243 max_length=max_tokens, 244 sampling_temperature=temperature, 245 sampling_topk=topk, 246 suppress_sequences=suppress, 247 beam_size=1, 248 include_prompt_in_result=False, 249 ) 250 outputs_ids = [r.sequences_ids[0] for r in results] 251 252 model_info["requests"] = model_info.get("requests", 0) + len(prompts) 253 254 in_toks = sum(len(p) for p in prompts_tok) 255 model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks 256 257 out_toks = sum(len(o) for o in outputs_ids) 258 model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks 259 260 elapsed_time = perf_counter() - start_time 261 model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time 262 263 return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids] 264 265 266def rank_instruct(inputs, targets): 267 """Sorts a list of targets by their probabilities 268 269 >>> rank_instruct(["Classify positive or negative: I love python. Classification:"], 270 ... ['positive', 'negative']) 271 [['positive', 'negative']] 272 273 >>> rank_instruct(["Classify fantasy or documentary: " 274 ... "The wizard raised their wand. Classification:"], 275 ... ['fantasy', 'documentary']) 276 [['fantasy', 'documentary']] 277 278 >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"]) 279 [['six', 'seven'], ['seven', 'six']] 280 """ 281 tokenizer, model = get_model("instruct") 282 283 targ_tok = [tokenizer.encode(t).tokens for t in targets] 284 targ_tok *= len(inputs) 285 286 in_tok = [] 287 for input in inputs: 288 toks = [tokenizer.encode(input).tokens] 289 in_tok += toks * len(targets) 290 291 if "Generator" in str(type(model)): 292 scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)]) 293 else: 294 scores = model.score_batch(in_tok, target=targ_tok) 295 296 ret = [] 297 for i in range(0, len(inputs) * len(targets), len(targets)): 298 logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]] 299 results = sorted(zip(targets, logprobs), key=lambda r: -r[1]) 300 ret.append([r[0] for r in results]) 301 302 return ret 303 304 305def parse_chat(prompt): 306 """Converts a chat prompt using special tokens to a plain-text prompt 307 308 This is useful for prompting generic models that have not been fine-tuned 309 for chat using specialized tokens. 310 311 >>> parse_chat('User: What time is it?') 312 Traceback (most recent call last): 313 .... 314 inference.InferenceException: Chat prompt must end with 'Assistant:' 315 316 >>> parse_chat('''User: What time is it? 317 ... 318 ... Assistant:''') 319 [{'role': 'user', 'content': 'What time is it?'}] 320 321 >>> parse_chat(''' 322 ... A helpful assistant 323 ... 324 ... User: What time is it? 325 ... 326 ... Assistant: 327 ... ''') 328 [{'role': 'system', 'content': 'A helpful assistant'}, 329 {'role': 'user', 'content': 'What time is it?'}] 330 331 >>> parse_chat(''' 332 ... A helpful assistant 333 ... 334 ... User: What time is it? 335 ... 336 ... Assistant: The time is 337 ... ''') 338 Traceback (most recent call last): 339 .... 340 inference.InferenceException: Final assistant message must be blank 341 342 >>> parse_chat(''' 343 ... A helpful assistant 344 ... 345 ... User: First para 346 ... 347 ... Second para 348 ... 349 ... Assistant: 350 ... ''') 351 [{'role': 'system', 'content': 'A helpful assistant'}, 352 {'role': 'user', 'content': 'First para\\n\\nSecond para'}] 353 354 >>> parse_chat(''' 355 ... A helpful assistant 356 ... 357 ... User: What time is it? 358 ... 359 ... InvalidRole: Nothing 360 ... 361 ... Assistant: 362 ... ''') 363 Traceback (most recent call last): 364 .... 365 inference.InferenceException: Invalid chat role: invalidrole 366 """ 367 368 if not re.match(r"^\s*\w+:", prompt): 369 prompt = "System: " + prompt 370 371 prompt = "\n\n" + prompt 372 373 chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M) 374 chunks = [m.strip() for m in chunks if m.strip()] 375 376 messages = [] 377 378 for i in range(0, len(chunks), 2): 379 role = chunks[i].lower() 380 381 try: 382 content = chunks[i + 1] 383 content = re.sub(r"\s*\n\n\s*", "\n\n", content) 384 except IndexError: 385 content = "" 386 messages.append({"role": role, "content": content}) 387 388 for message in messages: 389 if message["role"] not in ["system", "user", "assistant"]: 390 raise InferenceException(f"Invalid chat role: {message['role']}") 391 392 if messages[-1]["role"] != "assistant": 393 raise InferenceException("Chat prompt must end with 'Assistant:'") 394 395 if messages[-1]["content"] != "": 396 raise InferenceException("Final assistant message must be blank") 397 398 return messages[:-1]
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
17def list_tokens(prompt): 18 """Generates a list of tokens for a supplied prompt 19 20 >>> list_tokens("Hello, world!") # doctest: +SKIP 21 [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)] 22 23 >>> list_tokens("Hello, world!") 24 [('...Hello', ...), ... ('...world', ...), ...] 25 """ 26 tokenizer, _ = get_model("instruct") 27 28 output = tokenizer.encode(prompt, add_special_tokens=False) 29 tokens = output.tokens 30 ids = output.ids 31 32 return list(zip(tokens, ids))
Generates a list of tokens for a supplied prompt
>>> list_tokens("Hello, world!") # doctest: +SKIP
[('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
>>> list_tokens("Hello, world!")
[('...Hello', ...), ... ('...world', ...), ...]
35def generate_ts(engine, prompt, max_tokens=200): 36 """Generates a single text response for a prompt from a textsynth server 37 38 The server and API key are provided as environment variables: 39 40 LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 41 LANGUAGEMODELS_TS_KEY is the API key 42 """ 43 apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or "" 44 server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com" 45 46 response = requests.post( 47 f"{server}/v1/engines/{engine}/completions", 48 headers={"Authorization": f"Bearer {apikey}"}, 49 json={"prompt": prompt, "max_tokens": max_tokens}, 50 ) 51 resp = response.json() 52 if "text" in resp: 53 return resp["text"] 54 else: 55 raise InferenceException(f"TextSynth error: {resp}")
Generates a single text response for a prompt from a textsynth server
The server and API key are provided as environment variables:
LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 LANGUAGEMODELS_TS_KEY is the API key
58def generate_oa(engine, prompt, max_tokens=200, temperature=0): 59 """Generates a single text response for a prompt using OpenAI 60 61 The server and API key are provided as environment variables: 62 63 LANGUAGEMODELS_OA_KEY is the API key 64 """ 65 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 66 67 response = requests.post( 68 "https://api.openai.com/v1/completions", 69 headers={ 70 "Authorization": f"Bearer {apikey}", 71 "Content-Type": "application/json", 72 }, 73 json={ 74 "model": engine, 75 "prompt": prompt, 76 "max_tokens": max_tokens, 77 "temperature": temperature, 78 }, 79 ) 80 resp = response.json() 81 82 try: 83 return resp["choices"][0]["text"] 84 except KeyError: 85 raise InferenceException(f"OpenAI error: {resp}")
Generates a single text response for a prompt using OpenAI
The server and API key are provided as environment variables:
LANGUAGEMODELS_OA_KEY is the API key
88def chat_oa(engine, prompt, max_tokens=200, temperature=0): 89 """Generates a single text response for a prompt using OpenAI 90 91 The server and API key are provided as environment variables: 92 93 LANGUAGEMODELS_OA_KEY is the API key 94 """ 95 apikey = os.environ.get("LANGUAGEMODELS_OA_KEY") 96 97 response = requests.post( 98 "https://api.openai.com/v1/chat/completions", 99 headers={ 100 "Authorization": f"Bearer {apikey}", 101 "Content-Type": "application/json", 102 }, 103 json={ 104 "model": engine, 105 "messages": [{"role": "user", "content": prompt}], 106 "max_tokens": max_tokens, 107 "temperature": temperature, 108 }, 109 ) 110 resp = response.json() 111 112 try: 113 return resp["choices"][0]["message"]["content"] 114 except KeyError: 115 raise InferenceException(f"OpenAI error: {resp}")
Generates a single text response for a prompt using OpenAI
The server and API key are provided as environment variables:
LANGUAGEMODELS_OA_KEY is the API key
118def stream_results(results, tokenizer): 119 """Map a token iterator to a substring iterator""" 120 tokens = [] 121 last_len = 0 122 123 for result in results: 124 tokens.append(result.token_id) 125 text = tokenizer.decode(tokens) 126 yield text[last_len:] 127 last_len = len(text)
Map a token iterator to a substring iterator
130def echo_results(results, tokenizer): 131 """Output results to stderr as they are collected""" 132 tokens = [] 133 last_len = 0 134 135 for result in results: 136 tokens.append(result.token_id) 137 text = tokenizer.decode(tokens) 138 sys.stderr.write(text[last_len:]) 139 sys.stderr.flush() 140 last_len = len(text) 141 142 sys.stderr.write("\n\n") 143 sys.stderr.flush() 144 return tokens
Output results to stderr as they are collected
147def generate( 148 instructions: List[str], 149 max_tokens: int = 200, 150 temperature: float = 0.1, 151 topk: int = 1, 152 repetition_penalty: float = 0.0, 153 prefix: str = "", 154 suppress: List[str] = [], 155 model: str = "instruct", 156 stream: bool = False, 157): 158 """Generates completions for a prompt 159 160 This may use a local model, or it may make an API call to an external 161 model if API keys are available. 162 163 >>> generate(["What is the capital of France?"]) 164 ['...Paris...'] 165 166 >>> list(generate(["What is the capital of France?"], stream=True)) 167 ['...Paris...'] 168 """ 169 if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get( 170 "LANGUAGEMODELS_TS_SERVER" 171 ): 172 return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip() 173 174 if os.environ.get("LANGUAGEMODELS_OA_KEY"): 175 return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip() 176 177 tokenizer, model = get_model(model) 178 179 start_time = perf_counter() 180 181 suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress] 182 183 model_info = get_model_info("instruct") 184 185 fmt = model_info.get("prompt_fmt", "{instruction}") 186 187 if repetition_penalty == 0.0: 188 repetition_penalty = model_info.get("repetition_penalty", 1.3) 189 190 prompts = [fmt.replace("{instruction}", inst) for inst in instructions] 191 prompts_tok = [tokenizer.encode(p).tokens for p in prompts] 192 193 outputs_ids = [] 194 if hasattr(model, "translate_batch"): 195 prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens 196 if stream or (config["echo"] and len(prompts_tok) == 1): 197 results = model.generate_tokens( 198 prompts_tok[0], 199 target_prefix=prefix, 200 repetition_penalty=repetition_penalty, 201 max_decoding_length=max_tokens, 202 sampling_temperature=temperature, 203 sampling_topk=topk, 204 suppress_sequences=suppress, 205 ) 206 207 if stream: 208 return stream_results(results, tokenizer) 209 else: 210 outputs_ids = [echo_results(results, tokenizer)] 211 else: 212 results = model.translate_batch( 213 prompts_tok, 214 target_prefix=[prefix] * len(prompts), 215 repetition_penalty=repetition_penalty, 216 max_decoding_length=max_tokens, 217 sampling_temperature=temperature, 218 sampling_topk=topk, 219 suppress_sequences=suppress, 220 beam_size=1, 221 ) 222 outputs_tokens = [r.hypotheses[0] for r in results] 223 for output in outputs_tokens: 224 outputs_ids.append([tokenizer.token_to_id(t) for t in output]) 225 else: 226 if stream or (config["echo"] and len(prompts_tok) == 1): 227 results = model.generate_tokens( 228 prompts_tok, 229 repetition_penalty=repetition_penalty, 230 max_length=max_tokens, 231 sampling_temperature=temperature, 232 sampling_topk=topk, 233 suppress_sequences=suppress, 234 ) 235 236 if stream: 237 return stream_results(results, tokenizer) 238 else: 239 outputs_ids = [echo_results(results, tokenizer)] 240 else: 241 results = model.generate_batch( 242 prompts_tok, 243 repetition_penalty=repetition_penalty, 244 max_length=max_tokens, 245 sampling_temperature=temperature, 246 sampling_topk=topk, 247 suppress_sequences=suppress, 248 beam_size=1, 249 include_prompt_in_result=False, 250 ) 251 outputs_ids = [r.sequences_ids[0] for r in results] 252 253 model_info["requests"] = model_info.get("requests", 0) + len(prompts) 254 255 in_toks = sum(len(p) for p in prompts_tok) 256 model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks 257 258 out_toks = sum(len(o) for o in outputs_ids) 259 model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks 260 261 elapsed_time = perf_counter() - start_time 262 model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time 263 264 return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]
Generates completions for a prompt
This may use a local model, or it may make an API call to an external model if API keys are available.
>>> generate(["What is the capital of France?"])
['...Paris...']
>>> list(generate(["What is the capital of France?"], stream=True))
['...Paris...']
267def rank_instruct(inputs, targets): 268 """Sorts a list of targets by their probabilities 269 270 >>> rank_instruct(["Classify positive or negative: I love python. Classification:"], 271 ... ['positive', 'negative']) 272 [['positive', 'negative']] 273 274 >>> rank_instruct(["Classify fantasy or documentary: " 275 ... "The wizard raised their wand. Classification:"], 276 ... ['fantasy', 'documentary']) 277 [['fantasy', 'documentary']] 278 279 >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"]) 280 [['six', 'seven'], ['seven', 'six']] 281 """ 282 tokenizer, model = get_model("instruct") 283 284 targ_tok = [tokenizer.encode(t).tokens for t in targets] 285 targ_tok *= len(inputs) 286 287 in_tok = [] 288 for input in inputs: 289 toks = [tokenizer.encode(input).tokens] 290 in_tok += toks * len(targets) 291 292 if "Generator" in str(type(model)): 293 scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)]) 294 else: 295 scores = model.score_batch(in_tok, target=targ_tok) 296 297 ret = [] 298 for i in range(0, len(inputs) * len(targets), len(targets)): 299 logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]] 300 results = sorted(zip(targets, logprobs), key=lambda r: -r[1]) 301 ret.append([r[0] for r in results]) 302 303 return ret
Sorts a list of targets by their probabilities
>>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
... ['positive', 'negative'])
[['positive', 'negative']]
>>> rank_instruct(["Classify fantasy or documentary: "
... "The wizard raised their wand. Classification:"],
... ['fantasy', 'documentary'])
[['fantasy', 'documentary']]
>>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
[['six', 'seven'], ['seven', 'six']]
306def parse_chat(prompt): 307 """Converts a chat prompt using special tokens to a plain-text prompt 308 309 This is useful for prompting generic models that have not been fine-tuned 310 for chat using specialized tokens. 311 312 >>> parse_chat('User: What time is it?') 313 Traceback (most recent call last): 314 .... 315 inference.InferenceException: Chat prompt must end with 'Assistant:' 316 317 >>> parse_chat('''User: What time is it? 318 ... 319 ... Assistant:''') 320 [{'role': 'user', 'content': 'What time is it?'}] 321 322 >>> parse_chat(''' 323 ... A helpful assistant 324 ... 325 ... User: What time is it? 326 ... 327 ... Assistant: 328 ... ''') 329 [{'role': 'system', 'content': 'A helpful assistant'}, 330 {'role': 'user', 'content': 'What time is it?'}] 331 332 >>> parse_chat(''' 333 ... A helpful assistant 334 ... 335 ... User: What time is it? 336 ... 337 ... Assistant: The time is 338 ... ''') 339 Traceback (most recent call last): 340 .... 341 inference.InferenceException: Final assistant message must be blank 342 343 >>> parse_chat(''' 344 ... A helpful assistant 345 ... 346 ... User: First para 347 ... 348 ... Second para 349 ... 350 ... Assistant: 351 ... ''') 352 [{'role': 'system', 'content': 'A helpful assistant'}, 353 {'role': 'user', 'content': 'First para\\n\\nSecond para'}] 354 355 >>> parse_chat(''' 356 ... A helpful assistant 357 ... 358 ... User: What time is it? 359 ... 360 ... InvalidRole: Nothing 361 ... 362 ... Assistant: 363 ... ''') 364 Traceback (most recent call last): 365 .... 366 inference.InferenceException: Invalid chat role: invalidrole 367 """ 368 369 if not re.match(r"^\s*\w+:", prompt): 370 prompt = "System: " + prompt 371 372 prompt = "\n\n" + prompt 373 374 chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M) 375 chunks = [m.strip() for m in chunks if m.strip()] 376 377 messages = [] 378 379 for i in range(0, len(chunks), 2): 380 role = chunks[i].lower() 381 382 try: 383 content = chunks[i + 1] 384 content = re.sub(r"\s*\n\n\s*", "\n\n", content) 385 except IndexError: 386 content = "" 387 messages.append({"role": role, "content": content}) 388 389 for message in messages: 390 if message["role"] not in ["system", "user", "assistant"]: 391 raise InferenceException(f"Invalid chat role: {message['role']}") 392 393 if messages[-1]["role"] != "assistant": 394 raise InferenceException("Chat prompt must end with 'Assistant:'") 395 396 if messages[-1]["content"] != "": 397 raise InferenceException("Final assistant message must be blank") 398 399 return messages[:-1]
Converts a chat prompt using special tokens to a plain-text prompt
This is useful for prompting generic models that have not been fine-tuned for chat using specialized tokens.
>>> parse_chat('User: What time is it?')
Traceback (most recent call last):
....
inference.InferenceException: Chat prompt must end with 'Assistant:'
>>> parse_chat('''User: What time is it?
...
... Assistant:''')
[{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... Assistant:
... ''')
[{'role': 'system', 'content': 'A helpful assistant'},
{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... Assistant: The time is
... ''')
Traceback (most recent call last):
....
inference.InferenceException: Final assistant message must be blank
>>> parse_chat('''
... A helpful assistant
...
... User: First para
...
... Second para
...
... Assistant:
... ''')
[{'role': 'system', 'content': 'A helpful assistant'},
{'role': 'user', 'content': 'First para\n\nSecond para'}]
>>> parse_chat('''
... A helpful assistant
...
... User: What time is it?
...
... InvalidRole: Nothing
...
... Assistant:
... ''')
Traceback (most recent call last):
....
inference.InferenceException: Invalid chat role: invalidrole