File size: 2,923 Bytes
fb56537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# optimization.py
# AducSdr: Uma implementação aberta e funcional da arquitetura ADUC-SDR
# Copyright (C) 4 de Agosto de 2025  Carlos Rodrigues dos Santos
#
# Contato:
# Carlos Rodrigues dos Santos
# [email protected]
# Rua Eduardo Carlos Pereira, 4125, B1 Ap32, Curitiba, PR, Brazil, CEP 8102025
#
# Repositórios e Projetos Relacionados:
# GitHub: https://github.com/carlex22/Aduc-sdr
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License...
# PENDING PATENT NOTICE: Please see NOTICE.md.

import torch
import logging
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight

# Usamos type hints com strings para evitar importações circulares
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from ltx_manager_helpers import LtxWorker

logger = logging.getLogger(__name__)

def can_optimize_fp8():
    """Verifica se a GPU atual suporta otimizações FP8."""
    if not torch.cuda.is_available():
        return False
    
    major, _ = torch.cuda.get_device_capability()
    
    if major >= 9: # Arquitetura Hopper
        logger.info(f"GPU com arquitetura Hopper ou superior (CC {major}.x) detectada. Ativando quantização FP8.")
        return True
    
    if major == 8:
        device_name = torch.cuda.get_device_name(0).lower()
        if "h100" in device_name or "l40" in device_name or "rtx 40" in device_name: # Arquitetura Ada Lovelace
            logger.info(f"GPU com arquitetura Ada Lovelace (CC 8.9, Nome: {device_name}) detectada. Ativando quantização FP8.")
            return True
        
    logger.warning(f"A GPU atual (CC {major}.x) não tem suporte otimizado para FP8. Pulando quantização.")
    return False

@torch.no_grad()
def optimize_ltx_worker(worker: "LtxWorker"):
    """Aplica quantização FP8 ao transformador do pipeline LTX."""
    pipeline = worker.pipeline
    device = worker.device
    
    logger.info(f"Iniciando quantização FP8 do transformador LTX no dispositivo {device}...")
    quantize_(pipeline.transformer, float8_dynamic_activation_float8_weight())
    
    torch.cuda.empty_cache()
    logger.info(f"Quantização FP8 do LTX Worker no dispositivo {device} concluída com sucesso!")