Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,26 @@ def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: I
# Let's make sure that the slice dimensions perfectly correspond with the number of
# outputs per input that we have
# Otherwise - its too complicated. We'll go with the base case
if (row_start % num_blocks) or (col_start % num_blocks) or (row_end % num_blocks) or (col_end % num_blocks):
block_size = num_rows // num_blocks
if (row_start % block_size) or (col_start % block_size) or (row_end % block_size) or (col_end % block_size):
return super()._getitem(row_index, col_index, *batch_indices)

# Otherwise - let's divide the slices by the number of outputs per input
row_index = slice(row_start // num_blocks, row_end // num_blocks, None)
col_index = slice(col_start // num_blocks, col_end // num_blocks, None)
# Compute block-level indices
block_row_idx = slice(row_start // block_size, row_end // block_size, None)
block_col_idx = slice(col_start // block_size, col_end // block_size, None)

# Now we can try the super call!
new_base_linear_op = self.base_linear_op._getitem(row_index, col_index, *batch_indices, _noop_index)
# If the row and column block ranges differ, this is a cross-block slice.
# The block-diagonal structure is lost, so fall back to the general case.
if block_row_idx != block_col_idx:
row_index = slice(row_start, row_end, row_step)
col_index = slice(col_start, col_end, col_step)
return super()._getitem(row_index, col_index, *batch_indices)

# Select blocks from the base operator's batch dimension.
# block_row_idx selects which blocks to keep; row/col are all (keep per-block matrix intact).
new_base_linear_op = self.base_linear_op._getitem(
slice(None), slice(None), *batch_indices, block_row_idx
)

# Now construct a kernel with those indices
return self.__class__(new_base_linear_op, block_dim=-3)
Expand Down
31 changes: 31 additions & 0 deletions test/operators/test_block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,34 @@ def test_metaclass_constructor(self):

if __name__ == "__main__":
unittest.main()


class TestBlockDiagCrossBlockSlicing(unittest.TestCase):
def test_cross_block_getitem(self):
T, n_blocks = 12, 200
blocks = torch.randn(n_blocks, T, T)
dense = DenseLinearOperator(blocks)
bd = BlockDiagLinearOperator(dense)

total = n_blocks * T # 2400

# Cross-block slices (row blocks != col blocks) should work
n_train = 150
sliced = bd[n_train * T:, :n_train * T]
self.assertEqual(sliced.shape, (total - n_train * T, n_train * T))

# Same-range block-aligned slices (fixed by the block_size fix)
sliced2 = bd[:T, :T]
self.assertEqual(sliced2.shape, (T, T))

# Same-range block-aligned, multiple blocks
sliced3 = bd[500:600, 500:600]
self.assertEqual(sliced3.shape, (100, 100))

# Non-aligned, non-square
sliced4 = bd[5:100, 50:200]
self.assertEqual(sliced4.shape, (95, 150))

# Single element
sliced5 = bd[0:1, 0:1]
self.assertEqual(sliced5.shape, (1, 1))