Skip to content

Commit

Permalink
Merge pull request #7 from prabha-git/mixtral
Browse files Browse the repository at this point in the history
Mixtral
  • Loading branch information
prabha-git authored Jan 16, 2024
2 parents e631891 + 2af541d commit 8759ca2
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
4 changes: 2 additions & 2 deletions GoogleEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ async def get_response_from_model(self, prompt):
if __name__ == "__main__":
# Tons of defaults set, check out the LLMNeedleHaystackTester's init for more info
ht = GoogleEvaluator(model_name='gemini-pro', evaluation_method='substring_match',
save_results_dir='gemini-pro-results-run2',
save_contexts_dir='gemini-pro-context-run2'
save_results_dir='results/gemini-pro/gemini-pro-results-run2',
save_contexts_dir='contexts/gemini-pro/gemini-pro-context-run2'
)

ht.start_test()
5 changes: 3 additions & 2 deletions LLMNeedleHaystackTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,17 @@ async def evaluate_and_log(self, context_length, depth_percent):
if not os.path.exists(self.save_contexts_dir):
os.makedirs(self.save_contexts_dir)

with open(f'{self.save_contexts_dir}/{self.save_contexts_dir}_{context_file_location}_context.txt', 'w') as f:
with open(f'{self.save_contexts_dir}/{context_file_location}_context.txt', 'w') as f:
f.write(context)


if self.save_results:
# Save the context to file for retesting
if not os.path.exists(self.save_results_dir):
os.makedirs(self.save_results_dir)

# Save the result to file for retesting
with open(f'{self.save_results_dir}/{self.save_results_dir}_{context_file_location}_results.json', 'w') as f:
with open(f'{self.save_results_dir}/{context_file_location}_results.json', 'w') as f:
json.dump(results, f)

if self.seconds_to_sleep_between_completions:
Expand Down
62 changes: 62 additions & 0 deletions OSS_TogetherEvaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import tiktoken
from LLMNeedleHaystackTester import LLMNeedleHaystackTester
from langchain_together import Together
import time



class OSS_TogetherEvaluator(LLMNeedleHaystackTester):
def __init__(self,**kwargs):
if 'together_api_key' not in kwargs and not os.getenv('TOGETHER_API_KEY'):
raise ValueError("Either together_api_key must be supplied with init, or TOGETHER_API_KEY must be in env")

if 'model_name' not in kwargs:
raise ValueError("model_name must be supplied with init, accepted model_names are ''")
elif kwargs['model_name'] not in ['Mixtral-8x7B-Instruct-v0.1','Mistral-7B-Instruct-v0.2']:
raise ValueError(f"Model name must be in this list (Mixtral-8x7B-Instruct-v0.1,Mistral-7B-Instruct-v0.2) but given {kwargs['model_name']}")

if 'evaluation_method' not in kwargs:
print("since evaluation method is not specified , 'gpt4' will be used for evaluation")
elif kwargs['evaluation_method'] not in ('gpt4', 'substring_match'):
raise ValueError("evaluation_method must be 'substring_match' or 'gpt4'")


self.google_api_key = kwargs.pop('together_api_key', os.getenv('TOGETHER_API_KEY'))
self.model_name = kwargs['model_name']
self.model_to_test_description = kwargs.pop('model_name')
self.tokenizer = tiktoken.encoding_for_model('gpt-4-1106-previe3') # Probably need to change in future
self.model_to_test = Together(model=f'mistralai/{self.model_name}', temperature=0)
kwargs['context_lengths_max'] = 32000
super().__init__(**kwargs)



def get_encoding(self,context):
return self.tokenizer.encode(context)

def get_decoding(self, encoded_context):
return self.tokenizer.decode(encoded_context)

def get_prompt(self, context):
with open('Google_prompt.txt', 'r') as file: # Need to change it
prompt = file.read()
return prompt.format(retrieval_question=self.retrieval_question, context=context)

async def get_response_from_model(self, prompt):
response = self.model_to_test.invoke( # Need to make it async
prompt
)
return response




if __name__ == "__main__":
# Tons of defaults set, check out the LLMNeedleHaystackTester's init for more info
ht = OSS_TogetherEvaluator(model_name='Mistral-7B-Instruct-v0.2', evaluation_method='substring_match',
save_results_dir='results/Mixtral_7B/results-run2',
save_contexts_dir='contexts/Mixtral_7B/contexts-run2'
)

ht.start_test()
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ huggingface-hub==0.19.4
idna==3.6
jsonpatch==1.33
jsonpointer==2.4
langchain==0.0.341
langchain-core==0.0.6
langchain==0.1.0
langchain-core==0.1.10
langsmith==0.0.67
marshmallow==3.20.1
multidict==6.0.4
Expand All @@ -43,3 +43,6 @@ typing-inspect==0.9.0
typing_extensions==4.8.0
urllib3==2.1.0
yarl==1.9.3

langchain-together==0.0.2.post1
together==0.2.10

0 comments on commit 8759ca2

Please sign in to comment.