FLUX (by Black Forest Labs) has taken the world of AI image generation by storm in the last few months. Not only has it beat Stable Diffusion (the prior open-source king) on many benchmarks, it has also surpassed proprietary models like Dall-E or Midjourney in some metrics.
But how would you go about using FLUX on one of your apps? One might think of using serverless hosts like Replicate and others, but these can get very expensive very quickly, and may not provide the flexibility you need. That’s where creating your own FLUX server comes in handy.
In this article, we’ll walk you through creating your own FLUX server using Python. This server will allow you to generate images based on text prompts via a simple API. Whether you’re running this server for personal use or deploying it as part of a production application, this guide will help you get started.
Prerequisites
Before diving into the code, let’s ensure you have the necessary tools and libraries set up:
- Python: You’ll need Python 3 installed on your machine, preferably version 3.10.
torch
: The deep learning framework we’ll use to run FLUX.diffusers
: Provides access to the FLUX model.transformers
: Required dependency of diffusers.sentencepiece
: Required to run the FLUX tokenizerprotobuf
: Required to run FLUXaccelerate
: Helps load the FLUX model more efficiently in some cases.fastapi
: Framework to create a web server that can accept image generation requests.uvicorn
: Required to run the fastapi server.psutil
: Allows us to check how much RAM there is on our machine.
You can install all the libraries by running the following command: pip install torch diffusers transformers accelerate fastapi uvicorn psutil
.
Note for MacOS Users: If you’re using a Mac with an M1 or M2 chip, you should set up PyTorch with Metal for optimal performance. Follow the official PyTorch with Metal guide before proceeding.
Step 1: Setting Up the Environment
Let’s start the script by picking the right device to run inference based on the hardware we’re using.
import torch
device = 'cuda' # can also be 'cpu' or 'mps'
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
You can specify cpu
, cuda
(for NVIDIA GPUs), or mps
(for Apple’s Metal Performance Shaders). The script then checks if the selected device is available and raises an exception if it’s not.
Step 2: Loading the FLUX Model
Next, we load the FLUX model. We’ll load the model in fp16 precision which will save us some memory without a much loss in quality.
Note: At this point, you may be asked to authenticate with HuggingFace, as the FLUX model is gated. In order to authenticate successfully, you’ll need to create a HuggingFace account, go to the model page, accept the terms, and then create a HuggingFace token from your account settings and add it on your machine as the HF_TOKEN environment variable.
from diffusers import DDIMScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
Here, we’re loading the FLUX model using the diffusers library. The model we’re using is black-forest-labs/FLUX.1-dev
, loaded in fp16 precision. There is alos a FLUX pro model which is stronger, but unfortunately not open-source so that cannot be used.
We’ll use the DDIM scheduler here, but you may also choose another one like Euler or UniPC. You can read more on schedulers here.
Since image generation can be resource-intensive, it’s crucial to optimize memory usage, especially when running on a CPU or a device with limited memory.
# Recommended if running on MPS or CPU with 64 GB of RAM
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb 64:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
This code checks the total available memory and enables attention slicing if the system has less than 64 GB of RAM. Attention slicing reduces memory usage during image generation, which is essential for devices with limited resources.
Step 3: Creating the API with FastAPI
Next, we’ll set up the FastAPI server, which will provide an API to generate images.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
FastAPI is a popular framework for building web APIs with Python. In this case, we’re using it to create a server that can accept requests for image generation. We’re also using GZip middleware to compress the response, which is particularly useful when sending images back in base64 format.
Note: In a production environment, you might want to store the generated images in an S3 bucket or other cloud storage and return the URLs instead of the base64-encoded strings, to take advantage of a CDN and other optimizations.
Step 4: Defining the Request Model
We need to define a model for the requests that our API will accept.
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
This GenerateRequest model defines the parameters required to generate an image. The prompt is the text description of the image you want to create. The negative_prompt can be used to specify what you don’t want in the image. Other fields include the image dimensions, the number of inference steps, and the batch size.
Step 5: Creating the Image Generation Endpoint
Now, let’s create the endpoint that will handle image generation requests.
@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
base64_images = []
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
This endpoint handles the image generation process. It first validates that the height and width are multiples of 8, as required by FLUX. It then generates images based on the provided prompt and returns them as base64-encoded strings.
Step 6: Starting the Server
Finally, let’s add some code to start the server when the script is run.
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
This code starts the FastAPI server on port 8000, making it accessible from http://localhost:8000
.
Step 7: Testing Your Server Locally
Now that your FLUX server is up and running, it’s time to test it. You can use curl, a command-line tool for making HTTP requests, to interact with your server:
curl -X POST "http://localhost:8000/" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A futuristic cityscape at sunset",
"negative_prompt": "low quality, blurry",
"seed": 42,
"height": 512,
"width": 512,
"cfg": 7.5,
"steps": 50,
"batch_size": 1
}'
Conclusion
Congratulations! You’ve successfully created your own FLUX server using Python. This setup allows you to generate images based on text prompts via a simple API. If you’re not satisfied with the results of the base FLUX model, you might consider fine-tuning the model for even better performance or specific use cases.
Full Code
You may find the full code used in this guide below:
import torch
device = 'cuda'
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
from diffusers import DDIMScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size)]
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
base64_images = []
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)