refactor: sockets to use riverpod

This commit is contained in:
cogwheel0
2025-09-29 00:22:12 +05:30
parent 0ba48030c8
commit 3dfa5c6ec8
5 changed files with 410 additions and 49 deletions

View File

@@ -0,0 +1,105 @@
import 'package:flutter/foundation.dart';
/// Identifies which socket channel emitted a conversation delta.
enum ConversationDeltaSource { chat, channel }
/// Describes the parameters needed to bind a socket-backed stream.
@immutable
class ConversationDeltaRequest {
const ConversationDeltaRequest._(
this.source, {
this.conversationId,
this.sessionId,
this.requireFocus = true,
});
const ConversationDeltaRequest.chat({
String? conversationId,
String? sessionId,
bool requireFocus = true,
}) : this._(
ConversationDeltaSource.chat,
conversationId: conversationId,
sessionId: sessionId,
requireFocus: requireFocus,
);
const ConversationDeltaRequest.channel({
String? conversationId,
String? sessionId,
bool requireFocus = true,
}) : this._(
ConversationDeltaSource.channel,
conversationId: conversationId,
sessionId: sessionId,
requireFocus: requireFocus,
);
final ConversationDeltaSource source;
final String? conversationId;
final String? sessionId;
final bool requireFocus;
@override
int get hashCode =>
Object.hash(source, conversationId, sessionId, requireFocus);
@override
bool operator ==(Object other) {
return other is ConversationDeltaRequest &&
other.source == source &&
other.conversationId == conversationId &&
other.sessionId == sessionId &&
other.requireFocus == requireFocus;
}
@override
String toString() {
return 'ConversationDeltaRequest(source: $source, conversationId: '
'$conversationId, sessionId: $sessionId, requireFocus: '
'$requireFocus)';
}
}
/// Carries a socket event payload along with metadata.
@immutable
class ConversationDelta {
const ConversationDelta({
required this.source,
required this.raw,
this.type,
this.payload,
this.ack,
});
factory ConversationDelta.fromSocketEvent(
ConversationDeltaSource source,
Map<String, dynamic> event,
void Function(dynamic response)? ack,
) {
final data = event['data'];
String? type;
Map<String, dynamic>? payload;
if (data is Map<String, dynamic>) {
type = data['type']?.toString();
final dynamic inner = data['data'];
if (inner is Map<String, dynamic>) {
payload = inner;
}
}
return ConversationDelta(
source: source,
raw: event,
type: type,
payload: payload,
ack: ack,
);
}
final ConversationDeltaSource source;
final Map<String, dynamic> raw;
final String? type;
final Map<String, dynamic>? payload;
final void Function(dynamic response)? ack;
}

View File

@@ -24,6 +24,7 @@ import '../services/settings_service.dart';
import '../services/optimized_storage_service.dart'; import '../services/optimized_storage_service.dart';
import '../services/socket_service.dart'; import '../services/socket_service.dart';
import '../utils/debug_logger.dart'; import '../utils/debug_logger.dart';
import '../models/socket_event.dart';
part 'app_providers.g.dart'; part 'app_providers.g.dart';
@@ -281,41 +282,30 @@ enum SocketConnectionState { disconnected, connecting, connected }
class SocketConnectionStream extends _$SocketConnectionStream { class SocketConnectionStream extends _$SocketConnectionStream {
StreamController<SocketConnectionState>? _controller; StreamController<SocketConnectionState>? _controller;
ProviderSubscription<AsyncValue<SocketService?>>? _serviceSubscription; ProviderSubscription<AsyncValue<SocketService?>>? _serviceSubscription;
void Function()? _cancelConnectListener; VoidCallback? _cancelConnectListener;
void Function()? _cancelDisconnectListener; VoidCallback? _cancelDisconnectListener;
SocketConnectionState _latestState = SocketConnectionState.disconnected;
@override @override
Stream<SocketConnectionState> build() { Stream<SocketConnectionState> build() {
final controller = StreamController<SocketConnectionState>.broadcast(); final controller = StreamController<SocketConnectionState>.broadcast(
sync: true,
);
controller
..onListen = _primeState
..onCancel = _maybeNotifyDisconnected;
_controller = controller; _controller = controller;
void emitState(SocketService? service) { final initialService = ref
if (service == null) {
controller.add(SocketConnectionState.disconnected);
_unbindSocket();
return;
}
controller.add(
service.isConnected
? SocketConnectionState.connected
: SocketConnectionState.connecting,
);
_bindSocket(service);
}
emitState(
ref
.watch(socketServiceManagerProvider) .watch(socketServiceManagerProvider)
.maybeWhen(data: (service) => service, orElse: () => null), .maybeWhen(data: (service) => service, orElse: () => null);
); _handleServiceChange(initialService);
_serviceSubscription = ref.listen<AsyncValue<SocketService?>>( _serviceSubscription = ref.listen<AsyncValue<SocketService?>>(
socketServiceManagerProvider, socketServiceManagerProvider,
(previous, next) { (_, next) => _handleServiceChange(
emitState(
next.maybeWhen(data: (service) => service, orElse: () => null), next.maybeWhen(data: (service) => service, orElse: () => null),
); ),
},
); );
ref.onDispose(() { ref.onDispose(() {
@@ -329,15 +319,45 @@ class SocketConnectionStream extends _$SocketConnectionStream {
return controller.stream; return controller.stream;
} }
/// Publishes a disconnected state when the final listener cancels.
void _maybeNotifyDisconnected() {
try {
_controller?.add(SocketConnectionState.disconnected);
_latestState = SocketConnectionState.disconnected;
} catch (_) {}
}
/// Replays the cached state to new listeners.
void _primeState() {
try {
_controller?.add(_latestState);
} catch (_) {}
}
void _handleServiceChange(SocketService? service) {
if (service == null) {
_unbindSocket();
_emit(SocketConnectionState.disconnected);
return;
}
_emit(
service.isConnected
? SocketConnectionState.connected
: SocketConnectionState.connecting,
);
_bindSocket(service);
}
void _bindSocket(SocketService service) { void _bindSocket(SocketService service) {
_unbindSocket(); _unbindSocket();
void handleConnect(dynamic _) { void handleConnect(dynamic _) {
_controller?.add(SocketConnectionState.connected); _emit(SocketConnectionState.connected);
} }
void handleDisconnect(dynamic _) { void handleDisconnect(dynamic _) {
_controller?.add(SocketConnectionState.disconnected); _emit(SocketConnectionState.disconnected);
} }
service.socket?.on('connect', handleConnect); service.socket?.on('connect', handleConnect);
@@ -351,14 +371,149 @@ class SocketConnectionStream extends _$SocketConnectionStream {
}; };
} }
void _emit(SocketConnectionState next) {
if (_latestState == next) {
return;
}
_latestState = next;
try {
_controller?.add(next);
} catch (_) {}
}
void _unbindSocket() { void _unbindSocket() {
_cancelConnectListener?.call(); _cancelConnectListener?.call();
_cancelDisconnectListener?.call(); _cancelDisconnectListener?.call();
_cancelConnectListener = null; _cancelConnectListener = null;
_cancelDisconnectListener = null; _cancelDisconnectListener = null;
} }
/// Forces a best-effort reconnect of the underlying socket service.
Future<void> reconnect({
Duration timeout = const Duration(seconds: 2),
}) async {
final service = ref.read(socketServiceProvider);
if (service == null) {
return;
}
final connected = await service.ensureConnected(timeout: timeout);
_emit(
connected
? SocketConnectionState.connected
: SocketConnectionState.connecting,
);
}
/// Exposes the latest cached state for imperative reads.
SocketConnectionState get latest => _latestState;
} }
@Riverpod(keepAlive: true)
class ConversationDeltaStream extends _$ConversationDeltaStream {
StreamController<ConversationDelta>? _controller;
ProviderSubscription<AsyncValue<SocketService?>>? _serviceSubscription;
SocketEventSubscription? _socketSubscription;
@override
Stream<ConversationDelta> build(ConversationDeltaRequest request) {
final controller = StreamController<ConversationDelta>.broadcast(
sync: true,
onCancel: _maybeTearDownSocket,
);
_controller = controller;
final initialService = ref
.watch(socketServiceManagerProvider)
.maybeWhen(data: (service) => service, orElse: () => null);
_bindSocket(initialService, request);
_serviceSubscription = ref.listen<AsyncValue<SocketService?>>(
socketServiceManagerProvider,
(_, next) => _bindSocket(
next.maybeWhen(data: (service) => service, orElse: () => null),
request,
),
);
ref.onDispose(() {
_serviceSubscription?.close();
_serviceSubscription = null;
_socketSubscription?.dispose();
_socketSubscription = null;
_controller?.close();
_controller = null;
});
return controller.stream;
}
void _bindSocket(SocketService? service, ConversationDeltaRequest request) {
_socketSubscription?.dispose();
_socketSubscription = null;
if (service == null) {
return;
}
switch (request.source) {
case ConversationDeltaSource.chat:
_socketSubscription = service.addChatEventHandler(
conversationId: request.conversationId,
sessionId: request.sessionId,
requireFocus: request.requireFocus,
handler: (event, ack) {
_controller?.add(
ConversationDelta.fromSocketEvent(
ConversationDeltaSource.chat,
event,
ack,
),
);
},
);
break;
case ConversationDeltaSource.channel:
_socketSubscription = service.addChannelEventHandler(
conversationId: request.conversationId,
sessionId: request.sessionId,
requireFocus: request.requireFocus,
handler: (event, ack) {
_controller?.add(
ConversationDelta.fromSocketEvent(
ConversationDeltaSource.channel,
event,
ack,
),
);
},
);
break;
}
}
void _maybeTearDownSocket() {
if (_controller?.hasListener == true) {
return;
}
_socketSubscription?.dispose();
_socketSubscription = null;
}
Stream<ConversationDelta> get stream =>
_controller?.stream ?? const Stream<ConversationDelta>.empty();
}
final conversationDeltaEventsProvider =
StreamProvider.family<ConversationDelta, ConversationDeltaRequest>((
ref,
request,
) {
final notifier = ref.watch(
conversationDeltaStreamProvider(request).notifier,
);
return notifier.stream;
});
// Attachment upload queue provider // Attachment upload queue provider
final attachmentUploadQueueProvider = Provider<AttachmentUploadQueue?>((ref) { final attachmentUploadQueueProvider = Provider<AttachmentUploadQueue?>((ref) {
final api = ref.watch(apiServiceProvider); final api = ref.watch(apiServiceProvider);

View File

@@ -1,8 +1,13 @@
import 'dart:async'; import 'dart:async';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:dio/dio.dart'; import 'package:dio/dio.dart';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:riverpod_annotation/riverpod_annotation.dart';
import '../providers/app_providers.dart'; import '../providers/app_providers.dart';
part 'connectivity_service.g.dart';
enum ConnectivityStatus { online, offline, checking } enum ConnectivityStatus { online, offline, checking }
class ConnectivityService { class ConnectivityService {
@@ -211,10 +216,29 @@ final connectivityServiceProvider = Provider<ConnectivityService>((ref) {
return service; return service;
}); });
final connectivityStatusProvider = StreamProvider<ConnectivityStatus>((ref) { @Riverpod(keepAlive: true)
class ConnectivityStatusNotifier extends _$ConnectivityStatusNotifier {
StreamSubscription<ConnectivityStatus>? _subscription;
@override
FutureOr<ConnectivityStatus> build() {
final service = ref.watch(connectivityServiceProvider); final service = ref.watch(connectivityServiceProvider);
return service.connectivityStream;
}); _subscription?.cancel();
_subscription = service.connectivityStream.listen(
(status) => state = AsyncValue.data(status),
onError: (error, stackTrace) =>
state = AsyncValue.error(error, stackTrace),
);
ref.onDispose(() {
_subscription?.cancel();
_subscription = null;
});
return service.currentStatus;
}
}
final isOnlineProvider = Provider<bool>((ref) { final isOnlineProvider = Provider<bool>((ref) {
// In reviewer mode, treat app as online to enable flows // In reviewer mode, treat app as online to enable flows

View File

@@ -4,6 +4,7 @@ import 'dart:convert';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import '../../core/models/chat_message.dart'; import '../../core/models/chat_message.dart';
import '../../core/models/socket_event.dart';
import '../../core/services/persistent_streaming_service.dart'; import '../../core/services/persistent_streaming_service.dart';
import '../../core/services/socket_service.dart'; import '../../core/services/socket_service.dart';
import '../../core/utils/inactivity_watchdog.dart'; import '../../core/utils/inactivity_watchdog.dart';
@@ -25,7 +26,7 @@ class ActiveSocketStream {
}); });
final StreamSubscription<String> streamSubscription; final StreamSubscription<String> streamSubscription;
final List<SocketEventSubscription> socketSubscriptions; final List<VoidCallback> socketSubscriptions;
final VoidCallback disposeWatchdog; final VoidCallback disposeWatchdog;
} }
@@ -44,6 +45,8 @@ ActiveSocketStream attachUnifiedChunkedStreaming({
required String? activeConversationId, required String? activeConversationId,
required dynamic api, required dynamic api,
required SocketService? socketService, required SocketService? socketService,
Stream<ConversationDelta>? chatEvents,
Stream<ConversationDelta>? channelEvents,
// Message update callbacks // Message update callbacks
required void Function(String) appendToLastMessage, required void Function(String) appendToLastMessage,
required void Function(String) replaceLastMessageContent, required void Function(String) replaceLastMessageContent,
@@ -91,8 +94,10 @@ ActiveSocketStream attachUnifiedChunkedStreaming({
); );
InactivityWatchdog? socketWatchdog; InactivityWatchdog? socketWatchdog;
final socketSubscriptions = <SocketEventSubscription>[]; final socketSubscriptions = <VoidCallback>[];
if (socketService != null) { final hasSocketSignals =
socketService != null || chatEvents != null || channelEvents != null;
if (hasSocketSignals) {
// Increase timeout to match OpenWebUI's more generous timeouts for long responses // Increase timeout to match OpenWebUI's more generous timeouts for long responses
socketWatchdog = InactivityWatchdog( socketWatchdog = InactivityWatchdog(
window: const Duration(minutes: 15), // Increased from 5 to 15 minutes window: const Duration(minutes: 15), // Increased from 5 to 15 minutes
@@ -102,8 +107,10 @@ ActiveSocketStream attachUnifiedChunkedStreaming({
scope: 'streaming/helper', scope: 'streaming/helper',
); );
try { try {
for (final sub in socketSubscriptions) { for (final dispose in socketSubscriptions) {
sub.dispose(); try {
dispose();
} catch (_) {}
} }
socketSubscriptions.clear(); socketSubscriptions.clear();
} catch (_) {} } catch (_) {}
@@ -124,9 +131,9 @@ ActiveSocketStream attachUnifiedChunkedStreaming({
if (socketSubscriptions.isEmpty) { if (socketSubscriptions.isEmpty) {
return; return;
} }
for (final sub in socketSubscriptions) { for (final dispose in socketSubscriptions) {
try { try {
sub.dispose(); dispose();
} catch (_) {} } catch (_) {}
} }
socketSubscriptions.clear(); socketSubscriptions.clear();
@@ -983,22 +990,40 @@ ActiveSocketStream attachUnifiedChunkedStreaming({
} catch (_) {} } catch (_) {}
} }
if (socketService != null) { if (chatEvents != null) {
final subscription = chatEvents.listen((event) {
socketWatchdog?.ping();
chatHandler(event.raw, event.ack);
});
socketSubscriptions.add(() {
unawaited(subscription.cancel());
});
} else if (socketService != null) {
final chatSub = socketService.addChatEventHandler( final chatSub = socketService.addChatEventHandler(
conversationId: activeConversationId, conversationId: activeConversationId,
sessionId: sessionId, sessionId: sessionId,
requireFocus: false, requireFocus: false,
handler: chatHandler, handler: chatHandler,
); );
socketSubscriptions.add(chatSub); socketSubscriptions.add(chatSub.dispose);
}
if (channelEvents != null) {
final subscription = channelEvents.listen((event) {
socketWatchdog?.ping();
channelEventsHandler(event.raw, event.ack);
});
socketSubscriptions.add(() {
unawaited(subscription.cancel());
});
} else if (socketService != null) {
final channelSub = socketService.addChannelEventHandler( final channelSub = socketService.addChannelEventHandler(
conversationId: activeConversationId, conversationId: activeConversationId,
sessionId: sessionId, sessionId: sessionId,
requireFocus: false, requireFocus: false,
handler: channelEventsHandler, handler: channelEventsHandler,
); );
socketSubscriptions.add(channelSub); socketSubscriptions.add(channelSub.dispose);
} }
final subscription = persistentController.stream.listen( final subscription = persistentController.stream.listen(

View File

@@ -6,10 +6,10 @@ import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:uuid/uuid.dart'; import 'package:uuid/uuid.dart';
import '../../../core/utils/tool_calls_parser.dart'; import '../../../core/utils/tool_calls_parser.dart';
import '../../../core/services/streaming_helper.dart'; import '../../../core/services/streaming_helper.dart';
import '../../../core/services/socket_service.dart';
import '../../../core/models/chat_message.dart'; import '../../../core/models/chat_message.dart';
import '../../../core/models/conversation.dart'; import '../../../core/models/conversation.dart';
import '../../../core/providers/app_providers.dart'; import '../../../core/providers/app_providers.dart';
import '../../../core/models/socket_event.dart';
import '../../../core/auth/auth_state_manager.dart'; import '../../../core/auth/auth_state_manager.dart';
import '../../../core/utils/inactivity_watchdog.dart'; import '../../../core/utils/inactivity_watchdog.dart';
import '../services/reviewer_mode_service.dart'; import '../services/reviewer_mode_service.dart';
@@ -89,7 +89,7 @@ class ChatMessagesNotifier extends Notifier<List<ChatMessage>> {
StreamSubscription? _messageStream; StreamSubscription? _messageStream;
ProviderSubscription? _conversationListener; ProviderSubscription? _conversationListener;
final List<StreamSubscription> _subscriptions = []; final List<StreamSubscription> _subscriptions = [];
final List<SocketEventSubscription> _socketSubscriptions = []; final List<VoidCallback> _socketSubscriptions = [];
VoidCallback? _socketTeardown; VoidCallback? _socketTeardown;
// Activity-based watchdog to prevent stuck typing indicator // Activity-based watchdog to prevent stuck typing indicator
InactivityWatchdog? _typingWatchdog; InactivityWatchdog? _typingWatchdog;
@@ -397,7 +397,7 @@ class ChatMessagesNotifier extends Notifier<List<ChatMessage>> {
} }
void setSocketSubscriptions( void setSocketSubscriptions(
List<SocketEventSubscription> subscriptions, { List<VoidCallback> subscriptions, {
VoidCallback? onDispose, VoidCallback? onDispose,
}) { }) {
cancelSocketSubscriptions(); cancelSocketSubscriptions();
@@ -411,9 +411,9 @@ class ChatMessagesNotifier extends Notifier<List<ChatMessage>> {
_socketTeardown = null; _socketTeardown = null;
return; return;
} }
for (final sub in _socketSubscriptions) { for (final dispose in _socketSubscriptions) {
try { try {
sub.dispose(); dispose();
} catch (_) {} } catch (_) {}
} }
_socketSubscriptions.clear(); _socketSubscriptions.clear();
@@ -1332,6 +1332,30 @@ Future<void> regenerateMessage(
}); });
} catch (_) {} } catch (_) {}
final chatEventsStream = ref
.read(
conversationDeltaStreamProvider(
ConversationDeltaRequest.chat(
conversationId: activeConversation.id,
sessionId: effectiveSessionId,
requireFocus: false,
),
).notifier,
)
.stream;
final channelEventsStream = ref
.read(
conversationDeltaStreamProvider(
ConversationDeltaRequest.channel(
conversationId: activeConversation.id,
sessionId: effectiveSessionId,
requireFocus: false,
),
).notifier,
)
.stream;
final activeStream = attachUnifiedChunkedStreaming( final activeStream = attachUnifiedChunkedStreaming(
stream: stream, stream: stream,
webSearchEnabled: webSearchEnabled, webSearchEnabled: webSearchEnabled,
@@ -1342,6 +1366,8 @@ Future<void> regenerateMessage(
activeConversationId: activeConversation.id, activeConversationId: activeConversation.id,
api: api, api: api,
socketService: socketService, socketService: socketService,
chatEvents: chatEventsStream,
channelEvents: channelEventsStream,
appendToLastMessage: (c) => appendToLastMessage: (c) =>
ref.read(chatMessagesProvider.notifier).appendToLastMessage(c), ref.read(chatMessagesProvider.notifier).appendToLastMessage(c),
replaceLastMessageContent: (c) => replaceLastMessageContent: (c) =>
@@ -1867,6 +1893,30 @@ Future<void> _sendMessageInternal(
}); });
} catch (_) {} } catch (_) {}
final chatEventsStream = ref
.read(
conversationDeltaStreamProvider(
ConversationDeltaRequest.chat(
conversationId: activeConversation?.id,
sessionId: effectiveSessionId,
requireFocus: false,
),
).notifier,
)
.stream;
final channelEventsStream = ref
.read(
conversationDeltaStreamProvider(
ConversationDeltaRequest.channel(
conversationId: activeConversation?.id,
sessionId: effectiveSessionId,
requireFocus: false,
),
).notifier,
)
.stream;
final activeStream = attachUnifiedChunkedStreaming( final activeStream = attachUnifiedChunkedStreaming(
stream: stream, stream: stream,
webSearchEnabled: webSearchEnabled, webSearchEnabled: webSearchEnabled,
@@ -1877,6 +1927,8 @@ Future<void> _sendMessageInternal(
activeConversationId: activeConversation?.id, activeConversationId: activeConversation?.id,
api: api, api: api,
socketService: socketService, socketService: socketService,
chatEvents: chatEventsStream,
channelEvents: channelEventsStream,
appendToLastMessage: (c) => appendToLastMessage: (c) =>
ref.read(chatMessagesProvider.notifier).appendToLastMessage(c), ref.read(chatMessagesProvider.notifier).appendToLastMessage(c),
replaceLastMessageContent: (c) => replaceLastMessageContent: (c) =>