diff --git a/lib/repl_type_completor.rb b/lib/repl_type_completor.rb index e04d736..ead3381 100644 --- a/lib/repl_type_completor.rb +++ b/lib/repl_type_completor.rb @@ -67,6 +67,23 @@ def analyze_code(code, binding = Object::TOPLEVEL_BINDING) calculate_scope = -> { TypeAnalyzer.calculate_target_type_scope(binding, parents, target_node).last } calculate_type_scope = ->(node) { TypeAnalyzer.calculate_target_type_scope binding, [*parents, target_node], node } + calculate_lvar_or_method = ->(name) { + if parents[-1].is_a?(Prism::ArgumentsNode) && parents[-2].is_a?(Prism::CallNode) + kwarg_call_node = parents[-2] + kwarg_method_sym = kwarg_call_node.message.to_sym + end + kwarg_call_receiver = nil + lvar_or_method_scope = TypeAnalyzer.calculate_target_type_scope binding, parents, target_node do |dig_targets| + if kwarg_call_node&.receiver + dig_targets.on kwarg_call_node.receiver do |type, _scope| + kwarg_call_receiver = type + end + end + end.last + kwarg_call_receiver = lvar_or_method_scope.self_type if kwarg_call_node && kwarg_call_node.receiver.nil? + [:lvar_or_method, name, lvar_or_method_scope, kwarg_call_receiver && [kwarg_call_receiver, kwarg_method_sym]] + } + case target_node when Prism::StringNode, Prism::InterpolatedStringNode call_node, args_node = parents.last(2) @@ -90,7 +107,7 @@ def analyze_code(code, binding = Object::TOPLEVEL_BINDING) end when Prism::CallNode name = target_node.message.to_s - return [:lvar_or_method, name, calculate_scope.call] if target_node.receiver.nil? + return calculate_lvar_or_method.call(name) if target_node.receiver.nil? self_call = target_node.receiver.is_a? Prism::SelfNode op = target_node.call_operator @@ -98,7 +115,7 @@ def analyze_code(code, binding = Object::TOPLEVEL_BINDING) receiver_type = receiver_type.nonnillable if op == '&.' [op == '::' ? :call_or_const : :call, name, receiver_type, self_call] when Prism::LocalVariableReadNode, Prism::LocalVariableTargetNode - [:lvar_or_method, target_node.name.to_s, calculate_scope.call] + calculate_lvar_or_method.call(target_node.name.to_s) when Prism::ConstantReadNode, Prism::ConstantTargetNode name = target_node.name.to_s if parents.last.is_a? Prism::ConstantPathNode diff --git a/lib/repl_type_completor/result.rb b/lib/repl_type_completor/result.rb index 1de88e5..3c32ef5 100644 --- a/lib/repl_type_completor/result.rb +++ b/lib/repl_type_completor/result.rb @@ -61,8 +61,9 @@ def completion_candidates Symbol.all_symbols.map { _1.inspect[1..] } in [:call, name, type, self_call] (self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS - in [:lvar_or_method, name, scope] - scope.self_type.all_methods.map(&:to_s) | scope.local_variables | RESERVED_WORDS + in [:lvar_or_method, name, scope, kwarg_call] + kwargs = kwarg_call ? Types.method_kwargs_names(*kwarg_call).map { "#{_1}:" } : [] + scope.self_type.all_methods.map(&:to_s) | scope.local_variables | kwargs | RESERVED_WORDS else [] end @@ -93,7 +94,7 @@ def doc_namespace(matched) value_doc scope[prefix + matched] in [:call, prefix, type, _self_call] method_doc type, prefix + matched - in [:lvar_or_method, prefix, scope] + in [:lvar_or_method, prefix, scope, kwarg_call] if scope.local_variables.include?(prefix + matched) value_doc scope[prefix + matched] else diff --git a/lib/repl_type_completor/type_analyzer.rb b/lib/repl_type_completor/type_analyzer.rb index f6b16ab..a1dc2a6 100644 --- a/lib/repl_type_completor/type_analyzer.rb +++ b/lib/repl_type_completor/type_analyzer.rb @@ -8,16 +8,20 @@ module ReplTypeCompletor class TypeAnalyzer class DigTarget - def initialize(parents, receiver, &block) - @dig_ids = parents.to_h { [_1.__id__, true] } - @target_id = receiver.__id__ - @block = block + def initialize(parents) + @dig_ids = Set.new(parents.map(&:__id__)) + @events = {} end - def dig?(node) = @dig_ids[node.__id__] - def target?(node) = @target_id == node.__id__ - def resolve(type, scope) - @block.call type, scope + def on(target, &block) + @dig_ids << target.__id__ + @events[target.__id__] = block + end + + def dig?(node) = @dig_ids.include?(node.__id__) + def target?(node) = @events.key?(node.__id__) + def trigger(node, type, scope) + @events[node.__id__]&.call type, scope end end @@ -46,7 +50,7 @@ def evaluate(node, scope) else result = Types::NIL end - @dig_targets.resolve result, scope if @dig_targets.target? node + @dig_targets.trigger node, result, scope result end @@ -242,7 +246,7 @@ def evaluate_call_node(node, scope) # method(args, &:completion_target) call_block_proc = ->(block_args, _self_type) do block_receiver = block_args.first || Types::OBJECT - @dig_targets.resolve block_receiver, scope + @dig_targets.trigger block_sym_node, block_receiver, scope Types::OBJECT end else @@ -892,7 +896,7 @@ def evaluate_constant_node_info(node, scope) name = node.name.to_s type = scope[name] end - @dig_targets.resolve type, scope if @dig_targets.target? node + @dig_targets.trigger node, type, scope [type, receiver, parent_module, name] end @@ -1167,9 +1171,11 @@ def method_call(receiver, method_name, args, kwargs, block, scope, name_match: t end def self.calculate_target_type_scope(binding, parents, target) - dig_targets = DigTarget.new(parents, target) do |type, scope| + dig_targets = DigTarget.new(parents) + dig_targets.on target do |type, scope| return type, scope end + yield dig_targets if block_given? program = parents.first scope = Scope.from_binding(binding, program.locals) new(dig_targets).evaluate program, scope diff --git a/lib/repl_type_completor/types.rb b/lib/repl_type_completor/types.rb index c002284..727af18 100644 --- a/lib/repl_type_completor/types.rb +++ b/lib/repl_type_completor/types.rb @@ -61,13 +61,39 @@ def self.method_return_type(type, method_name) types = receivers.flat_map do |receiver_type, klass, singleton| method = rbs_search_method klass, method_name, singleton next [] unless method - method.method_types.map do |method| - from_rbs_type(method.type.return_type, receiver_type, {}) + method.method_types.map do |method_type| + from_rbs_type(method_type.type.return_type, receiver_type, {}) end end UnionType[*types] end + def self.method_kwargs_names(type, method_name) + receivers = type.types.map do |t| + case t + in SingletonType + [t.module_or_class, true] + in InstanceType + [t.klass, false] + end + end + parameters_keywords = receivers.flat_map do |klass, singleton| + method_obj = singleton ? klass.method(method_name) : klass.instance_method(method_name) + method_obj.parameters.filter_map { _2 if _1 == :key || _1 == :keyreq } + rescue NameError + [] + end + rbs_keywords = receivers.flat_map do |klass, singleton| + method = rbs_search_method klass, method_name, singleton + next [] unless method + + method.method_types.flat_map do |method_type| + method_type.type.required_keywords.keys | method_type.type.optional_keywords.keys + end + end + (parameters_keywords | rbs_keywords).sort + end + def self.rbs_methods(type, method_name, args_types, kwargs_type, has_block) return [] unless rbs_builder diff --git a/test/repl_type_completor/test_repl_type_completor.rb b/test/repl_type_completor/test_repl_type_completor.rb index c80bd97..509de84 100644 --- a/test/repl_type_completor/test_repl_type_completor.rb +++ b/test/repl_type_completor/test_repl_type_completor.rb @@ -76,6 +76,23 @@ def test_lvar assert_doc_namespace('lvar = ""; lvar.ascii_only?', 'String#ascii_only?', binding: bind) end + def test_kwarg + o = Object.new; def o.foo(bar:, baz: true); end + m = Module.new; def m.foo(foobar:, foobaz: true); end + bind = binding + # kwarg name from method.parameters + assert_completion('o.foo ba', binding: bind, include: ['r:', 'z:']) + assert_completion('m.foo fo', binding: bind, include: ['obar:', 'obaz:']) + assert_completion('foo ba', binding: o.instance_eval { binding }, include: ['r:', 'z:']) + assert_completion('foo fo', binding: m.instance_eval { binding }, include: ['obar:', 'obaz:']) + # kwarg name from RBS + assert_completion('"".each_line ch', binding: bind, include: 'omp:') + assert_completion('String.new en', binding: bind, include: 'coding:') + # assert completion when kwarg name is not found + assert_completion('o.inspect ra', binding: bind, include: 'nd') + assert_completion('o.undefined_method ra', binding: bind, include: 'nd') + end + def test_const assert_completion('Ar', include: 'ray') assert_completion('::Ar', include: 'ray') diff --git a/test/repl_type_completor/test_types.rb b/test/repl_type_completor/test_types.rb index 642b2dd..9e29aed 100644 --- a/test/repl_type_completor/test_types.rb +++ b/test/repl_type_completor/test_types.rb @@ -82,5 +82,19 @@ def bo.foobar; end type = ReplTypeCompletor::Types.type_from_object bo assert type.all_methods.include?(:foobar) end + + def test_kwargs_names + bo = BasicObject.new + def bo.foobar(bo_kwarg1: nil, bo_kwarg2:); end + bo_type = ReplTypeCompletor::Types.type_from_object bo + assert_equal %i[bo_kwarg1 bo_kwarg2], ReplTypeCompletor::Types.method_kwargs_names(bo_type, :foobar) + str_type = ReplTypeCompletor::Types::STRING + assert_include ReplTypeCompletor::Types.method_kwargs_names(str_type, :each_line), :chomp + singleton_type = ReplTypeCompletor::Types::SingletonType.new String + assert_include ReplTypeCompletor::Types.method_kwargs_names(singleton_type, :new), :encoding + union_type = ReplTypeCompletor::Types::UnionType[bo_type, str_type, singleton_type] + assert_include ReplTypeCompletor::Types.method_kwargs_names(union_type, :each_line), :chomp + assert_equal ReplTypeCompletor::Types.method_kwargs_names(str_type, :undefined_method), [] + end end end