Skip to content

Commit

Permalink
Use custom paging provider for aggregate DAO queries (#5606)
Browse files Browse the repository at this point in the history
This commit adds a custom paging provider that is used only by the aggregate
DAO. This is required because the standard paging provider that ships with
Spring Batch 4.x does not properly handle sort key aliases when using nested
ROW_NUMBER clauses.

* This also sneaks in Mac ARM64 support for DB2.

Resolves #5531
  • Loading branch information
onobc authored Dec 22, 2023
1 parent 7bc89c6 commit 6cb57b3
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 9 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ You can follow the steps in the [MSSQL on Mac ARM64](https://github.com/spring-c
----

## Running Locally w/ IBM DB2
By default, the Dataflow server jar does not include the DB2 database driver dependency.
If you want to use DB2 for development/testing when running locally, you can specify the `local-dev-db2` Maven profile when building.
The following command will include the DB2 driver dependency in the jar:
```
$ ./mvnw -s .settings.xml clean package -Plocal-dev-db2
```
You can follow the steps in the [DB2 on Mac ARM64](https://github.com/spring-cloud/spring-cloud-dataflow/wiki/DB2-on-Mac-ARM64#running-dataflow-locally-against-db2) Wiki to run DB2 locally in Docker with Dataflow pointing at it.

> **NOTE:** If you are not running Mac ARM64 just skip the steps related to Homebrew and Colima
----

## Contributing

We welcome contributions! See the [CONTRIBUTING](./CONTRIBUTING.adoc) guide for details.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.cloud.dataflow.server.repository;

import java.lang.reflect.Field;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Instant;
Expand Down Expand Up @@ -47,7 +48,12 @@
import org.springframework.batch.core.repository.dao.JdbcJobExecutionDao;
import org.springframework.batch.item.database.Order;
import org.springframework.batch.item.database.PagingQueryProvider;
import org.springframework.batch.item.database.support.AbstractSqlPagingQueryProvider;
import org.springframework.batch.item.database.support.Db2PagingQueryProvider;
import org.springframework.batch.item.database.support.OraclePagingQueryProvider;
import org.springframework.batch.item.database.support.SqlPagingQueryProviderFactoryBean;
import org.springframework.batch.item.database.support.SqlPagingQueryUtils;
import org.springframework.batch.item.database.support.SqlServerPagingQueryProvider;
import org.springframework.cloud.dataflow.core.DataFlowPropertyKeys;
import org.springframework.cloud.dataflow.core.database.support.DatabaseType;
import org.springframework.cloud.dataflow.rest.job.JobInstanceExecutions;
Expand Down Expand Up @@ -75,6 +81,7 @@
import org.springframework.jdbc.core.RowMapper;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

/**
Expand Down Expand Up @@ -802,7 +809,7 @@ private PagingQueryProvider getPagingQueryProvider(String fields, String fromCla
* @throws Exception if page provider is not created.
*/
private PagingQueryProvider getPagingQueryProvider(String fields, String fromClause, String whereClause, Map<String, Order> sortKeys) throws Exception {
SqlPagingQueryProviderFactoryBean factory = new SqlPagingQueryProviderFactoryBean();
SqlPagingQueryProviderFactoryBean factory = new SafeSqlPagingQueryProviderFactoryBean();
factory.setDataSource(dataSource);
fromClause = "AGGREGATE_JOB_INSTANCE I JOIN AGGREGATE_JOB_EXECUTION E ON I.JOB_INSTANCE_ID=E.JOB_INSTANCE_ID AND I.SCHEMA_TARGET=E.SCHEMA_TARGET" + (fromClause == null ? "" : " " + fromClause);
factory.setFromClause(fromClause);
Expand All @@ -811,7 +818,7 @@ private PagingQueryProvider getPagingQueryProvider(String fields, String fromCla
}
if (fields.contains("E.JOB_EXECUTION_ID") && this.useRowNumberOptimization) {
Order order = sortKeys.get("E.JOB_EXECUTION_ID");
String orderString = Optional.ofNullable(order).map(orderKey -> orderKey == Order.DESCENDING ? "DESC" : "ASC").orElse("DESC");
String orderString = (order == null || order == Order.DESCENDING) ? "DESC" : "ASC";
fields += ", ROW_NUMBER() OVER (PARTITION BY E.JOB_EXECUTION_ID ORDER BY E.JOB_EXECUTION_ID " + orderString + ") as RN";
}
factory.setSelectClause(fields);
Expand All @@ -832,4 +839,201 @@ private boolean determineSupportsRowNumberFunction(DataSource dataSource) {
}
return false;
}

/**
* A {@link SqlPagingQueryProviderFactoryBean} specialization that overrides the {@code Oracle, MSSQL, and DB2}
* paging {@link SafeOraclePagingQueryProvider provider} with an implementation that properly handles sort aliases.
* <p><b>NOTE:</b> nested within the aggregate DAO as this is the only place that needs this specialization.
*/
static class SafeSqlPagingQueryProviderFactoryBean extends SqlPagingQueryProviderFactoryBean {

private DataSource dataSource;

@Override
public void setDataSource(DataSource dataSource) {
super.setDataSource(dataSource);
this.dataSource = dataSource;
}

@Override
public PagingQueryProvider getObject() throws Exception {
PagingQueryProvider provider = super.getObject();
if (provider instanceof OraclePagingQueryProvider) {
provider = new SafeOraclePagingQueryProvider((AbstractSqlPagingQueryProvider) provider, this.dataSource);
}
else if (provider instanceof SqlServerPagingQueryProvider) {
provider = new SafeSqlServerPagingQueryProvider((SqlServerPagingQueryProvider) provider, this.dataSource);
}
else if (provider instanceof Db2PagingQueryProvider) {
provider = new SafeDb2PagingQueryProvider((Db2PagingQueryProvider) provider, this.dataSource);
}
return provider;
}

}

/**
* A {@link AbstractSqlPagingQueryProvider paging provider} for {@code Oracle} that works around the fact that the
* Oracle provider in Spring Batch 4.x does not properly handle sort aliases when using nested {@code ROW_NUMBER}
* clauses.
*/
static class SafeOraclePagingQueryProvider extends AbstractSqlPagingQueryProvider {

SafeOraclePagingQueryProvider(AbstractSqlPagingQueryProvider delegate, DataSource dataSource) {
// Have to use reflection to retrieve the provider fields
this.setFromClause(extractField(delegate, "fromClause", String.class));
this.setWhereClause(extractField(delegate, "whereClause", String.class));
this.setSortKeys(extractField(delegate, "sortKeys", Map.class));
this.setSelectClause(extractField(delegate, "selectClause", String.class));
this.setGroupClause(extractField(delegate, "groupClause", String.class));
try {
this.init(dataSource);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

private <T> T extractField(AbstractSqlPagingQueryProvider target, String fieldName, Class<T> fieldType) {
Field field = ReflectionUtils.findField(AbstractSqlPagingQueryProvider.class, fieldName, fieldType);
ReflectionUtils.makeAccessible(field);
return (T) ReflectionUtils.getField(field, target);
}

@Override
public String generateFirstPageQuery(int pageSize) {
return generateRowNumSqlQuery(false, pageSize);
}

@Override
public String generateRemainingPagesQuery(int pageSize) {
return generateRowNumSqlQuery(true, pageSize);
}

@Override
public String generateJumpToItemQuery(int itemIndex, int pageSize) {
int page = itemIndex / pageSize;
int offset = (page * pageSize);
offset = (offset == 0) ? 1 : offset;
String sortKeyInnerSelect = this.getSortKeySelect(true);
String sortKeyOuterSelect = this.getSortKeySelect(false);
return SqlPagingQueryUtils.generateRowNumSqlQueryWithNesting(this, sortKeyInnerSelect, sortKeyOuterSelect,
false, "TMP_ROW_NUM = " + offset);
}

private String getSortKeySelect(boolean withAliases) {
StringBuilder sql = new StringBuilder();
Map<String, Order> sortKeys = (withAliases) ? this.getSortKeys() : this.getSortKeysWithoutAliases();
sql.append(sortKeys.keySet().stream().collect(Collectors.joining(",")));
return sql.toString();
}

// Taken from SqlPagingQueryUtils.generateRowNumSqlQuery but use sortKeysWithoutAlias
// for outer sort condition.
private String generateRowNumSqlQuery(boolean remainingPageQuery, int pageSize) {
StringBuilder sql = new StringBuilder();
sql.append("SELECT * FROM (SELECT ").append(getSelectClause());
sql.append(" FROM ").append(this.getFromClause());
if (StringUtils.hasText(this.getWhereClause())) {
sql.append(" WHERE ").append(this.getWhereClause());
}
if (StringUtils.hasText(this.getGroupClause())) {
sql.append(" GROUP BY ").append(this.getGroupClause());
}
// inner sort by
sql.append(" ORDER BY ").append(SqlPagingQueryUtils.buildSortClause(this));
sql.append(") WHERE ").append("ROWNUM <= " + pageSize);
if (remainingPageQuery) {
sql.append(" AND ");
// For the outer sort we want to use sort keys w/o aliases. However,
// SqlPagingQueryUtils.buildSortConditions does not allow sort keys to be passed in.
// Therefore, we temporarily set the 'sortKeys' for the call to 'buildSortConditions'.
// The alternative is to clone the 'buildSortConditions' method here and allow the sort keys to be
// passed in BUT method is gigantic and this approach is the lesser of the two evils.
Map<String, Order> originalSortKeys = this.getSortKeys();
this.setSortKeys(this.getSortKeysWithoutAliases());
try {
SqlPagingQueryUtils.buildSortConditions(this, sql);
}
finally {
this.setSortKeys(originalSortKeys);
}
}
return sql.toString();
}
}

/**
* A {@link SqlServerPagingQueryProvider paging provider} for {@code MSSQL} that works around the fact that the
* MSSQL provider in Spring Batch 4.x does not properly handle sort aliases when generating jump to page queries.
*/
static class SafeSqlServerPagingQueryProvider extends SqlServerPagingQueryProvider {

SafeSqlServerPagingQueryProvider(SqlServerPagingQueryProvider delegate, DataSource dataSource) {
// Have to use reflection to retrieve the provider fields
this.setFromClause(extractField(delegate, "fromClause", String.class));
this.setWhereClause(extractField(delegate, "whereClause", String.class));
this.setSortKeys(extractField(delegate, "sortKeys", Map.class));
this.setSelectClause(extractField(delegate, "selectClause", String.class));
this.setGroupClause(extractField(delegate, "groupClause", String.class));
try {
this.init(dataSource);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

private <T> T extractField(AbstractSqlPagingQueryProvider target, String fieldName, Class<T> fieldType) {
Field field = ReflectionUtils.findField(AbstractSqlPagingQueryProvider.class, fieldName, fieldType);
ReflectionUtils.makeAccessible(field);
return (T) ReflectionUtils.getField(field, target);
}

@Override
protected String getOverClause() {
// Overrides the parent impl to use 'getSortKeys' instead of 'getSortKeysWithoutAliases'
StringBuilder sql = new StringBuilder();
sql.append(" ORDER BY ").append(SqlPagingQueryUtils.buildSortClause(this.getSortKeys()));
return sql.toString();
}

}

/**
* A {@link Db2PagingQueryProvider paging provider} for {@code DB2} that works around the fact that the
* DB2 provider in Spring Batch 4.x does not properly handle sort aliases when generating jump to page queries.
*/
static class SafeDb2PagingQueryProvider extends Db2PagingQueryProvider {

SafeDb2PagingQueryProvider(Db2PagingQueryProvider delegate, DataSource dataSource) {
// Have to use reflection to retrieve the provider fields
this.setFromClause(extractField(delegate, "fromClause", String.class));
this.setWhereClause(extractField(delegate, "whereClause", String.class));
this.setSortKeys(extractField(delegate, "sortKeys", Map.class));
this.setSelectClause(extractField(delegate, "selectClause", String.class));
this.setGroupClause(extractField(delegate, "groupClause", String.class));
try {
this.init(dataSource);
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

private <T> T extractField(AbstractSqlPagingQueryProvider target, String fieldName, Class<T> fieldType) {
Field field = ReflectionUtils.findField(AbstractSqlPagingQueryProvider.class, fieldName, fieldType);
ReflectionUtils.makeAccessible(field);
return (T) ReflectionUtils.getField(field, target);
}

@Override
protected String getOverClause() {
// Overrides the parent impl to use 'getSortKeys' instead of 'getSortKeysWithoutAliases'
StringBuilder sql = new StringBuilder();
sql.append(" ORDER BY ").append(SqlPagingQueryUtils.buildSortClause(this.getSortKeys()));
return sql.toString();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class SqlPagingQueryProviderFactoryBean implements FactoryBean<PagingQuer
private final static Map<DatabaseType, AbstractSqlPagingQueryProvider> providers;

static {
Map<DatabaseType, AbstractSqlPagingQueryProvider> providerMap = new HashMap<DatabaseType, AbstractSqlPagingQueryProvider>();
Map<DatabaseType, AbstractSqlPagingQueryProvider> providerMap = new HashMap<DatabaseType, AbstractSqlPagingQueryProvider>();
providerMap.put(DatabaseType.HSQL, new HsqlPagingQueryProvider());
providerMap.put(DatabaseType.H2, new H2PagingQueryProvider());
providerMap.put(DatabaseType.MYSQL, new MySqlPagingQueryProvider());
Expand Down
10 changes: 10 additions & 0 deletions spring-cloud-dataflow-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -408,5 +408,15 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>local-dev-db2</id>
<dependencies>
<dependency>
<groupId>com.ibm.db2</groupId>
<artifactId>jcc</artifactId>
<version>11.5.8.0</version>
</dependency>
</dependencies>
</profile>
</profiles>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,21 @@ void shouldListJobExecutionsUsingPerformantRowNumberQuery(
createdExecutionIdsBySchemaTarget.add(schemaVersionTarget, execution1.getExecutionId());
TaskExecution execution2 = testUtils.createSampleJob("job2", 3, BatchStatus.COMPLETED, new JobParameters(), schemaVersionTarget);
createdExecutionIdsBySchemaTarget.add(schemaVersionTarget, execution2.getExecutionId());

// Get all executions and ensure the count and that the row number function was (or not) used
jobExecutions = taskJobService.listJobExecutionsWithStepCount(Pageable.ofSize(100));
assertThat(jobExecutions).hasSize(originalCount + 4);
String expectedSqlFragment = (this.supportsRowNumberFunction()) ?
"as STEP_COUNT, ROW_NUMBER() OVER (PARTITION" :
"as STEP_COUNT FROM AGGREGATE_JOB_INSTANCE";
Awaitility.waitAtMost(Duration.ofSeconds(5))
.untilAsserted(() -> assertThat(output).contains(expectedSqlFragment));

// Verify that paging works as well
jobExecutions = taskJobService.listJobExecutionsWithStepCount(Pageable.ofSize(2).withPage(0));
assertThat(jobExecutions).hasSize(2);
jobExecutions = taskJobService.listJobExecutionsWithStepCount(Pageable.ofSize(2).withPage(1));
assertThat(jobExecutions).hasSize(2);
}

static Stream<SchemaVersionTarget> schemaVersionTargetsProvider() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,22 @@

import javax.sql.DataSource;

import com.zaxxer.hikari.HikariDataSource;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.JobExecution;
import org.springframework.batch.core.JobInstance;
import org.springframework.batch.core.JobParameters;
import org.springframework.batch.core.repository.dao.JdbcJobInstanceDao;
import org.springframework.batch.item.database.support.DataFieldMaxValueIncrementerFactory;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.cloud.dataflow.core.database.support.DatabaseType;
import org.springframework.cloud.dataflow.core.database.support.MultiSchemaIncrementerFactory;
import org.springframework.cloud.dataflow.schema.SchemaVersionTarget;
import org.springframework.cloud.dataflow.schema.service.SchemaService;
import org.springframework.cloud.dataflow.schema.service.impl.DefaultSchemaService;
import org.springframework.cloud.dataflow.server.repository.TaskBatchDaoContainer;
import org.springframework.cloud.dataflow.server.repository.TaskExecutionDaoContainer;
import org.springframework.cloud.task.batch.listener.TaskBatchDao;
Expand Down Expand Up @@ -130,4 +137,37 @@ private JobExecution saveJobExecution(JobExecution jobExecution, JdbcTemplate jd
private Timestamp timestampFromDate(Date date) {
return (date != null) ? Timestamp.valueOf(date.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime()) : null;
}


/**
* Test utility that generates hundreds of job executions which can be useful when debugging paging issues.
* <p>To run, adjust the datasource properties accordingly and then execute the test manually in your editor.
*/
@Disabled
static class JobExecutionTestDataGenerator {

@Test
void generateJobExecutions() {
// Adjust these properties as necessary to point to your env
DataSourceProperties dataSourceProperties = new DataSourceProperties();
dataSourceProperties.setUrl("jdbc:oracle:thin:@localhost:1521/dataflow");
dataSourceProperties.setUsername("spring");
dataSourceProperties.setPassword("spring");
dataSourceProperties.setDriverClassName("oracle.jdbc.OracleDriver");

DataSource dataSource = dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build();
SchemaService schemaService = new DefaultSchemaService();
TaskExecutionDaoContainer taskExecutionDaoContainer = new TaskExecutionDaoContainer(dataSource, schemaService);
TaskBatchDaoContainer taskBatchDaoContainer = new TaskBatchDaoContainer(dataSource, schemaService);
JobExecutionTestUtils generator = new JobExecutionTestUtils(taskExecutionDaoContainer, taskBatchDaoContainer);
generator.createSampleJob(jobName("boot2"), 200, BatchStatus.COMPLETED, new JobParameters(),
schemaService.getTarget("boot2"));
generator.createSampleJob(jobName("boot3"), 200, BatchStatus.COMPLETED, new JobParameters(),
schemaService.getTarget("boot3"));
}

private String jobName(String schemaTarget) {
return schemaTarget + "-job-" + System.currentTimeMillis();
}
}
}
Loading

0 comments on commit 6cb57b3

Please sign in to comment.