languagemodels.embeddings

  1import numpy as np
  2from time import perf_counter
  3
  4from languagemodels.models import get_model, get_model_info
  5
  6
  7def embed(docs):
  8    """Compute embeddings for a batch of documents
  9
 10    >>> embed(["I love Python!"])[0].shape
 11    (384,)
 12
 13    >>> embed(["I love Python!"])[0][-3:]
 14    array([0.1..., 0.1..., 0.0...], dtype=float32)
 15
 16    >>> float(np.linalg.norm(embed(["I love Python!"])[0]))
 17    1.0
 18
 19    Embeddings are computed by running the first 512 tokens of each doc
 20    through a forward pass of the embedding model. The last hidden state
 21    of the model is mean pooled to produce a single vector
 22
 23    Documents will be processed in batches. The batch size is fixed at 64
 24    as this size was found to maximize throughput on a number of test
 25    systems while limiting memory usage.
 26    """
 27
 28    tokenizer, model = get_model("embedding")
 29    model_info = get_model_info("embedding")
 30
 31    start_time = perf_counter()
 32
 33    tokens = [tokenizer.encode(doc[:8192]).ids[:512] for doc in docs]
 34
 35    def mean_pool(last_hidden_state):
 36        embedding = np.mean(last_hidden_state, axis=0)
 37        embedding = embedding / np.linalg.norm(embedding)
 38        return embedding
 39
 40    bs = 64
 41    embeddings = []
 42    for i in range(0, len(docs), bs):
 43        outputs = model.forward_batch(tokens[i : i + bs])
 44        embeddings += [mean_pool(lhs) for lhs in np.array(outputs.last_hidden_state)]
 45
 46    model_info["requests"] = model_info.get("requests", 0) + len(tokens)
 47
 48    in_toks = sum(len(d) for d in tokens)
 49    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
 50
 51    runtime = perf_counter() - start_time
 52    model_info["runtime"] = model_info.get("runtime", 0) + runtime
 53
 54    return embeddings
 55
 56
 57def search(query, docs, count=16):
 58    """Return `count` `docs` sorted by match against `query`
 59
 60    :param query: Input to match in search
 61    :param docs: List of docs to search against
 62    :param count: Number of document to return
 63    :return: List of (doc_num, score) tuples sorted by score descending
 64    """
 65
 66    prefix = get_model_info("embedding").get("query_prefix", "")
 67
 68    query_embedding = embed([f"{prefix}{query}"])[0]
 69
 70    scores = np.dot([d.embedding for d in docs], query_embedding)
 71
 72    return [(i, scores[i]) for i in reversed(np.argsort(scores)[-count:])]
 73
 74
 75def get_token_ids(doc):
 76    """Return list of token ids for a document
 77
 78    Note that the tokenzier used here is from the generative model.
 79
 80    This is used for token counting for the context, not for tokenization
 81    before embedding.
 82    """
 83
 84    generative_tokenizer, _ = get_model("instruct", tokenizer_only=True)
 85
 86    # We need to disable and re-enable truncation here
 87    # This allows us to tokenize very large documents
 88    # We won't be feeding the tokens themselves to a model, so this
 89    # shouldn't cause any problems.
 90    trunk = generative_tokenizer.truncation
 91    if trunk:
 92        generative_tokenizer.no_truncation()
 93    ids = generative_tokenizer.encode(doc, add_special_tokens=False).ids
 94    if trunk:
 95        generative_tokenizer.enable_truncation(
 96            trunk["max_length"], stride=trunk["stride"], strategy=trunk["strategy"]
 97        )
 98
 99    return ids
100
101
102def chunk_doc(doc, name="", chunk_size=64, chunk_overlap=8):
103    """Break a document into chunks
104
105    :param doc: Document to chunk
106    :param name: Optional document name
107    :param chunk_size: Length of individual chunks in tokens
108    :param chunk_overlap: Number of tokens to overlap when breaking chunks
109    :return: List of strings representing the chunks
110
111    The simple chunking approach used here consist of the following:
112
113    1. Attempt to chunk the remainder of the document.
114    2. If we can't fit all tokens in chunk_size, backtrack to look for a
115    meaningful cut point.
116    3. If a cut point is found, use that as the chunk boundary. There will
117    be no overlap between this chunk and the next in this case.
118    4. If a cut point is not found, use chunk_size a the boundary. There
119    will be chunk_overlap overlapping tokens starting the next chunk.
120    5. Repeat until entire document has been split into chunks.
121
122    >>> chunk_doc("")
123    []
124
125    >>> chunk_doc(
126    ... "It was the best of times, it was the worst of times, it was the age "
127    ... "of wisdom, it was the age of foolishness, it was the epoch of belief, "
128    ... "it was the epoch of incredulity, it was the season of Light, it was "
129    ... "the season of Darkness, it was the spring of hope, it was the winter "
130    ... "of despair, we had everything before us, we had nothing before us, we "
131    ... "were all going direct to Heaven, we were all going direct the other "
132    ... "way—in short, the period was so far like the present period, that "
133    ... "some of its noisiest authorities insisted on its being received, for "
134    ... "good or for evil, in the superlative degree of comparison only.")
135    ['It was the best of times...']
136
137    >>> chunk_doc(
138    ... "One morning, when Gregor Samsa woke from troubled dreams, he found "
139    ... "himself transformed in his bed into a horrible vermin. He lay on his "
140    ... "armour-like back, and if he lifted his head a little he could see "
141    ... "his brown belly, slightly domed and divided by arches into stiff "
142    ... "sections. The bedding was hardly able to cover it and seemed ready "
143    ... "to slide off any moment. His many legs, pitifully thin compared with "
144    ... "the size of the rest of him, waved about helplessly as he looked.")
145    ['One morning, ...']
146
147    >>> chunk_doc("Hello")
148    ['Hello']
149
150    >>> chunk_doc("Hello " * 65)
151    ['Hello Hello...', 'Hello...']
152
153    >>> chunk_doc("Hello world. " * 24)[0]
154    'Hello world. ...Hello world.'
155
156    >>> len(chunk_doc("Hello world. " * 20))
157    1
158
159    >>> len(chunk_doc("Hello world. " * 24))
160    2
161
162    # Check to make sure sentences aren't broken on decimal points
163    >>> chunk_doc(('z. ' + ' 37.468 ' * 5) * 3)[0]
164    'z. 37.468 ...z.'
165    """
166    generative_tokenizer, _ = get_model("instruct", tokenizer_only=True)
167
168    tokens = get_token_ids(doc)
169
170    separator_tokens = [".", "!", "?", ").", "\n\n", "\n", '."']
171
172    separators = [get_token_ids(t)[-1] for t in separator_tokens]
173
174    name_tokens = []
175
176    label = f"From {name} document:" if name else ""
177
178    if name:
179        name_tokens = get_token_ids(label)
180
181    i = 0
182    chunks = []
183    chunk = name_tokens.copy()
184    while i < len(tokens):
185        token = tokens[i]
186        chunk.append(token)
187        i += 1
188
189        # Save the last chunk if we're done
190        if i == len(tokens):
191            chunks.append(generative_tokenizer.decode(chunk))
192            break
193
194        if len(chunk) == chunk_size:
195            # Backtrack to find a reasonable cut point
196            for j in range(1, chunk_size // 2):
197                if chunk[chunk_size - j] in separators:
198                    ctx = generative_tokenizer.decode(
199                        chunk[chunk_size - j : chunk_size - j + 2]
200                    )
201                    if " " in ctx or "\n" in ctx:
202                        # Found a good separator
203                        text = generative_tokenizer.decode(chunk[: chunk_size - j + 1])
204                        chunks.append(text)
205                        chunk = name_tokens + chunk[chunk_size - j + 1 :]
206                        break
207            else:
208                # No semantically meaningful cutpoint found
209                # Default to a hard cut
210                text = generative_tokenizer.decode(chunk)
211                chunks.append(text)
212                # Share some overlap with next chunk
213                overlap = max(
214                    chunk_overlap, chunk_size - len(name_tokens) - (len(tokens) - i)
215                )
216                chunk = name_tokens + chunk[-overlap:]
217
218    return chunks
219
220
221class Document:
222    """
223    A document used for semantic search
224
225    Documents have content and an embedding that is used to match the content
226    against other semantically similar documents.
227    """
228
229    def __init__(self, content, name="", embedding=None):
230        self.content = content
231        self.embedding = embedding if embedding is not None else embed([content])[0]
232        self.name = name
233
234
235class RetrievalContext:
236    """
237    Provides a context for document retrieval
238
239    Documents are embedded and cached for later search.
240
241    Example usage:
242
243    >>> rc = RetrievalContext()
244    >>> rc.store("Paris is in France.")
245    >>> rc.store("The sky is blue.")
246    >>> rc.store("Mars is a planet.")
247    >>> rc.get_match("Paris is in France.")
248    'Paris is in France.'
249
250    >>> rc.get_match("Where is Paris?")
251    'Paris is in France.'
252
253    >>> rc.clear()
254    >>> rc.get_match("Where is Paris?")
255
256    >>> rc.clear()
257    >>> rc.store(' '.join(['Python'] * 4096))
258    >>> len(rc.chunks)
259    73
260
261    >>> rc.clear()
262    >>> rc.store(' '.join(['Python'] * 232))
263    >>> len(rc.chunks)
264    4
265
266    >>> rc.get_context("What is Python?")
267    'Python Python Python...'
268
269    >>> [len(c.content.split()) for c in rc.chunks]
270    [64, 64, 64, 64]
271
272    >>> len(rc.get_context("What is Python?").split())
273    128
274    """
275
276    def __init__(self, chunk_size=64, chunk_overlap=8):
277        self.chunk_size = chunk_size
278        self.chunk_overlap = chunk_overlap
279        self.clear()
280
281    def clear(self):
282        self.docs = []
283        self.chunks = []
284
285    def store(self, doc, name=""):
286        """Stores a document along with embeddings
287
288        This stores both the document as well as document chunks
289
290        >>> rc = RetrievalContext()
291        >>> rc.clear()
292        >>> rc.store(' '.join(['Python'] * 233))
293        >>> len(rc.chunks)
294        5
295
296        >>> rc.clear()
297        >>> rc.store(' '.join(['Python'] * 232))
298        >>> len(rc.chunks)
299        4
300
301        >>> rc.clear()
302        >>> rc.store('Python')
303        >>> len(rc.chunks)
304        1
305
306        >>> rc.clear()
307        >>> rc.store('It is a language.', 'Python')
308        >>> len(rc.chunks)
309        1
310        >>> [c.content for c in rc.chunks]
311        ['From Python document: It is a language.']
312
313        >>> rc = RetrievalContext()
314        >>> rc.clear()
315        >>> rc.store(' '.join(['details'] * 217), 'Python')
316        >>> len(rc.chunks)
317        5
318
319        >>> rc.clear()
320        >>> rc.store(' '.join(['details'] * 216), 'Python')
321        >>> len(rc.chunks)
322        4
323        >>> [c.content for c in rc.chunks]
324        ['From Python document: details details details...']
325        """
326
327        if doc not in self.docs:
328            self.docs.append(Document(doc))
329            self.store_chunks(doc, name)
330
331    def store_chunks(self, doc, name=""):
332        chunks = chunk_doc(doc, name, self.chunk_size, self.chunk_overlap)
333
334        embeddings = embed(chunks)
335
336        for embedding, chunk in zip(embeddings, chunks):
337            self.chunks.append(Document(chunk, embedding=embedding))
338
339    def get_context(self, query, max_tokens=128):
340        """Gets context matching a query
341
342        Context is capped by token length and is retrieved from stored
343        document chunks
344        """
345
346        if len(self.chunks) == 0:
347            return None
348
349        results = search(query, self.chunks)
350
351        chunks = []
352        tokens = 0
353
354        for chunk_id, score in results:
355            chunk = self.chunks[chunk_id].content
356            chunk_tokens = len(get_token_ids(chunk))
357            if tokens + chunk_tokens <= max_tokens and score > 0.1:
358                chunks.append(chunk)
359                tokens += chunk_tokens
360
361        context = "\n\n".join(chunks)
362
363        return context
364
365    def get_match(self, query):
366        if len(self.docs) == 0:
367            return None
368
369        return self.docs[search(query, self.docs)[0][0]].content
def embed(docs):
 8def embed(docs):
 9    """Compute embeddings for a batch of documents
10
11    >>> embed(["I love Python!"])[0].shape
12    (384,)
13
14    >>> embed(["I love Python!"])[0][-3:]
15    array([0.1..., 0.1..., 0.0...], dtype=float32)
16
17    >>> float(np.linalg.norm(embed(["I love Python!"])[0]))
18    1.0
19
20    Embeddings are computed by running the first 512 tokens of each doc
21    through a forward pass of the embedding model. The last hidden state
22    of the model is mean pooled to produce a single vector
23
24    Documents will be processed in batches. The batch size is fixed at 64
25    as this size was found to maximize throughput on a number of test
26    systems while limiting memory usage.
27    """
28
29    tokenizer, model = get_model("embedding")
30    model_info = get_model_info("embedding")
31
32    start_time = perf_counter()
33
34    tokens = [tokenizer.encode(doc[:8192]).ids[:512] for doc in docs]
35
36    def mean_pool(last_hidden_state):
37        embedding = np.mean(last_hidden_state, axis=0)
38        embedding = embedding / np.linalg.norm(embedding)
39        return embedding
40
41    bs = 64
42    embeddings = []
43    for i in range(0, len(docs), bs):
44        outputs = model.forward_batch(tokens[i : i + bs])
45        embeddings += [mean_pool(lhs) for lhs in np.array(outputs.last_hidden_state)]
46
47    model_info["requests"] = model_info.get("requests", 0) + len(tokens)
48
49    in_toks = sum(len(d) for d in tokens)
50    model_info["input_tokens"] = model_info.get("input_tokens", 0) + in_toks
51
52    runtime = perf_counter() - start_time
53    model_info["runtime"] = model_info.get("runtime", 0) + runtime
54
55    return embeddings

Compute embeddings for a batch of documents

>>> embed(["I love Python!"])[0].shape
(384,)
>>> embed(["I love Python!"])[0][-3:]
array([0.1..., 0.1..., 0.0...], dtype=float32)
>>> float(np.linalg.norm(embed(["I love Python!"])[0]))
1.0

Embeddings are computed by running the first 512 tokens of each doc through a forward pass of the embedding model. The last hidden state of the model is mean pooled to produce a single vector

Documents will be processed in batches. The batch size is fixed at 64 as this size was found to maximize throughput on a number of test systems while limiting memory usage.

def get_token_ids(doc):
 76def get_token_ids(doc):
 77    """Return list of token ids for a document
 78
 79    Note that the tokenzier used here is from the generative model.
 80
 81    This is used for token counting for the context, not for tokenization
 82    before embedding.
 83    """
 84
 85    generative_tokenizer, _ = get_model("instruct", tokenizer_only=True)
 86
 87    # We need to disable and re-enable truncation here
 88    # This allows us to tokenize very large documents
 89    # We won't be feeding the tokens themselves to a model, so this
 90    # shouldn't cause any problems.
 91    trunk = generative_tokenizer.truncation
 92    if trunk:
 93        generative_tokenizer.no_truncation()
 94    ids = generative_tokenizer.encode(doc, add_special_tokens=False).ids
 95    if trunk:
 96        generative_tokenizer.enable_truncation(
 97            trunk["max_length"], stride=trunk["stride"], strategy=trunk["strategy"]
 98        )
 99
100    return ids

Return list of token ids for a document

Note that the tokenzier used here is from the generative model.

This is used for token counting for the context, not for tokenization before embedding.

def chunk_doc(doc, name='', chunk_size=64, chunk_overlap=8):
103def chunk_doc(doc, name="", chunk_size=64, chunk_overlap=8):
104    """Break a document into chunks
105
106    :param doc: Document to chunk
107    :param name: Optional document name
108    :param chunk_size: Length of individual chunks in tokens
109    :param chunk_overlap: Number of tokens to overlap when breaking chunks
110    :return: List of strings representing the chunks
111
112    The simple chunking approach used here consist of the following:
113
114    1. Attempt to chunk the remainder of the document.
115    2. If we can't fit all tokens in chunk_size, backtrack to look for a
116    meaningful cut point.
117    3. If a cut point is found, use that as the chunk boundary. There will
118    be no overlap between this chunk and the next in this case.
119    4. If a cut point is not found, use chunk_size a the boundary. There
120    will be chunk_overlap overlapping tokens starting the next chunk.
121    5. Repeat until entire document has been split into chunks.
122
123    >>> chunk_doc("")
124    []
125
126    >>> chunk_doc(
127    ... "It was the best of times, it was the worst of times, it was the age "
128    ... "of wisdom, it was the age of foolishness, it was the epoch of belief, "
129    ... "it was the epoch of incredulity, it was the season of Light, it was "
130    ... "the season of Darkness, it was the spring of hope, it was the winter "
131    ... "of despair, we had everything before us, we had nothing before us, we "
132    ... "were all going direct to Heaven, we were all going direct the other "
133    ... "way—in short, the period was so far like the present period, that "
134    ... "some of its noisiest authorities insisted on its being received, for "
135    ... "good or for evil, in the superlative degree of comparison only.")
136    ['It was the best of times...']
137
138    >>> chunk_doc(
139    ... "One morning, when Gregor Samsa woke from troubled dreams, he found "
140    ... "himself transformed in his bed into a horrible vermin. He lay on his "
141    ... "armour-like back, and if he lifted his head a little he could see "
142    ... "his brown belly, slightly domed and divided by arches into stiff "
143    ... "sections. The bedding was hardly able to cover it and seemed ready "
144    ... "to slide off any moment. His many legs, pitifully thin compared with "
145    ... "the size of the rest of him, waved about helplessly as he looked.")
146    ['One morning, ...']
147
148    >>> chunk_doc("Hello")
149    ['Hello']
150
151    >>> chunk_doc("Hello " * 65)
152    ['Hello Hello...', 'Hello...']
153
154    >>> chunk_doc("Hello world. " * 24)[0]
155    'Hello world. ...Hello world.'
156
157    >>> len(chunk_doc("Hello world. " * 20))
158    1
159
160    >>> len(chunk_doc("Hello world. " * 24))
161    2
162
163    # Check to make sure sentences aren't broken on decimal points
164    >>> chunk_doc(('z. ' + ' 37.468 ' * 5) * 3)[0]
165    'z. 37.468 ...z.'
166    """
167    generative_tokenizer, _ = get_model("instruct", tokenizer_only=True)
168
169    tokens = get_token_ids(doc)
170
171    separator_tokens = [".", "!", "?", ").", "\n\n", "\n", '."']
172
173    separators = [get_token_ids(t)[-1] for t in separator_tokens]
174
175    name_tokens = []
176
177    label = f"From {name} document:" if name else ""
178
179    if name:
180        name_tokens = get_token_ids(label)
181
182    i = 0
183    chunks = []
184    chunk = name_tokens.copy()
185    while i < len(tokens):
186        token = tokens[i]
187        chunk.append(token)
188        i += 1
189
190        # Save the last chunk if we're done
191        if i == len(tokens):
192            chunks.append(generative_tokenizer.decode(chunk))
193            break
194
195        if len(chunk) == chunk_size:
196            # Backtrack to find a reasonable cut point
197            for j in range(1, chunk_size // 2):
198                if chunk[chunk_size - j] in separators:
199                    ctx = generative_tokenizer.decode(
200                        chunk[chunk_size - j : chunk_size - j + 2]
201                    )
202                    if " " in ctx or "\n" in ctx:
203                        # Found a good separator
204                        text = generative_tokenizer.decode(chunk[: chunk_size - j + 1])
205                        chunks.append(text)
206                        chunk = name_tokens + chunk[chunk_size - j + 1 :]
207                        break
208            else:
209                # No semantically meaningful cutpoint found
210                # Default to a hard cut
211                text = generative_tokenizer.decode(chunk)
212                chunks.append(text)
213                # Share some overlap with next chunk
214                overlap = max(
215                    chunk_overlap, chunk_size - len(name_tokens) - (len(tokens) - i)
216                )
217                chunk = name_tokens + chunk[-overlap:]
218
219    return chunks

Break a document into chunks

Parameters
  • doc: Document to chunk
  • name: Optional document name
  • chunk_size: Length of individual chunks in tokens
  • chunk_overlap: Number of tokens to overlap when breaking chunks
Returns

List of strings representing the chunks

The simple chunking approach used here consist of the following:

  1. Attempt to chunk the remainder of the document.
  2. If we can't fit all tokens in chunk_size, backtrack to look for a meaningful cut point.
  3. If a cut point is found, use that as the chunk boundary. There will be no overlap between this chunk and the next in this case.
  4. If a cut point is not found, use chunk_size a the boundary. There will be chunk_overlap overlapping tokens starting the next chunk.
  5. Repeat until entire document has been split into chunks.
>>> chunk_doc("")
[]
>>> chunk_doc(
... "It was the best of times, it was the worst of times, it was the age "
... "of wisdom, it was the age of foolishness, it was the epoch of belief, "
... "it was the epoch of incredulity, it was the season of Light, it was "
... "the season of Darkness, it was the spring of hope, it was the winter "
... "of despair, we had everything before us, we had nothing before us, we "
... "were all going direct to Heaven, we were all going direct the other "
... "way—in short, the period was so far like the present period, that "
... "some of its noisiest authorities insisted on its being received, for "
... "good or for evil, in the superlative degree of comparison only.")
['It was the best of times...']
>>> chunk_doc(
... "One morning, when Gregor Samsa woke from troubled dreams, he found "
... "himself transformed in his bed into a horrible vermin. He lay on his "
... "armour-like back, and if he lifted his head a little he could see "
... "his brown belly, slightly domed and divided by arches into stiff "
... "sections. The bedding was hardly able to cover it and seemed ready "
... "to slide off any moment. His many legs, pitifully thin compared with "
... "the size of the rest of him, waved about helplessly as he looked.")
['One morning, ...']
>>> chunk_doc("Hello")
['Hello']
>>> chunk_doc("Hello " * 65)
['Hello Hello...', 'Hello...']
>>> chunk_doc("Hello world. " * 24)[0]
'Hello world. ...Hello world.'
>>> len(chunk_doc("Hello world. " * 20))
1
>>> len(chunk_doc("Hello world. " * 24))
2

Check to make sure sentences aren't broken on decimal points

>>> chunk_doc(('z. ' + ' 37.468 ' * 5) * 3)[0]
'z. 37.468 ...z.'
class Document:
222class Document:
223    """
224    A document used for semantic search
225
226    Documents have content and an embedding that is used to match the content
227    against other semantically similar documents.
228    """
229
230    def __init__(self, content, name="", embedding=None):
231        self.content = content
232        self.embedding = embedding if embedding is not None else embed([content])[0]
233        self.name = name

A document used for semantic search

Documents have content and an embedding that is used to match the content against other semantically similar documents.

Document(content, name='', embedding=None)
230    def __init__(self, content, name="", embedding=None):
231        self.content = content
232        self.embedding = embedding if embedding is not None else embed([content])[0]
233        self.name = name
content
embedding
name
class RetrievalContext:
236class RetrievalContext:
237    """
238    Provides a context for document retrieval
239
240    Documents are embedded and cached for later search.
241
242    Example usage:
243
244    >>> rc = RetrievalContext()
245    >>> rc.store("Paris is in France.")
246    >>> rc.store("The sky is blue.")
247    >>> rc.store("Mars is a planet.")
248    >>> rc.get_match("Paris is in France.")
249    'Paris is in France.'
250
251    >>> rc.get_match("Where is Paris?")
252    'Paris is in France.'
253
254    >>> rc.clear()
255    >>> rc.get_match("Where is Paris?")
256
257    >>> rc.clear()
258    >>> rc.store(' '.join(['Python'] * 4096))
259    >>> len(rc.chunks)
260    73
261
262    >>> rc.clear()
263    >>> rc.store(' '.join(['Python'] * 232))
264    >>> len(rc.chunks)
265    4
266
267    >>> rc.get_context("What is Python?")
268    'Python Python Python...'
269
270    >>> [len(c.content.split()) for c in rc.chunks]
271    [64, 64, 64, 64]
272
273    >>> len(rc.get_context("What is Python?").split())
274    128
275    """
276
277    def __init__(self, chunk_size=64, chunk_overlap=8):
278        self.chunk_size = chunk_size
279        self.chunk_overlap = chunk_overlap
280        self.clear()
281
282    def clear(self):
283        self.docs = []
284        self.chunks = []
285
286    def store(self, doc, name=""):
287        """Stores a document along with embeddings
288
289        This stores both the document as well as document chunks
290
291        >>> rc = RetrievalContext()
292        >>> rc.clear()
293        >>> rc.store(' '.join(['Python'] * 233))
294        >>> len(rc.chunks)
295        5
296
297        >>> rc.clear()
298        >>> rc.store(' '.join(['Python'] * 232))
299        >>> len(rc.chunks)
300        4
301
302        >>> rc.clear()
303        >>> rc.store('Python')
304        >>> len(rc.chunks)
305        1
306
307        >>> rc.clear()
308        >>> rc.store('It is a language.', 'Python')
309        >>> len(rc.chunks)
310        1
311        >>> [c.content for c in rc.chunks]
312        ['From Python document: It is a language.']
313
314        >>> rc = RetrievalContext()
315        >>> rc.clear()
316        >>> rc.store(' '.join(['details'] * 217), 'Python')
317        >>> len(rc.chunks)
318        5
319
320        >>> rc.clear()
321        >>> rc.store(' '.join(['details'] * 216), 'Python')
322        >>> len(rc.chunks)
323        4
324        >>> [c.content for c in rc.chunks]
325        ['From Python document: details details details...']
326        """
327
328        if doc not in self.docs:
329            self.docs.append(Document(doc))
330            self.store_chunks(doc, name)
331
332    def store_chunks(self, doc, name=""):
333        chunks = chunk_doc(doc, name, self.chunk_size, self.chunk_overlap)
334
335        embeddings = embed(chunks)
336
337        for embedding, chunk in zip(embeddings, chunks):
338            self.chunks.append(Document(chunk, embedding=embedding))
339
340    def get_context(self, query, max_tokens=128):
341        """Gets context matching a query
342
343        Context is capped by token length and is retrieved from stored
344        document chunks
345        """
346
347        if len(self.chunks) == 0:
348            return None
349
350        results = search(query, self.chunks)
351
352        chunks = []
353        tokens = 0
354
355        for chunk_id, score in results:
356            chunk = self.chunks[chunk_id].content
357            chunk_tokens = len(get_token_ids(chunk))
358            if tokens + chunk_tokens <= max_tokens and score > 0.1:
359                chunks.append(chunk)
360                tokens += chunk_tokens
361
362        context = "\n\n".join(chunks)
363
364        return context
365
366    def get_match(self, query):
367        if len(self.docs) == 0:
368            return None
369
370        return self.docs[search(query, self.docs)[0][0]].content

Provides a context for document retrieval

Documents are embedded and cached for later search.

Example usage:

>>> rc = RetrievalContext()
>>> rc.store("Paris is in France.")
>>> rc.store("The sky is blue.")
>>> rc.store("Mars is a planet.")
>>> rc.get_match("Paris is in France.")
'Paris is in France.'
>>> rc.get_match("Where is Paris?")
'Paris is in France.'
>>> rc.clear()
>>> rc.get_match("Where is Paris?")
>>> rc.clear()
>>> rc.store(' '.join(['Python'] * 4096))
>>> len(rc.chunks)
73
>>> rc.clear()
>>> rc.store(' '.join(['Python'] * 232))
>>> len(rc.chunks)
4
>>> rc.get_context("What is Python?")
'Python Python Python...'
>>> [len(c.content.split()) for c in rc.chunks]
[64, 64, 64, 64]
>>> len(rc.get_context("What is Python?").split())
128
RetrievalContext(chunk_size=64, chunk_overlap=8)
277    def __init__(self, chunk_size=64, chunk_overlap=8):
278        self.chunk_size = chunk_size
279        self.chunk_overlap = chunk_overlap
280        self.clear()
chunk_size
chunk_overlap
def clear(self):
282    def clear(self):
283        self.docs = []
284        self.chunks = []
def store(self, doc, name=''):
286    def store(self, doc, name=""):
287        """Stores a document along with embeddings
288
289        This stores both the document as well as document chunks
290
291        >>> rc = RetrievalContext()
292        >>> rc.clear()
293        >>> rc.store(' '.join(['Python'] * 233))
294        >>> len(rc.chunks)
295        5
296
297        >>> rc.clear()
298        >>> rc.store(' '.join(['Python'] * 232))
299        >>> len(rc.chunks)
300        4
301
302        >>> rc.clear()
303        >>> rc.store('Python')
304        >>> len(rc.chunks)
305        1
306
307        >>> rc.clear()
308        >>> rc.store('It is a language.', 'Python')
309        >>> len(rc.chunks)
310        1
311        >>> [c.content for c in rc.chunks]
312        ['From Python document: It is a language.']
313
314        >>> rc = RetrievalContext()
315        >>> rc.clear()
316        >>> rc.store(' '.join(['details'] * 217), 'Python')
317        >>> len(rc.chunks)
318        5
319
320        >>> rc.clear()
321        >>> rc.store(' '.join(['details'] * 216), 'Python')
322        >>> len(rc.chunks)
323        4
324        >>> [c.content for c in rc.chunks]
325        ['From Python document: details details details...']
326        """
327
328        if doc not in self.docs:
329            self.docs.append(Document(doc))
330            self.store_chunks(doc, name)

Stores a document along with embeddings

This stores both the document as well as document chunks

>>> rc = RetrievalContext()
>>> rc.clear()
>>> rc.store(' '.join(['Python'] * 233))
>>> len(rc.chunks)
5
>>> rc.clear()
>>> rc.store(' '.join(['Python'] * 232))
>>> len(rc.chunks)
4
>>> rc.clear()
>>> rc.store('Python')
>>> len(rc.chunks)
1
>>> rc.clear()
>>> rc.store('It is a language.', 'Python')
>>> len(rc.chunks)
1
>>> [c.content for c in rc.chunks]
['From Python document: It is a language.']
>>> rc = RetrievalContext()
>>> rc.clear()
>>> rc.store(' '.join(['details'] * 217), 'Python')
>>> len(rc.chunks)
5
>>> rc.clear()
>>> rc.store(' '.join(['details'] * 216), 'Python')
>>> len(rc.chunks)
4
>>> [c.content for c in rc.chunks]
['From Python document: details details details...']
def store_chunks(self, doc, name=''):
332    def store_chunks(self, doc, name=""):
333        chunks = chunk_doc(doc, name, self.chunk_size, self.chunk_overlap)
334
335        embeddings = embed(chunks)
336
337        for embedding, chunk in zip(embeddings, chunks):
338            self.chunks.append(Document(chunk, embedding=embedding))
def get_context(self, query, max_tokens=128):
340    def get_context(self, query, max_tokens=128):
341        """Gets context matching a query
342
343        Context is capped by token length and is retrieved from stored
344        document chunks
345        """
346
347        if len(self.chunks) == 0:
348            return None
349
350        results = search(query, self.chunks)
351
352        chunks = []
353        tokens = 0
354
355        for chunk_id, score in results:
356            chunk = self.chunks[chunk_id].content
357            chunk_tokens = len(get_token_ids(chunk))
358            if tokens + chunk_tokens <= max_tokens and score > 0.1:
359                chunks.append(chunk)
360                tokens += chunk_tokens
361
362        context = "\n\n".join(chunks)
363
364        return context

Gets context matching a query

Context is capped by token length and is retrieved from stored document chunks

def get_match(self, query):
366    def get_match(self, query):
367        if len(self.docs) == 0:
368            return None
369
370        return self.docs[search(query, self.docs)[0][0]].content