File size: 5,548 Bytes
7cc7f9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torchcrf import CRF
from huggingface_hub import hf_hub_download
import PyPDF2
from docx import Document

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class VanillaTransformer(nn.Module):
    def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self, src, src_key_padding_mask=None):
        src = self.pos_encoder(src)
        return self.transformer(src, src_key_padding_mask=src_key_padding_mask)

class HierarchicalLegalSegModel(nn.Module):
    def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
        super().__init__()
        self.longformer = longformer_model
        self.hidden_dim = hidden_dim
        self.vanilla_transformer = VanillaTransformer(d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers, dim_feedforward=hidden_dim*4, dropout=dropout)
        self.classifier = nn.Linear(hidden_dim, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def encode_sentences(self, input_ids, attention_mask):
        batch_size, num_sentences, max_seq_len = input_ids.shape
        input_ids_flat = input_ids.view(-1, max_seq_len)
        attention_mask_flat = attention_mask.view(-1, max_seq_len)
        outputs = self.longformer(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]
        return cls_embeddings.view(batch_size, num_sentences, self.hidden_dim)
    
    def forward(self, input_ids, attention_mask, sentence_mask=None):
        embeddings = self.encode_sentences(input_ids, attention_mask)
        embeddings = self.dropout(embeddings)
        output = self.vanilla_transformer(embeddings, src_key_padding_mask=~sentence_mask if sentence_mask is not None else None)
        emissions = self.classifier(output)
        return self.crf.decode(emissions, mask=sentence_mask)

device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("lexlms/legal-longformer-base")
longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device)
for param in longformer.parameters():
    param.requires_grad = False

model = HierarchicalLegalSegModel(longformer, 7).to(device)
model_path = hf_hub_download(repo_id="Prateek0515/legal-document-segmentation", filename="model.pth")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

id2label = {0: "Arguments of Petitioner", 1: "Arguments of Respondent", 2: "Decision", 3: "Facts", 4: "Issue", 5: "None", 6: "Reasoning"}

def extract_text_from_pdf(file):
    reader = PyPDF2.PdfReader(file)
    text = ""
    for page in reader.pages:
        text += page.extract_text()
    return text.strip()

def extract_text_from_docx(file):
    doc = Document(file)
    return "\n".join([para.text for para in doc.paragraphs]).strip()

def predict(text_input, file_input):
    try:
        if file_input is not None:
            if file_input.name.endswith('.pdf'):
                text = extract_text_from_pdf(file_input.name)
            elif file_input.name.endswith('.docx'):
                text = extract_text_from_docx(file_input.name)
            elif file_input.name.endswith('.txt'):
                with open(file_input.name, 'r') as f:
                    text = f.read()
            else:
                return "❌ Unsupported file type"
        else:
            text = text_input
        
        if not text:
            return "⚠️ Please provide text"
        
        encoded = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
        input_ids = encoded["input_ids"].unsqueeze(1).to(device)
        attention_mask = encoded["attention_mask"].unsqueeze(1).to(device)
        sentence_mask = torch.ones(1, 1, dtype=torch.bool).to(device)
        
        with torch.no_grad():
            predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask)
        
        label = id2label[predictions[0][0]]
        return f"βœ… **Label:** {label}\n\nπŸ“„ **Text:** {text[:300]}..."
    except Exception as e:
        return f"❌ Error: {str(e)}"

demo = gr.Interface(fn=predict, inputs=[gr.Textbox(label="Enter Legal Text", lines=5), gr.File(label="Or Upload (PDF/DOCX/TXT)")], outputs=gr.Textbox(label="Result", lines=5), title="βš–οΈ Legal Document Segmentation", api_name="predict")

if __name__ == "__main__":
    demo.launch()