Make the method registry the single source of truth for foundation-mo…#108
Merged
Conversation
…del row limits; fix torch deprecations and consolidate duplicated code
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR centralizes common utilities and training-row capping behavior across methods, reducing duplicated code and making per-method row limits come from the method registry.
Changes:
- Added shared
check_softmax()utility and removed duplicated inline implementations across multiple methods. - Introduced
resolve_sample_size()/subsample_train_rows()onMethodto standardize training-row caps based ontrain_row_limit(withgeneral.sample_sizeoverride). - Updated configs to remove hard-coded
sample_sizedefaults and adjusted some inference-related Torch AMP/checkpoint usage.
Reviewed changes
Copilot reviewed 39 out of 39 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| readme.md | Documents registry as source of truth for training-row caps and overrides. |
| TALENT/model/utils.py | Adds shared check_softmax() helper. |
| TALENT/model/methods/trompt.py | Switches to shared check_softmax() import. |
| TALENT/model/methods/tabptm.py | Removes duplicate torch import. |
| TALENT/model/methods/tabpfn_v3.py | Replaces local sample-size logic with subsample_train_rows(). |
| TALENT/model/methods/tabpfn_v2_5.py | Replaces local sample-size logic with subsample_train_rows(). |
| TALENT/model/methods/tabpfn_v2.py | Avoids mutable default arg; uses centralized resolve_sample_size(). |
| TALENT/model/methods/tabpfn_real.py | Avoids mutable default arg for cat_indices. |
| TALENT/model/methods/tabpfn.py | Replaces local sample-size logic with subsample_train_rows(). |
| TALENT/model/methods/tabnet.py | Removes duplicate torch import. |
| TALENT/model/methods/tabm.py | Switches to shared check_softmax() import. |
| TALENT/model/methods/tabicl_v2.py | Switches to shared check_softmax() and centralized row capping. |
| TALENT/model/methods/tabicl.py | Switches to shared check_softmax(), avoids mutable default arg, centralized row capping. |
| TALENT/model/methods/tabdpt.py | Replaces local sample-size logic with subsample_train_rows(). |
| TALENT/model/methods/tabcaps.py | Removes duplicate torch import. |
| TALENT/model/methods/ptarl.py | Removes duplicate torch import. |
| TALENT/model/methods/mitra.py | Adds centralized row capping; makes max_samples_* more robust with defaults. |
| TALENT/model/methods/limix.py | Removes duplicate torch import. |
| TALENT/model/methods/hyperfast.py | Removes duplicate torch import. |
| TALENT/model/methods/grownet.py | Removes duplicate torch import. |
| TALENT/model/methods/excelformer.py | Switches to shared check_softmax() import. |
| TALENT/model/methods/base.py | Adds resolve_sample_size() + subsample_train_rows() utilities for consistent capping. |
| TALENT/model/lib/tabpfn/utils.py | Updates checkpoint + autocast usage (incl. use_reentrant=False). |
| TALENT/model/lib/tabpfn/layer.py | Updates checkpoint usage to pass use_reentrant=False. |
| TALENT/model/lib/limix/model/transformer.py | Updates AMP autocast usage. |
| TALENT/model/classical_methods/base.py | Removes duplicated check_softmax() in favor of shared utility. |
| TALENT/configs/opt_space/tabpfn_v3.json | Removes sample_size; changes n_estimators. |
| TALENT/configs/opt_space/tabpfn_v2_5.json | Removes sample_size. |
| TALENT/configs/opt_space/tabpfn_v2.json | Removes sample_size and keeps empty general. |
| TALENT/configs/opt_space/tabpfn_real.json | Removes sample_size and keeps empty general. |
| TALENT/configs/opt_space/tabicl_v2.json | Removes sample_size. |
| TALENT/configs/opt_space/tabdpt.json | Removes sample_size. |
| TALENT/configs/default/tabpfn_v3.json | Removes sample_size. |
| TALENT/configs/default/tabpfn_v2_5.json | Removes sample_size. |
| TALENT/configs/default/tabpfn_v2.json | Removes sample_size and keeps empty general. |
| TALENT/configs/default/tabpfn_real.json | Removes sample_size and keeps empty general. |
| TALENT/configs/default/tabpfn.json | Removes sample_size and keeps empty general. |
| TALENT/configs/default/tabicl_v2.json | Removes sample_size. |
| TALENT/configs/default/tabdpt.json | Removes sample_size. |
Comments suppressed due to low confidence (1)
TALENT/model/utils.py:1
check_softmax()checks normalization usingsum(axis=-1)but computes the softmax usingaxis=1, which is inconsistent and will produce incorrect results (or errors) iflogitsisn’t strictly 2D with classes on axis 1. Use a consistent axis (typicallyaxis=-1) fornp.max/np.sumand the normalization check so the function works correctly for any(…, C)shaped input.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+102
to
+121
| def subsample_train_rows(self, X, y): | ||
| """ | ||
| Cap the training rows at ``resolve_sample_size()`` rows. | ||
|
|
||
| Classification subsamples stratified by label to keep class | ||
| proportions; regression takes a uniform random subset. Both are | ||
| seeded with ``args.seed`` for reproducibility. Returns (X, y) | ||
| unchanged when no cap applies. | ||
| """ | ||
| sample_size = self.resolve_sample_size() | ||
| if sample_size is None or X.shape[0] <= sample_size: | ||
| return X, y | ||
| if not self.is_regression: | ||
| from sklearn.model_selection import train_test_split | ||
| X, _, y, _ = train_test_split( | ||
| X, y, | ||
| train_size=sample_size, | ||
| stratify=y, | ||
| random_state=self.args.seed, | ||
| ) |
Comment on lines
+82
to
+92
| def resolve_sample_size(self): | ||
| """ | ||
| Resolve the effective training-row cap for this method. | ||
|
|
||
| An explicit ``config['general']['sample_size']`` takes precedence as a | ||
| per-run override; otherwise the method's ``train_row_limit`` from the | ||
| method registry applies (the single source of truth for row limits). | ||
| Returns None when neither is set (no cap). | ||
| """ | ||
| general = self.args.config.get('general', {}) or {} | ||
| sample_size = general.get('sample_size') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR eliminates drift between the method registry and the config JSONs for foundation-model row limits, fixes the
torch.cuda.amp/use_reentrantdeprecation warnings, and removes duplicated code found during a package-wide consistency audit. No new features — only correctness, consistency, and a single source of truth.1. Registry as single source of truth for training-row caps
Previously each foundation model's row cap lived in two places —
train_row_limitinmodel/method_registry.pyandsample_sizein the default/opt_space config JSONs — and several had silently drifted (e.g.tabicl_v2.jsonsaid 60k while the registry said 1M;tabdpt.jsoncapped at 100k while TabDPT is unlimited).Method.resolve_sample_size()/Method.subsample_train_rows()helpers inmodel/methods/base.py: an explicitconfig['general']['sample_size']acts as a per-run override; otherwise the registry'strain_row_limitapplies.tabpfn,tabpfn_v2,tabpfn_real,tabpfn_v2_5,tabpfn_v3,tabicl,tabicl_v2,tabdpt,mitra) now share this one seeded, stratified-for-classification subsampling path.sample_sizeremoved from all default and opt_space config JSONs, so configs can no longer disagree with the registry.Effective caps after this PR (registry-governed):
2. Torch deprecation fixes
torch.cuda.amp.autocast(enabled=...)→torch.amp.autocast("cuda", enabled=...)inmodel/lib/tabpfn/utils.pyandmodel/lib/limix/model/transformer.py.use_reentrant=Falseadded to alltorch.utils.checkpoint.checkpoint(...)call sites inmodel/lib/tabpfn/utils.pyandmodel/lib/tabpfn/layer.py(including thepartial(checkpoint, ...)variant). This silences the current warnings and is required before torch makes the default a hard error.3. Consistency / dead-code cleanup
check_softmaxwas defined 7 times across the package; it now lives once inmodel/utils.pyand is imported everywhere else (names remain importable from their old locations).import torchstatements from 11 method files.cat_indices=[]) replaced with theNonepattern intabpfn_v2.py,tabpfn_real.py,tabicl.py(matchingtabpfn_v3.py).predict()no longer raisesKeyErrorwhenmax_samples_support/max_samples_queryare absent from the config (falls back to 8192 / 1024).opt_space/tabpfn_v3.jsonhad drifted from the default config (n_estimators4 vs 32,sample_size50k vs 1M) — aligned.args.seedlike every other method.train_row_limitin the registry governs row caps andsample_sizeis a per-run override.Verification
python -m compileallclean overmodel/methods,model/classical_methods,model/utils.py,model/method_registry.py,api.py.sample_sizekeys remain.padding_obs_query__inmitra.py— confirmed it matches the upstream MitraTab2D.forwardsignature (not a typo; unchanged).Fixes
feature-attention CUDA kernel batch, fixing
CUDA error: invalid configuration argumenton wide datasets with large test sets (predictions are identical; rows are scored
independently).
UndefinedMetricWarnings by passingzero_division=0toprecision_scorein bothmetric()implementations (reported values unchanged —sklearn already returned 0 for ill-defined precision).