Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nlcodec/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#
# Author: Thamme Gowda [tg (at) isi (dot) edu]
# Author: Thamme Gowda [tg (at) isi (dot) edu]
# Created: 7/19/20

from .core import SeqField, Db, MultipartDb, best_dtype
from .core import SeqField, Db, MultipartDb, best_int_type
69 changes: 51 additions & 18 deletions nlcodec/db/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import shutil
from collections import namedtuple, defaultdict
from pathlib import Path
from typing import List, Iterator, Dict, Any, Tuple
from typing import List, Iterator, Dict, Any, Tuple, Union
import copy
import random
import numpy as np
Expand All @@ -23,6 +23,7 @@

Array = np.ndarray
Record = Tuple[Array]
DType = Union[int, float]

DEF_TYPE = np.uint16 # uint16 is [0, 65,535]
DEF_MIN = np.iinfo(DEF_TYPE).min
Expand All @@ -31,7 +32,8 @@
DEF_MAX_PARTS = 1_000


def best_dtype(mn, mx):

def best_int_type(mn, mx):
"""
determines best (integer) data type for a given range of integer
:param mn: min value
Expand All @@ -48,7 +50,8 @@ def best_dtype(mn, mx):
if info.min <= mn and mx <= info.max: # completely fits inside
found_type = int_type
break
assert found_type, f'Could not find an integet type for [{min},{max}]'
else:
assert ValueError(f'Could not find an integer type for [{min},{max}]')
return found_type


Expand All @@ -66,7 +69,7 @@ class Builder:
# 100MB; each int takes 28 bytes + 8 byte for ref
BUFFER_SIZE = 100_000_000 / (28 + 8)

def __init__(self, name, buf_size=None):
def __init__(self, name, buf_size=None, dtype:DType=int):

self.name = name
self.ids = {}
Expand All @@ -75,9 +78,13 @@ def __init__(self, name, buf_size=None):
self.refs = []
self.max_len = 0
self.buf_size = buf_size or self.BUFFER_SIZE
assert dtype in (int, float, np.integer, np.floating), 'Only int and float types are supported'
dtype = {int: np.integer, float: np.floating}.get(dtype, dtype) # convert to numpy type
self.orig_dtype = dtype

def append(self, id, arr):
assert len(arr) == 0 or isinstance(arr[0], (int, np.integer))

assert len(arr) == 0 or isinstance(arr[0], self.orig_dtype)
self.ids[id] = len(self.refs)
self.max_len = max(len(arr), self.max_len)
self.refs.append([len(self.frozen) + len(self.buffer), len(arr)]) # start, len
Expand All @@ -86,17 +93,34 @@ def append(self, id, arr):
self.shrink()
return self

def best_dtype(self, buffer:List[DType], mn=None, mx=None):
assert buffer
if mn is None:
mn = min(buffer)
if mx is None:
mx = max(buffer)
assert type(mn) is type(mx)
if self.orig_dtype is np.integer:
return best_int_type(mn=mn, mx=mx)
elif self.orig_dtype is np.floating:
# TODO: implement float type, use np.float32 for now
# TODO: use np.finfo to find the best type looking at precision and resolution
return np.float32
else:
raise ValueError(f'Unknown dtype: {self.orig_dtype}')

def shrink(self):
if self.buffer:
dtype = best_dtype(mn=min(self.buffer), mx=max(self.buffer))
dtype = self.best_dtype(self.buffer)
data = np.array(self.buffer, dtype=dtype)
self.frozen = np.concatenate((self.frozen, data))
self.buffer = []

def build(self):
self.shrink()
assert not self.buffer
refs_type = best_dtype(mn=0, mx=max(len(self.frozen), self.max_len))
# reference type is integer
refs_type = best_int_type(mn=0, mx=max(len(self.frozen), self.max_len))
return SeqField(name=self.name, ids=self.ids,
refs=np.array(self.refs, dtype=refs_type), data=self.frozen)

Expand Down Expand Up @@ -135,15 +159,16 @@ def lengths(self):
return ((id, self.refs[idx, 1]) for id, idx in self.ids.items())

@classmethod
def create(cls, name, recs) -> 'SeqField':
builder = cls.Builder(name=name)
def create(cls, name, dtype: DType, recs) -> 'SeqField':
builder = cls.Builder(name=name, orig_dtype=dtype)
for id, arr in recs:
builder.append(id, arr)
return builder.build()

@classmethod
def create_many(cls, names: List[str], recs: Iterator[List[List[int]]]) -> List['SeqField']:
builders = [cls.Builder(name) for name in names]
def create_many(cls, names: List[str], dtypes:List[DType], recs: Iterator[List[List[int]]]) -> List['SeqField']:
assert len(names) == len(dtypes)
builders = [cls.Builder(name, dtype=dtype) for name, dtype in zip(names, dtypes)]
for id, rec in recs:
assert len(rec) == len(builders)
for b, col in zip(builders, rec):
Expand Down Expand Up @@ -196,7 +221,7 @@ def load(cls, path, rec_type=None, shuffle=False) -> 'Db':
return obj

@classmethod
def create(cls, recs, field_names, has_id=False, path=None):
def create(cls, recs, field_names, dtypes=None, has_id=False, path=None):
"""
:param recs: Iterator[List[List[int]]] or Iterator[(id, List[List[int]])]
:param field_names: field names in records
Expand All @@ -205,9 +230,15 @@ def create(cls, recs, field_names, has_id=False, path=None):
:param path: path to save on disk (optional)
:return:
"""

if not has_id:
recs = enumerate(recs)
fields = SeqField.create_many(field_names, recs)
if dtypes is None:
log.warning("dtypes not provided, assuming all int. This will be deprecated soon.")
dtypes = [int] * len(field_names)
else:
assert len(field_names) == len(dtypes)
fields = SeqField.create_many(field_names, dtypes=dtypes, recs=recs)
db = cls(fields=fields)
if path:
db.save(path)
Expand All @@ -220,7 +251,7 @@ def __getitem__(self, _id):
def __iter__(self):
ids = self.ids
if self.shuffle:
ids = copy.copy(ids)
ids = list(ids)
random.shuffle(ids)
for _id in ids:
yield self[_id]
Expand Down Expand Up @@ -291,11 +322,11 @@ def slices(cls, stream, size):

@classmethod
def create(cls, path, recs, field_names, has_id=False, overwrite=False,
part_size=DEF_PART_SIZE, max_parts=DEF_MAX_PARTS):
part_size=DEF_PART_SIZE, max_parts=DEF_MAX_PARTS, dtypes=None):
if not has_id:
recs = enumerate(recs)
builder = cls.Writer(path=path, field_names=field_names, overwrite=overwrite,
max_parts=max_parts)
max_parts=max_parts, dtypes=dtypes)

part_num = -1
for sliced in cls.slices(recs, part_size):
Expand Down Expand Up @@ -363,8 +394,9 @@ def make_eq_len_ran_batches(self, max_toks, max_sents=float('inf'), join_ratio=0
class Writer:

def __init__(self, path, field_names: List[str], overwrite=False,
max_parts=DEF_MAX_PARTS):
max_parts=DEF_MAX_PARTS, dtypes=None):
self.field_names = field_names
self.dtypes = dtypes
path = as_path(path)
if path.exists() and len(os.listdir(path)) > 0:
if overwrite:
Expand All @@ -375,10 +407,11 @@ def __init__(self, path, field_names: List[str], overwrite=False,
path.mkdir(parents=True, exist_ok=True)
self.path = path
self.part_path_pad = part_path_pads(max_parts)


def __call__(self, part_num: int, recs):
# assume recs have ids created externally
part = Db.create(recs, field_names=self.field_names, has_id=True)
part = Db.create(recs, field_names=self.field_names, has_id=True, dtypes=self.dtypes)
part_path = self.path / f'part-{part_num:0{self.part_path_pad}d}'
part.save(part_path)
meta_path = part_path.with_suffix('.meta')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Author: Thamme Gowda [tg (at) isi (dot) edu]
# Created: 7/19/20

from nlcodec.db.core import Db, MultipartDb, best_dtype, log
from nlcodec.db.core import Db, MultipartDb, best_int_type as best_dtype, log
from nlcodec import spark
import numpy as np
import random
Expand Down