diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5f19a8d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+# Cache files/python bytecode
+*.pyc
+*.pyo
+__pycache__
+*.pyd
+/.mypy_cache/
+
+# IDE configs
+/.vscode/
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
new file mode 100644
index 0000000..c6cc609
--- /dev/null
+++ b/CONTRIBUTORS.md
@@ -0,0 +1,12 @@
+# Contributors
+
+This documents lists the code authors of REST-Attacker.
+
+Any file in this project that doesn't state otherwise is licensed under the terms of the GNU Lesser General Public License Version 3, or any later version (LGPL3+). A copy of the license can be found in [COPYING](/COPYING).
+
+Name | Aliases | Contributions
+--------------------|-----------------------|-------------------------
+Christoph Heine | heinezen | Initial design & release
+
+If you're a first-time contributor, add yourself to the above list. We keep it mainly for licensing reasons, but also to keep
+an overview of who has done what.
diff --git a/COPYING b/COPYING
new file mode 100644
index 0000000..0a04128
--- /dev/null
+++ b/COPYING
@@ -0,0 +1,165 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..9a2a1f0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,111 @@
+# REST-Attacker
+
+REST-Attacker is an automated penetration testing framework for APIs following the REST architecture style.
+The tool's focus is on streamlining the analysis of generic REST API implementations by completely automating
+the testing process - including test generation, access control handling, and report generation - with minimal
+configuration effort. Additionally, REST-Attacker is designed to be flexible and extensible with support
+for both large-scale testing and fine-grained analysis.
+
+REST-Attacker is maintained by the [Chair of Network & Data Security](https://informatik.rub.de/nds/) of the Ruhr University of Bochum.
+
+
+## Features
+
+REST-Attacker currently provides these features:
+
+- **Automated generation of tests**
+ - Utilize an OpenAPI description to automatically generate test runs
+ - 32 integrated security tests based on [OWASP](https://owasp.org/www-project-api-security/) and other scientific contributions
+ - Built-in creation of security reports
+- **Streamlined API communication**
+ - Custom request interface for the REST security use case (based on the Python3 [requests](https://requests.readthedocs.io/en/latest/) module)
+ - Communicate with any generic REST API
+- **Handling of access control**
+ - Background authentication/authorization with API
+ - Support for the most popular access control mechanisms: OAuth2, HTTP Basic Auth, API keys and more
+- **Easy to use & extend**
+ - Usable as standalone (CLI) tool or as a module
+ - Adapt test runs to specific APIs with extensive configuration options
+ - Create custom test cases or access control schemes with the tool's interfaces
+
+
+## Install
+
+Get the tool by downloading or cloning the repository:
+
+```
+git clone https://github.com/RUB-NDS/REST-Attacker.git
+```
+
+You need Python >3.10 for running the tool.
+
+You also need to install the following packages with pip:
+
+```
+python3 -m pip install pyyaml requests requests_oauthlib oauthlib jsf jsonschema
+```
+
+## Quickstart
+
+Here you can find a quick rundown of the most common and useful commands. You can find more
+information on each command and other about available configuration options in our [usage guides](doc/usage).
+
+Get the list of supported test cases:
+
+```
+python3 -m rest_attacker --list
+```
+
+Basic test run (with load-time test case generation):
+
+```
+python3 -m rest_attacker --generate
+```
+
+Full test run (with load-time and runtime test case generation + rate limit handling):
+
+```
+python3 -m rest_attacker --generate --propose --handle-limits
+```
+
+Test run with only selected test cases (only generates test cases for test cases `scopes.TestTokenRequestScopeOmit` and `resources.FindSecurityParameters`):
+
+```
+python3 -m rest_attacker --generate --test-cases scopes.TestTokenRequestScopeOmit resources.FindSecurityParameters
+```
+
+Rerun a test run from a report:
+
+```
+python3 -m rest_attacker --run /path/to/report.json
+```
+
+
+## Documentation
+
+Usage guides and configuration format documentation can be found in the [documentation](/doc) subfolders.
+
+
+## Troubleshooting
+
+For fixes/mitigations for known problems with the tool, see the [troubleshooting docs](/doc/troubleshooting.md) or the [Issues](https://github.com/RUB-NDS/REST-Attacker/issues) section.
+
+
+## Contributing
+
+Contributions of all kinds are appreciated! If you found a bug or want to make a suggestion or feature request, feel free
+to create a new [issue](https://github.com/RUB-NDS/REST-Attacker/issues) in the issue tracker. You can also submit fixes
+or code ammendments via a [pull request](https://github.com/RUB-NDS/REST-Attacker/pulls).
+
+Unfortunately, we can be very busy sometimes, so it may take a while before we respond to comments in this repository.
+
+
+## License
+
+This project is licensed under **GNU LGPLv3 or later** (LGPL3+). See [COPYING](/COPYING) for the full license text and
+[CONTRIBUTORS.md](/CONTRIBUTORS.md) for the list of authors.
+
+
+
+
+
diff --git a/doc/README.md b/doc/README.md
new file mode 100644
index 0000000..0bfd55f
--- /dev/null
+++ b/doc/README.md
@@ -0,0 +1,12 @@
+# Documentation
+
+Documentation for REST-Attacker.
+
+- [Basic Usage Guide](/doc/guides/basic_usage.md)
+- [Advanced Usage Guide](/doc/guides/advanced_usage.md)
+- Configuration Formats
+ - [Info](/doc/formats/info.md)
+ - [Credentials & Access Control](/doc/formats/auth.md)
+ - [Report](/doc/formats/report.md)
+ - [Test Run](/doc/formats/run.md)
+ - [Meta](/doc/formats/meta.md)
diff --git a/doc/formats/auth.md b/doc/formats/auth.md
new file mode 100644
index 0000000..988cc8c
--- /dev/null
+++ b/doc/formats/auth.md
@@ -0,0 +1,458 @@
+# Authfile
+
+The auth file stores information that is used to authenticate with the service or make authorized requests.
+This includes credential information, schemes for creating authenticated/authorized request payloads and
+user sessions to retrieve OAuth2 tokens.
+
+The auth file is referenced in the mandatory [info file](info.md). The preferred filename is `credentials.json`.
+
+## Quick Reference
+
+```json
+{
+ "creds": {
+ "client0": {
+ "type": "oauth2_client",
+ "description": "OAuth Client",
+ "client_id": "aabbccddeeff123456789",
+ "client_secret": "abcdef12345678998765431fedcba",
+ "redirect_uri": "https://localhost:1234/test/",
+ "authorization_endpoint": "https://example.com/login/oauth/authorize",
+ "token_endpoint": "https://example.com/login/oauth/token",
+ "grants": [
+ "code",
+ "token"
+ ],
+ "scopes": [
+ "user"
+ ],
+ "flags": []
+ }
+ },
+ "schemes": {
+ "scheme0": {
+ "type": "header",
+ "key_id": "authorization",
+ "payload": "token {0}",
+ "params": {
+ "0": {
+ "id": "access_token",
+ "from": [
+ "token0",
+ ]
+ }
+ }
+ }
+ },
+ "required_always": {},
+ "required_auth": {
+ "req0": [
+ "scheme0"
+ ]
+ },
+ "users": {
+ "user0": {
+ "account_id": "user",
+ "user_id": "userXYZ",
+ "owned_resources": {},
+ "allowed_resources": {},
+ "sessions": {
+ "gbrowser": {
+ "type": "browser",
+ "exec_path": "/usr/bin/chromium",
+ "local_port": "1234"
+ }
+ },
+ "credentials": [
+ "client0"
+ ]
+ }
+ }
+}
+```
+
+## Attributes
+
+Parameter | Type | Optional
+------------------|---------------|----------
+[creds] | Object | No
+[schemes] | Object | No
+[users] | Object | Yes
+required_always | Object | Yes
+required_auth | Object | Yes
+
+[creds](#creds-object)
+[schemes](#schemes-object)
+[users](#users-object)
+
+
+**creds**
+Credentials for the service defined as [credentials objects](#creds-object). Keys
+are used as identifiers for referencing the specific credential object.
+
+**schemes**
+Schemes for authenticated/authorized requests defined as [scheme objects](#schemes-object). Keys
+are used as identifiers for referencing the specific scheme object.
+
+**required_always**
+This attribute can be used to define the minimum required schemes to make requests.
+
+Contains groups of schemes, where each group consists of a list of scheme IDs. At least one scheme per
+group should be included in a request. The first scheme ID in each group is internally used as the default
+scheme.
+
+*Example:*
+
+```json
+{
+ "required_unauth": {
+ "group0": [
+ "header0"
+ ],
+ "group1": [
+ "query0"
+ "cookie0"
+ ]
+ }
+}
+```
+
+*To make an **unauthenticated/unauthorized** requests, one scheme from `group0` and one scheme from `group1`*
+*should be included in the request. For `group0` the only option is using scheme `header0`. For `group1`, we can*
+*choose between the schemes `query0` and `cookie0`. By default, the tool chooses the first scheme listed for each*
+*group, i.e. the request will include the schemes `header0` and `query0`.*
+
+
+**required_auth**
+This attribute can be used to define the minimum required schemes to make **authenticated/authorized**
+requests.
+
+Contains groups of schemes, where each group consists of a list of scheme IDs. At least one scheme per
+group should be included in a request. The first scheme ID in each group is internally used as the default
+scheme.
+
+*Example:*
+
+```json
+{
+ "required_auth": {
+ "group0": [
+ "header0"
+ ],
+ "group1": [
+ "query0"
+ "cookie0"
+ ]
+ }
+}
+```
+
+*To make an **unauthenticated/unauthorized** requests, one scheme from `group0` and one scheme from `group1`*
+*should be included in the request. For `group0` the only option is using scheme `header0`. For `group1`, we can*
+*choose between the schemes `query0` and `cookie0`. By default, the tool chooses the first scheme listed for each*
+*group, i.e. the request will include the schemes `header0` and `query0`.*
+
+
+### `creds` Object
+
+Parameter | Type | Optional
+----------------|--------|----------
+type | String | No
+description | String | Yes
+*type-specific* | Any | -
+
+**type**
+Type of credentials. This type determines which additional type-specific parameters are expected to be included in this object
+by the tool. Type-specific parameters for each type are linked below.
+
+The following types are currently supported:
+
+- `[oauth2_client](#oauth2-client-type)`: Credentials for an OAuth2 client and information about authorization/token endpoints.
+- `[token](#token-type)`: (Access) Tokens
+- `[api_key](#api-key-type)`: API Keys/Tokens
+- `[basic](#basic-type)`: Credentials for HTTP Basic Authentication (i.e. username/password)
+
+Only the `oauth2_client` currently matters because these credentials are converted to `OAuth2TokenGenerator`s.
+
+**description**
+Human-readable description of the credentials.
+
+
+#### `oauth2_client` Type
+
+Parameter | Type | Optional
+---------------|---------------|----------
+client_id | String | No
+client_secret | String | No
+auth_endpoint | String | No
+token_endpoint | String | No
+redirect_uri | String | No
+grants | Array[String] | No
+scopes | Array[String] | Yes
+
+**client_id**
+ID of the configured client.
+
+**client_secret**
+Secret for the configured client.
+
+**auth_endpoint**
+The OAuth2 authorization endpoint of the service.
+
+**token_endpoint**
+The OAuth2 token endpoint of the service.
+
+**redirect_uri**
+Redirect URI configured for this client. Currently only one redirect URI can be specified here.
+
+**grants**
+Grants supported by the client. The tool can understand the OAuth2 grant types `code`, `token` and `refresh_token`.
+
+**scopes**
+Scopes supported by the client. If not present, the tool assumes that the client supports all scopes listed in
+the [info file](info.md). If no scopes were specified there, checks that require a list of claimed scopes
+may be skipped by the tool.
+
+
+#### `token` Type
+
+Parameter | Type | Optional
+---------------|---------------|----------
+access_token | String | No
+expires_at | Number | Yes
+scopes | Array[String] | Yes
+
+**access_token**
+Access token value.
+
+**expires_in**
+UNIX time of the expiration date of the token.
+
+**scopes**
+Scopes assigned to the access token. If not present, the tool assumes that the access tokens is valid for all scopes.
+
+
+#### `api_key` Type
+
+Parameter | Type | Optional
+---------------|---------------|----------
+key | String | No
+client_id | String | Yes
+
+**key**
+API key value.
+
+**client_id**
+ID of the client for which the API key was generated.
+
+
+### `basic` Type
+
+Parameter | Type | Optional
+---------------|---------------|----------
+username | String | No
+password | String | No
+
+**username**
+Username of the user.
+
+**password**
+Password of the user.
+
+
+#### `users` Object
+
+Parameter | Type | Optional
+------------------|---------------|----------
+account_id | String | No
+user_id | String | No
+userinfo_endpoint | Array | Yes
+owned_resources | Object | Yes
+allowed_resources | Object | Yes
+[sessions] | Object | Yes
+credentials | Array | Yes
+
+[sessions](#sessions-object)
+
+**account_id**
+Login name of the account at the service, e.g., `rest@attacker.com`. CURRENTLY UNUSED
+
+**user_id**
+Internal user ID at the service, e.g., `master-hacker-1234`. CURRENTLY UNUSED
+
+**userinfo_endpoint**
+Contains an endpoint definition for testing the authorized communication with the API.
+This can be used to check if the service has denied access to the client (because of rate/access limits or security protections).
+
+The order of arguments is `[, , ]`.
+
+**owned_resources**
+Resources that are tied to the users account.
+
+**allowed_resources**
+Resources that the user is allowed to access but does not own. CURRENTLY UNUSED
+
+**sessions**
+Sessions for authenticated users that can be used to establish user agents for OAuth2 authorization
+and token requests. They are defined as [session objects](#session-object). Keys are used as identifiers
+for referencing the specific session object.
+
+**credentials**
+Credentials that should be used for authorizing the user. CURRENTLY UNUSED
+
+
+### `schemes` Object
+
+Parameter | Type | Optional
+----------|--------|----------
+type | String | No
+key_id | String | No
+payload | String | No
+params | Object | No
+
+**type**
+Type of the scheme that also determines the location of the created payload in the request.
+
+The following types are supported by the tool:
+
+- `header`: Creates a HTTP header payload
+- `query`: Creates a query parameter key-value pair
+- `cookie`: Creates a HTTP cookie key-value pair
+- `basic`: Creates a HTTP Basic Authentication payload
+
+**key_id**
+The string used for the key of the created key-value pair for the types `query` and `cookie`. If the type is
+`header`, this value is used as the header name. If the type is `basic`, this value is ignored as its assumed
+to use the *Authorization* header.
+
+**payload**
+Pattern of the payload to create the value of the key-value pair for `query` and `cookie`, or the header payload
+for type `header`, respectively. For type `basic`, the payload is inserted into the Base64-encoded part of the
+header payload.
+
+The mattern may contain parameters for dynamic auth data that is inserted at runtime, e.g. an access token value.
+Parameters are referenced as IDs enclosed by curly braces (`{}`). The credentials used for this dynamic auth data
+must be referenced in the `params` attribute.
+
+**params**
+Defines the source of the auth data for parameters used in the pattern defined by the `payload` attribute. Keys
+in the `params` object are parameter IDs. Each value is a parameter object that contains the following attributes.
+
+Parameter | Type | Optional
+----------|---------------|----------
+id | String | No
+from | Array[String] | No
+
+The **from** attribute is a list of credentials IDs that can be used as a source for the auth value. The value of
+**id** is the name of the key in the credentials object that is used to access the auth value.
+
+---
+
+Example (`header`):
+
+```json
+{
+ "type": "header",
+ "key_id": "authorization",
+ "payload": "token {0}",
+ "params": {
+ "0": {
+ "id": "access_token",
+ "from": [
+ "token0",
+ ]
+ }
+ }
+}
+```
+
+Assuming the access token for credentials `token0` is `12341234abab`, this scheme will result in the following header payload:
+
+```
+authorization: token 12341234abab
+```
+
+Example (`basic`):
+
+```json
+{
+ "type": "basic",
+ "key_id": "authorization",
+ "payload": "{0}",
+ "params": {
+ "0": {
+ "id": "access_token",
+ "from": [
+ "token0",
+ ]
+ }
+ }
+}
+```
+
+Assuming the access token for credentials `token0` is `12341234abab`, this scheme will result in the following header payload:
+
+```
+authorization: Basic MTIzNDEyMzRhYmFi
+```
+
+
+### `session` Object
+
+Parameter | Type | Optional
+----------------|--------|----------
+type | String | No
+test_url | String | Yes
+*type-specific* | Any | -
+
+**type**
+Type of the session that also determines how the session is established.
+
+**test_url**
+URL to a protected resource that only the owner of the session should be able to access.
+After establishing the session, the tool will send a request to the URL. If the response
+status is not a 2XX status code, the session is not considered valid.
+
+The following types are currently supported:
+
+- `[weblogin](#weblogin-type)`: Login via POST request to a login endpoint to create a new session
+- `[cookie](#cookie-type)`: Use an established user session from cookies
+- `[browser](#browser-type)`: (**Recommended**) Use an established user session in a browser
+
+
+#### `weblogin` Type
+
+Parameter | Type | Optional
+----------|---------|----------
+url | String | No
+params | Object | No
+
+**url**
+Login endpoint URL.
+
+**params**
+Key-value pairs of parameters sent in the HTTP body to the URL.
+
+
+#### `cookie` Type
+
+Parameter | Type | Optional
+----------|---------|----------
+params | Object | No
+
+**params**
+Key-value pairs of cookies copied from the established session.
+
+
+#### `browser` Type
+
+Parameter | Type | Optional
+-----------|---------|----------
+exec_path | String | No
+local_port | String | No
+
+**exec_path**
+Path to the browser executable. Firefox or Chrome should both work, other browsers are untested.
+
+**local_port**
+Port used for the tool's internal HTTP server. The port should be available when starting the tool.
diff --git a/doc/formats/info.md b/doc/formats/info.md
new file mode 100644
index 0000000..276cee4
--- /dev/null
+++ b/doc/formats/info.md
@@ -0,0 +1,129 @@
+# Infofile
+
+The info file stores a service configuration that is used to initialize the checks in REST Attacker.
+For this reason, the info file is mandatory when testing a service.
+
+The filename must be `info.json`.
+
+## Quick Reference
+
+```json
+{
+ "descriptions": {
+ "openapi0": {
+ "available": true,
+ "date": "2021-09-04",
+ "path": "openapi.json",
+ "alt_versions": [],
+ "format": "openapi",
+ "official": false
+ }
+ },
+ "meta": "meta.json",
+ "credentials": "credentials.json",
+ "auth_methods": [
+ "oauth2"
+ ],
+ "scopes": [
+ "user",
+ "admin"
+ ],
+ "content_types": [
+ "application/json"
+ ],
+ "custom_headers": {
+ "rate_limit_max": "RateLimit-Limit",
+ "rate_limit_remaining": "RateLimit-Remaining"
+ }
+}
+```
+
+## Attributes
+
+Parameter | Type | Optional
+---------------|---------------|----------
+[descriptions] | Object | Yes
+meta | String | Yes
+credentials | String | Yes
+scopes | Array[String] | Yes
+content_types | Array[String] | Yes
+custom_headers | Object | Yes
+
+[descriptions](#descriptions-object)
+
+**descriptions**
+API descriptions for the service ([see here for more details](#descriptions-object)). Keys
+are used as identifiers for referencing the specific API description.
+
+**meta**
+Path to the [meta file](meta.md) of the service.
+
+**credentials**
+Path to the [credentials file](credentials.md) of the service. If no credentials file is referenced,
+REST Attacker will only execute unauthenticated/unauthorized checks.
+
+**scopes**
+Scopes supported by the service.
+
+**content_types**
+Content types supported by the service.
+
+**custom_headers**
+Maps handler types to response header IDs which are tracked by the tool during the analysis.
+After every check, the last response received by the tool is returned and can be analyzed by
+a `ResponseHandler`. Currently, this is only used for tracking rate limits, but it may be used
+for tracking other types of information.
+
+The built-in handler types can be referenced with these keys:
+
+Key | Handler Type | Description
+---------------------|--------------|------------
+rate_limit_max | Rate limit | Maximum rate limit (per intervall)
+rate_limit_remaining | Rate limit | Remaining rate limit (in intervall)
+rate_limit_reset | Rate limit | Reset time of rate limit
+
+
+## `descriptions` Object
+
+Defines an API description of the service's API.
+
+Parameter | Type | Optional
+---------------|---------------|----------
+path | String | No
+available | Boolean | No
+official | Boolean | Yes
+date | String | Yes
+alt_versions | Array[String] | Yes
+format | String | Yes
+
+**path**
+Relative path to the description file.
+
+**available**
+Signifies whether the analysis tool is allowed to use this API description. If `false`, the description
+is not loaded.
+
+**official**
+Signifies whether this file is from an official source (i.e. the service's documentation) or created
+externally by other parties.
+
+This attribute is currently purely informational and has not influence on the analysis. It may be used
+in the future to compare official and unofficial descriptions of the same API.
+
+**date**
+The date the API description was created (if known). Expects ISO 8601 format.
+
+Currently this attribute is not used. REST Attacker may use the `date` attribute in the future to compare
+different versions of the same API.
+
+**alt_versions**
+IDs of alternative versions of the API description, i.e. versions that use a different file format (e.g. YAML)
+or API description format (e.g. RAML).
+
+Currently this attribute is not used. REST Attacker may use the `alt_versions` attribute in the future to compare
+differences in API descriptions for the same API version.
+
+**format**
+The description format used for the file referenced in `path`. The only supported value is `openapi`. If this
+attribute is missing, REST attacker will use `openapi` by default.
+
diff --git a/doc/formats/meta.md b/doc/formats/meta.md
new file mode 100644
index 0000000..8ad1aa0
--- /dev/null
+++ b/doc/formats/meta.md
@@ -0,0 +1,42 @@
+# Metafile
+
+The meta file stores descriptions and meta information about a service. All information is stored in JSON format.
+It can be used to categorize a service if multiple services are tested. Meta files are currenty not used during
+testing, although they may be used for debug output in the future.
+
+The meta file is referenced in the mandatory [info file](info.md). The preferred filename is `meta.json`.
+
+## Quick Reference
+
+```json
+{
+ "name": "MyService",
+ "description": "An example service",
+ "tags": [
+ "misc",
+ "example"
+ ],
+ "docs": "https://example.com/docs/restapi/"
+}
+```
+
+## Attributes
+
+Parameter | Type | Optional
+------------|---------------|---------
+name | String | No
+description | String | No
+tags | Array[String] | No
+docs | String | No
+
+**name**
+Human-readable name of the service or name of the API.
+
+**description**
+Human-readable description of the service or API.
+
+**tags**
+An array of strings that act as tags to categorize the service. Tags can be chosen arbitrarily by users.
+
+**docs**
+URL that links to the documentation for the service or API.
diff --git a/doc/formats/report.md b/doc/formats/report.md
new file mode 100644
index 0000000..0d10354
--- /dev/null
+++ b/doc/formats/report.md
@@ -0,0 +1,218 @@
+# Reportfile
+
+The report file contains the results of an executed test run as well as statistics
+and meta information about the analyzed API.
+
+Report files can also be used as run configurations to replicate a run if check
+parameters were exported alongside the check results.
+
+The default filename is `report.json`.
+
+
+## Quick Reference
+
+```json
+{
+ "type": "report",
+ "meta": {
+ "name": "MyService",
+ "description": "An example service"
+ },
+ "stats": {
+ "start": "2022-07-16T14-27-20Z",
+ "end": "2022-07-16T14-27-25Z",
+ "planned": 1,
+ "finished": 1,
+ "skipped": 0,
+ "aborted": 0,
+ "errors": 0,
+ "analytical_checks": 0,
+ "security_checks": 1
+ },
+ "args": [
+ "example.json",
+ "--generate"
+ ],
+ "reports": [
+ {
+ "check_id": 0,
+ "test_type": "security",
+ "test_case": "https.TestHTTPAvailable",
+ "status": "finished",
+ "issue": "security_flaw",
+ "value": {
+ "status_code": 200
+ },
+ "curl": "curl -X GET http://api.example.com/user",
+ "config": {
+ "request_info": {
+ "url": "http://api.example.com",
+ "path": "/user",
+ "operation": "get",
+ "kwargs": {
+ "allow_redirects": false
+ }
+ },
+ "auth_info": {
+ "scheme_ids": null,
+ "scopes": null,
+ "policy": "DEFAULT"
+ }
+ }
+ }
+ ]
+}
+```
+
+
+## Attributes
+
+Parameter | Type | Optional
+---------------|---------------|----------
+type | String | No
+meta | Object | Yes
+[stats] | Object | Yes
+args | Array[String] | Yes
+[reports] | Array[Object] | No
+
+[stats](#stats-object)
+[reports](#reports-object)
+
+**type**
+Report type. Can be either `report` for a completed run (every check executed) or
+`partial` for a run that was aborted before completion.
+
+**meta**
+Meta information of the service. Contains the content of the [meta file](meta.md) if
+it is configured.
+
+**stats**
+Statistics about the test run ([see here for more details](#stats-object)).
+
+**args**
+Command-line arguments passed at the start of the test run if the run was initiated
+via CLI.
+
+**reports**
+Reports for the individual checks ([see here for more details](#reports-object)).
+
+
+## `stats` Object
+
+Statistcs about the test run.
+
+Parameter | Type | Optional
+------------------|---------------|----------
+start | String | No
+end | String | No
+planned | Number | No
+finished | Number | No
+skipped | Number | No
+aborted | Number | No
+errors | Number | No
+analytical_checks | Number | No
+security_checks | Number | No
+
+**start**
+Start date and time of the test run.
+
+**end**
+End date and time of the test run.
+
+**planned**
+Number of checks that were planned for the test run (i.e. the number of checks
+passed to the engine during initialization).
+
+**finished**
+Number of checks that were completed sucessfully (without being skipped or
+generating uncaught exceptions).
+
+**skipped**
+Number of checks skipped during the test run.
+
+**aborted**
+Number of aborted checks when terminating the test run early.
+
+**errors**
+Number of checks that failed because of unexpected errors.
+
+**analytical_checks**
+Number of *planned* analysis checks.
+
+**security_checks**
+Number of *planned* security checks.
+
+
+## `report` Object
+
+Report for an individual check.
+
+Parameter | Type | Optional
+-----------|-------------|----------
+report_id | Number | No
+check_id | Number | No
+test_type | String | No
+test_case | String | No
+status | String | No
+issue | String | No
+value | Object | No
+curl | String | Yes
+config | Object | Yes
+
+**report_id**
+Reference ID for this report.
+
+**check_id**
+Reference ID of the check the report belongs to.
+
+**test_type**
+Type of test case that was executed. Can be one of these values:
+
+Value | Description
+----------|------------
+security | Checks for security issues or flaws.
+analysis | Analyzes behaviour or configiration of the API.
+
+**test_case**
+ID of the test case of the check.
+
+**status**
+Status of the check after execution. Can be one of these values.
+
+Value | Description
+----------|------------
+finished | Check completed without unexpected errors.
+skipped | Check was skipped.
+error | Check failed with an unexpected error.
+aborted | Check was aborted because the test run was terminated.
+
+**issue**
+Simple classification of the detected issue. The possible value are different depending
+on the test type.
+
+For the security test type, these values are possible:
+
+Value | Description
+------------------|------------
+security_okay | No security issue has been found.
+security_problem | Indicators for a security issue have been found, but a flaw could not be confirmed.
+security_flaw | A security issue was found and could be confirmed.
+
+For the analysis test type, these values are possible:
+
+Value | Description
+-------------------|------------
+analysis_candidate | The check detected the behaviour/configuration it was looking for.
+analysis_none | The check did not find the behaviour/configuration it was looking for.
+
+**value**
+Additional parameters for interpreting the check result. The format of this object
+depends on the test case.
+
+**curl**
+Curl command for replicating an API request sent by the check.
+
+**config**
+Serialized configuration of the check. If present, the check can be replicated
+when the report file is passed to the tool as a run configuration. The format of this object
+depends on the test case.
\ No newline at end of file
diff --git a/doc/formats/run.md b/doc/formats/run.md
new file mode 100644
index 0000000..64e8db3
--- /dev/null
+++ b/doc/formats/run.md
@@ -0,0 +1,70 @@
+# Runfile
+
+Run configuration file for creating a test run with specified checks.
+
+
+## Quick Reference
+
+```json
+{
+ "type": "run",
+ "checks": [
+ {
+ "check_id": 0,
+ "test_case": "https.TestHTTPAvailable",
+ "config": {
+ "request_info": {
+ "url": "http://api.example.com",
+ "path": "/user",
+ "operation": "get",
+ "kwargs": {
+ "allow_redirects": false
+ }
+ },
+ "auth_info": {
+ "scheme_ids": null,
+ "scopes": null,
+ "policy": "DEFAULT"
+ }
+ }
+ }
+ ]
+}
+```
+
+
+## Attributes
+
+Parameter | Type | Optional
+------------|---------------|----------
+type | String | No
+[checks] | Array[Object] | No
+
+[checks](#checks-object)
+
+
+**type**
+Run configuration type. Value must be `run`.
+
+**checks**
+Configuration parameters for the individual checks ([see here for more details](#checks-object)).
+
+
+## `checks` Object
+
+Check definitions for the test run.
+
+Parameter | Type | Optional
+-----------|-------------|----------
+check_id | Number | No
+test_case | String | No
+config | Object | Yes
+
+**check_id**
+(Unique) reference ID of the check.
+
+**test_case**
+ID of the test case of the check.
+
+**config**
+Serialized configuration of the check. The format of this object depends on the test case.
\ No newline at end of file
diff --git a/doc/guides/advanced_usage.md b/doc/guides/advanced_usage.md
new file mode 100644
index 0000000..7501efa
--- /dev/null
+++ b/doc/guides/advanced_usage.md
@@ -0,0 +1,218 @@
+# Advanced Usage
+
+1. [Using REST-Attacker as a Module](#using-rest-attacker-as-a-module)
+ 1. [Loading Configuration](#loading-configuration)
+ 1. [Check Generation](#check-generation)
+ 1. [Initializing the Test Engine](#initializing-the-test-engine)
+ 1. [Controlling the Run](#controlling-the-run)
+1. [Creating a Custom Test Case](#creating-a-custom-test-case)
+ 1. [Test Case Interface](#test-case-interface)
+ 1. [Sending Requests to the API](#sending-requests-to-the-api)
+ 1. [Getting Access Control Information](#getting-access-control-information)
+
+
+## Using REST-Attacker as a Module
+
+Choosing to use REST-Attacker as a module can give you much more control over
+the check generation, testing, and reporting processes. Furthermore, it allows you to
+write your own test cases for security or analysis checks.
+
+
+### Loading Configuration
+
+REST-Attacker provides parsers for every input configuration format in the
+`rest_attacker.util.parsers` module.
+
+1. [Auth files](/doc/formats/auth.md): `rest_attacker.util.parsers.config_auth`
+1. [Info files](/doc/formats/info.md): `rest_attacker.util.parsers.config_info`
+1. [Run files](/doc/formats/run.md): `rest_attacker.util.parsers.config_run`
+1. OpenAPI: `rest_attacker.util.parsers.openapi`
+
+These parsers automatically create configuration objects that the test engine needs.
+If you don't want to use REST-Attacker's configuration formats, you need to
+create the objects yourself. Check generation and starting a test run requires
+at least an `EngineConfig` object which references all other configuration data.
+
+
+### Check Generation
+
+REST-Attacker's internal check generation is implemented in the `rest_attacker.engine.generate_checks`
+module. The `generate_checks(..)` methods receives an `EngineConfig` object, a dict
+of test cases mapped to their test case ID, and optional filters. You can pass
+your own implemented test cases here if they inherit from the `TestCase` base class
+([see here for implementing your own test cases](#test-case-interface)).
+
+`generate_checks(..)` filters the provided test cases (if filters are defined) and then
+calls their respective `generate(..)` methods to create checks. The list of generated
+checks is returned in the end.
+
+
+### Initializing the Test Engine
+
+![Engine](images/engine.svg)
+
+REST-Attacker executes a test run via the implemented `Engine` class that
+you find in the `rest_attacker.engine.engine` module. It requires an
+`EngineConfig` and a list of checks for intialization.
+
+On initialization, the test engine will set up its statistics and internal
+state tracking (`InternalState`) which you can access with the `state` member.
+`InternalState` also manages trackers for rate limiting.
+
+
+### Controlling the Run
+
+The easiest way to start the test run is to call the `run()` method. This
+starts an automated test run that iterates through all checks you provided and
+manages the complete execution and internal updates.
+
+If you want more fine-grained control, you can also access the methods used in each
+iteration directly.
+
+- `current_check(..)`: Executes the check at the current index (`index` member of `Engine`). The index is not automatically incremented by the method. You need to update the engine's `index` member manually.
+- `update_handlers(..)`: Updates rate limit detection and checks if rate limits are reached.
+- `status(..)`: Prints the current index and the total number of check to `stdout`.
+
+After the test run is finished, you can write the results to a [report file](/doc/formats/report.md)
+using the `export(..)` method. You may also access the individual reports for each check object
+by accessing their `report(..)` method.
+
+
+## Creating a Custom Test Case
+
+Writing a custom test case allows you to execute your own security and analysis
+checks with REST-Attacker. Test cases are all classes inheriting from the generic
+`TestCase` interface in the `rest_attacker.checks.generic` module. The interface
+describes the required methods that you need to implement which we will now explain in
+detail.
+
+
+### Test Case Interface
+
+![Test Case Interface](images/test_case_interface.svg)
+
+**`__init__(..)`**
+Initializes a check object for the test case. All parameters needed for test execution
+should be passed to this method. The only mandatory member you need to define
+is `self.check_id`.
+
+It is recommended to call `super().__init__(..)` at the start of your initialization.
+This automatically creates a `TestResult` object which can be used to store the results
+of the test execution. You can access the member via `self.result`.
+
+**`run()`**
+Implements the security tests for the test case. It is called by the test engine during the execution of its
+`current_check(..)` method.
+
+For communication with the API, you should use REST-Attacker's integrated [request backend](#sending-requests-to-the-api)
+and [auth backend](#getting-access-control-information). However, you are free
+to implement any tests you want in `run()`. The method has no restrictions on what
+can or cannot be analyzed.
+
+We recommend you use the `TestResult` object referenced by `self.result` to store
+results of your tests and track the status of the test execution, although this
+is not required. With `TestResult` a result summary can be easily exported using
+its `dump()` method. The test engine also uses the `TestResult` object of a check
+to update its internal statistics.
+
+**`report(..)`**
+Creates an exportable report for the check. The method must return a `Report` object
+that contains a JSON-compatible dict (i.e. it can be printed as a JSON object).
+
+The easiest way to create a basic report is to call `dump(..)` of the `TestResult`
+object and pass the result to a `Report` object. You can choose to add other values
+in your report if you want.
+
+**`generate(..)`**
+Implements load-time check generation for the test case. This method is called when
+you pass the `--generate` flag to the CLI to automatically create a test run.
+
+`generate(..)` can use any values in the API configuration provided by an `EngineConfig`
+object for its generation. It must return a list of checks initialized from any
+`TestCase` class. However, we recommend that you only generate checks for the same
+`TestCase` class that implements the `generate(..)` method to avoid interdependencies.
+
+If you don't want to enable automated load-time generation for the test case, return an empty list.
+
+**`propose(..)`**
+Implements run-time check generation for the test case. This method is called by the
+test engine after a check has been executed if you pass the `--propose` flag to the CLI.
+
+`propose(..)` can use any values in the API configuration **and** any parameters of
+the check object (for exampole the `TestResult`) for its generation. It must return
+a list of checks initialized from any `TestCase` class.
+
+If you don't want to enable automated run-time generation for the test case, return an empty list.
+
+**`serialize(..)`**
+Saves the initialization parameters of a check to a JSON-compatible dict. This
+method may be used to export the configuration of a check, so that it can be
+reproduced later.
+
+**`deserialize(..)`**
+Creates a check from a serialized configuration.
+
+
+### Sending Requests to the API
+
+![Request Backend](images/request_info.svg)
+
+You can prepare and send request with REST-Attacker's `RequestInfo` class in
+the `rest_attacker.util.request.request_info` module. `RequestInfo` allows you
+to specify an API request consisting of
+
+- API operation (HTTP method)
+- API base URL
+- Resource path
+
+`RequestInfo` wraps around Python's `requests` HTTP library. Therefore, it accepts
+all additional parameters that are also supported by `requests.Request`.
+
+To make an API request, call the `send(..)` method of your initialized `RequestInfo`
+object. The method will prepare the corresponding HTTP request and send it
+via the `requests.request(..)` method. The received `requests.Response` is
+returned.
+
+`send(..)` allows you to pass access control payloads separately from other request
+parameters. This is useful if you want to make the same API request with different
+access levels or authentication methods, e.g., to compare how the API responds.
+The method expects access control payloads generated by REST-Attacker's auth backend.
+
+
+### Getting Access Control Information
+
+![Auth Backend](images/auth_backend.svg)
+
+Access control payloads for authorized API requests can be created via REST-Attacker's
+auth backend. When starting via CLI, the backend is initialized from an [auth config](/doc/formats/auth.md)
+referenced by the info file in the passed config directory. The auth backend gives
+you several options for assembling access control payloads from automated generation
+of payloads to more fine-rained control over the used schemes, credentials, and access
+levels.
+
+![Auth Generator](images/auth_gen.svg)
+
+The simplest method to get an access control payload is using the `AuthGenerator` object
+that can be accessed via the `auth` member of the test engine's `EngineConfig`. `AuthGenerator`
+provides the method `get_auth(..)` which will automatically try to assemble a valid
+access contol payload using the requirements in the auth config. `get_auth(..)`
+optionally allows you to define the desired access level of OAuth2 credentials via
+the `scopes` parameter.
+
+You can also generate access control payloads for specific authentication schemes. To do
+this, you can either pass a list of scheme IDs to `get_auth(..)` or requests a payload
+for a specific scheme via the auth generators `get_auth_scheme(..)` method. `get_auth_scheme(..)`
+also allows you to manually pass credential information that should be used for the payload
+via the `credentials_map` parameter.
+
+![Token Generator](images/token_generator.svg)
+
+Credentials from the auth config are stored in the `credentials` member of the test
+engine's `EngineConfig`. Credentials are referenced by their ID in the auth config.
+Plaintext credentials like passwords and API keys can be accessed directly,
+while dynamic credentials such as OAuth2 tokens have to be requested from the API.
+For this purpose, REST-Attacker provides the `OAuth2TokenGenerator` class which
+is initialized for every OAuth2 client defined in the auth config. `OAuth2TokenGenerator`
+supports OAuth2's authorization code, implicit, and refresh flows and can
+handle token retrieval in the background if user session information is
+defined in the auth config.
diff --git a/doc/guides/basic_usage.md b/doc/guides/basic_usage.md
new file mode 100644
index 0000000..4b57100
--- /dev/null
+++ b/doc/guides/basic_usage.md
@@ -0,0 +1,169 @@
+# Basic Usage
+
+1. [Setup](#setup)
+1. [First Start](#first-start)
+1. [Viewing Results of a Test Run](#viewing-results-of-a-test-run)
+1. [Configuring More API Parameters](#configuring-more-api-parameters)
+1. [Automated Access Control](#automated-access-control)
+1. [Automated Rate Limit Detection](#automated-rate-limit-detection)
+
+## Setup
+
+Running REST-Attacker requires at minimum:
+
+* **OpenAPI file** (version 2.0+) that describes the API you want to analyze
+
+Optionally, you may additionally configure:
+
+* **[Info file](/doc/formats/auth.md)** for extended API configuration
+ * Multiple API descriptions
+ * Miscellaneous API configuration parameters (if not available in API description)
+ * Content types
+ * OAuth2 scopes
+ * Custom headers for rate limiting
+* **[Auth file](/doc/formats/auth.md)** for automated handling of authentication/authorization with the tool
+ * Authentication/Authorization schemes
+ * Required access control methods
+ * Credentials for user account(s) at the service
+ * Credentials for OAuth2 client(s) at the service
+
+## First Start
+
+Starting a test run requires the setup configuration and a run configuration containing
+checks and their parameters. REST-Attacker can also generate the run configuration for you
+if you pass it the `--generate` flag:
+
+```
+python3 -m rest_attacker --generate
+```
+
+With this option, REST-Attacker will try to automatically generate checks for all
+built-in test cases using the given OpenAPI file. REST-Attackers test cases cover a variety of different
+analysis and security issues. The code for the implemented test cases is available in the
+`checks` submodule.
+
+You can also see the list of test cases by running:
+
+```
+python3 -m rest_attacker --list
+```
+
+If you only want to generate checks for specific test cases, you can pass a list of test case IDs to the
+`--test-cases` argument:
+
+```
+python3 -m rest_attacker --generate --test-cases scopes.TestTokenRequestScopeOmit resources.FindSecurityParameters
+```
+
+For example, this command would only generate a test run with checks for the
+`scopes.TestTokenRequestScopeOmit` and `resources.FindSecurityParameters` test cases.
+
+Check generation can be enhanced with run-time generation using the `--propose` flag.
+This option will generate checks during a test run using test results from test execution.
+`--propose` and `--generate` can also be combined:
+
+```
+python3 -m rest_attacker --generate --propose
+```
+
+Currently, `--propose` only works for a few of the built-in checks, so you may not see any
+run-time generated checks if you filter for certain test cases.
+
+
+## Viewing Results of a Test Run
+
+Results of a test run are exported to the directory `rest_attacker/out` by default.
+Alternatively, you can specify the report folder with the `--output-dir` argument:
+
+```
+python3 -m rest_attacker --generate --output-dir /tmp/example_run/
+```
+
+For more information on report files, see the [report docs](/doc/guides/report.md).
+
+
+## Configuring More API Parameters
+
+Some of the more advanced features require further configuration in addition to the OpenAPI description,
+namely the automated handling of access control and rate limit detection. You can find
+templates for all configuration formats in the [formats documentation directory](/doc/formats/).
+If this is your first time using the configuration formats, the easiest way to start is
+to create a copy of the "Quickstart" templates for each format that is mentioned here and fill
+in the respective config values for the API you want to test.
+
+Configuration files must be stored in a directory that you pass to the tool as the config.
+A file called `info.json` ([format documentation](/doc/formats/info.md)) needs to be present
+in this directory. It contains references to other config files and is the only
+required config file.
+
+`info.json` must specify at least one OpenAPI file in its `descriptions` attribute. You may
+add alternative OpenAPI descriptions for the service. However, only the first available
+OpenAPI description is used for the automated check generation with the `--generate` flag
+by default.
+
+Starting the tool with a custom configuration looks like this:
+
+```
+python3 -m rest_attacker --generate --output-dir
+```
+
+As config path you can either pass
+
+1. A relative or absolute directory *path* on your system
+2. The *name* of a directory inside the `rest_attacker/cfg/` subfolder
+
+
+### Automated Access Control
+
+If the API you want to test has protected endpoints or requires the usage of access control
+mechanisms, you can supply an *auth config* to REST-Attacker to automate the necessary
+access control flows. REST-Attacker is able to handle many authentication and authorization
+processes in the background, without requiring manual intervention. This includes building
+authorized API requests, retrieval of OAuth2 access tokens, and determining the
+correct access levels to use for the respective endpoints. Some built-in
+test cases also require an existing auth config for their automated check generation.
+
+The auth config for the API is placed in an auth file (usually called `auth.json` or
+`credentials.json`) ([format documentation](/doc/formats/info.md)). The path to this file
+**must be referenced in the mandatory `info.json` file**.
+
+The auth config allows you to configure:
+
+- Credentials
+ - Static (e.g. username/password, token values)
+ - Dynamic (e.g. OAuth2 clients)
+- Authentication/Authorization schemes for the API request
+- API requirements for schemes
+- User sessions
+
+### Automated Rate Limit Detection
+
+REST-Attacker can check and detect whether the API has blocked requests of the tool during
+the test execution process duze to rate limits. It can currently check for two types of rate limits:
+
+1. Standard rate limits limiting the general number of API requests
+2. Generic access limits that block access to the API (these can be caused by multiple factors, e.g. sending
+too many requests to the same endpoint)
+
+You can activate rate limit detection by passing the `--handle-limits` flag to
+the CLI call:
+
+```
+python3 -m rest_attacker --generate --handle-limits
+```
+
+By default, this can only detect if a standard rate limit has been reached by looking for
+HTTP status code `429` in API responses. However, you can enhance the naive rate limit detection
+by supplying additional configuration.
+
+Some APIs may return headers in their API responses that communicate the remaining
+rate limit. REST-Attacker can utilize these headers to avoid triggering a rate limit
+and to pause a test run in case it needs to. Rate limit headers can be configured
+in the `custom_headers` attribute of the info file.
+
+To detect generic access limits, you need to configure at least one user in the auth config
+with the `userinfo_endpoint` attribute. This endpoint must be accessible to the configured
+user and must return a `2XX` status code in return to an authorized API request. During a test
+run, REST-Attacker will send regular API requests to this endpoint to check if it can
+still access the endpoint. If the API response no longer contains a `2XX` response code,
+REST-Attacker will assume an access limit has been reached an will terminate the test run.
\ No newline at end of file
diff --git a/doc/guides/images/auth_backend.svg b/doc/guides/images/auth_backend.svg
new file mode 100644
index 0000000..ae8e175
--- /dev/null
+++ b/doc/guides/images/auth_backend.svg
@@ -0,0 +1,140 @@
+
+
+
diff --git a/doc/guides/images/auth_gen.svg b/doc/guides/images/auth_gen.svg
new file mode 100644
index 0000000..e471348
--- /dev/null
+++ b/doc/guides/images/auth_gen.svg
@@ -0,0 +1,39 @@
+
+
+
diff --git a/doc/guides/images/engine.svg b/doc/guides/images/engine.svg
new file mode 100644
index 0000000..476fffa
--- /dev/null
+++ b/doc/guides/images/engine.svg
@@ -0,0 +1,127 @@
+
+
+
diff --git a/doc/guides/images/request_info.svg b/doc/guides/images/request_info.svg
new file mode 100644
index 0000000..55c5010
--- /dev/null
+++ b/doc/guides/images/request_info.svg
@@ -0,0 +1,61 @@
+
+
+
diff --git a/doc/guides/images/test_case_interface.svg b/doc/guides/images/test_case_interface.svg
new file mode 100644
index 0000000..141527e
--- /dev/null
+++ b/doc/guides/images/test_case_interface.svg
@@ -0,0 +1,87 @@
+
+
+
diff --git a/doc/guides/images/token_generator.svg b/doc/guides/images/token_generator.svg
new file mode 100644
index 0000000..eb8806e
--- /dev/null
+++ b/doc/guides/images/token_generator.svg
@@ -0,0 +1,92 @@
+
+
+
diff --git a/doc/guides/report.md b/doc/guides/report.md
new file mode 100644
index 0000000..ab4b99b
--- /dev/null
+++ b/doc/guides/report.md
@@ -0,0 +1,56 @@
+# Reports
+
+Reports contain the results of checks in a test run. They also contain other helpful
+information about the run such as statistics and test configuration parameters to
+replicate the run.
+
+For every test run, a report and a logfile is created. Files are saved to the directory
+specified with the `--output-dir` flag, or `rest-attacker/out/` if no output directory
+was specified.
+
+
+## Test Results in the Report
+
+A report is stored as a JSON file (`report.json`). The complete format is documented
+[here](doc/formats/report.md). In this document, we will briefly cover the most
+relevant parts of the report format.
+
+You can see if your test run was completed successfully by checking the value of the `type`
+attribute. If the type is `report`, all checks were completed. If the type is `partial`,
+then the run was aborted at some point. This can happen if the tool detects that it
+reached an unrecoverable rate/access limit or if the test run was manually aborted
+via a `KeyboardInterrupt`. Aborted runs can be continued by using them as a run
+configuration:
+
+```
+python3 -m rest_attacker --continue report.json
+```
+
+`stats` displays you statistics of the run, e.g. start and end times of the test run as
+well as the number of completed checks.
+
+If you started the test run via command-line, the arguments passed to the CLI are
+listed in the `args` attribute.
+
+Results for the individual checks can be found in the `reports` array. Every check
+gets its own report. A simple summary of the detected can be seen in the `issue`
+attribute of the check report. Its value tells you if the tool found a security issue
+or behaviour that should be analyzed. The `value` attribute contains more information
+that helps you interpret the issue result.
+
+
+## Reproducing a Run from a Report File
+
+Reports can be used to reproduce a run if the check configuration parameters
+are stored in the individual check reports (this should be active by default).
+To do so, simply use report file as run configuration:
+
+```
+python3 -m rest_attacker --run report.json
+```
+
+The service config should not be altered significantly between runs as the
+configuration parameters only reference IDs of API descriptions, authentication
+schemes, and credentials and not the used values themselves. Beware that any
+authorization data for the initial test run, such as OAuth2 tokens, are not reused
+in the reproduced run. Instead, they are requested again from the service.
diff --git a/doc/troubleshooting.md b/doc/troubleshooting.md
new file mode 100644
index 0000000..cbfba60
--- /dev/null
+++ b/doc/troubleshooting.md
@@ -0,0 +1,26 @@
+# Troubleshooting
+
+## "Mismatching state"/"CSRF Warning!" error during token generation
+
+This can happen if the Chrome browser is used for handling user sessions. Chrome sometimes calls the redirect URI in OAuth2 authorization flows twice. However, the current implementation can only handle one request at the time. Thus, the next authorization flow will handle the second (outdated) Chrome request first.
+
+If this happens, you should restart the run.
+
+
+## "Rate Limit reached" / "Access Limit reached. Aborting Run" message
+
+In this case, the tool has detected that it cannot access the API anymore. This usually happens if too many authorization requests are sent to the authorization server or too many unauthorized requests to the API.
+
+If this happens, wait for a while and continue the test run by providing the report of the aborted run as a run configuration. This will continue the run.
+Alternatively, you can omit the `--handle-limits` flag in the CLI commands. Then, the rate/access limit detection is deactivated.
+
+
+## Test case execution stops / "Rate Limit reached" message
+
+In this case, the tool has detected that a rate limit has been exceeded. The run is halted until the rate limit resets (this may be detected from a response header).
+The run will continue after the rate limit has been reset. Alternatively, you can abort the run with `CTRL + C` and view the report for the already executed checks. You can continue the run by supplying the report file as a run configuration.
+
+
+## Firefox/Chrome does not open even though browser session is defined in service configuration
+
+`ROBrowserSession` currently does not work inside a Docker container. Open the repository outside of the Docker container and restart the test run.
diff --git a/rest_attacker/.gitignore b/rest_attacker/.gitignore
new file mode 100644
index 0000000..a3153e8
--- /dev/null
+++ b/rest_attacker/.gitignore
@@ -0,0 +1,2 @@
+# Output files
+/out/
diff --git a/rest_attacker/__init__.py b/rest_attacker/__init__.py
new file mode 100644
index 0000000..8c1441c
--- /dev/null
+++ b/rest_attacker/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+REST-Attacker is a pentesting and analysis tool for
+testing and evaluating the security of REST APIs.
+"""
diff --git a/rest_attacker/__main__.py b/rest_attacker/__main__.py
new file mode 100644
index 0000000..1e39b1e
--- /dev/null
+++ b/rest_attacker/__main__.py
@@ -0,0 +1,358 @@
+#!/usr/bin/env python3
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Main entrypoint for the tool. Handles argument parsing.
+"""
+from __future__ import annotations
+import typing
+
+import argparse
+from datetime import datetime
+import os
+import pathlib
+import sys
+import time
+import logging
+
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.engine.config import EngineConfig
+from rest_attacker.engine.generate_checks import generate_checks
+from rest_attacker.engine.internal_state import EngineStatus
+from rest_attacker.util.auth.token_generator import OAuth2TokenGenerator
+from rest_attacker.util.log import setup_logging
+import rest_attacker.util.parsers.config_info as config_info
+import rest_attacker.util.parsers.config_run as config_run
+from rest_attacker.util.request.http_methods import SAFE_METHODS
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.response_handler import AccessLimitHandler, RateLimitHandler
+from rest_attacker.util.version import GetVersion
+from rest_attacker.util.enum_test_cases import GetTestCases, get_test_cases
+from rest_attacker.engine.engine import Engine
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.checks.generic import TestCase
+
+
+def parse_args() -> argparse.Namespace:
+ """
+ Parse CLI arguments to initialize the tool.
+ """
+ parser = argparse.ArgumentParser(
+ "REST-Attacker",
+ description=("Pentesting tool for analyzing REST APIs")
+ )
+
+ parser.add_argument("config", default=None,
+ help=("Path to the service configuration. "
+ "Can be a directory or OpenAPI file."))
+
+ parser.add_argument("--version", nargs=0, action=GetVersion,
+ help="Print version number.")
+ parser.add_argument("--list", "-l", nargs=0, action=GetTestCases,
+ help="List all available test cases.")
+ parser.add_argument("--output-dir", default=None,
+ help="Export path for logs and reports.")
+
+ parser.add_argument("--loglevel", type=int, default=3, choices={1, 2, 3, 4, 5},
+ help=("Set the loglevel for the CLI. "
+ "Choices map to Python loglevels. "
+ "Logging to file is always level 5."))
+
+ # parser.add_argument("--demo", action="store_true",
+ # help="Run the demo. Shortcut for --config-dir demo")
+
+ parser.add_argument("--handle-limits", action="store_true",
+ help="Handle rate and access limits during the test run.")
+
+ parser.add_argument("--safemode", action="store_true", default=False,
+ help=("Deactivate modifying/destructive API operations."))
+ parser.add_argument("--fake-inputs", action="store_true",
+ help=("Generate fake input parameter values when generating checks. "
+ "WARNING: May result in destructive behavior."))
+
+ run_cfg = parser.add_mutually_exclusive_group()
+ run_cfg.add_argument("--run", default=None,
+ help="Start test run from a run configuration file.")
+ run_cfg.add_argument("--continue", default=None, dest='cont',
+ help="Continue test run from a run configuration file.")
+
+ parser.add_argument("--generate", action="store_true",
+ help="Automatically generate checks at load-time.")
+ parser.add_argument("--propose", action="store_true",
+ help="Automatically generate checks at run-time.")
+
+ filters = parser.add_argument_group()
+ filters.add_argument("--test-cases", action="extend", nargs="+", type=str,
+ help="Only execute checks with the specified test case IDs.")
+ filters.add_argument("--test-type", action="extend", nargs="+", type=str,
+ choices={"ANALYTICAL", "SECURITY", "COMPARISON"},
+ help="Only execute checks with the specified TestType.")
+ filters.add_argument("--auth-type", action="extend", nargs="+", type=str,
+ choices={"NOPE", "OPTIONAL",
+ "RECOMMENDED", "REQUIRED"},
+ help="Only execute checks with the specified AuthType.")
+ filters.add_argument("--live-type", action="extend", nargs="+", type=str,
+ choices={"ONLINE", "OFFLINE"},
+ help="Only execute checks with the specified LiveType.")
+
+ verbosity = parser.add_mutually_exclusive_group()
+ verbosity.add_argument("--verbosity", '-v', action='count', default=1,
+ help="Set output verbosity level.")
+ # verbosity.add_argument("--quiet", action="store_true",
+ # help="Run without terminal output.")
+
+ proxy = parser.add_argument_group()
+ proxy.add_argument("--proxy", default=None, type=str,
+ help="Define a HTTP/HTTPS proxy server for requests.")
+ proxy.add_argument("--cacert", default=None, type=str,
+ help="Path to a custom CA certificate.")
+
+ return parser.parse_args()
+
+
+def setup_outputs(args: argparse.Namespace) -> None:
+ """
+ Set the output directory for run reports and logging.
+
+ :param args: CLI aguments from argparse.
+ :type args: argparse.Namespace
+ """
+ if args.output_dir:
+ args.output_path = pathlib.Path(args.output_dir)
+
+ else:
+ args.output_path = pathlib.Path().resolve() / "rest_attacker" / "out"
+ out_folder = datetime.utcfromtimestamp(
+ time.time()).strftime('%Y-%m-%dT%H-%M-%SZ')
+ # Append service config name
+ out_folder += f"-{pathlib.Path(args.config).name.split('.')[0]}"
+
+ args.output_path = args.output_path / out_folder
+
+ if not os.path.exists(args.output_path):
+ os.makedirs(args.output_path)
+
+ # Setup logging
+ logpath = args.output_path / "run.log"
+ # Multiply by 10 to match Python loglevels
+ args.loglevel = args.loglevel * 10
+ setup_logging(cli_loglevel=args.loglevel, logpath=logpath)
+
+
+def setup_config(args: argparse.Namespace) -> EngineConfig:
+ """
+ Setup the service configuration for the engine.
+
+ :param args: CLI aguments from argparse.
+ :type args: argparse.Namespace
+ """
+ # if args.demo:
+ # args.config = "demo"
+ # logging.info("Starting demo.")
+ # args.run = pathlib.Path().resolve() / "rest_attacker" / \
+ # "cfg" / "demo" / "runs" / "sample.json"
+
+ args.config_path = pathlib.Path(args.config)
+ if not args.config_path.exists():
+ # Assume it's a name of a folder in cfg and get the actual path
+ args.config_path = pathlib.Path().resolve() / "rest_attacker" / \
+ "cfg" / args.config
+
+ if args.config_path.exists():
+ cfg = config_info.load_config(args.config_path)
+
+ else:
+ raise Exception(f"No service configuration found at '{args.config}'.")
+
+ elif args.config_path.is_dir():
+ cfg = config_info.load_config(args.config_path)
+
+ elif args.config_path.is_file():
+ cfg = config_info.create_config_from_openapi(args.config_path)
+
+ else:
+ raise Exception(f"No service configuration found for '{args.config}'.")
+
+ logging.info(f"Using service configuration at: {args.config_path}")
+
+ cfg.cli_args = args
+
+ return cfg
+
+
+def setup_run(cfg: EngineConfig, args: argparse.Namespace) -> list[TestCase]:
+ """
+ Setup the run configuration for the engine.
+
+ :param args: CLI aguments from argparse.
+ :type args: argparse.Namespace
+ """
+ # Test case filters
+ test_filters = {}
+ if args.test_cases:
+ test_filters.update({
+ "test_cases": args.test_cases
+ })
+
+ if args.test_type:
+ test_filters.update({
+ "test_type": [TestCaseType[test_type] for test_type in args.test_type]
+ })
+ if args.auth_type:
+ test_filters.update({
+ "auth_type": [AuthType[auth_type] for auth_type in args.auth_type]
+ })
+ if args.live_type:
+ test_filters.update({
+ "live_type": [LiveType[live_type] for live_type in args.live_type]
+ })
+
+ checks = []
+ if args.run:
+ args.run_path = pathlib.Path(args.run)
+ logging.info(f"Using run configuration at: {args.run_path}")
+ checks = config_run.load_config(get_test_cases(), cfg, args.run_path)
+
+ elif args.cont:
+ args.run_path = pathlib.Path(args.cont)
+ logging.info(f"Continuing run from configuration at: {args.run_path}")
+ checks = config_run.load_config(get_test_cases(), cfg, args.run_path, continue_run=True)
+
+ elif args.generate:
+ logging.info("No run configuration found.")
+ logging.info("Resuming with automatically generated checks.")
+ checks = generate_checks(cfg, get_test_cases(), test_filters)
+
+ else:
+ logging.warning("No run configuration found.")
+ logging.warning("Use --run to specify run configuration or "
+ "--generate to automatically generate checks.")
+
+ return checks
+
+
+def setup_request_backend(args: argparse.Namespace) -> None:
+ """
+ Configure the request backend.
+
+ :param args: CLI aguments from argparse.
+ :type args: argparse.Namespace
+ """
+ if args.safemode:
+ RequestInfo.allowed_ops = SAFE_METHODS
+
+ if args.proxy:
+ RequestInfo.global_kwargs["proxies"] = {
+ "http": args.proxy,
+ "https": args.proxy
+ }
+
+ if args.cacert:
+ RequestInfo.global_kwargs["verify"] = pathlib.Path(args.cacert)
+
+
+def setup_limits(cfg: EngineConfig, args: argparse.Namespace) -> list:
+ """
+ Setup handling of rate and access limits.
+
+ :param args: CLI aguments from argparse.
+ :type args: argparse.Namespace
+ """
+ handlers = []
+ if args.handle_limits:
+ # Rate limit handler
+ headers = {}
+ if "custom_headers" in cfg.info.keys():
+ headers.update(cfg.info["custom_headers"])
+
+ # handlers.append(RateLimitHandler(headers=headers))
+
+ # Access limit handler
+ # User userinfo endpoint of (default) user
+ if cfg.current_user_id:
+ default_user = cfg.users[cfg.current_user_id]
+ if default_user.userinfo_endpoint:
+ request_info = RequestInfo(
+ default_user.userinfo_endpoint[0],
+ default_user.userinfo_endpoint[1],
+ default_user.userinfo_endpoint[2]
+ )
+ # Request required scopes
+ # TODO: Look up endpoint and check security requirements (if defined)
+ user_cred_ids = default_user.credentials
+ scopes = None
+ for cred_id in user_cred_ids:
+ # Find a suitable client and request all scopes
+ cred = cfg.credentials[cred_id]
+
+ if isinstance(cred, OAuth2TokenGenerator):
+ scopes = cred.client_info.supported_scopes
+ break
+
+ auth_info = AuthRequestInfo(cfg.auth, scopes=scopes)
+
+ # TODO: Make interval configurable
+ handlers.append(AccessLimitHandler(
+ request_info, auth_info, interval=20))
+
+ return handlers
+
+
+def main():
+ """
+ CLI entrypoint of REST-Attacker.
+ """
+ args = parse_args()
+
+ # Setup output files
+ setup_outputs(args)
+
+ # Setup service configuration
+ cfg = setup_config(args)
+
+ if args.fake_inputs:
+ # Ask for confirmation before doing this.
+ print("Do you really want to generate fake inputs?")
+ print("This may retrieve, MODIFY or DELETE resources of other users.")
+ confirm = input("To proceed anyway type 'Yes, I understand' here: ")
+
+ if confirm != "Yes, I understand":
+ return
+
+ # Setup run configuration
+ checks = setup_run(cfg, args)
+ if len(checks) == 0:
+ return
+
+ # Setup Request Backend
+ setup_request_backend(args)
+
+ handlers = setup_limits(cfg, args)
+
+ # Start run
+ engine = Engine(cfg, checks, handlers=handlers)
+
+ try:
+ engine.run()
+
+ except KeyboardInterrupt:
+ logging.warning("Aborting run: KeyboardInterrupt")
+ engine.abort()
+
+ except Exception as e:
+ engine.state.status = EngineStatus.ERROR
+ logging.exception(
+ "Execution failed with the following error:", exc_info=e)
+
+ # Export results
+ try:
+ engine.export(args.output_path)
+
+ except Exception as e:
+ logging.exception(
+ "Exporting results failed with the following error:", exc_info=e)
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/rest_attacker/cfg/.gitignore b/rest_attacker/cfg/.gitignore
new file mode 100644
index 0000000..4d75efc
--- /dev/null
+++ b/rest_attacker/cfg/.gitignore
@@ -0,0 +1,5 @@
+*.json
+*.yaml
+
+!demo/**
+!demo/runs/**
diff --git a/rest_attacker/cfg/readme.txt b/rest_attacker/cfg/readme.txt
new file mode 100644
index 0000000..cf2f9a5
--- /dev/null
+++ b/rest_attacker/cfg/readme.txt
@@ -0,0 +1 @@
+Place config files for individual services in subfolders.
\ No newline at end of file
diff --git a/rest_attacker/checks/__init__.py b/rest_attacker/checks/__init__.py
new file mode 100644
index 0000000..859e5b3
--- /dev/null
+++ b/rest_attacker/checks/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+This module contains analysis and security test cases for investigating
+service implementations.
+"""
diff --git a/rest_attacker/checks/body.py b/rest_attacker/checks/body.py
new file mode 100644
index 0000000..4d58b48
--- /dev/null
+++ b/rest_attacker/checks/body.py
@@ -0,0 +1,551 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing response body parameters.
+"""
+
+from email import header
+import json
+import logging
+from os import stat
+import jsonschema
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+from rest_attacker.util.input_gen import replace_params
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+
+
+class CompareHTTPBodyToSchema(TestCase):
+ """
+ Compare the JSON body of a check to a JSON schema definitions from an API description.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.OPTIONAL
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ schema: dict,
+ auth_info: AuthRequestInfo = None,
+ ) -> None:
+ """
+ Creates a new check for CompareHTTPBodyToSchema.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ :param schema: JSON schema definition.
+ :type schema: dict
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.schema = schema
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ try:
+ json_body = response.json()
+
+ except ValueError as err:
+ logging.warning("Response contained no valid JSON payload.")
+ self.result.status = CheckStatus.ERROR
+ self.result.error = err
+ return
+
+ try:
+ jsonschema.validate(json_body, self.schema)
+
+ # Payload matches schema
+ self.result.issue_type = IssueType.MATCH
+ self.result.value = {
+ "valid": True,
+ }
+
+ except jsonschema.ValidationError as err:
+ # Payload does not match to schema
+ self.result.issue_type = IssueType.DIFFERENT
+ self.result.value = {
+ "valid": False,
+ "invalid_subschema": err.schema # Stores the faulty parts of schema
+ }
+
+ except jsonschema.SchemaError as err:
+ logging.warning("JSON schema is invalid.")
+ self.result.status = CheckStatus.ERROR
+ self.result.error = err
+ return
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr.servers:
+ for path_id, path in descr.endpoints.items():
+ for op_id, op in path.items():
+ replacments = None
+ if descr.requires_parameters(path_id, op_id):
+ if not config.users or not config.current_user_id:
+ # No replacement parameters defined
+ continue
+
+ default_user = config.users[config.current_user_id]
+ req_parameters = descr.get_required_param_defs(path_id, op_id)
+
+ replacments = replace_params(path, default_user, req_parameters)
+ if not replacments:
+ # No replacements found
+ continue
+
+ for status_code, response in op["responses"].items():
+ if not "content" in response.keys():
+ continue
+
+ if not "application/json" in response["content"].keys():
+ continue
+
+ if not "schema" in response["content"]["application/json"].keys():
+ continue
+
+ schema_def = response["content"]["application/json"]["schema"]
+ schemas = [schema_def]
+ if "allOf" in schema_def.keys():
+ schemas = schema_def["allOf"]
+
+ elif "oneOf" in schema_def.keys():
+ schemas = schema_def["oneOf"]
+
+ for schema in schemas:
+ if replacments:
+ request_info = RequestInfo(
+ server["url"],
+ replacments[0],
+ op_id,
+ headers=replacments[1],
+ params=replacments[2],
+ cookies=replacments[3]
+ )
+
+ else:
+ request_info = RequestInfo(
+ server["url"],
+ path_id,
+ op_id
+ )
+
+ auth_info = None
+ if config.auth and 200 <= int(status_code) < 400:
+ auth_info = AuthRequestInfo(
+ config.auth
+ )
+
+ test_cases.append(CompareHTTPBodyToSchema(
+ cur_check_id, request_info, schema, auth_info))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "schema": self.schema,
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ schema = serialized.pop("schema")
+ auth_info = None
+ if "auth_info" in serialized:
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return CompareHTTPBodyToSchema(check_id, request_info, schema, auth_info)
+
+
+class CompareHTTPBodyAuthNonauth(TestCase):
+ """
+ Compare the response bodies of a non-auth request and an auth request for the same endpoint.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.REQUIRED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo,
+ ) -> None:
+ """
+ Creates a new check for CompareHTTPBodyAuthNonauth.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ # 1st request: No authorization
+ auth_data = self.auth_info.auth_gen.get_min_auth()
+ response1 = self.request_info.send(auth_data)
+
+ # 2nd request: Authorized
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+ response2 = self.request_info.send(auth_data)
+ self.result.last_response = response2
+
+ try:
+ response1_json = response1.json()
+ response2_json = response2.json()
+
+ except json.JSONDecodeError as err:
+ logging.warning("JSON payload could not be decoded.")
+ self.result.status = CheckStatus.ERROR
+ self.result.error = err
+ return
+
+ common_values, unique_values_left, unique_values_right = \
+ _recursive_diff(
+ response1_json,
+ response2_json
+ )
+
+ if len(unique_values_left) == len(unique_values_right) == 0:
+ self.result.issue_type = IssueType.MATCH
+
+ else:
+ self.result.issue_type = IssueType.DIFFERENT
+
+ self.result.value = {
+ "common_values": common_values,
+ "unique_values_left": unique_values_left,
+ "unique_values_right": unique_values_right,
+ }
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr.servers:
+ for path_id, path in descr.endpoints.items():
+ for op_id, op in path.items():
+ replacments = None
+ if descr.requires_parameters(path_id, op_id):
+ if not config.users or not config.current_user_id:
+ # No replacement parameters defined
+ continue
+
+ default_user = config.users[config.current_user_id]
+ req_parameters = descr.get_required_param_defs(path_id, op_id)
+
+ replacments = replace_params(path, default_user, req_parameters)
+ if not replacments:
+ # No replacements found
+ continue
+
+ if replacments:
+ request_info = RequestInfo(
+ server["url"],
+ replacments[0],
+ op_id,
+ headers=replacments[1],
+ params=replacments[2],
+ cookies=replacments[3]
+ )
+
+ else:
+ request_info = RequestInfo(
+ server["url"],
+ path_id,
+ op_id
+ )
+
+ auth_info = AuthRequestInfo(
+ config.auth
+ )
+
+ test_cases.append(
+ CompareHTTPBodyAuthNonauth(cur_check_id, request_info, auth_info)
+ )
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "auth_info": self.auth_info.serialize(),
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return CompareHTTPBodyAuthNonauth(check_id, request_info, auth_info, **serialized)
+
+
+def _recursive_diff(left, right):
+ """
+ Compares two JSON payloads recursively and returns a comparison containing the common values
+ and the unique values of 'left' and 'right'.
+
+ :param left: First payload.
+ :param right: Second payload.
+ :type left: dict|list
+ :type right: dict|list
+ :return: Common values, unique values of left, unique values of right (in that order).
+ :rtype: tuple
+ """
+ if type(left) is not type(right):
+ # Different types cannot be compared
+ return {}, [left], [right]
+
+ if isinstance(left, dict):
+ # Dict comparison recurse
+ return _recursive_diff_dict(left, right)
+
+ if isinstance(left, list):
+ # List comparison recurse
+ return _recursive_diff_list(left, right)
+
+ # Primitive values
+ if left != right:
+ return {}, [left], [right]
+
+ return [left], [], []
+
+
+def _recursive_diff_dict(left, right):
+ """
+ Compares two dicts recursively and returns a comparison containing the common values
+ and the unique values of 'left' and 'right'.
+
+ :param left: First dict.
+ :param right: Second dict.
+ :type left: dict
+ :type right: dict
+ :return: Common values, unique values of left, unique values of right (in that order).
+ :rtype: tuple[dict]
+ """
+ unique_values_left = {}
+ unique_values_right = {}
+ common_values = {}
+
+ if left == right:
+ common_values.update(left)
+
+ else:
+ for key_left, value_left in left.items():
+ if key_left in right.keys():
+ value_right = right[key_left]
+ if value_left == value_right:
+ common_values.update({key_left: value_left})
+
+ else:
+ if type(value_left) is not type(value_right):
+ # Different types cannot be compared
+ unique_values_left.update({key_left: value_left})
+ unique_values_right.update({key_left: value_right})
+
+ if isinstance(value_left, dict):
+ # Dict comparison recurse
+ common, un_left, un_right = _recursive_diff_dict(
+ value_left, value_right)
+ common_values.update({key_left: common})
+ unique_values_left.update({key_left: un_left})
+ unique_values_right.update({key_left: un_right})
+
+ elif isinstance(value_left, list):
+ # List comparison recurse
+ common, un_left, un_right = _recursive_diff_list(
+ value_left, value_right)
+ common_values.update({key_left: common})
+ unique_values_left.update({key_left: un_left})
+ unique_values_right.update({key_left: un_right})
+
+ else:
+ unique_values_left.update({key_left: value_left})
+ unique_values_right.update({key_left: value_right})
+
+ else:
+ unique_values_left.update({key_left: value_left})
+
+ for key_right, value_right in right.items():
+ if key_right in left.keys():
+ # Should already be in dict because of 'left' for-loop
+ pass
+
+ else:
+ unique_values_right.update({key_right: value_right})
+
+ return common_values, unique_values_left, unique_values_right
+
+
+def _recursive_diff_list(left, right):
+ """
+ Compares two lists recursively and returns a comparison containing the common values
+ and the unique values of 'left' and 'right'.
+
+ :param left: First list.
+ :param right: Second list.
+ :type left: list
+ :type right: list
+ :return: Common values, unique values of left, unique values of right (in that order).
+ :rtype: tuple[list]
+ """
+ unique_values_left = []
+ unique_values_right = []
+ common_values = []
+
+ if left == right:
+ common_values.extend(left)
+
+ else:
+ min_length = 0
+ if len(left) < len(right):
+ min_length = len(left)
+
+ else:
+ min_length = len(right)
+
+ for array_idx in range(min_length):
+ value_left = left[array_idx]
+ value_right = right[array_idx]
+ if value_left != value_right:
+ if type(value_left) is not type(value_right):
+ # Different types cannot be compared
+ unique_values_left.append(value_left)
+ unique_values_right.append(value_right)
+
+ elif isinstance(value_left, dict):
+ # Dict comparison recurse
+ common, un_left, un_right = _recursive_diff_dict(
+ value_left, value_right)
+ common_values.append(common)
+ unique_values_left.append(un_left)
+ unique_values_right.append(un_right)
+
+ elif isinstance(value_left, list):
+ # List comparison recurse
+ common, un_left, un_right = _recursive_diff_list(
+ value_left, value_right)
+ common_values.append(common)
+ unique_values_left.append(un_left)
+ unique_values_right.append(un_right)
+
+ else:
+ unique_values_left.append(value_left)
+ unique_values_right.append(value_right)
+
+ else:
+ common_values.append(value_left)
+
+ if len(left) < len(right):
+ unique_values_right.extend(right[min_length:])
+
+ else:
+ unique_values_left.extend(left[min_length:])
+
+ return common_values, unique_values_left, unique_values_right
diff --git a/rest_attacker/checks/generic.py b/rest_attacker/checks/generic.py
new file mode 100644
index 0000000..b694b9c
--- /dev/null
+++ b/rest_attacker/checks/generic.py
@@ -0,0 +1,113 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Implementation of the generic test case super class.
+"""
+from __future__ import annotations
+import typing
+
+from abc import ABC, abstractmethod
+
+from rest_attacker.util.test_result import CheckStatus, TestResult
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.engine.config import EngineConfig
+ from rest_attacker.report.report import Report
+ from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+
+
+class TestCase(ABC):
+ """
+ Interface for test cases.
+ """
+ test_type: TestCaseType = None
+ auth_type: AuthType = None
+ live_type: LiveType = None
+ generates_for: tuple[typing.Type[TestCase], ...] | None = None
+
+ def __init__(self, check_id: int) -> None:
+ """
+ Create a new check from the test case.
+
+ :param check_id: Unique identifier of the check generated from the test case.
+ :type check_id: int
+ """
+ self.check_id = check_id
+
+ # Stores the result of a run.
+ self.result = TestResult(self)
+
+ @abstractmethod
+ def run(self) -> None:
+ """
+ Execute the check instance for this test case.
+ """
+
+ @abstractmethod
+ def report(self, verbosity: int = 2) -> Report:
+ """
+ Generate a report for the check instance.
+
+ :param verbosity: Verbosity of the exported results.
+ :type verbosity: int
+ """
+
+ @abstractmethod
+ def propose(self, config: EngineConfig, check_id_start: int) -> list[TestCase]:
+ """
+ Propose checks for the test case based on the results of the check.
+
+ :param config: Engine configuration for a service
+ :type config: EngineConfig
+ :param check_id_start: Starting index for assigning the check IDs.
+ :type check_id_start: int
+ """
+ if self.result.status is not CheckStatus.FINISHED:
+ raise Exception(f"Cannot propose checks for {self}. Check is not finished.")
+
+ @classmethod
+ @abstractmethod
+ def generate(cls, config: EngineConfig, check_id_start: int = 0) -> list[TestCase]:
+ """
+ Generate checks for the test case from information at load-time.
+
+ :param config: Engine configuration for a service
+ :type config: EngineConfig
+ :param check_id_start: Starting index for assigning the check IDs.
+ :type check_id_start: int
+ """
+
+ @abstractmethod
+ def serialize(self) -> dict | None:
+ """
+ Serialize a check to a JSON-compatible dict.
+
+ :return: A JSON-compatible dict if the check/test case is serializable, else None.
+ :rtype: dict | None
+ """
+
+ @classmethod
+ @abstractmethod
+ def deserialize(cls, serialized: dict, config: EngineConfig, check_id: int = 0) -> TestCase | None:
+ """
+ Deserialize a check from a JSON-compatible dict to a TestCase object.
+
+ :param serialized: Serialized representation of the check.
+ :type serialized: dict
+ :param config: Engine configuration for a service
+ :type config: EngineConfig
+ :param check_id_start: Starting index for assigning the check IDs.
+ :type check_id_start: int
+ :return: A check of the test case if the check/test case is deserializable, else None.
+ :rtype: TestCase | None
+ """
+
+ @classmethod
+ def get_test_case_id(cls) -> str:
+ """
+ Get the identifier of the test case.
+ """
+ return f"{cls.__module__.rsplit('.',maxsplit=1)[-1]}.{cls.__name__}"
+
+ def __repr__(self):
+ return f"<{type(self).__name__}<{self.check_id}>>"
diff --git a/rest_attacker/checks/headers.py b/rest_attacker/checks/headers.py
new file mode 100644
index 0000000..14aac9c
--- /dev/null
+++ b/rest_attacker/checks/headers.py
@@ -0,0 +1,532 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing HTTP headers.
+"""
+
+import logging
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+
+
+class FindCustomHeaders(TestCase):
+ """
+ Searches a response for custom (= non-standardized) HTTP headers.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None
+ ) -> None:
+ """
+ Creates a new check for FindCustomHeaders.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ response_headers = response.headers
+
+ unique_headers = {}
+ for header_id, header in response_headers.items():
+ if header_id.lower() in STANDARD_HEADERS or header_id.lower() in COMMON_HEADERS:
+ continue
+
+ unique_headers.update({
+ header_id: header
+ })
+
+ if len(unique_headers) > 0:
+ self.result.issue_type = IssueType.CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.value = unique_headers
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr["servers"]:
+ nonauth_request = RequestInfo(
+ server["url"],
+ "/", # TODO: better default path
+ "get"
+ )
+ test_cases.append(FindCustomHeaders(cur_check_id, nonauth_request))
+ cur_check_id += 1
+
+ if config.auth:
+ auth_request = RequestInfo(
+ server["url"],
+ "/", # TODO: better default path
+ "get"
+ )
+ auth_info = AuthRequestInfo(
+ config.auth,
+ policy=AccessLevelPolicy.MAX
+ )
+ test_cases.append(FindCustomHeaders(cur_check_id, auth_request, auth_info))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = None
+ if "auth_info" in serialized:
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return FindCustomHeaders(check_id, request_info, auth_info)
+
+
+class FindSecurityHeaders(TestCase):
+ """
+ Searches a response for security-related HTTP headers.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None
+ ) -> None:
+ """
+ Creates a new check for FindSecurityHeaders.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ response_headers = response.headers
+
+ security_headers = {}
+ for header_id, header in response_headers.items():
+ if header_id.lower() in SECURITY_HEADERS:
+ security_headers.update({
+ header_id: header
+ })
+
+ if len(security_headers) > 0:
+ self.result.issue_type = IssueType.CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.value = security_headers
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr["servers"]:
+ nonauth_request = RequestInfo(
+ server["url"],
+ "/", # TODO: better default path
+ "get"
+ )
+ test_cases.append(FindSecurityHeaders(cur_check_id, nonauth_request))
+ cur_check_id += 1
+
+ if config.auth:
+ auth_request = RequestInfo(
+ server["url"],
+ "/", # TODO: better default path
+ "get"
+ )
+ auth_info = AuthRequestInfo(
+ config.auth
+ )
+ test_cases.append(FindSecurityHeaders(cur_check_id, auth_request, auth_info))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = None
+ if "auth_info" in serialized:
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return FindSecurityHeaders(check_id, request_info, auth_info, **serialized)
+
+
+class MetaCompareHeaders(TestCase):
+ """
+ Compare the custom HTTP headers found in two checks of either FindCustomHeaders
+ or FindSecurityHeaders.
+ """
+ test_type = TestCaseType.META
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+ generates_for = (FindCustomHeaders, FindSecurityHeaders)
+
+ def __init__(self, check_id, check_left: TestCase, check_right: TestCase) -> None:
+ """
+ Creates a new check for MetaCompareHeaders.
+
+ :param check_left: Left comparison check.
+ :type check_left: TestCase
+ :param check_right: Right comparison check.
+ :type check_right: TestCase
+ """
+ super().__init__(check_id)
+
+ self.check_left = check_left
+ self.check_right = check_right
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if not (self.check_left.result.status is CheckStatus.FINISHED and
+ self.check_right.result.status is CheckStatus.FINISHED):
+ raise Exception(f"Cannot run meta check {self}. Dependent checks are not finished.")
+
+ unique_headers_left = set()
+ unique_headers_right = set()
+ common_headers = set()
+
+ for header in self.check_left.result.value:
+ if header in self.check_right.result.value:
+ common_headers.add(header)
+
+ else:
+ unique_headers_left.add(header)
+
+ for header in self.check_right.result.value:
+ if header in self.check_left.result.value:
+ common_headers.add(header)
+
+ else:
+ unique_headers_right.add(header)
+
+ if len(unique_headers_left) == len(unique_headers_right) == 0:
+ self.result.issue_type = IssueType.MATCH
+
+ else:
+ self.result.issue_type = IssueType.DIFFERENT
+
+ self.result.status = CheckStatus.FINISHED
+
+ self.result.value = {
+ "left": list(unique_headers_left),
+ "right": list(unique_headers_right),
+ "common": list(common_headers),
+ }
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ new_checks = FindCustomHeaders.generate(
+ config,
+ check_id_start
+ )
+ test_cases.extend(new_checks)
+ cur_check_id += len(new_checks)
+
+ if len(new_checks) >= 2:
+ # Only compare if there are enough checks
+ test_cases.append(MetaCompareHeaders(cur_check_id, new_checks[0], new_checks[1]))
+ cur_check_id += 1
+
+ new_checks = FindSecurityHeaders.generate(
+ config,
+ check_id_start
+ )
+ test_cases.extend(new_checks)
+ cur_check_id += len(new_checks)
+
+ if len(new_checks) >= 2:
+ # same here as above
+ test_cases.append(MetaCompareHeaders(cur_check_id, new_checks[0], new_checks[1]))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "check_left_id": self.check_left.check_id,
+ "check_right_id": self.check_right.check_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ # TODO: Reference checks from deserialized config
+
+ # return MetaCompareHeaders(check_id, **serialized)
+ return None
+
+
+# Standard headers defined in HTTP
+# from https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#Response_fields
+# and https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
+STANDARD_HEADERS = {
+ "accept-ch",
+ "access-control-allow-origin",
+ "access-control-allow-credentials",
+ "access-control-expose-headers",
+ "access-control-max-age",
+ "access-control-allow-methods",
+ "access-control-allow-headers",
+ "accept-patch",
+ "accept-ranges",
+ "age",
+ "allow",
+ "alt-svc",
+ "cache-control",
+ "connection",
+ "content-disposition",
+ "content-encoding",
+ "content-language",
+ "content-length",
+ "content-location",
+ "content-md5",
+ "content-range",
+ "content-type",
+ "date",
+ "delta-base",
+ "etag",
+ "expires",
+ "im",
+ "last-modified",
+ "link",
+ "location",
+ "p3p",
+ "pragma",
+ "preference-applied",
+ "proxy-authenticate",
+ "public-key-pins",
+ "referrer-policy", # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy
+ "retry-after",
+ "server",
+ "set-cookie",
+ "strict-transport-security",
+ "trailer",
+ "transfer-encoding",
+ "tk",
+ "upgrade",
+ "vary",
+ "via",
+ "warning",
+ "www-authenticate",
+ "x-frame-options",
+}
+
+# Common non-standard headers in HTTP
+# from https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#Common_non-standard_response_fields
+COMMON_HEADERS = {
+ "cache-control",
+ "cross-origin-embedder-policy",
+ "cross-origin-opener-policy",
+ "cross-origin-resource-policy",
+ "content-security-policy",
+ "content-security-policy-report-only",
+ "expect-ct",
+ "nel",
+ "permissions-policy",
+ "refresh",
+ "report-to",
+ "status",
+ "timing-allow-origin",
+ "x-content-duration",
+ "x-content-security-policy",
+ "x-content-type-options",
+ "x-correlation-id",
+ "x-powered-by",
+ "x-redirect-by",
+ "x-request-id",
+ "x-ua-compatible",
+ "x-webkit-csp",
+ "x-xss-protection",
+}
+
+# Deprecated interesting headers in HTTP
+# from https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers
+# DEPRECATED_HEADERS = {
+# "expect-ct",
+# "set-cookie2",
+# }
+
+# Security headers in HTTP
+# either used for setting security policy or containing security info
+SECURITY_HEADERS = {
+ "access-control-allow-origin",
+ "access-control-allow-credentials",
+ "access-control-expose-headers",
+ "access-control-max-age",
+ "access-control-allow-methods",
+ "access-control-allow-headers",
+ "cross-origin-embedder-policy",
+ "cross-origin-opener-policy",
+ "cross-origin-resource-policy",
+ "content-security-policy",
+ "content-security-policy-report-only",
+ "referrer-policy",
+ "set-cookie",
+ "strict-transport-security",
+ "warning",
+ "www-authenticate",
+ "x-content-security-policy",
+ "x-content-type-options",
+ "x-frame-options",
+ "x-webkit-csp",
+ "x-xss-protection",
+}
diff --git a/rest_attacker/checks/https.py b/rest_attacker/checks/https.py
new file mode 100644
index 0000000..855e17a
--- /dev/null
+++ b/rest_attacker/checks/https.py
@@ -0,0 +1,538 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing HTTPS support.
+"""
+
+import logging
+
+from urllib.parse import urlparse, urlunparse
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy
+from rest_attacker.util.openapi.wrapper import OpenAPI
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+
+
+class TestHTTPSAvailable(TestCase):
+ """
+ Checks whether an endpoint can be accessed via HTTPS.
+ """
+ test_type = TestCaseType.SECURITY
+ auth_type = AuthType.OPTIONAL
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None
+ ) -> None:
+ """
+ Creates a new check for TestHTTPSAvailable.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if urlparse(self.request_info.url)[0] != "https":
+ logging.info(
+ f"Scheme of provided URL {self.request_info.url} is not HTTPS.")
+
+ # Construct HTTPS URL if none was given
+ self.request_info.url = ("https", *self.request_info._url[1:])
+ logging.info(
+ f"Using constructed URL {self.request_info.url} with HTTPS scheme.")
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 <= response.status_code < 300:
+ self.result.issue_type = IssueType.OKAY
+
+ else:
+ self.result.issue_type = IssueType.PROBLEM
+
+ if 300 <= response.status_code < 400:
+ # Check if the redirect URL is HTTPS
+ redirect_url = urlparse(response.headers["location"])
+ if redirect_url.scheme == 'https':
+ self.result.issue_type = IssueType.OKAY
+
+ self.result.value = {
+ "status_code": response.status_code,
+ "redirect": 300 <= response.status_code < 400,
+ }
+
+ if 300 <= response.status_code < 400:
+ self.result.value.update({
+ "redirect_url": response.headers["location"]
+ })
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+ if self.result.value["redirect"]:
+ # Check if the redirect works
+ new_request = RequestInfo(
+ self.result.value["redirect_url"],
+ "", # path should be part of redirect URL
+ self.request_info.operation,
+ allow_redirects=False
+ )
+ new_checks.append(TestHTTPSAvailable(check_id_start, new_request))
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr["servers"]:
+ for path_id, path in descr.endpoints.items():
+ for op_id, op in path.items():
+ nonauth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id,
+ allow_redirects=False
+ )
+ test_cases.append(TestHTTPSAvailable(cur_check_id, nonauth_request))
+ cur_check_id += 1
+
+ if config.auth:
+ auth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id,
+ allow_redirects=False
+ )
+ auth_info = AuthRequestInfo(
+ config.auth,
+ policy=AccessLevelPolicy.MAX
+ )
+ test_cases.append(
+ TestHTTPSAvailable(cur_check_id, auth_request, auth_info)
+ )
+ cur_check_id += 1
+
+ logging.debug(
+ f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = None
+ if "auth_info" in serialized:
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestHTTPSAvailable(check_id, request_info, auth_info, **serialized)
+
+
+class TestHTTPAvailable(TestCase):
+ """
+ Checks whether an endpoint can be accessed via plain HTTP (without TLS).
+ """
+ test_type = TestCaseType.SECURITY
+ auth_type = AuthType.OPTIONAL
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None
+ ) -> None:
+ """
+ Creates a new check for TestHTTPAvailable.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if urlparse(self.request_info.url)[0] != "http":
+ logging.info(
+ f"Scheme of provided URL {self.request_info.url} is not HTTP.")
+
+ # Construct HTTPS URL if none was given
+ self.request_info.url = ("http", *self.request_info._url[1:])
+ logging.info(
+ f"Using constructed URL {self.request_info.url} with HTTP scheme.")
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 <= response.status_code < 300:
+ self.result.issue_type = IssueType.FLAW
+
+ else:
+ self.result.issue_type = IssueType.OKAY
+
+ if 300 <= response.status_code < 400:
+ # Check if the redirect URL is HTTPS
+ redirect_url = urlparse(response.headers["location"])
+ if redirect_url.scheme != 'https':
+ self.result.issue_type = IssueType.PROBLEM
+
+ self.result.status = CheckStatus.FINISHED
+
+ self.result.value = {
+ "status_code": response.status_code,
+ "redirect": 300 <= response.status_code < 400,
+ }
+
+ if 300 <= response.status_code < 400:
+ self.result.value.update({
+ "redirect_url": response.headers["location"]
+ })
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+ if self.result.value["redirect"]:
+ # Check if the redirect is HTTPS
+ new_request = RequestInfo(
+ self.result.value["redirect_url"],
+ "", # path should be part of redirect URL
+ self.request_info.operation,
+ allow_redirects=False
+ )
+ new_checks.append(TestHTTPSAvailable(check_id_start, new_request))
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr["servers"]:
+ for path_id, path in descr.endpoints.items():
+ for op_id, op in path.items():
+ nonauth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id,
+ allow_redirects=False
+ )
+ test_cases.append(TestHTTPAvailable(cur_check_id, nonauth_request))
+ cur_check_id += 1
+
+ if config.auth:
+ auth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id,
+ allow_redirects=False
+ )
+ auth_info = AuthRequestInfo(config.auth)
+ test_cases.append(
+ TestHTTPAvailable(cur_check_id, auth_request, auth_info)
+ )
+ cur_check_id += 1
+
+ logging.debug(
+ f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = None
+ if "auth_info" in serialized:
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestHTTPAvailable(check_id, request_info, auth_info, **serialized)
+
+
+class TestDescriptionURLs(TestCase):
+ """
+ Checks which protocol schemes are defined for servers in the API description.
+ """
+ test_type = TestCaseType.SECURITY
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for TestDescriptionURLs.
+
+ :param description: API description.
+ :type description: OpenAPI
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ global_server_urls = self.description["servers"]
+
+ except KeyError as error:
+ logging.warning("Could not find 'servers' entry in API description.")
+ self.result.error = error
+ self.result.finished = False
+ return
+
+ self.result.value = {
+ "http_urls": [],
+ "https_urls": [],
+ "unknown_scheme_urls": [],
+ "paths_with_servers": []
+ }
+
+ paths_with_servers = set()
+
+ # Global server URLs
+ for server_url in global_server_urls:
+ url = urlparse(server_url["url"])
+ if url.scheme == "http":
+ self.result.value["http_urls"].append(server_url)
+
+ elif url.scheme == "https":
+ self.result.value["https_urls"].append(server_url)
+
+ else:
+ self.result.value["unknown_scheme_urls"].append(server_url)
+
+ # Endpoint server URLs
+ for path_id, path in self.description.endpoints.items():
+ if not "servers" in path.keys():
+ continue
+
+ for server_url in path["servers"]:
+ url = urlparse(server_url["url"])
+ if url.scheme == "http":
+ self.result.value["http_urls"].append(server_url)
+ paths_with_servers.update(path_id)
+
+ elif url.scheme == "https":
+ self.result.value["https_urls"].append(server_url)
+ paths_with_servers.update(path_id)
+
+ else:
+ self.result.value["unknown_scheme_urls"].append(server_url)
+ paths_with_servers.update(path_id)
+
+ self.result.value["paths_with_servers"] = sorted(list(paths_with_servers))
+
+ if len(self.result.value["http_urls"]) > 0:
+ self.result.issue_type = IssueType.FLAW
+
+ if len(self.result.value["unknown_scheme_urls"]) > 0:
+ self.result.issue_type = IssueType.PROBLEM
+
+ else:
+ self.result.issue_type = IssueType.OKAY
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+ # Check the found URls and see if they work
+ for http_url in self.result.value["http_urls"]:
+ new_request = RequestInfo(
+ http_url['url'],
+ "/",
+ "get"
+ )
+ if config.auth:
+ auth_info = AuthRequestInfo(
+ config.auth,
+ policy=AccessLevelPolicy.MAX
+ )
+
+ else:
+ auth_info = None
+
+ new_checks.append(TestHTTPAvailable(
+ check_id_start,
+ new_request,
+ auth_info
+ ))
+
+ check_id_start += 1
+
+ for https_url in self.result.value["https_urls"]:
+ new_request = RequestInfo(
+ https_url['url'],
+ "/",
+ "get"
+ )
+ if config.auth:
+ auth_info = AuthRequestInfo(
+ config.auth,
+ policy=AccessLevelPolicy.MAX
+ )
+
+ else:
+ auth_info = None
+
+ new_checks.append(TestHTTPAvailable(
+ check_id_start,
+ new_request,
+ auth_info
+ ))
+
+ check_id_start += 1
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(TestDescriptionURLs(cur_check_id, descr))
+
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return TestDescriptionURLs(check_id, description)
diff --git a/rest_attacker/checks/misc.py b/rest_attacker/checks/misc.py
new file mode 100644
index 0000000..192cbe6
--- /dev/null
+++ b/rest_attacker/checks/misc.py
@@ -0,0 +1,226 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for miscellaneous getting resource values via the API.
+"""
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+
+
+class GetHeaders(TestCase):
+ """
+ Get value of a specified HTTP header in a HTTP response.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None,
+ headers=[]
+ ) -> None:
+ """
+ Creates a new check for GetHeaders.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ :param parameters: List of header IDs to fetch.
+ :type parameters: list[tuple[str]]
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.headers = headers
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ self.result.value = {}
+ for header_id in self.headers:
+ self.result.value.update({
+ header_id: response.headers[header_id]
+ })
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ # No checks generated
+ return []
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "headers": self.headers
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return GetParameters(check_id, request_info, auth_info, **serialized)
+
+
+class GetParameters(TestCase):
+ """
+ Get value of a specified JSON key in a HTTP response body.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None,
+ parameters=[]
+ ) -> None:
+ """
+ Creates a new check for GetParameters.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ :param parameters: List of parameter paths to fetch. Parameter paths are
+ passed as string tuples..
+ :type parameters: list[tuple[str]]
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.parameters = parameters
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 400 <= response.status_code < 500:
+ self.result.status = CheckStatus.ERROR
+ return
+
+ response_body = response.json()
+
+ self.result.value = {}
+ for param in self.parameters:
+ current_item = response_body
+ for param_part in param:
+ current_item = current_item[param_part]
+
+ self.result.value.update({
+ "/".join(param): current_item
+ })
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ # No checks generated
+ return []
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "parameters": self.parameters
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return GetParameters(check_id, request_info, auth_info, **serialized)
diff --git a/rest_attacker/checks/resources.py b/rest_attacker/checks/resources.py
new file mode 100644
index 0000000..237b80b
--- /dev/null
+++ b/rest_attacker/checks/resources.py
@@ -0,0 +1,1075 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing resources and input parameters.
+"""
+
+import logging
+from pydoc import describe
+from rest_attacker.util.input_gen import replace_params
+from rest_attacker.util.openapi.wrapper import OpenAPI
+from rest_attacker.checks.misc import GetParameters
+from rest_attacker.util.test_result import CheckStatus, IssueType
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.report.report import Report
+
+
+class TestObjectIDInvalidUserAccess(TestCase):
+ """
+ Check if an object (resource with ID) is accessible without providing a
+ sufficient access level (= unauthorized access).
+ """
+ test_type = TestCaseType.SECURITY
+ auth_type = AuthType.OPTIONAL
+ live_type = LiveType.OFFLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None,
+ object_id: str = None,
+ object_name: str = None
+ ) -> None:
+ """
+ Creates a new check for TestObjectIDInvalidUserAccess.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ :param object_id: ID of the object that is requested.
+ :type object_id: str
+ :param object_name: Resource name of the object that is requested.
+ :type object_name: str
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.object_id = object_id
+ self.object_name = object_name
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ self.result.value = {}
+ self.result.issue_type = IssueType.OKAY
+
+ if response.status_code not in (401, 403, 404):
+ # Not Unauthorized/Forbidden/Not Found
+ self.result.issue_type = IssueType.PROBLEM
+
+ if 200 <= response.status_code < 300:
+ # Direct access possible
+ self.result.issue_type = IssueType.FLAW
+
+ self.result.value["status_code"] = response.status_code
+
+ if self.object_id:
+ self.result.value["object_id"] = self.object_id
+
+ if self.object_name:
+ self.result.value["object_name"] = self.object_name
+
+ if self.result.issue_type in (IssueType.PROBLEM, IssueType.FLAW):
+ try:
+ # Try to export received data
+ self.result.value["response_body"] = response.json()
+
+ except ValueError:
+ self.result.value["response_body"] = None
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ for server in descr["servers"]:
+ for path_id, path in descr.endpoints.items():
+ for op_id, op in path.items():
+ replacments = None
+ if descr.requires_parameters(path_id, op_id):
+ if not config.users or not config.current_user_id:
+ # No replacement parameters defined
+ continue
+
+ default_user = config.users[config.current_user_id]
+ req_parameters = descr.get_required_param_defs(path_id, op_id)
+
+ replacments = replace_params(path, default_user, req_parameters)
+ if not replacments:
+ # No replacements found
+ continue
+
+ else:
+ # Nothing is referenced
+ continue
+
+ if replacments:
+ request_info = RequestInfo(
+ server["url"],
+ replacments[0],
+ op_id,
+ headers=replacments[1],
+ params=replacments[2],
+ cookies=replacments[3]
+ )
+
+ else:
+ request_info = RequestInfo(
+ server["url"],
+ path_id,
+ op_id
+ )
+
+ test_cases.append(
+ TestObjectIDInvalidUserAccess(cur_check_id, request_info)
+ )
+ cur_check_id += 1
+
+ if config.auth:
+ auth_info = AuthRequestInfo(config.auth)
+ test_cases.append(
+ TestObjectIDInvalidUserAccess(cur_check_id, request_info, auth_info)
+ )
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "object_id": self.object_id,
+ "object_name": self.object_name,
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestObjectIDInvalidUserAccess(check_id, request_info, auth_info, **serialized)
+
+
+class CountParameterRequiredRefs(TestCase):
+ """
+ Determine frequency of required request parameters in an OpenAPI description. This is a
+ naive search that counts each occurence of the names of required parameters.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for FindIDParameters.
+
+ :param description: API description.
+ :type description: dict
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ paths = self.description.endpoints
+
+ except KeyError as error:
+ logging.warning("Could not find 'paths' entry in API description.")
+ self.result.error = error
+ self.result.status = CheckStatus.ERROR
+ return
+
+ unique_parameter_count = dict()
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ for path_id, path in paths.items():
+ for op_id, op in path.items():
+ for param_id in self.description.get_required_param_ids(path_id, op_id):
+ if param_id not in unique_parameter_count.keys():
+ unique_parameter_count.update({
+ param_id: 0
+ })
+
+ unique_parameter_count[param_id] += 1
+
+ self.result.issue_type = IssueType.CANDIDATE
+
+ self.result.value = dict(reversed(sorted(
+ unique_parameter_count.items(),
+ key=lambda dic: dic[1] # Sort params by count
+ )))
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+
+ new_checks.append(FindParameterReturns(
+ check_id_start,
+ self.description,
+ list(self.result.value.keys())
+ ))
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(CountParameterRequiredRefs(cur_check_id, descr))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return FindIDParameters(check_id, description)
+
+
+class FindIDParameters(TestCase):
+ """
+ Find parameters that could be resource IDs or other object references.
+ This test case may be used to find candidates for testing object level authorization.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for FindIDParameters.
+
+ :param description: API description.
+ :type description: dict
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ paths = self.description.endpoints
+
+ except KeyError as error:
+ logging.warning("Could not find 'paths' entry in API description.")
+ self.result.error = error
+ self.result.status = CheckStatus.ERROR
+ return
+
+ candidate_strings = {'id', 'name', 'obj'}
+
+ unique_parameters = set()
+ unique_parameter_count = dict()
+ endpoints = dict()
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ for path_id, path in paths.items():
+ for op_id, op in path.items():
+ parameters = self.description.get_required_param_defs(path_id, op_id).values()
+
+ for param in parameters:
+ if "$ref" in param.keys():
+ param = self.description.resolve_ref(param["$ref"])
+
+ param_name = param["name"]
+ param_loc = param["in"]
+
+ if not isinstance(param_name, str):
+ # parameter name can be malformed
+ continue
+
+ for candidate in candidate_strings:
+ if candidate in param_name.lower():
+ if param_name not in unique_parameters:
+ unique_parameter_count.update({
+ param_name: 0
+ })
+ unique_parameters.add(param_name)
+
+ unique_parameter_count[param_name] += 1
+
+ endpoints.update(
+ {path_id: (op_id, param_loc, param_name)}
+ )
+
+ self.result.issue_type = IssueType.CANDIDATE
+
+ self.result.value = {
+ "unique_parameters": sorted(list(unique_parameters)),
+ "unique_parameter_count": dict(reversed(sorted(
+ unique_parameter_count.items(),
+ key=lambda dic: dic[1] # Sort params by count
+ ))),
+ "endpoints": endpoints
+ }
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+
+ new_checks.append(FindParameterReturns(
+ check_id_start,
+ self.description,
+ list(self.result.value["unique_parameters"])
+ ))
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(FindIDParameters(cur_check_id, descr))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return FindIDParameters(check_id, description)
+
+
+class FindParameterReturns(TestCase):
+ """
+ Find endpoints which return specified parameters in their response.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI, parameters: list[str]) -> None:
+ """
+ Creates a new check for FindParameterReturns.
+
+ :param description: API description.
+ :type description: OpenAPI
+ :param parameters: Names of the parameters.
+ :type parameters: list[str]
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ self.search_parameters = parameters
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ paths = self.description.endpoints
+
+ except KeyError as error:
+ logging.warning("Could not find 'paths' entry in API description.")
+ self.result.error = error
+ self.result.status = CheckStatus.ERROR
+ return
+
+ param_locations = dict()
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ for path_id, path in paths.items():
+ for op_id, op in path.items():
+ responses = op["responses"]
+ for status_code, response in responses.items():
+ if not "content" in response.keys():
+ continue
+
+ response_content = response["content"]
+ for cty in response_content.values():
+ if not "schema" in cty.keys():
+ continue
+
+ schema_defs = []
+ schema_def = cty["schema"]
+
+ if "$ref" in schema_def.keys():
+ schema_def = self.description.resolve_ref(schema_def["$ref"])
+
+ # Skip choices in
+ if "allOf" in schema_def.keys():
+ schema_defs.extend(schema_def["allOf"])
+
+ elif "oneOf" in schema_def.keys():
+ schema_defs.extend(schema_def["oneOf"])
+
+ else:
+ schema_defs.append(schema_def)
+
+ for schema in schema_defs:
+ if not "properties" in schema.keys():
+ continue
+
+ schema_object = schema["properties"]
+
+ found_params = self._recursive_search(schema_object)
+ if len(found_params) == 0:
+ continue
+
+ self.result.issue_type = IssueType.CANDIDATE
+
+ for param_id in found_params:
+ param_name = param_id.split("/")[-1]
+ if param_name not in param_locations.keys():
+ param_locations[param_name] = {
+ "endpoints": []
+ }
+
+ param_locations[param_name]["endpoints"].append({
+ "path": path_id,
+ "op": op_id,
+ "status_code": status_code,
+ "location": param_id,
+ # Number of required parameters to access the endpoint
+ # useful if we want to get the search parameters
+ "required_param_count": len(
+ self.description.get_required_param_ids(path_id, op_id)
+ )
+ })
+
+ for param_loc in param_locations.values():
+ param_loc["endpoints"].sort(
+ # Sort endpoints by required parameter count
+ key=lambda item: item.get("required_param_count")
+ )
+
+ # Sort result by param name
+ param_locations = dict(sorted(param_locations.items(), key=lambda dic: dic[0]))
+
+ self.result.value = param_locations
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+
+ # TODO: Fetch parameter values?
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ # Proposed by FindIDParameters
+ return []
+
+ def _recursive_search(self, schema_object: dict) -> list[str]:
+ """
+ Recursively search parameter definitions in a JSON schema object.
+
+ :param schema_object: JSON schema definition.
+ :type schema_object: dict
+ """
+ found_params = []
+ for param_name, param in schema_object.items():
+ param_descr = ""
+
+ if "description" in param.keys():
+ param_descr = param["description"]
+
+ for candidate in self.search_parameters:
+ if candidate in param_name.lower() or \
+ candidate in param_descr.lower():
+ found_params.append(param_name)
+
+ if isinstance(param, dict):
+ if not "properties" in param.keys():
+ continue
+
+ subschema_object = param["properties"]
+ subparams = self._recursive_search(subschema_object)
+ for subparam_name in subparams:
+ subparam_id = "/".join((param_name, subparam_name))
+ found_params.append(subparam_id)
+
+ return found_params
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ "parameters": self.search_parameters,
+ }
+
+ return serialized
+
+ @ classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+ parameters = serialized["parameters"]
+
+ return FindParameterReturns(check_id, description, parameters)
+
+
+class FindSecurityParameters(TestCase):
+ """
+ Find parameters that could be security-related, i.e. they contain access control
+ data or other information for authentication/authorization.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for FindSecurityParameters.
+
+ :param description: API description.
+ :type description: OpenAPI
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ self._candidate_strings = {'token', 'key', 'auth', 'pass', 'pw', 'session'}
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ paths = self.description.endpoints
+
+ except KeyError as error:
+ logging.warning("Could not find 'paths' entry in API description.")
+ self.result.error = error
+ self.result.status = CheckStatus.ERROR
+ return
+
+ security_descr_set = set()
+ security_params = set()
+ endpoints = dict()
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ for path_id, path in paths.items():
+ for op_id, op in path.items():
+ responses = op["responses"]
+ for status_code, response in responses.items():
+ if "description" not in response.keys():
+ logging.info("Response has no description")
+ continue
+
+ description = response["description"]
+
+ for candidate in self._candidate_strings:
+ if candidate in description.lower():
+ security_descr_set.add(description)
+ endpoints.update(
+ {
+ path_id: {
+ "method": op_id,
+ "status_code": status_code,
+ "descr": description,
+ "params": [],
+ }
+ }
+ )
+
+ self.result.issue_type = IssueType.CANDIDATE
+
+ if not "content" in response.keys():
+ continue
+
+ response_content = response["content"]
+ for cty in response_content.values():
+ if not "schema" in cty.keys():
+ continue
+
+ schema_defs = []
+ schema_def = cty["schema"]
+
+ if "$ref" in schema_def.keys():
+ schema_def = self.description.resolve_ref(schema_def["$ref"])
+
+ # Skip choices in
+ if "allOf" in schema_def.keys():
+ schema_defs.extend(schema_def["allOf"])
+
+ elif "oneOf" in schema_def.keys():
+ schema_defs.extend(schema_def["oneOf"])
+
+ else:
+ schema_defs.append(schema_def)
+
+ for schema in schema_defs:
+ if not "properties" in schema.keys():
+ continue
+
+ schema_object = schema["properties"]
+
+ found_params = self._recursive_search(schema_object)
+ if len(found_params) == 0:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ continue
+
+ self.result.issue_type = IssueType.CANDIDATE
+ security_params.update(found_params)
+
+ if path_id not in endpoints.keys():
+ endpoints.update({
+ path_id: {
+ "method": op_id,
+ "status_code": status_code,
+ "descr": description,
+ "params": [],
+ }
+ })
+
+ endpoints[path_id]["params"].extend(found_params)
+
+ self.result.value = {
+ "security_descriptions": sorted(list(security_descr_set)),
+ "security_params": sorted(list(security_params)),
+ "endpoints": endpoints
+ }
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ super().propose(config, check_id_start)
+
+ new_checks = []
+
+ # Fetch parameter values
+ for path_name, path_item in self.result.value["endpoints"].items():
+ new_request = RequestInfo(
+ self.description["servers"][0]["url"], # TODO: What if there are multiple servers?
+ path_name,
+ path_item["method"]
+ )
+
+ if config.auth:
+ new_auth_info = AuthRequestInfo(config.auth)
+
+ else:
+ new_auth_info = None
+
+ # Split parameter subpaths
+ params = []
+ for param in path_item["params"]:
+ param_parts = param.split("/")
+ params.append(param_parts)
+
+ new_checks.append(GetParameters(
+ check_id_start,
+ new_request,
+ new_auth_info,
+ parameters=params
+ ))
+
+ check_id_start += 1
+
+ logging.debug(f"Proposed {len(new_checks)} new checks from check {self}")
+
+ return new_checks
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(FindSecurityParameters(cur_check_id, descr))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def _recursive_search(self, schema_object: dict) -> list[str]:
+ """
+ Recursively search parameter definitions in a JSON schema object.
+
+ :param schema_object: JSON schema definition.
+ :type schema_object: dict
+ """
+ found_params = []
+ for param_name, param in schema_object.items():
+ param_descr = ""
+
+ if "description" in param.keys():
+ param_descr = param["description"]
+
+ for candidate in self._candidate_strings:
+ if candidate in param_name.lower() or \
+ candidate in param_descr.lower():
+ found_params.append(param_name)
+
+ if isinstance(param, dict):
+ if not "properties" in param.keys():
+ continue
+
+ subschema_object = param["properties"]
+ subparams = self._recursive_search(subschema_object)
+ for subparam_name in subparams:
+ subparam_id = "/".join((param_name, subparam_name))
+ found_params.append(subparam_id)
+
+ return found_params
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @ classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return FindSecurityParameters(check_id, description)
+
+
+class FindDuplicateParameters(TestCase):
+ """
+ Find parameters that are returned at multiple endpoints. This can be used to find
+ alternative ways to access a specific parameter.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for FindDuplicateParameters.
+
+ :param description: API description.
+ :type description: OpenAPI
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ paths = self.description.endpoints
+
+ except KeyError as error:
+ logging.warning("Could not find 'paths' entry in API description.")
+ self.result.error = error
+ self.result.status = CheckStatus.ERROR
+ return
+
+ # Parameters by name
+ params = {}
+
+ # Components by reference
+ components = {}
+
+ # Search for parameters and components
+ for path_id, path in paths.items():
+ for op_id, op in path.items():
+ responses = op["responses"]
+ for status_code, response in responses.items():
+ if not "content" in response.keys():
+ logging.info("Response has no content.")
+ continue
+
+ response_content = response["content"]
+ for cty in response_content.values():
+ if not "schema" in cty.keys():
+ logging.info("Response content has no schema.")
+ continue
+
+ schema_defs = []
+ schema_def = cty["schema"]
+
+ if "$ref" in schema_def.keys():
+ component_ref = schema_def["$ref"]
+ if component_ref not in components.keys():
+ components.update({
+ component_ref: {
+ "endpoints": [{
+ "op": op_id,
+ "path": path_id,
+ "response_code": status_code
+ }],
+ "count": 1,
+ }
+ })
+
+ else:
+ components[component_ref]["endpoints"].append({
+ "op": op_id,
+ "path": path_id,
+ "response_code": status_code
+ })
+ components[component_ref]["count"] += 1
+
+ schema_def = self.description.resolve_ref(schema_def["$ref"])
+
+ # Handle choices
+ if "allOf" in schema_def.keys():
+ schema_defs.extend(schema_def["allOf"])
+
+ elif "oneOf" in schema_def.keys():
+ schema_defs.extend(schema_def["oneOf"])
+
+ else:
+ schema_defs.append(schema_def)
+
+ for schema in schema_defs:
+ if not "properties" in schema.keys():
+ continue
+
+ schema_object = schema["properties"]
+
+ found_params = self._recursive_search(schema_object)
+ if len(found_params) == 0:
+ continue
+
+ for param_id, param_meta in found_params.items():
+ param_meta.update({
+ "endpoints": [{
+ "op": op_id,
+ "path": path_id,
+ "response_code": status_code
+ }],
+ })
+
+ if param_id not in params.keys():
+ params.update({
+ param_id: param_meta
+ })
+
+ else:
+ params[param_id]["endpoints"].append(param_meta["endpoints"])
+ params[param_id]["count"] += param_meta["count"]
+
+ # Look for duplicate parameters (i.e. count > 1)
+ candidate_params = {}
+ for param_name, param_data in params.items():
+ if param_data["count"] > 1:
+ candidate_params.update({
+ param_name: param_data
+ })
+
+ # Look for duplicate components (i.e. count > 1)
+ candidate_components = {}
+ for component_ref, component_data in components.items():
+ if component_data["count"] > 1:
+ candidate_components.update({
+ component_ref: component_data
+ })
+
+ if len(candidate_params) > 0 or len(candidate_components) > 0:
+ self.result.issue_type = IssueType.CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.value = {
+ "params_count": len(candidate_params),
+ "params": dict(sorted(candidate_params.items(), key=lambda dic: dic[0])),
+ "components_count": len(candidate_params),
+ "components": dict(sorted(candidate_components.items(), key=lambda dic: dic[0])),
+ }
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(FindDuplicateParameters(cur_check_id, descr))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @ classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return FindDuplicateParameters(check_id, description)
+
+ def _recursive_search(self, schema_object: dict) -> dict[str, dict]:
+ """
+ Recursively search parameter definitions in a JSON schema object.
+
+ :param schema_object: JSON schema definition.
+ :type schema_object: dict
+ """
+ found_params: dict[str, dict] = {}
+ for param_name, param in schema_object.items():
+ if param_name not in found_params.keys():
+ found_params.update({
+ param_name: {
+ "count": 1
+ }
+ })
+
+ else:
+ found_params[param_name]["count"] += 1
+
+ if "properties" in param.keys():
+ subschema_object = param["properties"]
+ subparams = self._recursive_search(subschema_object)
+ for subparam_name in subparams.keys():
+ if subparam_name not in found_params.keys():
+ found_params.update({
+ param_name: {
+ "count": 1
+ }
+ })
+
+ else:
+ found_params[param_name]["count"] += 1
+
+ return found_params
diff --git a/rest_attacker/checks/scopes.py b/rest_attacker/checks/scopes.py
new file mode 100644
index 0000000..e55525b
--- /dev/null
+++ b/rest_attacker/checks/scopes.py
@@ -0,0 +1,1286 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing scope handling.
+"""
+import typing
+
+import logging
+
+from oauthlib.oauth2.rfc6749.tokens import OAuth2Token
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy, ClientInfo, OAuth2TokenGenerator
+from rest_attacker.util.openapi.wrapper import OpenAPI
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+
+
+class CheckScopesEndpoint(TestCase):
+ """
+ Check if an endpoint can be accessed with a specified authorization level (using OAuth2 scopes).
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.REQUIRED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo
+ ) -> None:
+ """
+ Creates a new check for CheckScopesEndpoint.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 < response.status_code < 300:
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value = {
+ "accepted": True
+ }
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.value = {
+ "accepted": False
+ }
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+
+ # TODO: Reactivate for services with no rate limits or
+ # services that actually define their security requirements
+ # Currently this test case drains the rate/access limit significantly :(
+ return test_cases
+
+ for descr in config.descriptions.values():
+ # Check for security schemes that support scoped security, e.g. OAuth
+ if "components" not in descr:
+ continue
+
+ if "security_schemes" not in descr["components"]:
+ continue
+
+ scoped_schemes = set()
+ for scheme_id, scheme in descr["components"]["security_schemes"].items():
+ if scheme["type"] != "oauth2":
+ # TODO: Other scoped schemes?
+ continue
+
+ scoped_schemes.add(scheme_id)
+
+ if len(scoped_schemes) == 0:
+ continue
+
+ for path_id, path_item in descr["paths"]:
+ if "parameters" in path_item.keys():
+ # TODO: Analyze paths with parameters
+ continue
+
+ for op_id, op in path_item.items():
+ if "parameters" in op.keys():
+ # TODO: Analyze operations with parameters
+ continue
+
+ for server in descr["servers"]:
+ nonauth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id
+ )
+ test_cases.append(CheckScopesEndpoint(cur_check_id, nonauth_request))
+ cur_check_id += 1
+
+ auth_request = RequestInfo(
+ server["url"],
+ path_id,
+ op_id
+ )
+ auth_info = AuthRequestInfo(
+ config.auth
+ )
+ test_cases.append(CheckScopesEndpoint(
+ cur_check_id, auth_request, auth_info))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "auth_info": self.auth_info.serialize(),
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return CheckScopesEndpoint(check_id, request_info, auth_info, **serialized)
+
+
+class ScopeMappingDescription(TestCase):
+ """
+ Creates a mapping of OAuth2 scopes to the endpoints they can access based on the information
+ in an OpenAPI description.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, description: OpenAPI) -> None:
+ """
+ Creates a new check for ScopeMappingDescription.
+
+ :param description: API description.
+ :type description: OpenAPI
+ """
+ super().__init__(check_id)
+
+ self.description = description
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ # Check for security schemes that support scoped security, e.g. OAuth
+ if "components" not in self.description:
+ logging.debug("API description has no components defined.")
+ self.result.status = CheckStatus.SKIPPED
+ return
+
+ if "securitySchemes" not in self.description["components"]:
+ logging.debug("API description has no security schemes defined.")
+ self.result.status = CheckStatus.SKIPPED
+ return
+
+ scoped_schemes = set()
+ scopemap = {}
+ for scheme_id, scheme in self.description["components"]["securitySchemes"].items():
+ if scheme["type"] != "oauth2":
+ # TODO: Other scoped schemes?
+ continue
+
+ scoped_schemes.add(scheme_id)
+ # Create scopemap from available scopes
+ for flow in scheme["flows"].values():
+ for scope_id in flow["scopes"]:
+ scopemap[scope_id] = []
+
+ if len(scoped_schemes) == 0:
+ logging.debug("API description has defined no scoped schemes (e.g. OAuth2).")
+ self.result.status = CheckStatus.SKIPPED
+ return
+
+ # Top-level security requirements
+ top_level_requirements = []
+ if "security" in self.description:
+ top_level_requirements = self.description["security"]
+ if len(top_level_requirements) == 0:
+ # This is non-standard because there should always be at least one object
+ # No security --> Empty security requirement object
+ logging.debug("API description has empty security requirements.")
+
+ # Operation-level security requirements
+ for path_id, path in self.description.endpoints.items():
+ for op_id, op in path.items():
+ if "security" not in op.keys():
+ # Use top-level requirements
+ requirements = top_level_requirements
+
+ else:
+ requirements = op["security"]
+
+ if len(op["security"]) == 0:
+ # This is non-standard because there should always be at least one object
+ # No security --> Empty security requirement object
+ logging.debug(f"{op_id} / {path_id} has empty security requirements.")
+
+ for requirement in requirements:
+ requirement_name = list(requirement.keys())[0]
+ if requirement_name not in scoped_schemes:
+ continue
+
+ for scopename in requirement[requirement_name]:
+ if scopename not in scopemap.keys():
+ # Shouldn't happen, but maybe interesting for analysis?
+ logging.debug(
+ f"Scope {scopename} required for {op_id} / {path_id} "
+ "not declared in security schemes.")
+ scopemap[scopename] = []
+
+ scopemap[scopename].append((path_id, op_id))
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value = scopemap
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ test_cases.append(ScopeMappingDescription(cur_check_id, descr))
+
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "description": self.description.description_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ description = config.descriptions[serialized["description"]]
+
+ return ScopeMappingDescription(check_id, description)
+
+
+class CompareTokenScopesToClientScopes(TestCase):
+ """
+ Check if scopes assigned to an OAuth2 token are available to the client that requests
+ the token.
+ """
+ test_type: TestCaseType = TestCaseType.SECURITY
+ auth_type: AuthType = AuthType.NOPE
+ live_type: LiveType = LiveType.OFFLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token: OAuth2Token,
+ client_info: ClientInfo
+ ) -> None:
+ """
+ Create a new check for TestTokenRequestScopeOmit.
+
+ :param token: OAuth2 token with assigned scopes.
+ :type token: OAuth2Token
+ :param client_info: Information about the client for which the token was issued.
+ :type client_info: ClientInfo
+ """
+ super().__init__(check_id)
+
+ self.token = token
+ self.client_info = client_info
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ # Check if we received any scopes that the client does not support
+ received_scopes = set(self.token.scopes)
+ extra_scopes = received_scopes.difference(set(self.client_info.supported_scopes))
+ self.result.value = {}
+ if len(extra_scopes) > 0:
+ # Privilege escalation?
+ self.result.issue_type = IssueType.FLAW
+
+ else:
+ self.result.issue_type = IssueType.OKAY
+
+ self.result.value["supported_by_client"] = self.client_info.supported_scopes
+ self.result.value["unsupported_by_client"] = sorted(list(extra_scopes))
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ token = cred.get_token()
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, token, cred.client_info))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "token": self.token,
+ "client_id": self.client_info.client_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+
+ client_info = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ client_info = cred.client_info
+ break
+
+ if not client_info:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return CompareTokenScopesToClientScopes(check_id, token, client_info)
+
+
+class TestTokenRequestScopeOmit(TestCase):
+ """
+ Check which scopes are assigned to an OAuth2 token if the scope parameter is omitted.
+ This means the OAuth2 authorization request is sent without a scope parameter.
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ grant_type: str,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestTokenRequestScopeOmit.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param grant_type: Grant used to request the token (code or token).
+ :type grant_type: str
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: list[str]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.grant_type = grant_type
+ self.claims = claims
+ self.token: OAuth2Token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ # Force omission of scope parameter
+ self.token = self.token_gen.request_new_token(
+ scopes=None,
+ grant_type=self.grant_type,
+ policy=AccessLevelPolicy.NOPE
+ )
+
+ except Exception as err:
+ self.result.error = err
+ self.token = None
+
+ self.result.value = {}
+
+ if not self.token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if "scope_reqired" in cred.client_info.flags:
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeOmit(cur_check_id, cred, 'code'))
+ cur_check_id += 1
+
+ if 'token' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeOmit(cur_check_id, cred, 'token'))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "grant_type": self.grant_type,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return TestTokenRequestScopeOmit(check_id, token_gen, **serialized)
+
+
+class TestTokenRequestScopeEmpty(TestCase):
+ """
+ Check which scopes are assigned to an OAuth2 token if the scope parameter is empty.
+ This means the scope query parameter looks like this: scope=
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ grant_type: str,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestTokenRequestScopeEmpty.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param grant_type: Grant used to request the token (code or token).
+ :type grant_type: str
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: Optional[List]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.grant_type = grant_type
+ self.claims = claims
+ self.token: OAuth2Token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ try:
+ self.token = self.token_gen.request_new_token(
+ scopes=[],
+ grant_type=self.grant_type,
+ policy=AccessLevelPolicy.NOPE
+ )
+
+ except Exception as err:
+ self.result.error = err
+ self.token = None
+
+ self.result.value = {}
+
+ if not self.token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeEmpty(cur_check_id, cred, 'code'))
+ cur_check_id += 1
+
+ if 'token' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeEmpty(cur_check_id, cred, 'token'))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "grant_type": self.grant_type,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return TestTokenRequestScopeOmit(check_id, token_gen, **serialized)
+
+
+class TestTokenRequestScopeInvalid(TestCase):
+ """
+ Check which scopes are assigned to an OAuth2 token if the scope parameter is invalid.
+ Invalid means the scope is not supported by the service.
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ grant_type: str,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestTokenRequestScopeInvalid.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param grant_type: Grant used to request the token (code or token).
+ :type grant_type: str
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: Optional[List]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.grant_type = grant_type
+ self.claims = claims
+ self.token: OAuth2Token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ # Use an invalid scope value:
+
+ # MD5(REST-Attacker)
+ # scope = "8516bfad8d65603b872d2c4a688135d7"
+
+ # Use a pseudo-random 16 Bit number and hash it
+ # then use hex value as scope
+ from random import randint
+ from hashlib import sha256
+ rand_val = randint(0, 2 ** 16 - 1)
+ rand_bytes = rand_val.to_bytes(length=2, byteorder='little')
+ scope = sha256(rand_bytes).hexdigest()
+
+ self.result.value = {}
+ self.result.value["random_number"] = rand_val
+ self.result.value["scope"] = scope
+
+ try:
+ self.token = self.token_gen.request_new_token(
+ scopes=[scope],
+ grant_type=self.grant_type
+ )
+
+ except Exception as err:
+ self.result.error = err
+ self.token = None
+
+ if not self.token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeInvalid(cur_check_id, cred, 'code'))
+ cur_check_id += 1
+
+ if 'token' in cred.client_info.supported_grants:
+ test_cases.append(TestTokenRequestScopeInvalid(cur_check_id, cred, 'token'))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "grant_type": self.grant_type,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return TestTokenRequestScopeOmit(check_id, token_gen, **serialized)
+
+
+class TestRefreshTokenRequestScopeOmit(TestCase):
+ """
+ Check which scopes are assigned to a refreshed OAuth2 token
+ if the scope parameter is omitted. This means the OAuth2 refresh request is sent
+ without a scope parameter.
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ first_token: OAuth2Token = None,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestRefreshTokenRequestScopeOmit.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param token: Token that is refreshed. A new token is requested if no token was specified.
+ :type token_gen: OAuth2Token
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: Optional[List]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.claims = claims
+ self.first_token = first_token
+ self.refreshed_token: OAuth2Token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ if not self.first_token:
+ self.first_token = self.token_gen.request_new_token(
+ scopes=None,
+ grant_type='code',
+ policy=AccessLevelPolicy.NOPE
+ )
+
+ try:
+ self.refreshed_token = self.token_gen.refresh_token(self.first_token, scopes=None)
+
+ except Exception as err:
+ self.result.error = err
+ self.refreshed_token = None
+
+ self.result.value = {}
+
+ if not self.refreshed_token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.refreshed_token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.refreshed_token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.first_token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if "scope_reqired" in cred.client_info.flags:
+ continue
+
+ if not 'refresh_token' in cred.client_info.supported_grants:
+ # Generator must support refreshing tokens
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestRefreshTokenRequestScopeOmit(cur_check_id, cred))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "first_token": self.first_token,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return TestRefreshTokenRequestScopeOmit(check_id, token_gen, **serialized)
+
+
+class TestRefreshTokenRequestScopeEmpty(TestCase):
+ """
+ Check which scopes are assigned to a refreshed OAuth2 token if the scope parameter is empty.
+ This means the scope query parameter looks like this: scope=
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ first_token: OAuth2Token = None,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestRefreshTokenRequestScopeEmpty.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param token: Token that is refreshed. A new token is requested if no token was specified.
+ :type token_gen: OAuth2Token
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: Optional[List]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.claims = claims
+ self.first_token = first_token
+ self.refreshed_token: OAuth2Token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ if not self.first_token:
+ self.first_token = self.token_gen.request_new_token(
+ scopes=None,
+ grant_type='code',
+ policy=AccessLevelPolicy.NOPE
+ )
+
+ try:
+ self.refreshed_token = self.token_gen.refresh_token(self.first_token, scopes=[])
+
+ except Exception as err:
+ self.result.error = err
+ self.refreshed_token = None
+
+ self.result.value = {}
+
+ if not self.refreshed_token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.refreshed_token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.refreshed_token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.first_token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if not 'refresh_token' in cred.client_info.supported_grants:
+ # Generator must support refreshing tokens
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestRefreshTokenRequestScopeEmpty(cur_check_id, cred))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "first_token": self.first_token,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ serialized.pop("client_id")
+
+ return TestRefreshTokenRequestScopeEmpty(check_id, token_gen, **serialized)
+
+
+class TestRefreshTokenRequestScopeInvalid(TestCase):
+ """
+ Check which scopes are assigned to a refreshed OAuth2 token if the scope parameter is invalid.
+ Invalid means the scope is not supported by the service.
+ """
+ test_type: TestCaseType = TestCaseType.ANALYTICAL
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ first_token: OAuth2Token = None,
+ claims: list[str] = None
+ ) -> None:
+ """
+ Create a new check for TestRefreshTokenRequestScopeInvalid.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param token: Token that is refreshed. A new token is requested if no token was specified.
+ :type token_gen: OAuth2Token
+ :param claims: Optional list of scopes that are expected to be returned.
+ :type claims: Optional[List]
+ """
+ super().__init__(check_id)
+
+ self.token_gen = token_gen
+ self.claims = claims
+ self.first_token = first_token
+ self.refreshed_token = None
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ if not self.first_token:
+ self.first_token = self.token_gen.request_new_token(scopes=None, grant_type='code')
+
+ # MD5(REST-Attacker)
+ # scope = "8516bfad8d65603b872d2c4a688135d7"
+
+ # Use a pseudo-random 16 Bit number and hash it
+ # then use hex value as scope
+ from random import randint
+ from hashlib import sha256
+ rand_val = randint(0, 2 ** 16 - 1)
+ rand_bytes = rand_val.to_bytes(length=2, byteorder='little')
+ scope = sha256(rand_bytes).hexdigest()
+
+ self.result.value = {}
+
+ self.result.value["random_number"] = rand_val
+ self.result.value["scope"] = scope
+
+ try:
+ self.refreshed_token = self.token_gen.refresh_token(self.first_token, scopes=[scope])
+
+ except Exception as err:
+ self.result.error = err
+ self.refreshed_token = None
+
+ if not self.refreshed_token:
+ # Exit if no token could be created
+ logging.warning(f"No token received for checking {repr(self)}.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if not self.refreshed_token.scopes:
+ # Auth server should indicate scope according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
+ logging.info(f"No scope information received in token.")
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = None
+ self.result.status = CheckStatus.FINISHED
+ return
+
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["received_scopes"] = self.refreshed_token.scopes
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ cur_check_id = check_id_start
+ test_cases = []
+ if self.result.value["received_scopes"] is not None:
+ test_cases.append(CompareTokenScopesToClientScopes(
+ cur_check_id, self.first_token, self.token_gen.client_info))
+ cur_check_id += 1
+
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if not 'refresh_token' in cred.client_info.supported_grants:
+ # Generator must support refreshing tokens
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestRefreshTokenRequestScopeEmpty(cur_check_id, cred))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "client_id": self.token_gen.client_info.client_id,
+ "first_token": self.first_token,
+ "claims": self.claims,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ return TestRefreshTokenRequestScopeOmit(check_id, token_gen, **serialized)
diff --git a/rest_attacker/checks/token.py b/rest_attacker/checks/token.py
new file mode 100644
index 0000000..ea08c03
--- /dev/null
+++ b/rest_attacker/checks/token.py
@@ -0,0 +1,660 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing tokens provided by the service.
+"""
+
+import logging
+import base64
+import time
+
+from oauthlib.oauth2.rfc6749.tokens import OAuth2Token
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.report.report import Report
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy, OAuth2TokenGenerator
+from rest_attacker.util.test_result import CheckStatus, IssueType
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+
+
+class TestReadOAuth2Expiration(TestCase):
+ """
+ Check the expiration time of the provided token.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id: int, token: OAuth2Token) -> None:
+ """
+ Creates a new check for TestExpiration.
+
+ :param token: OAuth2 token from token request.
+ :type token: OAuth2Token
+ """
+ super().__init__(check_id)
+
+ self.token = token
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if "expires_in" in self.token:
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value = {
+ "vailidity_length": self.token["expires_in"]
+ }
+
+ if "expires_at" in self.token:
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value = {
+ "expires_at": self.token["expires_at"]
+ }
+
+ if not self.result.value:
+ logging.info("Token has no expiration time information.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ # TODO: Refresh/Expiration checks
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ token = cred.get_token()
+ test_cases.append(TestReadOAuth2Expiration(cur_check_id, token))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "token": self.token,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+
+ return TestReadOAuth2Expiration(check_id, token)
+
+
+class TestOAuth2Expiration(TestCase):
+ """
+ Check if the provided token expires after the specified time.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.REQUIRED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token: OAuth2Token,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo
+ ) -> None:
+ """
+ Creates a new check for TestExpiration.
+
+ :param token: OAuth2 token from a token request.
+ :type token: OAuth2Token
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that contains the specified token.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.token = token
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if "expires_at" in self.token:
+ sleep_time = self.token["expires_at"] - time.time() - 30
+
+ else:
+ logging.warning("Token has no expiration time information.")
+ self.result.status = CheckStatus.ERROR
+ return
+
+ if sleep_time < 0:
+ logging.info("Token is already expired. Skipping check.")
+ self.result.status = CheckStatus.SKIPPED
+ return
+
+ # Sleep until expiration time reached
+ time.sleep(sleep_time)
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ # Test various times to determine expiration allowance
+ # 30 seconds before expiration
+ logging.debug("Testing token expiration validation: -30 seconds after expiration time.")
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ response = self.request_info.send()
+
+ if not 200 <= response.status_code < 300:
+ logging.warning("Token could not be used before expiration time was reached.")
+ self.result.status = CheckStatus.ERROR
+ return
+
+ validity_time = 0
+ # Test various times to determine expiration allowance
+ # 1 seconds after expiration
+ logging.debug("Testing token expiration validation: 1 seconds after expiration time.")
+ time.sleep(31)
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 <= response.status_code < 300:
+ logging.debug("Token accepted: 1 second after expiration time.")
+ # Set to PROBLEM bcause there might be some allowance
+ self.result.issue_type = IssueType.PROBLEM
+ validity_time = 1
+
+ # 60 seconds after expiration
+ logging.debug("Testing token expiration validation: 60 seconds after expiration time.")
+ time.sleep(60)
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 <= response.status_code < 300:
+ logging.debug("Token accepted: 60 seconds after expiration time.")
+ self.result.issue_type = IssueType.FLAW
+ validity_time = 60
+
+ # 300 seconds (= 5 mins) after expiration
+ logging.debug("Testing token expiration validation: 300 seconds after expiration time.")
+ time.sleep(240)
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ if 200 <= response.status_code < 300:
+ logging.debug("Token accepted: 300 seconds after expiration time.")
+ self.result.issue_type = IssueType.FLAW
+ validity_time = 300
+
+ self.result.value = {
+ "min_validity_time": validity_time
+ }
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ return test_cases
+
+ # TODO: Reactivate
+ # Currently these tests take ages to complete
+
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ token = cred.get_token()
+ test_cases.append(TestOAuth2Expiration(cur_check_id, token))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "auth_info": self.auth_info.serialize(),
+ "token": self.token,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestOAuth2Expiration(check_id, token, request_info, auth_info)
+
+
+class TestDecodeOAuth2JWT(TestCase):
+ """
+ Check if the OAuth2 Token is a JWT.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+
+ def __init__(self, check_id, token: OAuth2Token) -> None:
+ """
+ Creates a new check for TestExpiration.
+
+ :param token: OAuth2 token from token request.
+ :type token: OAuth2Token
+ """
+ super().__init__(check_id)
+
+ self.token = token
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ self.result.value = {}
+ if not "access_token" in self.token:
+ logging.info("Token has no access token defined.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ else:
+ access_token = self.token["access_token"]
+ jwt_candidate = access_token.split('.')
+
+ if len(jwt_candidate) == 3:
+ # Test if header and payload can be decoded
+ try:
+ header_candidate = base64.urlsafe_b64decode(
+ jwt_candidate[0] + '=' * (4 - len(jwt_candidate[0]) % 4)
+ )
+ self.result.issue_type = IssueType.CANDIDATE
+ self.result.value["header"] = header_candidate
+ logging.info("Token header could be decoded with Base64Url.")
+
+ payload_candidate = base64.urlsafe_b64decode(
+ jwt_candidate[1] + '=' * (4 - len(jwt_candidate[1]) % 4)
+ )
+ self.result.value["header"] = header_candidate.decode('utf-8')
+ self.result.value["payload"] = payload_candidate.decode('utf-8')
+ logging.info("Token payload could be decoded with Base64Url.")
+
+ except ValueError as err:
+ logging.info("Token could not be decoded.")
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.auth:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ token = cred.get_token()
+ test_cases.append(TestDecodeOAuth2JWT(cur_check_id, token))
+ cur_check_id += 1
+
+ # Refresh token
+ if 'refresh_token' in token:
+ refresh_token = token
+ test_cases.append(TestDecodeOAuth2JWT(cur_check_id, refresh_token))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "token": self.token,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+
+ return TestDecodeOAuth2JWT(check_id, token)
+
+
+class TestRefreshTokenRevocation(TestCase):
+ """
+ Check if refresh tokens are single-use, i.e. they are invalidated after redeeming them once.
+ """
+ test_type: TestCaseType = TestCaseType.SECURITY
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ token_gen: OAuth2TokenGenerator,
+ token: OAuth2Token = None,
+ ) -> None:
+ """
+ Create a new check for TestRefreshTokenRequestScopeInvalid.
+
+ :param token_gen: Token Generator for OAuth2 tokens.
+ :type token_gen: OAuth2TokenGenerator
+ :param token: Token with a refresh token.
+ :type token: OAuth2Token
+ """
+ super().__init__(check_id)
+
+ self.token = token
+ self.token_gen = token_gen
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ if not self.token:
+ self.token = self.token_gen.request_new_token(
+ grant_type='code',
+ policy=AccessLevelPolicy.MAX
+ )
+
+ if not 'refresh_token' in self.token:
+ # Token is not refreshable
+ logging.warning(f"No refresh token available received for checking {repr(self)}.")
+ self.result.status = CheckStatus.ERROR
+ return
+
+ self.result.value = {}
+
+ # First redemption; this should be fine
+ new_token = self.token_gen.refresh_token(self.token)
+
+ # Second redemption; this may be rejected according to
+ # https://datatracker.ietf.org/doc/html/rfc6749#section-6
+ try:
+ new_token2 = self.token_gen.refresh_token(self.token)
+ self.result.issue_type = IssueType.PROBLEM
+ self.result.value["refresh_token"] = self.token["refresh_token"]
+ self.result.value["single_use"] = False
+
+ except:
+ self.result.issue_type = IssueType.OKAY
+ self.result.value["refresh_token"] = self.token["refresh_token"]
+ self.result.value["single_use"] = True
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ if not 'refresh_token' in cred.client_info.supported_grants:
+ # Generator must support refreshing tokens
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ test_cases.append(TestRefreshTokenRevocation(cur_check_id, cred))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "token": self.token,
+ "client_id": self.token_gen.client_info.client_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+
+ token_gen = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id"]:
+ token_gen = cred
+ break
+
+ if not token_gen:
+ raise Exception(f"Client with ID {serialized['client_id']} not found.")
+
+ return TestRefreshTokenRevocation(check_id, token_gen, token)
+
+
+class TestRefreshTokenClientBinding(TestCase):
+ """
+ Check if refresh tokens are bound to the client that requests the corresponding access tokens.
+ """
+ test_type: TestCaseType = TestCaseType.SECURITY
+ auth_type: AuthType = AuthType.REQUIRED
+ live_type: LiveType = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ client0_token_gen: OAuth2TokenGenerator,
+ client1_token_gen: OAuth2TokenGenerator,
+ token: OAuth2Token = None,
+ ) -> None:
+ """
+ Create a new check for TestRefreshTokenRequestScopeInvalid.
+
+ :param client0_token_gen: Token Generator that genereted the token.
+ :type client0_token_gen: OAuth2TokenGenerator
+ :param client1_token_gen: Token Generator with different client information.
+ :type client1_token_gen: OAuth2TokenGenerator
+ :param token: Token with a refresh token.
+ :type token: OAuth2Token
+ """
+ super().__init__(check_id)
+
+ self.token = token
+ self.token_gen0 = client0_token_gen
+ self.token_gen1 = client1_token_gen
+
+ def run(self) -> None:
+ self.result.status = CheckStatus.RUNNING
+
+ if self.token_gen0.client_info.client_id == self.token_gen1.client_info.client_id:
+ # Clients are identical
+ logging.info(f"Skipping check {repr(self)}: Clients are identical.")
+ self.result.status = CheckStatus.SKIPPED
+ return
+
+ if not self.token:
+ self.token = self.token_gen0.request_new_token(
+ grant_type='code',
+ policy=AccessLevelPolicy.MAX
+ )
+
+ if not 'refresh_token' in self.token:
+ # Token is not refreshable
+ logging.warning(f"No refresh token available received for checking {repr(self)}.")
+ self.result.status = CheckStatus.ERROR
+ return
+
+ self.result.value = {}
+ self.result.value["initial_client"] = self.token_gen0.client_info.client_id
+
+ try:
+ # Try refreshing with different client than the one who
+ # received the token
+ new_token = self.token_gen1.refresh_token(self.token)
+ self.result.issue_type = IssueType.FLAW
+ self.result.value["bound"] = False
+ self.result.value["refresh_token"] = self.token["refresh_token"]
+ self.result.value["refresher_client"] = self.token_gen1.client_info.client_id
+
+ except:
+ self.result.issue_type = IssueType.OKAY
+ self.result.value["bound"] = True
+ self.result.value["refresh_token"] = self.token["refresh_token"]
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start) -> list:
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0) -> list:
+ if not config.credentials:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+
+ # Needs 2 or more clients for testing
+ # and at least 1 refreshable client
+ oauth2_clients = []
+ refreshable_clients = []
+ for cred in config.credentials.values():
+ if not isinstance(cred, OAuth2TokenGenerator):
+ continue
+
+ oauth2_clients.append(cred)
+
+ if not 'refresh_token' in cred.client_info.supported_grants:
+ # Generator must support refreshing tokens
+ continue
+
+ if 'code' in cred.client_info.supported_grants:
+ refreshable_clients.append(cred)
+
+ if len(oauth2_clients) > 1 and len(refreshable_clients) > 0:
+ for refr_client in refreshable_clients:
+ for other_client in oauth2_clients:
+ if refr_client is other_client:
+ continue
+
+ test_cases.append(TestRefreshTokenClientBinding(
+ cur_check_id, refr_client, other_client))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "token": self.token,
+ "client_id0": self.token_gen0.client_info.client_id,
+ "client_id1": self.token_gen1.client_info.client_id,
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ token = serialized["token"]
+ token_gen0 = None
+ token_gen1 = None
+ for cred in config.credentials.values():
+ if isinstance(cred, OAuth2TokenGenerator):
+ if cred.client_info.client_id == serialized["client_id0"]:
+ token_gen0 = cred
+
+ elif cred.client_info.client_id == serialized["client_id1"]:
+ token_gen1 = cred
+
+ if token_gen0 and token_gen1:
+ break
+
+ if not token_gen0:
+ raise Exception(f"Client with ID {serialized['client_id0']} not found.")
+
+ if not token_gen1:
+ raise Exception(f"Client with ID {serialized['client_id1']} not found.")
+
+ return TestRefreshTokenClientBinding(check_id, token_gen0, token_gen1, token)
diff --git a/rest_attacker/checks/types.py b/rest_attacker/checks/types.py
new file mode 100644
index 0000000..6532d77
--- /dev/null
+++ b/rest_attacker/checks/types.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test case type enums.
+"""
+
+import enum
+
+
+@enum.unique
+class TestCaseType(enum.Enum):
+ """
+ Test case types. Used to classify test cases to interpret their results.
+ """
+ ANALYTICAL = "analytical"
+ SECURITY = "security"
+ COMPARISON = "comparison" # unused; originally intended for comparing checks between runs
+ META = "meta"
+ # TODO: instead of "meta" being a test case type
+ # it could be a different class. like
+ # the TestSuite class in unittest
+
+
+@enum.unique
+class AuthType(enum.Enum):
+ """
+ Specifies whether access control data is required for this test case.
+ """
+ NOPE = "nope" # no access control data used
+ OPTIONAL = "optional" # access control data can be used but is not required
+ RECOMMENDED = "recommended" # access control data should be used but is not required
+ REQUIRED = "required" # access control data must be used
+
+
+@enum.unique
+class LiveType(enum.Enum):
+ """
+ Specifies whether a test case requires live access to the API.
+ """
+ ONLINE = "online" # sends API requests
+ OFFLINE = "offline" # does not send API requests
diff --git a/rest_attacker/checks/undocumented.py b/rest_attacker/checks/undocumented.py
new file mode 100644
index 0000000..062e99e
--- /dev/null
+++ b/rest_attacker/checks/undocumented.py
@@ -0,0 +1,577 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Test cases for analyzing undocumented behaviour.
+"""
+
+import logging
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+from rest_attacker.util.test_result import CheckStatus, IssueType
+from rest_attacker.checks.types import AuthType, LiveType, TestCaseType
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.report.report import Report
+
+
+class TestOptionsHTTPMethod(TestCase):
+ """
+ Checks which HTTP methods are allowed for a path using the OPTIONS HTTP method.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None,
+ claims=None
+ ) -> None:
+ """
+ Creates a new check for TestOptionsHTTPMethod.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ :param claims: Methods that the endpoint claims to support.
+ :type claims: list[str]
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ self.claims = claims
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ # if response.status_code != 204:
+ # self.result.status = CheckStatus.ERROR
+ # self.result.error = Exception(
+ # "Could not retrieve allowed methods "
+ # f"for endpoint {self.request_info.endpoint_url} using OPTIONS method")
+ # return
+
+ self.result.value = {
+ "path": self.request_info.path,
+ "status_code": response.status_code,
+ }
+
+ if "allow" in response.headers.keys():
+ allowed_methods = response.headers["allow"].lower().split(", ")
+ self.result.value["allowed_methods"] = allowed_methods
+
+ elif "access-control-allow-methods" in response.headers.keys():
+ # CORS sometimes reveals same info
+ allowed_methods = response.headers["access-control-allow-methods"].lower().split(", ")
+ self.result.value["allowed_methods_cors"] = allowed_methods
+
+ else:
+ self.result.status = CheckStatus.ERROR
+ self.result.error = Exception(
+ "Could not retrieve allowed methods "
+ f"for path {self.request_info.path} using OPTIONS method")
+ return
+
+ if len(allowed_methods) == 1 and "OPTIONS" in allowed_methods:
+ # only method is OPTIONS
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.CANDIDATE
+
+ if self.claims:
+ # Check if claims and results match up
+ wrong_claims = list(set(self.claims) - set(allowed_methods))
+ missing_claims = list(set(allowed_methods) - set(self.claims))
+
+ self.result.value.update({
+ "claims": {
+ "claimed": self.claims,
+ "wrong_claims": sorted(wrong_claims),
+ "missing_claims": sorted(missing_claims),
+ }
+ })
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ server_urls = []
+ for server in descr["servers"]:
+ server_urls.append(server["url"])
+
+ for path_id, path_item in descr.endpoints.items():
+ claims = list(path_item.keys())
+ for server_url in server_urls:
+ auth_request = RequestInfo(
+ server_url,
+ path_id,
+ "options"
+ )
+
+ if config.auth:
+ auth_info = AuthRequestInfo(config.auth)
+
+ else:
+ auth_info = None
+
+ test_cases.append(TestOptionsHTTPMethod(cur_check_id,
+ auth_request,
+ auth_info,
+ claims=claims))
+ cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ "claims": self.claims,
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestOptionsHTTPMethod(check_id, request_info, auth_info, **serialized)
+
+
+class MetaTestOptionsHTTPMethod(TestCase):
+ """
+ Aggregate the results of TestOptionsHTTPMethod. This only aggregates results
+ of checks for which 'claims' where defined.
+ """
+ test_type = TestCaseType.META
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+ generates_for = (TestOptionsHTTPMethod,)
+
+ def __init__(self, check_id, checks=[]) -> None:
+ """
+ Creates a new check for MetaTestOptionsHTTPMethod.
+
+ :param checks: List of checks investigated.
+ :type checks: list[TestOptionsHTTPMethod]
+ """
+ super().__init__(check_id)
+
+ self.checks = checks
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if any(check.result.status == CheckStatus.QUEUED for check in self.checks):
+ self.result.status = CheckStatus.ERROR
+ self.result.error = Exception(
+ f"Cannot run meta check {self}. Some checks are not finished.")
+ return
+
+ # Sort out error and skipped checks
+ checks = []
+ for check in self.checks:
+ if check.result.status == CheckStatus.FINISHED:
+ checks.append(check)
+
+ self.result.value = {
+ "affected_paths": 0,
+ "skipped_paths": 0, # Skipped because of no claims set
+ "total_wrong_claims": 0,
+ "total_missing_claims": 0,
+ "paths": []
+ }
+ for check in checks:
+ if check.result.issue_type == IssueType.CANDIDATE:
+ if not "claims" in check.result.value.keys():
+ # Only aggregate if claims exists
+ self.result.value["skipped_paths"] += 1
+ continue
+
+ path = check.result.value["path"]
+ claims = check.result.value["claims"]
+
+ if path not in self.result.value["paths"]:
+ self.result.value["affected_paths"] += 1
+
+ self.result.value["paths"].append(path)
+ self.result.value["total_wrong_claims"] += len(claims["wrong_claims"])
+ self.result.value["total_missing_claims"] += len(claims["missing_claims"])
+
+ if self.result.value["affected_paths"] > 0:
+ self.result.issue_type = IssueType.CANDIDATE
+
+ else:
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ subchecks = TestOptionsHTTPMethod.generate(config, check_id_start)
+
+ test_cases = []
+ test_cases.extend(subchecks)
+ check_id_start += len(subchecks)
+
+ test_cases.append(MetaTestOptionsHTTPMethod(check_id_start, subchecks))
+
+ check_id_start += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "check_ids": [check.check_id for check in self.checks],
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ # TODO: Reference checks from deserialized config
+
+ # return MetaTestOptionsHTTPMethod(check_id, **serialized)
+ return None
+
+
+class TestAllowedHTTPMethod(TestCase):
+ """
+ Checks if a defined path supports a specified HTTP method/API operation.
+ """
+ test_type = TestCaseType.ANALYTICAL
+ auth_type = AuthType.RECOMMENDED
+ live_type = LiveType.ONLINE
+
+ def __init__(
+ self,
+ check_id: int,
+ request_info: RequestInfo,
+ auth_info: AuthRequestInfo = None
+ ) -> None:
+ """
+ Creates a new check for TestAllowedHTTPMethod.
+
+ :param request_info: RequestInfo object that stores data to make the request.
+ :type request_info: RequestInfo
+ :param auth_info: AuthRequestInfo object that is used for authentication if specified.
+ :type auth_info: AuthRequestInfo
+ """
+ super().__init__(check_id)
+
+ self.request_info = request_info
+ self.auth_info = auth_info
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ auth_data = None
+ if self.auth_info:
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+
+ response = self.request_info.send(auth_data)
+ self.result.last_response = response
+
+ self.result.value = {
+ "path": self.request_info.path,
+ "http_method": self.request_info.operation,
+ "status_code": response.status_code
+ }
+ if response.status_code == 405:
+ # 405: Method not allowed
+ self.result.issue_type = IssueType.NO_CANDIDATE
+
+ else:
+ # Anything else might indicate access
+ self.result.issue_type = IssueType.CANDIDATE
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Curl request
+ if self.auth_info:
+ # TODO: Save used auth payload somewhere
+ # auth_data = self.auth_info.auth_gen.get_auth()
+ report["curl"] = self.request_info.get_curl_command()
+
+ else:
+ report["curl"] = self.request_info.get_curl_command()
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @classmethod
+ def generate(cls, config, check_id_start=0):
+ if not config.descriptions:
+ return []
+
+ cur_check_id = check_id_start
+ test_cases = []
+ for descr in config.descriptions.values():
+ server_urls = []
+ for server in descr["servers"]:
+ server_urls.append(server["url"])
+
+ for path_id, path in descr.endpoints.items():
+ op_names = path.keys()
+
+ for op in HTTP_REQUEST_METHODS:
+ if op in op_names:
+ continue
+
+ for server_url in server_urls:
+ auth_request = RequestInfo(
+ server_url,
+ path_id,
+ op
+ )
+
+ if config.auth:
+ auth_info = AuthRequestInfo(config.auth)
+
+ else:
+ auth_info = None
+
+ test_cases.append(TestAllowedHTTPMethod(cur_check_id,
+ auth_request,
+ auth_info))
+ cur_check_id += 1
+
+ # This also check for WebDAV which is unnecessary in most cases
+ # for op in COMMON_WEBDAV_REQUEST_METHODS:
+ # if op in op_names:
+ # continue
+
+ # _, auth_header = generate_auth(config)
+ # for server_url in server_urls:
+ # test_cases.append(TestAllowedHTTPMethod(cur_check_id,
+ # server_url,
+ # path_name,
+ # op,
+ # headers=auth_header))
+
+ # cur_check_id += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "request_info": self.request_info.serialize(),
+ }
+
+ if self.auth_info:
+ serialized.update({
+ "auth_info": self.auth_info.serialize(),
+ })
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ request_info = RequestInfo.deserialize(serialized.pop("request_info"))
+ auth_info = AuthRequestInfo.deserialize(serialized.pop("auth_info"), config.auth)
+
+ return TestAllowedHTTPMethod(check_id, request_info, auth_info, **serialized)
+
+
+class MetaTestAllowedHTTPMethod(TestCase):
+ """
+ Aggregate the results of TestAllowedHTTPMethod.
+ """
+ test_type = TestCaseType.META
+ auth_type = AuthType.NOPE
+ live_type = LiveType.OFFLINE
+ generates_for = (TestAllowedHTTPMethod,)
+
+ def __init__(self, check_id, checks: list[TestCase] = []) -> None:
+ """
+ Creates a new check for MetaTestAllowedHTTPMethod.
+
+ :param checks: List of checks investigated.
+ :type checks: list[TestAllowedHTTPMethod]
+ """
+ super().__init__(check_id)
+
+ self.checks = checks
+
+ def run(self):
+ self.result.status = CheckStatus.RUNNING
+
+ if not all(check.result.status == CheckStatus.FINISHED for check in self.checks):
+ raise Exception(f"Cannot run meta check {self}. Dependent checks are not finished.")
+
+ self.result.value = {
+ "affected_paths": 0,
+ "affected_methods": 0,
+ "found_methods": {}
+ }
+ self.result.issue_type = IssueType.NO_CANDIDATE
+ for check in self.checks:
+ if check.result.issue_type == IssueType.CANDIDATE:
+ self.result.value["affected_methods"] += 1
+ path = check.result.value["path"]
+ method = check.result.value["http_method"]
+ status_code = check.result.value["status_code"]
+
+ if path not in self.result.value["found_methods"]:
+ self.result.value["affected_paths"] += 1
+ self.result.value["found_methods"][path] = {
+ "undocumented_methods": []
+ }
+
+ self.result.value["found_methods"][path]["undocumented_methods"].append(
+ {
+ "method": method,
+ "status_code": status_code
+ }
+ )
+ self.result.issue_type = IssueType.CANDIDATE
+
+ self.result.status = CheckStatus.FINISHED
+
+ def report(self, verbosity: int = 2):
+ report = {}
+ report.update(self.result.dump(verbosity=verbosity))
+
+ # Check params
+ report["config"] = self.serialize()
+
+ return Report(self.check_id, content=report)
+
+ def propose(self, config, check_id_start):
+ return []
+
+ @ classmethod
+ def generate(cls, config, check_id_start=0):
+ subchecks = TestAllowedHTTPMethod.generate(config, check_id_start)
+
+ test_cases = []
+ test_cases.extend(subchecks)
+ check_id_start += len(subchecks)
+
+ test_cases.append(MetaTestAllowedHTTPMethod(check_id_start, subchecks))
+
+ check_id_start += 1
+
+ logging.debug(f"Generated {len(test_cases)} checks from test case {cls}")
+
+ return test_cases
+
+ def serialize(self) -> dict:
+ serialized = {
+ "check_ids": [check.check_id for check in self.checks],
+ }
+
+ return serialized
+
+ @classmethod
+ def deserialize(cls, serialized, config, check_id: int = 0):
+ # TODO: Reference checks from deserialized config
+
+ # return MetaTestAllowedHTTPMethod(check_id, **serialized)
+ return None
+
+
+HTTP_REQUEST_METHODS = [
+ # CRUD methods
+ "get",
+ "post",
+ "put",
+ "delete",
+ "patch",
+
+
+ # Other HTTP methods; ignored for now
+ # "head", # should be same result as get
+ # "connect", # not relevant (?)
+ # "options", # tested in extra test case
+ # "trace", # not relevant (?)
+]
+
+# Methods that hint at WebDAV usage
+COMMON_WEBDAV_REQUEST_METHODS = [
+ "copy",
+ "lock",
+ "mkcol",
+ "move",
+ "propfind",
+ "proppatch",
+ "unlock",
+]
diff --git a/rest_attacker/engine/__init__.py b/rest_attacker/engine/__init__.py
new file mode 100644
index 0000000..d009c11
--- /dev/null
+++ b/rest_attacker/engine/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+This module contains the core implementation for the test engine that
+executes checks at runtime.
+"""
diff --git a/rest_attacker/engine/config.py b/rest_attacker/engine/config.py
new file mode 100644
index 0000000..28a94e4
--- /dev/null
+++ b/rest_attacker/engine/config.py
@@ -0,0 +1,74 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Stores the main service configuration for the engine.
+"""
+from __future__ import annotations
+import typing
+
+from argparse import Namespace
+
+from rest_attacker.util.auth.auth_generator import AuthGenerator
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.util.openapi.wrapper import OpenAPI
+
+
+class EngineConfig:
+ """
+ Store configuration information for the test run.
+ """
+
+ def __init__(
+ self,
+ meta: dict,
+ info: dict,
+ credentials: dict,
+ users: dict = None,
+ current_user_id: str = None,
+ auth_gen: AuthGenerator = None,
+ descriptions: dict[str, OpenAPI] = None,
+ cli_args: Namespace = None
+ ) -> None:
+ """
+ Create a new configuration object for the engine.
+
+ :param meta: Metadata (name, etc.) for the service.
+ :type meta: dict
+ :param info: Analysis information (scopes, etc.) for the service.
+ :type info: dict
+ :param credentials: Authentication information for the service.
+ :type credentials: dict
+ :param users: User definitions for the service.
+ :type users: dict
+ :param current_user_id: ID of the currently active user.
+ :type current_user_id: str
+ :param auth_gen: Authentication generator that handles creation of
+ authentication data for checks.
+ :type auth_gen: AuthGenerator
+ :param descriptions: Available API descriptions.
+ :type descriptions: dict
+ :param cli_args: Arguments from the RATT CLI interface.
+ :type cli_args: argparse.Namespace
+ """
+ # Metadata about the service
+ self.meta = meta
+
+ # Information for the analysis
+ self.info = info
+
+ # Credentials info
+ self.credentials = credentials
+
+ # User definitions
+ self.users = users
+ self.current_user_id = current_user_id
+
+ # Auth Generator
+ self.auth = auth_gen
+
+ # API Descriptions
+ self.descriptions = descriptions
+
+ # CLI args
+ self.cli_args = cli_args
diff --git a/rest_attacker/engine/engine.py b/rest_attacker/engine/engine.py
new file mode 100644
index 0000000..7f53984
--- /dev/null
+++ b/rest_attacker/engine/engine.py
@@ -0,0 +1,252 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Implementation of the generic test case super class.
+"""
+
+import typing
+
+from datetime import datetime
+import json
+import logging
+import sys
+import time
+
+from rest_attacker.checks.generic import TestCase
+from rest_attacker.checks.types import LiveType, TestCaseType
+from rest_attacker.engine.config import EngineConfig
+from rest_attacker.engine.internal_state import EngineStatus, InternalState
+from rest_attacker.util.errors import RestrictedOperationError
+from rest_attacker.util.test_result import CheckStatus, TestResult
+
+
+class Engine:
+ """
+ Test engine for a test run.
+ """
+
+ def __init__(self, config: EngineConfig, checks: list[TestCase], handlers=[]) -> None:
+ """
+ Create a new engine for a test run.
+
+ :param config: Configuration for the service.
+ :type config: EngineConfig
+ :param checks: Ordered list of checks that should be executed in the run.
+ :type checks: list
+ :param handlers: Handlers for tracking rate limits imposed by the service.
+ :type handlers: list
+ """
+ self.checks = checks
+ self.config = config
+ self.state = InternalState()
+ self.index = 0
+
+ # Initialize statistics
+ self.state.planned_check_count = len(self.checks)
+
+ for check in self.checks:
+ if check.test_type == TestCaseType.ANALYTICAL:
+ self.state.analytical_check_count += 1
+
+ elif check.test_type == TestCaseType.SECURITY:
+ self.state.security_check_count += 1
+
+ # Setup handlers
+ for handler in handlers:
+ self.state.set_limit_handler(handler)
+
+ def run(self) -> None:
+ """
+ Run all checks of the test run.
+ """
+ self.state.status = EngineStatus.RUNNING
+ logging.info("Starting: Engine run.")
+ while self.index < len(self.checks):
+ self.current_check()
+
+ # Check if rate/access limits are reached
+ self.update_handlers()
+
+ self.status()
+ self.index += 1
+
+ if not self.state.status is EngineStatus.ABORTED:
+ logging.info("Finished: Engine run.")
+ self.state.status = EngineStatus.FINISHED
+
+ self.state.end_time = time.time()
+
+ def current_check(self) -> None:
+ """
+ Execute the check at the current index.
+ """
+ logging.debug(f"Starting: Check {self.checks[self.index].check_id} "
+ f"({self.checks[self.index].get_test_case_id()}).")
+
+ current_check = self.checks[self.index]
+ try:
+ current_check.run()
+
+ except RestrictedOperationError as err:
+ current_check.result.status = CheckStatus.SKIPPED
+ current_check.result.error = err
+ logging.warning((f"Check {current_check.check_id} did not execute: "
+ "API operation is restricted"))
+
+ except Exception as err:
+ current_check.result.status = CheckStatus.ERROR
+ current_check.result.error = err
+ logging.warning(
+ f"Check {current_check.check_id} ({current_check.get_test_case_id()}) "
+ f"produced the following error:\n{err}")
+
+ # Update statistics
+ if current_check.result.status is CheckStatus.FINISHED:
+ self.state.finished_check_count += 1
+ if self.config.cli_args:
+ if self.config.cli_args.propose:
+ for check in current_check.propose(self.config, len(self.checks)):
+ self.checks.insert(self.index + 1, check)
+
+ elif current_check.result.status is CheckStatus.SKIPPED:
+ self.state.skipped_check_count += 1
+
+ elif current_check.result.status is CheckStatus.ERROR:
+ self.state.error_check_count += 1
+
+ logging.debug(f"Finished: Check {self.checks[self.index].check_id} ({current_check.get_test_case_id()}) "
+ " with status: "
+ f"{self.checks[self.index].result.status.value}")
+
+ def export(self, output_dir) -> None:
+ """
+ Export the results of a run to file.
+
+ :param output_dir: Directory the report files are exported to.
+ :type output_dir: pathlib.Path
+ """
+ output: dict[str, typing.Any] = {}
+
+ if self.state.status is EngineStatus.ABORTED:
+ output["type"] = "partial"
+
+ else:
+ output["type"] = "report"
+
+ # Service info
+ if self.config.meta:
+ output["meta"] = self.config.meta
+
+ # Run statistics
+ output["stats"] = self.state.dump()
+
+ # Run args
+ if self.config.cli_args:
+ output["args"] = sys.argv[1:]
+
+ reports = []
+ for check in self.checks:
+ try:
+ report = check.report().dump()
+ reports.append(report)
+
+ except Exception as exc:
+ logging.exception(
+ f"Report for check {check.check_id} could not be generated.", exc_info=exc)
+
+ output["reports"] = reports
+
+ output_str = json.dumps(output, indent=4)
+
+ report_file = output_dir / "report.json"
+ with report_file.open('w') as repf:
+ repf.write(output_str)
+
+ logging.info(f"Exported results to: {report_file}")
+
+ def update_handlers(self):
+ """
+ Execute response handlers assigned to the run.
+ """
+ current_check = self.checks[self.index]
+
+ # Currently only online checks are relevant here
+ if current_check.live_type is LiveType.ONLINE:
+ if self.state.rate_limit:
+ rl_limit_reached = self.state.rate_limit.update(current_check.result.last_response)
+
+ if rl_limit_reached:
+ reset_time = self.state.rate_limit.get_reset_wait_time()
+ logging.warning(
+ f"Rate limit reached: Next check possible in {reset_time} seconds.")
+ self.pause(time.time() + reset_time)
+
+ self.state.rate_limit.reset()
+
+ if self.state.access_limit:
+ # Check interval
+ if self.state.access_limit.current_pos >= self.state.access_limit.interval:
+ acc_limit_reached = self.state.access_limit.update()
+
+ if acc_limit_reached:
+ logging.warning("Access limit reached.")
+ # TODO: Idea: Switch to a different client/user and resume run
+ # may affect the remaining preconfigured checks?
+
+ # Roll back to last successful check
+ while (self.checks[self.index].check_id != self.state.access_limit.last_check_id
+ and self.index > -1):
+ # Clear test result
+ self.checks[self.index].result = TestResult(self.checks[self.index])
+ self.index -= 1
+
+ # TODO: This currently also rolls back offline checks. Maybe only roll
+ # back online checks that actively make requests
+
+ # TODO: What if any checks were proposed based on the results of the checks
+ # that were faulty and are now rolled back?
+
+ # abort the run
+ self.abort()
+ return
+
+ else:
+ self.state.access_limit.last_check_id = self.checks[self.index].check_id
+ self.state.access_limit.reset()
+
+ else:
+ self.state.access_limit.current_pos += 1
+
+ def pause(self, until: int):
+ """
+ Pause a run until a point in time.
+
+ :param until: UNIX timestamp of the time when the run should resume.
+ :type until: int
+ """
+ logging.warning(f"Pausing run until {datetime.fromtimestamp(until)}.")
+ print("Press CTRL + C to abort run.")
+ time.sleep(until)
+
+ def abort(self):
+ """
+ Abort a run. This skips all remaining checks and immediately end the run.
+ Already gathered results can be exported. The run can be resumed if the check
+ params were exported.
+ """
+ for idx in range(self.index, len(self.checks)):
+ self.checks[idx].result.status = CheckStatus.ABORTED
+
+ self.state.aborted_check_count = len(self.checks) - self.index
+
+ self.index = len(self.checks)
+ self.state.status = EngineStatus.ABORTED
+ self.state.end_time = time.time()
+
+ logging.warning("Run successfully aborted.")
+
+ def status(self):
+ """
+ Print the current status to CLI.
+ """
+ sys.stdout.write(f"Executed {self.index + 1}/{len(self.checks)} checks\r")
diff --git a/rest_attacker/engine/generate_checks.py b/rest_attacker/engine/generate_checks.py
new file mode 100644
index 0000000..53f8c6a
--- /dev/null
+++ b/rest_attacker/engine/generate_checks.py
@@ -0,0 +1,90 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Generate checks for a run.
+"""
+
+from __future__ import annotations
+import typing
+
+import logging
+
+from rest_attacker.checks.types import TestCaseType
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.checks.generic import TestCase
+ from rest_attacker.engine.config import EngineConfig
+
+
+def generate_checks(
+ config: EngineConfig,
+ test_cases: dict[str, TestCase],
+ filters: dict = None
+) -> list[TestCase]:
+ """
+ Generate checks for a set of test cases from a service configuration.
+
+ :param config: Configuration for the service.
+ :type config: EngineConfig
+ :param test_cases: Test case classes that are used for the generation. Maps test case IDs to test case classes.
+ :type test_cases: dict[str, TestCase]
+ :param filters: Only generate test cases with a specific type. The input dict must
+ have strings of the name of the class type member as keys and
+ a list of the allowed types as values. If a type is not specified
+ in the keys, all variants of this type are allowed.
+ :type filters: dict
+ """
+ index = 0
+ checks = []
+
+ # Check if test case conforms to filter
+ if filters:
+ filtered_test_cases = {}
+ if "test_cases" in filters.keys():
+ # Check if IDs given here are valid test case IDs
+ for allowed_type in filters["test_cases"]:
+ if allowed_type not in test_cases.keys():
+ raise Exception(
+ f"Could not find test case with ID '{allowed_type}' "
+ "in available test cases.")
+
+ for test_case_id, test_case in test_cases.items():
+ allowed = True
+ for filter_type, allowed_types in filters.items():
+ if filter_type in ("test_type", "auth_type", "live_type"):
+ case_type = getattr(test_case, filter_type)
+
+ if case_type not in allowed_types:
+ allowed = False
+ break
+
+ elif filter_type == "test_cases":
+ if test_case_id not in allowed_types:
+ allowed = False
+ break
+
+ if allowed:
+ logging.debug(f"Added test case: {test_case_id}")
+ filtered_test_cases.update({
+ test_case_id: test_case
+ })
+
+ else:
+ filtered_test_cases = test_cases.copy()
+
+ # Search for meta test cases that generate checks for their subcases
+ # The subcases can be removed from the generation
+ dedupl_test_cases = filtered_test_cases.copy()
+ for test_case in filtered_test_cases.values():
+ if test_case.test_type is TestCaseType.META:
+ if test_case.generates_for:
+ for test_case_cls in test_case.generates_for:
+ test_case_id = test_case_cls.get_test_case_id()
+ dedupl_test_cases.pop(test_case_id, None)
+
+ for test_case in dedupl_test_cases.values():
+ new_checks = test_case.generate(config, check_id_start=index)
+ checks.extend(new_checks)
+ index += len(new_checks)
+
+ return checks
diff --git a/rest_attacker/engine/internal_state.py b/rest_attacker/engine/internal_state.py
new file mode 100644
index 0000000..e6b3162
--- /dev/null
+++ b/rest_attacker/engine/internal_state.py
@@ -0,0 +1,89 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Keeps track of the internal state of the engine.
+"""
+
+import enum
+import time
+from datetime import datetime
+
+from rest_attacker.util.response_handler import AccessLimitHandler, RateLimitHandler
+
+
+class EngineStatus(enum.Enum):
+ """
+ Status of the engine.
+ """
+ QUEUED = "queued" # run is configured and queued
+ RUNNING = "running" # run was started
+ FINISHED = "finished" # run finished successfully
+ ABORTED = "aborted" # run was aborted by engine or user
+ ERROR = "error" # run failed with error
+
+
+class InternalState:
+ """
+ Keeps track of dynamic information about the internal state of the test run.
+ """
+
+ def __init__(self) -> None:
+ """
+ Create a new internal state for an engine.
+ """
+ # Unix time when the run was started
+ self.status = EngineStatus.QUEUED
+
+ # Unix time when the run was started
+ self.start_time = time.time()
+
+ # Unix time when the run ended
+ self.end_time: float = -1.0
+
+ # Current rate limit
+ self.rate_limit = None
+
+ # Current rate limit
+ self.access_limit = None
+
+ # Number of planned checks (when the run started)
+ self.planned_check_count = 0
+
+ # Number of already executed checks
+ self.finished_check_count = 0
+
+ # Statistics for checks
+ self.analytical_check_count = 0
+ self.security_check_count = 0
+ self.error_check_count = 0
+ self.skipped_check_count = 0
+ self.aborted_check_count = 0
+
+ def set_limit_handler(self, handler):
+ """
+ Set rate/access limit handlers for the test rub.
+
+ :param handler: Handler used for analyzing the internal state.
+ :type handler: RateLimitHandler|AccessLimitHandler
+ """
+ if isinstance(handler, RateLimitHandler):
+ self.rate_limit = handler
+
+ elif isinstance(handler, AccessLimitHandler):
+ self.access_limit = handler
+
+ def dump(self) -> dict:
+ """
+ Generate a dictionary with information from the internal state.
+ """
+ return {
+ "start": datetime.utcfromtimestamp(self.start_time).strftime('%Y-%m-%dT%H-%M-%SZ'),
+ "end": datetime.utcfromtimestamp(self.end_time).strftime('%Y-%m-%dT%H-%M-%SZ'),
+ "planned": self.planned_check_count,
+ "finished": self.finished_check_count,
+ "skipped": self.skipped_check_count,
+ "aborted": self.aborted_check_count,
+ "errors": self.error_check_count,
+ "analytical_checks": self.analytical_check_count,
+ "security_checks": self.security_check_count,
+ }
diff --git a/rest_attacker/report/__init__.py b/rest_attacker/report/__init__.py
new file mode 100644
index 0000000..b25c55f
--- /dev/null
+++ b/rest_attacker/report/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+This module contains functions to create and parse reports from
+executed checks.
+"""
diff --git a/rest_attacker/report/report.py b/rest_attacker/report/report.py
new file mode 100644
index 0000000..9726231
--- /dev/null
+++ b/rest_attacker/report/report.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Implementation of report objects.
+"""
+
+import json
+
+
+class Report:
+ """
+ Report of an individual check.
+ """
+
+ def __init__(self, check_id: int, content: dict = None) -> None:
+ """
+ Create a new report.
+
+ :param check_id: Identifier of the check that the report is generated for.
+ :type check_id: int
+ :param content: Content (= parameters and values) of the report.
+ :type content: dict
+ """
+ self.report_id = check_id
+ self.content = {}
+
+ if content:
+ self.content.update(content)
+
+ def dump(self) -> dict:
+ """
+ Create a dict with the report contents.
+ """
+ output = {
+ "report_id": self.report_id
+ }
+ output.update(self.content)
+
+ return output
+
+ def dumps(self, outformat: str = "json") -> str:
+ """
+ Create a string representation of the report.
+
+ :param format: Data representation format of the report.
+ :type format: str
+ """
+ if outformat == "json":
+ output = json.dumps(self.content, sort_keys=True, indent=4)
+
+ return output
+
+ else:
+ raise ValueError(
+ f"Format '{outformat}' unknown: Cannot be used to generate reports.")
diff --git a/rest_attacker/util/__init__.py b/rest_attacker/util/__init__.py
new file mode 100644
index 0000000..e04e58d
--- /dev/null
+++ b/rest_attacker/util/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+This module contains helper functions and tools that aid the analysis.
+"""
diff --git a/rest_attacker/util/auth/__init__.py b/rest_attacker/util/auth/__init__.py
new file mode 100644
index 0000000..921b9bd
--- /dev/null
+++ b/rest_attacker/util/auth/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Helper functions and classes for OAuth2 authentication.
+"""
diff --git a/rest_attacker/util/auth/auth_generator.py b/rest_attacker/util/auth/auth_generator.py
new file mode 100644
index 0000000..081a1da
--- /dev/null
+++ b/rest_attacker/util/auth/auth_generator.py
@@ -0,0 +1,158 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Generates authentication information from authentication schemes.
+"""
+
+import logging
+
+from rest_attacker.util.auth.auth_scheme import AuthScheme, AuthType
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy
+
+
+class AuthGenerator:
+ """
+ Generates authentication/authorization information for a request.
+ """
+
+ def __init__(
+ self,
+ schemes: dict[str, AuthScheme] = {},
+ required_min: dict[str, list[str]] = {},
+ required_auth: dict[str, list[str]] = {}
+ ) -> None:
+ """
+ Create a new AuthGenerator.
+
+ :param schemes: Dict of schemes that can be used by the auth generator.
+ :type schemes: dict[AuthScheme]
+ :param required_min: Dict of scheme lists that are required for non-authenticated
+ requests.
+ The first ID in each list is used as the default.
+ :type required_min: dict
+ :param required_auth: Dict of scheme lists that are required for authenticated requests.
+ The first ID in each list is used as the default.
+ :type required_auth: dict
+ """
+ self.supported_schemes = schemes
+ self.required_min = required_min
+ self.required_auth = required_auth
+
+ def get_auth(
+ self,
+ scheme_ids: list[str] = None,
+ scopes: list[str] = None,
+ policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT,
+ ) -> list[tuple[AuthType, dict]]:
+ """
+ Get an authentication infos that can be inserted into a request. The location
+ of each info in the request is returned as the first parameter of each tuple.
+
+ :param scheme_ids: Optional list of IDs of the authentication schemes that should be used.
+ :type scheme_ids: list[str]
+ :param scopes: Authorization scopes that should be requested.
+ :type scopes: list[str]
+ """
+ auth_infos = []
+ if scheme_ids:
+ logging.debug(f"Generating auth infos for schemes {scheme_ids}")
+ for scheme_id in scheme_ids:
+ auth_infos.append(
+ self.get_auth_scheme(scheme_id=scheme_id, scopes=scopes, policy=policy)
+ )
+
+ return auth_infos
+
+ logging.debug(f"Generating auth infos from required schemes {self.required_auth}")
+ for scheme_list in self.required_auth.values():
+ scheme_id = scheme_list[0]
+ auth_infos.append(
+ self.get_auth_scheme(scheme_id=scheme_id, scopes=scopes, policy=policy)
+ )
+
+ return auth_infos
+
+ def get_min_auth(self) -> list[tuple[AuthType, dict]]:
+ """
+ Get authentication infos that are required for every request. The location
+ of each info in the request is returned as the first parameter of each tuple.
+ """
+ auth_infos = []
+ logging.debug(f"Generating auth infos from required schemes {self.required_min}")
+ for scheme_list in self.required_min.values():
+ scheme_id = scheme_list[0]
+ auth_infos.append(self.get_auth_scheme(scheme_id=scheme_id))
+
+ return auth_infos
+
+ def get_auth_scheme(
+ self,
+ scheme_id: str = None,
+ auth_type: AuthType = None,
+ credentials_map: dict[str, str] = None,
+ scopes: list[str] = None,
+ policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT,
+ ) -> tuple[AuthType, dict]:
+ """
+ Get an authentication info for a scheme. The location of the info in the request is
+ returned as the first parameter. Scheme ID, authentication type and credentials
+ can be specified independently. The auth generator will try to find the best match.
+
+ :param scheme_id: ID of the preferred authentication scheme.
+ :type scheme_id: str
+ :param auth_type: Preferred location of authentication.
+ :type scheme_id: AuthType
+ :param default_creds: Map of parameter ID to credential ID. Overrides the preference in the
+ parameter config.
+ :type default_creds: dict[str,str]
+ :param scopes: Authorization scopes that should be requested.
+ :type scopes: list[str]
+ """
+ if scheme_id:
+ # Get the specific scheme
+ scheme = self.supported_schemes[scheme_id]
+
+ if auth_type and not scheme.auth_type is auth_type:
+ # Check if the scheme has the correct auth type
+ raise Exception(
+ f"scheme '{scheme.scheme_id}' does not match auth type '{auth_type}'")
+
+ if credentials_map and not scheme.supports_credentials(credentials_map.keys()):
+ # Check if the scheme supports the credentials
+ raise Exception(
+ f"scheme '{scheme.scheme_id}' does not support "
+ f"credentials '{credentials_map.keys()}'")
+
+ elif auth_type:
+ # Use the first matching scheme that can be found
+ for sch in self.supported_schemes.values():
+ if sch.auth_type is auth_type:
+ scheme = sch
+ break
+
+ else:
+ raise Exception(
+ f"Could not find scheme with auth type '{auth_type}'")
+
+ elif credentials_map:
+ # Use the first matching scheme that can be found
+ for sch in self.supported_schemes.values():
+ if sch.supports_credentials(credentials_map.keys()):
+ scheme = sch
+ break
+
+ else:
+ raise Exception(
+ f"Could not find scheme that uses credentials '{credentials_map.keys()}'")
+
+ else:
+ raise Exception(
+ "Generator cannot select scheme. Specify at least one of: "
+ "scheme ID, auth type or credentials ID.")
+
+ logging.debug(f"Generating auth info for scheme '{scheme.scheme_id}'")
+ auth_type, auth_info = scheme.get_auth(credentials_map=credentials_map,
+ scopes=scopes,
+ access_policy=policy)
+
+ return auth_type, auth_info
diff --git a/rest_attacker/util/auth/auth_scheme.py b/rest_attacker/util/auth/auth_scheme.py
new file mode 100644
index 0000000..8caf9b9
--- /dev/null
+++ b/rest_attacker/util/auth/auth_scheme.py
@@ -0,0 +1,285 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Manages authentication schemes.
+"""
+
+from abc import ABC, abstractmethod
+import base64
+import enum
+import re
+from typing import Collection, Mapping
+
+
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy, OAuth2TokenGenerator
+
+
+class AuthType(enum.Enum):
+ """
+ Authentication Types.
+ """
+ QUERY = "query"
+ HEADER = "header"
+ BASIC = "basic"
+ COOKIE = "cookie"
+
+
+class AuthScheme(ABC):
+ """
+ Stores patterns to generate authentication/authorization information.
+ """
+
+ def __init__(
+ self,
+ scheme_id: str,
+ auth_type: AuthType,
+ credentials: dict = {},
+ default_creds: Mapping[str, str] = None
+ ) -> None:
+ """
+ Create a new AuthScheme.
+
+ :param scheme_id: ID of the scheme.
+ :type scheme_id: str
+ :param auth_type: Location of the scheme in the request.
+ :type auth_type: AuthType
+ :param credentials: Credentials used by the scheme.
+ :type credentials: dict[str,dict]
+ :param default_creds: Map of parameter ID to credential ID. Overrides the preference in the
+ parameter config.
+ :type default_creds: dict[str,str]
+ """
+ self.scheme_id = scheme_id
+ self.auth_type = auth_type
+ self.credentials = credentials
+ self.default_creds = None
+
+ if default_creds:
+ self.default_creds = {}
+ for param_id, cred_id in default_creds.items():
+ self.default_creds[param_id] = self.credentials[cred_id]
+
+ def supports_credential_id(self, credentials_id: str) -> bool:
+ """
+ Checks if the auth scheme supports the credentials with the specified ID.
+
+ :param credentials_id: ID of credentials.
+ :type credentials_id: str
+ """
+ return credentials_id in self.credentials.keys()
+
+ def supports_credentials(self, credentials: Collection) -> bool:
+ """
+ Checks if the auth scheme supports all credentials with the IDs in the collection.
+
+ :param credentials: Collection of credential IDs.
+ :type credentials: Collection
+ """
+ return all(self.supports_credential_id(cred_id) for cred_id in credentials)
+
+ @abstractmethod
+ def get_auth(
+ self,
+ credentials_map: Mapping[str, str] = None,
+ scopes: list[str] = None,
+ access_policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> tuple[AuthType, dict]:
+ """
+ Create the authentication info for the scheme.
+
+ :param credentials_map: Map of parameter ID to preferred credentials ID.
+ :type credentials_map: dict[str,str]
+ :param scopes: Authorization scopes that should be requested. These are ignored
+ for credentials that do not support scoped access control.
+ :type scopes: list[str]
+ """
+
+
+class KeyValueAuthScheme(AuthScheme):
+ """
+ Stores patterns to generate key-value based authentication info.
+ """
+
+ def __init__(
+ self,
+ scheme_id: str,
+ auth_type: AuthType,
+ key_id: str,
+ payload_pattern: str,
+ params_cfg: dict,
+ credentials: dict = {},
+ default_creds: Mapping[str, str] = None
+ ) -> None:
+ """
+ Create a new ValueAuthScheme.
+
+ :param key_id: ID of the key of the key-value pair.
+ :type key_id: str
+ :param payload_pattern: Regex pattern for building the payload.
+ :type payload_pattern: str
+ :param params_cfg: Config for the parameters used in the payload.
+ :type params_cfg: dict
+ """
+ super().__init__(scheme_id, auth_type, credentials=credentials, default_creds=default_creds)
+
+ self.key_id = key_id
+ self.payload_pattern = payload_pattern
+ self.params_cfg = params_cfg
+
+ def get_auth(
+ self,
+ credentials_map=None,
+ scopes: list[str] = None,
+ access_policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> tuple[AuthType, dict]:
+ params = re.findall(r"\{[0-9]+\}", self.payload_pattern)
+
+ creds = None
+ if credentials_map:
+ creds = {}
+ for param_id, cred_src_id in credentials_map.items():
+ creds[param_id] = self.credentials[cred_src_id]
+
+ elif self.default_creds:
+ creds = self.default_creds
+
+ payload_str = self.payload_pattern
+ for p_param in params:
+ param_cfg = self.params_cfg[p_param[1:-1]]
+ cred_value_id = param_cfg["id"]
+
+ if not creds:
+ # Use the first entry in the list as default
+ cred_src_id = param_cfg["from"][0]
+
+ else:
+ # Use the preconfigured default for the scheme
+ cred_src_id = creds[cred_value_id]
+
+ try:
+ cred = self.credentials[cred_src_id]
+
+ except KeyError as err:
+ raise Exception(f"Could not generate auth info for scheme '{self.scheme_id}': "
+ f"Could not find credentials '{cred_src_id}' in dict "
+ f"of credentials for the scheme") from err
+
+ try:
+ if isinstance(cred, OAuth2TokenGenerator):
+ param_val = cred.get_token(scopes, policy=access_policy)[cred_value_id]
+
+ elif isinstance(cred, dict):
+ param_val = cred[cred_value_id]
+
+ else:
+ raise Exception(
+ f"Unknown credentials format '{type(cred)}'. "
+ "Expected dict or OAuth2TokenGenerator")
+
+ except KeyError as err:
+ raise Exception(f"Could not generate auth info for scheme '{self.scheme_id}': "
+ f"Could not find parameter '{cred_value_id}' in credentials "
+ f"{cred_src_id}") from err
+
+ search_param = re.escape(p_param)
+ payload_str = re.sub(search_param, param_val, payload_str)
+
+ value = {
+ self.key_id: payload_str
+ }
+
+ return self.auth_type, value
+
+
+class BasicAuthScheme(AuthScheme):
+ """
+ Stores patterns to generate a HTTP Basic Authentication header.
+ """
+
+ def __init__(
+ self,
+ scheme_id: str,
+ payload_pattern: str,
+ params_cfg: dict,
+ credentials: dict = {},
+ default_creds: Mapping[str, str] = None
+ ) -> None:
+ """
+ Create a new BasicAuthScheme.
+
+ :param payload_pattern: Regex pattern for building the payload after the 'Basic' keyword.
+ :type payload_pattern: str
+ :param params_cfg: Config for the parameters used in the paload.
+ :type params_cfg: dict
+ """
+ super().__init__(scheme_id,
+ AuthType.BASIC,
+ credentials=credentials,
+ default_creds=default_creds)
+
+ self.key_id = 'authorization'
+ self.payload_pattern = payload_pattern
+ self.params_cfg = params_cfg
+
+ def get_auth(
+ self,
+ credentials_map=None,
+ scopes: list[str] = None,
+ access_policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> tuple[AuthType, dict]:
+ params = re.findall(r"\{[0-9]+\}", self.payload_pattern)
+
+ creds = None
+ if credentials_map:
+ creds = {}
+ for param_id, cred_src_id in credentials_map.items():
+ creds[param_id] = self.credentials[cred_src_id]
+
+ elif self.default_creds:
+ creds = self.default_creds
+
+ payload_str = self.payload_pattern
+ for p_param in params:
+ param_cfg = self.params_cfg[p_param[1:-1]]
+ cred_value_id = param_cfg["id"]
+
+ if not creds:
+ # Use the first entry in the list as default
+ cred_src_id = param_cfg["from"][0]
+
+ else:
+ # Use the preconfigured default for the scheme
+ cred_src_id = creds[cred_value_id]
+
+ try:
+ cred = self.credentials[cred_src_id]
+
+ except KeyError as err:
+ raise Exception(f"Could not generate auth info for scheme '{self.scheme_id}': "
+ f"Could not find credentials '{cred_src_id}' in dict "
+ f"of credentials for the scheme") from err
+
+ try:
+ if isinstance(cred, OAuth2TokenGenerator):
+ param_val = cred.get_token(scopes, policy=access_policy)[cred_value_id]
+
+ else:
+ param_val = cred[cred_value_id]
+
+ except KeyError as err:
+ raise Exception(f"Could not generate auth info for scheme '{self.scheme_id}': "
+ f"Could not find parameter '{cred_value_id}' in credentials "
+ f"{cred_src_id}") from err
+
+ search_param = re.escape(p_param)
+ payload_str = re.sub(search_param, param_val, payload_str)
+
+ # Convert to bytes for Base64 encoding
+ payload_bytes = payload_str.encode('ascii')
+ payload_str = f"Basic {base64.b64encode(payload_bytes).decode('ascii')}"
+
+ header = {
+ self.key_id: payload_str
+ }
+
+ return AuthType.HEADER, header
diff --git a/rest_attacker/util/auth/server.py b/rest_attacker/util/auth/server.py
new file mode 100644
index 0000000..9651363
--- /dev/null
+++ b/rest_attacker/util/auth/server.py
@@ -0,0 +1,79 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Listen to OAuth2 redirects to localhost.
+"""
+
+from http.server import BaseHTTPRequestHandler
+from os.path import dirname
+from urllib.parse import unquote
+
+# Content of the denug page shown in the browser
+PAYLOAD = open(f"{dirname(__file__)}/server_payload.html", encoding='utf-8').read()
+
+
+class RedirectHandler(BaseHTTPRequestHandler):
+ """
+ Listens to incoming HTTP requests.
+ """
+ called = False
+ call_url = None
+
+ def do_GET(self) -> None:
+ query_params = self._get_query_params()
+ if "fragment" in query_params:
+ return self._implicit_handler()
+
+ # Always return HTTPS URL because oauthlib does not like plain HTTP
+ call_url = "https://"
+ call_url += self.server.server_address[0]
+ call_url += ":"
+ call_url += str(self.server.server_address[1])
+ call_url += self.path
+
+ RedirectHandler.called = True
+ RedirectHandler.call_url = call_url
+
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+ self.wfile.write(PAYLOAD.encode('utf-8'))
+
+ def _get_query_params(self) -> dict[str, str]:
+ """
+ Get the query parameters from the URL path-
+ """
+ query_string = self.path.split("?")
+ if len(query_string) < 2:
+ # No query string in path
+ return {}
+
+ query_string = query_string[1]
+ query_param_strings = query_string.split("&")
+
+ query_params = {}
+ for query_param_string in query_param_strings:
+ query_param = query_param_string.split("=")
+ query_params[query_param[0]] = query_param[1]
+
+ return query_params
+
+ def _implicit_handler(self) -> None:
+ """
+ Handle a request that contains the fragment value of the implicit grant.
+ """
+ query_params = self._get_query_params()
+ fragment = query_params.pop("fragment")
+
+ call_url = "https://"
+ call_url += self.server.server_address[0]
+ call_url += ":"
+ call_url += str(self.server.server_address[1])
+ call_url += self.path.split("?")[0]
+ call_url += unquote(fragment)
+
+ RedirectHandler.called = True
+ RedirectHandler.call_url = call_url
+
+ self.send_response(200)
+ self.end_headers()
diff --git a/rest_attacker/util/auth/server_payload.html b/rest_attacker/util/auth/server_payload.html
new file mode 100644
index 0000000..14fa67f
--- /dev/null
+++ b/rest_attacker/util/auth/server_payload.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+ RATT AuthResponse
+
+
+
+
Test
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/rest_attacker/util/auth/session.py b/rest_attacker/util/auth/session.py
new file mode 100644
index 0000000..1fa9f41
--- /dev/null
+++ b/rest_attacker/util/auth/session.py
@@ -0,0 +1,167 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Creates and manages an OAuth Resource Owner session.
+"""
+
+import webbrowser
+from abc import ABC, abstractmethod
+import time
+import requests
+
+
+class ROSession(ABC):
+ """
+ A user session of a resource owner.
+ """
+
+ def __init__(self, session_id: str, test_url: str = None) -> None:
+ """
+ Create a new Resource Owner Session.
+
+ :param session_id: Identifier of the session.
+ :type session_id: str
+ :param test_url: URL for testing if the session is valid.
+ :type test_url: str
+ """
+ self.session_id = session_id
+
+ self.test_url = test_url
+
+ @abstractmethod
+ def setup(self) -> None:
+ """
+ Create a new session.
+ """
+
+
+class ROCookieSession(ROSession):
+ """
+ User session using session cookies from an existing browser session.
+ """
+
+ def __init__(
+ self,
+ session_id: str,
+ cookies: dict,
+ test_url: str = None,
+ expires: int = None
+ ) -> None:
+ """
+ Create a new ROCookieSession.
+
+ :param cookies: Cookies from a browser session.
+ :type cookies: dict
+ :param expires: Expiration time of the session in UNIX time.
+ :type expires: int
+ """
+ super().__init__(session_id, test_url=test_url)
+
+ self.cookies = cookies
+ self.expires = expires
+
+ self.setup()
+
+ def setup(self) -> None:
+ cookies = requests.cookies.cookiejar_from_dict(self.cookies)
+ self.session = requests.Session()
+ self.session.cookies = cookies
+
+ def is_expired(self):
+ """
+ Checks if the current session is expired.
+ """
+ return self.expires - time.time() < 0
+
+ def is_valid(self) -> bool:
+ response = self.session.get(self.test_url)
+
+ return 200 <= response.status_code < 300
+
+
+class ROWebSession(ROSession):
+ """
+ User session using a web login.
+ """
+
+ def __init__(
+ self,
+ session_id: str,
+ login_url: str,
+ login_data: dict,
+ test_url: str = None
+ ) -> None:
+ """
+ Create a new ROWebSession. Note that this type of session likely does not work
+ for services that use 2FA or CSFR tokens.
+
+ :param login_url: Web login endpoint of the service. Must be accessible via POST method.
+ :type login_url: str
+ :param login_data: Body parameters with the login data.
+ :type login_data: dict
+ """
+ super().__init__(session_id, test_url=test_url)
+
+ self.login_url = login_url
+ self.login_data = login_data
+
+ self.setup()
+
+ def setup(self) -> None:
+ self.session = requests.Session()
+ login_response = self.session.post(
+ self.login_url, data=self.login_data)
+
+ if 400 <= login_response.status_code < 500:
+ raise Exception(f"Failed to create session '{self.session_id}'.")
+
+ def is_valid(self) -> bool:
+ response = self.session.get(self.test_url)
+
+ return 200 <= response.status_code < 300
+
+
+class ROBrowserSession(ROSession):
+ """
+ User session using a browser session.
+ """
+
+ def __init__(
+ self,
+ session_id: str,
+ executable: str,
+ port: int,
+ test_url: str = None
+ ) -> None:
+ """
+ Create a new ROBrowserSession.
+
+ :param executable: Browser executable path.
+ :type executable: str
+ :param port: Port number used for the local server.
+ :type port: int
+ """
+ super().__init__(session_id, test_url)
+
+ self.executable = executable
+ self.port = port
+
+ self.setup()
+
+ def setup(self):
+ webbrowser.register(
+ "ratt-browser",
+ None,
+ webbrowser.BackgroundBrowser(self.executable),
+ preferred=True
+ )
+
+ def authorize(self, url: str) -> None:
+ """
+ Send the authorization request with a browser using the invocation.
+
+ :param url: URL of the authorization request.
+ :type url: str
+ """
+ browser = webbrowser.get()
+ browser.open_new_tab(url)
diff --git a/rest_attacker/util/auth/token_generator.py b/rest_attacker/util/auth/token_generator.py
new file mode 100644
index 0000000..19fdbb4
--- /dev/null
+++ b/rest_attacker/util/auth/token_generator.py
@@ -0,0 +1,447 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Generates and keeps track of tokens for an OAuth provider.
+"""
+
+
+from http.server import HTTPServer
+import logging
+import os
+import time
+import enum
+
+from oauthlib.oauth2.rfc6749.clients.mobile_application import MobileApplicationClient
+from oauthlib.oauth2.rfc6749.tokens import OAuth2Token
+from requests_oauthlib.oauth2_session import OAuth2Session
+
+from rest_attacker.util.auth.server import RedirectHandler
+from rest_attacker.util.auth.session import ROBrowserSession, ROCookieSession, ROWebSession
+from rest_attacker.util.auth.userinfo import UserInfo
+
+
+@enum.unique
+class AccessLevelPolicy(enum.Enum):
+ """
+ Policy for retrieving an access level if no access level is specified.
+ """
+ NOPE = "nope" # Scope parameter is omitted (i.e. let authorization server decide)
+ DEFAULT = "default" # if service requires scope parameter: MAX; otherwise NOPE
+ MAX = "max" # get all available scopes (e.g. for guaranteed access)
+
+
+@enum.unique
+class ClientInfoFlag(enum.Enum):
+ """
+ Flags for client settings.
+ """
+ SCOPE_REQUIRED = "scope_required" # Authorization requests must contain scope
+
+
+class ClientInfo:
+ """
+ Stores OAuth2 client information.
+ """
+
+ def __init__(
+ self,
+ client_id: str,
+ client_secret: str,
+ redirect_urls: list[str],
+ auth_url: str,
+ token_url: str,
+ grants: list[str],
+ revoke_url: str = None,
+ scopes: list[str] = [],
+ flags: list[ClientInfoFlag] = [],
+ description: str = None
+ ) -> None:
+ """
+ Creates a new ClientInfo object.
+
+ :param client_id: Client ID.
+ :param client_secret: Client secret.
+ :param redirect_urls: Redirect URIs registered by the client at the service.
+ :param auth_url: URL of the authorization endpoint.
+ :param token_url: URL of the token endpoint.
+ :param revoke_url: URL for manual revocation of tokens. Implementation is service-specific
+ and not standardized in OAuth2.
+ :param grants: List of supported grants.
+ :param scopes: List of supported scopes (optional).
+ :param scopes: List of flags for token generator settings (optional).
+ :param description: Description or name of the client (optional).
+ :type client_id: str
+ :type client_secret: str
+ :type redirect_urls: list[str]
+ :type auth_url: str
+ :type token_url: str
+ :type revoke_url: str
+ :type grants: list[str]
+ :type scopes: list[str]
+ :type flags: list[ClientInfoFlag]
+ :type description: str
+ """
+ self.description = description
+ self.client_id = client_id
+ self.client_secret = client_secret
+
+ self.redirect_urls = redirect_urls
+ self.auth_url = auth_url
+ self.token_url = token_url
+ self.revoke_url = revoke_url
+
+ self.supported_grants = grants
+ self.supported_scopes = scopes
+
+ self.flags = flags
+
+
+class OAuth2TokenGenerator:
+ """
+ Generates and keeps track of tokens.
+ """
+
+ def __init__(self, client_info: ClientInfo, user_agent: UserInfo = None) -> None:
+ """
+ Create a new OAuth2TokenGenerator.
+
+ :param client_info: Object that contains information about the service
+ and authentication data for a client.
+ :type client_info: ClientInfo
+ :param user_agent: User info which can contain a resource owner session which
+ is used to authorize requrests.
+ :type user_agent: UserInfo
+ """
+
+ self.client_info = client_info
+ self.user_agent = user_agent
+
+ self.active_token: OAuth2Token | None = None
+ # Requested scopes of active token (not necessarily the scopes of the active token)
+ self._req_scopes: list[str] | None = None
+
+ # Log of all previously requested tokens
+ # New tokens are appended to the front
+ # The first item in the list is the currently active token
+ self.token_history: list[OAuth2Token] = []
+
+ self._setup_oauthlib()
+
+ def get_token(
+ self,
+ scopes: list[str] = None,
+ policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> OAuth2Token:
+ """
+ Get a valid token for authorization.
+
+ By default, the last generated token is returned. If a token is expired, it will
+ be refreshed for a new token if a refresh token is available. If scopes were
+ requested and the scopes do not exactly match the scopes of the active token
+ (i.e. scope value AND order of scopes), a new token is requested.
+
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ """
+ if not self.active_token:
+ return self.request_new_token(scopes=scopes, policy=policy)
+
+ if self._is_expired(self.active_token) and self._is_refreshable(self.active_token):
+ # TODO: Use requests_oauthlib's token refresher for automatic updates
+ return self.request_new_token(scopes=scopes, grant_type='refresh_token', policy=policy)
+
+ if scopes and self._req_scopes != scopes:
+ return self.request_new_token(scopes=scopes, policy=policy)
+
+ return self.active_token
+
+ def get_access_token(self, scopes: list[str] = None) -> str:
+ """
+ Gets only the access token part required for authentication from a valid token.
+
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ """
+ return self.get_token(scopes=scopes)['access_token']
+
+ def request_new_token(
+ self,
+ scopes: list[str] = None,
+ grant_type: str = 'code',
+ policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> OAuth2Token:
+ """
+ Fetch a new token from the service.
+
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ :param grant_type: Grant type used for the OAuth2 exchange.
+ :type grant_type: str
+ """
+ if scopes is None:
+ if policy is AccessLevelPolicy.MAX:
+ scopes = self.client_info.supported_scopes
+
+ elif "scope_required" in self.client_info.flags and policy is AccessLevelPolicy.DEFAULT:
+ # If no scopes are requested, but the service requires the scope parameter
+ # request all supported scopes from the authorization server
+ scopes = self.client_info.supported_scopes
+
+ elif policy is AccessLevelPolicy.NOPE:
+ scopes = None
+
+ logging.debug(f"Requesting new token from client '{self.client_info.client_id}'")
+ if self.active_token:
+ self.token_history.insert(0, self.active_token)
+
+ if not grant_type in self.client_info.supported_grants:
+ raise Exception(f"grant type '{grant_type}'' is not available for service")
+
+ if grant_type == 'code':
+ self.active_token = self._request_auth_grant(scopes=scopes)
+ self._req_scopes = scopes
+
+ elif grant_type == 'token':
+ self.active_token = self._request_impl_grant(scopes=scopes)
+ self._req_scopes = scopes
+
+ elif grant_type == 'refresh_token':
+ if not self.active_token:
+ raise Exception(f"Action 'refresh_token' not available: No currently active token")
+
+ self.active_token = self.refresh_token(token=self.active_token, scopes=scopes)
+ self._req_scopes = scopes
+
+ else:
+ raise Exception(f"unrecognized grant type: {grant_type}")
+
+ return self.active_token
+
+ def _request_auth_grant(self, scopes: list[str] = None) -> OAuth2Token:
+ """
+ Fetch a new token from the service using the Authorization Code Grant.
+
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ """
+ logging.debug("Using Authorization Grant to request token")
+ session = OAuth2Session(
+ client_id=self.client_info.client_id,
+ redirect_uri=self._get_redirect_url(),
+ scope=scopes
+ )
+ authorization_url, state = session.authorization_url(self.client_info.auth_url)
+
+ redirect_result = self._get_redirect_result(authorization_url)
+
+ token = session.fetch_token(self.client_info.token_url,
+ client_secret=self.client_info.client_secret,
+ authorization_response=redirect_result)
+
+ return token
+
+ def _request_impl_grant(self, scopes: list[str] = None) -> OAuth2Token:
+ """
+ Fetch a new token from the service using the Implicit Grant.
+
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ """
+ logging.debug("Using Implicit Grant to request token")
+ session = OAuth2Session(
+ client=MobileApplicationClient(client_id=self.client_info.client_id),
+ redirect_uri=self._get_redirect_url(),
+ scope=scopes
+ )
+ authorization_url, state = session.authorization_url(
+ self.client_info.auth_url)
+
+ redirect_result = self._get_redirect_result(authorization_url, fragment_expected=True)
+
+ token = session.token_from_fragment(redirect_result)
+
+ return token
+
+ def switch_user(self, user_agent: UserInfo) -> None:
+ """
+ Switch to a different user for authorizing requests. This will also clear
+ the currently active token.
+
+ :param user_agent: UserInfo object definition for the new user.
+ :type user_agent: UserInfo
+ """
+ logging.debug(f"Switching to user '{user_agent.internal_id}'")
+ if self.active_token:
+ self.token_history.insert(0, self.active_token)
+ self.active_token = None
+
+ self.user_agent = user_agent
+
+ def refresh_token(self, token: OAuth2Token, scopes: list[str] = None) -> OAuth2Token:
+ """
+ Refresh a given token.
+
+ :param token: Token that should be refreshed.
+ :type token: OAuth2Token
+ :param scopes: Scopes that must be assigned to this token.
+ :type scopes: list[str]
+ """
+ logging.debug("Using Refresh Token Grant to request token")
+ session = OAuth2Session(
+ client_id=self.client_info.client_id,
+ scope=scopes,
+ token=token
+ )
+
+ new_token = session.refresh_token(
+ self.client_info.token_url,
+ client_id=self.client_info.client_id,
+ client_secret=self.client_info.client_secret
+ )
+
+ return new_token
+
+ def _get_redirect_result(self, auth_url: str, fragment_expected: bool = False) -> str:
+ """
+ Authorize the resource owner by using the supplied user agent (if it exists) and
+ return the resulting redirect URL.
+
+ :param auth_url: Authorization URL for the authorization request.
+ :type auth_url: str
+ :param fragment_expected: True if the redirect URL should contain a fragment, else False.
+ :type fragment_expected: bool
+ """
+ user_session = self._get_user_session()
+ if user_session:
+ logging.debug(f"Selected user agent via established session: {user_session}")
+ if isinstance(user_session, ROBrowserSession):
+ server = HTTPServer(
+ ('127.0.0.1', user_session.port),
+ RedirectHandler,
+ bind_and_activate=True
+ )
+
+ # Wait max 30 seconds for answer
+ server.timeout = 30
+ user_session.authorize(auth_url)
+ server.handle_request()
+ # time.sleep(5) # Helps when too many auth requests are sent (?)
+ if RedirectHandler.called and fragment_expected:
+ RedirectHandler.called = False
+ RedirectHandler.call_url = None
+
+ # Handle a second request that contains the fragment value
+ server.handle_request()
+
+ if RedirectHandler.called:
+ redirect_result = RedirectHandler.call_url
+
+ # Reset handler class
+ RedirectHandler.called = False
+ RedirectHandler.call_url = None
+
+ else:
+ logging.warning(f"No redirect received after {server.timeout} seconds. "
+ "Token could not be generated")
+ return # TODO: Custom error type
+
+ else:
+ redirect_response = user_session.session.get(auth_url, allow_redirects=False)
+ redirect_result = redirect_response.headers["location"]
+
+ else:
+ logging.debug("No established session found. Creating manual request.")
+ cli_manual_request = ("Authorization required for OAuth2 "
+ f"Authorization Grant:\n{auth_url}")
+ print(cli_manual_request)
+ logging.debug(cli_manual_request)
+
+ redirect_result = input('Enter the returned URI: ')
+
+ logging.debug(f"Got Redirect URI: {redirect_result}")
+ return redirect_result
+
+ def _get_user_session(self):
+ """
+ Get the user session that is used for authorizing the request.
+ """
+ user_session = None
+ if self.user_agent and self.user_agent.sessions:
+ web_session = None
+ # Select established session
+ for _, session in self.user_agent.sessions.items():
+ # Prefer the cookie session
+ if isinstance(session, ROCookieSession):
+ user_session = session
+ break
+
+ elif isinstance(session, ROBrowserSession):
+ user_session = session
+ break
+
+ elif isinstance(session, ROWebSession):
+ web_session = session
+
+ else:
+ # Use the web session if no alternative is available
+ user_session = web_session
+
+ return user_session
+
+ def _get_redirect_url(self) -> str:
+ """
+ Get the redirect URL that is used for redirecting the user agent.
+ """
+ if isinstance(self._get_user_session(), ROBrowserSession):
+ # Browser sessions need a HTTP redirect URL to avoid TLS handshake
+ # problem in HTTPServer
+ for url in self.client_info.redirect_urls:
+ if url.startswith("http:"):
+ return url
+
+ else:
+ # Use HTTPS redirect URL for everything else
+ for url in self.client_info.redirect_urls:
+ if url.startswith("https:"):
+ return url
+
+ # Default URL (= first URL) as fallback
+ return self.client_info.redirect_urls[0]
+
+ def _is_expired(self, token: OAuth2Token) -> bool:
+ """
+ Check if a given token is expired.
+
+ :param token: Token that is checked.
+ :type token: OAuth2Token
+ """
+ if "expires_at" in token.keys():
+ return time.time() - token["expires_at"] > 0
+
+ return False
+
+ def _is_refreshable(self, token: OAuth2Token) -> bool:
+ """
+ Check if a given token has a refresh token.
+
+ :param token: Token that is checked.
+ :type token: OAuth2Token
+ """
+ return "refresh_token" in token.keys()
+
+ def __getitem__(self, key):
+ """
+ Access the active token's parameters like a dict.
+ """
+ return self.get_token()[key]
+
+ @staticmethod
+ def _setup_oauthlib() -> None:
+ """
+ Setup environment variables for oauthlib.
+ """
+ # Deactivate warning when using HTTP redirects
+ os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'
+
+ # Deactivate warning when scopes change ion refresh
+ os.environ['OAUTHLIB_RELAX_TOKEN_SCOPE'] = '1'
diff --git a/rest_attacker/util/auth/userinfo.py b/rest_attacker/util/auth/userinfo.py
new file mode 100644
index 0000000..66e012c
--- /dev/null
+++ b/rest_attacker/util/auth/userinfo.py
@@ -0,0 +1,62 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Handles information about a user/resource owner.
+"""
+
+from __future__ import annotations
+import typing
+
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.util.auth.session import ROSession
+
+
+class UserInfo:
+ """
+ Stores user information.
+ """
+
+ def __init__(
+ self,
+ internal_id: str,
+ account_id: str,
+ user_id: str,
+ userinfo_endpoint: str = None,
+ owned_resources: dict[str, list[str]] = {},
+ allowed_resources: dict[str, list[str]] = {},
+ sessions: dict[str, ROSession] = None,
+ credentials: list[str] = [],
+ ) -> None:
+ """
+ Creates a new UserInfo object.
+
+ :param internal_id: Internal ID for the user in the tool.
+ :param account_id: ID of the user account (i.e. login name).
+ :param user_id: ID of the user in the service.
+ :param userinfo_endpoint: API endpoint where user information can be fetched (optional).
+ :param owned_resources: Dict of resource IDs mapped to usable object IDs that are owned by the user (optional).
+ :param allowed_resources: Dict of resource IDs mapped to usable object IDs that the user has access to, but does not own (optional).
+ :param sessions: Sessions that can be used as a user-agent to fetch authorization tokens (optional).
+ :param credentials: Credentials that can be used for this user (optional).
+ :type internal_id: str
+ :type account_id: str
+ :type user_id: str
+ :type userinfo_endpoint: str
+ :type owned_resources: dict[str, list[str]]
+ :type allowed_resources: dict[str, list[str]]
+ :type sessions: dict[str, ROSession]
+ :type credentials: list[str]
+ """
+ self.internal_id = internal_id
+
+ self.account_id = account_id
+ self.user_id = user_id
+
+ self.userinfo_endpoint = userinfo_endpoint
+
+ self.owned_resources = owned_resources
+ self.allowed_resources = allowed_resources
+
+ self.sessions = sessions
+ self.credentials = credentials
diff --git a/rest_attacker/util/enum_test_cases.py b/rest_attacker/util/enum_test_cases.py
new file mode 100644
index 0000000..baf1cad
--- /dev/null
+++ b/rest_attacker/util/enum_test_cases.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Enumerate the test cases provided by the tool.
+"""
+
+import argparse
+import pkgutil
+import inspect
+import sys
+from rest_attacker.checks.generic import TestCase
+import rest_attacker.checks as checks
+
+
+def get_test_cases():
+ """
+ Enumerate all test cases in the 'checks' submodule.
+ """
+ test_cases = {}
+
+ name_prefix = checks.__name__ + "."
+ for _, modname, ispkg in pkgutil.iter_modules(checks.__path__):
+ if ispkg:
+ continue
+
+ abs_modname = f"{name_prefix}{modname}"
+
+ if abs_modname == "rest_attacker.checks.generic":
+ # Ignore the ABC class
+ continue
+
+ # Needs non-empty 'fromlist' to import the actual module
+ module = __import__(abs_modname, fromlist=" ")
+ for m_cls in inspect.getmembers(module, inspect.isclass):
+ mod_cls = m_cls[1]
+ test_case_cls = TestCase
+ if mod_cls == test_case_cls:
+ continue
+
+ if issubclass(mod_cls, test_case_cls):
+ test_cases[mod_cls.get_test_case_id()] = mod_cls
+
+ return test_cases
+
+
+class GetTestCases(argparse.Action):
+ """
+ Enumerate all test cases in the 'checks' submodule.
+ """
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ print("List of available test cases:")
+
+ test_cases = get_test_cases()
+ for test_case in test_cases.keys():
+ print(f" {test_case}")
+
+ sys.exit(0)
diff --git a/rest_attacker/util/errors.py b/rest_attacker/util/errors.py
new file mode 100644
index 0000000..16a67c8
--- /dev/null
+++ b/rest_attacker/util/errors.py
@@ -0,0 +1,21 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Errors and exceptions raised by REST-Attacker.
+"""
+
+
+class RestrictedOperationError(Exception):
+ """
+ Should be raised when the tool tries to execute an API endpoint operation that
+ is not allowlisted (e.g. DELETE when using safemode).
+ """
+ pass
+
+
+class RateLimitException(Exception):
+ """
+ Should be raised when the tool detects that the tool reached some kind of
+ rate limit when communicating with an API.
+ """
+ pass
diff --git a/rest_attacker/util/input_gen.py b/rest_attacker/util/input_gen.py
new file mode 100644
index 0000000..109dabd
--- /dev/null
+++ b/rest_attacker/util/input_gen.py
@@ -0,0 +1,167 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Generate inputs for parameters in a request.
+"""
+
+from __future__ import annotations
+import typing
+
+import jsonschema
+import re
+import logging
+
+from jsf import JSF
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.util.auth.userinfo import UserInfo
+
+
+def fake_param(param_schema: dict):
+ """
+ Create a fake parameter value for a parameter from a JSON schema definition.
+
+ :param param_schema: JSON schema definition for the parameter.
+ :type param_schema: dict
+ """
+ faker = JSF(param_schema)
+
+ return faker.generate()
+
+
+def fake_path_params(path: str, param_schemas: dict[str, dict]) -> str:
+ """
+ Create fake parameter values for parameters in a given path. Returns the
+ parametrized path.
+
+ :param path: The path string. Parameters in the path are enclosed by curly brackets.
+ :type path: str
+ :param param_schemas: JSON schema definition for each parameter.
+ :type param_schemas: dict
+ """
+ new_path = path
+ params = re.findall(r"\{[a-zA-Z0-9]+\}", path)
+
+ # REplace all placeholders with fake values
+ for p_param in params:
+ param_schema = param_schemas[p_param[1:-1]]
+ fake_value = fake_param(param_schema)
+
+ search_param = re.escape(p_param)
+ new_path = re.sub(search_param, fake_value, new_path)
+
+ return new_path
+
+
+def replace_params(
+ path: str,
+ user_info: UserInfo,
+ param_defs: dict[str, dict]
+) -> tuple[str, dict, dict, dict] | None:
+ """
+ Replace parameter definitions by user defined values.
+
+ :param defined_params: Replacement values.
+ :type defined_params: dict
+ :param param_defs: OpenAPI parameter definitions.
+ :type param_defs: dict
+ """
+ if user_info.owned_resources:
+ defined_params = user_info.owned_resources
+
+ elif user_info.allowed_resources:
+ defined_params = user_info.allowed_resources
+
+ else:
+ # No replacement parameters defined
+ return None
+
+ new_path = replace_uri_params(path, defined_params)
+ header_params, query_params, cookie_params = replace_http_params(defined_params, param_defs)
+
+ return new_path, header_params, query_params, cookie_params
+
+
+def replace_uri_params(
+ path: str,
+ defined_params: dict[str, list[str]],
+ required_schemas: dict[str, dict] = None
+) -> str:
+ """
+ Replace parameters value for a parameter from a JSON schema definition.
+
+ :param path: The path string. Parameters in the path are enclosed by curly brackets.
+ :type path: str
+ :param defined_params: Parameter values.
+ :type defined_params: dict
+ :param required_schemas: JSON schema definitions for the parameters.
+ :type required_schemas: dict
+ """
+ new_path = path
+ params = re.findall(r"\{[a-zA-Z0-9]+\}", path)
+
+ # Replace all placeholders
+ for param_def in params:
+ param_id = param_def[1:-1]
+ if not param_id in defined_params.keys():
+ # Exit if a required parameter cannot be found
+ logging.warning(f"Could not find paramater '{param_id}' in lookup dict.")
+ return ""
+
+ param_value = defined_params[param_id][0]
+ if required_schemas and param_id in required_schemas.keys():
+ # Optional schema validation
+ try:
+ jsonschema.validate(param_value, required_schemas[param_id])
+
+ except jsonschema.ValidationError:
+ # Continue if schema is not correct but log error
+ logging.info(
+ f"Parameter '{param_id}' does not conform to requested schema.")
+
+ except jsonschema.SchemaError:
+ # Continue if schema is not correct but log error
+ logging.info(
+ f"Requested schema for parameter '{param_id}' is invalid.")
+
+ search_param = re.escape(param_def)
+ new_path = re.sub(search_param, param_value, new_path)
+
+ return new_path
+
+
+def replace_http_params(
+ defined_params: dict[str, list[str]],
+ param_defs: dict[str, dict]
+) -> tuple[dict, dict, dict]:
+ """
+ Replace parameter values for an endpoint from OpenAPI parameter definitions.
+
+ :param defined_params: Replacement values.
+ :type defined_params: dict
+ :param param_defs: OpenAPI parameter definitions.
+ :type param_defs: dict
+ """
+ header_params = {}
+ query_params = {}
+ cookie_params = {}
+
+ # Replace all placeholders
+ for param_id, param_def in param_defs.items():
+ if not param_id in defined_params.keys():
+ # Exit if a required parameter cannot be found
+ logging.warning(f"Could not find paramater '{param_id}' in lookup dict.")
+ return {}, {}, {}
+
+ # TODO: Schema validations?
+
+ if param_def["in"] == "header":
+ header_params[param_id] = defined_params[param_id][0]
+
+ elif param_def["in"] == "query":
+ query_params[param_id] = defined_params[param_id][0]
+
+ elif param_def["in"] == "cookie":
+ cookie_params[param_id] = defined_params[param_id][0]
+
+ return header_params, query_params, cookie_params
diff --git a/rest_attacker/util/log.py b/rest_attacker/util/log.py
new file mode 100644
index 0000000..5c2d2a5
--- /dev/null
+++ b/rest_attacker/util/log.py
@@ -0,0 +1,37 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Setup logging for the tool.
+"""
+
+import logging
+
+
+def setup_logging(cli_loglevel=logging.WARNING, file_loglevel=logging.DEBUG, logpath=None):
+ """
+ Setup logging for the tool.
+
+ :param cli_loglevel: Loglevel for logging to the CLI.
+ :type cli_loglevel: int
+ :param file_loglevel: Loglevel for logging to file.
+ :type file_loglevel: EngineConfig
+ :param logpath: Path to the log file. If 'None', no file handler is created.
+ :type logpath: pathlib.Path
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+ formatter = logging.Formatter("[%(levelname)s] %(message)s")
+
+ # CLI output
+ handler = logging.StreamHandler()
+ handler.setLevel(cli_loglevel)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+ # File output
+ handler = logging.FileHandler(str(logpath.resolve()))
+ handler.setLevel(file_loglevel)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+
+ logging.info(f"Logfile created at: {logpath}")
diff --git a/rest_attacker/util/openapi/__init__.py b/rest_attacker/util/openapi/__init__.py
new file mode 100644
index 0000000..e3026c0
--- /dev/null
+++ b/rest_attacker/util/openapi/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Helper functions and classes for handling Swagger/OpenAPI definitions.
+"""
diff --git a/rest_attacker/util/openapi/wrapper.py b/rest_attacker/util/openapi/wrapper.py
new file mode 100644
index 0000000..69674d1
--- /dev/null
+++ b/rest_attacker/util/openapi/wrapper.py
@@ -0,0 +1,319 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Wrapper for Swagger 2.0 and OpenAPI 3.0 formats
+"""
+
+from collections import defaultdict
+from urllib.parse import unquote
+
+
+class OpenAPI:
+ """
+ Wrapper for an OpenAPI definition.
+ """
+
+ def __init__(self, description_id: str, content: dict) -> None:
+ """
+ Create a new OpenAPI description.
+
+ :param description_id: Identifier for the description.
+ :type description_id: str
+ :param content: Content of the description file.
+ :type content: dict
+ """
+ self.description_id = description_id
+ self.definition = content
+
+ self.version = None
+ if "swagger" in self.definition.keys():
+ if self.definition["swagger"] == "2.0":
+ self.version = self.definition["swagger"]
+ self.transform()
+
+ elif "openapi" in self.definition.keys():
+ self.version = self.definition["openapi"]
+
+ else:
+ raise Exception("Could not find version in OpenAPI description.")
+
+ def transform(self) -> None:
+ """
+ Transform the format from swagger 2.0 to OpenAPI 3.0.
+ """
+ # Multiple hosts are supported
+ host = self.definition.pop("host")
+ base_path = self.definition.pop("basePath")
+ schemes = self.definition.pop("schemes")
+
+ servers = []
+ for scheme in schemes:
+ server_url = f"{scheme}://{host}{base_path}"
+ servers.append({"url": server_url})
+
+ self.definition["servers"] = servers
+
+ global_consumes = self.definition.pop("consumes", [])
+ global_produces = self.definition.pop("produces", [])
+ for path in self.definition["paths"].values():
+ for method in path.values():
+ local_consumes = method.pop("consumes", [])
+ if len(local_consumes) == 0:
+ if len(global_consumes) == 0:
+ local_consumes = []
+
+ else:
+ local_consumes = global_consumes[0]
+
+ input_parameters = method.pop("parameters", [])
+ if len(input_parameters) > 0:
+ response_body = None
+ response_form = None
+ for param in input_parameters:
+ if param["in"] == "body":
+ response_body = {
+ "description": param["description"],
+ "content": {
+ local_consumes: param["schema"]
+ }
+ }
+
+ elif param["in"] == "form":
+ response_form = {
+ "description": param["description"],
+ "content": {
+ local_consumes: param["schema"]
+ }
+ }
+
+ if response_body:
+ method["responseBody"] = response_body
+
+ if response_form:
+ method["responseForm"] = response_form
+
+ for response in method["responses"].values():
+ response_schema = response.pop("schema", [])
+ if response_schema:
+ response.update({
+ "content": {
+ global_produces[0]: response_schema
+ }
+ })
+
+ def resolve_ref(self, ref: str) -> None | dict:
+ """
+ Get the referenced object to a relative reference. The reference can be an
+ URI or a JSON pointer (RFC 6901).
+
+ :param ref: Reference URI.
+ :type ref: str
+ """
+ if ref[0] == "#":
+ # JSON pointer
+ # Remove URI encoding
+ new_ref = unquote(ref)
+
+ # Split into parts
+ parts = new_ref[2:].split('/')
+
+ # Start at root
+ current_item = self.definition
+ for part in parts:
+ # Replace escaped symbols_ '~', '/'
+ part_ref = part.replace('~0', '~')
+ part_ref = part_ref.replace("~1", "/")
+
+ if isinstance(current_item, dict):
+ # JSON object
+ current_item = current_item[part_ref]
+
+ elif isinstance(current_item, list):
+ # JSON array
+ current_item = current_item[int(part_ref)]
+
+ else:
+ return Exception(f"Item at {part} in {new_ref} must be a JSON object or array.")
+
+ return current_item
+
+ # TODO: External references
+ return None
+
+ def get_security_requirements(self, path: str, operation: str) -> list[dict]:
+ """
+ Get the security requirements of an endpoint.
+ """
+ endpoint_def = self.paths[path][operation]
+ if "security" in endpoint_def.keys():
+ return endpoint_def["security"]
+
+ # Fall back to default security requirements if they exist
+ elif "security" in self.definition.keys():
+ return self.definition["security"]
+
+ return []
+
+ def requires_auth(self, path: str, operation: str) -> bool:
+ """
+ Check whether an endpoint requires authentication or authorization for access.
+ """
+ endpoint_reqs = self.get_security_requirements(path, operation)
+
+ return len(endpoint_reqs) > 0
+
+ def get_required_param_defs(self, path: str, operation: str) -> dict[str, dict]:
+ """
+ Get the parameter requirement definitions for an endpoint.
+ """
+ path_def = self.paths[path]
+ endpoint_def = path_def[operation]
+ params = {}
+
+ # Path parameters
+ if "parameters" in path_def.keys():
+ for param in path_def["parameters"]:
+ if "$ref" in param.keys():
+ param = self.resolve_ref(param["$ref"])
+
+ if "required" in param.keys() and param["required"] == True:
+ params.update({
+ param["name"]: param
+ })
+
+ # Endpoint parameters (overwrite path parameter definitions)
+ if "parameters" in endpoint_def.keys():
+ for param in endpoint_def["parameters"]:
+ if "$ref" in param.keys():
+ param = self.resolve_ref(param["$ref"])
+
+ if "required" in param.keys() and param["required"] == True:
+ params.update({
+ param["name"]: param
+ })
+
+ elif "required" in param.keys() and param["required"] == False:
+ params.pop(param["name"], None)
+
+ return params
+
+ def get_required_param_ids(self, path: str, operation: str) -> list[str]:
+ """
+ Get the IDs of the required parameter of an endpoint.
+ """
+ return list(self.get_required_param_defs(path, operation).keys())
+
+ def requires_parameters(self, path: str, operation: str) -> bool:
+ """
+ Check whether an endpoint requires one or more input parameters.
+ """
+ endpoint_reqs = self.get_required_param_ids(path, operation)
+
+ return len(endpoint_reqs) > 0
+
+ def get_nosec_endpoints(self) -> dict[str, list[str]]:
+ """
+ Get all endpoint IDs that require no security.
+ """
+ endpoints = defaultdict(list)
+
+ search_endpoints = self.endpoints
+ for path_id, path in search_endpoints.items():
+ for op_id, _ in path.items():
+ if not self.requires_auth(path_id, op_id):
+ endpoints[path_id].append(op_id)
+
+ return dict(endpoints)
+
+ def get_sec_endpoints(self) -> dict[str, list[str]]:
+ """
+ Get all endpoint IDs that have at least one security requirement.
+ """
+ endpoints = defaultdict(list)
+
+ search_endpoints = self.endpoints
+ for path_id, path in search_endpoints.items():
+ for op_id, _ in path.items():
+ if self.requires_auth(path_id, op_id):
+ endpoints[path_id].append(op_id)
+
+ return dict(endpoints)
+
+ def get_param_endpoints(self) -> dict[str, list[str]]:
+ """
+ Get all endpoint IDs that have at least one parameter requirement.
+ """
+ endpoints = defaultdict(list)
+
+ search_endpoints = self.endpoints
+ for path_id, path in search_endpoints.items():
+ for op_id, _ in path.items():
+ if self.requires_parameters(path_id, op_id):
+ endpoints[path_id].append(op_id)
+
+ return dict(endpoints)
+
+ def get_noparam_endpoints(self) -> dict[str, list[str]]:
+ """
+ Get all endpoint IDs that require no parameter.
+ """
+ endpoints = defaultdict(list)
+
+ search_endpoints = self.endpoints
+ for path_id, path in search_endpoints.items():
+ for op_id, _ in path.items():
+ if not self.requires_parameters(path_id, op_id):
+ endpoints[path_id].append(op_id)
+
+ return dict(endpoints)
+
+ @property
+ def components(self) -> dict:
+ """
+ Get the component definitions of the description.
+ """
+ return self.definition["components"]
+
+ @property
+ def endpoints(self) -> dict[str, dict]:
+ """
+ Get only the path + operation definitions of the description. Other
+ fields from the PathItem object (summary, description, servers, parameters)
+ are excluded.
+ """
+ endpoints = defaultdict(dict)
+
+ for path_id, path in self.paths.items():
+ if "$ref" in path.keys():
+ # Follow reference
+ path = self.resolve_ref(path["$ref"])
+
+ for op_id, operation in path.items():
+ if op_id in ("summary", "description", "servers", "parameters"):
+ continue
+
+ endpoints[path_id].update({
+ op_id: operation
+ })
+
+ return dict(endpoints)
+
+ @ property
+ def paths(self) -> dict[str, dict]:
+ """
+ Get the path definitions of the description.
+ """
+ return self.definition["paths"]
+
+ @ property
+ def servers(self) -> list[dict]:
+ """
+ Get the server definitions of the description.
+ """
+ return self.definition["servers"]
+
+ def __getitem__(self, key):
+ return self.definition[key]
+
+ def __contains__(self, key):
+ return key in self.definition.keys()
diff --git a/rest_attacker/util/parsers/__init__.py b/rest_attacker/util/parsers/__init__.py
new file mode 100644
index 0000000..19dfede
--- /dev/null
+++ b/rest_attacker/util/parsers/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+This module contains parsers for various formats.
+"""
diff --git a/rest_attacker/util/parsers/config_auth.py b/rest_attacker/util/parsers/config_auth.py
new file mode 100644
index 0000000..0f7fb1c
--- /dev/null
+++ b/rest_attacker/util/parsers/config_auth.py
@@ -0,0 +1,230 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Loads the auth configuration format.
+"""
+
+from __future__ import annotations
+import typing
+
+import json
+import logging
+
+from rest_attacker.util.auth.auth_generator import AuthGenerator
+from rest_attacker.util.auth.auth_scheme import AuthType, BasicAuthScheme, KeyValueAuthScheme
+from rest_attacker.util.auth.session import ROBrowserSession, ROCookieSession, ROWebSession
+from rest_attacker.util.auth.token_generator import ClientInfo, OAuth2TokenGenerator
+from rest_attacker.util.auth.userinfo import UserInfo
+
+if typing.TYPE_CHECKING:
+ from pathlib import Path
+
+
+def load_auth_file(path: Path) -> tuple[typing.Optional[dict[str, UserInfo]],
+ dict,
+ AuthGenerator]:
+ """
+ Load a credentials and authentication info from JSON.
+
+ :param path: Path to the credentials file.
+ :type path: pathlib.Path
+ """
+ logging.debug("Starting: Loading credentials configuration.")
+
+ if not path.exists():
+ raise Exception(f"Configuration in '{path}' does not exist.")
+
+ if not path.is_file():
+ raise Exception(f"{path} is not a file")
+
+ with path.open() as credfile:
+ auth_data = json.load(credfile)
+
+ logging.debug(f"Using service credentials file at: {path}")
+
+ # Users + sessions (optional)
+ users = None
+ if "users" in auth_data.keys():
+ users = {}
+ users_info = auth_data["users"]
+
+ for user_internal_id, user_data in users_info.items():
+ user_info_data = {
+ "internal_id": user_internal_id,
+ "account_id": user_data["account_id"],
+ "user_id": user_data["user_id"]
+ }
+
+ if "userinfo_endpoint" in user_data.keys():
+ user_info_data["userinfo_endpoint"] = user_data["userinfo_endpoint"]
+
+ if "owned_resources" in user_data.keys():
+ user_info_data["owned_resources"] = user_data["owned_resources"]
+
+ if "allowed_resources" in user_data.keys():
+ user_info_data["allowed_resources"] = user_data["allowed_resources"]
+
+ if "credentials" in user_data.keys():
+ user_info_data["credentials"] = user_data["credentials"]
+
+ if "sessions" in user_data.keys():
+ sessions = {}
+ sessions_info = user_data["sessions"]
+
+ for session_id, session_data in sessions_info.items():
+ test_url = None
+ if "test_url" in session_data.keys():
+ test_url = session_data["test_url"]
+
+ if session_data["type"] == "weblogin":
+ login_data = session_data["params"]
+ login_url = session_data["url"]
+ session = ROWebSession(session_id, login_url, login_data, test_url=test_url)
+
+ sessions.update({
+ session_id: session
+ })
+
+ elif session_data["type"] == "cookie":
+ cookies = session_data["params"]
+ session = ROCookieSession(session_id, cookies, test_url=test_url)
+
+ sessions.update({
+ session_id: session
+ })
+
+ elif session_data["type"] == "browser":
+ executable = session_data["exec_path"]
+ port = int(session_data["local_port"])
+ session = ROBrowserSession(session_id, executable, port, test_url=test_url)
+
+ sessions.update({
+ session_id: session
+ })
+
+ else:
+ raise ValueError(f"Unrecognized session type: '{session_data['type']}'")
+
+ logging.debug(f"Added session info: {session_id}")
+
+ user_info_data["sessions"] = sessions
+
+ users.update({
+ user_internal_id: UserInfo(**user_info_data)
+ })
+
+ logging.debug(f"Added user info: {user_internal_id}")
+
+ # Required auth infos (optional)
+ required_min = {}
+ if "required_always" in auth_data.keys():
+ required_min = auth_data["required_always"]
+
+ logging.debug("Added unauthorized scheme requirements")
+
+ else:
+ logging.debug("No unauthorized schemes specified.")
+
+ required_auth = {}
+ if "required_auth" in auth_data.keys():
+ required_auth = auth_data["required_auth"]
+
+ logging.debug("Added authenticated scheme requirements")
+
+ else:
+ logging.debug("No authenticated schemes specified.")
+
+ # Credentials
+ credentials = auth_data["creds"]
+
+ # Create token generators for OAuth2
+ for cred_id, cred in auth_data["creds"].items():
+ if cred["type"] == "oauth2_client":
+ client_info_data = {
+ "client_id": cred["client_id"],
+ "client_secret": cred["client_secret"],
+ "redirect_urls": cred["redirect_uris"],
+ "auth_url": cred["authorization_endpoint"],
+ "token_url": cred["token_endpoint"],
+ }
+
+ if "revocation_endpoint" in cred.keys():
+ client_info_data.update(
+ {"revoke_url": cred["revocation_endpoint"]}
+ )
+
+ if "scopes" in cred.keys():
+ client_info_data.update(
+ {"scopes": cred["scopes"]}
+ )
+
+ if "grants" in cred.keys():
+ client_info_data.update(
+ {"grants": cred["grants"]}
+ )
+
+ if "description" in cred.keys():
+ client_info_data.update(
+ {"description": cred["description"]}
+ )
+
+ if "flags" in cred.keys():
+ client_info_data.update(
+ {"flags": cred["flags"]}
+ )
+
+ client_info = ClientInfo(**client_info_data)
+
+ # Initialize with a user-agent session to retrieve tokens
+ default_user = None
+ if users and len(users) > 0:
+ default_user = list(users.values())[0]
+
+ token_gen = OAuth2TokenGenerator(client_info, user_agent=default_user)
+
+ # Replace data with reference to token generator
+ credentials.update({
+ cred_id: token_gen
+ })
+
+ logging.debug(f"Created token generator for credentials: {cred_id}")
+
+ logging.debug(f"Added credentials: {cred_id}")
+
+ # Schemes
+ schemes = {}
+ scheme_info = auth_data["schemes"]
+ for scheme_id, scheme in scheme_info.items():
+ auth_type = AuthType[scheme["type"].upper()]
+ payload_pattern = scheme["payload"]
+ params_cfg = scheme["params"]
+
+ scheme_creds = {}
+ for _, param in params_cfg.items():
+ param_srcs = param["from"]
+ for param_src in param_srcs:
+ scheme_creds.update({
+ param_src: credentials[param_src]
+ })
+
+ if auth_type is AuthType.BASIC:
+ schemes.update({
+ scheme_id: BasicAuthScheme(
+ scheme_id, payload_pattern, params_cfg, credentials=scheme_creds)
+ })
+
+ else:
+ key_id = scheme["key_id"]
+ schemes.update({
+ scheme_id: KeyValueAuthScheme(
+ scheme_id, auth_type, key_id, payload_pattern,
+ params_cfg, credentials=scheme_creds)
+ })
+
+ logging.debug(f"Added scheme: {scheme_id}")
+
+ auth_gen = AuthGenerator(schemes, required_min, required_auth)
+
+ logging.debug("Finished: Loading credentials configuration.")
+
+ return users, credentials, auth_gen
diff --git a/rest_attacker/util/parsers/config_info.py b/rest_attacker/util/parsers/config_info.py
new file mode 100644
index 0000000..f3368c2
--- /dev/null
+++ b/rest_attacker/util/parsers/config_info.py
@@ -0,0 +1,125 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Loads the configuration formats.
+"""
+
+import json
+import logging
+
+from pathlib import Path
+
+from rest_attacker.util.parsers.config_auth import load_auth_file
+from rest_attacker.util.parsers.openapi import load_openapi
+from rest_attacker.engine.config import EngineConfig
+
+
+def load_config(path: Path) -> EngineConfig:
+ """
+ Load a service info and metadata files from JSON.
+
+ :param path: Path to the service directory.
+ :type path: pathlib.Path
+ """
+ logging.debug("Starting: Loading service configuration.")
+
+ if not path.exists():
+ raise Exception(f"Configuration in '{path}' does not exist.")
+
+ if not path.is_dir():
+ raise Exception(f"{path} is not a directory.")
+
+ # Mandatory info file
+ info_path = path / "info.json"
+ with info_path.open() as infofile:
+ info = json.load(infofile)
+
+ logging.debug(f"Using service info file at: {info_path}")
+
+ # Non-mandatory information files
+ if "meta" in info.keys():
+ meta_path = path / info["meta"]
+ with meta_path.open() as metafile:
+ meta = json.load(metafile)
+
+ logging.debug(f"Using service meta file at: {meta_path}")
+
+ else:
+ logging.debug("No service meta file specified: Skipping meta info load.")
+ meta = {}
+
+ if "credentials" in info.keys():
+ credentials_path = path / info["credentials"]
+ users, credentials, auth_gen = load_auth_file(credentials_path)
+
+ else:
+ logging.debug("No service credentials file specified: Skipping credential info load.")
+ users = None
+ credentials = {}
+ auth_gen = None
+
+ current_user_id = None
+ if users and len(users) > 0:
+ current_user_id = list(users.keys())[0]
+
+ if "descriptions" in info.keys():
+ descriptions = {}
+ for descr_key, descr in info["descriptions"].items():
+ if not descr["available"]:
+ logging.debug(f"Skipping API description '{descr_key}'")
+ continue
+
+ logging.debug(f"Using OpenAPI description '{descr_key}'")
+ description_path = path / descr["path"]
+ description = load_openapi(descr_key, description_path)
+ descriptions[descr_key] = description
+
+ logging.debug(f"{len(descriptions)} API descriptions available.")
+
+ else:
+ logging.debug("No API descriptions found.")
+
+ descriptions = None
+
+ logging.debug("Finished: Loading service configuration.")
+
+ return EngineConfig(
+ meta,
+ info,
+ credentials,
+ users=users,
+ current_user_id=current_user_id,
+ auth_gen=auth_gen,
+ descriptions=descriptions
+ )
+
+
+def create_config_from_openapi(path: Path) -> EngineConfig:
+ """
+ Create a temporary config from an OpenAPI description.
+
+ :param path: Path to the OpenAPI file.
+ :type path: pathlib.Path
+ """
+ logging.debug("Starting: Loading OpenAPI file.")
+
+ if not path.exists():
+ raise Exception(f"Configuration in '{path}' does not exist.")
+
+ if not path.is_file():
+ raise Exception(f"{path} is not a file.")
+
+ # Use filename as key
+ descr_key = path.name
+ logging.debug(f"Using OpenAPI description '{descr_key}'")
+ description = load_openapi(descr_key, path)
+ descriptions = {descr_key: description}
+
+ logging.debug("Finished: Loading service configuration.")
+
+ return EngineConfig(
+ {},
+ {},
+ {},
+ descriptions=descriptions
+ )
diff --git a/rest_attacker/util/parsers/config_run.py b/rest_attacker/util/parsers/config_run.py
new file mode 100644
index 0000000..16890d7
--- /dev/null
+++ b/rest_attacker/util/parsers/config_run.py
@@ -0,0 +1,92 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Loads the configuration formats.
+"""
+
+from __future__ import annotations
+import typing
+
+import json
+import logging
+from pathlib import Path
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.checks.generic import TestCase
+ from rest_attacker.engine.config import EngineConfig
+
+
+def load_config(
+ test_cases: dict[str, TestCase],
+ engine_cfg: EngineConfig,
+ path: Path,
+ continue_run: bool = False
+) -> list[TestCase]:
+ """
+ Load check configurations.
+
+ :param test_cases: Available test cases by test case ID.
+ :type test_cases: dict
+ :param engine_cfg: Configuration for the service
+ :type engine_cfg: EngineConfig
+ :param path: Path to the run configuration file.
+ :type path: pathlib.Path
+ """
+ logging.debug("Starting: Loading run configuration.")
+
+ if not path.exists():
+ raise Exception(f"Configuration in '{path}'' does not exist.")
+
+ if not path.is_file():
+ raise Exception(f"{path} is not a file")
+
+ with path.open() as checkfile:
+ check_cfg = json.load(checkfile)
+
+ logging.debug(f"Using checks file at: {path}")
+
+ if check_cfg["type"] == "report":
+ check_defs = check_cfg["reports"]
+
+ elif check_cfg["type"] == "run":
+ check_defs = check_cfg["checks"]
+
+ elif check_cfg["type"] == "partial":
+ part_check_defs = check_cfg["reports"]
+
+ if continue_run:
+ # Only run the aborted checks
+ check_defs = []
+ for check_def in part_check_defs:
+ if check_def["status"] == "aborted":
+ check_defs.append(check_def)
+
+ else:
+ check_defs = part_check_defs
+
+ checks = []
+ for check_def in check_defs:
+ if not "config" in check_def:
+ # No serialization provided
+ logging.warning(f"Skipping check {check_def['check_id']}: No serialization found.")
+ continue
+
+ check_id = check_def["check_id"]
+ test_case_id = check_def["test_case"]
+ test_case_cls = test_cases[test_case_id]
+
+ check = test_case_cls.deserialize(check_def["config"], engine_cfg, check_id)
+ if check:
+ checks.append(check)
+ logging.debug(
+ f"Configured check for test case '{check_def['test_case']}' loaded.")
+
+ else:
+ logging.info(
+ f"Check '{check_def['check_id']}' coild not be loaded. "
+ "No deserialization avaliable")
+
+ logging.debug(f"{len(checks)} checks loaded.")
+ logging.debug("Finished: Loading run configuration.")
+
+ return checks
diff --git a/rest_attacker/util/parsers/openapi.py b/rest_attacker/util/parsers/openapi.py
new file mode 100644
index 0000000..8a438f3
--- /dev/null
+++ b/rest_attacker/util/parsers/openapi.py
@@ -0,0 +1,40 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Parses an OpenAPI file. OpenAPI files are distributed as YAML
+or JSON files.
+"""
+
+import json
+import logging
+import yaml
+
+from pathlib import Path
+
+from rest_attacker.util.openapi.wrapper import OpenAPI
+
+
+def load_openapi(description_id: str, path: Path):
+ """
+ Load an OpenAPI YAML or JSON definition.
+
+ :param description_id: Identifier for the description.
+ :type description_id: str
+ :param path: Path to the file.
+ :type path: pathlib.Path
+ """
+ if not path.is_file():
+ raise Exception(f"{path} is not a file")
+
+ with path.open() as apifile:
+ logging.debug(f"Loading OpenAPI description at: {path}")
+ if path.suffix == ".json":
+ return OpenAPI(description_id, json.load(apifile))
+
+ elif path.suffix == ".yaml":
+ return OpenAPI(description_id, yaml.load(apifile, Loader=yaml.loader.FullLoader))
+
+ else:
+ raise Exception(
+ (f"{path.suffix} is not a recognized extension. "
+ "Expected '.json' or '.yaml'"))
diff --git a/rest_attacker/util/request/__init__.py b/rest_attacker/util/request/__init__.py
new file mode 100644
index 0000000..7552cab
--- /dev/null
+++ b/rest_attacker/util/request/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Helper functions and classes for handling (online) API requests.
+"""
diff --git a/rest_attacker/util/request/http_methods.py b/rest_attacker/util/request/http_methods.py
new file mode 100644
index 0000000..dbe201b
--- /dev/null
+++ b/rest_attacker/util/request/http_methods.py
@@ -0,0 +1,35 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+HTTP method names and categories.
+"""
+
+# Non-destructive (or read-only)
+SAFE_METHODS = [
+ "get",
+ "head",
+ "options",
+
+ # unsupported by requests module?
+ "trace",
+]
+
+# Potentially destructive
+UNSAFE_METHODS = [
+ "post",
+ "put",
+ "patch",
+ "delete",
+
+ # unsupported by requests module?
+ "connect",
+]
+
+# Multiple identical requests => same result on the server
+IDEMPOTENT_METHODS = [
+ "put",
+ "delete",
+]
+
+# All methods
+METHODS = list(set(SAFE_METHODS + UNSAFE_METHODS))
diff --git a/rest_attacker/util/request/request_info.py b/rest_attacker/util/request/request_info.py
new file mode 100644
index 0000000..b2531b4
--- /dev/null
+++ b/rest_attacker/util/request/request_info.py
@@ -0,0 +1,478 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Store request info for checks.
+"""
+
+import logging
+from urllib.parse import urlparse, urlunparse
+import requests
+
+from requests.models import Response
+
+from rest_attacker.util.auth.auth_scheme import AuthType
+from rest_attacker.util.auth.auth_generator import AuthGenerator
+from rest_attacker.util.auth.token_generator import AccessLevelPolicy
+from rest_attacker.util.errors import RestrictedOperationError
+from rest_attacker.util.request.http_methods import METHODS as HTTP_METHODS
+
+
+class RequestInfo:
+ """
+ Request for a check. Wraps around the requests interface to allow
+ separate specification of endpoint info, i.e. server URL, endpoint
+ path and endpoint operation.
+ """
+ # Settings that are used for every request. Overwritten by self.kwargs.
+ global_kwargs: dict = {}
+ allowed_ops = HTTP_METHODS
+
+ def __init__(self, url: str, path: str, operation: str, **kwargs) -> None:
+ """
+ Creates a new Request object.
+
+ :param url: URL string containing the base path to the REST API.
+ You can also pass a tuple returned from urllib.parse
+ :type url: str | tuple
+ :param path: Endpoint path.
+ :type path: str
+ :param operation: Endpoint operation.
+ :type operation: str
+ :param kwargs: Additional parameters that will be passed to requests.request() method.
+ :type kwargs: dict
+ """
+ if isinstance(url, str):
+ self._url = urlparse(url)
+
+ elif isinstance(url, tuple):
+ self._url = url
+
+ for idx in range(3, 6):
+ # Confirm that the URL contains no query/fragment information
+ if len(self._url[idx]) > 0:
+ logging.warning("URL for request contains more than scheme, netloc, path.")
+
+ self.path = path
+ self.operation = operation
+
+ self.kwargs = kwargs
+
+ def send(self, auth_data: list[tuple[AuthType, dict]] = None) -> Response:
+ """
+ Send the request and return the response.
+
+ :return: Response to the request.
+ :rtype: requests.models.Response
+ """
+ if self.operation.lower() not in self.allowed_ops:
+ raise RestrictedOperationError(f"HTTP Method {self.operation} is not allowed.")
+
+ kwargs = {}
+ kwargs.update(self.global_kwargs)
+
+ if auth_data:
+ kwargs = self._prepare_auth_args(auth_data)
+
+ else:
+ kwargs.update(self.kwargs)
+
+ return requests.request(self.operation, self.endpoint_url, **kwargs)
+
+ def _prepare_auth_args(self, auth_data: list[tuple[AuthType, dict]]) -> dict:
+ """
+ Prepare request arguments and include auth data.
+
+ :param auth_data: List of auth payloads specialized by auth type.
+ :type auth_data: list[tuple[AuthType, dict]]
+ """
+ # Make a copy to not pollute normal request info
+ tmp_kwargs = self.kwargs.copy()
+
+ for auth_type, auth_payload in auth_data:
+ if auth_type in (AuthType.HEADER, AuthType.BASIC):
+ if not "headers" in tmp_kwargs.keys():
+ tmp_kwargs["headers"] = {}
+
+ tmp_kwargs["headers"].update(auth_payload)
+
+ elif auth_type is AuthType.QUERY:
+ if not "params" in tmp_kwargs.keys():
+ tmp_kwargs["params"] = {}
+
+ tmp_kwargs["params"].update(auth_payload)
+
+ elif auth_type is AuthType.COOKIE:
+ if not "cookies" in tmp_kwargs.keys():
+ tmp_kwargs["cookies"] = {}
+
+ tmp_kwargs["cookies"].update(auth_payload)
+
+ return tmp_kwargs
+
+ def get_curl_command(self, auth_data: list[tuple[AuthType, dict]] = None) -> str:
+ """
+ Build a curl CLI command from the request info.
+ """
+ kwargs = {}
+ kwargs.update(self.global_kwargs)
+
+ if auth_data:
+ kwargs = self._prepare_auth_args(auth_data)
+
+ else:
+ kwargs.update(self.kwargs)
+
+ output = "curl "
+
+ # TODO: Proxies, verify, cert
+
+ # Headers
+ if len(self.headers) > 0:
+ output += " ".join(
+ f"-H \"{header_id}: {header_payload}\""
+ for header_id, header_payload in self.headers.items()
+ )
+ output += " "
+
+ # Cookies
+ if len(self.cookies) > 0:
+ output += " ".join(
+ f"-b \"{cookie_id}={cookie_payload}\""
+ for cookie_id, cookie_payload in self.cookies.items()
+ )
+ output += " "
+
+ # Body Data
+ if self.data:
+ output += f"-d {self.data} "
+
+ # Follow redirects (by default curl does not follow them)
+ if self.allow_redirects:
+ output += "-L "
+
+ # Timeout
+ if self.timeout:
+ output += f"-m {self.timeout} "
+
+ # HTTP method
+ output += f"-X {self.operation.upper()} "
+
+ # scheme + host + path
+ output += self.endpoint_url
+
+ # Query params
+ if self.params:
+ output += "?" + "&".join(
+ f"{param_id}={param_payload}"
+ for param_id, param_payload in self.params.items()
+ )
+
+ return output
+
+ @property
+ def endpoint_url(self) -> str:
+ """
+ Get the endpoint URL as a string.
+ """
+ return f"{urlunparse(self._url)}{self.path}"
+
+ @property
+ def url(self) -> str:
+ """
+ Get the server URL as a string.
+ """
+ return urlunparse(self._url)
+
+ @url.setter
+ def url(self, value) -> None:
+ """
+ Set the server URL. Can use either a tuple or a str.
+ """
+ if isinstance(value, str):
+ self._url = urlparse(value)
+
+ elif isinstance(value, tuple):
+ self._url = value
+
+ # Set optional parameters of requests library
+ # TODO: There must be a better way to do this. Subclass Request maybe?
+ @property
+ def params(self):
+ """
+ Get the query parameters of the request.
+ """
+ if "params" in self.kwargs:
+ return self.kwargs["params"]
+
+ if "params" in self.global_kwargs:
+ return self.global_kwargs["params"]
+
+ return {}
+
+ @params.setter
+ def params(self, value):
+ """
+ Set the query parameters of the request.
+ """
+ self.kwargs["params"] = value
+
+ @property
+ def data(self):
+ """
+ Get the body parameters or data of the request.
+ """
+ if "data" in self.kwargs:
+ return self.kwargs["data"]
+
+ if "data" in self.global_kwargs:
+ return self.global_kwargs["data"]
+
+ return {}
+
+ @data.setter
+ def data(self, value):
+ """
+ Set the body parameters or data of the request.
+ """
+ self.kwargs["data"] = value
+
+ @property
+ def json(self):
+ """
+ Get the JSON payload of the request.
+ """
+ if "json" in self.kwargs:
+ return self.kwargs["json"]
+
+ if "json" in self.global_kwargs:
+ return self.global_kwargs["json"]
+
+ return {}
+
+ @json.setter
+ def json(self, value):
+ """
+ Set the JSON payload of the request.
+ """
+ self.kwargs["json"] = value
+
+ @property
+ def headers(self):
+ """
+ Get the headers of the request.
+ """
+ if "headers" in self.kwargs:
+ return self.kwargs["headers"]
+
+ if "headers" in self.global_kwargs:
+ return self.global_kwargs["headers"]
+
+ return {}
+
+ @headers.setter
+ def headers(self, value):
+ """
+ Set the headers of the request.
+ """
+ self.kwargs["headers"] = value
+
+ @property
+ def cookies(self):
+ """
+ Get the cookies of the request.
+ """
+ if "cookies" in self.kwargs:
+ return self.kwargs["cookies"]
+
+ if "cookies" in self.global_kwargs:
+ return self.global_kwargs["cookies"]
+
+ return {}
+
+ @cookies.setter
+ def cookies(self, value):
+ """
+ Set the cookies of the request.
+ """
+ self.kwargs["cookies"] = value
+
+ @property
+ def timeout(self):
+ """
+ Get the timeout limit of the request.
+ """
+ if "timeout" in self.kwargs:
+ return self.kwargs["timeout"]
+
+ if "timeout" in self.global_kwargs:
+ return self.global_kwargs["timeout"]
+
+ return None
+
+ @timeout.setter
+ def timeout(self, value):
+ """
+ Set the timeout limit of the request.
+ """
+ self.kwargs["timeout"] = value
+
+ @property
+ def allow_redirects(self):
+ """
+ Get the redirect setting of the request.
+ """
+ if "allow_redirects" in self.kwargs:
+ return self.kwargs["allow_redirects"]
+
+ if "allow_redirects" in self.global_kwargs:
+ return self.global_kwargs["allow_redirects"]
+
+ return True
+
+ @allow_redirects.setter
+ def allow_redirects(self, value):
+ """
+ Set the redirect setting of the request.
+ """
+ self.kwargs["allow_redirects"] = value
+
+ @property
+ def proxies(self):
+ """
+ Get the proxy settings of the request.
+ """
+ if "proxies" in self.kwargs:
+ return self.kwargs["proxies"]
+
+ if "proxies" in self.global_kwargs:
+ return self.global_kwargs["proxies"]
+
+ return {}
+
+ @proxies.setter
+ def proxies(self, value):
+ """
+ Set the proxy settings of the request.
+ """
+ self.kwargs["proxies"] = value
+
+ @property
+ def verify(self):
+ """
+ Get the CA verification settings of the request.
+ """
+ if "verify" in self.kwargs:
+ return self.kwargs["verify"]
+
+ if "verify" in self.global_kwargs:
+ return self.global_kwargs["verify"]
+
+ return {}
+
+ @verify.setter
+ def verify(self, value):
+ """
+ Set the CA verification settings of the request.
+ """
+ self.kwargs["verify"] = value
+
+ @property
+ def cert(self):
+ """
+ Get the client-side cert settings of the request.
+ """
+ if "cert" in self.kwargs:
+ return self.kwargs["cert"]
+
+ if "cert" in self.global_kwargs:
+ return self.global_kwargs["cert"]
+
+ return {}
+
+ @cert.setter
+ def cert(self, value):
+ """
+ Set the client-side cert settings of the request.
+ """
+ self.kwargs["cert"] = value
+
+ def serialize(self) -> dict:
+ """
+ Serialize a request to a JSON-compatible dict.
+ """
+ return {
+ "url": self.url,
+ "path": self.path,
+ "operation": self.operation,
+ "kwargs": self.kwargs,
+ # Global kwargs should be reconfigurable?
+ # either way they could be stored somewhere else
+ # "global_args": self.global_kwargs
+ }
+
+ @classmethod
+ def deserialize(cls, serialized: dict):
+ """
+ Deserialize a request from a JSON-compatible dict to a RequestInfo object.
+
+ :param serialized: Serialized representation of the request.
+ :type serialized: dict
+ """
+ url = serialized.pop("url")
+ path = serialized.pop("path")
+ operation = serialized.pop("operation")
+ kwargs = serialized.pop("kwargs")
+ return RequestInfo(url, path, operation, **kwargs)
+
+
+class AuthRequestInfo:
+ """
+ Auth information for an online check. Contains a generator for dynamically
+ creating authentication and authorization payloads for the request.
+ """
+
+ def __init__(
+ self,
+ auth_gen: AuthGenerator,
+ scheme_ids: list[str] = None,
+ scopes: list[str] = None,
+ policy: AccessLevelPolicy = AccessLevelPolicy.DEFAULT
+ ) -> None:
+ """
+ Creates a new AuthRequestInfo object.
+
+ :param auth_gen: AuthGenerator for creation auth info.
+ :type auth_gen: AuthGenerator
+ :param scheme_ids: Optional list of scheme IDs that auth info should be generated for.
+ :type scheme_ids: list[str]
+ :param scopes: List of scopes that are requested if OAuth2 credentials are used.
+ :type scopes: list[str]
+ """
+ self.auth_gen = auth_gen
+ self.scheme_ids = scheme_ids
+ self.scopes = scopes
+ self.policy = policy
+
+ def serialize(self) -> dict:
+ """
+ Serialize authorized request information to a JSON-compatible dict.
+ """
+ return {
+ # auth_gen is recreated dynamically because its settings can change between runs
+ "scheme_ids": self.scheme_ids,
+ "scopes": self.scopes,
+ "policy": self.policy.name
+ }
+
+ @classmethod
+ def deserialize(cls, serialized: dict, auth_gen: AuthGenerator):
+ """
+ Deserialize authorized request information from a JSON-compatible dict to a AuthRequestInfo object.
+
+ :param auth_gen: AuthGenerator for creation auth info.
+ :type auth_gen: AuthGenerator
+ :param serialized: Serialized representation of the authorized request information.
+ :type serialized: dict
+ """
+ policy = AccessLevelPolicy[serialized.pop("policy")]
+ return AuthRequestInfo(auth_gen, policy=policy, **serialized)
diff --git a/rest_attacker/util/response_handler.py b/rest_attacker/util/response_handler.py
new file mode 100644
index 0000000..ef3d47d
--- /dev/null
+++ b/rest_attacker/util/response_handler.py
@@ -0,0 +1,164 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Handles responses to HTTP requests made by checks for the main engine.
+"""
+
+import time
+from requests import Response
+
+from rest_attacker.util.request.request_info import AuthRequestInfo, RequestInfo
+
+
+class RateLimitHandler:
+ """
+ Manages the rate limit from the last response.
+ """
+
+ def __init__(
+ self,
+ max_limit=1000,
+ remaining=1000,
+ reset_time: int = 3600,
+ headers: dict[str, str] = {}
+ ) -> None:
+ """
+ Create a new RateLimitHandler.
+
+ :param max_limit: Maximum number of requests until the rate limit must be reset.
+ :type max_limit: int
+ :param remaining: Remaining requests until max rate limit is reached.
+ :type remaining: int
+ :param reset_time: Seconds to wait until next request can be made after rate limit has been reached.
+ :type reset_time: int
+ :param headers: Identifiers for response headers indicating the current rate limit status.
+ Hints for max, remaining and reset time can be given.
+ :type headers: dict[str,str]
+ """
+ self.max_limit = max_limit
+ self.remaining = remaining
+
+ self.reset_time = reset_time
+
+ self.header_id_max = headers.get("rate_limit_max", None)
+ self.header_id_cur = headers.get("rate_limit_remaining", None)
+ self.header_id_reset = headers.get("rate_limit_reset", None)
+
+ def setup(self, response: Response) -> None:
+ """
+ Initialize the handle from the first response.
+ """
+ if self.header_id_max and self.header_id_max in response.headers.keys():
+ self.max_limit = int(response.headers[self.header_id_max])
+
+ if self.header_id_cur and self.header_id_cur in response.headers.keys():
+ self.remaining = int(response.headers[self.header_id_cur])
+
+ if self.header_id_reset and self.header_id_reset in response.headers.keys():
+ self.reset_time = int(response.headers[self.header_id_reset])
+
+ def reset(self, response: Response = None) -> None:
+ """
+ Reset the limit to the max limit or use the header values from the response.
+ """
+ if self.header_id_max and response:
+ self.max_limit = int(response.headers[self.header_id_max])
+
+ if self.header_id_cur and response:
+ self.remaining = int(response.headers[self.header_id_cur])
+
+ else:
+ self.remaining = self.max_limit
+
+ if self.header_id_reset and response:
+ self.reset_time = int(response.headers[self.header_id_reset])
+
+ def update(self, response: Response) -> bool:
+ """
+ Update the handler from a response. Return True if limit has been reached.
+ """
+ if response.status_code == 429:
+ return False
+
+ if self.header_id_cur and self.header_id_cur in response.headers.keys():
+ self.remaining = int(response.headers[self.header_id_cur])
+
+ else:
+ self.remaining -= 1
+
+ if self.header_id_reset and self.header_id_reset in response.headers.keys():
+ self.reset_time = int(response.headers[self.header_id_reset])
+
+ return self.remaining <= 0
+
+ def get_reset_wait_time(self):
+ """
+ Get the time required until the rate limit resets.
+ """
+ current_time = time.time()
+ required_time = int(self.reset_time - current_time) + 1
+
+ return required_time
+
+
+class AccessLimitHandler:
+ """
+ Manages responses to reaching the access limit of a service.
+ """
+
+ def __init__(
+ self,
+ test_request: RequestInfo,
+ auth_info: AuthRequestInfo,
+ interval: int = 10
+ ) -> None:
+ """
+ Create a new AccessLimitHandler.
+
+ :param test_request: Request that is used to test whether the access limit has been reached.
+ This should ideally be a GET operation to an endpoint that is protected
+ with access control measures. The resource should be accessible by the
+ currently active user.
+ :type test_request: RequestInfo
+ :param auth_info: Auth information for authenticating/authorizing the request.
+ :type auth_info: AuthRequestInfo
+ :param interval: Number of (online) checks that can be executed before the test request is sent.
+ :type interval: int
+ """
+ self.test_request = test_request
+ self.auth_info = auth_info
+
+ self.interval = interval
+
+ # Current position in the interval
+ self.current_pos = interval
+
+ # ID of the check before the last successful AccessLimitHandler check.
+ self.last_check_id = None
+
+ def reset(self) -> None:
+ """
+ Reset the interval.
+ """
+ self.current_pos = 0
+
+ def update(self) -> bool:
+ """
+ Check if the endpoint specified in the test request is still accessible. Return True if limit has been reached
+ """
+ auth_data = self.auth_info.auth_gen.get_auth(
+ scheme_ids=self.auth_info.scheme_ids,
+ scopes=self.auth_info.scopes,
+ policy=self.auth_info.policy
+ )
+ response = self.test_request.send(auth_data=auth_data)
+
+ if response.status_code == 429:
+ # Handled by RateLimitHandler
+ # TODO: More verbose return values than bool types may be helpful here
+ return False
+
+ if 200 <= response.status_code < 300:
+ return False
+
+ return True
diff --git a/rest_attacker/util/test_result.py b/rest_attacker/util/test_result.py
new file mode 100644
index 0000000..aa99cb3
--- /dev/null
+++ b/rest_attacker/util/test_result.py
@@ -0,0 +1,102 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Implementation of test result objects.
+"""
+from __future__ import annotations
+import typing
+
+import enum
+import logging
+
+if typing.TYPE_CHECKING:
+ from rest_attacker.checks.generic import TestCase
+
+
+class CheckStatus(enum.Enum):
+ """
+ Status of the check.
+ """
+ QUEUED = "queued" # test result is waiting for check to execute
+ RUNNING = "running" # check was started
+ FINISHED = "finished" # check finished successfully
+ SKIPPED = "skipped" # check was skipped
+ ABORTED = "aborted" # check was aborted because run finished early
+ ERROR = "error" # check failed with error
+
+
+class IssueType(enum.Enum):
+ """
+ Type of issue found by the check. Depends on the test case type.
+ """
+ # Analytical
+ CANDIDATE = "analysis_candidate" # Found what the check was looking for
+ NO_CANDIDATE = "analysis_none" # Found nothing unusual
+
+ # Security check
+ OKAY = "security_okay" # intended (secure) behaviour
+ PROBLEM = "security_problem" # unintended or undocumented behaviour
+ FLAW = "security_flaw" # insecure behaviour
+
+ # Comparision check
+ MATCH = "comparison_match" # check result values are equal
+ DIFFERENT = "comparison_different" # check result values are different
+
+
+class TestResult:
+ """
+ Stores the result of a check.
+ """
+
+ def __init__(self, check: TestCase) -> None:
+ """
+ Create a new TestResult object.
+
+ :param check: Reference to the check the result is created for.
+ :type check: TestCase
+ """
+ self.issue_type: IssueType = None
+ self.status = CheckStatus.QUEUED
+ self.error: Exception = None
+ self.check: TestCase = check
+
+ # Result value of the check. Should be a dict.
+ self.value: dict = None
+
+ self.last_response = None
+
+ def dump(self, verbosity: int = 2) -> dict[str, typing.Any]:
+ """
+ Generate a dictionary with information from the test result.
+
+ :param verbosity: Verbosity of the exported results.
+ 0 -> check_id, status, issue type
+ 1 -> 0 + error
+ 2 -> 1 + value (default)
+ :type verbosity: int
+ """
+ if self.status is CheckStatus.QUEUED:
+ logging.warning(f"{self}: Dumping test result for unfinished check.")
+
+ output = {
+ "check_id": self.check.check_id,
+ "test_type": self.check.test_type.value,
+ "test_case": self.check.get_test_case_id(),
+ "status": self.status.value,
+ }
+
+ if self.status not in (CheckStatus.SKIPPED, CheckStatus.ABORTED, CheckStatus.ERROR):
+ if not self.issue_type:
+ logging.warning(f"{self}: Dumping test result with unspecified issue type.")
+
+ output["issue"] = self.issue_type.value
+
+ if verbosity >= 1:
+ if self.error:
+ output["error"] = str(self.error)
+
+ if verbosity >= 2:
+ if self.value is not None:
+ output["value"] = self.value
+
+ return output
diff --git a/rest_attacker/util/version.py b/rest_attacker/util/version.py
new file mode 100644
index 0000000..b3f51d7
--- /dev/null
+++ b/rest_attacker/util/version.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2021-2022 the REST-Attacker authors. See COPYING and CONTRIBUTORS.md for legal info.
+
+"""
+Retrieves the version number.
+"""
+
+import subprocess
+import sys
+from argparse import Action
+
+
+class GetVersion(Action):
+ """
+ Retrieves version number using 'git describe'.
+ """
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ version = subprocess.check_output(
+ ["git", "describe", "--always"]).strip()
+ print(version.decode("utf8"))
+
+ sys.exit(0)