Skip to content

Commit

Permalink
Expect symbolic num argument
Browse files Browse the repository at this point in the history
Fill out `_broadcast_inputs` docstring
  • Loading branch information
jessegrabowski committed May 18, 2024
1 parent 6ad98e4 commit 5499ae5
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,8 @@ def _linspace_core(


def _broadcast_inputs(*args):
"""Helper function to preprocess inputs to *space Ops"""

args = map(ptb.as_tensor_variable, args)
args = broadcast_arrays(*args)

Expand All @@ -1651,14 +1653,23 @@ def _broadcast_base_with_inputs(start, stop, base, axis):
Parameters
----------
start
stop
base
axis
start: TensorVariable
The start value(s) of the sequence(s).
stop: TensorVariable
The end value(s) of the sequence(s)
base: TensorVariable
The log base value(s) of the sequence(s)
axis: int
The axis along which to generate samples.
Returns
-------
start: TensorVariable
The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
stop: TensorVariable
The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
base: TensorVariable
The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
"""
base = ptb.as_tensor_variable(base)
if base.ndim > 0:
Expand Down Expand Up @@ -1839,10 +1850,9 @@ def geomspace(
)
result = base**result

if num > 0:
result = set_subtensor(result[0, ...], start)
if num > 1 and endpoint:
result = set_subtensor(result[-1, ...], stop)
result = switch(gt(num, 0), set_subtensor(result[0, ...], start), result)
if endpoint:
result = switch(gt(num, 1), set_subtensor(result[-1, ...], stop), result)

result = result * out_sign

Expand Down

0 comments on commit 5499ae5

Please sign in to comment.