Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilizing circe-be extensions to pertain cohort covariates during its generation #2411

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

<commons-fileupload.version>1.5</commons-fileupload.version>

<circe.version>1.11.2</circe.version>
<circe.version>1.11.4-SNAPSHOT</circe.version>
<jersey.version>2.14</jersey.version>
<SqlRender.version>1.16.1</SqlRender.version>
<hive-jdbc.version>3.1.2</hive-jdbc.version>
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/ohdsi/webapi/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ interface Params {
String EXECUTABLE_FILE_NAME = "executableFilename";
String GENERATION_ID = "generation_id";
String DESIGN_HASH = "design_hash";
String RETAIN_COHORT_COVARIATES = "retains_cohort_covariates";
}

interface Variables {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ public CohortGenerationInfo(CohortDefinition definition, Integer sourceId)
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "created_by_id")
private UserEntity createdBy;
@Column(name = "is_choose_covariates")
private boolean isChooseCovariates;

public CohortGenerationInfoId getId() {
return id;
Expand Down Expand Up @@ -187,4 +189,12 @@ public void setCreatedBy(UserEntity createdBy) {
public UserEntity getCreatedBy() {
return createdBy;
}

public boolean isChooseCovariates() {
return isChooseCovariates;
}

public void setChooseCovariates(boolean chooseCovariates) {
isChooseCovariates = chooseCovariates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ public class CohortGenerationRequest {
private String sessionId;
private String targetSchema;
private Integer targetId;
private Boolean retainCohortCovariates;
private Integer cohortId;

public CohortGenerationRequest(CohortExpression expression, Source source, String sessionId, Integer targetId, String targetSchema) {

Expand All @@ -19,6 +21,18 @@ public CohortGenerationRequest(CohortExpression expression, Source source, Strin
this.targetId = targetId;
this.targetSchema = targetSchema;
}

public CohortGenerationRequest(CohortExpression expression, Source source, String sessionId, Integer targetId,
String targetSchema, Boolean retainCohortCovariates, Integer cohortId) {

this.expression = expression;
this.source = source;
this.sessionId = sessionId;
this.targetId = targetId;
this.targetSchema = targetSchema;
this.retainCohortCovariates = retainCohortCovariates;
this.cohortId = cohortId;
}

public CohortExpression getExpression() {

Expand All @@ -44,4 +58,14 @@ public Integer getTargetId() {

return targetId;
}

public Boolean getRetainCohortCovariates() {

return retainCohortCovariates;
}

public Integer getCohortId() {

return cohortId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@ public class CohortGenerationRequestBuilder {
private String sessionId;
private String targetSchema;
private Integer targetId;

private Integer cohortId;
private Boolean retainCohortCovariates;

public CohortGenerationRequestBuilder(String sessionId, String targetSchema) {

this.sessionId = sessionId;
this.targetSchema = targetSchema;
}

public CohortGenerationRequestBuilder(String sessionId, String targetSchema, Boolean retainCohortCovariates) {

this.sessionId = sessionId;
this.targetSchema = targetSchema;
this.retainCohortCovariates = retainCohortCovariates;
}

public Boolean getRetainCohortCovariates() {

return retainCohortCovariates;
}

public CohortGenerationRequestBuilder withSource(Source source) {

Expand All @@ -34,6 +48,11 @@ public CohortGenerationRequestBuilder withTargetId(Integer targetId) {
this.targetId = targetId;
return this;
}

public CohortGenerationRequestBuilder withCohortId(Integer cohortId) {
this.cohortId = cohortId;
return this;
}

public CohortGenerationRequest build() {

Expand All @@ -43,4 +62,18 @@ public CohortGenerationRequest build() {

return new CohortGenerationRequest(expression, source, sessionId, targetId, targetSchema);
}

public CohortGenerationRequest buildWithRetainCohortCovariates() {

if (this.source == null || this.expression == null || this.targetId == null || this.retainCohortCovariates == null) {
throw new RuntimeException("CohortGenerationRequest should contain non-null expression, source and targetId");
}

return new CohortGenerationRequest(expression, source, sessionId, targetId, targetSchema,
retainCohortCovariates, cohortId);
}

public boolean hasRetainCohortCovariates() {
return retainCohortCovariates != null ? retainCohortCovariates.booleanValue() : false;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.ohdsi.webapi.cohortdefinition;

import org.apache.commons.lang3.StringUtils;

import org.ohdsi.circe.cohortdefinition.CohortExpression;
import org.ohdsi.circe.cohortdefinition.CohortExpressionQueryBuilder;
import org.ohdsi.circe.cohortdefinition.InclusionRule;
import org.ohdsi.sql.SqlRender;
Expand All @@ -10,14 +10,14 @@
import org.ohdsi.webapi.source.Source;
import org.ohdsi.webapi.util.SourceUtils;
import org.springframework.jdbc.core.JdbcTemplate;

import org.springframework.util.ObjectUtils;

import java.util.Arrays;
import java.util.List;

import static org.ohdsi.webapi.Constants.Params.TARGET_DATABASE_SCHEMA;
import static org.ohdsi.webapi.Constants.Params.DESIGN_HASH;

import static org.ohdsi.webapi.Constants.Params.RESULTS_DATABASE_SCHEMA;
import static org.ohdsi.webapi.Constants.Tables.COHORT_CACHE;
import static org.ohdsi.webapi.Constants.Tables.COHORT_CENSOR_STATS_CACHE;
import static org.ohdsi.webapi.Constants.Tables.COHORT_INCLUSION_RESULT_CACHE;
Expand Down Expand Up @@ -46,6 +46,7 @@ public static void insertInclusionRules(CohortDefinition cohortDef, Source sourc
public static String[] buildGenerationSql(CohortGenerationRequest request) {

Source source = request.getSource();
CohortExpression expression = request.getExpression();

String cdmSchema = SourceUtils.getCdmQualifier(source);
String vocabSchema = SourceUtils.getVocabQualifierOrNull(source);
Expand All @@ -56,13 +57,15 @@ public static String[] buildGenerationSql(CohortGenerationRequest request) {
CohortExpressionQueryBuilder.BuildExpressionQueryOptions options = new CohortExpressionQueryBuilder.BuildExpressionQueryOptions();
options.cohortIdFieldName = DESIGN_HASH;
options.cohortId = request.getTargetId();
options.resultCohortId = request.getCohortId();
options.cdmSchema = cdmSchema;
options.vocabularySchema = vocabSchema;
options.generateStats = true; // always generate with stats
options.retainCohortCovariates = !ObjectUtils.isEmpty(request.getRetainCohortCovariates()) && request.getRetainCohortCovariates(); // this field decides whether to retain cohort covariates

final String oracleTempSchema = SourceUtils.getTempQualifier(source);

String expressionSql = expressionQueryBuilder.buildExpressionQuery(request.getExpression(), options);
String expressionSql = expressionQueryBuilder.buildExpressionQuery(expression, options);
expressionSql = SqlRender.renderSql(
expressionSql,
new String[] {"target_cohort_table",
Expand All @@ -81,6 +84,7 @@ public static String[] buildGenerationSql(CohortGenerationRequest request) {
"@target_database_schema.cohort_inclusion"
}
);
expressionSql = expressionSql.replaceAll("@results_database_schema", request.getTargetSchema());
sqlBuilder.append(expressionSql);

String renderedSql = SqlRender.renderSql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ protected String[] prepareQueries(ChunkContext chunkContext, CancelableJdbcTempl
Integer sourceId = Integer.parseInt(jobParams.get(SOURCE_ID).toString());
String targetSchema = jobParams.get(TARGET_DATABASE_SCHEMA).toString();
String sessionId = jobParams.getOrDefault(SESSION_ID, SessionUtils.sessionId()).toString();
Boolean retainCohortCovariates = Boolean.valueOf(jobParams.get(RETAIN_COHORT_COVARIATES).toString());

CohortDefinition cohortDefinition = cohortDefinitionRepository.findOneWithDetail(cohortDefinitionId);
Source source = sourceService.findBySourceId(sourceId);

CohortGenerationRequestBuilder generationRequestBuilder = new CohortGenerationRequestBuilder(
sessionId,
targetSchema
targetSchema,
retainCohortCovariates
);

int designHash = this.generationCacheHelper.computeHash(cohortDefinition.getDetails().getExpression());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public CohortGenerationInfoDTO convert(CohortGenerationInfo info) {
dto.setStartTime(info.getStartTime());
dto.setStatus(info.getStatus());
dto.setIsValid(info.isIsValid());
dto.setChooseCovariates(info.isChooseCovariates());

return dto;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class CohortGenerationInfoDTO {

private UserDTO createdBy;

private boolean isChooseCovariates;

public CohortGenerationInfoId getId() {
return id;
}
Expand Down Expand Up @@ -124,4 +126,12 @@ public UserDTO getCreatedBy() {
public void setCreatedBy(UserDTO createdBy) {
this.createdBy = createdBy;
}

public boolean isChooseCovariates() {
return isChooseCovariates;
}

public void setChooseCovariates(boolean chooseCovariates) {
isChooseCovariates = chooseCovariates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ private String getInclusionRuleInserts(FeasibilityStudy study)
private String getInclusionRuleQuery(CriteriaGroup inclusionRule)
{
String resultSql = INCLUSION_RULE_QUERY_TEMPLATE;
String additionalCriteriaQuery = "\nJOIN (\n" + cohortExpressionQueryBuilder.getCriteriaGroupQuery(inclusionRule, "#primary_events") + ") AC on AC.event_id = pe.event_id";
additionalCriteriaQuery = StringUtils.replace(additionalCriteriaQuery,"@indexId", "" + 0);
resultSql = StringUtils.replace(resultSql, "@additionalCriteriaQuery", additionalCriteriaQuery);
String additionalCriteriaQuery = "\nJOIN (\n"
+ cohortExpressionQueryBuilder.getCriteriaGroupQuery(inclusionRule, "#primary_events", false)
+ ") AC on AC.event_id = pe.event_id"; additionalCriteriaQuery = StringUtils.replace(additionalCriteriaQuery,"@indexId", "" + 0); resultSql = StringUtils.replace(resultSql, "@additionalCriteriaQuery", additionalCriteriaQuery);
return resultSql;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,20 @@ public CacheResult computeCacheIfAbsent(CohortDefinition cohortDefinition, Sourc
return transactionTemplateRequiresNew.execute(s -> {
log.info("Retrieves or invalidates cache for cohort id = {}", cohortDefinition.getId());
GenerationCache cache = generationCacheService.getCacheOrEraseInvalid(type, designHash, source.getSourceId());
if (cache == null) {
log.info("Cache is absent for cohort id = {}. Calculating with design hash = {}", cohortDefinition.getId(), designHash);
if (cache == null || requestBuilder.hasRetainCohortCovariates()) {
String messagePrefix = (cache == null ? "Cache is absent" : "Cache will not be used because the retain cohort covariates option is switched on");
log.info(messagePrefix + " for cohort id = {}. Calculating with design hash = {}", cohortDefinition.getId(), designHash);
// Ensure that there are no records in results schema with which we could mess up
generationCacheService.removeCache(type, source, designHash);
// the line below forces a cached entry to be really deleted and it is a bit unclear why this line was even present as the cache had to be null anyway
// without it there is a constraint violation exception when there was a cache entry present and the retain covariates option is on
GenerationCache cachedResultsStillPresent = generationCacheService.getCacheOrEraseInvalid(type, designHash, source.getSourceId());
CohortGenerationRequest cohortGenerationRequest = requestBuilder
.withExpression(cohortDefinition.getDetails().getExpressionObject())
.withSource(source)
.withTargetId(designHash)
.build();
.withCohortId(cohortDefinition.getId())
.buildWithRetainCohortCovariates();
String[] sqls = CohortGenerationUtils.buildGenerationSql(cohortGenerationRequest);
sqlExecutor.accept(designHash, sqls);
cache = generationCacheService.cacheResults(CacheableGenerationType.COHORT, designHash, source.getSourceId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ public IRAnalysisQueryBuilder(ObjectMapper objectMapper) {
private String getStrataQuery(CriteriaGroup strataCriteria)
{
String resultSql = STRATA_QUERY_TEMPLATE;
String additionalCriteriaQuery = "\nJOIN (\n" + cohortExpressionQueryBuilder.getCriteriaGroupQuery(strataCriteria, "#analysis_events") + ") AC on AC.person_id = pe.person_id AND AC.event_id = pe.event_id";
String additionalCriteriaQuery = "\nJOIN (\n"
+ cohortExpressionQueryBuilder.getCriteriaGroupQuery(strataCriteria, "#analysis_events", false)
+ ") AC on AC.person_id = pe.person_id AND AC.event_id = pe.event_id";
additionalCriteriaQuery = StringUtils.replace(additionalCriteriaQuery,"@indexId", "" + 0);
resultSql = StringUtils.replace(resultSql, "@additionalCriteriaQuery", additionalCriteriaQuery);
return resultSql;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,14 @@ public CohortDTO saveCohortDefinition(@PathParam("id") final int id, CohortDTO d
@Produces(MediaType.APPLICATION_JSON)
@Path("/{id}/generate/{sourceKey}")
@Transactional
public JobExecutionResource generateCohort(@PathParam("id") final int id, @PathParam("sourceKey") final String sourceKey) {
public JobExecutionResource generateCohort(@PathParam("id") final int id, @PathParam("sourceKey") final String sourceKey, @QueryParam("retainCohortCovariates") String retainCohortCovariates) {

Source source = getSourceRepository().findBySourceKey(sourceKey);
CohortDefinition currentDefinition = this.cohortDefinitionRepository.findOne(id);
UserEntity user = userRepository.findByLogin(security.getSubject());
return cohortGenerationService.generateCohortViaJob(user, currentDefinition, source);
return cohortGenerationService.generateCohortViaJob(user, currentDefinition, source, Boolean.parseBoolean(retainCohortCovariates));
}

/**
* Cancel a cohort generation task
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.ohdsi.webapi.job.GeneratesNotification;
import org.ohdsi.webapi.job.JobExecutionResource;
import org.ohdsi.webapi.shiro.Entities.UserEntity;
import org.ohdsi.webapi.shiro.Entities.UserRepository;
import org.ohdsi.webapi.source.Source;
import org.ohdsi.webapi.source.SourceService;
import org.ohdsi.webapi.util.SessionUtils;
Expand All @@ -38,6 +37,7 @@
import static org.ohdsi.webapi.Constants.Params.COHORT_DEFINITION_ID;
import static org.ohdsi.webapi.Constants.Params.GENERATE_STATS;
import static org.ohdsi.webapi.Constants.Params.JOB_NAME;
import static org.ohdsi.webapi.Constants.Params.RETAIN_COHORT_COVARIATES;
import static org.ohdsi.webapi.Constants.Params.SESSION_ID;
import static org.ohdsi.webapi.Constants.Params.SOURCE_ID;
import static org.ohdsi.webapi.Constants.Params.TARGET_DATABASE_SCHEMA;
Expand Down Expand Up @@ -71,13 +71,14 @@ public CohortGenerationService(CohortDefinitionRepository cohortDefinitionReposi
this.generationCacheHelper = generationCacheHelper;
}

public JobExecutionResource generateCohortViaJob(UserEntity userEntity, CohortDefinition cohortDefinition, Source source) {
public JobExecutionResource generateCohortViaJob(UserEntity userEntity, CohortDefinition cohortDefinition, Source source, Boolean retainCohortCovariates) {

CohortGenerationInfo info = cohortDefinition.getGenerationInfoList().stream()
.filter(val -> Objects.equals(val.getId().getSourceId(), source.getSourceId())).findFirst()
.orElse(new CohortGenerationInfo(cohortDefinition, source.getSourceId()));

info.setCreatedBy(userEntity);
info.setChooseCovariates(retainCohortCovariates);

cohortDefinition.getGenerationInfoList().add(info);

Expand All @@ -86,7 +87,7 @@ public JobExecutionResource generateCohortViaJob(UserEntity userEntity, CohortDe

cohortDefinitionRepository.save(cohortDefinition);

return runGenerateCohortJob(cohortDefinition, source);
return runGenerateCohortJob(cohortDefinition, source, retainCohortCovariates);
}

private Job buildGenerateCohortJob(CohortDefinition cohortDefinition, Source source, JobParameters jobParameters) {
Expand Down Expand Up @@ -121,13 +122,13 @@ private Job buildGenerateCohortJob(CohortDefinition cohortDefinition, Source sou
return generateJobBuilder.build();
}

private JobExecutionResource runGenerateCohortJob(CohortDefinition cohortDefinition, Source source) {
final JobParametersBuilder jobParametersBuilder = getJobParametersBuilder(source, cohortDefinition);
private JobExecutionResource runGenerateCohortJob(CohortDefinition cohortDefinition, Source source, Boolean retainCohortCovariates) {
final JobParametersBuilder jobParametersBuilder = getJobParametersBuilder(source, cohortDefinition, retainCohortCovariates);
Job job = buildGenerateCohortJob(cohortDefinition, source, jobParametersBuilder.toJobParameters());
return jobService.runJob(job, jobParametersBuilder.toJobParameters());
}

private JobParametersBuilder getJobParametersBuilder(Source source, CohortDefinition cohortDefinition) {
private JobParametersBuilder getJobParametersBuilder(Source source, CohortDefinition cohortDefinition, Boolean retainCohortCovariates) {

JobParametersBuilder builder = new JobParametersBuilder();
builder.addString(JOB_NAME, String.format("Generating cohort %d : %s (%s)", cohortDefinition.getId(), source.getSourceName(), source.getSourceKey()));
Expand All @@ -136,6 +137,7 @@ private JobParametersBuilder getJobParametersBuilder(Source source, CohortDefini
builder.addString(COHORT_DEFINITION_ID, String.valueOf(cohortDefinition.getId()));
builder.addString(SOURCE_ID, String.valueOf(source.getSourceId()));
builder.addString(GENERATE_STATS, Boolean.TRUE.toString());
builder.addString(RETAIN_COHORT_COVARIATES, String.valueOf(retainCohortCovariates));
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE ${ohdsiSchema}.cohort_generation_info ADD is_choose_covariates BOOLEAN NOT NULL DEFAULT FALSE;
Loading