Skip to content

Feature: give the opportunity to explicitly pass the list of arguments to register #1455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ overwriting.
```
See `@register_array_symbolic` to register functions which return arrays.
"""
macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true)
macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true, args_list = [])
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts)

args_list = parse_args_list(args_list)

args′ = map((a, T) -> :($a::$T), argnames, Ts)
ret_type = isnothing(ret_type) ? Real : ret_type

Expand All @@ -42,7 +44,7 @@ macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays =
else
return $wrap(res)
end
end $wrap_arrays)
end $wrap_arrays $args_list)

if define_promotion
fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type)
Expand Down Expand Up @@ -96,7 +98,7 @@ symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: SymbolicUtils.Symbolic
symbolic_eltype(::AbstractArray{Num}) = Real
symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT

function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true)
function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true, wrap_arrays = true, args_list = [])
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
Expand All @@ -123,7 +125,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
else
return $wrap(res)
end
end $wrap_arrays
end $wrap_arrays $args_list
end |> esc

if define_promotion
Expand Down Expand Up @@ -161,6 +163,23 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
return fexpr
end

function parse_args_list(args_list)
if isa(args_list, Expr) && args_list.head == :vect
args_list = args_list.args

# for each element of args_list, convert exp to a tuple
args_list = map(args_list) do exp
if exp isa Expr && exp.head == :tuple
exp.args
else
(exp,)
end
end
end

return args_list
end

"""
@register_array_symbolic(expr, define_promotion = true)

Expand Down Expand Up @@ -193,7 +212,10 @@ overloads for one function, all the rest of the registers must set
`define_promotion` to `false` except for the first one, to avoid method
overwriting.
"""
macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true)
macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true, args_list = [])
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([]))
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays)
end

args_list = parse_args_list(args_list)

register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays, args_list)
end
21 changes: 16 additions & 5 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function wraps_type end
has_symwrapper(::Type) = false
is_wrapper_type(::Type) = false

function wrap_func_expr(mod, expr, wrap_arrays = true)
function wrap_func_expr(mod, expr, wrap_arrays = true, args_list = [])
@assert expr.head == :function || (expr.head == :(=) &&
expr.args[1] isa Expr &&
expr.args[1].head == :call)
Expand Down Expand Up @@ -142,8 +142,19 @@ function wrap_func_expr(mod, expr, wrap_arrays = true)
impl = :(function $impl_name($self, $(names...))
$body
end)
# TODO: maybe don't drop first lol
methods = map(Iterators.drop(Iterators.product(types...), 1)) do Ts

if isempty(args_list)
# TODO: maybe don't drop first lol
it = Iterators.drop(Iterators.product(types...), 1)
else
# sanity check
if length(args_list[1]) != length(names)
error("args_list must have the same length as the number of arguments")
end
it = args_list
end

methods = map(it) do Ts
method_args = map(names, Ts) do n, T
:($n::$T)
end
Expand Down Expand Up @@ -171,6 +182,6 @@ function wrap_func_expr(mod, expr, wrap_arrays = true)
end |> esc
end

macro wrapped(expr, wrap_arrays = true)
wrap_func_expr(__module__, expr, wrap_arrays)
macro wrapped(expr, wrap_arrays = true, args_list = [])
wrap_func_expr(__module__, expr, wrap_arrays, args_list)
end