languagemodels.inference

  1from typing import List
  2import requests
  3import re
  4import os
  5import sys
  6from time import perf_counter
  7
  8from languagemodels.models import get_model, get_model_info
  9from languagemodels.config import config
 10
 11
 12class InferenceException(Exception):
 13    pass
 14
 15
 16def list_tokens(prompt):
 17    """Generates a list of tokens for a supplied prompt
 18
 19    >>> list_tokens("Hello, world!") # doctest: +SKIP
 20    [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
 21
 22    >>> list_tokens("Hello, world!")
 23    [('...Hello', ...), ... ('...world', ...), ...]
 24    """
 25    tokenizer, _ = get_model("instruct")
 26
 27    output = tokenizer.encode(prompt, add_special_tokens=False)
 28    tokens = output.tokens
 29    ids = output.ids
 30
 31    return list(zip(tokens, ids))
 32
 33
 34def generate_ts(engine, prompt, max_tokens=200):
 35    """Generates a single text response for a prompt from a textsynth server
 36
 37    The server and API key are provided as environment variables:
 38
 39    LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080
 40    LANGUAGEMODELS_TS_KEY is the API key
 41    """
 42    apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or ""
 43    server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com"
 44
 45    response = requests.post(
 46        f"{server}/v1/engines/{engine}/completions",
 47        headers={"Authorization": f"Bearer {apikey}"},
 48        json={"prompt": prompt, "max_tokens": max_tokens},
 49    )
 50    resp = response.json()
 51    if "text" in resp:
 52        return resp["text"]
 53    else:
 54        raise InferenceException(f"TextSynth error: {resp}")
 55
 56
 57def generate_oa(engine, prompt, max_tokens=200, temperature=0):
 58    """Generates a single text response for a prompt using OpenAI
 59
 60    The server and API key are provided as environment variables:
 61
 62    LANGUAGEMODELS_OA_KEY is the API key
 63    """
 64    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
 65
 66    response = requests.post(
 67        "https://api.openai.com/v1/completions",
 68        headers={
 69            "Authorization": f"Bearer {apikey}",
 70            "Content-Type": "application/json",
 71        },
 72        json={
 73            "model": engine,
 74            "prompt": prompt,
 75            "max_tokens": max_tokens,
 76            "temperature": temperature,
 77        },
 78    )
 79    resp = response.json()
 80
 81    try:
 82        return resp["choices"][0]["text"]
 83    except KeyError:
 84        raise InferenceException(f"OpenAI error: {resp}")
 85
 86
 87def chat_oa(engine, prompt, max_tokens=200, temperature=0):
 88    """Generates a single text response for a prompt using OpenAI
 89
 90    The server and API key are provided as environment variables:
 91
 92    LANGUAGEMODELS_OA_KEY is the API key
 93    """
 94    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
 95
 96    response = requests.post(
 97        "https://api.openai.com/v1/chat/completions",
 98        headers={
 99            "Authorization": f"Bearer {apikey}",
100            "Content-Type": "application/json",
101        },
102        json={
103            "model": engine,
104            "messages": [{"role": "user", "content": prompt}],
105            "max_tokens": max_tokens,
106            "temperature": temperature,
107        },
108    )
109    resp = response.json()
110
111    try:
112        return resp["choices"][0]["message"]["content"]
113    except KeyError:
114        raise InferenceException(f"OpenAI error: {resp}")
115
116
117def stream_results(results, tokenizer):
118    """Map a token iterator to a substring iterator"""
119    tokens = []
120    last_len = 0
121
122    for result in results:
123        tokens.append(result.token_id)
124        text = tokenizer.decode(tokens)
125        yield text[last_len:]
126        last_len = len(text)
127
128
129def echo_results(results, tokenizer):
130    """Output results to stderr as they are collected"""
131    tokens = []
132    last_len = 0
133
134    for result in results:
135        tokens.append(result.token_id)
136        text = tokenizer.decode(tokens)
137        sys.stderr.write(text[last_len:])
138        sys.stderr.flush()
139        last_len = len(text)
140
141    sys.stderr.write("\n\n")
142    sys.stderr.flush()
143    return tokens
144
145
146def generate(
147    instructions: List[str],
148    max_tokens: int = 200,
149    temperature: float = 0.1,
150    topk: int = 1,
151    repetition_penalty: float = 0.0,
152    prefix: str = "",
153    suppress: List[str] = [],
154    model: str = "instruct",
155    stream: bool = False,
156):
157    """Generates completions for a prompt
158
159    This may use a local model, or it may make an API call to an external
160    model if API keys are available.
161
162    >>> generate(["What is the capital of France?"])
163    ['...Paris...']
164
165    >>> list(generate(["What is the capital of France?"], stream=True))
166    ['...Paris...']
167    """
168    if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get(
169        "LANGUAGEMODELS_TS_SERVER"
170    ):
171        return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip()
172
173    if os.environ.get("LANGUAGEMODELS_OA_KEY"):
174        return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip()
175
176    tokenizer, model = get_model(model)
177
178    start_time = perf_counter()
179
180    suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress]
181
182    model_info = get_model_info("instruct")
183
184    fmt = model_info.get("prompt_fmt", "{instruction}")
185
186    if repetition_penalty == 0.0:
187        repetition_penalty = model_info.get("repetition_penalty", 1.3)
188
189    prompts = [fmt.replace("{instruction}", inst) for inst in instructions]
190    prompts_tok = [tokenizer.encode(p).tokens for p in prompts]
191
192    outputs_ids = []
193    if hasattr(model, "translate_batch"):
194        prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens
195        if stream or (config["echo"] and len(prompts_tok) == 1):
196            results = model.generate_tokens(
197                prompts_tok[0],
198                target_prefix=prefix,
199                repetition_penalty=repetition_penalty,
200                max_decoding_length=max_tokens,
201                sampling_temperature=temperature,
202                sampling_topk=topk,
203                suppress_sequences=suppress,
204            )
205
206            if stream:
207                return stream_results(results, tokenizer)
208            else:
209                outputs_ids = [echo_results(results, tokenizer)]
210        else:
211            results = model.translate_batch(
212                prompts_tok,
213                target_prefix=[prefix] * len(prompts),
214                repetition_penalty=repetition_penalty,
215                max_decoding_length=max_tokens,
216                sampling_temperature=temperature,
217                sampling_topk=topk,
218                suppress_sequences=suppress,
219                beam_size=1,
220            )
221            outputs_tokens = [r.hypotheses[0] for r in results]
222            for output in outputs_tokens:
223                outputs_ids.append([tokenizer.token_to_id(t) for t in output])
224    else:
225        if stream or (config["echo"] and len(prompts_tok) == 1):
226            results = model.generate_tokens(
227                prompts_tok,
228                repetition_penalty=repetition_penalty,
229                max_length=max_tokens,
230                sampling_temperature=temperature,
231                sampling_topk=topk,
232                suppress_sequences=suppress,
233            )
234
235            if stream:
236                return stream_results(results, tokenizer)
237            else:
238                outputs_ids = [echo_results(results, tokenizer)]
239        else:
240            results = model.generate_batch(
241                prompts_tok,
242                repetition_penalty=repetition_penalty,
243                max_length=max_tokens,
244                sampling_temperature=temperature,
245                sampling_topk=topk,
246                suppress_sequences=suppress,
247                beam_size=1,
248                include_prompt_in_result=False,
249            )
250            outputs_ids = [r.sequences_ids[0] for r in results]
251
252    model_info["requests"] = model_info.get("requests", 0) + len(prompts)
253
254    in_toks = sum(len(p) for p in prompts_tok)
255    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
256
257    out_toks = sum(len(o) for o in outputs_ids)
258    model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks
259
260    elapsed_time = perf_counter() - start_time
261    model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time
262
263    return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]
264
265
266def rank_instruct(inputs, targets):
267    """Sorts a list of targets by their probabilities
268
269    >>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
270    ... ['positive', 'negative'])
271    [['positive', 'negative']]
272
273    >>> rank_instruct(["Classify fantasy or documentary: "
274    ... "The wizard raised their wand. Classification:"],
275    ... ['fantasy', 'documentary'])
276    [['fantasy', 'documentary']]
277
278    >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
279    [['six', 'seven'], ['seven', 'six']]
280    """
281    tokenizer, model = get_model("instruct")
282
283    targ_tok = [tokenizer.encode(t).tokens for t in targets]
284    targ_tok *= len(inputs)
285
286    in_tok = []
287    for input in inputs:
288        toks = [tokenizer.encode(input).tokens]
289        in_tok += toks * len(targets)
290
291    if "Generator" in str(type(model)):
292        scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)])
293    else:
294        scores = model.score_batch(in_tok, target=targ_tok)
295
296    ret = []
297    for i in range(0, len(inputs) * len(targets), len(targets)):
298        logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]]
299        results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
300        ret.append([r[0] for r in results])
301
302    return ret
303
304
305def parse_chat(prompt):
306    """Converts a chat prompt using special tokens to a plain-text prompt
307
308    This is useful for prompting generic models that have not been fine-tuned
309    for chat using specialized tokens.
310
311    >>> parse_chat('User: What time is it?')
312    Traceback (most recent call last):
313        ....
314    inference.InferenceException: Chat prompt must end with 'Assistant:'
315
316    >>> parse_chat('''User: What time is it?
317    ...
318    ...               Assistant:''')
319    [{'role': 'user', 'content': 'What time is it?'}]
320
321    >>> parse_chat('''
322    ...              A helpful assistant
323    ...
324    ...              User: What time is it?
325    ...
326    ...              Assistant:
327    ...              ''')
328    [{'role': 'system', 'content': 'A helpful assistant'},
329     {'role': 'user', 'content': 'What time is it?'}]
330
331    >>> parse_chat('''
332    ...              A helpful assistant
333    ...
334    ...              User: What time is it?
335    ...
336    ...              Assistant: The time is
337    ...              ''')
338    Traceback (most recent call last):
339        ....
340    inference.InferenceException: Final assistant message must be blank
341
342    >>> parse_chat('''
343    ...              A helpful assistant
344    ...
345    ...              User: First para
346    ...
347    ...              Second para
348    ...
349    ...              Assistant:
350    ...              ''')
351    [{'role': 'system', 'content': 'A helpful assistant'},
352     {'role': 'user', 'content': 'First para\\n\\nSecond para'}]
353
354    >>> parse_chat('''
355    ...              A helpful assistant
356    ...
357    ...              User: What time is it?
358    ...
359    ...              InvalidRole: Nothing
360    ...
361    ...              Assistant:
362    ...              ''')
363    Traceback (most recent call last):
364        ....
365    inference.InferenceException: Invalid chat role: invalidrole
366    """
367
368    if not re.match(r"^\s*\w+:", prompt):
369        prompt = "System: " + prompt
370
371    prompt = "\n\n" + prompt
372
373    chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M)
374    chunks = [m.strip() for m in chunks if m.strip()]
375
376    messages = []
377
378    for i in range(0, len(chunks), 2):
379        role = chunks[i].lower()
380
381        try:
382            content = chunks[i + 1]
383            content = re.sub(r"\s*\n\n\s*", "\n\n", content)
384        except IndexError:
385            content = ""
386        messages.append({"role": role, "content": content})
387
388    for message in messages:
389        if message["role"] not in ["system", "user", "assistant"]:
390            raise InferenceException(f"Invalid chat role: {message['role']}")
391
392    if messages[-1]["role"] != "assistant":
393        raise InferenceException("Chat prompt must end with 'Assistant:'")
394
395    if messages[-1]["content"] != "":
396        raise InferenceException("Final assistant message must be blank")
397
398    return messages[:-1]
class InferenceException(builtins.Exception):
13class InferenceException(Exception):
14    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def list_tokens(prompt):
17def list_tokens(prompt):
18    """Generates a list of tokens for a supplied prompt
19
20    >>> list_tokens("Hello, world!") # doctest: +SKIP
21    [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
22
23    >>> list_tokens("Hello, world!")
24    [('...Hello', ...), ... ('...world', ...), ...]
25    """
26    tokenizer, _ = get_model("instruct")
27
28    output = tokenizer.encode(prompt, add_special_tokens=False)
29    tokens = output.tokens
30    ids = output.ids
31
32    return list(zip(tokens, ids))

Generates a list of tokens for a supplied prompt

>>> list_tokens("Hello, world!") # doctest: +SKIP
[('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
>>> list_tokens("Hello, world!")
[('...Hello', ...), ... ('...world', ...), ...]
def generate_ts(engine, prompt, max_tokens=200):
35def generate_ts(engine, prompt, max_tokens=200):
36    """Generates a single text response for a prompt from a textsynth server
37
38    The server and API key are provided as environment variables:
39
40    LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080
41    LANGUAGEMODELS_TS_KEY is the API key
42    """
43    apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or ""
44    server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com"
45
46    response = requests.post(
47        f"{server}/v1/engines/{engine}/completions",
48        headers={"Authorization": f"Bearer {apikey}"},
49        json={"prompt": prompt, "max_tokens": max_tokens},
50    )
51    resp = response.json()
52    if "text" in resp:
53        return resp["text"]
54    else:
55        raise InferenceException(f"TextSynth error: {resp}")

Generates a single text response for a prompt from a textsynth server

The server and API key are provided as environment variables:

LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 LANGUAGEMODELS_TS_KEY is the API key

def generate_oa(engine, prompt, max_tokens=200, temperature=0):
58def generate_oa(engine, prompt, max_tokens=200, temperature=0):
59    """Generates a single text response for a prompt using OpenAI
60
61    The server and API key are provided as environment variables:
62
63    LANGUAGEMODELS_OA_KEY is the API key
64    """
65    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
66
67    response = requests.post(
68        "https://api.openai.com/v1/completions",
69        headers={
70            "Authorization": f"Bearer {apikey}",
71            "Content-Type": "application/json",
72        },
73        json={
74            "model": engine,
75            "prompt": prompt,
76            "max_tokens": max_tokens,
77            "temperature": temperature,
78        },
79    )
80    resp = response.json()
81
82    try:
83        return resp["choices"][0]["text"]
84    except KeyError:
85        raise InferenceException(f"OpenAI error: {resp}")

Generates a single text response for a prompt using OpenAI

The server and API key are provided as environment variables:

LANGUAGEMODELS_OA_KEY is the API key

def chat_oa(engine, prompt, max_tokens=200, temperature=0):
 88def chat_oa(engine, prompt, max_tokens=200, temperature=0):
 89    """Generates a single text response for a prompt using OpenAI
 90
 91    The server and API key are provided as environment variables:
 92
 93    LANGUAGEMODELS_OA_KEY is the API key
 94    """
 95    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
 96
 97    response = requests.post(
 98        "https://api.openai.com/v1/chat/completions",
 99        headers={
100            "Authorization": f"Bearer {apikey}",
101            "Content-Type": "application/json",
102        },
103        json={
104            "model": engine,
105            "messages": [{"role": "user", "content": prompt}],
106            "max_tokens": max_tokens,
107            "temperature": temperature,
108        },
109    )
110    resp = response.json()
111
112    try:
113        return resp["choices"][0]["message"]["content"]
114    except KeyError:
115        raise InferenceException(f"OpenAI error: {resp}")

Generates a single text response for a prompt using OpenAI

The server and API key are provided as environment variables:

LANGUAGEMODELS_OA_KEY is the API key

def stream_results(results, tokenizer):
118def stream_results(results, tokenizer):
119    """Map a token iterator to a substring iterator"""
120    tokens = []
121    last_len = 0
122
123    for result in results:
124        tokens.append(result.token_id)
125        text = tokenizer.decode(tokens)
126        yield text[last_len:]
127        last_len = len(text)

Map a token iterator to a substring iterator

def echo_results(results, tokenizer):
130def echo_results(results, tokenizer):
131    """Output results to stderr as they are collected"""
132    tokens = []
133    last_len = 0
134
135    for result in results:
136        tokens.append(result.token_id)
137        text = tokenizer.decode(tokens)
138        sys.stderr.write(text[last_len:])
139        sys.stderr.flush()
140        last_len = len(text)
141
142    sys.stderr.write("\n\n")
143    sys.stderr.flush()
144    return tokens

Output results to stderr as they are collected

def generate( instructions: List[str], max_tokens: int = 200, temperature: float = 0.1, topk: int = 1, repetition_penalty: float = 0.0, prefix: str = '', suppress: List[str] = [], model: str = 'instruct', stream: bool = False):
147def generate(
148    instructions: List[str],
149    max_tokens: int = 200,
150    temperature: float = 0.1,
151    topk: int = 1,
152    repetition_penalty: float = 0.0,
153    prefix: str = "",
154    suppress: List[str] = [],
155    model: str = "instruct",
156    stream: bool = False,
157):
158    """Generates completions for a prompt
159
160    This may use a local model, or it may make an API call to an external
161    model if API keys are available.
162
163    >>> generate(["What is the capital of France?"])
164    ['...Paris...']
165
166    >>> list(generate(["What is the capital of France?"], stream=True))
167    ['...Paris...']
168    """
169    if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get(
170        "LANGUAGEMODELS_TS_SERVER"
171    ):
172        return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip()
173
174    if os.environ.get("LANGUAGEMODELS_OA_KEY"):
175        return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip()
176
177    tokenizer, model = get_model(model)
178
179    start_time = perf_counter()
180
181    suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress]
182
183    model_info = get_model_info("instruct")
184
185    fmt = model_info.get("prompt_fmt", "{instruction}")
186
187    if repetition_penalty == 0.0:
188        repetition_penalty = model_info.get("repetition_penalty", 1.3)
189
190    prompts = [fmt.replace("{instruction}", inst) for inst in instructions]
191    prompts_tok = [tokenizer.encode(p).tokens for p in prompts]
192
193    outputs_ids = []
194    if hasattr(model, "translate_batch"):
195        prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens
196        if stream or (config["echo"] and len(prompts_tok) == 1):
197            results = model.generate_tokens(
198                prompts_tok[0],
199                target_prefix=prefix,
200                repetition_penalty=repetition_penalty,
201                max_decoding_length=max_tokens,
202                sampling_temperature=temperature,
203                sampling_topk=topk,
204                suppress_sequences=suppress,
205            )
206
207            if stream:
208                return stream_results(results, tokenizer)
209            else:
210                outputs_ids = [echo_results(results, tokenizer)]
211        else:
212            results = model.translate_batch(
213                prompts_tok,
214                target_prefix=[prefix] * len(prompts),
215                repetition_penalty=repetition_penalty,
216                max_decoding_length=max_tokens,
217                sampling_temperature=temperature,
218                sampling_topk=topk,
219                suppress_sequences=suppress,
220                beam_size=1,
221            )
222            outputs_tokens = [r.hypotheses[0] for r in results]
223            for output in outputs_tokens:
224                outputs_ids.append([tokenizer.token_to_id(t) for t in output])
225    else:
226        if stream or (config["echo"] and len(prompts_tok) == 1):
227            results = model.generate_tokens(
228                prompts_tok,
229                repetition_penalty=repetition_penalty,
230                max_length=max_tokens,
231                sampling_temperature=temperature,
232                sampling_topk=topk,
233                suppress_sequences=suppress,
234            )
235
236            if stream:
237                return stream_results(results, tokenizer)
238            else:
239                outputs_ids = [echo_results(results, tokenizer)]
240        else:
241            results = model.generate_batch(
242                prompts_tok,
243                repetition_penalty=repetition_penalty,
244                max_length=max_tokens,
245                sampling_temperature=temperature,
246                sampling_topk=topk,
247                suppress_sequences=suppress,
248                beam_size=1,
249                include_prompt_in_result=False,
250            )
251            outputs_ids = [r.sequences_ids[0] for r in results]
252
253    model_info["requests"] = model_info.get("requests", 0) + len(prompts)
254
255    in_toks = sum(len(p) for p in prompts_tok)
256    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
257
258    out_toks = sum(len(o) for o in outputs_ids)
259    model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks
260
261    elapsed_time = perf_counter() - start_time
262    model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time
263
264    return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]

Generates completions for a prompt

This may use a local model, or it may make an API call to an external model if API keys are available.

>>> generate(["What is the capital of France?"])
['...Paris...']
>>> list(generate(["What is the capital of France?"], stream=True))
['...Paris...']
def rank_instruct(inputs, targets):
267def rank_instruct(inputs, targets):
268    """Sorts a list of targets by their probabilities
269
270    >>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
271    ... ['positive', 'negative'])
272    [['positive', 'negative']]
273
274    >>> rank_instruct(["Classify fantasy or documentary: "
275    ... "The wizard raised their wand. Classification:"],
276    ... ['fantasy', 'documentary'])
277    [['fantasy', 'documentary']]
278
279    >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
280    [['six', 'seven'], ['seven', 'six']]
281    """
282    tokenizer, model = get_model("instruct")
283
284    targ_tok = [tokenizer.encode(t).tokens for t in targets]
285    targ_tok *= len(inputs)
286
287    in_tok = []
288    for input in inputs:
289        toks = [tokenizer.encode(input).tokens]
290        in_tok += toks * len(targets)
291
292    if "Generator" in str(type(model)):
293        scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)])
294    else:
295        scores = model.score_batch(in_tok, target=targ_tok)
296
297    ret = []
298    for i in range(0, len(inputs) * len(targets), len(targets)):
299        logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]]
300        results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
301        ret.append([r[0] for r in results])
302
303    return ret

Sorts a list of targets by their probabilities

>>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
... ['positive', 'negative'])
[['positive', 'negative']]
>>> rank_instruct(["Classify fantasy or documentary: "
... "The wizard raised their wand. Classification:"],
... ['fantasy', 'documentary'])
[['fantasy', 'documentary']]
>>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
[['six', 'seven'], ['seven', 'six']]
def parse_chat(prompt):
306def parse_chat(prompt):
307    """Converts a chat prompt using special tokens to a plain-text prompt
308
309    This is useful for prompting generic models that have not been fine-tuned
310    for chat using specialized tokens.
311
312    >>> parse_chat('User: What time is it?')
313    Traceback (most recent call last):
314        ....
315    inference.InferenceException: Chat prompt must end with 'Assistant:'
316
317    >>> parse_chat('''User: What time is it?
318    ...
319    ...               Assistant:''')
320    [{'role': 'user', 'content': 'What time is it?'}]
321
322    >>> parse_chat('''
323    ...              A helpful assistant
324    ...
325    ...              User: What time is it?
326    ...
327    ...              Assistant:
328    ...              ''')
329    [{'role': 'system', 'content': 'A helpful assistant'},
330     {'role': 'user', 'content': 'What time is it?'}]
331
332    >>> parse_chat('''
333    ...              A helpful assistant
334    ...
335    ...              User: What time is it?
336    ...
337    ...              Assistant: The time is
338    ...              ''')
339    Traceback (most recent call last):
340        ....
341    inference.InferenceException: Final assistant message must be blank
342
343    >>> parse_chat('''
344    ...              A helpful assistant
345    ...
346    ...              User: First para
347    ...
348    ...              Second para
349    ...
350    ...              Assistant:
351    ...              ''')
352    [{'role': 'system', 'content': 'A helpful assistant'},
353     {'role': 'user', 'content': 'First para\\n\\nSecond para'}]
354
355    >>> parse_chat('''
356    ...              A helpful assistant
357    ...
358    ...              User: What time is it?
359    ...
360    ...              InvalidRole: Nothing
361    ...
362    ...              Assistant:
363    ...              ''')
364    Traceback (most recent call last):
365        ....
366    inference.InferenceException: Invalid chat role: invalidrole
367    """
368
369    if not re.match(r"^\s*\w+:", prompt):
370        prompt = "System: " + prompt
371
372    prompt = "\n\n" + prompt
373
374    chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M)
375    chunks = [m.strip() for m in chunks if m.strip()]
376
377    messages = []
378
379    for i in range(0, len(chunks), 2):
380        role = chunks[i].lower()
381
382        try:
383            content = chunks[i + 1]
384            content = re.sub(r"\s*\n\n\s*", "\n\n", content)
385        except IndexError:
386            content = ""
387        messages.append({"role": role, "content": content})
388
389    for message in messages:
390        if message["role"] not in ["system", "user", "assistant"]:
391            raise InferenceException(f"Invalid chat role: {message['role']}")
392
393    if messages[-1]["role"] != "assistant":
394        raise InferenceException("Chat prompt must end with 'Assistant:'")
395
396    if messages[-1]["content"] != "":
397        raise InferenceException("Final assistant message must be blank")
398
399    return messages[:-1]

Converts a chat prompt using special tokens to a plain-text prompt

This is useful for prompting generic models that have not been fine-tuned for chat using specialized tokens.

>>> parse_chat('User: What time is it?')
Traceback (most recent call last):
    ....
inference.InferenceException: Chat prompt must end with 'Assistant:'
>>> parse_chat('''User: What time is it?
...
...               Assistant:''')
[{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              Assistant:
...              ''')
[{'role': 'system', 'content': 'A helpful assistant'},
 {'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              Assistant: The time is
...              ''')
Traceback (most recent call last):
    ....
inference.InferenceException: Final assistant message must be blank
>>> parse_chat('''
...              A helpful assistant
...
...              User: First para
...
...              Second para
...
...              Assistant:
...              ''')
[{'role': 'system', 'content': 'A helpful assistant'},
 {'role': 'user', 'content': 'First para\n\nSecond para'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              InvalidRole: Nothing
...
...              Assistant:
...              ''')
Traceback (most recent call last):
    ....
inference.InferenceException: Invalid chat role: invalidrole