datawithsuman commited on
Commit
213363a
1 Parent(s): a4aa40a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -87,22 +87,33 @@ if uploaded_files:
87
  index = VectorStoreIndex(nodes=nodes, storage_context=storage_context)
88
 
89
  # Retrieval
90
- bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=3)
91
  vector_retriever = index.as_retriever(similarity_top_k=3)
92
 
93
  # Hybrid Retriever class
94
  class HybridRetriever(BaseRetriever):
95
  def __init__(self, vector_retriever, bm25_retriever):
96
- self.vector_retriever = vector_retriever
97
  self.bm25_retriever = bm25_retriever
98
  super().__init__()
99
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _retrieve(self, query, **kwargs):
101
  bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
102
- vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
103
  all_nodes = []
104
  node_ids = set()
105
- for n in bm25_nodes + vector_nodes:
106
  if n.node.node_id not in node_ids:
107
  all_nodes.append(n)
108
  node_ids.add(n.node.node_id)
 
87
  index = VectorStoreIndex(nodes=nodes, storage_context=storage_context)
88
 
89
  # Retrieval
90
+ bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=1)
91
  vector_retriever = index.as_retriever(similarity_top_k=3)
92
 
93
  # Hybrid Retriever class
94
  class HybridRetriever(BaseRetriever):
95
  def __init__(self, vector_retriever, bm25_retriever):
96
+ # self.vector_retriever = vector_retriever
97
  self.bm25_retriever = bm25_retriever
98
  super().__init__()
99
 
100
+ # def _retrieve(self, query, **kwargs):
101
+ # bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
102
+ # vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
103
+ # all_nodes = []
104
+ # node_ids = set()
105
+ # for n in bm25_nodes + vector_nodes:
106
+ # if n.node.node_id not in node_ids:
107
+ # all_nodes.append(n)
108
+ # node_ids.add(n.node.node_id)
109
+ # return all_nodes
110
+
111
  def _retrieve(self, query, **kwargs):
112
  bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
113
+ # vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
114
  all_nodes = []
115
  node_ids = set()
116
+ for n in bm25_nodes
117
  if n.node.node_id not in node_ids:
118
  all_nodes.append(n)
119
  node_ids.add(n.node.node_id)