Device TypeError while running PyTorch model

Running Stability AI’s SDXL-Base-1.0 model in a docker container on Windows WSL 2.0 causes device TypeError during inference.

    # loading the model into memory (works fine)
    base = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
        cache_dir=cache_dir)
    base.to("cuda")

    refiner = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-refiner-1.0",
        text_encoder_2=base.text_encoder_2,
        vae=base.vae,
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
        cache_dir=cache_dir)
    refiner.to("cuda")

Inference with a prompt stored in prompt variable

    # run both experts
    b_latent = base(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
        height=imsize,
        width=imsize,
    ).images

    r_image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=b_latent,
    ).images[0]

Results in the following error:

2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/azureml_inference_server_http/server/user_script.py", line 132, in invoke_run
2024-10-21 22:14:03     run_output = self._wrapped_user_run(**run_parameters, request_headers=dict(request.headers))
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/azureml_inference_server_http/server/user_script.py", line 156, in <lambda>
2024-10-21 22:14:03     self._wrapped_user_run = lambda request_headers, **kwargs: self._user_run(**kwargs)
2024-10-21 22:14:03   File "/var/azureml-app/src/az.score.py", line 146, in run
2024-10-21 22:14:03     b_latent = base(
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-10-21 22:14:03     return func(*args, **kwargs)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1094, in __call__
2024-10-21 22:14:03     timesteps, num_inference_steps = retrieve_timesteps(
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 158, in retrieve_timesteps
2024-10-21 22:14:03     scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/diffusers/schedulers/scheduling_euler_discrete.py", line 383, in set_timesteps
2024-10-21 22:14:03     sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/torch/_tensor.py", line 1060, in __array__
2024-10-21 22:14:03     return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/torch/overrides.py", line 1604, in handle_torch_function
2024-10-21 22:14:03     result = mode.__torch_function__(public_api, types, args, kwargs)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/torch/utils/_device.py", line 77, in __torch_function__
2024-10-21 22:14:03     return func(*args, **kwargs)
2024-10-21 22:14:03   File "/opt/miniconda/envs/inf-conda-env/lib/python3.10/site-packages/torch/_tensor.py", line 1062, in __array__
2024-10-21 22:14:03     return self.numpy()
2024-10-21 22:14:03 TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

This seems like an error caused inside the SDXL’s source code and I cannot do anything about it. Is there something I need to do while configuring the Container that will solve this issue? Or what could be the solution to this issue?

PS: the container has GPU access enabled, when run on a Conda environment in the container:

> import torch
> torch.cuda.is_available()
true