diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index 1ee2997f..214bee77 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -54,7 +54,6 @@ import io.supertokens.pluginInterface.multitenancy.exceptions.DuplicateThirdPartyIdException; import io.supertokens.pluginInterface.multitenancy.exceptions.TenantOrAppNotFoundException; import io.supertokens.pluginInterface.multitenancy.sqlStorage.MultitenancySQLStorage; -import io.supertokens.pluginInterface.oauth.exceptions.OAuth2ClientAlreadyExistsForAppException; import io.supertokens.pluginInterface.oauth.sqlStorage.OAuthSQLStorage; import io.supertokens.pluginInterface.passwordless.PasswordlessCode; import io.supertokens.pluginInterface.passwordless.PasswordlessDevice; @@ -3094,19 +3093,10 @@ public boolean doesClientIdExistForApp(AppIdentifier appIdentifier, String clien @Override public void addOrUpdateClientForApp(AppIdentifier appIdentifier, String clientId, boolean isClientCredentialsOnly) - throws StorageQueryException, OAuth2ClientAlreadyExistsForAppException { + throws StorageQueryException { try { - OAuthQueries.insertClientIdForAppId(this, clientId, appIdentifier); + OAuthQueries.insertClientIdForAppId(this, appIdentifier, clientId, isClientCredentialsOnly); } catch (SQLException e) { - - if (e instanceof PSQLException) { - PostgreSQLConfig config = Config.getConfig(this); - ServerErrorMessage serverMessage = ((PSQLException) e).getServerErrorMessage(); - - if (isPrimaryKeyError(serverMessage, config.getOAuthClientTable())) { - throw new OAuth2ClientAlreadyExistsForAppException(); - } - } throw new StorageQueryException(e); } } @@ -3164,12 +3154,20 @@ public int countTotalNumberOfM2MTokensCreatedSince(AppIdentifier appIdentifier, @Override public int countTotalNumberOfClientCredentialsOnlyClientsForApp(AppIdentifier appIdentifier) throws StorageQueryException { - return 0; // TODO + try { + return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, true); + } catch (SQLException e) { + throw new StorageQueryException(e); + } } @Override public int countTotalNumberOfClientsForApp(AppIdentifier appIdentifier) throws StorageQueryException { - return 0; // TODO + try { + return OAuthQueries.countTotalNumberOfClientsForApp(this, appIdentifier, false); + } catch (SQLException e) { + throw new StorageQueryException(e); + } } @TestOnly diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java index 3518b9fe..19c6715f 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/OAuthQueries.java @@ -14,7 +14,6 @@ import static io.supertokens.storage.postgresql.QueryExecutorTemplate.execute; import static io.supertokens.storage.postgresql.QueryExecutorTemplate.update; - public class OAuthQueries { public static String getQueryToCreateOAuthClientTable(Start start) { String schema = Config.getConfig(start).getTableSchema(); @@ -23,6 +22,7 @@ public static String getQueryToCreateOAuthClientTable(Start start) { return "CREATE TABLE IF NOT EXISTS " + oAuth2ClientTable + " (" + "app_id VARCHAR(64) DEFAULT 'public'," + "client_id VARCHAR(128) NOT NULL," + + "is_client_credentials_only BOOLEAN NOT NULL," + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "client_id", "pkey") + " PRIMARY KEY (app_id, client_id)," + "CONSTRAINT " + Utils.getConstraintName(schema, oAuth2ClientTable, "app_id", "fkey") @@ -82,13 +82,17 @@ public static List listClientsForApp(Start start, AppIdentifier appIdent }); } - public static void insertClientIdForAppId(Start start, String clientId, AppIdentifier appIdentifier) + public static void insertClientIdForAppId(Start start, AppIdentifier appIdentifier, String clientId, + boolean isClientCredentialsOnly) throws SQLException, StorageQueryException { String INSERT = "INSERT INTO " + Config.getConfig(start).getOAuthClientTable() - + "(app_id, client_id) VALUES(?, ?)"; + + "(app_id, client_id, is_client_credentials_only) VALUES(?, ?, ?) " + + "ON CONFLICT (app_id, client_id) DO UPDATE SET is_client_credentials_only = ?"; update(start, INSERT, pst -> { pst.setString(1, appIdentifier.getAppId()); pst.setString(2, clientId); + pst.setBoolean(3, isClientCredentialsOnly); + pst.setBoolean(4, isClientCredentialsOnly); }); } @@ -119,7 +123,8 @@ public static void revoke(Start start, AppIdentifier appIdentifier, String targe }); } - public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String[] targetTypes, String[] targetValues, long issuedAt) + public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String[] targetTypes, + String[] targetValues, long issuedAt) throws SQLException, StorageQueryException { String QUERY = "SELECT app_id FROM " + Config.getConfig(start).getOAuthRevokeTable() + " WHERE app_id = ? AND timestamp > ? AND ("; @@ -147,4 +152,32 @@ public static boolean isRevoked(Start start, AppIdentifier appIdentifier, String } }, ResultSet::next); } + + public static int countTotalNumberOfClientsForApp(Start start, AppIdentifier appIdentifier, + boolean filterByClientCredentialsOnly) throws SQLException, StorageQueryException { + if (filterByClientCredentialsOnly) { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientTable() + + " WHERE app_id = ? AND is_client_credentials_only = ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + pst.setBoolean(2, true); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } else { + String QUERY = "SELECT COUNT(*) as c FROM " + Config.getConfig(start).getOAuthClientTable() + + " WHERE app_id = ?"; + return execute(start, QUERY, pst -> { + pst.setString(1, appIdentifier.getAppId()); + }, result -> { + if (result.next()) { + return result.getInt("c"); + } + return 0; + }); + } + } }