| import argparse |
| import torch |
|
|
| from dassl.utils import setup_logger, set_random_seed, collect_env_info |
| from dassl.config import get_cfg_default |
| from dassl.engine import build_trainer |
|
|
| |
| import datasets.oxford_pets |
| import datasets.oxford_flowers |
| import datasets.fgvc_aircraft |
| import datasets.dtd |
| import datasets.eurosat |
| import datasets.stanford_cars |
| import datasets.food101 |
| import datasets.sun397 |
| import datasets.caltech101 |
| import datasets.ucf101 |
| import datasets.imagenet |
|
|
| import datasets.imagenet_sketch |
| import datasets.imagenetv2 |
| import datasets.imagenet_a |
| import datasets.imagenet_r |
|
|
| import trainers.coop |
| import trainers.cocoop |
| import trainers.kgcoop |
| import trainers.zsclip |
| import trainers.maple |
| import trainers.independentVL |
| import trainers.promptsrc |
| import trainers.tcp |
| import trainers.supr |
| import trainers.supr_ens |
| import trainers.elp_promptsrc |
| import trainers.supr_promptsrc |
|
|
|
|
| def print_args(args, cfg): |
| print("***************") |
| print("** Arguments **") |
| print("***************") |
| optkeys = list(args.__dict__.keys()) |
| optkeys.sort() |
| for key in optkeys: |
| print("{}: {}".format(key, args.__dict__[key])) |
| print("************") |
| print("** Config **") |
| print("************") |
| print(cfg) |
|
|
|
|
| def reset_cfg(cfg, args): |
| if args.root: |
| cfg.DATASET.ROOT = args.root |
|
|
| if args.output_dir: |
| cfg.OUTPUT_DIR = args.output_dir |
|
|
| if args.resume: |
| cfg.RESUME = args.resume |
|
|
| if args.seed: |
| cfg.SEED = args.seed |
|
|
| if args.source_domains: |
| cfg.DATASET.SOURCE_DOMAINS = args.source_domains |
|
|
| if args.target_domains: |
| cfg.DATASET.TARGET_DOMAINS = args.target_domains |
|
|
| if args.transforms: |
| cfg.INPUT.TRANSFORMS = args.transforms |
|
|
| if args.trainer: |
| cfg.TRAINER.NAME = args.trainer |
|
|
| if args.backbone: |
| cfg.MODEL.BACKBONE.NAME = args.backbone |
|
|
| if args.head: |
| cfg.MODEL.HEAD.NAME = args.head |
|
|
|
|
| def extend_cfg(cfg): |
| """ |
| Add new config variables. |
| |
| E.g. |
| from yacs.config import CfgNode as CN |
| cfg.TRAINER.MY_MODEL = CN() |
| cfg.TRAINER.MY_MODEL.PARAM_A = 1. |
| cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 |
| cfg.TRAINER.MY_MODEL.PARAM_C = False |
| """ |
| from yacs.config import CfgNode as CN |
|
|
| cfg.TRAINER.COOP = CN() |
| cfg.TRAINER.COOP.N_CTX = 16 |
| cfg.TRAINER.COOP.CSC = False |
| cfg.TRAINER.COOP.CTX_INIT = "" |
| cfg.TRAINER.COOP.PREC = "fp16" |
| cfg.TRAINER.COOP.W = 8.0 |
| cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" |
|
|
| cfg.TRAINER.COCOOP = CN() |
| cfg.TRAINER.COCOOP.N_CTX = 16 |
| cfg.TRAINER.COCOOP.CTX_INIT = "" |
| cfg.TRAINER.COCOOP.PREC = "fp16" |
|
|
| |
| cfg.TRAINER.MAPLE = CN() |
| cfg.TRAINER.MAPLE.N_CTX = 2 |
| cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" |
| cfg.TRAINER.MAPLE.PREC = "fp16" |
| cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 |
| cfg.DATASET.SUBSAMPLE_CLASSES = "all" |
|
|
| |
| cfg.TRAINER.PROMPTSRC = CN() |
| cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 |
| cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 |
| cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" |
| cfg.TRAINER.PROMPTSRC.PREC = "fp16" |
| cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 |
| cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 |
| cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25 |
| cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 |
| cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 |
| cfg.TRAINER.PROMPTSRC.GPA_STD = 1 |
|
|
|
|
| |
| cfg.TRAINER.IVLP = CN() |
| cfg.TRAINER.IVLP.N_CTX_VISION = 2 |
| cfg.TRAINER.IVLP.N_CTX_TEXT = 2 |
| cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" |
| cfg.TRAINER.IVLP.PREC = "fp16" |
| |
| cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 |
| cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 |
| cfg.DATASET.SUBSAMPLE_CLASSES = "all" |
| cfg.TEST.NO_TEST = False |
|
|
| |
| |
| |
| cfg.TRAINER.LINEAR_PROBE = CN() |
| cfg.TRAINER.LINEAR_PROBE.TYPE = 'linear' |
| cfg.TRAINER.LINEAR_PROBE.WEIGHT = 0.3 |
| cfg.TRAINER.LINEAR_PROBE.TEST_TIME_FUSION = True |
|
|
| |
| cfg.TRAINER.FILM = CN() |
| cfg.TRAINER.FILM.LINEAR_PROBE = True |
| cfg.OPTIM.LR_EXP = 6.5 |
| cfg.OPTIM.NEW_LAYERS = ['linear_probe', 'film'] |
|
|
| |
| cfg.TRAINER.TCP = CN() |
| cfg.TRAINER.TCP.N_CTX = 4 |
| cfg.TRAINER.TCP.CSC = False |
| cfg.TRAINER.TCP.CTX_INIT = "" |
| cfg.TRAINER.TCP.PREC = "fp16" |
| cfg.TRAINER.TCP.W = 1.0 |
| cfg.TRAINER.TCP.CLASS_TOKEN_POSITION = "end" |
|
|
|
|
| |
| cfg.TRAINER.SUPR = CN() |
| cfg.TRAINER.SUPR.N_CTX_VISION = 4 |
| cfg.TRAINER.SUPR.N_CTX_TEXT = 4 |
| cfg.TRAINER.SUPR.CTX_INIT = "a photo of a" |
| cfg.TRAINER.SUPR.PREC = "fp16" |
| cfg.TRAINER.SUPR.PROMPT_DEPTH_VISION = 9 |
| cfg.TRAINER.SUPR.PROMPT_DEPTH_TEXT = 9 |
| cfg.TRAINER.SUPR.SPACE_DIM = 7 |
| cfg.TRAINER.SUPR.ENSEMBLE_NUM = 3 |
| cfg.TRAINER.SUPR.REG_LOSS_WEIGHT = 60 |
| cfg.TRAINER.SUPR.LAMBDA = 0.7 |
| cfg.TRAINER.SUPR.SVD = True |
| cfg.TRAINER.SUPR.HARD_PROMPT_PATH = "configs/trainers/SuPr/hard_prompts/" |
| cfg.TRAINER.SUPR.TRAINER_BACKBONE = "SuPr" |
|
|
| |
| def setup_cfg(args): |
| cfg = get_cfg_default() |
| extend_cfg(cfg) |
|
|
| |
| if args.dataset_config_file: |
| cfg.merge_from_file(args.dataset_config_file) |
|
|
| |
| if args.config_file: |
| cfg.merge_from_file(args.config_file) |
|
|
| |
| reset_cfg(cfg, args) |
|
|
| |
| cfg.merge_from_list(args.opts) |
|
|
| cfg.freeze() |
|
|
| return cfg |
|
|
|
|
| def main(args): |
| cfg = setup_cfg(args) |
| if cfg.SEED >= 0: |
| print("Setting fixed seed: {}".format(cfg.SEED)) |
| set_random_seed(cfg.SEED) |
| setup_logger(cfg.OUTPUT_DIR) |
|
|
| if torch.cuda.is_available() and cfg.USE_CUDA: |
| torch.backends.cudnn.benchmark = True |
|
|
| print_args(args, cfg) |
| print("Collecting env info ...") |
| print("** System info **\n{}\n".format(collect_env_info())) |
|
|
| trainer = build_trainer(cfg) |
|
|
| if args.eval_only: |
| trainer.load_model(args.model_dir, epoch=args.load_epoch) |
| trainer.test() |
| return |
|
|
|
|
| if not args.no_train: |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--root", type=str, default="", help="path to dataset") |
| parser.add_argument("--output-dir", type=str, default="", help="output directory") |
| parser.add_argument( |
| "--resume", |
| type=str, |
| default="", |
| help="checkpoint directory (from which the training resumes)", |
| ) |
| parser.add_argument( |
| "--seed", type=int, default=-1, help="only positive value enables a fixed seed" |
| ) |
| parser.add_argument( |
| "--source-domains", type=str, nargs="+", help="source domains for DA/DG" |
| ) |
| parser.add_argument( |
| "--target-domains", type=str, nargs="+", help="target domains for DA/DG" |
| ) |
| parser.add_argument( |
| "--transforms", type=str, nargs="+", help="data augmentation methods" |
| ) |
| parser.add_argument( |
| "--config-file", type=str, default="", help="path to config file" |
| ) |
| parser.add_argument( |
| "--dataset-config-file", |
| type=str, |
| default="", |
| help="path to config file for dataset setup", |
| ) |
| parser.add_argument("--trainer", type=str, default="", help="name of trainer") |
| parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") |
| parser.add_argument("--head", type=str, default="", help="name of head") |
| parser.add_argument("--eval-only", action="store_true", help="evaluation only") |
| parser.add_argument( |
| "--model-dir", |
| type=str, |
| default="", |
| help="load model from this directory for eval-only mode", |
| ) |
| parser.add_argument( |
| "--load-epoch", type=int, help="load model weights at this epoch for evaluation" |
| ) |
| parser.add_argument( |
| "--no-train", action="store_true", help="do not call trainer.train()" |
| ) |
| parser.add_argument( |
| "opts", |
| default=None, |
| nargs=argparse.REMAINDER, |
| help="modify config options using the command-line", |
| ) |
| args = parser.parse_args() |
| main(args) |
|
|