Skip to content

Commit

Permalink
Merge pull request #433 from nipreps/enh/nibabies-fit-apply
Browse files Browse the repository at this point in the history
RF: Remove much of the hardcoded `T1w` fields in favor of `anat`
  • Loading branch information
effigies authored Jul 26, 2024
2 parents 7852cba + 4a8a782 commit c361b3a
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 265 deletions.
88 changes: 41 additions & 47 deletions smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@
from .fit.registration import init_register_template_wf
from .outputs import (
init_anat_reports_wf,
init_anat_second_derivatives_wf,
init_ds_anat_volumes_wf,
init_ds_dseg_wf,
init_ds_fs_registration_wf,
init_ds_fs_segs_wf,
init_ds_grayord_metrics_wf,
init_ds_mask_wf,
init_ds_surface_metrics_wf,
Expand Down Expand Up @@ -287,7 +287,6 @@ def init_anat_preproc_wf(
ds_std_volumes_wf = init_ds_anat_volumes_wf(
bids_root=bids_root,
output_dir=output_dir,
name='ds_std_volumes_wf',
)

workflow.connect([
Expand Down Expand Up @@ -320,10 +319,10 @@ def init_anat_preproc_wf(
]),
(anat_fit_wf, ds_std_volumes_wf, [
('outputnode.t1w_valid_list', 'inputnode.source_files'),
('outputnode.t1w_preproc', 'inputnode.t1w_preproc'),
('outputnode.t1w_mask', 'inputnode.t1w_mask'),
('outputnode.t1w_dseg', 'inputnode.t1w_dseg'),
('outputnode.t1w_tpms', 'inputnode.t1w_tpms'),
('outputnode.t1w_preproc', 'inputnode.anat_preproc'),
('outputnode.t1w_mask', 'inputnode.anat_mask'),
('outputnode.t1w_dseg', 'inputnode.anat_dseg'),
('outputnode.t1w_tpms', 'inputnode.anat_tpms'),
]),
(template_iterator_wf, ds_std_volumes_wf, [
('outputnode.std_t1w', 'inputnode.ref_file'),
Expand All @@ -335,17 +334,12 @@ def init_anat_preproc_wf(
]) # fmt:skip

if freesurfer:
anat_second_derivatives_wf = init_anat_second_derivatives_wf(
ds_fs_segs_wf = init_ds_fs_segs_wf(
bids_root=bids_root,
output_dir=output_dir,
cifti_output=cifti_output,
)
surface_derivatives_wf = init_surface_derivatives_wf(
cifti_output=cifti_output,
)
ds_surfaces_wf = init_ds_surfaces_wf(
bids_root=bids_root, output_dir=output_dir, surfaces=['inflated']
)
surface_derivatives_wf = init_surface_derivatives_wf()
ds_surfaces_wf = init_ds_surfaces_wf(output_dir=output_dir, surfaces=['inflated'])
ds_curv_wf = init_ds_surface_metrics_wf(
bids_root=bids_root, output_dir=output_dir, metrics=['curv'], name='ds_curv_wf'
)
Expand All @@ -355,7 +349,7 @@ def init_anat_preproc_wf(
('outputnode.t1w_preproc', 'inputnode.reference'),
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2anat_xfm'),
]),
(anat_fit_wf, ds_surfaces_wf, [
('outputnode.t1w_valid_list', 'inputnode.source_files'),
Expand All @@ -369,12 +363,12 @@ def init_anat_preproc_wf(
(surface_derivatives_wf, ds_curv_wf, [
('outputnode.curv', 'inputnode.curv'),
]),
(anat_fit_wf, anat_second_derivatives_wf, [
(anat_fit_wf, ds_fs_segs_wf, [
('outputnode.t1w_valid_list', 'inputnode.source_files'),
]),
(surface_derivatives_wf, anat_second_derivatives_wf, [
('outputnode.out_aseg', 'inputnode.t1w_fs_aseg'),
('outputnode.out_aparc', 'inputnode.t1w_fs_aparc'),
(surface_derivatives_wf, ds_fs_segs_wf, [
('outputnode.out_aseg', 'inputnode.anat_fs_aseg'),
('outputnode.out_aparc', 'inputnode.anat_fs_aparc'),
]),
(surface_derivatives_wf, outputnode, [
('outputnode.out_aseg', 't1w_aseg'),
Expand Down Expand Up @@ -765,10 +759,12 @@ def init_anat_fit_wf(
longitudinal=longitudinal,
omp_nthreads=omp_nthreads,
num_files=num_t1w,
contrast='T1w',
image_type='T1w',
name='anat_template_wf',
)
ds_template_wf = init_ds_template_wf(output_dir=output_dir, num_t1w=num_t1w)
ds_template_wf = init_ds_template_wf(
output_dir=output_dir, num_anat=num_t1w, image_type='T1w'
)

# fmt:off
workflow.connect([
Expand All @@ -781,11 +777,11 @@ def init_anat_fit_wf(
('outputnode.out_report', 'inputnode.t1w_conform_report'),
]),
(anat_template_wf, ds_template_wf, [
('outputnode.anat_realign_xfm', 'inputnode.t1w_ref_xfms'),
('outputnode.anat_realign_xfm', 'inputnode.anat_ref_xfms'),
]),
(sourcefile_buffer, ds_template_wf, [('source_files', 'inputnode.source_files')]),
(t1w_buffer, ds_template_wf, [('t1w_preproc', 'inputnode.t1w_preproc')]),
(ds_template_wf, outputnode, [('outputnode.t1w_preproc', 't1w_preproc')]),
(t1w_buffer, ds_template_wf, [('t1w_preproc', 'inputnode.anat_preproc')]),
(ds_template_wf, outputnode, [('outputnode.anat_preproc', 't1w_preproc')]),
])
# fmt:on
else:
Expand Down Expand Up @@ -954,16 +950,16 @@ def init_anat_fit_wf(
workflow.connect([
(fast, lut_t1w_dseg, [('partial_volume_map', 'in_dseg')]),
(sourcefile_buffer, ds_dseg_wf, [('source_files', 'inputnode.source_files')]),
(lut_t1w_dseg, ds_dseg_wf, [('out', 'inputnode.t1w_dseg')]),
(ds_dseg_wf, seg_buffer, [('outputnode.t1w_dseg', 't1w_dseg')]),
(lut_t1w_dseg, ds_dseg_wf, [('out', 'inputnode.anat_dseg')]),
(ds_dseg_wf, seg_buffer, [('outputnode.anat_dseg', 't1w_dseg')]),
])
if not have_tpms:
ds_tpms_wf = init_ds_tpms_wf(output_dir=output_dir)
workflow.connect([
(fast, fast2bids, [('partial_volume_files', 'inlist')]),
(sourcefile_buffer, ds_tpms_wf, [('source_files', 'inputnode.source_files')]),
(fast2bids, ds_tpms_wf, [('out', 'inputnode.t1w_tpms')]),
(ds_tpms_wf, seg_buffer, [('outputnode.t1w_tpms', 't1w_tpms')]),
(fast2bids, ds_tpms_wf, [('out', 'inputnode.anat_tpms')]),
(ds_tpms_wf, seg_buffer, [('outputnode.anat_tpms', 't1w_tpms')]),
])
# fmt:on
else:
Expand Down Expand Up @@ -998,7 +994,9 @@ def init_anat_fit_wf(
omp_nthreads=omp_nthreads,
templates=templates,
)
ds_template_registration_wf = init_ds_template_registration_wf(output_dir=output_dir)
ds_template_registration_wf = init_ds_template_registration_wf(
output_dir=output_dir, image_type='T1w'
)

# fmt:off
workflow.connect([
Expand Down Expand Up @@ -1081,17 +1079,17 @@ def init_anat_fit_wf(

fsnative_xfms = precomputed.get('transforms', {}).get('fsnative')
if not fsnative_xfms:
ds_fs_registration_wf = init_ds_fs_registration_wf(output_dir=output_dir)
ds_fs_registration_wf = init_ds_fs_registration_wf(output_dir=output_dir, image_type='T1w')
# fmt:off
workflow.connect([
(sourcefile_buffer, ds_fs_registration_wf, [
('source_files', 'inputnode.source_files'),
]),
(surface_recon_wf, ds_fs_registration_wf, [
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2anat_xfm'),
]),
(ds_fs_registration_wf, outputnode, [
('outputnode.fsnative2t1w_xfm', 'fsnative2t1w_xfm'),
('outputnode.fsnative2anat_xfm', 'fsnative2t1w_xfm'),
]),
])
# fmt:on
Expand All @@ -1114,7 +1112,7 @@ def init_anat_fit_wf(
(surface_recon_wf, refinement_wf, [
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2anat_xfm'),
]),
(t1w_buffer, refinement_wf, [
('t1w_preproc', 'inputnode.reference_image'),
Expand All @@ -1135,7 +1133,7 @@ def init_anat_fit_wf(
longitudinal=longitudinal,
omp_nthreads=omp_nthreads,
num_files=len(t2w),
contrast='T2w',
image_type='T2w',
name='t2w_template_wf',
)
bbreg = pe.Node(
Expand Down Expand Up @@ -1216,15 +1214,13 @@ def init_anat_fit_wf(
LOGGER.info(f'ANAT Stage 8: Creating GIFTI surfaces for {surfs + spheres}')
if surfs:
gifti_surfaces_wf = init_gifti_surfaces_wf(surfaces=surfs)
ds_surfaces_wf = init_ds_surfaces_wf(
bids_root=bids_root, output_dir=output_dir, surfaces=surfs
)
ds_surfaces_wf = init_ds_surfaces_wf(output_dir=output_dir, surfaces=surfs)
# fmt:off
workflow.connect([
(surface_recon_wf, gifti_surfaces_wf, [
('outputnode.subject_id', 'inputnode.subject_id'),
('outputnode.subjects_dir', 'inputnode.subjects_dir'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2t1w_xfm'),
('outputnode.fsnative2t1w_xfm', 'inputnode.fsnative2anat_xfm'),
]),
(gifti_surfaces_wf, surfaces_buffer, [
(f'outputnode.{surf}', surf) for surf in surfs
Expand All @@ -1240,7 +1236,7 @@ def init_anat_fit_wf(
surfaces=spheres, to_scanner=False, name='gifti_spheres_wf'
)
ds_spheres_wf = init_ds_surfaces_wf(
bids_root=bids_root, output_dir=output_dir, surfaces=spheres, name='ds_spheres_wf'
output_dir=output_dir, surfaces=spheres, name='ds_spheres_wf'
)
# fmt:off
workflow.connect([
Expand Down Expand Up @@ -1316,7 +1312,6 @@ def init_anat_fit_wf(
LOGGER.info('ANAT Stage 9: Creating fsLR registration sphere')
fsLR_reg_wf = init_fsLR_reg_wf()
ds_fsLR_reg_wf = init_ds_surfaces_wf(
bids_root=bids_root,
output_dir=output_dir,
surfaces=['sphere_reg_fsLR'],
name='ds_fsLR_reg_wf',
Expand All @@ -1341,7 +1336,6 @@ def init_anat_fit_wf(
LOGGER.info('ANAT Stage 10: Creating MSM-Sulc registration sphere')
msm_sulc_wf = init_msm_sulc_wf(sloppy=sloppy)
ds_msmsulc_wf = init_ds_surfaces_wf(
bids_root=bids_root,
output_dir=output_dir,
surfaces=['sphere_reg_msm'],
name='ds_msmsulc_wf',
Expand Down Expand Up @@ -1375,7 +1369,7 @@ def init_anat_template_wf(
longitudinal: bool,
omp_nthreads: int,
num_files: int,
contrast: str,
image_type: ty.Literal['T1w', 'T2w'],
name: str = 'anat_template_wf',
):
"""
Expand All @@ -1388,7 +1382,7 @@ def init_anat_template_wf(
from smriprep.workflows.anatomical import init_anat_template_wf
wf = init_anat_template_wf(
longitudinal=False, omp_nthreads=1, num_files=1, contrast="T1w"
longitudinal=False, omp_nthreads=1, num_files=1, image_type="T1w"
)
Parameters
Expand All @@ -1400,8 +1394,8 @@ def init_anat_template_wf(
Maximum number of threads an individual process may use
num_files : :obj:`int`
Number of images
contrast : :obj:`str`
Name of contrast, for reporting purposes, e.g., T1w, T2w, PDw
image_type : :obj:`str`
MR image type (T1w, T2w, etc.)
name : :obj:`str`, optional
Workflow name (default: anat_template_wf)
Expand All @@ -1427,8 +1421,8 @@ def init_anat_template_wf(
if num_files > 1:
fs_ver = fs.Info().looseversion() or '(version unknown)'
workflow.__desc__ = f"""\
An anatomical {contrast}-reference map was computed after registration of
{num_files} {contrast} images (after INU-correction) using
An anatomical {image_type}-reference map was computed after registration of
{num_files} {image} images (after INU-correction) using
`mri_robust_template` [FreeSurfer {fs_ver}, @fs_template].
"""

Expand Down
Loading

0 comments on commit c361b3a

Please sign in to comment.