diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 4d63f041e4..a1a9f39907 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy import ray import torch @@ -40,7 +41,11 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None: ) if enable_replay: config.replay_buffer = ReplayBufferConfig(enable=True) - sql_writer = SQLWriter(config.to_storage_config()) + writer_config = deepcopy(config) + writer_config.batch_size = put_batch_size + # Create buffer by writer, so buffer.batch_size will be set to put_batch_size + # This will check whether read_batch_size tasks effect + sql_writer = SQLWriter(writer_config.to_storage_config()) sql_reader = SQLReader(config.to_storage_config()) exps = [ Experience( diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index b3b1d14c12..6f8e7c0334 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -21,7 +21,7 @@ def __init__(self, config: StorageConfig): def read(self, batch_size: Optional[int] = None, **kwargs) -> List: try: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) if len(exps) != batch_size: raise TimeoutError( @@ -32,7 +32,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List: return exps async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: - batch_size = batch_size or self.read_batch_size + batch_size = self.read_batch_size if batch_size is None else batch_size exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) if len(exps) != batch_size: raise TimeoutError( diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index f7572c628c..fd1425bb8f 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -16,15 +16,18 @@ class SQLReader(BufferReader): def __init__(self, config: StorageConfig) -> None: assert config.storage_type == StorageType.SQL.value self.wrap_in_ray = config.wrap_in_ray + self.read_batch_size = config.batch_size self.storage = SQLStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None, **kwargs) -> List: + batch_size = self.read_batch_size if batch_size is None else batch_size if self.wrap_in_ray: return ray.get(self.storage.read.remote(batch_size, **kwargs)) else: return self.storage.read(batch_size, **kwargs) async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: + batch_size = self.read_batch_size if batch_size is None else batch_size if self.wrap_in_ray: try: return await self.storage.read.remote(batch_size, **kwargs) diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 04f0c20bda..eb1097cdc7 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -197,7 +197,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: if self.stopped: raise StopIteration() - batch_size = batch_size or self.batch_size + batch_size = self.batch_size if batch_size is None else batch_size return self._read_method(batch_size, **kwargs) @classmethod @@ -248,7 +248,7 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]: raise StopIteration() if self.offset > self.total_samples: raise StopIteration() - batch_size = batch_size or self.batch_size + batch_size = self.batch_size if batch_size is None else batch_size with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: query = ( session.query(self.table_model_cls)