Skip to content

Commit cf40b4e

Browse files
authored
feat: Slice view (#893)
1 parent 74f28d9 commit cf40b4e

File tree

68 files changed

+1051
-494
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1051
-494
lines changed

crates/cubecl-attention/src/components/args.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ impl<EI: Numeric, EO: Numeric, GA: AttentionArgs> TensorOutput<EI, EO, GA> {
493493

494494
/// Get the buffer length of the tensor.
495495
pub fn buffer_len(&self) -> u32 {
496-
unsafe { GA::len_out(&(*self.state)) }
496+
unsafe { GA::buffer_len_out(&(*self.state)) }
497497
}
498498

499499
/// Get the line size of the tensor.

crates/cubecl-attention/src/components/global/dummy/attention.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl<
112112
) -> DummyQueryLoader<AP, Self::Config> {
113113
comment!("Global: Init Query Loader");
114114
let layout =
115-
SimpleGlobalLayout::new(&query, config.global_memory_config(FlashIdent::Query));
115+
SimpleGlobalLayout::new(&query, 0, config.global_memory_config(FlashIdent::Query));
116116
DummyQueryLoader::<AP, Self::Config>::new(q_offset, query.view(layout), config)
117117
}
118118

@@ -121,7 +121,7 @@ impl<
121121
#[comptime] config: Self::Config,
122122
) -> Self::KeyLoader {
123123
comment!("Global: Init Key Loader");
124-
let layout = SimpleGlobalLayout::new(&key, config.global_memory_config(FlashIdent::Key));
124+
let layout = SimpleGlobalLayout::new(&key, 0, config.global_memory_config(FlashIdent::Key));
125125
DummyKeyLoader::new(key.view(layout), config)
126126
}
127127

@@ -131,7 +131,7 @@ impl<
131131
) -> Self::ValueLoader {
132132
comment!("Global: Init Value Loader");
133133
let layout =
134-
SimpleGlobalLayout::new(&value, config.global_memory_config(FlashIdent::Value));
134+
SimpleGlobalLayout::new(&value, 0, config.global_memory_config(FlashIdent::Value));
135135
DummyValueLoader::new(value.view(layout), config)
136136
}
137137

@@ -141,7 +141,8 @@ impl<
141141
#[comptime] config: Self::Config,
142142
) -> Self::Writer {
143143
comment!("Global: Init Writer");
144-
let layout = SimpleGlobalLayout::new(&out, config.global_memory_config(FlashIdent::Out));
145-
SA::init_writer(q_offset, out.view_mut(layout))
144+
let layout = SimpleGlobalLayout::new(&out, 0, config.global_memory_config(FlashIdent::Out));
145+
let out = out.view_mut(layout);
146+
SA::init_writer(out.slice_mut_unchecked((q_offset, 0), out.shape()))
146147
}
147148
}

crates/cubecl-attention/src/components/global/dummy/load.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use cubecl_matmul::components::global::memory::{TensorReader, ViewDirection};
44
use cubecl_matmul::components::stage::{FullStageReader, StageMemory};
55
use cubecl_matmul::components::tile::Tile;
66
use cubecl_matmul::components::{MatrixLayout, StageIdent};
7-
use cubecl_std::tensor::{View, layout::Coords3d};
7+
use cubecl_std::tensor::{View, layout::Coords2d};
88
use std::marker::PhantomData;
99

1010
use crate::components::global::base::GlobalAttentionConfig;
@@ -41,8 +41,9 @@ pub struct DummyValueLoader<AP: AttentionPrecision, G: GlobalAttentionConfig> {
4141

4242
#[cube]
4343
impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyQueryLoader<AP, G> {
44-
pub fn new(q_offset: u32, query: View<Line<AP::EI>, Coords3d>, #[comptime] _config: G) -> Self {
45-
let tensor_reader = TensorReader::new(query, (0u32.runtime(), q_offset, 0u32.runtime()));
44+
pub fn new(q_offset: u32, query: View<Line<AP::EI>, Coords2d>, #[comptime] _config: G) -> Self {
45+
let query = query.slice((q_offset, 0), query.shape());
46+
let tensor_reader = TensorReader::new(query);
4647

4748
DummyQueryLoader::<AP, G> {
4849
tensor_reader,
@@ -55,14 +56,17 @@ impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyQueryLoader<AP, G> {
5556

5657
let attention_tile_size = config.stage_config().tile_config().attention_tile_size();
5758
let tile = Tile::<AP::EI> {
58-
slice: self.tensor_reader.view.slice(
59-
(
60-
self.tensor_reader.row_offset.read() * attention_tile_size.seq_q,
61-
0u32.runtime(),
62-
0u32.runtime(),
63-
),
64-
attention_tile_size.query_size(),
65-
),
59+
slice: self
60+
.tensor_reader
61+
.view
62+
.slice(
63+
(
64+
self.tensor_reader.row_offset.read() * attention_tile_size.seq_q,
65+
0u32.runtime(),
66+
),
67+
(1u32, attention_tile_size.query_size()).runtime(),
68+
)
69+
.to_linear_slice(),
6670
stride: attention_tile_size.num_cols(FlashIdent::Query),
6771
layout: MatrixLayout::RowMajor,
6872
};
@@ -73,9 +77,8 @@ impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyQueryLoader<AP, G> {
7377

7478
#[cube]
7579
impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyKeyLoader<AP, G> {
76-
pub fn new(key: View<Line<AP::EI>, Coords3d>, #[comptime] config: G) -> Self {
77-
let tensor_reader =
78-
TensorReader::new(key, (0u32.runtime(), 0u32.runtime(), 0u32.runtime()));
80+
pub fn new(key: View<Line<AP::EI>, Coords2d>, #[comptime] config: G) -> Self {
81+
let tensor_reader = TensorReader::new(key);
7982
let stage_memory = StageMemory::new::<G::ScoreStageMemoryConfig>(
8083
1u32,
8184
StageIdent::Rhs,
@@ -140,9 +143,8 @@ impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyKeyLoader<AP, G> {
140143

141144
#[cube]
142145
impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyValueLoader<AP, G> {
143-
pub fn new(value: View<Line<AP::EI>, Coords3d>, #[comptime] config: G) -> Self {
144-
let tensor_reader =
145-
TensorReader::new(value, (0u32.runtime(), 0u32.runtime(), 0u32.runtime()));
146+
pub fn new(value: View<Line<AP::EI>, Coords2d>, #[comptime] config: G) -> Self {
147+
let tensor_reader = TensorReader::new(value);
146148
let stage_memory = StageMemory::new::<G::ValueStageMemoryConfig>(
147149
1u32,
148150
StageIdent::Rhs,

crates/cubecl-attention/src/components/stage/base.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33
use cubecl_matmul::components::stage::{StageMemoryConfig, StageReaderFamily};
44
use cubecl_std::CubeOption;
5-
use cubecl_std::tensor::{View, layout::Coords3d};
5+
use cubecl_std::tensor::{View, layout::Coords2d};
66
use std::{fmt::Debug, hash::Hash};
77

88
use crate::components::{
@@ -96,7 +96,7 @@ pub trait StageAttention<AP: AttentionPrecision>: 'static + Send + Sync {
9696
#[comptime] global_config: G,
9797
);
9898

99-
fn init_writer(q_offset: u32, tensor: View<Line<AP::EO>, Coords3d, ReadWrite>) -> Self::Writer;
99+
fn init_writer(tensor: View<Line<AP::EO>, Coords2d, ReadWrite>) -> Self::Writer;
100100

101101
fn init_fragments(
102102
query_reader: QueryRegisterReader<AP::EI>,

crates/cubecl-attention/src/components/stage/dummy/attention.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use cubecl_core::prelude::*;
33
use cubecl_matmul::components::{stage::StageReader, tile::loader::Strided};
44
use cubecl_std::CubeOption;
55
use cubecl_std::tensor::View;
6-
use cubecl_std::tensor::layout::Coords3d;
6+
use cubecl_std::tensor::layout::Coords2d;
77
use std::marker::PhantomData;
88

99
use crate::components::global::dummy::QueryRegisterReader;
@@ -83,8 +83,8 @@ impl<AP: AttentionPrecision, R: StageReader<AP::ES, TileKind = Strided>, TA: Til
8383
TA::write::<G>(acc, writer, stage_config.tile_config(), global_config);
8484
}
8585

86-
fn init_writer(q_offset: u32, out: View<Line<AP::EO>, Coords3d, ReadWrite>) -> Self::Writer {
87-
TA::init_writer(q_offset, out)
86+
fn init_writer(out: View<Line<AP::EO>, Coords2d, ReadWrite>) -> Self::Writer {
87+
TA::init_writer(out)
8888
}
8989

9090
fn init_fragments(

crates/cubecl-attention/src/components/tile/base.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use cubecl_matmul::components::{
55
tile::Tile,
66
};
77
use cubecl_std::CubeOption;
8-
use cubecl_std::tensor::{View, layout::Coords3d};
8+
use cubecl_std::tensor::{View, layout::Coords2d};
99

1010
use crate::components::global::dummy::QueryRegisterReader;
1111
use crate::components::{
@@ -84,7 +84,7 @@ pub trait TileAttention<AP: AttentionPrecision>: 'static + Send + Sync {
8484
#[comptime] global_config: G,
8585
);
8686

87-
fn init_writer(q_offset: u32, tensor: View<Line<AP::EO>, Coords3d, ReadWrite>) -> Self::Writer;
87+
fn init_writer(tensor: View<Line<AP::EO>, Coords2d, ReadWrite>) -> Self::Writer;
8888

8989
fn init_fragments(
9090
query_reader: QueryRegisterReader<AP::EI>,

crates/cubecl-attention/src/components/tile/dummy/attention.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use cubecl_core as cubecl;
22
use cubecl_core::prelude::*;
33
use cubecl_matmul::components::tile::Tile;
44
use cubecl_std::tensor::View;
5-
use cubecl_std::tensor::layout::Coords3d;
5+
use cubecl_std::tensor::layout::Coords2d;
66
use cubecl_std::{CubeOption, CubeOptionExpand};
77
use std::marker::PhantomData;
88

@@ -135,8 +135,8 @@ impl<AP: AttentionPrecision, FM: FlashMatmul<AP::FlashPrecision>> TileAttention<
135135
)
136136
}
137137

138-
fn init_writer(q_offset: u32, out: View<Line<AP::EO>, Coords3d, ReadWrite>) -> Self::Writer {
139-
DummyWriter::new(out, q_offset, 0, 0)
138+
fn init_writer(out: View<Line<AP::EO>, Coords2d, ReadWrite>) -> Self::Writer {
139+
DummyWriter::new(out)
140140
}
141141

142142
fn init_fragments(

crates/cubecl-attention/src/components/tile/dummy/writer.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use cubecl_core::prelude::*;
33
use cubecl_matmul::components::{global::memory::TensorWriter, stage::StageMemoryConfig as _};
44
use cubecl_std::{
55
div_ceil,
6-
tensor::{View, layout::Coords3d},
6+
tensor::{View, layout::Coords2d},
77
};
88

99
use crate::components::{FlashIdent, global::GlobalAttentionConfig};
@@ -17,14 +17,9 @@ pub struct DummyWriter<EO: Numeric> {
1717

1818
#[cube]
1919
impl<EO: Numeric> DummyWriter<EO> {
20-
pub fn new(
21-
tensor: View<Line<EO>, Coords3d, ReadWrite>,
22-
x_offset: u32,
23-
y_offset: u32,
24-
batch_offset: u32,
25-
) -> Self {
20+
pub fn new(tensor: View<Line<EO>, Coords2d, ReadWrite>) -> Self {
2621
DummyWriter::<EO> {
27-
tensor_writer: TensorWriter::new(tensor, x_offset, y_offset, batch_offset),
22+
tensor_writer: TensorWriter::new(tensor),
2823
}
2924
}
3025

crates/cubecl-attention/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#![allow(unknown_lints)] // `manual_div_ceil` only appeared in 1.83
2-
#![allow(clippy::manual_div_ceil)]
2+
#![allow(clippy::manual_div_ceil, clippy::manual_is_multiple_of)]
33

44
mod base;
55
/// Components for matrix multiplication

crates/cubecl-convolution/src/components/global/base.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use cubecl_matmul::components::{
66
global::StageUnloader,
77
stage::{ContiguousTilingLayout, RowMajorTilingOrder},
88
};
9-
use cubecl_std::{CubeOption, tensor::r#virtual::VirtualTensor};
9+
use cubecl_std::{
10+
CubeOption,
11+
tensor::{layout::Coords2d, r#virtual::VirtualTensor},
12+
};
1013

1114
use crate::{
1215
components::{ConvGemmConfig, ConvolutionProblem, global::entry_point::ConvolutionLaunch},
@@ -44,7 +47,7 @@ pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
4447
type Config: ConvGemmConfig;
4548

4649
/// The writer used to write the results to the output feature map
47-
type StageWriter: StageUnloader<AccG<MP>>;
50+
type StageUnloader: StageUnloader<AccG<MP>>;
4851
/// The type of the tile matmul accumulator
4952
type Accumulators: CubeType;
5053

@@ -58,7 +61,7 @@ pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
5861
lhs_loader: Self::LhsStageLoader,
5962
rhs_loader: Self::RhsStageLoader,
6063
acc_loader: Self::AccStageLoader,
61-
writer: Self::StageWriter,
64+
writer: Self::StageUnloader,
6265
acc: &mut Self::Accumulators,
6366
k_range: (u32, u32),
6467
#[comptime] config: Self::Config,
@@ -67,17 +70,17 @@ pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
6770
/// Initializes the loader for the input feature map with an appropriate layout
6871
fn init_lhs_loader(
6972
lhs: VirtualTensor<LhsG<MP>>,
70-
x_offset: u32,
71-
y_offset: u32,
73+
offset: Coords2d,
74+
view_shape: Coords2d,
7275
runtime_args: &RuntimeArgs,
7376
#[comptime] config: Self::Config,
7477
) -> Self::LhsStageLoader;
7578

7679
/// Initializes the loader for the weights with an appropriate layout
7780
fn init_rhs_loader(
7881
rhs: VirtualTensor<RhsG<MP>>,
79-
x_offset: u32,
80-
y_offset: u32,
82+
offset: Coords2d,
83+
view_shape: Coords2d,
8184
runtime_args: &RuntimeArgs,
8285
#[comptime] config: Self::Config,
8386
) -> Self::RhsStageLoader;
@@ -86,17 +89,18 @@ pub trait GlobalConvolution<MP: MatmulPrecision>: 'static + Send + Sync {
8689
fn init_bias_loader(
8790
bias: CubeOption<VirtualTensor<AccG<MP>>>,
8891
n_offset: u32,
92+
slice_size: u32,
8993
#[comptime] config: Self::Config,
9094
) -> Self::AccStageLoader;
9195

9296
/// Initializes the output feature map loader with an appropriate layout
93-
fn init_writer(
97+
fn init_global_writer(
9498
out: VirtualTensor<AccG<MP>, ReadWrite>,
95-
x_offset: u32,
96-
y_offset: u32,
99+
offset: Coords2d,
100+
view_shape: Coords2d,
97101
runtime_args: &RuntimeArgs,
98102
#[comptime] config: Self::Config,
99-
) -> Self::StageWriter;
103+
) -> Self::StageUnloader;
100104

101105
/// Initializes a new accumulator for the tile matmul
102106
fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulators;

0 commit comments

Comments
 (0)