languagemodels.inference

  1from typing import List
  2import requests
  3import re
  4import os
  5import sys
  6from time import perf_counter
  7
  8from languagemodels.models import get_model, get_model_info
  9from languagemodels.config import config
 10
 11
 12class InferenceException(Exception):
 13    pass
 14
 15
 16def truncate_prompt(prompt):
 17    """Truncates a prompt to the maximum length allowed by the config"""
 18    max_prompt_length = config["max_prompt_length"]
 19    if len(prompt) > max_prompt_length:
 20        print(
 21            f"Warning: Prompt truncated from {len(prompt)} to "
 22            f"{max_prompt_length} characters to avoid OOM."
 23        )
 24        return prompt[:max_prompt_length]
 25    return prompt
 26
 27
 28def list_tokens(prompt):
 29    """Generates a list of tokens for a supplied prompt
 30
 31    >>> list_tokens("Hello, world!") # doctest: +SKIP
 32    [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
 33
 34    >>> list_tokens("Hello, world!")
 35    [('...Hello', ...), ... ('...world', ...), ...]
 36    """
 37    prompt = truncate_prompt(prompt)
 38    tokenizer, _ = get_model("instruct")
 39
 40    output = tokenizer.encode(prompt, add_special_tokens=False)
 41    tokens = output.tokens
 42    ids = output.ids
 43
 44    return list(zip(tokens, ids))
 45
 46
 47def generate_ts(engine, prompt, max_tokens=200):
 48    """Generates a single text response for a prompt from a textsynth server
 49
 50    The server and API key are provided as environment variables:
 51
 52    LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080
 53    LANGUAGEMODELS_TS_KEY is the API key
 54    """
 55    apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or ""
 56    server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com"
 57
 58    response = requests.post(
 59        f"{server}/v1/engines/{engine}/completions",
 60        headers={"Authorization": f"Bearer {apikey}"},
 61        json={"prompt": prompt, "max_tokens": max_tokens},
 62    )
 63    resp = response.json()
 64    if "text" in resp:
 65        return resp["text"]
 66    else:
 67        raise InferenceException(f"TextSynth error: {resp}")
 68
 69
 70def generate_oa(engine, prompt, max_tokens=200, temperature=0):
 71    """Generates a single text response for a prompt using OpenAI
 72
 73    The server and API key are provided as environment variables:
 74
 75    LANGUAGEMODELS_OA_KEY is the API key
 76    """
 77    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
 78
 79    response = requests.post(
 80        "https://api.openai.com/v1/completions",
 81        headers={
 82            "Authorization": f"Bearer {apikey}",
 83            "Content-Type": "application/json",
 84        },
 85        json={
 86            "model": engine,
 87            "prompt": prompt,
 88            "max_tokens": max_tokens,
 89            "temperature": temperature,
 90        },
 91    )
 92    resp = response.json()
 93
 94    try:
 95        return resp["choices"][0]["text"]
 96    except KeyError:
 97        raise InferenceException(f"OpenAI error: {resp}")
 98
 99
100def chat_oa(engine, prompt, max_tokens=200, temperature=0):
101    """Generates a single text response for a prompt using OpenAI
102
103    The server and API key are provided as environment variables:
104
105    LANGUAGEMODELS_OA_KEY is the API key
106    """
107    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
108
109    response = requests.post(
110        "https://api.openai.com/v1/chat/completions",
111        headers={
112            "Authorization": f"Bearer {apikey}",
113            "Content-Type": "application/json",
114        },
115        json={
116            "model": engine,
117            "messages": [{"role": "user", "content": prompt}],
118            "max_tokens": max_tokens,
119            "temperature": temperature,
120        },
121    )
122    resp = response.json()
123
124    try:
125        return resp["choices"][0]["message"]["content"]
126    except KeyError:
127        raise InferenceException(f"OpenAI error: {resp}")
128
129
130def stream_results(results, tokenizer):
131    """Map a token iterator to a substring iterator"""
132    tokens = []
133    last_len = 0
134
135    for result in results:
136        tokens.append(result.token_id)
137        text = tokenizer.decode(tokens)
138        yield text[last_len:]
139        last_len = len(text)
140
141
142def echo_results(results, tokenizer):
143    """Output results to stderr as they are collected"""
144    tokens = []
145    last_len = 0
146
147    for result in results:
148        tokens.append(result.token_id)
149        text = tokenizer.decode(tokens)
150        sys.stderr.write(text[last_len:])
151        sys.stderr.flush()
152        last_len = len(text)
153
154    sys.stderr.write("\n\n")
155    sys.stderr.flush()
156    return tokens
157
158
159def generate(
160    instructions: List[str],
161    max_tokens: int = 200,
162    temperature: float = 0.1,
163    topk: int = 1,
164    repetition_penalty: float = 0.0,
165    prefix: str = "",
166    suppress: List[str] = [],
167    model: str = "instruct",
168    stream: bool = False,
169):
170    """Generates completions for a prompt
171
172    This may use a local model, or it may make an API call to an external
173    model if API keys are available.
174
175    >>> generate(["What is the capital of France?"])
176    ['...Paris...']
177
178    >>> list(generate(["What is the capital of France?"], stream=True))
179    ['...Paris...']
180    """
181    if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get(
182        "LANGUAGEMODELS_TS_SERVER"
183    ):
184        return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip()
185
186    if os.environ.get("LANGUAGEMODELS_OA_KEY"):
187        return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip()
188
189    tokenizer, model = get_model(model)
190
191    start_time = perf_counter()
192
193    suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress]
194
195    model_info = get_model_info("instruct")
196
197    fmt = model_info.get("prompt_fmt", "{instruction}")
198
199    if repetition_penalty == 0.0:
200        repetition_penalty = model_info.get("repetition_penalty", 1.3)
201
202    prompts = [fmt.replace("{instruction}", inst) for inst in instructions]
203    truncated_prompts = [truncate_prompt(p) for p in prompts]
204
205    prompts_tok = [tokenizer.encode(p).tokens for p in truncated_prompts]
206
207    outputs_ids = []
208    if hasattr(model, "translate_batch"):
209        prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens
210        if stream or (config["echo"] and len(prompts_tok) == 1):
211            results = model.generate_tokens(
212                prompts_tok[0],
213                target_prefix=prefix,
214                repetition_penalty=repetition_penalty,
215                max_decoding_length=max_tokens,
216                sampling_temperature=temperature,
217                sampling_topk=topk,
218                suppress_sequences=suppress,
219            )
220
221            if stream:
222                return stream_results(results, tokenizer)
223            else:
224                outputs_ids = [echo_results(results, tokenizer)]
225        else:
226            results = model.translate_batch(
227                prompts_tok,
228                target_prefix=[prefix] * len(prompts),
229                repetition_penalty=repetition_penalty,
230                max_decoding_length=max_tokens,
231                sampling_temperature=temperature,
232                sampling_topk=topk,
233                suppress_sequences=suppress,
234                beam_size=1,
235            )
236            outputs_tokens = [r.hypotheses[0] for r in results]
237            for output in outputs_tokens:
238                outputs_ids.append([tokenizer.token_to_id(t) for t in output])
239    else:
240        if stream or (config["echo"] and len(prompts_tok) == 1):
241            results = model.generate_tokens(
242                prompts_tok,
243                repetition_penalty=repetition_penalty,
244                max_length=max_tokens,
245                sampling_temperature=temperature,
246                sampling_topk=topk,
247                suppress_sequences=suppress,
248            )
249
250            if stream:
251                return stream_results(results, tokenizer)
252            else:
253                outputs_ids = [echo_results(results, tokenizer)]
254        else:
255            results = model.generate_batch(
256                prompts_tok,
257                repetition_penalty=repetition_penalty,
258                max_length=max_tokens,
259                sampling_temperature=temperature,
260                sampling_topk=topk,
261                suppress_sequences=suppress,
262                beam_size=1,
263                include_prompt_in_result=False,
264            )
265            outputs_ids = [r.sequences_ids[0] for r in results]
266
267    model_info["requests"] = model_info.get("requests", 0) + len(prompts)
268
269    in_toks = sum(len(p) for p in prompts_tok)
270    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
271
272    out_toks = sum(len(o) for o in outputs_ids)
273    model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks
274
275    elapsed_time = perf_counter() - start_time
276    model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time
277
278    return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]
279
280
281def rank_instruct(inputs, targets):
282    """Sorts a list of targets by their probabilities
283
284    >>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
285    ... ['positive', 'negative'])
286    [['positive', 'negative']]
287
288    >>> rank_instruct(["Classify fantasy or documentary: "
289    ... "The wizard raised their wand. Classification:"],
290    ... ['fantasy', 'documentary'])
291    [['fantasy', 'documentary']]
292
293    >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
294    [['six', 'seven'], ['seven', 'six']]
295    """
296    tokenizer, model = get_model("instruct")
297
298    model_info = get_model_info("instruct")
299    fmt = model_info.get("prompt_fmt", "{instruction}")
300    inputs = [fmt.replace("{instruction}", inst) for inst in inputs]
301    inputs = [truncate_prompt(i) for i in inputs]
302
303    targ_tok = [tokenizer.encode(t, add_special_tokens=False).tokens for t in targets]
304    targ_tok *= len(inputs)
305
306    in_tok = []
307    for input in inputs:
308        toks = [tokenizer.encode(input).tokens]
309        in_tok += toks * len(targets)
310
311    if "Generator" in str(type(model)):
312        scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)])
313    else:
314        scores = model.score_batch(in_tok, target=targ_tok)
315
316    ret = []
317    for i in range(0, len(inputs) * len(targets), len(targets)):
318        logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]]
319        results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
320        ret.append([r[0] for r in results])
321
322    return ret
323
324
325def parse_chat(prompt):
326    """Converts a chat prompt using special tokens to a plain-text prompt
327
328    This is useful for prompting generic models that have not been fine-tuned
329    for chat using specialized tokens.
330
331    >>> parse_chat('User: What time is it?')
332    Traceback (most recent call last):
333        ....
334    inference.InferenceException: Chat prompt must end with 'Assistant:'
335
336    >>> parse_chat('''User: What time is it?
337    ...
338    ...               Assistant:''')
339    [{'role': 'user', 'content': 'What time is it?'}]
340
341    >>> parse_chat('''
342    ...              A helpful assistant
343    ...
344    ...              User: What time is it?
345    ...
346    ...              Assistant:
347    ...              ''')
348    [{'role': 'system', 'content': 'A helpful assistant'},
349     {'role': 'user', 'content': 'What time is it?'}]
350
351    >>> parse_chat('''
352    ...              A helpful assistant
353    ...
354    ...              User: What time is it?
355    ...
356    ...              Assistant: The time is
357    ...              ''')
358    Traceback (most recent call last):
359        ....
360    inference.InferenceException: Final assistant message must be blank
361
362    >>> parse_chat('''
363    ...              A helpful assistant
364    ...
365    ...              User: First para
366    ...
367    ...              Second para
368    ...
369    ...              Assistant:
370    ...              ''')
371    [{'role': 'system', 'content': 'A helpful assistant'},
372     {'role': 'user', 'content': 'First para\\n\\nSecond para'}]
373
374    >>> parse_chat('''
375    ...              A helpful assistant
376    ...
377    ...              User: What time is it?
378    ...
379    ...              InvalidRole: Nothing
380    ...
381    ...              Assistant:
382    ...              ''')
383    Traceback (most recent call last):
384        ....
385    inference.InferenceException: Invalid chat role: invalidrole
386    """
387
388    if not re.match(r"^\s*\w+:", prompt):
389        prompt = "System: " + prompt
390
391    prompt = "\n\n" + prompt
392
393    chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M)
394    chunks = [m.strip() for m in chunks if m.strip()]
395
396    messages = []
397
398    for i in range(0, len(chunks), 2):
399        role = chunks[i].lower()
400
401        try:
402            content = chunks[i + 1]
403            content = re.sub(r"\s*\n\n\s*", "\n\n", content)
404        except IndexError:
405            content = ""
406        messages.append({"role": role, "content": content})
407
408    for message in messages:
409        if message["role"] not in ["system", "user", "assistant"]:
410            raise InferenceException(f"Invalid chat role: {message['role']}")
411
412    if messages[-1]["role"] != "assistant":
413        raise InferenceException("Chat prompt must end with 'Assistant:'")
414
415    if messages[-1]["content"] != "":
416        raise InferenceException("Final assistant message must be blank")
417
418    return messages[:-1]
class InferenceException(builtins.Exception):
13class InferenceException(Exception):
14    pass

Common base class for all non-exit exceptions.

def truncate_prompt(prompt):
17def truncate_prompt(prompt):
18    """Truncates a prompt to the maximum length allowed by the config"""
19    max_prompt_length = config["max_prompt_length"]
20    if len(prompt) > max_prompt_length:
21        print(
22            f"Warning: Prompt truncated from {len(prompt)} to "
23            f"{max_prompt_length} characters to avoid OOM."
24        )
25        return prompt[:max_prompt_length]
26    return prompt

Truncates a prompt to the maximum length allowed by the config

def list_tokens(prompt):
29def list_tokens(prompt):
30    """Generates a list of tokens for a supplied prompt
31
32    >>> list_tokens("Hello, world!") # doctest: +SKIP
33    [('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
34
35    >>> list_tokens("Hello, world!")
36    [('...Hello', ...), ... ('...world', ...), ...]
37    """
38    prompt = truncate_prompt(prompt)
39    tokenizer, _ = get_model("instruct")
40
41    output = tokenizer.encode(prompt, add_special_tokens=False)
42    tokens = output.tokens
43    ids = output.ids
44
45    return list(zip(tokens, ids))

Generates a list of tokens for a supplied prompt

>>> list_tokens("Hello, world!") # doctest: +SKIP
[('▁Hello', 8774), (',', 6), ('▁world', 296), ('!', 55)]
>>> list_tokens("Hello, world!")
[('...Hello', ...), ... ('...world', ...), ...]
def generate_ts(engine, prompt, max_tokens=200):
48def generate_ts(engine, prompt, max_tokens=200):
49    """Generates a single text response for a prompt from a textsynth server
50
51    The server and API key are provided as environment variables:
52
53    LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080
54    LANGUAGEMODELS_TS_KEY is the API key
55    """
56    apikey = os.environ.get("LANGUAGEMODELS_TS_KEY") or ""
57    server = os.environ.get("LANGUAGEMODELS_TS_SERVER") or "https://api.textsynth.com"
58
59    response = requests.post(
60        f"{server}/v1/engines/{engine}/completions",
61        headers={"Authorization": f"Bearer {apikey}"},
62        json={"prompt": prompt, "max_tokens": max_tokens},
63    )
64    resp = response.json()
65    if "text" in resp:
66        return resp["text"]
67    else:
68        raise InferenceException(f"TextSynth error: {resp}")

Generates a single text response for a prompt from a textsynth server

The server and API key are provided as environment variables:

LANGUAGEMODELS_TS_SERVER is the server such as http://localhost:8080 LANGUAGEMODELS_TS_KEY is the API key

def generate_oa(engine, prompt, max_tokens=200, temperature=0):
71def generate_oa(engine, prompt, max_tokens=200, temperature=0):
72    """Generates a single text response for a prompt using OpenAI
73
74    The server and API key are provided as environment variables:
75
76    LANGUAGEMODELS_OA_KEY is the API key
77    """
78    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
79
80    response = requests.post(
81        "https://api.openai.com/v1/completions",
82        headers={
83            "Authorization": f"Bearer {apikey}",
84            "Content-Type": "application/json",
85        },
86        json={
87            "model": engine,
88            "prompt": prompt,
89            "max_tokens": max_tokens,
90            "temperature": temperature,
91        },
92    )
93    resp = response.json()
94
95    try:
96        return resp["choices"][0]["text"]
97    except KeyError:
98        raise InferenceException(f"OpenAI error: {resp}")

Generates a single text response for a prompt using OpenAI

The server and API key are provided as environment variables:

LANGUAGEMODELS_OA_KEY is the API key

def chat_oa(engine, prompt, max_tokens=200, temperature=0):
101def chat_oa(engine, prompt, max_tokens=200, temperature=0):
102    """Generates a single text response for a prompt using OpenAI
103
104    The server and API key are provided as environment variables:
105
106    LANGUAGEMODELS_OA_KEY is the API key
107    """
108    apikey = os.environ.get("LANGUAGEMODELS_OA_KEY")
109
110    response = requests.post(
111        "https://api.openai.com/v1/chat/completions",
112        headers={
113            "Authorization": f"Bearer {apikey}",
114            "Content-Type": "application/json",
115        },
116        json={
117            "model": engine,
118            "messages": [{"role": "user", "content": prompt}],
119            "max_tokens": max_tokens,
120            "temperature": temperature,
121        },
122    )
123    resp = response.json()
124
125    try:
126        return resp["choices"][0]["message"]["content"]
127    except KeyError:
128        raise InferenceException(f"OpenAI error: {resp}")

Generates a single text response for a prompt using OpenAI

The server and API key are provided as environment variables:

LANGUAGEMODELS_OA_KEY is the API key

def stream_results(results, tokenizer):
131def stream_results(results, tokenizer):
132    """Map a token iterator to a substring iterator"""
133    tokens = []
134    last_len = 0
135
136    for result in results:
137        tokens.append(result.token_id)
138        text = tokenizer.decode(tokens)
139        yield text[last_len:]
140        last_len = len(text)

Map a token iterator to a substring iterator

def echo_results(results, tokenizer):
143def echo_results(results, tokenizer):
144    """Output results to stderr as they are collected"""
145    tokens = []
146    last_len = 0
147
148    for result in results:
149        tokens.append(result.token_id)
150        text = tokenizer.decode(tokens)
151        sys.stderr.write(text[last_len:])
152        sys.stderr.flush()
153        last_len = len(text)
154
155    sys.stderr.write("\n\n")
156    sys.stderr.flush()
157    return tokens

Output results to stderr as they are collected

def generate( instructions: List[str], max_tokens: int = 200, temperature: float = 0.1, topk: int = 1, repetition_penalty: float = 0.0, prefix: str = '', suppress: List[str] = [], model: str = 'instruct', stream: bool = False):
160def generate(
161    instructions: List[str],
162    max_tokens: int = 200,
163    temperature: float = 0.1,
164    topk: int = 1,
165    repetition_penalty: float = 0.0,
166    prefix: str = "",
167    suppress: List[str] = [],
168    model: str = "instruct",
169    stream: bool = False,
170):
171    """Generates completions for a prompt
172
173    This may use a local model, or it may make an API call to an external
174    model if API keys are available.
175
176    >>> generate(["What is the capital of France?"])
177    ['...Paris...']
178
179    >>> list(generate(["What is the capital of France?"], stream=True))
180    ['...Paris...']
181    """
182    if os.environ.get("LANGUAGEMODELS_TS_KEY") or os.environ.get(
183        "LANGUAGEMODELS_TS_SERVER"
184    ):
185        return generate_ts("flan_t5_xxl_q4", instructions, max_tokens).strip()
186
187    if os.environ.get("LANGUAGEMODELS_OA_KEY"):
188        return chat_oa("gpt-3.5-turbo", instructions, max_tokens).strip()
189
190    tokenizer, model = get_model(model)
191
192    start_time = perf_counter()
193
194    suppress = [tokenizer.encode(s, add_special_tokens=False).tokens for s in suppress]
195
196    model_info = get_model_info("instruct")
197
198    fmt = model_info.get("prompt_fmt", "{instruction}")
199
200    if repetition_penalty == 0.0:
201        repetition_penalty = model_info.get("repetition_penalty", 1.3)
202
203    prompts = [fmt.replace("{instruction}", inst) for inst in instructions]
204    truncated_prompts = [truncate_prompt(p) for p in prompts]
205
206    prompts_tok = [tokenizer.encode(p).tokens for p in truncated_prompts]
207
208    outputs_ids = []
209    if hasattr(model, "translate_batch"):
210        prefix = tokenizer.encode(prefix, add_special_tokens=False).tokens
211        if stream or (config["echo"] and len(prompts_tok) == 1):
212            results = model.generate_tokens(
213                prompts_tok[0],
214                target_prefix=prefix,
215                repetition_penalty=repetition_penalty,
216                max_decoding_length=max_tokens,
217                sampling_temperature=temperature,
218                sampling_topk=topk,
219                suppress_sequences=suppress,
220            )
221
222            if stream:
223                return stream_results(results, tokenizer)
224            else:
225                outputs_ids = [echo_results(results, tokenizer)]
226        else:
227            results = model.translate_batch(
228                prompts_tok,
229                target_prefix=[prefix] * len(prompts),
230                repetition_penalty=repetition_penalty,
231                max_decoding_length=max_tokens,
232                sampling_temperature=temperature,
233                sampling_topk=topk,
234                suppress_sequences=suppress,
235                beam_size=1,
236            )
237            outputs_tokens = [r.hypotheses[0] for r in results]
238            for output in outputs_tokens:
239                outputs_ids.append([tokenizer.token_to_id(t) for t in output])
240    else:
241        if stream or (config["echo"] and len(prompts_tok) == 1):
242            results = model.generate_tokens(
243                prompts_tok,
244                repetition_penalty=repetition_penalty,
245                max_length=max_tokens,
246                sampling_temperature=temperature,
247                sampling_topk=topk,
248                suppress_sequences=suppress,
249            )
250
251            if stream:
252                return stream_results(results, tokenizer)
253            else:
254                outputs_ids = [echo_results(results, tokenizer)]
255        else:
256            results = model.generate_batch(
257                prompts_tok,
258                repetition_penalty=repetition_penalty,
259                max_length=max_tokens,
260                sampling_temperature=temperature,
261                sampling_topk=topk,
262                suppress_sequences=suppress,
263                beam_size=1,
264                include_prompt_in_result=False,
265            )
266            outputs_ids = [r.sequences_ids[0] for r in results]
267
268    model_info["requests"] = model_info.get("requests", 0) + len(prompts)
269
270    in_toks = sum(len(p) for p in prompts_tok)
271    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
272
273    out_toks = sum(len(o) for o in outputs_ids)
274    model_info["output_tokens"] = model_info.get("output_tokens", 0) + out_toks
275
276    elapsed_time = perf_counter() - start_time
277    model_info["runtime"] = model_info.get("runtime", 0) + elapsed_time
278
279    return [tokenizer.decode(i, skip_special_tokens=True).lstrip() for i in outputs_ids]

Generates completions for a prompt

This may use a local model, or it may make an API call to an external model if API keys are available.

>>> generate(["What is the capital of France?"])
['...Paris...']
>>> list(generate(["What is the capital of France?"], stream=True))
['...Paris...']
def rank_instruct(inputs, targets):
282def rank_instruct(inputs, targets):
283    """Sorts a list of targets by their probabilities
284
285    >>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
286    ... ['positive', 'negative'])
287    [['positive', 'negative']]
288
289    >>> rank_instruct(["Classify fantasy or documentary: "
290    ... "The wizard raised their wand. Classification:"],
291    ... ['fantasy', 'documentary'])
292    [['fantasy', 'documentary']]
293
294    >>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
295    [['six', 'seven'], ['seven', 'six']]
296    """
297    tokenizer, model = get_model("instruct")
298
299    model_info = get_model_info("instruct")
300    fmt = model_info.get("prompt_fmt", "{instruction}")
301    inputs = [fmt.replace("{instruction}", inst) for inst in inputs]
302    inputs = [truncate_prompt(i) for i in inputs]
303
304    targ_tok = [tokenizer.encode(t, add_special_tokens=False).tokens for t in targets]
305    targ_tok *= len(inputs)
306
307    in_tok = []
308    for input in inputs:
309        toks = [tokenizer.encode(input).tokens]
310        in_tok += toks * len(targets)
311
312    if "Generator" in str(type(model)):
313        scores = model.score_batch([i + t for i, t in zip(in_tok, targ_tok)])
314    else:
315        scores = model.score_batch(in_tok, target=targ_tok)
316
317    ret = []
318    for i in range(0, len(inputs) * len(targets), len(targets)):
319        logprobs = [sum(r.log_probs) for r in scores[i : i + len(targets)]]
320        results = sorted(zip(targets, logprobs), key=lambda r: -r[1])
321        ret.append([r[0] for r in results])
322
323    return ret

Sorts a list of targets by their probabilities

>>> rank_instruct(["Classify positive or negative: I love python. Classification:"],
... ['positive', 'negative'])
[['positive', 'negative']]
>>> rank_instruct(["Classify fantasy or documentary: "
... "The wizard raised their wand. Classification:"],
... ['fantasy', 'documentary'])
[['fantasy', 'documentary']]
>>> rank_instruct(["Say six", "Say seven"], ["six", "seven"])
[['six', 'seven'], ['seven', 'six']]
def parse_chat(prompt):
326def parse_chat(prompt):
327    """Converts a chat prompt using special tokens to a plain-text prompt
328
329    This is useful for prompting generic models that have not been fine-tuned
330    for chat using specialized tokens.
331
332    >>> parse_chat('User: What time is it?')
333    Traceback (most recent call last):
334        ....
335    inference.InferenceException: Chat prompt must end with 'Assistant:'
336
337    >>> parse_chat('''User: What time is it?
338    ...
339    ...               Assistant:''')
340    [{'role': 'user', 'content': 'What time is it?'}]
341
342    >>> parse_chat('''
343    ...              A helpful assistant
344    ...
345    ...              User: What time is it?
346    ...
347    ...              Assistant:
348    ...              ''')
349    [{'role': 'system', 'content': 'A helpful assistant'},
350     {'role': 'user', 'content': 'What time is it?'}]
351
352    >>> parse_chat('''
353    ...              A helpful assistant
354    ...
355    ...              User: What time is it?
356    ...
357    ...              Assistant: The time is
358    ...              ''')
359    Traceback (most recent call last):
360        ....
361    inference.InferenceException: Final assistant message must be blank
362
363    >>> parse_chat('''
364    ...              A helpful assistant
365    ...
366    ...              User: First para
367    ...
368    ...              Second para
369    ...
370    ...              Assistant:
371    ...              ''')
372    [{'role': 'system', 'content': 'A helpful assistant'},
373     {'role': 'user', 'content': 'First para\\n\\nSecond para'}]
374
375    >>> parse_chat('''
376    ...              A helpful assistant
377    ...
378    ...              User: What time is it?
379    ...
380    ...              InvalidRole: Nothing
381    ...
382    ...              Assistant:
383    ...              ''')
384    Traceback (most recent call last):
385        ....
386    inference.InferenceException: Invalid chat role: invalidrole
387    """
388
389    if not re.match(r"^\s*\w+:", prompt):
390        prompt = "System: " + prompt
391
392    prompt = "\n\n" + prompt
393
394    chunks = re.split(r"[\r\n]\s*(\w+):", prompt, flags=re.M)
395    chunks = [m.strip() for m in chunks if m.strip()]
396
397    messages = []
398
399    for i in range(0, len(chunks), 2):
400        role = chunks[i].lower()
401
402        try:
403            content = chunks[i + 1]
404            content = re.sub(r"\s*\n\n\s*", "\n\n", content)
405        except IndexError:
406            content = ""
407        messages.append({"role": role, "content": content})
408
409    for message in messages:
410        if message["role"] not in ["system", "user", "assistant"]:
411            raise InferenceException(f"Invalid chat role: {message['role']}")
412
413    if messages[-1]["role"] != "assistant":
414        raise InferenceException("Chat prompt must end with 'Assistant:'")
415
416    if messages[-1]["content"] != "":
417        raise InferenceException("Final assistant message must be blank")
418
419    return messages[:-1]

Converts a chat prompt using special tokens to a plain-text prompt

This is useful for prompting generic models that have not been fine-tuned for chat using specialized tokens.

>>> parse_chat('User: What time is it?')
Traceback (most recent call last):
    ....
inference.InferenceException: Chat prompt must end with 'Assistant:'
>>> parse_chat('''User: What time is it?
...
...               Assistant:''')
[{'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              Assistant:
...              ''')
[{'role': 'system', 'content': 'A helpful assistant'},
 {'role': 'user', 'content': 'What time is it?'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              Assistant: The time is
...              ''')
Traceback (most recent call last):
    ....
inference.InferenceException: Final assistant message must be blank
>>> parse_chat('''
...              A helpful assistant
...
...              User: First para
...
...              Second para
...
...              Assistant:
...              ''')
[{'role': 'system', 'content': 'A helpful assistant'},
 {'role': 'user', 'content': 'First para\n\nSecond para'}]
>>> parse_chat('''
...              A helpful assistant
...
...              User: What time is it?
...
...              InvalidRole: Nothing
...
...              Assistant:
...              ''')
Traceback (most recent call last):
    ....
inference.InferenceException: Invalid chat role: invalidrole