Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,8 +9,11 @@ import PyPDF2
|
|
| 9 |
from docx import Document
|
| 10 |
import re
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
class PositionalEncoding(nn.Module):
|
| 13 |
-
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 14 |
super().__init__()
|
| 15 |
self.dropout = nn.Dropout(p=dropout)
|
| 16 |
pe = torch.zeros(max_len, d_model)
|
|
@@ -23,8 +26,11 @@ class PositionalEncoding(nn.Module):
|
|
| 23 |
def forward(self, x):
|
| 24 |
return x + self.pe[:, :x.size(1)]
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
class VanillaTransformer(nn.Module):
|
| 27 |
-
def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
|
| 28 |
super().__init__()
|
| 29 |
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
| 30 |
encoder_layer = nn.TransformerEncoderLayer(
|
|
@@ -37,14 +43,17 @@ class VanillaTransformer(nn.Module):
|
|
| 37 |
src = self.pos_encoder(src)
|
| 38 |
return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
class HierarchicalLegalSegModel(nn.Module):
|
| 41 |
-
def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
|
| 42 |
super().__init__()
|
| 43 |
self.longformer = longformer_model
|
| 44 |
self.hidden_dim = hidden_dim
|
| 45 |
self.vanilla_transformer = VanillaTransformer(
|
| 46 |
d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
|
| 47 |
-
dim_feedforward=hidden_dim*4, dropout=dropout
|
| 48 |
)
|
| 49 |
self.classifier = nn.Linear(hidden_dim, num_labels)
|
| 50 |
self.crf = CRF(num_labels, batch_first=True)
|
|
@@ -64,7 +73,7 @@ class HierarchicalLegalSegModel(nn.Module):
|
|
| 64 |
sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
|
| 65 |
sentence_embeddings = self.dropout(sentence_embeddings)
|
| 66 |
transformer_output = self.vanilla_transformer(
|
| 67 |
-
sentence_embeddings,
|
| 68 |
src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
|
| 69 |
)
|
| 70 |
emissions = self.classifier(transformer_output)
|
|
@@ -75,6 +84,9 @@ class HierarchicalLegalSegModel(nn.Module):
|
|
| 75 |
predictions = self.crf.decode(emissions, mask=sentence_mask)
|
| 76 |
return predictions
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
print("Loading model...")
|
| 79 |
device = torch.device("cpu")
|
| 80 |
|
|
@@ -84,8 +96,10 @@ longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device
|
|
| 84 |
for param in longformer.parameters():
|
| 85 |
param.requires_grad = False
|
| 86 |
|
| 87 |
-
model = HierarchicalLegalSegModel(
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
model_path = hf_hub_download(
|
| 91 |
repo_id="Prateek0515/legal-document-segmentation",
|
|
@@ -93,14 +107,26 @@ model_path = hf_hub_download(
|
|
| 93 |
)
|
| 94 |
|
| 95 |
checkpoint = torch.load(model_path, map_location=device)
|
| 96 |
-
if isinstance(checkpoint, dict)
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
else:
|
| 99 |
model.load_state_dict(checkpoint)
|
| 100 |
|
| 101 |
model.eval()
|
| 102 |
print("Model loaded successfully!")
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
id2label = {
|
| 105 |
0: "Arguments of Petitioner",
|
| 106 |
1: "Arguments of Respondent",
|
|
@@ -111,8 +137,10 @@ id2label = {
|
|
| 111 |
6: "Reasoning"
|
| 112 |
}
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
def split_sentences(text):
|
| 115 |
-
"""Split text into sentences"""
|
| 116 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 117 |
return [s.strip() for s in sentences if s.strip()]
|
| 118 |
|
|
@@ -134,6 +162,9 @@ def extract_text_from_docx(file_path):
|
|
| 134 |
except Exception as e:
|
| 135 |
return f"Error reading DOCX: {str(e)}"
|
| 136 |
|
|
|
|
|
|
|
|
|
|
| 137 |
def predict(text_input, file_input):
|
| 138 |
try:
|
| 139 |
text = None
|
|
@@ -160,11 +191,10 @@ def predict(text_input, file_input):
|
|
| 160 |
return "β οΈ No text content found"
|
| 161 |
|
| 162 |
sentences = split_sentences(text)
|
| 163 |
-
|
| 164 |
if not sentences:
|
| 165 |
return "β οΈ Could not split text into sentences"
|
| 166 |
|
| 167 |
-
#
|
| 168 |
encoded_sentences = []
|
| 169 |
for sentence in sentences:
|
| 170 |
encoded = tokenizer(
|
|
@@ -180,11 +210,24 @@ def predict(text_input, file_input):
|
|
| 180 |
attention_mask = torch.stack([e["attention_mask"].squeeze(0) for e in encoded_sentences]).unsqueeze(0).to(device)
|
| 181 |
sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
with torch.no_grad():
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
predicted_labels = predictions[0]
|
| 187 |
-
|
| 188 |
results = []
|
| 189 |
for sentence, label_id in zip(sentences, predicted_labels):
|
| 190 |
label = id2label.get(label_id, "Unknown")
|
|
@@ -195,6 +238,9 @@ def predict(text_input, file_input):
|
|
| 195 |
except Exception as e:
|
| 196 |
return f"β Error during prediction: {str(e)}"
|
| 197 |
|
|
|
|
|
|
|
|
|
|
| 198 |
demo = gr.Interface(
|
| 199 |
fn=predict,
|
| 200 |
inputs=[
|
|
@@ -202,7 +248,7 @@ demo = gr.Interface(
|
|
| 202 |
gr.File(label="Or Upload File (PDF, DOCX, TXT)")
|
| 203 |
],
|
| 204 |
outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
|
| 205 |
-
title="βοΈ Legal Document Segmentation",
|
| 206 |
description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, or Reasoning",
|
| 207 |
examples=[
|
| 208 |
["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None],
|
|
@@ -210,5 +256,5 @@ demo = gr.Interface(
|
|
| 210 |
api_name="predict"
|
| 211 |
)
|
| 212 |
|
| 213 |
-
if __name__ == "__main__":
|
| 214 |
demo.launch()
|
|
|
|
| 9 |
from docx import Document
|
| 10 |
import re
|
| 11 |
|
| 12 |
+
# -------------------------
|
| 13 |
+
# Positional Encoding
|
| 14 |
+
# -------------------------
|
| 15 |
class PositionalEncoding(nn.Module):
|
| 16 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 17 |
super().__init__()
|
| 18 |
self.dropout = nn.Dropout(p=dropout)
|
| 19 |
pe = torch.zeros(max_len, d_model)
|
|
|
|
| 26 |
def forward(self, x):
|
| 27 |
return x + self.pe[:, :x.size(1)]
|
| 28 |
|
| 29 |
+
# -------------------------
|
| 30 |
+
# Vanilla Transformer
|
| 31 |
+
# -------------------------
|
| 32 |
class VanillaTransformer(nn.Module):
|
| 33 |
+
def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
|
| 34 |
super().__init__()
|
| 35 |
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
| 36 |
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
|
| 43 |
src = self.pos_encoder(src)
|
| 44 |
return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
|
| 45 |
|
| 46 |
+
# -------------------------
|
| 47 |
+
# Hierarchical Model
|
| 48 |
+
# -------------------------
|
| 49 |
class HierarchicalLegalSegModel(nn.Module):
|
| 50 |
+
def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
|
| 51 |
super().__init__()
|
| 52 |
self.longformer = longformer_model
|
| 53 |
self.hidden_dim = hidden_dim
|
| 54 |
self.vanilla_transformer = VanillaTransformer(
|
| 55 |
d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
|
| 56 |
+
dim_feedforward=hidden_dim * 4, dropout=dropout
|
| 57 |
)
|
| 58 |
self.classifier = nn.Linear(hidden_dim, num_labels)
|
| 59 |
self.crf = CRF(num_labels, batch_first=True)
|
|
|
|
| 73 |
sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
|
| 74 |
sentence_embeddings = self.dropout(sentence_embeddings)
|
| 75 |
transformer_output = self.vanilla_transformer(
|
| 76 |
+
sentence_embeddings,
|
| 77 |
src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
|
| 78 |
)
|
| 79 |
emissions = self.classifier(transformer_output)
|
|
|
|
| 84 |
predictions = self.crf.decode(emissions, mask=sentence_mask)
|
| 85 |
return predictions
|
| 86 |
|
| 87 |
+
# -------------------------
|
| 88 |
+
# Load Model
|
| 89 |
+
# -------------------------
|
| 90 |
print("Loading model...")
|
| 91 |
device = torch.device("cpu")
|
| 92 |
|
|
|
|
| 96 |
for param in longformer.parameters():
|
| 97 |
param.requires_grad = False
|
| 98 |
|
| 99 |
+
model = HierarchicalLegalSegModel(
|
| 100 |
+
longformer, num_labels=7, hidden_dim=768,
|
| 101 |
+
transformer_layers=3, transformer_heads=8, dropout=0.1
|
| 102 |
+
).to(device)
|
| 103 |
|
| 104 |
model_path = hf_hub_download(
|
| 105 |
repo_id="Prateek0515/legal-document-segmentation",
|
|
|
|
| 107 |
)
|
| 108 |
|
| 109 |
checkpoint = torch.load(model_path, map_location=device)
|
| 110 |
+
if isinstance(checkpoint, dict):
|
| 111 |
+
if 'model_state_dict' in checkpoint:
|
| 112 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 113 |
+
else:
|
| 114 |
+
model.load_state_dict(checkpoint)
|
| 115 |
else:
|
| 116 |
model.load_state_dict(checkpoint)
|
| 117 |
|
| 118 |
model.eval()
|
| 119 |
print("Model loaded successfully!")
|
| 120 |
|
| 121 |
+
# π Debug model info
|
| 122 |
+
print("\n>>> MODEL CHECK <<<")
|
| 123 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 124 |
+
print("Checkpoint keys:", checkpoint.keys())
|
| 125 |
+
print("Tokenizer used:", tokenizer.name_or_path)
|
| 126 |
+
|
| 127 |
+
# -------------------------
|
| 128 |
+
# Label mapping
|
| 129 |
+
# -------------------------
|
| 130 |
id2label = {
|
| 131 |
0: "Arguments of Petitioner",
|
| 132 |
1: "Arguments of Respondent",
|
|
|
|
| 137 |
6: "Reasoning"
|
| 138 |
}
|
| 139 |
|
| 140 |
+
# -------------------------
|
| 141 |
+
# Helpers
|
| 142 |
+
# -------------------------
|
| 143 |
def split_sentences(text):
|
|
|
|
| 144 |
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 145 |
return [s.strip() for s in sentences if s.strip()]
|
| 146 |
|
|
|
|
| 162 |
except Exception as e:
|
| 163 |
return f"Error reading DOCX: {str(e)}"
|
| 164 |
|
| 165 |
+
# -------------------------
|
| 166 |
+
# Prediction Function
|
| 167 |
+
# -------------------------
|
| 168 |
def predict(text_input, file_input):
|
| 169 |
try:
|
| 170 |
text = None
|
|
|
|
| 191 |
return "β οΈ No text content found"
|
| 192 |
|
| 193 |
sentences = split_sentences(text)
|
|
|
|
| 194 |
if not sentences:
|
| 195 |
return "β οΈ Could not split text into sentences"
|
| 196 |
|
| 197 |
+
# Encode sentences
|
| 198 |
encoded_sentences = []
|
| 199 |
for sentence in sentences:
|
| 200 |
encoded = tokenizer(
|
|
|
|
| 210 |
attention_mask = torch.stack([e["attention_mask"].squeeze(0) for e in encoded_sentences]).unsqueeze(0).to(device)
|
| 211 |
sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device)
|
| 212 |
|
| 213 |
+
# π§© Debug info
|
| 214 |
+
print(">>> DEBUG INFO <<<")
|
| 215 |
+
print("input_ids:", input_ids.shape)
|
| 216 |
+
print("attention_mask:", attention_mask.shape)
|
| 217 |
+
print("sentence_mask:", sentence_mask.shape)
|
| 218 |
+
|
| 219 |
with torch.no_grad():
|
| 220 |
+
sentence_embeddings = model.encode_sentences(input_ids, attention_mask)
|
| 221 |
+
print("sentence_embeddings:", sentence_embeddings.shape)
|
| 222 |
+
transformer_output = model.vanilla_transformer(sentence_embeddings)
|
| 223 |
+
print("transformer_output mean:", transformer_output.mean().item())
|
| 224 |
+
emissions = model.classifier(transformer_output)
|
| 225 |
+
print("emissions shape:", emissions.shape, " | mean:", emissions.mean().item())
|
| 226 |
+
|
| 227 |
+
predictions = model.crf.decode(emissions, mask=sentence_mask)
|
| 228 |
+
print("Predictions (raw):", predictions)
|
| 229 |
+
|
| 230 |
predicted_labels = predictions[0]
|
|
|
|
| 231 |
results = []
|
| 232 |
for sentence, label_id in zip(sentences, predicted_labels):
|
| 233 |
label = id2label.get(label_id, "Unknown")
|
|
|
|
| 238 |
except Exception as e:
|
| 239 |
return f"β Error during prediction: {str(e)}"
|
| 240 |
|
| 241 |
+
# -------------------------
|
| 242 |
+
# Gradio Interface
|
| 243 |
+
# -------------------------
|
| 244 |
demo = gr.Interface(
|
| 245 |
fn=predict,
|
| 246 |
inputs=[
|
|
|
|
| 248 |
gr.File(label="Or Upload File (PDF, DOCX, TXT)")
|
| 249 |
],
|
| 250 |
outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
|
| 251 |
+
title="βοΈ Legal Document Segmentation (Debug Mode)",
|
| 252 |
description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, or Reasoning",
|
| 253 |
examples=[
|
| 254 |
["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None],
|
|
|
|
| 256 |
api_name="predict"
|
| 257 |
)
|
| 258 |
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
demo.launch()
|