File size: 3,755 Bytes
595bead
73321dd
 
 
9a25cef
73321dd
 
9a25cef
73321dd
 
 
 
 
 
 
 
 
 
 
9a25cef
 
73321dd
 
 
 
 
 
 
 
9a25cef
 
73321dd
 
 
 
 
 
 
 
 
9a25cef
73321dd
 
 
9a25cef
73321dd
9a25cef
73321dd
 
 
9a25cef
 
73321dd
 
 
9a25cef
73321dd
 
 
 
 
9a25cef
595bead
73321dd
 
 
 
 
 
9a25cef
73321dd
 
 
9a25cef
73321dd
 
9a25cef
73321dd
 
9a25cef
73321dd
 
9a25cef
 
73321dd
 
 
 
9a25cef
73321dd
 
 
 
 
9a25cef
73321dd
 
 
 
 
9a25cef
 
 
73321dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import pipeline
from difflib import get_close_matches
from pathlib import Path
import os


class BadQueryDetector:
    def __init__(self):
        self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")

    def is_bad_query(self, query):
        result = self.detector(query)[0]
        label = result["label"]
        score = result["score"]
        # Mark queries as malicious or bad if negative sentiment with high confidence
        if label == "NEGATIVE" and score > 0.8:
            print(f"Detected malicious query with high confidence ({score:.4f}): {query}")
            return True
        return False


class QueryTransformer:
    def transform_query(self, query):
        # Simple transformation example: rephrasing and clarifying
        # In practice, this could involve more sophisticated models like T5
        if "DROP TABLE" in query or "SELECT *" in query:
            return "Your query appears to contain SQL injection elements. Please rephrase."
        # Add more sophisticated handling here
        return query


class DocumentRetriever:
    def __init__(self):
        self.documents = []

    def load_documents(self, source_dir):
        data_dir = Path(source_dir)
        if not data_dir.exists():
            print(f"Source directory not found: {source_dir}")
            return

        for file in data_dir.glob("*.txt"):
            with open(file, "r", encoding="utf-8") as f:
                self.documents.append(f.read())

        print(f"Loaded {len(self.documents)} documents.")

    def retrieve(self, query):
        matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
        return matches if matches else ["No matching documents found."]


class SemanticResponseGenerator:
    def __init__(self):
        self.generator = pipeline("text-generation", model="gpt2")

    def generate_response(self, retrieved_docs):
        # Generate a semantic response using retrieved documents
        combined_docs = " ".join(retrieved_docs[:2])  # Use top 2 matches for response
        response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
        return response[0]["generated_text"]


class DocumentSearchSystem:
    def __init__(self):
        self.detector = BadQueryDetector()
        self.transformer = QueryTransformer()
        self.retriever = DocumentRetriever()
        self.response_generator = SemanticResponseGenerator()

    def process_query(self, query):
        if self.detector.is_bad_query(query):
            return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}

        transformed_query = self.transformer.transform_query(query)
        retrieved_docs = self.retriever.retrieve(transformed_query)

        if "No matching documents found." in retrieved_docs:
            return {"status": "no_results", "message": "No relevant documents found for your query."}

        response = self.response_generator.generate_response(retrieved_docs)
        return {"status": "success", "response": response}


# Test the enhanced system
def test_system():
    system = DocumentSearchSystem()
    system.retriever.load_documents("/path/to/documents")

    # Test with a normal query
    normal_query = "Tell me about great acting performances."
    normal_result = system.process_query(normal_query)
    print("\nNormal Query Result:")
    print(normal_result)

    # Test with a malicious query
    malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
    malicious_result = system.process_query(malicious_query)
    print("\nMalicious Query Result:")
    print(malicious_result)


if __name__ == "__main__":
    test_system()