Skip to content

Commit

Permalink
Fix nested preloads with joined through associations (#4341)
Browse files Browse the repository at this point in the history
  • Loading branch information
greg-rychlewski authored Dec 21, 2023
1 parent 7b1695f commit b26b25b
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 44 deletions.
59 changes: 59 additions & 0 deletions integration_test/cases/joins.exs
Original file line number Diff line number Diff line change
Expand Up @@ -589,4 +589,63 @@ defmodule Ecto.Integration.JoinsTest do
assert [post] = TestRepo.all(query)
assert post.post_user_composite_pk
end

test "joining a through association with a nested preloads" do
post = TestRepo.insert!(%Post{title: "1"})
user = TestRepo.insert!(%User{name: "1"})
TestRepo.insert!(%Comment{text: "1", post_id: post.id})
TestRepo.insert!(%Permalink{post_id: post.id, user_id: user.id})

query =
from c in Comment,
join: pp in assoc(c, :post_permalink),
join: u in assoc(pp, :user),
preload: [post_permalink: {pp, [:post, user: u]}]

[comment] = TestRepo.all(query)

assert not Ecto.assoc_loaded?(comment.post)
assert %Permalink{user: %User{}, post: %Post{}} = comment.post_permalink
end

test "joining multiple through associations with a nested preloads" do
post = TestRepo.insert!(%Post{title: "1"})
user = TestRepo.insert!(%User{name: "1"})
TestRepo.insert!(%Comment{text: "1", post_id: post.id, author_id: user.id})
TestRepo.insert!(%Permalink{post_id: post.id, user_id: user.id})

query =
from c in Comment,
join: pp in assoc(c, :post_permalink),
join: ap in assoc(c, :author_permalink),
join: u1 in assoc(pp, :user),
join: u2 in assoc(ap, :user),
preload: [post_permalink: {pp, [:post, user: u1]}, author_permalink: {ap, [:post, user: u2]}]

[comment] = TestRepo.all(query)

assert not Ecto.assoc_loaded?(comment.post)
assert not Ecto.assoc_loaded?(comment.author)
assert %Permalink{user: %User{}, post: %Post{}} = comment.post_permalink
assert %Permalink{user: %User{}, post: %Post{}} = comment.author_permalink
end

test "joining nested through associations with a nested preloads" do
user = TestRepo.insert!(%User{name: "1"})
post = TestRepo.insert!(%Post{title: "1", author_id: user.id})
TestRepo.insert!(%Comment{text: "1", post_id: post.id})
TestRepo.insert!(%Permalink{post_id: post.id, user_id: user.id})

query =
from c in Comment,
join: pp in assoc(c, :post_permalink),
join: up in assoc(pp, :user_posts),
preload: [post_permalink: {pp, [:post, user_posts: {up, :comments}]}]

[comment] = TestRepo.all(query)

assert not Ecto.assoc_loaded?(comment.post)
assert %Permalink{post: %Post{}, user_posts: [%Post{}]} = comment.post_permalink
assert not Ecto.assoc_loaded?(comment.post_permalink.user)
end
end
2 changes: 2 additions & 0 deletions integration_test/support/schemas.exs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ defmodule Ecto.Integration.Comment do
belongs_to :post, Ecto.Integration.Post
belongs_to :author, Ecto.Integration.User
has_one :post_permalink, through: [:post, :permalink]
has_one :author_permalink, through: [:author, :permalink]
end

def changeset(schema, params) do
Expand All @@ -124,6 +125,7 @@ defmodule Ecto.Integration.Permalink do
belongs_to :update_post, Ecto.Integration.Post, on_replace: :update, foreign_key: :post_id, define_field: false
belongs_to :user, Ecto.Integration.User
has_many :post_comments_authors, through: [:post, :comments_authors]
has_many :user_posts, through: [:user, :posts]
end

def changeset(schema, params) do
Expand Down
139 changes: 96 additions & 43 deletions lib/ecto/repo/preloader.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ defmodule Ecto.Repo.Preloader do
Transforms a result set based on query preloads, loading
the associations onto their parent schema.
"""
@spec query([list], Ecto.Repo.t, list, Access.t, fun, {adapter_meta :: map, opts :: Keyword.t}) :: [list]
def query([], _repo_name, _preloads, _take, _fun, _tuplet), do: []
def query(rows, _repo_name, [], _take, fun, _tuplet), do: Enum.map(rows, fun)
@spec query([list], Ecto.Repo.t, list, Access.t, list, fun, {adapter_meta :: map, opts :: Keyword.t}) :: [list]
def query([], _repo_name, _preloads, _take, _assocs, _fun, _tuplet), do: []
def query(rows, _repo_name, [], _take, _assocs, fun, _tuplet), do: Enum.map(rows, fun)

def query(rows, repo_name, preloads, take, assocs, fun, tuplet) do
assocs = normalize_query_assocs(assocs)

def query(rows, repo_name, preloads, take, fun, tuplet) do
rows
|> extract()
|> normalize_and_preload_each(repo_name, preloads, take, tuplet)
|> normalize_and_preload_each(repo_name, preloads, take, assocs, tuplet)
|> unextract(rows, fun)
end

Expand All @@ -41,16 +43,16 @@ defmodule Ecto.Repo.Preloader do
end

def preload(structs, repo_name, preloads, {_adapter_meta, opts} = tuplet) when is_list(structs) do
normalize_and_preload_each(structs, repo_name, preloads, opts[:take], tuplet)
normalize_and_preload_each(structs, repo_name, preloads, opts[:take], %{}, tuplet)
end

def preload(struct, repo_name, preloads, {_adapter_meta, opts} = tuplet) when is_map(struct) do
normalize_and_preload_each([struct], repo_name, preloads, opts[:take], tuplet) |> hd()
normalize_and_preload_each([struct], repo_name, preloads, opts[:take], %{}, tuplet) |> hd()
end

defp normalize_and_preload_each(structs, repo_name, preloads, take, tuplet) do
defp normalize_and_preload_each(structs, repo_name, preloads, take, query_assocs, tuplet) do
preloads = normalize(preloads, take, preloads)
preload_each(structs, repo_name, preloads, tuplet)
preload_each(structs, repo_name, preloads, query_assocs, tuplet)
rescue
e ->
# Reraise errors so we ignore the preload inner stacktrace
Expand All @@ -59,21 +61,21 @@ defmodule Ecto.Repo.Preloader do

## Preloading

defp preload_each(structs, _repo_name, [], _tuplet), do: structs
defp preload_each([], _repo_name, _preloads, _tuplet), do: []
defp preload_each(structs, repo_name, preloads, tuplet) do
defp preload_each(structs, _repo_name, [], _query_assocs, _tuplet), do: structs
defp preload_each([], _repo_name, _preloads, _query_assocs, _tuplet), do: []
defp preload_each(structs, repo_name, preloads, query_assocs, tuplet) do
if sample = Enum.find(structs, & &1) do
module = sample.__struct__
prefix = preload_prefix(tuplet, sample)
{assocs, throughs, embeds} = expand(module, preloads, {%{}, %{}, []})
{assocs, throughs, embeds} = expand(module, preloads, query_assocs, {%{}, [], []})
structs = preload_embeds(structs, embeds, repo_name, tuplet)
structs = preload_throughs(structs, throughs, repo_name, query_assocs, tuplet)

{fetched_assocs, to_fetch_queries} =
prepare_queries(structs, module, assocs, prefix, repo_name, tuplet)

fetched_queries = maybe_pmap(to_fetch_queries, repo_name, tuplet)
assocs = preload_assocs(fetched_assocs, fetched_queries, repo_name, tuplet)
throughs = Map.values(throughs)
assocs = preload_assocs(fetched_assocs, fetched_queries, repo_name, query_assocs, tuplet)

for struct <- structs do
struct = Enum.reduce assocs, struct, &load_assoc/2
Expand Down Expand Up @@ -148,23 +150,24 @@ defmodule Ecto.Repo.Preloader do

# Then we unpack the query results, merge them, and preload recursively
defp preload_assocs(
[{assoc, query?, loaded_ids, loaded_structs, preloads} | assocs],
[{assoc, query?, loaded_ids, loaded_structs, sub_preloads} | assocs],
queries,
repo_name,
query_assocs,
tuplet
) do
{fetch_ids, fetch_structs, queries} = maybe_unpack_query(query?, queries)
all = preload_each(Enum.reverse(loaded_structs, fetch_structs), repo_name, preloads, tuplet)
sub_query_assocs = Map.get(query_assocs, assoc.field, %{})
all = preload_each(Enum.reverse(loaded_structs, fetch_structs), repo_name, sub_preloads, sub_query_assocs, tuplet)
entry = {:assoc, assoc, assoc_map(assoc.cardinality, Enum.reverse(loaded_ids, fetch_ids), all)}
[entry | preload_assocs(assocs, queries, repo_name, tuplet)]
[entry | preload_assocs(assocs, queries, repo_name, query_assocs, tuplet)]
end

defp preload_assocs([], [], _repo_name, _tuplet), do: []
defp preload_assocs([] = _assocs, [] = _queries, _, _, _), do: []

defp preload_embeds(structs, [], _repo_name, _tuplet), do: structs
defp preload_embeds(structs, [] = _embeds, _, _), do: structs

defp preload_embeds(structs, [embed | embeds], repo_name, tuplet) do

{%{field: field, cardinality: card}, sub_preloads} = embed

{embed_structs, counts} =
Expand All @@ -176,23 +179,59 @@ defmodule Ecto.Repo.Preloader do
struct, _counts -> raise ArgumentError, "expected #{inspect(struct)} to contain embed `#{field}`"
end)

embed_structs = preload_each(embed_structs, repo_name, sub_preloads, tuplet)
structs = load_embeds(card, field, structs, embed_structs, Enum.reverse(counts), [])
# It is not possible for an embed to be preloaded through Ecto.Query.preload
# Therefore, we don't consider associations coming from queries
embed_structs = preload_each(embed_structs, repo_name, sub_preloads, %{}, tuplet)
structs = put_through_or_embed(card, field, structs, embed_structs, Enum.reverse(counts), [])
preload_embeds(structs, embeds, repo_name, tuplet)
end

defp load_embeds(_card, _field, [], [], [], acc), do: Enum.reverse(acc)
defp preload_throughs(structs, [] = _throughs, _, _, _), do: structs

defp load_embeds(card, field, [struct | structs], embed_structs, [0 | counts], acc),
do: load_embeds(card, field, structs, embed_structs, counts, [struct | acc])
defp preload_throughs(
structs,
[{_, _, false = _from_query?} | throughs],
repo_name,
query_assocs,
tuplet
) do
# Through associations will not be preloaded directly unless they were
# loaded through a join using Ecto.Query.preload. When using Ecto.Repo.preload
# or Ecto.Query.preload where the through association is not part of a join,
# the chain of associations making up the through association are preloaded instead.
preload_throughs(structs, throughs, repo_name, query_assocs, tuplet)
end

defp load_embeds(:one, field, [struct | structs], [embed_struct | embed_structs], [1 | counts], acc),
do: load_embeds(:one, field, structs, embed_structs, counts, [Map.put(struct, field, embed_struct) | acc])
defp preload_throughs(structs, [through | throughs], repo_name, query_assocs, tuplet) do
{{_, %{field: field, cardinality: card}, _}, sub_preloads, true} = through
sub_query_assocs = Map.get(query_assocs, field, %{})

defp load_embeds(:many, field, [struct | structs], embed_structs, [count | counts], acc) do
{current_embeds, rest_embeds} = split_n(embed_structs, count, [])
acc = [Map.put(struct, field, Enum.reverse(current_embeds)) | acc]
load_embeds(:many, field, structs, rest_embeds, counts, acc)
{through_structs, counts} =
Enum.flat_map_reduce(structs, [], fn
%{^field => throughs}, counts when is_list(throughs) -> {throughs, [length(throughs) | counts]}
%{^field => nil}, counts -> {[], [0 | counts]}
%{^field => through}, counts -> {[through], [1 | counts]}
nil, counts -> {[], [0 | counts]}
struct, _counts -> raise ArgumentError, "expected #{inspect(struct)} to contain through association `#{field}`"
end)

through_structs = preload_each(through_structs, repo_name, sub_preloads, sub_query_assocs, tuplet)
structs = put_through_or_embed(card, field, structs, through_structs, Enum.reverse(counts), [])
preload_throughs(structs, throughs, repo_name, query_assocs, tuplet)
end

defp put_through_or_embed(_card, _field, [], [], [], acc), do: Enum.reverse(acc)

defp put_through_or_embed(card, field, [struct | structs], loaded_structs, [0 | counts], acc),
do: put_through_or_embed(card, field, structs, loaded_structs, counts, [struct | acc])

defp put_through_or_embed(:one, field, [struct | structs], [loaded | loaded_structs], [1 | counts], acc),
do: put_through_or_embed(:one, field, structs, loaded_structs, counts, [Map.put(struct, field, loaded) | acc])

defp put_through_or_embed(:many, field, [struct | structs], loaded_structs, [count | counts], acc) do
{current_loaded, rest_loaded} = split_n(loaded_structs, count, [])
acc = [Map.put(struct, field, Enum.reverse(current_loaded)) | acc]
put_through_or_embed(:many, field, structs, rest_loaded, counts, acc)
end

defp maybe_unpack_query(false, queries), do: {[], [], queries}
Expand Down Expand Up @@ -440,11 +479,10 @@ defmodule Ecto.Repo.Preloader do
Map.put(struct, field, loaded)
end

defp load_through({:through, _assoc, _throughs}, nil) do
nil
end
defp load_through({_, _, _}, nil), do: nil
defp load_through({_, _, true = _from_query?}, struct), do: struct

defp load_through({:through, assoc, throughs}, struct) do
defp load_through({{:through, assoc, throughs}, _, false = _from_query?}, struct) do
%{cardinality: cardinality, field: field, owner: owner} = assoc
{loaded, _} = Enum.reduce(throughs, {[struct], owner}, &recur_through/2)
Map.put(struct, field, maybe_first(loaded, cardinality))
Expand Down Expand Up @@ -574,9 +612,19 @@ defmodule Ecto.Repo.Preloader do
"preload expects an atom, a (nested) keyword or a (nested) list of atoms"
end

defp normalize_query_assocs([]), do: %{}

defp normalize_query_assocs(assocs) when is_list(assocs) do
Enum.reduce(assocs, %{}, &normalize_each_query_assoc(&1, &2))
end

defp normalize_each_query_assoc({field, {_idx, sub_assocs}}, acc) do
Map.put(acc, field, normalize_query_assocs(sub_assocs))
end

## Expand

def expand(schema, preloads, acc) do
def expand(schema, preloads, query_assocs, acc) do
Enum.reduce(preloads, acc, fn {preload, {fields, query, sub_preloads}},
{assocs, throughs, embeds} ->
assoc_or_embed = association_or_embed!(schema, preload)
Expand All @@ -590,13 +638,18 @@ defmodule Ecto.Repo.Preloader do
{assocs, throughs, embeds}

{:through, _, through} ->
through =
through
|> Enum.reverse()
|> Enum.reduce({fields, query, sub_preloads}, &{nil, nil, [{&1, &2}]})
|> elem(2)
case query_assocs do
%{^preload => _} ->
{assocs, [{info, sub_preloads, true} | throughs], embeds}

_ ->
{_, _, through} =
through
|> Enum.reverse()
|> Enum.reduce({fields, query, sub_preloads}, &{nil, nil, [{&1, &2}]})

expand(schema, through, {assocs, Map.put(throughs, preload, info), embeds})
expand(schema, through, query_assocs, {assocs, [{info, sub_preloads, false} | throughs], embeds})
end

:embed ->
if sub_preloads == [] do
Expand Down
2 changes: 1 addition & 1 deletion lib/ecto/repo/queryable.ex
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ defmodule Ecto.Repo.Queryable do
{count,
rows
|> Ecto.Repo.Assoc.query(assocs, sources, preprocessor)
|> Ecto.Repo.Preloader.query(name, preloads, take, postprocessor, tuplet)}
|> Ecto.Repo.Preloader.query(name, preloads, take, assocs, postprocessor, tuplet)}
end
end

Expand Down

0 comments on commit b26b25b

Please sign in to comment.