import os
from datasets import load_dataset
from haystack import Pipeline, Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.writers import DocumentWriter
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from mixedbread_ai_haystack import MixedbreadAIDocumentEmbedder, MixedbreadAITextEmbedder, MixedbreadAIReranker
# Set API Key
os.environ["MXBAI_API_KEY"] = "YOUR_API_KEY"
# Load Dataset and Prepare Documents
ds = load_dataset("rajuptvs/ecommerce_products_clip")
documents = [
Document(
id=str(i),
content=data["Description"],
meta={
"name": data["Product_name"],
"price": data["Price"],
"colors": data["colors"],
"pattern": data["Pattern"],
"extra": data["Other Details"]
}
) for i, data in enumerate(ds["train"])
]
meta_fields = documents[0].meta.keys()
# Define Components
document_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
document_writer = DocumentWriter(document_store=document_store)
embedding_retriever = InMemoryEmbeddingRetriever(document_store=document_store, top_k=20)
embed_model = "mixedbread-ai/mxbai-embed-large-v1"
reranking_model = "mixedbread-ai/mxbai-rerank-large-v1"
text_embedder = MixedbreadAITextEmbedder(model=embed_model)
document_embedder = MixedbreadAIDocumentEmbedder(model=embed_model, max_concurrency=3, meta_fields_to_embed=meta_fields, show_progress_bar=True)
reranker = MixedbreadAIReranker(model=reranking_model, meta_fields_to_rank=meta_fields, top_k=5)
# Indexing Pipeline
indexing_pipeline = Pipeline()
indexing_pipeline.add_component(instance=document_embedder, name="document_embedder")
indexing_pipeline.add_component(instance=document_writer, name="document_writer")
indexing_pipeline.connect("document_embedder", "document_writer")
# Query Pipeline
query_pipeline = Pipeline()
query_pipeline.add_component(instance=text_embedder, name="text_embedder")
query_pipeline.add_component(instance=embedding_retriever, name="embedding_retriever")
query_pipeline.add_component(instance=reranker, name="reranker")
query_pipeline.connect("text_embedder", "embedding_retriever")
query_pipeline.connect("embedding_retriever.documents", "reranker.documents")
# Index the dataset
indexing_pipeline.run({"document_embedder": {"documents": documents}})
# Query to get results
query = "I am looking for a regular fit t-shirt in blue color. Ideally without any prints. What are my options?"
results = query_pipeline.run(
{
"text_embedder": {"text": query},
"reranker": {"query": query}
}
)
print(results["reranker"]["documents"])