feat: tools implementation

This commit is contained in:
cogwheel0
2025-08-19 20:26:19 +05:30
parent d3742944bc
commit 267a45cd9e
8 changed files with 189 additions and 14 deletions

26
lib/core/models/tool.dart Normal file
View File

@@ -0,0 +1,26 @@
import 'package:freezed_annotation/freezed_annotation.dart';
part 'tool.freezed.dart';
@freezed
sealed class Tool with _$Tool {
const Tool._();
const factory Tool({
required String id,
required String name,
String? description,
String? userId,
Map<String, dynamic>? meta,
}) = _Tool;
factory Tool.fromJson(Map<String, dynamic> json) {
return Tool(
id: json['id'] as String,
name: json['name'] as String,
description: json['description'] as String?,
userId: json['user_id'] as String?,
meta: json['meta'] as Map<String, dynamic>?,
);
}
}

View File

@@ -24,6 +24,9 @@ class ApiService {
late final ApiAuthInterceptor _authInterceptor; late final ApiAuthInterceptor _authInterceptor;
// Removed legacy websocket/socket.io fields // Removed legacy websocket/socket.io fields
// Public getter for dio instance
Dio get dio => _dio;
// Callback to notify when auth token becomes invalid // Callback to notify when auth token becomes invalid
void Function()? onAuthTokenInvalid; void Function()? onAuthTokenInvalid;
@@ -2415,7 +2418,7 @@ class ApiService {
required List<Map<String, dynamic>> messages, required List<Map<String, dynamic>> messages,
required String model, required String model,
String? conversationId, String? conversationId,
List<Map<String, dynamic>>? tools, List<String>? toolIds,
bool enableWebSearch = false, bool enableWebSearch = false,
Map<String, dynamic>? modelItem, Map<String, dynamic>? modelItem,
}) { }) {
@@ -2500,6 +2503,12 @@ class ApiService {
debugPrint('DEBUG: Web search enabled in SSE request'); debugPrint('DEBUG: Web search enabled in SSE request');
} }
// Add tool_ids if provided (Open-WebUI expects tool_ids as array of strings)
if (toolIds != null && toolIds.isNotEmpty) {
data['tool_ids'] = toolIds;
debugPrint('DEBUG: Including tool_ids in SSE request: $toolIds');
}
// Don't add session_id or id - they break SSE streaming! // Don't add session_id or id - they break SSE streaming!
// The server falls back to task-based async when these are present // The server falls back to task-based async when these are present

View File

@@ -0,0 +1,29 @@
import 'package:dio/dio.dart';
import 'package:conduit/core/models/tool.dart';
import 'package:conduit/core/services/api_service.dart';
import 'package:conduit/core/error/api_error_handler.dart';
import 'package:conduit/core/providers/app_providers.dart';
import 'package:flutter_riverpod/flutter_riverpod.dart';
class ToolsService {
final ApiService _apiService;
ToolsService(this._apiService);
Future<List<Tool>> getTools() async {
try {
final response = await _apiService.dio.get('/api/v1/tools/');
return (response.data as List)
.map((json) => Tool.fromJson(json))
.toList();
} on DioException catch (e) {
throw ApiErrorHandler().transformError(e);
}
}
}
final toolsServiceProvider = Provider<ToolsService?>((ref) {
final apiService = ref.watch(apiServiceProvider);
if (apiService == null) return null;
return ToolsService(apiService);
});

View File

@@ -507,20 +507,22 @@ Future<void> regenerateMessage(
Future<void> sendMessage( Future<void> sendMessage(
WidgetRef ref, WidgetRef ref,
String message, String message,
List<String>? attachments, List<String>? attachments, [
) async { List<String>? toolIds,
]) async {
debugPrint( debugPrint(
'DEBUG: sendMessage called with message: $message, attachments: $attachments', 'DEBUG: sendMessage called with message: $message, attachments: $attachments, tools: $toolIds',
); );
await _sendMessageInternal(ref, message, attachments); await _sendMessageInternal(ref, message, attachments, toolIds);
} }
// Internal send message implementation // Internal send message implementation
Future<void> _sendMessageInternal( Future<void> _sendMessageInternal(
dynamic ref, dynamic ref,
String message, String message,
List<String>? attachments, List<String>? attachments, [
) async { List<String>? toolIds,
]) async {
debugPrint('DEBUG: _sendMessageInternal called'); debugPrint('DEBUG: _sendMessageInternal called');
debugPrint('DEBUG: Message: $message'); debugPrint('DEBUG: Message: $message');
debugPrint('DEBUG: Attachments: $attachments'); debugPrint('DEBUG: Attachments: $attachments');
@@ -543,7 +545,7 @@ Future<void> _sendMessageInternal(
debugPrint('DEBUG: Active conversation before send: ${activeConversation?.id}'); debugPrint('DEBUG: Active conversation before send: ${activeConversation?.id}');
// Create user message first // Create user message first
debugPrint('DEBUG: Creating user message with attachments: $attachments'); debugPrint('DEBUG: Creating user message with attachments: $attachments, tools: $toolIds');
final userMessage = ChatMessage( final userMessage = ChatMessage(
id: const Uuid().v4(), id: const Uuid().v4(),
role: 'user', role: 'user',
@@ -794,8 +796,11 @@ Future<void> _sendMessageInternal(
// Debug log to track web search state // Debug log to track web search state
debugPrint('DEBUG: Web search toggle state: $webSearchEnabled'); debugPrint('DEBUG: Web search toggle state: $webSearchEnabled');
// No need for function calling tools since we're using retrieval directly // Prepare tools list - pass tool IDs directly
final tools = <Map<String, dynamic>>[]; final List<String>? toolIdsForApi = (toolIds != null && toolIds.isNotEmpty) ? toolIds : null;
if (toolIdsForApi != null) {
debugPrint('DEBUG: Including tool IDs: $toolIdsForApi');
}
try { try {
// Use the model's actual supported parameters if available // Use the model's actual supported parameters if available
@@ -927,7 +932,7 @@ Future<void> _sendMessageInternal(
messages: conversationMessages, messages: conversationMessages,
model: selectedModel.id, model: selectedModel.id,
conversationId: activeConversation?.id, conversationId: activeConversation?.id,
tools: tools.isNotEmpty ? tools : null, toolIds: toolIdsForApi,
enableWebSearch: webSearchEnabled, enableWebSearch: webSearchEnabled,
modelItem: modelItem, modelItem: modelItem,
); );

View File

@@ -22,6 +22,7 @@ import '../services/file_attachment_service.dart';
import '../../navigation/views/chats_list_page.dart'; import '../../navigation/views/chats_list_page.dart';
import '../../files/views/files_page.dart'; import '../../files/views/files_page.dart';
import '../../profile/views/profile_page.dart'; import '../../profile/views/profile_page.dart';
import '../../tools/providers/tools_providers.dart';
import '../../../shared/widgets/offline_indicator.dart'; import '../../../shared/widgets/offline_indicator.dart';
import '../../../core/services/connectivity_service.dart'; import '../../../core/services/connectivity_service.dart';
import '../../../core/models/chat_message.dart'; import '../../../core/models/chat_message.dart';
@@ -289,11 +290,16 @@ class _ChatPageState extends ConsumerState<ChatPage> {
debugPrint('DEBUG: Uploaded file IDs: $uploadedFileIds'); debugPrint('DEBUG: Uploaded file IDs: $uploadedFileIds');
// Send message with file attachments using existing provider logic // Get selected tools
final toolIds = ref.read(selectedToolIdsProvider);
debugPrint('DEBUG: Selected tool IDs: $toolIds');
// Send message with file attachments and tools using existing provider logic
await sendMessage( await sendMessage(
ref, ref,
text, text,
uploadedFileIds.isNotEmpty ? uploadedFileIds : null, uploadedFileIds.isNotEmpty ? uploadedFileIds : null,
toolIds.isNotEmpty ? toolIds : null,
); );
debugPrint('DEBUG: Message sent successfully'); debugPrint('DEBUG: Message sent successfully');
@@ -744,8 +750,6 @@ class _ChatPageState extends ConsumerState<ChatPage> {
).push(MaterialPageRoute(builder: (context) => const FilesPage())); ).push(MaterialPageRoute(builder: (context) => const FilesPage()));
} }
void _navigateToProfile() { void _navigateToProfile() {
Navigator.of( Navigator.of(
context, context,

View File

@@ -8,6 +8,8 @@ import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'dart:io' show Platform; import 'dart:io' show Platform;
import 'dart:async'; import 'dart:async';
import '../providers/chat_providers.dart'; import '../providers/chat_providers.dart';
import '../../tools/widgets/tool_selector.dart';
import '../../tools/providers/tools_providers.dart';
import '../../../shared/utils/platform_utils.dart'; import '../../../shared/utils/platform_utils.dart';
@@ -47,6 +49,7 @@ class _ModernChatInputState extends ConsumerState<ModernChatInput>
late AnimationController _pulseController; late AnimationController _pulseController;
Timer? _blurCollapseTimer; Timer? _blurCollapseTimer;
bool _hasAutoFocusedOnce = false; bool _hasAutoFocusedOnce = false;
bool _showToolSelector = false;
@override @override
void initState() { void initState() {
@@ -165,6 +168,11 @@ class _ModernChatInputState extends ConsumerState<ModernChatInput>
PlatformUtils.lightHaptic(); PlatformUtils.lightHaptic();
widget.onSendMessage(text); widget.onSendMessage(text);
_controller.clear(); _controller.clear();
setState(() {
_showToolSelector = false;
});
// Clear selected tools after sending
ref.read(selectedToolIdsProvider.notifier).state = [];
// Keep input expanded and focused for better UX - don't dismiss keyboard // Keep input expanded and focused for better UX - don't dismiss keyboard
// KeyboardUtils.dismissKeyboard(context); // KeyboardUtils.dismissKeyboard(context);
// _setExpanded(false); // _setExpanded(false);
@@ -346,6 +354,14 @@ class _ModernChatInputState extends ConsumerState<ModernChatInput>
// Expanded bottom row with additional options // Expanded bottom row with additional options
if (_isExpanded) ...[ if (_isExpanded) ...[
// Tool selector
if (_showToolSelector)
const Padding(
padding: EdgeInsets.symmetric(
horizontal: Spacing.inputPadding,
),
child: ToolSelector(),
),
Container( Container(
padding: const EdgeInsets.only( padding: const EdgeInsets.only(
left: Spacing.inputPadding, left: Spacing.inputPadding,
@@ -364,6 +380,20 @@ class _ModernChatInputState extends ConsumerState<ModernChatInput>
tooltip: 'Add attachment', tooltip: 'Add attachment',
), ),
const SizedBox(width: Spacing.sm), const SizedBox(width: Spacing.sm),
// Tools button
_buildRoundButton(
icon: Icons.build,
onTap: widget.enabled
? () {
setState(() {
_showToolSelector = !_showToolSelector;
});
}
: null,
tooltip: 'Tools',
isActive: _showToolSelector || ref.watch(selectedToolIdsProvider).isNotEmpty,
),
const SizedBox(width: Spacing.sm),
Flexible( Flexible(
child: Center(child: _buildResearchToggle()), child: Center(child: _buildResearchToggle()),
), ),

View File

@@ -0,0 +1,11 @@
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:conduit/core/models/tool.dart';
import 'package:conduit/core/services/tools_service.dart';
final toolsListProvider = FutureProvider<List<Tool>>((ref) async {
final toolsService = ref.watch(toolsServiceProvider);
if (toolsService == null) return [];
return await toolsService.getTools();
});
final selectedToolIdsProvider = StateProvider<List<String>>((ref) => []);

View File

@@ -0,0 +1,61 @@
import 'package:flutter/material.dart';
import 'package:flutter_riverpod/flutter_riverpod.dart';
import 'package:conduit/features/tools/providers/tools_providers.dart';
class ToolSelector extends ConsumerWidget {
const ToolSelector({super.key});
@override
Widget build(BuildContext context, WidgetRef ref) {
final toolsAsync = ref.watch(toolsListProvider);
final selectedIds = ref.watch(selectedToolIdsProvider);
final theme = Theme.of(context);
return toolsAsync.when(
data: (tools) {
if (tools.isEmpty) {
return const SizedBox.shrink();
}
return Container(
height: 40,
margin: const EdgeInsets.symmetric(vertical: 8),
child: ListView.separated(
scrollDirection: Axis.horizontal,
padding: const EdgeInsets.symmetric(horizontal: 12),
itemCount: tools.length,
separatorBuilder: (context, index) => const SizedBox(width: 8),
itemBuilder: (context, index) {
final tool = tools[index];
final isSelected = selectedIds.contains(tool.id);
return FilterChip(
label: Text(tool.name),
selected: isSelected,
onSelected: (_) {
final currentIds = ref.read(selectedToolIdsProvider);
if (isSelected) {
ref.read(selectedToolIdsProvider.notifier).state =
currentIds.where((id) => id != tool.id).toList();
} else {
ref.read(selectedToolIdsProvider.notifier).state =
[...currentIds, tool.id];
}
},
avatar: Icon(
Icons.build,
size: 16,
color: isSelected
? theme.colorScheme.onSecondaryContainer
: theme.colorScheme.onSurfaceVariant,
),
);
},
),
);
},
loading: () => const SizedBox.shrink(),
error: (error, stack) => const SizedBox.shrink(),
);
}
}