Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Jp 3749 none #8868

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ jobs:
default_python: "3.12"
envs: |
- linux: check-dependencies
- linux: check-types
latest_crds_contexts:
uses: spacetelescope/crds/.github/workflows/contexts.yml@94138b4501c9487535fd6b977492fc1a2c319708 # 12.0.2
crds_context:
Expand Down
1 change: 1 addition & 0 deletions changes/8852.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added mypy type checking to CI checks
2 changes: 1 addition & 1 deletion jwst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_regex_git_hash = re.compile(r".*\+g(\w+)")
__version_commit__ = ""
if "+" in __version__:
commit = _regex_git_hash.match(__version__).groups()
commit = _regex_git_hash.match(__version__).groups() # type: ignore
if commit:
__version_commit__ = commit[0]
2 changes: 1 addition & 1 deletion jwst/ami/ami_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import copy

from jwst.datamodels import CubeModel, ImageModel
from jwst.datamodels import CubeModel, ImageModel # type: ignore[attr-defined]

from .find_affine2d_parameters import find_rotation
from . import instrument_data
Expand Down
2 changes: 1 addition & 1 deletion jwst/ami/leastsqnrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def matrix_operations(img, model, flux=None, linfit=False, dqm=None):

if linfit:
try:
from linearfit import linearfit
from linearfit import linearfit # type: ignore[import-not-found]

# dependent variables
M = np.asmatrix(flatimg)
Expand Down
2 changes: 1 addition & 1 deletion jwst/ami/matrix_dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def matrix_idft(*args, **kwargs):
return matrix_dft(*args, **kwargs)


matrix_idft.__doc__ = matrix_dft.__doc__.replace(
matrix_idft.__doc__ = matrix_dft.__doc__.replace( # type: ignore
'Perform a matrix discrete Fourier transform',
'Perform an inverse matrix discrete Fourier transform'
)
Expand Down
9 changes: 5 additions & 4 deletions jwst/assign_wcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _reproject(x, y):


def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
disp_axis: int = None, pscale_ratio: float = None) -> float:
disp_axis: int | None = None, pscale_ratio: float | None = None) -> float:
"""Compute scaling transform.

Parameters
Expand Down Expand Up @@ -137,8 +137,8 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
crval_with_offsets = wcs(*crpix_with_offsets, with_bounding_box=False)

coords = SkyCoord(ra=crval_with_offsets[spatial_idx[0]], dec=crval_with_offsets[spatial_idx[1]], unit="deg")
xscale = np.abs(coords[0].separation(coords[1]).value)
yscale = np.abs(coords[0].separation(coords[2]).value)
xscale: float = np.abs(coords[0].separation(coords[1]).value)
yscale: float = np.abs(coords[0].separation(coords[2]).value)

if pscale_ratio is not None:
xscale *= pscale_ratio
Expand All @@ -149,7 +149,8 @@ def compute_scale(wcs: WCS, fiducial: Union[tuple, np.ndarray],
# Assuming disp_axis is consistent with DataModel.meta.wcsinfo.dispersion.direction
return yscale if disp_axis == 1 else xscale

return np.sqrt(xscale * yscale)
scale: float = np.sqrt(xscale * yscale)
return scale


def calc_rotation_matrix(roll_ref: float, v3i_yang: float, vparity: int = 1) -> List[float]:
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

# Take version from the upstream package
from .. import __version__
from jwst import __version__


# Utility
Expand Down
4 changes: 2 additions & 2 deletions jwst/associations/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ class Association(MutableMapping):
GLOBAL_CONSTRAINT = None
"""Global constraints"""

INVALID_VALUES = None
INVALID_VALUES: tuple | None = None
"""Attribute values that indicate the
attribute is not specified.
"""

ioregistry = IORegistry()
ioregistry: IORegistry = IORegistry()
"""The association IO registry"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/association_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

__all__ = []
__all__: list = []


# Define JSON encoder to convert `Member` to `dict`
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/lib/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class SimpleConstraintABC(abc.ABC):
"""

# Attributes to show in the string representation.
_str_attrs = ('name', 'value')
_str_attrs: tuple = ('name', 'value')

def __new__(cls, *args, **kwargs):
"""Force creation of the constraint attribute dict before anything else."""
Expand Down
2 changes: 1 addition & 1 deletion jwst/associations/lib/rules_level3_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class DMS_Level3_Base(DMSBaseMixin, Association):
INVALID_VALUES = _EMPTY

# Make sequences type-dependent
_sequences = defaultdict(Counter)
_sequences: defaultdict = defaultdict(Counter)

def __init__(self, *args, **kwargs):

Expand Down
4 changes: 2 additions & 2 deletions jwst/associations/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class BasePoolRule():

# Define the pools and testing parameters related to them.
# Each entry is a tuple starting with the path of the pool.
pools = []
pools: list = []

# Define the rules that SHOULD be present.
# Each entry is the class name of the rule.
valid_rules = []
valid_rules: list = []

def test_rules_exist(self):
rules = registry_level3_only()
Expand Down
11 changes: 5 additions & 6 deletions jwst/badpix_selfcal/badpix_selfcal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

import jwst.datamodels as dm
from jwst.datamodels import IFUImageModel # type: ignore[attr-defined]
from stcal.outlier_detection.utils import medfilt
from stdatamodels.jwst.datamodels.dqflags import pixel

Expand All @@ -12,7 +12,7 @@ def badpix_selfcal(minimg: np.ndarray,
flagfrac_lower: float = 0.001,
flagfrac_upper: float = 0.001,
kernel_size: int = 15,
dispaxis=None) -> np.ndarray:
dispaxis=None) -> tuple:
"""
Flag residual artifacts as bad pixels in the DQ array of a JWST exposure

Expand Down Expand Up @@ -59,26 +59,25 @@ def badpix_selfcal(minimg: np.ndarray,
flag_low, flag_high = np.nanpercentile(minimg_hpf, [flagfrac_lower * 100, (1 - flagfrac_upper) * 100])
bad = (minimg_hpf > flag_high) | (minimg_hpf < flag_low)
flagged_indices = np.where(bad)

return flagged_indices


def apply_flags(input_model: dm.IFUImageModel, flagged_indices: np.ndarray) -> dm.IFUImageModel:
def apply_flags(input_model: IFUImageModel, flagged_indices: np.ndarray) -> IFUImageModel:
"""
Apply the flagged indices to the input model. Sets the flagged pixels to NaN
and the DQ flag to DO_NOT_USE + OTHER_BAD_PIXEL

Parameters
----------
input_model : dm.IFUImageModel
input_model : IFUImageModel
Input science data to be corrected
flagged_indices : np.ndarray
Indices of the flagged pixels,
shaped like output from np.where

Returns
-------
output_model : dm.IFUImageModel
output_model : IFUImageModel
Flagged data model
"""

Expand Down
2 changes: 1 addition & 1 deletion jwst/cube_skymatch/cube_skymatch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CubeSkyMatchStep(Step):
binwidth = float(min=0.0, default=0.1) # Bin width for 'mode' and 'midpt' `skystat`, in sigma
"""

reference_file_types = []
reference_file_types: list = []

def process(self, input1, input2):
cube_models = ModelContainer(input1)
Expand Down
2 changes: 1 addition & 1 deletion jwst/datamodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_deprecated_modules = ['schema']

# Deprecated models in stdatamodels
_deprecated_models = []
_deprecated_models: list[str] = []

# Import all submodules from stdatamodels.jwst.datamodels
for attr in dir(stdatamodels.jwst.datamodels):
Expand Down
2 changes: 1 addition & 1 deletion jwst/dq_init/dq_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from stdatamodels.jwst import datamodels

from ..lib import reffile_utils
from jwst.datamodels import dqflags
from jwst.datamodels import dqflags # type: ignore[attr-defined]

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)
Expand Down
4 changes: 2 additions & 2 deletions jwst/dq_init/tests/test_dq_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ def test_dq_add1_groupdq():

# Set parameters for multiple runs of guider data
args = "xstart, ystart, xsize, ysize, nints, ngroups, instrument, exp_type, detector"
test_data = [(1, 1, 2048, 2048, 2, 2, 'FGS', 'FGS_ID-IMAGE', 'GUIDER1'),
test_data_multiple = [(1, 1, 2048, 2048, 2, 2, 'FGS', 'FGS_ID-IMAGE', 'GUIDER1'),
(1, 1, 1032, 1024, 1, 5, 'MIRI', 'MIR_IMAGE', 'MIRIMAGE')]
ids = ["GuiderRawModel-Image", "RampModel"]


@pytest.mark.parametrize(args, test_data, ids=ids)
@pytest.mark.parametrize(args, test_data_multiple, ids=ids)
def test_fullstep(xstart, ystart, xsize, ysize, nints, ngroups, instrument, exp_type, detector):
"""Test that the full step runs"""

Expand Down
37 changes: 18 additions & 19 deletions jwst/extract_1d/apply_apcorr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import abc

from typing import Tuple, Union, Type
from scipy.interpolate import RectBivariateSpline, interp1d
from astropy.io import fits
from stdatamodels import DataModel
from stdatamodels.jwst.datamodels import MultiSlitModel

from ..assign_wcs.util import compute_scale
Expand Down Expand Up @@ -55,10 +52,8 @@ class ApCorrBase(abc.ABC):
}
}

size_key = None

def __init__(self, input_model: DataModel, apcorr_table: fits.FITS_rec, sizeunit: str,
location: Tuple[float, float, float] = None, slit_name: str = None, **match_kwargs):
def __init__(self, input_model, apcorr_table, sizeunit,
location = None, slit_name = None, **match_kwargs):
self.correction = None

self.model = input_model
Expand All @@ -75,6 +70,10 @@ def __init__(self, input_model: DataModel, apcorr_table: fits.FITS_rec, sizeunit
self.apcorr_func = self.approximate()
self.tabulated_correction = None

@property
def size_key(self):
return None

def _convert_size_units(self):
"""If the SIZE or Radius column is in units of arcseconds, convert to pixels."""
if self.apcorr_sizeunits.startswith('arcsec'):
Expand Down Expand Up @@ -102,7 +101,7 @@ def _convert_size_units(self):
'pixels.'
)

def _get_match_keys(self) -> dict:
def _get_match_keys(self):
"""Get column keys needed for reducing the reference table based on input."""
instrument = self.model.meta.instrument.name.upper()
exptype = self.model.meta.exposure.type.upper()
Expand All @@ -113,7 +112,7 @@ def _get_match_keys(self) -> dict:
if key in exptype:
return relevant_pars[key]

def _get_match_pars(self) -> dict:
def _get_match_pars(self):
"""Get meta parameters required for reference table row-selection."""
match_pars = {}

Expand All @@ -125,7 +124,7 @@ def _get_match_pars(self) -> dict:

return match_pars

def _reduce_reftable(self) -> fits.FITS_record:
def _reduce_reftable(self):
"""Reduce full reference table to a single matched row."""
table = self._reference_table.copy()

Expand All @@ -145,7 +144,7 @@ def approximate(self):
"""Generate an approximate aperture correction function based on input data."""
pass

def apply(self, spec_table: fits.FITS_rec):
def apply(self, spec_table):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -181,14 +180,14 @@ class ApCorrPhase(ApCorrBase):
"""
size_key = 'size'

def __init__(self, *args, pixphase: float = 0.5, **kwargs):
def __init__(self, *args, pixphase = 0.5, **kwargs):
self.phase = pixphase # In the future we'll attempt to measure the pixel phase from inputs.

super().__init__(*args, **kwargs)

def approximate(self):
"""Generate an approximate function for interpolating apcorr values to input wavelength and size."""
def _approx_func(wavelength: float, size: float, pixel_phase: float) -> RectBivariateSpline:
def _approx_func(wavelength, size, pixel_phase):
"""Create a 'custom' approximation function that approximates the aperture correction in two stages based on
input data.

Expand Down Expand Up @@ -228,7 +227,7 @@ def _approx_func(wavelength: float, size: float, pixel_phase: float) -> RectBiva
def measure_phase(self): # Future method in determining pixel phase
pass

def tabulate_correction(self, spec_table: fits.FITS_rec):
def tabulate_correction(self, spec_table):
"""Tabulate the interpolated aperture correction value.

This will save time when applying it later, especially if it is to be applied to multiple integrations.
Expand All @@ -255,7 +254,7 @@ def tabulate_correction(self, spec_table: fits.FITS_rec):

self.tabulated_correction = np.asarray(coefs)

def apply(self, spec_table: fits.FITS_rec, use_tabulated=False):
def apply(self, spec_table, use_tabulated=False):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -297,8 +296,8 @@ def apply(self, spec_table: fits.FITS_rec, use_tabulated=False):
class ApCorrRadial(ApCorrBase):
"""Aperture correction class used with spectral data produced from an extraction aperture radius."""

def __init__(self, input_model: DataModel, apcorr_table,
location: Tuple[float, float, float] = None):
def __init__(self, input_model, apcorr_table,
location = None):

self.correction = None
self.model = input_model
Expand Down Expand Up @@ -329,7 +328,7 @@ def _convert_size_units(self):
'pixels.'
)

def apply(self, spec_table: fits.FITS_rec):
def apply(self, spec_table):
"""Apply interpolated aperture correction value to source-related extraction results in-place.

Parameters
Expand Down Expand Up @@ -410,7 +409,7 @@ def approximate(self):
return RectBivariateSpline(size, wavelength, apcorr.T, ky=1, kx=1)


def select_apcorr(input_model: DataModel) -> Union[Type[ApCorr], Type[ApCorrPhase], Type[ApCorrRadial]]:
def select_apcorr(input_model):
"""Select appropriate Aperture correction class based on input DataModel.

Parameters
Expand Down
Loading
Loading