Prateek0515 commited on
Commit
68a58fd
Β·
verified Β·
1 Parent(s): 36d9d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -17
app.py CHANGED
@@ -9,8 +9,11 @@ import PyPDF2
9
  from docx import Document
10
  import re
11
 
 
 
 
12
  class PositionalEncoding(nn.Module):
13
- def __init__(self, d_model, dropout=0.1, max_len=5000): # βœ… DOUBLE UNDERSCORE
14
  super().__init__()
15
  self.dropout = nn.Dropout(p=dropout)
16
  pe = torch.zeros(max_len, d_model)
@@ -23,8 +26,11 @@ class PositionalEncoding(nn.Module):
23
  def forward(self, x):
24
  return x + self.pe[:, :x.size(1)]
25
 
 
 
 
26
  class VanillaTransformer(nn.Module):
27
- def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1): # βœ… DOUBLE UNDERSCORE
28
  super().__init__()
29
  self.pos_encoder = PositionalEncoding(d_model, dropout)
30
  encoder_layer = nn.TransformerEncoderLayer(
@@ -37,14 +43,17 @@ class VanillaTransformer(nn.Module):
37
  src = self.pos_encoder(src)
38
  return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
39
 
 
 
 
40
  class HierarchicalLegalSegModel(nn.Module):
41
- def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1): # βœ… DOUBLE UNDERSCORE
42
  super().__init__()
43
  self.longformer = longformer_model
44
  self.hidden_dim = hidden_dim
45
  self.vanilla_transformer = VanillaTransformer(
46
  d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
47
- dim_feedforward=hidden_dim*4, dropout=dropout
48
  )
49
  self.classifier = nn.Linear(hidden_dim, num_labels)
50
  self.crf = CRF(num_labels, batch_first=True)
@@ -64,7 +73,7 @@ class HierarchicalLegalSegModel(nn.Module):
64
  sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
65
  sentence_embeddings = self.dropout(sentence_embeddings)
66
  transformer_output = self.vanilla_transformer(
67
- sentence_embeddings,
68
  src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
69
  )
70
  emissions = self.classifier(transformer_output)
@@ -75,6 +84,9 @@ class HierarchicalLegalSegModel(nn.Module):
75
  predictions = self.crf.decode(emissions, mask=sentence_mask)
76
  return predictions
77
 
 
 
 
78
  print("Loading model...")
79
  device = torch.device("cpu")
80
 
@@ -84,8 +96,10 @@ longformer = AutoModel.from_pretrained("lexlms/legal-longformer-base").to(device
84
  for param in longformer.parameters():
85
  param.requires_grad = False
86
 
87
- model = HierarchicalLegalSegModel(longformer, num_labels=7, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1)
88
- model = model.to(device)
 
 
89
 
90
  model_path = hf_hub_download(
91
  repo_id="Prateek0515/legal-document-segmentation",
@@ -93,14 +107,26 @@ model_path = hf_hub_download(
93
  )
94
 
95
  checkpoint = torch.load(model_path, map_location=device)
96
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
97
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
98
  else:
99
  model.load_state_dict(checkpoint)
100
 
101
  model.eval()
102
  print("Model loaded successfully!")
103
 
 
 
 
 
 
 
 
 
 
104
  id2label = {
105
  0: "Arguments of Petitioner",
106
  1: "Arguments of Respondent",
@@ -111,8 +137,10 @@ id2label = {
111
  6: "Reasoning"
112
  }
113
 
 
 
 
114
  def split_sentences(text):
115
- """Split text into sentences"""
116
  sentences = re.split(r'(?<=[.!?])\s+', text)
117
  return [s.strip() for s in sentences if s.strip()]
118
 
@@ -134,6 +162,9 @@ def extract_text_from_docx(file_path):
134
  except Exception as e:
135
  return f"Error reading DOCX: {str(e)}"
136
 
 
 
 
137
  def predict(text_input, file_input):
138
  try:
139
  text = None
@@ -160,11 +191,10 @@ def predict(text_input, file_input):
160
  return "⚠️ No text content found"
161
 
162
  sentences = split_sentences(text)
163
-
164
  if not sentences:
165
  return "⚠️ Could not split text into sentences"
166
 
167
- # SIMPLE - PROCESS ALL SENTENCES TOGETHER (WORKS PERFECTLY)
168
  encoded_sentences = []
169
  for sentence in sentences:
170
  encoded = tokenizer(
@@ -180,11 +210,24 @@ def predict(text_input, file_input):
180
  attention_mask = torch.stack([e["attention_mask"].squeeze(0) for e in encoded_sentences]).unsqueeze(0).to(device)
181
  sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device)
182
 
 
 
 
 
 
 
183
  with torch.no_grad():
184
- predictions = model(input_ids, attention_mask, sentence_mask=sentence_mask)
185
-
 
 
 
 
 
 
 
 
186
  predicted_labels = predictions[0]
187
-
188
  results = []
189
  for sentence, label_id in zip(sentences, predicted_labels):
190
  label = id2label.get(label_id, "Unknown")
@@ -195,6 +238,9 @@ def predict(text_input, file_input):
195
  except Exception as e:
196
  return f"❌ Error during prediction: {str(e)}"
197
 
 
 
 
198
  demo = gr.Interface(
199
  fn=predict,
200
  inputs=[
@@ -202,7 +248,7 @@ demo = gr.Interface(
202
  gr.File(label="Or Upload File (PDF, DOCX, TXT)")
203
  ],
204
  outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
205
- title="βš–οΈ Legal Document Segmentation",
206
  description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, or Reasoning",
207
  examples=[
208
  ["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None],
@@ -210,5 +256,5 @@ demo = gr.Interface(
210
  api_name="predict"
211
  )
212
 
213
- if __name__ == "__main__": # βœ… DOUBLE UNDERSCORE
214
  demo.launch()
 
9
  from docx import Document
10
  import re
11
 
12
+ # -------------------------
13
+ # Positional Encoding
14
+ # -------------------------
15
  class PositionalEncoding(nn.Module):
16
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
17
  super().__init__()
18
  self.dropout = nn.Dropout(p=dropout)
19
  pe = torch.zeros(max_len, d_model)
 
26
  def forward(self, x):
27
  return x + self.pe[:, :x.size(1)]
28
 
29
+ # -------------------------
30
+ # Vanilla Transformer
31
+ # -------------------------
32
  class VanillaTransformer(nn.Module):
33
+ def __init__(self, d_model=768, nhead=8, num_layers=3, dim_feedforward=2048, dropout=0.1):
34
  super().__init__()
35
  self.pos_encoder = PositionalEncoding(d_model, dropout)
36
  encoder_layer = nn.TransformerEncoderLayer(
 
43
  src = self.pos_encoder(src)
44
  return self.transformer(src, src_key_padding_mask=src_key_padding_mask)
45
 
46
+ # -------------------------
47
+ # Hierarchical Model
48
+ # -------------------------
49
  class HierarchicalLegalSegModel(nn.Module):
50
+ def __init__(self, longformer_model, num_labels, hidden_dim=768, transformer_layers=3, transformer_heads=8, dropout=0.1):
51
  super().__init__()
52
  self.longformer = longformer_model
53
  self.hidden_dim = hidden_dim
54
  self.vanilla_transformer = VanillaTransformer(
55
  d_model=hidden_dim, nhead=transformer_heads, num_layers=transformer_layers,
56
+ dim_feedforward=hidden_dim * 4, dropout=dropout
57
  )
58
  self.classifier = nn.Linear(hidden_dim, num_labels)
59
  self.crf = CRF(num_labels, batch_first=True)
 
73
  sentence_embeddings = self.encode_sentences(input_ids, attention_mask)
74
  sentence_embeddings = self.dropout(sentence_embeddings)
75
  transformer_output = self.vanilla_transformer(
76
+ sentence_embeddings,
77
  src_key_padding_mask=~sentence_mask if sentence_mask is not None else None
78
  )
79
  emissions = self.classifier(transformer_output)
 
84
  predictions = self.crf.decode(emissions, mask=sentence_mask)
85
  return predictions
86
 
87
+ # -------------------------
88
+ # Load Model
89
+ # -------------------------
90
  print("Loading model...")
91
  device = torch.device("cpu")
92
 
 
96
  for param in longformer.parameters():
97
  param.requires_grad = False
98
 
99
+ model = HierarchicalLegalSegModel(
100
+ longformer, num_labels=7, hidden_dim=768,
101
+ transformer_layers=3, transformer_heads=8, dropout=0.1
102
+ ).to(device)
103
 
104
  model_path = hf_hub_download(
105
  repo_id="Prateek0515/legal-document-segmentation",
 
107
  )
108
 
109
  checkpoint = torch.load(model_path, map_location=device)
110
+ if isinstance(checkpoint, dict):
111
+ if 'model_state_dict' in checkpoint:
112
+ model.load_state_dict(checkpoint['model_state_dict'])
113
+ else:
114
+ model.load_state_dict(checkpoint)
115
  else:
116
  model.load_state_dict(checkpoint)
117
 
118
  model.eval()
119
  print("Model loaded successfully!")
120
 
121
+ # πŸ” Debug model info
122
+ print("\n>>> MODEL CHECK <<<")
123
+ checkpoint = torch.load(model_path, map_location=device)
124
+ print("Checkpoint keys:", checkpoint.keys())
125
+ print("Tokenizer used:", tokenizer.name_or_path)
126
+
127
+ # -------------------------
128
+ # Label mapping
129
+ # -------------------------
130
  id2label = {
131
  0: "Arguments of Petitioner",
132
  1: "Arguments of Respondent",
 
137
  6: "Reasoning"
138
  }
139
 
140
+ # -------------------------
141
+ # Helpers
142
+ # -------------------------
143
  def split_sentences(text):
 
144
  sentences = re.split(r'(?<=[.!?])\s+', text)
145
  return [s.strip() for s in sentences if s.strip()]
146
 
 
162
  except Exception as e:
163
  return f"Error reading DOCX: {str(e)}"
164
 
165
+ # -------------------------
166
+ # Prediction Function
167
+ # -------------------------
168
  def predict(text_input, file_input):
169
  try:
170
  text = None
 
191
  return "⚠️ No text content found"
192
 
193
  sentences = split_sentences(text)
 
194
  if not sentences:
195
  return "⚠️ Could not split text into sentences"
196
 
197
+ # Encode sentences
198
  encoded_sentences = []
199
  for sentence in sentences:
200
  encoded = tokenizer(
 
210
  attention_mask = torch.stack([e["attention_mask"].squeeze(0) for e in encoded_sentences]).unsqueeze(0).to(device)
211
  sentence_mask = torch.ones(1, len(sentences), dtype=torch.bool).to(device)
212
 
213
+ # 🧩 Debug info
214
+ print(">>> DEBUG INFO <<<")
215
+ print("input_ids:", input_ids.shape)
216
+ print("attention_mask:", attention_mask.shape)
217
+ print("sentence_mask:", sentence_mask.shape)
218
+
219
  with torch.no_grad():
220
+ sentence_embeddings = model.encode_sentences(input_ids, attention_mask)
221
+ print("sentence_embeddings:", sentence_embeddings.shape)
222
+ transformer_output = model.vanilla_transformer(sentence_embeddings)
223
+ print("transformer_output mean:", transformer_output.mean().item())
224
+ emissions = model.classifier(transformer_output)
225
+ print("emissions shape:", emissions.shape, " | mean:", emissions.mean().item())
226
+
227
+ predictions = model.crf.decode(emissions, mask=sentence_mask)
228
+ print("Predictions (raw):", predictions)
229
+
230
  predicted_labels = predictions[0]
 
231
  results = []
232
  for sentence, label_id in zip(sentences, predicted_labels):
233
  label = id2label.get(label_id, "Unknown")
 
238
  except Exception as e:
239
  return f"❌ Error during prediction: {str(e)}"
240
 
241
+ # -------------------------
242
+ # Gradio Interface
243
+ # -------------------------
244
  demo = gr.Interface(
245
  fn=predict,
246
  inputs=[
 
248
  gr.File(label="Or Upload File (PDF, DOCX, TXT)")
249
  ],
250
  outputs=gr.Textbox(label="Per-Sentence Predictions", lines=10),
251
+ title="βš–οΈ Legal Document Segmentation (Debug Mode)",
252
  description="Classify legal documents sentence-by-sentence into: Arguments (Petitioner/Respondent), Decision, Facts, Issue, or Reasoning",
253
  examples=[
254
  ["The appellant filed a petition against the respondent. The court decides that the appellant is liable.", None],
 
256
  api_name="predict"
257
  )
258
 
259
+ if __name__ == "__main__":
260
  demo.launch()