diff --git a/Project.toml b/Project.toml index 822c5ee2d..7cc90648a 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 37bde4a65..ae4579ef9 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -58,6 +58,7 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2 using LogExpFunctions: softplus using StatsBase using TensorCore +using Tullio using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield # Hack to work around Zygote type inference problems. @@ -66,8 +67,14 @@ const Distances_pairwise = Distances.pairwise abstract type Kernel end abstract type SimpleKernel <: Kernel end +# A general binary op type not respecting Distances metric rules +abstract type AbstractBinaryOp end +const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric} + include("utils.jl") + include("distances/pairwise.jl") +include("distances/euclidean.jl") include("distances/dotproduct.jl") include("distances/delta.jl") include("distances/sinus.jl") diff --git a/src/distances/delta.jl b/src/distances/delta.jl index f41370eae..f25ba633a 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,6 +1,6 @@ -# Delta is not following the PreMetric rules since d(x, x) == 1 -struct Delta <: Distances.UnionPreMetric end +struct Delta <: AbstractBinaryOp end +# Basic definitions (dist::Delta)(a::Number, b::Number) = a == b Base.@propagate_inbounds function (dist::Delta)( a::AbstractArray{<:Number}, b::AbstractArray{<:Number} @@ -14,5 +14,3 @@ Base.@propagate_inbounds function (dist::Delta)( end return a == b end - -Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index 1cef13ab5..7c8095c5e 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,21 +1,31 @@ ## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y -struct DotProduct <: Distances.UnionPreMetric end - -@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) - @boundscheck if length(a) != length(b) - throw( - DimensionMismatch( - "first array has length $(length(a)) which does not match the length of the second, $(length(b)).", - ), - ) - end - return dot(a, b) +struct DotProduct <: AbstractBinaryOp end + +(::DotProduct)(a::AbstractVector, b::AbstractVector) = dot(a, b) + +(::DotProduct)(a::Number, b::Number) = a * b + +function pairwise(::DotProduct, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := x.X[k, i] * y.X[k, j] +end + +function pairwise(::DotProduct, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := x.X[i, k] * y.X[j, k] end -Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) +# Simplification for x == y +function colwise(::DotProduct, x::RowVecs) + return @tullio out[i] := x.X[i, k]^2 +end + +function colwise(::DotProduct, x::ColVecs) + return @tullio out[i] := x.X[k, i]^2 +end + +function colwise(::DotProduct, x::RowVecs, y::RowVecs) + return @tullio out[i] := x.X[i, k] * y.X[i, k] +end -@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b -@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray) - return Distances._evaluate(dist, a, b) +function colwise(::DotProduct, x::ColVecs, y::ColVecs) + return @tullio out[i] := x.X[k, i] * y.X[k, i] end -@inline (dist::DotProduct)(a::Number, b::Number) = a * b diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl new file mode 100644 index 000000000..b137445d3 --- /dev/null +++ b/src/distances/euclidean.jl @@ -0,0 +1,53 @@ +# Tullio specialization for Euclidean and SqEuclidean metrics + +function pairwise(::Euclidean, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := sqrt <| (x.X[k, i] - y.X[k, j])^2 +end + +function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::ColVecs, y::ColVecs) + D = pairwise(d, x, y) + function pairwise_pullback(Δ) + @tullio ΔX[l, k] := Δ[k, i] * (x.X[l, k] - y.X[l, i]) / D[k, i] + @tullio ΔY[l, i] := Δ[k, i] * (y.X[l, i] - x.X[l, k]) / D[k, i] + return NoTangent(), NoTangent(), Tangent{ColVecs}(; X=ΔX), Tangent{ColVecs}(; X=ΔY) + end + return D, pairwise_pullback +end + +function pairwise(::Euclidean, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := sqrt <| (x.X[i, k] - y.X[j, k])^2 +end + +function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::RowVecs, y::RowVecs) + D = pairwise(d, x, y) + function pairwise_pullback(Δ) + @tullio ΔX[k, l] := Δ[k, i] * (x.X[k, l] - y.X[i, l]) / D[k, i] + @tullio ΔY[i, l] := Δ[k, i] * (y.X[i, l] - x.X[k, l]) / D[k, i] + return NoTangent(), NoTangent(), Tangent{RowVecs}(; X=ΔX), Tangent{RowVecs}(; X=ΔY) + end + return D, pairwise_pullback +end + +function colwise(::Euclidean, x::ColVecs, y::ColVecs) + return @tullio out[i] := sqrt <| (x.X[k, i] - y.X[k, i])^2 +end + +function colwise(::Euclidean, x::RowVecs, y::RowVecs) + return @tullio out[i] := sqrt <| (x.X[i, k] - y.X[i, k])^2 +end + +function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := (x.X[k, i] - y.X[k, j])^2 +end + +function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := (x.X[i, k] - y.X[j, k])^2 +end + +function colwise(::SqEuclidean, x::ColVecs, y::ColVecs) + return @tullio out[i] := (x.X[k, i] - y.X[k, i])^2 +end + +function colwise(::SqEuclidean, x::RowVecs, y::RowVecs) + return @tullio out[i] := (x.X[i, k] - y.X[i, k])^2 +end \ No newline at end of file diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 8b5cb43e7..3a25845a4 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -1,70 +1,28 @@ # Add our own pairwise function to be able to apply it on vectors -function pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast(d, X, permutedims(Y)) +function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector=X) + return @tullio out[i, j] := d(X[i], Y[j]) end -pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X) - -function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast!(d, out, X, permutedims(Y)) -end - -pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X) - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}) - return Distances_pairwise(d, reshape(x, :, 1); dims=1) -end - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1) -end - -function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1) -end - -function pairwise!( - out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real} -) - return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) +function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector=X) + return @tullio out[i, j] = d(X[i], Y[j]) end # Also defines the colwise method for abstractvectors - -function colwise(d::PreMetric, x::AbstractVector) +# We have different methods for PreMetric and AbstractBinaryOp +# Since colwise on AbstractBinaryOp is not guaranteed to be equal to 0 +function colwise(d::Distances.PreMetric, x::AbstractVector) return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition end -function colwise(d::PreMetric, x::ColVecs) +function colwise(d::Distances.PreMetric, x::Union{ColVecs,RowVecs}) return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition end -function colwise(d::PreMetric, x::RowVecs) - return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition -end - -## The following is a hack for DotProduct and Delta to still work -function colwise(d::Distances.UnionPreMetric, x::ColVecs) - return Distances.colwise(d, x.X, x.X) -end - -function colwise(d::Distances.UnionPreMetric, x::RowVecs) - return Distances.colwise(d, x.X', x.X') -end - -function colwise(d::Distances.UnionPreMetric, x::AbstractVector) - return map(d, x, x) -end - -function colwise(d::PreMetric, x::ColVecs, y::ColVecs) - return Distances.colwise(d, x.X, y.X) -end - -function colwise(d::PreMetric, x::RowVecs, y::RowVecs) - return Distances.colwise(d, x.X', y.X') +function colwise(d::AbstractBinaryOp, x::AbstractVector) + return @tullio out[i] := d(x[i], x[i]) end -function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) - return map(d, x, y) +function colwise(d::BinaryOp, x::AbstractVector, y::AbstractVector) + return @tullio out[i] := d(x[i], y[i]) end diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 51d14c47d..56759148a 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -1,5 +1,5 @@ -struct Sinus{T} <: Distances.UnionSemiMetric - r::Vector{T} +struct Sinus{T,V<:AbstractVector{T}} <: Distances.SemiMetric + r::V end Sinus(r::Real) = Sinus([r]) diff --git a/src/utils.jl b/src/utils.jl index 75dd62110..5c2034ef9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -80,21 +80,6 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) dim(x::ColVecs) = size(x.X, 1) -pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2) -pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2) -function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs) - return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2) -end -function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector) - return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs) - return Distances.pairwise!(out, d, x.X; dims=2) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs) - return Distances.pairwise!(out, d, x.X, y.X; dims=2) -end - """ RowVecs(X::AbstractMatrix) @@ -150,25 +135,16 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X)) dim(x::RowVecs) = size(x.X, 2) -pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1) -pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1) -function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs) - return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1) -end -function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector) - return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs) - return Distances.pairwise!(out, d, x.X; dims=1) +# Resolve ambiguity error for ColVecs vs RowVecs. #346 +pairwise(d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) +pairwise(d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) +function pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs) + return pairwise!(out, d, x, ColVecs(permutedims(y.X))) end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs) - return Distances.pairwise!(out, d, x.X, y.X; dims=1) +function pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs) + return pairwise!(out, d, ColVecs(permutedims(x.X)), y) end -# Resolve ambiguity error for ColVecs vs RowVecs. #346 -pairwise(d::PreMetric, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) -pairwise(d::PreMetric, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) - dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype. dim(x::AbstractVector) = dim(first(x)) dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))