diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a39cef8b6b5..55ae20da8df 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -15,11 +15,6 @@ updates: schedule: interval: "daily" - - package-ecosystem: "docker" - directory: "/lambdas/molecule" - schedule: - interval: "daily" - - package-ecosystem: "docker" directory: "/catalog" schedule: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a1eca6112fc..0a2d81dd7e3 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -9,6 +9,7 @@ - [ ] Automated tests (e.g. Preflight) - [ ] Confirm that this change meets security best practices and does not violate the security model - [ ] Documentation + - [ ] run `optipng` on any new PNGs - [ ] [Python: Run `build.py`](../tree/master/gendocs/build.py) for new docstrings - [ ] JavaScript: basic explanation and screenshot of new features - [ ] Markdown somewhere in docs/**/*.md that explains the feature to end users (said .md files should be linked from SUMMARY.md so they appear on https://docs.quiltdata.com) diff --git a/.github/workflows/deploy-catalog.yaml b/.github/workflows/deploy-catalog.yaml new file mode 100644 index 00000000000..fc4f8aed0fe --- /dev/null +++ b/.github/workflows/deploy-catalog.yaml @@ -0,0 +1,68 @@ +name: Deploy catalog to ECR + +on: + push: + branches: + - master + paths: + - '.github/workflows/deploy-catalog.yaml' + - 'catalog/**' + - 'shared/**' + +jobs: + deploy-catalog-ecr: + runs-on: ubuntu-latest + defaults: + run: + working-directory: catalog + # These permissions are needed to interact with GitHub's OIDC Token endpoint. + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version-file: 'catalog/package.json' + cache: 'npm' + cache-dependency-path: 'catalog/package-lock.json' + - run: npm ci + - run: npm run build + - name: Configure AWS credentials from Prod account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::730278974607:role/github/GitHub-Quilt + aws-region: us-east-1 + - name: Login to Prod ECR + id: login-prod-ecr + uses: aws-actions/amazon-ecr-login@v2 + - name: Login to MP ECR + id: login-mp-ecr + uses: aws-actions/amazon-ecr-login@v2 + with: + registries: 709825985650 + - name: Configure AWS credentials from GovCloud account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws-us-gov:iam::313325871032:role/github/GitHub-Quilt + aws-region: us-gov-east-1 + - name: Login to GovCloud ECR + id: login-govcloud-ecr + uses: aws-actions/amazon-ecr-login@v2 + - name: Build and push Docker image to Prod, MP and GovCloud ECR + env: + ECR_REGISTRY_PROD: ${{ steps.login-prod-ecr.outputs.registry }} + ECR_REGISTRY_GOVCLOUD: ${{ steps.login-govcloud-ecr.outputs.registry }} + ECR_REGISTRY_MP: ${{ steps.login-mp-ecr.outputs.registry }} + ECR_REPOSITORY: quiltdata/catalog + ECR_REPOSITORY_MP: quilt-data/quilt-payg-catalog + IMAGE_TAG: ${{ github.sha }} + run: | + docker buildx build \ + -t $ECR_REGISTRY_PROD/$ECR_REPOSITORY:$IMAGE_TAG \ + -t $ECR_REGISTRY_GOVCLOUD/$ECR_REPOSITORY:$IMAGE_TAG \ + -t $ECR_REGISTRY_MP/$ECR_REPOSITORY_MP:$IMAGE_TAG \ + . + docker push $ECR_REGISTRY_PROD/$ECR_REPOSITORY:$IMAGE_TAG + docker push $ECR_REGISTRY_GOVCLOUD/$ECR_REPOSITORY:$IMAGE_TAG + docker push $ECR_REGISTRY_MP/$ECR_REPOSITORY_MP:$IMAGE_TAG diff --git a/.github/workflows/deploy-lambdas.yaml b/.github/workflows/deploy-lambdas.yaml new file mode 100644 index 00000000000..e453ef18f5c --- /dev/null +++ b/.github/workflows/deploy-lambdas.yaml @@ -0,0 +1,91 @@ +name: Deploy lambdas to S3 and ECR + +on: + push: + branches: + - master + paths: + - '.github/workflows/deploy-lambdas.yaml' + - 'lambdas/**' + +jobs: + deploy-lambda-s3: + strategy: + matrix: + path: + - access_counts + - indexer + - pkgevents + - pkgpush + - preview + - s3hash + - status_reports + - tabular_preview + - transcode + runs-on: ubuntu-latest + # These permissions are needed to interact with GitHub's OIDC Token endpoint. + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v4 + - name: Build zip + run: | + BUILDER_IMAGE=quiltdata/lambda:build-$(cat lambdas/${{ matrix.path }}/.python-version) + + touch ./out.zip + + docker run --rm \ + --entrypoint /build_zip.sh \ + -v "$PWD/lambdas/${{ matrix.path }}":/lambda/function:z \ + -v "$PWD/out.zip":/out.zip:z \ + -v "$PWD/lambdas/scripts/build_zip.sh":/build_zip.sh:z \ + "$BUILDER_IMAGE" + - name: Configure AWS credentials from Prod account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::730278974607:role/github/GitHub-Quilt + aws-region: us-east-1 + - name: Upload zips to Prod S3 + run: | + s3_key="${{ matrix.path }}/${{ github.sha }}.zip" + ./lambdas/scripts/upload_zip.sh ./out.zip "$AWS_REGION" "$s3_key" + - name: Configure AWS credentials from GovCloud account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws-us-gov:iam::313325871032:role/github/GitHub-Quilt + aws-region: us-gov-east-1 + - name: Upload zips to GovCloud S3 + run: | + s3_key="${{ matrix.path }}/${{ github.sha }}.zip" + ./lambdas/scripts/upload_zip.sh ./out.zip "$AWS_REGION" "$s3_key" + + deploy-lambda-ecr: + strategy: + matrix: + path: + - thumbnail + runs-on: ubuntu-latest + # These permissions are needed to interact with GitHub's OIDC Token endpoint. + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v4 + - name: Build Docker image + working-directory: ./lambdas/${{ matrix.path }} + run: docker buildx build -t "quiltdata/lambdas/${{ matrix.path }}:${{ github.sha }}" -f Dockerfile . + - name: Configure AWS credentials from Prod account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::730278974607:role/github/GitHub-Quilt + aws-region: us-east-1 + - name: Push Docker image to Prod ECR + run: ./lambdas/scripts/upload_ecr.sh 730278974607 "quiltdata/lambdas/${{ matrix.path }}:${{ github.sha }}" + - name: Configure AWS credentials from GovCloud account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws-us-gov:iam::313325871032:role/github/GitHub-Quilt + aws-region: us-gov-east-1 + - name: Push Docker image to GovCloud ECR + run: ./lambdas/scripts/upload_ecr.sh 313325871032 "quiltdata/lambdas/${{ matrix.path }}:${{ github.sha }}" diff --git a/.github/workflows/deploy-s3-proxy.yaml b/.github/workflows/deploy-s3-proxy.yaml new file mode 100644 index 00000000000..0548ca45e81 --- /dev/null +++ b/.github/workflows/deploy-s3-proxy.yaml @@ -0,0 +1,62 @@ +name: Deploy S3 Proxy to ECR + +on: + push: + branches: + - master + paths: + - .github/workflows/deploy-s3-proxy.yaml + - 's3-proxy/**' + +jobs: + deploy-s3-proxy: + runs-on: ubuntu-latest + defaults: + run: + working-directory: s3-proxy + # These permissions are needed to interact with GitHub's OIDC Token endpoint. + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v4 + - name: Configure AWS credentials from Prod account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::730278974607:role/github/GitHub-Quilt + aws-region: us-east-1 + - name: Login to Prod ECR + id: login-prod-ecr + uses: aws-actions/amazon-ecr-login@v2 + - name: Login to MP ECR + id: login-mp-ecr + uses: aws-actions/amazon-ecr-login@v2 + with: + registries: 709825985650 + - name: Configure AWS credentials from GovCloud account + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws-us-gov:iam::313325871032:role/github/GitHub-Quilt + aws-region: us-gov-east-1 + - name: Login to GovCloud ECR + id: login-govcloud-ecr + uses: aws-actions/amazon-ecr-login@v2 + - name: Build and push Docker image to ECR + env: + ECR_REGISTRY_PROD: ${{ steps.login-prod-ecr.outputs.registry }} + ECR_REGISTRY_GOVCLOUD: ${{ steps.login-govcloud-ecr.outputs.registry }} + ECR_REGISTRY_MP: ${{ steps.login-mp-ecr.outputs.registry }} + ECR_REPOSITORY: quiltdata/s3-proxy + ECR_REPOSITORY_MP: quilt-data/quilt-payg-s3-proxy + IMAGE_TAG: ${{ github.sha }} + run: | + docker buildx build \ + -t $ECR_REGISTRY_PROD/$ECR_REPOSITORY:$IMAGE_TAG \ + -t $ECR_REGISTRY_GOVCLOUD/$ECR_REPOSITORY:$IMAGE_TAG \ + -t $ECR_REGISTRY_MP/$ECR_REPOSITORY_MP:$IMAGE_TAG \ + . + docker push $ECR_REGISTRY_PROD/$ECR_REPOSITORY:$IMAGE_TAG + docker push $ECR_REGISTRY_GOVCLOUD/$ECR_REPOSITORY:$IMAGE_TAG + # push to MP last because it can't be re-pushed using the same tag + # so we can re-run the job in case something has failed + docker push $ECR_REGISTRY_MP/$ECR_REPOSITORY_MP:$IMAGE_TAG diff --git a/.github/workflows/js-ci.yml b/.github/workflows/js-ci.yml index 47ca43b59bb..c6e23434716 100644 --- a/.github/workflows/js-ci.yml +++ b/.github/workflows/js-ci.yml @@ -12,13 +12,11 @@ jobs: defaults: run: working-directory: catalog - env: - NODE_OPTIONS: --max-old-space-size=4096 steps: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: - node-version: '16.11' + node-version-file: 'catalog/package.json' cache: 'npm' cache-dependency-path: 'catalog/package-lock.json' - run: npm ci @@ -27,20 +25,19 @@ jobs: - env: BUNDLEWATCH_GITHUB_TOKEN: ${{ secrets.BUNDLEWATCH_GITHUB_TOKEN }} run: npm run bundlewatch - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: flags: catalog name: ${{ github.job }} lint-docs: runs-on: ubuntu-latest - defaults: - run: - working-directory: docs env: NODE_OPTIONS: --max-old-space-size=4096 steps: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 with: - node-version: '16.11' - - run: npx --package=markdownlint-cli markdownlint --ignore node_modules **/*.md + node-version-file: 'catalog/package.json' + - run: npx --package=markdownlint-cli markdownlint . diff --git a/.github/workflows/py-ci.yml b/.github/workflows/py-ci.yml index 9574aceefd8..f0b467ea918 100644 --- a/.github/workflows/py-ci.yml +++ b/.github/workflows/py-ci.yml @@ -13,14 +13,14 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip setuptools - python -m pip install 'pylint==2.10.2' 'pycodestyle>=2.6.1' + python -m pip install 'pylint==3.2.7' 'pycodestyle>=2.6.1' - name: Run pylint run: | - pylint $(find -name '*.py' -not -path './venv/*') + pylint . - name: Run pycodestyle run: | pycodestyle $(find -name '*.py' -not -path './venv/*') @@ -31,7 +31,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.11' - name: Install dependencies run: | python -m pip install --upgrade pip setuptools @@ -48,11 +48,13 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + # on newer versions it crashes with + # TypeError: )> is not a callable object + python-version: '3.9' - name: install deps run: | python -m pip install --upgrade pip setuptools - python -m pip install PyYAML~=5.4.1 api/python nbconvert git+https://github.com/quiltdata/pydoc-markdown.git@quilt + python -m pip install api/python nbconvert git+https://github.com/quiltdata/pydoc-markdown.git@v2.0.5+quilt3.2 - name: generate docs run: cd gendocs && python build.py - name: show invisible changes via cat @@ -68,7 +70,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: install poetry run: python -m pip install poetry - name: install deps @@ -80,7 +82,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] runs-on: ${{ matrix.os }} env: QUILT_DISABLE_USAGE_METRICS: true @@ -96,8 +98,9 @@ jobs: - name: Run Pytest run: | pytest --cov=api/python api/python - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} OS: ${{ matrix.os }} PYTHON_VERSION: ${{ matrix.python-version }} with: @@ -130,15 +133,15 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.7' + python-version: '3.12' + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools + python -m pip install build==1.2.2.post1 twine==5.1.1 - name: verify git tag vs. version env: CIRCLE_TAG: ${{ github.ref_name }} run: python setup.py verify - - name: Install dependencies - run: | - python -m pip install --upgrade pip setuptools - python -m pip install build==0.8.0 twine==4.0.0 - name: build run: python -m build - name: upload to PyPI @@ -149,25 +152,20 @@ jobs: test-lambda: strategy: + fail-fast: false matrix: path: - access_counts - - es/indexer - - molecule + - indexer - pkgevents - pkgpush - - pkgselect - preview - s3hash - - s3select - shared - status_reports - tabular_preview + - thumbnail - transcode - python-version: ['3.7'] - include: - - path: thumbnail - python-version: '3.9' runs-on: ubuntu-latest env: QUILT_DISABLE_USAGE_METRICS: true @@ -177,33 +175,33 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version-file: lambdas/${{ matrix.path }}/.python-version - name: Install dependencies run: | - # Due to behavior change in pip>=23.1 installing tifffile==0.15.1 - # from thumbnail lambda fails whithout installed wheel. - # See https://github.com/pypa/pip/issues/8559. - python -m pip install wheel - if [ ${{ matrix.path }} == "shared" ] - python -m pip install -e lambdas/shared[tests] + if [ ${{ matrix.path }} == "thumbnail" ] then - python -m pip install -e lambdas/shared - python -m pip install -e lambdas/${{ matrix.path }} + # Due to behavior change in pip>=23.1 installing tifffile==0.15.1 + # from thumbnail lambda fails whithout installed wheel. + # See https://github.com/pypa/pip/issues/8559. + # HACK: Pre-install numpy v1 as a build dependency for tifffile to prevent it from using v2 and failing to build + python -m pip install wheel 'numpy<2' fi - python -m pip install -r lambdas/${{ matrix.path }}/test-requirements.txt # Try to simulate the lambda .zip file: # - Use --no-deps to ensure that second-order dependencies are included in the requirements file # - Remove "tests" directories # - Run "strip" on shared libraries - python -m pip install -t deps --no-deps -r lambdas/${{ matrix.path }}/requirements.txt + python -m pip install -t deps --no-deps -r lambdas/${{ matrix.path }}/requirements.txt lambdas/${{ matrix.path }} find deps -name tests -type d -exec rm -r \{} \+ find deps \( -name '*.so.*' -o -name '*.so' \) -type f -exec strip \{} \+ + + python -m pip install -r lambdas/${{ matrix.path }}/test-requirements.txt - name: Pytest run: | pytest --cov=lambdas lambdas/${{ matrix.path }} - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} LAMBDA: ${{ matrix.path }} with: flags: lambda diff --git a/.github/workflows/test-quilt3-admin-codegen.yaml b/.github/workflows/test-quilt3-admin-codegen.yaml new file mode 100644 index 00000000000..d98719c8223 --- /dev/null +++ b/.github/workflows/test-quilt3-admin-codegen.yaml @@ -0,0 +1,36 @@ +name: Test quilt3.admin code generation + +on: + push: + paths: + - '.github/workflows/test-quilt3-admin-codegen.yaml' + - 'shared/graphql/schema.graphql' + - 'api/python/quilt3-admin/**' + - 'api/python/quilt3/admin/_graphql_client/**' + pull_request: + paths: + - '.github/workflows/test-quilt3-admin-codegen.yaml' + - 'shared/graphql/schema.graphql' + - 'api/python/quilt3-admin/**' + - 'api/python/quilt3/admin/_graphql_client/**' + merge_group: + +jobs: + test-quilt3-admin-codegen: + name: test quilt3.admin generated code is up-to-date + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./api/python/quilt3-admin + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version-file: 'api/python/quilt3-admin/.python-version' + cache: 'pip' + cache-dependency-path: 'api/python/quilt3-admin/requirements.txt' + - run: pip install -r requirements.txt + - run: rm -r ../quilt3/admin/_graphql_client + - run: ariadne-codegen + - name: Check for changes + run: git diff --exit-code diff --git a/.gitignore b/.gitignore index aa9b758d10c..64317deaec1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build dist node_modules tmp_* +.venv **/.DS_Store @@ -46,3 +47,6 @@ testdocs/notebooks testdocs/scripts .env + +.venv +.aider* diff --git a/.markdownlint.jsonc b/.markdownlint.jsonc new file mode 100644 index 00000000000..36e6024b7c3 --- /dev/null +++ b/.markdownlint.jsonc @@ -0,0 +1,7 @@ +{ + "default": true, + "no-blanks-blockquote": false, + "no-duplicate-heading": { + "siblings_only": true + } +} diff --git a/.markdownlintignore b/.markdownlintignore new file mode 100644 index 00000000000..6010fbdad16 --- /dev/null +++ b/.markdownlintignore @@ -0,0 +1,10 @@ +# Not ready for lint yet +gendocs +testdocs + +# Autogenerated +docs/api-reference + +.git +catalog/node_modules +venv diff --git a/api/python/.gitattributes b/api/python/.gitattributes new file mode 100644 index 00000000000..aa9cc2f3e99 --- /dev/null +++ b/api/python/.gitattributes @@ -0,0 +1 @@ +quilt3/admin/_graphql_client/** linguist-generated diff --git a/api/python/quilt3-admin/.python-version b/api/python/quilt3-admin/.python-version new file mode 100644 index 00000000000..e4fba218358 --- /dev/null +++ b/api/python/quilt3-admin/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/api/python/quilt3-admin/README.md b/api/python/quilt3-admin/README.md new file mode 100644 index 00000000000..9e6c66d280d --- /dev/null +++ b/api/python/quilt3-admin/README.md @@ -0,0 +1,10 @@ +# quilt3.admin GraphQL code generation + +```sh +python -m venv venv +python -m pip install -r requirements.txt +ariadne-codegen +``` + +This will generate GraphQL client in `api/python/quilt3/admin/_graphql_client/` using +GraphQL queries from `queries.graphql`. diff --git a/api/python/quilt3-admin/base_client.py b/api/python/quilt3-admin/base_client.py new file mode 100644 index 00000000000..a3346665914 --- /dev/null +++ b/api/python/quilt3-admin/base_client.py @@ -0,0 +1,211 @@ +# This is +# https://github.com/mirumee/ariadne-codegen/blob/5bfd63c5e7e3a8cc5293eb94deee638b7adab98d/ariadne_codegen/client_generators/dependencies/base_client.py +# modified to use our requests session instead of httpx. +# pylint: disable=relative-beyond-top-level +import json +from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, cast + +import requests +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from quilt3 import session + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) + +Self = TypeVar("Self", bound="BaseClient") + + +class BaseClient: + def __init__( + self, + ) -> None: + self.url = session.get_registry_url() + "/graphql" + + self.http_client = session.get_session() + + def __enter__(self: Self) -> Self: + return self + + def __exit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + self.http_client.close() + + def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> requests.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: requests.Response) -> Dict[str, Any]: + if not 200 <= response.status_code < 300: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> requests.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return self.http_client.post(url=self.url, data=data, files=files, **kwargs) + + def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> requests.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return self.http_client.post( + url=self.url, + data=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) diff --git a/api/python/quilt3-admin/exceptions.py b/api/python/quilt3-admin/exceptions.py new file mode 100644 index 00000000000..0876664a226 --- /dev/null +++ b/api/python/quilt3-admin/exceptions.py @@ -0,0 +1,87 @@ +# This is +# https://github.com/mirumee/ariadne-codegen/blob/5bfd63c5e7e3a8cc5293eb94deee638b7adab98d/ariadne_codegen/client_generators/dependencies/exceptions.py +# modified to use our requests instead of httpx. +# pylint: disable=super-init-not-called +from typing import Any, Dict, List, Optional, Union + +import requests + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: requests.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: requests.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/api/python/quilt3-admin/pyproject.toml b/api/python/quilt3-admin/pyproject.toml new file mode 100644 index 00000000000..a43cf7442ad --- /dev/null +++ b/api/python/quilt3-admin/pyproject.toml @@ -0,0 +1,21 @@ +[tool.ariadne-codegen] +schema_path = "../../../shared/graphql/schema.graphql" +queries_path = "queries.graphql" +target_package_path = "../quilt3/admin/" +target_package_name = "_graphql_client" +files_to_include = [ + "exceptions.py", +] +async_client = false +base_client_file_path = "base_client.py" +base_client_name = "BaseClient" +include_all_inputs = false +include_all_enums = false +plugins = [ + "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", + "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", + "ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin", +] + +[tool.ariadne-codegen.scalars.Datetime] +type = "datetime.datetime" diff --git a/api/python/quilt3-admin/queries.graphql b/api/python/quilt3-admin/queries.graphql new file mode 100644 index 00000000000..0fb50433554 --- /dev/null +++ b/api/python/quilt3-admin/queries.graphql @@ -0,0 +1,243 @@ +fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn +} +fragment ManagedRoleSelection on ManagedRole { + id + name + arn +} +fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection +} +fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } +} +fragment UserMutationSelection on UserResult { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection +} +fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } +} +fragment OperationErrorSelection on OperationError { + message + name + context +} +fragment SsoConfigSelection on SsoConfig { + text + timestamp + uploader { + ...UserSelection + } +} + +query rolesList { + roles { + ...RoleSelection + } +} + +query usersGet($name: String!) { + admin { + user { + get(name: $name) { + ...UserSelection + } + } + } +} + +query usersList { + admin { + user { + list { + ...UserSelection + } + } + } +} + +mutation usersCreate($input: UserInput!) { + admin { + user { + create(input: $input) { + ...UserMutationSelection + } + } + } +} + +mutation usersDelete($name: String!) { + admin { + user { + mutate(name: $name) { + delete { + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } +} + +mutation usersSetEmail($email: String!, $name: String!) { + admin { + user { + mutate(name: $name) { + setEmail(email: $email) { + ...UserMutationSelection + } + } + } + } +} + +mutation usersSetAdmin($name: String!, $admin: Boolean!) { + admin { + user { + mutate(name: $name) { + setAdmin(admin: $admin) { + ...UserMutationSelection + } + } + } + } +} + +mutation usersSetActive($active: Boolean!, $name: String!) { + admin { + user { + mutate(name: $name) { + setActive(active: $active) { + ...UserMutationSelection + } + } + } + } +} + +mutation usersResetPassword($name: String!) { + admin { + user { + mutate(name: $name) { + resetPassword { + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } +} + +mutation usersSetRole($name: String!, $role: String!, $extraRoles: [String!], $append: Boolean!) { + admin { + user { + mutate(name: $name) { + setRole(role: $role, extraRoles: $extraRoles, append: $append) { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } +} + +mutation usersAddRoles($name: String!, $roles: [String!]!) { + admin { + user { + mutate(name: $name) { + addRoles(roles: $roles) { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } +} + +mutation usersRemoveRoles($name: String!, $roles: [String!]!, $fallback: String) { + admin { + user { + mutate(name: $name) { + removeRoles(roles: $roles, fallback: $fallback) { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } +} + +query ssoConfigGet { + admin { + ssoConfig { + ...SsoConfigSelection + } + } +} + +mutation ssoConfigSet($config: String) { + admin { + setSsoConfig(config: $config) { + ...SsoConfigSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } +} + +query bucketTabulatorTablesList($name: String!) { + bucketConfig(name: $name) { + tabulatorTables { + name + config + } + } +} + +mutation bucketTabulatorTableSet($bucketName: String!, $tableName: String!, $config: String) { + admin { + bucketSetTabulatorTable(bucketName: $bucketName, tableName: $tableName, config: $config) { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } +} + +mutation bucketTabulatorTableRename($bucketName: String!, $tableName: String!, $newTableName: String!) { + admin { + bucketRenameTabulatorTable(bucketName: $bucketName, tableName: $tableName, newTableName: $newTableName) { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } +} diff --git a/api/python/quilt3-admin/requirements.in b/api/python/quilt3-admin/requirements.in new file mode 100644 index 00000000000..2f122266b8e --- /dev/null +++ b/api/python/quilt3-admin/requirements.in @@ -0,0 +1 @@ +ariadne-codegen diff --git a/api/python/quilt3-admin/requirements.txt b/api/python/quilt3-admin/requirements.txt new file mode 100644 index 00000000000..0a5c604f8e3 --- /dev/null +++ b/api/python/quilt3-admin/requirements.txt @@ -0,0 +1,62 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# pip-compile requirements.in +# +annotated-types==0.7.0 + # via pydantic +anyio==4.4.0 + # via httpx +ariadne-codegen==0.13.0 + # via -r requirements.in +autoflake==2.3.1 + # via ariadne-codegen +black==24.4.2 + # via ariadne-codegen +certifi==2024.7.4 + # via + # httpcore + # httpx +click==8.1.7 + # via + # ariadne-codegen + # black +graphql-core==3.2.3 + # via ariadne-codegen +h11==0.14.0 + # via httpcore +httpcore==1.0.5 + # via httpx +httpx==0.27.0 + # via ariadne-codegen +idna==3.7 + # via + # anyio + # httpx +isort==5.13.2 + # via ariadne-codegen +mypy-extensions==1.0.0 + # via black +packaging==24.1 + # via black +pathspec==0.12.1 + # via black +platformdirs==4.2.2 + # via black +pydantic==2.7.3 + # via ariadne-codegen +pydantic-core==2.18.4 + # via pydantic +pyflakes==3.2.0 + # via autoflake +sniffio==1.3.1 + # via + # anyio + # httpx +toml==0.10.2 + # via ariadne-codegen +typing-extensions==4.12.2 + # via + # pydantic + # pydantic-core diff --git a/api/python/quilt3/VERSION b/api/python/quilt3/VERSION index 8a30e8f94a3..f3b5af39e43 100644 --- a/api/python/quilt3/VERSION +++ b/api/python/quilt3/VERSION @@ -1 +1 @@ -5.4.0 +6.1.1 diff --git a/api/python/quilt3/__init__.py b/api/python/quilt3/__init__.py index 43dd88fda79..5fb4a2ca6e5 100644 --- a/api/python/quilt3/__init__.py +++ b/api/python/quilt3/__init__.py @@ -22,6 +22,6 @@ from .bucket import Bucket from .imports import start_data_package_loader from .packages import Package -from .session import logged_in, login, logout +from .session import get_boto3_session, logged_in, login, logout start_data_package_loader() diff --git a/api/python/quilt3/admin.py b/api/python/quilt3/admin.py deleted file mode 100644 index 6bdae483c74..00000000000 --- a/api/python/quilt3/admin.py +++ /dev/null @@ -1,56 +0,0 @@ -"""APIs for Quilt administrators. 'Registry' refers to Quilt stack backend services, including identity management.""" -import typing as T - -from .session import get_registry_url, get_session - - -def create_user(*, username: str, email: str): - """ - Create a new user in the registry. - - Required parameters: - username (str): Username of user to create. - email (str): Email of user to create. - """ - session = get_session() - response = session.post( - get_registry_url() + "/api/users/create", - json={ - "username": username, - "email": email, - }, - ) - - -def delete_user(*, username: str): - """ - Delete user from the registry. - - Required parameters: - username (str): Username of user to delete. - """ - session = get_session() - response = session.post( - get_registry_url() + "/api/users/delete", - json={ - "username": username, - }, - ) - - -def set_role(*, username: str, role_name: T.Optional[str]): - """ - Set the named Quilt role for a user. - - Required parameters: - username (str): Username of user to update. - role_name (str): Quilt role name assign to the user. Set a `None` value to unassign the role. - """ - session = get_session() - session.post( - get_registry_url() + "/api/users/set_role", - json={ - "username": username, - "role": role_name or "", - }, - ) diff --git a/api/python/quilt3/admin/__init__.py b/api/python/quilt3/admin/__init__.py new file mode 100644 index 00000000000..daf17ac2fdf --- /dev/null +++ b/api/python/quilt3/admin/__init__.py @@ -0,0 +1,13 @@ +""" +APIs for Quilt administrators. 'Registry' refers to Quilt stack backend services, including identity management. +""" + +# This wraps code generated by ariadne-codegen to provide a more user-friendly API. + +from . import roles, sso_config, tabulator, users +from .exceptions import ( + BucketNotFoundError, + Quilt3AdminError, + UserNotFoundError, +) +from .types import ManagedRole, SSOConfig, TabulatorTable, UnmanagedRole, User diff --git a/api/python/quilt3/admin/_graphql_client/__init__.py b/api/python/quilt3/admin/_graphql_client/__init__.py new file mode 100644 index 00000000000..717f43537df --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/__init__.py @@ -0,0 +1,253 @@ +# Generated by ariadne-codegen + +from .base_client import BaseClient +from .base_model import BaseModel, Upload +from .bucket_tabulator_table_rename import ( + BucketTabulatorTableRename, + BucketTabulatorTableRenameAdmin, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError, +) +from .bucket_tabulator_table_set import ( + BucketTabulatorTableSet, + BucketTabulatorTableSetAdmin, + BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig, + BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput, + BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError, +) +from .bucket_tabulator_tables_list import ( + BucketTabulatorTablesList, + BucketTabulatorTablesListBucketConfig, + BucketTabulatorTablesListBucketConfigTabulatorTables, +) +from .client import Client +from .fragments import ( + InvalidInputSelection, + InvalidInputSelectionErrors, + ManagedRoleSelection, + OperationErrorSelection, + SsoConfigSelection, + SsoConfigSelectionUploader, + UnmanagedRoleSelection, + UserSelection, + UserSelectionExtraRolesManagedRole, + UserSelectionExtraRolesUnmanagedRole, + UserSelectionRoleManagedRole, + UserSelectionRoleUnmanagedRole, +) +from .input_types import UserInput +from .roles_list import ( + RolesList, + RolesListRolesManagedRole, + RolesListRolesUnmanagedRole, +) +from .sso_config_get import SsoConfigGet, SsoConfigGetAdmin, SsoConfigGetAdminSsoConfig +from .sso_config_set import ( + SsoConfigSet, + SsoConfigSetAdmin, + SsoConfigSetAdminSetSsoConfigInvalidInput, + SsoConfigSetAdminSetSsoConfigOperationError, + SsoConfigSetAdminSetSsoConfigSsoConfig, +) +from .users_add_roles import ( + UsersAddRoles, + UsersAddRolesAdmin, + UsersAddRolesAdminUser, + UsersAddRolesAdminUserMutate, + UsersAddRolesAdminUserMutateAddRolesInvalidInput, + UsersAddRolesAdminUserMutateAddRolesOperationError, + UsersAddRolesAdminUserMutateAddRolesUser, +) +from .users_create import ( + UsersCreate, + UsersCreateAdmin, + UsersCreateAdminUser, + UsersCreateAdminUserCreateInvalidInput, + UsersCreateAdminUserCreateOperationError, + UsersCreateAdminUserCreateUser, +) +from .users_delete import ( + UsersDelete, + UsersDeleteAdmin, + UsersDeleteAdminUser, + UsersDeleteAdminUserMutate, + UsersDeleteAdminUserMutateDeleteInvalidInput, + UsersDeleteAdminUserMutateDeleteOk, + UsersDeleteAdminUserMutateDeleteOperationError, +) +from .users_get import UsersGet, UsersGetAdmin, UsersGetAdminUser, UsersGetAdminUserGet +from .users_list import ( + UsersList, + UsersListAdmin, + UsersListAdminUser, + UsersListAdminUserList, +) +from .users_remove_roles import ( + UsersRemoveRoles, + UsersRemoveRolesAdmin, + UsersRemoveRolesAdminUser, + UsersRemoveRolesAdminUserMutate, + UsersRemoveRolesAdminUserMutateRemoveRolesInvalidInput, + UsersRemoveRolesAdminUserMutateRemoveRolesOperationError, + UsersRemoveRolesAdminUserMutateRemoveRolesUser, +) +from .users_reset_password import ( + UsersResetPassword, + UsersResetPasswordAdmin, + UsersResetPasswordAdminUser, + UsersResetPasswordAdminUserMutate, + UsersResetPasswordAdminUserMutateResetPasswordInvalidInput, + UsersResetPasswordAdminUserMutateResetPasswordOk, + UsersResetPasswordAdminUserMutateResetPasswordOperationError, +) +from .users_set_active import ( + UsersSetActive, + UsersSetActiveAdmin, + UsersSetActiveAdminUser, + UsersSetActiveAdminUserMutate, + UsersSetActiveAdminUserMutateSetActiveInvalidInput, + UsersSetActiveAdminUserMutateSetActiveOperationError, + UsersSetActiveAdminUserMutateSetActiveUser, +) +from .users_set_admin import ( + UsersSetAdmin, + UsersSetAdminAdmin, + UsersSetAdminAdminUser, + UsersSetAdminAdminUserMutate, + UsersSetAdminAdminUserMutateSetAdminInvalidInput, + UsersSetAdminAdminUserMutateSetAdminOperationError, + UsersSetAdminAdminUserMutateSetAdminUser, +) +from .users_set_email import ( + UsersSetEmail, + UsersSetEmailAdmin, + UsersSetEmailAdminUser, + UsersSetEmailAdminUserMutate, + UsersSetEmailAdminUserMutateSetEmailInvalidInput, + UsersSetEmailAdminUserMutateSetEmailOperationError, + UsersSetEmailAdminUserMutateSetEmailUser, +) +from .users_set_role import ( + UsersSetRole, + UsersSetRoleAdmin, + UsersSetRoleAdminUser, + UsersSetRoleAdminUserMutate, + UsersSetRoleAdminUserMutateSetRoleInvalidInput, + UsersSetRoleAdminUserMutateSetRoleOperationError, + UsersSetRoleAdminUserMutateSetRoleUser, +) + +__all__ = [ + "BaseClient", + "BaseModel", + "BucketTabulatorTableRename", + "BucketTabulatorTableRenameAdmin", + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig", + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput", + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError", + "BucketTabulatorTableSet", + "BucketTabulatorTableSetAdmin", + "BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig", + "BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput", + "BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError", + "BucketTabulatorTablesList", + "BucketTabulatorTablesListBucketConfig", + "BucketTabulatorTablesListBucketConfigTabulatorTables", + "Client", + "InvalidInputSelection", + "InvalidInputSelectionErrors", + "ManagedRoleSelection", + "OperationErrorSelection", + "RolesList", + "RolesListRolesManagedRole", + "RolesListRolesUnmanagedRole", + "SsoConfigGet", + "SsoConfigGetAdmin", + "SsoConfigGetAdminSsoConfig", + "SsoConfigSelection", + "SsoConfigSelectionUploader", + "SsoConfigSet", + "SsoConfigSetAdmin", + "SsoConfigSetAdminSetSsoConfigInvalidInput", + "SsoConfigSetAdminSetSsoConfigOperationError", + "SsoConfigSetAdminSetSsoConfigSsoConfig", + "UnmanagedRoleSelection", + "Upload", + "UserInput", + "UserSelection", + "UserSelectionExtraRolesManagedRole", + "UserSelectionExtraRolesUnmanagedRole", + "UserSelectionRoleManagedRole", + "UserSelectionRoleUnmanagedRole", + "UsersAddRoles", + "UsersAddRolesAdmin", + "UsersAddRolesAdminUser", + "UsersAddRolesAdminUserMutate", + "UsersAddRolesAdminUserMutateAddRolesInvalidInput", + "UsersAddRolesAdminUserMutateAddRolesOperationError", + "UsersAddRolesAdminUserMutateAddRolesUser", + "UsersCreate", + "UsersCreateAdmin", + "UsersCreateAdminUser", + "UsersCreateAdminUserCreateInvalidInput", + "UsersCreateAdminUserCreateOperationError", + "UsersCreateAdminUserCreateUser", + "UsersDelete", + "UsersDeleteAdmin", + "UsersDeleteAdminUser", + "UsersDeleteAdminUserMutate", + "UsersDeleteAdminUserMutateDeleteInvalidInput", + "UsersDeleteAdminUserMutateDeleteOk", + "UsersDeleteAdminUserMutateDeleteOperationError", + "UsersGet", + "UsersGetAdmin", + "UsersGetAdminUser", + "UsersGetAdminUserGet", + "UsersList", + "UsersListAdmin", + "UsersListAdminUser", + "UsersListAdminUserList", + "UsersRemoveRoles", + "UsersRemoveRolesAdmin", + "UsersRemoveRolesAdminUser", + "UsersRemoveRolesAdminUserMutate", + "UsersRemoveRolesAdminUserMutateRemoveRolesInvalidInput", + "UsersRemoveRolesAdminUserMutateRemoveRolesOperationError", + "UsersRemoveRolesAdminUserMutateRemoveRolesUser", + "UsersResetPassword", + "UsersResetPasswordAdmin", + "UsersResetPasswordAdminUser", + "UsersResetPasswordAdminUserMutate", + "UsersResetPasswordAdminUserMutateResetPasswordInvalidInput", + "UsersResetPasswordAdminUserMutateResetPasswordOk", + "UsersResetPasswordAdminUserMutateResetPasswordOperationError", + "UsersSetActive", + "UsersSetActiveAdmin", + "UsersSetActiveAdminUser", + "UsersSetActiveAdminUserMutate", + "UsersSetActiveAdminUserMutateSetActiveInvalidInput", + "UsersSetActiveAdminUserMutateSetActiveOperationError", + "UsersSetActiveAdminUserMutateSetActiveUser", + "UsersSetAdmin", + "UsersSetAdminAdmin", + "UsersSetAdminAdminUser", + "UsersSetAdminAdminUserMutate", + "UsersSetAdminAdminUserMutateSetAdminInvalidInput", + "UsersSetAdminAdminUserMutateSetAdminOperationError", + "UsersSetAdminAdminUserMutateSetAdminUser", + "UsersSetEmail", + "UsersSetEmailAdmin", + "UsersSetEmailAdminUser", + "UsersSetEmailAdminUserMutate", + "UsersSetEmailAdminUserMutateSetEmailInvalidInput", + "UsersSetEmailAdminUserMutateSetEmailOperationError", + "UsersSetEmailAdminUserMutateSetEmailUser", + "UsersSetRole", + "UsersSetRoleAdmin", + "UsersSetRoleAdminUser", + "UsersSetRoleAdminUserMutate", + "UsersSetRoleAdminUserMutateSetRoleInvalidInput", + "UsersSetRoleAdminUserMutateSetRoleOperationError", + "UsersSetRoleAdminUserMutateSetRoleUser", +] diff --git a/api/python/quilt3/admin/_graphql_client/base_client.py b/api/python/quilt3/admin/_graphql_client/base_client.py new file mode 100644 index 00000000000..298d14e5a8a --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/base_client.py @@ -0,0 +1,213 @@ +# Generated by ariadne-codegen + +# This is +# https://github.com/mirumee/ariadne-codegen/blob/5bfd63c5e7e3a8cc5293eb94deee638b7adab98d/ariadne_codegen/client_generators/dependencies/base_client.py +# modified to use our requests session instead of httpx. +# pylint: disable=relative-beyond-top-level +import json +from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, cast + +import requests +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from quilt3 import session + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) + +Self = TypeVar("Self", bound="BaseClient") + + +class BaseClient: + def __init__( + self, + ) -> None: + self.url = session.get_registry_url() + "/graphql" + + self.http_client = session.get_session() + + def __enter__(self: Self) -> Self: + return self + + def __exit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + self.http_client.close() + + def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> requests.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: requests.Response) -> Dict[str, Any]: + if not 200 <= response.status_code < 300: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> requests.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return self.http_client.post(url=self.url, data=data, files=files, **kwargs) + + def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> requests.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return self.http_client.post( + url=self.url, + data=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) diff --git a/api/python/quilt3/admin/_graphql_client/base_model.py b/api/python/quilt3/admin/_graphql_client/base_model.py new file mode 100644 index 00000000000..76b84873a6f --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/base_model.py @@ -0,0 +1,29 @@ +# Generated by ariadne-codegen + +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_rename.py b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_rename.py new file mode 100644 index 00000000000..9eb7f822c0c --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_rename.py @@ -0,0 +1,41 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection + + +class BucketTabulatorTableRename(BaseModel): + admin: "BucketTabulatorTableRenameAdmin" + + +class BucketTabulatorTableRenameAdmin(BaseModel): + bucket_rename_tabulator_table: Union[ + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig", + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput", + "BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError", + ] = Field(alias="bucketRenameTabulatorTable", discriminator="typename__") + + +class BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig(BaseModel): + typename__: Literal["BucketConfig"] = Field(alias="__typename") + + +class BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput( + InvalidInputSelection +): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError( + OperationErrorSelection +): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +BucketTabulatorTableRename.model_rebuild() +BucketTabulatorTableRenameAdmin.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_set.py b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_set.py new file mode 100644 index 00000000000..6173234db53 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_table_set.py @@ -0,0 +1,41 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection + + +class BucketTabulatorTableSet(BaseModel): + admin: "BucketTabulatorTableSetAdmin" + + +class BucketTabulatorTableSetAdmin(BaseModel): + bucket_set_tabulator_table: Union[ + "BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig", + "BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput", + "BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError", + ] = Field(alias="bucketSetTabulatorTable", discriminator="typename__") + + +class BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig(BaseModel): + typename__: Literal["BucketConfig"] = Field(alias="__typename") + + +class BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput( + InvalidInputSelection +): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError( + OperationErrorSelection +): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +BucketTabulatorTableSet.model_rebuild() +BucketTabulatorTableSetAdmin.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/bucket_tabulator_tables_list.py b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_tables_list.py new file mode 100644 index 00000000000..a91d4dc9b8e --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/bucket_tabulator_tables_list.py @@ -0,0 +1,29 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import List, Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class BucketTabulatorTablesList(BaseModel): + bucket_config: Optional["BucketTabulatorTablesListBucketConfig"] = Field( + alias="bucketConfig" + ) + + +class BucketTabulatorTablesListBucketConfig(BaseModel): + tabulator_tables: List["BucketTabulatorTablesListBucketConfigTabulatorTables"] = ( + Field(alias="tabulatorTables") + ) + + +class BucketTabulatorTablesListBucketConfigTabulatorTables(BaseModel): + name: str + config: str + + +BucketTabulatorTablesList.model_rebuild() +BucketTabulatorTablesListBucketConfig.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/client.py b/api/python/quilt3/admin/_graphql_client/client.py new file mode 100644 index 00000000000..e4ef1dac7e7 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/client.py @@ -0,0 +1,1166 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Any, Dict, List, Optional, Union + +from .base_client import BaseClient +from .base_model import UNSET, UnsetType +from .bucket_tabulator_table_rename import ( + BucketTabulatorTableRename, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError, +) +from .bucket_tabulator_table_set import ( + BucketTabulatorTableSet, + BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig, + BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput, + BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError, +) +from .bucket_tabulator_tables_list import ( + BucketTabulatorTablesList, + BucketTabulatorTablesListBucketConfig, +) +from .input_types import UserInput +from .roles_list import ( + RolesList, + RolesListRolesManagedRole, + RolesListRolesUnmanagedRole, +) +from .sso_config_get import SsoConfigGet, SsoConfigGetAdminSsoConfig +from .sso_config_set import ( + SsoConfigSet, + SsoConfigSetAdminSetSsoConfigInvalidInput, + SsoConfigSetAdminSetSsoConfigOperationError, + SsoConfigSetAdminSetSsoConfigSsoConfig, +) +from .users_add_roles import UsersAddRoles, UsersAddRolesAdminUserMutate +from .users_create import ( + UsersCreate, + UsersCreateAdminUserCreateInvalidInput, + UsersCreateAdminUserCreateOperationError, + UsersCreateAdminUserCreateUser, +) +from .users_delete import UsersDelete, UsersDeleteAdminUserMutate +from .users_get import UsersGet, UsersGetAdminUserGet +from .users_list import UsersList, UsersListAdminUserList +from .users_remove_roles import UsersRemoveRoles, UsersRemoveRolesAdminUserMutate +from .users_reset_password import UsersResetPassword, UsersResetPasswordAdminUserMutate +from .users_set_active import UsersSetActive, UsersSetActiveAdminUserMutate +from .users_set_admin import UsersSetAdmin, UsersSetAdminAdminUserMutate +from .users_set_email import UsersSetEmail, UsersSetEmailAdminUserMutate +from .users_set_role import UsersSetRole, UsersSetRoleAdminUserMutate + + +def gql(q: str) -> str: + return q + + +class Client(BaseClient): + def roles_list( + self, **kwargs: Any + ) -> List[Union[RolesListRolesUnmanagedRole, RolesListRolesManagedRole]]: + query = gql( + """ + query rolesList { + roles { + ...RoleSelection + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + """ + ) + variables: Dict[str, object] = {} + response = self.execute( + query=query, operation_name="rolesList", variables=variables, **kwargs + ) + data = self.get_data(response) + return RolesList.model_validate(data).roles + + def users_get(self, name: str, **kwargs: Any) -> Optional[UsersGetAdminUserGet]: + query = gql( + """ + query usersGet($name: String!) { + admin { + user { + get(name: $name) { + ...UserSelection + } + } + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = self.execute( + query=query, operation_name="usersGet", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersGet.model_validate(data).admin.user.get + + def users_list(self, **kwargs: Any) -> List[UsersListAdminUserList]: + query = gql( + """ + query usersList { + admin { + user { + list { + ...UserSelection + } + } + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {} + response = self.execute( + query=query, operation_name="usersList", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersList.model_validate(data).admin.user.list + + def users_create(self, input: UserInput, **kwargs: Any) -> Union[ + UsersCreateAdminUserCreateUser, + UsersCreateAdminUserCreateInvalidInput, + UsersCreateAdminUserCreateOperationError, + ]: + query = gql( + """ + mutation usersCreate($input: UserInput!) { + admin { + user { + create(input: $input) { + __typename + ...UserMutationSelection + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserMutationSelection on UserResult { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"input": input} + response = self.execute( + query=query, operation_name="usersCreate", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersCreate.model_validate(data).admin.user.create + + def users_delete( + self, name: str, **kwargs: Any + ) -> Optional[UsersDeleteAdminUserMutate]: + query = gql( + """ + mutation usersDelete($name: String!) { + admin { + user { + mutate(name: $name) { + delete { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = self.execute( + query=query, operation_name="usersDelete", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersDelete.model_validate(data).admin.user.mutate + + def users_set_email( + self, email: str, name: str, **kwargs: Any + ) -> Optional[UsersSetEmailAdminUserMutate]: + query = gql( + """ + mutation usersSetEmail($email: String!, $name: String!) { + admin { + user { + mutate(name: $name) { + setEmail(email: $email) { + __typename + ...UserMutationSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserMutationSelection on UserResult { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"email": email, "name": name} + response = self.execute( + query=query, operation_name="usersSetEmail", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersSetEmail.model_validate(data).admin.user.mutate + + def users_set_admin( + self, name: str, admin: bool, **kwargs: Any + ) -> Optional[UsersSetAdminAdminUserMutate]: + query = gql( + """ + mutation usersSetAdmin($name: String!, $admin: Boolean!) { + admin { + user { + mutate(name: $name) { + setAdmin(admin: $admin) { + __typename + ...UserMutationSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserMutationSelection on UserResult { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"name": name, "admin": admin} + response = self.execute( + query=query, operation_name="usersSetAdmin", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersSetAdmin.model_validate(data).admin.user.mutate + + def users_set_active( + self, active: bool, name: str, **kwargs: Any + ) -> Optional[UsersSetActiveAdminUserMutate]: + query = gql( + """ + mutation usersSetActive($active: Boolean!, $name: String!) { + admin { + user { + mutate(name: $name) { + setActive(active: $active) { + __typename + ...UserMutationSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserMutationSelection on UserResult { + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"active": active, "name": name} + response = self.execute( + query=query, operation_name="usersSetActive", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersSetActive.model_validate(data).admin.user.mutate + + def users_reset_password( + self, name: str, **kwargs: Any + ) -> Optional[UsersResetPasswordAdminUserMutate]: + query = gql( + """ + mutation usersResetPassword($name: String!) { + admin { + user { + mutate(name: $name) { + resetPassword { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = self.execute( + query=query, + operation_name="usersResetPassword", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return UsersResetPassword.model_validate(data).admin.user.mutate + + def users_set_role( + self, + name: str, + role: str, + append: bool, + extra_roles: Union[Optional[List[str]], UnsetType] = UNSET, + **kwargs: Any + ) -> Optional[UsersSetRoleAdminUserMutate]: + query = gql( + """ + mutation usersSetRole($name: String!, $role: String!, $extraRoles: [String!], $append: Boolean!) { + admin { + user { + mutate(name: $name) { + setRole(role: $role, extraRoles: $extraRoles, append: $append) { + __typename + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = { + "name": name, + "role": role, + "extraRoles": extra_roles, + "append": append, + } + response = self.execute( + query=query, operation_name="usersSetRole", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersSetRole.model_validate(data).admin.user.mutate + + def users_add_roles( + self, name: str, roles: List[str], **kwargs: Any + ) -> Optional[UsersAddRolesAdminUserMutate]: + query = gql( + """ + mutation usersAddRoles($name: String!, $roles: [String!]!) { + admin { + user { + mutate(name: $name) { + addRoles(roles: $roles) { + __typename + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"name": name, "roles": roles} + response = self.execute( + query=query, operation_name="usersAddRoles", variables=variables, **kwargs + ) + data = self.get_data(response) + return UsersAddRoles.model_validate(data).admin.user.mutate + + def users_remove_roles( + self, + name: str, + roles: List[str], + fallback: Union[Optional[str], UnsetType] = UNSET, + **kwargs: Any + ) -> Optional[UsersRemoveRolesAdminUserMutate]: + query = gql( + """ + mutation usersRemoveRoles($name: String!, $roles: [String!]!, $fallback: String) { + admin { + user { + mutate(name: $name) { + removeRoles(roles: $roles, fallback: $fallback) { + __typename + ...UserSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = { + "name": name, + "roles": roles, + "fallback": fallback, + } + response = self.execute( + query=query, + operation_name="usersRemoveRoles", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return UsersRemoveRoles.model_validate(data).admin.user.mutate + + def sso_config_get(self, **kwargs: Any) -> Optional[SsoConfigGetAdminSsoConfig]: + query = gql( + """ + query ssoConfigGet { + admin { + ssoConfig { + ...SsoConfigSelection + } + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment SsoConfigSelection on SsoConfig { + text + timestamp + uploader { + ...UserSelection + } + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {} + response = self.execute( + query=query, operation_name="ssoConfigGet", variables=variables, **kwargs + ) + data = self.get_data(response) + return SsoConfigGet.model_validate(data).admin.sso_config + + def sso_config_set( + self, config: Union[Optional[str], UnsetType] = UNSET, **kwargs: Any + ) -> Optional[ + Union[ + SsoConfigSetAdminSetSsoConfigSsoConfig, + SsoConfigSetAdminSetSsoConfigInvalidInput, + SsoConfigSetAdminSetSsoConfigOperationError, + ] + ]: + query = gql( + """ + mutation ssoConfigSet($config: String) { + admin { + setSsoConfig(config: $config) { + __typename + ...SsoConfigSelection + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment ManagedRoleSelection on ManagedRole { + id + name + arn + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + + fragment RoleSelection on Role { + __typename + ...UnmanagedRoleSelection + ...ManagedRoleSelection + } + + fragment SsoConfigSelection on SsoConfig { + text + timestamp + uploader { + ...UserSelection + } + } + + fragment UnmanagedRoleSelection on UnmanagedRole { + id + name + arn + } + + fragment UserSelection on User { + name + email + dateJoined + lastLogin + isActive + isAdmin + isSsoOnly + isService + role { + ...RoleSelection + } + extraRoles { + ...RoleSelection + } + } + """ + ) + variables: Dict[str, object] = {"config": config} + response = self.execute( + query=query, operation_name="ssoConfigSet", variables=variables, **kwargs + ) + data = self.get_data(response) + return SsoConfigSet.model_validate(data).admin.set_sso_config + + def bucket_tabulator_tables_list( + self, name: str, **kwargs: Any + ) -> Optional[BucketTabulatorTablesListBucketConfig]: + query = gql( + """ + query bucketTabulatorTablesList($name: String!) { + bucketConfig(name: $name) { + tabulatorTables { + name + config + } + } + } + """ + ) + variables: Dict[str, object] = {"name": name} + response = self.execute( + query=query, + operation_name="bucketTabulatorTablesList", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return BucketTabulatorTablesList.model_validate(data).bucket_config + + def bucket_tabulator_table_set( + self, + bucket_name: str, + table_name: str, + config: Union[Optional[str], UnsetType] = UNSET, + **kwargs: Any + ) -> Union[ + BucketTabulatorTableSetAdminBucketSetTabulatorTableBucketConfig, + BucketTabulatorTableSetAdminBucketSetTabulatorTableInvalidInput, + BucketTabulatorTableSetAdminBucketSetTabulatorTableOperationError, + ]: + query = gql( + """ + mutation bucketTabulatorTableSet($bucketName: String!, $tableName: String!, $config: String) { + admin { + bucketSetTabulatorTable( + bucketName: $bucketName + tableName: $tableName + config: $config + ) { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + """ + ) + variables: Dict[str, object] = { + "bucketName": bucket_name, + "tableName": table_name, + "config": config, + } + response = self.execute( + query=query, + operation_name="bucketTabulatorTableSet", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return BucketTabulatorTableSet.model_validate( + data + ).admin.bucket_set_tabulator_table + + def bucket_tabulator_table_rename( + self, bucket_name: str, table_name: str, new_table_name: str, **kwargs: Any + ) -> Union[ + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableBucketConfig, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableInvalidInput, + BucketTabulatorTableRenameAdminBucketRenameTabulatorTableOperationError, + ]: + query = gql( + """ + mutation bucketTabulatorTableRename($bucketName: String!, $tableName: String!, $newTableName: String!) { + admin { + bucketRenameTabulatorTable( + bucketName: $bucketName + tableName: $tableName + newTableName: $newTableName + ) { + __typename + ...InvalidInputSelection + ...OperationErrorSelection + } + } + } + + fragment InvalidInputSelection on InvalidInput { + errors { + path + message + name + context + } + } + + fragment OperationErrorSelection on OperationError { + message + name + context + } + """ + ) + variables: Dict[str, object] = { + "bucketName": bucket_name, + "tableName": table_name, + "newTableName": new_table_name, + } + response = self.execute( + query=query, + operation_name="bucketTabulatorTableRename", + variables=variables, + **kwargs + ) + data = self.get_data(response) + return BucketTabulatorTableRename.model_validate( + data + ).admin.bucket_rename_tabulator_table diff --git a/api/python/quilt3/admin/_graphql_client/enums.py b/api/python/quilt3/admin/_graphql_client/enums.py new file mode 100644 index 00000000000..638363665ea --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/enums.py @@ -0,0 +1,3 @@ +# Generated by ariadne-codegen +# Source: ../../../shared/graphql/schema.graphql + diff --git a/api/python/quilt3/admin/_graphql_client/exceptions.py b/api/python/quilt3/admin/_graphql_client/exceptions.py new file mode 100644 index 00000000000..f42118680cc --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/exceptions.py @@ -0,0 +1,89 @@ +# Generated by ariadne-codegen + +# This is +# https://github.com/mirumee/ariadne-codegen/blob/5bfd63c5e7e3a8cc5293eb94deee638b7adab98d/ariadne_codegen/client_generators/dependencies/exceptions.py +# modified to use our requests instead of httpx. +# pylint: disable=super-init-not-called +from typing import Any, Dict, List, Optional, Union + +import requests + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: requests.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: requests.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/api/python/quilt3/admin/_graphql_client/fragments.py b/api/python/quilt3/admin/_graphql_client/fragments.py new file mode 100644 index 00000000000..1568dbcc7c4 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/fragments.py @@ -0,0 +1,98 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from datetime import datetime +from typing import Annotated, Any, List, Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel + + +class InvalidInputSelection(BaseModel): + errors: List["InvalidInputSelectionErrors"] + + +class InvalidInputSelectionErrors(BaseModel): + path: Optional[str] + message: str + name: str + context: Optional[Any] + + +class ManagedRoleSelection(BaseModel): + id: str + name: str + arn: str + + +class OperationErrorSelection(BaseModel): + message: str + name: str + context: Optional[Any] + + +class UnmanagedRoleSelection(BaseModel): + id: str + name: str + arn: str + + +class UserSelection(BaseModel): + name: str + email: str + date_joined: datetime = Field(alias="dateJoined") + last_login: datetime = Field(alias="lastLogin") + is_active: bool = Field(alias="isActive") + is_admin: bool = Field(alias="isAdmin") + is_sso_only: bool = Field(alias="isSsoOnly") + is_service: bool = Field(alias="isService") + role: Optional[ + Annotated[ + Union["UserSelectionRoleUnmanagedRole", "UserSelectionRoleManagedRole"], + Field(discriminator="typename__"), + ] + ] + extra_roles: List[ + Annotated[ + Union[ + "UserSelectionExtraRolesUnmanagedRole", + "UserSelectionExtraRolesManagedRole", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="extraRoles") + + +class UserSelectionRoleUnmanagedRole(UnmanagedRoleSelection): + typename__: Literal["UnmanagedRole"] = Field(alias="__typename") + + +class UserSelectionRoleManagedRole(ManagedRoleSelection): + typename__: Literal["ManagedRole"] = Field(alias="__typename") + + +class UserSelectionExtraRolesUnmanagedRole(UnmanagedRoleSelection): + typename__: Literal["UnmanagedRole"] = Field(alias="__typename") + + +class UserSelectionExtraRolesManagedRole(ManagedRoleSelection): + typename__: Literal["ManagedRole"] = Field(alias="__typename") + + +class SsoConfigSelection(BaseModel): + text: str + timestamp: datetime + uploader: "SsoConfigSelectionUploader" + + +class SsoConfigSelectionUploader(UserSelection): + pass + + +InvalidInputSelection.model_rebuild() +ManagedRoleSelection.model_rebuild() +OperationErrorSelection.model_rebuild() +UnmanagedRoleSelection.model_rebuild() +UserSelection.model_rebuild() +SsoConfigSelection.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/input_types.py b/api/python/quilt3/admin/_graphql_client/input_types.py new file mode 100644 index 00000000000..cb87f01d15a --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/input_types.py @@ -0,0 +1,16 @@ +# Generated by ariadne-codegen +# Source: ../../../shared/graphql/schema.graphql + +from datetime import datetime +from typing import Annotated, Any, List, Optional, Union + +from pydantic import Field, PlainSerializer + +from .base_model import BaseModel, Upload + + +class UserInput(BaseModel): + name: str + email: str + role: str + extra_roles: Optional[List[str]] = Field(alias="extraRoles", default=None) diff --git a/api/python/quilt3/admin/_graphql_client/roles_list.py b/api/python/quilt3/admin/_graphql_client/roles_list.py new file mode 100644 index 00000000000..37598459172 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/roles_list.py @@ -0,0 +1,29 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Annotated, List, Literal, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import ManagedRoleSelection, UnmanagedRoleSelection + + +class RolesList(BaseModel): + roles: List[ + Annotated[ + Union["RolesListRolesUnmanagedRole", "RolesListRolesManagedRole"], + Field(discriminator="typename__"), + ] + ] + + +class RolesListRolesUnmanagedRole(UnmanagedRoleSelection): + typename__: Literal["UnmanagedRole"] = Field(alias="__typename") + + +class RolesListRolesManagedRole(ManagedRoleSelection): + typename__: Literal["ManagedRole"] = Field(alias="__typename") + + +RolesList.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/sso_config_get.py b/api/python/quilt3/admin/_graphql_client/sso_config_get.py new file mode 100644 index 00000000000..b57c6dbb47c --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/sso_config_get.py @@ -0,0 +1,25 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import SsoConfigSelection + + +class SsoConfigGet(BaseModel): + admin: "SsoConfigGetAdmin" + + +class SsoConfigGetAdmin(BaseModel): + sso_config: Optional["SsoConfigGetAdminSsoConfig"] = Field(alias="ssoConfig") + + +class SsoConfigGetAdminSsoConfig(SsoConfigSelection): + pass + + +SsoConfigGet.model_rebuild() +SsoConfigGetAdmin.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/sso_config_set.py b/api/python/quilt3/admin/_graphql_client/sso_config_set.py new file mode 100644 index 00000000000..af7e32af9fe --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/sso_config_set.py @@ -0,0 +1,46 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Annotated, Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import ( + InvalidInputSelection, + OperationErrorSelection, + SsoConfigSelection, +) + + +class SsoConfigSet(BaseModel): + admin: "SsoConfigSetAdmin" + + +class SsoConfigSetAdmin(BaseModel): + set_sso_config: Optional[ + Annotated[ + Union[ + "SsoConfigSetAdminSetSsoConfigSsoConfig", + "SsoConfigSetAdminSetSsoConfigInvalidInput", + "SsoConfigSetAdminSetSsoConfigOperationError", + ], + Field(discriminator="typename__"), + ] + ] = Field(alias="setSsoConfig") + + +class SsoConfigSetAdminSetSsoConfigSsoConfig(SsoConfigSelection): + typename__: Literal["SsoConfig"] = Field(alias="__typename") + + +class SsoConfigSetAdminSetSsoConfigInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class SsoConfigSetAdminSetSsoConfigOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +SsoConfigSet.model_rebuild() +SsoConfigSetAdmin.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_add_roles.py b/api/python/quilt3/admin/_graphql_client/users_add_roles.py new file mode 100644 index 00000000000..101a2f8de2d --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_add_roles.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersAddRoles(BaseModel): + admin: "UsersAddRolesAdmin" + + +class UsersAddRolesAdmin(BaseModel): + user: "UsersAddRolesAdminUser" + + +class UsersAddRolesAdminUser(BaseModel): + mutate: Optional["UsersAddRolesAdminUserMutate"] + + +class UsersAddRolesAdminUserMutate(BaseModel): + add_roles: Union[ + "UsersAddRolesAdminUserMutateAddRolesUser", + "UsersAddRolesAdminUserMutateAddRolesInvalidInput", + "UsersAddRolesAdminUserMutateAddRolesOperationError", + ] = Field(alias="addRoles", discriminator="typename__") + + +class UsersAddRolesAdminUserMutateAddRolesUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersAddRolesAdminUserMutateAddRolesInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersAddRolesAdminUserMutateAddRolesOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersAddRoles.model_rebuild() +UsersAddRolesAdmin.model_rebuild() +UsersAddRolesAdminUser.model_rebuild() +UsersAddRolesAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_create.py b/api/python/quilt3/admin/_graphql_client/users_create.py new file mode 100644 index 00000000000..2b696f4dcd3 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_create.py @@ -0,0 +1,42 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersCreate(BaseModel): + admin: "UsersCreateAdmin" + + +class UsersCreateAdmin(BaseModel): + user: "UsersCreateAdminUser" + + +class UsersCreateAdminUser(BaseModel): + create: Union[ + "UsersCreateAdminUserCreateUser", + "UsersCreateAdminUserCreateInvalidInput", + "UsersCreateAdminUserCreateOperationError", + ] = Field(discriminator="typename__") + + +class UsersCreateAdminUserCreateUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersCreateAdminUserCreateInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersCreateAdminUserCreateOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersCreate.model_rebuild() +UsersCreateAdmin.model_rebuild() +UsersCreateAdminUser.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_delete.py b/api/python/quilt3/admin/_graphql_client/users_delete.py new file mode 100644 index 00000000000..3f5cdc726a2 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_delete.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection + + +class UsersDelete(BaseModel): + admin: "UsersDeleteAdmin" + + +class UsersDeleteAdmin(BaseModel): + user: "UsersDeleteAdminUser" + + +class UsersDeleteAdminUser(BaseModel): + mutate: Optional["UsersDeleteAdminUserMutate"] + + +class UsersDeleteAdminUserMutate(BaseModel): + delete: Union[ + "UsersDeleteAdminUserMutateDeleteOk", + "UsersDeleteAdminUserMutateDeleteInvalidInput", + "UsersDeleteAdminUserMutateDeleteOperationError", + ] = Field(discriminator="typename__") + + +class UsersDeleteAdminUserMutateDeleteOk(BaseModel): + typename__: Literal["Ok"] = Field(alias="__typename") + + +class UsersDeleteAdminUserMutateDeleteInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersDeleteAdminUserMutateDeleteOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersDelete.model_rebuild() +UsersDeleteAdmin.model_rebuild() +UsersDeleteAdminUser.model_rebuild() +UsersDeleteAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_get.py b/api/python/quilt3/admin/_graphql_client/users_get.py new file mode 100644 index 00000000000..3a93f98aa2d --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_get.py @@ -0,0 +1,28 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Optional + +from .base_model import BaseModel +from .fragments import UserSelection + + +class UsersGet(BaseModel): + admin: "UsersGetAdmin" + + +class UsersGetAdmin(BaseModel): + user: "UsersGetAdminUser" + + +class UsersGetAdminUser(BaseModel): + get: Optional["UsersGetAdminUserGet"] + + +class UsersGetAdminUserGet(UserSelection): + pass + + +UsersGet.model_rebuild() +UsersGetAdmin.model_rebuild() +UsersGetAdminUser.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_list.py b/api/python/quilt3/admin/_graphql_client/users_list.py new file mode 100644 index 00000000000..bd85e2399cf --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_list.py @@ -0,0 +1,28 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import List + +from .base_model import BaseModel +from .fragments import UserSelection + + +class UsersList(BaseModel): + admin: "UsersListAdmin" + + +class UsersListAdmin(BaseModel): + user: "UsersListAdminUser" + + +class UsersListAdminUser(BaseModel): + list: List["UsersListAdminUserList"] + + +class UsersListAdminUserList(UserSelection): + pass + + +UsersList.model_rebuild() +UsersListAdmin.model_rebuild() +UsersListAdminUser.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_remove_roles.py b/api/python/quilt3/admin/_graphql_client/users_remove_roles.py new file mode 100644 index 00000000000..521dda0feef --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_remove_roles.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersRemoveRoles(BaseModel): + admin: "UsersRemoveRolesAdmin" + + +class UsersRemoveRolesAdmin(BaseModel): + user: "UsersRemoveRolesAdminUser" + + +class UsersRemoveRolesAdminUser(BaseModel): + mutate: Optional["UsersRemoveRolesAdminUserMutate"] + + +class UsersRemoveRolesAdminUserMutate(BaseModel): + remove_roles: Union[ + "UsersRemoveRolesAdminUserMutateRemoveRolesUser", + "UsersRemoveRolesAdminUserMutateRemoveRolesInvalidInput", + "UsersRemoveRolesAdminUserMutateRemoveRolesOperationError", + ] = Field(alias="removeRoles", discriminator="typename__") + + +class UsersRemoveRolesAdminUserMutateRemoveRolesUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersRemoveRolesAdminUserMutateRemoveRolesInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersRemoveRolesAdminUserMutateRemoveRolesOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersRemoveRoles.model_rebuild() +UsersRemoveRolesAdmin.model_rebuild() +UsersRemoveRolesAdminUser.model_rebuild() +UsersRemoveRolesAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_reset_password.py b/api/python/quilt3/admin/_graphql_client/users_reset_password.py new file mode 100644 index 00000000000..65b54546dc5 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_reset_password.py @@ -0,0 +1,49 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection + + +class UsersResetPassword(BaseModel): + admin: "UsersResetPasswordAdmin" + + +class UsersResetPasswordAdmin(BaseModel): + user: "UsersResetPasswordAdminUser" + + +class UsersResetPasswordAdminUser(BaseModel): + mutate: Optional["UsersResetPasswordAdminUserMutate"] + + +class UsersResetPasswordAdminUserMutate(BaseModel): + reset_password: Union[ + "UsersResetPasswordAdminUserMutateResetPasswordOk", + "UsersResetPasswordAdminUserMutateResetPasswordInvalidInput", + "UsersResetPasswordAdminUserMutateResetPasswordOperationError", + ] = Field(alias="resetPassword", discriminator="typename__") + + +class UsersResetPasswordAdminUserMutateResetPasswordOk(BaseModel): + typename__: Literal["Ok"] = Field(alias="__typename") + + +class UsersResetPasswordAdminUserMutateResetPasswordInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersResetPasswordAdminUserMutateResetPasswordOperationError( + OperationErrorSelection +): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersResetPassword.model_rebuild() +UsersResetPasswordAdmin.model_rebuild() +UsersResetPasswordAdminUser.model_rebuild() +UsersResetPasswordAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_set_active.py b/api/python/quilt3/admin/_graphql_client/users_set_active.py new file mode 100644 index 00000000000..1d6b06056dc --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_set_active.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersSetActive(BaseModel): + admin: "UsersSetActiveAdmin" + + +class UsersSetActiveAdmin(BaseModel): + user: "UsersSetActiveAdminUser" + + +class UsersSetActiveAdminUser(BaseModel): + mutate: Optional["UsersSetActiveAdminUserMutate"] + + +class UsersSetActiveAdminUserMutate(BaseModel): + set_active: Union[ + "UsersSetActiveAdminUserMutateSetActiveUser", + "UsersSetActiveAdminUserMutateSetActiveInvalidInput", + "UsersSetActiveAdminUserMutateSetActiveOperationError", + ] = Field(alias="setActive", discriminator="typename__") + + +class UsersSetActiveAdminUserMutateSetActiveUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersSetActiveAdminUserMutateSetActiveInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersSetActiveAdminUserMutateSetActiveOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersSetActive.model_rebuild() +UsersSetActiveAdmin.model_rebuild() +UsersSetActiveAdminUser.model_rebuild() +UsersSetActiveAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_set_admin.py b/api/python/quilt3/admin/_graphql_client/users_set_admin.py new file mode 100644 index 00000000000..67e3213265a --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_set_admin.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersSetAdmin(BaseModel): + admin: "UsersSetAdminAdmin" + + +class UsersSetAdminAdmin(BaseModel): + user: "UsersSetAdminAdminUser" + + +class UsersSetAdminAdminUser(BaseModel): + mutate: Optional["UsersSetAdminAdminUserMutate"] + + +class UsersSetAdminAdminUserMutate(BaseModel): + set_admin: Union[ + "UsersSetAdminAdminUserMutateSetAdminUser", + "UsersSetAdminAdminUserMutateSetAdminInvalidInput", + "UsersSetAdminAdminUserMutateSetAdminOperationError", + ] = Field(alias="setAdmin", discriminator="typename__") + + +class UsersSetAdminAdminUserMutateSetAdminUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersSetAdminAdminUserMutateSetAdminInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersSetAdminAdminUserMutateSetAdminOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersSetAdmin.model_rebuild() +UsersSetAdminAdmin.model_rebuild() +UsersSetAdminAdminUser.model_rebuild() +UsersSetAdminAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_set_email.py b/api/python/quilt3/admin/_graphql_client/users_set_email.py new file mode 100644 index 00000000000..4e434d313f0 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_set_email.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersSetEmail(BaseModel): + admin: "UsersSetEmailAdmin" + + +class UsersSetEmailAdmin(BaseModel): + user: "UsersSetEmailAdminUser" + + +class UsersSetEmailAdminUser(BaseModel): + mutate: Optional["UsersSetEmailAdminUserMutate"] + + +class UsersSetEmailAdminUserMutate(BaseModel): + set_email: Union[ + "UsersSetEmailAdminUserMutateSetEmailUser", + "UsersSetEmailAdminUserMutateSetEmailInvalidInput", + "UsersSetEmailAdminUserMutateSetEmailOperationError", + ] = Field(alias="setEmail", discriminator="typename__") + + +class UsersSetEmailAdminUserMutateSetEmailUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersSetEmailAdminUserMutateSetEmailInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersSetEmailAdminUserMutateSetEmailOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersSetEmail.model_rebuild() +UsersSetEmailAdmin.model_rebuild() +UsersSetEmailAdminUser.model_rebuild() +UsersSetEmailAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/_graphql_client/users_set_role.py b/api/python/quilt3/admin/_graphql_client/users_set_role.py new file mode 100644 index 00000000000..e5af3f1b997 --- /dev/null +++ b/api/python/quilt3/admin/_graphql_client/users_set_role.py @@ -0,0 +1,47 @@ +# Generated by ariadne-codegen +# Source: queries.graphql + +from typing import Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import InvalidInputSelection, OperationErrorSelection, UserSelection + + +class UsersSetRole(BaseModel): + admin: "UsersSetRoleAdmin" + + +class UsersSetRoleAdmin(BaseModel): + user: "UsersSetRoleAdminUser" + + +class UsersSetRoleAdminUser(BaseModel): + mutate: Optional["UsersSetRoleAdminUserMutate"] + + +class UsersSetRoleAdminUserMutate(BaseModel): + set_role: Union[ + "UsersSetRoleAdminUserMutateSetRoleUser", + "UsersSetRoleAdminUserMutateSetRoleInvalidInput", + "UsersSetRoleAdminUserMutateSetRoleOperationError", + ] = Field(alias="setRole", discriminator="typename__") + + +class UsersSetRoleAdminUserMutateSetRoleUser(UserSelection): + typename__: Literal["User"] = Field(alias="__typename") + + +class UsersSetRoleAdminUserMutateSetRoleInvalidInput(InvalidInputSelection): + typename__: Literal["InvalidInput"] = Field(alias="__typename") + + +class UsersSetRoleAdminUserMutateSetRoleOperationError(OperationErrorSelection): + typename__: Literal["OperationError"] = Field(alias="__typename") + + +UsersSetRole.model_rebuild() +UsersSetRoleAdmin.model_rebuild() +UsersSetRoleAdminUser.model_rebuild() +UsersSetRoleAdminUserMutate.model_rebuild() diff --git a/api/python/quilt3/admin/exceptions.py b/api/python/quilt3/admin/exceptions.py new file mode 100644 index 00000000000..591077095c3 --- /dev/null +++ b/api/python/quilt3/admin/exceptions.py @@ -0,0 +1,14 @@ +class Quilt3AdminError(Exception): + def __init__(self, details): + super().__init__(details) + self.details = details + + +class UserNotFoundError(Quilt3AdminError): + def __init__(self): + super().__init__(None) + + +class BucketNotFoundError(Quilt3AdminError): + def __init__(self): + super().__init__(None) diff --git a/api/python/quilt3/admin/roles.py b/api/python/quilt3/admin/roles.py new file mode 100644 index 00000000000..ec7b6450919 --- /dev/null +++ b/api/python/quilt3/admin/roles.py @@ -0,0 +1,10 @@ +from typing import List + +from . import types, util + + +def list() -> List[types.Role]: + """ + Get a list of all roles in the registry. + """ + return [types.role_adapter.validate_python(r.model_dump()) for r in util.get_client().roles_list()] diff --git a/api/python/quilt3/admin/sso_config.py b/api/python/quilt3/admin/sso_config.py new file mode 100644 index 00000000000..9e2ef36314f --- /dev/null +++ b/api/python/quilt3/admin/sso_config.py @@ -0,0 +1,19 @@ +import typing as T + +from . import types, util + + +def get() -> T.Optional[types.SSOConfig]: + """ + Get the current SSO configuration. + """ + result = util.get_client().sso_config_get() + return None if result is None else types.SSOConfig(**result.model_dump()) + + +def set(config: T.Optional[str]) -> T.Optional[types.SSOConfig]: + """ + Set the SSO configuration. Pass `None` to remove SSO configuration. + """ + result = util.get_client().sso_config_set(config) + return None if result is None else types.SSOConfig(**util.handle_errors(result).model_dump()) diff --git a/api/python/quilt3/admin/tabulator.py b/api/python/quilt3/admin/tabulator.py new file mode 100644 index 00000000000..e6e2188a3d6 --- /dev/null +++ b/api/python/quilt3/admin/tabulator.py @@ -0,0 +1,29 @@ +import typing as T + +from . import exceptions, types, util + + +def list_tables(bucket_name: str) -> list[types.TabulatorTable]: + """ + List all tabulator tables in a bucket. + """ + result = util.get_client().bucket_tabulator_tables_list(bucket_name) + if result is None: + raise exceptions.BucketNotFoundError + return [types.TabulatorTable(**x.model_dump()) for x in result.tabulator_tables] + + +def set_table(bucket_name: str, table_name: str, config: T.Optional[str]) -> None: + """ + Set the tabulator table configuration. Pass `None` to remove the table. + """ + result = util.get_client().bucket_tabulator_table_set(bucket_name, table_name, config) + util.handle_errors(result) + + +def rename_table(bucket_name: str, table_name: str, new_table_name: str) -> None: + """ + Rename tabulator table. + """ + result = util.get_client().bucket_tabulator_table_rename(bucket_name, table_name, new_table_name) + util.handle_errors(result) diff --git a/api/python/quilt3/admin/types.py b/api/python/quilt3/admin/types.py new file mode 100644 index 00000000000..855f170c2a8 --- /dev/null +++ b/api/python/quilt3/admin/types.py @@ -0,0 +1,52 @@ +from datetime import datetime +from typing import Annotated, List, Literal, Optional, Union + +import pydantic + + +@pydantic.dataclasses.dataclass +class ManagedRole: + id: str + name: str + arn: str + typename__: Literal["ManagedRole"] + + +@pydantic.dataclasses.dataclass +class UnmanagedRole: + id: str + name: str + arn: str + typename__: Literal["UnmanagedRole"] + + +Role = Union[ManagedRole, UnmanagedRole] +AnnotatedRole = Annotated[Role, pydantic.Field(discriminator="typename__")] +role_adapter = pydantic.TypeAdapter(AnnotatedRole) + + +@pydantic.dataclasses.dataclass +class User: + name: str + email: str + date_joined: datetime + last_login: datetime + is_active: bool + is_admin: bool + is_sso_only: bool + is_service: bool + role: Optional[AnnotatedRole] + extra_roles: List[AnnotatedRole] + + +@pydantic.dataclasses.dataclass +class SSOConfig: + text: str + timestamp: datetime + uploader: User + + +@pydantic.dataclasses.dataclass +class TabulatorTable: + name: str + config: str diff --git a/api/python/quilt3/admin/users.py b/api/python/quilt3/admin/users.py new file mode 100644 index 00000000000..fb667de9de8 --- /dev/null +++ b/api/python/quilt3/admin/users.py @@ -0,0 +1,165 @@ + +from typing import List, Optional + +from . import _graphql_client, exceptions, types, util + + +def get(name: str) -> Optional[types.User]: + """ + Get a specific user from the registry. Return `None` if the user does not exist. + + Args: + name: Username of user to get. + """ + result = util.get_client().users_get(name=name) + if result is None: + return None + return types.User(**result.model_dump()) + + +def list() -> List[types.User]: + """ + Get a list of all users in the registry. + """ + return [types.User(**u.model_dump()) for u in util.get_client().users_list()] + + +def create(name: str, email: str, role: str, extra_roles: Optional[List[str]] = None) -> types.User: + """ + Create a new user in the registry. + + Args: + name: Username of user to create. + email: Email of user to create. + role: Active role of the user. + extra_roles: Additional roles to assign to the user. + """ + + return util.handle_user_mutation( + util.get_client().users_create( + input=_graphql_client.UserInput(name=name, email=email, role=role, extraRoles=extra_roles) + ) + ) + + +def delete(name: str) -> None: + """ + Delete user from the registry. + + Args: + name: Username of user to delete. + """ + result = util.get_client().users_delete(name=name) + if result is None: + raise exceptions.UserNotFoundError + util.handle_errors(result.delete) + + +def set_email(name: str, email: str) -> types.User: + """ + Set the email for a user. + + Args: + name: Username of user to update. + email: Email to set for the user. + """ + result = util.get_client().users_set_email(name=name, email=email) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.set_email) + + +def set_admin(name: str, admin: bool) -> types.User: + """ + Set the admin status for a user. + + Args: + name: Username of user to update. + admin: Admin status to set for the user. + """ + result = util.get_client().users_set_admin(name=name, admin=admin) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.set_admin) + + +def set_active(name: str, active: bool) -> types.User: + """ + Set the active status for a user. + + Args: + name: Username of user to update. + active: Active status to set for the user. + """ + result = util.get_client().users_set_active(name=name, active=active) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.set_active) + + +def reset_password(name: str) -> None: + """ + Reset the password for a user. + + Args: + name: Username of user to update. + """ + result = util.get_client().users_reset_password(name=name) + if result is None: + raise exceptions.UserNotFoundError + util.handle_errors(result.reset_password) + + +def set_role( + name: str, + role: str, + extra_roles: Optional[List[str]] = None, + *, + append: bool = False, +) -> types.User: + """ + Set the active and extra roles for a user. + + Args: + name: Username of user to update. + role: Role to be set as the active role. + extra_roles: Additional roles to assign to the user. + append: If True, append the extra roles to the existing roles. If False, replace the existing roles. + """ + result = util.get_client().users_set_role(name=name, role=role, extra_roles=extra_roles, append=append) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.set_role) + + +def add_roles(name: str, roles: List[str]) -> types.User: + """ + Add roles to a user. + + Args: + name: Username of user to update. + roles: Roles to add to the user. + """ + result = util.get_client().users_add_roles(name=name, roles=roles) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.add_roles) + + +def remove_roles( + name: str, + roles: List[str], + fallback: Optional[str] = None, +) -> types.User: + """ + Remove roles from a user. + + Args: + name: Username of user to update. + roles: Roles to remove from the user. + fallback: If set, the role to assign to the user if the active role is removed. + """ + result = util.get_client().users_remove_roles(name=name, roles=roles, fallback=fallback) + if result is None: + raise exceptions.UserNotFoundError + return util.handle_user_mutation(result.remove_roles) diff --git a/api/python/quilt3/admin/util.py b/api/python/quilt3/admin/util.py new file mode 100644 index 00000000000..bf3a3f9c988 --- /dev/null +++ b/api/python/quilt3/admin/util.py @@ -0,0 +1,15 @@ +from . import _graphql_client, exceptions, types + + +def handle_errors(result: _graphql_client.BaseModel) -> _graphql_client.BaseModel: + if isinstance(result, (_graphql_client.InvalidInputSelection, _graphql_client.OperationErrorSelection)): + raise exceptions.Quilt3AdminError(result) + return result + + +def handle_user_mutation(result: _graphql_client.BaseModel) -> types.User: + return types.User(**handle_errors(result).model_dump()) + + +def get_client(): + return _graphql_client.Client() diff --git a/api/python/quilt3/api.py b/api/python/quilt3/api.py index 7e5d8b6e591..561a2a8d35a 100644 --- a/api/python/quilt3/api.py +++ b/api/python/quilt3/api.py @@ -1,3 +1,5 @@ +import typing as T + from .backends import get_package_registry from .data_transfer import copy_file from .search_util import search_api @@ -164,38 +166,23 @@ def _disable_telemetry(): @ApiTelemetry("api.search") -def search(query, limit=10): +def search(query: T.Union[str, dict], limit: int = 10) -> T.List[dict]: """ Execute a search against the configured search endpoint. Args: - query (str): query string to search - limit (number): maximum number of results to return. Defaults to 10 + query: query string to query if passed as `str`, DSL query body if passed as `dict` + limit: maximum number of results to return. Defaults to 10 Query Syntax: - [simple query string query]( - https://www.elastic.co/guide/en/elasticsearch/reference/6.8/query-dsl-simple-query-string-query.html) - + [Query String Query]( + https://www.elastic.co/guide/en/elasticsearch/reference/6.8/query-dsl-query-string-query.html) + [Query DSL](https://www.elastic.co/guide/en/elasticsearch/reference/6.8/query-dsl.html) Returns: - a list of objects with the following structure: - ``` - [{ - "_id": - "_index": , - "_score": - "_source": - "key": , - "size": , - "user_meta": , - "last_modified": , - "updated": , - "version_id": - "_type": - }, ...] - ``` + search results """ # force a call to configure_from_default if no config exists _config() - raw_results = search_api(query, '*', limit) + raw_results = search_api(query, '_all', limit) return raw_results['hits']['hits'] diff --git a/api/python/quilt3/bucket.py b/api/python/quilt3/bucket.py index 23b05431e09..e0f58d69d2a 100644 --- a/api/python/quilt3/bucket.py +++ b/api/python/quilt3/bucket.py @@ -5,6 +5,7 @@ over an s3 bucket. """ import pathlib +import typing as T from .data_transfer import ( copy_file, @@ -36,36 +37,23 @@ def __init__(self, bucket_uri): if self._pk.path or self._pk.version_id is not None: raise QuiltException("Bucket URI shouldn't contain a path or a version ID") - def search(self, query, limit=10): + def search(self, query: T.Union[str, dict], limit: int = 10) -> T.List[dict]: """ Execute a search against the configured search endpoint. Args: - query (str): query string to search - limit (number): maximum number of results to return. Defaults to 10 + query: query string to query if passed as `str`, DSL query body if passed as `dict` + limit: maximum number of results to return. Defaults to 10 Query Syntax: - By default, a normal plaintext search will be executed over the query string. - You can use field-match syntax to filter on exact matches for fields in - your metadata. - The syntax for field match is `user_meta.$field_name:"exact_match"`. + [Query String Query]( + https://www.elastic.co/guide/en/elasticsearch/reference/6.8/query-dsl-query-string-query.html) + [Query DSL](https://www.elastic.co/guide/en/elasticsearch/reference/6.8/query-dsl.html) Returns: - a list of objects with the following structure: - ``` - [{ - "key": , - "version_id": , - "operation": <"Create" or "Delete">, - "meta": , - "size": , - "text": , - "source": , - "time": , - }...] - ``` - """ - return search_api(query, index=self._pk.bucket, limit=limit) + search results + """ + return search_api(query, index=f"{self._pk.bucket},{self._pk.bucket}_packages", limit=limit)["hits"]["hits"] def put_file(self, key, path): """ diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index e8e9a29d5ff..7ae97ef5bf4 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -1,3 +1,4 @@ +import binascii import concurrent import functools import hashlib @@ -14,12 +15,12 @@ import warnings from codecs import iterdecode from collections import defaultdict, deque -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from enum import Enum from threading import Lock -from typing import List, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple -import boto3 import jsonlines from boto3.s3.transfer import TransferConfig from botocore import UNSIGNED @@ -30,12 +31,7 @@ HTTPClientError, ReadTimeoutError, ) -from s3transfer.utils import ( - ChunksizeAdjuster, - OSUtils, - signal_not_transferring, - signal_transferring, -) +from s3transfer.utils import ReadFileChunk from tenacity import ( retry, retry_if_not_result, @@ -46,7 +42,7 @@ from tqdm import tqdm from . import util -from .session import create_botocore_session +from .session import get_boto3_session from .util import DISABLE_TQDM, PhysicalKey, QuiltException MAX_COPY_FILE_LIST_RETRIES = 3 @@ -140,7 +136,7 @@ def find_correct_client(self, api_type, bucket, param_dict): S3Api.LIST_OBJECTS_V2: check_list_objects_v2_works_for_client, S3Api.LIST_OBJECT_VERSIONS: check_list_object_versions_works_for_client } - assert api_type in check_fn_mapper.keys(), f"Only certain APIs are supported with unsigned_client. The " \ + assert api_type in check_fn_mapper, f"Only certain APIs are supported with unsigned_client. The " \ f"API '{api_type}' is not current supported. You may want to use S3ClientProvider.standard_client " \ f"instead " check_fn = check_fn_mapper[api_type] @@ -155,38 +151,24 @@ def find_correct_client(self, api_type, bucket, param_dict): raise S3NoValidClientError(f"S3 AccessDenied for {api_type} on bucket: {bucket}") def get_boto_session(self): - botocore_session = create_botocore_session() - boto_session = boto3.Session(botocore_session=botocore_session) - return boto_session - - def register_signals(self, s3_client): - # Enable/disable file read callbacks when uploading files. - # Copied from https://github.com/boto/s3transfer/blob/develop/s3transfer/manager.py#L501 - event_name = 'request-created.s3' - s3_client.meta.events.register_first( - event_name, signal_not_transferring, - unique_id='datatransfer-not-transferring') - s3_client.meta.events.register_last( - event_name, signal_transferring, - unique_id='datatransfer-transferring') - - def _build_client(self, get_config): + return get_boto3_session() + + def _build_client(self, is_unsigned): session = self.get_boto_session() - return session.client('s3', config=get_config(session)) + conf_kwargs = { + "max_pool_connections": MAX_CONCURRENCY, + } + if is_unsigned(session): + conf_kwargs["signature_version"] = UNSIGNED + + return session.client('s3', config=Config(**conf_kwargs)) def _build_standard_client(self): - s3_client = self._build_client( - lambda session: - Config(signature_version=UNSIGNED) - if session.get_credentials() is None - else None - ) - self.register_signals(s3_client) + s3_client = self._build_client(lambda session: session.get_credentials() is None) self._standard_client = s3_client def _build_unsigned_client(self): - s3_client = self._build_client(lambda session: Config(signature_version=UNSIGNED)) - self.register_signals(s3_client) + s3_client = self._build_client(lambda session: True) self._unsigned_client = s3_client @@ -249,7 +231,66 @@ def read_file_chunks(file, chunksize=s3_transfer_config.io_chunksize): UPLOAD_ETAG_OPTIMIZATION_THRESHOLD = 1024 -def _copy_local_file(ctx, size, src_path, dest_path): +# 8 MiB - same as TransferConfig().multipart_threshold - but hard-coded to guarantee it won't change. +CHECKSUM_MULTIPART_THRESHOLD = 8 * 1024 * 1024 + +# Maximum number of parts supported by S3 +CHECKSUM_MAX_PARTS = 10_000 + + +@dataclass +class WorkerContext: + s3_client_provider: S3ClientProvider + progress: Callable[[int], None] + done: Callable[[PhysicalKey, Optional[str]], None] + run: Callable[..., None] + + +def get_checksum_chunksize(file_size: int) -> int: + """ + Calculate the chunk size to be used for the checksum. It is normally 8 MiB, + but gets doubled as long as the number of parts exceeds the maximum of 10,000. + + It is the same as + `ChunksizeAdjuster().adjust_chunksize(s3_transfer_config.multipart_chunksize, file_size)`, + but hard-coded to guarantee it won't change and make the current behavior a part of the API. + """ + chunksize = 8 * 1024 * 1024 + num_parts = math.ceil(file_size / chunksize) + + while num_parts > CHECKSUM_MAX_PARTS: + chunksize *= 2 + num_parts = math.ceil(file_size / chunksize) + + return chunksize + + +def is_mpu(file_size: int) -> bool: + return file_size >= CHECKSUM_MULTIPART_THRESHOLD + + +_EMPTY_STRING_SHA256 = hashlib.sha256(b'').digest() + + +def _simple_s3_to_quilt_checksum(s3_checksum: str) -> str: + """ + Converts a SHA256 hash from a regular (non-multipart) S3 upload into a multipart hash, + i.e., base64(sha256(bytes)) -> base64(sha256([sha256(bytes)])). + + Edge case: a 0-byte upload is treated as an empty list of chunks, rather than a list of a 0-byte chunk. + Its checksum is sha256(''), NOT sha256(sha256('')). + """ + s3_checksum_bytes = binascii.a2b_base64(s3_checksum) + + if s3_checksum_bytes == _EMPTY_STRING_SHA256: + # Do not hash it again. + return s3_checksum + + quilt_checksum_bytes = hashlib.sha256(s3_checksum_bytes).digest() + return binascii.b2a_base64(quilt_checksum_bytes, newline=False).decode() + + +def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: str): pathlib.Path(dest_path).parent.mkdir(parents=True, exist_ok=True) # TODO(dima): More detailed progress. @@ -257,31 +298,33 @@ def _copy_local_file(ctx, size, src_path, dest_path): ctx.progress(size) shutil.copymode(src_path, dest_path) - ctx.done(PhysicalKey.from_path(dest_path)) + ctx.done(PhysicalKey.from_path(dest_path), None) -def _upload_file(ctx, size, src_path, dest_bucket, dest_key): +def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str): s3_client = ctx.s3_client_provider.standard_client - if size < s3_transfer_config.multipart_threshold: - with OSUtils().open_file_chunk_reader(src_path, 0, size, [ctx.progress]) as fd: + if not is_mpu(size): + with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd: resp = s3_client.put_object( Body=fd, Bucket=dest_bucket, Key=dest_key, + ChecksumAlgorithm='SHA256', ) version_id = resp.get('VersionId') # Absent in unversioned buckets. - ctx.done(PhysicalKey(dest_bucket, dest_key, version_id)) + checksum = _simple_s3_to_quilt_checksum(resp['ChecksumSHA256']) + ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) else: resp = s3_client.create_multipart_upload( Bucket=dest_bucket, Key=dest_key, + ChecksumAlgorithm='SHA256', ) upload_id = resp['UploadId'] - adjuster = ChunksizeAdjuster() - chunksize = adjuster.adjust_chunksize(s3_transfer_config.multipart_chunksize, size) + chunksize = get_checksum_chunksize(size) chunk_offsets = list(range(0, size, chunksize)) @@ -292,16 +335,21 @@ def _upload_file(ctx, size, src_path, dest_bucket, dest_key): def upload_part(i, start, end): nonlocal remaining part_id = i + 1 - with OSUtils().open_file_chunk_reader(src_path, start, end-start, [ctx.progress]) as fd: + with ReadFileChunk.from_filename(src_path, start, end-start, [ctx.progress]) as fd: part = s3_client.upload_part( Body=fd, Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, - PartNumber=part_id + PartNumber=part_id, + ChecksumAlgorithm='SHA256', ) with lock: - parts[i] = {"PartNumber": part_id, "ETag": part["ETag"]} + parts[i] = dict( + PartNumber=part_id, + ETag=part['ETag'], + ChecksumSHA256=part['ChecksumSHA256'], + ) remaining -= 1 done = remaining == 0 @@ -310,17 +358,25 @@ def upload_part(i, start, end): Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, - MultipartUpload={"Parts": parts} + MultipartUpload={'Parts': parts}, ) version_id = resp.get('VersionId') # Absent in unversioned buckets. - ctx.done(PhysicalKey(dest_bucket, dest_key, version_id)) + checksum, _ = resp['ChecksumSHA256'].split('-', 1) + ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) for i, start in enumerate(chunk_offsets): end = min(start + chunksize, size) ctx.run(upload_part, i, start, end) -def _download_file(ctx, size, src_bucket, src_key, src_version, dest_path): +def _download_file( + ctx: WorkerContext, + size: int, + src_bucket: str, + src_key: str, + src_version: Optional[str], + dest_path: str +): dest_file = pathlib.Path(dest_path) if dest_file.is_reserved(): raise ValueError("Cannot download to %r: reserved file name" % dest_path) @@ -346,6 +402,8 @@ def _download_file(ctx, size, src_bucket, src_key, src_version, dest_path): if src_version is not None: params.update(VersionId=src_version) + # Note: we are not calculating checksums when downloading, + # so we're free to use S3 defaults (or anything else) here. part_size = s3_transfer_config.multipart_chunksize is_multi_part = ( is_regular_file @@ -384,14 +442,14 @@ def download_part(part_number): remaining_counter -= 1 done = remaining_counter == 0 if done: - ctx.done(PhysicalKey.from_path(dest_path)) + ctx.done(PhysicalKey.from_path(dest_path), None) for part_number in part_numbers: ctx.run(download_part, part_number) -def _copy_remote_file(ctx, size, src_bucket, src_key, src_version, - dest_bucket, dest_key, extra_args=None): +def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: str, src_version: Optional[str], + dest_bucket: str, dest_key: str, extra_args: Optional[Iterable[Tuple[str, Any]]] = None): src_params = dict( Bucket=src_bucket, Key=src_key @@ -403,11 +461,12 @@ def _copy_remote_file(ctx, size, src_bucket, src_key, src_version, s3_client = ctx.s3_client_provider.standard_client - if size < s3_transfer_config.multipart_threshold: - params = dict( + if not is_mpu(size): + params: Dict[str, Any] = dict( CopySource=src_params, Bucket=dest_bucket, Key=dest_key, + ChecksumAlgorithm='SHA256', ) if extra_args: @@ -416,16 +475,17 @@ def _copy_remote_file(ctx, size, src_bucket, src_key, src_version, resp = s3_client.copy_object(**params) ctx.progress(size) version_id = resp.get('VersionId') # Absent in unversioned buckets. - ctx.done(PhysicalKey(dest_bucket, dest_key, version_id)) + checksum = _simple_s3_to_quilt_checksum(resp['CopyObjectResult']['ChecksumSHA256']) + ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) else: resp = s3_client.create_multipart_upload( Bucket=dest_bucket, Key=dest_key, + ChecksumAlgorithm='SHA256', ) upload_id = resp['UploadId'] - adjuster = ChunksizeAdjuster() - chunksize = adjuster.adjust_chunksize(s3_transfer_config.multipart_chunksize, size) + chunksize = get_checksum_chunksize(size) chunk_offsets = list(range(0, size, chunksize)) @@ -442,10 +502,14 @@ def upload_part(i, start, end): Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, - PartNumber=part_id + PartNumber=part_id, ) with lock: - parts[i] = {"PartNumber": part_id, "ETag": part["CopyPartResult"]["ETag"]} + parts[i] = dict( + PartNumber=part_id, + ETag=part['CopyPartResult']['ETag'], + ChecksumSHA256=part['CopyPartResult']['ChecksumSHA256'], + ) remaining -= 1 done = remaining == 0 @@ -456,55 +520,75 @@ def upload_part(i, start, end): Bucket=dest_bucket, Key=dest_key, UploadId=upload_id, - MultipartUpload={"Parts": parts} + MultipartUpload={'Parts': parts}, ) version_id = resp.get('VersionId') # Absent in unversioned buckets. - ctx.done(PhysicalKey(dest_bucket, dest_key, version_id)) + checksum, _ = resp['ChecksumSHA256'].split('-', 1) + ctx.done(PhysicalKey(dest_bucket, dest_key, version_id), checksum) for i, start in enumerate(chunk_offsets): end = min(start + chunksize, size) ctx.run(upload_part, i, start, end) -def _upload_or_copy_file(ctx, size, src_path, dest_bucket, dest_path): +def _calculate_local_checksum(path: str, size: int): + chunksize = get_checksum_chunksize(size) + + part_hashes = [] + for start in range(0, size, chunksize): + end = min(start + chunksize, size) + part_hashes.append(_calculate_local_part_checksum(path, start, end - start)) + + return _make_checksum_from_parts(part_hashes) + + +def _reuse_remote_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str): # Optimization: check if the remote file already exists and has the right ETag, # and skip the upload. - if size >= UPLOAD_ETAG_OPTIMIZATION_THRESHOLD: - try: - params = dict(Bucket=dest_bucket, Key=dest_path) - s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params) - resp = s3_client.head_object(**params) - except ClientError: - # Destination doesn't exist, so fall through to the normal upload. - pass - except S3NoValidClientError: - # S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a - # user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway. - pass - else: - # Check the ETag. - dest_size = resp['ContentLength'] - dest_etag = resp['ETag'] - dest_version_id = resp.get('VersionId') - if size == dest_size and resp.get('ServerSideEncryption') != 'aws:kms': - src_etag = _calculate_etag(src_path) - if src_etag == dest_etag: - # Nothing more to do. We should not attempt to copy the object because - # that would cause the "copy object to itself" error. - ctx.progress(size) - ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id)) - return # Optimization succeeded. + if size < UPLOAD_ETAG_OPTIMIZATION_THRESHOLD: + return None + try: + params = dict(Bucket=dest_bucket, Key=dest_path) + s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params) + resp = s3_client.head_object(**params, ChecksumMode="ENABLED") + except ClientError: + # Destination doesn't exist, so fall through to the normal upload. + pass + except S3NoValidClientError: + # S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a + # user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway. + pass + else: + dest_size = resp["ContentLength"] + if dest_size != size: + return None + # TODO: we could check hashes of parts, to finish faster + s3_checksum = resp.get("ChecksumSHA256") + if s3_checksum is not None: + if "-" in s3_checksum: + checksum, num_parts_str = s3_checksum.split("-", 1) + num_parts = int(num_parts_str) + else: + checksum = _simple_s3_to_quilt_checksum(s3_checksum) + num_parts = None + expected_num_parts = math.ceil(size / get_checksum_chunksize(size)) if is_mpu(size) else None + if num_parts == expected_num_parts and checksum == _calculate_local_checksum(src_path, size): + return resp.get("VersionId"), checksum + elif resp.get("ServerSideEncryption") != "aws:kms" and resp["ETag"] == _calculate_etag(src_path): + return resp.get("VersionId"), _calculate_local_checksum(src_path, size) - # If the optimization didn't happen, do the normal upload. - _upload_file(ctx, size, src_path, dest_bucket, dest_path) + return None -class WorkerContext: - def __init__(self, s3_client_provider, progress, done, run): - self.s3_client_provider = s3_client_provider - self.progress = progress - self.done = done - self.run = run +def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str): + result = _reuse_remote_file(ctx, size, src_path, dest_bucket, dest_path) + if result is not None: + dest_version_id, checksum = result + ctx.progress(size) + ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum) + return # Optimization succeeded. + # If the optimization didn't happen, do the normal upload. + _upload_file(ctx, size, src_path, dest_bucket, dest_path) def _copy_file_list_last_retry(retry_state): @@ -527,7 +611,7 @@ def _copy_file_list_internal(file_list, results, message, callback, exceptions_t if not file_list: return [] - logger.info('copy files: started') + logger.debug('copy files: started') assert len(file_list) == len(results) @@ -562,11 +646,11 @@ def worker(idx, src, dest, size): if stopped: raise Exception("Interrupted") - def done_callback(value): + def done_callback(value, checksum): assert value is not None with lock: assert results[idx] is None - results[idx] = value + results[idx] = (value, checksum) if callback is not None: callback(src, dest, size) @@ -584,7 +668,7 @@ def done_callback(value): else: if dest.version_id: raise ValueError("Cannot set VersionId on destination") - _upload_or_copy_file(ctx, size, src.path, dest.bucket, dest.path) + _upload_or_reuse_file(ctx, size, src.path, dest.bucket, dest.path) else: if dest.is_local(): _download_file(ctx, size, src.bucket, src.path, src.version_id, dest.path) @@ -622,7 +706,7 @@ def done_callback(value): # Make sure all tasks exit quickly if the main thread exits before they're done. stopped = True - logger.info('copy files: finished') + logger.debug('copy files: finished') return results @@ -637,12 +721,11 @@ def _calculate_etag(file_path): """ size = pathlib.Path(file_path).stat().st_size with open(file_path, 'rb') as fd: - if size < s3_transfer_config.multipart_threshold: + if not is_mpu(size): contents = fd.read() etag = hashlib.md5(contents).hexdigest() else: - adjuster = ChunksizeAdjuster() - chunksize = adjuster.adjust_chunksize(s3_transfer_config.multipart_chunksize, size) + chunksize = get_checksum_chunksize(size) hashes = [] for contents in read_file_chunks(fd, chunksize): @@ -889,15 +972,130 @@ def get_size_and_version(src: PhysicalKey): return size, version -def calculate_sha256(src_list: List[PhysicalKey], sizes: List[int]): +def calculate_checksum(src_list: List[PhysicalKey], sizes: List[int]) -> List[bytes]: assert len(src_list) == len(sizes) if not src_list: return [] - return _calculate_sha256_internal(src_list, sizes, [None] * len(src_list)) + return _calculate_checksum_internal(src_list, sizes, [None] * len(src_list)) + + +def with_lock(f): + lock = threading.Lock() + + @functools.wraps(f) + def wrapper(*args, **kwargs): + with lock: + return f(*args, **kwargs) + return wrapper + + +def _calculate_local_part_checksum(src: str, offset: int, length: int, callback=None) -> bytes: + hash_obj = hashlib.sha256() + bytes_remaining = length + with open(src, "rb") as fd: + fd.seek(offset) + while bytes_remaining > 0: + chunk = fd.read(min(s3_transfer_config.io_chunksize, bytes_remaining)) + if not chunk: + # Should not happen, but let's not get stuck in an infinite loop. + raise QuiltException("Unexpected end of file") + hash_obj.update(chunk) + if callback is not None: + callback(len(chunk)) + bytes_remaining -= len(chunk) + + return hash_obj.digest() -def _calculate_hash_get_s3_chunks(ctx, src, size): +def _make_checksum_from_parts(parts: List[bytes]) -> str: + return binascii.b2a_base64(hashlib.sha256(b"".join(parts)).digest(), newline=False).decode() + + +@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)), + retry_error_callback=lambda retry_state: retry_state.outcome.result(), + ) +def _calculate_checksum_internal(src_list, sizes, results) -> List[bytes]: + total_size = sum( + size + for size, result in zip(sizes, results) + if result is None or isinstance(result, Exception) + ) + stopped = False + + with tqdm(desc="Hashing", total=total_size, unit='B', unit_scale=True, disable=DISABLE_TQDM) as progress, \ + ThreadPoolExecutor(MAX_CONCURRENCY) as executor: + + find_correct_client = with_lock(S3ClientProvider().find_correct_client) + progress_update = with_lock(progress.update) + + def _process_url_part(src: PhysicalKey, offset: int, length: int): + if src.is_local(): + return _calculate_local_part_checksum(src.path, offset, length, progress_update) + else: + hash_obj = hashlib.sha256() + end = offset + length - 1 + params = dict( + Bucket=src.bucket, + Key=src.path, + Range=f'bytes={offset}-{end}', + ) + if src.version_id is not None: + params.update(VersionId=src.version_id) + + s3_client = find_correct_client(S3Api.GET_OBJECT, src.bucket, params) + + try: + body = s3_client.get_object(**params)['Body'] + for chunk in read_file_chunks(body): + hash_obj.update(chunk) + progress_update(len(chunk)) + if stopped: + return None + except (ConnectionError, HTTPClientError, ReadTimeoutError) as ex: + return ex + + return hash_obj.digest() + + futures: List[Tuple[int, List[Future]]] = [] + + for idx, (src, size, result) in enumerate(zip(src_list, sizes, results)): + if result is None or isinstance(result, Exception): + chunksize = get_checksum_chunksize(size) + + src_future_list = [] + for start in range(0, size, chunksize): + end = min(start + chunksize, size) + future = executor.submit(_process_url_part, src, start, end-start) + src_future_list.append(future) + + futures.append((idx, src_future_list)) + + try: + for idx, future_list in futures: + future_results = [future.result() for future in future_list] + exceptions = [ex for ex in future_results if isinstance(ex, Exception)] + results[idx] = exceptions[0] if exceptions else _make_checksum_from_parts(future_results) + finally: + stopped = True + for _, future_list in futures: + for future in future_list: + future.cancel() + + return results + + +def legacy_calculate_checksum(src_list: List[PhysicalKey], sizes: List[int]) -> List[bytes]: + assert len(src_list) == len(sizes) + + if not src_list: + return [] + return _legacy_calculate_checksum_internal(src_list, sizes, [None] * len(src_list)) + + +def _legacy_calculate_hash_get_s3_chunks(ctx, src, size): params = dict(Bucket=src.bucket, Key=src.path) if src.version_id is not None: params.update(VersionId=src.version_id) @@ -970,22 +1168,12 @@ def iter_queue(part_number): itertools.starmap(generators.popleft, itertools.repeat((), len(part_numbers)))) -def with_lock(f): - lock = threading.Lock() - - @functools.wraps(f) - def wrapper(*args, **kwargs): - with lock: - return f(*args, **kwargs) - return wrapper - - @retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)), retry_error_callback=lambda retry_state: retry_state.outcome.result(), ) -def _calculate_sha256_internal(src_list, sizes, results): +def _legacy_calculate_checksum_internal(src_list, sizes, results) -> List[bytes]: total_size = sum( size for size, result in zip(sizes, results) @@ -1018,7 +1206,7 @@ def _process_url(src, size): (get_file_chunks(src, size), ()) if src.is_local() else ( - _calculate_hash_get_s3_chunks(s3_context, src, size), + _legacy_calculate_hash_get_s3_chunks(s3_context, src, size), (ConnectionError, HTTPClientError, ReadTimeoutError) ) ) @@ -1066,6 +1254,23 @@ def _process_url(src, size): return results +def calculate_checksum_bytes(data: bytes) -> str: + size = len(data) + chunksize = get_checksum_chunksize(size) + + hashes = [] + for start in range(0, size, chunksize): + end = min(start + chunksize, size) + hashes.append(hashlib.sha256(data[start:end]).digest()) + + hashes_hash = hashlib.sha256(b''.join(hashes)).digest() + return binascii.b2a_base64(hashes_hash, newline=False).decode() + + +def legacy_calculate_checksum_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + def select(src, query, meta=None, raw=False, **kwargs): """Perform an S3 Select SQL query, return results as a Pandas DataFrame diff --git a/api/python/quilt3/formats.py b/api/python/quilt3/formats.py index 7ca158b193f..34580c7a591 100644 --- a/api/python/quilt3/formats.py +++ b/api/python/quilt3/formats.py @@ -77,13 +77,9 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict +from importlib import metadata as importlib_metadata from pathlib import Path -try: - from importlib import metadata as importlib_metadata -except ImportError: - import importlib_metadata - from .util import QuiltException # Constants diff --git a/api/python/quilt3/packages.py b/api/python/quilt3/packages.py index c87074b5cfb..09f1391b60c 100644 --- a/api/python/quilt3/packages.py +++ b/api/python/quilt3/packages.py @@ -12,6 +12,7 @@ import tempfile import textwrap import time +import typing as T import uuid import warnings from collections import deque @@ -24,12 +25,16 @@ from . import util, workflows from .backends import get_package_registry from .data_transfer import ( - calculate_sha256, + calculate_checksum, + calculate_checksum_bytes, copy_file, copy_file_list, get_bytes, get_size_and_version, + legacy_calculate_checksum, + legacy_calculate_checksum_bytes, list_object_versions, + list_objects, list_url, put_bytes, ) @@ -65,11 +70,25 @@ if MANIFEST_MAX_RECORD_SIZE is None: MANIFEST_MAX_RECORD_SIZE = DEFAULT_MANIFEST_MAX_RECORD_SIZE +SHA256_HASH_NAME = 'SHA256' +SHA256_CHUNKED_HASH_NAME = 'sha2-256-chunked' + SUPPORTED_HASH_TYPES = ( - "SHA256", + SHA256_HASH_NAME, + SHA256_CHUNKED_HASH_NAME, ) +class CopyFileListFn(T.Protocol): + def __call__( + self, + file_list: T.List[T.Tuple[PhysicalKey, PhysicalKey, int]], + message: T.Optional[str] = None, + callback: T.Optional[T.Callable] = None, + ) -> T.List[T.Tuple[PhysicalKey, T.Optional[str]]]: + ... + + def _fix_docstring(**kwargs): def f(wrapped): if sys.flags.optimize < 2: @@ -209,9 +228,17 @@ def _verify_hash(self, read_bytes): """ if self.hash is None: raise QuiltException("Hash missing - need to build the package") - _check_hash_type_support(self.hash.get('type')) - digest = hashlib.sha256(read_bytes).hexdigest() - if digest != self.hash.get('value'): + hash_type = self.hash.get('type') + _check_hash_type_support(hash_type) + + if hash_type == SHA256_CHUNKED_HASH_NAME: + expected_value = calculate_checksum_bytes(read_bytes) + elif hash_type == SHA256_HASH_NAME: + expected_value = legacy_calculate_checksum_bytes(read_bytes) + else: + assert False + + if expected_value != self.hash.get('value'): raise QuiltException("Hash validation failed") def set(self, path=None, meta=None): @@ -521,7 +548,7 @@ def install(cls, name, registry=None, top_hash=None, dest=None, dest_registry=No if subpkg_key is not None: if subpkg_key not in pkg: - raise QuiltException(f"Package {name} doesn't contain {subpkg_key!r}.") + raise QuiltException(f"Package {name!r} doesn't contain {subpkg_key!r}.") entry = pkg[subpkg_key] entries = entry.walk() if isinstance(entry, Package) else ((subpkg_key.split('/')[-1], entry),) else: @@ -804,7 +831,7 @@ def _load(cls, readable_file): subpkg.set_meta(obj['meta']) continue if key in subpkg._children: - raise PackageException("Duplicate logical key while loading package") + raise PackageException(f"Duplicate logical key {key!r} while loading package entry: {obj!r}") subpkg._children[key] = PackageEntry( PhysicalKey.from_url(obj['physical_keys'][0]), obj['size'], @@ -815,7 +842,7 @@ def _load(cls, readable_file): gc.enable() return pkg - def set_dir(self, lkey, path=None, meta=None, update_policy="incoming"): + def set_dir(self, lkey, path=None, meta=None, update_policy="incoming", unversioned: bool = False): """ Adds all files from `path` to the package. @@ -832,6 +859,7 @@ def set_dir(self, lkey, path=None, meta=None, update_policy="incoming"): If 'incoming', whenever logical keys match, always take the new entry from set_dir. If 'existing', whenever logical keys match, retain existing entries and ignore new entries from set_dir. + unversioned(bool): when True, do not retrieve VersionId for S3 physical keys. Returns: self @@ -841,7 +869,7 @@ def set_dir(self, lkey, path=None, meta=None, update_policy="incoming"): ValueError: When `update_policy` is invalid. """ if update_policy not in PACKAGE_UPDATE_POLICY: - raise ValueError(f"Update policy should be one of {PACKAGE_UPDATE_POLICY}, not {update_policy!r}") + raise ValueError(f"Update policy should be one of {PACKAGE_UPDATE_POLICY!r}, not {update_policy!r}") lkey = lkey.strip("/") @@ -862,7 +890,7 @@ def set_dir(self, lkey, path=None, meta=None, update_policy="incoming"): if src.is_local(): src_path = pathlib.Path(src.path) if not src_path.is_dir(): - raise PackageException("The specified directory doesn't exist") + raise PackageException(f"The specified directory {src_path!r} doesn't exist") files = src_path.rglob('*') ignore = src_path / '.quiltignore' @@ -884,10 +912,13 @@ def set_dir(self, lkey, path=None, meta=None, update_policy="incoming"): src_path = src.path if src.basename() != '': src_path += '/' - objects, _ = list_object_versions(src.bucket, src_path) + if not unversioned: + objects, _ = list_object_versions(src.bucket, src_path) + objects = filter(lambda obj: obj["IsLatest"], objects) + else: + objects = list_objects(src.bucket, src_path, recursive=True) + for obj in objects: - if not obj['IsLatest']: - continue # Skip S3 pseduo directory files and Keys that end in / if obj['Key'].endswith('/'): if obj['Size'] != 0: @@ -920,7 +951,7 @@ def get(self, logical_key): """ obj = self[logical_key] if not isinstance(obj, PackageEntry): - raise ValueError("Key does not point to a PackageEntry") + raise ValueError(f"Key {logical_key!r} does not point to a PackageEntry") return obj.get() def readme(self): @@ -944,11 +975,11 @@ def set_meta(self, meta): self._meta['user_meta'] = meta return self - def _fix_sha256(self): + def _calculate_missing_hashes(self): """ Calculate and set missing hash values """ - logger.info('fix package hashes: started') + logger.debug('fix package hashes: started') self._incomplete_entries = [entry for key, entry in self.walk() if entry.hash is None] @@ -958,19 +989,19 @@ def _fix_sha256(self): physical_keys.append(entry.physical_key) sizes.append(entry.size) - results = calculate_sha256(physical_keys, sizes) + results = calculate_checksum(physical_keys, sizes) exc = None - for entry, obj_hash in zip(self._incomplete_entries, results): - if isinstance(obj_hash, Exception): - exc = obj_hash + for entry, result in zip(self._incomplete_entries, results): + if isinstance(result, Exception): + exc = result else: - entry.hash = dict(type='SHA256', value=obj_hash) + entry.hash = dict(type=SHA256_CHUNKED_HASH_NAME, value=result) if exc: incomplete_manifest_path = self._dump_manifest_to_scratch() msg = "Unable to reach S3 for some hash values. Incomplete manifest saved to {path}." raise PackageException(msg.format(path=incomplete_manifest_path)) from exc - logger.info('fix package hashes: finished') + logger.debug('fix package hashes: finished') def _set_commit_message(self, msg): """ @@ -1045,7 +1076,7 @@ def _build(self, name, registry, message): registry = get_package_registry(registry) self._set_commit_message(message) - self._fix_sha256() + self._calculate_missing_hashes() top_hash = self.top_hash self._push_manifest(name, registry, top_hash) @@ -1108,7 +1139,15 @@ def manifest(self): for logical_key, entry in self.walk(): yield {'logical_key': logical_key, **entry.as_dict()} - def set(self, logical_key, entry=None, meta=None, serialization_location=None, serialization_format_opts=None): + def set( + self, + logical_key, + entry=None, + meta=None, + serialization_location=None, + serialization_format_opts=None, + unversioned: bool = False, + ): """ Returns self with the object at logical_key set to entry. @@ -1128,6 +1167,7 @@ def set(self, logical_key, entry=None, meta=None, serialization_location=None, s https://github.com/quiltdata/quilt/blob/master/api/python/quilt3/formats.py serialization_location(string): Optional. If passed in, only used if entry is an object. Where the serialized object should be written, e.g. "./mydataframe.parquet" + unversioned(bool): when True, do not retrieve VersionId for S3 physical keys. Returns: self @@ -1136,13 +1176,21 @@ def set(self, logical_key, entry=None, meta=None, serialization_location=None, s entry=entry, meta=meta, serialization_location=serialization_location, - serialization_format_opts=serialization_format_opts) - - def _set(self, logical_key, entry=None, meta=None, serialization_location=None, serialization_format_opts=None): + serialization_format_opts=serialization_format_opts, + unversioned=unversioned) + + def _set( + self, + logical_key, + entry=None, + meta=None, + serialization_location=None, + serialization_format_opts=None, + unversioned: bool = False, + ): if not logical_key or logical_key.endswith('/'): raise QuiltException( - f"Invalid logical key {logical_key!r}. " - f"A package entry logical key cannot be a directory." + f"A package entry logical key {logical_key!r} must be a file." ) validate_key(logical_key) @@ -1155,7 +1203,7 @@ def _set(self, logical_key, entry=None, meta=None, serialization_location=None, size, version_id = get_size_and_version(src) # Determine if a new version needs to be appended. - if not src.is_local() and src.version_id is None and version_id is not None: + if not src.is_local() and src.version_id is None and version_id is not None and not unversioned: src.version_id = version_id entry = PackageEntry(src, size, None, None) elif isinstance(entry, PackageEntry): @@ -1189,7 +1237,7 @@ def _set(self, logical_key, entry=None, meta=None, serialization_location=None, if len(format_handlers) == 0: error_message = f'Quilt does not know how to serialize a {type(entry)}' if ext is not None: - error_message += f' as a {ext} file.' + error_message += f' as a {ext!r} file.' error_message += '. If you think this should be supported, please open an issue or PR at ' \ 'https://github.com/quiltdata/quilt' raise QuiltException(error_message) @@ -1222,7 +1270,7 @@ def _set(self, logical_key, entry=None, meta=None, serialization_location=None, pkg = self._ensure_subpackage(path[:-1], ensure_no_entry=True) if path[-1] in pkg and isinstance(pkg[path[-1]], Package): - raise QuiltException("Cannot overwrite directory with PackageEntry") + raise QuiltException(f"Cannot overwrite directory {path[-1]!r} with PackageEntry") pkg._children[path[-1]] = entry return self @@ -1243,7 +1291,10 @@ def _ensure_subpackage(self, path, ensure_no_entry=False): for key_fragment in path: if ensure_no_entry and key_fragment in pkg \ and isinstance(pkg[key_fragment], PackageEntry): - raise QuiltException("Already a PackageEntry along the path.") + raise QuiltException( + f"Already a PackageEntry for {key_fragment!r} " + f"along the path {path!r}: {pkg[key_fragment].physical_key!r}", + ) pkg = pkg._children.setdefault(key_fragment, Package()) return pkg @@ -1293,7 +1344,7 @@ def _get_top_hash_parts(cls, meta, entries): for logical_key, entry in entries: if entry.hash is None or entry.size is None: raise QuiltException( - "PackageEntry missing hash and/or size: %s" % entry.physical_key + "PackageEntry missing hash and/or size: %r" % entry.physical_key ) yield { 'hash': entry.hash, @@ -1336,8 +1387,8 @@ def push( Args: name: name for package in registry dest: where to copy the objects in the package. Must be either an S3 URI prefix (e.g., s3://$bucket/$key) - in the registry bucket, or a callable that takes logical_key, package_entry, and top_hash - and returns an S3 URI. + in the registry bucket, or a callable that takes logical_key and package_entry, and returns an S3 URI. + (Changed in 6.0.0a1) previously top_hash was passed to the callable dest as a third argument. registry: registry where to create the new package message: the commit message for the new package selector_fn: An optional function that determines which package entries should be copied to S3. @@ -1360,12 +1411,16 @@ def push( def _push( self, name, registry=None, dest=None, message=None, selector_fn=None, *, - workflow, print_info, force: bool, dedupe: bool + workflow, print_info, force: bool, dedupe: bool, + copy_file_list_fn: T.Optional[CopyFileListFn] = None, ): if selector_fn is None: def selector_fn(*args): return True + if copy_file_list_fn is None: + copy_file_list_fn = copy_file_list + validate_package_name(name) if registry is None: @@ -1400,7 +1455,7 @@ def dest_fn(*args, **kwargs): raise TypeError(f'{dest!r} returned {url!r}, but str is expected') pk = PhysicalKey.from_url(url) if pk.is_local(): - raise util.URLParseError("Unexpected scheme: 'file'") + raise util.URLParseError(f"Unexpected scheme: 'file' for {pk!r}") if pk.version_id: raise ValueError(f'{dest!r} returned {url!r}, but URI must not include versionId') return pk @@ -1435,8 +1490,8 @@ def check_hash_conficts(latest_hash): if self._origin is None or latest_hash != self._origin.top_hash: raise QuiltConflictException( - f"Package with hash {latest_hash} already exists at the destination; " - f"expected {None if self._origin is None else self._origin.top_hash}. " + f"Package with hash {latest_hash!r} already exists at the destination; " + f"expected {None if self._origin is None else self._origin.top_hash!r}. " "Use force=True (Python) or --force (CLI) to overwrite." ) @@ -1451,21 +1506,9 @@ def check_hash_conficts(latest_hash): if not force: check_hash_conficts(latest_hash) - self._fix_sha256() - pkg = self.__class__() pkg._meta = self._meta pkg._set_commit_message(message) - top_hash = self._calculate_top_hash(pkg._meta, self.walk()) - pkg._origin = PackageRevInfo(str(registry.base), name, top_hash) - - if dedupe and top_hash == latest_hash: - if print_info: - print( - f"Skipping since package with hash {latest_hash} already exists " - "at the destination and dedupe parameter is true." - ) - return self # Since all that is modified is physical keys, pkg will have the same top hash file_list = [] @@ -1479,7 +1522,7 @@ def check_hash_conficts(latest_hash): # Copy the datafiles in the package. physical_key = entry.physical_key - new_physical_key = dest_fn(logical_key, entry, top_hash) + new_physical_key = dest_fn(logical_key, entry) if ( physical_key.bucket == new_physical_key.bucket and physical_key.path == new_physical_key.path @@ -1490,14 +1533,31 @@ def check_hash_conficts(latest_hash): entries.append((logical_key, entry)) file_list.append((physical_key, new_physical_key, entry.size)) - results = copy_file_list(file_list, message="Copying objects") + results = copy_file_list_fn(file_list, message="Copying objects") - for (logical_key, entry), versioned_key in zip(entries, results): + for (logical_key, entry), (versioned_key, checksum) in zip(entries, results): # Create a new package entry pointing to the new remote key. assert versioned_key is not None new_entry = entry.with_physical_key(versioned_key) + if checksum is not None: + new_entry.hash = dict(type=SHA256_CHUNKED_HASH_NAME, value=checksum) pkg._set(logical_key, new_entry) + # Some entries may miss hash values (e.g because of selector_fn), so we need + # to fix them before calculating the top hash. + pkg._calculate_missing_hashes() + top_hash = pkg._calculate_top_hash(pkg._meta, pkg.walk()) + + if dedupe and top_hash == latest_hash: + if print_info: + print( + f"Skipping since package with hash {latest_hash} already exists " + "at the destination and dedupe parameter is true." + ) + return self + + pkg._origin = PackageRevInfo(str(registry.base), name, top_hash) + def physical_key_is_temp_file(pk): if not pk.is_local(): return False @@ -1668,26 +1728,48 @@ def verify(self, src, extra_files_ok=False): src = PhysicalKey.from_url(fix_url(src)) src_dict = dict(list_url(src)) + + expected_hash_list = [] url_list = [] size_list = [] + + legacy_expected_hash_list = [] + legacy_url_list = [] + legacy_size_list = [] + for logical_key, entry in self.walk(): src_size = src_dict.pop(logical_key, None) - if src_size is None: - return False - if entry.size != src_size: + if src_size is None or entry.size != src_size: return False entry_url = src.join(logical_key) - url_list.append(entry_url) - size_list.append(src_size) + hash_type = entry.hash['type'] + hash_value = entry.hash['value'] + if hash_type == SHA256_CHUNKED_HASH_NAME: + expected_hash_list.append(hash_value) + url_list.append(entry_url) + size_list.append(src_size) + elif hash_type == SHA256_HASH_NAME: + legacy_expected_hash_list.append(hash_value) + legacy_url_list.append(entry_url) + legacy_size_list.append(src_size) + else: + assert False, hash_type if src_dict and not extra_files_ok: return False - hash_list = calculate_sha256(url_list, size_list) - for (logical_key, entry), url_hash in zip(self.walk(), hash_list): + hash_list = calculate_checksum(url_list, size_list) + for expected_hash, url_hash in zip(expected_hash_list, hash_list): + if isinstance(url_hash, Exception): + raise url_hash + if expected_hash != url_hash: + return False + + legacy_hash_list = legacy_calculate_checksum(legacy_url_list, legacy_size_list) + for expected_hash, url_hash in zip(legacy_expected_hash_list, legacy_hash_list): if isinstance(url_hash, Exception): raise url_hash - if entry.hash['value'] != url_hash: + if expected_hash != url_hash: return False return True diff --git a/api/python/quilt3/search_util.py b/api/python/quilt3/search_util.py index 78c7b0ebf63..9a59f6c7268 100644 --- a/api/python/quilt3/search_util.py +++ b/api/python/quilt3/search_util.py @@ -3,53 +3,22 @@ Contains search-related glue code """ -import re -from urllib.parse import quote, urlencode, urlparse +import json +import typing as T -import requests -from aws_requests_auth.aws_auth import AWSRequestsAuth +from . import session -from .session import create_botocore_session -from .util import QuiltException, get_from_config - -def search_credentials(host, region, service): - credentials = create_botocore_session().get_credentials() - if credentials: - # use registry-provided credentials if present, otherwise - # standard boto credentials - creds = credentials.get_frozen_credentials() - auth = AWSRequestsAuth(aws_access_key=creds.access_key, - aws_secret_access_key=creds.secret_key, - aws_host=host, - aws_region=region, - aws_service=service, - aws_token=creds.token, - ) - else: - auth = None - - return auth - - -def search_api(query, index, limit=10): +def search_api(query: T.Union[str, dict], index: str, limit: int = 10): """ - Sends a query to the search API (supports simple search - queries only) + Send a query to the search API """ - api_gateway = get_from_config('apiGatewayEndpoint') - api_gateway_host = urlparse(api_gateway).hostname - match = re.match(r".*\.([a-z]{2}-[a-z]+-\d)\.amazonaws\.com$", api_gateway_host) - region = match.groups()[0] - auth = search_credentials(api_gateway_host, region, 'execute-api') - # Encode the parameters manually because AWS Auth requires spaces to be encoded as '%20' rather than '+'. - encoded_params = urlencode(dict(index=index, action='search', query=query), quote_via=quote) - response = requests.get( - f"{api_gateway}/search?{encoded_params}", - auth=auth + if isinstance(query, dict): + params = dict(index=index, action="freeform", body=json.dumps(query), size=limit) + else: + params = dict(index=index, action="search", query=query, size=limit) + response = session.get_session().get( + f"{session.get_registry_url()}/api/search", + params=params, ) - - if not response.ok: - raise QuiltException(response.text) - return response.json() diff --git a/api/python/quilt3/session.py b/api/python/quilt3/session.py index 67a8e794fdc..7865808b354 100644 --- a/api/python/quilt3/session.py +++ b/api/python/quilt3/session.py @@ -9,12 +9,10 @@ import subprocess import sys import time +import typing as T +from importlib import metadata -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata - +import boto3 import botocore.session import requests from botocore.credentials import ( @@ -31,10 +29,11 @@ def _load_auth(): - if AUTH_PATH.exists(): + try: with open(AUTH_PATH, encoding='utf-8') as fd: return json.load(fd) - return {} + except FileNotFoundError: + return {} def _save_auth(cfg): @@ -45,10 +44,11 @@ def _save_auth(cfg): def _load_credentials(): - if CREDENTIALS_PATH.exists(): + try: with open(CREDENTIALS_PATH, encoding='utf-8') as fd: return json.load(fd) - return {} + except FileNotFoundError: + return {} def _save_credentials(creds): @@ -290,14 +290,34 @@ def load(self): return creds -def create_botocore_session(): +def create_botocore_session(*, credentials: T.Optional[dict] = None) -> botocore.session.Session: botocore_session = botocore.session.get_session() # If we have saved credentials, use them. Otherwise, create a normal Boto session. - credentials = _load_credentials() + if credentials is None: + credentials = _load_credentials() if credentials: provider = QuiltProvider(credentials) resolver = CredentialResolver([provider]) botocore_session.register_component('credential_provider', resolver) return botocore_session + + +def get_boto3_session(*, fallback: bool = True) -> boto3.Session: + """ + Return a Boto3 session with Quilt stack credentials and AWS region. + In case of no Quilt credentials found, return a "normal" Boto3 session if `fallback` is `True`, + otherwise raise a `QuiltException`. + + > Note: you need to call `quilt3.config("https://your-catalog-homepage/")` to have region set on the session, + if you previously called it in quilt3 < 6.1.0. + """ + if not (credentials := _load_credentials()): + if fallback: + return boto3.Session() + raise QuiltException("No Quilt credentials found.") + return boto3.Session( + botocore_session=create_botocore_session(credentials=credentials), + region_name=get_from_config("region"), + ) diff --git a/api/python/quilt3/util.py b/api/python/quilt3/util.py index 5c922d030d6..973c658320f 100644 --- a/api/python/quilt3/util.py +++ b/api/python/quilt3/util.py @@ -13,7 +13,7 @@ urlparse, urlunparse, ) -from urllib.request import pathname2url, url2pathname +from urllib.request import url2pathname import requests # Third-Party @@ -81,6 +81,9 @@ def get_bool_from_env(var_name: str): default_registry_version: 1 +# AWS region +region: + """.format(BASE_PATH.as_uri() + '/packages') @@ -136,8 +139,6 @@ def __init__(self, bucket, path, version_id): assert version_id is None, "Local keys cannot have a version ID" if os.name == 'nt': assert '\\' not in path, "Paths must use / as a separator" - else: - assert not path.startswith('/'), "S3 paths must not start with '/'" self.bucket = bucket self.path = path @@ -222,13 +223,13 @@ def __repr__(self): def __str__(self): if self.bucket is None: - return urlunparse(('file', '', pathname2url(self.path.replace('/', os.path.sep)), None, None, None)) + return pathlib.PurePath(self.path).as_uri() else: if self.version_id is None: params = {} else: params = {'versionId': self.version_id} - return urlunparse(('s3', self.bucket, quote(self.path), None, urlencode(params), None)) + return urlunparse(('s3', self.bucket, quote('/' + self.path), None, urlencode(params), None)) def fix_url(url): @@ -292,26 +293,19 @@ def write_yaml(data, yaml_path, keep_backup=False): :param keep_backup: If set, a timestamped backup will be kept in the same dir. """ path = pathlib.Path(yaml_path) - now = str(datetime.datetime.now()) - - # XXX unicode colon for Windows/NTFS -- looks prettier, but could be confusing. We could use '_' instead. - if os.name == 'nt': - now = now.replace(':', '\ua789') - + now = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ") # ISO 8601 'basic format' backup_path = path.with_name(path.name + '.backup.' + now) try: if path.exists(): - path.rename(backup_path) - if not path.parent.exists(): - path.parent.mkdir(parents=True) + # TODO: use something from tempfile to make sure backup_path doesn't exist. + path.replace(backup_path) + path.parent.mkdir(parents=True, exist_ok=True) with path.open('w') as config_file: yaml.dump(data, config_file) except Exception: # intentionally wide catch -- reraised immediately. if backup_path.exists(): - if path.exists(): - path.unlink() - backup_path.rename(path) + backup_path.replace(path) raise if backup_path.exists() and not keep_backup: @@ -438,8 +432,10 @@ def load_config(): Read the local config using defaults from CONFIG_TEMPLATE. """ local_config = read_yaml(CONFIG_TEMPLATE) - if CONFIG_PATH.exists(): + try: local_config.update(read_yaml(CONFIG_PATH)) + except FileNotFoundError: + pass return local_config diff --git a/api/python/setup.py b/api/python/setup.py index 797f1b95cdf..bc688d8b117 100644 --- a/api/python/setup.py +++ b/api/python/setup.py @@ -40,16 +40,16 @@ def run(self): packages=find_packages(exclude=("tests", "tests.*")), description='Quilt: where data comes together', long_description=readme(), - python_requires='>=3.7', + python_requires='>=3.9', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ], author='quiltdata', author_email='contact@quiltdata.io', @@ -58,16 +58,15 @@ def run(self): keywords='', install_requires=[ 'platformdirs>=2', - 'aws-requests-auth>=0.4.2', - 'boto3>=1.10.0', + 'boto3>=1.21.7', 'jsonlines==1.2.0', 'PyYAML>=5.1', 'requests>=2.12.4', - 'tenacity>=5.1.1', + 'tenacity>=5.1.1,!=8.4.0', 'tqdm>=4.32', 'requests_futures==1.0.0', 'jsonschema>=3,<5', - 'importlib_metadata; python_version < "3.8"', + "pydantic>=2.0.0,<3.0.0", ], extras_require={ 'pyarrow': [ diff --git a/api/python/tests/integration/test_packages.py b/api/python/tests/integration/test_packages.py index 0fb9ab0d5cb..1a3e34664ad 100644 --- a/api/python/tests/integration/test_packages.py +++ b/api/python/tests/integration/test_packages.py @@ -25,10 +25,13 @@ LocalPackageRegistryV2, ) from quilt3.backends.s3 import S3PackageRegistryV1, S3PackageRegistryV2 +from quilt3.exceptions import PackageException +from quilt3.packages import PackageEntry from quilt3.util import ( PhysicalKey, QuiltConflictException, QuiltException, + URLParseError, validate_package_name, ) @@ -44,7 +47,7 @@ def _mock_copy_file_list(file_list, callback=None, message=None): - return [key for _, key, _ in file_list] + return [(key, None) for _, key, _ in file_list] class PackageTest(QuiltTestCase): @@ -221,11 +224,13 @@ def setup_s3_stubber_upload_pkg_data(self, pkg_registry, pkg_name, *, lkey, data method='put_object', service_response={ 'VersionId': version, + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, # TODO: use data here. 'Bucket': pkg_registry.root.bucket, 'Key': f'{pkg_name}/{lkey}', + 'ChecksumAlgorithm': 'SHA256', } ) @@ -474,7 +479,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version): self.setup_s3_stubber_push_manifest( pkg_registry, pkg_name, - '7fd8e7f49a344aadf4154a2210fe6b08297ecb23218d95027963dc0410548440', + 'b8cc6e8caa93d1250afe3a4ae1d47bb4f03a900076d9d12bcb6797df57b273d0', pointer_name=str(timestamp1), ) with patch('time.time', return_value=timestamp1), \ @@ -488,7 +493,7 @@ def add_pkg_file(pkg, lk, filename, data, *, version): self.setup_s3_stubber_push_manifest( pkg_registry, pkg_name, - 'd4efbb1734a53726d97086824d153e6cb5e9d8bc31d15ead0dbc019022cfe539', + 'b8cc6e8caa93d1250afe3a4ae1d47bb4f03a900076d9d12bcb6797df57b273d0', pointer_name=str(timestamp2), ) with patch('time.time', return_value=timestamp2), \ @@ -621,6 +626,30 @@ def test_s3_set_dir(self): list_object_versions_mock.assert_called_with('bucket', 'foo/') + @patch("quilt3.packages.list_object_versions") + def test_set_dir_root_folder_named_slash(self, list_object_versions_mock): + list_object_versions_mock.return_value = ( + [dict(Key="/foo/a.txt", VersionId="xyz", IsLatest=True, Size=10)], + [], + ) + pkg = Package() + pkg.set_dir("bar", "s3://bucket//foo") # top-level '/' folder + + assert pkg["bar"]["a.txt"].get() == "s3://bucket//foo/a.txt?versionId=xyz" + assert pkg["bar"]["a.txt"].size == 10 + + list_object_versions_mock.assert_called_once_with("bucket", "/foo/") + + @patch("quilt3.packages.get_size_and_version", return_value=(123, "v1")) + def test_set_file_root_folder_named_slash(self, get_size_and_version_mock): + pkg = Package() + pkg.set("bar.txt", "s3://bucket//foo/a.txt") + + assert pkg["bar.txt"].get() == "s3://bucket//foo/a.txt?versionId=v1" + assert pkg["bar.txt"].size == 123 + + get_size_and_version_mock.assert_called_once_with(PhysicalKey("bucket", "/foo/a.txt", "v1")) + def test_set_dir_wrong_update_policy(self): """Verify non existing update policy raises value error.""" pkg = Package() @@ -629,6 +658,25 @@ def test_set_dir_wrong_update_policy(self): pkg.set_dir("nested", DATA_DIR, update_policy='invalid_policy') assert expected_err in str(e.value) + @mock.patch("quilt3.packages.list_objects") + @mock.patch("quilt3.packages.list_object_versions") + def test_set_dir_unversioned(self, list_object_versions_mock, list_objects_mock): + list_objects_mock.return_value = [ + { + "Key": "foo/bar.txt", + "Size": 123, + }, + ] + + pkg = Package().set_dir(".", "s3://bucket/foo", unversioned=True) + + list_object_versions_mock.assert_not_called() + list_objects_mock.assert_called_once_with("bucket", "foo/", recursive=True) + assert [ + (lk, e.get()) + for lk, e in pkg.walk() + ] == [("bar.txt", "s3://bucket/foo/bar.txt")] + def test_package_entry_meta(self): pkg = ( Package() @@ -726,7 +774,7 @@ def test_set_package_entry_as_object(self): file_path = entry.physical_key.path assert pathlib.Path(file_path).exists(), "The serialization files should exist" - pkg._fix_sha256() + pkg._calculate_missing_hashes() for lk, entry in pkg.walk(): assert df.equals(entry.deserialize()), "The deserialized PackageEntry should be equal to the object " \ "that was serialized" @@ -744,6 +792,17 @@ def test_set_package_entry_as_object(self): file_path = pkg[lk].physical_key.path assert not pathlib.Path(file_path).exists(), "These temp files should have been deleted during push()" + @patch("quilt3.packages.get_size_and_version", mock.Mock(return_value=(123, "v1"))) + def test_set_package_entry_unversioned_flag(self): + for flag_value, version_id in { + True: None, + False: "v1", + }.items(): + with self.subTest(flag_value=flag_value, version_id=version_id): + pkg = Package() + pkg.set("bar", "s3://bucket/bar", unversioned=flag_value) + assert pkg["bar"].physical_key == PhysicalKey("bucket", "bar", version_id) + def test_tophash_changes(self): test_file = Path('test.txt') test_file.write_text('asdf', 'utf-8') @@ -799,7 +858,9 @@ def test_iter(self): def test_invalid_set_key(self): """Verify an exception when setting a key with a path object.""" pkg = Package() - with pytest.raises(TypeError): + with pytest.raises(TypeError, + match="Expected a string for entry, but got an instance of " + r"\."): pkg.set('asdf/jkl', Package()) def test_brackets(self): @@ -1032,7 +1093,11 @@ def _test_remote_revision_delete_setup_stubber(self, pkg_registry, pkg_name, *, ) self.s3_stubber.add_response( method='copy_object', - service_response={}, + service_response={ + 'CopyObjectResult': { + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', + }, + }, expected_params={ 'CopySource': { 'Bucket': pkg_registry.root.bucket, @@ -1040,6 +1105,7 @@ def _test_remote_revision_delete_setup_stubber(self, pkg_registry, pkg_name, *, }, 'Bucket': pkg_registry.root.bucket, 'Key': pkg_registry.pointer_latest_pk(pkg_name).path, + 'ChecksumAlgorithm': 'SHA256', } ) @@ -1203,7 +1269,8 @@ def test_commit_message_on_push(self, mocked_workflow_validate): ) def test_overwrite_dir_fails(self): - with pytest.raises(QuiltException): + with pytest.raises(QuiltException, + match="Cannot overwrite directory 'asdf' with PackageEntry"): pkg = Package() pkg.set('asdf/jkl', LOCAL_MANIFEST) pkg.set('asdf', LOCAL_MANIFEST) @@ -1323,7 +1390,7 @@ def test_filter(self): pkg['b'].set_meta({'foo': 'bar'}) p_copy = pkg.filter(lambda lk, entry: lk == 'a/', include_directories=True) - assert list(p_copy) == [] + assert not list(p_copy) p_copy = pkg.filter(lambda lk, entry: lk in ('a/', 'a/df'), include_directories=True) assert list(p_copy) == ['a'] and list(p_copy['a']) == ['df'] @@ -1625,9 +1692,20 @@ def test_verify(self): Path('test/blah').unlink() assert pkg.verify('test') - def test_verify_poo_hash_type(self): - expected_err_msg = "Unsupported hash type: '💩'. Supported types: SHA256. Try to update quilt3." + # Legacy hash + pkg['foo'].hash = dict( + type='SHA256', + value='12345', + ) + assert not pkg.verify('test') + + pkg['foo'].hash = dict( + type='SHA256', + value='dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f', + ) + assert pkg.verify('test') + def test_verify_poo_hash_type(self): self.patch_local_registry('shorten_top_hash', return_value='7a67ff4') pkg = Package() @@ -1637,9 +1715,8 @@ def test_verify_poo_hash_type(self): pkg['foo'].hash['type'] = '💩' def _test_verify_fails(*args, **kwargs): - with pytest.raises(QuiltException) as excinfo: + with pytest.raises(QuiltException, match="Unsupported hash type: '💩'") as excinfo: pkg.verify(*args, **kwargs) - assert str(excinfo.value) == expected_err_msg Package.install('quilt/test', LOCAL_REGISTRY, dest='test') _test_verify_fails('test') @@ -1657,33 +1734,33 @@ def _test_verify_fails(*args, **kwargs): _test_verify_fails('test') _test_verify_fails('test', extra_files_ok=True) - @patch('quilt3.packages.calculate_sha256') - def test_fix_sha256_fail(self, mocked_calculate_sha256): + @patch('quilt3.packages.calculate_checksum') + def test_calculate_missing_hashes_fail(self, mocked_calculate_checksum): data = b'Hello, World!' pkg = Package() pkg.set('foo', data) _, entry = next(pkg.walk()) exc = Exception('test exception') - mocked_calculate_sha256.return_value = [exc] + mocked_calculate_checksum.return_value = [exc] with pytest.raises(quilt3.exceptions.PackageException) as excinfo: - pkg._fix_sha256() - mocked_calculate_sha256.assert_called_once_with([entry.physical_key], [len(data)]) + pkg._calculate_missing_hashes() + mocked_calculate_checksum.assert_called_once_with([entry.physical_key], [len(data)]) assert entry.hash is None assert excinfo.value.__cause__ == exc - @patch('quilt3.packages.calculate_sha256') - def test_fix_sha256(self, mocked_calculate_sha256): + @patch('quilt3.packages.calculate_checksum') + def test_calculate_missing_hashes(self, mocked_calculate_checksum): data = b'Hello, World!' pkg = Package() pkg.set('foo', data) _, entry = next(pkg.walk()) hash_ = object() - mocked_calculate_sha256.return_value = [hash_] - pkg._fix_sha256() - mocked_calculate_sha256.assert_called_once_with([entry.physical_key], [len(data)]) - assert entry.hash == {'type': 'SHA256', 'value': hash_} + mocked_calculate_checksum.return_value = [(hash_)] + pkg._calculate_missing_hashes() + mocked_calculate_checksum.assert_called_once_with([entry.physical_key], [len(data)]) + assert entry.hash == {'type': 'sha2-256-chunked', 'value': hash_} def test_resolve_hash_invalid_pkg_name(self): with pytest.raises(QuiltException, match='Invalid package name'): @@ -1716,9 +1793,9 @@ def test_resolve_hash(self): with pytest.raises(QuiltException, match='Found multiple matches'): Package.resolve_hash(pkg_name, LOCAL_REGISTRY, hash_prefix) - @patch('quilt3.Package._fix_sha256', wraps=quilt3.Package._fix_sha256) + @patch('quilt3.Package._calculate_missing_hashes', wraps=quilt3.Package._calculate_missing_hashes) @patch('quilt3.Package._build', wraps=quilt3.Package._build) - def test_workflow_validation_error(self, build_mock, fix_hashes): + def test_workflow_validation_error(self, build_mock, calculate_missing_hashes): self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') pkg = Package().set('foo', DATA_DIR / 'foo.txt') @@ -1733,7 +1810,7 @@ def test_workflow_validation_error(self, build_mock, fix_hashes): assert excinfo.value is workflow_validate_mock.side_effect workflow_validate_mock.assert_called_once() assert not build_mock.mock_calls - assert not fix_hashes.mock_calls + assert not calculate_missing_hashes.mock_calls assert pkg._workflow is None @patch('quilt3.packages.copy_file_list') @@ -1795,7 +1872,8 @@ def test_push_dest_fn_non_string(self): with self.subTest(value=val): with pytest.raises(TypeError) as excinfo: pkg.push('foo/bar', registry='s3://test-bucket', - dest=(lambda v: lambda *args, **kwargs: v)(val), force=True) + # pylint: disable=cell-var-from-loop + dest=lambda *args, **kwargs: val, force=True) assert 'str is expected' in str(excinfo.value) @patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) @@ -1805,7 +1883,8 @@ def test_push_dest_fn_non_supported_uri(self): with self.subTest(value=val): with pytest.raises(quilt3.util.URLParseError): pkg.push('foo/bar', registry='s3://test-bucket', - dest=(lambda v: lambda *args, **kwargs: v)(val), force=True) + # pylint: disable=cell-var-from-loop + dest=lambda *args, **kwargs: val, force=True) @patch('quilt3.workflows.validate', mock.MagicMock(return_value=None)) def test_push_dest_fn_s3_uri_with_version_id(self): @@ -1832,18 +1911,20 @@ def test_push_dest_fn(self): method='put_object', service_response={ 'VersionId': '1', + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': dest_bucket, 'Key': dest_key, + 'ChecksumAlgorithm': 'SHA256', } ) push_manifest_mock = self.patch_s3_registry('push_manifest') self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') pkg.push(pkg_name, registry='s3://test-bucket', dest=dest_fn, force=True) - dest_fn.assert_called_once_with(lk, pkg[lk], mock.sentinel.top_hash) + dest_fn.assert_called_once_with(lk, pkg[lk]) push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) @@ -1865,10 +1946,11 @@ def test_push_selector_fn_false(self): selector_fn = mock.MagicMock(return_value=False) push_manifest_mock = self.patch_s3_registry('push_manifest') self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') - with patch('quilt3.packages.calculate_sha256', return_value=["a" * 64]): + with patch('quilt3.packages.calculate_checksum', return_value=["a" * 64]) as calculate_checksum_mock: pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True) selector_fn.assert_called_once_with(lk, pkg[lk]) + calculate_checksum_mock.assert_called_once_with([PhysicalKey(src_bucket, src_key, src_version)], [0]) push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) @@ -1894,6 +1976,9 @@ def test_push_selector_fn_true(self): method='copy_object', service_response={ 'VersionId': dst_version, + 'CopyObjectResult': { + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', + }, }, expected_params={ 'Bucket': dst_bucket, @@ -1903,14 +1988,16 @@ def test_push_selector_fn_true(self): 'Key': src_key, 'VersionId': src_version, }, + 'ChecksumAlgorithm': 'SHA256', } ) push_manifest_mock = self.patch_s3_registry('push_manifest') self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4') - with patch('quilt3.packages.calculate_sha256', return_value=["a" * 64]): + with patch('quilt3.packages.calculate_checksum', return_value=[]) as calculate_checksum_mock: pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True) selector_fn.assert_called_once_with(lk, pkg[lk]) + calculate_checksum_mock.assert_called_once_with([], []) push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY) assert Package.load( BytesIO(push_manifest_mock.call_args[0][2]) @@ -1936,7 +2023,7 @@ def test_max_manifest_record_size(self): with mock.patch('quilt3.packages.MANIFEST_MAX_RECORD_SIZE', 1): with pytest.raises(QuiltException) as excinfo: Package().dump(buf) - assert 'Size of manifest record for package metadata' in str(excinfo.value) + assert "Size of manifest record for package metadata" in str(excinfo.value) with mock.patch('quilt3.packages.MANIFEST_MAX_RECORD_SIZE', 10_000): with pytest.raises(QuiltException) as excinfo: @@ -2165,3 +2252,78 @@ def test_set_dir_update_policy_s3(update_policy, expected_a_url, expected_xy_url assert pkg['z.txt'].get() == 's3://bucket/bar/z.txt?versionId=123' assert list_object_versions_mock.call_count == 2 list_object_versions_mock.assert_has_calls([call('bucket', 'foo/'), call('bucket', 'bar/')]) + + +def create_test_file(filename): + file_path = Path(filename) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text("test") + return filename + + +def test_set_meta_error(): + with pytest.raises(PackageException, match="Must specify either path or meta"): + entry = PackageEntry( + PhysicalKey("test-bucket", "without-hash", "without-hash"), + 42, + None, + {}, + ) + entry.set() + + +def test_loading_duplicate_logical_key_error(): + # Create a manifest with duplicate logical keys + KEY = "duplicate_key" + ROW = {"logical_key": KEY, "physical_keys": [f"s3://bucket/{KEY}"], "size": 123, "hash": None, "meta": {}} + buf = io.BytesIO() + jsonlines.Writer(buf).write_all([{"version": "v0"}, ROW, ROW]) + buf.seek(0) + + # Attempt to load the package, which should raise the error + with pytest.raises(PackageException, match=f"Duplicate logical key {KEY!r} while loading package entry: .*"): + Package.load(buf) + + +def test_directory_not_exist_error(): + pkg = Package() + with pytest.raises(PackageException, match="The specified directory .*non_existent_directory'. doesn't exist"): + pkg.set_dir("foo", "non_existent_directory") + + +def test_key_not_point_to_package_entry_error(): + DIR = "foo" + KEY = create_test_file(f"{DIR}/foo.txt") + pkg = Package().set(KEY) + + with pytest.raises(ValueError, match=f"Key {DIR!r} does not point to a PackageEntry"): + pkg.get(DIR) + + +def test_commit_message_type_error(): + pkg = Package() + with pytest.raises( + ValueError, + match="The package commit message must be a string, " + "but the message provided is an instance of .", + ): + pkg.build("test/pkg", message=123) + + +def test_already_package_entry_error(): + DIR = "foo" + KEY = create_test_file(f"{DIR}/foo.txt") + KEY2 = create_test_file(f"{DIR}/bar.txt") + pkg = Package().set(DIR, KEY) + with pytest.raises( + QuiltException, match=f"Already a PackageEntry for {DIR!r} along the path " rf"\['{DIR}'\]: .*/{KEY}" + ): + pkg.set(KEY2) + + +@patch("quilt3.workflows.validate", return_value=None) +def test_unexpected_scheme_error(workflow_validate_mock): + KEY = create_test_file("foo.txt") + pkg = Package().set(KEY) + with pytest.raises(URLParseError, match="Unexpected scheme: 'file' for .*"): + pkg.push("foo/bar", registry="s3://test-bucket", dest=lambda lk, entry: "file:///foo.txt", force=True) diff --git a/api/python/tests/test_admin_api.py b/api/python/tests/test_admin_api.py new file mode 100644 index 00000000000..0a444b506e0 --- /dev/null +++ b/api/python/tests/test_admin_api.py @@ -0,0 +1,425 @@ +import contextlib +import datetime +from unittest import mock + +import pytest + +from quilt3 import admin +from quilt3.admin import _graphql_client + +UNMANAGED_ROLE = { + "__typename": "UnmanagedRole", + "id": "d7d15bef-c482-4086-ae6b-d0372b6145d2", + "name": "UnmanagedRole", + "arn": "arn:aws:iam::000000000000:role/UnmanagedRole", +} +MANAGED_ROLE = { + "__typename": "ManagedRole", + "id": "b1bab604-98fd-4b46-a20b-958cf2541c91", + "name": "ManagedRole", + "arn": "arn:aws:iam::000000000000:role/ManagedRole", +} +USER = { + "__typename": "User", + "name": "test", + "email": "test@example.com", + "dateJoined": datetime.datetime(2024, 6, 14, 11, 42, 27, 857128, tzinfo=datetime.timezone.utc), + "lastLogin": datetime.datetime(2024, 6, 14, 11, 42, 27, 857128, tzinfo=datetime.timezone.utc), + "isActive": True, + "isAdmin": False, + "isSsoOnly": False, + "isService": False, + "role": UNMANAGED_ROLE, + "extraRoles": [MANAGED_ROLE], +} +SSO_CONFIG = { + "__typename": "SsoConfig", + "text": "", + "timestamp": datetime.datetime(2024, 6, 14, 11, 42, 27, 857128, tzinfo=datetime.timezone.utc), + "uploader": USER, +} +TABULATOR_TABLE = { + "name": "table", + "config": "config", +} +MUTATION_ERRORS = ( + ( + { + "__typename": "InvalidInput", + "errors": [ + { + "path": "error path", + "message": "error message", + "name": "error name", + "context": {}, + } + ], + }, + admin.Quilt3AdminError, + ), + ( + { + "__typename": "OperationError", + "message": "error message", + "name": "error name", + "context": {}, + }, + admin.Quilt3AdminError, + ), +) +USER_MUTATION_ERRORS = ( + *MUTATION_ERRORS, + (None, admin.UserNotFoundError), +) + + +def _camel_to_snake(name: str) -> str: + return "".join("_" + c.lower() if c.isupper() else c for c in name).lstrip("_") + + +def _as_dataclass_kwargs(data: dict) -> dict: + return { + "typename__" if k == "__typename" else _camel_to_snake(k): ( + _as_dataclass_kwargs(v) + if isinstance(v, dict) + else [_as_dataclass_kwargs(x) for x in v] if isinstance(v, list) else v + ) + for k, v in data.items() + } + + +def _make_nested_dict(path: str, value) -> dict: + if "." in path: + key, rest = path.split(".", 1) + return {key: _make_nested_dict(rest, value)} + return {path: value} + + +@contextlib.contextmanager +def mock_client(data, operation_name, variables=None): + with mock.patch("quilt3.session.get_registry_url", return_value="https://registry.example.com"): + with mock.patch( + "quilt3.admin._graphql_client.Client.execute", return_value=mock.sentinel.RESPONSE + ) as execute_mock: + with mock.patch("quilt3.admin._graphql_client.Client.get_data", return_value=data) as get_data_mock: + yield + + execute_mock.assert_called_once_with(query=mock.ANY, operation_name=operation_name, variables=variables or {}) + get_data_mock.assert_called_once_with(mock.sentinel.RESPONSE) + + +def test_get_roles(): + with mock_client({"roles": [UNMANAGED_ROLE, MANAGED_ROLE]}, "rolesList"): + assert admin.roles.list() == [ + admin.UnmanagedRole(**_as_dataclass_kwargs(UNMANAGED_ROLE)), + admin.ManagedRole(**_as_dataclass_kwargs(MANAGED_ROLE)), + ] + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + (None, None), + ], +) +def test_get_user(data, result): + with mock_client(_make_nested_dict("admin.user.get", data), "usersGet", variables={"name": "test"}): + assert admin.users.get("test") == result + + +def test_get_users(): + with mock_client(_make_nested_dict("admin.user.list", [USER]), "usersList"): + assert admin.users.list() == [admin.User(**_as_dataclass_kwargs(USER))] + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *MUTATION_ERRORS, + ], +) +def test_create_user(data, result): + with mock_client( + _make_nested_dict("admin.user.create", data), + "usersCreate", + variables={ + "input": _graphql_client.UserInput( + name="test", email="test@example.com", role="UnmanagedRole", extraRoles=[] + ) + }, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.create("test", "test@example.com", "UnmanagedRole", []) + else: + assert admin.users.create("test", "test@example.com", "UnmanagedRole", []) == result + + +@pytest.mark.parametrize( + "data,result", + MUTATION_ERRORS, +) +def test_delete_user(data, result): + with mock_client( + _make_nested_dict("admin.user.mutate.delete", data), + "usersDelete", + variables={"name": "test"}, + ): + with pytest.raises(result): + admin.users.delete("test") + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_set_user_email(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.setEmail", data) + ), + "usersSetEmail", + variables={"name": "test", "email": "test@example.com"}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.set_email("test", "test@example.com") + else: + assert admin.users.set_email("test", "test@example.com") == result + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_set_user_admin(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.setAdmin", data) + ), + "usersSetAdmin", + variables={"name": "test", "admin": True}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.set_admin("test", True) + else: + assert admin.users.set_admin("test", True) == result + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_set_user_active(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.setActive", data) + ), + "usersSetActive", + variables={"name": "test", "active": True}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.set_active("test", True) + else: + assert admin.users.set_active("test", True) == result + + +@pytest.mark.parametrize( + "data,result", + USER_MUTATION_ERRORS, +) +def test_reset_user_password(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.resetPassword", data) + ), + "usersResetPassword", + variables={"name": "test"}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.reset_password("test") + else: + assert admin.users.reset_password("test") == result + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_set_role(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.setRole", data) + ), + "usersSetRole", + variables={"name": "test", "role": "UnamangedRole", "extraRoles": [], "append": True}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.set_role("test", "UnamangedRole", [], append=True) + else: + assert admin.users.set_role("test", "UnamangedRole", [], append=True) == result + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_add_roles(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.addRoles", data) + ), + "usersAddRoles", + variables={"name": "test", "roles": ["ManagedRole"]}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.add_roles("test", ["ManagedRole"]) + else: + assert admin.users.add_roles("test", ["ManagedRole"]) == result + + +@pytest.mark.parametrize( + "data,result", + [ + (USER, admin.User(**_as_dataclass_kwargs(USER))), + *USER_MUTATION_ERRORS, + ], +) +def test_remove_roles(data, result): + with mock_client( + ( + _make_nested_dict("admin.user.mutate", None) + if data is None + else _make_nested_dict("admin.user.mutate.removeRoles", data) + ), + "usersRemoveRoles", + variables={"name": "test", "roles": ["ManagedRole"], "fallback": "UnamanagedRole"}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.users.remove_roles("test", ["ManagedRole"], fallback="UnamanagedRole") + else: + assert admin.users.remove_roles("test", ["ManagedRole"], fallback="UnamanagedRole") == result + + +@pytest.mark.parametrize( + "data,result", + [ + (SSO_CONFIG, admin.SSOConfig(**_as_dataclass_kwargs(SSO_CONFIG))), + (None, None), + ], +) +def test_sso_config_get(data, result): + with mock_client(_make_nested_dict("admin.sso_config", data), "ssoConfigGet"): + assert admin.sso_config.get() == result + + +@pytest.mark.parametrize( + "data,result", + [ + (SSO_CONFIG, admin.SSOConfig(**_as_dataclass_kwargs(SSO_CONFIG))), + (None, None), + *MUTATION_ERRORS, + ], +) +def test_sso_config_set(data, result): + with mock_client(_make_nested_dict("admin.set_sso_config", data), "ssoConfigSet", variables={"config": ""}): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.sso_config.set("") + else: + assert admin.sso_config.set("") == result + + +@pytest.mark.parametrize( + "data, result", + [ + ({"tabulator_tables": [TABULATOR_TABLE]}, [admin.TabulatorTable(**TABULATOR_TABLE)]), + (None, admin.BucketNotFoundError), + ], +) +def test_tabulator_list(data, result): + with mock_client( + _make_nested_dict("bucket_config", data), + "bucketTabulatorTablesList", + variables={"name": "test"}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.tabulator.list_tables("test") + else: + assert admin.tabulator.list_tables("test") == result + + +@pytest.mark.parametrize( + "data,result", + [ + ({"__typename": "BucketConfig"}, None), + *MUTATION_ERRORS, + ], +) +def test_tabulator_set(data, result): + with mock_client( + _make_nested_dict("admin.bucket_set_tabulator_table", data), + "bucketTabulatorTableSet", + variables={"bucketName": "test", "tableName": "table", "config": ""}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.tabulator.set_table("test", "table", "") + else: + assert admin.tabulator.set_table("test", "table", "") == result + + +@pytest.mark.parametrize( + "data,result", + [ + ({"__typename": "BucketConfig"}, None), + *MUTATION_ERRORS, + ], +) +def test_tabulator_rename(data, result): + with mock_client( + _make_nested_dict("admin.bucket_rename_tabulator_table", data), + "bucketTabulatorTableRename", + variables={"bucketName": "test", "tableName": "table", "newTableName": "new_table"}, + ): + if isinstance(result, type) and issubclass(result, Exception): + with pytest.raises(result): + admin.tabulator.rename_table("test", "table", "new_table") + else: + assert admin.tabulator.rename_table("test", "table", "new_table") == result diff --git a/api/python/tests/test_api.py b/api/python/tests/test_api.py index 549495dbecd..9db8f16c924 100644 --- a/api/python/tests/test_api.py +++ b/api/python/tests/test_api.py @@ -17,7 +17,8 @@ def test_config(self): 'telemetry_disabled': False, 's3Proxy': None, 'apiGatewayEndpoint': None, - 'binaryApiGatewayEndpoint': None + 'binaryApiGatewayEndpoint': None, + "region": "us-west-2", } self.requests_mock.add(responses.GET, 'https://foo.bar/config.json', json=content, status=200) @@ -42,18 +43,3 @@ def test_config_invalid_host(self): # present. ..but, a bad port causes an error.. with pytest.raises(util.QuiltException, match='Port must be a number'): he.config('https://fliff:fluff') - - def test_set_role(self): - self.requests_mock.add(responses.POST, DEFAULT_URL + '/api/users/set_role', - json={}, status=200) - - not_found_result = { - 'message': "No user exists by the provided name." - } - self.requests_mock.add(responses.POST, DEFAULT_URL + '/api/users/set_role', - json=not_found_result, status=400) - - he.admin.set_role(username='test_user', role_name='test_role') - - with pytest.raises(util.QuiltException): - he.admin.set_role(username='not_found', role_name='test_role') diff --git a/api/python/tests/test_data_transfer.py b/api/python/tests/test_data_transfer.py index d4a466722a2..05a71597a98 100644 --- a/api/python/tests/test_data_transfer.py +++ b/api/python/tests/test_data_transfer.py @@ -1,6 +1,5 @@ """ Testing for data_transfer.py """ -import hashlib import io import os import pathlib @@ -14,7 +13,7 @@ import botocore.client import pandas as pd import pytest -from botocore.exceptions import ClientError, ReadTimeoutError +from botocore.exceptions import ClientError, ConnectionError, ReadTimeoutError from botocore.stub import ANY from quilt3 import data_transfer @@ -165,11 +164,13 @@ def test_simple_upload(self): self.s3_stubber.add_response( method='put_object', service_response={ + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': 'example', 'Key': 'foo.csv', + 'ChecksumAlgorithm': 'SHA256', } ) @@ -183,11 +184,13 @@ def test_multi_upload(self): self.s3_stubber.add_response( method='put_object', service_response={ + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': 'example1', 'Key': 'foo.csv', + 'ChecksumAlgorithm': 'SHA256', } ) @@ -195,12 +198,14 @@ def test_multi_upload(self): self.s3_stubber.add_response( method='put_object', service_response={ - 'VersionId': 'v123' + 'VersionId': 'v123', + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': 'example2', 'Key': 'foo.txt', + 'ChecksumAlgorithm': 'SHA256', } ) @@ -211,8 +216,14 @@ def test_multi_upload(self): (PhysicalKey.from_path(path2), PhysicalKey.from_url('s3://example2/foo.txt'), path2.stat().st_size), ]) - assert urls[0] == PhysicalKey.from_url('s3://example1/foo.csv') - assert urls[1] == PhysicalKey.from_url('s3://example2/foo.txt?versionId=v123') + assert urls[0] == ( + PhysicalKey.from_url('s3://example1/foo.csv'), + 'Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=', + ) + assert urls[1] == ( + PhysicalKey.from_url('s3://example2/foo.txt?versionId=v123'), + 'Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=', + ) def test_upload_large_file(self): path = DATA_DIR / 'large_file.npy' @@ -223,25 +234,31 @@ def test_upload_large_file(self): expected_params={ 'Bucket': 'example', 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', } ) self.s3_stubber.add_response( method='put_object', service_response={ - 'VersionId': 'v1' + 'VersionId': 'v1', + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': 'example', 'Key': 'large_file.npy', + 'ChecksumAlgorithm': 'SHA256', } ) urls = data_transfer.copy_file_list([ (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), ]) - assert urls[0] == PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1') + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1'), + 'Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=', + ) def test_upload_large_file_etag_match(self): path = DATA_DIR / 'large_file.npy' @@ -256,13 +273,17 @@ def test_upload_large_file_etag_match(self): expected_params={ 'Bucket': 'example', 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', } ) urls = data_transfer.copy_file_list([ (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), ]) - assert urls[0] == PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1') + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1'), + "IsygGcHBbQgZ3DCzdPy9+0od5VqDJjcW4R0mF2v/Bu8=", + ) def test_upload_large_file_etag_mismatch(self): path = DATA_DIR / 'large_file.npy' @@ -277,25 +298,244 @@ def test_upload_large_file_etag_mismatch(self): expected_params={ 'Bucket': 'example', 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', + } + ) + + self.s3_stubber.add_response( + method='put_object', + service_response={ + 'VersionId': 'v2', + # b2a_base64(a2b_hex(b'0123456789abcdef0123456789abcdef')) + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', + }, + expected_params={ + 'Body': ANY, + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumAlgorithm': 'SHA256', + } + ) + + urls = data_transfer.copy_file_list([ + (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), + ]) + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v2'), + 'Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=', + ) + + def test_upload_file_checksum_match(self): + path = DATA_DIR / 'large_file.npy' + assert path.stat().st_size < data_transfer.CHECKSUM_MULTIPART_THRESHOLD + + self.s3_stubber.add_response( + method='head_object', + service_response={ + 'ContentLength': path.stat().st_size, + 'ETag': '"123"', + 'VersionId': 'v1', + 'ChecksumSHA256': 'J+KTXLmOXrP7AmRZQQZWSj6DznTh7TbeeP6YbL1j+5w=', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', + } + ) + + urls = data_transfer.copy_file_list([ + (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), + ]) + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1'), + "IsygGcHBbQgZ3DCzdPy9+0od5VqDJjcW4R0mF2v/Bu8=", + ) + + def test_upload_file_checksum_match_unexpected_parts(self): + path = DATA_DIR / 'large_file.npy' + assert path.stat().st_size < data_transfer.CHECKSUM_MULTIPART_THRESHOLD + + self.s3_stubber.add_response( + method='head_object', + service_response={ + 'ContentLength': path.stat().st_size, + 'ETag': '"123"', + 'VersionId': 'v1', + 'ChecksumSHA256': 'IsygGcHBbQgZ3DCzdPy9+0od5VqDJjcW4R0mF2v/Bu8=-1', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', + } + ) + + self.s3_stubber.add_response( + method='put_object', + service_response={ + 'VersionId': 'v2', + # b2a_base64(a2b_hex(b'0123456789abcdef0123456789abcdef')) + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', + }, + expected_params={ + 'Body': ANY, + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumAlgorithm': 'SHA256', + } + ) + + urls = data_transfer.copy_file_list([ + (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), + ]) + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v2'), + "Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=", + ) + + def test_upload_file_checksum_multipart_match(self): + path = pathlib.Path("test-file") + path.write_bytes(bytes(data_transfer.CHECKSUM_MULTIPART_THRESHOLD)) + + self.s3_stubber.add_response( + method='head_object', + service_response={ + 'ContentLength': path.stat().st_size, + 'ETag': '"123"', + 'VersionId': 'v1', + 'ChecksumSHA256': 'MIsGKY+ykqN4CPj3gGGu4Gv03N7OWKWpsZqEf+OrGJs=-1', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', + } + ) + + urls = data_transfer.copy_file_list([ + (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), + ]) + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1'), + "MIsGKY+ykqN4CPj3gGGu4Gv03N7OWKWpsZqEf+OrGJs=", + ) + + def test_upload_file_checksum_multipart_match_unexpected_parts(self): + path = pathlib.Path("test-file") + path.write_bytes(bytes(data_transfer.CHECKSUM_MULTIPART_THRESHOLD)) + + self.s3_stubber.add_response( + method='head_object', + service_response={ + 'ContentLength': path.stat().st_size, + 'ETag': '"123"', + 'VersionId': 'v1', + 'ChecksumSHA256': 'La6x82CVtEsxhBCz9Oi12Yncx7sCPRQmxJLasKMFPnQ=', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', + } + ) + + self.s3_stubber.add_response( + method='create_multipart_upload', + service_response={ + 'UploadId': '123' + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumAlgorithm': 'SHA256', + } + ) + self.s3_stubber.add_response( + method='upload_part', + service_response={ + 'ETag': '"123"', + 'ChecksumSHA256': 'La6x82CVtEsxhBCz9Oi12Yncx7sCPRQmxJLasKMFPnQ=', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'UploadId': '123', + 'Body': ANY, + 'PartNumber': 1, + 'ChecksumAlgorithm': 'SHA256', + } + ) + self.s3_stubber.add_response( + method='complete_multipart_upload', + service_response={ + 'ChecksumSHA256': "MIsGKY+ykqN4CPj3gGGu4Gv03N7OWKWpsZqEf+OrGJs=-1", + 'VersionId': 'v1', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'UploadId': '123', + 'MultipartUpload': { + 'Parts': [ + { + 'ETag': '"123"', + 'ChecksumSHA256': 'La6x82CVtEsxhBCz9Oi12Yncx7sCPRQmxJLasKMFPnQ=', + 'PartNumber': 1, + }, + ] + } + } + ) + + urls = data_transfer.copy_file_list([ + (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), + ]) + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v1'), + "MIsGKY+ykqN4CPj3gGGu4Gv03N7OWKWpsZqEf+OrGJs=", + ) + + def test_upload_file_size_mismatch(self): + path = DATA_DIR / 'large_file.npy' + + self.s3_stubber.add_response( + method='head_object', + service_response={ + 'ContentLength': path.stat().st_size + 1, + 'ETag': data_transfer._calculate_etag(path), + 'VersionId': 'v1', + 'ChecksumSHA256': 'IsygGcHBbQgZ3DCzdPy9+0od5VqDJjcW4R0mF2v/Bu8=-1', + }, + expected_params={ + 'Bucket': 'example', + 'Key': 'large_file.npy', + 'ChecksumMode': 'ENABLED', } ) self.s3_stubber.add_response( method='put_object', service_response={ - 'VersionId': 'v2' + 'VersionId': 'v2', + # b2a_base64(a2b_hex(b'0123456789abcdef0123456789abcdef')) + 'ChecksumSHA256': 'ASNFZ4mrze8BI0VniavN7w==', }, expected_params={ 'Body': ANY, 'Bucket': 'example', 'Key': 'large_file.npy', + 'ChecksumAlgorithm': 'SHA256', } ) urls = data_transfer.copy_file_list([ (PhysicalKey.from_path(path), PhysicalKey.from_url('s3://example/large_file.npy'), path.stat().st_size), ]) - assert urls[0] == PhysicalKey.from_url('s3://example/large_file.npy?versionId=v2') + assert urls[0] == ( + PhysicalKey.from_url('s3://example/large_file.npy?versionId=v2'), + "Ij4KFgr52goD5t0sRxnFb11mpjPL6E54qqnzc1hlUio=", + ) def test_multipart_upload(self): name = 'very_large_file.bin' @@ -317,6 +557,7 @@ def test_multipart_upload(self): expected_params={ 'Bucket': 'example', 'Key': name, + 'ChecksumMode': 'ENABLED', } ) @@ -328,6 +569,7 @@ def test_multipart_upload(self): expected_params={ 'Bucket': 'example', 'Key': name, + 'ChecksumAlgorithm': 'SHA256', } ) @@ -335,20 +577,24 @@ def test_multipart_upload(self): self.s3_stubber.add_response( method='upload_part', service_response={ - 'ETag': 'etag%d' % part_num + 'ETag': 'etag%d' % part_num, + 'ChecksumSHA256': 'hash%d' % part_num, }, expected_params={ 'Bucket': 'example', 'Key': name, 'UploadId': '123', 'Body': ANY, - 'PartNumber': part_num + 'PartNumber': part_num, + 'ChecksumAlgorithm': 'SHA256', } ) self.s3_stubber.add_response( method='complete_multipart_upload', - service_response={}, + service_response={ + 'ChecksumSHA256': f'123456-{chunks}', + }, expected_params={ 'Bucket': 'example', 'Key': name, @@ -356,6 +602,7 @@ def test_multipart_upload(self): 'MultipartUpload': { 'Parts': [{ 'ETag': 'etag%d' % i, + 'ChecksumSHA256': 'hash%d' % i, 'PartNumber': i } for i in range(1, chunks+1)] } @@ -387,6 +634,7 @@ def test_multipart_copy(self): expected_params={ 'Bucket': 'example2', 'Key': 'large_file2.npy', + 'ChecksumAlgorithm': 'SHA256', } ) @@ -395,7 +643,8 @@ def test_multipart_copy(self): method='upload_part_copy', service_response={ 'CopyPartResult': { - 'ETag': 'etag%d' % part_num + 'ETag': 'etag%d' % part_num, + 'ChecksumSHA256': 'hash%d' % part_num, } }, expected_params={ @@ -416,7 +665,9 @@ def test_multipart_copy(self): self.s3_stubber.add_response( method='complete_multipart_upload', - service_response={}, + service_response={ + 'ChecksumSHA256': f'123456-{chunks}', + }, expected_params={ 'Bucket': 'example2', 'Key': 'large_file2.npy', @@ -424,6 +675,7 @@ def test_multipart_copy(self): 'MultipartUpload': { 'Parts': [{ 'ETag': 'etag%d' % i, + 'ChecksumSHA256': 'hash%d' % i, 'PartNumber': i } for i in range(1, chunks+1)] } @@ -454,7 +706,7 @@ def test_calculate_sha256_read_timeout(self, mocked_api_call): pk = PhysicalKey(bucket, key, vid) exc = ReadTimeoutError('Error Uploading', endpoint_url="s3://foobar") mocked_api_call.side_effect = exc - results = data_transfer.calculate_sha256([pk], [len(a_contents)]) + results = data_transfer.calculate_checksum([pk], [len(a_contents)]) assert mocked_api_call.call_count == data_transfer.MAX_FIX_HASH_RETRIES assert results == [exc] @@ -512,6 +764,44 @@ def side_effect(operation_name, *args, **kwargs): with pytest.raises(ClientError): data_transfer.copy_file_list([(src, dst, size)]) + def test_calculate_checksum_retry(self): + src = PhysicalKey('test-bucket', 'dir/a', None) + + # TODO: copy_file_list also retries ClientError. Should calculate_checksum do that? + with mock.patch('botocore.client.BaseClient._make_api_call', + side_effect=ConnectionError(error='foo')) as mocked_api_call: + result = data_transfer.calculate_checksum([src], [1]) + assert isinstance(result[0], ConnectionError) + self.assertEqual(mocked_api_call.call_count, data_transfer.MAX_FIX_HASH_RETRIES) + + def test_calculate_checksum_partial_retry(self): + src1 = PhysicalKey('test-bucket', 'dir/a', None) + src2 = PhysicalKey('test-bucket', 'dir/b', None) + + def side_effect(operation_name, *args, **kwargs): + if args[0]['Key'] == 'dir/a': + # src1 succeeds + return { + 'Body': io.BytesIO(b'a'), + } + else: + # src2 fails twice, then succeeds + if side_effect.counter < 2: + side_effect.counter += 1 + raise ConnectionError(error='foo') + return { + 'Body': io.BytesIO(b'b'), + } + + side_effect.counter = 0 + + with mock.patch('botocore.client.BaseClient._make_api_call', + side_effect=side_effect) as mocked_api_call: + result = data_transfer.calculate_checksum([src1, src2], [1, 2]) + assert result[0] == 'v106/7c+/S7Gw2rTES3ZM+/tY8Thy//PqI4nWcFE8tg=' + assert result[1] == 'OTYRYJA8ZpXGgEtxV8e9EAE+m6ibH5VCQ7yOOZCwjbk=' + self.assertEqual(mocked_api_call.call_count, 4) + @mock.patch.multiple( 'quilt3.data_transfer.s3_transfer_config', multipart_threshold=1, @@ -598,11 +888,7 @@ def test_progress_updateds(self, mocked_update): def handler(request, **kwargs): request.body.read(2) - mocked_update.assert_called_once_with(2) - - mocked_update.reset_mock() - request.body.seek(0) - mocked_update.assert_called_once_with(-2) + mocked_update.assert_called_once() raise Success @@ -662,50 +948,84 @@ def test_threshold_eq_chunk_gt_size(self): class S3HashingTest(QuiltTestCase): - data = b'0123456789abcdef' - size = len(data) - hasher = hashlib.sha256 - bucket = 'test-bucket' key = 'test-key' src = PhysicalKey(bucket, key, None) - def _hashing_subtest(self, *, threshold, chunksize, data=data): + def test_adjust_chunksize(self): + default = 8 * 1024 * 1024 + + # "Normal" file sizes + assert data_transfer.get_checksum_chunksize(8 * 1024 * 1024) == default + assert data_transfer.get_checksum_chunksize(1024 * 1024 * 1024) == default + assert data_transfer.get_checksum_chunksize(10_000 * default) == default + + # Big file: exceeds 10,000 parts + assert data_transfer.get_checksum_chunksize(10_000 * default + 1) == default * 2 + assert data_transfer.get_checksum_chunksize(2 * 10_000 * default + 1) == default * 4 + + def test_single(self): + data = b'0123456789abcdef' + size = len(data) + + chunksize = 8 * 1024 * 1024 + + ranges = { + f'bytes=0-{size-1}': data, + } + with self.s3_test_multi_thread_download( - self.bucket, self.key, data, threshold=threshold, chunksize=chunksize + self.bucket, self.key, ranges, threshold=chunksize, chunksize=chunksize ): - assert data_transfer.calculate_sha256([self.src], [self.size]) == [self.hasher(self.data).hexdigest()] - - def test_single_request(self): - params = ( - (self.size + 1, 5), - (self.size, self.size), - (self.size, self.size + 1), - (5, self.size), - ) - for threshold, chunksize in params: - with self.subTest(threshold=threshold, chunksize=chunksize): - self._hashing_subtest(threshold=threshold, chunksize=chunksize) - - def test_multi_request(self): - params = ( - ( - self.size, 5, { - 'bytes=0-4': self.data[:5], - 'bytes=5-9': self.data[5:10], - 'bytes=10-14': self.data[10:15], - 'bytes=15-15': self.data[15:], - } - ), - ( - 5, self.size - 1, { - 'bytes=0-14': self.data[:15], - 'bytes=15-15': self.data[15:], - } - ), - ) - for threshold, chunksize, data in params: - for concurrency in (len(data), 1): - with mock.patch('quilt3.data_transfer.s3_transfer_config.max_request_concurrency', concurrency): - with self.subTest(threshold=threshold, chunksize=chunksize, data=data, concurrency=concurrency): - self._hashing_subtest(threshold=threshold, chunksize=chunksize, data=data) + hash1 = data_transfer.calculate_checksum([self.src], [size])[0] + hash2 = data_transfer.calculate_checksum_bytes(data) + assert hash1 == hash2 + assert hash1 == 'Xb1PbjJeWof4zD7zuHc9PI7sLiz/Ykj4gphlaZEt3xA=' + + def test_multipart(self): + data = b'1234567890abcdefgh' * 1024 * 1024 + size = len(data) + + chunksize = 8 * 1024 * 1024 + + ranges = { + f'bytes=0-{chunksize-1}': data[:chunksize], + f'bytes={chunksize}-{chunksize*2-1}': data[chunksize:chunksize*2], + f'bytes={chunksize*2}-{size-1}': data[chunksize*2:], + } + + with self.s3_test_multi_thread_download( + self.bucket, self.key, ranges, threshold=chunksize, chunksize=chunksize + ): + hash1 = data_transfer.calculate_checksum([self.src], [size])[0] + hash2 = data_transfer.calculate_checksum_bytes(data) + assert hash1 == hash2 + assert hash1 == 'T+rt/HKRJOiAkEGXKvc+DhCwRcrZiDrFkjKonDT1zgs=' + + def test_one_part(self): + # Edge case: file length is exactly the threshold, resulting in a 1-part multipart upload. + data = b'12345678' * 1024 * 1024 + size = len(data) + + chunksize = 8 * 1024 * 1024 + + ranges = { + f'bytes=0-{size-1}': data, + } + + with self.s3_test_multi_thread_download( + self.bucket, self.key, ranges, threshold=chunksize, chunksize=chunksize + ): + hash1 = data_transfer.calculate_checksum([self.src], [size])[0] + hash2 = data_transfer.calculate_checksum_bytes(data) + assert hash1 == hash2 + assert hash1 == '7V3rZ3Q/AmAYax2wsQBZbc7N1EMIxlxRyMiMthGRdwg=' + + def test_empty(self): + data = b'' + size = len(data) + + hash1 = data_transfer.calculate_checksum([self.src], [size])[0] + hash2 = data_transfer.calculate_checksum_bytes(data) + assert hash1 == hash2 + assert hash1 == '47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=' diff --git a/api/python/tests/test_s3clientprovider.py b/api/python/tests/test_s3clientprovider.py new file mode 100644 index 00000000000..65a8f3a4708 --- /dev/null +++ b/api/python/tests/test_s3clientprovider.py @@ -0,0 +1,34 @@ +import os +from unittest import mock + +import botocore +import pytest + +from quilt3.data_transfer import S3ClientProvider + +PATCH_UNSET_CREDENTIALS = mock.patch.dict(os.environ, {"AWS_SHARED_CREDENTIALS_FILE": "/not-exist"}, clear=True) +PATCH_SET_CREDENTIALS = mock.patch.dict( + os.environ, + dict.fromkeys( + ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + ), + "blah" + ), + clear=True, +) + + +@pytest.mark.parametrize( + "credentials_context_manager, client, is_unsigned", + [ + (PATCH_UNSET_CREDENTIALS, "standard_client", True), + (PATCH_UNSET_CREDENTIALS, "unsigned_client", True), + (PATCH_SET_CREDENTIALS, "standard_client", False), + (PATCH_SET_CREDENTIALS, "unsigned_client", True), + ], +) +def test_client(credentials_context_manager, client, is_unsigned): + with credentials_context_manager: + assert (getattr(S3ClientProvider(), client).meta.config.signature_version == botocore.UNSIGNED) is is_unsigned diff --git a/api/python/tests/test_search.py b/api/python/tests/test_search.py index 42689182d94..e24f401f529 100644 --- a/api/python/tests/test_search.py +++ b/api/python/tests/test_search.py @@ -1,7 +1,8 @@ +from unittest import mock + import responses from quilt3 import search -from quilt3.util import get_from_config from .utils import QuiltTestCase @@ -11,11 +12,8 @@ class ResponseMock: class SearchTestCase(QuiltTestCase): - def test_all_bucket_search(self): - navigator_url = get_from_config('navigator_url') - api_gateway_url = get_from_config('apiGatewayEndpoint') - search_url = api_gateway_url + '/search' + registry_url = "https://registry.example.com" mock_search = { 'hits': { 'hits': [{ @@ -32,10 +30,27 @@ def test_all_bucket_search(self): } } - self.requests_mock.add(responses.GET, - f"{search_url}?index=%2A&action=search&query=%2A", - json=mock_search, - status=200, - match_querystring=True) - results = search("*") - assert len(results) == 1 + for user_query, expected_query_param in ( + ("*", dict(action="search", index="_all", size=10, query="*")), + ( + {"query": {"query_string": {"query": "handle:test*"}}}, + dict( + action="freeform", + index="_all", + size=10, + body='{"query": {"query_string": {"query": "handle:test*"}}}', + ), + ), + ): + with self.subTest(user_query=user_query, expected_query_param=expected_query_param): + self.requests_mock.get( + f"{registry_url}/api/search", + match=[responses.matchers.query_param_matcher(expected_query_param)], + json=mock_search, + status=200, + ) + with mock.patch("quilt3.session.get_registry_url", return_value=registry_url) as get_registry_url_mock: + results = search(user_query) + + get_registry_url_mock.assert_called_with() + assert len(results) == 1 diff --git a/api/python/tests/test_session.py b/api/python/tests/test_session.py index 258fa08a4a1..9cfc9e21927 100644 --- a/api/python/tests/test_session.py +++ b/api/python/tests/test_session.py @@ -5,13 +5,20 @@ import datetime from unittest.mock import patch +import boto3 +import pytest import responses import quilt3 +import quilt3.util from .utils import QuiltTestCase +def format_date(date): + return date.replace(tzinfo=datetime.timezone.utc, microsecond=0).isoformat() + + class TestSession(QuiltTestCase): @patch('quilt3.session.open_url') @patch('quilt3.session.input', return_value='123456') @@ -67,9 +74,6 @@ def test_login_with_token(self, mock_save_credentials, mock_save_auth): @patch('quilt3.session._save_credentials') @patch('quilt3.session._load_credentials') def test_create_botocore_session(self, mock_load_credentials, mock_save_credentials): - def format_date(date): - return date.replace(tzinfo=datetime.timezone.utc, microsecond=0).isoformat() - # Test good credentials. future_date = datetime.datetime.utcnow() + datetime.timedelta(hours=1) @@ -121,6 +125,62 @@ def format_date(date): mock_save_credentials.assert_called() + @patch("quilt3.util.load_config") + @patch("quilt3.session._load_credentials") + def test_get_boto3_session(self, mock_load_credentials, mock_load_config): + for kw in ( + {"fallback": False}, + {"fallback": True}, + {}, + ): + mock_load_credentials.reset_mock() + mock_load_config.reset_mock() + with self.subTest(kwargs=kw): + region = "us-west-2" + config = quilt3.util.load_config() + mock_load_config.return_value = { + **config, + "region": region, + } + + future_date = datetime.datetime.now() + datetime.timedelta(hours=1) + mock_load_credentials.return_value = dict( + access_key="access-key", + secret_key="secret-key", + token="session-token", + expiry_time=format_date(future_date), + ) + + session = quilt3.get_boto3_session() + mock_load_credentials.assert_called_once_with() + mock_load_config.assert_called_with() + + assert isinstance(session, boto3.Session) + credentials = session.get_credentials() + + assert credentials.access_key == "access-key" + assert credentials.secret_key == "secret-key" + assert credentials.token == "session-token" + + assert session.region_name == region + + @patch("quilt3.session.create_botocore_session") + @patch("quilt3.session._load_credentials", return_value={}) + def test_get_boto3_session_no_credentials_fallback_true(self, mock_load_credentials, mock_create_botocore_session): + session = quilt3.get_boto3_session() + mock_load_credentials.assert_called_once_with() + mock_create_botocore_session.assert_not_called() + + assert isinstance(session, boto3.Session) + + @patch("quilt3.session._load_credentials", return_value={}) + def test_get_boto3_session_no_credentials_fallback_false(self, mock_load_credentials): + with pytest.raises(quilt3.util.QuiltException) as exc_info: + quilt3.get_boto3_session(fallback=False) + + mock_load_credentials.assert_called_once_with() + assert "No Quilt credentials found" in str(exc_info.value) + def test_logged_in(self): registry_url = quilt3.session.get_registry_url() other_registry_url = registry_url + 'other' diff --git a/api/python/tests/test_util.py b/api/python/tests/test_util.py index cb929a6ce4c..3976a87927b 100644 --- a/api/python/tests/test_util.py +++ b/api/python/tests/test_util.py @@ -35,6 +35,21 @@ def test_write_yaml(tmpdir): assert fname.read_text('utf-8') == 'testing: bar\n' +@pytest.mark.parametrize("keep_backup", [False, True]) +def test_write_yaml_exception(tmp_path, keep_backup): + fname = tmp_path / "some_file.yml" + fname.write_text("42") + + exc = Exception("test exception") + with mock.patch("quilt3.util.yaml.dump", side_effect=exc): + with pytest.raises(Exception) as exc_info: + util.write_yaml("test", fname) + + assert exc_info.value == exc + assert fname.read_text("utf-8") == "42" + assert list(tmp_path.iterdir()) == [fname] + + def test_read_yaml(tmpdir): # Read a string parsed_string = util.read_yaml(TEST_YAML) diff --git a/catalog/.eslintrc.js b/catalog/.eslintrc.js index f8e68546c82..21a1a1d8d1d 100644 --- a/catalog/.eslintrc.js +++ b/catalog/.eslintrc.js @@ -51,6 +51,7 @@ module.exports = { 'max-classes-per-file': 0, 'no-console': 2, 'no-nested-ternary': 1, + 'no-restricted-globals': [2, "event", "location", "stop"], 'no-underscore-dangle': [2, { allow: ['_', '__', '__typename', '_tag'] }], 'prefer-arrow-callback': [2, { allowNamedFunctions: true }], 'prefer-template': 2, diff --git a/catalog/.graphqlrc.yml b/catalog/.graphqlrc.yml index 297ac6c267b..1b2a191112c 100644 --- a/catalog/.graphqlrc.yml +++ b/catalog/.graphqlrc.yml @@ -1,5 +1,4 @@ schema: '../shared/graphql/schema.graphql' -documents: './app/**/*.{graphql,js,ts,tsx}' extensions: codegen: hooks: diff --git a/catalog/CHANGELOG.md b/catalog/CHANGELOG.md new file mode 100644 index 00000000000..b5ba576e5cc --- /dev/null +++ b/catalog/CHANGELOG.md @@ -0,0 +1,46 @@ + +# Changelog + +Changes are listed in reverse chronological order (newer entries at the top). +The entry format is + +```markdown +- [Verb] Change description ([#](https://github.com/quiltdata/quilt/pull/)) +``` + +where verb is one of + +- Removed +- Added +- Fixed +- Changed + +## Changes + +- [Added] Support "html" type in `quilt_summarize.json` ([#4252](https://github.com/quiltdata/quilt/pull/4252)) +- [Fixed] Resolve caching issues where changes in `.quilt/{workflows,catalog}` were not applied ([#4245](https://github.com/quiltdata/quilt/pull/4245)) +- [Added] A shortcut to enable adding files to a package from the current bucket ([#4245](https://github.com/quiltdata/quilt/pull/4245)) +- [Changed] Qurator: propagate error messages from Bedrock ([#4192](https://github.com/quiltdata/quilt/pull/4192)) +- [Added] Qurator Developer Tools ([#4192](https://github.com/quiltdata/quilt/pull/4192)) +- [Changed] JsonDisplay: handle dates and functions ([#4192](https://github.com/quiltdata/quilt/pull/4192)) +- [Fixed] Keep default Intercom launcher closed when closing Package Dialog ([#4244](https://github.com/quiltdata/quilt/pull/4244)) +- [Fixed] Handle invalid bucket name in `ui.sourceBuckets` in bucket config ([#4242](https://github.com/quiltdata/quilt/pull/4242)) +- [Added] Preview Markdown while editing ([#4153](https://github.com/quiltdata/quilt/pull/4153)) +- [Changed] Athena: hide data catalogs user doesn't have access to ([#4239](https://github.com/quiltdata/quilt/pull/4239)) +- [Added] Enable MixPanel tracking in Embed mode ([#4237](https://github.com/quiltdata/quilt/pull/4237)) +- [Fixed] Fix embed files listing ([#4236](https://github.com/quiltdata/quilt/pull/4236)) +- [Changed] Qurator: switch to Claude 3.5 Sonnet **v2** ([#4234](https://github.com/quiltdata/quilt/pull/4234)) +- [Changed] Add `catalog` fragment to Quilt+ URIs (and to documentation) ([#4213](https://github.com/quiltdata/quilt/pull/4213)) +- [Fixed] Athena: fix minor UI bugs ([#4232](https://github.com/quiltdata/quilt/pull/4232)) +- [Fixed] Show Athena query editor when no named queries ([#4230](https://github.com/quiltdata/quilt/pull/4230)) +- [Fixed] Fix some doc URLs in catalog ([#4205](https://github.com/quiltdata/quilt/pull/4205)) +- [Changed] S3 Select -> GQL API calls for getting access counts ([#4218](https://github.com/quiltdata/quilt/pull/4218)) +- [Changed] Athena: improve loading state and errors visuals; fix minor bugs; alphabetize and persist selection in workgroups, catalog names and databases ([#4208](https://github.com/quiltdata/quilt/pull/4208)) +- [Changed] Show stack release version in footer ([#4200](https://github.com/quiltdata/quilt/pull/4200)) +- [Added] Selective package downloading ([#4173](https://github.com/quiltdata/quilt/pull/4173)) +- [Added] Qurator Omni: initial public release ([#4032](https://github.com/quiltdata/quilt/pull/4032), [#4181](https://github.com/quiltdata/quilt/pull/4181)) +- [Added] Admin: UI for configuring longitudinal queries (Tabulator) ([#4135](https://github.com/quiltdata/quilt/pull/4135), [#4164](https://github.com/quiltdata/quilt/pull/4164), [#4165](https://github.com/quiltdata/quilt/pull/4165)) +- [Changed] Admin: Move bucket settings to a separate page ([#4122](https://github.com/quiltdata/quilt/pull/4122)) +- [Changed] Athena: always show catalog name, simplify setting execution context ([#4123](https://github.com/quiltdata/quilt/pull/4123)) +- [Added] Support `ui.actions.downloadObject` and `ui.actions.downloadPackage` options for configuring visibility of download buttons under "Bucket" and "Packages" respectively ([#4111](https://github.com/quiltdata/quilt/pull/4111)) +- [Added] Bootstrap the change log ([#4112](https://github.com/quiltdata/quilt/pull/4112)) diff --git a/catalog/Dockerfile b/catalog/Dockerfile index c747f568b3e..cfdd0e23740 100644 --- a/catalog/Dockerfile +++ b/catalog/Dockerfile @@ -1,4 +1,4 @@ -FROM amazonlinux:2023.3.20231218.0 +FROM amazonlinux:2023.6.20241121.0 MAINTAINER Quilt Data, Inc. contact@quiltdata.io ENV LC_ALL=C.UTF-8 @@ -16,18 +16,22 @@ RUN ln -sf /dev/stdout /var/log/nginx/access.log && \ COPY nginx.conf /etc/nginx/nginx.conf COPY nginx-web.conf /etc/nginx/conf.d/default.conf +ARG NGINX_STATIC_DIR=/usr/share/nginx/html # Copy pre-built catalog assets to nginx -RUN rm -rf /usr/share/nginx/html -COPY build /usr/share/nginx/html +RUN rm -rf $NGINX_STATIC_DIR +COPY build $NGINX_STATIC_DIR # Copy config file COPY config.json.tmpl config.json.tmpl +RUN ln -s /tmp/config.json $NGINX_STATIC_DIR/config.json && \ + ln -s /tmp/config.js $NGINX_STATIC_DIR/config.js + # Use SIGQUIT for a "graceful" shutdown STOPSIGNAL SIGQUIT # Substitute environment variables into config.json and generate config.js based on that before starting nginx. # Note: use "exec" because otherwise the shell will catch Ctrl-C and other signals. -CMD envsubst < config.json.tmpl > /usr/share/nginx/html/config.json \ - && echo "window.QUILT_CATALOG_CONFIG = `cat /usr/share/nginx/html/config.json`" > /usr/share/nginx/html/config.js \ +CMD envsubst < config.json.tmpl > /tmp/config.json \ + && echo "window.QUILT_CATALOG_CONFIG = `cat /tmp/config.json`" > /tmp/config.js \ && exec nginx -g 'daemon off;' diff --git a/catalog/README.md b/catalog/README.md index 5c3f310704d..933afe6b259 100644 --- a/catalog/README.md +++ b/catalog/README.md @@ -2,8 +2,6 @@ The catalog is a web frontend for browsing meta-data held by the Quilt registry. -# Developer - ## Configuration The app configuration (API endpoints, bucket federations, etc.) is read from @@ -41,8 +39,10 @@ $ npm start ### Fetch -- An accurate check for a successful fetch() would include checking that the promise resolved, then checking that the Response.ok property has a value of true. The code would look something like this: - [msdn fetch doc](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch) +- An accurate check for a successful fetch() would include checking that +the promise resolved, then checking that the Response.ok property has +a value of true. The code would look something like this: + [msdn fetch doc](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch) ## Deployment (for Quilt internal usage) diff --git a/catalog/app/app.tsx b/catalog/app/app.tsx index 5b6966a3d90..4fa88afa393 100644 --- a/catalog/app/app.tsx +++ b/catalog/app/app.tsx @@ -22,11 +22,10 @@ Sentry.init(cfg, history) import 'sanitize.css' // Import the rest of our modules -import { ExperimentsProvider } from 'components/Experiments' +import * as Assistant from 'components/Assistant' import * as Intercom from 'components/Intercom' import Placeholder from 'components/Placeholder' import App from 'containers/App' -import GTMLoader from 'utils/gtm' import * as Auth from 'containers/Auth' import * as Errors from 'containers/Errors' import * as Notifications from 'containers/Notifications' @@ -39,8 +38,10 @@ import * as APIConnector from 'utils/APIConnector' import * as GraphQL from 'utils/GraphQL' import { BucketCacheProvider } from 'utils/BucketCache' import GlobalAPI from 'utils/GlobalAPI' +import WithGlobalDialogs from 'utils/GlobalDialogs' import log from 'utils/Logging' import * as NamedRoutes from 'utils/NamedRoutes' +import { PFSCookieManager } from 'utils/PFSCookieManager' import * as Cache from 'utils/ResourceCache' import * as Store from 'utils/Store' import fontLoader from 'utils/fontLoader' @@ -104,7 +105,6 @@ const render = () => { Notifications.Provider, [APIConnector.Provider, { fetch, middleware: [Auth.apiMiddleware] }], [Auth.Provider, { storage }], - [GTMLoader, { gtmId: cfg.gtmId }], [ Intercom.Provider, { @@ -115,15 +115,18 @@ const render = () => { vertical_padding: 59, }, ], - ExperimentsProvider, [Tracking.Provider, { userSelector: Auth.selectors.username }], AWS.Credentials.Provider, AWS.Config.Provider, AWS.Athena.Provider, AWS.S3.Provider, + Assistant.Provider, + Assistant.WithUI, Notifications.WithNotifications, + WithGlobalDialogs, Errors.ErrorBoundary, BucketCacheProvider, + PFSCookieManager, App, ), MOUNT_NODE, diff --git a/catalog/app/components/Assistant/Model/Assistant.tsx b/catalog/app/components/Assistant/Model/Assistant.tsx new file mode 100644 index 00000000000..74667acf28c --- /dev/null +++ b/catalog/app/components/Assistant/Model/Assistant.tsx @@ -0,0 +1,96 @@ +import * as Eff from 'effect' +import invariant from 'invariant' + +import * as React from 'react' + +import * as AWS from 'utils/AWS' +import * as Actor from 'utils/Actor' + +import * as Bedrock from './Bedrock' +import * as Context from './Context' +import * as Conversation from './Conversation' +import * as GlobalContext from './GlobalContext' +import useIsEnabled from './enabled' + +export const DISABLED = Symbol('DISABLED') + +function usePassThru(val: T) { + const ref = React.useRef(val) + ref.current = val + return ref +} + +function useConstructAssistantAPI() { + const passThru = usePassThru({ + bedrock: AWS.Bedrock.useClient(), + context: Context.useLayer(), + }) + const layerEff = Eff.Effect.sync(() => + Eff.Layer.merge( + Bedrock.LLMBedrock(passThru.current.bedrock), + passThru.current.context, + ), + ) + const [state, dispatch] = Actor.useActorLayer( + Conversation.ConversationActor, + Conversation.init, + layerEff, + ) + + GlobalContext.use() + + // XXX: move this to actor state? + const [visible, setVisible] = React.useState(false) + const show = React.useCallback(() => setVisible(true), []) + const hide = React.useCallback(() => setVisible(false), []) + + const assist = React.useCallback( + (msg?: string) => { + if (msg) dispatch(Conversation.Action.Ask({ content: msg })) + show() + }, + [show, dispatch], + ) + + return { + visible, + show, + hide, + assist, + state, + dispatch, + } +} + +export type AssistantAPI = ReturnType +export type { AssistantAPI as API } + +const Ctx = React.createContext(null) + +function AssistantAPIProvider({ children }: React.PropsWithChildren<{}>) { + return {children} +} + +function DisabledAPIProvider({ children }: React.PropsWithChildren<{}>) { + return {children} +} + +export function AssistantProvider({ children }: React.PropsWithChildren<{}>) { + return useIsEnabled() ? ( + + {children} + + ) : ( + {children} + ) +} + +export function useAssistantAPI() { + const api = React.useContext(Ctx) + invariant(api, 'AssistantAPI must be used within an AssistantProvider') + return api === DISABLED ? null : api +} + +export function useAssistant() { + return useAssistantAPI()?.assist +} diff --git a/catalog/app/components/Assistant/Model/Bedrock.ts b/catalog/app/components/Assistant/Model/Bedrock.ts new file mode 100644 index 00000000000..11501f73697 --- /dev/null +++ b/catalog/app/components/Assistant/Model/Bedrock.ts @@ -0,0 +1,160 @@ +import type * as AWSSDK from 'aws-sdk' +import BedrockRuntime from 'aws-sdk/clients/bedrockruntime' +import * as Eff from 'effect' + +import * as Log from 'utils/Logging' + +import * as Content from './Content' +import * as LLM from './LLM' + +const MODULE = 'Bedrock' + +const MODEL_ID = 'us.anthropic.claude-3-5-sonnet-20241022-v2:0' + +const mapContent = (contentBlocks: BedrockRuntime.ContentBlocks | undefined) => + Eff.pipe( + contentBlocks, + Eff.Option.fromNullable, + Eff.Option.map( + Eff.Array.flatMapNullable((c) => { + if (c.document) { + return Content.ResponseMessageContentBlock.Document({ + format: c.document.format as $TSFixMe, + source: c.document.source.bytes as $TSFixMe, + name: c.document.name, + }) + } + if (c.image) { + return Content.ResponseMessageContentBlock.Image({ + format: c.image.format as $TSFixMe, + source: c.image.source.bytes as $TSFixMe, + }) + } + if (c.text) { + return Content.ResponseMessageContentBlock.Text({ text: c.text }) + } + if (c.toolUse) { + return Content.ResponseMessageContentBlock.ToolUse(c.toolUse as $TSFixMe) + } + // if (c.guardContent) { + // // TODO + // return acc + // } + // if (c.toolResult) { + // // XXX: is it really supposed to occur here in LLM response? + // return acc + // } + return null + }), + ), + ) + +// TODO: use Schema +const contentToBedrock = Content.PromptMessageContentBlock.$match({ + GuardContent: ({ text }) => ({ guardContent: { text: { text } } }), + ToolResult: ({ toolUseId, status, content }) => ({ + toolResult: { + toolUseId, + status, + content: content.map( + Content.ToolResultContentBlock.$match({ + Json: ({ _tag, ...rest }) => rest, + Text: ({ _tag, ...rest }) => rest, + // XXX: be careful with base64/non-base64 encoding + Image: ({ format, source }) => ({ + image: { format, source: { bytes: source } }, + }), + Document: ({ format, source, name }) => ({ + document: { format, source: { bytes: source }, name }, + }), + }), + ), + }, + }), + ToolUse: ({ _tag, ...toolUse }) => ({ toolUse }), + Text: ({ _tag, ...rest }) => rest, + Image: ({ format, source }) => ({ image: { format, source: { bytes: source } } }), + Document: ({ format, source, name }) => ({ + document: { format, source: { bytes: source }, name }, + }), +}) + +const messagesToBedrock = ( + messages: Eff.Array.NonEmptyArray, +): BedrockRuntime.Message[] => + // create an array of alternating assistant and user messages + Eff.pipe( + messages, + Eff.Array.groupWith((m1, m2) => m1.role === m2.role), + Eff.Array.map((group) => ({ + role: group[0].role, + content: group.map((m) => contentToBedrock(m.content)), + })), + ) + +const toolConfigToBedrock = ( + toolConfig: LLM.ToolConfig, +): BedrockRuntime.ToolConfiguration => ({ + tools: Object.entries(toolConfig.tools).map(([name, { description, schema }]) => ({ + toolSpec: { + name, + description, + inputSchema: { json: schema }, + }, + })), + toolChoice: + toolConfig.choice && + LLM.ToolChoice.$match(toolConfig.choice, { + Auto: () => ({ auto: {} }), + Any: () => ({ any: {} }), + Specific: ({ name }) => ({ tool: { name } }), + }), +}) + +function isAWSError(e: any): e is AWSSDK.AWSError { + return e.code !== undefined && e.message !== undefined +} + +// a layer providing the service over aws.bedrock +export function LLMBedrock(bedrock: BedrockRuntime) { + const converse = (prompt: LLM.Prompt, opts?: LLM.Options) => + Log.scoped({ + name: `${MODULE}.converse`, + enter: [ + Log.br, + 'model id:', + MODEL_ID, + Log.br, + 'prompt:', + prompt, + Log.br, + 'opts:', + opts, + ], + })( + Eff.Effect.tryPromise({ + try: () => + bedrock + .converse({ + modelId: MODEL_ID, + system: [{ text: prompt.system }], + messages: messagesToBedrock(prompt.messages), + toolConfig: prompt.toolConfig && toolConfigToBedrock(prompt.toolConfig), + ...opts, + }) + .promise() + .then((backendResponse) => ({ + backendResponse, + content: mapContent(backendResponse.output.message?.content), + })), + catch: (e) => + new LLM.LLMError({ + message: isAWSError(e) + ? `Bedrock error (${e.code}): ${e.message}` + : `Unexpected error: ${e}`, + }), + }), + ) + + return Eff.Layer.succeed(LLM.LLM, { converse }) +} diff --git a/catalog/app/components/Assistant/Model/Content.ts b/catalog/app/components/Assistant/Model/Content.ts new file mode 100644 index 00000000000..9c45cde3ee3 --- /dev/null +++ b/catalog/app/components/Assistant/Model/Content.ts @@ -0,0 +1,107 @@ +import * as Eff from 'effect' + +import { JsonRecord } from 'utils/types' + +// XXX: schema for en/decoding to/from aws bedrock types? + +export const DOCUMENT_FORMATS = [ + 'pdf', + 'csv', + 'doc', + 'docx', + 'xls', + 'xlsx', + 'html', + 'txt', + 'md', +] as const +export type DocumentFormat = (typeof DOCUMENT_FORMATS)[number] + +export interface DocumentBlock { + format: DocumentFormat + name: string + // A base64-encoded string of a UTF-8 encoded file, that is the document to include in the message. + source: Buffer | Uint8Array | Blob | string +} + +export const IMAGE_FORMATS = ['png', 'jpeg', 'gif', 'webp'] as const +export type ImageFormat = (typeof IMAGE_FORMATS)[number] + +export interface ImageBlock { + format: ImageFormat + // The raw image bytes for the image. If you use an AWS SDK, you don't need to base64 encode the image bytes. + source: Buffer | Uint8Array | Blob | string +} + +export interface JsonBlock { + json: JsonRecord +} + +export interface TextBlock { + text: string +} + +export interface GuardBlock { + text: string +} + +export interface ToolUseBlock { + toolUseId: string + name: string + input: JsonRecord +} + +export type ToolResultContentBlock = Eff.Data.TaggedEnum<{ + Json: JsonBlock + Text: TextBlock + Image: ImageBlock + Document: DocumentBlock +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const ToolResultContentBlock = Eff.Data.taggedEnum() + +export type ToolResultStatus = 'success' | 'error' + +export interface ToolResultBlock { + toolUseId: string + content: ToolResultContentBlock[] + status: ToolResultStatus +} + +export type ResponseMessageContentBlock = Eff.Data.TaggedEnum<{ + // GuardContent: {} + // ToolResult: {} + ToolUse: ToolUseBlock + Text: TextBlock + Image: ImageBlock + Document: DocumentBlock +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const ResponseMessageContentBlock = + Eff.Data.taggedEnum() + +export type MessageContentBlock = Eff.Data.TaggedEnum<{ + Text: TextBlock + Image: ImageBlock + Document: DocumentBlock +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const MessageContentBlock = Eff.Data.taggedEnum() + +export type PromptMessageContentBlock = Eff.Data.TaggedEnum<{ + GuardContent: GuardBlock + ToolResult: ToolResultBlock + ToolUse: ToolUseBlock + Text: TextBlock + Image: ImageBlock + Document: DocumentBlock +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const PromptMessageContentBlock = Eff.Data.taggedEnum() + +export const text = (first: string, ...rest: string[]) => + PromptMessageContentBlock.Text({ text: [first, ...rest].join('\n') }) diff --git a/catalog/app/components/Assistant/Model/Context.tsx b/catalog/app/components/Assistant/Model/Context.tsx new file mode 100644 index 00000000000..9b92a413129 --- /dev/null +++ b/catalog/app/components/Assistant/Model/Context.tsx @@ -0,0 +1,139 @@ +import * as Eff from 'effect' +import invariant from 'invariant' +import * as R from 'ramda' +import * as React from 'react' +import * as uuid from 'uuid' + +import { runtime } from 'utils/Effect' +import useConst from 'utils/useConstant' +import useMemoEq from 'utils/useMemoEq' + +import * as Tool from './Tool' +import useIsEnabled from './enabled' + +export interface ContextShape { + messages: string[] + tools: Tool.Collection + markers: Record +} + +interface ContextAggregator { + push: (context: Partial) => () => void + counter: number + getValues: () => Partial[] +} + +export const ContextAggregatorCtx = React.createContext(null) + +export function ContextAggregatorProvider({ children }: React.PropsWithChildren<{}>) { + const mountedRef = React.useRef>>({}) + const [counter, setCounter] = React.useState(0) + + const push = React.useCallback( + (context: Partial) => { + const id = uuid.v4() + mountedRef.current[id] = context + setCounter((c) => c + 1) + return () => { + delete mountedRef.current[id] + } + }, + [mountedRef], + ) + + const getValues = React.useCallback( + () => Object.values(mountedRef.current), + [mountedRef], + ) + + const value = { push, counter, getValues } + + return ( + + {children} + + ) +} + +const ROOT_CONTEXT: ContextShape = { + tools: {}, + messages: [], + markers: {}, +} + +function aggregateContext(contexts: Partial[]) { + return contexts.reduce( + (acc: ContextShape, next) => ({ + // XXX: check for conflicts? + tools: { ...acc.tools, ...next.tools }, + messages: acc.messages.concat(next.messages || []), + markers: { ...acc.markers, ...next.markers }, + }), + ROOT_CONTEXT, + ) +} + +export function useAggregatedContext(): ContextShape { + const ctx = React.useContext(ContextAggregatorCtx) + invariant(ctx, 'ContextAggregator must be used within a ContextAggregatorProvider') + + const { getValues, counter } = ctx + const [computed, setComputed] = React.useState(() => aggregateContext(getValues())) + + React.useEffect(() => { + const values = getValues() + const aggregated = aggregateContext(values) + setComputed(aggregated) + }, [setComputed, getValues, counter]) + + return computed +} + +export function usePushContext(context: Partial) { + const ctx = React.useContext(ContextAggregatorCtx) + invariant(ctx, 'ContextAggregator must be used within a ContextAggregatorProvider') + const { push } = ctx + const contextMemo = useMemoEq(context, R.identity) + React.useEffect(() => push(contextMemo), [push, contextMemo]) +} + +export function Push(context: Partial) { + usePushContext(context) + return null +} + +export type ContextProviderHook = (props: Props) => Partial + +export function LazyContext(useContext: ContextProviderHook) { + function ProvideContext(props: Props) { + usePushContext(useContext(props)) + return null + } + return function WithLazyContext(props: Props & JSX.IntrinsicAttributes) { + return useIsEnabled() ? : null + } +} + +export class ConversationContext extends Eff.Context.Tag('ConversationContext')< + ConversationContext, + { + context: Eff.Effect.Effect + } +>() {} + +export function useLayer() { + const contextObj = useAggregatedContext() + const ref = React.useRef(contextObj) + if (ref.current !== contextObj) ref.current = contextObj + const context = React.useMemo(() => Eff.Effect.sync(() => ref.current), [ref]) + return Eff.Layer.succeed(ConversationContext, { context }) +} + +export function useMarkersRef() { + const { markers } = useAggregatedContext() + const ref = useConst(() => runtime.runSync(Eff.SubscriptionRef.make(markers))) + React.useEffect(() => { + runtime.runFork(Eff.SubscriptionRef.set(ref, markers)) + }, [markers, ref]) + return ref +} diff --git a/catalog/app/components/Assistant/Model/Conversation.ts b/catalog/app/components/Assistant/Model/Conversation.ts new file mode 100644 index 00000000000..50a7ee8053e --- /dev/null +++ b/catalog/app/components/Assistant/Model/Conversation.ts @@ -0,0 +1,431 @@ +import * as Eff from 'effect' +import * as uuid from 'uuid' + +import * as Actor from 'utils/Actor' +import * as Log from 'utils/Logging' +import * as XML from 'utils/XML' + +import * as Content from './Content' +import * as Context from './Context' +import * as LLM from './LLM' +import * as Tool from './Tool' + +const MODULE = 'Conversation' + +// TODO: make this a globally available service? +const genId = Eff.Effect.sync(uuid.v4) + +// TODO: use effect/DateTime after upgrading +const getNow = Eff.Clock.currentTimeMillis.pipe(Eff.Effect.map((t) => new Date(t))) + +export interface ToolCall { + readonly name: string + readonly input: Record + readonly fiber: Eff.Fiber.RuntimeFiber +} + +export type ToolUseId = string + +export type ToolCalls = Record + +interface EventBase { + readonly id: string + readonly timestamp: Date + readonly discarded?: boolean +} + +export type Event = Eff.Data.TaggedEnum<{ + Message: EventBase & { + readonly role: 'user' | 'assistant' + readonly content: Content.MessageContentBlock + } + ToolUse: EventBase & { + readonly toolUseId: string + readonly name: string + readonly input: Record + readonly result: Tool.Result + } +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Event = Eff.Data.taggedEnum() + +interface ConversationError { + message: string + details: string +} + +interface StateBase { + readonly events: Event[] + readonly timestamp: Date +} + +export type State = Eff.Data.TaggedEnum<{ + /** + * Waiting for user input + */ + Idle: StateBase & { + readonly error: Eff.Option.Option + } + + /** + * Waiting for assistant (LLM) to respond + */ + WaitingForAssistant: StateBase & { + readonly requestFiber: Eff.Fiber.RuntimeFiber + } + + /** + * Tool use in progress + */ + ToolUse: StateBase & { + // TODO: use HashMap? + readonly calls: Record + // readonly retries: number + } +}> + +const idle = (events: Event[], error?: ConversationError) => + Eff.Effect.map(getNow, (timestamp) => + State.Idle({ events, timestamp, error: Eff.Option.fromNullable(error) }), + ) + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const State = Eff.Data.taggedEnum() + +export type Action = Eff.Data.TaggedEnum<{ + Ask: { + readonly content: string + } + LLMError: { + readonly error: LLM.LLMError + } + LLMResponse: { + readonly content: Exclude[] + readonly toolUses: Extract[] + } + ToolUse: { + readonly toolUseId: string + readonly name: string + readonly input: Record + } + ToolResult: { + readonly id: string + readonly result: Tool.ResultOption + } + Abort: {} + Clear: {} + Discard: { readonly id: string } +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Action = Eff.Data.taggedEnum() + +export const init = Eff.Effect.gen(function* () { + return State.Idle({ + timestamp: yield* getNow, + events: [], + error: Eff.Option.none(), + }) +}) + +const llmRequest = (events: Event[]) => + Log.scoped({ + name: `${MODULE}.llmRequest`, + enter: [Log.br, 'events:', events], + })( + Eff.Effect.gen(function* () { + const llm = yield* LLM.LLM + const ctxService = yield* Context.ConversationContext + const ctx = yield* ctxService.context + const filteredEvents = events.filter((e) => !e.discarded) + const prompt = yield* constructPrompt(filteredEvents, ctx) + + const response = yield* llm.converse(prompt) + + if (Eff.Option.isNone(response.content)) { + return yield* Eff.Effect.fail( + new LLM.LLMError({ message: 'No content in LLM response' }), + ) + } + + const [toolUses, content] = Eff.Array.partitionMap(response.content.value, (c) => + c._tag === 'ToolUse' ? Eff.Either.left(c) : Eff.Either.right(c), + ) + + return { content, toolUses } + }), + ) + +// XXX: separate "service" from handlers +export const ConversationActor = Eff.Effect.succeed( + Log.scopedFn(`${MODULE}.ConversationActor`)( + Actor.taggedHandler({ + Idle: { + Ask: (state, action, dispatch) => + Eff.Effect.gen(function* () { + const timestamp = yield* getNow + const event = Event.Message({ + id: yield* genId, + timestamp, + role: 'user', + content: Content.text(action.content), + }) + const events = state.events.concat(event) + + const requestFiber = yield* Actor.forkRequest( + llmRequest(events), + dispatch, + (r) => Eff.Effect.succeed(Action.LLMResponse(r)), + (error) => Eff.Effect.succeed(Action.LLMError({ error })), + ) + return State.WaitingForAssistant({ events, timestamp, requestFiber }) + }), + Clear: () => idle([]), + Discard: (state, { id }) => + Eff.Effect.succeed({ + ...state, + events: state.events.map((e) => + e.id === id ? { ...e, discarded: true } : e, + ), + }), + }, + WaitingForAssistant: { + LLMError: ({ events }, { error }) => + idle(events, { + message: 'Error while interacting with LLM.', + details: error.message, + }), + LLMResponse: (state, { content, toolUses }, dispatch) => + Eff.Effect.gen(function* () { + const timestamp = yield* getNow + + let { events } = state + if (content.length) { + events = events.concat( + yield* Eff.Effect.all( + content.map((c) => + Eff.Effect.andThen(genId, (id) => + Event.Message({ + id, + timestamp, + role: 'assistant', + content: c, + }), + ), + ), + ), + ) + } + + if (!toolUses.length) { + return State.Idle({ events, timestamp, error: Eff.Option.none() }) + } + + const ctxService = yield* Context.ConversationContext + const { tools } = yield* ctxService.context + const calls: Record = {} + for (const tu of toolUses) { + const fiber = yield* Eff.Effect.fork( + Eff.Effect.gen(function* () { + const result = yield* Tool.execute(tools, tu.name, tu.input) + yield* dispatch(Action.ToolResult({ id: tu.toolUseId, result })) + }), + ) + calls[tu.toolUseId] = { + name: tu.name, + input: tu.input, + fiber, + } + } + + return State.ToolUse({ events, timestamp: state.timestamp, calls }) + }), + Abort: ({ events, requestFiber }) => + Eff.Effect.gen(function* () { + // interrupt current request fiber and go back to idle + yield* Eff.Fiber.interruptFork(requestFiber) + + return State.Idle({ + events, + timestamp: yield* getNow, + error: Eff.Option.none(), + }) + }), + }, + ToolUse: { + ToolResult: (state, { id, result }, dispatch) => + Eff.Effect.gen(function* () { + if (!(id in state.calls)) return state + + const calls = { ...state.calls } + const call = calls[id] + delete calls[id] + + let events = state.events + if (Eff.Option.isSome(result)) { + const event = Event.ToolUse({ + id: yield* genId, + timestamp: yield* getNow, + toolUseId: id, + name: call.name, + input: call.input, + result: result.value, + }) + events = events.concat(event) + } + + if (Object.keys(calls).length) { + // some calls still in progress + return State.ToolUse({ events, timestamp: state.timestamp, calls }) + } + + // all calls completed, send results back to LLM + const requestFiber = yield* Actor.forkRequest( + llmRequest(events), + dispatch, + (r) => Eff.Effect.succeed(Action.LLMResponse(r)), + (error) => Eff.Effect.succeed(Action.LLMError({ error })), + ) + + return State.WaitingForAssistant({ + events, + timestamp: yield* getNow, + requestFiber, + }) + }), + Abort: ({ events, calls }) => + Eff.Effect.gen(function* () { + // interrupt current tool use fibers and go back to idle + yield* Eff.pipe( + calls, + Eff.Record.collect((_k, v) => v.fiber), + Eff.Array.map(Eff.Fiber.interruptFork), + Eff.Effect.all, + ) + + return State.Idle({ + events, + timestamp: yield* getNow, + error: Eff.Option.none(), + }) + }), + }, + }), + ), +) + +const NAME = 'Qurator' + +const SYSTEM = ` +You are ${NAME}, an AI assistant created by Quilt Data. +Your primary purpose is assisting users of Quilt Data products. +Persona: conservative and creative scientist. +` + +// TODO: mention the client company? +const TASK_CONTEXT = XML.tag( + 'task-context', + {}, + 'You act as a chatbot deployed on the Quilt Catalog web app.', + 'You have access to the the Quilt Catalog UI via context and tools.', +).toString() + +// detailed task description and rules +const TASK_DESCRIPTION = XML.tag( + 'task-description', + {}, + 'When asked a question about Quilt or Quilt Data, refer to the documentation at https://docs.quiltdata.com.', +).toString() + +const CONVERSATION_START = ` +Following is the conversation history: +### CONVERSATION START ### +` + +const CONVERSATION_END = ` +### CONVERSATION END ### +` + +const IMMEDIATE_TASK = ` +Advance the provided conversation in the most helpful way possible. +Use tools proactively, but don't mention that unnecessarily, so that it feels transparent. + +Think step by step and carefully analyze the provided context to prevent giving +incomplete or inaccurate information. + +Never make things up, always double-check your responses and use the context +to your advantage. + +Use GitHub Flavored Markdown syntax for formatting when appropriate. +` + +export const constructPrompt = ( + events: Event[], + context: Context.ContextShape, +): Eff.Effect.Effect => + Log.scoped({ + name: `${MODULE}.constructPrompt`, + enter: [Log.br, 'events:', events, Log.br, 'context:', context], + })( + Eff.Effect.gen(function* () { + // XXX: add context about quilt products? + // XXX: add context about catalog structure and features? + + const [msgEvents, toolEvents] = Eff.Array.partitionMap( + events, + Event.$match({ + Message: (m) => Eff.Either.left(m), + ToolUse: (t) => Eff.Either.right(t), + }), + ) + + const toolMessages = Eff.Array.flatMap( + toolEvents, + ({ toolUseId, name, input, result }) => [ + LLM.assistantMessage( + Content.PromptMessageContentBlock.ToolUse({ toolUseId, name, input }), + ), + LLM.userMessage( + Content.PromptMessageContentBlock.ToolResult({ toolUseId, ...result }), + ), + ], + ) + + const currentTime = (yield* getNow).toISOString() + + // prompt structure + // - task context + // - tone context + // - Background data, documents, and images + // - detailed task description and rules + // - examples + // - input data + // - conversation history + // - user input + // - immediate task + // - precognition + // - output formatting + // - prefill + const messages: Eff.Array.NonEmptyArray = [ + LLM.userMessage( + Content.text( + TASK_CONTEXT, + TASK_DESCRIPTION, + XML.tag('context', {}, ...context.messages).toString(), + XML.tag('current-time', {}, currentTime).toString(), + CONVERSATION_START, + ), + ), + ...msgEvents.map(({ role, content }) => LLM.PromptMessage({ role, content })), + LLM.userMessage(Content.text(CONVERSATION_END, IMMEDIATE_TASK)), + ...toolMessages, + ] + + return { + system: SYSTEM, + messages, + toolConfig: { tools: context.tools }, + } + }), + ) diff --git a/catalog/app/components/Assistant/Model/GlobalContext/__snapshots__/navigation.spec.ts.snap b/catalog/app/components/Assistant/Model/GlobalContext/__snapshots__/navigation.spec.ts.snap new file mode 100644 index 00000000000..8fba6749d1c --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/__snapshots__/navigation.spec.ts.snap @@ -0,0 +1,803 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`components/Assistant/Model/GlobalTools/navigation NavigateSchema produced JSON Schema should match the snapshot 1`] = ` +{ + "$defs": { + "BooleanPredicateState": { + "additionalProperties": false, + "properties": { + "false": { + "type": "boolean", + }, + "true": { + "type": "boolean", + }, + }, + "required": [], + "title": "Boolean predicate state", + "type": "object", + }, + "BooleanTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "Boolean", + ], + }, + "value": { + "$ref": "#/$defs/BooleanPredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + "DatetimePredicateState": { + "additionalProperties": false, + "properties": { + "gte": { + "description": "ISO date string", + "format": "date-time", + "type": "string", + }, + "lte": { + "description": "ISO date string", + "format": "date-time", + "type": "string", + }, + }, + "required": [], + "title": "Datetime predicate state", + "type": "object", + }, + "DatetimeTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "Datetime", + ], + }, + "value": { + "$ref": "#/$defs/DatetimePredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + "KeywordEnumPredicateState": { + "additionalProperties": false, + "properties": { + "terms": { + "items": { + "type": "string", + }, + "type": "array", + }, + }, + "required": [], + "title": "KeywordEnum predicate state", + "type": "object", + }, + "KeywordEnumTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "KeywordEnum", + ], + }, + "value": { + "$ref": "#/$defs/KeywordEnumPredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + "KeywordWildcardPredicateState": { + "additionalProperties": false, + "properties": { + "wildcard": { + "type": "string", + }, + }, + "required": [], + "title": "KeywordWildcard predicate state", + "type": "object", + }, + "KeywordWildcardTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "KeywordWildcard", + ], + }, + "value": { + "$ref": "#/$defs/KeywordWildcardPredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + "NumberPredicateState": { + "additionalProperties": false, + "properties": { + "gte": { + "type": "number", + }, + "lte": { + "type": "number", + }, + }, + "required": [], + "title": "Number predicate state", + "type": "object", + }, + "NumberTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "Number", + ], + }, + "value": { + "$ref": "#/$defs/NumberPredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + "TextPredicateState": { + "additionalProperties": false, + "properties": { + "queryString": { + "type": "string", + }, + }, + "required": [], + "title": "Text predicate state", + "type": "object", + }, + "TextTaggedPredicate": { + "additionalProperties": false, + "properties": { + "type": { + "enum": [ + "Text", + ], + }, + "value": { + "$ref": "#/$defs/TextPredicateState", + }, + }, + "required": [ + "type", + "value", + ], + "type": "object", + }, + }, + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": false, + "description": "navigate to a provided route", + "properties": { + "route": { + "anyOf": [ + { + "additionalProperties": false, + "description": "Home page", + "properties": { + "name": { + "enum": [ + "home", + ], + }, + "params": { + "$id": "/schemas/{}", + "anyOf": [ + { + "type": "object", + }, + { + "type": "array", + }, + ], + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "Installation page", + "properties": { + "name": { + "enum": [ + "install", + ], + }, + "params": { + "$id": "/schemas/{}", + "anyOf": [ + { + "type": "object", + }, + { + "type": "array", + }, + ], + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "Search page", + "properties": { + "name": { + "enum": [ + "search", + ], + }, + "params": { + "additionalProperties": false, + "properties": { + "buckets": { + "description": "A list of buckets to search in (keep empty to search in all buckets)", + "items": { + "type": "string", + }, + "title": "Search buckets", + "type": "array", + }, + "order": { + "$comment": "/schemas/enums", + "anyOf": [ + { + "enum": [ + "BEST_MATCH", + ], + "title": "BEST_MATCH", + }, + { + "enum": [ + "NEWEST", + ], + "title": "NEWEST", + }, + { + "enum": [ + "OLDEST", + ], + "title": "OLDEST", + }, + ], + "description": "Order of search results", + "title": "Search result order", + }, + "params": { + "anyOf": [ + { + "additionalProperties": false, + "properties": { + "filter": { + "description": "Filter results by system metadata", + "items": { + "anyOf": [ + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "modified", + ], + }, + "predicate": { + "$ref": "#/$defs/DatetimePredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Last modified date", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "size", + ], + }, + "predicate": { + "$ref": "#/$defs/NumberPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "File size in bytes", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "name", + ], + }, + "predicate": { + "$ref": "#/$defs/KeywordWildcardPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Package name (aka namespace or handle)", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "hash", + ], + }, + "predicate": { + "$ref": "#/$defs/KeywordWildcardPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Package revision hash", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "entries", + ], + }, + "predicate": { + "$ref": "#/$defs/NumberPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Number of package entries", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "comment", + ], + }, + "predicate": { + "$ref": "#/$defs/TextPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Package revision comment (aka commit message)", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "workflow", + ], + }, + "predicate": { + "$ref": "#/$defs/KeywordEnumPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Package workflow", + "type": "object", + }, + ], + }, + "title": "Result filters (system metadata)", + "type": "array", + }, + "resultType": { + "enum": [ + "p", + ], + "title": "Result type: Quilt Package", + }, + "userMetaFilters": { + "description": "Filter results by user metadata", + "items": { + "additionalProperties": false, + "properties": { + "path": { + "type": "string", + }, + "predicate": { + "anyOf": [ + { + "$ref": "#/$defs/BooleanTaggedPredicate", + }, + { + "$ref": "#/$defs/DatetimeTaggedPredicate", + }, + { + "$ref": "#/$defs/KeywordEnumTaggedPredicate", + }, + { + "$ref": "#/$defs/KeywordWildcardTaggedPredicate", + }, + { + "$ref": "#/$defs/NumberTaggedPredicate", + }, + { + "$ref": "#/$defs/TextTaggedPredicate", + }, + ], + }, + }, + "required": [ + "path", + "predicate", + ], + "type": "object", + }, + "title": "Result filters (user metadata)", + "type": "array", + }, + }, + "required": [ + "resultType", + ], + "title": "Package-specific search parameters", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "filter": { + "items": { + "anyOf": [ + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "modified", + ], + }, + "predicate": { + "$ref": "#/$defs/DatetimePredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Last modified date", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "size", + ], + }, + "predicate": { + "$ref": "#/$defs/NumberPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "File size in bytes", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "ext", + ], + }, + "predicate": { + "$ref": "#/$defs/KeywordEnumPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "File extensions (with a leading dot)", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "key", + ], + }, + "predicate": { + "$ref": "#/$defs/KeywordWildcardPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "File name (aka S3 Object Key)", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "content", + ], + }, + "predicate": { + "$ref": "#/$defs/TextPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Indexed text contents", + "type": "object", + }, + { + "additionalProperties": false, + "properties": { + "key": { + "enum": [ + "deleted", + ], + }, + "predicate": { + "$ref": "#/$defs/BooleanPredicateState", + }, + }, + "required": [ + "key", + "predicate", + ], + "title": "Whether a file is a delete marker", + "type": "object", + }, + ], + }, + "title": "Result filters", + "type": "array", + }, + "resultType": { + "enum": [ + "o", + ], + "title": "Result type: S3 Object", + }, + }, + "required": [ + "resultType", + ], + "title": "Object-specific search parameters", + "type": "object", + }, + ], + "title": "Result type-specific parameters", + }, + "searchString": { + "description": "A String to search for. ElasticSearch syntax supported. For packages, searches in package name, comment (commit message), and metadata. For objects, searches in object key and indexed content.", + "title": "Search string", + "type": "string", + }, + }, + "required": [ + "params", + ], + "type": "object", + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "TBD", + "properties": { + "name": { + "enum": [ + "activate", + ], + }, + "params": { + "$id": "/schemas/{}", + "anyOf": [ + { + "type": "object", + }, + { + "type": "array", + }, + ], + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "S3 Object (aka File) page", + "properties": { + "name": { + "enum": [ + "bucket.object", + ], + }, + "params": { + "additionalProperties": false, + "properties": { + "bucket": { + "type": "string", + }, + "mode": { + "description": "Contents preview mode", + "title": "Viewing Mode", + "type": "string", + }, + "path": { + "description": "S3 Object Key aka File Path", + "title": "Path", + "type": "string", + }, + "version": { + "description": "S3 Object Version ID (omit for latest version)", + "title": "Version ID", + "type": "string", + }, + }, + "required": [ + "bucket", + "path", + ], + "type": "object", + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "S3 Prefix (aka Directory) page", + "properties": { + "name": { + "enum": [ + "bucket.prefix", + ], + }, + "params": { + "additionalProperties": false, + "properties": { + "bucket": { + "type": "string", + }, + "path": { + "description": "S3 Prefix aka Directory Path", + "title": "Path", + "type": "string", + }, + }, + "required": [ + "bucket", + "path", + ], + "type": "object", + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + { + "additionalProperties": false, + "description": "Bucket overview page", + "properties": { + "name": { + "enum": [ + "bucket.overview", + ], + }, + "params": { + "additionalProperties": false, + "properties": { + "bucket": { + "type": "string", + }, + }, + "required": [ + "bucket", + ], + "type": "object", + }, + }, + "required": [ + "name", + "params", + ], + "type": "object", + }, + ], + }, + }, + "required": [ + "route", + ], + "title": "navigate the catalog", + "type": "object", +} +`; diff --git a/catalog/app/components/Assistant/Model/GlobalContext/index.ts b/catalog/app/components/Assistant/Model/GlobalContext/index.ts new file mode 100644 index 00000000000..44da93c352d --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/index.ts @@ -0,0 +1,17 @@ +import * as Context from '../Context' + +import { useNavigate, useRouteContext } from './navigation' +import { useGetObject } from './preview' +import { useStackInfo } from './stack' + +export function useGlobalContext() { + Context.usePushContext({ + tools: { + navigate: useNavigate(), + catalog_global_getObject: useGetObject(), + }, + messages: [useStackInfo(), useRouteContext()], + }) +} + +export { useGlobalContext as use } diff --git a/catalog/app/components/Assistant/Model/GlobalContext/navigation.spec.ts b/catalog/app/components/Assistant/Model/GlobalContext/navigation.spec.ts new file mode 100644 index 00000000000..8983f4e61f7 --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/navigation.spec.ts @@ -0,0 +1,98 @@ +import * as Eff from 'effect' +import { JSONSchema, Schema } from '@effect/schema' + +import * as nav from './navigation' + +jest.mock( + 'constants/config', + jest.fn(() => ({})), +) + +describe('components/Assistant/Model/GlobalTools/navigation', () => { + describe('NavigateSchema', () => { + describe('produced JSON Schema', () => { + it('should match the snapshot', () => { + const jsonSchema = JSONSchema.make(nav.NavigateSchema) + expect(jsonSchema).toMatchSnapshot() + }) + }) + }) + describe('routes', () => { + const TEST_CASES = [ + { + route: { + name: 'search', + params: { + searchString: '', + buckets: [], + order: 'NEWEST', + params: { + resultType: 'p', + filter: [], + userMetaFilters: [ + { + path: '/author', + predicate: { + type: 'KeywordEnum', + value: { + terms: ['Aneesh', 'Maksim'], + }, + }, + }, + ], + }, + }, + }, + loc: { + pathname: '/search', + search: 'o=NEWEST&meta.e%2Fauthor=%22Aneesh%22%2C%22Maksim%22', + hash: '', + }, + }, + { + route: { + name: 'bucket.overview', + params: { + bucket: 'test-bucket', + }, + }, + loc: { + pathname: '/b/test-bucket', + search: '', + hash: '', + }, + }, + { + route: { + name: 'bucket.prefix', + params: { + bucket: 'quilt-example', + path: 'data/random-data-benchmark/100kb/', + }, + }, + loc: { + pathname: '/b/quilt-example/tree/data/random-data-benchmark/100kb/', + search: '', + hash: '', + }, + }, + ] + + const encode = Eff.flow( + Schema.decodeUnknown(nav.NavigableRouteSchema), + Eff.Effect.andThen(nav.locationFromRoute), + Eff.Effect.runPromise, + ) + + TEST_CASES.forEach((tc, i) => { + describe(`${i + 1}: ${tc.route.name}`, () => { + it('should encode', async () => { + expect(await encode(tc.route)).toEqual(tc.loc) + }) + it('should decode', async () => { + expect(nav.matchLocation(tc.loc)?.decoded).toEqual(tc.route) + }) + }) + }) + }) +}) diff --git a/catalog/app/components/Assistant/Model/GlobalContext/navigation.ts b/catalog/app/components/Assistant/Model/GlobalContext/navigation.ts new file mode 100644 index 00000000000..df76765cf59 --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/navigation.ts @@ -0,0 +1,256 @@ +import * as Eff from 'effect' +import * as React from 'react' +import * as RR from 'react-router-dom' +import { Schema as S } from '@effect/schema' + +import bucketRoutes from 'containers/Bucket/Routes' +import search from 'containers/Search/Route' +import * as ROUTES from 'constants/routes' +import * as Log from 'utils/Logging' +import * as Nav from 'utils/Navigation' +import * as XML from 'utils/XML' + +import * as Content from '../Content' +import * as Context from '../Context' +import * as Tool from '../Tool' + +const MODULE = 'GlobalContext/navigation' + +// the routes are in the order of matching +// TODO: specify/describe all the *relevant* routes +const routeList = [ + Nav.makeRoute({ + name: 'home', + path: ROUTES.home.path, + exact: true, + description: 'Home page', + // searchParams: S.Struct({ + // // XXX: passing this param doesn't actually work bc of how it's implemented in + // // website/pages/Landing/Buckets/Buckets.js + // q: SearchParamLastOpt.annotations({ + // title: 'bucket filter query', + // description: 'filter buckets in the bucket grid', + // }), + // }), + }), + Nav.makeRoute({ + name: 'install', + path: ROUTES.install.path, + description: 'Installation page', + }), + search, + Nav.makeRoute({ + name: 'activate', + path: ROUTES.activate.path, + description: 'TBD', + }), + // + // + // + // + // + // + // + // + // + // + // + // {(cfg.passwordAuth === true || cfg.ssoAuth === true) && ( + // + // + // + // )} + // {!!cfg.passwordAuth && ( + // + // + // + // )} + // {!!cfg.passwordAuth && ( + // + // + // + // )} + // + // + // + // + // + // + // + // + // + // {cfg.mode === 'OPEN' && ( + // // XXX: show profile in all modes? + // + // + // + // )} + // + // + // + // + // + // + // + // + // + ...bucketRoutes, +] as const + +type KnownRoute = (typeof routeList)[number] +type KnownRouteMap = { + [K in KnownRoute['name']]: Extract +} +export const routes = Object.fromEntries( + routeList.map((r) => [r.name, r]), +) as KnownRouteMap + +export const NavigableRouteSchema = S.Union( + ...routeList.map((r) => r.navigableRouteSchema), +) + +type NavigableRoute = typeof NavigableRouteSchema.Type + +export const locationFromRoute = (route: NavigableRoute) => + // @ts-expect-error + S.encode(routes[route.name].paramsSchema)(route.params) + +type History = ReturnType + +const WAIT_TIMEOUT = Eff.Duration.seconds(30) +const NAV_LAG = Eff.Duration.seconds(1) + +const navigate = ( + route: NavigableRoute, + history: History, + markers: Eff.SubscriptionRef.SubscriptionRef>, +) => + Log.scoped({ + name: `${MODULE}.navigate`, + enter: [`to: ${route.name}`, Log.br, 'params:', route.params], + })( + Eff.pipe( + locationFromRoute(route), + Eff.Effect.tap((loc) => Eff.Effect.log(`Navigating to location:`, Log.br, loc)), + Eff.Effect.andThen((loc) => Eff.Effect.sync(() => history.push(loc))), + Eff.Effect.andThen(() => + Eff.Effect.gen(function* () { + const { waitForMarkers } = routes[route.name] + if (!waitForMarkers.length) return + yield* Eff.Effect.log(`Waiting for markers: ${waitForMarkers.join(', ')}`) + yield* Eff.Effect.sleep(NAV_LAG) + yield* Eff.pipe( + markers.changes, + Eff.Stream.timeoutFail(() => ({ _tag: 'timeout' as const }), WAIT_TIMEOUT), + Eff.Stream.runForEachWhile((currentMarkers) => + Eff.Effect.succeed(!waitForMarkers.every((k) => currentMarkers[k])), + ), + Eff.Effect.andThen(() => Eff.Effect.log('Markers found')), + Eff.Effect.catchTag('timeout', () => + Eff.Effect.log( + `Timed out after ${Eff.Duration.format( + WAIT_TIMEOUT, + )} while waiting for markers`, + ), + ), + ) + }), + ), + ), + ) + +export interface Match { + descriptor: KnownRoute + decoded: NavigableRoute | null +} + +export const matchLocation = (loc: typeof Nav.Location.Type): Match | null => + Eff.pipe( + Eff.Array.findFirst(routeList, (route) => + RR.matchPath(loc.pathname, { + path: route.path, + exact: route.exact, + strict: route.strict, + }) + ? Eff.Option.some(route) + : Eff.Option.none(), + ), + Eff.Option.map((descriptor) => ({ + descriptor, + decoded: Eff.pipe( + loc, + // @ts-expect-error + S.decodeOption(descriptor.paramsSchema), + Eff.Option.map((params) => ({ name: descriptor.name, params }) as NavigableRoute), + Eff.Option.getOrNull, + ), + })), + Eff.Option.getOrNull, + ) + +export function useRouteContext() { + const loc = RR.useLocation() + + const match = React.useMemo( + () => matchLocation({ pathname: loc.pathname, search: loc.search, hash: '' }), + [loc.pathname, loc.search], + ) + + const description = React.useMemo(() => { + if (!match) return '' + const params = match.decoded?.params + ? XML.tag('parameters', {}, JSON.stringify(match.decoded.params, null, 2)) + : null + return XML.tag( + 'route-info', + {}, + `Name: "${match.descriptor.name}"`, + XML.tag('description', {}, match.descriptor.description), + params, + ) + }, [match]) + + const msg = React.useMemo( + () => + XML.tag( + 'viewport', + {}, + XML.tag('current-location', {}, JSON.stringify(loc, null, 2)), + description, + 'Refer to "navigate" tool schema for navigable routes and their parameters.', + ).toString(), + [description, loc], + ) + + return msg +} + +export const NavigateSchema = S.Struct({ + route: NavigableRouteSchema, +}).annotations({ + title: 'navigate the catalog', + description: 'navigate to a provided route', +}) + +export function useNavigate() { + const history = RR.useHistory() + const markers = Context.useMarkersRef() + + return Tool.useMakeTool( + NavigateSchema, + ({ route }) => + Eff.pipe( + navigate(route, history, markers), + Eff.Effect.match({ + onSuccess: () => + Tool.succeed(Content.text(`Navigating to the '${route.name}' route.`)), + onFailure: (e) => + Tool.fail( + Content.text(`Failed to navigate to the '${route.name}' route: ${e}`), + ), + }), + Eff.Effect.map(Eff.Option.some), + ), + [history, markers], + ) +} diff --git a/catalog/app/components/Assistant/Model/GlobalContext/preview.ts b/catalog/app/components/Assistant/Model/GlobalContext/preview.ts new file mode 100644 index 00000000000..087f6c634b5 --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/preview.ts @@ -0,0 +1,439 @@ +import { basename, extname } from 'path' + +import type AWSSDK from 'aws-sdk' +import * as Eff from 'effect' +import * as S from '@effect/schema/Schema' + +import cfg from 'constants/config' +import { S3ObjectLocation } from 'model/S3' +import * as AWS from 'utils/AWS' +import * as Log from 'utils/Logging' +import mkSearch from 'utils/mkSearch' + +import * as Content from '../Content' +import * as Tool from '../Tool' + +const MODULE = 'GlobalContext/preview' + +type AWSEffect = Eff.Effect.Effect + +// TODO: move this out +export class S3 extends Eff.Context.Tag('S3')< + S3, + { + headObject(handle: S3ObjectLocation): AWSEffect + getObject(handle: S3ObjectLocation): AWSEffect + } +>() {} + +export const fromS3Client = (client: AWSSDK.S3) => + Eff.Layer.succeed(S3, { + headObject: (handle) => + Eff.Effect.tryPromise({ + try: () => + client + .headObject({ + Bucket: handle.bucket, + Key: handle.key, + VersionId: handle.version, + }) + .promise(), + catch: (e) => e as AWSSDK.AWSError, + }), + getObject: (handle) => + Eff.Effect.tryPromise({ + try: () => + client + .getObject({ + Bucket: handle.bucket, + Key: handle.key, + VersionId: handle.version, + }) + .promise(), + catch: (e) => e as AWSSDK.AWSError, + }), + }) + +export interface S3SignerOptions { + urlExpiration?: number // in seconds + forceProxy?: boolean +} + +export class S3Signer extends Eff.Context.Tag('S3Signer')< + S3Signer, + { + sign(handle: S3ObjectLocation, options?: S3SignerOptions): Eff.Effect.Effect + } +>() {} + +export const fromS3Signer = ( + signer: (handle: S3ObjectLocation, options?: S3SignerOptions) => string, +) => + Eff.Layer.succeed(S3Signer, { + sign: (...args) => Eff.Effect.sync(() => signer(...args)), + }) + +// The document file name can only contain: +// - alphanumeric characters +// - whitespace characters +// - hyphens +// - parentheses and square brackets +// The name can't contain more than one consecutive whitespace character +const normalizeDocumentName = (name: string) => + name + .replace(/[^a-zA-Z0-9\s\-\(\)\[\]]/g, ' ') // Remove invalid characters + .replace(/\s+/g, ' ') // Replace multiple whitespace characters with a single space + .trim() // Remove leading and trailing whitespace + +const THRESHOLD = 500 * 1024 // 500 KiB + +const GetObjectSchema = S.Struct({ + bucket: S.String, + key: S.String, + version: S.optional(S.String).annotations({}), + // XXX: force type? +}).annotations({ + description: 'Get contents and metadata of an S3 object', +}) + +// return format: +// - metadata block (text or json) +// - content block (json | text | image | document) +export function useGetObject() { + const s3Client = AWS.S3.use() + const s3Signer = AWS.Signer.useS3Signer() + + return Tool.useMakeTool( + GetObjectSchema, + Eff.flow( + getObject, + Eff.Effect.map(Eff.Option.some), + Eff.Effect.provide(fromS3Client(s3Client)), + Eff.Effect.provide(fromS3Signer(s3Signer)), + ), + [s3Client, s3Signer], + ) +} + +const getObject = (handle: S3ObjectLocation) => + Log.scoped({ + name: `${MODULE}.getObject`, + enter: [Log.br, 'handle:', handle], + })( + Eff.Effect.gen(function* () { + const s3 = yield* S3 + const headE = yield* Eff.Effect.either(s3.headObject(handle)) + if (Eff.Either.isLeft(headE)) { + return Tool.fail( + Content.text( + 'Error while getting S3 object metadata:\n', + `\n${headE.left}'n`, + ), + ) + } + const head = headE.right + + const metaBlock = Content.text( + 'Got S3 object metadata:\n', + `\n${JSON.stringify(head, null, 2)}\n`, + ) + + const size = head.ContentLength + if (size == null) { + return Tool.succeed( + metaBlock, + Content.text('Could not determine object content length'), + ) + } + + const fileType = detectFileType(handle.key) + + const contentBlocks: Content.ToolResultContentBlock[] = yield* FileType.$match( + fileType, + { + Image: () => + getImagePreview(handle).pipe( + Eff.Effect.map(({ format, bytes }) => + Content.ToolResultContentBlock.Image({ + format, + source: bytes as $TSFixMe, + }), + ), + Eff.Effect.catchAll((e) => + Eff.Effect.succeed( + Content.text( + 'Error while getting image preview:\n', + `\n${e}'n`, + ), + ), + ), + Eff.Effect.map(Eff.Array.of), + ), + Document: ({ format }) => + size > THRESHOLD + ? Eff.Effect.succeed([ + Content.text('Object is too large to include its contents directly'), + ]) + : getDocumentPreview(handle, format), + Unidentified: () => + Eff.Effect.succeed([ + Content.text( + 'Error while getting object contents:\n', + `\nUnidentified file type\n`, + ), + ]), + }, + ) + + return Tool.succeed(metaBlock, ...contentBlocks) + }), + ) + +const SUPPORTED_IMAGE_EXTENSIONS = [ + '.jpg', + '.jpeg', + '.png', + '.gif', + '.webp', + '.bmp', + '.tiff', + '.tif', + '.czi', +] + +// =< 1568px as per anthropic/claude guidelines +const PREVIEW_SIZE = `w1024h768` + +interface ImagePreview { + format: Content.ImageFormat + bytes: ArrayBuffer +} + +const getImagePreview = (handle: S3ObjectLocation) => + Eff.Effect.gen(function* () { + const signer = yield* S3Signer + const url = yield* signer.sign(handle) + const src = `${cfg.apiGatewayEndpoint}/thumbnail${mkSearch({ + url, + size: PREVIEW_SIZE, + })}` + const r = yield* Eff.Effect.tryPromise(() => fetch(src)) + if (r.status !== 200) { + const text = yield* Eff.Effect.promise(() => r.text()) + return yield* new Eff.Cause.UnknownException(text, text) + } + const bytes = yield* Eff.Effect.promise(() => + r.blob().then((blob) => blob.arrayBuffer()), + ) + const format = yield* Eff.Effect.try(() => { + const info = r.headers.get('X-Quilt-Info') + if (!info) throw new Error('X-Quilt-Info header not found') + const parsed = JSON.parse(info) + switch (parsed.thumbnail_format) { + case 'JPG': + return 'jpeg' + case 'PNG': + return 'png' + case 'GIF': + return 'gif' + default: + throw new Error(`Unknown thumbnail format: ${parsed.thumbnail_format}`) + } + }) + return { format, bytes } as ImagePreview + }) + +const getDocumentPreview = (handle: S3ObjectLocation, format: Content.DocumentFormat) => + S3.pipe( + Eff.Effect.andThen((s3) => s3.getObject(handle)), + Eff.Effect.map((body) => { + const blob = body.Body + if (!blob) { + return Content.text('Could not get object contents') + } + return Content.ToolResultContentBlock.Document({ + name: normalizeDocumentName( + `${handle.bucket} ${handle.key} ${handle.version || ''}`, + ), + format, + source: blob as $TSFixMe, + }) + }), + Eff.Effect.catchAll((e) => + Eff.Effect.succeed( + Content.text( + 'Error while getting object contents:\n', + `\n${e}'n`, + ), + ), + ), + Eff.Effect.map(Eff.Array.of), + ) + +// const hasNoExt = (key: string) => !extname(key) +// +// const isCsv = utils.extIs('.csv') +// +// const isExcel = utils.extIn(['.xls', '.xlsx']) +// +// const isJsonl = utils.extIs('.jsonl') +// +// const isParquet = R.anyPass([ +// utils.extIn(['.parquet', '.pq']), +// R.test(/.+_0$/), +// R.test(/[.-]c\d{3,5}$/gi), +// ]) +// +// const isTsv = utils.extIn(['.tsv', '.tab']) +// +// +// type TabularType = 'csv' | 'jsonl' | 'excel' | 'parquet' | 'tsv' +// +// const detectTabularType: (type: string) => TabularType = R.pipe( +// utils.stripCompression, +// R.cond([ +// [isCsv, R.always('csv')], +// [isExcel, R.always('excel')], +// [isJsonl, R.always('jsonl')], +// [isParquet, R.always('parquet')], +// [isTsv, R.always('tsv')], +// [R.T, R.always('csv')], +// ]), +// ) + +const LANGS = { + accesslog: /\.log$/, + bash: /\.(ba|z)?sh$/, + clojure: /\.clj$/, + coffeescript: /\.(coffee|cson|iced)$/, + coq: /\.v$/, + c: /\.(c|h)$/, + cpp: /\.((c(c|\+\+|pp|xx)?)|(h(\+\+|pp|xx)?))$/, + csharp: /\.cs$/, + css: /\.css$/, + diff: /\.(diff|patch)$/, + dockerfile: /^dockerfile$/, + erlang: /\.erl$/, + go: /\.go$/, + haskell: /\.hs$/, + ini: /\.(ini|toml)$/, + java: /\.(java|jsp)$/, + javascript: /\.m?jsx?$/, + json: /\.jsonl?$/, + lisp: /\.lisp$/, + makefile: /^(gnu)?makefile$/, + matlab: /\.m$/, + ocaml: /\.mli?$/, + perl: /\.pl$/, + php: /\.php[3-7]?$/, + plaintext: + /((^license)|(^readme)|(^\.\w*(ignore|rc|config))|(\.txt)|(\.(c|t)sv)|(\.(big)?bed)|(\.cef)|(\.fa)|(\.fsa)|(\.fasta)|(\.(san)?fastq)|(\.fq)|(\.sam)|(\.gff(2|3)?)|(\.gtf)|(\.index)|(\.readme)|(changelog)|(.*notes)|(\.pdbqt)|(\.results)(\.(inn|out)ie))$/, + python: /\.(py|gyp)$/, + r: /\.r$/, + ruby: /\.rb$/, + rust: /\.rs$/, + scala: /\.scala$/, + scheme: /\.s(s|ls|cm)$/, + sql: /\.sql$/, + typescript: /\.tsx?$/, + xml: /\.(xml|x?html|rss|atom|xjb|xsd|xsl|plist)$/, + yaml: /((\.ya?ml$)|(^snakefile))/, +} + +const langPairs = Object.entries(LANGS) + +function isText(name: string) { + const normalized = basename(name).toLowerCase() + return langPairs.some(([, re]) => re.test(normalized)) +} + +// const loaderChain = { +// Audio: extIn(['.flac', '.mp3', '.ogg', '.ts', '.tsa', '.wav']), +// Fcs: R.pipe(utils.stripCompression, utils.extIs('.fcs')), +// Json: 'json', +// Manifest: R.allPass([R.startsWith('.quilt/packages/'), hasNoExt]), +// NamedPackage: R.startsWith('.quilt/named_packages/'), +// Ngl: R.pipe( +// utils.stripCompression, +// utils.extIn(['.cif', '.ent', '.mol', '.mol2', '.pdb', '.sdf']),), +// Notebook: R.pipe(utils.stripCompression, utils.extIs('.ipynb')), +// Tabular: R.pipe( +// utils.stripCompression, +// R.anyPass([isCsv, isExcel, isJsonl, isParquet, isTsv]),), +// Vcf: R.pipe(utils.stripCompression, utils.extIs('.vcf')), +// Video: utils.extIn(['.m2t', '.m2ts', '.mp4', '.webm']), +// Text: R.pipe(findLang, Boolean), +// } +// TODO: convert pptx? + +type FileType = Eff.Data.TaggedEnum<{ + Image: {} + Document: { + readonly format: Content.DocumentFormat + } + Unidentified: {} +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +const FileType = Eff.Data.taggedEnum() + +// const getExt = (key: string) => extname(key).toLowerCase().slice(1) + +// const COMPRESSION_TYPES = { gz: '.gz', bz2: '.bz2' } +// type CompressionType = keyof typeof COMPRESSION_TYPES +// +// const getCompression = (key: string): [string, CompressionType | null] => { +// for (const [type, ext] of Object.entries(COMPRESSION_TYPES)) { +// if (key.toLowerCase().endsWith(ext)) { +// return [ +// key.slice(0, -ext.length), +// type as CompressionType, +// ] +// } +// } +// return [key, null] +// } + +// TODO +const detectFileType = (key: string): FileType => { + // XXX: support compression? + // const [withoutCompression, compression] = getCompression(key) + // const ext = extname(withoutCompression).toLowerCase() + const ext = extname(key).toLowerCase() + + if (SUPPORTED_IMAGE_EXTENSIONS.includes(ext)) { + return FileType.Image() + } + if (['.htm', '.html'].includes(ext)) { + return FileType.Document({ format: 'html' }) + } + if (['.md', '.rmd'].includes(ext)) { + return FileType.Document({ format: 'md' }) + } + if (ext === '.pdf') { + return FileType.Document({ format: 'pdf' }) + } + if (ext === '.csv') { + // XXX: does it support TSV? + return FileType.Document({ format: 'csv' }) + } + if (ext === '.docx') { + return FileType.Document({ format: 'docx' }) + } + if (ext === '.doc') { + return FileType.Document({ format: 'doc' }) + } + if (ext === '.xls') { + return FileType.Document({ format: 'xls' }) + } + if (ext === '.xlsx') { + return FileType.Document({ format: 'xlsx' }) + } + if (isText(key)) { + return FileType.Document({ format: 'txt' }) + } + if (ext === '.ipynb') { + return FileType.Document({ format: 'txt' }) + } + return FileType.Unidentified() +} diff --git a/catalog/app/components/Assistant/Model/GlobalContext/stack.ts b/catalog/app/components/Assistant/Model/GlobalContext/stack.ts new file mode 100644 index 00000000000..147b88fc1b8 --- /dev/null +++ b/catalog/app/components/Assistant/Model/GlobalContext/stack.ts @@ -0,0 +1,33 @@ +import * as React from 'react' + +import * as BucketConfig from 'utils/BucketConfig' +import * as XML from 'utils/XML' + +export function useStackInfo() { + const bucketConfigs = BucketConfig.useRelevantBucketConfigs() + + return React.useMemo(() => { + const buckets = XML.tag( + 'buckets', + {}, + 'Buckets attached to this stack:', + ...bucketConfigs.map((b) => + XML.tag( + 'bucket', + {}, + JSON.stringify( + { + name: b.name, + title: b.title, + description: b.description, + tags: b.tags, + }, + null, + 2, + ), + ), + ), + ) + return XML.tag('quilt-stack-info', {}, buckets).toString() + }, [bucketConfigs]) +} diff --git a/catalog/app/components/Assistant/Model/LLM.ts b/catalog/app/components/Assistant/Model/LLM.ts new file mode 100644 index 00000000000..98be8ad3b26 --- /dev/null +++ b/catalog/app/components/Assistant/Model/LLM.ts @@ -0,0 +1,77 @@ +import BedrockRuntime from 'aws-sdk/clients/bedrockruntime' +import * as Eff from 'effect' + +import { JsonRecord } from 'utils/types' + +import * as Content from './Content' +import * as Tool from './Tool' + +export type Role = 'user' | 'assistant' + +// XXX: explicitly restrict specific content blocks for each role? +export interface PromptMessage { + role: Role + content: Content.PromptMessageContentBlock +} + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const PromptMessage = Eff.Data.case() + +export const userMessage = (content: Content.PromptMessageContentBlock) => + PromptMessage({ role: 'user', content }) + +export const assistantMessage = (content: Content.PromptMessageContentBlock) => + PromptMessage({ role: 'assistant', content }) + +export type ToolChoice = Eff.Data.TaggedEnum<{ + Auto: {} + Any: {} + Specific: { + readonly name: string + } +}> + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const ToolChoice = Eff.Data.taggedEnum() + +export interface ToolConfig { + tools: Tool.Collection + choice?: ToolChoice +} + +export interface Prompt { + system: string + messages: Eff.Array.NonEmptyArray + toolConfig?: ToolConfig +} + +export interface Options { + inferenceConfig?: BedrockRuntime.InferenceConfiguration + guardrailConfig?: BedrockRuntime.GuardrailConfiguration + additionalModelRequestFields?: JsonRecord + additionalModelResponseFieldPaths?: BedrockRuntime.ConverseRequestAdditionalModelResponseFieldPathsList +} + +interface ConverseResponse { + content: Eff.Option.Option + backendResponse: BedrockRuntime.ConverseResponse +} + +export class LLMError { + message: string + + constructor({ message }: { message: string }) { + this.message = message + } +} + +// a service +export class LLM extends Eff.Context.Tag('LLM')< + LLM, + { + converse: ( + prompt: Prompt, + opts?: Options, + ) => Eff.Effect.Effect + } +>() {} diff --git a/catalog/app/components/Assistant/Model/Tool.ts b/catalog/app/components/Assistant/Model/Tool.ts new file mode 100644 index 00000000000..1d9bb21b5a8 --- /dev/null +++ b/catalog/app/components/Assistant/Model/Tool.ts @@ -0,0 +1,93 @@ +import * as Eff from 'effect' +import * as React from 'react' +import { JSONSchema, Schema } from '@effect/schema' + +import * as Content from './Content' + +export interface Result { + readonly content: Content.ToolResultContentBlock[] + readonly status: 'success' | 'error' +} + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Result = Eff.Data.case() + +export const succeed = (...content: Content.ToolResultContentBlock[]) => + Result({ status: 'success', content }) + +export const fail = (...content: Content.ToolResultContentBlock[]) => + Result({ status: 'error', content }) + +export type ResultOption = Eff.Option.Option + +export type Executor = (params: I) => Eff.Effect.Effect + +export interface Descriptor { + description?: string + schema: JSONSchema.JsonSchema7Root + executor: Executor +} + +export type Collection = Record> + +export function make(schema: Schema.Schema, fn: Executor): Descriptor { + const jsonSchema = JSONSchema.make(schema) + + const decode = Schema.decodeUnknown(schema, { + errors: 'all', + onExcessProperty: 'error', + }) + + const wrappedFn = (params: unknown) => + decode(params).pipe( + Eff.Effect.andThen(fn), + Eff.Effect.catchAll((error) => + Eff.Effect.succeed( + Eff.Option.some( + Result({ + status: 'error', + content: [ + Content.ToolResultContentBlock.Text({ + text: `Error while executing tool:\n${error.message}`, + }), + ], + }), + ), + ), + ), + ) + + return { + description: jsonSchema.description, + schema: jsonSchema, + executor: wrappedFn, + } +} + +const EMPTY_DEPS: React.DependencyList = [] + +export function useMakeTool( + schema: Schema.Schema, + fn: Executor, + deps: React.DependencyList = EMPTY_DEPS, +): Descriptor { + // eslint-disable-next-line react-hooks/exhaustive-deps + const fnMemo = React.useCallback(fn, deps) + return React.useMemo(() => make(schema, fnMemo), [schema, fnMemo]) +} + +export const execute = (tools: Collection, name: string, input: unknown) => + name in tools + ? tools[name].executor(input) + : Eff.Effect.succeed( + Eff.Option.some( + Result({ + status: 'error', + content: [ + Content.ToolResultContentBlock.Text({ + text: `Tool "${name}" not found`, + }), + ], + }), + ), + ) diff --git a/catalog/app/components/Assistant/Model/enabled.ts b/catalog/app/components/Assistant/Model/enabled.ts new file mode 100644 index 00000000000..ff230da0961 --- /dev/null +++ b/catalog/app/components/Assistant/Model/enabled.ts @@ -0,0 +1,9 @@ +import * as redux from 'react-redux' + +import { authenticated as authenticatedSelector } from 'containers/Auth/selectors' +import cfg from 'constants/config' + +export default function useIsEnabled() { + const authenticated = redux.useSelector(authenticatedSelector) + return cfg.qurator && authenticated +} diff --git a/catalog/app/components/Assistant/Model/index.ts b/catalog/app/components/Assistant/Model/index.ts new file mode 100644 index 00000000000..9eec41162a0 --- /dev/null +++ b/catalog/app/components/Assistant/Model/index.ts @@ -0,0 +1,9 @@ +export { useAssistant, useAssistantAPI, AssistantProvider } from './Assistant' +export { default as useIsEnabled } from './enabled' + +export * as Assistant from './Assistant' +export * as Content from './Content' +export * as Context from './Context' +export * as Conversation from './Conversation' +export * as LLM from './LLM' +export * as Tool from './Tool' diff --git a/catalog/app/components/Assistant/UI/Chat/Chat.tsx b/catalog/app/components/Assistant/UI/Chat/Chat.tsx new file mode 100644 index 00000000000..0b58b93acc6 --- /dev/null +++ b/catalog/app/components/Assistant/UI/Chat/Chat.tsx @@ -0,0 +1,489 @@ +import cx from 'classnames' +import * as Eff from 'effect' +import * as React from 'react' +import * as M from '@material-ui/core' + +import JsonDisplay from 'components/JsonDisplay' +import Markdown from 'components/Markdown' +import usePrevious from 'utils/usePrevious' + +import * as Model from '../../Model' + +import DevTools from './DevTools' +import Input from './Input' + +const BG = { + intense: M.colors.indigo[900], + bright: M.colors.indigo[500], + faint: M.colors.common.white, +} + +const useMessageContainerStyles = M.makeStyles((t) => ({ + align_left: {}, + align_right: {}, + color_intense: {}, + color_bright: {}, + color_faint: {}, + messageContainer: { + display: 'flex', + flexDirection: 'column', + gap: `${t.spacing(0.5)}px`, + '&$align_left': { + alignItems: 'flex-start', + }, + '&$align_right': { + alignItems: 'flex-end', + }, + }, + contentWrapper: { + display: 'flex', + flexDirection: 'column', + maxWidth: '100%', + }, + contentArea: { + borderRadius: `${t.spacing(1)}px`, + '$color_intense &': { + background: BG.intense, + color: M.fade(t.palette.common.white, 0.8), + }, + '$color_bright &': { + background: BG.bright, + color: t.palette.common.white, + }, + '$color_faint &': { + background: BG.faint, + color: t.palette.text.primary, + }, + '$align_right &': { + borderBottomRightRadius: 0, + }, + '$align_left &': { + borderBottomLeftRadius: 0, + }, + }, + contents: { + ...t.typography.body2, + padding: `${t.spacing(2)}px`, + }, + footer: { + ...t.typography.caption, + color: t.palette.text.hint, + display: 'flex', + gap: t.spacing(1), + justifyContent: 'flex-end', + paddingLeft: t.spacing(4), + paddingTop: '6px', + }, + actions: { + opacity: 0.7, + '$messageContainer:hover &': { + opacity: 1, + }, + }, +})) + +interface MessageContainerProps { + color?: 'intense' | 'bright' | 'faint' + align?: 'left' | 'right' + children: React.ReactNode + actions?: React.ReactNode + timestamp?: Date +} + +function MessageContainer({ + color = 'faint', + align = 'left', + children, + actions, + timestamp, +}: MessageContainerProps) { + const classes = useMessageContainerStyles() + return ( +
+
+
+
{children}
+
+ {!!(actions || timestamp) && ( +
+ {!!actions &&
{actions}
} + {timestamp && {timestamp.toLocaleTimeString()}} +
+ )} +
+
+ ) +} + +const useMessageActionStyles = M.makeStyles({ + action: { + cursor: 'pointer', + opacity: 0.7, + '&:hover': { + opacity: 1, + }, + }, +}) + +interface MessageActionProps { + children: React.ReactNode + className?: string + onClick?: () => void +} + +function MessageAction({ children, onClick }: MessageActionProps) { + const classes = useMessageActionStyles() + return ( + + {children} + + ) +} + +interface ConversationDispatchProps { + dispatch: Model.Assistant.API['dispatch'] +} + +interface ConversationStateProps { + state: Model.Conversation.State['_tag'] +} + +type MessageEventProps = ConversationDispatchProps & + ConversationStateProps & + ReturnType + +function MessageEvent({ + state, + id, + timestamp, + dispatch, + role, + content, +}: MessageEventProps) { + const discard = React.useMemo( + () => + state === 'Idle' ? () => dispatch(Model.Conversation.Action.Discard({ id })) : null, + [dispatch, id, state], + ) + + return ( + discard} + timestamp={timestamp} + > + {Model.Content.MessageContentBlock.$match(content, { + Text: ({ text }) => , + Image: ({ format }) => `${format} image`, + Document: ({ name, format }) => `${format} document "${name}"`, + })} + + ) +} + +type ToolUseEventProps = ConversationDispatchProps & + ConversationStateProps & + ReturnType + +function ToolUseEvent({ + state, + id, + timestamp, + toolUseId, + name, + input, + result, + dispatch, +}: ToolUseEventProps) { + const discard = React.useMemo( + () => + state === 'Idle' ? () => dispatch(Model.Conversation.Action.Discard({ id })) : null, + [dispatch, id, state], + ) + const details = React.useMemo( + () => ({ toolUseId, input, result }), + [toolUseId, input, result], + ) + return ( + discard} + > + + Tool Use: {name} ({result.status}) + + + + + + ) +} + +interface ToolUseStateProps extends ConversationDispatchProps { + timestamp: Date + calls: Model.Conversation.ToolCalls +} + +function ToolUseState({ timestamp, dispatch, calls }: ToolUseStateProps) { + const abort = React.useCallback( + () => dispatch(Model.Conversation.Action.Abort()), + [dispatch], + ) + + const details = React.useMemo( + () => Eff.Record.map(calls, Eff.Struct.pick('name', 'input')), + [calls], + ) + + const names = Eff.Record.collect(calls, (_k, v) => v.name) + + return ( + abort} + > + + Tool Use: {names.join(', ')} + + + + + + ) +} + +interface WaitingStateProps extends ConversationDispatchProps { + timestamp: Date +} + +function WaitingState({ timestamp, dispatch }: WaitingStateProps) { + const abort = React.useCallback( + () => dispatch(Model.Conversation.Action.Abort()), + [dispatch], + ) + return ( + abort} + > + Processing... + + ) +} + +interface MenuProps { + state: Model.Assistant.API['state'] + dispatch: Model.Assistant.API['dispatch'] + onToggleDevTools: () => void + devToolsOpen: boolean + className?: string +} + +function Menu({ state, dispatch, devToolsOpen, onToggleDevTools, className }: MenuProps) { + const [menuOpen, setMenuOpen] = React.useState(null) + + const isIdle = state._tag === 'Idle' + + const toggleMenu = React.useCallback( + (e: React.BaseSyntheticEvent) => + setMenuOpen((prev) => (prev ? null : e.currentTarget)), + [setMenuOpen], + ) + const closeMenu = React.useCallback(() => setMenuOpen(null), [setMenuOpen]) + + const startNewSession = React.useCallback(() => { + if (isIdle) dispatch(Model.Conversation.Action.Clear()) + closeMenu() + }, [closeMenu, isIdle, dispatch]) + + const showDevTools = React.useCallback(() => { + onToggleDevTools() + closeMenu() + }, [closeMenu, onToggleDevTools]) + + return ( + <> + + + menu + + + + + + close + + + + + + New session + + Developer Tools + + + ) +} + +const useStyles = M.makeStyles((t) => ({ + chat: { + display: 'flex', + flexDirection: 'column', + flexGrow: 1, + overflow: 'hidden', + }, + menu: { + position: 'absolute', + right: t.spacing(1), + top: t.spacing(1), + zIndex: 1, + }, + devTools: { + height: '50%', + }, + historyContainer: { + flexGrow: 1, + overflowY: 'auto', + // TODO: nice overflow markers + // position: 'relative', + // '&::before': { + // content: '""', + // position: 'absolute', + // }, + // '&::after': { + // }, + }, + history: { + display: 'flex', + flexDirection: 'column', + gap: `${t.spacing(2)}px`, + justifyContent: 'flex-end', + minHeight: '100%', + padding: `${t.spacing(3)}px`, + paddingBottom: 0, + }, + input: {}, +})) + +interface ChatProps { + state: Model.Assistant.API['state'] + dispatch: Model.Assistant.API['dispatch'] +} + +export default function Chat({ state, dispatch }: ChatProps) { + const classes = useStyles() + const scrollRef = React.useRef(null) + + const inputDisabled = state._tag !== 'Idle' + + const stateFingerprint = `${state._tag}:${state.timestamp.getTime()}` + + usePrevious(stateFingerprint, (prev) => { + if (prev && stateFingerprint !== prev) { + scrollRef.current?.scrollIntoView({ + block: 'end', + behavior: 'smooth', + }) + } + }) + + const ask = React.useCallback( + (content: string) => { + dispatch(Model.Conversation.Action.Ask({ content })) + }, + [dispatch], + ) + + const [devToolsOpen, setDevToolsOpen] = React.useState(false) + + const toggleDevTools = React.useCallback( + () => setDevToolsOpen((prev) => !prev), + [setDevToolsOpen], + ) + + return ( +
+ + + + + + +
+
+ + Hi! I'm Qurator, your AI assistant. How can I help you? + + {state.events + .filter((e) => !e.discarded) + .map( + Model.Conversation.Event.$match({ + Message: (event) => ( + + ), + ToolUse: (event) => ( + + ), + }), + )} + {Model.Conversation.State.$match(state, { + Idle: (s) => + Eff.Option.match(s.error, { + onSome: (e) => ( + + {e.message} +
+ {e.details} +
+ // TODO: retry / discard + ), + onNone: () => null, + }), + WaitingForAssistant: (s) => ( + + ), + ToolUse: (s) => ( + + ), + })} +
+
+
+ +
+ ) +} diff --git a/catalog/app/components/Assistant/UI/Chat/DevTools.tsx b/catalog/app/components/Assistant/UI/Chat/DevTools.tsx new file mode 100644 index 00000000000..d7ed7e5501c --- /dev/null +++ b/catalog/app/components/Assistant/UI/Chat/DevTools.tsx @@ -0,0 +1,61 @@ +import * as Eff from 'effect' +import * as React from 'react' +import * as M from '@material-ui/core' + +import JsonDisplay from 'components/JsonDisplay' + +import * as Model from '../../Model' + +const useStyles = M.makeStyles((t) => ({ + root: { + display: 'flex', + flexDirection: 'column', + height: '100%', + }, + heading: { + ...t.typography.h5, + borderBottom: `1px solid ${t.palette.divider}`, + lineHeight: '64px', + paddingLeft: t.spacing(2), + }, + contents: { + flexGrow: 1, + overflow: 'auto', + }, + json: { + margin: t.spacing(2, 0), + padding: t.spacing(0, 2), + }, +})) + +interface DevToolsProps { + state: Model.Assistant.API['state'] +} + +export default function DevTools({ state }: DevToolsProps) { + const classes = useStyles() + + const context = Model.Context.useAggregatedContext() + + const prompt = React.useMemo( + () => + Eff.Effect.runSync( + Model.Conversation.constructPrompt( + state.events.filter((e) => !e.discarded), + context, + ), + ), + [state, context], + ) + + return ( +
+

Qurator Developer Tools

+
+ + + +
+
+ ) +} diff --git a/catalog/app/components/Assistant/UI/Chat/Input.tsx b/catalog/app/components/Assistant/UI/Chat/Input.tsx new file mode 100644 index 00000000000..5f7b89b8663 --- /dev/null +++ b/catalog/app/components/Assistant/UI/Chat/Input.tsx @@ -0,0 +1,116 @@ +import cx from 'classnames' +import * as React from 'react' +import * as M from '@material-ui/core' + +import { createCustomAppTheme } from 'constants/style' + +const useStyles = M.makeStyles((t) => ({ + input: { + alignItems: 'center', + display: 'flex', + paddingLeft: `${t.spacing(2)}px`, + paddingRight: `${t.spacing(2)}px`, + }, + textField: { + marginTop: 0, + }, + hint: { + color: t.palette.text.hint, + }, +})) + +const backgroundColor = M.colors.indigo[900] +const backgroundColorLt = M.lighten(backgroundColor, 0.1) + +const useInputStyles = M.makeStyles({ + focused: {}, + disabled: {}, + root: { + backgroundColor, + borderRadius: '8px', + color: M.fade(M.colors.common.white, 0.8), + '&:hover': { + backgroundColor: backgroundColorLt, + // Reset on touch devices, it doesn't add specificity + '@media (hover: none)': { + backgroundColor, + }, + }, + '&$focused': { + backgroundColor, + }, + '&$disabled': { + backgroundColor: backgroundColorLt, + }, + }, +}) + +const useLabelStyles = M.makeStyles({ + focused: {}, + root: { + color: M.fade(M.colors.common.white, 0.6), + '&$focused': { + color: M.fade(M.colors.common.white, 0.6), + }, + }, +}) + +const darkTheme = createCustomAppTheme({ palette: { type: 'dark' } } as any) + +interface ChatInputProps { + className?: string + disabled?: boolean + onSubmit: (value: string) => void +} + +export default function ChatInput({ className, disabled, onSubmit }: ChatInputProps) { + const classes = useStyles() + + const [value, setValue] = React.useState('') + + const handleSubmit = React.useCallback( + (event) => { + event.preventDefault() + if (!value || disabled) return + onSubmit(value) + setValue('') + }, + [disabled, onSubmit, value], + ) + + return ( +
+ + setValue(e.target.value)} + value={value} + variant="filled" + autoFocus + fullWidth + margin="normal" + label="Ask Qurator" + helperText="Qurator may make errors. Verify important information." + InputProps={{ + disableUnderline: true, + classes: useInputStyles(), + endAdornment: ( + + + send + + + ), + }} + InputLabelProps={{ classes: useLabelStyles() }} + FormHelperTextProps={{ classes: { root: classes.hint } }} + /> + +
+ ) +} diff --git a/catalog/app/components/Assistant/UI/Chat/index.ts b/catalog/app/components/Assistant/UI/Chat/index.ts new file mode 100644 index 00000000000..9ceb934d30f --- /dev/null +++ b/catalog/app/components/Assistant/UI/Chat/index.ts @@ -0,0 +1 @@ +export { default } from './Chat' diff --git a/catalog/app/components/Assistant/UI/UI.tsx b/catalog/app/components/Assistant/UI/UI.tsx new file mode 100644 index 00000000000..ebedc8afacc --- /dev/null +++ b/catalog/app/components/Assistant/UI/UI.tsx @@ -0,0 +1,66 @@ +import * as React from 'react' +import * as M from '@material-ui/core' + +import * as style from 'constants/style' + +import * as Model from '../Model' +import Chat from './Chat' + +const useSidebarStyles = M.makeStyles({ + sidebar: { + background: M.colors.indigo[50], + display: 'flex', + height: '100%', + maxWidth: '40rem', + width: '50vw', + }, +}) + +function Sidebar() { + const classes = useSidebarStyles() + + const api = Model.useAssistantAPI() + if (!api) return null + + return ( + + +
+ +
+
+
+ ) +} + +const useTriggerStyles = M.makeStyles({ + trigger: { + bottom: '50px', + position: 'fixed', + right: '100px', + zIndex: 1, + }, +}) + +function Trigger() { + const classes = useTriggerStyles() + const api = Model.useAssistantAPI() + if (!api) return null + return ( + + + assistant + + + ) +} + +export function WithAssistantUI({ children }: React.PropsWithChildren<{}>) { + return ( + <> + {children} + + + + ) +} diff --git a/catalog/app/components/Assistant/UI/index.ts b/catalog/app/components/Assistant/UI/index.ts new file mode 100644 index 00000000000..d9cf6579ad9 --- /dev/null +++ b/catalog/app/components/Assistant/UI/index.ts @@ -0,0 +1 @@ +export * from './UI' diff --git a/catalog/app/components/Assistant/index.ts b/catalog/app/components/Assistant/index.ts new file mode 100644 index 00000000000..3bc0795dae1 --- /dev/null +++ b/catalog/app/components/Assistant/index.ts @@ -0,0 +1,14 @@ +import { useAssistant, AssistantProvider } from './Model' +import { WithAssistantUI } from './UI' + +export { + WithAssistantUI, + WithAssistantUI as WithUI, + AssistantProvider, + AssistantProvider as Provider, + useAssistant, + useAssistant as use, +} + +export * as Context from './Model/Context' +export * as Model from './Model' diff --git a/catalog/app/components/BreadCrumbs/BreadCrumbs.spec.tsx b/catalog/app/components/BreadCrumbs/BreadCrumbs.spec.tsx index 073f49f73b0..63b067feb24 100644 --- a/catalog/app/components/BreadCrumbs/BreadCrumbs.spec.tsx +++ b/catalog/app/components/BreadCrumbs/BreadCrumbs.spec.tsx @@ -80,7 +80,7 @@ describe('components/BreadCrumbs', () => { expect(tree).toMatchSnapshot() }) }) - test('copyWithoutSpaces', () => { + it('copyWithoutSpaces', () => { const input = `ROOT / aa a / bb-b / c / / d_d` expect(BreadCrumbs.trimSeparatorSpaces(input)).toBe('/aa a/bb-b/c//d_d') diff --git a/catalog/app/components/Chat/Chat.tsx b/catalog/app/components/Chat/Chat.tsx new file mode 100644 index 00000000000..688296a3d27 --- /dev/null +++ b/catalog/app/components/Chat/Chat.tsx @@ -0,0 +1,91 @@ +import * as React from 'react' +import * as M from '@material-ui/core' +import * as Lab from '@material-ui/lab' + +import Skeleton from 'components/Skeleton' +import type * as AWS from 'utils/AWS' + +import History from './History' +import Input from './Input' + +const useStyles = M.makeStyles((t) => ({ + root: { + display: 'flex', + flexDirection: 'column', + flexGrow: 1, + overflow: 'hidden', + }, + error: { + marginTop: t.spacing(2), + }, + history: { + ...t.typography.body1, + maxHeight: t.spacing(70), + overflowY: 'auto', + }, + input: { + marginTop: t.spacing(2), + }, +})) + +const noMessages: AWS.Bedrock.Message[] = [] + +export function ChatSkeleton() { + const classes = useStyles() + return ( +
+ + +
+ ) +} + +const Submitting = Symbol('Submitting') + +interface ChatProps { + initializing: boolean + history: AWS.Bedrock.History + onSubmit: (value: string) => Promise +} + +export default function Chat({ history, onSubmit, initializing }: ChatProps) { + const classes = useStyles() + + const [value, setValue] = React.useState('') + const [state, setState] = React.useState(null) + + const handleSubmit = React.useCallback(async () => { + if (state) return + + setState(Submitting) + try { + await onSubmit(value) + setValue('') + } catch (e) { + setState(e instanceof Error ? e : new Error('Failed to submit message')) + } + setState(null) + }, [state, onSubmit, value]) + + return ( +
+ + {state instanceof Error && ( + + {state.message} + + )} + +
+ ) +} diff --git a/catalog/app/components/Chat/History.tsx b/catalog/app/components/Chat/History.tsx new file mode 100644 index 00000000000..bb1ef818756 --- /dev/null +++ b/catalog/app/components/Chat/History.tsx @@ -0,0 +1,101 @@ +import cx from 'classnames' +import * as React from 'react' +import * as M from '@material-ui/core' + +import usePrevious from 'utils/usePrevious' +import * as AWS from 'utils/AWS' + +import * as Messages from './Message' + +const useStyles = M.makeStyles((t) => ({ + assistant: { + animation: `$show 300ms ease-out`, + }, + message: { + '& + &': { + marginTop: t.spacing(2), + }, + }, + user: { + animation: `$slide 150ms ease-out`, + marginLeft: 'auto', + width: '60%', + }, + '@keyframes slide': { + '0%': { + transform: `translateX($${t.spacing(8)}px)`, + }, + '100%': { + transform: `translateX(0)`, + }, + }, + '@keyframes show': { + '0%': { + opacity: 0.7, + }, + '100%': { + opacity: '1', + }, + }, +})) + +interface HistoryProps { + className?: string + loading: boolean + messages: AWS.Bedrock.Message[] +} + +export default function History({ className, loading, messages }: HistoryProps) { + const classes = useStyles() + + const list = React.useMemo( + () => messages.filter((message) => message.role !== 'system'), + [messages], + ) + + const ref = React.useRef(null) + usePrevious(messages, (prev) => { + if (prev && messages.length > prev.length) { + ref.current?.scroll({ + top: ref.current?.firstElementChild?.clientHeight, + behavior: 'smooth', + }) + } + }) + + return ( +
+
+ {list.map((message, index) => { + switch (message.role) { + case 'user': + return ( + + ) + case 'summarize': + return ( + + ) + case 'assistant': + return ( + + ) + } + })} + {loading && } +
+
+ ) +} diff --git a/catalog/app/components/Chat/Input.tsx b/catalog/app/components/Chat/Input.tsx new file mode 100644 index 00000000000..6605883d1a9 --- /dev/null +++ b/catalog/app/components/Chat/Input.tsx @@ -0,0 +1,52 @@ +import * as React from 'react' +import * as M from '@material-ui/core' + +interface ChatInputProps { + className?: string + disabled?: boolean + onChange: (value: string) => void + onSubmit: () => void + value: string +} + +export default function ChatInput({ + className, + disabled, + onChange, + onSubmit, + value, +}: ChatInputProps) { + const handleSubmit = React.useCallback( + (event) => { + event.preventDefault() + if (!value || disabled) return + onSubmit() + }, + [disabled, onSubmit, value], + ) + return ( +
+ onChange(e.target.value)} + size="small" + value={value} + variant="outlined" + InputProps={{ + endAdornment: ( + + + send + + + ), + }} + /> + + ) +} diff --git a/catalog/app/components/Chat/Message.tsx b/catalog/app/components/Chat/Message.tsx new file mode 100644 index 00000000000..390a265ae1f --- /dev/null +++ b/catalog/app/components/Chat/Message.tsx @@ -0,0 +1,73 @@ +import cx from 'classnames' +import * as React from 'react' +import * as M from '@material-ui/core' +import { fade } from '@material-ui/core/styles' + +import Markdown from 'components/Markdown' +import Skel from 'components/Skeleton' + +const useSkeletonStyles = M.makeStyles((t) => ({ + text: { + height: t.spacing(2), + '& + &': { + marginTop: t.spacing(1), + }, + }, +})) + +interface SkeletonProps { + className?: string +} + +export function Skeleton({ className }: SkeletonProps) { + const classes = useSkeletonStyles() + return ( +
+ + + + +
+ ) +} + +interface AssistantProps { + className?: string + content: string +} + +export function Assistant({ className, content }: AssistantProps) { + return ( +
+ +
+ ) +} + +const useUserStyles = M.makeStyles((t) => ({ + root: { + borderRadius: t.shape.borderRadius, + background: t.palette.primary.main, + }, + inner: { + padding: t.spacing(2), + background: fade(t.palette.background.paper, 0.9), + }, +})) + +interface UserProps { + className?: string + content: string +} + +export function User({ className, content }: UserProps) { + const classes = useUserStyles() + + return ( +
+
+ +
+
+ ) +} diff --git a/catalog/app/components/Chat/index.ts b/catalog/app/components/Chat/index.ts new file mode 100644 index 00000000000..ef4f2666527 --- /dev/null +++ b/catalog/app/components/Chat/index.ts @@ -0,0 +1,3 @@ +export { default as Input } from './Input' +export { default as History } from './History' +export { default, ChatSkeleton } from './Chat' diff --git a/catalog/app/components/Dialog/Confirm.tsx b/catalog/app/components/Dialog/Confirm.tsx index 15d1749d709..0146c0e071c 100644 --- a/catalog/app/components/Dialog/Confirm.tsx +++ b/catalog/app/components/Dialog/Confirm.tsx @@ -44,6 +44,7 @@ interface PromptProps { title: string } +// TODO: Re-use utils/Dialog export function useConfirm({ cancelTitle, title, onSubmit, submitTitle }: PromptProps) { const [key, setKey] = React.useState(0) const [opened, setOpened] = React.useState(false) diff --git a/catalog/app/components/Dialog/Prompt.tsx b/catalog/app/components/Dialog/Prompt.tsx index 999d331fa6d..dd87c916468 100644 --- a/catalog/app/components/Dialog/Prompt.tsx +++ b/catalog/app/components/Dialog/Prompt.tsx @@ -4,19 +4,23 @@ import * as M from '@material-ui/core' import * as Lab from '@material-ui/lab' interface DialogProps { + children: React.ReactNode initialValue?: string onCancel: () => void onSubmit: (value: string) => void open: boolean + placeholder?: string title: string validate: (value: string) => Error | undefined } function Dialog({ + children, initialValue, - open, onCancel, onSubmit, + open, + placeholder, title, validate, }: DialogProps) { @@ -26,6 +30,7 @@ function Dialog({ const handleChange = React.useCallback((event) => setValue(event.target.value), []) const handleSubmit = React.useCallback( (event) => { + event.stopPropagation() event.preventDefault() setSubmitted(true) if (!error) onSubmit(value) @@ -37,11 +42,13 @@ function Dialog({
{title} + {children} {!!error && !!submitted && ( @@ -69,11 +76,19 @@ function Dialog({ interface PromptProps { initialValue?: string onSubmit: (value: string) => void + placeholder?: string title: string validate: (value: string) => Error | undefined } -export function usePrompt({ initialValue, title, onSubmit, validate }: PromptProps) { +// TODO: Re-use utils/Dialog +export function usePrompt({ + onSubmit, + initialValue, + placeholder, + validate, + title, +}: PromptProps) { const [key, setKey] = React.useState(0) const [opened, setOpened] = React.useState(false) const open = React.useCallback(() => { @@ -89,20 +104,22 @@ export function usePrompt({ initialValue, title, onSubmit, validate }: PromptPro [close, onSubmit], ) const render = React.useCallback( - () => ( + (children?: React.ReactNode) => ( ), - [initialValue, key, close, handleSubmit, opened, title, validate], + [close, handleSubmit, initialValue, key, opened, placeholder, title, validate], ) return React.useMemo( () => ({ diff --git a/catalog/app/components/Experiments/Experiments.js b/catalog/app/components/Experiments/Experiments.js deleted file mode 100644 index 25a064b7cde..00000000000 --- a/catalog/app/components/Experiments/Experiments.js +++ /dev/null @@ -1,51 +0,0 @@ -import * as R from 'ramda' -import * as React from 'react' - -// map of experiment name to array of variants -const EXPERIMENTS = { - cta: [ - 'Ready to get your data organized?', - 'Ready to experiment faster?', - 'Ready to maximize return on data?', - ], - lede: ['Accelerate from data to impact', 'Manage data like code', 'Discover faster'], -} - -const Ctx = React.createContext() - -const pickRandom = (arr) => arr[Math.floor(Math.random() * arr.length)] - -const mapKeys = (fn) => - R.pipe( - R.toPairs, - R.map(([k, v]) => [fn(k, v), v]), - R.fromPairs, - ) - -export function ExperimentsProvider({ children }) { - const ref = React.useRef({}) - - const get = React.useCallback( - (name) => { - if (!(name in ref.current)) { - ref.current[name] = pickRandom(EXPERIMENTS[name]) - } - return ref.current[name] - }, - [ref], - ) - - const getSelectedVariants = React.useCallback( - (prefix = '') => mapKeys((k) => `${prefix}${k}`)(ref.current), - [ref], - ) - - return {children} -} - -export function useExperiments(experiment) { - const exps = React.useContext(Ctx) - return experiment ? exps.get(experiment) : exps -} - -export { ExperimentsProvider as Provider, useExperiments as use } diff --git a/catalog/app/components/Experiments/index.js b/catalog/app/components/Experiments/index.js deleted file mode 100644 index f0a23d3783a..00000000000 --- a/catalog/app/components/Experiments/index.js +++ /dev/null @@ -1 +0,0 @@ -export * from './Experiments' diff --git a/catalog/app/components/FileEditor/Controls.tsx b/catalog/app/components/FileEditor/Controls.tsx index dad90109590..be174abce4b 100644 --- a/catalog/app/components/FileEditor/Controls.tsx +++ b/catalog/app/components/FileEditor/Controls.tsx @@ -16,6 +16,26 @@ export function AddFileButton({ onClick }: AddFileButtonProps) { ) } +interface PreviewButtonProps extends EditorState { + className?: string + onPreview: NonNullable +} + +export function PreviewButton({ className, preview, onPreview }: PreviewButtonProps) { + const handleClick = React.useCallback(() => onPreview(!preview), [onPreview, preview]) + return ( + event.stopPropagation()} + className={className} + control={ + + } + label="Preview" + labelPlacement="end" + /> + ) +} + interface ControlsProps extends EditorState { className?: string } @@ -63,7 +83,7 @@ export function Controls({ setAnchorEl(null)}> {types.map((type) => ( handleTypeClick(type)} key={type.brace}> - Edit as {type.title || type.brace} + {type.title || 'Edit file'} ))} diff --git a/catalog/app/components/FileEditor/FileEditor.tsx b/catalog/app/components/FileEditor/FileEditor.tsx index 14d6c2bed33..e34851c84db 100644 --- a/catalog/app/components/FileEditor/FileEditor.tsx +++ b/catalog/app/components/FileEditor/FileEditor.tsx @@ -1,7 +1,10 @@ +import cx from 'classnames' import * as React from 'react' +import * as M from '@material-ui/core' -import * as PreviewUtils from 'components/Preview/loaders/utils' import PreviewDisplay from 'components/Preview/Display' +import * as PreviewUtils from 'components/Preview/loaders/utils' +import { QuickPreview } from 'components/Preview/quick' import type * as Model from 'model' import AsyncResult from 'utils/AsyncResult' @@ -15,12 +18,14 @@ import { EditorInputType } from './types' export { detect, isSupportedFileType } from './loader' interface EditorProps extends EditorState { + className: string editing: EditorInputType empty?: boolean handle: Model.S3.S3ObjectLocation } function EditorSuspended({ + className, saving, empty, error, @@ -37,6 +42,7 @@ function EditorSuspended({ if (empty) return editing.brace === '__quiltConfig' ? ( ) : ( - + ) return data.case({ _: () => , @@ -61,6 +74,7 @@ function EditorSuspended({ if (editing.brace === '__quiltConfig') { return ( ) }, }) } +const useStyles = M.makeStyles({ + tab: { + display: 'none', + width: '100%', + }, + active: { + display: 'block', + }, +}) + export function Editor(props: EditorProps) { + const classes = useStyles() return ( }> - +
+ +
+ {props.preview && ( +
+ +
+ )}
) } diff --git a/catalog/app/components/FileEditor/HelpLinks.spec.tsx b/catalog/app/components/FileEditor/HelpLinks.spec.tsx new file mode 100644 index 00000000000..feae2501025 --- /dev/null +++ b/catalog/app/components/FileEditor/HelpLinks.spec.tsx @@ -0,0 +1,97 @@ +import * as React from 'react' +import renderer from 'react-test-renderer' + +import { MissingSourceBucket, WorkflowsConfigLink } from './HelpLinks' + +jest.mock( + 'constants/config', + jest.fn(() => ({})), +) + +jest.mock( + 'utils/StyledLink', + () => + ({ href, to, children }: React.PropsWithChildren<{ href: string; to: string }>) => ( +
{children} + ), +) + +jest.mock( + 'utils/StyledTooltip', + () => + ({ title, children }: React.PropsWithChildren<{ title: React.ReactNode }>) => ( +
+ {title} +
+ {children} +
+ ), +) + +jest.mock('components/Code', () => ({ children }: React.PropsWithChildren<{}>) => ( + {children} +)) + +jest.mock('utils/NamedRoutes', () => ({ + ...jest.requireActual('utils/NamedRoutes'), + use: () => ({ + urls: { + bucketFile: (b: string, k: string, opts: Record) => { + const params = new URLSearchParams(opts) + return `/b/${b}/tree/k/${k}?${params}` + }, + }, + }), +})) + +jest.mock('utils/GlobalDialogs', () => ({ + use: jest.fn(), +})) + +const useLocation = jest.fn( + () => ({ pathname: '/a/b/c', search: '?foo=bar' }) as Record, +) + +const useParams = jest.fn(() => ({ bucket: 'buck' }) as Record) + +jest.mock('react-router-dom', () => ({ + ...jest.requireActual('react-router-dom'), + useParams: jest.fn(() => useParams()), + useLocation: jest.fn(() => useLocation()), +})) + +describe('components/FileEditor/HelpLinks', () => { + describe('WorkflowsConfigLink', () => { + it('should render', () => { + const tree = renderer + .create(Test) + .toJSON() + expect(tree).toMatchSnapshot() + }) + + it('should throw outside bucket', () => { + jest.spyOn(console, 'error').mockImplementationOnce(jest.fn()) + useParams.mockImplementationOnce(() => ({})) + const tree = () => renderer.create(Any) + expect(tree).toThrowError('`bucket` must be defined') + }) + }) + + describe('MissingSourceBucket', () => { + it('should render', () => { + const tree = renderer + .create(Disabled button) + .toJSON() + expect(tree).toMatchSnapshot() + }) + + it('should throw outside bucket', () => { + jest.spyOn(console, 'error').mockImplementationOnce(jest.fn()) + useParams.mockImplementationOnce(() => ({})) + const tree = renderer + .create(Any) + .toJSON() + expect(tree).toMatchSnapshot() + }) + }) +}) diff --git a/catalog/app/components/FileEditor/HelpLinks.tsx b/catalog/app/components/FileEditor/HelpLinks.tsx new file mode 100644 index 00000000000..41cfe1df14f --- /dev/null +++ b/catalog/app/components/FileEditor/HelpLinks.tsx @@ -0,0 +1,194 @@ +import invariant from 'invariant' +import * as React from 'react' +import * as RRDom from 'react-router-dom' +import * as M from '@material-ui/core' +import * as Lab from '@material-ui/lab' + +import Code from 'components/Code' +import Lock from 'components/Lock' +import * as quiltConfigs from 'constants/quiltConfigs' +import { docs } from 'constants/urls' +import type * as Model from 'model' +import * as BucketPreferences from 'utils/BucketPreferences' +import { createBoundary } from 'utils/ErrorBoundary' +import * as Dialogs from 'utils/GlobalDialogs' +import Log from 'utils/Logging' +import * as NamedRoutes from 'utils/NamedRoutes' +import StyledLink from 'utils/StyledLink' +import StyledTooltip from 'utils/StyledTooltip' + +function useRouteToEditFile(handle: Model.S3.S3ObjectLocation) { + const { urls } = NamedRoutes.use() + const { pathname, search } = RRDom.useLocation() + const next = pathname + search + return urls.bucketFile(handle.bucket, handle.key, { edit: true, next }) +} + +interface WrapperProps { + children: React.ReactNode +} + +export function WorkflowsConfigLink({ children }: WrapperProps) { + const { bucket } = RRDom.useParams<{ bucket: string }>() + invariant(bucket, '`bucket` must be defined') + + const toConfig = useRouteToEditFile({ bucket, key: quiltConfigs.workflows }) + return {children} +} + +const Loading = Symbol('loading') + +interface AddMissingSourceBucketProps { + bucket: string + close: Dialogs.Close + onSubmit: () => Promise +} + +function MissingSourceBucketAddConfirmation({ + bucket, + close, + onSubmit, +}: AddMissingSourceBucketProps) { + const [state, setState] = React.useState() + const handleSubmit = React.useCallback(async () => { + setState(Loading) + try { + await onSubmit() + + setState() + close() + } catch (error) { + Log.error(error) + setState(error instanceof Error ? error : new Error('Unknown error')) + } + }, [close, onSubmit]) + return ( + <> + Add {bucket} to the source buckets list + {state instanceof Error && ( + + {state.message} + + )} + {state === Loading && ( + + + + )} + + + Cancel + + + Update config + + + + ) +} + +const DIALOG_PROPS = { + maxWidth: 'sm' as const, + fullWidth: true, +} + +const useMissingSourceBucketTooltipStyles = M.makeStyles({ + // browsers break the word on '-' + nowrap: { + whiteSpace: 'nowrap', + }, +}) + +function MissingSourceBucketTooltip() { + const { bucket } = RRDom.useParams<{ bucket: string }>() + invariant(bucket, '`bucket` must be defined') + + const classes = useMissingSourceBucketTooltipStyles() + + const { handle, update } = BucketPreferences.use() + + const toConfig = useRouteToEditFile( + handle || { bucket, key: quiltConfigs.bucketPreferences[0] }, + ) + + const open = Dialogs.use() + + const autoAdd = React.useCallback(async () => { + await update(BucketPreferences.sourceBucket(bucket)) + }, [bucket, update]) + + const showConfirmation = React.useCallback(() => { + open( + ({ close }) => ( + + ), + DIALOG_PROPS, + ) + }, [autoAdd, bucket, open]) + + return ( + <> + + Config property ui.sourceBuckets is empty.{' '} + + Learn more + + . + + + Edit manually or{' '} + + auto-add current bucket ( + s3://{bucket}) + + + + ) +} + +const ErrorBoundary = createBoundary( + (props: Lab.AlertProps) => (error: Error) => ( + + {error.message || 'Unexpected Error'} + + ), + 'MissingSourceBucketErrorBoundary', +) + +interface MissingSourceBucketProps { + className?: string + children: React.ReactNode +} + +const tooltipStyles = M.makeStyles((t) => ({ + tooltip: { + maxWidth: t.spacing(36), + }, +})) + +export function MissingSourceBucket({ className, children }: MissingSourceBucketProps) { + const classes = tooltipStyles() + return ( + + + + } + > +
{children}
+
+ ) +} diff --git a/catalog/app/components/FileEditor/QuiltConfigEditor/BucketPreferences.tsx b/catalog/app/components/FileEditor/QuiltConfigEditor/BucketPreferences.tsx index 67737f681f5..9abbd53a7a4 100644 --- a/catalog/app/components/FileEditor/QuiltConfigEditor/BucketPreferences.tsx +++ b/catalog/app/components/FileEditor/QuiltConfigEditor/BucketPreferences.tsx @@ -12,7 +12,10 @@ function Header() { return ( Configuration for Catalog UI: show and hide features, set default values. See{' '} - + the docs diff --git a/catalog/app/components/FileEditor/QuiltConfigEditor/Dummy.tsx b/catalog/app/components/FileEditor/QuiltConfigEditor/Dummy.tsx index 6daad7532f1..316a15df636 100644 --- a/catalog/app/components/FileEditor/QuiltConfigEditor/Dummy.tsx +++ b/catalog/app/components/FileEditor/QuiltConfigEditor/Dummy.tsx @@ -1,6 +1,6 @@ import * as React from 'react' -import type { JsonSchema } from 'utils/json-schema' +import type { JsonSchema } from 'utils/JSONSchema' export interface ConfigDetailsProps { children: (props: { diff --git a/catalog/app/components/FileEditor/QuiltConfigEditor/QuiltConfigEditor.tsx b/catalog/app/components/FileEditor/QuiltConfigEditor/QuiltConfigEditor.tsx index 01a091f0b12..340ff0ea457 100644 --- a/catalog/app/components/FileEditor/QuiltConfigEditor/QuiltConfigEditor.tsx +++ b/catalog/app/components/FileEditor/QuiltConfigEditor/QuiltConfigEditor.tsx @@ -1,10 +1,11 @@ import type { ErrorObject } from 'ajv' +import cx from 'classnames' import * as React from 'react' import * as M from '@material-ui/core' import JsonEditor from 'components/JsonEditor' import JsonValidationErrors from 'components/JsonValidationErrors' -import { JsonSchema, makeSchemaValidator } from 'utils/json-schema' +import { JsonSchema, makeSchemaValidator } from 'utils/JSONSchema' import * as YAML from 'utils/yaml' const useStyles = M.makeStyles((t) => ({ @@ -22,6 +23,7 @@ const useStyles = M.makeStyles((t) => ({ })) export interface QuiltConfigEditorProps { + className?: string disabled?: boolean error: Error | null initialValue?: string @@ -34,6 +36,7 @@ interface QuiltConfigEditorEssentialProps { } export default function QuiltConfigEditorSuspended({ + className, disabled, error, header, @@ -56,7 +59,7 @@ export default function QuiltConfigEditorSuspended({ [onChange, validate], ) return ( -
+
{!!header &&
{header}
} Configuration for data quality workflows. See{' '} - + the docs diff --git a/catalog/app/components/FileEditor/Skeleton.tsx b/catalog/app/components/FileEditor/Skeleton.tsx index 727e8350687..fef3e1d579c 100644 --- a/catalog/app/components/FileEditor/Skeleton.tsx +++ b/catalog/app/components/FileEditor/Skeleton.tsx @@ -6,7 +6,7 @@ import Skel from 'components/Skeleton' const useSkeletonStyles = M.makeStyles((t) => ({ root: { display: 'flex', - height: t.spacing(30), + height: ({ height }: { height: number }) => t.spacing(height), width: '100%', }, lineNumbers: { @@ -16,6 +16,7 @@ const useSkeletonStyles = M.makeStyles((t) => ({ content: { flexGrow: 1, marginLeft: t.spacing(2), + overflow: 'hidden', }, line: { height: t.spacing(2), @@ -25,8 +26,12 @@ const useSkeletonStyles = M.makeStyles((t) => ({ const fakeLines = [80, 50, 100, 60, 30, 80, 50, 100, 60, 30, 20, 70] -export default function Skeleton() { - const classes = useSkeletonStyles() +interface SkeletonProps { + height?: number +} + +export default function Skeleton({ height = 30 }: SkeletonProps) { + const classes = useSkeletonStyles({ height }) return (
diff --git a/catalog/app/components/FileEditor/State.tsx b/catalog/app/components/FileEditor/State.tsx index 1bf86f92919..75c57b249fc 100644 --- a/catalog/app/components/FileEditor/State.tsx +++ b/catalog/app/components/FileEditor/State.tsx @@ -2,6 +2,7 @@ import * as React from 'react' import * as RRDom from 'react-router-dom' import type * as Model from 'model' +import { isQuickPreviewAvailable } from 'components/Preview/quick' import * as AddToPackage from 'containers/AddToPackage' import * as NamedRoutes from 'utils/NamedRoutes' import parseSearch from 'utils/parseSearch' @@ -32,7 +33,9 @@ export interface EditorState { onCancel: () => void onChange: (value: string) => void onEdit: (type: EditorInputType | null) => void + onPreview: ((p: boolean) => void) | null onSave: () => Promise + preview: boolean saving: boolean types: EditorInputType[] value?: string @@ -48,6 +51,7 @@ export function useState(handle: Model.S3.S3ObjectLocation): EditorState { const [editing, setEditing] = React.useState( edit ? types[0] : null, ) + const [preview, setPreview] = React.useState(false) const [saving, setSaving] = React.useState(false) const writeFile = useWriteData(handle) const redirect = useRedirect() @@ -80,11 +84,13 @@ export function useState(handle: Model.S3.S3ObjectLocation): EditorState { onCancel, onChange: setValue, onEdit: setEditing, + onPreview: isQuickPreviewAvailable(editing) ? setPreview : null, onSave, + preview, saving, types, value, }), - [editing, error, onCancel, onSave, saving, types, value], + [editing, error, onCancel, onSave, preview, saving, types, value], ) } diff --git a/catalog/app/components/FileEditor/TextEditor.tsx b/catalog/app/components/FileEditor/TextEditor.tsx index c3eb8872337..c9aded95ec1 100644 --- a/catalog/app/components/FileEditor/TextEditor.tsx +++ b/catalog/app/components/FileEditor/TextEditor.tsx @@ -1,7 +1,7 @@ import * as brace from 'brace' +import cx from 'classnames' import * as React from 'react' import * as M from '@material-ui/core' -import * as Lab from '@material-ui/lab' import Lock from 'components/Lock' @@ -11,33 +11,50 @@ import 'brace/theme/eclipse' const useEditorTextStyles = M.makeStyles((t) => ({ root: { - border: `1px solid ${t.palette.divider}`, - width: '100%', + display: 'flex', + flexDirection: 'column', position: 'relative', + width: '100%', }, editor: { - height: t.spacing(50), + border: `1px solid ${t.palette.divider}`, + flexGrow: 1, resize: 'vertical', }, error: { - marginTop: t.spacing(1), + '& $editor': { + borderColor: t.palette.error.main, + }, + '& $helperText': { + color: t.palette.error.main, + }, + }, + helperText: { + marginTop: t.spacing(0.5), + whiteSpace: 'pre-wrap', // TODO: use JsonValidationErrors }, })) interface TextEditorProps { + autoFocus?: boolean + className: string disabled?: boolean + error: Error | null + leadingChange?: boolean onChange: (value: string) => void type: EditorInputType - value?: string - error: Error | null + initialValue?: string } export default function TextEditor({ - error, + className, + autoFocus, disabled, - type, - value = '', + error, + leadingChange = true, onChange, + type, + initialValue = '', }: TextEditorProps) { const classes = useEditorTextStyles() const ref = React.useRef(null) @@ -53,23 +70,35 @@ export default function TextEditor({ editor.getSession().setMode(`ace/mode/${type.brace}`) editor.setTheme('ace/theme/eclipse') - editor.setValue(value, -1) - onChange(editor.getValue()) // initially fill the value + + editor.$blockScrolling = Infinity + editor.setValue(initialValue, -1) + if (leadingChange) { + // Initially fill the value in the parent component. + // TODO: Re-design fetching data, so leading onChange won't be necessary + // probably, by putting data fetch into FileEditor/State + onChange(editor.getValue()) + } editor.on('change', () => onChange(editor.getValue())) + if (autoFocus) { + editor.focus() + wrapper.scrollIntoView() + } + return () => { resizeObserver.unobserve(wrapper) editor.destroy() } - }, [onChange, ref, type.brace, value]) + }, [autoFocus, leadingChange, onChange, ref, type.brace, initialValue]) return ( -
+
{error && ( - + {error.message} - + )} {disabled && }
diff --git a/catalog/app/components/FileEditor/__snapshots__/HelpLinks.spec.tsx.snap b/catalog/app/components/FileEditor/__snapshots__/HelpLinks.spec.tsx.snap new file mode 100644 index 00000000000..c74723e412e --- /dev/null +++ b/catalog/app/components/FileEditor/__snapshots__/HelpLinks.spec.tsx.snap @@ -0,0 +1,93 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`components/FileEditor/HelpLinks MissingSourceBucket should render 1`] = ` +
+

+ Config property + + ui.sourceBuckets + + is empty. + + + Learn more + + . +

+

+ + Edit manually + + or + + + + auto-add + +  current bucket ( + + s3:// + buck + + ) + +

+
+
+ Disabled button +
+
+`; + +exports[`components/FileEditor/HelpLinks MissingSourceBucket should throw outside bucket 1`] = ` +
+
+
+ + + +
+
+ \`bucket\` must be defined +
+
+
+
+ Any +
+
+`; + +exports[`components/FileEditor/HelpLinks WorkflowsConfigLink should render 1`] = ` + + Test + +`; diff --git a/catalog/app/components/FileEditor/index.ts b/catalog/app/components/FileEditor/index.ts index 9e0b2815bfe..f3beb287509 100644 --- a/catalog/app/components/FileEditor/index.ts +++ b/catalog/app/components/FileEditor/index.ts @@ -2,3 +2,4 @@ export * from './Controls' export * from './CreateFile' export * from './FileEditor' export * from './State' +export * from './types' diff --git a/catalog/app/components/FileEditor/loader.spec.ts b/catalog/app/components/FileEditor/loader.spec.ts index 57638f76a9b..6e1934db5be 100644 --- a/catalog/app/components/FileEditor/loader.spec.ts +++ b/catalog/app/components/FileEditor/loader.spec.ts @@ -1,5 +1,12 @@ import { isSupportedFileType } from './loader' +jest.mock( + 'constants/config', + jest.fn(() => ({ + apiGatewayEndpoint: '', + })), +) + describe('components/FileEditor/loader', () => { describe('isSupportedFileType', () => { it('should return true for supported files', () => { diff --git a/catalog/app/components/FileEditor/loader.ts b/catalog/app/components/FileEditor/loader.ts index df61fc2671b..feb4cbde3b5 100644 --- a/catalog/app/components/FileEditor/loader.ts +++ b/catalog/app/components/FileEditor/loader.ts @@ -24,7 +24,7 @@ export const loadMode = (mode: Mode) => { const isQuiltConfig = (path: string) => quiltConfigs.all.some((quiltConfig) => quiltConfig.includes(path)) const typeQuiltConfig: EditorInputType = { - title: 'Quilt config helper', + title: 'Edit with config helper', brace: '__quiltConfig', } @@ -44,13 +44,11 @@ const typeMarkdown: EditorInputType = { const isText = PreviewUtils.extIn(['.txt', '']) const typeText: EditorInputType = { - title: 'Plain text', brace: 'plain_text', } const isYaml = PreviewUtils.extIn(['.yaml', '.yml']) const typeYaml: EditorInputType = { - title: 'YAML', brace: 'yaml', } diff --git a/catalog/app/components/Footer/Footer.js b/catalog/app/components/Footer/Footer.tsx similarity index 76% rename from catalog/app/components/Footer/Footer.js rename to catalog/app/components/Footer/Footer.tsx index ab3cae20dd0..d10b6299b3a 100644 --- a/catalog/app/components/Footer/Footer.js +++ b/catalog/app/components/Footer/Footer.tsx @@ -1,4 +1,3 @@ -import cx from 'classnames' import * as React from 'react' import { Link } from 'react-router-dom' import * as M from '@material-ui/core' @@ -10,7 +9,6 @@ import * as style from 'constants/style' import * as URLS from 'constants/urls' import * as Notifications from 'containers/Notifications' import * as CatalogSettings from 'utils/CatalogSettings' -import HashLink from 'utils/HashLink' import * as NamedRoutes from 'utils/NamedRoutes' import copyToClipboard from 'utils/clipboard' @@ -37,52 +35,59 @@ function Version() { const classes = useVersionStyles() const { push } = Notifications.use() const handleCopy = React.useCallback(() => { - copyToClipboard(process.env.REVISION_HASH) + copyToClipboard(cfg.stackVersion) push('Web catalog container hash has been copied to clipboard') }, [push]) return ( -
- - Revision: {process.env.REVISION_HASH.substring(0, 8)} - -
+ + Version: {cfg.stackVersion} + ) } const FooterLogo = () => -const NavLink = (props) => ( - +const NavLink = (props: M.LinkProps) => ( + ) const NavSpacer = () => -const NavIcon = ({ icon, ...props }) => ( - - - -) +const useNavIconStyles = M.makeStyles({ + root: { + display: 'block', + height: '18px', + }, +}) + +interface NavIconProps extends M.BoxProps { + href: string + icon: string + target: string +} + +const NavIcon = ({ icon, ...props }: NavIconProps) => { + const classes = useNavIconStyles() + return ( + + + + ) +} const useStyles = M.makeStyles((t) => ({ - padded: {}, root: { background: `left / 64px url(${bg})`, boxShadow: [ '0px -12px 24px 0px rgba(25, 22, 59, 0.05)', '0px -16px 40px 0px rgba(25, 22, 59, 0.07)', '0px -24px 88px 0px rgba(25, 22, 59, 0.16)', - ], + ].join(', '), height: 230, paddingTop: t.spacing(6), position: 'relative', @@ -91,14 +96,6 @@ const useStyles = M.makeStyles((t) => ({ display: 'flex', paddingTop: 0, }, - // padding for marketing CTA - '&$padded': { - [t.breakpoints.down('sm')]: { - backgroundSize: 'contain', - height: 230 + 64, - paddingBottom: 64, - }, - }, }, container: { alignItems: 'center', @@ -124,6 +121,9 @@ const useStyles = M.makeStyles((t) => ({ `, }, }, + logoLink: { + display: 'block', + }, })) export default function Footer() { @@ -135,9 +135,7 @@ export default function Footer() { const reservedSpaceForIntercom = !intercom.dummy && !intercom.isCustom return ( -