Skip to content

Commit

Permalink
Refactor @einsum (#223)
Browse files Browse the repository at this point in the history
* Refactor `@einsum`

* Update docs for einsum
  • Loading branch information
KeitaNakamura authored Oct 16, 2024
1 parent 8ffe6f4 commit d955d96
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 109 deletions.
183 changes: 74 additions & 109 deletions src/einsum.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
"""
@einsum (i,j...) -> expr
@einsum expr
@einsum [TensorType] (i,j...) -> expr
@einsum [TensorType] expr
Conducts tensor computation based on [Einstein summation convention](https://en.wikipedia.org/wiki/Einstein_notation).
The arguments of the anonymous function are regard as **free indices**.
If arguments are not given, they are guessed based on the order that indices appears from left to right.
Performs tensor computations using the [Einstein summation convention](https://en.wikipedia.org/wiki/Einstein_notation).
The arguments of the anonymous function are treated as **free indices**.
If no arguments are provided, they are inferred based on the order
in which the indices appear from left to right. Since `@einsum` cannot
fully infer tensor symmetries, it is possible to annotate the returned
tensor type (though this is not checked for correctness).
This can help eliminate the computation of the symmetric part, improving performance.
# Examples
```jldoctest einsum
julia> A = rand(Mat{3,3})
3×3 Tensor{Tuple{3, 3}, Float64, 2, 9}:
0.325977 0.894245 0.953125
0.549051 0.353112 0.795547
0.218587 0.394255 0.49425
julia> B = rand(Mat{3,3})
3×3 Tensor{Tuple{3, 3}, Float64, 2, 9}:
0.748415 0.00744801 0.682533
0.578232 0.199377 0.956741
0.727935 0.439243 0.647855
julia> @einsum (i,j) -> A[i,k] * B[k,j]
3×3 Tensor{Tuple{3, 3}, Float64, 2, 9}:
1.45486 0.599373 1.69554
1.19421 0.42393 1.22798
0.751346 0.297329 0.846595
julia> @einsum A[i,k] * B[k,j] # same as above
3×3 Tensor{Tuple{3, 3}, Float64, 2, 9}:
1.45486 0.599373 1.69554
1.19421 0.42393 1.22798
0.751346 0.297329 0.846595
julia> @einsum A[i,j] * B[i,j]
2.7026716125808266
julia> A = rand(Mat{3,3});
julia> B = rand(Mat{3,3});
julia> (@einsum (i,j) -> A[j,k] * B[k,i]) ≈ (A ⋅ B)'
true
julia> (@einsum A[i,k] * B[k,j]) ≈ A ⋅ B
true
julia> (@einsum A[i,j] * A[i,j]) ≈ A ⊡ A
true
julia> (@einsum SymmetricSecondOrderTensor{3} A[k,i] * A[k,j]) ≈ A' ⋅ A
true
```
"""
macro einsum(expr)
Expand All @@ -49,7 +42,7 @@ function einsum_exprssion(TT, expr)
freeinds === nothing && return einex.ex
isempty(freeinds) && return einex.ex
perm = find_perm(einex.freeinds => freeinds)
:(convert($TT, permutedims($(einex.ex), $(ValTuple(perm...)))))
:(permutedims($(einex.ex), $(ValTuple(perm...))))
end

ValTuple(x...) = Val(x)
Expand Down Expand Up @@ -156,51 +149,35 @@ function einsum_instantiate_addition(op::Symbol, lhs::EinsumExpr, rhs::EinsumExp
end

# contraction
function einsum_instantiate_contraction(lhs::EinsumExpr, rhs::EinsumExpr)
if isscalarexpr(lhs) || isscalarexpr(rhs)
ex = Expr(:call, :*, lhs.ex, rhs.ex)
return EinsumExpr(ex, [lhs.freeinds; rhs.freeinds], [lhs.allinds; rhs.allinds])
else
freeinds = find_freeindices([lhs.freeinds; rhs.freeinds])
allinds = [lhs.allinds; rhs.allinds]
ex = :($einsum_contraction(Any, $(ValTuple(freeinds...)), ($(lhs.ex), $(rhs.ex)), ($(ValTuple(lhs.freeinds...)), $(ValTuple(rhs.freeinds...)))))
return EinsumExpr(ex, freeinds, allinds)
end
end

function einsum_instantiate_contraction(TT, exprs::Vector{EinsumExpr})
freeinds = find_freeindices(mapreduce(x->x.freeinds, vcat, exprs))

list = findall(exprs) do einex # tensors having only dummy indices
dummies_list = findall(exprs) do einex # tensors having only dummy indices
isscalarexpr(einex) || !any(in(freeinds), einex.freeinds)
end

if !isempty(list)
dummy_tensors = exprs[list]
deleteat!(exprs, list)
if !isempty(exprs)
dummy_tensors = [dummy_tensors; popfirst!(exprs)]
end
push!(exprs, reduce(einsum_instantiate_contraction, dummy_tensors))
# compute dummy indices first
if !isempty(dummies_list)
dummies = exprs[dummies_list]
deleteat!(exprs, dummies_list)
exprs = [dummies; exprs]
end

length(exprs) == 1 && return only(exprs)
# lastly apply `TT`
ex = foldl(einsum_instantiate_contraction, exprs[1:end-1])
einsum_instantiate_contraction(ex, exprs[end], TT)
end

exprs::Vector{EinsumExpr} = foldl(exprs) do x, y
lhs::EinsumExpr = x isa Vector ? x[end] : x
rhs::EinsumExpr = y
if isscalarexpr(lhs) || isscalarexpr(rhs)
ex = Expr(:call, :*, lhs.ex, rhs.ex)
tails = [EinsumExpr(ex, [lhs.freeinds; rhs.freeinds], [lhs.allinds; rhs.allinds])]
else
tails = [lhs, rhs]
end
x isa Vector ? append!(x[1:end-1], tails) : tails
function einsum_instantiate_contraction(lhs::EinsumExpr, rhs::EinsumExpr, TT = :Any)
if isscalarexpr(lhs) || isscalarexpr(rhs)
ex = Expr(:call, :*, lhs.ex, rhs.ex)
return EinsumExpr(ex, [lhs.freeinds; rhs.freeinds], [lhs.allinds; rhs.allinds])
else
freeinds = find_freeindices([lhs.freeinds; rhs.freeinds])
allinds = [lhs.allinds; rhs.allinds]
ex = :($einsum_contraction($TT, $(ValTuple(freeinds...)), ($(lhs.ex), $(rhs.ex)), ($(ValTuple(lhs.freeinds...)), $(ValTuple(rhs.freeinds...)))))
return EinsumExpr(ex, freeinds, allinds)
end

allinds = mapreduce(x->x.allinds, vcat, exprs)
ex = :($einsum_contraction($TT, $(ValTuple(freeinds...)), ($([x.ex for x in exprs]...),), ($([ValTuple(x.freeinds...) for x in exprs]...),)))
return EinsumExpr(ex, freeinds, allinds)
end

# for dummy indices
Expand All @@ -211,62 +188,50 @@ end
x
end

function einsum_contraction_expr(freeinds::Vector, tensors::Vector, tensorinds::Vector{<: AbstractVector})
@assert length(tensors) == length(tensorinds)
function einsum_contraction_expr(free_indices::Vector, tensors::Vector, tensor_indices::Vector{<: AbstractVector})
@assert length(tensors) == length(tensor_indices)

allinds = mapreduce(collect, vcat, tensorinds)
dummyinds = setdiff(allinds, freeinds)
allinds = [freeinds; dummyinds]
all_indices = mapreduce(collect, vcat, tensor_indices)
dummy_indices = setdiff(all_indices, free_indices)
all_indices = [free_indices; dummy_indices]

# check dimensions
dummyaxes = Base.OneTo{Int}[]
for di in dummyinds
dim = 0
count = 0
for (i, inds) in enumerate(tensorinds)
for I in findall(==(di), inds)
if dim == 0
dim = size(tensors[i], I)
push!(dummyaxes, axes(tensors[i], I))
else
size(tensors[i], I) == dim || error("@einsum: dimension mismatch")
end
count += 1
end
dummy_axes = Base.OneTo{Int}[]
for dummy_index in dummy_indices
axs = mapreduce(vcat, zip(tensors, tensor_indices)) do (tensor, inds) # (A, [:i,:j])
map(i -> axes(tensor, i), findall(==(dummy_index), inds))
end
count == 2 || error("@einsum: index $symbol appears more than twice")
length(axs) < 2 && error("@einsum: wrong free indices given")
length(axs) > 2 && error("@einsum: index $dummy_index appears more than twice")
ax = unique(axs)
length(ax) == 1 && push!(dummy_axes, only(ax))
length(ax) > 1 && error("@einsum: dimension mismatch at index $dummy_index")
end

# tensor -> global indices
whichindices = Vector{Int}[]
for (i, inds) in enumerate(tensorinds)
length(inds) == ndims(tensors[i]) || error("@einsum: the number of indices does not match the number of dimensions")
whichinds = map(inds) do index
I = findall(==(index), allinds)
@assert I !== nothing
only(I)
end
push!(whichindices, whichinds)
# create indexmaps from each tensor to `all_indices`
indexmaps = Vector{Int}[]
for (tensor, inds) in zip(tensors, tensor_indices) # (A, [:i,:j])
ndims(tensor) == length(inds) || error("@einsum: the number of indices does not match the order of tensor #$i")
indices = map(index -> only(findall(==(index), all_indices)), inds)
push!(indexmaps, indices)
end

T = promote_type(map(eltype, tensors)...)
if isempty(freeinds)
if isempty(free_indices)
TT = T
freeaxes = ()
free_axes = ()
else
perm = map(freeinds) do index
only(findall(==(index), reduce(vcat, tensorinds)))
end
perm = map(index -> only(findall(==(index), reduce(vcat, tensor_indices))), free_indices)
TT = tensortype(_permutedims(otimes(map(Space, tensors)...), Val(tuple(perm...)))){T}
freeaxes = axes(TT)
free_axes = axes(TT)
end

sumexps = map(CartesianIndices(freeaxes)) do finds
xs = map(CartesianIndices(Tuple(dummyaxes))) do dinds
ainds = [Tuple(finds)..., Tuple(dinds)...]
exps = map(enumerate(tensors)) do (i, t)
inds = ainds[whichindices[i]]
getindex_expr(t, :(tensors[$i]), inds...)
sumexps = map(CartesianIndices(free_axes)) do free_cartesian_index
xs = map(CartesianIndices(Tuple(dummy_axes))) do dummy_cartesian_index
cartesian_index = Tuple(CartesianIndex(free_cartesian_index, dummy_cartesian_index))
exps = map(enumerate(tensors)) do (i, tensor)
indices = cartesian_index[indexmaps[i]]
getindex_expr(tensor, :(tensors[$i]), indices...)
end
Expr(:call, :*, exps...)
end
Expand Down
7 changes: 7 additions & 0 deletions test/einsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ end
check_value_and_type((@einsum -a[μ]*a[v] + b[μ]*b[v]), -a a + b b, (@tensor t[μ,v] := -Array(a)[μ] * Array(a)[v] + Array(b)[μ] * Array(b)[v]))
check_value_and_type((@einsum b[μ]*b[v] + -a[μ]*a[v]), b b + -a a, (@tensor t[μ,v] := Array(b)[μ] * Array(b)[v] + -Array(a)[μ] * Array(a)[v]))
end
@testset "type annotation" begin
A = rand(SecondOrderTensor{4})
B = rand(Tensor{Tuple{4,@Symmetry{4,4}}})
ans = @einsum A[σp,σ]*A[μp,μ]*A[νp,ν]*B[σp,μp,νp]
@test (@einsum Tensor{Tuple{4,@Symmetry{4,4}}, Float64} A[σp,σ]*A[μp,μ]*A[νp,ν]*B[σp,μp,νp])::Tensor{Tuple{4,@Symmetry{4,4}}, Float64} ans
@test (@einsum Tensor{Tuple{4,@Symmetry{4,4}}, Float32} A[σp,σ]*A[μp,μ]*A[νp,ν]*B[σp,μp,νp])::Tensor{Tuple{4,@Symmetry{4,4}}, Float32} ans
end
@testset "errors" begin
S1 = rand(SymmetricSecondOrderTensor{3})
v = rand(Vec{2})
Expand Down

0 comments on commit d955d96

Please sign in to comment.