Skip to content

Commit 1a4269e

Browse files
committed
allow opaque structs to compare equal if the underlying address is the same
1 parent 01f2be1 commit 1a4269e

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

cuda_bindings/cuda/bindings/driver.pyx.in

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6525,6 +6525,10 @@ cdef class CUcontext:
65256525
return '<CUcontext ' + str(hex(self.__int__())) + '>'
65266526
def __index__(self):
65276527
return self.__int__()
6528+
def __eq__(self, other):
6529+
if not isinstance(other, CUcontext):
6530+
return False
6531+
return self._pvt_ptr[0] == (<CUcontext>other)._pvt_ptr[0]
65286532
def __int__(self):
65296533
return <void_ptr>self._pvt_ptr[0]
65306534
def getPtr(self):
@@ -6556,6 +6560,10 @@ cdef class CUmodule:
65566560
return '<CUmodule ' + str(hex(self.__int__())) + '>'
65576561
def __index__(self):
65586562
return self.__int__()
6563+
def __eq__(self, other):
6564+
if not isinstance(other, CUmodule):
6565+
return False
6566+
return self._pvt_ptr[0] == (<CUmodule>other)._pvt_ptr[0]
65596567
def __int__(self):
65606568
return <void_ptr>self._pvt_ptr[0]
65616569
def getPtr(self):
@@ -6587,6 +6595,10 @@ cdef class CUfunction:
65876595
return '<CUfunction ' + str(hex(self.__int__())) + '>'
65886596
def __index__(self):
65896597
return self.__int__()
6598+
def __eq__(self, other):
6599+
if not isinstance(other, CUfunction):
6600+
return False
6601+
return self._pvt_ptr[0] == (<CUfunction>other)._pvt_ptr[0]
65906602
def __int__(self):
65916603
return <void_ptr>self._pvt_ptr[0]
65926604
def getPtr(self):
@@ -6618,6 +6630,10 @@ cdef class CUlibrary:
66186630
return '<CUlibrary ' + str(hex(self.__int__())) + '>'
66196631
def __index__(self):
66206632
return self.__int__()
6633+
def __eq__(self, other):
6634+
if not isinstance(other, CUlibrary):
6635+
return False
6636+
return self._pvt_ptr[0] == (<CUlibrary>other)._pvt_ptr[0]
66216637
def __int__(self):
66226638
return <void_ptr>self._pvt_ptr[0]
66236639
def getPtr(self):
@@ -6649,6 +6665,10 @@ cdef class CUkernel:
66496665
return '<CUkernel ' + str(hex(self.__int__())) + '>'
66506666
def __index__(self):
66516667
return self.__int__()
6668+
def __eq__(self, other):
6669+
if not isinstance(other, CUkernel):
6670+
return False
6671+
return self._pvt_ptr[0] == (<CUkernel>other)._pvt_ptr[0]
66526672
def __int__(self):
66536673
return <void_ptr>self._pvt_ptr[0]
66546674
def getPtr(self):
@@ -6680,6 +6700,10 @@ cdef class CUarray:
66806700
return '<CUarray ' + str(hex(self.__int__())) + '>'
66816701
def __index__(self):
66826702
return self.__int__()
6703+
def __eq__(self, other):
6704+
if not isinstance(other, CUarray):
6705+
return False
6706+
return self._pvt_ptr[0] == (<CUarray>other)._pvt_ptr[0]
66836707
def __int__(self):
66846708
return <void_ptr>self._pvt_ptr[0]
66856709
def getPtr(self):
@@ -6711,6 +6735,10 @@ cdef class CUmipmappedArray:
67116735
return '<CUmipmappedArray ' + str(hex(self.__int__())) + '>'
67126736
def __index__(self):
67136737
return self.__int__()
6738+
def __eq__(self, other):
6739+
if not isinstance(other, CUmipmappedArray):
6740+
return False
6741+
return self._pvt_ptr[0] == (<CUmipmappedArray>other)._pvt_ptr[0]
67146742
def __int__(self):
67156743
return <void_ptr>self._pvt_ptr[0]
67166744
def getPtr(self):
@@ -6742,6 +6770,10 @@ cdef class CUtexref:
67426770
return '<CUtexref ' + str(hex(self.__int__())) + '>'
67436771
def __index__(self):
67446772
return self.__int__()
6773+
def __eq__(self, other):
6774+
if not isinstance(other, CUtexref):
6775+
return False
6776+
return self._pvt_ptr[0] == (<CUtexref>other)._pvt_ptr[0]
67456777
def __int__(self):
67466778
return <void_ptr>self._pvt_ptr[0]
67476779
def getPtr(self):
@@ -6773,6 +6805,10 @@ cdef class CUsurfref:
67736805
return '<CUsurfref ' + str(hex(self.__int__())) + '>'
67746806
def __index__(self):
67756807
return self.__int__()
6808+
def __eq__(self, other):
6809+
if not isinstance(other, CUsurfref):
6810+
return False
6811+
return self._pvt_ptr[0] == (<CUsurfref>other)._pvt_ptr[0]
67766812
def __int__(self):
67776813
return <void_ptr>self._pvt_ptr[0]
67786814
def getPtr(self):
@@ -6804,6 +6840,10 @@ cdef class CUevent:
68046840
return '<CUevent ' + str(hex(self.__int__())) + '>'
68056841
def __index__(self):
68066842
return self.__int__()
6843+
def __eq__(self, other):
6844+
if not isinstance(other, CUevent):
6845+
return False
6846+
return self._pvt_ptr[0] == (<CUevent>other)._pvt_ptr[0]
68076847
def __int__(self):
68086848
return <void_ptr>self._pvt_ptr[0]
68096849
def getPtr(self):
@@ -6835,6 +6875,10 @@ cdef class CUstream:
68356875
return '<CUstream ' + str(hex(self.__int__())) + '>'
68366876
def __index__(self):
68376877
return self.__int__()
6878+
def __eq__(self, other):
6879+
if not isinstance(other, CUstream):
6880+
return False
6881+
return self._pvt_ptr[0] == (<CUstream>other)._pvt_ptr[0]
68386882
def __int__(self):
68396883
return <void_ptr>self._pvt_ptr[0]
68406884
def getPtr(self):
@@ -6866,6 +6910,10 @@ cdef class CUgraphicsResource:
68666910
return '<CUgraphicsResource ' + str(hex(self.__int__())) + '>'
68676911
def __index__(self):
68686912
return self.__int__()
6913+
def __eq__(self, other):
6914+
if not isinstance(other, CUgraphicsResource):
6915+
return False
6916+
return self._pvt_ptr[0] == (<CUgraphicsResource>other)._pvt_ptr[0]
68696917
def __int__(self):
68706918
return <void_ptr>self._pvt_ptr[0]
68716919
def getPtr(self):
@@ -6897,6 +6945,10 @@ cdef class CUexternalMemory:
68976945
return '<CUexternalMemory ' + str(hex(self.__int__())) + '>'
68986946
def __index__(self):
68996947
return self.__int__()
6948+
def __eq__(self, other):
6949+
if not isinstance(other, CUexternalMemory):
6950+
return False
6951+
return self._pvt_ptr[0] == (<CUexternalMemory>other)._pvt_ptr[0]
69006952
def __int__(self):
69016953
return <void_ptr>self._pvt_ptr[0]
69026954
def getPtr(self):
@@ -6928,6 +6980,10 @@ cdef class CUexternalSemaphore:
69286980
return '<CUexternalSemaphore ' + str(hex(self.__int__())) + '>'
69296981
def __index__(self):
69306982
return self.__int__()
6983+
def __eq__(self, other):
6984+
if not isinstance(other, CUexternalSemaphore):
6985+
return False
6986+
return self._pvt_ptr[0] == (<CUexternalSemaphore>other)._pvt_ptr[0]
69316987
def __int__(self):
69326988
return <void_ptr>self._pvt_ptr[0]
69336989
def getPtr(self):
@@ -6959,6 +7015,10 @@ cdef class CUgraph:
69597015
return '<CUgraph ' + str(hex(self.__int__())) + '>'
69607016
def __index__(self):
69617017
return self.__int__()
7018+
def __eq__(self, other):
7019+
if not isinstance(other, CUgraph):
7020+
return False
7021+
return self._pvt_ptr[0] == (<CUgraph>other)._pvt_ptr[0]
69627022
def __int__(self):
69637023
return <void_ptr>self._pvt_ptr[0]
69647024
def getPtr(self):
@@ -6990,6 +7050,10 @@ cdef class CUgraphNode:
69907050
return '<CUgraphNode ' + str(hex(self.__int__())) + '>'
69917051
def __index__(self):
69927052
return self.__int__()
7053+
def __eq__(self, other):
7054+
if not isinstance(other, CUgraphNode):
7055+
return False
7056+
return self._pvt_ptr[0] == (<CUgraphNode>other)._pvt_ptr[0]
69937057
def __int__(self):
69947058
return <void_ptr>self._pvt_ptr[0]
69957059
def getPtr(self):
@@ -7021,6 +7085,10 @@ cdef class CUgraphExec:
70217085
return '<CUgraphExec ' + str(hex(self.__int__())) + '>'
70227086
def __index__(self):
70237087
return self.__int__()
7088+
def __eq__(self, other):
7089+
if not isinstance(other, CUgraphExec):
7090+
return False
7091+
return self._pvt_ptr[0] == (<CUgraphExec>other)._pvt_ptr[0]
70247092
def __int__(self):
70257093
return <void_ptr>self._pvt_ptr[0]
70267094
def getPtr(self):
@@ -7052,6 +7120,10 @@ cdef class CUmemoryPool:
70527120
return '<CUmemoryPool ' + str(hex(self.__int__())) + '>'
70537121
def __index__(self):
70547122
return self.__int__()
7123+
def __eq__(self, other):
7124+
if not isinstance(other, CUmemoryPool):
7125+
return False
7126+
return self._pvt_ptr[0] == (<CUmemoryPool>other)._pvt_ptr[0]
70557127
def __int__(self):
70567128
return <void_ptr>self._pvt_ptr[0]
70577129
def getPtr(self):
@@ -7083,6 +7155,10 @@ cdef class CUuserObject:
70837155
return '<CUuserObject ' + str(hex(self.__int__())) + '>'
70847156
def __index__(self):
70857157
return self.__int__()
7158+
def __eq__(self, other):
7159+
if not isinstance(other, CUuserObject):
7160+
return False
7161+
return self._pvt_ptr[0] == (<CUuserObject>other)._pvt_ptr[0]
70867162
def __int__(self):
70877163
return <void_ptr>self._pvt_ptr[0]
70887164
def getPtr(self):
@@ -7114,6 +7190,10 @@ cdef class CUgraphDeviceNode:
71147190
return '<CUgraphDeviceNode ' + str(hex(self.__int__())) + '>'
71157191
def __index__(self):
71167192
return self.__int__()
7193+
def __eq__(self, other):
7194+
if not isinstance(other, CUgraphDeviceNode):
7195+
return False
7196+
return self._pvt_ptr[0] == (<CUgraphDeviceNode>other)._pvt_ptr[0]
71177197
def __int__(self):
71187198
return <void_ptr>self._pvt_ptr[0]
71197199
def getPtr(self):
@@ -7145,6 +7225,10 @@ cdef class CUasyncCallbackHandle:
71457225
return '<CUasyncCallbackHandle ' + str(hex(self.__int__())) + '>'
71467226
def __index__(self):
71477227
return self.__int__()
7228+
def __eq__(self, other):
7229+
if not isinstance(other, CUasyncCallbackHandle):
7230+
return False
7231+
return self._pvt_ptr[0] == (<CUasyncCallbackHandle>other)._pvt_ptr[0]
71487232
def __int__(self):
71497233
return <void_ptr>self._pvt_ptr[0]
71507234
def getPtr(self):
@@ -7176,6 +7260,10 @@ cdef class CUgreenCtx:
71767260
return '<CUgreenCtx ' + str(hex(self.__int__())) + '>'
71777261
def __index__(self):
71787262
return self.__int__()
7263+
def __eq__(self, other):
7264+
if not isinstance(other, CUgreenCtx):
7265+
return False
7266+
return self._pvt_ptr[0] == (<CUgreenCtx>other)._pvt_ptr[0]
71797267
def __int__(self):
71807268
return <void_ptr>self._pvt_ptr[0]
71817269
def getPtr(self):
@@ -7205,6 +7293,10 @@ cdef class CUlinkState:
72057293
return '<CUlinkState ' + str(hex(self.__int__())) + '>'
72067294
def __index__(self):
72077295
return self.__int__()
7296+
def __eq__(self, other):
7297+
if not isinstance(other, CUlinkState):
7298+
return False
7299+
return self._pvt_ptr[0] == (<CUlinkState>other)._pvt_ptr[0]
72087300
def __int__(self):
72097301
return <void_ptr>self._pvt_ptr[0]
72107302
def getPtr(self):
@@ -7236,6 +7328,10 @@ cdef class CUdevResourceDesc:
72367328
return '<CUdevResourceDesc ' + str(hex(self.__int__())) + '>'
72377329
def __index__(self):
72387330
return self.__int__()
7331+
def __eq__(self, other):
7332+
if not isinstance(other, CUdevResourceDesc):
7333+
return False
7334+
return self._pvt_ptr[0] == (<CUdevResourceDesc>other)._pvt_ptr[0]
72397335
def __int__(self):
72407336
return <void_ptr>self._pvt_ptr[0]
72417337
def getPtr(self):
@@ -7265,6 +7361,10 @@ cdef class CUlogsCallbackHandle:
72657361
return '<CUlogsCallbackHandle ' + str(hex(self.__int__())) + '>'
72667362
def __index__(self):
72677363
return self.__int__()
7364+
def __eq__(self, other):
7365+
if not isinstance(other, CUlogsCallbackHandle):
7366+
return False
7367+
return self._pvt_ptr[0] == (<CUlogsCallbackHandle>other)._pvt_ptr[0]
72687368
def __int__(self):
72697369
return <void_ptr>self._pvt_ptr[0]
72707370
def getPtr(self):
@@ -7296,6 +7396,10 @@ cdef class CUeglStreamConnection:
72967396
return '<CUeglStreamConnection ' + str(hex(self.__int__())) + '>'
72977397
def __index__(self):
72987398
return self.__int__()
7399+
def __eq__(self, other):
7400+
if not isinstance(other, CUeglStreamConnection):
7401+
return False
7402+
return self._pvt_ptr[0] == (<CUeglStreamConnection>other)._pvt_ptr[0]
72997403
def __int__(self):
73007404
return <void_ptr>self._pvt_ptr[0]
73017405
def getPtr(self):
@@ -7325,6 +7429,10 @@ cdef class EGLImageKHR:
73257429
return '<EGLImageKHR ' + str(hex(self.__int__())) + '>'
73267430
def __index__(self):
73277431
return self.__int__()
7432+
def __eq__(self, other):
7433+
if not isinstance(other, EGLImageKHR):
7434+
return False
7435+
return self._pvt_ptr[0] == (<EGLImageKHR>other)._pvt_ptr[0]
73287436
def __int__(self):
73297437
return <void_ptr>self._pvt_ptr[0]
73307438
def getPtr(self):
@@ -7354,6 +7462,10 @@ cdef class EGLStreamKHR:
73547462
return '<EGLStreamKHR ' + str(hex(self.__int__())) + '>'
73557463
def __index__(self):
73567464
return self.__int__()
7465+
def __eq__(self, other):
7466+
if not isinstance(other, EGLStreamKHR):
7467+
return False
7468+
return self._pvt_ptr[0] == (<EGLStreamKHR>other)._pvt_ptr[0]
73577469
def __int__(self):
73587470
return <void_ptr>self._pvt_ptr[0]
73597471
def getPtr(self):
@@ -7383,6 +7495,10 @@ cdef class EGLSyncKHR:
73837495
return '<EGLSyncKHR ' + str(hex(self.__int__())) + '>'
73847496
def __index__(self):
73857497
return self.__int__()
7498+
def __eq__(self, other):
7499+
if not isinstance(other, EGLSyncKHR):
7500+
return False
7501+
return self._pvt_ptr[0] == (<EGLSyncKHR>other)._pvt_ptr[0]
73867502
def __int__(self):
73877503
return <void_ptr>self._pvt_ptr[0]
73887504
def getPtr(self):

cuda_bindings/cuda/bindings/nvrtc.pyx.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ cdef class nvrtcProgram:
109109
return '<nvrtcProgram ' + str(hex(self.__int__())) + '>'
110110
def __index__(self):
111111
return self.__int__()
112+
def __eq__(self, other):
113+
if not isinstance(other, nvrtcProgram):
114+
return False
115+
return self._pvt_ptr[0] == (<nvrtcProgram>other)._pvt_ptr[0]
112116
def __int__(self):
113117
return <void_ptr>self._pvt_ptr[0]
114118
def getPtr(self):

0 commit comments

Comments
 (0)