Refactor gradio auth

This commit is contained in:
Aarni Koskela 2023-05-19 17:28:41 +03:00
parent 674e80c625
commit 0f28aee9cd

View File

@ -7,6 +7,7 @@ import re
import warnings import warnings
import json import json
from threading import Thread from threading import Thread
from typing import Iterable
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -178,6 +179,32 @@ def validate_tls_options():
startup_timer.record("TLS") startup_timer.record("TLS")
def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
"""
Convert the gradio_auth and gradio_auth_path commandline arguments into
an iterable of (username, password) tuples.
"""
def process_credential_line(s) -> tuple[str, ...] | None:
s = s.strip()
if not s:
return None
return tuple(s.split(':', 1))
if cmd_opts.gradio_auth:
for cred in cmd_opts.gradio_auth.split(','):
cred = process_credential_line(cred)
if cred:
yield cred
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
for cred in line.strip().split(','):
cred = process_credential_line(cred)
if cred:
yield cred
def configure_sigint_handler(): def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
@ -316,13 +343,7 @@ def webui():
if not cmd_opts.no_gradio_queue: if not cmd_opts.no_gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
gradio_auth_creds = [] gradio_auth_creds = list(get_gradio_auth_creds()) or None
if cmd_opts.gradio_auth:
gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
# this restores the missing /docs endpoint # this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'): if launch_api and not hasattr(FastAPI, 'original_setup'):
@ -343,7 +364,7 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile, ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify, ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True, prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path, allowed_paths=cmd_opts.gradio_allowed_path,