diff --git a/nlcodec/db/__init__.py b/nlcodec/db/__init__.py index 9986ac5..cae5598 100644 --- a/nlcodec/db/__init__.py +++ b/nlcodec/db/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -# Author: Thamme Gowda [tg (at) isi (dot) edu] +# Author: Thamme Gowda [tg (at) isi (dot) edu] # Created: 7/19/20 -from .core import SeqField, Db, MultipartDb, best_dtype \ No newline at end of file +from .core import SeqField, Db, MultipartDb, best_int_type \ No newline at end of file diff --git a/nlcodec/db/core.py b/nlcodec/db/core.py index de71a5d..6417161 100644 --- a/nlcodec/db/core.py +++ b/nlcodec/db/core.py @@ -13,7 +13,7 @@ import shutil from collections import namedtuple, defaultdict from pathlib import Path -from typing import List, Iterator, Dict, Any, Tuple +from typing import List, Iterator, Dict, Any, Tuple, Union import copy import random import numpy as np @@ -23,6 +23,7 @@ Array = np.ndarray Record = Tuple[Array] +DType = Union[int, float] DEF_TYPE = np.uint16 # uint16 is [0, 65,535] DEF_MIN = np.iinfo(DEF_TYPE).min @@ -31,7 +32,8 @@ DEF_MAX_PARTS = 1_000 -def best_dtype(mn, mx): + +def best_int_type(mn, mx): """ determines best (integer) data type for a given range of integer :param mn: min value @@ -48,7 +50,8 @@ def best_dtype(mn, mx): if info.min <= mn and mx <= info.max: # completely fits inside found_type = int_type break - assert found_type, f'Could not find an integet type for [{min},{max}]' + else: + assert ValueError(f'Could not find an integer type for [{min},{max}]') return found_type @@ -66,7 +69,7 @@ class Builder: # 100MB; each int takes 28 bytes + 8 byte for ref BUFFER_SIZE = 100_000_000 / (28 + 8) - def __init__(self, name, buf_size=None): + def __init__(self, name, buf_size=None, dtype:DType=int): self.name = name self.ids = {} @@ -75,9 +78,13 @@ def __init__(self, name, buf_size=None): self.refs = [] self.max_len = 0 self.buf_size = buf_size or self.BUFFER_SIZE + assert dtype in (int, float, np.integer, np.floating), 'Only int and float types are supported' + dtype = {int: np.integer, float: np.floating}.get(dtype, dtype) # convert to numpy type + self.orig_dtype = dtype def append(self, id, arr): - assert len(arr) == 0 or isinstance(arr[0], (int, np.integer)) + + assert len(arr) == 0 or isinstance(arr[0], self.orig_dtype) self.ids[id] = len(self.refs) self.max_len = max(len(arr), self.max_len) self.refs.append([len(self.frozen) + len(self.buffer), len(arr)]) # start, len @@ -86,9 +93,25 @@ def append(self, id, arr): self.shrink() return self + def best_dtype(self, buffer:List[DType], mn=None, mx=None): + assert buffer + if mn is None: + mn = min(buffer) + if mx is None: + mx = max(buffer) + assert type(mn) is type(mx) + if self.orig_dtype is np.integer: + return best_int_type(mn=mn, mx=mx) + elif self.orig_dtype is np.floating: + # TODO: implement float type, use np.float32 for now + # TODO: use np.finfo to find the best type looking at precision and resolution + return np.float32 + else: + raise ValueError(f'Unknown dtype: {self.orig_dtype}') + def shrink(self): if self.buffer: - dtype = best_dtype(mn=min(self.buffer), mx=max(self.buffer)) + dtype = self.best_dtype(self.buffer) data = np.array(self.buffer, dtype=dtype) self.frozen = np.concatenate((self.frozen, data)) self.buffer = [] @@ -96,7 +119,8 @@ def shrink(self): def build(self): self.shrink() assert not self.buffer - refs_type = best_dtype(mn=0, mx=max(len(self.frozen), self.max_len)) + # reference type is integer + refs_type = best_int_type(mn=0, mx=max(len(self.frozen), self.max_len)) return SeqField(name=self.name, ids=self.ids, refs=np.array(self.refs, dtype=refs_type), data=self.frozen) @@ -135,15 +159,16 @@ def lengths(self): return ((id, self.refs[idx, 1]) for id, idx in self.ids.items()) @classmethod - def create(cls, name, recs) -> 'SeqField': - builder = cls.Builder(name=name) + def create(cls, name, dtype: DType, recs) -> 'SeqField': + builder = cls.Builder(name=name, orig_dtype=dtype) for id, arr in recs: builder.append(id, arr) return builder.build() @classmethod - def create_many(cls, names: List[str], recs: Iterator[List[List[int]]]) -> List['SeqField']: - builders = [cls.Builder(name) for name in names] + def create_many(cls, names: List[str], dtypes:List[DType], recs: Iterator[List[List[int]]]) -> List['SeqField']: + assert len(names) == len(dtypes) + builders = [cls.Builder(name, dtype=dtype) for name, dtype in zip(names, dtypes)] for id, rec in recs: assert len(rec) == len(builders) for b, col in zip(builders, rec): @@ -196,7 +221,7 @@ def load(cls, path, rec_type=None, shuffle=False) -> 'Db': return obj @classmethod - def create(cls, recs, field_names, has_id=False, path=None): + def create(cls, recs, field_names, dtypes=None, has_id=False, path=None): """ :param recs: Iterator[List[List[int]]] or Iterator[(id, List[List[int]])] :param field_names: field names in records @@ -205,9 +230,15 @@ def create(cls, recs, field_names, has_id=False, path=None): :param path: path to save on disk (optional) :return: """ + if not has_id: recs = enumerate(recs) - fields = SeqField.create_many(field_names, recs) + if dtypes is None: + log.warning("dtypes not provided, assuming all int. This will be deprecated soon.") + dtypes = [int] * len(field_names) + else: + assert len(field_names) == len(dtypes) + fields = SeqField.create_many(field_names, dtypes=dtypes, recs=recs) db = cls(fields=fields) if path: db.save(path) @@ -220,7 +251,7 @@ def __getitem__(self, _id): def __iter__(self): ids = self.ids if self.shuffle: - ids = copy.copy(ids) + ids = list(ids) random.shuffle(ids) for _id in ids: yield self[_id] @@ -291,11 +322,11 @@ def slices(cls, stream, size): @classmethod def create(cls, path, recs, field_names, has_id=False, overwrite=False, - part_size=DEF_PART_SIZE, max_parts=DEF_MAX_PARTS): + part_size=DEF_PART_SIZE, max_parts=DEF_MAX_PARTS, dtypes=None): if not has_id: recs = enumerate(recs) builder = cls.Writer(path=path, field_names=field_names, overwrite=overwrite, - max_parts=max_parts) + max_parts=max_parts, dtypes=dtypes) part_num = -1 for sliced in cls.slices(recs, part_size): @@ -363,8 +394,9 @@ def make_eq_len_ran_batches(self, max_toks, max_sents=float('inf'), join_ratio=0 class Writer: def __init__(self, path, field_names: List[str], overwrite=False, - max_parts=DEF_MAX_PARTS): + max_parts=DEF_MAX_PARTS, dtypes=None): self.field_names = field_names + self.dtypes = dtypes path = as_path(path) if path.exists() and len(os.listdir(path)) > 0: if overwrite: @@ -375,10 +407,11 @@ def __init__(self, path, field_names: List[str], overwrite=False, path.mkdir(parents=True, exist_ok=True) self.path = path self.part_path_pad = part_path_pads(max_parts) + def __call__(self, part_num: int, recs): # assume recs have ids created externally - part = Db.create(recs, field_names=self.field_names, has_id=True) + part = Db.create(recs, field_names=self.field_names, has_id=True, dtypes=self.dtypes) part_path = self.path / f'part-{part_num:0{self.part_path_pad}d}' part.save(part_path) meta_path = part_path.with_suffix('.meta') diff --git a/tests/test_db.py b/tests/test_db.py index 336149c..5bfb990 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -3,7 +3,7 @@ # Author: Thamme Gowda [tg (at) isi (dot) edu] # Created: 7/19/20 -from nlcodec.db.core import Db, MultipartDb, best_dtype, log +from nlcodec.db.core import Db, MultipartDb, best_int_type as best_dtype, log from nlcodec import spark import numpy as np import random