Rui Tao's Portfolio

Implementing WebSocket Gateways in NestJS

Published on
Published on
/5 mins read/---

Basic Gateway Implementation

First, let's implement a basic WebSocket gateway:

import {
  WebSocketGateway,
  WebSocketServer,
  SubscribeMessage,
  OnGatewayConnection,
  OnGatewayDisconnect,
  WsResponse,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { UseGuards } from '@nestjs/common';
import { WsAuthGuard } from '../guards/ws-auth.guard';
 
@WebSocketGateway({
  cors: {
    origin: process.env.CLIENT_URL,
    credentials: true
  },
  namespace: '/ws'
})
@UseGuards(WsAuthGuard)
export class AppGateway implements OnGatewayConnection, OnGatewayDisconnect {
  @WebSocketServer()
  server: Server;
 
  private connectedClients: Map<string, Socket> = new Map();
 
  async handleConnection(client: Socket) {
    try {
      // Validate client connection
      const userId = client.handshake.auth.userId;
      if (!userId) {
        client.disconnect();
        return;
      }
 
      // Store client connection
      this.connectedClients.set(userId, client);
      
      // Join user-specific room
      await client.join(`user:${userId}`);
      
      console.log(`Client connected: ${userId}`);
    } catch (error) {
      console.error('Connection error:', error);
      client.disconnect();
    }
  }
 
  handleDisconnect(client: Socket) {
    try {
      const userId = client.handshake.auth.userId;
      if (userId) {
        this.connectedClients.delete(userId);
        console.log(`Client disconnected: ${userId}`);
      }
    } catch (error) {
      console.error('Disconnection error:', error);
    }
  }
 
  @SubscribeMessage('message')
  async handleMessage(client: Socket, payload: any): Promise<WsResponse<any>> {
    try {
      const userId = client.handshake.auth.userId;
      
      // Process message
      const result = await this.processMessage(payload);
      
      // Broadcast to room if needed
      if (payload.broadcast) {
        this.server.to(`user:${userId}`).emit('message', result);
      }
      
      return { event: 'message', data: result };
    } catch (error) {
      return { event: 'error', data: { message: error.message } };
    }
  }
 
  private async processMessage(payload: any): Promise<any> {
    // Add your message processing logic here
    return payload;
  }
}

Connection Manager Service

Implement a service to manage WebSocket connections:

import { Injectable, OnModuleInit } from '@nestjs/common';
import { Server, Socket } from 'socket.io';
import { WebSocketServer } from '@nestjs/websockets';
 
@Injectable()
export class ConnectionManagerService implements OnModuleInit {
  @WebSocketServer()
  private server: Server;
 
  private connections: Map<string, Set<Socket>> = new Map();
  private heartbeatInterval = 30000; // 30 seconds
 
  onModuleInit() {
    this.startHeartbeat();
  }
 
  addConnection(userId: string, socket: Socket) {
    if (!this.connections.has(userId)) {
      this.connections.set(userId, new Set());
    }
    this.connections.get(userId).add(socket);
  }
 
  removeConnection(userId: string, socket: Socket) {
    const userConnections = this.connections.get(userId);
    if (userConnections) {
      userConnections.delete(socket);
      if (userConnections.size === 0) {
        this.connections.delete(userId);
      }
    }
  }
 
  isConnected(userId: string): boolean {
    return this.connections.has(userId) && this.connections.get(userId).size > 0;
  }
 
  getUserSockets(userId: string): Socket[] {
    return Array.from(this.connections.get(userId) || []);
  }
 
  broadcastToUser(userId: string, event: string, data: any) {
    const sockets = this.getUserSockets(userId);
    sockets.forEach(socket => {
      socket.emit(event, data);
    });
  }
 
  private startHeartbeat() {
    setInterval(() => {
      this.connections.forEach((sockets, userId) => {
        sockets.forEach(socket => {
          socket.emit('heartbeat', { timestamp: Date.now() });
        });
      });
    }, this.heartbeatInterval);
  }
}

Specialized Gateway Example

Here's an example of a specialized gateway for handling real-time updates:

@WebSocketGateway({
  namespace: '/updates'
})
export class UpdatesGateway implements OnGatewayConnection {
  constructor(
    private readonly connectionManager: ConnectionManagerService,
    private readonly authService: AuthService,
  ) {}
 
  @SubscribeMessage('subscribe')
  async handleSubscribe(
    client: Socket,
    payload: { topics: string[] }
  ): Promise<WsResponse<any>> {
    try {
      // Validate subscription request
      await this.validateSubscription(client, payload.topics);
      
      // Join topic rooms
      await Promise.all(
        payload.topics.map(topic => client.join(`topic:${topic}`))
      );
      
      return {
        event: 'subscribed',
        data: { topics: payload.topics }
      };
    } catch (error) {
      return {
        event: 'error',
        data: { message: error.message }
      };
    }
  }
 
  @SubscribeMessage('unsubscribe')
  async handleUnsubscribe(
    client: Socket,
    payload: { topics: string[] }
  ): Promise<WsResponse<any>> {
    try {
      // Leave topic rooms
      await Promise.all(
        payload.topics.map(topic => client.leave(`topic:${topic}`))
      );
      
      return {
        event: 'unsubscribed',
        data: { topics: payload.topics }
      };
    } catch (error) {
      return {
        event: 'error',
        data: { message: error.message }
      };
    }
  }
 
  private async validateSubscription(
    client: Socket,
    topics: string[]
  ): Promise<void> {
    const userId = client.handshake.auth.userId;
    const user = await this.authService.validateUser(userId);
    
    // Check if user has access to all topics
    const hasAccess = await Promise.all(
      topics.map(topic => this.authService.canAccessTopic(user, topic))
    );
    
    if (!hasAccess.every(Boolean)) {
      throw new Error('Unauthorized subscription request');
    }
  }
}

Authentication Guard

Implement a WebSocket authentication guard:

import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
import { WsException } from '@nestjs/websockets';
import { AuthService } from './auth.service';
import { Socket } from 'socket.io';
 
@Injectable()
export class WsAuthGuard implements CanActivate {
  constructor(private authService: AuthService) {}
 
  async canActivate(context: ExecutionContext): Promise<boolean> {
    try {
      const client: Socket = context.switchToWs().getClient();
      const token = client.handshake.auth.token;
      
      if (!token) {
        throw new WsException('Missing auth token');
      }
      
      const user = await this.authService.validateToken(token);
      if (!user) {
        throw new WsException('Invalid auth token');
      }
      
      // Attach user to socket
      client.data.user = user;
      
      return true;
    } catch (error) {
      throw new WsException(error.message);
    }
  }
}

Usage Example

Here's how to use these components:

// main.ts
import { NestFactory } from '@nestjs/core';
import { AppModule } from './app.module';
 
async function bootstrap() {
  const app = await NestFactory.create(AppModule);
  
  // Configure WebSocket
  app.enableCors({
    origin: process.env.CLIENT_URL,
    credentials: true
  });
  
  await app.listen(3000);
}
bootstrap();
 
// app.module.ts
@Module({
  imports: [],
  providers: [
    AppGateway,
    UpdatesGateway,
    ConnectionManagerService,
    AuthService,
  ],
})
export class AppModule {}
 
// client-side usage
const socket = io('http://localhost:3000/ws', {
  auth: {
    token: 'your-auth-token',
    userId: 'user-id'
  }
});
 
socket.on('connect', () => {
  console.log('Connected to WebSocket');
});
 
socket.on('message', (data) => {
  console.log('Received message:', data);
});
 
socket.on('error', (error) => {
  console.error('WebSocket error:', error);
});
 
socket.emit('message', { text: 'Hello!' });

Notes

  • Implement proper error handling
  • Use authentication guards
  • Manage connections efficiently
  • Handle reconnection scenarios
  • Monitor connection health
  • Implement rate limiting
  • Use proper typing for messages
  • Handle cleanup on disconnect