-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
159 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# from https://github.com/li-plus/chatglm.cpp/blob/main/.github/workflows/python-package.yml | ||
|
||
name: Python package | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
pull_request: | ||
branches: [ "main" ] | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [ubuntu-latest, macos-latest] | ||
python-version: ["3.8", "3.9", "3.10", "3.11"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
submodules: true | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install pytest build | ||
python -m build --sdist | ||
pip install dist/*.tar.gz -v | ||
- name: Lint with black | ||
uses: psf/black@stable | ||
with: | ||
options: "--check --verbose" | ||
src: "qwen_cpp examples tests setup.py" | ||
- name: Test with pytest | ||
run: | | ||
cd tests | ||
pytest test_qwen_cpp.py | ||
build-windows: | ||
|
||
runs-on: windows-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
submodules: true | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install pytest | ||
pip install . -v | ||
- name: Test with pytest | ||
run: | | ||
cd tests | ||
pytest test_qwen_cpp.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# https://github.com/li-plus/chatglm.cpp/blob/main/.github/workflows/wheels.yml | ||
|
||
name: Build Wheels | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
jobs: | ||
build_wheels: | ||
name: Build wheels on ${{ matrix.os }} | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
matrix: | ||
# macos-13 is an intel runner, macos-14 is apple silicon | ||
os: [ubuntu-latest, windows-latest, macos-13, macos-14] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
submodules: true | ||
|
||
- name: Build wheels | ||
uses: pypa/[email protected] | ||
env: | ||
CIBW_BUILD: cp* | ||
CIBW_SKIP: "*-win32 *_i686 *musllinux*" | ||
CIBW_TEST_REQUIRES: pytest | ||
CIBW_TEST_COMMAND: pytest {package}/tests/test_qwen_cpp.py | ||
|
||
- uses: actions/upload-artifact@v4 | ||
with: | ||
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} | ||
path: ./wheelhouse/*.whl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from pathlib import Path | ||
import qwen_cpp | ||
import pytest | ||
|
||
PROJECT_ROOT = Path(__file__).resolve().parent.parent | ||
|
||
QWEN_MODEL_PATH = PROJECT_ROOT / "qwen2_1.8b_f16.bin" | ||
QWEN_TIKTOKEN_PAHT= PROJECT_ROOT / "qwen.tiktoken" | ||
LLAMA_MODEL_PATH = PROJECT_ROOT / "llama.bin" | ||
LLAMA_TIKTOKEN_PATH = PROJECT_ROOT / "llama3.tiktoken" | ||
|
||
|
||
def test_qwen_version(): | ||
print(qwen_cpp.__version__) | ||
|
||
|
||
def check_pipeline(model_path, tiktoken_path, prompt, target, gen_kwargs={}): | ||
messages = [qwen_cpp.ChatMessage(role="system", content="You are a helpful assistant."), qwen_cpp.ChatMessage(role="user", content=prompt)] | ||
|
||
pipeline = qwen_cpp.Pipeline(model_path, tiktoken_path) | ||
output = pipeline.chat(messages, do_sample=False, **gen_kwargs).content | ||
assert output == target | ||
|
||
stream_output = pipeline.chat(messages, do_sample=False, stream=True, **gen_kwargs) | ||
stream_output = "".join([msg.content for msg in stream_output]) | ||
assert stream_output == target | ||
|
||
|
||
@pytest.mark.skipif(not QWEN_MODEL_PATH.exists(), reason="model file not found") | ||
def test_pipeline_options(): | ||
# check max_length option | ||
pipeline = qwen_cpp.Pipeline(QWEN_MODEL_PATH, QWEN_TIKTOKEN_PAHT) | ||
assert pipeline.model.config.max_length == 4096 | ||
pipeline = qwen_cpp.Pipeline(QWEN_MODEL_PATH, QWEN_TIKTOKEN_PAHT, max_length=234) | ||
assert pipeline.model.config.max_length == 234 | ||
|
||
# check if resources are properly released | ||
for _ in range(100): | ||
qwen_cpp.Pipeline(QWEN_MODEL_PATH, QWEN_TIKTOKEN_PAHT) | ||
|
||
|
||
@pytest.mark.skipif(not QWEN_MODEL_PATH.exists(), reason="model file not found") | ||
def test_qwen_pipeline(): | ||
check_pipeline( | ||
model_path=QWEN_MODEL_PATH, | ||
tiktoken_path=QWEN_TIKTOKEN_PAHT, | ||
prompt="你好", | ||
target="你好!有什么我可以帮助你的吗?", | ||
) | ||
|
||
@pytest.mark.skipif(not LLAMA_MODEL_PATH.exists(), reason="model file not found") | ||
def test_llama_pipeline(): | ||
check_pipeline( | ||
model_path=LLAMA_MODEL_PATH, | ||
tiktoken_path=LLAMA_TIKTOKEN_PATH, | ||
prompt="hello", | ||
target="Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat? I'm here to assist you with any questions or topics you'd like to discuss.", | ||
) |