diff --git a/backends/metax_gpu/patch/tmp/mixed_vector.h b/backends/metax_gpu/patch/tmp/mixed_vector.h index e7cf1e626c9..1dcca9c71b4 100644 --- a/backends/metax_gpu/patch/tmp/mixed_vector.h +++ b/backends/metax_gpu/patch/tmp/mixed_vector.h @@ -386,7 +386,8 @@ class MixVector { // the unify method to access CPU or CUDA data. immutable. const T *Data(phi::Place place) const { - if (place.GetType() == phi::AllocationType::GPU) { + if (place.GetType() == phi::AllocationType::GPU || + place.GetType() == phi::AllocationType::CUSTOM) { return CUDAData(place); } else { return data(); @@ -395,7 +396,8 @@ class MixVector { // the unify method to access CPU or CUDA data. mutable. T *MutableData(phi::Place place) { - if (place.GetType() == phi::AllocationType::GPU) { + if (place.GetType() == phi::AllocationType::GPU || + place.GetType() == phi::AllocationType::CUSTOM) { return CUDAMutableData(place); } else { return data();