88//
99// ===----------------------------------------------------------------------===//
1010
11+ #include " memory.hpp"
1112#include " ../ur_interface_loader.hpp"
1213#include " context.hpp"
13- #include " memory.hpp"
1414
1515#include " ../helpers/memory_helpers.hpp"
1616#include " ../image_common.hpp"
@@ -66,33 +66,44 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t(
6666 if (ret == UR_RESULT_SUCCESS && memProps.type != ZE_MEMORY_TYPE_UNKNOWN) {
6767 // Already a USM allocation - just use it directly without import
6868 this ->ptr = usm_unique_ptr_t (hostPtr, [](void *) {});
69- } else {
70- // Not USM - try to import it
71- bool hostPtrImported =
72- maybeImportUSM (hContext->getPlatform ()->ZeDriverHandleExpTranslated ,
73- hContext->getZeHandle (), hostPtr, size);
69+ return ;
70+ }
7471
75- if (!hostPtrImported) {
76- throw UR_RESULT_ERROR_INVALID_VALUE;
77- }
72+ // Not USM - try to import it
73+ bool hostPtrImported =
74+ maybeImportUSM (hContext->getPlatform ()->ZeDriverHandleExpTranslated ,
75+ hContext->getZeHandle (), hostPtr, size);
7876
77+ if (hostPtrImported) {
78+ // Successfully imported - use it with release
7979 this ->ptr = usm_unique_ptr_t (hostPtr, [hContext](void *ptr) {
8080 ZeUSMImport.doZeUSMRelease (
8181 hContext->getPlatform ()->ZeDriverHandleExpTranslated , ptr);
8282 });
83+ // No copy-back needed for imported pointers
84+ return ;
8385 }
84- } else {
85- // No host pointer - allocate new USM host memory
86- void *rawPtr;
87- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
88- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &rawPtr));
8986
90- this ->ptr = usm_unique_ptr_t (rawPtr, [hContext](void *ptr) {
91- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
92- if (ret != UR_RESULT_SUCCESS) {
93- UR_LOG (ERR, " Failed to free host memory: {}" , ret);
94- }
95- });
87+ // Import failed - allocate backing buffer and set up copy-back
88+ }
89+
90+ // No host pointer, or import failed - allocate new USM host memory
91+ void *rawPtr;
92+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
93+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &rawPtr));
94+
95+ this ->ptr = usm_unique_ptr_t (rawPtr, [hContext](void *ptr) {
96+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
97+ if (ret != UR_RESULT_SUCCESS) {
98+ UR_LOG (ERR, " Failed to free host memory: {}" , ret);
99+ }
100+ });
101+
102+ if (hostPtr) {
103+ // Copy data from user pointer to our backing buffer
104+ std::memcpy (this ->ptr .get (), hostPtr, size);
105+ // Remember to copy back on destruction
106+ writeBackPtr = hostPtr;
96107 }
97108}
98109
@@ -108,6 +119,12 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t(
108119 });
109120}
110121
122+ ur_integrated_buffer_handle_t ::~ur_integrated_buffer_handle_t () {
123+ if (writeBackPtr) {
124+ std::memcpy (writeBackPtr, ptr.get (), size);
125+ }
126+ }
127+
111128void *ur_integrated_buffer_handle_t ::getDevicePtr(
112129 ur_device_handle_t /* hDevice*/ , device_access_mode_t /* access*/ ,
113130 size_t offset, size_t /* size*/ , ze_command_list_handle_t /* cmdList*/ ,
@@ -116,16 +133,53 @@ void *ur_integrated_buffer_handle_t::getDevicePtr(
116133}
117134
118135void *ur_integrated_buffer_handle_t ::mapHostPtr(
119- ur_map_flags_t /* flags*/ , size_t offset, size_t /* size */ ,
136+ ur_map_flags_t flags, size_t offset, size_t mapSize ,
120137 ze_command_list_handle_t /* cmdList*/ , wait_list_view & /* waitListView*/ ) {
121- // For integrated devices, both device and host access the same memory
138+ if (writeBackPtr) {
139+ // Copy-back path: user gets back their original pointer
140+ void *mappedPtr = ur_cast<char *>(writeBackPtr) + offset;
141+
142+ if (flags & UR_MAP_FLAG_READ) {
143+ std::memcpy (mappedPtr, ur_cast<char *>(ptr.get ()) + offset, mapSize);
144+ }
145+
146+ // Track this mapping for unmap
147+ mappedRegions.emplace_back (usm_unique_ptr_t (mappedPtr, [](void *) {}),
148+ mapSize, offset, flags);
149+
150+ return mappedPtr;
151+ }
152+
153+ // Zero-copy path: for successfully imported or USM pointers
122154 return ur_cast<char *>(ptr.get ()) + offset;
123155}
124156
125157void ur_integrated_buffer_handle_t::unmapHostPtr (
126- void * /* pMappedPtr*/ , ze_command_list_handle_t /* cmdList*/ ,
158+ void *pMappedPtr, ze_command_list_handle_t /* cmdList*/ ,
127159 wait_list_view & /* waitListView*/ ) {
128- // No-op: integrated buffers use zero-copy, no synchronization needed
160+ if (writeBackPtr) {
161+ // Copy-back path: find the mapped region and copy data back if needed
162+ auto mappedRegion =
163+ std::find_if (mappedRegions.begin (), mappedRegions.end (),
164+ [pMappedPtr](const host_allocation_desc_t &desc) {
165+ return desc.ptr .get () == pMappedPtr;
166+ });
167+
168+ if (mappedRegion == mappedRegions.end ()) {
169+ UR_DFAILURE (" could not find pMappedPtr:" << pMappedPtr);
170+ throw UR_RESULT_ERROR_INVALID_ARGUMENT;
171+ }
172+
173+ if (mappedRegion->flags &
174+ (UR_MAP_FLAG_WRITE | UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
175+ std::memcpy (ur_cast<char *>(ptr.get ()) + mappedRegion->offset ,
176+ mappedRegion->ptr .get (), mappedRegion->size );
177+ }
178+
179+ mappedRegions.erase (mappedRegion);
180+ return ;
181+ }
182+ // No op for zero-copy path, memory is synced
129183}
130184
131185static v2::raii::command_list_unique_handle
@@ -564,47 +618,12 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
564618 void *hostPtr = pProperties ? pProperties->pHost : nullptr ;
565619 auto accessMode = ur_mem_buffer_t::getDeviceAccessMode (flags);
566620
567- // For integrated devices, we can use zero-copy host buffers when:
568- // 1. No host pointer is provided (we'll allocate USM host memory)
569- // 2. Host pointer is already USM memory
570- // 3. Host pointer can be imported as USM
571- // Otherwise, fall back to discrete buffer (explicit copies).
572- if (useHostBuffer (hContext) && hostPtr) {
573- // Check what type of memory this pointer is
574- ZeStruct<ze_memory_allocation_properties_t > memProps;
575- auto ret =
576- getMemoryAttrs (hContext->getZeHandle (), hostPtr, nullptr , &memProps);
577-
578- if (ret == UR_RESULT_SUCCESS) {
579- if (memProps.type != ZE_MEMORY_TYPE_UNKNOWN) {
580- // Already USM memory (host, device, or shared) - use integrated path
581- *phBuffer = ur_mem_handle_t_::create<ur_integrated_buffer_handle_t >(
582- hContext, hostPtr, size, accessMode);
583- return UR_RESULT_SUCCESS;
584- }
585-
586- // Memory type is UNKNOWN - try to import it
587- bool canImport =
588- maybeImportUSM (hContext->getPlatform ()->ZeDriverHandleExpTranslated ,
589- hContext->getZeHandle (), hostPtr, size);
590- if (!canImport) {
591- // Cannot import: fall back to discrete buffer path
592- *phBuffer = ur_mem_handle_t_::create<ur_discrete_buffer_handle_t >(
593- hContext, hostPtr, size, accessMode);
594- return UR_RESULT_SUCCESS;
595- }
596- // Successfully imported: release it now, constructor will import again
597- ZeUSMImport.doZeUSMRelease (
598- hContext->getPlatform ()->ZeDriverHandleExpTranslated , hostPtr);
599- } else {
600- // Cannot get memory attributes: fall back to discrete buffer
601- *phBuffer = ur_mem_handle_t_::create<ur_discrete_buffer_handle_t >(
602- hContext, hostPtr, size, accessMode);
603- return UR_RESULT_SUCCESS;
604- }
605- }
606-
607- // Use integrated buffer path (no hostPtr, or hostPtr is USM/importable)
621+ // For integrated devices, use zero-copy host buffers. The integrated buffer
622+ // constructor will handle all cases:
623+ // 1. No host pointer - allocate USM host memory
624+ // 2. Host pointer is already USM - use directly
625+ // 3. Host pointer can be imported - import it
626+ // 4. Otherwise - allocate USM and copy-back on destruction
608627 if (useHostBuffer (hContext)) {
609628 *phBuffer = ur_mem_handle_t_::create<ur_integrated_buffer_handle_t >(
610629 hContext, hostPtr, size, accessMode);
0 commit comments