Silence PyTorch warnings by using untyped storage (#72)
This commit is contained in:
parent
7cbfbc55c8
commit
d26791b5bc
|
@ -79,7 +79,7 @@ class RWKVModel:
|
|||
if state_in is not None:
|
||||
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||
|
||||
state_in_ptr = state_in.storage().data_ptr()
|
||||
state_in_ptr = state_in.untyped_storage().data_ptr()
|
||||
else:
|
||||
state_in_ptr = 0
|
||||
|
||||
|
@ -97,8 +97,8 @@ class RWKVModel:
|
|||
self._ctx,
|
||||
token,
|
||||
state_in_ptr,
|
||||
state_out.storage().data_ptr(),
|
||||
logits_out.storage().data_ptr()
|
||||
state_out.untyped_storage().data_ptr(),
|
||||
logits_out.untyped_storage().data_ptr()
|
||||
)
|
||||
|
||||
return logits_out, state_out
|
||||
|
|
Loading…
Reference in New Issue