diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index f637225c1..1d9462329 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -107,6 +107,7 @@ def __iter__(self) -> Iterator[int]: def __len__(self) -> int: """The number of samples in this rank.""" + # TODO: not same with LengthGroupedSampler? return self.num_samples - self.step def set_epoch(self, epoch: int) -> None: @@ -137,10 +138,11 @@ def load_state_dict(self, state_dict) -> None: ) def get_state_dict(self, step: int): - self.step = step % self.total_size + # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. + # self.step = step % self.total_size return { "epoch": self.epoch, - "step": self.step, + "step": step, "world_size": self.world_size, "shuffle": self.shuffle, "round_up": self.round_up, @@ -291,7 +293,8 @@ def get_state_dict(self, step: int): Returns: dict: The state of the sampler. """ - self.step = step % self.total_size + # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples. + # self.step = step % self.total_size return { "epoch": self.epoch, "step": self.step,