Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"extraPaths": [
"rerun_py/rerun_sdk"
"rerun_py/rerun_sdk",
"rerun_py/tests/api_sandbox"
],
"exclude": [
"**/node_modules",
Expand Down
4 changes: 2 additions & 2 deletions rerun_py/rerun_bindings/rerun_bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import datafusion as dfn
import numpy as np
import numpy.typing as npt
import pyarrow as pa
from rerun.catalog import CatalogClient
from typing_extensions import deprecated # type: ignore[misc, unused-ignore]

from .types import (
Expand Down Expand Up @@ -1283,8 +1282,9 @@ class Entry:
def name(self) -> str:
"""The entry's name."""

# TODO(RR-2938): this should return `CatalogClient`
@property
def catalog(self) -> CatalogClient:
def catalog(self) -> CatalogClientInternal:
"""The catalog client that this entry belongs to."""

@property
Expand Down
72 changes: 52 additions & 20 deletions rerun_py/tests/api_sandbox/rerun_draft/catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import atexit
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any

from rerun import catalog as _catalog
Expand All @@ -16,6 +19,8 @@ class CatalogClient:

def __init__(self, address: str, token: str | None = None) -> None:
self._inner = _catalog.CatalogClient(address, token)
self.tmpdirs = []
atexit.register(self._cleanup)

def __repr__(self) -> str:
return repr(self._inner)
Expand Down Expand Up @@ -60,18 +65,14 @@ def get_dataset_entry(self, *, id: EntryId | str | None = None, name: str | None
"""Returns a dataset entry by its ID or name."""
return DatasetEntry(self._inner.get_dataset_entry(id=id, name=name))

def get_table_entry(self, *, id: EntryId | str | None = None, name: str | None = None) -> TableEntry:
def get_table(self, *, id: EntryId | str | None = None, name: str | None = None) -> TableEntry:
"""Returns a table entry by its ID or name."""
return TableEntry(self._inner.get_table_entry(id=id, name=name))

def get_dataset(self, *, id: EntryId | str | None = None, name: str | None = None) -> DatasetEntry:
"""Returns a dataset by its ID or name."""
return DatasetEntry(self._inner.get_dataset(id=id, name=name))

def get_table(self, *, id: EntryId | str | None = None, name: str | None = None) -> datafusion.DataFrame:
"""Returns a table by its ID or name as a DataFrame."""
return self._inner.get_table(id=id, name=name)

def create_dataset(self, name: str) -> DatasetEntry:
"""Creates a new dataset with the given name."""
return DatasetEntry(self._inner.create_dataset(name))
Expand All @@ -80,8 +81,12 @@ def register_table(self, name: str, url: str) -> TableEntry:
"""Registers a foreign Lance table as a new table entry."""
return TableEntry(self._inner.register_table(name, url))

def create_table_entry(self, name: str, schema, url: str) -> TableEntry:
def create_table(self, name: str, schema, url: str | None = None) -> TableEntry:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we tried to establish the convention that: xyz would return "raw data" (aka dataframe), while xyz_entry returns the python object representation. We couldn't really make it super consistent, but the df/object hybrid obviously doesn't help here.

"""Create and register a new table."""
if url is None:
tmpdir = tempfile.TemporaryDirectory()
self.tmpdirs.append(tmpdir)
url = Path(tmpdir.name).as_uri()
Comment on lines +86 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the utter nastiness we are allowed in here 😅

return TableEntry(self._inner.create_table_entry(name, schema, url))

def write_table(self, name: str, batches, insert_mode) -> None:
Expand All @@ -101,6 +106,14 @@ def ctx(self) -> datafusion.SessionContext:
"""Returns a DataFusion session context for querying the catalog."""
return self._inner.ctx

def _cleanup(self) -> None:
# Safety net: avoid warning if GC happens late
try:
for tmpdir in self.tmpdirs:
tmpdir.cleanup()
except Exception:
pass


class Entry:
"""An entry in the catalog."""
Expand Down Expand Up @@ -295,23 +308,42 @@ class TableEntry(Entry):

def __init__(self, inner: _catalog.TableEntry) -> None:
super().__init__(inner)
# Cache the dataframe for forwarding
self._df = inner.df()
self._inner = inner

def __datafusion_table_provider__(self) -> Any:
return self._inner.__datafusion_table_provider__()
def client(self) -> CatalogClient:
"""Returns the CatalogClient associated with this table."""
inner_catalog = _catalog.CatalogClient.__new__(_catalog.CatalogClient) # bypass __init__
inner_catalog._raw_client = self._inner.catalog
outer_catalog = CatalogClient.__new__(CatalogClient) # bypass __init__
outer_catalog._inner = inner_catalog

def to_arrow_reader(self) -> pa.RecordBatchReader:
return self._inner.to_arrow_reader()
return outer_catalog

def __getattr__(self, name: str) -> Any:
"""Forward DataFrame methods to the underlying dataframe."""
# First try to get from Entry base class
try:
return super().__getattribute__(name)
except AttributeError:
# Then forward to the dataframe
return getattr(self._df, name)
def append(self, **named_params: Any) -> None:
"""Convert Python objects into columns of data and append them to a table."""
self.client().append_to_table(self._inner.name, **named_params)

def update(self, *, name: str | None = None) -> None:
return self._inner.update(name=name)

def reader(self) -> datafusion.DataFrame:
"""
Exposes the contents of the table via a datafusion DataFrame.

Note: this is equivalent to `catalog.ctx.table(<tablename>)`.

This operation is lazy. The data will not be read from the source table until consumed
from the DataFrame.
"""
return self._inner.df()

def schema(self) -> pa.Schema:
"""Returns the schema of the table."""
return self.reader().schema()

def to_polars(self) -> Any:
"""Returns the table as a Polars DataFrame."""
return self.reader().to_polars()


AlreadyExistsError = _catalog.AlreadyExistsError
Expand Down
36 changes: 33 additions & 3 deletions rerun_py/tests/api_sandbox/rerun_draft/server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from rerun import server as _server

from .catalog import CatalogClient

if TYPE_CHECKING:
from os import PathLike
from types import TracebackType


class Server:
__init__ = _server.Server.__init__
address = _server.Server.address
is_running = _server.Server.is_running
shutdown = _server.Server.shutdown
__enter__ = _server.Server.__enter__
__exit__ = _server.Server.__exit__

def __init__(
self,
*,
address: str = "0.0.0.0",
port: int | None = None,
datasets: dict[str, PathLike[str]] | None = None,
tables: dict[str, PathLike[str]] | None = None,
) -> None:
self._internal = _server.Server(
address=address,
port=port,
datasets=datasets,
tables=tables,
)

def __enter__(self) -> Server:
self._internal.__enter__()
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self._internal.__exit__(exc_type, exc_value, traceback)

def client(self) -> CatalogClient:
return CatalogClient(address=self.address())
60 changes: 60 additions & 0 deletions rerun_py/tests/api_sandbox/test_current/test_table_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import datafusion
import pyarrow as pa
import rerun as rr
from inline_snapshot import snapshot as inline_snapshot

if TYPE_CHECKING:
import pytest


def test_table_api(tmp_path_factory: pytest.TempPathFactory) -> None:
with rr.server.Server() as server:
client = server.client()

tmp_path = tmp_path_factory.mktemp("my_table")

table = client.create_table_entry(
"my_table",
pa.schema([
("rerun_segment_id", pa.string()),
("operator", pa.string()),
]),
tmp_path.as_uri(),
)

assert isinstance(table.df(), datafusion.DataFrame)

assert str(table.df().schema()) == inline_snapshot("""\
rerun_segment_id: string
operator: string
-- schema metadata --
sorbet:version: '0.1.1'\
""")

assert str(table.df().collect()) == inline_snapshot("[]")

client.append_to_table(
"my_table",
rerun_segment_id=["segment_001", "segment_002", "segment_003"],
operator=["alice", "bob", "carol"],
)

assert str(table.df().select("rerun_segment_id", "operator")) == inline_snapshot("""\
┌─────────────────────┬─────────────────────┐
│ rerun_segment_id ┆ operator │
│ --- ┆ --- │
│ type: nullable Utf8 ┆ type: nullable Utf8 │
╞═════════════════════╪═════════════════════╡
│ segment_001 ┆ alice │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ segment_002 ┆ bob │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ segment_003 ┆ carol │
└─────────────────────┴─────────────────────┘\
""")

assert str(table.df()) == str(client.ctx.table("my_table"))
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_catalog_basics(tmp_path: Path) -> None:
client = server.client()

client.create_dataset("my_dataset")
client.create_table_entry("my_table", pa.schema([]), tmp_path.as_uri())
client.create_table("my_table", pa.schema([]), tmp_path.as_uri())

df = client.entries()

Expand Down
6 changes: 3 additions & 3 deletions rerun_py/tests/api_sandbox/test_draft/test_polars_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_entries_to_polars(tmp_path: Path) -> None:
client = server.client()

client.create_dataset("my_dataset")
client.create_table_entry("my_table", pa.schema([]), tmp_path.as_uri())
client.create_table("my_table", pa.schema([]), tmp_path.as_uri())

df = client.entries().to_polars()

Expand Down Expand Up @@ -51,14 +51,14 @@ def test_entries_to_polars(tmp_path: Path) -> None:
def test_table_to_polars(tmp_path: Path) -> None:
with rr.server.Server() as server:
client = server.client()
client.create_table_entry(
client.create_table(
"my_table",
pa.schema([pa.field("int16", pa.int16()), pa.field("string_list", pa.list_(pa.string()))]),
tmp_path.as_uri(),
)
client.append_to_table("my_table", int16=[12], string_list=[["a", "b", "c"]])

df = client.get_table_entry(name="my_table").to_polars()
df = client.get_table(name="my_table").to_polars()

assert str(df) == inline_snapshot("""\
shape: (1, 2)
Expand Down
51 changes: 51 additions & 0 deletions rerun_py/tests/api_sandbox/test_draft/test_table_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import datafusion
import pyarrow as pa
import rerun_draft as rr
from inline_snapshot import snapshot as inline_snapshot


def test_table_api() -> None:
with rr.server.Server() as server:
client = server.client()

table = client.create_table(
"my_table",
pa.schema([
("rerun_segment_id", pa.string()),
("operator", pa.string()),
]),
)

assert isinstance(table.reader(), datafusion.DataFrame)

assert str(table.schema()) == inline_snapshot("""\
rerun_segment_id: string
operator: string
-- schema metadata --
sorbet:version: '0.1.1'\
""")

assert str(table.reader().collect()) == inline_snapshot("[]")

table.append(
rerun_segment_id=["segment_001", "segment_002", "segment_003"],
operator=["alice", "bob", "carol"],
)

assert str(table.reader().select("rerun_segment_id", "operator")) == inline_snapshot("""\
┌─────────────────────┬─────────────────────┐
│ rerun_segment_id ┆ operator │
│ --- ┆ --- │
│ type: nullable Utf8 ┆ type: nullable Utf8 │
╞═════════════════════╪═════════════════════╡
│ segment_001 ┆ alice │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ segment_002 ┆ bob │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ segment_003 ┆ carol │
└─────────────────────┴─────────────────────┘\
""")

assert str(table.reader()) == str(client.ctx.table("my_table"))