Silence PyTorch warnings by using untyped storage (#72)
This commit is contained in:
parent
7cbfbc55c8
commit
d26791b5bc
|
@ -79,7 +79,7 @@ class RWKVModel:
|
||||||
if state_in is not None:
|
if state_in is not None:
|
||||||
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
validate_buffer(state_in, 'state_in', self._state_buffer_element_count)
|
||||||
|
|
||||||
state_in_ptr = state_in.storage().data_ptr()
|
state_in_ptr = state_in.untyped_storage().data_ptr()
|
||||||
else:
|
else:
|
||||||
state_in_ptr = 0
|
state_in_ptr = 0
|
||||||
|
|
||||||
|
@ -97,8 +97,8 @@ class RWKVModel:
|
||||||
self._ctx,
|
self._ctx,
|
||||||
token,
|
token,
|
||||||
state_in_ptr,
|
state_in_ptr,
|
||||||
state_out.storage().data_ptr(),
|
state_out.untyped_storage().data_ptr(),
|
||||||
logits_out.storage().data_ptr()
|
logits_out.untyped_storage().data_ptr()
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits_out, state_out
|
return logits_out, state_out
|
||||||
|
|
Loading…
Reference in New Issue