itlevy commited on
Commit
e9d0db3
1 Parent(s): d311379
Files changed (1) hide show
  1. variable_cache.py +11 -9
variable_cache.py CHANGED
@@ -32,18 +32,20 @@ class VariableCache(Cache_4_44_2, Cache):
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
- def __init__(self,
36
- config: DeciLMConfig,
37
- max_batch_size: int,
38
- max_cache_len: int | None,
39
- device: torch.device | str | None = None,
40
- dtype: torch.dtype | None = None,
41
- **kwargs: Any,
42
- ):
 
 
43
  Cache_4_44_2.__init__(self)
44
 
45
  self.config = config
46
- self.max_batch_size = max_batch_size
47
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
48
  self.dtype = dtype
49
 
 
32
  The cache of each layer is allocated to the same gpu as the layer itself.
33
  """
34
 
35
+ def __init__(
36
+ self,
37
+ config: DeciLMConfig,
38
+ batch_size: int = None,
39
+ max_cache_len: int = None,
40
+ device: torch.device = None,
41
+ dtype: torch.dtype = torch.float32,
42
+ max_batch_size: Optional[int] = None,
43
+ **kwargs: Any,
44
+ ) -> None:
45
  Cache_4_44_2.__init__(self)
46
 
47
  self.config = config
48
+ self.max_batch_size = batch_size or max_batch_size
49
  self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
50
  self.dtype = dtype
51