Skip to content

Commit

Permalink
wip metals bsp
Browse files Browse the repository at this point in the history
  • Loading branch information
jgranstrom committed Jul 8, 2024
1 parent ef88dc8 commit 960c6ae
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 49 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ __pycache__/
.idea
/.vscode/
.cache
.metals
.scala-build
.pants.d
# TODO: We can probably delete these 3. They have not been used in a long time, if ever.
# In fact there's a lot of things we can clean up in this .gitignore. It's not harmful
Expand Down
158 changes: 126 additions & 32 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

import dataclasses
import logging
import textwrap
from dataclasses import dataclass
Expand Down Expand Up @@ -50,7 +51,16 @@
)
from pants.core.util_rules.system_binaries import BashBinary, ReadlinkBinary
from pants.engine.addresses import Addresses
from pants.engine.fs import AddPrefix, CreateDigest, Digest, FileContent, MergeDigests, Workspace
from pants.engine.fs import (
AddPrefix,
CreateDigest,
Digest,
FileContent,
FileDigest,
FileEntry,
MergeDigests,
Workspace,
)
from pants.engine.internals.native_engine import Snapshot
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.process import Process, ProcessResult
Expand Down Expand Up @@ -114,7 +124,7 @@ class ThirdpartyModulesRequest:
@dataclass(frozen=True)
class ThirdpartyModules:
resolve: CoursierResolveKey
entries: dict[CoursierLockfileEntry, ClasspathEntry]
entries: dict[CoursierLockfileEntry, tuple[ClasspathEntry, list[CoursierLockfileEntry]]]
merged_digest: Digest


Expand All @@ -128,6 +138,11 @@ async def collect_thirdparty_modules(
lockfile = await Get(CoursierResolvedLockfile, CoursierResolveKey, resolve)

applicable_lockfile_entries: dict[CoursierLockfileEntry, CoarsenedTarget] = {}
applicable_lockfile_source_entries: dict[CoursierLockfileEntry, CoursierLockfileEntry] = {}
applicable_lockfile_source_entries_inverse: dict[
CoursierLockfileEntry, list[CoursierLockfileEntry]
] = {}

for ct in coarsened_targets.coarsened_closure():
for tgt in ct.members:
if not JvmArtifactFieldSet.is_applicable(tgt):
Expand All @@ -142,6 +157,21 @@ async def collect_thirdparty_modules(
continue
applicable_lockfile_entries[entry] = ct

artifact_requirement_source = dataclasses.replace(
artifact_requirement,
coordinate=dataclasses.replace(
artifact_requirement.coordinate, classifier="sources"
),
)
entrySource = get_entry_for_coord(lockfile, artifact_requirement_source.coordinate)
if not entrySource:
_logger.warning(
f"No lockfile source entry for {artifact_requirement_source.coordinate} in resolve {resolve.name}."
)
continue
applicable_lockfile_source_entries[entrySource] = entry
applicable_lockfile_source_entries_inverse[entry] = [entrySource]

classpath_entries = await MultiGet(
Get(
ClasspathEntry,
Expand All @@ -151,11 +181,35 @@ async def collect_thirdparty_modules(
for target in applicable_lockfile_entries.values()
)

resolve_digest = await Get(Digest, MergeDigests(cpe.digest for cpe in classpath_entries))
digests = []
for cpe in classpath_entries:
digests.append(cpe.digest)
for alse in applicable_lockfile_source_entries:
new_file = FileEntry(alse.file_name, alse.file_digest)
digest = await Get(Digest, CreateDigest([new_file]))
digests.append(digest)

for dep in alse.dependencies:
coord = Coordinate.from_coord_str(dep.to_coord_str())
dep_artifact_requirement = ArtifactRequirement(coord)
dep_entry = get_entry_for_coord(lockfile, dep_artifact_requirement.coordinate)
dep_new_file = FileEntry(dep_entry.file_name, dep_entry.file_digest)
dep_digest = await Get(Digest, CreateDigest([dep_new_file]))
digests.append(dep_digest)
src_ent = applicable_lockfile_source_entries.get(alse)
applicable_lockfile_source_entries_inverse.get(src_ent).append(dep_entry)

resolve_digest = await Get(Digest, MergeDigests(digests))
inverse = dict(zip(classpath_entries, applicable_lockfile_entries))

s = map(
lambda x: (x, applicable_lockfile_source_entries_inverse.get(inverse.get(x))),
classpath_entries,
)

return ThirdpartyModules(
resolve,
dict(zip(applicable_lockfile_entries, classpath_entries)),
dict(zip(applicable_lockfile_entries, s)),
resolve_digest,
)

Expand Down Expand Up @@ -300,7 +354,7 @@ async def bsp_resolve_scala_metadata(

def _jdk_request_sort_key(
jvm: JvmSubsystem,
) -> Callable[[JdkRequest,], tuple[int, ...]]:
) -> Callable[[JdkRequest,], tuple[int, ...],]:
def sort_key_function(request: JdkRequest) -> tuple[int, ...]:
if request == JdkRequest.SYSTEM:
return (-1,)
Expand Down Expand Up @@ -337,11 +391,15 @@ class HandleScalacOptionsResult:

@_uncacheable_rule
async def handle_bsp_scalac_options_request(
request: HandleScalacOptionsRequest, build_root: BuildRoot, workspace: Workspace, scalac: Scalac
request: HandleScalacOptionsRequest,
build_root: BuildRoot,
workspace: Workspace,
scalac: Scalac,
) -> HandleScalacOptionsResult:
targets = await Get(Targets, BuildTargetIdentifier, request.bsp_target_id)
thirdparty_modules = await Get(
ThirdpartyModules, ThirdpartyModulesRequest(Addresses(tgt.address for tgt in targets))
ThirdpartyModules,
ThirdpartyModulesRequest(Addresses(tgt.address for tgt in targets)),
)
resolve = thirdparty_modules.resolve

Expand All @@ -352,12 +410,16 @@ async def handle_bsp_scalac_options_request(

local_plugins_prefix = f"jvm/resolves/{resolve.name}/plugins"
local_plugins = await Get(
ScalaPlugins, ScalaPluginsRequest.from_target_plugins(scalac_plugin_targets, resolve)
ScalaPlugins,
ScalaPluginsRequest.from_target_plugins(scalac_plugin_targets, resolve),
)

thirdparty_modules_prefix = f"jvm/resolves/{resolve.name}/lib"
thirdparty_modules_digest, local_plugins_digest = await MultiGet(
Get(Digest, AddPrefix(thirdparty_modules.merged_digest, thirdparty_modules_prefix)),
Get(
Digest,
AddPrefix(thirdparty_modules.merged_digest, thirdparty_modules_prefix),
),
Get(Digest, AddPrefix(local_plugins.classpath.digest, local_plugins_prefix)),
)

Expand All @@ -370,7 +432,7 @@ async def handle_bsp_scalac_options_request(
build_root.pathlib_path.joinpath(
f".pants.d/bsp/{thirdparty_modules_prefix}/{filename}"
).as_uri()
for cp_entry in thirdparty_modules.entries.values()
for cp_entry, _ in thirdparty_modules.entries.values()
for filename in cp_entry.filenames
)

Expand All @@ -387,7 +449,9 @@ async def handle_bsp_scalac_options_request(


@rule
async def bsp_scalac_options_request(request: ScalacOptionsParams) -> ScalacOptionsResult:
async def bsp_scalac_options_request(
request: ScalacOptionsParams,
) -> ScalacOptionsResult:
results = await MultiGet(
Get(HandleScalacOptionsResult, HandleScalacOptionsRequest(btgt)) for btgt in request.targets
)
Expand All @@ -407,7 +471,9 @@ class ScalaMainClassesHandlerMapping(BSPHandlerMapping):


@rule
async def bsp_scala_main_classes_request(request: ScalaMainClassesParams) -> ScalaMainClassesResult:
async def bsp_scala_main_classes_request(
request: ScalaMainClassesParams,
) -> ScalaMainClassesResult:
# TODO: This is a stub. VSCode/Metals calls this RPC and expects it to exist.
return ScalaMainClassesResult(
items=(),
Expand All @@ -428,14 +494,22 @@ class ScalaTestClassesHandlerMapping(BSPHandlerMapping):


@rule
async def bsp_scala_test_classes_request(request: ScalaTestClassesParams) -> ScalaTestClassesResult:
async def bsp_scala_test_classes_request(
request: ScalaTestClassesParams,
) -> ScalaTestClassesResult:
# TODO: This is a stub. VSCode/Metals calls this RPC and expects it to exist.
return ScalaTestClassesResult(
items=(),
origin_id=request.origin_id,
)


# -----------------------------------------------------------------------------------------------
# Dependency Sources
# -----------------------------------------------------------------------------------------------

# TODO

# -----------------------------------------------------------------------------------------------
# Dependency Modules
# -----------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -464,33 +538,53 @@ async def scala_bsp_dependency_modules(
ThirdpartyModules,
ThirdpartyModulesRequest(Addresses(fs.address for fs in request.field_sets)),
)

resolve = thirdparty_modules.resolve

resolve_digest = await Get(
Digest, AddPrefix(thirdparty_modules.merged_digest, f"jvm/resolves/{resolve.name}/lib")
Digest,
AddPrefix(thirdparty_modules.merged_digest, f"jvm/resolves/{resolve.name}/lib"),
)

modules = [
DependencyModule(
name=f"{entry.coord.group}:{entry.coord.artifact}",
version=entry.coord.version,
data=MavenDependencyModule(
organization=entry.coord.group,
name=entry.coord.artifact,
modules = []

for entry, (cp_entry, source_entry) in thirdparty_modules.entries.items():
a1 = [
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{filename}"
).as_uri(),
)
for filename in cp_entry.filenames
]

a2 = None
if source_entry is not None:
a2 = [
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{se.file_name}"
).as_uri(),
classifier="sources",
)
for se in source_entry
]
else:
a2 = []

modules.append(
DependencyModule(
name=f"{entry.coord.group}:{entry.coord.artifact}",
version=entry.coord.version,
scope=None,
artifacts=tuple(
MavenDependencyModuleArtifact(
uri=build_root.pathlib_path.joinpath(
f".pants.d/bsp/jvm/resolves/{resolve.name}/lib/{filename}"
).as_uri()
)
for filename in cp_entry.filenames
data=MavenDependencyModule(
organization=entry.coord.group,
name=entry.coord.artifact,
version=entry.coord.version,
scope=None,
artifacts=tuple(a1 + a2),
),
),
)
)
for entry, cp_entry in thirdparty_modules.entries.items()
]

return BSPDependencyModulesResult(
modules=tuple(modules),
Expand Down
33 changes: 30 additions & 3 deletions src/python/pants/bsp/util_rules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools
import logging
import typing
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -57,6 +58,7 @@
Targets,
)
from pants.engine.unions import UnionMembership, UnionRule, union
from pants.jvm.bsp.spec import MavenDependencyModule, MavenDependencyModuleArtifact
from pants.source.source_root import SourceRootsRequest, SourceRootsResult
from pants.util.frozendict import FrozenDict
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
Expand Down Expand Up @@ -397,6 +399,7 @@ async def generate_one_bsp_build_target_request(
# directory or else be configurable by the user. It is used as a hint in IntelliJ for where to place the
# corresponding IntelliJ module.
source_info = await Get(BSPBuildTargetSourcesInfo, BSPBuildTargetInternal, request.bsp_target)

if source_info.source_roots:
roots = [build_root.pathlib_path.joinpath(p) for p in source_info.source_roots]
else:
Expand Down Expand Up @@ -479,7 +482,6 @@ async def materialize_bsp_build_target_sources(
) -> MaterializeBuildTargetSourcesResult:
bsp_target = await Get(BSPBuildTargetInternal, BuildTargetIdentifier, request.bsp_target_id)
source_info = await Get(BSPBuildTargetSourcesInfo, BSPBuildTargetInternal, bsp_target)

if source_info.source_roots:
roots = [build_root.pathlib_path.joinpath(p) for p in source_info.source_roots]
else:
Expand Down Expand Up @@ -516,6 +518,18 @@ async def bsp_build_target_sources(request: SourcesParams) -> SourcesResult:
# -----------------------------------------------------------------------------------------------


@dataclass(frozen=True)
class BSPDependencySourcesRequest:
"""Hook to allow language backends to provide dependency sources."""

params: DependencySourcesParams


@dataclass(frozen=True)
class BSPDependencyModulesResult:
result: DependencySourcesResult


class DependencySourcesHandlerMapping(BSPHandlerMapping):
method_name = "buildTarget/dependencySources"
request_type = DependencySourcesParams
Expand All @@ -524,9 +538,22 @@ class DependencySourcesHandlerMapping(BSPHandlerMapping):

@rule
async def bsp_dependency_sources(request: DependencySourcesParams) -> DependencySourcesResult:
# TODO: This is a stub.
dependency_modules = await Get(
DependencyModulesResult, DependencyModulesParams, DependencyModulesParams(request.targets)
)

sources = {}
for i in dependency_modules.items:
for m in i.modules:
mavenmod: MavenDependencyModule = m.data
for x in mavenmod.artifacts:
if x.classifier == "sources":
sources[x.uri] = x

files = sources.keys()

return DependencySourcesResult(
tuple(DependencySourcesItem(target=tgt, sources=()) for tgt in request.targets)
tuple(DependencySourcesItem(target=tgt, sources=tuple(files)) for tgt in request.targets)
)


Expand Down
Loading

0 comments on commit 960c6ae

Please sign in to comment.