Text-to-Image
Diffusers
Safetensors
StableDiffusionPipeline
stable-diffusion
stable-diffusion-diffusers
Instructions to use CompVis/stable-diffusion-v1-4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use CompVis/stable-diffusion-v1-4 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", dtype=torch.bfloat16, device_map="cuda") prompt = "A high tech solarpunk utopia in the Amazon rainforest" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
Commit ·
bd73f2a
1
Parent(s): 114c79c
Do not assume 8 devices in JAX (#154)
Browse files- Do not assume 8 devices in JAX (e124bbdca2dab1af0cdce19d575f8043eab9341e)
Co-authored-by: Pedro Cuenca <pcuenq@users.noreply.huggingface.co>
README.md
CHANGED
|
@@ -154,7 +154,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
| 154 |
|
| 155 |
# shard inputs and rng
|
| 156 |
params = replicate(params)
|
| 157 |
-
prng_seed = jax.random.split(prng_seed,
|
| 158 |
prompt_ids = shard(prompt_ids)
|
| 159 |
|
| 160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
@@ -187,7 +187,7 @@ prompt_ids = pipeline.prepare_inputs(prompt)
|
|
| 187 |
|
| 188 |
# shard inputs and rng
|
| 189 |
params = replicate(params)
|
| 190 |
-
prng_seed = jax.random.split(prng_seed,
|
| 191 |
prompt_ids = shard(prompt_ids)
|
| 192 |
|
| 193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
|
| 154 |
|
| 155 |
# shard inputs and rng
|
| 156 |
params = replicate(params)
|
| 157 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
| 158 |
prompt_ids = shard(prompt_ids)
|
| 159 |
|
| 160 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
|
|
|
| 187 |
|
| 188 |
# shard inputs and rng
|
| 189 |
params = replicate(params)
|
| 190 |
+
prng_seed = jax.random.split(prng_seed, num_samples)
|
| 191 |
prompt_ids = shard(prompt_ids)
|
| 192 |
|
| 193 |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|