1- from collections import Counter
2-
31import numpy as np
42import torch
53
@@ -17,6 +15,56 @@ def convert_to_torch_tensor(data_list, use_cuda):
1715 return data_list
1816
1917
18+ class BaseSampler (object ):
19+ """The base class of all samplers.
20+
21+ Sub-classes must implement the __call__ method.
22+ __call__ takes a DataSet object and returns a list of int - the sampling indices.
23+ """
24+
25+ def __call__ (self , * args , ** kwargs ):
26+ raise NotImplementedError
27+
28+
29+ class SequentialSampler (BaseSampler ):
30+ """Sample data in the original order.
31+
32+ """
33+
34+ def __call__ (self , data_set ):
35+ return list (range (len (data_set )))
36+
37+
38+ class RandomSampler (BaseSampler ):
39+ """Sample data in random permutation order.
40+
41+ """
42+
43+ def __call__ (self , data_set ):
44+ return list (np .random .permutation (len (data_set )))
45+
46+
47+ def simple_sort_bucketing (lengths ):
48+ """
49+
50+ :param lengths: list of int, the lengths of all examples.
51+ :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
52+ threshold for each bucket (This is usually None.).
53+ :return data: 2-level list
54+ ::
55+
56+ [
57+ [index_11, index_12, ...], # bucket 1
58+ [index_21, index_22, ...], # bucket 2
59+ ...
60+ ]
61+
62+ """
63+ lengths_mapping = [(idx , length ) for idx , length in enumerate (lengths )]
64+ sorted_lengths = sorted (lengths_mapping , key = lambda x : x [1 ])
65+ # TODO: need to return buckets
66+ return [idx for idx , _ in sorted_lengths ]
67+
2068def k_means_1d (x , k , max_iter = 100 ):
2169 """Perform k-means on 1-D data.
2270
@@ -46,18 +94,10 @@ def k_means_1d(x, k, max_iter=100):
4694 return np .array (centroids ), assign
4795
4896
49- def k_means_bucketing (all_inst , buckets ):
97+ def k_means_bucketing (lengths , buckets ):
5098 """Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths.
5199
52- :param all_inst: 3-level list
53- E.g. ::
54-
55- [
56- [[word_11, word_12, word_13], [label_11. label_12]], # sample 1
57- [[word_21, word_22, word_23], [label_21. label_22]], # sample 2
58- ...
59- ]
60-
100+ :param lengths: list of int, the length of all samples.
61101 :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
62102 threshold for each bucket (This is usually None.).
63103 :return data: 2-level list
@@ -72,7 +112,6 @@ def k_means_bucketing(all_inst, buckets):
72112 """
73113 bucket_data = [[] for _ in buckets ]
74114 num_buckets = len (buckets )
75- lengths = np .array ([len (inst [0 ]) for inst in all_inst ])
76115 _ , assignments = k_means_1d (lengths , num_buckets )
77116
78117 for idx , bucket_id in enumerate (assignments ):
@@ -81,102 +120,33 @@ def k_means_bucketing(all_inst, buckets):
81120 return bucket_data
82121
83122
84- class BaseSampler (object ):
85- """The base class of all samplers.
86-
87- """
88-
89- def __call__ (self , * args , ** kwargs ):
90- raise NotImplementedError
91-
92-
93- class SequentialSampler (BaseSampler ):
94- """Sample data in the original order.
95-
96- """
97-
98- def __call__ (self , data_set ):
99- return list (range (len (data_set )))
100-
101-
102- class RandomSampler (BaseSampler ):
103- """Sample data in random permutation order.
104-
105- """
106-
107- def __call__ (self , data_set ):
108- return list (np .random .permutation (len (data_set )))
109-
110-
111-
112- class Batchifier (object ):
113- """Wrap random or sequential sampler to generate a mini-batch.
114-
115- """
116-
117- def __init__ (self , sampler , batch_size , drop_last = True ):
118- """
119-
120- :param sampler: a Sampler object
121- :param batch_size: int, the size of the mini-batch
122- :param drop_last: bool, whether to drop the last examples that are not enough to make a mini-batch.
123-
124- """
125- super (Batchifier , self ).__init__ ()
126- self .sampler = sampler
127- self .batch_size = batch_size
128- self .drop_last = drop_last
129-
130- def __iter__ (self ):
131- batch = []
132- for example in self .sampler :
133- batch .append (example )
134- if len (batch ) == self .batch_size :
135- yield batch
136- batch = []
137- if 0 < len (batch ) < self .batch_size and self .drop_last is False :
138- yield batch
139-
140-
141- class BucketBatchifier (Batchifier ):
123+ class BucketSampler (BaseSampler ):
142124 """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
143125 In sampling, first random choose a bucket. Then sample data from it.
144126 The number of buckets is decided dynamically by the variance of sentence lengths.
145- TODO: merge it into Batch
127+
146128 """
147129
148- def __init__ (self , data_set , batch_size , num_buckets , drop_last = True , sampler = None ):
130+ def __call__ (self , data_set , batch_size , num_buckets ):
131+ return self ._process (data_set , batch_size , num_buckets )
132+
133+ def _process (self , data_set , batch_size , num_buckets , use_kmeans = False ):
149134 """
150135
151- :param data_set: three-level list, shape [num_samples, 2]
136+ :param data_set: a DataSet object
152137 :param batch_size: int
153138 :param num_buckets: int, number of buckets for grouping these sequences.
154- :param drop_last: bool, useless currently.
155- :param sampler: Sampler, useless currently.
139+ :param use_kmeans: bool, whether to use k-means to create buckets.
156140
157141 """
158- super (BucketBatchifier , self ).__init__ (sampler , batch_size , drop_last )
159142 buckets = ([None ] * num_buckets )
160- self .data = data_set
161- self .batch_size = batch_size
162- self .length_freq = dict (Counter ([len (example ) for example in data_set ]))
163- self .buckets = k_means_bucketing (data_set , buckets )
164-
165- def __iter__ (self ):
166- """Make a min-batch of data."""
167- for _ in range (len (self .data ) // self .batch_size ):
168- bucket_samples = self .buckets [np .random .randint (0 , len (self .buckets ))]
169- np .random .shuffle (bucket_samples )
170- yield [self .data [idx ] for idx in bucket_samples [:batch_size ]]
171-
172-
173- if __name__ == "__main__" :
174- import random
175-
176- data = [[[y ] * random .randint (0 , 50 ), [y ]] for y in range (500 )]
177- batch_size = 8
178- iterator = iter (BucketBatchifier (data , batch_size , num_buckets = 5 ))
179- for d in iterator :
180- print ("\n batch:" )
181- for dd in d :
182- print (len (dd [0 ]), end = " " )
143+ if use_kmeans is True :
144+ buckets = k_means_bucketing (data_set , buckets )
145+ else :
146+ buckets = simple_sort_bucketing (data_set )
147+ index_list = []
148+ for _ in range (len (data_set ) // batch_size ):
149+ chosen_bucket = buckets [np .random .randint (0 , len (buckets ))]
150+ np .random .shuffle (chosen_bucket )
151+ index_list += [idx for idx in chosen_bucket [:batch_size ]]
152+ return index_list
0 commit comments