Free ggml context when model is garbage collected
This commit is contained in:
parent
ad3a4ebc57
commit
fa9ad13a39
|
@ -32,14 +32,14 @@ class RWKVModel:
|
||||||
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
assert os.path.isfile(model_path), f'{model_path} is not a file'
|
||||||
assert thread_count > 0, 'Thread count must be positive'
|
assert thread_count > 0, 'Thread count must be positive'
|
||||||
|
|
||||||
self.library = shared_library
|
self._library = shared_library
|
||||||
|
|
||||||
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count)
|
self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
|
||||||
|
|
||||||
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx)
|
self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
|
||||||
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx)
|
self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
|
||||||
|
|
||||||
self.valid = True
|
self._valid = True
|
||||||
|
|
||||||
def eval(
|
def eval(
|
||||||
self,
|
self,
|
||||||
|
@ -69,7 +69,7 @@ class RWKVModel:
|
||||||
Logits vector of shape (n_vocab); state for the next step.
|
Logits vector of shape (n_vocab); state for the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.valid, 'Model was freed'
|
assert self._valid, 'Model was freed'
|
||||||
|
|
||||||
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
|
||||||
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
assert buf.dtype == torch.float32, f'{name} is not of type float32'
|
||||||
|
@ -77,24 +77,24 @@ class RWKVModel:
|
||||||
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
|
||||||
|
|
||||||
if state_in is not None:
|
if state_in is not None:
|
||||||
validate_buffer(state_in, 'state_in', self.state_buffer_element_count)
|
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||||
|
|
||||||
state_in_ptr = state_in.storage().data_ptr()
|
state_in_ptr = state_in.storage().data_ptr()
|
||||||
else:
|
else:
|
||||||
state_in_ptr = 0
|
state_in_ptr = 0
|
||||||
|
|
||||||
if state_out is not None:
|
if state_out is not None:
|
||||||
validate_buffer(state_out, 'state_out', self.state_buffer_element_count)
|
validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
|
||||||
else:
|
else:
|
||||||
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu')
|
state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
if logits_out is not None:
|
if logits_out is not None:
|
||||||
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count)
|
validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
|
||||||
else:
|
else:
|
||||||
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
|
||||||
|
|
||||||
self.library.rwkv_eval(
|
self._library.rwkv_eval(
|
||||||
self.ctx,
|
self._ctx,
|
||||||
token,
|
token,
|
||||||
state_in_ptr,
|
state_in_ptr,
|
||||||
state_out.storage().data_ptr(),
|
state_out.storage().data_ptr(),
|
||||||
|
@ -110,8 +110,13 @@ class RWKVModel:
|
||||||
The object must not be used anymore after calling this method.
|
The object must not be used anymore after calling this method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.valid, 'Already freed'
|
assert self._valid, 'Already freed'
|
||||||
|
|
||||||
self.valid = False
|
self._valid = False
|
||||||
|
|
||||||
self.library.rwkv_free(self.ctx)
|
self._library.rwkv_free(self._ctx)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# Free the context on GC in case user forgot to call free() explicitly.
|
||||||
|
if self._valid:
|
||||||
|
self.free()
|
||||||
|
|
Loading…
Reference in New Issue