From 8fedda30891362a68bc962c43d6bc64527ea7dfb Mon Sep 17 00:00:00 2001 From: Vinicius Stock Date: Thu, 12 Sep 2024 10:21:56 -0400 Subject: [PATCH] Add keyword support to signature match Co-authored-by: Andy Waite --- lib/ruby_indexer/lib/ruby_indexer/entry.rb | 78 +++++++++++++++-- lib/ruby_indexer/test/method_test.rb | 98 ++++++++++++++++++---- 2 files changed, 153 insertions(+), 23 deletions(-) diff --git a/lib/ruby_indexer/lib/ruby_indexer/entry.rb b/lib/ruby_indexer/lib/ruby_indexer/entry.rb index e68170ee8..396a4941e 100644 --- a/lib/ruby_indexer/lib/ruby_indexer/entry.rb +++ b/lib/ruby_indexer/lib/ruby_indexer/entry.rb @@ -600,11 +600,26 @@ def format # Returns `true` if the given call node arguments array matches this method signature. This method will prefer # returning `true` for situations that cannot be analyzed statically, like the presence of splats, keyword splats - # or forwarding arguments + # or forwarding arguments. + # + # Since this method is used to detect which overload should be displayed in signature help, it will also return + # `true` if there are missing arguments since the user may not be done typing yet. For example: + # + # ```ruby + # def foo(a, b); end + # # All of the following are considered matches because the user might be in the middle of typing and we have to + # # show them the signature + # foo + # foo(1) + # foo(1, 2) + # ``` sig { params(arguments: T::Array[Prism::Node]).returns(T::Boolean) } def matches?(arguments) min_pos = 0 - max_pos = T.let(0, Numeric) + max_pos = T.let(0, T.any(Integer, Float)) + names = [] + has_forward = T.let(false, T::Boolean) + has_keyword_rest = T.let(false, T::Boolean) @parameters.each do |param| case param @@ -617,15 +632,66 @@ def matches?(arguments) max_pos = Float::INFINITY when ForwardingParameter max_pos = Float::INFINITY + has_forward = true + when KeywordParameter, OptionalKeywordParameter + names << param.name + when KeywordRestParameter + has_keyword_rest = true end end - _keyword_hash_node, positional_args = arguments.partition { |arg| arg.is_a?(Prism::KeywordHashNode) } - argument_length_is_unknown = positional_args.any? do |arg| - arg.is_a?(Prism::SplatNode) || arg.is_a?(Prism::ForwardingArgumentsNode) + keyword_hash_nodes, positional_args = arguments.partition { |arg| arg.is_a?(Prism::KeywordHashNode) } + keyword_args = T.cast(keyword_hash_nodes.first, T.nilable(Prism::KeywordHashNode))&.elements + forwarding_arguments, positionals = positional_args.partition do |arg| + arg.is_a?(Prism::ForwardingArgumentsNode) end - argument_length_is_unknown || (min_pos..max_pos).cover?(positional_args.length) + return true if has_forward && min_pos == 0 + + # If the only argument passed is a forwarding argument, then anything will match + (positionals.empty? && forwarding_arguments.any?) || + ( + # Check if positional arguments match. This includes required, optional, rest arguments. We also need to + # verify if there's a trailing forwading argument, like `def foo(a, ...); end` + positional_arguments_match?(positionals, forwarding_arguments, keyword_args, min_pos, max_pos) && + # If the positional arguments match, we move on to checking keyword, optional keyword and keyword rest + # arguments. If there's a forward argument, then it will always match. If the method accepts a keyword rest + # (**kwargs), then we can't analyze statically because the user could be passing a hash and we don't know + # what the runtime values inside the hash are. + # + # If none of those match, then we verify if the user is passing the expect names for the keyword arguments + (has_forward || has_keyword_rest || keyword_arguments_match?(keyword_args, names)) + ) + end + + sig do + params( + positional_args: T::Array[Prism::Node], + forwarding_arguments: T::Array[Prism::Node], + keyword_args: T.nilable(T::Array[Prism::Node]), + min_pos: Integer, + max_pos: T.any(Integer, Float), + ).returns(T::Boolean) + end + def positional_arguments_match?(positional_args, forwarding_arguments, keyword_args, min_pos, max_pos) + # If the method accepts at least one positional argument and a splat has been passed + (min_pos > 0 && positional_args.any? { |arg| arg.is_a?(Prism::SplatNode) }) || + # If there's at least one positional argument unaccounted for and a keyword splat has been passed + (min_pos - positional_args.length > 0 && keyword_args&.any? { |arg| arg.is_a?(Prism::AssocSplatNode) }) || + # If there's at least one positional argument unaccounted for and a forwarding argument has been passed + (min_pos - positional_args.length > 0 && forwarding_arguments.any?) || + # If the number of positional arguments is within the expected range + (min_pos > 0 && positional_args.length <= max_pos) || + (min_pos == 0 && positional_args.empty?) + end + + sig { params(args: T.nilable(T::Array[Prism::Node]), names: T::Array[Symbol]).returns(T::Boolean) } + def keyword_arguments_match?(args, names) + return true unless args + return true if args.any? { |arg| arg.is_a?(Prism::AssocSplatNode) } + + arg_names = args.filter_map { |arg| arg.key.value.to_sym if arg.is_a?(Prism::AssocNode) } + (arg_names - names).empty? end end end diff --git a/lib/ruby_indexer/test/method_test.rb b/lib/ruby_indexer/test/method_test.rb index a443b3655..d02f2cfc6 100644 --- a/lib/ruby_indexer/test/method_test.rb +++ b/lib/ruby_indexer/test/method_test.rb @@ -500,6 +500,7 @@ def bar(a, b = 123) entry = T.must(@index["bar"].first) # Matching calls + assert_signature_matches(entry, "bar()") assert_signature_matches(entry, "bar(1)") assert_signature_matches(entry, "bar(1, 2)") assert_signature_matches(entry, "bar(...)") @@ -510,15 +511,16 @@ def bar(a, b = 123) assert_signature_matches(entry, "bar(*a, 2)") assert_signature_matches(entry, "bar(1, **a)") assert_signature_matches(entry, "bar(1) {}") + # This call is impossible to analyze statically because it depends on whether there are elements inside `a` or + # not. If there's nothing, the call will fail. But if there's anything inside, the hash will become the first + # positional argument + assert_signature_matches(entry, "bar(**a)") # Non matching calls - refute_signature_matches(entry, "bar()") refute_signature_matches(entry, "bar(1, 2, 3)") - - # TODO: uncomment after keyword support - # refute_signature_matches(entry, "bar(1, b: 2)") - # refute_signature_matches(entry, "bar(1, 2, c: 3)") + refute_signature_matches(entry, "bar(1, b: 2)") + refute_signature_matches(entry, "bar(1, 2, c: 3)") end def test_signature_matches_for_a_method_with_argument_forwarding @@ -570,8 +572,7 @@ def bar(a, ...) assert_signature_matches(entry, "bar(1) {}") assert_signature_matches(entry, "bar(1, 2, 3)") assert_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") - - refute_signature_matches(entry, "bar()") + assert_signature_matches(entry, "bar()") end def test_signature_matches_for_destructured_parameters @@ -585,6 +586,8 @@ def bar(a, (b, c)) entry = T.must(@index["bar"].first) # All calls with at least one positional argument match + assert_signature_matches(entry, "bar()") + assert_signature_matches(entry, "bar(1)") assert_signature_matches(entry, "bar(1, 2)") assert_signature_matches(entry, "bar(...)") assert_signature_matches(entry, "bar(1, ...)") @@ -593,15 +596,11 @@ def bar(a, (b, c)) assert_signature_matches(entry, "bar(*a, 2)") # This matches because `bar(1, *[], 2)` would result in `bar(1, 2)`, which is a valid call assert_signature_matches(entry, "bar(1, *a, 2)") + assert_signature_matches(entry, "bar(1, **a)") + assert_signature_matches(entry, "bar(1) {}") - refute_signature_matches(entry, "bar()") - refute_signature_matches(entry, "bar(1)") - refute_signature_matches(entry, "bar(1, **a)") refute_signature_matches(entry, "bar(1, 2, 3)") - refute_signature_matches(entry, "bar(1) {}") - - # TODO: uncomment after keyword support - # refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") + refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") end def test_signature_matches_for_post_parameters @@ -626,11 +625,76 @@ def bar(*splat, a) assert_signature_matches(entry, "bar(1, **a)") assert_signature_matches(entry, "bar(1, 2, 3)") assert_signature_matches(entry, "bar(1) {}") + assert_signature_matches(entry, "bar()") + + refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") + end + + def test_signature_matches_for_keyword_parameters + index(<<~RUBY) + class Foo + def bar(a:, b: 123) + end + end + RUBY + + entry = T.must(@index["bar"].first) + + assert_signature_matches(entry, "bar(...)") + assert_signature_matches(entry, "bar()") + assert_signature_matches(entry, "bar(a: 1)") + assert_signature_matches(entry, "bar(a: 1, b: 32)") + + refute_signature_matches(entry, "bar(a: 1, c: 2)") + refute_signature_matches(entry, "bar(1, ...)") + refute_signature_matches(entry, "bar(1) {}") + refute_signature_matches(entry, "bar(1, *a)") + refute_signature_matches(entry, "bar(*a, 2)") + refute_signature_matches(entry, "bar(1, *a, 2)") + refute_signature_matches(entry, "bar(1, **a)") + refute_signature_matches(entry, "bar(*a)") + refute_signature_matches(entry, "bar(1)") + refute_signature_matches(entry, "bar(1, 2)") + refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") + end + + def test_signature_matches_for_keyword_splats + index(<<~RUBY) + class Foo + def bar(a, b:, **kwargs) + end + end + RUBY + + entry = T.must(@index["bar"].first) + + assert_signature_matches(entry, "bar(...)") + assert_signature_matches(entry, "bar()") + assert_signature_matches(entry, "bar(1)") + assert_signature_matches(entry, "bar(1, b: 2)") + assert_signature_matches(entry, "bar(1, b: 2, c: 3, d: 4)") + + refute_signature_matches(entry, "bar(1, 2, b: 2)") + end + + def test_partial_signature_matches + # It's important to match signatures partially, because we want to figure out which signature we should show while + # the user is in the middle of typing + index(<<~RUBY) + class Foo + def bar(a:, b:) + end - refute_signature_matches(entry, "bar()") + def baz(a, b) + end + end + RUBY - # TODO: uncomment after keyword support - # refute_signature_matches(entry, "bar(1, 2, a: 1, b: 5) {}") + entry = T.must(@index["bar"].first) + assert_signature_matches(entry, "bar(a: 1)") + + entry = T.must(@index["baz"].first) + assert_signature_matches(entry, "baz(1)") end private