Skip to content

FluxPipeline produces noise when .enable_vae_slicing is used, and FluxImage2ImagePipeline does not support .enable_vae_slicing. #11540

Open
@Meatfucker

Description

@Meatfucker

Describe the bug

When using the flux pipeline, if vae slicing is enabled, it produces noise instead of images, and in the image2image pipeline it is not usable at all.

Reproduction

from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL, FluxTransformer2DModel, FluxPipeline, utils, FluxImg2ImgPipeline
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from optimum.quanto import freeze, qfloat8, quantize
import torch
import gc


async def generate_flux(prompt,
                        width,
                        height,
                        steps,
                        batch_size,
                        image=None,
                        strength=None,
                        model_name=None,
                        lora_name=None):
    width = width if width is not None else 1024
    height = height if height is not None else 1024
    steps = steps if steps is not None else 30
    batch_size = batch_size if batch_size is not None else 4
    strength = strength if strength is not None else 0.7
    if model_name is None:
        model_name = "black-forest-labs/FLUX.1-dev"
        revision = "refs/pr/3"
    dtype = torch.bfloat16
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler", revision=revision)
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=dtype)
    text_encoder_2 = T5EncoderModel.from_pretrained(model_name, subfolder="text_encoder_2", torch_dtype=dtype,
                                                    revision=revision)
    tokenizer_2 = T5TokenizerFast.from_pretrained(model_name, subfolder="tokenizer_2", torch_dtype=dtype,
                                                  revision=revision)
    vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype, revision=revision)
    transformer = FluxTransformer2DModel.from_pretrained(model_name, subfolder="transformer", torch_dtype=dtype, revision=revision)

    if image is not None:
        generator = FluxImg2ImgPipeline(scheduler=scheduler,
                                        text_encoder=text_encoder,
                                        tokenizer=tokenizer,
                                        text_encoder_2=text_encoder_2,
                                        tokenizer_2=tokenizer_2,
                                        vae=vae,
                                        transformer=transformer)
        generator.enable_vae_slicing()
    else:
        generator = FluxPipeline(scheduler=scheduler,
                                 text_encoder=text_encoder,
                                 tokenizer=tokenizer,
                                 text_encoder_2=text_encoder_2,
                                 tokenizer_2=tokenizer_2,
                                 vae=vae,
                                 transformer=transformer)
        generator.enable_vae_slicing()
    if lora_name is not None:
        try:
            generator.load_lora_weights(f"loras/flux/{lora_name}", weight_name=lora_name)
        except Exception as e:
            print(f"FLUX LORA ERROR: {e}")
    quantize(transformer, weights=qfloat8)
    freeze(transformer)
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)
    generator.to("cuda")
    generator.set_progress_bar_config(disable=True)
    if image is not None:
        images = generator(prompt=prompt,
                           image=image,
                           width=width, height=height,
                           num_inference_steps=steps,
                           strength=strength,
                           num_images_per_prompt=batch_size).images
    else:
        images = generator(prompt=prompt,
                           width=width, height=height,
                           num_inference_steps=steps,
                           num_images_per_prompt=batch_size).images
    generator.to("cpu")

    del generator, scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, transformer
    torch.cuda.empty_cache()
    gc.collect()
    return images

Logs

The fluxpipeline one produces no errors, the fluximage2imagepipeline says it cannot find the .enable_vae_slicing

System Info

accelerate==1.4.0
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.8.0
attrs==25.3.0
backoff==2.2.1
beautifulsoup4==4.13.4
bitsandbytes==0.45.3
certifi==2025.1.31
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.1
click==8.1.8
contourpy==1.3.2
cryptography==44.0.2
cycler==0.12.1
dataclasses-json==0.6.7
diffusers==0.33.1
emoji==2.14.1
eval_type_backport==0.2.2
faiss-cpu==1.10.0
fastapi==0.115.11
filelock==3.17.0
filetype==1.2.0
fonttools==4.58.0
frozenlist==1.6.0
fsspec==2025.2.0
ftfy==6.3.1
greenlet==3.2.1
h11==0.14.0
html5lib==1.1
httpcore==1.0.7
httpx==0.28.1
httpx-sse==0.4.0
huggingface-hub==0.29.2
idna==3.10
imageio==2.37.0
imageio-ffmpeg==0.6.0
importlib_metadata==8.6.1
Jinja2==3.1.6
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
kiwisolver==1.4.8
langchain-core==0.3.56
langchain-text-splitters==0.3.8
langdetect==1.0.9
langsmith==0.3.34
loguru==0.7.3
lxml==5.4.0
MarkupSafe==3.0.2
marshmallow==3.26.1
matplotlib==3.10.3
mpmath==1.3.0
multidict==6.4.3
mypy_extensions==1.1.0
nest-asyncio==1.6.0
networkx==3.4.2
ninja==1.11.1.3
nltk==3.9.1
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
olefile==0.47
opencv-python==4.11.0.86
optimum==1.24.0
optimum-quanto==0.2.7
orjson==3.10.16
packaging==24.2
peft==0.14.0
pillow==11.1.0
propcache==0.3.1
protobuf==6.30.0
psutil==7.0.0
pycparser==2.22
pydantic==2.11.3
pydantic-settings==2.9.1
pydantic_core==2.33.1
pyparsing==3.2.3
pypdf==5.4.0
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-iso639==2025.2.18
python-magic==0.4.27
python-oxmsg==0.0.2
PyYAML==6.0.2
RapidFuzz==3.13.0
regex==2024.11.6
requests==2.32.3
requests-toolbelt==1.0.0
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
sentence-transformers==4.1.0
sentencepiece==0.2.0
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.7
SQLAlchemy==2.0.40
starlette==0.46.0
sympy==1.13.1
tenacity==9.1.2
threadpoolctl==3.6.0
timm==1.0.15
tokenizers==0.21.0
torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.50.2
triton==3.2.0
typing-inspect==0.9.0
typing-inspection==0.4.0
typing_extensions==4.12.2
unstructured==0.17.2
unstructured-client==0.34.0
urllib3==2.3.0
uvicorn==0.34.0
wcwidth==0.2.13
webencodings==0.5.1
wrapt==1.17.2
yarl==1.20.0
zipp==3.21.0
zstandard==0.23.0

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions