Skip to content

Commit

Permalink
Fix properties to be batched
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Jan 6, 2024
1 parent a077aa2 commit c893beb
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def is_skippable(self) -> bool:

@property
def is_active(self) -> bool:
return self.k1 != 0
return any(self.k1 != 0)

def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
Expand Down Expand Up @@ -472,10 +472,11 @@ def __init__(

@property
def hx(self) -> torch.Tensor:
if self.length == 0.0:
return torch.tensor(0.0, device=self.length.device, dtype=self.length.dtype)
else:
return self.angle / self.length
value = torch.zeros_like(self.length)
value[self.length != 0] = (
self.angle[self.length != 0] / self.length[self.length != 0]
)
return value

@property
def is_skippable(self) -> bool:
Expand Down Expand Up @@ -712,7 +713,7 @@ def is_skippable(self) -> bool:

@property
def is_active(self) -> bool:
return self.angle != 0
return any(self.angle != 0)

def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
Expand Down Expand Up @@ -799,7 +800,7 @@ def is_skippable(self) -> bool:

@property
def is_active(self) -> bool:
return self.angle != 0
return any(self.angle != 0)

def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
Expand Down Expand Up @@ -875,7 +876,7 @@ def __init__(

@property
def is_active(self) -> bool:
return self.voltage != 0
return any(self.voltage != 0)

@property
def is_skippable(self) -> bool:
Expand Down Expand Up @@ -1798,7 +1799,7 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:

@property
def is_active(self) -> bool:
return self.k != 0
return any(self.k != 0)

def is_skippable(self) -> bool:
return True
Expand Down

0 comments on commit c893beb

Please sign in to comment.