118 lines
4.7 KiB
Python
118 lines
4.7 KiB
Python
from typing import Callable, List, Optional
|
|
|
|
|
|
def _default_translator_factory(model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8):
|
|
"""Crea una función translator(texts: List[str]) -> List[str] usando transformers.
|
|
|
|
La creación se hace perezosamente para evitar obligar la dependencia en import-time.
|
|
"""
|
|
|
|
def make():
|
|
try:
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
except Exception as e:
|
|
raise RuntimeError("transformers no disponible: instale 'transformers' y 'sentencepiece' para traducción local") from e
|
|
|
|
tok = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
|
def translator(texts: List[str]) -> List[str]:
|
|
outs = []
|
|
# procesar en batches simples
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i : i + batch_size]
|
|
enc = tok(batch, return_tensors="pt", padding=True, truncation=True)
|
|
gen = model.generate(**enc, max_length=512)
|
|
dec = tok.batch_decode(gen, skip_special_tokens=True)
|
|
outs.extend([d.strip() for d in dec])
|
|
return outs
|
|
|
|
return translator
|
|
|
|
return make()
|
|
|
|
|
|
def translate_srt(in_path: str, out_path: str, *, model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8, translator: Optional[Callable[[List[str]], List[str]]] = None) -> None:
|
|
"""Traduce un archivo SRT manteniendo índices y timestamps.
|
|
|
|
Parámetros:
|
|
- in_path, out_path: rutas de entrada/salida
|
|
- model_name, batch_size: usados si `translator` es None
|
|
- translator: función opcional que recibe lista de textos y devuelve lista de textos traducidos.
|
|
"""
|
|
# Importar srt perezosamente; si no está disponible, usar un parser mínimo
|
|
try:
|
|
import srt # type: ignore
|
|
|
|
def _read_srt(path: str):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
raw = f.read()
|
|
return list(srt.parse(raw))
|
|
|
|
def _write_srt(path: str, subs):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
f.write(srt.compose(subs))
|
|
|
|
subs = _read_srt(in_path)
|
|
texts = [sub.content.strip() for sub in subs]
|
|
_compose_fn = lambda out_path, subs_list: _write_srt(out_path, subs_list)
|
|
except Exception:
|
|
# Fallback mínimo: parsear bloques simples de SRT (no soporta todos los casos)
|
|
def _parse_simple(raw_text: str):
|
|
blocks = [b.strip() for b in raw_text.strip().split("\n\n") if b.strip()]
|
|
parsed = []
|
|
for b in blocks:
|
|
lines = b.splitlines()
|
|
if len(lines) < 3:
|
|
continue
|
|
idx = lines[0]
|
|
times = lines[1]
|
|
content = "\n".join(lines[2:])
|
|
parsed.append({"index": idx, "times": times, "content": content})
|
|
return parsed
|
|
|
|
def _compose_simple(parsed, out_path: str):
|
|
with open(out_path, "w", encoding="utf-8") as f:
|
|
for i, item in enumerate(parsed, start=1):
|
|
f.write(f"{item['index']}\n")
|
|
f.write(f"{item['times']}\n")
|
|
f.write(f"{item['content']}\n\n")
|
|
|
|
with open(in_path, "r", encoding="utf-8") as f:
|
|
raw = f.read()
|
|
subs = _parse_simple(raw)
|
|
texts = [s["content"].strip() for s in subs]
|
|
_compose_fn = lambda out_path, subs_list: _compose_simple(subs_list, out_path)
|
|
|
|
if translator is None:
|
|
translator = _default_translator_factory(model_name=model_name, batch_size=batch_size)
|
|
|
|
translated = translator(texts)
|
|
|
|
if len(translated) != len(subs):
|
|
raise RuntimeError("El traductor devolvió un número distinto de segmentos traducidos")
|
|
|
|
# Asignar traducidos en la estructura usada (objeto srt o dict simple)
|
|
if subs and isinstance(subs[0], dict):
|
|
for s, t in zip(subs, translated):
|
|
s["content"] = t.strip()
|
|
_compose_fn(out_path, subs)
|
|
else:
|
|
for sub, t in zip(subs, translated):
|
|
sub.content = t.strip()
|
|
_compose_fn(out_path, subs)
|
|
|
|
|
|
class MarianTranslator:
|
|
"""Adapter que ofrece una API simple para uso en usecases.
|
|
|
|
Internamente llama a `translate_srt` y permite inyectar un traductor para tests.
|
|
"""
|
|
|
|
def __init__(self, model_name: str = "Helsinki-NLP/opus-mt-en-es", batch_size: int = 8):
|
|
self.model_name = model_name
|
|
self.batch_size = batch_size
|
|
|
|
def translate_srt(self, in_srt: str, out_srt: str, translator: Optional[Callable[[List[str]], List[str]]] = None) -> None:
|
|
translate_srt(in_srt, out_srt, model_name=self.model_name, batch_size=self.batch_size, translator=translator)
|