Skip to content

Commit

Permalink
Merge pull request #50 from DiSSCo/feature/pgcopy-temp-tables
Browse files Browse the repository at this point in the history
Copy batches into db
  • Loading branch information
southeo authored Dec 6, 2023
2 parents 283a425 + 918c901 commit f8b6d4c
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package eu.dissco.core.translator.configuration;

import java.sql.DriverManager;
import java.sql.SQLException;
import lombok.RequiredArgsConstructor;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.context.annotation.Configuration;
import org.postgresql.copy.CopyManager;
import org.postgresql.core.BaseConnection;
import org.springframework.context.annotation.Bean;


@Configuration
@RequiredArgsConstructor
public class BatchInserterConfig {

private final DataSourceProperties properties;

@Bean
public CopyManager copyManager() throws SQLException {
var connection = DriverManager.getConnection(properties.getUrl(), properties.getUsername(),
properties.getPassword());
return new CopyManager((BaseConnection) connection);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

public class DisscoRepositoryException extends Exception {

public DisscoRepositoryException(String message) {
super(message);
}

public DisscoRepositoryException(String message, Throwable cause) {
super(message, cause);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package eu.dissco.core.translator.repository;

import com.fasterxml.jackson.databind.JsonNode;
import eu.dissco.core.translator.exception.DisscoRepositoryException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.postgresql.copy.CopyManager;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class BatchInserter {

private final CopyManager copyManager;

public void batchCopy(String tableName, List<Pair<String, JsonNode>> dbRecords)
throws DisscoRepositoryException {
try (var outputStream = new ByteArrayOutputStream()) {
for (var dbRecord : dbRecords) {
outputStream.write(getCsvRow(dbRecord));
}
var inputStream = new ByteArrayInputStream(outputStream.toByteArray());
copyManager.copyIn("COPY " + tableName
+ " FROM stdin DELIMITER ','", inputStream);
} catch (IOException | SQLException e) {
throw new DisscoRepositoryException(
String.format("An error has occurred inserting %d records into temp table %s",
dbRecords.size(), tableName), e);
}
}

private static byte[] getCsvRow(Pair<String, JsonNode> dbRecord) {
return (dbRecord.getLeft() + "," +
cleanString(dbRecord.getRight())
+ "\n").getBytes(StandardCharsets.UTF_8);
}

private static String cleanString(JsonNode jsonNode) {
if (jsonNode.isEmpty()) {
return "{}";
}
var node = jsonNode.toString();
node = node.replace("\\u0000", "");
node = node.replace("\\", "\\\\");
node = node.replace(",", "\\,");
return node;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import eu.dissco.core.translator.exception.DisscoRepositoryException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.jooq.DSLContext;
import org.jooq.Field;
import org.jooq.JSONB;
import org.jooq.Query;
import org.jooq.Record;
import org.jooq.Table;
import org.jooq.impl.DSL;
Expand All @@ -34,6 +33,7 @@ public class DwcaRepository {

private final ObjectMapper mapper;
private final DSLContext context;
private final BatchInserter batchInserter;

public void createTable(String tableName) {
context.createTable(tableName)
Expand All @@ -43,26 +43,13 @@ public void createTable(String tableName) {
context.createIndex().on(tableName, idField.getName()).execute();
}


private Table<Record> getTable(String tableName) {
return DSL.table("\"" + tableName + "\"");
}

public void postRecords(String tableName, List<Pair<String, JsonNode>> dbRecords) {
var queries = dbRecords.stream().map(dbRecord -> recordToQuery(tableName, dbRecord)).filter(
Objects::nonNull).toList();
context.batch(queries).execute();
}

private Query recordToQuery(String tableName, Pair<String, JsonNode> dbRecord) {
try {
return context.insertInto(getTable(tableName)).set(idField, dbRecord.getLeft())
.set(dataField,
JSONB.jsonb(mapper.writeValueAsString(dbRecord.getRight()).replace("\\u0000", "")));
} catch (JsonProcessingException e) {
log.error("Unable to map JSON to JSONB, ignoring record: {}", dbRecord.getLeft(), e);
return null;
}
public void postRecords(String tableName, List<Pair<String, JsonNode>> dbRecords)
throws DisscoRepositoryException {
batchInserter.batchCopy(tableName, dbRecords);
}

public Map<String, ObjectNode> getCoreRecords(List<String> batch, String tableName) {
Expand All @@ -89,6 +76,5 @@ public void deleteTable(String tableName) {
context.dropTableIfExists(tableName).execute();
}


}

23 changes: 15 additions & 8 deletions src/main/java/eu/dissco/core/translator/service/DwcaService.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import eu.dissco.core.translator.domain.DigitalSpecimenWrapper;
import eu.dissco.core.translator.domain.Enrichment;
import eu.dissco.core.translator.exception.DiSSCoDataException;
import eu.dissco.core.translator.exception.DisscoRepositoryException;
import eu.dissco.core.translator.exception.OrganisationException;
import eu.dissco.core.translator.properties.DwcaProperties;
import eu.dissco.core.translator.properties.EnrichmentProperties;
Expand Down Expand Up @@ -101,6 +102,8 @@ public void retrieveData() {
} catch (InterruptedException e) {
log.error("Failed during downloading file due to interruption", e);
Thread.currentThread().interrupt();
} catch (DisscoRepositoryException e) {
log.error("Failed during batch copy into temp tables with exception", e);
} finally {
if (archive != null) {
log.info("Cleaning up database tables");
Expand Down Expand Up @@ -329,14 +332,13 @@ private Collection<List<String>> prepareChunks(List<String> inputList, int chunk
}


private List<String> postArchiveToDatabase(Archive archive) {
private List<String> postArchiveToDatabase(Archive archive) throws DisscoRepositoryException {
var tableNames = generateTableNames(archive);
createTempTables(tableNames);
log.info("Created tables: {}", tableNames);
var idList = postCore(archive.getCore());
postExtensions(archive.getExtensions(), idList);
return idList;

}

private void removeTempTables(Archive archive) {
Expand All @@ -357,8 +359,13 @@ private List<String> generateTableNames(Archive archive) {

private String getTableName(ArchiveFile archiveFile) {
var fullSourceSystemId = webClientProperties.getSourceSystemId();
var minifiedSourceSystemId = fullSourceSystemId.substring(fullSourceSystemId.indexOf('/') + 1);
return minifiedSourceSystemId + "_" + archiveFile.getRowType().prefixedName();
var minifiedSourceSystemId = fullSourceSystemId.substring(fullSourceSystemId.indexOf('/') + 1)
.replace("-", "_");
var tableName = (minifiedSourceSystemId + "_" + archiveFile.getRowType()
.prefixedName()).toLowerCase()
.replace(":", "_");
tableName = tableName.replace("/", "_");
return tableName.replace(".", "_");
}

private void createTempTables(List<String> tableNames) {
Expand All @@ -367,7 +374,7 @@ private void createTempTables(List<String> tableNames) {
}
}

private ArrayList<String> postCore(ArchiveFile core) {
private ArrayList<String> postCore(ArchiveFile core) throws DisscoRepositoryException {
var dbRecords = new ArrayList<Pair<String, JsonNode>>();
var idList = new ArrayList<String>();
for (var rec : core) {
Expand All @@ -393,14 +400,14 @@ private ArrayList<String> postCore(ArchiveFile core) {
}

private void postToDatabase(ArchiveFile archiveFile,
ArrayList<Pair<String, JsonNode>> dbRecords) {
ArrayList<Pair<String, JsonNode>> dbRecords) throws DisscoRepositoryException {
log.info("Persisting {} records to database", dbRecords.size());
dwcaRepository.postRecords(getTableName(archiveFile), dbRecords);
dbRecords.clear();
}


private void postExtensions(Set<ArchiveFile> extensions, List<String> idsList) {
private void postExtensions(Set<ArchiveFile> extensions, List<String> idsList)
throws DisscoRepositoryException {
var dbRecords = new ArrayList<Pair<String, JsonNode>>();
for (var extension : extensions) {
log.info("Processing records of extension: {}", extension.getRowType().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class BaseRepositoryIT {
@Container
private static final PostgreSQLContainer<?> CONTAINER = new PostgreSQLContainer<>(POSTGIS);
protected DSLContext context;
private HikariDataSource dataSource;
protected HikariDataSource dataSource;

@BeforeEach
void prepareDatabase() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package eu.dissco.core.translator.repository;

import static eu.dissco.core.translator.TestUtils.MAPPER;
import static org.assertj.core.api.Assertions.assertThat;

import com.fasterxml.jackson.databind.JsonNode;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.jooq.Field;
import org.jooq.JSONB;
import org.jooq.Record;
import org.jooq.Table;
import org.jooq.impl.DSL;
import org.jooq.impl.SQLDataType;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.postgresql.copy.CopyManager;
import org.postgresql.core.BaseConnection;
import org.testcontainers.shaded.org.yaml.snakeyaml.events.Event.ID;

class BatchInserterTest extends BaseRepositoryIT {

private BatchInserter batchInserter;
private static final String TABLE_NAME = "xxx_xxx_xxx_core";
private final Field<String> ID_FIELD = DSL.field("dwcaid", String.class);
private static final Field<JSONB> DATA_FIELD = DSL.field("data", JSONB.class);
private static final String RECORD_ID = "11a8a4c6-3188-4305-9688-d68942f4038e";
private static final String RECORD_ID_ALT = "32546f7b-f62a-4368-8c60-922f1cba4ab8";
@BeforeEach
void setup() throws SQLException {
var connection = DriverManager.getConnection(dataSource.getJdbcUrl(), dataSource.getUsername(),
dataSource.getPassword());
var copyManager = new CopyManager((BaseConnection) connection);
batchInserter = new BatchInserter(copyManager);
context.createTable(TABLE_NAME)
.column(ID_FIELD, SQLDataType.VARCHAR)
.column(DATA_FIELD, SQLDataType.JSONB)
.execute();
context.createIndex().on(TABLE_NAME, ID_FIELD.getName()).execute();
}

@AfterEach
void destroy(){
context.dropTableIfExists(getTable(TABLE_NAME)).execute();
}

@Test
void testBatchInsert() throws Exception {
// Given
var records = givenCoreRecords();
var idField = context.meta().getTables(TABLE_NAME).get(0).field(ID_FIELD);

// When
batchInserter.batchCopy(TABLE_NAME, records);
var result = context.select(getTable(TABLE_NAME).asterisk())
.from(getTable(TABLE_NAME))
.where(idField.eq(RECORD_ID))
.fetchOne();

// Then
assertThat(MAPPER.readTree(result.get(DATA_FIELD).data())).isEqualTo(givenJsonNode());
}

@ParameterizedTest
@MethodSource("badStrings")
void testBadCharacters(String badString) throws Exception {
// Given
var node = MAPPER.createObjectNode();
node.put("field", badString);
var pair = List.of(Pair.of(RECORD_ID, (JsonNode) node));
var idField = context.meta().getTables(TABLE_NAME).get(0).field(ID_FIELD);

// When
batchInserter.batchCopy(TABLE_NAME, pair);
var result = context.select(getTable(TABLE_NAME).asterisk())
.from(getTable(TABLE_NAME))
.where(idField.eq(RECORD_ID))
.fetchOne();

// Then
assertThat(MAPPER.readTree(result.get(DATA_FIELD).data())).isEqualTo(node);
}

@Test
void testBadCharacters() throws Exception {
// Given
var node = MAPPER.createObjectNode();
node.put("field", "\u0000");
var pair = List.of(Pair.of(RECORD_ID, (JsonNode) node));
var expected = MAPPER.readTree("""
{
"field":""
}
""");
var idField = context.meta().getTables(TABLE_NAME).get(0).field(ID_FIELD);

// When
batchInserter.batchCopy(TABLE_NAME, pair);
var result = context.select(getTable(TABLE_NAME).asterisk())
.from(getTable(TABLE_NAME))
.where(idField.eq(RECORD_ID))
.fetchOne();

// Then
assertThat(MAPPER.readTree(result.get(DATA_FIELD).data())).isEqualTo(expected);
}

private static Stream<Arguments> badStrings(){
return Stream.of(
Arguments.of("bad \b string"),
Arguments.of("bad \f string"),
Arguments.of("bad \n string"),
Arguments.of("bad \r string"),
Arguments.of("bad \t string"),
Arguments.of("bad, string"),
Arguments.of("bad \\N string")
);
}

private List<Pair<String, JsonNode>> givenCoreRecords() {
var records = new ArrayList<Pair<String, JsonNode>>();
records.add(Pair.of(RECORD_ID, givenJsonNode()));
records.add(Pair.of(RECORD_ID_ALT, MAPPER.createObjectNode()));
return records;
}

private JsonNode givenJsonNode(){
var node = MAPPER.createObjectNode();
node.put("test", "test");
node.put("data", "value");
return node;
}

private Table<Record> getTable(String tableName) {
return DSL.table("\"" + tableName + "\"");
}
}
Loading

0 comments on commit f8b6d4c

Please sign in to comment.