-
Notifications
You must be signed in to change notification settings - Fork 571
Some tweaks to draft table APIs #11880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
385f282
3012770
812c9e2
b1c44a1
409302e
f65c37b
0218496
30147c5
6cb1c17
b71c2e1
1dec6c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,8 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import atexit | ||
| import tempfile | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from rerun import catalog as _catalog | ||
|
|
@@ -16,6 +19,8 @@ class CatalogClient: | |
|
|
||
| def __init__(self, address: str, token: str | None = None) -> None: | ||
| self._inner = _catalog.CatalogClient(address, token) | ||
| self.tmpdirs = [] | ||
| atexit.register(self._cleanup) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return repr(self._inner) | ||
|
|
@@ -60,18 +65,14 @@ def get_dataset_entry(self, *, id: EntryId | str | None = None, name: str | None | |
| """Returns a dataset entry by its ID or name.""" | ||
| return DatasetEntry(self._inner.get_dataset_entry(id=id, name=name)) | ||
|
|
||
| def get_table_entry(self, *, id: EntryId | str | None = None, name: str | None = None) -> TableEntry: | ||
| def get_table(self, *, id: EntryId | str | None = None, name: str | None = None) -> TableEntry: | ||
| """Returns a table entry by its ID or name.""" | ||
| return TableEntry(self._inner.get_table_entry(id=id, name=name)) | ||
|
|
||
| def get_dataset(self, *, id: EntryId | str | None = None, name: str | None = None) -> DatasetEntry: | ||
| """Returns a dataset by its ID or name.""" | ||
| return DatasetEntry(self._inner.get_dataset(id=id, name=name)) | ||
|
|
||
| def get_table(self, *, id: EntryId | str | None = None, name: str | None = None) -> datafusion.DataFrame: | ||
| """Returns a table by its ID or name as a DataFrame.""" | ||
| return self._inner.get_table(id=id, name=name) | ||
|
|
||
| def create_dataset(self, name: str) -> DatasetEntry: | ||
| """Creates a new dataset with the given name.""" | ||
| return DatasetEntry(self._inner.create_dataset(name)) | ||
|
|
@@ -80,8 +81,12 @@ def register_table(self, name: str, url: str) -> TableEntry: | |
| """Registers a foreign Lance table as a new table entry.""" | ||
| return TableEntry(self._inner.register_table(name, url)) | ||
|
|
||
| def create_table_entry(self, name: str, schema, url: str) -> TableEntry: | ||
| def create_table(self, name: str, schema, url: str | None = None) -> TableEntry: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: we tried to establish the convention that: |
||
| """Create and register a new table.""" | ||
| if url is None: | ||
| tmpdir = tempfile.TemporaryDirectory() | ||
| self.tmpdirs.append(tmpdir) | ||
| url = Path(tmpdir.name).as_uri() | ||
|
Comment on lines
+86
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I love the utter nastiness we are allowed in here 😅 |
||
| return TableEntry(self._inner.create_table_entry(name, schema, url)) | ||
|
|
||
| def write_table(self, name: str, batches, insert_mode) -> None: | ||
|
|
@@ -101,6 +106,14 @@ def ctx(self) -> datafusion.SessionContext: | |
| """Returns a DataFusion session context for querying the catalog.""" | ||
| return self._inner.ctx | ||
|
|
||
| def _cleanup(self) -> None: | ||
| # Safety net: avoid warning if GC happens late | ||
| try: | ||
| for tmpdir in self.tmpdirs: | ||
| tmpdir.cleanup() | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
| class Entry: | ||
| """An entry in the catalog.""" | ||
|
|
@@ -295,23 +308,42 @@ class TableEntry(Entry): | |
|
|
||
| def __init__(self, inner: _catalog.TableEntry) -> None: | ||
| super().__init__(inner) | ||
| # Cache the dataframe for forwarding | ||
| self._df = inner.df() | ||
| self._inner = inner | ||
|
|
||
| def __datafusion_table_provider__(self) -> Any: | ||
| return self._inner.__datafusion_table_provider__() | ||
| def client(self) -> CatalogClient: | ||
| """Returns the CatalogClient associated with this table.""" | ||
| inner_catalog = _catalog.CatalogClient.__new__(_catalog.CatalogClient) # bypass __init__ | ||
| inner_catalog._raw_client = self._inner.catalog | ||
| outer_catalog = CatalogClient.__new__(CatalogClient) # bypass __init__ | ||
| outer_catalog._inner = inner_catalog | ||
|
|
||
| def to_arrow_reader(self) -> pa.RecordBatchReader: | ||
| return self._inner.to_arrow_reader() | ||
| return outer_catalog | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| """Forward DataFrame methods to the underlying dataframe.""" | ||
| # First try to get from Entry base class | ||
| try: | ||
| return super().__getattribute__(name) | ||
| except AttributeError: | ||
| # Then forward to the dataframe | ||
| return getattr(self._df, name) | ||
| def append(self, **named_params: Any) -> None: | ||
| """Convert Python objects into columns of data and append them to a table.""" | ||
| self.client().append_to_table(self._inner.name, **named_params) | ||
|
|
||
| def update(self, *, name: str | None = None) -> None: | ||
| return self._inner.update(name=name) | ||
|
|
||
| def reader(self) -> datafusion.DataFrame: | ||
| """ | ||
| Exposes the contents of the table via a datafusion DataFrame. | ||
|
|
||
| Note: this is equivalent to `catalog.ctx.table(<tablename>)`. | ||
|
|
||
| This operation is lazy. The data will not be read from the source table until consumed | ||
| from the DataFrame. | ||
| """ | ||
| return self._inner.df() | ||
|
|
||
| def schema(self) -> pa.Schema: | ||
| """Returns the schema of the table.""" | ||
| return self.reader().schema() | ||
|
|
||
| def to_polars(self) -> Any: | ||
| """Returns the table as a Polars DataFrame.""" | ||
| return self.reader().to_polars() | ||
|
|
||
|
|
||
| AlreadyExistsError = _catalog.AlreadyExistsError | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,47 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| from rerun import server as _server | ||
|
|
||
| from .catalog import CatalogClient | ||
|
|
||
| if TYPE_CHECKING: | ||
| from os import PathLike | ||
| from types import TracebackType | ||
|
|
||
|
|
||
| class Server: | ||
| __init__ = _server.Server.__init__ | ||
| address = _server.Server.address | ||
| is_running = _server.Server.is_running | ||
| shutdown = _server.Server.shutdown | ||
| __enter__ = _server.Server.__enter__ | ||
| __exit__ = _server.Server.__exit__ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| address: str = "0.0.0.0", | ||
| port: int | None = None, | ||
| datasets: dict[str, PathLike[str]] | None = None, | ||
| tables: dict[str, PathLike[str]] | None = None, | ||
| ) -> None: | ||
| self._internal = _server.Server( | ||
| address=address, | ||
| port=port, | ||
| datasets=datasets, | ||
| tables=tables, | ||
| ) | ||
|
|
||
| def __enter__(self) -> Server: | ||
| self._internal.__enter__() | ||
| return self | ||
|
|
||
| def __exit__( | ||
| self, | ||
| exc_type: type[BaseException] | None, | ||
| exc_value: BaseException | None, | ||
| traceback: TracebackType | None, | ||
| ) -> None: | ||
| self._internal.__exit__(exc_type, exc_value, traceback) | ||
|
|
||
| def client(self) -> CatalogClient: | ||
| return CatalogClient(address=self.address()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import datafusion | ||
| import pyarrow as pa | ||
| import rerun as rr | ||
| from inline_snapshot import snapshot as inline_snapshot | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pytest | ||
|
|
||
|
|
||
| def test_table_api(tmp_path_factory: pytest.TempPathFactory) -> None: | ||
| with rr.server.Server() as server: | ||
| client = server.client() | ||
|
|
||
| tmp_path = tmp_path_factory.mktemp("my_table") | ||
|
|
||
| table = client.create_table_entry( | ||
| "my_table", | ||
| pa.schema([ | ||
| ("rerun_segment_id", pa.string()), | ||
| ("operator", pa.string()), | ||
| ]), | ||
| tmp_path.as_uri(), | ||
| ) | ||
|
|
||
| assert isinstance(table.df(), datafusion.DataFrame) | ||
|
|
||
| assert str(table.df().schema()) == inline_snapshot("""\ | ||
| rerun_segment_id: string | ||
| operator: string | ||
| -- schema metadata -- | ||
| sorbet:version: '0.1.1'\ | ||
| """) | ||
|
|
||
| assert str(table.df().collect()) == inline_snapshot("[]") | ||
|
|
||
| client.append_to_table( | ||
| "my_table", | ||
| rerun_segment_id=["segment_001", "segment_002", "segment_003"], | ||
| operator=["alice", "bob", "carol"], | ||
| ) | ||
|
|
||
| assert str(table.df().select("rerun_segment_id", "operator")) == inline_snapshot("""\ | ||
| ┌─────────────────────┬─────────────────────┐ | ||
| │ rerun_segment_id ┆ operator │ | ||
| │ --- ┆ --- │ | ||
| │ type: nullable Utf8 ┆ type: nullable Utf8 │ | ||
| ╞═════════════════════╪═════════════════════╡ | ||
| │ segment_001 ┆ alice │ | ||
| ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ | ||
| │ segment_002 ┆ bob │ | ||
| ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ | ||
| │ segment_003 ┆ carol │ | ||
| └─────────────────────┴─────────────────────┘\ | ||
| """) | ||
|
|
||
| assert str(table.df()) == str(client.ctx.table("my_table")) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import datafusion | ||
| import pyarrow as pa | ||
| import rerun_draft as rr | ||
| from inline_snapshot import snapshot as inline_snapshot | ||
|
|
||
|
|
||
| def test_table_api() -> None: | ||
| with rr.server.Server() as server: | ||
| client = server.client() | ||
|
|
||
| table = client.create_table( | ||
| "my_table", | ||
| pa.schema([ | ||
| ("rerun_segment_id", pa.string()), | ||
| ("operator", pa.string()), | ||
| ]), | ||
| ) | ||
|
|
||
| assert isinstance(table.reader(), datafusion.DataFrame) | ||
|
|
||
| assert str(table.schema()) == inline_snapshot("""\ | ||
| rerun_segment_id: string | ||
| operator: string | ||
| -- schema metadata -- | ||
| sorbet:version: '0.1.1'\ | ||
| """) | ||
|
|
||
| assert str(table.reader().collect()) == inline_snapshot("[]") | ||
|
|
||
| table.append( | ||
| rerun_segment_id=["segment_001", "segment_002", "segment_003"], | ||
| operator=["alice", "bob", "carol"], | ||
| ) | ||
|
|
||
| assert str(table.reader().select("rerun_segment_id", "operator")) == inline_snapshot("""\ | ||
| ┌─────────────────────┬─────────────────────┐ | ||
| │ rerun_segment_id ┆ operator │ | ||
| │ --- ┆ --- │ | ||
| │ type: nullable Utf8 ┆ type: nullable Utf8 │ | ||
| ╞═════════════════════╪═════════════════════╡ | ||
| │ segment_001 ┆ alice │ | ||
| ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ | ||
| │ segment_002 ┆ bob │ | ||
| ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ | ||
| │ segment_003 ┆ carol │ | ||
| └─────────────────────┴─────────────────────┘\ | ||
| """) | ||
|
|
||
| assert str(table.reader()) == str(client.ctx.table("my_table")) |
Uh oh!
There was an error while loading. Please reload this page.