Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/complex_kernel.cu" // NOLINT
#include "paddle/phi/kernels/complex_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(conj,
iluvatar_gpu,
Expand All @@ -34,7 +35,7 @@ PD_CUSTOM_KERNEL_REGISTER(real,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToRealType(kernel_key.dtype()));
}

PD_CUSTOM_KERNEL_REGISTER(imag,
Expand All @@ -43,10 +44,10 @@ PD_CUSTOM_KERNEL_REGISTER(imag,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToRealType(kernel_key.dtype()));
}

PD_CUSTOM_KERNEL_REGISTER(
complex, iluvatar_gpu, ALL_LAYOUT, phi::ComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToComplexType(kernel_key.dtype()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu" // NOLINT
#include "paddle/phi/kernels/fusion/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h"

PD_CUSTOM_KERNEL_REGISTER(fused_softmax_mask_upper_triangle,
Expand Down
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

编译报错https://github.com/PaddlePaddle/PaddleCustomDevice/actions/runs/19928332563/job/57133953009?pr=2236#step:6:3099
phi::dtype::ToReal需要某个头文件,否则缺少声明导致编译报错

Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/complex_kernel.cu" // NOLINT
#include "paddle/phi/kernels/complex_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(conj,
metax_gpu,
Expand All @@ -34,7 +35,7 @@ PD_CUSTOM_KERNEL_REGISTER(real,
phi::RealKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToRealType(kernel_key.dtype()));
}

PD_CUSTOM_KERNEL_REGISTER(imag,
Expand All @@ -43,10 +44,10 @@ PD_CUSTOM_KERNEL_REGISTER(imag,
phi::ImagKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToRealType(kernel_key.dtype()));
}

PD_CUSTOM_KERNEL_REGISTER(
complex, metax_gpu, ALL_LAYOUT, phi::ComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::ToComplexType(kernel_key.dtype()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu" // NOLINT
#include "paddle/phi/kernels/fusion/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h"

PD_CUSTOM_KERNEL_REGISTER(fused_softmax_mask_upper_triangle,
Expand Down
Loading