Free ggml context when model is garbage collected

This commit is contained in:
saharNooby 2023-04-05 15:55:47 +04:00
parent ad3a4ebc57
commit fa9ad13a39
1 changed files with 21 additions and 16 deletions

View File

@ -32,14 +32,14 @@ class RWKVModel:
assert os.path.isfile(model_path), f'{model_path} is not a file' assert os.path.isfile(model_path), f'{model_path} is not a file'
assert thread_count > 0, 'Thread count must be positive' assert thread_count > 0, 'Thread count must be positive'
self.library = shared_library self._library = shared_library
self.ctx = self.library.rwkv_init_from_file(model_path, thread_count) self._ctx = self._library.rwkv_init_from_file(model_path, thread_count)
self.state_buffer_element_count = self.library.rwkv_get_state_buffer_element_count(self.ctx) self._state_buffer_element_count = self._library.rwkv_get_state_buffer_element_count(self._ctx)
self.logits_buffer_element_count = self.library.rwkv_get_logits_buffer_element_count(self.ctx) self._logits_buffer_element_count = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
self.valid = True self._valid = True
def eval( def eval(
self, self,
@ -69,7 +69,7 @@ class RWKVModel:
Logits vector of shape (n_vocab); state for the next step. Logits vector of shape (n_vocab); state for the next step.
""" """
assert self.valid, 'Model was freed' assert self._valid, 'Model was freed'
def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None: def validate_buffer(buf: torch.Tensor, name: str, size: int) -> None:
assert buf.dtype == torch.float32, f'{name} is not of type float32' assert buf.dtype == torch.float32, f'{name} is not of type float32'
@ -77,24 +77,24 @@ class RWKVModel:
assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})' assert buf.shape == (size,), f'{name} has invalid shape {buf.shape}, expected ({size})'
if state_in is not None: if state_in is not None:
validate_buffer(state_in, 'state_in', self.state_buffer_element_count) validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
state_in_ptr = state_in.storage().data_ptr() state_in_ptr = state_in.storage().data_ptr()
else: else:
state_in_ptr = 0 state_in_ptr = 0
if state_out is not None: if state_out is not None:
validate_buffer(state_out, 'state_out', self.state_buffer_element_count) validate_buffer(state_out, 'state_out', self._state_buffer_element_count)
else: else:
state_out = torch.zeros(self.state_buffer_element_count, dtype=torch.float32, device='cpu') state_out = torch.zeros(self._state_buffer_element_count, dtype=torch.float32, device='cpu')
if logits_out is not None: if logits_out is not None:
validate_buffer(logits_out, 'logits_out', self.logits_buffer_element_count) validate_buffer(logits_out, 'logits_out', self._logits_buffer_element_count)
else: else:
logits_out = torch.zeros(self.logits_buffer_element_count, dtype=torch.float32, device='cpu') logits_out = torch.zeros(self._logits_buffer_element_count, dtype=torch.float32, device='cpu')
self.library.rwkv_eval( self._library.rwkv_eval(
self.ctx, self._ctx,
token, token,
state_in_ptr, state_in_ptr,
state_out.storage().data_ptr(), state_out.storage().data_ptr(),
@ -110,8 +110,13 @@ class RWKVModel:
The object must not be used anymore after calling this method. The object must not be used anymore after calling this method.
""" """
assert self.valid, 'Already freed' assert self._valid, 'Already freed'
self.valid = False self._valid = False
self.library.rwkv_free(self.ctx) self._library.rwkv_free(self._ctx)
def __del__(self):
# Free the context on GC in case user forgot to call free() explicitly.
if self._valid:
self.free()