diff --git a/tools/statvar_importer/data_sampler.py b/tools/statvar_importer/data_sampler.py index 671dbc3091..cf8d775a7f 100644 --- a/tools/statvar_importer/data_sampler.py +++ b/tools/statvar_importer/data_sampler.py @@ -54,6 +54,14 @@ flags.DEFINE_integer( 'sampler_rows_per_key', 5, 'The maximum number of rows to select for each unique value found.') +flags.DEFINE_integer( + 'sampler_uniques_per_column', 10, + 'The maximum number of unique values to track per column. ' + 'If 0 or -1, all unique values are tracked.') +flags.DEFINE_boolean( + 'sampler_exhaustive', False, + 'If True, sets sampler_output_rows and sampler_uniques_per_column to ' + 'infinity, and sampler_rows_per_key to 1, to capture every unique value.') flags.DEFINE_float( 'sampler_rate', -1, 'The sampling rate for random row selection (e.g., 0.1 for 10%).') @@ -65,6 +73,11 @@ flags.DEFINE_string( 'sampler_unique_columns', '', 'A comma-separated list of column names to use for selecting unique rows.') +flags.DEFINE_list( + 'sampler_column_keys', [], + 'A list of "column:file" pairs containing values that MUST be included ' + 'in the sample if they appear in the input data. ' + 'Example: "variableMeasured:prominent_svs.txt"') flags.DEFINE_string('sampler_input_delimiter', ',', 'The delimiter used in the input CSV file.') flags.DEFINE_string('sampler_input_encoding', 'UTF8', @@ -75,6 +88,7 @@ _FLAGS = flags.FLAGS import file_util +import mcf_file_util from config_map import ConfigMap from counters import Counters @@ -103,27 +117,37 @@ def __init__( self, config_dict: dict = None, counters: Counters = None, + column_include_values: dict = None, ): """Initializes the DataSampler object. Args: config_dict: A dictionary of configuration parameters. counters: A Counters object for tracking statistics. + column_include_values: a dictionary of column-name to set of values + in the column to be included in the sample """ self._config = ConfigMap(config_dict=get_default_config()) if config_dict: self._config.add_configs(config_dict) self._counters = counters if counters is not None else Counters() + self._column_include_values = column_include_values self.reset() def reset(self) -> None: """Resets the state of the DataSampler. This method resets the internal state of the DataSampler, including the - counts of unique column values and the number of selected rows. This is - useful when you want to reuse the same DataSampler instance for sampling - multiple files. + counts of unique column values and the number of selected rows. If + sampler_exhaustive is set in the configuration, it applies overrides + to other configuration parameters to capture all unique values. """ + if self._config.get('sampler_exhaustive'): + # Exhaustive mode overrides limits to capture all unique values. + self._config.set_config('sampler_output_rows', -1) + self._config.set_config('sampler_uniques_per_column', -1) + self._config.set_config('sampler_rows_per_key', 1) + # Dictionary of unique values: count per column self._column_counts = {} # Dictionary of column index: list of header strings @@ -144,6 +168,16 @@ def reset(self) -> None: if col.strip() ] + # Must include values: dict of column_name -> set of values + self._must_include_values = load_column_keys( + self._config.get('sampler_column_keys', [])) + if self._column_include_values: + for col, vals in self._column_include_values.items(): + self._must_include_values.setdefault(col, set()).update(vals) + + # Map of column index -> set of values + self._must_include_indices = {} + def __del__(self) -> None: """Logs the column headers and counts upon object deletion.""" logging.log(2, f'Sampler column headers: {self._column_headers}') @@ -179,22 +213,35 @@ def _get_column_count(self, column_index: int, value: str) -> int: return 0 return col_values.get(value, 0) - def _should_track_column(self, column_index: int) -> bool: - """Determines if a column should be tracked for unique values. + def _is_unique_column(self, column_index: int) -> bool: + """Determines if a column is specified for unique value sampling. Args: column_index: The index of the column. Returns: - True if the column should be tracked (either no unique columns - specified or this column is in the unique columns list). + True if the column should be sampled for unique values. """ if not self._unique_column_names: - # No specific columns specified, track all + # No specific columns specified, track all for unique sampling return True # Check if this column is in our unique columns return column_index in self._unique_column_indices.values() + def _should_track_column(self, column_index: int) -> bool: + """Determines if a column should be tracked for counts. + + Args: + column_index: The index of the column. + + Returns: + True if the column should be tracked for unique values or is a + must-include column. + """ + if self._is_unique_column(column_index): + return True + return column_index in self._must_include_indices + def _process_header_row(self, row: list[str]) -> None: """Process a header row to build column name to index mapping. @@ -206,15 +253,27 @@ def _process_header_row(self, row: list[str]) -> None: Args: row: A header row containing column names. """ - if not self._unique_column_names: - return - for index, column_name in enumerate(row): - if column_name in self._unique_column_names: + if self._unique_column_names and column_name in self._unique_column_names: self._unique_column_indices[column_name] = index logging.level_debug() and logging.debug( f'Mapped unique column "{column_name}" to index {index}') + if self._must_include_values and column_name in self._must_include_values: + self._must_include_indices[index] = self._must_include_values[ + column_name] + logging.info( + f'Mapped must-include column "{column_name}" to index {index}' + ) + + def _is_must_include(self, column_index: int, value: str) -> bool: + """Checks if a column value is in the must-include list.""" + if column_index not in self._must_include_indices: + return False + # Normalize the input value before checking against the set + return mcf_file_util.strip_namespace( + value) in self._must_include_indices[column_index] + def _add_column_header(self, column_index: int, value: str) -> str: """Adds the first non-empty value of a column as its header. @@ -282,13 +341,26 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool: # Too many rows already selected. Drop it. return False max_count = self._config.get('sampler_rows_per_key', 3) + if max_count <= 0: + max_count = sys.maxsize max_uniques_per_col = self._config.get('sampler_uniques_per_column', 10) + if max_uniques_per_col <= 0: + max_uniques_per_col = sys.maxsize + for index in range(len(row)): - # Skip columns not in unique_columns list - if not self._should_track_column(index): - continue value = row[index] value_count = self._get_column_count(index, value) + + # Rule 1: Always include if it's a must-include value and + # we haven't reached per-key limit. + if value_count < max_count and self._is_must_include(index, value): + self._counters.add_counter('sampler-selected-must-include', 1) + return True + + # Skip columns not in unique_columns list for general unique sampling + if not self._is_unique_column(index): + continue + if value_count == 0 or value_count < max_count: # This is a new value for this column. col_counts = self._column_counts.get(index, {}) @@ -301,7 +373,7 @@ def select_row(self, row: list[str], sample_rate: float = -1) -> bool: # No new unique value for the row. # Check random sampler. if sample_rate < 0: - sample_rate = self._config.get('sampler_rate') + sample_rate = self._config.get('sampler_rate', -1) if random.random() <= sample_rate: self._counters.add_counter('sampler-sampled-rows', 1) return True @@ -425,6 +497,36 @@ def sample_csv_file(self, input_file: str, output_file: str = '') -> str: return output_file +def load_column_keys(column_keys: list) -> dict: + """Returns a dictionary of column name to set of keys loaded from a file. + The set of keys for a column are used as filter when sampling. + + Args: + column_keys: comma separated list of column_name: with + first column as the keys to be loaded. + + Returns: + dictionary of column name to a set of keys for that column + { : { key1, key2, ...}, : { k1, k2...} ...} + """ + column_map = {} + if not isinstance(column_keys, list): + column_keys = column_keys.split(',') + + for col_file in column_keys: + column_name, file_name = col_file.split(':', 1) + if not file_name: + logging.error(f'No file for column {column_name} in {column_keys}') + continue + + col_items = file_util.file_load_csv_dict(file_name) + column_map[column_name] = set(col_items.keys()) + logging.info( + f'Loaded {len(col_items)} for column {column_name} from {file_name}' + ) + return column_map + + def sample_csv_file(input_file: str, output_file: str = '', config: dict = None) -> str: @@ -443,6 +545,8 @@ def sample_csv_file(input_file: str, - sampler_output_rows: The maximum number of rows to include in the sample. - sampler_rate: The sampling rate to use for random selection. + - sampler_exhaustive: If True, overrides limits to capture all unique + values. - header_rows: The number of header rows to copy from the input file and search for sampler_unique_columns. Increase this if column names appear in later header rows (e.g., after a title row). @@ -500,20 +604,26 @@ def get_default_config() -> dict: # Use default values of flags for tests if not _FLAGS.is_parsed(): _FLAGS.mark_as_parsed() - return { + + config = { 'sampler_rate': _FLAGS.sampler_rate, 'sampler_input': _FLAGS.sampler_input, 'sampler_output': _FLAGS.sampler_output, 'sampler_output_rows': _FLAGS.sampler_output_rows, 'header_rows': _FLAGS.sampler_header_rows, 'sampler_rows_per_key': _FLAGS.sampler_rows_per_key, + 'sampler_uniques_per_column': _FLAGS.sampler_uniques_per_column, 'sampler_column_regex': _FLAGS.sampler_column_regex, 'sampler_unique_columns': _FLAGS.sampler_unique_columns, + 'sampler_column_keys': _FLAGS.sampler_column_keys, 'input_delimiter': _FLAGS.sampler_input_delimiter, 'output_delimiter': _FLAGS.sampler_output_delimiter, 'input_encoding': _FLAGS.sampler_input_encoding, + 'sampler_exhaustive': _FLAGS.sampler_exhaustive, } + return config + def main(_): sample_csv_file(_FLAGS.sampler_input, _FLAGS.sampler_output) diff --git a/tools/statvar_importer/data_sampler_test.py b/tools/statvar_importer/data_sampler_test.py index 7886f03937..6eabe16534 100644 --- a/tools/statvar_importer/data_sampler_test.py +++ b/tools/statvar_importer/data_sampler_test.py @@ -214,32 +214,157 @@ def test_unique_columns_partial_match_raises_error(self): self.assertNotIn('Name', error_msg) # Name should be found, not in error - @unittest.skip("TODO: Implement rows per key in DataSampler.") def test_rows_per_key(self): """Tests that the sampler respects the sampler_rows_per_key config.""" - config = {'sampler_rows_per_key': 2} - data_sampler.sample_csv_file(self.input_file, self.output_file, config) + # Use a controlled input + input_file = os.path.join(self._tmp_dir, 'rows_per_key.csv') + with open(input_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Key', 'Value']) + writer.writerow(['A', 'v1']) + writer.writerow(['A', 'v2']) + writer.writerow(['A', 'v3']) + writer.writerow(['B', 'v4']) + + config = { + 'sampler_rows_per_key': 2, + 'sampler_unique_columns': 'Key', + 'sampler_output_rows': -1, + 'header_rows': 1, + } + data_sampler.sample_csv_file(input_file, self.output_file, config) with open(self.output_file) as f: lines = f.readlines() - # The input file has 3 unique states. With sampler_rows_per_key=2, - # we expect 2 rows for each state, plus the header. - self.assertLessEqual(len(lines), 3 * 2 + 1) + # 2 rows for 'A', 1 row for 'B' + 1 header = 4 lines + self.assertEqual(len(lines), 4) - @unittest.skip("TODO: Implement cell value regex filtering in DataSampler.") def test_cell_value_regex(self): """Tests that sampler_column_regex filters based on cell values.""" - # This test checks if the sampler correctly uses the regex to identify - # and select rows based on the content of their cells. + input_file = os.path.join(self._tmp_dir, 'regex_test.csv') + with open(input_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Key', 'Value']) + writer.writerow(['2021', 'v1']) + writer.writerow(['2022', 'v2']) + writer.writerow(['abc', 'v3']) + writer.writerow(['123', 'v4']) + config = { - 'sampler_column_regex': r'^\d{4}$' - } # Regex for a 4-digit year - data_sampler.sample_csv_file(self.input_file, self.output_file, config) + 'sampler_column_regex': r'^\d{4}$', + 'sampler_output_rows': -1, + 'header_rows': 1, + } + data_sampler.sample_csv_file(input_file, self.output_file, config) + with open(self.output_file) as f: + lines = f.readlines() + # Header + 2021 + 2022 = 3 lines. '123' is only 3 digits. + self.assertEqual(len(lines), 3) + self.assertIn('2021,v1\n', lines) + self.assertIn('2022,v2\n', lines) + + def test_exhaustive_mode(self): + """Tests that exhaustive mode captures all unique values.""" + input_file = os.path.join(self._tmp_dir, 'exhaustive.csv') + with open(input_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Key', 'Value']) + for i in range(150): + writer.writerow([f'Key{i}', f'Value{i}']) + + # Default max rows is 100. Exhaustive should take all 150. + config = { + 'sampler_exhaustive': True, + 'sampler_unique_columns': 'Key', + 'header_rows': 1, + } + data_sampler.sample_csv_file(input_file, self.output_file, config) + with open(self.output_file) as f: + lines = f.readlines() + self.assertEqual(len(lines), 151) + + def test_must_include_values(self): + """Tests that the sampler always includes must-include values.""" + input_file = os.path.join(self._tmp_dir, 'must_include.csv') + with open(input_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Key', 'Value']) + writer.writerow(['A', 'v1']) + writer.writerow(['B', 'v2']) + writer.writerow(['C', 'v3']) + + # Include list for column 'Key' + include_file = os.path.join(self._tmp_dir, 'include.txt') + with open(include_file, 'w') as f: + f.write('Key\n') # Add header + f.write('C\n') + + config = { + 'sampler_column_keys': [f'Key:{include_file}'], + 'sampler_rate': 0, # Disable random sampling + 'sampler_unique_columns': + 'Value', # Track Value for unique sampling + 'sampler_uniques_per_column': + 1, # Only first row ('A', 'v1') will be unique + 'sampler_output_rows': 10, + 'header_rows': 1, + } + + data_sampler.sample_csv_file(input_file, self.output_file, config) + with open(self.output_file) as f: + lines = f.readlines() + # Header + 'A' (unique) + 'C' (must-include) = 3 lines + # 'B' is skipped because it's not unique enough for 'Value' column + # (since v1 was already taken) and not in must-include. + self.assertEqual(len(lines), 3) + self.assertIn('A,v1\n', lines) + self.assertIn('C,v3\n', lines) + self.assertNotIn('B,v2\n', lines) + + def test_uniques_per_column(self): + """Tests that the sampler respects sampler_uniques_per_column.""" + input_file = os.path.join(self._tmp_dir, 'uniques_per_col.csv') + with open(input_file, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['Key', 'Value']) + writer.writerow(['A', 'v1']) + writer.writerow(['B', 'v2']) + writer.writerow(['C', 'v3']) + writer.writerow(['D', 'v4']) + + config = { + 'sampler_uniques_per_column': 2, + 'sampler_unique_columns': 'Key', + 'sampler_output_rows': -1, + 'header_rows': 1, + } + data_sampler.sample_csv_file(input_file, self.output_file, config) with open(self.output_file) as f: lines = f.readlines() - # This would select the header and the one row in the test data that - # contains a year-like value. - self.assertEqual(len(lines), 2) - self.assertIn('2011', lines[1]) + # Header + 2 unique values = 3 lines + self.assertEqual(len(lines), 3) + + def test_load_column_keys(self): + """Tests that load_column_keys correctly parses the include list.""" + file1 = os.path.join(self._tmp_dir, 'file1.csv') + with open(file1, 'w') as f: + f.write('col1\nval1\nval2\n') + + file2 = os.path.join(self._tmp_dir, 'file2.csv') + with open(file2, 'w') as f: + f.write('col2\nval3\n') + + column_keys = [f'col1:{file1}', f'col2:{file2}'] + result = data_sampler.load_column_keys(column_keys) + + self.assertEqual(result, {'col1': {'val1', 'val2'}, 'col2': {'val3'}}) + + def test_get_default_config(self): + """Tests that get_default_config returns the expected dictionary.""" + config = data_sampler.get_default_config() + self.assertIn('sampler_rate', config) + self.assertIn('sampler_output_rows', config) + self.assertEqual(config['sampler_output_rows'], + 100) # Default flag value def test_non_existent_input_file(self): """Tests that the sampler handles a non-existent input file."""