@@ -21,6 +21,8 @@ resulting data will be shuffled after its creation; if it is not
2121shuffled then all the repeated samples will be together at the
2222end, sorted by class. Defaults to `true`.
2323
24+ The output will contain both the resampled data and classes.
25+
2426```julia
2527# 6 observations with 3 features each
2628X = rand(3, 6)
@@ -40,14 +42,7 @@ X_bal, Y_bal = oversample(X, Y)
4042```
4143
4244For this function to work, the type of `data` must implement
43- [`numobs`](@ref) and [`getobs`](@ref). For example, the following
44- code allows `oversample` to work on a `DataFrame`.
45-
46- ```julia
47- # Make DataFrames.jl work
48- MLUtils.getobs(data::DataFrame, i) = data[i,:]
49- MLUtils.numobs(data::DataFrame) = nrow(data)
50- ```
45+ [`numobs`](@ref) and [`getobs`](@ref).
5146
5247Note that if `data` is a tuple and `classes` is not given,
5348then it will be assumed that the last element of the tuple contains the classes.
@@ -98,16 +93,22 @@ function oversample(data, classes; fraction=1, shuffle::Bool=true)
9893 append! (inds, inds_for_lbl)
9994 end
10095 if num_extra_needed > 0
101- append! (inds, sample (inds_for_lbl, num_extra_needed; replace= false ))
96+ if shuffle
97+ append! (inds, sample (inds_for_lbl, num_extra_needed; replace= false ))
98+ else
99+ append! (inds, inds_for_lbl[1 : num_extra_needed])
100+ end
102101 end
103102 end
104103
105104 shuffle && shuffle! (inds)
106- return obsview (data, inds)
105+ return obsview (data, inds), obsview (classes, inds)
107106end
108107
109- oversample (data:: Tuple ; kws... ) = oversample (data, data[end ]; kws... )
110-
108+ function oversample (data:: Tuple ; kws... )
109+ d, c = oversample (data[1 : end - 1 ], data[end ]; kws... )
110+ return (d... , c)
111+ end
111112
112113"""
113114 undersample(data, classes; shuffle=true)
@@ -123,6 +124,8 @@ resulting data will be shuffled after its creation; if it is not
123124shuffled then all the observations will be in their original
124125order. Defaults to `false`.
125126
127+ The output will contain both the resampled data and classes.
128+
126129```julia
127130# 6 observations with 3 features each
128131X = rand(3, 6)
@@ -142,14 +145,8 @@ X_bal, Y_bal = undersample(X, Y)
142145```
143146
144147For this function to work, the type of `data` must implement
145- [`numobs`](@ref) and [`getobs`](@ref). For example, the following
146- code allows `undersample` to work on a `DataFrame`.
148+ [`numobs`](@ref) and [`getobs`](@ref).
147149
148- ```julia
149- # Make DataFrames.jl work
150- MLUtils.getobs(data::DataFrame, i) = data[i,:]
151- MLUtils.numobs(data::DataFrame) = nrow(data)
152- ```
153150Note that if `data` is a tuple, then it will be assumed that the
154151last element of the tuple contains the targets.
155152
@@ -186,11 +183,18 @@ function undersample(data, classes; shuffle::Bool=true)
186183 inds = Int[]
187184
188185 for (lbl, inds_for_lbl) in lm
189- append! (inds, sample (inds_for_lbl, mincount; replace= false ))
186+ if shuffle
187+ append! (inds, sample (inds_for_lbl, mincount; replace= false ))
188+ else
189+ append! (inds, inds_for_lbl[1 : mincount])
190+ end
190191 end
191192
192193 shuffle ? shuffle! (inds) : sort! (inds)
193- return obsview (data, inds)
194+ return obsview (data, inds), obsview (classes, inds)
194195end
195196
196- undersample (data:: Tuple ; kws... ) = undersample (data, data[end ]; kws... )
197+ function undersample (data:: Tuple ; kws... )
198+ d, c = undersample (data[1 : end - 1 ], data[end ]; kws... )
199+ return (d... , c)
200+ end
0 commit comments