PerceptionDLM / cache.py
MSALab's picture
Add files using upload-large-folder tool
cadf670 verified
Raw
History Blame Contribute Delete
2.86 kB
from dataclasses import dataclass
@dataclass
class dLLMCacheConfig:
prompt_interval_steps: int = 1
gen_interval_steps: int = 1
transfer_ratio: float = 0.0
cfg_interval_steps: int = 1
import torch
from collections import defaultdict
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class dLLMCache(metaclass=Singleton):
gen_interval_steps: int
prompt_interval_steps: int
cfg_interval_steps: int
prompt_length: int
transfer_ratio: float
__cache: defaultdict
__step_counter: defaultdict
@classmethod
def new_instance(
cls,
prompt_interval_steps: int = 1,
gen_interval_steps: int = 1,
cfg_interval_steps: int = 1,
transfer_ratio: float = 0.0,
) -> "dLLMCache":
ins = cls()
setattr(ins, "prompt_interval_steps", prompt_interval_steps)
setattr(ins, "gen_interval_steps", gen_interval_steps)
setattr(ins, "cfg_interval_steps", cfg_interval_steps)
setattr(ins, "transfer_ratio", transfer_ratio)
ins.init()
return ins
def init(self) -> None:
self.__cache = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
)
self.__step_counter = defaultdict(lambda: defaultdict(lambda: 0))
def reset_cache(self, prompt_length: int = 0) -> None:
self.init()
torch.cuda.empty_cache()
self.prompt_length = prompt_length
self.cache_type = "no_cfg"
def set_cache(
self, layer_id: int, feature_name: str, features: torch.Tensor, cache_type: str
) -> None:
self.__cache[self.cache_type][cache_type][layer_id][feature_name] = {
0: features
}
def get_cache(
self, layer_id: int, feature_name: str, cache_type: str
) -> torch.Tensor:
output = self.__cache[self.cache_type][cache_type][layer_id][feature_name][0]
return output
def update_step(self, layer_id: int) -> None:
self.__step_counter[self.cache_type][layer_id] += 1
def refresh_gen(self, layer_id: int = 0) -> bool:
return (self.current_step - 1) % self.gen_interval_steps == 0
def refresh_prompt(self, layer_id: int = 0) -> bool:
return (self.current_step - 1) % self.prompt_interval_steps == 0
def refresh_cfg(self, layer_id: int = 0) -> bool:
return (
self.current_step - 1
) % self.cfg_interval_steps == 0 or self.current_step <= 5
@property
def current_step(self) -> int:
return max(list(self.__step_counter[self.cache_type].values()), default=1)
def __repr__(self):
return f"USE dLLMCache"