Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 committed Aug 5, 2024
1 parent 365b6df commit 90b0e43
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,19 +659,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
if not isinstance(node.op.core_op, Cholesky):
return None

inputs = node.inputs[0]
[input] = node.inputs
# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
input.owner
and isinstance(input.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(input.owner)
):
diag_input = inputs.owner.inputs[0]
diag_input = input.owner.inputs[0]
cholesky_val = pt.diag(diag_input**0.5)
return [cholesky_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
inputs_or_none = _find_diag_from_eye_mul(input)
if inputs_or_none is None:
return None

Expand All @@ -681,7 +681,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
if len(non_eye_inputs) != 1:
return None

Check warning on line 682 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L682

Added line #L682 was not covered by tests

non_eye_input = non_eye_inputs[0]
[non_eye_input] = non_eye_inputs

# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
Expand Down

0 comments on commit 90b0e43

Please sign in to comment.