From 79efb28a44651738f17bd41b1fddd618bfc6a24e Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 16:43:01 +0200 Subject: [PATCH 01/15] Introduction of Tullio to perform pairwise operations --- Project.toml | 1 + src/KernelFunctions.jl | 8 +++++ src/distances/binaryop.jl | 0 src/distances/delta.jl | 6 ++-- src/distances/dotproduct.jl | 31 ++++++++++--------- src/distances/pairwise.jl | 61 +++++++------------------------------ src/distances/sinus.jl | 4 +-- 7 files changed, 40 insertions(+), 71 deletions(-) create mode 100644 src/distances/binaryop.jl diff --git a/Project.toml b/Project.toml index 2c0bafd1d..047926f01 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 37bde4a65..e65f8e591 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -58,6 +58,7 @@ using IrrationalConstants: logtwo, twoπ, invsqrt2 using LogExpFunctions: softplus using StatsBase using TensorCore +using Tullio using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield # Hack to work around Zygote type inference problems. @@ -67,6 +68,13 @@ abstract type Kernel end abstract type SimpleKernel <: Kernel end include("utils.jl") + +const VecOfVecs = Union{ColVecs,RowVecs} + +# A general binary op type not respecting Distances metric rules +abstract type AbstractBinaryOp end +const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric} + include("distances/pairwise.jl") include("distances/dotproduct.jl") include("distances/delta.jl") diff --git a/src/distances/binaryop.jl b/src/distances/binaryop.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/distances/delta.jl b/src/distances/delta.jl index f41370eae..f25ba633a 100644 --- a/src/distances/delta.jl +++ b/src/distances/delta.jl @@ -1,6 +1,6 @@ -# Delta is not following the PreMetric rules since d(x, x) == 1 -struct Delta <: Distances.UnionPreMetric end +struct Delta <: AbstractBinaryOp end +# Basic definitions (dist::Delta)(a::Number, b::Number) = a == b Base.@propagate_inbounds function (dist::Delta)( a::AbstractArray{<:Number}, b::AbstractArray{<:Number} @@ -14,5 +14,3 @@ Base.@propagate_inbounds function (dist::Delta)( end return a == b end - -Distances.result_type(::Delta, Ta::Type, Tb::Type) = Bool diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index 1cef13ab5..f689d3cb1 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -1,21 +1,22 @@ ## DotProduct is not following the PreMetric rules since d(x, x) != 0 and d(x, y) >= 0 for all x, y -struct DotProduct <: Distances.UnionPreMetric end +struct DotProduct <: AbstractBinaryOp end -@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector) - @boundscheck if length(a) != length(b) - throw( - DimensionMismatch( - "first array has length $(length(a)) which does not match the length of the second, $(length(b)).", - ), - ) - end - return dot(a, b) +(::DotProduct)(a::AbstractVector, b::AbstractVector) = dot(a, b) + +(::DotProduct)(a::Number, b::Number) = a * b + +function pairwise(::DotProduct, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := x.X[k, i] * y.X[k, j] end -Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb) +function pairwise(::DotProduct, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := x.X[i, k] * y.X[j, k] +end -@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b -@inline function (dist::DotProduct)(a::AbstractArray, b::AbstractArray) - return Distances._evaluate(dist, a, b) +function colwise(::DotProduct, x::RowVecs, y::RowVecs=x) + return @tullio out[i] := x.X[i, k] * y.X[i, k] end -@inline (dist::DotProduct)(a::Number, b::Number) = a * b + +function colwise(::DotProduct, x::ColVecs, y::ColVecs=x) + return @tullio out[i] := x.X[k, i] * y.X[k, i] +end \ No newline at end of file diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 8b5cb43e7..daa894a81 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -1,34 +1,16 @@ # Add our own pairwise function to be able to apply it on vectors -function pairwise(d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast(d, X, permutedims(Y)) +function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector) + return @tullio out[i, j] := d(X[i], Y[j]) end -pairwise(d::PreMetric, X::AbstractVector) = pairwise(d, X, X) +pairwise(d::BinaryOp, X::AbstractVector) = pairwise(d, X, X) -function pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector, Y::AbstractVector) - return broadcast!(d, out, X, permutedims(Y)) +function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector) + return @tullio out[i, j] = d(X[i], Y[j]) end -pairwise!(out::AbstractMatrix, d::PreMetric, X::AbstractVector) = pairwise!(out, d, X, X) - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}) - return Distances_pairwise(d, reshape(x, :, 1); dims=1) -end - -function pairwise(d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) - return Distances_pairwise(d, reshape(x, :, 1), reshape(y, :, 1); dims=1) -end - -function pairwise!(out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}) - return Distances.pairwise!(out, d, reshape(x, :, 1); dims=1) -end - -function pairwise!( - out::AbstractMatrix, d::PreMetric, x::AbstractVector{<:Real}, y::AbstractVector{<:Real} -) - return Distances.pairwise!(out, d, reshape(x, :, 1), reshape(y, :, 1); dims=1) -end +pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector) = pairwise!(out, d, X, X) # Also defines the colwise method for abstractvectors @@ -36,35 +18,14 @@ function colwise(d::PreMetric, x::AbstractVector) return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition end -function colwise(d::PreMetric, x::ColVecs) - return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition -end - -function colwise(d::PreMetric, x::RowVecs) +function colwise(d::PreMetric, x::VecOfVecs) return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition end -## The following is a hack for DotProduct and Delta to still work -function colwise(d::Distances.UnionPreMetric, x::ColVecs) - return Distances.colwise(d, x.X, x.X) -end - -function colwise(d::Distances.UnionPreMetric, x::RowVecs) - return Distances.colwise(d, x.X', x.X') -end - -function colwise(d::Distances.UnionPreMetric, x::AbstractVector) - return map(d, x, x) -end - -function colwise(d::PreMetric, x::ColVecs, y::ColVecs) - return Distances.colwise(d, x.X, y.X) -end - -function colwise(d::PreMetric, x::RowVecs, y::RowVecs) - return Distances.colwise(d, x.X', y.X') +function colwise(d::AbstractBinaryOp, x::AbstractVector) + return @tullio out[i] := d(x[i], x[i]) end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) - return map(d, x, y) -end + return @tullio out[i] := d(x[i], y[i]) +end \ No newline at end of file diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 4bcf4bdf0..1c063f41e 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -1,5 +1,5 @@ -struct Sinus{T} <: Distances.UnionSemiMetric - r::Vector{T} +struct Sinus{T,V<:AbstractVector{T}} <: Distances.SemiMetric + r::V end Distances.parameters(d::Sinus) = d.r From a6afab117243e5a97e3d2f0d4f3d75cd6ec844ed Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 17:01:59 +0200 Subject: [PATCH 02/15] Fix formatting --- src/distances/binaryop.jl | 0 src/distances/dotproduct.jl | 2 +- src/distances/pairwise.jl | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 src/distances/binaryop.jl diff --git a/src/distances/binaryop.jl b/src/distances/binaryop.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index f689d3cb1..fb89a4bee 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -19,4 +19,4 @@ end function colwise(::DotProduct, x::ColVecs, y::ColVecs=x) return @tullio out[i] := x.X[k, i] * y.X[k, i] -end \ No newline at end of file +end diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index daa894a81..00fb9620f 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -28,4 +28,4 @@ end function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) return @tullio out[i] := d(x[i], y[i]) -end \ No newline at end of file +end From ec4b5cdae2e9af789599fd628e2dd677aa38fb97 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 17:08:02 +0200 Subject: [PATCH 03/15] Fix one colwise --- src/distances/pairwise.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 00fb9620f..2327ac3ce 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -26,6 +26,6 @@ function colwise(d::AbstractBinaryOp, x::AbstractVector) return @tullio out[i] := d(x[i], x[i]) end -function colwise(d::PreMetric, x::AbstractVector, y::AbstractVector) +function colwise(d::BinaryOp, x::AbstractVector, y::AbstractVector) return @tullio out[i] := d(x[i], y[i]) end From 5b279f279555f6d28b7978a204d689c69b433958 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 17:11:56 +0200 Subject: [PATCH 04/15] Corrections --- src/distances/pairwise.jl | 15 ++++++--------- src/utils.jl | 36 ++++-------------------------------- 2 files changed, 10 insertions(+), 41 deletions(-) diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 2327ac3ce..4751c92bb 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -1,24 +1,21 @@ # Add our own pairwise function to be able to apply it on vectors -function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector) +function pairwise(d::BinaryOp, X::AbstractVector, Y::AbstractVector=X) return @tullio out[i, j] := d(X[i], Y[j]) end -pairwise(d::BinaryOp, X::AbstractVector) = pairwise(d, X, X) - -function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector) +function pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector, Y::AbstractVector=X) return @tullio out[i, j] = d(X[i], Y[j]) end -pairwise!(out::AbstractMatrix, d::BinaryOp, X::AbstractVector) = pairwise!(out, d, X, X) - # Also defines the colwise method for abstractvectors - -function colwise(d::PreMetric, x::AbstractVector) +# We have different methods for PreMetric and AbstractBinaryOp +# Since colwise on AbstractBinaryOp is not guaranteed to be equal to 0 +function colwise(d::Distances.PreMetric, x::AbstractVector) return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition end -function colwise(d::PreMetric, x::VecOfVecs) +function colwise(d::Distances.PreMetric, x::VecOfVecs) return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition end diff --git a/src/utils.jl b/src/utils.jl index 7eea4358c..f32cb49b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -80,21 +80,6 @@ Base.vcat(a::ColVecs, b::ColVecs) = ColVecs(hcat(a.X, b.X)) dim(x::ColVecs) = size(x.X, 1) -pairwise(d::PreMetric, x::ColVecs) = Distances_pairwise(d, x.X; dims=2) -pairwise(d::PreMetric, x::ColVecs, y::ColVecs) = Distances_pairwise(d, x.X, y.X; dims=2) -function pairwise(d::PreMetric, x::AbstractVector, y::ColVecs) - return Distances_pairwise(d, reduce(hcat, x), y.X; dims=2) -end -function pairwise(d::PreMetric, x::ColVecs, y::AbstractVector) - return Distances_pairwise(d, x.X, reduce(hcat, y); dims=2) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs) - return Distances.pairwise!(out, d, x.X; dims=2) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::ColVecs, y::ColVecs) - return Distances.pairwise!(out, d, x.X, y.X; dims=2) -end - """ RowVecs(X::AbstractMatrix) @@ -150,24 +135,11 @@ Base.vcat(a::RowVecs, b::RowVecs) = RowVecs(vcat(a.X, b.X)) dim(x::RowVecs) = size(x.X, 2) -pairwise(d::PreMetric, x::RowVecs) = Distances_pairwise(d, x.X; dims=1) -pairwise(d::PreMetric, x::RowVecs, y::RowVecs) = Distances_pairwise(d, x.X, y.X; dims=1) -function pairwise(d::PreMetric, x::AbstractVector, y::RowVecs) - return Distances_pairwise(d, permutedims(reduce(hcat, x)), y.X; dims=1) -end -function pairwise(d::PreMetric, x::RowVecs, y::AbstractVector) - return Distances_pairwise(d, x.X, permutedims(reduce(hcat, y)); dims=1) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs) - return Distances.pairwise!(out, d, x.X; dims=1) -end -function pairwise!(out::AbstractMatrix, d::PreMetric, x::RowVecs, y::RowVecs) - return Distances.pairwise!(out, d, x.X, y.X; dims=1) -end - # Resolve ambiguity error for ColVecs vs RowVecs. #346 -pairwise(d::PreMetric, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) -pairwise(d::PreMetric, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) +pairwise(d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) +pairwise(d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) +pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise!(out, d, x, ColVecs(permutedims(y.X))) +pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise!(out, d, ColVecs(permutedims(x.X)), y) dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype. dim(x::AbstractVector) = dim(first(x)) From 5f3ea125f9ba1dd0d1b77a490836abb088a37b07 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 17:22:00 +0200 Subject: [PATCH 05/15] Add specialization Euclidean --- src/KernelFunctions.jl | 1 + src/distances/euclidean.jl | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 src/distances/euclidean.jl diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index e65f8e591..8e8549732 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -76,6 +76,7 @@ abstract type AbstractBinaryOp end const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric} include("distances/pairwise.jl") +include("distances/euclidean.jl") include("distances/dotproduct.jl") include("distances/delta.jl") include("distances/sinus.jl") diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl new file mode 100644 index 000000000..78200a4f7 --- /dev/null +++ b/src/distances/euclidean.jl @@ -0,0 +1,17 @@ +# Tullio specialization for Euclidean and SqEuclidean metrics + +function pairwise(::Euclidean, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := sqrt <| x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 +end + +function pairwise(::Euclidean, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := sqrt <| x.X[i, k] ^ 2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k] ^ 2 +end + +function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) + return @tullio out[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 +end + +function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) + return @tullio out[i, j] := x.X[i, k] ^ 2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k] ^ 2 +end \ No newline at end of file From 8ac029446fb8ff5ed5791e548783e27405e6cf25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 19 Oct 2021 17:25:00 +0200 Subject: [PATCH 06/15] Formatting fixes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/distances/euclidean.jl | 10 ++++++---- src/utils.jl | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 78200a4f7..c6bcebb5b 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -1,17 +1,19 @@ # Tullio specialization for Euclidean and SqEuclidean metrics function pairwise(::Euclidean, x::ColVecs, y::ColVecs) - return @tullio out[i, j] := sqrt <| x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 + return @tullio out[i, j] := + sqrt <| x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2 end function pairwise(::Euclidean, x::RowVecs, y::RowVecs) - return @tullio out[i, j] := sqrt <| x.X[i, k] ^ 2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k] ^ 2 + return @tullio out[i, j] := + sqrt <| x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2 end function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) - return @tullio out[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 + return @tullio out[i, j] := x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2 end function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) - return @tullio out[i, j] := x.X[i, k] ^ 2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k] ^ 2 + return @tullio out[i, j] := x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2 end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index f32cb49b6..f9c7a2345 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -138,8 +138,12 @@ dim(x::RowVecs) = size(x.X, 2) # Resolve ambiguity error for ColVecs vs RowVecs. #346 pairwise(d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise(d, x, ColVecs(permutedims(y.X))) pairwise(d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise(d, ColVecs(permutedims(x.X)), y) -pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs) = pairwise!(out, d, x, ColVecs(permutedims(y.X))) -pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs) = pairwise!(out, d, ColVecs(permutedims(x.X)), y) +function pairwise!(out::AbstractMatrix, d::BinaryOp, x::ColVecs, y::RowVecs) + return pairwise!(out, d, x, ColVecs(permutedims(y.X))) +end +function pairwise!(out::AbstractMatrix, d::BinaryOp, x::RowVecs, y::ColVecs) + return pairwise!(out, d, ColVecs(permutedims(x.X)), y) +end dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype. dim(x::AbstractVector) = dim(first(x)) From 98a9f634f1813ff89c75656ae538708a445ae637 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 19 Oct 2021 17:30:43 +0200 Subject: [PATCH 07/15] Remove VecOfVecs and reorder types --- src/KernelFunctions.jl | 6 ++---- src/distances/euclidean.jl | 2 +- src/distances/pairwise.jl | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 8e8549732..ae4579ef9 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -67,14 +67,12 @@ const Distances_pairwise = Distances.pairwise abstract type Kernel end abstract type SimpleKernel <: Kernel end -include("utils.jl") - -const VecOfVecs = Union{ColVecs,RowVecs} - # A general binary op type not respecting Distances metric rules abstract type AbstractBinaryOp end const BinaryOp = Union{AbstractBinaryOp,Distances.PreMetric} +include("utils.jl") + include("distances/pairwise.jl") include("distances/euclidean.jl") include("distances/dotproduct.jl") diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 78200a4f7..a6e1a982e 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -14,4 +14,4 @@ end function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) return @tullio out[i, j] := x.X[i, k] ^ 2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k] ^ 2 -end \ No newline at end of file +end diff --git a/src/distances/pairwise.jl b/src/distances/pairwise.jl index 4751c92bb..3a25845a4 100644 --- a/src/distances/pairwise.jl +++ b/src/distances/pairwise.jl @@ -15,7 +15,7 @@ function colwise(d::Distances.PreMetric, x::AbstractVector) return zeros(Distances.result_type(d, x, x), length(x)) # Valid since d(x,x) == 0 by definition end -function colwise(d::Distances.PreMetric, x::VecOfVecs) +function colwise(d::Distances.PreMetric, x::Union{ColVecs,RowVecs}) return zeros(Distances.result_type(d, x.X, x.X), length(x)) # Valid since d(x,x) == 0 by definition end From cabf7156f2080b244274cf61fc8cba36a014592d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 26 Oct 2021 11:08:16 +0200 Subject: [PATCH 08/15] Update src/distances/euclidean.jl Co-authored-by: David Widmann --- src/distances/euclidean.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 96b58c69a..ad0ae37d9 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -2,7 +2,7 @@ function pairwise(::Euclidean, x::ColVecs, y::ColVecs) return @tullio out[i, j] := - sqrt <| x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2 + sqrt <| (x.X[k, i] - y.X[k, j])^2 end function pairwise(::Euclidean, x::RowVecs, y::RowVecs) From 3d16cd7c29a71f6135fd8be5d6a16a490abfa9ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 26 Oct 2021 12:20:32 +0200 Subject: [PATCH 09/15] Update src/distances/euclidean.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/distances/euclidean.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index ad0ae37d9..8030620f0 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -1,8 +1,7 @@ # Tullio specialization for Euclidean and SqEuclidean metrics function pairwise(::Euclidean, x::ColVecs, y::ColVecs) - return @tullio out[i, j] := - sqrt <| (x.X[k, i] - y.X[k, j])^2 + return @tullio out[i, j] := sqrt <| (x.X[k, i] - y.X[k, j])^2 end function pairwise(::Euclidean, x::RowVecs, y::RowVecs) From adac02ce89ca58eeaecce66f63354580ff7821cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 26 Oct 2021 13:38:46 +0200 Subject: [PATCH 10/15] Apply suggestions from code review Co-authored-by: David Widmann --- src/distances/euclidean.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 8030620f0..61bf18321 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -6,13 +6,13 @@ end function pairwise(::Euclidean, x::RowVecs, y::RowVecs) return @tullio out[i, j] := - sqrt <| x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2 + sqrt <| (x.X[i, k] - y.X[j, k])^2 end function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) - return @tullio out[i, j] := x.X[k, i]^2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j]^2 + return @tullio out[i, j] := (x.X[k, i] - y.X[k, j])^2 end function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) - return @tullio out[i, j] := x.X[i, k]^2 - 2 * x.X[i, k] * y.X[j, k] + y.X[j, k]^2 + return @tullio out[i, j] := (x.X[i, k] - y.X[j, k])^2 end From 5d3cd59b026f4c4d7a3fd530e393bfac7a3ad9cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Tue, 26 Oct 2021 13:44:27 +0200 Subject: [PATCH 11/15] Update src/distances/euclidean.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/distances/euclidean.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 61bf18321..55c4d1ca4 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -5,8 +5,7 @@ function pairwise(::Euclidean, x::ColVecs, y::ColVecs) end function pairwise(::Euclidean, x::RowVecs, y::RowVecs) - return @tullio out[i, j] := - sqrt <| (x.X[i, k] - y.X[j, k])^2 + return @tullio out[i, j] := sqrt <| (x.X[i, k] - y.X[j, k])^2 end function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) From 5b36f3d010392d895c82627b362834f9a7cf9c88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 27 Oct 2021 11:41:21 +0200 Subject: [PATCH 12/15] Add simplification for dotproduct --- src/distances/dotproduct.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index fb89a4bee..20027250a 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -13,7 +13,16 @@ function pairwise(::DotProduct, x::RowVecs, y::RowVecs) return @tullio out[i, j] := x.X[i, k] * y.X[j, k] end -function colwise(::DotProduct, x::RowVecs, y::RowVecs=x) +# Simplification for x == y +function colwise(::DotProduct, x::RowVecs) + return @tullio out[i] := x.X[i, k]^2 +end + +function colwise(::DotProduct, x::ColVecs) + return @tullio out[i] := x.X[k, i]^2 +end + +function colwise(::DotProduct, x::RowVecs, y::RowVecs) return @tullio out[i] := x.X[i, k] * y.X[i, k] end From e6bf11fcd66e8157a31a3cc3c75a256f423fe6ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 13 Jan 2022 19:25:05 +0100 Subject: [PATCH 13/15] Add rrules --- src/distances/dotproduct.jl | 2 +- src/distances/euclidean.jl | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/distances/dotproduct.jl b/src/distances/dotproduct.jl index 20027250a..7c8095c5e 100644 --- a/src/distances/dotproduct.jl +++ b/src/distances/dotproduct.jl @@ -26,6 +26,6 @@ function colwise(::DotProduct, x::RowVecs, y::RowVecs) return @tullio out[i] := x.X[i, k] * y.X[i, k] end -function colwise(::DotProduct, x::ColVecs, y::ColVecs=x) +function colwise(::DotProduct, x::ColVecs, y::ColVecs) return @tullio out[i] := x.X[k, i] * y.X[k, i] end diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 55c4d1ca4..0b7639b32 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -4,10 +4,30 @@ function pairwise(::Euclidean, x::ColVecs, y::ColVecs) return @tullio out[i, j] := sqrt <| (x.X[k, i] - y.X[k, j])^2 end +function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::ColVecs, y::ColVecs) + D = pairwise(d, x, y) + function pairwise_pullback(Δ) + @tullio ΔX[l, k] := Δ[k, i] * (x.X[l, k] - y.X[l, i]) / D[k, i] + @tullio ΔY[l, i] := Δ[k, i] * (y.X[l, i] - x.X[l, k]) / D[k, i] + return NoTangent(), NoTangent(), Tangent{ColVecs}(;X=ΔX), Tangent{ColVecs}(;X=ΔY) + end + return D, pairwise_pullback +end + function pairwise(::Euclidean, x::RowVecs, y::RowVecs) return @tullio out[i, j] := sqrt <| (x.X[i, k] - y.X[j, k])^2 end +function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::RowVecs, y::RowVecs) + D = pairwise(d, x, y) + function pairwise_pullback(Δ) + @tullio ΔX[k, l] := Δ[k, i] * (x.X[k, l] - y.X[i, l]) / D[k, i] + @tullio ΔY[i, l] := Δ[k, i] * (y.X[i, l] - x.X[k, l]) / D[k, i] + return NoTangent(), NoTangent(), Tangent{RowVecs}(;X=ΔX), Tangent{RowVecs}(;X=ΔY) + end + return D, pairwise_pullback +end + function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) return @tullio out[i, j] := (x.X[k, i] - y.X[k, j])^2 end From d7bc47f446c4280fafba5197d2a6e25878417b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 13 Jan 2022 19:33:19 +0100 Subject: [PATCH 14/15] Solve formatting --- src/distances/euclidean.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 0b7639b32..5054f6e74 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -9,7 +9,7 @@ function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::ColVecs, y::C function pairwise_pullback(Δ) @tullio ΔX[l, k] := Δ[k, i] * (x.X[l, k] - y.X[l, i]) / D[k, i] @tullio ΔY[l, i] := Δ[k, i] * (y.X[l, i] - x.X[l, k]) / D[k, i] - return NoTangent(), NoTangent(), Tangent{ColVecs}(;X=ΔX), Tangent{ColVecs}(;X=ΔY) + return NoTangent(), NoTangent(), Tangent{ColVecs}(; X=ΔX), Tangent{ColVecs}(; X=ΔY) end return D, pairwise_pullback end @@ -23,7 +23,7 @@ function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::RowVecs, y::R function pairwise_pullback(Δ) @tullio ΔX[k, l] := Δ[k, i] * (x.X[k, l] - y.X[i, l]) / D[k, i] @tullio ΔY[i, l] := Δ[k, i] * (y.X[i, l] - x.X[k, l]) / D[k, i] - return NoTangent(), NoTangent(), Tangent{RowVecs}(;X=ΔX), Tangent{RowVecs}(;X=ΔY) + return NoTangent(), NoTangent(), Tangent{RowVecs}(; X=ΔX), Tangent{RowVecs}(; X=ΔY) end return D, pairwise_pullback end From 48f672757e0254d0bfd849d9dbac33a5a6241a28 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Fri, 14 Jan 2022 17:09:27 +0100 Subject: [PATCH 15/15] Add colwise for euclidean --- src/distances/euclidean.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/distances/euclidean.jl b/src/distances/euclidean.jl index 5054f6e74..b137445d3 100644 --- a/src/distances/euclidean.jl +++ b/src/distances/euclidean.jl @@ -28,6 +28,14 @@ function ChainRulesCore.rrule(::typeof(pairwise), d::Euclidean, x::RowVecs, y::R return D, pairwise_pullback end +function colwise(::Euclidean, x::ColVecs, y::ColVecs) + return @tullio out[i] := sqrt <| (x.X[k, i] - y.X[k, i])^2 +end + +function colwise(::Euclidean, x::RowVecs, y::RowVecs) + return @tullio out[i] := sqrt <| (x.X[i, k] - y.X[i, k])^2 +end + function pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) return @tullio out[i, j] := (x.X[k, i] - y.X[k, j])^2 end @@ -35,3 +43,11 @@ end function pairwise(::SqEuclidean, x::RowVecs, y::RowVecs) return @tullio out[i, j] := (x.X[i, k] - y.X[j, k])^2 end + +function colwise(::SqEuclidean, x::ColVecs, y::ColVecs) + return @tullio out[i] := (x.X[k, i] - y.X[k, i])^2 +end + +function colwise(::SqEuclidean, x::RowVecs, y::RowVecs) + return @tullio out[i] := (x.X[i, k] - y.X[i, k])^2 +end \ No newline at end of file