diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 2d656c21..f8cd9fdc 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -30,7 +30,7 @@ jobs: cd /tmp git clone https://github.com/kokkos/pykokkos-base.git cd pykokkos-base - python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF + python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5 - name: Install pykokkos run: | python -m pip install . @@ -49,4 +49,4 @@ jobs: # for hypothesis-driven test case generation pytest $GITHUB_WORKSPACE/pre_compile_tools/pre_compile_ufuncs.py -s # only run a subset of the conformance tests to get started - pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor + pytest array_api_tests/meta/test_broadcasting.py array_api_tests/meta/test_equality_mapping.py array_api_tests/meta/test_signatures.py array_api_tests/meta/test_special_cases.py array_api_tests/test_constants.py array_api_tests/meta/test_utils.py array_api_tests/test_creation_functions.py::test_ones array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_data_type_functions.py::test_result_type array_api_tests/test_operators_and_elementwise_functions.py::test_log10 array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_isfinite array_api_tests/test_operators_and_elementwise_functions.py::test_log2 array_api_tests/test_operators_and_elementwise_functions.py::test_log1p array_api_tests/test_operators_and_elementwise_functions.py::test_isinf array_api_tests/test_operators_and_elementwise_functions.py::test_log array_api_tests/test_array_object.py::test_scalar_casting array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_square array_api_tests/test_operators_and_elementwise_functions.py::test_cos array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_floor "array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]" diff --git a/.github/workflows/main_ci.yml b/.github/workflows/main_ci.yml index 000af412..35326b0c 100644 --- a/.github/workflows/main_ci.yml +++ b/.github/workflows/main_ci.yml @@ -30,7 +30,7 @@ jobs: cd /tmp git clone https://github.com/kokkos/pykokkos-base.git cd pykokkos-base - python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF + python setup.py install -- -DENABLE_LAYOUTS=ON -DENABLE_MEMORY_TRAITS=OFF -DENABLE_VIEW_RANKS=5 - name: Install pykokkos run: | python -m pip install . diff --git a/pykokkos/lib/ufunc_workunits.py b/pykokkos/lib/ufunc_workunits.py index 06f12530..6bcfbb70 100644 --- a/pykokkos/lib/ufunc_workunits.py +++ b/pykokkos/lib/ufunc_workunits.py @@ -1,6 +1,355 @@ import pykokkos as pk +@pk.workunit +def add_impl_1d_double(tid: int, viewA: pk.View1D[pk.double], viewB: pk.View1D[pk.double], out: pk.View1D[pk.double], ): + out[tid] = viewA[tid] + viewB[tid % viewB.extent(0)] + + +@pk.workunit +def add_impl_1d_float(tid: int, viewA: pk.View1D[pk.float], viewB: pk.View1D[pk.float], out: pk.View1D[pk.float]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_int8(tid: int, viewA: pk.View1D[pk.int8], viewB: pk.View1D[pk.int8], out: pk.View1D[pk.int8]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_int16(tid: int, viewA: pk.View1D[pk.int16], viewB: pk.View1D[pk.int16], out: pk.View1D[pk.int16]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_int32(tid: int, viewA: pk.View1D[pk.int32], viewB: pk.View1D[pk.int32], out: pk.View1D[pk.int32]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_int64(tid: int, viewA: pk.View1D[pk.int64], viewB: pk.View1D[pk.int64], out: pk.View1D[pk.int64]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_uint8(tid: int, viewA: pk.View1D[pk.uint8], viewB: pk.View1D[pk.uint8], out: pk.View1D[pk.uint8]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_uint16(tid: int, viewA: pk.View1D[pk.uint16], viewB: pk.View1D[pk.uint16], out: pk.View1D[pk.uint16]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_uint32(tid: int, viewA: pk.View1D[pk.uint32], viewB: pk.View1D[pk.uint32], out: pk.View1D[pk.uint32]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_1d_uint64(tid: int, viewA: pk.View1D[pk.uint64], viewB: pk.View1D[pk.uint64], out: pk.View1D[pk.uint64]): + out[tid] = viewA[tid] + viewB[tid] + + +@pk.workunit +def add_impl_2d_double(tid: int, viewA: pk.View2D[pk.double], viewB: pk.View2D[pk.double], out: pk.View2D[pk.double]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_uint8(tid: int, viewA: pk.View2D[pk.uint8], viewB: pk.View2D[pk.uint8], out: pk.View2D[pk.uint8]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_float(tid: int, viewA: pk.View2D[pk.float], viewB: pk.View2D[pk.float], out: pk.View2D[pk.float]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_int8(tid: int, viewA: pk.View2D[pk.int8], viewB: pk.View2D[pk.int8], out: pk.View2D[pk.int8]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_int16(tid: int, viewA: pk.View2D[pk.int16], viewB: pk.View2D[pk.int16], out: pk.View2D[pk.int16]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_int32(tid: int, viewA: pk.View2D[pk.int32], viewB: pk.View2D[pk.int32], out: pk.View2D[pk.int32]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_int64(tid: int, viewA: pk.View2D[pk.int64], viewB: pk.View2D[pk.int64], out: pk.View2D[pk.int64]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_3d_float(tid: int, viewA: pk.View3D[pk.float], viewB: pk.View3D[pk.float], out: pk.View3D[pk.float]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_double(tid: int, viewA: pk.View3D[pk.double], viewB: pk.View3D[pk.double], out: pk.View3D[pk.double]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_uint8(tid: int, viewA: pk.View3D[pk.uint8], viewB: pk.View3D[pk.uint8], out: pk.View3D[pk.uint8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_uint16(tid: int, viewA: pk.View3D[pk.uint16], viewB: pk.View3D[pk.uint16], out: pk.View3D[pk.uint16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_uint32(tid: int, viewA: pk.View3D[pk.uint32], viewB: pk.View3D[pk.uint32], out: pk.View3D[pk.uint32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_uint64(tid: int, viewA: pk.View3D[pk.uint64], viewB: pk.View3D[pk.uint64], out: pk.View3D[pk.uint64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + +@pk.workunit +def add_impl_3d_int8(tid: int, viewA: pk.View3D[pk.int8], viewB: pk.View3D[pk.int8], out: pk.View3D[pk.int8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_int16(tid: int, viewA: pk.View3D[pk.int16], viewB: pk.View3D[pk.int16], out: pk.View3D[pk.int16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_int32(tid: int, viewA: pk.View3D[pk.int32], viewB: pk.View3D[pk.int32], out: pk.View3D[pk.int32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_3d_int64(tid: int, viewA: pk.View3D[pk.int64], viewB: pk.View3D[pk.int64], out: pk.View3D[pk.int64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + out[tid][i][j] = viewA[tid][i][j] + viewB[tid][i][j] + + +@pk.workunit +def add_impl_4d_float(tid: int, viewA: pk.View4D[pk.float], viewB: pk.View4D[pk.float], out: pk.View4D[pk.float]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_double(tid: int, viewA: pk.View4D[pk.double], viewB: pk.View4D[pk.double], out: pk.View4D[pk.double]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_uint8(tid: int, viewA: pk.View4D[pk.uint8], viewB: pk.View4D[pk.uint8], out: pk.View4D[pk.uint8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_uint16(tid: int, viewA: pk.View4D[pk.uint16], viewB: pk.View4D[pk.uint16], out: pk.View4D[pk.uint16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_uint32(tid: int, viewA: pk.View4D[pk.uint32], viewB: pk.View4D[pk.uint32], out: pk.View4D[pk.uint32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_uint64(tid: int, viewA: pk.View4D[pk.uint64], viewB: pk.View4D[pk.uint64], out: pk.View4D[pk.uint64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_int8(tid: int, viewA: pk.View4D[pk.int8], viewB: pk.View4D[pk.int8], out: pk.View4D[pk.int8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_int16(tid: int, viewA: pk.View4D[pk.int16], viewB: pk.View4D[pk.int16], out: pk.View4D[pk.int16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_int32(tid: int, viewA: pk.View4D[pk.int32], viewB: pk.View4D[pk.int32], out: pk.View4D[pk.int32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_4d_int64(tid: int, viewA: pk.View4D[pk.int64], viewB: pk.View4D[pk.int64], out: pk.View4D[pk.int64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + out[tid][i][j][k] = viewA[tid][i][j][k] + viewB[tid][i][j][k] + + +@pk.workunit +def add_impl_5d_float(tid: int, viewA: pk.View5D[pk.float], viewB: pk.View5D[pk.float], out: pk.View5D[pk.float]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_double(tid: int, viewA: pk.View5D[pk.double], viewB: pk.View5D[pk.double], out: pk.View5D[pk.double]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_uint8(tid: int, viewA: pk.View5D[pk.uint8], viewB: pk.View5D[pk.uint8], out: pk.View5D[pk.uint8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_uint16(tid: int, viewA: pk.View5D[pk.uint16], viewB: pk.View5D[pk.uint16], out: pk.View5D[pk.uint16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_uint32(tid: int, viewA: pk.View5D[pk.uint32], viewB: pk.View5D[pk.uint32], out: pk.View5D[pk.uint32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_uint64(tid: int, viewA: pk.View5D[pk.uint64], viewB: pk.View5D[pk.uint64], out: pk.View5D[pk.uint64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_int8(tid: int, viewA: pk.View5D[pk.int8], viewB: pk.View5D[pk.int8], out: pk.View5D[pk.int8]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_int16(tid: int, viewA: pk.View5D[pk.int16], viewB: pk.View5D[pk.int16], out: pk.View5D[pk.int16]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_int32(tid: int, viewA: pk.View5D[pk.int32], viewB: pk.View5D[pk.int32], out: pk.View5D[pk.int32]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_5d_int64(tid: int, viewA: pk.View5D[pk.int64], viewB: pk.View5D[pk.int64], out: pk.View5D[pk.int64]): + for i in range(viewA.extent(1)): + for j in range(viewA.extent(2)): + for k in range(viewA.extent(3)): + for l in range(viewA.extent(4)): + out[tid][i][j][k][l] = viewA[tid][i][j][k][l] + viewB[tid][i][j][k][l] + + +@pk.workunit +def add_impl_2d_uint16(tid: int, viewA: pk.View2D[pk.uint16], viewB: pk.View2D[pk.uint16], out: pk.View2D[pk.uint16]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_uint32(tid: int, viewA: pk.View2D[pk.uint32], viewB: pk.View2D[pk.uint32], out: pk.View2D[pk.uint32]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + +@pk.workunit +def add_impl_2d_uint64(tid: int, viewA: pk.View2D[pk.uint64], viewB: pk.View2D[pk.uint64], out: pk.View2D[pk.uint64]): + for i in range(viewA.extent(1)): + out[tid][i] = viewA[tid][i] + viewB[tid][i] + + @pk.workunit def floor_impl_1d_double(tid: int, view: pk.View1D[pk.double], out: pk.View1D[pk.double]): out[tid] = floor(view[tid]) diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index 13be86e2..28def314 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -9,6 +9,59 @@ kernel_dict = dict(getmembers(ufunc_workunits, isfunction)) +def _broadcast_views(view1, view2): + # support broadcasting by using the same + # shape matching rules as NumPy + # TODO: determine if this can be done with + # more memory efficiency? + if view1.shape != view2.shape: + new_shape = np.broadcast_shapes(view1.shape, view2.shape) + view1_new = pk.View([*new_shape], dtype=view1.dtype) + view1_new[:] = view1 + view1 = view1_new + view2_new = pk.View([*new_shape], dtype=view2.dtype) + view2_new[:] = view2 + view2 = view2_new + return view1, view2 + + +def _typematch_views(view1, view2): + # very crude casting implementation + # for binary ufuncs + dtype1 = view1.dtype + dtype2 = view2.dtype + dtype_extractor = re.compile(r".*(?:data_types|DataType)\.(\w+)") + res1 = dtype_extractor.match(str(dtype1)) + res2 = dtype_extractor.match(str(dtype2)) + effective_dtype = dtype1 + if res1 is not None and res2 is not None: + res1_dtype_str = res1.group(1) + res2_dtype_str = res2.group(1) + if res1_dtype_str == "double": + res1_dtype_str = "float64" + elif res1_dtype_str == "float": + res1_dtype_str = "float32" + if res2_dtype_str == "double": + res2_dtype_str = "float64" + elif res2_dtype_str == "float": + res2_dtype_str = "float32" + if (("int" in res1_dtype_str and "int" in res2_dtype_str) or + ("float" in res1_dtype_str and "float" in res2_dtype_str)): + dtype_1_width = int(res1_dtype_str.split("t")[1]) + dtype_2_width = int(res2_dtype_str.split("t")[1]) + if dtype_1_width >= dtype_2_width: + effective_dtype = dtype1 + view2_new = pk.View([*view2.shape], dtype=effective_dtype) + view2_new[:] = view2 + view2 = view2_new + else: + effective_dtype = dtype2 + view1_new = pk.View([*view1.shape], dtype=effective_dtype) + view1_new[:] = view1 + view1 = view1_new + return view1, view2, effective_dtype + + def _supported_types_check(dtype_str, supported_type_strings): options = "" for type_str in supported_type_strings: @@ -771,21 +824,6 @@ def sign(view): return out -@pk.workunit -def add_impl_1d_double(tid: int, viewA: pk.View1D[pk.double], viewB: pk.View1D[pk.double], out: pk.View1D[pk.double], ): - out[tid] = viewA[tid] + viewB[tid % viewB.extent(0)] - - -@pk.workunit -def add_impl_1d_float(tid: int, viewA: pk.View1D[pk.float], viewB: pk.View1D[pk.float], out: pk.View1D[pk.float]): - out[tid] = viewA[tid] + viewB[tid] - -@pk.workunit -def add_impl_2d_1d_double(tid: int, viewA: pk.View2D[pk.double], viewB: pk.View1D[pk.double], out: pk.View2D[pk.double]): - for i in range(viewA.extent(1)): - out[tid][i] = viewA[tid][i] + viewB[i % viewB.extent(0)] - - def add(viewA, viewB): """ Sums positionally corresponding elements @@ -804,39 +842,28 @@ def add(viewA, viewB): Output view. """ - if not isinstance(viewB, pk.View): - view_temp = pk.View([1], pk.double) - view_temp[0] = viewB - viewB = view_temp - - if viewA.rank() == 2: - out = pk.View(viewA.shape, pk.double) - pk.parallel_for( - viewA.shape[0], - add_impl_2d_1d_double, - viewA=viewA, - viewB=viewB, - out=out) - - elif str(viewA.dtype) == "DataType.double" and str(viewB.dtype) == "DataType.double": - out = pk.View([viewA.shape[0]], pk.double) - pk.parallel_for( - viewA.shape[0], - add_impl_1d_double, - viewA=viewA, - viewB=viewB, - out=out) - - elif str(viewA.dtype) == "DataType.float" and str(viewB.dtype) == "DataType.float": - out = pk.View([viewA.shape[0]], pk.float) - pk.parallel_for( - viewA.shape[0], - add_impl_1d_float, - viewA=viewA, - viewB=viewB, - out=out) + view1, view2 = _broadcast_views(viewA, viewB) + dtype1 = view1.dtype + dtype2 = view2.dtype + view1, view2, effective_dtype = _typematch_views(view1, view2) + ndims = len(view1.shape) + if ndims > 5: + raise NotImplementedError("add() ufunc only supports up to 5D views") + if view1.size == 0: + return pk.View([*view1.shape], dtype=effective_dtype) + out = pk.View([*view1.shape], dtype=effective_dtype) + if view1.shape == (): + tid = 1 else: - raise RuntimeError("Incompatible Types") + tid = view1.shape[0] + _ufunc_kernel_dispatcher(tid=tid, + dtype=effective_dtype, + ndims=ndims, + op="add", + sub_dispatcher=pk.parallel_for, + out=out, + viewA=view1, + viewB=view2) return out