Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 127 additions & 17 deletions tools/statvar_importer/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
flags.DEFINE_integer(
'sampler_rows_per_key', 5,
'The maximum number of rows to select for each unique value found.')
flags.DEFINE_integer(
'sampler_uniques_per_column', 10,
'The maximum number of unique values to track per column. '
'If 0 or -1, all unique values are tracked.')
flags.DEFINE_boolean(
'sampler_exhaustive', False,
'If True, sets sampler_output_rows and sampler_uniques_per_column to '
'infinity, and sampler_rows_per_key to 1, to capture every unique value.')
flags.DEFINE_float(
'sampler_rate', -1,
'The sampling rate for random row selection (e.g., 0.1 for 10%).')
Expand All @@ -65,6 +73,11 @@
flags.DEFINE_string(
'sampler_unique_columns', '',
'A comma-separated list of column names to use for selecting unique rows.')
flags.DEFINE_list(
'sampler_column_keys', [],
'A list of "column:file" pairs containing values that MUST be included '
'in the sample if they appear in the input data. '
'Example: "variableMeasured:prominent_svs.txt"')
flags.DEFINE_string('sampler_input_delimiter', ',',
'The delimiter used in the input CSV file.')
flags.DEFINE_string('sampler_input_encoding', 'UTF8',
Expand All @@ -75,6 +88,7 @@
_FLAGS = flags.FLAGS

import file_util
import mcf_file_util

from config_map import ConfigMap
from counters import Counters
Expand Down Expand Up @@ -103,27 +117,37 @@ def __init__(
self,
config_dict: dict = None,
counters: Counters = None,
column_include_values: dict = None,
):
"""Initializes the DataSampler object.

Args:
config_dict: A dictionary of configuration parameters.
counters: A Counters object for tracking statistics.
column_include_values: a dictionary of column-name to set of values
in the column to be included in the sample
"""
self._config = ConfigMap(config_dict=get_default_config())
if config_dict:
self._config.add_configs(config_dict)
self._counters = counters if counters is not None else Counters()
self._column_include_values = column_include_values
self.reset()

def reset(self) -> None:
"""Resets the state of the DataSampler.

This method resets the internal state of the DataSampler, including the
counts of unique column values and the number of selected rows. This is
useful when you want to reuse the same DataSampler instance for sampling
multiple files.
counts of unique column values and the number of selected rows. If
sampler_exhaustive is set in the configuration, it applies overrides
to other configuration parameters to capture all unique values.
"""
if self._config.get('sampler_exhaustive'):
# Exhaustive mode overrides limits to capture all unique values.
self._config.set_config('sampler_output_rows', -1)
self._config.set_config('sampler_uniques_per_column', -1)
self._config.set_config('sampler_rows_per_key', 1)

# Dictionary of unique values: count per column
self._column_counts = {}
# Dictionary of column index: list of header strings
Expand All @@ -144,6 +168,16 @@ def reset(self) -> None:
if col.strip()
]

# Must include values: dict of column_name -> set of values
self._must_include_values = load_column_keys(
self._config.get('sampler_column_keys', []))
if self._column_include_values:
for col, vals in self._column_include_values.items():
self._must_include_values.setdefault(col, set()).update(vals)

# Map of column index -> set of values
self._must_include_indices = {}

def __del__(self) -> None:
"""Logs the column headers and counts upon object deletion."""
logging.log(2, f'Sampler column headers: {self._column_headers}')
Expand Down Expand Up @@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int:
return 0
return col_values.get(value, 0)

def _should_track_column(self, column_index: int) -> bool:
"""Determines if a column should be tracked for unique values.
def _is_unique_column(self, column_index: int) -> bool:
"""Determines if a column is specified for unique value sampling.

Args:
column_index: The index of the column.

Returns:
True if the column should be tracked (either no unique columns
specified or this column is in the unique columns list).
True if the column should be sampled for unique values.
"""
if not self._unique_column_names:
# No specific columns specified, track all
# No specific columns specified, track all for unique sampling
return True
# Check if this column is in our unique columns
return column_index in self._unique_column_indices.values()

def _should_track_column(self, column_index: int) -> bool:
"""Determines if a column should be tracked for counts.

Args:
column_index: The index of the column.

Returns:
True if the column should be tracked for unique values or is a
must-include column.
"""
if self._is_unique_column(column_index):
return True
return column_index in self._must_include_indices

def _process_header_row(self, row: list[str]) -> None:
"""Process a header row to build column name to index mapping.

Expand All @@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None:
Args:
row: A header row containing column names.
"""
if not self._unique_column_names:
return

for index, column_name in enumerate(row):
if column_name in self._unique_column_names:
if self._unique_column_names and column_name in self._unique_column_names:
self._unique_column_indices[column_name] = index
logging.level_debug() and logging.debug(
f'Mapped unique column "{column_name}" to index {index}')

if self._must_include_values and column_name in self._must_include_values:
self._must_include_indices[index] = self._must_include_values[
column_name]
logging.info(
f'Mapped must-include column "{column_name}" to index {index}'
)

def _is_must_include(self, column_index: int, value: str) -> bool:
"""Checks if a column value is in the must-include list."""
if column_index not in self._must_include_indices:
return False
# Normalize the input value before checking against the set
return mcf_file_util.strip_namespace(
value) in self._must_include_indices[column_index]

def _add_column_header(self, column_index: int, value: str) -> str:
"""Adds the first non-empty value of a column as its header.

Expand Down Expand Up @@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
# Too many rows already selected. Drop it.
return False
max_count = self._config.get('sampler_rows_per_key', 3)
if max_count <= 0:
max_count = sys.maxsize
max_uniques_per_col = self._config.get('sampler_uniques_per_column', 10)
if max_uniques_per_col <= 0:
max_uniques_per_col = sys.maxsize

for index in range(len(row)):
# Skip columns not in unique_columns list
if not self._should_track_column(index):
continue
value = row[index]
value_count = self._get_column_count(index, value)

# Rule 1: Always include if it's a must-include value and
# we haven't reached per-key limit.
if value_count < max_count and self._is_must_include(index, value):
self._counters.add_counter('sampler-selected-must-include', 1)
return True

# Skip columns not in unique_columns list for general unique sampling
if not self._is_unique_column(index):
continue

if value_count == 0 or value_count < max_count:
# This is a new value for this column.
col_counts = self._column_counts.get(index, {})
Expand All @@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool:
# No new unique value for the row.
# Check random sampler.
if sample_rate < 0:
sample_rate = self._config.get('sampler_rate')
sample_rate = self._config.get('sampler_rate', -1)
if random.random() <= sample_rate:
self._counters.add_counter('sampler-sampled-rows', 1)
return True
Expand Down Expand Up @@ -425,6 +497,36 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str:
return output_file


def load_column_keys(column_keys: list) -> dict:
"""Returns a dictionary of column name to set of keys loaded from a file.
The set of keys for a column are used as filter when sampling.

Args:
column_keys: comma separated list of column_name:<csv file> with
first column as the keys to be loaded.

Returns:
dictionary of column name to a set of keys for that column
{ <column-name1>: { key1, key2, ...}, <column-name2>: { k1, k2...} ...}
"""
column_map = {}
if not isinstance(column_keys, list):
column_keys = column_keys.split(',')

for col_file in column_keys:
column_name, file_name = col_file.split(':', 1)
if not file_name:
logging.error(f'No file for column {column_name} in {column_keys}')
continue

col_items = file_util.file_load_csv_dict(file_name)
column_map[column_name] = set(col_items.keys())
logging.info(
f'Loaded {len(col_items)} for column {column_name} from {file_name}'
)
return column_map


def sample_csv_file(input_file: str,
output_file: str = '',
config: dict = None) -> str:
Expand All @@ -443,6 +545,8 @@ def sample_csv_file(input_file: str,
- sampler_output_rows: The maximum number of rows to include in the
sample.
- sampler_rate: The sampling rate to use for random selection.
- sampler_exhaustive: If True, overrides limits to capture all unique
values.
- header_rows: The number of header rows to copy from the input file
and search for sampler_unique_columns. Increase this if column names
appear in later header rows (e.g., after a title row).
Expand Down Expand Up @@ -500,20 +604,26 @@ def get_default_config() -> dict:
# Use default values of flags for tests
if not _FLAGS.is_parsed():
_FLAGS.mark_as_parsed()
return {

config = {
'sampler_rate': _FLAGS.sampler_rate,
'sampler_input': _FLAGS.sampler_input,
'sampler_output': _FLAGS.sampler_output,
'sampler_output_rows': _FLAGS.sampler_output_rows,
'header_rows': _FLAGS.sampler_header_rows,
'sampler_rows_per_key': _FLAGS.sampler_rows_per_key,
'sampler_uniques_per_column': _FLAGS.sampler_uniques_per_column,
'sampler_column_regex': _FLAGS.sampler_column_regex,
'sampler_unique_columns': _FLAGS.sampler_unique_columns,
'sampler_column_keys': _FLAGS.sampler_column_keys,
'input_delimiter': _FLAGS.sampler_input_delimiter,
'output_delimiter': _FLAGS.sampler_output_delimiter,
'input_encoding': _FLAGS.sampler_input_encoding,
'sampler_exhaustive': _FLAGS.sampler_exhaustive,
}

return config


def main(_):
sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output)
Expand Down
Loading
Loading