Skip to content

Commit

Permalink
fix: guard concurrent extension datatype setting with a lock (#3589)
Browse files Browse the repository at this point in the history
Fixes a bug where concurrent accesses to
`_ensure_registered_super_ext_type` might potentially cause race
conditions, erroring out on multiple calls to
`pa.register_extension_type(DaftExtension(pa.null()))` from different
threads.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Dec 18, 2024
1 parent 4bb0413 commit 855a02d
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions daft/datatype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING, Union

from daft.context import get_context
Expand Down Expand Up @@ -576,38 +577,44 @@ def __hash__(self) -> int:
DataTypeLike = Union[DataType, type]


_EXT_TYPE_REGISTRATION_LOCK = threading.Lock()
_EXT_TYPE_REGISTERED = False
_STATIC_DAFT_EXTENSION = None


def _ensure_registered_super_ext_type():
global _EXT_TYPE_REGISTERED
global _STATIC_DAFT_EXTENSION

# Double-checked locking: avoid grabbing the lock if we know that the ext type
# has already been registered.
if not _EXT_TYPE_REGISTERED:
with _EXT_TYPE_REGISTRATION_LOCK:
if not _EXT_TYPE_REGISTERED:

class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")
class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata
def __arrow_ext_serialize__(self):
return self._metadata

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)
@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)

_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit
_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True
atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True


def get_super_ext_type():
Expand Down

0 comments on commit 855a02d

Please sign in to comment.