I like simple things and not dependencies. When it comes to architecting Jina, I keep its core dependencies as simple as numpy, pyzmq, protobuf, and grpcio while still delivering full-stack functionalities. This allows Jina developers & users to bootstrap an end-to-end neural search system without extra packages quickly.

On the vector indexing and querying part, Jina has implemented a baseline vector indexer called NumpyIndexer, a vector indexer that is purely based on numpy. The implementation pretty straightforward: it writes vectors directly to the disk and queries nearest neighbors via dot product. It is simple, requires no extra dependencies, and the performance was reasonable on small data. As the default vector indexer, we have been using it since day one when showcasing quick demos, toy examples, and tutorials.

Recently, this community issue has raised my attention. I realize there is a space of improvement, even for this baseline indexer. In the end, I manage to improve the index and query speed by 1.6x and 2.8x while keeping the memory footprint constant (i.e., invariant to the size of the index data). This blog post summarizes the tricks I used.

Table of Contents

The Scalability Problem

As the issue points out, NumpyIndexer faces a severe scalability problem: in the query time, it tries to load all data into memory, making it (and all inheritances) unusable on big data. If the user has 8GB physical memory on its laptop, then the biggest index size that NumpyIndexer can support is around 8GB, which is 7 digits embeddings at most. Having this in mind, my solution is built on memmap naturally.

numpy.memmap Instead of numpy.frombuffer

Replacing the full-read to numpy.memmap is trivial. Thanks to the ndarray format Jina followed in the index time; the only necessary work is changing how raw_ndarray is built at the query time.

BeforeAfter
1
2
with gzip.open(self.index_abspath, 'rb') as fp:
self.raw_ndarray = np.frombuffer(fp.read(), dtype=self.dtype).reshape([-1, self.num_dim])
1
self.raw_ndarray = np.memmap(self.index_abspath, dtype=self.dtype, mode='r', shape=(self.size, self.num_dim))

This reduces loading time basically to zero, as ndarray is now just a pointer to the virtual memory, it can be built in no time. The data behind this pointer is loaded on demand. Simple ops such as ndarray[1], ndarray.shape won’t trigger the full-read.

Batching with Care

To query nearest neighbors NumpyIndexer computes dot products between query vectors and all indexed vectors row-wise. That means it anyway has to scan over the full data.If we stop here and don’t optimize any further, then the memory consumption is back to the high watermark after the first scan. The correct way is adding “batching” to all computations, e.g. dot products, cosine distance. Jina has provided a powerful @batching decorator from day one, and here is how to use it:

BeforeAfter
1
2
3
... 
elif self.metric == 'euclidean':
dist = _euclidean(queries, self.raw_ndarray)
1
2
3
4
5
6
7
elif self.metric == 'euclidean':
dist = _batch_euclidean(queries, self.raw_ndarray)

...
@batching(merge_over_axis=1, slice_on=2)
def _batch_euclidean(self, raw_A, raw_B):
return _euclidean(raw_A, raw_B)

Lifecycle of memmap

This decorator will slice raw_B into batches and compute _euclidean one by one. All partial results are then merged before returning. However, it didn’t work as expected. Below is the memory footprint when running 100 x 10,000 query against 10,000 x 10,000 indexed vectors:

1
2
3
4
5
6
7
8
9
10
11
12
Line #    Mem usage    Increment   Line Contents
================================================
27 62.1 MiB 62.1 MiB @profile
28 def query():
29 85.0 MiB 7.7 MiB q = [np.random.random([num_query, num_dim]) for j in range(3)]
30 85.0 MiB 0.0 MiB with TimeContext('query'):
31 85.0 MiB 0.1 MiB with NumpyIndexer.load('a.bin') as ni:
32 1172.8 MiB 1087.8 MiB ni.query(q[0], top_k=10)
33 # query again and see if memory increasing further
34 1180.6 MiB 7.8 MiB ni.query(q[1], top_k=10)
35 # query again and see if memory increasing further
36 1195.1 MiB 14.5 MiB ni.query(q[2], top_k=10)

One can observe that all data is loaded after the first query. Though each time batch_euclidean only works with partial data, the program eventually loads all data into the cache buffer. This buffer (as part of the virtual memory) is controlled on the OS level. There are arguments that claim this buffer will be released by OS automatically when the memory is intense. I decided to manually release this buffer by restricting the lifecycle of memmap inside the batching for-loop.

1
2
3
4
5
6
7
8
9
10
11
12
data = np.memmap(self.index_abspath, dtype=self.dtype, mode='r', shape=(self.size, self.num_dim))

for slice in batch_iterator(data[:total_size], b_size, split_over_axis, yield_slice=yield_slice):
# make a new memmap
new_memmap = np.memmap(data.filename, dtype=data.dtype, mode='r', shape=data.shape)
p_data = new_memmap[slice]

# do memory-intensive computation
_euclidean(q, p_data)

# close the memmap mannually
del new_memmap

Zero-copy slicing

When generating slice, the previous Jina uses numpy.take. One gotcha of numpy.take is it copies the data instead of slicing in-place, costing extra memory. One can use Python built-in slice function to implement zero-copy version of numpy.take as follows:

numpy.takebuilt-in slice (zero-copy)
1
np.take(data, range(start, end), axis, mode='clip')
1
2
3
4
_d = data.ndim
sl = [slice(None)] * _d
sl[axis] = slice(start, end)
data[tuple(sl)]

Memory-efficient Euclidean and Cosine

When computing Euclidean and Cosine distance row-wise on two matrices, one can simply use the broadcast feature of ndarray:

1
numpy.sqrt(((A[:,:,None]-B[:,:,None].T)**2).sum(1))

However, [:,:,None] - [:,:,None].T creates a (num_query, num_dim, num_data) matrix in memory and this method can quickly become infeasible on big data. The space complexity is $O(N^3)$. In Jina, we use a more memory-efficient way to compute Euclidean distance:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _ext_A(A):
nA, dim = A.shape
A_ext = _get_ones(nA, dim * 3)
A_ext[:, dim:2 * dim] = A
A_ext[:, 2 * dim:] = A ** 2
return A_ext


def _ext_B(B):
nB, dim = B.shape
B_ext = _get_ones(dim * 3, nB)
B_ext[:dim] = (B ** 2).T
B_ext[dim:2 * dim] = -2.0 * B.T
return B_ext


def _euclidean(A_ext, B_ext):
sqdist = A_ext.dot(B_ext).clip(min=0)
return np.sqrt(sqdist)

The idea is simply doing $(a-b)^2 = a^2 -2ab +b^2$ explicitly: first prefilling the values of $a^2$, $b^2$, $a$, $-2b$, and then computing $-a \times 2b$ part. This reduces the memory to (num_query*3, num_dim) + (num_data*3, num_dim) and the overall space complexity to $O(N^2)$.

Notice that when computing _euclidean with the @batching decorator, A is the query matrix and remains constant across all batching iterations. Hence, A_ext can be precomputed before the batching loop to avoid unnecessary computation.

Finally, the Cosine distance is nothing more but the “normed” Euclidean distance. It can be simply built on top of _euclidean function by feeding _norm(A) and _norm(B) to it:

1
2
def _norm(A):
return A / np.linalg.norm(A, ord=2, axis=1, keepdims=True)

Removing gzip compression

Jina uses gzip to store the vector data. This has been an old feature since the day one. However, this feature does not play well with memmap because the decompression will have to load all data into memory and this is against the idea of on-demand computing.

While removing gzip I found it didn’t add many benefits in the first place. The table below summarizes with/without gzip compression on v0.5.5 (this version does not include any improvement mentioned abov). One can see that the index time increases 15x, whereas results in only 40MB space-saving.

v0.5.5 ModeTime (second)On-disk Space (MB)
Index2.13800
Index (compressed)29.36759
Query4.56
Query (compressed)5.52

The script for benchmarking can be found here.

Summary

Putting everything together, now we get a pretty good baseline vector indexer that uses constant memory. We have rolled out this improved NumpyIndexer in v0.5.6 patch and all examples, tutorials, and jina hello-world can immediately enjoy this improvement.

VersionTimeMemory
IndexQueryIndexQuery
v0.5.61.34s1.61s7.3MB273.6MB
v0.5.52.14s4.56s11.5MB910.2MB
v0.5.5 (compressed)29.36s5.52s11.4MB902.4MB

To me the use of memmap is the most inspiring part. It makes me think a lot. This function and memoryview() can be used in Jina more extensively to improve its scalability further.