@@ -739,20 +739,31 @@ impl<'tcx> CodegenCx<'tcx> {
739
739
. decorate ( var_id. unwrap ( ) , Decoration :: Invariant , std:: iter:: empty ( ) ) ;
740
740
}
741
741
if let Some ( per_primitive_ext) = attrs. per_primitive_ext {
742
- if storage_class != Ok ( StorageClass :: Output ) {
743
- self . tcx . dcx ( ) . span_fatal (
744
- per_primitive_ext. span ,
745
- "`#[spirv(per_primitive_ext)]` is only valid on Output variables" ,
746
- ) ;
747
- }
748
- if !( execution_model == ExecutionModel :: MeshEXT
749
- || execution_model == ExecutionModel :: MeshNV )
750
- {
751
- self . tcx . dcx ( ) . span_fatal (
752
- per_primitive_ext. span ,
753
- "`#[spirv(per_primitive_ext)]` is only valid in mesh shaders" ,
754
- ) ;
742
+ match execution_model {
743
+ ExecutionModel :: Fragment => {
744
+ if storage_class != Ok ( StorageClass :: Input ) {
745
+ self . tcx . dcx ( ) . span_fatal (
746
+ per_primitive_ext. span ,
747
+ "`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables" ,
748
+ ) ;
749
+ }
750
+ }
751
+ ExecutionModel :: MeshNV | ExecutionModel :: MeshEXT => {
752
+ if storage_class != Ok ( StorageClass :: Output ) {
753
+ self . tcx . dcx ( ) . span_fatal (
754
+ per_primitive_ext. span ,
755
+ "`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables" ,
756
+ ) ;
757
+ }
758
+ }
759
+ _ => {
760
+ self . tcx . dcx ( ) . span_fatal (
761
+ per_primitive_ext. span ,
762
+ "`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders" ,
763
+ ) ;
764
+ }
755
765
}
766
+
756
767
self . emit_global ( ) . decorate (
757
768
var_id. unwrap ( ) ,
758
769
Decoration :: PerPrimitiveEXT ,
0 commit comments