Проблема:
Время ответа чат-бота составляет примерно 80 секунд на один запрос, что кажется слишком долгим. Вот разбивка задержки:
Внедрение поиска по сходству (AstraDB): ~5 секунд.
Обработка LLM (Mistral 7B): ~75 секунд. (метод - answe_query)
Моя цель:
Я хочу сократить общее время ответа до менее 10 секунд, чтобы обеспечить удобство работы с пользователем.
Мои вопросы:
Как я могу оптимизировать процесс встраивания поиска в AstraDB для повышения производительности?
Существуют ли какие-либо конкретные способы точной настройки или повышения эффективности Mistral 7B для более быстрого вывода?
>Следует ли мне рассмотреть возможность масштабирования или настройки моего экземпляра EC2 или вообще использовать другую стратегию хостинга?
Будем очень признательны за любые советы, подсказки или инструменты, которые вы можете порекомендовать для повышения производительности!
Я попробовал приведенный ниже код получения ответа на запрос:
class EnhancedChatbot:
def __init__(self, collection):
self.llm = global_llm
self.collection = collection
self.prompt_template = PromptTemplate(
template="""
Based on the following context, please answer the question.
Context: {context}
Question: {question}
Answer:""",
input_variables=["context", "question"]
)
def get_relevant_documents(self, query: str) -> List[Document]:
try:
embeddings_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': False}
)
query_embedding = np.array(embeddings_model.embed_query(query))
results = self.collection.find()
if not results or not isinstance(results, dict) or 'data' not in results:
print("No documents found in collection")
return []
documents = results['data'].get('documents', [])
if not documents:
print("No documents in collection data")
return []
docs_with_similarity = []
for doc in documents:
try:
if not isinstance(doc, dict):
continue
content = doc.get('content')
embedding = doc.get('embedding')
if not content or not embedding:
continue
embedding = np.array(embedding)
if embedding.shape[0] != 768:
print(f"Skipping document with invalid embedding shape: {embedding.shape}")
continue
similarity = euclidean_distances([query_embedding], [embedding])[0][0]
docs_with_similarity.append(
Document(
page_content=content,
metadata={'similarity': similarity, **doc.get('metadata', {})}
)
)
except Exception as doc_error:
print(f"Error processing document: {str(doc_error)}")
continue
docs_with_similarity.sort(key=lambda x: x.metadata['similarity'], reverse=True)
return docs_with_similarity[:6]
except Exception as e:
print(f"Error in get_relevant_documents: {str(e)}")
return []
def answer_query(self, query: str) -> str:
try:
if not query or not isinstance(query, str):
return self.format_response("Invalid query provided", [], query)
relevant_docs = self.get_relevant_documents(query)
if not relevant_docs:
return self.format_response(
"No relevant information found in the database.",
[],
query
)
text_splitter = TokenTextSplitter(
chunk_size=200,
chunk_overlap=20
)
texts = text_splitter.split_documents(relevant_docs)
max_chunks = 3
texts = texts[:max_chunks]
llm_chain = LLMChain(
llm=self.llm,
prompt=self.prompt_template
)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context"
)
context = "\n".join([doc.page_content for doc in texts])
max_context_chars = 800
if len(context) > max_context_chars:
context = context[:max_context_chars] + "..."
response = chain({
"question": query,
"input_documents": [Document(page_content=context)]
})
response_text = response.get('output_text', 'No response generated')
print(f"Response text: {response_text}")
return self.format_response(
response_text,
[doc.metadata.get('source', 'Unknown') for doc in relevant_docs],
query
)
except Exception as e:
print(f"Error in answer_query: {str(e)}")
return self.format_response(
f"An error occurred while processing your query: {str(e)}",
[],
query
)
def format_response(self, message: str, sources: List[str], query: str) -> str:
try:
sources_text = "\n".join([f"{source}" for source in sources]) if sources else "No relevant sources found"
print(f"source response: {sources_text}")
return f"""
{message}
Query: {query}
{message if sources else f"Unable to provide detailed information due to: {message}"}
{sources_text}
"""
except Exception as e:
print(f"Error formatting response: {str(e)}")
return f"""
Error formatting response
An error occurred while formatting the response
No sources available
"""
def train_on_example(self, example: Dict[str, str]):
"""Add a training example to improve responses"""
try:
if not isinstance(example, dict) or 'question' not in example or 'answer' not in example:
raise ValueError("Invalid training example format")
formatted_prompt = self.prompt_template.format(
context="",
question=example['question']
)
training_example = {
'prompt': formatted_prompt,
'response': example['answer'],
'metadata': {
'timestamp': str(time.time()),
'type': 'training_example'
}
}
if self.collection:
self.collection.insert_one(training_example)
print(f"Training example added successfully for question: {example['question']}")
else:
raise ValueError("Collection not initialized")
except Exception as e:
print(f"Error adding training example: {str(e)}")
def initialize_application(model_path: str):
"""Initialize the application with the LLM"""
try:
initialize_llm(model_path, model_type='mistral')
print("Application initialized successfully")
except Exception as e:
print(f"Failed to initialize application: {str(e)}")
raise
def answer_query(query: str):
if chatbot_collection is None:
print("Collection not initialized")
return None
try:
chatbot = EnhancedChatbot(collection=chatbot_collection)
return chatbot.answer_query(query)
except Exception as e:
print(f"Error processing query: {str(e)}")
return None
@app.route('/query', methods=\['POST'\])
async def query():
try:
authenticated = is_authenticated()
print(f"authenticated: {authenticated}", flush=True)
if not authenticated:
if not check_query_limit():
print("Query limit exceeded", flush=True)
return jsonify({
"error": "Query limit exceeded. Please login to continue.",
"status": "error",
"requires_login": True
}), 403
session['query_count'] = session.get('query_count', 0) + 1
data = request.json # This parses the JSON automatically
if not data:
return jsonify({"error": "No data provided", "status": "error"}), 400
question = data.get('query')
session_id = data.get('session_id')
if not question:
return jsonify({"error": "No question provided", "status": "error"}), 400
user_id = None
if authenticated:
auth_header = request.headers.get('Authorization')
token = auth_header.split(' ')[1]
user_id = decode_jwt_token(token)
if not session_id:
user = db.users.find_one(
{
"_id": ObjectId(user_id),
"queries.status": "active"
},
{"queries": {"$elemMatch": {"status": "active"}}}
)
if user and 'queries' in user and user['queries']:
session_id = user['queries'][0]['session_id']
else:
# Create new session if no active session exists
new_session = {
"session_id": str(uuid.uuid4()),
"created_at": datetime.utcnow(),
"status": "active",
"queries": []
}
result = db.users.update_one(
{"_id": ObjectId(user_id)},
{
"$push": {
"queries": new_session
}
}
)
if result.modified_count == 0:
return jsonify({
"error": "Failed to create session",
"status": "error"
}), 500
session_id = new_session["session_id"]
response = answer_query(question)
if not response:
return jsonify({
"error": "Failed to generate response",
"status": "error"
}), 500
if authenticated and user_id and session_id:
try:
# Check if the session is active
user = db.users.find_one({
"_id": ObjectId(user_id),
"queries": {
"$elemMatch": {
"session_id": session_id,
"status": "active"
}
}
})
if not user:
return jsonify({
"error": "Session not found or has ended",
"status": "error"
}), 400
query_data = {
"query": question,
"response": response,
"timestamp": datetime.utcnow()
}
result = db.users.update_one(
{
"_id": ObjectId(user_id),
"queries.session_id": session_id
},
{
"$push": {
"queries.$.queries": query_data
}
}
)
if not result.modified_count:
logger.error(f"Failed to update session {session_id}")
except Exception as e:
logger.error(f"Error storing query in session: {str(e)}")
remaining_queries = 2 - session.get('query_count', 0) if not authenticated else None
return jsonify({
"answer": response,
"status": "success",
"remaining_queries": remaining_queries,
"authenticated": authenticated,
"session_id": session_id
}), 200
except Exception as e:
logger.error(f"Query endpoint error: {str(e)}")
return jsonify({
"error": f"An error occurred: {str(e)}",
"status": "error"
}), 500`your text`
if __name__ == '__main__':
connect_to_astra_db()
initialize_application('./models/models/capybarahermes-2.5-mistral-7b.Q4_K_M.gguf')
app.run(host='0.0.0.0', port=5000)
Подробнее здесь: https://stackoverflow.com/questions/792 ... mistral-7b