2023-11-20 19:47:09 +08:00
from __future__ import annotations
2023-11-11 18:01:13 +08:00
import configparser
2024-03-10 20:14:04 +08:00
import dataclasses
2022-10-31 22:36:45 +08:00
import os
2023-05-16 01:57:11 +08:00
import threading
2023-11-12 00:58:26 +08:00
import re
2022-10-31 22:36:45 +08:00
2023-08-09 15:25:35 +08:00
from modules import shared , errors , cache , scripts
2023-05-29 05:41:12 +08:00
from modules . gitpython_hack import Repo
2023-05-10 14:02:23 +08:00
from modules . paths_internal import extensions_dir , extensions_builtin_dir , script_path # noqa: F401
2022-10-31 22:36:45 +08:00
2024-03-17 16:51:40 +08:00
extensions : list [ Extension ] = [ ]
extension_paths : dict [ str , Extension ] = { }
loaded_extensions : dict [ str , Exception ] = { }
2022-10-31 22:36:45 +08:00
2023-05-29 15:18:15 +08:00
os . makedirs ( extensions_dir , exist_ok = True )
2023-03-27 15:02:30 +08:00
2022-10-31 22:36:45 +08:00
def active ( ) :
2023-07-29 00:07:35 +08:00
if shared . cmd_opts . disable_all_extensions or shared . opts . disable_all_extensions == " all " :
2023-03-28 00:44:49 +08:00
return [ ]
2023-07-29 00:07:35 +08:00
elif shared . cmd_opts . disable_extra_extensions or shared . opts . disable_all_extensions == " extra " :
2023-03-28 00:44:49 +08:00
return [ x for x in extensions if x . enabled and x . is_builtin ]
else :
return [ x for x in extensions if x . enabled ]
2022-10-31 22:36:45 +08:00
2024-03-10 20:14:04 +08:00
@dataclasses.dataclass
class CallbackOrderInfo :
name : str
before : list
after : list
2023-11-20 19:47:09 +08:00
class ExtensionMetadata :
filename = " metadata.ini "
config : configparser . ConfigParser
canonical_name : str
requires : list
def __init__ ( self , path , canonical_name ) :
self . config = configparser . ConfigParser ( )
filepath = os . path . join ( path , self . filename )
2024-01-04 06:16:58 +08:00
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
# so no need to check whether the file exists beforehand.
try :
self . config . read ( filepath )
except Exception :
errors . report ( f " Error reading { self . filename } for extension { canonical_name } . " , exc_info = True )
2023-11-20 19:47:09 +08:00
self . canonical_name = self . config . get ( " Extension " , " Name " , fallback = canonical_name )
self . canonical_name = canonical_name . lower ( ) . strip ( )
2024-03-17 16:51:40 +08:00
self . requires = None
2023-11-20 19:47:09 +08:00
def get_script_requirements ( self , field , section , extra_section = None ) :
""" reads a list of requirements from the config; field is the name of the field in the ini file,
like Requires or Before , and section is the name of the [ section ] in the ini file ; additionally ,
reads more requirements from [ extra_section ] if specified . """
x = self . config . get ( section , field , fallback = ' ' )
if extra_section :
x = x + ' , ' + self . config . get ( extra_section , field , fallback = ' ' )
2024-03-17 18:30:20 +08:00
listed_requirements = self . parse_list ( x . lower ( ) )
res = [ ]
for requirement in listed_requirements :
loaded_requirements = ( x for x in requirement . split ( " | " ) if x in loaded_extensions )
2024-03-18 13:00:30 +08:00
relevant_requirement = next ( loaded_requirements , requirement )
2024-03-17 18:30:20 +08:00
res . append ( relevant_requirement )
return res
2023-11-20 19:47:09 +08:00
def parse_list ( self , text ) :
""" converts a line from config ( " ext1 ext2, ext3 " ) into a python list ([ " ext1 " , " ext2 " , " ext3 " ]) """
if not text :
return [ ]
# both "," and " " are accepted as separator
return [ x for x in re . split ( r " [, \ s]+ " , text . strip ( ) ) if x ]
2024-03-10 20:14:04 +08:00
def list_callback_order_instructions ( self ) :
for section in self . config . sections ( ) :
if not section . startswith ( " callbacks/ " ) :
continue
callback_name = section [ 10 : ]
if not callback_name . startswith ( self . canonical_name ) :
errors . report ( f " Callback order section for extension { self . canonical_name } is referencing the wrong extension: { section } " )
continue
before = self . parse_list ( self . config . get ( section , ' Before ' , fallback = ' ' ) )
after = self . parse_list ( self . config . get ( section , ' After ' , fallback = ' ' ) )
yield CallbackOrderInfo ( callback_name , before , after )
2023-11-20 19:47:09 +08:00
2022-10-31 22:36:45 +08:00
class Extension :
2023-05-16 01:57:11 +08:00
lock = threading . Lock ( )
2023-07-15 14:20:43 +08:00
cached_fields = [ ' remote ' , ' commit_date ' , ' branch ' , ' commit_hash ' , ' version ' ]
2023-11-20 19:47:09 +08:00
metadata : ExtensionMetadata
2023-05-16 01:57:11 +08:00
2023-11-20 19:47:09 +08:00
def __init__ ( self , name , path , enabled = True , is_builtin = False , metadata = None ) :
2022-10-31 22:36:45 +08:00
self . name = name
self . path = path
self . enabled = enabled
self . status = ' '
self . can_update = False
2022-12-03 23:06:33 +08:00
self . is_builtin = is_builtin
2023-03-30 05:46:03 +08:00
self . commit_hash = ' '
self . commit_date = None
2023-02-14 00:04:34 +08:00
self . version = ' '
2023-03-30 05:46:03 +08:00
self . branch = None
2023-03-27 15:02:30 +08:00
self . remote = None
self . have_info_from_repo = False
2023-11-20 19:47:09 +08:00
self . metadata = metadata if metadata else ExtensionMetadata ( self . path , name . lower ( ) )
self . canonical_name = metadata . canonical_name
2023-11-11 18:01:13 +08:00
2023-07-15 14:20:43 +08:00
def to_dict ( self ) :
return { x : getattr ( self , x ) for x in self . cached_fields }
def from_dict ( self , d ) :
for field in self . cached_fields :
setattr ( self , field , d [ field ] )
2023-03-27 15:02:30 +08:00
def read_info_from_repo ( self ) :
2023-03-30 05:46:03 +08:00
if self . is_builtin or self . have_info_from_repo :
2023-03-27 15:02:30 +08:00
return
2023-07-15 14:20:43 +08:00
def read_from_repo ( ) :
with self . lock :
if self . have_info_from_repo :
return
self . do_read_info_from_repo ( )
return self . to_dict ( )
2023-11-12 00:58:26 +08:00
2023-07-25 19:01:10 +08:00
try :
d = cache . cached_data_for_file ( ' extensions-git ' , self . name , os . path . join ( self . path , " .git " ) , read_from_repo )
self . from_dict ( d )
except FileNotFoundError :
pass
2023-07-26 12:43:38 +08:00
self . status = ' unknown ' if self . status == ' ' else self . status
2022-10-31 22:36:45 +08:00
2023-05-16 01:57:11 +08:00
def do_read_info_from_repo ( self ) :
2022-10-31 22:36:45 +08:00
repo = None
try :
2023-03-27 15:02:30 +08:00
if os . path . exists ( os . path . join ( self . path , " .git " ) ) :
2023-05-29 05:41:12 +08:00
repo = Repo ( self . path )
2022-10-31 22:36:45 +08:00
except Exception :
2023-06-01 00:56:37 +08:00
errors . report ( f " Error reading github repository info from { self . path } " , exc_info = True )
2022-10-31 22:36:45 +08:00
if repo is None or repo . bare :
self . remote = None
else :
2022-11-05 20:04:48 +08:00
try :
2023-03-27 15:02:30 +08:00
self . remote = next ( repo . remote ( ) . urls , None )
2023-05-21 18:30:00 +08:00
commit = repo . head . commit
self . commit_date = commit . committed_date
2023-03-30 05:46:03 +08:00
if repo . active_branch :
self . branch = repo . active_branch . name
2023-05-21 18:30:00 +08:00
self . commit_hash = commit . hexsha
self . version = self . commit_hash [ : 8 ]
2023-03-30 05:46:03 +08:00
2023-05-29 13:54:13 +08:00
except Exception :
2023-06-01 00:56:37 +08:00
errors . report ( f " Failed reading extension data from Git repository ( { self . name } ) " , exc_info = True )
2022-11-05 20:04:48 +08:00
self . remote = None
2022-10-31 22:36:45 +08:00
2023-05-16 01:57:11 +08:00
self . have_info_from_repo = True
2022-10-31 22:36:45 +08:00
def list_files ( self , subdir , extension ) :
dirpath = os . path . join ( self . path , subdir )
if not os . path . isdir ( dirpath ) :
return [ ]
res = [ ]
for filename in sorted ( os . listdir ( dirpath ) ) :
2022-10-31 23:40:47 +08:00
res . append ( scripts . ScriptFile ( self . path , filename , os . path . join ( dirpath , filename ) ) )
2022-10-31 22:36:45 +08:00
res = [ x for x in res if os . path . splitext ( x . path ) [ 1 ] . lower ( ) == extension and os . path . isfile ( x . path ) ]
return res
def check_updates ( self ) :
2023-05-29 05:41:12 +08:00
repo = Repo ( self . path )
2023-02-26 03:15:06 +08:00
for fetch in repo . remote ( ) . fetch ( dry_run = True ) :
2024-03-13 02:21:59 +08:00
if self . branch and fetch . name != f ' { repo . remote ( ) . name } / { self . branch } ' :
continue
2022-10-31 22:36:45 +08:00
if fetch . flags != fetch . HEAD_UPTODATE :
self . can_update = True
2023-03-30 07:32:29 +08:00
self . status = " new commits "
2022-10-31 22:36:45 +08:00
return
2023-03-30 07:32:29 +08:00
try :
origin = repo . rev_parse ( ' origin ' )
if repo . head . commit != origin :
self . can_update = True
self . status = " behind HEAD "
return
except Exception :
self . can_update = False
self . status = " unknown (remote error) "
return
2022-10-31 22:36:45 +08:00
self . can_update = False
self . status = " latest "
2023-03-30 05:46:03 +08:00
def fetch_and_reset_hard ( self , commit = ' origin ' ) :
2023-05-29 05:41:12 +08:00
repo = Repo ( self . path )
2022-11-13 02:44:42 +08:00
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
2023-02-26 03:15:06 +08:00
repo . git . fetch ( all = True )
2023-03-30 05:46:03 +08:00
repo . git . reset ( commit , hard = True )
2023-03-30 07:32:29 +08:00
self . have_info_from_repo = False
2022-10-31 22:36:45 +08:00
def list_extensions ( ) :
extensions . clear ( )
2024-03-10 12:52:57 +08:00
extension_paths . clear ( )
2024-03-17 16:51:40 +08:00
loaded_extensions . clear ( )
2022-10-31 22:36:45 +08:00
2023-07-29 00:07:35 +08:00
if shared . cmd_opts . disable_all_extensions :
print ( " *** \" --disable-all-extensions \" arg was used, will not load any extensions *** " )
elif shared . opts . disable_all_extensions == " all " :
2023-03-28 00:04:45 +08:00
print ( " *** \" Disable all extensions \" option was set, will not load any extensions *** " )
2023-07-29 00:07:35 +08:00
elif shared . cmd_opts . disable_extra_extensions :
print ( " *** \" --disable-extra-extensions \" arg was used, will only load built-in extensions *** " )
2023-03-28 00:44:49 +08:00
elif shared . opts . disable_all_extensions == " extra " :
print ( " *** \" Disable all extensions \" option was set, will only load built-in extensions *** " )
2023-03-28 00:04:45 +08:00
2023-11-11 18:01:13 +08:00
# scan through extensions directory and load metadata
2023-11-11 18:08:45 +08:00
for dirname in [ extensions_builtin_dir , extensions_dir ] :
2022-12-03 23:06:33 +08:00
if not os . path . isdir ( dirname ) :
2023-11-11 18:01:13 +08:00
continue
2022-10-31 22:36:45 +08:00
2022-12-03 23:06:33 +08:00
for extension_dirname in sorted ( os . listdir ( dirname ) ) :
path = os . path . join ( dirname , extension_dirname )
if not os . path . isdir ( path ) :
continue
2023-11-11 18:01:13 +08:00
canonical_name = extension_dirname
2023-11-20 19:47:09 +08:00
metadata = ExtensionMetadata ( path , canonical_name )
2023-11-11 18:01:13 +08:00
2023-11-20 19:47:09 +08:00
# check for duplicated canonical names
already_loaded_extension = loaded_extensions . get ( metadata . canonical_name )
if already_loaded_extension is not None :
errors . report ( f ' Duplicate canonical name " { canonical_name } " found in extensions " { extension_dirname } " and " { already_loaded_extension . name } " . Former will be discarded. ' , exc_info = False )
continue
2023-11-11 18:01:13 +08:00
2023-11-20 19:47:09 +08:00
is_builtin = dirname == extensions_builtin_dir
extension = Extension ( name = extension_dirname , path = path , enabled = extension_dirname not in shared . opts . disabled_extensions , is_builtin = is_builtin , metadata = metadata )
extensions . append ( extension )
2024-03-10 12:52:57 +08:00
extension_paths [ extension . path ] = extension
2023-11-20 19:47:09 +08:00
loaded_extensions [ canonical_name ] = extension
2023-11-11 18:01:13 +08:00
2024-03-17 16:51:40 +08:00
for extension in extensions :
extension . metadata . requires = extension . metadata . get_script_requirements ( " Requires " , " Extension " )
2023-11-20 19:47:09 +08:00
# check for requirements
for extension in extensions :
2024-01-13 18:45:15 +08:00
if not extension . enabled :
continue
2023-11-20 19:47:09 +08:00
for req in extension . metadata . requires :
required_extension = loaded_extensions . get ( req )
if required_extension is None :
errors . report ( f ' Extension " { extension . name } " requires " { req } " which is not installed. ' , exc_info = False )
2023-11-11 18:01:13 +08:00
continue
2022-12-03 23:06:33 +08:00
2024-01-13 18:45:15 +08:00
if not required_extension . enabled :
2023-11-20 19:47:09 +08:00
errors . report ( f ' Extension " { extension . name } " requires " { required_extension . name } " which is disabled. ' , exc_info = False )
continue
2023-11-11 18:01:13 +08:00
2024-03-10 12:52:57 +08:00
def find_extension ( filename ) :
parentdir = os . path . dirname ( os . path . realpath ( filename ) )
while parentdir != filename :
extension = extension_paths . get ( parentdir )
if extension is not None :
return extension
filename = parentdir
parentdir = os . path . dirname ( filename )
return None