File size: 12,060 Bytes
8bab08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# file: tests/test_pipeline.py
import pytest
import json
from unittest.mock import Mock, AsyncMock, patch, mock_open
from app.orchestrator import Orchestrator
from app.schema import Company, Prospect
from pathlib import Path
import asyncio

@pytest.mark.asyncio
async def test_pipeline_happy_path():
    """Test full pipeline execution without streaming details"""
    
    # Create a test company in mock data
    test_company = {
        "id": "test",
        "name": "Test Co",
        "domain": "test.com",
        "industry": "SaaS",
        "size": 100,
        "pains": ["Low NPS scores"],
        "notes": ["Growing company"]
    }
    
    # Mock file operations for companies.json
    with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))):
        # Mock MCP registry at module level
        with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry:
            mock_mcp = Mock()
            MockMCPRegistry.return_value = mock_mcp
            
            # Mock store client
            mock_store = AsyncMock()
            mock_store.save_prospect = AsyncMock(return_value=None)
            mock_store.save_company = AsyncMock(return_value=None)
            mock_store.save_fact = AsyncMock(return_value=None)
            mock_store.save_contact = AsyncMock(return_value=None)
            mock_store.save_handoff = AsyncMock(return_value=None)
            mock_store.check_suppression = AsyncMock(return_value=False)
            mock_store.list_contacts_by_domain = AsyncMock(return_value=[])
            
            # Mock search client
            mock_search = AsyncMock()
            mock_search.query = AsyncMock(return_value=[
                {
                    "text": "Test Co focuses on customer experience",
                    "source": "Industry Report",
                    "confidence": 0.85
                }
            ])
            
            # Mock email client
            mock_email = AsyncMock()
            mock_email.send = AsyncMock(return_value={"thread_id": "test-thread-123", "message_id": "msg-456", "prospect_id": "test"})
            mock_email.get_thread = AsyncMock(return_value={
                "id": "test-thread-123",
                "prospect_id": "test",
                "messages": [{
                    "id": "msg-456",
                    "thread_id": "test-thread-123",
                    "direction": "outbound",
                    "subject": "Test Subject",
                    "body": "Test Body",
                    "sent_at": "2024-01-01T00:00:00"
                }]
            })
            
            # Mock calendar client
            mock_calendar = AsyncMock()
            mock_calendar.suggest_slots = AsyncMock(return_value=[
                {"start_iso": "2024-01-02T14:00:00", "end_iso": "2024-01-02T14:30:00"}
            ])
            mock_calendar.generate_ics = AsyncMock(return_value="BEGIN:VCALENDAR...")
            
            # Configure mock MCP
            mock_mcp.get_store_client.return_value = mock_store
            mock_mcp.get_search_client.return_value = mock_search
            mock_mcp.get_email_client.return_value = mock_email
            mock_mcp.get_calendar_client.return_value = mock_calendar
            
            # Mock Path for footer file
            with patch.object(Path, 'exists', return_value=True):
                with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"):
                    # Mock vector retriever
                    with patch('agents.writer.Retriever') as MockRetriever:
                        mock_retriever = Mock()
                        mock_retriever.retrieve.return_value = [
                            {"text": "Relevant fact 1", "score": 0.9}
                        ]
                        MockRetriever.return_value = mock_retriever
                        
                        # Mock requests for Ollama (fallback in Writer)
                        with patch('agents.writer.aiohttp.ClientSession') as MockSession:
                            # Create a mock that fails, triggering the fallback in Writer
                            mock_session = AsyncMock()
                            mock_session.post.side_effect = Exception("Connection failed")
                            MockSession.return_value.__aenter__.return_value = mock_session
                            
                            # Create orchestrator
                            orchestrator = Orchestrator()
                            
                            # Collect all events
                            events = []
                            async for event in orchestrator.run_pipeline(["test"]):
                                events.append(event)
                            
                            # Verify key events occurred
                            event_types = [e.get("type") for e in events]
                            
                            # Should have agent events
                            assert "agent_start" in event_types
                            assert "agent_end" in event_types
                            
                            # Should have MCP interactions
                            assert "mcp_call" in event_types
                            assert "mcp_response" in event_types
                            
                            # Check for either successful completion or policy block
                            # (depends on whether email draft was generated via fallback)
                            assert "llm_done" in event_types or "policy_block" in event_types
                            
                            # Verify core MCP operations were attempted
                            assert mock_store.save_prospect.called
                            assert mock_search.query.called

@pytest.mark.asyncio
async def test_pipeline_compliance_block():
    """Test that compliance violations block the pipeline"""
    
    test_company = {
        "id": "blocked-test",
        "name": "Blocked Co",
        "domain": "blocked.com",
        "industry": "SaaS",
        "size": 100,
        "pains": ["Test pain"],
        "notes": []
    }
    
    with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))):
        with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry:
            mock_mcp = Mock()
            MockMCPRegistry.return_value = mock_mcp
            
            # Mock store with suppressed domain
            mock_store = AsyncMock()
            mock_store.save_prospect = AsyncMock(return_value=None)
            mock_store.save_fact = AsyncMock(return_value=None)
            mock_store.save_contact = AsyncMock(return_value=None)
            
            # This will make the domain suppressed
            async def check_suppression(type, value):
                if type == "domain" and value == "blocked.com":
                    return True
                if type == "email" and "blocked.com" in value:
                    return True
                return False
            
            mock_store.check_suppression = AsyncMock(side_effect=check_suppression)
            mock_store.list_contacts_by_domain = AsyncMock(return_value=[])
            
            # Mock search
            mock_search = AsyncMock()
            mock_search.query = AsyncMock(return_value=[])
            
            # Mock email and calendar
            mock_email = AsyncMock()
            mock_calendar = AsyncMock()
            
            mock_mcp.get_store_client.return_value = mock_store
            mock_mcp.get_search_client.return_value = mock_search
            mock_mcp.get_email_client.return_value = mock_email
            mock_mcp.get_calendar_client.return_value = mock_calendar
            
            with patch.object(Path, 'exists', return_value=True):
                with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"):
                    with patch('agents.writer.Retriever') as MockRetriever:
                        mock_retriever = Mock()
                        mock_retriever.retrieve.return_value = []
                        MockRetriever.return_value = mock_retriever
                        
                        orchestrator = Orchestrator()
                        
                        events = []
                        async for event in orchestrator.run_pipeline(["blocked-test"]):
                            events.append(event)
                        
                        # Should have dropped or blocked due to suppression
                        messages = [str(e.get("message", "")).lower() for e in events]
                        reasons = [str(e.get("payload", {}).get("reason", "")).lower() for e in events]
                        all_text = " ".join(messages + reasons)
                        
                        assert "suppressed" in all_text or "dropped" in all_text or "blocked" in all_text, \
                            f"Should have suppression/dropped/blocked message"

@pytest.mark.asyncio
async def test_pipeline_scorer_drop():
    """Test that low scores drop prospects"""
    
    test_company = {
        "id": "low-score",
        "name": "Small Co",
        "domain": "small.com",
        "industry": "Unknown",  # Low value industry
        "size": 10,  # Too small
        "pains": [],  # No pains
        "notes": []
    }
    
    with patch('builtins.open', mock_open(read_data=json.dumps([test_company]))):
        with patch('app.orchestrator.MCPRegistry') as MockMCPRegistry:
            mock_mcp = Mock()
            MockMCPRegistry.return_value = mock_mcp
            
            mock_store = AsyncMock()
            mock_store.save_prospect = AsyncMock(return_value=None)
            mock_store.save_fact = AsyncMock(return_value=None)
            mock_store.save_contact = AsyncMock(return_value=None)
            mock_store.check_suppression = AsyncMock(return_value=False)
            mock_store.list_contacts_by_domain = AsyncMock(return_value=[])
            
            mock_search = AsyncMock()
            mock_search.query = AsyncMock(return_value=[])
            
            mock_email = AsyncMock()
            mock_calendar = AsyncMock()
            
            mock_mcp.get_store_client.return_value = mock_store
            mock_mcp.get_search_client.return_value = mock_search
            mock_mcp.get_email_client.return_value = mock_email
            mock_mcp.get_calendar_client.return_value = mock_calendar
            
            with patch.object(Path, 'exists', return_value=True):
                with patch.object(Path, 'read_text', return_value="\n---\nTest Footer"):
                    with patch('agents.writer.Retriever') as MockRetriever:
                        mock_retriever = Mock()
                        mock_retriever.retrieve.return_value = []
                        MockRetriever.return_value = mock_retriever
                        
                        orchestrator = Orchestrator()
                        
                        events = []
                        async for event in orchestrator.run_pipeline(["low-score"]):
                            events.append(event)
                        
                        # Check for drop message in events
                        found_drop = False
                        for event in events:
                            message = str(event.get("message", "")).lower()
                            reason = str(event.get("payload", {}).get("reason", "")).lower()
                            status = str(event.get("payload", {}).get("status", "")).lower()
                            
                            if "dropped" in message or "dropped" in reason or "dropped" in status or "low fit score" in message or "low fit score" in reason:
                                found_drop = True
                                break
                        
                        assert found_drop, f"Should have found drop message"