Add Barspoon for multi-target prediction and several updates#154
Add Barspoon for multi-target prediction and several updates#154
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds multi-target classification support to STAMP via a new Barspoon encoder-decoder transformer model and makes several architectural improvements. The changes enable predicting multiple classification targets simultaneously (e.g., subtype and grade) while maintaining backward compatibility with single-target tasks.
Changes:
- Introduced Barspoon model for multi-target classification with encoder-decoder transformer architecture
- Refactored data loading to handle multi-target ground truths as dictionaries
- Improved type safety by removing most
type: ignorepragmas and adding explicit type annotations - Optimized I/O with HDF5 handle caching and moved utility modules (seed, config, cache) to
stamp.utils - Corrected survival analysis stratification to use event/status instead of raw strings
- Added RedDino and KEEP feature extractors
Reviewed changes
Copilot reviewed 51 out of 54 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
src/stamp/modeling/models/barspoon.py |
New encoder-decoder transformer for multi-target classification |
src/stamp/modeling/data.py |
Refactored data loading with multi-target support and HDF5 caching |
src/stamp/modeling/train.py |
Updated training pipeline to handle multi-target categories |
src/stamp/modeling/deploy.py |
Extended deployment to support multi-target predictions |
src/stamp/modeling/crossval.py |
Fixed survival stratification and added multi-target support |
src/stamp/statistics/__init__.py |
Added multi-target statistics computation |
src/stamp/statistics/categorical.py |
Implemented per-target metrics aggregation |
src/stamp/statistics/roc.py |
Optimized bootstrap sampling with pre-allocation |
src/stamp/statistics/prc.py |
Replaced scipy.interp1d with numpy.interp for simplicity |
src/stamp/utils/*.py |
Refactored utility modules (seed, config, cache) |
src/stamp/preprocessing/extractor/*.py |
Added RedDino and KEEP extractors |
tests/*.py |
Comprehensive test coverage for multi-target functionality |
pyproject.toml, uv.lock |
Version bump to 2.4.1 |
Comments suppressed due to low confidence (8)
tests/test_train_deploy.py:130
- Tests are suppressing a warning about type hint violations for tuples. This suggests there may be a runtime type checking issue (likely from beartype) where survival ground truths (tuples) don't match the expected type annotation. While suppressing the warning allows tests to pass, the underlying type annotations should be reviewed to ensure they correctly represent the actual data structures being used, especially for survival analysis where ground truth is now stored as
(time, event)tuples.
src/stamp/modeling/models/barspoon.py:137 - The class token initialization uses
torch.rand()which produces values in [0, 1). For learnable parameters that will be optimized, it's generally better to use a proper initialization scheme like Xavier/Glorot or Kaiming initialization. Random uniform [0,1) can lead to suboptimal training dynamics, especially for deeper networks. Consider usingnn.Parameter(nn.init.xavier_uniform_(torch.empty(d_model)))or similar.
sanitize(target_label): torch.rand(d_model)
src/stamp/modeling/models/barspoon.py:81
- Typo in docstring: "percieves" should be "perceives".
Since neither reduced performance and the author percieves the first one to
src/stamp/modeling/models/barspoon.py:86
- Typo in docstring: "descibed" should be "described".
The architecture _differs_ from the one descibed in [Attention Is All You
src/stamp/modeling/data.py:250
- The comment at line 247 says "Do NOT call .split here" but the code at line 250 then sets
status_str = "nan"without attempting to split. This appears intentional to handle the structured tuple case, but the logic comment could be clearer. The code seems to conflate two behaviors: if gt is already a tuple/list (line 237), it's handled; otherwise it's converted to string and status is marked as unknown. Consider clarifying the comment to explain this is the fallback for non-tuple ground truths.
# Legacy string form supported historically, but prefer the
# structured (time, event) tuple. Do NOT call .split here;
# treat the entire value as the time string and mark status
# as unknown. This avoids AttributeError when gt is a tuple.
time_str, status_str = str(gt), "nan"
src/stamp/modeling/data.py:598
- The HDF5 handle cache uses
popitem()to remove an arbitrary entry when the cache reaches 128 entries. This can be problematic if the dataset iteration pattern revisits files frequently, as the least-recently-used file might not be the one removed. Consider using an LRU cache orcollections.OrderedDictwithpopitem(last=False)to ensure oldest entries are evicted first, which would better align with typical access patterns.
if bag_file not in self._h5_handle_cache:
# Limit open handles to avoid reaching OS ulimits
if len(self._h5_handle_cache) >= 128:
_, h = self._h5_handle_cache.popitem()
h.close()
src/stamp/modeling/crossval.py:339
- Typo in error message: "Grounf" should be "Ground".
raise RuntimeError("Grounf truth label is required for regression")
src/stamp/modeling/crossval.py:356
- Typo in error message: "Grounf" should be "Ground".
"Grounf truth label is required for classification"
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Implement multi-class classification as the primary objective, requiring substantial modifications to the data loader structure and training pipeline, and for the first time leverage a coding agent to assist with architectural refactoring and performance optimization. Since multiple updates, STAMP will be bumped to 2.4.1