284 lines
9.9 KiB
Python
284 lines
9.9 KiB
Python
"""Transcribe service adapter.
|
|
|
|
Provides a small class that wraps transcription and SRT helper functions
|
|
so callers can depend on an object instead of free functions.
|
|
"""
|
|
from typing import Optional
|
|
import logging
|
|
|
|
"""Transcribe service with inlined implementation.
|
|
|
|
This class contains the transcription and SRT utilities previously in
|
|
`transcribe_impl.py`. Keeping it here as a single adapter simplifies DI
|
|
and makes it easier to unit-test.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TranscribeService:
|
|
def __init__(self, model: str = "base", compute_type: str = "int8") -> None:
|
|
self.model = model
|
|
self.compute_type = compute_type
|
|
|
|
def transcribe_openai(self, file: str):
|
|
import whisper
|
|
|
|
logger.info("Cargando openai-whisper modelo=%s en CPU...", self.model)
|
|
m = whisper.load_model(self.model, device="cpu")
|
|
logger.info("Transcribiendo...")
|
|
result = m.transcribe(file, fp16=False)
|
|
segments = result.get("segments", None)
|
|
if segments:
|
|
for seg in segments:
|
|
logger.debug(seg.get("text", ""))
|
|
return segments
|
|
else:
|
|
logger.debug(result.get("text", ""))
|
|
return None
|
|
|
|
def transcribe_transformers(self, file: str):
|
|
import torch
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
|
|
|
device = "cpu"
|
|
torch_dtype = torch.float32
|
|
|
|
logger.info("Cargando transformers modelo=%s en CPU...", self.model)
|
|
model_obj = AutoModelForSpeechSeq2Seq.from_pretrained(self.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
|
|
model_obj.to(device)
|
|
processor = AutoProcessor.from_pretrained(self.model)
|
|
|
|
pipe = pipeline(
|
|
"automatic-speech-recognition",
|
|
model=model_obj,
|
|
tokenizer=processor.tokenizer,
|
|
feature_extractor=processor.feature_extractor,
|
|
device=-1,
|
|
)
|
|
|
|
logger.info("Transcribiendo...")
|
|
result = pipe(file)
|
|
if isinstance(result, dict):
|
|
logger.debug(result.get("text", ""))
|
|
else:
|
|
logger.debug(result)
|
|
return None
|
|
|
|
def transcribe_faster(self, file: str):
|
|
from faster_whisper import WhisperModel
|
|
|
|
logger.info("Cargando faster-whisper modelo=%s en CPU compute_type=%s...", self.model, self.compute_type)
|
|
model_obj = WhisperModel(self.model, device="cpu", compute_type=self.compute_type)
|
|
logger.info("Transcribiendo...")
|
|
segments_gen, info = model_obj.transcribe(file, beam_size=5)
|
|
segments = list(segments_gen)
|
|
text = "".join([seg.text for seg in segments])
|
|
logger.debug(text)
|
|
return segments
|
|
|
|
def _format_timestamp(self, seconds: float) -> str:
|
|
millis = int((seconds - int(seconds)) * 1000)
|
|
h = int(seconds // 3600)
|
|
m = int((seconds % 3600) // 60)
|
|
s = int(seconds % 60)
|
|
return f"{h:02d}:{m:02d}:{s:02d},{millis:03d}"
|
|
|
|
def write_srt(self, segments, out_path: str):
|
|
lines = []
|
|
for i, seg in enumerate(segments, start=1):
|
|
if hasattr(seg, "start"):
|
|
start = float(seg.start)
|
|
end = float(seg.end)
|
|
text = seg.text if hasattr(seg, "text") else str(seg)
|
|
else:
|
|
start = float(seg.get("start", 0.0))
|
|
end = float(seg.get("end", 0.0))
|
|
text = seg.get("text", "")
|
|
|
|
start_ts = self._format_timestamp(start)
|
|
end_ts = self._format_timestamp(end)
|
|
lines.append(str(i))
|
|
lines.append(f"{start_ts} --> {end_ts}")
|
|
for line in str(text).strip().splitlines():
|
|
lines.append(line)
|
|
lines.append("")
|
|
|
|
Path(out_path).write_text("\n".join(lines), encoding="utf-8")
|
|
|
|
def dedupe_adjacent_segments(self, segments):
|
|
if not segments:
|
|
return segments
|
|
|
|
norm = []
|
|
for s in segments:
|
|
if hasattr(s, "start"):
|
|
norm.append({"start": float(s.start), "end": float(s.end), "text": getattr(s, "text", "")})
|
|
else:
|
|
norm.append({"start": float(s.get("start", 0.0)), "end": float(s.get("end", 0.0)), "text": s.get("text", "")})
|
|
|
|
out = [norm[0].copy()]
|
|
for seg in norm[1:]:
|
|
prev = out[-1]
|
|
a = (prev.get("text") or "").strip()
|
|
b = (seg.get("text") or "").strip()
|
|
if not a or not b:
|
|
out.append(seg.copy())
|
|
continue
|
|
|
|
a_words = a.split()
|
|
b_words = b.split()
|
|
max_ol = 0
|
|
max_k = min(len(a_words), len(b_words), 10)
|
|
for k in range(1, max_k + 1):
|
|
if a_words[-k:] == b_words[:k]:
|
|
max_ol = k
|
|
|
|
if max_ol > 0:
|
|
new_b = " ".join(b_words[max_ol:]).strip()
|
|
new_seg = seg.copy()
|
|
new_seg["text"] = new_b
|
|
out.append(new_seg)
|
|
else:
|
|
out.append(seg.copy())
|
|
|
|
return out
|
|
|
|
def get_audio_duration(self, file_path: str):
|
|
try:
|
|
import subprocess
|
|
|
|
cmd = [
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-show_entries",
|
|
"format=duration",
|
|
"-of",
|
|
"default=noprint_wrappers=1:nokey=1",
|
|
file_path,
|
|
]
|
|
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
|
|
return float(out.strip())
|
|
except Exception:
|
|
return None
|
|
|
|
def make_uniform_segments(self, duration: float, seg_seconds: float):
|
|
segments = []
|
|
if duration <= 0 or seg_seconds <= 0:
|
|
return segments
|
|
start = 0.0
|
|
while start < duration:
|
|
end = min(start + seg_seconds, duration)
|
|
segments.append({"start": round(start, 3), "end": round(end, 3)})
|
|
start = end
|
|
return segments
|
|
|
|
def transcribe_segmented_with_tempfiles(self, src_file: str, segments: list, backend: str = "faster-whisper", model: str = "base", compute_type: str = "int8", overlap: float = 0.2):
|
|
import subprocess
|
|
import tempfile
|
|
|
|
results = []
|
|
for seg in segments:
|
|
start = max(0.0, float(seg["start"]) - overlap)
|
|
end = float(seg["end"]) + overlap
|
|
duration = end - start
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
|
|
tmp_path = tmp.name
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-y",
|
|
"-ss",
|
|
str(start),
|
|
"-t",
|
|
str(duration),
|
|
"-i",
|
|
src_file,
|
|
"-ar",
|
|
"16000",
|
|
"-ac",
|
|
"1",
|
|
tmp_path,
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
except Exception:
|
|
results.append({"start": seg["start"], "end": seg["end"], "text": ""})
|
|
continue
|
|
|
|
try:
|
|
if backend == "openai-whisper":
|
|
import whisper
|
|
|
|
m = whisper.load_model(model, device="cpu")
|
|
res = m.transcribe(tmp_path, fp16=False)
|
|
text = res.get("text", "")
|
|
elif backend == "transformers":
|
|
import torch
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
|
|
|
torch_dtype = torch.float32
|
|
model_obj = AutoModelForSpeechSeq2Seq.from_pretrained(model, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
|
|
model_obj.to("cpu")
|
|
processor = AutoProcessor.from_pretrained(model)
|
|
pipe = pipeline(
|
|
"automatic-speech-recognition",
|
|
model=model_obj,
|
|
tokenizer=processor.tokenizer,
|
|
feature_extractor=processor.feature_extractor,
|
|
device=-1,
|
|
)
|
|
out = pipe(tmp_path)
|
|
text = out["text"] if isinstance(out, dict) else str(out)
|
|
else:
|
|
from faster_whisper import WhisperModel
|
|
|
|
wmodel = WhisperModel(model, device="cpu", compute_type=compute_type)
|
|
segs_gen, info = wmodel.transcribe(tmp_path, beam_size=5)
|
|
segs = list(segs_gen)
|
|
text = "".join([s.text for s in segs])
|
|
|
|
except Exception:
|
|
text = ""
|
|
|
|
results.append({"start": seg["start"], "end": seg["end"], "text": text})
|
|
|
|
return results
|
|
|
|
def tts_synthesize(self, text: str, out_path: str, model: str = "kokoro"):
|
|
try:
|
|
from TTS.api import TTS
|
|
|
|
tts = TTS(model_name=model, progress_bar=False, gpu=False)
|
|
tts.tts_to_file(text=text, file_path=out_path)
|
|
return True
|
|
except Exception:
|
|
try:
|
|
import pyttsx3
|
|
|
|
engine = pyttsx3.init()
|
|
engine.save_to_file(text, out_path)
|
|
engine.runAndWait()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def ensure_tts_model(self, repo_id: str):
|
|
try:
|
|
from huggingface_hub import snapshot_download
|
|
|
|
try:
|
|
local_dir = snapshot_download(repo_id, repo_type="model")
|
|
except Exception:
|
|
local_dir = snapshot_download(repo_id)
|
|
return local_dir
|
|
except Exception:
|
|
return repo_id
|
|
|
|
|
|
__all__ = ["TranscribeService"]
|