Skip to content

Commit

Permalink
Remove static find() method from YqlStatement (#102)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Lavrukov <[email protected]>
  • Loading branch information
lavrukov and Alexander Lavrukov authored Dec 4, 2024
1 parent 51623c8 commit 0bee8ca
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,48 @@
import tech.ydb.yoj.databind.schema.Schema;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.ydb.yql.YqlPredicate;
import tech.ydb.yoj.repository.ydb.yql.YqlOrderBy;
import tech.ydb.yoj.repository.ydb.yql.YqlStatementPart;

import java.util.ArrayList;
import java.util.Collection;
import java.util.function.Function;
import java.util.List;

import static java.util.Comparator.comparing;
import static java.util.stream.Collectors.joining;

public class FindStatement<ENTITY extends Entity<ENTITY>, RESULT> extends PredicateStatement<Collection<? extends YqlStatementPart<?>>, ENTITY, RESULT> {
private final boolean distinct;
private final Collection<? extends YqlStatementPart<?>> parts;
private final List<YqlStatementPart<?>> parts;

public FindStatement(
public static <E extends Entity<E>, R> FindStatement<E, R> from(
@NonNull EntitySchema<E> schema,
@NonNull Schema<R> outSchema,
@NonNull Collection<? extends YqlStatementPart<?>> parts,
boolean distinct
) {
List<YqlStatementPart<?>> partsList = new ArrayList<>(parts);
if (!distinct) {
if (parts.stream().noneMatch(s -> s.getType().equals(YqlOrderBy.TYPE))) {
partsList.add(ORDER_BY_ID_ASCENDING);
}
}

return new FindStatement<>(schema, outSchema, partsList, distinct);
}

private FindStatement(
@NonNull EntitySchema<ENTITY> schema,
@NonNull Schema<RESULT> outSchema,
@NonNull Collection<? extends YqlStatementPart<?>> parts,
@NonNull Function<Collection<? extends YqlStatementPart<?>>, YqlPredicate> predicateFrom,
boolean distinct,
String tableName) {
super(schema, outSchema, parts, predicateFrom, tableName);
@NonNull List<YqlStatementPart<?>> parts,
boolean distinct
) {
super(schema, outSchema, parts, YqlStatement::predicateFrom);
this.distinct = distinct;
this.parts = parts;
}

@Override
public String getQuery(String tablespace) {
return declarations()
+ "SELECT " + (distinct ? "DISTINCT " : "") + outNames()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ public FindYqlStatement(@NonNull EntitySchema<ENTITY> schema, @NonNull Schema<RE
super(schema, resultSchema);
}

public FindYqlStatement(@NonNull EntitySchema<ENTITY> schema, @NonNull Schema<RESULT> resultSchema, String tableName) {
super(schema, resultSchema, tableName);
}

@Override
public List<YqlStatementParam> getParams() {
return schema.flattenId().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tech.ydb.yoj.repository.ydb.yql.YqlStatementPart;
import tech.ydb.yoj.repository.ydb.yql.YqlType;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -77,26 +76,6 @@ public static <PARAMS, ENTITY extends Entity<ENTITY>> Statement<PARAMS, ENTITY>
return new UpsertYqlStatement<>(type);
}

public static <PARAMS, ENTITY extends Entity<ENTITY>> Statement<PARAMS, ENTITY> find(
Class<ENTITY> type
) {
EntitySchema<ENTITY> schema = EntitySchema.of(type);
return find(schema, schema);
}

public static <PARAMS, ENTITY extends Entity<ENTITY>, VIEW extends View> Statement<PARAMS, VIEW> find(
Class<ENTITY> type,
Class<VIEW> viewType
) {
return find(EntitySchema.of(type), ViewSchema.of(viewType));
}

private static <PARAMS, ENTITY extends Entity<ENTITY>, RESULT> Statement<PARAMS, RESULT> find(
EntitySchema<ENTITY> schema,
Schema<RESULT> resultSchema) {
return new FindYqlStatement<>(schema, resultSchema);
}

public static <ENTITY extends Entity<ENTITY>, ID extends Entity.Id<ENTITY>> Statement<Range<ID>, ENTITY> findRange(
Class<ENTITY> type,
Range<ID> range
Expand Down Expand Up @@ -188,38 +167,6 @@ private static <PARAMS, ENTITY extends Entity<ENTITY>, RESULT> Statement<PARAMS,
return new FindAllYqlStatement<>(schema, outSchema);
}

public static <ENTITY extends Entity<ENTITY>> Statement<Collection<? extends YqlStatementPart<?>>, ENTITY> find(
Class<ENTITY> type,
Collection<? extends YqlStatementPart<?>> parts
) {
EntitySchema<ENTITY> schema = EntitySchema.of(type);
return find(schema, schema, false, parts);
}

public static <ENTITY extends Entity<ENTITY>, VIEW extends View> Statement<Collection<? extends YqlStatementPart<?>>, VIEW> find(
Class<ENTITY> type,
Class<VIEW> viewType,
Collection<? extends YqlStatementPart<?>> parts
) {
return find(type, viewType, false, parts);
}

public static <ENTITY extends Entity<ENTITY>, VIEW extends View> Statement<Collection<? extends YqlStatementPart<?>>, VIEW> find(
Class<ENTITY> type,
Class<VIEW> viewType,
boolean distinct,
Collection<? extends YqlStatementPart<?>> parts
) {
return find(EntitySchema.of(type), ViewSchema.of(viewType), distinct, parts);
}

public static <ENTITY extends Entity<ENTITY>, ID extends Entity.Id<ENTITY>> Statement<Collection<? extends YqlStatementPart<?>>, ID> findIds(
Class<ENTITY> type,
Collection<? extends YqlStatementPart<?>> parts
) {
return find(EntitySchema.of(type), EntityIdSchema.ofEntity(type), false, parts);
}

public static <ENTITY extends Entity<ENTITY>, ID extends Entity.Id<ENTITY>> Statement<Range<ID>, ID> findIds(
Class<ENTITY> type,
Range<ID> range
Expand Down Expand Up @@ -247,31 +194,6 @@ public String getDeclaration(String name, String type) {
return String.format("DECLARE %s AS %s;\n", name, type);
}

private static <ENTITY extends Entity<ENTITY>, RESULT> Statement<Collection<? extends YqlStatementPart<?>>, RESULT> find(
EntitySchema<ENTITY> schema,
Schema<RESULT> resultSchema,
boolean distinct,
Collection<? extends YqlStatementPart<?>> parts
) {
return find(schema, resultSchema, distinct, parts, schema.getName());
}

static <ENTITY extends Entity<ENTITY>, RESULT> Statement<Collection<? extends YqlStatementPart<?>>, RESULT> find(
EntitySchema<ENTITY> schema,
Schema<RESULT> resultSchema,
boolean distinct,
Collection<? extends YqlStatementPart<?>> parts,
String tableName
) {
List<YqlStatementPart<?>> partList = new ArrayList<>(parts);
if (!distinct) {
if (parts.stream().noneMatch(s -> s.getType().equals(YqlOrderBy.TYPE))) {
partList.add(ORDER_BY_ID_ASCENDING);
}
}
return new FindStatement<>(schema, resultSchema, parts, YqlStatement::predicateFrom, distinct, tableName);
}

public static <ENTITY extends Entity<ENTITY>> Statement<Collection<? extends YqlStatementPart<?>>, Count> count(
Class<ENTITY> entityType,
Collection<? extends YqlStatementPart<?>> parts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import tech.ydb.yoj.repository.ydb.bulk.BulkMapperImpl;
import tech.ydb.yoj.repository.ydb.readtable.EntityIdKeyMapper;
import tech.ydb.yoj.repository.ydb.readtable.ReadTableMapper;
import tech.ydb.yoj.repository.ydb.statement.FindStatement;
import tech.ydb.yoj.repository.ydb.statement.FindYqlStatement;
import tech.ydb.yoj.repository.ydb.statement.Statement;
import tech.ydb.yoj.repository.ydb.statement.UpdateInStatement;
import tech.ydb.yoj.repository.ydb.statement.UpdateModel;
Expand Down Expand Up @@ -223,14 +225,17 @@ public T find(Entity.Id<T> id) {
throw new IllegalArgumentException("Cannot use partial id in find method");
}
return executor.getTransactionLocal().firstLevelCache().get(id, __ -> {
List<T> res = postLoad(executor.execute(YqlStatement.find(type), id));
var statement = new FindYqlStatement<>(schema, schema);
List<T> res = postLoad(executor.execute(statement, id));
return res.isEmpty() ? null : res.get(0);
});
}

@Override
public <V extends View> V find(Class<V> viewType, Entity.Id<T> id) {
List<V> res = executor.execute(YqlStatement.find(type, viewType), id);
ViewSchema<V> viewSchema = ViewSchema.of(viewType);
var statement = new FindYqlStatement<>(schema, viewSchema);
List<V> res = executor.execute(statement, id);
return res.isEmpty() ? null : res.get(0);
}

Expand All @@ -254,7 +259,8 @@ public final List<T> find(YqlStatementPart<?> part, YqlStatementPart<?>... other
}

public List<T> find(Collection<? extends YqlStatementPart<?>> parts) {
return postLoad(executor.execute(YqlStatement.find(type, parts), parts));
var statement = FindStatement.from(schema, schema, parts, false);
return postLoad(executor.execute(statement, parts));
}

@Override
Expand Down Expand Up @@ -385,15 +391,19 @@ public <V extends View> List<V> find(Class<V> viewType, YqlStatementPart<?> part
}

public <V extends View> List<V> find(Class<V> viewType, Collection<? extends YqlStatementPart<?>> parts, boolean distinct) {
return executor.execute(YqlStatement.find(type, viewType, distinct, parts), parts);
ViewSchema<V> viewSchema = ViewSchema.of(viewType);
var statement = FindStatement.from(schema, viewSchema, parts, distinct);
return executor.execute(statement, parts);
}

public <ID extends Entity.Id<T>> List<ID> findIds(YqlStatementPart<?> part, YqlStatementPart<?>... otherParts) {
return findIds(toList(part, otherParts));
}

private <ID extends Entity.Id<T>> List<ID> findIds(Collection<? extends YqlStatementPart<?>> parts) {
return executor.execute(YqlStatement.findIds(type, parts), parts);
EntityIdSchema<ID> idSchema = EntityIdSchema.ofEntity(type);
var statement = FindStatement.from(schema, idSchema, parts, false);
return executor.execute(statement, parts);
}

@Override
Expand Down Expand Up @@ -450,7 +460,8 @@ public void delete(Entity.Id<T> id) {
* @param <ID> entity ID type
*/
public <ID extends Id<T>> void migrate(ID id) {
List<T> foundRaw = executor.execute(YqlStatement.find(type), id);
var statement = new FindYqlStatement<>(schema, schema);
List<T> foundRaw = executor.execute(statement, id);
if (foundRaw.isEmpty()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import tech.ydb.table.query.DataQueryResult;
import tech.ydb.table.query.Params;
import tech.ydb.table.values.StructType;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.db.Range;
import tech.ydb.yoj.repository.db.exception.EntityAlreadyExistsException;
import tech.ydb.yoj.repository.test.sample.TestEntityOperations;
import tech.ydb.yoj.repository.test.sample.model.Complex;
import tech.ydb.yoj.repository.test.sample.model.Complex.Id;
import tech.ydb.yoj.repository.ydb.client.SessionManager;
import tech.ydb.yoj.repository.ydb.client.YdbConverter;
import tech.ydb.yoj.repository.ydb.statement.FindYqlStatement;
import tech.ydb.yoj.repository.ydb.statement.MultipleVarsYqlStatement;
import tech.ydb.yoj.repository.ydb.statement.YqlStatement;

Expand Down Expand Up @@ -350,6 +352,8 @@ private CompletableFuture<Result<DataQueryResult>> convertEntity(List<Complex> c
}

private Params convertId(Id id) {
return YdbConverter.convertToParams(YqlStatement.find(Complex.class).toQueryParameters(id));
EntitySchema<Complex> schema = EntitySchema.of(Complex.class);
var statement = new FindYqlStatement<>(schema, schema);
return YdbConverter.convertToParams(statement.toQueryParameters(id));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import tech.ydb.yoj.repository.ydb.sample.model.HintInt64Range;
import tech.ydb.yoj.repository.ydb.sample.model.HintTablePreset;
import tech.ydb.yoj.repository.ydb.sample.model.HintUniform;
import tech.ydb.yoj.repository.ydb.statement.FindStatement;
import tech.ydb.yoj.repository.ydb.statement.YqlStatement;
import tech.ydb.yoj.repository.ydb.table.YdbTable;
import tech.ydb.yoj.repository.ydb.yql.YqlPredicate;
Expand Down Expand Up @@ -888,12 +889,13 @@ public void complexIdLtYsingYqlPredicate() {

private void executeQuery(String expectSqlQuery, List<IndexedEntity> expectRows,
Collection<? extends YqlStatementPart<?>> query) {
var statement = YqlStatement.find(IndexedEntity.class, query);
EntitySchema<IndexedEntity> schema = EntitySchema.of(IndexedEntity.class);
var statement = FindStatement.from(schema, schema, new ArrayList<>(query), false);
var sqlQuery = statement.getQuery("ts/");
assertEquals(expectSqlQuery, sqlQuery);

// Check we use index and query was not failed
var actual = db.tx(() -> ((YdbTable<IndexedEntity>) db.indexedTable()).find(query));
var actual = db.tx(() -> ((YdbTable<IndexedEntity>) db.indexedTable()).find(new ArrayList<>(query)));
assertEquals(expectRows, actual);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import org.assertj.core.api.Assertions;
import org.junit.Test;
import tech.ydb.yoj.repository.db.Entity;
import tech.ydb.yoj.repository.db.EntitySchema;
import tech.ydb.yoj.repository.db.cache.RepositoryCache;
import tech.ydb.yoj.repository.db.cache.RepositoryCacheImpl;
import tech.ydb.yoj.repository.test.sample.model.Primitive;
import tech.ydb.yoj.repository.test.sample.model.Project;
import tech.ydb.yoj.repository.ydb.YdbRepository;
import tech.ydb.yoj.repository.ydb.statement.FindYqlStatement;
import tech.ydb.yoj.repository.ydb.statement.Statement;
import tech.ydb.yoj.repository.ydb.statement.YqlStatement;

Expand Down Expand Up @@ -154,7 +156,9 @@ private <T extends Entity<T>> YdbRepository.Query<?> insert(T p) {

@SuppressWarnings("unchecked")
private <T extends Entity<T>> YdbRepository.Query<?> find(T p) {
return new YdbRepository.Query<>(YqlStatement.find((Class<T>) p.getClass()), p.getId());
EntitySchema<T> schema = EntitySchema.of((Class<T>) p.getClass());
var statement = new FindYqlStatement<>(schema, schema);
return new YdbRepository.Query<>(statement, p.getId());
}

@SuppressWarnings("unchecked")
Expand Down

0 comments on commit 0bee8ca

Please sign in to comment.