Skip to content

Improved Scan #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d4e3738
intial changes
keptsecret Mar 18, 2025
10d9c39
subgroup2 implementations
keptsecret Mar 27, 2025
f2a281c
some fixes, example
keptsecret Mar 27, 2025
4622f1f
changed template parameters
keptsecret Mar 28, 2025
abfaf67
working subgroup2 template and funcs
keptsecret Mar 31, 2025
f2d6d8a
fix reduction bug
keptsecret Mar 31, 2025
eeec20a
minor fix
keptsecret Apr 1, 2025
53ffc60
latest example
keptsecret Apr 2, 2025
0efeb8d
merge master, fix conflicts
keptsecret Apr 7, 2025
1478837
new example number
keptsecret Apr 7, 2025
e88f51a
partial spec for items per invoc =1
keptsecret Apr 9, 2025
a8e02a3
changes to Params, Config handling types
keptsecret Apr 10, 2025
237ac09
rework specializations for native, emulated funcs
keptsecret Apr 10, 2025
859c313
added OpSelect intrinsic for mix, fix mix behavior with bool
keptsecret Apr 4, 2025
c5a3223
use mix instead of ternary op
keptsecret Apr 11, 2025
87bca2b
fixes to subgroup2 funcs
keptsecret Apr 11, 2025
49fd605
changes to handle coalesced data loads
keptsecret Apr 21, 2025
4ae51a1
merge master, fix example conflicts
keptsecret Apr 21, 2025
609ad85
fixes to inclusive_scan for coalesced
keptsecret Apr 21, 2025
6b692f4
removed redundant code
keptsecret Apr 21, 2025
d0acb31
enabled handling vectors in spirv group ops with templates and enable_if
keptsecret Apr 23, 2025
fc92538
added impl component wise inclusive scan for inclusive scan
keptsecret Apr 23, 2025
8ad4843
revert to scans using consecutive data loads
keptsecret Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"


namespace nbl
Expand Down
47 changes: 47 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_


#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"

#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
#include "nbl/builtin/hlsl/concepts.hlsl"


namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void, bool OverrideUseNativeInstrinsics=true NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
struct ArithmeticParams
{
using config_t = Config;
using binop_t = BinOp;
using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type
using type_t = vector<scalar_t, _ItemsPerInvocation>;// conditional_t<_ItemsPerInvocation<2, scalar_t, vector<scalar_t, _ItemsPerInvocation> >;

NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
// if OverrideUseNativeInstrinsics is true, tries to use native spirv intrinsics
// if OverrideUseNativeInstrinsics is false, will always use emulated versions
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic && OverrideUseNativeInstrinsics /*&& /*some heuristic for when its faster*/;
};

template<typename Params>
struct reduction : impl::reduction<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
template<typename Params>
struct inclusive_scan : impl::inclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
template<typename Params>
struct exclusive_scan : impl::exclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};

}
}
}

#endif
183 changes: 183 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_

// #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
// #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"

// #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"

// #include "nbl/builtin/hlsl/functional.hlsl"

#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"

namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

namespace impl
{

template<class Params, uint32_t ItemsPerInvocation, bool native>
struct inclusive_scan
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

type_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
binop_t binop;
type_t retval;
retval[0] = value[0];
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval[i] = binop(retval[i-1], value[i]);

exclusive_scan_op_t op;
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);

[unroll]
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
retval[i] = binop(retval[i], exclusive);
return retval;
}
};

template<class Params, uint32_t ItemsPerInvocation, bool native>
struct exclusive_scan
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, ItemsPerInvocation, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

type_t operator()(type_t value)
{
inclusive_scan_op_t op;
value = op(value);

type_t left = glsl::subgroupShuffleUp<type_t>(value,1);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, if each invocation holds consecutive input and output elements, this shift becomes a mess (see that loop you have at the end)

also there was never a need to shuffle the entire vector, because you only ever used the last component

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do coalesced, then a plain subgroup shuffle on the vector and then conditional set of first element (literal vectorized version of old code) will achieve what you want

const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
// cyclic/modulo shuffle instead of relative needed
const type_t left = ItemsPerInvocation ? glsl::subgroupShuffle<type_t>(value,(invocationID-1)&SubgroupMask):glsl::subgroupShuffleUp<type_t>(value,1);
type_t newFirst; newFirst[0] = binop_t::identity;
[unroll]
for (uint32_t i=1; i<ItemsPerInvocation; i++)
   newFirst[i] = left[i-1];
return mix(newFirst,left,bool(glsl::gl_SubgroupInvocationID()));

P.S. also use mix(T,T,bool) instead of ? bevcause of HLSL short circuiting and turning ternaries into branches.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the subgroupShuffle with a modulo SubgroupSize can be replaced with new intrinsic from SPV_KHR_subgroup_rotate if you extend the device_limits.json and so on (so that device_capability_traits gets it)


type_t retval;
retval[0] = bool(glsl::gl_SubgroupInvocationID()) ? left[ItemsPerInvocation-1] : binop_t::identity;
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval[i] = value[i-1];
return retval;
}
};

template<class Params, uint32_t ItemsPerInvocation, bool native>
struct reduction
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using op_t = subgroup::impl::reduction<binop_t, native>;

// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;

scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
binop_t binop;
op_t op;
scalar_t retval = value[0];
[unroll]
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
retval = binop(retval, value[i]);
return op(retval);
}
};


// specs for N=1 uses subgroup funcs
// specialize native
// #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
// { \
// using type_t = T; \
// \
// type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
// }

// #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
// SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
// SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);

// SPECIALIZE_ALL(bit_and,And);
// SPECIALIZE_ALL(bit_or,Or);
// SPECIALIZE_ALL(bit_xor,Xor);

// SPECIALIZE_ALL(plus,Add);
// SPECIALIZE_ALL(multiplies,Mul);

// SPECIALIZE_ALL(minimum,Min);
// SPECIALIZE_ALL(maximum,Max);

// #undef SPECIALIZE_ALL
// #undef SPECIALIZE

// specialize portability
template<class Params, bool native>
struct inclusive_scan<Params, 1, native>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
// assert T == scalar type, binop::type == T

type_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
op_t op;
return op(value);
}
};

template<class Params, bool native>
struct exclusive_scan<Params, 1, native>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using op_t = subgroup::impl::exclusive_scan<binop_t, native>;

type_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
op_t op;
return op(value);
}
};

template<class Params, bool native>
struct reduction<Params, 1, native>
{
using type_t = typename Params::type_t;
using scalar_t = typename Params::scalar_t;
using binop_t = typename Params::binop_t;
using op_t = subgroup::impl::reduction<binop_t, native>;

scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
{
op_t op;
return op(value);
}
};

}

}
}
}

#endif
36 changes: 36 additions & 0 deletions include/nbl/builtin/hlsl/subgroup2/ballot.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
#define _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_

namespace nbl
{
namespace hlsl
{
namespace subgroup2
{

template<uint32_t SubgroupSizeLog2>
struct Configuration
{
using mask_t = conditional_t<SubgroupSizeLog2 < 7, conditional_t<SubgroupSizeLog2 < 6, uint32_t1, uint32_t2>, uint32_t4>;

NBL_CONSTEXPR_STATIC_INLINE uint16_t SizeLog2 = uint16_t(SubgroupSizeLog2);
NBL_CONSTEXPR_STATIC_INLINE uint16_t Size = uint16_t(0x1u) << SubgroupSizeLog2;
};

template<class T>
struct is_configuration : bool_constant<false> {};

template<uint32_t N>
struct is_configuration<Configuration<N> > : bool_constant<true> {};

template<typename T>
NBL_CONSTEXPR bool is_configuration_v = is_configuration<T>::value;

}
}
}

#endif