diff --git a/src/main/java/org/itsallcode/jdbc/BatchInsertBuilder.java b/src/main/java/org/itsallcode/jdbc/BatchInsertBuilder.java index 945bac9..c8640f2 100644 --- a/src/main/java/org/itsallcode/jdbc/BatchInsertBuilder.java +++ b/src/main/java/org/itsallcode/jdbc/BatchInsertBuilder.java @@ -4,6 +4,7 @@ import java.sql.PreparedStatement; import java.util.*; +import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Stream; @@ -17,15 +18,15 @@ public class BatchInsertBuilder { private static final Logger LOG = Logger.getLogger(BatchInsertBuilder.class.getName()); private static final int DEFAULT_MAX_BATCH_SIZE = 200_000; - private final SimpleConnection connection; + private final Function statementFactory; private final Context context; private String sql; private RowPreparedStatementSetter mapper; private Iterator rows; private int maxBatchSize = DEFAULT_MAX_BATCH_SIZE; - BatchInsertBuilder(final SimpleConnection connection, final Context context) { - this.connection = connection; + BatchInsertBuilder(final Function statementFactory, final Context context) { + this.statementFactory = statementFactory; this.context = context; } @@ -129,7 +130,7 @@ public void start() { Objects.requireNonNull(this.mapper, "mapper"); Objects.requireNonNull(this.rows, "rows"); LOG.finest(() -> "Running insert statement '" + sql + "'..."); - final SimplePreparedStatement statement = connection.prepareStatement(sql); + final SimplePreparedStatement statement = statementFactory.apply(sql); try (BatchInsert batch = new BatchInsert<>(statement, this.mapper, this.maxBatchSize)) { while (rows.hasNext()) { batch.add(rows.next()); diff --git a/src/main/java/org/itsallcode/jdbc/SimpleConnection.java b/src/main/java/org/itsallcode/jdbc/SimpleConnection.java index a2d1b87..058c35b 100644 --- a/src/main/java/org/itsallcode/jdbc/SimpleConnection.java +++ b/src/main/java/org/itsallcode/jdbc/SimpleConnection.java @@ -108,7 +108,7 @@ SimplePreparedStatement prepareStatement(final String sql) { * @return batch insert builder */ public BatchInsertBuilder batchInsert(final Class rowType) { - return new BatchInsertBuilder<>(this, context); + return new BatchInsertBuilder<>(this::prepareStatement, context); } private PreparedStatement prepare(final String sql) {