Skip to content

Commit

Permalink
Don't do commits/rollbacks on a transaction invalidated by Optimistic…
Browse files Browse the repository at this point in the history
…LockException or BadSessionException

If we received an OptimisticLockException or BadSession inside a transaction, this means that
the transaction was invalidated on the YDB side. Sending commit() or
rollback() after this is incorrect, since in this case they will always generate a
BadSessionException.

In order to write tests for this behavior that will work both on the
real YDB and on the mock, we have to add a check for read locks inside
the InMemoryRepository.

---------

Co-authored-by: Alexander Lavrukov <[email protected]>
  • Loading branch information
lavrukov and Alexander Lavrukov authored Jan 9, 2024
1 parent 62d9cce commit 9e0e385
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ public synchronized void rollback(long txId) {
}

@Nullable
public synchronized T find(long txId, long version, Entity.Id<T> id) {
public synchronized T find(long txId, long version, Entity.Id<T> id, InMemoryTxLockWatcher watcher) {
checkLocks(version, watcher);

InMemoryEntityLine entityLine = entityLines.get(id);
if (entityLine == null) {
return null;
Expand All @@ -105,7 +107,11 @@ public synchronized T find(long txId, long version, Entity.Id<T> id) {
}

@Nullable
public synchronized <V extends Table.View> V find(long txId, long version, Entity.Id<T> id, Class<V> viewType) {
public synchronized <V extends Table.View> V find(
long txId, long version, Entity.Id<T> id, Class<V> viewType, InMemoryTxLockWatcher watcher
) {
checkLocks(version, watcher);

InMemoryEntityLine entityLine = entityLines.get(id);
if (entityLine == null) {
return null;
Expand All @@ -114,7 +120,9 @@ public synchronized <V extends Table.View> V find(long txId, long version, Entit
return columns != null ? columns.toSchema(ViewSchema.of(viewType)) : null;
}

public synchronized List<T> findAll(long txId, long version) {
public synchronized List<T> findAll(long txId, long version, InMemoryTxLockWatcher watcher) {
checkLocks(version, watcher);

List<T> entities = new ArrayList<>();
for (InMemoryEntityLine entityLine : entityLines.values()) {
Columns columns = entityLine.get(txId, version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tech.ydb.yoj.repository.db.cache.TransactionLocal;
import tech.ydb.yoj.repository.db.exception.IllegalTransactionIsolationLevelException;
import tech.ydb.yoj.repository.db.exception.IllegalTransactionScanException;
import tech.ydb.yoj.repository.db.exception.OptimisticLockException;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -34,8 +35,10 @@ public class InMemoryRepositoryTransaction implements BaseDb, RepositoryTransact
private final InMemoryTxLockWatcher watcher;
private final InMemoryStorage storage;

private boolean hasWrites = false;
private Long version = null;
private String closeAction = null; // used to detect of usage transaction after commit()/rollback()
private boolean isBadSession = false;

public InMemoryRepositoryTransaction(TxOptions options, InMemoryRepository repository) {
this.storage = repository.getStorage();
Expand All @@ -62,6 +65,9 @@ public final <T extends Entity<T>> InMemoryTable.DbMemory<T> getMemory(Class<T>

@Override
public void commit() {
if (isBadSession) {
throw new IllegalStateException("Transaction was invalidated. Commit isn't possible");
}
endTransaction("commit()", this::commitImpl);
}

Expand Down Expand Up @@ -125,6 +131,8 @@ final <T extends Entity<T>> void doInWriteTransaction(
Runnable query = () -> logTransaction(log, () -> {
WriteTxDataShard<T> shard = storage.getWriteTxDataShard(type, txId, getVersion());
consumer.accept(shard);

hasWrites = true;
});
if (options.isImmediateWrites()) {
query.run();
Expand All @@ -138,8 +146,14 @@ final <T extends Entity<T>, R> R doInTransaction(
String action, Class<T> type, Function<ReadOnlyTxDataShard<T>, R> func
) {
return logTransaction(action, () -> {
ReadOnlyTxDataShard<T> shard = storage.getReadOnlyTxDataShard(type, txId, getVersion());
return func.apply(shard);
InMemoryTxLockWatcher findWatcher = hasWrites ? watcher : InMemoryTxLockWatcher.NO_LOCKS;
ReadOnlyTxDataShard<T> shard = storage.getReadOnlyTxDataShard(type, txId, getVersion(), findWatcher);
try {
return func.apply(shard);
} catch (OptimisticLockException e) {
isBadSession = true;
throw e;
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,24 @@ public synchronized <T extends Entity<T>> WriteTxDataShard<T> getWriteTxDataShar
Class<T> type, long txId, long version
) {
uncommited.computeIfAbsent(txId, __ -> new HashSet<>()).add(type);
return getTxDataShard(type, txId, version);
return getTxDataShard(type, txId, version, InMemoryTxLockWatcher.NO_LOCKS);
}

public synchronized <T extends Entity<T>> ReadOnlyTxDataShard<T> getReadOnlyTxDataShard(
Class<T> type, long txId, long version
Class<T> type, long txId, long version, InMemoryTxLockWatcher watcher
) {
return getTxDataShard(type, txId, version);
return getTxDataShard(type, txId, version, watcher);
}

private <T extends Entity<T>> TxDataShardImpl<T> getTxDataShard(Class<T> type, long txId, long version) {
private <T extends Entity<T>> TxDataShardImpl<T> getTxDataShard(
Class<T> type, long txId, long version, InMemoryTxLockWatcher watcher
) {
@SuppressWarnings("unchecked")
InMemoryDataShard<T> shard = (InMemoryDataShard<T>) shards.get(type);
if (shard == null) {
throw new InMemoryRepositoryException("Table is not created: " + type.getSimpleName());
}
return new TxDataShardImpl<>(shard, txId, version);
return new TxDataShardImpl<>(shard, txId, version, watcher);
}

public synchronized void dropDb() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package tech.ydb.yoj.repository.test.inmemory;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.db.Range;
Expand All @@ -11,9 +13,16 @@
import java.util.Map;
import java.util.Set;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public final class InMemoryTxLockWatcher {
private final Map<Class<?>, Set<Entity.Id<?>>> readRows = new HashMap<>();
private final Map<Class<?>, List<Range<?>>> readRanges = new HashMap<>();
public static final InMemoryTxLockWatcher NO_LOCKS = new InMemoryTxLockWatcher(Map.of(), Map.of());

private final Map<Class<?>, Set<Entity.Id<?>>> readRows;
private final Map<Class<?>, List<Range<?>>> readRanges;

public InMemoryTxLockWatcher() {
this(new HashMap<>(), new HashMap<>());
}

public <T extends Entity<T>> void markRowRead(Class<T> type, Entity.Id<T> id) {
readRows.computeIfAbsent(type, __ -> new HashSet<>()).add(id);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
package tech.ydb.yoj.repository.test.inmemory;

import lombok.RequiredArgsConstructor;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.Table;

import javax.annotation.Nullable;
import java.util.List;

@RequiredArgsConstructor
final class TxDataShardImpl<T extends Entity<T>> implements ReadOnlyTxDataShard<T>, WriteTxDataShard<T> {
private final InMemoryDataShard<T> shard;
private final long txId;
private final long version;

public TxDataShardImpl(InMemoryDataShard<T> shard, long txId, long version) {
this.shard = shard;
this.txId = txId;
this.version = version;
}
private final InMemoryTxLockWatcher watcher;

@Nullable
@Override
public T find(Entity.Id<T> id) {
return shard.find(txId, version, id);
return shard.find(txId, version, id, watcher);
}

@Nullable
@Override
public <V extends Table.View> V find(Entity.Id<T> id, Class<V> viewType) {
return shard.find(txId, version, id, viewType);
return shard.find(txId, version, id, viewType, watcher);
}

@Override
public List<T> findAll() {
return shard.findAll(txId, version);
return shard.findAll(txId, version, watcher);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,88 @@ public void readTableViews() {
.isThrownBy(() -> db.tx(() -> db.typeFreaks().readTableIds(ReadTableParams.getDefault()).count()));
}

@Test
public void doNotCommitAfterTLI() {
Project.Id id1 = new Project.Id("id1");
Project.Id id2 = new Project.Id("id2");

RepositoryTransaction tx = repository.startTransaction(
TxOptions.create(IsolationLevel.SERIALIZABLE_READ_WRITE)
.withImmediateWrites(true)
.withFirstLevelCache(false)
);

tx.table(Project.class).find(id2);

db.tx(() -> db.projects().save(new Project(id2, "name2")));

tx.table(Project.class).save(new Project(id1, "name1")); // make tx available for TLI

assertThatExceptionOfType(OptimisticLockException.class)
.isThrownBy(() -> tx.table(Project.class).find(id2));

assertThatExceptionOfType(IllegalStateException.class)
.isThrownBy(tx::commit);

tx.rollback(); // YOJ-tx rollback is possible. session.rollbackCommit() won't execute
}

@Test
public void writeDontProduceTLI() {
Project.Id id = new Project.Id("id");

db.tx(() -> db.projects().save(new Project(id, "name")));

RepositoryTransaction tx = repository.startTransaction(
TxOptions.create(IsolationLevel.SERIALIZABLE_READ_WRITE)
.withImmediateWrites(true)
.withFirstLevelCache(false)
);

tx.table(Project.class).find(id);

db.tx(() -> {
db.projects().find(id);
db.projects().save(new Project(id, "name2"));
});

// write don't produce TLI
tx.table(Project.class).save(new Project(id, "name3"));

assertThatExceptionOfType(OptimisticLockException.class)
.isThrownBy(tx::commit);
}

@Test
public void consistencyCheckAllColumnsOnFind() {
Project.Id id1 = new Project.Id("id1");
Project.Id id2 = new Project.Id("id2");

db.tx(() -> {
db.projects().save(new Project(id1, "name"));
db.projects().save(new Project(id2, "name"));
});

RepositoryTransaction tx = repository.startTransaction(
TxOptions.create(IsolationLevel.SERIALIZABLE_READ_WRITE)
.withImmediateWrites(true)
.withFirstLevelCache(false)
);

tx.table(Project.class).save(new Project(new Project.Id("id3"), "name")); // make tx available for TLI

tx.table(Project.class).find(id1);
tx.table(Project.class).find(id2);

db.tx(() -> {
db.projects().find(id2);
db.projects().save(new Project(id2, "name2"));
});

assertThatExceptionOfType(OptimisticLockException.class)
.isThrownBy(() -> tx.table(Project.class).find(id1));
}

@Test
public void streamAllWithPartitioning() {
db.tx(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.yandex.ydb.ValueProtos;
import com.yandex.ydb.core.Result;
import com.yandex.ydb.core.Status;
import com.yandex.ydb.core.StatusCode;
import com.yandex.ydb.table.Session;
import com.yandex.ydb.table.query.DataQueryResult;
import com.yandex.ydb.table.query.Params;
Expand Down Expand Up @@ -36,12 +37,15 @@
import tech.ydb.yoj.repository.db.cache.TransactionLocal;
import tech.ydb.yoj.repository.db.exception.IllegalTransactionIsolationLevelException;
import tech.ydb.yoj.repository.db.exception.IllegalTransactionScanException;
import tech.ydb.yoj.repository.db.exception.OptimisticLockException;
import tech.ydb.yoj.repository.db.exception.RepositoryException;
import tech.ydb.yoj.repository.db.exception.UnavailableException;
import tech.ydb.yoj.repository.db.readtable.ReadTableParams;
import tech.ydb.yoj.repository.ydb.bulk.BulkMapper;
import tech.ydb.yoj.repository.ydb.client.ResultSetConverter;
import tech.ydb.yoj.repository.ydb.client.YdbConverter;
import tech.ydb.yoj.repository.ydb.client.YdbValidator;
import tech.ydb.yoj.repository.ydb.exception.BadSessionException;
import tech.ydb.yoj.repository.ydb.exception.ResultTruncatedException;
import tech.ydb.yoj.repository.ydb.exception.UnexpectedException;
import tech.ydb.yoj.repository.ydb.exception.YdbComponentUnavailableException;
Expand All @@ -65,7 +69,6 @@
import static com.google.common.base.Strings.emptyToNull;
import static java.lang.String.format;
import static java.util.stream.Collectors.toList;
import static tech.ydb.yoj.repository.ydb.client.YdbValidator.validate;
import static tech.ydb.yoj.repository.ydb.client.YdbValidator.validatePkConstraint;
import static tech.ydb.yoj.repository.ydb.client.YdbValidator.validateTruncatedResults;

Expand All @@ -89,6 +92,7 @@ public class YdbRepositoryTransaction<REPO extends YdbRepository>
protected String txId = null;
private String firstNonNullTxId = null; // used for logs
private String closeAction = null; // used to detect of usage transaction after commit()/rollback()
private boolean isBadSession = false;

public YdbRepositoryTransaction(REPO repo, @NonNull TxOptions options) {
this.repo = repo;
Expand All @@ -110,6 +114,9 @@ public <T extends Entity<T>> Table<T> table(Class<T> c) {

@Override
public void commit() {
if (isBadSession) {
throw new IllegalStateException("Transaction was invalidated. Commit isn't possible");
}
try {
flushPendingWrites();
} catch (Throwable t) {
Expand Down Expand Up @@ -162,8 +169,18 @@ private void closeStreams() {
}
}

private void validate(String request, StatusCode statusCode, String response) {
try {
YdbValidator.validate(request, statusCode, response);
} catch (BadSessionException | OptimisticLockException e) {
transactionLocal.log().info("Request got %s: DB tx was invalidated", e.getClass().getSimpleName());
isBadSession = true;
throw e;
}
}

private boolean isFinalActionNeeded(String actionName) {
if (session == null) {
if (session == null || isBadSession) {
transactionLocal.log().info("No-op %s: no active DB session", actionName);
return false;
}
Expand Down Expand Up @@ -430,7 +447,7 @@ public <IN> void bulkUpsert(BulkMapper<IN> mapper, List<IN> input, BulkParams pa
} catch (RepositoryException e) {
throw e;
} catch (Exception e) {
throw new UnexpectedException("Could not bulk insert into table " + tableName);
throw new UnexpectedException("Could not bulk insert into table " + tableName, e);
}
});
}
Expand Down Expand Up @@ -489,7 +506,7 @@ public <PARAMS, RESULT> Stream<RESULT> readTable(ReadTableMapper<PARAMS, RESULT>
} catch (RepositoryException e) {
throw e;
} catch (Exception e) {
throw new UnexpectedException("Could not read table " + tableName);
throw new UnexpectedException("Could not read table " + tableName, e);
}
}

Expand Down
Loading

0 comments on commit 9e0e385

Please sign in to comment.