submaster/whisper_project/infra/transcribe_adapter.py

284 lines
9.9 KiB
Python

"""Transcribe service adapter.
Provides a small class that wraps transcription and SRT helper functions
so callers can depend on an object instead of free functions.
"""
from typing import Optional
import logging
"""Transcribe service with inlined implementation.
This class contains the transcription and SRT utilities previously in
`transcribe_impl.py`. Keeping it here as a single adapter simplifies DI
and makes it easier to unit-test.
"""
from pathlib import Path
logger = logging.getLogger(__name__)
class TranscribeService:
def __init__(self, model: str = "base", compute_type: str = "int8") -> None:
self.model = model
self.compute_type = compute_type
def transcribe_openai(self, file: str):
import whisper
logger.info("Cargando openai-whisper modelo=%s en CPU...", self.model)
m = whisper.load_model(self.model, device="cpu")
logger.info("Transcribiendo...")
result = m.transcribe(file, fp16=False)
segments = result.get("segments", None)
if segments:
for seg in segments:
logger.debug(seg.get("text", ""))
return segments
else:
logger.debug(result.get("text", ""))
return None
def transcribe_transformers(self, file: str):
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
device = "cpu"
torch_dtype = torch.float32
logger.info("Cargando transformers modelo=%s en CPU...", self.model)
model_obj = AutoModelForSpeechSeq2Seq.from_pretrained(self.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
model_obj.to(device)
processor = AutoProcessor.from_pretrained(self.model)
pipe = pipeline(
"automatic-speech-recognition",
model=model_obj,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=-1,
)
logger.info("Transcribiendo...")
result = pipe(file)
if isinstance(result, dict):
logger.debug(result.get("text", ""))
else:
logger.debug(result)
return None
def transcribe_faster(self, file: str):
from faster_whisper import WhisperModel
logger.info("Cargando faster-whisper modelo=%s en CPU compute_type=%s...", self.model, self.compute_type)
model_obj = WhisperModel(self.model, device="cpu", compute_type=self.compute_type)
logger.info("Transcribiendo...")
segments_gen, info = model_obj.transcribe(file, beam_size=5)
segments = list(segments_gen)
text = "".join([seg.text for seg in segments])
logger.debug(text)
return segments
def _format_timestamp(self, seconds: float) -> str:
millis = int((seconds - int(seconds)) * 1000)
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
return f"{h:02d}:{m:02d}:{s:02d},{millis:03d}"
def write_srt(self, segments, out_path: str):
lines = []
for i, seg in enumerate(segments, start=1):
if hasattr(seg, "start"):
start = float(seg.start)
end = float(seg.end)
text = seg.text if hasattr(seg, "text") else str(seg)
else:
start = float(seg.get("start", 0.0))
end = float(seg.get("end", 0.0))
text = seg.get("text", "")
start_ts = self._format_timestamp(start)
end_ts = self._format_timestamp(end)
lines.append(str(i))
lines.append(f"{start_ts} --> {end_ts}")
for line in str(text).strip().splitlines():
lines.append(line)
lines.append("")
Path(out_path).write_text("\n".join(lines), encoding="utf-8")
def dedupe_adjacent_segments(self, segments):
if not segments:
return segments
norm = []
for s in segments:
if hasattr(s, "start"):
norm.append({"start": float(s.start), "end": float(s.end), "text": getattr(s, "text", "")})
else:
norm.append({"start": float(s.get("start", 0.0)), "end": float(s.get("end", 0.0)), "text": s.get("text", "")})
out = [norm[0].copy()]
for seg in norm[1:]:
prev = out[-1]
a = (prev.get("text") or "").strip()
b = (seg.get("text") or "").strip()
if not a or not b:
out.append(seg.copy())
continue
a_words = a.split()
b_words = b.split()
max_ol = 0
max_k = min(len(a_words), len(b_words), 10)
for k in range(1, max_k + 1):
if a_words[-k:] == b_words[:k]:
max_ol = k
if max_ol > 0:
new_b = " ".join(b_words[max_ol:]).strip()
new_seg = seg.copy()
new_seg["text"] = new_b
out.append(new_seg)
else:
out.append(seg.copy())
return out
def get_audio_duration(self, file_path: str):
try:
import subprocess
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
file_path,
]
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
return float(out.strip())
except Exception:
return None
def make_uniform_segments(self, duration: float, seg_seconds: float):
segments = []
if duration <= 0 or seg_seconds <= 0:
return segments
start = 0.0
while start < duration:
end = min(start + seg_seconds, duration)
segments.append({"start": round(start, 3), "end": round(end, 3)})
start = end
return segments
def transcribe_segmented_with_tempfiles(self, src_file: str, segments: list, backend: str = "faster-whisper", model: str = "base", compute_type: str = "int8", overlap: float = 0.2):
import subprocess
import tempfile
results = []
for seg in segments:
start = max(0.0, float(seg["start"]) - overlap)
end = float(seg["end"]) + overlap
duration = end - start
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
tmp_path = tmp.name
cmd = [
"ffmpeg",
"-y",
"-ss",
str(start),
"-t",
str(duration),
"-i",
src_file,
"-ar",
"16000",
"-ac",
"1",
tmp_path,
]
try:
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except Exception:
results.append({"start": seg["start"], "end": seg["end"], "text": ""})
continue
try:
if backend == "openai-whisper":
import whisper
m = whisper.load_model(model, device="cpu")
res = m.transcribe(tmp_path, fp16=False)
text = res.get("text", "")
elif backend == "transformers":
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
torch_dtype = torch.float32
model_obj = AutoModelForSpeechSeq2Seq.from_pretrained(model, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
model_obj.to("cpu")
processor = AutoProcessor.from_pretrained(model)
pipe = pipeline(
"automatic-speech-recognition",
model=model_obj,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=-1,
)
out = pipe(tmp_path)
text = out["text"] if isinstance(out, dict) else str(out)
else:
from faster_whisper import WhisperModel
wmodel = WhisperModel(model, device="cpu", compute_type=compute_type)
segs_gen, info = wmodel.transcribe(tmp_path, beam_size=5)
segs = list(segs_gen)
text = "".join([s.text for s in segs])
except Exception:
text = ""
results.append({"start": seg["start"], "end": seg["end"], "text": text})
return results
def tts_synthesize(self, text: str, out_path: str, model: str = "kokoro"):
try:
from TTS.api import TTS
tts = TTS(model_name=model, progress_bar=False, gpu=False)
tts.tts_to_file(text=text, file_path=out_path)
return True
except Exception:
try:
import pyttsx3
engine = pyttsx3.init()
engine.save_to_file(text, out_path)
engine.runAndWait()
return True
except Exception:
return False
def ensure_tts_model(self, repo_id: str):
try:
from huggingface_hub import snapshot_download
try:
local_dir = snapshot_download(repo_id, repo_type="model")
except Exception:
local_dir = snapshot_download(repo_id)
return local_dir
except Exception:
return repo_id
__all__ = ["TranscribeService"]