1717"""
1818
1919import collections
20+ import functools
2021import inspect
2122import math
2223import warnings
@@ -1043,11 +1044,13 @@ def stateless_call(
10431044 if self ._remat_mode is not None :
10441045 outputs = self .rematerialized_call (
10451046 self .quantized_call , * args , ** kwargs
1046- )
1047+ )( * args , ** kwargs )
10471048 else :
10481049 outputs = self .quantized_call (* args , ** kwargs )
10491050 elif self ._remat_mode is not None :
1050- outputs = self .rematerialized_call (self .call , * args , ** kwargs )
1051+ outputs = self .rematerialized_call (self .call , * args , ** kwargs )(
1052+ * args , ** kwargs
1053+ )
10511054 else :
10521055 outputs = self .call (* args , ** kwargs )
10531056 if return_losses :
@@ -1601,13 +1604,13 @@ def compute_size(x):
16011604
16021605 # Full rematerialization
16031606 if self ._remat_mode .mode == "full" :
1604- return remat .remat (layer_call )( * args , ** kwargs )
1607+ return remat .remat (layer_call )
16051608
16061609 # Apply rematerialization to specific layers
16071610 elif self ._remat_mode .mode == "list_of_layers" and (
16081611 self .name in self ._remat_mode .layer_names
16091612 ):
1610- return remat .remat (layer_call )( * args , ** kwargs )
1613+ return remat .remat (layer_call )
16111614
16121615 # Apply rematerialization based on output size threshold
16131616 elif self ._remat_mode .mode == "larger_than" :
@@ -1619,20 +1622,24 @@ def compute_size(x):
16191622 output_size
16201623 and output_size > self ._remat_mode .output_size_threshold
16211624 ):
1622- return remat .remat (layer_call )( * args , ** kwargs )
1625+ return remat .remat (layer_call )
16231626 elif self ._remat_mode .mode == "activations" :
16241627 has_activation = (
16251628 hasattr (self , "activation" ) and self .activation is not None
16261629 )
16271630 if has_activation :
1628- not_rematted_activation = self .activation
1629- try :
1630- self .activation = remat .remat (not_rematted_activation )
1631- return layer_call (* args , ** kwargs )
1632- finally :
1633- self .activation = not_rematted_activation
16341631
1635- return layer_call (* args , ** kwargs )
1632+ @functools .wraps (layer_call )
1633+ def rematerialized_activation_call_wrapper (* args , ** kwargs ):
1634+ original_activation = self .activation
1635+ self .activation = remat .remat (original_activation )
1636+ try :
1637+ return layer_call (* args , ** kwargs )
1638+ finally :
1639+ self .activation = original_activation
1640+
1641+ return rematerialized_activation_call_wrapper
1642+ return layer_call
16361643
16371644
16381645def is_backend_tensor_or_symbolic (x , allow_none = False ):
0 commit comments