From 0f28aee9cd12b8294df80506e6466cd90a9ae195 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Fri, 19 May 2023 17:28:41 +0300
Subject: [PATCH] Refactor gradio auth

---
 webui.py | 37 +++++++++++++++++++++++++++++--------
 1 file changed, 29 insertions(+), 8 deletions(-)

diff --git a/webui.py b/webui.py
index e568ef42c..64b113dd8 100644
--- a/webui.py
+++ b/webui.py
@@ -7,6 +7,7 @@ import re
 import warnings
 import json
 from threading import Thread
+from typing import Iterable
 
 from fastapi import FastAPI, Response
 from fastapi.middleware.cors import CORSMiddleware
@@ -178,6 +179,32 @@ def validate_tls_options():
     startup_timer.record("TLS")
 
 
+def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
+    """
+    Convert the gradio_auth and gradio_auth_path commandline arguments into
+    an iterable of (username, password) tuples.
+    """
+    def process_credential_line(s) -> tuple[str, ...] | None:
+        s = s.strip()
+        if not s:
+            return None
+        return tuple(s.split(':', 1))
+
+    if cmd_opts.gradio_auth:
+        for cred in cmd_opts.gradio_auth.split(','):
+            cred = process_credential_line(cred)
+            if cred:
+                yield cred
+
+    if cmd_opts.gradio_auth_path:
+        with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+            for line in file.readlines():
+                for cred in line.strip().split(','):
+                    cred = process_credential_line(cred)
+                    if cred:
+                        yield cred
+
+
 def configure_sigint_handler():
     # make the program just exit at ctrl+c without waiting for anything
     def sigint_handler(sig, frame):
@@ -316,13 +343,7 @@ def webui():
         if not cmd_opts.no_gradio_queue:
             shared.demo.queue(64)
 
-        gradio_auth_creds = []
-        if cmd_opts.gradio_auth:
-            gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
-        if cmd_opts.gradio_auth_path:
-            with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
-                for line in file.readlines():
-                    gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+        gradio_auth_creds = list(get_gradio_auth_creds()) or None
 
         # this restores the missing /docs endpoint
         if launch_api and not hasattr(FastAPI, 'original_setup'):
@@ -343,7 +364,7 @@ def webui():
             ssl_certfile=cmd_opts.tls_certfile,
             ssl_verify=cmd_opts.disable_tls_verify,
             debug=cmd_opts.gradio_debug,
-            auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
+            auth=gradio_auth_creds,
             inbrowser=cmd_opts.autolaunch,
             prevent_thread_lock=True,
             allowed_paths=cmd_opts.gradio_allowed_path,