Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
WPengXiang committed Oct 17, 2024
1 parent 9bfe688 commit 5c7e22b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
1 change: 1 addition & 0 deletions fealpy/fem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .linear_form import LinearForm
from .semilinear_form import SemilinearForm
from .block_form import BlockForm
from .linear_block_form import LinearBlockForm

### Cell Operator
from .scalar_diffusion_integrator import ScalarDiffusionIntegrator
Expand Down
34 changes: 34 additions & 0 deletions fealpy/fem/linear_block_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .form import Form
from typing import List,overload, Literal, Optional
from ..typing import TensorLike
from ..sparse import COOTensor

from ..backend import backend_manager as bm




class LinearBlockForm(Form):
_V = None

def __init__(self, blocks:List):
self.blocks = blocks
self.sparse_shape = self._get_sparse_shape()

def _get_sparse_shape(self):
shape = [i._get_sparse_shape() for i in self.blocks]
return (bm.sum(shape), )

@overload
def assembly(self, *, retain_ints: bool=False) -> TensorLike: ...
@overload
def assembly(self, *, format: Literal['coo'], retain_ints: bool=False) -> COOTensor: ...
@overload
def assembly(self, *, format: Literal['dense'], retain_ints: bool=False) -> TensorLike: ...
def assembly(self, *, format='dense', retain_ints: bool=False):
self._V = [i.assembly(format=format, retain_ints=retain_ints)for i in self.blocks]
self._v = bm.concatenate(self._V)
return self._V

Form.register(LinearBlockForm)

1 change: 1 addition & 0 deletions fealpy/mesh/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def quadrature_formula(self, q: int, etype: Union[int, str]='cell',

if etype == 2:
from ..quadrature.stroud_quadrature import StroudQuadrature
from ..quadrature import TriangleQuadrature
if q > 9:
quad = StroudQuadrature(2, q)
else:
Expand Down

0 comments on commit 5c7e22b

Please sign in to comment.