Spaces:
Runtime error
Runtime error
Refactor #6
Browse files- README.md +1 -1
- src/ctp_slack_bot/app.py +14 -13
- src/ctp_slack_bot/containers.py +39 -34
- src/ctp_slack_bot/controllers/__init__.py +1 -1
- src/ctp_slack_bot/controllers/application_health_controller.py +5 -3
- src/ctp_slack_bot/controllers/base.py +54 -20
- src/ctp_slack_bot/db/mongo_db.py +11 -33
- src/ctp_slack_bot/mime_type_handlers/__init__.py +1 -1
- src/ctp_slack_bot/mime_type_handlers/base.py +29 -11
- src/ctp_slack_bot/mime_type_handlers/text/vtt.py +2 -2
- src/ctp_slack_bot/services/__init__.py +1 -1
- src/ctp_slack_bot/services/content_ingestion_service.py +9 -6
- src/ctp_slack_bot/services/event_brokerage_service.py +9 -8
- src/ctp_slack_bot/services/google_drive_service.py +0 -1
- src/ctp_slack_bot/services/question_dispatch_service.py +9 -6
- src/ctp_slack_bot/services/slack_service.py +59 -45
- src/ctp_slack_bot/services/{schedule_service.py → task_service.py} +13 -17
README.md
CHANGED
|
@@ -101,7 +101,7 @@ Not every file or folder is listed, but the important stuff is here.
|
|
| 101 |
* `google_drive_service.py`: interfaces with Google Drive
|
| 102 |
* `language_model_service.py`: answers questions using relevant context
|
| 103 |
* `question_dispatch_service.py`: listens for questions and retrieves relevant context to get answers
|
| 104 |
-
* `
|
| 105 |
* `slack_service.py`: handles events from Slack and sends back responses
|
| 106 |
* `vectorization_service.py`: converts chunks into chunks with embeddings
|
| 107 |
* `tasks/`: scheduled tasks to run in the background
|
|
|
|
| 101 |
* `google_drive_service.py`: interfaces with Google Drive
|
| 102 |
* `language_model_service.py`: answers questions using relevant context
|
| 103 |
* `question_dispatch_service.py`: listens for questions and retrieves relevant context to get answers
|
| 104 |
+
* `task_service.py`: runs periodic background tasks
|
| 105 |
* `slack_service.py`: handles events from Slack and sends back responses
|
| 106 |
* `vectorization_service.py`: converts chunks into chunks with embeddings
|
| 107 |
* `tasks/`: scheduled tasks to run in the background
|
src/ctp_slack_bot/app.py
CHANGED
|
@@ -7,8 +7,11 @@ from containers import Container
|
|
| 7 |
from core.logging import setup_logging
|
| 8 |
|
| 9 |
|
| 10 |
-
async def handle_shutdown_signal() -> None:
|
| 11 |
logger.info("Received shutdown signal.")
|
|
|
|
|
|
|
|
|
|
| 12 |
for task in all_tasks():
|
| 13 |
if task is not current_task() and not task.done():
|
| 14 |
task.cancel()
|
|
@@ -16,9 +19,9 @@ async def handle_shutdown_signal() -> None:
|
|
| 16 |
logger.info("Cancelled all tasks.")
|
| 17 |
|
| 18 |
|
| 19 |
-
def create_shutdown_signal_handler() -> Callable[[], None]:
|
| 20 |
def shutdown_signal_handler() -> None:
|
| 21 |
-
create_task(handle_shutdown_signal())
|
| 22 |
return shutdown_signal_handler
|
| 23 |
|
| 24 |
|
|
@@ -32,29 +35,27 @@ async def main() -> None:
|
|
| 32 |
container.wire(packages=["ctp_slack_bot"])
|
| 33 |
logger.debug("Created dependency injection container with providers: {}", '; '.join(container.providers))
|
| 34 |
|
| 35 |
-
# Initialize/instantiate services which should be
|
| 36 |
-
container.content_ingestion_service()
|
| 37 |
-
container.question_dispatch_service()
|
| 38 |
http_server = await container.http_server()
|
| 39 |
-
|
| 40 |
-
container.
|
| 41 |
logger.debug("Initialized services.")
|
| 42 |
|
| 43 |
# Install the shutdown signal handler.
|
| 44 |
-
shutdown_signal_handler = create_shutdown_signal_handler()
|
| 45 |
loop = get_running_loop()
|
| 46 |
loop.add_signal_handler(SIGINT, shutdown_signal_handler)
|
| 47 |
loop.add_signal_handler(SIGTERM, shutdown_signal_handler)
|
| 48 |
|
| 49 |
# Start the HTTP server and Slack socket mode handler in the background; clean up resources when shut down.
|
| 50 |
try:
|
| 51 |
-
logger.info("Starting
|
| 52 |
-
await gather(http_server.start(),
|
| 53 |
except CancelledError:
|
| 54 |
logger.info("Shutting down application…")
|
| 55 |
finally:
|
| 56 |
-
await socket_mode_handler.close_async()
|
| 57 |
-
logger.info("Stopped Slack Socket Mode handler.")
|
| 58 |
await container.shutdown_resources()
|
| 59 |
|
| 60 |
|
|
|
|
| 7 |
from core.logging import setup_logging
|
| 8 |
|
| 9 |
|
| 10 |
+
async def handle_shutdown_signal(*args) -> None:
|
| 11 |
logger.info("Received shutdown signal.")
|
| 12 |
+
for arg in args:
|
| 13 |
+
await arg()
|
| 14 |
+
logger.info("Executed shutdown tasks.")
|
| 15 |
for task in all_tasks():
|
| 16 |
if task is not current_task() and not task.done():
|
| 17 |
task.cancel()
|
|
|
|
| 19 |
logger.info("Cancelled all tasks.")
|
| 20 |
|
| 21 |
|
| 22 |
+
def create_shutdown_signal_handler(*args) -> Callable[[], None]:
|
| 23 |
def shutdown_signal_handler() -> None:
|
| 24 |
+
create_task(handle_shutdown_signal(*args))
|
| 25 |
return shutdown_signal_handler
|
| 26 |
|
| 27 |
|
|
|
|
| 35 |
container.wire(packages=["ctp_slack_bot"])
|
| 36 |
logger.debug("Created dependency injection container with providers: {}", '; '.join(container.providers))
|
| 37 |
|
| 38 |
+
# Initialize/instantiate services which should be available from the start.
|
| 39 |
+
await container.content_ingestion_service()
|
| 40 |
+
await container.question_dispatch_service()
|
| 41 |
http_server = await container.http_server()
|
| 42 |
+
slack_service = await container.slack_service()
|
| 43 |
+
task_service = await container.task_service()
|
| 44 |
logger.debug("Initialized services.")
|
| 45 |
|
| 46 |
# Install the shutdown signal handler.
|
| 47 |
+
shutdown_signal_handler = create_shutdown_signal_handler(http_server.stop, slack_service.stop, task_service.stop)
|
| 48 |
loop = get_running_loop()
|
| 49 |
loop.add_signal_handler(SIGINT, shutdown_signal_handler)
|
| 50 |
loop.add_signal_handler(SIGTERM, shutdown_signal_handler)
|
| 51 |
|
| 52 |
# Start the HTTP server and Slack socket mode handler in the background; clean up resources when shut down.
|
| 53 |
try:
|
| 54 |
+
logger.info("Starting services…")
|
| 55 |
+
await gather(http_server.start(), slack_service.start(), task_service.start())
|
| 56 |
except CancelledError:
|
| 57 |
logger.info("Shutting down application…")
|
| 58 |
finally:
|
|
|
|
|
|
|
| 59 |
await container.shutdown_resources()
|
| 60 |
|
| 61 |
|
src/ctp_slack_bot/containers.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
|
|
|
|
|
| 1 |
from dependency_injector.containers import DeclarativeContainer
|
| 2 |
-
from dependency_injector.providers import
|
| 3 |
from importlib import import_module
|
| 4 |
from itertools import chain
|
| 5 |
from openai import AsyncOpenAI
|
| 6 |
from pkgutil import iter_modules
|
| 7 |
-
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
| 8 |
-
from slack_bolt.async_app import AsyncApp
|
| 9 |
from types import ModuleType
|
| 10 |
-
from typing import Sequence
|
| 11 |
|
| 12 |
from ctp_slack_bot.controllers import ControllerBase, ControllerRegistry
|
| 13 |
from ctp_slack_bot.core import Settings
|
|
@@ -16,7 +16,7 @@ from ctp_slack_bot.db.repositories.mongo_db_vectorized_chunk_repository import M
|
|
| 16 |
from ctp_slack_bot.mime_type_handlers import MimeTypeHandlerRegistry
|
| 17 |
from ctp_slack_bot.services.answer_retrieval_service import AnswerRetrievalService
|
| 18 |
from ctp_slack_bot.services.application_health_service import ApplicationHealthService
|
| 19 |
-
from ctp_slack_bot.services.content_ingestion_service import
|
| 20 |
from ctp_slack_bot.services.context_retrieval_service import ContextRetrievalService
|
| 21 |
from ctp_slack_bot.services.embeddings_model_service import EmbeddingsModelService
|
| 22 |
from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
|
|
@@ -24,26 +24,38 @@ from ctp_slack_bot.services.google_drive_service import GoogleDriveService
|
|
| 24 |
from ctp_slack_bot.services.http_client_service import HTTPClientServiceResource
|
| 25 |
from ctp_slack_bot.services.http_server_service import HTTPServerResource
|
| 26 |
from ctp_slack_bot.services.language_model_service import LanguageModelService
|
| 27 |
-
from ctp_slack_bot.services.question_dispatch_service import
|
| 28 |
-
from ctp_slack_bot.services.schedule_service import ScheduleServiceResource
|
| 29 |
from ctp_slack_bot.services.slack_service import SlackServiceResource
|
|
|
|
| 30 |
from ctp_slack_bot.services.vectorization_service import VectorizationService
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class Container(DeclarativeContainer): # TODO: audit for potential async-related bugs.
|
| 34 |
-
async def
|
| 35 |
-
return [controller_class(**{dependency_name: await
|
| 36 |
-
for dependency_name
|
| 37 |
-
in
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
__self__ = Self()
|
| 43 |
settings = Singleton(Settings)
|
| 44 |
event_brokerage_service = Singleton(EventBrokerageService)
|
| 45 |
-
schedule_service = Resource (ScheduleServiceResource,
|
| 46 |
-
settings=settings)
|
| 47 |
http_client = Resource (HTTPClientServiceResource)
|
| 48 |
mongo_db = Resource (MongoDBResource,
|
| 49 |
settings=settings)
|
|
@@ -58,7 +70,7 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
| 58 |
vectorization_service = Singleton(VectorizationService,
|
| 59 |
settings=settings,
|
| 60 |
embeddings_model_service=embeddings_model_service)
|
| 61 |
-
content_ingestion_service =
|
| 62 |
settings=settings,
|
| 63 |
event_brokerage_service=event_brokerage_service,
|
| 64 |
vectorized_chunk_repository=vectorized_chunk_repository,
|
|
@@ -74,25 +86,18 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
| 74 |
settings=settings,
|
| 75 |
event_brokerage_service=event_brokerage_service,
|
| 76 |
language_model_service=language_model_service)
|
| 77 |
-
question_dispatch_service =
|
| 78 |
settings=settings,
|
| 79 |
event_brokerage_service=event_brokerage_service,
|
| 80 |
-
content_ingestion_service=content_ingestion_service,
|
| 81 |
context_retrieval_service=context_retrieval_service,
|
| 82 |
answer_retrieval_service=answer_retrieval_service)
|
| 83 |
-
slack_bolt_app = Singleton(lambda settings: AsyncApp(token=settings.slack_bot_token.get_secret_value()),
|
| 84 |
-
settings)
|
| 85 |
slack_service = Resource (SlackServiceResource,
|
|
|
|
| 86 |
event_brokerage_service=event_brokerage_service,
|
| 87 |
-
http_client=http_client
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
slack_bolt_app,
|
| 92 |
-
settings)
|
| 93 |
-
mime_type_handlers = Dict ({mime_type: Singleton(handler)
|
| 94 |
-
for mime_type, handler
|
| 95 |
-
in MimeTypeHandlerRegistry.get_registry().items()})
|
| 96 |
google_drive_service = Singleton(GoogleDriveService,
|
| 97 |
settings=settings)
|
| 98 |
# file_monitor_service = Singleton(FileMonitorService,
|
|
@@ -101,8 +106,8 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
| 101 |
# mime_type_handler_factory=mime_type_handler_factory)
|
| 102 |
application_health_service = Singleton(ApplicationHealthService,
|
| 103 |
services=List(mongo_db, slack_service))
|
| 104 |
-
|
| 105 |
-
|
| 106 |
http_server = Resource (HTTPServerResource,
|
| 107 |
settings=settings,
|
| 108 |
-
controllers=
|
|
|
|
| 1 |
+
from asyncio import iscoroutine, isfuture
|
| 2 |
+
from dependency_injector import providers
|
| 3 |
from dependency_injector.containers import DeclarativeContainer
|
| 4 |
+
from dependency_injector.providers import Callable, Configuration, Dict, List, Resource, Singleton
|
| 5 |
from importlib import import_module
|
| 6 |
from itertools import chain
|
| 7 |
from openai import AsyncOpenAI
|
| 8 |
from pkgutil import iter_modules
|
|
|
|
|
|
|
| 9 |
from types import ModuleType
|
| 10 |
+
from typing import Any, Iterator, Sequence
|
| 11 |
|
| 12 |
from ctp_slack_bot.controllers import ControllerBase, ControllerRegistry
|
| 13 |
from ctp_slack_bot.core import Settings
|
|
|
|
| 16 |
from ctp_slack_bot.mime_type_handlers import MimeTypeHandlerRegistry
|
| 17 |
from ctp_slack_bot.services.answer_retrieval_service import AnswerRetrievalService
|
| 18 |
from ctp_slack_bot.services.application_health_service import ApplicationHealthService
|
| 19 |
+
from ctp_slack_bot.services.content_ingestion_service import ContentIngestionServiceResource
|
| 20 |
from ctp_slack_bot.services.context_retrieval_service import ContextRetrievalService
|
| 21 |
from ctp_slack_bot.services.embeddings_model_service import EmbeddingsModelService
|
| 22 |
from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
|
|
|
|
| 24 |
from ctp_slack_bot.services.http_client_service import HTTPClientServiceResource
|
| 25 |
from ctp_slack_bot.services.http_server_service import HTTPServerResource
|
| 26 |
from ctp_slack_bot.services.language_model_service import LanguageModelService
|
| 27 |
+
from ctp_slack_bot.services.question_dispatch_service import QuestionDispatchServiceResource
|
|
|
|
| 28 |
from ctp_slack_bot.services.slack_service import SlackServiceResource
|
| 29 |
+
from ctp_slack_bot.services.task_service import TaskServiceResource
|
| 30 |
from ctp_slack_bot.services.vectorization_service import VectorizationService
|
| 31 |
|
| 32 |
|
| 33 |
+
async def _await_or_return(value):
|
| 34 |
+
if iscoroutine(value) or isfuture(value):
|
| 35 |
+
return await value
|
| 36 |
+
return value
|
| 37 |
+
|
| 38 |
+
|
| 39 |
class Container(DeclarativeContainer): # TODO: audit for potential async-related bugs.
|
| 40 |
+
async def __get_http_controller_providers(container) -> Sequence[ControllerBase]:
|
| 41 |
+
return [controller_class(**{dependency_name: await _await_or_return(container.providers[dependency_name]())
|
| 42 |
+
for dependency_name
|
| 43 |
+
in controller_class.model_fields.keys() & container.providers.keys()})
|
| 44 |
+
for controller_class in ControllerRegistry.get_registry()]
|
| 45 |
+
|
| 46 |
+
def __iter_mime_type_handler_providers() -> Iterator[tuple[str, Singleton]]:
|
| 47 |
+
handler_provider_map = {}
|
| 48 |
+
for mime_type, handler in MimeTypeHandlerRegistry.get_registry().items():
|
| 49 |
+
if handler in handler_provider_map:
|
| 50 |
+
provider = handler_provider_map[handler]
|
| 51 |
+
else:
|
| 52 |
+
provider = Singleton(handler)
|
| 53 |
+
handler_provider_map[handler] = provider
|
| 54 |
+
yield (mime_type, provider)
|
| 55 |
|
| 56 |
+
__self__ = providers.Self()
|
| 57 |
settings = Singleton(Settings)
|
| 58 |
event_brokerage_service = Singleton(EventBrokerageService)
|
|
|
|
|
|
|
| 59 |
http_client = Resource (HTTPClientServiceResource)
|
| 60 |
mongo_db = Resource (MongoDBResource,
|
| 61 |
settings=settings)
|
|
|
|
| 70 |
vectorization_service = Singleton(VectorizationService,
|
| 71 |
settings=settings,
|
| 72 |
embeddings_model_service=embeddings_model_service)
|
| 73 |
+
content_ingestion_service = Resource (ContentIngestionServiceResource,
|
| 74 |
settings=settings,
|
| 75 |
event_brokerage_service=event_brokerage_service,
|
| 76 |
vectorized_chunk_repository=vectorized_chunk_repository,
|
|
|
|
| 86 |
settings=settings,
|
| 87 |
event_brokerage_service=event_brokerage_service,
|
| 88 |
language_model_service=language_model_service)
|
| 89 |
+
question_dispatch_service = Resource (QuestionDispatchServiceResource,
|
| 90 |
settings=settings,
|
| 91 |
event_brokerage_service=event_brokerage_service,
|
|
|
|
| 92 |
context_retrieval_service=context_retrieval_service,
|
| 93 |
answer_retrieval_service=answer_retrieval_service)
|
|
|
|
|
|
|
| 94 |
slack_service = Resource (SlackServiceResource,
|
| 95 |
+
settings=settings,
|
| 96 |
event_brokerage_service=event_brokerage_service,
|
| 97 |
+
http_client=http_client)
|
| 98 |
+
mime_type_handlers = Dict ({mime_type: handler_provider
|
| 99 |
+
for mime_type, handler_provider
|
| 100 |
+
in __iter_mime_type_handler_providers()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
google_drive_service = Singleton(GoogleDriveService,
|
| 102 |
settings=settings)
|
| 103 |
# file_monitor_service = Singleton(FileMonitorService,
|
|
|
|
| 106 |
# mime_type_handler_factory=mime_type_handler_factory)
|
| 107 |
application_health_service = Singleton(ApplicationHealthService,
|
| 108 |
services=List(mongo_db, slack_service))
|
| 109 |
+
task_service = Resource (TaskServiceResource,
|
| 110 |
+
settings=settings)
|
| 111 |
http_server = Resource (HTTPServerResource,
|
| 112 |
settings=settings,
|
| 113 |
+
controllers=Callable(__get_http_controller_providers, __self__))
|
src/ctp_slack_bot/controllers/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
from .application_health_controller import ApplicationHealthController
|
| 2 |
-
from .base import ControllerBase, ControllerRegistry, Route
|
|
|
|
| 1 |
from .application_health_controller import ApplicationHealthController
|
| 2 |
+
from .base import controller, ControllerBase, ControllerRegistry, delete, get, patch, post, put, Route, route
|
src/ctp_slack_bot/controllers/application_health_controller.py
CHANGED
|
@@ -2,19 +2,21 @@ from aiohttp.web import json_response, Request, Response
|
|
| 2 |
from pydantic import ConfigDict
|
| 3 |
from typing import Self
|
| 4 |
|
| 5 |
-
from .base import ControllerBase,
|
| 6 |
from ctp_slack_bot.services import ApplicationHealthService
|
| 7 |
|
| 8 |
|
| 9 |
-
@
|
| 10 |
class ApplicationHealthController(ControllerBase):
|
| 11 |
"""
|
| 12 |
Application health reporting endpoints.
|
| 13 |
"""
|
| 14 |
|
|
|
|
|
|
|
| 15 |
application_health_service: ApplicationHealthService
|
| 16 |
|
| 17 |
-
@
|
| 18 |
async def get_health(self: Self, request: Request) -> Response:
|
| 19 |
health_statuses = await self.application_health_service.get_health()
|
| 20 |
return json_response(dict(health_statuses), status=200 if all(health_statuses.values()) else 503)
|
|
|
|
| 2 |
from pydantic import ConfigDict
|
| 3 |
from typing import Self
|
| 4 |
|
| 5 |
+
from .base import ControllerBase, controller, get
|
| 6 |
from ctp_slack_bot.services import ApplicationHealthService
|
| 7 |
|
| 8 |
|
| 9 |
+
@controller("/health")
|
| 10 |
class ApplicationHealthController(ControllerBase):
|
| 11 |
"""
|
| 12 |
Application health reporting endpoints.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
model_config = ConfigDict(frozen=True)
|
| 16 |
+
|
| 17 |
application_health_service: ApplicationHealthService
|
| 18 |
|
| 19 |
+
@get("")
|
| 20 |
async def get_health(self: Self, request: Request) -> Response:
|
| 21 |
health_statuses = await self.application_health_service.get_health()
|
| 22 |
return json_response(dict(health_statuses), status=200 if all(health_statuses.values()) else 503)
|
src/ctp_slack_bot/controllers/base.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
| 1 |
from aiohttp.web import Request, Response
|
| 2 |
-
from functools import partial
|
| 3 |
from importlib import import_module
|
| 4 |
from inspect import getmembers, ismethod
|
| 5 |
from pydantic import BaseModel, ConfigDict
|
| 6 |
-
from typing import Awaitable, Callable, ClassVar, Mapping, Self, Sequence, TypeVar
|
| 7 |
|
| 8 |
from ctp_slack_bot.core import ApplicationComponentBase
|
| 9 |
|
|
@@ -18,21 +19,20 @@ class Route(BaseModel):
|
|
| 18 |
path: str
|
| 19 |
handler: AsyncHandler
|
| 20 |
|
| 21 |
-
@staticmethod
|
| 22 |
-
def get(path: str) -> Callable[[AsyncHandler], AsyncHandler]:
|
| 23 |
-
def decorator(function: AsyncHandler) -> AsyncHandler:
|
| 24 |
-
function._http_method = "GET"
|
| 25 |
-
function._http_path = path
|
| 26 |
-
return function
|
| 27 |
-
return decorator
|
| 28 |
-
|
| 29 |
|
| 30 |
class ControllerBase(ApplicationComponentBase):
|
| 31 |
|
| 32 |
def get_routes(self: Self) -> Sequence[Route]:
|
| 33 |
-
return tuple(Route(method=method._http_method,
|
|
|
|
|
|
|
| 34 |
for name, method in getmembers(self, predicate=ismethod)
|
| 35 |
-
if name != 'get_routes' and hasattr(method, "_http_method") and hasattr(method, "_http_path"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
T = TypeVar('T', bound=ControllerBase)
|
|
@@ -40,16 +40,50 @@ T = TypeVar('T', bound=ControllerBase)
|
|
| 40 |
|
| 41 |
class ControllerRegistry:
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
@classmethod
|
| 46 |
-
def get_registry(cls) ->
|
| 47 |
import_module(__package__)
|
| 48 |
-
return tuple(cls.
|
| 49 |
|
| 50 |
@classmethod
|
| 51 |
-
def register(cls):
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
from aiohttp.web import Request, Response
|
| 3 |
+
from functools import partial, wraps
|
| 4 |
from importlib import import_module
|
| 5 |
from inspect import getmembers, ismethod
|
| 6 |
from pydantic import BaseModel, ConfigDict
|
| 7 |
+
from typing import Awaitable, Callable, ClassVar, Collection, Mapping, Optional, overload, ParamSpec, Self, Sequence, TypeVar
|
| 8 |
|
| 9 |
from ctp_slack_bot.core import ApplicationComponentBase
|
| 10 |
|
|
|
|
| 19 |
path: str
|
| 20 |
handler: AsyncHandler
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class ControllerBase(ApplicationComponentBase):
|
| 24 |
|
| 25 |
def get_routes(self: Self) -> Sequence[Route]:
|
| 26 |
+
return tuple(Route(method=method._http_method,
|
| 27 |
+
path="/".join(filter(None, (self.prefix, method._http_path))),
|
| 28 |
+
handler=method)
|
| 29 |
for name, method in getmembers(self, predicate=ismethod)
|
| 30 |
+
if name != 'get_routes' and name != 'prefix' and hasattr(method, "_http_method") and hasattr(method, "_http_path"))
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def prefix(self: Self) -> str:
|
| 35 |
+
pass
|
| 36 |
|
| 37 |
|
| 38 |
T = TypeVar('T', bound=ControllerBase)
|
|
|
|
| 40 |
|
| 41 |
class ControllerRegistry:
|
| 42 |
|
| 43 |
+
__registry: ClassVar[list[T]] = []
|
| 44 |
|
| 45 |
@classmethod
|
| 46 |
+
def get_registry(cls) -> Collection[T]:
|
| 47 |
import_module(__package__)
|
| 48 |
+
return tuple(cls.__registry)
|
| 49 |
|
| 50 |
@classmethod
|
| 51 |
+
def register(cls, controller_cls: T) -> None:
|
| 52 |
+
cls.__registry.append(controller_cls)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@overload
|
| 56 |
+
def controller(cls: T) -> T: ...
|
| 57 |
+
|
| 58 |
+
@overload
|
| 59 |
+
def controller(prefix: str = "/") -> Callable[[T], T]: ...
|
| 60 |
+
|
| 61 |
+
def controller(cls_or_prefix=None):
|
| 62 |
+
def implement_prefix_property_and_register_controller(cls: T, prefix: Optional[str] = "/") -> T:
|
| 63 |
+
def prefix_getter(self: T) -> str:
|
| 64 |
+
return prefix
|
| 65 |
+
setattr(cls, 'prefix', property(prefix_getter))
|
| 66 |
+
if hasattr(cls, '__abstractmethods__'):
|
| 67 |
+
cls.__abstractmethods__ = frozenset(method for method in cls.__abstractmethods__ if method != 'prefix')
|
| 68 |
+
ControllerRegistry.register(cls)
|
| 69 |
+
return cls
|
| 70 |
+
if isinstance(cls_or_prefix, type):
|
| 71 |
+
return implement_prefix_property_and_register_controller(cls_or_prefix)
|
| 72 |
+
def decorator(cls: T) -> T:
|
| 73 |
+
return implement_prefix_property_and_register_controller(cls, cls_or_prefix)
|
| 74 |
+
return decorator
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def route(method: str, path: str = "") -> Callable[[AsyncHandler], AsyncHandler]:
|
| 78 |
+
def decorator(function: AsyncHandler) -> AsyncHandler:
|
| 79 |
+
function._http_method = method
|
| 80 |
+
function._http_path = path
|
| 81 |
+
return function
|
| 82 |
+
return decorator
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
get = partial(route, "GET")
|
| 86 |
+
post = partial(route, "POST")
|
| 87 |
+
put = partial(route, "PUT")
|
| 88 |
+
delete = partial(route, "DELETE")
|
| 89 |
+
patch = partial(route, "PATCH")
|
src/ctp_slack_bot/db/mongo_db.py
CHANGED
|
@@ -65,30 +65,16 @@ class MongoDB(HealthReportingApplicationComponentBase):
|
|
| 65 |
"""
|
| 66 |
Get a collection by name or creates it if it doesn’t exist.
|
| 67 |
"""
|
| 68 |
-
# First ensure we can connect at all.
|
| 69 |
-
if not await self.ping():
|
| 70 |
-
logger.error("Cannot get collection '{}' because a MongoDB connection is not available.", name)
|
| 71 |
-
raise ConnectionError("MongoDB connection is not available.")
|
| 72 |
-
|
| 73 |
try:
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
if name not in collection_names:
|
| 79 |
-
logger.info("Collection '{}' does not exist. Creating it…", name)
|
| 80 |
-
|
| 81 |
-
# Create the collection.
|
| 82 |
-
await self._db.create_collection(name)
|
| 83 |
-
logger.debug("Successfully created collection: {}", name)
|
| 84 |
else:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# Get and return the collection.
|
| 88 |
-
collection = self._db[name]
|
| 89 |
return collection
|
| 90 |
except Exception as e:
|
| 91 |
-
logger.error("Error accessing collection
|
| 92 |
raise e
|
| 93 |
|
| 94 |
def close(self: Self) -> None:
|
|
@@ -117,19 +103,11 @@ class MongoDBResource(AsyncResource):
|
|
| 117 |
|
| 118 |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None:
|
| 119 |
"""Test MongoDB connection and log the result."""
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
else:
|
| 125 |
-
logger.error("MongoDB connection test failed!")
|
| 126 |
-
except Exception as e:
|
| 127 |
-
logger.error("Error testing MongoDB connection: {}", e)
|
| 128 |
-
raise e
|
| 129 |
|
| 130 |
async def shutdown(self: Self, mongo_db: MongoDB) -> None:
|
| 131 |
"""Close MongoDB connection on shutdown."""
|
| 132 |
-
|
| 133 |
-
mongo_db.close()
|
| 134 |
-
except Exception as e:
|
| 135 |
-
logger.error("Error closing MongoDB connection: {}", e)
|
|
|
|
| 65 |
"""
|
| 66 |
Get a collection by name or creates it if it doesn’t exist.
|
| 67 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
try:
|
| 69 |
+
if name not in await self._db.list_collection_names():
|
| 70 |
+
collection = await self._db.create_collection(name)
|
| 71 |
+
logger.debug("Created previously nonexistent collection, {}.", name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
else:
|
| 73 |
+
collection = self._db[name]
|
| 74 |
+
logger.debug("Retrieved collection, {}.", name)
|
|
|
|
|
|
|
| 75 |
return collection
|
| 76 |
except Exception as e:
|
| 77 |
+
logger.error("Error accessing collection, {}: {}", name, e)
|
| 78 |
raise e
|
| 79 |
|
| 80 |
def close(self: Self) -> None:
|
|
|
|
| 103 |
|
| 104 |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None:
|
| 105 |
"""Test MongoDB connection and log the result."""
|
| 106 |
+
if await mongo_db.ping():
|
| 107 |
+
logger.info("MongoDB connection test successful!")
|
| 108 |
+
else:
|
| 109 |
+
logger.error("MongoDB connection test failed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
async def shutdown(self: Self, mongo_db: MongoDB) -> None:
|
| 112 |
"""Close MongoDB connection on shutdown."""
|
| 113 |
+
mongo_db.close()
|
|
|
|
|
|
|
|
|
src/ctp_slack_bot/mime_type_handlers/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
from .base import MimeTypeHandler, MimeTypeHandlerRegistry
|
| 2 |
from .text.vtt import WebVTTMimeTypeHandler
|
|
|
|
| 1 |
+
from .base import MimeTypeHandler, mime_type_handler, MimeTypeHandlerRegistry
|
| 2 |
from .text.vtt import WebVTTMimeTypeHandler
|
src/ctp_slack_bot/mime_type_handlers/base.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
from abc import
|
| 2 |
from importlib import import_module
|
| 3 |
from types import MappingProxyType
|
| 4 |
-
from typing import Any, ClassVar, Mapping, Optional
|
| 5 |
|
| 6 |
from ctp_slack_bot.core import ApplicationComponentBase
|
| 7 |
from ctp_slack_bot.models import Content
|
|
@@ -14,20 +14,38 @@ class MimeTypeHandler(ApplicationComponentBase):
|
|
| 14 |
pass
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
class MimeTypeHandlerRegistry:
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
@classmethod
|
| 22 |
-
def get_registry(cls) -> Mapping[str,
|
| 23 |
import_module(__package__)
|
| 24 |
-
return MappingProxyType(cls.
|
| 25 |
|
| 26 |
@classmethod
|
| 27 |
-
def register(cls,
|
| 28 |
-
|
| 29 |
-
if mime_type in cls.
|
| 30 |
raise ValueError(f"The MIME type, {mime_type}, is already registered.")
|
| 31 |
-
cls.
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
from importlib import import_module
|
| 3 |
from types import MappingProxyType
|
| 4 |
+
from typing import Any, Callable, ClassVar, Mapping, Optional, overload, Set, TypeVar
|
| 5 |
|
| 6 |
from ctp_slack_bot.core import ApplicationComponentBase
|
| 7 |
from ctp_slack_bot.models import Content
|
|
|
|
| 14 |
pass
|
| 15 |
|
| 16 |
|
| 17 |
+
T = TypeVar('T', bound=MimeTypeHandler)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
class MimeTypeHandlerRegistry:
|
| 21 |
|
| 22 |
+
__registry: ClassVar[dict[str, T]] = {}
|
| 23 |
|
| 24 |
@classmethod
|
| 25 |
+
def get_registry(cls) -> Mapping[str, T]:
|
| 26 |
import_module(__package__)
|
| 27 |
+
return MappingProxyType(cls.__registry)
|
| 28 |
|
| 29 |
@classmethod
|
| 30 |
+
def register(cls, mime_types: Set[str], handler_cls: T):
|
| 31 |
+
for mime_type in mime_types:
|
| 32 |
+
if mime_type in cls.__registry:
|
| 33 |
raise ValueError(f"The MIME type, {mime_type}, is already registered.")
|
| 34 |
+
cls.__registry[mime_type] = handler_cls
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@overload
|
| 38 |
+
def mime_type_handler(cls: T) -> T: ...
|
| 39 |
+
|
| 40 |
+
@overload
|
| 41 |
+
def mime_type_handler(mime_types: Optional[Set[str] | str] = None) -> Callable[[T], T]: ...
|
| 42 |
+
|
| 43 |
+
def mime_type_handler(cls_or_mime_types=None):
|
| 44 |
+
def register_mime_type_handler(cls: T, mime_types: Optional[Set[str]] = None) -> T:
|
| 45 |
+
MimeTypeHandlerRegistry.register({mime_types} if isinstance(mime_types, str) else mime_types, cls)
|
| 46 |
+
return cls
|
| 47 |
+
if isinstance(cls_or_mime_types, type):
|
| 48 |
+
return register_mime_type_handler(cls_or_mime_types)
|
| 49 |
+
def decorator(cls: T) -> T:
|
| 50 |
+
return register_mime_type_handler(cls, cls_or_mime_types)
|
| 51 |
+
return decorator
|
src/ctp_slack_bot/mime_type_handlers/text/vtt.py
CHANGED
|
@@ -6,11 +6,11 @@ from types import MappingProxyType
|
|
| 6 |
from typing import Any, ClassVar, Mapping, Optional, Self
|
| 7 |
from webvtt import WebVTT
|
| 8 |
|
| 9 |
-
from ctp_slack_bot.mime_type_handlers.base import MimeTypeHandler,
|
| 10 |
from ctp_slack_bot.models import Content, WebVTTContent, WebVTTFrame
|
| 11 |
|
| 12 |
|
| 13 |
-
@
|
| 14 |
class WebVTTMimeTypeHandler(MimeTypeHandler):
|
| 15 |
|
| 16 |
model_config = ConfigDict(frozen=True)
|
|
|
|
| 6 |
from typing import Any, ClassVar, Mapping, Optional, Self
|
| 7 |
from webvtt import WebVTT
|
| 8 |
|
| 9 |
+
from ctp_slack_bot.mime_type_handlers.base import MimeTypeHandler, mime_type_handler
|
| 10 |
from ctp_slack_bot.models import Content, WebVTTContent, WebVTTFrame
|
| 11 |
|
| 12 |
|
| 13 |
+
@mime_type_handler("text/vtt")
|
| 14 |
class WebVTTMimeTypeHandler(MimeTypeHandler):
|
| 15 |
|
| 16 |
model_config = ConfigDict(frozen=True)
|
src/ctp_slack_bot/services/__init__.py
CHANGED
|
@@ -8,6 +8,6 @@ from .google_drive_service import GoogleDriveService
|
|
| 8 |
from .http_server_service import HTTPServer
|
| 9 |
from .language_model_service import LanguageModelService
|
| 10 |
from .question_dispatch_service import QuestionDispatchService
|
| 11 |
-
from. schedule_service import ScheduleService
|
| 12 |
from .slack_service import SlackService
|
|
|
|
| 13 |
from .vectorization_service import VectorizationService
|
|
|
|
| 8 |
from .http_server_service import HTTPServer
|
| 9 |
from .language_model_service import LanguageModelService
|
| 10 |
from .question_dispatch_service import QuestionDispatchService
|
|
|
|
| 11 |
from .slack_service import SlackService
|
| 12 |
+
from. task_service import TaskService
|
| 13 |
from .vectorization_service import VectorizationService
|
src/ctp_slack_bot/services/content_ingestion_service.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from loguru import logger
|
| 2 |
from pydantic import ConfigDict
|
| 3 |
from typing import Any, Self, Sequence, Set
|
|
@@ -18,15 +19,9 @@ class ContentIngestionService(ApplicationComponentBase):
|
|
| 18 |
model_config = ConfigDict(frozen=True)
|
| 19 |
|
| 20 |
settings: Settings
|
| 21 |
-
event_brokerage_service: EventBrokerageService
|
| 22 |
vectorized_chunk_repository: VectorizedChunkRepository
|
| 23 |
vectorization_service: VectorizationService
|
| 24 |
|
| 25 |
-
def model_post_init(self: Self, context: Any, /) -> None:
|
| 26 |
-
super().model_post_init(context)
|
| 27 |
-
self.event_brokerage_service.subscribe(EventType.INCOMING_CONTENT, self.process_incoming_content)
|
| 28 |
-
# self.event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, self.process_incoming_slack_message)
|
| 29 |
-
|
| 30 |
async def process_incoming_content(self: Self, content: Content) -> None:
|
| 31 |
logger.debug("Content ingestion service received content with metadata: {}", content.get_metadata())
|
| 32 |
if self.vectorized_chunk_repository.count_by_id(content.get_id()):
|
|
@@ -49,3 +44,11 @@ class ContentIngestionService(ApplicationComponentBase):
|
|
| 49 |
@property
|
| 50 |
def name(self: Self) -> str:
|
| 51 |
return "content_ingestion_service"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dependency_injector.resources import AsyncResource
|
| 2 |
from loguru import logger
|
| 3 |
from pydantic import ConfigDict
|
| 4 |
from typing import Any, Self, Sequence, Set
|
|
|
|
| 19 |
model_config = ConfigDict(frozen=True)
|
| 20 |
|
| 21 |
settings: Settings
|
|
|
|
| 22 |
vectorized_chunk_repository: VectorizedChunkRepository
|
| 23 |
vectorization_service: VectorizationService
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
async def process_incoming_content(self: Self, content: Content) -> None:
|
| 26 |
logger.debug("Content ingestion service received content with metadata: {}", content.get_metadata())
|
| 27 |
if self.vectorized_chunk_repository.count_by_id(content.get_id()):
|
|
|
|
| 44 |
@property
|
| 45 |
def name(self: Self) -> str:
|
| 46 |
return "content_ingestion_service"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ContentIngestionServiceResource(AsyncResource):
|
| 50 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, vectorized_chunk_repository: VectorizedChunkRepository, vectorization_service: VectorizationService) -> ContentIngestionService:
|
| 51 |
+
content_ingestion_service = ContentIngestionService(settings=settings, vectorized_chunk_repository=vectorized_chunk_repository, vectorization_service=vectorization_service)
|
| 52 |
+
await event_brokerage_service.subscribe(EventType.INCOMING_CONTENT, content_ingestion_service.process_incoming_content)
|
| 53 |
+
# await event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, content_ingestion_service.process_incoming_slack_message)
|
| 54 |
+
return content_ingestion_service
|
src/ctp_slack_bot/services/event_brokerage_service.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from asyncio import create_task, iscoroutinefunction, to_thread
|
| 2 |
from collections import defaultdict
|
| 3 |
from loguru import logger
|
| 4 |
from pydantic import ConfigDict, PrivateAttr
|
|
@@ -15,18 +15,19 @@ class EventBrokerageService(ApplicationComponentBase):
|
|
| 15 |
|
| 16 |
model_config = ConfigDict(frozen=True)
|
| 17 |
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
def subscribe(self: Self, type: EventType, callback: Callable) -> None:
|
| 21 |
"""Subscribe to an event type with a callback function."""
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
async def publish(self: Self, type: EventType, data: Any = None) -> None:
|
| 28 |
"""Publish an event with optional data to all subscribers."""
|
| 29 |
-
subscribers = self.
|
| 30 |
if not subscribers:
|
| 31 |
logger.debug("No subscribers handle event {}: {}", type, len(subscribers), data)
|
| 32 |
return
|
|
|
|
| 1 |
+
from asyncio import create_task, iscoroutinefunction, Lock, to_thread
|
| 2 |
from collections import defaultdict
|
| 3 |
from loguru import logger
|
| 4 |
from pydantic import ConfigDict, PrivateAttr
|
|
|
|
| 15 |
|
| 16 |
model_config = ConfigDict(frozen=True)
|
| 17 |
|
| 18 |
+
__write_lock: Lock = PrivateAttr(default_factory=Lock)
|
| 19 |
+
__subscribers: MutableMapping[EventType, tuple[Callable]] = PrivateAttr(default_factory=lambda: defaultdict(tuple))
|
| 20 |
|
| 21 |
+
async def subscribe(self: Self, type: EventType, callback: Callable) -> None:
|
| 22 |
"""Subscribe to an event type with a callback function."""
|
| 23 |
+
async with self.__write_lock:
|
| 24 |
+
subscribers = self.__subscribers[type]
|
| 25 |
+
self.__subscribers[type] = subscribers + (callback, )
|
| 26 |
+
logger.debug("One new subscriber was added for event type {} ({} subscriber(s) in total).", type, len(subscribers))
|
| 27 |
|
| 28 |
async def publish(self: Self, type: EventType, data: Any = None) -> None:
|
| 29 |
"""Publish an event with optional data to all subscribers."""
|
| 30 |
+
subscribers = self.__subscribers[type]
|
| 31 |
if not subscribers:
|
| 32 |
logger.debug("No subscribers handle event {}: {}", type, len(subscribers), data)
|
| 33 |
return
|
src/ctp_slack_bot/services/google_drive_service.py
CHANGED
|
@@ -39,7 +39,6 @@ class GoogleDriveService(ApplicationComponentBase):
|
|
| 39 |
"token_uri": self.settings.google_token_uri,
|
| 40 |
}, scopes=["https://www.googleapis.com/auth/drive"])
|
| 41 |
self._google_drive_client = build('drive', 'v3', credentials=credentials)
|
| 42 |
-
logger.info(type(self._google_drive_client))
|
| 43 |
|
| 44 |
def _resolve_folder_id(self: Self, folder_path: str) -> Optional[str]:
|
| 45 |
"""Resolve a folder path to a Google Drive ID."""
|
|
|
|
| 39 |
"token_uri": self.settings.google_token_uri,
|
| 40 |
}, scopes=["https://www.googleapis.com/auth/drive"])
|
| 41 |
self._google_drive_client = build('drive', 'v3', credentials=credentials)
|
|
|
|
| 42 |
|
| 43 |
def _resolve_folder_id(self: Self, folder_path: str) -> Optional[str]:
|
| 44 |
"""Resolve a folder path to a Google Drive ID."""
|
src/ctp_slack_bot/services/question_dispatch_service.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from loguru import logger
|
| 2 |
from pydantic import ConfigDict
|
| 3 |
from typing import Any, Self
|
|
@@ -18,15 +19,10 @@ class QuestionDispatchService(ApplicationComponentBase):
|
|
| 18 |
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
| 19 |
|
| 20 |
settings: Settings
|
| 21 |
-
event_brokerage_service: EventBrokerageService
|
| 22 |
context_retrieval_service: ContextRetrievalService
|
| 23 |
answer_retrieval_service: AnswerRetrievalService
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
super().model_post_init(context)
|
| 27 |
-
self.event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, self.__process_incoming_slack_message)
|
| 28 |
-
|
| 29 |
-
async def __process_incoming_slack_message(self: Self, message: SlackMessage) -> None:
|
| 30 |
if message.subtype != 'bot_message':
|
| 31 |
logger.debug("Question dispatch service received an answerable question: {}", message.text)
|
| 32 |
context = await self.context_retrieval_service.get_context(message)
|
|
@@ -35,3 +31,10 @@ class QuestionDispatchService(ApplicationComponentBase):
|
|
| 35 |
@property
|
| 36 |
def name(self: Self) -> str:
|
| 37 |
return "question_dispatch_service"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dependency_injector.resources import AsyncResource
|
| 2 |
from loguru import logger
|
| 3 |
from pydantic import ConfigDict
|
| 4 |
from typing import Any, Self
|
|
|
|
| 19 |
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
| 20 |
|
| 21 |
settings: Settings
|
|
|
|
| 22 |
context_retrieval_service: ContextRetrievalService
|
| 23 |
answer_retrieval_service: AnswerRetrievalService
|
| 24 |
|
| 25 |
+
async def process_incoming_slack_message(self: Self, message: SlackMessage) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if message.subtype != 'bot_message':
|
| 27 |
logger.debug("Question dispatch service received an answerable question: {}", message.text)
|
| 28 |
context = await self.context_retrieval_service.get_context(message)
|
|
|
|
| 31 |
@property
|
| 32 |
def name(self: Self) -> str:
|
| 33 |
return "question_dispatch_service"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class QuestionDispatchServiceResource(AsyncResource):
|
| 37 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, context_retrieval_service: ContextRetrievalService, answer_retrieval_service: AnswerRetrievalService) -> QuestionDispatchService:
|
| 38 |
+
question_dispatch_service = QuestionDispatchService(settings=settings, context_retrieval_service=context_retrieval_service, answer_retrieval_service=answer_retrieval_service)
|
| 39 |
+
await event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, question_dispatch_service.process_incoming_slack_message)
|
| 40 |
+
return question_dispatch_service
|
src/ctp_slack_bot/services/slack_service.py
CHANGED
|
@@ -2,14 +2,15 @@ from dependency_injector.resources import AsyncResource
|
|
| 2 |
from httpx import AsyncClient
|
| 3 |
from loguru import logger
|
| 4 |
from openai import OpenAI
|
| 5 |
-
from pydantic import ConfigDict
|
| 6 |
from re import compile as compile_re, Pattern
|
| 7 |
-
from
|
| 8 |
from slack_bolt.async_app import AsyncApp
|
|
|
|
| 9 |
from slack_sdk.web.async_slack_response import AsyncSlackResponse
|
| 10 |
from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set
|
| 11 |
|
| 12 |
-
from ctp_slack_bot.core import HealthReportingApplicationComponentBase
|
| 13 |
from ctp_slack_bot.enums import EventType
|
| 14 |
from ctp_slack_bot.models import SlackMessage, SlackResponse
|
| 15 |
from .event_brokerage_service import EventBrokerageService
|
|
@@ -25,17 +26,54 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
| 25 |
_SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+")
|
| 26 |
_SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>")
|
| 27 |
|
|
|
|
| 28 |
event_brokerage_service: EventBrokerageService
|
| 29 |
http_client: AsyncClient
|
| 30 |
slack_bolt_app: AsyncApp
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
-
def initialize(self: Self) -> None:
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
self.slack_bolt_app.event("message")(self._handle_message_event)
|
| 36 |
self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
|
| 37 |
logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
async def send_message(self: Self, message: SlackResponse) -> None:
|
| 40 |
await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
|
| 41 |
|
|
@@ -64,16 +102,16 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
| 64 |
)
|
| 65 |
|
| 66 |
async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
|
| 67 |
-
unknown_ids = ids - self.
|
| 68 |
if len(unknown_ids) == 0:
|
| 69 |
return
|
| 70 |
async with TaskGroup() as task_group:
|
| 71 |
update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
|
| 72 |
-
self.
|
| 73 |
|
| 74 |
async def _get_name(self: Self, id: str) -> str:
|
| 75 |
await self._ensure_ids_in_id_name_map({id})
|
| 76 |
-
return self.
|
| 77 |
|
| 78 |
async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
|
| 79 |
logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
|
|
@@ -113,46 +151,13 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
| 113 |
start, end = match.span()
|
| 114 |
parts.append(text[previous_end:start])
|
| 115 |
user_id = match.group(1)
|
| 116 |
-
parts.append(f"@{self.
|
| 117 |
previous_end = end
|
| 118 |
parts.append(text[previous_end:])
|
| 119 |
return ''.join(parts)
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
async def init(self: Self, event_brokerage_service: EventBrokerageService, http_client: AsyncClient, slack_bolt_app: AsyncApp) -> SlackService:
|
| 124 |
-
async def get_users_list():
|
| 125 |
-
cursor = None
|
| 126 |
-
while True:
|
| 127 |
-
try:
|
| 128 |
-
response = await slack_bolt_app.client.users_list(cursor=cursor, limit=200)
|
| 129 |
-
except SlackApiError as e:
|
| 130 |
-
logger.warning("Could not get a list of users: {}", e)
|
| 131 |
-
break
|
| 132 |
-
match response:
|
| 133 |
-
case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
|
| 134 |
-
for user in users:
|
| 135 |
-
yield user
|
| 136 |
-
match response.data:
|
| 137 |
-
case {"response_metadata": {"next_cursor": cursor}} if cursor:
|
| 138 |
-
continue
|
| 139 |
-
case AsyncSlackResponse(status_code=status_code) if status_code != 200:
|
| 140 |
-
logger.warning("Could not get a list of users: response status {}", status_code)
|
| 141 |
-
case AsyncSlackResponse(data={"ok": False}):
|
| 142 |
-
logger.warning("Could not get a list of users: non-OK response")
|
| 143 |
-
case _:
|
| 144 |
-
logger.warning("Could not get a list of users.")
|
| 145 |
-
break
|
| 146 |
-
id_name_map = {user["id"]: self._get_name(user)
|
| 147 |
-
async for user
|
| 148 |
-
in get_users_list()}
|
| 149 |
-
logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
|
| 150 |
-
slack_service = SlackService(event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, id_name_map=id_name_map)
|
| 151 |
-
slack_service.initialize()
|
| 152 |
-
return slack_service
|
| 153 |
-
|
| 154 |
-
@classmethod
|
| 155 |
-
def _get_name(cls, user: Mapping[str, Any]):
|
| 156 |
match user:
|
| 157 |
case {"real_name": real_name}:
|
| 158 |
return real_name
|
|
@@ -160,3 +165,12 @@ class SlackServiceResource(AsyncResource):
|
|
| 160 |
return display_name
|
| 161 |
case {"name": name}:
|
| 162 |
return name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from httpx import AsyncClient
|
| 3 |
from loguru import logger
|
| 4 |
from openai import OpenAI
|
| 5 |
+
from pydantic import ConfigDict, PrivateAttr
|
| 6 |
from re import compile as compile_re, Pattern
|
| 7 |
+
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
| 8 |
from slack_bolt.async_app import AsyncApp
|
| 9 |
+
from slack_sdk.errors import SlackApiError
|
| 10 |
from slack_sdk.web.async_slack_response import AsyncSlackResponse
|
| 11 |
from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set
|
| 12 |
|
| 13 |
+
from ctp_slack_bot.core import HealthReportingApplicationComponentBase, Settings
|
| 14 |
from ctp_slack_bot.enums import EventType
|
| 15 |
from ctp_slack_bot.models import SlackMessage, SlackResponse
|
| 16 |
from .event_brokerage_service import EventBrokerageService
|
|
|
|
| 26 |
_SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+")
|
| 27 |
_SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>")
|
| 28 |
|
| 29 |
+
settings: Settings
|
| 30 |
event_brokerage_service: EventBrokerageService
|
| 31 |
http_client: AsyncClient
|
| 32 |
slack_bolt_app: AsyncApp
|
| 33 |
+
socket_mode_handler: AsyncSocketModeHandler
|
| 34 |
+
_id_name_map: MutableMapping[str, str] = PrivateAttr(default={}) # TODO: Spin message processing out into its own service.
|
| 35 |
|
| 36 |
+
async def initialize(self: Self) -> None:
|
| 37 |
+
async def get_users_list():
|
| 38 |
+
cursor = None
|
| 39 |
+
while True:
|
| 40 |
+
try:
|
| 41 |
+
response = await self.slack_bolt_app.client.users_list(cursor=cursor, limit=200)
|
| 42 |
+
except SlackApiError as e:
|
| 43 |
+
logger.warning("Could not get a list of users: {}", e)
|
| 44 |
+
break
|
| 45 |
+
match response:
|
| 46 |
+
case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
|
| 47 |
+
for user in users:
|
| 48 |
+
yield user
|
| 49 |
+
match response.data:
|
| 50 |
+
case {"response_metadata": {"next_cursor": cursor}} if cursor:
|
| 51 |
+
continue
|
| 52 |
+
case AsyncSlackResponse(status_code=status_code) if status_code != 200:
|
| 53 |
+
logger.warning("Could not get a list of users: response status {}", status_code)
|
| 54 |
+
case AsyncSlackResponse(data={"ok": False}):
|
| 55 |
+
logger.warning("Could not get a list of users: non-OK response")
|
| 56 |
+
case _:
|
| 57 |
+
logger.warning("Could not get a list of users.")
|
| 58 |
+
break
|
| 59 |
+
id_name_map = {user["id"]: self._resolve_user_name(user)
|
| 60 |
+
async for user
|
| 61 |
+
in get_users_list()}
|
| 62 |
+
self._id_name_map.update(id_name_map)
|
| 63 |
+
logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
|
| 64 |
+
|
| 65 |
+
await self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message)
|
| 66 |
self.slack_bolt_app.event("message")(self._handle_message_event)
|
| 67 |
self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
|
| 68 |
logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
|
| 69 |
|
| 70 |
+
async def start(self: Self) -> None:
|
| 71 |
+
await self.socket_mode_handler.start_async()
|
| 72 |
+
|
| 73 |
+
async def stop(self: Self) -> None:
|
| 74 |
+
await self.socket_mode_handler.close_async()
|
| 75 |
+
logger.info("Stopped Slack Bolt socket mode handler and Slack service.")
|
| 76 |
+
|
| 77 |
async def send_message(self: Self, message: SlackResponse) -> None:
|
| 78 |
await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
|
| 79 |
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
|
| 105 |
+
unknown_ids = ids - self._id_name_map.keys()
|
| 106 |
if len(unknown_ids) == 0:
|
| 107 |
return
|
| 108 |
async with TaskGroup() as task_group:
|
| 109 |
update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
|
| 110 |
+
self._id_name_map.update({id: task.result() for id, task in update_tasks.items() if task.result()})
|
| 111 |
|
| 112 |
async def _get_name(self: Self, id: str) -> str:
|
| 113 |
await self._ensure_ids_in_id_name_map({id})
|
| 114 |
+
return self._id_name_map.get(id, id)
|
| 115 |
|
| 116 |
async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
|
| 117 |
logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
|
|
|
|
| 151 |
start, end = match.span()
|
| 152 |
parts.append(text[previous_end:start])
|
| 153 |
user_id = match.group(1)
|
| 154 |
+
parts.append(f"@{self._id_name_map.get(user_id, user_id)}")
|
| 155 |
previous_end = end
|
| 156 |
parts.append(text[previous_end:])
|
| 157 |
return ''.join(parts)
|
| 158 |
|
| 159 |
+
@staticmethod
|
| 160 |
+
def _resolve_user_name(user: Mapping[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
match user:
|
| 162 |
case {"real_name": real_name}:
|
| 163 |
return real_name
|
|
|
|
| 165 |
return display_name
|
| 166 |
case {"name": name}:
|
| 167 |
return name
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class SlackServiceResource(AsyncResource):
|
| 171 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, http_client: AsyncClient) -> SlackService:
|
| 172 |
+
slack_bolt_app = AsyncApp(token=settings.slack_bot_token.get_secret_value())
|
| 173 |
+
socket_mode_handler = AsyncSocketModeHandler(slack_bolt_app, settings.slack_app_token.get_secret_value())
|
| 174 |
+
slack_service = SlackService(settings=settings, event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, socket_mode_handler=socket_mode_handler)
|
| 175 |
+
await slack_service.initialize()
|
| 176 |
+
return slack_service
|
src/ctp_slack_bot/services/{schedule_service.py → task_service.py}
RENAMED
|
@@ -1,17 +1,16 @@
|
|
| 1 |
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
| 2 |
from apscheduler.triggers.cron import CronTrigger
|
| 3 |
-
from asyncio import create_task, iscoroutinefunction, to_thread
|
| 4 |
from datetime import datetime
|
| 5 |
-
from dependency_injector.resources import
|
| 6 |
from loguru import logger
|
| 7 |
from pydantic import ConfigDict
|
| 8 |
from pytz import timezone
|
| 9 |
-
from typing import Any,
|
| 10 |
|
| 11 |
from ctp_slack_bot.core import ApplicationComponentBase, Settings
|
| 12 |
|
| 13 |
|
| 14 |
-
class
|
| 15 |
"""
|
| 16 |
Service for running scheduled tasks.
|
| 17 |
"""
|
|
@@ -43,28 +42,25 @@ class ScheduleService(ApplicationComponentBase):
|
|
| 43 |
# )
|
| 44 |
pass
|
| 45 |
|
| 46 |
-
def start(self: Self) -> None:
|
|
|
|
| 47 |
self._scheduler.start()
|
| 48 |
|
| 49 |
-
def stop(self: Self) -> None:
|
| 50 |
if self._scheduler.running:
|
| 51 |
self._scheduler.shutdown()
|
|
|
|
| 52 |
else:
|
| 53 |
logger.debug("The scheduler is not running. There is no scheduler to shut down.")
|
| 54 |
|
| 55 |
@property
|
| 56 |
def name(self: Self) -> str:
|
| 57 |
-
return "
|
| 58 |
|
| 59 |
|
| 60 |
-
class
|
| 61 |
-
def init(self: Self, settings: Settings) ->
|
| 62 |
-
|
| 63 |
-
schedule_service = ScheduleService(settings=settings)
|
| 64 |
-
schedule_service.start()
|
| 65 |
-
return schedule_service
|
| 66 |
|
| 67 |
-
def shutdown(self: Self,
|
| 68 |
-
|
| 69 |
-
schedule_service.stop()
|
| 70 |
-
logger.info("Stopped scheduler.")
|
|
|
|
| 1 |
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
| 2 |
from apscheduler.triggers.cron import CronTrigger
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
+
from dependency_injector.resources import AsyncResource
|
| 5 |
from loguru import logger
|
| 6 |
from pydantic import ConfigDict
|
| 7 |
from pytz import timezone
|
| 8 |
+
from typing import Any, Self
|
| 9 |
|
| 10 |
from ctp_slack_bot.core import ApplicationComponentBase, Settings
|
| 11 |
|
| 12 |
|
| 13 |
+
class TaskService(ApplicationComponentBase):
|
| 14 |
"""
|
| 15 |
Service for running scheduled tasks.
|
| 16 |
"""
|
|
|
|
| 42 |
# )
|
| 43 |
pass
|
| 44 |
|
| 45 |
+
async def start(self: Self) -> None:
|
| 46 |
+
logger.info("Starting scheduler…")
|
| 47 |
self._scheduler.start()
|
| 48 |
|
| 49 |
+
async def stop(self: Self) -> None:
|
| 50 |
if self._scheduler.running:
|
| 51 |
self._scheduler.shutdown()
|
| 52 |
+
logger.info("Stopped scheduler.")
|
| 53 |
else:
|
| 54 |
logger.debug("The scheduler is not running. There is no scheduler to shut down.")
|
| 55 |
|
| 56 |
@property
|
| 57 |
def name(self: Self) -> str:
|
| 58 |
+
return "task_service"
|
| 59 |
|
| 60 |
|
| 61 |
+
class TaskServiceResource(AsyncResource):
|
| 62 |
+
async def init(self: Self, settings: Settings) -> TaskService:
|
| 63 |
+
return TaskService(settings=settings)
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
async def shutdown(self: Self, task_service: TaskService) -> None:
|
| 66 |
+
await task_service.stop()
|
|
|
|
|
|