diff --git a/backend/app/hashing.py b/backend/app/hashing.py index ac6e145852258bc2df0e576822f2db0db52241ce..c2f7e79daf188e979f8c82d108b727e117f84685 100644 --- a/backend/app/hashing.py +++ b/backend/app/hashing.py @@ -75,5 +75,7 @@ def get_jwt_user(token: str = Depends(reuseable_oauth), db: Session = Depends(ge return db_user + + def get_jwt_user_from_refresh_token(token: str = Depends(reuseable_refresh_oauth), db: Session = Depends(get_db)): return get_jwt_user(token, db) diff --git a/backend/app/main.py b/backend/app/main.py index f0d37578af8f88dcd0cd22e1227725da91d168f8..529db91996b6e94591c6696e6034f6d361b75247 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -227,17 +227,17 @@ def read_messages(session_id: int, skip: int = 0, limit: int = 100, db: Session return crud.get_messages(db, session_id, skip=skip, limit=limit) -def send_websoket_message(session_id: int, message: schemas.Message): +async def send_websoket_message(session_id: int, message: schemas.Message): content = json.dumps({ "type": "message", "action": "create", - "data": message + "data": message.to_dict() }) for _, user_websockets in websocket_users[session_id].items(): for user_websocket in user_websockets: - user_websocket.send_text(content) + await user_websocket.send_text(content) @sessionsRouter.post("/{session_id}/messages", status_code=status.HTTP_201_CREATED) def create_message(session_id: int, message: schemas.MessageCreate, background_tasks: BackgroundTasks, db: Session = Depends(get_db), current_user: schemas.User = Depends(hashing.get_jwt_user)): @@ -250,13 +250,18 @@ def create_message(session_id: int, message: schemas.MessageCreate, background_t message = crud.create_message(db, message, current_user, db_session) - background_tasks.add_task(send_websoket_message, session_id, message) + background_tasks.add_task(send_websoket_message, session_id, schemas.Message.model_validate(message)) return message.id -@websocketRouter.websocket("/{session_id}") -async def websocket_session(session_id: int, websocket: WebSocket, db: Session = Depends(get_db), current_user: schemas.User = Depends(hashing.get_jwt_user)): +@websocketRouter.websocket("/{token}/{session_id}") +async def websocket_session(token: str, session_id: int, websocket: WebSocket, db: Session = Depends(get_db)): + current_user = hashing.get_jwt_user(token=token, db=db) + + if current_user is None: + raise HTTPException(status_code=401, detail="Invalid token") + db_session = crud.get_session(db, session_id) if db_session is None: raise HTTPException(status_code=404, detail="Session not found") diff --git a/backend/app/schemas.py b/backend/app/schemas.py index fd241af74df751e52a103fc6c944b4b33747a373..88e7963b4a31d017203c4fc1c40b410be1f7d10b 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -63,6 +63,15 @@ class Message(BaseModel): class Config: from_attributes = True + def to_dict(self): + return { + "id": self.id, + "content": self.content, + "user_id": self.user_id, + "session_id": self.session_id, + "created_at": self.created_at.isoformat(), + } + class MessageCreate(BaseModel): content: str diff --git a/frontend/src/lib/components/sessions/chatbox.svelte b/frontend/src/lib/components/sessions/chatbox.svelte index 540c4d119286879df3117f087a8c7e85b53158ab..9a9af5e8c3404b500a3015c0a23b31b61b81b996 100644 --- a/frontend/src/lib/components/sessions/chatbox.svelte +++ b/frontend/src/lib/components/sessions/chatbox.svelte @@ -6,10 +6,16 @@ import { Icon, PaperAirplane } from 'svelte-hero-icons'; import { toastAlert } from '$lib/utils/toasts'; import { get } from 'svelte/store'; + import Message from '$lib/types/message'; let message = ''; export let session: Session; let htmlMessages: HTMLElement; + let messages = get(session.messages); + + session.messages.subscribe((newMessages) => { + messages = newMessages; + }); onMount(async () => { await session.loadMessages(); @@ -35,7 +41,7 @@ <div class="flex flex-col md:my-8 min-w-fit w-full max-w-4xl border-2"> <div class="flex-grow h-48 overflow-auto flex-col-reverse px-4 flex" bind:this={htmlMessages}> - {#each get(session.messages).sort((a, b) => b.created_at.getTime() - a.created_at.getTime()) as message (message.id)} + {#each messages.sort((a, b) => b.created_at.getTime() - a.created_at.getTime()) as message (message.id)} <MessageC {message} /> {/each} </div> diff --git a/frontend/src/lib/components/sessions/message.svelte b/frontend/src/lib/components/sessions/message.svelte index df5205713ecd546d422c235cb36b3bd8ca347d92..fc213e2fb1bd29d301826ab4e0a2c8307c974b87 100644 --- a/frontend/src/lib/components/sessions/message.svelte +++ b/frontend/src/lib/components/sessions/message.svelte @@ -19,8 +19,9 @@ <div class="w-full flex" class:justify-end={isSender}> <div - class="bg-gray-200 rounded-b-xl my-2 p-4 w-fit" + class="rounded-b-xl my-2 p-4 w-fit" class:bg-blue-200={isSender} + class:bg-gray-200={!isSender} class:rounded-tl-xl={isSender} class:rounded-tr-xl={!isSender} > diff --git a/frontend/src/lib/types/session.ts b/frontend/src/lib/types/session.ts index cb7be7f7af7699c7b4d5c99acd912ce6007c76b2..1bb480d48107592fcbb38c3b36a49e796cf6b96c 100644 --- a/frontend/src/lib/types/session.ts +++ b/frontend/src/lib/types/session.ts @@ -131,7 +131,12 @@ export default class Session { const message = new Message(id, content, new Date(), sender, this); - this._messages.update((messages) => [...messages, message]); + this._messages.update((messages) => { + if (!messages.find((m) => m.id === message.id)) { + return [...messages, message]; + } + return messages.map((m) => (m.id === message.id ? message : m)); + }); return message; } @@ -139,7 +144,9 @@ export default class Session { public wsConnect() { if (this._ws_connected) return; - this._ws = new WebSocket(`${WS_URL}/${this.id}`); + const token = localStorage.getItem('accessToken'); + + this._ws = new WebSocket(`${WS_URL}/${token}/${this.id}`); this._ws.onopen = () => { this._ws_connected = true; diff --git a/frontend/src/lib/utils/login.ts b/frontend/src/lib/utils/login.ts index fb8716830481b16ec4959ea2c62b3b6a01656854..5004c82e4b6e1642b20f3037b30fec97bcac8daa 100644 --- a/frontend/src/lib/utils/login.ts +++ b/frontend/src/lib/utils/login.ts @@ -2,7 +2,7 @@ import session from '$lib/stores/JWTSession'; export function requireLogin(): boolean { if (!session.isLoggedIn()) { - window.location.href = '/login?redirect=' + window.location.pathname; + window.location.href = '/login?redirect=' + encodeURIComponent(window.location.href); return false; } return true; diff --git a/frontend/src/routes/login/+page.svelte b/frontend/src/routes/login/+page.svelte index 5a54fa2939c5d6c0bb4b27d1c08046e49401e337..04af7768ac712d0da620dfaff8295d8c53e8f4dc 100644 --- a/frontend/src/routes/login/+page.svelte +++ b/frontend/src/routes/login/+page.svelte @@ -23,7 +23,9 @@ return; } - const redirect = new URLSearchParams(window.location.search).get('redirect') ?? '/'; + const redirect = decodeURIComponent( + new URLSearchParams(window.location.search).get('redirect') ?? '/' + ); window.location.href = redirect; } </script>