From 736f9a0e546a624363f4b8d8cac522b210c53ccf Mon Sep 17 00:00:00 2001 From: pitx-perf <> Date: Thu, 27 Feb 2025 12:58:02 +0100 Subject: [PATCH] Feature: give the opportunity to explicitly pass the list of arguments types --- src/register.jl | 36 +++++++++++++++++++++++++++++------- src/wrapper-types.jl | 21 ++++++++++++++++----- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/register.jl b/src/register.jl index 8d1c01dc1..963fdd446 100644 --- a/src/register.jl +++ b/src/register.jl @@ -23,9 +23,11 @@ overwriting. ``` See `@register_array_symbolic` to register functions which return arrays. """ -macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true) +macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true, args_list = []) f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts) + args_list = parse_args_list(args_list) + args′ = map((a, T) -> :($a::$T), argnames, Ts) ret_type = isnothing(ret_type) ? Real : ret_type @@ -42,7 +44,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = else return $wrap(res) end - end $wrap_arrays) + end $wrap_arrays $args_list) if define_promotion fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type) @@ -96,7 +98,7 @@ symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: SymbolicUtils.Symbolic symbolic_eltype(::AbstractArray{Num}) = Real symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT -function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true) +function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true, args_list = []) def_assignments = MacroTools.rmlines(partial_defs).args defs = map(def_assignments) do ex @assert ex.head == :(=) @@ -123,7 +125,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs else return $wrap(res) end - end $wrap_arrays + end $wrap_arrays $args_list end |> esc if define_promotion @@ -161,6 +163,23 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs return fexpr end +function parse_args_list(args_list) + if isa(args_list, Expr) && args_list.head == :vect + args_list = args_list.args + + # for each element of args_list, convert exp to a tuple + args_list = map(args_list) do exp + if exp isa Expr && exp.head == :tuple + exp.args + else + (exp,) + end + end + end + + return args_list +end + """ @register_array_symbolic(expr, define_promotion = true) @@ -193,7 +212,10 @@ overloads for one function, all the rest of the registers must set `define_promotion` to `false` except for the first one, to avoid method overwriting. """ -macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true) +macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true, args_list = []) f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([])) - register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays) -end + + args_list = parse_args_list(args_list) + + register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays, args_list) +end \ No newline at end of file diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index fd575b66c..8abd9b98e 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -55,7 +55,7 @@ function wraps_type end has_symwrapper(::Type) = false is_wrapper_type(::Type) = false -function wrap_func_expr(mod, expr, wrap_arrays = true) +function wrap_func_expr(mod, expr, wrap_arrays = true, args_list = []) @assert expr.head == :function || (expr.head == :(=) && expr.args[1] isa Expr && expr.args[1].head == :call) @@ -142,8 +142,19 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) impl = :(function $impl_name($self, $(names...)) $body end) - # TODO: maybe don't drop first lol - methods = map(Iterators.drop(Iterators.product(types...), 1)) do Ts + + if isempty(args_list) + # TODO: maybe don't drop first lol + it = Iterators.drop(Iterators.product(types...), 1) + else + # sanity check + if length(args_list[1]) != length(names) + error("args_list must have the same length as the number of arguments") + end + it = args_list + end + + methods = map(it) do Ts method_args = map(names, Ts) do n, T :($n::$T) end @@ -171,6 +182,6 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) end |> esc end -macro wrapped(expr, wrap_arrays = true) - wrap_func_expr(__module__, expr, wrap_arrays) +macro wrapped(expr, wrap_arrays = true, args_list = []) + wrap_func_expr(__module__, expr, wrap_arrays, args_list) end