mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-31 02:32:57 +08:00
Refactor gradio auth
This commit is contained in:
parent
674e80c625
commit
0f28aee9cd
37
webui.py
37
webui.py
@ -7,6 +7,7 @@ import re
|
|||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@ -178,6 +179,32 @@ def validate_tls_options():
|
|||||||
startup_timer.record("TLS")
|
startup_timer.record("TLS")
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
|
||||||
|
"""
|
||||||
|
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
||||||
|
an iterable of (username, password) tuples.
|
||||||
|
"""
|
||||||
|
def process_credential_line(s) -> tuple[str, ...] | None:
|
||||||
|
s = s.strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
return tuple(s.split(':', 1))
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth:
|
||||||
|
for cred in cmd_opts.gradio_auth.split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth_path:
|
||||||
|
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
for cred in line.strip().split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
|
||||||
def configure_sigint_handler():
|
def configure_sigint_handler():
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
@ -316,13 +343,7 @@ def webui():
|
|||||||
if not cmd_opts.no_gradio_queue:
|
if not cmd_opts.no_gradio_queue:
|
||||||
shared.demo.queue(64)
|
shared.demo.queue(64)
|
||||||
|
|
||||||
gradio_auth_creds = []
|
gradio_auth_creds = list(get_gradio_auth_creds()) or None
|
||||||
if cmd_opts.gradio_auth:
|
|
||||||
gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
|
|
||||||
if cmd_opts.gradio_auth_path:
|
|
||||||
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
|
||||||
for line in file.readlines():
|
|
||||||
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
|
|
||||||
|
|
||||||
# this restores the missing /docs endpoint
|
# this restores the missing /docs endpoint
|
||||||
if launch_api and not hasattr(FastAPI, 'original_setup'):
|
if launch_api and not hasattr(FastAPI, 'original_setup'):
|
||||||
@ -343,7 +364,7 @@ def webui():
|
|||||||
ssl_certfile=cmd_opts.tls_certfile,
|
ssl_certfile=cmd_opts.tls_certfile,
|
||||||
ssl_verify=cmd_opts.disable_tls_verify,
|
ssl_verify=cmd_opts.disable_tls_verify,
|
||||||
debug=cmd_opts.gradio_debug,
|
debug=cmd_opts.gradio_debug,
|
||||||
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
|
auth=gradio_auth_creds,
|
||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
prevent_thread_lock=True,
|
prevent_thread_lock=True,
|
||||||
allowed_paths=cmd_opts.gradio_allowed_path,
|
allowed_paths=cmd_opts.gradio_allowed_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user