Skip to content

训练时如何降低显存要求 #12

@Uwwal

Description

@Uwwal

大佬你好, 尝试使用colab的T4进行训练, 发现显存爆掉了, 于是修改VAE_GAN_train.py文件的line53和line55 将batch size和epoch降低, 在GAN.py文件中修改图像大小, 但是依然由于显存问题不能训练, 请问有什么方法可以降低显存要求 或者降低batch? 因为输出的batch似乎与VAE_GAN_train.py设置并不相同

数据集是四张图片 没有输入测试集

/usr/local/lib/python3.11/dist-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
warnings.warn(
Error: Test dataset was not found at ./datasets/test_dataset.pt.
Test images directory not found: ./test_images/
Generate test dataset randomly.
---------- Loading VAE-GAN model ----------
Try to load model from ./saved_model/VAE.pth
No saved model found at './saved_model/VAE.pth'. Starting from scratch.
No saved model found at './saved_model/Discriminator.pth'. Starting from scratch.
/usr/local/lib/python3.11/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)
return VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Traceback (most recent call last):
File "/content/Unsupervised-Defect-Detection-Project-Based-on-VAE-GAN-Architecture/VAE_GAN_train.py", line 61, in
model.train(train_epochs, 10)
File "/content/Unsupervised-Defect-Detection-Project-Based-on-VAE-GAN-Architecture/GAN.py", line 253, in train
y, z_mu_y, z_var_y = self.model(x
)
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/content/Unsupervised-Defect-Detection-Project-Based-on-VAE-GAN-Architecture/AutoEncoder.py", line 251, in forward
predicted = self.decoder(x_sample)
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/content/Unsupervised-Defect-Detection-Project-Based-on-VAE-GAN-Architecture/AutoEncoder.py", line 211, in forward
identity3 = self.skip3(x2)
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/conv.py", line 952, in forward
return F.conv_transpose2d(
^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.19 GiB. GPU 0 has a total capacty of 14.74 GiB of which 352.12 MiB is free. Process 206937 has 14.39 GiB memory in use. Of the allocated memory 11.65 GiB is allocated by PyTorch, and 2.62 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions