From 3dfa5c6ec8e2963f86a8d6cddc92dc3b75d17008 Mon Sep 17 00:00:00 2001 From: cogwheel0 <172976095+cogwheel0@users.noreply.github.com> Date: Mon, 29 Sep 2025 00:22:12 +0530 Subject: [PATCH] refactor: sockets to use riverpod --- lib/core/models/socket_event.dart | 105 +++++++++ lib/core/providers/app_providers.dart | 213 +++++++++++++++--- lib/core/services/connectivity_service.dart | 34 ++- lib/core/services/streaming_helper.dart | 45 +++- .../chat/providers/chat_providers.dart | 62 ++++- 5 files changed, 410 insertions(+), 49 deletions(-) create mode 100644 lib/core/models/socket_event.dart diff --git a/lib/core/models/socket_event.dart b/lib/core/models/socket_event.dart new file mode 100644 index 0000000..ea6517a --- /dev/null +++ b/lib/core/models/socket_event.dart @@ -0,0 +1,105 @@ +import 'package:flutter/foundation.dart'; + +/// Identifies which socket channel emitted a conversation delta. +enum ConversationDeltaSource { chat, channel } + +/// Describes the parameters needed to bind a socket-backed stream. +@immutable +class ConversationDeltaRequest { + const ConversationDeltaRequest._( + this.source, { + this.conversationId, + this.sessionId, + this.requireFocus = true, + }); + + const ConversationDeltaRequest.chat({ + String? conversationId, + String? sessionId, + bool requireFocus = true, + }) : this._( + ConversationDeltaSource.chat, + conversationId: conversationId, + sessionId: sessionId, + requireFocus: requireFocus, + ); + + const ConversationDeltaRequest.channel({ + String? conversationId, + String? sessionId, + bool requireFocus = true, + }) : this._( + ConversationDeltaSource.channel, + conversationId: conversationId, + sessionId: sessionId, + requireFocus: requireFocus, + ); + + final ConversationDeltaSource source; + final String? conversationId; + final String? sessionId; + final bool requireFocus; + + @override + int get hashCode => + Object.hash(source, conversationId, sessionId, requireFocus); + + @override + bool operator ==(Object other) { + return other is ConversationDeltaRequest && + other.source == source && + other.conversationId == conversationId && + other.sessionId == sessionId && + other.requireFocus == requireFocus; + } + + @override + String toString() { + return 'ConversationDeltaRequest(source: $source, conversationId: ' + '$conversationId, sessionId: $sessionId, requireFocus: ' + '$requireFocus)'; + } +} + +/// Carries a socket event payload along with metadata. +@immutable +class ConversationDelta { + const ConversationDelta({ + required this.source, + required this.raw, + this.type, + this.payload, + this.ack, + }); + + factory ConversationDelta.fromSocketEvent( + ConversationDeltaSource source, + Map event, + void Function(dynamic response)? ack, + ) { + final data = event['data']; + String? type; + Map? payload; + if (data is Map) { + type = data['type']?.toString(); + final dynamic inner = data['data']; + if (inner is Map) { + payload = inner; + } + } + + return ConversationDelta( + source: source, + raw: event, + type: type, + payload: payload, + ack: ack, + ); + } + + final ConversationDeltaSource source; + final Map raw; + final String? type; + final Map? payload; + final void Function(dynamic response)? ack; +} diff --git a/lib/core/providers/app_providers.dart b/lib/core/providers/app_providers.dart index 356b77a..9715b27 100644 --- a/lib/core/providers/app_providers.dart +++ b/lib/core/providers/app_providers.dart @@ -24,6 +24,7 @@ import '../services/settings_service.dart'; import '../services/optimized_storage_service.dart'; import '../services/socket_service.dart'; import '../utils/debug_logger.dart'; +import '../models/socket_event.dart'; part 'app_providers.g.dart'; @@ -281,41 +282,30 @@ enum SocketConnectionState { disconnected, connecting, connected } class SocketConnectionStream extends _$SocketConnectionStream { StreamController? _controller; ProviderSubscription>? _serviceSubscription; - void Function()? _cancelConnectListener; - void Function()? _cancelDisconnectListener; + VoidCallback? _cancelConnectListener; + VoidCallback? _cancelDisconnectListener; + SocketConnectionState _latestState = SocketConnectionState.disconnected; @override Stream build() { - final controller = StreamController.broadcast(); + final controller = StreamController.broadcast( + sync: true, + ); + controller + ..onListen = _primeState + ..onCancel = _maybeNotifyDisconnected; _controller = controller; - void emitState(SocketService? service) { - if (service == null) { - controller.add(SocketConnectionState.disconnected); - _unbindSocket(); - return; - } - controller.add( - service.isConnected - ? SocketConnectionState.connected - : SocketConnectionState.connecting, - ); - _bindSocket(service); - } - - emitState( - ref - .watch(socketServiceManagerProvider) - .maybeWhen(data: (service) => service, orElse: () => null), - ); + final initialService = ref + .watch(socketServiceManagerProvider) + .maybeWhen(data: (service) => service, orElse: () => null); + _handleServiceChange(initialService); _serviceSubscription = ref.listen>( socketServiceManagerProvider, - (previous, next) { - emitState( - next.maybeWhen(data: (service) => service, orElse: () => null), - ); - }, + (_, next) => _handleServiceChange( + next.maybeWhen(data: (service) => service, orElse: () => null), + ), ); ref.onDispose(() { @@ -329,15 +319,45 @@ class SocketConnectionStream extends _$SocketConnectionStream { return controller.stream; } + /// Publishes a disconnected state when the final listener cancels. + void _maybeNotifyDisconnected() { + try { + _controller?.add(SocketConnectionState.disconnected); + _latestState = SocketConnectionState.disconnected; + } catch (_) {} + } + + /// Replays the cached state to new listeners. + void _primeState() { + try { + _controller?.add(_latestState); + } catch (_) {} + } + + void _handleServiceChange(SocketService? service) { + if (service == null) { + _unbindSocket(); + _emit(SocketConnectionState.disconnected); + return; + } + + _emit( + service.isConnected + ? SocketConnectionState.connected + : SocketConnectionState.connecting, + ); + _bindSocket(service); + } + void _bindSocket(SocketService service) { _unbindSocket(); void handleConnect(dynamic _) { - _controller?.add(SocketConnectionState.connected); + _emit(SocketConnectionState.connected); } void handleDisconnect(dynamic _) { - _controller?.add(SocketConnectionState.disconnected); + _emit(SocketConnectionState.disconnected); } service.socket?.on('connect', handleConnect); @@ -351,14 +371,149 @@ class SocketConnectionStream extends _$SocketConnectionStream { }; } + void _emit(SocketConnectionState next) { + if (_latestState == next) { + return; + } + _latestState = next; + try { + _controller?.add(next); + } catch (_) {} + } + void _unbindSocket() { _cancelConnectListener?.call(); _cancelDisconnectListener?.call(); _cancelConnectListener = null; _cancelDisconnectListener = null; } + + /// Forces a best-effort reconnect of the underlying socket service. + Future reconnect({ + Duration timeout = const Duration(seconds: 2), + }) async { + final service = ref.read(socketServiceProvider); + if (service == null) { + return; + } + final connected = await service.ensureConnected(timeout: timeout); + _emit( + connected + ? SocketConnectionState.connected + : SocketConnectionState.connecting, + ); + } + + /// Exposes the latest cached state for imperative reads. + SocketConnectionState get latest => _latestState; } +@Riverpod(keepAlive: true) +class ConversationDeltaStream extends _$ConversationDeltaStream { + StreamController? _controller; + ProviderSubscription>? _serviceSubscription; + SocketEventSubscription? _socketSubscription; + + @override + Stream build(ConversationDeltaRequest request) { + final controller = StreamController.broadcast( + sync: true, + onCancel: _maybeTearDownSocket, + ); + _controller = controller; + + final initialService = ref + .watch(socketServiceManagerProvider) + .maybeWhen(data: (service) => service, orElse: () => null); + _bindSocket(initialService, request); + + _serviceSubscription = ref.listen>( + socketServiceManagerProvider, + (_, next) => _bindSocket( + next.maybeWhen(data: (service) => service, orElse: () => null), + request, + ), + ); + + ref.onDispose(() { + _serviceSubscription?.close(); + _serviceSubscription = null; + _socketSubscription?.dispose(); + _socketSubscription = null; + _controller?.close(); + _controller = null; + }); + + return controller.stream; + } + + void _bindSocket(SocketService? service, ConversationDeltaRequest request) { + _socketSubscription?.dispose(); + _socketSubscription = null; + + if (service == null) { + return; + } + + switch (request.source) { + case ConversationDeltaSource.chat: + _socketSubscription = service.addChatEventHandler( + conversationId: request.conversationId, + sessionId: request.sessionId, + requireFocus: request.requireFocus, + handler: (event, ack) { + _controller?.add( + ConversationDelta.fromSocketEvent( + ConversationDeltaSource.chat, + event, + ack, + ), + ); + }, + ); + break; + case ConversationDeltaSource.channel: + _socketSubscription = service.addChannelEventHandler( + conversationId: request.conversationId, + sessionId: request.sessionId, + requireFocus: request.requireFocus, + handler: (event, ack) { + _controller?.add( + ConversationDelta.fromSocketEvent( + ConversationDeltaSource.channel, + event, + ack, + ), + ); + }, + ); + break; + } + } + + void _maybeTearDownSocket() { + if (_controller?.hasListener == true) { + return; + } + _socketSubscription?.dispose(); + _socketSubscription = null; + } + + Stream get stream => + _controller?.stream ?? const Stream.empty(); +} + +final conversationDeltaEventsProvider = + StreamProvider.family(( + ref, + request, + ) { + final notifier = ref.watch( + conversationDeltaStreamProvider(request).notifier, + ); + return notifier.stream; + }); + // Attachment upload queue provider final attachmentUploadQueueProvider = Provider((ref) { final api = ref.watch(apiServiceProvider); diff --git a/lib/core/services/connectivity_service.dart b/lib/core/services/connectivity_service.dart index 40910d0..d9ce86b 100644 --- a/lib/core/services/connectivity_service.dart +++ b/lib/core/services/connectivity_service.dart @@ -1,8 +1,13 @@ import 'dart:async'; -import 'package:flutter_riverpod/flutter_riverpod.dart'; + import 'package:dio/dio.dart'; +import 'package:flutter_riverpod/flutter_riverpod.dart'; +import 'package:riverpod_annotation/riverpod_annotation.dart'; + import '../providers/app_providers.dart'; +part 'connectivity_service.g.dart'; + enum ConnectivityStatus { online, offline, checking } class ConnectivityService { @@ -211,10 +216,29 @@ final connectivityServiceProvider = Provider((ref) { return service; }); -final connectivityStatusProvider = StreamProvider((ref) { - final service = ref.watch(connectivityServiceProvider); - return service.connectivityStream; -}); +@Riverpod(keepAlive: true) +class ConnectivityStatusNotifier extends _$ConnectivityStatusNotifier { + StreamSubscription? _subscription; + + @override + FutureOr build() { + final service = ref.watch(connectivityServiceProvider); + + _subscription?.cancel(); + _subscription = service.connectivityStream.listen( + (status) => state = AsyncValue.data(status), + onError: (error, stackTrace) => + state = AsyncValue.error(error, stackTrace), + ); + + ref.onDispose(() { + _subscription?.cancel(); + _subscription = null; + }); + + return service.currentStatus; + } +} final isOnlineProvider = Provider((ref) { // In reviewer mode, treat app as online to enable flows diff --git a/lib/core/services/streaming_helper.dart b/lib/core/services/streaming_helper.dart index 48754e0..f56ae50 100644 --- a/lib/core/services/streaming_helper.dart +++ b/lib/core/services/streaming_helper.dart @@ -4,6 +4,7 @@ import 'dart:convert'; import 'package:flutter/material.dart'; import '../../core/models/chat_message.dart'; +import '../../core/models/socket_event.dart'; import '../../core/services/persistent_streaming_service.dart'; import '../../core/services/socket_service.dart'; import '../../core/utils/inactivity_watchdog.dart'; @@ -25,7 +26,7 @@ class ActiveSocketStream { }); final StreamSubscription streamSubscription; - final List socketSubscriptions; + final List socketSubscriptions; final VoidCallback disposeWatchdog; } @@ -44,6 +45,8 @@ ActiveSocketStream attachUnifiedChunkedStreaming({ required String? activeConversationId, required dynamic api, required SocketService? socketService, + Stream? chatEvents, + Stream? channelEvents, // Message update callbacks required void Function(String) appendToLastMessage, required void Function(String) replaceLastMessageContent, @@ -91,8 +94,10 @@ ActiveSocketStream attachUnifiedChunkedStreaming({ ); InactivityWatchdog? socketWatchdog; - final socketSubscriptions = []; - if (socketService != null) { + final socketSubscriptions = []; + final hasSocketSignals = + socketService != null || chatEvents != null || channelEvents != null; + if (hasSocketSignals) { // Increase timeout to match OpenWebUI's more generous timeouts for long responses socketWatchdog = InactivityWatchdog( window: const Duration(minutes: 15), // Increased from 5 to 15 minutes @@ -102,8 +107,10 @@ ActiveSocketStream attachUnifiedChunkedStreaming({ scope: 'streaming/helper', ); try { - for (final sub in socketSubscriptions) { - sub.dispose(); + for (final dispose in socketSubscriptions) { + try { + dispose(); + } catch (_) {} } socketSubscriptions.clear(); } catch (_) {} @@ -124,9 +131,9 @@ ActiveSocketStream attachUnifiedChunkedStreaming({ if (socketSubscriptions.isEmpty) { return; } - for (final sub in socketSubscriptions) { + for (final dispose in socketSubscriptions) { try { - sub.dispose(); + dispose(); } catch (_) {} } socketSubscriptions.clear(); @@ -983,22 +990,40 @@ ActiveSocketStream attachUnifiedChunkedStreaming({ } catch (_) {} } - if (socketService != null) { + if (chatEvents != null) { + final subscription = chatEvents.listen((event) { + socketWatchdog?.ping(); + chatHandler(event.raw, event.ack); + }); + socketSubscriptions.add(() { + unawaited(subscription.cancel()); + }); + } else if (socketService != null) { final chatSub = socketService.addChatEventHandler( conversationId: activeConversationId, sessionId: sessionId, requireFocus: false, handler: chatHandler, ); - socketSubscriptions.add(chatSub); + socketSubscriptions.add(chatSub.dispose); + } + if (channelEvents != null) { + final subscription = channelEvents.listen((event) { + socketWatchdog?.ping(); + channelEventsHandler(event.raw, event.ack); + }); + socketSubscriptions.add(() { + unawaited(subscription.cancel()); + }); + } else if (socketService != null) { final channelSub = socketService.addChannelEventHandler( conversationId: activeConversationId, sessionId: sessionId, requireFocus: false, handler: channelEventsHandler, ); - socketSubscriptions.add(channelSub); + socketSubscriptions.add(channelSub.dispose); } final subscription = persistentController.stream.listen( diff --git a/lib/features/chat/providers/chat_providers.dart b/lib/features/chat/providers/chat_providers.dart index 9e24864..649242a 100644 --- a/lib/features/chat/providers/chat_providers.dart +++ b/lib/features/chat/providers/chat_providers.dart @@ -6,10 +6,10 @@ import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'package:uuid/uuid.dart'; import '../../../core/utils/tool_calls_parser.dart'; import '../../../core/services/streaming_helper.dart'; -import '../../../core/services/socket_service.dart'; import '../../../core/models/chat_message.dart'; import '../../../core/models/conversation.dart'; import '../../../core/providers/app_providers.dart'; +import '../../../core/models/socket_event.dart'; import '../../../core/auth/auth_state_manager.dart'; import '../../../core/utils/inactivity_watchdog.dart'; import '../services/reviewer_mode_service.dart'; @@ -89,7 +89,7 @@ class ChatMessagesNotifier extends Notifier> { StreamSubscription? _messageStream; ProviderSubscription? _conversationListener; final List _subscriptions = []; - final List _socketSubscriptions = []; + final List _socketSubscriptions = []; VoidCallback? _socketTeardown; // Activity-based watchdog to prevent stuck typing indicator InactivityWatchdog? _typingWatchdog; @@ -397,7 +397,7 @@ class ChatMessagesNotifier extends Notifier> { } void setSocketSubscriptions( - List subscriptions, { + List subscriptions, { VoidCallback? onDispose, }) { cancelSocketSubscriptions(); @@ -411,9 +411,9 @@ class ChatMessagesNotifier extends Notifier> { _socketTeardown = null; return; } - for (final sub in _socketSubscriptions) { + for (final dispose in _socketSubscriptions) { try { - sub.dispose(); + dispose(); } catch (_) {} } _socketSubscriptions.clear(); @@ -1332,6 +1332,30 @@ Future regenerateMessage( }); } catch (_) {} + final chatEventsStream = ref + .read( + conversationDeltaStreamProvider( + ConversationDeltaRequest.chat( + conversationId: activeConversation.id, + sessionId: effectiveSessionId, + requireFocus: false, + ), + ).notifier, + ) + .stream; + + final channelEventsStream = ref + .read( + conversationDeltaStreamProvider( + ConversationDeltaRequest.channel( + conversationId: activeConversation.id, + sessionId: effectiveSessionId, + requireFocus: false, + ), + ).notifier, + ) + .stream; + final activeStream = attachUnifiedChunkedStreaming( stream: stream, webSearchEnabled: webSearchEnabled, @@ -1342,6 +1366,8 @@ Future regenerateMessage( activeConversationId: activeConversation.id, api: api, socketService: socketService, + chatEvents: chatEventsStream, + channelEvents: channelEventsStream, appendToLastMessage: (c) => ref.read(chatMessagesProvider.notifier).appendToLastMessage(c), replaceLastMessageContent: (c) => @@ -1867,6 +1893,30 @@ Future _sendMessageInternal( }); } catch (_) {} + final chatEventsStream = ref + .read( + conversationDeltaStreamProvider( + ConversationDeltaRequest.chat( + conversationId: activeConversation?.id, + sessionId: effectiveSessionId, + requireFocus: false, + ), + ).notifier, + ) + .stream; + + final channelEventsStream = ref + .read( + conversationDeltaStreamProvider( + ConversationDeltaRequest.channel( + conversationId: activeConversation?.id, + sessionId: effectiveSessionId, + requireFocus: false, + ), + ).notifier, + ) + .stream; + final activeStream = attachUnifiedChunkedStreaming( stream: stream, webSearchEnabled: webSearchEnabled, @@ -1877,6 +1927,8 @@ Future _sendMessageInternal( activeConversationId: activeConversation?.id, api: api, socketService: socketService, + chatEvents: chatEventsStream, + channelEvents: channelEventsStream, appendToLastMessage: (c) => ref.read(chatMessagesProvider.notifier).appendToLastMessage(c), replaceLastMessageContent: (c) =>