@@ -57,31 +57,53 @@ ur_integrated_buffer_handle_t::ur_integrated_buffer_handle_t(
5757 ur_context_handle_t hContext, void *hostPtr, size_t size,
5858 device_access_mode_t accessMode)
5959 : ur_mem_buffer_t (hContext, size, accessMode) {
60- bool hostPtrImported =
61- maybeImportUSM (hContext->getPlatform ()->ZeDriverHandleExpTranslated ,
62- hContext->getZeHandle (), hostPtr, size);
63-
64- if (hostPtrImported) {
65- this ->ptr = usm_unique_ptr_t (hostPtr, [hContext](void *ptr) {
66- ZeUSMImport.doZeUSMRelease (
67- hContext->getPlatform ()->ZeDriverHandleExpTranslated , ptr);
68- });
69- } else {
70- void *rawPtr;
71- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
72- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &rawPtr));
60+ if (hostPtr) {
61+ // Host pointer provided - check if it's already USM or needs import
62+ ZeStruct<ze_memory_allocation_properties_t > memProps;
63+ auto ret =
64+ getMemoryAttrs (hContext->getZeHandle (), hostPtr, nullptr , &memProps);
65+
66+ if (ret == UR_RESULT_SUCCESS && memProps.type != ZE_MEMORY_TYPE_UNKNOWN) {
67+ // Already a USM allocation - just use it directly without import
68+ this ->ptr = usm_unique_ptr_t (hostPtr, [](void *) {});
69+ return ;
70+ }
7371
74- this ->ptr = usm_unique_ptr_t (rawPtr, [hContext](void *ptr) {
75- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
76- if (ret != UR_RESULT_SUCCESS) {
77- UR_LOG (ERR, " Failed to free host memory: {}" , ret);
78- }
79- });
72+ // Not USM - try to import it
73+ bool hostPtrImported =
74+ maybeImportUSM (hContext->getPlatform ()->ZeDriverHandleExpTranslated ,
75+ hContext->getZeHandle (), hostPtr, size);
8076
81- if (hostPtr) {
82- std::memcpy (this ->ptr .get (), hostPtr, size);
83- writeBackPtr = hostPtr;
77+ if (hostPtrImported) {
78+ // Successfully imported - use it with release
79+ this ->ptr = usm_unique_ptr_t (hostPtr, [hContext](void *ptr) {
80+ ZeUSMImport.doZeUSMRelease (
81+ hContext->getPlatform ()->ZeDriverHandleExpTranslated , ptr);
82+ });
83+ // No copy-back needed for imported pointers
84+ return ;
8485 }
86+
87+ // Import failed - allocate backing buffer and set up copy-back
88+ }
89+
90+ // No host pointer, or import failed - allocate new USM host memory
91+ void *rawPtr;
92+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
93+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &rawPtr));
94+
95+ this ->ptr = usm_unique_ptr_t (rawPtr, [hContext](void *ptr) {
96+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
97+ if (ret != UR_RESULT_SUCCESS) {
98+ UR_LOG (ERR, " Failed to free host memory: {}" , ret);
99+ }
100+ });
101+
102+ if (hostPtr) {
103+ // Copy data from user pointer to our backing buffer
104+ std::memcpy (this ->ptr .get (), hostPtr, size);
105+ // Remember to copy back on destruction
106+ writeBackPtr = hostPtr;
85107 }
86108}
87109
@@ -111,20 +133,53 @@ void *ur_integrated_buffer_handle_t::getDevicePtr(
111133}
112134
113135void *ur_integrated_buffer_handle_t ::mapHostPtr(
114- ur_map_flags_t /* flags*/ , size_t offset, size_t /* size */ ,
136+ ur_map_flags_t flags, size_t offset, size_t mapSize ,
115137 ze_command_list_handle_t /* cmdList*/ , wait_list_view & /* waitListView*/ ) {
116- // TODO: if writeBackPtr is set, we should map to that pointer
117- // because that's what SYCL expects, SYCL will attempt to call free
118- // on the resulting pointer leading to double free with the current
119- // implementation. Investigate the SYCL implementation.
138+ if (writeBackPtr) {
139+ // Copy-back path: user gets back their original pointer
140+ void *mappedPtr = ur_cast<char *>(writeBackPtr) + offset;
141+
142+ if (flags & UR_MAP_FLAG_READ) {
143+ std::memcpy (mappedPtr, ur_cast<char *>(ptr.get ()) + offset, mapSize);
144+ }
145+
146+ // Track this mapping for unmap
147+ mappedRegions.emplace_back (usm_unique_ptr_t (mappedPtr, [](void *) {}),
148+ mapSize, offset, flags);
149+
150+ return mappedPtr;
151+ }
152+
153+ // Zero-copy path: for successfully imported or USM pointers
120154 return ur_cast<char *>(ptr.get ()) + offset;
121155}
122156
123157void ur_integrated_buffer_handle_t::unmapHostPtr (
124- void * /* pMappedPtr*/ , ze_command_list_handle_t /* cmdList*/ ,
158+ void *pMappedPtr, ze_command_list_handle_t /* cmdList*/ ,
125159 wait_list_view & /* waitListView*/ ) {
126- // TODO: if writeBackPtr is set, we should copy the data back
127- /* nop */
160+ if (writeBackPtr) {
161+ // Copy-back path: find the mapped region and copy data back if needed
162+ auto mappedRegion =
163+ std::find_if (mappedRegions.begin (), mappedRegions.end (),
164+ [pMappedPtr](const host_allocation_desc_t &desc) {
165+ return desc.ptr .get () == pMappedPtr;
166+ });
167+
168+ if (mappedRegion == mappedRegions.end ()) {
169+ UR_DFAILURE (" could not find pMappedPtr:" << pMappedPtr);
170+ throw UR_RESULT_ERROR_INVALID_ARGUMENT;
171+ }
172+
173+ if (mappedRegion->flags &
174+ (UR_MAP_FLAG_WRITE | UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
175+ std::memcpy (ur_cast<char *>(ptr.get ()) + mappedRegion->offset ,
176+ mappedRegion->ptr .get (), mappedRegion->size );
177+ }
178+
179+ mappedRegions.erase (mappedRegion);
180+ return ;
181+ }
182+ // No op for zero-copy path, memory is synced
128183}
129184
130185static v2::raii::command_list_unique_handle
@@ -410,19 +465,16 @@ void ur_shared_buffer_handle_t::unmapHostPtr(
410465 // nop
411466}
412467
413- static bool useHostBuffer (ur_context_handle_t /* hContext */ ) {
468+ static bool useHostBuffer (ur_context_handle_t hContext) {
414469 // We treat integrated devices (physical memory shared with the CPU)
415470 // differently from discrete devices (those with distinct memories).
416471 // For integrated devices, allocating the buffer in the host memory
417472 // enables automatic access from the device, and makes copying
418473 // unnecessary in the map/unmap operations. This improves performance.
419474
420- // TODO: fix integrated buffer implementation
421- return false ;
422-
423- // return hContext->getDevices().size() == 1 &&
424- // hContext->getDevices()[0]->ZeDeviceProperties->flags &
425- // ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
475+ return hContext->getDevices ().size () == 1 &&
476+ hContext->getDevices ()[0 ]->ZeDeviceProperties ->flags &
477+ ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
426478}
427479
428480ur_mem_sub_buffer_t ::ur_mem_sub_buffer_t (ur_mem_handle_t hParent, size_t offset,
0 commit comments