submaster/whisper_project/infra/marian_adapter.py

118 lines
4.7 KiB
Python

from typing import Callable, List, Optional
def _default_translator_factory(model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8):
"""Crea una función translator(texts: List[str]) -> List[str] usando transformers.
La creación se hace perezosamente para evitar obligar la dependencia en import-time.
"""
def make():
try:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
except Exception as e:
raise RuntimeError("transformers no disponible: instale 'transformers' y 'sentencepiece' para traducción local") from e
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def translator(texts: List[str]) -> List[str]:
outs = []
# procesar en batches simples
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
enc = tok(batch, return_tensors="pt", padding=True, truncation=True)
gen = model.generate(**enc, max_length=512)
dec = tok.batch_decode(gen, skip_special_tokens=True)
outs.extend([d.strip() for d in dec])
return outs
return translator
return make()
def translate_srt(in_path: str, out_path: str, *, model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8, translator: Optional[Callable[[List[str]], List[str]]] = None) -> None:
"""Traduce un archivo SRT manteniendo índices y timestamps.
Parámetros:
- in_path, out_path: rutas de entrada/salida
- model_name, batch_size: usados si `translator` es None
- translator: función opcional que recibe lista de textos y devuelve lista de textos traducidos.
"""
# Importar srt perezosamente; si no está disponible, usar un parser mínimo
try:
import srt # type: ignore
def _read_srt(path: str):
with open(path, "r", encoding="utf-8") as f:
raw = f.read()
return list(srt.parse(raw))
def _write_srt(path: str, subs):
with open(path, "w", encoding="utf-8") as f:
f.write(srt.compose(subs))
subs = _read_srt(in_path)
texts = [sub.content.strip() for sub in subs]
_compose_fn = lambda out_path, subs_list: _write_srt(out_path, subs_list)
except Exception:
# Fallback mínimo: parsear bloques simples de SRT (no soporta todos los casos)
def _parse_simple(raw_text: str):
blocks = [b.strip() for b in raw_text.strip().split("\n\n") if b.strip()]
parsed = []
for b in blocks:
lines = b.splitlines()
if len(lines) < 3:
continue
idx = lines[0]
times = lines[1]
content = "\n".join(lines[2:])
parsed.append({"index": idx, "times": times, "content": content})
return parsed
def _compose_simple(parsed, out_path: str):
with open(out_path, "w", encoding="utf-8") as f:
for i, item in enumerate(parsed, start=1):
f.write(f"{item['index']}\n")
f.write(f"{item['times']}\n")
f.write(f"{item['content']}\n\n")
with open(in_path, "r", encoding="utf-8") as f:
raw = f.read()
subs = _parse_simple(raw)
texts = [s["content"].strip() for s in subs]
_compose_fn = lambda out_path, subs_list: _compose_simple(subs_list, out_path)
if translator is None:
translator = _default_translator_factory(model_name=model_name, batch_size=batch_size)
translated = translator(texts)
if len(translated) != len(subs):
raise RuntimeError("El traductor devolvió un número distinto de segmentos traducidos")
# Asignar traducidos en la estructura usada (objeto srt o dict simple)
if subs and isinstance(subs[0], dict):
for s, t in zip(subs, translated):
s["content"] = t.strip()
_compose_fn(out_path, subs)
else:
for sub, t in zip(subs, translated):
sub.content = t.strip()
_compose_fn(out_path, subs)
class MarianTranslator:
"""Adapter que ofrece una API simple para uso en usecases.
Internamente llama a `translate_srt` y permite inyectar un traductor para tests.
"""
def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8):
self.model_name = model_name
self.batch_size = batch_size
def translate_srt(self, in_srt: str, out_srt: str, translator: Optional[Callable[[List[str]], List[str]]] = None) -> None:
translate_srt(in_srt, out_srt, model_name=self.model_name, batch_size=self.batch_size, translator=translator)