diff --git a/whispercpp.pxd b/whispercpp.pxd index 9381078..1a033db 100644 --- a/whispercpp.pxd +++ b/whispercpp.pxd @@ -1,9 +1,8 @@ #!python # cython: language_level=3 - from libc.stdint cimport int64_t -cdef: +cdef nogil: int WHISPER_SAMPLE_RATE = 16000 int WHISPER_N_FFT = 400 int WHISPER_N_MEL = 80 @@ -13,6 +12,9 @@ cdef: char* TEST_FILE = b'test.wav' char* DEFAULT_MODEL = b'ggml-tiny.bin' char* LANGUAGE = b'fr' + ctypedef struct audio_data: + float* frames; + int n_frames; cdef extern from "whisper.h" nogil: enum whisper_sampling_strategy: @@ -109,8 +111,3 @@ cdef extern from "whisper.h" nogil: const char* whisper_print_system_info() const char* whisper_full_get_segment_text(whisper_context*, int) - -ctypedef struct audio_data: - float* frames; - int n_frames; - diff --git a/whispercpp.pyx b/whispercpp.pyx index 2737063..187c177 100644 --- a/whispercpp.pyx +++ b/whispercpp.pyx @@ -8,12 +8,14 @@ import numpy as np import requests import os + cimport numpy as cnp cdef int SAMPLE_RATE = 16000 -cdef char* TEST_FILE = b'test.wav' +cdef char* TEST_FILE = 'test.wav' cdef char* DEFAULT_MODEL = 'tiny' cdef char* LANGUAGE = b'fr' +cdef int N_THREADS = os.cpu_count() MODELS = { 'model_ggml_tiny.bin': 'https://ggml.ggerganov.com/ggml-model-whisper-tiny.bin', @@ -67,7 +69,7 @@ cdef audio_data load_audio(bytes file, int sr = SAMPLE_RATE): return data -cdef whisper_full_params default_params(): +cdef whisper_full_params default_params() nogil: cdef whisper_full_params params = whisper_full_default_params( whisper_sampling_strategy.WHISPER_SAMPLING_GREEDY ) @@ -75,6 +77,7 @@ cdef whisper_full_params default_params(): params.print_progress = True params.translate = False params.language = LANGUAGE + n_threads = N_THREADS return params @@ -83,25 +86,26 @@ cdef class Whisper: cdef whisper_full_params params def __init__(self, model=DEFAULT_MODEL, pb=None): - model_fullname = f'model_ggml_{model.decode()}.bin'.encode('utf8') + model_fullname = f'model_ggml_{model}.bin'.encode('utf8') download_model(model_fullname) cdef bytes model_b = model_fullname self.ctx = whisper_init(model_b) self.params = default_params() + whisper_print_system_info() def __dealloc__(self): whisper_free(self.ctx) - def transcribe(self): - cdef audio_data data = load_audio(TEST_FILE) + def transcribe(self, filename=TEST_FILE): + cdef audio_data data = load_audio(filename) return whisper_full(self.ctx, self.params, data.frames, data.n_frames) - cpdef str extract_text(self, int res): + cpdef list extract_text(self, int res): if res != 0: raise RuntimeError cdef int n_segments = whisper_full_n_segments(self.ctx) - return b'\n'.join([ - whisper_full_get_segment_text(self.ctx, i) for i in range(n_segments) - ]).decode() + return [ + whisper_full_get_segment_text(self.ctx, i).decode() for i in range(n_segments) + ]