Skip to content

Commit

Permalink
fix(sqlite): use source file path when overriding default datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jul 29, 2024
1 parent 2da273c commit 23ee9bf
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/prisma/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,13 @@ def _make_sqlite_datasource(self) -> DatasourceOverride:
"""
return {
'name': self._default_datasource['name'],
'url': self._make_sqlite_url(self._default_datasource['url']),
'url': self._make_sqlite_url(
self._default_datasource['url'],
relative_to=self._default_datasource.get('source_file_path'),
),
}

def _make_sqlite_url(self, url: str, *, relative_to: Path | None = None) -> str:
def _make_sqlite_url(self, url: str, *, relative_to: Path | str | None = None) -> str:
url_path = removeprefix(removeprefix(url, 'file:'), 'sqlite:')
if url_path == url:
return url
Expand All @@ -242,6 +245,9 @@ def _make_sqlite_url(self, url: str, *, relative_to: Path | None = None) -> str:
if relative_to is None:
relative_to = self._schema_path.parent

if isinstance(relative_to, str):
relative_to = Path(relative_to)

return f'file:{relative_to.joinpath(url_path).resolve()}'

def _prepare_connect_args(
Expand All @@ -268,10 +274,12 @@ def _prepare_connect_args(
ds.setdefault('name', self._default_datasource_name)
datasources = [ds]
elif self._active_provider == 'sqlite':
log.debug('overriding default SQLite datasource path')
# Override the default SQLite path to protect against
# https://github.com/RobertCraigie/prisma-client-py/issues/409
datasources = [self._make_sqlite_datasource()]

log.debug('datasources %s', datasources)
return timeout, datasources

def _make_query_builder(
Expand Down
1 change: 1 addition & 0 deletions src/prisma/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class DatasourceOverride(_DatasourceOverrideOptional):

class _DatasourceOptional(TypedDict, total=False):
env: str
source_file_path: str | None


class Datasource(_DatasourceOptional):
Expand Down
2 changes: 2 additions & 0 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ class Datasource(BaseModel):
active_provider: str = FieldInfo(alias='activeProvider')
url: 'OptionalValueFromEnvVar'

source_file_path: Optional[str] = FieldInfo(alias='sourceFilePath')


class Generator(GenericModel, Generic[ConfigT]):
name: str
Expand Down
1 change: 1 addition & 0 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
return {
'name': '{{ datasources[0].name }}',
'url': OptionalValueFromEnvVar(**{{ model_dict(datasources[0].url, by_alias=True) }}).resolve(),
'source_file_path': '{{ datasources[0].source_file_path }}',
}

{% if active_provider != 'mongodb' %}
Expand Down

0 comments on commit 23ee9bf

Please sign in to comment.